Source code for e2cnn.nn.modules.masking_module


from e2cnn.nn import GeometricTensor
from e2cnn.nn import FieldType
from .equivariant_module import EquivariantModule

from typing import Tuple

import torch

import math

__all__ = ["MaskModule"]


def build_mask(s, margin=2, dtype=torch.float32):
    mask = torch.zeros(1, 1, s, s, dtype=dtype)
    c = (s-1) / 2
    t = (c - margin/100.*c)**2
    sig = 2.
    for x in range(s):
        for y in range(s):
            r = (x - c) ** 2 + (y - c) ** 2
            if r > t:
                mask[..., x, y] = math.exp((t - r)/sig**2)
            else:
                mask[..., x, y] = 1.
    return mask


[docs]class MaskModule(EquivariantModule): def __init__(self, in_type: FieldType, S: int, margin: float = 0.): r""" Performs an element-wise multiplication of the input with a *mask* of shape ``S x S``. The mask has value :math:`1` in all pixels with distance smaller than ``(S-1)/2 * (1 - margin)/100`` from the center of the mask and :math:`0` elsewhere. Values change smoothly between the two regions. This operation is useful to remove from an input image or feature map all the part of the signal defined on the pixels which lay outside the circle inscribed in the grid. Because a rotation would move these pixels outside the grid, this information would anyways be discarded when rotating an image. However, allowing a model to use this information might break the guaranteed equivariance as rotated and non-rotated inputs have different information content. .. note:: In order to perform the masking, the module expects an input with the same spatial dimensions as the mask. Then, input tensors must have shape ``B x C x S x S``. Args: in_type (FieldType): input field type S (int): the shape of the mask and the expected inputs margin (float, optional): margin around the mask in percentage with respect to the radius of the mask (default ``0.``) """ super(MaskModule, self).__init__() self.margin = margin self.mask = torch.nn.Parameter(build_mask(S, margin=margin), requires_grad=False) self.in_type = self.out_type = in_type def forward(self, input: GeometricTensor) -> GeometricTensor: assert input.type == self.in_type assert input.tensor.shape[2:] == self.mask.shape[2:] out = input.tensor * self.mask return GeometricTensor(out, self.out_type) def evaluate_output_shape(self, input_shape: Tuple[int, ...]) -> Tuple[int, ...]: return input_shape