from __future__ import annotations
from e2cnn.group import Group
from e2cnn.group import IrreducibleRepresentation, Representation
from e2cnn.group import utils
import numpy as np
import math
from typing import List, Tuple, Callable, Iterable
__all__ = ["CyclicGroup"]
_cached_group_instances = {}
[docs]class CyclicGroup(Group):
def __init__(self, N: int):
r"""
Build an instance of the cyclic group :math:`C_N` which contains :math:`N` discrete planar rotations.
The group elements are :math:`\{e, r, r^2, r^3, \dots, r^{N-1}\}`, with group law
:math:`r^a \cdot r^b = r^{\ a + b \!\! \mod \!\! N \ }`.
The cyclic group :math:`C_N` is isomorphic to the integers *modulo* ``N``.
For this reason, elements are stored as the integers between :math:`0` and :math:`N-1`, where the :math:`k`-th
element can also be interpreted as the discrete rotation by :math:`k\frac{2\pi}{N}`.
Args:
N (int): order of the group
"""
assert (isinstance(N, int) and N > 0)
super(CyclicGroup, self).__init__("C%d" % N, False, True)
self.elements = list(range(N))
self.elements_names = ['e'] + ['r%d' % i for i in range(1, N)]
self.identity = 0
self._build_representations()
[docs] def inverse(self, element: int) -> int:
r"""
Return the inverse element :math:`r^{-j \mod N}` of the input element :math:`r^j`, specified by the input
integer :math:`j` (``element``)
Args:
element (int): a group element :math:`r^j`
Returns:
its opposite :math:`r^{-j \mod N}`
"""
return (-element) % self.order()
[docs] def combine(self, e1: int, e2: int) -> int:
r"""
Return the composition of the two input elements.
Given two integers :math:`a` and :math:`b` representing the elements :math:`r^a` and :math:`r^b`, the method
returns the integer :math:`a + b \mod N` representing the element :math:`r^{a + b \mod N}`.
Args:
e1 (int): a group element :math:`r^a`
e2 (int): another group element :math:`r^a`
Returns:
their composition :math:`r^{a+b \mod N}`
"""
return (e1 + e2) % self.order()
[docs] def equal(self, e1: int, e2: int) -> bool:
r"""
Check if the two input values corresponds to the same element.
Args:
e1 (int): an element
e2 (int): another element
Returns:
whether they are the same element
"""
return e1 == e2
def is_element(self, element: int) -> bool:
if isinstance(element, int):
return 0 <= element < self.order()
else:
return False
[docs] def testing_elements(self) -> Iterable[int]:
r"""
A finite number of group elements to use for testing.
"""
return iter(self.elements)
def __eq__(self, other):
if not isinstance(other, CyclicGroup):
return False
else:
return self.name == other.name and self.order() == other.order()
[docs] def subgroup(self, id: int) -> Tuple[Group, Callable, Callable]:
r"""
Restrict the current group to the cyclic subgroup :math:`C_M`.
If the current group is :math:`C_N`, it restricts to the subgroup generated by :math:`r^{(N/M)}`.
Notice that :math:`M` has to divide the order :math:`N` of the current group.
The method takes as input the integer :math:`M` identifying of the subgroup to build (the order of the subgroup)
Args:
id (int): the integer :math:`M` identifying of the subgroup
Returns:
a tuple containing
- the subgroup,
- a function which maps an element of the subgroup to its inclusion in the original group and
- a function which maps an element of the original group to the corresponding element in the subgroup (returns None if the element is not contained in the subgroup)
"""
assert isinstance(id, int)
order = id
assert self.order() % order == 0, \
"Error! The subgroups of a cyclic group have an order that divides the order of the supergroup." \
" %d does not divide %d " % (order, self.order())
if id not in self._subgroups:
# Build the subgroup
ratio = self.order()//order
# take the elements of the group generated by "r^ratio"
sg = CyclicGroup(order)
parent_mapping = lambda e, ratio=ratio: e * ratio
child_mapping = lambda e, ratio=ratio: None if e % ratio != 0 else int(e // ratio)
self._subgroups[id] = sg, parent_mapping, child_mapping
return self._subgroups[id]
def _restrict_irrep(self, irrep: str, id: int) -> Tuple[np.matrix, List[str]]:
r"""
Restrict the input irrep to the subgroup :math:`C_m` with order ``m``.
If the current group is :math:`C_n`, it restricts to the subgroup generated by :math:`r^{(n/m)}`.
Notice that :math:`m` has to divide the order :math:`n` of the current group.
The method takes as input the integer :math:`m` identifying of the subgroup to build (the order of the subgroup)
Args:
irrep (str): the name/identifier of the irrep to restrict
id (int): the integer ``m`` identifying the subgroup
Returns:
a pair containing the change of basis and the list of irreps of the subgroup which appear in the restricted irrep
"""
irr = self.irreps[irrep]
# Build the subgroup
sg, _, _ = self.subgroup(id)
order = id
change_of_basis = None
irreps = []
f = irr.attributes["frequency"] % order
if f > order/2:
f = order - f
change_of_basis = np.array([[1, 0], [0, -1]])
else:
change_of_basis = np.eye(irr.size)
r = f"irrep_{f}"
irreps.append(r)
if sg.irreps[r].size < irr.size:
irreps.append(r)
return change_of_basis, irreps
def _build_representations(self):
r"""
Build the irreps and the regular representation for this group
"""
N = self.order()
# Build all the Irreducible Representations
for k in range(0, int(N // 2) + 1):
self.irrep(k)
# Build all Representations
# add all the irreps to the set of representations already built for this group
self.representations.update(**self.irreps)
# build the regular representation
self.representations['regular'] = self.regular_representation
self.representations['regular'].supported_nonlinearities.add('vectorfield')
def _build_quotient_representations(self):
r"""
Build all the quotient representations for this group
"""
for n in range(2, int(math.ceil(math.sqrt(self.order())))):
if self.order() % n == 0:
self.quotient_representation(n)
@property
def trivial_representation(self) -> Representation:
return self.representations['irrep_0']
[docs] def irrep(self, k: int) -> IrreducibleRepresentation:
r"""
Build the irrep of frequency ``k`` of the current cyclic group.
The frequency has to be a non-negative integer in :math:`\{0, \dots, \left \lfloor N/2 \right \rfloor \}`,
where :math:`N` is the order of the group.
Args:
k (int): the frequency of the representation
Returns:
the corresponding irrep
"""
assert 0 <= k <= self.order()//2
name = f"irrep_{k}"
if name not in self.irreps:
n = self.order()
base_angle = 2.0 * np.pi / n
if k == 0:
# Trivial representation
irrep = lambda element, identity=np.eye(1): identity
character = lambda e: 1
supported_nonlinearities = ['pointwise', 'gate', 'norm', 'gated', 'concatenated']
self.irreps[name] = IrreducibleRepresentation(self, name, irrep, 1, 1,
supported_nonlinearities=supported_nonlinearities,
# character=character,
# trivial=True,
frequency=k)
elif n % 2 == 0 and k == int(n/2):
# 1 dimensional Irreducible representation (only for even order groups)
irrep = lambda element, k=k, base_angle=base_angle: np.array([[np.cos(k * element * base_angle)]])
supported_nonlinearities = ['norm', 'gated', 'concatenated']
self.irreps[name] = IrreducibleRepresentation(self, name, irrep, 1, 1,
supported_nonlinearities=supported_nonlinearities,
frequency=k)
else:
# 2 dimensional Irreducible Representations
# build the rotation matrix with rotation frequency 'frequency'
irrep = lambda element, k=k, base_angle=base_angle: utils.psi(element * base_angle, k=k)
supported_nonlinearities = ['norm', 'gated']
self.irreps[name] = IrreducibleRepresentation(self, name, irrep, 2, 2,
supported_nonlinearities=supported_nonlinearities,
frequency=k)
return self.irreps[name]
@staticmethod
def _generator(N: int) -> 'CyclicGroup':
global _cached_group_instances
if N not in _cached_group_instances:
_cached_group_instances[N] = CyclicGroup(N)
return _cached_group_instances[N]