import torch def get_tensor(x): if isinstance(x, torch.Tensor): return x if hasattr(x, 'parameters') and len(list(x.parameters())) > 0: return list(x.parameters())[0] if callable(x): try: out = x() if isinstance(out, torch.Tensor): return out except Exception: pass raise RuntimeError('Could not extract tensor') a = torch.load('test/output/resnet/sample_0_image_preprocessed_python.pt') b = torch.load('test/output/resnet/sample_0_image_preprocessed_cpp.pt') a = get_tensor(a) b = get_tensor(b) print('py:', a.shape, a.dtype, a.min().item(), a.max().item(), a.mean().item()) print('cpp:', b.shape, b.dtype, b.min().item(), b.max().item(), b.mean().item()) print('diff:', (a-b).abs().max().item(), (a-b).abs().mean().item())