You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

20 lines
670 B

from torch import nn
class LinearBlock(nn.Module):
def __init__(self, in_planes, out_planes, input_sz, bias=True, batch_norm=True, relu=True):
super().__init__()
self.linear = nn.Linear(in_planes*input_sz*input_sz, out_planes, bias=bias)
self.bn = nn.BatchNorm2d(out_planes) if batch_norm else None
self.relu = nn.ReLU(inplace=True) if relu else None
def forward(self, x):
x = self.linear(x.reshape(x.shape[0], -1))
if self.bn is not None:
x = self.bn(x.reshape(x.shape[0], x.shape[1], 1, 1))
if self.relu is not None:
x = self.relu(x)
return x.reshape(x.shape[0], -1)