You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
20 lines
670 B
20 lines
670 B
from torch import nn
|
|
|
|
|
|
|
|
|
|
class LinearBlock(nn.Module):
|
|
|
|
def __init__(self, in_planes, out_planes, input_sz, bias=True, batch_norm=True, relu=True):
|
|
super().__init__()
|
|
self.linear = nn.Linear(in_planes*input_sz*input_sz, out_planes, bias=bias)
|
|
self.bn = nn.BatchNorm2d(out_planes) if batch_norm else None
|
|
self.relu = nn.ReLU(inplace=True) if relu else None
|
|
|
|
def forward(self, x):
|
|
x = self.linear(x.reshape(x.shape[0], -1))
|
|
if self.bn is not None:
|
|
x = self.bn(x.reshape(x.shape[0], x.shape[1], 1, 1))
|
|
if self.relu is not None:
|
|
x = self.relu(x)
|
|
return x.reshape(x.shape[0], -1)
|