from e2cnn.gspaces import *
from e2cnn.nn import FieldType
from e2cnn.nn import GeometricTensor
from ..equivariant_module import EquivariantModule
import torch
from typing import List, Tuple, Any
import numpy as np
__all__ = ["VectorFieldNonLinearity"]
[docs]class VectorFieldNonLinearity(EquivariantModule):
def __init__(self, in_type: FieldType, **kwargs):
r"""
VectorField non-linearities.
This non-linearity only supports the regular representation of cyclic group :math:`C_N`, i.e. the group of
:math:`N` discrete rotations.
For each input field, the output one is built by taking the rotation associated with the highest
activation; then, a 2-dimensional vector with an angle with respect to the x-axis equal to that rotation and a
length equal to its activation is set in the output field.
Args:
in_type (FieldType): the input field type
"""
assert isinstance(in_type.gspace, GeneralOnR2)
assert in_type.gspace.fibergroup.order() > 1
for r in in_type.representations:
assert 'vectorfield' in r.supported_nonlinearities,\
'Error! Representation "{}" does not support "vector-field" non-linearity'.format(r.name)
assert r.name == 'regular' and r.size == in_type.gspace.fibergroup.order(), r.name
super(VectorFieldNonLinearity, self).__init__()
self.space = in_type.gspace
self.in_type = in_type
# build the output representation substituting each input field with a rotation representation with frequency 1
self.out_type = FieldType(self.space, [self.space.representations['irrep_1']] * len(in_type))
# the number of rotations associated with the group action
self._rotations = self.space.fibergroup.order()
[docs] def forward(self, input: GeometricTensor) -> GeometricTensor:
r"""
Apply the VectorField non-linearity to the input feature map.
Args:
input (GeometricTensor): the input feature map
Returns:
the resulting feature map
"""
assert input.type == self.in_type
b, c, h, w = input.tensor.shape
# split the channel dimension in 2 dimensions, separating fields
fm = input.tensor.view(b, -1, self._rotations, h, w)
# evaluate the base rotation associated with the group action
base_angle = 2 * np.pi / self._rotations
# for each field, retrieve the maximum activation (and the argmax)
max_activations, argmaxes = torch.max(fm, 2)
max_activations = torch.relu_(max_activations)
# compute the angles from the index of the maximum activation in the field
max_angles = argmaxes.to(dtype=torch.float) * base_angle
# build the output tensor
output = torch.empty(b, self.out_type.size, h, w, dtype=torch.float, device=input.tensor.device)
# to build the output vectors, take the cosine and the sine of the argmax angle
# and multiply the 2-dimensional vector by the activation value
output[:, ::2, ...] = torch.cos(max_angles) * max_activations
output[:, 1::2, ...] = torch.sin(max_angles) * max_activations
# wrap the result in a GeometricTensor
return GeometricTensor(output, self.out_type)
def evaluate_output_shape(self, input_shape: Tuple[int, int, int, int]) -> Tuple[int, int, int, int]:
assert len(input_shape) == 4
assert input_shape[1] == self.in_type.size
b, c, hi, wi = input_shape
return b, self.out_type.size, hi, wi
def check_equivariance(self, atol: float = 1e-6, rtol: float = 1e-5) -> List[Tuple[Any, float]]:
c = self.in_type.size
x = torch.randn(3, c, 10, 10)
x = GeometricTensor(x, self.in_type)
errors = []
for el in self.space.testing_elements:
out1 = self(x).transform_fibers(el)
out2 = self(x.transform_fibers(el))
errs = (out1.tensor - out2.tensor).detach().numpy()
errs = np.abs(errs).reshape(-1)
print(el, errs.max(), errs.mean(), errs.var())
assert torch.allclose(out1.tensor, out2.tensor, atol=atol, rtol=rtol), \
'The error found during equivariance check with element "{}" is too high: max = {}, mean = {} var ={}' \
.format(el, errs.max(), errs.mean(), errs.var())
errors.append((el, errs.mean()))
return errors