from typing import List import torch import torch.nn.functional as F import math from pytracking.libs import dcf from pytracking.libs.tensorlist import TensorList from pytracking.features.preprocessing import numpy_to_torch from pytracking.features.preprocessing import sample_patch_multiscale, sample_patch_transformed from pytracking.features import augmentation from icecream import ic class DiMP(): def __init__(self, params): self.params = params # Initialize network self.params.net.initialize() # The DiMP network self.net = self.params.net def initialize(self, image=None, bbox: List = None) -> dict: if bbox is None or image is None: return None # Initialize some stuff self.params.device = 'cuda' # Convert image im = numpy_to_torch(image) # Get target position and size # Get position and size self.pos = torch.Tensor([bbox[1] + (bbox[3] - 1) / 2, bbox[0] + (bbox[2] - 1) / 2]) self.target_sz = torch.Tensor([bbox[3], bbox[2]]) # Set sizes self.image_sz = torch.Tensor([im.shape[2], im.shape[3]]) sz = self.params.image_sample_size sz = torch.Tensor([sz, sz] if isinstance(sz, int) else sz) self.img_sample_sz = sz self.img_support_sz = self.img_sample_sz # Set search area search_area = torch.prod(self.target_sz * self.params.search_area_scale).item() self.target_scale = math.sqrt(search_area) / self.img_sample_sz.prod().sqrt() # Target size in base scale self.base_target_sz = self.target_sz / self.target_scale # Setup scale factors self.params.scale_factors = torch.ones(1) # Extract and transform sample init_backbone_feat = self.generate_init_samples(im) # Initialize classifier self.init_classifier(init_backbone_feat) # Initialize IoUNet self.init_iou_net(init_backbone_feat) def track(self, image) -> dict: torch.cuda.empty_cache() # Convert image im = numpy_to_torch(image) # ------- LOCALIZATION ------- # # Extract backbone features backbone_feat, sample_coords = self.extract_backbone_features(im, self.get_centered_sample_pos(), self.target_scale * self.params.scale_factors, self.img_sample_sz) # Extract classification features test_x = self.get_classification_features(backbone_feat) # Location of sample sample_pos, sample_scales = self.get_sample_location(sample_coords) # Compute classification scores scores_raw = self.classify_target(test_x) # Localize the target translation_vec, scale_ind, s, flag = self.localize_target(scores_raw, sample_pos, sample_scales) new_pos = sample_pos[scale_ind, :] + translation_vec # Update position and scale if flag != 'not_found': update_scale_flag = self.params.get('update_scale_when_uncertain', True) or flag != 'uncertain' self.update_state(new_pos) self.refine_target_box(backbone_feat, sample_pos[scale_ind, :], sample_scales[scale_ind], scale_ind, update_scale_flag) # ------- UPDATE ------- # # Set the pos of the tracker to iounet pos if self.params.get('use_iou_net', True) and flag != 'not_found' and hasattr(self, 'pos_iounet'): self.pos = self.pos_iounet.clone() # Visualize and set debug info self.search_area_box = torch.cat( (sample_coords[scale_ind, [1, 0]], sample_coords[scale_ind, [3, 2]] - sample_coords[scale_ind, [1, 0]] - 1)) # Compute output bounding box new_state = torch.cat((self.pos[[1, 0]] - (self.target_sz[[1, 0]] - 1) / 2, self.target_sz[[1, 0]])) out = {'target_bbox': new_state.tolist(), 'success': flag != 'not_found'} return out def get_sample_location(self, sample_coord): """Get the location of the extracted sample.""" sample_coord = sample_coord.float() sample_pos = 0.5 * (sample_coord[:, :2] + sample_coord[:, 2:] - 1) sample_scales = ((sample_coord[:, 2:] - sample_coord[:, :2]) / self.img_sample_sz).prod(dim=1).sqrt() return sample_pos, sample_scales def get_centered_sample_pos(self): """Get the center position for the new sample. Make sure the target is correctly centered.""" return self.pos + ((self.feature_sz + self.kernel_size) % 2) * self.target_scale * \ self.img_support_sz / (2 * self.feature_sz) def classify_target(self, sample_x: TensorList): """Classify target by applying the DiMP filter.""" with torch.no_grad(): scores = self.net.classifier.classify(self.target_filter, sample_x) return scores def localize_target(self, scores, sample_pos, sample_scales): """Run the target localization.""" scores = scores.squeeze(1) return self.localize_advanced(scores, sample_pos, sample_scales) def localize_advanced(self, scores, sample_pos, sample_scales): """Run the target advanced localization (as in ATOM).""" sz = scores.shape[-2:] score_sz = torch.Tensor(list(sz)) output_sz = score_sz - (self.kernel_size + 1) % 2 score_center = (score_sz - 1) / 2 scores_hn = scores max_score1, max_disp1 = dcf.max2d(scores) _, scale_ind = torch.max(max_score1, dim=0) sample_scale = sample_scales[scale_ind] max_score1 = max_score1[scale_ind] max_disp1 = max_disp1[scale_ind, ...].float().cpu().view(-1) target_disp1 = max_disp1 - score_center translation_vec1 = target_disp1 * (self.img_support_sz / output_sz) * sample_scale if max_score1.item() < self.params.target_not_found_threshold: return translation_vec1, scale_ind, scores_hn, 'not_found' # Mask out target neighborhood target_neigh_sz = self.params.target_neighborhood_scale * (self.target_sz / sample_scale) * ( output_sz / self.img_support_sz) tneigh_top = max(round(max_disp1[0].item() - target_neigh_sz[0].item() / 2), 0) tneigh_bottom = min(round(max_disp1[0].item() + target_neigh_sz[0].item() / 2 + 1), sz[0]) tneigh_left = max(round(max_disp1[1].item() - target_neigh_sz[1].item() / 2), 0) tneigh_right = min(round(max_disp1[1].item() + target_neigh_sz[1].item() / 2 + 1), sz[1]) scores_masked = scores_hn[scale_ind:scale_ind + 1, ...].clone() scores_masked[..., tneigh_top:tneigh_bottom, tneigh_left:tneigh_right] = 0 return translation_vec1, scale_ind, scores_hn, 'normal' def extract_backbone_features(self, im: torch.Tensor, pos: torch.Tensor, scales, sz: torch.Tensor): im_patches, patch_coords = sample_patch_multiscale(im, pos, scales, sz, mode=self.params.get('border_mode', 'replicate'), max_scale_change=self.params.get('patch_max_scale_change', None)) with torch.no_grad(): backbone_feat = self.net.extract_backbone(im_patches) return backbone_feat, patch_coords def get_classification_features(self, backbone_feat): with torch.no_grad(): return self.net.extract_classification_feat(backbone_feat) def get_iou_backbone_features(self, backbone_feat): return self.net.get_backbone_bbreg_feat(backbone_feat) def get_iou_features(self, backbone_feat): with torch.no_grad(): return self.net.bb_regressor.get_iou_feat(self.get_iou_backbone_features(backbone_feat)) def get_iou_modulation(self, iou_backbone_feat, target_boxes): with torch.no_grad(): return self.net.bb_regressor.get_modulation(iou_backbone_feat, target_boxes) def generate_init_samples(self, im: torch.Tensor) -> TensorList: """Perform data augmentation to generate initial training samples.""" self.init_sample_scale = self.target_scale global_shift = torch.zeros(2) self.init_sample_pos = self.pos.round() # Compute augmentation size aug_expansion_factor = self.params.get('augmentation_expansion_factor', None) aug_expansion_sz = (self.img_sample_sz * aug_expansion_factor).long() aug_expansion_sz += (aug_expansion_sz - self.img_sample_sz.long()) % 2 aug_expansion_sz = aug_expansion_sz.float() aug_output_sz = self.img_sample_sz.long().tolist() random_shift_factor = self.params.get('random_shift_factor', 0) get_rand_shift = lambda: ( (torch.rand(2) - 0.5) * self.img_sample_sz * random_shift_factor + global_shift).long().tolist() # Always put identity transformation first, since it is the unaugmented sample that is always used self.transforms = [augmentation.Identity(aug_output_sz, global_shift.long().tolist())] augs = self.params.augmentation if self.params.get('use_augmentation', True) else {} # Add all augmentations get_absolute = lambda shift: (torch.Tensor(shift) * self.img_sample_sz / 2).long().tolist() self.transforms.extend( [augmentation.Translation(get_absolute(shift), aug_output_sz, global_shift.long().tolist()) for shift in augs['relativeshift']]) self.transforms.append(augmentation.FlipHorizontal(aug_output_sz, get_rand_shift())) self.transforms.extend( [augmentation.Blur(sigma, aug_output_sz, get_rand_shift()) for sigma in augs['blur']]) self.transforms.extend( [augmentation.Rotate(angle, aug_output_sz, get_rand_shift()) for angle in augs['rotate']]) # Extract augmented image patches im_patches = sample_patch_transformed(im, self.init_sample_pos, self.init_sample_scale, aug_expansion_sz, self.transforms) # Extract initial backbone features with torch.no_grad(): init_backbone_feat = self.net.extract_backbone(im_patches) return init_backbone_feat def init_target_boxes(self): """Get the target bounding boxes for the initial augmented samples.""" self.classifier_target_box = self.get_iounet_box(self.pos, self.target_sz, self.init_sample_pos, self.init_sample_scale) init_target_boxes = TensorList() for T in self.transforms: init_target_boxes.append(self.classifier_target_box + torch.Tensor([T.shift[1], T.shift[0], 0, 0])) init_target_boxes = torch.cat(init_target_boxes.view(1, 4), 0).to(self.params.device) self.target_boxes = init_target_boxes.new_zeros(self.params.sample_memory_size, 4) self.target_boxes[:init_target_boxes.shape[0], :] = init_target_boxes return init_target_boxes def update_state(self, new_pos): # Update pos inside_ratio = self.params.get('target_inside_ratio', 0.2) inside_offset = (inside_ratio - 0.5) * self.target_sz self.pos = torch.max(torch.min(new_pos, self.image_sz - inside_offset), inside_offset) def get_iounet_box(self, pos, sz, sample_pos, sample_scale): """All inputs in original image coordinates. Generates a box in the cropped image sample reference frame, in the format used by the IoUNet.""" box_center = (pos - sample_pos) / sample_scale + (self.img_sample_sz - 1) / 2 box_sz = sz / sample_scale target_ul = box_center - (box_sz - 1) / 2 return torch.cat([target_ul.flip((0,)), box_sz.flip((0,))]) def init_iou_net(self, backbone_feat): # Setup IoU net and objective for p in self.net.bb_regressor.parameters(): p.requires_grad = False # Get target boxes for the different augmentations self.classifier_target_box = self.get_iounet_box(self.pos, self.target_sz, self.init_sample_pos, self.init_sample_scale) target_boxes = TensorList() target_boxes.append(self.classifier_target_box + torch.Tensor( [self.transforms[0].shift[1], self.transforms[0].shift[0], 0, 0])) target_boxes = torch.cat(target_boxes.view(1, 4), 0).to(self.params.device) # Get iou features iou_backbone_feat = self.get_iou_backbone_features(backbone_feat) # Remove other augmentations such as rotation iou_backbone_feat = TensorList([x[:target_boxes.shape[0], ...] for x in iou_backbone_feat]) # Get modulation vector self.iou_modulation = self.get_iou_modulation(iou_backbone_feat, target_boxes) self.iou_modulation = TensorList([x.detach().mean(0) for x in self.iou_modulation]) def init_classifier(self, init_backbone_feat): # Get classification features x = self.get_classification_features(init_backbone_feat) # Add the dropout augmentation here, since it requires extraction of the classification features num, prob = self.params.augmentation['dropout'] self.transforms.extend(self.transforms[:1] * num) x = torch.cat([x, F.dropout2d(x[0:1, ...].expand(num, -1, -1, -1), p=prob, training=True)]) # Set feature size and other related sizes self.feature_sz = torch.Tensor(list(x.shape[-2:])) ksz = self.net.classifier.filter_size self.kernel_size = torch.Tensor([ksz, ksz] if isinstance(ksz, (int, float)) else ksz) self.output_sz = self.feature_sz + (self.kernel_size + 1) % 2 # Get target boxes for the different augmentations target_boxes = self.init_target_boxes() # Set number of iterations num_iter = self.params.get('net_opt_iter', None) # Get target filter by running the discriminative model prediction module with torch.no_grad(): self.target_filter, _, losses = self.net.classifier.get_filter(x, target_boxes, num_iter=num_iter, compute_losses=False) def refine_target_box(self, backbone_feat, sample_pos, sample_scale, scale_ind, update_scale=True): """Run the ATOM IoUNet to refine the target bounding box.""" # Initial box for refinement init_box = self.get_iounet_box(self.pos, self.target_sz, sample_pos, sample_scale) # Extract features from the relevant scale iou_features = self.get_iou_features(backbone_feat) iou_features = TensorList([x[scale_ind:scale_ind + 1, ...] for x in iou_features]) # Generate random initial boxes init_boxes = init_box.view(1, 4).clone() square_box_sz = init_box[2:].prod().sqrt() rand_factor = square_box_sz * torch.cat( [self.params.box_jitter_pos * torch.ones(2), self.params.box_jitter_sz * torch.ones(2)]) # Optimize the boxes output_boxes, output_iou = self.optimize_boxes(iou_features, init_boxes) # Remove weird boxes output_boxes[:, 2:].clamp_(1) aspect_ratio = output_boxes[:, 2] / output_boxes[:, 3] keep_ind = (aspect_ratio < self.params.maximal_aspect_ratio) * ( aspect_ratio > 1 / self.params.maximal_aspect_ratio) output_boxes = output_boxes[keep_ind, :] output_iou = output_iou[keep_ind] # If no box found if output_boxes.shape[0] == 0: return # Predict box k = self.params.get('iounet_k', 5) topk = min(k, output_boxes.shape[0]) _, inds = torch.topk(output_iou, topk) predicted_box = output_boxes[inds, :].mean(0) # Get new position and size new_pos = predicted_box[:2] + predicted_box[2:] / 2 new_pos = (new_pos.flip((0,)) - (self.img_sample_sz - 1) / 2) * sample_scale + sample_pos new_target_sz = predicted_box[2:].flip((0,)) * sample_scale new_scale = torch.sqrt(new_target_sz.prod() / self.base_target_sz.prod()) self.pos_iounet = new_pos.clone() self.pos = new_pos.clone() self.target_sz = new_target_sz self.target_scale = new_scale def optimize_boxes(self, iou_features, init_boxes): return self.optimize_boxes_default(iou_features, init_boxes) def optimize_boxes_default(self, iou_features, init_boxes): """Optimize iounet boxes with the default parametrization""" output_boxes = init_boxes.view(1, -1, 4).to(self.params.device) step_length = self.params.box_refinement_step_length for i_ in range(self.params.box_refinement_iter): # forward pass bb_init = output_boxes.clone().detach() bb_init.requires_grad = True outputs = self.net.bb_regressor.predict_iou(self.iou_modulation, iou_features, bb_init) outputs.backward(gradient=torch.ones_like(outputs)) # Update proposal output_boxes = bb_init + step_length * bb_init.grad * bb_init[:, :, 2:].repeat(1, 1, 2) output_boxes.detach_() step_length *= self.params.box_refinement_step_decay return output_boxes.view(-1, 4).cpu(), outputs.detach().view(-1).cpu()