diff --git a/tests/utilities/test_planning_handler.py b/tests/utilities/test_planning_handler.py index 502398fab..8013e0213 100644 --- a/tests/utilities/test_planning_handler.py +++ b/tests/utilities/test_planning_handler.py @@ -6,7 +6,11 @@ from langchain_openai import ChatOpenAI from crewai.agent import Agent from crewai.task import Task from crewai.tasks.task_output import TaskOutput -from crewai.utilities.planning_handler import CrewPlanner, PlannerTaskPydanticOutput +from crewai.utilities.planning_handler import ( + CrewPlanner, + PlannerTaskPydanticOutput, + PlanPerTask, +) class TestCrewPlanner: @@ -44,12 +48,17 @@ class TestCrewPlanner: return CrewPlanner(tasks, planning_agent_llm) def test_handle_crew_planning(self, crew_planner): + list_of_plans_per_task = [ + PlanPerTask(task="Task1", plan="Plan 1"), + PlanPerTask(task="Task2", plan="Plan 2"), + PlanPerTask(task="Task3", plan="Plan 3"), + ] with patch.object(Task, "execute_sync") as execute: execute.return_value = TaskOutput( description="Description", agent="agent", pydantic=PlannerTaskPydanticOutput( - list_of_plans_per_task=["Plan 1", "Plan 2", "Plan 3"] + list_of_plans_per_task=list_of_plans_per_task ), ) result = crew_planner._handle_crew_planning() @@ -91,7 +100,9 @@ class TestCrewPlanner: execute.return_value = TaskOutput( description="Description", agent="agent", - pydantic=PlannerTaskPydanticOutput(list_of_plans_per_task=["Plan 1"]), + pydantic=PlannerTaskPydanticOutput( + list_of_plans_per_task=[PlanPerTask(task="Task1", plan="Plan 1")] + ), ) result = crew_planner_different_llm._handle_crew_planning()