You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

13 lines
552 B

import torch
def max2d(a: torch.Tensor) -> (torch.Tensor, torch.Tensor):
"""Computes maximum and argmax in the last two dimensions."""
torch.cuda.empty_cache()
max_val_row, argmax_row = torch.max(a, dim=-2)
max_val, argmax_col = torch.max(max_val_row, dim=-1)
argmax_row = argmax_row.view(argmax_col.numel(),-1)[torch.arange(argmax_col.numel()), argmax_col.view(-1)]
argmax_row = argmax_row.reshape(argmax_col.shape)
argmax = torch.cat((argmax_row.unsqueeze(-1), argmax_col.unsqueeze(-1)), -1)
return max_val, argmax