mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-14 15:02:37 +00:00
Compare commits
2 Commits
1.14.2a3
...
devin/1773
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d79c4a62a1 | ||
|
|
79013a6dc2 |
@@ -1161,9 +1161,19 @@ class Agent(BaseAgent):
|
||||
|
||||
return task_prompt
|
||||
|
||||
def _use_trained_data(self, task_prompt: str) -> str:
|
||||
"""Use trained data for the agent task prompt to improve output."""
|
||||
if data := CrewTrainingHandler(TRAINED_AGENTS_DATA_FILE).load():
|
||||
def _use_trained_data(
|
||||
self, task_prompt: str, trained_agents_data_file: str | None = None
|
||||
) -> str:
|
||||
"""Use trained data for the agent task prompt to improve output.
|
||||
|
||||
Args:
|
||||
task_prompt: The task prompt to augment.
|
||||
trained_agents_data_file: Optional path to the trained agents data
|
||||
file. Falls back to the default ``TRAINED_AGENTS_DATA_FILE``
|
||||
when not provided.
|
||||
"""
|
||||
filename = trained_agents_data_file or TRAINED_AGENTS_DATA_FILE
|
||||
if data := CrewTrainingHandler(filename).load():
|
||||
if trained_data_output := data.get(self.role):
|
||||
task_prompt += (
|
||||
"\n\nYou MUST follow these instructions: \n - "
|
||||
|
||||
@@ -250,7 +250,13 @@ def apply_training_data(agent: Agent, task_prompt: str) -> str:
|
||||
"""
|
||||
if agent.crew and not isinstance(agent.crew, str) and agent.crew._train:
|
||||
return agent._training_handler(task_prompt=task_prompt)
|
||||
return agent._use_trained_data(task_prompt=task_prompt)
|
||||
trained_agents_data_file = (
|
||||
agent.crew.trained_agents_data_file if agent.crew else None
|
||||
)
|
||||
return agent._use_trained_data(
|
||||
task_prompt=task_prompt,
|
||||
trained_agents_data_file=trained_agents_data_file,
|
||||
)
|
||||
|
||||
|
||||
def process_tool_results(agent: Agent, result: Any) -> Any:
|
||||
|
||||
@@ -117,7 +117,11 @@ from crewai.tools.base_tool import BaseTool
|
||||
from crewai.types.callback import SerializableCallable
|
||||
from crewai.types.streaming import CrewStreamingOutput
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
from crewai.utilities.constants import NOT_SPECIFIED, TRAINING_DATA_FILE
|
||||
from crewai.utilities.constants import (
|
||||
NOT_SPECIFIED,
|
||||
TRAINED_AGENTS_DATA_FILE,
|
||||
TRAINING_DATA_FILE,
|
||||
)
|
||||
from crewai.utilities.crew.models import CrewContext
|
||||
from crewai.utilities.env import get_env_context
|
||||
from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator
|
||||
@@ -361,6 +365,14 @@ class Crew(FlowTrackable, BaseModel):
|
||||
default=None,
|
||||
description="Whether to enable tracing for the crew. True=always enable, False=always disable, None=check environment/user settings.",
|
||||
)
|
||||
trained_agents_data_file: str = Field(
|
||||
default=TRAINED_AGENTS_DATA_FILE,
|
||||
description=(
|
||||
"Path to the file containing trained agent suggestions. "
|
||||
"Defaults to 'trained_agents_data.pkl'. Set this to match the "
|
||||
"custom filename used during training (e.g., via `crewai train -f`)."
|
||||
),
|
||||
)
|
||||
|
||||
execution_context: ExecutionContext | None = Field(default=None)
|
||||
checkpoint_inputs: dict[str, Any] | None = Field(default=None)
|
||||
|
||||
@@ -1064,6 +1064,59 @@ def test_agent_use_trained_data(crew_training_handler):
|
||||
)
|
||||
|
||||
|
||||
@patch("crewai.agent.core.CrewTrainingHandler")
|
||||
def test_agent_use_trained_data_with_custom_filename(crew_training_handler):
|
||||
"""Test that _use_trained_data respects a custom filename when provided."""
|
||||
task_prompt = "What is 1 + 1?"
|
||||
agent = Agent(
|
||||
role="researcher",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
verbose=True,
|
||||
)
|
||||
crew_training_handler.return_value.load.return_value = {
|
||||
agent.role: {
|
||||
"suggestions": [
|
||||
"The result of the math operation must be right.",
|
||||
"Result must be better than 1.",
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
custom_filename = "my_custom_trained.pkl"
|
||||
result = agent._use_trained_data(
|
||||
task_prompt=task_prompt, trained_agents_data_file=custom_filename
|
||||
)
|
||||
|
||||
assert (
|
||||
result == "What is 1 + 1?\n\nYou MUST follow these instructions: \n"
|
||||
" - The result of the math operation must be right.\n - Result must be better than 1."
|
||||
)
|
||||
crew_training_handler.assert_has_calls(
|
||||
[mock.call(custom_filename), mock.call().load()]
|
||||
)
|
||||
|
||||
|
||||
@patch("crewai.agent.core.CrewTrainingHandler")
|
||||
def test_agent_use_trained_data_defaults_without_custom_filename(crew_training_handler):
|
||||
"""Test that _use_trained_data falls back to the default file when no custom filename is given."""
|
||||
task_prompt = "What is 1 + 1?"
|
||||
agent = Agent(
|
||||
role="researcher",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
verbose=True,
|
||||
)
|
||||
crew_training_handler.return_value.load.return_value = {}
|
||||
|
||||
result = agent._use_trained_data(task_prompt=task_prompt)
|
||||
|
||||
assert result == task_prompt
|
||||
crew_training_handler.assert_has_calls(
|
||||
[mock.call("trained_agents_data.pkl"), mock.call().load()]
|
||||
)
|
||||
|
||||
|
||||
def test_agent_max_retry_limit():
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
|
||||
@@ -2971,6 +2971,75 @@ def test__setup_for_training(researcher, writer):
|
||||
assert agent.allow_delegation is False
|
||||
|
||||
|
||||
def test_crew_trained_agents_data_file_defaults(researcher, writer):
|
||||
"""Test that Crew.trained_agents_data_file defaults to 'trained_agents_data.pkl'."""
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
agent=researcher,
|
||||
)
|
||||
crew = Crew(agents=[researcher, writer], tasks=[task])
|
||||
assert crew.trained_agents_data_file == "trained_agents_data.pkl"
|
||||
|
||||
|
||||
def test_crew_trained_agents_data_file_custom(researcher, writer):
|
||||
"""Test that Crew.trained_agents_data_file can be set to a custom value."""
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
agent=researcher,
|
||||
)
|
||||
crew = Crew(
|
||||
agents=[researcher, writer],
|
||||
tasks=[task],
|
||||
trained_agents_data_file="my_custom_trained.pkl",
|
||||
)
|
||||
assert crew.trained_agents_data_file == "my_custom_trained.pkl"
|
||||
|
||||
|
||||
@patch("crewai.agent.core.CrewTrainingHandler")
|
||||
def test_apply_training_data_uses_crew_custom_filename(mock_handler, researcher):
|
||||
"""Test that apply_training_data propagates the crew's trained_agents_data_file."""
|
||||
from crewai.agent.utils import apply_training_data
|
||||
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
agent=researcher,
|
||||
)
|
||||
crew = Crew(
|
||||
agents=[researcher],
|
||||
tasks=[task],
|
||||
trained_agents_data_file="my_custom_trained.pkl",
|
||||
)
|
||||
researcher.crew = crew
|
||||
|
||||
mock_handler.return_value.load.return_value = {
|
||||
researcher.role: {
|
||||
"suggestions": ["Be concise."]
|
||||
}
|
||||
}
|
||||
|
||||
result = apply_training_data(researcher, "Do the task")
|
||||
|
||||
mock_handler.assert_called_with("my_custom_trained.pkl")
|
||||
assert "Be concise." in result
|
||||
|
||||
|
||||
@patch("crewai.agent.core.CrewTrainingHandler")
|
||||
def test_apply_training_data_uses_default_when_no_crew(mock_handler, researcher):
|
||||
"""Test that apply_training_data falls back to the default file when agent has no crew."""
|
||||
from crewai.agent.utils import apply_training_data
|
||||
|
||||
researcher.crew = None
|
||||
mock_handler.return_value.load.return_value = {}
|
||||
|
||||
result = apply_training_data(researcher, "Do the task")
|
||||
|
||||
mock_handler.assert_called_with("trained_agents_data.pkl")
|
||||
assert result == "Do the task"
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_replay_feature(researcher, writer):
|
||||
list_ideas = Task(
|
||||
|
||||
Reference in New Issue
Block a user