# e2cnn.nn¶

This subpackage provides implementations of equivariant neural network modules.

In an equivariant network, features are associated with a transformation law under actions of a symmetry group. The transformation law of a feature field is implemented by its FieldType which can be interpreted as a data type. A GeometricTensor is wrapping a torch.Tensor to endow it with a FieldType. Geometric tensors are processed by EquivariantModule s which are torch.nn.Module s that guarantee the specified behavior of their output fields given a transformation of their input fields.

This subpackage depends on e2cnn.group and e2cnn.gspaces.

To enable efficient deployment of equivariant networks, many EquivariantModule s implement a export() method which converts a trained equivariant module into a pure PyTorch module, with few or no dependencies with e2cnn. Not all modules support this feature yet, so read each module’s documentation to check whether it implements this method or not. We provide a simple example:

# build a simple equivariant model using a SequentialModule

s = e2cnn.gspaces.Rot2dOnR2(8)
c_in = e2cnn.nn.FieldType(s, [s.trivial_repr]*3)
c_hid = e2cnn.nn.FieldType(s, [s.regular_repr]*3)
c_out = e2cnn.nn.FieldType(s, [s.regular_repr]*1)

net = SequentialModule(
R2Conv(c_in, c_hid, 5, bias=False),
InnerBatchNorm(c_hid),
ReLU(c_hid, inplace=True),
R2Conv(c_hid, c_out, 3, bias=False),
InnerBatchNorm(c_out),
ELU(c_out, inplace=True),
GroupPooling(c_out)
)

# train the model
# ...

# export the model

net.eval()
net_exported = net.export()

