rshift working

This commit is contained in:
Brandon Hancock
2024-07-12 16:48:19 -04:00
parent f7680d6157
commit c5002eedd9
3 changed files with 70 additions and 12 deletions

View File

@@ -2,19 +2,19 @@ import asyncio
import json
import uuid
from concurrent.futures import Future
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from langchain_core.callbacks import BaseCallbackHandler
from pydantic import (
UUID4,
BaseModel,
ConfigDict,
Field,
InstanceOf,
Json,
PrivateAttr,
field_validator,
model_validator,
UUID4,
BaseModel,
ConfigDict,
Field,
InstanceOf,
Json,
PrivateAttr,
field_validator,
model_validator,
)
from pydantic_core import PydanticCustomError
@@ -34,8 +34,8 @@ from crewai.utilities import I18N, FileHandler, Logger, RPMController
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
from crewai.utilities.formatter import (
aggregate_raw_outputs_from_task_outputs,
aggregate_raw_outputs_from_tasks,
aggregate_raw_outputs_from_task_outputs,
aggregate_raw_outputs_from_tasks,
)
from crewai.utilities.training_handler import CrewTrainingHandler
@@ -44,6 +44,9 @@ try:
except ImportError:
agentops = None
if TYPE_CHECKING:
from crewai.procedure.procedure import Procedure
class Crew(BaseModel):
"""
@@ -767,5 +770,17 @@ class Crew(BaseModel):
return total_usage_metrics
def __rshift__(self, other: "Crew") -> "Procedure":
"""
Implements the >> operator to add another Crew to an existing Procedure.
"""
from crewai.procedure.procedure import Procedure
if not isinstance(other, Crew):
raise TypeError(
f"Unsupported operand type for >>: '{type(self).__name__}' and '{type(other).__name__}'"
)
return Procedure(crews=[self, other])
def __repr__(self):
return f"Crew(id={self.id}, process={self.process}, number_of_agents={len(self.agents)}, number_of_tasks={len(self.tasks)})"

View File

@@ -36,3 +36,13 @@ class Procedure(BaseModel):
outputs = await asyncio.gather(*crew_kickoffs)
return outputs
def __rshift__(self, other: Crew) -> "Procedure":
"""
Implements the >> operator to add another Crew to an existing Procedure.
"""
if not isinstance(other, Crew):
raise TypeError(
f"Unsupported operand type for >>: '{type(self).__name__}' and '{type(other).__name__}'"
)
return type(self)(crews=self.crews + [other])

View File

@@ -180,3 +180,36 @@ async def test_procedure_chaining(mock_crew_factory):
"completion_tokens": 75,
}
assert result[0].json_dict == {"key2": "value2"}
def test_crew_rshift_operator():
"""
Test that the >> operator correctly creates a Procedure from two Crews.
"""
# Create minimal Crew instances
agent = Agent(role="Test Agent", goal="Test Goal", backstory="Test Backstory")
task = Task(agent=agent, description="Test Task", expected_output="Test Output")
crew1 = Crew(agents=[agent], tasks=[task])
crew2 = Crew(agents=[agent], tasks=[task])
crew3 = Crew(agents=[agent], tasks=[task])
# Test the >> operator
procedure = crew1 >> crew2
assert isinstance(procedure, Procedure)
assert len(procedure.crews) == 2
assert procedure.crews[0] == crew1
assert procedure.crews[1] == crew2
# Test chaining multiple crews
procedure = crew1 >> crew2 >> crew3
assert isinstance(procedure, Procedure)
assert len(procedure.crews) == 3
assert procedure.crews[0] == crew1
assert procedure.crews[1] == crew2
assert procedure.crews[2] == crew3
# Test error case: trying to shift with non-Crew object
with pytest.raises(TypeError):
crew1 >> "not a crew"