Source code for torch_rc.nn.esn

import torch
from torch.nn import ModuleList
from typing import Optional, Final


@torch.jit.script
def to_sparse(tensor, density: float, sparse_repr: bool = False) -> torch.Tensor:
    if density < 1:
        if sparse_repr:
            cp = tensor.clone()
            cp[torch.rand_like(tensor) > density] = 0
            return cp.to_sparse()
        else:
            return tensor * (torch.rand_like(tensor) <= density).to(dtype=tensor.dtype).to(device=tensor.device)
    else:
        return tensor


class ESNCell(torch.nn.Module):
    use_bias: Final[bool]
    leaking_rate: Final[float]

    # TODO: density -> sparsity
    def __init__(self,
                 input_size: int,
                 output_size: int,
                 bias: bool = True,
                 scale_rec: float = 0.9,
                 density_rec: float = 1.0,
                 scale_in: float = 1.0,
                 density_in: float = 1.0,
                 scale_bias: float = 1.0,
                 leaking_rate: float = 1.0,
                 rec_rescaling_method: str = 'specrad',  # Either "norm" or "specrad"
                 ):
        super(ESNCell, self).__init__()

        assert rec_rescaling_method in ['norm', 'specrad'], f"Invalid value for {rec_rescaling_method=}"
        assert 0 <= density_rec <= 1, f"Invalid range for {density_rec=}"
        assert 0 <= density_in <= 1, f"Invalid range for {density_in=}"
        assert 0 <= leaking_rate <= 1, f"Invalid range for {leaking_rate=}"

        self.use_bias = bias
        self.leaking_rate = leaking_rate

        # Reservoir
        W_in = torch.rand((output_size, input_size)) * 2 - 1
        W_hat = torch.rand((output_size, output_size)) * 2 - 1
        b = torch.rand(output_size) * 2 - 1 if self.use_bias else None

        # Sparsity
        W_in = to_sparse(W_in, density_in)
        W_hat = to_sparse(W_hat, density_rec)

        # Scale W_in
        W_in = scale_in * W_in

        # Scale W_hat
        W_hat = self.rescale_contractivity(W_hat, scale_rec, rec_rescaling_method)

        if self.use_bias:
            b = scale_bias * b

        # Assign as buffers
        self.register_buffer('W_in', W_in)
        self.register_buffer('W_hat', W_hat)
        if self.use_bias:
            self.register_buffer('b', b)

    def forward(self, input, hidden):
        """
        input: (batch, input_size)
        hidden: (batch, hidden_size)

        output: (batch, hidden_size)
        """

        if self.use_bias:
            h_tilde = torch.tanh(self.W_in @ input.t() + self.W_hat @ hidden.t() + self.b.view(-1, 1)).t()
        else:
            h_tilde = torch.tanh(self.W_in @ input.t() + self.W_hat @ hidden.t()).t()

        h = (1 - self.leaking_rate) * hidden + self.leaking_rate * h_tilde
        return h

    @staticmethod
    def rescale_contractivity(W, coeff, rescaling_method):
        if rescaling_method == 'norm':
            return W * coeff / W.norm()
        elif rescaling_method == 'specrad':
            return W * coeff / (W.eig()[0].abs().max())
        else:
            raise Exception("Invalid rescaling method used (must be either 'norm' or 'specrad')")


