Source code for e2cnn.gspaces.r2.fliprot2d_on_r2

from __future__ import annotations

from e2cnn import gspaces
from e2cnn import kernels
from e2cnn import diffops

from .general_r2 import GeneralOnR2
from .utils import rotate_array

from e2cnn.group import Representation
from e2cnn.group import Group
from e2cnn.group import DihedralGroup
from e2cnn.group import O2
from e2cnn.group import dihedral_group
from e2cnn.group import o2_group

from e2cnn.diffops import DiscretizationArgs

import numpy as np


from typing import Tuple, Union, Callable, List

__all__ = ["FlipRot2dOnR2"]


[docs]class FlipRot2dOnR2(GeneralOnR2): def __init__(self, N: int = None, maximum_frequency: int = None, axis: float = np.pi / 2, fibergroup: Group = None): r""" Describes reflectional and rotational symmetries of the plane :math:`\R^2`. Reflections are applied with respect to the line through the origin with an angle ``axis`` degrees with respect to the *X*-axis. If ``N > 1``, the class models reflections and *discrete* rotations by angles multiple of :math:`\frac{2\pi}{N}` (:class:`~e2cnn.group.DihedralGroup`). Otherwise, if ``N=-1``, the class models reflections and *continuous* planar rotations (:class:`~e2cnn.group.O2`). In that case the parameter ``maximum_frequency`` is required to specify the maximum frequency of the irreps of :class:`~e2cnn.group.O2` (see its documentation for more details) .. note :: All axes obtained from the axis defined by ``axis`` with a rotation in the symmetry group are equivalent. For instance, if ``N = 4``, an axis :math:`\beta` is equivalent to the axis :math:`\beta + \pi/2`. It follows that for ``N = -1``, i.e. in case the symmetry group contains all continuous rotations, any reflection axis is theoretically equivalent. In practice, though, a basis for equivariant convolutional filter sampled on a grid is affected by the specific choice of the axis. In general, choosing an axis aligned with the grid (an horizontal or a vertical axis, i.e. :math:`0` or :math:`\pi/2`) is suggested. Args: N (int): number of discrete rotations (integer greater than 1) or -1 for continuous rotations maximum_frequency (int): maximum frequency of :class:`~e2cnn.group.O2` 's irreps if ``N = -1`` axis (float, optional): the slope of the axis of the flip (in radians) fibergroup (Group, optional): use an already existing instance of the symmetry group. In that case only the parameter ``axis`` should be used. Attributes: ~.axis (float): Angle with respect to the horizontal axis which defines the reflection axis. """ assert N is not None or fibergroup is not None, "Error! Either use the parameter `N` or the parameter `group`!" if fibergroup is not None: assert isinstance(fibergroup, DihedralGroup) or isinstance(fibergroup, O2) assert maximum_frequency is None, "Maximum Frequency can't be set when the group is already provided in input" N = fibergroup.rotation_order assert isinstance(N, int) self.axis = axis if N > 1: assert maximum_frequency is None, "Maximum Frequency can't be set for finite cyclic groups" name = 'Flip_{}-Rotations(f={:.5f})'.format(N, self.axis) elif N == -1: name = 'Flip_Continuous-Rotations(f={:.5f})'.format(self.axis) # self.axis = np.pi/2 else: raise ValueError(f'Error! "N" has to be an integer greater than 1 or -1, but got {N}') if fibergroup is None: if N > 1: fibergroup = dihedral_group(N) elif N == -1: fibergroup = o2_group(maximum_frequency) super(FlipRot2dOnR2, self).__init__(fibergroup, name)
[docs] def restrict(self, id: Tuple[Union[None, float, int], int]) -> Tuple[gspaces.GSpace, Callable, Callable]: r""" Build the :class:`~e2cnn.group.GSpace` associated with the subgroup of the current fiber group identified by the input ``id``, which is a tuple :math:`(k, M)`. Here, :math:`M` is a positive integer indicating the number of discrete rotations in the subgroup while :math:`k` is either ``None`` (no reflections) or an angle indicating the axis of reflection. If the current fiber group is :math:`D_N` (:class:`~e2cnn.group.DihedralGroup`), then :math:`M` needs to divide :math:`N` and :math:`k` needs to be an integer in :math:`\{0, \dots, \frac{N}{M}-1\}`. Otherwise, :math:`M` can be any positive integer while :math:`k` needs to be a real number in :math:`[0, \frac{2\pi}{M}]`. Valid combinations are: - (``None``, :math:`1`): restrict to no reflection and rotation symmetries - (``None``, :math:`M`): restrict to only the :math:`M` rotations generated by :math:`r_{2\pi/M}`. - (:math:`0`, :math:`1`): restrict to only reflections :math:`\langle f \rangle` around the same axis as in the current group - (:math:`0`, :math:`M`): restrict to reflections and :math:`M` rotations generated by :math:`r_{2\pi/M}` and :math:`f` If the current fiber group is :math:`D_N` (an instance of :class:`~e2cnn.group.DihedralGroup`): - (:math:`k`, :math:`M`): restrict to reflections :math:`\langle r_{k\frac{2\pi}{N}} f \rangle` around the axis of the current G-space rotated by :math:`k\frac{\pi}{N}` and :math:`M` rotations generated by :math:`r_{2\pi/M}` If the current fiber group is :math:`O(2)` (an instance of :class:`~e2cnn.group.O2`): - (:math:`\theta`, :math:`M`): restrict to reflections :math:`\langle r_{\theta} f \rangle` around the axis of the current G-space rotated by :math:`\frac{\theta}{2}` and :math:`M` rotations generated by :math:`r_{2\pi/M}` - (``None``, :math:`-1`): restrict to all (continuous) rotations Args: id (tuple): the id of the subgroup Returns: a tuple containing - **gspace**: the restricted gspace - **back_map**: a function mapping an element of the subgroup to itself in the fiber group of the original space - **subgroup_map**: a function mapping an element of the fiber group of the original space to itself in the subgroup (returns ``None`` if the element is not in the subgroup) """ subgroup, mapping, child = self.fibergroup.subgroup(id) if id[0] is not None: # the new flip axis is the previous one rotated by the new chosen axis for the flip # notice that the actual group element used to generate the subgroup does not correspond to the flip axis # but to 2 times that angle if self.fibergroup.order() > 1: n = self.fibergroup.rotation_order rotation = id[0] * 2.0 * np.pi / n else: rotation = id[0] new_axis = divmod(self.axis + 0.5*rotation, 2*np.pi)[1] if id[0] is None and id[1] == 1: return gspaces.TrivialOnR2(fibergroup=subgroup), mapping, child elif id[0] is None and (id[1] > 1 or id[1] == -1): return gspaces.Rot2dOnR2(fibergroup=subgroup), mapping, child elif id[0] is not None and id[1] == 1: return gspaces.Flip2dOnR2(fibergroup=subgroup, axis=new_axis), mapping, child elif id[0] is not None: return gspaces.FlipRot2dOnR2(fibergroup=subgroup, axis=new_axis), mapping, child else: raise ValueError(f"id {id} not recognized!")
def _basis_generator(self, in_repr: Representation, out_repr: Representation, rings: List[float], sigma: List[float], **kwargs, ) -> kernels.KernelBasis: r""" Method that builds the analitical basis that spans the space of equivariant filters which are intertwiners between the representations induced from the representation ``in_repr`` and ``out_repr``. If this :class:`~e2cnn.group.GSpace` includes only a discrete number of rotations (``n > 1``), either ``maximum_frequency`` or ``maximum_offset`` must be set in the keywords arguments. Args: in_repr: the input representation out_repr: the output representation rings: radii of the rings where to sample the bases sigma: parameters controlling the width of each ring where the bases are sampled. Keyword Args: maximum_frequency (int): the maximum frequency allowed in the basis vectors maximum_offset (int): the maximum frequencies offset for each basis vector with respect to its base ones (sum and difference of the frequencies of the input and the output representations) Returns: the basis built """ if self.fibergroup.order() > 0: maximum_frequency = None maximum_offset = None if 'maximum_frequency' in kwargs and kwargs['maximum_frequency'] is not None: maximum_frequency = kwargs['maximum_frequency'] assert isinstance(maximum_frequency, int) and maximum_frequency >= 0 if 'maximum_offset' in kwargs and kwargs['maximum_offset'] is not None: maximum_offset = kwargs['maximum_offset'] assert isinstance(maximum_offset, int) and maximum_offset >= 0 assert (maximum_frequency is not None or maximum_offset is not None), \ 'Error! Either the maximum frequency or the maximum offset for the frequencies must be set' return kernels.kernels_DN_act_R2(in_repr, out_repr, rings, sigma, axis=self.axis, max_frequency=maximum_frequency, max_offset=maximum_offset) else: return kernels.kernels_O2_act_R2(in_repr, out_repr, rings, sigma, axis=self.axis) def _diffop_basis_generator(self, in_repr: Representation, out_repr: Representation, max_power: int, discretization: DiscretizationArgs, ** kwargs, ) -> diffops.DiffopBasis: r""" Method that builds the analytical basis that spans the space of equivariant PDOs which are intertwiners between the representations induced from the representation ``in_repr`` and ``out_repr``. If this :class:`~e2cnn.group.GSpace` includes only a discrete number of rotations (``n > 1``), either ``maximum_frequency`` or ``maximum_offset`` must be set in the keywords arguments. Args: in_repr: the input representation out_repr: the output representation max_power (int): the maximum power of Laplacians to use discretization (DiscretizationArgs): the parameters specifying a discretization procedure for PDOs Keyword Args: maximum_frequency (int): the maximum frequency allowed in the basis vectors maximum_offset (int): the maximum frequencies offset for each basis vector with respect to its base ones (sum and difference of the frequencies of the input and the output representations) Returns: the basis built """ if self.fibergroup.order() > 0: maximum_frequency = None maximum_offset = None if 'maximum_frequency' in kwargs and kwargs['maximum_frequency'] is not None: maximum_frequency = kwargs['maximum_frequency'] assert isinstance(maximum_frequency, int) and maximum_frequency >= 0 if 'maximum_offset' in kwargs and kwargs['maximum_offset'] is not None: maximum_offset = kwargs['maximum_offset'] assert isinstance(maximum_offset, int) and maximum_offset >= 0 assert (maximum_frequency is not None or maximum_offset is not None), \ 'Error! Either the maximum frequency or the maximum offset for the frequencies must be set' return diffops.diffops_DN_act_R2(in_repr, out_repr, max_power, axis=self.axis, max_frequency=maximum_frequency, max_offset=maximum_offset, discretization=discretization) else: return diffops.diffops_O2_act_R2(in_repr, out_repr, max_power, axis=self.axis, discretization=discretization) def _basespace_action(self, input: np.ndarray, element: Tuple[int, Union[float, int]]) -> np.ndarray: assert self.fibergroup.is_element(element) if self.fibergroup.order() > 1: n = self.fibergroup.rotation_order rotation = element[1] * 2.0 * np.pi / n else: rotation = element[1] output = input if element[0]: output = output[..., ::-1, :] rotation += 2*self.axis if rotation != 0.: output = rotate_array(output, rotation) else: output = output.copy() return output def __eq__(self, other): if isinstance(other, FlipRot2dOnR2): return self.fibergroup == other.fibergroup else: return False def __hash__(self): return hash(self.name)