You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

32 lines
914 B

import functools
import torch
import copy
class TensorList(list):
"""Container mainly used for lists of torch tensors. Extends lists with pytorch functionality."""
def __iadd__(self, other):
if TensorList._iterable(other):
for i, e2 in enumerate(other):
self[i] += e2
else:
for i in range(len(self)):
self[i] += other
return self
def copy(self):
return TensorList(super(TensorList, self).copy())
def __getattr__(self, name):
if not hasattr(torch.Tensor, name):
raise AttributeError('\'TensorList\' object has not attribute \'{}\''.format(name))
def apply_attr(*args, **kwargs):
return TensorList([getattr(e, name)(*args, **kwargs) for e in self])
return apply_attr
@staticmethod
def _iterable(a):
return isinstance(a, (TensorList, list))