diff --git a/src/crewai/agent.py b/src/crewai/agent.py index 5598a2800..e8a96c14f 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -118,7 +118,7 @@ class Agent(BaseAgent): default=None, description="Knowledge context for the agent.", ) - trained_data_file: str = Field( + trained_data_file: Optional[str] = Field( default=TRAINED_AGENTS_DATA_FILE, description="Path to the trained data file to use for task prompts.", ) @@ -501,13 +501,24 @@ class Agent(BaseAgent): return task_prompt def _use_trained_data(self, task_prompt: str) -> str: - """Use trained data for the agent task prompt to improve output.""" - if data := CrewTrainingHandler(self.trained_data_file).load(): + """ + Use trained data from a specified file for the agent task prompt. + + Uses the 'trained_data_file' attribute as the source of training instructions. + + Args: + task_prompt: The original task prompt to enhance. + + Returns: + Enhanced task prompt with training instructions if available. + """ + if self.trained_data_file and (data := CrewTrainingHandler(self.trained_data_file).load()): if trained_data_output := data.get(self.role): - task_prompt += ( - "\n\nYou MUST follow these instructions: \n - " - + "\n - ".join(trained_data_output["suggestions"]) - ) + if "suggestions" in trained_data_output: + task_prompt += ( + "\n\nYou MUST follow these instructions: \n - " + + "\n - ".join(trained_data_output["suggestions"]) + ) return task_prompt def _render_text_description(self, tools: List[Any]) -> str: diff --git a/src/crewai/cli/cli.py b/src/crewai/cli/cli.py index 66a9b970e..76a64b1ee 100644 --- a/src/crewai/cli/cli.py +++ b/src/crewai/cli/cli.py @@ -203,15 +203,15 @@ def install(context): @crewai.command() @click.option( "-f", - "--filename", + "--trained-data-file", type=str, default="trained_agents_data.pkl", - help="Path to a trained data file to use", + help="Path to a trained data file to use for agent task prompts", ) -def run(filename: str): +def run(trained_data_file: str): """Run the Crew.""" - click.echo(f"Running the Crew with trained data from {filename}") - run_crew(trained_data_file=filename) + click.echo(f"Running the Crew with agent training data file: {trained_data_file}") + run_crew(trained_data_file=trained_data_file) @crewai.command() diff --git a/src/crewai/cli/run_crew.py b/src/crewai/cli/run_crew.py index dd0f97ae6..5fc4a9f6e 100644 --- a/src/crewai/cli/run_crew.py +++ b/src/crewai/cli/run_crew.py @@ -1,3 +1,4 @@ +import os import subprocess from enum import Enum from typing import List, Optional @@ -61,6 +62,20 @@ def execute_command(crew_type: CrewType, trained_data_file: Optional[str] = None command = ["uv", "run", "kickoff" if crew_type == CrewType.FLOW else "run_crew"] if trained_data_file and crew_type == CrewType.STANDARD: + if not trained_data_file.endswith('.pkl'): + click.secho( + f"Error: Trained data file '{trained_data_file}' must have a .pkl extension.", + fg="red", + ) + return + + if not os.path.exists(trained_data_file): + click.secho( + f"Error: Trained data file '{trained_data_file}' does not exist.", + fg="red", + ) + return + command.extend(["--trained-data-file", trained_data_file]) try: diff --git a/src/crewai/cli/templates/crew/main.py b/src/crewai/cli/templates/crew/main.py index 9c74dada2..c8b43ced0 100644 --- a/src/crewai/cli/templates/crew/main.py +++ b/src/crewai/cli/templates/crew/main.py @@ -1,4 +1,5 @@ #!/usr/bin/env python +import argparse import sys import warnings @@ -17,14 +18,23 @@ def run(): """ Run the crew. """ + parser = argparse.ArgumentParser(description="Run the crew") + parser.add_argument( + "--trained-data-file", + "-f", + type=str, + default="trained_agents_data.pkl", + help="Path to a trained data file to use for agent task prompts" + ) + args, _ = parser.parse_known_args() + inputs = { 'topic': 'AI LLMs', 'current_year': str(datetime.now().year) } try: - filename = sys.argv[1] if len(sys.argv) > 1 else "trained_agents_data.pkl" - {{crew_name}}().crew(trained_data_file=filename).kickoff(inputs=inputs) + {{crew_name}}().crew(trained_data_file=args.trained_data_file).kickoff(inputs=inputs) except Exception as e: raise Exception(f"An error occurred while running the crew: {e}") diff --git a/tests/agent_test.py b/tests/agent_test.py index bf33fc772..02423cd91 100644 --- a/tests/agent_test.py +++ b/tests/agent_test.py @@ -1227,6 +1227,71 @@ def test_agent_use_custom_trained_data_file(crew_training_handler): ) +@patch("crewai.agent.CrewTrainingHandler") +def test_agent_with_none_trained_data_file(crew_training_handler): + task_prompt = "What is 1 + 1?" + agent = Agent( + role="researcher", + goal="test goal", + backstory="test backstory", + verbose=True, + trained_data_file=None + ) + + result = agent._use_trained_data(task_prompt=task_prompt) + + assert result == task_prompt + crew_training_handler.assert_not_called() + + +@patch("crewai.agent.CrewTrainingHandler") +def test_agent_with_missing_role_in_trained_data(crew_training_handler): + task_prompt = "What is 1 + 1?" + agent = Agent( + role="researcher", + goal="test goal", + backstory="test backstory", + verbose=True, + trained_data_file="trained_agents_data.pkl" + ) + crew_training_handler().load.return_value = { + "other_role": { + "suggestions": ["This should not be used."] + } + } + + result = agent._use_trained_data(task_prompt=task_prompt) + + assert result == task_prompt + crew_training_handler.assert_has_calls( + [mock.call(), mock.call("trained_agents_data.pkl"), mock.call().load()] + ) + + +@patch("crewai.agent.CrewTrainingHandler") +def test_agent_with_missing_suggestions_in_trained_data(crew_training_handler): + task_prompt = "What is 1 + 1?" + agent = Agent( + role="researcher", + goal="test goal", + backstory="test backstory", + verbose=True, + trained_data_file="trained_agents_data.pkl" + ) + crew_training_handler().load.return_value = { + "researcher": { + "other_key": ["This should not be used."] + } + } + + result = agent._use_trained_data(task_prompt=task_prompt) + + assert result == task_prompt + crew_training_handler.assert_has_calls( + [mock.call(), mock.call("trained_agents_data.pkl"), mock.call().load()] + ) + + def test_agent_max_retry_limit(): agent = Agent( role="test role",