Compare commits

...

2 Commits

Author SHA1 Message Date
Lorenze Jay
d79c4a62a1 Merge branch 'main' into devin/1773669058-fix-trained-agents-data-file 2026-04-09 13:19:44 -07:00
João
79013a6dc2 fix: respect custom trained_agents_data_file during inference
Agents always loaded from the hardcoded 'trained_agents_data.pkl' during
inference, ignoring any custom filename supplied at training time via
'crewai train -f <custom>.pkl'.

Changes:
- Add 'trained_agents_data_file' field to Crew (defaults to
  'trained_agents_data.pkl') so users can specify which file to load
  trained agent suggestions from during inference.
- Update Agent._use_trained_data() to accept an optional filename
  parameter instead of always using the hardcoded constant.
- Update apply_training_data() in agent/utils.py to propagate the
  crew's trained_agents_data_file to the agent.
- Add tests for custom filename propagation at agent and crew levels.

Closes #4905

Co-Authored-By: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com>
2026-03-16 13:58:45 +00:00
5 changed files with 155 additions and 5 deletions

View File

@@ -1161,9 +1161,19 @@ class Agent(BaseAgent):
return task_prompt
def _use_trained_data(self, task_prompt: str) -> str:
"""Use trained data for the agent task prompt to improve output."""
if data := CrewTrainingHandler(TRAINED_AGENTS_DATA_FILE).load():
def _use_trained_data(
self, task_prompt: str, trained_agents_data_file: str | None = None
) -> str:
"""Use trained data for the agent task prompt to improve output.
Args:
task_prompt: The task prompt to augment.
trained_agents_data_file: Optional path to the trained agents data
file. Falls back to the default ``TRAINED_AGENTS_DATA_FILE``
when not provided.
"""
filename = trained_agents_data_file or TRAINED_AGENTS_DATA_FILE
if data := CrewTrainingHandler(filename).load():
if trained_data_output := data.get(self.role):
task_prompt += (
"\n\nYou MUST follow these instructions: \n - "

View File

@@ -250,7 +250,13 @@ def apply_training_data(agent: Agent, task_prompt: str) -> str:
"""
if agent.crew and not isinstance(agent.crew, str) and agent.crew._train:
return agent._training_handler(task_prompt=task_prompt)
return agent._use_trained_data(task_prompt=task_prompt)
trained_agents_data_file = (
agent.crew.trained_agents_data_file if agent.crew else None
)
return agent._use_trained_data(
task_prompt=task_prompt,
trained_agents_data_file=trained_agents_data_file,
)
def process_tool_results(agent: Agent, result: Any) -> Any:

View File

@@ -117,7 +117,11 @@ from crewai.tools.base_tool import BaseTool
from crewai.types.callback import SerializableCallable
from crewai.types.streaming import CrewStreamingOutput
from crewai.types.usage_metrics import UsageMetrics
from crewai.utilities.constants import NOT_SPECIFIED, TRAINING_DATA_FILE
from crewai.utilities.constants import (
NOT_SPECIFIED,
TRAINED_AGENTS_DATA_FILE,
TRAINING_DATA_FILE,
)
from crewai.utilities.crew.models import CrewContext
from crewai.utilities.env import get_env_context
from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator
@@ -361,6 +365,14 @@ class Crew(FlowTrackable, BaseModel):
default=None,
description="Whether to enable tracing for the crew. True=always enable, False=always disable, None=check environment/user settings.",
)
trained_agents_data_file: str = Field(
default=TRAINED_AGENTS_DATA_FILE,
description=(
"Path to the file containing trained agent suggestions. "
"Defaults to 'trained_agents_data.pkl'. Set this to match the "
"custom filename used during training (e.g., via `crewai train -f`)."
),
)
execution_context: ExecutionContext | None = Field(default=None)
checkpoint_inputs: dict[str, Any] | None = Field(default=None)

View File

@@ -1064,6 +1064,59 @@ def test_agent_use_trained_data(crew_training_handler):
)
@patch("crewai.agent.core.CrewTrainingHandler")
def test_agent_use_trained_data_with_custom_filename(crew_training_handler):
"""Test that _use_trained_data respects a custom filename when provided."""
task_prompt = "What is 1 + 1?"
agent = Agent(
role="researcher",
goal="test goal",
backstory="test backstory",
verbose=True,
)
crew_training_handler.return_value.load.return_value = {
agent.role: {
"suggestions": [
"The result of the math operation must be right.",
"Result must be better than 1.",
]
}
}
custom_filename = "my_custom_trained.pkl"
result = agent._use_trained_data(
task_prompt=task_prompt, trained_agents_data_file=custom_filename
)
assert (
result == "What is 1 + 1?\n\nYou MUST follow these instructions: \n"
" - The result of the math operation must be right.\n - Result must be better than 1."
)
crew_training_handler.assert_has_calls(
[mock.call(custom_filename), mock.call().load()]
)
@patch("crewai.agent.core.CrewTrainingHandler")
def test_agent_use_trained_data_defaults_without_custom_filename(crew_training_handler):
"""Test that _use_trained_data falls back to the default file when no custom filename is given."""
task_prompt = "What is 1 + 1?"
agent = Agent(
role="researcher",
goal="test goal",
backstory="test backstory",
verbose=True,
)
crew_training_handler.return_value.load.return_value = {}
result = agent._use_trained_data(task_prompt=task_prompt)
assert result == task_prompt
crew_training_handler.assert_has_calls(
[mock.call("trained_agents_data.pkl"), mock.call().load()]
)
def test_agent_max_retry_limit():
agent = Agent(
role="test role",

View File

@@ -2971,6 +2971,75 @@ def test__setup_for_training(researcher, writer):
assert agent.allow_delegation is False
def test_crew_trained_agents_data_file_defaults(researcher, writer):
"""Test that Crew.trained_agents_data_file defaults to 'trained_agents_data.pkl'."""
task = Task(
description="Test task",
expected_output="Test output",
agent=researcher,
)
crew = Crew(agents=[researcher, writer], tasks=[task])
assert crew.trained_agents_data_file == "trained_agents_data.pkl"
def test_crew_trained_agents_data_file_custom(researcher, writer):
"""Test that Crew.trained_agents_data_file can be set to a custom value."""
task = Task(
description="Test task",
expected_output="Test output",
agent=researcher,
)
crew = Crew(
agents=[researcher, writer],
tasks=[task],
trained_agents_data_file="my_custom_trained.pkl",
)
assert crew.trained_agents_data_file == "my_custom_trained.pkl"
@patch("crewai.agent.core.CrewTrainingHandler")
def test_apply_training_data_uses_crew_custom_filename(mock_handler, researcher):
"""Test that apply_training_data propagates the crew's trained_agents_data_file."""
from crewai.agent.utils import apply_training_data
task = Task(
description="Test task",
expected_output="Test output",
agent=researcher,
)
crew = Crew(
agents=[researcher],
tasks=[task],
trained_agents_data_file="my_custom_trained.pkl",
)
researcher.crew = crew
mock_handler.return_value.load.return_value = {
researcher.role: {
"suggestions": ["Be concise."]
}
}
result = apply_training_data(researcher, "Do the task")
mock_handler.assert_called_with("my_custom_trained.pkl")
assert "Be concise." in result
@patch("crewai.agent.core.CrewTrainingHandler")
def test_apply_training_data_uses_default_when_no_crew(mock_handler, researcher):
"""Test that apply_training_data falls back to the default file when agent has no crew."""
from crewai.agent.utils import apply_training_data
researcher.crew = None
mock_handler.return_value.load.return_value = {}
result = apply_training_data(researcher, "Do the task")
mock_handler.assert_called_with("trained_agents_data.pkl")
assert result == "Do the task"
@pytest.mark.vcr()
def test_replay_feature(researcher, writer):
list_ideas = Task(