You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
32 lines
914 B
32 lines
914 B
import functools
|
|
import torch
|
|
import copy
|
|
|
|
|
|
class TensorList(list):
|
|
"""Container mainly used for lists of torch tensors. Extends lists with pytorch functionality."""
|
|
|
|
def __iadd__(self, other):
|
|
if TensorList._iterable(other):
|
|
for i, e2 in enumerate(other):
|
|
self[i] += e2
|
|
else:
|
|
for i in range(len(self)):
|
|
self[i] += other
|
|
return self
|
|
def copy(self):
|
|
return TensorList(super(TensorList, self).copy())
|
|
|
|
def __getattr__(self, name):
|
|
if not hasattr(torch.Tensor, name):
|
|
raise AttributeError('\'TensorList\' object has not attribute \'{}\''.format(name))
|
|
|
|
def apply_attr(*args, **kwargs):
|
|
return TensorList([getattr(e, name)(*args, **kwargs) for e in self])
|
|
|
|
return apply_attr
|
|
|
|
@staticmethod
|
|
def _iterable(a):
|
|
return isinstance(a, (TensorList, list))
|
|
|