You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

105 lines
3.4 KiB

import torch
import os
from pathlib import Path
from icecream import ic
class NetWrapper:
"""Used for wrapping networks in pytracking.
Network modules and functions can be accessed directly as if they were members of this class."""
_rec_iter = 0
def __init__(self, net_path, initialize=False, **kwargs):
self.net_path = net_path
self.net = None
self.net_kwargs = kwargs
if initialize:
self.initialize()
def __getattr__(self, name):
if self._rec_iter > 0:
self._rec_iter = 0
return None
self._rec_iter += 1
try:
ret_val = getattr(self.net, name)
except Exception as e:
self._rec_iter = 0
raise e
self._rec_iter = 0
return ret_val
def load_network(self):
"""Loads a network based on the given path and additional arguments."""
print(f"Loading network from: {self.net_path}")
kwargs = self.net_kwargs
kwargs['backbone_pretrained'] = False
# Construct the full network path based on the environment setting
path_full = "tracker/pytracking/networks/dimp50.pth"
# from pytracking.evaluation.environment import env_settings # Importing only inside the method
# path_full = os.path.join(env_settings().network_path, self.net_path)
# print("path full: ",path_full)
# Prepare the checkpoint path
net_path_obj = Path(path_full)
checkpoint_path = str(net_path_obj)
print("LOADING FROM",os.path.expanduser(checkpoint_path))
# Load the checkpoint dictionary
checkpoint_dict = torch.load(os.path.expanduser(checkpoint_path), map_location='cpu')
net_constr = checkpoint_dict['constructor']
# Update constructor with additional keyword arguments
for arg, val in kwargs.items():
net_constr.kwds[arg] = val
# Initialize the network using the constructor
self.net = net_constr.get()
self.net.load_state_dict(checkpoint_dict['net'])
self.net.constructor = checkpoint_dict['constructor']
self.cuda()
self.eval()
def initialize(self):
"""Initializes the network by loading it."""
self.load_network()
class NetWithBackbone(NetWrapper):
"""Wraps a network with a common backbone.
Assumes the network has an 'extract_backbone_features(image)' function."""
def __init__(self, net_path, initialize=False, image_format='rgb',
mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), **kwargs):
super().__init__(net_path, initialize, **kwargs)
self.image_format = image_format
self._mean = torch.Tensor(mean).view(1, -1, 1, 1)
self._std = torch.Tensor(std).view(1, -1, 1, 1)
def preprocess_image(self, im: torch.Tensor):
"""Normalize the image with the mean and standard deviation used by the network."""
if self.image_format in ['rgb', 'bgr']:
im = im / 255
if self.image_format in ['bgr', 'bgr255']:
im = im[:, [2, 1, 0], :, :]
im -= self._mean
im /= self._std
im = im.cuda()
return im
def extract_backbone(self, im: torch.Tensor):
"""Extract backbone features from the network.
Expects a float tensor image with pixel range [0, 255]."""
im = self.preprocess_image(im)
#ic(self.net)
return self.net.extract_backbone_features(im)