Skip to content

Commit

Permalink
feat: add ability to train on custom file (#1161)
Browse files Browse the repository at this point in the history
* feat: add ability to train on custom file

* feat: add pkl file validation

* feat: fix tests

* feat: fix tests

* feat: fix tests
  • Loading branch information
pythonbyte authored Aug 9, 2024
1 parent 62f5b2f commit 51ee483
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 35 deletions.
13 changes: 10 additions & 3 deletions src/crewai/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,17 @@ def version(tools):
default=5,
help="Number of iterations to train the crew",
)
def train(n_iterations: int):
@click.option(
"-f",
"--filename",
type=str,
default="trained_agents_data.pkl",
help="Path to a custom file for training",
)
def train(n_iterations: int, filename: str):
"""Train the crew."""
click.echo(f"Training the crew for {n_iterations} iterations")
train_crew(n_iterations)
click.echo(f"Training the Crew for {n_iterations} iterations")
train_crew(n_iterations, filename)


@crewai.command()
Expand Down
2 changes: 1 addition & 1 deletion src/crewai/cli/templates/crew/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def train():
"topic": "AI LLMs"
}
try:
{{crew_name}}Crew().crew().train(n_iterations=int(sys.argv[1]), inputs=inputs)
{{crew_name}}Crew().crew().train(n_iterations=int(sys.argv[1]), filename=sys.argv[2], inputs=inputs)

except Exception as e:
raise Exception(f"An error occurred while training the crew: {e}")
Expand Down
7 changes: 5 additions & 2 deletions src/crewai/cli/train_crew.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,22 @@
import click


def train_crew(n_iterations: int) -> None:
def train_crew(n_iterations: int, filename: str) -> None:
"""
Train the crew by running a command in the Poetry environment.
Args:
n_iterations (int): The number of iterations to train the crew.
"""
command = ["poetry", "run", "train", str(n_iterations)]
command = ["poetry", "run", "train", str(n_iterations), filename]

try:
if n_iterations <= 0:
raise ValueError("The number of iterations must be a positive integer.")

if not filename.endswith(".pkl"):
raise ValueError("The filename must not end with .pkl")

result = subprocess.run(command, capture_output=False, text=True, check=True)

if result.stderr:
Expand Down
16 changes: 10 additions & 6 deletions src/crewai/crew.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@
from crewai.tools.agent_tools import AgentTools
from crewai.types.usage_metrics import UsageMetrics
from crewai.utilities import I18N, FileHandler, Logger, RPMController
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
from crewai.utilities.constants import (
TRAINING_DATA_FILE,
)
from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
from crewai.utilities.formatter import (
Expand Down Expand Up @@ -388,7 +390,7 @@ def _create_task(self, task_config: Dict[str, Any]) -> Task:
del task_config["agent"]
return Task(**task_config, agent=task_agent)

def _setup_for_training(self) -> None:
def _setup_for_training(self, filename: str) -> None:
"""Sets up the crew for training."""
self._train = True

Expand All @@ -399,11 +401,13 @@ def _setup_for_training(self) -> None:
agent.allow_delegation = False

CrewTrainingHandler(TRAINING_DATA_FILE).initialize_file()
CrewTrainingHandler(TRAINED_AGENTS_DATA_FILE).initialize_file()
CrewTrainingHandler(filename).initialize_file()

def train(self, n_iterations: int, inputs: Optional[Dict[str, Any]] = {}) -> None:
def train(
self, n_iterations: int, filename: str, inputs: Optional[Dict[str, Any]] = {}
) -> None:
"""Trains the crew for a given number of iterations."""
self._setup_for_training()
self._setup_for_training(filename)

for n_iteration in range(n_iterations):
self._train_iteration = n_iteration
Expand All @@ -416,7 +420,7 @@ def train(self, n_iterations: int, inputs: Optional[Dict[str, Any]] = {}) -> Non
training_data=training_data, agent_id=str(agent.id)
)

