You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

260 lines
10 KiB

import torch
import os
from pathlib import Path
from icecream import ic
class NetWrapper:
"""Used for wrapping networks in pytracking.
Network modules and functions can be accessed directly as if they were members of this class."""
_rec_iter = 0
def __init__(self, net_path, initialize=False, **kwargs):
self.net_path = net_path
self.net = None
self.net_kwargs = kwargs
if initialize:
self.initialize()
def __getattr__(self, name):
if self._rec_iter > 0:
self._rec_iter = 0
return None
self._rec_iter += 1
try:
ret_val = getattr(self.net, name)
except Exception as e:
self._rec_iter = 0
raise e
self._rec_iter = 0
return ret_val
def load_network(self):
"""Loads a network based on the given path and additional arguments."""
print(f"Loading network from: {self.net_path}")
kwargs = self.net_kwargs
kwargs['backbone_pretrained'] = False
# Construct the full network path based on the environment setting
path_full = "pytracking/networks/dimp50.pth"
# from pytracking.evaluation.environment import env_settings # Importing only inside the method
# path_full = os.path.join(env_settings().network_path, self.net_path)
# print("path full: ",path_full)
# Prepare the checkpoint path
net_path_obj = Path(path_full)
checkpoint_path = str(net_path_obj)
print("LOADING FROM",os.path.expanduser(checkpoint_path))
# Load the checkpoint dictionary
checkpoint_dict = torch.load(os.path.expanduser(checkpoint_path), map_location='cpu')
net_constr = checkpoint_dict['constructor']
# Update constructor with additional keyword arguments
for arg, val in kwargs.items():
net_constr.kwds[arg] = val
# Initialize the network using the constructor
self.net = net_constr.get()
self.net.load_state_dict(checkpoint_dict['net'])
self.net.constructor = checkpoint_dict['constructor']
self.cuda()
self.eval()
def initialize(self):
"""Initializes the network by loading it."""
self.load_network()
class NetWithBackbone(NetWrapper):
"""Wraps a network with a common backbone.
Assumes the network has an 'extract_backbone_features(image)' function."""
def __init__(self, net_path, initialize=False, image_format='rgb',
mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), **kwargs):
super().__init__(net_path, initialize, **kwargs)
self.image_format = image_format
self._mean = torch.Tensor(mean).view(1, -1, 1, 1)
self._std = torch.Tensor(std).view(1, -1, 1, 1)
def preprocess_image(self, im: torch.Tensor):
"""Normalize the image with the mean and standard deviation used by the network."""
if self.image_format in ['rgb', 'bgr']:
im = im / 255
if self.image_format in ['bgr', 'bgr255']:
im = im[:, [2, 1, 0], :, :]
im -= self._mean
im /= self._std
im = im.cuda()
return im
def extract_backbone(self, im: torch.Tensor):
"""Extract backbone features from the network.
Expects a float tensor image with pixel range [0, 255]."""
im = self.preprocess_image(im)
#ic(self.net)
return self.net.extract_backbone_features(im)
class DiMPTorchScriptWrapper:
"""Wraps DiMP components loaded from TorchScript modules or weights."""
def __init__(self, model_dir, device='cuda',
backbone_ts=None, backbone_sd=None,
classifier_ts=None, classifier_sd=None,
bbregressor_ts=None, bbregressor_sd=None):
self.device = device
self.model_dir = model_dir
# Load backbone
if backbone_ts:
self.backbone = torch.jit.load(os.path.join(model_dir, backbone_ts), map_location=device)
elif backbone_sd:
from ltr.models.backbone import resnet50
self.backbone = resnet50(output_layers=['layer2', 'layer3'], pretrained=False)
# Load weights from individual tensor files instead of state_dict
self.load_weights_from_tensors(self.backbone, os.path.join(model_dir, 'backbone'))
else:
raise RuntimeError('No backbone provided')
self.backbone.eval()
self.backbone.to(device)
# Load classifier
if classifier_ts:
self.classifier = torch.jit.load(os.path.join(model_dir, classifier_ts), map_location=device)
elif classifier_sd:
from ltr.models.target_classifier.linear_filter import LinearFilter
from ltr.models.target_classifier.features import residual_bottleneck
from ltr.models.target_classifier.initializer import FilterInitializerLinear
from ltr.models.target_classifier.optimizer import DiMPSteepestDescentGN
# Use parameters from exported info files
feat_extractor = residual_bottleneck(
feature_dim=256,
num_blocks=0,
l2norm=True,
final_conv=True,
norm_scale=0.011048543456039804,
out_dim=512
)
filter_initializer = FilterInitializerLinear(
filter_size=4,
filter_norm=False,
feature_dim=512
)
filter_optimizer = DiMPSteepestDescentGN(
num_iter=5,
feat_stride=16,
init_step_length=0.9,
init_filter_reg=0.1,
init_gauss_sigma=0.9,
num_dist_bins=100,
bin_displacement=0.1,
mask_init_factor=3.0
)
self.classifier = LinearFilter(
filter_size=4,
filter_initializer=filter_initializer,
filter_optimizer=filter_optimizer,
feature_extractor=feat_extractor
)
# Load weights from individual tensor files instead of state_dict
self.load_weights_from_tensors(self.classifier, os.path.join(model_dir, 'classifier'))
else:
raise RuntimeError('No classifier provided')
self.classifier.eval()
self.classifier.to(device)
# Load bb_regressor
if bbregressor_ts:
self.bb_regressor = torch.jit.load(os.path.join(model_dir, bbregressor_ts), map_location=device)
elif bbregressor_sd:
from ltr.models.bbreg.atom_iou_net import AtomIoUNet
self.bb_regressor = AtomIoUNet(
input_dim=(512, 1024),
pred_input_dim=(256, 256),
pred_inter_dim=(256, 256)
)
# Load weights from individual tensor files instead of state_dict
self.load_weights_from_tensors(self.bb_regressor, os.path.join(model_dir, 'bb_regressor'))
else:
raise RuntimeError('No bb_regressor provided')
self.bb_regressor.eval()
self.bb_regressor.to(device)
def load_weights_from_tensors(self, model, tensor_dir):
"""Load weights from individual tensor files and assign them to the model.
tensor_dir should be the directory containing the .pt files and the documentation file.
"""
doc_file = os.path.join(tensor_dir, os.path.basename(tensor_dir) + '_weights_doc.txt')
if not os.path.exists(doc_file):
raise FileNotFoundError(f"Documentation file not found: {doc_file}")
with open(doc_file, 'r') as f:
lines = f.readlines()
i = 0
while i < len(lines):
line = lines[i]
if line.startswith('## '):
key = line.strip()[3:]
# Look ahead for the File: line
j = i + 1
while j < len(lines) and 'File:' not in lines[j]:
j += 1
if j < len(lines) and 'File:' in lines[j]:
file_name = lines[j].split('File:')[1].strip()
tensor_path = os.path.join(tensor_dir, file_name)
if os.path.exists(tensor_path):
tensor = torch.load(tensor_path, map_location=self.device)
# Assign tensor to the model's parameter or buffer
parts = key.split('.')
module = model
for part in parts[:-1]:
module = getattr(module, part)
if parts[-1] in module._parameters:
module._parameters[parts[-1]] = tensor
elif parts[-1] in module._buffers:
module._buffers[parts[-1]] = tensor
else:
print(f"Warning: {key} not found in model parameters or buffers.")
i = j
i += 1
def preprocess_image(self, im: torch.Tensor):
# Use the same normalization as before
im = im.to(self.device)
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1).to(self.device)
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1).to(self.device)
im = im / 255
im -= mean
im /= std
return im
def extract_backbone(self, im: torch.Tensor):
im = self.preprocess_image(im)
with torch.no_grad():
return self.backbone(im, output_layers=['layer2', 'layer3'])
def extract_classification_feat(self, backbone_feat):
with torch.no_grad():
# If backbone_feat is a dict, extract the correct layer
if isinstance(backbone_feat, dict):
backbone_feat = backbone_feat['layer3']
return self.classifier.extract_classification_feat(backbone_feat)
def get_backbone_bbreg_feat(self, backbone_feat):
# This assumes backbone_feat is a dict with 'layer2' and 'layer3'
return [backbone_feat['layer2'], backbone_feat['layer3']]
def initialize(self):
# No-op for compatibility
return self
# Add any other methods needed to match the old interface