diff --git a/cimp/resnet/resnet.cpp b/cimp/resnet/resnet.cpp index 7484f22..c4a4b47 100644 --- a/cimp/resnet/resnet.cpp +++ b/cimp/resnet/resnet.cpp @@ -238,7 +238,6 @@ std::map ResNetImpl::forward(torch::Tensor x) { }; x = conv1->forward(x); - if (should_output("conv1_output")) outputs["conv1_output"] = x; if (should_output("debug_resnet_conv1_output_for_bn1_input")) { outputs["debug_resnet_conv1_output_for_bn1_input"] = x.clone(); } @@ -255,6 +254,9 @@ std::map ResNetImpl::forward(torch::Tensor x) { x = relu->forward(x); if (should_output("relu1_output")) outputs["relu1_output"] = x; + // Save conv1_output AFTER bn1 and relu (matching Python behavior) + if (should_output("conv1_output")) outputs["conv1_output"] = x; + torch::Tensor x_pre_layer1 = maxpool->forward(x); if (should_output("maxpool_output")) outputs["maxpool_output"] = x_pre_layer1;