You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							149 lines
						
					
					
						
							5.1 KiB
						
					
					
				
			
		
		
		
			
			
			
				
					
				
				
					
				
			
		
		
	
	
							149 lines
						
					
					
						
							5.1 KiB
						
					
					
				| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| 
 | |
| 
 | |
| def numpy_to_torch(a: np.ndarray): | |
|     torch.cuda.empty_cache() | |
|     return torch.from_numpy(a).float().permute(2, 0, 1).unsqueeze(0) | |
| 
 | |
| 
 | |
| def torch_to_numpy(a: torch.Tensor): | |
|     torch.cuda.empty_cache() | |
|     return a.squeeze(0).permute(1,2,0).numpy() | |
| 
 | |
| 
 | |
| def sample_patch_transformed(im, pos, scale, image_sz, transforms, is_mask=False): | |
|     """Extract transformed image samples. | |
|     args: | |
|         im: Image. | |
|         pos: Center position for extraction. | |
|         scale: Image scale to extract features from. | |
|         image_sz: Size to resize the image samples to before extraction. | |
|         transforms: A set of image transforms to apply. | |
|     """ | |
|     torch.cuda.empty_cache() | |
|     # Get image patche | |
|     im_patch, _ = sample_patch(im, pos, scale*image_sz, image_sz, is_mask=is_mask) | |
| 
 | |
|     # Apply transforms | |
|     im_patches = torch.cat([T(im_patch, is_mask=is_mask) for T in transforms]) | |
| 
 | |
|     return im_patches | |
| 
 | |
| def sample_patch_multiscale(im, pos, scales, image_sz, mode: str='replicate', max_scale_change=None): | |
|     """Extract image patches at multiple scales. | |
|     args: | |
|         im: Image. | |
|         pos: Center position for extraction. | |
|         scales: Image scales to extract image patches from. | |
|         image_sz: Size to resize the image samples to | |
|         mode: how to treat image borders: 'replicate' (default), 'inside' or 'inside_major' | |
|         max_scale_change: maximum allowed scale change when using 'inside' and 'inside_major' mode | |
|     """ | |
|     torch.cuda.empty_cache() | |
|     if isinstance(scales, (int, float)): | |
|         scales = [scales] | |
| 
 | |
|     # Get image patches | |
|     patch_iter, coord_iter = zip(*(sample_patch(im, pos, s*image_sz, image_sz, mode=mode, | |
|                                                 max_scale_change=max_scale_change) for s in scales)) | |
|     im_patches = torch.cat(list(patch_iter)) | |
|     patch_coords = torch.cat(list(coord_iter)) | |
| 
 | |
|     return  im_patches, patch_coords | |
| 
 | |
| def sample_patch(im: torch.Tensor, pos: torch.Tensor, sample_sz: torch.Tensor, output_sz: torch.Tensor = None, | |
|                  mode: str = 'replicate', max_scale_change=None, is_mask=False): | |
|     """Sample an image patch. | |
|  | |
|     args: | |
|         im: Image | |
|         pos: center position of crop | |
|         sample_sz: size to crop | |
|         output_sz: size to resize to | |
|         mode: how to treat image borders: 'replicate' (default), 'inside' or 'inside_major' | |
|         max_scale_change: maximum allowed scale change when using 'inside' and 'inside_major' mode | |
|     """ | |
|     torch.cuda.empty_cache() | |
|     # if mode not in ['replicate', 'inside']: | |
|     #     raise ValueError('Unknown border mode \'{}\'.'.format(mode)) | |
| 
 | |
|     # copy and convert | |
|     posl = pos.long().clone() | |
| 
 | |
|     pad_mode = mode | |
| 
 | |
|     # Get new sample size if forced inside the image | |
|     if mode == 'inside' or mode == 'inside_major': | |
|         pad_mode = 'replicate' | |
|         im_sz = torch.Tensor([im.shape[2], im.shape[3]]) | |
|         shrink_factor = (sample_sz.float() / im_sz) | |
|         if mode == 'inside': | |
|             shrink_factor = shrink_factor.max() | |
|         elif mode == 'inside_major': | |
|             shrink_factor = shrink_factor.min() | |
|         shrink_factor.clamp_(min=1, max=max_scale_change) | |
|         sample_sz = (sample_sz.float() / shrink_factor).long() | |
| 
 | |
|     # Compute pre-downsampling factor | |
|     if output_sz is not None: | |
|         resize_factor = torch.min(sample_sz.float() / output_sz.float()).item() | |
|         df = int(max(int(resize_factor - 0.1), 1)) | |
|     else: | |
|         df = int(1) | |
| 
 | |
|     sz = sample_sz.float() / df     # new size | |
| 
 | |
|     # Do downsampling | |
|     if df > 1: | |
|         os = posl % df              # offset | |
|         posl = (posl - os) / df     # new position | |
|         im2 = im[..., os[0].item()::df, os[1].item()::df]   # downsample | |
|     else: | |
|         im2 = im | |
| 
 | |
|     # compute size to crop | |
|     szl = torch.max(sz.round(), torch.Tensor([2])).long() | |
| 
 | |
|     # Extract top and bottom coordinates | |
|     tl = posl - (szl - 1) / 2 | |
|     br = posl + szl/2 + 1 | |
| 
 | |
|     # Shift the crop to inside | |
|     if mode == 'inside' or mode == 'inside_major': | |
|         im2_sz = torch.LongTensor([im2.shape[2], im2.shape[3]]) | |
|         shift = (-tl).clamp(0) - (br - im2_sz).clamp(0) | |
|         tl += shift | |
|         br += shift | |
| 
 | |
|         outside = ((-tl).clamp(0) + (br - im2_sz).clamp(0)) // 2 | |
|         shift = (-tl - outside) * (outside > 0).long() | |
|         tl += shift | |
|         br += shift | |
| 
 | |
|         # Get image patch | |
|         # im_patch = im2[...,tl[0].item():br[0].item(),tl[1].item():br[1].item()] | |
| 
 | |
|     # Get image patch | |
|     pad = (-tl[1].int().item(), br[1].int().item() - im2.shape[3], | |
|            -tl[0].int().item(), br[0].int().item() - im2.shape[2]) | |
| 
 | |
|     if not is_mask: | |
|         im_patch = F.pad(im2, pad, pad_mode) | |
|     else: | |
|         im_patch = F.pad(im2, pad) | |
| 
 | |
|     # Get image coordinates | |
|     patch_coord = df * torch.cat((tl, br)).view(1,4) | |
| 
 | |
|     if output_sz is None or (im_patch.shape[-2] == output_sz[0] and im_patch.shape[-1] == output_sz[1]): | |
|         return im_patch.clone(), patch_coord | |
| 
 | |
|     # Resample | |
|     if not is_mask: | |
|         im_patch = F.interpolate(im_patch, output_sz.long().tolist(), mode='bilinear') | |
|     else: | |
|         im_patch = F.interpolate(im_patch, output_sz.long().tolist(), mode='nearest') | |
| 
 | |
|     return im_patch, patch_coord
 |