import torch
from torch import Tensor
from escnn.group import GroupElement
from .field_type import FieldType
from typing import List, Union
import itertools
from collections.abc import Iterable
__all__ = ["GeometricTensor", "tensor_directsum"]
[docs]class GeometricTensor:
def __init__(self, tensor: Tensor, type: FieldType, coords: Tensor = None):
r"""
A GeometricTensor can be interpreted as a *typed* tensor.
It is wrapping a common :class:`torch.Tensor` and endows it with a (compatible) :class:`~escnn.nn.FieldType` as
*transformation law*.
The :class:`~escnn.nn.FieldType` describes the action of a group :math:`G` on the tensor.
This action includes both a transformation of the base space and a transformation of the channels according to
a :math:`G`-representation :math:`\rho`.
All *escnn* neural network operations have :class:`~escnn.nn.GeometricTensor` s as inputs and outputs.
They perform a dynamic typechecking, ensuring that the transformation laws of the data and the operation match.
See also :class:`~escnn.nn.EquivariantModule`.
As usual, the first dimension of the tensor is interpreted as the batch dimension. The second is the fiber
(or channel) dimension, which is associated with a group representation by the field type. The following
dimensions are the spatial dimensions (like in a conventional CNN).
In addition, the method accepts an optional ``coords`` tensor.
If the argument ``coords`` is not passed, the `input` tensor is assumed to have shape
``(batchsize, channels, *spatial_grid_shape)`` and to represent features sampled on a grid of shape
``spatial_grid_shape``; in this case, the action on the base space resamples the transformed features on this
grid (using interpolation, if necessary).
If ``coords`` is not ``None``, ``input`` is assumed to be a ``(#points, channels)`` tensor containing an unstructured
set of points living on the base space; then, ``coords`` should contain the coordinates of these points.
The base space action will then transform these coordinates (no interpolation required).
In that case, the method returns a pair containing both the transformed features (according to the action on the
fibers) and the transformed coordinates (according to the action on the basespace).
In addition, this class accepts an optional argument ``coords``.
If the argument is not passed, the input ``tensor`` is assumed to have shape
``(batchsize, channels, *spatial_grid_shape)`` and to represent features sampled on a grid of shape
``spatial_grid_shape``; in this case, the action on the base space resamples the transformed features on this
grid (using interpolation, if necessary).
If ``coords`` is not ``None``, ``tensor`` is assumed to be a ``(#points, channels)`` tensor containing an unstructured
set of points living on the base space; then, `coords` should contain the coordinates of these points.
The tensor ``coords`` must have shape ``(#points, type.gspace.dimensionality)``, where ``#points = tensor.shape[0]``.
The base space action will then transform these coordinates (no interpolation required).
See also the method :meth:`escnn.nn.FieldType.transform` for more details.
The operations of **addition** and **scalar multiplication** are supported.
For example::
gs = escnn.gspaces.rot2dOnR2(8)
type = escnn.nn.FieldType(gs, [gs.regular_repr]*3)
t1 = escnn.nn.GeometricTensor(torch.randn(1, 24, 3, 3), type)
t2 = escnn.nn.GeometricTensor(torch.randn(1, 24, 3, 3), type)
# addition
t3 = t1 + t2
# scalar product
t3 = t1 * 3.
# scalar product also supports tensors containing only one scalar
t3 = t1 * torch.tensor(3.)
# inplace operations are also supported
t1 += t2
t2 *= 3.
.. warning ::
The multiplication of a PyTorch tensor containing only a scalar with a GeometricTensor is only supported
when using PyTorch 1.4 or higher (see this `issue <https://github.com/pytorch/pytorch/issues/26333>`_ )
.. warning ::
These operations are only supported for *compatible* tensors, i.e. tensors which share the same ``type`` and,
if set, ``coords``.
A GeometricTensor can also be right-multiplied with a :class:`~escnn.group.GroupElement` with the ``@`` operator.
The group element must belong to the :attr:`~escnn.nn.FieldType.fibergroup` of the :class:`~escnn.nn.FieldType`
``type`` of this tensor.
The right-multiplication through ``@`` *only transforms the channels* using the group representation
``type.representation`` but *does not transform the base space*.
This operation is equivalent to :meth:`~escnn.nn.GeometricTensor.transform_fibers`.
Check its documentation for more details.
A GeometricTensor supports **slicing** in a similar way to PyTorch's :class:`torch.Tensor`.
More precisely, slicing along the batch (1st) and the spatial (3rd, 4th, ...) dimensions works as usual.
However, slicing the fiber (2nd) dimension would break equivariance when splitting channels belonging to
the same field.
To prevent this, slicing on the second dimension is defined over *fields* instead of channels.
Note that, if ``coords`` is not `None`, slicing over the first dimension also slices the ``coords`` tensor over its
first dimension.
Moreover, a GeometricTensor also partially supports **advanced indexing** (see NumPy's
documentation about
`indexing <https://numpy.org/doc/stable/user/basics.indexing.html#indexing-on-ndarrays>`_
for more details).
.. warning ::
In contrast to NumPy and PyTorch, an index containing a single integer value **does not** reduce
the dimensionality of the tensor.
In this way, the resulting tensor can always be interpreted as a GeometricTensor.
We give few examples to illustrate this behavior::
# Example of GeometricTensor slicing
space = escnn.gspaces.rot2dOnR2(4)
type = escnn.nn.FieldType(space, [
# field type # index # size
space.regular_repr, # 0 # 4
space.regular_repr, # 1 # 4
space.irrep(1), # 2 # 2
space.irrep(1), # 3 # 2
space.trivial_repr, # 4 # 1
space.trivial_repr, # 5 # 1
space.trivial_repr, # 6 # 1
]) # sum = 15
# this FieldType contains 7 fields
len(type)
>> 7
# the size of this FieldType is equal to the sum of the sizes of each of its fields
type.size
>> 15
geom_tensor = escnn.nn.GeometricTensor(torch.randn(10, type.size, 9, 9), type)
# or, equivalently:
geom_tensor = type(torch.randn(10, type.size, 9, 9))
geom_tensor.shape
>> torch.Size([10, 15, 9, 9])
geom_tensor[1:3, :, 2:5, 2:5].shape
>> torch.Size([2, 15, 3, 3])
geom_tensor[..., 2:5].shape
>> torch.Size([10, 15, 9, 3])
# the tensor contains the fields 1:4, i.e 1, 2 and 3
# these fields have size, respectively, 4, 2 and 2
# so the resulting tensor has 8 channels
geom_tensor[:, 1:4, ...].shape
>> torch.Size([10, 8, 9, 9])
# the tensor contains the fields 0:6:2, i.e 0, 2 and 4
# these fields have size, respectively, 4, 2 and 1
# so the resulting tensor has 7 channels
geom_tensor[:, 0:6:2].shape
>> torch.Size([10, 7, 9, 9])
# the tensor contains only the field 2, which has size 2
# note, also, that even though a single index is used for the batch dimension, the resulting tensor
# still has 4 dimensions
geom_tensor[3, 2].shape
>> torch.Size(1, 2, 9, 9)
# we can use a boolean tensor to perform advanced indexing over the fields of a Geometric Tensor.
# the tensor contains only the two fields of size 2, i.e. the third and the fourth.
# moreover, we also slice the batch dimension.
idx = torch.tensor([type.representations[i].size == 2 for i in range(len(type))])
geom_tensor[0:4, idx].shape
>> torch.Size(4, 4, 9, 9)
# it is also possible to use an integer tensor to perform advanced indexing.
# Here, we select the second and the sixth fields.
idx = torch.tensor([1, 5])
geom_tensor[0:4, idx].shape
>> torch.Size(4, 5, 9, 9)
# advanced indexing over the other dimensions works as usual.
idx = torch.tensor([1, 2, 2, 8, 4])
geom_tensor[idx, 1:3].shape
>> torch.Size(5, 6, 9, 9)
.. warning ::
*Slicing* over the fiber (2nd) dimension with ``step > 1`` or with a negative step is converted
into *indexing* over the channels.
This means that, in these cases, slicing behaves like *advanced indexing* in PyTorch and NumPy
**returning a copy instead of a view**.
For more details, see the *note* `here <https://pytorch.org/docs/stable/tensor_view.html>`_ and
*NumPy*'s `docs <https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html>`_ .
.. note ::
Slicing is not supported for setting values inside the tensor
(i.e. :meth:`~object.__setitem__` is not implemented).
Indeed, depending on the values which are assigned, this operation can break the symmetry of the tensor
which may not transform anymore according to its transformation law (specified by ``type``).
In case this feature is necessary, one can directly access the underlying :class:`torch.Tensor`, e.g.
``geom_tensor.tensor[:3, :, 2:5, 2:5] = torch.randn(3, 4, 3, 3)``, although this is not recommended.
.. warning ::
Using advanced indexing over multiple axes simultaneously drops the corresponding dimensions (except for the
first one). This means that the resulting tensor will likely not have a shape compatible with a Geometric
Tensor (e.g. the spatial dimensions are dropped), causing an error.
Args:
tensor (torch.Tensor): the tensor data
type (FieldType): the type of the tensor, modeling its transformation law
coords (torch.Tensor, optional): a tensor containing the coordinates of the datapoints
Attributes:
~.tensor (torch.Tensor)
~.type (FieldType)
~.coords (torch.Tensor)
"""
assert isinstance(tensor, torch.Tensor)
assert isinstance(type, FieldType)
if coords is None:
assert len(tensor.shape) == 2 + type.gspace.dimensionality, (tensor.shape, type.gspace.dimensionality)
else:
assert len(coords.shape) == 2, coords.shape
assert coords.shape[1] == type.gspace.dimensionality, \
f"Error! `coords` tensor with shape {coords.shape} does not match the dimensionality of the field type {type.gspace.dimensionality}."
assert len(tensor.shape) == 2, \
f"Error! If `coords` is not `None`, the tensor must be a 2D tensor of shape (#points, filedtype size) but the tensor has shape {tensor.shape}."
assert tensor.shape[0] == coords.shape[0],\
f"Error! Points in the `tensor` and `coords` tensors do not match: {tensor.shape[0]} != {coords.shape[0]}."
assert tensor.shape[1] == type.size, \
f"Error! The size of the tensor {tensor.shape} does not match the size of the field type {type.size}."
# torch.Tensor: PyTorch tensor containing the data
self.tensor = tensor
# torch.Tensor: PyTorch tensor containing the coordinates of the data points
self.coords = coords
# FieldType: field type of the signal
self.type = type
[docs] def restrict(self, id) -> 'GeometricTensor':
r"""
Restrict the field type of this tensor.
The method returns a new :class:`~escnn.nn.GeometricTensor` whose :attr:`~escnn.nn.GeometricTensor.type`
is equal to this tensor's :attr:`~escnn.nn.GeometricTensor.type`
restricted to a subgroup :math:`H<G` (see :meth:`escnn.nn.FieldType.restrict`).
The restricted :attr:`~escnn.nn.GeometricTensor.type` is associated with the restricted representation
:math:`\Res{H}{G}\rho` of the :math:`G`-representation :math:`\rho` associated to this tensor's
:attr:`~escnn.nn.GeometricTensor.type`.
The input ``id`` specifies the subgroup :math:`H < G`.
Notice that the underlying :attr:`~escnn.nn.GeometricTensor.tensor` instance will be shared between
the current tensor and the returned one.
.. warning ::
The method builds the new representation on the fly; hence, if this operation is needed at run time,
we suggest to use :class:`escnn.nn.RestrictionModule` which pre-computes the new representation offline.
.. seealso ::
Check the documentation of the :meth:`~escnn.gspaces.GSpace.restrict` method in the
:class:`~escnn.gspaces.GSpace` instance used for a description of the parameter ``id``.
Args:
id: the id identifying the subgroup :math:`H` the representations are restricted to
Returns:
the geometric tensor with the restricted representations
"""
new_class = self.type.restrict(id)
return GeometricTensor(self.tensor, new_class, self.coords)
[docs] def split(self, breaks: List[int]):
r"""
Split this tensor on the channel dimension in a list of smaller tensors.
The original tensor is split at the *fields* specified by the index list ``breaks``.
If the tensor is associated with the list of fields :math:`\{\rho_i\}_i`
(see :attr:`escnn.nn.FieldType.representations`), the :math:`j`-th output tensor will contain the fields
:math:`\rho_{\text{breaks}[j-1]}, \dots, \rho_{\text{breaks}[j]-1}` of the original tensor.
If `breaks = None`, the tensor is split at each field.
This is equivalent to using `breaks = list(range(len(self.type)))`.
Example ::
space = escnn.gspaces.rot2dOnR2(4)
type = escnn.nn.FieldType(space, [
space.regular_repr, # size = 4
space.regular_repr, # size = 4
space.irrep(1), # size = 2
space.irrep(1), # size = 2
space.trivial_repr, # size = 1
space.trivial_repr, # size = 1
space.trivial_repr, # size = 1
]) # sum = 15
type.size
>> 15
geom_tensor = escnn.nn.GeometricTensor(torch.randn(10, type.size, 7, 7), type)
geom_tensor.shape
>> torch.Size([10, 15, 7, 7])
# split the tensor in 3 parts
len(geom_tensor.split([0, 4, 6]))
>> 3
# the first contains
# - the first 2 regular fields (2*4 = 8 channels)
# - 2 vector fields (irrep(1)) (2*2 = 4 channels)
# and, therefore, contains 12 channels
geom_tensor.split([0, 4, 6])[0].shape
>> torch.Size([10, 12, 7, 7])
# the second contains only 2 scalar (trivial) fields (2*1 = 2 channels)
geom_tensor.split([0, 4, 6])[1].shape
>> torch.Size([10, 2, 7, 7])
# the last contains only 1 scalar (trivial) field (1*1 = 1 channels)
geom_tensor.split([0, 4, 6])[2].shape
>> torch.Size([10, 1, 7, 7])
Args:
breaks (list): indices of the fields where to split the tensor
Returns:
list of :class:`~escnn.nn.GeometricTensor` s into which the original tensor is chunked
"""
if breaks is None:
breaks = list(range(len(self.type)))
breaks.append(len(self.type))
# final list of tensors
tensors = []
# list containing the index of the channels separating consecutive fields in this tensor
positions = []
last = 0
for repr in self.type.representations:
positions.append(last)
last += repr.size
positions.append(last)
last_field = 0
# for each break point
for b in breaks:
assert b > last_field, 'Error! "breaks" must be an increasing list of positive indexes'
# compute the sub-class of the new sub-tensor
repr = FieldType(self.type.gspace, self.type.representations[last_field: b])
# retrieve the sub-tensor
data = self.tensor[:, positions[last_field]:positions[b], ...]
tensors.append(GeometricTensor(data, repr, self.coords))
last_field = b
return tensors
@property
def shape(self):
r"""
Alias for ``self.tensor.shape``
"""
return self.tensor.shape
[docs] def size(self):
r"""
Alias for ``self.tensor.size()``
.. seealso ::
:meth:`torch.Tensor.size`
"""
return self.tensor.size()
[docs] def to(self, *args, **kwargs):
r"""
Alias for ``self.tensor.to(*args, **kwargs)``.
Applies :meth:`torch.Tensor.to` to the underlying tensor and wraps the resulting tensor in a new
:class:`~escnn.nn.GeometricTensor` with the same type.
.. warning ::
This method does not affect ``self.coords``.
"""
tensor = self.tensor.to(*args, **kwargs)
return GeometricTensor(tensor, self.type, self.coords)
[docs] def __getitem__(self, slices) -> 'GeometricTensor':
r'''
A GeometricTensor supports *slicing* in a similar way to PyTorch's :class:`torch.Tensor`.
More precisely, slicing along the batch (1st) and the spatial (3rd, 4th, ...) dimensions works as usual.
However, slicing along the channel dimension could break equivariance by splitting the channels belonging to the
same field.
For this reason, slicing on the second dimension is not defined over the channels but over fields.
When a continuous (step=1) slice is used over the fields/channels dimension (the 2nd axis), it is converted
into a continuous slice over the channels.
This is not possible when the step is greater than 1 or negative.
In such cases, the slice over the fields needs to be converted into an index over the channels.
Moreover, when a single integer is used to index an axis, that axis is not discarded as in PyTorch but is
preserved with size 1.
If slicing is performed over the batch dimension and `coords` is not `None`, also the `coords` tensor is sliced.
Slicing is not supported for setting values inside the tensor (i.e. :meth:`object.__setitem__`
is not implemented).
'''
# Slicing is not supported on the channel dimension.
if isinstance(slices, tuple):
if len(slices) > len(self.tensor.shape):
raise TypeError(
f'''
Error! Too many slicing indices for GeometricTensor.
At most {len(self.tensor.shape)} indices expected but {len(slices)} indices passed.
'''
)
else:
slices = (slices,)
for i, idx in enumerate(slices):
if (isinstance(idx, slice) or isinstance(idx, int) or idx == Ellipsis):
continue
elif isinstance(idx, torch.Tensor) and idx.dtype == torch.bool:
continue
elif isinstance(idx, torch.Tensor) and idx.dtype == torch.long:
continue
else:
raise TypeError(f'''
Error! Advanced Indexing over a GeometricTensor is not fully supported yet.
Currently, only indexing with boolean or long tensors and basic slicing are supported.
''')
naxes = len(self.tensor.shape)
# count the number of indexes passed
indexed_axes = 0
for idx in slices:
indexed_axes += 1 - (idx == Ellipsis)
# number of axes which are missing an index
missing_axes = naxes - indexed_axes
# expand the first ellipsis with a number of full slices (i.e. [::]) equal to the number
# of axes not indexed. Discard all other ellipses
expanded_idxs = []
expanded_ellipsis = False
for s in slices:
if s == Ellipsis:
# expand only the first ellipsis
if not expanded_ellipsis:
expanded_idxs += [slice(None)]*missing_axes
expanded_ellipsis = True
else:
# other indices are preserved
expanded_idxs.append(s)
# maximum index per dimension
idx_max = list(self.tensor.shape)
idx_max[1] = len(self.type)
# If an index containing a single integer is passed, it is converted into a slice
# which starts at that index and ends at the following one.
# In this way, when passing a single integer to index a dimension, the resulting tensor will still have that
# dimension with size 1
for i in range(len(expanded_idxs)):
if isinstance(expanded_idxs[i], int):
idx = expanded_idxs[i]
if idx < 0:
# convert a negative index into a positive index
idx = idx_max[i] + idx
expanded_idxs[i] = slice(idx, idx+1, 1)
if len(expanded_idxs) == 1:
# if only the first dimension is indexed, there is no need to do anything
# the resulting tensor will have the same type of the original as the indexing does not affect the
# channels/fields dimension
type = self.type
elif isinstance(expanded_idxs[1], slice) and (expanded_idxs[1].step is None or expanded_idxs[1].step == 1):
# If the index over the fields is a slice and it is contiguous, we can convert it into a
# contiguous slice over the channels.
# The slice will start from the first channel of the first field and will stop at the last channel
# of the last field
start = expanded_idxs[1].start if expanded_idxs[1].start is not None else 0
stop = expanded_idxs[1].stop if expanded_idxs[1].stop is not None else len(self.type)
channel_idxs = slice(
self.type.fields_start[start],
self.type.fields_end[stop-1],
1
)
if start == 0 and stop == len(self.type):
# if all the fields are retrieved by this index, the resulting tensor has the same field
# types of the original one
type = self.type
else:
# otherwise, only a subset of the fields are preserved
type = FieldType(self.type.gspace, self.type.representations[expanded_idxs[1]])
expanded_idxs[1] = channel_idxs
else:
# If the index over the fields is not a slice or it is not a contiguous slice, we need to convert it
# into an index over the channels. We first use the index provided to retrieve the list of fields
# and then add the index of their channels in a list of indexes
idxs = []
# convert the indices into iterable and retrieve the subset of field representations
if isinstance(expanded_idxs[1], slice):
fields = range(len(self.type))[expanded_idxs[1]]
representations = self.type.representations[expanded_idxs[1]]
elif isinstance(expanded_idxs[1], int):
fields = [expanded_idxs[1]]
representations = self.type.representations[expanded_idxs[1]]
elif isinstance(expanded_idxs[1], torch.Tensor) and expanded_idxs[1].dtype == torch.bool:
fields = [i for i in range(len(self.type)) if expanded_idxs[1][i]]
representations = [self.type.representations[f] for f in fields]
elif isinstance(expanded_idxs[1], torch.Tensor) and expanded_idxs[1].dtype == torch.long:
fields = expanded_idxs[1].tolist()
representations = [self.type.representations[f] for f in fields]
elif isinstance(expanded_idxs[1], Iterable):
fields = expanded_idxs[1]
representations = [self.type.representations[f] for f in fields]
else:
raise ValueError('Index over the fiber (2nd) dimension not recognized.')
# iterate over all fields indexed by the user
for field in fields:
# append the indexes of the channels in the field
idxs.append(list(
range(
self.type.fields_start[field],
self.type.fields_end[field],
1
)
))
# only a subset of the fields are preserved by this index
type = FieldType(self.type.gspace, representations)
# concatenate all the channel indexes
channel_idxs = list(itertools.chain(*idxs))
expanded_idxs[1] = channel_idxs
idxs = tuple(expanded_idxs)
sliced_tensor = self.tensor[idxs]
if self.coords is not None:
coords = self.coords[idxs[0], :]
else:
coords = None
return GeometricTensor(sliced_tensor, type, coords)
[docs] def has_same_coords(self, other: 'GeometricTensor') -> bool:
if self.coords is other.coords:
# this includes also the case both are None
return True
if (self.coords is None) != (other.coords is None):
return False
elif self.coords is not None:
return torch.allclose(self.coords, other.coords)
return True
[docs] def is_compatible(self, other: 'GeometricTensor') -> bool:
if self.type != other.type:
return False
return self.has_same_coords(other)
[docs] def __add__(self, other: 'GeometricTensor') -> 'GeometricTensor':
r"""
Add two compatible :class:`~escnn.nn.GeometricTensor` using pointwise addition.
The two tensors needs to have the same shape and be associated to the same field type.
Args:
other (GeometricTensor): the other geometric tensor
Returns:
the sum
"""
if isinstance(other, GeometricTensor):
assert self.is_compatible(other), 'The two geometric tensor must have the same `type` and `coords`'
return GeometricTensor(self.tensor + other.tensor, self.type, self.coords)
else:
return NotImplemented
[docs] def __sub__(self, other: 'GeometricTensor') -> 'GeometricTensor':
r"""
Subtract two compatible :class:`~escnn.nn.GeometricTensor` using pointwise subtraction.
The two tensors needs to have the same shape and be associated to the same field type.
Args:
other (GeometricTensor): the other geometric tensor
Returns:
their difference
"""
if isinstance(other, GeometricTensor):
assert self.is_compatible(other), 'The two geometric tensor must have the same `type` and `coords`'
return GeometricTensor(self.tensor - other.tensor, self.type, self.coords)
else:
return NotImplemented
[docs] def __iadd__(self, other: 'GeometricTensor') -> 'GeometricTensor':
r"""
Add a compatible :class:`~escnn.nn.GeometricTensor` to this tensor inplace.
The two tensors needs to have the same shape and be associated to the same field type.
Args:
other (GeometricTensor): the other geometric tensor
Returns:
this tensor
"""
if isinstance(other, GeometricTensor):
assert self.is_compatible(other), 'The two geometric tensor must have the same `type` and `coords`'
self.tensor += other.tensor
return self
else:
return NotImplemented
[docs] def __isub__(self, other: 'GeometricTensor') -> 'GeometricTensor':
r"""
Subtract a compatible :class:`~escnn.nn.GeometricTensor` to this tensor inplace.
The two tensors needs to have the same shape and be associated to the same field type.
Args:
other (GeometricTensor): the other geometric tensor
Returns:
this tensor
"""
if isinstance(other, GeometricTensor):
assert self.is_compatible(other), 'The two geometric tensor must have the same `type` and `coords`'
self.tensor -= other.tensor
return self
else:
return NotImplemented
[docs] def __mul__(self, other: Union[float, torch.Tensor]) -> 'GeometricTensor':
r"""
Scalar product of this :class:`~escnn.nn.GeometricTensor` with a scalar.
The scalar can be a float number or a :class:`torch.Tensor` containing only
one scalar (i.e. :func:`torch.numel` should return `1`).
Args:
other : a scalar
Returns:
the scalar product
"""
if isinstance(other, float) or (isinstance(other, torch.Tensor) and other.numel() == 1):
# Only multiplication with a scalar is allowed
return GeometricTensor(self.tensor * other, self.type, self.coords)
else:
return NotImplemented
__rmul__ = __mul__
[docs] def __imul__(self, other: Union[float, torch.Tensor]) -> 'GeometricTensor':
r"""
Scalar product of this :class:`~escnn.nn.GeometricTensor` with a scalar.
The operation is done inplace.
The scalar can be a float number of a :class:`torch.Tensor` containing only
one scalar (i.e. :func:`torch.numel` should return `1`).
Args:
other : a scalar
Returns:
the scalar product
"""
if isinstance(other, float) or (isinstance(other, torch.Tensor) and other.numel() == 1):
self.tensor *= other
return self
else:
return NotImplemented
[docs] def __rmatmul__(self, other: GroupElement):
r"""
Equivalent to :meth:`escnn.nn.GeometricTensor.transform_fibers`.
.. warning::
This only transforms the fibers, not the basespace.
.. todo ::
should this be like `.transform(element)` ?
"""
if isinstance(other, GroupElement):
assert other.group == self.type.fibergroup
return GeometricTensor(
self.type.transform_fibers(self.tensor, other),
self.type,
self.coords
)
else:
return NotImplemented
def __repr__(self):
t = repr(self.tensor)[:-1]
t = t.replace('\n', '\n ')
r = 'g_' + t + ', ' + repr(self.type) + ')'
return r
[docs]def tensor_directsum(tensors: List['GeometricTensor']) -> 'GeometricTensor':
r"""
Concatenate a list of :class:`~escnn.nn.GeometricTensor` s on the channels dimension (``dim=1``).
The input tensors have to be compatible: they need to have the same shape except for the channels
dimension (``dim=1``); additionally, if their `coords` attribute is not `None`, they also need to share the same
value.
In the resulting :class:`~escnn.nn.GeometricTensor`, the channels dimension will be associated with the direct sum
representation of the representations of the input tensors.
.. seealso::
:func:`escnn.group.directsum`
Args:
tensors (list): a list of :class:`~escnn.nn.GeometricTensor` s
Returns:
the direct sum of the inputs
"""
# assert len(tensors) > 1
for i in range(1, len(tensors)):
assert tensors[0].type.gspace == tensors[i].type.gspace
assert tensors[0].tensor.ndimension() == tensors[i].tensor.ndimension()
assert tensors[0].tensor.shape[0] == tensors[i].tensor.shape[0]
assert tensors[0].tensor.shape[2:] == tensors[i].tensor.shape[2:]
assert tensors[0].has_same_coords(tensors[i])
# concatenate all representations from all field types
reprs = []
for t in tensors:
reprs += t.type.representations
# build the new field type
cls = FieldType(tensors[0].type.gspace, reprs)
# concatenate the underlying tensors
data = torch.cat([t.tensor for t in tensors], dim=1)
# build the new Geometric Tensor
return GeometricTensor(data, cls, tensors[0].coords)