e2cnn.nn
This subpackage provides implementations of equivariant neural network modules.
In an equivariant network, features are associated with a transformation law under actions of a symmetry group.
The transformation law of a feature field is implemented by its FieldType
which can be interpreted as a data type.
A GeometricTensor
is wrapping a torch.Tensor
to endow it with a FieldType
.
Geometric tensors are processed by EquivariantModule
s which are torch.nn.Module
s that guarantee the
specified behavior of their output fields given a transformation of their input fields.
This subpackage depends on e2cnn.group and e2cnn.gspaces.
To enable efficient deployment of equivariant networks, many EquivariantModule
s implement a
export()
method which converts a trained equivariant module into a pure PyTorch
module, with few or no dependencies with e2cnn.
Not all modules support this feature yet, so read each module’s documentation to check whether it implements this method
or not.
We provide a simple example:
# build a simple equivariant model using a SequentialModule
s = e2cnn.gspaces.Rot2dOnR2(8)
c_in = e2cnn.nn.FieldType(s, [s.trivial_repr]*3)
c_hid = e2cnn.nn.FieldType(s, [s.regular_repr]*3)
c_out = e2cnn.nn.FieldType(s, [s.regular_repr]*1)
net = SequentialModule(
R2Conv(c_in, c_hid, 5, bias=False),
InnerBatchNorm(c_hid),
ReLU(c_hid, inplace=True),
PointwiseMaxPool(c_hid, kernel_size=3, stride=2, padding=1),
R2Conv(c_hid, c_out, 3, bias=False),
InnerBatchNorm(c_out),
ELU(c_out, inplace=True),
GroupPooling(c_out)
)
# train the model
# ...
# export the model
net.eval()
net_exported = net.export()
print(net)
> SequentialModule(
> (0): R2Conv([8-Rotations: {irrep_0, irrep_0, irrep_0}], [8-Rotations: {regular, regular, regular}], kernel_size=5, stride=1, bias=False)
> (1): InnerBatchNorm([8-Rotations: {regular, regular, regular}], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
> (2): ReLU(inplace=True, type=[8-Rotations: {regular, regular, regular}])
> (3): PointwiseMaxPool()
> (4): R2Conv([8-Rotations: {regular, regular, regular}], [8-Rotations: {regular}], kernel_size=3, stride=1, bias=False)
> (5): InnerBatchNorm([8-Rotations: {regular}], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
> (6): ELU(alpha=1.0, inplace=True, type=[8-Rotations: {regular}])
> (7): GroupPooling([8-Rotations: {regular}])
> )
print(net_exported)
> Sequential(
> (0): Conv2d(3, 24, kernel_size=(5, 5), stride=(1, 1), bias=False)
> (1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
> (2): ReLU(inplace=True)
> (3): MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), dilation=(1, 1), ceil_mode=False)
> (4): Conv2d(24, 8, kernel_size=(3, 3), stride=(1, 1), bias=False)
> (5): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
> (6): ELU(alpha=1.0, inplace=True)
> (7): MaxPoolChannels(kernel_size=8)
> )
# check that the two models are equivalent
x = torch.randn(10, c_in.size, 31, 31)
x = GeometricTensor(x, c_in)
y1 = net(x).tensor
y2 = net_exported(x.tensor)
assert torch.allclose(y1, y2)
Contents
Field Type
- class FieldType(gspace, representations)[source]
An
FieldType
can be interpreted as the data type of a feature space. It describes:the base space on which a feature field is living and its symmetries considered
the transformation law of feature fields under the action of the fiber group
The former is formalize by a choice of
gspace
while the latter is determined by a choice of group representations (representations
), passed as a list ofRepresentation
instances. Each single representation in this list corresponds to one independent feature field contained in the feature space. The inputrepresentations
need to belong togspace
’s fiber group (e2cnn.gspaces.GSpace.fibergroup
).Note
Mathematically, this class describes a (trivial) vector bundle, associated to the symmetry group \((\R^D, +) \rtimes G\).
Given a principal bundle \(\pi: (\R^D, +) \rtimes G \to \R^D, tg \mapsto tG\) with fiber group \(G\), an associated vector bundle has the same base space \(\R^D\) but its fibers are vector spaces like \(\mathbb{R}^c\). Moreover, these vector spaces are associated to a \(c\)-dimensional representation \(\rho\) of the fiber group \(G\) and transform accordingly.
The representation \(\rho\) is defined as the direct sum of the representations \(\{\rho_i\}_i\) in
representations
. See alsodirectsum()
.- Parameters
gspace (GSpace) – the space where the feature fields live and its symmetries
representations (list) – a list of
Representation
s of thegspace
’s fiber group, determining the transformation laws of the feature fields
- Variables
- property fibergroup: e2cnn.group.group.Group
The fiber group of
gspace
.- Returns
the fiber group
- property representation: e2cnn.group.representation.Representation
The (combined) representations of this field type. They describe how the feature vectors transform under the fiber group action, that is, how the channels mix.
It is the direct sum (
directsum()
) of the representations ine2cnn.nn.FieldType.representations
.Because a feature space can contain a very large number of feature fields, computing this representation as the direct sum of many small representations can be expensive. Hence, this representation is only built the first time it is explicitly used, in order to avoid unnecessary overhead when not needed.
- Returns
the
Representation
describing the whole feature space
- property irreps
Ordered list of irreps contained in the
representation
of the field type. It is the concatenation of the irreps in each representation ine2cnn.nn.FieldType.representations
.- Returns
list of irreps
- property change_of_basis: scipy.sparse.coo.coo_matrix
The change of basis matrix which decomposes the field types representation into irreps, given as a sparse (block diagonal) matrix (
scipy.sparse.coo_matrix
).It is the direct sum of the change of basis matrices of each representation in
e2cnn.nn.FieldType.representations
.See also
e2cnn.group.Representation.change_of_basis
- Returns
the change of basis
- property change_of_basis_inv
Inverse of the (sparse) change of basis matrix. See
e2cnn.nn.FieldType.change_of_basis
for more details.- Returns
the inverted change of basis
- get_dense_change_of_basis()[source]
The method returns a dense
torch.Tensor
containing a copy of the change-of-basis matrix.See also
See
e2cnn.nn.FieldType.change_of_basis
for more details.
- get_dense_change_of_basis_inv()[source]
The method returns a dense
torch.Tensor
containing a copy of the inverse of the change-of-basis matrix.See also
See
e2cnn.nn.FieldType.change_of_basis
for more details.
- transform(input, element)[source]
The method takes a PyTorch’s tensor, compatible with this type (i.e. whose spatial dimensions are supported by the base space and whose number of channels equals the
e2cnn.nn.FieldType.size
of this type), and an element of the fiber group of this type.Transform the input tensor according to the group representation associated with the input element and its (induced) action on the base space.
Warning
This method is internally implemented using
`numpy`
. This means that the input tensor is detached (and moved to CPU) before the transformation, therefore no gradient is propagated back through this operation.See also
See
e2cnn.nn.GeometricTensor.transform_fibers()
to transform only the fibers, i.e. not transform the base space.See
e2cnn.gspaces.GSpace.featurefield_action()
for more details.- Parameters
input (torch.Tensor) – input tensor
element – element of the fiber group
- Returns
transformed tensor
- restrict(id)[source]
Reduce the symmetries modeled by the
FieldType
by choosing a subgroup of its fiber group as specified byid
. This implies a restriction of each representation ine2cnn.nn.FieldType.representations
to this subgroup.See also
Check the documentation of the
restrict()
method in the subclass ofGSpace
used for a description of the parameterid
.- Parameters
id – identifier of the subgroup to which the
FieldType
and itse2cnn.nn.FieldType.representations
should be restricted- Returns
the restricted type
- sorted()[source]
Return a new field type containing the fields of the current one sorted by their dimensionalities. It is built from the
e2cnn.nn.FieldType.representations
of this field type sorted.- Returns
the sorted field type
- __add__(other)[source]
Returns a field type associate with the direct sum \(\rho = \rho_1 \oplus \rho_2\) of the representations \(\rho_1\) and \(\rho_2\) of two field types.
In practice, the method builds a new
FieldType
using the concatenation of the listse2cnn.nn.FieldType.representations
of the two field types.The two field types need to be associated with the same
GSpace
.- Parameters
other (FieldType) – the other addend
- Returns
the direct sum
- __len__()[source]
Return the number of feature fields in this
FieldType
, i.e. the length ofe2cnn.nn.FieldType.representations
.Note
This is in general different from
e2cnn.nn.FieldType.size
.- Returns
the number of fields in this type
- fields_names()[source]
Return an ordered list containing the names of the representation associated with each field.
- Returns
the list of fields’ representations’ names
- index_select(index)[source]
Build a new
FieldType
from the current one by taking theRepresentation
s selected by the inputindex
.- Parameters
index (list) – a list of integers in the range
{0, ..., N-1}
, whereN
is the number of representations in the current field type- Returns
the new field type
- property fields_end: numpy.ndarray
Array containing the index of the first channel following each field. More precisely, the integer in the \(i\)-th position is equal to the index of the last channel of the \(i\)-th field plus \(1\).
- property fields_start: numpy.ndarray
Array containing the index of the first channel of each field. More precisely, the integer in the \(i\)-th position is equal to the index of the first channel of the \(i\)-th field.
- group_by_labels(labels)[source]
Associate a label to each feature field (or representation in
e2cnn.nn.FieldType.representations
) and group them accordingly into newFieldType
s.- Parameters
labels (list) – a list of strings with length equal to the number of representations in
e2cnn.nn.FieldType.representations
- Returns
a dictionary mapping each different input label to a new field type
- __iter__()[source]
It is possible to iterate over all
representations
in a field type by usingFieldType
as an iterable object.
- property testing_elements
Alias for
self.gspace.testing_elements
.
Geometric Tensor
- class GeometricTensor(tensor, type)[source]
A GeometricTensor can be interpreted as a typed tensor. It is wrapping a common
torch.Tensor
and endows it with a (compatible)FieldType
as transformation law.The
FieldType
describes the action of a group \(G\) on the tensor. This action includes both a transformation of the base space and a transformation of the channels according to a \(G\)-representation \(\rho\).All e2cnn neural network operations have
GeometricTensor
s as inputs and outputs. They perform a dynamic typechecking, ensuring that the transformation laws of the data and the operation match. See alsoEquivariantModule
.As usual, the first dimension of the tensor is interpreted as the batch dimension. The second is the fiber (or channel) dimension, which is associated with a group representation by the field type. The following dimensions are the spatial dimensions (like in a conventional CNN).
The operations of addition and scalar multiplication are supported. For example:
gs = e2cnn.gspaces.Rot2dOnR2(8) type = e2cnn.nn.FieldType(gs, [gs.regular_repr]*3) t1 = e2cnn.nn.GeometricTensor(torch.randn(1, 24, 3, 3), type) t2 = e2cnn.nn.GeometricTensor(torch.randn(1, 24, 3, 3), type) # addition t3 = t1 + t2 # scalar product t3 = t1 * 3. # scalar product also supports tensors containing only one scalar t3 = t1 * torch.tensor(3.) # inplace operations are also supported t1 += t2 t2 *= 3.
Warning
The multiplication of a PyTorch tensor containing only a scalar with a GeometricTensor is only supported when using PyTorch 1.4 or higher (see this issue )
A GeometricTensor supports slicing in a similar way to PyTorch’s
torch.Tensor
. More precisely, slicing along the batch (1st) and the spatial (3rd, 4th, …) dimensions works as usual. However, slicing the fiber (2nd) dimension would break equivariance when splitting channels belonging to the same field. To prevent this, slicing on the second dimension is defined over fields instead of channels.Warning
GeometricTensor only supports basic slicing but it does not support advanced indexing (see NumPy’s documentation about indexing for more details). Moreover, in contrast to NumPy and PyTorch, an index containing a single integer value does not reduce the dimensionality of the tensor. In this way, the resulting tensor can always be interpreted as a GeometricTensor.
We give few examples to illustrate this behavior:
# Example of GeometricTensor slicing space = e2cnn.gspaces.Rot2dOnR2(4) type = e2cnn.nn.FieldType(space, [ # field type # index # size space.regular_repr, # 0 # 4 space.regular_repr, # 1 # 4 space.irrep(1), # 2 # 2 space.irrep(1), # 3 # 2 space.trivial_repr, # 4 # 1 space.trivial_repr, # 5 # 1 space.trivial_repr, # 6 # 1 ]) # sum = 15 # this FieldType contains 8 fields len(type) >> 7 # the size of this FieldType is equal to the sum of the sizes of each of its fields type.size >> 15 geom_tensor = e2cnn.nn.GeometricTensor(torch.randn(10, type.size, 9, 9), type) geom_tensor.shape >> torch.Size([10, 15, 9, 9]) geom_tensor[1:3, :, 2:5, 2:5].shape >> torch.Size([2, 15, 3, 3]) geom_tensor[..., 2:5].shape >> torch.Size([10, 15, 9, 3]) # the tensor contains the fields 1:4, i.e 1, 2 and 3 # these fields have size, respectively, 4, 2 and 2 # so the resulting tensor has 8 channels geom_tensor[:, 1:4, ...].shape >> torch.Size([10, 8, 9, 9]) # the tensor contains the fields 0:6:2, i.e 0, 2 and 4 # these fields have size, respectively, 4, 2 and 1 # so the resulting tensor has 7 channels geom_tensor[:, 0:6:2].shape >> torch.Size([10, 7, 9, 9]) # the tensor contains only the field 2, which has size 2 # note, also, that even though a single index is used for the batch dimension, the resulting tensor # still has 4 dimensions geom_tensor[3, 2].shape >> torch.Size(1, 2, 9, 9)
Warning
Slicing over the fiber (2nd) dimension with
step > 1
or with a negative step is converted into indexing over the channels. This means that, in these cases, slicing behaves like advanced indexing in PyTorch and NumPy returning a copy instead of a view. For more details, see the note here and NumPy’s docs .Note
Slicing is not supported for setting values inside the tensor (i.e.
__setitem__()
is not implemented). Indeed, depending on the values which are assigned, this operation can break the symmetry of the tensor which may not transform anymore according to its transformation law (specified bytype
). In case this feature is necessary, one can directly access the underlyingtorch.Tensor
, e.g.geom_tensor.tensor[:3, :, 2:5, 2:5] = torch.randn(3, 4, 3, 3)
, although this is not recommended.- Parameters
tensor (torch.Tensor) – the tensor data
type (FieldType) – the type of the tensor, modeling its transformation law
- Variables
~.tensor (torch.Tensor) –
~.type (FieldType) –
- restrict(id)[source]
Restrict the field type of this tensor.
The method returns a new
GeometricTensor
whosetype
is equal to this tensor’stype
restricted to a subgroup \(H<G\) (seee2cnn.nn.FieldType.restrict()
). The restrictedtype
is associated with the restricted representation \(\Res{H}{G}\rho\) of the \(G\)-representation \(\rho\) associated to this tensor’stype
. The inputid
specifies the subgroup \(H < G\).Notice that the underlying
tensor
instance will be shared between the current tensor and the returned one.Warning
The method builds the new representation on the fly; hence, if this operation is needed at run time, we suggest to use
e2cnn.nn.RestrictionModule
which pre-computes the new representation offline.See also
Check the documentation of the
restrict()
method in theGSpace
instance used for a description of the parameterid
.- Parameters
id – the id identifying the subgroup \(H\) the representations are restricted to
- Returns
the geometric tensor with the restricted representations
- split(breaks)[source]
Split this tensor on the channel dimension in a list of smaller tensors. The original tensor is split at the fields specified by the index list
breaks
.If the tensor is associated with the list of fields \(\{\rho_i\}_{i=0}^{L-1}\) (where \(L\) equals len(self.type); see also
e2cnn.nn.FieldType.representations
), the \(j\)-th output tensor (\(j>0\)) will contain the fields \(\rho_{\text{breaks}[j-1]}, \dots, \rho_{\text{breaks}[j]-1}\) of the original tensor. The \(j=0\)-th tensor contains the fields \(\rho_{0}, \dots, \rho_{\text{breaks}[0]-1}\) while the last tensor (\(j = len(breaks)\)) contains the last fields \(\rho_{\text{breaks}[-1]}, \dots, \rho_{L-1}\).Note
breaks must be sorted list of integers greater than 0 and smaller than len(self.type) - 1.
If breaks = None, the tensor is split at each field. This is equivalent to using breaks = list(range(1, len(self.type))).
Example
space = e2cnn.gspaces.Rot2dOnR2(4) type = e2cnn.nn.FieldType(space, [ space.regular_repr, # size = 4 space.regular_repr, # size = 4 space.irrep(1), # size = 2 space.irrep(1), # size = 2 space.trivial_repr, # size = 1 space.trivial_repr, # size = 1 space.trivial_repr, # size = 1 ]) # sum = 15 type.size >> 15 geom_tensor = e2cnn.nn.GeometricTensor(torch.randn(10, type.size, 7, 7), type) geom_tensor.shape >> torch.Size([10, 15, 7, 7]) # split the tensor in 3 parts len(geom_tensor.split([4, 6])) >> 3 # the first contains # - the first 2 regular fields (2*4 = 8 channels) # - 2 vector fields (irrep(1)) (2*2 = 4 channels) # and, therefore, contains 12 channels geom_tensor.split([4, 6])[0].shape >> torch.Size([10, 12, 7, 7]) # the second contains only 2 scalar (trivial) fields (2*1 = 2 channels) geom_tensor.split([4, 6])[1].shape >> torch.Size([10, 2, 7, 7]) # the last contains only 1 scalar (trivial) field (1*1 = 1 channels) geom_tensor.split([4, 6])[2].shape >> torch.Size([10, 1, 7, 7])
- Parameters
breaks (list) – indices of the fields where to split the tensor
- Returns
list of
GeometricTensor
s into which the original tensor is chunked
- transform(element)[source]
Transform the current tensor according to the group representation associated to the input element and its induced action on the base space
Warning
The input tensor is detached before the transformation therefore no gradient is backpropagated through this operation
See
e2cnn.nn.GeometricTensor.transform_fibers()
to transform only the fibers, i.e. not transform the base space.- Parameters
element – an element of the group of symmetries of the fiber.
- Returns
the transformed tensor
- transform_fibers(element)[source]
Transform the feature vectors of the underlying tensor according to the group representation associated to the input element.
Interpreting the tensor as a vector-valued signal \(f: X \to \mathbb{R}^c\) over a base space \(X\) (where \(c\) is the number of channels of the tensor), given the input
element
\(g \in G\) (\(G\) fiber group) the method returns the new signal \(f'\):\[f'(x) := \rho(g) f(x)\]for \(x \in X\) point in the base space and \(\rho\) the representation of \(G\) in the field type of this tensor.
Notice that the input element has to be an element of the fiber group of this tensor’s field type.
See also
See
e2cnn.nn.GeometricTensor.transform()
to transform the whole tensor.- Parameters
element – an element of the group of symmetries of the fiber.
- Returns
the transformed tensor
- property shape
Alias for
self.tensor.shape
- to(*args, **kwargs)[source]
Alias for
self.tensor.to(*args, **kwargs)
.Applies
torch.Tensor.to()
to the underlying tensor and wraps the resulting tensor in a newGeometricTensor
with the same type.
- __getitem__(slices)[source]
A GeometricTensor supports slicing in a similar way to PyTorch’s
torch.Tensor
. More precisely, slicing along the batch (1st) and the spatial (3rd, 4th, …) dimensions works as usual. However, slicing along the channel dimension could break equivariance by splitting the channels belonging to the same field. For this reason, slicing on the second dimension is not defined over the channels but over fields.When a continuous (step=1) slice is used over the fields/channels dimension (the 2nd axis), it is converted into a continuous slice over the channels. This is not possible when the step is greater than 1 or negative. In such cases, the slice over the fields needs to be converted into an index over the channels.
Moreover, when a single integer is used to index an axis, that axis is not discarded as in PyTorch but is preserved with size 1.
Slicing is not supported for setting values inside the tensor (i.e.
object.__setitem__()
is not implemented).
- __add__(other)[source]
Add two compatible
GeometricTensor
using pointwise addition. The two tensors needs to have the same shape and be associated to the same field type.- Parameters
other (GeometricTensor) – the other geometric tensor
- Returns
the sum
- __sub__(other)[source]
Subtract two compatible
GeometricTensor
using pointwise subtraction. The two tensors needs to have the same shape and be associated to the same field type.- Parameters
other (GeometricTensor) – the other geometric tensor
- Returns
their difference
- __iadd__(other)[source]
Add a compatible
GeometricTensor
to this tensor inplace. The two tensors needs to have the same shape and be associated to the same field type.- Parameters
other (GeometricTensor) – the other geometric tensor
- Returns
this tensor
- __isub__(other)[source]
Subtract a compatible
GeometricTensor
to this tensor inplace. The two tensors needs to have the same shape and be associated to the same field type.- Parameters
other (GeometricTensor) – the other geometric tensor
- Returns
this tensor
- __mul__(other)[source]
Scalar product of this
GeometricTensor
with a scalar. The operation is done inplace.The scalar can be a float number of a
torch.Tensor
containing only one scalar (i.e.torch.numel()
should return 1).
- __rmul__(other)
Scalar product of this
GeometricTensor
with a scalar. The operation is done inplace.The scalar can be a float number of a
torch.Tensor
containing only one scalar (i.e.torch.numel()
should return 1).
- __imul__(other)[source]
Scalar product of this
GeometricTensor
with a scalar.The scalar can be a float number of a
torch.Tensor
containing only one scalar (i.e.torch.numel()
should return 1).
Equivariant Module
- class EquivariantModule[source]
Abstract base class for all equivariant modules.
An
EquivariantModule
is a subclass oftorch.nn.Module
. It follows that any subclass ofEquivariantModule
needs to implement theforward()
method. With respect to a generaltorch.nn.Module
, an equivariant module implements a typed function as both its input and its output are associated with specificFieldType
s. Therefore, usually, the inputs and the outputs of an equivariant module are not just instances oftorch.Tensor
butGeometricTensor
s.As a subclass of
torch.nn.Module
, it supports most of the commonly used methods (e.g.torch.nn.Module.to()
,torch.nn.Module.cuda()
,torch.nn.Module.train()
ortorch.nn.Module.eval()
)Many equivariant modules implement a
export()
method which converts the module to eval mode and returns a pure PyTorch implementation of it. This can be used after training to efficiently deploy the model without, for instance, the overhead of the automatic type checking performed by all the modules in this library.Warning
Not all modules implement this feature yet. If the
export()
method is called in a module which does not implement it yet, aNotImplementedError
is raised. Check the documentation of each individual module to understand if the method is implemented.- Variables
~.in_type (FieldType) – type of the
GeometricTensor
expected as input~.out_type (FieldType) – type of the
GeometricTensor
returned as output
- abstract evaluate_output_shape(input_shape)[source]
Compute the shape the output tensor which would be generated by this module when a tensor with shape
input_shape
is provided as input.- Parameters
input_shape (tuple) – shape of the input tensor
- Returns
shape of the output tensor
- check_equivariance(atol=1e-07, rtol=1e-05)[source]
Method that automatically tests the equivariance of the current module. The default implementation of this method relies on
e2cnn.nn.GeometricTensor.transform()
and uses the the group elements intesting_elements
.This method can be overwritten for custom tests.
- Returns
a list containing containing for each testing element a pair with that element and the corresponding equivariance error
- export()[source]
Export recursively each submodule to a normal PyTorch module and set to “eval” mode.
Warning
Not all modules implement this feature yet. If the
export()
method is called in a module which does not implement it yet, aNotImplementedError
is raised. Check the documentation of each individual module to understand if the method is implemented.
Utils
direct sum
- tensor_directsum(tensors)[source]
Concatenate a list of
GeometricTensor
s on the channels dimension (dim=1
). The input tensors have to be compatible: they need to have the same shape except for the channels dimension (dim=1
). In the resultingGeometricTensor
, the channels dimension will be associated with the direct sum representation of the representations of the input tensors.See also
- Parameters
tensors (list) – a list of
GeometricTensor
s- Returns
the direct sum of the inputs
Planar Convolution and Differential Operators
R2Conv
- class R2Conv(in_type, out_type, kernel_size, padding=0, stride=1, dilation=1, padding_mode='zeros', groups=1, bias=True, basisexpansion='blocks', sigma=None, frequencies_cutoff=None, rings=None, maximum_offset=None, recompute=False, basis_filter=None, initialize=True)[source]
Bases:
e2cnn.nn.modules.equivariant_module.EquivariantModule
G-steerable planar convolution mapping between the input and output
FieldType
s specified by the parametersin_type
andout_type
. This operation is equivariant under the action of \(\R^2\rtimes G\) where \(G\) is thee2cnn.nn.FieldType.fibergroup
ofin_type
andout_type
.Specifically, let \(\rho_\text{in}: G \to \GL{\R^{c_\text{in}}}\) and \(\rho_\text{out}: G \to \GL{\R^{c_\text{out}}}\) be the representations specified by the input and output field types. Then
R2Conv
guarantees an equivariant mapping\[\kappa \star [\mathcal{T}^\text{in}_{g,u} . f] = \mathcal{T}^\text{out}_{g,u} . [\kappa \star f] \qquad\qquad \forall g \in G, u \in \R^2\]where the transformation of the input and output fields are given by
\[\begin{split}[\mathcal{T}^\text{in}_{g,u} . f](x) &= \rho_\text{in}(g)f(g^{-1} (x - u)) \\ [\mathcal{T}^\text{out}_{g,u} . f](x) &= \rho_\text{out}(g)f(g^{-1} (x - u)) \\\end{split}\]The equivariance of G-steerable convolutions is guaranteed by restricting the space of convolution kernels to an equivariant subspace. As proven in 3D Steerable CNNs, this parametrizes the most general equivariant convolutional map between the input and output fields. For feature fields on \(\R^2\) (e.g. images), the complete G-steerable kernel spaces for \(G \leq \O2\) is derived in General E(2)-Equivariant Steerable CNNs.
During training, in each forward pass the module expands the basis of G-steerable kernels with learned weights before calling
torch.nn.functional.conv2d()
. Wheneval()
is called, the filter is built with the current trained weights and stored for future reuse such that no overhead of expanding the kernel remains.Warning
When
train()
is called, the attributesfilter
andexpanded_bias
are discarded to avoid situations of mismatch with the learnable expansion coefficients. See alsoe2cnn.nn.R2Conv.train()
.This behaviour can cause problems when storing the
state_dict()
of a model while in a mode and lately loading it in a model with a different mode, as the attributes of the class change. To avoid this issue, we recommend converting the model to eval mode before storing or loading the state dictionary.The learnable expansion coefficients of the this module can be initialized with the methods in
e2cnn.nn.init
. By default, the weights are initialized in the constructors usinggeneralized_he_init()
.Warning
This initialization procedure can be extremely slow for wide layers. In case initializing the model is not required (e.g. before loading the state dict of a pre-trained model) or another initialization method is preferred (e.g.
deltaorthonormal_init()
), the parameterinitialize
can be set toFalse
to avoid unnecessary overhead.The parameters
basisexpansion
,sigma
,frequencies_cutoff
,rings
andmaximum_offset
are optional parameters used to control how the basis for the filters is built, how it is sampled on the filter grid and how it is expanded to build the filter. We suggest to keep these default values.- Parameters
in_type (FieldType) – the type of the input field, specifying its transformation law
out_type (FieldType) – the type of the output field, specifying its transformation law
kernel_size (int) – the size of the (square) filter
padding (int, optional) – implicit zero paddings on both sides of the input. Default:
0
padding_mode (str, optional) –
zeros
,reflect
,replicate
orcircular
. Default:zeros
stride (int, optional) – the stride of the kernel. Default:
1
dilation (int, optional) – the spacing between kernel elements. Default:
1
groups (int, optional) – number of blocked connections from input channels to output channels. It allows depthwise convolution. When used, the input and output types need to be divisible in
groups
groups, all equal to each other. Default:1
.bias (bool, optional) – Whether to add a bias to the output (only to fields which contain a trivial irrep) or not. Default
True
basisexpansion (str, optional) – the basis expansion algorithm to use
sigma (list or float, optional) – width of each ring where the bases are sampled. If only one scalar is passed, it is used for all rings.
frequencies_cutoff (callable or float, optional) – function mapping the radii of the basis elements to the maximum frequency accepted. If a float values is passed, the maximum frequency is equal to the radius times this factor. By default (
None
), a more complex policy is used.rings (list, optional) – radii of the rings where to sample the bases
maximum_offset (int, optional) – number of additional (aliased) frequencies in the intertwiners for finite groups. By default (
None
), all additional frequencies allowed by the frequencies cut-off are used.recompute (bool, optional) – if
True
, recomputes a new basis for the equivariant kernels. By Default (False
), it caches the basis built or reuse a cached one, if it is found.basis_filter (callable, optional) – function which takes as input a descriptor of a basis element (as a dictionary) and returns a boolean value: whether to preserve (
True
) or discard (False
) the basis element. By default (None
), no filtering is applied.initialize (bool, optional) – initialize the weights of the model. Default:
True
- Variables
~.weights (torch.Tensor) – the learnable parameters which are used to expand the kernel
~.filter (torch.Tensor) – the convolutional kernel obtained by expanding the parameters in
weights
~.bias (torch.Tensor) – the learnable parameters which are used to expand the bias, if
bias=True
~.expanded_bias (torch.Tensor) – the equivariant bias which is summed to the output, obtained by expanding the parameters in
bias
- property basisexpansion: e2cnn.nn.modules.r2_conv.basisexpansion.BasisExpansion
Submodule which takes care of building the filter.
It uses the learnt
weights
to expand a basis and returns a filter in the usual form used by conventional convolutional modules. It uses the learnedweights
to expand the kernel in the G-steerable basis and returns it in the shape \((c_\text{out}, c_\text{in}, s^2)\), where \(s\) is thekernel_size
.
- expand_parameters()[source]
Expand the filter in terms of the
e2cnn.nn.R2Conv.weights
and the expanded bias in terms ofe2cnn.nn.R2Conv.bias
.- Returns
the expanded filter and bias
- forward(input)[source]
Convolve the input with the expanded filter and bias.
- Parameters
input (GeometricTensor) – input feature field transforming according to
in_type
- Returns
output feature field transforming according to
out_type
- train(mode=True)[source]
If
mode=True
, the method sets the module in training mode and discards thefilter
andexpanded_bias
attributes.If
mode=False
, it sets the module in evaluation mode. Moreover, the method builds the filter and the bias using the current values of the trainable parameters and store them infilter
andexpanded_bias
such that they are not recomputed at each forward pass.Warning
This behaviour can cause problems when storing the
state_dict()
of a model while in a mode and lately loading it in a model with a different mode, as the attributes of this class change. To avoid this issue, we recommend converting the model to eval mode before storing or loading the state dictionary.- Parameters
mode (bool, optional) – whether to set training mode (
True
) or evaluation mode (False
). Default:True
.
- export()[source]
Export this module to a normal PyTorch
torch.nn.Conv2d
module and set to “eval” mode.
R2ConvTransposed
- class R2ConvTransposed(in_type, out_type, kernel_size, padding=0, output_padding=0, stride=1, dilation=1, groups=1, bias=True, basisexpansion='blocks', sigma=None, frequencies_cutoff=None, rings=None, maximum_offset=None, recompute=False, basis_filter=None, initialize=True)[source]
Bases:
e2cnn.nn.modules.equivariant_module.EquivariantModule
Transposed G-steerable planar convolution layer.
Warning
Transposed convolution can produce artifacts which can harm the overall equivariance of the model. We suggest using
R2Upsampling
combined withR2Conv
to perform upsampling.See also
For additional information about the parameters and the methods of this class, see
e2cnn.nn.R2Conv
. The two modules are essentially the same, except for the type of convolution used.- Parameters
in_type (FieldType) – the type of the input field
out_type (FieldType) – the type of the output field
kernel_size (int) – the size of the filter
padding (int, optional) – implicit zero paddings on both sides of the input. Default:
0
output_padding (int, optional) – implicit zero paddings on both sides of the input. Default:
0
stride (int, optional) – the stride of the convolving kernel. Default:
1
dilation (int, optional) – the spacing between kernel elements. Default:
1
groups (int, optional) – number of blocked connections from input channels to output channels. Default:
1
.bias (bool, optional) – Whether to add a bias to the output (only to fields which contain a trivial irrep) or not. Default
True
initialize (bool, optional) – initialize the weights of the model. Default:
True
- export()[source]
Export this module to a normal PyTorch
torch.nn.ConvTranspose2d
module and set to “eval” mode.
R2Diffop
- class R2Diffop(in_type, out_type, kernel_size=None, accuracy=None, padding=0, stride=1, dilation=1, padding_mode='zeros', groups=1, bias=True, basisexpansion='blocks', maximum_order=None, maximum_power=None, maximum_offset=None, recompute=False, angle_offset=None, basis_filter=None, initialize=True, cache=False, rbffd=False, radial_basis_function='ga', smoothing=None)[source]
Bases:
e2cnn.nn.modules.equivariant_module.EquivariantModule
G-steerable planar partial differential operator mapping between the input and output
FieldType
s specified by the parametersin_type
andout_type
. This operation is equivariant under the action of \(\R^2\rtimes G\) where \(G\) is thee2cnn.nn.FieldType.fibergroup
ofin_type
andout_type
.Specifically, let \(\rho_\text{in}: G \to \GL{\R^{c_\text{in}}}\) and \(\rho_\text{out}: G \to \GL{\R^{c_\text{out}}}\) be the representations specified by the input and output field types. Then
R2Diffop
guarantees an equivariant mapping\[D [\mathcal{T}^\text{in}_{g,u} . f] = \mathcal{T}^\text{out}_{g,u} . [Df] \qquad\qquad \forall g \in G, u \in \R^2\]where the transformation of the input and output fields are given by
\[\begin{split}[\mathcal{T}^\text{in}_{g,u} . f](x) &= \rho_\text{in}(g)f(g^{-1} (x - u)) \\ [\mathcal{T}^\text{out}_{g,u} . f](x) &= \rho_\text{out}(g)f(g^{-1} (x - u)) \\\end{split}\]The equivariance of G-steerable PDOs is guaranteed by restricting the space of PDOs to an equivariant subspace.
During training, in each forward pass the module expands the basis of G-steerable PDOs with learned weights before calling
torch.nn.functional.conv2d()
. Wheneval()
is called, the filter is built with the current trained weights and stored for future reuse such that no overhead of expanding the PDO remains.Warning
When
train()
is called, the attributesfilter
andexpanded_bias
are discarded to avoid situations of mismatch with the learnable expansion coefficients. See alsoe2cnn.nn.R2Diffop.train()
.This behaviour can cause problems when storing the
state_dict()
of a model while in a mode and lately loading it in a model with a different mode, as the attributes of the class change. To avoid this issue, we recommend converting the model to eval mode before storing or loading the state dictionary.The learnable expansion coefficients of the this module can be initialized with the methods in
e2cnn.nn.init
. By default, the weights are initialized in the constructors usinggeneralized_he_init()
.Warning
This initialization procedure can be extremely slow for wide layers. In case initializing the model is not required (e.g. before loading the state dict of a pre-trained model) or another initialization method is preferred (e.g.
deltaorthonormal_init()
), the parameterinitialize
can be set toFalse
to avoid unnecessary overhead.A reasonable default is to only set the
kernel_size
and leave all other options on their defaults. However, you might get considerable performance improvements by settingsmoothing
to something other thanNone
(kernel_size / 4
is a sane default, see below for details).If you want to modify
accuracy
ormaximum_order
, you will need to take into account how they are related tokernel_size
: it is possible to set any two ofkernel_size
,accuracy
andmaximum_order
, in which case the third one will be determined automatically. Alternatively, you can set eitherkernel_size
ormaximum_order
, in which case a sane default will be used foraccuracy
. The relation between the three is approximately \(\text{kernel size} \approx \text{accuracy} + \text{order}\), though this formula is off by one in some cases. A larger maximum order will lead to more basis filters and this more parameters. A larger accuracy (i.e. larger kernel size at constant order) might lead to lower equivariance errors, though whether this actually happens may depend on your exact setup.The parameters
basisexpansion
,maximum_power
, andmaximum_offset
are optional parameters used to control how the basis for the PDOs is built, how it is sampled on the filter grid and how it is expanded to build the filter. We suggest to keep these default values.Warning
The discretization of the differential operators relies on two external packages: sympy and rbf. If they are not available, an error is raised.
- Parameters
in_type (FieldType) – the type of the input field, specifying its transformation law
out_type (FieldType) – the type of the output field, specifying its transformation law
kernel_size (int, optional) – the size of the (square) filter. This can be chosen automatically, see above for details.
accuracy (int, optional) – the desired asymptotic accuracy for the PDO discretization, affects the
kernel_size
. See above for details.padding (int, optional) – implicit zero paddings on both sides of the input. Default:
0
padding_mode (str, optional) –
zeros
,reflect
,replicate
orcircular
. Default:zeros
stride (int, optional) – the stride of the kernel. Default:
1
dilation (int, optional) – the spacing between kernel elements. Default:
1
groups (int, optional) – number of blocked connections from input channels to output channels. It allows depthwise convolution. When used, the input and output types need to be divisible in
groups
groups, all equal to each other. Default:1
.bias (bool, optional) – Whether to add a bias to the output (only to fields which contain a trivial irrep) or not. Default
True
basisexpansion (str, optional) – the basis expansion algorithm to use
maximum_order (int, optional) – the largest derivative order to allow as part of the basis. Larger maximum orders require larger kernel sizes, see above for details.
maximum_power (int, optional) – the maximum power of the Laplacian that will be used for constructing the basis. If this is not
None
, it places a restriction on the basis elements, in addition to the restriction given bymaximum_order
. We suggest to leave this setting on its default unless you have a good reason to change it.maximum_offset (int, optional) – number of additional (aliased) frequencies in the intertwiners for finite groups. By default (
None
), all additional frequencies allowed by the frequencies cut-off are used.recompute (bool, optional) – if
True
, recomputes a new basis for the equivariant PDOs. By Default (False
), it caches the basis built or reuse a cached one, if it is found.basis_filter (callable, optional) – function which takes as input a descriptor of a basis element (as a dictionary) and returns a boolean value: whether to preserve (
True
) or discard (False
) the basis element. By default (None
), no filtering is applied.initialize (bool, optional) – initialize the weights of the model. Default:
True
cache (bool or str, optional) –
Discretizing the PDOs can take a bit longer than for kernels, so we provide the option to cache PDOs on disk. Our suggestion is to keep the cache off (default) and only activate it if discretizing the PDOs is in fact a bottleneck for your setup (it often is not). Setting
cache
toTrue
will load an existing cache before instantiating the layer and will write to the cache afterwards. You can also setcache
toload
orstore
to only do one of these.All
R2Diffop
layers share the PDO cache in memory. If you have severalR2Diffop
layers inside your model, we therefore recommend to leavecache
toFalse
and instead calle2cnn.diffops.load_cache()
before instantiating the model, ande2cnn.diffops.store_cache()
afterwards to save the PDOs for the next run of the program. This will avoid unnecessary reads/writes from/to disk.rbffd (bool, optional) – if set to
True
, use RBF-FD discretization instead of finite differences (the default). We suggest leaving this toFalse
unless you have a specific reason for wanting to use RBF-FD.radial_basis_function (str, optional) – which RBF to use (only relevant for RBF-FD). Can be any of the abbreviations in this list. The default is to use Gaussian RBFs because this always avoids singularity issues. But other RBFs, such as polyharmonic splines, may work better if they are applicable.
smoothing (float, optional) – if not
None
, discretization will be performed with derivatives of Gaussians as stencils. This is similar to smoothing with a Gaussian before applying the PDO, though there are slight technical differences.smoothing
is the standard deviation (in pixels) of the Gaussian, meaning that larger values correspond to stronger smoothing. A reasonable value would be aboutkernel_size / 4
but you might want to experiment a bit with this parameter.
- Variables
~.weights (torch.Tensor) – the learnable parameters which are used to expand the PDO
~.filter (torch.Tensor) – the convolutional stencil obtained by expanding the parameters in
weights
~.bias (torch.Tensor) – the learnable parameters which are used to expand the bias, if
bias=True
~.expanded_bias (torch.Tensor) – the equivariant bias which is summed to the output, obtained by expanding the parameters in
bias
- property basisexpansion: e2cnn.nn.modules.r2_conv.basisexpansion.BasisExpansion
Submodule which takes care of building the filter.
It uses the learnt
weights
to expand a basis and returns a filter in the usual form used by conventional convolutional modules. It uses the learnedweights
to expand the PDO in the G-steerable basis and returns it in the shape \((c_\text{out}, c_\text{in}, s^2)\), where \(s\) is thekernel_size
.
- expand_parameters()[source]
Expand the filter in terms of the
e2cnn.nn.R2Diffop.weights
and the expanded bias in terms ofe2cnn.nn.R2Diffop.bias
.- Returns
the expanded filter and bias
- forward(input)[source]
Convolve the input with the expanded filter and bias.
- Parameters
input (GeometricTensor) – input feature field transforming according to
in_type
- Returns
output feature field transforming according to
out_type
- train(mode=True)[source]
If
mode=True
, the method sets the module in training mode and discards thefilter
andexpanded_bias
attributes.If
mode=False
, it sets the module in evaluation mode. Moreover, the method builds the filter and the bias using the current values of the trainable parameters and store them infilter
andexpanded_bias
such that they are not recomputed at each forward pass.Warning
This behaviour can cause problems when storing the
state_dict()
of a model while in a mode and lately loading it in a model with a different mode, as the attributes of this class change. To avoid this issue, we recommend converting the model to eval mode before storing or loading the state dictionary.- Parameters
mode (bool, optional) – whether to set training mode (
True
) or evaluation mode (False
). Default:True
.
- export()[source]
Export this module to a normal PyTorch
torch.nn.Conv2d
module and set to “eval” mode.
Basis Expansion Modules
Basis Expansion
- class BasisExpansion[source]
Bases:
abc.ABC
,torch.nn.modules.module.Module
Abstract class defining the interface for the different basis expansion algorithms.
- abstract forward(weights)[source]
Forward step of the module which expands the basis and returns the filter built
- Parameters
weights (torch.Tensor) – the learnable weights used to linearly combine the basis elements
- Returns
the filter built
- abstract get_basis_names()[source]
Method that returns the list of identification names of the basis elements
- Returns
list of names
- abstract get_element_info(name)[source]
Method that returns the information associated to a basis element
BlocksBasisExpansion
- class BlocksBasisExpansion(in_type, out_type, basis_generator, points, basis_filter=None, recompute=False, **kwargs)[source]
Bases:
e2cnn.nn.modules.r2_conv.basisexpansion.BasisExpansion
With this algorithm, the expansion is done on the intertwiners of the fields’ representations pairs in input and output.
- Parameters
in_type (FieldType) – the input field type
out_type (FieldType) – the output field type
basis_generator (callable) – method that generates the analytical filter basis
points (ndarray) – points where the analytical basis should be sampled
basis_filter (callable, optional) – filter for the basis elements. Should take a dictionary containing an element’s attributes and return whether to keep it or not.
recompute (bool, optional) – whether to recompute new bases or reuse, if possible, already built tensors.
**kwargs – keyword arguments to be passed to
`basis_generator`
- Variables
~BlocksBasisExpansion.S (int) – number of points where the filters are sampled
- forward(weights)[source]
Forward step of the Module which expands the basis and returns the filter built
- Parameters
weights (torch.Tensor) – the learnable weights used to linearly combine the basis filters
- Returns
the filter built
SingleBlockBasisExpansion
- class SingleBlockBasisExpansion(basis, points, basis_filter=None)[source]
Bases:
e2cnn.nn.modules.r2_conv.basisexpansion.BasisExpansion
Basis expansion method for a single contiguous block, i.e. for kernels/PDOs whose input type and output type contain only fields of one type.
This class should be instantiated through the factory method
block_basisexpansion()
to enable caching.- Parameters
basis (Basis) – analytical basis to sample
points (ndarray) – points where the analytical basis should be sampled
basis_filter (callable, optional) – filter for the basis elements. Should take a dictionary containing an element’s attributes and return whether to keep it or not.
block_basisexpansion
- block_basisexpansion(basis, points, basis_filter=None, recompute=False)[source]
Return an instance of
SingleBlockBasisExpansion
.This function support caching through the argument
recompute
.- Parameters
basis (Basis) – basis defining the space of kernels
points (ndarray) – points where the analytical basis should be sampled
basis_filter (callable, optional) – filter for the basis elements. Should take a dictionary containing an element’s attributes and return whether to keep it or not.
recompute (bool, optional) – whether to recompute new bases (
True
) or reuse, if possible, already built tensors (False
, default).
Non Linearities
ConcatenatedNonLinearity
- class ConcatenatedNonLinearity(in_type, function='c_relu')[source]
Bases:
e2cnn.nn.modules.equivariant_module.EquivariantModule
Concatenated non-linearities. For each input channel, the module applies the specified activation function both to its value and its opposite (the value multiplied by -1). The number of channels is, therefore, doubled.
Notice that not all the representations support this kind of non-linearity. Indeed, only representations with the same pattern of permutation matrices and containing only values in \(\{0, 1, -1\}\) support it.
ELU
- class ELU(in_type, alpha=1.0, inplace=False)[source]
Bases:
e2cnn.nn.modules.equivariant_module.EquivariantModule
Module that implements a pointwise ELU to every channel independently. The input representation is preserved by this operation and, therefore, it equals the output representation.
Only representations supporting pointwise non-linearities are accepted as input field type.
- Parameters
- forward(input)[source]
Applies ELU function on the input fields
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map after elu has been applied
- export()[source]
Export this module to a normal PyTorch
torch.nn.ELU
module and set to “eval” mode.
GatedNonLinearity1
- class GatedNonLinearity1(in_type, gates=None, drop_gates=True, **kwargs)[source]
Bases:
e2cnn.nn.modules.equivariant_module.EquivariantModule
Gated non-linearities. This module applies a bias and a sigmoid function of the gates fields and, then, multiplies each gated field by one of the gates.
The input representation of the gated fields is preserved by this operation while the gate fields are discarded.
The gates and the gated fields are provided in one unique input tensor and, therefore,
in_repr
should be the representation of the fiber containing both gates and gated fields. Moreover, the parametergates
needs to be set with a list long as the total number of fields, containing in a positioni
the string"gate"
if thei
-th field is a gate or the string"gated"
if thei
-th field is a gated field. No other strings are allowed. By default (gates = None
), the first half of the fields is assumed to contain the gates (and, so, these fields have to be trivial fields) while the second one is assumed to contain the gated fields.In any case, the number of gates and the number of gated fields have to match (therefore, the number of fields has to be an even number).
- Parameters
in_type (FieldType) – the input field type
gates (list, optional) – list of strings specifying which field in input is a gate and which is a gated field
drop_gates (bool, optional) – if
True
(default), drop the trivial fields after using them to compute the gates. IfFalse
, the gates are stacked with the gated fields in the output
- forward(input)[source]
Apply the gated non-linearity to the input feature map.
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map
GatedNonLinearity2
- class GatedNonLinearity2(in_type, **kwargs)[source]
Bases:
e2cnn.nn.modules.equivariant_module.EquivariantModule
Gated non-linearities. This module applies a bias and a sigmoid function of the gates fields and, then, multiplies each gated field by one of the gates.
The input representation of the gated fields is preserved by this operation while the gate fields are discarded.
The gates and the gated fields are provided in two different tensors:
in_repr
is a tuple containing two representations: the representation of the tensor containing only the gates (which have to be trivial fields) and the representation of the tensor containing only the gated fields. Therefore, two tensors need to be provided as input to theforward
method: the first contains the gates and the second the gated fields. Finally, notice that the number of gates and the number of gated fields have to match, i.e. the two representations need to have the same number of fields.Todo
This module has 2 input tensors and, so, two input field types. EquivariantModule only supports one input though. Fix this.
- Parameters
in_type (Tuple) – a pair containing, in order, the field type of the gates and the field type of the gated fields
- forward(gates, input)[source]
Apply the gated non-linearity to the input feature map.
- Parameters
gates (GeometricTensor) – feature map corresponding to the gates
input (GeometricTensor) – input feature map corresponding to the gated fields
- Returns
the resulting feature map
InducedGatedNonLinearity1
- class InducedGatedNonLinearity1(in_type, gates=None, drop_gates=True, **kwargs)[source]
Bases:
e2cnn.nn.modules.equivariant_module.EquivariantModule
Induced Gated non-linearities.
Todo
complete documentation!
Note
Make sure all induced gate and gates have same subgroup
- Parameters
in_type (FieldType) – the input field type
gates (list, optional) – list of strings specifying which field in input is a gate and which is a gated field
drop_gates (bool, optional) – if
True
(default), drop the trivial fields after using them to compute the gates. IfFalse
, the gates are stacked with the gated fields in the output
- forward(input)[source]
Apply the gated non-linearity to the input feature map.
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map
InducedNormNonLinearity
- class InducedNormNonLinearity(in_type, function='n_relu', bias=True)[source]
Bases:
e2cnn.nn.modules.equivariant_module.EquivariantModule
Induced Norm non-linearities. This module requires the input fields to be associated to an induced representation from a representation which supports ‘norm’ non-linearities. This module applies a bias and an activation function over the norm of each sub-field of an induced field. The bias is shared among all sub-field of the same induced field.
The input representation of the fields is preserved by this operation.
- Parameters
- forward(input)[source]
Apply norm non-linearities to the input feature map
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map
NormNonLinearity
- class NormNonLinearity(in_type, function='n_relu', bias=True)[source]
Bases:
e2cnn.nn.modules.equivariant_module.EquivariantModule
Norm non-linearities. This module applies a bias and an activation function over the norm of each field.
The input representation of the fields is preserved by this operation.
Note
If ‘squash’ non-linearity (function) is chosen, no bias is allowed
- Parameters
- forward(input)[source]
Apply norm non-linearities to the input feature map
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map
PointwiseNonLinearity
- class PointwiseNonLinearity(in_type, function='p_relu')[source]
Bases:
e2cnn.nn.modules.equivariant_module.EquivariantModule
Pointwise non-linearities. The same scalar function is applied to every channel independently. The input representation is preserved by this operation and, therefore, it equals the output representation.
Only representations supporting pointwise non-linearities are accepted as input field type.
- Parameters
- forward(input)[source]
Applies the pointwise activation function on the input fields
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map after the non-linearities have been applied
ReLU
- class ReLU(in_type, inplace=False)[source]
Bases:
e2cnn.nn.modules.equivariant_module.EquivariantModule
Module that implements a pointwise ReLU to every channel independently. The input representation is preserved by this operation and, therefore, it equals the output representation.
Only representations supporting pointwise non-linearities are accepted as input field type.
- Parameters
- forward(input)[source]
Applies ReLU function on the input fields
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map after relu has been applied
- export()[source]
Export this module to a normal PyTorch
torch.nn.ReLU
module and set to “eval” mode.
VectorFieldNonLinearity
- class VectorFieldNonLinearity(in_type, **kwargs)[source]
Bases:
e2cnn.nn.modules.equivariant_module.EquivariantModule
VectorField non-linearities. This non-linearity only supports the regular representation of cyclic group \(C_N\), i.e. the group of \(N\) discrete rotations. For each input field, the output one is built by taking the rotation associated with the highest activation; then, a 2-dimensional vector with an angle with respect to the x-axis equal to that rotation and a length equal to its activation is set in the output field.
- Parameters
in_type (FieldType) – the input field type
- forward(input)[source]
Apply the VectorField non-linearity to the input feature map.
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map
Invariant Maps
GroupPooling
- class GroupPooling(in_type, **kwargs)[source]
Bases:
e2cnn.nn.modules.equivariant_module.EquivariantModule
Module that implements group pooling. This module only supports permutation representations such as regular representation, quotient representation or trivial representation (though, in the last case, this module acts as identity). For each input field, an output field is built by taking the maximum activation within that field; as a result, the output field transforms according to a trivial representation.
- Parameters
in_type (FieldType) – the input field type
- forward(input)[source]
Apply Group Pooling to the input feature map.
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map
- export()[source]
Export this module to the pure PyTorch module
MaxPoolChannels
and set to “eval” mode.Warning
Currently, this method only supports group pooling with feature types containing only representations of the same size.
Note
Because there is no native PyTorch module performing this operation, it is not possible to export this module without any dependency with this library. Indeed, the resulting module is dependent on this library through the class
MaxPoolChannels
. In case PyTorch will introduce a similar module in a future release, we will update this method to remove this dependency.Nevertheless, the
MaxPoolChannels
module is slightly lighter thanGroupPooling
as it does not perform any automatic type checking and does not wrap each tensor in aGeometricTensor
. Furthermore, theMaxPoolChannels
class is very simple and one can easily reimplement it to remove any dependency with this library after training the model.
NormPool
- class NormPool(in_type, **kwargs)[source]
Bases:
e2cnn.nn.modules.equivariant_module.EquivariantModule
Module that implements Norm Pooling. For each input field, an output one is built by taking the norm of that field; as a result, the output field transforms according to a trivial representation.
- Parameters
in_type (FieldType) – the input field type
- forward(input)[source]
Apply the Norm Pooling to the input feature map.
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map
InducedNormPool
- class InducedNormPool(in_type, **kwargs)[source]
Bases:
e2cnn.nn.modules.equivariant_module.EquivariantModule
Module that implements Induced Norm Pooling. This module requires the input fields to be associated to an induced representation from a representation which supports ‘norm’ non-linearities.
First, for each input field, an output one is built by taking the maximum norm of all its sub-fields.
- Parameters
in_type (FieldType) – the input field type
- forward(input)[source]
Apply the Norm Pooling to the input feature map.
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map
Pooling
NormMaxPool
- class NormMaxPool(in_type, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False)[source]
Bases:
e2cnn.nn.modules.equivariant_module.EquivariantModule
Max-pooling based on the fields’ norms. In a given window of shape
kernel_size
, for each group of channels belonging to the same field, the field with the highest norm (as the length of the vector) is preserved.Except
in_type
, the other parameters correspond to the ones oftorch.nn.MaxPool2d
.- Parameters
in_type (FieldType) – the input field type
kernel_size (
Union
[int
,Tuple
[int
,int
]]) – the size of the window to take a max overstride (
Union
[int
,Tuple
[int
,int
],None
]) – the stride of the window. Default value iskernel_size
padding (
Union
[int
,Tuple
[int
,int
]]) – implicit zero padding to be added on both sidesdilation (
Union
[int
,Tuple
[int
,int
]]) – a parameter that controls the stride of elements in the windowceil_mode (
bool
) – whenTrue
, will use ceil instead of floor to compute the output shape
- forward(input)[source]
Run the norm-based max-pooling on the input tensor
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map
PointwiseMaxPool
- class PointwiseMaxPool(in_type, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False)[source]
Bases:
e2cnn.nn.modules.equivariant_module.EquivariantModule
Channel-wise max-pooling: each channel is treated independently. This module works exactly as
torch.nn.MaxPool2D
, wrapping it in theEquivariantModule
interface.Notice that not all representations support this kind of pooling. In general, only representations which support pointwise non-linearities do.
- Parameters
in_type (FieldType) – the input field type
kernel_size (
Union
[int
,Tuple
[int
,int
]]) – the size of the window to take a max overstride (
Union
[int
,Tuple
[int
,int
],None
]) – the stride of the window. Default value iskernel_size
padding (
Union
[int
,Tuple
[int
,int
]]) – implicit zero padding to be added on both sidesdilation (
Union
[int
,Tuple
[int
,int
]]) – a parameter that controls the stride of elements in the windowceil_mode (
bool
) – when True, will use ceil instead of floor to compute the output shape
- forward(input)[source]
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map
- export()[source]
Export this module to a normal PyTorch
torch.nn.MaxPool2d
module and set to “eval” mode.
PointwiseMaxPoolAntialiased
- class PointwiseMaxPoolAntialiased(in_type, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, sigma=0.6)[source]
Bases:
e2cnn.nn.modules.pooling.pointwise_max.PointwiseMaxPool
Anti-aliased version of channel-wise max-pooling (each channel is treated independently).
The max over a neighborhood is performed pointwise withot downsampling. Then, convolution with a gaussian blurring filter is performed before downsampling the feature map.
Based on Making Convolutional Networks Shift-Invariant Again.
Notice that not all representations support this kind of pooling. In general, only representations which support pointwise non-linearities do.
- Parameters
in_type (FieldType) – the input field type
kernel_size (
Union
[int
,Tuple
[int
,int
]]) – the size of the window to take a max overstride (
Union
[int
,Tuple
[int
,int
],None
]) – the stride of the window. Default value iskernel_size
padding (
Union
[int
,Tuple
[int
,int
]]) – implicit zero padding to be added on both sidesdilation (
Union
[int
,Tuple
[int
,int
]]) – a parameter that controls the stride of elements in the windowceil_mode (
bool
) – whenTrue
, will use ceil instead of floor to compute the output shapesigma (float) – standard deviation for the Gaussian blur filter
- forward(input)[source]
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map
- export()[source]
Export this module to a normal PyTorch
torch.nn.MaxPool2d
module and set to “eval” mode.
PointwiseAvgPool
- class PointwiseAvgPool(in_type, kernel_size, stride=None, padding=0, ceil_mode=False)[source]
Bases:
e2cnn.nn.modules.equivariant_module.EquivariantModule
Channel-wise average-pooling: each channel is treated independently. This module works exactly as
torch.nn.AvgPool2D
, wrapping it in theEquivariantModule
interface.- Parameters
in_type (FieldType) – the input field type
kernel_size (
Union
[int
,Tuple
[int
,int
]]) – the size of the window to take a average overstride (
Union
[int
,Tuple
[int
,int
],None
]) – the stride of the window. Default value iskernel_size
padding (
Union
[int
,Tuple
[int
,int
]]) – implicit zero padding to be added on both sidesceil_mode (
bool
) – whenTrue
, will use ceil instead of floor to compute the output shape
- forward(input)[source]
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map
PointwiseAvgPoolAntialiased
- class PointwiseAvgPoolAntialiased(in_type, sigma, stride, padding=None)[source]
Bases:
e2cnn.nn.modules.equivariant_module.EquivariantModule
Antialiased channel-wise average-pooling: each channel is treated independently. It performs strided convolution with a Gaussian blur filter.
The size of the filter is computed as 3 standard deviations of the Gaussian curve. By default, padding is added such that input size is preserved if stride is 1.
Based on Making Convolutional Networks Shift-Invariant Again.
- Parameters
- forward(input)[source]
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map
PointwiseAdaptiveAvgPool
- class PointwiseAdaptiveAvgPool(in_type, output_size)[source]
Bases:
e2cnn.nn.modules.equivariant_module.EquivariantModule
Adaptive channel-wise average-pooling: each channel is treated independently. This module works exactly as
torch.nn.AdaptiveAvgPool2D
, wrapping it in theEquivariantModule
interface.Notice that not all representations support this kind of pooling. In general, only representations which support pointwise non-linearities do.
- Parameters
- forward(input)[source]
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map
- export()[source]
Export this module to a normal PyTorch
torch.nn.AdaptiveAvgPool2d
module and set to “eval” mode.
PointwiseAdaptiveMaxPool
- class PointwiseAdaptiveMaxPool(in_type, output_size)[source]
Bases:
e2cnn.nn.modules.equivariant_module.EquivariantModule
Module that implements adaptive channel-wise max-pooling: each channel is treated independently. This module works exactly as
torch.nn.AdaptiveMaxPool2D
, wrapping it in theEquivariantModule
interface.Notice that not all representations support this kind of pooling. In general, only representations which support pointwise non-linearities do.
- Parameters
- forward(input)[source]
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map
- export()[source]
Export this module to a normal PyTorch
torch.nn.AdaptiveAvgPool2d
module and set to “eval” mode.
Normalization
InnerBatchNorm
- class InnerBatchNorm(in_type, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)[source]
Bases:
e2cnn.nn.modules.equivariant_module.EquivariantModule
Batch normalization for representations with permutation matrices.
Statistics are computed both over the batch and the spatial dimensions and over the channels within the same field (which are permuted by the representation).
Only representations supporting pointwise non-linearities are accepted as input field type.
- Parameters
in_type (FieldType) – the input field type
eps (float, optional) – a value added to the denominator for numerical stability. Default:
1e-5
momentum (float, optional) – the value used for the
running_mean
andrunning_var
computation. Can be set toNone
for cumulative moving average (i.e. simple average). Default:0.1
affine (bool, optional) – if
True
, this module has learnable affine parameters. Default:True
track_running_stats (bool, optional) – when set to
True
, the module tracks the running mean and variance; when set toFalse
, it does not track such statistics but uses batch statistics in both training and eval modes. Default:True
- forward(input)[source]
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map
- export()[source]
Export this module to a normal PyTorch
torch.nn.BatchNorm2d
module and set to “eval” mode.
NormBatchNorm
- class NormBatchNorm(in_type, eps=1e-05, momentum=0.1, affine=True)[source]
Bases:
e2cnn.nn.modules.equivariant_module.EquivariantModule
Batch normalization for isometric (i.e. which preserves the norm) non-trivial representations.
The module assumes the mean of the vectors is always zero so no running mean is computed and no bias is added. This is guaranteed as long as the representations do not include a trivial representation.
Indeed, if \(\rho\) does not include a trivial representation, it holds:
\[\forall \bold{v} \in \mathbb{R}^n, \ \ \frac{1}{|G|} \sum_{g \in G} \rho(g) \bold{v} = \bold{0}\]Hence, only the standard deviation is normalized.
Only representations which do not contain the trivial representation are allowed. You can check if a representation contains the trivial representation using
contains_trivial()
. To check if a trivial irrep is present in a representation in aFieldType
, you can use:for r in field_type: if r.contains_trivial(): print(f"field type contains a trivial irrep")
- Parameters
in_type (FieldType) – the input field type
eps (float, optional) – a value added to the denominator for numerical stability. Default:
1e-5
momentum (float, optional) – the value used for the
running_mean
andrunning_var
computation. Can be set toNone
for cumulative moving average (i.e. simple average). Default:0.1
affine (bool, optional) – if
True
, this module has learnable scale parameters. Default:True
- forward(input)[source]
Apply norm non-linearities to the input feature map
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map
InducedNormBatchNorm
- class InducedNormBatchNorm(in_type, eps=1e-05, momentum=0.1, affine=True)[source]
Bases:
e2cnn.nn.modules.equivariant_module.EquivariantModule
Batch normalization for induced isometric representations. This module requires the input fields to be associated to an induced representation from an isometric (i.e. which preserves the norm) non-trivial representation which supports ‘norm’ non-linearities.
The module assumes the mean of the vectors is always zero so no running mean is computed and no bias is added. This is guaranteed as long as the representations do not include a trivial representation.
Indeed, if \(\rho\) does not include a trivial representation, it holds:
\[\forall \bold{v} \in \mathbb{R}^n, \ \ \frac{1}{|G|} \sum_{g \in G} \rho(g) \bold{v} = \bold{0}\]Hence, only the standard deviation is normalized. The same standard deviation, however, is shared by all the sub-fields of the same induced field.
The input representation of the fields is preserved by this operation.
Only representations which do not contain the trivial representation are allowed. You can check if a representation contains the trivial representation using
contains_trivial()
. To check if a trivial irrep is present in a representation in aFieldType
, you can use:for r in field_type: if r.contains_trivial(): print(f"field type contains a trivial irrep")
- Parameters
in_type (FieldType) – the input field type
eps (float, optional) – a value added to the denominator for numerical stability. Default:
1e-5
momentum (float, optional) – the value used for the
running_mean
andrunning_var
computation. Can be set toNone
for cumulative moving average (i.e. simple average). Default:0.1
affine (bool, optional) – if
True
, this module has learnable scale parameters. Default:True
- forward(input)[source]
Apply norm non-linearities to the input feature map
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map
GNormBatchNorm
- class GNormBatchNorm(in_type, eps=1e-05, momentum=0.1, affine=True)[source]
Bases:
e2cnn.nn.modules.equivariant_module.EquivariantModule
Batch normalization for generic representations.
Todo
Add more details about how stats are computed and how affine transformation is done.
- Parameters
in_type (FieldType) – the input field type
eps (float, optional) – a value added to the denominator for numerical stability. Default:
1e-5
momentum (float, optional) – the value used for the
running_mean
andrunning_var
computation. Can be set toNone
for cumulative moving average (i.e. simple average). Default:0.1
affine (bool, optional) – if
True
, this module has learnable affine parameters. Default:True
- forward(input)[source]
Apply norm non-linearities to the input feature map
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map
Dropout
FieldDropout
- class FieldDropout(in_type, p=0.5, inplace=False)[source]
Bases:
e2cnn.nn.modules.equivariant_module.EquivariantModule
Applies dropout to individual fields independently.
Notice that, with respect to
PointwiseDropout
, this module acts on a whole field instead of single channels.- Parameters
- forward(input)[source]
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map
PointwiseDropout
- class PointwiseDropout(in_type, p=0.5, inplace=False)[source]
Bases:
e2cnn.nn.modules.equivariant_module.EquivariantModule
Applies dropout to individual channels independently.
This class is just a wrapper for
torch.nn.functional.dropout()
in anEquivariantModule
.Only representations supporting pointwise non-linearities are accepted as input field type.
- Parameters
- forward(input)[source]
- Parameters
input (GeometricTensor) – the input feature map
- Returns
the resulting feature map
- export()[source]
Export this module to a normal PyTorch
torch.nn.Dropout
module and set to “eval” mode.
Other Modules
Sequential
- class SequentialModule(*args)[source]
Bases:
e2cnn.nn.modules.equivariant_module.EquivariantModule
A sequential container similar to
torch.nn.Sequential
.The constructor accepts both a list or an ordered dict of
EquivariantModule
instances.Example:
# Example of SequentialModule s = e2cnn.gspaces.Rot2dOnR2(8) c_in = e2cnn.nn.FieldType(s, [s.trivial_repr]*3) c_out = e2cnn.nn.FieldType(s, [s.regular_repr]*16) model = e2cnn.nn.SequentialModule( e2cnn.nn.R2Conv(c_in, c_out, 5), e2cnn.nn.InnerBatchNorm(c_out), e2cnn.nn.ReLU(c_out), ) # Example with OrderedDict s = e2cnn.gspaces.Rot2dOnR2(8) c_in = e2cnn.nn.FieldType(s, [s.trivial_repr]*3) c_out = e2cnn.nn.FieldType(s, [s.regular_repr]*16) model = e2cnn.nn.SequentialModule(OrderedDict([ ('conv', e2cnn.nn.R2Conv(c_in, c_out, 5)), ('bn', e2cnn.nn.InnerBatchNorm(c_out)), ('relu', e2cnn.nn.ReLU(c_out)), ]))
- forward(input)[source]
- Parameters
input (GeometricTensor) – the input GeometricTensor
- Returns
the output tensor
- add_module(name, module)[source]
Append
module
to the sequence of modules applied in the forward pass.
- export()[source]
Export this module to a normal PyTorch
torch.nn.Sequential
module and set to “eval” mode.
ModuleList
- class ModuleList(modules=None)[source]
Bases:
torch.nn.modules.container.ModuleList
Module similar to
ModuleList
containing a list ofEquivariantModule
s.This class works like
ModuleList
except for the fact it only accepts instances ofEquivariantModule
.Additionally, this class provides a .export() method. This method calls the
export()
method of each module contained in thisModuleList
and returns aModuleList
containing the exported modules.- Parameters
modules (iterable, optional) – an iterable of equivariant modules to add
- append(module)[source]
Appends an
EquivariantModule
to the end of the list.- Parameters
module (EquivariantModule) – equivariant module to append
- extend(modules)[source]
Appends multiple
EquivariantModule
instances from a Python iterable to the end of the list.- Parameters
modules (iterable) – iterable of equivariant modules to append
- export()[source]
Export this module to a normal PyTorch
torch.nn.ModuleList
module and set to “eval” mode.
Restriction
- class RestrictionModule(in_type, id)[source]
Bases:
e2cnn.nn.modules.equivariant_module.EquivariantModule
Restricts the type of the input to the subgroup identified by
id
.It computes the output type in the constructor and wraps the underlying tensor (
torch.Tensor
) in input with the output type during the forward pass.This module only acts as a wrapper for
e2cnn.nn.FieldType.restrict()
(ore2cnn.nn.GeometricTensor.restrict()
). The accepted values ofid
depend on the underlyinggspace
in the input typein_type
; check the documentation of the methode2cnn.gspaces.GSpace.restrict()
of the gspace used for further information.See also
e2cnn.nn.FieldType.restrict()
,e2cnn.nn.GeometricTensor.restrict()
,e2cnn.gspaces.GSpace.restrict()
- Parameters
in_type (FieldType) – the input field type
id – a valid id for a subgroup of the space associated with the input type
- export()[source]
Export this module to a normal PyTorch
torch.nn.Identity
module and set to “eval” mode.Warning
Only working with PyTorch >= 1.2
Disentangle
- class DisentangleModule(in_type)[source]
Bases:
e2cnn.nn.modules.equivariant_module.EquivariantModule
Disentangles the representations in the field type of the input.
This module only acts as a wrapper for
e2cnn.group.disentangle()
. In the constructor, it disentangles each representation in the input type to build the output type and pre-compute the change of basis matrices needed to transform each input field.During the forward pass, each field in the input tensor is transformed with the change of basis corresponding to its representation.
- Parameters
in_type (FieldType) – the input field type
Upsampling
- class R2Upsampling(in_type, scale_factor=None, size=None, mode='bilinear', align_corners=False)[source]
Bases:
e2cnn.nn.modules.equivariant_module.EquivariantModule
Wrapper for
torch.nn.functional.interpolate()
. Check its documentation for further details.Only
"bilinear"
and"nearest"
methods are supported. However,"nearest"
is not equivariant; using this method may result in broken equivariance. For this reason, we suggest to use"bilinear"
(default value).Warning
The module supports a
size
parameter as an alternative toscale_factor
. However, the use ofscale_factor
should be preferred, since it guarantees both axes are scaled uniformly, which preserves rotation equivariance. A misuse of the parametersize
can break the overall equivariance, since it might scale the two axes by two different factors.- Parameters
in_type (FieldType) – the input field type
scale_factor (optional, int) – multiplier for spatial size
mode (str) – algorithm used for upsampling:
nearest
|bilinear
. Default:bilinear
align_corners (bool) – if
True
, the corner pixels of the input and output tensors are aligned, and thus preserving the values at those pixels. This only has effect when mode isbilinear
. Default:False
- forward(input)[source]
- Parameters
input (torch.Tensor) – input feature map
- Returns
the result of the convolution
- export()[source]
Export this module to a normal PyTorch
torch.nn.Upsample
module and set to “eval” mode.
Multiple
- class MultipleModule(in_type, labels, modules, reshuffle=0)[source]
Bases:
e2cnn.nn.modules.equivariant_module.EquivariantModule
Split the input tensor in multiple branches identified by the input
labels
and apply to each of them the corresponding module inmodules
A label is associated to each field in the input type, while
modules
assigns a module to apply to each label (or set of labels).modules
should be a list of pairs, each containing anEquivariantModule
and a label (or a list of labels).During forward, fields are grouped by the labels and the input tensor is split accordingly. Then, each subtensor is passed to the corresponding module in
modules
.If
reshuffle
is set to a positive integer, a copy of the input tensor is first built sorting the fields according to the value set:1: fields are sorted by their labels
2: fields are sorted by their labels and, then, by their size
3: fields are sorted by their labels, by their size and, then, by their type
In this way, fields that need to be retrieved together are contiguous and it is possible to exploit slicing to split the tensor. By default,
reshuffle = 0
which means that no sorting is performed and, so, if input fields are not contiguous this layer will use indexing to retrieve sub-tensors.This modules wraps a
BranchingModule
followed by aMergeModule
.- Parameters
- forward(input)[source]
Split the input tensor according to the labels, apply each module to the corresponding input sub-tensors and stack the results.
- Parameters
input (GeometricTensor) – the input GeometricTensor
- Returns
the concatenation of the output of each module
Reshuffle
- class ReshuffleModule(in_type, permutation)[source]
Bases:
e2cnn.nn.modules.equivariant_module.EquivariantModule
Permutes the fields of the input tensor according to the input
permutation
.The parameter
permutation
should be a list containing a permutation of the integers{0, 1, ..., n-1}
, wheren
is the number of fields ofin_type
(seee2cnn.nn.FieldType.__len__()
).
Mask
- class MaskModule(in_type, S, margin=0.0)[source]
Bases:
e2cnn.nn.modules.equivariant_module.EquivariantModule
Performs an element-wise multiplication of the input with a mask of shape
S x S
.The mask has value \(1\) in all pixels with distance smaller than
(S-1)/2 * (1 - margin)/100
from the center of the mask and \(0\) elsewhere. Values change smoothly between the two regions.This operation is useful to remove from an input image or feature map all the part of the signal defined on the pixels which lay outside the circle inscribed in the grid. Because a rotation would move these pixels outside the grid, this information would anyways be discarded when rotating an image. However, allowing a model to use this information might break the guaranteed equivariance as rotated and non-rotated inputs have different information content.
Note
In order to perform the masking, the module expects an input with the same spatial dimensions as the mask. Then, input tensors must have shape
B x C x S x S
.
Identity
- class IdentityModule(in_type)[source]
Bases:
e2cnn.nn.modules.equivariant_module.EquivariantModule
Simple module which does not perform any operation on the input tensor.
- Parameters
in_type (FieldType) – input (and output) type of this module
- forward(input)[source]
Returns the input tensor.
- Parameters
input (GeometricTensor) – the input GeometricTensor
- Returns
the output tensor
- export()[source]
Export this module to a normal PyTorch
torch.nn.Identity
module and set to “eval” mode.Warning
Only working with PyTorch >= 1.2
Weight Initialization
- generalized_he_init(tensor, basisexpansion, cache=False)[source]
Initialize the weights of a convolutional layer with a generalized He’s weight initialization method.
Because the computation of the variances can be expensive, to save time on consecutive runs of the same model, it is possible to cache the tensor containing the variance of each weight, for a specific
`basisexpansion`
. This can be useful if a network contains multiple convolution layers of the same kind (same input and output types, same kernel size, etc.) or if one needs to train the same network from scratch multiple times (e.g. to perform hyper-parameter search over learning rate or to repeat an experiment with different random seeds).Note
The variance tensor is cached in memory and therefore is only available to the current process.
- Parameters
tensor (torch.Tensor) – the tensor containing the weights
basisexpansion (BasisExpansion) – the basis expansion method
cache (bool, optional) – cache the variance tensor. By default,
`cache=False`
- deltaorthonormal_init(tensor, basisexpansion)[source]
Initialize the weights of a convolutional layer with delta-orthogonal initialization.
- Parameters
tensor (torch.Tensor) – the tensor containing the weights
basisexpansion (BasisExpansion) – the basis expansion method