from e2cnn.kernels import Basis, EmptyBasisException
from e2cnn.gspaces import *
from e2cnn.group import Representation
from e2cnn.nn import FieldType
from .. import utils
from .basisexpansion import BasisExpansion
from .basisexpansion_singleblock import block_basisexpansion
from collections import defaultdict
from typing import Callable, List, Iterable, Dict, Union
import torch
import numpy as np
__all__ = ["BlocksBasisExpansion"]
[docs]class BlocksBasisExpansion(BasisExpansion):
def __init__(self,
in_type: FieldType,
out_type: FieldType,
basis_generator: Callable[[Representation, Representation], Basis],
points: np.ndarray,
basis_filter: Callable[[dict], bool] = None,
recompute: bool = False,
**kwargs
):
r"""
With this algorithm, the expansion is done on the intertwiners of the fields' representations pairs in input and
output.
Args:
in_type (FieldType): the input field type
out_type (FieldType): the output field type
basis_generator (callable): method that generates the analytical filter basis
points (~numpy.ndarray): points where the analytical basis should be sampled
basis_filter (callable, optional): filter for the basis elements. Should take a dictionary containing an
element's attributes and return whether to keep it or not.
recompute (bool, optional): whether to recompute new bases or reuse, if possible, already built tensors.
**kwargs: keyword arguments to be passed to ```basis_generator```
Attributes:
S (int): number of points where the filters are sampled
"""
assert in_type.gspace == out_type.gspace
assert isinstance(in_type.gspace, GeneralOnR2)
super(BlocksBasisExpansion, self).__init__()
self._in_type = in_type
self._out_type = out_type
self._input_size = in_type.size
self._output_size = out_type.size
self.points = points
# int: number of points where the filters are sampled
self.S = self.points.shape[1]
# we group the basis vectors by their input and output representations
_block_expansion_modules = {}
# iterate through all different pairs of input/output representationions
# and, for each of them, build a basis
for i_repr in in_type._unique_representations:
for o_repr in out_type._unique_representations:
reprs_names = (i_repr.name, o_repr.name)
try:
basis = basis_generator(i_repr, o_repr, **kwargs)
block_expansion = block_basisexpansion(basis, points, basis_filter, recompute=recompute)
_block_expansion_modules[reprs_names] = block_expansion
# register the block expansion as a submodule
self.add_module(f"block_expansion_{reprs_names}", block_expansion)
except EmptyBasisException:
# print(f"Empty basis at {reprs_names}")
pass
if len(_block_expansion_modules) == 0:
print('WARNING! The basis for the block expansion of the filter is empty!')
self._n_pairs = len(in_type._unique_representations) * len(out_type._unique_representations)
# the list of all pairs of input/output representations which don't have an empty basis
self._representations_pairs = sorted(list(_block_expansion_modules.keys()))
# retrieve for each representation in both input and output fields:
# - the number of its occurrences,
# - the indices where it occurs and
# - whether its occurrences are contiguous or not
self._in_count, _in_indices, _in_contiguous = _retrieve_indices(in_type)
self._out_count, _out_indices, _out_contiguous = _retrieve_indices(out_type)
# compute the attributes and an id for each basis element (and, so, of each parameter)
# attributes, basis_ids = _compute_attrs_and_ids(in_type, out_type, _block_expansion_modules)
basis_ids = _compute_attrs_and_ids(in_type, out_type, _block_expansion_modules)
self._weights_ranges = {}
last_weight_position = 0
self._ids_to_basis = {}
self._basis_to_ids = []
self._contiguous = {}
# iterate through the different group of blocks
# i.e., through all input/output pairs
for io_pair in self._representations_pairs:
self._contiguous[io_pair] = _in_contiguous[io_pair[0]] and _out_contiguous[io_pair[1]]
# build the indices tensors
if self._contiguous[io_pair]:
# in_indices = torch.LongTensor([
in_indices = [
_in_indices[io_pair[0]].min(),
_in_indices[io_pair[0]].max() + 1,
_in_indices[io_pair[0]].max() + 1 - _in_indices[io_pair[0]].min()
]# )
# out_indices = torch.LongTensor([
out_indices = [
_out_indices[io_pair[1]].min(),
_out_indices[io_pair[1]].max() + 1,
_out_indices[io_pair[1]].max() + 1 - _out_indices[io_pair[1]].min()
] #)
setattr(self, 'in_indices_{}'.format(io_pair), in_indices)
setattr(self, 'out_indices_{}'.format(io_pair), out_indices)
else:
out_indices, in_indices = torch.meshgrid([_out_indices[io_pair[1]], _in_indices[io_pair[0]]])
in_indices = in_indices.reshape(-1)
out_indices = out_indices.reshape(-1)
# register the indices tensors and the bases tensors as parameters of this module
self.register_buffer('in_indices_{}'.format(io_pair), in_indices)
self.register_buffer('out_indices_{}'.format(io_pair), out_indices)
# count the actual number of parameters
total_weights = len(basis_ids[io_pair])
for i, id in enumerate(basis_ids[io_pair]):
self._ids_to_basis[id] = last_weight_position + i
self._basis_to_ids += basis_ids[io_pair]
# evaluate the indices in the global weights tensor to use for the basis belonging to this group
self._weights_ranges[io_pair] = (last_weight_position, last_weight_position + total_weights)
# increment the position counter
last_weight_position += total_weights
def get_basis_names(self) -> List[str]:
return self._basis_to_ids
def get_element_info(self, name: Union[str, int]) -> Dict:
if isinstance(name, str):
idx = self._ids_to_basis[name]
else:
idx = name
reprs_names = None
relative_idx = None
for pair, idx_range in self._weights_ranges.items():
if idx_range[0] <= idx < idx_range[1]:
reprs_names = pair
relative_idx = idx - idx_range[0]
break
assert reprs_names is not None and relative_idx is not None
block_expansion = getattr(self, f"block_expansion_{reprs_names}")
block_idx = relative_idx // block_expansion.dimension()
relative_idx = relative_idx % block_expansion.dimension()
attr = block_expansion.get_element_info(relative_idx).copy()
block_count = 0
out_irreps_count = 0
for o, o_repr in enumerate(self._out_type.representations):
in_irreps_count = 0
for i, i_repr in enumerate(self._in_type.representations):
if reprs_names == (i_repr.name, o_repr.name):
if block_count == block_idx:
# retrieve the attributes of each basis element and build a new list of
# attributes adding information specific to the current block
attr.update({
"in_irreps_position": in_irreps_count + attr["in_irrep_idx"],
"out_irreps_position": out_irreps_count + attr["out_irrep_idx"],
"in_repr": reprs_names[0],
"out_repr": reprs_names[1],
"in_field_position": i,
"out_field_position": o,
})
# build the ids of the basis vectors
# add names and indices of the input and output fields
id = '({}-{},{}-{})'.format(i_repr.name, i, o_repr.name, o)
# add the original id in the block submodule
id += "_" + attr["id"]
# update with the new id
attr["id"] = id
attr["idx"] = idx
return attr
block_count += 1
in_irreps_count += len(i_repr.irreps)
out_irreps_count += len(o_repr.irreps)
raise ValueError(f"Parameter with index {idx} not found!")
def get_basis_info(self) -> Iterable:
out_irreps_counts = [0]
out_block_counts = defaultdict(list)
for o, o_repr in enumerate(self._out_type.representations):
out_irreps_counts.append(out_irreps_counts[-1] + len(o_repr.irreps))
out_block_counts[o_repr.name].append(o)
in_irreps_counts = [0]
in_block_counts = defaultdict(list)
for i, i_repr in enumerate(self._in_type.representations):
in_irreps_counts.append(in_irreps_counts[-1] + len(i_repr.irreps))
in_block_counts[i_repr.name].append(i)
# iterate through the different group of blocks
# i.e., through all input/output pairs
idx = 0
for reprs_names in self._representations_pairs:
block_expansion = getattr(self, f"block_expansion_{reprs_names}")
for o in out_block_counts[reprs_names[1]]:
out_irreps_count = out_irreps_counts[o]
for i in in_block_counts[reprs_names[0]]:
in_irreps_count = in_irreps_counts[i]
# retrieve the attributes of each basis element and build a new list of
# attributes adding information specific to the current block
for attr in block_expansion.get_basis_info():
attr = attr.copy()
attr.update({
"in_irreps_position": in_irreps_count + attr["in_irrep_idx"],
"out_irreps_position": out_irreps_count + attr["out_irrep_idx"],
"in_repr": reprs_names[0],
"out_repr": reprs_names[1],
"in_field_position": i,
"out_field_position": o,
})
# build the ids of the basis vectors
# add names and indices of the input and output fields
id = '({}-{},{}-{})'.format(reprs_names[0], i, reprs_names[1], o)
# add the original id in the block submodule
id += "_" + attr["id"]
# update with the new id
attr["id"] = id
attr["idx"] = idx
idx += 1
yield attr
def dimension(self) -> int:
return len(self._ids_to_basis)
def _expand_block(self, weights, io_pair):
# retrieve the basis
block_expansion = getattr(self, f"block_expansion_{io_pair}")
# retrieve the linear coefficients for the basis expansion
coefficients = weights[self._weights_ranges[io_pair][0]:self._weights_ranges[io_pair][1]]
# reshape coefficients for the batch matrix multiplication
coefficients = coefficients.view(-1, block_expansion.dimension())
# expand the current subset of basis vectors and set the result in the appropriate place in the filter
_filter = block_expansion(coefficients)
k, o, i, p = _filter.shape
_filter = _filter.view(
self._out_count[io_pair[1]],
self._in_count[io_pair[0]],
o,
i,
self.S,
)
_filter = _filter.transpose(1, 2)
return _filter
[docs] def forward(self, weights: torch.Tensor) -> torch.Tensor:
"""
Forward step of the Module which expands the basis and returns the filter built
Args:
weights (torch.Tensor): the learnable weights used to linearly combine the basis filters
Returns:
the filter built
"""
assert weights.shape[0] == self.dimension()
assert len(weights.shape) == 1
if self._n_pairs == 1:
# if there is only one block (i.e. one type of input field and one type of output field),
# we can return the expanded block immediately, instead of copying it inside a preallocated empty tensor
io_pair = self._representations_pairs[0]
in_indices = getattr(self, f"in_indices_{io_pair}")
out_indices = getattr(self, f"out_indices_{io_pair}")
_filter = self._expand_block(weights, io_pair).reshape(out_indices[2], in_indices[2], self.S)
else:
# build the tensor which will contain te filter
_filter = torch.zeros(self._output_size, self._input_size, self.S, device=weights.device)
# iterate through all input-output field representations pairs
for io_pair in self._representations_pairs:
# retrieve the indices
in_indices = getattr(self, f"in_indices_{io_pair}")
out_indices = getattr(self, f"out_indices_{io_pair}")
# expand the current subset of basis vectors and set the result in the appropriate place in the filter
expanded = self._expand_block(weights, io_pair)
if self._contiguous[io_pair]:
_filter[
out_indices[0]:out_indices[1],
in_indices[0]:in_indices[1],
:,
] = expanded.reshape(out_indices[2], in_indices[2], self.S)
else:
_filter[
out_indices,
in_indices,
:,
] = expanded.reshape(-1, self.S)
# return the new filter
return _filter
def __hash__(self):
_hash = 0
for io in self._representations_pairs:
n_pairs = self._in_count[io[0]] * self._out_count[io[1]]
_hash += hash(getattr(self, f"block_expansion_{io}")) * n_pairs
return _hash
def __eq__(self, other):
if not isinstance(other, BlocksBasisExpansion):
return False
if self.dimension() != other.dimension():
return False
if self._representations_pairs != other._representations_pairs:
return False
for io in self._representations_pairs:
if self._contiguous[io] != other._contiguous[io]:
return False
if self._weights_ranges[io] != other._weights_ranges[io]:
return False
if self._contiguous[io]:
if getattr(self, f"in_indices_{io}") != getattr(other, f"in_indices_{io}"):
return False
if getattr(self, f"out_indices_{io}") != getattr(other, f"out_indices_{io}"):
return False
else:
if torch.any(getattr(self, f"in_indices_{io}") != getattr(other, f"in_indices_{io}")):
return False
if torch.any(getattr(self, f"out_indices_{io}") != getattr(other, f"out_indices_{io}")):
return False
if getattr(self, f"block_expansion_{io}") != getattr(other, f"block_expansion_{io}"):
return False
return True
def _retrieve_indices(type: FieldType):
fiber_position = 0
_indices = defaultdict(list)
_count = defaultdict(int)
_contiguous = {}
for repr in type.representations:
_indices[repr.name] += list(range(fiber_position, fiber_position + repr.size))
fiber_position += repr.size
_count[repr.name] += 1
for name, indices in _indices.items():
# _contiguous[o_name] = indices == list(range(indices[0], indices[0]+len(indices)))
_contiguous[name] = utils.check_consecutive_numbers(indices)
_indices[name] = torch.LongTensor(indices)
return _count, _indices, _contiguous
def _compute_attrs_and_ids(in_type, out_type, block_submodules):
basis_ids = defaultdict(lambda: [])
# iterate over all blocks
# each block is associated to an input/output representations pair
out_fiber_position = 0
out_irreps_count = 0
for o, o_repr in enumerate(out_type.representations):
in_fiber_position = 0
in_irreps_count = 0
for i, i_repr in enumerate(in_type.representations):
reprs_names = (i_repr.name, o_repr.name)
# if a basis for the space of kernels between the current pair of representations exists
if reprs_names in block_submodules:
# retrieve the attributes of each basis element and build a new list of
# attributes adding information specific to the current block
ids = []
for attr in block_submodules[reprs_names].get_basis_info():
# build the ids of the basis vectors
# add names and indices of the input and output fields
id = '({}-{},{}-{})'.format(i_repr.name, i, o_repr.name, o)
# add the original id in the block submodule
id += "_" + attr["id"]
ids.append(id)
# append the ids of the basis vectors
basis_ids[reprs_names] += ids
in_fiber_position += i_repr.size
in_irreps_count += len(i_repr.irreps)
out_fiber_position += o_repr.size
out_irreps_count += len(o_repr.irreps)
# return attributes, basis_ids
return basis_ids