Source code for escnn.nn.modules.nonlinearities.vectorfield

from escnn.gspaces import *
from import CyclicGroup
from escnn.nn import FieldType
from escnn.nn import GeometricTensor

from ..equivariant_module import EquivariantModule

import torch

from typing import List, Tuple, Any

import numpy as np

__all__ = ["VectorFieldNonLinearity"]

[docs]class VectorFieldNonLinearity(EquivariantModule): def __init__(self, in_type: FieldType, **kwargs): r""" VectorField non-linearities. This non-linearity only supports the regular representation of cyclic group :math:`C_N`, i.e. the group of :math:`N` discrete rotations. For each input field, the output one is built by taking the rotation associated with the highest activation; then, a 2-dimensional vector with an angle with respect to the x-axis equal to that rotation and a length equal to its activation is set in the output field. Args: in_type (FieldType): the input field type """ assert isinstance(in_type.gspace, GSpace) assert isinstance(in_type.gspace.fibergroup, CyclicGroup) assert in_type.gspace.fibergroup.order() > 1 for r in in_type.representations: assert 'vectorfield' in r.supported_nonlinearities,\ 'Error! Representation "{}" does not support "vector-field" non-linearity'.format( assert == 'regular' and r.size == in_type.gspace.fibergroup.order(), super(VectorFieldNonLinearity, self).__init__() = in_type.gspace self.in_type = in_type # build the output representation substituting each input field with a rotation representation with frequency 1 self.out_type = FieldType(, [['irrep_1']] * len(in_type)) # the number of rotations associated with the group action self._rotations =
[docs] def forward(self, input: GeometricTensor) -> GeometricTensor: r""" Apply the VectorField non-linearity to the input feature map. Args: input (GeometricTensor): the input feature map Returns: the resulting feature map """ assert input.type == self.in_type b, c = input.shape[:2] spatial_shape = input.shape[2:] # split the channel dimension in 2 dimensions, separating fields fm = input.tensor.view(b, -1, self._rotations, *spatial_shape) # evaluate the base rotation associated with the group action base_angle = 2 * np.pi / self._rotations # for each field, retrieve the maximum activation (and the argmax) max_activations, argmaxes = torch.max(fm, 2) max_activations = torch.relu_(max_activations) # compute the angles from the index of the maximum activation in the field max_angles = * base_angle # build the output tensor output = torch.empty(b, self.out_type.size, *spatial_shape, dtype=torch.float, device=input.tensor.device) # to build the output vectors, take the cosine and the sine of the argmax angle # and multiply the 2-dimensional vector by the activation value output[:, ::2, ...] = torch.cos(max_angles) * max_activations output[:, 1::2, ...] = torch.sin(max_angles) * max_activations # wrap the result in a GeometricTensor return GeometricTensor(output, self.out_type, input.coords)
def evaluate_output_shape(self, input_shape: Tuple[int, ...]) -> Tuple[int, ...]: assert len(input_shape) >= 2 assert input_shape[1] == self.in_type.size b, c = input_shape[:2] spatial_shape = input_shape[2:] return (b, self.out_type.size, *spatial_shape) def check_equivariance(self, atol: float = 1e-6, rtol: float = 1e-5) -> List[Tuple[Any, float]]: c = self.in_type.size x = torch.randn(3, c, 10, 10) x = GeometricTensor(x, self.in_type) errors = [] for el in out1 = self(x).transform_fibers(el) out2 = self(x.transform_fibers(el)) errs = (out1.tensor - out2.tensor).detach().numpy() errs = np.abs(errs).reshape(-1) print(el, errs.max(), errs.mean(), errs.var()) assert torch.allclose(out1.tensor, out2.tensor, atol=atol, rtol=rtol), \ 'The error found during equivariance check with element "{}" is too high: max = {}, mean = {} var ={}' \ .format(el, errs.max(), errs.mean(), errs.var()) errors.append((el, errs.mean())) return errors