{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "*Copyright (C) 2021 Intel Corporation*
\n", "*SPDX-License-Identifier: BSD-3-Clause*
\n", "*See: https://spdx.org/licenses/*\n", "\n", "---\n", "\n", "# MNIST Digit Classification with Lava" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "_**Motivation**: In this tutorial, we will build a Lava Process for an MNIST\n", "classifier, using the Lava Processes for LIF neurons and Dense connectivity.\n", "The tutorial is useful to get started with Lava in a few minutes._" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### This tutorial assumes that you:\n", "- have the [Lava framework installed](../in_depth/tutorial01_installing_lava.ipynb \"Tutorial on Installing Lava\")\n", "- are familiar with the [Process concept in Lava](../in_depth/tutorial02_processes.ipynb \"Tutorial on Processes\")\n", "\n", "#### This tutorial gives a bird's-eye view of\n", "- how Lava Process(es) can perform the MNIST digit classification task using\n", "[Leaky Integrate-and-Fire (LIF)](https://github.com/lava-nc/lava/tree/main/src/lava/proc/lif \"Lava's LIF neuron\") neurons and [Dense\n", "(fully connected)](https://github.com/lava-nc/lava/tree/main/src/lava/proc/dense \"Lava's Dense Connectivity\") connectivity.\n", "- how to create a Process \n", "- how to create Python ProcessModels \n", "- how to connect Processes\n", "- how to execute them" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Our MNIST Classifier\n", "In this tutorial, we will build a multi-layer feed-forward classifier without\n", " any convolutional layers. The architecture is shown below.\n", "\n", "> **Important Note**:\n", ">\n", "> The classifier is a simple feed-forward model using pre-trained network \n", "> parameters (weights and biases). It illustrates how to build, compile and \n", "> run a functional model in Lava. Please refer to \n", "> [Lava-DL](https://github.com/lava-nc/lava-dl) to understand how to train \n", "> deep networks with Lava." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\"MNIST
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The 3 Processes shown above are:\n", " - *SpikeInput* Process - generates spikes via integrate and fire dynamics,\n", " using the image input\n", " - *ImageClassifier* Process - encapsulates feed-forward architecture of\n", " Dense connectivity and LIF neurons\n", " - *Output* Process - accumulates output spikes from the feed-forward process\n", "and infers the class label" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### General Imports" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "import numpy as np\n", "import typing as ty" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Lava Processes\n", "\n", "Below we create the Lava Process classes. We need to define only the structure of the process here. The details about how the Process will be executed are specified in the [ProcessModels](../in_depth/tutorial03_process_models.ipynb \"Tutorial on ProcessModels\") below.\n", "\n", "As mentioned above, we define Processes for \n", "- converting input images to binary spikes from those biases (_SpikeInput_),\n", "- the 3-layer fully connected feed-forward network (_MnistClassifier_)\n", "- accumulating the output spikes and inferring the class for an input image\n", "(_OutputProcess_)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Import Process level primitives\n", "from lava.magma.core.process.process import AbstractProcess\n", "from lava.magma.core.process.variable import Var\n", "from lava.magma.core.process.ports.ports import InPort, OutPort" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "class SpikeInput(AbstractProcess):\n", " \"\"\"Reads image data from the MNIST dataset and converts it to spikes.\n", " The resulting spike rate is proportional to the pixel value.\"\"\"\n", "\n", " def __init__(self,\n", " vth: int,\n", " num_images: ty.Optional[int] = 25,\n", " num_steps_per_image: ty.Optional[int] = 128):\n", " super().__init__()\n", " shape = (784,)\n", " self.spikes_out = OutPort(shape=shape) # Input spikes to the classifier\n", " self.label_out = OutPort(shape=(1,)) # Ground truth labels to OutputProc\n", " self.num_images = Var(shape=(1,), init=num_images)\n", " self.num_steps_per_image = Var(shape=(1,), init=num_steps_per_image)\n", " self.input_img = Var(shape=shape)\n", " self.ground_truth_label = Var(shape=(1,))\n", " self.v = Var(shape=shape, init=0)\n", " self.vth = Var(shape=(1,), init=vth)\n", " \n", " \n", "class ImageClassifier(AbstractProcess):\n", " \"\"\"A 3 layer feed-forward network with LIF and Dense Processes.\"\"\"\n", "\n", " def __init__(self, trained_weights_path: str):\n", " super().__init__()\n", " \n", " # Using pre-trained weights and biases\n", " real_path_trained_wgts = os.path.realpath(trained_weights_path)\n", "\n", " wb_list = np.load(real_path_trained_wgts, encoding='latin1', allow_pickle=True)\n", " w0 = wb_list[0].transpose().astype(np.int32)\n", " w1 = wb_list[2].transpose().astype(np.int32)\n", " w2 = wb_list[4].transpose().astype(np.int32)\n", " b1 = wb_list[1].astype(np.int32)\n", " b2 = wb_list[3].astype(np.int32)\n", " b3 = wb_list[5].astype(np.int32)\n", "\n", " self.spikes_in = InPort(shape=(w0.shape[1],))\n", " self.spikes_out = OutPort(shape=(w2.shape[0],))\n", " self.w_dense0 = Var(shape=w0.shape, init=w0)\n", " self.b_lif1 = Var(shape=(w0.shape[0],), init=b1)\n", " self.w_dense1 = Var(shape=w1.shape, init=w1)\n", " self.b_lif2 = Var(shape=(w1.shape[0],), init=b2)\n", " self.w_dense2 = Var(shape=w2.shape, init=w2)\n", " self.b_output_lif = Var(shape=(w2.shape[0],), init=b3)\n", " \n", " # Up-level currents and voltages of LIF Processes\n", " # for resetting (see at the end of the tutorial)\n", " self.lif1_u = Var(shape=(w0.shape[0],), init=0)\n", " self.lif1_v = Var(shape=(w0.shape[0],), init=0)\n", " self.lif2_u = Var(shape=(w1.shape[0],), init=0)\n", " self.lif2_v = Var(shape=(w1.shape[0],), init=0)\n", " self.oplif_u = Var(shape=(w2.shape[0],), init=0)\n", " self.oplif_v = Var(shape=(w2.shape[0],), init=0)\n", " \n", " \n", "class OutputProcess(AbstractProcess):\n", " \"\"\"Process to gather spikes from 10 output LIF neurons and interpret the\n", " highest spiking rate as the classifier output\"\"\"\n", "\n", " def __init__(self, **kwargs):\n", " super().__init__()\n", " shape = (10,)\n", " n_img = kwargs.pop('num_images', 25)\n", " self.num_images = Var(shape=(1,), init=n_img)\n", " self.spikes_in = InPort(shape=shape)\n", " self.label_in = InPort(shape=(1,))\n", " self.spikes_accum = Var(shape=shape) # Accumulated spikes for classification\n", " self.num_steps_per_image = Var(shape=(1,), init=128)\n", " self.pred_labels = Var(shape=(n_img,))\n", " self.gt_labels = Var(shape=(n_img,))" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "#### ProcessModels for Python execution\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# Import parent classes for ProcessModels\n", "from lava.magma.core.model.sub.model import AbstractSubProcessModel\n", "from lava.magma.core.model.py.model import PyLoihiProcessModel\n", "\n", "# Import ProcessModel ports, data-types\n", "from lava.magma.core.model.py.ports import PyInPort, PyOutPort\n", "from lava.magma.core.model.py.type import LavaPyType\n", "\n", "# Import execution protocol and hardware resources\n", "from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol\n", "from lava.magma.core.resources import CPU\n", "\n", "# Import decorators\n", "from lava.magma.core.decorator import implements, requires\n", "\n", "# Import MNIST dataset\n", "from lava.utils.dataloader.mnist import MnistDataset\n", "np.set_printoptions(linewidth=np.inf)" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "**Decorators for ProcessModels**: \n", "- `@implements`: associates a ProcessModel with a Process through \n", "the argument `proc`. Using `protocol` argument, we will specify the \n", "synchronization protocol used by the ProcessModel. In this tutorial, \n", "all ProcessModels execute \n", "according to the `LoihiProtocol`. Which means, similar to the Loihi \n", "chip, each time-step is divided into _spiking_, _pre-management_, \n", "_post-management_, and _learning_ phases. It is necessary to specify \n", "behaviour of a ProcessModel during the spiking phase using `run_spk` \n", "function. Other phases are optional.\n", "- `@requires`: specifies the hardware resource on which a ProcessModel \n", "will be executed. In this tutorial, we will execute all ProcessModels \n", "on a CPU." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**SpikingInput ProcessModel**" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "@implements(proc=SpikeInput, protocol=LoihiProtocol)\n", "@requires(CPU)\n", "class PySpikeInputModel(PyLoihiProcessModel):\n", " num_images: int = LavaPyType(int, int, precision=32)\n", " spikes_out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, bool, precision=1)\n", " label_out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, np.int32,\n", " precision=32)\n", " num_steps_per_image: int = LavaPyType(int, int, precision=32)\n", " input_img: np.ndarray = LavaPyType(np.ndarray, int, precision=32)\n", " ground_truth_label: int = LavaPyType(int, int, precision=32)\n", " v: np.ndarray = LavaPyType(np.ndarray, int, precision=32)\n", " vth: int = LavaPyType(int, int, precision=32)\n", " \n", " def __init__(self, proc_params):\n", " super().__init__(proc_params=proc_params)\n", " self.mnist_dataset = MnistDataset()\n", " self.curr_img_id = 0\n", "\n", " def post_guard(self):\n", " \"\"\"Guard function for PostManagement phase.\n", " \"\"\"\n", " if self.time_step % self.num_steps_per_image == 1:\n", " return True\n", " return False\n", "\n", " def run_post_mgmt(self):\n", " \"\"\"Post-Management phase: executed only when guard function above \n", " returns True.\n", " \"\"\"\n", " img = self.mnist_dataset.images[self.curr_img_id]\n", " self.ground_truth_label = self.mnist_dataset.labels[self.curr_img_id]\n", " self.input_img = img.astype(np.int32) - 127\n", " self.v = np.zeros(self.v.shape)\n", " self.label_out.send(np.array([self.ground_truth_label]))\n", " self.curr_img_id += 1\n", "\n", " def run_spk(self):\n", " \"\"\"Spiking phase: executed unconditionally at every time-step\n", " \"\"\"\n", " self.v[:] = self.v + self.input_img\n", " s_out = self.v > self.vth\n", " self.v[s_out] = 0 # reset voltage to 0 after a spike\n", " self.spikes_out.send(s_out)" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "**ImageClassifier ProcessModel**\n", "\n", "Notice that the following process model is further decomposed into\n", "sub-Processes, which implement LIF neural dynamics and Dense connectivity. We\n", " will not go into the details of how these are implemented in this tutorial.\n", " \n", "Also notice that a *SubProcessModel* does not actually contain any concrete \n", "execution. This is handled by the ProcessModels of the constituent Processes." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "from lava.proc.lif.process import LIF\n", "from lava.proc.dense.process import Dense \n", "\n", "@implements(ImageClassifier)\n", "@requires(CPU)\n", "class PyImageClassifierModel(AbstractSubProcessModel):\n", " def __init__(self, proc):\n", " self.dense0 = Dense(weights=proc.w_dense0.init)\n", " self.lif1 = LIF(shape=(64,), bias_mant=proc.b_lif1.init, vth=400,\n", " dv=0, du=4095)\n", " self.dense1 = Dense(weights=proc.w_dense1.init)\n", " self.lif2 = LIF(shape=(64,), bias_mant=proc.b_lif2.init, vth=350,\n", " dv=0, du=4095)\n", " self.dense2 = Dense(weights=proc.w_dense2.init)\n", " self.output_lif = LIF(shape=(10,), bias_mant=proc.b_output_lif.init,\n", " vth=1, dv=0, du=4095)\n", "\n", " proc.spikes_in.connect(self.dense0.s_in)\n", " self.dense0.a_out.connect(self.lif1.a_in)\n", " self.lif1.s_out.connect(self.dense1.s_in)\n", " self.dense1.a_out.connect(self.lif2.a_in)\n", " self.lif2.s_out.connect(self.dense2.s_in)\n", " self.dense2.a_out.connect(self.output_lif.a_in)\n", " self.output_lif.s_out.connect(proc.spikes_out)\n", " \n", " # Create aliases of SubProcess variables\n", " proc.lif1_u.alias(self.lif1.u)\n", " proc.lif1_v.alias(self.lif1.v)\n", " proc.lif2_u.alias(self.lif2.u)\n", " proc.lif2_v.alias(self.lif2.v)\n", " proc.oplif_u.alias(self.output_lif.u)\n", " proc.oplif_v.alias(self.output_lif.v)" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "**OutputProcess ProcessModel**" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "@implements(proc=OutputProcess, protocol=LoihiProtocol)\n", "@requires(CPU)\n", "class PyOutputProcessModel(PyLoihiProcessModel):\n", " label_in: PyInPort = LavaPyType(PyInPort.VEC_DENSE, int, precision=32)\n", " spikes_in: PyInPort = LavaPyType(PyInPort.VEC_DENSE, bool, precision=1)\n", " num_images: int = LavaPyType(int, int, precision=32)\n", " spikes_accum: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=32)\n", " num_steps_per_image: int = LavaPyType(int, int, precision=32)\n", " pred_labels: np.ndarray = LavaPyType(np.ndarray, int, precision=32)\n", " gt_labels: np.ndarray = LavaPyType(np.ndarray, int, precision=32)\n", " \n", " def __init__(self, proc_params):\n", " super().__init__(proc_params=proc_params)\n", " self.current_img_id = 0\n", "\n", " def post_guard(self):\n", " \"\"\"Guard function for PostManagement phase.\n", " \"\"\"\n", " if self.time_step % self.num_steps_per_image == 0 and \\\n", " self.time_step > 1:\n", " return True\n", " return False\n", "\n", " def run_post_mgmt(self):\n", " \"\"\"Post-Management phase: executed only when guard function above \n", " returns True.\n", " \"\"\"\n", " gt_label = self.label_in.recv()\n", " pred_label = np.argmax(self.spikes_accum)\n", " self.gt_labels[self.current_img_id] = gt_label\n", " self.pred_labels[self.current_img_id] = pred_label\n", " self.current_img_id += 1\n", " self.spikes_accum = np.zeros_like(self.spikes_accum)\n", "\n", " def run_spk(self):\n", " \"\"\"Spiking phase: executed unconditionally at every time-step\n", " \"\"\"\n", " spk_in = self.spikes_in.recv()\n", " self.spikes_accum = self.spikes_accum + spk_in" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Connecting Processes" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "num_images = 25\n", "num_steps_per_image = 128\n", "\n", "# Create Process instances\n", "spike_input = SpikeInput(vth=1,\n", " num_images=num_images,\n", " num_steps_per_image=num_steps_per_image)\n", "mnist_clf = ImageClassifier(\n", " trained_weights_path=os.path.join('.', 'mnist_pretrained.npy'))\n", "output_proc = OutputProcess(num_images=num_images)\n", "\n", "# Connect Processes\n", "spike_input.spikes_out.connect(mnist_clf.spikes_in)\n", "mnist_clf.spikes_out.connect(output_proc.spikes_in)\n", "# Connect Input directly to Output for ground truth labels\n", "spike_input.label_out.connect(output_proc.label_in)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you receive an ``UnpicklingError`` when instantiating the ``ImageClassifier``, make sure to download the pretrained weights from GitHub LFS in the current directory using:\n", "```bash\n", " git lfs fetch\n", " git lfs pull\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Execution and results\n", "Below, we will run the classifier process in a loop of `num_images`\n", "number of iterations. Each iteration will run the Process for \n", "`num_steps_per_image` number of time-steps. \n", "\n", "We take this approach to clear the neural states of all three LIF \n", "layers inside the classifier after every image. We need to clear \n", "the neural states, because the network parameters were trained \n", "assuming clear neural states for each inference.\n", "\n", "> Note:\n", "Below we have used `Var.set()` function to set the values\n", "of internal state variables. The same behaviour can be \n", "achieved by using `RefPorts`. See the \n", "[RefPorts tutorial](../in_depth/tutorial07_remote_memory_access.ipynb)\n", "to learn more about how to use `RefPorts` to access internal \n", "state variables of Lava Processes." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Current image: 25\n", "Ground truth: [5 0 4 1 9 2 1 3 1 4 3 5 3 6 1 7 2 8 6 9 4 0 9 1 1]\n", "Predictions : [3 0 4 1 4 2 1 3 1 4 3 5 3 6 1 7 2 8 5 9 4 0 9 1 1]\n", "Accuracy : 88.0\n" ] } ], "source": [ "from lava.magma.core.run_conditions import RunSteps\n", "from lava.magma.core.run_configs import Loihi1SimCfg\n", "\n", "# Loop over all images\n", "for img_id in range(num_images):\n", " print(f\"\\rCurrent image: {img_id+1}\", end=\"\")\n", " \n", " # Run each image-inference for fixed number of steps\n", " mnist_clf.run(\n", " condition=RunSteps(num_steps=num_steps_per_image),\n", " run_cfg=Loihi1SimCfg(select_sub_proc_model=True,\n", " select_tag='fixed_pt'))\n", " \n", " # Reset internal neural state of LIF neurons\n", " mnist_clf.lif1_u.set(np.zeros((64,), dtype=np.int32))\n", " mnist_clf.lif1_v.set(np.zeros((64,), dtype=np.int32))\n", " mnist_clf.lif2_u.set(np.zeros((64,), dtype=np.int32))\n", " mnist_clf.lif2_v.set(np.zeros((64,), dtype=np.int32))\n", " mnist_clf.oplif_u.set(np.zeros((10,), dtype=np.int32))\n", " mnist_clf.oplif_v.set(np.zeros((10,), dtype=np.int32))\n", "\n", "# Gather ground truth and predictions before stopping exec\n", "ground_truth = output_proc.gt_labels.get().astype(np.int32)\n", "predictions = output_proc.pred_labels.get().astype(np.int32)\n", "\n", "# Stop the execution\n", "mnist_clf.stop()\n", "\n", "accuracy = np.sum(ground_truth==predictions)/ground_truth.size * 100\n", "\n", "print(f\"\\nGround truth: {ground_truth}\\n\"\n", " f\"Predictions : {predictions}\\n\"\n", " f\"Accuracy : {accuracy}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> **Important Note**:\n", ">\n", "> The classifier is a simple feed-forward model using pre-trained network \n", "> parameters (weights and biases). It illustrates how to build, compile and \n", "> run a functional model in Lava. Please refer to \n", "> [Lava-DL](https://github.com/lava-nc/lava-dl) to understand how to train \n", "> deep networks with Lava." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## How to learn more?\n", "\n", "#### Follow the links below for deep-dive tutorials on the concepts in this tutorial:\n", "- [Processes](../in_depth/tutorial02_processes.ipynb \"Tutorial on Processes\")\n", "- [ProcessModel](../in_depth/tutorial03_process_models.ipynb \"Tutorial on ProcessModels\")\n", "- [Execution](../in_depth/tutorial04_execution.ipynb \"Tutorial on Executing Processes\")\n", "- [SubProcessModels](../in_depth/tutorial06_hierarchical_processes.ipynb) or [Hierarchical Processes](../in_depth/tutorial06_hierarchical_processes.ipynb)\n", "- [RefPorts](../in_depth/tutorial07_remote_memory_access.ipynb)\n", "\n", "If you want to find out more about Lava, have a look at the [Lava documentation](https://lava-nc.org/ \"Lava Documentation\") or dive into the [source code](https://github.com/lava-nc/lava/ \"Lava Source Code\").\n", "\n", "To receive regular updates on the latest developments and releases of the Lava Software Framework please subscribe to the [INRC newsletter](http://eepurl.com/hJCyhb \"INRC Newsletter\")." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.8.10 ('lava_nx_env')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" }, "vscode": { "interpreter": { "hash": "7ebb4c32c029abbab1fd16ef4d8ac43152261b56d4033e55d2744ce843ecba08" } } }, "nbformat": 4, "nbformat_minor": 4 }