You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

21 lines
731 B

import torch.nn as nn
class Backbone(nn.Module):
"""Base class for backbone networks. Handles freezing layers etc.
args:
frozen_layers - Name of layers to freeze. Either list of strings, 'none' or 'all'. Default: 'none'.
"""
def __init__(self, frozen_layers=()):
super().__init__()
if isinstance(frozen_layers, str):
if frozen_layers.lower() == 'none':
frozen_layers = ()
elif frozen_layers.lower() != 'all':
raise ValueError('Unknown option for frozen layers: \"{}\". Should be \"all\", \"none\" or list of layer names.'.format(frozen_layers))
self.frozen_layers = frozen_layers
self._is_frozen_nograd = False