Source code for escnn.gspaces.r2

from __future__ import annotations

from escnn import gspaces
from escnn import kernels

from .utils import rotate_array_2d

from escnn.group import *

import numpy as np

from typing import Tuple, Union, Callable, List

__all__ = [
    "GSpace2D",
    #################
    "rot2dOnR2",
    "flipRot2dOnR2",
    "flip2dOnR2",
    "trivialOnR2",
]


[docs]class GSpace2D(gspaces.GSpace): def __init__(self, sg_id: Tuple, maximum_frequency: int = 6): r""" A ``GSpace`` tha describes the set (or subset) of reflectional and rotational symmetries of the 2D Euclidean Space :math:`\R^2`. The subset of symmetries is determined by the subgroup of :math:`\O2` that is specified by `sg_id` (check the documentation of :class:`escnn.group.O2`). Args: sg_id (tuple): The ID of the subgroup within the fiber group :math:`\O2` that determines the reflectional and rotational symmetries to consider. For detailed documentation on the ID of each subgroup, refer to the documentation of :class:`escnn.group.O2` maximum_frequency (int): Maximum frequency of the irreps to pre-instantiate, if the symmetry group (identified by `sg_id`) contains all continuous rotations. .. note :: A point :math:`\bold{v} \in \R^2` is parametrized using an :math:`(X, Y)` convention, i.e. :math:`\bold{v} = (x, y)^T`. The representation :attr:`escnn.gspaces.GSpace2D.basespace_action` also assumes this convention. However, when working with data on a pixel grid, the usual :math:`(-Y, X)` convention is used. That means that, in a 4-dimensional feature tensor of shape ``(B, C, D1, D2)``, the last dimension is the X axis while the second last is the (inverted) Y axis. Note that this is consistent with 2D images, where a :math:`(-Y, X)` convention is used. """ o2 = o2_group(maximum_frequency=maximum_frequency) _sg_id = o2._process_subgroup_id(sg_id) fibergroup, inclusion, restriction = o2.subgroup(_sg_id) # TODO - catch sg_id and build a dictionary of more meaningful names # use the input sg_id instead of the processed one to avoid adding the adjoint parameter unless specified name = f'{fibergroup}_on_R2[{sg_id}]' self._sg_id = _sg_id self._inclusion = inclusion self._restriction = restriction self._base_action = o2.irrep(1, 1).restrict(_sg_id) super(GSpace2D, self).__init__(fibergroup, 2, name)
[docs] def restrict(self, id: Tuple) -> Tuple[gspaces.GSpace, Callable, Callable]: r""" Build the :class:`~escnn.group.GSpace` associated with the subgroup of the current fiber group identified by the input ``id`` Args: id (tuple): the id of the subgroup Returns: a tuple containing - **gspace**: the restricted gspace - **back_map**: a function mapping an element of the subgroup to itself in the fiber group of the original space - **subgroup_map**: a function mapping an element of the fiber group of the original space to itself in the subgroup (returns ``None`` if the element is not in the subgroup) """ o2 = o2_group() sg_id = o2._combine_subgroups(self._sg_id, id) sg, inclusion, restriction = self.fibergroup.subgroup(id) return GSpace2D(sg_id), inclusion, restriction
@property def rotations_order(self): return self._sg_id[1] @property def flips_order(self): return 1 if self._sg_id[0] is not None else 0 def _basis_generator(self, in_repr: Representation, out_repr: Representation, rings: List[float], sigma: List[float], **kwargs, ) -> kernels.KernelBasis: r""" Method that builds the analytical basis that spans the space of equivariant filters which are intertwiners between the representations induced from the representation ``in_repr`` and ``out_repr``. `kwargs` can be used to specify `maximum_frequency` Args: in_repr (Representation): the input representation out_repr (Representation): the output representation rings (list): radii of the rings where to sample the bases sigma (list): parameters controlling the width of each ring where the bases are sampled. Returns: the basis built """ # TODO - add max_offset for cyclic and dihedral groups! if 'maximum_frequency' in kwargs: maximum_frequency = kwargs['maximum_frequency'] else: maximum_frequency = None if self._sg_id[0] is not None and self._sg_id[1] == -1: return kernels.kernels_O2_act_R2(in_repr, out_repr, rings, sigma, axis=self._sg_id[0]/2, maximum_frequency=maximum_frequency, filter=None) elif self._sg_id == (None, -1): return kernels.kernels_SO2_act_R2(in_repr, out_repr, rings, sigma, maximum_frequency=maximum_frequency, filter=None) elif self._sg_id[0] is None: sg_id = self._sg_id[1] return kernels.kernels_SO2_subgroup_act_R2(in_repr, out_repr, sg_id, rings, sigma, adjoint=None, maximum_frequency=maximum_frequency, filter=None) else: return kernels.kernels_O2_subgroup_act_R2(in_repr, out_repr, self._sg_id, rings, sigma, axis=0., adjoint=None, maximum_frequency=maximum_frequency, filter=None) @property def basespace_action(self) -> Representation: return self._base_action def __eq__(self, other): if isinstance(other, GSpace2D): return self._sg_id == other._sg_id else: return False def __hash__(self): return 1000 * hash(self.name) + hash(self._sg_id)
[docs]def rot2dOnR2(N: int = -1, maximum_frequency: int = 6) -> GSpace2D: r""" Describes rotation symmetries of the plane :math:`\R^2`. If ``N > 1``, the gspace models *discrete* rotations by angles which are multiple of :math:`\frac{2\pi}{N}` (:class:`~e2cnn.group.CyclicGroup`). Otherwise, if ``N=-1``, the gspace models *continuous* planar rotations (:class:`~e2cnn.group.SO2`). In that case the parameter ``maximum_frequency`` is required to specify the maximum frequency of the irreps of :class:`~e2cnn.group.SO2` (see its documentation for more details) Args: N (int): number of discrete rotations (integer greater than 1) or ``-1`` for continuous rotations maximum_frequency (int): maximum frequency of :class:`~e2cnn.group.SO2`'s irreps if ``N = -1`` """ assert isinstance(N, int) assert N == -1 or N > 0 sg_id = None, N return GSpace2D(sg_id, maximum_frequency=maximum_frequency)
[docs]def flipRot2dOnR2(N: int = -1, maximum_frequency: int = 6, axis: float = np.pi / 2.) -> GSpace2D: r""" Describes reflectional and rotational symmetries of the plane :math:`\R^2`. Reflections are applied with respect to the line through the origin with an angle ``axis`` degrees with respect to the *X*-axis. If ``N > 1``, this gspace models reflections and *discrete* rotations by angles multiple of :math:`\frac{2\pi}{N}` (:class:`~e2cnn.group.DihedralGroup`). Otherwise, if ``N=-1`` (by default), the class models reflections and *continuous* planar rotations (:class:`~e2cnn.group.O2`). In that case, the parameter ``maximum_frequency`` is required to specify the maximum frequency of the irreps of :class:`~e2cnn.group.O2` (see its documentation for more details) .. note :: All axes obtained from the axis defined by ``axis`` with a rotation in the symmetry group are equivalent. For instance, if ``N = 4``, an axis :math:`\beta` is equivalent to the axis :math:`\beta + \pi/2`. It follows that for ``N = -1``, i.e. in case the symmetry group contains all continuous rotations, any reflection axis is theoretically equivalent. In practice, though, a basis for equivariant convolutional filter sampled on a grid is affected by the specific choice of the axis. In general, choosing an axis aligned with the grid (an horizontal or a vertical axis, i.e. :math:`0` or :math:`\pi/2`) is suggested. Args: N (int): number of discrete rotations (integer greater than 1) or -1 for continuous rotations maximum_frequency (int): maximum frequency of :class:`~e2cnn.group.O2` 's irreps if ``N = -1`` axis (float, optional): the slope of the axis of the flip (in radians) """ assert isinstance(N, int) assert N == -1 or N > 0 sg_id = 2*axis, N return GSpace2D(sg_id, maximum_frequency=maximum_frequency)
[docs]def flip2dOnR2(axis: float = np.pi / 2) -> GSpace2D: r""" Describes reflectional symmetries of the plane :math:`\R^2`. Reflections are applied along the line through the origin with an angle ``axis`` degrees with respect to the *X*-axis. Args: axis (float, optional): the slope of the axis of the reflection (in radians). By default, the vertical axis is used (:math:`\pi/2`). """ sg_id = 2*axis, 1 return GSpace2D(sg_id, maximum_frequency=1)
[docs]def trivialOnR2() -> GSpace2D: r""" Describes the plane :math:`\R^2` without considering any origin-preserving symmetry. This is modeled by choosing trivial fiber group :math:`\{e\}`. .. note :: This models the symmetries of conventional *Convolutional Neural Networks* which are not equivariant to origin preserving transformations such as rotations and reflections. """ sg_id = (None, 1) return GSpace2D(sg_id, maximum_frequency=1)