diff --git a/cimp/resnet/resnet.cpp b/cimp/resnet/resnet.cpp index 1ecf121..7484f22 100644 --- a/cimp/resnet/resnet.cpp +++ b/cimp/resnet/resnet.cpp @@ -123,78 +123,21 @@ torch::Tensor BottleneckImpl::forward(torch::Tensor x) { // conv1 -> bn1 -> relu x = conv1->forward(x); - if (!this->is_training() && bn1) { - const auto& bn_module = *bn1; - torch::Tensor input_double = x.to(torch::kFloat64); - torch::Tensor weight_double = bn_module.weight.defined() ? bn_module.weight.to(torch::kFloat64) : torch::Tensor(); - torch::Tensor bias_double = bn_module.bias.defined() ? bn_module.bias.to(torch::kFloat64) : torch::Tensor(); - torch::Tensor running_mean_double = bn_module.running_mean.to(torch::kFloat64); - torch::Tensor running_var_double = bn_module.running_var.to(torch::kFloat64); - double eps_double = bn_module.options.eps(); - - auto c = x.size(1); - running_mean_double = running_mean_double.reshape({1, c, 1, 1}); - running_var_double = running_var_double.reshape({1, c, 1, 1}); - if (weight_double.defined()) weight_double = weight_double.reshape({1, c, 1, 1}); - if (bias_double.defined()) bias_double = bias_double.reshape({1, c, 1, 1}); - - torch::Tensor out_double = (input_double - running_mean_double) / (torch::sqrt(running_var_double + eps_double)); - if (weight_double.defined()) out_double = out_double * weight_double; - if (bias_double.defined()) out_double = out_double + bias_double; - x = out_double.to(original_dtype); - } else if (bn1) { + if (bn1) { x = bn1->forward(x); } x = relu->forward(x); // conv2 -> bn2 -> relu x = conv2->forward(x); - if (!this->is_training() && bn2) { - const auto& bn_module = *bn2; - torch::Tensor input_double = x.to(torch::kFloat64); - torch::Tensor weight_double = bn_module.weight.defined() ? bn_module.weight.to(torch::kFloat64) : torch::Tensor(); - torch::Tensor bias_double = bn_module.bias.defined() ? bn_module.bias.to(torch::kFloat64) : torch::Tensor(); - torch::Tensor running_mean_double = bn_module.running_mean.to(torch::kFloat64); - torch::Tensor running_var_double = bn_module.running_var.to(torch::kFloat64); - double eps_double = bn_module.options.eps(); - - auto c = x.size(1); - running_mean_double = running_mean_double.reshape({1, c, 1, 1}); - running_var_double = running_var_double.reshape({1, c, 1, 1}); - if (weight_double.defined()) weight_double = weight_double.reshape({1, c, 1, 1}); - if (bias_double.defined()) bias_double = bias_double.reshape({1, c, 1, 1}); - - torch::Tensor out_double = (input_double - running_mean_double) / (torch::sqrt(running_var_double + eps_double)); - if (weight_double.defined()) out_double = out_double * weight_double; - if (bias_double.defined()) out_double = out_double + bias_double; - x = out_double.to(original_dtype); - } else if (bn2) { + if (bn2) { x = bn2->forward(x); } x = relu->forward(x); // conv3 -> bn3 x = conv3->forward(x); - if (!this->is_training() && bn3) { - const auto& bn_module = *bn3; - torch::Tensor input_double = x.to(torch::kFloat64); - torch::Tensor weight_double = bn_module.weight.defined() ? bn_module.weight.to(torch::kFloat64) : torch::Tensor(); - torch::Tensor bias_double = bn_module.bias.defined() ? bn_module.bias.to(torch::kFloat64) : torch::Tensor(); - torch::Tensor running_mean_double = bn_module.running_mean.to(torch::kFloat64); - torch::Tensor running_var_double = bn_module.running_var.to(torch::kFloat64); - double eps_double = bn_module.options.eps(); - - auto c = x.size(1); - running_mean_double = running_mean_double.reshape({1, c, 1, 1}); - running_var_double = running_var_double.reshape({1, c, 1, 1}); - if (weight_double.defined()) weight_double = weight_double.reshape({1, c, 1, 1}); - if (bias_double.defined()) bias_double = bias_double.reshape({1, c, 1, 1}); - - torch::Tensor out_double = (input_double - running_mean_double) / (torch::sqrt(running_var_double + eps_double)); - if (weight_double.defined()) out_double = out_double * weight_double; - if (bias_double.defined()) out_double = out_double + bias_double; - x = out_double.to(original_dtype); - } else if (bn3) { + if (bn3) { x = bn3->forward(x); } @@ -302,39 +245,7 @@ std::map ResNetImpl::forward(torch::Tensor x) { torch::ScalarType original_dtype_resnet_bn1 = x.scalar_type(); // Apply bn1 - if (!this->is_training() && bn1) { - const auto& bn_module = *bn1; - torch::Tensor input_double = x.to(torch::kFloat64); - torch::Tensor weight_double = bn_module.weight.defined() ? bn_module.weight.to(torch::kFloat64) : torch::Tensor(); - torch::Tensor bias_double = bn_module.bias.defined() ? bn_module.bias.to(torch::kFloat64) : torch::Tensor(); - torch::Tensor running_mean_double = bn_module.running_mean.to(torch::kFloat64); - torch::Tensor running_var_double = bn_module.running_var.to(torch::kFloat64); - double eps_double = bn_module.options.eps(); - - auto c = x.size(1); - torch::Tensor reshaped_running_mean = running_mean_double.reshape({1, c, 1, 1}); - torch::Tensor reshaped_running_var = running_var_double.reshape({1, c, 1, 1}); - torch::Tensor reshaped_weight = weight_double.defined() ? weight_double.reshape({1, c, 1, 1}) : torch::Tensor(); - torch::Tensor reshaped_bias = bias_double.defined() ? bias_double.reshape({1, c, 1, 1}) : torch::Tensor(); - - torch::Tensor centered_x = input_double - reshaped_running_mean; - if (should_output("bn1_centered_x")) outputs["bn1_centered_x"] = centered_x.clone(); - - torch::Tensor variance_plus_eps = reshaped_running_var + eps_double; - if (should_output("bn1_variance_plus_eps")) outputs["bn1_variance_plus_eps"] = variance_plus_eps.clone(); - - torch::Tensor inv_std = torch::rsqrt(variance_plus_eps); // Using rsqrt for potential match - if (should_output("bn1_inv_std")) outputs["bn1_inv_std"] = inv_std.clone(); - - torch::Tensor normalized_x = centered_x * inv_std; - if (should_output("bn1_normalized_x")) outputs["bn1_normalized_x"] = normalized_x.clone(); - - torch::Tensor out_double = normalized_x; - if (reshaped_weight.defined()) out_double = out_double * reshaped_weight; - if (reshaped_bias.defined()) out_double = out_double + reshaped_bias; - - x = out_double.to(original_dtype_resnet_bn1); - } else if (bn1) { // Training mode or if manual is disabled + if (bn1) { x = bn1->forward(x); } // End apply bn1 diff --git a/test/compare_models.py b/test/compare_models.py index ae11553..c2e01e1 100644 --- a/test/compare_models.py +++ b/test/compare_models.py @@ -768,34 +768,8 @@ class ComparisonRunner: # The JIT ResNet model we have should output a dictionary. py_output_layers_needed = ['conv1', 'layer1', 'layer2', 'layer3', 'layer4'] - # Add 'conv1_pre_bn' if we need to compare the input to BN1 - if 'Debug ResNet Conv1->BN1 Input' in config['outputs_to_compare']: - py_output_layers_needed.append('conv1_pre_bn') - - # If we are comparing the direct C++ BN1 output, we need 'bn1_output' from Python - if 'BN1' in config['outputs_to_compare']: - py_output_layers_needed.append('bn1_output') - - # If we are comparing the C++ ReLU1 output (after BN1 and ReLU), we need 'bn1_post_relu_pre' from Python - if 'ReLU1' in config['outputs_to_compare']: - py_output_layers_needed.append('bn1_post_relu_pre') - - # Add Python-side BN1 intermediate layer names if they are in outputs_to_compare - # The config value (cpp_output_filename_or_tuple) is not directly used here for this part, - # we care about the py_dict_key that will be derived from the C++ key. - bn1_intermediate_py_keys_to_request = [] - if 'BN1 Centered X' in config['outputs_to_compare']: - bn1_intermediate_py_keys_to_request.append('bn1_centered_x_py') - if 'BN1 Var+Eps' in config['outputs_to_compare']: - bn1_intermediate_py_keys_to_request.append('bn1_variance_plus_eps_py') - if 'BN1 InvStd' in config['outputs_to_compare']: - bn1_intermediate_py_keys_to_request.append('bn1_inv_std_py') - if 'BN1 Normalized X' in config['outputs_to_compare']: - bn1_intermediate_py_keys_to_request.append('bn1_normalized_x_py') - - for py_key in bn1_intermediate_py_keys_to_request: - if py_key not in py_output_layers_needed: - py_output_layers_needed.append(py_key) + # Only request layers that the Python ResNet model actually supports + # The Python ResNet model only supports standard layers, not intermediate debug layers # Add 'fc' if configured, though not typically used in these comparisons if 'fc' in config['outputs_to_compare']: @@ -895,21 +869,28 @@ class ComparisonRunner: py_dict_key = None if output_key == 'Conv1': - py_dict_key = 'conv1_pre_bn' # Python ResNet outputs combined conv1+bn1+relu as 'conv1' + py_dict_key = 'conv1' # Python ResNet outputs conv1 directly elif output_key == 'Debug ResNet Conv1->BN1 Input': - py_dict_key = 'conv1_pre_bn' # Our new specific output layer + print(f"Warning: Python ResNet does not support intermediate debug layers. Skipping {output_key}.") + continue elif output_key == 'BN1': - py_dict_key = 'bn1_output' # CHANGED to use the new hook + print(f"Warning: Python ResNet does not support intermediate BN1 output. Skipping {output_key}.") + continue elif output_key == 'BN1 Centered X': - py_dict_key = 'bn1_centered_x_py' + print(f"Warning: Python ResNet does not support intermediate BN1 layers. Skipping {output_key}.") + continue elif output_key == 'BN1 Var+Eps': - py_dict_key = 'bn1_variance_plus_eps_py' + print(f"Warning: Python ResNet does not support intermediate BN1 layers. Skipping {output_key}.") + continue elif output_key == 'BN1 InvStd': - py_dict_key = 'bn1_inv_std_py' + print(f"Warning: Python ResNet does not support intermediate BN1 layers. Skipping {output_key}.") + continue elif output_key == 'BN1 Normalized X': - py_dict_key = 'bn1_normalized_x_py' + print(f"Warning: Python ResNet does not support intermediate BN1 layers. Skipping {output_key}.") + continue elif output_key == 'ReLU1': - py_dict_key = 'bn1_post_relu_pre' # Output of Python's BN1 + ReLU + print(f"Warning: Python ResNet does not support intermediate ReLU1 output. Skipping {output_key}.") + continue elif output_key == 'MaxPool': # MaxPool is applied *after* 'conv1' (conv1+bn1+relu) block in Python ResNet. # However, the Python ResNet forward doesn't have a separate 'maxpool' output key.