You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
82 lines
2.9 KiB
82 lines
2.9 KiB
#ifndef CPP_TRACKER_RESNET_H
|
|
#define CPP_TRACKER_RESNET_H
|
|
|
|
#include <torch/torch.h>
|
|
#include <vector>
|
|
#include <string>
|
|
#include <map>
|
|
#include <optional>
|
|
|
|
namespace cimp {
|
|
namespace resnet {
|
|
|
|
// ResNet-50 Bottleneck block IMPLementation
|
|
struct BottleneckImpl : torch::nn::Module {
|
|
// Constructor declaration
|
|
BottleneckImpl(const std::string& base_weights_dir,
|
|
const std::string& block_param_prefix,
|
|
int64_t inplanes, int64_t planes,
|
|
const torch::Device& device,
|
|
int64_t stride = 1,
|
|
std::optional<torch::nn::Sequential> downsample_module_opt = std::nullopt,
|
|
int64_t expansion_factor_arg = 4);
|
|
// Forward method declaration
|
|
torch::Tensor forward(torch::Tensor x);
|
|
|
|
// Member layers (must be declared in the Impl struct)
|
|
torch::nn::Conv2d conv1{nullptr}, conv2{nullptr}, conv3{nullptr};
|
|
torch::nn::BatchNorm2d bn1{nullptr}, bn2{nullptr}, bn3{nullptr};
|
|
torch::nn::ReLU relu{nullptr};
|
|
torch::nn::Sequential projection_shortcut{nullptr};
|
|
int64_t expansion_factor; // Store the expansion factor
|
|
int64_t stride_member; // To avoid conflict with constructor param name
|
|
};
|
|
|
|
// This macro defines the 'Bottleneck' type based on 'BottleneckImpl'
|
|
// It effectively creates: using Bottleneck = torch::nn::ModuleHolder<BottleneckImpl>;
|
|
TORCH_MODULE(Bottleneck);
|
|
|
|
struct ResNetImpl : torch::nn::Module {
|
|
ResNetImpl(const std::string& base_weights_dir,
|
|
const std::vector<int64_t>& layers,
|
|
const std::vector<std::string>& output_layers,
|
|
const torch::Device& device);
|
|
|
|
std::map<std::string, torch::Tensor> forward(torch::Tensor x);
|
|
|
|
// Initial layers
|
|
torch::nn::Conv2d conv1{nullptr};
|
|
torch::nn::BatchNorm2d bn1{nullptr};
|
|
torch::nn::ReLU relu{nullptr};
|
|
torch::nn::MaxPool2d maxpool{nullptr};
|
|
|
|
// ResNet layers
|
|
torch::nn::Sequential layer1{nullptr};
|
|
torch::nn::Sequential layer2{nullptr};
|
|
torch::nn::Sequential layer3{nullptr};
|
|
torch::nn::Sequential layer4{nullptr}; // We'll build it, even if not always outputting
|
|
|
|
private:
|
|
torch::nn::Sequential _make_layer(int64_t planes, int64_t blocks,
|
|
const std::string& layer_param_prefix,
|
|
const torch::Device& device,
|
|
int64_t stride = 1);
|
|
int64_t inplanes = 64;
|
|
std::vector<std::string> _output_layers;
|
|
std::string _base_weights_dir; // Store base weights directory, e.g. ../exported_weights/raw_backbone
|
|
|
|
static const int expansion = 4; // Bottleneck expansion factor for ResNet layers
|
|
};
|
|
|
|
TORCH_MODULE(ResNet);
|
|
|
|
// Factory function for ResNet-50
|
|
ResNet resnet50(const std::string& base_weights_dir,
|
|
const std::vector<std::string>& output_layers,
|
|
const torch::Device& device);
|
|
|
|
|
|
} // namespace resnet
|
|
} // namespace cimp
|
|
|
|
#endif //CPP_TRACKER_RESNET_H
|