Fix circular dependencies and updated PipelineRouter

This commit is contained in:
Brandon Hancock
2024-07-29 11:31:34 -04:00
parent cdfac165e3
commit 53e91a7c78
4 changed files with 104 additions and 90 deletions

View File

@@ -2,14 +2,17 @@ from __future__ import annotations
import asyncio
import copy
from typing import Any, Dict, List, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
from pydantic import BaseModel, Field, model_validator
from crewai.crew import Crew
from crewai.crews.crew_output import CrewOutput
from crewai.pipeline.pipeline_run_result import PipelineRunResult
from crewai.routers.pipeline_router import PipelineRouter
from crewai.types.pipeline_stage import PipelineStage
if TYPE_CHECKING:
from crewai.routers.pipeline_router import PipelineRouter
Trace = Union[Union[str, Dict[str, Any]], List[Union[str, Dict[str, Any]]]]
@@ -46,7 +49,7 @@ Multiple runs can be processed concurrently, each following the defined pipeline
class Pipeline(BaseModel):
stages: List[Union[Crew, "Pipeline", "PipelineRouter"]] = Field(
stages: List[PipelineStage] = Field(
..., description="List of crews representing stages to be executed in sequence"
)
@@ -105,38 +108,38 @@ class Pipeline(BaseModel):
stage_input = copy.deepcopy(current_input)
if isinstance(stage, PipelineRouter):
next_stage = stage.route(stage_input)
traces.append([f"Routed to {next_stage.__class__.__name__}"])
stage = next_stage
next_pipeline, route_taken = stage.route(stage_input)
self.stages = (
self.stages[: stage_index + 1]
+ list(next_pipeline.stages)
+ self.stages[stage_index + 1 :]
)
traces.append([{"router": stage.name, "route_taken": route_taken}])
stage_index += 1
continue
if isinstance(stage, Crew):
stage_outputs, stage_trace = await self._process_crew(
stage, stage_input
)
elif isinstance(stage, Pipeline):
stage_outputs, stage_trace = await self._process_pipeline(
stage, stage_input
)
else:
raise ValueError(f"Unsupported stage type: {type(stage)}")
stage_outputs, stage_trace = await self._process_stage(stage, stage_input)
self._update_metrics_and_input(
usage_metrics, current_input, stage, stage_outputs
)
traces.append(stage_trace)
all_stage_outputs.append(stage_outputs)
stage_index += 1
return self._build_pipeline_run_results(
all_stage_outputs, traces, usage_metrics
)
async def _process_crew(
self, crew: Crew, current_input: Dict[str, Any]
async def _process_stage(
self, stage: PipelineStage, current_input: Dict[str, Any]
) -> Tuple[List[CrewOutput], List[Union[str, Dict[str, Any]]]]:
output = await crew.kickoff_async(inputs=current_input)
return [output], [crew.name or str(crew.id)]
if isinstance(stage, Crew):
return await self._process_single_crew(stage, current_input)
elif isinstance(stage, list) and all(isinstance(crew, Crew) for crew in stage):
return await self._process_parallel_crews(stage, current_input)
else:
raise ValueError(f"Unsupported stage type: {type(stage)}")
async def _process_pipeline(
self, pipeline: "Pipeline", current_input: Dict[str, Any]
@@ -148,14 +151,6 @@ class Pipeline(BaseModel):
]
return outputs, traces
async def _process_stage(
self, stage: Union[Crew, List[Crew]], current_input: Dict[str, Any]
) -> Tuple[List[CrewOutput], List[Union[str, Dict[str, Any]]]]:
if isinstance(stage, Crew):
return await self._process_single_crew(stage, current_input)
else:
return await self._process_parallel_crews(stage, current_input)
async def _process_single_crew(
self, crew: Crew, current_input: Dict[str, Any]
) -> Tuple[List[CrewOutput], List[Union[str, Dict[str, Any]]]]:
@@ -174,13 +169,18 @@ class Pipeline(BaseModel):
self,
usage_metrics: Dict[str, Any],
current_input: Dict[str, Any],
stage: Union[Crew, "Pipeline"],
stage: PipelineStage,
outputs: List[CrewOutput],
) -> None:
for output in outputs:
if isinstance(stage, Crew):
usage_metrics[stage.name or str(stage.id)] = output.token_usage
current_input.update(output.to_dict())
if isinstance(stage, Crew):
usage_metrics[stage.name or str(stage.id)] = outputs[0].token_usage
current_input.update(outputs[0].to_dict())
elif isinstance(stage, list) and all(isinstance(crew, Crew) for crew in stage):
for crew, output in zip(stage, outputs):
usage_metrics[crew.name or str(crew.id)] = output.token_usage
current_input.update(output.to_dict())
else:
raise ValueError(f"Unsupported stage type: {type(stage)}")
def _build_pipeline_run_results(
self,
@@ -235,16 +235,12 @@ class Pipeline(BaseModel):
]
return [crew_outputs + [output] for output in all_stage_outputs[-1]]
def __rshift__(self, other: Any) -> "Pipeline":
def __rshift__(self, other: PipelineStage) -> Pipeline:
if isinstance(other, (Crew, Pipeline, PipelineRouter)):
return type(self)(stages=self.stages + [other])
elif isinstance(other, list) and all(isinstance(crew, Crew) for crew in other):
return type(self)(stages=self.stages + [other])
else:
raise TypeError(
f"Unsupported operand type for >>: '{type(self).__name__}' and '{type(other).__name__}'"
)
# TODO: CHECK IF NECESSARY
from crewai.routers.pipeline_router import PipelineRouter
Pipeline.model_rebuild()

View File

@@ -1,64 +1,74 @@
from __future__ import annotations
from typing import Any, Callable, Dict, List, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from pydantic import BaseModel, Field
from pydantic import BaseModel
from crewai.crew import Crew
from crewai.pipeline.pipeline import Pipeline
RouteType = Tuple[Callable[[Dict[str, Any]], bool], Pipeline]
class PipelineRouter(BaseModel):
conditions: List[Tuple[Callable[[Dict[str, Any]], bool], Pipeline]] = []
default: Optional[Pipeline] = None
routes: Dict[str, RouteType] = Field(
default_factory=dict,
description="Dictionary of route names to (condition, pipeline) tuples",
)
default: Pipeline = Field(
..., description="Default pipeline if no conditions are met"
)
def add_condition(
self, condition: Callable[[Dict[str, Any]], bool], next_stage: Pipeline
):
def __init__(self, *routes: Union[Tuple[str, RouteType], Pipeline], **data):
routes_dict = {}
default_pipeline = None
for route in routes:
if isinstance(route, tuple) and len(route) == 2:
name, route_tuple = route
if isinstance(route_tuple, tuple) and len(route_tuple) == 2:
condition, pipeline = route_tuple
routes_dict[name] = (condition, pipeline)
else:
raise ValueError(f"Invalid route tuple structure: {route}")
elif isinstance(route, Pipeline):
if default_pipeline is not None:
raise ValueError("Only one default pipeline can be specified")
default_pipeline = route
else:
raise ValueError(f"Invalid route type: {type(route)}")
if default_pipeline is None:
raise ValueError("A default pipeline must be specified")
super().__init__(routes=routes_dict, default=default_pipeline, **data)
def add_route(
self, name: str, condition: Callable[[Dict[str, Any]], bool], pipeline: Pipeline
) -> "PipelineRouter":
"""
Add a condition and its corresponding next stage to the router.
Add a named route with its condition and corresponding pipeline to the router.
Args:
condition: A function that takes the input dictionary and returns a boolean.
next_stage: The Crew or Pipeline to execute if the condition is met.
"""
self.conditions.append((condition, next_stage))
def set_default(self, default_stage: Union[Crew, "Pipeline"]):
"""Set the default stage to be executed if no conditions are met."""
self.default = default_stage
def route(self, input_dict: Dict[str, Any]) -> Union[Crew, "Pipeline"]:
"""
Evaluate the input against the conditions and return the appropriate next stage.
Args:
input_dict: The input dictionary to be evaluated.
name: A unique name for this route
condition: A function that takes the input dictionary and returns a boolean
pipeline: The Pipeline to execute if the condition is met
Returns:
The next Crew or Pipeline to be executed.
Raises:
ValueError: If no conditions are met and no default stage was set.
The PipelineRouter instance for method chaining
"""
for condition, next_stage in self.conditions:
self.routes[name] = (condition, pipeline)
return self
def route(self, input_dict: Dict[str, Any]) -> Tuple[Pipeline, str]:
"""
Evaluate the input against the conditions and return the appropriate pipeline.
Args:
input_dict: The input dictionary to be evaluated
Returns:
A tuple containing the next Pipeline to be executed and the name of the route taken
"""
for name, (condition, pipeline) in self.routes.items():
if condition(input_dict):
self._update_trace(input_dict, next_stage)
return next_stage
return pipeline, name
if self.default is not None:
self._update_trace(input_dict, self.default)
return self.default
raise ValueError("No conditions were met and no default stage was set.")
def _update_trace(self, input_dict: Dict[str, Any], next_stage: Pipeline):
"""Update the trace to show that the input went through the router."""
if "trace" not in input_dict:
input_dict["trace"] = []
input_dict["trace"].append(
{
"router": self.__class__.__name__,
"next_stage": next_stage.__class__.__name__,
}
)
return self.default, "default"

View File

View File

@@ -0,0 +1,8 @@
from typing import TYPE_CHECKING, List, Union
from crewai.crew import Crew
if TYPE_CHECKING:
from crewai.routers.pipeline_router import PipelineRouter
PipelineStage = Union[Crew, "PipelineRouter", List[Crew]]