Source code for escnn.kernels.steerable_basis


from .basis import KernelBasis, EmptyBasisException
from .steerable_filters_basis import SteerableFiltersBasis

from escnn.group import Group
from escnn.group import IrreducibleRepresentation
from escnn.group import Representation

import torch

from typing import Type, Union, Tuple, Dict, List, Iterable, Callable, Set
from abc import ABC, abstractmethod
from collections import defaultdict


[docs]class IrrepBasis(KernelBasis): def __init__(self, basis: SteerableFiltersBasis, in_irrep: Union[IrreducibleRepresentation, Tuple], out_irrep: Union[IrreducibleRepresentation, Tuple], dim: int, harmonics: List[Tuple] = None ): r""" Abstract class for bases implementing the kernel constraint solutions associated to irreducible input and output representations. .. note :: The steerable *filter* ``basis`` is not necessarily associated with the same group as ``in_irrep`` and ``out_irrep``. For instance, :class:`~escnn.kernels.RestrictedWignerEckartBasis` uses a larger group to define ``basis``. The attribute ``IrrepBasis.group``, instead, refers to the equivariance group of this steerable *kernel* basis and is the same group of ``in_irrep`` and ``out_irrep``. The irreps in the list ``harmonics`` refer to the group in the steerable filter ``basis``, and not to ``IrrepBasis.group``. Args: basis (SteerableFiltersBasis): the steerable basis used to parameterize scalar filters and generate the kernel solutions in_irrep (IrreducibleRepresentation): the input irrep out_irrep (IrreducibleRepresentation): the output irrep dim (int): the number of elements in the basis harmonics (list, optional): optionally, use only a subset of the steerable filters in ``basis``. This list defines a subset of the group's irreps and is used to select only the steerable basis filters which transform according to these irreps. Attributes: ~.group (Group): the equivariance group ~.in_irrep (IrreducibleRepresentation): the input irrep ~.out_irrep (IrreducibleRepresentation): the output irrep ~.basis (SteerableFiltersBasis): the steerable basis used to parameterize scalar filters """ super(IrrepBasis, self).__init__(dim, (out_irrep.size, in_irrep.size)) assert in_irrep.group == out_irrep.group self.group: Group = in_irrep.group self.in_irrep: IrreducibleRepresentation = self.group.irrep(*self.group.get_irrep_id(in_irrep)) self.out_irrep: IrreducibleRepresentation = self.group.irrep(*self.group.get_irrep_id(out_irrep)) self.basis: SteerableFiltersBasis = basis self.js = [] harmonics = set(harmonics) for j, _ in basis.js: if j not in self.js and (harmonics is None or j in harmonics): self.js.append(j) self._start_index = {} idx = 0 for _j in self.js: self._start_index[_j] = idx idx += self.dim_harmonic(_j)
[docs] def sample(self, points: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: r""" Sample the continuous basis elements on the discrete set of points in ``points``. Optionally, store the resulting multidimentional array in ``out``. ``points`` must be an array of shape `(N, d)`, where `N` is the number of points and `d` is the dimensionality of the Euclidean space where filters are defined. Args: points (~torch.Tensor): points where to evaluate the basis elements out (~torch.Tensor, optional): pre-existing array to use to store the output Returns: the sampled basis """ assert len(points.shape) == 2 S = points.shape[0] if out is None: out = torch.empty(S, self.dim, self.shape[0], self.shape[1], device=points.device, dtype=points.dtype) assert out.shape == (S, self.dim, self.shape[0], self.shape[1]) steerable_basis = self.basis.sample_as_dict(points) B = 0 outs = {} for b, j in enumerate(self.js): outs[j] = out[:, B:B + self.dim_harmonic(j), ...] B += self.dim_harmonic(j) self.sample_harmonics(steerable_basis, outs) return out
[docs] @abstractmethod def sample_harmonics(self, points: Dict[Tuple, torch.Tensor], out: Dict[Tuple, torch.Tensor] = None) -> Dict[Tuple, torch.Tensor]: r""" Sample the continuous basis elements on the discrete set of points. Rather than using the points' coordinates, the method directly takes in input the steerable basis elements sampled on this points using the method :meth:`escnn.kernels.SteerableFilterBasis.sample_as_dict` of ``self.basis``. Similarly, rather than returning a single tensor containing all sampled basis elements, it groups basis elements by the ``G``-irrep acting on them. The method returns a dictionary mapping each irrep's ``id`` to a tensor of shape `(N, m, o, i)`, where `N` is the number of points, `m` is the multiplicity of the irrep (see :meth:`~escnn.kernels.SteerableKernelBasis.dim_harmonic`) and `o, i` is the number of input and output channels (see the ``shape`` attribute). Optionally, store the resulting tensors in ``out``, rather than allocating new memory. Args: points (~torch.Tensor): points where to evaluate the basis elements out (~torch.Tensor, optional): pre-existing array to use to store the output Returns: the sampled basis """ raise NotImplementedError()
[docs] @abstractmethod def dim_harmonic(self, j: Tuple) -> int: r''' Number of kernel basis elements generated from elements of the steerable filter basis (``self.basis``) which transform according to the ``self.basis.group``-irrep identified by ``j``. ''' raise NotImplementedError()
@abstractmethod def attrs_j_iter(self, j: Tuple) -> Iterable: raise NotImplementedError() @abstractmethod def attrs_j(self, j: Tuple, idx) -> Dict: raise NotImplementedError() @classmethod @abstractmethod def _generator(cls, basis: SteerableFiltersBasis, psi_in: Union[IrreducibleRepresentation, Tuple], psi_out: Union[IrreducibleRepresentation, Tuple], **kwargs ) -> 'IrrepBasis': raise NotImplementedError()
[docs]class SteerableKernelBasis(KernelBasis): def __init__(self, basis: SteerableFiltersBasis, in_repr: Representation, out_repr: Representation, irreps_basis: Type[IrrepBasis], **kwargs): r""" Implements a general basis for the vector space of equivariant kernels over an Euclidean space :math:`X=\R^n`. A :math:`G`-equivariant kernel :math:`\kappa`, mapping between an input field, transforming under :math:`\rho_\text{in}` (``in_repr``), and an output field, transforming under :math:`\rho_\text{out}` (``out_repr``), satisfies the following constraint: .. math :: \kappa(gx) = \rho_\text{out}(g) \kappa(x) \rho_\text{in}(g)^{-1} \qquad \forall g \in G, \forall x \in X As the kernel constraint is a linear constraint, the space of equivariant kernels is a vector subspace of the space of all convolutional kernels. It follows that any equivariant kernel can be expressed in terms of a basis of this space. This class solves the kernel constraint for two arbitrary representations by combining the solutions of the kernel constraints associated to their :class:`~escnn.group.IrreducibleRepresentation` s. In order to do so, it relies on ``irreps_basis`` which solves individual irreps constraints. ``irreps_basis`` must be a class (subclass of :class:`~escnn.kernels.IrrepsBasis`) which builds a basis for equivariant kernels associated with irreducible representations when instantiated. The groups :math:`G` which are currently implemented are origin-preserving isometries (what are called structure groups, or sometimes gauge groups, in the language of `Gauge Equivariant CNNs <https://arxiv.org/abs/1902.04615>`_ ). The origin-preserving isometries of :math:`\R^d` are subgroups of :math:`O(d)`, i.e. reflections and rotations. Therefore, equivariance does not enforce any constraint on the radial component of the kernels. Hence, this class only implements a basis for the angular part of the kernels. .. warning :: Typically, the user does not need to manually instantiate this class. Instead, we suggest to use the interface provided in :doc:`escnn.gspaces`. Args: basis (SteerableFiltersBasis): a steerable basis for scalar filters over the base space in_repr (Representation): Representation associated with the input feature field out_repr (Representation): Representation associated with the output feature field irreps_basis (class): class defining the irreps basis. This class is instantiated for each pair of irreps to solve all irreps constraints. **kwargs: additional arguments used when instantiating ``irreps_basis`` Attributes: ~.group (Group): the equivariance group ``G``. ~.in_repr (Representation): the input representation ~.out_repr (Representation): the output representation """ assert in_repr.group == out_repr.group self.in_repr: Representation = in_repr self.out_repr: Representation = out_repr group = in_repr.group self.group: Group = group self._irrep_basis = irreps_basis self._irrep__basis_kwargs = kwargs ################ # Dict[Tuple, IrrepsBasis]: self.irreps_bases = {} js = set() # loop over all input irreps for i_irrep_id in set(in_repr.irreps): # loop over all output irreps for o_irrep_id in set(out_repr.irreps): try: # retrieve the irrep intertwiner basis intertwiner_basis = irreps_basis._generator(basis, i_irrep_id, o_irrep_id, **kwargs) assert intertwiner_basis.group == self.group self.irreps_bases[(i_irrep_id, o_irrep_id)] = intertwiner_basis # compute the set of all harmonics needed by all intertwiners bases # in this way, we can precompute the embedding of the points with the harmonics and reuse the same # embeddings for all irreps bases js.update(intertwiner_basis.js) except EmptyBasisException: # if the basis is empty, skip it pass self._dim_harmonics = defaultdict(int) self.bases = [[None for _ in range(len(out_repr.irreps))] for _ in range(len(in_repr.irreps))] dim = 0 # loop over all input irreps for ii, i_irrep_id in enumerate(in_repr.irreps): # loop over all output irreps for oo, o_irrep_id in enumerate(out_repr.irreps): if (i_irrep_id, o_irrep_id) in self.irreps_bases: self.bases[ii][oo] = self.irreps_bases[(i_irrep_id, o_irrep_id)] dim += self.irreps_bases[(i_irrep_id, o_irrep_id)].dim for j in self.bases[ii][oo].js: self._dim_harmonics[j] += self.irreps_bases[(i_irrep_id, o_irrep_id)].dim_harmonic(j) ################ # before registering tensors as buffers and sub-modules, we need to call torch.nn.Module.__init__() super(SteerableKernelBasis, self).__init__(dim, (out_repr.size, in_repr.size)) self.basis = basis for io_pair, intertwiner_basis in self.irreps_bases.items(): self.add_module(f'basis_{io_pair}', intertwiner_basis) ################ A_inv = torch.tensor(in_repr.change_of_basis_inv, dtype=torch.float32).clone() B = torch.tensor(out_repr.change_of_basis, dtype=torch.float32).clone() if not torch.allclose(A_inv, torch.eye(in_repr.size)): self.register_buffer('A_inv', A_inv) else: self.A_inv = None if not torch.allclose(B, torch.eye(out_repr.size)): self.register_buffer('B', B) else: self.B = None self.js = [j for j, m in self.basis.js if j in js] if self.basis.group.trivial_representation.id in self.js: # make sure that the harmonic corresponding to the trivial representation is the first in the list. self.js.remove(self.basis.group.trivial_representation.id) self.js = [self.basis.group.trivial_representation.id] + self.js self.in_sizes = [] self.out_sizes = [] # loop over all input irreps for ii, i_irrep_id in enumerate(in_repr.irreps): self.in_sizes.append(group.irrep(*i_irrep_id).size) # loop over all output irreps for oo, o_irrep_id in enumerate(out_repr.irreps): self.out_sizes.append(group.irrep(*o_irrep_id).size) self._slices = defaultdict(dict) basis_count = defaultdict(int) in_position = 0 for ii, in_size in enumerate(self.in_sizes): out_position = 0 for oo, out_size in enumerate(self.out_sizes): if self.bases[ii][oo] is not None: for j in self.bases[ii][oo].js: self._slices[(ii, oo)][j] = ( out_position, out_position + out_size, in_position, in_position + in_size, basis_count[j], basis_count[j] + self.bases[ii][oo].dim_harmonic(j) ) basis_count[j] += self.bases[ii][oo].dim_harmonic(j) out_position += out_size in_position += in_size
[docs] def dim_harmonic(self, j: Tuple) -> int: r''' Number of kernel basis elements generated from elements of the steerable filter basis (``self.basis``) which transform according to the ``self.basis.group``-irrep identified by ``j``. ''' return self._dim_harmonics[j]
[docs] def compute_harmonics(self, points: torch.Tensor) -> Dict[Tuple, torch.Tensor]: r""" Pre-compute the sampled steerable filter basis over a set of point. This is an alias for ``self.basis.sample_as_dict(points)``. .. seealso :: :meth:`escnn.kernels.SteerableFiltersBasis.sample_as_dict`. """ return self.basis.sample_as_dict(points)
[docs] def sample(self, points: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: r""" Sample the continuous basis elements on the discrete set of points in ``points``. Optionally, store the resulting multidimentional array in ``out``. ``points`` must be an array of shape `(N, d)`, where `N` is the number of points. Args: points (~torch.Tensor): points where to evaluate the basis elements out (~torch.Tensor, optional): pre-existing array to use to store the output Returns: the sampled basis """ assert len(points.shape) == 2 S = points.shape[0] if out is None: out = torch.zeros((S, self.dim, self.shape[0], self.shape[1]), device=points.device, dtype=points.dtype) else: out[:] = 0. assert out.shape == (S, self.dim, self.shape[0], self.shape[1]) outs = {} B = 0 for j in self.js: outs[j] = out[:, B:B + self.dim_harmonic(j), ...] B += self.dim_harmonic(j) steerable_basis = self.compute_harmonics(points) self.sample_harmonics(steerable_basis, outs) return out
[docs] def sample_harmonics(self, points: Dict[Tuple, torch.Tensor], out: Dict[Tuple, torch.Tensor] = None) -> Dict[Tuple, torch.Tensor]: r""" Sample the continuous basis elements on the discrete set of points. Rather than using the points' coordinates, the method directly takes in input the steerable basis elements sampled on this points using the method :meth:`escnn.kernels.SteerableKernelBasis.compute_harmonics`. Similarly, rather than returning a single tensor containing all sampled basis elements, it groups basis elements by the ``G``-irrep acting on them. The method returns a dictionary mapping each irrep's ``id`` to a tensor of shape `(N, m, o, i)`, where `N` is the number of points, `m` is the multiplicity of the irrep (see :meth:`~escnn.kernels.SteerableKernelBasis.dim_harmonic`) and `o, i` is the number of input and output channels (see the ``shape`` attribute). Optionally, store the resulting tensors in ``out``, rather than allocating new memory. Args: points (~torch.Tensor): points where to evaluate the basis elements out (~torch.Tensor, optional): pre-existing array to use to store the output Returns: the sampled basis """ if out is None: out = { j: torch.zeros( (points[j].shape[0], self.dim_harmonic(j), self.shape[0], self.shape[1]), device=points[j].device, dtype=points[j].dtype ) for j in self.js } for j in self.js: if j in out: assert out[j].shape == (points[j].shape[0], self.dim_harmonic(j), self.shape[0], self.shape[1]) if self.A_inv is None and self.B is None: out = self._sample_direct_sum(points, out=out) else: samples = self._sample_direct_sum(points) out = self._change_of_basis(samples, out=out) return out
def _sample_direct_sum(self, points: Dict[Tuple, torch.Tensor], out: Dict[Tuple, torch.Tensor] = None) -> Dict[Tuple, torch.Tensor]: if out is None: out = { j: torch.zeros( (points[j].shape[0], self.dim_harmonic(j), self.shape[0], self.shape[1]), device=points[j].device, dtype=points[j].dtype ) for j in self.js } # else: # for j in self.js: # if j in out: # out[j][:] = 0. for j in self.js: if j in out: assert out[j].shape == (points[j].shape[0], self.dim_harmonic(j), self.shape[0], self.shape[1]) for ii, in_size in enumerate(self.in_sizes): for oo, out_size in enumerate(self.out_sizes): if self.bases[ii][oo] is not None: slices = self._slices[(ii, oo)] blocks = { j: out[j][:, b_s:b_e, o_s:o_e, i_s:i_e] for j, (o_s, o_e, i_s, i_e, b_s, b_e) in slices.items() } self.bases[ii][oo].sample_harmonics(points, out=blocks) return out def _change_of_basis(self, samples: Dict[Tuple, torch.Tensor], out: Dict[Tuple, torch.Tensor] = None) -> Dict[Tuple, torch.Tensor]: # multiply by the change of basis matrices to transform the irreps basis in the full representations basis if out is None: out = {j: None for j in self.js} for j in samples.keys(): if self.A_inv is not None and self.B is not None: out[j][:] = torch.einsum("no,pboi,ij->pbnj", self.B.to(samples[j].dtype), samples[j], self.A_inv.to(samples[j].dtype)) elif self.A_inv is not None: out[j][:] = torch.einsum("pboi,ij->pboj", samples[j], self.A_inv.to(samples[j].dtype)) elif self.B is not None: out[j][:] = torch.einsum("no,pboi->pbni", self.B.to(samples[j].dtype), samples[j]) else: out[j][...] = samples[j] return out def __getitem__(self, idx): assert idx < self.dim j_idx = idx for j in self.js: dim = self.dim_harmonic(j) if j_idx >= dim: j_idx -= dim else: break assert j_idx < self.dim_harmonic(j), (j_idx, self.dim_harmonic(j)) count = 0 for ii in range(len(self.in_sizes)): for oo in range(len(self.out_sizes)): if self.bases[ii][oo] is not None: dim = self.bases[ii][oo].dim_harmonic(j) rel_idx = j_idx - count if rel_idx >= 0 and rel_idx < dim: attr = dict(self.bases[ii][oo].attrs_j(j, rel_idx)) attr["shape"] = self.bases[ii][oo].shape attr["irreps_basis_idx"] = attr["idx"] attr["idx"] = idx attr["j"] = j attr["j_idx"] = j_idx attr["in_irrep"] = self.in_repr.irreps[ii] attr["out_irrep"] = self.out_repr.irreps[oo] attr["in_irrep_idx"] = ii attr["out_irrep_idx"] = oo return attr count += dim def __iter__(self): idx = 0 for j in self.js: j_idx = 0 for ii in range(len(self.in_sizes)): for oo in range(len(self.out_sizes)): basis = self.bases[ii][oo] if basis is not None: for attr in basis.attrs_j_iter(j): attr["shape"] = basis.shape attr["in_irrep"] = self.in_repr.irreps[ii] attr["out_irrep"] = self.out_repr.irreps[oo] attr["in_irrep_idx"] = ii attr["out_irrep_idx"] = oo attr["j"] = j attr["j_idx"] = j_idx attr["irreps_basis_idx"] = attr["idx"] attr["idx"] = idx assert idx < self.dim yield attr idx += 1 j_idx += 1 def __eq__(self, other): if not isinstance(other, SteerableKernelBasis): return False elif self.basis != other.basis or self.in_repr != other.in_repr or self.out_repr != other.out_repr: return False else: sbk1 = sorted(self.irreps_bases.keys()) sbk2 = sorted(other.irreps_bases.keys()) if sbk1 != sbk2: return False for irreps, basis in self.irreps_bases.items(): if basis != other.irreps_bases[irreps]: return False return True def __hash__(self): h = hash(self.in_repr) + hash(self.out_repr) + hash(self.basis) for basis in self.irreps_bases.items(): h += hash(basis) return h