Source code for escnn.kernels.r3

from escnn.kernels.basis import KernelBasis, AdjointBasis, UnionBasis, EmptyBasisException
from escnn.kernels.steerable_basis import SteerableKernelBasis
from escnn.kernels.wignereckart_solver import WignerEckartBasis, RestrictedWignerEckartBasis
from escnn.kernels.sparse_basis import SparseOrbitBasis, SparseOrbitBasisWithIcosahedralSymmetry

from escnn.kernels.polar_basis import GaussianRadialProfile
from escnn.kernels.polar_basis import SphericalShellsBasis

from escnn.group import *

import numpy as np

from typing import List, Union, Callable, Dict, Tuple

__all__ = [
    "kernels_SO3_act_R3",
    "kernels_O3_act_R3",
    ###################
    "kernels_SO3_subgroup_act_R3",
    "kernels_O3_subgroup_act_R3",
    ###################
    "kernels_Ico_act_R3",
    "kernels_Octa_act_R3",
    "kernels_Tetra_act_R3",
    "kernels_FullIco_act_R3",
    "kernels_FullOcta_act_R3",
    "kernels_FullTetra_act_R3",
    "kernels_Pyrito_act_R3",
    ###################
    "kernels_SO2_act_R3",
    "kernels_CN_act_R3",
    "kernels_O2_conical_act_R3",
    "kernels_DN_conical_act_R3",
    "kernels_O2_dihedral_act_R3",
    "kernels_DN_dihedral_act_R3",
    "kernels_full_cylinder_act_R3",
    "kernels_full_cylinder_discrete_act_R3",
    "kernels_cylinder_act_R3",
    "kernels_cylinder_discrete_act_R3",
    "kernels_Inv_act_R3",
    "kernels_Trivial_act_R3",
    ##################
    'kernels_aliased_Ico_act_R3_dodecahedron',
    'kernels_aliased_Ico_act_R3_icosidodecahedron',
    'kernels_aliased_Ico_act_R3_icosahedron',
]