class ESNMultiringCell(torch.jit.ScriptModule):
    scale_rec: Final[float]
    use_bias: Final[bool]
    leaking_rate: Final[float]

    def __init__(self,
                 input_size: int,
                 output_size: int,
                 bias: bool = True,
                 scale_rec: float = 0.9,
                 scale_in: float = 1.0,
                 density_in: float = 1.0,
                 scale_bias: float = 1.0,
                 leaking_rate: float = 1.0,
                 # rec_rescaling_method: str = 'specrad',  # Either "norm" or "specrad"
                 topology: str = 'multiring'):  # Either "ring" or "multiring"
        super(ESNMultiringCell, self).__init__()

        # assert rec_rescaling_method in ['norm', 'specrad'], f"Invalid value for {rec_rescaling_method=}"
        assert topology in ['ring', 'multiring'], f"Invalid value for {topology=}"
        assert 0 <= density_in <= 1, f"Invalid range for {density_in=}"
        assert 0 <= leaking_rate <= 1, f"Invalid range for {leaking_rate=}"

        self.use_bias = bias
        self.leaking_rate = leaking_rate

        # Reservoir
        W_in = torch.rand((output_size, input_size)) * 2 - 1
        b = torch.rand(output_size) * 2 - 1 if self.use_bias else None

        W_in = to_sparse(W_in, density_in)

        ring: torch.Tensor
        if topology == 'multiring':
            ring = torch.randperm(output_size)
        elif topology == 'ring':
            ring = torch.cat([torch.arange(1, output_size), torch.tensor([0])])
        else:
            assert False

        # Scale W_in
        W_in = scale_in * W_in

        # Scale W_hat
        # No need to compute the spectral radius for ring or multiring topologies since
        # the spectral radius is equal to the value of the nonzero elements
        self.scale_rec = scale_rec

        if self.use_bias:
            b = scale_bias * b

        self.register_buffer('W_in', W_in)
        self.register_buffer('ring', ring)
        if self.use_bias:
            self.register_buffer('b', b)

    @torch.jit.script_method
    def forward(self, input, hidden):
        """
        input: (batch, input_size)
        hidden: (batch, hidden_size)

        output: (batch, hidden_size)
        """
        # Add the bias column
        if self.use_bias:
            h_tilde = torch.tanh(self.W_in @ input.t() + self.scale_rec * hidden.t()[self.ring] + self.b.view(-1, 1)).t()
        else:
            h_tilde = torch.tanh(self.W_in @ input.t() + self.scale_rec * hidden.t()[self.ring]).t()

        h = (1-self.leaking_rate) * hidden + self.leaking_rate * h_tilde
        return h


