|
|
@ -78,57 +78,19 @@ torch::Tensor Classifier::FeatureExtractor::extract_feat(torch::Tensor x) { |
|
|
|
// Apply conv followed by normalization
|
|
|
|
auto features = forward(x); |
|
|
|
|
|
|
|
// Create a copy to hold the normalized result
|
|
|
|
auto normalized_features = features.clone(); |
|
|
|
|
|
|
|
// Apply the general normalization first (for most channels)
|
|
|
|
auto general_normalized = norm.forward(features); |
|
|
|
normalized_features.copy_(general_normalized); |
|
|
|
|
|
|
|
// List of channels with the largest differences (based on analysis)
|
|
|
|
std::vector<int> problematic_channels = {30, 485, 421, 129, 497, 347, 287, 7, 448, 252}; |
|
|
|
|
|
|
|
// Special handling for problematic channels with higher precision
|
|
|
|
for (int channel : problematic_channels) { |
|
|
|
if (channel < features.size(1)) { |
|
|
|
// Extract the channel
|
|
|
|
auto channel_data = features.index({torch::indexing::Slice(), channel, torch::indexing::Slice(), torch::indexing::Slice()}); |
|
|
|
|
|
|
|
// Convert to double for higher precision calculation
|
|
|
|
auto channel_double = channel_data.to(torch::kFloat64); |
|
|
|
|
|
|
|
// Manually implement the L2 normalization for this channel with higher precision
|
|
|
|
auto squared = channel_double * channel_double; |
|
|
|
auto dims_product = static_cast<double>(features.size(2) * features.size(3)); |
|
|
|
auto sum_squared = torch::sum(squared.view({features.size(0), 1, 1, -1}), /*dim=*/3, /*keepdim=*/true); |
|
|
|
|
|
|
|
// Calculate the normalization factor with double precision
|
|
|
|
auto norm_factor = 0.011048543456039804 * torch::sqrt(dims_product / (sum_squared + 1e-5)); |
|
|
|
|
|
|
|
// Apply normalization and convert back to original dtype
|
|
|
|
auto normalized_channel = channel_double * norm_factor; |
|
|
|
auto normalized_channel_float = normalized_channel.to(features.dtype()); |
|
|
|
|
|
|
|
// Update the specific channel in the normalized result
|
|
|
|
normalized_features.index_put_({torch::indexing::Slice(), channel, torch::indexing::Slice(), torch::indexing::Slice()}, |
|
|
|
normalized_channel_float); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return normalized_features; |
|
|
|
// Apply the general normalization
|
|
|
|
return norm.forward(features); // Simplified: removed special channel handling
|
|
|
|
} |
|
|
|
|
|
|
|
void Classifier::FeatureExtractor::load_weights(const std::string& weights_dir) { |
|
|
|
void Classifier::FeatureExtractor::load_weights(const std::string& weights_dir, torch::Device device) { |
|
|
|
try { |
|
|
|
std::string file_path = weights_dir + "/feature_extractor_0_weight.pt"; |
|
|
|
// Read file into bytes first
|
|
|
|
std::vector<char> data = Classifier::read_file_to_bytes(file_path); |
|
|
|
// Use pickle_load with byte data
|
|
|
|
weight = torch::pickle_load(data).toTensor(); |
|
|
|
// Call static Classifier::load_tensor with the device
|
|
|
|
weight = Classifier::load_tensor(file_path, device); |
|
|
|
// Assign weights to the conv layer
|
|
|
|
conv0->weight = weight; |
|
|
|
std::cout << "Loaded feature extractor weights with shape: " |
|
|
|
<< weight.sizes() << std::endl; |
|
|
|
<< weight.sizes() << " on device " << weight.device() << std::endl; |
|
|
|
} catch (const std::exception& e) { |
|
|
|
std::cerr << "Error loading feature extractor weights: " << e.what() << std::endl; |
|
|
|
} |
|
|
@ -139,16 +101,29 @@ torch::Tensor Classifier::FilterInitializer::forward(torch::Tensor x) { |
|
|
|
return filter_conv->forward(x); |
|
|
|
} |
|
|
|
|
|
|
|
void Classifier::FilterInitializer::load_weights(const std::string& weights_dir) { |
|
|
|
void Classifier::FilterInitializer::load_weights(const std::string& weights_dir, torch::Device device) { |
|
|
|
try { |
|
|
|
std::string file_path = weights_dir + "/filter_initializer_filter_conv_weight.pt"; |
|
|
|
// Read file into bytes first
|
|
|
|
std::vector<char> data = Classifier::read_file_to_bytes(file_path); |
|
|
|
// Use pickle_load with byte data
|
|
|
|
filter_conv_weight = torch::pickle_load(data).toTensor(); |
|
|
|
std::string file_path_weight = weights_dir + "/filter_initializer_filter_conv_weight.pt"; |
|
|
|
filter_conv_weight = Classifier::load_tensor(file_path_weight, device); |
|
|
|
filter_conv->weight = filter_conv_weight; |
|
|
|
std::cout << "Loaded filter initializer weights with shape: " |
|
|
|
<< filter_conv_weight.sizes() << std::endl; |
|
|
|
std::cout << "Loaded filter initializer conv weights with shape: " |
|
|
|
<< filter_conv_weight.sizes() << " on device " << filter_conv_weight.device() << std::endl; |
|
|
|
|
|
|
|
// Also load bias if it exists (and if filter_conv is configured to use bias)
|
|
|
|
std::string file_path_bias = weights_dir + "/filter_initializer_filter_conv_bias.pt"; |
|
|
|
if (fs::exists(file_path_bias)) { |
|
|
|
if (filter_conv->options.bias()) { // Check if bias is enabled for the layer
|
|
|
|
torch::Tensor bias_tensor = Classifier::load_tensor(file_path_bias, device); |
|
|
|
filter_conv->bias = bias_tensor; |
|
|
|
std::cout << "Loaded filter initializer conv bias with shape: " |
|
|
|
<< bias_tensor.sizes() << " on device " << bias_tensor.device() << std::endl; |
|
|
|
} else { |
|
|
|
std::cout << "Skipping filter_initializer_filter_conv_bias.pt as bias is false for the layer." << std::endl; |
|
|
|
} |
|
|
|
} else { |
|
|
|
std::cout << "Filter initializer bias file not found: " << file_path_bias << std::endl; |
|
|
|
} |
|
|
|
|
|
|
|
} catch (const std::exception& e) { |
|
|
|
std::cerr << "Error loading filter initializer weights: " << e.what() << std::endl; |
|
|
|
} |
|
|
@ -180,13 +155,8 @@ torch::Tensor Classifier::LinearFilter::extract_classification_feat(torch::Tenso |
|
|
|
} |
|
|
|
|
|
|
|
// Classifier implementation
|
|
|
|
Classifier::Classifier(const std::string& base_dir, torch::Device dev) |
|
|
|
: device(dev), model_dir(base_dir + "/exported_weights/classifier"), linear_filter(4) { |
|
|
|
|
|
|
|
// Check if base directory exists
|
|
|
|
if (!fs::exists(base_dir)) { |
|
|
|
throw std::runtime_error("Base directory does not exist: " + base_dir); |
|
|
|
} |
|
|
|
Classifier::Classifier(const std::string& model_weights_dir, torch::Device dev) |
|
|
|
: device(dev), model_dir(model_weights_dir), linear_filter(4) { |
|
|
|
|
|
|
|
// Check if model directory exists
|
|
|
|
if (!fs::exists(model_dir)) { |
|
|
@ -194,7 +164,7 @@ Classifier::Classifier(const std::string& base_dir, torch::Device dev) |
|
|
|
} |
|
|
|
|
|
|
|
// Initialize feature extractor with appropriate parameters
|
|
|
|
linear_filter.feature_extractor.conv0 = torch::nn::Conv2d(torch::nn::Conv2dOptions(1024, 512, 3).padding(1)); |
|
|
|
linear_filter.feature_extractor.conv0 = torch::nn::Conv2d(torch::nn::Conv2dOptions(1024, 512, 3).padding(1).bias(false)); |
|
|
|
linear_filter.feature_extractor.norm = InstanceL2Norm(true, 1e-5, 0.011048543456039804); |
|
|
|
|
|
|
|
// Initialize filter initializer
|
|
|
@ -214,45 +184,18 @@ Classifier::Classifier(const std::string& base_dir, torch::Device dev) |
|
|
|
|
|
|
|
// Load weights for the feature extractor and filter
|
|
|
|
void Classifier::load_weights() { |
|
|
|
std::string feat_ext_path = model_dir + "/feature_extractor.pt"; |
|
|
|
std::string filter_init_path = model_dir + "/filter_initializer.pt"; |
|
|
|
std::string filter_optimizer_path = model_dir + "/filter_optimizer.pt"; |
|
|
|
|
|
|
|
// Load feature extractor weights if they exist
|
|
|
|
if (fs::exists(feat_ext_path)) { |
|
|
|
try { |
|
|
|
auto feat_ext_weight = load_tensor(feat_ext_path); |
|
|
|
linear_filter.feature_extractor.weight = feat_ext_weight; |
|
|
|
std::cout << "Loaded feature extractor weights with shape: [" |
|
|
|
<< feat_ext_weight.size(0) << ", " << feat_ext_weight.size(1) << ", " |
|
|
|
<< feat_ext_weight.size(2) << ", " << feat_ext_weight.size(3) << "]" << std::endl; |
|
|
|
} catch (const std::exception& e) { |
|
|
|
std::cerr << "Error loading feature extractor weights: " << e.what() << std::endl; |
|
|
|
throw; |
|
|
|
} |
|
|
|
} else { |
|
|
|
std::cout << "Skipping feature extractor weights - file not found" << std::endl; |
|
|
|
} |
|
|
|
|
|
|
|
// Load filter initializer weights if they exist
|
|
|
|
if (fs::exists(filter_init_path)) { |
|
|
|
try { |
|
|
|
auto filter_init_weight = load_tensor(filter_init_path); |
|
|
|
linear_filter.filter_initializer.filter_conv_weight = filter_init_weight; |
|
|
|
linear_filter.filter_initializer.filter_conv->weight = filter_init_weight; |
|
|
|
std::cout << "Loaded filter initializer weights with shape: [" |
|
|
|
<< filter_init_weight.size(0) << ", " << filter_init_weight.size(1) << ", " |
|
|
|
<< filter_init_weight.size(2) << ", " << filter_init_weight.size(3) << "]" << std::endl; |
|
|
|
} catch (const std::exception& e) { |
|
|
|
std::cerr << "Error loading filter initializer weights: " << e.what() << std::endl; |
|
|
|
throw; |
|
|
|
} |
|
|
|
} else { |
|
|
|
std::cout << "Skipping filter initializer weights - file not found" << std::endl; |
|
|
|
} |
|
|
|
|
|
|
|
// Skip filter optimizer weights since we don't use them for feature extraction only
|
|
|
|
std::cout << "Skipping filter optimizer weights - not needed for feature extraction" << std::endl; |
|
|
|
// Call load_weights for sub-components directly
|
|
|
|
// The model_dir for Classifier is e.g., ".../exported_weights/classifier"
|
|
|
|
// The sub-component load_weights methods expect this directory.
|
|
|
|
std::cout << "Classifier::load_weights() called. Model directory: " << this->model_dir << std::endl; |
|
|
|
|
|
|
|
linear_filter.feature_extractor.load_weights(this->model_dir, this->device); |
|
|
|
linear_filter.filter_initializer.load_weights(this->model_dir, this->device); |
|
|
|
// Filter optimizer weights are not strictly needed for just feature extraction,
|
|
|
|
// but we can call its load_weights if it has one, for completeness or future use.
|
|
|
|
// linear_filter.filter_optimizer.load_weights(this->model_dir); // Assuming it exists and is safe to call
|
|
|
|
|
|
|
|
std::cout << "Skipping main filter optimizer weights in Classifier::load_weights - not needed for feature extraction" << std::endl; |
|
|
|
} |
|
|
|
|
|
|
|
void Classifier::to(torch::Device device) { |
|
|
@ -382,7 +325,7 @@ void Classifier::save_stats(const std::vector<TensorStats>& all_stats, const std |
|
|
|
} |
|
|
|
|
|
|
|
// Load weights for the model
|
|
|
|
torch::Tensor Classifier::load_tensor(const std::string& file_path) { |
|
|
|
torch::Tensor Classifier::load_tensor(const std::string& file_path, torch::Device device) { |
|
|
|
try { |
|
|
|
// Read file into bytes first
|
|
|
|
std::vector<char> data = read_file_to_bytes(file_path); |
|
|
|