Address PR review comments: improve validation, error handling, and add tests

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-04-29 21:30:04 +00:00
parent 2f66aa0efc
commit fa39ce9db2
5 changed files with 115 additions and 14 deletions

View File

@@ -118,7 +118,7 @@ class Agent(BaseAgent):
default=None,
description="Knowledge context for the agent.",
)
trained_data_file: str = Field(
trained_data_file: Optional[str] = Field(
default=TRAINED_AGENTS_DATA_FILE,
description="Path to the trained data file to use for task prompts.",
)
@@ -501,13 +501,24 @@ class Agent(BaseAgent):
return task_prompt
def _use_trained_data(self, task_prompt: str) -> str:
"""Use trained data for the agent task prompt to improve output."""
if data := CrewTrainingHandler(self.trained_data_file).load():
"""
Use trained data from a specified file for the agent task prompt.
Uses the 'trained_data_file' attribute as the source of training instructions.
Args:
task_prompt: The original task prompt to enhance.
Returns:
Enhanced task prompt with training instructions if available.
"""
if self.trained_data_file and (data := CrewTrainingHandler(self.trained_data_file).load()):
if trained_data_output := data.get(self.role):
task_prompt += (
"\n\nYou MUST follow these instructions: \n - "
+ "\n - ".join(trained_data_output["suggestions"])
)
if "suggestions" in trained_data_output:
task_prompt += (
"\n\nYou MUST follow these instructions: \n - "
+ "\n - ".join(trained_data_output["suggestions"])
)
return task_prompt
def _render_text_description(self, tools: List[Any]) -> str:

View File

@@ -203,15 +203,15 @@ def install(context):
@crewai.command()
@click.option(
"-f",
"--filename",
"--trained-data-file",
type=str,
default="trained_agents_data.pkl",
help="Path to a trained data file to use",
help="Path to a trained data file to use for agent task prompts",
)
def run(filename: str):
def run(trained_data_file: str):
"""Run the Crew."""
click.echo(f"Running the Crew with trained data from {filename}")
run_crew(trained_data_file=filename)
click.echo(f"Running the Crew with agent training data file: {trained_data_file}")
run_crew(trained_data_file=trained_data_file)
@crewai.command()

View File

@@ -1,3 +1,4 @@
import os
import subprocess
from enum import Enum
from typing import List, Optional
@@ -61,6 +62,20 @@ def execute_command(crew_type: CrewType, trained_data_file: Optional[str] = None
command = ["uv", "run", "kickoff" if crew_type == CrewType.FLOW else "run_crew"]
if trained_data_file and crew_type == CrewType.STANDARD:
if not trained_data_file.endswith('.pkl'):
click.secho(
f"Error: Trained data file '{trained_data_file}' must have a .pkl extension.",
fg="red",
)
return
if not os.path.exists(trained_data_file):
click.secho(
f"Error: Trained data file '{trained_data_file}' does not exist.",
fg="red",
)
return
command.extend(["--trained-data-file", trained_data_file])
try:

View File

@@ -1,4 +1,5 @@
#!/usr/bin/env python
import argparse
import sys
import warnings
@@ -17,14 +18,23 @@ def run():
"""
Run the crew.
"""
parser = argparse.ArgumentParser(description="Run the crew")
parser.add_argument(
"--trained-data-file",
"-f",
type=str,
default="trained_agents_data.pkl",
help="Path to a trained data file to use for agent task prompts"
)
args, _ = parser.parse_known_args()
inputs = {
'topic': 'AI LLMs',
'current_year': str(datetime.now().year)
}
try:
filename = sys.argv[1] if len(sys.argv) > 1 else "trained_agents_data.pkl"
{{crew_name}}().crew(trained_data_file=filename).kickoff(inputs=inputs)
{{crew_name}}().crew(trained_data_file=args.trained_data_file).kickoff(inputs=inputs)
except Exception as e:
raise Exception(f"An error occurred while running the crew: {e}")

View File

@@ -1227,6 +1227,71 @@ def test_agent_use_custom_trained_data_file(crew_training_handler):
)
@patch("crewai.agent.CrewTrainingHandler")
def test_agent_with_none_trained_data_file(crew_training_handler):
task_prompt = "What is 1 + 1?"
agent = Agent(
role="researcher",
goal="test goal",
backstory="test backstory",
verbose=True,
trained_data_file=None
)
result = agent._use_trained_data(task_prompt=task_prompt)
assert result == task_prompt
crew_training_handler.assert_not_called()
@patch("crewai.agent.CrewTrainingHandler")
def test_agent_with_missing_role_in_trained_data(crew_training_handler):
task_prompt = "What is 1 + 1?"
agent = Agent(
role="researcher",
goal="test goal",
backstory="test backstory",
verbose=True,
trained_data_file="trained_agents_data.pkl"
)
crew_training_handler().load.return_value = {
"other_role": {
"suggestions": ["This should not be used."]
}
}
result = agent._use_trained_data(task_prompt=task_prompt)
assert result == task_prompt
crew_training_handler.assert_has_calls(
[mock.call(), mock.call("trained_agents_data.pkl"), mock.call().load()]
)
@patch("crewai.agent.CrewTrainingHandler")
def test_agent_with_missing_suggestions_in_trained_data(crew_training_handler):
task_prompt = "What is 1 + 1?"
agent = Agent(
role="researcher",
goal="test goal",
backstory="test backstory",
verbose=True,
trained_data_file="trained_agents_data.pkl"
)
crew_training_handler().load.return_value = {
"researcher": {
"other_key": ["This should not be used."]
}
}
result = agent._use_trained_data(task_prompt=task_prompt)
assert result == task_prompt
crew_training_handler.assert_has_calls(
[mock.call(), mock.call("trained_agents_data.pkl"), mock.call().load()]
)
def test_agent_max_retry_limit():
agent = Agent(
role="test role",