You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
287 lines
15 KiB
287 lines
15 KiB
#include "resnet.h"
|
|
#include <filesystem> // For std::filesystem::path
|
|
#include <stdexcept> // For std::runtime_error
|
|
#include <torch/script.h> // For torch::jit::load and torch::jit::Module
|
|
#include <optional> // ensure this is included
|
|
#include <fstream> // Added for std::ifstream
|
|
#include <vector> // Added for std::vector
|
|
|
|
namespace cimp {
|
|
namespace resnet {
|
|
|
|
namespace fs = std::filesystem; // Moved fs namespace alias here
|
|
|
|
// Helper function to load a tensor by its parameter name (e.g., "conv1.weight")
|
|
// Assumes .pt files are named like "conv1.weight.pt", "layer1.0.bn1.running_mean.pt"
|
|
torch::Tensor load_named_tensor(const std::string& base_weights_dir, const std::string& param_name_original, const torch::Device& device) {
|
|
fs::path tensor_file_path = fs::path(base_weights_dir) / (param_name_original + ".pt");
|
|
|
|
if (!fs::exists(tensor_file_path)) {
|
|
std::string param_name_underscore = param_name_original;
|
|
std::replace(param_name_underscore.begin(), param_name_underscore.end(), '.', '_');
|
|
fs::path tensor_file_path_underscore = fs::path(base_weights_dir) / (param_name_underscore + ".pt");
|
|
|
|
if (fs::exists(tensor_file_path_underscore)) {
|
|
std::cout << "INFO: Using underscore-named file for C++ loading: " << tensor_file_path_underscore.string() << std::endl;
|
|
tensor_file_path = tensor_file_path_underscore;
|
|
} else {
|
|
throw std::runtime_error("Weight file not found (tried direct and underscore versions): " +
|
|
(fs::path(base_weights_dir) / (param_name_original + ".pt")).string() +
|
|
" and " + tensor_file_path_underscore.string());
|
|
}
|
|
}
|
|
|
|
std::cout << "Attempting direct torch::pickle_load for tensor: " << tensor_file_path.string() << std::endl;
|
|
try {
|
|
// Read the file into a vector<char>
|
|
std::ifstream file_stream(tensor_file_path.string(), std::ios::binary);
|
|
if (!file_stream) {
|
|
throw std::runtime_error("Failed to open file: " + tensor_file_path.string());
|
|
}
|
|
std::vector<char> file_buffer((std::istreambuf_iterator<char>(file_stream)),
|
|
std::istreambuf_iterator<char>());
|
|
file_stream.close();
|
|
|
|
c10::IValue ivalue = torch::pickle_load(file_buffer);
|
|
return ivalue.toTensor().to(device);
|
|
} catch (const c10::Error& e) {
|
|
std::cerr << "CRITICAL ERROR: torch::pickle_load FAILED for '" << tensor_file_path.string() << "'. Error: " << e.what() << std::endl;
|
|
throw;
|
|
}
|
|
}
|
|
|
|
// --- BottleneckImpl Method Definitions ---
|
|
|
|
// Constructor implementation for BottleneckImpl
|
|
// Signature must match resnet.h:
|
|
// BottleneckImpl(int64_t inplanes, int64_t planes, const std::string& weights_dir_prefix, int64_t stride = 1,
|
|
// std::optional<torch::nn::Sequential> downsample_module_opt = std::nullopt, int64_t expansion_factor_arg = 4);
|
|
BottleneckImpl::BottleneckImpl(const std::string& base_weights_dir,
|
|
const std::string& block_param_prefix,
|
|
int64_t inplanes, int64_t planes,
|
|
const torch::Device& device,
|
|
int64_t stride_param, std::optional<torch::nn::Sequential> downsample_module_opt,
|
|
int64_t expansion_factor_arg)
|
|
: expansion_factor(expansion_factor_arg), stride_member(stride_param) {
|
|
|
|
// conv1
|
|
conv1 = torch::nn::Conv2d(torch::nn::Conv2dOptions(inplanes, planes, 1).bias(false));
|
|
bn1 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(planes));
|
|
conv1->weight = load_named_tensor(base_weights_dir, block_param_prefix + "conv1.weight", device);
|
|
bn1->weight = load_named_tensor(base_weights_dir, block_param_prefix + "bn1.weight", device);
|
|
bn1->bias = load_named_tensor(base_weights_dir, block_param_prefix + "bn1.bias", device);
|
|
bn1->named_buffers()["running_mean"] = load_named_tensor(base_weights_dir, block_param_prefix + "bn1.running_mean", device);
|
|
bn1->named_buffers()["running_var"] = load_named_tensor(base_weights_dir, block_param_prefix + "bn1.running_var", device);
|
|
register_module("conv1", conv1);
|
|
register_module("bn1", bn1);
|
|
|
|
// conv2
|
|
conv2 = torch::nn::Conv2d(torch::nn::Conv2dOptions(planes, planes, 3).stride(stride_member).padding(1).bias(false));
|
|
bn2 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(planes));
|
|
conv2->weight = load_named_tensor(base_weights_dir, block_param_prefix + "conv2.weight", device);
|
|
bn2->weight = load_named_tensor(base_weights_dir, block_param_prefix + "bn2.weight", device);
|
|
bn2->bias = load_named_tensor(base_weights_dir, block_param_prefix + "bn2.bias", device);
|
|
bn2->named_buffers()["running_mean"] = load_named_tensor(base_weights_dir, block_param_prefix + "bn2.running_mean", device);
|
|
bn2->named_buffers()["running_var"] = load_named_tensor(base_weights_dir, block_param_prefix + "bn2.running_var", device);
|
|
register_module("conv2", conv2);
|
|
register_module("bn2", bn2);
|
|
|
|
// conv3
|
|
conv3 = torch::nn::Conv2d(torch::nn::Conv2dOptions(planes, planes * expansion_factor, 1).bias(false));
|
|
bn3 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(planes * expansion_factor));
|
|
conv3->weight = load_named_tensor(base_weights_dir, block_param_prefix + "conv3.weight", device);
|
|
bn3->weight = load_named_tensor(base_weights_dir, block_param_prefix + "bn3.weight", device);
|
|
bn3->bias = load_named_tensor(base_weights_dir, block_param_prefix + "bn3.bias", device);
|
|
bn3->named_buffers()["running_mean"] = load_named_tensor(base_weights_dir, block_param_prefix + "bn3.running_mean", device);
|
|
bn3->named_buffers()["running_var"] = load_named_tensor(base_weights_dir, block_param_prefix + "bn3.running_var", device);
|
|
register_module("conv3", conv3);
|
|
register_module("bn3", bn3);
|
|
|
|
relu = torch::nn::ReLU(torch::nn::ReLUOptions(true));
|
|
register_module("relu", relu);
|
|
|
|
if (downsample_module_opt.has_value()) {
|
|
this->projection_shortcut = downsample_module_opt.value(); // Assign the passed Sequential module
|
|
|
|
// Weights for the submodules of projection_shortcut (conv & bn) are loaded by _make_layer
|
|
// before this module is passed. Here, we just register it.
|
|
register_module("projection_shortcut", this->projection_shortcut);
|
|
} else {
|
|
this->projection_shortcut = nullptr;
|
|
}
|
|
}
|
|
|
|
// Forward method implementation for BottleneckImpl
|
|
torch::Tensor BottleneckImpl::forward(torch::Tensor x) {
|
|
torch::Tensor identity = x;
|
|
|
|
x = conv1->forward(x);
|
|
x = bn1->forward(x);
|
|
x = relu->forward(x);
|
|
|
|
x = conv2->forward(x);
|
|
x = bn2->forward(x);
|
|
x = relu->forward(x);
|
|
|
|
x = conv3->forward(x);
|
|
x = bn3->forward(x);
|
|
|
|
if (this->projection_shortcut) {
|
|
identity = this->projection_shortcut->forward(identity);
|
|
}
|
|
|
|
x += identity;
|
|
x = relu->forward(x);
|
|
|
|
return x;
|
|
}
|
|
|
|
// --- ResNetImpl Method Definitions ---
|
|
ResNetImpl::ResNetImpl(const std::string& base_weights_dir_path,
|
|
const std::vector<int64_t>& layers_dims,
|
|
const std::vector<std::string>& output_layers_param,
|
|
const torch::Device& device)
|
|
: _output_layers(output_layers_param), _base_weights_dir(base_weights_dir_path) {
|
|
|
|
conv1 = torch::nn::Conv2d(torch::nn::Conv2dOptions(3, 64, 7).stride(2).padding(3).bias(false));
|
|
bn1 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(64));
|
|
this->conv1->weight = load_named_tensor(this->_base_weights_dir, "conv1.weight", device);
|
|
this->bn1->weight = load_named_tensor(this->_base_weights_dir, "bn1.weight", device);
|
|
this->bn1->bias = load_named_tensor(this->_base_weights_dir, "bn1.bias", device);
|
|
this->bn1->named_buffers()["running_mean"] = load_named_tensor(this->_base_weights_dir, "bn1.running_mean", device);
|
|
this->bn1->named_buffers()["running_var"] = load_named_tensor(this->_base_weights_dir, "bn1.running_var", device);
|
|
register_module("conv1", conv1);
|
|
register_module("bn1", bn1);
|
|
|
|
relu = torch::nn::ReLU(torch::nn::ReLUOptions().inplace(true));
|
|
maxpool = torch::nn::MaxPool2d(torch::nn::MaxPool2dOptions(3).stride(2).padding(1));
|
|
register_module("relu", relu);
|
|
register_module("maxpool", maxpool);
|
|
|
|
layer1 = _make_layer(64, layers_dims[0], "layer1.", device);
|
|
layer2 = _make_layer(128, layers_dims[1], "layer2.", device, 2);
|
|
layer3 = _make_layer(256, layers_dims[2], "layer3.", device, 2);
|
|
layer4 = _make_layer(512, layers_dims[3], "layer4.", device, 2);
|
|
register_module("layer1", layer1);
|
|
register_module("layer2", layer2);
|
|
register_module("layer3", layer3);
|
|
register_module("layer4", layer4);
|
|
}
|
|
|
|
torch::nn::Sequential ResNetImpl::_make_layer(int64_t planes_for_block, int64_t num_blocks,
|
|
const std::string& layer_param_prefix,
|
|
const torch::Device& device,
|
|
int64_t stride_for_first_block) {
|
|
torch::nn::Sequential layer_sequential;
|
|
std::optional<torch::nn::Sequential> downsample_module_for_block_opt = std::nullopt;
|
|
|
|
if (stride_for_first_block != 1 || this->inplanes != planes_for_block * ResNetImpl::expansion) {
|
|
torch::nn::Sequential ds_seq;
|
|
auto conv_down = torch::nn::Conv2d(torch::nn::Conv2dOptions(this->inplanes, planes_for_block * ResNetImpl::expansion, 1).stride(stride_for_first_block).bias(false));
|
|
auto bn_down = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(planes_for_block * ResNetImpl::expansion));
|
|
|
|
std::string ds_block_prefix = layer_param_prefix + "0.downsample.";
|
|
|
|
conv_down->weight = load_named_tensor(this->_base_weights_dir, ds_block_prefix + "0.weight", device);
|
|
bn_down->weight = load_named_tensor(this->_base_weights_dir, ds_block_prefix + "1.weight", device);
|
|
bn_down->bias = load_named_tensor(this->_base_weights_dir, ds_block_prefix + "1.bias", device);
|
|
bn_down->named_buffers()["running_mean"] = load_named_tensor(this->_base_weights_dir, ds_block_prefix + "1.running_mean", device);
|
|
bn_down->named_buffers()["running_var"] = load_named_tensor(this->_base_weights_dir, ds_block_prefix + "1.running_var", device);
|
|
|
|
ds_seq->push_back(conv_down);
|
|
ds_seq->push_back(bn_down);
|
|
downsample_module_for_block_opt = ds_seq;
|
|
}
|
|
|
|
std::string first_block_param_prefix = layer_param_prefix + "0.";
|
|
layer_sequential->push_back(Bottleneck(this->_base_weights_dir, first_block_param_prefix,
|
|
this->inplanes, planes_for_block, device,
|
|
stride_for_first_block,
|
|
downsample_module_for_block_opt, ResNetImpl::expansion));
|
|
this->inplanes = planes_for_block * ResNetImpl::expansion;
|
|
|
|
for (int64_t i = 1; i < num_blocks; ++i) {
|
|
std::string current_block_param_prefix = layer_param_prefix + std::to_string(i) + ".";
|
|
layer_sequential->push_back(Bottleneck(this->_base_weights_dir, current_block_param_prefix,
|
|
this->inplanes, planes_for_block, device,
|
|
1,
|
|
std::nullopt, ResNetImpl::expansion));
|
|
}
|
|
return layer_sequential;
|
|
}
|
|
|
|
std::map<std::string, torch::Tensor> ResNetImpl::forward(torch::Tensor x) {
|
|
std::map<std::string, torch::Tensor> outputs;
|
|
auto should_output = [&](const std::string& layer_name) {
|
|
return std::find(_output_layers.begin(), _output_layers.end(), layer_name) != _output_layers.end();
|
|
};
|
|
|
|
// Original GPU path for conv1
|
|
x = conv1->forward(x);
|
|
if (should_output("conv1_output")) outputs["conv1_output"] = x;
|
|
|
|
x = bn1->forward(x.clone());
|
|
if (should_output("bn1_output")) outputs["bn1_output"] = x;
|
|
|
|
x = relu->forward(x.clone());
|
|
if (should_output("relu1_output")) outputs["relu1_output"] = x;
|
|
|
|
torch::Tensor x_pre_layer1 = maxpool->forward(x.clone());
|
|
if (should_output("maxpool_output")) outputs["maxpool_output"] = x_pre_layer1;
|
|
|
|
// Pass x_pre_layer1 to layer1
|
|
torch::Tensor x_after_layer1 = layer1->forward(x_pre_layer1.clone()); // Use .clone() if layer1 might modify input inplace, good practice
|
|
if (should_output("layer1")) outputs["layer1"] = x_after_layer1;
|
|
|
|
if (should_output("layer1_0_shortcut_output")) {
|
|
if (layer1 && !layer1->is_empty()) {
|
|
try {
|
|
// Get the first module (Bottleneck) from layer1 Sequential container
|
|
std::shared_ptr<torch::nn::Module> first_block_module_ptr = layer1->ptr(0);
|
|
// Attempt to dynamically cast to Bottleneck type
|
|
auto bottleneck_module_holder = std::dynamic_pointer_cast<cimp::resnet::BottleneckImpl>(first_block_module_ptr);
|
|
|
|
if (bottleneck_module_holder) { // Check if cast was successful
|
|
// Accessing projection_shortcut directly from BottleneckImpl
|
|
if (bottleneck_module_holder->projection_shortcut) {
|
|
torch::Tensor shortcut_out = bottleneck_module_holder->projection_shortcut->forward(x_pre_layer1.clone());
|
|
outputs["layer1_0_shortcut_output"] = shortcut_out;
|
|
} else {
|
|
// std::cout << "DEBUG: layer1.0 projection_shortcut is null." << std::endl;
|
|
}
|
|
} else {
|
|
// std::cerr << "ERROR: Failed to cast first block of layer1 to BottleneckImpl." << std::endl;
|
|
}
|
|
} catch (const std::exception& e) {
|
|
// std::cerr << "ERROR: Exception while getting layer1_0_shortcut_output: " << e.what() << std::endl;
|
|
}
|
|
} else {
|
|
// std::cout << "DEBUG: layer1 is null or empty, cannot get shortcut output." << std::endl;
|
|
}
|
|
}
|
|
|
|
torch::Tensor x_current = x_after_layer1; // Continue with the output of layer1
|
|
|
|
x_current = layer2->forward(x_current.clone());
|
|
if (should_output("layer2")) outputs["layer2"] = x_current;
|
|
|
|
x_current = layer3->forward(x_current.clone());
|
|
if (should_output("layer3")) outputs["layer3"] = x_current;
|
|
|
|
x_current = layer4->forward(x_current.clone());
|
|
if (should_output("layer4")) outputs["layer4"] = x_current;
|
|
|
|
if (should_output("features")) outputs["features"] = x_current; // 'features' is typically layer4 output
|
|
|
|
return outputs;
|
|
}
|
|
|
|
// For ResNet-50, layers are [3, 4, 6, 3]
|
|
ResNet resnet50(const std::string& base_weights_dir,
|
|
const std::vector<std::string>& output_layers,
|
|
const torch::Device& device) {
|
|
return ResNet(ResNetImpl(base_weights_dir, {3, 4, 6, 3}, output_layers, device)); // Pass device
|
|
}
|
|
|
|
} // namespace resnet
|
|
} // namespace cimp
|