Fix circular dependencies and updated PipelineRouter

This commit is contained in:
Brandon Hancock
2024-07-29 11:31:34 -04:00
parent cdfac165e3
commit 53e91a7c78
4 changed files with 104 additions and 90 deletions

View File

@@ -2,14 +2,17 @@ from __future__ import annotations
import asyncio import asyncio
import copy import copy
from typing import Any, Dict, List, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator
from crewai.crew import Crew from crewai.crew import Crew
from crewai.crews.crew_output import CrewOutput from crewai.crews.crew_output import CrewOutput
from crewai.pipeline.pipeline_run_result import PipelineRunResult from crewai.pipeline.pipeline_run_result import PipelineRunResult
from crewai.routers.pipeline_router import PipelineRouter from crewai.types.pipeline_stage import PipelineStage
if TYPE_CHECKING:
from crewai.routers.pipeline_router import PipelineRouter
Trace = Union[Union[str, Dict[str, Any]], List[Union[str, Dict[str, Any]]]] Trace = Union[Union[str, Dict[str, Any]], List[Union[str, Dict[str, Any]]]]
@@ -46,7 +49,7 @@ Multiple runs can be processed concurrently, each following the defined pipeline
class Pipeline(BaseModel): class Pipeline(BaseModel):
stages: List[Union[Crew, "Pipeline", "PipelineRouter"]] = Field( stages: List[PipelineStage] = Field(
..., description="List of crews representing stages to be executed in sequence" ..., description="List of crews representing stages to be executed in sequence"
) )
@@ -105,38 +108,38 @@ class Pipeline(BaseModel):
stage_input = copy.deepcopy(current_input) stage_input = copy.deepcopy(current_input)
if isinstance(stage, PipelineRouter): if isinstance(stage, PipelineRouter):
next_stage = stage.route(stage_input) next_pipeline, route_taken = stage.route(stage_input)
traces.append([f"Routed to {next_stage.__class__.__name__}"]) self.stages = (
stage = next_stage self.stages[: stage_index + 1]
+ list(next_pipeline.stages)
+ self.stages[stage_index + 1 :]
)
traces.append([{"router": stage.name, "route_taken": route_taken}])
stage_index += 1
continue
if isinstance(stage, Crew): stage_outputs, stage_trace = await self._process_stage(stage, stage_input)
stage_outputs, stage_trace = await self._process_crew(
stage, stage_input
)
elif isinstance(stage, Pipeline):
stage_outputs, stage_trace = await self._process_pipeline(
stage, stage_input
)
else:
raise ValueError(f"Unsupported stage type: {type(stage)}")
self._update_metrics_and_input( self._update_metrics_and_input(
usage_metrics, current_input, stage, stage_outputs usage_metrics, current_input, stage, stage_outputs
) )
traces.append(stage_trace) traces.append(stage_trace)
all_stage_outputs.append(stage_outputs) all_stage_outputs.append(stage_outputs)
stage_index += 1 stage_index += 1
return self._build_pipeline_run_results( return self._build_pipeline_run_results(
all_stage_outputs, traces, usage_metrics all_stage_outputs, traces, usage_metrics
) )
async def _process_crew( async def _process_stage(
self, crew: Crew, current_input: Dict[str, Any] self, stage: PipelineStage, current_input: Dict[str, Any]
) -> Tuple[List[CrewOutput], List[Union[str, Dict[str, Any]]]]: ) -> Tuple[List[CrewOutput], List[Union[str, Dict[str, Any]]]]:
output = await crew.kickoff_async(inputs=current_input) if isinstance(stage, Crew):
return [output], [crew.name or str(crew.id)] return await self._process_single_crew(stage, current_input)
elif isinstance(stage, list) and all(isinstance(crew, Crew) for crew in stage):
return await self._process_parallel_crews(stage, current_input)
else:
raise ValueError(f"Unsupported stage type: {type(stage)}")
async def _process_pipeline( async def _process_pipeline(
self, pipeline: "Pipeline", current_input: Dict[str, Any] self, pipeline: "Pipeline", current_input: Dict[str, Any]
@@ -148,14 +151,6 @@ class Pipeline(BaseModel):
] ]
return outputs, traces return outputs, traces
async def _process_stage(
self, stage: Union[Crew, List[Crew]], current_input: Dict[str, Any]
) -> Tuple[List[CrewOutput], List[Union[str, Dict[str, Any]]]]:
if isinstance(stage, Crew):
return await self._process_single_crew(stage, current_input)
else:
return await self._process_parallel_crews(stage, current_input)
async def _process_single_crew( async def _process_single_crew(
self, crew: Crew, current_input: Dict[str, Any] self, crew: Crew, current_input: Dict[str, Any]
) -> Tuple[List[CrewOutput], List[Union[str, Dict[str, Any]]]]: ) -> Tuple[List[CrewOutput], List[Union[str, Dict[str, Any]]]]:
@@ -174,13 +169,18 @@ class Pipeline(BaseModel):
self, self,
usage_metrics: Dict[str, Any], usage_metrics: Dict[str, Any],
current_input: Dict[str, Any], current_input: Dict[str, Any],
stage: Union[Crew, "Pipeline"], stage: PipelineStage,
outputs: List[CrewOutput], outputs: List[CrewOutput],
) -> None: ) -> None:
for output in outputs: if isinstance(stage, Crew):
if isinstance(stage, Crew): usage_metrics[stage.name or str(stage.id)] = outputs[0].token_usage
usage_metrics[stage.name or str(stage.id)] = output.token_usage current_input.update(outputs[0].to_dict())
current_input.update(output.to_dict()) elif isinstance(stage, list) and all(isinstance(crew, Crew) for crew in stage):
for crew, output in zip(stage, outputs):
usage_metrics[crew.name or str(crew.id)] = output.token_usage
current_input.update(output.to_dict())
else:
raise ValueError(f"Unsupported stage type: {type(stage)}")
def _build_pipeline_run_results( def _build_pipeline_run_results(
self, self,
@@ -235,16 +235,12 @@ class Pipeline(BaseModel):
] ]
return [crew_outputs + [output] for output in all_stage_outputs[-1]] return [crew_outputs + [output] for output in all_stage_outputs[-1]]
def __rshift__(self, other: Any) -> "Pipeline": def __rshift__(self, other: PipelineStage) -> Pipeline:
if isinstance(other, (Crew, Pipeline, PipelineRouter)): if isinstance(other, (Crew, Pipeline, PipelineRouter)):
return type(self)(stages=self.stages + [other]) return type(self)(stages=self.stages + [other])
elif isinstance(other, list) and all(isinstance(crew, Crew) for crew in other):
return type(self)(stages=self.stages + [other])
else: else:
raise TypeError( raise TypeError(
f"Unsupported operand type for >>: '{type(self).__name__}' and '{type(other).__name__}'" f"Unsupported operand type for >>: '{type(self).__name__}' and '{type(other).__name__}'"
) )
# TODO: CHECK IF NECESSARY
from crewai.routers.pipeline_router import PipelineRouter
Pipeline.model_rebuild()

