Source code for escnn.group.groups.octa

from __future__ import annotations

from escnn.group import change_basis, directsum
from escnn.group.irrep import generate_irrep_matrices_from_generators
from escnn.group.irrep import restrict_irrep
from escnn.group.utils import cycle_isclose

from .utils import *

from .so3_utils import PARAMETRIZATION as PARAMETRIZATION_SO3
from .so3_utils import PARAMETRIZATIONS
from .so3_utils import IDENTITY, _grid, _combine, _equal, _invert, _change_param, _check_param, _hash

from .so3group import _build_character, _build_irrep

import numpy as np

from typing import Tuple, Callable, Iterable, List, Dict, Any, Union


__all__ = ["Octahedral"]

_PHI = (1. + np.sqrt(5)) / 2


[docs]class Octahedral(Group): PARAM = PARAMETRIZATION_SO3 PARAMETRIZATIONS = PARAMETRIZATIONS def __init__(self): r""" Subgroup Structure: +-----------------------------------+-----------------------------------+-------------------------------------------------------------------------------------------------------------------+ | ``id[0]`` | ``id[1]`` | subgroup | +===================================+===================================+===================================================================================================================+ | 'octa' | | The Octahedral :math:`O` group itself | +-----------------------------------+-----------------------------------+-------------------------------------------------------------------------------------------------------------------+ | 'tetra' | | Tetrahedral :math:`T` subgroup | +-----------------------------------+-----------------------------------+-------------------------------------------------------------------------------------------------------------------+ | False | N = 1, 2, 3, 4 | :math:`C_N` of N discrete planar rotations | +-----------------------------------+-----------------------------------+-------------------------------------------------------------------------------------------------------------------+ | True | N = 2, 3, 4 | *dihedral* :math:`D_N` subgroup of N discrete planar rotations and out-of-plane :math:`\pi` rotation | +-----------------------------------+-----------------------------------+-------------------------------------------------------------------------------------------------------------------+ | True | 1 | equivalent to ``(False, 2, adj)`` | +-----------------------------------+-----------------------------------+-------------------------------------------------------------------------------------------------------------------+ """ super(Octahedral, self).__init__("Octahedral", False, False) self._identity = self.element(IDENTITY) self._elements = [self.element(g) for g in _grid('cube')] assert len(self._elements) == 24 # self._identity = self._elements[3] self._generators = [ self._elements[17], # Cyclic Group of order 4 self._elements[11], # Cyclic Group of order 3 self._elements[22], # Cyclic Group of order 2 ] self._build_representations() @property def generators(self) -> List[GroupElement]: return self._generators @property def identity(self) -> GroupElement: return self._identity @property def elements(self) -> List[GroupElement]: return self._elements @property def _keys(self) -> Dict[str, Any]: return dict() @property def subgroup_trivial_id(self): raise NotImplementedError @property def subgroup_self_id(self): raise NotImplementedError return 'octa' ########################################################################### # METHODS DEFINING THE GROUP LAW AND THE OPERATIONS ON THE GROUP'S ELEMENTS ########################################################################### def _inverse(self, element, param=PARAM): r""" Return the inverse element of the input element """ return _invert(element, param=param) def _combine(self, e1, e2, param=PARAM, param1=None, param2=None ): r""" Return the sum of the two input elements """ return _combine(e1, e2, param=param, param1=param1, param2=param2) def _equal(self, e1, e2, param=PARAM, param1=None, param2=None, ) -> bool: r""" Check if the two input values corresponds to the same element """ return _equal(e1, e2, param=param, param1=param1, param2=param2) def _hash_element(self, element, param=PARAM): return _hash(element, param) def _repr_element(self, element, param=PARAM): return element.__repr__() def _is_element(self, element, param: str = PARAM, verbose: bool = False, ) -> bool: ATOL = 1e-7 RTOL = 1e-5 if not _check_param(element, param): if verbose: print(f"Element {element} is not a rotation") return False # convert to matrix representation element = self._change_param(element, param, 'MAT') # take absolute value of the elements # note that we have already ensured that the determinant is positive using `_check_param` above since it checks # that it is a rotation at = np.abs(element) # check if the matrix is a permutation matrix ans = ( np.isclose(at.sum(axis=0), 1., atol=ATOL, rtol=RTOL).all() and np.isclose(at.sum(axis=1), 1., atol=ATOL, rtol=RTOL).all() and (np.isclose(at, 1., atol=ATOL, rtol=RTOL) | np.isclose(at, 0., atol=ATOL, rtol=RTOL)).all() ) return ans def _change_param(self, element, p_from: str, p_to: str): assert p_from in self.PARAMETRIZATIONS assert p_to in self.PARAMETRIZATIONS return _change_param(element, p_from, p_to) ########################################################################### def sample(self, param: str = PARAM) -> GroupElement: return self._elements[ np.random.randint(self.order()) ]
[docs] def testing_elements(self) -> Iterable[GroupElement]: r""" A finite number of group elements to use for testing. """ return iter(self._elements)
def __eq__(self, other): if not isinstance(other, Octahedral): return False else: return self.name == other.name def _process_subgroup_id(self, id): if not isinstance(id, tuple): id = (id,) assert isinstance(id[0], bool) or isinstance(id[0], str), id[0] if not isinstance(id[-1], GroupElement): id = (*id, self.identity) assert id[-1].group == self if isinstance(id[0], bool): assert id[1] in [1, 2, 3, 4] if id[0] == True and id[1] == 1: # flip subgroup of the O(2) subgroup of SO(3) # this is equivalent to the C_2 subgroup of 180 deg rotations out of the plane (around X axis) V = np.array([1., 1., -1.]) V /= np.linalg.norm(V) change_axis = np.zeros(4) change_axis[:3] = V * np.sin(np.pi/3.) change_axis[3] = np.cos(np.pi/3.) adj = self.element(change_axis, 'Q') @ id[-1] id = (False, 2, adj) return id def _subgroup(self, id) -> Tuple[ Group, Callable[[GroupElement], GroupElement], Callable[[GroupElement], GroupElement] ]: r""" Returns: a tuple containing - the subgroup, - a function which maps an element of the subgroup to its inclusion in the original group and - a function which maps an element of the original group to the corresponding element in the subgroup (returns None if the element is not contained in the subgroup) """ # TODO : implement this! sg = None parent_map = None child_map = None id, adj = id[:-1], id[-1] if id == ('octa',): sg = self parent_map = build_adjoint_map(self, ~adj) child_map = build_adjoint_map(self, adj) elif id == ('tetra',): raise NotImplementedError() elif id == (False, 1): sg = escnn.group.cyclic_group(1) parent_map, child_map = build_trivial_subgroup_maps(self) elif id == (False, 2): sg = escnn.group.cyclic_group(2) axis = np.asarray([0., 0., 1.]) parent_map = cn_to_octa(adj, sg, axis=axis) child_map = octa_to_cn(adj, sg, axis=axis) elif id == (False, 3): sg = escnn.group.cyclic_group(3) axis = np.asarray([1., 1., 1.]) / np.sqrt(3) parent_map = cn_to_octa(adj, sg, axis=axis) child_map = octa_to_cn(adj, sg, axis=axis) elif id == (False, 4): sg = escnn.group.cyclic_group(4) axis = np.asarray([0., 0., 1.]) axis /= np.linalg.norm(axis) parent_map = cn_to_octa(adj, sg, axis=axis) child_map = octa_to_cn(adj, sg, axis=axis) elif id == (True, 2): sg = escnn.group.dihedral_group(2) parent_map, child_map = None, None raise NotImplementedError() elif id == (True, 3): sg = escnn.group.dihedral_group(3) parent_map, child_map = None, None raise NotImplementedError() elif id == (True, 4): sg = escnn.group.dihedral_group(4) parent_map, child_map = None, None raise NotImplementedError() else: raise ValueError(f'Subgroup id {id} not recognized!') return sg, parent_map, child_map def _restrict_irrep(self, irrep: str, id) -> Tuple[np.matrix, List[str]]: r""" Returns: a pair containing the change of basis and the list of irreps of the subgroup which appear in the restricted irrep """ sg_id, adj = id[:-1], id[-1] irr = self.irrep(*irrep) sg, _, _ = self.subgroup(id) irreps = [] change_of_basis = None try: if sg_id == ('octa', ): change_of_basis = irr.change_of_basis irreps = irr.irreps elif sg_id == (False, 1): change_of_basis = np.eye(irr.size) irreps = [(1,)]*irr.size else: raise NotImplementedError() except NotImplementedError: if sg.order() > 0: change_of_basis, irreps = restrict_irrep(irr, sg_id) else: raise change_of_basis = self.irrep(*irrep)(adj).T @ change_of_basis return change_of_basis, irreps def _build_representations(self): r""" Build the irreps for this group """ # Build all the Irreducible Representations # add Trivial representation self.irrep(0) # add other irreducible representations # Frequency 1 Wigner D matrix self.irrep(1) # Frequency 2 Wigner D matrix decomposes as a direct sum of a 2 and a 3 dimensional irrep self.irrep(-1) # 3 dimensional irrep self.irrep(2) # 2 dimensional irrep # SO(3)'s freq 3 irrep decomposes in a 1-dimensional irrep and the sum of the two previous 3 dimensional irreps self.irrep(3) # 1 dimensional # add all the irreps to the set of representations already built for this group self.representations.update(**{irr.name: irr for irr in self.irreps()}) # build the regular representation # N.B.: it represents the LEFT-ACTION of the elements self.representations['regular'] = self.regular_representation @property def trivial_representation(self) -> Representation: return self.irrep(0) @property def standard_representation(self) -> Representation: r""" Restriction of the standard representation of SO(3) as 3x3 rotation matrices """ name = f'standard' if name not in self._representations: change_of_basis = np.array([ [0, 0, 1], [1, 0, 0], [0, 1, 0] ]) self._representations[name] = change_basis( self.irrep(1), change_of_basis=change_of_basis, name=name, supported_nonlinearities=self.irrep(1).supported_nonlinearities, ) return self._representations[name] @property def cube_vertices_representation(self) -> Representation: # action on the 8 vertices of the cube (or faces of the octahedron) sgid = (False, 3) return self.quotient_representation(sgid, name='cube_vertices') @property def cube_faces_representation(self) -> Representation: # action on the 6 faces of the cube (or vertices of the octahedron) sgid = (False, 4) return self.quotient_representation(sgid, name='cube_faces') @property def cube_edges_representation(self) -> Representation: # action on the 12 edges of the cube or octahedron sgid = (True, 1) return self.quotient_representation(sgid, name='cube_edges')
[docs] def irrep(self, l: int) -> IrreducibleRepresentation: r""" Build the irrep of :math:`O` identified by the integer :math:`l`. For :math:`l = 0, 1`, the irrep is equivalent to the Wigner D matrix of the same frequency :math:`l`. For :math:`l=2`, the 5-dimensional Wigner D matrix is decomposed in a 3-dimensional and a 2-dimensional irreps, here identified respectively by :math:`l=-1` and :math:`l=2`. For :math:`l=3`, the 7-dimensional Wigner D matrix is decomposed in a 1-dimensional irrep and the two previous 3-dimensional irreps, here identified respectively by :math:`l=3` and :math:`l=1, -1`. Args: l (int): identifier of the irrep Returns: the corresponding irrep """ assert isinstance(l, int) assert -1 <= l <= 3 name = f"irrep_{l}" id = (l,) if id not in self._irreps: if l == 0: # Trivial representation irrep = build_trivial_irrep() character = build_trivial_character() supported_nonlinearities = ['pointwise', 'norm', 'gated', 'gate'] self._irreps[id] = IrreducibleRepresentation(self, id, name, irrep, 1, 'R', supported_nonlinearities=supported_nonlinearities, character=character, ) elif l == 1: # Irreducible Representation equivalent to the frequency 1 Wigner D matrices irrep = _build_irrep(l) character = _build_character(l) supported_nonlinearities = ['norm', 'gated'] self._irreps[id] = IrreducibleRepresentation(self, id, name, irrep, 3, 'R', supported_nonlinearities=supported_nonlinearities, character=character) elif l == -1 or l == 2: irrep = _build_octa_irrep(self, l) supported_nonlinearities = ['norm', 'gated'] self._irreps[id] = IrreducibleRepresentation(self, id, name, irrep, irrep[self.identity].shape[0], 'R', supported_nonlinearities=supported_nonlinearities) elif l == 3: irrep = _build_octa_irrep(self, l) supported_nonlinearities = ['norm', 'gated', 'concatenated'] self._irreps[id] = IrreducibleRepresentation(self, id, name, irrep, irrep[self.identity].shape[0], 'R', supported_nonlinearities=supported_nonlinearities) else: raise ValueError() return self._irreps[id]
_cached_group_instance = None @classmethod def _generator(cls) -> 'Octahedral': if cls._cached_group_instance is None: cls._cached_group_instance = Octahedral() return cls._cached_group_instance
def _is_axis_aligned(v: np.ndarray, n: int, verbose: bool = False, ATOL=1e-7, RTOL = 1e-5) -> bool: norm = np.linalg.norm(v) v = v / norm if n == 2: # rotation of order 2 # the rotation axis need to be aligned with one of the axes X, Y, Z or to # the bisector of a pair of axes XY, XZ, YZ # There are in total 6 + 12 possible vectors # remove sign ambiguity v = np.abs(v) axes = np.eye(3) bisectors = np.array([ [1., 0., 1.], [1., 1., 0.], [0., 1., 1.], ]) / np.sqrt(2) ans = ( # axes aligned np.allclose(v, axes[0], atol=ATOL, rtol=RTOL) or np.allclose(v, axes[1], atol=ATOL, rtol=RTOL) or np.allclose(v, axes[2], atol=ATOL, rtol=RTOL) # bisectors aligned or np.allclose(v, bisectors[0], atol=ATOL, rtol=RTOL) or np.allclose(v, bisectors[1], atol=ATOL, rtol=RTOL) or np.allclose(v, bisectors[2], atol=ATOL, rtol=RTOL) ) if not ans and verbose: print(f'Rotation by a multiple of 2pi/{n} not aligned with a {n}-fold rotational axis of the Octahedron.') return ans elif n == 4: # rotation of order 4 # the rotation axis need to be aligned with one of the axes X, Y, Z # There are in total 6 possible vectors # remove sign ambiguity v = np.abs(v) axes = np.eye(3) ans = ( np.allclose(v, axes[0], atol=ATOL, rtol=RTOL) or np.allclose(v, axes[1], atol=ATOL, rtol=RTOL) or np.allclose(v, axes[2], atol=ATOL, rtol=RTOL) ) if not ans and verbose: print(f'Rotation by a multiple of 2pi/{n} not aligned with a {n}-fold rotational axis of the Octahedron.') return ans elif n == 3: # rotation or order 3 # the rotation axis need to pass through one of the vertices of the cube # There are in total 8 possible vectors # remove sign ambiguity v = np.abs(v) # since the vector is normalized, `v` should now be `(1, 1, 1)^T * 1/sqrt(3)` ans = np.allclose(v, 1./np.sqrt(3), atol=ATOL, rtol=RTOL) if not ans and verbose: print(f'Rotation by a multiple of 2pi/{n} not aligned with a {n}-fold rotational axis of the Octahedron.') return ans else: raise ValueError('The rotation order must be one of {2, 3, 4}.') ############################################# # SUBGROUPS MAPS ############################################# # C_N ##################################### def octa_to_cn(adj: GroupElement, cn: escnn.group.CyclicGroup, axis: np.ndarray): assert isinstance(adj.group, Octahedral) assert axis.shape == (3,) assert np.isclose(np.linalg.norm(axis), 1.) assert cn.order() in [2, 3, 4] assert _is_axis_aligned(axis, cn.order()) def _map(e: GroupElement, cn=cn, adj=adj, axis=axis): octa = adj.group assert e.group == octa e = adj @ e @ (~adj) e = e.to('Q') v = e[:3] n = np.linalg.norm(v) if np.allclose(n, 0.): return cn.identity elif np.allclose(v / n, axis): # if the rotation is along the axis s, c = n, e[-1] theta = 2 * np.arctan2(s, c) try: return cn.element(theta, 'radians') except ValueError: return None else: return None return _map def cn_to_octa(adj: GroupElement, cn: escnn.group.CyclicGroup, axis: np.ndarray): assert isinstance(adj.group, Octahedral) assert axis.shape == (3,) assert np.isclose(np.linalg.norm(axis), 1.) assert cn.order() in [2, 3, 4] assert _is_axis_aligned(axis, cn.order()) def _map(e: GroupElement, cn=cn, adj=adj, axis=axis): assert e.group == cn octa = adj.group theta_2 = e.to('radians') / 2. q = np.empty(4) q[:3] = axis * np.sin(theta_2) q[-1] = np.cos(theta_2) return (~adj) @ octa.element(q, 'Q') @ adj return _map ############################################# # Generate irreps ############################################# from joblib import Memory from escnn.group import __cache_path__ cache = Memory(__cache_path__, verbose=2) def _build_octa_irrep(octa: Octahedral, l: int): # See `_build_ico_irrep()` for an explanation of why this function is split irreps = _build_octa_irrep_picklable(octa, l) return { octa.element(g, param): v for g, param, v in irreps } @cache.cache(ignore=['octa']) def _build_octa_irrep_picklable(octa: Octahedral, l: int): if l == -1: # matrix coefficients from https://arxiv.org/pdf/1110.6376.pdf # the matrix coefficients there are expressed wrt a different set of generators # we fist build this set of generators r3 = octa.generators[0] r = r3 @ r3 @ r3 k = octa.elements[0] s = octa.generators[1] t = ~s @ k @ s @ r # Representation of `t` rho_t = np.array([ [1., 0., 0.], [0., 0., 1.], [0., 1., 0.], ]) # Representation of `k` rho_k = np.array([ [1., 0., 0.], [0., -1., 0.], [0., 0., -1.], ]) # Representation of `s` rho_s = np.array([ [0., 1., 0.], [0., 0., 1.], [1., 0., 0.], ]) # https://arxiv.org/pdf/1110.6376.pdf defines the irrep `l = 1` (denoted by 3 there) as our # `standard_representation`, which is expressed on a different basis than the Wigner D matrix with l=1. # Since `l=-1` (their 3') is defined as the tensor product between `l=1` and `l=3` (their 1') # we apply the inverse change of basis used in `standard_representation` to ensure that # `-1 = 1 \tensor 3` for us as well change_of_basis = np.array([ [0, 0, 1], [1, 0, 0], [0, 1, 0] ]) rho_t = change_of_basis.T @ rho_t @ change_of_basis rho_k = change_of_basis.T @ rho_k @ change_of_basis rho_s = change_of_basis.T @ rho_s @ change_of_basis elif l == 2: # matrix coefficients from https://arxiv.org/pdf/1110.6376.pdf # the matrix coefficients there are expressed wrt a different set of generators # we fist build this set of generators r3 = octa.generators[0] r = r3 @ r3 @ r3 k = octa.elements[0] s = octa.generators[1] t = ~s @ k @ s @ r # Representation of `t` rho_t = np.array([ [0., 1.], [1., 0.], ]) # Representation of `k` rho_k = np.array([ [1., 0.], [0., 1.], ]) # Representation of `s` rho_s = 0.5 * np.array([ [-1., -np.sqrt(3)], [np.sqrt(3), -1.], ]) elif l == 3: # matrix coefficients from https://arxiv.org/pdf/1110.6376.pdf # the matrix coefficients there are expressed wrt a different set of generators # we fist build this set of generators r3 = octa.generators[0] r = r3 @ r3 @ r3 k = octa.elements[0] s = octa.generators[1] t = ~s @ k @ s @ r # Representation of `t` rho_t = np.array([[-1.]]) # Representation of `k` rho_k = np.array([[1.]]) # Representation of `s` rho_s = np.array([[1.]]) else: raise ValueError() generators = [ (t, rho_t), (k, rho_k), (s, rho_s), ] irreps = generate_irrep_matrices_from_generators(octa, generators) return [ (k.value, k.param, v) for k, v in irreps.items() ]