mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-02 05:38:12 +00:00
fix: respect custom trained_agents_data_file during inference
Agents always loaded from the hardcoded 'trained_agents_data.pkl' during inference, ignoring any custom filename supplied at training time via 'crewai train -f <custom>.pkl'. Changes: - Add 'trained_agents_data_file' field to Crew (defaults to 'trained_agents_data.pkl') so users can specify which file to load trained agent suggestions from during inference. - Update Agent._use_trained_data() to accept an optional filename parameter instead of always using the hardcoded constant. - Update apply_training_data() in agent/utils.py to propagate the crew's trained_agents_data_file to the agent. - Add tests for custom filename propagation at agent and crew levels. Closes #4905 Co-Authored-By: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1010,9 +1010,19 @@ class Agent(BaseAgent):
|
||||
|
||||
return task_prompt
|
||||
|
||||
def _use_trained_data(self, task_prompt: str) -> str:
|
||||
"""Use trained data for the agent task prompt to improve output."""
|
||||
if data := CrewTrainingHandler(TRAINED_AGENTS_DATA_FILE).load():
|
||||
def _use_trained_data(
|
||||
self, task_prompt: str, trained_agents_data_file: str | None = None
|
||||
) -> str:
|
||||
"""Use trained data for the agent task prompt to improve output.
|
||||
|
||||
Args:
|
||||
task_prompt: The task prompt to augment.
|
||||
trained_agents_data_file: Optional path to the trained agents data
|
||||
file. Falls back to the default ``TRAINED_AGENTS_DATA_FILE``
|
||||
when not provided.
|
||||
"""
|
||||
filename = trained_agents_data_file or TRAINED_AGENTS_DATA_FILE
|
||||
if data := CrewTrainingHandler(filename).load():
|
||||
if trained_data_output := data.get(self.role):
|
||||
task_prompt += (
|
||||
"\n\nYou MUST follow these instructions: \n - "
|
||||
|
||||
@@ -222,7 +222,13 @@ def apply_training_data(agent: Agent, task_prompt: str) -> str:
|
||||
"""
|
||||
if agent.crew and agent.crew._train:
|
||||
return agent._training_handler(task_prompt=task_prompt)
|
||||
return agent._use_trained_data(task_prompt=task_prompt)
|
||||
trained_agents_data_file = (
|
||||
agent.crew.trained_agents_data_file if agent.crew else None
|
||||
)
|
||||
return agent._use_trained_data(
|
||||
task_prompt=task_prompt,
|
||||
trained_agents_data_file=trained_agents_data_file,
|
||||
)
|
||||
|
||||
|
||||
def process_tool_results(agent: Agent, result: Any) -> Any:
|
||||
|
||||
@@ -96,7 +96,11 @@ from crewai.tools.agent_tools.read_file_tool import ReadFileTool
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.types.streaming import CrewStreamingOutput
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
from crewai.utilities.constants import NOT_SPECIFIED, TRAINING_DATA_FILE
|
||||
from crewai.utilities.constants import (
|
||||
NOT_SPECIFIED,
|
||||
TRAINED_AGENTS_DATA_FILE,
|
||||
TRAINING_DATA_FILE,
|
||||
)
|
||||
from crewai.utilities.crew.models import CrewContext
|
||||
from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator
|
||||
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
|
||||
@@ -303,6 +307,14 @@ class Crew(FlowTrackable, BaseModel):
|
||||
default=None,
|
||||
description="Whether to enable tracing for the crew. True=always enable, False=always disable, None=check environment/user settings.",
|
||||
)
|
||||
trained_agents_data_file: str = Field(
|
||||
default=TRAINED_AGENTS_DATA_FILE,
|
||||
description=(
|
||||
"Path to the file containing trained agent suggestions. "
|
||||
"Defaults to 'trained_agents_data.pkl'. Set this to match the "
|
||||
"custom filename used during training (e.g., via `crewai train -f`)."
|
||||
),
|
||||
)
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
|
||||
@@ -1064,6 +1064,59 @@ def test_agent_use_trained_data(crew_training_handler):
|
||||
)
|
||||
|
||||
|
||||
@patch("crewai.agent.core.CrewTrainingHandler")
|
||||
def test_agent_use_trained_data_with_custom_filename(crew_training_handler):
|
||||
"""Test that _use_trained_data respects a custom filename when provided."""
|
||||
task_prompt = "What is 1 + 1?"
|
||||
agent = Agent(
|
||||
role="researcher",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
verbose=True,
|
||||
)
|
||||
crew_training_handler.return_value.load.return_value = {
|
||||
agent.role: {
|
||||
"suggestions": [
|
||||
"The result of the math operation must be right.",
|
||||
"Result must be better than 1.",
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
custom_filename = "my_custom_trained.pkl"
|
||||
result = agent._use_trained_data(
|
||||
task_prompt=task_prompt, trained_agents_data_file=custom_filename
|
||||
)
|
||||
|
||||
assert (
|
||||
result == "What is 1 + 1?\n\nYou MUST follow these instructions: \n"
|
||||
" - The result of the math operation must be right.\n - Result must be better than 1."
|
||||
)
|
||||
crew_training_handler.assert_has_calls(
|
||||
[mock.call(custom_filename), mock.call().load()]
|
||||
)
|
||||
|
||||
|
||||
@patch("crewai.agent.core.CrewTrainingHandler")
|
||||
def test_agent_use_trained_data_defaults_without_custom_filename(crew_training_handler):
|
||||
"""Test that _use_trained_data falls back to the default file when no custom filename is given."""
|
||||
task_prompt = "What is 1 + 1?"
|
||||
agent = Agent(
|
||||
role="researcher",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
verbose=True,
|
||||
)
|
||||
crew_training_handler.return_value.load.return_value = {}
|
||||
|
||||
result = agent._use_trained_data(task_prompt=task_prompt)
|
||||
|
||||
assert result == task_prompt
|
||||
crew_training_handler.assert_has_calls(
|
||||
[mock.call("trained_agents_data.pkl"), mock.call().load()]
|
||||
)
|
||||
|
||||
|
||||
def test_agent_max_retry_limit():
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
|
||||
@@ -2974,6 +2974,75 @@ def test__setup_for_training(researcher, writer):
|
||||
assert agent.allow_delegation is False
|
||||
|
||||
|
||||
def test_crew_trained_agents_data_file_defaults(researcher, writer):
|
||||
"""Test that Crew.trained_agents_data_file defaults to 'trained_agents_data.pkl'."""
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
agent=researcher,
|
||||
)
|
||||
crew = Crew(agents=[researcher, writer], tasks=[task])
|
||||
assert crew.trained_agents_data_file == "trained_agents_data.pkl"
|
||||
|
||||
|
||||
def test_crew_trained_agents_data_file_custom(researcher, writer):
|
||||
"""Test that Crew.trained_agents_data_file can be set to a custom value."""
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
agent=researcher,
|
||||
)
|
||||
crew = Crew(
|
||||
agents=[researcher, writer],
|
||||
tasks=[task],
|
||||
trained_agents_data_file="my_custom_trained.pkl",
|
||||
)
|
||||
assert crew.trained_agents_data_file == "my_custom_trained.pkl"
|
||||
|
||||
|
||||
@patch("crewai.agent.core.CrewTrainingHandler")
|
||||
def test_apply_training_data_uses_crew_custom_filename(mock_handler, researcher):
|
||||
"""Test that apply_training_data propagates the crew's trained_agents_data_file."""
|
||||
from crewai.agent.utils import apply_training_data
|
||||
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
agent=researcher,
|
||||
)
|
||||
crew = Crew(
|
||||
agents=[researcher],
|
||||
tasks=[task],
|
||||
trained_agents_data_file="my_custom_trained.pkl",
|
||||
)
|
||||
researcher.crew = crew
|
||||
|
||||
mock_handler.return_value.load.return_value = {
|
||||
researcher.role: {
|
||||
"suggestions": ["Be concise."]
|
||||
}
|
||||
}
|
||||
|
||||
result = apply_training_data(researcher, "Do the task")
|
||||
|
||||
mock_handler.assert_called_with("my_custom_trained.pkl")
|
||||
assert "Be concise." in result
|
||||
|
||||
|
||||
@patch("crewai.agent.core.CrewTrainingHandler")
|
||||
def test_apply_training_data_uses_default_when_no_crew(mock_handler, researcher):
|
||||
"""Test that apply_training_data falls back to the default file when agent has no crew."""
|
||||
from crewai.agent.utils import apply_training_data
|
||||
|
||||
researcher.crew = None
|
||||
mock_handler.return_value.load.return_value = {}
|
||||
|
||||
result = apply_training_data(researcher, "Do the task")
|
||||
|
||||
mock_handler.assert_called_with("trained_agents_data.pkl")
|
||||
assert result == "Do the task"
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_replay_feature(researcher, writer):
|
||||
list_ideas = Task(
|
||||
|
||||
Reference in New Issue
Block a user