mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
feat: Add Train feature for Crews (#686)
* feat: add training logic to agent and crew * feat: add training logic to agent executor * feat: add input parameter to cli command * feat: add utilities for the training logic * feat: polish code, logic and add private variables * feat: add docstring and type hinting to executor * feat: add constant file, add constant to code * feat: fix name of training handler function * feat: remove unused var * feat: change file handler file name * feat: Add training handler file, class and change on the code * feat: fix name error from file * fix: change import to adapt to logic * feat: add training handler test * feat: add tests for file and training_handler * feat: add test for task evaluator function * feat: change text to fit in-screen * feat: add test for train function * feat: add test for agent training_handler function * feat: add test for agent._use_trained_data
This commit is contained in:
committed by
GitHub
parent
9e61b8325b
commit
175d5b3dd6
@@ -1,5 +1,6 @@
|
||||
"""Test Agent creation and execution basic functionality."""
|
||||
|
||||
from unittest import mock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@@ -842,3 +843,54 @@ Thought:
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
@patch("crewai.agent.CrewTrainingHandler")
|
||||
def test_agent_training_handler(crew_training_handler):
|
||||
task_prompt = "What is 1 + 1?"
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
verbose=True,
|
||||
)
|
||||
crew_training_handler().load.return_value = {
|
||||
f"{str(agent.id)}": {"0": {"human_feedback": "good"}}
|
||||
}
|
||||
|
||||
result = agent._training_handler(task_prompt=task_prompt)
|
||||
|
||||
assert result == "What is 1 + 1?You MUST follow these feedbacks: \n good"
|
||||
|
||||
crew_training_handler.assert_has_calls(
|
||||
[mock.call(), mock.call("training_data.pkl"), mock.call().load()]
|
||||
)
|
||||
|
||||
|
||||
@patch("crewai.agent.CrewTrainingHandler")
|
||||
def test_agent_use_trained_data(crew_training_handler):
|
||||
task_prompt = "What is 1 + 1?"
|
||||
agent = Agent(
|
||||
role="researcher",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
verbose=True,
|
||||
)
|
||||
crew_training_handler().load.return_value = {
|
||||
agent.role: {
|
||||
"suggestions": [
|
||||
"The result of the math operatio must be right.",
|
||||
"Result must be better than 1.",
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
result = agent._use_trained_data(task_prompt=task_prompt)
|
||||
|
||||
assert (
|
||||
result == "What is 1 + 1?You MUST follow these feedbacks: \n "
|
||||
"The result of the math operatio must be right.\n - Result must be better than 1."
|
||||
)
|
||||
crew_training_handler.assert_has_calls(
|
||||
[mock.call(), mock.call("trained_agents_data.pkl"), mock.call().load()]
|
||||
)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Test Agent creation and execution basic functionality."""
|
||||
|
||||
import json
|
||||
from unittest import mock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pydantic_core
|
||||
import pytest
|
||||
@@ -1006,7 +1008,10 @@ def test_manager_agent_with_tools_raises_exception():
|
||||
crew.kickoff()
|
||||
|
||||
|
||||
def test_crew_train_success():
|
||||
@patch("crewai.crew.Crew.kickoff")
|
||||
@patch("crewai.crew.CrewTrainingHandler")
|
||||
@patch("crewai.crew.TaskEvaluator")
|
||||
def test_crew_train_success(task_evaluator, crew_training_handler, kickoff):
|
||||
task = Task(
|
||||
description="Come up with a list of 5 interesting ideas to explore for an article, then write one amazing paragraph highlight for each idea that showcases how good an article about this topic could be. Return the list of ideas with their paragraph and your notes.",
|
||||
expected_output="5 bullet points with a paragraph for each idea.",
|
||||
@@ -1016,8 +1021,48 @@ def test_crew_train_success():
|
||||
agents=[researcher, writer],
|
||||
tasks=[task],
|
||||
)
|
||||
crew.train(n_iterations=2, inputs={"topic": "AI"})
|
||||
task_evaluator.assert_has_calls(
|
||||
[
|
||||
mock.call(researcher),
|
||||
mock.call().evaluate_training_data(
|
||||
training_data=crew_training_handler().load(),
|
||||
agent_id=str(researcher.id),
|
||||
),
|
||||
mock.call().evaluate_training_data().model_dump(),
|
||||
mock.call(writer),
|
||||
mock.call().evaluate_training_data(
|
||||
training_data=crew_training_handler().load(),
|
||||
agent_id=str(writer.id),
|
||||
),
|
||||
mock.call().evaluate_training_data().model_dump(),
|
||||
]
|
||||
)
|
||||
|
||||
crew.train(n_iterations=2)
|
||||
crew_training_handler.assert_has_calls(
|
||||
[
|
||||
mock.call("training_data.pkl"),
|
||||
mock.call().load(),
|
||||
mock.call("trained_agents_data.pkl"),
|
||||
mock.call().save_trained_data(
|
||||
agent_id="Researcher",
|
||||
trained_data=task_evaluator().evaluate_training_data().model_dump(),
|
||||
),
|
||||
mock.call("trained_agents_data.pkl"),
|
||||
mock.call().save_trained_data(
|
||||
agent_id="Senior Writer",
|
||||
trained_data=task_evaluator().evaluate_training_data().model_dump(),
|
||||
),
|
||||
mock.call(),
|
||||
mock.call().load(),
|
||||
mock.call(),
|
||||
mock.call().load(),
|
||||
]
|
||||
)
|
||||
|
||||
kickoff.assert_has_calls(
|
||||
[mock.call(inputs={"topic": "AI"}), mock.call(inputs={"topic": "AI"})]
|
||||
)
|
||||
|
||||
|
||||
def test_crew_train_error():
|
||||
@@ -1036,3 +1081,32 @@ def test_crew_train_error():
|
||||
assert "train() missing 1 required positional argument: 'n_iterations'" in str(
|
||||
e
|
||||
)
|
||||
|
||||
|
||||
def test__setup_for_training():
|
||||
researcher.allow_delegation = True
|
||||
writer.allow_delegation = True
|
||||
agents = [researcher, writer]
|
||||
task = Task(
|
||||
description="Come up with a list of 5 interesting ideas to explore for an article",
|
||||
expected_output="5 bullet points with a paragraph for each idea.",
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=agents,
|
||||
tasks=[task],
|
||||
)
|
||||
|
||||
assert crew._train is False
|
||||
assert task.human_input is False
|
||||
|
||||
for agent in agents:
|
||||
assert agent.allow_delegation is True
|
||||
|
||||
crew._setup_for_training()
|
||||
|
||||
assert crew._train is True
|
||||
assert task.human_input is True
|
||||
|
||||
for agent in agents:
|
||||
assert agent.allow_delegation is False
|
||||
|
||||
64
tests/utilities/evaluators/test_task_evaluator.py
Normal file
64
tests/utilities/evaluators/test_task_evaluator.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from crewai.utilities.evaluators.task_evaluator import (
|
||||
TaskEvaluator,
|
||||
TrainingTaskEvaluation,
|
||||
)
|
||||
|
||||
|
||||
@patch("crewai.utilities.evaluators.task_evaluator.Converter")
|
||||
def test_evaluate_training_data(converter_mock):
|
||||
training_data = {
|
||||
"agent_id": {
|
||||
"data1": {
|
||||
"initial_output": "Initial output 1",
|
||||
"human_feedback": "Human feedback 1",
|
||||
"improved_output": "Improved output 1",
|
||||
},
|
||||
"data2": {
|
||||
"initial_output": "Initial output 2",
|
||||
"human_feedback": "Human feedback 2",
|
||||
"improved_output": "Improved output 2",
|
||||
},
|
||||
}
|
||||
}
|
||||
agent_id = "agent_id"
|
||||
original_agent = MagicMock()
|
||||
function_return_value = TrainingTaskEvaluation(
|
||||
suggestions=[
|
||||
"The initial output was already good, having a detailed explanation. However, the improved output "
|
||||
"gave similar information but in a more professional manner using better vocabulary. For future tasks, "
|
||||
"try to implement more elaborate language and precise terminology from the beginning."
|
||||
],
|
||||
quality=8.0,
|
||||
final_summary="The agent responded well initially. However, the improved output showed that there is room "
|
||||
"for enhancement in terms of language usage, precision, and professionalism. For future tasks, the agent "
|
||||
"should focus more on these points from the start to increase performance.",
|
||||
)
|
||||
converter_mock.return_value.to_pydantic.return_value = function_return_value
|
||||
result = TaskEvaluator(original_agent=original_agent).evaluate_training_data(
|
||||
training_data, agent_id
|
||||
)
|
||||
|
||||
assert result == function_return_value
|
||||
converter_mock.assert_has_calls(
|
||||
[
|
||||
mock.call(
|
||||
llm=original_agent.llm,
|
||||
text="Assess the quality of the training data based on the llm output, human feedback , and llm "
|
||||
"output improved result.\n\nInitial Output:\nInitial output 1\n\nHuman Feedback:\nHuman feedback "
|
||||
"1\n\nImproved Output:\nImproved output 1\n\nInitial Output:\nInitial output 2\n\nHuman "
|
||||
"Feedback:\nHuman feedback 2\n\nImproved Output:\nImproved output 2\n\nPlease provide:\n- "
|
||||
"Based on the Human Feedbacks and the comparison between Initial Outputs and Improved outputs "
|
||||
"provide action items based on human_feedback for future tasks\n- A score from 0 to 10 evaluating "
|
||||
"on completion, quality, and overall performance from the improved output to the initial output "
|
||||
"based on the human feedback\n",
|
||||
model=TrainingTaskEvaluation,
|
||||
instructions="I'm gonna convert this raw text into valid JSON.\n\nThe json should have the "
|
||||
"following structure, with the following keys:\n- suggestions: List[str]\n- "
|
||||
"quality: float\n- final_summary: str",
|
||||
),
|
||||
mock.call().to_pydantic(),
|
||||
]
|
||||
)
|
||||
41
tests/utilities/test_file_handler.py
Normal file
41
tests/utilities/test_file_handler.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.utilities.file_handler import PickleHandler
|
||||
|
||||
|
||||
class TestPickleHandler(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.file_name = "test_data.pkl"
|
||||
self.file_path = os.path.join(os.getcwd(), self.file_name)
|
||||
self.handler = PickleHandler(self.file_name)
|
||||
|
||||
def tearDown(self):
|
||||
if os.path.exists(self.file_path):
|
||||
os.remove(self.file_path)
|
||||
|
||||
def test_initialize_file(self):
|
||||
assert os.path.exists(self.file_path) is True
|
||||
assert os.path.getsize(self.file_path) >= 0
|
||||
|
||||
def test_save_and_load(self):
|
||||
data = {"key": "value"}
|
||||
self.handler.save(data)
|
||||
loaded_data = self.handler.load()
|
||||
assert loaded_data == data
|
||||
|
||||
def test_load_empty_file(self):
|
||||
loaded_data = self.handler.load()
|
||||
assert loaded_data == {}
|
||||
|
||||
def test_load_corrupted_file(self):
|
||||
with open(self.file_path, "wb") as file:
|
||||
file.write(b"corrupted data")
|
||||
|
||||
with pytest.raises(Exception) as exc:
|
||||
self.handler.load()
|
||||
|
||||
assert str(exc.value) == "pickle data was truncated"
|
||||
assert "<class '_pickle.UnpicklingError'>" == str(exc.type)
|
||||
42
tests/utilities/test_training_handler.py
Normal file
42
tests/utilities/test_training_handler.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
|
||||
|
||||
class TestCrewTrainingHandler(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.handler = CrewTrainingHandler("trained_data.pkl")
|
||||
|
||||
def tearDown(self):
|
||||
os.remove("trained_data.pkl")
|
||||
del self.handler
|
||||
|
||||
def test_save_trained_data(self):
|
||||
agent_id = "agent1"
|
||||
trained_data = {"param1": 1, "param2": 2}
|
||||
self.handler.save_trained_data(agent_id, trained_data)
|
||||
|
||||
# Assert that the trained data is saved correctly
|
||||
data = self.handler.load()
|
||||
assert data[agent_id] == trained_data
|
||||
|
||||
def test_append_existing_agent(self):
|
||||
train_iteration = 1
|
||||
agent_id = "agent1"
|
||||
new_data = {"param3": 3, "param4": 4}
|
||||
self.handler.append(train_iteration, agent_id, new_data)
|
||||
|
||||
# Assert that the new data is appended correctly to the existing agent
|
||||
data = self.handler.load()
|
||||
assert data[agent_id][train_iteration] == new_data
|
||||
|
||||
def test_append_new_agent(self):
|
||||
train_iteration = 1
|
||||
agent_id = "agent2"
|
||||
new_data = {"param5": 5, "param6": 6}
|
||||
self.handler.append(train_iteration, agent_id, new_data)
|
||||
|
||||
# Assert that the new agent and data are appended correctly
|
||||
data = self.handler.load()
|
||||
assert data[agent_id][train_iteration] == new_data
|
||||
Reference in New Issue
Block a user