You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

82 lines
2.9 KiB

#ifndef CPP_TRACKER_RESNET_H
#define CPP_TRACKER_RESNET_H
#include <torch/torch.h>
#include <vector>
#include <string>
#include <map>
#include <optional>
namespace cimp {
namespace resnet {
// ResNet-50 Bottleneck block IMPLementation
struct BottleneckImpl : torch::nn::Module {
// Constructor declaration
BottleneckImpl(const std::string& base_weights_dir,
const std::string& block_param_prefix,
int64_t inplanes, int64_t planes,
const torch::Device& device,
int64_t stride = 1,
std::optional<torch::nn::Sequential> downsample_module_opt = std::nullopt,
int64_t expansion_factor_arg = 4);
// Forward method declaration
torch::Tensor forward(torch::Tensor x);
// Member layers (must be declared in the Impl struct)
torch::nn::Conv2d conv1{nullptr}, conv2{nullptr}, conv3{nullptr};
torch::nn::BatchNorm2d bn1{nullptr}, bn2{nullptr}, bn3{nullptr};
torch::nn::ReLU relu{nullptr};
torch::nn::Sequential projection_shortcut{nullptr};
int64_t expansion_factor; // Store the expansion factor
int64_t stride_member; // To avoid conflict with constructor param name
};
// This macro defines the 'Bottleneck' type based on 'BottleneckImpl'
// It effectively creates: using Bottleneck = torch::nn::ModuleHolder<BottleneckImpl>;
TORCH_MODULE(Bottleneck);
struct ResNetImpl : torch::nn::Module {
ResNetImpl(const std::string& base_weights_dir,
const std::vector<int64_t>& layers,
const std::vector<std::string>& output_layers,
const torch::Device& device);
std::map<std::string, torch::Tensor> forward(torch::Tensor x);
// Initial layers
torch::nn::Conv2d conv1{nullptr};
torch::nn::BatchNorm2d bn1{nullptr};
torch::nn::ReLU relu{nullptr};
torch::nn::MaxPool2d maxpool{nullptr};
// ResNet layers
torch::nn::Sequential layer1{nullptr};
torch::nn::Sequential layer2{nullptr};
torch::nn::Sequential layer3{nullptr};
torch::nn::Sequential layer4{nullptr}; // We'll build it, even if not always outputting
private:
torch::nn::Sequential _make_layer(int64_t planes, int64_t blocks,
const std::string& layer_param_prefix,
const torch::Device& device,
int64_t stride = 1);
int64_t inplanes = 64;
std::vector<std::string> _output_layers;
std::string _base_weights_dir; // Store base weights directory, e.g. ../exported_weights/raw_backbone
static const int expansion = 4; // Bottleneck expansion factor for ResNet layers
};
TORCH_MODULE(ResNet);
// Factory function for ResNet-50
ResNet resnet50(const std::string& base_weights_dir,
const std::vector<std::string>& output_layers,
const torch::Device& device);
} // namespace resnet
} // namespace cimp
#endif //CPP_TRACKER_RESNET_H