N-MNIST Classification

N-MNIST is the neuromorphic version of MNIST digit recognition. The MNIST digits are converted into event based data using a DVS sensor moving in a repatable tri-saccadic motion each about 100 ms long.

The task is to classify each event sequence to it’s corresponding digit.

Drawing

Drawing

Drawing

Drawing

Drawing

NMNIST dataset is freely available here (© CC-4.0).

Orchard, G.; Cohen, G.; Jayawant, A.; and Thakor, N. “Converting Static Image Datasets to Spiking Neuromorphic Datasets Using Saccades”, Frontiers in Neuroscience, vol.9, no.437, Oct. 2015

[1]:
import os, sys
import glob
import zipfile
import h5py
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader

# import slayer from lava-dl
import lava.lib.dl.slayer as slayer

import IPython.display as display
from matplotlib import animation

Create Dataset

The dataset class follows standard torch dataset definition. They are defined in nmnist.py. We will just import the dataset and augmentation routine here.

[2]:
from nmnist import augment, NMNISTDataset

Create Network

A slayer network definition follows standard PyTorch way using torch.nn.Module.

The network can be described with a combination of individual synapse, dendrite, neuron and axon components. For rapid and easy development, slayer provides block interface - slayer.block - which bundles all these individual components into a single unit. These blocks can be cascaded to build a network easily. The block interface provides additional utilities for normalization (weight and neuron), dropout, gradient monitoring and network export.

In the example below, slayer.block.cuba is illustrated.

[3]:
class Network(torch.nn.Module):
    def __init__(self):
        super(Network, self).__init__()

        neuron_params = {
                'threshold'     : 1.25,
                'current_decay' : 0.25,
                'voltage_decay' : 0.03,
                'tau_grad'      : 0.03,
                'scale_grad'    : 3,
                'requires_grad' : True,
            }
        neuron_params_drop = {**neuron_params, 'dropout' : slayer.neuron.Dropout(p=0.05),}

        self.blocks = torch.nn.ModuleList([
                slayer.block.cuba.Dense(neuron_params_drop, 34*34*2, 512, weight_norm=True, delay=True),
                slayer.block.cuba.Dense(neuron_params_drop, 512, 512, weight_norm=True, delay=True),
                slayer.block.cuba.Dense(neuron_params, 512, 10, weight_norm=True),
            ])

    def forward(self, spike):
        for block in self.blocks:
            spike = block(spike)
        return spike

    def grad_flow(self, path):
        # helps monitor the gradient flow
        grad = [b.synapse.grad_norm for b in self.blocks if hasattr(b, 'synapse')]

        plt.figure()
        plt.semilogy(grad)
        plt.savefig(path + 'gradFlow.png')
        plt.close()

        return grad

    def export_hdf5(self, filename):
        # network export to hdf5 format
        h = h5py.File(filename, 'w')
        layer = h.create_group('layer')
        for i, b in enumerate(self.blocks):
            b.export_hdf5(layer.create_group(f'{i}'))

Instantiate Network, Optimizer, DataSet and DataLoader

Running the network in GPU is as simple as selecting torch.device('cuda').

[4]:
trained_folder = 'Trained'
os.makedirs(trained_folder, exist_ok=True)

# device = torch.device('cpu')
device = torch.device('cuda')

net = Network().to(device)

optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

training_set = NMNISTDataset(train=True, transform=augment)
testing_set  = NMNISTDataset(train=False)

train_loader = DataLoader(dataset=training_set, batch_size=32, shuffle=True)
test_loader  = DataLoader(dataset=testing_set , batch_size=32, shuffle=True)

NMNIST dataset is freely available here: https://www.garrickorchard.com/datasets/n-mnist

(c) Creative Commons:
    Orchard, G.; Cohen, G.; Jayawant, A.; and Thakor, N.
    "Converting Static Image Datasets to Spiking Neuromorphic Datasets Using Saccades",
    Frontiers in Neuroscience, vol.9, no.437, Oct. 2015

Visualize the input data

A slayer.io.Event can be visualized by invoking it’s Event.show() routine. Event.anim() instead returns the event visualization animation which can be embedded in notebook or exported as video/gif. Here, we will export gif animation and visualize it.

[5]:
for i in range(5):
    spike_tensor, label = testing_set[np.random.randint(len(testing_set))]
    spike_tensor = spike_tensor.reshape(2, 34, 34, -1)
    event = slayer.io.tensor_to_event(spike_tensor.cpu().data.numpy())
    anim = event.anim(plt.figure(figsize=(5, 5)), frame_rate=240)
    anim.save(f'gifs/input{i}.gif', animation.PillowWriter(fps=24), dpi=300)
[6]:
gif_td = lambda gif: f'<td> <img src="{gif}" alt="Drawing" style="height: 250px;"/> </td>'
header = '<table><tr>'
images = ' '.join([gif_td(f'gifs/input{i}.gif') for i in range(5)])
footer = '</tr></table>'
display.HTML(header + images + footer)
[6]:
Drawing Drawing Drawing Drawing Drawing

Error module

Slayer provides prebuilt loss modules: slayer.loss.{SpikeTime, SpikeRate, SpikeMax}. * SpikeTime: precise spike time based loss when target spike train is known. * SpikeRate: spike rate based loss when desired rate of the output neuron is known. * SpikeMax: negative log likelihood losses for classification without any rate tuning.

Since the target spike train is not known for this problem, we use SpikeRate loss and target high spiking rate for true class and low spiking rate for false class.

