mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-18 13:28:31 +00:00
Compare commits
7 Commits
bugfix/res
...
feat/add-c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4b8df71ca1 | ||
|
|
cc187a23d7 | ||
|
|
540e328f06 | ||
|
|
87d4c5f092 | ||
|
|
45c16cfa6b | ||
|
|
9232ac3e3f | ||
|
|
aa8640c086 |
@@ -60,10 +60,17 @@ def version(tools):
|
|||||||
default=5,
|
default=5,
|
||||||
help="Number of iterations to train the crew",
|
help="Number of iterations to train the crew",
|
||||||
)
|
)
|
||||||
def train(n_iterations: int):
|
@click.option(
|
||||||
|
"-f",
|
||||||
|
"--filename",
|
||||||
|
type=str,
|
||||||
|
default="trained_agents_data.pkl",
|
||||||
|
help="Path to a custom file for training",
|
||||||
|
)
|
||||||
|
def train(n_iterations: int, filename: str):
|
||||||
"""Train the crew."""
|
"""Train the crew."""
|
||||||
click.echo(f"Training the crew for {n_iterations} iterations")
|
click.echo(f"Training the Crew for {n_iterations} iterations")
|
||||||
train_crew(n_iterations)
|
train_crew(n_iterations, filename)
|
||||||
|
|
||||||
|
|
||||||
@crewai.command()
|
@crewai.command()
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ def train():
|
|||||||
"topic": "AI LLMs"
|
"topic": "AI LLMs"
|
||||||
}
|
}
|
||||||
try:
|
try:
|
||||||
{{crew_name}}Crew().crew().train(n_iterations=int(sys.argv[1]), inputs=inputs)
|
{{crew_name}}Crew().crew().train(n_iterations=int(sys.argv[1]), filename=sys.argv[2], inputs=inputs)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f"An error occurred while training the crew: {e}")
|
raise Exception(f"An error occurred while training the crew: {e}")
|
||||||
|
|||||||
@@ -3,19 +3,22 @@ import subprocess
|
|||||||
import click
|
import click
|
||||||
|
|
||||||
|
|
||||||
def train_crew(n_iterations: int) -> None:
|
def train_crew(n_iterations: int, filename: str) -> None:
|
||||||
"""
|
"""
|
||||||
Train the crew by running a command in the Poetry environment.
|
Train the crew by running a command in the Poetry environment.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
n_iterations (int): The number of iterations to train the crew.
|
n_iterations (int): The number of iterations to train the crew.
|
||||||
"""
|
"""
|
||||||
command = ["poetry", "run", "train", str(n_iterations)]
|
command = ["poetry", "run", "train", str(n_iterations), filename]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if n_iterations <= 0:
|
if n_iterations <= 0:
|
||||||
raise ValueError("The number of iterations must be a positive integer.")
|
raise ValueError("The number of iterations must be a positive integer.")
|
||||||
|
|
||||||
|
if not filename.endswith(".pkl"):
|
||||||
|
raise ValueError("The filename must not end with .pkl")
|
||||||
|
|
||||||
result = subprocess.run(command, capture_output=False, text=True, check=True)
|
result = subprocess.run(command, capture_output=False, text=True, check=True)
|
||||||
|
|
||||||
if result.stderr:
|
if result.stderr:
|
||||||
|
|||||||
@@ -34,7 +34,9 @@ from crewai.telemetry import Telemetry
|
|||||||
from crewai.tools.agent_tools import AgentTools
|
from crewai.tools.agent_tools import AgentTools
|
||||||
from crewai.types.usage_metrics import UsageMetrics
|
from crewai.types.usage_metrics import UsageMetrics
|
||||||
from crewai.utilities import I18N, FileHandler, Logger, RPMController
|
from crewai.utilities import I18N, FileHandler, Logger, RPMController
|
||||||
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
|
from crewai.utilities.constants import (
|
||||||
|
TRAINING_DATA_FILE,
|
||||||
|
)
|
||||||
from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator
|
from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator
|
||||||
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
|
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
|
||||||
from crewai.utilities.formatter import (
|
from crewai.utilities.formatter import (
|
||||||
@@ -388,7 +390,7 @@ class Crew(BaseModel):
|
|||||||
del task_config["agent"]
|
del task_config["agent"]
|
||||||
return Task(**task_config, agent=task_agent)
|
return Task(**task_config, agent=task_agent)
|
||||||
|
|
||||||
def _setup_for_training(self) -> None:
|
def _setup_for_training(self, filename: str) -> None:
|
||||||
"""Sets up the crew for training."""
|
"""Sets up the crew for training."""
|
||||||
self._train = True
|
self._train = True
|
||||||
|
|
||||||
@@ -399,11 +401,13 @@ class Crew(BaseModel):
|
|||||||
agent.allow_delegation = False
|
agent.allow_delegation = False
|
||||||
|
|
||||||
CrewTrainingHandler(TRAINING_DATA_FILE).initialize_file()
|
CrewTrainingHandler(TRAINING_DATA_FILE).initialize_file()
|
||||||
CrewTrainingHandler(TRAINED_AGENTS_DATA_FILE).initialize_file()
|
CrewTrainingHandler(filename).initialize_file()
|
||||||
|
|
||||||
def train(self, n_iterations: int, inputs: Optional[Dict[str, Any]] = {}) -> None:
|
def train(
|
||||||
|
self, n_iterations: int, filename: str, inputs: Optional[Dict[str, Any]] = {}
|
||||||
|
) -> None:
|
||||||
"""Trains the crew for a given number of iterations."""
|
"""Trains the crew for a given number of iterations."""
|
||||||
self._setup_for_training()
|
self._setup_for_training(filename)
|
||||||
|
|
||||||
for n_iteration in range(n_iterations):
|
for n_iteration in range(n_iterations):
|
||||||
self._train_iteration = n_iteration
|
self._train_iteration = n_iteration
|
||||||
@@ -416,7 +420,7 @@ class Crew(BaseModel):
|
|||||||
training_data=training_data, agent_id=str(agent.id)
|
training_data=training_data, agent_id=str(agent.id)
|
||||||
)
|
)
|
||||||
|
|
||||||
CrewTrainingHandler(TRAINED_AGENTS_DATA_FILE).save_trained_data(
|
CrewTrainingHandler(filename).save_trained_data(
|
||||||
agent_id=str(agent.role), trained_data=result.model_dump()
|
agent_id=str(agent.role), trained_data=result.model_dump()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
@@ -32,14 +30,16 @@ class PickleHandler:
|
|||||||
Parameters:
|
Parameters:
|
||||||
- file_name (str): The name of the file for saving and loading data.
|
- file_name (str): The name of the file for saving and loading data.
|
||||||
"""
|
"""
|
||||||
|
if not file_name.endswith(".pkl"):
|
||||||
|
file_name += ".pkl"
|
||||||
|
|
||||||
self.file_path = os.path.join(os.getcwd(), file_name)
|
self.file_path = os.path.join(os.getcwd(), file_name)
|
||||||
|
|
||||||
def initialize_file(self) -> None:
|
def initialize_file(self) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize the file with an empty dictionary if it does not exist or is empty.
|
Initialize the file with an empty dictionary and overwrite any existing data.
|
||||||
"""
|
"""
|
||||||
if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0:
|
self.save({})
|
||||||
self.save({}) # Save an empty dictionary to initialize the file
|
|
||||||
|
|
||||||
def save(self, data) -> None:
|
def save(self, data) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -15,18 +15,18 @@ def runner():
|
|||||||
def test_train_default_iterations(train_crew, runner):
|
def test_train_default_iterations(train_crew, runner):
|
||||||
result = runner.invoke(train)
|
result = runner.invoke(train)
|
||||||
|
|
||||||
train_crew.assert_called_once_with(5)
|
train_crew.assert_called_once_with(5, "trained_agents_data.pkl")
|
||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
assert "Training the crew for 5 iterations" in result.output
|
assert "Training the Crew for 5 iterations" in result.output
|
||||||
|
|
||||||
|
|
||||||
@mock.patch("crewai.cli.cli.train_crew")
|
@mock.patch("crewai.cli.cli.train_crew")
|
||||||
def test_train_custom_iterations(train_crew, runner):
|
def test_train_custom_iterations(train_crew, runner):
|
||||||
result = runner.invoke(train, ["--n_iterations", "10"])
|
result = runner.invoke(train, ["--n_iterations", "10"])
|
||||||
|
|
||||||
train_crew.assert_called_once_with(10)
|
train_crew.assert_called_once_with(10, "trained_agents_data.pkl")
|
||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
assert "Training the crew for 10 iterations" in result.output
|
assert "Training the Crew for 10 iterations" in result.output
|
||||||
|
|
||||||
|
|
||||||
@mock.patch("crewai.cli.cli.train_crew")
|
@mock.patch("crewai.cli.cli.train_crew")
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ from crewai.cli.train_crew import train_crew
|
|||||||
|
|
||||||
@mock.patch("crewai.cli.train_crew.subprocess.run")
|
@mock.patch("crewai.cli.train_crew.subprocess.run")
|
||||||
def test_train_crew_positive_iterations(mock_subprocess_run):
|
def test_train_crew_positive_iterations(mock_subprocess_run):
|
||||||
# Arrange
|
|
||||||
n_iterations = 5
|
n_iterations = 5
|
||||||
mock_subprocess_run.return_value = subprocess.CompletedProcess(
|
mock_subprocess_run.return_value = subprocess.CompletedProcess(
|
||||||
args=["poetry", "run", "train", str(n_iterations)],
|
args=["poetry", "run", "train", str(n_iterations)],
|
||||||
@@ -15,12 +14,10 @@ def test_train_crew_positive_iterations(mock_subprocess_run):
|
|||||||
stderr="",
|
stderr="",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Act
|
train_crew(n_iterations, "trained_agents_data.pkl")
|
||||||
train_crew(n_iterations)
|
|
||||||
|
|
||||||
# Assert
|
|
||||||
mock_subprocess_run.assert_called_once_with(
|
mock_subprocess_run.assert_called_once_with(
|
||||||
["poetry", "run", "train", str(n_iterations)],
|
["poetry", "run", "train", str(n_iterations), "trained_agents_data.pkl"],
|
||||||
capture_output=False,
|
capture_output=False,
|
||||||
text=True,
|
text=True,
|
||||||
check=True,
|
check=True,
|
||||||
@@ -29,7 +26,7 @@ def test_train_crew_positive_iterations(mock_subprocess_run):
|
|||||||
|
|
||||||
@mock.patch("crewai.cli.train_crew.click")
|
@mock.patch("crewai.cli.train_crew.click")
|
||||||
def test_train_crew_zero_iterations(click):
|
def test_train_crew_zero_iterations(click):
|
||||||
train_crew(0)
|
train_crew(0, "trained_agents_data.pkl")
|
||||||
click.echo.assert_called_once_with(
|
click.echo.assert_called_once_with(
|
||||||
"An unexpected error occurred: The number of iterations must be a positive integer.",
|
"An unexpected error occurred: The number of iterations must be a positive integer.",
|
||||||
err=True,
|
err=True,
|
||||||
@@ -38,7 +35,7 @@ def test_train_crew_zero_iterations(click):
|
|||||||
|
|
||||||
@mock.patch("crewai.cli.train_crew.click")
|
@mock.patch("crewai.cli.train_crew.click")
|
||||||
def test_train_crew_negative_iterations(click):
|
def test_train_crew_negative_iterations(click):
|
||||||
train_crew(-2)
|
train_crew(-2, "trained_agents_data.pkl")
|
||||||
click.echo.assert_called_once_with(
|
click.echo.assert_called_once_with(
|
||||||
"An unexpected error occurred: The number of iterations must be a positive integer.",
|
"An unexpected error occurred: The number of iterations must be a positive integer.",
|
||||||
err=True,
|
err=True,
|
||||||
@@ -55,10 +52,13 @@ def test_train_crew_called_process_error(mock_subprocess_run, click):
|
|||||||
output="Error",
|
output="Error",
|
||||||
stderr="Some error occurred",
|
stderr="Some error occurred",
|
||||||
)
|
)
|
||||||
train_crew(n_iterations)
|
train_crew(n_iterations, "trained_agents_data.pkl")
|
||||||
|
|
||||||
mock_subprocess_run.assert_called_once_with(
|
mock_subprocess_run.assert_called_once_with(
|
||||||
["poetry", "run", "train", "5"], capture_output=False, text=True, check=True
|
["poetry", "run", "train", str(n_iterations), "trained_agents_data.pkl"],
|
||||||
|
capture_output=False,
|
||||||
|
text=True,
|
||||||
|
check=True,
|
||||||
)
|
)
|
||||||
click.echo.assert_has_calls(
|
click.echo.assert_has_calls(
|
||||||
[
|
[
|
||||||
@@ -74,13 +74,15 @@ def test_train_crew_called_process_error(mock_subprocess_run, click):
|
|||||||
@mock.patch("crewai.cli.train_crew.click")
|
@mock.patch("crewai.cli.train_crew.click")
|
||||||
@mock.patch("crewai.cli.train_crew.subprocess.run")
|
@mock.patch("crewai.cli.train_crew.subprocess.run")
|
||||||
def test_train_crew_unexpected_exception(mock_subprocess_run, click):
|
def test_train_crew_unexpected_exception(mock_subprocess_run, click):
|
||||||
# Arrange
|
|
||||||
n_iterations = 5
|
n_iterations = 5
|
||||||
mock_subprocess_run.side_effect = Exception("Unexpected error")
|
mock_subprocess_run.side_effect = Exception("Unexpected error")
|
||||||
train_crew(n_iterations)
|
train_crew(n_iterations, "trained_agents_data.pkl")
|
||||||
|
|
||||||
mock_subprocess_run.assert_called_once_with(
|
mock_subprocess_run.assert_called_once_with(
|
||||||
["poetry", "run", "train", "5"], capture_output=False, text=True, check=True
|
["poetry", "run", "train", str(n_iterations), "trained_agents_data.pkl"],
|
||||||
|
capture_output=False,
|
||||||
|
text=True,
|
||||||
|
check=True,
|
||||||
)
|
)
|
||||||
click.echo.assert_called_once_with(
|
click.echo.assert_called_once_with(
|
||||||
"An unexpected error occurred: Unexpected error", err=True
|
"An unexpected error occurred: Unexpected error", err=True
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from unittest.mock import MagicMock, patch
|
|||||||
|
|
||||||
import pydantic_core
|
import pydantic_core
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from crewai.agent import Agent
|
from crewai.agent import Agent
|
||||||
from crewai.agents.cache import CacheHandler
|
from crewai.agents.cache import CacheHandler
|
||||||
from crewai.crew import Crew
|
from crewai.crew import Crew
|
||||||
@@ -1806,7 +1807,9 @@ def test_crew_train_success(task_evaluator, crew_training_handler, kickoff):
|
|||||||
agents=[researcher, writer],
|
agents=[researcher, writer],
|
||||||
tasks=[task],
|
tasks=[task],
|
||||||
)
|
)
|
||||||
crew.train(n_iterations=2, inputs={"topic": "AI"})
|
crew.train(
|
||||||
|
n_iterations=2, inputs={"topic": "AI"}, filename="trained_agents_data.pkl"
|
||||||
|
)
|
||||||
task_evaluator.assert_has_calls(
|
task_evaluator.assert_has_calls(
|
||||||
[
|
[
|
||||||
mock.call(researcher),
|
mock.call(researcher),
|
||||||
@@ -1890,7 +1893,7 @@ def test__setup_for_training():
|
|||||||
for agent in agents:
|
for agent in agents:
|
||||||
assert agent.allow_delegation is True
|
assert agent.allow_delegation is True
|
||||||
|
|
||||||
crew._setup_for_training()
|
crew._setup_for_training("trained_agents_data.pkl")
|
||||||
|
|
||||||
assert crew._train is True
|
assert crew._train is True
|
||||||
assert task.human_input is True
|
assert task.human_input is True
|
||||||
|
|||||||
Reference in New Issue
Block a user