diff --git a/js/src/index.ts b/js/src/index.ts index a609724..0610c72 100644 --- a/js/src/index.ts +++ b/js/src/index.ts @@ -1,7 +1,7 @@ export { CodeInterpreter, JupyterExtension } from './code-interpreter' export type { CreateKernelProps } from './code-interpreter' -export type { Error, Data, Result, MIMEType } from './messaging' +export type { ExecutionError, Data, Result, MIMEType } from './messaging' import { CodeInterpreter } from './code-interpreter' export default CodeInterpreter diff --git a/js/src/messaging.ts b/js/src/messaging.ts index b13dece..57d7dc7 100644 --- a/js/src/messaging.ts +++ b/js/src/messaging.ts @@ -9,7 +9,7 @@ import { ProcessMessage } from 'e2b' * @property {string} value - Value of the error. * @property {string} traceback - The traceback of the error. */ -export class Error { +export class ExecutionError { name: string value: string tracebackRaw: string[] @@ -77,6 +77,39 @@ export class Data { isMainResult: boolean raw: RawData + + constructor(data: RawData, isMainResult: boolean) { + this.text = data['text/plain'] + this.html = data['text/html'] + this.markdown = data['text/markdown'] + this.svg = data['image/svg+xml'] + this.png = data['image/png'] + this.jpeg = data['image/jpeg'] + this.pdf = data['application/pdf'] + this.latex = data['text/latex'] + this.json = data['application/json'] + this.javascript = data['application/javascript'] + this.isMainResult = isMainResult + this.raw = data + + this.extra = {} + for (const key in data) { + if (![ + 'text/plain', + 'text/html', + 'text/markdown', + 'image/svg+xml', + 'image/png', + 'image/jpeg', + 'application/pdf', + 'text/latex', + 'application/json', + 'application/javascript' + ].includes(key)) { + this.extra[key] = data[key] + } + } + } } /** @@ -93,16 +126,16 @@ export type Logs = { * Represents the result of a cell execution. * @property {Data} data - List of result of the cell (interactively interpreted last line), display calls, e.g. matplotlib plots. * @property {Logs} logs - "Logs printed to stdout and stderr during execution." - * @property {Error | null} error - An Error object if an error occurred, null otherwise. + * @property {ExecutionError | null} error - An Error object if an error occurred, null otherwise. */ export class Result { constructor( public data: Data[], public logs: Logs, - public error?: Error + public error?: ExecutionError ) {} - public text(): string | undefined { + public get text(): string | undefined { for (const data of this.data) { if (data.isMainResult) { return data.text @@ -183,14 +216,14 @@ export class JupyterKernelWebSocket { const result = cell.result if (message.msg_type == 'error') { - result.error = { - name: message.content.ename, - value: message.content.evalue, - tracebackRaw: message.content.traceback - } + result.error = new ExecutionError( + message.content.ename, + message.content.evalue, + message.content.traceback, + ) } else if (message.msg_type == 'stream') { if (message.content.name == 'stdout') { - result.stdout.push(message.content.text) + result.logs.stdout.push(message.content.text) if (cell?.onStdout) { cell.onStdout( new ProcessMessage( @@ -201,7 +234,7 @@ export class JupyterKernelWebSocket { ) } } else if (message.content.name == 'stderr') { - result.stderr.push(message.content.text) + result.logs.stderr.push(message.content.text) if (cell?.onStderr) { cell.onStderr( new ProcessMessage( @@ -213,29 +246,29 @@ export class JupyterKernelWebSocket { } } } else if (message.msg_type == 'display_data') { - result.displayData.push(message.content.data) + result.data.push(new Data(message.content.data, false)) } else if (message.msg_type == 'execute_result') { - result.result = message.content.data + result.data.push(new Data(message.content.data, true)) } else if (message.msg_type == 'status') { if (message.content.execution_state == 'idle') { if (cell.inputAccepted) { this.idAwaiter[parentMsgId](result) } } else if (message.content.execution_state == 'error') { - result.error = { - name: message.content.ename, - value: message.content.evalue, - tracebackRaw: message.content.traceback - } + result.error = new ExecutionError( + message.content.ename, + message.content.evalue, + message.content.traceback, + ) this.idAwaiter[parentMsgId](result) } } else if (message.msg_type == 'execute_reply') { if (message.content.status == 'error') { - result.error = { - name: message.content.ename, - value: message.content.evalue, - tracebackRaw: message.content.traceback - } + result.error = new ExecutionError( + message.content.ename, + message.content.evalue, + message.content.traceback, + ) } else if (message.content.status == 'ok') { return } @@ -248,7 +281,6 @@ export class JupyterKernelWebSocket { } // communication - /** * Sends code to be executed by Jupyter kernel. * @param code Code to be executed. diff --git a/js/tests/bash.test.ts b/js/tests/bash.test.ts index 4ccdf86..5a2f07c 100644 --- a/js/tests/bash.test.ts +++ b/js/tests/bash.test.ts @@ -5,9 +5,9 @@ import { expect, test } from 'vitest' test('bash', async () => { const sandbox = await CodeInterpreter.create() - const output = await sandbox.notebook.execCell('!pwd') + const result = await sandbox.notebook.execCell('!pwd') - expect(output.stdout.join().trim()).toEqual('/home/user') + expect(result.logs.stdout.join().trim()).toEqual('/home/user') await sandbox.close() }) diff --git a/js/tests/basic.test.ts b/js/tests/basic.test.ts index 886c2f3..90a0110 100644 --- a/js/tests/basic.test.ts +++ b/js/tests/basic.test.ts @@ -5,9 +5,9 @@ import { expect, test } from 'vitest' test('basic', async () => { const sandbox = await CodeInterpreter.create() - const output = await sandbox.notebook.execCell('x =1; x') + const result = await sandbox.notebook.execCell('x =1; x') - expect(output.text).toEqual('1') + expect(result.text).toEqual('1') await sandbox.close() }) diff --git a/js/tests/displayData.test.ts b/js/tests/displayData.test.ts index 0a82f05..9cbf987 100644 --- a/js/tests/displayData.test.ts +++ b/js/tests/displayData.test.ts @@ -17,10 +17,9 @@ test('display data', async () => { plt.show() `) - // there's your image - const image = result.displayData[0] - expect(image['image/png']).toBeDefined() - expect(image['text/plain']).toBeDefined() + const image = result.data[0] + expect(image.png).toBeDefined() + expect(image.text).toBeDefined() await sandbox.close() }) diff --git a/js/tests/statefulness.test.ts b/js/tests/statefulness.test.ts index 4f1ce52..e763529 100644 --- a/js/tests/statefulness.test.ts +++ b/js/tests/statefulness.test.ts @@ -7,9 +7,9 @@ test('statefulness', async () => { await sandbox.notebook.execCell('x = 1') - const output = await sandbox.notebook.execCell('x += 1; x') + const result = await sandbox.notebook.execCell('x += 1; x') - expect(output.text).toEqual('2') + expect(result.text).toEqual('2') await sandbox.close() }) diff --git a/python/e2b_code_interpreter/models.py b/python/e2b_code_interpreter/models.py index f64aa62..ac0c828 100644 --- a/python/e2b_code_interpreter/models.py +++ b/python/e2b_code_interpreter/models.py @@ -1,3 +1,4 @@ +import copy from typing import List, Optional, Iterable, Dict from pydantic import BaseModel @@ -65,18 +66,19 @@ class Data: def __init__(self, is_main_result: bool, data: [MIMEType, str]): self.is_main_result = is_main_result - self.raw = data - - self.text = data["text/plain"] - self.html = data.get("text/html", None) - self.markdown = data.get("text/markdown", None) - self.svg = data.get("image/svg+xml", None) - self.png = data.get("image/png", None) - self.jpeg = data.get("image/jpeg", None) - self.pdf = data.get("application/pdf", None) - self.latex = data.get("text/latex", None) - self.json = data.get("application/json", None) - self.javascript = data.get("application/javascript", None) + self.raw = copy.deepcopy(data) + + self.text = data.pop("text/plain") + self.html = data.pop("text/html", None) + self.markdown = data.pop("text/markdown", None) + self.svg = data.pop("image/svg+xml", None) + self.png = data.pop("image/png", None) + self.jpeg = data.pop("image/jpeg", None) + self.pdf = data.pop("application/pdf", None) + self.latex = data.pop("text/latex", None) + self.json = data.pop("application/json", None) + self.javascript = data.pop("application/javascript", None) + self.extra = data def keys(self) -> Iterable[str]: """ diff --git a/python/tests/test_display_data.py b/python/tests/test_display_data.py index b116ec8..dd3e1d3 100644 --- a/python/tests/test_display_data.py +++ b/python/tests/test_display_data.py @@ -19,5 +19,5 @@ def test_display_data(): # there's your image data = result.data[0] - assert "image/png" in data - assert "text/plain" in data + assert data.png + assert data.text