Skip to content

Commit

Permalink
- refactor: add pyproject to enable dependency management (#85)
Browse files Browse the repository at this point in the history
* - feature: refactor python to enable dependency management

* - ci: bump to 095
  • Loading branch information
agallardol authored Dec 13, 2024
1 parent 92de380 commit 6dde7e1
Show file tree
Hide file tree
Showing 6 changed files with 475 additions and 112 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions libs/shinkai-tools-runner/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ anyhow = { version = "1.0.93" }
chrono = { version = "0.4.38" }
tar = "0.4"
flate2 = "1.0"
toml_edit = "0.22.22"

[dev-dependencies]
rstest = "0.23.0"
Expand Down
302 changes: 193 additions & 109 deletions libs/shinkai-tools-runner/src/tools/python_runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use tokio::{
io::{AsyncBufReadExt, BufReader},
sync::Mutex,
};
use toml_edit::DocumentMut;

use crate::tools::{
execution_error::ExecutionError, path_buf_ext::PathBufExt, run_result::RunResult,
Expand All @@ -27,6 +28,7 @@ pub struct PythonRunner {

impl PythonRunner {
pub const MAX_EXECUTION_TIME_MS_INTERNAL_OPS: u64 = 1000;
pub const PYPROJECT_TOML_FILE_NAME: &'static str = "pyproject.toml";

pub fn new(
code_files: CodeFiles,
Expand Down Expand Up @@ -83,6 +85,90 @@ impl PythonRunner {
Ok(python_binary_path)
}

pub fn extend_with_pyproject_toml(code_files: CodeFiles) -> anyhow::Result<CodeFiles> {
let mut code_files = code_files.clone();
let code_entrypoint = match code_files.files.get(&code_files.entrypoint.clone()) {
Some(content) => content,
None => return Err(anyhow::anyhow!("Code entrypoint file is empty")),
};

let pyproject_toml_string = r#"
[project]
name = "shinkai-tool"
version = "0.0.1"
dependencies = [
"jsonpickle~=4.0.0",
]
requires-python = ">=3.13"
"#;
let mut pyproject_toml = pyproject_toml_string
.parse::<DocumentMut>()
.map_err(anyhow::Error::new)?;

// Extract pyproject.toml script section between #///script and #///
let mut script_lines = Vec::new();
let mut in_script = false;
let mut line_start = None;
let mut line_end = None;
for (line_number, code_line) in code_entrypoint.lines().enumerate() {
if code_line.trim() == "# /// script" {
line_start = Some(line_number);
in_script = true;
continue;
} else if code_line.trim() == "# ///" {
line_end = Some(line_number);
break;
}
if in_script {
let line = code_line
.trim()
.to_string()
.replace("#", "")
.trim()
.to_string();
script_lines.push(line);
}
}

// Remove lines between line_start and line_end
if let (Some(line_start), Some(line_end)) = (line_start, line_end) {
let mut lines: Vec<&str> = code_entrypoint.lines().collect();
lines.drain(line_start..=line_end);
let updated_code_entrypoint = lines.join("\n");
log::info!("Updated code entrypoint: {}", updated_code_entrypoint);
code_files
.files
.insert(code_files.entrypoint.clone(), updated_code_entrypoint);
}

let pyproject_toml_from_code_endpoint = script_lines
.join("\n")
.parse::<DocumentMut>()
.map_err(anyhow::Error::new)?;

// If dependencies exist in code endpoint toml, merge them into main toml
if let Some(deps) = pyproject_toml_from_code_endpoint.get("dependencies") {
if let Some(project) = pyproject_toml.get_mut("project") {
if let Some(existing_deps) = project.get_mut("dependencies") {
// Merge the dependencies arrays
if let (Some(existing_arr), Some(new_arr)) =
(existing_deps.as_array_mut(), deps.as_array())
{
existing_arr.extend(new_arr.clone());
}
}
}
}
log::info!(
"autogenerated pyproject_toml: {}",
pyproject_toml.to_string()
);
code_files
.files
.insert(Self::PYPROJECT_TOML_FILE_NAME.to_string(), pyproject_toml.to_string());
Ok(code_files)
}

pub async fn check(&self) -> anyhow::Result<Vec<String>> {
let execution_storage =
ExecutionStorage::new(self.code.clone(), self.options.context.clone());
Expand Down Expand Up @@ -140,46 +226,102 @@ impl PythonRunner {
log::info!("configurations: {}", self.configurations.to_string());
log::info!("parameters: {}", parameters.to_string());

let mut code = self.code.clone();
let entrypoint_code = code.files.get(&self.code.entrypoint.clone());
if let Some(entrypoint_code) = entrypoint_code {
let adapted_entrypoint_code = format!(
r#"
let entrypoint_code = self.code.files.get(&self.code.entrypoint.clone());
if entrypoint_code.is_none() {
return Err(ExecutionError::new(
format!("no entrypoint found {}", self.code.entrypoint),
None,
));
}

let mut code = Self::extend_with_pyproject_toml(self.code.clone()).map_err(|e| {
ExecutionError::new(format!("failed to create pyproject.toml: {}", e), None)
})?;

let entrypoint_code = code.files.get(&self.code.entrypoint.clone()).unwrap();

log::info!(
"Extended pyproject.toml {:?}",
code.files.get(Self::PYPROJECT_TOML_FILE_NAME).unwrap()
);
let mut adapted_configurations = self.configurations.clone();
if let Some(object) = adapted_configurations.as_object_mut() {
object.insert(
"py/object".to_string(),
Value::String("__main__.CONFIG".to_string()),
);
}

let mut adapted_parameters = parameters.clone();
if let Some(object) = adapted_parameters.as_object_mut() {
object.insert(
"py/object".to_string(),
Value::String("__main__.INPUTS".to_string()),
);
}

let adapted_entrypoint_code = format!(
r#"
{}
import json
import asyncio
configurations = json.loads('{}')
parameters = json.loads('{}')
import jsonpickle
import json
class TrickyJsonEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, (list, tuple)):
return [self.default(item) for item in obj]
elif isinstance(obj, dict):
return {{key: self.default(value) for key, value in obj.items()}}
elif isinstance(obj, set):
return list(obj)
elif isinstance(obj, bytes):
return obj.decode('utf-8') # Convert bytes to string
elif isinstance(obj, object) and hasattr(obj, '__dict__'):
return {{key: self.default(value) for key, value in obj.__dict__.items() if not key.startswith('__')}}
elif isinstance(obj, str):
return obj # Return string as is
elif obj is None:
return None
elif hasattr(obj,'__iter__'):
return list(obj) # Convert list_iterator to a list
else:
print("warning: trying to serialize an unknown type", type(obj), obj)
return str(obj) # Fallback for unknown types
def tricky_json_dump(obj):
jsonpickle_encoded = jsonpickle.encode(obj, unpicklable=False, make_refs=False, indent=4)
jsonpickle_decoded = jsonpickle.decode(jsonpickle_encoded, reset=True)
custom_json_dump = json.dumps(jsonpickle_decoded, indent=4, cls=TrickyJsonEncoder)
return custom_json_dump
configurations = jsonpickle.decode('{}')
parameters = jsonpickle.decode('{}')
result = run(configurations, parameters)
if asyncio.iscoroutine(result):
result = asyncio.run(result)
if hasattr(result, '__dict__'):
print("Using __dict__ to serialize object")
print(result)
print(result.__dict__)
serialiable_result = result.__dict__
else:
serialiable_result = result
serialized_result = tricky_json_dump(result)
print("<shinkai-code-result>")
print(json.dumps(serialiable_result))
print(serialized_result)
print("</shinkai-code-result>")
"#,
&entrypoint_code,
serde_json::to_string(&self.configurations)
.unwrap()
.replace("\\", "\\\\")
.replace("'", "\\'")
.replace("\"", "\\\""),
serde_json::to_string(&parameters)
.unwrap()
.replace("\\", "\\\\")
.replace("'", "\\'")
.replace("\"", "\\\"")
);
code.files
.insert(self.code.entrypoint.clone(), adapted_entrypoint_code);
}
&entrypoint_code,
serde_json::to_string(&adapted_configurations)
.unwrap()
.replace("\\", "\\\\")
.replace("'", "\\'")
.replace("\"", "\\\""),
serde_json::to_string(&adapted_parameters)
.unwrap()
.replace("\\", "\\\\")
.replace("'", "\\'")
.replace("\"", "\\\"")
);
code.files
.insert(self.code.entrypoint.clone(), adapted_entrypoint_code);

let result = match self.options.force_runner_type {
Some(RunnerType::Host) => self.run_in_host(code, envs, max_execution_timeout).await,
Expand Down Expand Up @@ -321,20 +463,24 @@ print("</shinkai-code-result>")

let code_entrypoint =
execution_storage.relative_to_root(execution_storage.code_entrypoint_file_path.clone());

let mut command = tokio::process::Command::new("docker");
let mut args = vec!["run", "--rm"];
args.extend(mount_params.iter().map(|s| s.as_str()));
args.extend(container_envs.iter().map(|s| s.as_str()));

let code_folder = Path::new(code_entrypoint.as_str())
.parent()
.unwrap()
.to_string_lossy()
let pyproject_toml_path = execution_storage
.relative_to_root(
execution_storage
.code_folder_path
.clone()
.join(Self::PYPROJECT_TOML_FILE_NAME),
)
.to_string();

let python_start_script = format!(
"source /app/cache/python-venv/bin/activate && python -m pipreqs.pipreqs --encoding utf-8 --force {} && uv pip install -r {}/requirements.txt && python {}",
code_folder.clone().as_str(),
code_folder.clone().as_str(),
"uv run --project {} {}",
pyproject_toml_path,
code_entrypoint.clone().as_str(),
);

Expand All @@ -346,7 +492,7 @@ print("</shinkai-code-result>")
"-c",
python_start_script.as_str(),
]);
// args.extend([code_entrypoint.as_str()]);

let command = command
.args(args)
.stdout(std::process::Stdio::piped())
Expand Down Expand Up @@ -438,88 +584,26 @@ print("</shinkai-code-result>")
let execution_storage = ExecutionStorage::new(code_files, self.options.context.clone());
execution_storage.init_for_python(None)?;

let python_binary_path: String = self
.ensure_python_venv(execution_storage.python_cache_folder_path())
.await?;

log::info!(
"using python from host at path: {:?}",
python_binary_path.clone()
);

let uv_binary_path = path::absolute(self.options.uv_binary_path.clone())
.unwrap()
.to_str()
.unwrap()
.to_string();

let mut ensure_pip_command = tokio::process::Command::new(&uv_binary_path);
ensure_pip_command
.args(["pip", "install", "pipreqs"])
.env(
"VIRTUAL_ENV",
execution_storage
.python_cache_folder_path()
.to_str()
.unwrap(),
)
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.kill_on_drop(true);
log::info!("using uv from host at path: {:?}", uv_binary_path.clone());

let pip_output = ensure_pip_command.spawn()?.wait_with_output().await?;
if !pip_output.status.success() {
return Err(anyhow::Error::new(std::io::Error::new(
std::io::ErrorKind::Other,
String::from_utf8(pip_output.stderr)?,
)));
}
let mut command = tokio::process::Command::new(uv_binary_path);

let mut pipreqs_command = tokio::process::Command::new(&python_binary_path);
pipreqs_command
let command = command
.arg("run")
.args([
"-m",
"pipreqs.pipreqs",
"--encoding",
"utf-8",
"--force",
execution_storage.code_folder_path.to_str().unwrap(),
])
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.kill_on_drop(true);
let pipreqs_output = pipreqs_command.spawn()?.wait_with_output().await?;
if !pipreqs_output.status.success() {
return Err(anyhow::Error::new(std::io::Error::new(
std::io::ErrorKind::Other,
String::from_utf8(pipreqs_output.stderr)?,
)));
}

let mut pip_install_command = tokio::process::Command::new(&uv_binary_path);
pip_install_command
.args(["pip", "install", "-r", "requirements.txt"])
.env(
"VIRTUAL_ENV",
"--project",
execution_storage
.python_cache_folder_path()
.code_folder_path
.join(Self::PYPROJECT_TOML_FILE_NAME)
.to_str()
.unwrap(),
)
.current_dir(execution_storage.code_folder_path.clone())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.kill_on_drop(true);
let pip_install_output = pip_install_command.spawn()?.wait_with_output().await?;
if !pip_install_output.status.success() {
return Err(anyhow::Error::new(std::io::Error::new(
std::io::ErrorKind::Other,
String::from_utf8(pip_install_output.stderr)?,
)));
}

let mut command = tokio::process::Command::new(python_binary_path);
let command = command
])
.arg(execution_storage.code_entrypoint_file_path.clone())
.current_dir(execution_storage.root_folder_path.clone())
.stdout(std::process::Stdio::piped())
Expand Down
Loading

0 comments on commit 6dde7e1

Please sign in to comment.