Source code for e2cnn.group.groups.cyclicgroup

from __future__ import annotations


from e2cnn.group import Group
from e2cnn.group import IrreducibleRepresentation, Representation
from e2cnn.group import utils

import numpy as np
import math

from typing import List, Tuple, Callable, Iterable


__all__ = ["CyclicGroup"]


_cached_group_instances = {}


[docs]class CyclicGroup(Group): def __init__(self, N: int): r""" Build an instance of the cyclic group :math:`C_N` which contains :math:`N` discrete planar rotations. The group elements are :math:`\{e, r, r^2, r^3, \dots, r^{N-1}\}`, with group law :math:`r^a \cdot r^b = r^{\ a + b \!\! \mod \!\! N \ }`. The cyclic group :math:`C_N` is isomorphic to the integers *modulo* ``N``. For this reason, elements are stored as the integers between :math:`0` and :math:`N-1`, where the :math:`k`-th element can also be interpreted as the discrete rotation by :math:`k\frac{2\pi}{N}`. Args: N (int): order of the group """ assert (isinstance(N, int) and N > 0) super(CyclicGroup, self).__init__("C%d" % N, False, True) self.elements = list(range(N)) self.elements_names = ['e'] + ['r%d' % i for i in range(1, N)] self.identity = 0 self._build_representations()
[docs] def inverse(self, element: int) -> int: r""" Return the inverse element :math:`r^{-j \mod N}` of the input element :math:`r^j`, specified by the input integer :math:`j` (``element``) Args: element (int): a group element :math:`r^j` Returns: its opposite :math:`r^{-j \mod N}` """ return (-element) % self.order()
[docs] def combine(self, e1: int, e2: int) -> int: r""" Return the composition of the two input elements. Given two integers :math:`a` and :math:`b` representing the elements :math:`r^a` and :math:`r^b`, the method returns the integer :math:`a + b \mod N` representing the element :math:`r^{a + b \mod N}`. Args: e1 (int): a group element :math:`r^a` e2 (int): another group element :math:`r^a` Returns: their composition :math:`r^{a+b \mod N}` """ return (e1 + e2) % self.order()
[docs] def equal(self, e1: int, e2: int) -> bool: r""" Check if the two input values corresponds to the same element. Args: e1 (int): an element e2 (int): another element Returns: whether they are the same element """ return e1 == e2
def is_element(self, element: int) -> bool: if isinstance(element, int): return 0 <= element < self.order() else: return False
[docs] def testing_elements(self) -> Iterable[int]: r""" A finite number of group elements to use for testing. """ return iter(self.elements)
def __eq__(self, other): if not isinstance(other, CyclicGroup): return False else: return self.name == other.name and self.order() == other.order()
[docs] def subgroup(self, id: int) -> Tuple[Group, Callable, Callable]: r""" Restrict the current group to the cyclic subgroup :math:`C_M`. If the current group is :math:`C_N`, it restricts to the subgroup generated by :math:`r^{(N/M)}`. Notice that :math:`M` has to divide the order :math:`N` of the current group. The method takes as input the integer :math:`M` identifying of the subgroup to build (the order of the subgroup) Args: id (int): the integer :math:`M` identifying of the subgroup Returns: a tuple containing - the subgroup, - a function which maps an element of the subgroup to its inclusion in the original group and - a function which maps an element of the original group to the corresponding element in the subgroup (returns None if the element is not contained in the subgroup) """ assert isinstance(id, int) order = id assert self.order() % order == 0, \ "Error! The subgroups of a cyclic group have an order that divides the order of the supergroup." \ " %d does not divide %d " % (order, self.order()) if id not in self._subgroups: # Build the subgroup ratio = self.order()//order # take the elements of the group generated by "r^ratio" sg = CyclicGroup(order) parent_mapping = lambda e, ratio=ratio: e * ratio child_mapping = lambda e, ratio=ratio: None if e % ratio != 0 else int(e // ratio) self._subgroups[id] = sg, parent_mapping, child_mapping return self._subgroups[id]
def _restrict_irrep(self, irrep: str, id: int) -> Tuple[np.matrix, List[str]]: r""" Restrict the input irrep to the subgroup :math:`C_m` with order ``m``. If the current group is :math:`C_n`, it restricts to the subgroup generated by :math:`r^{(n/m)}`. Notice that :math:`m` has to divide the order :math:`n` of the current group. The method takes as input the integer :math:`m` identifying of the subgroup to build (the order of the subgroup) Args: irrep (str): the name/identifier of the irrep to restrict id (int): the integer ``m`` identifying the subgroup Returns: a pair containing the change of basis and the list of irreps of the subgroup which appear in the restricted irrep """ irr = self.irreps[irrep] # Build the subgroup sg, _, _ = self.subgroup(id) order = id change_of_basis = None irreps = [] f = irr.attributes["frequency"] % order if f > order/2: f = order - f change_of_basis = np.array([[1, 0], [0, -1]]) else: change_of_basis = np.eye(irr.size) r = f"irrep_{f}" irreps.append(r) if sg.irreps[r].size < irr.size: irreps.append(r) return change_of_basis, irreps def _build_representations(self): r""" Build the irreps and the regular representation for this group """ N = self.order() # Build all the Irreducible Representations for k in range(0, int(N // 2) + 1): self.irrep(k) # Build all Representations # add all the irreps to the set of representations already built for this group self.representations.update(**self.irreps) # build the regular representation self.representations['regular'] = self.regular_representation self.representations['regular'].supported_nonlinearities.add('vectorfield') def _build_quotient_representations(self): r""" Build all the quotient representations for this group """ for n in range(2, int(math.ceil(math.sqrt(self.order())))): if self.order() % n == 0: self.quotient_representation(n) @property def trivial_representation(self) -> Representation: return self.representations['irrep_0']
[docs] def irrep(self, k: int) -> IrreducibleRepresentation: r""" Build the irrep of frequency ``k`` of the current cyclic group. The frequency has to be a non-negative integer in :math:`\{0, \dots, \left \lfloor N/2 \right \rfloor \}`, where :math:`N` is the order of the group. Args: k (int): the frequency of the representation Returns: the corresponding irrep """ assert 0 <= k <= self.order()//2 name = f"irrep_{k}" if name not in self.irreps: n = self.order() base_angle = 2.0 * np.pi / n if k == 0: # Trivial representation irrep = lambda element, identity=np.eye(1): identity character = lambda e: 1 supported_nonlinearities = ['pointwise', 'gate', 'norm', 'gated', 'concatenated'] self.irreps[name] = IrreducibleRepresentation(self, name, irrep, 1, 1, supported_nonlinearities=supported_nonlinearities, # character=character, # trivial=True, frequency=k) elif n % 2 == 0 and k == int(n/2): # 1 dimensional Irreducible representation (only for even order groups) irrep = lambda element, k=k, base_angle=base_angle: np.array([[np.cos(k * element * base_angle)]]) supported_nonlinearities = ['norm', 'gated', 'concatenated'] self.irreps[name] = IrreducibleRepresentation(self, name, irrep, 1, 1, supported_nonlinearities=supported_nonlinearities, frequency=k) else: # 2 dimensional Irreducible Representations # build the rotation matrix with rotation frequency 'frequency' irrep = lambda element, k=k, base_angle=base_angle: utils.psi(element * base_angle, k=k) supported_nonlinearities = ['norm', 'gated'] self.irreps[name] = IrreducibleRepresentation(self, name, irrep, 2, 2, supported_nonlinearities=supported_nonlinearities, frequency=k) return self.irreps[name]
@staticmethod def _generator(N: int) -> 'CyclicGroup': global _cached_group_instances if N not in _cached_group_instances: _cached_group_instances[N] = CyclicGroup(N) return _cached_group_instances[N]