You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

370 lines
18 KiB

#include <iostream>
#include <vector>
#include <string>
#include <map>
#include <optional>
#include <filesystem> // For directory creation and path manipulation
#include <torch/torch.h>
#include <torch/script.h> // For torch::save and torch::load
#include <ATen/Context.h> // Required for globalContext
// #include <opencv2/opencv.hpp> // REMOVED
#include <algorithm> // For std::find
// Project headers
#include "../cimp/resnet/resnet.h"
#include "../cimp/classifier/classifier.h"
#include "../cimp/bb_regressor/bb_regressor.h"
namespace fs = std::filesystem;
// Helper to ensure a directory exists
void ensure_dir_exists(const std::string& path_str) {
fs::path dir_path(path_str);
if (!fs::exists(dir_path)) {
if (fs::create_directories(dir_path)) {
std::cout << "Created directory: " << dir_path.string() << std::endl;
} else {
std::cerr << "Error: Could not create directory: " << dir_path.string() << std::endl;
// Optionally throw an exception or exit
}
}
}
// Helper to load a tensor
torch::Tensor load_tensor_from_file(const std::string& file_path, torch::Device device) {
std::cout << " Loading tensor from: " << file_path << std::endl;
torch::Tensor tensor;
try {
torch::load(tensor, file_path);
return tensor.to(device);
} catch (const c10::Error& e) {
std::cerr << " Direct torch::load failed, trying as JIT module: " << e.what() << std::endl;
try {
auto module = torch::jit::load(file_path, device);
// Try to extract parameter named 'tensor' (as in TensorContainer)
for (const auto& p : module.named_parameters()) {
if (p.name == "tensor") {
std::cout << " Extracted 'tensor' parameter from JIT module." << std::endl;
return p.value.to(device);
}
}
// Try buffers if not found in parameters
for (const auto& b : module.named_buffers()) {
if (b.name == "tensor") {
std::cout << " Extracted 'tensor' buffer from JIT module." << std::endl;
return b.value.to(device);
}
}
throw std::runtime_error("No 'tensor' parameter or buffer found in JIT module: " + file_path);
} catch (const c10::Error& e2) {
std::cerr << " Failed to load as JIT module: " << e2.what() << std::endl;
throw;
}
}
}
// Helper to save a tensor
void save_tensor_to_file(const torch::Tensor& tensor, const std::string& file_path) {
try {
torch::save(tensor.cpu(), file_path);
} catch (const c10::Error& e) {
std::cerr << "Error saving tensor to " << file_path << ": " << e.what() << std::endl;
}
}
// Helper to generate or load a dummy tensor
torch::Tensor get_or_generate_tensor(
const std::string& file_path,
const std::vector<int64_t>& shape,
torch::Device device,
bool force_generate = false) {
if (!force_generate && fs::exists(file_path)) {
return load_tensor_from_file(file_path, device);
} else {
std::cout << " Generating dummy tensor for: " << file_path << " with shape [";
for(size_t i=0; i < shape.size(); ++i) std::cout << shape[i] << (i == shape.size()-1 ? "" : ", ");
std::cout << "]" << std::endl;
auto tensor = torch::randn(shape, torch::TensorOptions().device(device).dtype(torch::kFloat32));
save_tensor_to_file(tensor, file_path); // Save the generated one for future use
return tensor;
}
}
// Helper to save a tensor as a JIT-wrapped module (matching Python TensorContainer)
void save_tensor_as_jit_module(const torch::Tensor& tensor, const std::string& file_path) {
// Define a simple module with a single parameter
struct TensorContainerImpl : torch::nn::Module {
TensorContainerImpl(const torch::Tensor& t) {
tensor = register_parameter("tensor", t.clone());
}
torch::Tensor tensor;
};
// Create and script the module
auto module = std::make_shared<TensorContainerImpl>(tensor);
auto scripted = torch::jit::script::Module("TensorContainer");
scripted.register_parameter("tensor", module->tensor, /*is_buffer=*/false);
// Save as JIT module
try {
scripted.save(file_path);
} catch (const c10::Error& e) {
std::cerr << "Error saving JIT-wrapped tensor to " << file_path << ": " << e.what() << std::endl;
}
}
int main(int argc, char* argv[]) {
if (argc < 4) {
// Original usage message for consistency with how run_tests.sh calls it
std::cerr << "Usage: " << argv[0] << " <base_input_dir_path> <base_output_dir_path> <num_samples> [model_to_test]" << std::endl;
return 1;
}
// --- Set Global PyTorch/cuDNN Flags ---
at::globalContext().setDeterministicCuDNN(true);
at::globalContext().setBenchmarkCuDNN(false); // Explicitly set benchmark to false
std::cout << "C++: Set globalContext DeterministicCuDNN=true, BenchmarkCuDNN=false" << std::endl;
// ---
// Original argument parsing structure
std::string base_input_dir = argv[1];
std::string base_output_dir = argv[2];
int num_samples = 0;
try {
num_samples = std::stoi(argv[3]);
} catch (const std::exception& e) {
std::cerr << "Error: Invalid number of samples in argv[3]='" << argv[3] << "': " << e.what() << std::endl;
return 1;
}
std::string model_to_test = "all"; // Default to all
if (argc > 4) {
model_to_test = argv[4];
}
if (num_samples <= 0) {
std::cout << "Number of samples is " << num_samples << ". Exiting." << std::endl;
return 0;
}
std::cout << "Running for " << num_samples << " samples." << std::endl;
std::cout << "Input directory: " << base_input_dir << std::endl;
std::cout << "Output directory: " << base_output_dir << std::endl;
torch::DeviceType device_type;
if (torch::cuda::is_available()) {
std::cout << "CUDA available! Running on GPU." << std::endl;
device_type = torch::kCUDA; // REVERTED to kCUDA
} else {
std::cout << "CUDA not available. Running on CPU." << std::endl;
device_type = torch::kCPU;
}
torch::Device device(device_type);
// --- Initialize Models (once) ---
std::cout << "--- Initializing Models ---" << std::endl;
std::string resnet_weights_path = "exported_weights/backbone_regenerated";
std::string bb_reg_weights_path = "exported_weights/bb_regressor";
std::string classifier_weights_path = "exported_weights/classifier";
std::optional<cimp::resnet::ResNet> resnet_model_opt;
// Use TORCH_MODULE wrapper types directly
std::optional<Classifier> classifier_model_opt_wrapped;
std::optional<BBRegressor> bb_regressor_model_opt_wrapped;
bool models_initialized_ok = true;
try {
std::vector<std::string> output_layers_resnet = {
"conv1_output", "bn1_output", "relu1_output", "maxpool_output",
"layer1", "layer2", "layer3", "layer4", "features",
"layer1_0_shortcut_output", "layer1_0_block_output",
"debug_resnet_conv1_output_for_bn1_input",
// New BN1 intermediate outputs
"bn1_centered_x", "bn1_variance_plus_eps", "bn1_inv_std", "bn1_normalized_x"
};
resnet_model_opt.emplace(cimp::resnet::resnet50(resnet_weights_path, output_layers_resnet, device));
(*resnet_model_opt)->to(device);
(*resnet_model_opt)->eval();
std::cout << "ResNet-50 initialized successfully." << std::endl;
} catch (const std::exception& e) {
std::cerr << "CRITICAL: Error initializing ResNet-50: " << e.what() << std::endl;
models_initialized_ok = false;
}
if (models_initialized_ok) {
try {
std::cout << "Attempting to initialize BBRegressor with weights from: " << bb_reg_weights_path << std::endl;
// BBRegressor class is in global namespace
bb_regressor_model_opt_wrapped.emplace(BBRegressor(bb_reg_weights_path, device));
(*bb_regressor_model_opt_wrapped).to(device);
(*bb_regressor_model_opt_wrapped).eval();
std::cout << "BBRegressor initialized successfully." << std::endl;
} catch (const std::exception& e) {
std::cerr << "CRITICAL: Error initializing BBRegressor: " << e.what() << std::endl;
models_initialized_ok = false;
}
}
if (models_initialized_ok) {
try {
std::cout << "Attempting to initialize Classifier with weights from: " << classifier_weights_path << std::endl;
// Classifier class is in global namespace
classifier_model_opt_wrapped.emplace(Classifier(classifier_weights_path, device));
(*classifier_model_opt_wrapped).to(device);
// (*classifier_model_opt_wrapped).eval(); // Classifier class does not have an eval() method
std::cout << "Classifier initialized successfully." << std::endl;
} catch (const std::exception& e) {
std::cerr << "CRITICAL: Error initializing Classifier: " << e.what() << std::endl;
models_initialized_ok = false;
}
}
if (!models_initialized_ok) {
std::cerr << "Model initialization failed. Aborting tests." << std::endl;
return 1;
}
std::cout << "All models initialized." << std::endl;
// --- Define and Announce Output Directories Before Loop ---
fs::path resnet_out_dir_main = fs::path(base_output_dir) / "resnet";
fs::path clf_out_dir_main = fs::path(base_output_dir) / "classifier";
fs::path bb_reg_out_dir_main = fs::path(base_output_dir) / "bb_regressor";
std::cout << "C++ Model Output Directories Check:" << std::endl; // Changed text slightly for new check
std::cout << " Expected ResNet output: " << resnet_out_dir_main.string() << std::endl;
std::cout << " Expected Classifier output: " << clf_out_dir_main.string() << std::endl;
std::cout << " Expected BB Regressor output: " << bb_reg_out_dir_main.string() << std::endl;
// --- Main Loop for Samples ---
for (int i = 0; i < num_samples; ++i) {
std::cout << "--- Processing Sample " << i << " ---" << std::endl;
std::string sample_suffix = "sample_" + std::to_string(i);
// Define input file paths
fs::path common_input_path = fs::path(base_input_dir) / "common";
ensure_dir_exists(common_input_path.string());
std::string image_file = (common_input_path / (sample_suffix + "_image.pt")).string();
std::string bb_file = (common_input_path / (sample_suffix + "_bb.pt")).string();
std::string proposals_file = (common_input_path / (sample_suffix + "_proposals.pt")).string();
// Define output directories
fs::path resnet_out_dir = fs::path(base_output_dir) / "resnet";
fs::path clf_out_dir = fs::path(base_output_dir) / "classifier";
fs::path bb_reg_out_dir = fs::path(base_output_dir) / "bb_regressor";
ensure_dir_exists(resnet_out_dir.string());
ensure_dir_exists(clf_out_dir.string());
ensure_dir_exists(bb_reg_out_dir.string());
bool force_generate_dummy = (i == 0);
torch::Tensor image_tensor = get_or_generate_tensor(image_file, {1, 3, 256, 256}, device, force_generate_dummy && !fs::exists(image_file));
torch::Tensor bb_tensor = get_or_generate_tensor(bb_file, {1, 4}, device, force_generate_dummy && !fs::exists(bb_file));
torch::Tensor proposals_tensor = get_or_generate_tensor(proposals_file, {1, 10, 4}, device, force_generate_dummy && !fs::exists(proposals_file));
// --- Normalize the image tensor --- START
if (image_tensor.dim() == 3) { // Add batch dim if not present
image_tensor = image_tensor.unsqueeze(0);
}
// Ensure it's float (already done by get_or_generate_tensor if is_image=true)
image_tensor = image_tensor / 255.0f; // Add division by 255.0
torch::Tensor mean = torch::tensor({0.485, 0.456, 0.406}, image_tensor.options()).view({1, -1, 1, 1});
torch::Tensor std_dev = torch::tensor({0.229, 0.224, 0.225}, image_tensor.options()).view({1, -1, 1, 1});
image_tensor.sub_(mean).div_(std_dev);
// --- Normalize the image tensor --- END
// Save C++ preprocessed input tensor
fs::path cpp_preprocessed_save_path = resnet_out_dir / ("sample_" + std::to_string(i) + "_image_preprocessed_cpp.pt");
save_tensor_to_file(image_tensor, cpp_preprocessed_save_path.string());
std::cout << "Saved C++ preprocessed image for sample " << i << " to " << cpp_preprocessed_save_path.string() << std::endl;
// Save proposals and bbox as JIT-wrapped modules for sample 0
if (i == 0) {
if (fs::exists(proposals_file)) {
auto proposals_tensor = load_tensor_from_file(proposals_file, device);
save_tensor_as_jit_module(proposals_tensor, (common_input_path / (sample_suffix + "_proposals.pt")).string());
std::cout << "Saved proposals as JIT-wrapped module for sample 0." << std::endl;
}
if (fs::exists(bb_file)) {
auto bb_tensor = load_tensor_from_file(bb_file, device);
save_tensor_as_jit_module(bb_tensor, (common_input_path / (sample_suffix + "_bb.pt")).string());
std::cout << "Saved bbox as JIT-wrapped module for sample 0." << std::endl;
}
}
// 1. ResNet Processing
std::map<std::string, torch::Tensor> resnet_outputs;
if (resnet_model_opt.has_value()){
try {
std::cout << "Processing ResNet for sample " << i << std::endl;
resnet_outputs = (*resnet_model_opt)->forward(image_tensor);
// Save all defined output layers for ResNet for comparison
// The output_layers used during construction was: {"layer1", "layer2", "layer3", "layer4", "features", "layer1_0_shortcut_output"}
for (const auto& pair : resnet_outputs) {
std::string output_name = pair.first;
std::string filename_key = output_name;
std::replace(filename_key.begin(), filename_key.end(), '.', '_'); // Ensure dots are replaced for all keys
fs::path output_path = resnet_out_dir / ("sample_" + std::to_string(i) + "_" + filename_key + ".pt");
save_tensor_to_file(pair.second, output_path.string());
std::cout << " Saving tensor to: " << output_path.string() << " shape: " << pair.second.sizes() << std::endl;
}
std::cout << "ResNet processing done for sample " << i << std::endl;
} catch (const std::exception& e) {
std::cerr << "Error during ResNet processing for sample " << i << ": " << e.what() << std::endl;
// continue; // Decide if to skip to next sample or let it fail further down
}
} else {
std::cerr << "Skipping ResNet processing for sample " << i << " as model was not initialized." << std::endl;
// If ResNet didn't initialize, resnet_outputs will be empty. Downstream models need to handle this.
}
// 2. Classifier Processing
if (classifier_model_opt_wrapped.has_value() && resnet_outputs.count("layer3")) {
try {
std::cout << "Processing Classifier for sample " << i << std::endl;
auto clf_input = resnet_outputs["layer3"].clone();
torch::Tensor clf_features = (*classifier_model_opt_wrapped).extract_features(clf_input);
save_tensor_to_file(clf_features, (clf_out_dir / (sample_suffix + "_features.pt")).string());
std::cout << "Classifier processing done for sample " << i << std::endl;
} catch (const std::exception& e) {
std::cerr << "Error during Classifier processing for sample " << i << ": " << e.what() << std::endl;
}
} else {
std::cerr << "Skipping Classifier for sample " << i << " (model not init or ResNet layer3 input missing)." << std::endl;
}
// 3. BBRegressor Processing
if (bb_regressor_model_opt_wrapped.has_value() && resnet_outputs.count("layer2") && resnet_outputs.count("layer3")) {
try {
std::cout << "Processing BBRegressor for sample " << i << std::endl;
// Remove debug_get_conv3_1t_output and debug_get_conv4_1t_output calls (they do not exist)
// The correct BBRegressor logic is:
std::vector<torch::Tensor> backbone_feats_for_bb = {
resnet_outputs["layer2"].clone(),
resnet_outputs["layer3"].clone()
};
std::vector<torch::Tensor> iou_feats = (*bb_regressor_model_opt_wrapped).get_iou_feat(backbone_feats_for_bb);
if (iou_feats.size() >= 1) save_tensor_to_file(iou_feats[0], (bb_reg_out_dir / (sample_suffix + "_iou_feat0.pt")).string());
if (iou_feats.size() >= 2) save_tensor_to_file(iou_feats[1], (bb_reg_out_dir / (sample_suffix + "_iou_feat1.pt")).string());
std::vector<torch::Tensor> mod_vectors = (*bb_regressor_model_opt_wrapped).get_modulation(backbone_feats_for_bb, bb_tensor);
if (mod_vectors.size() >= 1) save_tensor_to_file(mod_vectors[0], (bb_reg_out_dir / (sample_suffix + "_mod_vec0.pt")).string());
if (mod_vectors.size() >= 2) save_tensor_to_file(mod_vectors[1], (bb_reg_out_dir / (sample_suffix + "_mod_vec1.pt")).string());
if (!iou_feats.empty() && !mod_vectors.empty()) {
torch::Tensor iou_scores = (*bb_regressor_model_opt_wrapped).predict_iou(mod_vectors, iou_feats, proposals_tensor);
save_tensor_to_file(iou_scores, (bb_reg_out_dir / (sample_suffix + "_iou_scores.pt")).string());
} else {
std::cerr << " Skipping BBRegressor predict_iou for sample " << i << " (iou_feats or mod_vectors empty)." << std::endl;
}
std::cout << "BBRegressor processing done for sample " << i << std::endl;
} catch (const std::exception& e) {
std::cerr << "Error during BBRegressor processing for sample " << i << ": " << e.what() << std::endl;
}
} else {
std::cerr << "Skipping BBRegressor for sample " << i << " (model not init or ResNet inputs missing)." << std::endl;
}
}
std::cout << "File-based test models run completed for " << num_samples << " samples." << std::endl;
return 0;
}