You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
178 lines
9.4 KiB
178 lines
9.4 KiB
import torch
|
|
import torchvision.models as models
|
|
import os
|
|
from collections import OrderedDict
|
|
|
|
def export_weights(model, output_dir, doc_filename):
|
|
"""
|
|
Exports model weights as individual .pt files and creates a documentation file.
|
|
Each tensor in the model's state_dict is saved.
|
|
Filename convention: replaces '.' in state_dict keys with '_', appends '.pt'.
|
|
"""
|
|
if not os.path.exists(output_dir):
|
|
os.makedirs(output_dir)
|
|
print(f"Created directory: {output_dir}")
|
|
|
|
doc_lines = ["# Auto-generated weights documentation\n"]
|
|
state_dict = model.state_dict()
|
|
|
|
print(f"Exporting {len(state_dict)} tensors to {output_dir}...")
|
|
|
|
for key, tensor_data in state_dict.items():
|
|
# Use underscore naming convention for filename, matching DiMPTorchScriptWrapper expectations
|
|
file_name = key.replace('.', '_') + '.pt'
|
|
file_path = os.path.join(output_dir, file_name)
|
|
|
|
# Save the tensor
|
|
torch.save(tensor_data.clone().detach().cpu(), file_path)
|
|
|
|
# Add entry to documentation file
|
|
doc_lines.append(f"## {key}\n")
|
|
doc_lines.append(f"Shape: {list(tensor_data.shape)}\n")
|
|
doc_lines.append(f"Dtype: {tensor_data.dtype}\n")
|
|
doc_lines.append(f"File: {file_name}\n\n")
|
|
|
|
if (len(doc_lines) % 50) == 0: # Print progress periodically
|
|
print(f" Processed {len(doc_lines)//4} tensors...")
|
|
|
|
|
|
doc_file_path = os.path.join(output_dir, doc_filename)
|
|
with open(doc_file_path, 'w') as f:
|
|
f.writelines(doc_lines)
|
|
|
|
print(f"Successfully exported {len(state_dict)} tensors.")
|
|
print(f"Documentation file created: {doc_file_path}")
|
|
|
|
if __name__ == "__main__":
|
|
# --- Configuration ---
|
|
# For ResNet-50, the original DiMP seems to use a ResNet variant that might not
|
|
# exactly match torchvision's default ResNet-50 in terms of all parameter names
|
|
# or structure, especially if it was modified for specific feature extraction.
|
|
# The ltr.models.backbone.resnet.resnet50 is the one used by DiMPTorchScriptWrapper.
|
|
# We need to ensure the keys from this model are used for saving, so that
|
|
# DiMPTorchScriptWrapper can load them correctly.
|
|
|
|
print("Loading reference ResNet-50 model structure (ltr.models.backbone.resnet)...")
|
|
# This import assumes your project structure allows this script to find ltr
|
|
# You might need to adjust sys.path if this script is placed outside the main project tree
|
|
# For example, if cpp_tracker is the root:
|
|
import sys
|
|
# Assuming this script is in cpp_tracker/ or cpp_tracker/test/
|
|
# Adjust based on actual location if needed.
|
|
# If script is in cpp_tracker root, this is fine.
|
|
# If in cpp_tracker/test/, then '../' to get to cpp_tracker/
|
|
project_root = os.path.dirname(os.path.abspath(__file__)) # If in root
|
|
# project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) # If in a subdir like /test
|
|
# sys.path.insert(0, project_root) # Add project root to allow ltr import
|
|
|
|
# Let's assume the script is in the root for now, or that ltr is in PYTHONPATH
|
|
try:
|
|
from ltr.models.backbone.resnet import resnet50 as ltr_resnet50
|
|
except ImportError as e:
|
|
print(f"Error importing ltr.models.backbone.resnet: {e}")
|
|
print("Please ensure that the 'ltr' module is in your PYTHONPATH, or adjust sys.path in this script.")
|
|
print("You might need to run this script from the root of the cpp_tracker workspace or ensure correct setup.")
|
|
sys.exit(1)
|
|
|
|
# 1. Create an instance of the LTR ResNet-50 to get the correct parameter names and structure.
|
|
# This model will define the *target* state_dict keys we want to save.
|
|
print("Instantiating LTR ResNet-50 (for structure and param names)...")
|
|
# Output_layers doesn't strictly matter here as we only need its state_dict keys,
|
|
# but use a common setting.
|
|
ltr_model = ltr_resnet50(output_layers=['layer1','layer2','layer3','layer4'], pretrained=False)
|
|
ltr_model_state_dict_keys = ltr_model.state_dict().keys()
|
|
print(f"LTR ResNet-50 instantiated. It has {len(ltr_model_state_dict_keys)} parameters/buffers.")
|
|
|
|
# 2. Load the actual pretrained weights from torchvision.
|
|
print("Loading pretrained ResNet-50 weights from torchvision...")
|
|
torchvision_model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
|
|
torchvision_state_dict = torchvision_model.state_dict()
|
|
print("Torchvision ResNet-50 pretrained weights loaded.")
|
|
|
|
# 3. Create a new state_dict that only contains keys present in *both*
|
|
# the LTR model and the torchvision model, using torchvision's weights.
|
|
# This handles potential mismatches like torchvision having a full 'fc' layer
|
|
# that the LTR ResNet variant might not use or name identically.
|
|
|
|
# Also, ltr.models.backbone.resnet.py applies its own normalization to conv layers
|
|
# and fills batchnorm weights/biases. The torchvision pretrained=True model already has these.
|
|
# So, we directly use the torchvision weights for matching keys.
|
|
|
|
aligned_state_dict = OrderedDict()
|
|
copied_keys = 0
|
|
torchvision_only_keys = []
|
|
ltr_only_keys_not_in_torchvision = []
|
|
|
|
for key in ltr_model_state_dict_keys:
|
|
if key in torchvision_state_dict:
|
|
if ltr_model.state_dict()[key].shape == torchvision_state_dict[key].shape:
|
|
aligned_state_dict[key] = torchvision_state_dict[key].clone()
|
|
copied_keys += 1
|
|
else:
|
|
print(f" Shape mismatch for key '{key}': LTR shape {ltr_model.state_dict()[key].shape}, Torchvision shape {torchvision_state_dict[key].shape}. Skipping.")
|
|
ltr_only_keys_not_in_torchvision.append(key + " (shape mismatch)")
|
|
else:
|
|
# If a key from LTR model is not in torchvision, it might be an architectural difference
|
|
# or a buffer that torchvision doesn't save in its state_dict explicitly (e.g. num_batches_tracked for BN).
|
|
# The LTR model initializes these, so we can take them from the un-trained ltr_model instance.
|
|
# This is important for BN running_mean, running_var, and num_batches_tracked if not in torchvision sd.
|
|
print(f" Key '{key}' in LTR model but not in Torchvision state_dict. Using LTR model's initial value for this key.")
|
|
aligned_state_dict[key] = ltr_model.state_dict()[key].clone() # Use the initial value from ltr_model
|
|
ltr_only_keys_not_in_torchvision.append(key + " (taken from LTR init)")
|
|
copied_keys +=1 # Counting this as copied for completeness
|
|
|
|
print(f"Matched and copied {copied_keys} Tensors from Torchvision to LTR structure.")
|
|
if ltr_only_keys_not_in_torchvision:
|
|
print(f" Keys in LTR model structure but not found in Torchvision pretrained state_dict (or shape mismatch): {len(ltr_only_keys_not_in_torchvision)}")
|
|
for k_info in ltr_only_keys_not_in_torchvision[:10]: # Print first 10
|
|
print(f" - {k_info}")
|
|
if len(ltr_only_keys_not_in_torchvision) > 10: print(" ...")
|
|
|
|
|
|
for key in torchvision_state_dict.keys():
|
|
if key not in ltr_model_state_dict_keys:
|
|
torchvision_only_keys.append(key)
|
|
|
|
if torchvision_only_keys:
|
|
print(f" Keys in Torchvision pretrained state_dict but not in LTR model structure: {len(torchvision_only_keys)}")
|
|
for k in torchvision_only_keys[:10]: # Print first 10
|
|
print(f" - {k}")
|
|
if len(torchvision_only_keys) > 10: print(" ...")
|
|
|
|
|
|
# 4. Populate the LTR model instance with these aligned weights.
|
|
# This isn't strictly necessary for saving, but it's a good check.
|
|
print("Loading aligned state_dict into LTR model instance...")
|
|
missing_keys, unexpected_keys = ltr_model.load_state_dict(aligned_state_dict, strict=False) # Use strict=False due to potential key differences
|
|
|
|
if missing_keys:
|
|
print(f" Warning: Missing keys when loading aligned_state_dict into LTR model: {missing_keys}")
|
|
if unexpected_keys:
|
|
print(f" Warning: Unexpected keys when loading aligned_state_dict into LTR model: {unexpected_keys}")
|
|
if not missing_keys and not unexpected_keys:
|
|
print(" Successfully loaded aligned state_dict into LTR model instance.")
|
|
|
|
|
|
# 5. Now, use this populated ltr_model (or rather, its aligned_state_dict which has the correct keys and torchvision weights)
|
|
# for the export_weights function. The export_weights function expects a model, but we can
|
|
# give it an object that has a .state_dict() method returning our aligned_state_dict.
|
|
class TempModelWrapper:
|
|
def __init__(self, state_dict_to_serve):
|
|
self._state_dict = state_dict_to_serve
|
|
def state_dict(self):
|
|
return self._state_dict
|
|
|
|
model_to_export_from = TempModelWrapper(aligned_state_dict)
|
|
|
|
output_directory = "exported_weights/backbone_regenerated"
|
|
doc_file = "backbone_regenerated_weights_doc.txt"
|
|
|
|
print(f"\nStarting export process to '{output_directory}'...")
|
|
export_weights(model_to_export_from, output_directory, doc_file)
|
|
|
|
print("\nScript finished.")
|
|
print(f"Please check the '{output_directory}' for the .pt files and '{doc_file}'.")
|
|
print("Next steps:")
|
|
print("1. Update DiMPTorchScriptWrapper in pytracking/features/net_wrappers.py to use this new directory and doc file for ResNet.")
|
|
print("2. Update C++ ResNet loading in cimp/resnet/resnet.cpp (and test_models.cpp) to use this new directory.")
|
|
print("3. Re-run all tests (build.sh, then test/run_tests.sh).")
|