Source code for escnn.group.irrep

from __future__ import annotations

import escnn.group
from escnn.group import Representation, GroupElement, Group
from escnn.group._numerical import decompose_representation_finitegroup
from escnn.group._numerical import decompose_representation_general

from typing import Callable, Any, List, Union, Dict, Tuple, Type

import numpy as np

__all__ = [
    "IrreducibleRepresentation",
    "build_irrep_from_generators",
    "generate_irrep_matrices_from_generators",
    "restrict_irrep"
]


[docs]class IrreducibleRepresentation(Representation): def __init__(self, group: escnn.group.Group, id: Tuple, name: str, representation: Union[Dict[escnn.group.GroupElement, np.ndarray], Callable[[Any], np.ndarray]], size: int, type: str, supported_nonlinearities: List[str], character: Union[Dict[escnn.group.GroupElement, float], Callable[[Any], float]] = None, **kwargs ): """ Describes an "*irreducible representation*" (*irrep*). It is a subclass of a :class:`~escnn.group.Representation`. Irreducible representations are the building blocks into which any other representation decomposes under a change of basis. Indeed, any :class:`~escnn.group.Representation` is internally decomposed into a direct sum of irreps. Args: group (Group): the group which is being represented id (tuple): args to generate this irrep using ``group.irrep(*id)`` name (str): an identification name for this representation representation (dict or callable): a callable implementing this representation or a dict mapping each group element to its representation. size (int): the size of the vector space where this representation is defined (i.e. the size of the matrices) type (str): type of the irrep. It needs to be one of `R`, `C` or `H`, which represent respectively real, complex and quaternionic types. NOTE: this parameter substitutes the old `sum_of_squares_constituents` from *e2cnn*. supported_nonlinearities (list): list of nonlinearitiy types supported by this representation. character (callable or dict, optional): a callable returning the character of this representation for an input element or a dict mapping each group element to its character. **kwargs: custom attributes the user can set and, then, access from the dictionary in :attr:`escnn.group.Representation.attributes` Attributes: ~.id (tuple): tuple which identifies this irrep; it can be used to generate this irrep as ``group.irrep(*id)`` ~.sum_of_squares_constituents (int): the sum of the squares of the multiplicities of pairwise distinct irreducible constituents of the character of this representation over a non-splitting field (see `Character Orthogonality Theorem <https://groupprops.subwiki.org/wiki/Character_orthogonality_theorem#Statement_over_general_fields_in_terms_of_inner_product_of_class_functions>`_ over general fields). This attribute is fully determined by the irrep's `type` as: +----------+---------------------------------+ | `type` | `sum_of_squares_constituents` | +==========+=================================+ | 'R' | `1` | +----------+---------------------------------+ | 'C' | `2` | +----------+---------------------------------+ | 'H' | `4` | +----------+---------------------------------+ """ assert type in {'R', 'C', 'H'} if type == 'C': assert size % 2 == 0 elif type == 'H': assert size % 4 == 0 super(IrreducibleRepresentation, self).__init__(group, name, [id], np.eye(size), supported_nonlinearities, representation=representation, character=character, **kwargs) assert isinstance(id, tuple) self.id: tuple = id self.irreducible = True self.type = type if self.type == 'R': self.sum_of_squares_constituents = 1 elif self.type == 'C': self.sum_of_squares_constituents = 2 elif self.type == 'H': self.sum_of_squares_constituents = 4 else: raise ValueError() def endomorphism_basis(self) -> np.ndarray: if self.type == 'R': return np.eye(self.size).reshape(1, self.size, self.size) elif self.type == 'C': basis = np.stack([ np.eye(2), np.diag([1., -1.])[::-1] ], axis=0) return np.kron(basis, np.eye(self.size // 2)) elif self.type == 'H': basis = np.stack([ np.eye(4), np.diag([1., -1., 1., -1.])[::-1], np.array([ [ 0., 0., -1., 0.], [ 0., 0., 0., -1.], [ 1., 0., 0., 0.], [ 0., 1., 0., 0.], ]), np.array([ [0., -1., 0., 0.], [1., 0., 0., 0.], [0., 0., 0., 1.], [0., 0., -1., 0.], ]), ], axis=0) return np.kron(basis, np.eye(self.size // 4)) else: raise ValueError()
def build_irrep_from_generators( group: escnn.group.Group, generators: List[Tuple[escnn.group.GroupElement, np.ndarray]], id: Tuple, name: str, type: str, supported_nonlinearities: List[str], **kwargs ) -> IrreducibleRepresentation: rho = generate_irrep_matrices_from_generators(group, generators) d = generators[0][1].shape[0] return IrreducibleRepresentation( group, id, name, rho, d, type, supported_nonlinearities, **kwargs ) def generate_irrep_matrices_from_generators( group: escnn.group.Group, generators: List[Tuple[escnn.group.GroupElement, np.ndarray]], ) -> List[np.ndarray]: assert group.order() > 0 d = generators[0][1].shape[0] for g, rho_g in generators: assert group == g.group assert rho_g.shape == (d, d) elements = set() added = set() identity = group.identity added.add(identity) elements.add(identity) rho = { g: rho_g for g, rho_g in generators } rho[identity] = np.eye(d) generators = [g for g, _ in generators] while len(added) > 0: new = set() for g in generators: for e in added: if g @ e not in rho: rho[g @ e] = rho[g] @ rho[e] if ~g @ e not in rho: rho[~g @ e] = rho[g].T @ rho[e] new |= {g @ e, ~g @ e} added = new - elements elements |= added assert len(elements) == group.order(), 'Error! The set of generators passed does not generate the whole group' for a in elements: assert ~a in elements assert np.allclose( rho[~a], rho[a].T ) for b in elements: assert a @ b in elements assert np.allclose( rho[a] @ rho[b], rho[a @ b] ) return rho from joblib import Memory # import os # cache = Memory(os.path.join(os.path.dirname(__file__), '_jl_restricted_irreps'), verbose=2) from escnn.group import __cache_path__ cache = Memory(__cache_path__, verbose=0) @cache.cache def _restrict_irrep(irrep_id: Tuple, id, group_class: str, **group_keys) -> Tuple[np.matrix, List[Tuple[Tuple, int]]]: group = escnn.group.groups_dict[group_class]._generator(**group_keys) irrep = group.irrep(*irrep_id) id = irrep.group._decode_subgroup_id_pickleable(id) subgroup, parent, child = group.subgroup(id) if subgroup.order() == 1: # if the subgroup is the trivial group, just return the identity cob and the list of trivial reprs return np.eye(irrep.size), [subgroup.trivial_representation.id]*irrep.size if subgroup.order() > 1: # if it is a finite group, we can sample all the element and use the precise method based on Character theory representation = { g: irrep(parent(g)) for g in subgroup.elements } # to solve the Sylvester equation and find the change of basis matrix, it is sufficient to sample # the generators of the subgroup change_of_basis, multiplicities = decompose_representation_finitegroup( representation, subgroup, ) else: # if the group is not finite, we rely on the numerical method which is based on some samples of the group representation = lambda g: irrep(parent(g)) change_of_basis, multiplicities = decompose_representation_general( representation, subgroup, ) irreps = [] for irr, m in multiplicities: irreps += [irr]*m return change_of_basis, irreps def restrict_irrep(irrep: IrreducibleRepresentation, id) -> Tuple[np.matrix, List[Tuple[str, int]]]: r""" Restrict the input `irrep` to the subgroup identified by `id`. """ group_keys = irrep.group._keys id = irrep.group._encode_subgroup_id_pickleable(id) return _restrict_irrep(irrep.id, id, irrep.group.__class__.__name__, **group_keys)