You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
21 lines
731 B
21 lines
731 B
import torch.nn as nn
|
|
|
|
|
|
class Backbone(nn.Module):
|
|
"""Base class for backbone networks. Handles freezing layers etc.
|
|
args:
|
|
frozen_layers - Name of layers to freeze. Either list of strings, 'none' or 'all'. Default: 'none'.
|
|
"""
|
|
|
|
def __init__(self, frozen_layers=()):
|
|
super().__init__()
|
|
|
|
if isinstance(frozen_layers, str):
|
|
if frozen_layers.lower() == 'none':
|
|
frozen_layers = ()
|
|
elif frozen_layers.lower() != 'all':
|
|
raise ValueError('Unknown option for frozen layers: \"{}\". Should be \"all\", \"none\" or list of layer names.'.format(frozen_layers))
|
|
|
|
self.frozen_layers = frozen_layers
|
|
self._is_frozen_nograd = False
|
|
|