You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
103 lines
3.8 KiB
103 lines
3.8 KiB
import numpy as np
|
|
import math
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import cv2 as cv
|
|
import random
|
|
from pytracking.features.preprocessing import numpy_to_torch, torch_to_numpy
|
|
|
|
class Transform:
|
|
"""Base data augmentation transform class."""
|
|
|
|
def __init__(self, output_sz = None, shift = None):
|
|
self.output_sz = output_sz
|
|
self.shift = (0,0) if shift is None else shift
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
def crop_to_output(self, image):
|
|
torch.cuda.empty_cache()
|
|
|
|
if isinstance(image, torch.Tensor):
|
|
imsz = image.shape[2:]
|
|
if self.output_sz is None:
|
|
pad_h = 0
|
|
pad_w = 0
|
|
else:
|
|
pad_h = (self.output_sz[0] - imsz[0]) / 2
|
|
pad_w = (self.output_sz[1] - imsz[1]) / 2
|
|
|
|
pad_left = math.floor(pad_w) + self.shift[1]
|
|
pad_right = math.ceil(pad_w) - self.shift[1]
|
|
pad_top = math.floor(pad_h) + self.shift[0]
|
|
pad_bottom = math.ceil(pad_h) - self.shift[0]
|
|
|
|
return F.pad(image, (pad_left, pad_right, pad_top, pad_bottom), 'replicate')
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
|
|
class Identity(Transform):
|
|
"""Identity transformation."""
|
|
def __call__(self, image, is_mask=False):
|
|
return self.crop_to_output(image)
|
|
|
|
class FlipHorizontal(Transform):
|
|
"""Flip along horizontal axis."""
|
|
|
|
def __call__(self, image, is_mask=False):
|
|
if isinstance(image, torch.Tensor):
|
|
return self.crop_to_output(image.flip((3,)))
|
|
else:
|
|
return np.fliplr(image)
|
|
|
|
class Translation(Transform):
|
|
"""Translate."""
|
|
|
|
def __init__(self, translation, output_sz = None, shift = None):
|
|
super().__init__(output_sz, shift)
|
|
self.shift = (self.shift[0] + translation[0], self.shift[1] + translation[1])
|
|
def __call__(self, image, is_mask=False):
|
|
if isinstance(image, torch.Tensor):
|
|
return self.crop_to_output(image)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
class Rotate(Transform):
|
|
"""Rotate with given angle."""
|
|
def __init__(self, angle, output_sz = None, shift = None):
|
|
super().__init__(output_sz, shift)
|
|
self.angle = math.pi * angle/180
|
|
|
|
|
|
def __call__(self, image, is_mask=False):
|
|
if isinstance(image, torch.Tensor):
|
|
return self.crop_to_output(numpy_to_torch(self(torch_to_numpy(image))))
|
|
else:
|
|
c = (np.expand_dims(np.array(image.shape[:2]),1)-1)/2
|
|
R = np.array([[math.cos(self.angle), math.sin(self.angle)],
|
|
[-math.sin(self.angle), math.cos(self.angle)]])
|
|
H =np.concatenate([R, c - R @ c], 1)
|
|
return cv.warpAffine(image, H, image.shape[1::-1], borderMode=cv.BORDER_REPLICATE)
|
|
|
|
class Blur(Transform):
|
|
"""Blur with given sigma (can be axis dependent)."""
|
|
def __init__(self, sigma, output_sz = None, shift = None):
|
|
super().__init__(output_sz, shift)
|
|
if isinstance(sigma, (float, int)):
|
|
sigma = (sigma, sigma)
|
|
self.sigma = sigma
|
|
self.filter_size = [math.ceil(2*s) for s in self.sigma]
|
|
x_coord = [torch.arange(-sz, sz+1, dtype=torch.float32) for sz in self.filter_size]
|
|
self.filter = [torch.exp(-(x**2)/(2*s**2)) for x, s in zip(x_coord, self.sigma)]
|
|
self.filter[0] = self.filter[0].view(1,1,-1,1) / self.filter[0].sum()
|
|
self.filter[1] = self.filter[1].view(1,1,1,-1) / self.filter[1].sum()
|
|
|
|
|
|
def __call__(self, image, is_mask=False):
|
|
if isinstance(image, torch.Tensor):
|
|
sz = image.shape[2:]
|
|
im1 = F.conv2d(image.view(-1,1,sz[0],sz[1]), self.filter[0], padding=(self.filter_size[0],0))
|
|
return self.crop_to_output(F.conv2d(im1, self.filter[1], padding=(0,self.filter_size[1])).view(1,-1,sz[0],sz[1]))
|
|
else:
|
|
raise NotImplementedError
|