{ "cells": [ { "cell_type": "markdown", "id": "8e62b37e-35dc-45b2-a025-31cf9ee971c5", "metadata": {}, "source": [ "# N-MNIST Classification\n", "\n", "__N-MNIST__ is the neuromorphic version of MNIST digit recognition. The MNIST digits are converted into event based data using a DVS sensor moving in a repatable tri-saccadic motion each about 100 ms long.\n", "\n", "The task is to classify each event sequence to it's corresponding digit.\n", "\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", "\n", "
\"Drawing\" \"Drawing\" \"Drawing\" \"Drawing\" \"Drawing\"
\n", "\n", "NMNIST dataset is freely available [here](https://www.garrickorchard.com/datasets/n-mnist) (© CC-4.0).\n", "\n", "> Orchard, G.; Cohen, G.; Jayawant, A.; and Thakor, N. _\"Converting Static Image Datasets to Spiking Neuromorphic Datasets Using Saccades\"_,\n", "Frontiers in Neuroscience, vol.9, no.437, Oct. 2015\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "afc7d708-0431-4b60-91f0-9b30edbedac0", "metadata": {}, "outputs": [], "source": [ "import os, sys\n", "import glob\n", "import zipfile\n", "import h5py\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import torch\n", "from torch.utils.data import Dataset, DataLoader\n", "\n", "# import slayer from lava-dl\n", "import lava.lib.dl.slayer as slayer\n", "\n", "import IPython.display as display\n", "from matplotlib import animation" ] }, { "cell_type": "markdown", "id": "670955ec-b45f-4ce5-a0aa-acce71a07370", "metadata": {}, "source": [ "# Create Dataset\n", "\n", "The dataset class follows standard torch dataset definition. They are defined in `nmnist.py`. We will just import the dataset and augmentation routine here." ] }, { "cell_type": "code", "execution_count": 2, "id": "36ceee15-929a-43e4-a63e-57772019c1c4", "metadata": {}, "outputs": [], "source": [ "from nmnist import augment, NMNISTDataset" ] }, { "cell_type": "markdown", "id": "a4ff3649-9c82-4fd4-bdbe-e4672f726dc7", "metadata": {}, "source": [ "# Create Network\n", "\n", "A slayer network definition follows standard PyTorch way using `torch.nn.Module`.\n", "\n", "The network can be described with a combination of individual `synapse`, `dendrite`, `neuron` and `axon` components. For rapid and easy development, slayer provides __block interface__ - `slayer.block` - which bundles all these individual components into a single unit. These blocks can be cascaded to build a network easily. The block interface provides additional utilities for normalization (weight and neuron), dropout, gradient monitoring and network export.\n", "\n", "In the example below, `slayer.block.cuba` is illustrated." ] }, { "cell_type": "code", "execution_count": 3, "id": "4172d38f-7d39-475f-bac8-7985fb1baa53", "metadata": {}, "outputs": [], "source": [ "class Network(torch.nn.Module):\n", " def __init__(self):\n", " super(Network, self).__init__()\n", "\n", " neuron_params = {\n", " 'threshold' : 1.25,\n", " 'current_decay' : 0.25,\n", " 'voltage_decay' : 0.03,\n", " 'tau_grad' : 0.03,\n", " 'scale_grad' : 3,\n", " 'requires_grad' : True, \n", " }\n", " neuron_params_drop = {**neuron_params, 'dropout' : slayer.neuron.Dropout(p=0.05),}\n", " \n", " self.blocks = torch.nn.ModuleList([\n", " slayer.block.cuba.Dense(neuron_params_drop, 34*34*2, 512, weight_norm=True, delay=True),\n", " slayer.block.cuba.Dense(neuron_params_drop, 512, 512, weight_norm=True, delay=True),\n", " slayer.block.cuba.Dense(neuron_params, 512, 10, weight_norm=True),\n", " ])\n", " \n", " def forward(self, spike):\n", " for block in self.blocks:\n", " spike = block(spike)\n", " return spike\n", " \n", " def grad_flow(self, path):\n", " # helps monitor the gradient flow\n", " grad = [b.synapse.grad_norm for b in self.blocks if hasattr(b, 'synapse')]\n", "\n", " plt.figure()\n", " plt.semilogy(grad)\n", " plt.savefig(path + 'gradFlow.png')\n", " plt.close()\n", "\n", " return grad\n", "\n", " def export_hdf5(self, filename):\n", " # network export to hdf5 format\n", " h = h5py.File(filename, 'w')\n", " layer = h.create_group('layer')\n", " for i, b in enumerate(self.blocks):\n", " b.export_hdf5(layer.create_group(f'{i}'))" ] }, { "cell_type": "markdown", "id": "7617ab25-f112-42c6-8bdd-28d0ded7ffb1", "metadata": {}, "source": [ "# Instantiate Network, Optimizer, DataSet and DataLoader\n", "\n", "Running the network in GPU is as simple as selecting `torch.device('cuda')`." ] }, { "cell_type": "code", "execution_count": 4, "id": "47d40cfa-7c30-4192-910c-1b5a90e08c8e", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "NMNIST dataset is freely available here: https://www.garrickorchard.com/datasets/n-mnist\n", "\n", "(c) Creative Commons:\n", " Orchard, G.; Cohen, G.; Jayawant, A.; and Thakor, N.\n", " \"Converting Static Image Datasets to Spiking Neuromorphic Datasets Using Saccades\",\n", " Frontiers in Neuroscience, vol.9, no.437, Oct. 2015\n", "\n" ] } ], "source": [ "trained_folder = 'Trained'\n", "os.makedirs(trained_folder, exist_ok=True)\n", "\n", "# device = torch.device('cpu')\n", "device = torch.device('cuda') \n", "\n", "net = Network().to(device)\n", "\n", "optimizer = torch.optim.Adam(net.parameters(), lr=0.001)\n", "\n", "training_set = NMNISTDataset(train=True, transform=augment)\n", "testing_set = NMNISTDataset(train=False)\n", "\n", "train_loader = DataLoader(dataset=training_set, batch_size=32, shuffle=True)\n", "test_loader = DataLoader(dataset=testing_set , batch_size=32, shuffle=True)" ] }, { "cell_type": "markdown", "id": "87f1551e-ce69-46ef-b1bc-d73be3c97794", "metadata": {}, "source": [ "# Visualize the input data\n", "\n", "A `slayer.io.Event` can be visualized by invoking it's `Event.show()` routine. `Event.anim()` instead returns the event visualization animation which can be embedded in notebook or exported as video/gif. Here, we will export gif animation and visualize it." ] }, { "cell_type": "code", "execution_count": 5, "id": "be0b1c3b-77ec-4d8d-9fb3-22ae2b6dd742", "metadata": {}, "outputs": [], "source": [ "for i in range(5):\n", " spike_tensor, label = testing_set[np.random.randint(len(testing_set))]\n", " spike_tensor = spike_tensor.reshape(2, 34, 34, -1)\n", " event = slayer.io.tensor_to_event(spike_tensor.cpu().data.numpy())\n", " anim = event.anim(plt.figure(figsize=(5, 5)), frame_rate=240)\n", " anim.save(f'gifs/input{i}.gif', animation.PillowWriter(fps=24), dpi=300)" ] }, { "cell_type": "code", "execution_count": 6, "id": "89122bc4-6441-4963-bb43-2313432f0530", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\"Drawing\" \"Drawing\" \"Drawing\" \"Drawing\" \"Drawing\"
" ], "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gif_td = lambda gif: f' \"Drawing\" '\n", "header = ''\n", "images = ' '.join([gif_td(f'gifs/input{i}.gif') for i in range(5)])\n", "footer = '
'\n", "display.HTML(header + images + footer)" ] }, { "cell_type": "markdown", "id": "d5a3fc61-6560-40fa-a222-5c051dc2ede7", "metadata": {}, "source": [ "# Error module\n", "\n", "Slayer provides prebuilt loss modules: `slayer.loss.{SpikeTime, SpikeRate, SpikeMax}`.\n", "* `SpikeTime`: precise spike time based loss when target spike train is known.\n", "* `SpikeRate`: spike rate based loss when desired rate of the output neuron is known.\n", "* `SpikeMax`: negative log likelihood losses for classification without any rate tuning.\n", "\n", "Since the target spike train is not known for this problem, we use `SpikeRate` loss and target high spiking rate for true class and low spiking rate for false class.\n", "\n", "target rate: $\\hat{\\boldsymbol r} = r_\\text{true}\\,{\\bf 1}[\\text{label}] + r_\\text{false}\\,(1-{\\bf 1}[\\text{label}])$ where ${\\bf 1}[\\text{label}]$ is one-hot encoding of label. The loss is:\n", "\n", "$$L = \\frac{1}{2} \\left(\\frac{1}{T}\\int_T {\\boldsymbol s}(t)\\,\\text dt - \\hat{\\boldsymbol r}\\right)^\\top {\\bf 1}$$" ] }, { "cell_type": "code", "execution_count": 7, "id": "3c4b9b92-925a-4da8-b146-923dc4b6ad5f", "metadata": {}, "outputs": [], "source": [ "error = slayer.loss.SpikeRate(true_rate=0.2, false_rate=0.03, reduction='sum').to(device)" ] }, { "cell_type": "markdown", "id": "6cb1c953-78e7-47b7-9ce5-ab747ec23eb6", "metadata": {}, "source": [ "# Stats and Assistants\n", "\n", "Slayer provides `slayer.utils.LearningStats` as a simple learning statistics logger for training, validation and testing.\n", "\n", "In addtion, `slayer.utils.Assistant` module wraps common training validation and testing routine which help simplify the training routine." ] }, { "cell_type": "code", "execution_count": 8, "id": "473884dd-d2fa-4e6a-b44d-6a1c303dc950", "metadata": {}, "outputs": [], "source": [ "stats = slayer.utils.LearningStats()\n", "assistant = slayer.utils.Assistant(net, error, optimizer, stats, classifier=slayer.classifier.Rate.predict)" ] }, { "cell_type": "markdown", "id": "05eb7069-d384-4368-8235-f5b782c5eeae", "metadata": {}, "source": [ "# Training Loop\n", "\n", "Training loop mainly consists of looping over epochs and calling `assistant.train` and `assistant.test` utilities over training and testing dataset. The `assistant` utility takes care of statndard backpropagation procedure internally.\n", "\n", "* `stats` can be used in print statement to get formatted stats printout.\n", "* `stats.testing.best_accuracy` can be used to find out if the current iteration has the best testing accuracy. Here, we use it to save the best model.\n", "* `stats.update()` updates the stats collected for the epoch.\n", "* `stats.save` saves the stats in files." ] }, { "cell_type": "code", "execution_count": 9, "id": "dbcbc25a-41bb-49f0-bdb7-26be901626c9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", " \n", "[Epoch 19/100]\n", "Train loss = 0.11669 (min = 0.11887) accuracy = 0.96173 (max = 0.96188) \n", "Test loss = 0.07722 (min = 0.07424) accuracy = 0.97720 (max = 0.97770)\n", " \n", "[Epoch 39/100]\n", "Train loss = 0.09383 (min = 0.09434) accuracy = 0.97182 (max = 0.97180) \n", "Test loss = 0.06004 (min = 0.06169) accuracy = 0.98240 (max = 0.98250)\n", " \n", "[Epoch 59/100]\n", "Train loss = 0.08660 (min = 0.08739) accuracy = 0.97570 (max = 0.97692) \n", "Test loss = 0.05640 (min = 0.05682) accuracy = 0.98490 (max = 0.98420)\n", " \n", "[Epoch 79/100]\n", "Train loss = 0.08141 (min = 0.08102) accuracy = 0.97768 (max = 0.97808) \n", "Test loss = 0.05284 (min = 0.05230) accuracy = 0.98500 (max = 0.98660)\n", " \n", "[Epoch 99/100]\n", "Train loss = 0.07434 (min = 0.07396) accuracy = 0.97933 (max = 0.97993) \n", "Test loss = 0.05019 (min = 0.04633) accuracy = 0.98540 (max = 0.98720)\n" ] } ], "source": [ "epochs = 100\n", "\n", "for epoch in range(epochs):\n", " for i, (input, label) in enumerate(train_loader): # training loop\n", " output = assistant.train(input, label)\n", " print(f'\\r[Epoch {epoch:2d}/{epochs}] {stats}', end='')\n", " \n", " for i, (input, label) in enumerate(test_loader): # training loop\n", " output = assistant.test(input, label)\n", " print(f'\\r[Epoch {epoch:2d}/{epochs}] {stats}', end='')\n", " \n", " if epoch%20 == 19: # cleanup display\n", " print('\\r', ' '*len(f'\\r[Epoch {epoch:2d}/{epochs}] {stats}'))\n", " stats_str = str(stats).replace(\"| \", \"\\n\")\n", " print(f'[Epoch {epoch:2d}/{epochs}]\\n{stats_str}')\n", " \n", " if stats.testing.best_accuracy:\n", " torch.save(net.state_dict(), trained_folder + '/network.pt')\n", " stats.update()\n", " stats.save(trained_folder + '/')\n", " net.grad_flow(trained_folder + '/')" ] }, { "cell_type": "markdown", "id": "854daacd-5ea4-46fd-befd-b0918b26bda3", "metadata": {}, "source": [ "# Plot the learning curves\n", "\n", "Plotting the learning curves is as easy as calling `stats.plot()`." ] }, { "cell_type": "code", "execution_count": 10, "id": "be59928f-4da2-4055-8802-71216003bf0c", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "stats.plot(figsize=(15, 5))" ] }, { "cell_type": "markdown", "id": "d7785921-b165-425a-a3c8-600b1b822378", "metadata": {}, "source": [ "# Export the best model\n", "\n", "Load the best model during training and export it as hdf5 network. It is supported by `lava.lib.dl.netx` to automatically load the network as a lava process." ] }, { "cell_type": "code", "execution_count": 11, "id": "fa9a9efa-da94-45b9-8598-c669e522a5c7", "metadata": {}, "outputs": [], "source": [ "net.load_state_dict(torch.load(trained_folder + '/network.pt'))\n", "net.export_hdf5(trained_folder + '/network.net')" ] }, { "cell_type": "markdown", "id": "a9a89584-2b71-49d9-b9e5-582e657676bc", "metadata": {}, "source": [ "# Visualize the network output\n", "\n", "Here, we will use `slayer.io.tensor_to_event` method to convert the torch output spike tensor into `slayer.io.Event` object and visualize a few input and output event pairs." ] }, { "cell_type": "code", "execution_count": 12, "id": "2a37918a-c45d-40a5-b727-e2fe610b7c80", "metadata": {}, "outputs": [], "source": [ "output = net(input.to(device))\n", "for i in range(5):\n", " inp_event = slayer.io.tensor_to_event(input[i].cpu().data.numpy().reshape(2, 34, 34, -1))\n", " out_event = slayer.io.tensor_to_event(output[i].cpu().data.numpy().reshape(1, 10, -1))\n", " inp_anim = inp_event.anim(plt.figure(figsize=(5, 5)), frame_rate=240)\n", " out_anim = out_event.anim(plt.figure(figsize=(10, 5)), frame_rate=240)\n", " inp_anim.save(f'gifs/inp{i}.gif', animation.PillowWriter(fps=24), dpi=300)\n", " out_anim.save(f'gifs/out{i}.gif', animation.PillowWriter(fps=24), dpi=300)\n" ] }, { "cell_type": "code", "execution_count": 13, "id": "d933ed73-cdc1-43b8-8045-57cc9c499fba", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
InputOutput
\"Drawing\" \"Drawing\"
\"Drawing\" \"Drawing\"
\"Drawing\" \"Drawing\"
\"Drawing\" \"Drawing\"
\"Drawing\" \"Drawing\"
" ], "text/plain": [ "" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "html = ''\n", "html += ''\n", "for i in range(5):\n", " html += ''\n", " html += gif_td(f'gifs/inp{i}.gif')\n", " html += gif_td(f'gifs/out{i}.gif')\n", " html += ''\n", "html += '
InputOutput
'\n", "display.HTML(html)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 5 }