You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

31 lines
1.2 KiB

from torch import nn
from torchvision.models.resnet import Bottleneck
from ltr.models.layers.normalization import InstanceL2Norm
def residual_bottleneck(feature_dim=256, num_blocks=1, l2norm=True, final_conv=False, norm_scale=1.0, out_dim=None,
interp_cat=False, final_relu=False, final_pool=False, input_dim=None, final_stride=1):
"""Construct a network block based on the Bottleneck block used in ResNet."""
if out_dim is None:
out_dim = feature_dim
if input_dim is None:
input_dim = 4*feature_dim
dim = input_dim
feat_layers = []
for i in range(num_blocks):
planes = feature_dim if i < num_blocks - 1 + int(final_conv) else out_dim // 4
feat_layers.append(Bottleneck(dim, planes))
dim = 4*feature_dim
if final_conv:
feat_layers.append(nn.Conv2d(dim, out_dim, kernel_size=3, padding=1, bias=False, stride=final_stride))
if final_relu:
feat_layers.append(nn.ReLU(inplace=True))
if final_pool:
feat_layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
if l2norm:
feat_layers.append(InstanceL2Norm(scale=norm_scale))
return nn.Sequential(*feat_layers)