You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
113 lines
3.4 KiB
113 lines
3.4 KiB
#pragma once
|
|
|
|
#include <torch/torch.h>
|
|
#include <string>
|
|
#include <vector>
|
|
#include <filesystem>
|
|
|
|
namespace fs = std::filesystem;
|
|
|
|
// InstanceL2Norm class to match Python's implementation
|
|
class InstanceL2Norm {
|
|
public:
|
|
InstanceL2Norm(bool size_average = true, float eps = 1e-5, float scale = 1.0);
|
|
|
|
// Forward function for normalization
|
|
torch::Tensor forward(torch::Tensor input);
|
|
|
|
private:
|
|
bool size_average_;
|
|
float eps_;
|
|
float scale_;
|
|
};
|
|
|
|
// Main classifier class that manages feature extraction
|
|
class Classifier {
|
|
public:
|
|
// Constructor with base directory and device specification
|
|
Classifier(const std::string& base_dir, torch::Device device = torch::kCUDA);
|
|
|
|
// Load all necessary weights
|
|
void load_weights();
|
|
|
|
// Extract features from an input tensor
|
|
torch::Tensor extract_features(torch::Tensor input);
|
|
|
|
// Move model to specified device
|
|
void to(torch::Device device);
|
|
|
|
// Print model information
|
|
void print_model_info();
|
|
|
|
// Helper function to read file to bytes
|
|
static std::vector<char> read_file_to_bytes(const std::string& file_path);
|
|
|
|
// Helper function to load a tensor from a file
|
|
static torch::Tensor load_tensor(const std::string& file_path, torch::Device device);
|
|
|
|
// Statistics structure for tensors
|
|
struct TensorStats {
|
|
std::vector<int64_t> shape;
|
|
float mean;
|
|
float std_dev;
|
|
float min_val;
|
|
float max_val;
|
|
float sum;
|
|
std::vector<float> samples;
|
|
};
|
|
|
|
// Compute statistics for a tensor
|
|
TensorStats compute_stats(const torch::Tensor& tensor);
|
|
|
|
// Save tensor statistics to a file
|
|
void save_stats(const std::vector<TensorStats>& all_stats, const std::string& filepath);
|
|
|
|
private:
|
|
// Feature extractor component
|
|
struct FeatureExtractor {
|
|
torch::nn::Conv2d conv0{nullptr};
|
|
torch::Tensor weight;
|
|
InstanceL2Norm norm;
|
|
|
|
torch::Tensor forward(torch::Tensor x);
|
|
torch::Tensor extract_feat(torch::Tensor x);
|
|
void load_weights(const std::string& weights_dir, torch::Device device);
|
|
};
|
|
|
|
// Filter initializer component
|
|
struct FilterInitializer {
|
|
torch::nn::Conv2d filter_conv{nullptr};
|
|
torch::Tensor filter_conv_weight;
|
|
|
|
torch::Tensor forward(torch::Tensor x);
|
|
void load_weights(const std::string& weights_dir, torch::Device device);
|
|
};
|
|
|
|
// Filter optimizer component
|
|
struct FilterOptimizer {
|
|
torch::nn::Conv2d label_map_predictor{nullptr};
|
|
torch::nn::Conv2d target_mask_predictor{nullptr};
|
|
torch::nn::Conv2d spatial_weight_predictor{nullptr};
|
|
torch::Tensor filter_conv_weight;
|
|
|
|
void load_weights(const std::string& weights_dir);
|
|
};
|
|
|
|
// Linear filter component that combines the above components
|
|
struct LinearFilter {
|
|
int filter_size;
|
|
FeatureExtractor feature_extractor;
|
|
FilterInitializer filter_initializer;
|
|
FilterOptimizer filter_optimizer;
|
|
torch::Tensor filter;
|
|
|
|
LinearFilter(int filter_size = 4);
|
|
void load_weights(const std::string& weights_dir);
|
|
torch::Tensor extract_classification_feat(torch::Tensor feat);
|
|
};
|
|
|
|
// Main components - order matters for initialization
|
|
std::string model_dir;
|
|
torch::Device device;
|
|
LinearFilter linear_filter;
|
|
};
|