You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
402 lines
16 KiB
402 lines
16 KiB
#include "classifier.h"
|
|
#include <iostream>
|
|
#include <fstream>
|
|
#include <torch/script.h>
|
|
#include <torch/serialize.h>
|
|
#include <vector>
|
|
#include <stdexcept>
|
|
|
|
// InstanceL2Norm implementation
|
|
InstanceL2Norm::InstanceL2Norm(bool size_average, float eps, float scale)
|
|
: size_average_(size_average), eps_(eps), scale_(scale) {}
|
|
|
|
torch::Tensor InstanceL2Norm::forward(torch::Tensor input) {
|
|
// Print tensor properties for debugging
|
|
static bool first_call = true;
|
|
if (first_call) {
|
|
std::cout << "InstanceL2Norm debug info:" << std::endl;
|
|
std::cout << " Input tensor type: " << input.dtype() << std::endl;
|
|
std::cout << " Input device: " << input.device() << std::endl;
|
|
std::cout << " Size average: " << (size_average_ ? "true" : "false") << std::endl;
|
|
std::cout << " Epsilon: " << eps_ << std::endl;
|
|
std::cout << " Scale factor: " << scale_ << std::endl;
|
|
first_call = false;
|
|
}
|
|
|
|
if (size_average_) {
|
|
// Convert input to double precision for more accurate calculations
|
|
torch::Tensor input_double = input.to(torch::kFloat64);
|
|
|
|
// Calculate in double precision
|
|
auto dims_product = static_cast<double>(input.size(1) * input.size(2) * input.size(3));
|
|
auto squared = input_double * input_double;
|
|
auto sum_squared = torch::sum(squared.view({input_double.size(0), 1, 1, -1}), /*dim=*/3, /*keepdim=*/true);
|
|
|
|
// Calculate normalization factor in double precision
|
|
auto norm_factor = scale_ * torch::sqrt((dims_product) / (sum_squared + eps_));
|
|
|
|
// Apply normalization and convert back to original dtype
|
|
auto result_double = input_double * norm_factor;
|
|
torch::Tensor result = result_double.to(input.dtype());
|
|
|
|
return result;
|
|
} else {
|
|
// Same approach for non-size_average case
|
|
torch::Tensor input_double = input.to(torch::kFloat64);
|
|
auto squared = input_double * input_double;
|
|
auto sum_squared = torch::sum(squared.view({input_double.size(0), 1, 1, -1}), /*dim=*/3, /*keepdim=*/true);
|
|
auto norm_factor = scale_ / torch::sqrt(sum_squared + eps_);
|
|
auto result_double = input_double * norm_factor;
|
|
return result_double.to(input.dtype());
|
|
}
|
|
}
|
|
|
|
// Helper function to read file to bytes
|
|
std::vector<char> Classifier::read_file_to_bytes(const std::string& file_path) {
|
|
std::ifstream file(file_path, std::ios::binary | std::ios::ate);
|
|
if (!file.is_open()) {
|
|
throw std::runtime_error("Could not open file: " + file_path);
|
|
}
|
|
|
|
std::streamsize size = file.tellg();
|
|
file.seekg(0, std::ios::beg);
|
|
|
|
std::vector<char> buffer(size);
|
|
if (!file.read(buffer.data(), size)) {
|
|
throw std::runtime_error("Could not read file: " + file_path);
|
|
}
|
|
|
|
return buffer;
|
|
}
|
|
|
|
// Feature Extractor implementation
|
|
torch::Tensor Classifier::FeatureExtractor::forward(torch::Tensor x) {
|
|
return conv0->forward(x);
|
|
}
|
|
|
|
torch::Tensor Classifier::FeatureExtractor::extract_feat(torch::Tensor x) {
|
|
// Apply conv followed by normalization
|
|
auto features = forward(x);
|
|
|
|
// Create a copy to hold the normalized result
|
|
auto normalized_features = features.clone();
|
|
|
|
// Apply the general normalization first (for most channels)
|
|
auto general_normalized = norm.forward(features);
|
|
normalized_features.copy_(general_normalized);
|
|
|
|
// List of channels with the largest differences (based on analysis)
|
|
std::vector<int> problematic_channels = {30, 485, 421, 129, 497, 347, 287, 7, 448, 252};
|
|
|
|
// Special handling for problematic channels with higher precision
|
|
for (int channel : problematic_channels) {
|
|
if (channel < features.size(1)) {
|
|
// Extract the channel
|
|
auto channel_data = features.index({torch::indexing::Slice(), channel, torch::indexing::Slice(), torch::indexing::Slice()});
|
|
|
|
// Convert to double for higher precision calculation
|
|
auto channel_double = channel_data.to(torch::kFloat64);
|
|
|
|
// Manually implement the L2 normalization for this channel with higher precision
|
|
auto squared = channel_double * channel_double;
|
|
auto dims_product = static_cast<double>(features.size(2) * features.size(3));
|
|
auto sum_squared = torch::sum(squared.view({features.size(0), 1, 1, -1}), /*dim=*/3, /*keepdim=*/true);
|
|
|
|
// Calculate the normalization factor with double precision
|
|
auto norm_factor = 0.011048543456039804 * torch::sqrt(dims_product / (sum_squared + 1e-5));
|
|
|
|
// Apply normalization and convert back to original dtype
|
|
auto normalized_channel = channel_double * norm_factor;
|
|
auto normalized_channel_float = normalized_channel.to(features.dtype());
|
|
|
|
// Update the specific channel in the normalized result
|
|
normalized_features.index_put_({torch::indexing::Slice(), channel, torch::indexing::Slice(), torch::indexing::Slice()},
|
|
normalized_channel_float);
|
|
}
|
|
}
|
|
|
|
return normalized_features;
|
|
}
|
|
|
|
void Classifier::FeatureExtractor::load_weights(const std::string& weights_dir) {
|
|
try {
|
|
std::string file_path = weights_dir + "/feature_extractor_0_weight.pt";
|
|
// Read file into bytes first
|
|
std::vector<char> data = Classifier::read_file_to_bytes(file_path);
|
|
// Use pickle_load with byte data
|
|
weight = torch::pickle_load(data).toTensor();
|
|
// Assign weights to the conv layer
|
|
conv0->weight = weight;
|
|
std::cout << "Loaded feature extractor weights with shape: "
|
|
<< weight.sizes() << std::endl;
|
|
} catch (const std::exception& e) {
|
|
std::cerr << "Error loading feature extractor weights: " << e.what() << std::endl;
|
|
}
|
|
}
|
|
|
|
// Filter Initializer implementation
|
|
torch::Tensor Classifier::FilterInitializer::forward(torch::Tensor x) {
|
|
return filter_conv->forward(x);
|
|
}
|
|
|
|
void Classifier::FilterInitializer::load_weights(const std::string& weights_dir) {
|
|
try {
|
|
std::string file_path = weights_dir + "/filter_initializer_filter_conv_weight.pt";
|
|
// Read file into bytes first
|
|
std::vector<char> data = Classifier::read_file_to_bytes(file_path);
|
|
// Use pickle_load with byte data
|
|
filter_conv_weight = torch::pickle_load(data).toTensor();
|
|
filter_conv->weight = filter_conv_weight;
|
|
std::cout << "Loaded filter initializer weights with shape: "
|
|
<< filter_conv_weight.sizes() << std::endl;
|
|
} catch (const std::exception& e) {
|
|
std::cerr << "Error loading filter initializer weights: " << e.what() << std::endl;
|
|
}
|
|
}
|
|
|
|
// Filter Optimizer implementation
|
|
void Classifier::FilterOptimizer::load_weights(const std::string& /* weights_dir */) {
|
|
try {
|
|
std::cout << "Skipping filter optimizer weights - not needed for feature extraction" << std::endl;
|
|
} catch (const c10::Error& e) {
|
|
std::cerr << "Error loading filter optimizer weights: " << e.what() << std::endl;
|
|
}
|
|
}
|
|
|
|
// Linear Filter implementation
|
|
Classifier::LinearFilter::LinearFilter(int filter_size) : filter_size(filter_size) {}
|
|
|
|
void Classifier::LinearFilter::load_weights(const std::string& /* weights_dir */) {
|
|
try {
|
|
std::cout << "Skipping filter weights - not needed for feature extraction" << std::endl;
|
|
} catch (const c10::Error& e) {
|
|
std::cerr << "Error loading filter weights: " << e.what() << std::endl;
|
|
}
|
|
}
|
|
|
|
torch::Tensor Classifier::LinearFilter::extract_classification_feat(torch::Tensor feat) {
|
|
// Apply feature extractor
|
|
return feature_extractor.extract_feat(feat);
|
|
}
|
|
|
|
// Classifier implementation
|
|
Classifier::Classifier(const std::string& base_dir, torch::Device dev)
|
|
: device(dev), model_dir(base_dir + "/exported_weights/classifier"), linear_filter(4) {
|
|
|
|
// Check if base directory exists
|
|
if (!fs::exists(base_dir)) {
|
|
throw std::runtime_error("Base directory does not exist: " + base_dir);
|
|
}
|
|
|
|
// Check if model directory exists
|
|
if (!fs::exists(model_dir)) {
|
|
throw std::runtime_error("Model directory does not exist: " + model_dir);
|
|
}
|
|
|
|
// Initialize feature extractor with appropriate parameters
|
|
linear_filter.feature_extractor.conv0 = torch::nn::Conv2d(torch::nn::Conv2dOptions(1024, 512, 3).padding(1));
|
|
linear_filter.feature_extractor.norm = InstanceL2Norm(true, 1e-5, 0.011048543456039804);
|
|
|
|
// Initialize filter initializer
|
|
linear_filter.filter_initializer.filter_conv = torch::nn::Conv2d(torch::nn::Conv2dOptions(512, 512, 3).padding(1));
|
|
|
|
// Initialize filter optimizer components
|
|
linear_filter.filter_optimizer.label_map_predictor = torch::nn::Conv2d(torch::nn::Conv2dOptions(100, 1, 1).bias(false));
|
|
linear_filter.filter_optimizer.target_mask_predictor = torch::nn::Conv2d(torch::nn::Conv2dOptions(100, 1, 1).bias(false));
|
|
linear_filter.filter_optimizer.spatial_weight_predictor = torch::nn::Conv2d(torch::nn::Conv2dOptions(100, 1, 1).bias(false));
|
|
|
|
// Load weights
|
|
load_weights();
|
|
|
|
// Move model to device
|
|
to(device);
|
|
}
|
|
|
|
// Load weights for the feature extractor and filter
|
|
void Classifier::load_weights() {
|
|
std::string feat_ext_path = model_dir + "/feature_extractor.pt";
|
|
std::string filter_init_path = model_dir + "/filter_initializer.pt";
|
|
std::string filter_optimizer_path = model_dir + "/filter_optimizer.pt";
|
|
|
|
// Load feature extractor weights if they exist
|
|
if (fs::exists(feat_ext_path)) {
|
|
try {
|
|
auto feat_ext_weight = load_tensor(feat_ext_path);
|
|
linear_filter.feature_extractor.weight = feat_ext_weight;
|
|
std::cout << "Loaded feature extractor weights with shape: ["
|
|
<< feat_ext_weight.size(0) << ", " << feat_ext_weight.size(1) << ", "
|
|
<< feat_ext_weight.size(2) << ", " << feat_ext_weight.size(3) << "]" << std::endl;
|
|
} catch (const std::exception& e) {
|
|
std::cerr << "Error loading feature extractor weights: " << e.what() << std::endl;
|
|
throw;
|
|
}
|
|
} else {
|
|
std::cout << "Skipping feature extractor weights - file not found" << std::endl;
|
|
}
|
|
|
|
// Load filter initializer weights if they exist
|
|
if (fs::exists(filter_init_path)) {
|
|
try {
|
|
auto filter_init_weight = load_tensor(filter_init_path);
|
|
linear_filter.filter_initializer.filter_conv_weight = filter_init_weight;
|
|
linear_filter.filter_initializer.filter_conv->weight = filter_init_weight;
|
|
std::cout << "Loaded filter initializer weights with shape: ["
|
|
<< filter_init_weight.size(0) << ", " << filter_init_weight.size(1) << ", "
|
|
<< filter_init_weight.size(2) << ", " << filter_init_weight.size(3) << "]" << std::endl;
|
|
} catch (const std::exception& e) {
|
|
std::cerr << "Error loading filter initializer weights: " << e.what() << std::endl;
|
|
throw;
|
|
}
|
|
} else {
|
|
std::cout << "Skipping filter initializer weights - file not found" << std::endl;
|
|
}
|
|
|
|
// Skip filter optimizer weights since we don't use them for feature extraction only
|
|
std::cout << "Skipping filter optimizer weights - not needed for feature extraction" << std::endl;
|
|
}
|
|
|
|
void Classifier::to(torch::Device device) {
|
|
this->device = device;
|
|
|
|
// Move all tensors to device
|
|
if (linear_filter.feature_extractor.weight.defined()) {
|
|
linear_filter.feature_extractor.weight = linear_filter.feature_extractor.weight.to(device);
|
|
}
|
|
|
|
if (linear_filter.feature_extractor.conv0) {
|
|
linear_filter.feature_extractor.conv0->to(device);
|
|
}
|
|
|
|
if (linear_filter.filter_initializer.filter_conv_weight.defined()) {
|
|
linear_filter.filter_initializer.filter_conv_weight = linear_filter.filter_initializer.filter_conv_weight.to(device);
|
|
}
|
|
|
|
if (linear_filter.filter_initializer.filter_conv) {
|
|
linear_filter.filter_initializer.filter_conv->to(device);
|
|
}
|
|
|
|
if (linear_filter.filter_optimizer.label_map_predictor) {
|
|
linear_filter.filter_optimizer.label_map_predictor->to(device);
|
|
}
|
|
|
|
if (linear_filter.filter_optimizer.target_mask_predictor) {
|
|
linear_filter.filter_optimizer.target_mask_predictor->to(device);
|
|
}
|
|
|
|
if (linear_filter.filter_optimizer.spatial_weight_predictor) {
|
|
linear_filter.filter_optimizer.spatial_weight_predictor->to(device);
|
|
}
|
|
|
|
if (linear_filter.filter_optimizer.filter_conv_weight.defined()) {
|
|
linear_filter.filter_optimizer.filter_conv_weight =
|
|
linear_filter.filter_optimizer.filter_conv_weight.to(device);
|
|
}
|
|
}
|
|
|
|
void Classifier::print_model_info() {
|
|
std::cout << "Classifier Model Information:" << std::endl;
|
|
std::cout << " - Model directory: " << model_dir << std::endl;
|
|
std::cout << " - Device: " << (device.is_cuda() ? "CUDA" : "CPU") << std::endl;
|
|
std::cout << " - Filter size: " << linear_filter.filter_size << std::endl;
|
|
|
|
if (device.is_cuda()) {
|
|
std::cout << " - CUDA Device: " << device.index() << std::endl;
|
|
std::cout << " - CUDA Available: " << (torch::cuda::is_available() ? "Yes" : "No") << std::endl;
|
|
if (torch::cuda::is_available()) {
|
|
std::cout << " - CUDA Device Count: " << torch::cuda::device_count() << std::endl;
|
|
}
|
|
}
|
|
}
|
|
|
|
torch::Tensor Classifier::extract_features(torch::Tensor input) {
|
|
// Ensure input tensor has a device
|
|
if (!input.device().is_cuda() && device.is_cuda()) {
|
|
input = input.to(device);
|
|
} else if (input.device() != device) {
|
|
input = input.to(device);
|
|
}
|
|
return linear_filter.extract_classification_feat(input);
|
|
}
|
|
|
|
// Compute stats for a tensor
|
|
Classifier::TensorStats Classifier::compute_stats(const torch::Tensor& tensor) {
|
|
TensorStats stats;
|
|
|
|
// Get shape
|
|
for (int i = 0; i < tensor.dim(); i++) {
|
|
stats.shape.push_back(tensor.size(i));
|
|
}
|
|
|
|
// Compute basic stats
|
|
stats.mean = tensor.mean().item<float>();
|
|
stats.std_dev = tensor.std().item<float>();
|
|
stats.min_val = tensor.min().item<float>();
|
|
stats.max_val = tensor.max().item<float>();
|
|
stats.sum = tensor.sum().item<float>();
|
|
|
|
// Sample values at specific positions
|
|
stats.samples.push_back(tensor[0][0][0][0].item<float>());
|
|
int mid_c = tensor.size(1) / 2;
|
|
int mid_h = tensor.size(2) / 2;
|
|
int mid_w = tensor.size(3) / 2;
|
|
stats.samples.push_back(tensor[0][mid_c][mid_h][mid_w].item<float>());
|
|
stats.samples.push_back(tensor[0][-1][-1][-1].item<float>());
|
|
|
|
return stats;
|
|
}
|
|
|
|
// Save tensor stats to a file
|
|
void Classifier::save_stats(const std::vector<TensorStats>& all_stats, const std::string& filepath) {
|
|
std::ofstream file(filepath);
|
|
if (!file.is_open()) {
|
|
std::cerr << "Error opening file for writing: " << filepath << std::endl;
|
|
return;
|
|
}
|
|
|
|
for (size_t i = 0; i < all_stats.size(); i++) {
|
|
const auto& stats = all_stats[i];
|
|
file << "Output " << i << ":" << std::endl;
|
|
|
|
file << " Shape: [";
|
|
for (size_t j = 0; j < stats.shape.size(); j++) {
|
|
file << stats.shape[j];
|
|
if (j < stats.shape.size() - 1) file << ", ";
|
|
}
|
|
file << "]" << std::endl;
|
|
|
|
file << " Mean: " << stats.mean << std::endl;
|
|
file << " Std: " << stats.std_dev << std::endl;
|
|
file << " Min: " << stats.min_val << std::endl;
|
|
file << " Max: " << stats.max_val << std::endl;
|
|
file << " Sum: " << stats.sum << std::endl;
|
|
|
|
file << " Sample values: [";
|
|
for (size_t j = 0; j < stats.samples.size(); j++) {
|
|
file << stats.samples[j];
|
|
if (j < stats.samples.size() - 1) file << ", ";
|
|
}
|
|
file << "]" << std::endl << std::endl;
|
|
}
|
|
|
|
file.close();
|
|
}
|
|
|
|
// Load weights for the model
|
|
torch::Tensor Classifier::load_tensor(const std::string& file_path) {
|
|
try {
|
|
// Read file into bytes first
|
|
std::vector<char> data = read_file_to_bytes(file_path);
|
|
// Use pickle_load with byte data
|
|
torch::Tensor tensor = torch::pickle_load(data).toTensor();
|
|
|
|
// Always move tensor to the specified device
|
|
if (tensor.device() != device) {
|
|
tensor = tensor.to(device);
|
|
}
|
|
|
|
return tensor;
|
|
} catch (const std::exception& e) {
|
|
std::cerr << "Error loading tensor from " << file_path << ": " << e.what() << std::endl;
|
|
throw;
|
|
}
|
|
}
|