Source code for escnn.nn.modules.nonlinearities.fourier_quotient


from escnn.group import *
from escnn.gspaces import *
from escnn.nn import FieldType
from escnn.nn import GeometricTensor

from ..equivariant_module import EquivariantModule

import torch
import torch.nn.functional as F

from typing import List, Tuple, Any

import numpy as np

__all__ = ["QuotientFourierPointwise", "QuotientFourierELU"]


def _build_kernel(G: Group, subgroup_id: Tuple, irrep: List[tuple]):
    kernel = []
    
    X: HomSpace = G.homspace(subgroup_id)
    
    for irr in irrep:
        k = X._dirac_kernel_ft(irr, X.H.trivial_representation.id)
        # irr = G.irrep(*irr)
        # K *= np.sqrt(irr.size)
        kernel.append(k.T.reshape(-1))
    
    kernel = np.concatenate(kernel)
    return kernel
    

[docs]class QuotientFourierPointwise(EquivariantModule): def __init__(self, gspace: GSpace, subgroup_id: Tuple, channels: int, irreps: List, *grid_args, grid: List[GroupElement] = None, function: str = 'p_relu', inplace: bool=True, out_irreps: List = None, normalize: bool = True, **grid_kwargs ): r""" Applies a Inverse Fourier Transform to sample the input features on a *quotient space* :math:`X`, apply the pointwise non-linearity in the spatial domain (Dirac-delta basis) and, finally, computes the Fourier Transform to obtain irreps coefficients. The quotient space used is isomorphic to :math:`X \cong G / H` where :math:`G` is ```gspace.fibergroup``` while :math:`H` is the subgroup of :math:`G` idenitified by ```subgroup_id```; see :meth:`~escnn.group.Group.subgroup` and :meth:`~escnn.group.Group.homspace` .. warning:: This operation is only *approximately* equivariant and its equivariance depends on the sampling grid and the non-linear activation used, as well as the original band-limitation of the input features. The same function is applied to every channel independently. By default, the input representation is preserved by this operation and, therefore, it equals the output representation. Optionally, the output can have a different band-limit by using the argument ``out_irreps``. The class first constructs a band-limited quotient representation of ```gspace.fibergroup``` using :meth:`escnn.group.Group.spectral_quotient_representation`. The band-limitation of this representation is specified by ```irreps``` which should be a list containing a list of ids identifying irreps of ```gspace.fibergroup``` (see :attr:`escnn.group.IrreducibleRepresentation.id`). This representation is used to define the input and output field types, each containing ```channels``` copies of a feature field transforming according to this representation. A feature vector transforming according to such representation is interpreted as a vector of coefficients parameterizing a function over the group using a band-limited Fourier basis. .. note:: Instead of building the list ``irreps`` manually, most groups implement a method ``bl_irreps()`` which can be used to generate this list with through a simpler interface. Check each group's documentation. To approximate the Fourier transform, this module uses a finite number of samples from the group. The set of samples to be used can be specified through the parameter ```grid``` or by the ```grid_args``` and ```grid_kwargs``` which will then be passed to the method :meth:`~escnn.group.Group.grid`. .. warning :: By definition, an homogeneous space is invariant under a right action of the subgroup :math:`H`. That means that a feature representing a function over a homogeneous space :math:`X \cong G/H`, when interpreted as a function over :math:`G` (as we do here when sampling), the function will be constant along each coset, i.e. :math:`f(gh) = f(g)` if :math:`g \in G, h\in H`. An approximately uniform sampling grid over :math:`G` creates an approximately uniform grid over :math:`G/H` through projection but might contain redundant elements (if the grid contains :math:`g \in G`, any element :math:`gh` in the grid will be redundant). It is therefore advised to create a grid directly in the quotient space, e.g. using :meth:`escnn.group.SO3.sphere_grid`, :meth:`escnn.group.O3.sphere_grid`. We do not support yet a general method and interface to generate grids over any homogeneous space for any group, so you should check each group's methods. Args: gspace (GSpace): the gspace describing the symmetries of the data. The Fourier transform is performed over the group ```gspace.fibergroup``` subgroup_id (tuple): identifier of the subgroup :math:`H` to construct the quotient space channels (int): number of independent fields in the input `FieldType` irreps (list): list of irreps' ids to construct the band-limited representation *grid_args: parameters used to construct the discretization grid grid (list, optional): list containing the elements of the group to use for sampling. Optional (default ``None``). function (str): the identifier of the non-linearity. It is used to specify which function to apply. By default (``'p_relu'``), ReLU is used. inplace (bool): applies the non-linear activation in-place. Default: `True` out_irreps (list, optional): optionally, one can specify a different band-limiting in output normalize (bool, optional): if ``True``, the rows of the IFT matrix (and the columns of the FT matrix) are normalized. Default: ``True`` **grid_kwargs: keyword parameters used to construct the discretization grid """ assert isinstance(gspace, GSpace) super(QuotientFourierPointwise, self).__init__() self.space = gspace G: Group = gspace.fibergroup self.rho = G.spectral_quotient_representation(subgroup_id, *irreps, name=None) self.in_type = FieldType(self.space, [self.rho]*channels) if out_irreps is None: # the representation in input is preserved self.out_type = self.in_type self.rho_out = self.rho else: self.rho_out = G.spectral_quotient_representation(subgroup_id, *out_irreps, name=None) self.out_type = FieldType(self.space, [self.rho_out] * channels) # retrieve the activation function to apply if function == 'p_relu': self._function = F.relu_ if inplace else F.relu elif function == 'p_elu': self._function = F.elu_ if inplace else F.elu elif function == 'p_sigmoid': self._function = torch.sigmoid_ if inplace else F.sigmoid elif function == 'p_tanh': self._function = torch.tanh_ if inplace else F.tanh else: raise ValueError('Function "{}" not recognized!'.format(function)) kernel = _build_kernel(G, subgroup_id, irreps) assert kernel.shape[0] == self.rho.size if normalize: kernel = kernel / np.linalg.norm(kernel) kernel = kernel.reshape(-1, 1) if grid is None: grid = G.grid(*grid_args, **grid_kwargs) A = np.concatenate( [ self.rho(g) @ kernel for g in grid ], axis=1 ).T if out_irreps is not None: _missing_input_irreps = list(set(irreps).difference(set(out_irreps))) # _missing_input_irreps = [] rho_out_extended = G.spectral_quotient_representation(subgroup_id, *out_irreps, *_missing_input_irreps, name=None) kernel_out = _build_kernel(G, subgroup_id, out_irreps + _missing_input_irreps) assert kernel_out.shape[0] == rho_out_extended.size if normalize: kernel_out = kernel_out / np.linalg.norm(kernel_out) kernel_out = kernel_out.reshape(-1, 1) A_out = np.concatenate( [ rho_out_extended(g) @ kernel_out for g in grid ], axis=1 ).T else: A_out = A _missing_input_irreps = [] rho_out_extended = self.rho_out eps = 1e-8 Ainv = np.linalg.inv(A_out.T @ A_out + eps * np.eye(rho_out_extended.size)) @ A_out.T if out_irreps is not None: Ainv = Ainv[:self.rho_out.size, :] self.register_buffer('A', torch.tensor(A, dtype=torch.get_default_dtype())) self.register_buffer('Ainv', torch.tensor(Ainv, dtype=torch.get_default_dtype()))
[docs] def forward(self, input: GeometricTensor) -> GeometricTensor: r""" Applies the pointwise activation function on the input fields Args: input (GeometricTensor): the input feature map Returns: the resulting feature map after the non-linearities have been applied """ assert input.type == self.in_type shape = input.shape x_hat = input.tensor.view(shape[0], len(self.in_type), self.rho.size, *shape[2:]) x = torch.einsum('bcf...,gf->bcg...', x_hat, self.A) y = self._function(x) y_hat = torch.einsum('bcg...,fg->bcf...', y, self.Ainv) y_hat = y_hat.reshape(shape[0], self.out_type.size, *shape[2:]) return GeometricTensor(y_hat, self.out_type, input.coords)
def evaluate_output_shape(self, input_shape: Tuple[int, ...]) -> Tuple[int, ...]: assert len(input_shape) >= 2 assert input_shape[1] == self.in_type.size b, c = input_shape[:2] spatial_shape = input_shape[2:] return (b, self.out_type.size, *spatial_shape) def check_equivariance(self, atol: float = 1e-5, rtol: float = 2e-2, assert_raise: bool = True) -> List[Tuple[Any, float]]: c = self.in_type.size B = 64 x = torch.randn(B, c, *[3]*self.space.dimensionality) # since we mostly use non-linearities like relu or elu, we make sure the average value of the features is # positive, such that, when we test inputs with only frequency 0 (or only low frequencies), the output is not # zero everywhere x = x.view(B, len(self.in_type), self.rho.size, *[3]*self.space.dimensionality) p = 0 for irr in self.rho.irreps: irr = self.space.irrep(*irr) if irr.is_trivial(): x[:, :, p] = x[:, :, p].abs() p+=irr.size x = x.view(B, self.in_type.size, *[3]*self.space.dimensionality) errors = [] # for el in self.space.testing_elements: for _ in range(100): el = self.space.fibergroup.sample() x1 = GeometricTensor(x.clone(), self.in_type) x2 = GeometricTensor(x.clone(), self.in_type).transform_fibers(el) out1 = self(x1).transform_fibers(el) out2 = self(x2) out1 = out1.tensor.view(B, len(self.out_type), self.rho_out.size, *out1.shape[2:]).detach().numpy() out2 = out2.tensor.view(B, len(self.out_type), self.rho_out.size, *out2.shape[2:]).detach().numpy() errs = np.linalg.norm(out1 - out2, axis=2).reshape(-1) errs[errs < atol] = 0. norm = np.sqrt(np.linalg.norm(out1, axis=2).reshape(-1) * np.linalg.norm(out2, axis=2).reshape(-1)) relerr = errs / norm # print(el, errs.max(), errs.mean(), relerr.max(), relerr.min()) if assert_raise: assert relerr.mean()+ relerr.std() < rtol, \ 'The error found during equivariance check with element "{}" is too high: max = {}, mean = {}, std ={}, maxerr={}, xmean={}, xstd={}' \ .format(el, relerr.max(), relerr.mean(), relerr.std(), errs[np.argmax(relerr)], out1.mean(), out1.std()) # errors.append((el, errs.mean())) errors.append(relerr) # return errors return np.concatenate(errors)
[docs]class QuotientFourierELU(QuotientFourierPointwise): def __init__(self, gspace: GSpace, subgroup_id: Tuple, channels: int, irreps: List, *grid_args, grid: List[GroupElement] = None, inplace: bool = True, out_irreps: List = None, normalize: bool = True, **grid_kwargs ): r""" Applies a Inverse Fourier Transform to sample the input features on a quotient space, apply ELU point-wise in the spatial domain (Dirac-delta basis) and, finally, computes the Fourier Transform to obtain irreps coefficients. See :class:`~escnn.nn.QuotientFourierPointwise` for more details. Args: gspace (GSpace): the gspace describing the symmetries of the data. The Fourier transform is performed over the group ```gspace.fibergroup``` subgroup_id (tuple): identifier of the subgroup :math:`H` to construct the quotient space channels (int): number of independent fields in the input `FieldType` irreps (list): list of irreps' ids to construct the band-limited representation *grid_args: parameters used to construct the discretization grid grid (list, optional): list containing the elements of the group to use for sampling. Optional (default ``None``). inplace (bool): applies the non-linear activation in-place. Default: ``True`` out_irreps (list, optional): optionally, one can specify a different band-limiting in output normalize (bool, optional): if ``True``, the rows of the IFT matrix (and the columns of the FT matrix) are normalized. Default: ``True`` **grid_kwargs: keyword parameters used to construct the discretization grid """ super(QuotientFourierELU, self).__init__( gspace, subgroup_id, channels, irreps, *grid_args, function='p_elu', inplace=inplace, grid=grid, out_irreps=out_irreps, normalize=normalize, **grid_kwargs )