You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
13 lines
552 B
13 lines
552 B
import torch
|
|
|
|
|
|
def max2d(a: torch.Tensor) -> (torch.Tensor, torch.Tensor):
|
|
"""Computes maximum and argmax in the last two dimensions."""
|
|
torch.cuda.empty_cache()
|
|
|
|
max_val_row, argmax_row = torch.max(a, dim=-2)
|
|
max_val, argmax_col = torch.max(max_val_row, dim=-1)
|
|
argmax_row = argmax_row.view(argmax_col.numel(),-1)[torch.arange(argmax_col.numel()), argmax_col.view(-1)]
|
|
argmax_row = argmax_row.reshape(argmax_col.shape)
|
|
argmax = torch.cat((argmax_row.unsqueeze(-1), argmax_col.unsqueeze(-1)), -1)
|
|
return max_val, argmax
|