[docs]def kernels_SO3_act_R3(in_repr: Representation, out_repr: Representation, radii: List[float], sigma: Union[List[float], float], maximum_frequency: int = None, adjoint: np.ndarray = None, filter: Callable[[Dict], bool] = None ) -> KernelBasis: r""" Builds a basis for convolutional kernels equivariant to continuous rotations, modeled by the group :math:`SO(3)`. ``in_repr`` and ``out_repr`` need to be :class:`~escnn.group.Representation` s of :class:`~escnn.group.SO3`. Because the equivariance constraints allow any choice of radial profile, we use a :class:`~escnn.kernels.GaussianRadialProfile`. ``radii`` specifies the radial distances at which the rings are centered while ``sigma`` contains the width of each of these rings (see :class:`~escnn.kernels.GaussianRadialProfile`). Args: in_repr (Representation): the representation specifying the transformation of the input feature field out_repr (Representation): the representation specifying the transformation of the output feature field radii (list): radii of the rings defining the basis for the radial profile sigma (list or float): widths of the rings defining the basis for the radial profile adjoint (~numpy.ndarray, optional): 3x3 orthogonal matrix defining a change of basis on the base space """ assert in_repr.group == out_repr.group group = in_repr.group assert isinstance(group, SO3) radial_profile = GaussianRadialProfile(radii, sigma) if maximum_frequency is None: max_in_freq = max(freq for freq, in in_repr.irreps) max_out_freq = max(freq for freq, in out_repr.irreps) maximum_frequency = max_in_freq + max_out_freq basis = SteerableKernelBasis( SphericalShellsBasis(maximum_frequency, radial_profile, filter=filter), in_repr, out_repr, RestrictedWignerEckartBasis, sg_id='so3' ) if adjoint is not None and not np.allclose(adjoint, np.eye(3)): assert adjoint.shape == (3, 3) basis = AdjointBasis(basis, adjoint) return basis
[docs]def kernels_O3_act_R3(in_repr: Representation, out_repr: Representation, radii: List[float], sigma: Union[List[float], float], maximum_frequency: int = None, adjoint: np.ndarray = None, filter: Callable[[Dict], bool] = None ) -> KernelBasis: r""" Builds a basis for convolutional kernels equivariant to reflections and continuous rotations, modeled by the group :math:`O(3)`. ``in_repr`` and ``out_repr`` need to be :class:`~escnn.group.Representation` s of :class:`~escnn.group.O3`. Because the equivariance constraints allow any choice of radial profile, we use a :class:`~escnn.kernels.GaussianRadialProfile`. ``radii`` specifies the radial distances at which the rings are centered while ``sigma`` contains the width of each of these rings (see :class:`~escnn.kernels.GaussianRadialProfile`). Because :math:`O(3)` contains all rotations, the reflection element of the group can be associated to any reflection axis. Args: in_repr (Representation): the representation specifying the transformation of the input feature field out_repr (Representation): the representation specifying the transformation of the output feature field radii (list): radii of the rings defining the basis for the radial profile sigma (list or float): widths of the rings defining the basis for the radial profile adjoint (~numpy.ndarray, optional): 3x3 orthogonal matrix defining a change of basis on the base space """ assert in_repr.group == out_repr.group group = in_repr.group assert isinstance(group, O3) radial_profile = GaussianRadialProfile(radii, sigma) if maximum_frequency is None: max_in_freq = max(freq for _, freq in in_repr.irreps) max_out_freq = max(freq for _, freq in out_repr.irreps) maximum_frequency = max_in_freq + max_out_freq basis = SteerableKernelBasis( SphericalShellsBasis(maximum_frequency, radial_profile, filter=filter), in_repr, out_repr, WignerEckartBasis, ) if adjoint is not None and not np.allclose(adjoint, np.eye(3)): assert adjoint.shape == (3, 3) basis = AdjointBasis(basis, adjoint) return basis
###### Automatic subgroups kernel bases
[docs]def kernels_O3_subgroup_act_R3(in_repr: Representation, out_repr: Representation, sg_id: Tuple, radii: List[float], sigma: Union[List[float], float], maximum_frequency: int = 5, adjoint: np.ndarray = None, filter: Callable[[Dict], bool] = None ) -> KernelBasis: o3 = o3_group(maximum_frequency) group, _, _ = o3.subgroup(sg_id) assert in_repr.group == group assert out_repr.group == group radial_profile = GaussianRadialProfile(radii, sigma) basis = SteerableKernelBasis( SphericalShellsBasis(maximum_frequency, radial_profile, filter=filter), in_repr, out_repr, RestrictedWignerEckartBasis, sg_id=sg_id ) if adjoint is not None and not np.allclose(adjoint, np.eye(3)): assert adjoint.shape == (3, 3) basis = AdjointBasis(basis, adjoint) return basis
[docs]def kernels_SO3_subgroup_act_R3(in_repr: Representation, out_repr: Representation, sg_id: Tuple, radii: List[float], sigma: Union[List[float], float], maximum_frequency: int = 5, adjoint: np.ndarray = None, filter: Callable[[Dict], bool] = None ) -> KernelBasis: so3 = so3_group(maximum_frequency) group, _, _ = so3.subgroup(sg_id) assert in_repr.group == group assert out_repr.group == group o3 = o3_group(maximum_frequency) sg_id = o3._combine_subgroups('so3', sg_id) radial_profile = GaussianRadialProfile(radii, sigma) basis = SteerableKernelBasis( SphericalShellsBasis(maximum_frequency, radial_profile, filter=filter), in_repr, out_repr, RestrictedWignerEckartBasis, sg_id=sg_id ) if adjoint is not None and not np.allclose(adjoint, np.eye(3)): assert adjoint.shape == (3, 3) basis = AdjointBasis(basis, adjoint) return basis
###### Platonic Symmetries
[docs]def kernels_Ico_act_R3(in_repr: Representation, out_repr: Representation, radii: List[float], sigma: Union[List[float], float], maximum_frequency: int = 5, adjoint: np.ndarray = None, filter: Callable[[Dict], bool] = None ) -> KernelBasis: sg_id = 'ico' return kernels_SO3_subgroup_act_R3( in_repr, out_repr, sg_id, radii=radii, sigma=sigma, maximum_frequency=maximum_frequency, adjoint=adjoint, filter=filter )
[docs]def kernels_Octa_act_R3(in_repr: Representation, out_repr: Representation, radii: List[float], sigma: Union[List[float], float], maximum_frequency: int = 5, adjoint: np.ndarray = None, filter: Callable[[Dict], bool] = None ) -> KernelBasis: sg_id = 'octa' return kernels_SO3_subgroup_act_R3( in_repr, out_repr, sg_id, radii=radii, sigma=sigma, maximum_frequency=maximum_frequency, adjoint=adjoint, filter=filter )
[docs]def kernels_Tetra_act_R3(in_repr: Representation, out_repr: Representation, radii: List[float], sigma: Union[List[float], float], maximum_frequency: int = 5, adjoint: np.ndarray = None, filter: Callable[[Dict], bool] = None ) -> KernelBasis: sg_id = 'tetra' return kernels_SO3_subgroup_act_R3( in_repr, out_repr, sg_id, radii=radii, sigma=sigma, maximum_frequency=maximum_frequency, adjoint=adjoint, filter=filter )
###### Icosahedral Symmetry with Aliased samples def _kernels_aliased_Ico_act_R3(in_repr: Representation, out_repr: Representation, radii: List[float], sigma: Union[List[float], float], sgid: Tuple, adjoint: np.ndarray = None, ) -> KernelBasis: group = ico_group() assert in_repr.group == group assert out_repr.group == group if isinstance(sigma, float): sigma = [sigma]*len(radii) assert len(sigma) == len(radii) basis_list = [] for r, s in zip(radii, sigma): attributes = {'radius': r} try: if np.isclose(r, 0., rtol=1e-7, atol=1e-7): basis = SparseOrbitBasis( X=group.homspace(group.subgroup_self_id), action=group.standard_representation, root=np.zeros(3), sigma=s, attributes=attributes ) else: change_of_basis = np.eye(3) * r basis = SparseOrbitBasisWithIcosahedralSymmetry( X=group.homspace(sgid), sigma=s, attributes=attributes, change_of_basis = change_of_basis ) basis_list.append( SteerableKernelBasis(basis, in_repr, out_repr, WignerEckartBasis) ) except EmptyBasisException: pass basis = UnionBasis(basis_list) if adjoint is not None and not np.allclose(adjoint, np.eye(2)): assert adjoint.shape == (3, 3) basis = AdjointBasis(basis, adjoint) return basis def kernels_aliased_Ico_act_R3_dodecahedron(in_repr: Representation, out_repr: Representation, radii: List[float], sigma: Union[List[float], float], adjoint: np.ndarray = None, ) -> KernelBasis: sgid = (False, 5) return _kernels_aliased_Ico_act_R3( in_repr, out_repr, radii, sigma, sgid, adjoint ) def kernels_aliased_Ico_act_R3_icosahedron(in_repr: Representation, out_repr: Representation, radii: List[float], sigma: Union[List[float], float], adjoint: np.ndarray = None, ) -> KernelBasis: sgid = (False, 3) return _kernels_aliased_Ico_act_R3( in_repr, out_repr, radii, sigma, sgid, adjoint ) def kernels_aliased_Ico_act_R3_icosidodecahedron(in_repr: Representation, out_repr: Representation, radii: List[float], sigma: Union[List[float], float], adjoint: np.ndarray = None, ) -> KernelBasis: sgid = (False, 2) return _kernels_aliased_Ico_act_R3( in_repr, out_repr, radii, sigma, sgid, adjoint ) ###### Full Platonic Symmetries
[docs]def kernels_FullIco_act_R3(in_repr: Representation, out_repr: Representation, radii: List[float], sigma: Union[List[float], float], maximum_frequency: int = 5, adjoint: np.ndarray = None, filter: Callable[[Dict], bool] = None ) -> KernelBasis: # TODO raise NotImplementedError # I_h = I x C_2 sg_id = (True, 'ico') return kernels_O3_subgroup_act_R3( in_repr, out_repr, sg_id, radii=radii, sigma=sigma, maximum_frequency=maximum_frequency, adjoint=adjoint, filter=filter )
[docs]def kernels_FullOcta_act_R3(in_repr: Representation, out_repr: Representation, radii: List[float], sigma: Union[List[float], float], maximum_frequency: int = 5, adjoint: np.ndarray = None, filter: Callable[[Dict], bool] = None ) -> KernelBasis: # TODO raise NotImplementedError # O_h = O x C_2 sg_id = (True, 'octa') return kernels_O3_subgroup_act_R3( in_repr, out_repr, sg_id, radii=radii, sigma=sigma, maximum_frequency=maximum_frequency, adjoint=adjoint, filter=filter )
[docs]def kernels_FullTetra_act_R3(in_repr: Representation, out_repr: Representation, radii: List[float], sigma: Union[List[float], float], maximum_frequency: int = 5, adjoint: np.ndarray = None, filter: Callable[[Dict], bool] = None ) -> KernelBasis: # TODO raise NotImplementedError # group T_d # n.b. this is different from T x C_2 sg_id = 'fulltetra' return kernels_O3_subgroup_act_R3( in_repr, out_repr, sg_id, radii=radii, sigma=sigma, maximum_frequency=maximum_frequency, adjoint=adjoint, filter=filter )
[docs]def kernels_Pyrito_act_R3(in_repr: Representation, out_repr: Representation, radii: List[float], sigma: Union[List[float], float], maximum_frequency: int = 5, adjoint: np.ndarray = None, filter: Callable[[Dict], bool] = None ) -> KernelBasis: # TODO raise NotImplementedError # T_h = T x C_2 # n.b. not a symmetry of the tetrahedron sg_id = (True, 'tetra') return kernels_O3_subgroup_act_R3( in_repr, out_repr, sg_id, radii=radii, sigma=sigma, maximum_frequency=maximum_frequency, adjoint=adjoint, filter=filter )
###### Planar symmetries
[docs]def kernels_SO2_act_R3(in_repr: Representation, out_repr: Representation, radii: List[float], sigma: Union[List[float], float], maximum_frequency: int = 5, adjoint: np.ndarray = None, filter: Callable[[Dict], bool] = None ) -> KernelBasis: sg_id = (False, -1) return kernels_SO3_subgroup_act_R3( in_repr, out_repr, sg_id, radii=radii, sigma=sigma, maximum_frequency=maximum_frequency, adjoint=adjoint, filter=filter )
[docs]def kernels_CN_act_R3(in_repr: Representation, out_repr: Representation, radii: List[float], sigma: Union[List[float], float], maximum_frequency: int = 5, max_offset: int = None, adjoint: np.ndarray = None, filter: Callable[[Dict], bool] = None ) -> KernelBasis: assert in_repr.group == out_repr.group group = in_repr.group assert isinstance(group, CyclicGroup) # TODO implement max_offset as filter to apply on top later sg_id = (False, group.order()) return kernels_SO3_subgroup_act_R3( in_repr, out_repr, sg_id, radii=radii, sigma=sigma, maximum_frequency=maximum_frequency, adjoint=adjoint, filter=filter )
[docs]def kernels_O2_conical_act_R3(in_repr: Representation, out_repr: Representation, radii: List[float], sigma: Union[List[float], float], maximum_frequency: int = 5, adjoint: np.ndarray = None, filter: Callable[[Dict], bool] = None ) -> KernelBasis: sg_id = ('cone', -1) return kernels_O3_subgroup_act_R3( in_repr, out_repr, sg_id, radii=radii, sigma=sigma, maximum_frequency=maximum_frequency, adjoint=adjoint, filter=filter )
[docs]def kernels_DN_conical_act_R3(in_repr: Representation, out_repr: Representation, radii: List[float], sigma: Union[List[float], float], # heights: List[float], maximum_frequency: int = 5, max_offset: int = None, adjoint: np.ndarray = None, filter: Callable[[Dict], bool] = None ) -> KernelBasis: assert in_repr.group == out_repr.group group = in_repr.group assert isinstance(group, DihedralGroup) # TODO implement max_offset as filter to apply on top later sg_id = ('cone', group.rotation_order) return kernels_O3_subgroup_act_R3( in_repr, out_repr, sg_id, radii=radii, sigma=sigma, maximum_frequency=maximum_frequency, adjoint=adjoint, filter=filter )
[docs]def kernels_O2_dihedral_act_R3(in_repr: Representation, out_repr: Representation, radii: List[float], sigma: Union[List[float], float], # heights: List[float], maximum_frequency: int = 5, adjoint: np.ndarray = None, filter: Callable[[Dict], bool] = None ) -> KernelBasis: sg_id = (True, -1) return kernels_SO3_subgroup_act_R3( in_repr, out_repr, sg_id, radii=radii, sigma=sigma, maximum_frequency=maximum_frequency, adjoint=adjoint, filter=filter )
[docs]def kernels_DN_dihedral_act_R3(in_repr: Representation, out_repr: Representation, radii: List[float], sigma: Union[List[float], float], # heights: List[float], maximum_frequency: int = 5, max_offset: int = None, adjoint: np.ndarray = None, filter: Callable[[Dict], bool] = None ) -> KernelBasis: assert in_repr.group == out_repr.group # assert all(h >= 0. for h in heights) # TODO implement max_offset as filter to apply on top later group = in_repr.group assert isinstance(group, DihedralGroup) sg_id = (True, group.rotation_order) return kernels_SO3_subgroup_act_R3( in_repr, out_repr, sg_id, radii=radii, sigma=sigma, maximum_frequency=maximum_frequency, adjoint=adjoint, filter=filter )
def kernels_full_cylinder_act_R3(in_repr: Representation, out_repr: Representation, radii: List[float], sigma: Union[List[float], float], maximum_frequency: int = 5, adjoint: np.ndarray = None, filter: Callable[[Dict], bool] = None ) -> KernelBasis: sg_id = True, True, -1 return kernels_O3_subgroup_act_R3( in_repr, out_repr, sg_id, radii=radii, sigma=sigma, maximum_frequency=maximum_frequency, adjoint=adjoint, filter=filter ) def kernels_full_cylinder_discrete_act_R3(in_repr: Representation, out_repr: Representation, radii: List[float], sigma: Union[List[float], float], # heights: List[float], maximum_frequency: int = 5, max_offset: int = None, adjoint: np.ndarray = None, filter: Callable[[Dict], bool] = None ) -> KernelBasis: group = in_repr.group # TODO implement max_offset as filter to apply on top later sg_id = True, True, group.order()//4 return kernels_O3_subgroup_act_R3( in_repr, out_repr, sg_id, radii=radii, sigma=sigma, maximum_frequency=maximum_frequency, adjoint=adjoint, filter=filter ) def kernels_cylinder_act_R3(in_repr: Representation, out_repr: Representation, radii: List[float], sigma: Union[List[float], float], maximum_frequency: int = 5, adjoint: np.ndarray = None, filter: Callable[[Dict], bool] = None ) -> KernelBasis: sg_id = True, False, -1 return kernels_O3_subgroup_act_R3( in_repr, out_repr, sg_id, radii=radii, sigma=sigma, maximum_frequency=maximum_frequency, adjoint=adjoint, filter=filter ) def kernels_cylinder_discrete_act_R3(in_repr: Representation, out_repr: Representation, radii: List[float], sigma: Union[List[float], float], # heights: List[float], maximum_frequency: int = 5, max_offset: int = None, adjoint: np.ndarray = None, filter: Callable[[Dict], bool] = None ) -> KernelBasis: group = in_repr.group # TODO implement max_offset as filter to apply on top later sg_id = True, False, group.order() // 2 return kernels_O3_subgroup_act_R3( in_repr, out_repr, sg_id, radii=radii, sigma=sigma, maximum_frequency=maximum_frequency, adjoint=adjoint, filter=filter )
[docs]def kernels_Inv_act_R3(in_repr: Representation, out_repr: Representation, radii: List[float], sigma: Union[List[float], float], maximum_frequency: int = 5, filter: Callable[[Dict], bool] = None ) -> KernelBasis: sg_id = (True, False, 1) return kernels_O3_subgroup_act_R3( in_repr, out_repr, sg_id, radii=radii, sigma=sigma, maximum_frequency=maximum_frequency, filter=filter )
[docs]def kernels_Trivial_act_R3(in_repr: Representation, out_repr: Representation, radii: List[float], sigma: Union[List[float], float], maximum_frequency: int = 5, filter: Callable[[Dict], bool] = None ) -> KernelBasis: sg_id = so3_group().subgroup_trivial_id return kernels_SO3_subgroup_act_R3( in_repr, out_repr, sg_id, radii=radii, sigma=sigma, maximum_frequency=maximum_frequency, filter=filter )