You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
104 lines
4.1 KiB
104 lines
4.1 KiB
import torch.nn as nn
|
|
import torch
|
|
from ltr.external.PreciseRoIPooling.pytorch.prroi_pool import PrRoIPool2D
|
|
import math
|
|
|
|
|
|
class FilterPool(nn.Module):
|
|
"""Pool the target region in a feature map.
|
|
args:
|
|
filter_size: Size of the filter.
|
|
feature_stride: Input feature stride.
|
|
pool_square: Do a square pooling instead of pooling the exact target region."""
|
|
|
|
def __init__(self, filter_size=1, feature_stride=16, pool_square=False):
|
|
super().__init__()
|
|
self.prroi_pool = PrRoIPool2D(filter_size, filter_size, 1/feature_stride)
|
|
self.pool_square = pool_square
|
|
|
|
def forward(self, feat, bb):
|
|
"""Pool the regions in bb.
|
|
args:
|
|
feat: Input feature maps. Dims (num_samples, feat_dim, H, W).
|
|
bb: Target bounding boxes (x, y, w, h) in the image coords. Dims (num_samples, 4).
|
|
returns:
|
|
pooled_feat: Pooled features. Dims (num_samples, feat_dim, wH, wW)."""
|
|
|
|
# Add batch_index to rois
|
|
bb = bb.reshape(-1,4)
|
|
num_images_total = bb.shape[0]
|
|
batch_index = torch.arange(num_images_total, dtype=torch.float32).reshape(-1, 1).to(bb.device)
|
|
|
|
# input bb is in format xywh, convert it to x0y0x1y1 format
|
|
pool_bb = bb.clone()
|
|
|
|
if self.pool_square:
|
|
bb_sz = pool_bb[:, 2:4].prod(dim=1, keepdim=True).sqrt()
|
|
pool_bb[:, :2] += pool_bb[:, 2:]/2 - bb_sz/2
|
|
pool_bb[:, 2:] = bb_sz
|
|
|
|
pool_bb[:, 2:4] = pool_bb[:, 0:2] + pool_bb[:, 2:4]
|
|
roi1 = torch.cat((batch_index, pool_bb), dim=1)
|
|
|
|
return self.prroi_pool(feat, roi1)
|
|
|
|
|
|
class FilterInitializerLinear(nn.Module):
|
|
"""Initializes a target classification filter by applying a linear conv layer and then pooling the target region.
|
|
args:
|
|
filter_size: Size of the filter.
|
|
feature_dim: Input feature dimentionality.
|
|
feature_stride: Input feature stride.
|
|
pool_square: Do a square pooling instead of pooling the exact target region.
|
|
filter_norm: Normalize the output filter with its size in the end.
|
|
conv_ksz: Kernel size of the conv layer before pooling."""
|
|
|
|
def __init__(self, filter_size=1, feature_dim=256, feature_stride=16, pool_square=False, filter_norm=True,
|
|
conv_ksz=3, init_weights='default'):
|
|
super().__init__()
|
|
|
|
self.filter_conv = nn.Conv2d(feature_dim, feature_dim, kernel_size=conv_ksz, padding=conv_ksz // 2)
|
|
self.filter_pool = FilterPool(filter_size=filter_size, feature_stride=feature_stride, pool_square=pool_square)
|
|
self.filter_norm = filter_norm
|
|
|
|
# Init weights
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
if init_weights == 'default':
|
|
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
|
m.weight.data.normal_(0, math.sqrt(2. / n))
|
|
elif init_weights == 'zero':
|
|
m.weight.data.zero_()
|
|
if m.bias is not None:
|
|
m.bias.data.zero_()
|
|
elif isinstance(m, nn.BatchNorm2d):
|
|
m.weight.data.fill_(1)
|
|
m.bias.data.zero_()
|
|
|
|
|
|
|
|
def forward(self, feat, bb):
|
|
"""Runs the initializer module.
|
|
Note that [] denotes an optional dimension.
|
|
args:
|
|
feat: Input feature maps. Dims (images_in_sequence, [sequences], feat_dim, H, W).
|
|
bb: Target bounding boxes (x, y, w, h) in the image coords. Dims (images_in_sequence, [sequences], 4).
|
|
returns:
|
|
weights: The output weights. Dims (sequences, feat_dim, wH, wW)."""
|
|
|
|
num_images = feat.shape[0]
|
|
|
|
feat = self.filter_conv(feat.reshape(-1, feat.shape[-3], feat.shape[-2], feat.shape[-1]))
|
|
|
|
weights = self.filter_pool(feat, bb)
|
|
|
|
# If multiple input images, compute the initial filter as the average filter.
|
|
if num_images > 1:
|
|
weights = torch.mean(weights.reshape(num_images, -1, weights.shape[-3], weights.shape[-2], weights.shape[-1]), dim=0)
|
|
|
|
if self.filter_norm:
|
|
weights = weights / (weights.shape[1] * weights.shape[2] * weights.shape[3])
|
|
|
|
return weights
|
|
|
|
|