CrewTrainingHandler(TRAINED_AGENTS_DATA_FILE).save_trained_data(
CrewTrainingHandler(filename).save_trained_data(
agent_id=str(agent.role), trained_data=result.model_dump()
)

Expand Down
10 changes: 5 additions & 5 deletions src/crewai/utilities/file_handler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import os
import pickle


from datetime import datetime


Expand Down Expand Up @@ -32,14 +30,16 @@ def __init__(self, file_name: str) -> None:
Parameters:
- file_name (str): The name of the file for saving and loading data.
"""
if not file_name.endswith(".pkl"):
file_name += ".pkl"

self.file_path = os.path.join(os.getcwd(), file_name)

def initialize_file(self) -> None:
"""
Initialize the file with an empty dictionary if it does not exist or is empty.
Initialize the file with an empty dictionary and overwrite any existing data.
"""
if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0:
self.save({}) # Save an empty dictionary to initialize the file
self.save({})

def save(self, data) -> None:
"""
Expand Down
8 changes: 4 additions & 4 deletions tests/cli/cli_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,18 @@ def runner():
def test_train_default_iterations(train_crew, runner):
result = runner.invoke(train)

train_crew.assert_called_once_with(5)
train_crew.assert_called_once_with(5, "trained_agents_data.pkl")
assert result.exit_code == 0
assert "Training the crew for 5 iterations" in result.output
assert "Training the Crew for 5 iterations" in result.output


@mock.patch("crewai.cli.cli.train_crew")
def test_train_custom_iterations(train_crew, runner):
result = runner.invoke(train, ["--n_iterations", "10"])

train_crew.assert_called_once_with(10)
train_crew.assert_called_once_with(10, "trained_agents_data.pkl")
assert result.exit_code == 0
assert "Training the crew for 10 iterations" in result.output
assert "Training the Crew for 10 iterations" in result.output


@mock.patch("crewai.cli.cli.train_crew")
Expand Down
26 changes: 14 additions & 12 deletions tests/cli/train_crew_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

@mock.patch("crewai.cli.train_crew.subprocess.run")
def test_train_crew_positive_iterations(mock_subprocess_run):
# Arrange
n_iterations = 5
mock_subprocess_run.return_value = subprocess.CompletedProcess(
args=["poetry", "run", "train", str(n_iterations)],
Expand All @@ -15,12 +14,10 @@ def test_train_crew_positive_iterations(mock_subprocess_run):
stderr="",
)

# Act
train_crew(n_iterations)
train_crew(n_iterations, "trained_agents_data.pkl")

# Assert
mock_subprocess_run.assert_called_once_with(
["poetry", "run", "train", str(n_iterations)],
["poetry", "run", "train", str(n_iterations), "trained_agents_data.pkl"],
capture_output=False,
text=True,
check=True,
Expand All @@ -29,7 +26,7 @@ def test_train_crew_positive_iterations(mock_subprocess_run):

@mock.patch("crewai.cli.train_crew.click")
def test_train_crew_zero_iterations(click):
train_crew(0)
train_crew(0, "trained_agents_data.pkl")
click.echo.assert_called_once_with(
"An unexpected error occurred: The number of iterations must be a positive integer.",
err=True,
Expand All @@ -38,7 +35,7 @@ def test_train_crew_zero_iterations(click):

@mock.patch("crewai.cli.train_crew.click")
def test_train_crew_negative_iterations(click):
train_crew(-2)
train_crew(-2, "trained_agents_data.pkl")
click.echo.assert_called_once_with(
"An unexpected error occurred: The number of iterations must be a positive integer.",
err=True,
Expand All @@ -55,10 +52,13 @@ def test_train_crew_called_process_error(mock_subprocess_run, click):
output="Error",
stderr="Some error occurred",
)
train_crew(n_iterations)
train_crew(n_iterations, "trained_agents_data.pkl")

mock_subprocess_run.assert_called_once_with(
["poetry", "run", "train", "5"], capture_output=False, text=True, check=True
["poetry", "run", "train", str(n_iterations), "trained_agents_data.pkl"],
capture_output=False,
text=True,
check=True,
)
click.echo.assert_has_calls(
[
Expand All @@ -74,13 +74,15 @@ def test_train_crew_called_process_error(mock_subprocess_run, click):
@mock.patch("crewai.cli.train_crew.click")
@mock.patch("crewai.cli.train_crew.subprocess.run")
def test_train_crew_unexpected_exception(mock_subprocess_run, click):
# Arrange
n_iterations = 5
mock_subprocess_run.side_effect = Exception("Unexpected error")
train_crew(n_iterations)
train_crew(n_iterations, "trained_agents_data.pkl")

mock_subprocess_run.assert_called_once_with(
["poetry", "run", "train", "5"], capture_output=False, text=True, check=True
["poetry", "run", "train", str(n_iterations), "trained_agents_data.pkl"],
capture_output=False,
text=True,
check=True,
)
click.echo.assert_called_once_with(
"An unexpected error occurred: Unexpected error", err=True
Expand Down
7 changes: 5 additions & 2 deletions tests/crew_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import pydantic_core
import pytest

from crewai.agent import Agent
from crewai.agents.cache import CacheHandler
from crewai.crew import Crew
Expand Down Expand Up @@ -1806,7 +1807,9 @@ def test_crew_train_success(task_evaluator, crew_training_handler, kickoff):
agents=[researcher, writer],
tasks=[task],
)
crew.train(n_iterations=2, inputs={"topic": "AI"})
crew.train(
n_iterations=2, inputs={"topic": "AI"}, filename="trained_agents_data.pkl"
)
task_evaluator.assert_has_calls(
[
mock.call(researcher),
Expand Down Expand Up @@ -1890,7 +1893,7 @@ def test__setup_for_training():
for agent in agents:
assert agent.allow_delegation is True

crew._setup_for_training()
crew._setup_for_training("trained_agents_data.pkl")

assert crew._train is True
assert task.human_input is True
Expand Down

0 comments on commit 51ee483

Please sign in to comment.