import torch def max2d(a: torch.Tensor) -> (torch.Tensor, torch.Tensor): """Computes maximum and argmax in the last two dimensions.""" torch.cuda.empty_cache() max_val_row, argmax_row = torch.max(a, dim=-2) max_val, argmax_col = torch.max(max_val_row, dim=-1) argmax_row = argmax_row.view(argmax_col.numel(),-1)[torch.arange(argmax_col.numel()), argmax_col.view(-1)] argmax_row = argmax_row.reshape(argmax_col.shape) argmax = torch.cat((argmax_row.unsqueeze(-1), argmax_col.unsqueeze(-1)), -1) return max_val, argmax