You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
146 lines
4.3 KiB
146 lines
4.3 KiB
#pragma once
|
|
|
|
#include <torch/torch.h>
|
|
#include <string>
|
|
#include <vector>
|
|
#include <filesystem>
|
|
|
|
namespace fs = std::filesystem;
|
|
|
|
// Forward declaration of PrRoIPool2D
|
|
class PrRoIPool2D;
|
|
|
|
// Linear block for IoU prediction
|
|
class LinearBlock : public torch::nn::Module {
|
|
public:
|
|
LinearBlock(int in_planes = 256, int out_planes = 256, int input_sz = 5, bool bias = true,
|
|
bool batch_norm = true, bool relu = true);
|
|
|
|
torch::Tensor forward(torch::Tensor x);
|
|
|
|
// Set to evaluation mode
|
|
void eval() {
|
|
linear->eval();
|
|
if (use_bn) {
|
|
bn->eval();
|
|
}
|
|
if (use_relu) {
|
|
relu_->eval();
|
|
}
|
|
}
|
|
|
|
// Move to device
|
|
void to(torch::Device device) {
|
|
linear->to(device);
|
|
if (use_bn) bn->to(device);
|
|
if (use_relu) relu_->to(device);
|
|
}
|
|
|
|
// Public members for direct access to weights
|
|
torch::nn::Linear linear{nullptr};
|
|
torch::nn::BatchNorm2d bn{nullptr};
|
|
torch::nn::ReLU relu_{nullptr};
|
|
bool use_bn;
|
|
bool use_relu;
|
|
};
|
|
|
|
// PrRoIPool2D implementation
|
|
class PrRoIPool2D {
|
|
public:
|
|
PrRoIPool2D(int pooled_height, int pooled_width, float spatial_scale);
|
|
torch::Tensor forward(torch::Tensor feat, torch::Tensor rois);
|
|
|
|
// CPU-based fallback implementation
|
|
torch::Tensor forward_cpu(torch::Tensor feat, torch::Tensor rois) {
|
|
// Simple implementation that returns zeros (for fallback only)
|
|
int channels = feat.size(1);
|
|
int num_rois = rois.size(0);
|
|
return torch::zeros({num_rois, channels, pooled_height_, pooled_width_}, feat.options());
|
|
}
|
|
|
|
private:
|
|
int pooled_height_;
|
|
int pooled_width_;
|
|
float spatial_scale_;
|
|
};
|
|
|
|
// BBRegressor class
|
|
class BBRegressor {
|
|
public:
|
|
// Statistics structure for tensors
|
|
struct TensorStats {
|
|
std::vector<int64_t> shape;
|
|
float mean;
|
|
float std_dev;
|
|
float min_val;
|
|
float max_val;
|
|
float sum;
|
|
std::vector<float> samples;
|
|
};
|
|
|
|
// Constructor with base directory and device specification
|
|
BBRegressor(const std::string& base_dir, torch::Device device = torch::kCUDA);
|
|
|
|
// Set model to evaluation mode
|
|
void eval();
|
|
|
|
// Get IoU features from backbone features
|
|
std::vector<torch::Tensor> get_iou_feat(std::vector<torch::Tensor> feat);
|
|
|
|
// Get modulation vectors for target
|
|
std::vector<torch::Tensor> get_modulation(std::vector<torch::Tensor> feat, torch::Tensor bb);
|
|
|
|
// Predict IoU for proposals
|
|
torch::Tensor predict_iou(std::vector<torch::Tensor> modulation,
|
|
std::vector<torch::Tensor> feat,
|
|
torch::Tensor proposals);
|
|
|
|
// Move model to device
|
|
void to(torch::Device device);
|
|
|
|
// Print model information
|
|
void print_model_info();
|
|
|
|
// Compute statistics for a tensor
|
|
TensorStats compute_stats(const torch::Tensor& tensor);
|
|
|
|
// Save tensor statistics to a file
|
|
void save_stats(const std::vector<TensorStats>& all_stats, const std::string& filepath);
|
|
|
|
private:
|
|
// Helper functions
|
|
torch::nn::Sequential create_conv_block(int in_planes, int out_planes, int kernel_size,
|
|
int stride, int padding, int dilation);
|
|
void verify_batchnorm_dimensions();
|
|
std::vector<char> read_file_to_bytes(const std::string& file_path);
|
|
torch::Tensor load_tensor(const std::string& file_path);
|
|
void load_weights();
|
|
|
|
// Model state
|
|
torch::Device device;
|
|
std::string model_dir;
|
|
|
|
// Convolution blocks
|
|
torch::nn::Sequential conv3_1r{nullptr};
|
|
torch::nn::Sequential conv3_1t{nullptr};
|
|
torch::nn::Sequential conv3_2t{nullptr};
|
|
torch::nn::Sequential fc3_1r{nullptr};
|
|
torch::nn::Sequential conv4_1r{nullptr};
|
|
torch::nn::Sequential conv4_1t{nullptr};
|
|
torch::nn::Sequential conv4_2t{nullptr};
|
|
torch::nn::Sequential fc34_3r{nullptr};
|
|
torch::nn::Sequential fc34_4r{nullptr};
|
|
|
|
// Pooling layers
|
|
std::shared_ptr<PrRoIPool2D> prroi_pool3r;
|
|
std::shared_ptr<PrRoIPool2D> prroi_pool3t;
|
|
std::shared_ptr<PrRoIPool2D> prroi_pool4r;
|
|
std::shared_ptr<PrRoIPool2D> prroi_pool4t;
|
|
|
|
// Linear blocks
|
|
LinearBlock fc3_rt;
|
|
LinearBlock fc4_rt;
|
|
|
|
// IoU predictor
|
|
torch::nn::Linear iou_predictor{nullptr};
|
|
};
|