This commit is contained in:
Brandon Hancock
2024-07-26 17:24:09 -04:00
parent 31ff979a4b
commit cdfac165e3

View File

@@ -1,22 +1,19 @@
from __future__ import annotations
from typing import Any, Callable, Dict, List, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from pydantic import BaseModel
from crewai.crew import Crew
from crewai.pipeline.pipeline import Pipeline
class PipelineRouter(BaseModel):
conditions: List[
Tuple[Callable[[Dict[str, Any]], bool], Union[Crew, "Pipeline"]]
] = []
default: Union[Crew, "Pipeline", None] = None
conditions: List[Tuple[Callable[[Dict[str, Any]], bool], Pipeline]] = []
default: Optional[Pipeline] = None
def add_condition(
self,
condition: Callable[[Dict[str, Any]], bool],
next_stage: Union[Crew, "Pipeline"],
self, condition: Callable[[Dict[str, Any]], bool], next_stage: Pipeline
):
"""
Add a condition and its corresponding next stage to the router.
@@ -55,9 +52,7 @@ class PipelineRouter(BaseModel):
raise ValueError("No conditions were met and no default stage was set.")
def _update_trace(
self, input_dict: Dict[str, Any], next_stage: Union[Crew, "Pipeline"]
):
def _update_trace(self, input_dict: Dict[str, Any], next_stage: Pipeline):
"""Update the trace to show that the input went through the router."""
if "trace" not in input_dict:
input_dict["trace"] = []
@@ -67,10 +62,3 @@ class PipelineRouter(BaseModel):
"next_stage": next_stage.__class__.__name__,
}
)
# TODO: See if this is necessary
from crewai.pipeline.pipeline import Pipeline
# This line should be at the end of the file
PipelineRouter.model_rebuild()