Source code for escnn.kernels.basis


import numpy as np
from abc import ABC, abstractmethod
from typing import List, Union, Tuple

import torch



[docs]class EmptyBasisException(Exception): def __init__(self): r""" Exception raised when a :class:`~escnn.kernels.KernelBasis` with no elements is built. """ message = "The KernelBasis you tried to instantiate is empty (dim = 0). You should catch this exception." super(EmptyBasisException, self).__init__(message)
[docs]class KernelBasis(torch.nn.Module, ABC): def __init__(self, dim: int, shape: Tuple[int, int]): r""" Abstract class for implementing the basis of a kernel space. A kernel space is the space of functions in the form: .. math:: \mathcal{K} := \{ \kappa: X \to \mathbb{R}^{c_\text{out} \times c_\text{in}} \} where :math:`X` is the base space on which the kernel is defined. For instance, for planar images :math:`X = \R^2`. One can also access the dimensionality ``dim`` of this basis via the ``len()`` method. Args: dim (int): the dimensionality of the basis :math:`|\mathcal{K}|` (number of elements) shape (tuple): a tuple containing :math:`c_\text{out}` and :math:`c_\text{in}` Attributes: ~.dim (int): the dimensionality of the basis :math:`|\mathcal{K}|` (number of elements) ~.shape (tuple): a tuple containing :math:`c_\text{out}` and :math:`c_\text{in}` """ assert isinstance(dim, int), (dim, type(dim)) assert isinstance(shape, tuple) and len(shape) == 2, shape assert dim >= 0 if dim == 0: raise EmptyBasisException() self.dim = dim self.shape = shape super(KernelBasis, self).__init__()
[docs] @abstractmethod def sample(self, points: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: r""" Sample the continuous basis elements on discrete points in ``points``. Optionally, store the resulting multidimentional array in ``out``. ``points`` must be an array of shape `(N, D)`, where `D` is the dimensionality of the (parametrization of the) base space while `N` is the number of points. Args: points (~torch.Tensor): points where to evaluate the basis elements out (~torch.Tensor, optional): pre-existing array to use to store the output Returns: the sampled basis """ pass
[docs] def forward(self, points: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: r""" Alias for :meth:`~escnn.kernels.KernelBasis.sample`. """ return self.sample(points, out=out)
def __len__(self): return self.dim def __iter__(self): for i in range(self.dim): yield self[i] @abstractmethod def __getitem__(self, idx: int) -> dict: pass @abstractmethod def __hash__(self): pass @abstractmethod def __eq__(self, other): pass
[docs]class AdjointBasis(KernelBasis): def __init__(self, basis: KernelBasis, adjoint: np.ndarray): r""" Transform the input ``basis`` by applying a change of basis ``adjoint`` on the points before sampling the basis. Args: basis (KernelBasis): a kernel basis adjoint (~numpy.ndarray): an orthonormal matrix defining the change of basis on the base space """ n = adjoint.shape[0] assert adjoint.shape == (n, n) # adjoint_inv = np.linalg.inv(adjoint) assert np.allclose(adjoint @ adjoint.T, np.eye(n)), 'Error! The adjunction matrix must be orthonormal' assert np.allclose(adjoint.T @ adjoint, np.eye(n)), 'Error! The adjunction matrix must be orthonormal' super(AdjointBasis, self).__init__(basis.dim, basis.shape) self.basis = basis self.register_buffer('adj', torch.tensor(adjoint, dtype=torch.float32))
[docs] def sample(self, points: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: r""" Sample the continuous basis elements on the discrete set of points. Optionally, store the resulting multidimentional array in ``out``. ``points`` must be an array of shape `(N, d)`, where `N` is the number of points and `d` their dimensionality. Args: points (~numpy.ndarray): points where to evaluate the basis elements out (~numpy.ndarray, optional): pre-existing array to use to store the output Returns: the sampled basis """ assert len(points.shape) == 2 assert points.shape[1] == self.adj.shape[0] transformed_points = points @ self.adj.to(device=points.device, dtype=points.dtype).T return self.basis.sample(transformed_points, out)
def __getitem__(self, r): return self.basis[r] def __eq__(self, other): if isinstance(other, AdjointBasis): return self.basis == other.basis and torch.allclose(self.adj, other.adj) # elif self.basis == other: # return np.allclose(self.adj, np.eye(self.adj.shape[0)) else: return False def __hash__(self): return hash(self.adj) + 1000 * hash(self.basis)
[docs]class UnionBasis(KernelBasis): def __init__(self, bases_list: List[KernelBasis]): r""" Construct the union of a list of bases. All bases must have the same ``shape``; the resulting basis has ``dim`` equal to the sum of the dimensionalities of the individual bases. """ if len(bases_list) == 0: raise EmptyBasisException shape = bases_list[0].shape dim = 0 for i in range(len(bases_list)): assert bases_list[i].shape == shape dim += bases_list[i].dim if dim == 0: raise EmptyBasisException super(UnionBasis, self).__init__(dim, shape) self._bases = torch.nn.ModuleList(bases_list) def sample(self, points: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: assert len(points.shape) == 2 S = points.shape[0] if out is None: out = torch.empty(S, self.dim, self.shape[0], self.shape[1], device=points.device, dtype=points.dtype) p = 0 for i in range(len(self._bases)): basis = self._bases[i] basis.sample(points, out=out[:, p:p+basis.dim, ...]) p += basis.dim return out def __getitem__(self, idx: int) -> dict: assert idx < self.dim p = 0 for i in range(len(self._bases)): basis = self._bases[i] if idx < p + basis.dim: break p += basis.dim attr = self._bases[i][idx - p] attr["shape"] = self.shape attr['basis_id'] = i attr['basis_idx'] = attr['idx'] attr['idx'] = idx return attr def __hash__(self): return sum(hash(self._bases[i])*(i+1)**2 for i in range(len(self._bases))) def __eq__(self, other): if not isinstance(other, UnionBasis): return False elif self.dim != other.dim or self.shape != other.shape or len(self._bases) != len(other._bases): return False else: for i in range(len(self._bases)): if self._bases[i] != other._bases[i]: return False return True def __iter__(self): idx = 0 for i, basis in enumerate(self._bases): for attr in basis: attr["shape"] = self.shape attr['basis_id'] = i attr['basis_idx'] = attr['idx'] attr['idx'] = idx assert idx < self.dim yield attr idx += 1