You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

23 lines
827 B

import torch
def get_tensor(x):
if isinstance(x, torch.Tensor):
return x
if hasattr(x, 'parameters') and len(list(x.parameters())) > 0:
return list(x.parameters())[0]
if callable(x):
try:
out = x()
if isinstance(out, torch.Tensor):
return out
except Exception:
pass
raise RuntimeError('Could not extract tensor')
a = torch.load('test/output/resnet/sample_0_image_preprocessed_python.pt')
b = torch.load('test/output/resnet/sample_0_image_preprocessed_cpp.pt')
a = get_tensor(a)
b = get_tensor(b)
print('py:', a.shape, a.dtype, a.min().item(), a.max().item(), a.mean().item())
print('cpp:', b.shape, b.dtype, b.min().item(), b.max().item(), b.mean().item())
print('diff:', (a-b).abs().max().item(), (a-b).abs().mean().item())