Source code for escnn.nn.modules.nonlinearities.leakyrelu

from typing import Any, List, Tuple

import numpy as np
import torch
import torch.nn.functional as F
from escnn.gspaces import GSpace
from escnn.nn import FieldType, GeometricTensor

from ..equivariant_module import EquivariantModule

__all__ = ["LeakyReLU"]


[docs]class LeakyReLU(EquivariantModule): def __init__( self, in_type: FieldType, negative_slope: float = 0.01, inplace: bool = False ): r""" Module that implements a pointwise LeakyReLU to every channel independently. The input representation is preserved by this operation and, therefore, it equals the output representation. Only representations supporting pointwise non-linearities are accepted as input field type. Args: in_type (FieldType): the input field type negative_slope (float, optional): Controls the angle of the negative slope (which is used for negative input values). Default: 0.01 inplace (bool, optional): can optionally do the operation in-place. Default: ``False`` """ assert isinstance(in_type.gspace, GSpace) super(LeakyReLU, self).__init__() for r in in_type.representations: assert ( "pointwise" in r.supported_nonlinearities ), 'Error! Representation "{}" does not support "pointwise" non-linearity'.format( r.name ) self.space = in_type.gspace self.in_type = in_type # the representation in input is preserved self.out_type = in_type self._negative_slope = negative_slope self._inplace = inplace
[docs] def forward(self, input: GeometricTensor) -> GeometricTensor: r""" Applies leaky-ReLU function on the input fields Args: input (GeometricTensor): the input feature map Returns: the resulting feature map after relu has been applied """ assert ( input.type == self.in_type ), "Error! the type of the input does not match the input type of this module" return GeometricTensor( F.leaky_relu( input.tensor, negative_slope=self._negative_slope, inplace=self._inplace ), self.out_type, input.coords, )
def evaluate_output_shape(self, input_shape: Tuple[int, ...]) -> Tuple[int, ...]: assert len(input_shape) >= 2 assert input_shape[1] == self.in_type.size b, c = input_shape[:2] spatial_shape = input_shape[2:] return (b, self.out_type.size, *spatial_shape) def check_equivariance( self, atol: float = 1e-6, rtol: float = 1e-5 ) -> List[Tuple[Any, float]]: c = self.in_type.size x = torch.randn(3, c, 10, 10) x = GeometricTensor(x, self.in_type) errors = [] for el in self.space.testing_elements: out1 = self(x).transform_fibers(el) out2 = self(x.transform_fibers(el)) errs = (out1.tensor - out2.tensor).detach().numpy() errs = np.abs(errs).reshape(-1) print(el, errs.max(), errs.mean(), errs.var()) assert torch.allclose( out1.tensor, out2.tensor, atol=atol, rtol=rtol ), 'The error found during equivariance check with element "{}" is too high: max = {}, mean = {} var ={}'.format( el, errs.max(), errs.mean(), errs.var() ) errors.append((el, errs.mean())) return errors def extra_repr(self): return "inplace={}, type={}".format(self._inplace, self.in_type)
[docs] def export(self): r""" Export this module to a normal PyTorch :class:`torch.nn.LeakyReLU` module and set to "eval" mode. """ self.eval() return torch.nn.LeakyReLU( negative_slope=self._negative_slope, inplace=self._inplace )