import torch import torch.nn as nn class LeakyReluPar(nn.Module): r"""LeakyRelu parametric activation """ def forward(self, x, a): return (1.0 - a)/2.0 * torch.abs(x) + (1.0 + a)/2.0 * x class LeakyReluParDeriv(nn.Module): r"""Derivative of the LeakyRelu parametric activation, wrt x. """ def forward(self, x, a): return (1.0 - a)/2.0 * torch.sign(x.detach()) + (1.0 + a)/2.0