Source code for e2cnn.nn.modules.identity_module

from e2cnn.nn import GeometricTensor
from e2cnn.nn import FieldType
from .equivariant_module import EquivariantModule
import torch

from typing import List, Tuple, Union, Any

__all__ = ["IdentityModule"]


[docs]class IdentityModule(EquivariantModule): def __init__(self, in_type: FieldType ): r""" Simple module which does not perform any operation on the input tensor. Args: in_type (FieldType): input (and output) type of this module """ super(IdentityModule, self).__init__() self.space = in_type.gspace self.in_type = in_type self.out_type = in_type
[docs] def forward(self, input: GeometricTensor) -> GeometricTensor: r""" Returns the input tensor. Args: input (GeometricTensor): the input GeometricTensor Returns: the output tensor """ assert input.type == self.in_type return input
def evaluate_output_shape(self, input_shape: Tuple[int, ...]) -> Tuple[int, ...]: assert len(input_shape) > 1 assert input_shape[1] == self.in_type.size return input_shape def check_equivariance(self, atol: float = 2e-6, rtol: float = 1e-5) -> List[Tuple[Any, float]]: return super(IdentityModule, self).check_equivariance(atol=atol, rtol=rtol)
[docs] def export(self): r""" Export this module to a normal PyTorch :class:`torch.nn.Identity` module and set to "eval" mode. .. warning :: Only working with PyTorch >= 1.2 """ self.eval() return torch.nn.Identity()