Source code for e2cnn.nn.init



from e2cnn.nn.modules.r2_conv.basisexpansion import BasisExpansion

from collections import defaultdict

import torch
from scipy import stats
import math

__all__ = ["generalized_he_init", "deltaorthonormal_init"]


def _generalized_he_init_variances(basisexpansion: BasisExpansion):
    r"""

    Compute the variances of the weights of a convolutional layer with a generalized He's weight initialization method.

    Args:
        basisexpansion (BasisExpansion): the basis expansion method

    """
    
    vars = torch.ones((basisexpansion.dimension(),))
    
    inputs_count = defaultdict(lambda: set())
    basis_count = defaultdict(int)
    
    basis_info = list(basisexpansion.get_basis_info())
    
    for attr in basis_info:
        i, o = attr["in_irreps_position"], attr["out_irreps_position"]
        in_irrep, out_irrep = attr["in_irrep"], attr["out_irrep"]
        inputs_count[o].add(in_irrep)
        basis_count[(in_irrep, o)] += 1
    
    for o in inputs_count.keys():
        inputs_count[o] = len(inputs_count[o])
    
    for w, attr in enumerate(basis_info):
        i, o = attr["in_irreps_position"], attr["out_irreps_position"]
        in_irrep, out_irrep = attr["in_irrep"], attr["out_irrep"]
        vars[w] = 1. / math.sqrt(inputs_count[o] * basis_count[(in_irrep, o)])
    
    return vars


cached_he_vars = {}


[docs]def generalized_he_init(tensor: torch.Tensor, basisexpansion: BasisExpansion, cache: bool = False): r""" Initialize the weights of a convolutional layer with a generalized He's weight initialization method. Because the computation of the variances can be expensive, to save time on consecutive runs of the same model, it is possible to cache the tensor containing the variance of each weight, for a specific ```basisexpansion```. This can be useful if a network contains multiple convolution layers of the same kind (same input and output types, same kernel size, etc.) or if one needs to train the same network from scratch multiple times (e.g. to perform hyper-parameter search over learning rate or to repeat an experiment with different random seeds). .. note :: The variance tensor is cached in memory and therefore is only available to the current process. Args: tensor (torch.Tensor): the tensor containing the weights basisexpansion (BasisExpansion): the basis expansion method cache (bool, optional): cache the variance tensor. By default, ```cache=False``` """ # Initialization assert tensor.shape == (basisexpansion.dimension(),) if cache and basisexpansion not in cached_he_vars: cached_he_vars[basisexpansion] = _generalized_he_init_variances(basisexpansion) if cache: vars = cached_he_vars[basisexpansion] else: vars = _generalized_he_init_variances(basisexpansion) tensor[:] = vars * torch.randn_like(tensor)
[docs]def deltaorthonormal_init(tensor: torch.Tensor, basisexpansion: BasisExpansion): r""" Initialize the weights of a convolutional layer with *delta-orthogonal* initialization. Args: tensor (torch.Tensor): the tensor containing the weights basisexpansion (BasisExpansion): the basis expansion method """ # Initialization assert tensor.shape == (basisexpansion.dimension(), ) tensor.fill_(0.) counts = defaultdict(lambda: defaultdict(lambda: [])) for p, attr in enumerate(basisexpansion.get_basis_info()): i = attr["in_irrep"] o = attr["out_irrep"] ip = attr["in_irreps_position"] op = attr["out_irreps_position"] if "radius" in attr: r = attr["radius"] elif "order" in attr: r = attr["order"] else: raise ValueError("Attribute dict has no radial information, needed for deltaorthonormal init") if i == o and r == 0.: counts[i][(ip, op)].append(p) def same_content(l): l = list(l) return all(ll == l[0] for ll in l) for irrep, count in counts.items(): assert same_content([len(x) for x in count.values()]), [len(x) for x in count.values()] in_c = defaultdict(int) out_c = defaultdict(int) for ip, op in count.keys(): in_c[ip] += 1 out_c[op] += 1 assert same_content(in_c.values()), count.keys() assert same_content(out_c.values()), count.keys() assert list(in_c.values())[0] == len(out_c.keys()) assert list(out_c.values())[0] == len(in_c.keys()) s = len(list(count.values())[0]) i = len(in_c.keys()) o = len(out_c.keys()) # assert i <= o, (i, o, s, irrep, self._input_size, self._output_size) # if i > o: # print("Warning: using delta orthogonal initialization to map to a larger number of channels") if max(o, i) > 1: W = stats.ortho_group.rvs(max(i, o))[:o, :i] # W = np.eye(o, i) else: W = 2 * torch.randint(0, 1, size=(1, 1)) - 1 # W = np.array([[1]]) w = torch.randn((o, s)) # w = torch.ones((o, s)) w *= 5. w /= (w ** 2).sum(dim=1, keepdim=True).sqrt() for i, ip in enumerate(in_c.keys()): for o, op in enumerate(out_c.keys()): for p, pp in enumerate(count[(ip, op)]): tensor[pp] = w[o, p] * W[o, i]