from torch import nn class LinearBlock(nn.Module): def __init__(self, in_planes, out_planes, input_sz, bias=True, batch_norm=True, relu=True): super().__init__() self.linear = nn.Linear(in_planes*input_sz*input_sz, out_planes, bias=bias) self.bn = nn.BatchNorm2d(out_planes) if batch_norm else None self.relu = nn.ReLU(inplace=True) if relu else None def forward(self, x): x = self.linear(x.reshape(x.shape[0], -1)) if self.bn is not None: x = self.bn(x.reshape(x.shape[0], x.shape[1], 1, 1)) if self.relu is not None: x = self.relu(x) return x.reshape(x.shape[0], -1)