From b016e178da5c4b5535b3164803a503319c6df3be Mon Sep 17 00:00:00 2001 From: Jakub Novak Date: Fri, 29 Mar 2024 21:24:04 -0700 Subject: [PATCH] Add on display data --- python/e2b_code_interpreter/main.py | 4 +++- python/e2b_code_interpreter/messaging.py | 21 ++++++++++++++------- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/python/e2b_code_interpreter/main.py b/python/e2b_code_interpreter/main.py index 4c0c047..e39112e 100644 --- a/python/e2b_code_interpreter/main.py +++ b/python/e2b_code_interpreter/main.py @@ -64,6 +64,7 @@ def exec_cell( kernel_id: Optional[str] = None, on_stdout: Optional[Callable[[ProcessMessage], Any]] = None, on_stderr: Optional[Callable[[ProcessMessage], Any]] = None, + on_display_data: Optional[Callable[[Dict[str, Any]], Any]] = None, timeout: Optional[float] = TIMEOUT, ) -> Execution: """ @@ -73,6 +74,7 @@ def exec_cell( :param kernel_id: The ID of the kernel to execute the code on. If not provided, the default kernel is used. :param on_stdout: A callback function to handle standard output messages from the code execution. :param on_stderr: A callback function to handle standard error messages from the code execution. + :param on_display_data: A callback function to handle display data messages from the code execution. :param timeout: Timeout for the call :return: Result of the execution @@ -89,7 +91,7 @@ def exec_cell( logger.debug(f"Creating new websocket connection to kernel {kernel_id}") ws = self._connect_to_kernel_ws(kernel_id, timeout=timeout) - session_id = ws.send_execution_message(code, on_stdout, on_stderr) + session_id = ws.send_execution_message(code, on_stdout, on_stderr, on_display_data) logger.debug( f"Sent execution message to kernel {kernel_id}, session_id: {session_id}" ) diff --git a/python/e2b_code_interpreter/messaging.py b/python/e2b_code_interpreter/messaging.py index 10fe14a..fa1c840 100644 --- a/python/e2b_code_interpreter/messaging.py +++ b/python/e2b_code_interpreter/messaging.py @@ -14,7 +14,7 @@ from e2b.utils.future import DeferredFuture from pydantic import ConfigDict, PrivateAttr, BaseModel -from e2b_code_interpreter.models import Execution, Data, Error +from e2b_code_interpreter.models import Execution, Data, Error, MIMEType logger = logging.getLogger(__name__) @@ -26,18 +26,21 @@ class CellExecution: """ input_accepted: bool = False - on_stdout: Optional[Callable[[Any], None]] = None - on_stderr: Optional[Callable[[Any], None]] = None + on_stdout: Optional[Callable[[ProcessMessage], None]] = None + on_stderr: Optional[Callable[[ProcessMessage], None]] = None + on_display_data: Optional[Callable[[Dict[MIMEType, str]], None]] = None def __init__( self, - on_stdout: Optional[Callable[[Any], None]] = None, - on_stderr: Optional[Callable[[Any], None]] = None, + on_stdout: Optional[Callable[[ProcessMessage], None]] = None, + on_stderr: Optional[Callable[[ProcessMessage], None]] = None, + on_display_data: Optional[Callable[[Dict[MIMEType, str]], None]] = None, ): self.partial_result = Execution() self.result = Future() self.on_stdout = on_stdout self.on_stderr = on_stderr + self.on_display_data = on_display_data class JupyterKernelWebSocket(BaseModel): @@ -126,8 +129,9 @@ def _get_execute_request(msg_id: str, code: str) -> str: def send_execution_message( self, code: str, - on_stdout: Optional[Callable[[Any], None]] = None, - on_stderr: Optional[Callable[[Any], None]] = None, + on_stdout: Optional[Callable[[ProcessMessage], None]] = None, + on_stderr: Optional[Callable[[ProcessMessage], None]] = None, + on_display_data: Optional[Callable[[Dict[MIMEType, str]], None]] = None, ) -> str: message_id = str(uuid.uuid4()) logger.debug(f"Sending execution message: {message_id}") @@ -135,6 +139,7 @@ def send_execution_message( self._cells[message_id] = CellExecution( on_stdout=on_stdout, on_stderr=on_stderr, + on_display_data=on_display_data, ) request = self._get_execute_request(message_id, code) self._queue_in.put(request) @@ -196,6 +201,8 @@ def _receive_message(self, data: dict): elif data["msg_type"] in "display_data": result.data.append(Data(is_main_result=False, data=data["content"]["data"])) + if cell.on_display_data: + cell.on_display_data(data["content"]["data"]) elif data["msg_type"] == "execute_result": result.data.append(Data(is_main_result=True, data=data["content"]["data"])) elif data["msg_type"] == "status":