import numpy as np import math import torch import torch.nn.functional as F import cv2 as cv import random from pytracking.features.preprocessing import numpy_to_torch, torch_to_numpy class Transform: """Base data augmentation transform class.""" def __init__(self, output_sz = None, shift = None): self.output_sz = output_sz self.shift = (0,0) if shift is None else shift torch.cuda.empty_cache() def crop_to_output(self, image): torch.cuda.empty_cache() if isinstance(image, torch.Tensor): imsz = image.shape[2:] if self.output_sz is None: pad_h = 0 pad_w = 0 else: pad_h = (self.output_sz[0] - imsz[0]) / 2 pad_w = (self.output_sz[1] - imsz[1]) / 2 pad_left = math.floor(pad_w) + self.shift[1] pad_right = math.ceil(pad_w) - self.shift[1] pad_top = math.floor(pad_h) + self.shift[0] pad_bottom = math.ceil(pad_h) - self.shift[0] return F.pad(image, (pad_left, pad_right, pad_top, pad_bottom), 'replicate') else: raise NotImplementedError class Identity(Transform): """Identity transformation.""" def __call__(self, image, is_mask=False): return self.crop_to_output(image) class FlipHorizontal(Transform): """Flip along horizontal axis.""" def __call__(self, image, is_mask=False): if isinstance(image, torch.Tensor): return self.crop_to_output(image.flip((3,))) else: return np.fliplr(image) class Translation(Transform): """Translate.""" def __init__(self, translation, output_sz = None, shift = None): super().__init__(output_sz, shift) self.shift = (self.shift[0] + translation[0], self.shift[1] + translation[1]) def __call__(self, image, is_mask=False): if isinstance(image, torch.Tensor): return self.crop_to_output(image) else: raise NotImplementedError class Rotate(Transform): """Rotate with given angle.""" def __init__(self, angle, output_sz = None, shift = None): super().__init__(output_sz, shift) self.angle = math.pi * angle/180 def __call__(self, image, is_mask=False): if isinstance(image, torch.Tensor): return self.crop_to_output(numpy_to_torch(self(torch_to_numpy(image)))) else: c = (np.expand_dims(np.array(image.shape[:2]),1)-1)/2 R = np.array([[math.cos(self.angle), math.sin(self.angle)], [-math.sin(self.angle), math.cos(self.angle)]]) H =np.concatenate([R, c - R @ c], 1) return cv.warpAffine(image, H, image.shape[1::-1], borderMode=cv.BORDER_REPLICATE) class Blur(Transform): """Blur with given sigma (can be axis dependent).""" def __init__(self, sigma, output_sz = None, shift = None): super().__init__(output_sz, shift) if isinstance(sigma, (float, int)): sigma = (sigma, sigma) self.sigma = sigma self.filter_size = [math.ceil(2*s) for s in self.sigma] x_coord = [torch.arange(-sz, sz+1, dtype=torch.float32) for sz in self.filter_size] self.filter = [torch.exp(-(x**2)/(2*s**2)) for x, s in zip(x_coord, self.sigma)] self.filter[0] = self.filter[0].view(1,1,-1,1) / self.filter[0].sum() self.filter[1] = self.filter[1].view(1,1,1,-1) / self.filter[1].sum() def __call__(self, image, is_mask=False): if isinstance(image, torch.Tensor): sz = image.shape[2:] im1 = F.conv2d(image.view(-1,1,sz[0],sz[1]), self.filter[0], padding=(self.filter_size[0],0)) return self.crop_to_output(F.conv2d(im1, self.filter[1], padding=(0,self.filter_size[1])).view(1,-1,sz[0],sz[1])) else: raise NotImplementedError