import torch import torchvision.models as models import os from collections import OrderedDict def export_weights(model, output_dir, doc_filename): """ Exports model weights as individual .pt files and creates a documentation file. Each tensor in the model's state_dict is saved. Filename convention: replaces '.' in state_dict keys with '_', appends '.pt'. """ if not os.path.exists(output_dir): os.makedirs(output_dir) print(f"Created directory: {output_dir}") doc_lines = ["# Auto-generated weights documentation\n"] state_dict = model.state_dict() print(f"Exporting {len(state_dict)} tensors to {output_dir}...") for key, tensor_data in state_dict.items(): # Use underscore naming convention for filename, matching DiMPTorchScriptWrapper expectations file_name = key.replace('.', '_') + '.pt' file_path = os.path.join(output_dir, file_name) # Save the tensor torch.save(tensor_data.clone().detach().cpu(), file_path) # Add entry to documentation file doc_lines.append(f"## {key}\n") doc_lines.append(f"Shape: {list(tensor_data.shape)}\n") doc_lines.append(f"Dtype: {tensor_data.dtype}\n") doc_lines.append(f"File: {file_name}\n\n") if (len(doc_lines) % 50) == 0: # Print progress periodically print(f" Processed {len(doc_lines)//4} tensors...") doc_file_path = os.path.join(output_dir, doc_filename) with open(doc_file_path, 'w') as f: f.writelines(doc_lines) print(f"Successfully exported {len(state_dict)} tensors.") print(f"Documentation file created: {doc_file_path}") if __name__ == "__main__": # --- Configuration --- # For ResNet-50, the original DiMP seems to use a ResNet variant that might not # exactly match torchvision's default ResNet-50 in terms of all parameter names # or structure, especially if it was modified for specific feature extraction. # The ltr.models.backbone.resnet.resnet50 is the one used by DiMPTorchScriptWrapper. # We need to ensure the keys from this model are used for saving, so that # DiMPTorchScriptWrapper can load them correctly. print("Loading reference ResNet-50 model structure (ltr.models.backbone.resnet)...") # This import assumes your project structure allows this script to find ltr # You might need to adjust sys.path if this script is placed outside the main project tree # For example, if cpp_tracker is the root: import sys # Assuming this script is in cpp_tracker/ or cpp_tracker/test/ # Adjust based on actual location if needed. # If script is in cpp_tracker root, this is fine. # If in cpp_tracker/test/, then '../' to get to cpp_tracker/ project_root = os.path.dirname(os.path.abspath(__file__)) # If in root # project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) # If in a subdir like /test # sys.path.insert(0, project_root) # Add project root to allow ltr import # Let's assume the script is in the root for now, or that ltr is in PYTHONPATH try: from ltr.models.backbone.resnet import resnet50 as ltr_resnet50 except ImportError as e: print(f"Error importing ltr.models.backbone.resnet: {e}") print("Please ensure that the 'ltr' module is in your PYTHONPATH, or adjust sys.path in this script.") print("You might need to run this script from the root of the cpp_tracker workspace or ensure correct setup.") sys.exit(1) # 1. Create an instance of the LTR ResNet-50 to get the correct parameter names and structure. # This model will define the *target* state_dict keys we want to save. print("Instantiating LTR ResNet-50 (for structure and param names)...") # Output_layers doesn't strictly matter here as we only need its state_dict keys, # but use a common setting. ltr_model = ltr_resnet50(output_layers=['layer1','layer2','layer3','layer4'], pretrained=False) ltr_model_state_dict_keys = ltr_model.state_dict().keys() print(f"LTR ResNet-50 instantiated. It has {len(ltr_model_state_dict_keys)} parameters/buffers.") # 2. Load the actual pretrained weights from torchvision. print("Loading pretrained ResNet-50 weights from torchvision...") torchvision_model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1) torchvision_state_dict = torchvision_model.state_dict() print("Torchvision ResNet-50 pretrained weights loaded.") # 3. Create a new state_dict that only contains keys present in *both* # the LTR model and the torchvision model, using torchvision's weights. # This handles potential mismatches like torchvision having a full 'fc' layer # that the LTR ResNet variant might not use or name identically. # Also, ltr.models.backbone.resnet.py applies its own normalization to conv layers # and fills batchnorm weights/biases. The torchvision pretrained=True model already has these. # So, we directly use the torchvision weights for matching keys. aligned_state_dict = OrderedDict() copied_keys = 0 torchvision_only_keys = [] ltr_only_keys_not_in_torchvision = [] for key in ltr_model_state_dict_keys: if key in torchvision_state_dict: if ltr_model.state_dict()[key].shape == torchvision_state_dict[key].shape: aligned_state_dict[key] = torchvision_state_dict[key].clone() copied_keys += 1 else: print(f" Shape mismatch for key '{key}': LTR shape {ltr_model.state_dict()[key].shape}, Torchvision shape {torchvision_state_dict[key].shape}. Skipping.") ltr_only_keys_not_in_torchvision.append(key + " (shape mismatch)") else: # If a key from LTR model is not in torchvision, it might be an architectural difference # or a buffer that torchvision doesn't save in its state_dict explicitly (e.g. num_batches_tracked for BN). # The LTR model initializes these, so we can take them from the un-trained ltr_model instance. # This is important for BN running_mean, running_var, and num_batches_tracked if not in torchvision sd. print(f" Key '{key}' in LTR model but not in Torchvision state_dict. Using LTR model's initial value for this key.") aligned_state_dict[key] = ltr_model.state_dict()[key].clone() # Use the initial value from ltr_model ltr_only_keys_not_in_torchvision.append(key + " (taken from LTR init)") copied_keys +=1 # Counting this as copied for completeness print(f"Matched and copied {copied_keys} Tensors from Torchvision to LTR structure.") if ltr_only_keys_not_in_torchvision: print(f" Keys in LTR model structure but not found in Torchvision pretrained state_dict (or shape mismatch): {len(ltr_only_keys_not_in_torchvision)}") for k_info in ltr_only_keys_not_in_torchvision[:10]: # Print first 10 print(f" - {k_info}") if len(ltr_only_keys_not_in_torchvision) > 10: print(" ...") for key in torchvision_state_dict.keys(): if key not in ltr_model_state_dict_keys: torchvision_only_keys.append(key) if torchvision_only_keys: print(f" Keys in Torchvision pretrained state_dict but not in LTR model structure: {len(torchvision_only_keys)}") for k in torchvision_only_keys[:10]: # Print first 10 print(f" - {k}") if len(torchvision_only_keys) > 10: print(" ...") # 4. Populate the LTR model instance with these aligned weights. # This isn't strictly necessary for saving, but it's a good check. print("Loading aligned state_dict into LTR model instance...") missing_keys, unexpected_keys = ltr_model.load_state_dict(aligned_state_dict, strict=False) # Use strict=False due to potential key differences if missing_keys: print(f" Warning: Missing keys when loading aligned_state_dict into LTR model: {missing_keys}") if unexpected_keys: print(f" Warning: Unexpected keys when loading aligned_state_dict into LTR model: {unexpected_keys}") if not missing_keys and not unexpected_keys: print(" Successfully loaded aligned state_dict into LTR model instance.") # 5. Now, use this populated ltr_model (or rather, its aligned_state_dict which has the correct keys and torchvision weights) # for the export_weights function. The export_weights function expects a model, but we can # give it an object that has a .state_dict() method returning our aligned_state_dict. class TempModelWrapper: def __init__(self, state_dict_to_serve): self._state_dict = state_dict_to_serve def state_dict(self): return self._state_dict model_to_export_from = TempModelWrapper(aligned_state_dict) output_directory = "exported_weights/backbone_regenerated" doc_file = "backbone_regenerated_weights_doc.txt" print(f"\nStarting export process to '{output_directory}'...") export_weights(model_to_export_from, output_directory, doc_file) print("\nScript finished.") print(f"Please check the '{output_directory}' for the .pt files and '{doc_file}'.") print("Next steps:") print("1. Update DiMPTorchScriptWrapper in pytracking/features/net_wrappers.py to use this new directory and doc file for ResNet.") print("2. Update C++ ResNet loading in cimp/resnet/resnet.cpp (and test_models.cpp) to use this new directory.") print("3. Re-run all tests (build.sh, then test/run_tests.sh).")