#pragma once #include #include #include #include namespace fs = std::filesystem; // InstanceL2Norm class to match Python's implementation class InstanceL2Norm { public: InstanceL2Norm(bool size_average = true, float eps = 1e-5, float scale = 1.0); // Forward function for normalization torch::Tensor forward(torch::Tensor input); private: bool size_average_; float eps_; float scale_; }; // Main classifier class that manages feature extraction class Classifier { public: // Constructor with base directory and device specification Classifier(const std::string& base_dir, torch::Device device = torch::kCUDA); // Load all necessary weights void load_weights(); // Extract features from an input tensor torch::Tensor extract_features(torch::Tensor input); // Move model to specified device void to(torch::Device device); // Print model information void print_model_info(); // Helper function to read file to bytes static std::vector read_file_to_bytes(const std::string& file_path); // Helper function to load a tensor from a file torch::Tensor load_tensor(const std::string& file_path); // Statistics structure for tensors struct TensorStats { std::vector shape; float mean; float std_dev; float min_val; float max_val; float sum; std::vector samples; }; // Compute statistics for a tensor TensorStats compute_stats(const torch::Tensor& tensor); // Save tensor statistics to a file void save_stats(const std::vector& all_stats, const std::string& filepath); private: // Feature extractor component struct FeatureExtractor { torch::nn::Conv2d conv0{nullptr}; torch::Tensor weight; InstanceL2Norm norm; torch::Tensor forward(torch::Tensor x); torch::Tensor extract_feat(torch::Tensor x); void load_weights(const std::string& weights_dir); }; // Filter initializer component struct FilterInitializer { torch::nn::Conv2d filter_conv{nullptr}; torch::Tensor filter_conv_weight; torch::Tensor forward(torch::Tensor x); void load_weights(const std::string& weights_dir); }; // Filter optimizer component struct FilterOptimizer { torch::nn::Conv2d label_map_predictor{nullptr}; torch::nn::Conv2d target_mask_predictor{nullptr}; torch::nn::Conv2d spatial_weight_predictor{nullptr}; torch::Tensor filter_conv_weight; void load_weights(const std::string& weights_dir); }; // Linear filter component that combines the above components struct LinearFilter { int filter_size; FeatureExtractor feature_extractor; FilterInitializer filter_initializer; FilterOptimizer filter_optimizer; torch::Tensor filter; LinearFilter(int filter_size = 4); void load_weights(const std::string& weights_dir); torch::Tensor extract_classification_feat(torch::Tensor feat); }; // Main components - order matters for initialization std::string model_dir; torch::Device device; LinearFilter linear_filter; };