import torch import os from pathlib import Path from icecream import ic class NetWrapper: """Used for wrapping networks in pytracking. Network modules and functions can be accessed directly as if they were members of this class.""" _rec_iter = 0 def __init__(self, net_path, initialize=False, **kwargs): self.net_path = net_path self.net = None self.net_kwargs = kwargs if initialize: self.initialize() def __getattr__(self, name): if self._rec_iter > 0: self._rec_iter = 0 return None self._rec_iter += 1 try: ret_val = getattr(self.net, name) except Exception as e: self._rec_iter = 0 raise e self._rec_iter = 0 return ret_val def load_network(self): """Loads a network based on the given path and additional arguments.""" print(f"Loading network from: {self.net_path}") kwargs = self.net_kwargs kwargs['backbone_pretrained'] = False # Construct the full network path based on the environment setting path_full = "pytracking/networks/dimp50.pth" # from pytracking.evaluation.environment import env_settings # Importing only inside the method # path_full = os.path.join(env_settings().network_path, self.net_path) # print("path full: ",path_full) # Prepare the checkpoint path net_path_obj = Path(path_full) checkpoint_path = str(net_path_obj) print("LOADING FROM",os.path.expanduser(checkpoint_path)) # Load the checkpoint dictionary checkpoint_dict = torch.load(os.path.expanduser(checkpoint_path), map_location='cpu') net_constr = checkpoint_dict['constructor'] # Update constructor with additional keyword arguments for arg, val in kwargs.items(): net_constr.kwds[arg] = val # Initialize the network using the constructor self.net = net_constr.get() self.net.load_state_dict(checkpoint_dict['net']) self.net.constructor = checkpoint_dict['constructor'] self.cuda() self.eval() def initialize(self): """Initializes the network by loading it.""" self.load_network() class NetWithBackbone(NetWrapper): """Wraps a network with a common backbone. Assumes the network has an 'extract_backbone_features(image)' function.""" def __init__(self, net_path, initialize=False, image_format='rgb', mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), **kwargs): super().__init__(net_path, initialize, **kwargs) self.image_format = image_format self._mean = torch.Tensor(mean).view(1, -1, 1, 1) self._std = torch.Tensor(std).view(1, -1, 1, 1) def preprocess_image(self, im: torch.Tensor): """Normalize the image with the mean and standard deviation used by the network.""" if self.image_format in ['rgb', 'bgr']: im = im / 255 if self.image_format in ['bgr', 'bgr255']: im = im[:, [2, 1, 0], :, :] im -= self._mean im /= self._std im = im.cuda() return im def extract_backbone(self, im: torch.Tensor): """Extract backbone features from the network. Expects a float tensor image with pixel range [0, 255].""" im = self.preprocess_image(im) #ic(self.net) return self.net.extract_backbone_features(im) class DiMPTorchScriptWrapper: """Wraps DiMP components loaded from TorchScript modules or weights.""" def __init__(self, model_dir, device='cuda', backbone_ts=None, backbone_sd=None, classifier_ts=None, classifier_sd=None, bbregressor_ts=None, bbregressor_sd=None): self.device = device self.model_dir = model_dir # Load backbone if backbone_ts: self.backbone = torch.jit.load(os.path.join(model_dir, backbone_ts), map_location=device) elif backbone_sd: from ltr.models.backbone import resnet50 self.backbone = resnet50(output_layers=['layer2', 'layer3'], pretrained=False) # Load weights from individual tensor files instead of state_dict self.load_weights_from_tensors(self.backbone, os.path.join(model_dir, 'backbone')) else: raise RuntimeError('No backbone provided') self.backbone.eval() self.backbone.to(device) # Load classifier if classifier_ts: self.classifier = torch.jit.load(os.path.join(model_dir, classifier_ts), map_location=device) elif classifier_sd: from ltr.models.target_classifier.linear_filter import LinearFilter from ltr.models.target_classifier.features import residual_bottleneck from ltr.models.target_classifier.initializer import FilterInitializerLinear from ltr.models.target_classifier.optimizer import DiMPSteepestDescentGN # Use parameters from exported info files feat_extractor = residual_bottleneck( feature_dim=256, num_blocks=0, l2norm=True, final_conv=True, norm_scale=0.011048543456039804, out_dim=512 ) filter_initializer = FilterInitializerLinear( filter_size=4, filter_norm=False, feature_dim=512 ) filter_optimizer = DiMPSteepestDescentGN( num_iter=5, feat_stride=16, init_step_length=0.9, init_filter_reg=0.1, init_gauss_sigma=0.9, num_dist_bins=100, bin_displacement=0.1, mask_init_factor=3.0 ) self.classifier = LinearFilter( filter_size=4, filter_initializer=filter_initializer, filter_optimizer=filter_optimizer, feature_extractor=feat_extractor ) # Load weights from individual tensor files instead of state_dict self.load_weights_from_tensors(self.classifier, os.path.join(model_dir, 'classifier')) else: raise RuntimeError('No classifier provided') self.classifier.eval() self.classifier.to(device) # Load bb_regressor if bbregressor_ts: self.bb_regressor = torch.jit.load(os.path.join(model_dir, bbregressor_ts), map_location=device) elif bbregressor_sd: from ltr.models.bbreg.atom_iou_net import AtomIoUNet self.bb_regressor = AtomIoUNet( input_dim=(512, 1024), pred_input_dim=(256, 256), pred_inter_dim=(256, 256) ) # Load weights from individual tensor files instead of state_dict self.load_weights_from_tensors(self.bb_regressor, os.path.join(model_dir, 'bb_regressor')) else: raise RuntimeError('No bb_regressor provided') self.bb_regressor.eval() self.bb_regressor.to(device) def load_weights_from_tensors(self, model, tensor_dir): """Load weights from individual tensor files and assign them to the model. tensor_dir should be the directory containing the .pt files and the documentation file. """ doc_file = os.path.join(tensor_dir, os.path.basename(tensor_dir) + '_weights_doc.txt') if not os.path.exists(doc_file): raise FileNotFoundError(f"Documentation file not found: {doc_file}") with open(doc_file, 'r') as f: lines = f.readlines() i = 0 while i < len(lines): line = lines[i] if line.startswith('## '): key = line.strip()[3:] # Look ahead for the File: line j = i + 1 while j < len(lines) and 'File:' not in lines[j]: j += 1 if j < len(lines) and 'File:' in lines[j]: file_name = lines[j].split('File:')[1].strip() tensor_path = os.path.join(tensor_dir, file_name) if os.path.exists(tensor_path): tensor = torch.load(tensor_path, map_location=self.device) # Assign tensor to the model's parameter or buffer parts = key.split('.') module = model for part in parts[:-1]: module = getattr(module, part) if parts[-1] in module._parameters: module._parameters[parts[-1]] = tensor elif parts[-1] in module._buffers: module._buffers[parts[-1]] = tensor else: print(f"Warning: {key} not found in model parameters or buffers.") i = j i += 1 def preprocess_image(self, im: torch.Tensor): # Use the same normalization as before im = im.to(self.device) mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1).to(self.device) std = torch.Tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1).to(self.device) im = im / 255 im -= mean im /= std return im def extract_backbone(self, im: torch.Tensor): im = self.preprocess_image(im) with torch.no_grad(): return self.backbone(im, output_layers=['layer2', 'layer3']) def extract_classification_feat(self, backbone_feat): with torch.no_grad(): # If backbone_feat is a dict, extract the correct layer if isinstance(backbone_feat, dict): backbone_feat = backbone_feat['layer3'] return self.classifier.extract_classification_feat(backbone_feat) def get_backbone_bbreg_feat(self, backbone_feat): # This assumes backbone_feat is a dict with 'layer2' and 'layer3' return [backbone_feat['layer2'], backbone_feat['layer3']] def initialize(self): # No-op for compatibility return self # Add any other methods needed to match the old interface