You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
417 lines
22 KiB
417 lines
22 KiB
#include "resnet.h"
|
|
#include <filesystem> // For std::filesystem::path
|
|
#include <stdexcept> // For std::runtime_error
|
|
#include <torch/script.h> // For torch::jit::load and torch::jit::Module
|
|
#include <optional> // ensure this is included
|
|
#include <fstream> // Added for std::ifstream
|
|
#include <vector> // Added for std::vector
|
|
#include <iomanip> // Added for std::fixed and std::setprecision
|
|
|
|
namespace cimp {
|
|
namespace resnet {
|
|
|
|
namespace fs = std::filesystem; // Moved fs namespace alias here
|
|
|
|
// Helper function to load a tensor by its parameter name (e.g., "conv1.weight")
|
|
// Assumes .pt files are named like "conv1.weight.pt", "layer1.0.bn1.running_mean.pt"
|
|
torch::Tensor load_named_tensor(const std::string& base_weights_dir, const std::string& param_name_original, const torch::Device& device) {
|
|
fs::path tensor_file_path = fs::path(base_weights_dir) / (param_name_original + ".pt");
|
|
|
|
if (!fs::exists(tensor_file_path)) {
|
|
std::string param_name_underscore = param_name_original;
|
|
std::replace(param_name_underscore.begin(), param_name_underscore.end(), '.', '_');
|
|
fs::path tensor_file_path_underscore = fs::path(base_weights_dir) / (param_name_underscore + ".pt");
|
|
|
|
if (fs::exists(tensor_file_path_underscore)) {
|
|
std::cout << "INFO: Using underscore-named file for C++ loading: " << tensor_file_path_underscore.string() << std::endl;
|
|
tensor_file_path = tensor_file_path_underscore;
|
|
} else {
|
|
throw std::runtime_error("Weight file not found (tried direct and underscore versions): " +
|
|
(fs::path(base_weights_dir) / (param_name_original + ".pt")).string() +
|
|
" and " + tensor_file_path_underscore.string());
|
|
}
|
|
}
|
|
|
|
std::cout << "Attempting direct torch::pickle_load for tensor: " << tensor_file_path.string() << std::endl;
|
|
try {
|
|
// Read the file into a vector<char>
|
|
std::ifstream file_stream(tensor_file_path.string(), std::ios::binary);
|
|
if (!file_stream) {
|
|
throw std::runtime_error("Failed to open file: " + tensor_file_path.string());
|
|
}
|
|
std::vector<char> file_buffer((std::istreambuf_iterator<char>(file_stream)),
|
|
std::istreambuf_iterator<char>());
|
|
file_stream.close();
|
|
|
|
c10::IValue ivalue = torch::pickle_load(file_buffer);
|
|
return ivalue.toTensor().to(device);
|
|
} catch (const c10::Error& e) {
|
|
std::cerr << "CRITICAL ERROR: torch::pickle_load FAILED for '" << tensor_file_path.string() << "'. Error: " << e.what() << std::endl;
|
|
throw;
|
|
}
|
|
}
|
|
|
|
// --- BottleneckImpl Method Definitions ---
|
|
|
|
// Constructor implementation for BottleneckImpl
|
|
// Signature must match resnet.h:
|
|
// BottleneckImpl(int64_t inplanes, int64_t planes, const std::string& weights_dir_prefix, int64_t stride = 1,
|
|
// std::optional<torch::nn::Sequential> downsample_module_opt = std::nullopt, int64_t expansion_factor_arg = 4);
|
|
BottleneckImpl::BottleneckImpl(const std::string& base_weights_dir,
|
|
const std::string& block_param_prefix,
|
|
int64_t inplanes, int64_t planes,
|
|
const torch::Device& device,
|
|
int64_t stride_param, std::optional<torch::nn::Sequential> downsample_module_opt,
|
|
int64_t expansion_factor_arg)
|
|
: expansion_factor(expansion_factor_arg), stride_member(stride_param) {
|
|
|
|
// conv1
|
|
conv1 = torch::nn::Conv2d(torch::nn::Conv2dOptions(inplanes, planes, 1).bias(false));
|
|
bn1 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(planes).eps(static_cast<float>(1e-5)).momentum(0.1).affine(true).track_running_stats(true));
|
|
conv1->weight = load_named_tensor(base_weights_dir, block_param_prefix + "conv1.weight", device);
|
|
bn1->weight = load_named_tensor(base_weights_dir, block_param_prefix + "bn1.weight", device);
|
|
bn1->bias = load_named_tensor(base_weights_dir, block_param_prefix + "bn1.bias", device);
|
|
bn1->running_mean = load_named_tensor(base_weights_dir, block_param_prefix + "bn1.running_mean", device);
|
|
bn1->running_var = load_named_tensor(base_weights_dir, block_param_prefix + "bn1.running_var", device);
|
|
bn1->num_batches_tracked = load_named_tensor(base_weights_dir, block_param_prefix + "bn1.num_batches_tracked", device);
|
|
register_module("conv1", conv1);
|
|
register_module("bn1", bn1);
|
|
|
|
// conv2
|
|
conv2 = torch::nn::Conv2d(torch::nn::Conv2dOptions(planes, planes, 3).stride(stride_member).padding(1).bias(false));
|
|
bn2 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(planes).eps(static_cast<float>(1e-5)).momentum(0.1).affine(true).track_running_stats(true));
|
|
conv2->weight = load_named_tensor(base_weights_dir, block_param_prefix + "conv2.weight", device);
|
|
bn2->weight = load_named_tensor(base_weights_dir, block_param_prefix + "bn2.weight", device);
|
|
bn2->bias = load_named_tensor(base_weights_dir, block_param_prefix + "bn2.bias", device);
|
|
bn2->running_mean = load_named_tensor(base_weights_dir, block_param_prefix + "bn2.running_mean", device);
|
|
bn2->running_var = load_named_tensor(base_weights_dir, block_param_prefix + "bn2.running_var", device);
|
|
bn2->num_batches_tracked = load_named_tensor(base_weights_dir, block_param_prefix + "bn2.num_batches_tracked", device);
|
|
register_module("conv2", conv2);
|
|
register_module("bn2", bn2);
|
|
|
|
// conv3
|
|
conv3 = torch::nn::Conv2d(torch::nn::Conv2dOptions(planes, planes * expansion_factor, 1).bias(false));
|
|
bn3 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(planes * expansion_factor).eps(static_cast<float>(1e-5)).momentum(0.1).affine(true).track_running_stats(true));
|
|
conv3->weight = load_named_tensor(base_weights_dir, block_param_prefix + "conv3.weight", device);
|
|
bn3->weight = load_named_tensor(base_weights_dir, block_param_prefix + "bn3.weight", device);
|
|
bn3->bias = load_named_tensor(base_weights_dir, block_param_prefix + "bn3.bias", device);
|
|
bn3->running_mean = load_named_tensor(base_weights_dir, block_param_prefix + "bn3.running_mean", device);
|
|
bn3->running_var = load_named_tensor(base_weights_dir, block_param_prefix + "bn3.running_var", device);
|
|
bn3->num_batches_tracked = load_named_tensor(base_weights_dir, block_param_prefix + "bn3.num_batches_tracked", device);
|
|
register_module("conv3", conv3);
|
|
register_module("bn3", bn3);
|
|
|
|
relu = torch::nn::ReLU(torch::nn::ReLUOptions(true));
|
|
register_module("relu", relu);
|
|
|
|
if (downsample_module_opt.has_value()) {
|
|
this->projection_shortcut = downsample_module_opt.value(); // Assign the passed Sequential module
|
|
|
|
// Weights for the submodules of projection_shortcut (conv & bn) are loaded by _make_layer
|
|
// before this module is passed. Here, we just register it.
|
|
register_module("projection_shortcut", this->projection_shortcut);
|
|
} else {
|
|
this->projection_shortcut = nullptr;
|
|
}
|
|
}
|
|
|
|
// Forward method implementation for BottleneckImpl
|
|
torch::Tensor BottleneckImpl::forward(torch::Tensor x) {
|
|
torch::Tensor identity = x;
|
|
torch::ScalarType original_dtype = x.scalar_type();
|
|
|
|
// conv1 -> bn1 -> relu
|
|
x = conv1->forward(x);
|
|
|
|
if (!this->is_training() && bn1) {
|
|
const auto& bn_module = *bn1;
|
|
torch::Tensor input_double = x.to(torch::kFloat64);
|
|
torch::Tensor weight_double = bn_module.weight.defined() ? bn_module.weight.to(torch::kFloat64) : torch::Tensor();
|
|
torch::Tensor bias_double = bn_module.bias.defined() ? bn_module.bias.to(torch::kFloat64) : torch::Tensor();
|
|
torch::Tensor running_mean_double = bn_module.running_mean.to(torch::kFloat64);
|
|
torch::Tensor running_var_double = bn_module.running_var.to(torch::kFloat64);
|
|
double eps_double = bn_module.options.eps();
|
|
|
|
auto c = x.size(1);
|
|
running_mean_double = running_mean_double.reshape({1, c, 1, 1});
|
|
running_var_double = running_var_double.reshape({1, c, 1, 1});
|
|
if (weight_double.defined()) weight_double = weight_double.reshape({1, c, 1, 1});
|
|
if (bias_double.defined()) bias_double = bias_double.reshape({1, c, 1, 1});
|
|
|
|
torch::Tensor out_double = (input_double - running_mean_double) / (torch::sqrt(running_var_double + eps_double));
|
|
if (weight_double.defined()) out_double = out_double * weight_double;
|
|
if (bias_double.defined()) out_double = out_double + bias_double;
|
|
x = out_double.to(original_dtype);
|
|
} else if (bn1) {
|
|
x = bn1->forward(x);
|
|
}
|
|
x = relu->forward(x);
|
|
|
|
// conv2 -> bn2 -> relu
|
|
x = conv2->forward(x);
|
|
if (!this->is_training() && bn2) {
|
|
const auto& bn_module = *bn2;
|
|
torch::Tensor input_double = x.to(torch::kFloat64);
|
|
torch::Tensor weight_double = bn_module.weight.defined() ? bn_module.weight.to(torch::kFloat64) : torch::Tensor();
|
|
torch::Tensor bias_double = bn_module.bias.defined() ? bn_module.bias.to(torch::kFloat64) : torch::Tensor();
|
|
torch::Tensor running_mean_double = bn_module.running_mean.to(torch::kFloat64);
|
|
torch::Tensor running_var_double = bn_module.running_var.to(torch::kFloat64);
|
|
double eps_double = bn_module.options.eps();
|
|
|
|
auto c = x.size(1);
|
|
running_mean_double = running_mean_double.reshape({1, c, 1, 1});
|
|
running_var_double = running_var_double.reshape({1, c, 1, 1});
|
|
if (weight_double.defined()) weight_double = weight_double.reshape({1, c, 1, 1});
|
|
if (bias_double.defined()) bias_double = bias_double.reshape({1, c, 1, 1});
|
|
|
|
torch::Tensor out_double = (input_double - running_mean_double) / (torch::sqrt(running_var_double + eps_double));
|
|
if (weight_double.defined()) out_double = out_double * weight_double;
|
|
if (bias_double.defined()) out_double = out_double + bias_double;
|
|
x = out_double.to(original_dtype);
|
|
} else if (bn2) {
|
|
x = bn2->forward(x);
|
|
}
|
|
x = relu->forward(x);
|
|
|
|
// conv3 -> bn3
|
|
x = conv3->forward(x);
|
|
if (!this->is_training() && bn3) {
|
|
const auto& bn_module = *bn3;
|
|
torch::Tensor input_double = x.to(torch::kFloat64);
|
|
torch::Tensor weight_double = bn_module.weight.defined() ? bn_module.weight.to(torch::kFloat64) : torch::Tensor();
|
|
torch::Tensor bias_double = bn_module.bias.defined() ? bn_module.bias.to(torch::kFloat64) : torch::Tensor();
|
|
torch::Tensor running_mean_double = bn_module.running_mean.to(torch::kFloat64);
|
|
torch::Tensor running_var_double = bn_module.running_var.to(torch::kFloat64);
|
|
double eps_double = bn_module.options.eps();
|
|
|
|
auto c = x.size(1);
|
|
running_mean_double = running_mean_double.reshape({1, c, 1, 1});
|
|
running_var_double = running_var_double.reshape({1, c, 1, 1});
|
|
if (weight_double.defined()) weight_double = weight_double.reshape({1, c, 1, 1});
|
|
if (bias_double.defined()) bias_double = bias_double.reshape({1, c, 1, 1});
|
|
|
|
torch::Tensor out_double = (input_double - running_mean_double) / (torch::sqrt(running_var_double + eps_double));
|
|
if (weight_double.defined()) out_double = out_double * weight_double;
|
|
if (bias_double.defined()) out_double = out_double + bias_double;
|
|
x = out_double.to(original_dtype);
|
|
} else if (bn3) {
|
|
x = bn3->forward(x);
|
|
}
|
|
|
|
if (this->projection_shortcut) {
|
|
identity = this->projection_shortcut->forward(identity);
|
|
}
|
|
|
|
x += identity;
|
|
x = relu->forward(x);
|
|
|
|
return x;
|
|
}
|
|
|
|
// --- ResNetImpl Method Definitions ---
|
|
ResNetImpl::ResNetImpl(const std::string& base_weights_dir_path,
|
|
const std::vector<int64_t>& layers_dims,
|
|
const std::vector<std::string>& output_layers_param,
|
|
const torch::Device& device)
|
|
: _output_layers(output_layers_param), _base_weights_dir(base_weights_dir_path) {
|
|
|
|
conv1 = torch::nn::Conv2d(torch::nn::Conv2dOptions(3, 64, 7).stride(2).padding(3).bias(false));
|
|
bn1 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(64).eps(static_cast<float>(1e-5)).momentum(0.1).affine(true).track_running_stats(true));
|
|
this->conv1->weight = load_named_tensor(this->_base_weights_dir, "conv1.weight", device);
|
|
|
|
// Directly assign to the public member tensors of the bn1 module
|
|
this->bn1->weight = load_named_tensor(this->_base_weights_dir, "bn1.weight", device);
|
|
this->bn1->bias = load_named_tensor(this->_base_weights_dir, "bn1.bias", device);
|
|
this->bn1->running_mean = load_named_tensor(this->_base_weights_dir, "bn1.running_mean", device);
|
|
this->bn1->running_var = load_named_tensor(this->_base_weights_dir, "bn1.running_var", device);
|
|
this->bn1->num_batches_tracked = load_named_tensor(this->_base_weights_dir, "bn1.num_batches_tracked", device);
|
|
|
|
register_module("conv1", conv1);
|
|
register_module("bn1", bn1); // bn1 is already populated correctly
|
|
|
|
relu = torch::nn::ReLU(torch::nn::ReLUOptions().inplace(true));
|
|
maxpool = torch::nn::MaxPool2d(torch::nn::MaxPool2dOptions(3).stride(2).padding(1));
|
|
register_module("relu", relu);
|
|
register_module("maxpool", maxpool);
|
|
|
|
layer1 = _make_layer(64, layers_dims[0], "layer1.", device);
|
|
layer2 = _make_layer(128, layers_dims[1], "layer2.", device, 2);
|
|
layer3 = _make_layer(256, layers_dims[2], "layer3.", device, 2);
|
|
layer4 = _make_layer(512, layers_dims[3], "layer4.", device, 2);
|
|
register_module("layer1", layer1);
|
|
register_module("layer2", layer2);
|
|
register_module("layer3", layer3);
|
|
register_module("layer4", layer4);
|
|
}
|
|
|
|
torch::nn::Sequential ResNetImpl::_make_layer(int64_t planes_for_block, int64_t num_blocks,
|
|
const std::string& layer_param_prefix,
|
|
const torch::Device& device,
|
|
int64_t stride_for_first_block) {
|
|
torch::nn::Sequential layer_sequential;
|
|
std::optional<torch::nn::Sequential> downsample_module_for_block_opt = std::nullopt;
|
|
|
|
if (stride_for_first_block != 1 || this->inplanes != planes_for_block * ResNetImpl::expansion) {
|
|
torch::nn::Sequential ds_seq;
|
|
auto conv_down = torch::nn::Conv2d(torch::nn::Conv2dOptions(this->inplanes, planes_for_block * ResNetImpl::expansion, 1).stride(stride_for_first_block).bias(false));
|
|
auto bn_down = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(planes_for_block * ResNetImpl::expansion).eps(static_cast<float>(1e-5)).momentum(0.1).affine(true).track_running_stats(true));
|
|
|
|
std::string ds_block_prefix = layer_param_prefix + "0.downsample.";
|
|
|
|
conv_down->weight = load_named_tensor(this->_base_weights_dir, ds_block_prefix + "0.weight", device);
|
|
bn_down->weight = load_named_tensor(this->_base_weights_dir, ds_block_prefix + "1.weight", device);
|
|
bn_down->bias = load_named_tensor(this->_base_weights_dir, ds_block_prefix + "1.bias", device);
|
|
bn_down->running_mean = load_named_tensor(this->_base_weights_dir, ds_block_prefix + "1.running_mean", device);
|
|
bn_down->running_var = load_named_tensor(this->_base_weights_dir, ds_block_prefix + "1.running_var", device);
|
|
bn_down->num_batches_tracked = load_named_tensor(this->_base_weights_dir, ds_block_prefix + "1.num_batches_tracked", device);
|
|
|
|
ds_seq->push_back(conv_down);
|
|
ds_seq->push_back(bn_down);
|
|
downsample_module_for_block_opt = ds_seq;
|
|
}
|
|
|
|
std::string first_block_param_prefix = layer_param_prefix + "0.";
|
|
layer_sequential->push_back(Bottleneck(this->_base_weights_dir, first_block_param_prefix,
|
|
this->inplanes, planes_for_block, device,
|
|
stride_for_first_block,
|
|
downsample_module_for_block_opt, ResNetImpl::expansion));
|
|
this->inplanes = planes_for_block * ResNetImpl::expansion;
|
|
|
|
for (int64_t i = 1; i < num_blocks; ++i) {
|
|
std::string current_block_param_prefix = layer_param_prefix + std::to_string(i) + ".";
|
|
layer_sequential->push_back(Bottleneck(this->_base_weights_dir, current_block_param_prefix,
|
|
this->inplanes, planes_for_block, device,
|
|
1,
|
|
std::nullopt, ResNetImpl::expansion));
|
|
}
|
|
return layer_sequential;
|
|
}
|
|
|
|
std::map<std::string, torch::Tensor> ResNetImpl::forward(torch::Tensor x) {
|
|
std::map<std::string, torch::Tensor> outputs;
|
|
|
|
auto should_output = [&](const std::string& layer_name) {
|
|
return std::find(_output_layers.begin(), _output_layers.end(), layer_name) != _output_layers.end();
|
|
};
|
|
|
|
x = conv1->forward(x);
|
|
if (should_output("conv1_output")) outputs["conv1_output"] = x;
|
|
if (should_output("debug_resnet_conv1_output_for_bn1_input")) {
|
|
outputs["debug_resnet_conv1_output_for_bn1_input"] = x.clone();
|
|
}
|
|
torch::ScalarType original_dtype_resnet_bn1 = x.scalar_type();
|
|
|
|
// Apply bn1
|
|
if (!this->is_training() && bn1) {
|
|
const auto& bn_module = *bn1;
|
|
torch::Tensor input_double = x.to(torch::kFloat64);
|
|
torch::Tensor weight_double = bn_module.weight.defined() ? bn_module.weight.to(torch::kFloat64) : torch::Tensor();
|
|
torch::Tensor bias_double = bn_module.bias.defined() ? bn_module.bias.to(torch::kFloat64) : torch::Tensor();
|
|
torch::Tensor running_mean_double = bn_module.running_mean.to(torch::kFloat64);
|
|
torch::Tensor running_var_double = bn_module.running_var.to(torch::kFloat64);
|
|
double eps_double = bn_module.options.eps();
|
|
|
|
auto c = x.size(1);
|
|
torch::Tensor reshaped_running_mean = running_mean_double.reshape({1, c, 1, 1});
|
|
torch::Tensor reshaped_running_var = running_var_double.reshape({1, c, 1, 1});
|
|
torch::Tensor reshaped_weight = weight_double.defined() ? weight_double.reshape({1, c, 1, 1}) : torch::Tensor();
|
|
torch::Tensor reshaped_bias = bias_double.defined() ? bias_double.reshape({1, c, 1, 1}) : torch::Tensor();
|
|
|
|
torch::Tensor centered_x = input_double - reshaped_running_mean;
|
|
if (should_output("bn1_centered_x")) outputs["bn1_centered_x"] = centered_x.clone();
|
|
|
|
torch::Tensor variance_plus_eps = reshaped_running_var + eps_double;
|
|
if (should_output("bn1_variance_plus_eps")) outputs["bn1_variance_plus_eps"] = variance_plus_eps.clone();
|
|
|
|
torch::Tensor inv_std = torch::rsqrt(variance_plus_eps); // Using rsqrt for potential match
|
|
if (should_output("bn1_inv_std")) outputs["bn1_inv_std"] = inv_std.clone();
|
|
|
|
torch::Tensor normalized_x = centered_x * inv_std;
|
|
if (should_output("bn1_normalized_x")) outputs["bn1_normalized_x"] = normalized_x.clone();
|
|
|
|
torch::Tensor out_double = normalized_x;
|
|
if (reshaped_weight.defined()) out_double = out_double * reshaped_weight;
|
|
if (reshaped_bias.defined()) out_double = out_double + reshaped_bias;
|
|
|
|
x = out_double.to(original_dtype_resnet_bn1);
|
|
} else if (bn1) { // Training mode or if manual is disabled
|
|
x = bn1->forward(x);
|
|
}
|
|
// End apply bn1
|
|
|
|
if (should_output("bn1_output")) outputs["bn1_output"] = x;
|
|
|
|
x = relu->forward(x);
|
|
if (should_output("relu1_output")) outputs["relu1_output"] = x;
|
|
|
|
torch::Tensor x_pre_layer1 = maxpool->forward(x);
|
|
if (should_output("maxpool_output")) outputs["maxpool_output"] = x_pre_layer1;
|
|
|
|
// Save output of layer1.0 block if requested
|
|
if (should_output("layer1_0_block_output")) {
|
|
if (layer1 && !layer1->is_empty()) {
|
|
try {
|
|
// Get the base module pointer
|
|
std::shared_ptr<torch::nn::Module> base_module_ptr = layer1->ptr(0);
|
|
// Try to cast it to our BottleneckImpl (which is a torch::nn::Module)
|
|
auto bottleneck_impl_ptr = std::dynamic_pointer_cast<cimp::resnet::BottleneckImpl>(base_module_ptr);
|
|
|
|
if (bottleneck_impl_ptr) {
|
|
// Now call forward on the BottleneckImpl instance
|
|
outputs["layer1_0_block_output"] = bottleneck_impl_ptr->forward(x_pre_layer1);
|
|
} else {
|
|
std::cerr << "ERROR: layer1->ptr(0) could not be dynamically cast to BottleneckImpl! Module type: "
|
|
<< (base_module_ptr ? typeid(*base_module_ptr).name() : "null") << std::endl;
|
|
}
|
|
} catch (const std::exception& e) {
|
|
std::cerr << "EXCEPTION while getting layer1_0_block_output: " << e.what() << std::endl;
|
|
}
|
|
}
|
|
}
|
|
|
|
torch::Tensor x_after_layer1 = layer1->forward(x_pre_layer1);
|
|
if (should_output("layer1")) outputs["layer1"] = x_after_layer1;
|
|
|
|
if (should_output("layer1_0_shortcut_output")) {
|
|
if (layer1 && !layer1->is_empty()) {
|
|
try {
|
|
std::shared_ptr<torch::nn::Module> first_block_module_ptr = layer1->ptr(0);
|
|
auto bottleneck_module_holder = std::dynamic_pointer_cast<cimp::resnet::BottleneckImpl>(first_block_module_ptr);
|
|
|
|
if (bottleneck_module_holder) {
|
|
if (bottleneck_module_holder->projection_shortcut) {
|
|
torch::Tensor shortcut_out = bottleneck_module_holder->projection_shortcut->forward(x_pre_layer1);
|
|
outputs["layer1_0_shortcut_output"] = shortcut_out;
|
|
}
|
|
}
|
|
} catch (const std::exception& e) {
|
|
// std::cerr << "ERROR: Exception while getting layer1_0_shortcut_output: " << e.what() << std::endl;
|
|
}
|
|
}
|
|
}
|
|
|
|
torch::Tensor x_current = x_after_layer1;
|
|
|
|
x_current = layer2->forward(x_current);
|
|
if (should_output("layer2")) outputs["layer2"] = x_current;
|
|
|
|
x_current = layer3->forward(x_current);
|
|
if (should_output("layer3")) outputs["layer3"] = x_current;
|
|
|
|
x_current = layer4->forward(x_current);
|
|
if (should_output("layer4")) outputs["layer4"] = x_current;
|
|
|
|
if (should_output("features")) outputs["features"] = x_current;
|
|
|
|
return outputs;
|
|
}
|
|
|
|
// For ResNet-50, layers are [3, 4, 6, 3]
|
|
ResNet resnet50(const std::string& base_weights_dir,
|
|
const std::vector<std::string>& output_layers,
|
|
const torch::Device& device) {
|
|
return ResNet(ResNetImpl(base_weights_dir, {3, 4, 6, 3}, output_layers, device)); // Pass device
|
|
}
|
|
|
|
} // namespace resnet
|
|
} // namespace cimp
|