From 6dde7e12df31f2636a498f31ae4ce075ce5e2083 Mon Sep 17 00:00:00 2001 From: Alfredo Gallardo Date: Fri, 13 Dec 2024 15:38:49 -0300 Subject: [PATCH] - refactor: add pyproject to enable dependency management (#85) * - feature: refactor python to enable dependency management * - ci: bump to 095 --- Cargo.lock | 1 + libs/shinkai-tools-runner/Cargo.toml | 1 + .../src/tools/python_runner.rs | 302 +++++++++++------- .../src/tools/python_runner.test.rs | 277 ++++++++++++++++ package-lock.json | 4 +- package.json | 2 +- 6 files changed, 475 insertions(+), 112 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9d0a675..2cd25bd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1902,6 +1902,7 @@ dependencies = [ "tempfile", "tokio", "tokio-util", + "toml_edit", "zip", ] diff --git a/libs/shinkai-tools-runner/Cargo.toml b/libs/shinkai-tools-runner/Cargo.toml index f3667f5..cb1c3c7 100644 --- a/libs/shinkai-tools-runner/Cargo.toml +++ b/libs/shinkai-tools-runner/Cargo.toml @@ -34,6 +34,7 @@ anyhow = { version = "1.0.93" } chrono = { version = "0.4.38" } tar = "0.4" flate2 = "1.0" +toml_edit = "0.22.22" [dev-dependencies] rstest = "0.23.0" diff --git a/libs/shinkai-tools-runner/src/tools/python_runner.rs b/libs/shinkai-tools-runner/src/tools/python_runner.rs index fbadf9c..b5762fe 100644 --- a/libs/shinkai-tools-runner/src/tools/python_runner.rs +++ b/libs/shinkai-tools-runner/src/tools/python_runner.rs @@ -9,6 +9,7 @@ use tokio::{ io::{AsyncBufReadExt, BufReader}, sync::Mutex, }; +use toml_edit::DocumentMut; use crate::tools::{ execution_error::ExecutionError, path_buf_ext::PathBufExt, run_result::RunResult, @@ -27,6 +28,7 @@ pub struct PythonRunner { impl PythonRunner { pub const MAX_EXECUTION_TIME_MS_INTERNAL_OPS: u64 = 1000; + pub const PYPROJECT_TOML_FILE_NAME: &'static str = "pyproject.toml"; pub fn new( code_files: CodeFiles, @@ -83,6 +85,90 @@ impl PythonRunner { Ok(python_binary_path) } + pub fn extend_with_pyproject_toml(code_files: CodeFiles) -> anyhow::Result { + let mut code_files = code_files.clone(); + let code_entrypoint = match code_files.files.get(&code_files.entrypoint.clone()) { + Some(content) => content, + None => return Err(anyhow::anyhow!("Code entrypoint file is empty")), + }; + + let pyproject_toml_string = r#" +[project] +name = "shinkai-tool" +version = "0.0.1" +dependencies = [ + "jsonpickle~=4.0.0", +] +requires-python = ">=3.13" + "#; + let mut pyproject_toml = pyproject_toml_string + .parse::() + .map_err(anyhow::Error::new)?; + + // Extract pyproject.toml script section between #///script and #/// + let mut script_lines = Vec::new(); + let mut in_script = false; + let mut line_start = None; + let mut line_end = None; + for (line_number, code_line) in code_entrypoint.lines().enumerate() { + if code_line.trim() == "# /// script" { + line_start = Some(line_number); + in_script = true; + continue; + } else if code_line.trim() == "# ///" { + line_end = Some(line_number); + break; + } + if in_script { + let line = code_line + .trim() + .to_string() + .replace("#", "") + .trim() + .to_string(); + script_lines.push(line); + } + } + + // Remove lines between line_start and line_end + if let (Some(line_start), Some(line_end)) = (line_start, line_end) { + let mut lines: Vec<&str> = code_entrypoint.lines().collect(); + lines.drain(line_start..=line_end); + let updated_code_entrypoint = lines.join("\n"); + log::info!("Updated code entrypoint: {}", updated_code_entrypoint); + code_files + .files + .insert(code_files.entrypoint.clone(), updated_code_entrypoint); + } + + let pyproject_toml_from_code_endpoint = script_lines + .join("\n") + .parse::() + .map_err(anyhow::Error::new)?; + + // If dependencies exist in code endpoint toml, merge them into main toml + if let Some(deps) = pyproject_toml_from_code_endpoint.get("dependencies") { + if let Some(project) = pyproject_toml.get_mut("project") { + if let Some(existing_deps) = project.get_mut("dependencies") { + // Merge the dependencies arrays + if let (Some(existing_arr), Some(new_arr)) = + (existing_deps.as_array_mut(), deps.as_array()) + { + existing_arr.extend(new_arr.clone()); + } + } + } + } + log::info!( + "autogenerated pyproject_toml: {}", + pyproject_toml.to_string() + ); + code_files + .files + .insert(Self::PYPROJECT_TOML_FILE_NAME.to_string(), pyproject_toml.to_string()); + Ok(code_files) + } + pub async fn check(&self) -> anyhow::Result> { let execution_storage = ExecutionStorage::new(self.code.clone(), self.options.context.clone()); @@ -140,46 +226,102 @@ impl PythonRunner { log::info!("configurations: {}", self.configurations.to_string()); log::info!("parameters: {}", parameters.to_string()); - let mut code = self.code.clone(); - let entrypoint_code = code.files.get(&self.code.entrypoint.clone()); - if let Some(entrypoint_code) = entrypoint_code { - let adapted_entrypoint_code = format!( - r#" + let entrypoint_code = self.code.files.get(&self.code.entrypoint.clone()); + if entrypoint_code.is_none() { + return Err(ExecutionError::new( + format!("no entrypoint found {}", self.code.entrypoint), + None, + )); + } + + let mut code = Self::extend_with_pyproject_toml(self.code.clone()).map_err(|e| { + ExecutionError::new(format!("failed to create pyproject.toml: {}", e), None) + })?; + + let entrypoint_code = code.files.get(&self.code.entrypoint.clone()).unwrap(); + + log::info!( + "Extended pyproject.toml {:?}", + code.files.get(Self::PYPROJECT_TOML_FILE_NAME).unwrap() + ); + let mut adapted_configurations = self.configurations.clone(); + if let Some(object) = adapted_configurations.as_object_mut() { + object.insert( + "py/object".to_string(), + Value::String("__main__.CONFIG".to_string()), + ); + } + + let mut adapted_parameters = parameters.clone(); + if let Some(object) = adapted_parameters.as_object_mut() { + object.insert( + "py/object".to_string(), + Value::String("__main__.INPUTS".to_string()), + ); + } + + let adapted_entrypoint_code = format!( + r#" {} -import json import asyncio -configurations = json.loads('{}') -parameters = json.loads('{}') +import jsonpickle +import json + +class TrickyJsonEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, (list, tuple)): + return [self.default(item) for item in obj] + elif isinstance(obj, dict): + return {{key: self.default(value) for key, value in obj.items()}} + elif isinstance(obj, set): + return list(obj) + elif isinstance(obj, bytes): + return obj.decode('utf-8') # Convert bytes to string + elif isinstance(obj, object) and hasattr(obj, '__dict__'): + return {{key: self.default(value) for key, value in obj.__dict__.items() if not key.startswith('__')}} + elif isinstance(obj, str): + return obj # Return string as is + elif obj is None: + return None + elif hasattr(obj,'__iter__'): + return list(obj) # Convert list_iterator to a list + else: + print("warning: trying to serialize an unknown type", type(obj), obj) + return str(obj) # Fallback for unknown types + +def tricky_json_dump(obj): + jsonpickle_encoded = jsonpickle.encode(obj, unpicklable=False, make_refs=False, indent=4) + jsonpickle_decoded = jsonpickle.decode(jsonpickle_encoded, reset=True) + custom_json_dump = json.dumps(jsonpickle_decoded, indent=4, cls=TrickyJsonEncoder) + return custom_json_dump + +configurations = jsonpickle.decode('{}') +parameters = jsonpickle.decode('{}') result = run(configurations, parameters) if asyncio.iscoroutine(result): result = asyncio.run(result) -if hasattr(result, '__dict__'): - print("Using __dict__ to serialize object") - print(result) - print(result.__dict__) - serialiable_result = result.__dict__ -else: - serialiable_result = result + +serialized_result = tricky_json_dump(result) + print("") -print(json.dumps(serialiable_result)) +print(serialized_result) print("") "#, - &entrypoint_code, - serde_json::to_string(&self.configurations) - .unwrap() - .replace("\\", "\\\\") - .replace("'", "\\'") - .replace("\"", "\\\""), - serde_json::to_string(¶meters) - .unwrap() - .replace("\\", "\\\\") - .replace("'", "\\'") - .replace("\"", "\\\"") - ); - code.files - .insert(self.code.entrypoint.clone(), adapted_entrypoint_code); - } + &entrypoint_code, + serde_json::to_string(&adapted_configurations) + .unwrap() + .replace("\\", "\\\\") + .replace("'", "\\'") + .replace("\"", "\\\""), + serde_json::to_string(&adapted_parameters) + .unwrap() + .replace("\\", "\\\\") + .replace("'", "\\'") + .replace("\"", "\\\"") + ); + code.files + .insert(self.code.entrypoint.clone(), adapted_entrypoint_code); let result = match self.options.force_runner_type { Some(RunnerType::Host) => self.run_in_host(code, envs, max_execution_timeout).await, @@ -321,20 +463,24 @@ print("") let code_entrypoint = execution_storage.relative_to_root(execution_storage.code_entrypoint_file_path.clone()); + let mut command = tokio::process::Command::new("docker"); let mut args = vec!["run", "--rm"]; args.extend(mount_params.iter().map(|s| s.as_str())); args.extend(container_envs.iter().map(|s| s.as_str())); - let code_folder = Path::new(code_entrypoint.as_str()) - .parent() - .unwrap() - .to_string_lossy() + let pyproject_toml_path = execution_storage + .relative_to_root( + execution_storage + .code_folder_path + .clone() + .join(Self::PYPROJECT_TOML_FILE_NAME), + ) .to_string(); + let python_start_script = format!( - "source /app/cache/python-venv/bin/activate && python -m pipreqs.pipreqs --encoding utf-8 --force {} && uv pip install -r {}/requirements.txt && python {}", - code_folder.clone().as_str(), - code_folder.clone().as_str(), + "uv run --project {} {}", + pyproject_toml_path, code_entrypoint.clone().as_str(), ); @@ -346,7 +492,7 @@ print("") "-c", python_start_script.as_str(), ]); - // args.extend([code_entrypoint.as_str()]); + let command = command .args(args) .stdout(std::process::Stdio::piped()) @@ -438,88 +584,26 @@ print("") let execution_storage = ExecutionStorage::new(code_files, self.options.context.clone()); execution_storage.init_for_python(None)?; - let python_binary_path: String = self - .ensure_python_venv(execution_storage.python_cache_folder_path()) - .await?; - - log::info!( - "using python from host at path: {:?}", - python_binary_path.clone() - ); - let uv_binary_path = path::absolute(self.options.uv_binary_path.clone()) .unwrap() .to_str() .unwrap() .to_string(); - let mut ensure_pip_command = tokio::process::Command::new(&uv_binary_path); - ensure_pip_command - .args(["pip", "install", "pipreqs"]) - .env( - "VIRTUAL_ENV", - execution_storage - .python_cache_folder_path() - .to_str() - .unwrap(), - ) - .stdout(std::process::Stdio::piped()) - .stderr(std::process::Stdio::piped()) - .kill_on_drop(true); + log::info!("using uv from host at path: {:?}", uv_binary_path.clone()); - let pip_output = ensure_pip_command.spawn()?.wait_with_output().await?; - if !pip_output.status.success() { - return Err(anyhow::Error::new(std::io::Error::new( - std::io::ErrorKind::Other, - String::from_utf8(pip_output.stderr)?, - ))); - } + let mut command = tokio::process::Command::new(uv_binary_path); - let mut pipreqs_command = tokio::process::Command::new(&python_binary_path); - pipreqs_command + let command = command + .arg("run") .args([ - "-m", - "pipreqs.pipreqs", - "--encoding", - "utf-8", - "--force", - execution_storage.code_folder_path.to_str().unwrap(), - ]) - .stdout(std::process::Stdio::piped()) - .stderr(std::process::Stdio::piped()) - .kill_on_drop(true); - let pipreqs_output = pipreqs_command.spawn()?.wait_with_output().await?; - if !pipreqs_output.status.success() { - return Err(anyhow::Error::new(std::io::Error::new( - std::io::ErrorKind::Other, - String::from_utf8(pipreqs_output.stderr)?, - ))); - } - - let mut pip_install_command = tokio::process::Command::new(&uv_binary_path); - pip_install_command - .args(["pip", "install", "-r", "requirements.txt"]) - .env( - "VIRTUAL_ENV", + "--project", execution_storage - .python_cache_folder_path() + .code_folder_path + .join(Self::PYPROJECT_TOML_FILE_NAME) .to_str() .unwrap(), - ) - .current_dir(execution_storage.code_folder_path.clone()) - .stdout(std::process::Stdio::piped()) - .stderr(std::process::Stdio::piped()) - .kill_on_drop(true); - let pip_install_output = pip_install_command.spawn()?.wait_with_output().await?; - if !pip_install_output.status.success() { - return Err(anyhow::Error::new(std::io::Error::new( - std::io::ErrorKind::Other, - String::from_utf8(pip_install_output.stderr)?, - ))); - } - - let mut command = tokio::process::Command::new(python_binary_path); - let command = command + ]) .arg(execution_storage.code_entrypoint_file_path.clone()) .current_dir(execution_storage.root_folder_path.clone()) .stdout(std::process::Stdio::piped()) diff --git a/libs/shinkai-tools-runner/src/tools/python_runner.test.rs b/libs/shinkai-tools-runner/src/tools/python_runner.test.rs index cdc784c..8bd3e1b 100644 --- a/libs/shinkai-tools-runner/src/tools/python_runner.test.rs +++ b/libs/shinkai-tools-runner/src/tools/python_runner.test.rs @@ -310,6 +310,11 @@ async fn run_with_import_library(#[case] runner_type: RunnerType) { files: HashMap::from([( "main.py".to_string(), r#" +# /// script +# dependencies = [ +# "requests" +# ] +# /// import requests def run(configurations, parameters): response = requests.get('https://jsonplaceholder.typicode.com/todos/1') @@ -351,6 +356,11 @@ async fn shinkai_tool_run_concurrency(#[case] runner_type: RunnerType) { .is_test(true) .try_init(); let js_code1 = r#" +# /// script +# dependencies = [ +# "requests" +# ] +# /// import requests def run(configurations, params): response = requests.get('https://jsonplaceholder.typicode.com/todos/1') @@ -905,3 +915,270 @@ def run(configurations, parameters): assert_eq!(result.data.get("kind").unwrap(), "vegetable"); } + +#[rstest] +#[case::host(RunnerType::Host)] +#[case::docker(RunnerType::Docker)] +#[tokio::test] +async fn run_pip_lib_name_neq_to_import_name(#[case] runner_type: RunnerType) { + let _ = env_logger::builder() + .filter_level(log::LevelFilter::Info) + .is_test(true) + .try_init(); + + let code_files = CodeFiles { + files: HashMap::from([( + "main.py".to_string(), + r#" +# /// script +# dependencies = [ +# "googlesearch-python", +# ] +# /// +from googlesearch import search, SearchResult +from typing import List +from dataclasses import dataclass + +class CONFIG: + pass + +class INPUTS: + query: str + num_results: int = 10 + +class OUTPUT: + results: List[SearchResult] + query: str + +async def run(c: CONFIG, p: INPUTS) -> OUTPUT: + query = p.query + if not query: + raise ValueError("No search query provided") + + results = [] + try: + results = search(query, num_results=p.num_results, advanced=True) + except Exception as e: + raise RuntimeError(f"Search failed: {str(e)}") + + output = OUTPUT() + output.results = results + output.query = query + return output + "# + .to_string(), + )]), + entrypoint: "main.py".to_string(), + }; + + let python_runner = PythonRunner::new( + code_files, + Value::Null, + Some(PythonRunnerOptions { + force_runner_type: Some(runner_type), + ..Default::default() + }), + ); + + let result = python_runner + .run( + None, + serde_json::json!({ "query": "macbook pro m4", "num_results": 5 }), + None, + ) + .await + .map_err(|e| { + log::error!("Failed to run python code: {}", e); + e + }) + .unwrap(); + + let results_length = result + .data + .get("results") + .unwrap() + .as_array() + .unwrap() + .len(); + assert!( + results_length > 0 && results_length <= 5, + "results should be an array with 0 to 5 elements" + ); + assert!(!result + .data + .get("query") + .unwrap() + .as_str() + .unwrap() + .is_empty()); +} + +/* + This test utilizes the hidden `tricky_json_dump` function, which is part of the engine. + This function serves as a tricky way to test the engine's serialization capabilities + without requiring extensive setup. +*/ +#[rstest] +#[case::host(RunnerType::Host)] +#[case::docker(RunnerType::Docker)] +#[tokio::test] +async fn tricky_json_dump(#[case] runner_type: RunnerType) { + let _ = env_logger::builder() + .filter_level(log::LevelFilter::Info) + .is_test(true) + .try_init(); + + let code_files = CodeFiles { + files: HashMap::from([( + "main.py".to_string(), + r#" +# /// script +# dependencies = [ +# "googlesearch-python", +# ] +# /// + +import asyncio +from googlesearch import SearchResult +from typing import List + +from datetime import datetime +from typing import Dict, Optional + +class CONFIG: + pass + +class INPUTS: + pass + +class OUTPUT: + pass + +class AnyClass1: + query: str + num_results: int = 10 + timestamp: datetime + def __init__(self, query: str): + self.query = query + self.timestamp = datetime.now() + self.num_results = 10 + + def any_method_1(self): + return "any_method_1" + +class AnyClass2: + results: List[SearchResult] + query: str + status_code: Optional[int] = None + +class VeryComplexClass: + results: List[SearchResult] + query: str + input: AnyClass1 + number: int + output: AnyClass2 + unique_ids: set + metadata: Dict[str, str] + additional_info: Optional[str] # New attribute for extra information + status: str # New attribute to track the status of the class + creation_date: datetime # New attribute to track the creation date + slice: slice # New attribute to represent a slice of data + complex_number: complex + byte_array: bytearray + memoryview: memoryview + frozen_set: frozenset + def cry(self): + return "cry" + +def create_very_complex_class(): + search_result = SearchResult(title="potato", url="https://potato.com", description="potato is a vegetable") + search_result2 = SearchResult(title="tomato", url="https://tomato.com", description="tomato is a fruit") # Additional search result + + search_results = list[SearchResult]() + search_results.append(search_result) + search_results.append(search_result2) + + any_class_2 = AnyClass2() + any_class_2.results = iter(search_results) + any_class_2.query = "something about potatoes and tomatoes" # Updated query + any_class_2.status_code = 200 + + very_complex_class = VeryComplexClass() + very_complex_class.results = any_class_2.results + very_complex_class.query = any_class_2.query + very_complex_class.number = 5 + very_complex_class.input = AnyClass1(query="potato") + very_complex_class.output = any_class_2 + very_complex_class.unique_ids = set() + very_complex_class.metadata = {"source": "google", "category": "vegetable"} + + very_complex_class.unique_ids.add("potato_id_1") + very_complex_class.unique_ids.add("tomato_id_1") # Adding unique ID for the new search result + very_complex_class.metadata["potato_name"] = "potato" + very_complex_class.metadata["tomato_name"] = "tomato" # New metadata for tomato + very_complex_class.metadata["potato_search_results"] = str(search_results) + very_complex_class.metadata["additional_info"] = "This class contains search results for vegetables and fruits." # New metadata + very_complex_class.status = "active" # Setting the status + very_complex_class.creation_date = datetime.now() # Setting the creation date + very_complex_class.slice = slice(6) # Example slice initialization + very_complex_class.complex_number = 4+3j + very_complex_class.byte_array = bytearray(b"Hello, World!") + very_complex_class.memoryview = memoryview(b"Hello, World!") + very_complex_class.frozen_set = frozenset([1, 2, 3]) + return very_complex_class + +async def run(c: CONFIG, p: INPUTS) -> OUTPUT: + very_complex_class = create_very_complex_class() + json_dump = tricky_json_dump(very_complex_class) + print("json:", json_dump) + + loaded_data = json.loads(json_dump) + + # Assertions to validate the loaded data + assert isinstance(loaded_data, dict), "Loaded data should be a dictionary" + assert "results" in loaded_data, "'results' key should be present in the loaded data" + assert isinstance(loaded_data["results"], list), "'results' should be a list" + assert len(loaded_data["results"]) == 2, "There should be two search results" + assert all("title" in result for result in loaded_data["results"]), "Each result should have a 'title'" + assert all("url" in result for result in loaded_data["results"]), "Each result should have a 'url'" + assert all("description" in result for result in loaded_data["results"]), "Each result should have a 'description'" + + # Additional assertions for metadata + assert "query" in loaded_data, "'query' key should be present in the loaded data" + assert loaded_data["query"] == "something about potatoes and tomatoes", "Query should match the expected value" + assert "status_code" in loaded_data.get("output", {}), "'status_code' key should be present in the loaded data" + assert loaded_data.get("output", {}).get("status_code") == 200, "Status code should be 200" + assert "metadata" in loaded_data, "'metadata' key should be present in the loaded data" + assert "source" in loaded_data.get("metadata", {}), "'source' should be present in metadata" + assert len(loaded_data.get("output", {}).get("results")) == 2, "output.results should be 2" + + return loaded_data + "# + .to_string(), + )]), + entrypoint: "main.py".to_string(), + }; + + let python_runner = PythonRunner::new( + code_files, + Value::Null, + Some(PythonRunnerOptions { + force_runner_type: Some(runner_type), + ..Default::default() + }), + ); + + let result = python_runner + .run( + None, + serde_json::json!({ "query": "macbook pro m4", "num_results": 5 }), + None, + ) + .await + .map_err(|e| { + log::error!("Failed to run python code: {}", e); + e + }); + + assert!(result.is_ok()); +} diff --git a/package-lock.json b/package-lock.json index 2efb2b2..d58f620 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "@shinkai_protocol/source", - "version": "0.9.4", + "version": "0.9.5", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@shinkai_protocol/source", - "version": "0.9.4", + "version": "0.9.5", "license": "SEE LICENSE IN LICENSE", "dependencies": { "@coinbase/coinbase-sdk": "^0.0.16", diff --git a/package.json b/package.json index cfe2e4d..06821ab 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@shinkai_protocol/source", - "version": "0.9.4", + "version": "0.9.5", "description": "This repository serves as the ecosystem to execute Shinkai tools, provided by the Shinkai team or third-party developers, in a secure environment. It provides a sandboxed space for executing these tools, ensuring that they run safely and efficiently, while also allowing for seamless integration with Rust code.", "main": "index.js", "author": "",