You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
23 lines
867 B
23 lines
867 B
import torch
|
|
|
|
def get_tensor(x):
|
|
if isinstance(x, torch.Tensor):
|
|
return x
|
|
if hasattr(x, 'parameters') and len(list(x.parameters())) > 0:
|
|
return list(x.parameters())[0]
|
|
if callable(x):
|
|
try:
|
|
out = x()
|
|
if isinstance(out, torch.Tensor):
|
|
return out
|
|
except Exception:
|
|
pass
|
|
raise RuntimeError('Could not extract tensor')
|
|
|
|
a = torch.load('test/output/resnet/sample_0_image_preprocessed_python.pt', weights_only=False)
|
|
b = torch.load('test/output/resnet/sample_0_image_preprocessed_cpp.pt', weights_only=False)
|
|
a = get_tensor(a)
|
|
b = get_tensor(b)
|
|
print('py:', a.shape, a.dtype, a.min().item(), a.max().item(), a.mean().item())
|
|
print('cpp:', b.shape, b.dtype, b.min().item(), b.max().item(), b.mean().item())
|
|
print('diff:', (a-b).abs().max().item(), (a-b).abs().mean().item())
|