from __future__ import annotations
from escnn import gspaces
from escnn import kernels
from escnn.group import *
from .utils import linear_transform_array_3d
import numpy as np
from typing import Tuple, Union, Callable, List
__all__ = [
"GSpace3D",
###########
"flipRot3dOnR3",
"rot3dOnR3",
"fullIcoOnR3",
"icoOnR3",
"fullOctaOnR3",
"octaOnR3",
"dihedralOnR3",
"rot2dOnR3",
"conicalOnR3",
"fullCylindricalOnR3",
"cylindricalOnR3",
"mirOnR3",
"invOnR3",
"trivialOnR3"
]
[docs]class GSpace3D(gspaces.GSpace):
def __init__(self, sg_id: Tuple, maximum_frequency: int = 2):
r"""
A ``GSpace`` tha describes the set (or subset) of reflectional and rotational symmetries of the
3D Euclidean Space :math:`\R^3`.
The subset of symmetries is determined by the subgroup of :math:`\O3` that is specified by `sg_id`
(check the documentation of :class:`escnn.group.O3`).
Args:
sg_id (tuple): The ID of the subgroup within the fiber group :math:`\O3` that determines the reflectional and rotational symmetries to consider. For detailed documentation on the ID of each subgroup, refer to the documentation of :class:`escnn.group.O3`
maximum_frequency (int): Maximum frequency of the irreps to pre-instantiate, if the symmetry group (identified by `sg_id`) contains continuous rotations.
.. note ::
A point :math:`\bold{v} \in \R^3` is parametrized using an :math:`(X, Y, Z)` convention,
i.e. :math:`\bold{v} = (x, y, z)^T`.
The representation :attr:`escnn.gspaces.GSpace3D.basespace_action` also assumes this convention.
However, when working with voxel data, the :math:`(-Z, -Y, X)` convention is used.
That means that, in a 5-dimensional feature tensor of shape ``(B, C, D1, D2, D3)``, the last dimension
is the X axis, the second last the (inverted) Y axis and then the (inverted) Z axis.
Note that this is consistent with 2D images, where a :math:`(-Y, X)` convention is used.
This is especially relevant when transforming a :class:`~escnn.nn.GeometricTensor` or when building
convolutional filters in :class:`~escnn.nn.R3Conv` which should be equivariant to subgroups of :math:`\O3`
(e.g. when choosing the rotation axis for :func:`~escnn.gspaces.rot2dOnR3`).
"""
o3 = o3_group(maximum_frequency=maximum_frequency)
_sg_id = o3._process_subgroup_id(sg_id)
fibergroup, inclusion, restriction = o3.subgroup(_sg_id)
# TODO - catch sg_id and build a dictionary of more meaningful names
# use the input sg_id instead of the processed one to avoid adding the adjoint parameter unless specified
name = f'{fibergroup}_on_R3[{sg_id}]'
self._sg_id = _sg_id
self._inclusion = inclusion
self._restriction = restriction
self._base_action = o3.standard_representation().restrict(_sg_id)
super(GSpace3D, self).__init__(fibergroup, 3, name)
[docs] def restrict(self, id: Tuple) -> Tuple[gspaces.GSpace, Callable, Callable]:
r"""
Build the :class:`~escnn.group.GSpace` associated with the subgroup of the current fiber group identified by
the input ``id``
Args:
id (tuple): the id of the subgroup
Returns:
a tuple containing
- **gspace**: the restricted gspace
- **back_map**: a function mapping an element of the subgroup to itself in the fiber group of the original space
- **subgroup_map**: a function mapping an element of the fiber group of the original space to itself in the subgroup (returns ``None`` if the element is not in the subgroup)
"""
o3 = o3_group()
sg_id = o3._combine_subgroups(self._sg_id, id)
sg, inclusion, restriction = self.fibergroup.subgroup(id)
return GSpace3D(sg_id), inclusion, restriction
def _basis_generator(self,
in_repr: Representation,
out_repr: Representation,
rings: List[float],
sigma: List[float],
**kwargs,
) -> kernels.KernelBasis:
r"""
Method that builds the analytical basis that spans the space of equivariant filters which
are intertwiners between the representations induced from the representation ``in_repr`` and ``out_repr``.
`kwargs` can be used to specify `maximum_frequency`
Args:
in_repr (Representation): the input representation
out_repr (Representation): the output representation
rings (list): radii of the rings where to sample the bases
sigma (list): parameters controlling the width of each ring where the bases are sampled.
Returns:
the basis built
"""
# TODO - add max_offset for cyclic and dihedral groups!
if 'maximum_frequency' in kwargs:
maximum_frequency = kwargs['maximum_frequency']
else:
maximum_frequency = None
if self._sg_id == (True, 'so3'):
return kernels.kernels_O3_act_R3(in_repr, out_repr, rings, sigma, maximum_frequency=maximum_frequency, adjoint=None)
elif self._sg_id == (False, 'so3'):
return kernels.kernels_SO3_act_R3(in_repr, out_repr, rings, sigma, maximum_frequency=maximum_frequency, adjoint=None)
elif self._sg_id[0] == False:
sg_id = self._sg_id[1:]
if isinstance(sg_id[-1], GroupElement):
# the adjoint is an O(3) group element
# convert it to an SO(3) element
# not that even if the adjoint contains the 3d inversion, we can ignore it
# (since O(3) is a direct product, the inversion commutes with any 3D rotation)
adj = sg_id[-1]
so3 = so3_group()
adj = so3.element(adj.value[1], adj.param)
sg_id = sg_id[:-1] + (adj,)
return kernels.kernels_SO3_subgroup_act_R3(in_repr, out_repr, sg_id, rings, sigma, maximum_frequency=maximum_frequency, adjoint=None)
else:
return kernels.kernels_O3_subgroup_act_R3(in_repr, out_repr, self._sg_id, rings, sigma, maximum_frequency=maximum_frequency, adjoint=None)
@property
def basespace_action(self) -> Representation:
return self._base_action
def __eq__(self, other):
if isinstance(other, GSpace3D):
return self._sg_id == other._sg_id
else:
return False
def __hash__(self):
return 1000 * hash(self.name) + hash(self._sg_id)
########################################################################################################################
[docs]def flipRot3dOnR3(maximum_frequency: int = 2) -> GSpace3D:
r"""
Describes 3D rotation and inversion symmetries in the space :math:`\R^3`.
.. todo ::
rename to invRot3dOnR3?
Args:
maximum_frequency (int): maximum frequency of :class:`~e2cnn.group.O3`'s irreps
"""
sg_id = 'o3'
return GSpace3D(sg_id, maximum_frequency=maximum_frequency)
[docs]def rot3dOnR3(maximum_frequency: int = 2) -> GSpace3D:
r"""
Describes 3D rotation symmetries in the space :math:`\R^3`.
Args:
maximum_frequency (int): maximum frequency of :class:`~e2cnn.group.SO3`'s irreps
"""
sg_id = 'so3'
return GSpace3D(sg_id, maximum_frequency=maximum_frequency)
[docs]def fullIcoOnR3() -> GSpace3D:
sg_id = True, 'ico'
return GSpace3D(sg_id, maximum_frequency=4)
[docs]def icoOnR3() -> GSpace3D:
r"""
Describes 3D rotation symmetries of a Icosahedron (or Dodecahedron) in the space :math:`\R^3`
"""
sg_id = False, 'ico'
return GSpace3D(sg_id, maximum_frequency=4)
[docs]def fullOctaOnR3() -> GSpace3D:
sg_id = True, 'octa'
return GSpace3D(sg_id, maximum_frequency=3)
[docs]def octaOnR3() -> GSpace3D:
r"""
Describes 3D rotation symmetries of an Octahedron (or Cube) in the space :math:`\R^3`
"""
sg_id = False, 'octa'
return GSpace3D(sg_id, maximum_frequency=3)
[docs]def dihedralOnR3(n: int = -1, axis: float = np.pi / 2, adjoint: GroupElement = None, maximum_frequency: int = 2) -> GSpace3D:
r"""
Describes 2D rotation symmetries along the :math:`Z` axis in the space :math:`\R^3` and :math:`\pi` rotations
along the ``axis`` in the :math:`XY` plane, i.e. the rotations inside the plane :math:`XY` and reflections around
the ``axis``.
The ``adjoint`` parameter can be a :class:`~escnn.group.GroupElement` of :class:`~escnn.group.O3`.
If not ``None`` (which is equivalent to the identity), this specifies another :math:`\SO2` subgroup of :math:`\O3`
which is adjoint to the :math:`\SO2` subgroup of rotations around the :math:`Z` axis.
If ``adjoint`` is the group element :math:`A \in \O3`, the new subgroup would then represent rotations around the
axis :math:`A^{-1} \cdot (0, 0, 1)^T`.
If ``N > 1``, the gspace models *discrete* rotations by angles which are multiple of :math:`\frac{2\pi}{N}`
(:class:`~e2cnn.group.CyclicGroup`).
Otherwise, if ``N=-1``, the gspace models *continuous* planar rotations (:class:`~e2cnn.group.SO2`).
In that case the parameter ``maximum_frequency`` is required to specify the maximum frequency of the irreps of
:class:`~e2cnn.group.SO2` (see its documentation for more details)
Args:
N (int): number of discrete rotations (integer greater than 1) or ``-1`` for continuous rotations
adjoint (GroupElement, optional): an element of :math:`\O3`
maximum_frequency (int): maximum frequency of :class:`~e2cnn.group.SO2`'s irreps if ``N = -1``
"""
assert isinstance(n, int)
assert n == -1 or n > 0
sg_id = False, 2*axis, n
if adjoint is not None:
sg_id += (adjoint,)
return GSpace3D(sg_id, maximum_frequency=maximum_frequency)
[docs]def rot2dOnR3(n: int = -1, adjoint: GroupElement = None, maximum_frequency: int = 2) -> GSpace3D:
r"""
Describes 2D rotation symmetries along the :math:`Z` axis in the space :math:`\R^3`, i.e. the rotations inside the
plane :math:`XY`.
``adjoint`` is a :class:`~escnn.group.GroupElement` of :class:`~escnn.group.O3`.
If not ``None`` (which is equivalent to the identity), this specifies another :math:`\SO2` subgroup of :math:`\O3`
which is adjoint to the :math:`\SO2` subgroup of rotations around the :math:`Z` axis.
If ``adjoint`` is the group element :math:`A \in \O3`, the new subgroup would then represent rotations around the
axis :math:`A^{-1} \cdot (0, 0, 1)^T`.
If ``N > 1``, the gspace models *discrete* rotations by angles which are multiple of :math:`\frac{2\pi}{N}`
(:class:`~e2cnn.group.CyclicGroup`).
Otherwise, if ``N=-1``, the gspace models *continuous* planar rotations (:class:`~e2cnn.group.SO2`).
In that case the parameter ``maximum_frequency`` is required to specify the maximum frequency of the irreps of
:class:`~e2cnn.group.SO2` (see its documentation for more details)
Args:
N (int): number of discrete rotations (integer greater than 1) or ``-1`` for continuous rotations
adjoint (GroupElement, optional): an element of :math:`\O3`
maximum_frequency (int): maximum frequency of :class:`~e2cnn.group.SO2`'s irreps if ``N = -1``
"""
assert isinstance(n, int)
assert n == -1 or n > 0
sg_id = False, False, n
if adjoint is not None:
sg_id += (adjoint,)
return GSpace3D(sg_id, maximum_frequency=maximum_frequency)
[docs]def conicalOnR3(n: int = -1, axis: float = np.pi / 2., adjoint: GroupElement = None, maximum_frequency: int = 2) -> GSpace3D:
assert isinstance(n, int)
assert n == -1 or n > 0
sg_id = 'cone', 2*axis, n
if adjoint is not None:
sg_id += (adjoint,)
return GSpace3D(sg_id, maximum_frequency=maximum_frequency)
[docs]def mirOnR3(axis: float = np.pi / 2, adjoint: GroupElement = None) -> GSpace3D:
r"""
Describes mirroring with respect to a plane in the space :math:`\R^3`.
.. todo ::
Document what ``axis`` and ``adjoint`` describe or change parameters, just getting a :math:`\bold{v} \in \R^3`
vector in input which specifies the mirroring axis.
"""
sg_id = 'cone', 2*axis, 1
if adjoint is not None:
sg_id += (adjoint,)
return GSpace3D(sg_id, maximum_frequency=1)
[docs]def fullCylindricalOnR3(n: int = -1, axis: float = np.pi / 2, adjoint: GroupElement = None, maximum_frequency: int = 2) -> GSpace3D:
assert isinstance(n, int)
assert n == -1 or n > 0
sg_id = True, axis, n
if adjoint is not None:
sg_id += (adjoint,)
return GSpace3D(sg_id, maximum_frequency=maximum_frequency)
[docs]def cylindricalOnR3(n: int = -1, adjoint: GroupElement = None, maximum_frequency: int = 2) -> GSpace3D:
assert isinstance(n, int)
assert n == -1 or n > 0
sg_id = True, False, n
if adjoint is not None:
sg_id += (adjoint,)
return GSpace3D(sg_id, maximum_frequency=maximum_frequency)
[docs]def invOnR3() -> GSpace3D:
r"""
Describes the inversion symmetry of the space :math:`\R^3`.
An inversion flips the sign of all coordinates, mapping a vector :math:`\bold{v} \in \R^3` to :math:`-\bold{v}`.
"""
sg_id = True, False, 1
return GSpace3D(sg_id, maximum_frequency=1)
[docs]def trivialOnR3() -> GSpace3D:
r"""
Describes the space :math:`\R^3` without considering any origin-preserving symmetry.
This is modeled by choosing trivial fiber group :math:`\{e\}`.
.. note ::
This models the symmetries of conventional *Convolutional Neural Networks* which are not equivariant to
origin preserving transformations such as rotations and reflections.
"""
sg_id = False, False, 1
return GSpace3D(sg_id, maximum_frequency=1)