You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

330 lines
16 KiB

#include "resnet.h"
#include <filesystem> // For std::filesystem::path
#include <stdexcept> // For std::runtime_error
#include <torch/script.h> // For torch::jit::load and torch::jit::Module
#include <optional> // ensure this is included
#include <fstream> // Added for std::ifstream
#include <vector> // Added for std::vector
#include <iomanip> // Added for std::fixed and std::setprecision
namespace cimp {
namespace resnet {
namespace fs = std::filesystem; // Moved fs namespace alias here
// Helper function to load a tensor by its parameter name (e.g., "conv1.weight")
// Assumes .pt files are named like "conv1.weight.pt", "layer1.0.bn1.running_mean.pt"
torch::Tensor load_named_tensor(const std::string& base_weights_dir, const std::string& param_name_original, const torch::Device& device) {
fs::path tensor_file_path = fs::path(base_weights_dir) / (param_name_original + ".pt");
if (!fs::exists(tensor_file_path)) {
std::string param_name_underscore = param_name_original;
std::replace(param_name_underscore.begin(), param_name_underscore.end(), '.', '_');
fs::path tensor_file_path_underscore = fs::path(base_weights_dir) / (param_name_underscore + ".pt");
if (fs::exists(tensor_file_path_underscore)) {
std::cout << "INFO: Using underscore-named file for C++ loading: " << tensor_file_path_underscore.string() << std::endl;
tensor_file_path = tensor_file_path_underscore;
} else {
throw std::runtime_error("Weight file not found (tried direct and underscore versions): " +
(fs::path(base_weights_dir) / (param_name_original + ".pt")).string() +
" and " + tensor_file_path_underscore.string());
}
}
std::cout << "Attempting direct torch::pickle_load for tensor: " << tensor_file_path.string() << std::endl;
try {
// Read the file into a vector<char>
std::ifstream file_stream(tensor_file_path.string(), std::ios::binary);
if (!file_stream) {
throw std::runtime_error("Failed to open file: " + tensor_file_path.string());
}
std::vector<char> file_buffer((std::istreambuf_iterator<char>(file_stream)),
std::istreambuf_iterator<char>());
file_stream.close();
c10::IValue ivalue = torch::pickle_load(file_buffer);
return ivalue.toTensor().to(device);
} catch (const c10::Error& e) {
std::cerr << "CRITICAL ERROR: torch::pickle_load FAILED for '" << tensor_file_path.string() << "'. Error: " << e.what() << std::endl;
throw;
}
}
// --- BottleneckImpl Method Definitions ---
// Constructor implementation for BottleneckImpl
// Signature must match resnet.h:
// BottleneckImpl(int64_t inplanes, int64_t planes, const std::string& weights_dir_prefix, int64_t stride = 1,
// std::optional<torch::nn::Sequential> downsample_module_opt = std::nullopt, int64_t expansion_factor_arg = 4);
BottleneckImpl::BottleneckImpl(const std::string& base_weights_dir,
const std::string& block_param_prefix,
int64_t inplanes, int64_t planes,
const torch::Device& device,
int64_t stride_param, std::optional<torch::nn::Sequential> downsample_module_opt,
int64_t expansion_factor_arg)
: expansion_factor(expansion_factor_arg), stride_member(stride_param) {
// conv1
conv1 = torch::nn::Conv2d(torch::nn::Conv2dOptions(inplanes, planes, 1).bias(false));
bn1 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(planes).eps(static_cast<float>(1e-5)).momentum(0.1).affine(true).track_running_stats(true));
conv1->weight = load_named_tensor(base_weights_dir, block_param_prefix + "conv1.weight", device);
bn1->weight = load_named_tensor(base_weights_dir, block_param_prefix + "bn1.weight", device);
bn1->bias = load_named_tensor(base_weights_dir, block_param_prefix + "bn1.bias", device);
bn1->running_mean = load_named_tensor(base_weights_dir, block_param_prefix + "bn1.running_mean", device);
bn1->running_var = load_named_tensor(base_weights_dir, block_param_prefix + "bn1.running_var", device);
bn1->num_batches_tracked = load_named_tensor(base_weights_dir, block_param_prefix + "bn1.num_batches_tracked", device);
register_module("conv1", conv1);
register_module("bn1", bn1);
// conv2
conv2 = torch::nn::Conv2d(torch::nn::Conv2dOptions(planes, planes, 3).stride(stride_member).padding(1).bias(false));
bn2 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(planes).eps(static_cast<float>(1e-5)).momentum(0.1).affine(true).track_running_stats(true));
conv2->weight = load_named_tensor(base_weights_dir, block_param_prefix + "conv2.weight", device);
bn2->weight = load_named_tensor(base_weights_dir, block_param_prefix + "bn2.weight", device);
bn2->bias = load_named_tensor(base_weights_dir, block_param_prefix + "bn2.bias", device);
bn2->running_mean = load_named_tensor(base_weights_dir, block_param_prefix + "bn2.running_mean", device);
bn2->running_var = load_named_tensor(base_weights_dir, block_param_prefix + "bn2.running_var", device);
bn2->num_batches_tracked = load_named_tensor(base_weights_dir, block_param_prefix + "bn2.num_batches_tracked", device);
register_module("conv2", conv2);
register_module("bn2", bn2);
// conv3
conv3 = torch::nn::Conv2d(torch::nn::Conv2dOptions(planes, planes * expansion_factor, 1).bias(false));
bn3 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(planes * expansion_factor).eps(static_cast<float>(1e-5)).momentum(0.1).affine(true).track_running_stats(true));
conv3->weight = load_named_tensor(base_weights_dir, block_param_prefix + "conv3.weight", device);
bn3->weight = load_named_tensor(base_weights_dir, block_param_prefix + "bn3.weight", device);
bn3->bias = load_named_tensor(base_weights_dir, block_param_prefix + "bn3.bias", device);
bn3->running_mean = load_named_tensor(base_weights_dir, block_param_prefix + "bn3.running_mean", device);
bn3->running_var = load_named_tensor(base_weights_dir, block_param_prefix + "bn3.running_var", device);
bn3->num_batches_tracked = load_named_tensor(base_weights_dir, block_param_prefix + "bn3.num_batches_tracked", device);
register_module("conv3", conv3);
register_module("bn3", bn3);
relu = torch::nn::ReLU(torch::nn::ReLUOptions(true));
register_module("relu", relu);
if (downsample_module_opt.has_value()) {
this->projection_shortcut = downsample_module_opt.value(); // Assign the passed Sequential module
// Weights for the submodules of projection_shortcut (conv & bn) are loaded by _make_layer
// before this module is passed. Here, we just register it.
register_module("projection_shortcut", this->projection_shortcut);
} else {
this->projection_shortcut = nullptr;
}
}
// Forward method implementation for BottleneckImpl
torch::Tensor BottleneckImpl::forward(torch::Tensor x) {
torch::Tensor identity = x;
torch::ScalarType original_dtype = x.scalar_type();
// conv1 -> bn1 -> relu
x = conv1->forward(x);
if (bn1) {
x = bn1->forward(x);
}
x = relu->forward(x);
// conv2 -> bn2 -> relu
x = conv2->forward(x);
if (bn2) {
x = bn2->forward(x);
}
x = relu->forward(x);
// conv3 -> bn3
x = conv3->forward(x);
if (bn3) {
x = bn3->forward(x);
}
if (this->projection_shortcut) {
identity = this->projection_shortcut->forward(identity);
}
x += identity;
x = relu->forward(x);
return x;
}
// --- ResNetImpl Method Definitions ---
ResNetImpl::ResNetImpl(const std::string& base_weights_dir_path,
const std::vector<int64_t>& layers_dims,
const std::vector<std::string>& output_layers_param,
const torch::Device& device)
: _output_layers(output_layers_param), _base_weights_dir(base_weights_dir_path) {
conv1 = torch::nn::Conv2d(torch::nn::Conv2dOptions(3, 64, 7).stride(2).padding(3).bias(false));
bn1 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(64).eps(static_cast<float>(1e-5)).momentum(0.1).affine(true).track_running_stats(true));
this->conv1->weight = load_named_tensor(this->_base_weights_dir, "conv1.weight", device);
// Directly assign to the public member tensors of the bn1 module
this->bn1->weight = load_named_tensor(this->_base_weights_dir, "bn1.weight", device);
this->bn1->bias = load_named_tensor(this->_base_weights_dir, "bn1.bias", device);
this->bn1->running_mean = load_named_tensor(this->_base_weights_dir, "bn1.running_mean", device);
this->bn1->running_var = load_named_tensor(this->_base_weights_dir, "bn1.running_var", device);
this->bn1->num_batches_tracked = load_named_tensor(this->_base_weights_dir, "bn1.num_batches_tracked", device);
register_module("conv1", conv1);
register_module("bn1", bn1); // bn1 is already populated correctly
relu = torch::nn::ReLU(torch::nn::ReLUOptions().inplace(true));
maxpool = torch::nn::MaxPool2d(torch::nn::MaxPool2dOptions(3).stride(2).padding(1));
register_module("relu", relu);
register_module("maxpool", maxpool);
layer1 = _make_layer(64, layers_dims[0], "layer1.", device);
layer2 = _make_layer(128, layers_dims[1], "layer2.", device, 2);
layer3 = _make_layer(256, layers_dims[2], "layer3.", device, 2);
layer4 = _make_layer(512, layers_dims[3], "layer4.", device, 2);
register_module("layer1", layer1);
register_module("layer2", layer2);
register_module("layer3", layer3);
register_module("layer4", layer4);
}
torch::nn::Sequential ResNetImpl::_make_layer(int64_t planes_for_block, int64_t num_blocks,
const std::string& layer_param_prefix,
const torch::Device& device,
int64_t stride_for_first_block) {
torch::nn::Sequential layer_sequential;
std::optional<torch::nn::Sequential> downsample_module_for_block_opt = std::nullopt;
if (stride_for_first_block != 1 || this->inplanes != planes_for_block * ResNetImpl::expansion) {
torch::nn::Sequential ds_seq;
auto conv_down = torch::nn::Conv2d(torch::nn::Conv2dOptions(this->inplanes, planes_for_block * ResNetImpl::expansion, 1).stride(stride_for_first_block).bias(false));
auto bn_down = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(planes_for_block * ResNetImpl::expansion).eps(static_cast<float>(1e-5)).momentum(0.1).affine(true).track_running_stats(true));
std::string ds_block_prefix = layer_param_prefix + "0.downsample.";
conv_down->weight = load_named_tensor(this->_base_weights_dir, ds_block_prefix + "0.weight", device);
bn_down->weight = load_named_tensor(this->_base_weights_dir, ds_block_prefix + "1.weight", device);
bn_down->bias = load_named_tensor(this->_base_weights_dir, ds_block_prefix + "1.bias", device);
bn_down->running_mean = load_named_tensor(this->_base_weights_dir, ds_block_prefix + "1.running_mean", device);
bn_down->running_var = load_named_tensor(this->_base_weights_dir, ds_block_prefix + "1.running_var", device);
bn_down->num_batches_tracked = load_named_tensor(this->_base_weights_dir, ds_block_prefix + "1.num_batches_tracked", device);
ds_seq->push_back(conv_down);
ds_seq->push_back(bn_down);
downsample_module_for_block_opt = ds_seq;
}
std::string first_block_param_prefix = layer_param_prefix + "0.";
layer_sequential->push_back(Bottleneck(this->_base_weights_dir, first_block_param_prefix,
this->inplanes, planes_for_block, device,
stride_for_first_block,
downsample_module_for_block_opt, ResNetImpl::expansion));
this->inplanes = planes_for_block * ResNetImpl::expansion;
for (int64_t i = 1; i < num_blocks; ++i) {
std::string current_block_param_prefix = layer_param_prefix + std::to_string(i) + ".";
layer_sequential->push_back(Bottleneck(this->_base_weights_dir, current_block_param_prefix,
this->inplanes, planes_for_block, device,
1,
std::nullopt, ResNetImpl::expansion));
}
return layer_sequential;
}
std::map<std::string, torch::Tensor> ResNetImpl::forward(torch::Tensor x) {
std::map<std::string, torch::Tensor> outputs;
auto should_output = [&](const std::string& layer_name) {
return std::find(_output_layers.begin(), _output_layers.end(), layer_name) != _output_layers.end();
};
x = conv1->forward(x);
if (should_output("debug_resnet_conv1_output_for_bn1_input")) {
outputs["debug_resnet_conv1_output_for_bn1_input"] = x.clone();
}
torch::ScalarType original_dtype_resnet_bn1 = x.scalar_type();
// Apply bn1
if (bn1) {
x = bn1->forward(x);
}
// End apply bn1
if (should_output("bn1_output")) outputs["bn1_output"] = x;
x = relu->forward(x);
if (should_output("relu1_output")) outputs["relu1_output"] = x;
// Save conv1_output AFTER bn1 and relu (matching Python behavior)
if (should_output("conv1_output")) outputs["conv1_output"] = x;
torch::Tensor x_pre_layer1 = maxpool->forward(x);
if (should_output("maxpool_output")) outputs["maxpool_output"] = x_pre_layer1;
// Save output of layer1.0 block if requested
if (should_output("layer1_0_block_output")) {
if (layer1 && !layer1->is_empty()) {
try {
// Get the base module pointer
std::shared_ptr<torch::nn::Module> base_module_ptr = layer1->ptr(0);
// Try to cast it to our BottleneckImpl (which is a torch::nn::Module)
auto bottleneck_impl_ptr = std::dynamic_pointer_cast<cimp::resnet::BottleneckImpl>(base_module_ptr);
if (bottleneck_impl_ptr) {
// Now call forward on the BottleneckImpl instance
outputs["layer1_0_block_output"] = bottleneck_impl_ptr->forward(x_pre_layer1);
} else {
std::cerr << "ERROR: layer1->ptr(0) could not be dynamically cast to BottleneckImpl! Module type: "
<< (base_module_ptr ? typeid(*base_module_ptr).name() : "null") << std::endl;
}
} catch (const std::exception& e) {
std::cerr << "EXCEPTION while getting layer1_0_block_output: " << e.what() << std::endl;
}
}
}
torch::Tensor x_after_layer1 = layer1->forward(x_pre_layer1);
if (should_output("layer1")) outputs["layer1"] = x_after_layer1;
if (should_output("layer1_0_shortcut_output")) {
if (layer1 && !layer1->is_empty()) {
try {
std::shared_ptr<torch::nn::Module> first_block_module_ptr = layer1->ptr(0);
auto bottleneck_module_holder = std::dynamic_pointer_cast<cimp::resnet::BottleneckImpl>(first_block_module_ptr);
if (bottleneck_module_holder) {
if (bottleneck_module_holder->projection_shortcut) {
torch::Tensor shortcut_out = bottleneck_module_holder->projection_shortcut->forward(x_pre_layer1);
outputs["layer1_0_shortcut_output"] = shortcut_out;
}
}
} catch (const std::exception& e) {
// std::cerr << "ERROR: Exception while getting layer1_0_shortcut_output: " << e.what() << std::endl;
}
}
}
torch::Tensor x_current = x_after_layer1;
x_current = layer2->forward(x_current);
if (should_output("layer2")) outputs["layer2"] = x_current;
x_current = layer3->forward(x_current);
if (should_output("layer3")) outputs["layer3"] = x_current;
x_current = layer4->forward(x_current);
if (should_output("layer4")) outputs["layer4"] = x_current;
if (should_output("features")) outputs["features"] = x_current;
return outputs;
}
// For ResNet-50, layers are [3, 4, 6, 3]
ResNet resnet50(const std::string& base_weights_dir,
const std::vector<std::string>& output_layers,
const torch::Device& device) {
return ResNet(ResNetImpl(base_weights_dir, {3, 4, 6, 3}, output_layers, device)); // Pass device
}
} // namespace resnet
} // namespace cimp