Source code for e2cnn.nn.modules.r2_conv.basisexpansion


from abc import ABC, abstractmethod

from torch.nn import Module
import torch
import numpy as np

from typing import List, Iterable, Dict, Union


__all__ = ["BasisExpansion"]


[docs]class BasisExpansion(ABC, Module): def __init__(self): r""" Abstract class defining the interface for the different basis expansion algorithms. """ super(BasisExpansion, self).__init__()
[docs] @abstractmethod def forward(self, weights: torch.Tensor) -> torch.Tensor: """ Forward step of the module which expands the basis and returns the filter built Args: weights (torch.Tensor): the learnable weights used to linearly combine the basis elements Returns: the filter built """ pass
[docs] @abstractmethod def get_basis_names(self) -> List[str]: """ Method that returns the list of identification names of the basis elements Returns: list of names """ pass
[docs] @abstractmethod def get_element_info(self, name: Union[str, int]) -> Dict: """ Method that returns the information associated to a basis element Parameters: name (str or int): identifier of the basis element or its index Returns: dictionary containing the information """ pass
[docs] @abstractmethod def get_basis_info(self) -> Iterable: """ Method that returns an iterable over all basis elements' attributes. Returns: an iterable over all the basis elements' attributes """ pass
[docs] @abstractmethod def dimension(self) -> int: r""" The dimensionality of the basis and, so, the number of weights needed to expand it. Returns: the dimensionality of the basis """ pass
@abstractmethod def __hash__(self): raise NotImplementedError() @abstractmethod def __eq__(self, other): raise NotImplementedError()