Loss

This module provides some pre-built loss methods to be used with spike-train. Standard PyTorch loss are also compatible.

class lava.lib.dl.slayer.loss.SparsityEnforcer(max_rate=0.01, lam=1.0)

Event sparsity enforcement module. Penalizes event rate higher than a specific value.

Parameters:
  • max_rate (float, optional) – Rate above which the events are penalized, by default 0.01.

  • lam (float, optional) – Ratio of event rate loss scaling, by default 1.0.

append(x)

Appends loss tickets given the state of input tensors.

Parameters:

x (torch.tensor) – Input tensor.

Return type:

None

clear()

Clear all gathered sparsity loss.

Return type:

None

property loss: tensor

Accumulate sparsity loss.

class lava.lib.dl.slayer.loss.SpikeMax(moving_window=None, mode='probability', reduction='sum')

Spike max (NLL) loss.

L &= \begin{cases} -\int_T {\bf 1}[\text{label}]^\top \log(\boldsymbol p(t))\,\text dt &\text{ if moving window}\\ -{\bf 1}[\text{label}]^\top \log(\boldsymbol p) &\text{ otherwise} \end{cases}

Note: input is always collapsed in spatial dimension.

Parameters:
  • moving_window (int) – size of moving window. If not None, assumes label to be specified at every time step. Defaults to None.

  • mode (str) – confidence mode. One of ‘probability’|’softmax’. Defaults to ‘probability’.

  • reduction (str) – loss reduction method. One of ‘sum’|’mean’. Defaults to ‘sum’.

forward(input, label)

Forward computation of loss.

class lava.lib.dl.slayer.loss.SpikeMoid(moving_window=None, reduction='sum', alpha=1, theta=0)

SpikeMoid (BCE) loss.

\text{if sliding window:} \quad p(t) = \sigma\left(\frac{r(t) - \theta}{\alpha}\right) \\ \text{otherwise:} \quad p = \sigma\left(\frac{r - \theta}{\alpha}\right)

r signifies a spike rate calculated over the time dimension

\mathcal{L} = \begin{cases} -\int_T \hat{y}(t) \cdot \log{p(t)} + (1 - \hat{y}(t)) \cdot \log{(1 - p(t))}\,\text{d}t &\text{if sliding window} \\ -\left(\hat{y} \cdot \log{p} + (1 - \hat{y}) \cdot \log{(1 - p)}\right) &\text{otherwise} \end{cases}

Note: input is always collapsed in the spatial dimension. r signifies a spike rate calculated over the time dimension

Parameters:
  • moving_window (int) – size of moving window. If not None, assumes label to be specified at every time step. Defaults to None.

  • reduction (str) – loss reduction method. One of ‘sum’|’mean’. Defaults to ‘sum’.

  • alpha (int) – Sigmoid temperature parameter. Defaults to 1.

  • theta (int) – Bias term for logits. Defaults to 1.

forward(input, label)

Forward computation of loss.

class lava.lib.dl.slayer.loss.SpikeRate(true_rate, false_rate, moving_window=None, reduction='sum')

Spike rate loss.

\hat {\boldsymbol r} &= r_\text{true}\,{\bf 1}[\text{label}] + r_\text{false}\,(1 - {\bf 1}[\text{label}])\ L &= \begin{cases} \frac{1}{2}\int_T( {\boldsymbol r}(t) - \hat{\boldsymbol r}(t) )^\top {\bf 1}\,\text dt &\text{ if moving window}\\ \frac{1}{2}( \boldsymbol r - \hat{\boldsymbol r} )^\top 1 &\text{ otherwise} \end{cases}

Note: input is always collapsed in spatial dimension.

Parameters:
  • true_rate (float) – true spiking rate.

  • false_rate (float) – false spiking rate.

  • moving_window (int) – size of moving window. If not None, assumes label to be specified at every time step. Defaults to None.

  • reduction (str) – loss reduction method. One of ‘sum’|’mean’. Defaults to ‘sum’.

forward(input, label)

Forward computation of loss.

class lava.lib.dl.slayer.loss.SpikeTime(time_constant=5, length=100, filter_order=1, reduction='sum')

Spike-time based loss. It is similar to van Rossum distance between output and desired spike train.

L = \int_0^T \left( \varepsilon * (s - \hat{s}) \right)(t)^2\, \text{d}t

Parameters:
  • time_constant (int) – time constant of low pass filter. Defaults to 5.

  • length (int) – length of low pass filter. Defaults to 100.

  • filter_order (int) – order of low pass filter. Defaults to 1.

  • reduction (str) – mean square reduction. Options are ‘mean’|’sum’. Defaults to ‘sum’.

forward(input, desired)

Forward computation of loss.