from torch import nn from torchvision.models.resnet import Bottleneck from ltr.models.layers.normalization import InstanceL2Norm def residual_bottleneck(feature_dim=256, num_blocks=1, l2norm=True, final_conv=False, norm_scale=1.0, out_dim=None, interp_cat=False, final_relu=False, final_pool=False, input_dim=None, final_stride=1): """Construct a network block based on the Bottleneck block used in ResNet.""" if out_dim is None: out_dim = feature_dim if input_dim is None: input_dim = 4*feature_dim dim = input_dim feat_layers = [] for i in range(num_blocks): planes = feature_dim if i < num_blocks - 1 + int(final_conv) else out_dim // 4 feat_layers.append(Bottleneck(dim, planes)) dim = 4*feature_dim if final_conv: feat_layers.append(nn.Conv2d(dim, out_dim, kernel_size=3, padding=1, bias=False, stride=final_stride)) if final_relu: feat_layers.append(nn.ReLU(inplace=True)) if final_pool: feat_layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) if l2norm: feat_layers.append(InstanceL2Norm(scale=norm_scale)) return nn.Sequential(*feat_layers)