You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
51 lines
2.3 KiB
51 lines
2.3 KiB
import torch
|
|
import os
|
|
import numpy as np
|
|
from torch.nn.functional import cosine_similarity
|
|
|
|
def load_tensor(path):
|
|
obj = torch.load(path, map_location='cpu', weights_only=False)
|
|
if isinstance(obj, torch.Tensor):
|
|
return obj
|
|
# Try to extract tensor from JIT module
|
|
if hasattr(obj, 'tensor') and isinstance(obj.tensor, torch.Tensor):
|
|
return obj.tensor
|
|
if hasattr(obj, 'data') and isinstance(obj.data, torch.Tensor):
|
|
return obj.data
|
|
# Try named_parameters or named_buffers
|
|
for name, p in getattr(obj, 'named_parameters', lambda: [])():
|
|
if isinstance(p, torch.Tensor):
|
|
return p
|
|
for name, b in getattr(obj, 'named_buffers', lambda: [])():
|
|
if isinstance(b, torch.Tensor):
|
|
return b
|
|
raise RuntimeError(f"Could not extract tensor from {path}")
|
|
|
|
def compare_tensors(tensor_py, tensor_cpp, label):
|
|
a = tensor_py.float().cpu().contiguous().view(-1)
|
|
b = tensor_cpp.float().cpu().contiguous().view(-1)
|
|
if a.shape != b.shape:
|
|
print(f"{label}: Shape mismatch: {a.shape} vs {b.shape}")
|
|
return
|
|
cos_sim = torch.nn.functional.cosine_similarity(a, b, dim=0).item()
|
|
mae = torch.mean(torch.abs(a - b)).item()
|
|
max_abs = torch.max(torch.abs(a - b)).item()
|
|
print(f"{label}: cos_sim={cos_sim:.8f}, MAE={mae:.8e}, max_abs={max_abs:.8e}")
|
|
|
|
def main():
|
|
pairs = [
|
|
("test/output_py/bb_regressor/sample_0_debug_conv3_1t_py.pt", "test/output/bb_regressor/sample_0_debug_conv3_1t.pt", "conv3_1t"),
|
|
("test/output_py/bb_regressor/sample_0_debug_conv3_2t_py.pt", "test/output/bb_regressor/sample_0_debug_conv3_2t.pt", "conv3_2t"),
|
|
("test/output_py/bb_regressor/sample_0_debug_conv4_1t_py.pt", "test/output/bb_regressor/sample_0_debug_conv4_1t.pt", "conv4_1t"),
|
|
("test/output_py/bb_regressor/sample_0_debug_conv4_2t_py.pt", "test/output/bb_regressor/sample_0_debug_conv4_2t.pt", "conv4_2t"),
|
|
]
|
|
for py_path, cpp_path, label in pairs:
|
|
if not (os.path.exists(py_path) and os.path.exists(cpp_path)):
|
|
print(f"{label}: Missing file(s): {py_path if not os.path.exists(py_path) else ''} {cpp_path if not os.path.exists(cpp_path) else ''}")
|
|
continue
|
|
tensor_py = load_tensor(py_path)
|
|
tensor_cpp = load_tensor(cpp_path)
|
|
compare_tensors(tensor_py, tensor_cpp, label)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|