PilotNet SNN Example
Network excange module is available as lava.lib.dl.netx.{hdf5, blocks, utils}
. * hdf5
implements automatic network generation. * blocks
implements individual layer blocks. * utils
implements hdf5 reading utilities.
[1]:
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from lava.magma.core.run_configs import Loihi1SimCfg
from lava.magma.core.run_conditions import RunSteps
from lava.proc import io
from lava.magma.core.process.variable import Var
from lava.magma.core.process.ports.ports import RefPort
from lava.lib.dl import netx
from dataset import PilotNetDataset
Create network block
A network block can be created by simply instantiating netx.hdf5.Network
with the path of the desired hdf5 network description file. * The input layer is accessible as net.in_layer
. * The output layer is accessible as net.out_layer
. * All the constituent layers are accessible as as a list: net.layers
.
[2]:
net = netx.hdf5.Network(net_config='network.net')
print(net)
| Type | W | H | C | ker | str | pad | dil | grp |delay|
|Input | 200| 66| 3| | | | | |False|
|Conv | 99| 32| 24| 3, 3| 2, 2| 0, 0| 1, 1| 1|False|
|Conv | 49| 15| 36| 3, 3| 2, 2| 0, 0| 1, 1| 1|False|
|Conv | 24| 7| 48| 3, 3| 2, 2| 0, 0| 1, 1| 1|False|
|Conv | 22| 4| 64| 3, 3| 1, 2| 0, 1| 1, 1| 1|False|
|Conv | 20| 2| 64| 3, 3| 1, 1| 0, 0| 1, 1| 1|False|
|Dense | 1| 1| 100| | | | | |False|
|Dense | 1| 1| 50| | | | | |False|
|Dense | 1| 1| 10| | | | | |False|
|Dense | 1| 1| 1| | | | | |False|
[3]:
print(f'There are {len(net)} layers in network:')
for l in net.layers:
print(f'{l.__class__.__name__:5s} : {l.name:10s}, shape : {l.shape}')
There are 10 layers in network:
Input : Process_1 , shape : (200, 66, 3)
Conv : Process_3 , shape : (99, 32, 24)
Conv : Process_6 , shape : (49, 15, 36)
Conv : Process_9 , shape : (24, 7, 48)
Conv : Process_12, shape : (22, 4, 64)
Conv : Process_15, shape : (20, 2, 64)
Dense : Process_18, shape : (100,)
Dense : Process_21, shape : (50,)
Dense : Process_24, shape : (10,)
Dense : Process_27, shape : (1,)
[4]:
num_samples = 201
steps_per_sample = 16
readout_offset = (steps_per_sample - 1) + len(net.layers)
num_steps = num_samples * steps_per_sample
Create Dataset instance
Typically the user would write it or provide it.
[5]:
full_set = PilotNetDataset(
path='../data',
transform=net.in_layer.transform, # input transform
visualize=True, # visualize ensures the images are returned in sequence
sample_offset=10550,
)
train_set = PilotNetDataset(
path='../data',
transform=net.in_layer.transform, # input transform
train=True,
)
test_set = PilotNetDataset(
path='../data',
transform=net.in_layer.transform, # input transform
train=False,
)
Instantiate Dataloader
[6]:
dataloader = io.dataloader.StateDataloader(
dataset=full_set,
interval=steps_per_sample,
)
Sample: 10550
Connect Input and Output
[7]:
gt_logger = io.sink.RingBuffer(shape=(1,), buffer=num_samples)
output_logger = io.sink.Read(
num_samples,
interval=steps_per_sample,
offset=readout_offset
)
# reset
for i, l in enumerate(net.layers[:-1]):
u_resetter = io.reset.Reset(interval=steps_per_sample, offset=i)
v_resetter = io.reset.Reset(interval=steps_per_sample, offset=i)
u_resetter.connect_var(l.neuron.u)
v_resetter.connect_var(l.neuron.v)
dataloader.ground_truth.connect(gt_logger.a_in)
dataloader.connect_var(net.in_layer.neuron.bias)
output_logger.connect_var(net.out_layer.neuron.v)
Run the network
[8]:
run_config = Loihi1SimCfg(select_tag='fixed_pt')
[9]:
net.run(condition=RunSteps(num_steps=num_steps), run_cfg=run_config)
results = output_logger.data.get().flatten()
net.stop()
Sample: 10750
Evaluate Results
[10]:
results = results.flatten()/steps_per_sample/32/64
results = results[1:] - results[:-1]
gt = np.load('3x3pred.npy')
[12]:
plt.figure(figsize=(15, 10))
plt.plot(gt, linewidth=5, label='Loihi ground Truth')
plt.plot(results, label='Lava output')
plt.xlabel(f'Sample frames (+{full_set.sample_offset})')
plt.ylabel('Steering angle (radians)')
plt.legend()
[12]:
<matplotlib.legend.Legend at 0x7f9ebfe7fb80>

[13]:
error = np.sum((gt - results)**2)
print(f'{error=}')
error=0.0