You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

146 lines
4.3 KiB

#pragma once
#include <torch/torch.h>
#include <string>
#include <vector>
#include <filesystem>
namespace fs = std::filesystem;
// Forward declaration of PrRoIPool2D
class PrRoIPool2D;
// Linear block for IoU prediction
class LinearBlock : public torch::nn::Module {
public:
LinearBlock(int in_planes = 256, int out_planes = 256, int input_sz = 5, bool bias = true,
bool batch_norm = true, bool relu = true);
torch::Tensor forward(torch::Tensor x);
// Set to evaluation mode
void eval() {
linear->eval();
if (use_bn) {
bn->eval();
}
if (use_relu) {
relu_->eval();
}
}
// Move to device
void to(torch::Device device) {
linear->to(device);
if (use_bn) bn->to(device);
if (use_relu) relu_->to(device);
}
// Public members for direct access to weights
torch::nn::Linear linear{nullptr};
torch::nn::BatchNorm2d bn{nullptr};
torch::nn::ReLU relu_{nullptr};
bool use_bn;
bool use_relu;
};
// PrRoIPool2D implementation
class PrRoIPool2D {
public:
PrRoIPool2D(int pooled_height, int pooled_width, float spatial_scale);
torch::Tensor forward(torch::Tensor feat, torch::Tensor rois);
// CPU-based fallback implementation
torch::Tensor forward_cpu(torch::Tensor feat, torch::Tensor rois) {
// Simple implementation that returns zeros (for fallback only)
int channels = feat.size(1);
int num_rois = rois.size(0);
return torch::zeros({num_rois, channels, pooled_height_, pooled_width_}, feat.options());
}
private:
int pooled_height_;
int pooled_width_;
float spatial_scale_;
};
// BBRegressor class
class BBRegressor {
public:
// Statistics structure for tensors
struct TensorStats {
std::vector<int64_t> shape;
float mean;
float std_dev;
float min_val;
float max_val;
float sum;
std::vector<float> samples;
};
// Constructor with base directory and device specification
BBRegressor(const std::string& base_dir, torch::Device device = torch::kCUDA);
// Set model to evaluation mode
void eval();
// Get IoU features from backbone features
std::vector<torch::Tensor> get_iou_feat(std::vector<torch::Tensor> feat);
// Get modulation vectors for target
std::vector<torch::Tensor> get_modulation(std::vector<torch::Tensor> feat, torch::Tensor bb);
// Predict IoU for proposals
torch::Tensor predict_iou(std::vector<torch::Tensor> modulation,
std::vector<torch::Tensor> feat,
torch::Tensor proposals);
// Move model to device
void to(torch::Device device);
// Print model information
void print_model_info();
// Compute statistics for a tensor
TensorStats compute_stats(const torch::Tensor& tensor);
// Save tensor statistics to a file
void save_stats(const std::vector<TensorStats>& all_stats, const std::string& filepath);
private:
// Helper functions
torch::nn::Sequential create_conv_block(int in_planes, int out_planes, int kernel_size,
int stride, int padding, int dilation);
void verify_batchnorm_dimensions();
std::vector<char> read_file_to_bytes(const std::string& file_path);
torch::Tensor load_tensor(const std::string& file_path);
void load_weights();
// Model state
torch::Device device;
std::string model_dir;
// Convolution blocks
torch::nn::Sequential conv3_1r{nullptr};
torch::nn::Sequential conv3_1t{nullptr};
torch::nn::Sequential conv3_2t{nullptr};
torch::nn::Sequential fc3_1r{nullptr};
torch::nn::Sequential conv4_1r{nullptr};
torch::nn::Sequential conv4_1t{nullptr};
torch::nn::Sequential conv4_2t{nullptr};
torch::nn::Sequential fc34_3r{nullptr};
torch::nn::Sequential fc34_4r{nullptr};
// Pooling layers
std::shared_ptr<PrRoIPool2D> prroi_pool3r;
std::shared_ptr<PrRoIPool2D> prroi_pool3t;
std::shared_ptr<PrRoIPool2D> prroi_pool4r;
std::shared_ptr<PrRoIPool2D> prroi_pool4t;
// Linear blocks
LinearBlock fc3_rt;
LinearBlock fc4_rt;
// IoU predictor
torch::nn::Linear iou_predictor{nullptr};
};