diff --git a/src/crewai/crew.py b/src/crewai/crew.py index d1bbde947..509154822 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -6,7 +6,6 @@ from hashlib import md5 from typing import Any, Dict, List, Optional, Tuple, Union from langchain_core.callbacks import BaseCallbackHandler -from langchain_openai import ChatOpenAI from pydantic import ( UUID4, BaseModel, @@ -156,7 +155,7 @@ class Crew(BaseModel): description="Plan the crew execution and add the plan to the crew.", ) planning_llm: Optional[Any] = Field( - default=ChatOpenAI(model="gpt-4o-mini"), + default=None, description="Language model that will run the AgentPlanner if planning is True.", ) task_execution_output_json_files: Optional[List[str]] = Field( diff --git a/src/crewai/utilities/planning_handler.py b/src/crewai/utilities/planning_handler.py index 10d13c385..29b89667e 100644 --- a/src/crewai/utilities/planning_handler.py +++ b/src/crewai/utilities/planning_handler.py @@ -1,5 +1,6 @@ -from typing import Any, List +from typing import Any, List, Optional +from langchain_openai import ChatOpenAI from pydantic import BaseModel from crewai.agent import Agent @@ -11,9 +12,13 @@ class PlannerTaskPydanticOutput(BaseModel): class CrewPlanner: - def __init__(self, tasks: List[Task], planning_agent_llm: Any): + def __init__(self, tasks: List[Task], planning_agent_llm: Optional[Any] = None): self.tasks = tasks - self.planning_agent_llm = planning_agent_llm + + if planning_agent_llm is None: + self.planning_agent_llm = ChatOpenAI(model="gpt-4o-mini") + else: + self.planning_agent_llm = planning_agent_llm def _handle_crew_planning(self) -> PlannerTaskPydanticOutput: """Handles the Crew planning by creating detailed step-by-step plans for each task.""" diff --git a/tests/utilities/test_planning_handler.py b/tests/utilities/test_planning_handler.py index 75bc3e033..d1bee0f50 100644 --- a/tests/utilities/test_planning_handler.py +++ b/tests/utilities/test_planning_handler.py @@ -1,10 +1,10 @@ from unittest.mock import patch -from crewai.tasks.task_output import TaskOutput import pytest from crewai.agent import Agent from crewai.task import Task +from crewai.tasks.task_output import TaskOutput from crewai.utilities.planning_handler import CrewPlanner, PlannerTaskPydanticOutput @@ -28,7 +28,7 @@ class TestCrewPlanner: agent=Agent(role="Agent 3", goal="Goal 3", backstory="Backstory 3"), ), ] - return CrewPlanner(tasks) + return CrewPlanner(tasks, None) def test_handle_crew_planning(self, crew_planner): with patch.object(Task, "execute_sync") as execute: