You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
205 lines
11 KiB
205 lines
11 KiB
import torch
|
|
import torch.nn as nn
|
|
import os
|
|
import argparse
|
|
import sys
|
|
from collections import OrderedDict
|
|
|
|
# Add ltr to path to import ResNet
|
|
workspace_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '.'))
|
|
print(f"Workspace root: {workspace_root}")
|
|
ltr_path = os.path.join(workspace_root, 'ltr')
|
|
if not os.path.isdir(ltr_path):
|
|
print(f"Error: 'ltr' directory not found in {workspace_root}")
|
|
sys.exit(1)
|
|
sys.path.insert(0, workspace_root)
|
|
|
|
try:
|
|
# We might not strictly need resnet50 from ltr if loading a JIT model,
|
|
# but good to have for potential type checking or structure reference.
|
|
from ltr.models.backbone.resnet import resnet50
|
|
print("Successfully imported ResNet from ltr.models.backbone.resnet")
|
|
except ImportError as e:
|
|
print(f"Warning: Could not import ResNet from ltr.models.backbone.resnet: {e}")
|
|
# Depending on the JIT model, this might not be fatal.
|
|
|
|
class TensorContainer(nn.Module):
|
|
def __init__(self, tensor_to_wrap, tensor_name="tensor"):
|
|
super().__init__()
|
|
# Can't use register_buffer or register_parameter as these expect string keys
|
|
# that are valid python identifiers. setattr works for general attributes.
|
|
setattr(self, tensor_name, tensor_to_wrap)
|
|
|
|
def convert_param_name_to_filename(param_name):
|
|
"""Converts a PyTorch parameter name (e.g., layer1.0.conv1.weight)
|
|
to the underscore-separated filename convention (e.g., layer1_0_conv1_weight.pt).
|
|
"""
|
|
return param_name.replace('.', '_') + '.pt'
|
|
|
|
def load_weights_from_individual_files(model_to_populate, source_dir):
|
|
print(f"Attempting to load weights from individual files in: {source_dir} using underscore naming convention.")
|
|
new_state_dict = OrderedDict()
|
|
loaded_count = 0
|
|
missed_params = []
|
|
missed_buffers = []
|
|
|
|
# Parameters
|
|
for name, param_tensor_template in model_to_populate.named_parameters():
|
|
expected_filename = convert_param_name_to_filename(name)
|
|
filepath = os.path.join(source_dir, expected_filename)
|
|
|
|
if os.path.exists(filepath):
|
|
try:
|
|
# print(f" Loading parameter: {name} from {filepath}")
|
|
tensor_data = torch.load(filepath, map_location=torch.device('cpu'))
|
|
|
|
if isinstance(tensor_data, torch.jit.ScriptModule):
|
|
# If it's a JIT module (e.g. from previous save attempts or other sources)
|
|
# try to extract the tensor, assuming it was wrapped with a known key like 'tensor'
|
|
# or if it's a module with a single parameter/buffer.
|
|
try:
|
|
tensor_data = tensor_data.attr("tensor").toTensor()
|
|
print(f" INFO: Extracted tensor via .attr('tensor') from ScriptModule: {filepath}")
|
|
except RuntimeError:
|
|
params = list(tensor_data.parameters())
|
|
buffers = list(tensor_data.buffers())
|
|
if len(params) == 1:
|
|
tensor_data = params[0]
|
|
print(f" INFO: Extracted tensor from single parameter of ScriptModule: {filepath}")
|
|
elif len(buffers) == 1 and not params:
|
|
tensor_data = buffers[0]
|
|
print(f" INFO: Extracted tensor from single buffer of ScriptModule: {filepath}")
|
|
else:
|
|
raise ValueError(f"ScriptModule at {filepath} doesn't have .attr('tensor') and not single param/buffer.")
|
|
|
|
if not isinstance(tensor_data, torch.Tensor):
|
|
raise TypeError(f"Loaded data from {filepath} is not a tensor (type: {type(tensor_data)})")
|
|
|
|
if tensor_data.shape != param_tensor_template.data.shape:
|
|
print(f" WARNING: Shape mismatch for param {name}. Expected {param_tensor_template.data.shape}, got {tensor_data.shape} from {filepath}. Skipping.")
|
|
missed_params.append(name)
|
|
continue
|
|
new_state_dict[name] = tensor_data
|
|
loaded_count += 1
|
|
except Exception as e:
|
|
print(f" ERROR loading or processing {filepath} for param {name}: {e}. Skipping.")
|
|
missed_params.append(name)
|
|
else:
|
|
# print(f" File not found for parameter {name}: {filepath}. Will be missed.")
|
|
missed_params.append(name)
|
|
|
|
# Buffers
|
|
for name, buffer_tensor_template in model_to_populate.named_buffers():
|
|
expected_filename = convert_param_name_to_filename(name)
|
|
filepath = os.path.join(source_dir, expected_filename)
|
|
if os.path.exists(filepath):
|
|
try:
|
|
# print(f" Loading buffer: {name} from {filepath}")
|
|
tensor_data = torch.load(filepath, map_location=torch.device('cpu'))
|
|
|
|
if isinstance(tensor_data, torch.jit.ScriptModule):
|
|
try:
|
|
tensor_data = tensor_data.attr("tensor").toTensor()
|
|
print(f" INFO: Extracted tensor via .attr('tensor') from ScriptModule: {filepath}")
|
|
except RuntimeError:
|
|
params = list(tensor_data.parameters())
|
|
buffers = list(tensor_data.buffers())
|
|
if len(buffers) == 1 and not params:
|
|
tensor_data = buffers[0]
|
|
print(f" INFO: Extracted tensor from single buffer of ScriptModule: {filepath}")
|
|
elif len(params) == 1 and not buffers:
|
|
tensor_data = params[0]
|
|
print(f" INFO: Extracted tensor from single param of ScriptModule: {filepath}")
|
|
else:
|
|
raise ValueError(f"ScriptModule at {filepath} doesn't have .attr('tensor') and not single param/buffer.")
|
|
|
|
if not isinstance(tensor_data, torch.Tensor):
|
|
raise TypeError(f"Loaded data from {filepath} is not a tensor (type: {type(tensor_data)})")
|
|
|
|
if tensor_data.shape != buffer_tensor_template.data.shape:
|
|
print(f" WARNING: Shape mismatch for buffer {name}. Expected {buffer_tensor_template.data.shape}, got {tensor_data.shape} from {filepath}. Skipping.")
|
|
missed_buffers.append(name)
|
|
continue
|
|
new_state_dict[name] = tensor_data
|
|
loaded_count += 1
|
|
except Exception as e:
|
|
print(f" ERROR loading or processing {filepath} for buffer {name}: {e}. Skipping.")
|
|
missed_buffers.append(name)
|
|
else:
|
|
# print(f" File not found for buffer {name}: {filepath}. Will be missed.")
|
|
missed_buffers.append(name)
|
|
|
|
if loaded_count > 0:
|
|
print(f"Attempting to load {loaded_count} found tensors into model state_dict.")
|
|
result = model_to_populate.load_state_dict(new_state_dict, strict=False)
|
|
print("State_dict loading result:")
|
|
if result.missing_keys:
|
|
print(f" Strict load missing_keys ({len(result.missing_keys)}): {result.missing_keys[:20]}...") # Print first 20
|
|
if result.unexpected_keys:
|
|
print(f" Strict load unexpected_keys ({len(result.unexpected_keys)}): {result.unexpected_keys[:20]}...")
|
|
|
|
# Cross check with our own missed lists
|
|
print(f"Manually tracked missed parameters ({len(missed_params)}): {missed_params[:20]}...")
|
|
print(f"Manually tracked missed buffers ({len(missed_buffers)}): {missed_buffers[:20]}...")
|
|
|
|
# Check if all expected params/buffers in the model were loaded
|
|
all_model_keys = set(model_to_populate.state_dict().keys())
|
|
loaded_keys = set(new_state_dict.keys())
|
|
truly_missing_from_model = all_model_keys - loaded_keys
|
|
if truly_missing_from_model:
|
|
print(f"CRITICAL: Keys in model NOT found in source_dir ({len(truly_missing_from_model)}): {list(truly_missing_from_model)[:20]}...")
|
|
|
|
if not truly_missing_from_model and not result.unexpected_keys :
|
|
print("Successfully loaded weights from individual files into the model.")
|
|
else:
|
|
print("WARNING: Some weights might be missing or unexpected after loading from individual files.")
|
|
if not loaded_keys: # if we loaded nothing
|
|
print("ERROR: No weights were successfully loaded from individual files. Aborting.")
|
|
sys.exit(1)
|
|
|
|
|
|
else:
|
|
print("ERROR: No weights were found or loaded from individual files. Aborting.")
|
|
sys.exit(1)
|
|
|
|
def export_jit_wrapped_tensors(model, output_dir):
|
|
TENSOR_KEY_IN_CONTAINER = "tensor" # The key used in TensorContainer and for C++ loading
|
|
if not os.path.exists(output_dir):
|
|
os.makedirs(output_dir); print(f"Created output directory: {output_dir}")
|
|
|
|
for name, param in model.named_parameters():
|
|
filename = name + '.pt'
|
|
filepath = os.path.join(output_dir, filename)
|
|
print(f"Exporting JIT-wrapped parameter: {name} (as {filename}) to {filepath} with shape {param.data.shape}")
|
|
container = TensorContainer(param.data.clone().detach().cpu(), TENSOR_KEY_IN_CONTAINER)
|
|
scripted_container = torch.jit.script(container)
|
|
torch.jit.save(scripted_container, filepath)
|
|
|
|
for name, buf in model.named_buffers():
|
|
filename = name + '.pt'
|
|
filepath = os.path.join(output_dir, filename)
|
|
print(f"Exporting JIT-wrapped buffer: {name} (as {filename}) to {filepath} with shape {buf.data.shape}")
|
|
container = TensorContainer(buf.data.clone().detach().cpu(), TENSOR_KEY_IN_CONTAINER)
|
|
scripted_container = torch.jit.script(container)
|
|
torch.jit.save(scripted_container, filepath)
|
|
|
|
print(f"All params/buffers exported as JIT-wrapped tensors to {output_dir} (using dot naming, key '{TENSOR_KEY_IN_CONTAINER}').")
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Load ResNet-50 weights from a directory of individual underscore_named .pt files, then re-export them as JIT-wrapped (TensorContainer) dot_named .pt files for C++ loading.")
|
|
parser.add_argument('--source_individual_weights_dir', type=str, required=True,
|
|
help="Directory containing the source underscore_named .pt files (e.g., 'exported_weights/backbone/').")
|
|
parser.add_argument('--output_jit_wrapped_tensors_dir', type=str, required=True,
|
|
help="Directory to save the re-exported JIT-wrapped dot_named .pt files (e.g., 'exported_weights/raw_backbone/').")
|
|
|
|
args = parser.parse_args()
|
|
|
|
print("Instantiating a new ResNet-50 model (will be populated from source dir)...")
|
|
model = resnet50(output_layers=['layer4'], pretrained=False)
|
|
print("ResNet-50 model instantiated.")
|
|
|
|
load_weights_from_individual_files(model, args.source_individual_weights_dir)
|
|
|
|
export_jit_wrapped_tensors(model, args.output_jit_wrapped_tensors_dir)
|
|
|
|
print("Process complete. Weights loaded from source (underscore_named) and re-exported as JIT-wrapped tensors (dot_named).")
|