feat: Add Train feature for Crews (#686)

* feat: add training logic to agent and crew

* feat: add training logic to agent executor

* feat: add input parameter  to cli command

* feat: add utilities for the training logic

* feat: polish code, logic and add private variables

* feat: add docstring and type hinting to executor

* feat: add constant file, add constant to code

* feat: fix name of training handler function

* feat: remove unused var

* feat: change file handler file name

* feat: Add training handler file, class and change on the code

* feat: fix name error from file

* fix: change import to adapt to logic

* feat: add training handler test

* feat: add tests for file and training_handler

* feat: add test for task evaluator function

* feat: change text to fit in-screen

* feat: add test for train function

* feat: add test for agent training_handler function

* feat: add test for agent._use_trained_data
This commit is contained in:
Eduardo Chiarotti
2024-06-27 02:22:34 -03:00
committed by GitHub
parent 0594a7f9d8
commit 3573a61568
15 changed files with 564 additions and 45 deletions

View File

@@ -1,6 +1,8 @@
"""Test Agent creation and execution basic functionality."""
import json
from unittest import mock
from unittest.mock import patch
import pydantic_core
import pytest
@@ -1006,7 +1008,10 @@ def test_manager_agent_with_tools_raises_exception():
crew.kickoff()
def test_crew_train_success():
@patch("crewai.crew.Crew.kickoff")
@patch("crewai.crew.CrewTrainingHandler")
@patch("crewai.crew.TaskEvaluator")
def test_crew_train_success(task_evaluator, crew_training_handler, kickoff):
task = Task(
description="Come up with a list of 5 interesting ideas to explore for an article, then write one amazing paragraph highlight for each idea that showcases how good an article about this topic could be. Return the list of ideas with their paragraph and your notes.",
expected_output="5 bullet points with a paragraph for each idea.",
@@ -1016,8 +1021,48 @@ def test_crew_train_success():
agents=[researcher, writer],
tasks=[task],
)
crew.train(n_iterations=2, inputs={"topic": "AI"})
task_evaluator.assert_has_calls(
[
mock.call(researcher),
mock.call().evaluate_training_data(
training_data=crew_training_handler().load(),
agent_id=str(researcher.id),
),
mock.call().evaluate_training_data().model_dump(),
mock.call(writer),
mock.call().evaluate_training_data(
training_data=crew_training_handler().load(),
agent_id=str(writer.id),
),
mock.call().evaluate_training_data().model_dump(),
]
)
crew.train(n_iterations=2)
crew_training_handler.assert_has_calls(
[
mock.call("training_data.pkl"),
mock.call().load(),
mock.call("trained_agents_data.pkl"),
mock.call().save_trained_data(
agent_id="Researcher",
trained_data=task_evaluator().evaluate_training_data().model_dump(),
),
mock.call("trained_agents_data.pkl"),
mock.call().save_trained_data(
agent_id="Senior Writer",
trained_data=task_evaluator().evaluate_training_data().model_dump(),
),
mock.call(),
mock.call().load(),
mock.call(),
mock.call().load(),
]
)
kickoff.assert_has_calls(
[mock.call(inputs={"topic": "AI"}), mock.call(inputs={"topic": "AI"})]
)
def test_crew_train_error():
@@ -1036,3 +1081,32 @@ def test_crew_train_error():
assert "train() missing 1 required positional argument: 'n_iterations'" in str(
e
)
def test__setup_for_training():
researcher.allow_delegation = True
writer.allow_delegation = True
agents = [researcher, writer]
task = Task(
description="Come up with a list of 5 interesting ideas to explore for an article",
expected_output="5 bullet points with a paragraph for each idea.",
)
crew = Crew(
agents=agents,
tasks=[task],
)
assert crew._train is False
assert task.human_input is False
for agent in agents:
assert agent.allow_delegation is True
crew._setup_for_training()
assert crew._train is True
assert task.human_input is True
for agent in agents:
assert agent.allow_delegation is False