Address PR review comments: improve validation, error handling, and add tests

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-04-29 21:30:04 +00:00
parent 2f66aa0efc
commit fa39ce9db2
5 changed files with 115 additions and 14 deletions

View File

@@ -118,7 +118,7 @@ class Agent(BaseAgent):
default=None, default=None,
description="Knowledge context for the agent.", description="Knowledge context for the agent.",
) )
trained_data_file: str = Field( trained_data_file: Optional[str] = Field(
default=TRAINED_AGENTS_DATA_FILE, default=TRAINED_AGENTS_DATA_FILE,
description="Path to the trained data file to use for task prompts.", description="Path to the trained data file to use for task prompts.",
) )
@@ -501,13 +501,24 @@ class Agent(BaseAgent):
return task_prompt return task_prompt
def _use_trained_data(self, task_prompt: str) -> str: def _use_trained_data(self, task_prompt: str) -> str:
"""Use trained data for the agent task prompt to improve output.""" """
if data := CrewTrainingHandler(self.trained_data_file).load(): Use trained data from a specified file for the agent task prompt.
Uses the 'trained_data_file' attribute as the source of training instructions.
Args:
task_prompt: The original task prompt to enhance.
Returns:
Enhanced task prompt with training instructions if available.
"""
if self.trained_data_file and (data := CrewTrainingHandler(self.trained_data_file).load()):
if trained_data_output := data.get(self.role): if trained_data_output := data.get(self.role):
task_prompt += ( if "suggestions" in trained_data_output:
"\n\nYou MUST follow these instructions: \n - " task_prompt += (
+ "\n - ".join(trained_data_output["suggestions"]) "\n\nYou MUST follow these instructions: \n - "
) + "\n - ".join(trained_data_output["suggestions"])
)
return task_prompt return task_prompt
def _render_text_description(self, tools: List[Any]) -> str: def _render_text_description(self, tools: List[Any]) -> str:

View File

@@ -203,15 +203,15 @@ def install(context):
@crewai.command() @crewai.command()
@click.option( @click.option(
"-f", "-f",
"--filename", "--trained-data-file",
type=str, type=str,
default="trained_agents_data.pkl", default="trained_agents_data.pkl",
help="Path to a trained data file to use", help="Path to a trained data file to use for agent task prompts",
) )
def run(filename: str): def run(trained_data_file: str):
"""Run the Crew.""" """Run the Crew."""
click.echo(f"Running the Crew with trained data from {filename}") click.echo(f"Running the Crew with agent training data file: {trained_data_file}")
run_crew(trained_data_file=filename) run_crew(trained_data_file=trained_data_file)
@crewai.command() @crewai.command()

View File

@@ -1,3 +1,4 @@
import os
import subprocess import subprocess
from enum import Enum from enum import Enum
from typing import List, Optional from typing import List, Optional
@@ -61,6 +62,20 @@ def execute_command(crew_type: CrewType, trained_data_file: Optional[str] = None
command = ["uv", "run", "kickoff" if crew_type == CrewType.FLOW else "run_crew"] command = ["uv", "run", "kickoff" if crew_type == CrewType.FLOW else "run_crew"]
if trained_data_file and crew_type == CrewType.STANDARD: if trained_data_file and crew_type == CrewType.STANDARD:
if not trained_data_file.endswith('.pkl'):
click.secho(
f"Error: Trained data file '{trained_data_file}' must have a .pkl extension.",
fg="red",
)
return
if not os.path.exists(trained_data_file):
click.secho(
f"Error: Trained data file '{trained_data_file}' does not exist.",
fg="red",
)
return
command.extend(["--trained-data-file", trained_data_file]) command.extend(["--trained-data-file", trained_data_file])
try: try:

View File

@@ -1,4 +1,5 @@
#!/usr/bin/env python #!/usr/bin/env python
import argparse
import sys import sys
import warnings import warnings
@@ -17,14 +18,23 @@ def run():
""" """
Run the crew. Run the crew.
""" """
parser = argparse.ArgumentParser(description="Run the crew")
parser.add_argument(
"--trained-data-file",
"-f",
type=str,
default="trained_agents_data.pkl",
help="Path to a trained data file to use for agent task prompts"
)
args, _ = parser.parse_known_args()
inputs = { inputs = {
'topic': 'AI LLMs', 'topic': 'AI LLMs',
'current_year': str(datetime.now().year) 'current_year': str(datetime.now().year)
} }
try: try:
filename = sys.argv[1] if len(sys.argv) > 1 else "trained_agents_data.pkl" {{crew_name}}().crew(trained_data_file=args.trained_data_file).kickoff(inputs=inputs)
{{crew_name}}().crew(trained_data_file=filename).kickoff(inputs=inputs)
except Exception as e: except Exception as e:
raise Exception(f"An error occurred while running the crew: {e}") raise Exception(f"An error occurred while running the crew: {e}")

View File

@@ -1227,6 +1227,71 @@ def test_agent_use_custom_trained_data_file(crew_training_handler):
) )
@patch("crewai.agent.CrewTrainingHandler")
def test_agent_with_none_trained_data_file(crew_training_handler):
task_prompt = "What is 1 + 1?"
agent = Agent(
role="researcher",
goal="test goal",
backstory="test backstory",
verbose=True,
trained_data_file=None
)
result = agent._use_trained_data(task_prompt=task_prompt)
assert result == task_prompt
crew_training_handler.assert_not_called()
@patch("crewai.agent.CrewTrainingHandler")
def test_agent_with_missing_role_in_trained_data(crew_training_handler):
task_prompt = "What is 1 + 1?"
agent = Agent(
role="researcher",
goal="test goal",
backstory="test backstory",
verbose=True,
trained_data_file="trained_agents_data.pkl"
)
crew_training_handler().load.return_value = {
"other_role": {
"suggestions": ["This should not be used."]
}
}
result = agent._use_trained_data(task_prompt=task_prompt)
assert result == task_prompt
crew_training_handler.assert_has_calls(
[mock.call(), mock.call("trained_agents_data.pkl"), mock.call().load()]
)
@patch("crewai.agent.CrewTrainingHandler")
def test_agent_with_missing_suggestions_in_trained_data(crew_training_handler):
task_prompt = "What is 1 + 1?"
agent = Agent(
role="researcher",
goal="test goal",
backstory="test backstory",
verbose=True,
trained_data_file="trained_agents_data.pkl"
)
crew_training_handler().load.return_value = {
"researcher": {
"other_key": ["This should not be used."]
}
}
result = agent._use_trained_data(task_prompt=task_prompt)
assert result == task_prompt
crew_training_handler.assert_has_calls(
[mock.call(), mock.call("trained_agents_data.pkl"), mock.call().load()]
)
def test_agent_max_retry_limit(): def test_agent_max_retry_limit():
agent = Agent( agent = Agent(
role="test role", role="test role",