mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 08:38:30 +00:00
Address PR review comments: improve validation, error handling, and add tests
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -118,7 +118,7 @@ class Agent(BaseAgent):
|
|||||||
default=None,
|
default=None,
|
||||||
description="Knowledge context for the agent.",
|
description="Knowledge context for the agent.",
|
||||||
)
|
)
|
||||||
trained_data_file: str = Field(
|
trained_data_file: Optional[str] = Field(
|
||||||
default=TRAINED_AGENTS_DATA_FILE,
|
default=TRAINED_AGENTS_DATA_FILE,
|
||||||
description="Path to the trained data file to use for task prompts.",
|
description="Path to the trained data file to use for task prompts.",
|
||||||
)
|
)
|
||||||
@@ -501,13 +501,24 @@ class Agent(BaseAgent):
|
|||||||
return task_prompt
|
return task_prompt
|
||||||
|
|
||||||
def _use_trained_data(self, task_prompt: str) -> str:
|
def _use_trained_data(self, task_prompt: str) -> str:
|
||||||
"""Use trained data for the agent task prompt to improve output."""
|
"""
|
||||||
if data := CrewTrainingHandler(self.trained_data_file).load():
|
Use trained data from a specified file for the agent task prompt.
|
||||||
|
|
||||||
|
Uses the 'trained_data_file' attribute as the source of training instructions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_prompt: The original task prompt to enhance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Enhanced task prompt with training instructions if available.
|
||||||
|
"""
|
||||||
|
if self.trained_data_file and (data := CrewTrainingHandler(self.trained_data_file).load()):
|
||||||
if trained_data_output := data.get(self.role):
|
if trained_data_output := data.get(self.role):
|
||||||
task_prompt += (
|
if "suggestions" in trained_data_output:
|
||||||
"\n\nYou MUST follow these instructions: \n - "
|
task_prompt += (
|
||||||
+ "\n - ".join(trained_data_output["suggestions"])
|
"\n\nYou MUST follow these instructions: \n - "
|
||||||
)
|
+ "\n - ".join(trained_data_output["suggestions"])
|
||||||
|
)
|
||||||
return task_prompt
|
return task_prompt
|
||||||
|
|
||||||
def _render_text_description(self, tools: List[Any]) -> str:
|
def _render_text_description(self, tools: List[Any]) -> str:
|
||||||
|
|||||||
@@ -203,15 +203,15 @@ def install(context):
|
|||||||
@crewai.command()
|
@crewai.command()
|
||||||
@click.option(
|
@click.option(
|
||||||
"-f",
|
"-f",
|
||||||
"--filename",
|
"--trained-data-file",
|
||||||
type=str,
|
type=str,
|
||||||
default="trained_agents_data.pkl",
|
default="trained_agents_data.pkl",
|
||||||
help="Path to a trained data file to use",
|
help="Path to a trained data file to use for agent task prompts",
|
||||||
)
|
)
|
||||||
def run(filename: str):
|
def run(trained_data_file: str):
|
||||||
"""Run the Crew."""
|
"""Run the Crew."""
|
||||||
click.echo(f"Running the Crew with trained data from {filename}")
|
click.echo(f"Running the Crew with agent training data file: {trained_data_file}")
|
||||||
run_crew(trained_data_file=filename)
|
run_crew(trained_data_file=trained_data_file)
|
||||||
|
|
||||||
|
|
||||||
@crewai.command()
|
@crewai.command()
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
@@ -61,6 +62,20 @@ def execute_command(crew_type: CrewType, trained_data_file: Optional[str] = None
|
|||||||
command = ["uv", "run", "kickoff" if crew_type == CrewType.FLOW else "run_crew"]
|
command = ["uv", "run", "kickoff" if crew_type == CrewType.FLOW else "run_crew"]
|
||||||
|
|
||||||
if trained_data_file and crew_type == CrewType.STANDARD:
|
if trained_data_file and crew_type == CrewType.STANDARD:
|
||||||
|
if not trained_data_file.endswith('.pkl'):
|
||||||
|
click.secho(
|
||||||
|
f"Error: Trained data file '{trained_data_file}' must have a .pkl extension.",
|
||||||
|
fg="red",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if not os.path.exists(trained_data_file):
|
||||||
|
click.secho(
|
||||||
|
f"Error: Trained data file '{trained_data_file}' does not exist.",
|
||||||
|
fg="red",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
command.extend(["--trained-data-file", trained_data_file])
|
command.extend(["--trained-data-file", trained_data_file])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
import argparse
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
@@ -17,14 +18,23 @@ def run():
|
|||||||
"""
|
"""
|
||||||
Run the crew.
|
Run the crew.
|
||||||
"""
|
"""
|
||||||
|
parser = argparse.ArgumentParser(description="Run the crew")
|
||||||
|
parser.add_argument(
|
||||||
|
"--trained-data-file",
|
||||||
|
"-f",
|
||||||
|
type=str,
|
||||||
|
default="trained_agents_data.pkl",
|
||||||
|
help="Path to a trained data file to use for agent task prompts"
|
||||||
|
)
|
||||||
|
args, _ = parser.parse_known_args()
|
||||||
|
|
||||||
inputs = {
|
inputs = {
|
||||||
'topic': 'AI LLMs',
|
'topic': 'AI LLMs',
|
||||||
'current_year': str(datetime.now().year)
|
'current_year': str(datetime.now().year)
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
filename = sys.argv[1] if len(sys.argv) > 1 else "trained_agents_data.pkl"
|
{{crew_name}}().crew(trained_data_file=args.trained_data_file).kickoff(inputs=inputs)
|
||||||
{{crew_name}}().crew(trained_data_file=filename).kickoff(inputs=inputs)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f"An error occurred while running the crew: {e}")
|
raise Exception(f"An error occurred while running the crew: {e}")
|
||||||
|
|
||||||
|
|||||||
@@ -1227,6 +1227,71 @@ def test_agent_use_custom_trained_data_file(crew_training_handler):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("crewai.agent.CrewTrainingHandler")
|
||||||
|
def test_agent_with_none_trained_data_file(crew_training_handler):
|
||||||
|
task_prompt = "What is 1 + 1?"
|
||||||
|
agent = Agent(
|
||||||
|
role="researcher",
|
||||||
|
goal="test goal",
|
||||||
|
backstory="test backstory",
|
||||||
|
verbose=True,
|
||||||
|
trained_data_file=None
|
||||||
|
)
|
||||||
|
|
||||||
|
result = agent._use_trained_data(task_prompt=task_prompt)
|
||||||
|
|
||||||
|
assert result == task_prompt
|
||||||
|
crew_training_handler.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@patch("crewai.agent.CrewTrainingHandler")
|
||||||
|
def test_agent_with_missing_role_in_trained_data(crew_training_handler):
|
||||||
|
task_prompt = "What is 1 + 1?"
|
||||||
|
agent = Agent(
|
||||||
|
role="researcher",
|
||||||
|
goal="test goal",
|
||||||
|
backstory="test backstory",
|
||||||
|
verbose=True,
|
||||||
|
trained_data_file="trained_agents_data.pkl"
|
||||||
|
)
|
||||||
|
crew_training_handler().load.return_value = {
|
||||||
|
"other_role": {
|
||||||
|
"suggestions": ["This should not be used."]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result = agent._use_trained_data(task_prompt=task_prompt)
|
||||||
|
|
||||||
|
assert result == task_prompt
|
||||||
|
crew_training_handler.assert_has_calls(
|
||||||
|
[mock.call(), mock.call("trained_agents_data.pkl"), mock.call().load()]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("crewai.agent.CrewTrainingHandler")
|
||||||
|
def test_agent_with_missing_suggestions_in_trained_data(crew_training_handler):
|
||||||
|
task_prompt = "What is 1 + 1?"
|
||||||
|
agent = Agent(
|
||||||
|
role="researcher",
|
||||||
|
goal="test goal",
|
||||||
|
backstory="test backstory",
|
||||||
|
verbose=True,
|
||||||
|
trained_data_file="trained_agents_data.pkl"
|
||||||
|
)
|
||||||
|
crew_training_handler().load.return_value = {
|
||||||
|
"researcher": {
|
||||||
|
"other_key": ["This should not be used."]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result = agent._use_trained_data(task_prompt=task_prompt)
|
||||||
|
|
||||||
|
assert result == task_prompt
|
||||||
|
crew_training_handler.assert_has_calls(
|
||||||
|
[mock.call(), mock.call("trained_agents_data.pkl"), mock.call().load()]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_agent_max_retry_limit():
|
def test_agent_max_retry_limit():
|
||||||
agent = Agent(
|
agent = Agent(
|
||||||
role="test role",
|
role="test role",
|
||||||
|
|||||||
Reference in New Issue
Block a user