Source code for e2cnn.nn.modules.nonlinearities.gated2



from typing import List, Tuple, Any

import numpy as np

from collections import defaultdict

from e2cnn.gspaces import *
from e2cnn.nn import FieldType
from e2cnn.nn import GeometricTensor

from ..equivariant_module import EquivariantModule

import torch

from torch.nn import Parameter

from .gated1 import GATED_ID, GATES_ID

__all__ = ["GatedNonLinearity2"]


[docs]class GatedNonLinearity2(EquivariantModule): def __init__(self, in_type: Tuple[FieldType, FieldType], **kwargs ): r""" Gated non-linearities. This module applies a bias and a sigmoid function of the gates fields and, then, multiplies each gated field by one of the gates. The input representation of the gated fields is preserved by this operation while the gate fields are discarded. The gates and the gated fields are provided in two different tensors: :attr:`in_repr` is a tuple containing two representations: the representation of the tensor containing only the gates (which have to be trivial fields) and the representation of the tensor containing only the gated fields. Therefore, two tensors need to be provided as input to the ``forward`` method: the first contains the gates and the second the gated fields. Finally, notice that the number of gates and the number of gated fields have to match, i.e. the two representations need to have the same number of fields. .. todo :: This module has 2 input tensors and, so, two input field types. EquivariantModule only supports one input though. Fix this. Args: in_type (Tuple): a pair containing, in order, the field type of the gates and the field type of the gated fields """ assert len(in_type) == 2 assert isinstance(in_type[0], FieldType) assert isinstance(in_type[1], FieldType) assert isinstance(in_type[0].gspace, GeneralOnR2) assert in_type[0].gspace == in_type[1].gspace # the first is the representation of the gates for r in in_type[0].representations: # assert GATES_ID in r.supported_nonlinearities, \ assert r.is_trivial(), \ "Error! Representation \"{}\" can't be a \"gate\"".format(r.name) # the second is the representation of the gated fields for r in in_type[1].representations: assert GATED_ID in r.supported_nonlinearities, \ 'Error! Representation "{}" does not support "gated" non-linearity'.format(r.name) assert len(in_type[0]) == len(in_type[1]), \ 'Error! Number of gates ({}) does not match the number of gated non-linearities required ({})' \ .format(len(in_type[0]), len(in_type[1])) super(GatedNonLinearity2, self).__init__() self.space = in_type[0].gspace self._contiguous = {} _input_indices = defaultdict(lambda: []) self._nfields = defaultdict(int) self.branching = None self.in_type = in_type self.out_type = in_type[1] ngates = len(in_type[0]) # group fields by their size, check if fields with the same size are consecutive # and retrieve the indices of the fields last_position = 0 last_size = None for r in in_type[1].representations: if r.size != last_size: if not r.size in self._contiguous: self._contiguous[r.size] = True else: self._contiguous[r.size] = False last_size = r.size self._nfields[r.size] += 1 _input_indices[r.size] += list(range(last_position, last_position + r.size)) last_position += r.size _input_indices = dict(_input_indices) for s, contiguous in self._contiguous.items(): if contiguous: # if the fields are contiguous, only the first and last indices are kept _input_indices[s] = torch.LongTensor([min(_input_indices[s]), max(_input_indices[s])+1]) else: # otherwise, transform the list of indices into a tensor _input_indices[s] = torch.LongTensor(_input_indices[s]) # register the indices tensors as parameters of this module self.register_buffer('indices_{}'.format(s), _input_indices[s]) # the gated fields are preserved and, therefore, the output representation is the same # as the gated fields input representation # build a sorted list of the fields groups, such that every time they are iterated through in the same order self._order = sorted(_input_indices.keys()) # the bias for the gates self.bias = Parameter(torch.randn(1, ngates, 1, 1, dtype=torch.float), requires_grad=True)
[docs] def forward(self, gates: GeometricTensor, input: GeometricTensor) -> GeometricTensor: r""" Apply the gated non-linearity to the input feature map. Args: gates (GeometricTensor): feature map corresponding to the gates input (GeometricTensor): input feature map corresponding to the gated fields Returns: the resulting feature map """ assert isinstance(input, GeometricTensor) assert isinstance(gates, GeometricTensor) assert gates.type == self.in_type[0] assert input.type == self.in_type[1] gates = gates.tensor input = input.tensor # transform the gates gates = torch.sigmoid(gates - self.bias) b, c, h, w = input.shape # build the output tensor output = torch.empty(b, self.out_type.size, h, w, dtype=torch.float, device=self.bias.device) next_gate = 0 # for each field size for size in self._order: indices = getattr(self, f"indices_{size}") # retrieve the needed gates g = gates[:, next_gate:next_gate + self._nfields[size], ...].view(b, -1, 1, h, w) if self._contiguous[size]: # if the fields were contiguous, we can use slicing output[:, indices[0]:indices[1], ...] =\ ( input[:, indices[0]:indices[1], ...] .view(b, -1, size, h, w) * g ).view(b, -1, h, w) else: # otherwise we have to use indexing output[:, indices, :, :] = \ ( input[:, indices, ...] .view(b, -1, size, h, w) * g ).view(b, -1, h, w) # shift the position on the gates fiber next_gate += self._nfields[size] # wrap the result in a GeometricTensor return GeometricTensor(output, self.out_type)
def evaluate_output_shape(self, *input_shapes: Tuple[int, int, int, int]) -> Tuple[int, int, int, int]: assert len(input_shapes) == 2 gates_shape, input_shape = input_shapes assert len(gates_shape) == 4 assert gates_shape[1] == self.in_type[0].size assert len(input_shape) == 4 assert input_shape[1] == self.in_type[1].size assert gates_shape[0] == input_shape[0] assert gates_shape[2:] == input_shape[2:] b, c, hi, wi = input_shape return b, self.out_type.size, hi, wi def check_equivariance(self, atol: float = 1e-6, rtol: float = 1e-5) -> List[Tuple[Any, float]]: c = [self.in_type[i].size for i in range(2)] x = [torch.randn(3, c[i], 10, 10) for i in range(2)] x = [GeometricTensor(x[i], self.in_type[i]) for i in range(2)] errors = [] for el in self.space.testing_elements: out1 = self(*x).transform_fibers(el) out2 = self(*[x[i].transform_fibers(el) for i in range(2)]) errs = (out1.tensor - out2.tensor).detach().numpy() errs = np.abs(errs).reshape(-1) print(el, errs.max(), errs.mean(), errs.var()) assert torch.allclose(out1.tensor, out2.tensor, atol=atol, rtol=rtol), \ 'The error found during equivariance check with element "{}" is too high: max = {}, mean = {} var ={}' \ .format(el, errs.max(), errs.mean(), errs.var()) errors.append((el, errs.mean())) return errors