Add more tests which showed underlying issue with traces

This commit is contained in:
Brandon Hancock
2024-07-30 16:25:03 -04:00
parent 19f87f2b82
commit 1154cef798

View File

@@ -457,7 +457,8 @@ def test_pipeline_invalid_crew(mock_crew_factory):
error_msg = str(exc_info.value)
print(f"Full error message: {error_msg}") # For debugging
assert (
"Expected Crew instance or list of Crews, got <class 'str'>" in error_msg
"Expected Crew instance, Router instance, or list of Crews, got <class 'str'>"
in error_msg
), f"Unexpected error message: {error_msg}"
@@ -540,3 +541,124 @@ async def test_pipeline_with_router(mock_router_factory):
{"route_taken": "default"},
"Crew 3",
]
@pytest.mark.asyncio
async def test_router_with_multiple_inputs(mock_router_factory):
router = mock_router_factory()
pipeline = Pipeline(stages=[router])
inputs = [{"score": 90}, {"score": 60}, {"score": 30}]
results = await pipeline.kickoff(inputs)
print("RESULTS", results)
assert len(results) == 3
assert results[0].json_dict is not None
assert results[0].json_dict["output"] == "crew1"
assert results[1].json_dict is not None
assert results[1].json_dict["output"] == "crew2"
assert results[2].json_dict is not None
assert results[2].json_dict["output"] == "crew3"
assert results[0].trace[1]["route_taken"] == "route1"
assert results[1].trace[1]["route_taken"] == "route2"
assert results[2].trace[1]["route_taken"] == "default"
@pytest.mark.asyncio
async def test_pipeline_with_multiple_routers(mock_router_factory, mock_crew_factory):
router1 = mock_router_factory()
router2 = mock_router_factory()
final_crew = mock_crew_factory(
name="Final Crew", output_json_dict={"output": "final"}
)
pipeline = Pipeline(stages=[router1, router2, final_crew])
result = await pipeline.kickoff([{"score": 75}])
assert len(result) == 1
assert result[0].json_dict is not None
assert result[0].json_dict["output"] == "final"
assert len(result[0].trace) == 4 # Input, Router1, Router2, Final Crew
assert result[0].trace[1]["route_taken"] == "route2"
assert result[0].trace[2]["route_taken"] == "route2"
@pytest.mark.asyncio
async def test_router_default_route(mock_crew_factory):
default_crew = mock_crew_factory(
name="Default Crew", output_json_dict={"output": "default"}
)
router = Router(
routes={
"route1": Route(
condition=lambda x: False,
pipeline=Pipeline(stages=[mock_crew_factory(name="Never Used")]),
),
},
default=Pipeline(stages=[default_crew]),
)
pipeline = Pipeline(stages=[router])
result = await pipeline.kickoff([{"score": 100}])
assert len(result) == 1
assert result[0].json_dict is not None
assert result[0].json_dict["output"] == "default"
assert result[0].trace[1]["route_taken"] == "default"
@pytest.mark.asyncio
async def test_router_with_empty_input(mock_router_factory):
router = mock_router_factory()
pipeline = Pipeline(stages=[router])
result = await pipeline.kickoff([{}])
assert len(result) == 1
assert result[0].json_dict is not None
assert result[0].json_dict["output"] == "crew3" # Default route
assert result[0].trace[1]["route_taken"] == "default"
@pytest.mark.asyncio
async def test_router_add_route(mock_router_factory, mock_crew_factory):
router = mock_router_factory()
new_crew = mock_crew_factory(name="New Crew", output_json_dict={"output": "new"})
router.add_route(
"new_route",
condition=lambda x: x.get("score", 0) > 90,
pipeline=Pipeline(stages=[new_crew]),
)
pipeline = Pipeline(stages=[router])
result = await pipeline.kickoff([{"score": 95}])
assert len(result) == 1
assert result[0].json_dict is not None
assert result[0].json_dict["output"] == "new"
assert result[0].trace[1]["route_taken"] == "new_route"
@pytest.mark.asyncio
async def test_pipeline_result_accumulation_with_routers(
mock_router_factory, mock_crew_factory
):
router = mock_router_factory()
accumulator_crew = mock_crew_factory(
name="Accumulator", output_json_dict={"accumulated": "data"}
)
pipeline = Pipeline(stages=[router, accumulator_crew])
result = await pipeline.kickoff([{"score": 75, "initial": "value"}])
assert len(result) == 1
assert result[0].json_dict is not None
assert "output" in result[0].json_dict
assert "accumulated" in result[0].json_dict
assert len(result[0].trace) == 3
assert result[0].trace[0]["initial"] == "value"
assert result[0].trace[1]["route_taken"] == "route2"