from __future__ import annotations
from e2cnn.gspaces import GSpace
from e2cnn import kernels, diffops
from e2cnn.group import Group
from e2cnn.group import Representation
from e2cnn.diffops import DiscretizationArgs
from abc import abstractmethod
from typing import List, Union
from collections import defaultdict
__all__ = ["GeneralOnR2"]
[docs]class GeneralOnR2(GSpace):
def __init__(self, fibergroup: Group, name: str):
r"""
Abstract class for the G-spaces which define the symmetries of the plane :math:`\R^2`.
Args:
fibergroup (Group): group of origin-preserving symmetries (fiber group)
name (str): identification name
"""
super(GeneralOnR2, self).__init__(fibergroup, 2, name)
# in order to not recompute the basis for the same intertwiner as many times as it appears, we store the basis
# in these dictionaries the first time we compute it
# Store the computed intertwiners between irreps
# - key = (filter size, sigma, rings)
# - value = dictionary mapping (input_irrep, output_irrep) pairs to the corresponding basis
self._irreps_intertwiners_basis_memory = defaultdict(lambda: dict())
# Store the computed intertwiners between general representations
# - key = (filter size, sigma, rings)
# - value = dictionary mapping (input_repr, output_repr) pairs to the corresponding basis
self._fields_intertwiners_basis_memory = defaultdict(dict)
[docs] def build_kernel_basis(self,
in_repr: Representation,
out_repr: Representation,
sigma: Union[float, List[float]],
rings: List[float],
**kwargs) -> kernels.KernelBasis:
r"""
Builds a basis for the space of the equivariant kernels with respect to the symmetries described by this
:class:`~e2cnn.gspaces.GSpace`.
A kernel :math:`\kappa` equivariant to a group :math:`G` needs to satisfy the following equivariance constraint:
.. math::
\kappa(gx) = \rho_\text{out}(g) \kappa(x) \rho_\text{in}(g)^{-1} \qquad \forall g \in G, x \in \R^2
where :math:`\rho_\text{in}` is ``in_repr`` while :math:`\rho_\text{out}` is ``out_repr``.
Because the equivariance constraints only restrict the angular part of the kernels, any radial profile is
permitted.
The basis for the radial profile used here contains rings with different radii (``rings``)
associated with (possibly different) widths (``sigma``).
A ring is implemented as a Gaussian function over the radial component, centered at one radius
(see also :class:`~e2cnn.kernels.GaussianRadialProfile`).
.. note ::
This method is a wrapper for the functions building the bases which are defined in :doc:`e2cnn.kernels`:
- :meth:`e2cnn.kernels.kernels_O2_act_R2`,
- :meth:`e2cnn.kernels.kernels_SO2_act_R2`,
- :meth:`e2cnn.kernels.kernels_DN_act_R2`,
- :meth:`e2cnn.kernels.kernels_CN_act_R2`,
- :meth:`e2cnn.kernels.kernels_Flip_act_R2`,
- :meth:`e2cnn.kernels.kernels_Trivial_act_R2`
Args:
in_repr (Representation): the input representation
out_repr (Representation): the output representation
sigma (list or float): parameters controlling the width of each ring of the radial profile.
If only one scalar is passed, it is used for all rings
rings (list): radii of the rings defining the radial profile
**kwargs: Group-specific keywords arguments for ``_basis_generator`` method
Returns:
the analytical basis
"""
assert isinstance(in_repr, Representation)
assert isinstance(out_repr, Representation)
assert in_repr.group == self.fibergroup
assert out_repr.group == self.fibergroup
if isinstance(sigma, float):
sigma = [sigma] * len(rings)
assert all([s > 0. for s in sigma])
assert len(sigma) == len(rings)
# build the key
key = dict(**kwargs)
key["sigma"] = tuple(sigma)
key["rings"] = tuple(rings)
key = tuple(sorted(key.items()))
if (in_repr.name, out_repr.name) not in self._fields_intertwiners_basis_memory[key]:
# TODO - we could use a flag in the args to choose whether to store it or not
basis = self._basis_generator(in_repr, out_repr, rings, sigma, **kwargs)
# store the basis in the dictionary
self._fields_intertwiners_basis_memory[key][(in_repr.name, out_repr.name)] = basis
# return the dictionary with all the basis built for this filter size
return self._fields_intertwiners_basis_memory[key][(in_repr.name, out_repr.name)]
[docs] def build_diffop_basis(self,
in_repr: Representation,
out_repr: Representation,
max_power: int,
discretization: DiscretizationArgs,
**kwargs) -> diffops.DiffopBasis:
r"""
Builds a basis for the space of the equivariant PDOs with respect to the symmetries described by this
:class:`~e2cnn.gspaces.GSpace`.
A :math:`G`-equivariant PDO :math:`D(P)` for a matrix of polynomials :math:`P`, mapping between an input field, transforming under
:math:`\rho_\text{in}` (``in_repr``), and an output field, transforming under :math:`\rho_\text{out}`
(``out_repr``), satisfies the following constraint:
.. math ::
P(gx) = \rho_\text{out}(g) P(x) \rho_\text{in}(g)^{-1} \qquad \forall g \in G, \forall x \in X
for :math:`G \leq O(d)`.
A complete basis is obtained by combining certain PDOs with powers of the Laplacian operator.
``max_power`` describes the maximum power of the Laplacian to use when building the basis.
.. note ::
This method is a wrapper for the functions building the bases which are defined in :doc:`e2cnn.diffops`:
- :meth:`e2cnn.diffops.diffops_O2_act_R2`,
- :meth:`e2cnn.diffops.diffops_SO2_act_R2`,
- :meth:`e2cnn.diffops.diffops_DN_act_R2`,
- :meth:`e2cnn.diffops.diffops_CN_act_R2`,
- :meth:`e2cnn.diffops.diffops_Flip_act_R2`,
- :meth:`e2cnn.diffops.diffops_Trivial_act_R2`
Args:
in_repr (Representation): the input representation
out_repr (Representation): the output representation
max_power (int): the largest power of the Laplacian that will be used
discretization (DiscretizationArgs): the parameters specifying a discretization procedure for PDOs
**kwargs: Group-specific keywords arguments for ``_basis_generator`` method
Returns:
the analytical basis
"""
assert isinstance(in_repr, Representation)
assert isinstance(out_repr, Representation)
assert in_repr.group == self.fibergroup
assert out_repr.group == self.fibergroup
# build the key
key = dict(**kwargs)
key = tuple(sorted(key.items()))
if (in_repr.name, out_repr.name) not in self._fields_intertwiners_basis_memory[key]:
# TODO - we could use a flag in the args to choose whether to store it or not
basis = self._diffop_basis_generator(in_repr, out_repr, max_power, discretization, **kwargs)
# store the basis in the dictionary
self._fields_intertwiners_basis_memory[key][(in_repr.name, out_repr.name)] = basis
# return the dictionary with all the basis built for this filter size
return self._fields_intertwiners_basis_memory[key][(in_repr.name, out_repr.name)]
@abstractmethod
def _basis_generator(self,
in_repr: Representation,
out_repr: Representation,
rings: List[float],
sigma: List[float],
**kwargs):
pass
@abstractmethod
def _diffop_basis_generator(self,
in_repr: Representation,
out_repr: Representation,
max_power: int,
discretization: DiscretizationArgs,
**kwargs):
pass