import functools import torch import copy class TensorList(list): """Container mainly used for lists of torch tensors. Extends lists with pytorch functionality.""" def __iadd__(self, other): if TensorList._iterable(other): for i, e2 in enumerate(other): self[i] += e2 else: for i in range(len(self)): self[i] += other return self def copy(self): return TensorList(super(TensorList, self).copy()) def __getattr__(self, name): if not hasattr(torch.Tensor, name): raise AttributeError('\'TensorList\' object has not attribute \'{}\''.format(name)) def apply_attr(*args, **kwargs): return TensorList([getattr(e, name)(*args, **kwargs) for e in self]) return apply_attr @staticmethod def _iterable(a): return isinstance(a, (TensorList, list))