Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow custom name for functions module #2241

Merged
merged 4 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions autogen/coding/local_commandline_code_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@

class LocalCommandLineCodeExecutor(CodeExecutor):
SUPPORTED_LANGUAGES: ClassVar[List[str]] = ["bash", "shell", "sh", "pwsh", "powershell", "ps1", "python"]
FUNCTIONS_MODULE: ClassVar[str] = "functions"
FUNCTIONS_FILENAME: ClassVar[str] = "functions.py"
FUNCTION_PROMPT_TEMPLATE: ClassVar[
str
] = """You have access to the following user defined functions. They can be accessed from the module called `$module_name` by their function names.
Expand All @@ -45,6 +43,7 @@ def __init__(
timeout: int = 60,
work_dir: Union[Path, str] = Path("."),
functions: List[Union[FunctionWithRequirements[Any, A], Callable[..., Any], FunctionWithRequirementsStr]] = [],
functions_module: str = "functions",
):
"""(Experimental) A code executor class that executes code through a local command line
environment.
Expand Down Expand Up @@ -76,6 +75,11 @@ def __init__(
if isinstance(work_dir, str):
work_dir = Path(work_dir)

if not functions_module.isidentifier():
raise ValueError("Module name must be a valid Python identifier")

self._functions_module = functions_module

work_dir.mkdir(exist_ok=True)

self._timeout = timeout
Expand Down Expand Up @@ -104,10 +108,15 @@ def format_functions_for_prompt(self, prompt_template: str = FUNCTION_PROMPT_TEM

template = Template(prompt_template)
return template.substitute(
module_name=self.FUNCTIONS_MODULE,
module_name=self._functions_module,
functions="\n\n".join([to_stub(func) for func in self._functions]),
)

@property
def functions_module(self) -> str:
"""(Experimental) The module name for the functions."""
return self._functions_module

@property
def functions(
self,
Expand Down Expand Up @@ -154,7 +163,7 @@ def sanitize_command(lang: str, code: str) -> None:

def _setup_functions(self) -> None:
func_file_content = _build_python_functions_file(self._functions)
func_file = self._work_dir / self.FUNCTIONS_FILENAME
func_file = self._work_dir / f"{self._functions_module}.py"
func_file.write_text(func_file_content)

# Collect requirements
Expand Down
17 changes: 8 additions & 9 deletions test/coding/test_user_defined_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def function_missing_reqs() -> "pandas.DataFrame":
def test_can_load_function_with_reqs(cls) -> None:
with tempfile.TemporaryDirectory() as temp_dir:
executor = cls(work_dir=temp_dir, functions=[load_data])
code = f"""from {cls.FUNCTIONS_MODULE} import load_data
code = f"""from {executor.functions_module} import load_data
import pandas

# Get first row's name
Expand All @@ -74,7 +74,7 @@ def test_can_load_function_with_reqs(cls) -> None:
def test_can_load_function(cls) -> None:
with tempfile.TemporaryDirectory() as temp_dir:
executor = cls(work_dir=temp_dir, functions=[add_two_numbers])
code = f"""from {cls.FUNCTIONS_MODULE} import add_two_numbers
code = f"""from {executor.functions_module} import add_two_numbers
print(add_two_numbers(1, 2))"""

result = executor.execute_code_blocks(
Expand All @@ -93,7 +93,7 @@ def test_can_load_function(cls) -> None:
# def test_fails_for_missing_reqs(cls) -> None:
# with tempfile.TemporaryDirectory() as temp_dir:
# executor = cls(work_dir=temp_dir, functions=[function_missing_reqs])
# code = f"""from {cls.FUNCTIONS_MODULE} import function_missing_reqs
# code = f"""from {executor.functions_module} import function_missing_reqs
# function_missing_reqs()"""

# with pytest.raises(ValueError):
Expand All @@ -109,7 +109,7 @@ def test_can_load_function(cls) -> None:
def test_fails_for_function_incorrect_import(cls) -> None:
with tempfile.TemporaryDirectory() as temp_dir:
executor = cls(work_dir=temp_dir, functions=[function_incorrect_import])
code = f"""from {cls.FUNCTIONS_MODULE} import function_incorrect_import
code = f"""from {executor.functions_module} import function_incorrect_import
function_incorrect_import()"""

with pytest.raises(ValueError):
Expand All @@ -125,7 +125,7 @@ def test_fails_for_function_incorrect_import(cls) -> None:
def test_fails_for_function_incorrect_dep(cls) -> None:
with tempfile.TemporaryDirectory() as temp_dir:
executor = cls(work_dir=temp_dir, functions=[function_incorrect_dep])
code = f"""from {cls.FUNCTIONS_MODULE} import function_incorrect_dep
code = f"""from {executor.functions_module} import function_incorrect_dep
function_incorrect_dep()"""

with pytest.raises(ValueError):
Expand Down Expand Up @@ -183,7 +183,7 @@ def add_two_numbers(a: int, b: int) -> int:
)

executor = cls(work_dir=temp_dir, functions=[func])
code = f"""from {cls.FUNCTIONS_MODULE} import add_two_numbers
code = f"""from {executor.functions_module} import add_two_numbers
print(add_two_numbers(1, 2))"""

result = executor.execute_code_blocks(
Expand Down Expand Up @@ -219,10 +219,9 @@ def add_two_numbers(a: int, b: int) -> int:
'''
)

code = f"""from {cls.FUNCTIONS_MODULE} import add_two_numbers
print(add_two_numbers(object(), False))"""

executor = cls(work_dir=temp_dir, functions=[func])
code = f"""from {executor.functions_module} import add_two_numbers
print(add_two_numbers(object(), False))"""

result = executor.execute_code_blocks(
code_blocks=[
Expand Down
Loading