class ESNBase(torch.nn.Module):
    bidirectional: Final[bool]

    def __init__(self, input_size, reservoir_size, cells: ModuleList, cells_bw: Optional[ModuleList], num_layers=1):
        """

        Args:
            input_size:
            reservoir_size:
            cells: Cells must be provided as a list. In case of a bidirectional architecture, they must be organized
            num_layers:
            dropout:
            bidirectional:
        """
        super(ESNBase, self).__init__()

        self.input_size = input_size
        self.reservoir_size = reservoir_size
        self.num_layers = num_layers
        self.bidirectional = cells_bw is not None
        self.num_directions = 2 if self.bidirectional else 1
        self._cells = cells
        self._cells_bw = cells_bw

    def forward(self, input, h_0: Optional[torch.Tensor] = None):
        r"""Forwards the given sequence through the network.

        Args:
            input: tensor of shape `(seq_len, batch, input_size)`: tensor containing the features
              of the input sequence.
            h_0: tensor of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
              containing the initial hidden state for each element in the batch.
              Defaults to zero if not provided. If the ESN is bidirectional,
              num_directions should be 2, else it should be 1.

        Returns:
            output, h_n

            - **output** of shape `(seq_len, batch, num_directions * hidden_size)`: tensor
              containing the output features (`h_t`) from the last layer of the RNN,
              for each `t`.

              The directions can be separated
              using ``output.view(seq_len, batch, num_directions, hidden_size)``,
              with forward and backward being direction `0` and `1` respectively.
            - **h_n** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor
              containing the hidden state for `t = seq_len`.

              Like *output*, the layers can be separated using
              ``h_n.view(num_layers, num_directions, batch, hidden_size)``.

        """

        if self.bidirectional:
            return self._forward_bi(input, h_0)
        else:
            return self._forward_mono(input, h_0)

    def _forward_mono(self, input, h_0: Optional[torch.Tensor] = None):
        batch = input.shape[1]

        if h_0 is None:
            h_0 = input.new_zeros((self.num_layers * self.num_directions, batch, self.reservoir_size))

        # Separate the layers to avoid in-place operations
        h_l = list(h_0.unbind())

        next_layer_input = input
        for lyr, cell in enumerate(self._cells):
            layer_outputs = []  # list of (batch, hidden_size)
            step_h = h_l[lyr]
            for i, x_t in enumerate(next_layer_input):
                h = cell.forward(x_t, step_h)  # (batch, hidden_size)
                step_h = h
                layer_outputs.append(h)
            h_l[lyr] = step_h
            next_layer_input = torch.stack(layer_outputs)

        h_n = torch.stack(h_l)

        return next_layer_input, h_n

    def _forward_bi(self, input, h_0: Optional[torch.Tensor] = None):
        # Implementation of Deep-Bi type 2
        # https://github.com/pytorch/pytorch/issues/4930#issuecomment-361851298

        batch = input.shape[1]

        if h_0 is None:
            h_0 = input.new_zeros((self.num_layers * self.num_directions, batch, self.reservoir_size))

        # Separate the layers to avoid in-place operations
        # h_l = [l0d0, l0d1, l1d0, l1d1, ...]
        h_l = list(h_0.unbind())

        next_layer_input = input
        for lyr, (cell_fw, cell_bw) in enumerate(zip(self._cells, self._cells_bw)):
            layer_outputs_fw = []  # list of (batch, hidden_size)
            step_h = h_l[2 * lyr + 0]
            for x_t in next_layer_input:
                h = cell_fw.forward(x_t, step_h)  # (batch, hidden_size)
                step_h = h
                layer_outputs_fw.append(h)
            h_l[2 * lyr + 0] = step_h

            layer_outputs_bw = []  # list of (batch, hidden_size)
            step_h = h_l[2 * lyr + 1]
            for x_t in reversed(next_layer_input):
                h = cell_bw.forward(x_t, step_h)  # (batch, hidden_size)
                step_h = h
                layer_outputs_bw.append(h)
            h_l[2 * lyr + 1] = step_h

            # (seq_len, batch, self.num_directions * self.reservoir_size)
            next_layer_input = torch.cat((
                torch.stack(layer_outputs_fw),
                torch.stack(layer_outputs_bw)
            ), dim=2)

        h_n = torch.stack(h_l)

        return next_layer_input, h_n


