Source code for e2cnn.nn.modules.pooling.pointwise_adaptive_max



from e2cnn.gspaces import *
from e2cnn.nn import FieldType
from e2cnn.nn import GeometricTensor

from ..equivariant_module import EquivariantModule

import torch
import torch.nn.functional as F

from typing import List, Tuple, Any, Union

__all__ = ["PointwiseAdaptiveMaxPool"]


[docs]class PointwiseAdaptiveMaxPool(EquivariantModule): def __init__(self, in_type: FieldType, output_size: Union[int, Tuple[int, int]] ): r""" Module that implements adaptive channel-wise max-pooling: each channel is treated independently. This module works exactly as :class:`torch.nn.AdaptiveMaxPool2D`, wrapping it in the :class:`~e2cnn.nn.EquivariantModule` interface. Notice that not all representations support this kind of pooling. In general, only representations which support pointwise non-linearities do. Args: in_type (FieldType): the input field type output_size: the target output size of the image of the form H x W """ assert isinstance(in_type.gspace, GeneralOnR2) for r in in_type.representations: assert 'pointwise' in r.supported_nonlinearities, \ f"""Error! Representation "{r}" does not support pointwise non-linearities so it is not possible to pool each channel independently""" super(PointwiseAdaptiveMaxPool, self).__init__() self.space = in_type.gspace self.in_type = in_type self.out_type = in_type if isinstance(output_size, int): self.output_size = (output_size, output_size) else: self.output_size = output_size
[docs] def forward(self, input: GeometricTensor) -> GeometricTensor: r""" Args: input (GeometricTensor): the input feature map Returns: the resulting feature map """ assert input.type == self.in_type # run the common max-pooling output = F.adaptive_max_pool2d(input.tensor, self.output_size) # wrap the result in a GeometricTensor return GeometricTensor(output, self.out_type)
def evaluate_output_shape(self, input_shape: Tuple[int, int, int, int]) -> Tuple[int, int, int, int]: assert len(input_shape) == 4 assert input_shape[1] == self.in_type.size b, c, hi, wi = input_shape return b, self.out_type.size, self.output_size, self.output_size def check_equivariance(self, atol: float = 1e-6, rtol: float = 1e-5) -> List[Tuple[Any, float]]: # this kind of pooling is not really equivariant so we can not test equivariance pass
[docs] def export(self): r""" Export this module to a normal PyTorch :class:`torch.nn.AdaptiveAvgPool2d` module and set to "eval" mode. """ self.eval() return torch.nn.AdaptiveMaxPool2d(self.output_size).eval()