from __future__ import annotations
from e2cnn import gspaces
from e2cnn import kernels
from e2cnn import diffops
from .general_r2 import GeneralOnR2
from .utils import rotate_array
from typing import Union, Tuple, Callable, List
from import Representation
from import Group
from import CyclicGroup
from import SO2
from import cyclic_group
from import so2_group
import numpy as np
from e2cnn.diffops import DiscretizationArgs
__all__ = ["Rot2dOnR2"]
[docs]class Rot2dOnR2(GeneralOnR2):
def __init__(self, N: int = None, maximum_frequency: int = None, fibergroup: Group = None):
Describes rotation symmetries of the plane :math:`\R^2`.
If ``N > 1``, the class models *discrete* rotations by angles which are multiple of :math:`\frac{2\pi}{N}`
Otherwise, if ``N=-1``, the class models *continuous* planar rotations (:class:``).
In that case the parameter ``maximum_frequency`` is required to specify the maximum frequency of the irreps of
:class:`` (see its documentation for more details)
N (int): number of discrete rotations (integer greater than 1) or ``-1`` for continuous rotations
maximum_frequency (int): maximum frequency of :class:``'s irreps if ``N = -1``
fibergroup (Group, optional): use an already existing instance of the symmetry group.
In that case, the other parameters should not be provided.
assert N is not None or fibergroup is not None, "Error! Either use the parameter `N` or the parameter `group`!"
if fibergroup is not None:
assert isinstance(fibergroup, CyclicGroup) or isinstance(fibergroup, SO2)
assert maximum_frequency is None, "Maximum Frequency can't be set when the group is already provided in input"
N = fibergroup.order()
assert isinstance(N, int)
if N > 1:
assert maximum_frequency is None, "Maximum Frequency can't be set for finite cyclic groups"
name = '{}-Rotations'.format(N)
elif N == -1:
name = 'Continuous-Rotations'
raise ValueError(f'Error! "N" has to be an integer greater than 1 or -1, but got {N}')
if fibergroup is None:
if N > 1:
fibergroup = cyclic_group(N)
elif N == -1:
fibergroup = so2_group(maximum_frequency)
super(Rot2dOnR2, self).__init__(fibergroup, name)
[docs] def restrict(self, id: int) -> Tuple[gspaces.GSpace, Callable, Callable]:
Build the :class:`` associated with the subgroup of the current fiber group identified by
the input ``id``.
``id`` is a positive integer :math:`M` indicating the number of rotations in the subgroup.
If the current fiber group is :math:`C_N` (:class:``), then :math:`M` needs to divide
:math:`N`. Otherwise, :math:`M` can be any positive integer.
id (int): the number :math:`M` of rotations in the subgroup
a tuple containing
- **gspace**: the restricted gspace
- **back_map**: a function mapping an element of the subgroup to itself in the fiber group of the original space
- **subgroup_map**: a function mapping an element of the fiber group of the original space to itself in the subgroup (returns ``None`` if the element is not in the subgroup)
subgroup, mapping, child = self.fibergroup.subgroup(id)
if id > 1:
return gspaces.Rot2dOnR2(fibergroup=subgroup), mapping, child
elif id == 1:
return gspaces.TrivialOnR2(fibergroup=subgroup), mapping, child
raise ValueError(f"id {id} not recognized!")
def _basis_generator(self,
in_repr: Representation,
out_repr: Representation,
rings: List[float],
sigma: List[float],
) -> kernels.KernelBasis:
Method that builds the analitical basis that spans the space of equivariant filters which
are intertwiners between the representations induced from the representation ``in_repr`` and ``out_repr``.
If this :class:`~e2cnn.gspaces.GSpace` includes only a discrete number of rotations (``N > 1``), either ``maximum_frequency``
or ``maximum_offset`` must be set in the keywords arguments.
in_repr (Representation): the input representation
out_repr (Representation): the output representation
rings (list): radii of the rings where to sample the bases
sigma (list): parameters controlling the width of each ring where the bases are sampled.
Keyword Args:
maximum_frequency (int): the maximum frequency allowed in the basis vectors
maximum_offset (int): the maximum frequencies offset for each basis vector with respect to its base ones (sum and difference of the frequencies of the input and the output representations)
the basis built
if self.fibergroup.order() > 0:
maximum_frequency = None
maximum_offset = None
if 'maximum_frequency' in kwargs and kwargs['maximum_frequency'] is not None:
maximum_frequency = kwargs['maximum_frequency']
assert isinstance(maximum_frequency, int) and maximum_frequency >= 0
if 'maximum_offset' in kwargs and kwargs['maximum_offset'] is not None:
maximum_offset = kwargs['maximum_offset']
assert isinstance(maximum_offset, int) and maximum_offset >= 0
assert (maximum_frequency is not None or maximum_offset is not None), \
'Error! Either the maximum frequency or the maximum offset for the frequencies must be set'
return kernels.kernels_CN_act_R2(in_repr, out_repr, rings, sigma,
return kernels.kernels_SO2_act_R2(in_repr, out_repr, rings, sigma)
def _diffop_basis_generator(self,
in_repr: Representation,
out_repr: Representation,
max_power: int,
discretization: DiscretizationArgs,
) -> diffops.DiffopBasis:
Method that builds the analytical basis that spans the space of equivariant PDOs which
are intertwiners between the representations induced from the representation ``in_repr`` and ``out_repr``.
If this :class:`` includes only a discrete number of rotations (``n > 1``), either
``maximum_frequency`` or ``maximum_offset`` must be set in the keywords arguments.
in_repr: the input representation
out_repr: the output representation
max_power (int): the maximum power of Laplacians to use
discretization (DiscretizationArgs): the parameters specifying a discretization procedure for PDOs
Keyword Args:
maximum_frequency (int): the maximum frequency allowed in the basis vectors
maximum_offset (int): the maximum frequencies offset for each basis vector with respect to its base ones
(sum and difference of the frequencies of the input and the output representations)
the basis built
if self.fibergroup.order() > 0:
maximum_frequency = None
maximum_offset = None
if 'maximum_frequency' in kwargs and kwargs['maximum_frequency'] is not None:
maximum_frequency = kwargs['maximum_frequency']
assert isinstance(maximum_frequency, int) and maximum_frequency >= 0
if 'maximum_offset' in kwargs and kwargs['maximum_offset'] is not None:
maximum_offset = kwargs['maximum_offset']
assert isinstance(maximum_offset, int) and maximum_offset >= 0
assert (maximum_frequency is not None or maximum_offset is not None), \
'Error! Either the maximum frequency or the maximum offset for the frequencies must be set'
return diffops.diffops_CN_act_R2(in_repr, out_repr, max_power,
max_offset=maximum_offset, discretization=discretization)
return diffops.diffops_SO2_act_R2(in_repr, out_repr, max_power, discretization=discretization)
def _basespace_action(self, input: np.ndarray, element: Union[float, int]) -> np.ndarray:
assert self.fibergroup.is_element(element)
if self.fibergroup.order() > 1:
n = self.fibergroup.order()
rotation = element * 2.0 * np.pi / n
rotation = element
output = rotate_array(input, rotation)
return output
def __eq__(self, other):
if isinstance(other, Rot2dOnR2):
return self.fibergroup == other.fibergroup
return False
def __hash__(self):
return hash(