Compare commits

..

12 Commits

Author SHA1 Message Date
Greyson Lalonde
6081809c76 chore: add comprehensive typing to AgentConfig and TaskConfig 2025-10-16 19:53:19 -04:00
Greyson Lalonde
30a4b712a3 refactor(project): improve type safety and consolidate crew metadata 2025-10-16 19:35:34 -04:00
Greyson Lalonde
8465350f1d refactor: convert CrewBase decorator to metaclass pattern 2025-10-15 01:28:59 -04:00
Greyson LaLonde
2a23bc604c fix: resolve NameError in project module with deferred type annotations 2025-10-14 15:50:49 -04:00
Greyson Lalonde
714f8a8940 Merge branch 'main' into gl/chore/project-linting-typing 2025-10-14 13:23:57 -04:00
Greyson Lalonde
e029de2863 chore: refactor type hints for Agent and Task in annotations and wrappers 2025-10-13 15:35:39 -04:00
Greyson Lalonde
6492852a0c chore: refactor type annotations and protocols in wrappers.py 2025-10-13 14:38:45 -04:00
Greyson Lalonde
fecf7e9a83 feat: add protocols for Agent, Task, and Crew instances 2025-10-13 13:52:29 -04:00
Greyson Lalonde
6bc8818ae9 chore: refactor and type memoize utility function 2025-10-13 13:09:10 -04:00
Greyson Lalonde
620df71763 chore: refactor decorators and add output class wrappers
Refactored task, agent, llm, tool, callback, and cache_handler decorators to directly memoize the method before wrapping, simplifying their implementation. Introduced OutputJsonClass and OutputPydanticClass wrappers to mark classes for JSON and Pydantic output formats, replacing the previous attribute-based approach.
2025-10-13 13:01:14 -04:00
Greyson Lalonde
7d6324dfa3 chore: refactor annotation decorators with type-safe wrappers 2025-10-13 12:44:02 -04:00
Greyson Lalonde
541eec0639 chore: refactor project decorators and imports for clarity 2025-10-13 12:29:57 -04:00
11 changed files with 1288 additions and 1010 deletions

View File

@@ -1,177 +0,0 @@
"""
FastAPI Streaming Integration Example for CrewAI
This example demonstrates how to integrate CrewAI with FastAPI to stream
crew execution events in real-time using Server-Sent Events (SSE).
Installation:
pip install crewai fastapi uvicorn
Usage:
python fastapi_streaming_example.py
Then visit:
http://localhost:8000/docs for the API documentation
http://localhost:8000/stream?topic=AI to see streaming in action
"""
import json
from typing import AsyncGenerator
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from crewai import Agent, Crew, Task
app = FastAPI(title="CrewAI Streaming API")
class ResearchRequest(BaseModel):
topic: str
def create_research_crew(topic: str) -> Crew:
"""Create a research crew for the given topic."""
researcher = Agent(
role="Researcher",
goal=f"Research and analyze information about {topic}",
backstory="You're an expert researcher with deep knowledge in various fields.",
verbose=True,
)
task = Task(
description=f"Research and provide a comprehensive summary about {topic}",
expected_output="A detailed summary with key insights",
agent=researcher,
)
return Crew(agents=[researcher], tasks=[task], verbose=True)
@app.get("/")
async def root():
"""Root endpoint with API information."""
return {
"message": "CrewAI Streaming API",
"endpoints": {
"/stream": "GET - Stream crew execution events (query param: topic)",
"/research": "POST - Execute crew and return final result",
},
}
@app.get("/stream")
async def stream_crew_execution(topic: str = "artificial intelligence"):
"""
Stream crew execution events in real-time using Server-Sent Events.
Args:
topic: The research topic (default: "artificial intelligence")
Returns:
StreamingResponse with text/event-stream content type
"""
async def event_generator() -> AsyncGenerator[str, None]:
"""Generate Server-Sent Events from crew execution."""
crew = create_research_crew(topic)
try:
for event in crew.kickoff_stream(inputs={"topic": topic}):
event_data = json.dumps(event)
yield f"data: {event_data}\n\n"
yield "data: {\"type\": \"done\"}\n\n"
except Exception as e:
error_event = {"type": "error", "data": {"message": str(e)}}
yield f"data: {json.dumps(error_event)}\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
@app.post("/research")
async def research_topic(request: ResearchRequest):
"""
Execute crew research and return the final result.
Args:
request: ResearchRequest with topic field
Returns:
JSON response with the research result
"""
crew = create_research_crew(request.topic)
try:
result = crew.kickoff(inputs={"topic": request.topic})
return {
"success": True,
"topic": request.topic,
"result": result.raw,
"usage_metrics": (
result.token_usage.model_dump() if result.token_usage else None
),
}
except Exception as e:
return {"success": False, "error": str(e)}
@app.get("/stream-filtered")
async def stream_filtered_events(
topic: str = "artificial intelligence", event_types: str = "llm_stream_chunk"
):
"""
Stream only specific event types.
Args:
topic: The research topic
event_types: Comma-separated list of event types to include
Returns:
StreamingResponse with filtered events
"""
allowed_types = set(event_types.split(","))
async def event_generator() -> AsyncGenerator[str, None]:
crew = create_research_crew(topic)
try:
for event in crew.kickoff_stream(inputs={"topic": topic}):
if event["type"] in allowed_types:
event_data = json.dumps(event)
yield f"data: {event_data}\n\n"
yield "data: {\"type\": \"done\"}\n\n"
except Exception as e:
error_event = {"type": "error", "data": {"message": str(e)}}
yield f"data: {json.dumps(error_event)}\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)
if __name__ == "__main__":
import uvicorn
print("Starting CrewAI Streaming API...")
print("Visit http://localhost:8000/docs for API documentation")
print("Try: http://localhost:8000/stream?topic=quantum%20computing")
uvicorn.run(app, host="0.0.0.0", port=8000)

View File

@@ -766,118 +766,6 @@ class Crew(FlowTrackable, BaseModel):
self._task_output_handler.reset()
return results
def kickoff_stream(self, inputs: dict[str, Any] | None = None):
"""
Stream crew execution events in real-time.
This method yields events as they occur during crew execution, making it
easy to integrate with streaming frameworks like FastAPI's StreamingResponse.
Args:
inputs: Optional dictionary of inputs for the crew execution
Yields:
dict: Event dictionaries containing event type and data
Example:
```python
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
import json
app = FastAPI()
@app.get("/stream")
async def stream_crew():
def event_generator():
for event in crew.kickoff_stream(inputs={"topic": "AI"}):
yield f"data: {json.dumps(event)}\\n\\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream"
)
```
"""
import queue
import threading
from crewai.events.base_events import BaseEvent
event_queue: queue.Queue = queue.Queue()
completion_event = threading.Event()
exception_holder = {"exception": None}
def event_handler(source: Any, event: BaseEvent):
event_dict = {
"type": event.type,
"data": event.model_dump(exclude={"from_task", "from_agent"}),
}
event_queue.put(event_dict)
from crewai.events.types.crew_events import (
CrewKickoffStartedEvent,
CrewKickoffCompletedEvent,
CrewKickoffFailedEvent,
)
from crewai.events.types.task_events import (
TaskStartedEvent,
TaskCompletedEvent,
TaskFailedEvent,
)
from crewai.events.types.agent_events import (
AgentExecutionStartedEvent,
AgentExecutionCompletedEvent,
)
from crewai.events.types.llm_events import (
LLMStreamChunkEvent,
LLMCallStartedEvent,
LLMCallCompletedEvent,
)
from crewai.events.types.tool_usage_events import (
ToolUsageStartedEvent,
ToolUsageFinishedEvent,
ToolUsageErrorEvent,
)
crewai_event_bus.register_handler(CrewKickoffStartedEvent, event_handler)
crewai_event_bus.register_handler(CrewKickoffCompletedEvent, event_handler)
crewai_event_bus.register_handler(CrewKickoffFailedEvent, event_handler)
crewai_event_bus.register_handler(TaskStartedEvent, event_handler)
crewai_event_bus.register_handler(TaskCompletedEvent, event_handler)
crewai_event_bus.register_handler(TaskFailedEvent, event_handler)
crewai_event_bus.register_handler(AgentExecutionStartedEvent, event_handler)
crewai_event_bus.register_handler(AgentExecutionCompletedEvent, event_handler)
crewai_event_bus.register_handler(LLMStreamChunkEvent, event_handler)
crewai_event_bus.register_handler(LLMCallStartedEvent, event_handler)
crewai_event_bus.register_handler(LLMCallCompletedEvent, event_handler)
crewai_event_bus.register_handler(ToolUsageStartedEvent, event_handler)
crewai_event_bus.register_handler(ToolUsageFinishedEvent, event_handler)
crewai_event_bus.register_handler(ToolUsageErrorEvent, event_handler)
def run_kickoff():
try:
result = self.kickoff(inputs=inputs)
event_queue.put({"type": "final_output", "data": {"output": result.raw}})
except Exception as e:
exception_holder["exception"] = e
finally:
completion_event.set()
thread = threading.Thread(target=run_kickoff, daemon=True)
thread.start()
try:
while not completion_event.is_set() or not event_queue.empty():
event = event_queue.get(timeout=0.1) if not event_queue.empty() else None
if event is not None:
yield event
if exception_holder["exception"]:
raise exception_holder["exception"]
finally:
thread.join(timeout=1)
def _handle_crew_planning(self):
"""Handles the Crew planning."""
self._logger.log("info", "Planning the crew execution")

View File

