from e2cnn.nn import GeometricTensor
from e2cnn.nn import FieldType
from torch.nn import Module
import torch
import numpy as np
from abc import ABC, abstractmethod
from typing import List, Tuple, Any
__all__ = ["EquivariantModule"]
[docs]class EquivariantModule(Module, ABC):
def __init__(self):
r"""
Abstract base class for all equivariant modules.
An :class:`~EquivariantModule` is a subclass of :class:`torch.nn.Module`.
It follows that any subclass of :class:`~EquivariantModule` needs to implement the
:meth:`~e2cnn.nn.EquivariantModule.forward` method.
With respect to a general :class:`torch.nn.Module`, an *equivariant module* implements a *typed* function as
both its input and its output are associated with specific :class:`~e2cnn.nn.FieldType` s.
Therefore, usually, the inputs and the outputs of an *equivariant module* are not just instances of
:class:`torch.Tensor` but :class:`~e2cnn.nn.GeometricTensor` s.
As a subclass of :class:`torch.nn.Module`, it supports most of the commonly used methods (e.g.
:meth:`torch.nn.Module.to`, :meth:`torch.nn.Module.cuda`, :meth:`torch.nn.Module.train` or
:meth:`torch.nn.Module.eval`)
Many equivariant modules implement a :meth:`~e2cnn.nn.EquivariantModule.export` method which converts the module
to *eval* mode and returns a pure PyTorch implementation of it.
This can be used after training to efficiently deploy the model without, for instance, the overhead of the
automatic type checking performed by all the modules in this library.
.. warning ::
Not all modules implement this feature yet.
If the :meth:`~e2cnn.nn.EquivariantModule.export` method is called in a module which does not implement it
yet, a :class:`NotImplementedError` is raised.
Check the documentation of each individual module to understand if the method is implemented.
Attributes:
~.in_type (FieldType): type of the :class:`~e2cnn.nn.GeometricTensor` expected as input
~.out_type (FieldType): type of the :class:`~e2cnn.nn.GeometricTensor` returned as output
"""
super(EquivariantModule, self).__init__()
# FieldType: type of the :class:`~e2cnn.nn.GeometricTensor` expected as input
self.in_type = None
# FieldType: type of the :class:`~e2cnn.nn.GeometricTensor` returned as output
self.out_type = None
@abstractmethod
def forward(self, *input):
pass
[docs] @abstractmethod
def evaluate_output_shape(self, input_shape: Tuple[int, ...]) -> Tuple[int, ...]:
r"""
Compute the shape the output tensor which would be generated by this module when a tensor with shape
``input_shape`` is provided as input.
Args:
input_shape (tuple): shape of the input tensor
Returns:
shape of the output tensor
"""
pass
[docs] def check_equivariance(self, atol: float = 1e-7, rtol: float = 1e-5) -> List[Tuple[Any, float]]:
r"""
Method that automatically tests the equivariance of the current module.
The default implementation of this method relies on :meth:`e2cnn.nn.GeometricTensor.transform` and uses the
the group elements in :attr:`~e2cnn.nn.FieldType.testing_elements`.
This method can be overwritten for custom tests.
Returns:
a list containing containing for each testing element a pair with that element and the corresponding
equivariance error
"""
c = self.in_type.size
x = torch.randn(3, c, 10, 10)
x = GeometricTensor(x, self.in_type)
errors = []
for el in self.out_type.testing_elements:
print(el)
out1 = self(x).transform(el).tensor.detach().numpy()
out2 = self(x.transform(el)).tensor.detach().numpy()
errs = out1 - out2
errs = np.abs(errs).reshape(-1)
print(el, errs.max(), errs.mean(), errs.var())
assert np.allclose(out1, out2, atol=atol, rtol=rtol), \
'The error found during equivariance check with element "{}" is too high: max = {}, mean = {} var ={}'\
.format(el, errs.max(), errs.mean(), errs.var())
errors.append((el, errs.mean()))
return errors
[docs] def export(self):
r"""
Export recursively each submodule to a normal PyTorch module and set to "eval" mode.
.. warning ::
Not all modules implement this feature yet.
If the :meth:`~e2cnn.nn.EquivariantModule.export` method is called in a module which does not implement it
yet, a :class:`NotImplementedError` is raised.
Check the documentation of each individual module to understand if the method is implemented.
"""
raise NotImplementedError(
'Conversion of equivariant module {} into PyTorch module is not supported yet'.format(self.__class__)
)