rshift working

This commit is contained in:
Brandon Hancock
2024-07-12 16:48:19 -04:00
parent f7680d6157
commit c5002eedd9
3 changed files with 70 additions and 12 deletions

View File

@@ -2,19 +2,19 @@ import asyncio
import json import json
import uuid import uuid
from concurrent.futures import Future from concurrent.futures import Future
from typing import Any, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from langchain_core.callbacks import BaseCallbackHandler from langchain_core.callbacks import BaseCallbackHandler
from pydantic import ( from pydantic import (
UUID4, UUID4,
BaseModel, BaseModel,
ConfigDict, ConfigDict,
Field, Field,
InstanceOf, InstanceOf,
Json, Json,
PrivateAttr, PrivateAttr,
field_validator, field_validator,
model_validator, model_validator,
) )
from pydantic_core import PydanticCustomError from pydantic_core import PydanticCustomError
@@ -34,8 +34,8 @@ from crewai.utilities import I18N, FileHandler, Logger, RPMController
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
from crewai.utilities.formatter import ( from crewai.utilities.formatter import (
aggregate_raw_outputs_from_task_outputs, aggregate_raw_outputs_from_task_outputs,
aggregate_raw_outputs_from_tasks, aggregate_raw_outputs_from_tasks,
) )
from crewai.utilities.training_handler import CrewTrainingHandler from crewai.utilities.training_handler import CrewTrainingHandler
@@ -44,6 +44,9 @@ try:
except ImportError: except ImportError:
agentops = None agentops = None
if TYPE_CHECKING:
from crewai.procedure.procedure import Procedure
class Crew(BaseModel): class Crew(BaseModel):
""" """
@@ -767,5 +770,17 @@ class Crew(BaseModel):
return total_usage_metrics return total_usage_metrics
def __rshift__(self, other: "Crew") -> "Procedure":
"""
Implements the >> operator to add another Crew to an existing Procedure.
"""
from crewai.procedure.procedure import Procedure
if not isinstance(other, Crew):
raise TypeError(
f"Unsupported operand type for >>: '{type(self).__name__}' and '{type(other).__name__}'"
)
return Procedure(crews=[self, other])
def __repr__(self): def __repr__(self):
return f"Crew(id={self.id}, process={self.process}, number_of_agents={len(self.agents)}, number_of_tasks={len(self.tasks)})" return f"Crew(id={self.id}, process={self.process}, number_of_agents={len(self.agents)}, number_of_tasks={len(self.tasks)})"

View File

@@ -36,3 +36,13 @@ class Procedure(BaseModel):
outputs = await asyncio.gather(*crew_kickoffs) outputs = await asyncio.gather(*crew_kickoffs)
return outputs return outputs
def __rshift__(self, other: Crew) -> "Procedure":
"""
Implements the >> operator to add another Crew to an existing Procedure.
"""
if not isinstance(other, Crew):
raise TypeError(
f"Unsupported operand type for >>: '{type(self).__name__}' and '{type(other).__name__}'"
)
return type(self)(crews=self.crews + [other])

View File

@@ -180,3 +180,36 @@ async def test_procedure_chaining(mock_crew_factory):
"completion_tokens": 75, "completion_tokens": 75,
} }
assert result[0].json_dict == {"key2": "value2"} assert result[0].json_dict == {"key2": "value2"}
def test_crew_rshift_operator():
"""
Test that the >> operator correctly creates a Procedure from two Crews.
"""
# Create minimal Crew instances
agent = Agent(role="Test Agent", goal="Test Goal", backstory="Test Backstory")
task = Task(agent=agent, description="Test Task", expected_output="Test Output")
crew1 = Crew(agents=[agent], tasks=[task])
crew2 = Crew(agents=[agent], tasks=[task])
crew3 = Crew(agents=[agent], tasks=[task])
# Test the >> operator
procedure = crew1 >> crew2
assert isinstance(procedure, Procedure)
assert len(procedure.crews) == 2
assert procedure.crews[0] == crew1
assert procedure.crews[1] == crew2
# Test chaining multiple crews
procedure = crew1 >> crew2 >> crew3
assert isinstance(procedure, Procedure)
assert len(procedure.crews) == 3
assert procedure.crews[0] == crew1
assert procedure.crews[1] == crew2
assert procedure.crews[2] == crew3
# Test error case: trying to shift with non-Crew object
with pytest.raises(TypeError):
crew1 >> "not a crew"