Add crew trained agents file support (#6012)

* Add crew trained agents file support

* Add crew trained agents file support
This commit is contained in:
Lorenze Jay
2026-06-02 09:38:34 -07:00
committed by GitHub
parent 383ae66b55
commit a9cb7867bb
5 changed files with 92 additions and 3 deletions

View File

@@ -187,7 +187,7 @@ flowchart TD
- **Filename Requirement:** Ensure that the filename ends with `.pkl`. The code will raise a `ValueError` if this condition is not met.
- **Error Handling:** The code handles subprocess errors and unexpected exceptions, providing error messages to the user.
- Trained guidance is applied at prompt time; it does not modify your Python/YAML agent configuration.
- Agents automatically load trained suggestions from a file named `trained_agents_data.pkl` located in the current working directory. If you trained to a different filename, either rename it to `trained_agents_data.pkl` before running, or adjust the loader in code.
- Agents automatically load trained suggestions from a file named `trained_agents_data.pkl` located in the current working directory. If you trained to a different filename, pass that path with `Crew(trained_agents_file="my_custom_trained.pkl")`, set `CREWAI_TRAINED_AGENTS_FILE`, or use `crewai run -f my_custom_trained.pkl`.
- You can change the output filename when calling `crewai train` with `-f/--filename`. Absolute paths are supported if you want to save outside the CWD.
It is important to note that the training process may take some time, depending on the complexity of your agents and will also require your feedback on each iteration.

View File

@@ -1219,9 +1219,17 @@ class Agent(BaseAgent):
def _use_trained_data(self, task_prompt: str) -> str:
"""Use trained data for the agent task prompt to improve output."""
trained_file = os.getenv(
CREWAI_TRAINED_AGENTS_FILE_ENV, TRAINED_AGENTS_DATA_FILE
crew_trained_agents_file = (
getattr(self.crew, "trained_agents_file", None)
if self.crew and not isinstance(self.crew, str)
else None
)
trained_file = (
os.fspath(crew_trained_agents_file)
if crew_trained_agents_file
else os.getenv(CREWAI_TRAINED_AGENTS_FILE_ENV, TRAINED_AGENTS_DATA_FILE)
)
if data := CrewTrainingHandler(trained_file).load():
if trained_data_output := data.get(self.role):
task_prompt += (

View File

@@ -179,6 +179,7 @@ class Crew(FlowTrackable, BaseModel):
max_rpm: Maximum number of requests per minute for the crew execution to
be respected.
prompt_file: Path to the prompt json file to be used for the crew.
trained_agents_file: Path to trained agent suggestions loaded during inference.
id: A unique identifier for the crew instance.
task_callback: Callback to be executed after each task for every agents
execution.
@@ -303,6 +304,13 @@ class Crew(FlowTrackable, BaseModel):
default=None,
description="Path to the prompt json file to be used for the crew.",
)
trained_agents_file: str | Path | None = Field(
default=None,
description=(
"Path to a trained-agents pickle produced by train(). "
"When set, agents load suggestions from this file during inference."
),
)
output_log_file: bool | str | None = Field(
default=None,
description="Path to the log file to be saved",

View File

@@ -1067,6 +1067,62 @@ def test_agent_use_trained_data_honors_env_var(crew_training_handler, monkeypatc
)
@patch("crewai.agent.core.CrewTrainingHandler")
def test_agent_use_trained_data_prefers_crew_trained_agents_file(
crew_training_handler, monkeypatch
):
monkeypatch.setenv("CREWAI_TRAINED_AGENTS_FILE", "env_trained.pkl")
agent = Agent(
role="researcher",
goal="test goal",
backstory="test backstory",
)
task = Task(
description="Research the topic",
expected_output="A short report",
agent=agent,
)
crew = Crew(agents=[agent], tasks=[task], trained_agents_file="crew_trained.pkl")
agent.crew = crew
crew_training_handler.return_value.load.return_value = {}
agent._use_trained_data(task_prompt="What is 1 + 1?")
crew_training_handler.assert_has_calls(
[mock.call("crew_trained.pkl"), mock.call().load()]
)
@patch("crewai.agent.core.CrewTrainingHandler")
def test_agent_use_trained_data_accepts_crew_trained_agents_file_path(
crew_training_handler, tmp_path
):
agent = Agent(
role="researcher",
goal="test goal",
backstory="test backstory",
)
task = Task(
description="Research the topic",
expected_output="A short report",
agent=agent,
)
trained_agents_file = tmp_path / "crew_trained.pkl"
crew = Crew(
agents=[agent],
tasks=[task],
trained_agents_file=trained_agents_file,
)
agent.crew = crew
crew_training_handler.return_value.load.return_value = {}
agent._use_trained_data(task_prompt="What is 1 + 1?")
crew_training_handler.assert_has_calls(
[mock.call(str(trained_agents_file)), mock.call().load()]
)
def test_agent_use_trained_data_skips_load_when_file_missing(tmp_path, monkeypatch):
monkeypatch.setenv(
"CREWAI_TRAINED_AGENTS_FILE", str(tmp_path / "does_not_exist.pkl")

View File

@@ -3010,6 +3010,23 @@ def test__setup_for_training(researcher, writer):
assert agent.allow_delegation is False
def test_crew_trained_agents_file_is_preserved_on_copy(researcher):
task = Task(
description="Come up with a list of 5 interesting ideas to explore for an article",
expected_output="5 bullet points with a paragraph for each idea.",
agent=researcher,
)
crew = Crew(
agents=[researcher],
tasks=[task],
trained_agents_file="custom_trained_agents.pkl",
)
cloned_crew = crew.copy()
assert cloned_crew.trained_agents_file == "custom_trained_agents.pkl"
@pytest.mark.vcr()
def test_replay_feature(researcher, writer):
list_ideas = Task(