import numpy as np
from .steerable_basis import IrrepBasis
from .steerable_filters_basis import SteerableFiltersBasis
from escnn.group import *
import torch
from typing import Union, Tuple, Dict, Iterable, List
from collections import defaultdict
from itertools import chain
__all__ = [
"WignerEckartBasis",
"RestrictedWignerEckartBasis"
]
[docs]class WignerEckartBasis(IrrepBasis):
def __init__(self,
basis: SteerableFiltersBasis,
in_irrep: Union[str, IrreducibleRepresentation, int],
out_irrep: Union[str, IrreducibleRepresentation, int],
):
r"""
Solves the kernel constraint for a pair of input and output :math:`G`-irreps by using the Wigner-Eckart theorem
described in Theorem 2.1 of
`A Program to Build E(N)-Equivariant Steerable CNNs <https://openreview.net/forum?id=WE4qe9xlnQw>`_
(see also
`A Wigner-Eckart Theorem for Group Equivariant Convolution Kernels <https://arxiv.org/abs/2010.10952>`_
).
The method relies on a :math:`G`-Steerable basis of scalar functions over the base space.
Args:
basis (SteerableFiltersBasis): a `G`-steerable basis for scalar functions over the base space
in_repr (IrreducibleRepresentation): the input irrep
out_repr (IrreducibleRepresentation): the output irrep
"""
group = basis.group
in_irrep = group.irrep(*group.get_irrep_id(in_irrep))
out_irrep = group.irrep(*group.get_irrep_id(out_irrep))
assert in_irrep.group == group
assert out_irrep.group == group
assert in_irrep.group == out_irrep.group
self.m = in_irrep.id
self.n = out_irrep.id
_js = group._tensor_product_irreps(self.m, self.n)
_js = [
(j, jJl)
for j, jJl in _js
if basis.multiplicity(j) > 0
]
dim = 0
self._dim_harmonics = {}
self._jJl = {}
for j, jJl in _js:
self._dim_harmonics[j] = basis.multiplicity(j) * jJl * group.irrep(*j).sum_of_squares_constituents
self._jJl[j] = jJl
dim += self._dim_harmonics[j]
super(WignerEckartBasis, self).__init__(basis, in_irrep, out_irrep, dim, harmonics=[_j for _j, _ in _js])
# SteerableFiltersBasis: a `G`-steerable basis for scalar functions over the base space
self.basis = basis
_coeff = [
torch.einsum(
# 'mnsi,koi->mnkso',
'mnsi,koi->ksmno',
torch.tensor(group._clebsh_gordan_coeff(self.n, self.m, j), dtype=torch.float32),
torch.tensor(group.irrep(*j).endomorphism_basis(), dtype=torch.float32),
) for j in self.js
]
for b, j in enumerate(self.js):
coeff = _coeff[b]
assert self._jJl[j] == coeff.shape[1]
self.register_buffer(f'coeff_{b}', coeff)
def coeff(self, idx: int) -> torch.Tensor:
return getattr(self, f'coeff_{idx}')
def sample_harmonics(self, points: Dict[Tuple, torch.Tensor], out: Dict[Tuple, torch.Tensor] = None) -> Dict[Tuple, torch.Tensor]:
if out is None:
out = {
j: torch.zeros(
(points[j].shape[0], self.dim_harmonic(j), self.shape[0], self.shape[1]),
device=points[j].device, dtype=points[j].dtype
)
for j in self.js
}
for j in self.js:
if j in out:
assert out[j].shape == (points[j].shape[0], self.dim_harmonic(j), self.shape[0], self.shape[1]), (
out[j].shape, points[j].shape[0], self.dim_harmonic(j), self.shape[0], self.shape[1]
)
for b, j in enumerate(self.js):
if j not in out:
continue
coeff = self.coeff(b)
jJl = coeff.shape[1]
Ys = points[j]
out[j].view((
Ys.shape[0],
self.group.irrep(*j).sum_of_squares_constituents, jJl,
Ys.shape[1],
self.out_irrep.size, self.in_irrep.size,
))[:] = torch.einsum(
# 'Nnksm,miS->NnksiS',
'kspnm,qim->qksipn',
coeff, Ys,
)
return out
def dim_harmonic(self, j: Tuple) -> int:
if j in self._dim_harmonics:
return self._dim_harmonics[j]
else:
return 0
def attrs_j_iter(self, j: Tuple) -> Iterable:
if self.dim_harmonic(j) == 0:
return
idx = self._start_index[j]
j_attr = {
'irrep:'+k: v
for k, v in self.group.irrep(*j).attributes.items()
}
steerable_basis_j_attr = list(self.basis.steerable_attrs_j_iter(j))
for k in range(self.group.irrep(*j).sum_of_squares_constituents):
for s in range(self._jJl[j]):
for i, attr_i in enumerate(steerable_basis_j_attr):
attr = j_attr.copy()
attr.update(**attr_i)
attr["idx"] = idx
attr["j"] = j
attr["i"] = i
attr["s"] = s
attr["k"] = k
idx += 1
yield attr
def attrs_j(self, j: Tuple, idx) -> Dict:
assert 0 <= idx < self.dim_harmonic(j)
full_idx = self._start_index[j] + idx
dim = self.basis.multiplicity(j)
attr = {
'irrep:'+k: v
for k, v in self.group.irrep(*j).attributes.items()
}
i = idx % dim
attr_i = self.basis.steerable_attrs_j(j, i)
attr.update(**attr_i)
attr["idx"] = full_idx
attr["j"] = j
attr["i"] = i
attr["s"] = (idx // dim) % self._jJl[j]
attr["k"] = idx // (dim * self._jJl[j])
return attr
def __getitem__(self, idx):
assert 0 <= idx < self.dim
i = idx
for j in self.js:
dim = self.dim_harmonic(j)
if i < dim:
break
else:
i -= dim
return self.attrs_j(j, i)
def __iter__(self):
for j in self.js:
for attr in self.attrs_j_iter(j):
yield attr
# return chain(self.attrs_j_iter(j) for j in self.js)
def __eq__(self, other):
if not isinstance(other, WignerEckartBasis):
return False
elif self.basis != other.basis or self.in_irrep != other.in_irrep or self.out_irrep != other.out_irrep:
# TODO check isomorphism too!
return False
elif len(self.js) != len(other.js):
return False
else:
for b, (j, i) in enumerate(zip(self.js, other.js)):
if j!=i or not torch.allclose(self.coeff(b), other.coeff(b)):
return False
return True
def __hash__(self):
return hash(self.basis) + hash(self.in_irrep) + hash(self.out_irrep) + hash(tuple(self.js))
_cached_instances = {}
@classmethod
def _generator(cls,
basis: SteerableFiltersBasis,
psi_in: Union[IrreducibleRepresentation, Tuple],
psi_out: Union[IrreducibleRepresentation, Tuple],
**kwargs) -> 'IrrepBasis':
assert len(kwargs) == 0
psi_in = basis.group.irrep(*basis.group.get_irrep_id(psi_in))
psi_out = basis.group.irrep(*basis.group.get_irrep_id(psi_out))
key = (basis, psi_in.id, psi_out.id)
if key not in cls._cached_instances:
cls._cached_instances[key] = WignerEckartBasis(basis, in_irrep=psi_in, out_irrep=psi_out)
return cls._cached_instances[key]
[docs]class RestrictedWignerEckartBasis(IrrepBasis):
def __init__(self,
basis: SteerableFiltersBasis,
sg_id: Tuple,
in_irrep: Union[str, IrreducibleRepresentation, int],
out_irrep: Union[str, IrreducibleRepresentation, int],
):
r"""
Solves the kernel constraint for a pair of input and output :math:`G`-irreps by using the Wigner-Eckart theorem
described in Theorem 2.1 of
`A Program to Build E(N)-Equivariant Steerable CNNs <https://openreview.net/forum?id=WE4qe9xlnQw>`_.
This method implicitly constructs the required :math:`G`-steerable basis for scalar functions on the base space
from a :math:`G'`-steerable basis, with :math:`G' > G` a larger group, according to Equation 5 from the same
paper.
The equivariance group :math:`G < G'` is identified by the input id ``sg_id``.
.. warning::
Note that the group :math:`G'` associated with ``basis`` is generally not the same as the group :math:`G`
associated with ``in_irrep`` and ``out_irrep`` and which the resulting kernel basis is equivariant to.
Args:
basis (SteerableFiltersBasis): :math:`G'`-steerable basis for scalar filters
sg_id (tuple): id of :math:`G` as a subgroup of :math:`G'`.
in_repr (IrreducibleRepresentation): the input `G`-irrep
out_repr (IrreducibleRepresentation): the output `G`-irrep
"""
# the larger group G'
_G = basis.group
G = _G.subgroup(sg_id)[0]
# Group: the smaller equivariance group G
self.group = G
self.sg_id = sg_id
in_irrep = G.irrep(*G.get_irrep_id(in_irrep))
out_irrep = G.irrep(*G.get_irrep_id(out_irrep))
assert in_irrep.group == G
assert out_irrep.group == G
assert in_irrep.group == out_irrep.group
self.m = in_irrep.id
self.n = out_irrep.id
# irreps of G in the decomposition of the tensor product of in_irrep and out_irrep
_js_G = [
j for j, _ in
G._tensor_product_irreps(self.m, self.n)
]
_js = set()
_js_restriction = defaultdict(list)
# for each harmonic j' to consider
for _j in set(_j for _j, _ in basis.js):
if basis.multiplicity(_j) == 0:
continue
# restrict the corresponding G' irrep j' to G
_j_G = _G.irrep(*_j).restrict(sg_id)
# for each G-irrep j in the tensor product decomposition of in_irrep and out_irrep
for j in _js_G:
# irrep-decomposition coefficients of j in j'
id_coeff = []
p = 0
# for each G-irrep i in the restriction of j' to G
for i in _j_G.irreps:
size = G.irrep(*i).size
# if the restricted irrep contains one of the irreps in the tensor product
if i == j:
id_coeff.append(
_j_G.change_of_basis_inv[p:p+size, :]
)
p += size
# if the G irrep j appears in the restriction of the G'-irrep j',
# store its irrep-decomposition coefficients
if len(id_coeff) > 0:
id_coeff = np.stack(id_coeff, axis=-1)
_js.add(_j)
_js_restriction[_j].append((j, id_coeff))
_js = sorted(list(_js))
self._js_restriction = {}
self._dim_harmonics = {}
_coeffs = {}
dim = 0
for _j in _js:
Y_size = _G.irrep(*_j).size
coeff = [
torch.einsum(
# 'nmsi,kji,jyt->nmksty',
'nmsi,kji,jyt->kstnmy',
torch.tensor(G._clebsh_gordan_coeff(self.n, self.m, j), dtype=torch.float32),
torch.tensor(G.irrep(*j).endomorphism_basis(), dtype=torch.float32),
torch.tensor(id_coeff, dtype=torch.float32),
).reshape((-1, out_irrep.size, in_irrep.size, Y_size))
for j, id_coeff in _js_restriction[_j]
]
_coeffs[_j] = torch.cat(coeff, dim=0)
self._js_restriction[_j] = [(j, id_coeff.shape[2]) for j, id_coeff in _js_restriction[_j]]
self._dim_harmonics[_j] = _coeffs[_j].shape[0]
dim += self._dim_harmonics[_j] * basis.multiplicity(_j)
super(RestrictedWignerEckartBasis, self).__init__(basis, in_irrep, out_irrep, dim, harmonics=_js)
# SteerableFiltersBasis: a `G'`-steerable basis for scalar functions over the base space, for the larger
# group `G' > G`
self.basis = basis
for b, _j in enumerate(self.js):
self.register_buffer(f'coeff_{b}', _coeffs[_j])
def coeff(self, idx: int) -> torch.Tensor:
return getattr(self, f'coeff_{idx}')
def sample_harmonics(self, points: Dict[Tuple, torch.Tensor], out: Dict[Tuple, torch.Tensor] = None) -> Dict[
Tuple, torch.Tensor]:
if out is None:
out = {
j: torch.zeros(
(points[j].shape[0], self.dim_harmonic(j), self.shape[0], self.shape[1]),
device=points[j].device, dtype=points[j].dtype
)
for j in self.js
}
for j in self.js:
if j in out:
assert out[j].shape == (points[j].shape[0], self.dim_harmonic(j), self.shape[0], self.shape[1])
for b, j in enumerate(self.js):
if j not in out:
continue
coeff = self.coeff(b)
Ys = points[j]
out[j].view((
Ys.shape[0], coeff.shape[0], Ys.shape[1],
self.out_irrep.size, self.in_irrep.size,
))[:] = torch.einsum(
'dpnm,sim->sdipn',
coeff, Ys,
)
return out
def dim_harmonic(self, j: Tuple) -> int:
if j in self._dim_harmonics:
return self.basis.multiplicity(j) * self._dim_harmonics[j]
else:
return 0
def attrs_j_iter(self, j: Tuple) -> Iterable:
if self.dim_harmonic(j) == 0:
return
idx = self._start_index[j]
steerable_basis_j_attr = list(self.basis.steerable_attrs_j_iter(j))
j_attr = {
'irrep:'+k: v
for k, v in self.basis.group.irrep(*j).attributes.items()
}
count = 0
for _j, _jj in self._js_restriction[j]:
_jJl = self.group._clebsh_gordan_coeff(self.n, self.m, _j).shape[2]
K = self.group.irrep(*_j).sum_of_squares_constituents
for k in range(K):
for s in range(_jJl):
for t in range(_jj):
for i, attr_i in enumerate(steerable_basis_j_attr):
attr = j_attr.copy()
attr.update(**attr_i)
attr["idx"] = idx
attr["j"] = j
attr["_j"] = _j
attr["i"] = i
attr["t"] = t
attr["s"] = s
attr["k"] = k
assert idx < self.dim
assert count < self.dim_harmonic(j), (count, self.dim_harmonic(j))
idx += 1
count += 1
yield attr
assert count == self.dim_harmonic(j), (count, self.dim_harmonic(j))
def attrs_j(self, j: Tuple, idx) -> Dict:
assert 0 <= idx < self.dim_harmonic(j)
full_idx = self._start_index[j] + idx
dim = self.basis.multiplicity(j)
for _j, _jj in self._js_restriction[j]:
_jJl = self.group._clebsh_gordan_coeff(self.n, self.m, _j).shape[2]
K = self.group.irrep(*_j).sum_of_squares_constituents
d = _jj * _jJl * K * dim
if idx >= d:
idx -= d
else:
break
i = idx % dim
attr_i = self.basis.steerable_attrs_j(j, i)
attr = {
'irrep:'+k: v
for k, v in self.basis.group.irrep(*j).attributes.items()
}
attr.update(**attr_i)
attr["idx"] = full_idx
attr["j"] = j
attr["_j"] = _j
attr["i"] = i
attr["t"] = (idx // dim) % _jj
attr["s"] = (idx // (dim * _jj)) % _jJl
attr["k"] = idx // (dim * _jj * _jJl)
return attr
def __getitem__(self, idx):
assert 0 <= idx < self.dim
i = idx
for j in self.js:
dim = self.dim_harmonic(j)
if i < dim:
break
else:
i -= dim
return self.attrs_j(j, i)
def __iter__(self) -> Iterable:
for j in self.js:
for attr in self.attrs_j_iter(j):
yield attr
# return chain(self.attrs_j_iter(j) for j in self.js)
def __eq__(self, other):
if not isinstance(other, RestrictedWignerEckartBasis):
return False
elif self.basis != other.basis or self.sg_id != other.sg_id or self.in_irrep != other.in_irrep or self.out_irrep != other.out_irrep:
# TODO check isomorphism too!
return False
elif len(self.js) != len(other.js):
return False
else:
for b, (j, i) in enumerate(zip(self.js, other.js)):
if j!=i or not torch.allclose(self.coeff(b), other.coeff(b)):
return False
return True
def __hash__(self):
return hash(self.basis) + hash(self.sg_id) + hash(self.in_irrep) + hash(self.out_irrep) + hash(tuple(self.js))
_cached_instances = {}
@classmethod
def _generator(cls,
basis: SteerableFiltersBasis,
psi_in: Union[IrreducibleRepresentation, Tuple],
psi_out: Union[IrreducibleRepresentation, Tuple],
**kwargs) -> 'IrrepBasis':
assert len(kwargs) == 1
assert 'sg_id' in kwargs
G, _, _ = basis.group.subgroup(kwargs['sg_id'])
psi_in = G.irrep(*G.get_irrep_id(psi_in))
psi_out = G.irrep(*G.get_irrep_id(psi_out))
key = (
basis, psi_in.id, psi_out.id,
kwargs['sg_id']
)
if key not in cls._cached_instances:
cls._cached_instances[key] = RestrictedWignerEckartBasis(
basis,
sg_id=kwargs['sg_id'],
in_irrep=psi_in,
out_irrep=psi_out,
)
return cls._cached_instances[key]