You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

205 lines
11 KiB

import torch
import torch.nn as nn
import os
import argparse
import sys
from collections import OrderedDict
# Add ltr to path to import ResNet
workspace_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '.'))
print(f"Workspace root: {workspace_root}")
ltr_path = os.path.join(workspace_root, 'ltr')
if not os.path.isdir(ltr_path):
print(f"Error: 'ltr' directory not found in {workspace_root}")
sys.exit(1)
sys.path.insert(0, workspace_root)
try:
# We might not strictly need resnet50 from ltr if loading a JIT model,
# but good to have for potential type checking or structure reference.
from ltr.models.backbone.resnet import resnet50
print("Successfully imported ResNet from ltr.models.backbone.resnet")
except ImportError as e:
print(f"Warning: Could not import ResNet from ltr.models.backbone.resnet: {e}")
# Depending on the JIT model, this might not be fatal.
class TensorContainer(nn.Module):
def __init__(self, tensor_to_wrap, tensor_name="tensor"):
super().__init__()
# Can't use register_buffer or register_parameter as these expect string keys
# that are valid python identifiers. setattr works for general attributes.
setattr(self, tensor_name, tensor_to_wrap)
def convert_param_name_to_filename(param_name):
"""Converts a PyTorch parameter name (e.g., layer1.0.conv1.weight)
to the underscore-separated filename convention (e.g., layer1_0_conv1_weight.pt).
"""
return param_name.replace('.', '_') + '.pt'
def load_weights_from_individual_files(model_to_populate, source_dir):
print(f"Attempting to load weights from individual files in: {source_dir} using underscore naming convention.")
new_state_dict = OrderedDict()
loaded_count = 0
missed_params = []
missed_buffers = []
# Parameters
for name, param_tensor_template in model_to_populate.named_parameters():
expected_filename = convert_param_name_to_filename(name)
filepath = os.path.join(source_dir, expected_filename)
if os.path.exists(filepath):
try:
# print(f" Loading parameter: {name} from {filepath}")
tensor_data = torch.load(filepath, map_location=torch.device('cpu'))
if isinstance(tensor_data, torch.jit.ScriptModule):
# If it's a JIT module (e.g. from previous save attempts or other sources)
# try to extract the tensor, assuming it was wrapped with a known key like 'tensor'
# or if it's a module with a single parameter/buffer.
try:
tensor_data = tensor_data.attr("tensor").toTensor()
print(f" INFO: Extracted tensor via .attr('tensor') from ScriptModule: {filepath}")
except RuntimeError:
params = list(tensor_data.parameters())
buffers = list(tensor_data.buffers())
if len(params) == 1:
tensor_data = params[0]
print(f" INFO: Extracted tensor from single parameter of ScriptModule: {filepath}")
elif len(buffers) == 1 and not params:
tensor_data = buffers[0]
print(f" INFO: Extracted tensor from single buffer of ScriptModule: {filepath}")
else:
raise ValueError(f"ScriptModule at {filepath} doesn't have .attr('tensor') and not single param/buffer.")
if not isinstance(tensor_data, torch.Tensor):
raise TypeError(f"Loaded data from {filepath} is not a tensor (type: {type(tensor_data)})")
if tensor_data.shape != param_tensor_template.data.shape:
print(f" WARNING: Shape mismatch for param {name}. Expected {param_tensor_template.data.shape}, got {tensor_data.shape} from {filepath}. Skipping.")
missed_params.append(name)
continue
new_state_dict[name] = tensor_data
loaded_count += 1
except Exception as e:
print(f" ERROR loading or processing {filepath} for param {name}: {e}. Skipping.")
missed_params.append(name)
else:
# print(f" File not found for parameter {name}: {filepath}. Will be missed.")
missed_params.append(name)
# Buffers
for name, buffer_tensor_template in model_to_populate.named_buffers():
expected_filename = convert_param_name_to_filename(name)
filepath = os.path.join(source_dir, expected_filename)
if os.path.exists(filepath):
try:
# print(f" Loading buffer: {name} from {filepath}")
tensor_data = torch.load(filepath, map_location=torch.device('cpu'))
if isinstance(tensor_data, torch.jit.ScriptModule):
try:
tensor_data = tensor_data.attr("tensor").toTensor()
print(f" INFO: Extracted tensor via .attr('tensor') from ScriptModule: {filepath}")
except RuntimeError:
params = list(tensor_data.parameters())
buffers = list(tensor_data.buffers())
if len(buffers) == 1 and not params:
tensor_data = buffers[0]
print(f" INFO: Extracted tensor from single buffer of ScriptModule: {filepath}")
elif len(params) == 1 and not buffers:
tensor_data = params[0]
print(f" INFO: Extracted tensor from single param of ScriptModule: {filepath}")
else:
raise ValueError(f"ScriptModule at {filepath} doesn't have .attr('tensor') and not single param/buffer.")
if not isinstance(tensor_data, torch.Tensor):
raise TypeError(f"Loaded data from {filepath} is not a tensor (type: {type(tensor_data)})")
if tensor_data.shape != buffer_tensor_template.data.shape:
print(f" WARNING: Shape mismatch for buffer {name}. Expected {buffer_tensor_template.data.shape}, got {tensor_data.shape} from {filepath}. Skipping.")
missed_buffers.append(name)
continue
new_state_dict[name] = tensor_data
loaded_count += 1
except Exception as e:
print(f" ERROR loading or processing {filepath} for buffer {name}: {e}. Skipping.")
missed_buffers.append(name)
else:
# print(f" File not found for buffer {name}: {filepath}. Will be missed.")
missed_buffers.append(name)
if loaded_count > 0:
print(f"Attempting to load {loaded_count} found tensors into model state_dict.")
result = model_to_populate.load_state_dict(new_state_dict, strict=False)
print("State_dict loading result:")
if result.missing_keys:
print(f" Strict load missing_keys ({len(result.missing_keys)}): {result.missing_keys[:20]}...") # Print first 20
if result.unexpected_keys:
print(f" Strict load unexpected_keys ({len(result.unexpected_keys)}): {result.unexpected_keys[:20]}...")
# Cross check with our own missed lists
print(f"Manually tracked missed parameters ({len(missed_params)}): {missed_params[:20]}...")
print(f"Manually tracked missed buffers ({len(missed_buffers)}): {missed_buffers[:20]}...")
# Check if all expected params/buffers in the model were loaded
all_model_keys = set(model_to_populate.state_dict().keys())
loaded_keys = set(new_state_dict.keys())
truly_missing_from_model = all_model_keys - loaded_keys
if truly_missing_from_model:
print(f"CRITICAL: Keys in model NOT found in source_dir ({len(truly_missing_from_model)}): {list(truly_missing_from_model)[:20]}...")
if not truly_missing_from_model and not result.unexpected_keys :
print("Successfully loaded weights from individual files into the model.")
else:
print("WARNING: Some weights might be missing or unexpected after loading from individual files.")
if not loaded_keys: # if we loaded nothing
print("ERROR: No weights were successfully loaded from individual files. Aborting.")
sys.exit(1)
else:
print("ERROR: No weights were found or loaded from individual files. Aborting.")
sys.exit(1)
def export_jit_wrapped_tensors(model, output_dir):
TENSOR_KEY_IN_CONTAINER = "tensor" # The key used in TensorContainer and for C++ loading
if not os.path.exists(output_dir):
os.makedirs(output_dir); print(f"Created output directory: {output_dir}")
for name, param in model.named_parameters():
filename = name + '.pt'
filepath = os.path.join(output_dir, filename)
print(f"Exporting JIT-wrapped parameter: {name} (as {filename}) to {filepath} with shape {param.data.shape}")
container = TensorContainer(param.data.clone().detach().cpu(), TENSOR_KEY_IN_CONTAINER)
scripted_container = torch.jit.script(container)
torch.jit.save(scripted_container, filepath)
for name, buf in model.named_buffers():
filename = name + '.pt'
filepath = os.path.join(output_dir, filename)
print(f"Exporting JIT-wrapped buffer: {name} (as {filename}) to {filepath} with shape {buf.data.shape}")
container = TensorContainer(buf.data.clone().detach().cpu(), TENSOR_KEY_IN_CONTAINER)
scripted_container = torch.jit.script(container)
torch.jit.save(scripted_container, filepath)
print(f"All params/buffers exported as JIT-wrapped tensors to {output_dir} (using dot naming, key '{TENSOR_KEY_IN_CONTAINER}').")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Load ResNet-50 weights from a directory of individual underscore_named .pt files, then re-export them as JIT-wrapped (TensorContainer) dot_named .pt files for C++ loading.")
parser.add_argument('--source_individual_weights_dir', type=str, required=True,
help="Directory containing the source underscore_named .pt files (e.g., 'exported_weights/backbone/').")
parser.add_argument('--output_jit_wrapped_tensors_dir', type=str, required=True,
help="Directory to save the re-exported JIT-wrapped dot_named .pt files (e.g., 'exported_weights/raw_backbone/').")
args = parser.parse_args()
print("Instantiating a new ResNet-50 model (will be populated from source dir)...")
model = resnet50(output_layers=['layer4'], pretrained=False)
print("ResNet-50 model instantiated.")
load_weights_from_individual_files(model, args.source_individual_weights_dir)
export_jit_wrapped_tensors(model, args.output_jit_wrapped_tensors_dir)
print("Process complete. Weights loaded from source (underscore_named) and re-exported as JIT-wrapped tensors (dot_named).")