Source code for e2cnn.nn.modules.nonlinearities.concatenated


from e2cnn.gspaces import *
from e2cnn.nn import FieldType
from e2cnn.nn import GeometricTensor
from e2cnn.group import Representation
from e2cnn.group.representation import build_from_discrete_group_representation

from ..equivariant_module import EquivariantModule

import torch

from typing import List, Tuple, Any

import numpy as np
import math

__all__ = ["ConcatenatedNonLinearity"]


[docs]class ConcatenatedNonLinearity(EquivariantModule): def __init__(self, in_type: FieldType, function: str = 'c_relu'): r""" Concatenated non-linearities. For each input channel, the module applies the specified activation function both to its value and its opposite (the value multiplied by -1). The number of channels is, therefore, doubled. Notice that not all the representations support this kind of non-linearity. Indeed, only representations with the same pattern of permutation matrices and containing only values in :math:`\{0, 1, -1\}` support it. Args: in_type (FieldType): the input field type function (str): the identifier of the non-linearity. It is used to specify which function to apply. By default (``'c_relu'``), ReLU is used. """ assert isinstance(in_type.gspace, GeneralOnR2) for r in in_type.representations: assert 'concatenated' in r.supported_nonlinearities,\ 'Error! Representation "{}" does not support "concatenated" non-linearity'.format(r.name) super(ConcatenatedNonLinearity, self).__init__() self.space = in_type.gspace self.in_type = in_type # compute the output representation given the input one self.out_type = ConcatenatedNonLinearity._transform_fiber_representation(in_type) # retrieve the activation function to apply if function == 'c_relu': self._function = torch.relu elif function == 'c_sigmoid': self._function = torch.sigmoid elif function == 'c_tanh': self._function = torch.tanh else: raise ValueError('Function "{}" not recognized!'.format(function)) def forward(self, input: GeometricTensor) -> GeometricTensor: assert input.type == self.in_type b, c, w, h = input.tensor.shape # build the output tensor output = torch.empty(b, 2 * c, w, h, dtype=torch.float, device=input.tensor.device) # each channels is transformed to 2 channels: # first, apply the non-linearity to its value output[:, ::2, ...] = self._function(input.tensor) # then, apply the non-linearity to its values with the sign inverted output[:, 1::2, ...] = self._function(-1 * input.tensor) # wrap the result in a GeometricTensor return GeometricTensor(output, self.out_type) def evaluate_output_shape(self, input_shape: Tuple[int, int, int, int]) -> Tuple[int, int, int, int]: assert len(input_shape) == 4 assert input_shape[1] == self.in_type.size b, c, hi, wi = input_shape return b, self.out_type.size, hi, wi def check_equivariance(self, atol: float = 1e-6, rtol: float = 1e-5) -> List[Tuple[Any, float]]: c = self.in_type.size x = torch.randn(3, c, 10, 10) x = GeometricTensor(x, self.in_type) errors = [] for el in self.space.testing_elements: out1 = self(x).transform_fibers(el) out2 = self(x.transform_fibers(el)) errs = (out1.tensor - out2.tensor).detach().numpy() errs = np.abs(errs).reshape(-1) print(el, errs.max(), errs.mean(), errs.var()) assert torch.allclose(out1.tensor, out2.tensor, atol=atol, rtol=rtol), \ 'The error found during equivariance check with element "{}" is too high: max = {}, mean = {} var ={}' \ .format(el, errs.max(), errs.mean(), errs.var()) errors.append((el, errs.mean())) return errors @staticmethod def _transform_fiber_representation(in_type: FieldType) -> FieldType: r""" Compute the output representation from the input one after applying the concatenated non-linearity. Args: in_type (FieldType): the input field type Returns: (FieldType): the new output field type """ transformed = {} # transform each different input Representation for repr in in_type._unique_representations: transformed[repr] = ConcatenatedNonLinearity._transform_representation(repr) new_representations = [] # concatenate the new representations for repr in in_type.representations: new_representations.append(transformed[repr]) return FieldType(in_type.gspace, new_representations) @staticmethod def _transform_representation(representation: Representation) -> Representation: r""" Transform an input :class:`~e2cnn.group.Representation` according to the concatenated non-linearity. The input representation needs to have the pattern of a permutation matrix, with values -1 or 1. The output representation has double the size of the input one and is built by substituting the ``1`` s with 2x2 identity matrices and the ``-1`` s with 2x2 antidiagonal matrix containing ``1`` s. Args: representation (Representation): the input representation Returns: (Representation): the new output representation """ group = representation.group assert not group.continuous # the name of the new representation name = "concatenated_{}".format(representation.name) if name in group.representations: # if the representation has already been built, return it r = group.representations[name] else: # otherwise, build the new representation s = representation.size rep = {} # build the representation for each element for element in group.elements: # retrieve the input representation of the current element r = representation(element) # build the matrix for the output representation of the current element rep[element] = np.zeros((2*s, 2*s)) # check if the input matrix has the pattern of a permutation matrix e = [-1]*s for i in range(s): for j in range(s): if not math.isclose(r[i, j], 0, abs_tol=1e-9): if e[i] < 0: e[i] = j else: raise ValueError( '''Error! the representation should have the pattern of a permutation matrix but 2 values have been found in a row for element "{}"''' .format(element) ) if len(set(e)) != len(e): raise ValueError( '''Error! the representation should have the pattern of a permutation matrix but 2 values have been found in a column for element "{}"''' .format(element) ) # parse the input representation matrix and fill the output representation accordingly for i in range(s): for j in range(s): if math.isclose(r[i, j], 1, abs_tol=1e-9): # if the current cell contains 1, fill the output submatrix with the 2x2 identity rep[element][2 * i:2 * i + 2, 2 * j:2 * j + 2] = np.eye(2) elif math.isclose(r[i, j], -1, abs_tol=1e-9): # if the current cell contains -1, fill the output submatrix with the 2x2 antidigonal matrix rep[element][2 * i:2 * i + 2, 2 * j:2 * j + 2] = np.flipud(np.eye(2)) elif not math.isclose(r[i, j], 0, abs_tol=1e-9): # otherwise the cell has to contain a 0 raise ValueError( '''Error! The representation should be a signed permutation matrix and, therefore, contain only -1, 1 or 0 values but {} found in position({}, {}) for element "{}"''' .format(r[i, j], i, j, element) ) # the resulting representation is a quotient repreentation and, therefore, # it also supports pointwise non-linearities nonlinearities = representation.supported_nonlinearities.union(['pointwise']) # build the output representation r = build_from_discrete_group_representation(rep, name, group, supported_nonlinearities=nonlinearities) return r