From 79013a6dc218dedb3ecb5b2a63335385c9f2206f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o?= Date: Mon, 16 Mar 2026 13:58:45 +0000 Subject: [PATCH] fix: respect custom trained_agents_data_file during inference Agents always loaded from the hardcoded 'trained_agents_data.pkl' during inference, ignoring any custom filename supplied at training time via 'crewai train -f .pkl'. Changes: - Add 'trained_agents_data_file' field to Crew (defaults to 'trained_agents_data.pkl') so users can specify which file to load trained agent suggestions from during inference. - Update Agent._use_trained_data() to accept an optional filename parameter instead of always using the hardcoded constant. - Update apply_training_data() in agent/utils.py to propagate the crew's trained_agents_data_file to the agent. - Add tests for custom filename propagation at agent and crew levels. Closes #4905 Co-Authored-By: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> --- lib/crewai/src/crewai/agent/core.py | 16 +++++-- lib/crewai/src/crewai/agent/utils.py | 8 +++- lib/crewai/src/crewai/crew.py | 14 +++++- lib/crewai/tests/agents/test_agent.py | 53 ++++++++++++++++++++ lib/crewai/tests/test_crew.py | 69 +++++++++++++++++++++++++++ 5 files changed, 155 insertions(+), 5 deletions(-) diff --git a/lib/crewai/src/crewai/agent/core.py b/lib/crewai/src/crewai/agent/core.py index 8f3c80107..b99785d7a 100644 --- a/lib/crewai/src/crewai/agent/core.py +++ b/lib/crewai/src/crewai/agent/core.py @@ -1010,9 +1010,19 @@ class Agent(BaseAgent): return task_prompt - def _use_trained_data(self, task_prompt: str) -> str: - """Use trained data for the agent task prompt to improve output.""" - if data := CrewTrainingHandler(TRAINED_AGENTS_DATA_FILE).load(): + def _use_trained_data( + self, task_prompt: str, trained_agents_data_file: str | None = None + ) -> str: + """Use trained data for the agent task prompt to improve output. + + Args: + task_prompt: The task prompt to augment. + trained_agents_data_file: Optional path to the trained agents data + file. Falls back to the default ``TRAINED_AGENTS_DATA_FILE`` + when not provided. + """ + filename = trained_agents_data_file or TRAINED_AGENTS_DATA_FILE + if data := CrewTrainingHandler(filename).load(): if trained_data_output := data.get(self.role): task_prompt += ( "\n\nYou MUST follow these instructions: \n - " diff --git a/lib/crewai/src/crewai/agent/utils.py b/lib/crewai/src/crewai/agent/utils.py index fc74db433..7d0b025cf 100644 --- a/lib/crewai/src/crewai/agent/utils.py +++ b/lib/crewai/src/crewai/agent/utils.py @@ -222,7 +222,13 @@ def apply_training_data(agent: Agent, task_prompt: str) -> str: """ if agent.crew and agent.crew._train: return agent._training_handler(task_prompt=task_prompt) - return agent._use_trained_data(task_prompt=task_prompt) + trained_agents_data_file = ( + agent.crew.trained_agents_data_file if agent.crew else None + ) + return agent._use_trained_data( + task_prompt=task_prompt, + trained_agents_data_file=trained_agents_data_file, + ) def process_tool_results(agent: Agent, result: Any) -> Any: diff --git a/lib/crewai/src/crewai/crew.py b/lib/crewai/src/crewai/crew.py index cdd371cbc..f86884006 100644 --- a/lib/crewai/src/crewai/crew.py +++ b/lib/crewai/src/crewai/crew.py @@ -96,7 +96,11 @@ from crewai.tools.agent_tools.read_file_tool import ReadFileTool from crewai.tools.base_tool import BaseTool from crewai.types.streaming import CrewStreamingOutput from crewai.types.usage_metrics import UsageMetrics -from crewai.utilities.constants import NOT_SPECIFIED, TRAINING_DATA_FILE +from crewai.utilities.constants import ( + NOT_SPECIFIED, + TRAINED_AGENTS_DATA_FILE, + TRAINING_DATA_FILE, +) from crewai.utilities.crew.models import CrewContext from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator from crewai.utilities.evaluators.task_evaluator import TaskEvaluator @@ -303,6 +307,14 @@ class Crew(FlowTrackable, BaseModel): default=None, description="Whether to enable tracing for the crew. True=always enable, False=always disable, None=check environment/user settings.", ) + trained_agents_data_file: str = Field( + default=TRAINED_AGENTS_DATA_FILE, + description=( + "Path to the file containing trained agent suggestions. " + "Defaults to 'trained_agents_data.pkl'. Set this to match the " + "custom filename used during training (e.g., via `crewai train -f`)." + ), + ) @field_validator("id", mode="before") @classmethod diff --git a/lib/crewai/tests/agents/test_agent.py b/lib/crewai/tests/agents/test_agent.py index a3aab28d6..78bfe0fb0 100644 --- a/lib/crewai/tests/agents/test_agent.py +++ b/lib/crewai/tests/agents/test_agent.py @@ -1064,6 +1064,59 @@ def test_agent_use_trained_data(crew_training_handler): ) +@patch("crewai.agent.core.CrewTrainingHandler") +def test_agent_use_trained_data_with_custom_filename(crew_training_handler): + """Test that _use_trained_data respects a custom filename when provided.""" + task_prompt = "What is 1 + 1?" + agent = Agent( + role="researcher", + goal="test goal", + backstory="test backstory", + verbose=True, + ) + crew_training_handler.return_value.load.return_value = { + agent.role: { + "suggestions": [ + "The result of the math operation must be right.", + "Result must be better than 1.", + ] + } + } + + custom_filename = "my_custom_trained.pkl" + result = agent._use_trained_data( + task_prompt=task_prompt, trained_agents_data_file=custom_filename + ) + + assert ( + result == "What is 1 + 1?\n\nYou MUST follow these instructions: \n" + " - The result of the math operation must be right.\n - Result must be better than 1." + ) + crew_training_handler.assert_has_calls( + [mock.call(custom_filename), mock.call().load()] + ) + + +@patch("crewai.agent.core.CrewTrainingHandler") +def test_agent_use_trained_data_defaults_without_custom_filename(crew_training_handler): + """Test that _use_trained_data falls back to the default file when no custom filename is given.""" + task_prompt = "What is 1 + 1?" + agent = Agent( + role="researcher", + goal="test goal", + backstory="test backstory", + verbose=True, + ) + crew_training_handler.return_value.load.return_value = {} + + result = agent._use_trained_data(task_prompt=task_prompt) + + assert result == task_prompt + crew_training_handler.assert_has_calls( + [mock.call("trained_agents_data.pkl"), mock.call().load()] + ) + + def test_agent_max_retry_limit(): agent = Agent( role="test role", diff --git a/lib/crewai/tests/test_crew.py b/lib/crewai/tests/test_crew.py index f941a7965..b7c5bc692 100644 --- a/lib/crewai/tests/test_crew.py +++ b/lib/crewai/tests/test_crew.py @@ -2974,6 +2974,75 @@ def test__setup_for_training(researcher, writer): assert agent.allow_delegation is False +def test_crew_trained_agents_data_file_defaults(researcher, writer): + """Test that Crew.trained_agents_data_file defaults to 'trained_agents_data.pkl'.""" + task = Task( + description="Test task", + expected_output="Test output", + agent=researcher, + ) + crew = Crew(agents=[researcher, writer], tasks=[task]) + assert crew.trained_agents_data_file == "trained_agents_data.pkl" + + +def test_crew_trained_agents_data_file_custom(researcher, writer): + """Test that Crew.trained_agents_data_file can be set to a custom value.""" + task = Task( + description="Test task", + expected_output="Test output", + agent=researcher, + ) + crew = Crew( + agents=[researcher, writer], + tasks=[task], + trained_agents_data_file="my_custom_trained.pkl", + ) + assert crew.trained_agents_data_file == "my_custom_trained.pkl" + + +@patch("crewai.agent.core.CrewTrainingHandler") +def test_apply_training_data_uses_crew_custom_filename(mock_handler, researcher): + """Test that apply_training_data propagates the crew's trained_agents_data_file.""" + from crewai.agent.utils import apply_training_data + + task = Task( + description="Test task", + expected_output="Test output", + agent=researcher, + ) + crew = Crew( + agents=[researcher], + tasks=[task], + trained_agents_data_file="my_custom_trained.pkl", + ) + researcher.crew = crew + + mock_handler.return_value.load.return_value = { + researcher.role: { + "suggestions": ["Be concise."] + } + } + + result = apply_training_data(researcher, "Do the task") + + mock_handler.assert_called_with("my_custom_trained.pkl") + assert "Be concise." in result + + +@patch("crewai.agent.core.CrewTrainingHandler") +def test_apply_training_data_uses_default_when_no_crew(mock_handler, researcher): + """Test that apply_training_data falls back to the default file when agent has no crew.""" + from crewai.agent.utils import apply_training_data + + researcher.crew = None + mock_handler.return_value.load.return_value = {} + + result = apply_training_data(researcher, "Do the task") + + mock_handler.assert_called_with("trained_agents_data.pkl") + assert result == "Do the task" + + @pytest.mark.vcr() def test_replay_feature(researcher, writer): list_ideas = Task(