from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Tuple, Callable, Iterable, List, Any, Dict
import escnn.group
import numpy as np
from scipy import sparse
__all__ = ["Group", "GroupElement"]
[docs]class Group(ABC):
@property
@abstractmethod
def PARAM(self) -> str:
f"""
Default parametrization used for storing the elements of the group.
"""
pass
@property
@abstractmethod
def PARAMETRIZATIONS(self) -> List[str]:
f"""
List of all supported parametrizations of the group.
"""
pass
def __init__(self, name: str, continuous: bool, abelian: bool):
r"""
Abstract class defining the interface of a group.
A group is a set of *group elements* together with a binary operation satisfying a number of axioms.
In this library, this is implemented using this class :class:`~escnn.group.Group` and the class
:class:`~escnn.group.GroupElement`.
One can retrieve or generate elements of a group by using, for instance, the properties or methods
:meth:`~escnn.group.Group.identity` , :meth:`~escnn.group.Group.elements` or :meth:`~escnn.group.Group.sample`.
Each group may also have additional methods to generate its group elements.
Additionally, one can use the method :meth:`~escnn.group.Group.element` to generate a new group element.
The group algebra is directly implemented inside :class:`~escnn.group.GroupElement` such that one can combine
group elements in a way that resamples mathematical expressions.
In particular, the ``@`` implements the binary product while ``~`` implements the group inverse.
See :class:`~escnn.group.GroupElement` for more details.
Args:
name (str): name identifying the group
continuous (bool): whether the group is non-finite or finite
abelian (bool): whether the group is *abelian* (commutative)
Attributes:
~.name (str): Name identifying the group
~.continuous (bool): Whether it is a non-finite or a finite group
~.abelian (bool): Whether it is an *abelian* group (i.e. if the group law is commutative)
"""
self.name = name
self.continuous = continuous
self.abelian = abelian
self._irreps = {}
self._representations = {}
self._subgroups = {}
self._homspaces = {}
[docs] def order(self) -> int:
r"""
Returns the number of elements in this group if it is a finite group, otherwise -1 is returned
Returns:
the size of the group or ``-1`` if it is a continuous group
"""
if self.elements is not None:
return len(self.elements)
else:
return -1
[docs] def element(self, element, param: str = None) -> GroupElement:
r"""
Generate the element of the current group parametrized by ``element`` according to
the parametrization ``param``.
Each group supports a different set of parametrizations.
By default, the parametrization :attr:`escnn.group.Group.PARAM` is used.
The list of all available parametrizations of a group can be accessed through the property
:attr:`escnn.group.Group.PARAMETRIZATIONS`.
Args:
element: values parametrizing a group element.
param (str): string identifying the parametrization to be used.
Returns:
an instance of :class:`~escnn.group.GroupElement`
"""
if param is None:
param = self.PARAM
return GroupElement(element, self, param)
@property
@abstractmethod
def identity(self) -> GroupElement:
r"""
The identity element of the group.
The identity element :math:`e` satisfies the following property
:math:`\forall\ g \in G,\ g \cdot e = e \cdot g= g` .
"""
pass
@property
@abstractmethod
def elements(self) -> List[GroupElement]:
r"""
If the group is finite (i.e. ``self.continuous = False``), it is a list of all group elements.
If the group is not finite, this property is `None`.
Returns:
a list of :class:`~escnn.group.GroupElement` instances
"""
pass
@property
@abstractmethod
def subgroup_trivial_id(self):
r"""
The subgroup `id` associated with the trivial subgroup containing only the identity element :math:`{e}`.
The id can be used in the method :meth:`~escnn.group.Group.subgroup` to generate the subgroup.
"""
pass
@property
@abstractmethod
def subgroup_self_id(self):
r"""
The subgroup `id` associated with the group itself.
The id can be used in the method :meth:`~escnn.group.Group.subgroup` to generate the subgroup.
"""
pass
@property
@abstractmethod
def _keys(self) -> Dict[str, Any]:
pass
@property
@abstractmethod
def generators(self) -> List[GroupElement]:
r"""
If the group is finite (``self.continuous = False``), a list of group elements which can generate this group.
Should raise a `ValueError` if the group is not finite.
Returns:
a list of :class:`~escnn.group.GroupElement` instances
"""
pass
###########################################################################
# METHODS DEFINING THE GROUP LAW AND THE OPERATIONS ON THE GROUP'S ELEMENTS
###########################################################################
@abstractmethod
def _combine(self, e1, e2,
param: str,
param1: str = None,
param2: str = None
):
r"""
Method that returns the combination of two group elements according to the *group law*.
Args:
e1: an element of the group
e2: another element of the group
Returns:
the group element :math:`e_1 \cdot e_2`
"""
pass
@abstractmethod
def _inverse(self, element, param: str):
r"""
Method that returns the inverse in the group of the element given as input
Args:
element: an element of the group
Returns:
its inverse
"""
pass
@abstractmethod
def _is_element(self,
element,
param: str,
verbose: bool = False
) -> bool:
r"""
Check whether the input is an element of this group or not.
Args:
element: input object to test
Returns:
if the input is an element of the group
"""
pass
@abstractmethod
def _equal(self, e1, e2,
param: str,
param1: str = None,
param2: str = None
) -> bool:
r"""
Method that checks whether the two inputs are the same element of the group.
This is especially useful for continuous groups with periodicity; see for instance
:meth:`escnn.group.SO2.equal`.
Args:
e1: an element of the group
e2: another element of the group
Returns:
if they are equal
"""
pass
@abstractmethod
def _change_param(self, element, p_from: str, p_to: str):
pass
@abstractmethod
def _hash_element(self, element, param: str):
r"""
Method that returns a unique hash for a group element given in input
Args:
element: an element of the group
Returns:
a unique hash
"""
pass
@abstractmethod
def _repr_element(self, element, param: str):
r"""
Method that returns a representative string for a group element given in input
Args:
element: an element of the group
Returns:
a unique hash
"""
pass
###########################################################################
def __repr__(self):
return self.name
@abstractmethod
def __eq__(self, other):
pass
[docs] @abstractmethod
def sample(self) -> GroupElement:
r"""
Sample a random element of the group from a uniform distribution over the group.
Returns:
:class:`~escnn.group.GroupElement`: the element sampled
"""
pass
[docs] def grid(self, *args, **kwargs) -> List[GroupElement]:
r"""
Method to generate collections fo points over the group.
Each group should implement its own set of collections.
Check the individual groups' documentations for details about the supported arguments.
Returns:
a list of :class:`~escnn.group.GroupElement` instances
"""
raise NotImplementedError()
def _process_subgroup_id(self, id):
return id
[docs] def subgroup(self, id) -> Tuple[
escnn.group.Group,
Callable[[escnn.group.GroupElement], escnn.group.GroupElement],
Callable[[escnn.group.GroupElement], escnn.group.GroupElement]
]:
r"""
Restrict the current group to the subgroup identified by the input ``id``.
Args:
id: the identifier of the subgroup
Returns:
a tuple containing
- the subgroup,
- a function which maps an element of the subgroup to its inclusion in the original group and
- a function which maps an element of the original group to the corresponding element in the subgroup (returns None if the element is not contained in the subgroup)
"""
id = self._process_subgroup_id(id)
if id not in self._subgroups:
subgroup, parent_mapping, child_mapping = self._subgroup(id)
self._subgroups[id] = subgroup, parent_mapping, child_mapping
return self._subgroups[id]
@abstractmethod
def _subgroup(self, id) -> Tuple[
escnn.group.Group,
Callable[[escnn.group.GroupElement], escnn.group.GroupElement],
Callable[[escnn.group.GroupElement], escnn.group.GroupElement]
]:
r"""
Restrict the current group to the subgroup identified by the input ``id``.
Args:
id: the identifier of the subgroup
Returns:
a tuple containing
- the subgroup,
- a function which maps an element of the subgroup to its inclusion in the original group and
-a function which maps an element of the original group to the corresponding element in the subgroup (returns None if the element is not contained in the subgroup)
"""
pass
def _combine_subgroups(self, sg_id1, sg_id2):
raise NotImplementedError
[docs] def irreps(self) -> List[escnn.group.IrreducibleRepresentation]:
r"""
List containing all irreducible representations (:class:`~escnn.group.IrreducibleRepresentation`)
currently instantiated for this group.
Returns:
a list containing all irreducible representations built
"""
return list(self._irreps.values())
@property
def representations(self) -> Dict[str, escnn.group.Representation]:
r"""
Dictionary containing all representations (:class:`~escnn.group.Representation`)
instantiated for this group.
Returns:
a dictionary containing all representations built
"""
return self._representations
@property
@abstractmethod
def trivial_representation(self) -> escnn.group.IrreducibleRepresentation:
r"""
Builds the trivial representation of the group.
The trivial representation is a 1-dimensional representation which maps any element to 1,
i.e. :math:`\forall g \in G,\ \rho(g) = 1`.
Returns:
the trivial representation of the group
"""
pass
[docs] @abstractmethod
def irrep(self, *id) -> escnn.group.IrreducibleRepresentation:
r"""
Builds the irreducible representation (:class:`~escnn.group.IrreducibleRepresentation`) of the group which is
specified by the input arguments.
.. seealso ::
Check the documentation of the specific group subclass used for more information on the valid ``id`` values.
Args:
*id: parameters identifying the specific irrep.
Returns:
the irrep built
"""
# TODO implement memoization here and let subclasses define an _irrep(*id) module
pass
@property
def regular_representation(self) -> escnn.group.Representation:
r"""
Builds the regular representation of the group if the group has a *finite* number of elements;
returns ``None`` otherwise.
The regular representation of a finite group :math:`G` acts on a vector space :math:`\R^{|G|}` by permuting its
axes.
Specifically, associating each axis :math:`e_g` of :math:`\R^{|G|}` to an element :math:`g \in G`, the
representation of an element :math:`\tilde{g}\in G` is a permutation matrix which maps :math:`e_g` to
:math:`e_{\tilde{g}g}`.
For instance, the regular representation of the group :math:`C_4` with elements
:math:`\{r^k | k=0,\dots,3 \}` is instantiated by:
+-----------------------------------+------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------+
| :math:`g` | :math:`e` | :math:`r` | :math:`r^2` | :math:`r^3` |
+===================================+============================================================================================================+============================================================================================================+============================================================================================================+============================================================================================================+
| :math:`\rho_\text{reg}^{C_4}(g)` | :math:`\begin{bmatrix} 1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & 1 & 0 \\ 0 & 0 & 0 & 1 \\ \end{bmatrix}` | :math:`\begin{bmatrix} 0 & 0 & 0 & 1 \\ 1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & 1 & 0 \\ \end{bmatrix}` | :math:`\begin{bmatrix} 0 & 0 & 1 & 0 \\ 0 & 0 & 0 & 1 \\ 1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ \end{bmatrix}` | :math:`\begin{bmatrix} 0 & 1 & 0 & 0 \\ 0 & 0 & 1 & 0 \\ 0 & 0 & 0 & 1 \\ 1 & 0 & 0 & 0 \\ \end{bmatrix}` |
+-----------------------------------+------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------+
A vector :math:`v=\sum_g v_g e_g` in :math:`\R^{|G|}` can be interpreted as a scalar function
:math:`v:G \to \R,\, g \mapsto v_g` on :math:`G`.
Returns:
the regular representation of the group
"""
if self.order() < 0:
raise ValueError(f"Regular representation is supported only for finite groups but "
f"the group {self.name} has an infinite number of elements")
else:
if "regular" not in self.representations:
irreps, change_of_basis, change_of_basis_inv = escnn.group.representation.build_regular_representation(self)
supported_nonlinearities = ['pointwise', 'norm', 'gated', 'concatenated']
self.representations["regular"] = escnn.group.Representation(self,
"regular",
[r.id for r in irreps],
change_of_basis,
supported_nonlinearities,
change_of_basis_inv=change_of_basis_inv,
)
return self.representations["regular"]
[docs] def quotient_representation(self,
subgroup_id,
representatives: List[GroupElement] = None,
name: str = None,
) -> escnn.group.Representation:
r"""
Builds the quotient representation of the group with respect to the subgroup identified by the
input ``subgroup_id``.
Similar to :meth:`~escnn.group.Group.regular_representation`, the quotient representation
:math:`\rho_\text{quot}^{G/H}` of :math:`G` w.r.t. a subgroup :math:`H` acts on :math:`\R^{|G|/|H|}` by
permuting its axes.
Labeling the axes by the cosets :math:`gH` in the quotient space :math:`G/H`, it can be defined via its action
:math:`\rho_\text{quot}^{G/H}(\tilde{g})e_{gH}=e_{\tilde{g}gH}`.
Regular and trivial representations are two specific cases of quotient representations obtained by choosing
:math:`H=\{e\}` or :math:`H=G`, respectively.
Vectors in the representation space :math:`\R^{|G|/|H|}` can be viewed as scalar functions on the quotient
space :math:`G/H`.
The quotient representation :math:`\rho_\text{quot}^{G/H}` can also be defined as the
:meth:`~escnn.group.Group.induced_representation` from the trivial representation of the subgroup :math:`H`.
.. todo ::
docs for `representatives`
Args:
subgroup_id: identifier of the subgroup
representatives (list, optional):
name (str, optional): optionally, specify a custom name for this representation
Returns:
the quotient representation of the group
"""
if name is None:
name = f"quotient[{subgroup_id}]"
if name not in self.representations:
subgroup, _, _ = self.subgroup(subgroup_id)
supported_nonlinearities = _induced_nonlinearities(subgroup.trivial_representation)
irreps, change_of_basis, change_of_basis_inv = self._induced_from_irrep(
subgroup_id,
subgroup.trivial_representation,
representatives
)
self.representations[name] = escnn.group.Representation(self,
name,
[r.id for r in irreps],
change_of_basis,
supported_nonlinearities,
change_of_basis_inv=change_of_basis_inv,
)
return self.representations[name]
[docs] def induced_representation(self,
subgroup_id,
repr: escnn.group.IrreducibleRepresentation,
representatives: List[GroupElement] = None,
name: str = None
) -> escnn.group.Representation:
r"""
Builds the induced representation from the input representation ``repr`` of the subgroup identified by
the input ``subgroup_id``.
.. todo ::
docs for `representatives`
Args:
subgroup_id: identifier of the subgroup
repr (Representation): the representation of the subgroup
representatives (list, optional):
name (str, optional): optionally, specify a custom name for this representation
Returns:
the induced representation of the group
"""
assert repr.irreducible, "Induction from general representations is not supported yet"
if name is None:
name = f"induced[{subgroup_id}][{repr.name}]"
if name not in self.representations:
supported_nonlinearities = _induced_nonlinearities(repr)
irreps, change_of_basis, change_of_basis_inv = self._induced_from_irrep(
subgroup_id,
repr,
representatives
)
self.representations[name] = escnn.group.Representation(self,
name,
[r.id for r in irreps],
change_of_basis,
supported_nonlinearities,
change_of_basis_inv=change_of_basis_inv,
)
return self.representations[name]
def _induced_from_irrep(self, subgroup_id: Tuple[float, int],
repr: escnn.group.IrreducibleRepresentation,
representatives: List[GroupElement] = None,
) -> Tuple[List[escnn.group.IrreducibleRepresentation], np.ndarray, np.ndarray]:
r"""
Builds the induced representation from the input *irreducible* representation ``repr`` of the subgroup
identified by the input ``subgroup_id``.
.. todo ::
docs for `representatives`
Args:
subgroup_id: identifier of the subgroup
repr (Representation): the representation of the subgroup
Returns:
a tuple containing the list of irreps, the change of basis and the inverse change of basis of
the induced representation
"""
assert repr.irreducible
return escnn.group.representation.build_induced_representation(
self,
subgroup_id,
repr,
representatives
)
[docs] def spectral_regular_representation(self, *irreps, name: str = None) -> 'Representation':
r"""
Finite dimensional invariant subspace of the regular representation containing only the irreps passed in input.
The regular representation is expressed in the spectral basis, i.e. as a direct sum of irreps.
The optional parameter ``name`` is also used for caching purpose.
Consecutive calls of this method using the same ``name`` will ignore the argument ``irreps``
and return the same instance of representation.
.. seealso::
:meth:`escnn.group.HomSpace.induced_representation`
"""
if name is None:
irreps_names = '|'.join(str(i) for i in irreps)
name = f'regular_[{irreps_names}]'
return self.spectral_quotient_representation(self.subgroup_trivial_id, *irreps, name=name)
[docs] def spectral_quotient_representation(self, subgroup_id: Tuple, *irreps, name: str = None) -> 'Representation':
r"""
Finite dimensional invariant subspace of the quotient representation containing only the irreps passed in input.
The quotient representation is expressed in the spectral basis, i.e. as a direct sum of irreps.
The optional parameter ``name`` is also used for caching purpose.
Consecutive calls of this method using the same ``name`` will ignore the arguments ``subgroup_id`` and ``irreps``
and return the same instance of representation.
.. seealso::
:meth:`escnn.group.HomSpace.induced_representation`
"""
if name is None:
irreps_names = '|'.join(str(i) for i in irreps)
name = f'quotient[{subgroup_id}]_[{irreps_names}]'
if name not in self._representations:
homspace = self.homspace(subgroup_id)
self._representations[name] = homspace.induced_representation(homspace.H.trivial_representation.id, irreps, name)
return self._representations[name]
[docs] def restrict_representation(self, id, repr: escnn.group.Representation) -> escnn.group.Representation:
r"""
Restrict the input :class:`~escnn.group.Representation` to the subgroup identified by ``id``.
Any representation :math:`\rho : G \to \GL{\R^n}` can be uniquely restricted to a representation
of a subgroup :math:`H < G` by restricting its domain of definition:
.. math ::
\Res{H}{G}(\rho): H \to \GL{{\R}^n},\ h \mapsto \rho\big|_H(h)
We recommend directly using the method :meth:`escnn.group.Representation.restrict`.
.. seealso ::
Check the documentation of the method :meth:`~escnn.group.Group.subgroup()` of the group used to see
the available subgroups and accepted ids.
Args:
id: identifier of the subgroup
repr (Representation): the representation to restrict
Returns:
the restricted representation
"""
assert repr.group == self
sg, _, _ = self.subgroup(id)
id = self._process_subgroup_id(id)
# First, restrict each irrep in the representation
irreps_changes_of_basis = []
irreps = []
for irr in repr.irreps:
irrep_cob, reduced_irreps = self._restrict_irrep(irr, id)
size = self.irrep(*irr).size
assert irrep_cob.shape == (size, size)
irreps_changes_of_basis.append(irrep_cob)
irreps += reduced_irreps
# concatenate the restricted irreps and merge the representation's change of basis with the
# restricted irreps' change of basis matrices
irreps_changes_of_basis = sparse.block_diag(irreps_changes_of_basis, format='csc')
change_of_basis = repr.change_of_basis @ irreps_changes_of_basis
name = f"{self.name}:{repr.name}"
resr = escnn.group.Representation(sg,
name,
irreps,
change_of_basis,
repr.supported_nonlinearities)
if resr.is_trivial() and 'pointwise' not in repr.supported_nonlinearities:
resr.supported_nonlinearities.add("pointwise")
return resr
[docs] def homspace(self, id) -> escnn.group.HomSpace:
r"""
If :math:`G` is the current group and ``id`` identifies the subgroup :math:`H`
(see :meth:`~escnn.group.Group.subgroup`), this method generates the homogeneous space
:class:`~escnn.group.HomSpace` :math:`X = G / H`.
.. note ::
The generated instance of :class:`~escnn.group.HomSpace` is cached inside the instance of the current
group such that repeated calls of this method using the same ``id`` return the same instance of
:class:`~escnn.group.HomSpace` and no additional computations are required.
Returns:
an instance of :class:`~escnn.group.HomSpace`
"""
id = self._process_subgroup_id(id)
if id not in self._homspaces:
self._homspaces[id] = escnn.group.HomSpace(self, self._process_subgroup_id(id))
return self._homspaces[id]
@abstractmethod
def _restrict_irrep(self, irrep: Tuple, id) -> Tuple[np.matrix, List[Tuple]]:
pass
def _clebsh_gordan_coeff(self, m, n, j) -> np.ndarray:
group_keys = self._keys
m = self.get_irrep_id(m)
n = self.get_irrep_id(n)
j = self.get_irrep_id(j)
return escnn.group._clebsh_gordan._clebsh_gordan_tensor(m, n, j, self.__class__.__name__, **group_keys)
def _tensor_product_irreps(self, m, n) -> List[Tuple[Tuple, int]]:
group_keys = self._keys
m = self.get_irrep_id(m)
n = self.get_irrep_id(n)
return escnn.group._clebsh_gordan._find_tensor_decomposition(m, n, self.__class__.__name__, **group_keys)
def _tensor_product(self, rho1: escnn.group.Representation, rho2: escnn.group.Representation) -> escnn.group.Representation:
assert rho1.group == self
assert rho2.group == self
D1 = rho1.size
D2 = rho2.size
D = D1 * D2
change_of_basis = np.zeros((D, D))
irreps = []
p = 0
for irr1 in rho1.irreps:
irr1 = self.irrep(*irr1)
permutation = np.zeros((irr1.size * rho2.size, irr1.size * rho2.size))
q = 0
for irr2 in rho2.irreps:
irr2 = self.irrep(*irr2)
irr1_tensor_irr2 = self._tensor_product_irreps(irr1.id, irr2.id)
size = 0
for irr_id, S in irr1_tensor_irr2:
irr = self.irrep(*irr_id)
size += irr.size*S
irreps += [irr.id]*S
assert size == irr1.size * irr2.size, (size, irr1.size, irr2.size)
i = 0
for irr_j, S in irr1_tensor_irr2:
irr_j = self.irrep(*irr_j)
change_of_basis[
p:p+size,
p+i:p+i+irr_j.size*S
] = self._clebsh_gordan_coeff(irr1.id, irr2.id, irr_j.id).reshape(-1, irr_j.size*S)
i += irr_j.size * S
assert i == size
for i in range(irr1.size):
permutation[
q + i*rho2.size:q + i*rho2.size + irr2.size,
q*irr1.size + i * irr2.size:q*irr1.size + (i+1) * irr2.size
] = np.eye(irr2.size)
q += irr2.size
p += size
assert np.allclose(permutation @ permutation.T, np.eye(permutation.shape[0]))
assert np.allclose(permutation.T @ permutation, np.eye(permutation.shape[0]))
change_of_basis[
p - irr1.size*rho2.size:p,
p - irr1.size*rho2.size:p
] = permutation @ change_of_basis[p - irr1.size*rho2.size:p, p - irr1.size*rho2.size:p]
change_of_basis = np.kron(rho1.change_of_basis, rho2.change_of_basis) @ change_of_basis
assert p == sum(self.irrep(*irr).size for irr in irreps), (p, rho1.size, rho2.size)
assert p == rho1.size * rho2.size, (p, rho1.size, rho2.size)
assert p == change_of_basis.shape[0]
assert p == change_of_basis.shape[1]
supported_nonlinearities = _tensor_nonlinearities(rho1, rho2)
character = _tensor_product_character(rho1, rho2)
if len(irreps) > 1:
return escnn.group.Representation(self,
f'{rho1.name} X {rho2.name}',
irreps,
change_of_basis,
character=character,
supported_nonlinearities=supported_nonlinearities
)
else:
return escnn.group.change_basis(
self.irrep(*irreps[0]),
change_of_basis,
name=f'{rho1.name} X {rho2.name}'
)
[docs] @abstractmethod
def testing_elements(self) -> Iterable[GroupElement]:
r"""
A finite number of group elements to use for testing.
"""
pass
[docs] def get_irrep_id(self, psi):
if isinstance(psi, escnn.group.IrreducibleRepresentation):
assert psi.group == self
return psi.id
elif isinstance(psi, str):
psi = self.representations[psi]
assert isinstance(psi, escnn.group.IrreducibleRepresentation)
return psi
elif isinstance(psi, tuple):
return self.irrep(*psi).id
else:
return self.irrep(psi).id
def _decode_subgroup_id_pickleable(self, id: Tuple) -> Tuple:
if isinstance(id, tuple):
if id[0] == 'GROUPELEMENT':
id = self.element(id[1], id[2])
else:
id = list(id)
for i in range(len(id)):
id[i] = self._decode_subgroup_id_pickleable(id[i])
id = tuple(id)
return id
def _encode_subgroup_id_pickleable(self, id: Tuple) -> Tuple:
if isinstance(id, GroupElement):
id = 'GROUPELEMENT', id.value, id.param
elif isinstance(id, tuple):
id = list(id)
for i in range(len(id)):
id[i] = self._encode_subgroup_id_pickleable(id[i])
id = tuple(id)
return id
@classmethod
@abstractmethod
def _generator(cls, *args, **kwargs) -> 'Group':
# TODO solve the singleton problem!!!
pass
def _tensor_product_character(rho1: 'Representation', rho2: 'Representation'):
def character(e: GroupElement, rho1=rho1, rho2=rho2) -> float:
return rho1.character(e) * rho2.character(e)
return character
def _induced_nonlinearities(repr: escnn.group.Representation):
supported_nonlinearities = []
if 'pointwise' in repr.supported_nonlinearities:
supported_nonlinearities.append('pointwise')
if 'concatenated' in repr.supported_nonlinearities:
supported_nonlinearities.append('concatenated')
if 'gated' in repr.supported_nonlinearities:
supported_nonlinearities.append('gated')
for nl in repr.supported_nonlinearities:
if nl.startswith('induced_gated'):
supported_nonlinearities.append(nl)
break
else:
supported_nonlinearities.append(f'induced_gated_{repr.size}')
if 'norm' in repr.supported_nonlinearities:
supported_nonlinearities.append('norm')
for nl in repr.supported_nonlinearities:
if nl.startswith('induced_norm'):
supported_nonlinearities.append(nl)
break
else:
supported_nonlinearities.append(f'induced_norm_{repr.size}')
if 'gate' in repr.supported_nonlinearities or 'induced_gate' in repr.supported_nonlinearities:
supported_nonlinearities.append('induced_gate')
# 'vectorfield' not always supported by the induced representation so they are not added
return supported_nonlinearities
def _tensor_nonlinearities(repr1: escnn.group.Representation, repr2: escnn.group.Representation):
supported_nonlinearities = []
if 'pointwise' in repr1.supported_nonlinearities and 'pointwise' in repr2.supported_nonlinearities:
supported_nonlinearities.append('pointwise')
supported_nonlinearities.append('gated')
supported_nonlinearities.append('norm')
if 'gate' in repr1.supported_nonlinearities and 'gate' in repr2.supported_nonlinearities:
supported_nonlinearities.append('gate')
# TODO - check for induced non-linearities
return supported_nonlinearities
[docs]class GroupElement(ABC):
def __init__(self, g, group: Group, param: str = None):
r"""
Class implementing an element of a group.
Group elements can be combined the group operations like the *group law* or the *inverse*.
In particular, one can combine two group elements through the group law using the
operator ``@`` or compute the inverse of an element using ``~``.
For example ::
G = so3_group()
a = G.sample()
b = G.sample()
c = a @ b
a_ = ~a
e = G.identity
assert e == ~e
assert a == a @ e
Args:
g: values parametrizing the group element
group (Group): the group this element belongs to
param (str): the type of parametrization of ``g``
Attributes:
~.group (Group): the group it belongs to
"""
if param is None:
param = group.PARAM
# TODO - create "ParametrizationError"?
if param not in group.PARAMETRIZATIONS:
raise ValueError(f'Error! {param} not recognized. Expected one of {group.PARAMETRIZATIONS}')
if not group._is_element(g, param): #, verbose=True):
raise ValueError(f'Error! {g} is not an element of {group} under the parametrization [{param}]')
# Group: the group this element belongs to
self.group = group
# TODO: do lazy conversion. Keep always input parametrization.
# Convert to self.group.PARAM only when performing operations and creating new group elements
self._element = group._change_param(g, param, group.PARAM)
def __eq__(self, other: GroupElement):
if not isinstance(other, GroupElement) or other.group != self.group:
return False
return self.group._equal(self._element, other._element, param1=self.param, param2=other.param)
def __matmul__(self, other: GroupElement):
if isinstance(other, GroupElement):
if other.group == self.group:
return GroupElement(
self.group._combine(self._element, other._element, param1=self.param, param2=other.param),
self.group,
self.param
)
else:
raise NotImplementedError(f'Multiplication of group elements which belong to different groups is not supported.')
else:
return NotImplemented
def __invert__(self):
return GroupElement(
self.group._inverse(self._element, param=self.param),
self.group,
self.param
)
def __hash__(self):
return self.group._hash_element(self._element, self.param)
def __repr__(self):
return self.group._repr_element(self._element, self.param)
@property
def value(self):
r"""
Returns the values of the internal parametrization of the group element.
These values parametrize the group element according to the parametrization
:attr:`escnn.group.GroupElement.param`.
"""
return self._element
@property
def param(self) -> str:
r"""
The type parametrization used internally to store this group element.
"""
return self.group.PARAM
[docs] def to(self, param: str):
r"""
Converts the current group element to the input parametrization ``param`` and returns the corresponding
values.
.. note ::
This method does *not* return an instance of :class:`~escnn.group.GroupElement`.
This method does *not* affect the internal representation of the element, but just converts it to the
input ``param`` and returns the converted values.
"""
return self.group._change_param(self._element, self.param, param)