Source code for escnn.nn.modules.linear


from escnn.gspaces import GSpace0D
from escnn.nn import init
from escnn.nn import FieldType
from escnn.nn import GeometricTensor
from .equivariant_module import EquivariantModule

from escnn.nn.modules.basismanager import BasisManager
from escnn.nn.modules.basismanager import BlocksBasisExpansion

from torch.nn import Parameter
from torch.nn.functional import linear
import torch
import numpy as np

from typing import Tuple


__all__ = ["Linear"]


[docs]class Linear(EquivariantModule): def __init__(self, in_type: FieldType, out_type: FieldType, bias: bool = True, basisexpansion: str = 'blocks', recompute: bool = False, initialize: bool = True, ): r""" G-equivariant linear transformation mapping between the input and output :class:`~escnn.nn.FieldType` s specified by the parameters ``in_type`` and ``out_type``. This operation is equivariant under the action of the :attr:`escnn.nn.FieldType.fibergroup` :math:`G` of ``in_type`` and ``out_type``. Specifically, let :math:`\rho_\text{in}: G \to \GL{\R^{c_\text{in}}}` and :math:`\rho_\text{out}: G \to \GL{\R^{c_\text{out}}}` be the representations specified by the input and output field types. Then :class:`~escnn.nn.Linear` guarantees an equivariant mapping .. math:: W \rho_\text{in}(g) v = \rho_\text{out}(g) W v \qquad\qquad \forall g \in G, u \in \R^{c_\text{in}} where :math:`\rho_\text{in}` and :math:`\rho_\text{out}` are the :math:`G`-representations associated with ``in_type`` and ``out_type``. The equivariance of a G-equivariant linear layer is guaranteed by restricting the space of weight matrices to an equivariant subspace. During training, in each forward pass the module expands the basis of G-equivariant matrices with learned weights before performing the linear trasformation. When :meth:`~torch.nn.Module.eval()` is called, the matrix is built with the current trained weights and stored for future reuse such that no overhead of expanding the matrix remains. .. warning :: When :meth:`~torch.nn.Module.train()` is called, the attributes :attr:`~escnn.nn.Linear.matrix` and :attr:`~escnn.nn.Linear.expanded_bias` are discarded to avoid situations of mismatch with the learnable expansion coefficients. See also :meth:`escnn.nn.Linear.train`. This behaviour can cause problems when storing the :meth:`~torch.nn.Module.state_dict` of a model while in a mode and lately loading it in a model with a different mode, as the attributes of the class change. To avoid this issue, we recommend converting the model to eval mode before storing or loading the state dictionary. .. warning :: To ensure compatibility with both :class:`torch.nn.Linear` and :class:`~escnn.nn.GeometricTensor`, this module supports only input tensors with two dimensions ``(batch_size, number_features)``. The learnable expansion coefficients of the this module can be initialized with the methods in :mod:`escnn.nn.init`. By default, the weights are initialized in the constructors using :func:`~escnn.nn.init.generalized_he_init`. .. warning :: This initialization procedure can be extremely slow for wide layers. In case initializing the model is not required (e.g. before loading the state dict of a pre-trained model) or another initialization method is preferred (e.g. :func:`~escnn.nn.init.deltaorthonormal_init`), the parameter ``initialize`` can be set to ``False`` to avoid unnecessary overhead. Args: in_type (FieldType): the type of the input field, specifying its transformation law out_type (FieldType): the type of the output field, specifying its transformation law bias (bool, optional): Whether to add a bias to the output (only to fields which contain a trivial irrep) or not. Default ``True`` basisexpansion (str, optional): the basis expansion algorithm to use. You can ignore this attribute. recompute (bool, optional): if ``True``, recomputes a new basis for the equivariant kernels. By Default (``False``), it caches the basis built or reuse a cached one, if it is found. initialize (bool, optional): initialize the weights of the model. Default: ``True`` Attributes: ~.weights (torch.Tensor): the learnable parameters which are used to expand the matrix ~.matrix (torch.Tensor): the matrix obtained by expanding the parameters in :attr:`~escnn.nn.Linear.weights` ~.bias (torch.Tensor): the learnable parameters which are used to expand the bias, if ``bias=True`` ~.expanded_bias (torch.Tensor): the equivariant bias which is summed to the output, obtained by expanding the parameters in :attr:`~escnn.nn.Linear.bias` """ # only GSpace0D allowed since Linear acts on the last dimension of a tensor assert isinstance(in_type.gspace, GSpace0D) assert in_type.gspace == out_type.gspace super(Linear, self).__init__() self.in_type = in_type self.out_type = out_type self.space = in_type.gspace if bias: # bias can be applied only to trivial irreps inside the representation # to apply bias to a field we learn a bias for each trivial irreps it contains # and, then, we transform it with the change of basis matrix to be able to apply it to the whole field # this is equivalent to transform the field to its irreps through the inverse change of basis, # sum the bias only to the trivial irrep and then map it back with the change of basis # count the number of trivial irreps trivials = 0 for r in self.out_type: for irr in r.irreps: if self.out_type.fibergroup.irrep(*irr).is_trivial(): trivials += 1 # if there is at least 1 trivial irrep if trivials > 0: # matrix containing the columns of the change of basis which map from the trivial irreps to the # field representations. This matrix allows us to map the bias defined only over the trivial irreps # to a bias for the whole field more efficiently bias_expansion = torch.zeros(self.out_type.size, trivials) p, c = 0, 0 for r in self.out_type: pi = 0 for irr in r.irreps: irr = self.out_type.fibergroup.irrep(*irr) if irr.is_trivial(): bias_expansion[p:p + r.size, c] = torch.tensor(r.change_of_basis[:, pi]) c += 1 pi += irr.size p += r.size self.register_buffer("bias_expansion", bias_expansion) self.bias = Parameter(torch.zeros(trivials), requires_grad=True) self.register_buffer("expanded_bias", torch.zeros(out_type.size)) else: self.bias = None self.expanded_bias = None else: self.bias = None self.expanded_bias = None # BasisExpansion: submodule which takes care of building the matrix self._basisexpansion = None if basisexpansion == 'blocks': self._basisexpansion = BlocksBasisExpansion(in_type.representations, out_type.representations, self.space.build_fiber_intertwiner_basis, np.zeros((1, 1)), recompute=recompute) else: raise ValueError('Basis Expansion algorithm "%s" not recognized' % basisexpansion) self.weights = Parameter(torch.zeros(self.basisexpansion.dimension()), requires_grad=True) filter_size = (out_type.size, in_type.size) self.register_buffer("matrix", torch.zeros(*filter_size)) if initialize: # by default, the weights are initialized with a generalized form of He's weight initialization init.generalized_he_init(self.weights.data, self.basisexpansion)
[docs] def forward(self, input: GeometricTensor): r""" Convolve the input with the expanded matrix and bias. Args: input (GeometricTensor): input feature field transforming according to ``in_type`` Returns: output feature field transforming according to ``out_type`` """ assert input.type == self.in_type # only GSpace0D allowed in practice assert len(input.shape) == 2 if not self.training: _matrix = self.matrix _bias = self.expanded_bias else: # retrieve the matrix and the bias _matrix, _bias = self.expand_parameters() output = linear(input.tensor, _matrix, bias=_bias) return GeometricTensor(output, self.out_type, input.coords)
@property def basisexpansion(self) -> BasisManager: r""" Submodule which takes care of building the matrix. It uses the learnt ``weights`` to expand a basis and returns a matrix in the usual form used by conventional linear modules. It uses the learned ``weights`` to expand the kernel in the G-steerable basis and returns it in the shape :math:`(c_\text{out}, c_\text{in})`. """ return self._basisexpansion
[docs] def expand_parameters(self) -> Tuple[torch.Tensor, torch.Tensor]: r""" Expand the matrix in terms of the :attr:`escnn.nn.Linear.weights` and the expanded bias in terms of :class:`escnn.nn.Linear.bias`. Returns: the expanded matrix and bias """ _matrix = self.basisexpansion(self.weights) _matrix = _matrix.reshape(_matrix.shape[0], _matrix.shape[1]) if self.bias is None: _bias = None else: _bias = self.bias_expansion @ self.bias return _matrix, _bias
[docs] def train(self, mode=True): r""" If ``mode=True``, the method sets the module in training mode and discards the :attr:`~escnn.nn.Linear.matrix` and :attr:`~escnn.nn.Linear.expanded_bias` attributes. If ``mode=False``, it sets the module in evaluation mode. Moreover, the method builds the matrix and the bias using the current values of the trainable parameters and store them in :attr:`~escnn.nn.Linear.matrix` and :attr:`~escnn.nn.Linear.expanded_bias` such that they are not recomputed at each forward pass. .. warning :: This behaviour can cause problems when storing the :meth:`~torch.nn.Module.state_dict` of a model while in a mode and lately loading it in a model with a different mode, as the attributes of this class change. To avoid this issue, we recommend converting the model to eval mode before storing or loading the state dictionary. Args: mode (bool, optional): whether to set training mode (``True``) or evaluation mode (``False``). Default: ``True``. """ if not mode: _matrix, _bias = self.expand_parameters() self.register_buffer("matrix", _matrix) if _bias is not None: self.register_buffer("expanded_bias", _bias) else: self.expanded_bias = None else: # TODO thoroughly check this is not causing problems if hasattr(self, "matrix"): del self.matrix if hasattr(self, "expanded_bias"): del self.expanded_bias return super(Linear, self).train(mode)
def evaluate_output_shape(self, input_shape: Tuple) -> Tuple: assert len(input_shape) == 2 assert input_shape[1] == self.in_type.size return (input_shape[0], self.out_type.size)
[docs] def export(self) -> torch.nn.Linear: r""" Export this module to a normal PyTorch :class:`torch.nn.Linear` module and set to "eval" mode. """ # set to eval mode so the matrix and the bias are updated with the current # values of the weights self.eval() _matrix = self.matrix _bias = self.expanded_bias has_bias = self.bias is not None # build the PyTorch module linear = torch.nn.Linear(self.in_type.size, self.out_type.size, bias=has_bias) # set the weights and the bias linear.weight.data = _matrix.data if has_bias: linear.bias.data = _bias.data return linear
def check_equivariance(self, atol: float = 1e-6, rtol: float = 1e-6, assertion: bool = True, verbose: bool = True): x = torch.randn(10, self.in_type.size) x = GeometricTensor(x, self.in_type) errors = [] for el in self.space.testing_elements: out1 = self(x).transform_fibers(el).tensor.detach().numpy() out2 = self(x.transform_fibers(el)).tensor.detach().numpy() errs = np.abs(out1 - out2) esum = np.maximum(np.abs(out1), np.abs(out2)) esum[esum == 0.0] = 1 relerr = errs / esum if verbose: print(el, relerr.max(), relerr.mean(), relerr.var(), errs.max(), errs.mean(), errs.var()) tol = rtol * esum + atol if np.any(errs > tol) and verbose: print(out1[errs > tol]) print(out2[errs > tol]) print(tol[errs > tol]) if assertion: assert np.all( errs < tol), 'The error found during equivariance check with element "{}" is too high: max = {}, mean = {} var ={}'.format( el, errs.max(), errs.mean(), errs.var()) errors.append((el, errs.mean())) return errors