View File

@@ -1,64 +1,74 @@
from __future__ import annotations from typing import Any, Callable, Dict, List, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from pydantic import BaseModel, Field
from pydantic import BaseModel
from crewai.crew import Crew
from crewai.pipeline.pipeline import Pipeline from crewai.pipeline.pipeline import Pipeline
RouteType = Tuple[Callable[[Dict[str, Any]], bool], Pipeline]
class PipelineRouter(BaseModel): class PipelineRouter(BaseModel):
conditions: List[Tuple[Callable[[Dict[str, Any]], bool], Pipeline]] = [] routes: Dict[str, RouteType] = Field(
default: Optional[Pipeline] = None default_factory=dict,
description="Dictionary of route names to (condition, pipeline) tuples",
)
default: Pipeline = Field(
..., description="Default pipeline if no conditions are met"
)
def add_condition( def __init__(self, *routes: Union[Tuple[str, RouteType], Pipeline], **data):
self, condition: Callable[[Dict[str, Any]], bool], next_stage: Pipeline routes_dict = {}
): default_pipeline = None
for route in routes:
if isinstance(route, tuple) and len(route) == 2:
name, route_tuple = route
if isinstance(route_tuple, tuple) and len(route_tuple) == 2:
condition, pipeline = route_tuple
routes_dict[name] = (condition, pipeline)
else:
raise ValueError(f"Invalid route tuple structure: {route}")
elif isinstance(route, Pipeline):
if default_pipeline is not None:
raise ValueError("Only one default pipeline can be specified")
default_pipeline = route
else:
raise ValueError(f"Invalid route type: {type(route)}")
if default_pipeline is None:
raise ValueError("A default pipeline must be specified")
super().__init__(routes=routes_dict, default=default_pipeline, **data)
def add_route(
self, name: str, condition: Callable[[Dict[str, Any]], bool], pipeline: Pipeline
) -> "PipelineRouter":
""" """
Add a condition and its corresponding next stage to the router. Add a named route with its condition and corresponding pipeline to the router.
Args: Args:
condition: A function that takes the input dictionary and returns a boolean. name: A unique name for this route
next_stage: The Crew or Pipeline to execute if the condition is met. condition: A function that takes the input dictionary and returns a boolean
""" pipeline: The Pipeline to execute if the condition is met
self.conditions.append((condition, next_stage))
def set_default(self, default_stage: Union[Crew, "Pipeline"]):
"""Set the default stage to be executed if no conditions are met."""
self.default = default_stage
def route(self, input_dict: Dict[str, Any]) -> Union[Crew, "Pipeline"]:
"""
Evaluate the input against the conditions and return the appropriate next stage.
Args:
input_dict: The input dictionary to be evaluated.
Returns: Returns:
The next Crew or Pipeline to be executed. The PipelineRouter instance for method chaining
Raises:
ValueError: If no conditions are met and no default stage was set.
""" """
for condition, next_stage in self.conditions: self.routes[name] = (condition, pipeline)
return self
def route(self, input_dict: Dict[str, Any]) -> Tuple[Pipeline, str]:
"""
Evaluate the input against the conditions and return the appropriate pipeline.
Args:
input_dict: The input dictionary to be evaluated
Returns:
A tuple containing the next Pipeline to be executed and the name of the route taken
"""
for name, (condition, pipeline) in self.routes.items():
if condition(input_dict): if condition(input_dict):
self._update_trace(input_dict, next_stage) return pipeline, name
return next_stage
if self.default is not None: return self.default, "default"
self._update_trace(input_dict, self.default)
return self.default
raise ValueError("No conditions were met and no default stage was set.")
def _update_trace(self, input_dict: Dict[str, Any], next_stage: Pipeline):
"""Update the trace to show that the input went through the router."""
if "trace" not in input_dict:
input_dict["trace"] = []
input_dict["trace"].append(
{
"router": self.__class__.__name__,
"next_stage": next_stage.__class__.__name__,
}
)

View File

View File

@@ -0,0 +1,8 @@
from typing import TYPE_CHECKING, List, Union
from crewai.crew import Crew
if TYPE_CHECKING:
from crewai.routers.pipeline_router import PipelineRouter
PipelineStage = Union[Crew, "PipelineRouter", List[Crew]]