Browse Source

Fix: Correct classifier weight loading and device handling

master
mht 1 week ago
parent
commit
988d8622fa
  1. 143
      cimp/classifier/classifier.cpp
  2. 6
      cimp/classifier/classifier.h

143
cimp/classifier/classifier.cpp

@ -78,57 +78,19 @@ torch::Tensor Classifier::FeatureExtractor::extract_feat(torch::Tensor x) {
// Apply conv followed by normalization
auto features = forward(x);
// Create a copy to hold the normalized result
auto normalized_features = features.clone();
// Apply the general normalization first (for most channels)
auto general_normalized = norm.forward(features);
normalized_features.copy_(general_normalized);
// List of channels with the largest differences (based on analysis)
std::vector<int> problematic_channels = {30, 485, 421, 129, 497, 347, 287, 7, 448, 252};
// Special handling for problematic channels with higher precision
for (int channel : problematic_channels) {
if (channel < features.size(1)) {
// Extract the channel
auto channel_data = features.index({torch::indexing::Slice(), channel, torch::indexing::Slice(), torch::indexing::Slice()});
// Convert to double for higher precision calculation
auto channel_double = channel_data.to(torch::kFloat64);
// Manually implement the L2 normalization for this channel with higher precision
auto squared = channel_double * channel_double;
auto dims_product = static_cast<double>(features.size(2) * features.size(3));
auto sum_squared = torch::sum(squared.view({features.size(0), 1, 1, -1}), /*dim=*/3, /*keepdim=*/true);
// Calculate the normalization factor with double precision
auto norm_factor = 0.011048543456039804 * torch::sqrt(dims_product / (sum_squared + 1e-5));
// Apply normalization and convert back to original dtype
auto normalized_channel = channel_double * norm_factor;
auto normalized_channel_float = normalized_channel.to(features.dtype());
// Update the specific channel in the normalized result
normalized_features.index_put_({torch::indexing::Slice(), channel, torch::indexing::Slice(), torch::indexing::Slice()},
normalized_channel_float);
}
}
return normalized_features;
// Apply the general normalization
return norm.forward(features); // Simplified: removed special channel handling
}
void Classifier::FeatureExtractor::load_weights(const std::string& weights_dir) {
void Classifier::FeatureExtractor::load_weights(const std::string& weights_dir, torch::Device device) {
try {
std::string file_path = weights_dir + "/feature_extractor_0_weight.pt";
// Read file into bytes first
std::vector<char> data = Classifier::read_file_to_bytes(file_path);
// Use pickle_load with byte data
weight = torch::pickle_load(data).toTensor();
// Call static Classifier::load_tensor with the device
weight = Classifier::load_tensor(file_path, device);
// Assign weights to the conv layer
conv0->weight = weight;
std::cout << "Loaded feature extractor weights with shape: "
<< weight.sizes() << std::endl;
<< weight.sizes() << " on device " << weight.device() << std::endl;
} catch (const std::exception& e) {
std::cerr << "Error loading feature extractor weights: " << e.what() << std::endl;
}
@ -139,16 +101,29 @@ torch::Tensor Classifier::FilterInitializer::forward(torch::Tensor x) {
return filter_conv->forward(x);
}
void Classifier::FilterInitializer::load_weights(const std::string& weights_dir) {
void Classifier::FilterInitializer::load_weights(const std::string& weights_dir, torch::Device device) {
try {
std::string file_path = weights_dir + "/filter_initializer_filter_conv_weight.pt";
// Read file into bytes first
std::vector<char> data = Classifier::read_file_to_bytes(file_path);
// Use pickle_load with byte data
filter_conv_weight = torch::pickle_load(data).toTensor();
std::string file_path_weight = weights_dir + "/filter_initializer_filter_conv_weight.pt";
filter_conv_weight = Classifier::load_tensor(file_path_weight, device);
filter_conv->weight = filter_conv_weight;
std::cout << "Loaded filter initializer weights with shape: "
<< filter_conv_weight.sizes() << std::endl;
std::cout << "Loaded filter initializer conv weights with shape: "
<< filter_conv_weight.sizes() << " on device " << filter_conv_weight.device() << std::endl;
// Also load bias if it exists (and if filter_conv is configured to use bias)
std::string file_path_bias = weights_dir + "/filter_initializer_filter_conv_bias.pt";
if (fs::exists(file_path_bias)) {
if (filter_conv->options.bias()) { // Check if bias is enabled for the layer
torch::Tensor bias_tensor = Classifier::load_tensor(file_path_bias, device);
filter_conv->bias = bias_tensor;
std::cout << "Loaded filter initializer conv bias with shape: "
<< bias_tensor.sizes() << " on device " << bias_tensor.device() << std::endl;
} else {
std::cout << "Skipping filter_initializer_filter_conv_bias.pt as bias is false for the layer." << std::endl;
}
} else {
std::cout << "Filter initializer bias file not found: " << file_path_bias << std::endl;
}
} catch (const std::exception& e) {
std::cerr << "Error loading filter initializer weights: " << e.what() << std::endl;
}
@ -180,13 +155,8 @@ torch::Tensor Classifier::LinearFilter::extract_classification_feat(torch::Tenso
}
// Classifier implementation
Classifier::Classifier(const std::string& base_dir, torch::Device dev)
: device(dev), model_dir(base_dir + "/exported_weights/classifier"), linear_filter(4) {
// Check if base directory exists
if (!fs::exists(base_dir)) {
throw std::runtime_error("Base directory does not exist: " + base_dir);
}
Classifier::Classifier(const std::string& model_weights_dir, torch::Device dev)
: device(dev), model_dir(model_weights_dir), linear_filter(4) {
// Check if model directory exists
if (!fs::exists(model_dir)) {
@ -194,7 +164,7 @@ Classifier::Classifier(const std::string& base_dir, torch::Device dev)
}
// Initialize feature extractor with appropriate parameters
linear_filter.feature_extractor.conv0 = torch::nn::Conv2d(torch::nn::Conv2dOptions(1024, 512, 3).padding(1));
linear_filter.feature_extractor.conv0 = torch::nn::Conv2d(torch::nn::Conv2dOptions(1024, 512, 3).padding(1).bias(false));
linear_filter.feature_extractor.norm = InstanceL2Norm(true, 1e-5, 0.011048543456039804);
// Initialize filter initializer
@ -214,45 +184,18 @@ Classifier::Classifier(const std::string& base_dir, torch::Device dev)
// Load weights for the feature extractor and filter
void Classifier::load_weights() {
std::string feat_ext_path = model_dir + "/feature_extractor.pt";
std::string filter_init_path = model_dir + "/filter_initializer.pt";
std::string filter_optimizer_path = model_dir + "/filter_optimizer.pt";
// Load feature extractor weights if they exist
if (fs::exists(feat_ext_path)) {
try {
auto feat_ext_weight = load_tensor(feat_ext_path);
linear_filter.feature_extractor.weight = feat_ext_weight;
std::cout << "Loaded feature extractor weights with shape: ["
<< feat_ext_weight.size(0) << ", " << feat_ext_weight.size(1) << ", "
<< feat_ext_weight.size(2) << ", " << feat_ext_weight.size(3) << "]" << std::endl;
} catch (const std::exception& e) {
std::cerr << "Error loading feature extractor weights: " << e.what() << std::endl;
throw;
}
} else {
std::cout << "Skipping feature extractor weights - file not found" << std::endl;
}
// Load filter initializer weights if they exist
if (fs::exists(filter_init_path)) {
try {
auto filter_init_weight = load_tensor(filter_init_path);
linear_filter.filter_initializer.filter_conv_weight = filter_init_weight;
linear_filter.filter_initializer.filter_conv->weight = filter_init_weight;
std::cout << "Loaded filter initializer weights with shape: ["
<< filter_init_weight.size(0) << ", " << filter_init_weight.size(1) << ", "
<< filter_init_weight.size(2) << ", " << filter_init_weight.size(3) << "]" << std::endl;
} catch (const std::exception& e) {
std::cerr << "Error loading filter initializer weights: " << e.what() << std::endl;
throw;
}
} else {
std::cout << "Skipping filter initializer weights - file not found" << std::endl;
}
// Skip filter optimizer weights since we don't use them for feature extraction only
std::cout << "Skipping filter optimizer weights - not needed for feature extraction" << std::endl;
// Call load_weights for sub-components directly
// The model_dir for Classifier is e.g., ".../exported_weights/classifier"
// The sub-component load_weights methods expect this directory.
std::cout << "Classifier::load_weights() called. Model directory: " << this->model_dir << std::endl;
linear_filter.feature_extractor.load_weights(this->model_dir, this->device);
linear_filter.filter_initializer.load_weights(this->model_dir, this->device);
// Filter optimizer weights are not strictly needed for just feature extraction,
// but we can call its load_weights if it has one, for completeness or future use.
// linear_filter.filter_optimizer.load_weights(this->model_dir); // Assuming it exists and is safe to call
std::cout << "Skipping main filter optimizer weights in Classifier::load_weights - not needed for feature extraction" << std::endl;
}
void Classifier::to(torch::Device device) {
@ -382,7 +325,7 @@ void Classifier::save_stats(const std::vector<TensorStats>& all_stats, const std
}
// Load weights for the model
torch::Tensor Classifier::load_tensor(const std::string& file_path) {
torch::Tensor Classifier::load_tensor(const std::string& file_path, torch::Device device) {
try {
// Read file into bytes first
std::vector<char> data = read_file_to_bytes(file_path);

6
cimp/classifier/classifier.h

@ -43,7 +43,7 @@ public:
static std::vector<char> read_file_to_bytes(const std::string& file_path);
// Helper function to load a tensor from a file
torch::Tensor load_tensor(const std::string& file_path);
static torch::Tensor load_tensor(const std::string& file_path, torch::Device device);
// Statistics structure for tensors
struct TensorStats {
@ -71,7 +71,7 @@ private:
torch::Tensor forward(torch::Tensor x);
torch::Tensor extract_feat(torch::Tensor x);
void load_weights(const std::string& weights_dir);
void load_weights(const std::string& weights_dir, torch::Device device);
};
// Filter initializer component
@ -80,7 +80,7 @@ private:
torch::Tensor filter_conv_weight;
torch::Tensor forward(torch::Tensor x);
void load_weights(const std::string& weights_dir);
void load_weights(const std::string& weights_dir, torch::Device device);
};
// Filter optimizer component

Loading…
Cancel
Save