Browse Source

Refactor: Remove double precision in BBRegressor LinearBlock

detached
mht 1 week ago
parent
commit
c4156feccc
  1. 430
      cimp/bb_regressor/bb_regressor.cpp

430
cimp/bb_regressor/bb_regressor.cpp

@ -30,37 +30,42 @@ torch::Tensor PrRoIPool2D::forward(torch::Tensor feat, torch::Tensor rois) {
int channels = feat.size(1);
int num_rois = rois.size(0);
// Ensure both tensors are on CUDA
// Ensure both tensors are on CUDA initially (as they come from GPU operations)
if (!feat.is_cuda() || !rois.is_cuda()) {
throw std::runtime_error("PrRoIPool2D requires CUDA tensors - CPU mode is not supported");
// This case should ideally not happen if inputs are from CUDA model parts
// but if it does, move them to CUDA first for consistency, then to CPU for the C function
std::cout << "Warning: PrRoIPool2D received non-CUDA tensor(s). Moving to CUDA then CPU." << std::endl;
feat = feat.to(torch::kCUDA);
rois = rois.to(torch::kCUDA);
}
// Print ROI values for debugging
std::cout << " ROI values: " << std::endl;
std::cout << " ROI values (on device " << rois.device() << "): " << std::endl;
auto rois_cpu_for_print = rois.to(torch::kCPU).contiguous(); // Temp CPU copy for printing
for (int i = 0; i < std::min(num_rois, 3); i++) {
std::cout << " ROI " << i << ": [";
for (int j = 0; j < rois.size(1); j++) {
std::cout << rois[i][j].item<float>();
if (j < rois.size(1) - 1) std::cout << ", ";
for (int j = 0; j < rois_cpu_for_print.size(1); j++) {
std::cout << rois_cpu_for_print[i][j].item<float>();
if (j < rois_cpu_for_print.size(1) - 1) std::cout << ", ";
}
std::cout << "]" << std::endl;
}
// Create output tensor on the same device
// Create output tensor on the same original device as feat (CUDA)
auto output = torch::zeros({num_rois, channels, pooled_height_, pooled_width_},
feat.options());
// Copy tensors to CPU for the C implementation
// REVERTED: Copy tensors to CPU for the C implementation, as prroi_pooling_forward_cuda expects CPU pointers
auto feat_cpu = feat.to(torch::kCPU).contiguous();
auto rois_cpu = rois.to(torch::kCPU).contiguous();
auto output_cpu = output.to(torch::kCPU).contiguous();
auto rois_cpu = rois.to(torch::kCPU).contiguous(); // Already on CPU for printing, ensure contiguous
auto output_cpu = output.to(torch::kCPU).contiguous(); // Create CPU version for the C function to fill
// Call the C wrapper function
std::cout << " Calling prroi_pooling_forward_cuda..." << std::endl;
// Call the C wrapper function (which is a CPU implementation)
std::cout << " Calling prroi_pooling_forward_cuda (CPU implementation)..." << std::endl;
prroi_pooling_forward_cuda(
feat_cpu.data_ptr<float>(),
static_cast<float*>(rois_cpu.data_ptr()),
static_cast<float*>(output_cpu.data_ptr()),
rois_cpu.data_ptr<float>(), // Pass the CPU tensor data
output_cpu.data_ptr<float>(), // Pass CPU output tensor data
channels,
feat.size(2),
feat.size(3),
@ -71,7 +76,7 @@ torch::Tensor PrRoIPool2D::forward(torch::Tensor feat, torch::Tensor rois) {
);
std::cout << " prroi_pooling_forward_cuda completed" << std::endl;
// Copy result back to GPU
// Copy result back to original device (GPU)
output.copy_(output_cpu);
return output;
@ -87,6 +92,9 @@ LinearBlock::LinearBlock(int in_planes, int out_planes, int input_sz, bool bias,
if (use_bn) {
// Important: use BatchNorm2d to match Python implementation
bn = register_module("bn", torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(out_planes)));
// Initialize BatchNorm weights and biases like Python
bn->weight.data().uniform_();
bn->bias.data().zero_();
}
use_relu = relu;
@ -96,46 +104,53 @@ LinearBlock::LinearBlock(int in_planes, int out_planes, int input_sz, bool bias,
}
torch::Tensor LinearBlock::forward(torch::Tensor x) {
// Store original dtype for later
auto original_dtype = x.dtype();
// Store original dtype for later (though we will stick to it)
// auto original_dtype = x.dtype();
// Use double precision for higher accuracy
auto x_double = x.to(torch::kFloat64);
// REMOVED: Conversions to double precision
// auto x_double = x.to(torch::kFloat64);
// Reshape exactly as in Python: x.reshape(x.shape[0], -1)
x_double = x_double.reshape({x_double.size(0), -1}).contiguous();
// x_double = x_double.reshape({x_double.size(0), -1}).contiguous();
x = x.reshape({x.size(0), -1}).contiguous(); // Operate on original tensor x
// Convert back to original precision for the linear operation
auto x_float = x_double.to(original_dtype);
x_float = linear->forward(x_float);
// REMOVED: Conversion back to original precision for the linear operation
// auto x_float = x_double.to(original_dtype);
// x_float = linear->forward(x_float);
x = linear->forward(x); // Operate on original tensor x
// Back to double precision for further operations
x_double = x_float.to(torch::kFloat64);
// REMOVED: Back to double precision for further operations
// x_double = x_float.to(torch::kFloat64);
if (use_bn) {
// This is crucial: reshape to 4D tensor for BatchNorm2d exactly as in Python
// In Python: x = self.bn(x.reshape(x.shape[0], x.shape[1], 1, 1))
x_double = x_double.reshape({x_double.size(0), x_double.size(1), 1, 1}).contiguous();
// x_double = x_double.reshape({x_double.size(0), x_double.size(1), 1, 1}).contiguous();
x = x.reshape({x.size(0), x.size(1), 1, 1}).contiguous(); // Operate on original tensor x
// Apply batch norm (convert to float32 for the operation)
x_float = x_double.to(original_dtype);
x_float = bn->forward(x_float);
x_double = x_float.to(torch::kFloat64);
// Apply batch norm (convert to float32 for the operation - NOT NEEDED if x is already float32)
// x_float = x_double.to(original_dtype);
// x_float = bn->forward(x_float);
// x_double = x_float.to(torch::kFloat64);
x = bn->forward(x); // Operate on original tensor x
}
// Apply ReLU if needed
if (use_relu) {
// Apply ReLU in float32 precision
x_float = x_double.to(original_dtype);
x_float = relu_->forward(x_float);
x_double = x_float.to(torch::kFloat64);
// Apply ReLU in float32 precision - NOT NEEDED if x is already float32
// x_float = x_double.to(original_dtype);
// x_float = relu_->forward(x_float);
// x_double = x_float.to(torch::kFloat64);
x = relu_->forward(x); // Operate on original tensor x
}
// Final reshape to 2D tensor, exactly matching Python's behavior
x_double = x_double.reshape({x_double.size(0), -1}).contiguous();
// x_double = x_double.reshape({x_double.size(0), -1}).contiguous();
x = x.reshape({x.size(0), -1}).contiguous(); // Operate on original tensor x
// Return tensor in original precision
return x_double.to(original_dtype);
// return x_double.to(original_dtype);
return x; // Return modified x directly
}
// Create convolutional block
@ -152,7 +167,11 @@ torch::nn::Sequential BBRegressor::create_conv_block(int in_planes, int out_plan
.stride(stride).padding(padding).dilation(dilation).bias(true)));
// Add batch normalization layer
seq->push_back(torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(out_planes)));
auto bn_layer = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(out_planes));
// Initialize BatchNorm weights and biases like Python
bn_layer->weight.data().uniform_();
bn_layer->bias.data().zero_();
seq->push_back(bn_layer);
// Add ReLU activation
seq->push_back(torch::nn::ReLU(torch::nn::ReLUOptions().inplace(true)));
@ -221,16 +240,11 @@ torch::Tensor BBRegressor::load_tensor(const std::string& file_path) {
}
// Constructor
BBRegressor::BBRegressor(const std::string& base_dir, torch::Device dev)
: device(dev), model_dir(base_dir + "/exported_weights/bb_regressor"),
BBRegressor::BBRegressor(const std::string& model_weights_dir, torch::Device dev)
: device(dev), model_dir(model_weights_dir),
fc3_rt(256, 256, 5, true, true, true),
fc4_rt(256, 256, 3, true, true, true) {
// Check if base directory exists
if (!fs::exists(base_dir)) {
throw std::runtime_error("Base directory does not exist: " + base_dir);
}
// Check if model directory exists
if (!fs::exists(model_dir)) {
throw std::runtime_error("Model directory does not exist: " + model_dir);
@ -539,103 +553,90 @@ void BBRegressor::to(torch::Device device) {
}
// Get IoU features from backbone features
std::vector<torch::Tensor> BBRegressor::get_iou_feat(std::vector<torch::Tensor> feat2) {
// Convert to double precision for better numerical stability
auto feat2_double0 = feat2[0].to(torch::kFloat64);
auto feat2_double1 = feat2[1].to(torch::kFloat64);
// Reshape exactly as in Python implementation
// In Python: feat2 = [f.reshape(-1, *f.shape[-3:]) if f.dim()==5 else f for f in feat2]
if (feat2_double0.dim() == 5) {
auto shape = feat2_double0.sizes();
feat2_double0 = feat2_double0.reshape({-1, shape[2], shape[3], shape[4]}).contiguous();
}
if (feat2_double1.dim() == 5) {
auto shape = feat2_double1.sizes();
feat2_double1 = feat2_double1.reshape({-1, shape[2], shape[3], shape[4]}).contiguous();
}
// Convert back to float32 for convolution operations
feat2[0] = feat2_double0.to(torch::kFloat32).contiguous();
feat2[1] = feat2_double1.to(torch::kFloat32).contiguous();
// Apply convolutions exactly as in Python
torch::Tensor feat3_t = feat2[0];
torch::Tensor feat4_t = feat2[1];
// Ensure we're in evaluation mode
std::vector<torch::Tensor> BBRegressor::get_iou_feat(std::vector<torch::Tensor> feat_in) {
torch::NoGradGuard no_grad;
if (feat_in.size() != 2) {
throw std::runtime_error("get_iou_feat expects 2 input features (layer2, layer3).");
}
// feat_in[0] is backbone layer2 (e.g., [B, 512, H1, W1])
// feat_in[1] is backbone layer3 (e.g., [B, 1024, H2, W2])
auto feat3_t_in = feat_in[0].to(device);
auto feat4_t_in = feat_in[1].to(device);
// Process through conv layers
// conv3_1t should take 512 -> 256 channels
// conv3_2t should take 256 -> 256 channels (pred_input_dim[0])
auto c3_t = conv3_2t->forward(conv3_1t->forward(feat3_t_in));
// conv4_1t should take 1024 -> 256 channels
// conv4_2t should take 256 -> 256 channels (pred_input_dim[1])
auto c4_t = conv4_2t->forward(conv4_1t->forward(feat4_t_in));
// Apply convolutions just like Python version
torch::Tensor c3_t_1 = conv3_1t->forward(feat3_t);
c3_t_1 = c3_t_1.contiguous();
torch::Tensor c3_t = conv3_2t->forward(c3_t_1);
c3_t = c3_t.contiguous();
torch::Tensor c4_t_1 = conv4_1t->forward(feat4_t);
c4_t_1 = c4_t_1.contiguous();
torch::Tensor c4_t = conv4_2t->forward(c4_t_1);
c4_t = c4_t.contiguous();
// Return results
return {c3_t, c4_t};
return {c3_t.contiguous(), c4_t.contiguous()};
}
// Get modulation vectors for the target
std::vector<torch::Tensor> BBRegressor::get_modulation(std::vector<torch::Tensor> feat, torch::Tensor bb) {
// Apply target branch to get modulation vectors
std::cout << " get_modulation input bb: " << bb.sizes() << std::endl;
// Convert bounding box from [x, y, w, h] to [batch_idx, x1, y1, x2, y2] format for ROI pooling
auto roi = torch::zeros({bb.size(0), 5}, bb.options());
// Set batch index to 0 (first element)
roi.index_put_({torch::indexing::Slice(), 0}, 0);
// Copy x, y coordinates
roi.index_put_({torch::indexing::Slice(), 1}, bb.index({torch::indexing::Slice(), 0}));
roi.index_put_({torch::indexing::Slice(), 2}, bb.index({torch::indexing::Slice(), 1}));
// Calculate x2, y2 from width and height
auto x2 = bb.index({torch::indexing::Slice(), 0}) + bb.index({torch::indexing::Slice(), 2});
auto y2 = bb.index({torch::indexing::Slice(), 1}) + bb.index({torch::indexing::Slice(), 3});
roi.index_put_({torch::indexing::Slice(), 3}, x2);
roi.index_put_({torch::indexing::Slice(), 4}, y2);
std::cout << " Converted ROI: [";
for (int i = 0; i < roi.size(1); i++) {
std::cout << roi[0][i].item<float>();
if (i < roi.size(1) - 1) std::cout << ", ";
std::vector<torch::Tensor> BBRegressor::get_modulation(std::vector<torch::Tensor> feat_in, torch::Tensor bb_in) {
torch::NoGradGuard no_grad;
auto feat3_r_in = feat_in[0].to(device); // Backbone layer2 features, e.g., [1, 512, H1, W1]
auto feat4_r_in = feat_in[1].to(device); // Backbone layer3 features, e.g., [1, 1024, H2, W2]
auto bb = bb_in.to(device); // Target bounding box, e.g., [1, 1, 4] (x,y,w,h)
// Ensure bb is [batch_size, 1, 4] then reshape to [batch_size, 4] for PrRoIPooling
// (as PrRoIPooling expects [batch_idx, x1, y1, x2, y2])
if (bb.dim() == 3 && bb.size(1) == 1) {
bb = bb.squeeze(1); // Now [batch_size, 4]
} else if (bb.dim() != 2 || bb.size(1) != 4) {
throw std::runtime_error("get_modulation: bb must be [batch, 1, 4] or [batch, 4]");
}
std::cout << "]" << std::endl;
// Apply target branch to get modulation vectors
auto feat1 = conv3_1t->forward(feat[0]);
auto feat2 = conv3_2t->forward(feat1);
// Apply target branch to get modulation vectors for second feature map
auto feat3 = conv4_1t->forward(feat[1]);
auto feat4 = conv4_2t->forward(feat3);
// ROI pool the features - use the same ROI for both feature maps
std::cout << " Applying ROI pooling to layer 3..." << std::endl;
auto pooled_feat1 = prroi_pool3t->forward(feat2, roi);
std::cout << " Applying ROI pooling to layer 4..." << std::endl;
auto pooled_feat2 = prroi_pool4t->forward(feat4, roi);
// Flatten and concatenate the pooled features
auto vec1 = pooled_feat1.reshape({pooled_feat1.size(0), -1});
auto vec2 = pooled_feat2.reshape({pooled_feat2.size(0), -1});
// Apply fully connected layer to get modulation vectors
auto modulation1 = fc3_rt.forward(vec1);
auto modulation2 = fc4_rt.forward(vec2);
// Python: c3_r = self.conv3_1r(feat3_r)
auto c3_r = conv3_1r->forward(feat3_r_in).contiguous(); // Output: [B, 128, H1, W1]
// Python: roi1 from bb (batch_idx, x1,y1,x2,y2)
auto batch_size = bb.size(0);
auto roi1 = torch::zeros({batch_size, 5}, bb.options());
for (int64_t i = 0; i < batch_size; ++i) {
roi1.index_put_({i, 0}, static_cast<float>(i));
}
roi1.index_put_({torch::indexing::Slice(), 1}, bb.index({torch::indexing::Slice(), 0})); // x1
roi1.index_put_({torch::indexing::Slice(), 2}, bb.index({torch::indexing::Slice(), 1})); // y1
roi1.index_put_({torch::indexing::Slice(), 3}, bb.index({torch::indexing::Slice(), 0}) + bb.index({torch::indexing::Slice(), 2})); // x2
roi1.index_put_({torch::indexing::Slice(), 4}, bb.index({torch::indexing::Slice(), 1}) + bb.index({torch::indexing::Slice(), 3})); // y2
// Python: roi3r = self.prroi_pool3r(c3_r, roi1)
// prroi_pool3r is (3,3, 1/8)
auto roi3r = prroi_pool3r->forward(c3_r, roi1).contiguous(); // Output: [B, 128, 3, 3]
// Python: c4_r = self.conv4_1r(feat4_r)
auto c4_r = conv4_1r->forward(feat4_r_in).contiguous(); // Output: [B, 256, H2, W2]
// Python: roi4r = self.prroi_pool4r(c4_r, roi1)
// prroi_pool4r is (1,1, 1/16)
auto roi4r = prroi_pool4r->forward(c4_r, roi1).contiguous(); // Output: [B, 256, 1, 1]
// Python: fc3_r = self.fc3_1r(roi3r)
// fc3_1r is conv(128, 256, kernel_size=3, stride=1, padding=0)
auto fc3_r = fc3_1r->forward(roi3r).contiguous(); // Output: [B, 256, 1, 1] (due to 3x3 kernel, padding 0 on 3x3 input)
// Python: fc34_r = torch.cat((fc3_r, roi4r), dim=1)
auto fc34_r = torch::cat({fc3_r, roi4r}, 1).contiguous(); // Output: [B, 256+256=512, 1, 1]
// Python: fc34_3_r = self.fc34_3r(fc34_r)
// fc34_3r is conv(512, 256, kernel_size=1, stride=1, padding=0)
auto fc34_3_r_out = fc34_3r->forward(fc34_r).contiguous(); // Output: [B, 256, 1, 1]
// Python: fc34_4_r = self.fc34_4r(fc34_r)
// fc34_4r is conv(512, 256, kernel_size=1, stride=1, padding=0)
auto fc34_4_r_out = fc34_4r->forward(fc34_r).contiguous(); // Output: [B, 256, 1, 1]
// Return modulation vectors
return {modulation1, modulation2};
std::cout << " get_modulation output shapes: " << std::endl;
std::cout << " fc34_3_r_out: " << fc34_3_r_out.sizes() << std::endl;
std::cout << " fc34_4_r_out: " << fc34_4_r_out.sizes() << std::endl;
return {fc34_3_r_out, fc34_4_r_out};
}
// Predict IoU for proposals
@ -680,8 +681,9 @@ torch::Tensor BBRegressor::predict_iou(std::vector<torch::Tensor> modulation,
roi = roi.to(feat_device);
// Apply ROI pooling to get features for each proposal
auto pooled_feat1 = prroi_pool3r->forward(feat[0], roi);
auto pooled_feat2 = prroi_pool4r->forward(feat[1], roi);
// CORRECTED: Use prroi_pool3t and prroi_pool4t
auto pooled_feat1 = prroi_pool3t->forward(feat[0], roi); // Was prroi_pool3r
auto pooled_feat2 = prroi_pool4t->forward(feat[1], roi); // Was prroi_pool4r
// Make sure all tensors are on the same device (GPU)
torch::Device target_device = modulation[0].device();
@ -701,14 +703,22 @@ torch::Tensor BBRegressor::predict_iou(std::vector<torch::Tensor> modulation,
std::cout << " bias: [" << iou_predictor->bias.size(0) << "]" << std::endl;
try {
// Flatten pooled features
auto vec1 = pooled_feat1.reshape({pooled_feat1.size(0), -1});
auto vec2 = pooled_feat2.reshape({pooled_feat2.size(0), -1});
// CORRECTED: Process pooled features through fc3_rt and fc4_rt (LinearBlocks)
// These will handle the reshape and linear transformation.
// pooled_feat1 is [B*N, 256, 5, 5] -> fc3_rt -> [B*N, 256]
// pooled_feat2 is [B*N, 256, 3, 3] -> fc4_rt -> [B*N, 256]
std::cout << " Applying fc3_rt to pooled_feat1 (shape: " << pooled_feat1.sizes() << ")" << std::endl;
auto mod_target_0 = fc3_rt.forward(pooled_feat1);
std::cout << " Applying fc4_rt to pooled_feat2 (shape: " << pooled_feat2.sizes() << ")" << std::endl;
auto mod_target_1 = fc4_rt.forward(pooled_feat2);
std::cout << " mod_target_0 shape: " << mod_target_0.sizes() << std::endl;
std::cout << " mod_target_1 shape: " << mod_target_1.sizes() << std::endl;
// Print flattened shapes
std::cout << " Flattened shapes:" << std::endl;
std::cout << " vec1: [" << vec1.size(0) << ", " << vec1.size(1) << "]" << std::endl;
std::cout << " vec2: [" << vec2.size(0) << ", " << vec2.size(1) << "]" << std::endl;
// std::cout << " Flattened shapes:" << std::endl;
// std::cout << " vec1: [" << vec1.size(0) << ", " << vec1.size(1) << "]" << std::endl;
// std::cout << " vec2: [" << vec2.size(0) << ", " << vec2.size(1) << "]" << std::endl;
// We need to adapt the input to match what the IoU predictor expects
// The IoU predictor has a weight matrix of size 512x1, so input should have 512 features
@ -717,80 +727,96 @@ torch::Tensor BBRegressor::predict_iou(std::vector<torch::Tensor> modulation,
// This is based on the original Python implementation
// Get modulation shapes
std::cout << " Modulation vector shapes:" << std::endl;
std::cout << " mod1: [" << modulation[0].size(0) << ", " << modulation[0].size(1) << "]" << std::endl;
std::cout << " mod2: [" << modulation[1].size(0) << ", " << modulation[1].size(1) << "]" << std::endl;
std::cout << " Modulation vector shapes (from get_modulation):" << std::endl;
std::cout << " mod1 (input arg): [" << modulation[0].size(0) << ", " << modulation[0].size(1);
if (modulation[0].dim() > 2) std::cout << ", " << modulation[0].size(2) << ", " << modulation[0].size(3);
std::cout << "]" << std::endl;
std::cout << " mod2 (input arg): [" << modulation[1].size(0) << ", " << modulation[1].size(1);
if (modulation[1].dim() > 2) std::cout << ", " << modulation[1].size(2) << ", " << modulation[1].size(3);
std::cout << "]" << std::endl;
// Calculate expected dimensions
int mod1_dim = modulation[0].size(1); // Should be 256
int mod2_dim = modulation[1].size(1); // Should be 256
int total_mod_dim = mod1_dim + mod2_dim; // Should be 512, matching iou_predictor weight row count
// int mod1_dim = modulation[0].size(1); // Should be 256
// int mod2_dim = modulation[1].size(1); // Should be 256
// int total_mod_dim = mod1_dim + mod2_dim; // Should be 512, matching iou_predictor weight row count
std::cout << " Using correct input dimensions for IoU predictor (total_dim=" << total_mod_dim << ")" << std::endl;
// std::cout << " Using correct input dimensions for IoU predictor (total_dim=" << total_mod_dim << ")" << std::endl;
// Create processed features with correct dimensions
auto processed_feat1 = torch::zeros({num_proposals, mod1_dim}, vec1.options());
auto processed_feat2 = torch::zeros({num_proposals, mod2_dim}, vec2.options());
// auto processed_feat1 = torch::zeros({num_proposals, mod1_dim}, vec1.options());
// auto processed_feat2 = torch::zeros({num_proposals, mod2_dim}, vec2.options());
// We need to reduce the dimensionality of vec1 and vec2 to match mod1_dim and mod2_dim
// REMOVED Manual Averaging Logic
// We'll use average pooling across spatial dimensions
if (vec1.size(1) > mod1_dim) {
// Average every N values to reduce dimension
int pool_size = vec1.size(1) / mod1_dim;
std::cout << " Reducing vec1 features with pool_size=" << pool_size << std::endl;
// if (vec1.size(1) > mod1_dim) {
// // Average every N values to reduce dimension
// int pool_size = vec1.size(1) / mod1_dim;
// std::cout << " Reducing vec1 features with pool_size=" << pool_size << std::endl;
for (int i = 0; i < num_proposals; i++) {
for (int j = 0; j < mod1_dim; j++) {
float sum = 0.0f;
for (int k = 0; k < pool_size; k++) {
int idx = j * pool_size + k;
if (idx < vec1.size(1)) {
sum += vec1[i][idx].item<float>();
}
}
processed_feat1[i][j] = sum / pool_size;
}
}
} else {
// Just copy directly if dimensions already match
processed_feat1 = vec1;
}
// for (int i = 0; i < num_proposals; i++) {
// for (int j = 0; j < mod1_dim; j++) {
// float sum = 0.0f;
// for (int k = 0; k < pool_size; k++) {
// int idx = j * pool_size + k;
// if (idx < vec1.size(1)) {
// sum += vec1[i][idx].item<float>();
// }
// }
// processed_feat1[i][j] = sum / pool_size;
// }
// }
// } else {
// // Just copy directly if dimensions already match
// processed_feat1 = vec1;
// }
if (vec2.size(1) > mod2_dim) {
// Similar reduction for vec2
int pool_size = vec2.size(1) / mod2_dim;
std::cout << " Reducing vec2 features with pool_size=" << pool_size << std::endl;
// if (vec2.size(1) > mod2_dim) {
// // Similar reduction for vec2
// int pool_size = vec2.size(1) / mod2_dim;
// std::cout << " Reducing vec2 features with pool_size=" << pool_size << std::endl;
for (int i = 0; i < num_proposals; i++) {
for (int j = 0; j < mod2_dim; j++) {
float sum = 0.0f;
for (int k = 0; k < pool_size; k++) {
int idx = j * pool_size + k;
if (idx < vec2.size(1)) {
sum += vec2[i][idx].item<float>();
}
}
processed_feat2[i][j] = sum / pool_size;
}
}
} else {
// Just copy directly if dimensions already match
processed_feat2 = vec2;
}
// for (int i = 0; i < num_proposals; i++) {
// for (int j = 0; j < mod2_dim; j++) {
// float sum = 0.0f;
// for (int k = 0; k < pool_size; k++) {
// int idx = j * pool_size + k;
// if (idx < vec2.size(1)) {
// sum += vec2[i][idx].item<float>();
// }
// }
// processed_feat2[i][j] = sum / pool_size;
// }
// }
// } else {
// // Just copy directly if dimensions already match
// processed_feat2 = vec2;
// }
// Prepare modulation vectors for each proposal
auto mod1 = modulation[0].repeat({num_proposals, 1});
auto mod2 = modulation[1].repeat({num_proposals, 1});
auto m0_in = modulation[0]; // Shape can be [1, 256] or [1, 256, 1, 1]
auto m1_in = modulation[1];
if (m0_in.dim() == 4 && m0_in.size(2) == 1 && m0_in.size(3) == 1) {
m0_in = m0_in.squeeze(-1).squeeze(-1); // Now [1, 256]
}
if (m1_in.dim() == 4 && m1_in.size(2) == 1 && m1_in.size(3) == 1) {
m1_in = m1_in.squeeze(-1).squeeze(-1); // Now [1, 256]
}
// Now m0_in and m1_in are guaranteed to be 2D [Batch, Channels] e.g. [1, 256]
auto mod1_repeated_for_proposals = m0_in.repeat({num_proposals, 1}); // [num_proposals, 256]
auto mod2_repeated_for_proposals = m1_in.repeat({num_proposals, 1}); // [num_proposals, 256]
std::cout << " Final feature shapes:" << std::endl;
std::cout << " processed_feat1: [" << processed_feat1.size(0) << ", " << processed_feat1.size(1) << "]" << std::endl;
std::cout << " processed_feat2: [" << processed_feat2.size(0) << ", " << processed_feat2.size(1) << "]" << std::endl;
std::cout << " mod1: [" << mod1.size(0) << ", " << mod1.size(1) << "]" << std::endl;
std::cout << " mod2: [" << mod2.size(0) << ", " << mod2.size(1) << "]" << std::endl;
std::cout << " Final feature shapes (after LinearBlocks, before element-wise mult):" << std::endl;
std::cout << " mod_target_0 (from fc3_rt): [" << mod_target_0.size(0) << ", " << mod_target_0.size(1) << "]" << std::endl;
std::cout << " mod_target_1 (from fc4_rt): [" << mod_target_1.size(0) << ", " << mod_target_1.size(1) << "]" << std::endl;
std::cout << " mod1_repeated (from get_modulation input): [" << mod1_repeated_for_proposals.size(0) << ", " << mod1_repeated_for_proposals.size(1) << "]" << std::endl;
std::cout << " mod2_repeated (from get_modulation input): [" << mod2_repeated_for_proposals.size(0) << ", " << mod2_repeated_for_proposals.size(1) << "]" << std::endl;
// Element-wise multiply features with modulation vectors
auto mod_feat1 = processed_feat1 * mod1;
auto mod_feat2 = processed_feat2 * mod2;
// CORRECTED: Use mod_target_0 and mod_target_1 from fc3_rt/fc4_rt
auto mod_feat1 = mod_target_0 * mod1_repeated_for_proposals;
auto mod_feat2 = mod_target_1 * mod2_repeated_for_proposals;
// Concatenate to get final features for IoU prediction
auto ioufeat = torch::cat({mod_feat1, mod_feat2}, /*dim=*/1);

Loading…
Cancel
Save