Source code for escnn.nn.modules.conv.r3_ico_transposed_convolution


from escnn.nn import FieldType

from escnn.gspaces import GSpace3D
from escnn.group import Representation, Group, Icosahedral
from escnn.kernels import KernelBasis
from escnn.kernels import kernels_aliased_Ico_act_R3_icosahedron, kernels_aliased_Ico_act_R3_icosidodecahedron, kernels_aliased_Ico_act_R3_dodecahedron

from .r3_transposed_convolution import R3ConvTransposed

from typing import Callable, Union, List


__all__ = ["R3IcoConvTransposed"]


[docs]class R3IcoConvTransposed(R3ConvTransposed): def __init__( self, in_type: FieldType, out_type: FieldType, kernel_size: int, padding: int = 0, output_padding: int = 0, stride: int = 1, dilation: int = 1, groups: int = 1, bias: bool = True, samples: str = 'ico', sigma: Union[List[float], float] = None, rings: List[float] = None, recompute: bool = False, basis_filter: Callable[[dict], bool] = None, initialize: bool = True, ): r""" Icosahedral-steerable volumetric convolution mapping between the input and output :class:`~escnn.nn.FieldType` s specified by the parameters ``in_type`` and ``out_type``. This operation is equivariant under the action of :math:`\R^3\rtimes I` where :math:`I` (the :class:`~escnn.group.Icosahedral` group) is the :attr:`escnn.nn.FieldType.fibergroup` of ``in_type`` and ``out_type``. This class is mostly similar to :class:`~escnn.nn.R3ConvTransposed`, with the only difference that it only supports the group :class:`~escnn.group.Icosahedral` since it uses a kernel basis which is specific for this group. The argument ``frequencies_cutoff`` of :class:`~escnn.nn.R3ConvTransposed` is not supported here since the steerable kernels are not generated from a band-limited set of harmonic functions. Instead, the argument ``samples`` specifies the polyhedron (symmetric with respect to the :class:`~escnn.group.Icosahedral` group) whose vertices are used to define the kernel on :math:`\R^3`. The supported polyhedrons are ``"ico"`` (the 12 vertices of the icosahedron), ``"dodeca"`` (the 20 vertices of the dodecahedron) or ``"icosidodeca"`` (the 30 vertices of the icosidodecahedron, which correspond to the centers of the 30 edges of either the icosahedron or the dodecahedron). For each ring ``r`` in ``rings``, the polyhedron specified in embedded in the sphere of radius ``r``. The analytical kernel, which is only defined on the vertices of this polyhedron, is then "diffused" in the ambient space :math:`\R^3` by means of a small Gaussian kernel with std ``sigma``. .. warning :: Even if the input tensor has a `coords` attribute, the output of this module will not have one. Args: in_type (FieldType): the type of the input field, specifying its transformation law out_type (FieldType): the type of the output field, specifying its transformation law kernel_size (int): the size of the (square) filter padding(int, optional): implicit zero paddings on both sides of the input. Default: ``0`` output_padding(int, optional): implicit zero paddings on both sides of the input. Default: ``0`` stride(int, optional): the stride of the kernel. Default: ``1`` dilation(int, optional): the spacing between kernel elements. Default: ``1`` groups (int, optional): number of blocked connections from input channels to output channels. It allows depthwise convolution. When used, the input and output types need to be divisible in ``groups`` groups, all equal to each other. Default: ``1``. bias (bool, optional): Whether to add a bias to the output (only to fields which contain a trivial irrep) or not. Default ``True`` sigma (list or float, optional): width of each ring where the bases are sampled. If only one scalar is passed, it is used for all rings. rings (list, optional): radii of the rings where to sample the bases recompute (bool, optional): if ``True``, recomputes a new basis for the equivariant kernels. By Default (``False``), it caches the basis built or reuse a cached one, if it is found. basis_filter (callable, optional): function which takes as input a descriptor of a basis element (as a dictionary) and returns a boolean value: whether to preserve (``True``) or discard (``False``) the basis element. By default (``None``), no filtering is applied. initialize (bool, optional): initialize the weights of the model. Default: ``True`` """ assert isinstance(in_type.gspace, GSpace3D) assert isinstance(out_type.gspace, GSpace3D) assert in_type.gspace == out_type.gspace assert isinstance(in_type.gspace.fibergroup, Icosahedral) assert samples in ['ico', 'dodeca', 'icosidodeca'] self._samples = samples super(R3IcoConvTransposed, self).__init__( in_type, out_type, kernel_size, padding, output_padding, stride, dilation, groups, bias, sigma, frequencies_cutoff=5., # to avoid performing any frequency cut-off, set a large upperbound rings=rings, recompute=recompute, basis_filter=basis_filter, initialize=initialize, ) def _build_kernel_basis(self, in_repr: Representation, out_repr: Representation) -> KernelBasis: if self._samples == 'ico': return kernels_aliased_Ico_act_R3_icosahedron(in_repr, out_repr, sigma=self._sigma, radii=self._rings) if self._samples == 'dodeca': return kernels_aliased_Ico_act_R3_dodecahedron(in_repr, out_repr, sigma=self._sigma, radii=self._rings) if self._samples == 'icosidodeca': return kernels_aliased_Ico_act_R3_icosidodecahedron(in_repr, out_repr, sigma=self._sigma, radii=self._rings) else: raise ValueError