from __future__ import annotations
from escnn.group import change_basis, directsum
from escnn.group.irrep import generate_irrep_matrices_from_generators
from escnn.group.irrep import restrict_irrep
from escnn.group.utils import cycle_isclose
from .utils import *
from .so3_utils import PARAMETRIZATION as PARAMETRIZATION_SO3
from .so3_utils import PARAMETRIZATIONS
from .so3_utils import IDENTITY, _grid, _combine, _equal, _invert, _change_param, _check_param, _hash
from .so3group import _build_character, _build_irrep
import numpy as np
from typing import Tuple, Callable, Iterable, List, Dict, Any, Union
__all__ = ["Icosahedral"]
_PHI = (1. + np.sqrt(5)) / 2
[docs]class Icosahedral(Group):
PARAM = PARAMETRIZATION_SO3
PARAMETRIZATIONS = PARAMETRIZATIONS
def __init__(self):
r"""
Subgroup Structure:
+-----------------------------------+-----------------------------------+-------------------------------------------------------------------------------------------------------------------+
| ``id[0]`` | ``id[1]`` | subgroup |
+===================================+===================================+===================================================================================================================+
| 'ico' | | The Icosahedral :math:`I` group itself |
+-----------------------------------+-----------------------------------+-------------------------------------------------------------------------------------------------------------------+
| 'tetra' | | Tetrahedral :math:`T` subgroup |
+-----------------------------------+-----------------------------------+-------------------------------------------------------------------------------------------------------------------+
| False | N = 1, 2, 3, 5 | :math:`C_N` of N discrete planar rotations |
+-----------------------------------+-----------------------------------+-------------------------------------------------------------------------------------------------------------------+
| True | N = 2, 3, 5 | *dihedral* :math:`D_N` subgroup of N discrete planar rotations and out-of-plane :math:`\pi` rotation |
+-----------------------------------+-----------------------------------+-------------------------------------------------------------------------------------------------------------------+
| True | 1 | equivalent to ``(False, 2, adj)`` |
+-----------------------------------+-----------------------------------+-------------------------------------------------------------------------------------------------------------------+
"""
super(Icosahedral, self).__init__("Icosahedral", False, False)
self._identity = self.element(IDENTITY)
self._elements = [self.element(g) for g in _grid('ico')]
assert len(self._elements) == 60
# self._identity = self._elements[3]
self._generators = [
self._elements[21], # Cyclic Group of order 5
self._elements[0], # Cyclic Group of order 2
# self._elements[38] # Cyclic Group of order 3
# self._elements[0] @ self._elements[21] # Cyclic Group of order 3
]
self._build_representations()
@property
def generators(self) -> List[GroupElement]:
return self._generators
@property
def identity(self) -> GroupElement:
return self._identity
@property
def elements(self) -> List[GroupElement]:
return self._elements
@property
def _keys(self) -> Dict[str, Any]:
return dict()
@property
def subgroup_trivial_id(self):
return (False, 1)
@property
def subgroup_self_id(self):
return 'ico'
###########################################################################
# METHODS DEFINING THE GROUP LAW AND THE OPERATIONS ON THE GROUP'S ELEMENTS
###########################################################################
def _inverse(self, element, param=PARAM):
r"""
Return the inverse element of the input element
"""
return _invert(element, param=param)
def _combine(self, e1, e2,
param=PARAM,
param1=None,
param2=None
):
r"""
Return the sum of the two input elements
"""
return _combine(e1, e2, param=param, param1=param1, param2=param2)
def _equal(self, e1, e2,
param=PARAM,
param1=None,
param2=None,
) -> bool:
r"""
Check if the two input values corresponds to the same element
"""
return _equal(e1, e2, param=param, param1=param1, param2=param2)
def _hash_element(self, element, param=PARAM):
return _hash(element, param)
def _repr_element(self, element, param=PARAM):
return element.__repr__()
def _is_element(self, element,
param: str = PARAM,
verbose: bool = False,
) -> bool:
ATOL = 1e-7
RTOL = 1e-5
angle_ATOL = 1e-6
if not _check_param(element, param):
if verbose:
print(f"Element {element} is not a rotation")
return False
element = self._change_param(element, param, 'Q')
v = element[:3]
theta = 2*np.arccos(np.clip(element[3], -1., 1.))
if cycle_isclose(theta, 0., 2*np.pi, atol=angle_ATOL, rtol=0.):
return True
v = v / np.sin(theta/2.)
if cycle_isclose(theta, 0., 2 * np.pi / 2, atol=angle_ATOL, rtol=0.):
# rotation of order 2
return _is_axis_aligned(v, 2, verbose=verbose, ATOL=ATOL, RTOL=RTOL)
elif cycle_isclose(theta, 0., 2 * np.pi / 3, atol=angle_ATOL, rtol=0.):
# rotation or order 3
return _is_axis_aligned(v, 3, verbose=verbose, ATOL=ATOL, RTOL=RTOL)
elif cycle_isclose(theta, 0., 2 * np.pi / 5, atol=angle_ATOL, rtol=0.):
# rotation or order 5
return _is_axis_aligned(v, 5, verbose=verbose, ATOL=ATOL, RTOL=RTOL)
else:
if verbose:
print(f'Group element is neither a 2-fold, a 3-fold nor a 5-fold rotation.')
return False
def _change_param(self, element, p_from: str, p_to: str):
assert p_from in self.PARAMETRIZATIONS
assert p_to in self.PARAMETRIZATIONS
return _change_param(element, p_from, p_to)
###########################################################################
def sample(self, param: str = PARAM) -> GroupElement:
return self._elements[
np.random.randint(self.order())
]
[docs] def testing_elements(self) -> Iterable[GroupElement]:
r"""
A finite number of group elements to use for testing.
"""
return iter(self._elements)
def change_param(self, element, p_from, p_to):
return _change_param(element, p_from, p_to)
def __eq__(self, other):
if not isinstance(other, Icosahedral):
return False
else:
return self.name == other.name
def _process_subgroup_id(self, id):
if not isinstance(id, tuple):
id = (id,)
assert isinstance(id[0], bool) or isinstance(id[0], str), id[0]
if not isinstance(id[-1], GroupElement):
id = (*id, self.identity)
assert id[-1].group == self
if isinstance(id[0], bool):
assert id[1] in [1, 2, 3, 5]
if id[0] == True and id[1] == 1:
# flip subgroup of the O(2) subgroup of SO(3)
# this is equivalent to the C_2 subgroup of 180 deg rotations out of the plane (around X axis)
# V = np.asarray([0., -_PHI, 1 / _PHI])
V = np.array([1., 1., -1.])
V /= np.linalg.norm(V)
change_axis = np.zeros(4)
change_axis[:3] = V * np.sin(np.pi/3.)
change_axis[3] = np.cos(np.pi/3.)
adj = self.element(change_axis, 'Q') @ id[-1]
id = (False, 2, adj)
return id
def _subgroup(self, id) -> Tuple[
Group,
Callable[[GroupElement], GroupElement],
Callable[[GroupElement], GroupElement]
]:
r"""
Returns:
a tuple containing
- the subgroup,
- a function which maps an element of the subgroup to its inclusion in the original group and
- a function which maps an element of the original group to the corresponding element in the subgroup (returns None if the element is not contained in the subgroup)
"""
# TODO : implement this!
sg = None
parent_map = None
child_map = None
id, adj = id[:-1], id[-1]
if id == ('ico',):
sg = self
parent_map = build_adjoint_map(self, ~adj)
child_map = build_adjoint_map(self, adj)
elif id == ('tetra',):
raise NotImplementedError()
elif id == (False, 1):
sg = escnn.group.cyclic_group(1)
parent_map, child_map = build_trivial_subgroup_maps(self)
elif id == (False, 2):
sg = escnn.group.cyclic_group(2)
axis = np.asarray([0., 0., 1.])
parent_map = cn_to_ico(adj, sg, axis=axis)
child_map = ico_to_cn(adj, sg, axis=axis)
elif id == (False, 3):
sg = escnn.group.cyclic_group(3)
axis = np.asarray([1., 1., 1.]) / np.sqrt(3)
parent_map = cn_to_ico(adj, sg, axis=axis)
child_map = ico_to_cn(adj, sg, axis=axis)
elif id == (False, 5):
sg = escnn.group.cyclic_group(5)
axis = np.asarray([_PHI, 0., 1.])
axis /= np.linalg.norm(axis)
parent_map = cn_to_ico(adj, sg, axis=axis)
child_map = ico_to_cn(adj, sg, axis=axis)
elif id == (True, 2):
sg = escnn.group.dihedral_group(2)
parent_map, child_map = None, None
raise NotImplementedError()
elif id == (True, 3):
sg = escnn.group.dihedral_group(3)
parent_map, child_map = None, None
raise NotImplementedError()
elif id == (True, 5):
sg = escnn.group.dihedral_group(5)
parent_map, child_map = None, None
raise NotImplementedError()
else:
raise ValueError(f'Subgroup id {id} not recognized!')
return sg, parent_map, child_map
def _restrict_irrep(self, irrep: str, id) -> Tuple[np.matrix, List[str]]:
r"""
Returns:
a pair containing the change of basis and the list of irreps of the subgroup which appear in the restricted irrep
"""
sg_id, adj = id[:-1], id[-1]
irr = self.irrep(*irrep)
sg, _, _ = self.subgroup(id)
irreps = []
change_of_basis = None
try:
if sg_id == ('ico', ):
change_of_basis = irr.change_of_basis
irreps = irr.irreps
elif sg_id == (False, 1):
change_of_basis = np.eye(irr.size)
irreps = [(0,)]*irr.size
else:
raise NotImplementedError()
except NotImplementedError:
if sg.order() > 0:
change_of_basis, irreps = restrict_irrep(irr, sg_id)
else:
raise
change_of_basis = self.irrep(*irrep)(adj).T @ change_of_basis
return change_of_basis, irreps
def _build_representations(self):
r"""
Build the irreps for this group
"""
# Build all the Irreducible Representations
# add Trivial representation
self.irrep(0)
# add other irreducible representations
self.irrep(1)
self.irrep(2)
# SO(3)'s freq 3 irrep decomposes in another 3 dimensional irrep and a 4 dimensional one
self.irrep(3)
self.irrep(4)
# add all the irreps to the set of representations already built for this group
self.representations.update(**{irr.name: irr for irr in self.irreps()})
# build the regular representation
# N.B.: it represents the LEFT-ACTION of the elements
self.representations['regular'] = self.regular_representation
@property
def trivial_representation(self) -> Representation:
return self.irrep(0)
@property
def standard_representation(self) -> Representation:
r"""
Restriction of the standard representation of SO(3) as 3x3 rotation matrices
"""
name = f'standard'
if name not in self._representations:
change_of_basis = np.array([
[0, 0, 1],
[1, 0, 0],
[0, 1, 0]
])
self._representations[name] = change_basis(
self.irrep(1),
change_of_basis=change_of_basis,
name=name,
supported_nonlinearities=self.irrep(1).supported_nonlinearities,
)
return self._representations[name]
@property
def ico_vertices_representation(self) -> Representation:
# action on the 12 vertices of the icosahedron (or faces of the dodecahedron)
# quotient repr wrt C_5 subgroup?
return self.quotient_representation((False, 5), name='ico_vertices')
@property
def ico_faces_representation(self) -> Representation:
# action on the 20 faces of the icosahedron (or vertices of the dodecahedron)
# quotient repr wrt C_3 subgroup?
return self.quotient_representation((False, 3), name='ico_faces')
@property
def ico_edges_representation(self) -> Representation:
# action on the 30 edges of the icosahedron or dodecahedron
# quotient repr wrt C_2 subgroup
# n.b.: C_2 is the symmetry group of an edge
return self.quotient_representation((False, 2), name='ico_edges')
[docs] def bl_irreps(self, L: int) -> List[Tuple]:
r"""
Returns a list containing the id of all irreps of frequency smaller or equal to ``L``.
This method is useful to easily specify the irreps to be used to instantiate certain objects, e.g. the
Fourier based non-linearity :class:`~escnn.nn.FourierPointwise`.
"""
assert 0 <= L <= 4, (L)
return [(l,) for l in range(L+1)]
[docs] def bl_regular_representation(self, L: int) -> Representation:
r"""
Band-Limited regular representation up to frequency ``L`` (included).
Args:
L(int): max frequency
"""
assert isinstance(L, int)
assert 0 <= L <= 4
name = f'regular_{L}'
if name not in self._representations:
irreps = []
for l in range(L + 1):
irr = self.irrep(l)
irreps += [irr] * irr.size
self._representations[name] = directsum(irreps, name=name)
return self._representations[name]
[docs] def irrep(self, l: int) -> IrreducibleRepresentation:
r"""
Build the irrep of :math:`I` identified by the non-negative integer :math:`l`.
For :math:`l = 0, 1, 2`, the irrep is equivalent to the Wigner D matrix of the same frequency :math:`l`.
For :math:`l=3`, the 7-dimensional Wigner D matrix is decomposed in a 3-dimensional and a 4-dimensional irrep,
here identified respectively by :math:`l=3` and :math:`l=4`.
Args:
l (int): identifier of the irrep
Returns:
the corresponding irrep
"""
assert isinstance(l, int)
assert 0 <= l <= 4
name = f"irrep_{l}"
id = (l,)
if id not in self._irreps:
if l == 0:
# Trivial representation
irrep = build_trivial_irrep()
character = build_trivial_character()
supported_nonlinearities = ['pointwise', 'norm', 'gated', 'gate']
self._irreps[id] = IrreducibleRepresentation(self, id, name, irrep, 1, 'R',
supported_nonlinearities=supported_nonlinearities,
character=character,
frequency=0
)
elif l <= 2:
# other Irreducible Representations which are equivalent to Wigner D matrices
# irrep = lambda element, l=l: _wigner_d_matrix(element.to(element.param), l=l, param=element.param)
# character = lambda element, l=l: _character(element.to(element.param), l=l, param=element.param)
irrep = _build_irrep(l)
character = _build_character(l)
supported_nonlinearities = ['norm', 'gated']
self._irreps[id] = IrreducibleRepresentation(self, id, name, irrep, 2*l+1, 'R',
supported_nonlinearities=supported_nonlinearities,
character=character,
frequency=l)
elif l == 3 or l == 4:
irrep = _build_ico_irrep(self, l)
supported_nonlinearities = ['norm', 'gated']
self._irreps[id] = IrreducibleRepresentation(self, id, name, irrep, irrep[self.identity].shape[0], 'R',
supported_nonlinearities=supported_nonlinearities,
frequency=l)
else:
raise ValueError()
return self._irreps[id]
_cached_group_instance = None
@classmethod
def _generator(cls) -> 'Icosahedral':
if cls._cached_group_instance is None:
cls._cached_group_instance = Icosahedral()
return cls._cached_group_instance
def _is_axis_aligned(v: np.ndarray, n: int, verbose: bool = False, ATOL=1e-7, RTOL = 1e-5) -> bool:
norm = np.linalg.norm(v)
v = v / norm
if n == 2:
# rotation of order 2
# the rotation axis need to pass through the center of an edge of the icosahedron or the dodecahedron
# These 30 points can be found as any cyclic permutation or change of sign of each element of
# these 2 vectors:
p = np.array([1., 0., 0.]) # 6 combinations
q = np.array([_PHI + 1, _PHI, 1.]) # 24 combinations
q /= np.linalg.norm(q)
# remove sign ambiguity
v = np.abs(v)
# fix a permutation making the highest value first
m = np.argmax(v)
v = np.roll(v, -m)
ans = np.allclose(v, p, atol=ATOL, rtol=RTOL) or np.allclose(v, q, atol=ATOL, rtol=RTOL)
if not ans and verbose:
print(f'Rotation by a multiple of pi/2 not aligned with a 2-fold rotational axis of the Icosahedron.')
return ans
elif n == 3:
# rotation or order 3
# the rotation axis need to pass through a vertex of the dodecahedron
# These 20 points can be found as any cyclic permutation or change of sign of each element of
# these 2 vectors:
p = np.array([1., 1., 1.]) # 8 combinations
q = np.array([_PHI, 1 / _PHI, 0.]) # 12 combinations
p /= np.linalg.norm(p)
q /= np.linalg.norm(q)
# remove sign ambiguity
v = np.abs(v)
# fix a permutation making the highest value first
m = np.argmax(v)
v = np.roll(v, -m)
ans = np.allclose(v, p, atol=ATOL, rtol=RTOL) or np.allclose(v, q, atol=ATOL, rtol=RTOL)
if not ans and verbose:
print(f'Rotation by a multiple of 2pi/3 not aligned with a 3-fold rotational axis of the Icosahedron.')
return ans
elif n == 5:
# rotation or order 5
# the rotation axis need to pass through a vertex of the icosahedron
# These 12 points can be found as any cyclic permutation or change of sign of each element of this vector:
p = np.array([_PHI, 0., 1.]) # 12 combinations
p /= np.linalg.norm(p)
# remove sign ambiguity
v = np.abs(v)
# fix a permutation making the highest value first
m = np.argmax(v)
v = np.roll(v, -m)
ans = np.allclose(v, p, atol=ATOL, rtol=RTOL)
if not ans and verbose:
print(f'Rotation by a multiple of 2pi/5 not aligned with a 5-fold rotational axis of the Icosahedron.')
return ans
else:
raise ValueError('The rotation order must be one of {2, 3, 5}.')
#############################################
# SUBGROUPS MAPS
#############################################
# C_N #####################################
def ico_to_cn(adj: GroupElement, cn: escnn.group.CyclicGroup, axis: np.ndarray):
assert isinstance(adj.group, Icosahedral)
assert axis.shape == (3,)
assert np.isclose(np.linalg.norm(axis), 1.)
assert cn.order() in [2, 3, 5]
assert _is_axis_aligned(axis, cn.order())
def _map(e: GroupElement, cn=cn, adj=adj, axis=axis):
ico = adj.group
assert e.group == ico
e = adj @ e @ (~adj)
e = e.to('Q')
v = e[:3]
n = np.linalg.norm(v)
if np.allclose(n, 0.):
return cn.identity
elif np.allclose(v / n, axis):
# if the rotation is along the axis
s, c = n, e[-1]
theta = 2 * np.arctan2(s, c)
try:
return cn.element(theta, 'radians')
except ValueError:
return None
else:
return None
return _map
def cn_to_ico(adj: GroupElement, cn: escnn.group.CyclicGroup, axis: np.ndarray):
assert isinstance(adj.group, Icosahedral)
assert axis.shape == (3,)
assert np.isclose(np.linalg.norm(axis), 1.)
assert cn.order() in [2, 3, 5]
assert _is_axis_aligned(axis, cn.order())
def _map(e: GroupElement, cn=cn, adj=adj, axis=axis):
assert e.group == cn
ico = adj.group
theta_2 = e.to('radians') / 2.
q = np.empty(4)
q[:3] = axis * np.sin(theta_2)
q[-1] = np.cos(theta_2)
return (~adj) @ ico.element(q, 'Q') @ adj
return _map
#############################################
# Generate irreps
#############################################
from joblib import Memory
from escnn.group import __cache_path__
cache = Memory(__cache_path__, verbose=2)
def _build_ico_irrep(ico: Icosahedral, l: int):
# To enable caching, the output of _build_ico_irrep_picklable needs to be picklable so it can not return a
# dictionary with group elements as keys. In this method, we retrieved the cached results and wrap the keys into
# group elements again
irreps = _build_ico_irrep_picklable(ico, l)
return {
ico.element(g, param): v
for g, param, v in irreps
}
@cache.cache(ignore=['ico'])
def _build_ico_irrep_picklable(ico: Icosahedral, l: int) -> List[Tuple]:
# To enable caching, the output of this method needs to be picklable so we can not return a dictionary with
# group elements as keys
if l == 3:
# Representation of the generator of the cyclic subgroup of order 5
rho_p = np.zeros((3, 3))
rho_p[0, 0] = rho_p[1, 1] = np.cos(144 / 180. * np.pi)
rho_p[1, 0] = np.sin(144 / 180. * np.pi)
rho_p[0, 1] = -np.sin(144 / 180. * np.pi)
rho_p[2, 2] = 1.
# Representation of the generator of the cyclic subgroup of order 2
rho_q = np.zeros((3, 3))
rho_q[0, 0] = 1. / np.sqrt(5)
rho_q[0, 2] = - 2. / np.sqrt(5)
rho_q[1, 1] = - 1
rho_q[2, 0] = - 2. / np.sqrt(5)
rho_q[2, 2] = - 1. / np.sqrt(5)
elif l == 4:
# Representation of the generator of the cyclic subgroup of order 5
rho_p = np.zeros((4, 4))
rho_p[0, 0] = rho_p[1, 1] = np.cos(72 / 180. * np.pi)
rho_p[1, 0] = np.sin(72 / 180. * np.pi)
rho_p[0, 1] = -np.sin(72 / 180. * np.pi)
rho_p[2, 2] = rho_p[3, 3] = np.cos(144 / 180. * np.pi)
rho_p[3, 2] = np.sin(144 / 180. * np.pi)
rho_p[2, 3] = -np.sin(144 / 180. * np.pi)
# Representation of the generator of the cyclic subgroup of order 2
rho_q = np.zeros((4, 4))
rho_q[0, 2] = -1
rho_q[1, 1] = 2. / np.sqrt(5)
rho_q[1, 3] = 1. / np.sqrt(5)
rho_q[2, 0] = -1
rho_q[3, 1] = 1. / np.sqrt(5)
rho_q[3, 3] = -2. / np.sqrt(5)
else:
raise ValueError()
generators = [
(ico._generators[0], rho_p),
(ico._generators[1], rho_q),
]
irreps = generate_irrep_matrices_from_generators(ico, generators)
return [
(k.value, k.param, v)
for k, v in irreps.items()
]