You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

42 lines
1.9 KiB

import torch
def get_tensor(x):
if isinstance(x, torch.Tensor):
return x
if hasattr(x, 'parameters') and len(list(x.parameters())) > 0:
return list(x.parameters())[0]
if callable(x):
try:
out = x()
if isinstance(out, torch.Tensor):
return out
except Exception:
pass
raise RuntimeError('Could not extract tensor')
def compare_tensors(py_path, cpp_path, label):
a = torch.load(py_path)
b = torch.load(cpp_path, weights_only=False)
b = get_tensor(b)
print(f'--- {label} ---')
print('py:', a.shape, a.dtype, a.min().item(), a.max().item(), a.mean().item())
print('cpp:', b.shape, b.dtype, b.min().item(), b.max().item(), b.mean().item())
print('diff:', (a-b).abs().max().item(), (a-b).abs().mean().item())
compare_tensors('test/output_py/resnet_py/sample_0/conv1_output.pt', 'test/output/resnet/sample_0_conv1_output.pt', 'conv1_output')
compare_tensors('test/output_py/resnet_py/sample_0/debug_resnet_conv1_output_for_bn1_input.pt', 'test/output/resnet/sample_0_debug_resnet_conv1_output_for_bn1_input.pt', 'debug_resnet_conv1_output_for_bn1_input')
compare_tensors('test/output_py/resnet_py/sample_0/bn1_output.pt', 'test/output/resnet/sample_0_bn1_output.pt', 'bn1_output')
# Print BN1 epsilon and momentum from Python model
try:
from ltr.models.backbone import resnet50
model = resnet50(output_layers=['layer1', 'layer2', 'layer3', 'layer4'], pretrained=False)
model.load_state_dict(torch.load('backbone_pure_tensors/state_dict.pt'))
print('\nPython BN1 epsilon:', model.bn1.eps)
print('Python BN1 momentum:', model.bn1.momentum)
except Exception as e:
print('Could not print Python BN1 eps/momentum:', e)
# Print expected C++ BN1 epsilon and momentum (from code)
print('C++ BN1 epsilon: 1e-5 (from C++ code)')
print('C++ BN1 momentum: 0.1 (from C++ code)')