diff --git a/src/crewai/crew.py b/src/crewai/crew.py index a7a7b1fed..a8235f9bf 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -2,19 +2,19 @@ import asyncio import json import uuid from concurrent.futures import Future -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from langchain_core.callbacks import BaseCallbackHandler from pydantic import ( - UUID4, - BaseModel, - ConfigDict, - Field, - InstanceOf, - Json, - PrivateAttr, - field_validator, - model_validator, + UUID4, + BaseModel, + ConfigDict, + Field, + InstanceOf, + Json, + PrivateAttr, + field_validator, + model_validator, ) from pydantic_core import PydanticCustomError @@ -34,8 +34,8 @@ from crewai.utilities import I18N, FileHandler, Logger, RPMController from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE from crewai.utilities.evaluators.task_evaluator import TaskEvaluator from crewai.utilities.formatter import ( - aggregate_raw_outputs_from_task_outputs, - aggregate_raw_outputs_from_tasks, + aggregate_raw_outputs_from_task_outputs, + aggregate_raw_outputs_from_tasks, ) from crewai.utilities.training_handler import CrewTrainingHandler @@ -44,6 +44,9 @@ try: except ImportError: agentops = None +if TYPE_CHECKING: + from crewai.procedure.procedure import Procedure + class Crew(BaseModel): """ @@ -767,5 +770,17 @@ class Crew(BaseModel): return total_usage_metrics + def __rshift__(self, other: "Crew") -> "Procedure": + """ + Implements the >> operator to add another Crew to an existing Procedure. + """ + from crewai.procedure.procedure import Procedure + + if not isinstance(other, Crew): + raise TypeError( + f"Unsupported operand type for >>: '{type(self).__name__}' and '{type(other).__name__}'" + ) + return Procedure(crews=[self, other]) + def __repr__(self): return f"Crew(id={self.id}, process={self.process}, number_of_agents={len(self.agents)}, number_of_tasks={len(self.tasks)})" diff --git a/src/crewai/procedure/procedure.py b/src/crewai/procedure/procedure.py index 26316937b..98cde607d 100644 --- a/src/crewai/procedure/procedure.py +++ b/src/crewai/procedure/procedure.py @@ -36,3 +36,13 @@ class Procedure(BaseModel): outputs = await asyncio.gather(*crew_kickoffs) return outputs + + def __rshift__(self, other: Crew) -> "Procedure": + """ + Implements the >> operator to add another Crew to an existing Procedure. + """ + if not isinstance(other, Crew): + raise TypeError( + f"Unsupported operand type for >>: '{type(self).__name__}' and '{type(other).__name__}'" + ) + return type(self)(crews=self.crews + [other]) diff --git a/tests/procedure/test_procedure.py b/tests/procedure/test_procedure.py index d7987e41e..696166973 100644 --- a/tests/procedure/test_procedure.py +++ b/tests/procedure/test_procedure.py @@ -180,3 +180,36 @@ async def test_procedure_chaining(mock_crew_factory): "completion_tokens": 75, } assert result[0].json_dict == {"key2": "value2"} + + +def test_crew_rshift_operator(): + """ + Test that the >> operator correctly creates a Procedure from two Crews. + """ + # Create minimal Crew instances + agent = Agent(role="Test Agent", goal="Test Goal", backstory="Test Backstory") + task = Task(agent=agent, description="Test Task", expected_output="Test Output") + crew1 = Crew(agents=[agent], tasks=[task]) + crew2 = Crew(agents=[agent], tasks=[task]) + crew3 = Crew(agents=[agent], tasks=[task]) + + # Test the >> operator + procedure = crew1 >> crew2 + + assert isinstance(procedure, Procedure) + assert len(procedure.crews) == 2 + assert procedure.crews[0] == crew1 + assert procedure.crews[1] == crew2 + + # Test chaining multiple crews + procedure = crew1 >> crew2 >> crew3 + + assert isinstance(procedure, Procedure) + assert len(procedure.crews) == 3 + assert procedure.crews[0] == crew1 + assert procedure.crews[1] == crew2 + assert procedure.crews[2] == crew3 + + # Test error case: trying to shift with non-Crew object + with pytest.raises(TypeError): + crew1 >> "not a crew"