Source code for escnn.kernels.sparse_basis


import numpy as np
import torch

from .steerable_filters_basis import SteerableFiltersBasis


from escnn.group import *

from typing import Type, Union, Tuple, Dict, List, Iterable, Callable, Set


[docs]class SparseOrbitBasis(SteerableFiltersBasis): def __init__(self, X: HomSpace, action: Representation, root: np.ndarray, sigma: float, harmonics: List = None, change_of_basis: np.ndarray = None, attributes: Dict = None): r""" Class which implements a steerable basis for scalar functions on an Euclidean space :math:`\R^n` by considering a single *orbit* of the symmetry group :math:`G` in :math:`\R^n`. Indeed, an orbit :math:`\{g.x | g \in G\}`, with :math:`x \in \R^n`, is isomorphic to an homogeneous space ```X``` of :math:`G`. A steerable basis for scalar functions over an homogeneous space is given by the harmonic basis in :class:`~escnn.group.HomSpace`. In this class, a steerable basis over :math:`\R^n` is defined by embedding the finite number of points of a *finite* homogeneous space ```X``` into :math:`\R^n` and then "diffusing" the harmonic basis defined over these points in the ambient space using a Gaussian kernel. This class only supports finite homogeneous spaces generated by a finite symmetry group, i.e. not only :math:`X=G/H` must have a finite number of elements but also :math:`G` itself should. The embedding of ```X``` into :math:`\R^n` is defined via the point ``root``, which defines the embedding of the coset :math:`eH` associated with the identity, and the ``action`` of the group :math:`G` on :math:`\R^n`. This implies that the action of any element of :math:`H` should keep ``root`` fixed. Args: X (HomSpace): the homogeneous space isomorphic to the orbit in the Euclidean space action (Representation): the action of the group ``G`` on the Euclidean space root (~np.ndarray): the embedding of the identity coset sigma (float): the standard deviation of the Gaussian kernel used to diffuse the harmonic basis in the ambient space. harmonics (list, optional): optionally, select only a subset of the harmonic functions. The list should contain the ids of group's irreps which should be used to construct the harmonic basis. See also :meth:`~escnn.group.HomSpace.scalar_basis`. change_of_basis (~np.ndarray, optional): optionally, apply a transformation on the embedded points to construct a different basis. Equivalently, one can apply this change of basis to both ``root`` and ``action``. attributes (dict, optional): additional attributes to describe the generated basis. Attributes: ~.X (HomSpace): the homogeneous space """ self.X: HomSpace = X # check that the homogeneous space has a finite number of elements # unfortunately, we can not directly check the size of the homogeneous space G/H so we check that G itself has # only a finite number of elements assert self.X.G.order() > 0 if harmonics is None: harmonics = [psi.id for psi in self.X.G.irreps()] js = [ (j, self.X.dimension_basis(j, self.X.H.trivial_representation.id)[1]) for j in harmonics if self.X.dimension_basis(j, self.X.H.trivial_representation.id)[1] > 0 ] super(SparseOrbitBasis, self).__init__(self.X.G, action, js) assert root.shape == (self.dimensionality,), (root.shape, self.dimensionality) self.root = root self.action = action root = root.reshape(-1, 1) for h in self.X.H.elements: assert np.allclose(root, action(self.X._inclusion(h)) @ root, atol=1e-5, rtol=1e-3) points = np.concatenate([ action(g) @ root for g in self.X.G.elements ], axis=1) _, idx = np.unique( np.round(points, decimals=4), axis=1, return_index=True ) points = points[:, idx].T assert points.shape == (self.X.G.order() / self.X.H.order(), self.dimensionality) assert sigma > 0., sigma for j, m in self.js: _harmonics_j = torch.tensor(np.stack([ self.X.scalar_basis(self.X.G.elements[i], j)[..., 0].T for i in idx ], axis=0), dtype=torch.float32) assert _harmonics_j.shape == (idx.shape[0], m, self.X.G.irrep(*j).size), (_harmonics_j.shape, m, j) self.register_buffer(f'harmonics_{j}', _harmonics_j) self.sigma = sigma points = torch.tensor(points, dtype=torch.float32) if change_of_basis is not None: assert change_of_basis.shape == (self.dimensionality, self.dimensionality) # check that the matrix is invertible assert np.isfinite(np.linalg.cond(change_of_basis)) change_of_basis = torch.tensor(change_of_basis, dtype=points.dtype, device=points.device) points = points @ change_of_basis.T self.register_buffer('points', points) self.change_of_basis = change_of_basis self._attributes = attributes if attributes is not None else dict() def _get_harmonics(self, j): return getattr(self, f'harmonics_{j}')
[docs] def sample(self, points: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: r""" Sample the continuous basis elements on the discrete set of points in ``points``. Optionally, store the resulting multidimentional array in ``out``. ``points`` must be an array of shape `(N, d)`, where `N` is the number of points and `d` is the :meth:`~escnn.kernels.SteerableFilterBasis.dimensionality` of the Euclidean space. Args: points (~torch.Tensor): points where to evaluate the basis elements out (~torch.Tensor, optional): pre-existing array to use to store the output Returns: the sampled basis """ assert len(points.shape) == 2 S = points.shape[0] assert points.shape[1] == self.dimensionality if out is None: out = torch.empty((S, self.dim, 1, 1), dtype=self.points.dtype, device=self.points.device) assert out.shape == (S, self.dim, 1, 1) assert self.points.device == points.device, (self.points.device, points.device) assert out.device == points.device, (out.device, points.device) weights = self.points.unsqueeze(1) - points.unsqueeze(0) assert weights.shape == (self.points.shape[0], S, self.dimensionality) weights = (weights**2).sum(axis=2) / self.sigma**2 weights = torch.exp(- 0.5 * weights) B = 0 for j, m in self.js: out[:, B:B + self.dim_harmonic(j), ...].view( S, m, -1 )[:] = torch.einsum('irm,io->orm', self._get_harmonics(j), weights) B += self.dim_harmonic(j) return out
def __iter__(self) -> Iterable: idx = 0 for j, M in self.js: psi = self.group.irrep(*j) attr = { 'irrep:' + k: v for k, v in psi.attributes.items() } attr.update(**self._attributes) attr["j"] = j attr["shape"] = (1, 1) for i in range(M): for m in range(psi.size): _attr = {"idx" : idx, "i" : i, "m" : m} _attr.update(**attr) yield _attr idx += 1 def __getitem__(self, idx) -> Dict: assert idx < self.dim, (idx, self.dim) j, j_idx = self._get_j_from_idx(idx) attr = { 'irrep:' + k: v for k, v in self.group.irrep(*j).attributes.items() } attr.update(**self._attributes) attr["idx"] = idx attr["j"] = j attr["i"] = j_idx // self.group.irrep(*j).size attr["m"] = j_idx % self.group.irrep(*j).size attr["shape"] = (1, 1) return attr def steerable_attrs_j_iter(self, j: Tuple) -> Iterable: attr = { 'irrep:' + k: v for k, v in self.X.G.irrep(*j).attributes.items() } attr['sigma'] = self.sigma attr.update(**self._attributes) for idx in range(self.multiplicity(j)): _attr = {'i': idx} _attr.update(**attr) yield _attr def steerable_attrs_j(self, j: Tuple, idx) -> Dict: assert 0 <= idx < self.multiplicity(j), idx attr = { 'irrep:' + k: v for k, v in self.X.G.irrep(*j).attributes.items() } attr['i'] = idx attr['sigma'] = self.sigma attr.update(**self._attributes) return attr def __eq__(self, other): if not isinstance(other, SparseOrbitBasis): return False elif self.X != other.X: return False elif self.action != other.action: return False elif not np.allclose(self.root, other.root, atol=1e-5): return False elif self.js != other.js: return False elif not np.isclose(self.sigma, other.sigma): return False elif (self.change_of_basis is None) != (other.change_of_basis is None): return False elif self.change_of_basis is not None: return torch.allclose(self.change_of_basis, other.change_of_basis) else: return True def __hash__(self): return 100 * hash(self.X) + 10 * hash(self.root.tobytes()) + 10 * hash(self.sigma) + hash(str(self.change_of_basis)) + sum(hash(x) for x in self.js) + hash(self.action)
[docs]class SparseOrbitBasisWithIcosahedralSymmetry(SparseOrbitBasis): def __init__(self, X: HomSpace, sigma: float, harmonics: List = None, change_of_basis: np.ndarray = None, attributes: Dict = None): r""" Given an homogeneous space ``X`` of the Icosahedral group, builds the :class:`~escnn.kernels.SparseOrbitBasis` with a ``root`` which generates an orbit isomorphic to ``X``. """ G = ico_group() assert X.G == G, X.G assert X.sgid[:2] in [(False, 2), (False, 3), (False, 5)], X.sgid if len(X.sgid) > 2: assert len(X.sgid) == 3 assert np.allclose(np.fabs(X.sgid[2].value), np.array([0., 0., 0., 1.])) # retrieve the rotation axis of one of the elements of the stabilizer group to find the point in R^3 which # represents the coset of the identity element of G axis = X._inclusion(X.H.elements[1]).to('Q')[:3] root = axis.reshape(-1) / np.linalg.norm(axis) super(SparseOrbitBasisWithIcosahedralSymmetry, self).__init__( X, G.standard_representation, root, sigma, harmonics, change_of_basis, attributes )