You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
31 lines
1.2 KiB
31 lines
1.2 KiB
from torch import nn
|
|
from torchvision.models.resnet import Bottleneck
|
|
from ltr.models.layers.normalization import InstanceL2Norm
|
|
|
|
|
|
|
|
|
|
def residual_bottleneck(feature_dim=256, num_blocks=1, l2norm=True, final_conv=False, norm_scale=1.0, out_dim=None,
|
|
interp_cat=False, final_relu=False, final_pool=False, input_dim=None, final_stride=1):
|
|
"""Construct a network block based on the Bottleneck block used in ResNet."""
|
|
|
|
if out_dim is None:
|
|
out_dim = feature_dim
|
|
if input_dim is None:
|
|
input_dim = 4*feature_dim
|
|
dim = input_dim
|
|
feat_layers = []
|
|
|
|
for i in range(num_blocks):
|
|
planes = feature_dim if i < num_blocks - 1 + int(final_conv) else out_dim // 4
|
|
feat_layers.append(Bottleneck(dim, planes))
|
|
dim = 4*feature_dim
|
|
if final_conv:
|
|
feat_layers.append(nn.Conv2d(dim, out_dim, kernel_size=3, padding=1, bias=False, stride=final_stride))
|
|
if final_relu:
|
|
feat_layers.append(nn.ReLU(inplace=True))
|
|
if final_pool:
|
|
feat_layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
|
|
if l2norm:
|
|
feat_layers.append(InstanceL2Norm(scale=norm_scale))
|
|
return nn.Sequential(*feat_layers)
|