You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
433 lines
17 KiB
433 lines
17 KiB
from typing import List
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import math
|
|
from pytracking.libs import dcf
|
|
from pytracking.libs.tensorlist import TensorList
|
|
from pytracking.features.preprocessing import numpy_to_torch
|
|
from pytracking.features.preprocessing import sample_patch_multiscale, sample_patch_transformed
|
|
from pytracking.features import augmentation
|
|
from icecream import ic
|
|
|
|
|
|
class DiMP():
|
|
def __init__(self, params):
|
|
self.params = params
|
|
|
|
# Initialize network
|
|
self.params.net.initialize()
|
|
# The DiMP network
|
|
self.net = self.params.net
|
|
|
|
|
|
def initialize(self, image=None, bbox: List = None) -> dict:
|
|
if bbox is None or image is None:
|
|
return None
|
|
# Initialize some stuff
|
|
|
|
self.params.device = 'cuda'
|
|
|
|
# Convert image
|
|
im = numpy_to_torch(image)
|
|
|
|
# Get target position and size
|
|
# Get position and size
|
|
self.pos = torch.Tensor([bbox[1] + (bbox[3] - 1) / 2, bbox[0] + (bbox[2] - 1) / 2])
|
|
self.target_sz = torch.Tensor([bbox[3], bbox[2]])
|
|
|
|
# Set sizes
|
|
self.image_sz = torch.Tensor([im.shape[2], im.shape[3]])
|
|
sz = self.params.image_sample_size
|
|
sz = torch.Tensor([sz, sz] if isinstance(sz, int) else sz)
|
|
|
|
self.img_sample_sz = sz
|
|
self.img_support_sz = self.img_sample_sz
|
|
|
|
# Set search area
|
|
search_area = torch.prod(self.target_sz * self.params.search_area_scale).item()
|
|
self.target_scale = math.sqrt(search_area) / self.img_sample_sz.prod().sqrt()
|
|
|
|
# Target size in base scale
|
|
self.base_target_sz = self.target_sz / self.target_scale
|
|
|
|
# Setup scale factors
|
|
self.params.scale_factors = torch.ones(1)
|
|
|
|
# Extract and transform sample
|
|
init_backbone_feat = self.generate_init_samples(im)
|
|
|
|
# Initialize classifier
|
|
self.init_classifier(init_backbone_feat)
|
|
|
|
# Initialize IoUNet
|
|
|
|
self.init_iou_net(init_backbone_feat)
|
|
|
|
|
|
|
|
def track(self, image) -> dict:
|
|
torch.cuda.empty_cache()
|
|
|
|
# Convert image
|
|
im = numpy_to_torch(image)
|
|
|
|
# ------- LOCALIZATION ------- #
|
|
|
|
# Extract backbone features
|
|
backbone_feat, sample_coords = self.extract_backbone_features(im, self.get_centered_sample_pos(),
|
|
self.target_scale * self.params.scale_factors,
|
|
self.img_sample_sz)
|
|
# Extract classification features
|
|
test_x = self.get_classification_features(backbone_feat)
|
|
|
|
# Location of sample
|
|
sample_pos, sample_scales = self.get_sample_location(sample_coords)
|
|
|
|
# Compute classification scores
|
|
scores_raw = self.classify_target(test_x)
|
|
|
|
# Localize the target
|
|
translation_vec, scale_ind, s, flag = self.localize_target(scores_raw, sample_pos, sample_scales)
|
|
new_pos = sample_pos[scale_ind, :] + translation_vec
|
|
|
|
# Update position and scale
|
|
if flag != 'not_found':
|
|
|
|
update_scale_flag = self.params.get('update_scale_when_uncertain', True) or flag != 'uncertain'
|
|
|
|
self.update_state(new_pos)
|
|
|
|
self.refine_target_box(backbone_feat, sample_pos[scale_ind, :], sample_scales[scale_ind], scale_ind,
|
|
update_scale_flag)
|
|
|
|
# ------- UPDATE ------- #
|
|
|
|
# Set the pos of the tracker to iounet pos
|
|
if self.params.get('use_iou_net', True) and flag != 'not_found' and hasattr(self, 'pos_iounet'):
|
|
self.pos = self.pos_iounet.clone()
|
|
|
|
|
|
# Visualize and set debug info
|
|
self.search_area_box = torch.cat(
|
|
(sample_coords[scale_ind, [1, 0]], sample_coords[scale_ind, [3, 2]] - sample_coords[scale_ind, [1, 0]] - 1))
|
|
|
|
|
|
# Compute output bounding box
|
|
new_state = torch.cat((self.pos[[1, 0]] - (self.target_sz[[1, 0]] - 1) / 2, self.target_sz[[1, 0]]))
|
|
|
|
|
|
out = {'target_bbox': new_state.tolist(), 'success': flag != 'not_found'}
|
|
return out
|
|
|
|
def get_sample_location(self, sample_coord):
|
|
"""Get the location of the extracted sample."""
|
|
sample_coord = sample_coord.float()
|
|
sample_pos = 0.5 * (sample_coord[:, :2] + sample_coord[:, 2:] - 1)
|
|
sample_scales = ((sample_coord[:, 2:] - sample_coord[:, :2]) / self.img_sample_sz).prod(dim=1).sqrt()
|
|
return sample_pos, sample_scales
|
|
|
|
def get_centered_sample_pos(self):
|
|
"""Get the center position for the new sample. Make sure the target is correctly centered."""
|
|
return self.pos + ((self.feature_sz + self.kernel_size) % 2) * self.target_scale * \
|
|
self.img_support_sz / (2 * self.feature_sz)
|
|
|
|
|
|
def classify_target(self, sample_x: TensorList):
|
|
"""Classify target by applying the DiMP filter."""
|
|
with torch.no_grad():
|
|
scores = self.net.classifier.classify(self.target_filter, sample_x)
|
|
return scores
|
|
|
|
|
|
def localize_target(self, scores, sample_pos, sample_scales):
|
|
"""Run the target localization."""
|
|
scores = scores.squeeze(1)
|
|
|
|
return self.localize_advanced(scores, sample_pos, sample_scales)
|
|
|
|
def localize_advanced(self, scores, sample_pos, sample_scales):
|
|
"""Run the target advanced localization (as in ATOM)."""
|
|
sz = scores.shape[-2:]
|
|
score_sz = torch.Tensor(list(sz))
|
|
output_sz = score_sz - (self.kernel_size + 1) % 2
|
|
score_center = (score_sz - 1) / 2
|
|
|
|
scores_hn = scores
|
|
|
|
max_score1, max_disp1 = dcf.max2d(scores)
|
|
_, scale_ind = torch.max(max_score1, dim=0)
|
|
sample_scale = sample_scales[scale_ind]
|
|
max_score1 = max_score1[scale_ind]
|
|
max_disp1 = max_disp1[scale_ind, ...].float().cpu().view(-1)
|
|
target_disp1 = max_disp1 - score_center
|
|
translation_vec1 = target_disp1 * (self.img_support_sz / output_sz) * sample_scale
|
|
|
|
if max_score1.item() < self.params.target_not_found_threshold:
|
|
return translation_vec1, scale_ind, scores_hn, 'not_found'
|
|
|
|
|
|
# Mask out target neighborhood
|
|
target_neigh_sz = self.params.target_neighborhood_scale * (self.target_sz / sample_scale) * (
|
|
output_sz / self.img_support_sz)
|
|
|
|
tneigh_top = max(round(max_disp1[0].item() - target_neigh_sz[0].item() / 2), 0)
|
|
tneigh_bottom = min(round(max_disp1[0].item() + target_neigh_sz[0].item() / 2 + 1), sz[0])
|
|
tneigh_left = max(round(max_disp1[1].item() - target_neigh_sz[1].item() / 2), 0)
|
|
tneigh_right = min(round(max_disp1[1].item() + target_neigh_sz[1].item() / 2 + 1), sz[1])
|
|
scores_masked = scores_hn[scale_ind:scale_ind + 1, ...].clone()
|
|
scores_masked[..., tneigh_top:tneigh_bottom, tneigh_left:tneigh_right] = 0
|
|
|
|
return translation_vec1, scale_ind, scores_hn, 'normal'
|
|
|
|
|
|
def extract_backbone_features(self, im: torch.Tensor, pos: torch.Tensor, scales, sz: torch.Tensor):
|
|
im_patches, patch_coords = sample_patch_multiscale(im, pos, scales, sz,
|
|
mode=self.params.get('border_mode', 'replicate'),
|
|
max_scale_change=self.params.get('patch_max_scale_change',
|
|
None))
|
|
with torch.no_grad():
|
|
backbone_feat = self.net.extract_backbone(im_patches)
|
|
return backbone_feat, patch_coords
|
|
|
|
|
|
def get_classification_features(self, backbone_feat):
|
|
with torch.no_grad():
|
|
return self.net.extract_classification_feat(backbone_feat)
|
|
|
|
def get_iou_backbone_features(self, backbone_feat):
|
|
return self.net.get_backbone_bbreg_feat(backbone_feat)
|
|
|
|
def get_iou_features(self, backbone_feat):
|
|
with torch.no_grad():
|
|
return self.net.bb_regressor.get_iou_feat(self.get_iou_backbone_features(backbone_feat))
|
|
|
|
def get_iou_modulation(self, iou_backbone_feat, target_boxes):
|
|
with torch.no_grad():
|
|
return self.net.bb_regressor.get_modulation(iou_backbone_feat, target_boxes)
|
|
|
|
def generate_init_samples(self, im: torch.Tensor) -> TensorList:
|
|
"""Perform data augmentation to generate initial training samples."""
|
|
|
|
self.init_sample_scale = self.target_scale
|
|
global_shift = torch.zeros(2)
|
|
|
|
self.init_sample_pos = self.pos.round()
|
|
|
|
# Compute augmentation size
|
|
aug_expansion_factor = self.params.get('augmentation_expansion_factor', None)
|
|
|
|
|
|
|
|
aug_expansion_sz = (self.img_sample_sz * aug_expansion_factor).long()
|
|
aug_expansion_sz += (aug_expansion_sz - self.img_sample_sz.long()) % 2
|
|
aug_expansion_sz = aug_expansion_sz.float()
|
|
aug_output_sz = self.img_sample_sz.long().tolist()
|
|
|
|
|
|
random_shift_factor = self.params.get('random_shift_factor', 0)
|
|
|
|
get_rand_shift = lambda: (
|
|
(torch.rand(2) - 0.5) * self.img_sample_sz * random_shift_factor + global_shift).long().tolist()
|
|
|
|
|
|
# Always put identity transformation first, since it is the unaugmented sample that is always used
|
|
self.transforms = [augmentation.Identity(aug_output_sz, global_shift.long().tolist())]
|
|
|
|
augs = self.params.augmentation if self.params.get('use_augmentation', True) else {}
|
|
|
|
# Add all augmentations
|
|
|
|
|
|
get_absolute = lambda shift: (torch.Tensor(shift) * self.img_sample_sz / 2).long().tolist()
|
|
self.transforms.extend(
|
|
[augmentation.Translation(get_absolute(shift), aug_output_sz, global_shift.long().tolist()) for shift in
|
|
augs['relativeshift']])
|
|
|
|
self.transforms.append(augmentation.FlipHorizontal(aug_output_sz, get_rand_shift()))
|
|
|
|
self.transforms.extend(
|
|
[augmentation.Blur(sigma, aug_output_sz, get_rand_shift()) for sigma in augs['blur']])
|
|
|
|
self.transforms.extend(
|
|
[augmentation.Rotate(angle, aug_output_sz, get_rand_shift()) for angle in augs['rotate']])
|
|
|
|
# Extract augmented image patches
|
|
im_patches = sample_patch_transformed(im, self.init_sample_pos, self.init_sample_scale, aug_expansion_sz,
|
|
self.transforms)
|
|
|
|
# Extract initial backbone features
|
|
with torch.no_grad():
|
|
init_backbone_feat = self.net.extract_backbone(im_patches)
|
|
|
|
return init_backbone_feat
|
|
|
|
def init_target_boxes(self):
|
|
"""Get the target bounding boxes for the initial augmented samples."""
|
|
self.classifier_target_box = self.get_iounet_box(self.pos, self.target_sz, self.init_sample_pos,
|
|
self.init_sample_scale)
|
|
init_target_boxes = TensorList()
|
|
for T in self.transforms:
|
|
init_target_boxes.append(self.classifier_target_box + torch.Tensor([T.shift[1], T.shift[0], 0, 0]))
|
|
init_target_boxes = torch.cat(init_target_boxes.view(1, 4), 0).to(self.params.device)
|
|
self.target_boxes = init_target_boxes.new_zeros(self.params.sample_memory_size, 4)
|
|
self.target_boxes[:init_target_boxes.shape[0], :] = init_target_boxes
|
|
return init_target_boxes
|
|
|
|
|
|
def update_state(self, new_pos):
|
|
|
|
# Update pos
|
|
inside_ratio = self.params.get('target_inside_ratio', 0.2)
|
|
inside_offset = (inside_ratio - 0.5) * self.target_sz
|
|
self.pos = torch.max(torch.min(new_pos, self.image_sz - inside_offset), inside_offset)
|
|
|
|
def get_iounet_box(self, pos, sz, sample_pos, sample_scale):
|
|
"""All inputs in original image coordinates.
|
|
Generates a box in the cropped image sample reference frame, in the format used by the IoUNet."""
|
|
box_center = (pos - sample_pos) / sample_scale + (self.img_sample_sz - 1) / 2
|
|
box_sz = sz / sample_scale
|
|
target_ul = box_center - (box_sz - 1) / 2
|
|
return torch.cat([target_ul.flip((0,)), box_sz.flip((0,))])
|
|
|
|
|
|
def init_iou_net(self, backbone_feat):
|
|
# Setup IoU net and objective
|
|
for p in self.net.bb_regressor.parameters():
|
|
p.requires_grad = False
|
|
|
|
# Get target boxes for the different augmentations
|
|
self.classifier_target_box = self.get_iounet_box(self.pos, self.target_sz, self.init_sample_pos,
|
|
self.init_sample_scale)
|
|
target_boxes = TensorList()
|
|
|
|
target_boxes.append(self.classifier_target_box + torch.Tensor(
|
|
[self.transforms[0].shift[1], self.transforms[0].shift[0], 0, 0]))
|
|
target_boxes = torch.cat(target_boxes.view(1, 4), 0).to(self.params.device)
|
|
|
|
# Get iou features
|
|
iou_backbone_feat = self.get_iou_backbone_features(backbone_feat)
|
|
|
|
# Remove other augmentations such as rotation
|
|
iou_backbone_feat = TensorList([x[:target_boxes.shape[0], ...] for x in iou_backbone_feat])
|
|
|
|
# Get modulation vector
|
|
self.iou_modulation = self.get_iou_modulation(iou_backbone_feat, target_boxes)
|
|
|
|
self.iou_modulation = TensorList([x.detach().mean(0) for x in self.iou_modulation])
|
|
|
|
|
|
|
|
def init_classifier(self, init_backbone_feat):
|
|
# Get classification features
|
|
x = self.get_classification_features(init_backbone_feat)
|
|
|
|
# Add the dropout augmentation here, since it requires extraction of the classification features
|
|
num, prob = self.params.augmentation['dropout']
|
|
self.transforms.extend(self.transforms[:1] * num)
|
|
x = torch.cat([x, F.dropout2d(x[0:1, ...].expand(num, -1, -1, -1), p=prob, training=True)])
|
|
|
|
|
|
# Set feature size and other related sizes
|
|
self.feature_sz = torch.Tensor(list(x.shape[-2:]))
|
|
ksz = self.net.classifier.filter_size
|
|
self.kernel_size = torch.Tensor([ksz, ksz] if isinstance(ksz, (int, float)) else ksz)
|
|
self.output_sz = self.feature_sz + (self.kernel_size + 1) % 2
|
|
|
|
# Get target boxes for the different augmentations
|
|
target_boxes = self.init_target_boxes()
|
|
|
|
# Set number of iterations
|
|
num_iter = self.params.get('net_opt_iter', None)
|
|
|
|
# Get target filter by running the discriminative model prediction module
|
|
with torch.no_grad():
|
|
self.target_filter, _, losses = self.net.classifier.get_filter(x, target_boxes, num_iter=num_iter,
|
|
compute_losses=False)
|
|
|
|
|
|
def refine_target_box(self, backbone_feat, sample_pos, sample_scale, scale_ind, update_scale=True):
|
|
"""Run the ATOM IoUNet to refine the target bounding box."""
|
|
|
|
|
|
# Initial box for refinement
|
|
init_box = self.get_iounet_box(self.pos, self.target_sz, sample_pos, sample_scale)
|
|
|
|
# Extract features from the relevant scale
|
|
iou_features = self.get_iou_features(backbone_feat)
|
|
iou_features = TensorList([x[scale_ind:scale_ind + 1, ...] for x in iou_features])
|
|
|
|
# Generate random initial boxes
|
|
init_boxes = init_box.view(1, 4).clone()
|
|
|
|
square_box_sz = init_box[2:].prod().sqrt()
|
|
rand_factor = square_box_sz * torch.cat(
|
|
[self.params.box_jitter_pos * torch.ones(2), self.params.box_jitter_sz * torch.ones(2)])
|
|
|
|
|
|
# Optimize the boxes
|
|
output_boxes, output_iou = self.optimize_boxes(iou_features, init_boxes)
|
|
|
|
# Remove weird boxes
|
|
output_boxes[:, 2:].clamp_(1)
|
|
aspect_ratio = output_boxes[:, 2] / output_boxes[:, 3]
|
|
keep_ind = (aspect_ratio < self.params.maximal_aspect_ratio) * (
|
|
aspect_ratio > 1 / self.params.maximal_aspect_ratio)
|
|
output_boxes = output_boxes[keep_ind, :]
|
|
output_iou = output_iou[keep_ind]
|
|
|
|
# If no box found
|
|
if output_boxes.shape[0] == 0:
|
|
return
|
|
|
|
# Predict box
|
|
k = self.params.get('iounet_k', 5)
|
|
topk = min(k, output_boxes.shape[0])
|
|
_, inds = torch.topk(output_iou, topk)
|
|
predicted_box = output_boxes[inds, :].mean(0)
|
|
|
|
# Get new position and size
|
|
new_pos = predicted_box[:2] + predicted_box[2:] / 2
|
|
new_pos = (new_pos.flip((0,)) - (self.img_sample_sz - 1) / 2) * sample_scale + sample_pos
|
|
new_target_sz = predicted_box[2:].flip((0,)) * sample_scale
|
|
new_scale = torch.sqrt(new_target_sz.prod() / self.base_target_sz.prod())
|
|
|
|
self.pos_iounet = new_pos.clone()
|
|
|
|
self.pos = new_pos.clone()
|
|
|
|
|
|
self.target_sz = new_target_sz
|
|
|
|
|
|
self.target_scale = new_scale
|
|
|
|
|
|
|
|
def optimize_boxes(self, iou_features, init_boxes):
|
|
|
|
return self.optimize_boxes_default(iou_features, init_boxes)
|
|
|
|
|
|
|
|
def optimize_boxes_default(self, iou_features, init_boxes):
|
|
"""Optimize iounet boxes with the default parametrization"""
|
|
output_boxes = init_boxes.view(1, -1, 4).to(self.params.device)
|
|
step_length = self.params.box_refinement_step_length
|
|
|
|
for i_ in range(self.params.box_refinement_iter):
|
|
# forward pass
|
|
bb_init = output_boxes.clone().detach()
|
|
bb_init.requires_grad = True
|
|
|
|
outputs = self.net.bb_regressor.predict_iou(self.iou_modulation, iou_features, bb_init)
|
|
|
|
outputs.backward(gradient=torch.ones_like(outputs))
|
|
|
|
# Update proposal
|
|
output_boxes = bb_init + step_length * bb_init.grad * bb_init[:, :, 2:].repeat(1, 1, 2)
|
|
output_boxes.detach_()
|
|
|
|
step_length *= self.params.box_refinement_step_decay
|
|
|
|
return output_boxes.view(-1, 4).cpu(), outputs.detach().view(-1).cpu()
|
|
|