#ifndef CPP_TRACKER_RESNET_H #define CPP_TRACKER_RESNET_H #include #include #include #include #include namespace cimp { namespace resnet { // ResNet-50 Bottleneck block IMPLementation struct BottleneckImpl : torch::nn::Module { // Constructor declaration BottleneckImpl(const std::string& base_weights_dir, const std::string& block_param_prefix, int64_t inplanes, int64_t planes, const torch::Device& device, int64_t stride = 1, std::optional downsample_module_opt = std::nullopt, int64_t expansion_factor_arg = 4); // Forward method declaration torch::Tensor forward(torch::Tensor x); // Member layers (must be declared in the Impl struct) torch::nn::Conv2d conv1{nullptr}, conv2{nullptr}, conv3{nullptr}; torch::nn::BatchNorm2d bn1{nullptr}, bn2{nullptr}, bn3{nullptr}; torch::nn::ReLU relu{nullptr}; torch::nn::Sequential projection_shortcut{nullptr}; int64_t expansion_factor; // Store the expansion factor int64_t stride_member; // To avoid conflict with constructor param name }; // This macro defines the 'Bottleneck' type based on 'BottleneckImpl' // It effectively creates: using Bottleneck = torch::nn::ModuleHolder; TORCH_MODULE(Bottleneck); struct ResNetImpl : torch::nn::Module { ResNetImpl(const std::string& base_weights_dir, const std::vector& layers, const std::vector& output_layers, const torch::Device& device); std::map forward(torch::Tensor x); // Initial layers torch::nn::Conv2d conv1{nullptr}; torch::nn::BatchNorm2d bn1{nullptr}; torch::nn::ReLU relu{nullptr}; torch::nn::MaxPool2d maxpool{nullptr}; // ResNet layers torch::nn::Sequential layer1{nullptr}; torch::nn::Sequential layer2{nullptr}; torch::nn::Sequential layer3{nullptr}; torch::nn::Sequential layer4{nullptr}; // We'll build it, even if not always outputting private: torch::nn::Sequential _make_layer(int64_t planes, int64_t blocks, const std::string& layer_param_prefix, const torch::Device& device, int64_t stride = 1); int64_t inplanes = 64; std::vector _output_layers; std::string _base_weights_dir; // Store base weights directory, e.g. ../exported_weights/raw_backbone static const int expansion = 4; // Bottleneck expansion factor for ResNet layers }; TORCH_MODULE(ResNet); // Factory function for ResNet-50 ResNet resnet50(const std::string& base_weights_dir, const std::vector& output_layers, const torch::Device& device); } // namespace resnet } // namespace cimp #endif //CPP_TRACKER_RESNET_H