You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
28 lines
998 B
28 lines
998 B
import os
|
|
import torch
|
|
from pathlib import Path
|
|
from ltr.models.backbone import resnet50
|
|
|
|
# Directory to save pure tensor weights
|
|
out_dir = Path('exported_weights/backbone_pure_tensors')
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Load backbone as in the tracker
|
|
model = resnet50(output_layers=['layer1', 'layer2', 'layer3', 'layer4'], pretrained=False)
|
|
|
|
# Load weights from the split files (original directory)
|
|
def load_weights_from_tensors(model, tensor_dir):
|
|
sd = model.state_dict()
|
|
for k in sd:
|
|
tensor_path = Path(tensor_dir) / (k.replace('.', '_') + '.pt')
|
|
if tensor_path.exists():
|
|
sd[k] = torch.load(tensor_path, map_location='cpu')
|
|
model.load_state_dict(sd)
|
|
|
|
load_weights_from_tensors(model, 'exported_weights/backbone')
|
|
|
|
# Save each parameter as a pure tensor
|
|
for name, param in model.state_dict().items():
|
|
out_path = out_dir / (name.replace('.', '_') + '.pt')
|
|
torch.save(param.detach().cpu(), out_path)
|
|
print(f"[OK] Saved {out_path}")
|