Source code for escnn.gspaces.r0


from __future__ import annotations

import escnn.kernels
import escnn.group
from escnn.group import Group
from .gspace import GSpace

from typing import Tuple, Callable

import numpy as np


__all__ = [
    "GSpace0D",
    "no_base_space",
]


[docs]class GSpace0D(GSpace): def __init__(self, G: escnn.group.Group): # TODO: is GSpace0D a good name? Maybe RepresentationSpace is better? super(GSpace0D, self).__init__(G, 0, G.name)
[docs] def restrict(self, id) -> Tuple[GSpace, Callable, Callable]: subgroup, parent, child = self.fibergroup.subgroup(id) return GSpace0D(subgroup), parent, child
@property def basespace_action(self) -> escnn.group.Representation: return None def _interpolate_transform_basespace( self, input: np.ndarray, element: escnn.group.GroupElement, order: int = 2, ) -> np.ndarray: assert element.group == self.fibergroup return input
[docs] def build_kernel_basis(self, in_repr: escnn.group.Representation, out_repr: escnn.group.Representation, **kwargs) -> escnn.kernels.KernelBasis: r""" Args: in_repr (Representation): the input representation out_repr (Representation): the output representation **kwargs: Group-specific keywords arguments for ``_basis_generator`` method Returns: the analytical basis """ assert isinstance(in_repr, escnn.group.Representation) assert isinstance(out_repr, escnn.group.Representation) assert in_repr.group == self.fibergroup assert out_repr.group == self.fibergroup # build the key key = () if (in_repr.name, out_repr.name) not in self._fields_intertwiners_basis_memory[key]: basis = self._basis_generator(in_repr, out_repr) # store the basis in the dictionary self._fields_intertwiners_basis_memory[key][(in_repr.name, out_repr.name)] = basis # return the dictionary with the basis built return self._fields_intertwiners_basis_memory[key][(in_repr.name, out_repr.name)]
def _basis_generator(self, in_repr: escnn.group.Representation, out_repr: escnn.group.Representation, **kwargs): return escnn.kernels.kernels_on_point(in_repr, out_repr) def __eq__(self, other): if isinstance(other, GSpace0D): return self.fibergroup == other.fibergroup else: return False def __hash__(self): return hash(self.fibergroup)
[docs]def no_base_space(G: Group) -> GSpace0D: r""" Build the :class:`~escnn.gspaces.GSpace` of the input group ``G`` acting on a single point space. This simple gspace can be useful to describe the features of a ``G``-equivariant MLP. Args: G (Group): a group """ return GSpace0D(G)