Source code for e2cnn.diffops.steerable_basis


import numpy as np

from e2cnn.kernels.basis import EmptyBasisException

from e2cnn.group import Representation, SO2, CyclicGroup

from .basis import DiffopBasis, DiscretizationArgs

from typing import Type, List


[docs]class SteerableDiffopBasis(DiffopBasis): def __init__(self, irreps_basis: Type[DiffopBasis], in_repr: Representation, out_repr: Representation, discretization: DiscretizationArgs = DiscretizationArgs(), **kwargs): r""" Implements a general basis for the vector space of equivariant PDOs. A :math:`G`-equivariant PDO :math:`D(P)` for a matrix of polynomials :math:`P`, mapping between an input field, transforming under :math:`\rho_\text{in}` (``in_repr``), and an output field, transforming under :math:`\rho_\text{out}` (``out_repr``), satisfies the following constraint: .. math :: P(gx) = \rho_\text{out}(g) P(x) \rho_\text{in}(g)^{-1} \qquad \forall g \in G, \forall x \in X for :math:`G \leq \O{d}`. As the PDO constraint is a linear constraint, the space of equivariant PDOs is a vector subspace of the space of all PDOs. It follows that any equivariant PDO can be expressed in terms of a basis of this space. This class solves the PDO constraint for two arbitrary representations by combining the solutions of the PDO constraints associated to their :class:`~e2cnn.group.IrreducibleRepresentation` s. In order to do so, it relies on ``irreps_basis`` which solves individual irreps constraints. ``irreps_basis`` must be a class which builds a basis for equivariant kernels associated with irreducible representations when instantiated. The groups :math:`G` which are currently implemented are origin-preserving isometries (what are called structure groups, or sometimes gauge groups, in the language of `Gauge Equivariant CNNs <https://arxiv.org/abs/1902.04615>`_ ). The origin-preserving isometries of :math:`\R^d` are subgroups of :math:`O(d)`, i.e. reflections and rotations. Therefore, PDOs may be composed with any rotation and reflection invariant PDO without affecting equivariance. This class only implements a basis up to such invariant PDOs, which are given by polynomials in the Laplacian operator. In order to build a complete basis of PDOs, you should combine this basis with :class:`~e2cnn.diffops.LaplaceProfile`) through :class:`~e2cnn.diffops.TensorBasis`. .. warning :: Typically, the user does not need to manually instantiate this class. Instead, we suggest to use the interface provided in :doc:`e2cnn.gspaces`. Args: irreps_basis (class): class defining the irreps basis. This class is instantiated for each pair of irreps to solve all irreps constraints. in_repr (Representation): Representation associated with the input feature field out_repr (Representation): Representation associated with the output feature field discretization (DiscretizationArgs, optional): additional parameters specifying parameters for the discretization procedure. See :class:`~e2cnn.diffops.DiscretizationArgs`. **kwargs: additional arguments used when instantiating ``irreps_basis`` """ assert in_repr.group == out_repr.group self.in_repr = in_repr self.out_repr = out_repr group = in_repr.group self.group = group A_inv = np.array(in_repr.change_of_basis_inv, copy=True) B = np.array(out_repr.change_of_basis, copy=True) # A_inv = in_repr.change_of_basis_inv # B = out_repr.change_of_basis if not np.allclose(A_inv, np.eye(in_repr.size)): self.A_inv = A_inv else: self.A_inv = None if not np.allclose(B, np.eye(out_repr.size)): self.B = B else: self.B = None self.irreps_bases = {} # loop over all input irreps for i_irrep_name in set(in_repr.irreps): # loop over all output irreps for o_irrep_name in set(out_repr.irreps): try: # retrieve the irrep intertwiner basis basis = irreps_basis(group=group, in_irrep=i_irrep_name, out_irrep=o_irrep_name, discretization=discretization, **kwargs) self.irreps_bases[(i_irrep_name, o_irrep_name)] = basis except EmptyBasisException: # if the basis is empty, skip it pass self.bases = [[None for _ in range(len(out_repr.irreps))] for _ in range(len(in_repr.irreps))] self.in_sizes = [] self.out_sizes = [] # loop over all input irreps for ii, i_irrep_name in enumerate(in_repr.irreps): self.in_sizes.append(group.irreps[i_irrep_name].size) # loop over all output irreps for oo, o_irrep_name in enumerate(out_repr.irreps): self.out_sizes.append(group.irreps[o_irrep_name].size) dim = 0 # loop over all input irreps for ii, i_irrep_name in enumerate(in_repr.irreps): # loop over all output irreps for oo, o_irrep_name in enumerate(out_repr.irreps): if (i_irrep_name, o_irrep_name) in self.irreps_bases: self.bases[ii][oo] = self.irreps_bases[(i_irrep_name, o_irrep_name)] dim += self.bases[ii][oo].dim # would be set later anyway but we need it now self.shape = (out_repr.size, in_repr.size) if self.A_inv is None and self.B is None: coefficients = self._direct_sum_coefficients() else: pre_coefficients = self._direct_sum_coefficients() coefficients = self._change_of_basis(pre_coefficients) super().__init__(coefficients, discretization) def _direct_sum_coefficients(self) -> List[np.ndarray]: coefficients: List[np.ndarray] = [] basis_count = 0 in_position = 0 for ii, in_size in enumerate(self.in_sizes): out_position = 0 for oo, out_size in enumerate(self.out_sizes): if self.bases[ii][oo] is not None: block_coefficients = self.bases[ii][oo].coefficients for element in block_coefficients: out = np.zeros((self.shape[0], self.shape[1], element.shape[2])) out[ out_position:out_position+out_size, in_position:in_position+in_size, : ] = element coefficients.append(out) out_position += out_size in_position += in_size return coefficients def _change_of_basis(self, coefficients: List[np.ndarray]) -> List[np.ndarray]: # multiply by the change of basis matrices to transform the irreps basis in the full representations basis new_coefficients: List[np.ndarray] = [] for element in coefficients: if self.A_inv is not None and self.B is not None: new_coefficients.append(np.einsum("no,oib,ij->njb", self.B, element, self.A_inv)) elif self.A_inv is not None: new_coefficients.append(np.einsum("oib,ij->ojb", element, self.A_inv)) elif self.B is not None: new_coefficients.append(np.einsum("no,oib->nib", self.B, element)) else: new_coefficients.append(element) return new_coefficients def __getitem__(self, idx): assert idx < self.dim count = 0 for ii in range(len(self.in_sizes)): for oo in range(len(self.out_sizes)): if self.bases[ii][oo] is not None: dim = self.bases[ii][oo].dim rel_idx = idx - count if rel_idx >= 0 and rel_idx < dim: attr = dict(self.bases[ii][oo][rel_idx]) attr["shape"] = self.bases[ii][oo].shape attr["in_irrep"] = self.in_repr.irreps[ii] attr["out_irrep"] = self.out_repr.irreps[oo] attr["in_irrep_idx"] = ii attr["out_irrep_idx"] = oo attr["inner_idx"] = attr["idx"] attr["idx"] = idx return attr count += dim def __iter__(self): idx = 0 for ii in range(len(self.in_sizes)): for oo in range(len(self.out_sizes)): if self.bases[ii][oo] is not None: for rel_idx in range(len(self.bases[ii][oo])): attr = dict(self.bases[ii][oo][rel_idx]) attr["shape"] = self.bases[ii][oo].shape attr["in_irrep"] = self.in_repr.irreps[ii] attr["out_irrep"] = self.out_repr.irreps[oo] attr["in_irrep_idx"] = ii attr["out_irrep_idx"] = oo attr["inner_idx"] = attr["idx"] attr["idx"] = idx yield attr idx += 1 def __eq__(self, other): if not isinstance(other, SteerableDiffopBasis): return False elif self.in_repr != other.in_repr or self.out_repr != other.out_repr: return False else: sbk1 = sorted(self.irreps_bases.keys()) sbk2 = sorted(other.irreps_bases.keys()) if sbk1 != sbk2: return False for irreps, basis in self.irreps_bases.items(): if basis != other.irreps_bases[irreps]: return False return True def __hash__(self): key = (self.in_repr, self.out_repr) h = hash(key) for basis in self.irreps_bases.items(): h += hash(basis) return h