Source code for e2cnn.nn.modules.dropout.pointwise


from e2cnn.gspaces import *
from e2cnn.nn import FieldType
from e2cnn.nn import GeometricTensor

from ..equivariant_module import EquivariantModule

import torch.nn.functional as F
import torch
from typing import List, Tuple, Any

__all__ = ["PointwiseDropout"]


[docs]class PointwiseDropout(EquivariantModule): def __init__(self, in_type: FieldType, p: float = 0.5, inplace: bool = False ): r""" Applies dropout to individual *channels* independently. This class is just a wrapper for :func:`torch.nn.functional.dropout` in an :class:`~e2cnn.nn.EquivariantModule`. Only representations supporting pointwise non-linearities are accepted as input field type. Args: in_type (FieldType): the input field type p (float, optional): dropout probability inplace (bool, optional): can optionally do the operation in-place. Default: ``False`` """ assert isinstance(in_type.gspace, GeneralOnR2) if p < 0 or p > 1: raise ValueError("dropout probability has to be between 0 and 1, but got {}".format(p)) super(PointwiseDropout, self).__init__() for r in in_type.representations: assert 'pointwise' in r.supported_nonlinearities, \ 'Error! Representation "{}" does not support "pointwise" non-linearity'.format(r.name) self.space = in_type.gspace self.in_type = in_type self.out_type = in_type self.p = p self.inplace = inplace
[docs] def forward(self, input: GeometricTensor) -> GeometricTensor: r""" Args: input (GeometricTensor): the input feature map Returns: the resulting feature map """ assert input.type == self.in_type output = F.dropout(input.tensor, self.p, self.training, self.inplace) # wrap the result in a GeometricTensor return GeometricTensor(output, self.out_type)
def evaluate_output_shape(self, input_shape: Tuple[int, int, int, int]) -> Tuple[int, int, int, int]: assert len(input_shape) == 4 assert input_shape[1] == self.in_type.size return input_shape def check_equivariance(self, atol: float = 1e-6, rtol: float = 1e-5) -> List[Tuple[Any, float]]: # return super(InnerBatchNorm, self).check_equivariance(atol=atol, rtol=rtol) pass
[docs] def export(self): r""" Export this module to a normal PyTorch :class:`torch.nn.Dropout` module and set to "eval" mode. """ self.eval() return torch.nn.Dropout(self.p, self.inplace).eval()