Source code for e2cnn.gspaces.r2.rot2d_on_r2


from __future__ import annotations

from e2cnn import gspaces
from e2cnn import kernels
from e2cnn import diffops

from .general_r2 import GeneralOnR2
from .utils import rotate_array

from typing import Union, Tuple, Callable, List

from e2cnn.group import Representation
from e2cnn.group import Group
from e2cnn.group import CyclicGroup
from e2cnn.group import SO2
from e2cnn.group import cyclic_group
from e2cnn.group import so2_group

import numpy as np
from e2cnn.diffops import DiscretizationArgs


__all__ = ["Rot2dOnR2"]


[docs]class Rot2dOnR2(GeneralOnR2): def __init__(self, N: int = None, maximum_frequency: int = None, fibergroup: Group = None): r""" Describes rotation symmetries of the plane :math:`\R^2`. If ``N > 1``, the class models *discrete* rotations by angles which are multiple of :math:`\frac{2\pi}{N}` (:class:`~e2cnn.group.CyclicGroup`). Otherwise, if ``N=-1``, the class models *continuous* planar rotations (:class:`~e2cnn.group.SO2`). In that case the parameter ``maximum_frequency`` is required to specify the maximum frequency of the irreps of :class:`~e2cnn.group.SO2` (see its documentation for more details) Args: N (int): number of discrete rotations (integer greater than 1) or ``-1`` for continuous rotations maximum_frequency (int): maximum frequency of :class:`~e2cnn.group.SO2`'s irreps if ``N = -1`` fibergroup (Group, optional): use an already existing instance of the symmetry group. In that case, the other parameters should not be provided. """ assert N is not None or fibergroup is not None, "Error! Either use the parameter `N` or the parameter `group`!" if fibergroup is not None: assert isinstance(fibergroup, CyclicGroup) or isinstance(fibergroup, SO2) assert maximum_frequency is None, "Maximum Frequency can't be set when the group is already provided in input" N = fibergroup.order() assert isinstance(N, int) if N > 1: assert maximum_frequency is None, "Maximum Frequency can't be set for finite cyclic groups" name = '{}-Rotations'.format(N) elif N == -1: name = 'Continuous-Rotations' else: raise ValueError(f'Error! "N" has to be an integer greater than 1 or -1, but got {N}') if fibergroup is None: if N > 1: fibergroup = cyclic_group(N) elif N == -1: fibergroup = so2_group(maximum_frequency) super(Rot2dOnR2, self).__init__(fibergroup, name)
[docs] def restrict(self, id: int) -> Tuple[gspaces.GSpace, Callable, Callable]: r""" Build the :class:`~e2cnn.group.GSpace` associated with the subgroup of the current fiber group identified by the input ``id``. ``id`` is a positive integer :math:`M` indicating the number of rotations in the subgroup. If the current fiber group is :math:`C_N` (:class:`~e2cnn.group.CyclicGroup`), then :math:`M` needs to divide :math:`N`. Otherwise, :math:`M` can be any positive integer. Args: id (int): the number :math:`M` of rotations in the subgroup Returns: a tuple containing - **gspace**: the restricted gspace - **back_map**: a function mapping an element of the subgroup to itself in the fiber group of the original space - **subgroup_map**: a function mapping an element of the fiber group of the original space to itself in the subgroup (returns ``None`` if the element is not in the subgroup) """ subgroup, mapping, child = self.fibergroup.subgroup(id) if id > 1: return gspaces.Rot2dOnR2(fibergroup=subgroup), mapping, child elif id == 1: return gspaces.TrivialOnR2(fibergroup=subgroup), mapping, child else: raise ValueError(f"id {id} not recognized!")
def _basis_generator(self, in_repr: Representation, out_repr: Representation, rings: List[float], sigma: List[float], **kwargs, ) -> kernels.KernelBasis: r""" Method that builds the analitical basis that spans the space of equivariant filters which are intertwiners between the representations induced from the representation ``in_repr`` and ``out_repr``. If this :class:`~e2cnn.gspaces.GSpace` includes only a discrete number of rotations (``N > 1``), either ``maximum_frequency`` or ``maximum_offset`` must be set in the keywords arguments. Args: in_repr (Representation): the input representation out_repr (Representation): the output representation rings (list): radii of the rings where to sample the bases sigma (list): parameters controlling the width of each ring where the bases are sampled. Keyword Args: maximum_frequency (int): the maximum frequency allowed in the basis vectors maximum_offset (int): the maximum frequencies offset for each basis vector with respect to its base ones (sum and difference of the frequencies of the input and the output representations) Returns: the basis built """ if self.fibergroup.order() > 0: maximum_frequency = None maximum_offset = None if 'maximum_frequency' in kwargs and kwargs['maximum_frequency'] is not None: maximum_frequency = kwargs['maximum_frequency'] assert isinstance(maximum_frequency, int) and maximum_frequency >= 0 if 'maximum_offset' in kwargs and kwargs['maximum_offset'] is not None: maximum_offset = kwargs['maximum_offset'] assert isinstance(maximum_offset, int) and maximum_offset >= 0 assert (maximum_frequency is not None or maximum_offset is not None), \ 'Error! Either the maximum frequency or the maximum offset for the frequencies must be set' return kernels.kernels_CN_act_R2(in_repr, out_repr, rings, sigma, maximum_frequency, max_offset=maximum_offset) else: return kernels.kernels_SO2_act_R2(in_repr, out_repr, rings, sigma) def _diffop_basis_generator(self, in_repr: Representation, out_repr: Representation, max_power: int, discretization: DiscretizationArgs, **kwargs, ) -> diffops.DiffopBasis: r""" Method that builds the analytical basis that spans the space of equivariant PDOs which are intertwiners between the representations induced from the representation ``in_repr`` and ``out_repr``. If this :class:`~e2cnn.group.GSpace` includes only a discrete number of rotations (``n > 1``), either ``maximum_frequency`` or ``maximum_offset`` must be set in the keywords arguments. Args: in_repr: the input representation out_repr: the output representation max_power (int): the maximum power of Laplacians to use discretization (DiscretizationArgs): the parameters specifying a discretization procedure for PDOs Keyword Args: maximum_frequency (int): the maximum frequency allowed in the basis vectors maximum_offset (int): the maximum frequencies offset for each basis vector with respect to its base ones (sum and difference of the frequencies of the input and the output representations) Returns: the basis built """ if self.fibergroup.order() > 0: maximum_frequency = None maximum_offset = None if 'maximum_frequency' in kwargs and kwargs['maximum_frequency'] is not None: maximum_frequency = kwargs['maximum_frequency'] assert isinstance(maximum_frequency, int) and maximum_frequency >= 0 if 'maximum_offset' in kwargs and kwargs['maximum_offset'] is not None: maximum_offset = kwargs['maximum_offset'] assert isinstance(maximum_offset, int) and maximum_offset >= 0 assert (maximum_frequency is not None or maximum_offset is not None), \ 'Error! Either the maximum frequency or the maximum offset for the frequencies must be set' return diffops.diffops_CN_act_R2(in_repr, out_repr, max_power, maximum_frequency, max_offset=maximum_offset, discretization=discretization) else: return diffops.diffops_SO2_act_R2(in_repr, out_repr, max_power, discretization=discretization) def _basespace_action(self, input: np.ndarray, element: Union[float, int]) -> np.ndarray: assert self.fibergroup.is_element(element) if self.fibergroup.order() > 1: n = self.fibergroup.order() rotation = element * 2.0 * np.pi / n else: rotation = element output = rotate_array(input, rotation) return output def __eq__(self, other): if isinstance(other, Rot2dOnR2): return self.fibergroup == other.fibergroup else: return False def __hash__(self): return hash(self.name)