import torch.nn as nn class Backbone(nn.Module): """Base class for backbone networks. Handles freezing layers etc. args: frozen_layers - Name of layers to freeze. Either list of strings, 'none' or 'all'. Default: 'none'. """ def __init__(self, frozen_layers=()): super().__init__() if isinstance(frozen_layers, str): if frozen_layers.lower() == 'none': frozen_layers = () elif frozen_layers.lower() != 'all': raise ValueError('Unknown option for frozen layers: \"{}\". Should be \"all\", \"none\" or list of layer names.'.format(frozen_layers)) self.frozen_layers = frozen_layers self._is_frozen_nograd = False