You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

28 lines
998 B

import os
import torch
from pathlib import Path
from ltr.models.backbone import resnet50
# Directory to save pure tensor weights
out_dir = Path('exported_weights/backbone_pure_tensors')
out_dir.mkdir(parents=True, exist_ok=True)
# Load backbone as in the tracker
model = resnet50(output_layers=['layer1', 'layer2', 'layer3', 'layer4'], pretrained=False)
# Load weights from the split files (original directory)
def load_weights_from_tensors(model, tensor_dir):
sd = model.state_dict()
for k in sd:
tensor_path = Path(tensor_dir) / (k.replace('.', '_') + '.pt')
if tensor_path.exists():
sd[k] = torch.load(tensor_path, map_location='cpu')
model.load_state_dict(sd)
load_weights_from_tensors(model, 'exported_weights/backbone')
# Save each parameter as a pure tensor
for name, param in model.state_dict().items():
out_path = out_dir / (name.replace('.', '_') + '.pt')
torch.save(param.detach().cpu(), out_path)
print(f"[OK] Saved {out_path}")