You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

104 lines
4.1 KiB

import torch.nn as nn
import torch
from ltr.external.PreciseRoIPooling.pytorch.prroi_pool import PrRoIPool2D
import math
class FilterPool(nn.Module):
"""Pool the target region in a feature map.
args:
filter_size: Size of the filter.
feature_stride: Input feature stride.
pool_square: Do a square pooling instead of pooling the exact target region."""
def __init__(self, filter_size=1, feature_stride=16, pool_square=False):
super().__init__()
self.prroi_pool = PrRoIPool2D(filter_size, filter_size, 1/feature_stride)
self.pool_square = pool_square
def forward(self, feat, bb):
"""Pool the regions in bb.
args:
feat: Input feature maps. Dims (num_samples, feat_dim, H, W).
bb: Target bounding boxes (x, y, w, h) in the image coords. Dims (num_samples, 4).
returns:
pooled_feat: Pooled features. Dims (num_samples, feat_dim, wH, wW)."""
# Add batch_index to rois
bb = bb.reshape(-1,4)
num_images_total = bb.shape[0]
batch_index = torch.arange(num_images_total, dtype=torch.float32).reshape(-1, 1).to(bb.device)
# input bb is in format xywh, convert it to x0y0x1y1 format
pool_bb = bb.clone()
if self.pool_square:
bb_sz = pool_bb[:, 2:4].prod(dim=1, keepdim=True).sqrt()
pool_bb[:, :2] += pool_bb[:, 2:]/2 - bb_sz/2
pool_bb[:, 2:] = bb_sz
pool_bb[:, 2:4] = pool_bb[:, 0:2] + pool_bb[:, 2:4]
roi1 = torch.cat((batch_index, pool_bb), dim=1)
return self.prroi_pool(feat, roi1)
class FilterInitializerLinear(nn.Module):
"""Initializes a target classification filter by applying a linear conv layer and then pooling the target region.
args:
filter_size: Size of the filter.
feature_dim: Input feature dimentionality.
feature_stride: Input feature stride.
pool_square: Do a square pooling instead of pooling the exact target region.
filter_norm: Normalize the output filter with its size in the end.
conv_ksz: Kernel size of the conv layer before pooling."""
def __init__(self, filter_size=1, feature_dim=256, feature_stride=16, pool_square=False, filter_norm=True,
conv_ksz=3, init_weights='default'):
super().__init__()
self.filter_conv = nn.Conv2d(feature_dim, feature_dim, kernel_size=conv_ksz, padding=conv_ksz // 2)
self.filter_pool = FilterPool(filter_size=filter_size, feature_stride=feature_stride, pool_square=pool_square)
self.filter_norm = filter_norm
# Init weights
for m in self.modules():
if isinstance(m, nn.Conv2d):
if init_weights == 'default':
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif init_weights == 'zero':
m.weight.data.zero_()
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def forward(self, feat, bb):
"""Runs the initializer module.
Note that [] denotes an optional dimension.
args:
feat: Input feature maps. Dims (images_in_sequence, [sequences], feat_dim, H, W).
bb: Target bounding boxes (x, y, w, h) in the image coords. Dims (images_in_sequence, [sequences], 4).
returns:
weights: The output weights. Dims (sequences, feat_dim, wH, wW)."""
num_images = feat.shape[0]
feat = self.filter_conv(feat.reshape(-1, feat.shape[-3], feat.shape[-2], feat.shape[-1]))
weights = self.filter_pool(feat, bb)
# If multiple input images, compute the initial filter as the average filter.
if num_images > 1:
weights = torch.mean(weights.reshape(num_images, -1, weights.shape[-3], weights.shape[-2], weights.shape[-1]), dim=0)
if self.filter_norm:
weights = weights / (weights.shape[1] * weights.shape[2] * weights.shape[3])
return weights