Source code for e2cnn.nn.modules.r2upsampling



from e2cnn.gspaces import *
from e2cnn.nn import FieldType
from e2cnn.nn import GeometricTensor

from .equivariant_module import EquivariantModule

from typing import Tuple, Union

import torch
import numpy as np

import math

from torch.nn.functional import interpolate

__all__ = ["R2Upsampling"]


[docs]class R2Upsampling(EquivariantModule): def __init__(self, in_type: FieldType, scale_factor: int = None, size: Union[int, Tuple[int, int]] = None, mode: str = "bilinear", align_corners: bool = False ): r""" Wrapper for :func:`torch.nn.functional.interpolate`. Check its documentation for further details. Only ``"bilinear"`` and ``"nearest"`` methods are supported. However, ``"nearest"`` is not equivariant; using this method may result in broken equivariance. For this reason, we suggest to use ``"bilinear"`` (default value). .. warning :: The module supports a ``size`` parameter as an alternative to ``scale_factor``. However, the use of ``scale_factor`` should be *preferred*, since it guarantees both axes are scaled uniformly, which preserves rotation equivariance. A misuse of the parameter ``size`` can break the overall equivariance, since it might scale the two axes by two different factors. Args: in_type (FieldType): the input field type scale_factor (optional, int): multiplier for spatial size size (optional, int or tuple): output spatial size. mode (str): algorithm used for upsampling: ``nearest`` | ``bilinear``. Default: ``bilinear`` align_corners (bool): if ``True``, the corner pixels of the input and output tensors are aligned, and thus preserving the values at those pixels. This only has effect when mode is ``bilinear``. Default: ``False`` """ assert isinstance(in_type.gspace, GeneralOnR2) super(R2Upsampling, self).__init__() self.space = in_type.gspace self.in_type = in_type self.out_type = in_type assert size is None or scale_factor is None, \ f'Only one of "size" and "scale_factor" can be set, but found scale_factor={scale_factor} and size={size}' self._size = (size, size) if isinstance(size, int) else size assert self._size is None or (isinstance(self._size, tuple) and len(self._size) == 2), self._size self._scale_factor = scale_factor self._mode = mode self._align_corners = align_corners if mode != "nearest" else None if mode not in ["nearest", "bilinear"]: raise ValueError(f'Error Upsampling mode {mode} not recognized! Mode should be `nearest` or `bilinear`.')
[docs] def forward(self, input: GeometricTensor): r""" Args: input (torch.Tensor): input feature map Returns: the result of the convolution """ assert input.type == self.in_type if self._align_corners is None: output = interpolate(input.tensor, scale_factor=self._scale_factor, size=self._size, mode=self._mode) else: output = interpolate(input.tensor, scale_factor=self._scale_factor, size=self._size, mode=self._mode, align_corners=self._align_corners) return GeometricTensor(output, self.out_type)
def evaluate_output_shape(self, input_shape: Tuple[int, int, int, int]) -> Tuple[int, int, int, int]: assert len(input_shape) == 4 assert input_shape[1] == self.in_type.size b, c, hi, wi = input_shape if self._size is None: ho = math.floor(hi * self._scale_factor) wo = math.floor(wi * self._scale_factor) else: ho = self._size[0] wo = self._size[1] return b, self.out_type.size, ho, wo def check_equivariance(self, atol: float = 0.1, rtol: float = 0.1): initial_size = 55 c = self.in_type.size # x = torch.randn(3, c, initial_size, initial_size) import matplotlib.image as mpimg from skimage.transform import resize x = mpimg.imread('../group/testimage.jpeg').transpose((2, 0, 1))[np.newaxis, 0:c, :, :] x = x / 255.0 x = resize( x, (x.shape[0], x.shape[1], initial_size, initial_size), anti_aliasing=True ) if x.shape[1] < c: to_stack = [x for i in range(c // x.shape[1])] if c % x.shape[1] > 0: to_stack += [x[:, :(c % x.shape[1]), ...]] x = np.concatenate(to_stack, axis=1) x = GeometricTensor(torch.FloatTensor(x), self.in_type) errors = [] for el in self.space.testing_elements: out1 = self(x).transform(el).tensor.detach().numpy() out2 = self(x.transform(el)).tensor.detach().numpy() b, c, h, w = out2.shape center_mask = np.zeros((2, h, w)) center_mask[1, :, :] = np.arange(0, w) - w / 2 center_mask[0, :, :] = np.arange(0, h) - h / 2 center_mask[0, :, :] = center_mask[0, :, :].T center_mask = center_mask[0, :, :] ** 2 + center_mask[1, :, :] ** 2 < (h * 0.4) ** 2 out1 = out1[..., center_mask] out2 = out2[..., center_mask] out1 = out1.reshape(-1) out2 = out2.reshape(-1) errs = np.abs(out1 - out2) esum = np.maximum(np.abs(out1), np.abs(out2)) esum[esum == 0.0] = 1 relerr = errs / esum # print(el, relerr.max(), relerr.mean(), relerr.var(), errs.max(), errs.mean(), errs.var()) # tol = rtol*(np.abs(out1) + np.abs(out2)) + atol tol = rtol * esum + atol if np.any(errs > tol): print(el, relerr.max(), relerr.mean(), relerr.var(), errs.max(), errs.mean(), errs.var()) # print(errs[errs > tol]) print(out1[errs > tol]) print(out2[errs > tol]) print(tol[errs > tol]) # assert np.all(np.abs(out1 - out2) < tol), 'The error found during equivariance check with element "{}" is too high: max = {}, mean = {} var ={}'.format(el, errs.max(), errs.mean(), errs.var()) assert np.all(errs < tol), 'The error found during equivariance check with element "{}" is too high: max = {}, mean = {} var ={}'.format(el, errs.max(), errs.mean(), errs.var()) errors.append((el, errs.mean())) return errors
[docs] def export(self): r""" Export this module to a normal PyTorch :class:`torch.nn.Upsample` module and set to "eval" mode. """ self.eval() if self._align_corners is not None: upsample = torch.nn.Upsample( scale_factor=self._scale_factor, mode=self._mode, align_corners=self._align_corners ) else: upsample = torch.nn.Upsample( scale_factor=self._scale_factor, mode=self._mode, ) return upsample.eval()