Fix tests

This commit is contained in:
Brandon Hancock
2024-07-31 09:46:19 -04:00
parent 1154cef798
commit 1147a1c93e
3 changed files with 50 additions and 65 deletions

View File

@@ -142,21 +142,22 @@ class Pipeline(BaseModel):
"""
initial_input = copy.deepcopy(kickoff_input)
current_input = copy.deepcopy(kickoff_input)
stages = copy.deepcopy(self.stages)
pipeline_usage_metrics: Dict[str, UsageMetrics] = {}
all_stage_outputs: List[List[CrewOutput]] = []
traces: List[List[Union[str, Dict[str, Any]]]] = [[initial_input]]
stage_index = 0
while stage_index < len(self.stages):
stage = self.stages[stage_index]
while stage_index < len(stages):
stage = stages[stage_index]
stage_input = copy.deepcopy(current_input)
if isinstance(stage, Router):
next_pipeline, route_taken = stage.route(stage_input)
self.stages = (
self.stages[: stage_index + 1]
stages = (
stages[: stage_index + 1]
+ list(next_pipeline.stages)
+ self.stages[stage_index + 1 :]
+ stages[stage_index + 1 :]
)
traces.append([{"route_taken": route_taken}])
stage_index += 1

View File

@@ -1,4 +1,3 @@
import logging
from dataclasses import dataclass
from typing import Any, Callable, Dict, Generic, Tuple, TypeVar