[docs]class LeakyESN(ESNBase): def __init__(self, input_size: int, output_size: int, num_layers=1, bias=True, bidirectional=False, scale_rec=0.9, density=1.0, scale_in=1.0, density_in=1.0, scale_bias=1.0, leaking_rate=1.0, rescaling_method='specrad'): r"""A multi-layer Leaky Echo State Network. Args: input_size: The number of expected features in the input `x` output_size: The number of features in the hidden state `h` num_layers: Number of recurrent layers. E.g., setting ``num_layers=2`` would mean stacking two ESNs together to form a `stacked ESN`, with the second ESN taking in outputs of the first ESN and computing the final results. Default: 1 bias: If ``False``, then the layer does not use bias weights. Default: ``True`` bidirectional: If ``True``, becomes a bidirectional ESN. Default: ``False`` scale_rec: Scaling of the recurrent connection matrix. The actual rescaling value depends on the ``rescaling_method`` parameter. density: Density of the recurrent connection matrix. Default: 1.0 scale_in: Scaling of the input connetions. Default: 1.0 density_in: Density of the input connetion matrix. Default: 1.0 scale_bias: Scaling of the bias values. Default: 1.0 leaking_rate: Leaking rate. Default: 1.0 rescaling_method: The method for rescaling the recurrent matrix. It can be either ``'norm'`` or ``'specrad'``. If ``'norm'``, then :math:`\left\|W\right\|_2 = \text{scale_rec}`. If ``'specrad'``, then :math:`\rho(W) = \text{scale_rec}`. Default: ``'specrad'`` Shape: Input1: :math:`(L, N, H_{in})` tensor containing input features where :math:`H_{in}=\text{input\_size}` and `L` represents a sequence length. Input2: :math:`(S, N, H_{out})` tensor containing the initial hidden state for each element in the batch. :math:`H_{out}=\text{hidden\_size}` Defaults to zero if not provided. where :math:`S=\text{num\_layers} * \text{num\_directions}` If the ESN is bidirectional, num_directions should be 2, else it should be 1. Output1: :math:`(L, N, H_{all})` where :math:`H_{all}=\text{num\_directions} * \text{hidden\_size}` Output2: :math:`(S, N, H_{out})` tensor containing the next hidden state for each element in the batch """ num_directions = 2 if bidirectional else 1 cells = ModuleList([ ESNCell(input_size if lyr == 0 else num_directions * output_size, output_size, bias, scale_rec, density, scale_in, density_in, scale_bias, leaking_rate, rescaling_method) for lyr in range(num_layers) ]) cells_bw = None if bidirectional: cells_bw = ModuleList([ ESNCell(input_size if lyr == 0 else num_directions * output_size, output_size, bias, scale_rec, density, scale_in, density_in, scale_bias, leaking_rate, rescaling_method) for lyr in range(num_layers) ]) super().__init__(input_size, output_size, cells, cells_bw, num_layers)
[docs]class MultiringESN(ESNBase): def __init__(self, input_size: int, output_size: int, num_layers=1, bias=True, bidirectional=False, scale_rec=1.0, scale_in=1.0, density_in=1.0, scale_bias=1.0, leaking_rate=1.0, topology='multiring'): r"""A multi-layer Multiring Echo State Network. Args: input_size: The number of expected features in the input `x` output_size: The number of features in the hidden state `h` num_layers: Number of recurrent layers. E.g., setting ``num_layers=2`` would mean stacking two ESNs together to form a `stacked ESN`, with the second ESN taking in outputs of the first ESN and computing the final results. Default: 1 bias: If ``False``, then the layer does not use bias weights. Default: ``True`` bidirectional: If ``True``, becomes a bidirectional ESN. Default: ``False`` scale_rec: Scaling of the recurrent connections. The spectral radius of the implicit recurrent matrices is equal to this parameter, i.e., :math:`\rho(W)=\text{scale_rec}`. Default: 1.0. scale_in: Scaling of the input connetions. Default: 1.0 density_in: Density of the input connetion matrix. Default: 1.0 scale_bias: Scaling of the bias values. Default: 1.0 leaking_rate: Leaking rate. Default: 1.0 topology: The topology of the reservoir. It can be either ``'ring'`` or ``'multiring'``, where `ring` corresponds to a shift operator on the state and `multiring` to a permutation of the state. Default: ``'multiring'`` Shape: Input1: :math:`(L, N, H_{in})` tensor containing input features where :math:`H_{in}=\text{input\_size}` and `L` represents a sequence length. Input2: :math:`(S, N, H_{out})` tensor containing the initial hidden state for each element in the batch. :math:`H_{out}=\text{hidden\_size}` Defaults to zero if not provided. where :math:`S=\text{num\_layers} * \text{num\_directions}` If the ESN is bidirectional, num_directions should be 2, else it should be 1. Output1: :math:`(L, N, H_{all})` where :math:`H_{all}=\text{num\_directions} * \text{hidden\_size}` Output2: :math:`(S, N, H_{out})` tensor containing the next hidden state for each element in the batch """ num_directions = 2 if bidirectional else 1 cells = ModuleList([ ESNMultiringCell(input_size if lyr == 0 else num_directions * output_size, output_size, bias, scale_rec, scale_in, density_in, scale_bias, leaking_rate, topology) for lyr in range(num_layers) ]) cells_bw = None if bidirectional: cells_bw = ModuleList([ ESNMultiringCell(input_size if lyr == 0 else num_directions * output_size, output_size, bias, scale_rec, scale_in, density_in, scale_bias, leaking_rate, topology) for lyr in range(num_layers) ]) super().__init__(input_size, output_size, cells, cells_bw, num_layers)