print(net)
> SequentialModule(
>   (0): R2Conv([8-Rotations: {irrep_0, irrep_0, irrep_0}], [8-Rotations: {regular, regular, regular}], kernel_size=5, stride=1, bias=False)
>   (1): InnerBatchNorm([8-Rotations: {regular, regular, regular}], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
>   (2): ReLU(inplace=True, type=[8-Rotations: {regular, regular, regular}])
>   (3): PointwiseMaxPool()
>   (4): R2Conv([8-Rotations: {regular, regular, regular}], [8-Rotations: {regular}], kernel_size=3, stride=1, bias=False)
>   (5): InnerBatchNorm([8-Rotations: {regular}], eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
>   (6): ELU(alpha=1.0, inplace=True, type=[8-Rotations: {regular}])
>   (7): GroupPooling([8-Rotations: {regular}])
> )

print(net_exported)
> Sequential(
>   (0): Conv2d(3, 24, kernel_size=(5, 5), stride=(1, 1), bias=False)
>   (1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
>   (2): ReLU(inplace=True)
>   (3): MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), dilation=(1, 1), ceil_mode=False)
>   (4): Conv2d(24, 8, kernel_size=(3, 3), stride=(1, 1), bias=False)
>   (5): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
>   (6): ELU(alpha=1.0, inplace=True)
>   (7): MaxPoolChannels(kernel_size=8)
> )

# check that the two models are equivalent

x = torch.randn(10, c_in.size, 31, 31)
x = GeometricTensor(x, c_in)

y1 = net(x).tensor
y2 = net_exported(x.tensor)

assert torch.allclose(y1, y2)


## Field Type¶

class FieldType(gspace, representations)[source]

An FieldType can be interpreted as the data type of a feature space. It describes:

• the base space on which a feature field is living and its symmetries considered

• the transformation law of feature fields under the action of the fiber group

The former is formalize by a choice of gspace while the latter is determined by a choice of group representations (representations), passed as a list of Representation instances. Each single representation in this list corresponds to one independent feature field contained in the feature space. The input representations need to belong to gspace’s fiber group (e2cnn.gspaces.GSpace.fibergroup).

Note

Mathematically, this class describes a (trivial) vector bundle, associated to the symmetry group $$(\R^D, +) \rtimes G$$.

Given a principal bundle $$\pi: (\R^D, +) \rtimes G \to \R^D, tg \mapsto tG$$ with fiber group $$G$$, an associated vector bundle has the same base space $$\R^D$$ but its fibers are vector spaces like $$\mathbb{R}^c$$. Moreover, these vector spaces are associated to a $$c$$-dimensional representation $$\rho$$ of the fiber group $$G$$ and transform accordingly.

The representation $$\rho$$ is defined as the direct sum of the representations $$\{\rho_i\}_i$$ in representations. See also directsum().

Parameters
• gspace (GSpace) – the space where the feature fields live and its symmetries

• representations (list) – a list of Representation s of the gspace’s fiber group, determining the transformation laws of the feature fields

Variables
• gspace (GSpace) –

• representations (list) –

• size (int) – dimensionality of the feature space described by the FieldType. It corresponds to the sum of the dimensionalities of the individual feature fields or group representations (e2cnn.group.Representation.size).

property fibergroup

The fiber group of gspace.

Returns

the fiber group

property representation

The (combined) representations of this field type. They describe how the feature vectors transform under the fiber group action, that is, how the channels mix.

It is the direct sum (directsum()) of the representations in e2cnn.nn.FieldType.representations.

Because a feature space can contain a very large number of feature fields, computing this representation as the direct sum of many small representations can be expensive. Hence, this representation is only built the first time it is explicitly used, in order to avoid unnecessary overhead when not needed.

Returns

the Representation describing the whole feature space

property irreps

Ordered list of irreps contained in the representation of the field type. It is the concatenation of the irreps in each representation in e2cnn.nn.FieldType.representations.

Returns

list of irreps

property change_of_basis

The change of basis matrix which decomposes the field types representation into irreps, given as a sparse (block diagonal) matrix (scipy.sparse.coo_matrix).

It is the direct sum of the change of basis matrices of each representation in e2cnn.nn.FieldType.representations.

e2cnn.group.Representation.change_of_basis

Returns

the change of basis

property change_of_basis_inv

Inverse of the (sparse) change of basis matrix. See e2cnn.nn.FieldType.change_of_basis for more details.

Returns

the inverted change of basis

get_dense_change_of_basis()[source]

The method returns a dense torch.Tensor containing a copy of the change-of-basis matrix.

See e2cnn.nn.FieldType.change_of_basis for more details.

get_dense_change_of_basis_inv()[source]

The method returns a dense torch.Tensor containing a copy of the inverse of the change-of-basis matrix.

See e2cnn.nn.FieldType.change_of_basis for more details.

transform(input, element)[source]

The method takes a PyTorch’s tensor, compatible with this type (i.e. whose spatial dimensions are supported by the base space and whose number of channels equals the e2cnn.nn.FieldType.size of this type), and an element of the fiber group of this type.

Transform the input tensor according to the group representation associated with the input element and its (induced) action on the base space.

Warning

This method is internally implemented using numpy. This means that the input tensor is detached (and moved to CPU) before the transformation, therefore no gradient is propagated back through this operation.

See e2cnn.nn.GeometricTensor.transform_fibers() to transform only the fibers, i.e. not transform the base space.

See e2cnn.gspaces.GSpace.featurefield_action() for more details.

Parameters
• input (torch.Tensor) – input tensor

• element – element of the fiber group

Returns

transformed tensor

restrict(id)[source]

Reduce the symmetries modeled by the FieldType by choosing a subgroup of its fiber group as specified by id. This implies a restriction of each representation in e2cnn.nn.FieldType.representations to this subgroup.

Check the documentation of the restrict() method in the subclass of GSpace used for a description of the parameter id.

Parameters

id – identifier of the subgroup to which the FieldType and its e2cnn.nn.FieldType.representations should be restricted

Returns

the restricted type

sorted()[source]

Return a new field type containing the fields of the current one sorted by their dimensionalities. It is built from the e2cnn.nn.FieldType.representations of this field type sorted.

Returns

the sorted field type

__add__(other)[source]

Returns a field type associate with the direct sum $$\rho = \rho_1 \oplus \rho_2$$ of the representations $$\rho_1$$ and $$\rho_2$$ of two field types.

In practice, the method builds a new FieldType using the concatenation of the lists e2cnn.nn.FieldType.representations of the two field types.

The two field types need to be associated with the same GSpace.

Parameters

other (FieldType) – the other addend

Returns

the direct sum

__len__()[source]

Return the number of feature fields in this FieldType, i.e. the length of e2cnn.nn.FieldType.representations.

Note

This is in general different from e2cnn.nn.FieldType.size.

Returns

the number of fields in this type

fields_names()[source]

Return an ordered list containing the names of the representation associated with each field.

Returns

the list of fields’ representations’ names

index_select(index)[source]

Build a new FieldType from the current one by taking the Representation s selected by the input index.

Parameters

index (list) – a list of integers in the range {0, ..., N-1}, where N is the number of representations in the current field type

Returns

the new field type

property fields_end

Array containing the index of the first channel following each field. More precisely, the integer in the $$i$$-th position is equal to the index of the last channel of the $$i$$-th field plus $$1$$.

property fields_start

Array containing the index of the first channel of each field. More precisely, the integer in the $$i$$-th position is equal to the index of the first channel of the $$i$$-th field.

group_by_labels(labels)[source]

Associate a label to each feature field (or representation in e2cnn.nn.FieldType.representations) and group them accordingly into new FieldType s.

Parameters

labels (list) – a list of strings with length equal to the number of representations in e2cnn.nn.FieldType.representations

Returns

a dictionary mapping each different input label to a new field type

__iter__()[source]

It is possible to iterate over all representations in a field type by using FieldType as an iterable object.

property testing_elements

Alias for self.gspace.testing_elements.

## Geometric Tensor¶

class GeometricTensor(tensor, type)[source]

A GeometricTensor can be interpreted as a typed tensor. It is wrapping a common torch.Tensor and endows it with a (compatible) FieldType as transformation law.

The FieldType describes the action of a group $$G$$ on the tensor. This action includes both a transformation of the base space and a transformation of the channels according to a $$G$$-representation $$\rho$$.

All e2cnn neural network operations have GeometricTensor s as inputs and outputs. They perform a dynamic typechecking, ensuring that the transformation laws of the data and the operation match. See also EquivariantModule.

As usual, the first dimension of the tensor is interpreted as the batch dimension. The second is the fiber (or channel) dimension, which is associated with a group representation by the field type. The following dimensions are the spatial dimensions (like in a conventional CNN).

The operations of addition and scalar multiplication are supported. For example:

gs = e2cnn.gspaces.Rot2dOnR2(8)
type = e2cnn.nn.FieldType(gs, [gs.regular_repr]*3)
t1 = e2cnn.nn.GeometricTensor(torch.randn(1, 24, 3, 3), type)
t2 = e2cnn.nn.GeometricTensor(torch.randn(1, 24, 3, 3), type)

t3 = t1 + t2

# scalar product
t3 = t1 * 3.

# scalar product also supports tensors containing only one scalar
t3 = t1 * torch.tensor(3.)

# inplace operations are also supported
t1 += t2
t2 *= 3.


Warning

The multiplication of a PyTorch tensor containing only a scalar with a GeometricTensor is only supported when using PyTorch 1.4 or higher (see this issue )

A GeometricTensor supports slicing in a similar way to PyTorch’s torch.Tensor. More precisely, slicing along the batch (1st) and the spatial (3rd, 4th, …) dimensions works as usual. However, slicing the fiber (2nd) dimension would break equivariance when splitting channels belonging to the same field. To prevent this, slicing on the second dimension is defined over fields instead of channels.

Warning

GeometricTensor only supports basic slicing but it does not support advanced indexing (see NumPy’s documentation about indexing for more details). Moreover, in contrast to NumPy and PyTorch, an index containing a single integer value does not reduce the dimensionality of the tensor. In this way, the resulting tensor can always be interpreted as a GeometricTensor.

We give few examples to illustrate this behavior:

# Example of GeometricTensor slicing
space = e2cnn.gspaces.Rot2dOnR2(4)
type = e2cnn.nn.FieldType(space, [
# field type            # index # size
space.regular_repr,     #   0   #  4
space.regular_repr,     #   1   #  4
space.irrep(1),         #   2   #  2
space.irrep(1),         #   3   #  2
space.trivial_repr,     #   4   #  1
space.trivial_repr,     #   5   #  1
space.trivial_repr,     #   6   #  1
])                          #   sum = 15

# this FieldType contains 8 fields
len(type)
>> 7

# the size of this FieldType is equal to the sum of the sizes of each of its fields
type.size
>> 15

geom_tensor = e2cnn.nn.GeometricTensor(torch.randn(10, type.size, 9, 9), type)

geom_tensor.shape
>> torch.Size([10, 15, 9, 9])

geom_tensor[1:3, :, 2:5, 2:5].shape
>> torch.Size([2, 15, 3, 3])

geom_tensor[..., 2:5].shape
>> torch.Size([10, 15, 9, 3])

# the tensor contains the fields 1:4, i.e 1, 2 and 3
# these fields have size, respectively, 4, 2 and 2
# so the resulting tensor has 8 channels
geom_tensor[:, 1:4, ...].shape
>> torch.Size([10, 8, 9, 9])

# the tensor contains the fields 0:6:2, i.e 0, 2 and 4
# these fields have size, respectively, 4, 2 and 1
# so the resulting tensor has 7 channels
geom_tensor[:, 0:6:2].shape
>> torch.Size([10, 7, 9, 9])

# the tensor contains only the field 2, which has size 2
# note, also, that even though a single index is used for the batch dimension, the resulting tensor
# still has 4 dimensions
geom_tensor[3, 2].shape
>> torch.Size(1, 2, 9, 9)


Warning

Slicing over the fiber (2nd) dimension with step > 1 or with a negative step is converted into indexing over the channels. This means that, in these cases, slicing behaves like advanced indexing in PyTorch and NumPy returning a copy instead of a view. For more details, see the note here and NumPy’s docs .

Note

Slicing is not supported for setting values inside the tensor (i.e. __setitem__() is not implemented). Indeed, depending on the values which are assigned, this operation can break the symmetry of the tensor which may not transform anymore according to its transformation law (specified by type). In case this feature is necessary, one can directly access the underlying torch.Tensor, e.g. geom_tensor.tensor[:3, :, 2:5, 2:5] = torch.randn(3, 4, 3, 3), although this is not recommended.

Parameters
• tensor (torch.Tensor) – the tensor data

• type (FieldType) – the type of the tensor, modeling its transformation law

Variables
restrict(id)[source]

Restrict the field type of this tensor.

The method returns a new GeometricTensor whose type is equal to this tensor’s type restricted to a subgroup $$H<G$$ (see e2cnn.nn.FieldType.restrict()). The restricted type is associated with the restricted representation $$\Res{H}{G}\rho$$ of the $$G$$-representation $$\rho$$ associated to this tensor’s type. The input id specifies the subgroup $$H < G$$.

Notice that the underlying tensor instance will be shared between the current tensor and the returned one.

Warning

The method builds the new representation on the fly; hence, if this operation is needed at run time, we suggest to use e2cnn.nn.RestrictionModule which pre-computes the new representation offline.

Check the documentation of the restrict() method in the GSpace instance used for a description of the parameter id.

Parameters

id – the id identifying the subgroup $$H$$ the representations are restricted to

Returns

the geometric tensor with the restricted representations

split(breaks)[source]

Split this tensor on the channel dimension in a list of smaller tensors. The original tensor is split at the fields specified by the index list breaks.

If the tensor is associated with the list of fields $$\{\rho_i\}_{i=0}^{L-1}$$ (where $$L$$ equals len(self.type); see also e2cnn.nn.FieldType.representations), the $$j$$-th output tensor ($$j>0$$) will contain the fields $$\rho_{\text{breaks}[j-1]}, \dots, \rho_{\text{breaks}[j]-1}$$ of the original tensor. The $$j=0$$-th tensor contains the fields $$\rho_{0}, \dots, \rho_{\text{breaks}[0]-1}$$ while the last tensor ($$j = len(breaks)$$) contains the last fields $$\rho_{\text{breaks}[-1]}, \dots, \rho_{L-1}$$.

Note

breaks must be sorted list of integers greater than 0 and smaller than len(self.type) - 1.

If breaks = None, the tensor is split at each field. This is equivalent to using breaks = list(range(1, len(self.type))).

Example

space = e2cnn.gspaces.Rot2dOnR2(4)
type = e2cnn.nn.FieldType(space, [
space.regular_repr,     # size = 4
space.regular_repr,     # size = 4
space.irrep(1),         # size = 2
space.irrep(1),         # size = 2
space.trivial_repr,     # size = 1
space.trivial_repr,     # size = 1
space.trivial_repr,     # size = 1
])                          #  sum = 15

type.size
>> 15

geom_tensor = e2cnn.nn.GeometricTensor(torch.randn(10, type.size, 7, 7), type)

geom_tensor.shape
>> torch.Size([10, 15, 7, 7])

# split the tensor in 3 parts
len(geom_tensor.split([4, 6]))
>> 3

# the first contains
# - the first 2 regular fields (2*4 = 8 channels)
# - 2 vector fields (irrep(1)) (2*2 = 4 channels)
# and, therefore, contains 12 channels
geom_tensor.split([4, 6])[0].shape
>> torch.Size([10, 12, 7, 7])

# the second contains only 2 scalar (trivial) fields (2*1 = 2 channels)
geom_tensor.split([4, 6])[1].shape
>> torch.Size([10, 2, 7, 7])

# the last contains only 1 scalar (trivial) field (1*1 = 1 channels)
geom_tensor.split([4, 6])[2].shape
>> torch.Size([10, 1, 7, 7])

Parameters

breaks (list) – indices of the fields where to split the tensor

Returns

list of GeometricTensor s into which the original tensor is chunked

transform(element)[source]

Transform the current tensor according to the group representation associated to the input element and its induced action on the base space

Warning

The input tensor is detached before the transformation therefore no gradient is backpropagated through this operation

See e2cnn.nn.GeometricTensor.transform_fibers() to transform only the fibers, i.e. not transform the base space.

Parameters

element – an element of the group of symmetries of the fiber.

Returns

the transformed tensor

transform_fibers(element)[source]

Transform the feature vectors of the underlying tensor according to the group representation associated to the input element.

Interpreting the tensor as a vector-valued signal $$f: X \to \mathbb{R}^c$$ over a base space $$X$$ (where $$c$$ is the number of channels of the tensor), given the input element $$g \in G$$ ($$G$$ fiber group) the method returns the new signal $$f'$$:

$f'(x) := \rho(g) f(x)$

for $$x \in X$$ point in the base space and $$\rho$$ the representation of $$G$$ in the field type of this tensor.

Notice that the input element has to be an element of the fiber group of this tensor’s field type.

See e2cnn.nn.GeometricTensor.transform() to transform the whole tensor.

Parameters

element – an element of the group of symmetries of the fiber.

Returns

the transformed tensor

property shape

Alias for self.tensor.shape

size()[source]

Alias for self.tensor.size()

to(*args, **kwargs)[source]

Alias for self.tensor.to(*args, **kwargs).

Applies torch.Tensor.to() to the underlying tensor and wraps the resulting tensor in a new GeometricTensor with the same type.

__getitem__(slices)[source]

A GeometricTensor supports slicing in a similar way to PyTorch’s torch.Tensor. More precisely, slicing along the batch (1st) and the spatial (3rd, 4th, …) dimensions works as usual. However, slicing along the channel dimension could break equivariance by splitting the channels belonging to the same field. For this reason, slicing on the second dimension is not defined over the channels but over fields.

When a continuous (step=1) slice is used over the fields/channels dimension (the 2nd axis), it is converted into a continuous slice over the channels. This is not possible when the step is greater than 1 or negative. In such cases, the slice over the fields needs to be converted into an index over the channels.

Moreover, when a single integer is used to index an axis, that axis is not discarded as in PyTorch but is preserved with size 1.

Slicing is not supported for setting values inside the tensor (i.e. object.__setitem__() is not implemented).

__add__(other)[source]

Add two compatible GeometricTensor using pointwise addition. The two tensors needs to have the same shape and be associated to the same field type.

Parameters

other (GeometricTensor) – the other geometric tensor

Returns

the sum

__sub__(other)[source]

Subtract two compatible GeometricTensor using pointwise subtraction. The two tensors needs to have the same shape and be associated to the same field type.

Parameters

other (GeometricTensor) – the other geometric tensor

Returns

their difference

__iadd__(other)[source]

Add a compatible GeometricTensor to this tensor inplace. The two tensors needs to have the same shape and be associated to the same field type.

Parameters

other (GeometricTensor) – the other geometric tensor

Returns

this tensor

__isub__(other)[source]

Subtract a compatible GeometricTensor to this tensor inplace. The two tensors needs to have the same shape and be associated to the same field type.

Parameters

other (GeometricTensor) – the other geometric tensor

Returns

this tensor

__mul__(other)[source]

Scalar product of this GeometricTensor with a scalar. The operation is done inplace.

The scalar can be a float number of a torch.Tensor containing only one scalar (i.e. torch.numel() should return 1).

Parameters

other (Union[float, Tensor]) – a scalar

Returns

the scalar product

__rmul__(other)

Scalar product of this GeometricTensor with a scalar. The operation is done inplace.

The scalar can be a float number of a torch.Tensor containing only one scalar (i.e. torch.numel() should return 1).

Parameters

other (Union[float, Tensor]) – a scalar

Returns

the scalar product

__imul__(other)[source]

Scalar product of this GeometricTensor with a scalar.

The scalar can be a float number of a torch.Tensor containing only one scalar (i.e. torch.numel() should return 1).

Parameters

other (Union[float, Tensor]) – a scalar

Returns

the scalar product

## Equivariant Module¶

class EquivariantModule[source]

Abstract base class for all equivariant modules.

An EquivariantModule is a subclass of torch.nn.Module. It follows that any subclass of EquivariantModule needs to implement the forward() method. With respect to a general torch.nn.Module, an equivariant module implements a typed function as both its input and its output are associated with specific FieldType s. Therefore, usually, the inputs and the outputs of an equivariant module are not just instances of torch.Tensor but GeometricTensor s.

As a subclass of torch.nn.Module, it supports most of the commonly used methods (e.g. torch.nn.Module.to(), torch.nn.Module.cuda(), torch.nn.Module.train() or torch.nn.Module.eval())

Many equivariant modules implement a export() method which converts the module to eval mode and returns a pure PyTorch implementation of it. This can be used after training to efficiently deploy the model without, for instance, the overhead of the automatic type checking performed by all the modules in this library.

Warning

Not all modules implement this feature yet. If the export() method is called in a module which does not implement it yet, a NotImplementedError is raised. Check the documentation of each individual module to understand if the method is implemented.

Variables
abstract evaluate_output_shape(input_shape)[source]

Compute the shape the output tensor which would be generated by this module when a tensor with shape input_shape is provided as input.

Parameters

input_shape (tuple) – shape of the input tensor

Returns

shape of the output tensor

check_equivariance(atol=1e-07, rtol=1e-05)[source]

Method that automatically tests the equivariance of the current module. The default implementation of this method relies on e2cnn.nn.GeometricTensor.transform() and uses the the group elements in testing_elements.

This method can be overwritten for custom tests.

Returns

a list containing containing for each testing element a pair with that element and the corresponding equivariance error

export()[source]

Export recursively each submodule to a normal PyTorch module and set to “eval” mode.

Warning

Not all modules implement this feature yet. If the export() method is called in a module which does not implement it yet, a NotImplementedError is raised. Check the documentation of each individual module to understand if the method is implemented.

## Utils¶

### direct sum¶

tensor_directsum(tensors)[source]

Concatenate a list of GeometricTensor s on the channels dimension (dim=1). The input tensors have to be compatible: they need to have the same shape except for the channels dimension (dim=1). In the resulting GeometricTensor, the channels dimension will be associated with the direct sum representation of the representations of the input tensors.

Parameters

tensors (list) – a list of GeometricTensor s

Returns

the direct sum of the inputs

## Planar Convolution and Differential Operators¶

### R2Conv¶

class R2Conv(in_type, out_type, kernel_size, padding=0, stride=1, dilation=1, padding_mode='zeros', groups=1, bias=True, basisexpansion='blocks', sigma=None, frequencies_cutoff=None, rings=None, maximum_offset=None, recompute=False, basis_filter=None, initialize=True)[source]

Bases: e2cnn.nn.modules.equivariant_module.EquivariantModule

G-steerable planar convolution mapping between the input and output FieldType s specified by the parameters in_type and out_type. This operation is equivariant under the action of $$\R^2\rtimes G$$ where $$G$$ is the e2cnn.nn.FieldType.fibergroup of in_type and out_type.

Specifically, let $$\rho_\text{in}: G \to \GL{\R^{c_\text{in}}}$$ and $$\rho_\text{out}: G \to \GL{\R^{c_\text{out}}}$$ be the representations specified by the input and output field types. Then R2Conv guarantees an equivariant mapping

$\kappa \star [\mathcal{T}^\text{in}_{g,u} . f] = \mathcal{T}^\text{out}_{g,u} . [\kappa \star f] \qquad\qquad \forall g \in G, u \in \R^2$

where the transformation of the input and output fields are given by

$\begin{split}[\mathcal{T}^\text{in}_{g,u} . f](x) &= \rho_\text{in}(g)f(g^{-1} (x - u)) \\ [\mathcal{T}^\text{out}_{g,u} . f](x) &= \rho_\text{out}(g)f(g^{-1} (x - u)) \\\end{split}$

The equivariance of G-steerable convolutions is guaranteed by restricting the space of convolution kernels to an equivariant subspace. As proven in 3D Steerable CNNs, this parametrizes the most general equivariant convolutional map between the input and output fields. For feature fields on $$\R^2$$ (e.g. images), the complete G-steerable kernel spaces for $$G \leq \O2$$ is derived in General E(2)-Equivariant Steerable CNNs.

During training, in each forward pass the module expands the basis of G-steerable kernels with learned weights before calling torch.nn.functional.conv2d(). When eval() is called, the filter is built with the current trained weights and stored for future reuse such that no overhead of expanding the kernel remains.

Warning

When train() is called, the attributes filter and expanded_bias are discarded to avoid situations of mismatch with the learnable expansion coefficients. See also e2cnn.nn.R2Conv.train().

This behaviour can cause problems when storing the state_dict() of a model while in a mode and lately loading it in a model with a different mode, as the attributes of the class change. To avoid this issue, we recommend converting the model to eval mode before storing or loading the state dictionary.

The learnable expansion coefficients of the this module can be initialized with the methods in e2cnn.nn.init. By default, the weights are initialized in the constructors using generalized_he_init().

Warning

This initialization procedure can be extremely slow for wide layers. In case initializing the model is not required (e.g. before loading the state dict of a pre-trained model) or another initialization method is preferred (e.g. deltaorthonormal_init()), the parameter initialize can be set to False to avoid unnecessary overhead.

The parameters basisexpansion, sigma, frequencies_cutoff, rings and maximum_offset are optional parameters used to control how the basis for the filters is built, how it is sampled on the filter grid and how it is expanded to build the filter. We suggest to keep these default values.

Parameters
• in_type (FieldType) – the type of the input field, specifying its transformation law

• out_type (FieldType) – the type of the output field, specifying its transformation law

• kernel_size (int) – the size of the (square) filter

• padding (int, optional) – implicit zero paddings on both sides of the input. Default: 0

• padding_mode (str, optional) – zeros, reflect, replicate or circular. Default: zeros

• stride (int, optional) – the stride of the kernel. Default: 1

• dilation (int, optional) – the spacing between kernel elements. Default: 1

• groups (int, optional) – number of blocked connections from input channels to output channels. It allows depthwise convolution. When used, the input and output types need to be divisible in groups groups, all equal to each other. Default: 1.

• bias (bool, optional) – Whether to add a bias to the output (only to fields which contain a trivial irrep) or not. Default True

• basisexpansion (str, optional) – the basis expansion algorithm to use

• sigma (list or float, optional) – width of each ring where the bases are sampled. If only one scalar is passed, it is used for all rings.

• frequencies_cutoff (callable or float, optional) – function mapping the radii of the basis elements to the maximum frequency accepted. If a float values is passed, the maximum frequency is equal to the radius times this factor. By default (None), a more complex policy is used.

• rings (list, optional) – radii of the rings where to sample the bases

• maximum_offset (int, optional) – number of additional (aliased) frequencies in the intertwiners for finite groups. By default (None), all additional frequencies allowed by the frequencies cut-off are used.

• recompute (bool, optional) – if True, recomputes a new basis for the equivariant kernels. By Default (False), it caches the basis built or reuse a cached one, if it is found.

• basis_filter (callable, optional) – function which takes as input a descriptor of a basis element (as a dictionary) and returns a boolean value: whether to preserve (True) or discard (False) the basis element. By default (None), no filtering is applied.

• initialize (bool, optional) – initialize the weights of the model. Default: True

Variables
• weights (torch.Tensor) – the learnable parameters which are used to expand the kernel

• filter (torch.Tensor) – the convolutional kernel obtained by expanding the parameters in weights

• bias (torch.Tensor) – the learnable parameters which are used to expand the bias, if bias=True

• expanded_bias (torch.Tensor) – the equivariant bias which is summed to the output, obtained by expanding the parameters in bias

property basisexpansion

Submodule which takes care of building the filter.

It uses the learnt weights to expand a basis and returns a filter in the usual form used by conventional convolutional modules. It uses the learned weights to expand the kernel in the G-steerable basis and returns it in the shape $$(c_\text{out}, c_\text{in}, s^2)$$, where $$s$$ is the kernel_size.

expand_parameters()[source]

Expand the filter in terms of the e2cnn.nn.R2Conv.weights and the expanded bias in terms of e2cnn.nn.R2Conv.bias.

Returns

the expanded filter and bias

forward(input)[source]

Convolve the input with the expanded filter and bias.

Parameters

input (GeometricTensor) – input feature field transforming according to in_type

Returns

output feature field transforming according to out_type

train(mode=True)[source]

If mode=True, the method sets the module in training mode and discards the filter and expanded_bias attributes.

If mode=False, it sets the module in evaluation mode. Moreover, the method builds the filter and the bias using the current values of the trainable parameters and store them in filter and expanded_bias such that they are not recomputed at each forward pass.

Warning

This behaviour can cause problems when storing the state_dict() of a model while in a mode and lately loading it in a model with a different mode, as the attributes of this class change. To avoid this issue, we recommend converting the model to eval mode before storing or loading the state dictionary.

Parameters

mode (bool, optional) – whether to set training mode (True) or evaluation mode (False). Default: True.

export()[source]

Export this module to a normal PyTorch torch.nn.Conv2d module and set to “eval” mode.

### R2ConvTransposed¶

class R2ConvTransposed(in_type, out_type, kernel_size, padding=0, output_padding=0, stride=1, dilation=1, groups=1, bias=True, basisexpansion='blocks', sigma=None, frequencies_cutoff=None, rings=None, maximum_offset=None, recompute=False, basis_filter=None, initialize=True)[source]

Bases: e2cnn.nn.modules.equivariant_module.EquivariantModule

Transposed G-steerable planar convolution layer.

Warning

Transposed convolution can produce artifacts which can harm the overall equivariance of the model. We suggest using R2Upsampling combined with R2Conv to perform upsampling.

For additional information about the parameters and the methods of this class, see e2cnn.nn.R2Conv. The two modules are essentially the same, except for the type of convolution used.

Parameters
• in_type (FieldType) – the type of the input field

• out_type (FieldType) – the type of the output field

• kernel_size (int) – the size of the filter

• padding (int, optional) – implicit zero paddings on both sides of the input. Default: 0

• output_padding (int, optional) – implicit zero paddings on both sides of the input. Default: 0

• stride (int, optional) – the stride of the convolving kernel. Default: 1

• dilation (int, optional) – the spacing between kernel elements. Default: 1

• groups (int, optional) – number of blocked connections from input channels to output channels. Default: 1.

• bias (bool, optional) – Whether to add a bias to the output (only to fields which contain a trivial irrep) or not. Default True

• initialize (bool, optional) – initialize the weights of the model. Default: True

export()[source]

Export this module to a normal PyTorch torch.nn.ConvTranspose2d module and set to “eval” mode.

### R2Diffop¶

class R2Diffop(in_type, out_type, kernel_size=None, accuracy=None, padding=0, stride=1, dilation=1, padding_mode='zeros', groups=1, bias=True, basisexpansion='blocks', maximum_order=None, maximum_power=None, maximum_offset=None, recompute=False, angle_offset=None, basis_filter=None, initialize=True, cache=False, rbffd=False, radial_basis_function='ga', smoothing=None)[source]

Bases: e2cnn.nn.modules.equivariant_module.EquivariantModule

G-steerable planar partial differential operator mapping between the input and output FieldType s specified by the parameters in_type and out_type. This operation is equivariant under the action of $$\R^2\rtimes G$$ where $$G$$ is the e2cnn.nn.FieldType.fibergroup of in_type and out_type.

Specifically, let $$\rho_\text{in}: G \to \GL{\R^{c_\text{in}}}$$ and $$\rho_\text{out}: G \to \GL{\R^{c_\text{out}}}$$ be the representations specified by the input and output field types. Then R2Diffop guarantees an equivariant mapping

$D [\mathcal{T}^\text{in}_{g,u} . f] = \mathcal{T}^\text{out}_{g,u} . [Df] \qquad\qquad \forall g \in G, u \in \R^2$

where the transformation of the input and output fields are given by

$\begin{split}[\mathcal{T}^\text{in}_{g,u} . f](x) &= \rho_\text{in}(g)f(g^{-1} (x - u)) \\ [\mathcal{T}^\text{out}_{g,u} . f](x) &= \rho_\text{out}(g)f(g^{-1} (x - u)) \\\end{split}$

The equivariance of G-steerable PDOs is guaranteed by restricting the space of PDOs to an equivariant subspace.

During training, in each forward pass the module expands the basis of G-steerable PDOs with learned weights before calling torch.nn.functional.conv2d(). When eval() is called, the filter is built with the current trained weights and stored for future reuse such that no overhead of expanding the PDO remains.

Warning

When train() is called, the attributes filter and expanded_bias are discarded to avoid situations of mismatch with the learnable expansion coefficients. See also e2cnn.nn.R2Diffop.train().

This behaviour can cause problems when storing the state_dict() of a model while in a mode and lately loading it in a model with a different mode, as the attributes of the class change. To avoid this issue, we recommend converting the model to eval mode before storing or loading the state dictionary.

The learnable expansion coefficients of the this module can be initialized with the methods in e2cnn.nn.init. By default, the weights are initialized in the constructors using generalized_he_init().

Warning

This initialization procedure can be extremely slow for wide layers. In case initializing the model is not required (e.g. before loading the state dict of a pre-trained model) or another initialization method is preferred (e.g. deltaorthonormal_init()), the parameter initialize can be set to False to avoid unnecessary overhead.

A reasonable default is to only set the kernel_size and leave all other options on their defaults. However, you might get considerable performance improvements by setting smoothing to something other than None (kernel_size / 4 is a sane default, see below for details).

If you want to modify accuracy or maximum_order, you will need to take into account how they are related to kernel_size: it is possible to set any two of kernel_size, accuracy and maximum_order, in which case the third one will be determined automatically. Alternatively, you can set either kernel_size or maximum_order, in which case a sane default will be used for accuracy. The relation between the three is approximately $$\text{kernel size} \approx \text{accuracy} + \text{order}$$, though this formula is off by one in some cases. A larger maximum order will lead to more basis filters and this more parameters. A larger accuracy (i.e. larger kernel size at constant order) might lead to lower equivariance errors, though whether this actually happens may depend on your exact setup.

The parameters basisexpansion, maximum_power, and maximum_offset are optional parameters used to control how the basis for the PDOs is built, how it is sampled on the filter grid and how it is expanded to build the filter. We suggest to keep these default values.

Warning

The discretization of the differential operators relies on two external packages: sympy and rbf. If they are not available, an error is raised.

Parameters
• in_type (FieldType) – the type of the input field, specifying its transformation law

• out_type (FieldType) – the type of the output field, specifying its transformation law

• kernel_size (int, optional) – the size of the (square) filter. This can be chosen automatically, see above for details.

• accuracy (int, optional) – the desired asymptotic accuracy for the PDO discretization, affects the kernel_size. See above for details.

• padding (int, optional) – implicit zero paddings on both sides of the input. Default: 0

• padding_mode (str, optional) – zeros, reflect, replicate or circular. Default: zeros

• stride (int, optional) – the stride of the kernel. Default: 1

• dilation (int, optional) – the spacing between kernel elements. Default: 1

• groups (int, optional) – number of blocked connections from input channels to output channels. It allows depthwise convolution. When used, the input and output types need to be divisible in groups groups, all equal to each other. Default: 1.

• bias (bool, optional) – Whether to add a bias to the output (only to fields which contain a trivial irrep) or not. Default True

• basisexpansion (str, optional) – the basis expansion algorithm to use

• maximum_order (int, optional) – the largest derivative order to allow as part of the basis. Larger maximum orders require larger kernel sizes, see above for details.

• maximum_power (int, optional) – the maximum power of the Laplacian that will be used for constructing the basis. If this is not None, it places a restriction on the basis elements, in addition to the restriction given by maximum_order. We suggest to leave this setting on its default unless you have a good reason to change it.

• maximum_offset (int, optional) – number of additional (aliased) frequencies in the intertwiners for finite groups. By default (None), all additional frequencies allowed by the frequencies cut-off are used.

• recompute (bool, optional) – if True, recomputes a new basis for the equivariant PDOs. By Default (False), it caches the basis built or reuse a cached one, if it is found.

• basis_filter (callable, optional) – function which takes as input a descriptor of a basis element (as a dictionary) and returns a boolean value: whether to preserve (True) or discard (False) the basis element. By default (None), no filtering is applied.

• initialize (bool, optional) – initialize the weights of the model. Default: True

• cache (bool or str, optional) –

Discretizing the PDOs can take a bit longer than for kernels, so we provide the option to cache PDOs on disk. Our suggestion is to keep the cache off (default) and only activate it if discretizing the PDOs is in fact a bottleneck for your setup (it often is not). Setting cache to True will load an existing cache before instantiating the layer and will write to the cache afterwards. You can also set cache to load or store to only do one of these.

All R2Diffop layers share the PDO cache in memory. If you have several R2Diffop layers inside your model, we therefore recommend to leave cache to False and instead call e2cnn.diffops.load_cache() before instantiating the model, and e2cnn.diffops.store_cache() afterwards to save the PDOs for the next run of the program. This will avoid unnecessary reads/writes from/to disk.

• rbffd (bool, optional) – if set to True, use RBF-FD discretization instead of finite differences (the default). We suggest leaving this to False unless you have a specific reason for wanting to use RBF-FD.

• radial_basis_function (str, optional) – which RBF to use (only relevant for RBF-FD). Can be any of the abbreviations in this list. The default is to use Gaussian RBFs because this always avoids singularity issues. But other RBFs, such as polyharmonic splines, may work better if they are applicable.

• smoothing (float, optional) – if not None, discretization will be performed with derivatives of Gaussians as stencils. This is similar to smoothing with a Gaussian before applying the PDO, though there are slight technical differences. smoothing is the standard deviation (in pixels) of the Gaussian, meaning that larger values correspond to stronger smoothing. A reasonable value would be about kernel_size / 4 but you might want to experiment a bit with this parameter.

Variables
• weights (torch.Tensor) – the learnable parameters which are used to expand the PDO

• filter (torch.Tensor) – the convolutional stencil obtained by expanding the parameters in weights

• bias (torch.Tensor) – the learnable parameters which are used to expand the bias, if bias=True

• expanded_bias (torch.Tensor) – the equivariant bias which is summed to the output, obtained by expanding the parameters in bias

property basisexpansion

Submodule which takes care of building the filter.

It uses the learnt weights to expand a basis and returns a filter in the usual form used by conventional convolutional modules. It uses the learned weights to expand the PDO in the G-steerable basis and returns it in the shape $$(c_\text{out}, c_\text{in}, s^2)$$, where $$s$$ is the kernel_size.

expand_parameters()[source]

Expand the filter in terms of the e2cnn.nn.R2Diffop.weights and the expanded bias in terms of e2cnn.nn.R2Diffop.bias.

Returns

the expanded filter and bias

forward(input)[source]

Convolve the input with the expanded filter and bias.

Parameters

input (GeometricTensor) – input feature field transforming according to in_type

Returns

output feature field transforming according to out_type

train(mode=True)[source]

If mode=True, the method sets the module in training mode and discards the filter and expanded_bias attributes.

If mode=False, it sets the module in evaluation mode. Moreover, the method builds the filter and the bias using the current values of the trainable parameters and store them in filter and expanded_bias such that they are not recomputed at each forward pass.

Warning

This behaviour can cause problems when storing the state_dict() of a model while in a mode and lately loading it in a model with a different mode, as the attributes of this class change. To avoid this issue, we recommend converting the model to eval mode before storing or loading the state dictionary.

Parameters

mode (bool, optional) – whether to set training mode (True) or evaluation mode (False). Default: True.

export()[source]

Export this module to a normal PyTorch torch.nn.Conv2d module and set to “eval” mode.

### Basis Expansion Modules¶

#### Basis Expansion¶

class BasisExpansion[source]

Bases: abc.ABC, torch.nn.modules.module.Module

Abstract class defining the interface for the different basis expansion algorithms.

abstract forward(weights)[source]

Forward step of the module which expands the basis and returns the filter built

Parameters

weights (torch.Tensor) – the learnable weights used to linearly combine the basis elements

Returns

the filter built

abstract get_basis_names()[source]

Method that returns the list of identification names of the basis elements

Returns

list of names

abstract get_element_info(name)[source]

Method that returns the information associated to a basis element

Parameters

name (str or int) – identifier of the basis element or its index

Returns

dictionary containing the information

abstract get_basis_info()[source]

Method that returns an iterable over all basis elements’ attributes.

Returns

an iterable over all the basis elements’ attributes

abstract dimension()[source]

The dimensionality of the basis and, so, the number of weights needed to expand it.

Returns

the dimensionality of the basis

#### BlocksBasisExpansion¶

class BlocksBasisExpansion(in_type, out_type, basis_generator, points, basis_filter=None, recompute=False, **kwargs)[source]

Bases: e2cnn.nn.modules.r2_conv.basisexpansion.BasisExpansion

With this algorithm, the expansion is done on the intertwiners of the fields’ representations pairs in input and output.

Parameters
• in_type (FieldType) – the input field type

• out_type (FieldType) – the output field type

• basis_generator (callable) – method that generates the analytical filter basis

• points (ndarray) – points where the analytical basis should be sampled

• basis_filter (callable, optional) – filter for the basis elements. Should take a dictionary containing an element’s attributes and return whether to keep it or not.

• recompute (bool, optional) – whether to recompute new bases or reuse, if possible, already built tensors.

• **kwargs – keyword arguments to be passed to basis_generator

Variables

S (int) – number of points where the filters are sampled

forward(weights)[source]

Forward step of the Module which expands the basis and returns the filter built

Parameters

weights (torch.Tensor) – the learnable weights used to linearly combine the basis filters

Returns

the filter built

#### SingleBlockBasisExpansion¶

class SingleBlockBasisExpansion(basis, points, basis_filter=None)[source]

Bases: e2cnn.nn.modules.r2_conv.basisexpansion.BasisExpansion

Basis expansion method for a single contiguous block, i.e. for kernels/PDOs whose input type and output type contain only fields of one type.

This class should be instantiated through the factory method block_basisexpansion() to enable caching.

Parameters
• basis (Basis) – analytical basis to sample

• points (ndarray) – points where the analytical basis should be sampled

• basis_filter (callable, optional) – filter for the basis elements. Should take a dictionary containing an element’s attributes and return whether to keep it or not.

#### block_basisexpansion¶

block_basisexpansion(basis, points, basis_filter=None, recompute=False)[source]

Return an instance of SingleBlockBasisExpansion.

This function support caching through the argument recompute.

Parameters
• basis (Basis) – basis defining the space of kernels

• points (ndarray) – points where the analytical basis should be sampled

• basis_filter (callable, optional) – filter for the basis elements. Should take a dictionary containing an element’s attributes and return whether to keep it or not.

• recompute (bool, optional) – whether to recompute new bases (True) or reuse, if possible, already built tensors (False, default).

## Non Linearities¶

### ConcatenatedNonLinearity¶

class ConcatenatedNonLinearity(in_type, function='c_relu')[source]

Bases: e2cnn.nn.modules.equivariant_module.EquivariantModule

Concatenated non-linearities. For each input channel, the module applies the specified activation function both to its value and its opposite (the value multiplied by -1). The number of channels is, therefore, doubled.

Notice that not all the representations support this kind of non-linearity. Indeed, only representations with the same pattern of permutation matrices and containing only values in $$\{0, 1, -1\}$$ support it.

Parameters
• in_type (FieldType) – the input field type

• function (str) – the identifier of the non-linearity. It is used to specify which function to apply. By default ('c_relu'), ReLU is used.

### ELU¶

class ELU(in_type, alpha=1.0, inplace=False)[source]

Bases: e2cnn.nn.modules.equivariant_module.EquivariantModule

Module that implements a pointwise ELU to every channel independently. The input representation is preserved by this operation and, therefore, it equals the output representation.

Only representations supporting pointwise non-linearities are accepted as input field type.

Parameters
• in_type (FieldType) – the input field type

• alpha (float) – the $$\alpha$$ value for the ELU formulation. Default: 1.0

• inplace (bool, optional) – can optionally do the operation in-place. Default: False

forward(input)[source]

Applies ELU function on the input fields

Parameters

input (GeometricTensor) – the input feature map

Returns

the resulting feature map after elu has been applied

export()[source]

Export this module to a normal PyTorch torch.nn.ELU module and set to “eval” mode.

### GatedNonLinearity1¶

class GatedNonLinearity1(in_type, gates=None, drop_gates=True, **kwargs)[source]

Bases: e2cnn.nn.modules.equivariant_module.EquivariantModule

Gated non-linearities. This module applies a bias and a sigmoid function of the gates fields and, then, multiplies each gated field by one of the gates.

The input representation of the gated fields is preserved by this operation while the gate fields are discarded.

The gates and the gated fields are provided in one unique input tensor and, therefore, in_repr should be the representation of the fiber containing both gates and gated fields. Moreover, the parameter gates needs to be set with a list long as the total number of fields, containing in a position i the string "gate" if the i-th field is a gate or the string "gated" if the i-th field is a gated field. No other strings are allowed. By default (gates = None), the first half of the fields is assumed to contain the gates (and, so, these fields have to be trivial fields) while the second one is assumed to contain the gated fields.

In any case, the number of gates and the number of gated fields have to match (therefore, the number of fields has to be an even number).

Parameters
• in_type (FieldType) – the input field type

• gates (list, optional) – list of strings specifying which field in input is a gate and which is a gated field

• drop_gates (bool, optional) – if True (default), drop the trivial fields after using them to compute the gates. If False, the gates are stacked with the gated fields in the output

forward(input)[source]

Apply the gated non-linearity to the input feature map.

Parameters

input (GeometricTensor) – the input feature map

Returns

the resulting feature map

### GatedNonLinearity2¶

class GatedNonLinearity2(in_type, **kwargs)[source]

Bases: e2cnn.nn.modules.equivariant_module.EquivariantModule

Gated non-linearities. This module applies a bias and a sigmoid function of the gates fields and, then, multiplies each gated field by one of the gates.

The input representation of the gated fields is preserved by this operation while the gate fields are discarded.

The gates and the gated fields are provided in two different tensors: in_repr is a tuple containing two representations: the representation of the tensor containing only the gates (which have to be trivial fields) and the representation of the tensor containing only the gated fields. Therefore, two tensors need to be provided as input to the forward method: the first contains the gates and the second the gated fields. Finally, notice that the number of gates and the number of gated fields have to match, i.e. the two representations need to have the same number of fields.

Todo

This module has 2 input tensors and, so, two input field types. EquivariantModule only supports one input though. Fix this.

Parameters

in_type (Tuple) – a pair containing, in order, the field type of the gates and the field type of the gated fields

forward(gates, input)[source]

Apply the gated non-linearity to the input feature map.

Parameters
Returns

the resulting feature map

### InducedGatedNonLinearity1¶

class InducedGatedNonLinearity1(in_type, gates=None, drop_gates=True, **kwargs)[source]

Bases: e2cnn.nn.modules.equivariant_module.EquivariantModule

Induced Gated non-linearities.

Todo

complete documentation!

Note

Make sure all induced gate and gates have same subgroup

Parameters
• in_type (FieldType) – the input field type

• gates (list, optional) – list of strings specifying which field in input is a gate and which is a gated field

• drop_gates (bool, optional) – if True (default), drop the trivial fields after using them to compute the gates. If False, the gates are stacked with the gated fields in the output

forward(input)[source]

Apply the gated non-linearity to the input feature map.

Parameters

input (GeometricTensor) – the input feature map

Returns

the resulting feature map

### InducedNormNonLinearity¶

class InducedNormNonLinearity(in_type, function='n_relu', bias=True)[source]

Bases: e2cnn.nn.modules.equivariant_module.EquivariantModule

Induced Norm non-linearities. This module requires the input fields to be associated to an induced representation from a representation which supports ‘norm’ non-linearities. This module applies a bias and an activation function over the norm of each sub-field of an induced field. The bias is shared among all sub-field of the same induced field.

The input representation of the fields is preserved by this operation.

Parameters
• in_type (FieldType) – the input field type

• function (str, optional) – the identifier of the non-linearity. It is used to specify which function to apply. By default ('n_relu'), ReLU is used.

• bias (bool, optional) – add bias to norm of fields before computing the non-linearity. Default: True

forward(input)[source]

Apply norm non-linearities to the input feature map

Parameters

input (GeometricTensor) – the input feature map

Returns

the resulting feature map

### NormNonLinearity¶

class NormNonLinearity(in_type, function='n_relu', bias=True)[source]

Bases: e2cnn.nn.modules.equivariant_module.EquivariantModule

Norm non-linearities. This module applies a bias and an activation function over the norm of each field.

The input representation of the fields is preserved by this operation.

Note

If ‘squash’ non-linearity (function) is chosen, no bias is allowed

Parameters
• in_type (FieldType) – the input field type

• function (str, optional) – the identifier of the non-linearity. It is used to specify which function to apply. By default ('n_relu'), ReLU is used.

• bias (bool, optional) – add bias to norm of fields before computing the non-linearity. Default: True

forward(input)[source]

Apply norm non-linearities to the input feature map

Parameters

input (GeometricTensor) – the input feature map

Returns

the resulting feature map

### PointwiseNonLinearity¶

class PointwiseNonLinearity(in_type, function='p_relu')[source]

Bases: e2cnn.nn.modules.equivariant_module.EquivariantModule

Pointwise non-linearities. The same scalar function is applied to every channel independently. The input representation is preserved by this operation and, therefore, it equals the output representation.

Only representations supporting pointwise non-linearities are accepted as input field type.

Parameters
• in_type (FieldType) – the input field type

• function (str) – the identifier of the non-linearity. It is used to specify which function to apply. By default ('p_relu'), ReLU is used.

forward(input)[source]

Applies the pointwise activation function on the input fields

Parameters

input (GeometricTensor) – the input feature map

Returns

the resulting feature map after the non-linearities have been applied

### ReLU¶

class ReLU(in_type, inplace=False)[source]

Bases: e2cnn.nn.modules.equivariant_module.EquivariantModule

Module that implements a pointwise ReLU to every channel independently. The input representation is preserved by this operation and, therefore, it equals the output representation.

Only representations supporting pointwise non-linearities are accepted as input field type.

Parameters
• in_type (FieldType) – the input field type

• inplace (bool, optional) – can optionally do the operation in-place. Default: False

forward(input)[source]

Applies ReLU function on the input fields

Parameters

input (GeometricTensor) – the input feature map

Returns

the resulting feature map after relu has been applied

export()[source]

Export this module to a normal PyTorch torch.nn.ReLU module and set to “eval” mode.

### VectorFieldNonLinearity¶

class VectorFieldNonLinearity(in_type, **kwargs)[source]

Bases: e2cnn.nn.modules.equivariant_module.EquivariantModule

VectorField non-linearities. This non-linearity only supports the regular representation of cyclic group $$C_N$$, i.e. the group of $$N$$ discrete rotations. For each input field, the output one is built by taking the rotation associated with the highest activation; then, a 2-dimensional vector with an angle with respect to the x-axis equal to that rotation and a length equal to its activation is set in the output field.

Parameters

in_type (FieldType) – the input field type

forward(input)[source]

Apply the VectorField non-linearity to the input feature map.

Parameters

input (GeometricTensor) – the input feature map

Returns

the resulting feature map

## Invariant Maps¶

### GroupPooling¶

class GroupPooling(in_type, **kwargs)[source]

Bases: e2cnn.nn.modules.equivariant_module.EquivariantModule

Module that implements group pooling. This module only supports permutation representations such as regular representation, quotient representation or trivial representation (though, in the last case, this module acts as identity). For each input field, an output field is built by taking the maximum activation within that field; as a result, the output field transforms according to a trivial representation.

Parameters

in_type (FieldType) – the input field type

forward(input)[source]

Apply Group Pooling to the input feature map.

Parameters

input (GeometricTensor) – the input feature map

Returns

the resulting feature map

export()[source]

Export this module to the pure PyTorch module MaxPoolChannels and set to “eval” mode.

Warning

Currently, this method only supports group pooling with feature types containing only representations of the same size.

Note

Because there is no native PyTorch module performing this operation, it is not possible to export this module without any dependency with this library. Indeed, the resulting module is dependent on this library through the class MaxPoolChannels. In case PyTorch will introduce a similar module in a future release, we will update this method to remove this dependency.

Nevertheless, the MaxPoolChannels module is slightly lighter than GroupPooling as it does not perform any automatic type checking and does not wrap each tensor in a GeometricTensor. Furthermore, the MaxPoolChannels class is very simple and one can easily reimplement it to remove any dependency with this library after training the model.

### NormPool¶

class NormPool(in_type, **kwargs)[source]

Bases: e2cnn.nn.modules.equivariant_module.EquivariantModule

Module that implements Norm Pooling. For each input field, an output one is built by taking the norm of that field; as a result, the output field transforms according to a trivial representation.

Parameters

in_type (FieldType) – the input field type

forward(input)[source]

Apply the Norm Pooling to the input feature map.

Parameters

input (GeometricTensor) – the input feature map

Returns

the resulting feature map

### InducedNormPool¶

class InducedNormPool(in_type, **kwargs)[source]

Bases: e2cnn.nn.modules.equivariant_module.EquivariantModule

Module that implements Induced Norm Pooling. This module requires the input fields to be associated to an induced representation from a representation which supports ‘norm’ non-linearities.

First, for each input field, an output one is built by taking the maximum norm of all its sub-fields.

Parameters

in_type (FieldType) – the input field type

forward(input)[source]

Apply the Norm Pooling to the input feature map.

Parameters

input (GeometricTensor) – the input feature map

Returns

the resulting feature map

## Pooling¶

### NormMaxPool¶

class NormMaxPool(in_type, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False)[source]

Bases: e2cnn.nn.modules.equivariant_module.EquivariantModule

Max-pooling based on the fields’ norms. In a given window of shape kernel_size, for each group of channels belonging to the same field, the field with the highest norm (as the length of the vector) is preserved.

Except in_type, the other parameters correspond to the ones of torch.nn.MaxPool2d.

Parameters
forward(input)[source]

Run the norm-based max-pooling on the input tensor

Parameters

input (GeometricTensor) – the input feature map

Returns

the resulting feature map

### PointwiseMaxPool¶

class PointwiseMaxPool(in_type, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False)[source]

Bases: e2cnn.nn.modules.equivariant_module.EquivariantModule

Channel-wise max-pooling: each channel is treated independently. This module works exactly as torch.nn.MaxPool2D, wrapping it in the EquivariantModule interface.

Notice that not all representations support this kind of pooling. In general, only representations which support pointwise non-linearities do.

Parameters
forward(input)[source]
Parameters

input (GeometricTensor) – the input feature map

Returns

the resulting feature map

export()[source]

Export this module to a normal PyTorch torch.nn.MaxPool2d module and set to “eval” mode.

### PointwiseMaxPoolAntialiased¶

class PointwiseMaxPoolAntialiased(in_type, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, sigma=0.6)[source]

Bases: e2cnn.nn.modules.pooling.pointwise_max.PointwiseMaxPool

Anti-aliased version of channel-wise max-pooling (each channel is treated independently).

The max over a neighborhood is performed pointwise withot downsampling. Then, convolution with a gaussian blurring filter is performed before downsampling the feature map.

Notice that not all representations support this kind of pooling. In general, only representations which support pointwise non-linearities do.

Parameters
forward(input)[source]
Parameters

input (GeometricTensor) – the input feature map

Returns

the resulting feature map

export()[source]

Export this module to a normal PyTorch torch.nn.MaxPool2d module and set to “eval” mode.

### PointwiseAvgPool¶

class PointwiseAvgPool(in_type, kernel_size, stride=None, padding=0, ceil_mode=False)[source]

Bases: e2cnn.nn.modules.equivariant_module.EquivariantModule

Channel-wise average-pooling: each channel is treated independently. This module works exactly as torch.nn.AvgPool2D, wrapping it in the EquivariantModule interface.

Parameters
forward(input)[source]
Parameters

input (GeometricTensor) – the input feature map

Returns

the resulting feature map

### PointwiseAvgPoolAntialiased¶

class PointwiseAvgPoolAntialiased(in_type, sigma, stride, padding=None)[source]

Bases: e2cnn.nn.modules.equivariant_module.EquivariantModule

Antialiased channel-wise average-pooling: each channel is treated independently. It performs strided convolution with a Gaussian blur filter.

The size of the filter is computed as 3 standard deviations of the Gaussian curve. By default, padding is added such that input size is preserved if stride is 1.

Parameters
forward(input)[source]
Parameters

input (GeometricTensor) – the input feature map

Returns

the resulting feature map

class PointwiseAdaptiveAvgPool(in_type, output_size)[source]

Bases: e2cnn.nn.modules.equivariant_module.EquivariantModule

Adaptive channel-wise average-pooling: each channel is treated independently. This module works exactly as torch.nn.AdaptiveAvgPool2D, wrapping it in the EquivariantModule interface.

Notice that not all representations support this kind of pooling. In general, only representations which support pointwise non-linearities do.

Parameters
forward(input)[source]
Parameters

input (GeometricTensor) – the input feature map

Returns

the resulting feature map

export()[source]

Export this module to a normal PyTorch torch.nn.AdaptiveAvgPool2d module and set to “eval” mode.

class PointwiseAdaptiveMaxPool(in_type, output_size)[source]

Bases: e2cnn.nn.modules.equivariant_module.EquivariantModule

Module that implements adaptive channel-wise max-pooling: each channel is treated independently. This module works exactly as torch.nn.AdaptiveMaxPool2D, wrapping it in the EquivariantModule interface.

Notice that not all representations support this kind of pooling. In general, only representations which support pointwise non-linearities do.

Parameters
forward(input)[source]
Parameters

input (GeometricTensor) – the input feature map

Returns

the resulting feature map

export()[source]

Export this module to a normal PyTorch torch.nn.AdaptiveAvgPool2d module and set to “eval” mode.

## Normalization¶

### InnerBatchNorm¶

class InnerBatchNorm(in_type, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)[source]

Bases: e2cnn.nn.modules.equivariant_module.EquivariantModule

Batch normalization for representations with permutation matrices.

Statistics are computed both over the batch and the spatial dimensions and over the channels within the same field (which are permuted by the representation).

Only representations supporting pointwise non-linearities are accepted as input field type.

Parameters
• in_type (FieldType) – the input field type

• eps (float, optional) – a value added to the denominator for numerical stability. Default: 1e-5

• momentum (float, optional) – the value used for the running_mean and running_var computation. Can be set to None for cumulative moving average (i.e. simple average). Default: 0.1

• affine (bool, optional) – if True, this module has learnable affine parameters. Default: True

• track_running_stats (bool, optional) – when set to True, the module tracks the running mean and variance; when set to False, it does not track such statistics but uses batch statistics in both training and eval modes. Default: True

forward(input)[source]
Parameters

input (GeometricTensor) – the input feature map

Returns

the resulting feature map

export()[source]

Export this module to a normal PyTorch torch.nn.BatchNorm2d module and set to “eval” mode.

### NormBatchNorm¶

class NormBatchNorm(in_type, eps=1e-05, momentum=0.1, affine=True)[source]

Bases: e2cnn.nn.modules.equivariant_module.EquivariantModule

Batch normalization for isometric (i.e. which preserves the norm) non-trivial representations.

The module assumes the mean of the vectors is always zero so no running mean is computed and no bias is added. This is guaranteed as long as the representations do not include a trivial representation.

Indeed, if $$\rho$$ does not include a trivial representation, it holds:

$\forall \bold{v} \in \mathbb{R}^n, \ \ \frac{1}{|G|} \sum_{g \in G} \rho(g) \bold{v} = \bold{0}$

Hence, only the standard deviation is normalized.

Only representations which do not contain the trivial representation are allowed. You can check if a representation contains the trivial representation using contains_trivial(). To check if a trivial irrep is present in a representation in a FieldType, you can use:

for r in field_type:
if r.contains_trivial():
print(f"field type contains a trivial irrep")

Parameters
• in_type (FieldType) – the input field type

• eps (float, optional) – a value added to the denominator for numerical stability. Default: 1e-5

• momentum (float, optional) – the value used for the running_mean and running_var computation. Can be set to None for cumulative moving average (i.e. simple average). Default: 0.1

• affine (bool, optional) – if True, this module has learnable scale parameters. Default: True

forward(input)[source]

Apply norm non-linearities to the input feature map

Parameters

input (GeometricTensor) – the input feature map

Returns

the resulting feature map

### InducedNormBatchNorm¶

class InducedNormBatchNorm(in_type, eps=1e-05, momentum=0.1, affine=True)[source]

Bases: e2cnn.nn.modules.equivariant_module.EquivariantModule

Batch normalization for induced isometric representations. This module requires the input fields to be associated to an induced representation from an isometric (i.e. which preserves the norm) non-trivial representation which supports ‘norm’ non-linearities.

The module assumes the mean of the vectors is always zero so no running mean is computed and no bias is added. This is guaranteed as long as the representations do not include a trivial representation.

Indeed, if $$\rho$$ does not include a trivial representation, it holds:

$\forall \bold{v} \in \mathbb{R}^n, \ \ \frac{1}{|G|} \sum_{g \in G} \rho(g) \bold{v} = \bold{0}$

Hence, only the standard deviation is normalized. The same standard deviation, however, is shared by all the sub-fields of the same induced field.

The input representation of the fields is preserved by this operation.

Only representations which do not contain the trivial representation are allowed. You can check if a representation contains the trivial representation using contains_trivial(). To check if a trivial irrep is present in a representation in a FieldType, you can use:

for r in field_type:
if r.contains_trivial():
print(f"field type contains a trivial irrep")

Parameters
• in_type (FieldType) – the input field type

• eps (float, optional) – a value added to the denominator for numerical stability. Default: 1e-5

• momentum (float, optional) – the value used for the running_mean and running_var computation. Can be set to None for cumulative moving average (i.e. simple average). Default: 0.1

• affine (bool, optional) – if True, this module has learnable scale parameters. Default: True

forward(input)[source]

Apply norm non-linearities to the input feature map

Parameters

input (GeometricTensor) – the input feature map

Returns

the resulting feature map

### GNormBatchNorm¶

class GNormBatchNorm(in_type, eps=1e-05, momentum=0.1, affine=True)[source]

Bases: e2cnn.nn.modules.equivariant_module.EquivariantModule

Batch normalization for generic representations.

Todo

Add more details about how stats are computed and how affine transformation is done.

Parameters
• in_type (FieldType) – the input field type

• eps (float, optional) – a value added to the denominator for numerical stability. Default: 1e-5

• momentum (float, optional) – the value used for the running_mean and running_var computation. Can be set to None for cumulative moving average (i.e. simple average). Default: 0.1

• affine (bool, optional) – if True, this module has learnable affine parameters. Default: True

forward(input)[source]

Apply norm non-linearities to the input feature map

Parameters

input (GeometricTensor) – the input feature map

Returns

the resulting feature map

## Dropout¶

### FieldDropout¶

class FieldDropout(in_type, p=0.5, inplace=False)[source]

Bases: e2cnn.nn.modules.equivariant_module.EquivariantModule

Applies dropout to individual fields independently.

Notice that, with respect to PointwiseDropout, this module acts on a whole field instead of single channels.

Parameters
• in_type (FieldType) – the input field type

• p (float, optional) – dropout probability

• inplace (bool, optional) – can optionally do the operation in-place. Default: False

forward(input)[source]
Parameters

input (GeometricTensor) – the input feature map

Returns

the resulting feature map

### PointwiseDropout¶

class PointwiseDropout(in_type, p=0.5, inplace=False)[source]

Bases: e2cnn.nn.modules.equivariant_module.EquivariantModule

Applies dropout to individual channels independently.

This class is just a wrapper for torch.nn.functional.dropout() in an EquivariantModule.

Only representations supporting pointwise non-linearities are accepted as input field type.

Parameters
• in_type (FieldType) – the input field type

• p (float, optional) – dropout probability

• inplace (bool, optional) – can optionally do the operation in-place. Default: False

forward(input)[source]
Parameters

input (GeometricTensor) – the input feature map

Returns

the resulting feature map

export()[source]

Export this module to a normal PyTorch torch.nn.Dropout module and set to “eval” mode.

## Other Modules¶

### Sequential¶

class SequentialModule(*args)[source]

Bases: e2cnn.nn.modules.equivariant_module.EquivariantModule

A sequential container similar to torch.nn.Sequential.

The constructor accepts both a list or an ordered dict of EquivariantModule instances.

Example:

# Example of SequentialModule
s = e2cnn.gspaces.Rot2dOnR2(8)
c_in = e2cnn.nn.FieldType(s, [s.trivial_repr]*3)
c_out = e2cnn.nn.FieldType(s, [s.regular_repr]*16)
model = e2cnn.nn.SequentialModule(
e2cnn.nn.R2Conv(c_in, c_out, 5),
e2cnn.nn.InnerBatchNorm(c_out),
e2cnn.nn.ReLU(c_out),
)

# Example with OrderedDict
s = e2cnn.gspaces.Rot2dOnR2(8)
c_in = e2cnn.nn.FieldType(s, [s.trivial_repr]*3)
c_out = e2cnn.nn.FieldType(s, [s.regular_repr]*16)
model = e2cnn.nn.SequentialModule(OrderedDict([
('conv', e2cnn.nn.R2Conv(c_in, c_out, 5)),
('bn', e2cnn.nn.InnerBatchNorm(c_out)),
('relu', e2cnn.nn.ReLU(c_out)),
]))

forward(input)[source]
Parameters

input (GeometricTensor) – the input GeometricTensor

Returns

the output tensor

add_module(name, module)[source]

Append module to the sequence of modules applied in the forward pass.

export()[source]

Export this module to a normal PyTorch torch.nn.Sequential module and set to “eval” mode.

### ModuleList¶

class ModuleList(modules=None)[source]

Bases: torch.nn.modules.container.ModuleList

Module similar to ModuleList containing a list of EquivariantModule s.

This class works like ModuleList except for the fact it only accepts instances of EquivariantModule.

Additionally, this class provides a .export() method. This method calls the export() method of each module contained in this ModuleList and returns a ModuleList containing the exported modules.

Parameters

modules (iterable, optional) – an iterable of equivariant modules to add

append(module)[source]

Appends an EquivariantModule to the end of the list.

Parameters

module (EquivariantModule) – equivariant module to append

extend(modules)[source]

Appends multiple EquivariantModule instances from a Python iterable to the end of the list.

Parameters

modules (iterable) – iterable of equivariant modules to append

export()[source]

Export this module to a normal PyTorch torch.nn.ModuleList module and set to “eval” mode.

### Restriction¶

class RestrictionModule(in_type, id)[source]

Bases: e2cnn.nn.modules.equivariant_module.EquivariantModule

Restricts the type of the input to the subgroup identified by id.

It computes the output type in the constructor and wraps the underlying tensor (torch.Tensor) in input with the output type during the forward pass.

This module only acts as a wrapper for e2cnn.nn.FieldType.restrict() (or e2cnn.nn.GeometricTensor.restrict()). The accepted values of id depend on the underlying gspace in the input type in_type; check the documentation of the method e2cnn.gspaces.GSpace.restrict() of the gspace used for further information.

Parameters
• in_type (FieldType) – the input field type

• id – a valid id for a subgroup of the space associated with the input type

export()[source]

Export this module to a normal PyTorch torch.nn.Identity module and set to “eval” mode.

Warning

Only working with PyTorch >= 1.2

### Disentangle¶

class DisentangleModule(in_type)[source]

Bases: e2cnn.nn.modules.equivariant_module.EquivariantModule

Disentangles the representations in the field type of the input.

This module only acts as a wrapper for e2cnn.group.disentangle(). In the constructor, it disentangles each representation in the input type to build the output type and pre-compute the change of basis matrices needed to transform each input field.

During the forward pass, each field in the input tensor is transformed with the change of basis corresponding to its representation.

Parameters

in_type (FieldType) – the input field type

### Upsampling¶

class R2Upsampling(in_type, scale_factor, mode='bilinear', align_corners=False)[source]

Bases: e2cnn.nn.modules.equivariant_module.EquivariantModule

Wrapper for torch.nn.functional.interpolate(). Check its documentation for further details.

Only "bilinear" and "nearest" methods are supported. However, "nearest" is not equivariant; using this method may result in broken equivariance. For this reason, we suggest to use "bilinear" (default value).

Parameters
• in_type (FieldType) – the input field type

• scale_factor (int) – multiplier for spatial size

• mode (str) – algorithm used for upsampling: nearest | bilinear. Default: bilinear

• align_corners (bool) – if True, the corner pixels of the input and output tensors are aligned, and thus preserving the values at those pixels. This only has effect when mode is bilinear. Default: False

forward(input)[source]
Parameters

input (torch.Tensor) – input feature map

Returns

the result of the convolution

export()[source]

Export this module to a normal PyTorch torch.nn.Upsample module and set to “eval” mode.

### Multiple¶

class MultipleModule(in_type, labels, modules, reshuffle=0)[source]

Bases: e2cnn.nn.modules.equivariant_module.EquivariantModule

Split the input tensor in multiple branches identified by the input labels and apply to each of them the corresponding module in modules

A label is associated to each field in the input type, while modules assigns a module to apply to each label (or set of labels). modules should be a list of pairs, each containing an EquivariantModule and a label (or a list of labels).

During forward, fields are grouped by the labels and the input tensor is split accordingly. Then, each subtensor is passed to the corresponding module in modules.

If reshuffle is set to a positive integer, a copy of the input tensor is first built sorting the fields according to the value set:

• 1: fields are sorted by their labels

• 2: fields are sorted by their labels and, then, by their size

• 3: fields are sorted by their labels, by their size and, then, by their type

In this way, fields that need to be retrieved together are contiguous and it is possible to exploit slicing to split the tensor. By default, reshuffle = 0 which means that no sorting is performed and, so, if input fields are not contiguous this layer will use indexing to retrieve sub-tensors.

This modules wraps a BranchingModule followed by a MergeModule.

Parameters
• in_type (FieldType) – the input field type

• labels (list) – the list of labels to group the fields

• modules (list) – list of modules to apply to the labeled fields

• reshuffle (int, optional) – set how to reshuffle the input fields before splitting the tensor. By default (0) no reshuffling is done

forward(input)[source]

Split the input tensor according to the labels, apply each module to the corresponding input sub-tensors and stack the results.

Parameters

input (GeometricTensor) – the input GeometricTensor

Returns

the concatenation of the output of each module

### Reshuffle¶

class ReshuffleModule(in_type, permutation)[source]

Bases: e2cnn.nn.modules.equivariant_module.EquivariantModule

Permutes the fields of the input tensor according to the input permutation.

The parameter permutation should be a list containing a permutation of the integers {0, 1, ..., n-1}, where n is the number of fields of in_type (see e2cnn.nn.FieldType.__len__()).

Parameters
• in_type (FieldType) – input field type

• permutation (list) – permutation to apply

class MaskModule(in_type, S, margin=0.0)[source]

Bases: e2cnn.nn.modules.equivariant_module.EquivariantModule

Performs an element-wise multiplication of the input with a mask of shape S x S.

The mask has value $$1$$ in all pixels with distance smaller than (S-1)/2 * (1 - margin)/100 from the center of the mask and $$0$$ elsewhere. Values change smoothly between the two regions.

This operation is useful to remove from an input image or feature map all the part of the signal defined on the pixels which lay outside the circle inscribed in the grid. Because a rotation would move these pixels outside the grid, this information would anyways be discarded when rotating an image. However, allowing a model to use this information might break the guaranteed equivariance as rotated and non-rotated inputs have different information content.

Note

In order to perform the masking, the module expects an input with the same spatial dimensions as the mask. Then, input tensors must have shape B x C x S x S.

Parameters
• in_type (FieldType) – input field type

• S (int) – the shape of the mask and the expected inputs

• margin (float, optional) – margin around the mask in percentage with respect to the radius of the mask (default 0.)

### Identity¶

class IdentityModule(in_type)[source]

Bases: e2cnn.nn.modules.equivariant_module.EquivariantModule

Simple module which does not perform any operation on the input tensor.

Parameters

in_type (FieldType) – input (and output) type of this module

forward(input)[source]

Returns the input tensor.

Parameters

input (GeometricTensor) – the input GeometricTensor

Returns

the output tensor

export()[source]

Export this module to a normal PyTorch torch.nn.Identity module and set to “eval” mode.

Warning

Only working with PyTorch >= 1.2

## Weight Initialization¶

generalized_he_init(tensor, basisexpansion, cache=False)[source]

Initialize the weights of a convolutional layer with a generalized He’s weight initialization method.

Because the computation of the variances can be expensive, to save time on consecutive runs of the same model, it is possible to cache the tensor containing the variance of each weight, for a specific basisexpansion. This can be useful if a network contains multiple convolution layers of the same kind (same input and output types, same kernel size, etc.) or if one needs to train the same network from scratch multiple times (e.g. to perform hyper-parameter search over learning rate or to repeat an experiment with different random seeds).

Note

The variance tensor is cached in memory and therefore is only available to the current process.

Parameters
• tensor (torch.Tensor) – the tensor containing the weights

• basisexpansion (BasisExpansion) – the basis expansion method

• cache (bool, optional) – cache the variance tensor. By default, cache=False

deltaorthonormal_init(tensor, basisexpansion)[source]

Initialize the weights of a convolutional layer with delta-orthogonal initialization.

Parameters