@ -558,68 +558,44 @@ std::vector<torch::Tensor> BBRegressor::get_modulation(std::vector<torch::Tensor
// Python: c3_r = self.conv3_1r(feat3_r)
auto c3_r = conv3_1r - > forward ( feat3_r ) ;
// Prepare ROIs: convert bb from [x,y,w,h] to [batch_idx, x1,y1,x2,y2]
// Prepare ROIs: convert bb from [x,y,w,h] to [batch_idx, x1,y1,x2,y2] (matching Python)
int batch_size = current_bb . size ( 0 ) ;
auto batch_idx = torch : : arange ( 0 , batch_size , current_bb . options ( ) . dtype ( torch : : kFloat ) ) . unsqueeze ( 1 ) ;
auto rois = torch : : zeros ( { batch_size , 5 } , current_bb . options ( ) ) ;
rois . index_put_ ( { torch : : indexing : : Slice ( ) , 0 } , batch_idx . squeeze ( 1 ) ) ; // batch index
rois . index_put_ ( { torch : : indexing : : Slice ( ) , 1 } , current_bb . index ( { torch : : indexing : : Slice ( ) , 0 } ) ) ; // x1
rois . index_put_ ( { torch : : indexing : : Slice ( ) , 2 } , current_bb . index ( { torch : : indexing : : Slice ( ) , 1 } ) ) ; // y1
rois . index_put_ ( { torch : : indexing : : Slice ( ) , 3 } , current_bb . index ( { torch : : indexing : : Slice ( ) , 0 } ) + current_bb . index ( { torch : : indexing : : Slice ( ) , 2 } ) ) ; // x2 = x1 + w
rois . index_put_ ( { torch : : indexing : : Slice ( ) , 4 } , current_bb . index ( { torch : : indexing : : Slice ( ) , 1 } ) + current_bb . index ( { torch : : indexing : : Slice ( ) , 3 } ) ) ; // y2 = y1 + h
rois = rois . to ( device ) ; // Ensure ROIs are on the correct device
std : : cout < < " BBRegressor::get_modulation: Converted ROIs (first item): [ " ;
if ( batch_size > 0 ) {
for ( int j = 0 ; j < rois . size ( 1 ) ; j + + ) {
std : : cout < < rois [ 0 ] [ j ] . item < float > ( ) ;
if ( j < rois . size ( 1 ) - 1 ) std : : cout < < " , " ;
}
}
std : : cout < < " ] " < < std : : endl ;
std : : cout < < " BBRegressor::get_modulation: c3_r shape: " < < c3_r . sizes ( ) < < " , device: " < < c3_r . device ( ) < < std : : endl ;
auto batch_index = torch : : arange ( 0 , batch_size , current_bb . options ( ) . dtype ( torch : : kFloat ) ) . reshape ( { - 1 , 1 } ) ;
// Convert bb from xywh to xyxy format (matching Python: bb[:, 2:4] = bb[:, 0:2] + bb[:, 2:4])
auto bb_xyxy = current_bb . clone ( ) ;
bb_xyxy . index_put_ ( { torch : : indexing : : Slice ( ) , torch : : indexing : : Slice ( 2 , 4 ) } ,
bb_xyxy . index ( { torch : : indexing : : Slice ( ) , torch : : indexing : : Slice ( 0 , 2 ) } ) +
bb_xyxy . index ( { torch : : indexing : : Slice ( ) , torch : : indexing : : Slice ( 2 , 4 ) } ) ) ;
// Create ROI (matching Python: roi1 = torch.cat((batch_index, bb), dim=1))
std : : vector < torch : : Tensor > roi1_tensors = { batch_index , bb_xyxy } ;
auto roi1 = torch : : cat ( roi1_tensors , 1 ) ;
roi1 = roi1 . to ( device ) ;
// Python: roi3r = self.prroi_pool3r(c3_r, roi1)
auto roi3r = prroi_pool3r - > forward ( c3_r , rois ) ;
std : : cout < < " BBRegressor::get_modulation: roi3r shape: " < < roi3r . sizes ( ) < < std : : endl ;
auto roi3r = prroi_pool3r - > forward ( c3_r , roi1 ) ;
// Python: c4_r = self.conv4_1r(feat4_r)
auto c4_r = conv4_1r - > forward ( feat4_r ) ;
std : : cout < < " BBRegressor::get_modulation: c4_r shape: " < < c4_r . sizes ( ) < < " , device: " < < c4_r . device ( ) < < std : : endl ;
// Python: roi4r = self.prroi_pool4r(c4_r, roi1)
auto roi4r = prroi_pool4r - > forward ( c4_r , rois ) ;
std : : cout < < " BBRegressor::get_modulation: roi4r shape: " < < roi4r . sizes ( ) < < std : : endl ;
auto roi4r = prroi_pool4r - > forward ( c4_r , roi1 ) ;
// Python: fc3_r = self.fc3_1r(roi3r)
// fc3_1r is a conv block: conv(128, 256, kernel_size=3, stride=1, padding=0)
// Input roi3r is (batch, 128, 3, 3) -> Output fc3_r is (batch, 256, 1, 1)
auto fc3_r = fc3_1r - > forward ( roi3r ) ;
std : : cout < < " BBRegressor::get_modulation: fc3_r shape: " < < fc3_r . sizes ( ) < < std : : endl ;
// Python: fc34_r = torch.cat((fc3_r, roi4r), dim=1)
// fc3_r is (batch, 256, 1, 1), roi4r is (batch, 256, 1, 1)
// Result fc34_r is (batch, 512, 1, 1)
auto fc34_r = torch : : cat ( { fc3_r , roi4r } , 1 ) ;
std : : cout < < " BBRegressor::get_modulation: fc34_r shape: " < < fc34_r . sizes ( ) < < std : : endl ;
std : : vector < torch : : Tensor > fc34_r_tensors = { fc3_r , roi4r } ;
auto fc34_r = torch : : cat ( fc34_r_tensors , 1 ) ;
// Python: fc34_3_r = self.fc34_3r(fc34_r)
// fc34_3r is conv(512, 256, kernel_size=1, stride=1, padding=0)
// Output fc34_3_r is (batch, 256, 1, 1)
auto mod_vec1 = fc34_3r - > forward ( fc34_r ) ;
std : : cout < < " BBRegressor::get_modulation: mod_vec1 (fc34_3_r) shape: " < < mod_vec1 . sizes ( ) < < std : : endl ;
auto fc34_3_r = fc34_3r - > forward ( fc34_r ) ;
// Python: fc34_4_r = self.fc34_4r(fc34_r)
// fc34_4r is conv(512, 256, kernel_size=1, stride=1, padding=0)
// Output fc34_4_r is (batch, 256, 1, 1)
auto mod_vec2 = fc34_4r - > forward ( fc34_r ) ;
std : : cout < < " BBRegressor::get_modulation: mod_vec2 (fc34_4_r) shape: " < < mod_vec2 . sizes ( ) < < std : : endl ;
auto fc34_4_r = fc34_4r - > forward ( fc34_r ) ;
return { mod_vec1 , mod_vec2 } ;
return { fc34_3_r , fc34_4_r } ;
}
// Predict IoU for proposals
@ -627,7 +603,7 @@ torch::Tensor BBRegressor::predict_iou(std::vector<torch::Tensor> modulation,
std : : vector < torch : : Tensor > feat ,
torch : : Tensor proposals ) {
// Ensure all inputs are on the correct device
auto target_device = device ; // Assuming 'device' is a member of BBRegressor
auto target_device = device ;
for ( auto & t : feat ) { t = t . to ( target_device ) ; }
for ( auto & m : modulation ) { m = m . to ( target_device ) ; }
proposals = proposals . to ( target_device ) ;
@ -636,135 +612,56 @@ torch::Tensor BBRegressor::predict_iou(std::vector<torch::Tensor> modulation,
int batch_size = proposals . size ( 0 ) ;
int num_proposals = proposals . size ( 1 ) ;
// Reshape proposals to [batch_size * num_proposals, 4]
// and add batch index for PrRoIPooling
auto proposals_view = proposals . reshape ( { batch_size * num_proposals , 4 } ) ;
auto roi_batch_index = torch : : arange ( 0 , batch_size , proposals . options ( ) . dtype ( torch : : kInt ) ) . unsqueeze ( 1 ) ;
roi_batch_index = roi_batch_index . repeat_interleave ( num_proposals , 0 ) ;
auto roi = torch : : cat ( std : : vector < torch : : Tensor > { roi_batch_index . to ( proposals_view . options ( ) ) , proposals_view } , 1 ) ;
// Ensure ROI is on the correct device, matching features
auto feat_device = feat [ 0 ] . device ( ) ;
roi = roi . to ( feat_device ) ;
// Apply modulation vectors BEFORE PrRoIPooling
auto mod0_4d = modulation [ 0 ] . to ( feat_device ) ;
auto mod1_4d = modulation [ 1 ] . to ( feat_device ) ;
// Apply modulation BEFORE PrRoIPooling (matching Python implementation)
auto fc34_3_r = modulation [ 0 ] . to ( target_device ) ;
auto fc34_4_r = modulation [ 1 ] . to ( target_device ) ;
auto c3_t = feat [ 0 ] . to ( target_device ) ;
auto c4_t = feat [ 1 ] . to ( target_device ) ;
if ( mod0_4d . dim ( ) = = 2 ) {
mod0_4d = mod0_4d . reshape ( { mod0_4d . size ( 0 ) , mod0_4d . size ( 1 ) , 1 , 1 } ) ;
// Reshape modulation vectors to match Python: fc34_3_r.reshape(batch_size, -1, 1, 1)
if ( fc34_3_r . dim ( ) = = 2 ) {
fc34_3_r = fc34_3_r . reshape ( { batch_size , - 1 , 1 , 1 } ) ;
}
if ( mod1_4d . dim ( ) = = 2 ) {
mod1_4d = mod1_4d . reshape ( { mod1_4d . size ( 0 ) , mod1_4d . size ( 1 ) , 1 , 1 } ) ;
if ( fc34_4_r . dim ( ) = = 2 ) {
fc34_4_r = fc34_4_r . reshape ( { batch_size , - 1 , 1 , 1 } ) ;
}
// Ensure modulation vectors are broadcastable with features
// Features (feat[0], feat[1]) are [batch_size, channels, H, W]
// Modulation (mod0_4d, mod1_4d) should be [batch_size, channels, 1, 1]
// If num_proposals > 1, the pooling happens on features that are effectively repeated.
// The modulation is per-image, not per-proposal before pooling.
torch : : Tensor modulated_feat0 = feat [ 0 ] * mod0_4d ;
torch : : Tensor modulated_feat1 = feat [ 1 ] * mod1_4d ;
// Apply ROI pooling to get features for each proposal from MODULATED features
auto pooled_feat1 = prroi_pool3t - > forward ( modulated_feat0 , roi ) ; // Output: [batch_size * num_proposals, C, 5, 5]
auto pooled_feat2 = prroi_pool4t - > forward ( modulated_feat1 , roi ) ;
std : : cout < < " Modulated and Pooled shapes: " < < std : : endl ;
std : : cout < < " pooled_feat1 (from prroi_pool3t on modulated_feat0): [ " < < pooled_feat1 . sizes ( ) < < " ] dev: " < < pooled_feat1 . device ( ) < < std : : endl ;
std : : cout < < " pooled_feat2 (from prroi_pool4t on modulated_feat1): [ " < < pooled_feat2 . sizes ( ) < < " ] dev: " < < pooled_feat2 . device ( ) < < std : : endl ;
std : : cout < < " IoU predictor dimensions: " < < std : : endl ;
std : : cout < < " weight: [ " < < iou_predictor - > weight . sizes ( ) < < " ] " < < std : : endl ;
std : : cout < < " bias: [ " < < iou_predictor - > bias . sizes ( ) < < " ] " < < std : : endl ;
try {
// The feat_prod_0 and feat_prod_1 are now directly the pooled_feat1 and pooled_feat2
// as modulation was applied before pooling.
auto x0 = fc3_rt . forward ( pooled_feat1 ) ;
auto x1 = fc4_rt . forward ( pooled_feat2 ) ;
auto ioufeat_final = torch : : cat ( { x0 , x1 } , 1 ) . contiguous ( ) ;
// Ensure iou_predictor is on the correct device
iou_predictor - > to ( target_device ) ;
auto iou_scores = iou_predictor - > forward ( ioufeat_final ) ;
// Ensure iou_scores is on the correct device before returning
iou_scores = iou_scores . to ( target_device ) ;
// Apply modulation BEFORE pooling (matching Python: c3_t_att = c3_t * fc34_3_r.reshape(batch_size, -1, 1, 1))
auto c3_t_att = c3_t * fc34_3_r ;
auto c4_t_att = c4_t * fc34_4_r ;
// The following block for feat_prod_0 and feat_prod_1 is no longer needed as modulation is done pre-pool.
/*
auto mod0_4d = modulation [ 0 ] . to ( target_device ) ;
auto mod1_4d = modulation [ 1 ] . to ( target_device ) ;
// Convert proposals from xywh to xyxy format (matching Python)
auto proposals_xy = proposals . index ( { torch : : indexing : : Slice ( ) , torch : : indexing : : Slice ( ) , torch : : indexing : : Slice ( 0 , 2 ) } ) ;
auto proposals_wh = proposals . index ( { torch : : indexing : : Slice ( ) , torch : : indexing : : Slice ( ) , torch : : indexing : : Slice ( 2 , 4 ) } ) ;
auto proposals_xyxy = torch : : cat ( { proposals_xy , proposals_xy + proposals_wh } , 2 ) ;
if ( mod0_4d . dim ( ) = = 2 ) {
mod0_4d = mod0_4d . reshape ( { mod0_4d . size ( 0 ) , mod0_4d . size ( 1 ) , 1 , 1 } ) ;
}
if ( mod1_4d . dim ( ) = = 2 ) {
mod1_4d = mod1_4d . reshape ( { mod1_4d . size ( 0 ) , mod1_4d . size ( 1 ) , 1 , 1 } ) ;
}
if ( mod0_4d . size ( 0 ) = = 1 & & pooled_feat1 . size ( 0 ) > 1 ) {
mod0_4d = mod0_4d . repeat ( { pooled_feat1 . size ( 0 ) , 1 , 1 , 1 } ) ;
}
if ( mod1_4d . size ( 0 ) = = 1 & & pooled_feat2 . size ( 0 ) > 1 ) {
mod1_4d = mod1_4d . repeat ( { pooled_feat2 . size ( 0 ) , 1 , 1 , 1 } ) ;
}
// Add batch index (matching Python implementation)
auto batch_index = torch : : arange ( 0 , batch_size , proposals . options ( ) . dtype ( torch : : kFloat ) ) . reshape ( { - 1 , 1 } ) ;
auto batch_index_expanded = batch_index . reshape ( { batch_size , - 1 , 1 } ) . expand ( { - 1 , num_proposals , - 1 } ) ;
std : : vector < torch : : Tensor > roi2_tensors = { batch_index_expanded , proposals_xyxy } ;
auto roi2 = torch : : cat ( roi2_tensors , 2 ) ;
roi2 = roi2 . reshape ( { - 1 , 5 } ) . to ( proposals_xyxy . device ( ) ) ;
std : : cout < < " Modulation vector shapes (reshaped 4D): " < < std : : endl ;
std : : cout < < " mod0_4d: [ " < < mod0_4d . sizes ( ) < < " ] dev: " < < mod0_4d . device ( ) < < std : : endl ;
std : : cout < < " mod1_4d: [ " < < mod1_4d . sizes ( ) < < " ] dev: " < < mod1_4d . device ( ) < < std : : endl ;
auto feat_prod_0 = pooled_feat1 * mod0_4d ;
auto feat_prod_1 = pooled_feat2 * mod1_4d ;
// Apply PrRoIPooling to MODULATED features (matching Python)
auto roi3t = prroi_pool3t - > forward ( c3_t_att , roi2 ) ;
auto roi4t = prroi_pool4t - > forward ( c4_t_att , roi2 ) ;
std : : cout < < " Feature product shapes (pooled_feat * mod_vec): " < < std : : endl ;
std : : cout < < " feat_prod_0: [ " < < feat_prod_0 . sizes ( ) < < " ] dev: " < < feat_prod_0 . device ( ) < < std : : endl ;
std : : cout < < " feat_prod_1: [ " < < feat_prod_1 . sizes ( ) < < " ] dev: " < < feat_prod_1 . device ( ) < < std : : endl ;
// Forward through linear blocks
// Ensure fc3_rt and fc4_rt are on the correct device
fc3_rt . to ( target_device ) ;
fc4_rt . to ( target_device ) ;
// Forward through linear blocks
fc3_rt . to ( target_device ) ;
fc4_rt . to ( target_device ) ;
auto fc3_rt_output = fc3_rt . forward ( roi3t ) ;
auto fc4_rt_output = fc4_rt . forward ( roi4t ) ;
auto x0 = fc3_rt . forward ( feat_prod_0 ) ;
auto x1 = fc4_rt . forward ( feat_prod_1 ) ;
std : : cout < < " fc_rt output shapes: " < < std : : endl ;
std : : cout < < " x0 (fc3_rt output): [ " < < x0 . sizes ( ) < < " ] dev: " < < x0 . device ( ) < < std : : endl ;
std : : cout < < " x1 (fc4_rt output): [ " < < x1 . sizes ( ) < < " ] dev: " < < x1 . device ( ) < < std : : endl ;
// Concatenate features (matching Python)
std : : vector < torch : : Tensor > fc34_rt_tensors = { fc3_rt_output , fc4_rt_output } ;
auto fc34_rt_cat = torch : : cat ( fc34_rt_tensors , 1 ) ;
auto ioufeat_final = torch : : cat ( { x0 , x1 } , 1 ) . contiguous ( ) ;
std : : cout < < " ioufeat_final shape: [ " < < ioufeat_final . sizes ( ) < < " ] dev: " < < ioufeat_final . device ( ) < < std : : endl ;
// Ensure iou_predictor is on the correct device
iou_predictor - > to ( target_device ) ;
auto iou_scores = iou_predictor - > forward ( ioufeat_final ) ;
// Ensure iou_scores is on the correct device before returning
iou_scores = iou_scores . to ( target_device ) ;
*/
// Ensure iou_scores is on the correct device before returning.
// This was already done above, but as a final check:
if ( iou_scores . device ( ) ! = target_device ) {
iou_scores = iou_scores . to ( target_device ) ;
}
iou_scores = iou_scores . reshape ( { batch_size , num_proposals } ) ;
std : : cout < < " Final iou_scores shape: [ " < < iou_scores . size ( 0 ) < < " , " < < iou_scores . size ( 1 ) < < " ] " < < std : : endl ;
return iou_scores ;
} catch ( const std : : exception & e ) {
std : : cerr < < " CRITICAL: Unexpected error in predict_iou: " < < e . what ( ) < < std : : endl ;
std : : cout < < " Propagating critical error. No fallback available for this stage. " < < std : : endl ;
throw ;
}
// Predict IoU
iou_predictor - > to ( target_device ) ;
auto iou_pred = iou_predictor - > forward ( fc34_rt_cat ) . reshape ( { batch_size , num_proposals } ) ;
return iou_pred ;
}
// Print model information