@@ -31,7 +31,7 @@ from crewai.flow.flow_visualizer import plot_flow
from crewai.flow.persistence.base import FlowPersistence
from crewai.flow.types import FlowExecutionData
from crewai.flow.utils import get_possible_return_constants
from crewai.utilities.printer import Printer, PrinterColor
from crewai.utilities.printer import Printer
logger = logging.getLogger(__name__)
@@ -105,7 +105,7 @@ def start(condition: str | dict | Callable | None = None) -> Callable:
condition : Optional[Union[str, dict, Callable]], optional
Defines when the start method should execute. Can be:
- str: Name of a method that triggers this start
- dict: Result from or_() or and_(), including nested conditions
- dict: Contains "type" ("AND"/"OR") and "methods" (list of triggers)
- Callable: A method reference that triggers this start
Default is None, meaning unconditional start.
@@ -140,18 +140,13 @@ def start(condition: str | dict | Callable | None = None) -> Callable:
if isinstance(condition, str):
func.__trigger_methods__ = [condition]
func.__condition_type__ = "OR"
elif isinstance(condition, dict) and "type" in condition:
if "conditions" in condition:
func.__trigger_condition__ = condition
func.__trigger_methods__ = _extract_all_methods(condition)
func.__condition_type__ = condition["type"]
elif "methods" in condition:
func.__trigger_methods__ = condition["methods"]
func.__condition_type__ = condition["type"]
else:
raise ValueError(
"Condition dict must contain 'conditions' or 'methods'"
)
elif (
isinstance(condition, dict)
and "type" in condition
and "methods" in condition
):
func.__trigger_methods__ = condition["methods"]
func.__condition_type__ = condition["type"]
elif callable(condition) and hasattr(condition, "__name__"):
func.__trigger_methods__ = [condition.__name__]
func.__condition_type__ = "OR"
@@ -177,7 +172,7 @@ def listen(condition: str | dict | Callable) -> Callable:
condition : Union[str, dict, Callable]
Specifies when the listener should execute. Can be:
- str: Name of a method that triggers this listener
- dict: Result from or_() or and_(), including nested conditions
- dict: Contains "type" ("AND"/"OR") and "methods" (list of triggers)
- Callable: A method reference that triggers this listener
Returns
@@ -205,18 +200,13 @@ def listen(condition: str | dict | Callable) -> Callable:
if isinstance(condition, str):
func.__trigger_methods__ = [condition]
func.__condition_type__ = "OR"
elif isinstance(condition, dict) and "type" in condition:
if "conditions" in condition:
func.__trigger_condition__ = condition
func.__trigger_methods__ = _extract_all_methods(condition)
func.__condition_type__ = condition["type"]
elif "methods" in condition:
func.__trigger_methods__ = condition["methods"]
func.__condition_type__ = condition["type"]
else:
raise ValueError(
"Condition dict must contain 'conditions' or 'methods'"
)
elif (
isinstance(condition, dict)
and "type" in condition
and "methods" in condition
):
func.__trigger_methods__ = condition["methods"]
func.__condition_type__ = condition["type"]
elif callable(condition) and hasattr(condition, "__name__"):
func.__trigger_methods__ = [condition.__name__]
func.__condition_type__ = "OR"
@@ -243,7 +233,7 @@ def router(condition: str | dict | Callable) -> Callable:
condition : Union[str, dict, Callable]
Specifies when the router should execute. Can be:
- str: Name of a method that triggers this router
- dict: Result from or_() or and_(), including nested conditions
- dict: Contains "type" ("AND"/"OR") and "methods" (list of triggers)
- Callable: A method reference that triggers this router
Returns
@@ -276,18 +266,13 @@ def router(condition: str | dict | Callable) -> Callable:
if isinstance(condition, str):
func.__trigger_methods__ = [condition]
func.__condition_type__ = "OR"
elif isinstance(condition, dict) and "type" in condition:
if "conditions" in condition:
func.__trigger_condition__ = condition
func.__trigger_methods__ = _extract_all_methods(condition)
func.__condition_type__ = condition["type"]
elif "methods" in condition:
func.__trigger_methods__ = condition["methods"]
func.__condition_type__ = condition["type"]
else:
raise ValueError(
"Condition dict must contain 'conditions' or 'methods'"
)
elif (
isinstance(condition, dict)
and "type" in condition
and "methods" in condition
):
func.__trigger_methods__ = condition["methods"]
func.__condition_type__ = condition["type"]
elif callable(condition) and hasattr(condition, "__name__"):
func.__trigger_methods__ = [condition.__name__]
func.__condition_type__ = "OR"
@@ -313,15 +298,14 @@ def or_(*conditions: str | dict | Callable) -> dict:
*conditions : Union[str, dict, Callable]
Variable number of conditions that can be:
- str: Method names
- dict: Existing condition dictionaries (nested conditions)
- dict: Existing condition dictionaries
- Callable: Method references
Returns
-------
dict
A condition dictionary with format:
{"type": "OR", "conditions": list_of_conditions}
where each condition can be a string (method name) or a nested dict
{"type": "OR", "methods": list_of_method_names}
Raises
------
@@ -333,22 +317,18 @@ def or_(*conditions: str | dict | Callable) -> dict:
>>> @listen(or_("success", "timeout"))
>>> def handle_completion(self):
... pass
>>> @listen(or_(and_("step1", "step2"), "step3"))
>>> def handle_nested(self):
... pass
"""
processed_conditions: list[str | dict[str, Any]] = []
methods = []
for condition in conditions:
if isinstance(condition, dict):
processed_conditions.append(condition)
if isinstance(condition, dict) and "methods" in condition:
methods.extend(condition["methods"])
elif isinstance(condition, str):
processed_conditions.append(condition)
methods.append(condition)
elif callable(condition):
processed_conditions.append(getattr(condition, "__name__", repr(condition)))
methods.append(getattr(condition, "__name__", repr(condition)))
else:
raise ValueError("Invalid condition in or_()")
return {"type": "OR", "conditions": processed_conditions}
return {"type": "OR", "methods": methods}
def and_(*conditions: str | dict | Callable) -> dict:
@@ -364,15 +344,14 @@ def and_(*conditions: str | dict | Callable) -> dict:
*conditions : Union[str, dict, Callable]
Variable number of conditions that can be:
- str: Method names
- dict: Existing condition dictionaries (nested conditions)
- dict: Existing condition dictionaries
- Callable: Method references
Returns
-------
dict
A condition dictionary with format:
{"type": "AND", "conditions": list_of_conditions}
where each condition can be a string (method name) or a nested dict
{"type": "AND", "methods": list_of_method_names}
Raises
------
@@ -384,69 +363,18 @@ def and_(*conditions: str | dict | Callable) -> dict:
>>> @listen(and_("validated", "processed"))
>>> def handle_complete_data(self):
... pass
>>> @listen(and_(or_("step1", "step2"), "step3"))
>>> def handle_nested(self):
... pass
"""
processed_conditions: list[str | dict[str, Any]] = []
methods = []
for condition in conditions:
if isinstance(condition, dict):
processed_conditions.append(condition)
if isinstance(condition, dict) and "methods" in condition:
methods.extend(condition["methods"])
elif isinstance(condition, str):
processed_conditions.append(condition)
methods.append(condition)
elif callable(condition):
processed_conditions.append(getattr(condition, "__name__", repr(condition)))
methods.append(getattr(condition, "__name__", repr(condition)))
else:
raise ValueError("Invalid condition in and_()")
return {"type": "AND", "conditions": processed_conditions}
def _normalize_condition(condition: str | dict | list) -> dict:
"""Normalize a condition to standard format with 'conditions' key.
Args:
condition: Can be a string (method name), dict (condition), or list
Returns:
Normalized dict with 'type' and 'conditions' keys
"""
if isinstance(condition, str):
return {"type": "OR", "conditions": [condition]}
if isinstance(condition, dict):
if "conditions" in condition:
return condition
if "methods" in condition:
return {"type": condition["type"], "conditions": condition["methods"]}
return condition
if isinstance(condition, list):
return {"type": "OR", "conditions": condition}
return {"type": "OR", "conditions": [condition]}
def _extract_all_methods(condition: str | dict | list) -> list[str]:
"""Extract all method names from a condition (including nested).
Args:
condition: Can be a string, dict, or list
Returns:
List of all method names in the condition tree
"""
if isinstance(condition, str):
return [condition]
if isinstance(condition, dict):
normalized = _normalize_condition(condition)
methods = []
for sub_cond in normalized.get("conditions", []):
methods.extend(_extract_all_methods(sub_cond))
return methods
if isinstance(condition, list):
methods = []
for item in condition:
methods.extend(_extract_all_methods(item))
return methods
return []
return {"type": "AND", "methods": methods}
class FlowMeta(type):
@@ -474,10 +402,7 @@ class FlowMeta(type):
if hasattr(attr_value, "__trigger_methods__"):
methods = attr_value.__trigger_methods__
condition_type = getattr(attr_value, "__condition_type__", "OR")
if hasattr(attr_value, "__trigger_condition__"):
listeners[attr_name] = attr_value.__trigger_condition__
else:
listeners[attr_name] = (condition_type, methods)
listeners[attr_name] = (condition_type, methods)
if (
hasattr(attr_value, "__is_router__")
@@ -897,7 +822,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
# Clear completed methods and outputs for a fresh start
self._completed_methods.clear()
self._method_outputs.clear()
self._pending_and_listeners.clear()
else:
# We're restoring from persistence, set the flag
self._is_execution_resuming = True
@@ -1162,16 +1086,10 @@ class Flow(Generic[T], metaclass=FlowMeta):
for method_name in self._start_methods:
# Check if this start method is triggered by the current trigger
if method_name in self._listeners:
condition_data = self._listeners[method_name]
should_trigger = False
if isinstance(condition_data, tuple):
_, trigger_methods = condition_data
should_trigger = current_trigger in trigger_methods
elif isinstance(condition_data, dict):
all_methods = _extract_all_methods(condition_data)
should_trigger = current_trigger in all_methods
if should_trigger:
condition_type, trigger_methods = self._listeners[
method_name
]
if current_trigger in trigger_methods:
# Only execute if this is a cycle (method was already completed)
if method_name in self._completed_methods:
# For router-triggered start methods in cycles, temporarily clear resumption flag
@@ -1181,51 +1099,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
await self._execute_start_method(method_name)
self._is_execution_resuming = was_resuming
def _evaluate_condition(
self, condition: str | dict, trigger_method: str, listener_name: str
) -> bool:
"""Recursively evaluate a condition (simple or nested).
Args:
condition: Can be a string (method name) or dict (nested condition)
trigger_method: The method that just completed
listener_name: Name of the listener being evaluated
Returns:
True if the condition is satisfied, False otherwise
"""
if isinstance(condition, str):
return condition == trigger_method
if isinstance(condition, dict):
normalized = _normalize_condition(condition)
cond_type = normalized.get("type", "OR")
sub_conditions = normalized.get("conditions", [])
if cond_type == "OR":
return any(
self._evaluate_condition(sub_cond, trigger_method, listener_name)
for sub_cond in sub_conditions
)
if cond_type == "AND":
pending_key = f"{listener_name}:{id(condition)}"
if pending_key not in self._pending_and_listeners:
all_methods = set(_extract_all_methods(condition))
self._pending_and_listeners[pending_key] = all_methods
if trigger_method in self._pending_and_listeners[pending_key]:
self._pending_and_listeners[pending_key].discard(trigger_method)
if not self._pending_and_listeners[pending_key]:
self._pending_and_listeners.pop(pending_key, None)
return True
return False
return False
def _find_triggered_methods(
self, trigger_method: str, router_only: bool
) -> list[str]:
@@ -1233,7 +1106,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
Finds all methods that should be triggered based on conditions.
This internal method evaluates both OR and AND conditions to determine
which methods should be executed next in the flow. Supports nested conditions.
which methods should be executed next in the flow.
Parameters
----------
@@ -1250,13 +1123,14 @@ class Flow(Generic[T], metaclass=FlowMeta):
Notes
-----
- Handles both OR and AND conditions, including nested combinations
- Handles both OR and AND conditions:
* OR: Triggers if any condition is met
* AND: Triggers only when all conditions are met
- Maintains state for AND conditions using _pending_and_listeners
- Separates router and normal listener evaluation
"""
triggered = []
for listener_name, condition_data in self._listeners.items():
for listener_name, (condition_type, methods) in self._listeners.items():
is_router = listener_name in self._routers
if router_only != is_router:
@@ -1265,29 +1139,23 @@ class Flow(Generic[T], metaclass=FlowMeta):
if not router_only and listener_name in self._start_methods:
continue
if isinstance(condition_data, tuple):
condition_type, methods = condition_data
if condition_type == "OR":
if trigger_method in methods:
triggered.append(listener_name)
elif condition_type == "AND":
if listener_name not in self._pending_and_listeners:
self._pending_and_listeners[listener_name] = set(methods)
if trigger_method in self._pending_and_listeners[listener_name]:
self._pending_and_listeners[listener_name].discard(
trigger_method
)
if not self._pending_and_listeners[listener_name]:
triggered.append(listener_name)
self._pending_and_listeners.pop(listener_name, None)
elif isinstance(condition_data, dict):
if self._evaluate_condition(
condition_data, trigger_method, listener_name
):
if condition_type == "OR":
# If the trigger_method matches any in methods, run this
if trigger_method in methods:
triggered.append(listener_name)
elif condition_type == "AND":
# Initialize pending methods for this listener if not already done
if listener_name not in self._pending_and_listeners:
self._pending_and_listeners[listener_name] = set(methods)
# Remove the trigger method from pending methods
if trigger_method in self._pending_and_listeners[listener_name]:
self._pending_and_listeners[listener_name].discard(trigger_method)
if not self._pending_and_listeners[listener_name]:
# All required methods have been executed
triggered.append(listener_name)
# Reset pending methods for this listener
self._pending_and_listeners.pop(listener_name, None)
return triggered
@@ -1350,7 +1218,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
raise
def _log_flow_event(
self, message: str, color: PrinterColor | None = "yellow", level: str = "info"
self, message: str, color: str = "yellow", level: str = "info"
) -> None:
"""Centralized logging method for flow events.

View File

@@ -1,4 +1,6 @@
from .annotations import (
"""Project package for CrewAI."""
from crewai.project.annotations import (
after_kickoff,
agent,
before_kickoff,
@@ -11,19 +13,19 @@ from .annotations import (
task,
tool,
)
from .crew_base import CrewBase
from crewai.project.crew_base import CrewBase
__all__ = [
"CrewBase",
"after_kickoff",
"agent",
"before_kickoff",
"cache_handler",
"callback",
"crew",
"task",
"llm",
"output_json",
"output_pydantic",
"task",
"tool",
"callback",
"CrewBase",
"llm",
"cache_handler",
"before_kickoff",
"after_kickoff",
]

View File

@@ -1,97 +1,192 @@
from functools import wraps
from typing import Callable
from crewai import Crew
from crewai.project.utils import memoize
"""Decorators for defining crew components and their behaviors."""
from __future__ import annotations
def before_kickoff(func):
"""Marks a method to execute before crew kickoff."""
func.is_before_kickoff = True
return func
from collections.abc import Callable
from functools import wraps
from typing import TYPE_CHECKING, Concatenate, ParamSpec, TypeVar
from crewai.project.utils import memoize
if TYPE_CHECKING:
from crewai import Agent, Crew, Task
from crewai.project.wrappers import (
AfterKickoffMethod,
AgentMethod,
BeforeKickoffMethod,
CacheHandlerMethod,
CallbackMethod,
CrewInstance,
LLMMethod,
OutputJsonClass,
OutputPydanticClass,
TaskMethod,
TaskResultT,
ToolMethod,
)
P = ParamSpec("P")
P2 = ParamSpec("P2")
R = TypeVar("R")
R2 = TypeVar("R2")
T = TypeVar("T")
def after_kickoff(func):
"""Marks a method to execute after crew kickoff."""
func.is_after_kickoff = True
return func
def before_kickoff(meth: Callable[P, R]) -> BeforeKickoffMethod[P, R]:
"""Marks a method to execute before crew kickoff.
Args:
meth: The method to mark.
Returns:
A wrapped method marked for before kickoff execution.
"""
return BeforeKickoffMethod(meth)
def task(func):
"""Marks a method as a crew task."""
func.is_task = True
def after_kickoff(meth: Callable[P, R]) -> AfterKickoffMethod[P, R]:
"""Marks a method to execute after crew kickoff.
@wraps(func)
def wrapper(*args, **kwargs):
result = func(*args, **kwargs)
if not result.name:
result.name = func.__name__
return result
Args:
meth: The method to mark.
return memoize(wrapper)
Returns:
A wrapped method marked for after kickoff execution.
"""
return AfterKickoffMethod(meth)
def agent(func):
"""Marks a method as a crew agent."""
func.is_agent = True
func = memoize(func)
return func
def task(meth: Callable[P, TaskResultT]) -> TaskMethod[P, TaskResultT]:
"""Marks a method as a crew task.
Args:
meth: The method to mark.
Returns:
A wrapped method marked as a task with memoization.
"""
return TaskMethod(memoize(meth))
def llm(func):
"""Marks a method as an LLM provider."""
func.is_llm = True
func = memoize(func)
return func
def agent(meth: Callable[P, R]) -> AgentMethod[P, R]:
"""Marks a method as a crew agent.
Args:
meth: The method to mark.
Returns:
A wrapped method marked as an agent with memoization.
"""
return AgentMethod(memoize(meth))
def output_json(cls):
"""Marks a class as JSON output format."""
cls.is_output_json = True
return cls
def llm(meth: Callable[P, R]) -> LLMMethod[P, R]:
"""Marks a method as an LLM provider.
Args:
meth: The method to mark.
Returns:
A wrapped method marked as an LLM provider with memoization.
"""
return LLMMethod(memoize(meth))
def output_pydantic(cls):
"""Marks a class as Pydantic output format."""
cls.is_output_pydantic = True
return cls
def output_json(cls: type[T]) -> OutputJsonClass[T]:
"""Marks a class as JSON output format.
Args:
cls: The class to mark.
Returns:
A wrapped class marked as JSON output format.
"""
return OutputJsonClass(cls)
def tool(func):
"""Marks a method as a crew tool."""
func.is_tool = True
return memoize(func)
def output_pydantic(cls: type[T]) -> OutputPydanticClass[T]:
"""Marks a class as Pydantic output format.
Args:
cls: The class to mark.
Returns:
A wrapped class marked as Pydantic output format.
"""
return OutputPydanticClass(cls)
def callback(func):
"""Marks a method as a crew callback."""
func.is_callback = True
return memoize(func)
def tool(meth: Callable[P, R]) -> ToolMethod[P, R]:
"""Marks a method as a crew tool.
Args:
meth: The method to mark.
Returns:
A wrapped method marked as a tool with memoization.
"""
return ToolMethod(memoize(meth))
def cache_handler(func):
"""Marks a method as a cache handler."""
func.is_cache_handler = True
return memoize(func)
def callback(meth: Callable[P, R]) -> CallbackMethod[P, R]:
"""Marks a method as a crew callback.
Args:
meth: The method to mark.
Returns:
A wrapped method marked as a callback with memoization.
"""
return CallbackMethod(memoize(meth))
def crew(func) -> Callable[..., Crew]:
"""Marks a method as the main crew execution point."""
def cache_handler(meth: Callable[P, R]) -> CacheHandlerMethod[P, R]:
"""Marks a method as a cache handler.
@wraps(func)
def wrapper(self, *args, **kwargs) -> Crew:
instantiated_tasks = []
instantiated_agents = []
agent_roles = set()
Args:
meth: The method to mark.
Returns:
A wrapped method marked as a cache handler with memoization.
"""
return CacheHandlerMethod(memoize(meth))
def crew(
meth: Callable[Concatenate[CrewInstance, P], Crew],
) -> Callable[Concatenate[CrewInstance, P], Crew]:
"""Marks a method as the main crew execution point.
Args:
meth: The method to mark as crew execution point.
Returns:
A wrapped method that instantiates tasks and agents before execution.
"""
@wraps(meth)
def wrapper(self: CrewInstance, *args: P.args, **kwargs: P.kwargs) -> Crew:
"""Wrapper that sets up crew before calling the decorated method.
Args:
self: The crew class instance.
*args: Additional positional arguments.
**kwargs: Keyword arguments to pass to the method.
Returns:
The configured Crew instance with callbacks attached.
"""
instantiated_tasks: list[Task] = []
instantiated_agents: list[Agent] = []
agent_roles: set[str] = set()
# Use the preserved task and agent information
tasks = self._original_tasks.items()
agents = self._original_agents.items()
tasks = self.__crew_metadata__["original_tasks"].items()
agents = self.__crew_metadata__["original_agents"].items()
# Instantiate tasks in order
for task_name, task_method in tasks:
for _, task_method in tasks:
task_instance = task_method(self)
instantiated_tasks.append(task_instance)
agent_instance = getattr(task_instance, "agent", None)
@@ -100,7 +195,7 @@ def crew(func) -> Callable[..., Crew]:
agent_roles.add(agent_instance.role)
# Instantiate agents not included by tasks
for agent_name, agent_method in agents:
for _, agent_method in agents:
agent_instance = agent_method(self)
if agent_instance.role not in agent_roles:
instantiated_agents.append(agent_instance)
@@ -109,19 +204,44 @@ def crew(func) -> Callable[..., Crew]:
self.agents = instantiated_agents
self.tasks = instantiated_tasks
crew = func(self, *args, **kwargs)
crew_instance = meth(self, *args, **kwargs)
def callback_wrapper(callback, instance):
def wrapper(*args, **kwargs):
return callback(instance, *args, **kwargs)
def callback_wrapper(
hook: Callable[Concatenate[CrewInstance, P2], R2], instance: CrewInstance
) -> Callable[P2, R2]:
"""Bind a hook callback to an instance.
return wrapper
Args:
hook: The callback hook to bind.
instance: The instance to bind to.
for _, callback in self._before_kickoff.items():
crew.before_kickoff_callbacks.append(callback_wrapper(callback, self))
for _, callback in self._after_kickoff.items():
crew.after_kickoff_callbacks.append(callback_wrapper(callback, self))
Returns:
A bound callback function.
"""
return crew
def bound_callback(*cb_args: P2.args, **cb_kwargs: P2.kwargs) -> R2:
"""Execute the bound callback.
Args:
*cb_args: Positional arguments for the callback.
**cb_kwargs: Keyword arguments for the callback.
Returns:
The result of the callback execution.
"""
return hook(instance, *cb_args, **cb_kwargs)
return bound_callback
for hook_callback in self.__crew_metadata__["before_kickoff"].values():
crew_instance.before_kickoff_callbacks.append(
callback_wrapper(hook_callback, self)
)
for hook_callback in self.__crew_metadata__["after_kickoff"].values():
crew_instance.after_kickoff_callbacks.append(
callback_wrapper(hook_callback, self)
)
return crew_instance
return memoize(wrapper)

View File

@@ -1,298 +1,631 @@
"""Base metaclass for creating crew classes with configuration and method management."""
from __future__ import annotations
import inspect
import logging
from collections.abc import Callable
from pathlib import Path
from typing import Any, TypeVar, cast
from typing import TYPE_CHECKING, Any, Literal, TypedDict, TypeGuard, TypeVar, cast
import yaml
from dotenv import load_dotenv
from crewai.project.wrappers import CrewClass, CrewMetadata
from crewai.tools import BaseTool
if TYPE_CHECKING:
from crewai import Agent, Task
from crewai.agents.cache.cache_handler import CacheHandler
from crewai.crews.crew_output import CrewOutput
from crewai.project.wrappers import (
CrewInstance,
OutputJsonClass,
OutputPydanticClass,
)
from crewai.tasks.task_output import TaskOutput
class AgentConfig(TypedDict, total=False):
"""Type definition for agent configuration dictionary.
All fields are optional as they come from YAML configuration files.
Fields can be either string references (from YAML) or actual instances (after processing).
"""
# Core agent attributes (from BaseAgent)
role: str
goal: str
backstory: str
cache: bool
verbose: bool
max_rpm: int
allow_delegation: bool
max_iter: int
max_tokens: int
callbacks: list[str]
# LLM configuration
llm: str
function_calling_llm: str
use_system_prompt: bool
# Template configuration
system_template: str
prompt_template: str
response_template: str
# Tools and handlers (can be string references or instances)
tools: list[str] | list[BaseTool]
step_callback: str
cache_handler: str | CacheHandler
# Code execution
allow_code_execution: bool
code_execution_mode: Literal["safe", "unsafe"]
# Context and performance
respect_context_window: bool
max_retry_limit: int
# Multimodal and reasoning
multimodal: bool
reasoning: bool
max_reasoning_attempts: int
# Knowledge configuration
knowledge_sources: list[str] | list[Any]
knowledge_storage: str | Any
knowledge_config: dict[str, Any]
embedder: dict[str, Any]
agent_knowledge_context: str
crew_knowledge_context: str
knowledge_search_query: str
# Misc configuration
inject_date: bool
date_format: str
from_repository: str
guardrail: Callable[[Any], tuple[bool, Any]] | str
guardrail_max_retries: int
class TaskConfig(TypedDict, total=False):
"""Type definition for task configuration dictionary.
All fields are optional as they come from YAML configuration files.
Fields can be either string references (from YAML) or actual instances (after processing).
"""
# Core task attributes
name: str
description: str
expected_output: str
# Agent and context
agent: str
context: list[str]
# Tools and callbacks (can be string references or instances)
tools: list[str] | list[BaseTool]
callback: str
callbacks: list[str]
# Output configuration
output_json: str
output_pydantic: str
output_file: str
create_directory: bool
# Execution configuration
async_execution: bool
human_input: bool
markdown: bool
# Guardrail configuration
guardrail: Callable[[TaskOutput], tuple[bool, Any]] | str
guardrail_max_retries: int
# Misc configuration
allow_crewai_trigger_context: bool
load_dotenv()
T = TypeVar("T", bound=type)
"""Base decorator for creating crew classes with configuration and function management."""
CallableT = TypeVar("CallableT", bound=Callable[..., Any])
def CrewBase(cls: T) -> T: # noqa: N802
"""Wraps a class with crew functionality and configuration management."""
def _set_base_directory(cls: type[CrewClass]) -> None:
"""Set the base directory for the crew class.
class WrappedClass(cls): # type: ignore
is_crew_class: bool = True # type: ignore
Args:
cls: Crew class to configure.
"""
try:
cls.base_directory = Path(inspect.getfile(cls)).parent
except (TypeError, OSError):
cls.base_directory = Path.cwd()
# Get the directory of the class being decorated
base_directory = Path(inspect.getfile(cls)).parent
original_agents_config_path = getattr(
cls, "agents_config", "config/agents.yaml"
def _set_config_paths(cls: type[CrewClass]) -> None:
"""Set the configuration file paths for the crew class.
Args:
cls: Crew class to configure.
"""
cls.original_agents_config_path = getattr(
cls, "agents_config", "config/agents.yaml"
)
cls.original_tasks_config_path = getattr(cls, "tasks_config", "config/tasks.yaml")
def _set_mcp_params(cls: type[CrewClass]) -> None:
"""Set the MCP server parameters for the crew class.
Args:
cls: Crew class to configure.
"""
cls.mcp_server_params = getattr(cls, "mcp_server_params", None)
cls.mcp_connect_timeout = getattr(cls, "mcp_connect_timeout", 30)
def _is_string_list(value: list[str] | list[BaseTool]) -> TypeGuard[list[str]]:
"""Type guard to check if list contains strings rather than BaseTool instances.
Args:
value: List that may contain strings or BaseTool instances.
Returns:
True if all elements are strings, False otherwise.
"""
return all(isinstance(item, str) for item in value)
def _is_string_value(value: str | CacheHandler) -> TypeGuard[str]:
"""Type guard to check if value is a string rather than a CacheHandler instance.
Args:
value: Value that may be a string or CacheHandler instance.
Returns:
True if value is a string, False otherwise.
"""
return isinstance(value, str)
class CrewBaseMeta(type):
"""Metaclass that adds crew functionality to classes."""
def __new__(
mcs,
name: str,
bases: tuple[type, ...],
namespace: dict[str, Any],
**kwargs: Any,
) -> type[CrewClass]:
"""Create crew class with configuration and method injection.
Args:
name: Class name.
bases: Base classes.
namespace: Class namespace dictionary.
**kwargs: Additional keyword arguments.
Returns:
New crew class with injected methods and attributes.
"""
cls = cast(
type[CrewClass], cast(object, super().__new__(mcs, name, bases, namespace))
)
original_tasks_config_path = getattr(cls, "tasks_config", "config/tasks.yaml")
mcp_server_params: Any = getattr(cls, "mcp_server_params", None)
mcp_connect_timeout: int = getattr(cls, "mcp_connect_timeout", 30)
cls.is_crew_class = True
cls._crew_name = name
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.load_configurations()
self.map_all_agent_variables()
self.map_all_task_variables()
# Preserve all decorated functions
self._original_functions = {
name: method
for name, method in cls.__dict__.items()
if any(
hasattr(method, attr)
for attr in [
"is_task",
"is_agent",
"is_before_kickoff",
"is_after_kickoff",
"is_kickoff",
]
)
}
# Store specific function types
self._original_tasks = self._filter_functions(
self._original_functions, "is_task"
)
self._original_agents = self._filter_functions(
self._original_functions, "is_agent"
)
self._before_kickoff = self._filter_functions(
self._original_functions, "is_before_kickoff"
)
self._after_kickoff = self._filter_functions(
self._original_functions, "is_after_kickoff"
)
self._kickoff = self._filter_functions(
self._original_functions, "is_kickoff"
)
for setup_fn in _CLASS_SETUP_FUNCTIONS:
setup_fn(cls)
# Add close mcp server method to after kickoff
bound_method = self._create_close_mcp_server_method()
self._after_kickoff['_close_mcp_server'] = bound_method
for method in _METHODS_TO_INJECT:
setattr(cls, method.__name__, method)
def _create_close_mcp_server_method(self):
def _close_mcp_server(self, instance, outputs):
adapter = getattr(self, '_mcp_server_adapter', None)
if adapter is not None:
try:
adapter.stop()
except Exception as e:
logging.warning(f"Error stopping MCP server: {e}")
return outputs
return cls
_close_mcp_server.is_after_kickoff = True
def __call__(cls, *args: Any, **kwargs: Any) -> CrewInstance:
"""Intercept instance creation to initialize crew functionality.
import types
return types.MethodType(_close_mcp_server, self)
Args:
*args: Positional arguments for instance creation.
**kwargs: Keyword arguments for instance creation.
def get_mcp_tools(self, *tool_names: list[str]) -> list[BaseTool]:
if not self.mcp_server_params:
return []
Returns:
Initialized crew instance.
"""
instance: CrewInstance = super().__call__(*args, **kwargs)
CrewBaseMeta._initialize_crew_instance(instance, cls)
return instance
from crewai_tools import MCPServerAdapter # type: ignore[import-untyped]
@staticmethod
def _initialize_crew_instance(instance: CrewInstance, cls: type) -> None:
"""Initialize crew instance attributes and load configurations.
adapter = getattr(self, '_mcp_server_adapter', None)
if not adapter:
self._mcp_server_adapter = MCPServerAdapter(
self.mcp_server_params,
connect_timeout=self.mcp_connect_timeout
)
Args:
instance: Crew instance to initialize.
cls: Crew class type.
"""
instance._mcp_server_adapter = None
instance.load_configurations()
instance._all_methods = _get_all_methods(instance)
instance.map_all_agent_variables()
instance.map_all_task_variables()
return self._mcp_server_adapter.tools.filter_by_names(tool_names or None)
def load_configurations(self):
"""Load agent and task configurations from YAML files."""
if isinstance(self.original_agents_config_path, str):
agents_config_path = (
self.base_directory / self.original_agents_config_path
)
try:
self.agents_config = self.load_yaml(agents_config_path)
except FileNotFoundError:
logging.warning(
f"Agent config file not found at {agents_config_path}. "
"Proceeding with empty agent configurations."
)
self.agents_config = {}
else:
logging.warning(
"No agent configuration path provided. Proceeding with empty agent configurations."
)
self.agents_config = {}
if isinstance(self.original_tasks_config_path, str):
tasks_config_path = (
self.base_directory / self.original_tasks_config_path
)
try:
self.tasks_config = self.load_yaml(tasks_config_path)
except FileNotFoundError:
logging.warning(
f"Task config file not found at {tasks_config_path}. "
"Proceeding with empty task configurations."
)
self.tasks_config = {}
else:
logging.warning(
"No task configuration path provided. Proceeding with empty task configurations."
)
self.tasks_config = {}
@staticmethod
def load_yaml(config_path: Path):
try:
with open(config_path, "r", encoding="utf-8") as file:
return yaml.safe_load(file)
except FileNotFoundError:
print(f"File not found: {config_path}")
raise
def _get_all_functions(self):
return {
name: getattr(self, name)
for name in dir(self)
if callable(getattr(self, name))
}
def _filter_functions(
self, functions: dict[str, Callable], attribute: str
) -> dict[str, Callable]:
return {
name: func
for name, func in functions.items()
if hasattr(func, attribute)
}
def map_all_agent_variables(self) -> None:
all_functions = self._get_all_functions()
llms = self._filter_functions(all_functions, "is_llm")
tool_functions = self._filter_functions(all_functions, "is_tool")
cache_handler_functions = self._filter_functions(
all_functions, "is_cache_handler"
)
callbacks = self._filter_functions(all_functions, "is_callback")
for agent_name, agent_info in self.agents_config.items():
self._map_agent_variables(
agent_name,
agent_info,
llms,
tool_functions,
cache_handler_functions,
callbacks,
)
def _map_agent_variables(
self,
agent_name: str,
agent_info: dict[str, Any],
llms: dict[str, Callable],
tool_functions: dict[str, Callable],
cache_handler_functions: dict[str, Callable],
callbacks: dict[str, Callable],
) -> None:
if llm := agent_info.get("llm"):
try:
self.agents_config[agent_name]["llm"] = llms[llm]()
except KeyError:
self.agents_config[agent_name]["llm"] = llm
if tools := agent_info.get("tools"):
self.agents_config[agent_name]["tools"] = [
tool_functions[tool]() for tool in tools
original_methods = {
name: method
for name, method in cls.__dict__.items()
if any(
hasattr(method, attr)
for attr in [
"is_task",
"is_agent",
"is_before_kickoff",
"is_after_kickoff",
"is_kickoff",
]
if function_calling_llm := agent_info.get("function_calling_llm"):
try:
self.agents_config[agent_name]["function_calling_llm"] = llms[function_calling_llm]()
except KeyError:
self.agents_config[agent_name]["function_calling_llm"] = function_calling_llm
if step_callback := agent_info.get("step_callback"):
self.agents_config[agent_name]["step_callback"] = callbacks[
step_callback
]()
if cache_handler := agent_info.get("cache_handler"):
self.agents_config[agent_name]["cache_handler"] = (
cache_handler_functions[cache_handler]()
)
def map_all_task_variables(self) -> None:
all_functions = self._get_all_functions()
agents = self._filter_functions(all_functions, "is_agent")
tasks = self._filter_functions(all_functions, "is_task")
output_json_functions = self._filter_functions(
all_functions, "is_output_json"
)
tool_functions = self._filter_functions(all_functions, "is_tool")
callback_functions = self._filter_functions(all_functions, "is_callback")
output_pydantic_functions = self._filter_functions(
all_functions, "is_output_pydantic"
}
after_kickoff_callbacks = _filter_methods(original_methods, "is_after_kickoff")
after_kickoff_callbacks["close_mcp_server"] = instance.close_mcp_server
instance.__crew_metadata__ = CrewMetadata(
original_methods=original_methods,
original_tasks=_filter_methods(original_methods, "is_task"),
original_agents=_filter_methods(original_methods, "is_agent"),
before_kickoff=_filter_methods(original_methods, "is_before_kickoff"),
after_kickoff=after_kickoff_callbacks,
kickoff=_filter_methods(original_methods, "is_kickoff"),
)
def close_mcp_server(
self: CrewInstance, _instance: CrewInstance, outputs: CrewOutput
) -> CrewOutput:
"""Stop MCP server adapter and return outputs.
Args:
self: Crew instance with MCP server adapter.
_instance: Crew instance (unused, required by callback signature).
outputs: Crew execution outputs.
Returns:
Unmodified crew outputs.
"""
if self._mcp_server_adapter is not None:
try:
self._mcp_server_adapter.stop()
except Exception as e:
logging.warning(f"Error stopping MCP server: {e}")
return outputs
def get_mcp_tools(self: CrewInstance, *tool_names: str) -> list[BaseTool]:
"""Get MCP tools filtered by name.
Args:
self: Crew instance with MCP server configuration.
*tool_names: Optional tool names to filter by.
Returns:
List of filtered MCP tools, or empty list if no MCP server configured.
"""
if not self.mcp_server_params:
return []
from crewai_tools import MCPServerAdapter # type: ignore[import-untyped]
if self._mcp_server_adapter is None:
self._mcp_server_adapter = MCPServerAdapter(
self.mcp_server_params, connect_timeout=self.mcp_connect_timeout
)
return self._mcp_server_adapter.tools.filter_by_names(tool_names or None)
def _load_config(
self: CrewInstance, config_path: str | None, config_type: Literal["agent", "task"]
) -> dict[str, Any]:
"""Load YAML config file or return empty dict if not found.
Args:
self: Crew instance with base directory and load_yaml method.
config_path: Relative path to config file.
config_type: Config type for logging, either "agent" or "task".
Returns:
Config dictionary or empty dict.
"""
if isinstance(config_path, str):
full_path = self.base_directory / config_path
try:
return self.load_yaml(full_path)
except FileNotFoundError:
logging.warning(
f"{config_type.capitalize()} config file not found at {full_path}. "
f"Proceeding with empty {config_type} configurations."
)
return {}
else:
logging.warning(
f"No {config_type} configuration path provided. "
f"Proceeding with empty {config_type} configurations."
)
return {}
for task_name, task_info in self.tasks_config.items():
self._map_task_variables(
task_name,
task_info,
agents,
tasks,
output_json_functions,
tool_functions,
callback_functions,
output_pydantic_functions,
)
def _map_task_variables(
self,
task_name: str,
task_info: dict[str, Any],
agents: dict[str, Callable],
tasks: dict[str, Callable],
output_json_functions: dict[str, Callable],
tool_functions: dict[str, Callable],
callback_functions: dict[str, Callable],
output_pydantic_functions: dict[str, Callable],
) -> None:
if context_list := task_info.get("context"):
self.tasks_config[task_name]["context"] = [
tasks[context_task_name]() for context_task_name in context_list
]
def load_configurations(self: CrewInstance) -> None:
"""Load agent and task YAML configurations.
if tools := task_info.get("tools"):
self.tasks_config[task_name]["tools"] = [
tool_functions[tool]() for tool in tools
]
Args:
self: Crew instance with configuration paths.
"""
self.agents_config = self._load_config(self.original_agents_config_path, "agent")
self.tasks_config = self._load_config(self.original_tasks_config_path, "task")
if agent_name := task_info.get("agent"):
self.tasks_config[task_name]["agent"] = agents[agent_name]()
if output_json := task_info.get("output_json"):
self.tasks_config[task_name]["output_json"] = output_json_functions[
output_json
]
def load_yaml(config_path: Path) -> dict[str, Any]:
"""Load and parse YAML configuration file.
if output_pydantic := task_info.get("output_pydantic"):
self.tasks_config[task_name]["output_pydantic"] = (
output_pydantic_functions[output_pydantic]
)
Args:
config_path: Path to YAML configuration file.
if callbacks := task_info.get("callbacks"):
self.tasks_config[task_name]["callbacks"] = [
callback_functions[callback]() for callback in callbacks
]
Returns:
Parsed YAML content as a dictionary. Returns empty dict if file is empty.
if guardrail := task_info.get("guardrail"):
self.tasks_config[task_name]["guardrail"] = guardrail
Raises:
FileNotFoundError: If config file does not exist.
"""
try:
with open(config_path, encoding="utf-8") as file:
content = yaml.safe_load(file)
return content if isinstance(content, dict) else {}
except FileNotFoundError:
logging.warning(f"File not found: {config_path}")
raise
# Include base class (qual)name in the wrapper class (qual)name.
WrappedClass.__name__ = CrewBase.__name__ + "(" + cls.__name__ + ")"
WrappedClass.__qualname__ = CrewBase.__qualname__ + "(" + cls.__name__ + ")"
WrappedClass._crew_name = cls.__name__
return cast(T, WrappedClass)
def _get_all_methods(self: CrewInstance) -> dict[str, Callable[..., Any]]:
"""Return all non-dunder callable attributes (methods).
Args:
self: Instance to inspect for callable attributes.
Returns:
Dictionary mapping method names to bound method objects.
"""
return {
name: getattr(self, name)
for name in dir(self)
if not (name.startswith("__") and name.endswith("__"))
and callable(getattr(self, name, None))
}
def _filter_methods(
methods: dict[str, CallableT], attribute: str
) -> dict[str, CallableT]:
"""Filter methods by attribute presence, preserving exact callable types.
Args:
methods: Dictionary of methods to filter.
attribute: Attribute name to check for.
Returns:
Dictionary containing only methods with the specified attribute.
The return type matches the input callable type exactly.
"""
return {
name: method for name, method in methods.items() if hasattr(method, attribute)
}
def map_all_agent_variables(self: CrewInstance) -> None:
"""Map agent configuration variables to callable instances.
Args:
self: Crew instance with agent configurations to map.
"""
llms = _filter_methods(self._all_methods, "is_llm")
tool_functions = _filter_methods(self._all_methods, "is_tool")
cache_handler_functions = _filter_methods(self._all_methods, "is_cache_handler")
callbacks = _filter_methods(self._all_methods, "is_callback")
for agent_name, agent_info in self.agents_config.items():
self._map_agent_variables(
agent_name=agent_name,
agent_info=agent_info,
llms=llms,
tool_functions=tool_functions,
cache_handler_functions=cache_handler_functions,
callbacks=callbacks,
)
def _map_agent_variables(
self: CrewInstance,
agent_name: str,
agent_info: AgentConfig,
llms: dict[str, Callable[[], Any]],
tool_functions: dict[str, Callable[[], BaseTool]],
cache_handler_functions: dict[str, Callable[[], Any]],
callbacks: dict[str, Callable[..., Any]],
) -> None:
"""Resolve and map variables for a single agent.
Args:
self: Crew instance with agent configurations.
agent_name: Name of agent to configure.
agent_info: Agent configuration dictionary with optional fields.
llms: Dictionary mapping names to LLM factory functions.
tool_functions: Dictionary mapping names to tool factory functions.
cache_handler_functions: Dictionary mapping names to cache handler factory functions.
callbacks: Dictionary of available callbacks.
"""
if llm := agent_info.get("llm"):
factory = llms.get(llm)
self.agents_config[agent_name]["llm"] = factory() if factory else llm
if tools := agent_info.get("tools"):
if _is_string_list(tools):
self.agents_config[agent_name]["tools"] = [
tool_functions[tool]() for tool in tools
]
if function_calling_llm := agent_info.get("function_calling_llm"):
factory = llms.get(function_calling_llm)
self.agents_config[agent_name]["function_calling_llm"] = (
factory() if factory else function_calling_llm
)
if step_callback := agent_info.get("step_callback"):
self.agents_config[agent_name]["step_callback"] = callbacks[step_callback]()
if cache_handler := agent_info.get("cache_handler"):
if _is_string_value(cache_handler):
self.agents_config[agent_name]["cache_handler"] = cache_handler_functions[
cache_handler
]()
def map_all_task_variables(self: CrewInstance) -> None:
"""Map task configuration variables to callable instances.
Args:
self: Crew instance with task configurations to map.
"""
agents = _filter_methods(self._all_methods, "is_agent")
tasks = _filter_methods(self._all_methods, "is_task")
output_json_functions = _filter_methods(self._all_methods, "is_output_json")
tool_functions = _filter_methods(self._all_methods, "is_tool")
callback_functions = _filter_methods(self._all_methods, "is_callback")
output_pydantic_functions = _filter_methods(self._all_methods, "is_output_pydantic")
for task_name, task_info in self.tasks_config.items():
self._map_task_variables(
task_name=task_name,
task_info=task_info,
agents=agents,
tasks=tasks,
output_json_functions=output_json_functions,
tool_functions=tool_functions,
callback_functions=callback_functions,
output_pydantic_functions=output_pydantic_functions,
)
def _map_task_variables(
self: CrewInstance,
task_name: str,
task_info: TaskConfig,
agents: dict[str, Callable[[], Agent]],
tasks: dict[str, Callable[[], Task]],
output_json_functions: dict[str, OutputJsonClass[Any]],
tool_functions: dict[str, Callable[[], BaseTool]],
callback_functions: dict[str, Callable[..., Any]],
output_pydantic_functions: dict[str, OutputPydanticClass[Any]],
) -> None:
"""Resolve and map variables for a single task.
Args:
self: Crew instance with task configurations.
task_name: Name of task to configure.
task_info: Task configuration dictionary with optional fields.
agents: Dictionary mapping names to agent factory functions.
tasks: Dictionary mapping names to task factory functions.
output_json_functions: Dictionary of JSON output class wrappers.
tool_functions: Dictionary mapping names to tool factory functions.
callback_functions: Dictionary of available callbacks.
output_pydantic_functions: Dictionary of Pydantic output class wrappers.
"""
if context_list := task_info.get("context"):
self.tasks_config[task_name]["context"] = [
tasks[context_task_name]() for context_task_name in context_list
]
if tools := task_info.get("tools"):
if _is_string_list(tools):
self.tasks_config[task_name]["tools"] = [
tool_functions[tool]() for tool in tools
]
if agent_name := task_info.get("agent"):
self.tasks_config[task_name]["agent"] = agents[agent_name]()
if output_json := task_info.get("output_json"):
self.tasks_config[task_name]["output_json"] = output_json_functions[output_json]
if output_pydantic := task_info.get("output_pydantic"):
self.tasks_config[task_name]["output_pydantic"] = output_pydantic_functions[
output_pydantic
]
if callbacks := task_info.get("callbacks"):
self.tasks_config[task_name]["callbacks"] = [
callback_functions[callback]() for callback in callbacks
]
if guardrail := task_info.get("guardrail"):
self.tasks_config[task_name]["guardrail"] = guardrail
_CLASS_SETUP_FUNCTIONS: tuple[Callable[[type[CrewClass]], None], ...] = (
_set_base_directory,
_set_config_paths,
_set_mcp_params,
)
_METHODS_TO_INJECT = (
close_mcp_server,
get_mcp_tools,
_load_config,
load_configurations,
staticmethod(load_yaml),
map_all_agent_variables,
_map_agent_variables,
map_all_task_variables,
_map_task_variables,
)
class _CrewBaseType(type):
"""Metaclass for CrewBase that makes it callable as a decorator."""
def __call__(cls, decorated_cls: type) -> type[CrewClass]:
"""Apply CrewBaseMeta to the decorated class.
Args:
decorated_cls: Class to transform with CrewBaseMeta metaclass.
Returns:
New class with CrewBaseMeta metaclass applied.
"""
__name = str(decorated_cls.__name__)
__bases = tuple(decorated_cls.__bases__)
__dict = {
key: value
for key, value in decorated_cls.__dict__.items()
if key not in ("__dict__", "__weakref__")
}
for slot in __dict.get("__slots__", tuple()):
__dict.pop(slot, None)
__dict["__metaclass__"] = CrewBaseMeta
return cast(type[CrewClass], CrewBaseMeta(__name, __bases, __dict))
class CrewBase(metaclass=_CrewBaseType):
"""Class decorator that applies CrewBaseMeta metaclass.
Applies CrewBaseMeta metaclass to a class via decorator syntax rather than
explicit metaclass declaration. Use as @CrewBase instead of
class Foo(metaclass=CrewBaseMeta).
Note:
Reference: https://stackoverflow.com/questions/11091609/setting-a-class-metaclass-using-a-decorator
"""

View File

@@ -1,14 +1,38 @@
"""Utility functions for the crewai project module."""
from collections.abc import Callable
from functools import wraps
from typing import Any, ParamSpec, TypeVar
P = ParamSpec("P")
R = TypeVar("R")
def memoize(func):
cache = {}
def memoize(meth: Callable[P, R]) -> Callable[P, R]:
"""Memoize a method by caching its results based on arguments.
@wraps(func)
def memoized_func(*args, **kwargs):
Args:
meth: The method to memoize.
Returns:
A memoized version of the method that caches results.
"""
cache: dict[Any, R] = {}
@wraps(meth)
def memoized_func(*args: P.args, **kwargs: P.kwargs) -> R:
"""Memoized wrapper method.
Args:
*args: Positional arguments to pass to the method.
**kwargs: Keyword arguments to pass to the method.
Returns:
The cached or computed result of the method.
"""
key = (args, tuple(kwargs.items()))
if key not in cache:
cache[key] = func(*args, **kwargs)
cache[key] = meth(*args, **kwargs)
return cache[key]
return memoized_func

View File

@@ -0,0 +1,388 @@
"""Wrapper classes for decorated methods with type-safe metadata."""
from __future__ import annotations
from collections.abc import Callable
from functools import partial
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Generic,
Literal,
ParamSpec,
Protocol,
TypedDict,
TypeVar,
)
from typing_extensions import Self
if TYPE_CHECKING:
from crewai import Agent, Task
from crewai.crews.crew_output import CrewOutput
from crewai.tools import BaseTool
class CrewMetadata(TypedDict):
"""Type definition for crew metadata dictionary.
Stores framework-injected metadata about decorated methods and callbacks.
"""
original_methods: dict[str, Callable[..., Any]]
original_tasks: dict[str, Callable[..., Task]]
original_agents: dict[str, Callable[..., Agent]]
before_kickoff: dict[str, Callable[..., Any]]
after_kickoff: dict[str, Callable[..., Any]]
kickoff: dict[str, Callable[..., Any]]
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
class TaskResult(Protocol):
"""Protocol for task objects that have a name attribute."""
name: str | None
TaskResultT = TypeVar("TaskResultT", bound=TaskResult)
def _copy_method_metadata(wrapper: Any, meth: Callable[..., Any]) -> None:
"""Copy method metadata to a wrapper object.
Args:
wrapper: The wrapper object to update.
meth: The method to copy metadata from.
"""
wrapper.__name__ = meth.__name__
wrapper.__doc__ = meth.__doc__
class CrewInstance(Protocol):
"""Protocol for crew class instances with required attributes."""
__crew_metadata__: CrewMetadata
_mcp_server_adapter: Any
_all_methods: dict[str, Callable[..., Any]]
agents: list[Agent]
tasks: list[Task]
base_directory: Path
original_agents_config_path: str
original_tasks_config_path: str
agents_config: dict[str, Any]
tasks_config: dict[str, Any]
mcp_server_params: Any
mcp_connect_timeout: int
def load_configurations(self) -> None: ...
def map_all_agent_variables(self) -> None: ...
def map_all_task_variables(self) -> None: ...
def close_mcp_server(self, instance: Self, outputs: CrewOutput) -> CrewOutput: ...
def _load_config(
self, config_path: str | None, config_type: Literal["agent", "task"]
) -> dict[str, Any]: ...
def _map_agent_variables(
self,
agent_name: str,
agent_info: dict[str, Any],
llms: dict[str, Callable[..., Any]],
tool_functions: dict[str, Callable[..., Any]],
cache_handler_functions: dict[str, Callable[..., Any]],
callbacks: dict[str, Callable[..., Any]],
) -> None: ...
def _map_task_variables(
self,
task_name: str,
task_info: dict[str, Any],
agents: dict[str, Callable[..., Any]],
tasks: dict[str, Callable[..., Any]],
output_json_functions: dict[str, Callable[..., Any]],
tool_functions: dict[str, Callable[..., Any]],
callback_functions: dict[str, Callable[..., Any]],
output_pydantic_functions: dict[str, Callable[..., Any]],
) -> None: ...
def load_yaml(self, config_path: Path) -> dict[str, Any]: ...
class CrewClass(Protocol):
"""Protocol describing class attributes injected by CrewBaseMeta."""
is_crew_class: bool
_crew_name: str
base_directory: Path
original_agents_config_path: str
original_tasks_config_path: str
mcp_server_params: Any
mcp_connect_timeout: int
close_mcp_server: Callable[..., Any]
get_mcp_tools: Callable[..., list[BaseTool]]
_load_config: Callable[..., dict[str, Any]]
load_configurations: Callable[..., None]
load_yaml: staticmethod
map_all_agent_variables: Callable[..., None]
_map_agent_variables: Callable[..., None]
map_all_task_variables: Callable[..., None]
_map_task_variables: Callable[..., None]
class DecoratedMethod(Generic[P, R]):
"""Base wrapper for methods with decorator metadata.
This class provides a type-safe way to add metadata to methods
while preserving their callable signature and attributes.
"""
def __init__(self, meth: Callable[P, R]) -> None:
"""Initialize the decorated method wrapper.
Args:
meth: The method to wrap.
"""
self._meth = meth
_copy_method_metadata(self, meth)
def __get__(
self, obj: Any, objtype: type[Any] | None = None
) -> Self | Callable[..., R]:
"""Support instance methods by implementing the descriptor protocol.
Args:
obj: The instance that the method is accessed through.
objtype: The type of the instance.
Returns:
Self when accessed through class, bound method when accessed through instance.
"""
if obj is None:
return self
bound = partial(self._meth, obj)
for attr in (
"is_agent",
"is_llm",
"is_tool",
"is_callback",
"is_cache_handler",
"is_before_kickoff",
"is_after_kickoff",
"is_crew",
):
if hasattr(self, attr):
setattr(bound, attr, getattr(self, attr))
return bound
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
"""Call the wrapped method.
Args:
*args: Positional arguments.
**kwargs: Keyword arguments.
Returns:
The result of calling the wrapped method.
"""
return self._meth(*args, **kwargs)
def unwrap(self) -> Callable[P, R]:
"""Get the original unwrapped method.
Returns:
The original method before decoration.
"""
return self._meth
class BeforeKickoffMethod(DecoratedMethod[P, R]):
"""Wrapper for methods marked to execute before crew kickoff."""
is_before_kickoff: bool = True
class AfterKickoffMethod(DecoratedMethod[P, R]):
"""Wrapper for methods marked to execute after crew kickoff."""
is_after_kickoff: bool = True
class BoundTaskMethod(Generic[TaskResultT]):
"""Bound task method with task marker attribute."""
is_task: bool = True
def __init__(self, task_method: TaskMethod[Any, TaskResultT], obj: Any) -> None:
"""Initialize the bound task method.
Args:
task_method: The TaskMethod descriptor instance.
obj: The instance to bind to.
"""
self._task_method = task_method
self._obj = obj
def __call__(self, *args: Any, **kwargs: Any) -> TaskResultT:
"""Execute the bound task method.
Args:
*args: Positional arguments.
**kwargs: Keyword arguments.
Returns:
The task result with name ensured.
"""
result = self._task_method.unwrap()(self._obj, *args, **kwargs)
return self._task_method.ensure_task_name(result)
class TaskMethod(Generic[P, TaskResultT]):
"""Wrapper for methods marked as crew tasks."""
is_task: bool = True
def __init__(self, meth: Callable[P, TaskResultT]) -> None:
"""Initialize the task method wrapper.
Args:
meth: The method to wrap.
"""
self._meth = meth
_copy_method_metadata(self, meth)
def ensure_task_name(self, result: TaskResultT) -> TaskResultT:
"""Ensure task result has a name set.
Args:
result: The task result to check.
Returns:
The task result with name ensured.
"""
if not result.name:
result.name = self._meth.__name__
return result
def __get__(
self, obj: Any, objtype: type[Any] | None = None
) -> Self | BoundTaskMethod[TaskResultT]:
"""Support instance methods by implementing the descriptor protocol.
Args:
obj: The instance that the method is accessed through.
objtype: The type of the instance.
Returns:
Self when accessed through class, bound method when accessed through instance.
"""
if obj is None:
return self
return BoundTaskMethod(self, obj)
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> TaskResultT:
"""Call the wrapped method and set task name if not provided.
Args:
*args: Positional arguments.
**kwargs: Keyword arguments.
Returns:
The task instance with name set if not already provided.
"""
return self.ensure_task_name(self._meth(*args, **kwargs))
def unwrap(self) -> Callable[P, TaskResultT]:
"""Get the original unwrapped method.
Returns:
The original method before decoration.
"""
return self._meth
class AgentMethod(DecoratedMethod[P, R]):
"""Wrapper for methods marked as crew agents."""
is_agent: bool = True
class LLMMethod(DecoratedMethod[P, R]):
"""Wrapper for methods marked as LLM providers."""
is_llm: bool = True
class ToolMethod(DecoratedMethod[P, R]):
"""Wrapper for methods marked as crew tools."""
is_tool: bool = True
class CallbackMethod(DecoratedMethod[P, R]):
"""Wrapper for methods marked as crew callbacks."""
is_callback: bool = True
class CacheHandlerMethod(DecoratedMethod[P, R]):
"""Wrapper for methods marked as cache handlers."""
is_cache_handler: bool = True
class CrewMethod(DecoratedMethod[P, R]):
"""Wrapper for methods marked as the main crew execution point."""
is_crew: bool = True
class OutputClass(Generic[T]):
"""Base wrapper for classes marked as output format."""
def __init__(self, cls: type[T]) -> None:
"""Initialize the output class wrapper.
Args:
cls: The class to wrap.
"""
self._cls = cls
self.__name__ = cls.__name__
self.__qualname__ = cls.__qualname__
self.__module__ = cls.__module__
self.__doc__ = cls.__doc__
def __call__(self, *args: Any, **kwargs: Any) -> T:
"""Create an instance of the wrapped class.
Args:
*args: Positional arguments for the class constructor.
**kwargs: Keyword arguments for the class constructor.
Returns:
An instance of the wrapped class.
"""
return self._cls(*args, **kwargs)
def __getattr__(self, name: str) -> Any:
"""Delegate attribute access to the wrapped class.
Args:
name: The attribute name.
Returns:
The attribute from the wrapped class.
"""
return getattr(self._cls, name)
class OutputJsonClass(OutputClass[T]):
"""Wrapper for classes marked as JSON output format."""
is_output_json: bool = True
class OutputPydanticClass(OutputClass[T]):
"""Wrapper for classes marked as Pydantic output format."""
is_output_pydantic: bool = True

View File

@@ -1,11 +1,6 @@
"""Utility for colored console output."""
from __future__ import annotations
from typing import TYPE_CHECKING, Final, Literal, NamedTuple
if TYPE_CHECKING:
from _typeshed import SupportsWrite
from typing import Final, Literal, NamedTuple
PrinterColor = Literal[
"purple",
@@ -59,22 +54,13 @@ class Printer:
@staticmethod
def print(
content: str | list[ColoredText],
color: PrinterColor | None = None,
sep: str | None = " ",
end: str | None = "\n",
file: SupportsWrite[str] | None = None,
flush: Literal[False] = False,
content: str | list[ColoredText], color: PrinterColor | None = None
) -> None:
"""Prints content to the console with optional color formatting.
Args:
content: Either a string or a list of ColoredText objects for multicolor output.
color: Optional color for the text when content is a string. Ignored when content is a list.
sep: Separator to use between the text and color.
end: String appended after the last value.
file: A file-like object (stream); defaults to the current sys.stdout.
flush: Whether to forcibly flush the stream.
"""
if isinstance(content, str):
content = [ColoredText(content, color)]
@@ -82,9 +68,5 @@ class Printer:
"".join(
f"{_COLOR_CODES[c.color] if c.color else ''}{c.text}{RESET}"
for c in content
),
sep=sep,
end=end,
file=file,
flush=flush,
)
)

View File

@@ -4744,81 +4744,3 @@ def test_ensure_exchanged_messages_are_propagated_to_external_memory():
assert "Researcher" in messages[0]["content"]
assert messages[1]["role"] == "user"
assert "Research a topic to teach a kid aged 6 about math" in messages[1]["content"]
@pytest.mark.vcr(filter_headers=["authorization"])
def test_crew_kickoff_stream(researcher):
"""Test that crew.kickoff_stream() yields events during execution."""
task = Task(
description="Research a topic about AI",
expected_output="A brief summary about AI",
agent=researcher,
)
crew = Crew(agents=[researcher], tasks=[task])
events = list(crew.kickoff_stream())
assert len(events) > 0
event_types = [event["type"] for event in events]
assert "crew_kickoff_started" in event_types
assert "final_output" in event_types
final_output_event = next(e for e in events if e["type"] == "final_output")
assert "output" in final_output_event["data"]
assert isinstance(final_output_event["data"]["output"], str)
assert len(final_output_event["data"]["output"]) > 0
@pytest.mark.vcr(filter_headers=["authorization"])
def test_crew_kickoff_stream_with_inputs(researcher):
"""Test that crew.kickoff_stream() works with inputs."""
task = Task(
description="Research about {topic}",
expected_output="A brief summary about {topic}",
agent=researcher,
)
crew = Crew(agents=[researcher], tasks=[task])
events = list(crew.kickoff_stream(inputs={"topic": "machine learning"}))
assert len(events) > 0
event_types = [event["type"] for event in events]
assert "crew_kickoff_started" in event_types
assert "final_output" in event_types
@pytest.mark.vcr(filter_headers=["authorization"])
def test_crew_kickoff_stream_includes_llm_chunks(researcher):
"""Test that crew.kickoff_stream() includes LLM stream chunks."""
task = Task(
description="Write a short poem about AI",
expected_output="A 2-line poem",
agent=researcher,
)
crew = Crew(agents=[researcher], tasks=[task])
events = list(crew.kickoff_stream())
event_types = [event["type"] for event in events]
assert "task_started" in event_types or "agent_execution_started" in event_types
def test_crew_kickoff_stream_handles_errors(researcher):
"""Test that crew.kickoff_stream() properly handles errors."""
task = Task(
description="This task will fail",
expected_output="Should not complete",
agent=researcher,
)
crew = Crew(agents=[researcher], tasks=[task])
with patch("crewai.crew.Crew.kickoff", side_effect=Exception("Test error")):
with pytest.raises(Exception, match="Test error"):
list(crew.kickoff_stream())

View File

@@ -6,15 +6,15 @@ from datetime import datetime
import pytest
from pydantic import BaseModel
from crewai.flow.flow import Flow, and_, listen, or_, router, start
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.flow_events import (
FlowFinishedEvent,
FlowPlotEvent,
FlowStartedEvent,
FlowPlotEvent,
MethodExecutionFinishedEvent,
MethodExecutionStartedEvent,
)
from crewai.flow.flow import Flow, and_, listen, or_, router, start
def test_simple_sequential_flow():
@@ -679,11 +679,11 @@ def test_structured_flow_event_emission():
assert isinstance(received_events[3], MethodExecutionStartedEvent)
assert received_events[3].method_name == "send_welcome_message"
assert received_events[3].params == {}
assert received_events[3].state.sent is False
assert getattr(received_events[3].state, "sent") is False
assert isinstance(received_events[4], MethodExecutionFinishedEvent)
assert received_events[4].method_name == "send_welcome_message"
assert received_events[4].state.sent is True
assert getattr(received_events[4].state, "sent") is True
assert received_events[4].result == "Welcome, Anakin!"
assert isinstance(received_events[5], FlowFinishedEvent)
@@ -894,75 +894,3 @@ def test_flow_name():
flow = MyFlow()
assert flow.name == "MyFlow"
def test_nested_and_or_conditions():
"""Test nested conditions like or_(and_(A, B), and_(C, D)).
Reproduces bug from issue #3719 where nested conditions are flattened,
causing premature execution.
"""
execution_order = []
class NestedConditionFlow(Flow):
@start()
def method_1(self):
execution_order.append("method_1")
@listen(method_1)
def method_2(self):
execution_order.append("method_2")
@router(method_2)
def method_3(self):
execution_order.append("method_3")
# Choose b_condition path
return "b_condition"
@listen("b_condition")
def method_5(self):
execution_order.append("method_5")
@listen(method_5)
async def method_4(self):
execution_order.append("method_4")
@listen(or_("a_condition", "b_condition"))
async def method_6(self):
execution_order.append("method_6")
@listen(
or_(
and_("a_condition", method_6),
and_(method_6, method_4),
)
)
def method_7(self):
execution_order.append("method_7")
@listen(method_7)
async def method_8(self):
execution_order.append("method_8")
flow = NestedConditionFlow()
flow.kickoff()
# Verify execution happened
assert "method_1" in execution_order
assert "method_2" in execution_order
assert "method_3" in execution_order
assert "method_5" in execution_order
assert "method_4" in execution_order
assert "method_6" in execution_order
assert "method_7" in execution_order
assert "method_8" in execution_order
# Critical assertion: method_7 should only execute AFTER both method_6 AND method_4
# Since b_condition was returned, method_6 triggers on b_condition
# method_7 requires: (a_condition AND method_6) OR (method_6 AND method_4)
# The second condition (method_6 AND method_4) should be the one that triggers
assert execution_order.index("method_7") > execution_order.index("method_6")
assert execution_order.index("method_7") > execution_order.index("method_4")
# method_8 should execute after method_7
assert execution_order.index("method_8") > execution_order.index("method_7")