diff --git a/src/crewai/crew.py b/src/crewai/crew.py index e286b47f9..d1bbde947 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -6,6 +6,7 @@ from hashlib import md5 from typing import Any, Dict, List, Optional, Tuple, Union from langchain_core.callbacks import BaseCallbackHandler +from langchain_openai import ChatOpenAI from pydantic import ( UUID4, BaseModel, @@ -154,6 +155,10 @@ class Crew(BaseModel): default=False, description="Plan the crew execution and add the plan to the crew.", ) + planning_llm: Optional[Any] = Field( + default=ChatOpenAI(model="gpt-4o-mini"), + description="Language model that will run the AgentPlanner if planning is True.", + ) task_execution_output_json_files: Optional[List[str]] = Field( default=None, description="List of file paths for task execution JSON files.", @@ -559,15 +564,12 @@ class Crew(BaseModel): def _handle_crew_planning(self): """Handles the Crew planning.""" self._logger.log("info", "Planning the crew execution") - result = CrewPlanner(self.tasks)._handle_crew_planning() + result = CrewPlanner( + tasks=self.tasks, planning_agent_llm=self.planning_llm + )._handle_crew_planning() - if result is not None and hasattr(result, "list_of_plans_per_task"): - for task, step_plan in zip(self.tasks, result.list_of_plans_per_task): - task.description += step_plan - else: - self._logger.log( - "info", "Something went wrong with the planning process of the Crew" - ) + for task, step_plan in zip(self.tasks, result.list_of_plans_per_task): + task.description += step_plan def _store_execution_log( self, diff --git a/src/crewai/utilities/planning_handler.py b/src/crewai/utilities/planning_handler.py index cba1727b9..10d13c385 100644 --- a/src/crewai/utilities/planning_handler.py +++ b/src/crewai/utilities/planning_handler.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Any, List from pydantic import BaseModel @@ -11,17 +11,23 @@ class PlannerTaskPydanticOutput(BaseModel): class CrewPlanner: - def __init__(self, tasks: List[Task]): + def __init__(self, tasks: List[Task], planning_agent_llm: Any): self.tasks = tasks + self.planning_agent_llm = planning_agent_llm - def _handle_crew_planning(self) -> Optional[BaseModel]: + def _handle_crew_planning(self) -> PlannerTaskPydanticOutput: """Handles the Crew planning by creating detailed step-by-step plans for each task.""" planning_agent = self._create_planning_agent() tasks_summary = self._create_tasks_summary() planner_task = self._create_planner_task(planning_agent, tasks_summary) - return planner_task.execute_sync().pydantic + result = planner_task.execute_sync() + + if isinstance(result.pydantic, PlannerTaskPydanticOutput): + return result.pydantic + + raise ValueError("Failed to get the Planning output") def _create_planning_agent(self) -> Agent: """Creates the planning agent for the crew planning.""" @@ -32,6 +38,7 @@ class CrewPlanner: "available to each agent so that they can perform the tasks in an exemplary manner" ), backstory="Planner agent for crew planning", + llm=self.planning_agent_llm, ) def _create_planner_task(self, planning_agent: Agent, tasks_summary: str) -> Task: