mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 04:18:35 +00:00
Compare commits
7 Commits
1.2.0
...
feat/add-c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4b8df71ca1 | ||
|
|
cc187a23d7 | ||
|
|
540e328f06 | ||
|
|
87d4c5f092 | ||
|
|
45c16cfa6b | ||
|
|
9232ac3e3f | ||
|
|
aa8640c086 |
@@ -60,10 +60,17 @@ def version(tools):
|
||||
default=5,
|
||||
help="Number of iterations to train the crew",
|
||||
)
|
||||
def train(n_iterations: int):
|
||||
@click.option(
|
||||
"-f",
|
||||
"--filename",
|
||||
type=str,
|
||||
default="trained_agents_data.pkl",
|
||||
help="Path to a custom file for training",
|
||||
)
|
||||
def train(n_iterations: int, filename: str):
|
||||
"""Train the crew."""
|
||||
click.echo(f"Training the crew for {n_iterations} iterations")
|
||||
train_crew(n_iterations)
|
||||
click.echo(f"Training the Crew for {n_iterations} iterations")
|
||||
train_crew(n_iterations, filename)
|
||||
|
||||
|
||||
@crewai.command()
|
||||
|
||||
@@ -25,7 +25,7 @@ def train():
|
||||
"topic": "AI LLMs"
|
||||
}
|
||||
try:
|
||||
{{crew_name}}Crew().crew().train(n_iterations=int(sys.argv[1]), inputs=inputs)
|
||||
{{crew_name}}Crew().crew().train(n_iterations=int(sys.argv[1]), filename=sys.argv[2], inputs=inputs)
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"An error occurred while training the crew: {e}")
|
||||
|
||||
@@ -3,19 +3,22 @@ import subprocess
|
||||
import click
|
||||
|
||||
|
||||
def train_crew(n_iterations: int) -> None:
|
||||
def train_crew(n_iterations: int, filename: str) -> None:
|
||||
"""
|
||||
Train the crew by running a command in the Poetry environment.
|
||||
|
||||
Args:
|
||||
n_iterations (int): The number of iterations to train the crew.
|
||||
"""
|
||||
command = ["poetry", "run", "train", str(n_iterations)]
|
||||
command = ["poetry", "run", "train", str(n_iterations), filename]
|
||||
|
||||
try:
|
||||
if n_iterations <= 0:
|
||||
raise ValueError("The number of iterations must be a positive integer.")
|
||||
|
||||
if not filename.endswith(".pkl"):
|
||||
raise ValueError("The filename must not end with .pkl")
|
||||
|
||||
result = subprocess.run(command, capture_output=False, text=True, check=True)
|
||||
|
||||
if result.stderr:
|
||||
|
||||
@@ -34,7 +34,9 @@ from crewai.telemetry import Telemetry
|
||||
from crewai.tools.agent_tools import AgentTools
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
from crewai.utilities import I18N, FileHandler, Logger, RPMController
|
||||
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
|
||||
from crewai.utilities.constants import (
|
||||
TRAINING_DATA_FILE,
|
||||
)
|
||||
from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator
|
||||
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
|
||||
from crewai.utilities.formatter import (
|
||||
@@ -388,7 +390,7 @@ class Crew(BaseModel):
|
||||
del task_config["agent"]
|
||||
return Task(**task_config, agent=task_agent)
|
||||
|
||||
def _setup_for_training(self) -> None:
|
||||
def _setup_for_training(self, filename: str) -> None:
|
||||
"""Sets up the crew for training."""
|
||||
self._train = True
|
||||
|
||||
@@ -399,11 +401,13 @@ class Crew(BaseModel):
|
||||
agent.allow_delegation = False
|
||||
|
||||
CrewTrainingHandler(TRAINING_DATA_FILE).initialize_file()
|
||||
CrewTrainingHandler(TRAINED_AGENTS_DATA_FILE).initialize_file()
|
||||
CrewTrainingHandler(filename).initialize_file()
|
||||
|
||||
def train(self, n_iterations: int, inputs: Optional[Dict[str, Any]] = {}) -> None:
|
||||
def train(
|
||||
self, n_iterations: int, filename: str, inputs: Optional[Dict[str, Any]] = {}
|
||||
) -> None:
|
||||
"""Trains the crew for a given number of iterations."""
|
||||
self._setup_for_training()
|
||||
self._setup_for_training(filename)
|
||||
|
||||
for n_iteration in range(n_iterations):
|
||||
self._train_iteration = n_iteration
|
||||
@@ -416,7 +420,7 @@ class Crew(BaseModel):
|
||||
training_data=training_data, agent_id=str(agent.id)
|
||||
)
|
||||
|
||||
CrewTrainingHandler(TRAINED_AGENTS_DATA_FILE).save_trained_data(
|
||||
CrewTrainingHandler(filename).save_trained_data(
|
||||
agent_id=str(agent.role), trained_data=result.model_dump()
|
||||
)
|
||||
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import os
|
||||
import pickle
|
||||
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@@ -32,14 +30,16 @@ class PickleHandler:
|
||||
Parameters:
|
||||
- file_name (str): The name of the file for saving and loading data.
|
||||
"""
|
||||
if not file_name.endswith(".pkl"):
|
||||
file_name += ".pkl"
|
||||
|
||||
self.file_path = os.path.join(os.getcwd(), file_name)
|
||||
|
||||
def initialize_file(self) -> None:
|
||||
"""
|
||||
Initialize the file with an empty dictionary if it does not exist or is empty.
|
||||
Initialize the file with an empty dictionary and overwrite any existing data.
|
||||
"""
|
||||
if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0:
|
||||
self.save({}) # Save an empty dictionary to initialize the file
|
||||
self.save({})
|
||||
|
||||
def save(self, data) -> None:
|
||||
"""
|
||||
|
||||
@@ -15,18 +15,18 @@ def runner():
|
||||
def test_train_default_iterations(train_crew, runner):
|
||||
result = runner.invoke(train)
|
||||
|
||||
train_crew.assert_called_once_with(5)
|
||||
train_crew.assert_called_once_with(5, "trained_agents_data.pkl")
|
||||
assert result.exit_code == 0
|
||||
assert "Training the crew for 5 iterations" in result.output
|
||||
assert "Training the Crew for 5 iterations" in result.output
|
||||
|
||||
|
||||
@mock.patch("crewai.cli.cli.train_crew")
|
||||
def test_train_custom_iterations(train_crew, runner):
|
||||
result = runner.invoke(train, ["--n_iterations", "10"])
|
||||
|
||||
train_crew.assert_called_once_with(10)
|
||||
train_crew.assert_called_once_with(10, "trained_agents_data.pkl")
|
||||
assert result.exit_code == 0
|
||||
assert "Training the crew for 10 iterations" in result.output
|
||||
assert "Training the Crew for 10 iterations" in result.output
|
||||
|
||||
|
||||
@mock.patch("crewai.cli.cli.train_crew")
|
||||
|
||||
@@ -6,7 +6,6 @@ from crewai.cli.train_crew import train_crew
|
||||
|
||||
@mock.patch("crewai.cli.train_crew.subprocess.run")
|
||||
def test_train_crew_positive_iterations(mock_subprocess_run):
|
||||
# Arrange
|
||||
n_iterations = 5
|
||||
mock_subprocess_run.return_value = subprocess.CompletedProcess(
|
||||
args=["poetry", "run", "train", str(n_iterations)],
|
||||
@@ -15,12 +14,10 @@ def test_train_crew_positive_iterations(mock_subprocess_run):
|
||||
stderr="",
|
||||
)
|
||||
|
||||
# Act
|
||||
train_crew(n_iterations)
|
||||
train_crew(n_iterations, "trained_agents_data.pkl")
|
||||
|
||||
# Assert
|
||||
mock_subprocess_run.assert_called_once_with(
|
||||
["poetry", "run", "train", str(n_iterations)],
|
||||
["poetry", "run", "train", str(n_iterations), "trained_agents_data.pkl"],
|
||||
capture_output=False,
|
||||
text=True,
|
||||
check=True,
|
||||
@@ -29,7 +26,7 @@ def test_train_crew_positive_iterations(mock_subprocess_run):
|
||||
|
||||
@mock.patch("crewai.cli.train_crew.click")
|
||||
def test_train_crew_zero_iterations(click):
|
||||
train_crew(0)
|
||||
train_crew(0, "trained_agents_data.pkl")
|
||||
click.echo.assert_called_once_with(
|
||||
"An unexpected error occurred: The number of iterations must be a positive integer.",
|
||||
err=True,
|
||||
@@ -38,7 +35,7 @@ def test_train_crew_zero_iterations(click):
|
||||
|
||||
@mock.patch("crewai.cli.train_crew.click")
|
||||
def test_train_crew_negative_iterations(click):
|
||||
train_crew(-2)
|
||||
train_crew(-2, "trained_agents_data.pkl")
|
||||
click.echo.assert_called_once_with(
|
||||
"An unexpected error occurred: The number of iterations must be a positive integer.",
|
||||
err=True,
|
||||
@@ -55,10 +52,13 @@ def test_train_crew_called_process_error(mock_subprocess_run, click):
|
||||
output="Error",
|
||||
stderr="Some error occurred",
|
||||
)
|
||||
train_crew(n_iterations)
|
||||
train_crew(n_iterations, "trained_agents_data.pkl")
|
||||
|
||||
mock_subprocess_run.assert_called_once_with(
|
||||
["poetry", "run", "train", "5"], capture_output=False, text=True, check=True
|
||||
["poetry", "run", "train", str(n_iterations), "trained_agents_data.pkl"],
|
||||
capture_output=False,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
click.echo.assert_has_calls(
|
||||
[
|
||||
@@ -74,13 +74,15 @@ def test_train_crew_called_process_error(mock_subprocess_run, click):
|
||||
@mock.patch("crewai.cli.train_crew.click")
|
||||
@mock.patch("crewai.cli.train_crew.subprocess.run")
|
||||
def test_train_crew_unexpected_exception(mock_subprocess_run, click):
|
||||
# Arrange
|
||||
n_iterations = 5
|
||||
mock_subprocess_run.side_effect = Exception("Unexpected error")
|
||||
train_crew(n_iterations)
|
||||
train_crew(n_iterations, "trained_agents_data.pkl")
|
||||
|
||||
mock_subprocess_run.assert_called_once_with(
|
||||
["poetry", "run", "train", "5"], capture_output=False, text=True, check=True
|
||||
["poetry", "run", "train", str(n_iterations), "trained_agents_data.pkl"],
|
||||
capture_output=False,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
click.echo.assert_called_once_with(
|
||||
"An unexpected error occurred: Unexpected error", err=True
|
||||
|
||||
@@ -8,6 +8,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pydantic_core
|
||||
import pytest
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.agents.cache import CacheHandler
|
||||
from crewai.crew import Crew
|
||||
@@ -1806,7 +1807,9 @@ def test_crew_train_success(task_evaluator, crew_training_handler, kickoff):
|
||||
agents=[researcher, writer],
|
||||
tasks=[task],
|
||||
)
|
||||
crew.train(n_iterations=2, inputs={"topic": "AI"})
|
||||
crew.train(
|
||||
n_iterations=2, inputs={"topic": "AI"}, filename="trained_agents_data.pkl"
|
||||
)
|
||||
task_evaluator.assert_has_calls(
|
||||
[
|
||||
mock.call(researcher),
|
||||
@@ -1890,7 +1893,7 @@ def test__setup_for_training():
|
||||
for agent in agents:
|
||||
assert agent.allow_delegation is True
|
||||
|
||||
crew._setup_for_training()
|
||||
crew._setup_for_training("trained_agents_data.pkl")
|
||||
|
||||
assert crew._train is True
|
||||
assert task.human_input is True
|
||||
|
||||
Reference in New Issue
Block a user