Loss

This module provides some pre-built loss methods to be used with spike-train. Standard PyTorch loss are also compatible.

class lava.lib.dl.slayer.loss.SpikeMax(moving_window=None, mode='probability', reduction='sum')

Spike max (NLL) loss.

L &= \begin{cases}
    -\int_T
        {\bf 1}[\text{label}]^\top
        \log(\boldsymbol p(t))\,\text dt
        &\text{ if moving window}\\
    -{\bf 1}[\text{label}]^\top
    \log(\boldsymbol p) &\text{ otherwise}
\end{cases}

Note: input is always collapsed in spatial dimension.

Parameters
  • moving_window (int) – size of moving window. If not None, assumes label to be specified at every time step. Defaults to None.

  • mode (str) – confidence mode. One of ‘probability’|’softmax’. Defaults to ‘probability’.

  • reduction (str) – loss reduction method. One of ‘sum’|’mean’. Defaults to ‘sum’.

forward(input, label)

Forward computation of loss.

class lava.lib.dl.slayer.loss.SpikeRate(true_rate, false_rate, moving_window=None, reduction='sum')

Spike rate loss.

\hat {\boldsymbol r} &=
    r_\text{true}\,{\bf 1}[\text{label}] +
    r_\text{false}\,(1 - {\bf 1}[\text{label}])\

L &= \begin{cases}
\frac{1}{2}\int_T(
    {\boldsymbol r}(t) - \hat{\boldsymbol r}(t)
)^\top {\bf 1}\,\text dt &\text{ if moving window}\\
\frac{1}{2}(
    \boldsymbol r - \hat{\boldsymbol r}
)^\top 1 &\text{ otherwise}
\end{cases}

Note: input is always collapsed in spatial dimension.

Parameters
  • true_rate (float) – true spiking rate.

  • false_rate (float) – false spiking rate.

  • moving_window (int) – size of moving window. If not None, assumes label to be specified at every time step. Defaults to None.

  • reduction (str) – loss reduction method. One of ‘sum’|’mean’. Defaults to ‘sum’.

forward(input, label)

Forward computation of loss.

class lava.lib.dl.slayer.loss.SpikeTime(time_constant=5, length=100, filter_order=1, reduction='sum')

Spike-time based loss. It is similar to van Rossum distance between output and desired spike train.

L = \int_0^T \left( \varepsilon * (s - \hat{s}) \right)(t)^2\,
    \text{d}t

Parameters
  • time_constant (int) – time constant of low pass filter. Defaults to 5.

  • length (int) – length of low pass filter. Defaults to 100.

  • filter_order (int) – order of low pass filter. Defaults to 1.

  • reduction (str) – mean square reduction. Options are ‘mean’|’sum’. Defaults to ‘sum’.

forward(input, desired)

Forward computation of loss.