diff --git a/src/crewai/agent.py b/src/crewai/agent.py index 73fbcd2fb..5598a2800 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -118,6 +118,10 @@ class Agent(BaseAgent): default=None, description="Knowledge context for the agent.", ) + trained_data_file: str = Field( + default=TRAINED_AGENTS_DATA_FILE, + description="Path to the trained data file to use for task prompts.", + ) crew_knowledge_context: Optional[str] = Field( default=None, description="Knowledge context for the crew.", @@ -498,7 +502,7 @@ class Agent(BaseAgent): def _use_trained_data(self, task_prompt: str) -> str: """Use trained data for the agent task prompt to improve output.""" - if data := CrewTrainingHandler(TRAINED_AGENTS_DATA_FILE).load(): + if data := CrewTrainingHandler(self.trained_data_file).load(): if trained_data_output := data.get(self.role): task_prompt += ( "\n\nYou MUST follow these instructions: \n - " diff --git a/src/crewai/cli/cli.py b/src/crewai/cli/cli.py index b2d59adbe..66a9b970e 100644 --- a/src/crewai/cli/cli.py +++ b/src/crewai/cli/cli.py @@ -201,9 +201,17 @@ def install(context): @crewai.command() -def run(): +@click.option( + "-f", + "--filename", + type=str, + default="trained_agents_data.pkl", + help="Path to a trained data file to use", +) +def run(filename: str): """Run the Crew.""" - run_crew() + click.echo(f"Running the Crew with trained data from {filename}") + run_crew(trained_data_file=filename) @crewai.command() diff --git a/src/crewai/cli/run_crew.py b/src/crewai/cli/run_crew.py index 62241a4b5..dd0f97ae6 100644 --- a/src/crewai/cli/run_crew.py +++ b/src/crewai/cli/run_crew.py @@ -14,13 +14,16 @@ class CrewType(Enum): FLOW = "flow" -def run_crew() -> None: +def run_crew(trained_data_file: Optional[str] = None) -> None: """ Run the crew or flow by running a command in the UV environment. Starting from version 0.103.0, this command can be used to run both standard crews and flows. For flows, it detects the type from pyproject.toml and automatically runs the appropriate command. + + Args: + trained_data_file: Optional path to a trained data file to use """ crewai_version = get_crewai_version() min_required_version = "0.71.0" @@ -44,17 +47,21 @@ def run_crew() -> None: click.echo(f"Running the {'Flow' if is_flow else 'Crew'}") # Execute the appropriate command - execute_command(crew_type) + execute_command(crew_type, trained_data_file) -def execute_command(crew_type: CrewType) -> None: +def execute_command(crew_type: CrewType, trained_data_file: Optional[str] = None) -> None: """ Execute the appropriate command based on crew type. Args: crew_type: The type of crew to run + trained_data_file: Optional path to a trained data file to use """ command = ["uv", "run", "kickoff" if crew_type == CrewType.FLOW else "run_crew"] + + if trained_data_file and crew_type == CrewType.STANDARD: + command.extend(["--trained-data-file", trained_data_file]) try: subprocess.run(command, capture_output=False, text=True, check=True) diff --git a/src/crewai/cli/templates/crew/main.py b/src/crewai/cli/templates/crew/main.py index b604d8ceb..9c74dada2 100644 --- a/src/crewai/cli/templates/crew/main.py +++ b/src/crewai/cli/templates/crew/main.py @@ -23,7 +23,8 @@ def run(): } try: - {{crew_name}}().crew().kickoff(inputs=inputs) + filename = sys.argv[1] if len(sys.argv) > 1 else "trained_agents_data.pkl" + {{crew_name}}().crew(trained_data_file=filename).kickoff(inputs=inputs) except Exception as e: raise Exception(f"An error occurred while running the crew: {e}") diff --git a/src/crewai/crew.py b/src/crewai/crew.py index 7c9696f6d..643a268f8 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -122,6 +122,10 @@ class Crew(BaseModel): tasks: List[Task] = Field(default_factory=list) agents: List[BaseAgent] = Field(default_factory=list) process: Process = Field(default=Process.sequential) + trained_data_file: Optional[str] = Field( + default=None, + description="Path to the trained data file to use for agent task prompts.", + ) verbose: bool = Field(default=False) memory: bool = Field( default=False, @@ -1196,7 +1200,12 @@ class Crew(BaseModel): "manager_llm", } - cloned_agents = [agent.copy() for agent in self.agents] + cloned_agents = [] + for agent in self.agents: + cloned_agent = agent.copy() + if self.trained_data_file: + cloned_agent.trained_data_file = self.trained_data_file + cloned_agents.append(cloned_agent) manager_agent = self.manager_agent.copy() if self.manager_agent else None manager_llm = shallow_copy(self.manager_llm) if self.manager_llm else None diff --git a/tests/agent_test.py b/tests/agent_test.py index b3d243a53..bf33fc772 100644 --- a/tests/agent_test.py +++ b/tests/agent_test.py @@ -1196,6 +1196,37 @@ def test_agent_use_trained_data(crew_training_handler): ) +@patch("crewai.agent.CrewTrainingHandler") +def test_agent_use_custom_trained_data_file(crew_training_handler): + task_prompt = "What is 1 + 1?" + custom_file = "custom_trained_data.pkl" + agent = Agent( + role="researcher", + goal="test goal", + backstory="test backstory", + verbose=True, + trained_data_file=custom_file + ) + crew_training_handler().load.return_value = { + agent.role: { + "suggestions": [ + "The result of the math operation must be right.", + "Result must be better than 1.", + ] + } + } + + result = agent._use_trained_data(task_prompt=task_prompt) + + assert ( + result == "What is 1 + 1?\n\nYou MUST follow these instructions: \n" + " - The result of the math operation must be right.\n - Result must be better than 1." + ) + crew_training_handler.assert_has_calls( + [mock.call(), mock.call(custom_file), mock.call().load()] + ) + + def test_agent_max_retry_limit(): agent = Agent( role="test role",