import torch import torch.nn as nn import os import argparse import sys from collections import OrderedDict # Add ltr to path to import ResNet workspace_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '.')) print(f"Workspace root: {workspace_root}") ltr_path = os.path.join(workspace_root, 'ltr') if not os.path.isdir(ltr_path): print(f"Error: 'ltr' directory not found in {workspace_root}") sys.exit(1) sys.path.insert(0, workspace_root) try: # We might not strictly need resnet50 from ltr if loading a JIT model, # but good to have for potential type checking or structure reference. from ltr.models.backbone.resnet import resnet50 print("Successfully imported ResNet from ltr.models.backbone.resnet") except ImportError as e: print(f"Warning: Could not import ResNet from ltr.models.backbone.resnet: {e}") # Depending on the JIT model, this might not be fatal. class TensorContainer(nn.Module): def __init__(self, tensor_to_wrap, tensor_name="tensor"): super().__init__() # Can't use register_buffer or register_parameter as these expect string keys # that are valid python identifiers. setattr works for general attributes. setattr(self, tensor_name, tensor_to_wrap) def convert_param_name_to_filename(param_name): """Converts a PyTorch parameter name (e.g., layer1.0.conv1.weight) to the underscore-separated filename convention (e.g., layer1_0_conv1_weight.pt). """ return param_name.replace('.', '_') + '.pt' def load_weights_from_individual_files(model_to_populate, source_dir): print(f"Attempting to load weights from individual files in: {source_dir} using underscore naming convention.") new_state_dict = OrderedDict() loaded_count = 0 missed_params = [] missed_buffers = [] # Parameters for name, param_tensor_template in model_to_populate.named_parameters(): expected_filename = convert_param_name_to_filename(name) filepath = os.path.join(source_dir, expected_filename) if os.path.exists(filepath): try: # print(f" Loading parameter: {name} from {filepath}") tensor_data = torch.load(filepath, map_location=torch.device('cpu')) if isinstance(tensor_data, torch.jit.ScriptModule): # If it's a JIT module (e.g. from previous save attempts or other sources) # try to extract the tensor, assuming it was wrapped with a known key like 'tensor' # or if it's a module with a single parameter/buffer. try: tensor_data = tensor_data.attr("tensor").toTensor() print(f" INFO: Extracted tensor via .attr('tensor') from ScriptModule: {filepath}") except RuntimeError: params = list(tensor_data.parameters()) buffers = list(tensor_data.buffers()) if len(params) == 1: tensor_data = params[0] print(f" INFO: Extracted tensor from single parameter of ScriptModule: {filepath}") elif len(buffers) == 1 and not params: tensor_data = buffers[0] print(f" INFO: Extracted tensor from single buffer of ScriptModule: {filepath}") else: raise ValueError(f"ScriptModule at {filepath} doesn't have .attr('tensor') and not single param/buffer.") if not isinstance(tensor_data, torch.Tensor): raise TypeError(f"Loaded data from {filepath} is not a tensor (type: {type(tensor_data)})") if tensor_data.shape != param_tensor_template.data.shape: print(f" WARNING: Shape mismatch for param {name}. Expected {param_tensor_template.data.shape}, got {tensor_data.shape} from {filepath}. Skipping.") missed_params.append(name) continue new_state_dict[name] = tensor_data loaded_count += 1 except Exception as e: print(f" ERROR loading or processing {filepath} for param {name}: {e}. Skipping.") missed_params.append(name) else: # print(f" File not found for parameter {name}: {filepath}. Will be missed.") missed_params.append(name) # Buffers for name, buffer_tensor_template in model_to_populate.named_buffers(): expected_filename = convert_param_name_to_filename(name) filepath = os.path.join(source_dir, expected_filename) if os.path.exists(filepath): try: # print(f" Loading buffer: {name} from {filepath}") tensor_data = torch.load(filepath, map_location=torch.device('cpu')) if isinstance(tensor_data, torch.jit.ScriptModule): try: tensor_data = tensor_data.attr("tensor").toTensor() print(f" INFO: Extracted tensor via .attr('tensor') from ScriptModule: {filepath}") except RuntimeError: params = list(tensor_data.parameters()) buffers = list(tensor_data.buffers()) if len(buffers) == 1 and not params: tensor_data = buffers[0] print(f" INFO: Extracted tensor from single buffer of ScriptModule: {filepath}") elif len(params) == 1 and not buffers: tensor_data = params[0] print(f" INFO: Extracted tensor from single param of ScriptModule: {filepath}") else: raise ValueError(f"ScriptModule at {filepath} doesn't have .attr('tensor') and not single param/buffer.") if not isinstance(tensor_data, torch.Tensor): raise TypeError(f"Loaded data from {filepath} is not a tensor (type: {type(tensor_data)})") if tensor_data.shape != buffer_tensor_template.data.shape: print(f" WARNING: Shape mismatch for buffer {name}. Expected {buffer_tensor_template.data.shape}, got {tensor_data.shape} from {filepath}. Skipping.") missed_buffers.append(name) continue new_state_dict[name] = tensor_data loaded_count += 1 except Exception as e: print(f" ERROR loading or processing {filepath} for buffer {name}: {e}. Skipping.") missed_buffers.append(name) else: # print(f" File not found for buffer {name}: {filepath}. Will be missed.") missed_buffers.append(name) if loaded_count > 0: print(f"Attempting to load {loaded_count} found tensors into model state_dict.") result = model_to_populate.load_state_dict(new_state_dict, strict=False) print("State_dict loading result:") if result.missing_keys: print(f" Strict load missing_keys ({len(result.missing_keys)}): {result.missing_keys[:20]}...") # Print first 20 if result.unexpected_keys: print(f" Strict load unexpected_keys ({len(result.unexpected_keys)}): {result.unexpected_keys[:20]}...") # Cross check with our own missed lists print(f"Manually tracked missed parameters ({len(missed_params)}): {missed_params[:20]}...") print(f"Manually tracked missed buffers ({len(missed_buffers)}): {missed_buffers[:20]}...") # Check if all expected params/buffers in the model were loaded all_model_keys = set(model_to_populate.state_dict().keys()) loaded_keys = set(new_state_dict.keys()) truly_missing_from_model = all_model_keys - loaded_keys if truly_missing_from_model: print(f"CRITICAL: Keys in model NOT found in source_dir ({len(truly_missing_from_model)}): {list(truly_missing_from_model)[:20]}...") if not truly_missing_from_model and not result.unexpected_keys : print("Successfully loaded weights from individual files into the model.") else: print("WARNING: Some weights might be missing or unexpected after loading from individual files.") if not loaded_keys: # if we loaded nothing print("ERROR: No weights were successfully loaded from individual files. Aborting.") sys.exit(1) else: print("ERROR: No weights were found or loaded from individual files. Aborting.") sys.exit(1) def export_jit_wrapped_tensors(model, output_dir): TENSOR_KEY_IN_CONTAINER = "tensor" # The key used in TensorContainer and for C++ loading if not os.path.exists(output_dir): os.makedirs(output_dir); print(f"Created output directory: {output_dir}") for name, param in model.named_parameters(): filename = name + '.pt' filepath = os.path.join(output_dir, filename) print(f"Exporting JIT-wrapped parameter: {name} (as {filename}) to {filepath} with shape {param.data.shape}") container = TensorContainer(param.data.clone().detach().cpu(), TENSOR_KEY_IN_CONTAINER) scripted_container = torch.jit.script(container) torch.jit.save(scripted_container, filepath) for name, buf in model.named_buffers(): filename = name + '.pt' filepath = os.path.join(output_dir, filename) print(f"Exporting JIT-wrapped buffer: {name} (as {filename}) to {filepath} with shape {buf.data.shape}") container = TensorContainer(buf.data.clone().detach().cpu(), TENSOR_KEY_IN_CONTAINER) scripted_container = torch.jit.script(container) torch.jit.save(scripted_container, filepath) print(f"All params/buffers exported as JIT-wrapped tensors to {output_dir} (using dot naming, key '{TENSOR_KEY_IN_CONTAINER}').") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Load ResNet-50 weights from a directory of individual underscore_named .pt files, then re-export them as JIT-wrapped (TensorContainer) dot_named .pt files for C++ loading.") parser.add_argument('--source_individual_weights_dir', type=str, required=True, help="Directory containing the source underscore_named .pt files (e.g., 'exported_weights/backbone/').") parser.add_argument('--output_jit_wrapped_tensors_dir', type=str, required=True, help="Directory to save the re-exported JIT-wrapped dot_named .pt files (e.g., 'exported_weights/raw_backbone/').") args = parser.parse_args() print("Instantiating a new ResNet-50 model (will be populated from source dir)...") model = resnet50(output_layers=['layer4'], pretrained=False) print("ResNet-50 model instantiated.") load_weights_from_individual_files(model, args.source_individual_weights_dir) export_jit_wrapped_tensors(model, args.output_jit_wrapped_tensors_dir) print("Process complete. Weights loaded from source (underscore_named) and re-exported as JIT-wrapped tensors (dot_named).")