Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed blocking message send with asyncio #120

Merged
merged 1 commit into from
Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions zt_backend/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import threading
import traceback
import sys
import asyncio
import trace

class ConnectionManager:
Expand Down Expand Up @@ -88,6 +89,7 @@ def kill(self):
user_states={}
user_timers={}
user_threads={}
user_message_tasks={}
notebook_state=UserState('')
run_mode = settings.run_mode

Expand All @@ -105,6 +107,7 @@ def ws_url():
async def run_code(websocket: WebSocket):
global current_thread
if(run_mode=='dev'):
message_send = asyncio.create_task(websocket_message_sender(notebook_state))
await manager.connect(websocket)
try:
while True:
Expand All @@ -115,6 +118,8 @@ async def run_code(websocket: WebSocket):
current_thread.start()
except WebSocketDisconnect:
manager.disconnect(websocket)
finally:
message_send.cancel()

@router.websocket("/ws/component_run")
async def component_run(websocket: WebSocket):
Expand Down Expand Up @@ -255,6 +260,7 @@ async def load_notebook(websocket: WebSocket):
userId = str(uuid.uuid4())
notebook_start.userId = userId
user_states[userId]=UserState(userId)
user_message_tasks[userId]=asyncio.create_task(websocket_message_sender(user_states[userId]))
timer_set(userId, 1800)
cells = []
components={}
Expand Down Expand Up @@ -318,14 +324,32 @@ async def stop_execution(websocket: WebSocket):
except WebSocketDisconnect:
manager.disconnect(websocket)

@router.on_event('shutdown')
def shutdown():
if current_thread:
current_thread.kill()
for user_id in user_threads:
if user_threads[user_id]:
user_threads[user_id].kill()
for user_id in user_timers:
if user_timers[user_id]:
user_timers[user_id].cancel()
for user_id in user_message_tasks:
if user_message_tasks[user_id]:
user_message_tasks[user_id].cancel()

def remove_user_state(user_id):
try:
if user_id in user_timers:
# Cancel and remove the associated timer
timer = user_timers[user_id]
message_sender = user_message_tasks[user_id]
if timer:
timer.cancel()
del user_timers[user_id]
if message_sender:
message_sender.cancel()
del user_message_tasks[user_id]
if user_id in user_states: del user_states[user_id]
logger.debug("User state removed for user %s", user_id)
except Exception as e:
Expand Down
9 changes: 4 additions & 5 deletions zt_backend/runner/execute_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def execute_request(request: request.Request, state: UserState):

for code_cell_id in downstream_cells:
code_cell = dependency_graph.cells[code_cell_id]
asyncio.run(execution_state.websocket.send_json({"cell_id": code_cell_id, "clear_output": True}))
execution_state.message_queue.put_nowait({"cell_id": code_cell_id, "clear_output": True})
execution_state.io_output = StringIO()
execute_cell(code_cell_id, code_cell, component_globals, dependency_graph, execution_state)
try:
Expand All @@ -101,7 +101,7 @@ def execute_request(request: request.Request, state: UserState):

cell_response = response.CellResponse(id=code_cell_id, layout=layout, components=execution_state.current_cell_components, output=execution_state.io_output.getvalue())
cell_outputs.append(cell_response)
asyncio.run(execution_state.websocket.send_json(cell_response.model_dump_json()))
execution_state.message_queue.put_nowait(cell_response.model_dump_json())
execution_state.current_cell_components.clear()
execution_state.current_cell_layout.clear()
execution_state.cell_outputs_dict['previous_dependecy_graph'] = dependency_graph
Expand All @@ -111,16 +111,15 @@ def execute_request(request: request.Request, state: UserState):
execution_response = response.Response(cells=cell_outputs)
if settings.run_mode=='dev':
globalStateUpdate(run_response=execution_response)
asyncio.run(execution_state.websocket.send_json({"complete": True}))
return execution_response
execution_state.message_queue.put_nowait({"complete": True})

def execute_cell(code_cell_id, code_cell, component_globals, dependency_graph, execution_state: UserContext):
class WebSocketStream:
def write(self, message):
user_state = UserContext.get_state()
if user_state:
user_state.io_output.write(message)
asyncio.run(user_state.websocket.send_json({"cell_id": code_cell_id, "output": message}))
user_state.message_queue.put_nowait({"cell_id": code_cell_id, "output": message})

def flush(self):
pass
Expand Down
2 changes: 2 additions & 0 deletions zt_backend/runner/user_state.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import threading
import asyncio

class UserState:
def __init__(self, user_id):
Expand All @@ -11,6 +12,7 @@ def __init__(self, user_id):
self.cell_outputs_dict = {}
self.websocket = None
self.io_output = None
self.message_queue = asyncio.Queue()

class UserContext:
_state = threading.local()
Expand Down
8 changes: 7 additions & 1 deletion zt_backend/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import OrderedDict
from zt_backend.runner.user_state import UserState
from zt_backend.models import request, notebook, response
from dictdiffer import diff
import logging
Expand Down Expand Up @@ -156,4 +157,9 @@ def save_toml():
def get_code_completions(cell_id:str, code: str, line: int, column: int) -> list:
script = jedi.Script(code)
completions = script.complete(line, column)
return {"cell_id": cell_id, "completions": [{"label": completion.name, "type": completion.type} for completion in completions]}
return {"cell_id": cell_id, "completions": [{"label": completion.name, "type": completion.type} for completion in completions]}

async def websocket_message_sender(execution_state: UserState):
while True:
message = await execution_state.message_queue.get()
await execution_state.websocket.send_json(message)