diff --git a/src/crewai/pipeline/__init__.py b/src/crewai/pipeline/__init__.py index d129ecee1..573154b25 100644 --- a/src/crewai/pipeline/__init__.py +++ b/src/crewai/pipeline/__init__.py @@ -1,5 +1,3 @@ from crewai.pipeline.pipeline import Pipeline +from crewai.pipeline.pipeline_kickoff_result import PipelineKickoffResult from crewai.pipeline.pipeline_output import PipelineOutput -from crewai.pipeline.pipeline_run_result import PipelineRunResult - -__all__ = ["Pipeline", "PipelineOutput", "PipelineRunResult"] diff --git a/src/crewai/pipeline/pipeline.py b/src/crewai/pipeline/pipeline.py index 6ae58174a..7e320bf4a 100644 --- a/src/crewai/pipeline/pipeline.py +++ b/src/crewai/pipeline/pipeline.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import asyncio import copy from typing import Any, Dict, List, Tuple, Union @@ -9,12 +7,11 @@ from pydantic import BaseModel, Field, model_validator from crewai.crew import Crew from crewai.crews.crew_output import CrewOutput from crewai.pipeline.pipeline_kickoff_result import PipelineKickoffResult -from crewai.routers.pipeline_router import PipelineRouter -from crewai.types.pipeline_stage import PipelineStage +from crewai.routers.router import Router from crewai.types.usage_metrics import UsageMetrics Trace = Union[Union[str, Dict[str, Any]], List[Union[str, Dict[str, Any]]]] - +PipelineStage = Union[Crew, List[Crew], Router] """ Developer Notes: @@ -88,15 +85,6 @@ class Pipeline(BaseModel): @model_validator(mode="before") @classmethod def validate_stages(cls, values): - """ - Validates the stages to ensure correct nesting and types. - - Args: - values (dict): Dictionary containing the pipeline stages. - - Returns: - dict: Validated stages. - """ stages = values.get("stages", []) def check_nesting_and_type(item, depth=0): @@ -105,9 +93,9 @@ class Pipeline(BaseModel): if isinstance(item, list): for sub_item in item: check_nesting_and_type(sub_item, depth + 1) - elif not isinstance(item, Crew): + elif not isinstance(item, (Crew, Router)): raise ValueError( - f"Expected Crew instance or list of Crews, got {type(item)}" + f"Expected Crew instance, Router instance, or list of Crews, got {type(item)}" ) for stage in stages: @@ -163,14 +151,16 @@ class Pipeline(BaseModel): stage = self.stages[stage_index] stage_input = copy.deepcopy(current_input) - if isinstance(stage, PipelineRouter): + if isinstance(stage, Router): next_pipeline, route_taken = stage.route(stage_input) self.stages = ( self.stages[: stage_index + 1] + list(next_pipeline.stages) + self.stages[stage_index + 1 :] ) - traces.append([{"router": stage.name, "route_taken": route_taken}]) + traces.append( + [{"router": stage.__class__.__name__, "route_taken": route_taken}] + ) stage_index += 1 continue @@ -210,7 +200,7 @@ class Pipeline(BaseModel): async def _process_pipeline( self, pipeline: "Pipeline", current_input: Dict[str, Any] ) -> Tuple[List[CrewOutput], List[Union[str, Dict[str, Any]]]]: - results = await pipeline.process_single_run(current_input) + results = await pipeline.process_single_kickoff(current_input) outputs = [result.crews_outputs[-1] for result in results] traces: List[Union[str, Dict[str, Any]]] = [ f"Nested Pipeline: {pipeline.__class__.__name__}" @@ -388,7 +378,7 @@ class Pipeline(BaseModel): ] return [crew_outputs + [output] for output in all_stage_outputs[-1]] - def __rshift__(self, other: PipelineStage) -> Pipeline: + def __rshift__(self, other: PipelineStage) -> "Pipeline": """ Implements the >> operator to add another Stage (Crew or List[Crew]) to an existing Pipeline. @@ -398,14 +388,11 @@ class Pipeline(BaseModel): Returns: Pipeline: A new pipeline with the added stage. """ - if isinstance(other, (Crew, Pipeline, PipelineRouter)): - return type(self)(stages=self.stages + [other]) - elif isinstance(other, list) and all(isinstance(crew, Crew) for crew in other): + if isinstance(other, (Crew, Router)) or ( + isinstance(other, list) and all(isinstance(item, Crew) for item in other) + ): return type(self)(stages=self.stages + [other]) else: raise TypeError( f"Unsupported operand type for >>: '{type(self).__name__}' and '{type(other).__name__}'" ) - - -Pipeline.model_rebuild() diff --git a/src/crewai/routers/__init__.py b/src/crewai/routers/__init__.py index a8e0d5f73..b21d76bd2 100644 --- a/src/crewai/routers/__init__.py +++ b/src/crewai/routers/__init__.py @@ -1,3 +1 @@ -from crewai.routers.pipeline_router import PipelineRouter - -__all__ = ["PipelineRouter"] +from crewai.routers.router import Router diff --git a/src/crewai/routers/pipeline_router.py b/src/crewai/routers/pipeline_router.py deleted file mode 100644 index 2dc82375a..000000000 --- a/src/crewai/routers/pipeline_router.py +++ /dev/null @@ -1,81 +0,0 @@ -from typing import Any, Callable, Dict, Tuple, Union - -from pydantic import BaseModel, Field - -from crewai.pipeline.pipeline import Pipeline -from crewai.types.route import Route - - -class PipelineRouter(BaseModel): - routes: Dict[str, Route] = Field( - default_factory=dict, - description="Dictionary of route names to (condition, pipeline) tuples", - ) - default: "Pipeline" = Field( - ..., description="Default pipeline if no conditions are met" - ) - - def __init__(self, *routes: Union[Tuple[str, Route], "Pipeline"], **data): - from crewai.pipeline.pipeline import Pipeline - - routes_dict = {} - default_pipeline = None - - for route in routes: - if isinstance(route, tuple) and len(route) == 2: - name, route_tuple = route - if isinstance(route_tuple, tuple) and len(route_tuple) == 2: - condition, pipeline = route_tuple - routes_dict[name] = (condition, pipeline) - else: - raise ValueError(f"Invalid route tuple structure: {route}") - elif isinstance(route, Pipeline): - if default_pipeline is not None: - raise ValueError("Only one default pipeline can be specified") - default_pipeline = route - else: - raise ValueError(f"Invalid route type: {type(route)}") - - if default_pipeline is None: - raise ValueError("A default pipeline must be specified") - - super().__init__(routes=routes_dict, default=default_pipeline, **data) - - def add_route( - self, - name: str, - condition: Callable[[Dict[str, Any]], bool], - pipeline: "Pipeline", - ) -> "PipelineRouter": - """ - Add a named route with its condition and corresponding pipeline to the router. - - Args: - name: A unique name for this route - condition: A function that takes the input dictionary and returns a boolean - pipeline: The Pipeline to execute if the condition is met - - Returns: - The PipelineRouter instance for method chaining - """ - self.routes[name] = (condition, pipeline) - return self - - def route(self, input_dict: Dict[str, Any]) -> Tuple["Pipeline", str]: - """ - Evaluate the input against the conditions and return the appropriate pipeline. - - Args: - input_dict: The input dictionary to be evaluated - - Returns: - A tuple containing the next Pipeline to be executed and the name of the route taken - """ - for name, (condition, pipeline) in self.routes.items(): - if condition(input_dict): - return pipeline, name - - return self.default, "default" - - -PipelineRouter.model_rebuild() diff --git a/src/crewai/routers/router.py b/src/crewai/routers/router.py new file mode 100644 index 000000000..e11c816f2 --- /dev/null +++ b/src/crewai/routers/router.py @@ -0,0 +1,60 @@ +from dataclasses import dataclass +from typing import Any, Callable, Dict, Generic, Tuple, TypeVar + +from pydantic import BaseModel, Field + +T = TypeVar("T", bound=Dict[str, Any]) +U = TypeVar("U") + + +@dataclass +class Route(Generic[T, U]): + condition: Callable[[T], bool] + pipeline: U + + +class Router(BaseModel, Generic[T, U]): + routes: Dict[str, Route[T, U]] = Field( + default_factory=dict, + description="Dictionary of route names to (condition, pipeline) tuples", + ) + default: U = Field(..., description="Default pipeline if no conditions are met") + + def __init__(self, routes: Dict[str, Route[T, U]], default: U, **data): + super().__init__(routes=routes, default=default, **data) + + def add_route( + self, + name: str, + condition: Callable[[T], bool], + pipeline: U, + ) -> "Router[T, U]": + """ + Add a named route with its condition and corresponding pipeline to the router. + + Args: + name: A unique name for this route + condition: A function that takes a dictionary input and returns a boolean + pipeline: The Pipeline to execute if the condition is met + + Returns: + The Router instance for method chaining + """ + self.routes[name] = Route(condition=condition, pipeline=pipeline) + return self + + def route(self, input_data: T) -> Tuple[U, str]: + """ + Evaluate the input against the conditions and return the appropriate pipeline. + + Args: + input_data: The input dictionary to be evaluated + + Returns: + A tuple containing the next Pipeline to be executed and the name of the route taken + """ + for name, route in self.routes.items(): + if route.condition(input_data): + return route.pipeline, name + + return self.default, "default" diff --git a/src/crewai/types/pipeline_stage.py b/src/crewai/types/pipeline_stage.py deleted file mode 100644 index 23b0d42d5..000000000 --- a/src/crewai/types/pipeline_stage.py +++ /dev/null @@ -1,8 +0,0 @@ -from typing import TYPE_CHECKING, List, Union - -from crewai.crew import Crew - -if TYPE_CHECKING: - from crewai.routers.pipeline_router import PipelineRouter - -PipelineStage = Union[Crew, "PipelineRouter", List[Crew]] diff --git a/src/crewai/types/route.py b/src/crewai/types/route.py deleted file mode 100644 index 9a97acc52..000000000 --- a/src/crewai/types/route.py +++ /dev/null @@ -1,5 +0,0 @@ -from typing import Any, Callable, Dict, Tuple - -from crewai.pipeline.pipeline import Pipeline - -Route = Tuple[Callable[[Dict[str, Any]], bool], Pipeline] diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py index 3afb68033..c5141d1ae 100644 --- a/tests/pipeline/test_pipeline.py +++ b/tests/pipeline/test_pipeline.py @@ -1,4 +1,5 @@ import json +from typing import Any, Dict from unittest.mock import MagicMock import pytest @@ -8,6 +9,7 @@ from crewai.crews.crew_output import CrewOutput from crewai.pipeline.pipeline import Pipeline from crewai.pipeline.pipeline_kickoff_result import PipelineKickoffResult from crewai.process import Process +from crewai.routers.router import Route, Router from crewai.task import Task from crewai.tasks.task_output import TaskOutput from crewai.types.usage_metrics import UsageMetrics @@ -64,9 +66,29 @@ def mock_crew_factory(): return _create_mock_crew -# @pytest.fixture -# def pipeline_router_factory(): -# return PipelineRouter() +@pytest.fixture +def mock_router_factory(mock_crew_factory): + def _create_mock_router(): + crew1 = mock_crew_factory(name="Crew 1", output_json_dict={"output": "crew1"}) + crew2 = mock_crew_factory(name="Crew 2", output_json_dict={"output": "crew2"}) + crew3 = mock_crew_factory(name="Crew 3", output_json_dict={"output": "crew3"}) + + router = Router[Dict[str, Any], Pipeline]( + routes={ + "route1": Route( + condition=lambda x: x.get("score", 0) > 80, + pipeline=Pipeline(stages=[crew1]), + ), + "route2": Route( + condition=lambda x: x.get("score", 0) > 50, + pipeline=Pipeline(stages=[crew2]), + ), + }, + default=Pipeline(stages=[crew3]), + ) + return router + + return _create_mock_router def test_pipeline_initialization(mock_crew_factory): @@ -479,9 +501,40 @@ async def test_pipeline_data_accumulation(mock_crew_factory): assert final_result.crews_outputs[1].json_dict == {"key2": "value2"} -def test_add_condition(pipeline_router_factory, mock_crew_factory): - pipeline_router = pipeline_router_factory() - crew = mock_crew_factory(name="Test Crew") - pipeline_router.add_condition(lambda x: x.get("score", 0) > 80, crew) - assert len(pipeline_router.conditions) == 1 - assert pipeline_router.conditions[0][1] == crew +@pytest.mark.asyncio +async def test_pipeline_with_router(mock_router_factory): + router = mock_router_factory() + pipeline = Pipeline(stages=[router]) + + # Test high score route + result_high = await pipeline.kickoff([{"score": 90}]) + assert len(result_high) == 1 + assert result_high[0].json_dict is not None + assert result_high[0].json_dict["output"] == "crew1" + assert result_high[0].trace == [ + {"score": 90}, + {"router": "Router", "route_taken": "route1"}, + "Crew 1", + ] + + # Test medium score route + result_medium = await pipeline.kickoff([{"score": 60}]) + assert len(result_medium) == 1 + assert result_medium[0].json_dict is not None + assert result_medium[0].json_dict["output"] == "crew2" + assert result_medium[0].trace == [ + {"score": 60}, + {"router": "Router", "route_taken": "route2"}, + "Crew 2", + ] + + # Test low score (default) route + result_low = await pipeline.kickoff([{"score": 30}]) + assert len(result_low) == 1 + assert result_low[0].json_dict is not None + assert result_low[0].json_dict["output"] == "crew3" + assert result_low[0].trace == [ + {"score": 30}, + {"router": "Router", "route_taken": "default"}, + "Crew 3", + ]