escnn.nn
This subpackage provides implementations of equivariant neural network modules.
In an equivariant network, features are associated with a transformation law under actions of a symmetry group.
The transformation law of a feature field is implemented by its FieldType
which can be interpreted as a data type.
A GeometricTensor
is wrapping a torch.Tensor
to endow it with a FieldType
.
Geometric tensors are processed by EquivariantModule
s which are torch.nn.Module
s that guarantee the
specified behavior of their output fields given a transformation of their input fields.
This subpackage depends on escnn.group and escnn.gspaces.
To enable efficient deployment of equivariant networks, many EquivariantModule
s implement a
export()
method which converts a trained equivariant module into a pure PyTorch
module, with few or no dependencies with escnn.
Not all modules support this feature yet, so read each module’s documentation to check whether it implements this method
or not.
We provide a simple example:
# build a simple equivariant model using a SequentialModule
s = escnn.gspaces.rot2dOnR2(8)
c_in = escnn.nn.FieldType(s, [s.trivial_repr]*3)
c_hid = escnn.nn.FieldType(s, [s.regular_repr]*3)
c_out = escnn.nn.FieldType(s, [s.regular_repr]*1)
net = SequentialModule(
R2Conv(c_in, c_hid, 5, bias=False),
InnerBatchNorm(c_hid),
ReLU(c_hid, inplace=True),
PointwiseMaxPool(c_hid, kernel_size=3, stride=2, padding=1),
R2Conv(c_hid, c_out, 3, bias=False),
InnerBatchNorm(c_out),
ELU(c_out, inplace=True),
GroupPooling(c_out)
)
# train the model
# ...
# export the model
net.eval()
net_exported = net.export()
print(net)
> SequentialModule(
> (0): R2Conv([8-Rotations: {irrep_0, irrep_0, irrep_0}], [8-Rotations: {regular, regular, regular}], kernel_size=5, stride=1, bias=False)
> (1): InnerBatchNorm([8-Rotations: {regular, regular, regular}], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
> (2): ReLU(inplace=True, type=[8-Rotations: {regular, regular, regular}])
> (3): PointwiseMaxPool()
> (4): R2Conv([8-Rotations: {regular, regular, regular}], [8-Rotations: {regular}], kernel_size=3, stride=1, bias=False)
> (5): InnerBatchNorm([8-Rotations: {regular}], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
> (6): ELU(alpha=1.0, inplace=True, type=[8-Rotations: {regular}])
> (7): GroupPooling([8-Rotations: {regular}])
> )
print(net_exported)
> Sequential(
> (0): Conv2d(3, 24, kernel_size=(5, 5), stride=(1, 1), bias=False)
> (1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
> (2): ReLU(inplace=True)
> (3): MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), dilation=(1, 1), ceil_mode=False)
> (4): Conv2d(24, 8, kernel_size=(3, 3), stride=(1, 1), bias=False)
> (5): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
> (6): ELU(alpha=1.0, inplace=True)
> (7): MaxPoolChannels(kernel_size=8)
> )
# check that the two models are equivalent
x = torch.randn(10, c_in.size, 31, 31)
x = GeometricTensor(x, c_in)
y1 = net(x).tensor
y2 = net_exported(x.tensor)
assert torch.allclose(y1, y2)
Contents
Field Type
- class FieldType(gspace, representations)[source]
An
FieldType
can be interpreted as the data type of a feature space. It describes:the base space on which a feature field is living and its symmetries considered
the transformation law of feature fields under the action of the fiber group
The former is formalize by a choice of
gspace
while the latter is determined by a choice of group representations (representations
), passed as a list ofRepresentation
instances. Each single representation in this list corresponds to one independent feature field contained in the feature space. The inputrepresentations
need to belong togspace
’s fiber group (escnn.gspaces.GSpace.fibergroup
).Note
Mathematically, this class describes a (trivial) vector bundle, associated to the symmetry group \((\R^D, +) \rtimes G\).
Given a principal bundle \(\pi: (\R^D, +) \rtimes G \to \R^D, tg \mapsto tG\) with fiber group \(G\), an associated vector bundle has the same base space \(\R^D\) but its fibers are vector spaces like \(\mathbb{R}^c\). Moreover, these vector spaces are associated to a \(c\)-dimensional representation \(\rho\) of the fiber group \(G\) and transform accordingly.
The representation \(\rho\) is defined as the direct sum of the representations \(\{\rho_i\}_i\) in
representations
. See alsodirectsum()
.- Parameters
gspace (GSpace) – the space where the feature fields live and its symmetries
representations (tuple, list) – a list or tuple of
Representation
s of thegspace
’s fiber group, determining the transformation laws of the feature fields
- Variables
- property fibergroup: escnn.group.group.Group
The fiber group of
gspace
.- Returns
the fiber group
- property representation: escnn.group.representation.Representation
The (combined) representations of this field type. They describe how the feature vectors transform under the fiber group action, that is, how the channels mix.
It is the direct sum (
directsum()
) of the representations inescnn.nn.FieldType.representations
.Because a feature space can contain a very large number of feature fields, computing this representation as the direct sum of many small representations can be expensive. Hence, this representation is only built the first time it is explicitly used, in order to avoid unnecessary overhead when not needed.
- Returns
the
Representation
describing the whole feature space
- property irreps: List[Tuple]
Ordered list of irreps contained in the
representation
of the field type. It is the concatenation of the irreps in each representation inescnn.nn.FieldType.representations
.- Returns
list of irreps
- property change_of_basis: scipy.sparse.coo.coo_matrix
The change of basis matrix which decomposes the field types representation into irreps, given as a sparse (block diagonal) matrix (
scipy.sparse.coo_matrix
).It is the direct sum of the change of basis matrices of each representation in
escnn.nn.FieldType.representations
.See also
escnn.group.Representation.change_of_basis
- Returns
the change of basis
- property change_of_basis_inv: scipy.sparse.coo.coo_matrix
Inverse of the (sparse) change of basis matrix. See
escnn.nn.FieldType.change_of_basis
for more details.- Returns
the inverted change of basis
- get_dense_change_of_basis()[source]
The method returns a dense
torch.Tensor
containing a copy of the change-of-basis matrix.See also
See
escnn.nn.FieldType.change_of_basis
for more details.
- get_dense_change_of_basis_inv()[source]
The method returns a dense
torch.Tensor
containing a copy of the inverse of the change-of-basis matrix.See also
See
escnn.nn.FieldType.change_of_basis
for more details.
- transform_fibers(input, element)[source]
Transform the feature vectors of the input tensor according to the group representation associated to the input element.
Interpreting the tensor as a vector-valued signal \(f: X \to \mathbb{R}^c\) over a base space \(X\) (where \(c\) is the number of channels of the tensor), given the input
element
\(g \in G\) (\(G\) fiber group) the method returns the new signal \(f'\):\[f'(x) := \rho(g) f(x)\]for \(x \in X\) point in the base space and \(\rho\) the representation of \(G\) in the field type of this tensor.
Notice that the input element has to be an element of the fiber group of this tensor’s field type.
See also
See
escnn.nn.FieldType.transform()
to transform the whole tensor.- Parameters
input (torch.Tensor) – the tensor to transform
element (GroupElement) – an element of the group of symmetries of the fiber.
- Returns
the transformed tensor
- transform(input, element, coords=None, order=2)[source]
The method takes a PyTorch’s tensor, compatible with this type (i.e. whose spatial dimensions are supported by the base space and whose number of channels equals the
escnn.nn.FieldType.size
of this type), and an element of the fiber group of this type.Transform the input tensor according to the group representation associated with the input element and its (induced) action on the base space.
This transformation includes both an action over the basespace (e.g. a rotation of the points on the plane) and a transformation of the channels by left-multiplying them with a representation of the fiber group.
The method takes as input a tensor (
input
) and anelement
of the fiber group. The tensorinput
is the feature field to be transformed and needs to be compatible with the G-space and the representation (i.e. its number of channels equals the size of that representation).element
needs to belong to the fiber group: checkescnn.group.GroupElement.group()
. This method returns a transformed tensor through the action ofelement
.In addition, the method accepts an optional coords tensor. If the argument is not passed, the input tensor is assumed to have shape (batchsize, channels, *spatial_grid_shape) and to represent features sampled on a grid of shape spatial_grid_shape; in this case, the action on the base space resamples the transformed features on this grid (using interpolation, if necessary). If coords is not None, input is assumed to be a (#points, channels) tensor containing an unstructured set of points living on the base space; then, coords should contain the coordinates of these points. The base space action will then transform these coordinates (no interpolation required). In that case, the method returns a pair containing both the transformed features (according to the action on the fibers) and the transformed coordinates (according to the action on the basespace).
More precisely, given an input tensor, interpreted as an \(c\)-dimensional signal \(f: \R^D \to \mathbb{R}^c\) defined over the base space \(\R^D\), a representation \(\rho: G \to \mathbb{R}^{c \times c}\) of \(G\) and an element \(g \in G\) of the fiber group, the method returns the transformed signal \(f'\) defined as:
\[f'(x) := \rho(g) f(g^{-1} x)\]Note
Mathematically, this method transforms the input with the induced representation from the input
repr
(\(\rho\)) of the symmetry group (\(G\)) to the total space (\(P\)), i.e. with \(Ind_{G}^{P} \rho\). For more details on this, see General E(2)-Equivariant Steerable CNNs or A General Theory of Equivariant CNNs On Homogeneous Spaces.Warning
In case coords is not passed and, therefore, the resampling of the grid is performed, the input tensor is detached before the transformation, therefore no gradient is propagated back through this operation.
See also
See
escnn.nn.GeometricTensor.transform_fibers()
to transform only the fibers, i.e. not transform the base space.See
escnn.gspaces.GSpace._interpolate_transform_basespace()
for more details on the action on the base space.- Parameters
input (torch.Tensor) – input tensor
element (GroupElement) – element of the fiber group
coords (torch.Tensor, optional) – coordinates of the points in input. If None (by default), it assumes the points input are arranged in a grid and it transforms the grid by interpolation. Otherwise, it transforms the coordinates in coords using self.gspace.basespace_action(). In the last case, the method returns a tuple (transformed_input, transformed_coords).
- Returns
transformed tensor and, optionally, the transformed coordinates
- restrict(id)[source]
Reduce the symmetries modeled by the
FieldType
by choosing a subgroup of its fiber group as specified byid
. This implies a restriction of each representation inescnn.nn.FieldType.representations
to this subgroup.See also
Check the documentation of the
restrict()
method in the subclass ofGSpace
used for a description of the parameterid
.- Parameters
id – identifier of the subgroup to which the
FieldType
and itsescnn.nn.FieldType.representations
should be restricted- Returns
the restricted type
- sorted()[source]
Return a new field type containing the fields of the current one sorted by their dimensionalities. It is built from the
escnn.nn.FieldType.representations
of this field type sorted.- Returns
the sorted field type
- __add__(other)[source]
Returns a field type associate with the direct sum \(\rho = \rho_1 \oplus \rho_2\) of the representations \(\rho_1\) and \(\rho_2\) of two field types.
In practice, the method builds a new
FieldType
using the concatenation of the listsescnn.nn.FieldType.representations
of the two field types.The two field types need to be associated with the same
GSpace
.- Parameters
other (FieldType) – the other addend
- Returns
the direct sum
- __len__()[source]
Return the number of feature fields in this
FieldType
, i.e. the length ofescnn.nn.FieldType.representations
.Note
This is in general different from
escnn.nn.FieldType.size
.- Returns
the number of fields in this type
- fields_names()[source]
Return an ordered list containing the names of the representation associated with each field.
- Returns
the list of fields’ representations’ names
- index_select(index)[source]
Build a new
FieldType
from the current one by taking theRepresentation
s selected by the inputindex
.- Parameters
index (list) – a list of integers in the range
{0, ..., N-1}
, whereN
is the number of representations in the current field type- Returns
the new field type
- property fields_end: numpy.ndarray
Array containing the index of the first channel following each field. More precisely, the integer in the \(i\)-th position is equal to the index of the last channel of the \(i\)-th field plus \(1\).
- property fields_start: numpy.ndarray
Array containing the index of the first channel of each field. More precisely, the integer in the \(i\)-th position is equal to the index of the first channel of the \(i\)-th field.
- group_by_labels(labels)[source]
Associate a label to each feature field (or representation in
escnn.nn.FieldType.representations
) and group them accordingly into newFieldType
s.- Parameters
labels (list) – a list of strings with length equal to the number of representations in
escnn.nn.FieldType.representations
- Returns
a dictionary mapping each different input label to a new field type
- property uniform: bool
Whether this FieldType contains only copies of the same representation, i.e. if all the elements of
representations
are the sameescnn.group.Representation
.
- __iter__()[source]
It is possible to iterate over all
representations
in a field type by usingFieldType
as an iterable object.
- property testing_elements
Alias for
self.gspace.testing_elements
.
Geometric Tensor
- class GeometricTensor(tensor, type, coords=None)[source]
A GeometricTensor can be interpreted as a typed tensor. It is wrapping a common
torch.Tensor
and endows it with a (compatible)FieldType
as transformation law.The
FieldType
describes the action of a group \(G\) on the tensor. This action includes both a transformation of the base space and a transformation of the channels according to a \(G\)-representation \(\rho\).All escnn neural network operations have
GeometricTensor
s as inputs and outputs. They perform a dynamic typechecking, ensuring that the transformation laws of the data and the operation match. See alsoEquivariantModule
.As usual, the first dimension of the tensor is interpreted as the batch dimension. The second is the fiber (or channel) dimension, which is associated with a group representation by the field type. The following dimensions are the spatial dimensions (like in a conventional CNN).
In addition, the method accepts an optional
coords
tensor. If the argumentcoords
is not passed, the input tensor is assumed to have shape(batchsize, channels, *spatial_grid_shape)
and to represent features sampled on a grid of shapespatial_grid_shape
; in this case, the action on the base space resamples the transformed features on this grid (using interpolation, if necessary). Ifcoords
is notNone
,input
is assumed to be a(#points, channels)
tensor containing an unstructured set of points living on the base space; then,coords
should contain the coordinates of these points. The base space action will then transform these coordinates (no interpolation required). In that case, the method returns a pair containing both the transformed features (according to the action on the fibers) and the transformed coordinates (according to the action on the basespace).In addition, this class accepts an optional argument
coords
. If the argument is not passed, the inputtensor
is assumed to have shape(batchsize, channels, *spatial_grid_shape)
and to represent features sampled on a grid of shapespatial_grid_shape
; in this case, the action on the base space resamples the transformed features on this grid (using interpolation, if necessary). Ifcoords
is notNone
,tensor
is assumed to be a(#points, channels)
tensor containing an unstructured set of points living on the base space; then, coords should contain the coordinates of these points. The tensorcoords
must have shape(#points, type.gspace.dimensionality)
, where#points = tensor.shape[0]
. The base space action will then transform these coordinates (no interpolation required). See also the methodescnn.nn.FieldType.transform()
for more details.The operations of addition and scalar multiplication are supported. For example:
gs = escnn.gspaces.rot2dOnR2(8) type = escnn.nn.FieldType(gs, [gs.regular_repr]*3) t1 = escnn.nn.GeometricTensor(torch.randn(1, 24, 3, 3), type) t2 = escnn.nn.GeometricTensor(torch.randn(1, 24, 3, 3), type) # addition t3 = t1 + t2 # scalar product t3 = t1 * 3. # scalar product also supports tensors containing only one scalar t3 = t1 * torch.tensor(3.) # inplace operations are also supported t1 += t2 t2 *= 3.
Warning
The multiplication of a PyTorch tensor containing only a scalar with a GeometricTensor is only supported when using PyTorch 1.4 or higher (see this issue )
Warning
These operations are only supported for compatible tensors, i.e. tensors which share the same
type
and, if set,coords
.A GeometricTensor can also be right-multiplied with a
GroupElement
with the@
operator. The group element must belong to thefibergroup
of theFieldType
type
of this tensor. The right-multiplication through@
only transforms the channels using the group representationtype.representation
but does not transform the base space. This operation is equivalent totransform_fibers()
. Check its documentation for more details.A GeometricTensor supports slicing in a similar way to PyTorch’s
torch.Tensor
. More precisely, slicing along the batch (1st) and the spatial (3rd, 4th, …) dimensions works as usual. However, slicing the fiber (2nd) dimension would break equivariance when splitting channels belonging to the same field. To prevent this, slicing on the second dimension is defined over fields instead of channels. Note that, ifcoords
is not None, slicing over the first dimension also slices thecoords
tensor over its first dimension. Moreover, a GeometricTensor also partially supports advanced indexing (see NumPy’s documentation about indexing for more details).Warning
In contrast to NumPy and PyTorch, an index containing a single integer value does not reduce the dimensionality of the tensor. In this way, the resulting tensor can always be interpreted as a GeometricTensor.
We give few examples to illustrate this behavior:
# Example of GeometricTensor slicing space = escnn.gspaces.rot2dOnR2(4) type = escnn.nn.FieldType(space, [ # field type # index # size space.regular_repr, # 0 # 4 space.regular_repr, # 1 # 4 space.irrep(1), # 2 # 2 space.irrep(1), # 3 # 2 space.trivial_repr, # 4 # 1 space.trivial_repr, # 5 # 1 space.trivial_repr, # 6 # 1 ]) # sum = 15 # this FieldType contains 7 fields len(type) >> 7 # the size of this FieldType is equal to the sum of the sizes of each of its fields type.size >> 15 geom_tensor = escnn.nn.GeometricTensor(torch.randn(10, type.size, 9, 9), type) # or, equivalently: geom_tensor = type(torch.randn(10, type.size, 9, 9)) geom_tensor.shape >> torch.Size([10, 15, 9, 9]) geom_tensor[1:3, :, 2:5, 2:5].shape >> torch.Size([2, 15, 3, 3]) geom_tensor[..., 2:5].shape >> torch.Size([10, 15, 9, 3]) # the tensor contains the fields 1:4, i.e 1, 2 and 3 # these fields have size, respectively, 4, 2 and 2 # so the resulting tensor has 8 channels geom_tensor[:, 1:4, ...].shape >> torch.Size([10, 8, 9, 9]) # the tensor contains the fields 0:6:2, i.e 0, 2 and 4 # these fields have size, respectively, 4, 2 and 1 # so the resulting tensor has 7 channels geom_tensor[:, 0:6:2].shape >> torch.Size([10, 7, 9, 9]) # the tensor contains only the field 2, which has size 2 # note, also, that even though a single index is used for the batch dimension, the resulting tensor # still has 4 dimensions geom_tensor[3, 2].shape >> torch.Size(1, 2, 9, 9) # we can use a boolean tensor to perform advanced indexing over the fields of a Geometric Tensor. # the tensor contains only the two fields of size 2, i.e. the third and the fourth. # moreover, we also slice the batch dimension. idx = torch.tensor([type.representations[i].size == 2 for i in range(len(type))]) geom_tensor[0:4, idx].shape >> torch.Size(4, 4, 9, 9) # it is also possible to use an integer tensor to perform advanced indexing. # Here, we select the second and the sixth fields. idx = torch.tensor([1, 5]) geom_tensor[0:4, idx].shape >> torch.Size(4, 5, 9, 9) # advanced indexing over the other dimensions works as usual. idx = torch.tensor([1, 2, 2, 8, 4]) geom_tensor[idx, 1:3].shape >> torch.Size(5, 6, 9, 9)
Warning
Slicing over the fiber (2nd) dimension with
step > 1
or with a negative step is converted into indexing over the channels. This means that, in these cases, slicing behaves like advanced indexing in PyTorch and NumPy returning a copy instead of a view. For more details, see the note here and NumPy’s docs .Note
Slicing is not supported for setting values inside the tensor (i.e.
__setitem__()
is not implemented). Indeed, depending on the values which are assigned, this operation can break the symmetry of the tensor which may not transform anymore according to its transformation law (specified bytype
). In case this feature is necessary, one can directly access the underlyingtorch.Tensor
, e.g.geom_tensor.tensor[:3, :, 2:5, 2:5] = torch.randn(3, 4, 3, 3)
, although this is not recommended.Warning
Using advanced indexing over multiple axes simultaneously drops the corresponding dimensions (except for the first one). This means that the resulting tensor will likely not have a shape compatible with a Geometric Tensor (e.g. the spatial dimensions are dropped), causing an error.
- Parameters
tensor (torch.Tensor) – the tensor data
type (FieldType) – the type of the tensor, modeling its transformation law
coords (torch.Tensor, optional) – a tensor containing the coordinates of the datapoints
- Variables
~.tensor (torch.Tensor) –
~.type (FieldType) –
~.coords (torch.Tensor) –
- restrict(id)[source]
Restrict the field type of this tensor.
The method returns a new
GeometricTensor
whosetype
is equal to this tensor’stype
restricted to a subgroup \(H<G\) (seeescnn.nn.FieldType.restrict()
). The restrictedtype
is associated with the restricted representation \(\Res{H}{G}\rho\) of the \(G\)-representation \(\rho\) associated to this tensor’stype
. The inputid
specifies the subgroup \(H < G\).Notice that the underlying
tensor
instance will be shared between the current tensor and the returned one.Warning
The method builds the new representation on the fly; hence, if this operation is needed at run time, we suggest to use
escnn.nn.RestrictionModule
which pre-computes the new representation offline.See also
Check the documentation of the
restrict()
method in theGSpace
instance used for a description of the parameterid
.- Parameters
id – the id identifying the subgroup \(H\) the representations are restricted to
- Returns
the geometric tensor with the restricted representations
- split(breaks)[source]
Split this tensor on the channel dimension in a list of smaller tensors. The original tensor is split at the fields specified by the index list
breaks
.If the tensor is associated with the list of fields \(\{\rho_i\}_i\) (see
escnn.nn.FieldType.representations
), the \(j\)-th output tensor will contain the fields \(\rho_{\text{breaks}[j-1]}, \dots, \rho_{\text{breaks}[j]-1}\) of the original tensor.If breaks = None, the tensor is split at each field. This is equivalent to using breaks = list(range(len(self.type))).
Example
space = escnn.gspaces.rot2dOnR2(4) type = escnn.nn.FieldType(space, [ space.regular_repr, # size = 4 space.regular_repr, # size = 4 space.irrep(1), # size = 2 space.irrep(1), # size = 2 space.trivial_repr, # size = 1 space.trivial_repr, # size = 1 space.trivial_repr, # size = 1 ]) # sum = 15 type.size >> 15 geom_tensor = escnn.nn.GeometricTensor(torch.randn(10, type.size, 7, 7), type) geom_tensor.shape >> torch.Size([10, 15, 7, 7]) # split the tensor in 3 parts len(geom_tensor.split([0, 4, 6])) >> 3 # the first contains # - the first 2 regular fields (2*4 = 8 channels) # - 2 vector fields (irrep(1)) (2*2 = 4 channels) # and, therefore, contains 12 channels geom_tensor.split([0, 4, 6])[0].shape >> torch.Size([10, 12, 7, 7]) # the second contains only 2 scalar (trivial) fields (2*1 = 2 channels) geom_tensor.split([0, 4, 6])[1].shape >> torch.Size([10, 2, 7, 7]) # the last contains only 1 scalar (trivial) field (1*1 = 1 channels) geom_tensor.split([0, 4, 6])[2].shape >> torch.Size([10, 1, 7, 7])
- Parameters
breaks (list) – indices of the fields where to split the tensor
- Returns
list of
GeometricTensor
s into which the original tensor is chunked
- transform(element)[source]
Transform the current tensor according to the group representation associated to the input element and its induced action on the base space
Warning
The input tensor is detached before the transformation therefore no gradient is backpropagated through this operation
See
escnn.nn.GeometricTensor.transform_fibers()
to transform only the fibers, i.e. not transform the base space.- Parameters
element (GroupElement) – an element of the group of symmetries of the fiber.
- Returns
the transformed tensor
- transform_fibers(element)[source]
Transform the feature vectors of the underlying tensor according to the group representation associated to the input element.
Interpreting the tensor as a vector-valued signal \(f: X \to \mathbb{R}^c\) over a base space \(X\) (where \(c\) is the number of channels of the tensor), given the input
element
\(g \in G\) (\(G\) fiber group) the method returns the new signal \(f'\):\[f'(x) := \rho(g) f(x)\]for \(x \in X\) point in the base space and \(\rho\) the representation of \(G\) in the field type of this tensor.
Notice that the input element has to be an element of the fiber group of this tensor’s field type.
See also
See
escnn.nn.GeometricTensor.transform()
to transform the whole tensor.- Parameters
element (GroupElement) – an element of the group of symmetries of the fiber.
- Returns
the transformed tensor
- property shape
Alias for
self.tensor.shape
- to(*args, **kwargs)[source]
Alias for
self.tensor.to(*args, **kwargs)
.Applies
torch.Tensor.to()
to the underlying tensor and wraps the resulting tensor in a newGeometricTensor
with the same type.Warning
This method does not affect
self.coords
.
- __getitem__(slices)[source]
A GeometricTensor supports slicing in a similar way to PyTorch’s
torch.Tensor
. More precisely, slicing along the batch (1st) and the spatial (3rd, 4th, …) dimensions works as usual. However, slicing along the channel dimension could break equivariance by splitting the channels belonging to the same field. For this reason, slicing on the second dimension is not defined over the channels but over fields.When a continuous (step=1) slice is used over the fields/channels dimension (the 2nd axis), it is converted into a continuous slice over the channels. This is not possible when the step is greater than 1 or negative. In such cases, the slice over the fields needs to be converted into an index over the channels.
Moreover, when a single integer is used to index an axis, that axis is not discarded as in PyTorch but is preserved with size 1.
If slicing is performed over the batch dimension and coords is not None, also the coords tensor is sliced.
Slicing is not supported for setting values inside the tensor (i.e.
object.__setitem__()
is not implemented).
- __add__(other)[source]
Add two compatible
GeometricTensor
using pointwise addition. The two tensors needs to have the same shape and be associated to the same field type.- Parameters
other (GeometricTensor) – the other geometric tensor
- Returns
the sum
- __sub__(other)[source]
Subtract two compatible
GeometricTensor
using pointwise subtraction. The two tensors needs to have the same shape and be associated to the same field type.- Parameters
other (GeometricTensor) – the other geometric tensor
- Returns
their difference
- __iadd__(other)[source]
Add a compatible
GeometricTensor
to this tensor inplace. The two tensors needs to have the same shape and be associated to the same field type.- Parameters
other (GeometricTensor) – the other geometric tensor
- Returns
this tensor
- __isub__(other)[source]
Subtract a compatible
GeometricTensor
to this tensor inplace. The two tensors needs to have the same shape and be associated to the same field type.- Parameters
other (GeometricTensor) – the other geometric tensor
- Returns
this tensor
- __mul__(other)[source]
Scalar product of this
GeometricTensor
with a scalar.The scalar can be a float number or a
torch.Tensor
containing only one scalar (i.e.torch.numel()
should return 1).
- __rmul__(other)
Scalar product of this
GeometricTensor
with a scalar.The scalar can be a float number or a
torch.Tensor
containing only one scalar (i.e.torch.numel()
should return 1).
- __imul__(other)[source]
Scalar product of this
GeometricTensor
with a scalar. The operation is done inplace.The scalar can be a float number of a
torch.Tensor
containing only one scalar (i.e.torch.numel()
should return 1).
- __rmatmul__(other)[source]
Equivalent to
escnn.nn.GeometricTensor.transform_fibers()
.Warning
This only transforms the fibers, not the basespace.
Todo
should this be like .transform(element) ?
Equivariant Module
- class EquivariantModule[source]
Abstract base class for all equivariant modules.
An
EquivariantModule
is a subclass oftorch.nn.Module
. It follows that any subclass ofEquivariantModule
needs to implement theforward()
method. With respect to a generaltorch.nn.Module
, an equivariant module implements a typed function as both its input and its output are associated with specificFieldType
s. Therefore, usually, the inputs and the outputs of an equivariant module are not just instances oftorch.Tensor
butGeometricTensor
s.As a subclass of
torch.nn.Module
, it supports most of the commonly used methods (e.g.torch.nn.Module.to()
,torch.nn.Module.cuda()
,torch.nn.Module.train()
ortorch.nn.Module.eval()
)Many equivariant modules implement a
export()
method which converts the module to eval mode and returns a pure PyTorch implementation of it. This can be used after training to efficiently deploy the model without, for instance, the overhead of the automatic type checking performed by all the modules in this library.Warning
Not all modules implement this feature yet. If the
export()
method is called in a module which does not implement it yet, aNotImplementedError
is raised. Check the documentation of each individual module to understand if the method is implemented.- Variables
~.in_type (FieldType) – type of the
GeometricTensor
expected as input~.out_type (FieldType) – type of the
GeometricTensor
returned as output
- abstract evaluate_output_shape(input_shape)[source]
Compute the shape the output tensor which would be generated by this module when a tensor with shape
input_shape
is provided as input.- Parameters
input_shape (tuple) – shape of the input tensor
- Returns
shape of the output tensor
- check_equivariance(atol=1e-07, rtol=1e-05)[source]
Method that automatically tests the equivariance of the current module. The default implementation of this method relies on
escnn.nn.GeometricTensor.transform()
and uses the the group elements intesting_elements
.This method can be overwritten for custom tests.
- Returns
a list containing containing for each testing element a pair with that element and the corresponding equivariance error
- export()[source]
Export recursively each submodule to a normal PyTorch module and set to “eval” mode.
Warning
Not all modules implement this feature yet. If the
export()
method is called in a module which does not implement it yet, aNotImplementedError
is raised. Check the documentation of each individual module to understand if the method is implemented.Warning
Since most modules do not use the coords attribute of the input
GeometricTensor
, once converted, they will only expect tensor but not coords in input. There is no standard behavior for modules that explicitly use coords, so check their specific documentation.
Utils
direct sum
- tensor_directsum(tensors)[source]
Concatenate a list of
GeometricTensor
s on the channels dimension (dim=1
). The input tensors have to be compatible: they need to have the same shape except for the channels dimension (dim=1
); additionally, if their coords attribute is not None, they also need to share the same value.In the resulting
GeometricTensor
, the channels dimension will be associated with the direct sum representation of the representations of the input tensors.See also
- Parameters
tensors (list) – a list of
GeometricTensor
s- Returns
the direct sum of the inputs
Linear Layers
Linear
- class Linear(in_type, out_type, bias=True, basisexpansion='blocks', recompute=False, initialize=True)[source]
Bases:
escnn.nn.modules.equivariant_module.EquivariantModule
G-equivariant linear transformation mapping between the input and output
FieldType
s specified by the parametersin_type
andout_type
. This operation is equivariant under the action of theescnn.nn.FieldType.fibergroup
\(G\) ofin_type
andout_type
.Specifically, let \(\rho_\text{in}: G \to \GL{\R^{c_\text{in}}}\) and \(\rho_\text{out}: G \to \GL{\R^{c_\text{out}}}\) be the representations specified by the input and output field types. Then
Linear
guarantees an equivariant mapping\[W \rho_\text{in}(g) v = \rho_\text{out}(g) W v \qquad\qquad \forall g \in G, u \in \R^{c_\text{in}}\]where \(\rho_\text{in}\) and \(\rho_\text{out}\) are the \(G\)-representations associated with
in_type
andout_type
.The equivariance of a G-equivariant linear layer is guaranteed by restricting the space of weight matrices to an equivariant subspace.
During training, in each forward pass the module expands the basis of G-equivariant matrices with learned weights before performing the linear trasformation. When
eval()
is called, the matrix is built with the current trained weights and stored for future reuse such that no overhead of expanding the matrix remains.Warning
When
train()
is called, the attributesmatrix
andexpanded_bias
are discarded to avoid situations of mismatch with the learnable expansion coefficients. See alsoescnn.nn.Linear.train()
.This behaviour can cause problems when storing the
state_dict()
of a model while in a mode and lately loading it in a model with a different mode, as the attributes of the class change. To avoid this issue, we recommend converting the model to eval mode before storing or loading the state dictionary.Warning
To ensure compatibility with both
torch.nn.Linear
andGeometricTensor
, this module supports only input tensors with two dimensions(batch_size, number_features)
.The learnable expansion coefficients of the this module can be initialized with the methods in
escnn.nn.init
. By default, the weights are initialized in the constructors usinggeneralized_he_init()
.Warning
This initialization procedure can be extremely slow for wide layers. In case initializing the model is not required (e.g. before loading the state dict of a pre-trained model) or another initialization method is preferred (e.g.
deltaorthonormal_init()
), the parameterinitialize
can be set toFalse
to avoid unnecessary overhead.- Parameters
in_type (FieldType) – the type of the input field, specifying its transformation law
out_type (FieldType) – the type of the output field, specifying its transformation law
bias (bool, optional) – Whether to add a bias to the output (only to fields which contain a trivial irrep) or not. Default
True
basisexpansion (str, optional) – the basis expansion algorithm to use. You can ignore this attribute.
recompute (bool, optional) – if
True
, recomputes a new basis for the equivariant kernels. By Default (False
), it caches the basis built or reuse a cached one, if it is found.initialize (bool, optional) – initialize the weights of the model. Default:
True
- Variables
~.weights (torch.Tensor) – the learnable parameters which are used to expand the matrix
~.matrix (torch.Tensor) – the matrix obtained by expanding the parameters in
weights
~.bias (torch.Tensor) – the learnable parameters which are used to expand the bias, if
bias=True
~.expanded_bias (torch.Tensor) – the equivariant bias which is summed to the output, obtained by expanding the parameters in
bias
- forward(input)[source]
Convolve the input with the expanded matrix and bias.
- Parameters
input (GeometricTensor) – input feature field transforming according to
in_type
- Returns
output feature field transforming according to
out_type
- property basisexpansion: escnn.nn.modules.basismanager.basismanager.BasisManager
Submodule which takes care of building the matrix.
It uses the learnt
weights
to expand a basis and returns a matrix in the usual form used by conventional linear modules. It uses the learnedweights
to expand the kernel in the G-steerable basis and returns it in the shape \((c_\text{out}, c_\text{in})\).
- expand_parameters()[source]
Expand the matrix in terms of the
escnn.nn.Linear.weights
and the expanded bias in terms ofescnn.nn.Linear.bias
.- Returns
the expanded matrix and bias
- train(mode=True)[source]
If
mode=True
, the method sets the module in training mode and discards thematrix
andexpanded_bias
attributes.If
mode=False
, it sets the module in evaluation mode. Moreover, the method builds the matrix and the bias using the current values of the trainable parameters and store them inmatrix
andexpanded_bias
such that they are not recomputed at each forward pass.Warning
This behaviour can cause problems when storing the
state_dict()
of a model while in a mode and lately loading it in a model with a different mode, as the attributes of this class change. To avoid this issue, we recommend converting the model to eval mode before storing or loading the state dictionary.- Parameters
mode (bool, optional) – whether to set training mode (
True
) or evaluation mode (False
). Default:True
.
- export()[source]
Export this module to a normal PyTorch
torch.nn.Linear
module and set to “eval” mode.
Steerable Dense Convolution
The following modules implement discretized convolution operators over discrete grids. This means that equivariance to continuous symmetries is not perfect. In practice, by using sufficiently band-limited filters, the equivariance error introduced by the discretization of the filters and the features is contained, but some design choices may have a negative effect on the overall equivariance of the architecture.
We also provide some practical notes on using these discretized convolution modules.
RdConv
- class _RdConv(in_type, out_type, d, kernel_size, padding=0, stride=1, dilation=1, padding_mode='zeros', groups=1, bias=True, basis_filter=None, recompute=False)[source]
Bases:
escnn.nn.modules.equivariant_module.EquivariantModule
,abc.ABC
Abstract class which implements a general G-steerable convolution, mapping between the input and output
FieldType
s specified by the parametersin_type
andout_type
. This operation is equivariant under the action of \(\R^d\rtimes G\) where \(G\) is theescnn.nn.FieldType.fibergroup
ofin_type
andout_type
.Specifically, let \(\rho_\text{in}: G \to \GL{\R^{c_\text{in}}}\) and \(\rho_\text{out}: G \to \GL{\R^{c_\text{out}}}\) be the representations specified by the input and output field types. Then
_RdConv
guarantees an equivariant mapping\[\kappa \star [\mathcal{T}^\text{in}_{g,u} . f] = \mathcal{T}^\text{out}_{g,u} . [\kappa \star f] \qquad\qquad \forall g \in G, u \in \R^d\]where the transformation of the input and output fields are given by
\[\begin{split}[\mathcal{T}^\text{in}_{g,u} . f](x) &= \rho_\text{in}(g)f(g^{-1} (x - u)) \\ [\mathcal{T}^\text{out}_{g,u} . f](x) &= \rho_\text{out}(g)f(g^{-1} (x - u)) \\\end{split}\]The equivariance of G-steerable convolutions is guaranteed by restricting the space of convolution kernels to an equivariant subspace. As proven in 3D Steerable CNNs, this parametrizes the most general equivariant convolutional map between the input and output fields.
Warning
This class implements a discretized convolution operator over a discrete grid. This means that equivariance to continuous symmetries is not perfect. In practice, by using sufficiently band-limited filters, the equivariance error introduced by the discretization of the filters and the features is contained, but some design choices may have a negative effect on the overall equivariance of the architecture.
We provide some practical notes on using this discretized convolution module.
During training, in each forward pass the module expands the basis of G-steerable kernels with learned weights before performing the convolution. When
eval()
is called, the filter is built with the current trained weights and stored for future reuse such that no overhead of expanding the kernel remains.Warning
When
train()
is called, the attributesfilter
andexpanded_bias
are discarded to avoid situations of mismatch with the learnable expansion coefficients. See alsoescnn.nn._RdConv.train()
.This behaviour can cause problems when storing the
state_dict()
of a model while in a mode and lately loading it in a model with a different mode, as the attributes of the class change. To avoid this issue, we recommend converting the model to eval mode before storing or loading the state dictionary.- Parameters
in_type (FieldType) – the type of the input field, specifying its transformation law
out_type (FieldType) – the type of the output field, specifying its transformation law
d (int) – dimensionality of the base space (2 for images, 3 for volumes)
kernel_size (int) – the size of the (square) filter
padding (int, optional) – implicit zero paddings on both sides of the input. Default:
0
padding_mode (str, optional) –
zeros
,reflect
,replicate
orcircular
. Default:zeros
stride (int, optional) – the stride of the kernel. Default:
1
dilation (int, optional) – the spacing between kernel elements. Default:
1
groups (int, optional) – number of blocked connections from input channels to output channels. It allows depthwise convolution. When used, the input and output types need to be divisible in
groups
groups, all equal to each other. Default:1
.bias (bool, optional) – Whether to add a bias to the output (only to fields which contain a trivial irrep) or not. Default
True
basis_filter (callable, optional) – filter for the basis elements. Should take a dictionary containing an element’s attributes and return whether to keep it or not.
recompute (bool, optional) – if
True
, recomputes a new basis for the equivariant kernels. By Default (False
), it caches the basis built or reuse a cached one, if it is found.
- Variables
~.weights (torch.Tensor) – the learnable parameters which are used to expand the kernel
~.filter (torch.Tensor) – the convolutional kernel obtained by expanding the parameters in
weights
~.bias (torch.Tensor) – the learnable parameters which are used to expand the bias, if
bias=True
~.expanded_bias (torch.Tensor) – the equivariant bias which is summed to the output, obtained by expanding the parameters in
bias
- property basisexpansion: escnn.nn.modules.basismanager.basisexpansion_blocks.BlocksBasisExpansion
Submodule which takes care of building the filter.
It uses the learnt
weights
to expand a basis and returns a filter in the usual form used by conventional convolutional modules. It uses the learnedweights
to expand the kernel in the G-steerable basis and returns it in the shape \((c_\text{out}, c_\text{in}, s^d)\), where \(s\) is thekernel_size
and \(d\) is the dimensionality of the base space.
- expand_parameters()[source]
Expand the filter in terms of the
weights
and the expanded bias in terms ofbias
.- Returns
the expanded filter and bias
- abstract forward(input)[source]
Convolve the input with the expanded filter and bias.
- Parameters
input (GeometricTensor) – input feature field transforming according to
in_type
- Returns
output feature field transforming according to
out_type
- train(mode=True)[source]
If
mode=True
, the method sets the module in training mode and discards thefilter
andexpanded_bias
attributes.If
mode=False
, it sets the module in evaluation mode. Moreover, the method builds the filter and the bias using the current values of the trainable parameters and store them infilter
andexpanded_bias
such that they are not recomputed at each forward pass.Warning
This behaviour can cause problems when storing the
state_dict()
of a model while in a mode and lately loading it in a model with a different mode, as the attributes of this class change. To avoid this issue, we recommend converting the model to eval mode before storing or loading the state dictionary.- Parameters
mode (bool, optional) – whether to set training mode (
True
) or evaluation mode (False
). Default:True
.
R2Conv
- class R2Conv(in_type, out_type, kernel_size, padding=0, stride=1, dilation=1, padding_mode='zeros', groups=1, bias=True, sigma=None, frequencies_cutoff=None, rings=None, maximum_offset=None, recompute=False, basis_filter=None, initialize=True)[source]
Bases:
escnn.nn.modules.conv.rd_convolution._RdConv
G-steerable planar convolution mapping between the input and output
FieldType
s specified by the parametersin_type
andout_type
. This operation is equivariant under the action of \(\R^2\rtimes G\) where \(G\) is theescnn.nn.FieldType.fibergroup
ofin_type
andout_type
.Specifically, let \(\rho_\text{in}: G \to \GL{\R^{c_\text{in}}}\) and \(\rho_\text{out}: G \to \GL{\R^{c_\text{out}}}\) be the representations specified by the input and output field types. Then
R2Conv
guarantees an equivariant mapping\[\kappa \star [\mathcal{T}^\text{in}_{g,u} . f] = \mathcal{T}^\text{out}_{g,u} . [\kappa \star f] \qquad\qquad \forall g \in G, u \in \R^2\]where the transformation of the input and output fields are given by
\[\begin{split}[\mathcal{T}^\text{in}_{g,u} . f](x) &= \rho_\text{in}(g)f(g^{-1} (x - u)) \\ [\mathcal{T}^\text{out}_{g,u} . f](x) &= \rho_\text{out}(g)f(g^{-1} (x - u)) \\\end{split}\]The equivariance of G-steerable convolutions is guaranteed by restricting the space of convolution kernels to an equivariant subspace. As proven in 3D Steerable CNNs, this parametrizes the most general equivariant convolutional map between the input and output fields. For feature fields on \(\R^2\) (e.g. images), the complete G-steerable kernel spaces for \(G \leq \O2\) is derived in General E(2)-Equivariant Steerable CNNs.
Warning
This class implements a discretized convolution operator over a discrete grid. This means that equivariance to continuous symmetries is not perfect. In practice, by using sufficiently band-limited filters, the equivariance error introduced by the discretization of the filters and the features is contained, but some design choices may have a negative effect on the overall equivariance of the architecture.
We provide some practical notes on using this discretized convolution module.
During training, in each forward pass the module expands the basis of G-steerable kernels with learned weights before calling
torch.nn.functional.conv2d()
. Wheneval()
is called, the filter is built with the current trained weights and stored for future reuse such that no overhead of expanding the kernel remains.Warning
When
train()
is called, the attributesfilter
andexpanded_bias
are discarded to avoid situations of mismatch with the learnable expansion coefficients. See alsoescnn.nn.R2Conv.train()
.This behaviour can cause problems when storing the
state_dict()
of a model while in a mode and lately loading it in a model with a different mode, as the attributes of the class change. To avoid this issue, we recommend converting the model to eval mode before storing or loading the state dictionary.The learnable expansion coefficients of the this module can be initialized with the methods in
escnn.nn.init
. By default, the weights are initialized in the constructors usinggeneralized_he_init()
.Warning
This initialization procedure can be extremely slow for wide layers. In case initializing the model is not required (e.g. before loading the state dict of a pre-trained model) or another initialization method is preferred (e.g.
deltaorthonormal_init()
), the parameterinitialize
can be set toFalse
to avoid unnecessary overhead. See also this issueThe parameters
basisexpansion
,sigma
,frequencies_cutoff
,rings
andmaximum_offset
are optional parameters used to control how the basis for the filters is built, how it is sampled on the filter grid and how it is expanded to build the filter. We suggest to keep these default values.Warning
Even if the input tensor has a coords attribute, the output of this module will not have one.
- Parameters
in_type (FieldType) – the type of the input field, specifying its transformation law
out_type (FieldType) – the type of the output field, specifying its transformation law
kernel_size (int) – the size of the (square) filter
padding (int, optional) – implicit zero paddings on both sides of the input. Default:
0
stride (int, optional) – the stride of the kernel. Default:
1
dilation (int, optional) – the spacing between kernel elements. Default:
1
padding_mode (str, optional) –
zeros
,reflect
,replicate
orcircular
. Default:zeros
groups (int, optional) – number of blocked connections from input channels to output channels. It allows depthwise convolution. When used, the input and output types need to be divisible in
groups
groups, all equal to each other. Default:1
.bias (bool, optional) – Whether to add a bias to the output (only to fields which contain a trivial irrep) or not. Default
True
sigma (list or float, optional) – width of each ring where the bases are sampled. If only one scalar is passed, it is used for all rings.
frequencies_cutoff (callable or float, optional) – function mapping the radii of the basis elements to the maximum frequency accepted. If a float values is passed, the maximum frequency is equal to the radius times this factor. By default (
None
), a more complex policy is used.rings (list, optional) – radii of the rings where to sample the bases
maximum_offset (int, optional) – number of additional (aliased) frequencies in the intertwiners for finite groups. By default (
None
), all additional frequencies allowed by the frequencies cut-off are used.recompute (bool, optional) – if
True
, recomputes a new basis for the equivariant kernels. By Default (False
), it caches the basis built or reuse a cached one, if it is found.basis_filter (callable, optional) – function which takes as input a descriptor of a basis element (as a dictionary) and returns a boolean value: whether to preserve (
True
) or discard (False
) the basis element. By default (None
), no filtering is applied.initialize (bool, optional) – initialize the weights of the model. Default:
True
- Variables
~.weights (torch.Tensor) – the learnable parameters which are used to expand the kernel
~.filter (torch.Tensor) – the convolutional kernel obtained by expanding the parameters in
weights
~.bias (torch.Tensor) – the learnable parameters which are used to expand the bias, if
bias=True
~.expanded_bias (torch.Tensor) – the equivariant bias which is summed to the output, obtained by expanding the parameters in
bias
- forward(input)[source]
Convolve the input with the expanded filter and bias.
- Parameters
input (GeometricTensor) – input feature field transforming according to
in_type
- Returns
output feature field transforming according to
out_type
- export()[source]
Export this module to a normal PyTorch
torch.nn.Conv2d
module and set to “eval” mode.
R3Conv
- class R3Conv(in_type, out_type, kernel_size, padding=0, stride=1, dilation=1, padding_mode='zeros', groups=1, bias=True, sigma=None, frequencies_cutoff=None, rings=None, recompute=False, basis_filter=None, initialize=True)[source]
Bases:
escnn.nn.modules.conv.rd_convolution._RdConv
G-steerable planar convolution mapping between the input and output
FieldType
s specified by the parametersin_type
andout_type
. This operation is equivariant under the action of \(\R^3\rtimes G\) where \(G\) is theescnn.nn.FieldType.fibergroup
ofin_type
andout_type
.Specifically, let \(\rho_\text{in}: G \to \GL{\R^{c_\text{in}}}\) and \(\rho_\text{out}: G \to \GL{\R^{c_\text{out}}}\) be the representations specified by the input and output field types. Then
R3Conv
guarantees an equivariant mapping\[\kappa \star [\mathcal{T}^\text{in}_{g,u} . f] = \mathcal{T}^\text{out}_{g,u} . [\kappa \star f] \qquad\qquad \forall g \in G, u \in \R^3\]where the transformation of the input and output fields are given by
\[\begin{split}[\mathcal{T}^\text{in}_{g,u} . f](x) &= \rho_\text{in}(g)f(g^{-1} (x - u)) \\ [\mathcal{T}^\text{out}_{g,u} . f](x) &= \rho_\text{out}(g)f(g^{-1} (x - u)) \\\end{split}\]The equivariance of G-steerable convolutions is guaranteed by restricting the space of convolution kernels to an equivariant subspace. As proven in 3D Steerable CNNs, this parametrizes the most general equivariant convolutional map between the input and output fields.
Warning
This class implements a discretized convolution operator over a discrete grid. This means that equivariance to continuous symmetries is not perfect. In practice, by using sufficiently band-limited filters, the equivariance error introduced by the discretization of the filters and the features is contained, but some design choices may have a negative effect on the overall equivariance of the architecture.
We provide some practical notes on using this discretized convolution module.
During training, in each forward pass the module expands the basis of G-steerable kernels with learned weights before calling
torch.nn.functional.conv3d()
. Wheneval()
is called, the filter is built with the current trained weights and stored for future reuse such that no overhead of expanding the kernel remains.Warning
When
train()
is called, the attributesfilter
andexpanded_bias
are discarded to avoid situations of mismatch with the learnable expansion coefficients. See alsoescnn.nn.R3Conv.train()
.This behaviour can cause problems when storing the
state_dict()
of a model while in a mode and lately loading it in a model with a different mode, as the attributes of the class change. To avoid this issue, we recommend converting the model to eval mode before storing or loading the state dictionary.The learnable expansion coefficients of the this module can be initialized with the methods in
escnn.nn.init
. By default, the weights are initialized in the constructors usinggeneralized_he_init()
.Warning
This initialization procedure can be extremely slow for wide layers. In case initializing the model is not required (e.g. before loading the state dict of a pre-trained model) or another initialization method is preferred (e.g.
deltaorthonormal_init()
), the parameterinitialize
can be set toFalse
to avoid unnecessary overhead. See also this issueThe parameters
sigma
,frequencies_cutoff
andrings
are optional parameters used to control how the basis for the filters is built, how it is sampled on the filter grid and how it is expanded to build the filter. We suggest to keep these default values.Warning
Even if the input tensor has a coords attribute, the output of this module will not have one.
- Parameters
in_type (FieldType) – the type of the input field, specifying its transformation law
out_type (FieldType) – the type of the output field, specifying its transformation law
kernel_size (int) – the size of the (square) filter
padding (int, optional) – implicit zero paddings on both sides of the input. Default:
0
stride (int, optional) – the stride of the kernel. Default:
1
dilation (int, optional) – the spacing between kernel elements. Default:
1
padding_mode (str, optional) –
zeros
,reflect
,replicate
orcircular
. Default:zeros
groups (int, optional) – number of blocked connections from input channels to output channels. It allows depthwise convolution. When used, the input and output types need to be divisible in
groups
groups, all equal to each other. Default:1
.bias (bool, optional) – Whether to add a bias to the output (only to fields which contain a trivial irrep) or not. Default
True
sigma (list or float, optional) – width of each ring where the bases are sampled. If only one scalar is passed, it is used for all rings.
frequencies_cutoff (callable or float, optional) – function mapping the radii of the basis elements to the maximum frequency accepted. If a float values is passed, the maximum frequency is equal to the radius times this factor. By default (
None
), a more complex policy is used.rings (list, optional) – radii of the rings where to sample the bases
recompute (bool, optional) – if
True
, recomputes a new basis for the equivariant kernels. By Default (False
), it caches the basis built or reuse a cached one, if it is found.basis_filter (callable, optional) – function which takes as input a descriptor of a basis element (as a dictionary) and returns a boolean value: whether to preserve (
True
) or discard (False
) the basis element. By default (None
), no filtering is applied.initialize (bool, optional) – initialize the weights of the model. Default:
True
- Variables
~.weights (torch.Tensor) – the learnable parameters which are used to expand the kernel
~.filter (torch.Tensor) – the convolutional kernel obtained by expanding the parameters in
weights
~.bias (torch.Tensor) – the learnable parameters which are used to expand the bias, if
bias=True
~.expanded_bias (torch.Tensor) – the equivariant bias which is summed to the output, obtained by expanding the parameters in
bias
- forward(input)[source]
Convolve the input with the expanded filter and bias.
- Parameters
input (GeometricTensor) – input feature field transforming according to
in_type
- Returns
output feature field transforming according to
out_type
- export()[source]
Export this module to a normal PyTorch
torch.nn.Conv3d
module and set to “eval” mode.
R3IcoConv
- class R3IcoConv(in_type, out_type, kernel_size, padding=0, stride=1, dilation=1, padding_mode='zeros', groups=1, bias=True, samples='ico', sigma=None, rings=None, recompute=False, basis_filter=None, initialize=True)[source]
Bases:
escnn.nn.modules.conv.r3convolution.R3Conv
Icosahedral-steerable volumetric convolution mapping between the input and output
FieldType
s specified by the parametersin_type
andout_type
. This operation is equivariant under the action of \(\R^3\rtimes I\) where \(I\) (theIcosahedral
group) is theescnn.nn.FieldType.fibergroup
ofin_type
andout_type
.This class is mostly similar to
R3Conv
, with the only difference that it only supports the groupIcosahedral
since it uses a kernel basis which is specific for this group.Warning
With respect to
R3Conv
, this convolution layer uses a different steerable basis to parameterize its filters. This basis does not rely on band-limited spherical harmonics but on a finite number of orbits of the Icosahedral group in \(\R^3\). This basis is less robust to discretization (see Figure 5 in our paper ). For this reason, using theR3Conv
class is often preferable.The argument
frequencies_cutoff
ofR3Conv
is not supported here since the steerable kernels are not generated from a band-limited set of harmonic functions.Instead, the argument
samples
specifies the polyhedron (symmetric with respect to theIcosahedral
group) whose vertices are used to define the kernel on \(\R^3\). The supported polyhedrons are"ico"
(the 12 vertices of the icosahedron),"dodeca"
(the 20 vertices of the dodecahedron) or"icosidodeca"
(the 30 vertices of the icosidodecahedron, which correspond to the centers of the 30 edges of either the icosahedron or the dodecahedron).For each ring
r
inrings
, the polyhedron specified in embedded in the sphere of radiusr
. The analytical kernel, which is only defined on the vertices of this polyhedron, is then “diffused” in the ambient space \(\R^3\) by means of a small Gaussian kernel with stdsigma
.Warning
Even if the input tensor has a coords attribute, the output of this module will not have one.
- Parameters
in_type (FieldType) – the type of the input field, specifying its transformation law
out_type (FieldType) – the type of the output field, specifying its transformation law
kernel_size (int) – the size of the (square) filter
padding (int, optional) – implicit zero paddings on both sides of the input. Default:
0
stride (int, optional) – the stride of the kernel. Default:
1
dilation (int, optional) – the spacing between kernel elements. Default:
1
padding_mode (str, optional) –
zeros
,reflect
,replicate
orcircular
. Default:zeros
groups (int, optional) – number of blocked connections from input channels to output channels. It allows depthwise convolution. When used, the input and output types need to be divisible in
groups
groups, all equal to each other. Default:1
.bias (bool, optional) – Whether to add a bias to the output (only to fields which contain a trivial irrep) or not. Default
True
sigma (list or float, optional) – width of each ring where the bases are sampled. If only one scalar is passed, it is used for all rings.
rings (list, optional) – radii of the rings where to sample the bases
recompute (bool, optional) – if
True
, recomputes a new basis for the equivariant kernels. By Default (False
), it caches the basis built or reuse a cached one, if it is found.basis_filter (callable, optional) – function which takes as input a descriptor of a basis element (as a dictionary) and returns a boolean value: whether to preserve (
True
) or discard (False
) the basis element. By default (None
), no filtering is applied.initialize (bool, optional) – initialize the weights of the model. Default:
True
R2ConvTransposed
- class R2ConvTransposed(in_type, out_type, kernel_size, padding=0, output_padding=0, stride=1, dilation=1, groups=1, bias=True, sigma=None, frequencies_cutoff=None, rings=None, maximum_offset=None, recompute=False, basis_filter=None, initialize=True)[source]
Bases:
escnn.nn.modules.conv.rd_transposed_convolution._RdConvTransposed
Transposed G-steerable planar convolution layer.
Warning
This class implements a discretized convolution operator over a discrete grid. This means that equivariance to continuous symmetries is not perfect. In practice, by using sufficiently band-limited filters, the equivariance error introduced by the discretization of the filters and the features is contained, but some design choices may have a negative effect on the overall equivariance of the architecture.
We provide some practical notes on using this discretized convolution module.
Warning
Transposed convolution can produce artifacts which can harm the overall equivariance of the model. We suggest using
R2Upsampling
combined withR2Conv
to perform upsampling.See also
For additional information about the parameters and the methods of this class, see
escnn.nn.R2Conv
. The two modules are essentially the same, except for the type of convolution used.Warning
Even if the input tensor has a coords attribute, the output of this module will not have one.
- Parameters
in_type (FieldType) – the type of the input field
out_type (FieldType) – the type of the output field
kernel_size (int) – the size of the filter
padding (int, optional) – implicit zero paddings on both sides of the input. Default:
0
output_padding (int, optional) – implicit zero paddings on both sides of the input. Default:
0
stride (int, optional) – the stride of the convolving kernel. Default:
1
dilation (int, optional) – the spacing between kernel elements. Default:
1
groups (int, optional) – number of blocked connections from input channels to output channels. Default:
1
.bias (bool, optional) – Whether to add a bias to the output (only to fields which contain a trivial irrep) or not. Default
True
initialize (bool, optional) – initialize the weights of the model. Default:
True
- export()[source]
Export this module to a normal PyTorch
torch.nn.ConvTranspose2d
module and set to “eval” mode.
R3ConvTransposed
- class R3ConvTransposed(in_type, out_type, kernel_size, padding=0, output_padding=0, stride=1, dilation=1, groups=1, bias=True, sigma=None, frequencies_cutoff=None, rings=None, maximum_offset=None, recompute=False, basis_filter=None, initialize=True)[source]
Bases:
escnn.nn.modules.conv.rd_transposed_convolution._RdConvTransposed
Transposed G-steerable 3D convolution layer.
Warning
This class implements a discretized convolution operator over a discrete grid. This means that equivariance to continuous symmetries is not perfect. In practice, by using sufficiently band-limited filters, the equivariance error introduced by the discretization of the filters and the features is contained, but some design choices may have a negative effect on the overall equivariance of the architecture.
We provide some practical notes on using this discretized convolution module.
Warning
Transposed convolution can produce artifacts which can harm the overall equivariance of the model. We suggest using
R2Upsampling
combined withR3Conv
to perform upsampling.See also
For additional information about the parameters and the methods of this class, see
escnn.nn.R3Conv
. The two modules are essentially the same, except for the type of convolution used.Warning
Even if the input tensor has a coords attribute, the output of this module will not have one.
- Parameters
in_type (FieldType) – the type of the input field
out_type (FieldType) – the type of the output field
kernel_size (int) – the size of the filter
padding (int, optional) – implicit zero paddings on both sides of the input. Default:
0
output_padding (int, optional) – implicit zero paddings on both sides of the input. Default:
0
stride (int, optional) – the stride of the convolving kernel. Default:
1
dilation (int, optional) – the spacing between kernel elements. Default:
1
groups (int, optional) – number of blocked connections from input channels to output channels. Default:
1
.bias (bool, optional) – Whether to add a bias to the output (only to fields which contain a trivial irrep) or not. Default
True
initialize (bool, optional) – initialize the weights of the model. Default:
True
- export()[source]
Export this module to a normal PyTorch
torch.nn.ConvTranspose2d
module and set to “eval” mode.
R3IcoConvTransposed
- class R3IcoConvTransposed(in_type, out_type, kernel_size, padding=0, output_padding=0, stride=1, dilation=1, groups=1, bias=True, samples='ico', sigma=None, rings=None, recompute=False, basis_filter=None, initialize=True)[source]
Bases:
escnn.nn.modules.conv.r3_transposed_convolution.R3ConvTransposed
Icosahedral-steerable volumetric convolution mapping between the input and output
FieldType
s specified by the parametersin_type
andout_type
. This operation is equivariant under the action of \(\R^3\rtimes I\) where \(I\) (theIcosahedral
group) is theescnn.nn.FieldType.fibergroup
ofin_type
andout_type
.This class is mostly similar to
R3ConvTransposed
, with the only difference that it only supports the groupIcosahedral
since it uses a kernel basis which is specific for this group.The argument
frequencies_cutoff
ofR3ConvTransposed
is not supported here since the steerable kernels are not generated from a band-limited set of harmonic functions.Instead, the argument
samples
specifies the polyhedron (symmetric with respect to theIcosahedral
group) whose vertices are used to define the kernel on \(\R^3\). The supported polyhedrons are"ico"
(the 12 vertices of the icosahedron),"dodeca"
(the 20 vertices of the dodecahedron) or"icosidodeca"
(the 30 vertices of the icosidodecahedron, which correspond to the centers of the 30 edges of either the icosahedron or the dodecahedron).For each ring
r
inrings
, the polyhedron specified in embedded in the sphere of radiusr
. The analytical kernel, which is only defined on the vertices of this polyhedron, is then “diffused” in the ambient space \(\R^3\) by means of a small Gaussian kernel with stdsigma
.Warning
Even if the input tensor has a coords attribute, the output of this module will not have one.
- Parameters
in_type (FieldType) – the type of the input field, specifying its transformation law
out_type (FieldType) – the type of the output field, specifying its transformation law
kernel_size (int) – the size of the (square) filter
padding (int, optional) – implicit zero paddings on both sides of the input. Default:
0
output_padding (int, optional) – implicit zero paddings on both sides of the input. Default:
0
stride (int, optional) – the stride of the kernel. Default:
1
dilation (int, optional) – the spacing between kernel elements. Default:
1
groups (int, optional) – number of blocked connections from input channels to output channels. It allows depthwise convolution. When used, the input and output types need to be divisible in
groups
groups, all equal to each other. Default:1
.bias (bool, optional) – Whether to add a bias to the output (only to fields which contain a trivial irrep) or not. Default
True
sigma (list or float, optional) – width of each ring where the bases are sampled. If only one scalar is passed, it is used for all rings.
rings (list, optional) – radii of the rings where to sample the bases
recompute (bool, optional) – if
True
, recomputes a new basis for the equivariant kernels. By Default (False
), it caches the basis built or reuse a cached one, if it is found.basis_filter (callable, optional) – function which takes as input a descriptor of a basis element (as a dictionary) and returns a boolean value: whether to preserve (
True
) or discard (False
) the basis element. By default (None
), no filtering is applied.initialize (bool, optional) – initialize the weights of the model. Default:
True
Steerable Point Convolution
RdPointConv
- class _RdPointConv(in_type, out_type, d, groups=1, bias=True, basis_filter=None, recompute=False)[source]
Bases:
torch_geometric.nn.conv.message_passing.MessagePassing
,escnn.nn.modules.equivariant_module.EquivariantModule
,abc.ABC
Abstract class which implements a general G-steerable convolution, mapping between the input and output
FieldType
s specified by the parametersin_type
andout_type
. This operation is equivariant under the action of \(\R^d\rtimes G\) where \(G\) is theescnn.nn.FieldType.fibergroup
ofin_type
andout_type
.This class implements convolution with steerable filters over sparse planar geometric graphs. Instead,
_RdConv
implements an equivalent convolution layer over a pixel/voxel grid. See the documentation of_RdConv
for more details about equivariance and steerable convolution.The input of this module is a geometric graph, i.e. a graph whose nodes are associated with
d
-dimensional coordinates in \(\R^d\). The nodes’ coordinates should be stored in thecoords
attribute of the inputGeometricTensor
. The adjacency of the graph should be passed as a second input tensoredge_index
, like commonly done inMessagePassing
. Seeforward()
.In each forward pass, the module computes the relative coordinates of the points on the edges and samples each filter in the basis of G-steerable kernels at these relative locations. The basis filters are expanded using the learnable weights and used to perform convolution over the graph in the message passing framework. Optionally, the relative coordinates can be pre-computed and passed in the input
edge_delta
tensor.Note
In practice, we first apply the basis filters on the input features and then combine the responses via the learnable weights. See also
compute_messages()
.Warning
When
eval()
is called, the bias is built with the current trained weights and stored for future reuse such that no overhead of expanding the bias remains.When
train()
is called, the attributeexpanded_bias
is discarded to avoid situations of mismatch with the learnable expansion coefficients. See alsoescnn.nn.modules.pointconv._RdPointConv.train()
.This behaviour can cause problems when storing the
state_dict()
of a model while in a mode and lately loading it in a model with a different mode, as the attributes of the class change. To avoid this issue, we recommend converting the model to eval mode before storing or loading the state dictionary.Warning
We don’t support
groups > 1
yet. We include this parameter for future compatibility.- Parameters
in_type (FieldType) – the type of the input field, specifying its transformation law
out_type (FieldType) – the type of the output field, specifying its transformation law
d (int) – dimensionality of the base space (2 for images, 3 for volumes)
groups (int, optional) – number of blocked connections from input channels to output channels. It allows depthwise convolution. When used, the input and output types need to be divisible in
groups
groups, all equal to each other. Default:1
.bias (bool, optional) – Whether to add a bias to the output (only to fields which contain a trivial irrep) or not. Default
True
basis_filter (callable, optional) – function which takes as input a descriptor of a basis element (as a dictionary) and returns a boolean value: whether to preserve (
True
) or discard (False
) the basis element. By default (None
), no filtering is applied.recompute (bool, optional) – if
True
, recomputes a new basis for the equivariant kernels. By Default (False
), it caches the basis built or reuse a cached one, if it is found.
- Variables
~.weights (torch.Tensor) – the learnable parameters which are used to expand the kernel
~.bias (torch.Tensor) – the learnable parameters which are used to expand the bias, if
bias=True
~.expanded_bias (torch.Tensor) – the equivariant bias which is summed to the output, obtained by expanding the parameters in
bias
- property basissampler: escnn.nn.modules.basismanager.basissampler_blocks.BlocksBasisSampler
Submodule which takes care of sampling the steerable filters.
It is used to sample the G-steerable basis on the relative coordinates along the edges of a geometric graph and, then, expand the kernel in the sampled basis using the learned
weights
. See alsoforward()
.In practice, this submodule is also used to directly compute the messages via
compute_messages()
: first, the basis filters are applied on the input features and, then, the responses are combined using the learnable weights.
- expand_filter(points)[source]
Expand the filter in terms of
weights
.- Returns
the expanded filter sampled on the input points
- expand_parameters(points)[source]
Expand the filter in terms of the
weights
and the expanded bias in terms ofbias
.- Returns
the expanded filter and bias
- forward(x, edge_index, edge_delta=None)[source]
Convolve the input with the expanded filter and bias.
This method is based on PyTorch Geometric’s
MessagePassing
, i.e. it usespropagate()
to send the messages computed inmessage()
.The input tensor
input
represents a feature field over the nodes of a geometric graph. Hence, thecoords
attribute ofinput
should contain thed
-dimensional coordinates of each node (seeGeometricTensor
).The tensor
edge_index
must be atorch.LongTensor
of shape(2, m)
, representingm
edges.Mini-batches containing multiple graphs can be constructed as in Pytorch Geometric by merging the graphs in a unique, disconnected, graph.
- Parameters
input (GeometricTensor) – input feature field transforming according to
in_type
.edge_index (torch.Tensor) – tensor representing the connectivity of the graph.
edge_delta (torch.Tensor, optional) – the relative coordinates of the nodes on each edge. If not passed, it is automatically computed using
input.coords
andedge_index
.
- Returns
output feature field transforming according to
out_type
- message(x_j, edge_delta=None)[source]
This methods computes the message from the input node
j
to the output nodei
of each edge inedge_index
.The message is equal to the product of the filter evaluated on the relative coordinate along an edge with the feature vector on the input node of the edge.
- train(mode=True)[source]
If
mode=True
, the method sets the module in training mode and discards theexpanded_bias
attribute.If
mode=False
, it sets the module in evaluation mode. Moreover, the method builds the bias using the current values of the trainable parameters and store itexpanded_bias
such that it is not recomputed at each forward pass.Warning
This behaviour can cause problems when storing the
state_dict()
of a model while in a mode and lately loading it in a model with a different mode, as the attributes of this class change. To avoid this issue, we recommend converting the model to eval mode before storing or loading the state dictionary.- Parameters
mode (bool, optional) – whether to set training mode (
True
) or evaluation mode (False
). Default:True
.
R2PointConv
- class R2PointConv(in_type, out_type, groups=1, bias=True, sigma=None, width=None, n_rings=None, frequencies_cutoff=None, rings=None, basis_filter=None, recompute=False, initialize=True)[source]
Bases:
escnn.nn.modules.pointconv.rd_point_convolution._RdPointConv
G-steerable planar convolution mapping between the input and output
FieldType
s specified by the parametersin_type
andout_type
. This operation is equivariant under the action of \(\R^2\rtimes G\) where \(G\) is theescnn.nn.FieldType.fibergroup
ofin_type
andout_type
.This class implements convolution with steerable filters over sparse planar geometric graphs. Instead,
R2Conv
implements an equivalent convolution layer over a pixel grid. See the documentation ofR2Conv
for more details about equivariance and steerable convolution.The input of this module is a planar geometric graph, i.e. a graph whose nodes are associated with 2D coordinates in \(\R^2\). The nodes’ coordinates should be stored in the
coords
attribute of the inputGeometricTensor
. The adjacency of the graph should be passed as a second input tensoredge_index
, like commonly done inMessagePassing
. Seeforward()
.In each forward pass, the module computes the relative coordinates of the points on the edges and samples each filter in the basis of G-steerable kernels at these relative locations. The basis filters are expanded using the learnable weights and used to perform convolution over the graph in the message passing framework. Optionally, the relative coordinates can be pre-computed and passed in the input
edge_delta
tensor.Note
In practice, we first apply the basis filters on the input features and then combine the responses via the learnable weights. See also
compute_messages()
.Warning
When
eval()
is called, the bias is built with the current trained weights and stored for future reuse such that no overhead of expanding the bias remains.When
train()
is called, the attributeexpanded_bias
is discarded to avoid situations of mismatch with the learnable expansion coefficients. See alsoescnn.nn.R2PointConv.train()
.This behaviour can cause problems when storing the
state_dict()
of a model while in a mode and lately loading it in a model with a different mode, as the attributes of the class change. To avoid this issue, we recommend converting the model to eval mode before storing or loading the state dictionary.The learnable expansion coefficients of this module can be initialized with the methods in
escnn.nn.init
. By default, the weights are initialized in the constructors usinggeneralized_he_init()
.Warning
This initialization procedure can be extremely slow for wide layers. In case initializing the model is not required (e.g. before loading the state dict of a pre-trained model) or another initialization method is preferred (e.g.
deltaorthonormal_init()
), the parameterinitialize
can be set toFalse
to avoid unnecessary overhead.The parameters
width
,n_rings
,sigma
,frequencies_cutoff
,rings
are optional parameters used to control how the basis for the filters is built, how it is sampled and how it is expanded to build the filter. In practice, steerable filters are parameterized independently on a number of concentric rings by using circular harmonics. These rings can be specified by i) either using the listrings
, which defines the radii of each ring, or by ii) indicating the maximum radiuswidth
and the numbern_rings
of rings to include.sigma
defines the “thickness” of each ring as the standard deviation of a Gaussian bell along the radial direction.frequencies_cutoff
is a function defining the maximum frequency of the circular harmonics allowed at each radius. If a float value is passed, the maximum frequency is equal to the radius times this factor.Note
These parameters should be carefully tuned depending on the typical connectivity and the scale of the geometric graphs of interest.
Warning
We don’t support
groups > 1
yet. We include this parameter for future compatibility.- Parameters
in_type (FieldType) – the type of the input field, specifying its transformation law
out_type (FieldType) – the type of the output field, specifying its transformation law
groups (int, optional) – number of blocked connections from input channels to output channels. It allows depthwise convolution. When used, the input and output types need to be divisible in
groups
groups, all equal to each other. Default:1
.bias (bool, optional) – Whether to add a bias to the output (only to fields which contain a trivial irrep) or not. Default
True
sigma (list or float, optional) – width of each ring where the bases are sampled. If only one scalar is passed, it is used for all rings.
width (float, optional) – radius of the support of the learnable filters. Setting
n_rings
andwidth
is an alternative to userings
.n_rings (int, optional) – number of (equally spaced) rings the support of the filters is split into. Setting
n_rings
andwidth
is an alternative to userings
.frequencies_cutoff (callable or float, optional) – function mapping the radii of the basis elements to the maximum frequency accepted. If a float values is passed, the maximum frequency is equal to the radius times this factor. Default:
3.
.rings (list, optional) – radii of the rings where to sample the bases
basis_filter (callable, optional) – function which takes as input a descriptor of a basis element (as a dictionary) and returns a boolean value: whether to preserve (
True
) or discard (False
) the basis element. By default (None
), no filtering is applied.recompute (bool, optional) – if
True
, recomputes a new basis for the equivariant kernels. By Default (False
), it caches the basis built or reuse a cached one, if it is found.initialize (bool, optional) – initialize the weights of the model. Default:
True
- Variables
~.weights (torch.Tensor) – the learnable parameters which are used to expand the kernel
~.bias (torch.Tensor) – the learnable parameters which are used to expand the bias, if
bias=True
~.expanded_bias (torch.Tensor) – the equivariant bias which is summed to the output, obtained by expanding the parameters in
bias
R3PointConv
- class R3PointConv(in_type, out_type, groups=1, bias=True, sigma=None, width=None, n_rings=None, frequencies_cutoff=None, rings=None, basis_filter=None, recompute=False, initialize=True)[source]
Bases:
escnn.nn.modules.pointconv.rd_point_convolution._RdPointConv
G-steerable planar convolution mapping between the input and output
FieldType
s specified by the parametersin_type
andout_type
. This operation is equivariant under the action of \(\R^3\rtimes G\) where \(G\) is theescnn.nn.FieldType.fibergroup
ofin_type
andout_type
.This class implements convolution with steerable filters over sparse planar geometric graphs. Instead,
R3Conv
implements an equivalent convolution layer over a pixel grid. See the documentation ofR3Conv
for more details about equivariance and steerable convolution.The input of this module is a geometric graph, i.e. a graph whose nodes are associated with 3D coordinates in \(\R^3\). The nodes’ coordinates should be stored in the
coords
attribute of the inputGeometricTensor
. The adjacency of the graph should be passed as a second input tensoredge_index
, like commonly done inMessagePassing
. Seeforward()
.In each forward pass, the module computes the relative coordinates of the points on the edges and samples each filter in the basis of G-steerable kernels at these relative locations. The basis filters are expanded using the learnable weights and used to perform convolution over the graph in the message passing framework. Optionally, the relative coordinates can be pre-computed and passed in the input
edge_delta
tensor.Note
In practice, we first apply the basis filters on the input features and then combine the responses via the learnable weights. See also
compute_messages()
.Warning
When
eval()
is called, the bias is built with the current trained weights and stored for future reuse such that no overhead of expanding the bias remains.When
train()
is called, the attributeexpanded_bias
is discarded to avoid situations of mismatch with the learnable expansion coefficients. See alsoescnn.nn.R3PointConv.train()
.This behaviour can cause problems when storing the
state_dict()
of a model while in a mode and lately loading it in a model with a different mode, as the attributes of the class change. To avoid this issue, we recommend converting the model to eval mode before storing or loading the state dictionary.The learnable expansion coefficients of this module can be initialized with the methods in
escnn.nn.init
. By default, the weights are initialized in the constructors usinggeneralized_he_init()
.Warning
This initialization procedure can be extremely slow for wide layers. In case initializing the model is not required (e.g. before loading the state dict of a pre-trained model) or another initialization method is preferred (e.g.
deltaorthonormal_init()
), the parameterinitialize
can be set toFalse
to avoid unnecessary overhead.The parameters
width
,n_rings
,sigma
,frequencies_cutoff
,rings
are optional parameters used to control how the basis for the filters is built, how it is sampled and how it is expanded to build the filter. In practice, steerable filters are parameterized independently on a number of concentric spherical shells by using spherical harmonics. These shells can be specified by i) either using the listrings
, which defines the radii of each shell, or by ii) indicating the maximum radiuswidth
and the numbern_rings
of shells to include.sigma
defines the “thickness” of each shell as the standard deviation of a Gaussian bell along the radial direction.frequencies_cutoff
is a function defining the maximum frequency of the spherical harmonics allowed at each radius. If a float value is passed, the maximum frequency is equal to the radius times this factor.Note
These parameters should be carefully tuned depending on the typical connectivity and the scale of the geometric graphs of interest.
Warning
We don’t support
groups > 1
yet. We include this parameter for future compatibility.- Parameters
in_type (FieldType) – the type of the input field, specifying its transformation law
out_type (FieldType) – the type of the output field, specifying its transformation law
groups (int, optional) – number of blocked connections from input channels to output channels. It allows depthwise convolution. When used, the input and output types need to be divisible in
groups
groups, all equal to each other. Default:1
.bias (bool, optional) – Whether to add a bias to the output (only to fields which contain a trivial irrep) or not. Default
True
sigma (list or float, optional) – width of each ring where the bases are sampled. If only one scalar is passed, it is used for all rings.
width (float, optional) – radius of the support of the learnable filters. Setting
n_rings
andwidth
is an alternative to userings
.n_rings (int, optional) – number of (equally spaced) rings the support of the filters is split into. Setting
n_rings
andwidth
is an alternative to userings
.frequencies_cutoff (callable or float, optional) – function mapping the radii of the basis elements to the maximum frequency accepted. If a float values is passed, the maximum frequency is equal to the radius times this factor. Default:
3.
.rings (list, optional) – radii of the rings where to sample the bases
basis_filter (callable, optional) – function which takes as input a descriptor of a basis element (as a dictionary) and returns a boolean value: whether to preserve (
True
) or discard (False
) the basis element. By default (None
), no filtering is applied.recompute (bool, optional) – if
True
, recomputes a new basis for the equivariant kernels. By Default (False
), it caches the basis built or reuse a cached one, if it is found.initialize (bool, optional) – initialize the weights of the model. Default:
True
- Variables
~.weights (torch.Tensor) – the learnable parameters which are used to expand the kernel
~.bias (torch.Tensor) – the learnable parameters which are used to expand the bias, if
bias=True
~.expanded_bias (torch.Tensor) – the equivariant bias which is summed to the output, obtained by expanding the parameters in
bias
BasisManager
- class BasisManager[source]
Bases:
abc.ABC
Abstract class defining the interface for the different modules which deal with the filter basis. It provides a few methods which can be used to retrieve information about the basis and each of its elements.
- abstract get_element_info(name)[source]
Method that returns the information associated to a basis element
- Parameters
name (int) – index of the basis element
- Returns
dictionary containing the information
- class BlocksBasisExpansion(in_reprs, out_reprs, basis_generator, points, basis_filter=None, recompute=False)[source]
Bases:
torch.nn.modules.module.Module
,escnn.nn.modules.basismanager.basismanager.BasisManager
Method that performs the expansion of a (already sampled) filter basis.
- Parameters
in_reprs (list) – the input field type
out_reprs (list) – the output field type
basis_generator (callable) – method that generates the analytical filter basis
points (ndarray) – points where the analytical basis should be sampled
basis_filter (callable, optional) – filter for the basis elements. Should take a dictionary containing an element’s attributes and return whether to keep it or not.
recompute (bool, optional) – whether to recompute new bases or reuse, if possible, already built tensors.
- Variables
~.S (int) – number of points where the filters are sampled
- forward(weights)[source]
Forward step of the Module which expands the basis and returns the filter built
- Parameters
weights (torch.Tensor) – the learnable weights used to linearly combine the basis filters
- Returns
the filter built
- class BlocksBasisSampler(in_reprs, out_reprs, basis_generator, basis_filter=None, recompute=False)[source]
Bases:
torch.nn.modules.module.Module
,escnn.nn.modules.basismanager.basismanager.BasisManager
Module which performs the expansion of an analytical filter basis and samples it on arbitrary input points.
- Parameters
in_reprs (list) – the input field type
out_reprs (list) – the output field type
basis_generator (callable) – method that generates the analytical filter basis
basis_filter (callable, optional) – filter for the basis elements. Should take a dictionary containing an element’s attributes and return whether to keep it or not.
recompute (bool, optional) – whether to recompute new bases or reuse, if possible, already built tensors.
- forward(weights, points)[source]
Forward step of the Module which expands the basis, samples it on the input points and returns the filter built.
- Parameters
weights (torch.Tensor) – the learnable weights used to linearly combine the basis filters
points (torch.Tensor) – the points where the filter should be sampled
- Returns
the filter built
- compute_messages(weights, input, points, conv_first=True, groups=1)[source]
Expands the basis with the learnable weights to generate the filter and use it to compute the messages along the edges.
Each point in points corresponds to an edge in a graph. Each point is associated with a row of input. This row is a feature associated to the source node of the edge which needs to be propagated to the target node of the edge.
This method also allows grouped-convolution via the argument
groups
. When used, theinput
tensor should containgroups
blocks, each transforming underself._in_reprs
. Moreover, the output sizeself._out_size
should be divisible bygroups
.Warning
With respect to convolution layers, this method does not check that
self._out_repr
splits ingroups
blocks containing the same representations. Hence, this operation can break equivariance ifgroups
is not properly set andself._out_repr
contains an heterogeneous list of representations. We recommend using directly theR2PointConv
orR3PointConv
modules instead, which implement a number of checks to ensure the convolution is done in an equivariant way.- Parameters
weights (torch.Tensor) – the learnable weights used to linearly combine the basis filters
input (torch.Tensor) – the input features associated with each point
points (torch.Tensor) – the points where the filter should be sampled
conv_first (bool, optional) – perform convolution with the basis filters and, then, combine the responses with the learnable weights. This generally has computational benefits. (Default
True
).groups (int, optional) – number of blocked connections from input channels to output channels. It allows depthwise convolution. Default:
1
.
- Returns
the messages computed
Non Linearities
PointwiseNonLinearity
- class PointwiseNonLinearity(in_type, function='p_relu')[source]
Bases:
escnn.nn.modules.equivariant_module.EquivariantModule
Pointwise non-linearities. The same scalar function is applied to every channel independently. The input representation is preserved by this operation and, therefore, it equals the output representation.
Only representations supporting pointwise non-linearities are accepted as input field type.
- Parameters
- forward(input)[source]
Applies the pointwise activation function on the input fields
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map after the non-linearities have been applied
ReLU
- class ReLU(in_type, inplace=False)[source]
Bases:
escnn.nn.modules.equivariant_module.EquivariantModule
Module that implements a pointwise ReLU to every channel independently. The input representation is preserved by this operation and, therefore, it equals the output representation.
Only representations supporting pointwise non-linearities are accepted as input field type.
- Parameters
- forward(input)[source]
Applies ReLU function on the input fields
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map after relu has been applied
- export()[source]
Export this module to a normal PyTorch
torch.nn.ReLU
module and set to “eval” mode.
ELU
- class ELU(in_type, alpha=1.0, inplace=False)[source]
Bases:
escnn.nn.modules.equivariant_module.EquivariantModule
Module that implements a pointwise ELU to every channel independently. The input representation is preserved by this operation and, therefore, it equals the output representation.
Only representations supporting pointwise non-linearities are accepted as input field type.
- Parameters
- forward(input)[source]
Applies ELU function on the input fields
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map after elu has been applied
- export()[source]
Export this module to a normal PyTorch
torch.nn.ELU
module and set to “eval” mode.
LeakyReLU
- class LeakyReLU(in_type, negative_slope=0.01, inplace=False)[source]
Bases:
escnn.nn.modules.equivariant_module.EquivariantModule
Module that implements a pointwise LeakyReLU to every channel independently. The input representation is preserved by this operation and, therefore, it equals the output representation.
Only representations supporting pointwise non-linearities are accepted as input field type.
- Parameters
- forward(input)[source]
Applies leaky-ReLU function on the input fields
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map after relu has been applied
- export()[source]
Export this module to a normal PyTorch
torch.nn.LeakyReLU
module and set to “eval” mode.
FourierPointwise
- class FourierPointwise(gspace, channels, irreps, *grid_args, function='p_relu', inplace=True, out_irreps=None, normalize=True, **grid_kwargs)[source]
Bases:
escnn.nn.modules.equivariant_module.EquivariantModule
Applies a Inverse Fourier Transform to sample the input features, apply the pointwise non-linearity in the group domain (Dirac-delta basis) and, finally, computes the Fourier Transform to obtain irreps coefficients.
Warning
This operation is only approximately equivariant and its equivariance depends on the sampling grid and the non-linear activation used, as well as the original band-limitation of the input features.
The same function is applied to every channel independently. By default, the input representation is preserved by this operation and, therefore, it equals the output representation. Optionally, the output can have a different band-limit by using the argument
out_irreps
.The class first constructs a band-limited regular representation of
`gspace.fibergroup`
usingescnn.group.Group.spectral_regular_representation()
. The band-limitation of this representation is specified by`irreps`
which should be a list containing a list of ids identifying irreps of`gspace.fibergroup`
(seeescnn.group.IrreducibleRepresentation.id
). This representation is used to define the input and output field types, each containing`channels`
copies of a feature field transforming according to this representation. A feature vector transforming according to such representation is interpreted as a vector of coefficients parameterizing a function over the group using a band-limited Fourier basis.Note
Instead of building the list
irreps
manually, most groups implement a methodbl_irreps()
which can be used to generate this list with through a simpler interface. Check each group’s documentation.To approximate the Fourier transform, this module uses a finite number of samples from the group. The set of samples used is specified by the
`grid_args`
and`grid_kwargs`
which are forwarded to the methodgrid()
.- Parameters
gspace (GSpace) – the gspace describing the symmetries of the data. The Fourier transform is performed over the group
`gspace.fibergroup`
channels (int) – number of independent fields in the input FieldType
irreps (list) – list of irreps’ ids to construct the band-limited representation
function (str) – the identifier of the non-linearity. It is used to specify which function to apply. By default (
'p_relu'
), ReLU is used.*grid_args – parameters used to construct the discretization grid
inplace (bool) – applies the non-linear activation in-place. Default: True
out_irreps (list, optional) – optionally, one can specify a different band-limiting in output
normalize (bool, optional) – if
True
, the rows of the IFT matrix (and the columns of the FT matrix) are normalized. Default:True
**grid_kwargs – keyword parameters used to construct the discretization grid
- forward(input)[source]
Applies the pointwise activation function on the input fields
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map after the non-linearities have been applied
FourierELU
- class FourierELU(gspace, channels, irreps, *grid_args, inplace=True, out_irreps=None, normalize=True, **grid_kwargs)[source]
Bases:
escnn.nn.modules.nonlinearities.fourier.FourierPointwise
Applies a Inverse Fourier Transform to sample the input features, apply ELU point-wise in the group domain (Dirac-delta basis) and, finally, computes the Fourier Transform to obtain irreps coefficients. See
FourierPointwise
for more details.- Parameters
gspace (GSpace) – the gspace describing the symmetries of the data. The Fourier transform is performed over the group
`gspace.fibergroup`
channels (int) – number of independent fields in the input FieldType
irreps (list) – list of irreps’ ids to construct the band-limited representation
*grid_args – parameters used to construct the discretization grid
inplace (bool) – applies the non-linear activation in-place. Default: True
out_irreps (list, optional) – optionally, one can specify a different band-limiting in output
normalize (bool, optional) – if
True
, the rows of the IFT matrix (and the columns of the FT matrix) are normalized. Default:True
**grid_kwargs – keyword parameters used to construct the discretization grid
QuotientFourierPointwise
- class QuotientFourierPointwise(gspace, subgroup_id, channels, irreps, *grid_args, grid=None, function='p_relu', inplace=True, out_irreps=None, normalize=True, **grid_kwargs)[source]
Bases:
escnn.nn.modules.equivariant_module.EquivariantModule
Applies a Inverse Fourier Transform to sample the input features on a quotient space \(X\), apply the pointwise non-linearity in the spatial domain (Dirac-delta basis) and, finally, computes the Fourier Transform to obtain irreps coefficients. The quotient space used is isomorphic to \(X \cong G / H\) where \(G\) is
`gspace.fibergroup`
while \(H\) is the subgroup of \(G\) idenitified by`subgroup_id`
; seesubgroup()
andhomspace()
Warning
This operation is only approximately equivariant and its equivariance depends on the sampling grid and the non-linear activation used, as well as the original band-limitation of the input features.
The same function is applied to every channel independently. By default, the input representation is preserved by this operation and, therefore, it equals the output representation. Optionally, the output can have a different band-limit by using the argument
out_irreps
.The class first constructs a band-limited quotient representation of
`gspace.fibergroup`
usingescnn.group.Group.spectral_quotient_representation()
. The band-limitation of this representation is specified by`irreps`
which should be a list containing a list of ids identifying irreps of`gspace.fibergroup`
(seeescnn.group.IrreducibleRepresentation.id
). This representation is used to define the input and output field types, each containing`channels`
copies of a feature field transforming according to this representation. A feature vector transforming according to such representation is interpreted as a vector of coefficients parameterizing a function over the group using a band-limited Fourier basis.Note
Instead of building the list
irreps
manually, most groups implement a methodbl_irreps()
which can be used to generate this list with through a simpler interface. Check each group’s documentation.To approximate the Fourier transform, this module uses a finite number of samples from the group. The set of samples to be used can be specified through the parameter
`grid`
or by the`grid_args`
and`grid_kwargs`
which will then be passed to the methodgrid()
.Warning
By definition, an homogeneous space is invariant under a right action of the subgroup \(H\). That means that a feature representing a function over a homogeneous space \(X \cong G/H\), when interpreted as a function over \(G\) (as we do here when sampling), the function will be constant along each coset, i.e. \(f(gh) = f(g)\) if \(g \in G, h\in H\). An approximately uniform sampling grid over \(G\) creates an approximately uniform grid over \(G/H\) through projection but might contain redundant elements (if the grid contains \(g \in G\), any element \(gh\) in the grid will be redundant). It is therefore advised to create a grid directly in the quotient space, e.g. using
escnn.group.SO3.sphere_grid()
,escnn.group.O3.sphere_grid()
. We do not support yet a general method and interface to generate grids over any homogeneous space for any group, so you should check each group’s methods.- Parameters
gspace (GSpace) – the gspace describing the symmetries of the data. The Fourier transform is performed over the group
`gspace.fibergroup`
subgroup_id (tuple) – identifier of the subgroup \(H\) to construct the quotient space
channels (int) – number of independent fields in the input FieldType
irreps (list) – list of irreps’ ids to construct the band-limited representation
*grid_args – parameters used to construct the discretization grid
grid (list, optional) – list containing the elements of the group to use for sampling. Optional (default
None
).function (str) – the identifier of the non-linearity. It is used to specify which function to apply. By default (
'p_relu'
), ReLU is used.inplace (bool) – applies the non-linear activation in-place. Default: True
out_irreps (list, optional) – optionally, one can specify a different band-limiting in output
normalize (bool, optional) – if
True
, the rows of the IFT matrix (and the columns of the FT matrix) are normalized. Default:True
**grid_kwargs – keyword parameters used to construct the discretization grid
- forward(input)[source]
Applies the pointwise activation function on the input fields
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map after the non-linearities have been applied
QuotientFourierELU
- class QuotientFourierELU(gspace, subgroup_id, channels, irreps, *grid_args, grid=None, inplace=True, out_irreps=None, normalize=True, **grid_kwargs)[source]
Bases:
escnn.nn.modules.nonlinearities.fourier_quotient.QuotientFourierPointwise
Applies a Inverse Fourier Transform to sample the input features on a quotient space, apply ELU point-wise in the spatial domain (Dirac-delta basis) and, finally, computes the Fourier Transform to obtain irreps coefficients. See
QuotientFourierPointwise
for more details.- Parameters
gspace (GSpace) – the gspace describing the symmetries of the data. The Fourier transform is performed over the group
`gspace.fibergroup`
subgroup_id (tuple) – identifier of the subgroup \(H\) to construct the quotient space
channels (int) – number of independent fields in the input FieldType
irreps (list) – list of irreps’ ids to construct the band-limited representation
*grid_args – parameters used to construct the discretization grid
grid (list, optional) – list containing the elements of the group to use for sampling. Optional (default
None
).inplace (bool) – applies the non-linear activation in-place. Default:
True
out_irreps (list, optional) – optionally, one can specify a different band-limiting in output
normalize (bool, optional) – if
True
, the rows of the IFT matrix (and the columns of the FT matrix) are normalized. Default:True
**grid_kwargs – keyword parameters used to construct the discretization grid
TensorProductModule
- class TensorProductModule(in_type, out_type, initialize=True)[source]
Bases:
escnn.nn.modules.equivariant_module.EquivariantModule
Applies a (learnable) quadratic non-linearity to the input features.
The module requires its input and output types to be uniform, i.e. contain multiple copies of the same representation; see also
uniform()
. Moreover, the input and output field types must have the same number of fields, i.e.`len(in_type) == len(out_type)`
.The module computes the tensor product of each field with itself to generate an intermediate feature map. Note that this feature map will have size
`len(in_type) * in_type.representations[0].size**2`
. To prevent the exponential growth of the model’s width at each layer, the module includes also a learnable linear projection of each`in_type.representations[0].size**2`
-dimensional output field to a corresponding`out_type.representations[0].size`
output field. Note that this layer applies an independent linear projection to each field individually but does not mix them.- ..warning ::
A model employing only this kind of non-linearities will effectively be a polynomial function. Moreover, the degree of the polynomial grows exponentially with the depth of the network. This may result in some instabilities during training.
- Parameters
- Variables
~.weights (torch.Tensor) – the learnable parameters which are used to expand the projection matrix
~.matrix (torch.Tensor) – the matrix obtained by expanding the parameters in
weights
- property basisexpansion: escnn.nn.modules.basismanager.basismanager.BasisManager
Submodule which takes care of building the matrix.
It uses the learnt
weights
to expand a basis and returns a matrix in the usual form used by conventional convolutional modules. It uses the learnedweights
to expand the kernel in the G-steerable basis and returns it in the shape \((c_\text{out}, c_\text{in}, s^d)\), where \(s\) is thekernel_size
and \(d\) is the dimensionality of the base space.
- expand_parameters()[source]
Expand the matrix in terms of the
escnn.nn.TensorProductModule.weights
.- Returns
the expanded projection matrix
- train(mode=True)[source]
If
mode=True
, the method sets the module in training mode and discards thematrix
attribute.If
mode=False
, it sets the module in evaluation mode. Moreover, the method builds the matrix using the current values of the trainable parameters and store them inmatrix
such that it is not recomputed at each forward pass.Warning
This behaviour can cause problems when storing the
state_dict()
of a model while in a mode and lately loading it in a model with a different mode, as the attributes of this class change. To avoid this issue, we recommend converting the model to eval mode before storing or loading the state dictionary.- Parameters
mode (bool, optional) – whether to set training mode (
True
) or evaluation mode (False
). Default:True
.
GatedNonLinearity1
- class GatedNonLinearity1(in_type, gates=None, drop_gates=True, **kwargs)[source]
Bases:
escnn.nn.modules.equivariant_module.EquivariantModule
Gated non-linearities. This module applies a bias and a sigmoid function of the gates fields and, then, multiplies each gated field by one of the gates.
The input representation of the gated fields is preserved by this operation while the gate fields are discarded.
The gates and the gated fields are provided in one unique input tensor and, therefore,
in_repr
should be the representation of the fiber containing both gates and gated fields. Moreover, the parametergates
needs to be set with a list long as the total number of fields, containing in a positioni
the string"gate"
if thei
-th field is a gate or the string"gated"
if thei
-th field is a gated field. No other strings are allowed. By default (gates = None
), the first half of the fields is assumed to contain the gates (and, so, these fields have to be trivial fields) while the second one is assumed to contain the gated fields.In any case, the number of gates and the number of gated fields have to match (therefore, the number of fields has to be an even number).
- Parameters
in_type (FieldType) – the input field type
gates (list, optional) – list of strings specifying which field in input is a gate and which is a gated field
drop_gates (bool, optional) – if
True
(default), drop the trivial fields after using them to compute the gates. IfFalse
, the gates are stacked with the gated fields in the output
- forward(input)[source]
Apply the gated non-linearity to the input feature map.
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map
GatedNonLinearity2
- class GatedNonLinearity2(in_type, **kwargs)[source]
Bases:
escnn.nn.modules.equivariant_module.EquivariantModule
Gated non-linearities. This module applies a bias and a sigmoid function of the gates fields and, then, multiplies each gated field by one of the gates.
The input representation of the gated fields is preserved by this operation while the gate fields are discarded.
The gates and the gated fields are provided in two different tensors:
in_repr
is a tuple containing two representations: the representation of the tensor containing only the gates (which have to be trivial fields) and the representation of the tensor containing only the gated fields. Therefore, two tensors need to be provided as input to theforward
method: the first contains the gates and the second the gated fields. Finally, notice that the number of gates and the number of gated fields have to match, i.e. the two representations need to have the same number of fields.Todo
This module has 2 input tensors and, so, two input field types. EquivariantModule only supports one input though. Fix this.
- Parameters
in_type (Tuple) – a pair containing, in order, the field type of the gates and the field type of the gated fields
- forward(gates, input)[source]
Apply the gated non-linearity to the input feature map.
- Parameters
gates (GeometricTensor) – feature map corresponding to the gates
input (GeometricTensor) – input feature map corresponding to the gated fields
- Returns
the resulting feature map
GatedNonLinearityUniform
- class GatedNonLinearityUniform(in_type, gate=<built-in method sigmoid of type object>)[source]
Bases:
escnn.nn.modules.equivariant_module.EquivariantModule
Gated non-linearities. This module applies a bias and a sigmoid function of the gates fields and, then, multiplies each gated field by one of the gates.
The input representation of the gated features is preserved by this operation while the gate fields are discarded.
This module is a less general-purpose version of the other Gated-Non-Linearity modules, optimized for uniform field types. This module assumes that the input type contains only copies of the same field (same representation) and that such field internally contains a trivial representation for each other irrep in it.
This means that the number of irreps in the representation must be even and that the first half of them need to be trivial representations.
The input representation is also assumed to have no change of basis, i.e. its change-of-basis must be equal to the identity matrix.
Note
The documentation of this method is still work in progress.
- Parameters
in_type (FieldType) – the input field type
gate (optional, Callable) – the gate fucntion to apply. By default, it is the sigmoid function.
- forward(input)[source]
Apply the gated non-linearity to the input feature map.
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map
InducedGatedNonLinearity1
- class InducedGatedNonLinearity1(in_type, gates=None, drop_gates=True, **kwargs)[source]
Bases:
escnn.nn.modules.equivariant_module.EquivariantModule
Induced Gated non-linearities.
Todo
complete documentation!
Note
Make sure all induced gate and gates have same subgroup
- Parameters
in_type (FieldType) – the input field type
gates (list, optional) – list of strings specifying which field in input is a gate and which is a gated field
drop_gates (bool, optional) – if
True
(default), drop the trivial fields after using them to compute the gates. IfFalse
, the gates are stacked with the gated fields in the output
- forward(input)[source]
Apply the gated non-linearity to the input feature map.
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map
ConcatenatedNonLinearity
- class ConcatenatedNonLinearity(in_type, function='c_relu')[source]
Bases:
escnn.nn.modules.equivariant_module.EquivariantModule
Concatenated non-linearities. For each input channel, the module applies the specified activation function both to its value and its opposite (the value multiplied by -1). The number of channels is, therefore, doubled.
Notice that not all the representations support this kind of non-linearity. Indeed, only representations with the same pattern of permutation matrices and containing only values in \(\{0, 1, -1\}\) support it.
NormNonLinearity
- class NormNonLinearity(in_type, function='n_relu', bias=True)[source]
Bases:
escnn.nn.modules.equivariant_module.EquivariantModule
Norm non-linearities. This module applies a bias and an activation function over the norm of each field.
The input representation of the fields is preserved by this operation.
Note
If ‘squash’ non-linearity (function) is chosen, no bias is allowed
- Parameters
- forward(input)[source]
Apply norm non-linearities to the input feature map
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map
InducedNormNonLinearity
- class InducedNormNonLinearity(in_type, function='n_relu', bias=True)[source]
Bases:
escnn.nn.modules.equivariant_module.EquivariantModule
Induced Norm non-linearities. This module requires the input fields to be associated to an induced representation from a representation which supports ‘norm’ non-linearities. This module applies a bias and an activation function over the norm of each sub-field of an induced field. The bias is shared among all sub-field of the same induced field.
The input representation of the fields is preserved by this operation.
- Parameters
- forward(input)[source]
Apply norm non-linearities to the input feature map
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map
VectorFieldNonLinearity
- class VectorFieldNonLinearity(in_type, **kwargs)[source]
Bases:
escnn.nn.modules.equivariant_module.EquivariantModule
VectorField non-linearities. This non-linearity only supports the regular representation of cyclic group \(C_N\), i.e. the group of \(N\) discrete rotations. For each input field, the output one is built by taking the rotation associated with the highest activation; then, a 2-dimensional vector with an angle with respect to the x-axis equal to that rotation and a length equal to its activation is set in the output field.
- Parameters
in_type (FieldType) – the input field type
- forward(input)[source]
Apply the VectorField non-linearity to the input feature map.
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map
Invariant Maps
GroupPooling
- class GroupPooling(in_type, **kwargs)[source]
Bases:
escnn.nn.modules.equivariant_module.EquivariantModule
Module that implements group pooling. This module only supports permutation representations such as regular representation, quotient representation or trivial representation (though, in the last case, this module acts as identity). For each input field, an output field is built by taking the maximum activation within that field; as a result, the output field transforms according to a trivial representation.
- Parameters
in_type (FieldType) – the input field type
- forward(input)[source]
Apply Group Pooling to the input feature map.
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map
- export()[source]
Export this module to the pure PyTorch module
MaxPoolChannels
and set to “eval” mode.Warning
Currently, this method only supports group pooling with feature types containing only representations of the same size.
Note
Because there is no native PyTorch module performing this operation, it is not possible to export this module without any dependency with this library. Indeed, the resulting module is dependent on this library through the class
MaxPoolChannels
. In case PyTorch will introduce a similar module in a future release, we will update this method to remove this dependency.Nevertheless, the
MaxPoolChannels
module is slightly lighter thanGroupPooling
as it does not perform any automatic type checking and does not wrap each tensor in aGeometricTensor
. Furthermore, theMaxPoolChannels
class is very simple and one can easily reimplement it to remove any dependency with this library after training the model.
NormPool
- class NormPool(in_type, **kwargs)[source]
Bases:
escnn.nn.modules.equivariant_module.EquivariantModule
Module that implements Norm Pooling. For each input field, an output one is built by taking the norm of that field; as a result, the output field transforms according to a trivial representation.
- Parameters
in_type (FieldType) – the input field type
- forward(input)[source]
Apply the Norm Pooling to the input feature map.
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map
InducedNormPool
- class InducedNormPool(in_type, **kwargs)[source]
Bases:
escnn.nn.modules.equivariant_module.EquivariantModule
Module that implements Induced Norm Pooling. This module requires the input fields to be associated to an induced representation from a representation which supports ‘norm’ non-linearities.
First, for each input field, an output one is built by taking the maximum norm of all its sub-fields.
- Parameters
in_type (FieldType) – the input field type
- forward(input)[source]
Apply the Norm Pooling to the input feature map.
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map
Pooling
NormMaxPool
- class NormMaxPool(in_type, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False)[source]
Bases:
escnn.nn.modules.equivariant_module.EquivariantModule
Max-pooling based on the fields’ norms. In a given window of shape
kernel_size
, for each group of channels belonging to the same field, the field with the highest norm (as the length of the vector) is preserved.Except
in_type
, the other parameters correspond to the ones oftorch.nn.MaxPool2d
.Warning
Even if the input tensor has a coords attribute, the output of this module will not have one.
- Parameters
in_type (FieldType) – the input field type
kernel_size (
Union
[int
,Tuple
[int
,int
]]) – the size of the window to take a max overstride (
Union
[int
,Tuple
[int
,int
],None
]) – the stride of the window. Default value iskernel_size
padding (
Union
[int
,Tuple
[int
,int
]]) – implicit zero padding to be added on both sidesdilation (
Union
[int
,Tuple
[int
,int
]]) – a parameter that controls the stride of elements in the windowceil_mode (
bool
) – whenTrue
, will use ceil instead of floor to compute the output shape
- forward(input)[source]
Run the norm-based max-pooling on the input tensor
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map
PointwiseMaxPool2D
- class PointwiseMaxPool2D(in_type, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False)[source]
Bases:
escnn.nn.modules.pooling.pointwise_max._PointwiseMaxPoolND
Channel-wise max-pooling: each channel is treated independently. This module works exactly as
torch.nn.MaxPool2D
, wrapping it in theEquivariantModule
interface.Notice that not all representations support this kind of pooling. In general, only representations which support pointwise non-linearities do.
Warning
Even if the input tensor has a coords attribute, the output of this module will not have one.
- Parameters
in_type (FieldType) – the input field type
kernel_size (
Union
[int
,Tuple
[int
,int
]]) – the size of the window to take a max overstride (
Union
[int
,Tuple
[int
,int
],None
]) – the stride of the window. Default value iskernel_size
padding (
Union
[int
,Tuple
[int
,int
]]) – implicit zero padding to be added on both sidesdilation (
Union
[int
,Tuple
[int
,int
]]) – a parameter that controls the stride of elements in the windowceil_mode (
bool
) – when True, will use ceil instead of floor to compute the output shape
PointwiseMaxPoolAntialiased2D
- class PointwiseMaxPoolAntialiased2D(in_type, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, sigma=0.6)[source]
Bases:
escnn.nn.modules.pooling.pointwise_max._PointwiseMaxPoolAntialiasedND
Anti-aliased version of channel-wise max-pooling (each channel is treated independently).
The max over a neighborhood is performed pointwise withot downsampling. Then, convolution with a gaussian blurring filter is performed before downsampling the feature map.
Based on Making Convolutional Networks Shift-Invariant Again.
Notice that not all representations support this kind of pooling. In general, only representations which support pointwise non-linearities do.
Warning
Even if the input tensor has a coords attribute, the output of this module will not have one.
- Parameters
in_type (FieldType) – the input field type
kernel_size (
Union
[int
,Tuple
[int
,int
]]) – the size of the window to take a max overstride (
Union
[int
,Tuple
[int
,int
],None
]) – the stride of the window. Default value iskernel_size
padding (
Union
[int
,Tuple
[int
,int
]]) – implicit zero padding to be added on both sidesdilation (
Union
[int
,Tuple
[int
,int
]]) – a parameter that controls the stride of elements in the windowceil_mode (
bool
) – whenTrue
, will use ceil instead of floor to compute the output shapesigma (float) – standard deviation for the Gaussian blur filter
PointwiseMaxPool3D
- class PointwiseMaxPool3D(in_type, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False)[source]
Bases:
escnn.nn.modules.pooling.pointwise_max._PointwiseMaxPoolND
Channel-wise max-pooling: each channel is treated independently. This module works exactly as
torch.nn.MaxPool3D
, wrapping it in theEquivariantModule
interface.Notice that not all representations support this kind of pooling. In general, only representations which support pointwise non-linearities do.
Warning
Even if the input tensor has a coords attribute, the output of this module will not have one.
- Parameters
in_type (FieldType) – the input field type
kernel_size (
Union
[int
,Tuple
[int
,int
]]) – the size of the window to take a max overstride (
Union
[int
,Tuple
[int
,int
],None
]) – the stride of the window. Default value iskernel_size
padding (
Union
[int
,Tuple
[int
,int
]]) – implicit zero padding to be added on both sidesdilation (
Union
[int
,Tuple
[int
,int
]]) – a parameter that controls the stride of elements in the windowceil_mode (
bool
) – when True, will use ceil instead of floor to compute the output shape
PointwiseMaxPoolAntialiased3D
- class PointwiseMaxPoolAntialiased3D(in_type, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, sigma=0.6)[source]
Bases:
escnn.nn.modules.pooling.pointwise_max._PointwiseMaxPoolAntialiasedND
Anti-aliased version of channel-wise max-pooling (each channel is treated independently).
The max over a neighborhood is performed pointwise withot downsampling. Then, convolution with a gaussian blurring filter is performed before downsampling the feature map.
Based on Making Convolutional Networks Shift-Invariant Again.
Notice that not all representations support this kind of pooling. In general, only representations which support pointwise non-linearities do.
Warning
Even if the input tensor has a coords attribute, the output of this module will not have one.
- Parameters
in_type (FieldType) – the input field type
kernel_size (
Union
[int
,Tuple
[int
,int
]]) – the size of the window to take a max overstride (
Union
[int
,Tuple
[int
,int
],None
]) – the stride of the window. Default value iskernel_size
padding (
Union
[int
,Tuple
[int
,int
]]) – implicit zero padding to be added on both sidesdilation (
Union
[int
,Tuple
[int
,int
]]) – a parameter that controls the stride of elements in the windowceil_mode (
bool
) – whenTrue
, will use ceil instead of floor to compute the output shapesigma (float) – standard deviation for the Gaussian blur filter
PointwiseAvgPool2D
- class PointwiseAvgPool2D(in_type, kernel_size, stride=None, padding=0, ceil_mode=False)[source]
Bases:
escnn.nn.modules.equivariant_module.EquivariantModule
Channel-wise average-pooling: each channel is treated independently. This module works exactly as
torch.nn.AvgPool2D
, wrapping it in theEquivariantModule
interface.Warning
Even if the input tensor has a coords attribute, the output of this module will not have one.
- Parameters
in_type (FieldType) – the input field type
kernel_size (
Union
[int
,Tuple
[int
,int
]]) – the size of the window to take a average overstride (
Union
[int
,Tuple
[int
,int
],None
]) – the stride of the window. Default value iskernel_size
padding (
Union
[int
,Tuple
[int
,int
]]) – implicit zero padding to be added on both sidesceil_mode (
bool
) – whenTrue
, will use ceil instead of floor to compute the output shape
- forward(input)[source]
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map
PointwiseAvgPoolAntialiased2D
- class PointwiseAvgPoolAntialiased2D(in_type, sigma, stride, padding=None)[source]
Bases:
escnn.nn.modules.equivariant_module.EquivariantModule
Antialiased channel-wise average-pooling: each channel is treated independently. It performs strided convolution with a Gaussian blur filter.
The size of the filter is computed as 3 standard deviations of the Gaussian curve. By default, padding is added such that input size is preserved if stride is 1.
Based on Making Convolutional Networks Shift-Invariant Again.
Warning
Even if the input tensor has a coords attribute, the output of this module will not have one.
- Parameters
- forward(input)[source]
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map
PointwiseAvgPool3D
- class PointwiseAvgPool3D(in_type, kernel_size, stride=None, padding=0, ceil_mode=False)[source]
Bases:
escnn.nn.modules.equivariant_module.EquivariantModule
Channel-wise average-pooling: each channel is treated independently. This module works exactly as
torch.nn.AvgPool3D
, wrapping it in theEquivariantModule
interface.Warning
Even if the input tensor has a coords attribute, the output of this module will not have one.
- Parameters
in_type (FieldType) – the input field type
kernel_size (
Union
[int
,Tuple
[int
,int
]]) – the size of the window to take a average overstride (
Union
[int
,Tuple
[int
,int
],None
]) – the stride of the window. Default value iskernel_size
padding (
Union
[int
,Tuple
[int
,int
]]) – implicit zero padding to be added on both sidesceil_mode (
bool
) – whenTrue
, will use ceil instead of floor to compute the output shape
- forward(input)[source]
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map
PointwiseAvgPoolAntialiased3D
- class PointwiseAvgPoolAntialiased3D(in_type, sigma, stride, padding=None)[source]
Bases:
escnn.nn.modules.equivariant_module.EquivariantModule
Antialiased channel-wise average-pooling: each channel is treated independently. It performs strided convolution with a Gaussian blur filter.
The size of the filter is computed as 3 standard deviations of the Gaussian curve. By default, padding is added such that input size is preserved if stride is 1.
Inspired by Making Convolutional Networks Shift-Invariant Again.
Warning
Even if the input tensor has a coords attribute, the output of this module will not have one.
- Parameters
- forward(input)[source]
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map
PointwiseAdaptiveAvgPool2D
- class PointwiseAdaptiveAvgPool2D(in_type, output_size)[source]
Bases:
escnn.nn.modules.pooling.pointwise_adaptive_avg._PointwiseAdaptiveAvgPoolND
Adaptive channel-wise average-pooling: each channel is treated independently. This module works exactly as
torch.nn.AdaptiveAvgPool2D
, wrapping it in theEquivariantModule
interface.Notice that not all representations support this kind of pooling. In general, only representations which support pointwise non-linearities do.
Warning
Even if the input tensor has a coords attribute, the output of this module will not have one.
PointwiseAdaptiveAvgPool3D
- class PointwiseAdaptiveAvgPool3D(in_type, output_size)[source]
Bases:
escnn.nn.modules.pooling.pointwise_adaptive_avg._PointwiseAdaptiveAvgPoolND
Adaptive channel-wise average-pooling: each channel is treated independently. This module works exactly as
torch.nn.AdaptiveAvgPool3D
, wrapping it in theEquivariantModule
interface.Notice that not all representations support this kind of pooling. In general, only representations which support pointwise non-linearities do.
Warning
Even if the input tensor has a coords attribute, the output of this module will not have one.
PointwiseAdaptiveMaxPool2D
- class PointwiseAdaptiveMaxPool2D(in_type, output_size)[source]
Bases:
escnn.nn.modules.pooling.pointwise_adaptive_max._PointwiseAdaptiveMaxPoolND
Module that implements adaptive channel-wise max-pooling: each channel is treated independently. This module works exactly as
torch.nn.AdaptiveMaxPool2D
, wrapping it in theEquivariantModule
interface.Notice that not all representations support this kind of pooling. In general, only representations which support pointwise non-linearities do.
Warning
Even if the input tensor has a coords attribute, the output of this module will not have one.
PointwiseAdaptiveMaxPool3D
- class PointwiseAdaptiveMaxPool3D(in_type, output_size)[source]
Bases:
escnn.nn.modules.pooling.pointwise_adaptive_max._PointwiseAdaptiveMaxPoolND
Module that implements adaptive channel-wise max-pooling: each channel is treated independently. This module works exactly as
torch.nn.AdaptiveMaxPool3D
, wrapping it in theEquivariantModule
interface.Notice that not all representations support this kind of pooling. In general, only representations which support pointwise non-linearities do.
Warning
Even if the input tensor has a coords attribute, the output of this module will not have one.
Normalization
IIDBatchNorm1d
- class IIDBatchNorm1d(in_type, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)[source]
Bases:
escnn.nn.modules.batchnormalization.iid._IIDBatchNorm
Batch normalization for generic representations for 1D or 0D data (i.e. 3D or 2D inputs).
This batch normalization assumes that all dimensions within the same field have the same variance, i.e. that the covariance matrix of each field in in_type is a scalar multiple of the identity. Moreover, the mean is only computed over the trivial irreps occurring in the input representations (the input representation does not need to be decomposed into a direct sum of irreps since this module can deal with the change of basis).
Similarly, if
affine = True
, a single scale is learnt per input field and the bias is applied only to the trivial irreps.This assumption is equivalent to the usual Batch Normalization in a Group Convolution NN (GCNN), where statistics are shared over the group dimension. See Chapter 4.2 at https://gabri95.github.io/Thesis/thesis.pdf .
- Parameters
in_type (FieldType) – the input field type
eps (float, optional) – a value added to the denominator for numerical stability. Default:
1e-5
momentum (float, optional) – the value used for the
running_mean
andrunning_var
computation. Can be set toNone
for cumulative moving average (i.e. simple average). Default:0.1
affine (bool, optional) – if
True
, this module has learnable affine parameters. Default:True
track_running_stats (bool, optional) – when set to
True
, the module tracks the running mean and variance; when set toFalse
, it does not track such statistics but uses batch statistics in both training and eval modes. Default:True
- export()[source]
Export this module to a normal PyTorch
torch.nn.BatchNorm2d
module and set to “eval” mode.
IIDBatchNorm2d
- class IIDBatchNorm2d(in_type, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)[source]
Bases:
escnn.nn.modules.batchnormalization.iid._IIDBatchNorm
Batch normalization for generic representations for 2D data (i.e. 4D inputs).
This batch normalization assumes that all dimensions within the same field have the same variance, i.e. that the covariance matrix of each field in in_type is a scalar multiple of the identity. Moreover, the mean is only computed over the trivial irreps occurring in the input representations (the input representation does not need to be decomposed into a direct sum of irreps since this module can deal with the change of basis).
Similarly, if
affine = True
, a single scale is learnt per input field and the bias is applied only to the trivial irreps.This assumption is equivalent to the usual Batch Normalization in a Group Convolution NN (GCNN), where statistics are shared over the group dimension. See Chapter 4.2 at https://gabri95.github.io/Thesis/thesis.pdf .
- Parameters
in_type (FieldType) – the input field type
eps (float, optional) – a value added to the denominator for numerical stability. Default:
1e-5
momentum (float, optional) – the value used for the
running_mean
andrunning_var
computation. Can be set toNone
for cumulative moving average (i.e. simple average). Default:0.1
affine (bool, optional) – if
True
, this module has learnable affine parameters. Default:True
track_running_stats (bool, optional) – when set to
True
, the module tracks the running mean and variance; when set toFalse
, it does not track such statistics but uses batch statistics in both training and eval modes. Default:True
- export()[source]
Export this module to a normal PyTorch
torch.nn.BatchNorm2d
module and set to “eval” mode.
IIDBatchNorm3d
- class IIDBatchNorm3d(in_type, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)[source]
Bases:
escnn.nn.modules.batchnormalization.iid._IIDBatchNorm
Batch normalization for generic representations for 3D data (i.e. 5D inputs).
This batch normalization assumes that all dimensions within the same field have the same variance, i.e. that the covariance matrix of each field in in_type is a scalar multiple of the identity. Moreover, the mean is only computed over the trivial irreps occurring in the input representations (the input representation does not need to be decomposed into a direct sum of irreps since this module can deal with the change of basis).
Similarly, if
affine = True
, a single scale is learnt per input field and the bias is applied only to the trivial irreps.This assumption is equivalent to the usual Batch Normalization in a Group Convolution NN (GCNN), where statistics are shared over the group dimension. See Chapter 4.2 at https://gabri95.github.io/Thesis/thesis.pdf .
- Parameters
in_type (FieldType) – the input field type
eps (float, optional) – a value added to the denominator for numerical stability. Default:
1e-5
momentum (float, optional) – the value used for the
running_mean
andrunning_var
computation. Can be set toNone
for cumulative moving average (i.e. simple average). Default:0.1
affine (bool, optional) – if
True
, this module has learnable affine parameters. Default:True
track_running_stats (bool, optional) – when set to
True
, the module tracks the running mean and variance; when set toFalse
, it does not track such statistics but uses batch statistics in both training and eval modes. Default:True
- export()[source]
Export this module to a normal PyTorch
torch.nn.BatchNorm2d
module and set to “eval” mode.
FieldNorm
- class FieldNorm(in_type, eps=1e-05, affine=True)[source]
Bases:
escnn.nn.modules.equivariant_module.EquivariantModule
Normalization module which normalizes each field individually. The statistics are only computed over the channels within a single field (not over the batch dimension or the spatial dimensions). Moreover, this layer does not track running statistics and uses only the current input, so it behaves similarly at train and eval time.
For each individual field, the mean is given by the projection on the subspaces transforming under the trivial representation while the variance is the squared norm of the field, after the mean has been subtracted.
If
affine = True
, a single scale is learnt per input field and the bias is applied only to the trivial irreps (this scale and bias are shared over the spatial dimensions in order to preserve equivariance).Warning
If a field is only containing trivial irreps, this layer will just set its values to zero and, possibly, replace them with a learnable bias if
affine = True
.- Parameters
- forward(input)[source]
Normalize the input feature map
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map
InnerBatchNorm
- class InnerBatchNorm(in_type, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)[source]
Bases:
escnn.nn.modules.equivariant_module.EquivariantModule
Batch normalization for representations with permutation matrices.
Statistics are computed both over the batch and the spatial dimensions and over the channels within the same field (which are permuted by the representation).
Only representations supporting pointwise non-linearities are accepted as input field type.
- Parameters
in_type (FieldType) – the input field type
eps (float, optional) – a value added to the denominator for numerical stability. Default:
1e-5
momentum (float, optional) – the value used for the
running_mean
andrunning_var
computation. Can be set toNone
for cumulative moving average (i.e. simple average). Default:0.1
affine (bool, optional) – if
True
, this module has learnable affine parameters. Default:True
track_running_stats (bool, optional) – when set to
True
, the module tracks the running mean and variance; when set toFalse
, it does not track such statistics but uses batch statistics in both training and eval modes. Default:True
- forward(input)[source]
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map
- export()[source]
Export this module to a normal PyTorch
torch.nn.BatchNorm2d
module and set to “eval” mode.
NormBatchNorm
- class NormBatchNorm(in_type, eps=1e-05, momentum=0.1, affine=True)[source]
Bases:
escnn.nn.modules.equivariant_module.EquivariantModule
Batch normalization for isometric (i.e. which preserves the norm) non-trivial representations.
The module assumes the mean of the vectors is always zero so no running mean is computed and no bias is added. This is guaranteed as long as the representations do not include a trivial representation.
Indeed, if \(\rho\) does not include a trivial representation, it holds:
\[\forall \bold{v} \in \mathbb{R}^n, \ \ \frac{1}{|G|} \sum_{g \in G} \rho(g) \bold{v} = \bold{0}\]Hence, only the standard deviation is normalized.
Only representations which do not contain the trivial representation are allowed. You can check if a representation contains the trivial representation using
contains_trivial()
. To check if a trivial irrep is present in a representation in aFieldType
, you can use:for r in field_type: if r.contains_trivial(): print(f"field type contains a trivial irrep")
- Parameters
in_type (FieldType) – the input field type
eps (float, optional) – a value added to the denominator for numerical stability. Default:
1e-5
momentum (float, optional) – the value used for the
running_mean
andrunning_var
computation. Can be set toNone
for cumulative moving average (i.e. simple average). Default:0.1
affine (bool, optional) – if
True
, this module has learnable scale parameters. Default:True
- forward(input)[source]
Apply norm non-linearities to the input feature map
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map
InducedNormBatchNorm
- class InducedNormBatchNorm(in_type, eps=1e-05, momentum=0.1, affine=True)[source]
Bases:
escnn.nn.modules.equivariant_module.EquivariantModule
Batch normalization for induced isometric representations. This module requires the input fields to be associated to an induced representation from an isometric (i.e. which preserves the norm) non-trivial representation which supports ‘norm’ non-linearities.
The module assumes the mean of the vectors is always zero so no running mean is computed and no bias is added. This is guaranteed as long as the representations do not include a trivial representation.
Indeed, if \(\rho\) does not include a trivial representation, it holds:
\[\forall \bold{v} \in \mathbb{R}^n, \ \ \frac{1}{|G|} \sum_{g \in G} \rho(g) \bold{v} = \bold{0}\]Hence, only the standard deviation is normalized. The same standard deviation, however, is shared by all the sub-fields of the same induced field.
The input representation of the fields is preserved by this operation.
Only representations which do not contain the trivial representation are allowed. You can check if a representation contains the trivial representation using
contains_trivial()
. To check if a trivial irrep is present in a representation in aFieldType
, you can use:for r in field_type: if r.contains_trivial(): print(f"field type contains a trivial irrep")
- Parameters
in_type (FieldType) – the input field type
eps (float, optional) – a value added to the denominator for numerical stability. Default:
1e-5
momentum (float, optional) – the value used for the
running_mean
andrunning_var
computation. Can be set toNone
for cumulative moving average (i.e. simple average). Default:0.1
affine (bool, optional) – if
True
, this module has learnable scale parameters. Default:True
- forward(input)[source]
Apply norm non-linearities to the input feature map
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map
GNormBatchNorm
- class GNormBatchNorm(in_type, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)[source]
Bases:
escnn.nn.modules.equivariant_module.EquivariantModule
Batch normalization for generic representations.
This batch normalization assumes that the covariance matrix of a subset of channels in in_type transforming under an irreducible representation is a scalar multiple of the identity. Moreover, the mean is only computed over the trivial irreps occurring in the input representations. These assumptions are necessary and sufficient conditions for the equivariance in expectation of this module, see Chapter 4.2 at https://gabri95.github.io/Thesis/thesis.pdf .
Similarly, if
affine = True
, a single scale is learnt per input irrep and the bias is applied only to the trivial irreps.Note that the representations in the input field type do not need to be already decomposed into direct sums of irreps since this module can deal with changes of basis.
Warning
However, because the irreps in the input representations rarely appear in a contiguous way, this module might internally use advanced indexing, leading to some computational overhead. Modules like
IIDBatchNorm2d
orIIDBatchNorm3d
, instead, share the same variance with all channels within the same field (and, therefore, over multiple irreps). This can be more efficient if the input field type contains multiple copies of a larger, reducible representation.- Parameters
in_type (FieldType) – the input field type
eps (float, optional) – a value added to the denominator for numerical stability. Default:
1e-5
momentum (float, optional) – the value used for the
running_mean
andrunning_var
computation. Can be set toNone
for cumulative moving average (i.e. simple average). Default:0.1
affine (bool, optional) – if
True
, this module has learnable affine parameters. Default:True
track_running_stats (bool, optional) – when set to
True
, the module tracks the running mean and variance; when set toFalse
, it does not track such statistics but uses batch statistics in both training and eval modes. Default:True
- forward(input)[source]
Apply norm non-linearities to the input feature map
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map
Dropout
FieldDropout
- class FieldDropout(in_type, p=0.5, inplace=False)[source]
Bases:
escnn.nn.modules.equivariant_module.EquivariantModule
Applies dropout to individual fields independently.
Notice that, with respect to
PointwiseDropout
, this module acts on a whole field instead of single channels.- Parameters
- forward(input)[source]
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map
PointwiseDropout
- class PointwiseDropout(in_type, p=0.5, inplace=False)[source]
Bases:
escnn.nn.modules.equivariant_module.EquivariantModule
Applies dropout to individual channels independently.
This class is just a wrapper for
torch.nn.functional.dropout()
in anEquivariantModule
.Only representations supporting pointwise non-linearities are accepted as input field type.
- Parameters
- forward(input)[source]
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map
- export()[source]
Export this module to a normal PyTorch
torch.nn.Dropout
module and set to “eval” mode.
Other Modules
Sequential
- class SequentialModule(*args)[source]
Bases:
escnn.nn.modules.equivariant_module.EquivariantModule
A sequential container similar to
torch.nn.Sequential
.The constructor accepts both a list or an ordered dict of
EquivariantModule
instances.The module also supports indexing, slicing and iteration. If slicing with a step different from 1 is used, one should ensure that adjacent modules in the new sequence are compatible.
Example:
# Example of SequentialModule s = escnn.gspaces.rot2dOnR2(8) c_in = escnn.nn.FieldType(s, [s.trivial_repr]*3) c_out = escnn.nn.FieldType(s, [s.regular_repr]*16) model = escnn.nn.SequentialModule( escnn.nn.R2Conv(c_in, c_out, 5), escnn.nn.InnerBatchNorm(c_out), escnn.nn.ReLU(c_out), ) len(module) # returns 3 module[:2] # returns another SequentialModule containing the first two modules # Example with OrderedDict s = escnn.gspaces.rot2dOnR2(8) c_in = escnn.nn.FieldType(s, [s.trivial_repr]*3) c_out = escnn.nn.FieldType(s, [s.regular_repr]*16) model = escnn.nn.SequentialModule(OrderedDict([ ('conv', escnn.nn.R2Conv(c_in, c_out, 5)), ('bn', escnn.nn.InnerBatchNorm(c_out)), ('relu', escnn.nn.ReLU(c_out)), ]))
- forward(input)[source]
- Parameters
input (GeometricTensor) – the input GeometricTensor
- Returns
the output tensor
- add_module(name, module)[source]
Append
module
to the sequence of modules applied in the forward pass.
- export()[source]
Export this module to a normal PyTorch
torch.nn.Sequential
module and set to “eval” mode.
Restriction
- class RestrictionModule(in_type, id)[source]
Bases:
escnn.nn.modules.equivariant_module.EquivariantModule
Restricts the type of the input to the subgroup identified by
id
.It computes the output type in the constructor and wraps the underlying tensor (
torch.Tensor
) in input with the output type during the forward pass.This module only acts as a wrapper for
escnn.nn.FieldType.restrict()
(orescnn.nn.GeometricTensor.restrict()
). The accepted values ofid
depend on the underlyinggspace
in the input typein_type
; check the documentation of the methodescnn.gspaces.GSpace.restrict()
of the gspace used for further information.See also
escnn.nn.FieldType.restrict()
,escnn.nn.GeometricTensor.restrict()
,escnn.gspaces.GSpace.restrict()
- Parameters
in_type (FieldType) – the input field type
id – a valid id for a subgroup of the space associated with the input type
- export()[source]
Export this module to a normal PyTorch
torch.nn.Identity
module and set to “eval” mode.Warning
Only working with PyTorch >= 1.2
Disentangle
- class DisentangleModule(in_type)[source]
Bases:
escnn.nn.modules.equivariant_module.EquivariantModule
Disentangles the representations in the field type of the input.
This module only acts as a wrapper for
escnn.group.disentangle()
. In the constructor, it disentangles each representation in the input type to build the output type and pre-compute the change of basis matrices needed to transform each input field.During the forward pass, each field in the input tensor is transformed with the change of basis corresponding to its representation.
- Parameters
in_type (FieldType) – the input field type
Upsampling
- class R2Upsampling(in_type, scale_factor=None, size=None, mode='bilinear', align_corners=False)[source]
Bases:
escnn.nn.modules.rdupsampling._RdUpsampling
Wrapper for
torch.nn.functional.interpolate()
. Check its documentation for further details.Only
"bilinear"
and"nearest"
methods are supported. However,"nearest"
is not equivariant; using this method may result in broken equivariance. For this reason, we suggest to use"bilinear"
(default value).Warning
The module supports a
size
parameter as an alternative toscale_factor
. However, the use ofscale_factor
should be preferred, since it guarantees both axes are scaled uniformly, which preserves rotation equivariance. A misuse of the parametersize
can break the overall equivariance, since it might scale the two axes by two different factors.Warning
Even if the input tensor has a coords attribute, the output of this module will not have one.
- Parameters
in_type (FieldType) – the input field type
scale_factor (optional, int) – multiplier for spatial size
mode (str) – algorithm used for upsampling:
nearest
|bilinear
. Default:bilinear
align_corners (bool) – if
True
, the corner pixels of the input and output tensors are aligned, and thus preserving the values at those pixels. This only has effect when mode isbilinear
. Default:False
- class R3Upsampling(in_type, scale_factor=None, size=None, mode='trilinear', align_corners=False)[source]
Bases:
escnn.nn.modules.rdupsampling._RdUpsampling
Wrapper for
torch.nn.functional.interpolate()
. Check its documentation for further details.Only
"trilinear"
and"nearest"
methods are supported. However,"nearest"
is not equivariant; using this method may result in broken equivariance. For this reason, we suggest to use"trilinear"
(default value).Warning
The module supports a
size
parameter as an alternative toscale_factor
. However, the use ofscale_factor
should be preferred, since it guarantees both axes are scaled uniformly, which preserves rotation equivariance. A misuse of the parametersize
can break the overall equivariance, since it might scale the two axes by two different factors.Warning
Even if the input tensor has a coords attribute, the output of this module will not have one.
- Parameters
in_type (FieldType) – the input field type
scale_factor (optional, int) – multiplier for spatial size
mode (str) – algorithm used for upsampling:
nearest
|trilinear
. Default:trilinear
align_corners (bool) – if
True
, the corner pixels of the input and output tensors are aligned, and thus preserving the values at those pixels. This only has effect when mode istrilinear
. Default:False
Multiple
- class MultipleModule(in_type, labels, modules, reshuffle=0)[source]
Bases:
escnn.nn.modules.equivariant_module.EquivariantModule
Split the input tensor in multiple branches identified by the input
labels
and apply to each of them the corresponding module inmodules
A label is associated to each field in the input type, while
modules
assigns a module to apply to each label (or set of labels).modules
should be a list of pairs, each containing anEquivariantModule
and a label (or a list of labels).During forward, fields are grouped by the labels and the input tensor is split accordingly. Then, each subtensor is passed to the corresponding module in
modules
.If
reshuffle
is set to a positive integer, a copy of the input tensor is first built sorting the fields according to the value set:1: fields are sorted by their labels
2: fields are sorted by their labels and, then, by their size
3: fields are sorted by their labels, by their size and, then, by their type
In this way, fields that need to be retrieved together are contiguous and it is possible to exploit slicing to split the tensor. By default,
reshuffle = 0
which means that no sorting is performed and, so, if input fields are not contiguous this layer will use indexing to retrieve sub-tensors.This modules wraps a
BranchingModule
followed by aMergeModule
.- Parameters
- forward(input)[source]
Split the input tensor according to the labels, apply each module to the corresponding input sub-tensors and stack the results.
- Parameters
input (GeometricTensor) – the input GeometricTensor
- Returns
the concatenation of the output of each module
Reshuffle
- class ReshuffleModule(in_type, permutation)[source]
Bases:
escnn.nn.modules.equivariant_module.EquivariantModule
Permutes the fields of the input tensor according to the input
permutation
.The parameter
permutation
should be a list containing a permutation of the integers{0, 1, ..., n-1}
, wheren
is the number of fields ofin_type
(seeescnn.nn.FieldType.__len__()
).
Mask
- class MaskModule(in_type, S, margin=0.0, sigma=2.0)[source]
Bases:
escnn.nn.modules.equivariant_module.EquivariantModule
Performs an element-wise multiplication of the input with a mask of shape \(S^n\), where \(n\) is the dimensionality of the underlying space.
The mask has value \(1\) in all pixels with distance smaller than \(\frac{S - 1}{2} \times (1 - \frac{\mathrm{margin}}{100})\) from the center of the mask and \(0\) elsewhere. Values change smoothly between the two regions.
This operation is useful to remove from an input image or feature map all the part of the signal defined on the pixels which lay outside the circle inscribed in the grid. Because a rotation would move these pixels outside the grid, this information would anyways be discarded when rotating an image. However, allowing a model to use this information might break the guaranteed equivariance as rotated and non-rotated inputs have different information content.
Note
The input tensors provided to this module must have the following dimensions: \(B \times C \times S^n\), where \(B\) is the minibatch dimension, \(C\) is the channels dimension, and \(S^n\) are the \(n\) spatial dimensions (corresponding to the Euclidean basespace \(\R^n\)) associated with the given input field type, i.e.
in_type.gspace.dimensionality
. Each Euclidean dimension must be of size \(S\).For example, if \(S=10\) and the
in_type.gspace.dimensionality=2
, then the input tensors should be of size \(B \times C \times 10 \times 10\). Ifin_type.gspace.dimensionality=3
instead, then the input tensors should be of size \(B \times C \times 10 \times 10 \times 10\).- Parameters
in_type (FieldType) – input field type
S (int) – the shape of the mask and the expected inputs
margin (float, optional) – margin around the mask in percentage with respect to the radius of the mask
sigma (float, optional) – how quickly masked pixels should approach 0. This can be thought of a standard deviation in units of pixels/voxels. For example, the default value of 2 means that only 5% of the original signal will remain 4 px into the masked region.
Identity
- class IdentityModule(in_type)[source]
Bases:
escnn.nn.modules.equivariant_module.EquivariantModule
Simple module which does not perform any operation on the input tensor.
- Parameters
in_type (FieldType) – input (and output) type of this module
- forward(input)[source]
Returns the input tensor.
- Parameters
input (GeometricTensor) – the input GeometricTensor
- Returns
the output tensor
- export()[source]
Export this module to a normal PyTorch
torch.nn.Identity
module and set to “eval” mode.Warning
Only working with PyTorch >= 1.2
HarmonicPolynomialR3
- class HarmonicPolynomialR3(L, group='so3')[source]
Bases:
escnn.nn.modules.equivariant_module.EquivariantModule
Module which computes the harmonic polynomials in \(\R^3\) up to order L.
The argument group can be a string (“so3” or “o3”) or a group (instance of
SO3
orO3
).This equivariant module takes a set of 3-dimensional points transforming according to the
standard_representation()
of \(SO(3)\) (or thestandard_representation()
of \(O(3)\)) and outputs \((L+1)^2\) dimensional feature vectors transforming like spherical harmonics according tobl_sphere_representation()
of \(SO(3)\) (orbl_sphere_representation()
of \(O(3)\)) with L=L.See also
Harmonic polynomial are related to the spherical harmonics. Check the Wikipedia page about them.
Weight Initialization
- generalized_he_init(tensor, basismanager, cache=False)[source]
Initialize the weights of a convolutional layer with a generalized He’s weight initialization method.
Because the computation of the variances can be expensive, to save time on consecutive runs of the same model, it is possible to cache the tensor containing the variance of each weight, for a specific
`basismanager`
. This can be useful if a network contains multiple convolution layers of the same kind (same input and output types, same kernel size, etc.) or if one needs to train the same network from scratch multiple times (e.g. hyper-parameter search over learning rate).Note
The variance tensor is cached in memory and therefore is only available to the current process.
- Parameters
tensor (torch.Tensor) – the tensor containing the weights
basismanager (BasisManager) – the basis expansion method
cache (bool, optional) – cache the variance tensor. By default,
`cache=False`
- deltaorthonormal_init(tensor, basismanager)[source]
Initialize the weights of a convolutional layer with delta-orthogonal initialization.
If no ‘radius’ attribute is present in the basismanager.get_basis_info(), it is assumed the parameters parametrize a
Linear
layer.- Parameters
tensor (torch.Tensor) – the tensor containing the weights
basismanager (BasisManager) – the basis expansion method