import warnings
from e2cnn.diffops.utils import (
load_cache,
store_cache,
required_points,
largest_possible_order,
)
from torch.nn.functional import conv2d, pad
from e2cnn.nn import init
from e2cnn.nn import FieldType
from e2cnn.nn import GeometricTensor
from e2cnn.gspaces import *
from e2cnn.diffops import DiscretizationArgs
from ..equivariant_module import EquivariantModule
from .basisexpansion import BasisExpansion
from .basisexpansion_blocks import BlocksBasisExpansion
from typing import Callable, Union, Tuple, List
import torch
from torch.nn import Parameter
import numpy as np
import math
__all__ = ["R2Diffop"]
[docs]class R2Diffop(EquivariantModule):
def __init__(self,
in_type: FieldType,
out_type: FieldType,
kernel_size: int = None,
accuracy: int = None,
padding: int = 0,
stride: int = 1,
dilation: int = 1,
padding_mode: str = 'zeros',
groups: int = 1,
bias: bool = True,
basisexpansion: str = 'blocks',
maximum_order: int = None,
maximum_power: int = None,
maximum_offset: int = None,
recompute: bool = False,
angle_offset: float = None,
basis_filter: Callable[[dict], bool] = None,
initialize: bool = True,
cache: Union[bool, str] = False,
rbffd: bool = False,
radial_basis_function: str = "ga",
smoothing: float = None,
):
r"""
G-steerable planar partial differential operator mapping between the
input and output :class:`~e2cnn.nn.FieldType` s specified by
the parameters ``in_type`` and ``out_type``.
This operation is equivariant under the action of :math:`\R^2\rtimes G` where :math:`G` is the
:attr:`e2cnn.nn.FieldType.fibergroup` of ``in_type`` and ``out_type``.
Specifically, let :math:`\rho_\text{in}: G \to \GL{\R^{c_\text{in}}}` and
:math:`\rho_\text{out}: G \to \GL{\R^{c_\text{out}}}` be the representations specified by the input and output
field types.
Then :class:`~e2cnn.nn.R2Diffop` guarantees an equivariant mapping
.. math::
D [\mathcal{T}^\text{in}_{g,u} . f] = \mathcal{T}^\text{out}_{g,u} . [Df] \qquad\qquad \forall g \in G, u \in \R^2
where the transformation of the input and output fields are given by
.. math::
[\mathcal{T}^\text{in}_{g,u} . f](x) &= \rho_\text{in}(g)f(g^{-1} (x - u)) \\
[\mathcal{T}^\text{out}_{g,u} . f](x) &= \rho_\text{out}(g)f(g^{-1} (x - u)) \\
The equivariance of G-steerable PDOs is guaranteed by restricting the space of PDOs to an
equivariant subspace.
During training, in each forward pass the module expands the basis of G-steerable PDOs with learned weights
before calling :func:`torch.nn.functional.conv2d`.
When :meth:`~torch.nn.Module.eval()` is called, the filter is built with the current trained weights and stored
for future reuse such that no overhead of expanding the PDO remains.
.. warning ::
When :meth:`~torch.nn.Module.train()` is called, the attributes :attr:`~e2cnn.nn.R2Diffop.filter` and
:attr:`~e2cnn.nn.R2Diffop.expanded_bias` are discarded to avoid situations of mismatch with the
learnable expansion coefficients.
See also :meth:`e2cnn.nn.R2Diffop.train`.
This behaviour can cause problems when storing the :meth:`~torch.nn.Module.state_dict` of a model while in
a mode and lately loading it in a model with a different mode, as the attributes of the class change.
To avoid this issue, we recommend converting the model to eval mode before storing or loading the state
dictionary.
The learnable expansion coefficients of the this module can be initialized with the methods in
:mod:`e2cnn.nn.init`.
By default, the weights are initialized in the constructors using :func:`~e2cnn.nn.init.generalized_he_init`.
.. warning ::
This initialization procedure can be extremely slow for wide layers.
In case initializing the model is not required (e.g. before loading the state dict of a pre-trained model)
or another initialization method is preferred (e.g. :func:`~e2cnn.nn.init.deltaorthonormal_init`), the
parameter ``initialize`` can be set to ``False`` to avoid unnecessary overhead.
A reasonable default is to only set the ``kernel_size`` and leave all other options
on their defaults. However, you might get considerable performance improvements
by setting ``smoothing`` to something other than ``None`` (``kernel_size / 4`` is
a sane default, see below for details).
If you want to modify ``accuracy`` or ``maximum_order``, you will need to take
into account how they are related to ``kernel_size``: it is possible to set any two
of ``kernel_size``, ``accuracy`` and ``maximum_order``, in which case the third
one will be determined automatically. Alternatively, you can set either ``kernel_size``
or ``maximum_order``, in which case a sane default will be used for ``accuracy``.
The relation between the three is approximately :math:`\text{kernel size} \approx \text{accuracy} + \text{order}`,
though this formula is off by one in some cases.
A larger maximum order will lead to more basis filters and this more parameters.
A larger accuracy (i.e. larger kernel size at constant order)
might lead to lower equivariance errors, though whether this actually happens may
depend on your exact setup.
The parameters ``basisexpansion``, ``maximum_power``, and ``maximum_offset`` are
optional parameters used to control how the basis for the PDOs is built, how it is sampled on the filter
grid and how it is expanded to build the filter. We suggest to keep these default values.
.. warning::
The discretization of the differential operators relies on two external packages: `sympy <https://docs.sympy.org/>`_ and
`rbf <https://rbf.readthedocs.io>`_. If they are not available, an error is raised.
Args:
in_type (FieldType): the type of the input field, specifying its transformation law
out_type (FieldType): the type of the output field, specifying its transformation law
kernel_size (int, optional): the size of the (square) filter. This can be chosen automatically,
see above for details.
accuracy (int, optional): the desired asymptotic accuracy for the PDO discretization,
affects the ``kernel_size``. See above for details.
padding (int, optional): implicit zero paddings on both sides of the input. Default: ``0``
padding_mode(str, optional): ``zeros``, ``reflect``, ``replicate`` or ``circular``. Default: ``zeros``
stride (int, optional): the stride of the kernel. Default: ``1``
dilation (int, optional): the spacing between kernel elements. Default: ``1``
groups (int, optional): number of blocked connections from input channels to output channels.
It allows depthwise convolution. When used, the input and output types need to be
divisible in ``groups`` groups, all equal to each other.
Default: ``1``.
bias (bool, optional): Whether to add a bias to the output (only to fields which contain a
trivial irrep) or not. Default ``True``
basisexpansion (str, optional): the basis expansion algorithm to use
maximum_order (int, optional): the largest derivative order to allow
as part of the basis. Larger maximum orders require larger kernel sizes,
see above for details.
maximum_power (int, optional): the maximum power of the Laplacian that will be used
for constructing the basis. If this is not ``None``, it places a restriction on
the basis elements, *in addition to* the restriction given by ``maximum_order``.
We suggest to leave this setting on its default unless you have a good reason
to change it.
maximum_offset (int, optional): number of additional (aliased) frequencies in the intertwiners for finite
groups. By default (``None``), all additional frequencies allowed by the frequencies cut-off
are used.
recompute (bool, optional): if ``True``, recomputes a new basis for the equivariant PDOs.
By Default (``False``), it caches the basis built or reuse a cached one, if it is found.
basis_filter (callable, optional): function which takes as input a descriptor of a basis element
(as a dictionary) and returns a boolean value: whether to preserve (``True``) or discard (``False``)
the basis element. By default (``None``), no filtering is applied.
initialize (bool, optional): initialize the weights of the model. Default: ``True``
cache (bool or str, optional): Discretizing the PDOs can take a bit longer than
for kernels, so we provide the option to cache PDOs on disk. Our suggestion is
to keep the cache off (default) and only activate it if discretizing the PDOs
is in fact a bottleneck for your setup (it often is not). Setting ``cache`` to
``True`` will load an existing cache before instantiating the layer and will
write to the cache afterwards. You can also set ``cache`` to ``load`` or ``store``
to only do one of these.
All :class:`~e2cnn.nn.R2Diffop` layers share the PDO cache in memory.
If you have several :class:`~e2cnn.nn.R2Diffop` layers inside your model,
we therefore recommend to leave ``cache`` to ``False`` and instead call
:func:`e2cnn.diffops.load_cache` before instantiating the model, and :func:`e2cnn.diffops.store_cache`
afterwards to save the PDOs for the next run of the program.
This will avoid unnecessary reads/writes from/to disk.
rbffd (bool, optional): if set to ``True``, use RBF-FD discretization instead of
finite differences (the default). We suggest leaving this to ``False`` unless
you have a specific reason for wanting to use RBF-FD.
radial_basis_function (str, optional): which RBF to use (only relevant for RBF-FD).
Can be any of the abbreviations in `this list <https://rbf.readthedocs.io/en/latest/basis.html>`_.
The default is to use Gaussian RBFs because this always avoids singularity issues.
But other RBFs, such as polyharmonic splines, may work better if they are applicable.
smoothing (float, optional): if not ``None``, discretization will be performed
with derivatives of Gaussians as stencils. This is similar to smoothing
with a Gaussian before applying the PDO, though there are slight technical
differences. ``smoothing`` is the standard deviation (in pixels) of the Gaussian,
meaning that larger values correspond to stronger smoothing.
A reasonable value would be about ``kernel_size / 4`` but you might want to experiment
a bit with this parameter.
Attributes:
~.weights (torch.Tensor): the learnable parameters which are used to expand the PDO
~.filter (torch.Tensor): the convolutional stencil obtained by expanding the parameters
in :attr:`~e2cnn.nn.R2Diffop.weights`
~.bias (torch.Tensor): the learnable parameters which are used to expand the bias, if ``bias=True``
~.expanded_bias (torch.Tensor): the equivariant bias which is summed to the output, obtained by expanding
the parameters in :attr:`~e2cnn.nn.R2Diffop.bias`
"""
assert in_type.gspace == out_type.gspace
assert isinstance(in_type.gspace, GeneralOnR2)
super(R2Diffop, self).__init__()
self.space = in_type.gspace
self.in_type = in_type
self.out_type = out_type
if cache and cache != "store":
# Load the cached lambdas for RBFs if they exist
load_cache()
# out of kernel_size, accuracy and maximum_order, exactly two must be known,
# the third one can then be determined automatically.
# To provide sane defaults, we will also allow only kernel_size or maximum_order
# to be set, in that case accuracy will become 2.
if kernel_size is None:
assert maximum_order is not None
if accuracy is None:
accuracy = 2 if (maximum_order > 0) else 1
# TODO: Ideally, we should look at the basis, maybe the maximum_order isn't
# reached (e.g. if it is odd but all basis diffops are even). In that case,
# we could perhaps get away with a smaller kernel
kernel_size = required_points(maximum_order, accuracy)
elif maximum_order is None:
assert kernel_size is not None
if accuracy is None:
accuracy = 2 if (kernel_size > 1) else 1
maximum_order = largest_possible_order(kernel_size, accuracy)
if maximum_order < 2:
warnings.warn(f"Maximum order is only {maximum_order} for kernel size "
f"{kernel_size} and desired accuracy {accuracy}. This may "
"lead to a small basis. If this is unintentional, consider "
"increasing the kernel size.")
elif accuracy is None:
if kernel_size < required_points(maximum_order, 2):
warnings.warn(f"Small kernel size: {kernel_size} x {kernel_size} kernel "
f"is used for differential operators of order up to {maximum_order}. "
"This may lead to bad approximations, consider using a larger kernel "
"or setting the desired accuracy instead of the kernel size.")
else:
# all three are set
raise ValueError("At most two of kernel size, maximum order and accuracy can bet set, "
"see documentation for details.")
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
self.padding = padding
self.padding_mode = padding_mode
self.groups = groups
if isinstance(padding, tuple) and len(padding) == 2:
_padding = padding
elif isinstance(padding, int):
_padding = (padding, padding)
else:
raise ValueError('padding needs to be either an integer or a tuple containing two integers but {} found'.format(padding))
padding_modes = {'zeros', 'reflect', 'replicate', 'circular'}
if padding_mode not in padding_modes:
raise ValueError("padding_mode must be one of [{}], but got padding_mode='{}'".format(padding_modes, padding_mode))
self._reversed_padding_repeated_twice = tuple(x for x in reversed(_padding) for _ in range(2))
if groups > 1:
# Check the input and output classes can be split in `groups` groups, all equal to each other
# first, check that the number of fields is divisible by `groups`
assert len(in_type) % groups == 0
assert len(out_type) % groups == 0
in_size = len(in_type) // groups
out_size = len(out_type) // groups
# then, check that all groups are equal to each other, i.e. have the same types in the same order
assert all(in_type.representations[i] == in_type.representations[i % in_size] for i in range(len(in_type)))
assert all(out_type.representations[i] == out_type.representations[i % out_size] for i in range(len(out_type)))
# finally, retrieve the type associated to a single group in input.
# this type will be used to build a smaller kernel basis and a smaller filter
# as in PyTorch, to build a filter for grouped convolution, we build a filter which maps from one input
# group to all output groups. Then, PyTorch's standard convolution routine interpret this filter as `groups`
# different filters, each mapping an input group to an output group.
in_type = in_type.index_select(list(range(in_size)))
if bias:
# bias can be applied only to trivial irreps inside the representation
# to apply bias to a field we learn a bias for each trivial irreps it contains
# and, then, we transform it with the change of basis matrix to be able to apply it to the whole field
# this is equivalent to transform the field to its irreps through the inverse change of basis,
# sum the bias only to the trivial irrep and then map it back with the change of basis
# count the number of trivial irreps
trivials = 0
for r in self.out_type:
for irr in r.irreps:
if self.out_type.fibergroup.irreps[irr].is_trivial():
trivials += 1
# if there is at least 1 trivial irrep
if trivials > 0:
# matrix containing the columns of the change of basis which map from the trivial irreps to the
# field representations. This matrix allows us to map the bias defined only over the trivial irreps
# to a bias for the whole field more efficiently
bias_expansion = torch.zeros(self.out_type.size, trivials)
p, c = 0, 0
for r in self.out_type:
pi = 0
for irr in r.irreps:
irr = self.out_type.fibergroup.irreps[irr]
if irr.is_trivial():
bias_expansion[p:p+r.size, c] = torch.tensor(r.change_of_basis[:, pi])
c += 1
pi += irr.size
p += r.size
self.register_buffer("bias_expansion", bias_expansion)
self.bias = Parameter(torch.zeros(trivials), requires_grad=True)
self.register_buffer("expanded_bias", torch.zeros(out_type.size))
else:
self.bias = None
self.expanded_bias = None
else:
self.bias = None
self.expanded_bias = None
grid, basis_filter, params = compute_basis_params(kernel_size,
dilation,
basis_filter,
maximum_order,
maximum_power,
rbffd,
radial_basis_function,
smoothing,
angle_offset)
# BasisExpansion: submodule which takes care of building the filter
self._basisexpansion = None
# notice that `in_type` is used instead of `self.in_type` such that it works also when `groups > 1`
if basisexpansion == 'blocks':
self._basisexpansion = BlocksBasisExpansion(in_type, out_type,
self.space.build_diffop_basis,
points=grid,
maximum_offset=maximum_offset,
**params,
basis_filter=basis_filter,
recompute=recompute)
else:
raise ValueError('Basis Expansion algorithm "%s" not recognized' % basisexpansion)
if self.basisexpansion.dimension() == 0:
raise ValueError('''
The basis for the steerable filter is empty!
Tune the `frequencies_cutoff`, `kernel_size`, `rings`, `sigma` or `basis_filter` parameters to allow
for a larger basis.
''')
self.weights = Parameter(torch.zeros(self.basisexpansion.dimension()), requires_grad=True)
self.register_buffer("filter", torch.zeros(out_type.size, in_type.size, kernel_size, kernel_size))
if initialize:
# by default, the weights are initialized with a generalized form of He's weight initialization
init.generalized_he_init(self.weights.data, self.basisexpansion)
if cache and cache != "load":
store_cache()
@property
def basisexpansion(self) -> BasisExpansion:
r"""
Submodule which takes care of building the filter.
It uses the learnt ``weights`` to expand a basis and returns a filter in the usual form used by conventional
convolutional modules.
It uses the learned ``weights`` to expand the PDO in the G-steerable basis and returns it in the shape
:math:`(c_\text{out}, c_\text{in}, s^2)`, where :math:`s` is the ``kernel_size``.
"""
return self._basisexpansion
[docs] def expand_parameters(self) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Expand the filter in terms of the :attr:`e2cnn.nn.R2Diffop.weights` and the
expanded bias in terms of :class:`e2cnn.nn.R2Diffop.bias`.
Returns:
the expanded filter and bias
"""
_filter = self.basisexpansion(self.weights)
_filter = _filter.reshape(_filter.shape[0], _filter.shape[1], self.kernel_size, self.kernel_size)
if self.bias is None:
_bias = None
else:
_bias = self.bias_expansion @ self.bias
return _filter, _bias
[docs] def forward(self, input: GeometricTensor):
r"""
Convolve the input with the expanded filter and bias.
Args:
input (GeometricTensor): input feature field transforming according to ``in_type``
Returns:
output feature field transforming according to ``out_type``
"""
assert input.type == self.in_type
if not self.training:
_filter = self.filter
_bias = self.expanded_bias
else:
# retrieve the filter and the bias
_filter, _bias = self.expand_parameters()
# use it for convolution and return the result
if self.padding_mode == 'zeros':
output = conv2d(input.tensor, _filter,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups,
bias=_bias)
else:
output = conv2d(pad(input.tensor, self._reversed_padding_repeated_twice, self.padding_mode),
_filter,
stride=self.stride,
dilation=self.dilation,
padding=(0,0),
groups=self.groups,
bias=_bias)
return GeometricTensor(output, self.out_type)
[docs] def train(self, mode=True):
r"""
If ``mode=True``, the method sets the module in training mode and discards the :attr:`~e2cnn.nn.R2Diffop.filter`
and :attr:`~e2cnn.nn.R2Diffop.expanded_bias` attributes.
If ``mode=False``, it sets the module in evaluation mode. Moreover, the method builds the filter and the bias using
the current values of the trainable parameters and store them in :attr:`~e2cnn.nn.R2Diffop.filter` and
:attr:`~e2cnn.nn.R2Diffop.expanded_bias` such that they are not recomputed at each forward pass.
.. warning ::
This behaviour can cause problems when storing the :meth:`~torch.nn.Module.state_dict` of a model while in
a mode and lately loading it in a model with a different mode, as the attributes of this class change.
To avoid this issue, we recommend converting the model to eval mode before storing or loading the state
dictionary.
Args:
mode (bool, optional): whether to set training mode (``True``) or evaluation mode (``False``).
Default: ``True``.
"""
if mode:
# TODO thoroughly check this is not causing problems
if hasattr(self, "filter"):
del self.filter
if hasattr(self, "expanded_bias"):
del self.expanded_bias
elif self.training:
# avoid re-computation of the filter and the bias on multiple consecutive calls of `.eval()`
_filter, _bias = self.expand_parameters()
self.register_buffer("filter", _filter)
if _bias is not None:
self.register_buffer("expanded_bias", _bias)
else:
self.expanded_bias = None
return super(R2Diffop, self).train(mode)
def evaluate_output_shape(self, input_shape: Tuple[int, int, int, int]) -> Tuple[int, int, int, int]:
assert len(input_shape) == 4
assert input_shape[1] == self.in_type.size
b, c, hi, wi = input_shape
ho = math.floor((hi + 2 * self.padding - self.dilation * (self.kernel_size - 1) - 1) / self.stride + 1)
wo = math.floor((wi + 2 * self.padding - self.dilation * (self.kernel_size - 1) - 1) / self.stride + 1)
return b, self.out_type.size, ho, wo
def check_equivariance(self, atol: float = 0.1, rtol: float = 0.1, assertion: bool = True, verbose: bool = True):
# np.set_printoptions(precision=5, threshold=30 *self.in_type.size**2, suppress=False, linewidth=30 *self.in_type.size**2)
feature_map_size = 33
last_downsampling = 5
first_downsampling = 5
initial_size = (feature_map_size * last_downsampling - 1 + self.kernel_size) * first_downsampling
c = self.in_type.size
import matplotlib.image as mpimg
from skimage.measure import block_reduce
from skimage.transform import resize
x = mpimg.imread('../group/testimage.jpeg').transpose((2, 0, 1))[np.newaxis, 0:c, :, :]
x = resize(
x,
(x.shape[0], x.shape[1], initial_size, initial_size),
anti_aliasing=True
)
x = x / 255.0 - 0.5
if x.shape[1] < c:
to_stack = [x for i in range(c // x.shape[1])]
if c % x.shape[1] > 0:
to_stack += [x[:, :(c % x.shape[1]), ...]]
x = np.concatenate(to_stack, axis=1)
x = GeometricTensor(torch.FloatTensor(x), self.in_type)
def shrink(t: GeometricTensor, s) -> GeometricTensor:
return GeometricTensor(torch.FloatTensor(block_reduce(t.tensor.detach().numpy(), s, func=np.mean)), t.type)
errors = []
for el in self.space.testing_elements:
out1 = self(shrink(x, (1, 1, 5, 5))).transform(el).tensor.detach().numpy()
out2 = self(shrink(x.transform(el), (1, 1, 5, 5))).tensor.detach().numpy()
out1 = block_reduce(out1, (1, 1, 5, 5), func=np.mean)
out2 = block_reduce(out2, (1, 1, 5, 5), func=np.mean)
b, c, h, w = out2.shape
center_mask = np.zeros((2, h, w))
center_mask[1, :, :] = np.arange(0, w) - w / 2
center_mask[0, :, :] = np.arange(0, h) - h / 2
center_mask[0, :, :] = center_mask[0, :, :].T
center_mask = center_mask[0, :, :] ** 2 + center_mask[1, :, :] ** 2 < (h / 4) ** 2
out1 = out1[..., center_mask]
out2 = out2[..., center_mask]
out1 = out1.reshape(-1)
out2 = out2.reshape(-1)
errs = np.abs(out1 - out2)
esum = np.maximum(np.abs(out1), np.abs(out2))
esum[esum == 0.0] = 1
relerr = errs / esum
if verbose:
print(el, relerr.max(), relerr.mean(), relerr.var(), errs.max(), errs.mean(), errs.var())
tol = rtol * esum + atol
if np.any(errs > tol) and verbose:
print(out1[errs > tol])
print(out2[errs > tol])
print(tol[errs > tol])
if assertion:
assert np.all(errs < tol), 'The error found during equivariance check with element "{}" is too high: max = {}, mean = {} var ={}'.format(el, errs.max(), errs.mean(), errs.var())
errors.append((el, errs.mean()))
return errors
# init.deltaorthonormal_init(self.weights.data, self.basisexpansion)
# filter = self.basisexpansion()
# center = self.s // 2
# filter = filter[..., center, center]
# assert torch.allclose(torch.eye(filter.shape[1]), filter.t() @ filter, atol=3e-7)
[docs] def export(self):
r"""
Export this module to a normal PyTorch :class:`torch.nn.Conv2d` module and set to "eval" mode.
"""
# set to eval mode so the filter and the bias are updated with the current
# values of the weights
self.eval()
_filter = self.filter
_bias = self.expanded_bias
if self.padding_mode not in ['zeros']:
x, y = torch.__version__.split('.')[:2]
if int(x) < 1 or int(y) < 5:
if self.padding_mode == 'circular':
raise ImportError(
"'{}' padding mode had some issues in old `torch` versions. Therefore, we only support conversion from version 1.5 but only version {} is installed.".format(
self.padding_mode, torch.__version__
)
)
else:
raise ImportError(
"`torch` supports '{}' padding mode only from version 1.5 but only version {} is installed.".format(
self.padding_mode, torch.__version__
)
)
# build the PyTorch Conv2d module
has_bias = self.bias is not None
conv = torch.nn.Conv2d(self.in_type.size,
self.out_type.size,
self.kernel_size,
padding=self.padding,
padding_mode=self.padding_mode,
stride=self.stride,
dilation=self.dilation,
groups=self.groups,
bias=has_bias)
# set the filter and the bias
conv.weight.data = _filter.data
if has_bias:
conv.bias.data = _bias.data
return conv
def __repr__(self):
extra_lines = []
extra_repr = self.extra_repr()
if extra_repr:
extra_lines = extra_repr.split('\n')
main_str = self._get_name() + '('
if len(extra_lines) == 1:
main_str += extra_lines[0]
else:
main_str += '\n ' + '\n '.join(extra_lines) + '\n'
main_str += ')'
return main_str
def extra_repr(self):
s = ('{in_type}, {out_type}, kernel_size={kernel_size}, stride={stride}')
if self.padding != 0 and self.padding != (0, 0):
s += ', padding={padding}'
if self.dilation != 1 and self.dilation != (1, 1):
s += ', dilation={dilation}'
if self.groups != 1:
s += ', groups={groups}'
if self.bias is None:
s += ', bias=False'
return s.format(**self.__dict__)
def get_grid_coords(kernel_size: int, dilation: int = 1):
actual_size = dilation * (kernel_size -1) + 1
origin = actual_size / 2 - 0.5
points = []
for y in range(kernel_size):
y *= dilation
for x in range(kernel_size):
x *= dilation
p = (x - origin, -y + origin)
points.append(p)
points = np.array(points)
assert points.shape == (kernel_size ** 2, 2), points.shape
return points.T
def compute_basis_params(kernel_size: int,
dilation: int,
custom_basis_filter: Callable[[dict], bool],
maximum_order: int,
maximum_power: int,
rbffd: bool,
radial_basis_function: str,
smoothing: float,
angle_offset: float,
):
# compute the coordinates of the centers of the cells in the grid where the filter is sampled
grid = get_grid_coords(kernel_size, dilation)
if custom_basis_filter is None:
basis_filter = order_filter(maximum_order)
else:
basis_filter = lambda d: custom_basis_filter(d) and order_filter(maximum_order)(d)
if maximum_power is not None:
maximum_power = min(maximum_power, maximum_order // 2)
else:
maximum_power = maximum_order // 2
if smoothing is not None and rbffd:
raise ValueError("You can't use smoothing and RBF-FD at the same time.")
if smoothing is not None:
method = "gauss"
elif rbffd:
method = "rbffd"
else:
method = "fd"
disc = DiscretizationArgs(
method=method,
smoothing=smoothing,
angle_offset=angle_offset,
phi=radial_basis_function,
)
params = {
# to guarantee that all relevant tensor products
# are generated, we need Laplacian powers up to
# half the maximum order. Anything higher would be
# discarded anyways by the basis_filter
"max_power": maximum_power,
# frequencies higher than than the maximum order will be discarded anyway
"maximum_frequency": maximum_order,
"discretization": disc,
}
return grid, basis_filter, params
def order_filter(maximum_order: int) -> Callable[[dict], bool]:
return lambda attr: attr["order"] <= maximum_order