You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

345 lines
14 KiB

#include "classifier.h"
#include <iostream>
#include <fstream>
#include <torch/script.h>
#include <torch/serialize.h>
#include <vector>
#include <stdexcept>
// InstanceL2Norm implementation
InstanceL2Norm::InstanceL2Norm(bool size_average, float eps, float scale)
: size_average_(size_average), eps_(eps), scale_(scale) {}
torch::Tensor InstanceL2Norm::forward(torch::Tensor input) {
// Print tensor properties for debugging
static bool first_call = true;
if (first_call) {
std::cout << "InstanceL2Norm debug info:" << std::endl;
std::cout << " Input tensor type: " << input.dtype() << std::endl;
std::cout << " Input device: " << input.device() << std::endl;
std::cout << " Size average: " << (size_average_ ? "true" : "false") << std::endl;
std::cout << " Epsilon: " << eps_ << std::endl;
std::cout << " Scale factor: " << scale_ << std::endl;
first_call = false;
}
if (size_average_) {
// Convert input to double precision for more accurate calculations
torch::Tensor input_double = input.to(torch::kFloat64);
// Calculate in double precision
auto dims_product = static_cast<double>(input.size(1) * input.size(2) * input.size(3));
auto squared = input_double * input_double;
auto sum_squared = torch::sum(squared.view({input_double.size(0), 1, 1, -1}), /*dim=*/3, /*keepdim=*/true);
// Calculate normalization factor in double precision
auto norm_factor = scale_ * torch::sqrt((dims_product) / (sum_squared + eps_));
// Apply normalization and convert back to original dtype
auto result_double = input_double * norm_factor;
torch::Tensor result = result_double.to(input.dtype());
return result;
} else {
// Same approach for non-size_average case
torch::Tensor input_double = input.to(torch::kFloat64);
auto squared = input_double * input_double;
auto sum_squared = torch::sum(squared.view({input_double.size(0), 1, 1, -1}), /*dim=*/3, /*keepdim=*/true);
auto norm_factor = scale_ / torch::sqrt(sum_squared + eps_);
auto result_double = input_double * norm_factor;
return result_double.to(input.dtype());
}
}
// Helper function to read file to bytes
std::vector<char> Classifier::read_file_to_bytes(const std::string& file_path) {
std::ifstream file(file_path, std::ios::binary | std::ios::ate);
if (!file.is_open()) {
throw std::runtime_error("Could not open file: " + file_path);
}
std::streamsize size = file.tellg();
file.seekg(0, std::ios::beg);
std::vector<char> buffer(size);
if (!file.read(buffer.data(), size)) {
throw std::runtime_error("Could not read file: " + file_path);
}
return buffer;
}
// Feature Extractor implementation
torch::Tensor Classifier::FeatureExtractor::forward(torch::Tensor x) {
return conv0->forward(x);
}
torch::Tensor Classifier::FeatureExtractor::extract_feat(torch::Tensor x) {
// Apply conv followed by normalization
auto features = forward(x);
// Apply the general normalization
return norm.forward(features); // Simplified: removed special channel handling
}
void Classifier::FeatureExtractor::load_weights(const std::string& weights_dir, torch::Device device) {
try {
std::string file_path = weights_dir + "/feature_extractor_0_weight.pt";
// Call static Classifier::load_tensor with the device
weight = Classifier::load_tensor(file_path, device);
// Assign weights to the conv layer
conv0->weight = weight;
std::cout << "Loaded feature extractor weights with shape: "
<< weight.sizes() << " on device " << weight.device() << std::endl;
} catch (const std::exception& e) {
std::cerr << "Error loading feature extractor weights: " << e.what() << std::endl;
}
}
// Filter Initializer implementation
torch::Tensor Classifier::FilterInitializer::forward(torch::Tensor x) {
return filter_conv->forward(x);
}
void Classifier::FilterInitializer::load_weights(const std::string& weights_dir, torch::Device device) {
try {
std::string file_path_weight = weights_dir + "/filter_initializer_filter_conv_weight.pt";
filter_conv_weight = Classifier::load_tensor(file_path_weight, device);
filter_conv->weight = filter_conv_weight;
std::cout << "Loaded filter initializer conv weights with shape: "
<< filter_conv_weight.sizes() << " on device " << filter_conv_weight.device() << std::endl;
// Also load bias if it exists (and if filter_conv is configured to use bias)
std::string file_path_bias = weights_dir + "/filter_initializer_filter_conv_bias.pt";
if (fs::exists(file_path_bias)) {
if (filter_conv->options.bias()) { // Check if bias is enabled for the layer
torch::Tensor bias_tensor = Classifier::load_tensor(file_path_bias, device);
filter_conv->bias = bias_tensor;
std::cout << "Loaded filter initializer conv bias with shape: "
<< bias_tensor.sizes() << " on device " << bias_tensor.device() << std::endl;
} else {
std::cout << "Skipping filter_initializer_filter_conv_bias.pt as bias is false for the layer." << std::endl;
}
} else {
std::cout << "Filter initializer bias file not found: " << file_path_bias << std::endl;
}
} catch (const std::exception& e) {
std::cerr << "Error loading filter initializer weights: " << e.what() << std::endl;
}
}
// Filter Optimizer implementation
void Classifier::FilterOptimizer::load_weights(const std::string& /* weights_dir */) {
try {
std::cout << "Skipping filter optimizer weights - not needed for feature extraction" << std::endl;
} catch (const c10::Error& e) {
std::cerr << "Error loading filter optimizer weights: " << e.what() << std::endl;
}
}
// Linear Filter implementation
Classifier::LinearFilter::LinearFilter(int filter_size) : filter_size(filter_size) {}
void Classifier::LinearFilter::load_weights(const std::string& /* weights_dir */) {
try {
std::cout << "Skipping filter weights - not needed for feature extraction" << std::endl;
} catch (const c10::Error& e) {
std::cerr << "Error loading filter weights: " << e.what() << std::endl;
}
}
torch::Tensor Classifier::LinearFilter::extract_classification_feat(torch::Tensor feat) {
// Apply feature extractor
return feature_extractor.extract_feat(feat);
}
// Classifier implementation
Classifier::Classifier(const std::string& model_weights_dir, torch::Device dev)
: device(dev), model_dir(model_weights_dir), linear_filter(4) {
// Check if model directory exists
if (!fs::exists(model_dir)) {
throw std::runtime_error("Model directory does not exist: " + model_dir);
}
// Initialize feature extractor with appropriate parameters
linear_filter.feature_extractor.conv0 = torch::nn::Conv2d(torch::nn::Conv2dOptions(1024, 512, 3).padding(1).bias(false));
linear_filter.feature_extractor.norm = InstanceL2Norm(true, 1e-5, 0.011048543456039804);
// Initialize filter initializer
linear_filter.filter_initializer.filter_conv = torch::nn::Conv2d(torch::nn::Conv2dOptions(512, 512, 3).padding(1));
// Initialize filter optimizer components
linear_filter.filter_optimizer.label_map_predictor = torch::nn::Conv2d(torch::nn::Conv2dOptions(100, 1, 1).bias(false));
linear_filter.filter_optimizer.target_mask_predictor = torch::nn::Conv2d(torch::nn::Conv2dOptions(100, 1, 1).bias(false));
linear_filter.filter_optimizer.spatial_weight_predictor = torch::nn::Conv2d(torch::nn::Conv2dOptions(100, 1, 1).bias(false));
// Load weights
load_weights();
// Move model to device
to(device);
}
// Load weights for the feature extractor and filter
void Classifier::load_weights() {
// Call load_weights for sub-components directly
// The model_dir for Classifier is e.g., ".../exported_weights/classifier"
// The sub-component load_weights methods expect this directory.
std::cout << "Classifier::load_weights() called. Model directory: " << this->model_dir << std::endl;
linear_filter.feature_extractor.load_weights(this->model_dir, this->device);
linear_filter.filter_initializer.load_weights(this->model_dir, this->device);
// Filter optimizer weights are not strictly needed for just feature extraction,
// but we can call its load_weights if it has one, for completeness or future use.
// linear_filter.filter_optimizer.load_weights(this->model_dir); // Assuming it exists and is safe to call
std::cout << "Skipping main filter optimizer weights in Classifier::load_weights - not needed for feature extraction" << std::endl;
}
void Classifier::to(torch::Device device) {
this->device = device;
// Move all tensors to device
if (linear_filter.feature_extractor.weight.defined()) {
linear_filter.feature_extractor.weight = linear_filter.feature_extractor.weight.to(device);
}
if (linear_filter.feature_extractor.conv0) {
linear_filter.feature_extractor.conv0->to(device);
}
if (linear_filter.filter_initializer.filter_conv_weight.defined()) {
linear_filter.filter_initializer.filter_conv_weight = linear_filter.filter_initializer.filter_conv_weight.to(device);
}
if (linear_filter.filter_initializer.filter_conv) {
linear_filter.filter_initializer.filter_conv->to(device);
}
if (linear_filter.filter_optimizer.label_map_predictor) {
linear_filter.filter_optimizer.label_map_predictor->to(device);
}
if (linear_filter.filter_optimizer.target_mask_predictor) {
linear_filter.filter_optimizer.target_mask_predictor->to(device);
}
if (linear_filter.filter_optimizer.spatial_weight_predictor) {
linear_filter.filter_optimizer.spatial_weight_predictor->to(device);
}
if (linear_filter.filter_optimizer.filter_conv_weight.defined()) {
linear_filter.filter_optimizer.filter_conv_weight =
linear_filter.filter_optimizer.filter_conv_weight.to(device);
}
}
void Classifier::print_model_info() {
std::cout << "Classifier Model Information:" << std::endl;
std::cout << " - Model directory: " << model_dir << std::endl;
std::cout << " - Device: " << (device.is_cuda() ? "CUDA" : "CPU") << std::endl;
std::cout << " - Filter size: " << linear_filter.filter_size << std::endl;
if (device.is_cuda()) {
std::cout << " - CUDA Device: " << device.index() << std::endl;
std::cout << " - CUDA Available: " << (torch::cuda::is_available() ? "Yes" : "No") << std::endl;
if (torch::cuda::is_available()) {
std::cout << " - CUDA Device Count: " << torch::cuda::device_count() << std::endl;
}
}
}
torch::Tensor Classifier::extract_features(torch::Tensor input) {
// Ensure input tensor has a device
if (!input.device().is_cuda() && device.is_cuda()) {
input = input.to(device);
} else if (input.device() != device) {
input = input.to(device);
}
return linear_filter.extract_classification_feat(input);
}
// Compute stats for a tensor
Classifier::TensorStats Classifier::compute_stats(const torch::Tensor& tensor) {
TensorStats stats;
// Get shape
for (int i = 0; i < tensor.dim(); i++) {
stats.shape.push_back(tensor.size(i));
}
// Compute basic stats
stats.mean = tensor.mean().item<float>();
stats.std_dev = tensor.std().item<float>();
stats.min_val = tensor.min().item<float>();
stats.max_val = tensor.max().item<float>();
stats.sum = tensor.sum().item<float>();
// Sample values at specific positions
stats.samples.push_back(tensor[0][0][0][0].item<float>());
int mid_c = tensor.size(1) / 2;
int mid_h = tensor.size(2) / 2;
int mid_w = tensor.size(3) / 2;
stats.samples.push_back(tensor[0][mid_c][mid_h][mid_w].item<float>());
stats.samples.push_back(tensor[0][-1][-1][-1].item<float>());
return stats;
}
// Save tensor stats to a file
void Classifier::save_stats(const std::vector<TensorStats>& all_stats, const std::string& filepath) {
std::ofstream file(filepath);
if (!file.is_open()) {
std::cerr << "Error opening file for writing: " << filepath << std::endl;
return;
}
for (size_t i = 0; i < all_stats.size(); i++) {
const auto& stats = all_stats[i];
file << "Output " << i << ":" << std::endl;
file << " Shape: [";
for (size_t j = 0; j < stats.shape.size(); j++) {
file << stats.shape[j];
if (j < stats.shape.size() - 1) file << ", ";
}
file << "]" << std::endl;
file << " Mean: " << stats.mean << std::endl;
file << " Std: " << stats.std_dev << std::endl;
file << " Min: " << stats.min_val << std::endl;
file << " Max: " << stats.max_val << std::endl;
file << " Sum: " << stats.sum << std::endl;
file << " Sample values: [";
for (size_t j = 0; j < stats.samples.size(); j++) {
file << stats.samples[j];
if (j < stats.samples.size() - 1) file << ", ";
}
file << "]" << std::endl << std::endl;
}
file.close();
}
// Load weights for the model
torch::Tensor Classifier::load_tensor(const std::string& file_path, torch::Device device) {
try {
// Read file into bytes first
std::vector<char> data = read_file_to_bytes(file_path);
// Use pickle_load with byte data
torch::Tensor tensor = torch::pickle_load(data).toTensor();
// Always move tensor to the specified device
if (tensor.device() != device) {
tensor = tensor.to(device);
}
return tensor;
} catch (const std::exception& e) {
std::cerr << "Error loading tensor from " << file_path << ": " << e.what() << std::endl;
throw;
}
}