Skip to content

Commit

Permalink
Add on display data
Browse files Browse the repository at this point in the history
  • Loading branch information
jakubno committed Mar 30, 2024
1 parent 2d9bbf0 commit b016e17
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
4 changes: 3 additions & 1 deletion python/e2b_code_interpreter/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def exec_cell(
kernel_id: Optional[str] = None,
on_stdout: Optional[Callable[[ProcessMessage], Any]] = None,
on_stderr: Optional[Callable[[ProcessMessage], Any]] = None,
on_display_data: Optional[Callable[[Dict[str, Any]], Any]] = None,
timeout: Optional[float] = TIMEOUT,
) -> Execution:
"""
Expand All @@ -73,6 +74,7 @@ def exec_cell(
:param kernel_id: The ID of the kernel to execute the code on. If not provided, the default kernel is used.
:param on_stdout: A callback function to handle standard output messages from the code execution.
:param on_stderr: A callback function to handle standard error messages from the code execution.
:param on_display_data: A callback function to handle display data messages from the code execution.
:param timeout: Timeout for the call
:return: Result of the execution
Expand All @@ -89,7 +91,7 @@ def exec_cell(
logger.debug(f"Creating new websocket connection to kernel {kernel_id}")
ws = self._connect_to_kernel_ws(kernel_id, timeout=timeout)

session_id = ws.send_execution_message(code, on_stdout, on_stderr)
session_id = ws.send_execution_message(code, on_stdout, on_stderr, on_display_data)
logger.debug(
f"Sent execution message to kernel {kernel_id}, session_id: {session_id}"
)
Expand Down
21 changes: 14 additions & 7 deletions python/e2b_code_interpreter/messaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from e2b.utils.future import DeferredFuture
from pydantic import ConfigDict, PrivateAttr, BaseModel

from e2b_code_interpreter.models import Execution, Data, Error
from e2b_code_interpreter.models import Execution, Data, Error, MIMEType

logger = logging.getLogger(__name__)

Expand All @@ -26,18 +26,21 @@ class CellExecution:
"""

input_accepted: bool = False
on_stdout: Optional[Callable[[Any], None]] = None
on_stderr: Optional[Callable[[Any], None]] = None
on_stdout: Optional[Callable[[ProcessMessage], None]] = None
on_stderr: Optional[Callable[[ProcessMessage], None]] = None
on_display_data: Optional[Callable[[Dict[MIMEType, str]], None]] = None

def __init__(
self,
on_stdout: Optional[Callable[[Any], None]] = None,
on_stderr: Optional[Callable[[Any], None]] = None,
on_stdout: Optional[Callable[[ProcessMessage], None]] = None,
on_stderr: Optional[Callable[[ProcessMessage], None]] = None,
on_display_data: Optional[Callable[[Dict[MIMEType, str]], None]] = None,
):
self.partial_result = Execution()
self.result = Future()
self.on_stdout = on_stdout
self.on_stderr = on_stderr
self.on_display_data = on_display_data


class JupyterKernelWebSocket(BaseModel):
Expand Down Expand Up @@ -126,15 +129,17 @@ def _get_execute_request(msg_id: str, code: str) -> str:
def send_execution_message(
self,
code: str,
on_stdout: Optional[Callable[[Any], None]] = None,
on_stderr: Optional[Callable[[Any], None]] = None,
on_stdout: Optional[Callable[[ProcessMessage], None]] = None,
on_stderr: Optional[Callable[[ProcessMessage], None]] = None,
on_display_data: Optional[Callable[[Dict[MIMEType, str]], None]] = None,
) -> str:
message_id = str(uuid.uuid4())
logger.debug(f"Sending execution message: {message_id}")

self._cells[message_id] = CellExecution(
on_stdout=on_stdout,
on_stderr=on_stderr,
on_display_data=on_display_data,
)
request = self._get_execute_request(message_id, code)
self._queue_in.put(request)
Expand Down Expand Up @@ -196,6 +201,8 @@ def _receive_message(self, data: dict):

elif data["msg_type"] in "display_data":
result.data.append(Data(is_main_result=False, data=data["content"]["data"]))
if cell.on_display_data:
cell.on_display_data(data["content"]["data"])
elif data["msg_type"] == "execute_result":
result.data.append(Data(is_main_result=True, data=data["content"]["data"]))
elif data["msg_type"] == "status":
Expand Down

0 comments on commit b016e17

Please sign in to comment.