From f7ecc0bc47ac2d0cfbaa6d052be69ffd383fbda5 Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Sat, 19 Apr 2025 23:06:52 +0000 Subject: [PATCH] Fix issue #2647: Make planning LLM inherit authentication parameters from agent's LLM Co-Authored-By: Joe Moura --- src/crewai/crew.py | 7 +++- src/crewai/utilities/planning_handler.py | 15 ++++++- .../utilities/test_planning_auth/__init__.py | 0 .../test_planning_auth_inheritance.py | 39 +++++++++++++++++++ 4 files changed, 58 insertions(+), 3 deletions(-) create mode 100644 tests/utilities/test_planning_auth/__init__.py create mode 100644 tests/utilities/test_planning_auth/test_planning_auth_inheritance.py diff --git a/src/crewai/crew.py b/src/crewai/crew.py index 7c9696f6d..3ca6ff9f7 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -720,8 +720,13 @@ class Crew(BaseModel): def _handle_crew_planning(self): """Handles the Crew planning.""" self._logger.log("info", "Planning the crew execution") + + agent_llm = self.agents[0].llm if self.agents and hasattr(self.agents[0], 'llm') else None + result = CrewPlanner( - tasks=self.tasks, planning_agent_llm=self.planning_llm + tasks=self.tasks, + planning_agent_llm=self.planning_llm, + agent_llm=agent_llm )._handle_crew_planning() for task, step_plan in zip(self.tasks, result.list_of_plans_per_task): diff --git a/src/crewai/utilities/planning_handler.py b/src/crewai/utilities/planning_handler.py index 1bd14a0c8..1b990c81c 100644 --- a/src/crewai/utilities/planning_handler.py +++ b/src/crewai/utilities/planning_handler.py @@ -28,11 +28,22 @@ class PlannerTaskPydanticOutput(BaseModel): class CrewPlanner: """Plans and coordinates the execution of crew tasks.""" - def __init__(self, tasks: List[Task], planning_agent_llm: Optional[Any] = None): + def __init__(self, tasks: List[Task], planning_agent_llm: Optional[Any] = None, agent_llm: Optional[Any] = None): self.tasks = tasks if planning_agent_llm is None: - self.planning_agent_llm = "gpt-4o-mini" + if agent_llm is not None and hasattr(agent_llm, "base_url") and agent_llm.base_url is not None: + from crewai.llm import LLM + self.planning_agent_llm = LLM( + model="gpt-4o-mini", + base_url=agent_llm.base_url, + api_key=getattr(agent_llm, "api_key", None), + organization=getattr(agent_llm, "organization", None), + api_version=getattr(agent_llm, "api_version", None), + extra_headers=getattr(agent_llm, "extra_headers", None) + ) + else: + self.planning_agent_llm = "gpt-4o-mini" else: self.planning_agent_llm = planning_agent_llm diff --git a/tests/utilities/test_planning_auth/__init__.py b/tests/utilities/test_planning_auth/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/utilities/test_planning_auth/test_planning_auth_inheritance.py b/tests/utilities/test_planning_auth/test_planning_auth_inheritance.py new file mode 100644 index 000000000..e264c74d3 --- /dev/null +++ b/tests/utilities/test_planning_auth/test_planning_auth_inheritance.py @@ -0,0 +1,39 @@ +from crewai import Agent, Task +from crewai.llm import LLM +from crewai.utilities.planning_handler import CrewPlanner + +def test_planning_llm_inherits_auth_params(): + """Test that planning LLM inherits authentication parameters from agent LLM.""" + custom_llm = LLM( + model="custom-model", + base_url="https://api.custom-provider.com/v1", + api_key="fake-api-key", + api_version="2023-05-15", + organization="custom-org" + ) + + agent = Agent( + role="Test Agent", + goal="Test Goal", + backstory="Test Backstory", + llm=custom_llm + ) + + task = Task( + description="Test Task", + expected_output="Test Output", + agent=agent + ) + + planner = CrewPlanner( + tasks=[task], + planning_agent_llm=None, # This should trigger the inheritance logic + agent_llm=custom_llm + ) + + assert hasattr(planner, 'planning_agent_llm') + assert hasattr(planner.planning_agent_llm, 'base_url') + assert planner.planning_agent_llm.base_url == "https://api.custom-provider.com/v1" + assert planner.planning_agent_llm.api_key == "fake-api-key" + assert planner.planning_agent_llm.api_version == "2023-05-15" + # organization is not directly accessible as an attribute