You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

22 lines
426 B

import torch
import torch.nn as nn
class LeakyReluPar(nn.Module):
r"""LeakyRelu parametric activation
"""
def forward(self, x, a):
return (1.0 - a)/2.0 * torch.abs(x) + (1.0 + a)/2.0 * x
class LeakyReluParDeriv(nn.Module):
r"""Derivative of the LeakyRelu parametric activation, wrt x.
"""
def forward(self, x, a):
return (1.0 - a)/2.0 * torch.sign(x.detach()) + (1.0 + a)/2.0