Source code for e2cnn.nn.modules.module_list


from .equivariant_module import EquivariantModule

import torch
TORCH_MAJOR, TORCH_MINOR = map(int, torch.__version__.split('.')[:2])

if TORCH_MAJOR == 1 and TORCH_MINOR <= 8:
    from torch._six import container_abcs
else:
    import collections.abc as container_abcs

from typing import List, Iterable


__all__ = ["ModuleList"]


[docs]class ModuleList(torch.nn.ModuleList): def __init__(self, modules: Iterable[EquivariantModule] = None, ): r""" Module similar to :class:`~torch.nn.ModuleList` containing a list of :class:`~e2cnn.nn.EquivariantModule` s. This class works like :class:`~torch.nn.ModuleList` except for the fact it only accepts instances of :class:`~e2cnn.nn.EquivariantModule`. Additionally, this class provides a `.export()` method. This method calls the :meth:`~e2cnn.nn.EquivariantModule.export` method of each module contained in this :class:`~e2cnn.nn.ModuleList` and returns a :class:`~torch.nn.ModuleList` containing the exported modules. Args: modules (iterable, optional): an iterable of equivariant modules to add """ super(ModuleList, self).__init__(modules) def __setitem__(self, idx: int, module: EquivariantModule): assert isinstance(module, EquivariantModule) super(ModuleList, self).__setitem__(idx, module) def insert(self, index: int, module: EquivariantModule) -> None: assert isinstance(module, EquivariantModule) super(ModuleList, self).insert(index, module)
[docs] def append(self, module: EquivariantModule) -> 'ModuleList': r"""Appends an :class:`~e2cnn.nn.EquivariantModule` to the end of the list. Args: module (EquivariantModule): equivariant module to append """ assert isinstance(module, EquivariantModule) return super(ModuleList, self).append(module)
[docs] def extend(self, modules: Iterable[EquivariantModule]) -> 'ModuleList': r"""Appends multiple :class:`~e2cnn.nn.EquivariantModule` instances from a Python iterable to the end of the list. Args: modules (iterable): iterable of equivariant modules to append """ if not isinstance(modules, container_abcs.Iterable): raise TypeError("ModuleList.extend expects an iterable object, but found " + type(modules).__name__) for module in modules: assert isinstance(module, EquivariantModule) self.append(module) return self
[docs] def export(self) -> torch.nn.ModuleList: r""" Export this module to a normal PyTorch :class:`torch.nn.ModuleList` module and set to "eval" mode. """ self.eval() submodules = [] # convert all the submodules for module in self: module = module.export() submodules.append(module) m = torch.nn.ModuleList(submodules) m.eval() return m
def forward(self): raise NotImplementedError()