Source code for e2cnn.kernels.steerable_basis


import numpy as np

from .basis import KernelBasis, EmptyBasisException
from .irreps_basis import IrrepBasis

from e2cnn.group import Representation

from typing import Type


[docs]class SteerableKernelBasis(KernelBasis): def __init__(self, irreps_basis: Type[IrrepBasis], in_repr: Representation, out_repr: Representation, **kwargs): r""" Implements a general basis for the vector space of equivariant kernels. A :math:`G`-equivariant kernel :math:`\kappa`, mapping between an input field, transforming under :math:`\rho_\text{in}` (``in_repr``), and an output field, transforming under :math:`\rho_\text{out}` (``out_repr``), satisfies the following constraint: .. math :: \kappa(gx) = \rho_\text{out}(g) \kappa(x) \rho_\text{in}(g)^{-1} \qquad \forall g \in G, \forall x \in X As the kernel constraint is a linear constraint, the space of equivariant kernels is a vector subspace of the space of all convolutional kernels. It follows that any equivariant kernel can be expressed in terms of a basis of this space. This class solves the kernel constraint for two arbitrary representations by combining the solutions of the kernel constraints associated to their :class:`~e2cnn.group.IrreducibleRepresentation` s. In order to do so, it relies on ``irreps_basis`` which solves individual irreps constraints. ``irreps_basis`` must be a class (subclass of :class:`~e2cnn.kernels.IrrepsBasis`) which builds a basis for equivariant kernels associated with irreducible representations when instantiated. The groups :math:`G` which are currently implemented are origin-preserving isometries (what are called structure groups, or sometimes gauge groups, in the language of `Gauge Equivariant CNNs <https://arxiv.org/abs/1902.04615>`_ ). The origin-preserving isometries of :math:`\R^d` are subgroups of :math:`O(d)`, i.e. reflections and rotations. Therefore, equivariance does not enforce any constraint on the radial component of the kernels. Hence, this class only implements a basis for the angular part of the kernels. In order to build a complete basis of kernels, you should combine this basis with a basis which defines the radial profile (such as :class:`~e2cnn.kernels.GaussianRadialProfile`) through :class:`~e2cnn.kernels.PolarBasis`. .. math:: \mathcal{B} = \left\{ b_i (r) := \exp \left( \frac{ \left( r - r_i \right)^2}{2 \sigma_i^2} \right) \right\}_i .. warning :: Typically, the user does not need to manually instantiate this class. Instead, we suggest to use the interface provided in :doc:`e2cnn.gspaces`. Args: irreps_basis (class): class defining the irreps basis. This class is instantiated for each pair of irreps to solve all irreps constraints. in_repr (Representation): Representation associated with the input feature field out_repr (Representation): Representation associated with the output feature field **kwargs: additional arguments used when instantiating ``irreps_basis`` """ assert in_repr.group == out_repr.group self.in_repr = in_repr self.out_repr = out_repr group = in_repr.group self.group = group A_inv = np.array(in_repr.change_of_basis_inv, copy=True) B = np.array(out_repr.change_of_basis, copy=True) # A_inv = in_repr.change_of_basis_inv # B = out_repr.change_of_basis if not np.allclose(A_inv, np.eye(in_repr.size)): self.A_inv = A_inv else: self.A_inv = None if not np.allclose(B, np.eye(out_repr.size)): self.B = B else: self.B = None self.irreps_bases = {} # loop over all input irreps for i_irrep_name in set(in_repr.irreps): # loop over all output irreps for o_irrep_name in set(out_repr.irreps): try: # retrieve the irrep intertwiner basis basis = irreps_basis(group=group, in_irrep=i_irrep_name, out_irrep=o_irrep_name, **kwargs) self.irreps_bases[(i_irrep_name, o_irrep_name)] = basis except EmptyBasisException: # if the basis is empty, skip it pass self.bases = [[None for _ in range(len(out_repr.irreps))] for _ in range(len(in_repr.irreps))] self.in_sizes = [] self.out_sizes = [] # loop over all input irreps for ii, i_irrep_name in enumerate(in_repr.irreps): self.in_sizes.append(group.irreps[i_irrep_name].size) # loop over all output irreps for oo, o_irrep_name in enumerate(out_repr.irreps): self.out_sizes.append(group.irreps[o_irrep_name].size) dim = 0 # loop over all input irreps for ii, i_irrep_name in enumerate(in_repr.irreps): # loop over all output irreps for oo, o_irrep_name in enumerate(out_repr.irreps): if (i_irrep_name, o_irrep_name) in self.irreps_bases: self.bases[ii][oo] = self.irreps_bases[(i_irrep_name, o_irrep_name)] dim += self.bases[ii][oo].dim super(SteerableKernelBasis, self).__init__(dim, (out_repr.size, in_repr.size))
[docs] def sample(self, angles: np.ndarray, out: np.ndarray = None) -> np.ndarray: r""" Sample the continuous basis elements on the discrete set of angles in ``angles``. Optionally, store the resulting multidimentional array in ``out``. A value of ``nan`` is interpreted as the angle of a point placed on the origin of the axes. ``angles`` must be an array of shape `(1, N)`, where `N` is the number of points. Args: angles (~numpy.ndarray): angles where to evaluate the basis elements out (~numpy.ndarray, optional): pre-existing array to use to store the output Returns: the sampled basis """ assert len(angles.shape) == 2 assert angles.shape[0] == 1 if out is None: out = np.zeros((self.shape[0], self.shape[1], self.dim, angles.shape[1])) else: out.fill(0) assert out.shape == (self.shape[0], self.shape[1], self.dim, angles.shape[1]) if self.A_inv is None and self.B is None: out = self._sample_direct_sum(angles, out=out) else: samples = self._sample_direct_sum(angles) out = self._change_of_basis(samples, out=out) return out
def _sample_direct_sum(self, angles: np.ndarray, out: np.ndarray = None) -> np.ndarray: assert len(angles.shape) == 2 if out is None: out = np.zeros((self.shape[0], self.shape[1], self.dim, angles.shape[1])) assert out.shape == (self.shape[0], self.shape[1], self.dim, angles.shape[1]) basis_count = 0 in_position = 0 for ii, in_size in enumerate(self.in_sizes): out_position = 0 for oo, out_size in enumerate(self.out_sizes): if self.bases[ii][oo] is not None: dim = self.bases[ii][oo].dim block = out[ out_position:out_position+out_size, in_position:in_position+in_size, basis_count:basis_count+dim, ... ] self.bases[ii][oo].sample(angles, out=block) # out[ # out_position:out_position+out_size, # in_position:in_position+in_size, # basis_count:basis_count+dim, # ... # ] = self.bases[ii][oo].sample(angles) basis_count += dim out_position += out_size in_position += in_size return out def _change_of_basis(self, samples: np.ndarray, out: np.ndarray = None) -> np.ndarray: # multiply by the change of basis matrices to transform the irreps basis in the full representations basis if self.A_inv is not None and self.B is not None: out = np.einsum("no,oibp,ij->njbp", self.B, samples, self.A_inv, out=out) elif self.A_inv is not None: out = np.einsum("oibp,ij->ojbp", samples, self.A_inv, out=out) elif self.B is not None: out = np.einsum("no,oibp->nibp", self.B, samples, out=out) else: out[...] = samples return out def __getitem__(self, idx): assert idx < self.dim count = 0 for ii in range(len(self.in_sizes)): for oo in range(len(self.out_sizes)): if self.bases[ii][oo] is not None: dim = self.bases[ii][oo].dim rel_idx = idx - count if rel_idx >= 0 and rel_idx < dim: attr = dict(self.bases[ii][oo][rel_idx]) attr["shape"] = self.bases[ii][oo].shape attr["in_irrep"] = self.in_repr.irreps[ii] attr["out_irrep"] = self.out_repr.irreps[oo] attr["in_irrep_idx"] = ii attr["out_irrep_idx"] = oo attr["inner_idx"] = attr["idx"] attr["idx"] = idx return attr count += dim def __iter__(self): idx = 0 for ii in range(len(self.in_sizes)): for oo in range(len(self.out_sizes)): if self.bases[ii][oo] is not None: for rel_idx in range(len(self.bases[ii][oo])): attr = dict(self.bases[ii][oo][rel_idx]) attr["shape"] = self.bases[ii][oo].shape attr["in_irrep"] = self.in_repr.irreps[ii] attr["out_irrep"] = self.out_repr.irreps[oo] attr["in_irrep_idx"] = ii attr["out_irrep_idx"] = oo attr["inner_idx"] = attr["idx"] attr["idx"] = idx yield attr idx += 1 def __eq__(self, other): if not isinstance(other, SteerableKernelBasis): return False elif self.in_repr != other.in_repr or self.out_repr != other.out_repr: return False else: sbk1 = sorted(self.irreps_bases.keys()) sbk2 = sorted(other.irreps_bases.keys()) if sbk1 != sbk2: return False for irreps, basis in self.irreps_bases.items(): if basis != other.irreps_bases[irreps]: return False return True def __hash__(self): h = hash(self.in_repr) + hash(self.out_repr) for basis in self.irreps_bases.items(): h += hash(basis) return h