import torch import os from pathlib import Path from icecream import ic class NetWrapper: """Used for wrapping networks in pytracking. Network modules and functions can be accessed directly as if they were members of this class.""" _rec_iter = 0 def __init__(self, net_path, initialize=False, **kwargs): self.net_path = net_path self.net = None self.net_kwargs = kwargs if initialize: self.initialize() def __getattr__(self, name): if self._rec_iter > 0: self._rec_iter = 0 return None self._rec_iter += 1 try: ret_val = getattr(self.net, name) except Exception as e: self._rec_iter = 0 raise e self._rec_iter = 0 return ret_val def load_network(self): """Loads a network based on the given path and additional arguments.""" print(f"Loading network from: {self.net_path}") kwargs = self.net_kwargs kwargs['backbone_pretrained'] = False # Construct the full network path based on the environment setting path_full = "tracker/pytracking/networks/dimp50.pth" # from pytracking.evaluation.environment import env_settings # Importing only inside the method # path_full = os.path.join(env_settings().network_path, self.net_path) # print("path full: ",path_full) # Prepare the checkpoint path net_path_obj = Path(path_full) checkpoint_path = str(net_path_obj) print("LOADING FROM",os.path.expanduser(checkpoint_path)) # Load the checkpoint dictionary checkpoint_dict = torch.load(os.path.expanduser(checkpoint_path), map_location='cpu') net_constr = checkpoint_dict['constructor'] # Update constructor with additional keyword arguments for arg, val in kwargs.items(): net_constr.kwds[arg] = val # Initialize the network using the constructor self.net = net_constr.get() self.net.load_state_dict(checkpoint_dict['net']) self.net.constructor = checkpoint_dict['constructor'] self.cuda() self.eval() def initialize(self): """Initializes the network by loading it.""" self.load_network() class NetWithBackbone(NetWrapper): """Wraps a network with a common backbone. Assumes the network has an 'extract_backbone_features(image)' function.""" def __init__(self, net_path, initialize=False, image_format='rgb', mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), **kwargs): super().__init__(net_path, initialize, **kwargs) self.image_format = image_format self._mean = torch.Tensor(mean).view(1, -1, 1, 1) self._std = torch.Tensor(std).view(1, -1, 1, 1) def preprocess_image(self, im: torch.Tensor): """Normalize the image with the mean and standard deviation used by the network.""" if self.image_format in ['rgb', 'bgr']: im = im / 255 if self.image_format in ['bgr', 'bgr255']: im = im[:, [2, 1, 0], :, :] im -= self._mean im /= self._std im = im.cuda() return im def extract_backbone(self, im: torch.Tensor): """Extract backbone features from the network. Expects a float tensor image with pixel range [0, 255].""" im = self.preprocess_image(im) #ic(self.net) return self.net.extract_backbone_features(im)