mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-23 11:22:38 +00:00
Compare commits
12 Commits
devin/1760
...
gl/chore/p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6081809c76 | ||
|
|
30a4b712a3 | ||
|
|
8465350f1d | ||
|
|
2a23bc604c | ||
|
|
714f8a8940 | ||
|
|
e029de2863 | ||
|
|
6492852a0c | ||
|
|
fecf7e9a83 | ||
|
|
6bc8818ae9 | ||
|
|
620df71763 | ||
|
|
7d6324dfa3 | ||
|
|
541eec0639 |
@@ -1,177 +0,0 @@
|
||||
"""
|
||||
FastAPI Streaming Integration Example for CrewAI
|
||||
|
||||
This example demonstrates how to integrate CrewAI with FastAPI to stream
|
||||
crew execution events in real-time using Server-Sent Events (SSE).
|
||||
|
||||
Installation:
|
||||
pip install crewai fastapi uvicorn
|
||||
|
||||
Usage:
|
||||
python fastapi_streaming_example.py
|
||||
|
||||
Then visit:
|
||||
http://localhost:8000/docs for the API documentation
|
||||
http://localhost:8000/stream?topic=AI to see streaming in action
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai import Agent, Crew, Task
|
||||
|
||||
app = FastAPI(title="CrewAI Streaming API")
|
||||
|
||||
|
||||
class ResearchRequest(BaseModel):
|
||||
topic: str
|
||||
|
||||
|
||||
def create_research_crew(topic: str) -> Crew:
|
||||
"""Create a research crew for the given topic."""
|
||||
researcher = Agent(
|
||||
role="Researcher",
|
||||
goal=f"Research and analyze information about {topic}",
|
||||
backstory="You're an expert researcher with deep knowledge in various fields.",
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description=f"Research and provide a comprehensive summary about {topic}",
|
||||
expected_output="A detailed summary with key insights",
|
||||
agent=researcher,
|
||||
)
|
||||
|
||||
return Crew(agents=[researcher], tasks=[task], verbose=True)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint with API information."""
|
||||
return {
|
||||
"message": "CrewAI Streaming API",
|
||||
"endpoints": {
|
||||
"/stream": "GET - Stream crew execution events (query param: topic)",
|
||||
"/research": "POST - Execute crew and return final result",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@app.get("/stream")
|
||||
async def stream_crew_execution(topic: str = "artificial intelligence"):
|
||||
"""
|
||||
Stream crew execution events in real-time using Server-Sent Events.
|
||||
|
||||
Args:
|
||||
topic: The research topic (default: "artificial intelligence")
|
||||
|
||||
Returns:
|
||||
StreamingResponse with text/event-stream content type
|
||||
"""
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
"""Generate Server-Sent Events from crew execution."""
|
||||
crew = create_research_crew(topic)
|
||||
|
||||
try:
|
||||
for event in crew.kickoff_stream(inputs={"topic": topic}):
|
||||
event_data = json.dumps(event)
|
||||
yield f"data: {event_data}\n\n"
|
||||
|
||||
yield "data: {\"type\": \"done\"}\n\n"
|
||||
|
||||
except Exception as e:
|
||||
error_event = {"type": "error", "data": {"message": str(e)}}
|
||||
yield f"data: {json.dumps(error_event)}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@app.post("/research")
|
||||
async def research_topic(request: ResearchRequest):
|
||||
"""
|
||||
Execute crew research and return the final result.
|
||||
|
||||
Args:
|
||||
request: ResearchRequest with topic field
|
||||
|
||||
Returns:
|
||||
JSON response with the research result
|
||||
"""
|
||||
crew = create_research_crew(request.topic)
|
||||
|
||||
try:
|
||||
result = crew.kickoff(inputs={"topic": request.topic})
|
||||
return {
|
||||
"success": True,
|
||||
"topic": request.topic,
|
||||
"result": result.raw,
|
||||
"usage_metrics": (
|
||||
result.token_usage.model_dump() if result.token_usage else None
|
||||
),
|
||||
}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
|
||||
@app.get("/stream-filtered")
|
||||
async def stream_filtered_events(
|
||||
topic: str = "artificial intelligence", event_types: str = "llm_stream_chunk"
|
||||
):
|
||||
"""
|
||||
Stream only specific event types.
|
||||
|
||||
Args:
|
||||
topic: The research topic
|
||||
event_types: Comma-separated list of event types to include
|
||||
|
||||
Returns:
|
||||
StreamingResponse with filtered events
|
||||
"""
|
||||
allowed_types = set(event_types.split(","))
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
crew = create_research_crew(topic)
|
||||
|
||||
try:
|
||||
for event in crew.kickoff_stream(inputs={"topic": topic}):
|
||||
if event["type"] in allowed_types:
|
||||
event_data = json.dumps(event)
|
||||
yield f"data: {event_data}\n\n"
|
||||
|
||||
yield "data: {\"type\": \"done\"}\n\n"
|
||||
|
||||
except Exception as e:
|
||||
error_event = {"type": "error", "data": {"message": str(e)}}
|
||||
yield f"data: {json.dumps(error_event)}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
print("Starting CrewAI Streaming API...")
|
||||
print("Visit http://localhost:8000/docs for API documentation")
|
||||
print("Try: http://localhost:8000/stream?topic=quantum%20computing")
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
@@ -766,118 +766,6 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self._task_output_handler.reset()
|
||||
return results
|
||||
|
||||
def kickoff_stream(self, inputs: dict[str, Any] | None = None):
|
||||
"""
|
||||
Stream crew execution events in real-time.
|
||||
|
||||
This method yields events as they occur during crew execution, making it
|
||||
easy to integrate with streaming frameworks like FastAPI's StreamingResponse.
|
||||
|
||||
Args:
|
||||
inputs: Optional dictionary of inputs for the crew execution
|
||||
|
||||
Yields:
|
||||
dict: Event dictionaries containing event type and data
|
||||
|
||||
Example:
|
||||
```python
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import StreamingResponse
|
||||
import json
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/stream")
|
||||
async def stream_crew():
|
||||
def event_generator():
|
||||
for event in crew.kickoff_stream(inputs={"topic": "AI"}):
|
||||
yield f"data: {json.dumps(event)}\\n\\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream"
|
||||
)
|
||||
```
|
||||
"""
|
||||
import queue
|
||||
import threading
|
||||
from crewai.events.base_events import BaseEvent
|
||||
|
||||
event_queue: queue.Queue = queue.Queue()
|
||||
completion_event = threading.Event()
|
||||
exception_holder = {"exception": None}
|
||||
|
||||
def event_handler(source: Any, event: BaseEvent):
|
||||
event_dict = {
|
||||
"type": event.type,
|
||||
"data": event.model_dump(exclude={"from_task", "from_agent"}),
|
||||
}
|
||||
event_queue.put(event_dict)
|
||||
|
||||
from crewai.events.types.crew_events import (
|
||||
CrewKickoffStartedEvent,
|
||||
CrewKickoffCompletedEvent,
|
||||
CrewKickoffFailedEvent,
|
||||
)
|
||||
from crewai.events.types.task_events import (
|
||||
TaskStartedEvent,
|
||||
TaskCompletedEvent,
|
||||
TaskFailedEvent,
|
||||
)
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionStartedEvent,
|
||||
AgentExecutionCompletedEvent,
|
||||
)
|
||||
from crewai.events.types.llm_events import (
|
||||
LLMStreamChunkEvent,
|
||||
LLMCallStartedEvent,
|
||||
LLMCallCompletedEvent,
|
||||
)
|
||||
from crewai.events.types.tool_usage_events import (
|
||||
ToolUsageStartedEvent,
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageErrorEvent,
|
||||
)
|
||||
|
||||
crewai_event_bus.register_handler(CrewKickoffStartedEvent, event_handler)
|
||||
crewai_event_bus.register_handler(CrewKickoffCompletedEvent, event_handler)
|
||||
crewai_event_bus.register_handler(CrewKickoffFailedEvent, event_handler)
|
||||
crewai_event_bus.register_handler(TaskStartedEvent, event_handler)
|
||||
crewai_event_bus.register_handler(TaskCompletedEvent, event_handler)
|
||||
crewai_event_bus.register_handler(TaskFailedEvent, event_handler)
|
||||
crewai_event_bus.register_handler(AgentExecutionStartedEvent, event_handler)
|
||||
crewai_event_bus.register_handler(AgentExecutionCompletedEvent, event_handler)
|
||||
crewai_event_bus.register_handler(LLMStreamChunkEvent, event_handler)
|
||||
crewai_event_bus.register_handler(LLMCallStartedEvent, event_handler)
|
||||
crewai_event_bus.register_handler(LLMCallCompletedEvent, event_handler)
|
||||
crewai_event_bus.register_handler(ToolUsageStartedEvent, event_handler)
|
||||
crewai_event_bus.register_handler(ToolUsageFinishedEvent, event_handler)
|
||||
crewai_event_bus.register_handler(ToolUsageErrorEvent, event_handler)
|
||||
|
||||
def run_kickoff():
|
||||
try:
|
||||
result = self.kickoff(inputs=inputs)
|
||||
event_queue.put({"type": "final_output", "data": {"output": result.raw}})
|
||||
except Exception as e:
|
||||
exception_holder["exception"] = e
|
||||
finally:
|
||||
completion_event.set()
|
||||
|
||||
thread = threading.Thread(target=run_kickoff, daemon=True)
|
||||
thread.start()
|
||||
|
||||
try:
|
||||
while not completion_event.is_set() or not event_queue.empty():
|
||||
event = event_queue.get(timeout=0.1) if not event_queue.empty() else None
|
||||
if event is not None:
|
||||
yield event
|
||||
|
||||
if exception_holder["exception"]:
|
||||
raise exception_holder["exception"]
|
||||
|
||||
finally:
|
||||
thread.join(timeout=1)
|
||||
|
||||
def _handle_crew_planning(self):
|
||||
"""Handles the Crew planning."""
|
||||
self._logger.log("info", "Planning the crew execution")
|
||||
|
||||
@@ -31,7 +31,7 @@ from crewai.flow.flow_visualizer import plot_flow
|
||||
from crewai.flow.persistence.base import FlowPersistence
|
||||
from crewai.flow.types import FlowExecutionData
|
||||
from crewai.flow.utils import get_possible_return_constants
|
||||
from crewai.utilities.printer import Printer, PrinterColor
|
||||
from crewai.utilities.printer import Printer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -105,7 +105,7 @@ def start(condition: str | dict | Callable | None = None) -> Callable:
|
||||
condition : Optional[Union[str, dict, Callable]], optional
|
||||
Defines when the start method should execute. Can be:
|
||||
- str: Name of a method that triggers this start
|
||||
- dict: Result from or_() or and_(), including nested conditions
|
||||
- dict: Contains "type" ("AND"/"OR") and "methods" (list of triggers)
|
||||
- Callable: A method reference that triggers this start
|
||||
Default is None, meaning unconditional start.
|
||||
|
||||
@@ -140,18 +140,13 @@ def start(condition: str | dict | Callable | None = None) -> Callable:
|
||||
if isinstance(condition, str):
|
||||
func.__trigger_methods__ = [condition]
|
||||
func.__condition_type__ = "OR"
|
||||
elif isinstance(condition, dict) and "type" in condition:
|
||||
if "conditions" in condition:
|
||||
func.__trigger_condition__ = condition
|
||||
func.__trigger_methods__ = _extract_all_methods(condition)
|
||||
func.__condition_type__ = condition["type"]
|
||||
elif "methods" in condition:
|
||||
func.__trigger_methods__ = condition["methods"]
|
||||
func.__condition_type__ = condition["type"]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Condition dict must contain 'conditions' or 'methods'"
|
||||
)
|
||||
elif (
|
||||
isinstance(condition, dict)
|
||||
and "type" in condition
|
||||
and "methods" in condition
|
||||
):
|
||||
func.__trigger_methods__ = condition["methods"]
|
||||
func.__condition_type__ = condition["type"]
|
||||
elif callable(condition) and hasattr(condition, "__name__"):
|
||||
func.__trigger_methods__ = [condition.__name__]
|
||||
func.__condition_type__ = "OR"
|
||||
@@ -177,7 +172,7 @@ def listen(condition: str | dict | Callable) -> Callable:
|
||||
condition : Union[str, dict, Callable]
|
||||
Specifies when the listener should execute. Can be:
|
||||
- str: Name of a method that triggers this listener
|
||||
- dict: Result from or_() or and_(), including nested conditions
|
||||
- dict: Contains "type" ("AND"/"OR") and "methods" (list of triggers)
|
||||
- Callable: A method reference that triggers this listener
|
||||
|
||||
Returns
|
||||
@@ -205,18 +200,13 @@ def listen(condition: str | dict | Callable) -> Callable:
|
||||
if isinstance(condition, str):
|
||||
func.__trigger_methods__ = [condition]
|
||||
func.__condition_type__ = "OR"
|
||||
elif isinstance(condition, dict) and "type" in condition:
|
||||
if "conditions" in condition:
|
||||
func.__trigger_condition__ = condition
|
||||
func.__trigger_methods__ = _extract_all_methods(condition)
|
||||
func.__condition_type__ = condition["type"]
|
||||
elif "methods" in condition:
|
||||
func.__trigger_methods__ = condition["methods"]
|
||||
func.__condition_type__ = condition["type"]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Condition dict must contain 'conditions' or 'methods'"
|
||||
)
|
||||
elif (
|
||||
isinstance(condition, dict)
|
||||
and "type" in condition
|
||||
and "methods" in condition
|
||||
):
|
||||
func.__trigger_methods__ = condition["methods"]
|
||||
func.__condition_type__ = condition["type"]
|
||||
elif callable(condition) and hasattr(condition, "__name__"):
|
||||
func.__trigger_methods__ = [condition.__name__]
|
||||
func.__condition_type__ = "OR"
|
||||
@@ -243,7 +233,7 @@ def router(condition: str | dict | Callable) -> Callable:
|
||||
condition : Union[str, dict, Callable]
|
||||
Specifies when the router should execute. Can be:
|
||||
- str: Name of a method that triggers this router
|
||||
- dict: Result from or_() or and_(), including nested conditions
|
||||
- dict: Contains "type" ("AND"/"OR") and "methods" (list of triggers)
|
||||
- Callable: A method reference that triggers this router
|
||||
|
||||
Returns
|
||||
@@ -276,18 +266,13 @@ def router(condition: str | dict | Callable) -> Callable:
|
||||
if isinstance(condition, str):
|
||||
func.__trigger_methods__ = [condition]
|
||||
func.__condition_type__ = "OR"
|
||||
elif isinstance(condition, dict) and "type" in condition:
|
||||
if "conditions" in condition:
|
||||
func.__trigger_condition__ = condition
|
||||
func.__trigger_methods__ = _extract_all_methods(condition)
|
||||
func.__condition_type__ = condition["type"]
|
||||
elif "methods" in condition:
|
||||
func.__trigger_methods__ = condition["methods"]
|
||||
func.__condition_type__ = condition["type"]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Condition dict must contain 'conditions' or 'methods'"
|
||||
)
|
||||
elif (
|
||||
isinstance(condition, dict)
|
||||
and "type" in condition
|
||||
and "methods" in condition
|
||||
):
|
||||
func.__trigger_methods__ = condition["methods"]
|
||||
func.__condition_type__ = condition["type"]
|
||||
elif callable(condition) and hasattr(condition, "__name__"):
|
||||
func.__trigger_methods__ = [condition.__name__]
|
||||
func.__condition_type__ = "OR"
|
||||
@@ -313,15 +298,14 @@ def or_(*conditions: str | dict | Callable) -> dict:
|
||||
*conditions : Union[str, dict, Callable]
|
||||
Variable number of conditions that can be:
|
||||
- str: Method names
|
||||
- dict: Existing condition dictionaries (nested conditions)
|
||||
- dict: Existing condition dictionaries
|
||||
- Callable: Method references
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
A condition dictionary with format:
|
||||
{"type": "OR", "conditions": list_of_conditions}
|
||||
where each condition can be a string (method name) or a nested dict
|
||||
{"type": "OR", "methods": list_of_method_names}
|
||||
|
||||
Raises
|
||||
------
|
||||
@@ -333,22 +317,18 @@ def or_(*conditions: str | dict | Callable) -> dict:
|
||||
>>> @listen(or_("success", "timeout"))
|
||||
>>> def handle_completion(self):
|
||||
... pass
|
||||
|
||||
>>> @listen(or_(and_("step1", "step2"), "step3"))
|
||||
>>> def handle_nested(self):
|
||||
... pass
|
||||
"""
|
||||
processed_conditions: list[str | dict[str, Any]] = []
|
||||
methods = []
|
||||
for condition in conditions:
|
||||
if isinstance(condition, dict):
|
||||
processed_conditions.append(condition)
|
||||
if isinstance(condition, dict) and "methods" in condition:
|
||||
methods.extend(condition["methods"])
|
||||
elif isinstance(condition, str):
|
||||
processed_conditions.append(condition)
|
||||
methods.append(condition)
|
||||
elif callable(condition):
|
||||
processed_conditions.append(getattr(condition, "__name__", repr(condition)))
|
||||
methods.append(getattr(condition, "__name__", repr(condition)))
|
||||
else:
|
||||
raise ValueError("Invalid condition in or_()")
|
||||
return {"type": "OR", "conditions": processed_conditions}
|
||||
return {"type": "OR", "methods": methods}
|
||||
|
||||
|
||||
def and_(*conditions: str | dict | Callable) -> dict:
|
||||
@@ -364,15 +344,14 @@ def and_(*conditions: str | dict | Callable) -> dict:
|
||||
*conditions : Union[str, dict, Callable]
|
||||
Variable number of conditions that can be:
|
||||
- str: Method names
|
||||
- dict: Existing condition dictionaries (nested conditions)
|
||||
- dict: Existing condition dictionaries
|
||||
- Callable: Method references
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
A condition dictionary with format:
|
||||
{"type": "AND", "conditions": list_of_conditions}
|
||||
where each condition can be a string (method name) or a nested dict
|
||||
{"type": "AND", "methods": list_of_method_names}
|
||||
|
||||
Raises
|
||||
------
|
||||
@@ -384,69 +363,18 @@ def and_(*conditions: str | dict | Callable) -> dict:
|
||||
>>> @listen(and_("validated", "processed"))
|
||||
>>> def handle_complete_data(self):
|
||||
... pass
|
||||
|
||||
>>> @listen(and_(or_("step1", "step2"), "step3"))
|
||||
>>> def handle_nested(self):
|
||||
... pass
|
||||
"""
|
||||
processed_conditions: list[str | dict[str, Any]] = []
|
||||
methods = []
|
||||
for condition in conditions:
|
||||
if isinstance(condition, dict):
|
||||
processed_conditions.append(condition)
|
||||
if isinstance(condition, dict) and "methods" in condition:
|
||||
methods.extend(condition["methods"])
|
||||
elif isinstance(condition, str):
|
||||
processed_conditions.append(condition)
|
||||
methods.append(condition)
|
||||
elif callable(condition):
|
||||
processed_conditions.append(getattr(condition, "__name__", repr(condition)))
|
||||
methods.append(getattr(condition, "__name__", repr(condition)))
|
||||
else:
|
||||
raise ValueError("Invalid condition in and_()")
|
||||
return {"type": "AND", "conditions": processed_conditions}
|
||||
|
||||
|
||||
def _normalize_condition(condition: str | dict | list) -> dict:
|
||||
"""Normalize a condition to standard format with 'conditions' key.
|
||||
|
||||
Args:
|
||||
condition: Can be a string (method name), dict (condition), or list
|
||||
|
||||
Returns:
|
||||
Normalized dict with 'type' and 'conditions' keys
|
||||
"""
|
||||
if isinstance(condition, str):
|
||||
return {"type": "OR", "conditions": [condition]}
|
||||
if isinstance(condition, dict):
|
||||
if "conditions" in condition:
|
||||
return condition
|
||||
if "methods" in condition:
|
||||
return {"type": condition["type"], "conditions": condition["methods"]}
|
||||
return condition
|
||||
if isinstance(condition, list):
|
||||
return {"type": "OR", "conditions": condition}
|
||||
return {"type": "OR", "conditions": [condition]}
|
||||
|
||||
|
||||
def _extract_all_methods(condition: str | dict | list) -> list[str]:
|
||||
"""Extract all method names from a condition (including nested).
|
||||
|
||||
Args:
|
||||
condition: Can be a string, dict, or list
|
||||
|
||||
Returns:
|
||||
List of all method names in the condition tree
|
||||
"""
|
||||
if isinstance(condition, str):
|
||||
return [condition]
|
||||
if isinstance(condition, dict):
|
||||
normalized = _normalize_condition(condition)
|
||||
methods = []
|
||||
for sub_cond in normalized.get("conditions", []):
|
||||
methods.extend(_extract_all_methods(sub_cond))
|
||||
return methods
|
||||
if isinstance(condition, list):
|
||||
methods = []
|
||||
for item in condition:
|
||||
methods.extend(_extract_all_methods(item))
|
||||
return methods
|
||||
return []
|
||||
return {"type": "AND", "methods": methods}
|
||||
|
||||
|
||||
class FlowMeta(type):
|
||||
@@ -474,10 +402,7 @@ class FlowMeta(type):
|
||||
if hasattr(attr_value, "__trigger_methods__"):
|
||||
methods = attr_value.__trigger_methods__
|
||||
condition_type = getattr(attr_value, "__condition_type__", "OR")
|
||||
if hasattr(attr_value, "__trigger_condition__"):
|
||||
listeners[attr_name] = attr_value.__trigger_condition__
|
||||
else:
|
||||
listeners[attr_name] = (condition_type, methods)
|
||||
listeners[attr_name] = (condition_type, methods)
|
||||
|
||||
if (
|
||||
hasattr(attr_value, "__is_router__")
|
||||
@@ -897,7 +822,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
# Clear completed methods and outputs for a fresh start
|
||||
self._completed_methods.clear()
|
||||
self._method_outputs.clear()
|
||||
self._pending_and_listeners.clear()
|
||||
else:
|
||||
# We're restoring from persistence, set the flag
|
||||
self._is_execution_resuming = True
|
||||
@@ -1162,16 +1086,10 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
for method_name in self._start_methods:
|
||||
# Check if this start method is triggered by the current trigger
|
||||
if method_name in self._listeners:
|
||||
condition_data = self._listeners[method_name]
|
||||
should_trigger = False
|
||||
if isinstance(condition_data, tuple):
|
||||
_, trigger_methods = condition_data
|
||||
should_trigger = current_trigger in trigger_methods
|
||||
elif isinstance(condition_data, dict):
|
||||
all_methods = _extract_all_methods(condition_data)
|
||||
should_trigger = current_trigger in all_methods
|
||||
|
||||
if should_trigger:
|
||||
condition_type, trigger_methods = self._listeners[
|
||||
method_name
|
||||
]
|
||||
if current_trigger in trigger_methods:
|
||||
# Only execute if this is a cycle (method was already completed)
|
||||
if method_name in self._completed_methods:
|
||||
# For router-triggered start methods in cycles, temporarily clear resumption flag
|
||||
@@ -1181,51 +1099,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
await self._execute_start_method(method_name)
|
||||
self._is_execution_resuming = was_resuming
|
||||
|
||||
def _evaluate_condition(
|
||||
self, condition: str | dict, trigger_method: str, listener_name: str
|
||||
) -> bool:
|
||||
"""Recursively evaluate a condition (simple or nested).
|
||||
|
||||
Args:
|
||||
condition: Can be a string (method name) or dict (nested condition)
|
||||
trigger_method: The method that just completed
|
||||
listener_name: Name of the listener being evaluated
|
||||
|
||||
Returns:
|
||||
True if the condition is satisfied, False otherwise
|
||||
"""
|
||||
if isinstance(condition, str):
|
||||
return condition == trigger_method
|
||||
|
||||
if isinstance(condition, dict):
|
||||
normalized = _normalize_condition(condition)
|
||||
cond_type = normalized.get("type", "OR")
|
||||
sub_conditions = normalized.get("conditions", [])
|
||||
|
||||
if cond_type == "OR":
|
||||
return any(
|
||||
self._evaluate_condition(sub_cond, trigger_method, listener_name)
|
||||
for sub_cond in sub_conditions
|
||||
)
|
||||
|
||||
if cond_type == "AND":
|
||||
pending_key = f"{listener_name}:{id(condition)}"
|
||||
|
||||
if pending_key not in self._pending_and_listeners:
|
||||
all_methods = set(_extract_all_methods(condition))
|
||||
self._pending_and_listeners[pending_key] = all_methods
|
||||
|
||||
if trigger_method in self._pending_and_listeners[pending_key]:
|
||||
self._pending_and_listeners[pending_key].discard(trigger_method)
|
||||
|
||||
if not self._pending_and_listeners[pending_key]:
|
||||
self._pending_and_listeners.pop(pending_key, None)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
def _find_triggered_methods(
|
||||
self, trigger_method: str, router_only: bool
|
||||
) -> list[str]:
|
||||
@@ -1233,7 +1106,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
Finds all methods that should be triggered based on conditions.
|
||||
|
||||
This internal method evaluates both OR and AND conditions to determine
|
||||
which methods should be executed next in the flow. Supports nested conditions.
|
||||
which methods should be executed next in the flow.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -1250,13 +1123,14 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
|
||||
Notes
|
||||
-----
|
||||
- Handles both OR and AND conditions, including nested combinations
|
||||
- Handles both OR and AND conditions:
|
||||
* OR: Triggers if any condition is met
|
||||
* AND: Triggers only when all conditions are met
|
||||
- Maintains state for AND conditions using _pending_and_listeners
|
||||
- Separates router and normal listener evaluation
|
||||
"""
|
||||
triggered = []
|
||||
|
||||
for listener_name, condition_data in self._listeners.items():
|
||||
for listener_name, (condition_type, methods) in self._listeners.items():
|
||||
is_router = listener_name in self._routers
|
||||
|
||||
if router_only != is_router:
|
||||
@@ -1265,29 +1139,23 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
if not router_only and listener_name in self._start_methods:
|
||||
continue
|
||||
|
||||
if isinstance(condition_data, tuple):
|
||||
condition_type, methods = condition_data
|
||||
|
||||
if condition_type == "OR":
|
||||
if trigger_method in methods:
|
||||
triggered.append(listener_name)
|
||||
elif condition_type == "AND":
|
||||
if listener_name not in self._pending_and_listeners:
|
||||
self._pending_and_listeners[listener_name] = set(methods)
|
||||
if trigger_method in self._pending_and_listeners[listener_name]:
|
||||
self._pending_and_listeners[listener_name].discard(
|
||||
trigger_method
|
||||
)
|
||||
|
||||
if not self._pending_and_listeners[listener_name]:
|
||||
triggered.append(listener_name)
|
||||
self._pending_and_listeners.pop(listener_name, None)
|
||||
|
||||
elif isinstance(condition_data, dict):
|
||||
if self._evaluate_condition(
|
||||
condition_data, trigger_method, listener_name
|
||||
):
|
||||
if condition_type == "OR":
|
||||
# If the trigger_method matches any in methods, run this
|
||||
if trigger_method in methods:
|
||||
triggered.append(listener_name)
|
||||
elif condition_type == "AND":
|
||||
# Initialize pending methods for this listener if not already done
|
||||
if listener_name not in self._pending_and_listeners:
|
||||
self._pending_and_listeners[listener_name] = set(methods)
|
||||
# Remove the trigger method from pending methods
|
||||
if trigger_method in self._pending_and_listeners[listener_name]:
|
||||
self._pending_and_listeners[listener_name].discard(trigger_method)
|
||||
|
||||
if not self._pending_and_listeners[listener_name]:
|
||||
# All required methods have been executed
|
||||
triggered.append(listener_name)
|
||||
# Reset pending methods for this listener
|
||||
self._pending_and_listeners.pop(listener_name, None)
|
||||
|
||||
return triggered
|
||||
|
||||
@@ -1350,7 +1218,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
raise
|
||||
|
||||
def _log_flow_event(
|
||||
self, message: str, color: PrinterColor | None = "yellow", level: str = "info"
|
||||
self, message: str, color: str = "yellow", level: str = "info"
|
||||
) -> None:
|
||||
"""Centralized logging method for flow events.
|
||||
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from .annotations import (
|
||||
"""Project package for CrewAI."""
|
||||
|
||||
from crewai.project.annotations import (
|
||||
after_kickoff,
|
||||
agent,
|
||||
before_kickoff,
|
||||
@@ -11,19 +13,19 @@ from .annotations import (
|
||||
task,
|
||||
tool,
|
||||
)
|
||||
from .crew_base import CrewBase
|
||||
from crewai.project.crew_base import CrewBase
|
||||
|
||||
__all__ = [
|
||||
"CrewBase",
|
||||
"after_kickoff",
|
||||
"agent",
|
||||
"before_kickoff",
|
||||
"cache_handler",
|
||||
"callback",
|
||||
"crew",
|
||||
"task",
|
||||
"llm",
|
||||
"output_json",
|
||||
"output_pydantic",
|
||||
"task",
|
||||
"tool",
|
||||
"callback",
|
||||
"CrewBase",
|
||||
"llm",
|
||||
"cache_handler",
|
||||
"before_kickoff",
|
||||
"after_kickoff",
|
||||
]
|
||||
|
||||
@@ -1,97 +1,192 @@
|
||||
from functools import wraps
|
||||
from typing import Callable
|
||||
|
||||
from crewai import Crew
|
||||
from crewai.project.utils import memoize
|
||||
|
||||
"""Decorators for defining crew components and their behaviors."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
def before_kickoff(func):
|
||||
"""Marks a method to execute before crew kickoff."""
|
||||
func.is_before_kickoff = True
|
||||
return func
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Concatenate, ParamSpec, TypeVar
|
||||
|
||||
from crewai.project.utils import memoize
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai import Agent, Crew, Task
|
||||
|
||||
from crewai.project.wrappers import (
|
||||
AfterKickoffMethod,
|
||||
AgentMethod,
|
||||
BeforeKickoffMethod,
|
||||
CacheHandlerMethod,
|
||||
CallbackMethod,
|
||||
CrewInstance,
|
||||
LLMMethod,
|
||||
OutputJsonClass,
|
||||
OutputPydanticClass,
|
||||
TaskMethod,
|
||||
TaskResultT,
|
||||
ToolMethod,
|
||||
)
|
||||
|
||||
P = ParamSpec("P")
|
||||
P2 = ParamSpec("P2")
|
||||
R = TypeVar("R")
|
||||
R2 = TypeVar("R2")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def after_kickoff(func):
|
||||
"""Marks a method to execute after crew kickoff."""
|
||||
func.is_after_kickoff = True
|
||||
return func
|
||||
def before_kickoff(meth: Callable[P, R]) -> BeforeKickoffMethod[P, R]:
|
||||
"""Marks a method to execute before crew kickoff.
|
||||
|
||||
Args:
|
||||
meth: The method to mark.
|
||||
|
||||
Returns:
|
||||
A wrapped method marked for before kickoff execution.
|
||||
"""
|
||||
return BeforeKickoffMethod(meth)
|
||||
|
||||
|
||||
def task(func):
|
||||
"""Marks a method as a crew task."""
|
||||
func.is_task = True
|
||||
def after_kickoff(meth: Callable[P, R]) -> AfterKickoffMethod[P, R]:
|
||||
"""Marks a method to execute after crew kickoff.
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
result = func(*args, **kwargs)
|
||||
if not result.name:
|
||||
result.name = func.__name__
|
||||
return result
|
||||
Args:
|
||||
meth: The method to mark.
|
||||
|
||||
return memoize(wrapper)
|
||||
Returns:
|
||||
A wrapped method marked for after kickoff execution.
|
||||
"""
|
||||
return AfterKickoffMethod(meth)
|
||||
|
||||
|
||||
def agent(func):
|
||||
"""Marks a method as a crew agent."""
|
||||
func.is_agent = True
|
||||
func = memoize(func)
|
||||
return func
|
||||
def task(meth: Callable[P, TaskResultT]) -> TaskMethod[P, TaskResultT]:
|
||||
"""Marks a method as a crew task.
|
||||
|
||||
Args:
|
||||
meth: The method to mark.
|
||||
|
||||
Returns:
|
||||
A wrapped method marked as a task with memoization.
|
||||
"""
|
||||
return TaskMethod(memoize(meth))
|
||||
|
||||
|
||||
def llm(func):
|
||||
"""Marks a method as an LLM provider."""
|
||||
func.is_llm = True
|
||||
func = memoize(func)
|
||||
return func
|
||||
def agent(meth: Callable[P, R]) -> AgentMethod[P, R]:
|
||||
"""Marks a method as a crew agent.
|
||||
|
||||
Args:
|
||||
meth: The method to mark.
|
||||
|
||||
Returns:
|
||||
A wrapped method marked as an agent with memoization.
|
||||
"""
|
||||
return AgentMethod(memoize(meth))
|
||||
|
||||
|
||||
def output_json(cls):
|
||||
"""Marks a class as JSON output format."""
|
||||
cls.is_output_json = True
|
||||
return cls
|
||||
def llm(meth: Callable[P, R]) -> LLMMethod[P, R]:
|
||||
"""Marks a method as an LLM provider.
|
||||
|
||||
Args:
|
||||
meth: The method to mark.
|
||||
|
||||
Returns:
|
||||
A wrapped method marked as an LLM provider with memoization.
|
||||
"""
|
||||
return LLMMethod(memoize(meth))
|
||||
|
||||
|
||||
def output_pydantic(cls):
|
||||
"""Marks a class as Pydantic output format."""
|
||||
cls.is_output_pydantic = True
|
||||
return cls
|
||||
def output_json(cls: type[T]) -> OutputJsonClass[T]:
|
||||
"""Marks a class as JSON output format.
|
||||
|
||||
Args:
|
||||
cls: The class to mark.
|
||||
|
||||
Returns:
|
||||
A wrapped class marked as JSON output format.
|
||||
"""
|
||||
return OutputJsonClass(cls)
|
||||
|
||||
|
||||
def tool(func):
|
||||
"""Marks a method as a crew tool."""
|
||||
func.is_tool = True
|
||||
return memoize(func)
|
||||
def output_pydantic(cls: type[T]) -> OutputPydanticClass[T]:
|
||||
"""Marks a class as Pydantic output format.
|
||||
|
||||
Args:
|
||||
cls: The class to mark.
|
||||
|
||||
Returns:
|
||||
A wrapped class marked as Pydantic output format.
|
||||
"""
|
||||
return OutputPydanticClass(cls)
|
||||
|
||||
|
||||
def callback(func):
|
||||
"""Marks a method as a crew callback."""
|
||||
func.is_callback = True
|
||||
return memoize(func)
|
||||
def tool(meth: Callable[P, R]) -> ToolMethod[P, R]:
|
||||
"""Marks a method as a crew tool.
|
||||
|
||||
Args:
|
||||
meth: The method to mark.
|
||||
|
||||
Returns:
|
||||
A wrapped method marked as a tool with memoization.
|
||||
"""
|
||||
return ToolMethod(memoize(meth))
|
||||
|
||||
|
||||
def cache_handler(func):
|
||||
"""Marks a method as a cache handler."""
|
||||
func.is_cache_handler = True
|
||||
return memoize(func)
|
||||
def callback(meth: Callable[P, R]) -> CallbackMethod[P, R]:
|
||||
"""Marks a method as a crew callback.
|
||||
|
||||
Args:
|
||||
meth: The method to mark.
|
||||
|
||||
Returns:
|
||||
A wrapped method marked as a callback with memoization.
|
||||
"""
|
||||
return CallbackMethod(memoize(meth))
|
||||
|
||||
|
||||
def crew(func) -> Callable[..., Crew]:
|
||||
"""Marks a method as the main crew execution point."""
|
||||
def cache_handler(meth: Callable[P, R]) -> CacheHandlerMethod[P, R]:
|
||||
"""Marks a method as a cache handler.
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs) -> Crew:
|
||||
instantiated_tasks = []
|
||||
instantiated_agents = []
|
||||
agent_roles = set()
|
||||
Args:
|
||||
meth: The method to mark.
|
||||
|
||||
Returns:
|
||||
A wrapped method marked as a cache handler with memoization.
|
||||
"""
|
||||
return CacheHandlerMethod(memoize(meth))
|
||||
|
||||
|
||||
def crew(
|
||||
meth: Callable[Concatenate[CrewInstance, P], Crew],
|
||||
) -> Callable[Concatenate[CrewInstance, P], Crew]:
|
||||
"""Marks a method as the main crew execution point.
|
||||
|
||||
Args:
|
||||
meth: The method to mark as crew execution point.
|
||||
|
||||
Returns:
|
||||
A wrapped method that instantiates tasks and agents before execution.
|
||||
"""
|
||||
|
||||
@wraps(meth)
|
||||
def wrapper(self: CrewInstance, *args: P.args, **kwargs: P.kwargs) -> Crew:
|
||||
"""Wrapper that sets up crew before calling the decorated method.
|
||||
|
||||
Args:
|
||||
self: The crew class instance.
|
||||
*args: Additional positional arguments.
|
||||
**kwargs: Keyword arguments to pass to the method.
|
||||
|
||||
Returns:
|
||||
The configured Crew instance with callbacks attached.
|
||||
"""
|
||||
instantiated_tasks: list[Task] = []
|
||||
instantiated_agents: list[Agent] = []
|
||||
agent_roles: set[str] = set()
|
||||
|
||||
# Use the preserved task and agent information
|
||||
tasks = self._original_tasks.items()
|
||||
agents = self._original_agents.items()
|
||||
tasks = self.__crew_metadata__["original_tasks"].items()
|
||||
agents = self.__crew_metadata__["original_agents"].items()
|
||||
|
||||
# Instantiate tasks in order
|
||||
for task_name, task_method in tasks:
|
||||
for _, task_method in tasks:
|
||||
task_instance = task_method(self)
|
||||
instantiated_tasks.append(task_instance)
|
||||
agent_instance = getattr(task_instance, "agent", None)
|
||||
@@ -100,7 +195,7 @@ def crew(func) -> Callable[..., Crew]:
|
||||
agent_roles.add(agent_instance.role)
|
||||
|
||||
# Instantiate agents not included by tasks
|
||||
for agent_name, agent_method in agents:
|
||||
for _, agent_method in agents:
|
||||
agent_instance = agent_method(self)
|
||||
if agent_instance.role not in agent_roles:
|
||||
instantiated_agents.append(agent_instance)
|
||||
@@ -109,19 +204,44 @@ def crew(func) -> Callable[..., Crew]:
|
||||
self.agents = instantiated_agents
|
||||
self.tasks = instantiated_tasks
|
||||
|
||||
crew = func(self, *args, **kwargs)
|
||||
crew_instance = meth(self, *args, **kwargs)
|
||||
|
||||
def callback_wrapper(callback, instance):
|
||||
def wrapper(*args, **kwargs):
|
||||
return callback(instance, *args, **kwargs)
|
||||
def callback_wrapper(
|
||||
hook: Callable[Concatenate[CrewInstance, P2], R2], instance: CrewInstance
|
||||
) -> Callable[P2, R2]:
|
||||
"""Bind a hook callback to an instance.
|
||||
|
||||
return wrapper
|
||||
Args:
|
||||
hook: The callback hook to bind.
|
||||
instance: The instance to bind to.
|
||||
|
||||
for _, callback in self._before_kickoff.items():
|
||||
crew.before_kickoff_callbacks.append(callback_wrapper(callback, self))
|
||||
for _, callback in self._after_kickoff.items():
|
||||
crew.after_kickoff_callbacks.append(callback_wrapper(callback, self))
|
||||
Returns:
|
||||
A bound callback function.
|
||||
"""
|
||||
|
||||
return crew
|
||||
def bound_callback(*cb_args: P2.args, **cb_kwargs: P2.kwargs) -> R2:
|
||||
"""Execute the bound callback.
|
||||
|
||||
Args:
|
||||
*cb_args: Positional arguments for the callback.
|
||||
**cb_kwargs: Keyword arguments for the callback.
|
||||
|
||||
Returns:
|
||||
The result of the callback execution.
|
||||
"""
|
||||
return hook(instance, *cb_args, **cb_kwargs)
|
||||
|
||||
return bound_callback
|
||||
|
||||
for hook_callback in self.__crew_metadata__["before_kickoff"].values():
|
||||
crew_instance.before_kickoff_callbacks.append(
|
||||
callback_wrapper(hook_callback, self)
|
||||
)
|
||||
for hook_callback in self.__crew_metadata__["after_kickoff"].values():
|
||||
crew_instance.after_kickoff_callbacks.append(
|
||||
callback_wrapper(hook_callback, self)
|
||||
)
|
||||
|
||||
return crew_instance
|
||||
|
||||
return memoize(wrapper)
|
||||
|
||||
@@ -1,298 +1,631 @@
|
||||
"""Base metaclass for creating crew classes with configuration and method management."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any, TypeVar, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, TypedDict, TypeGuard, TypeVar, cast
|
||||
|
||||
import yaml
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from crewai.project.wrappers import CrewClass, CrewMetadata
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai import Agent, Task
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.project.wrappers import (
|
||||
CrewInstance,
|
||||
OutputJsonClass,
|
||||
OutputPydanticClass,
|
||||
)
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
|
||||
|
||||
class AgentConfig(TypedDict, total=False):
|
||||
"""Type definition for agent configuration dictionary.
|
||||
|
||||
All fields are optional as they come from YAML configuration files.
|
||||
Fields can be either string references (from YAML) or actual instances (after processing).
|
||||
"""
|
||||
|
||||
# Core agent attributes (from BaseAgent)
|
||||
role: str
|
||||
goal: str
|
||||
backstory: str
|
||||
cache: bool
|
||||
verbose: bool
|
||||
max_rpm: int
|
||||
allow_delegation: bool
|
||||
max_iter: int
|
||||
max_tokens: int
|
||||
callbacks: list[str]
|
||||
|
||||
# LLM configuration
|
||||
llm: str
|
||||
function_calling_llm: str
|
||||
use_system_prompt: bool
|
||||
|
||||
# Template configuration
|
||||
system_template: str
|
||||
prompt_template: str
|
||||
response_template: str
|
||||
|
||||
# Tools and handlers (can be string references or instances)
|
||||
tools: list[str] | list[BaseTool]
|
||||
step_callback: str
|
||||
cache_handler: str | CacheHandler
|
||||
|
||||
# Code execution
|
||||
allow_code_execution: bool
|
||||
code_execution_mode: Literal["safe", "unsafe"]
|
||||
|
||||
# Context and performance
|
||||
respect_context_window: bool
|
||||
max_retry_limit: int
|
||||
|
||||
# Multimodal and reasoning
|
||||
multimodal: bool
|
||||
reasoning: bool
|
||||
max_reasoning_attempts: int
|
||||
|
||||
# Knowledge configuration
|
||||
knowledge_sources: list[str] | list[Any]
|
||||
knowledge_storage: str | Any
|
||||
knowledge_config: dict[str, Any]
|
||||
embedder: dict[str, Any]
|
||||
agent_knowledge_context: str
|
||||
crew_knowledge_context: str
|
||||
knowledge_search_query: str
|
||||
|
||||
# Misc configuration
|
||||
inject_date: bool
|
||||
date_format: str
|
||||
from_repository: str
|
||||
guardrail: Callable[[Any], tuple[bool, Any]] | str
|
||||
guardrail_max_retries: int
|
||||
|
||||
|
||||
class TaskConfig(TypedDict, total=False):
|
||||
"""Type definition for task configuration dictionary.
|
||||
|
||||
All fields are optional as they come from YAML configuration files.
|
||||
Fields can be either string references (from YAML) or actual instances (after processing).
|
||||
"""
|
||||
|
||||
# Core task attributes
|
||||
name: str
|
||||
description: str
|
||||
expected_output: str
|
||||
|
||||
# Agent and context
|
||||
agent: str
|
||||
context: list[str]
|
||||
|
||||
# Tools and callbacks (can be string references or instances)
|
||||
tools: list[str] | list[BaseTool]
|
||||
callback: str
|
||||
callbacks: list[str]
|
||||
|
||||
# Output configuration
|
||||
output_json: str
|
||||
output_pydantic: str
|
||||
output_file: str
|
||||
create_directory: bool
|
||||
|
||||
# Execution configuration
|
||||
async_execution: bool
|
||||
human_input: bool
|
||||
markdown: bool
|
||||
|
||||
# Guardrail configuration
|
||||
guardrail: Callable[[TaskOutput], tuple[bool, Any]] | str
|
||||
guardrail_max_retries: int
|
||||
|
||||
# Misc configuration
|
||||
allow_crewai_trigger_context: bool
|
||||
|
||||
|
||||
load_dotenv()
|
||||
|
||||
T = TypeVar("T", bound=type)
|
||||
|
||||
"""Base decorator for creating crew classes with configuration and function management."""
|
||||
CallableT = TypeVar("CallableT", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def CrewBase(cls: T) -> T: # noqa: N802
|
||||
"""Wraps a class with crew functionality and configuration management."""
|
||||
def _set_base_directory(cls: type[CrewClass]) -> None:
|
||||
"""Set the base directory for the crew class.
|
||||
|
||||
class WrappedClass(cls): # type: ignore
|
||||
is_crew_class: bool = True # type: ignore
|
||||
Args:
|
||||
cls: Crew class to configure.
|
||||
"""
|
||||
try:
|
||||
cls.base_directory = Path(inspect.getfile(cls)).parent
|
||||
except (TypeError, OSError):
|
||||
cls.base_directory = Path.cwd()
|
||||
|
||||
# Get the directory of the class being decorated
|
||||
base_directory = Path(inspect.getfile(cls)).parent
|
||||
|
||||
original_agents_config_path = getattr(
|
||||
cls, "agents_config", "config/agents.yaml"
|
||||
def _set_config_paths(cls: type[CrewClass]) -> None:
|
||||
"""Set the configuration file paths for the crew class.
|
||||
|
||||
Args:
|
||||
cls: Crew class to configure.
|
||||
"""
|
||||
cls.original_agents_config_path = getattr(
|
||||
cls, "agents_config", "config/agents.yaml"
|
||||
)
|
||||
cls.original_tasks_config_path = getattr(cls, "tasks_config", "config/tasks.yaml")
|
||||
|
||||
|
||||
def _set_mcp_params(cls: type[CrewClass]) -> None:
|
||||
"""Set the MCP server parameters for the crew class.
|
||||
|
||||
Args:
|
||||
cls: Crew class to configure.
|
||||
"""
|
||||
cls.mcp_server_params = getattr(cls, "mcp_server_params", None)
|
||||
cls.mcp_connect_timeout = getattr(cls, "mcp_connect_timeout", 30)
|
||||
|
||||
|
||||
def _is_string_list(value: list[str] | list[BaseTool]) -> TypeGuard[list[str]]:
|
||||
"""Type guard to check if list contains strings rather than BaseTool instances.
|
||||
|
||||
Args:
|
||||
value: List that may contain strings or BaseTool instances.
|
||||
|
||||
Returns:
|
||||
True if all elements are strings, False otherwise.
|
||||
"""
|
||||
return all(isinstance(item, str) for item in value)
|
||||
|
||||
|
||||
def _is_string_value(value: str | CacheHandler) -> TypeGuard[str]:
|
||||
"""Type guard to check if value is a string rather than a CacheHandler instance.
|
||||
|
||||
Args:
|
||||
value: Value that may be a string or CacheHandler instance.
|
||||
|
||||
Returns:
|
||||
True if value is a string, False otherwise.
|
||||
"""
|
||||
return isinstance(value, str)
|
||||
|
||||
|
||||
class CrewBaseMeta(type):
|
||||
"""Metaclass that adds crew functionality to classes."""
|
||||
|
||||
def __new__(
|
||||
mcs,
|
||||
name: str,
|
||||
bases: tuple[type, ...],
|
||||
namespace: dict[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> type[CrewClass]:
|
||||
"""Create crew class with configuration and method injection.
|
||||
|
||||
Args:
|
||||
name: Class name.
|
||||
bases: Base classes.
|
||||
namespace: Class namespace dictionary.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
New crew class with injected methods and attributes.
|
||||
"""
|
||||
cls = cast(
|
||||
type[CrewClass], cast(object, super().__new__(mcs, name, bases, namespace))
|
||||
)
|
||||
original_tasks_config_path = getattr(cls, "tasks_config", "config/tasks.yaml")
|
||||
|
||||
mcp_server_params: Any = getattr(cls, "mcp_server_params", None)
|
||||
mcp_connect_timeout: int = getattr(cls, "mcp_connect_timeout", 30)
|
||||
cls.is_crew_class = True
|
||||
cls._crew_name = name
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.load_configurations()
|
||||
self.map_all_agent_variables()
|
||||
self.map_all_task_variables()
|
||||
# Preserve all decorated functions
|
||||
self._original_functions = {
|
||||
name: method
|
||||
for name, method in cls.__dict__.items()
|
||||
if any(
|
||||
hasattr(method, attr)
|
||||
for attr in [
|
||||
"is_task",
|
||||
"is_agent",
|
||||
"is_before_kickoff",
|
||||
"is_after_kickoff",
|
||||
"is_kickoff",
|
||||
]
|
||||
)
|
||||
}
|
||||
# Store specific function types
|
||||
self._original_tasks = self._filter_functions(
|
||||
self._original_functions, "is_task"
|
||||
)
|
||||
self._original_agents = self._filter_functions(
|
||||
self._original_functions, "is_agent"
|
||||
)
|
||||
self._before_kickoff = self._filter_functions(
|
||||
self._original_functions, "is_before_kickoff"
|
||||
)
|
||||
self._after_kickoff = self._filter_functions(
|
||||
self._original_functions, "is_after_kickoff"
|
||||
)
|
||||
self._kickoff = self._filter_functions(
|
||||
self._original_functions, "is_kickoff"
|
||||
)
|
||||
for setup_fn in _CLASS_SETUP_FUNCTIONS:
|
||||
setup_fn(cls)
|
||||
|
||||
# Add close mcp server method to after kickoff
|
||||
bound_method = self._create_close_mcp_server_method()
|
||||
self._after_kickoff['_close_mcp_server'] = bound_method
|
||||
for method in _METHODS_TO_INJECT:
|
||||
setattr(cls, method.__name__, method)
|
||||
|
||||
def _create_close_mcp_server_method(self):
|
||||
def _close_mcp_server(self, instance, outputs):
|
||||
adapter = getattr(self, '_mcp_server_adapter', None)
|
||||
if adapter is not None:
|
||||
try:
|
||||
adapter.stop()
|
||||
except Exception as e:
|
||||
logging.warning(f"Error stopping MCP server: {e}")
|
||||
return outputs
|
||||
return cls
|
||||
|
||||
_close_mcp_server.is_after_kickoff = True
|
||||
def __call__(cls, *args: Any, **kwargs: Any) -> CrewInstance:
|
||||
"""Intercept instance creation to initialize crew functionality.
|
||||
|
||||
import types
|
||||
return types.MethodType(_close_mcp_server, self)
|
||||
Args:
|
||||
*args: Positional arguments for instance creation.
|
||||
**kwargs: Keyword arguments for instance creation.
|
||||
|
||||
def get_mcp_tools(self, *tool_names: list[str]) -> list[BaseTool]:
|
||||
if not self.mcp_server_params:
|
||||
return []
|
||||
Returns:
|
||||
Initialized crew instance.
|
||||
"""
|
||||
instance: CrewInstance = super().__call__(*args, **kwargs)
|
||||
CrewBaseMeta._initialize_crew_instance(instance, cls)
|
||||
return instance
|
||||
|
||||
from crewai_tools import MCPServerAdapter # type: ignore[import-untyped]
|
||||
@staticmethod
|
||||
def _initialize_crew_instance(instance: CrewInstance, cls: type) -> None:
|
||||
"""Initialize crew instance attributes and load configurations.
|
||||
|
||||
adapter = getattr(self, '_mcp_server_adapter', None)
|
||||
if not adapter:
|
||||
self._mcp_server_adapter = MCPServerAdapter(
|
||||
self.mcp_server_params,
|
||||
connect_timeout=self.mcp_connect_timeout
|
||||
)
|
||||
Args:
|
||||
instance: Crew instance to initialize.
|
||||
cls: Crew class type.
|
||||
"""
|
||||
instance._mcp_server_adapter = None
|
||||
instance.load_configurations()
|
||||
instance._all_methods = _get_all_methods(instance)
|
||||
instance.map_all_agent_variables()
|
||||
instance.map_all_task_variables()
|
||||
|
||||
return self._mcp_server_adapter.tools.filter_by_names(tool_names or None)
|
||||
|
||||
|
||||
def load_configurations(self):
|
||||
"""Load agent and task configurations from YAML files."""
|
||||
if isinstance(self.original_agents_config_path, str):
|
||||
agents_config_path = (
|
||||
self.base_directory / self.original_agents_config_path
|
||||
)
|
||||
try:
|
||||
self.agents_config = self.load_yaml(agents_config_path)
|
||||
except FileNotFoundError:
|
||||
logging.warning(
|
||||
f"Agent config file not found at {agents_config_path}. "
|
||||
"Proceeding with empty agent configurations."
|
||||
)
|
||||
self.agents_config = {}
|
||||
else:
|
||||
logging.warning(
|
||||
"No agent configuration path provided. Proceeding with empty agent configurations."
|
||||
)
|
||||
self.agents_config = {}
|
||||
|
||||
if isinstance(self.original_tasks_config_path, str):
|
||||
tasks_config_path = (
|
||||
self.base_directory / self.original_tasks_config_path
|
||||
)
|
||||
try:
|
||||
self.tasks_config = self.load_yaml(tasks_config_path)
|
||||
except FileNotFoundError:
|
||||
logging.warning(
|
||||
f"Task config file not found at {tasks_config_path}. "
|
||||
"Proceeding with empty task configurations."
|
||||
)
|
||||
self.tasks_config = {}
|
||||
else:
|
||||
logging.warning(
|
||||
"No task configuration path provided. Proceeding with empty task configurations."
|
||||
)
|
||||
self.tasks_config = {}
|
||||
|
||||
@staticmethod
|
||||
def load_yaml(config_path: Path):
|
||||
try:
|
||||
with open(config_path, "r", encoding="utf-8") as file:
|
||||
return yaml.safe_load(file)
|
||||
except FileNotFoundError:
|
||||
print(f"File not found: {config_path}")
|
||||
raise
|
||||
|
||||
def _get_all_functions(self):
|
||||
return {
|
||||
name: getattr(self, name)
|
||||
for name in dir(self)
|
||||
if callable(getattr(self, name))
|
||||
}
|
||||
|
||||
def _filter_functions(
|
||||
self, functions: dict[str, Callable], attribute: str
|
||||
) -> dict[str, Callable]:
|
||||
return {
|
||||
name: func
|
||||
for name, func in functions.items()
|
||||
if hasattr(func, attribute)
|
||||
}
|
||||
|
||||
def map_all_agent_variables(self) -> None:
|
||||
all_functions = self._get_all_functions()
|
||||
llms = self._filter_functions(all_functions, "is_llm")
|
||||
tool_functions = self._filter_functions(all_functions, "is_tool")
|
||||
cache_handler_functions = self._filter_functions(
|
||||
all_functions, "is_cache_handler"
|
||||
)
|
||||
callbacks = self._filter_functions(all_functions, "is_callback")
|
||||
|
||||
for agent_name, agent_info in self.agents_config.items():
|
||||
self._map_agent_variables(
|
||||
agent_name,
|
||||
agent_info,
|
||||
llms,
|
||||
tool_functions,
|
||||
cache_handler_functions,
|
||||
callbacks,
|
||||
)
|
||||
|
||||
def _map_agent_variables(
|
||||
self,
|
||||
agent_name: str,
|
||||
agent_info: dict[str, Any],
|
||||
llms: dict[str, Callable],
|
||||
tool_functions: dict[str, Callable],
|
||||
cache_handler_functions: dict[str, Callable],
|
||||
callbacks: dict[str, Callable],
|
||||
) -> None:
|
||||
if llm := agent_info.get("llm"):
|
||||
try:
|
||||
self.agents_config[agent_name]["llm"] = llms[llm]()
|
||||
except KeyError:
|
||||
self.agents_config[agent_name]["llm"] = llm
|
||||
|
||||
if tools := agent_info.get("tools"):
|
||||
self.agents_config[agent_name]["tools"] = [
|
||||
tool_functions[tool]() for tool in tools
|
||||
original_methods = {
|
||||
name: method
|
||||
for name, method in cls.__dict__.items()
|
||||
if any(
|
||||
hasattr(method, attr)
|
||||
for attr in [
|
||||
"is_task",
|
||||
"is_agent",
|
||||
"is_before_kickoff",
|
||||
"is_after_kickoff",
|
||||
"is_kickoff",
|
||||
]
|
||||
|
||||
if function_calling_llm := agent_info.get("function_calling_llm"):
|
||||
try:
|
||||
self.agents_config[agent_name]["function_calling_llm"] = llms[function_calling_llm]()
|
||||
except KeyError:
|
||||
self.agents_config[agent_name]["function_calling_llm"] = function_calling_llm
|
||||
|
||||
if step_callback := agent_info.get("step_callback"):
|
||||
self.agents_config[agent_name]["step_callback"] = callbacks[
|
||||
step_callback
|
||||
]()
|
||||
|
||||
if cache_handler := agent_info.get("cache_handler"):
|
||||
self.agents_config[agent_name]["cache_handler"] = (
|
||||
cache_handler_functions[cache_handler]()
|
||||
)
|
||||
|
||||
def map_all_task_variables(self) -> None:
|
||||
all_functions = self._get_all_functions()
|
||||
agents = self._filter_functions(all_functions, "is_agent")
|
||||
tasks = self._filter_functions(all_functions, "is_task")
|
||||
output_json_functions = self._filter_functions(
|
||||
all_functions, "is_output_json"
|
||||
)
|
||||
tool_functions = self._filter_functions(all_functions, "is_tool")
|
||||
callback_functions = self._filter_functions(all_functions, "is_callback")
|
||||
output_pydantic_functions = self._filter_functions(
|
||||
all_functions, "is_output_pydantic"
|
||||
}
|
||||
|
||||
after_kickoff_callbacks = _filter_methods(original_methods, "is_after_kickoff")
|
||||
after_kickoff_callbacks["close_mcp_server"] = instance.close_mcp_server
|
||||
|
||||
instance.__crew_metadata__ = CrewMetadata(
|
||||
original_methods=original_methods,
|
||||
original_tasks=_filter_methods(original_methods, "is_task"),
|
||||
original_agents=_filter_methods(original_methods, "is_agent"),
|
||||
before_kickoff=_filter_methods(original_methods, "is_before_kickoff"),
|
||||
after_kickoff=after_kickoff_callbacks,
|
||||
kickoff=_filter_methods(original_methods, "is_kickoff"),
|
||||
)
|
||||
|
||||
|
||||
def close_mcp_server(
|
||||
self: CrewInstance, _instance: CrewInstance, outputs: CrewOutput
|
||||
) -> CrewOutput:
|
||||
"""Stop MCP server adapter and return outputs.
|
||||
|
||||
Args:
|
||||
self: Crew instance with MCP server adapter.
|
||||
_instance: Crew instance (unused, required by callback signature).
|
||||
outputs: Crew execution outputs.
|
||||
|
||||
Returns:
|
||||
Unmodified crew outputs.
|
||||
"""
|
||||
if self._mcp_server_adapter is not None:
|
||||
try:
|
||||
self._mcp_server_adapter.stop()
|
||||
except Exception as e:
|
||||
logging.warning(f"Error stopping MCP server: {e}")
|
||||
return outputs
|
||||
|
||||
|
||||
def get_mcp_tools(self: CrewInstance, *tool_names: str) -> list[BaseTool]:
|
||||
"""Get MCP tools filtered by name.
|
||||
|
||||
Args:
|
||||
self: Crew instance with MCP server configuration.
|
||||
*tool_names: Optional tool names to filter by.
|
||||
|
||||
Returns:
|
||||
List of filtered MCP tools, or empty list if no MCP server configured.
|
||||
"""
|
||||
if not self.mcp_server_params:
|
||||
return []
|
||||
|
||||
from crewai_tools import MCPServerAdapter # type: ignore[import-untyped]
|
||||
|
||||
if self._mcp_server_adapter is None:
|
||||
self._mcp_server_adapter = MCPServerAdapter(
|
||||
self.mcp_server_params, connect_timeout=self.mcp_connect_timeout
|
||||
)
|
||||
|
||||
return self._mcp_server_adapter.tools.filter_by_names(tool_names or None)
|
||||
|
||||
|
||||
def _load_config(
|
||||
self: CrewInstance, config_path: str | None, config_type: Literal["agent", "task"]
|
||||
) -> dict[str, Any]:
|
||||
"""Load YAML config file or return empty dict if not found.
|
||||
|
||||
Args:
|
||||
self: Crew instance with base directory and load_yaml method.
|
||||
config_path: Relative path to config file.
|
||||
config_type: Config type for logging, either "agent" or "task".
|
||||
|
||||
Returns:
|
||||
Config dictionary or empty dict.
|
||||
"""
|
||||
if isinstance(config_path, str):
|
||||
full_path = self.base_directory / config_path
|
||||
try:
|
||||
return self.load_yaml(full_path)
|
||||
except FileNotFoundError:
|
||||
logging.warning(
|
||||
f"{config_type.capitalize()} config file not found at {full_path}. "
|
||||
f"Proceeding with empty {config_type} configurations."
|
||||
)
|
||||
return {}
|
||||
else:
|
||||
logging.warning(
|
||||
f"No {config_type} configuration path provided. "
|
||||
f"Proceeding with empty {config_type} configurations."
|
||||
)
|
||||
return {}
|
||||
|
||||
for task_name, task_info in self.tasks_config.items():
|
||||
self._map_task_variables(
|
||||
task_name,
|
||||
task_info,
|
||||
agents,
|
||||
tasks,
|
||||
output_json_functions,
|
||||
tool_functions,
|
||||
callback_functions,
|
||||
output_pydantic_functions,
|
||||
)
|
||||
|
||||
def _map_task_variables(
|
||||
self,
|
||||
task_name: str,
|
||||
task_info: dict[str, Any],
|
||||
agents: dict[str, Callable],
|
||||
tasks: dict[str, Callable],
|
||||
output_json_functions: dict[str, Callable],
|
||||
tool_functions: dict[str, Callable],
|
||||
callback_functions: dict[str, Callable],
|
||||
output_pydantic_functions: dict[str, Callable],
|
||||
) -> None:
|
||||
if context_list := task_info.get("context"):
|
||||
self.tasks_config[task_name]["context"] = [
|
||||
tasks[context_task_name]() for context_task_name in context_list
|
||||
]
|
||||
def load_configurations(self: CrewInstance) -> None:
|
||||
"""Load agent and task YAML configurations.
|
||||
|
||||
if tools := task_info.get("tools"):
|
||||
self.tasks_config[task_name]["tools"] = [
|
||||
tool_functions[tool]() for tool in tools
|
||||
]
|
||||
Args:
|
||||
self: Crew instance with configuration paths.
|
||||
"""
|
||||
self.agents_config = self._load_config(self.original_agents_config_path, "agent")
|
||||
self.tasks_config = self._load_config(self.original_tasks_config_path, "task")
|
||||
|
||||
if agent_name := task_info.get("agent"):
|
||||
self.tasks_config[task_name]["agent"] = agents[agent_name]()
|
||||
|
||||
if output_json := task_info.get("output_json"):
|
||||
self.tasks_config[task_name]["output_json"] = output_json_functions[
|
||||
output_json
|
||||
]
|
||||
def load_yaml(config_path: Path) -> dict[str, Any]:
|
||||
"""Load and parse YAML configuration file.
|
||||
|
||||
if output_pydantic := task_info.get("output_pydantic"):
|
||||
self.tasks_config[task_name]["output_pydantic"] = (
|
||||
output_pydantic_functions[output_pydantic]
|
||||
)
|
||||
Args:
|
||||
config_path: Path to YAML configuration file.
|
||||
|
||||
if callbacks := task_info.get("callbacks"):
|
||||
self.tasks_config[task_name]["callbacks"] = [
|
||||
callback_functions[callback]() for callback in callbacks
|
||||
]
|
||||
Returns:
|
||||
Parsed YAML content as a dictionary. Returns empty dict if file is empty.
|
||||
|
||||
if guardrail := task_info.get("guardrail"):
|
||||
self.tasks_config[task_name]["guardrail"] = guardrail
|
||||
Raises:
|
||||
FileNotFoundError: If config file does not exist.
|
||||
"""
|
||||
try:
|
||||
with open(config_path, encoding="utf-8") as file:
|
||||
content = yaml.safe_load(file)
|
||||
return content if isinstance(content, dict) else {}
|
||||
except FileNotFoundError:
|
||||
logging.warning(f"File not found: {config_path}")
|
||||
raise
|
||||
|
||||
# Include base class (qual)name in the wrapper class (qual)name.
|
||||
WrappedClass.__name__ = CrewBase.__name__ + "(" + cls.__name__ + ")"
|
||||
WrappedClass.__qualname__ = CrewBase.__qualname__ + "(" + cls.__name__ + ")"
|
||||
WrappedClass._crew_name = cls.__name__
|
||||
|
||||
return cast(T, WrappedClass)
|
||||
def _get_all_methods(self: CrewInstance) -> dict[str, Callable[..., Any]]:
|
||||
"""Return all non-dunder callable attributes (methods).
|
||||
|
||||
Args:
|
||||
self: Instance to inspect for callable attributes.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping method names to bound method objects.
|
||||
"""
|
||||
return {
|
||||
name: getattr(self, name)
|
||||
for name in dir(self)
|
||||
if not (name.startswith("__") and name.endswith("__"))
|
||||
and callable(getattr(self, name, None))
|
||||
}
|
||||
|
||||
|
||||
def _filter_methods(
|
||||
methods: dict[str, CallableT], attribute: str
|
||||
) -> dict[str, CallableT]:
|
||||
"""Filter methods by attribute presence, preserving exact callable types.
|
||||
|
||||
Args:
|
||||
methods: Dictionary of methods to filter.
|
||||
attribute: Attribute name to check for.
|
||||
|
||||
Returns:
|
||||
Dictionary containing only methods with the specified attribute.
|
||||
The return type matches the input callable type exactly.
|
||||
"""
|
||||
return {
|
||||
name: method for name, method in methods.items() if hasattr(method, attribute)
|
||||
}
|
||||
|
||||
|
||||
def map_all_agent_variables(self: CrewInstance) -> None:
|
||||
"""Map agent configuration variables to callable instances.
|
||||
|
||||
Args:
|
||||
self: Crew instance with agent configurations to map.
|
||||
"""
|
||||
llms = _filter_methods(self._all_methods, "is_llm")
|
||||
tool_functions = _filter_methods(self._all_methods, "is_tool")
|
||||
cache_handler_functions = _filter_methods(self._all_methods, "is_cache_handler")
|
||||
callbacks = _filter_methods(self._all_methods, "is_callback")
|
||||
|
||||
for agent_name, agent_info in self.agents_config.items():
|
||||
self._map_agent_variables(
|
||||
agent_name=agent_name,
|
||||
agent_info=agent_info,
|
||||
llms=llms,
|
||||
tool_functions=tool_functions,
|
||||
cache_handler_functions=cache_handler_functions,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
|
||||
def _map_agent_variables(
|
||||
self: CrewInstance,
|
||||
agent_name: str,
|
||||
agent_info: AgentConfig,
|
||||
llms: dict[str, Callable[[], Any]],
|
||||
tool_functions: dict[str, Callable[[], BaseTool]],
|
||||
cache_handler_functions: dict[str, Callable[[], Any]],
|
||||
callbacks: dict[str, Callable[..., Any]],
|
||||
) -> None:
|
||||
"""Resolve and map variables for a single agent.
|
||||
|
||||
Args:
|
||||
self: Crew instance with agent configurations.
|
||||
agent_name: Name of agent to configure.
|
||||
agent_info: Agent configuration dictionary with optional fields.
|
||||
llms: Dictionary mapping names to LLM factory functions.
|
||||
tool_functions: Dictionary mapping names to tool factory functions.
|
||||
cache_handler_functions: Dictionary mapping names to cache handler factory functions.
|
||||
callbacks: Dictionary of available callbacks.
|
||||
"""
|
||||
if llm := agent_info.get("llm"):
|
||||
factory = llms.get(llm)
|
||||
self.agents_config[agent_name]["llm"] = factory() if factory else llm
|
||||
|
||||
if tools := agent_info.get("tools"):
|
||||
if _is_string_list(tools):
|
||||
self.agents_config[agent_name]["tools"] = [
|
||||
tool_functions[tool]() for tool in tools
|
||||
]
|
||||
|
||||
if function_calling_llm := agent_info.get("function_calling_llm"):
|
||||
factory = llms.get(function_calling_llm)
|
||||
self.agents_config[agent_name]["function_calling_llm"] = (
|
||||
factory() if factory else function_calling_llm
|
||||
)
|
||||
|
||||
if step_callback := agent_info.get("step_callback"):
|
||||
self.agents_config[agent_name]["step_callback"] = callbacks[step_callback]()
|
||||
|
||||
if cache_handler := agent_info.get("cache_handler"):
|
||||
if _is_string_value(cache_handler):
|
||||
self.agents_config[agent_name]["cache_handler"] = cache_handler_functions[
|
||||
cache_handler
|
||||
]()
|
||||
|
||||
|
||||
def map_all_task_variables(self: CrewInstance) -> None:
|
||||
"""Map task configuration variables to callable instances.
|
||||
|
||||
Args:
|
||||
self: Crew instance with task configurations to map.
|
||||
"""
|
||||
agents = _filter_methods(self._all_methods, "is_agent")
|
||||
tasks = _filter_methods(self._all_methods, "is_task")
|
||||
output_json_functions = _filter_methods(self._all_methods, "is_output_json")
|
||||
tool_functions = _filter_methods(self._all_methods, "is_tool")
|
||||
callback_functions = _filter_methods(self._all_methods, "is_callback")
|
||||
output_pydantic_functions = _filter_methods(self._all_methods, "is_output_pydantic")
|
||||
|
||||
for task_name, task_info in self.tasks_config.items():
|
||||
self._map_task_variables(
|
||||
task_name=task_name,
|
||||
task_info=task_info,
|
||||
agents=agents,
|
||||
tasks=tasks,
|
||||
output_json_functions=output_json_functions,
|
||||
tool_functions=tool_functions,
|
||||
callback_functions=callback_functions,
|
||||
output_pydantic_functions=output_pydantic_functions,
|
||||
)
|
||||
|
||||
|
||||
def _map_task_variables(
|
||||
self: CrewInstance,
|
||||
task_name: str,
|
||||
task_info: TaskConfig,
|
||||
agents: dict[str, Callable[[], Agent]],
|
||||
tasks: dict[str, Callable[[], Task]],
|
||||
output_json_functions: dict[str, OutputJsonClass[Any]],
|
||||
tool_functions: dict[str, Callable[[], BaseTool]],
|
||||
callback_functions: dict[str, Callable[..., Any]],
|
||||
output_pydantic_functions: dict[str, OutputPydanticClass[Any]],
|
||||
) -> None:
|
||||
"""Resolve and map variables for a single task.
|
||||
|
||||
Args:
|
||||
self: Crew instance with task configurations.
|
||||
task_name: Name of task to configure.
|
||||
task_info: Task configuration dictionary with optional fields.
|
||||
agents: Dictionary mapping names to agent factory functions.
|
||||
tasks: Dictionary mapping names to task factory functions.
|
||||
output_json_functions: Dictionary of JSON output class wrappers.
|
||||
tool_functions: Dictionary mapping names to tool factory functions.
|
||||
callback_functions: Dictionary of available callbacks.
|
||||
output_pydantic_functions: Dictionary of Pydantic output class wrappers.
|
||||
"""
|
||||
if context_list := task_info.get("context"):
|
||||
self.tasks_config[task_name]["context"] = [
|
||||
tasks[context_task_name]() for context_task_name in context_list
|
||||
]
|
||||
|
||||
if tools := task_info.get("tools"):
|
||||
if _is_string_list(tools):
|
||||
self.tasks_config[task_name]["tools"] = [
|
||||
tool_functions[tool]() for tool in tools
|
||||
]
|
||||
|
||||
if agent_name := task_info.get("agent"):
|
||||
self.tasks_config[task_name]["agent"] = agents[agent_name]()
|
||||
|
||||
if output_json := task_info.get("output_json"):
|
||||
self.tasks_config[task_name]["output_json"] = output_json_functions[output_json]
|
||||
|
||||
if output_pydantic := task_info.get("output_pydantic"):
|
||||
self.tasks_config[task_name]["output_pydantic"] = output_pydantic_functions[
|
||||
output_pydantic
|
||||
]
|
||||
|
||||
if callbacks := task_info.get("callbacks"):
|
||||
self.tasks_config[task_name]["callbacks"] = [
|
||||
callback_functions[callback]() for callback in callbacks
|
||||
]
|
||||
|
||||
if guardrail := task_info.get("guardrail"):
|
||||
self.tasks_config[task_name]["guardrail"] = guardrail
|
||||
|
||||
|
||||
_CLASS_SETUP_FUNCTIONS: tuple[Callable[[type[CrewClass]], None], ...] = (
|
||||
_set_base_directory,
|
||||
_set_config_paths,
|
||||
_set_mcp_params,
|
||||
)
|
||||
|
||||
_METHODS_TO_INJECT = (
|
||||
close_mcp_server,
|
||||
get_mcp_tools,
|
||||
_load_config,
|
||||
load_configurations,
|
||||
staticmethod(load_yaml),
|
||||
map_all_agent_variables,
|
||||
_map_agent_variables,
|
||||
map_all_task_variables,
|
||||
_map_task_variables,
|
||||
)
|
||||
|
||||
|
||||
class _CrewBaseType(type):
|
||||
"""Metaclass for CrewBase that makes it callable as a decorator."""
|
||||
|
||||
def __call__(cls, decorated_cls: type) -> type[CrewClass]:
|
||||
"""Apply CrewBaseMeta to the decorated class.
|
||||
|
||||
Args:
|
||||
decorated_cls: Class to transform with CrewBaseMeta metaclass.
|
||||
|
||||
Returns:
|
||||
New class with CrewBaseMeta metaclass applied.
|
||||
"""
|
||||
__name = str(decorated_cls.__name__)
|
||||
__bases = tuple(decorated_cls.__bases__)
|
||||
__dict = {
|
||||
key: value
|
||||
for key, value in decorated_cls.__dict__.items()
|
||||
if key not in ("__dict__", "__weakref__")
|
||||
}
|
||||
for slot in __dict.get("__slots__", tuple()):
|
||||
__dict.pop(slot, None)
|
||||
__dict["__metaclass__"] = CrewBaseMeta
|
||||
return cast(type[CrewClass], CrewBaseMeta(__name, __bases, __dict))
|
||||
|
||||
|
||||
class CrewBase(metaclass=_CrewBaseType):
|
||||
"""Class decorator that applies CrewBaseMeta metaclass.
|
||||
|
||||
Applies CrewBaseMeta metaclass to a class via decorator syntax rather than
|
||||
explicit metaclass declaration. Use as @CrewBase instead of
|
||||
class Foo(metaclass=CrewBaseMeta).
|
||||
|
||||
Note:
|
||||
Reference: https://stackoverflow.com/questions/11091609/setting-a-class-metaclass-using-a-decorator
|
||||
"""
|
||||
|
||||
@@ -1,14 +1,38 @@
|
||||
"""Utility functions for the crewai project module."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any, ParamSpec, TypeVar
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def memoize(func):
|
||||
cache = {}
|
||||
def memoize(meth: Callable[P, R]) -> Callable[P, R]:
|
||||
"""Memoize a method by caching its results based on arguments.
|
||||
|
||||
@wraps(func)
|
||||
def memoized_func(*args, **kwargs):
|
||||
Args:
|
||||
meth: The method to memoize.
|
||||
|
||||
Returns:
|
||||
A memoized version of the method that caches results.
|
||||
"""
|
||||
cache: dict[Any, R] = {}
|
||||
|
||||
@wraps(meth)
|
||||
def memoized_func(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""Memoized wrapper method.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments to pass to the method.
|
||||
**kwargs: Keyword arguments to pass to the method.
|
||||
|
||||
Returns:
|
||||
The cached or computed result of the method.
|
||||
"""
|
||||
key = (args, tuple(kwargs.items()))
|
||||
if key not in cache:
|
||||
cache[key] = func(*args, **kwargs)
|
||||
cache[key] = meth(*args, **kwargs)
|
||||
return cache[key]
|
||||
|
||||
return memoized_func
|
||||
|
||||
388
src/crewai/project/wrappers.py
Normal file
388
src/crewai/project/wrappers.py
Normal file
@@ -0,0 +1,388 @@
|
||||
"""Wrapper classes for decorated methods with type-safe metadata."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Generic,
|
||||
Literal,
|
||||
ParamSpec,
|
||||
Protocol,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai import Agent, Task
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
|
||||
class CrewMetadata(TypedDict):
|
||||
"""Type definition for crew metadata dictionary.
|
||||
|
||||
Stores framework-injected metadata about decorated methods and callbacks.
|
||||
"""
|
||||
|
||||
original_methods: dict[str, Callable[..., Any]]
|
||||
original_tasks: dict[str, Callable[..., Task]]
|
||||
original_agents: dict[str, Callable[..., Agent]]
|
||||
before_kickoff: dict[str, Callable[..., Any]]
|
||||
after_kickoff: dict[str, Callable[..., Any]]
|
||||
kickoff: dict[str, Callable[..., Any]]
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class TaskResult(Protocol):
|
||||
"""Protocol for task objects that have a name attribute."""
|
||||
|
||||
name: str | None
|
||||
|
||||
|
||||
TaskResultT = TypeVar("TaskResultT", bound=TaskResult)
|
||||
|
||||
|
||||
def _copy_method_metadata(wrapper: Any, meth: Callable[..., Any]) -> None:
|
||||
"""Copy method metadata to a wrapper object.
|
||||
|
||||
Args:
|
||||
wrapper: The wrapper object to update.
|
||||
meth: The method to copy metadata from.
|
||||
"""
|
||||
wrapper.__name__ = meth.__name__
|
||||
wrapper.__doc__ = meth.__doc__
|
||||
|
||||
|
||||
class CrewInstance(Protocol):
|
||||
"""Protocol for crew class instances with required attributes."""
|
||||
|
||||
__crew_metadata__: CrewMetadata
|
||||
_mcp_server_adapter: Any
|
||||
_all_methods: dict[str, Callable[..., Any]]
|
||||
agents: list[Agent]
|
||||
tasks: list[Task]
|
||||
base_directory: Path
|
||||
original_agents_config_path: str
|
||||
original_tasks_config_path: str
|
||||
agents_config: dict[str, Any]
|
||||
tasks_config: dict[str, Any]
|
||||
mcp_server_params: Any
|
||||
mcp_connect_timeout: int
|
||||
|
||||
def load_configurations(self) -> None: ...
|
||||
def map_all_agent_variables(self) -> None: ...
|
||||
def map_all_task_variables(self) -> None: ...
|
||||
def close_mcp_server(self, instance: Self, outputs: CrewOutput) -> CrewOutput: ...
|
||||
def _load_config(
|
||||
self, config_path: str | None, config_type: Literal["agent", "task"]
|
||||
) -> dict[str, Any]: ...
|
||||
def _map_agent_variables(
|
||||
self,
|
||||
agent_name: str,
|
||||
agent_info: dict[str, Any],
|
||||
llms: dict[str, Callable[..., Any]],
|
||||
tool_functions: dict[str, Callable[..., Any]],
|
||||
cache_handler_functions: dict[str, Callable[..., Any]],
|
||||
callbacks: dict[str, Callable[..., Any]],
|
||||
) -> None: ...
|
||||
def _map_task_variables(
|
||||
self,
|
||||
task_name: str,
|
||||
task_info: dict[str, Any],
|
||||
agents: dict[str, Callable[..., Any]],
|
||||
tasks: dict[str, Callable[..., Any]],
|
||||
output_json_functions: dict[str, Callable[..., Any]],
|
||||
tool_functions: dict[str, Callable[..., Any]],
|
||||
callback_functions: dict[str, Callable[..., Any]],
|
||||
output_pydantic_functions: dict[str, Callable[..., Any]],
|
||||
) -> None: ...
|
||||
def load_yaml(self, config_path: Path) -> dict[str, Any]: ...
|
||||
|
||||
|
||||
class CrewClass(Protocol):
|
||||
"""Protocol describing class attributes injected by CrewBaseMeta."""
|
||||
|
||||
is_crew_class: bool
|
||||
_crew_name: str
|
||||
base_directory: Path
|
||||
original_agents_config_path: str
|
||||
original_tasks_config_path: str
|
||||
mcp_server_params: Any
|
||||
mcp_connect_timeout: int
|
||||
close_mcp_server: Callable[..., Any]
|
||||
get_mcp_tools: Callable[..., list[BaseTool]]
|
||||
_load_config: Callable[..., dict[str, Any]]
|
||||
load_configurations: Callable[..., None]
|
||||
load_yaml: staticmethod
|
||||
map_all_agent_variables: Callable[..., None]
|
||||
_map_agent_variables: Callable[..., None]
|
||||
map_all_task_variables: Callable[..., None]
|
||||
_map_task_variables: Callable[..., None]
|
||||
|
||||
|
||||
class DecoratedMethod(Generic[P, R]):
|
||||
"""Base wrapper for methods with decorator metadata.
|
||||
|
||||
This class provides a type-safe way to add metadata to methods
|
||||
while preserving their callable signature and attributes.
|
||||
"""
|
||||
|
||||
def __init__(self, meth: Callable[P, R]) -> None:
|
||||
"""Initialize the decorated method wrapper.
|
||||
|
||||
Args:
|
||||
meth: The method to wrap.
|
||||
"""
|
||||
self._meth = meth
|
||||
_copy_method_metadata(self, meth)
|
||||
|
||||
def __get__(
|
||||
self, obj: Any, objtype: type[Any] | None = None
|
||||
) -> Self | Callable[..., R]:
|
||||
"""Support instance methods by implementing the descriptor protocol.
|
||||
|
||||
Args:
|
||||
obj: The instance that the method is accessed through.
|
||||
objtype: The type of the instance.
|
||||
|
||||
Returns:
|
||||
Self when accessed through class, bound method when accessed through instance.
|
||||
"""
|
||||
if obj is None:
|
||||
return self
|
||||
bound = partial(self._meth, obj)
|
||||
for attr in (
|
||||
"is_agent",
|
||||
"is_llm",
|
||||
"is_tool",
|
||||
"is_callback",
|
||||
"is_cache_handler",
|
||||
"is_before_kickoff",
|
||||
"is_after_kickoff",
|
||||
"is_crew",
|
||||
):
|
||||
if hasattr(self, attr):
|
||||
setattr(bound, attr, getattr(self, attr))
|
||||
return bound
|
||||
|
||||
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""Call the wrapped method.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments.
|
||||
**kwargs: Keyword arguments.
|
||||
|
||||
Returns:
|
||||
The result of calling the wrapped method.
|
||||
"""
|
||||
return self._meth(*args, **kwargs)
|
||||
|
||||
def unwrap(self) -> Callable[P, R]:
|
||||
"""Get the original unwrapped method.
|
||||
|
||||
Returns:
|
||||
The original method before decoration.
|
||||
"""
|
||||
return self._meth
|
||||
|
||||
|
||||
class BeforeKickoffMethod(DecoratedMethod[P, R]):
|
||||
"""Wrapper for methods marked to execute before crew kickoff."""
|
||||
|
||||
is_before_kickoff: bool = True
|
||||
|
||||
|
||||
class AfterKickoffMethod(DecoratedMethod[P, R]):
|
||||
"""Wrapper for methods marked to execute after crew kickoff."""
|
||||
|
||||
is_after_kickoff: bool = True
|
||||
|
||||
|
||||
class BoundTaskMethod(Generic[TaskResultT]):
|
||||
"""Bound task method with task marker attribute."""
|
||||
|
||||
is_task: bool = True
|
||||
|
||||
def __init__(self, task_method: TaskMethod[Any, TaskResultT], obj: Any) -> None:
|
||||
"""Initialize the bound task method.
|
||||
|
||||
Args:
|
||||
task_method: The TaskMethod descriptor instance.
|
||||
obj: The instance to bind to.
|
||||
"""
|
||||
self._task_method = task_method
|
||||
self._obj = obj
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> TaskResultT:
|
||||
"""Execute the bound task method.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments.
|
||||
**kwargs: Keyword arguments.
|
||||
|
||||
Returns:
|
||||
The task result with name ensured.
|
||||
"""
|
||||
result = self._task_method.unwrap()(self._obj, *args, **kwargs)
|
||||
return self._task_method.ensure_task_name(result)
|
||||
|
||||
|
||||
class TaskMethod(Generic[P, TaskResultT]):
|
||||
"""Wrapper for methods marked as crew tasks."""
|
||||
|
||||
is_task: bool = True
|
||||
|
||||
def __init__(self, meth: Callable[P, TaskResultT]) -> None:
|
||||
"""Initialize the task method wrapper.
|
||||
|
||||
Args:
|
||||
meth: The method to wrap.
|
||||
"""
|
||||
self._meth = meth
|
||||
_copy_method_metadata(self, meth)
|
||||
|
||||
def ensure_task_name(self, result: TaskResultT) -> TaskResultT:
|
||||
"""Ensure task result has a name set.
|
||||
|
||||
Args:
|
||||
result: The task result to check.
|
||||
|
||||
Returns:
|
||||
The task result with name ensured.
|
||||
"""
|
||||
if not result.name:
|
||||
result.name = self._meth.__name__
|
||||
return result
|
||||
|
||||
def __get__(
|
||||
self, obj: Any, objtype: type[Any] | None = None
|
||||
) -> Self | BoundTaskMethod[TaskResultT]:
|
||||
"""Support instance methods by implementing the descriptor protocol.
|
||||
|
||||
Args:
|
||||
obj: The instance that the method is accessed through.
|
||||
objtype: The type of the instance.
|
||||
|
||||
Returns:
|
||||
Self when accessed through class, bound method when accessed through instance.
|
||||
"""
|
||||
if obj is None:
|
||||
return self
|
||||
return BoundTaskMethod(self, obj)
|
||||
|
||||
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> TaskResultT:
|
||||
"""Call the wrapped method and set task name if not provided.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments.
|
||||
**kwargs: Keyword arguments.
|
||||
|
||||
Returns:
|
||||
The task instance with name set if not already provided.
|
||||
"""
|
||||
return self.ensure_task_name(self._meth(*args, **kwargs))
|
||||
|
||||
def unwrap(self) -> Callable[P, TaskResultT]:
|
||||
"""Get the original unwrapped method.
|
||||
|
||||
Returns:
|
||||
The original method before decoration.
|
||||
"""
|
||||
return self._meth
|
||||
|
||||
|
||||
class AgentMethod(DecoratedMethod[P, R]):
|
||||
"""Wrapper for methods marked as crew agents."""
|
||||
|
||||
is_agent: bool = True
|
||||
|
||||
|
||||
class LLMMethod(DecoratedMethod[P, R]):
|
||||
"""Wrapper for methods marked as LLM providers."""
|
||||
|
||||
is_llm: bool = True
|
||||
|
||||
|
||||
class ToolMethod(DecoratedMethod[P, R]):
|
||||
"""Wrapper for methods marked as crew tools."""
|
||||
|
||||
is_tool: bool = True
|
||||
|
||||
|
||||
class CallbackMethod(DecoratedMethod[P, R]):
|
||||
"""Wrapper for methods marked as crew callbacks."""
|
||||
|
||||
is_callback: bool = True
|
||||
|
||||
|
||||
class CacheHandlerMethod(DecoratedMethod[P, R]):
|
||||
"""Wrapper for methods marked as cache handlers."""
|
||||
|
||||
is_cache_handler: bool = True
|
||||
|
||||
|
||||
class CrewMethod(DecoratedMethod[P, R]):
|
||||
"""Wrapper for methods marked as the main crew execution point."""
|
||||
|
||||
is_crew: bool = True
|
||||
|
||||
|
||||
class OutputClass(Generic[T]):
|
||||
"""Base wrapper for classes marked as output format."""
|
||||
|
||||
def __init__(self, cls: type[T]) -> None:
|
||||
"""Initialize the output class wrapper.
|
||||
|
||||
Args:
|
||||
cls: The class to wrap.
|
||||
"""
|
||||
self._cls = cls
|
||||
self.__name__ = cls.__name__
|
||||
self.__qualname__ = cls.__qualname__
|
||||
self.__module__ = cls.__module__
|
||||
self.__doc__ = cls.__doc__
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> T:
|
||||
"""Create an instance of the wrapped class.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments for the class constructor.
|
||||
**kwargs: Keyword arguments for the class constructor.
|
||||
|
||||
Returns:
|
||||
An instance of the wrapped class.
|
||||
"""
|
||||
return self._cls(*args, **kwargs)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""Delegate attribute access to the wrapped class.
|
||||
|
||||
Args:
|
||||
name: The attribute name.
|
||||
|
||||
Returns:
|
||||
The attribute from the wrapped class.
|
||||
"""
|
||||
return getattr(self._cls, name)
|
||||
|
||||
|
||||
class OutputJsonClass(OutputClass[T]):
|
||||
"""Wrapper for classes marked as JSON output format."""
|
||||
|
||||
is_output_json: bool = True
|
||||
|
||||
|
||||
class OutputPydanticClass(OutputClass[T]):
|
||||
"""Wrapper for classes marked as Pydantic output format."""
|
||||
|
||||
is_output_pydantic: bool = True
|
||||
@@ -1,11 +1,6 @@
|
||||
"""Utility for colored console output."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Final, Literal, NamedTuple
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _typeshed import SupportsWrite
|
||||
from typing import Final, Literal, NamedTuple
|
||||
|
||||
PrinterColor = Literal[
|
||||
"purple",
|
||||
@@ -59,22 +54,13 @@ class Printer:
|
||||
|
||||
@staticmethod
|
||||
def print(
|
||||
content: str | list[ColoredText],
|
||||
color: PrinterColor | None = None,
|
||||
sep: str | None = " ",
|
||||
end: str | None = "\n",
|
||||
file: SupportsWrite[str] | None = None,
|
||||
flush: Literal[False] = False,
|
||||
content: str | list[ColoredText], color: PrinterColor | None = None
|
||||
) -> None:
|
||||
"""Prints content to the console with optional color formatting.
|
||||
|
||||
Args:
|
||||
content: Either a string or a list of ColoredText objects for multicolor output.
|
||||
color: Optional color for the text when content is a string. Ignored when content is a list.
|
||||
sep: Separator to use between the text and color.
|
||||
end: String appended after the last value.
|
||||
file: A file-like object (stream); defaults to the current sys.stdout.
|
||||
flush: Whether to forcibly flush the stream.
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
content = [ColoredText(content, color)]
|
||||
@@ -82,9 +68,5 @@ class Printer:
|
||||
"".join(
|
||||
f"{_COLOR_CODES[c.color] if c.color else ''}{c.text}{RESET}"
|
||||
for c in content
|
||||
),
|
||||
sep=sep,
|
||||
end=end,
|
||||
file=file,
|
||||
flush=flush,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -4744,81 +4744,3 @@ def test_ensure_exchanged_messages_are_propagated_to_external_memory():
|
||||
assert "Researcher" in messages[0]["content"]
|
||||
assert messages[1]["role"] == "user"
|
||||
assert "Research a topic to teach a kid aged 6 about math" in messages[1]["content"]
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_kickoff_stream(researcher):
|
||||
"""Test that crew.kickoff_stream() yields events during execution."""
|
||||
task = Task(
|
||||
description="Research a topic about AI",
|
||||
expected_output="A brief summary about AI",
|
||||
agent=researcher,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[researcher], tasks=[task])
|
||||
|
||||
events = list(crew.kickoff_stream())
|
||||
|
||||
assert len(events) > 0
|
||||
|
||||
event_types = [event["type"] for event in events]
|
||||
assert "crew_kickoff_started" in event_types
|
||||
assert "final_output" in event_types
|
||||
|
||||
final_output_event = next(e for e in events if e["type"] == "final_output")
|
||||
assert "output" in final_output_event["data"]
|
||||
assert isinstance(final_output_event["data"]["output"], str)
|
||||
assert len(final_output_event["data"]["output"]) > 0
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_kickoff_stream_with_inputs(researcher):
|
||||
"""Test that crew.kickoff_stream() works with inputs."""
|
||||
task = Task(
|
||||
description="Research about {topic}",
|
||||
expected_output="A brief summary about {topic}",
|
||||
agent=researcher,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[researcher], tasks=[task])
|
||||
|
||||
events = list(crew.kickoff_stream(inputs={"topic": "machine learning"}))
|
||||
|
||||
assert len(events) > 0
|
||||
|
||||
event_types = [event["type"] for event in events]
|
||||
assert "crew_kickoff_started" in event_types
|
||||
assert "final_output" in event_types
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_kickoff_stream_includes_llm_chunks(researcher):
|
||||
"""Test that crew.kickoff_stream() includes LLM stream chunks."""
|
||||
task = Task(
|
||||
description="Write a short poem about AI",
|
||||
expected_output="A 2-line poem",
|
||||
agent=researcher,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[researcher], tasks=[task])
|
||||
|
||||
events = list(crew.kickoff_stream())
|
||||
|
||||
event_types = [event["type"] for event in events]
|
||||
|
||||
assert "task_started" in event_types or "agent_execution_started" in event_types
|
||||
|
||||
|
||||
def test_crew_kickoff_stream_handles_errors(researcher):
|
||||
"""Test that crew.kickoff_stream() properly handles errors."""
|
||||
task = Task(
|
||||
description="This task will fail",
|
||||
expected_output="Should not complete",
|
||||
agent=researcher,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[researcher], tasks=[task])
|
||||
|
||||
with patch("crewai.crew.Crew.kickoff", side_effect=Exception("Test error")):
|
||||
with pytest.raises(Exception, match="Test error"):
|
||||
list(crew.kickoff_stream())
|
||||
|
||||
@@ -6,15 +6,15 @@ from datetime import datetime
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.flow.flow import Flow, and_, listen, or_, router, start
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.flow_events import (
|
||||
FlowFinishedEvent,
|
||||
FlowPlotEvent,
|
||||
FlowStartedEvent,
|
||||
FlowPlotEvent,
|
||||
MethodExecutionFinishedEvent,
|
||||
MethodExecutionStartedEvent,
|
||||
)
|
||||
from crewai.flow.flow import Flow, and_, listen, or_, router, start
|
||||
|
||||
|
||||
def test_simple_sequential_flow():
|
||||
@@ -679,11 +679,11 @@ def test_structured_flow_event_emission():
|
||||
assert isinstance(received_events[3], MethodExecutionStartedEvent)
|
||||
assert received_events[3].method_name == "send_welcome_message"
|
||||
assert received_events[3].params == {}
|
||||
assert received_events[3].state.sent is False
|
||||
assert getattr(received_events[3].state, "sent") is False
|
||||
|
||||
assert isinstance(received_events[4], MethodExecutionFinishedEvent)
|
||||
assert received_events[4].method_name == "send_welcome_message"
|
||||
assert received_events[4].state.sent is True
|
||||
assert getattr(received_events[4].state, "sent") is True
|
||||
assert received_events[4].result == "Welcome, Anakin!"
|
||||
|
||||
assert isinstance(received_events[5], FlowFinishedEvent)
|
||||
@@ -894,75 +894,3 @@ def test_flow_name():
|
||||
|
||||
flow = MyFlow()
|
||||
assert flow.name == "MyFlow"
|
||||
|
||||
|
||||
def test_nested_and_or_conditions():
|
||||
"""Test nested conditions like or_(and_(A, B), and_(C, D)).
|
||||
|
||||
Reproduces bug from issue #3719 where nested conditions are flattened,
|
||||
causing premature execution.
|
||||
"""
|
||||
execution_order = []
|
||||
|
||||
class NestedConditionFlow(Flow):
|
||||
@start()
|
||||
def method_1(self):
|
||||
execution_order.append("method_1")
|
||||
|
||||
@listen(method_1)
|
||||
def method_2(self):
|
||||
execution_order.append("method_2")
|
||||
|
||||
@router(method_2)
|
||||
def method_3(self):
|
||||
execution_order.append("method_3")
|
||||
# Choose b_condition path
|
||||
return "b_condition"
|
||||
|
||||
@listen("b_condition")
|
||||
def method_5(self):
|
||||
execution_order.append("method_5")
|
||||
|
||||
@listen(method_5)
|
||||
async def method_4(self):
|
||||
execution_order.append("method_4")
|
||||
|
||||
@listen(or_("a_condition", "b_condition"))
|
||||
async def method_6(self):
|
||||
execution_order.append("method_6")
|
||||
|
||||
@listen(
|
||||
or_(
|
||||
and_("a_condition", method_6),
|
||||
and_(method_6, method_4),
|
||||
)
|
||||
)
|
||||
def method_7(self):
|
||||
execution_order.append("method_7")
|
||||
|
||||
@listen(method_7)
|
||||
async def method_8(self):
|
||||
execution_order.append("method_8")
|
||||
|
||||
flow = NestedConditionFlow()
|
||||
flow.kickoff()
|
||||
|
||||
# Verify execution happened
|
||||
assert "method_1" in execution_order
|
||||
assert "method_2" in execution_order
|
||||
assert "method_3" in execution_order
|
||||
assert "method_5" in execution_order
|
||||
assert "method_4" in execution_order
|
||||
assert "method_6" in execution_order
|
||||
assert "method_7" in execution_order
|
||||
assert "method_8" in execution_order
|
||||
|
||||
# Critical assertion: method_7 should only execute AFTER both method_6 AND method_4
|
||||
# Since b_condition was returned, method_6 triggers on b_condition
|
||||
# method_7 requires: (a_condition AND method_6) OR (method_6 AND method_4)
|
||||
# The second condition (method_6 AND method_4) should be the one that triggers
|
||||
assert execution_order.index("method_7") > execution_order.index("method_6")
|
||||
assert execution_order.index("method_7") > execution_order.index("method_4")
|
||||
|
||||
# method_8 should execute after method_7
|
||||
assert execution_order.index("method_8") > execution_order.index("method_7")
|
||||
|
||||
Reference in New Issue
Block a user