You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

171 lines
8.6 KiB

import torch.nn as nn
import torch
import ltr.models.layers.filter as filter_layer
import ltr.models.layers.activation as activation
from ltr.models.layers.distance import DistanceMap
import math
class DiMPSteepestDescentGN(nn.Module):
"""Optimizer module for DiMP.
It unrolls the steepest descent with Gauss-Newton iterations to optimize the target filter.
Moreover it learns parameters in the loss itself, as described in the DiMP paper.
args:
num_iter: Number of default optimization iterations.
feat_stride: The stride of the input feature.
init_step_length: Initial scaling of the step length (which is then learned).
init_filter_reg: Initial filter regularization weight (which is then learned).
init_gauss_sigma: The standard deviation to use for the initialization of the label function.
num_dist_bins: Number of distance bins used for learning the loss label, mask and weight.
bin_displacement: The displacement of the bins (level of discritization).
mask_init_factor: Parameter controlling the initialization of the target mask.
score_act: Type of score activation (target mask computation) to use. The default 'relu' is what is described in the paper.
act_param: Parameter for the score_act.
min_filter_reg: Enforce a minimum value on the regularization (helps stability sometimes).
mask_act: What activation to do on the output of the mask computation ('sigmoid' or 'linear').
detach_length: Detach the filter every n-th iteration. Default is to never detech, i.e. 'Inf'.
alpha_eps: Term in the denominator of the steepest descent that stabalizes learning.
"""
def __init__(self, num_iter=1, feat_stride=16, init_step_length=1.0,
init_filter_reg=1e-2, init_gauss_sigma=1.0, num_dist_bins=5, bin_displacement=1.0, mask_init_factor=4.0,
score_act='relu', act_param=None, min_filter_reg=1e-3, mask_act='sigmoid',
detach_length=float('Inf'), alpha_eps=0):
super().__init__()
self.num_iter = num_iter
self.feat_stride = feat_stride
self.log_step_length = nn.Parameter(math.log(init_step_length) * torch.ones(1))
self.filter_reg = nn.Parameter(init_filter_reg * torch.ones(1))
self.distance_map = DistanceMap(num_dist_bins, bin_displacement)
self.min_filter_reg = min_filter_reg
self.detach_length = detach_length
self.alpha_eps = alpha_eps
# Distance coordinates
d = torch.arange(num_dist_bins, dtype=torch.float32).reshape(1,-1,1,1) * bin_displacement
if init_gauss_sigma == 0:
init_gauss = torch.zeros_like(d)
init_gauss[0,0,0,0] = 1
else:
init_gauss = torch.exp(-1/2 * (d / init_gauss_sigma)**2)
# Module that predicts the target label function (y in the paper)
self.label_map_predictor = nn.Conv2d(num_dist_bins, 1, kernel_size=1, bias=False)
self.label_map_predictor.weight.data = init_gauss - init_gauss.min()
# Module that predicts the target mask (m in the paper)
mask_layers = [nn.Conv2d(num_dist_bins, 1, kernel_size=1, bias=False)]
if mask_act == 'sigmoid':
mask_layers.append(nn.Sigmoid())
init_bias = 0.0
elif mask_act == 'linear':
init_bias = 0.5
else:
raise ValueError('Unknown activation')
self.target_mask_predictor = nn.Sequential(*mask_layers)
self.target_mask_predictor[0].weight.data = mask_init_factor * torch.tanh(2.0 - d) + init_bias
# Module that predicts the residual weights (v in the paper)
self.spatial_weight_predictor = nn.Conv2d(num_dist_bins, 1, kernel_size=1, bias=False)
self.spatial_weight_predictor.weight.data.fill_(1.0)
# The score actvation and its derivative
if score_act == 'bentpar':
self.score_activation = activation.BentIdentPar(act_param)
self.score_activation_deriv = activation.BentIdentParDeriv(act_param)
elif score_act == 'relu':
self.score_activation = activation.LeakyReluPar()
self.score_activation_deriv = activation.LeakyReluParDeriv()
else:
raise ValueError('Unknown score activation')
def forward(self, weights, feat, bb, sample_weight=None, num_iter=None, compute_losses=True):
"""Runs the optimizer module.
Note that [] denotes an optional dimension.
args:
weights: Initial weights. Dims (sequences, feat_dim, wH, wW).
feat: Input feature maps. Dims (images_in_sequence, [sequences], feat_dim, H, W).
bb: Target bounding boxes (x, y, w, h) in the image coords. Dims (images_in_sequence, [sequences], 4).
sample_weight: Optional weight for each sample. Dims: (images_in_sequence, [sequences]).
num_iter: Number of iterations to run.
compute_losses: Whether to compute the (train) loss in each iteration.
returns:
weights: The final oprimized weights.
weight_iterates: The weights computed in each iteration (including initial input and final output).
losses: Train losses."""
# Sizes
num_iter = self.num_iter if num_iter is None else num_iter
num_images = feat.shape[0]
num_sequences = feat.shape[1] if feat.dim() == 5 else 1
filter_sz = (weights.shape[-2], weights.shape[-1])
output_sz = (feat.shape[-2] + (weights.shape[-2] + 1) % 2, feat.shape[-1] + (weights.shape[-1] + 1) % 2)
# Get learnable scalars
step_length_factor = torch.exp(self.log_step_length)
reg_weight = (self.filter_reg*self.filter_reg).clamp(min=self.min_filter_reg**2)
# Compute distance map
dmap_offset = (torch.Tensor(filter_sz).to(bb.device) % 2) / 2.0
center = ((bb[..., :2] + bb[..., 2:] / 2) / self.feat_stride).reshape(-1, 2).flip((1,)) - dmap_offset
dist_map = self.distance_map(center, output_sz)
# Compute label map masks and weight
label_map = self.label_map_predictor(dist_map).reshape(num_images, num_sequences, *dist_map.shape[-2:])
target_mask = self.target_mask_predictor(dist_map).reshape(num_images, num_sequences, *dist_map.shape[-2:])
spatial_weight = self.spatial_weight_predictor(dist_map).reshape(num_images, num_sequences, *dist_map.shape[-2:])
# Get total sample weights
if sample_weight is None:
sample_weight = math.sqrt(1.0 / num_images) * spatial_weight
elif isinstance(sample_weight, torch.Tensor):
sample_weight = sample_weight.sqrt().reshape(num_images, num_sequences, 1, 1) * spatial_weight
backprop_through_learning = (self.detach_length > 0)
weight_iterates = [weights]
losses = []
for i in range(num_iter):
if not backprop_through_learning or (i > 0 and i % self.detach_length == 0):
weights = weights.detach()
# Compute residuals
scores = filter_layer.apply_filter(feat, weights)
scores_act = self.score_activation(scores, target_mask)
score_mask = self.score_activation_deriv(scores, target_mask)
residuals = sample_weight * (scores_act - label_map)
if compute_losses:
losses.append(((residuals**2).sum() + reg_weight * (weights**2).sum())/num_sequences)
# Compute gradient
residuals_mapped = score_mask * (sample_weight * residuals)
weights_grad = filter_layer.apply_feat_transpose(feat, residuals_mapped, filter_sz, training=self.training) + \
reg_weight * weights
# Map the gradient with the Jacobian
scores_grad = filter_layer.apply_filter(feat, weights_grad)
scores_grad = sample_weight * (score_mask * scores_grad)
# Compute optimal step length
alpha_num = (weights_grad * weights_grad).sum(dim=(1,2,3))
alpha_den = ((scores_grad * scores_grad).reshape(num_images, num_sequences, -1).sum(dim=(0,2)) + (reg_weight + self.alpha_eps) * alpha_num).clamp(1e-8)
alpha = alpha_num / alpha_den
# Update filter
weights = weights - (step_length_factor * alpha.reshape(-1, 1, 1, 1)) * weights_grad
# Add the weight iterate
weight_iterates.append(weights)
if compute_losses:
scores = filter_layer.apply_filter(feat, weights)
scores = self.score_activation(scores, target_mask)
losses.append((((sample_weight * (scores - label_map))**2).sum() + reg_weight * (weights**2).sum())/num_sequences)
return weights, weight_iterates, losses