Source code for e2cnn.gspaces.r2.general_r2

from __future__ import annotations

from e2cnn.gspaces import GSpace
from e2cnn import kernels, diffops
from e2cnn.group import Group
from e2cnn.group import Representation
from e2cnn.diffops import DiscretizationArgs

from abc import abstractmethod
from typing import List, Union

from collections import defaultdict

__all__ = ["GeneralOnR2"]


[docs]class GeneralOnR2(GSpace): def __init__(self, fibergroup: Group, name: str): r""" Abstract class for the G-spaces which define the symmetries of the plane :math:`\R^2`. Args: fibergroup (Group): group of origin-preserving symmetries (fiber group) name (str): identification name """ super(GeneralOnR2, self).__init__(fibergroup, 2, name) # in order to not recompute the basis for the same intertwiner as many times as it appears, we store the basis # in these dictionaries the first time we compute it # Store the computed intertwiners between irreps # - key = (filter size, sigma, rings) # - value = dictionary mapping (input_irrep, output_irrep) pairs to the corresponding basis self._irreps_intertwiners_basis_memory = defaultdict(lambda: dict()) # Store the computed intertwiners between general representations # - key = (filter size, sigma, rings) # - value = dictionary mapping (input_repr, output_repr) pairs to the corresponding basis self._fields_intertwiners_basis_memory = defaultdict(dict)
[docs] def build_kernel_basis(self, in_repr: Representation, out_repr: Representation, sigma: Union[float, List[float]], rings: List[float], **kwargs) -> kernels.KernelBasis: r""" Builds a basis for the space of the equivariant kernels with respect to the symmetries described by this :class:`~e2cnn.gspaces.GSpace`. A kernel :math:`\kappa` equivariant to a group :math:`G` needs to satisfy the following equivariance constraint: .. math:: \kappa(gx) = \rho_\text{out}(g) \kappa(x) \rho_\text{in}(g)^{-1} \qquad \forall g \in G, x \in \R^2 where :math:`\rho_\text{in}` is ``in_repr`` while :math:`\rho_\text{out}` is ``out_repr``. Because the equivariance constraints only restrict the angular part of the kernels, any radial profile is permitted. The basis for the radial profile used here contains rings with different radii (``rings``) associated with (possibly different) widths (``sigma``). A ring is implemented as a Gaussian function over the radial component, centered at one radius (see also :class:`~e2cnn.kernels.GaussianRadialProfile`). .. note :: This method is a wrapper for the functions building the bases which are defined in :doc:`e2cnn.kernels`: - :meth:`e2cnn.kernels.kernels_O2_act_R2`, - :meth:`e2cnn.kernels.kernels_SO2_act_R2`, - :meth:`e2cnn.kernels.kernels_DN_act_R2`, - :meth:`e2cnn.kernels.kernels_CN_act_R2`, - :meth:`e2cnn.kernels.kernels_Flip_act_R2`, - :meth:`e2cnn.kernels.kernels_Trivial_act_R2` Args: in_repr (Representation): the input representation out_repr (Representation): the output representation sigma (list or float): parameters controlling the width of each ring of the radial profile. If only one scalar is passed, it is used for all rings rings (list): radii of the rings defining the radial profile **kwargs: Group-specific keywords arguments for ``_basis_generator`` method Returns: the analytical basis """ assert isinstance(in_repr, Representation) assert isinstance(out_repr, Representation) assert in_repr.group == self.fibergroup assert out_repr.group == self.fibergroup if isinstance(sigma, float): sigma = [sigma] * len(rings) assert all([s > 0. for s in sigma]) assert len(sigma) == len(rings) # build the key key = dict(**kwargs) key["sigma"] = tuple(sigma) key["rings"] = tuple(rings) key = tuple(sorted(key.items())) if (in_repr.name, out_repr.name) not in self._fields_intertwiners_basis_memory[key]: # TODO - we could use a flag in the args to choose whether to store it or not basis = self._basis_generator(in_repr, out_repr, rings, sigma, **kwargs) # store the basis in the dictionary self._fields_intertwiners_basis_memory[key][(in_repr.name, out_repr.name)] = basis # return the dictionary with all the basis built for this filter size return self._fields_intertwiners_basis_memory[key][(in_repr.name, out_repr.name)]
[docs] def build_diffop_basis(self, in_repr: Representation, out_repr: Representation, max_power: int, discretization: DiscretizationArgs, **kwargs) -> diffops.DiffopBasis: r""" Builds a basis for the space of the equivariant PDOs with respect to the symmetries described by this :class:`~e2cnn.gspaces.GSpace`. A :math:`G`-equivariant PDO :math:`D(P)` for a matrix of polynomials :math:`P`, mapping between an input field, transforming under :math:`\rho_\text{in}` (``in_repr``), and an output field, transforming under :math:`\rho_\text{out}` (``out_repr``), satisfies the following constraint: .. math :: P(gx) = \rho_\text{out}(g) P(x) \rho_\text{in}(g)^{-1} \qquad \forall g \in G, \forall x \in X for :math:`G \leq O(d)`. A complete basis is obtained by combining certain PDOs with powers of the Laplacian operator. ``max_power`` describes the maximum power of the Laplacian to use when building the basis. .. note :: This method is a wrapper for the functions building the bases which are defined in :doc:`e2cnn.diffops`: - :meth:`e2cnn.diffops.diffops_O2_act_R2`, - :meth:`e2cnn.diffops.diffops_SO2_act_R2`, - :meth:`e2cnn.diffops.diffops_DN_act_R2`, - :meth:`e2cnn.diffops.diffops_CN_act_R2`, - :meth:`e2cnn.diffops.diffops_Flip_act_R2`, - :meth:`e2cnn.diffops.diffops_Trivial_act_R2` Args: in_repr (Representation): the input representation out_repr (Representation): the output representation max_power (int): the largest power of the Laplacian that will be used discretization (DiscretizationArgs): the parameters specifying a discretization procedure for PDOs **kwargs: Group-specific keywords arguments for ``_basis_generator`` method Returns: the analytical basis """ assert isinstance(in_repr, Representation) assert isinstance(out_repr, Representation) assert in_repr.group == self.fibergroup assert out_repr.group == self.fibergroup # build the key key = dict(**kwargs) key = tuple(sorted(key.items())) if (in_repr.name, out_repr.name) not in self._fields_intertwiners_basis_memory[key]: # TODO - we could use a flag in the args to choose whether to store it or not basis = self._diffop_basis_generator(in_repr, out_repr, max_power, discretization, **kwargs) # store the basis in the dictionary self._fields_intertwiners_basis_memory[key][(in_repr.name, out_repr.name)] = basis # return the dictionary with all the basis built for this filter size return self._fields_intertwiners_basis_memory[key][(in_repr.name, out_repr.name)]
@abstractmethod def _basis_generator(self, in_repr: Representation, out_repr: Representation, rings: List[float], sigma: List[float], **kwargs): pass @abstractmethod def _diffop_basis_generator(self, in_repr: Representation, out_repr: Representation, max_power: int, discretization: DiscretizationArgs, **kwargs): pass