Source code for escnn.nn.modules.basismanager.basissampler_blocks


from escnn.group import Representation
from escnn.kernels import KernelBasis, EmptyBasisException


from escnn.nn.modules.basismanager import retrieve_indices
from .basismanager import BasisManager

from escnn.nn.modules.basismanager.basissampler_singleblock import block_basissampler

from typing import Callable, Tuple, Dict, List, Iterable, Union
from collections import defaultdict

import torch
import numpy as np
import math


__all__ = ["BlocksBasisSampler"]


[docs]class BlocksBasisSampler(torch.nn.Module, BasisManager): def __init__(self, in_reprs: List[Representation], out_reprs: List[Representation], basis_generator: Callable[[Representation, Representation], KernelBasis], basis_filter: Callable[[dict], bool] = None, recompute: bool = False, ): r""" Module which performs the expansion of an analytical filter basis and samples it on arbitrary input points. Args: in_reprs (list): the input field type out_reprs (list): the output field type basis_generator (callable): method that generates the analytical filter basis basis_filter (callable, optional): filter for the basis elements. Should take a dictionary containing an element's attributes and return whether to keep it or not. recompute (bool, optional): whether to recompute new bases or reuse, if possible, already built tensors. """ super(BlocksBasisSampler, self).__init__() self._in_reprs = in_reprs self._out_reprs = out_reprs self._input_size = sum(r.size for r in in_reprs) self._output_size = sum(r.size for r in out_reprs) self._in_sizes = { r.name: r.size for r in in_reprs } uniform_input = all(r == in_reprs[0] for r in in_reprs) uniform_output = all(r == out_reprs[0] for r in out_reprs) _uniform = uniform_input and uniform_output # we group the bases by their input and output representations _block_sampler_modules = {} # iterate through all different pairs of input/output representations # and, for each of them, build a basis for i_repr in set(in_reprs): for o_repr in set(out_reprs): reprs_names = (i_repr.name, o_repr.name) try: basis = basis_generator(i_repr, o_repr) # BasisSamplerSingleBlock: sampler for block with input i_repr and output o_repr block_sampler = block_basissampler(basis, basis_filter=basis_filter, recompute=recompute) _block_sampler_modules[reprs_names] = block_sampler # register the block sampler as a submodule self.add_module(f"block_sampler_{reprs_names}", block_sampler) except EmptyBasisException: # print(f"Empty basis at {reprs_names}") pass if len(_block_sampler_modules) == 0: print('WARNING! The basis for the block sampler of the filter is empty!') # the list of all pairs of input/output representations which don't have an empty basis self._representations_pairs = sorted(list(_block_sampler_modules.keys())) self._n_pairs = len(set(in_reprs)) * len(set(out_reprs)) if _uniform: assert self._n_pairs <= 1 self._uniform = _uniform and self._n_pairs == 1 # retrieve for each representation in both input and output fields: # - the number of its occurrences, # - the indices where it occurs and # - whether its occurrences are contiguous or not self._in_count, _in_indices, _in_contiguous = retrieve_indices(in_reprs) self._out_count, _out_indices, _out_contiguous = retrieve_indices(out_reprs) self._weights_ranges = {} last_weight_position = 0 self._contiguous = {} # iterate through the different groups of blocks # i.e., through all input/output pairs for io_pair in self._representations_pairs: self._contiguous[io_pair] = _in_contiguous[io_pair[0]] and _out_contiguous[io_pair[1]] # build the indices tensors if self._contiguous[io_pair]: in_indices = [ _in_indices[io_pair[0]].min(), _in_indices[io_pair[0]].max() + 1, _in_indices[io_pair[0]].max() + 1 - _in_indices[io_pair[0]].min() ] out_indices = [ _out_indices[io_pair[1]].min(), _out_indices[io_pair[1]].max() + 1, _out_indices[io_pair[1]].max() + 1 - _out_indices[io_pair[1]].min() ] setattr(self, 'in_indices_{}'.format(io_pair), in_indices) setattr(self, 'out_indices_{}'.format(io_pair), out_indices) else: out_indices, in_indices = _out_indices[io_pair[1]], _in_indices[io_pair[0]] # out_indices, in_indices = torch.meshgrid([_out_indices[io_pair[1]], _in_indices[io_pair[0]]]) # in_indices = in_indices.reshape(-1) # out_indices = out_indices.reshape(-1) # register the indices tensors and the bases tensors as parameters of this module self.register_buffer('in_indices_{}'.format(io_pair), in_indices) self.register_buffer('out_indices_{}'.format(io_pair), out_indices) # number of occurrences of the input/output pair `io_pair` n_pairs = self._in_count[io_pair[0]] * self._out_count[io_pair[1]] # count the actual number of parameters total_weights = _block_sampler_modules[io_pair].dimension() * n_pairs # evaluate the indices in the global weights tensor to use for the basis belonging to this group self._weights_ranges[io_pair] = (last_weight_position, last_weight_position + total_weights) # increment the position counter last_weight_position += total_weights self._dim = last_weight_position def get_element_info(self, idx: int) -> Dict: assert 0 <= idx < self._dim, idx reprs_names = None relative_idx = None for pair, idx_range in self._weights_ranges.items(): if idx_range[0] <= idx < idx_range[1]: reprs_names = pair relative_idx = idx - idx_range[0] break assert reprs_names is not None and relative_idx is not None block_sampler = getattr(self, f"block_sampler_{reprs_names}") block_idx = relative_idx // block_sampler.dimension() relative_idx = relative_idx % block_sampler.dimension() attr = block_sampler.get_element_info(relative_idx).copy() block_count = 0 out_irreps_count = 0 for o, o_repr in enumerate(self._out_reprs): in_irreps_count = 0 for i, i_repr in enumerate(self._in_reprs): if reprs_names == (i_repr.name, o_repr.name): if block_count == block_idx: # retrieve the attributes of each basis element and build a new list of # attributes adding information specific to the current block attr.update({ "in_irreps_position": in_irreps_count + attr["in_irrep_idx"], "out_irreps_position": out_irreps_count + attr["out_irrep_idx"], "in_repr": reprs_names[0], "out_repr": reprs_names[1], "in_field_position": i, "out_field_position": o, }) attr['block_id'] = attr['id'] attr['id'] = idx return attr block_count += 1 in_irreps_count += len(i_repr.irreps) out_irreps_count += len(o_repr.irreps) raise ValueError(f"Parameter with index {idx} not found!") def get_basis_info(self) -> Iterable[Dict]: out_irreps_counts = [0] out_block_counts = defaultdict(list) for o, o_repr in enumerate(self._out_reprs): out_irreps_counts.append(out_irreps_counts[-1] + len(o_repr.irreps)) out_block_counts[o_repr.name].append(o) in_irreps_counts = [0] in_block_counts = defaultdict(list) for i, i_repr in enumerate(self._in_reprs): in_irreps_counts.append(in_irreps_counts[-1] + len(i_repr.irreps)) in_block_counts[i_repr.name].append(i) # iterate through the different group of blocks # i.e., through all input/output pairs idx = 0 for reprs_names in self._representations_pairs: block_sampler = getattr(self, f"block_sampler_{reprs_names}") # since this method returns an iterable of attributes built on the fly, load all attributes first and then # iterate on this list attrs = list(block_sampler.get_basis_info()) for o in out_block_counts[reprs_names[1]]: out_irreps_count = out_irreps_counts[o] for i in in_block_counts[reprs_names[0]]: in_irreps_count = in_irreps_counts[i] # retrieve the attributes of each basis element and build a new list of # attributes adding information specific to the current block for attr in attrs: attr = attr.copy() attr.update({ "in_irreps_position": in_irreps_count + attr["in_irrep_idx"], "out_irreps_position": out_irreps_count + attr["out_irrep_idx"], "in_repr": reprs_names[0], "out_repr": reprs_names[1], "in_field_position": i, "out_field_position": o, }) # build the ids of the basis vectors # add names and indices of the input and output fields # id = '({}-{},{}-{})'.format(reprs_names[0], i, reprs_names[1], o) # # add the original id in the block submodule # id += "_" + attr["id"] # # # update with the new id # attr["id"] = id # # attr["idx"] = idx attr['block_id'] = attr['id'] attr['id'] = idx assert idx < self._dim idx += 1 yield attr def dimension(self) -> int: return self._dim def _compute_out_block(self, weights: torch.Tensor, input: torch.Tensor, points: torch.Tensor, io_pair) -> torch.Tensor: groups = input.shape[1] # retrieve the basis block_sampler = getattr(self, f"block_sampler_{io_pair}") # retrieve the linear coefficients for the basis sampler coefficients = weights[self._weights_ranges[io_pair][0]:self._weights_ranges[io_pair][1]] # reshape coefficients for the batch matrix multiplication coefficients = coefficients.view( groups, self._out_count[io_pair[1]] // groups, # u self._in_count[io_pair[0]], # j block_sampler.dimension(), # k ) # expand the current subset of basis vectors and set the result in the appropriate place in the filter _basis = block_sampler(points) # p, o, i, k = _basis.shape # TODO: torch.einsum does not optimize the order of operations. Need to do this manually! _out = torch.einsum( 'gujk,poik,pgji->pguo', coefficients, _basis, input, ) return _out def _contract_basis_block(self, weights: torch.Tensor, points: torch.Tensor, io_pair) -> torch.Tensor: # retrieve the basis block_sampler = getattr(self, f"block_sampler_{io_pair}") # retrieve the linear coefficients for the basis sampler coefficients = weights[self._weights_ranges[io_pair][0]:self._weights_ranges[io_pair][1]] # reshape coefficients for the batch matrix multiplication coefficients = coefficients.view( self._out_count[io_pair[1]], # u self._in_count[io_pair[0]], # j block_sampler.dimension(), # k ) # expand the current subset of basis vectors and set the result in the appropriate place in the filter _basis = block_sampler(points) # p, o, i, k = _basis.shape _filter = torch.einsum( 'ujk,poik->pujoi', coefficients, _basis, ).permute(0, 1, 3, 2, 4) return _filter
[docs] def forward(self, weights: torch.Tensor, points: torch.Tensor) -> torch.Tensor: """ Forward step of the Module which expands the basis, samples it on the input `points` and returns the filter built. Args: weights (torch.Tensor): the learnable weights used to linearly combine the basis filters points (torch.Tensor): the points where the filter should be sampled Returns: the filter built """ assert weights.shape[0] == self.dimension() assert len(weights.shape) == 1 S = points.shape[0] if self._uniform: # if there is only one block (i.e. one type of input field and one type of output field), # we can return the computed block immediately, instead of copying it inside a preallocated empty tensor io_pair = self._representations_pairs[0] _filter = self._contract_basis_block(weights, points, io_pair) _filter = _filter.reshape(S, self._output_size, self._input_size) else: # to support Automatic Mixed Precision (AMP), we can not preallocate the output tensor with a specific dtype # Instead, we check the dtype of the first `expanded` block. For this reason, we postpose the allocation # of the full _filter tensor _filter = None # iterate through all input-output field representations pairs for io_pair in self._representations_pairs: # retrieve the indices in_indices = getattr(self, f"in_indices_{io_pair}") out_indices = getattr(self, f"out_indices_{io_pair}") # expand the current subset of basis vectors and set the result in the appropriate place in the filter expanded = self._contract_basis_block(weights, points, io_pair) if _filter is None: # build the tensor which will contain the filter # this lazy strategy allows us to use expanded.dtype which is dynamically chosen by PyTorch's AMP _filter = torch.zeros( S, self._output_size, self._input_size, device=weights.device, dtype=expanded.dtype, ) if self._contiguous[io_pair]: _filter[ :, out_indices[0]:out_indices[1], in_indices[0]:in_indices[1], ] = expanded.reshape(S, out_indices[2], in_indices[2]) else: out_indices, in_indices = torch.meshgrid([out_indices, in_indices], indexing='ij') in_indices = in_indices.reshape(-1) out_indices = out_indices.reshape(-1) _filter[ :, out_indices, in_indices, ] = expanded.reshape(S, -1) if _filter is None: # just in case _filter = torch.zeros( S, self._output_size, self._input_size, device=weights.device, dtype=weights.dtype, ) # return the new filter return _filter
def _expand_filter_then_compute(self, weights: torch.Tensor, input: torch.Tensor, points: torch.Tensor, groups: int = 1) -> torch.Tensor: _filter = self(weights, points) S = input.shape[0] assert S > 0, S input = input.view(S, groups, self._input_size) _filter = _filter.view(S, groups, self._output_size // groups, self._input_size) return torch.einsum( 'pgoi,pgi->pgo', _filter, input ).view(S, self._output_size) def _compute_then_expand_filter( self, weights: torch.Tensor, input: torch.Tensor, points: torch.Tensor, groups: int = 1 ) -> torch.Tensor: S = input.shape[0] assert S > 0, S input = input.view(S, groups, self._input_size) if self._uniform: # if there is only one block (i.e. one type of input field and one type of output field), # we can return the computed block immediately, instead of copying it inside a preallocated empty tensor io_pair = self._representations_pairs[0] in_repr = io_pair[0] _input = input.view(S, groups, self._in_count[in_repr], self._in_sizes[in_repr]) return self._compute_out_block(weights, _input, points, io_pair).reshape(S, self._output_size) else: # build the tensor which will contain the output _out = torch.zeros(S, self._output_size, device=input.device, dtype=input.dtype) # iterate through all input-output field representations pairs for io_pair in self._representations_pairs: # retrieve the indices in_indices = getattr(self, f"in_indices_{io_pair}") out_indices = getattr(self, f"out_indices_{io_pair}") if self._contiguous[io_pair]: _input = input[:, :, in_indices[0]:in_indices[1]] else: _input = input[:, :, in_indices] in_repr = io_pair[0] _input = _input.view(S, groups, self._in_count[in_repr], self._in_sizes[in_repr]) # expand the current subset of basis vectors and set the result in the appropriate place in the filter block_out = self._compute_out_block(weights, _input, points, io_pair) if self._contiguous[io_pair]: _out[:, out_indices[0]:out_indices[1]] += block_out.reshape(S, groups*out_indices[2]) else: _out[:, out_indices] += block_out.reshape(S, -1) return _out
[docs] def compute_messages(self, weights: torch.Tensor, input: torch.Tensor, points: torch.Tensor, conv_first: bool = True, groups: int = 1, ) -> torch.Tensor: """ Expands the basis with the learnable weights to generate the filter and use it to compute the messages along the edges. Each point in `points` corresponds to an edge in a graph. Each point is associated with a row of `input`. This row is a feature associated to the source node of the edge which needs to be propagated to the target node of the edge. This method also allows grouped-convolution via the argument ``groups``. When used, the ``input`` tensor should contain ``groups`` blocks, each transforming under ``self._in_reprs``. Moreover, the output size ``self._out_size`` should be divisible by ``groups``. .. warning:: With respect to convolution layers, this method does not check that ``self._out_repr`` splits in ``groups`` blocks containing the same representations. Hence, this operation can break equivariance if ``groups`` is not properly set and ``self._out_repr`` contains an heterogeneous list of representations. We recommend using directly the :class:`~escnn.nn.R2PointConv` or :class:`~escnn.nn.R3PointConv` modules instead, which implement a number of checks to ensure the convolution is done in an equivariant way. Args: weights (torch.Tensor): the learnable weights used to linearly combine the basis filters input (torch.Tensor): the input features associated with each point points (torch.Tensor): the points where the filter should be sampled conv_first (bool, optional): perform convolution with the basis filters and, then, combine the responses with the learnable weights. This generally has computational benefits. (Default ``True``). groups (int, optional): number of blocked connections from input channels to output channels. It allows depthwise convolution. Default: ``1``. Returns: the messages computed """ assert weights.shape[0] == self.dimension() assert len(weights.shape) == 1 assert len(input.shape) == 2 assert input.shape[1] == self._input_size * groups, (input.shape, self._input_size, groups) assert input.shape[0] == points.shape[0] assert self._output_size % groups == 0, (self._output_size, groups) if conv_first: return self._compute_then_expand_filter(weights, input, points, groups=groups) else: return self._expand_filter_then_compute(weights, input, points, groups=groups)
def __hash__(self): _hash = 0 for io in self._representations_pairs: n_pairs = self._in_count[io[0]] * self._out_count[io[1]] _hash += hash(getattr(self, f"block_sampler_{io}")) * n_pairs return _hash def __eq__(self, other): if not isinstance(other, BlocksBasisSampler): return False if self._dim != other._dim: return False if self._representations_pairs != other._representations_pairs: return False for io in self._representations_pairs: if self._contiguous[io] != other._contiguous[io]: return False if self._weights_ranges[io] != other._weights_ranges[io]: return False if self._contiguous[io]: if getattr(self, f"in_indices_{io}") != getattr(other, f"in_indices_{io}"): return False if getattr(self, f"out_indices_{io}") != getattr(other, f"out_indices_{io}"): return False else: if torch.any(getattr(self, f"in_indices_{io}") != getattr(other, f"in_indices_{io}")): return False if torch.any(getattr(self, f"out_indices_{io}") != getattr(other, f"out_indices_{io}")): return False if getattr(self, f"block_sampler_{io}") != getattr(other, f"block_sampler_{io}"): return False return True