#include "resnet.h" #include // For std::filesystem::path #include // For std::runtime_error #include // For torch::jit::load and torch::jit::Module #include // ensure this is included #include // Added for std::ifstream #include // Added for std::vector #include // Added for std::fixed and std::setprecision namespace cimp { namespace resnet { namespace fs = std::filesystem; // Moved fs namespace alias here // Helper function to load a tensor by its parameter name (e.g., "conv1.weight") // Assumes .pt files are named like "conv1.weight.pt", "layer1.0.bn1.running_mean.pt" torch::Tensor load_named_tensor(const std::string& base_weights_dir, const std::string& param_name_original, const torch::Device& device) { fs::path tensor_file_path = fs::path(base_weights_dir) / (param_name_original + ".pt"); if (!fs::exists(tensor_file_path)) { std::string param_name_underscore = param_name_original; std::replace(param_name_underscore.begin(), param_name_underscore.end(), '.', '_'); fs::path tensor_file_path_underscore = fs::path(base_weights_dir) / (param_name_underscore + ".pt"); if (fs::exists(tensor_file_path_underscore)) { std::cout << "INFO: Using underscore-named file for C++ loading: " << tensor_file_path_underscore.string() << std::endl; tensor_file_path = tensor_file_path_underscore; } else { throw std::runtime_error("Weight file not found (tried direct and underscore versions): " + (fs::path(base_weights_dir) / (param_name_original + ".pt")).string() + " and " + tensor_file_path_underscore.string()); } } std::cout << "Attempting direct torch::pickle_load for tensor: " << tensor_file_path.string() << std::endl; try { // Read the file into a vector std::ifstream file_stream(tensor_file_path.string(), std::ios::binary); if (!file_stream) { throw std::runtime_error("Failed to open file: " + tensor_file_path.string()); } std::vector file_buffer((std::istreambuf_iterator(file_stream)), std::istreambuf_iterator()); file_stream.close(); c10::IValue ivalue = torch::pickle_load(file_buffer); return ivalue.toTensor().to(device); } catch (const c10::Error& e) { std::cerr << "CRITICAL ERROR: torch::pickle_load FAILED for '" << tensor_file_path.string() << "'. Error: " << e.what() << std::endl; throw; } } // --- BottleneckImpl Method Definitions --- // Constructor implementation for BottleneckImpl // Signature must match resnet.h: // BottleneckImpl(int64_t inplanes, int64_t planes, const std::string& weights_dir_prefix, int64_t stride = 1, // std::optional downsample_module_opt = std::nullopt, int64_t expansion_factor_arg = 4); BottleneckImpl::BottleneckImpl(const std::string& base_weights_dir, const std::string& block_param_prefix, int64_t inplanes, int64_t planes, const torch::Device& device, int64_t stride_param, std::optional downsample_module_opt, int64_t expansion_factor_arg) : expansion_factor(expansion_factor_arg), stride_member(stride_param) { // conv1 conv1 = torch::nn::Conv2d(torch::nn::Conv2dOptions(inplanes, planes, 1).bias(false)); bn1 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(planes).eps(static_cast(1e-5)).momentum(0.1).affine(true).track_running_stats(true)); conv1->weight = load_named_tensor(base_weights_dir, block_param_prefix + "conv1.weight", device); bn1->weight = load_named_tensor(base_weights_dir, block_param_prefix + "bn1.weight", device); bn1->bias = load_named_tensor(base_weights_dir, block_param_prefix + "bn1.bias", device); bn1->running_mean = load_named_tensor(base_weights_dir, block_param_prefix + "bn1.running_mean", device); bn1->running_var = load_named_tensor(base_weights_dir, block_param_prefix + "bn1.running_var", device); bn1->num_batches_tracked = load_named_tensor(base_weights_dir, block_param_prefix + "bn1.num_batches_tracked", device); register_module("conv1", conv1); register_module("bn1", bn1); // conv2 conv2 = torch::nn::Conv2d(torch::nn::Conv2dOptions(planes, planes, 3).stride(stride_member).padding(1).bias(false)); bn2 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(planes).eps(static_cast(1e-5)).momentum(0.1).affine(true).track_running_stats(true)); conv2->weight = load_named_tensor(base_weights_dir, block_param_prefix + "conv2.weight", device); bn2->weight = load_named_tensor(base_weights_dir, block_param_prefix + "bn2.weight", device); bn2->bias = load_named_tensor(base_weights_dir, block_param_prefix + "bn2.bias", device); bn2->running_mean = load_named_tensor(base_weights_dir, block_param_prefix + "bn2.running_mean", device); bn2->running_var = load_named_tensor(base_weights_dir, block_param_prefix + "bn2.running_var", device); bn2->num_batches_tracked = load_named_tensor(base_weights_dir, block_param_prefix + "bn2.num_batches_tracked", device); register_module("conv2", conv2); register_module("bn2", bn2); // conv3 conv3 = torch::nn::Conv2d(torch::nn::Conv2dOptions(planes, planes * expansion_factor, 1).bias(false)); bn3 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(planes * expansion_factor).eps(static_cast(1e-5)).momentum(0.1).affine(true).track_running_stats(true)); conv3->weight = load_named_tensor(base_weights_dir, block_param_prefix + "conv3.weight", device); bn3->weight = load_named_tensor(base_weights_dir, block_param_prefix + "bn3.weight", device); bn3->bias = load_named_tensor(base_weights_dir, block_param_prefix + "bn3.bias", device); bn3->running_mean = load_named_tensor(base_weights_dir, block_param_prefix + "bn3.running_mean", device); bn3->running_var = load_named_tensor(base_weights_dir, block_param_prefix + "bn3.running_var", device); bn3->num_batches_tracked = load_named_tensor(base_weights_dir, block_param_prefix + "bn3.num_batches_tracked", device); register_module("conv3", conv3); register_module("bn3", bn3); relu = torch::nn::ReLU(torch::nn::ReLUOptions(true)); register_module("relu", relu); if (downsample_module_opt.has_value()) { this->projection_shortcut = downsample_module_opt.value(); // Assign the passed Sequential module // Weights for the submodules of projection_shortcut (conv & bn) are loaded by _make_layer // before this module is passed. Here, we just register it. register_module("projection_shortcut", this->projection_shortcut); } else { this->projection_shortcut = nullptr; } } // Forward method implementation for BottleneckImpl torch::Tensor BottleneckImpl::forward(torch::Tensor x) { torch::Tensor identity = x; torch::ScalarType original_dtype = x.scalar_type(); // conv1 -> bn1 -> relu x = conv1->forward(x); if (bn1) { x = bn1->forward(x); } x = relu->forward(x); // conv2 -> bn2 -> relu x = conv2->forward(x); if (bn2) { x = bn2->forward(x); } x = relu->forward(x); // conv3 -> bn3 x = conv3->forward(x); if (bn3) { x = bn3->forward(x); } if (this->projection_shortcut) { identity = this->projection_shortcut->forward(identity); } x += identity; x = relu->forward(x); return x; } // --- ResNetImpl Method Definitions --- ResNetImpl::ResNetImpl(const std::string& base_weights_dir_path, const std::vector& layers_dims, const std::vector& output_layers_param, const torch::Device& device) : _output_layers(output_layers_param), _base_weights_dir(base_weights_dir_path) { conv1 = torch::nn::Conv2d(torch::nn::Conv2dOptions(3, 64, 7).stride(2).padding(3).bias(false)); bn1 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(64).eps(static_cast(1e-5)).momentum(0.1).affine(true).track_running_stats(true)); this->conv1->weight = load_named_tensor(this->_base_weights_dir, "conv1.weight", device); // Directly assign to the public member tensors of the bn1 module this->bn1->weight = load_named_tensor(this->_base_weights_dir, "bn1.weight", device); this->bn1->bias = load_named_tensor(this->_base_weights_dir, "bn1.bias", device); this->bn1->running_mean = load_named_tensor(this->_base_weights_dir, "bn1.running_mean", device); this->bn1->running_var = load_named_tensor(this->_base_weights_dir, "bn1.running_var", device); this->bn1->num_batches_tracked = load_named_tensor(this->_base_weights_dir, "bn1.num_batches_tracked", device); register_module("conv1", conv1); register_module("bn1", bn1); // bn1 is already populated correctly relu = torch::nn::ReLU(torch::nn::ReLUOptions().inplace(true)); maxpool = torch::nn::MaxPool2d(torch::nn::MaxPool2dOptions(3).stride(2).padding(1)); register_module("relu", relu); register_module("maxpool", maxpool); layer1 = _make_layer(64, layers_dims[0], "layer1.", device); layer2 = _make_layer(128, layers_dims[1], "layer2.", device, 2); layer3 = _make_layer(256, layers_dims[2], "layer3.", device, 2); layer4 = _make_layer(512, layers_dims[3], "layer4.", device, 2); register_module("layer1", layer1); register_module("layer2", layer2); register_module("layer3", layer3); register_module("layer4", layer4); } torch::nn::Sequential ResNetImpl::_make_layer(int64_t planes_for_block, int64_t num_blocks, const std::string& layer_param_prefix, const torch::Device& device, int64_t stride_for_first_block) { torch::nn::Sequential layer_sequential; std::optional downsample_module_for_block_opt = std::nullopt; if (stride_for_first_block != 1 || this->inplanes != planes_for_block * ResNetImpl::expansion) { torch::nn::Sequential ds_seq; auto conv_down = torch::nn::Conv2d(torch::nn::Conv2dOptions(this->inplanes, planes_for_block * ResNetImpl::expansion, 1).stride(stride_for_first_block).bias(false)); auto bn_down = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(planes_for_block * ResNetImpl::expansion).eps(static_cast(1e-5)).momentum(0.1).affine(true).track_running_stats(true)); std::string ds_block_prefix = layer_param_prefix + "0.downsample."; conv_down->weight = load_named_tensor(this->_base_weights_dir, ds_block_prefix + "0.weight", device); bn_down->weight = load_named_tensor(this->_base_weights_dir, ds_block_prefix + "1.weight", device); bn_down->bias = load_named_tensor(this->_base_weights_dir, ds_block_prefix + "1.bias", device); bn_down->running_mean = load_named_tensor(this->_base_weights_dir, ds_block_prefix + "1.running_mean", device); bn_down->running_var = load_named_tensor(this->_base_weights_dir, ds_block_prefix + "1.running_var", device); bn_down->num_batches_tracked = load_named_tensor(this->_base_weights_dir, ds_block_prefix + "1.num_batches_tracked", device); ds_seq->push_back(conv_down); ds_seq->push_back(bn_down); downsample_module_for_block_opt = ds_seq; } std::string first_block_param_prefix = layer_param_prefix + "0."; layer_sequential->push_back(Bottleneck(this->_base_weights_dir, first_block_param_prefix, this->inplanes, planes_for_block, device, stride_for_first_block, downsample_module_for_block_opt, ResNetImpl::expansion)); this->inplanes = planes_for_block * ResNetImpl::expansion; for (int64_t i = 1; i < num_blocks; ++i) { std::string current_block_param_prefix = layer_param_prefix + std::to_string(i) + "."; layer_sequential->push_back(Bottleneck(this->_base_weights_dir, current_block_param_prefix, this->inplanes, planes_for_block, device, 1, std::nullopt, ResNetImpl::expansion)); } return layer_sequential; } std::map ResNetImpl::forward(torch::Tensor x) { std::map outputs; auto should_output = [&](const std::string& layer_name) { return std::find(_output_layers.begin(), _output_layers.end(), layer_name) != _output_layers.end(); }; // Print input shape if (x.size(0) > 0) std::cout << "[DEBUG] Input shape: " << x.sizes() << std::endl; x = conv1->forward(x); if (x.size(0) > 0) torch::save(x[0].cpu(), "test/output/resnet_debug/sample_0_after_conv1.pt"); if (should_output("debug_resnet_conv1_output_for_bn1_input")) { outputs["debug_resnet_conv1_output_for_bn1_input"] = x.clone(); } torch::ScalarType original_dtype_resnet_bn1 = x.scalar_type(); // Apply bn1 if (bn1) { x = bn1->forward(x); } if (x.size(0) > 0) torch::save(x[0].cpu(), "test/output/resnet_debug/sample_0_after_bn1.pt"); // End apply bn1 if (should_output("bn1_output")) outputs["bn1_output"] = x; x = relu->forward(x); if (x.size(0) > 0) torch::save(x[0].cpu(), "test/output/resnet_debug/sample_0_after_relu1.pt"); // Save conv1_output AFTER bn1 and relu (matching Python behavior) if (should_output("conv1_output")) outputs["conv1_output"] = x; torch::Tensor x_pre_layer1 = maxpool->forward(x); if (x_pre_layer1.size(0) > 0) torch::save(x_pre_layer1[0].cpu(), "test/output/resnet_debug/sample_0_after_maxpool.pt"); // Save output of layer1.0 block if requested if (should_output("layer1_0_block_output")) { if (layer1 && !layer1->is_empty()) { try { std::shared_ptr base_module_ptr = layer1->ptr(0); auto bottleneck_impl_ptr = std::dynamic_pointer_cast(base_module_ptr); if (bottleneck_impl_ptr) { outputs["layer1_0_block_output"] = bottleneck_impl_ptr->forward(x_pre_layer1); } else { std::cerr << "ERROR: layer1->ptr(0) could not be dynamically cast to BottleneckImpl! Module type: " << (base_module_ptr ? typeid(*base_module_ptr).name() : "null") << std::endl; } } catch (const std::exception& e) { std::cerr << "EXCEPTION while getting layer1_0_block_output: " << e.what() << std::endl; } } } torch::Tensor x_after_layer1 = layer1->forward(x_pre_layer1); if (x_after_layer1.size(0) > 0) torch::save(x_after_layer1[0].cpu(), "test/output/resnet_debug/sample_0_after_layer1.pt"); if (should_output("layer1")) outputs["layer1"] = x_after_layer1; if (should_output("layer1_0_shortcut_output")) { if (layer1 && !layer1->is_empty()) { try { std::shared_ptr first_block_module_ptr = layer1->ptr(0); auto bottleneck_module_holder = std::dynamic_pointer_cast(first_block_module_ptr); if (bottleneck_module_holder) { if (bottleneck_module_holder->projection_shortcut) { torch::Tensor shortcut_out = bottleneck_module_holder->projection_shortcut->forward(x_pre_layer1); outputs["layer1_0_shortcut_output"] = shortcut_out; } } } catch (const std::exception& e) { // std::cerr << "ERROR: Exception while getting layer1_0_shortcut_output: " << e.what() << std::endl; } } } torch::Tensor x_current = x_after_layer1; x_current = layer2->forward(x_current); if (x_current.size(0) > 0) torch::save(x_current[0].cpu(), "test/output/resnet_debug/sample_0_after_layer2.pt"); if (should_output("layer2")) outputs["layer2"] = x_current; x_current = layer3->forward(x_current); if (x_current.size(0) > 0) torch::save(x_current[0].cpu(), "test/output/resnet_debug/sample_0_after_layer3.pt"); if (should_output("layer3")) outputs["layer3"] = x_current; x_current = layer4->forward(x_current); if (x_current.size(0) > 0) torch::save(x_current[0].cpu(), "test/output/resnet_debug/sample_0_after_layer4.pt"); if (should_output("layer4")) outputs["layer4"] = x_current; if (should_output("features")) outputs["features"] = x_current; return outputs; } // For ResNet-50, layers are [3, 4, 6, 3] ResNet resnet50(const std::string& base_weights_dir, const std::vector& output_layers, const torch::Device& device) { return ResNet(ResNetImpl(base_weights_dir, {3, 4, 6, 3}, output_layers, device)); // Pass device } } // namespace resnet } // namespace cimp