Skip to content

Commit

Permalink
Runtime error handling (lava-nc#135)
Browse files Browse the repository at this point in the history
* - Initial commit on passing exception from ProcessModels up to the runtime

* - Initial commit on passing exception from ProcessModels up to the runtime

* - Updated commit on passing exception from ProcessModels up to the runtime

* - Updated commit on passing exception from ProcessModels up to the runtime
  • Loading branch information
PhilippPlank authored Dec 4, 2021
1 parent fa908f4 commit 31495d6
Show file tree
Hide file tree
Showing 7 changed files with 235 additions and 38 deletions.
63 changes: 35 additions & 28 deletions src/lava/magma/core/model/py/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,34 +125,41 @@ def run(self):
self.process_to_service_ack.send(MGMT_RESPONSE.TERMINATED)
self.join()
return
# Spiking phase - increase time step
if enum_equal(phase, PyLoihiProcessModel.Phase.SPK):
self.current_ts += 1
self.run_spk()
self.process_to_service_ack.send(MGMT_RESPONSE.DONE)
# Pre-management phase
elif enum_equal(phase, PyLoihiProcessModel.Phase.PRE_MGMT):
# Enable via guard method
if self.pre_guard():
self.run_pre_mgmt()
self.process_to_service_ack.send(MGMT_RESPONSE.DONE)
# Learning phase
elif enum_equal(phase, PyLoihiProcessModel.Phase.LRN):
# Enable via guard method
if self.lrn_guard():
self.run_lrn()
self.process_to_service_ack.send(MGMT_RESPONSE.DONE)
# Post-management phase
elif enum_equal(phase, PyLoihiProcessModel.Phase.POST_MGMT):
# Enable via guard method
if self.post_guard():
self.run_post_mgmt()
self.process_to_service_ack.send(MGMT_RESPONSE.DONE)
# Host phase - called at the last time step before STOP
elif enum_equal(phase, PyLoihiProcessModel.Phase.HOST):
pass
else:
raise ValueError(f"Wrong Phase Info Received : {phase}")
try:
# Spiking phase - increase time step
if enum_equal(phase, PyLoihiProcessModel.Phase.SPK):
self.current_ts += 1
self.run_spk()
self.process_to_service_ack.send(MGMT_RESPONSE.DONE)
# Pre-management phase
elif enum_equal(phase, PyLoihiProcessModel.Phase.PRE_MGMT):
# Enable via guard method
if self.pre_guard():
self.run_pre_mgmt()
self.process_to_service_ack.send(MGMT_RESPONSE.DONE)
# Learning phase
elif enum_equal(phase, PyLoihiProcessModel.Phase.LRN):
# Enable via guard method
if self.lrn_guard():
self.run_lrn()
self.process_to_service_ack.send(MGMT_RESPONSE.DONE)
# Post-management phase
elif enum_equal(phase, PyLoihiProcessModel.Phase.POST_MGMT):
# Enable via guard method
if self.post_guard():
self.run_post_mgmt()
self.process_to_service_ack.send(MGMT_RESPONSE.DONE)
# Host phase - called at the last time step before STOP
elif enum_equal(phase, PyLoihiProcessModel.Phase.HOST):
pass
else:
raise ValueError(f"Wrong Phase Info Received : {phase}")
except Exception as inst:
# Inform runtime service about termination
self.process_to_service_ack.send(MGMT_RESPONSE.ERROR)
self.join()
raise inst

elif action == 'req':
# Handle get/set Var requests from runtime service
self._handle_get_set_var()
Expand Down
24 changes: 23 additions & 1 deletion src/lava/magma/runtime/message_infrastructure/multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from lava.magma.compiler.builder import PyProcessBuilder, \
AbstractRuntimeServiceBuilder

from multiprocessing import Process as SystemProcess
import multiprocessing as mp
from multiprocessing.managers import SharedMemoryManager
import traceback

from lava.magma.compiler.channels.interfaces import ChannelType, Channel
from lava.magma.compiler.channels.pypychannel import PyPyChannel
Expand All @@ -18,6 +19,27 @@
import MessageInfrastructureInterface


class SystemProcess(mp.Process):
def __init__(self, *args, **kwargs):
mp.Process.__init__(self, *args, **kwargs)
self._pconn, self._cconn = mp.Pipe()
self._exception = None

def run(self):
try:
mp.Process.run(self)
self._cconn.send(None)
except Exception as e:
tb = traceback.format_exc()
self._cconn.send((e, tb))

@property
def exception(self):
if self._pconn.poll():
self._exception = self._pconn.recv()
return self._exception


class MultiProcessing(MessageInfrastructureInterface):
"""Implements message passing using shared memory and multiprocessing"""
def __init__(self):
Expand Down
4 changes: 3 additions & 1 deletion src/lava/magma/runtime/mgmt_token_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,7 @@ class MGMT_RESPONSE:
"""Signfies Ack or Finished with the Command"""
TERMINATED = enum_to_np(-1)
"""Signifies Termination"""
PAUSED = enum_to_np(-2)
ERROR = enum_to_np(-2)
"""Signifies Error raised"""
PAUSED = enum_to_np(-3)
"""Signifies Execution State to be Paused"""
17 changes: 16 additions & 1 deletion src/lava/magma/runtime/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,22 @@ def _run(self, run_condition):
for recv_port in self.service_to_runtime_ack:
data = recv_port.recv()
if not enum_equal(data, MGMT_RESPONSE.DONE):
raise RuntimeError(f"Runtime Received {data}")
if enum_equal(data, MGMT_RESPONSE.ERROR):
# Receive all errors from the ProcessModels
error_cnt = 0
for actors in \
self._messaging_infrastructure.actors:
actors.join()
if actors.exception:
_, traceback = actors.exception
print(traceback)
error_cnt += 1

raise RuntimeError(
f"{error_cnt} Exception(s) occurred. See "
f"output above for details.")
else:
raise RuntimeError(f"Runtime Received {data}")
if run_condition.blocking:
self.current_ts += self.num_steps
self._is_running = False
Expand Down
16 changes: 12 additions & 4 deletions src/lava/magma/runtime/runtime_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,11 +228,19 @@ def run(self):
# ProcessModels respond with DONE if not HOST phase
if not enum_equal(
phase, LoihiPyRuntimeService.Phase.HOST):
rsps = self._get_pm_resp()
for rsp in rsps:

for rsp in self._get_pm_resp():
if not enum_equal(rsp, MGMT_RESPONSE.DONE):
raise ValueError(
f"Wrong Response Received : {rsp}")
if enum_equal(rsp, MGMT_RESPONSE.ERROR):
# Forward error to runtime
self.service_to_runtime_ack.send(
MGMT_RESPONSE.ERROR)
# stop all other pm
self._send_pm_cmd(MGMT_COMMAND.STOP)
return
else:
raise ValueError(
f"Wrong Response Received : {rsp}")

# If HOST phase (last time step ended) break the loop
if enum_equal(
Expand Down
146 changes: 146 additions & 0 deletions tests/lava/magma/runtime/test_exception_handling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright (C) 2021 Intel Corporation
# SPDX-License-Identifier: BSD-3-Clause
# See: https://spdx.org/licenses/

import unittest

from lava.magma.core.decorator import implements, requires, tag
from lava.magma.core.model.py.model import PyLoihiProcessModel
from lava.magma.core.model.py.ports import PyOutPort, PyInPort
from lava.magma.core.model.py.type import LavaPyType
from lava.magma.core.process.ports.ports import OutPort, InPort
from lava.magma.core.process.process import AbstractProcess
from lava.magma.core.resources import CPU
from lava.magma.core.run_configs import Loihi1SimCfg
from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol
from lava.magma.core.run_conditions import RunSteps


# A minimal process with an OutPort
class P1(AbstractProcess):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.out = OutPort(shape=(2,))


# A minimal process with an InPort
class P2(AbstractProcess):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.inp = InPort(shape=(2,))


# A minimal process with an InPort
class P3(AbstractProcess):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.inp = InPort(shape=(2,))


# A minimal PyProcModel implementing P1
@implements(proc=P1, protocol=LoihiProtocol)
@requires(CPU)
@tag('floating_pt')
class PyProcModel1(PyLoihiProcessModel):
out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, int)

def run_spk(self):
if self.current_ts > 1:
# Raise exception
raise AssertionError("All the error info")


# A minimal PyProcModel implementing P2
@implements(proc=P2, protocol=LoihiProtocol)
@requires(CPU)
@tag('floating_pt')
class PyProcModel2(PyLoihiProcessModel):
inp: PyInPort = LavaPyType(PyInPort.VEC_DENSE, int)

def run_spk(self):
if self.current_ts > 1:
# Raise exception
raise TypeError("All the error info")


# A minimal PyProcModel implementing P3
@implements(proc=P3, protocol=LoihiProtocol)
@requires(CPU)
@tag('floating_pt')
class PyProcModel3(PyLoihiProcessModel):
inp: PyInPort = LavaPyType(PyInPort.VEC_DENSE, int)

def run_spk(self):
...


class TestExceptionHandling(unittest.TestCase):
def test_one_pm(self):
"""Checks the forwarding of exceptions within a ProcessModel to the
runtime."""

# Create an instance of P1
proc = P1()

# Run the network for 1 time step -> no exception
proc.run(condition=RunSteps(num_steps=1), run_cfg=Loihi1SimCfg())

# Run the network for another time step -> expect exception
with self.assertRaises(RuntimeError) as context:
proc.run(condition=RunSteps(num_steps=1), run_cfg=Loihi1SimCfg())

exception = context.exception
self.assertEqual(RuntimeError, type(exception))
# 1 exception in the ProcessModel expected
self.assertTrue('1 Exception(s) occurred' in str(exception))

def test_two_pm(self):
"""Checks the forwarding of exceptions within two ProcessModel to the
runtime."""

# Create a sender instance of P1 and a receiver instance of P2
sender = P1()
recv = P2()

# Connect sender with receiver
sender.out.connect(recv.inp)

# Run the network for 1 time step -> no exception
sender.run(condition=RunSteps(num_steps=1), run_cfg=Loihi1SimCfg())

# Run the network for another time step -> expect exception
with self.assertRaises(RuntimeError) as context:
sender.run(condition=RunSteps(num_steps=1), run_cfg=Loihi1SimCfg())

exception = context.exception
self.assertEqual(RuntimeError, type(exception))
# 2 Exceptions in the ProcessModels expected
self.assertTrue('2 Exception(s) occurred' in str(exception))

def test_three_pm(self):
"""Checks the forwarding of exceptions within three ProcessModel to the
runtime."""

# Create a sender instance of P1 and receiver instances of P2 and P3
sender = P1()
recv1 = P2()
recv2 = P3()

# Connect sender with receiver
sender.out.connect([recv1.inp, recv2.inp])

# Run the network for 1 time step -> no exception
sender.run(condition=RunSteps(num_steps=1), run_cfg=Loihi1SimCfg())

# Run the network for another time step -> expect exception
with self.assertRaises(RuntimeError) as context:
sender.run(condition=RunSteps(num_steps=1), run_cfg=Loihi1SimCfg())

exception = context.exception
self.assertEqual(RuntimeError, type(exception))
# 2 Exceptions in the ProcessModels expected
self.assertTrue('2 Exception(s) occurred' in str(exception))


if __name__ == '__main__':
unittest.main(buffer=True)
3 changes: 0 additions & 3 deletions tests/lava/magma/runtime/test_loihi_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,6 @@ def pre_guard(self):
def lrn_guard(self):
return False

def host_guard(self):
return True


class TestProcess(unittest.TestCase):
def test_synchronization_single_process_model(self):
Expand Down

0 comments on commit 31495d6

Please sign in to comment.