You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

45 lines
1.9 KiB

import torch
import os
def load_tensor(path):
obj = torch.load(path, map_location='cpu', weights_only=False)
if isinstance(obj, torch.Tensor):
return obj
if hasattr(obj, 'tensor') and isinstance(obj.tensor, torch.Tensor):
return obj.tensor
if hasattr(obj, 'data') and isinstance(obj.data, torch.Tensor):
return obj.data
for name, p in getattr(obj, 'named_parameters', lambda: [])():
if isinstance(p, torch.Tensor):
return p
for name, b in getattr(obj, 'named_buffers', lambda: [])():
if isinstance(b, torch.Tensor):
return b
raise RuntimeError(f"Could not extract tensor from {path}")
def compare_tensors(tensor_py, tensor_cpp, label):
a = tensor_py.float().cpu().contiguous().view(-1)
b = tensor_cpp.float().cpu().contiguous().view(-1)
if a.shape != b.shape:
print(f"{label}: Shape mismatch: {a.shape} vs {b.shape}")
return
cos_sim = torch.nn.functional.cosine_similarity(a, b, dim=0).item()
mae = torch.mean(torch.abs(a - b)).item()
max_abs = torch.max(torch.abs(a - b)).item()
print(f"{label}: cos_sim={cos_sim:.8f}, MAE={mae:.8e}, max_abs={max_abs:.8e}")
def main():
pairs = [
("test/output_py/resnet_py/sample_0/layer2.pt", "test/output/resnet/sample_0_layer2.pt", "layer2"),
("test/output_py/resnet_py/sample_0/layer3.pt", "test/output/resnet/sample_0_layer3.pt", "layer3"),
]
for py_path, cpp_path, label in pairs:
if not (os.path.exists(py_path) and os.path.exists(cpp_path)):
print(f"{label}: Missing file(s): {py_path if not os.path.exists(py_path) else ''} {cpp_path if not os.path.exists(cpp_path) else ''}")
continue
tensor_py = load_tensor(py_path)
tensor_cpp = load_tensor(cpp_path)
compare_tensors(tensor_py, tensor_cpp, label)
if __name__ == "__main__":
main()