from __future__ import annotations
from e2cnn import gspaces
from e2cnn import kernels
from e2cnn import diffops
from .general_r2 import GeneralOnR2
from typing import Union, Tuple, Callable, List
from e2cnn.group import Representation
from e2cnn.group import Group
from e2cnn.group import CyclicGroup
from e2cnn.group import cyclic_group
from e2cnn.diffops import DiscretizationArgs
import numpy as np
__all__ = ["TrivialOnR2"]
[docs]class TrivialOnR2(GeneralOnR2):
def __init__(self, fibergroup: Group = None):
r"""
Describes the plane :math:`\R^2` without considering any origin-preserving symmetry.
This is modeled by a choosing trivial fiber group :math:`\{e\}`.
.. note ::
This models the symmetries of conventional *Convolutional Neural Networks* which are not equivariant to
origin preserving transformations such as rotations and reflections.
Args:
fibergroup (Group, optional): use an already existing instance of the symmetry group.
By default, it builds a new instance of the trivial group.
"""
if fibergroup is None:
fibergroup = cyclic_group(1)
else:
assert isinstance(fibergroup, CyclicGroup) and fibergroup.order() == 1
name = "Trivial"
super(TrivialOnR2, self).__init__(fibergroup, name)
[docs] def restrict(self, id: int) -> Tuple[gspaces.GSpace, Callable, Callable]:
r"""
Build the :class:`~e2cnn.group.GSpace` associated with the subgroup of the current fiber group identified by
the input ``id``.
As the trivial group contains only one element, there are no other subgroups.
The only accepted input value is ``id=1`` and returns this same group.
This functionality is implemented only for consistency with the other G-spaces.
Args:
id (int): the order of the subgroup
Returns:
a tuple containing
- **gspace**: the restricted gspace
- **back_map**: a function mapping an element of the subgroup to itself in the fiber group of the original space
- **subgroup_map**: a function mapping an element of the fiber group of the original space to itself in the subgroup (returns ``None`` if the element is not in the subgroup)
"""
group, mapping, child = self.fibergroup.subgroup(id)
return gspaces.TrivialOnR2(fibergroup=group), mapping, child
def _basis_generator(self,
in_repr: Representation,
out_repr: Representation,
rings: List[float],
sigma: List[float],
**kwargs,
) -> kernels.KernelBasis:
r"""
Method that builds the analitical basis that spans the space of equivariant filters which
are intertwiners between the representations induced from the representation ``in_repr`` and ``out_repr``.
Either ``maximum_frequency`` or ``maximum_offset`` must be set in the keywords arguments.
Args:
in_repr: the input representation
out_repr: the output representation
rings: radii of the rings where to sample the bases
sigma: parameters controlling the width of each ring where the bases are sampled.
Keyword Args:
maximum_frequency (int): the maximum frequency allowed in the basis vectors
maximum_offset (int): the maximum frequencies offset for each basis vector with respect to its base ones
(sum and difference of the frequencies of the input and the output representations)
Returns:
the basis built
"""
maximum_frequency = None
maximum_offset = None
if 'maximum_frequency' in kwargs and kwargs['maximum_frequency'] is not None:
maximum_frequency = kwargs['maximum_frequency']
assert isinstance(maximum_frequency, int) and maximum_frequency >= 0
if 'maximum_offset' in kwargs and kwargs['maximum_offset'] is not None:
maximum_offset = kwargs['maximum_offset']
assert isinstance(maximum_offset, int) and maximum_offset >= 0
assert (maximum_frequency is not None or maximum_offset is not None), \
'Error! Either the maximum frequency or the maximum offset for the frequencies must be set'
return kernels.kernels_Trivial_act_R2(in_repr, out_repr, rings, sigma,
maximum_frequency,
max_offset=maximum_offset)
def _diffop_basis_generator(self,
in_repr: Representation,
out_repr: Representation,
max_power: int,
discretization: DiscretizationArgs,
**kwargs,
) -> diffops.DiffopBasis:
r"""
Method that builds the analytical basis that spans the space of equivariant PDOs which
are intertwiners between the representations induced from the representation ``in_repr`` and ``out_repr``.
Either ``maximum_frequency`` or ``maximum_offset`` must be set in the keywords arguments.
Args:
in_repr: the input representation
out_repr: the output representation
max_power (int): the maximum power of Laplacians to use
discretization (DiscretizationArgs): the parameters specifying a discretization procedure for PDOs
Keyword Args:
maximum_frequency (int): the maximum frequency allowed in the basis vectors
maximum_offset (int): the maximum frequencies offset for each basis vector with respect to its base ones
(sum and difference of the frequencies of the input and the output representations)
Returns:
the basis built
"""
maximum_frequency = None
maximum_offset = None
if 'maximum_frequency' in kwargs and kwargs['maximum_frequency'] is not None:
maximum_frequency = kwargs['maximum_frequency']
assert isinstance(maximum_frequency, int) and maximum_frequency >= 0
if 'maximum_offset' in kwargs and kwargs['maximum_offset'] is not None:
maximum_offset = kwargs['maximum_offset']
assert isinstance(maximum_offset, int) and maximum_offset >= 0
assert (maximum_frequency is not None or maximum_offset is not None), \
'Error! Either the maximum frequency or the maximum offset for the frequencies must be set'
return diffops.diffops_Trivial_act_R2(in_repr, out_repr, max_power,
maximum_frequency,
max_offset=maximum_offset, discretization=discretization)
def _basespace_action(self, input: np.ndarray, element: Union[float, int]) -> np.ndarray:
assert self.fibergroup.is_element(element)
return input.copy()
def __eq__(self, other):
if isinstance(other, TrivialOnR2):
return self.fibergroup == other.fibergroup
else:
return False
def __hash__(self):
return hash(self.name)