import torch import os def load_tensor(path): obj = torch.load(path, map_location='cpu', weights_only=False) if isinstance(obj, torch.Tensor): return obj if hasattr(obj, 'tensor') and isinstance(obj.tensor, torch.Tensor): return obj.tensor if hasattr(obj, 'data') and isinstance(obj.data, torch.Tensor): return obj.data for name, p in getattr(obj, 'named_parameters', lambda: [])(): if isinstance(p, torch.Tensor): return p for name, b in getattr(obj, 'named_buffers', lambda: [])(): if isinstance(b, torch.Tensor): return b raise RuntimeError(f"Could not extract tensor from {path}") def compare_tensors(tensor_py, tensor_cpp, label): a = tensor_py.float().cpu().contiguous().view(-1) b = tensor_cpp.float().cpu().contiguous().view(-1) if a.shape != b.shape: print(f"{label}: Shape mismatch: {a.shape} vs {b.shape}") return cos_sim = torch.nn.functional.cosine_similarity(a, b, dim=0).item() mae = torch.mean(torch.abs(a - b)).item() max_abs = torch.max(torch.abs(a - b)).item() print(f"{label}: cos_sim={cos_sim:.8f}, MAE={mae:.8e}, max_abs={max_abs:.8e}") def main(): pairs = [ ("test/output_py/resnet_py/sample_0/layer2.pt", "test/output/resnet/sample_0_layer2.pt", "layer2"), ("test/output_py/resnet_py/sample_0/layer3.pt", "test/output/resnet/sample_0_layer3.pt", "layer3"), ] for py_path, cpp_path, label in pairs: if not (os.path.exists(py_path) and os.path.exists(cpp_path)): print(f"{label}: Missing file(s): {py_path if not os.path.exists(py_path) else ''} {cpp_path if not os.path.exists(cpp_path) else ''}") continue tensor_py = load_tensor(py_path) tensor_cpp = load_tensor(cpp_path) compare_tensors(tensor_py, tensor_cpp, label) if __name__ == "__main__": main()