You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

113 lines
3.4 KiB

#pragma once
#include <torch/torch.h>
#include <string>
#include <vector>
#include <filesystem>
namespace fs = std::filesystem;
// InstanceL2Norm class to match Python's implementation
class InstanceL2Norm {
public:
InstanceL2Norm(bool size_average = true, float eps = 1e-5, float scale = 1.0);
// Forward function for normalization
torch::Tensor forward(torch::Tensor input);
private:
bool size_average_;
float eps_;
float scale_;
};
// Main classifier class that manages feature extraction
class Classifier {
public:
// Constructor with base directory and device specification
Classifier(const std::string& base_dir, torch::Device device = torch::kCUDA);
// Load all necessary weights
void load_weights();
// Extract features from an input tensor
torch::Tensor extract_features(torch::Tensor input);
// Move model to specified device
void to(torch::Device device);
// Print model information
void print_model_info();
// Helper function to read file to bytes
static std::vector<char> read_file_to_bytes(const std::string& file_path);
// Helper function to load a tensor from a file
static torch::Tensor load_tensor(const std::string& file_path, torch::Device device);
// Statistics structure for tensors
struct TensorStats {
std::vector<int64_t> shape;
float mean;
float std_dev;
float min_val;
float max_val;
float sum;
std::vector<float> samples;
};
// Compute statistics for a tensor
TensorStats compute_stats(const torch::Tensor& tensor);
// Save tensor statistics to a file
void save_stats(const std::vector<TensorStats>& all_stats, const std::string& filepath);
private:
// Feature extractor component
struct FeatureExtractor {
torch::nn::Conv2d conv0{nullptr};
torch::Tensor weight;
InstanceL2Norm norm;
torch::Tensor forward(torch::Tensor x);
torch::Tensor extract_feat(torch::Tensor x);
void load_weights(const std::string& weights_dir, torch::Device device);
};
// Filter initializer component
struct FilterInitializer {
torch::nn::Conv2d filter_conv{nullptr};
torch::Tensor filter_conv_weight;
torch::Tensor forward(torch::Tensor x);
void load_weights(const std::string& weights_dir, torch::Device device);
};
// Filter optimizer component
struct FilterOptimizer {
torch::nn::Conv2d label_map_predictor{nullptr};
torch::nn::Conv2d target_mask_predictor{nullptr};
torch::nn::Conv2d spatial_weight_predictor{nullptr};
torch::Tensor filter_conv_weight;
void load_weights(const std::string& weights_dir);
};
// Linear filter component that combines the above components
struct LinearFilter {
int filter_size;
FeatureExtractor feature_extractor;
FilterInitializer filter_initializer;
FilterOptimizer filter_optimizer;
torch::Tensor filter;
LinearFilter(int filter_size = 4);
void load_weights(const std::string& weights_dir);
torch::Tensor extract_classification_feat(torch::Tensor feat);
};
// Main components - order matters for initialization
std::string model_dir;
torch::Device device;
LinearFilter linear_filter;
};