Source code for escnn.kernels.polar_basis

import numpy as np
import torch

from escnn.group import *

from abc import ABC, abstractmethod
from typing import List, Union, Tuple, Callable, Dict, Iterable

from .basis import KernelBasis, EmptyBasisException
from .steerable_filters_basis import SteerableFiltersBasis

from .harmonic_polynomial_r3 import HarmonicPolynomialR3Generator


__all__ = [
    'GaussianRadialProfile',
    'SphericalShellsBasis'
]


[docs]class GaussianRadialProfile(KernelBasis): def __init__(self, radii: List[float], sigma: Union[List[float], float]): r""" Basis for kernels defined over a radius in :math:`\R^+_0`. Each basis element is defined as a Gaussian function. Different basis elements are centered at different radii (``rings``) and can possibly be associated with different widths (``sigma``). More precisely, the following basis is implemented: .. math:: \mathcal{B} = \left\{ b_i (r) := \exp \left( \frac{ \left( r - r_i \right)^2}{2 \sigma_i^2} \right) \right\}_i In order to build a complete basis of kernels, you should combine this basis with a basis which defines the angular profile, see for example :class:`~escnn.kernels.SphericalShellsBasis` or :class:`~escnn.kernels.CircularShellsBasis`. Args: radii (list): centers of each basis element. They should be different and spread to cover all domain of interest sigma (list or float): widths of each element. Can potentially be different. """ if isinstance(sigma, float): sigma = [sigma] * len(radii) assert len(radii) == len(sigma) assert isinstance(radii, list) for r in radii: assert r >= 0. for s in sigma: assert s > 0. super(GaussianRadialProfile, self).__init__(len(radii), (1, 1)) self.register_buffer('radii', torch.tensor(radii, dtype=torch.float32).reshape(1, -1, 1, 1)) self.register_buffer('sigma', torch.tensor(sigma, dtype=torch.float32).reshape(1, -1, 1, 1))
[docs] def sample(self, radii: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: r""" Sample the continuous basis elements on the discrete set of radii in ``radii``. Optionally, store the resulting multidimentional array in ``out``. ``radii`` must be an array of shape `(N, 1)`, where `N` is the number of points. Args: radii (~torch.Tensor): radii where to evaluate the basis elements out (~torch.Tensor, optional): pre-existing array to use to store the output Returns: the sampled basis """ assert len(radii.shape) == 2 assert radii.shape[1] == 1 S = radii.shape[0] if out is None: out = torch.empty((S, self.dim, self.shape[0], self.shape[1]), device=radii.device, dtype=radii.dtype) assert out.shape == (S, self.dim, self.shape[0], self.shape[1]) radii = radii.reshape(-1, 1, 1, 1) assert not torch.isnan(radii).any() d = (self.radii - radii) ** 2 if radii.requires_grad: out[:] = torch.exp(-0.5 * d / self.sigma ** 2) else: out = torch.exp(-0.5 * d / self.sigma ** 2, out=out) return out
def __getitem__(self, r): assert r < self.dim return {"radius": self.radii[0, r, 0, 0].item(), "sigma": self.sigma[0, r, 0, 0].item(), "idx": r} def __eq__(self, other): if isinstance(other, GaussianRadialProfile): return ( torch.allclose(self.radii, other.radii.to(self.radii.device)) and torch.allclose(self.sigma, other.sigma.to(self.sigma.device)) ) else: return False def __hash__(self): return hash(self.radii.cpu().numpy().tobytes()) + hash(self.sigma.cpu().numpy().tobytes())
def circular_harmonics(points: torch.Tensor, L: int, phase: float = 0.): r""" Compute the circular harmonics up to frequency ``L``. """ assert len(points.shape) == 2 assert points.shape[1] == 2 device = points.device dtype = points.dtype S = points.shape[0] x, y = points.T angles = torch.atan2(y, x).view(S, 1) - phase freqs = torch.arange(1, L+1, device=device, dtype=dtype).view(1, L) freqs_times_angles = freqs * angles del freqs, angles Y = torch.empty((S, 2 * L + 1), dtype=dtype, device=device) Y[:, 0] = 1. Y[:, 1::2] = torch.cos(freqs_times_angles) Y[:, 2::2] = torch.sin(freqs_times_angles) return Y
[docs]class SphericalShellsBasis(SteerableFiltersBasis): def __init__(self, L: int, radial: GaussianRadialProfile, filter: Callable[[Dict], bool] = None ): r""" Build the tensor product basis of a radial profile basis and a spherical harmonics basis for kernels over the Euclidean space :math:`\R^3`. The kernel space is spanned by an independent basis for each shell. The kernel space over each shell is spanned by the spherical harmonics of frequency up to `L` (an independent copy of each for each cell). Given the bases :math:`A = \{a_j\}_j` for the spherical shells and :math:`D = \{d_r\}_r` for the radial component (indexed by :math:`r \geq 0`, the radius of each ring), this basis is defined as .. math:: C = \left\{c_{i,j}(\bold{p}) := d_r(||\bold{p}||) a_j(\hat{\bold{p}}) \right\}_{r, j} where :math:`(||\bold{p}||, \hat{\bold{p}})` are the polar coordinates of the point :math:`\bold{p} \in \R^n`. The radial component is parametrized using :class:`~escnn.kernels.GaussianRadialProfile`. Args: L (int): the maximum spherical frequency radial (GaussianRadialProfile): the basis for the radial profile filter (callable, optional): function used to filter out some basis elements. It takes as input a dict describing a basis element and should return a boolean value indicating whether to keep (`True`) or discard (`False`) the element. By default (`None`), all basis elements are kept. Attributes: ~.radial (GaussianRadialProfile): the radial basis ~.L (int): the maximum spherical frequency """ self.L: int = L assert isinstance(radial, GaussianRadialProfile) self._angular_dim = (L+1)**2 # number of invariant subspaces self._num_inv_spaces = 0 G = o3_group(L) if filter is not None: _filter = torch.zeros(self._angular_dim * len(radial), dtype=torch.bool) js = [] _idx_map = [] _steerable_idx_map = [] i = 0 steerable_i = 0 for j in range(self.L+1): j_id = (j % 2, j) # the id of the O(3) irrep attr2 = { 'irrep:' + k: v for k, v in G.irrep(*j_id).attributes.items() } attr2['j'] = j_id dim = 2 * j + 1 multiplicity = 0 for attr1 in radial: attr = dict() attr.update(attr1) attr.update(attr2) if filter(attr): multiplicity += 1 _filter[i:i+dim] = 1 _idx_map += list(range(i, i+dim)) _steerable_idx_map.append(steerable_i) i += dim steerable_i += 1 js.append( ( (j%2, j), # the O(3) irrep ID multiplicity ) ) self._num_inv_spaces += multiplicity self._idx_map = np.array(_idx_map) self._steerable_idx_map = np.array(_steerable_idx_map) else: _filter = None self._idx_map = None self._steerable_idx_map = None js = [ ( (j % 2, j), # the O(3) irrep ID len(radial) ) for j in range(L+1) ] super(SphericalShellsBasis, self).__init__(G, G.standard_representation(), js) self.radial = radial self.harmonics_generator = HarmonicPolynomialR3Generator(self.L) if _filter is not None: self.register_buffer('_filter', _filter) else: self._filter = None
[docs] def sample(self, points: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: r""" Sample the continuous basis elements on a discrete set of ``points`` in the space :math:`\R^n`. Optionally, store the resulting multidimensional array in ``out``. ``points`` must be an array of shape `(N, 3)` containing `N` points in the space. Note that the points are specified in cartesian coordinates :math:`(x, y, z, ...)`. Args: points (~torch.Tensor): points in the n-dimensional Euclidean space where to evaluate the basis elements out (~torch.Tensor, optional): pre-existing array to use to store the output Returns: the sampled basis """ assert len(points.shape) == 2 assert points.shape[1] == self.dimensionality, (points.shape, self.dimensionality) S = points.shape[0] assert not torch.isnan(points).any() radii = torch.norm(points, dim=1, keepdim=True) non_origin_mask = (radii > 1e-9).reshape(-1) any_origin = not non_origin_mask.all() # sphere = points[non_origin_mask, :] / radii[non_origin_mask, :] if any_origin: sphere = torch.empty_like(points) sphere[non_origin_mask, :] = torch.nn.functional.normalize(points[non_origin_mask], dim=1) sphere[~non_origin_mask, :] = points[~non_origin_mask] else: sphere = torch.nn.functional.normalize(points, dim=1) if out is None: out = torch.empty(S, self.dim, 1, 1, device=points.device, dtype=points.dtype) assert out.shape == (S, self.dim, 1, 1) # sample the radial basis radial = self.radial.sample(radii) assert radial.shape == (S, len(self.radial), 1, 1) radial = radial[..., 0, 0] assert not torch.isnan(radial).any() # sample the angular basis spherical = self.harmonics_generator(sphere) # only frequency 0 is sampled at the origin. Other frequencies are set to 0 # spherical[~non_origin_mask, :1] = 1. # spherical[~non_origin_mask, 1:] = 0. if any_origin: assert (spherical[~non_origin_mask, :1]-1.).abs().max().item() < 1e-7, (spherical[~non_origin_mask, :1]-1.).abs().max().item() if spherical.shape[1] > 1: assert (spherical[~non_origin_mask, 1:]).abs().max().item() < 1e-7, (spherical[~non_origin_mask, 1:]).abs().max().item() assert not torch.isnan(spherical).any() tensor_product = torch.einsum("pa,pb->pab", radial, spherical) n_radii = len(self.radial) if self._filter is None: tmp_out = out else: tmp_out = torch.empty(S, self._angular_dim*n_radii, 1, 1, device=points.device, dtype=points.dtype) for j in range(self.L+1): first, last = j**2, (j+1)**2 tmp_out[:, first * n_radii:last * n_radii, 0, 0].view( S, n_radii, 2*j+1 )[:] = tensor_product[:, :, first:last] if self._filter is not None: out[:] = tmp_out[:, self._filter, ...] return out
def steerable_attrs_iter(self): # This attributes don't describe a single basis element but a group of basis elements which span an invariant # subspace. This is needed to generate the attributes of the SteerableKernelBasis idx = 0 i = 0 # since this methods return iterables of attributes built on the fly, load all attributes first and then # iterate on these lists radial_attrs = list(self.radial) for j in range(self.L + 1): attr1 = { 'irrep:' + k: v for k, v in self.group.irrep(j%2, j).attributes.items() } for radial_idx, attr2 in enumerate(radial_attrs): if self._filter is None or (self._filter[i:i+2*j+1] == 1).all(): assert attr2["idx"] == radial_idx attr = dict() attr.update(attr1) attr.update(attr2) attr["idx"] = idx attr["radial_idx"] = radial_idx attr["j"] = (j % 2, j) # the id of the O(3) irrep attr["shape"] = (1, 1) yield attr idx += 1 i += 2*j+1 def steerable_attrs(self, idx): # This attributes don't describe a single basis element but a group of basis elements which span an invariant # subspace. This is needed to generate the attributes of the SteerableKernelBasis assert idx < self._num_inv_spaces, (idx, self._num_inv_spaces) if self._steerable_idx_map is None: _idx = idx else: _idx = self._steerable_idx_map[idx] j, radial_idx = divmod(_idx, len(self.radial)) attr2 = self.radial[radial_idx] assert attr2["idx"] == radial_idx attr = dict() attr.update(attr2) j = (j%2, j) # the id of the O(3) irrep attr = { 'irrep:'+k: v for k, v in self.group.irrep(*j).attributes.items() } attr["j"] = j attr["idx"] = idx attr["radial_idx"] = radial_idx attr["shape"] = (1, 1) return attr def steerable_attrs_j_iter(self, j: Tuple) -> Iterable: # This attributes don't describe a single basis element but a group of basis elements which span an invariant # subspace. This is needed to generate the attributes of the SteerableKernelBasis j_id = j f, j = j_id if f != j%2: return idx = sum(self.multiplicity((_j%2, _j)) for _j in range(j)) i = 0 # since this methods return iterables of attributes built on the fly, load all attributes first and then # iterate on these lists radial_attrs = list(self.radial) attr1 = { 'irrep:' + k: v for k, v in self.group.irrep(*j_id).attributes.items() } for radial_idx, attr2 in enumerate(radial_attrs): if self._filter is None or (self._filter[i:i+2*j+1] == 1).all(): assert attr2["idx"] == radial_idx attr = dict() attr.update(attr1) attr.update(attr2) attr["idx"] = idx attr["radial_idx"] = radial_idx attr["j"] = j_id attr["shape"] = (1, 1) yield attr idx += 1 i += 2*j+1 def steerable_attrs_j(self, j: Tuple, idx) -> Dict: # This attributes don't describe a single basis element but a group of basis elements which span an invariant # subspace. This is needed to generate the attributes of the SteerableKernelBasis j_id = j f, j = j_id if f != j % 2: return assert idx < self.multiplicity(j_id), (idx, self.multiplicity(j_id)) idx += sum(self.multiplicity((_j%2, _j)) for _j in range(j)) if self._steerable_idx_map is None: _idx = idx else: _idx = self._steerable_idx_map[idx] _j, radial_idx = divmod(_idx, len(self.radial)) assert _j == j, (j, _j) attr1 = { 'irrep:' + k: v for k, v in self.group.irrep(*j_id).attributes.items() } attr2 = self.radial[radial_idx] assert attr2["idx"] == radial_idx attr = dict() attr.update(attr1) attr.update(attr2) attr["idx"] = idx attr["radial_idx"] = radial_idx attr["j"] = j_id attr["shape"] = (1, 1) return attr def __getitem__(self, idx): assert idx < self.dim, (idx, self.dim) if self._idx_map is None: _idx = idx else: _idx = self._idx_map[idx] j = int(np.floor(np.sqrt(_idx // len(self.radial)))) assert j**2 * len(self.radial) <= _idx < (j+1)**2 * len(self.radial), (_idx, j, self.L, len(self.radial)) j_idx = _idx - j**2 * len(self.radial) radial_idx, m = divmod(j_idx, 2*j+1) j_id = (j%2, j) # the id of the O(3) irrep attr1 = { 'irrep:' + k: v for k, v in self.group.irrep(*j_id).attributes.items() } attr2 = self.radial[radial_idx] assert attr2["idx"] == radial_idx attr = dict() attr.update(attr1) attr.update(attr2) attr["idx"] = idx attr["radial_idx"] = radial_idx attr["j"] = j_id attr["m"] = m attr["shape"] = (1, 1) return attr def __iter__(self): idx = 0 i = 0 # since this methods return iterables of attributes built on the fly, load all attributes first and then # iterate on these lists radial_attrs = list(self.radial) for j in range(self.L+1): j_id = (j % 2, j) # the id of the O(3) irrep attr1 = { 'irrep:' + k: v for k, v in self.group.irrep(*j_id).attributes.items() } for radial_idx, attr2 in enumerate(radial_attrs): for m in range(2*j+1): if self._filter is None or self._filter[i] == 1: assert attr2["idx"] == radial_idx attr = dict() attr.update(attr1) attr.update(attr2) attr["idx"] = idx attr["radial_idx"] = radial_idx attr["j"] = j_id attr["m"] = m attr["shape"] = (1, 1) yield attr idx += 1 i += 1 def __eq__(self, other): if isinstance(other, SphericalShellsBasis): return ( self.radial == other.radial and self.L == other.L and self._filter == other._filter ) else: return False def __hash__(self): return self.L + hash(self.radial) + hash(self._filter)
[docs]class CircularShellsBasis(SteerableFiltersBasis): def __init__(self, L: int, radial: GaussianRadialProfile, filter: Callable[[Dict], bool] = None, axis: float = np.pi/2, ): r""" Build the tensor product basis of a radial profile basis and a circular harmonics basis for kernels over the Euclidean space :math:`\R^2`. The kernel space is spanned by an independent basis for each shell. The kernel space each shell is spanned by the circular harmonics of frequency up to `L` (an independent copy of each for each cell). Given the bases :math:`A = \{a_j\}_j` for the circular shells and :math:`D = \{d_r\}_r` for the radial component (indexed by :math:`r \geq 0`, the radius different rings), this basis is defined as .. math:: C = \left\{c_{i,j}(\bold{p}) := d_r(||\bold{p}||) a_j(\hat{\bold{p}}) \right\}_{r, j} where :math:`(||\bold{p}||, \hat{\bold{p}})` are the polar coordinates of the point :math:`\bold{p} \in \R^n`. The radial component is parametrized using :class:`~escnn.kernels.GaussianRadialProfile`. Args: L (int): the maximum circular frequency radial (GaussianRadialProfile): the basis for the radial profile filter (callable, optional): function used to filter out some basis elements. It takes as input a dict describing a basis element and should return a boolean value indicating whether to keep (`True`) or discard (`False`) the element. By default (`None`), all basis elements are kept. Attributes: ~.radial (GaussianRadialProfile): the radial basis ~.L (int): the maximum circular frequency """ self.L: int = L assert isinstance(radial, GaussianRadialProfile) self._angular_dim = 2*L+1 # number of invariant subspaces self._num_inv_spaces = 0 G = o2_group(L) if filter is not None: _filter = torch.zeros(self._angular_dim * len(radial), dtype=torch.bool) js = [] _idx_map = [] _steerable_idx_map = [] i = 0 steerable_i = 0 for j in range(self.L + 1): attr2 = { 'irrep:' + k: v for k, v in G.irrep(int(j>0), j).attributes.items() } attr2['j'] = (int(j>0), j) # the id of the O(2) irrep dim = 2 if j > 0 else 1 multiplicity = 0 for attr1 in radial: attr = dict() attr.update(attr1) attr.update(attr2) if filter(attr): multiplicity += 1 _filter[i:i + dim] = 1 _idx_map += list(range(i, i + dim)) _steerable_idx_map.append(steerable_i) i += dim steerable_i += 1 js.append( ( (int(j>0), j), # the O(2) irrep ID multiplicity ) ) self._num_inv_spaces += multiplicity self._idx_map = np.array(_idx_map) self._steerable_idx_map = np.array(_steerable_idx_map) else: _filter = None self._idx_map = None self._steerable_idx_map = None js = [ ( (int(j>0), j), # the O(2) irrep ID len(radial) ) for j in range(L + 1) ] self.axis = axis action = G.standard_representation() action = change_basis(action, action(G.element((0, axis), 'radians')), name=f'StandardAction|axis=[{axis}]') super(CircularShellsBasis, self).__init__(G, action, js) self.radial = radial if _filter is None: self._filter = None else: self.register_buffer('_filter', _filter)
[docs] def sample(self, points: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: r""" Sample the continuous basis elements on a discrete set of ``points`` in the space :math:`\R^n`. Optionally, store the resulting multidimensional array in ``out``. ``points`` must be an array of shape `(N, 2)` containing `N` points in the space. Note that the points are specified in cartesian coordinates :math:`(x, y, z, ...)`. Args: points (~torch.Tensor): points in the n-dimensional Euclidean space where to evaluate the basis elements out (~torch.Tensor, optional): pre-existing array to use to store the output Returns: the sampled basis """ assert len(points.shape) == 2 assert points.shape[1] == self.dimensionality, (points.shape, self.dimensionality) S = points.shape[0] radii = torch.norm(points, dim=1, keepdim=True) non_origin_mask = (radii > 1e-9).reshape(-1) sphere = points[non_origin_mask, :] / radii[non_origin_mask, :] if out is None: out = torch.empty(S, self.dim, 1, 1, device=points.device, dtype=points.dtype) assert out.shape == (S, self.dim, 1, 1) # sample the radial basis radial = self.radial.sample(radii) assert radial.shape[-2:] == (1, 1) radial = radial[..., 0, 0] assert not torch.isnan(radial).any() # sample the angular basis circular = torch.empty(S, self._angular_dim, device=points.device, dtype=points.dtype) # circular[:] = np.nan # where r>0, we sample all frequencies circular[non_origin_mask, :] = circular_harmonics(sphere, self.L, phase=self.axis) # only frequency 0 is sampled at the origin. Other frequencies are set to 0 circular[~non_origin_mask, :1] = 1. # This trick allows us to compute meaningful gradients at the origin # Unfortunately, only the newest versions of PyTorch and CUDA support these complex operations # complex_points = points[~non_origin_mask, 0] + 1j * points[~non_origin_mask, 1] # complex_powers = complex_points.view(S, 1).pow(torch.arange(1, self.L).view(1, self.L)) # circular[~non_origin_mask, 1::2] = complex_powers.real # circular[~non_origin_mask, 2::2] = complex_powers.img circular[~non_origin_mask, 1:] = 0. tensor_product = torch.einsum("pa,pb->pab", radial, circular) n_radii = len(self.radial) if self._filter is None: tmp_out = out else: tmp_out = torch.empty(S, self._angular_dim*n_radii, 1, 1, device=points.device, dtype=points.dtype) for j in range(self.L+1): dim = 2 if j > 0 else 1 last = 2*j+1 first = last - dim tmp_out[:, first * n_radii:last * n_radii, 0, 0].view( S, n_radii, dim )[:] = tensor_product[:, :, first:last] if self._filter is not None: out[:] = tmp_out[:, self._filter, ...] return out
def steerable_attrs_iter(self): # This attributes don't describe a single basis element but a group of basis elements which span an invariant # subspace. This is needed to generate the attributes of the SteerableKernelBasis idx = 0 i = 0 # since this methods return iterables of attributes built on the fly, load all attributes first and then # iterate on these lists radial_attrs = list(self.radial) for j in range(self.L + 1): dim = 2 if j > 0 else 1 j_id = (int(j>0), j) # the id of the O(2) irrep attr1 = { 'irrep:' + k: v for k, v in self.group.irrep(*j_id).attributes.items() } for radial_idx, attr2 in enumerate(radial_attrs): if self._filter is None or (self._filter[i:i + dim] == 1).all(): assert attr2["idx"] == radial_idx attr = dict() attr.update(attr1) attr.update(attr2) attr["idx"] = idx attr["radial_idx"] = radial_idx attr["j"] = j_id attr["shape"] = (1, 1) yield attr idx += 1 i += dim def steerable_attrs(self, idx): # This attributes don't describe a single basis element but a group of basis elements which span an invariant # subspace. This is needed to generate the attributes of the SteerableKernelBasis assert idx < self._num_inv_spaces, (idx, self._num_inv_spaces) if self._steerable_idx_map is None: _idx = idx else: _idx = self._steerable_idx_map[idx] j, radial_idx = divmod(_idx, len(self.radial)) j_id = (int(j > 0), j) # the id of the O(2) irrep attr1 = { 'irrep:' + k: v for k, v in self.group.irrep(*j_id).attributes.items() } attr2 = self.radial[radial_idx] assert attr2["idx"] == radial_idx attr = dict() attr.update(attr1) attr.update(attr2) attr["idx"] = idx attr["radial_idx"] = radial_idx attr["j"] = j_id attr["shape"] = (1, 1) return attr def steerable_attrs_j_iter(self, j: Tuple) -> Iterable: # This attributes don't describe a single basis element but a group of basis elements which span an invariant # subspace. This is needed to generate the attributes of the SteerableKernelBasis j_id = j f, j = j if f != int(j>0): return idx = sum(self.multiplicity((int(_j>0),_j)) for _j in range(j)) dim = 2 if j > 0 else 1 i = 0 attr1 = { 'irrep:' + k: v for k, v in self.group.irrep(*j_id).attributes.items() } # since this methods return iterables of attributes built on the fly, load all attributes first and then # iterate on these lists radial_attrs = list(self.radial) for radial_idx, attr2 in enumerate(radial_attrs): if self._filter is None or (self._filter[i:i + dim] == 1).all(): assert attr2["idx"] == radial_idx attr = dict() attr.update(attr1) attr.update(attr2) attr["idx"] = idx attr["radial_idx"] = radial_idx attr["j"] = j_id attr["shape"] = (1, 1) yield attr idx += 1 i += dim def steerable_attrs_j(self, j: Tuple, idx) -> Dict: # This attributes don't describe a single basis element but a group of basis elements which span an invariant # subspace. This is needed to generate the attributes of the SteerableKernelBasis j_id = j f, j = j_id if f != int(j>0): return assert idx < self.multiplicity(j_id), (idx, self.multiplicity(j_id)) idx += sum(self.multiplicity((int(_j>0), _j)) for _j in range(j)) if self._steerable_idx_map is None: _idx = idx else: _idx = self._steerable_idx_map[idx] _j, radial_idx = divmod(_idx, len(self.radial)) assert _j == j, (j, _j) attr1 = { 'irrep:' + k: v for k, v in self.group.irrep(*j_id).attributes.items() } attr2 = self.radial[radial_idx] assert attr2["idx"] == radial_idx attr = dict() attr.update(attr1) attr.update(attr2) attr["idx"] = idx attr["radial_idx"] = radial_idx attr["j"] = j_id attr["shape"] = (1, 1) return attr def __getitem__(self, idx): assert idx < self.dim, (idx, self.dim) if self._idx_map is None: _idx = idx else: _idx = self._idx_map[idx] j = (_idx // len(self.radial) + 1) //2 assert (2*j+1) * len(self.radial) <= _idx < (2*j + 3) * len(self.radial), (_idx, j, self.L, len(self.radial)) j_idx = _idx - (2*j +1) * len(self.radial) radial_idx, m = divmod(j_idx, 2 if j>0 else 1) j_id = (int(j>0), j) # the id of the O(3) irrep attr1 = { 'irrep:' + k: v for k, v in self.group.irrep(*j_id).attributes.items() } attr2 = self.radial[radial_idx] assert attr2["idx"] == radial_idx attr = dict() attr.update(attr1) attr.update(attr2) attr["idx"] = idx attr["radial_idx"] = radial_idx attr["j"] = j_id attr["m"] = m attr["shape"] = (1, 1) return attr def __iter__(self): idx = 0 i = 0 # since this methods return iterables of attributes built on the fly, load all attributes first and then # iterate on these lists radial_attrs = list(self.radial) for j in range(self.L + 1): dim = 2 if j>0 else 1 j_id = (int(j>0), j) attr1 = { 'irrep:' + k: v for k, v in self.group.irrep(*j_id).attributes.items() } for radial_idx, attr2 in enumerate(radial_attrs): for m in range(dim): if self._filter is None or self._filter[i] == 1: assert attr2["idx"] == radial_idx attr = dict() attr.update(attr1) attr.update(attr2) attr["idx"] = idx attr["radial_idx"] = radial_idx attr["j"] = j_id attr["m"] = m attr["shape"] = (1, 1) yield attr idx += 1 i += 1 def __eq__(self, other): if isinstance(other, CircularShellsBasis): return ( self.radial == other.radial and self.L == other.L and self._filter == other._filter ) else: return False def __hash__(self): return self.L + hash(self.radial) + hash(self._filter)