From 988d8622fafb765b0a734c3cb8e6e4da6e1f6840 Mon Sep 17 00:00:00 2001 From: mht Date: Wed, 28 May 2025 15:00:07 +0330 Subject: [PATCH] Fix: Correct classifier weight loading and device handling --- cimp/classifier/classifier.cpp | 143 ++++++++++----------------------- cimp/classifier/classifier.h | 6 +- 2 files changed, 46 insertions(+), 103 deletions(-) diff --git a/cimp/classifier/classifier.cpp b/cimp/classifier/classifier.cpp index cc2050c..87a0be8 100644 --- a/cimp/classifier/classifier.cpp +++ b/cimp/classifier/classifier.cpp @@ -78,57 +78,19 @@ torch::Tensor Classifier::FeatureExtractor::extract_feat(torch::Tensor x) { // Apply conv followed by normalization auto features = forward(x); - // Create a copy to hold the normalized result - auto normalized_features = features.clone(); - - // Apply the general normalization first (for most channels) - auto general_normalized = norm.forward(features); - normalized_features.copy_(general_normalized); - - // List of channels with the largest differences (based on analysis) - std::vector problematic_channels = {30, 485, 421, 129, 497, 347, 287, 7, 448, 252}; - - // Special handling for problematic channels with higher precision - for (int channel : problematic_channels) { - if (channel < features.size(1)) { - // Extract the channel - auto channel_data = features.index({torch::indexing::Slice(), channel, torch::indexing::Slice(), torch::indexing::Slice()}); - - // Convert to double for higher precision calculation - auto channel_double = channel_data.to(torch::kFloat64); - - // Manually implement the L2 normalization for this channel with higher precision - auto squared = channel_double * channel_double; - auto dims_product = static_cast(features.size(2) * features.size(3)); - auto sum_squared = torch::sum(squared.view({features.size(0), 1, 1, -1}), /*dim=*/3, /*keepdim=*/true); - - // Calculate the normalization factor with double precision - auto norm_factor = 0.011048543456039804 * torch::sqrt(dims_product / (sum_squared + 1e-5)); - - // Apply normalization and convert back to original dtype - auto normalized_channel = channel_double * norm_factor; - auto normalized_channel_float = normalized_channel.to(features.dtype()); - - // Update the specific channel in the normalized result - normalized_features.index_put_({torch::indexing::Slice(), channel, torch::indexing::Slice(), torch::indexing::Slice()}, - normalized_channel_float); - } - } - - return normalized_features; + // Apply the general normalization + return norm.forward(features); // Simplified: removed special channel handling } -void Classifier::FeatureExtractor::load_weights(const std::string& weights_dir) { +void Classifier::FeatureExtractor::load_weights(const std::string& weights_dir, torch::Device device) { try { std::string file_path = weights_dir + "/feature_extractor_0_weight.pt"; - // Read file into bytes first - std::vector data = Classifier::read_file_to_bytes(file_path); - // Use pickle_load with byte data - weight = torch::pickle_load(data).toTensor(); + // Call static Classifier::load_tensor with the device + weight = Classifier::load_tensor(file_path, device); // Assign weights to the conv layer conv0->weight = weight; std::cout << "Loaded feature extractor weights with shape: " - << weight.sizes() << std::endl; + << weight.sizes() << " on device " << weight.device() << std::endl; } catch (const std::exception& e) { std::cerr << "Error loading feature extractor weights: " << e.what() << std::endl; } @@ -139,16 +101,29 @@ torch::Tensor Classifier::FilterInitializer::forward(torch::Tensor x) { return filter_conv->forward(x); } -void Classifier::FilterInitializer::load_weights(const std::string& weights_dir) { +void Classifier::FilterInitializer::load_weights(const std::string& weights_dir, torch::Device device) { try { - std::string file_path = weights_dir + "/filter_initializer_filter_conv_weight.pt"; - // Read file into bytes first - std::vector data = Classifier::read_file_to_bytes(file_path); - // Use pickle_load with byte data - filter_conv_weight = torch::pickle_load(data).toTensor(); + std::string file_path_weight = weights_dir + "/filter_initializer_filter_conv_weight.pt"; + filter_conv_weight = Classifier::load_tensor(file_path_weight, device); filter_conv->weight = filter_conv_weight; - std::cout << "Loaded filter initializer weights with shape: " - << filter_conv_weight.sizes() << std::endl; + std::cout << "Loaded filter initializer conv weights with shape: " + << filter_conv_weight.sizes() << " on device " << filter_conv_weight.device() << std::endl; + + // Also load bias if it exists (and if filter_conv is configured to use bias) + std::string file_path_bias = weights_dir + "/filter_initializer_filter_conv_bias.pt"; + if (fs::exists(file_path_bias)) { + if (filter_conv->options.bias()) { // Check if bias is enabled for the layer + torch::Tensor bias_tensor = Classifier::load_tensor(file_path_bias, device); + filter_conv->bias = bias_tensor; + std::cout << "Loaded filter initializer conv bias with shape: " + << bias_tensor.sizes() << " on device " << bias_tensor.device() << std::endl; + } else { + std::cout << "Skipping filter_initializer_filter_conv_bias.pt as bias is false for the layer." << std::endl; + } + } else { + std::cout << "Filter initializer bias file not found: " << file_path_bias << std::endl; + } + } catch (const std::exception& e) { std::cerr << "Error loading filter initializer weights: " << e.what() << std::endl; } @@ -180,13 +155,8 @@ torch::Tensor Classifier::LinearFilter::extract_classification_feat(torch::Tenso } // Classifier implementation -Classifier::Classifier(const std::string& base_dir, torch::Device dev) - : device(dev), model_dir(base_dir + "/exported_weights/classifier"), linear_filter(4) { - - // Check if base directory exists - if (!fs::exists(base_dir)) { - throw std::runtime_error("Base directory does not exist: " + base_dir); - } +Classifier::Classifier(const std::string& model_weights_dir, torch::Device dev) + : device(dev), model_dir(model_weights_dir), linear_filter(4) { // Check if model directory exists if (!fs::exists(model_dir)) { @@ -194,7 +164,7 @@ Classifier::Classifier(const std::string& base_dir, torch::Device dev) } // Initialize feature extractor with appropriate parameters - linear_filter.feature_extractor.conv0 = torch::nn::Conv2d(torch::nn::Conv2dOptions(1024, 512, 3).padding(1)); + linear_filter.feature_extractor.conv0 = torch::nn::Conv2d(torch::nn::Conv2dOptions(1024, 512, 3).padding(1).bias(false)); linear_filter.feature_extractor.norm = InstanceL2Norm(true, 1e-5, 0.011048543456039804); // Initialize filter initializer @@ -214,45 +184,18 @@ Classifier::Classifier(const std::string& base_dir, torch::Device dev) // Load weights for the feature extractor and filter void Classifier::load_weights() { - std::string feat_ext_path = model_dir + "/feature_extractor.pt"; - std::string filter_init_path = model_dir + "/filter_initializer.pt"; - std::string filter_optimizer_path = model_dir + "/filter_optimizer.pt"; - - // Load feature extractor weights if they exist - if (fs::exists(feat_ext_path)) { - try { - auto feat_ext_weight = load_tensor(feat_ext_path); - linear_filter.feature_extractor.weight = feat_ext_weight; - std::cout << "Loaded feature extractor weights with shape: [" - << feat_ext_weight.size(0) << ", " << feat_ext_weight.size(1) << ", " - << feat_ext_weight.size(2) << ", " << feat_ext_weight.size(3) << "]" << std::endl; - } catch (const std::exception& e) { - std::cerr << "Error loading feature extractor weights: " << e.what() << std::endl; - throw; - } - } else { - std::cout << "Skipping feature extractor weights - file not found" << std::endl; - } - - // Load filter initializer weights if they exist - if (fs::exists(filter_init_path)) { - try { - auto filter_init_weight = load_tensor(filter_init_path); - linear_filter.filter_initializer.filter_conv_weight = filter_init_weight; - linear_filter.filter_initializer.filter_conv->weight = filter_init_weight; - std::cout << "Loaded filter initializer weights with shape: [" - << filter_init_weight.size(0) << ", " << filter_init_weight.size(1) << ", " - << filter_init_weight.size(2) << ", " << filter_init_weight.size(3) << "]" << std::endl; - } catch (const std::exception& e) { - std::cerr << "Error loading filter initializer weights: " << e.what() << std::endl; - throw; - } - } else { - std::cout << "Skipping filter initializer weights - file not found" << std::endl; - } - - // Skip filter optimizer weights since we don't use them for feature extraction only - std::cout << "Skipping filter optimizer weights - not needed for feature extraction" << std::endl; + // Call load_weights for sub-components directly + // The model_dir for Classifier is e.g., ".../exported_weights/classifier" + // The sub-component load_weights methods expect this directory. + std::cout << "Classifier::load_weights() called. Model directory: " << this->model_dir << std::endl; + + linear_filter.feature_extractor.load_weights(this->model_dir, this->device); + linear_filter.filter_initializer.load_weights(this->model_dir, this->device); + // Filter optimizer weights are not strictly needed for just feature extraction, + // but we can call its load_weights if it has one, for completeness or future use. + // linear_filter.filter_optimizer.load_weights(this->model_dir); // Assuming it exists and is safe to call + + std::cout << "Skipping main filter optimizer weights in Classifier::load_weights - not needed for feature extraction" << std::endl; } void Classifier::to(torch::Device device) { @@ -382,7 +325,7 @@ void Classifier::save_stats(const std::vector& all_stats, const std } // Load weights for the model -torch::Tensor Classifier::load_tensor(const std::string& file_path) { +torch::Tensor Classifier::load_tensor(const std::string& file_path, torch::Device device) { try { // Read file into bytes first std::vector data = read_file_to_bytes(file_path); diff --git a/cimp/classifier/classifier.h b/cimp/classifier/classifier.h index 7e11e0e..e83a2f5 100644 --- a/cimp/classifier/classifier.h +++ b/cimp/classifier/classifier.h @@ -43,7 +43,7 @@ public: static std::vector read_file_to_bytes(const std::string& file_path); // Helper function to load a tensor from a file - torch::Tensor load_tensor(const std::string& file_path); + static torch::Tensor load_tensor(const std::string& file_path, torch::Device device); // Statistics structure for tensors struct TensorStats { @@ -71,7 +71,7 @@ private: torch::Tensor forward(torch::Tensor x); torch::Tensor extract_feat(torch::Tensor x); - void load_weights(const std::string& weights_dir); + void load_weights(const std::string& weights_dir, torch::Device device); }; // Filter initializer component @@ -80,7 +80,7 @@ private: torch::Tensor filter_conv_weight; torch::Tensor forward(torch::Tensor x); - void load_weights(const std::string& weights_dir); + void load_weights(const std::string& weights_dir, torch::Device device); }; // Filter optimizer component