diff --git a/tests/utilities/test_planning_handler.py b/tests/utilities/test_planning_handler.py index d1bee0f50..502398fab 100644 --- a/tests/utilities/test_planning_handler.py +++ b/tests/utilities/test_planning_handler.py @@ -1,6 +1,7 @@ from unittest.mock import patch import pytest +from langchain_openai import ChatOpenAI from crewai.agent import Agent from crewai.task import Task @@ -30,6 +31,18 @@ class TestCrewPlanner: ] return CrewPlanner(tasks, None) + @pytest.fixture + def crew_planner_different_llm(self): + tasks = [ + Task( + description="Task 1", + expected_output="Output 1", + agent=Agent(role="Agent 1", goal="Goal 1", backstory="Backstory 1"), + ) + ] + planning_agent_llm = ChatOpenAI(model="gpt-3.5-turbo") + return CrewPlanner(tasks, planning_agent_llm) + def test_handle_crew_planning(self, crew_planner): with patch.object(Task, "execute_sync") as execute: execute.return_value = TaskOutput( @@ -40,7 +53,7 @@ class TestCrewPlanner: ), ) result = crew_planner._handle_crew_planning() - + assert crew_planner.planning_agent_llm.model_name == "gpt-4o-mini" assert isinstance(result, PlannerTaskPydanticOutput) assert len(result.list_of_plans_per_task) == len(crew_planner.tasks) execute.assert_called_once() @@ -72,3 +85,22 @@ class TestCrewPlanner: assert isinstance(tasks_summary, str) assert tasks_summary.startswith("\n Task Number 1 - Task 1") assert tasks_summary.endswith('"agent_tools": []\n ') + + def test_handle_crew_planning_different_llm(self, crew_planner_different_llm): + with patch.object(Task, "execute_sync") as execute: + execute.return_value = TaskOutput( + description="Description", + agent="agent", + pydantic=PlannerTaskPydanticOutput(list_of_plans_per_task=["Plan 1"]), + ) + result = crew_planner_different_llm._handle_crew_planning() + + assert ( + crew_planner_different_llm.planning_agent_llm.model_name + == "gpt-3.5-turbo" + ) + assert isinstance(result, PlannerTaskPydanticOutput) + assert len(result.list_of_plans_per_task) == len( + crew_planner_different_llm.tasks + ) + execute.assert_called_once()