target rate: \hat{\boldsymbol r} = r_\text{true}\,{\bf 1}[\text{label}] + r_\text{false}\,(1-{\bf 1}[\text{label}]) where {\bf 1}[\text{label}] is one-hot encoding of label. The loss is:

L = \frac{1}{2} \left(\frac{1}{T}\int_T {\boldsymbol s}(t)\,\text dt - \hat{\boldsymbol r}\right)^\top {\bf 1}

[7]:
error = slayer.loss.SpikeRate(true_rate=0.2, false_rate=0.03, reduction='sum').to(device)

Stats and Assistants

Slayer provides slayer.utils.LearningStats as a simple learning statistics logger for training, validation and testing.

In addtion, slayer.utils.Assistant module wraps common training validation and testing routine which help simplify the training routine.

[8]:
stats = slayer.utils.LearningStats()
assistant = slayer.utils.Assistant(net, error, optimizer, stats, classifier=slayer.classifier.Rate.predict)

Training Loop

Training loop mainly consists of looping over epochs and calling assistant.train and assistant.test utilities over training and testing dataset. The assistant utility takes care of statndard backpropagation procedure internally.

  • stats can be used in print statement to get formatted stats printout.

  • stats.testing.best_accuracy can be used to find out if the current iteration has the best testing accuracy. Here, we use it to save the best model.

  • stats.update() updates the stats collected for the epoch.

  • stats.save saves the stats in files.

[9]:
epochs = 100

for epoch in range(epochs):
    for i, (input, label) in enumerate(train_loader): # training loop
        output = assistant.train(input, label)
    print(f'\r[Epoch {epoch:2d}/{epochs}] {stats}', end='')

    for i, (input, label) in enumerate(test_loader): # training loop
        output = assistant.test(input, label)
    print(f'\r[Epoch {epoch:2d}/{epochs}] {stats}', end='')

    if epoch%20 == 19: # cleanup display
        print('\r', ' '*len(f'\r[Epoch {epoch:2d}/{epochs}] {stats}'))
        stats_str = str(stats).replace("| ", "\n")
        print(f'[Epoch {epoch:2d}/{epochs}]\n{stats_str}')

    if stats.testing.best_accuracy:
        torch.save(net.state_dict(), trained_folder + '/network.pt')
    stats.update()
    stats.save(trained_folder + '/')
    net.grad_flow(trained_folder + '/')


[Epoch 19/100]
Train loss =     0.11669 (min =     0.11887)    accuracy = 0.96173 (max = 0.96188)
Test  loss =     0.07722 (min =     0.07424)    accuracy = 0.97720 (max = 0.97770)

[Epoch 39/100]
Train loss =     0.09383 (min =     0.09434)    accuracy = 0.97182 (max = 0.97180)
Test  loss =     0.06004 (min =     0.06169)    accuracy = 0.98240 (max = 0.98250)

[Epoch 59/100]
Train loss =     0.08660 (min =     0.08739)    accuracy = 0.97570 (max = 0.97692)
Test  loss =     0.05640 (min =     0.05682)    accuracy = 0.98490 (max = 0.98420)

[Epoch 79/100]
Train loss =     0.08141 (min =     0.08102)    accuracy = 0.97768 (max = 0.97808)
Test  loss =     0.05284 (min =     0.05230)    accuracy = 0.98500 (max = 0.98660)

[Epoch 99/100]
Train loss =     0.07434 (min =     0.07396)    accuracy = 0.97933 (max = 0.97993)
Test  loss =     0.05019 (min =     0.04633)    accuracy = 0.98540 (max = 0.98720)

Plot the learning curves

Plotting the learning curves is as easy as calling stats.plot().

[10]:
stats.plot(figsize=(15, 5))
../../../../_images/lava-lib-dl_slayer_notebooks_nmnist_train_18_0.png
../../../../_images/lava-lib-dl_slayer_notebooks_nmnist_train_18_1.png

Export the best model

Load the best model during training and export it as hdf5 network. It is supported by lava.lib.dl.netx to automatically load the network as a lava process.

[11]:
net.load_state_dict(torch.load(trained_folder + '/network.pt'))
net.export_hdf5(trained_folder + '/network.net')

Visualize the network output

Here, we will use slayer.io.tensor_to_event method to convert the torch output spike tensor into slayer.io.Event object and visualize a few input and output event pairs.

[12]:
output = net(input.to(device))
for i in range(5):
    inp_event = slayer.io.tensor_to_event(input[i].cpu().data.numpy().reshape(2, 34, 34, -1))
    out_event = slayer.io.tensor_to_event(output[i].cpu().data.numpy().reshape(1, 10, -1))
    inp_anim = inp_event.anim(plt.figure(figsize=(5, 5)), frame_rate=240)
    out_anim = out_event.anim(plt.figure(figsize=(10, 5)), frame_rate=240)
    inp_anim.save(f'gifs/inp{i}.gif', animation.PillowWriter(fps=24), dpi=300)
    out_anim.save(f'gifs/out{i}.gif', animation.PillowWriter(fps=24), dpi=300)

[13]:
html = '<table>'
html += '<tr><td align="center"><b>Input</b></td><td><b>Output</b></td></tr>'
for i in range(5):
    html += '<tr>'
    html += gif_td(f'gifs/inp{i}.gif')
    html += gif_td(f'gifs/out{i}.gif')
    html += '</tr>'
html += '</tr></table>'
display.HTML(html)
[13]:
InputOutput
Drawing Drawing
Drawing Drawing
Drawing Drawing
Drawing Drawing
Drawing Drawing