You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
36 lines
1.4 KiB
36 lines
1.4 KiB
import torch
|
|
import os
|
|
|
|
PYTHON_DIR = "exported_weights/backbone_pure_tensors/"
|
|
CPP_DIR = "exported_weights/backbone_regenerated/"
|
|
|
|
def compare_tensors(a, b, label):
|
|
a = a.float().cpu().contiguous().view(-1)
|
|
b = b.float().cpu().contiguous().view(-1)
|
|
if a.shape != b.shape:
|
|
print(f"{label}: Shape mismatch: {a.shape} vs {b.shape}")
|
|
return
|
|
cos_sim = torch.nn.functional.cosine_similarity(a, b, dim=0).item()
|
|
mae = torch.mean(torch.abs(a - b)).item()
|
|
max_abs = torch.max(torch.abs(a - b)).item()
|
|
print(f"{label}: cos_sim={cos_sim:.8f}, MAE={mae:.8e}, max_abs={max_abs:.8e}")
|
|
|
|
def main():
|
|
py_files = {f for f in os.listdir(PYTHON_DIR) if f.endswith('.pt')}
|
|
cpp_files = {f for f in os.listdir(CPP_DIR) if f.endswith('.pt')}
|
|
common_files = sorted(py_files & cpp_files)
|
|
missing_in_cpp = sorted(py_files - cpp_files)
|
|
missing_in_py = sorted(cpp_files - py_files)
|
|
|
|
if missing_in_cpp:
|
|
print("Files missing in C++ export:", missing_in_cpp)
|
|
if missing_in_py:
|
|
print("Files missing in Python export:", missing_in_py)
|
|
|
|
for fname in common_files:
|
|
py_tensor = torch.load(os.path.join(PYTHON_DIR, fname), map_location='cpu', weights_only=False)
|
|
cpp_tensor = torch.load(os.path.join(CPP_DIR, fname), map_location='cpu', weights_only=False)
|
|
compare_tensors(py_tensor, cpp_tensor, fname)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|