Dynamics, Neurons, and Spikes
Dynamics are the fundamental building blocks of neurons in lava.dl.lib.slayer
. In this tutorial, we will go throught some fundamental neuron dynamics that are built in SLAYER and illustrate how they can be combined to build a variety of neuron models. These dynamics are custom CUDA accelerated, fixed precision compatible, and PyTorch autogard compatible with learnable decay(s) and persistent state(s).
Dynamics scaling: the internal dynamics computation are done in fixed precision range. However, the parameters and state can be interpreted in scaled representation. This has two advantages.
First, the dynamics are scaled such that the backpropagation gradients are usually in proper range for good gradient flow thus eliminating the need for unnatural scaling of surrogate gradients.
Second, the states and decays are in intuitive range rather than in abstract scaled fixed point state.
Notations:
Following notation for variable is used in this notebook.
Notation |
Variable |
---|---|
x[t] |
input |
y[t] |
state variable |
\vartheta[t] |
threshold |
r[t] |
refractory state |
s[t] |
spike flag |
\alpha |
leak parameter |
\phi |
phase shift |
NOTE: this is a deep dive tutorial. It introduces * neuron dynamics in SLAYER. * how some available neurons are implemented. * how one can use these dynamics to build their custom neuron model. * real and complex spike mechanism in SLAYER.
[1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import lava.lib.dl.slayer as slayer
Parameter setup
[2]:
device = torch.device('cpu')
# device = torch.device('cuda')
[3]:
time = np.arange(1000)
t = torch.FloatTensor(time).to(device)
input = torch.zeros_like(t)
for tt, ww in [[10, 1], [97, 1.8], [100, 1.6], [270, -3], [500, 0.5]]:
input[tt] = ww
[4]:
scale = 1<<12 # scale factor for integer simulation
decay = torch.FloatTensor([0.1 * scale]).to(device)
initial_state = torch.FloatTensor([0]).to(device)
threshold = 1.5
Dynamics
Leaky integrator is the basic first order neuron dynamics represented by the following discrete system:
state dynamics: y[t] = (1-\alpha)\,y[t-1] + x[t]
spike dynamics: s[t] = y[t] \geq \vartheta
reset dynamics: y[t] = 0
Leaky integrator dynamics can be cascaded with other dynamics to form a second order neuron like CUBA neuron and other higer order neurons.
[5]:
y = slayer.neuron.dynamics.leaky_integrator.dynamics(input, decay=decay, state=initial_state, w_scale=scale, threshold=threshold)
Fully backpropagable spike mechanism is avalilable as slayer.spike.Spike
. It supports binary as well as graded spikes.
[6]:
sp = slayer.spike.Spike.apply(
y,
threshold,
1, # tau_rho: gradient relaxation constant
1, # scale_rho: gradient scale constant
False, # graded_spike: graded or binary spike
0, # voltage_last: voltage at t=-1
1, # scale: graded spike scale
)
1.1 CUrrent BAsed (CUBA) leaky integrate and fire (LIF) neuron
A CUBA-LIF neuron is simply the leaky integrator dynamics applied to current followed by voltage. For easy usage, CUBA neuron is avaliable as ``slayer.neuron.cuba``.
[7]:
second_order_th = threshold * 5
current = slayer.neuron.dynamics.leaky_integrator.dynamics(input, decay=decay, state=initial_state, w_scale=scale)
voltage = slayer.neuron.dynamics.leaky_integrator.dynamics(current, decay=decay, state=initial_state, w_scale=scale, threshold=second_order_th)
1.2 Plot results
[8]:
fig,ax = plt.subplots(3, 1, figsize=(15, 7))
ax[0].plot(time, input.cpu(), label='weighted spikes')
ax[0].legend(loc='upper right')
ax[1].plot(time, y.cpu(), label='Leaky integrator dynamics')
ax[1].plot(time, threshold * np.ones_like(time), alpha=0.5, label='threshold')
ax[1].plot(time[sp>0], y[sp>0], '*', label='spike')
ax[1].legend(loc='upper right')
ax[2].plot(time, current.cpu(), label='CUBA current')
ax[2].plot(time, voltage.cpu(), label='CUBA voltage')
ax[2].plot(time, second_order_th * np.ones_like(time), alpha=0.5, label='threshold')
ax[2].plot(time[voltage>second_order_th], voltage[voltage>second_order_th].cpu(), label='spike')
ax[2].legend(loc='upper right')
ax[-1].set_xlabel('time')
[8]:
Text(0.5, 0, 'time')
2. Adaptive Threshold Dynamics
Adaptive threshold dynamics provides first order threshold adaptation and refractory state adaptation dynamics.
threshold dynamics: \vartheta[t] = (1-\alpha_{\vartheta})\,(\vartheta[t-1] - \vartheta_0) + \vartheta_0
refractory dynamics: r[t] = (1-\alpha_r)\,r[t-1]
spike dynamics: s[t] = (x[t] - r[t]) \geq \vartheta[t]
post spike dynamics: r[t] = r[t] + 2\,\vartheta[t] and \vartheta[t] = \vartheta[t] + \vartheta_{\text{step}}
2.1 Adaptive Leaky Integrator and Fire Neuron
When coupled with a second order leaky integrator, it results in second order adaptive leaky integartor neuron. For easy usage, ALIF neuron is avaliable as ``slayer.neuron.alif``.
[9]:
current = slayer.neuron.dynamics.leaky_integrator.dynamics(input, decay=decay, state=initial_state, w_scale=scale)
voltage = slayer.neuron.dynamics.leaky_integrator.dynamics(current, decay=decay, state=initial_state, w_scale=scale)
th, ref = slayer.neuron.dynamics.adaptive_threshold.dynamics(
voltage, # dynamics state
ref_state=initial_state, # previous refractory state
ref_decay=0.5*decay, # refractory decay
th_state=initial_state + second_order_th, # previous threshold state
th_decay=decay, # threshold decay
th_scale=0.5*second_order_th, # threshold step
th0=second_order_th, # threshold stable state
w_scale=scale # fixed precision scaling
)
2.2 Plot results
[10]:
fig,ax = plt.subplots(2, 1, figsize=(15, 4.5))
ax[0].plot(time, input.cpu(), label='weighted spikes')
ax[0].legend(loc='upper right')
ax[1].plot(time, current.cpu(), label='ALIF current')
ax[1].plot(time, voltage.cpu(), label='ALIF voltage')
ax[1].plot(time, ref.cpu(), label='refractory dynamics')
ax[1].plot(time, th.cpu(), alpha=0.5, label='threshold')
ax[1].plot(time[(voltage-ref)>th], voltage[(voltage-ref)>th], '*', label='spike')
ax[1].legend(loc='upper right')
ax[-1].set_xlabel('time')
[10]:
Text(0.5, 0, 'time')
3. Resonator
Resonator is first order complex leaky dynamics. The leak is, in general, complex and gives rise to oscillatory dynamics. The resonator dynamics is described by
\frac{\text dz}{\text dt} = (-\lambda + i\omega)\,z + \zeta
where \zeta \in \mathbb{R}^n is the complex input to the system, \lambda, \omega \in R^+.
Discretization
The impulse response of resonator is
h(t) = e^{-\lambda t}\,e^{i\omega t}\,\mathcal H(t)
The coresponding discrete system has the impulse respnse
h[n] = e^{-\lambda n \Delta t}\,e^{i\omega n \Delta t}\,\mathcal H[n] = (e^{-\lambda \Delta t}\,e^{i\omega \Delta t})^n\,\mathcal H[n] = e^{-\lambda \Delta t}\,e^{i\omega \Delta t} h[n-1],\ h[0]=1
The equivalent discrete system is therefore
z[n + 1] = e^{-\lambda \Delta t}\,e^{i\omega \Delta t}\,z[n] + \zeta[n]
The complex decay can be decoupled as
magnitude leak: \alpha = 1 - e^{-\lambda \Delta t}
phase shift: \phi = e^{i\omega \Delta t}
decay matrix: (1-\alpha)\begin{bmatrix} \cos\phi &-\sin\phi \\ \sin\phi &\cos\phi\end{bmatrix}
3.1 Resonate and Fire (RF) neuron
A resonate and fire neuron spikes when it’s internal state crosses the real axis in the positive real half plane greater than the neuron’s theshold, \vartheta. Formally, it can be stated as follows.
f_s(z) = \mathcal{H}(\mathfrak{Re}(z) - \vartheta)\,\delta(\mathfrak{Im}(z))
or equivalently
f_s(z) = \mathcal{H}(|z| - \vartheta)\,\delta(\arg(z))
For easy usage, RF neuron is avaliable as ``slayer.neuron.rf``.
[11]:
re_input = input
im_input = 2*torch.randn_like(re_input) * (re_input > 0)
alpha = torch.FloatTensor([0.03 * scale]).to(device)
phi = 2 * np.pi /25
sin_decay = (scale-alpha) * np.sin(phi)
cos_decay = (scale-alpha) * np.cos(phi)
re, im = slayer.neuron.dynamics.resonator.dynamics(
re_input, im_input,
sin_decay, cos_decay,
real_state=initial_state,
imag_state=initial_state,
w_scale=scale,
)
Fully backpropagable complex spike mechanism supporting phase spiking mechanism of RF neuron is avalilable as slayer.spike.complex.Spike
. It supports binary as well as graded spikes.
spike dynamics: |z[t]| \geq \vartheta and \arg(z[t]) = 0
[12]:
sp = slayer.spike.complex.Spike.apply(
re, im,
threshold,
1, # tau_rho: gradient relaxation constant
1, # scale_rho: gradient scale constant
False, # graded_spike: graded or binary spike
0, # voltage_last: voltage at t=-1
1, # scale: graded spike scale
)
3.2 Izhikevich Resonate and Fire (RF-Iz) neuron
RF-Izhikevich[1] neuron dynamics is same as the basic RF neuron. However the firing and reset mechanism is different. The neuron fires when the imaginary state is above threshold and the real state is reset to zero post spike.
spike dynamics: \mathfrak{Im}(z[t]) \geq \vartheta
post spike dynamics: \mathfrak{Re}(z[t]) = 0
For easy usage, RF-Izhikevich neuron is avaliable as ``slayer.neuron.rf_iz``.
[1] Eugene M. Izhikevich Resonate and Fire Neurons.
[13]:
iz_re, iz_im = slayer.neuron.dynamics.resonator.dynamics(
re_input, im_input,
sin_decay, cos_decay,
real_state=initial_state,
imag_state=initial_state,
w_scale=scale,
threshold=threshold,
)
3.3 Second Order RF neuron
Two resonator dynamics can be cascaded to produce Gammatone like second order resonator dynamics. In theory, it could also be combined with leaky integrator for even more exotic neuron model.
[14]:
second_order_th = threshold * 15
re_0, im_0 = slayer.neuron.dynamics.resonator.dynamics(
re_input, im_input,
sin_decay, cos_decay,
real_state=initial_state,
imag_state=initial_state,
w_scale=scale,
)
re_1, im_1 = slayer.neuron.dynamics.resonator.dynamics(
re_0, im_0,
sin_decay, cos_decay,
real_state=initial_state,
imag_state=initial_state,
w_scale=scale,
)
sp_1 = slayer.spike.complex.Spike.apply(re_1, im_1, second_order_th, 1, 1, False, 0, 1)
3.4 Plot results
[15]:
fig,ax = plt.subplots(4, 1, figsize=(15, 9))
ax[0].plot(time, re_input.cpu(), label='real weighted spikes')
ax[0].plot(time, im_input.cpu(), label='imag weighted spikes')
ax[0].legend(loc='upper right')
ax[1].plot(time, re.cpu(), label='RF real state')
ax[1].plot(time, im.cpu(), label='RF imag state')
ax[1].plot(time, threshold * np.ones_like(time), alpha=0.5, label='threshold')
ax[1].plot(time[sp>0], re[sp>0], '*', label='spike')
ax[1].legend(loc='upper right')
ax[2].plot(time, iz_re.cpu(), label='RF Izhikevich real state')
ax[2].plot(time, iz_im.cpu(), label='RF Izhikevich imag state')
ax[2].plot(time, threshold * np.ones_like(time), alpha=0.5, label='threshold')
ax[2].plot(time[iz_im > threshold], iz_im[iz_im > threshold], '*', label='spike')
ax[2].legend(loc='upper right')
ax[3].plot(time, re_1.cpu(), label='RF-Or2 real state')
ax[3].plot(time, im_1.cpu(), label='RF-Or2 imag state')
ax[3].plot(time, second_order_th * np.ones_like(time), alpha=0.5, label='threshold')
ax[3].plot(time[sp_1>0], re_1[sp_1>0], '*', label='spike')
ax[3].legend(loc='upper right')
ax[-1].set_xlabel('time')
[15]:
Text(0.5, 0, 'time')
[16]:
def plot_phase_region(ax, threshold, sin_decay):
xlims = ax.get_xlim()
ylims = ax.get_ylim()
xx = np.array([threshold, threshold+500])
yy = xx * sin_decay
ax.fill_between(xx, yy, color='green', alpha=0.1)
ax.set_xlim(xlims)
ax.set_ylim(ylims)
def plot_iz_region(ax, threshold):
ylims = ax.get_ylim()
ax.axhspan(threshold, threshold+50, color='green', alpha=0.1)
ax.set_ylim(ylims)
[17]:
fig, ax = plt.subplots(1, 3, figsize=(10, 3))
ax[0].plot(re.cpu(), im.cpu())
ax[0].plot(re[sp>0], im[sp>0], '*', label='spike')
plot_phase_region(ax[0], threshold, sin_decay.item()/scale)
ax[0].set_xlabel('$\mathfrak{Re}(z)$')
ax[0].set_ylabel('$\mathfrak{Im}(z)$')
ax[0].legend(loc='lower left')
ax[0].set_title('RF')
ax[1].plot(iz_re.cpu(), iz_im.cpu())
ax[1].plot(iz_re[iz_im > threshold], iz_im[iz_im > threshold], '*', label='spike')
plot_iz_region(ax[1], threshold)
ax[1].set_xlabel('$\mathfrak{Re}(z)$')
ax[1].legend(loc='lower left')
ax[1].set_title('RF Izhikevich')
ax[2].plot(re_1.cpu(), im_1.cpu())
ax[2].plot(re_1[sp_1>0], im_1[sp_1>0], '*', label='spike')
plot_phase_region(ax[2], second_order_th, sin_decay.item()/scale)
ax[2].set_xlabel('$\mathfrak{Re}(z)$')
ax[2].legend(loc='lower left')
ax[2].set_title('RF-Or2')
[17]:
Text(0.5, 1.0, 'RF-Or2')
4. Adaptive Resonator
Adaptive resonator adds adaptive threshold and refractory dynamics on top of resonator. Two flavors of adaptive resonator dynamics are available in SLAYER: slayer.neuron.dynamics.phase_th
and slayer.neuron.dynamics.adaptive_resonator
corresponding to phase spiking and Izhikevich spiking mechanisms respectively. Both dynamics follow the same post spike dynamics.
post spike dynamics: r[t] = r[t] + 2\,\vartheta[t] and \vartheta[t] = \vartheta[t] + \vartheta_{\text{step}}
4.1 Adaptive Resonate and Fire (AdRF) neuron
AdRF neuron spikes when its state crosses zero phase with real value higher than refractory dynamics and threshold dynamics combined.
spike dynamics: |z[t]| \geq (\vartheta[t] + r[t]) and \arg(z[t]) = 0
For easy usage, AdRF neuron is avaliable as ``slayer.neuron.adrf``.
[18]:
adrf_re, adrf_im = slayer.neuron.dynamics.resonator.dynamics(
re_input, im_input,
sin_decay, cos_decay,
real_state=initial_state,
imag_state=initial_state,
w_scale=scale,
)
adrf_th, adrf_ref = slayer.neuron.dynamics.adaptive_phase_th.dynamics(
adrf_re, adrf_im,
im_state=initial_state, # only imaginary state is needed to determine first phase crossing
ref_state=initial_state, ref_decay=0.5*decay, # refractory state and decay
th_state=initial_state + threshold, th_decay=decay, # threshold state and decay
th_scale=0.5 * threshold, # threshold step
th0=threshold, # threshold stable state
w_scale=scale,
)
adrf_sp = slayer.spike.complex.Spike.apply(
adrf_re, adrf_im, adrf_th + adrf_ref,
1, # tau_rho: gradient relaxation constant
1, # scale_rho: gradient scale constant
False, # graded_spike: graded or binary spike
0, # voltage_last: voltage at t=-1
1, # scale: graded spike scale
)
4.2 Adaptive Resonate and Fire Izhikevich (AdRF-Iz) neuron
AdRF-Iz neuron fires when the imaginary state exceeds the threshold and refractory dynamics. There is no hard reset in this model.
spike dynamics: \mathfrak{Im}(z[t]) \geq (\vartheta[t] + r[t])
For easy usage, AdRF-Iz neuron is avaliable as ``slayer.neuron.adrf_iz``.
[19]:
adrf_iz_re, adrf_iz_im, adrf_iz_th, adrf_iz_ref = slayer.neuron.dynamics.adaptive_resonator.dynamics(
re_input, im_input,
sin_decay, cos_decay, ref_decay=0.5*decay, th_decay=decay,
real_state=initial_state,
imag_state=initial_state,
ref_state=initial_state,
th_state=initial_state + threshold,
th_scale=0.5 * threshold, # threshold step
th0=threshold, # threshold stable state
w_scale=scale,
)
adrf_iz_sp = adrf_iz_im > (adrf_iz_th + adrf_iz_ref)
4.3 Plot results
[20]:
fig,ax = plt.subplots(3, 1, figsize=(15, 6.6))
ax[0].plot(time, re_input.cpu(), label='real weighted spikes')
ax[0].plot(time, im_input.cpu(), label='imag weighted spikes')
ax[0].legend(loc='upper right')
ax[1].plot(time, adrf_re.cpu(), label='AdRF real state')
ax[1].plot(time, adrf_im.cpu(), label='AdRF imag state')
ax[1].plot(time, adrf_ref.cpu(), label='refractory dynamics')
ax[1].plot(time, adrf_th.cpu(), alpha=0.5, label='threshold')
ax[1].plot(time[adrf_sp>0], adrf_re[adrf_sp>0], '*', label='spike')
ax[1].legend(loc='upper right')
ax[2].plot(time, adrf_iz_re.cpu(), label='AdRF-Iz real state')
ax[2].plot(time, adrf_iz_im.cpu(), label='AdRF-Iz imag state')
ax[2].plot(time, adrf_iz_ref.cpu(), label='refractory dynamics')
ax[2].plot(time, adrf_iz_th.cpu(), alpha=0.5, label='threshold')
ax[2].plot(time[adrf_iz_sp], adrf_iz_im[adrf_iz_sp], '*', label='spike')
ax[2].legend(loc='upper right')
ax[-1].set_xlabel('time')
[20]:
Text(0.5, 0, 'time')
[21]:
import matplotlib.patches as patches
fig, ax = plt.subplots(1, 2, figsize=(6.5, 3), sharey=True)
ax[0].plot(adrf_re.cpu(), adrf_im.cpu())
ax[0].plot(adrf_re[sp>0], adrf_im[sp>0], '*', label='spike')
plot_phase_region(ax[0], threshold, sin_decay.item()/scale)
ax[0].set_xlabel('$\mathfrak{Re}(z)$')
ax[0].set_ylabel('$\mathfrak{Im}(z)$')
ax[0].legend(loc='lower left')
ax[0].set_title('AdRF')
ax[1].plot(adrf_iz_re.cpu(), adrf_iz_im.cpu())
ax[1].plot(adrf_iz_re[adrf_iz_sp], adrf_iz_im[adrf_iz_sp], '*', label='spike')
plot_iz_region(ax[1], threshold)
ax[1].set_xlabel('$\mathfrak{Re}(z)$')
ax[1].legend(loc='lower left')
ax[1].set_title('AdRF Izhikevich')
[21]:
Text(0.5, 1.0, 'AdRF Izhikevich')