import numpy as np
import torch
from .steerable_filters_basis import SteerableFiltersBasis
from escnn.group import *
from typing import Type, Union, Tuple, Dict, List, Iterable, Callable, Set
[docs]class SparseOrbitBasis(SteerableFiltersBasis):
def __init__(self,
X: HomSpace,
action: Representation,
root: np.ndarray,
sigma: float,
harmonics: List = None,
change_of_basis: np.ndarray = None,
attributes: Dict = None):
r"""
Class which implements a steerable basis for scalar functions on an Euclidean space :math:`\R^n` by considering
a single *orbit* of the symmetry group :math:`G` in :math:`\R^n`.
Indeed, an orbit :math:`\{g.x | g \in G\}`, with :math:`x \in \R^n`, is isomorphic to
an homogeneous space ```X``` of :math:`G`.
A steerable basis for scalar functions over an homogeneous space is given by the harmonic basis in
:class:`~escnn.group.HomSpace`.
In this class, a steerable basis over :math:`\R^n` is defined by embedding the finite number of points of
a *finite* homogeneous space ```X``` into :math:`\R^n` and then "diffusing" the harmonic basis defined over
these points in the ambient space using a Gaussian kernel.
This class only supports finite homogeneous spaces generated by a finite symmetry group, i.e. not only
:math:`X=G/H` must have a finite number of elements but also :math:`G` itself should.
The embedding of ```X``` into :math:`\R^n` is defined via the point ``root``, which defines the embedding of the
coset :math:`eH` associated with the identity, and the ``action`` of the group :math:`G` on :math:`\R^n`.
This implies that the action of any element of :math:`H` should keep ``root`` fixed.
Args:
X (HomSpace): the homogeneous space isomorphic to the orbit in the Euclidean space
action (Representation): the action of the group ``G`` on the Euclidean space
root (~np.ndarray): the embedding of the identity coset
sigma (float): the standard deviation of the Gaussian kernel used to diffuse the harmonic basis in the
ambient space.
harmonics (list, optional): optionally, select only a subset of the harmonic functions. The list should
contain the ids of group's irreps which should be used to construct the
harmonic basis. See also :meth:`~escnn.group.HomSpace.scalar_basis`.
change_of_basis (~np.ndarray, optional): optionally, apply a transformation on the embedded points to
construct a different basis. Equivalently, one can apply this
change of basis to both ``root`` and ``action``.
attributes (dict, optional): additional attributes to describe the generated basis.
Attributes:
~.X (HomSpace): the homogeneous space
"""
self.X: HomSpace = X
# check that the homogeneous space has a finite number of elements
# unfortunately, we can not directly check the size of the homogeneous space G/H so we check that G itself has
# only a finite number of elements
assert self.X.G.order() > 0
if harmonics is None:
harmonics = [psi.id for psi in self.X.G.irreps()]
js = [
(j, self.X.dimension_basis(j, self.X.H.trivial_representation.id)[1])
for j in harmonics
if self.X.dimension_basis(j, self.X.H.trivial_representation.id)[1] > 0
]
super(SparseOrbitBasis, self).__init__(self.X.G, action, js)
assert root.shape == (self.dimensionality,), (root.shape, self.dimensionality)
self.root = root
self.action = action
root = root.reshape(-1, 1)
for h in self.X.H.elements:
assert np.allclose(root, action(self.X._inclusion(h)) @ root, atol=1e-5, rtol=1e-3)
points = np.concatenate([
action(g) @ root for g in self.X.G.elements
], axis=1)
_, idx = np.unique(
np.round(points, decimals=4),
axis=1, return_index=True
)
points = points[:, idx].T
assert points.shape == (self.X.G.order() / self.X.H.order(), self.dimensionality)
assert sigma > 0., sigma
for j, m in self.js:
_harmonics_j = torch.tensor(np.stack([
self.X.scalar_basis(self.X.G.elements[i], j)[..., 0].T
for i in idx
], axis=0), dtype=torch.float32)
assert _harmonics_j.shape == (idx.shape[0], m, self.X.G.irrep(*j).size), (_harmonics_j.shape, m, j)
self.register_buffer(f'harmonics_{j}', _harmonics_j)
self.sigma = sigma
points = torch.tensor(points, dtype=torch.float32)
if change_of_basis is not None:
assert change_of_basis.shape == (self.dimensionality, self.dimensionality)
# check that the matrix is invertible
assert np.isfinite(np.linalg.cond(change_of_basis))
change_of_basis = torch.tensor(change_of_basis, dtype=points.dtype, device=points.device)
points = points @ change_of_basis.T
self.register_buffer('points', points)
self.change_of_basis = change_of_basis
self._attributes = attributes if attributes is not None else dict()
def _get_harmonics(self, j):
return getattr(self, f'harmonics_{j}')
[docs] def sample(self, points: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
r"""
Sample the continuous basis elements on the discrete set of points in ``points``.
Optionally, store the resulting multidimentional array in ``out``.
``points`` must be an array of shape `(N, d)`, where `N` is the number of points and
`d` is the :meth:`~escnn.kernels.SteerableFilterBasis.dimensionality` of the Euclidean space.
Args:
points (~torch.Tensor): points where to evaluate the basis elements
out (~torch.Tensor, optional): pre-existing array to use to store the output
Returns:
the sampled basis
"""
assert len(points.shape) == 2
S = points.shape[0]
assert points.shape[1] == self.dimensionality
if out is None:
out = torch.empty((S, self.dim, 1, 1), dtype=self.points.dtype, device=self.points.device)
assert out.shape == (S, self.dim, 1, 1)
assert self.points.device == points.device, (self.points.device, points.device)
assert out.device == points.device, (out.device, points.device)
weights = self.points.unsqueeze(1) - points.unsqueeze(0)
assert weights.shape == (self.points.shape[0], S, self.dimensionality)
weights = (weights**2).sum(axis=2) / self.sigma**2
weights = torch.exp(- 0.5 * weights)
B = 0
for j, m in self.js:
out[:, B:B + self.dim_harmonic(j), ...].view(
S, m, -1
)[:] = torch.einsum('irm,io->orm', self._get_harmonics(j), weights)
B += self.dim_harmonic(j)
return out
def __iter__(self) -> Iterable:
idx = 0
for j, M in self.js:
psi = self.group.irrep(*j)
attr = {
'irrep:' + k: v
for k, v in psi.attributes.items()
}
attr.update(**self._attributes)
attr["j"] = j
attr["shape"] = (1, 1)
for i in range(M):
for m in range(psi.size):
_attr = {"idx" : idx, "i" : i, "m" : m}
_attr.update(**attr)
yield _attr
idx += 1
def __getitem__(self, idx) -> Dict:
assert idx < self.dim, (idx, self.dim)
j, j_idx = self._get_j_from_idx(idx)
attr = {
'irrep:' + k: v
for k, v in self.group.irrep(*j).attributes.items()
}
attr.update(**self._attributes)
attr["idx"] = idx
attr["j"] = j
attr["i"] = j_idx // self.group.irrep(*j).size
attr["m"] = j_idx % self.group.irrep(*j).size
attr["shape"] = (1, 1)
return attr
def steerable_attrs_j_iter(self, j: Tuple) -> Iterable:
attr = {
'irrep:' + k: v
for k, v in self.X.G.irrep(*j).attributes.items()
}
attr['sigma'] = self.sigma
attr.update(**self._attributes)
for idx in range(self.multiplicity(j)):
_attr = {'i': idx}
_attr.update(**attr)
yield _attr
def steerable_attrs_j(self, j: Tuple, idx) -> Dict:
assert 0 <= idx < self.multiplicity(j), idx
attr = {
'irrep:' + k: v
for k, v in self.X.G.irrep(*j).attributes.items()
}
attr['i'] = idx
attr['sigma'] = self.sigma
attr.update(**self._attributes)
return attr
def __eq__(self, other):
if not isinstance(other, SparseOrbitBasis):
return False
elif self.X != other.X:
return False
elif self.action != other.action:
return False
elif not np.allclose(self.root, other.root, atol=1e-5):
return False
elif self.js != other.js:
return False
elif not np.isclose(self.sigma, other.sigma):
return False
elif (self.change_of_basis is None) != (other.change_of_basis is None):
return False
elif self.change_of_basis is not None:
return torch.allclose(self.change_of_basis, other.change_of_basis)
else:
return True
def __hash__(self):
return 100 * hash(self.X) + 10 * hash(self.root.tobytes()) + 10 * hash(self.sigma) + hash(str(self.change_of_basis)) + sum(hash(x) for x in self.js) + hash(self.action)
[docs]class SparseOrbitBasisWithIcosahedralSymmetry(SparseOrbitBasis):
def __init__(self,
X: HomSpace,
sigma: float,
harmonics: List = None,
change_of_basis: np.ndarray = None,
attributes: Dict = None):
r"""
Given an homogeneous space ``X`` of the Icosahedral group, builds the :class:`~escnn.kernels.SparseOrbitBasis`
with a ``root`` which generates an orbit isomorphic to ``X``.
"""
G = ico_group()
assert X.G == G, X.G
assert X.sgid[:2] in [(False, 2), (False, 3), (False, 5)], X.sgid
if len(X.sgid) > 2:
assert len(X.sgid) == 3
assert np.allclose(np.fabs(X.sgid[2].value), np.array([0., 0., 0., 1.]))
# retrieve the rotation axis of one of the elements of the stabilizer group to find the point in R^3 which
# represents the coset of the identity element of G
axis = X._inclusion(X.H.elements[1]).to('Q')[:3]
root = axis.reshape(-1) / np.linalg.norm(axis)
super(SparseOrbitBasisWithIcosahedralSymmetry, self).__init__(
X,
G.standard_representation,
root,
sigma, harmonics, change_of_basis, attributes
)