mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 16:48:30 +00:00
feat: Add Train feature for Crews (#686)
* feat: add training logic to agent and crew * feat: add training logic to agent executor * feat: add input parameter to cli command * feat: add utilities for the training logic * feat: polish code, logic and add private variables * feat: add docstring and type hinting to executor * feat: add constant file, add constant to code * feat: fix name of training handler function * feat: remove unused var * feat: change file handler file name * feat: Add training handler file, class and change on the code * feat: fix name error from file * fix: change import to adapt to logic * feat: add training handler test * feat: add tests for file and training_handler * feat: add test for task evaluator function * feat: change text to fit in-screen * feat: add test for train function * feat: add test for agent training_handler function * feat: add test for agent._use_trained_data
This commit is contained in:
committed by
GitHub
parent
0594a7f9d8
commit
3573a61568
@@ -1,6 +1,8 @@
|
||||
"""Test Agent creation and execution basic functionality."""
|
||||
|
||||
import json
|
||||
from unittest import mock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pydantic_core
|
||||
import pytest
|
||||
@@ -1006,7 +1008,10 @@ def test_manager_agent_with_tools_raises_exception():
|
||||
crew.kickoff()
|
||||
|
||||
|
||||
def test_crew_train_success():
|
||||
@patch("crewai.crew.Crew.kickoff")
|
||||
@patch("crewai.crew.CrewTrainingHandler")
|
||||
@patch("crewai.crew.TaskEvaluator")
|
||||
def test_crew_train_success(task_evaluator, crew_training_handler, kickoff):
|
||||
task = Task(
|
||||
description="Come up with a list of 5 interesting ideas to explore for an article, then write one amazing paragraph highlight for each idea that showcases how good an article about this topic could be. Return the list of ideas with their paragraph and your notes.",
|
||||
expected_output="5 bullet points with a paragraph for each idea.",
|
||||
@@ -1016,8 +1021,48 @@ def test_crew_train_success():
|
||||
agents=[researcher, writer],
|
||||
tasks=[task],
|
||||
)
|
||||
crew.train(n_iterations=2, inputs={"topic": "AI"})
|
||||
task_evaluator.assert_has_calls(
|
||||
[
|
||||
mock.call(researcher),
|
||||
mock.call().evaluate_training_data(
|
||||
training_data=crew_training_handler().load(),
|
||||
agent_id=str(researcher.id),
|
||||
),
|
||||
mock.call().evaluate_training_data().model_dump(),
|
||||
mock.call(writer),
|
||||
mock.call().evaluate_training_data(
|
||||
training_data=crew_training_handler().load(),
|
||||
agent_id=str(writer.id),
|
||||
),
|
||||
mock.call().evaluate_training_data().model_dump(),
|
||||
]
|
||||
)
|
||||
|
||||
crew.train(n_iterations=2)
|
||||
crew_training_handler.assert_has_calls(
|
||||
[
|
||||
mock.call("training_data.pkl"),
|
||||
mock.call().load(),
|
||||
mock.call("trained_agents_data.pkl"),
|
||||
mock.call().save_trained_data(
|
||||
agent_id="Researcher",
|
||||
trained_data=task_evaluator().evaluate_training_data().model_dump(),
|
||||
),
|
||||
mock.call("trained_agents_data.pkl"),
|
||||
mock.call().save_trained_data(
|
||||
agent_id="Senior Writer",
|
||||
trained_data=task_evaluator().evaluate_training_data().model_dump(),
|
||||
),
|
||||
mock.call(),
|
||||
mock.call().load(),
|
||||
mock.call(),
|
||||
mock.call().load(),
|
||||
]
|
||||
)
|
||||
|
||||
kickoff.assert_has_calls(
|
||||
[mock.call(inputs={"topic": "AI"}), mock.call(inputs={"topic": "AI"})]
|
||||
)
|
||||
|
||||
|
||||
def test_crew_train_error():
|
||||
@@ -1036,3 +1081,32 @@ def test_crew_train_error():
|
||||
assert "train() missing 1 required positional argument: 'n_iterations'" in str(
|
||||
e
|
||||
)
|
||||
|
||||
|
||||
def test__setup_for_training():
|
||||
researcher.allow_delegation = True
|
||||
writer.allow_delegation = True
|
||||
agents = [researcher, writer]
|
||||
task = Task(
|
||||
description="Come up with a list of 5 interesting ideas to explore for an article",
|
||||
expected_output="5 bullet points with a paragraph for each idea.",
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=agents,
|
||||
tasks=[task],
|
||||
)
|
||||
|
||||
assert crew._train is False
|
||||
assert task.human_input is False
|
||||
|
||||
for agent in agents:
|
||||
assert agent.allow_delegation is True
|
||||
|
||||
crew._setup_for_training()
|
||||
|
||||
assert crew._train is True
|
||||
assert task.human_input is True
|
||||
|
||||
for agent in agents:
|
||||
assert agent.allow_delegation is False
|
||||
|
||||
Reference in New Issue
Block a user