You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
149 lines
5.1 KiB
149 lines
5.1 KiB
import torch
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
|
|
|
|
def numpy_to_torch(a: np.ndarray):
|
|
torch.cuda.empty_cache()
|
|
return torch.from_numpy(a).float().permute(2, 0, 1).unsqueeze(0)
|
|
|
|
|
|
def torch_to_numpy(a: torch.Tensor):
|
|
torch.cuda.empty_cache()
|
|
return a.squeeze(0).permute(1,2,0).numpy()
|
|
|
|
|
|
def sample_patch_transformed(im, pos, scale, image_sz, transforms, is_mask=False):
|
|
"""Extract transformed image samples.
|
|
args:
|
|
im: Image.
|
|
pos: Center position for extraction.
|
|
scale: Image scale to extract features from.
|
|
image_sz: Size to resize the image samples to before extraction.
|
|
transforms: A set of image transforms to apply.
|
|
"""
|
|
torch.cuda.empty_cache()
|
|
# Get image patche
|
|
im_patch, _ = sample_patch(im, pos, scale*image_sz, image_sz, is_mask=is_mask)
|
|
|
|
# Apply transforms
|
|
im_patches = torch.cat([T(im_patch, is_mask=is_mask) for T in transforms])
|
|
|
|
return im_patches
|
|
|
|
def sample_patch_multiscale(im, pos, scales, image_sz, mode: str='replicate', max_scale_change=None):
|
|
"""Extract image patches at multiple scales.
|
|
args:
|
|
im: Image.
|
|
pos: Center position for extraction.
|
|
scales: Image scales to extract image patches from.
|
|
image_sz: Size to resize the image samples to
|
|
mode: how to treat image borders: 'replicate' (default), 'inside' or 'inside_major'
|
|
max_scale_change: maximum allowed scale change when using 'inside' and 'inside_major' mode
|
|
"""
|
|
torch.cuda.empty_cache()
|
|
if isinstance(scales, (int, float)):
|
|
scales = [scales]
|
|
|
|
# Get image patches
|
|
patch_iter, coord_iter = zip(*(sample_patch(im, pos, s*image_sz, image_sz, mode=mode,
|
|
max_scale_change=max_scale_change) for s in scales))
|
|
im_patches = torch.cat(list(patch_iter))
|
|
patch_coords = torch.cat(list(coord_iter))
|
|
|
|
return im_patches, patch_coords
|
|
|
|
def sample_patch(im: torch.Tensor, pos: torch.Tensor, sample_sz: torch.Tensor, output_sz: torch.Tensor = None,
|
|
mode: str = 'replicate', max_scale_change=None, is_mask=False):
|
|
"""Sample an image patch.
|
|
|
|
args:
|
|
im: Image
|
|
pos: center position of crop
|
|
sample_sz: size to crop
|
|
output_sz: size to resize to
|
|
mode: how to treat image borders: 'replicate' (default), 'inside' or 'inside_major'
|
|
max_scale_change: maximum allowed scale change when using 'inside' and 'inside_major' mode
|
|
"""
|
|
torch.cuda.empty_cache()
|
|
# if mode not in ['replicate', 'inside']:
|
|
# raise ValueError('Unknown border mode \'{}\'.'.format(mode))
|
|
|
|
# copy and convert
|
|
posl = pos.long().clone()
|
|
|
|
pad_mode = mode
|
|
|
|
# Get new sample size if forced inside the image
|
|
if mode == 'inside' or mode == 'inside_major':
|
|
pad_mode = 'replicate'
|
|
im_sz = torch.Tensor([im.shape[2], im.shape[3]])
|
|
shrink_factor = (sample_sz.float() / im_sz)
|
|
if mode == 'inside':
|
|
shrink_factor = shrink_factor.max()
|
|
elif mode == 'inside_major':
|
|
shrink_factor = shrink_factor.min()
|
|
shrink_factor.clamp_(min=1, max=max_scale_change)
|
|
sample_sz = (sample_sz.float() / shrink_factor).long()
|
|
|
|
# Compute pre-downsampling factor
|
|
if output_sz is not None:
|
|
resize_factor = torch.min(sample_sz.float() / output_sz.float()).item()
|
|
df = int(max(int(resize_factor - 0.1), 1))
|
|
else:
|
|
df = int(1)
|
|
|
|
sz = sample_sz.float() / df # new size
|
|
|
|
# Do downsampling
|
|
if df > 1:
|
|
os = posl % df # offset
|
|
posl = (posl - os) / df # new position
|
|
im2 = im[..., os[0].item()::df, os[1].item()::df] # downsample
|
|
else:
|
|
im2 = im
|
|
|
|
# compute size to crop
|
|
szl = torch.max(sz.round(), torch.Tensor([2])).long()
|
|
|
|
# Extract top and bottom coordinates
|
|
tl = posl - (szl - 1) / 2
|
|
br = posl + szl/2 + 1
|
|
|
|
# Shift the crop to inside
|
|
if mode == 'inside' or mode == 'inside_major':
|
|
im2_sz = torch.LongTensor([im2.shape[2], im2.shape[3]])
|
|
shift = (-tl).clamp(0) - (br - im2_sz).clamp(0)
|
|
tl += shift
|
|
br += shift
|
|
|
|
outside = ((-tl).clamp(0) + (br - im2_sz).clamp(0)) // 2
|
|
shift = (-tl - outside) * (outside > 0).long()
|
|
tl += shift
|
|
br += shift
|
|
|
|
# Get image patch
|
|
# im_patch = im2[...,tl[0].item():br[0].item(),tl[1].item():br[1].item()]
|
|
|
|
# Get image patch
|
|
pad = (-tl[1].int().item(), br[1].int().item() - im2.shape[3],
|
|
-tl[0].int().item(), br[0].int().item() - im2.shape[2])
|
|
|
|
if not is_mask:
|
|
im_patch = F.pad(im2, pad, pad_mode)
|
|
else:
|
|
im_patch = F.pad(im2, pad)
|
|
|
|
# Get image coordinates
|
|
patch_coord = df * torch.cat((tl, br)).view(1,4)
|
|
|
|
if output_sz is None or (im_patch.shape[-2] == output_sz[0] and im_patch.shape[-1] == output_sz[1]):
|
|
return im_patch.clone(), patch_coord
|
|
|
|
# Resample
|
|
if not is_mask:
|
|
im_patch = F.interpolate(im_patch, output_sz.long().tolist(), mode='bilinear')
|
|
else:
|
|
im_patch = F.interpolate(im_patch, output_sz.long().tolist(), mode='nearest')
|
|
|
|
return im_patch, patch_coord
|