{ "cells": [ { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "*Copyright (C) 2021 Intel Corporation*
\n", "*SPDX-License-Identifier: BSD-3-Clause*
\n", "*See: https://spdx.org/licenses/*\n", "\n", "---\n", "\n", "# Spike-timing Dependent Plasticity (STDP)\n", "\n", "_**Motivation**: In this tutorial, we will demonstrate usage of a software model of Loihi's learning engine, exposed in Lava. This involves the LearningRule object for learning rule and other learning-related information encapsulation and the LearningDense Lava Process modelling learning-enabled connections._\n", "\n", "#### This tutorial assumes that you:\n", "- have the [Lava framework installed](../../in_depth/tutorial01_installing_lava.ipynb \"Tutorial on Installing Lava\")\n", "- are familiar with the [Process concept in Lava](../../in_depth/tutorial02_processes.ipynb \"Tutorial on Processes\")\n", "- are familiar with the [ProcessModel concept in Lava](../../in_depth/tutorial02_process_models.ipynb \"Tutorial on ProcessModels\")\n", "- are familiar with how to [connect Lava Processes](../../in_depth/tutorial05_connect_processes.ipynb \"Tutorial on connecting Processes\")\n", "\n", "This tutorial gives a bird's-eye view of how to make use of the available learning rules in Lavas Process Library. For this purpose, we will create a network of LIF and Dense processes with one plastic connection and generate frozen patterns of activity. We can easily choose between a floating point simulation of the learning engine and a fixed point simulation, which approximates the behavior on the Loihi neuromorphic hardware. We also will create monitors to observe the behavior of the weights and activity traces of the neurons and learning rules." ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## STDP from Lavas Process Library\n", "\n", "Let's first generate the random, frozen input and define all parameters for the network.\n", "\n", "### Parameters" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "import numpy as np\n", "\n", "# Set this tag to \"fixed_pt\" or \"floating_pt\" to choose the corresponding models.\n", "SELECT_TAG = \"floating_pt\"\n", "\n", "# LIF parameters\n", "if SELECT_TAG == \"fixed_pt\":\n", " du = 4095\n", " dv = 4095\n", "elif SELECT_TAG == \"floating_pt\":\n", " du = 1\n", " dv = 1\n", "vth = 240\n", "\n", "# Number of neurons per layer\n", "num_neurons = 1\n", "shape_lif = (num_neurons, )\n", "shape_conn = (num_neurons, num_neurons)\n", "\n", "# Connection parameters\n", "\n", "# SpikePattern -> LIF connection weight\n", "wgt_inp = np.eye(num_neurons) * 250\n", "\n", "# LIF -> LIF connection initial weight (learning-enabled)\n", "wgt_plast_conn = np.full(shape_conn, 50)\n", " \n", "# Number of simulation time steps\n", "num_steps = 200\n", "time = list(range(1, num_steps + 1))\n", "\n", "# Spike times\n", "spike_prob = 0.03\n", "\n", "# Create spike rasters\n", "np.random.seed(123)\n", "spike_raster_pre = np.zeros((num_neurons, num_steps))\n", "np.place(spike_raster_pre, np.random.rand(num_neurons, num_steps) < spike_prob, 1)\n", "\n", "spike_raster_post = np.zeros((num_neurons, num_steps))\n", "np.place(spike_raster_post, np.random.rand(num_neurons, num_steps) < spike_prob, 1)" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### Define STDP learning rule\n", "\n", "Next, lets instatiate the STDP learning rule from the Lava Process Library. The STDPLoihi learning rule provides the parameters as described in Gerstner and al. 1996 (see also http://www.scholarpedia.org/article/Spike-timing_dependent_plasticity)." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "from lava.proc.learning_rules.stdp_learning_rule import STDPLoihi" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "stdp = STDPLoihi(learning_rate=1,\n", " A_plus=-1,\n", " A_minus=1,\n", " tau_plus=10,\n", " tau_minus=10,\n", " t_epoch=4)" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### Create Network\n", "The following diagram depics the Lava Process architecture used in this tutorial. It consists of:\n", "- 2 Constant pattern generators for injection spike trains to LIF neurons.\n", "- 2 _LIF_ Processes representing pre- and post-synaptic Leaky Integrate-and-Fire neurons.\n", "- 1 _Dense_ Process representing learning-enable connection between LIF neurons.\n", "\n", ">**Note:** \n", "All neuronal population (spike generator, LIF) are composed of only 1 neuron in this tutorial." ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "\"Architecture.svg\"" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "#### The plastic connection Process\n", "We now instantiate our plastic Dense process. The Dense Process provides the following Vars and Ports relevant for plasticity:\n", "\n", "| Component | Name | Description |\n", "| :- | :- | :- |\n", "| **InPort** | `s_in_bap` | Receives spikes from post-synaptic neurons.\n", "| **Var** | `tag_2` | Delay synaptic variable.\n", "| | `tag_1` | Tag synaptic variable.\n", "| | `x0` | State of $x_0$ dependency.\n", "| | `tx` | Within-epoch spike times of pre-synaptic neurons.\n", "| | `x1` | State of $x_1$ trace.\n", "| | `x2` | State of $x_2$ trace.\n", "| | `y0` | State of $y_0$ dependency.\n", "| | `ty` | Within-epoch spike times of post-synaptic neurons.\n", "| | `y1` | State of $y_1$ trace.\n", "| | `y2` | State of $y_2$ trace.\n", "| | `y3` | State of $y_3$ trace.\n", "\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "pycharm": { "name": "#%%\n" }, "tags": [] }, "outputs": [], "source": [ "from lava.proc.lif.process import LIF\n", "from lava.proc.io.source import RingBuffer\n", "from lava.proc.dense.process import LearningDense, Dense" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "# Create input devices\n", "pattern_pre = RingBuffer(data=spike_raster_pre.astype(int))\n", "pattern_post = RingBuffer(data=spike_raster_post.astype(int))\n", "\n", "# Create input connectivity\n", "conn_inp_pre = Dense(weights=wgt_inp)\n", "conn_inp_post = Dense(weights=wgt_inp)\n", "\n", "# Create pre-synaptic neurons\n", "lif_pre = LIF(u=0,\n", " v=0,\n", " du=du,\n", " dv=du,\n", " bias_mant=0,\n", " bias_exp=0,\n", " vth=vth,\n", " shape=shape_lif,\n", " name='lif_pre')\n", "\n", "# Create plastic connection\n", "plast_conn = LearningDense(weights=wgt_plast_conn,\n", " learning_rule=stdp,\n", " name='plastic_dense')\n", "\n", "# Create post-synaptic neuron\n", "lif_post = LIF(u=0,\n", " v=0,\n", " du=du,\n", " dv=du,\n", " bias_mant=0,\n", " bias_exp=0,\n", " vth=vth,\n", " shape=shape_lif,\n", " name='lif_post')\n", "\n", "# Connect network\n", "pattern_pre.s_out.connect(conn_inp_pre.s_in)\n", "conn_inp_pre.a_out.connect(lif_pre.a_in)\n", "\n", "pattern_post.s_out.connect(conn_inp_post.s_in)\n", "conn_inp_post.a_out.connect(lif_post.a_in)\n", "\n", "lif_pre.s_out.connect(plast_conn.s_in)\n", "plast_conn.a_out.connect(lif_post.a_in)\n", "\n", "# Connect back-propagating actionpotential (BAP)\n", "lif_post.s_out.connect(plast_conn.s_in_bap)" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### Create monitors to observe traces" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "pycharm": { "name": "#%%\n" }, "tags": [] }, "outputs": [], "source": [ "from lava.proc.monitor.process import Monitor" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "pycharm": { "name": "#%%\n" }, "tags": [] }, "outputs": [], "source": [ "# Create monitors\n", "mon_pre_trace = Monitor()\n", "mon_post_trace = Monitor()\n", "mon_pre_spikes = Monitor()\n", "mon_post_spikes = Monitor()\n", "mon_weight = Monitor()\n", "\n", "# Connect monitors\n", "mon_pre_trace.probe(plast_conn.x1, num_steps)\n", "mon_post_trace.probe(plast_conn.y1, num_steps)\n", "mon_pre_spikes.probe(lif_pre.s_out, num_steps)\n", "mon_post_spikes.probe(lif_post.s_out, num_steps)\n", "mon_weight.probe(plast_conn.weights, num_steps)" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" }, "tags": [] }, "source": [ "### Running" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "from lava.magma.core.run_conditions import RunSteps\n", "from lava.magma.core.run_configs import Loihi2SimCfg" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "# Running\n", "pattern_pre.run(condition=RunSteps(num_steps=num_steps), run_cfg=Loihi2SimCfg(select_tag=SELECT_TAG))" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "# Get data from monitors\n", "pre_trace = mon_pre_trace.get_data()['plastic_dense']['x1']\n", "post_trace = mon_post_trace.get_data()['plastic_dense']['y1']\n", "pre_spikes = mon_pre_spikes.get_data()['lif_pre']['s_out']\n", "post_spikes = mon_post_spikes.get_data()['lif_post']['s_out']\n", "weights = mon_weight.get_data()['plastic_dense']['weights'][:, :, 0]" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "# Stopping\n", "pattern_pre.stop()" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "### Results\n", "\n", "Now, we can take a look at the results of the simulation. " ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "pycharm": { "name": "#%%\n" }, "tags": [] }, "outputs": [], "source": [ "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "#### Plot spike trains" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plotting pre- and post- spike arrival\n", "def plot_spikes(spikes, legend, colors):\n", " offsets = list(range(1, len(spikes) + 1))\n", " \n", " plt.figure(figsize=(10, 3))\n", " \n", " spikes_plot = plt.eventplot(positions=spikes, \n", " lineoffsets=offsets,\n", " linelength=0.9,\n", " colors=colors)\n", " \n", " plt.title(\"Spike arrival\")\n", " plt.xlabel(\"Time steps\")\n", " plt.ylabel(\"Neurons\")\n", " plt.yticks(ticks=offsets, labels=legend)\n", " \n", " plt.show()\n", "\n", "# Plot spikes\n", "plot_spikes(spikes=[np.where(post_spikes[:, 0])[0], np.where(pre_spikes[:, 0])[0]], \n", " legend=['Post', 'Pre'], \n", " colors=['#370665', '#f14a16'])" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "#### Plot traces" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plotting trace dynamics\n", " \n", "def plot_time_series(time, time_series, ylabel, title):\n", " plt.figure(figsize=(10, 1))\n", " \n", " plt.step(time, time_series)\n", " \n", " plt.title(title)\n", " plt.xlabel(\"Time steps\")\n", " plt.ylabel(ylabel)\n", " \n", " plt.show()\n", " \n", "# Plotting pre trace dynamics\n", "plot_time_series(time=time, time_series=pre_trace, ylabel=\"Trace value\", title=\"Pre trace\")\n", "# Plotting post trace dynamics\n", "plot_time_series(time=time, time_series=post_trace, ylabel=\"Trace value\", title=\"Post trace\")\n", "# Plotting weight dynamics\n", "plot_time_series(time=time, time_series=weights, ylabel=\"Weight value\", title=\"Weight dynamics\")" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "#### Plot STDP learning window and weight changes" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "def extract_stdp_weight_changes(time, spikes_pre, spikes_post, wgt):\n", " # Compute the weight changes for every weight change event\n", " w_diff = np.zeros(wgt.shape)\n", " w_diff[1:] = np.diff(wgt)\n", "\n", " w_diff_non_zero = np.where(w_diff != 0)\n", " dw = w_diff[w_diff_non_zero].tolist()\n", "\n", " # Find the absolute time of every weight change event\n", " time = np.array(time)\n", " t_non_zero = time[w_diff_non_zero]\n", "\n", " # Compute the difference between post and pre synaptic spike time for every weight change event\n", " spikes_pre = np.array(spikes_pre)\n", " spikes_post = np.array(spikes_post)\n", " dt = []\n", " for i in range(0, len(dw)):\n", " time_stamp = t_non_zero[i]\n", " t_post = (spikes_post[np.where(spikes_post <= time_stamp)])[-1]\n", " t_pre = (spikes_pre[np.where(spikes_pre <= time_stamp)])[-1]\n", " dt.append(t_post-t_pre)\n", "\n", " return np.array(dt), np.array(dw)\n", " \n", "def plot_stdp(time, spikes_pre, spikes_post, wgt, \n", " on_pre_stdp, y1_impulse, y1_tau, \n", " on_post_stdp, x1_impulse, x1_tau):\n", " # Derive weight changes as a function of time differences\n", " diff_t, diff_w = extract_stdp_weight_changes(time, spikes_pre, spikes_post, wgt)\n", " \n", " # Derive learning rule coefficients\n", " on_pre_stdp = eval(str(on_pre_stdp).replace(\"^\", \"**\"))\n", " a_neg = on_pre_stdp * y1_impulse\n", " on_post_stdp = eval(str(on_post_stdp).replace(\"^\", \"**\"))\n", " a_pos = on_post_stdp * x1_impulse\n", " \n", " # Derive x-axis limit (absolute value)\n", " max_abs_dt = np.maximum(np.abs(np.max(diff_t)), np.abs(np.min(diff_t)))\n", " \n", " # Derive x-axis for learning window computation (negative part)\n", " x_neg = np.linspace(-max_abs_dt, 0, 1000)\n", " # Derive learning window (negative part)\n", " w_neg = a_neg * np.exp(x_neg / y1_tau)\n", " \n", " # Derive x-axis for learning window computation (positive part)\n", " x_pos = np.linspace(0, max_abs_dt, 1000)\n", " # Derive learning window (positive part)\n", " w_pos = a_pos * np.exp(- x_pos / x1_tau)\n", " \n", " plt.figure(figsize=(10, 5))\n", " \n", " plt.scatter(diff_t, diff_w, label=\"Weight changes\", color=\"b\")\n", " \n", " plt.plot(x_neg, w_neg, label=\"W-\", color=\"r\")\n", " plt.plot(x_pos, w_pos, label=\"W+\", color=\"g\")\n", " \n", " plt.title(\"STDP weight changes - Learning window\")\n", " plt.xlabel('t_post - t_pre')\n", " plt.ylabel('Weight change')\n", " plt.legend()\n", " plt.grid()\n", " \n", " plt.show()\n", "\n", "# Plot STDP window\n", "plot_stdp(time, np.where(pre_spikes[:, 0]), np.where(post_spikes[:, 0]), weights[:, 0], \n", " stdp.A_plus, stdp.y1_impulse, stdp.tau_plus, \n", " stdp.A_minus, stdp.x1_impulse, stdp.tau_minus)" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "As can be seen, the actual weight changes follow the defined STDP with a certain amout of noise. If the tag is set to `fixed_pt`, the weight changes get more quantized, but still follow the correct trend." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" }, "vscode": { "interpreter": { "hash": "adaac2fd6fcd86ccecf37a646988a7a33da53e5b6c7446bb18fec9222bbb1862" } } }, "nbformat": 4, "nbformat_minor": 4 }