Fix critical circular dependency issues. Now needing to fix trace issue.

This commit is contained in:
Brandon Hancock
2024-07-30 13:29:01 -04:00
parent 4251494c55
commit a79efefe7b
8 changed files with 137 additions and 135 deletions

View File

@@ -1,4 +1,5 @@
import json
from typing import Any, Dict
from unittest.mock import MagicMock
import pytest
@@ -8,6 +9,7 @@ from crewai.crews.crew_output import CrewOutput
from crewai.pipeline.pipeline import Pipeline
from crewai.pipeline.pipeline_kickoff_result import PipelineKickoffResult
from crewai.process import Process
from crewai.routers.router import Route, Router
from crewai.task import Task
from crewai.tasks.task_output import TaskOutput
from crewai.types.usage_metrics import UsageMetrics
@@ -64,9 +66,29 @@ def mock_crew_factory():
return _create_mock_crew
# @pytest.fixture
# def pipeline_router_factory():
# return PipelineRouter()
@pytest.fixture
def mock_router_factory(mock_crew_factory):
def _create_mock_router():
crew1 = mock_crew_factory(name="Crew 1", output_json_dict={"output": "crew1"})
crew2 = mock_crew_factory(name="Crew 2", output_json_dict={"output": "crew2"})
crew3 = mock_crew_factory(name="Crew 3", output_json_dict={"output": "crew3"})
router = Router[Dict[str, Any], Pipeline](
routes={
"route1": Route(
condition=lambda x: x.get("score", 0) > 80,
pipeline=Pipeline(stages=[crew1]),
),
"route2": Route(
condition=lambda x: x.get("score", 0) > 50,
pipeline=Pipeline(stages=[crew2]),
),
},
default=Pipeline(stages=[crew3]),
)
return router
return _create_mock_router
def test_pipeline_initialization(mock_crew_factory):
@@ -479,9 +501,40 @@ async def test_pipeline_data_accumulation(mock_crew_factory):
assert final_result.crews_outputs[1].json_dict == {"key2": "value2"}
def test_add_condition(pipeline_router_factory, mock_crew_factory):
pipeline_router = pipeline_router_factory()
crew = mock_crew_factory(name="Test Crew")
pipeline_router.add_condition(lambda x: x.get("score", 0) > 80, crew)
assert len(pipeline_router.conditions) == 1
assert pipeline_router.conditions[0][1] == crew
@pytest.mark.asyncio
async def test_pipeline_with_router(mock_router_factory):
router = mock_router_factory()
pipeline = Pipeline(stages=[router])
# Test high score route
result_high = await pipeline.kickoff([{"score": 90}])
assert len(result_high) == 1
assert result_high[0].json_dict is not None
assert result_high[0].json_dict["output"] == "crew1"
assert result_high[0].trace == [
{"score": 90},
{"router": "Router", "route_taken": "route1"},
"Crew 1",
]
# Test medium score route
result_medium = await pipeline.kickoff([{"score": 60}])
assert len(result_medium) == 1
assert result_medium[0].json_dict is not None
assert result_medium[0].json_dict["output"] == "crew2"
assert result_medium[0].trace == [
{"score": 60},
{"router": "Router", "route_taken": "route2"},
"Crew 2",
]
# Test low score (default) route
result_low = await pipeline.kickoff([{"score": 30}])
assert len(result_low) == 1
assert result_low[0].json_dict is not None
assert result_low[0].json_dict["output"] == "crew3"
assert result_low[0].trace == [
{"score": 30},
{"router": "Router", "route_taken": "default"},
"Crew 3",
]