Compare commits

..

5 Commits

Author SHA1 Message Date
Devin AI
2ff47f4bd7 fix: remove trailing whitespace from docstring
Co-Authored-By: João <joao@crewai.com>
2025-10-20 12:26:37 +00:00
Devin AI
0af7e04cde fix: address lint issues in streaming implementation
- Remove whitespace from blank lines
- Refactor try-except out of loop for better performance
- Use list() instead of append in loop for better performance

Co-Authored-By: João <joao@crewai.com>
2025-10-20 12:24:36 +00:00
Devin AI
b664637afa feat: add kickoff_stream method for FastAPI streaming integration
- Add kickoff_stream() method to Crew class that yields events in real-time
- Enables easy integration with FastAPI StreamingResponse
- Add comprehensive tests for streaming functionality
- Include FastAPI example demonstrating Server-Sent Events (SSE)

Resolves #3739

Co-Authored-By: João <joao@crewai.com>
2025-10-20 12:21:07 +00:00
Greyson LaLonde
42f2b4d551 fix: preserve nested condition structure in Flow decorators
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Update Test Durations / update-durations (3.10) (push) Has been cancelled
Update Test Durations / update-durations (3.11) (push) Has been cancelled
Update Test Durations / update-durations (3.12) (push) Has been cancelled
Update Test Durations / update-durations (3.13) (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
Build uv cache / build-cache (3.10) (push) Has been cancelled
Build uv cache / build-cache (3.11) (push) Has been cancelled
Build uv cache / build-cache (3.12) (push) Has been cancelled
Build uv cache / build-cache (3.13) (push) Has been cancelled
Fixes nested boolean conditions being flattened in @listen, @start, and @router decorators. The or_() and and_() combinators now preserve their nested structure using a "conditions" key instead of flattening to a list. Added recursive evaluation logic to properly handle complex patterns like or_(and_(A, B), and_(C, D)).
2025-10-17 17:06:19 -04:00
Greyson LaLonde
0229390ad1 fix: add standard print parameters to Printer.print method
- Adds sep, end, file, and flush parameters to match Python's built-in print function signature.
2025-10-17 15:27:22 -04:00
10 changed files with 742 additions and 600 deletions

View File

@@ -0,0 +1,177 @@
"""
FastAPI Streaming Integration Example for CrewAI
This example demonstrates how to integrate CrewAI with FastAPI to stream
crew execution events in real-time using Server-Sent Events (SSE).
Installation:
pip install crewai fastapi uvicorn
Usage:
python fastapi_streaming_example.py
Then visit:
http://localhost:8000/docs for the API documentation
http://localhost:8000/stream?topic=AI to see streaming in action
"""
import json
from typing import AsyncGenerator
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from crewai import Agent, Crew, Task
app = FastAPI(title="CrewAI Streaming API")
class ResearchRequest(BaseModel):
topic: str
def create_research_crew(topic: str) -> Crew:
"""Create a research crew for the given topic."""
researcher = Agent(
role="Researcher",
goal=f"Research and analyze information about {topic}",
backstory="You're an expert researcher with deep knowledge in various fields.",
verbose=True,
)
task = Task(
description=f"Research and provide a comprehensive summary about {topic}",
expected_output="A detailed summary with key insights",
agent=researcher,
)
return Crew(agents=[researcher], tasks=[task], verbose=True)
@app.get("/")
async def root():
"""Root endpoint with API information."""
return {
"message": "CrewAI Streaming API",
"endpoints": {
"/stream": "GET - Stream crew execution events (query param: topic)",
"/research": "POST - Execute crew and return final result",
},
}
@app.get("/stream")
async def stream_crew_execution(topic: str = "artificial intelligence"):
"""
Stream crew execution events in real-time using Server-Sent Events.
Args:
topic: The research topic (default: "artificial intelligence")
Returns:
StreamingResponse with text/event-stream content type
"""
async def event_generator() -> AsyncGenerator[str, None]:
"""Generate Server-Sent Events from crew execution."""
crew = create_research_crew(topic)
try:
for event in crew.kickoff_stream(inputs={"topic": topic}):
event_data = json.dumps(event)
yield f"data: {event_data}\n\n"
yield "data: {\"type\": \"done\"}\n\n"
except Exception as e:
error_event = {"type": "error", "data": {"message": str(e)}}
yield f"data: {json.dumps(error_event)}\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
@app.post("/research")
async def research_topic(request: ResearchRequest):
"""
Execute crew research and return the final result.
Args:
request: ResearchRequest with topic field
Returns:
JSON response with the research result
"""
crew = create_research_crew(request.topic)
try:
result = crew.kickoff(inputs={"topic": request.topic})
return {
"success": True,
"topic": request.topic,
"result": result.raw,
"usage_metrics": (
result.token_usage.model_dump() if result.token_usage else None
),
}
except Exception as e:
return {"success": False, "error": str(e)}
@app.get("/stream-filtered")
async def stream_filtered_events(
topic: str = "artificial intelligence", event_types: str = "llm_stream_chunk"
):
"""
Stream only specific event types.
Args:
topic: The research topic
event_types: Comma-separated list of event types to include
Returns:
StreamingResponse with filtered events
"""
allowed_types = set(event_types.split(","))
async def event_generator() -> AsyncGenerator[str, None]:
crew = create_research_crew(topic)
try:
for event in crew.kickoff_stream(inputs={"topic": topic}):
if event["type"] in allowed_types:
event_data = json.dumps(event)
yield f"data: {event_data}\n\n"
yield "data: {\"type\": \"done\"}\n\n"
except Exception as e:
error_event = {"type": "error", "data": {"message": str(e)}}
yield f"data: {json.dumps(error_event)}\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)
if __name__ == "__main__":
import uvicorn
print("Starting CrewAI Streaming API...")
print("Visit http://localhost:8000/docs for API documentation")
print("Try: http://localhost:8000/stream?topic=quantum%20computing")
uvicorn.run(app, host="0.0.0.0", port=8000)

View File

@@ -864,7 +864,6 @@ class Agent(BaseAgent):
i18n=self.i18n,
original_agent=self,
guardrail=self.guardrail,
guardrail_max_retries=self.guardrail_max_retries,
)
return await lite_agent.kickoff_async(messages)

View File

@@ -766,6 +766,118 @@ class Crew(FlowTrackable, BaseModel):
self._task_output_handler.reset()
return results
def kickoff_stream(self, inputs: dict[str, Any] | None = None):
"""
Stream crew execution events in real-time.
This method yields events as they occur during crew execution, making it
easy to integrate with streaming frameworks like FastAPI's StreamingResponse.
Args:
inputs: Optional dictionary of inputs for the crew execution
Yields:
dict: Event dictionaries containing event type and data
Example:
```python
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
import json
app = FastAPI()
@app.get("/stream")
async def stream_crew():
def event_generator():
for event in crew.kickoff_stream(inputs={"topic": "AI"}):
yield f"data: {json.dumps(event)}\\n\\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream"
)
```
"""
import queue
import threading
from crewai.events.base_events import BaseEvent
event_queue: queue.Queue = queue.Queue()
completion_event = threading.Event()
exception_holder = {"exception": None}
def event_handler(source: Any, event: BaseEvent):
event_dict = {
"type": event.type,
"data": event.model_dump(exclude={"from_task", "from_agent"}),
}
event_queue.put(event_dict)
from crewai.events.types.crew_events import (
CrewKickoffStartedEvent,
CrewKickoffCompletedEvent,
CrewKickoffFailedEvent,
)
from crewai.events.types.task_events import (
TaskStartedEvent,
TaskCompletedEvent,
TaskFailedEvent,
)
from crewai.events.types.agent_events import (
AgentExecutionStartedEvent,
AgentExecutionCompletedEvent,
)
from crewai.events.types.llm_events import (
LLMStreamChunkEvent,
LLMCallStartedEvent,
LLMCallCompletedEvent,
)
from crewai.events.types.tool_usage_events import (
ToolUsageStartedEvent,
ToolUsageFinishedEvent,
ToolUsageErrorEvent,
)
crewai_event_bus.register_handler(CrewKickoffStartedEvent, event_handler)
crewai_event_bus.register_handler(CrewKickoffCompletedEvent, event_handler)
crewai_event_bus.register_handler(CrewKickoffFailedEvent, event_handler)
crewai_event_bus.register_handler(TaskStartedEvent, event_handler)
crewai_event_bus.register_handler(TaskCompletedEvent, event_handler)
crewai_event_bus.register_handler(TaskFailedEvent, event_handler)
crewai_event_bus.register_handler(AgentExecutionStartedEvent, event_handler)
crewai_event_bus.register_handler(AgentExecutionCompletedEvent, event_handler)
crewai_event_bus.register_handler(LLMStreamChunkEvent, event_handler)
crewai_event_bus.register_handler(LLMCallStartedEvent, event_handler)
crewai_event_bus.register_handler(LLMCallCompletedEvent, event_handler)
crewai_event_bus.register_handler(ToolUsageStartedEvent, event_handler)
crewai_event_bus.register_handler(ToolUsageFinishedEvent, event_handler)
crewai_event_bus.register_handler(ToolUsageErrorEvent, event_handler)
def run_kickoff():
try:
result = self.kickoff(inputs=inputs)
event_queue.put({"type": "final_output", "data": {"output": result.raw}})
except Exception as e:
exception_holder["exception"] = e
finally:
completion_event.set()
thread = threading.Thread(target=run_kickoff, daemon=True)
thread.start()
try:
while not completion_event.is_set() or not event_queue.empty():
event = event_queue.get(timeout=0.1) if not event_queue.empty() else None
if event is not None:
yield event
if exception_holder["exception"]:
raise exception_holder["exception"]
finally:
thread.join(timeout=1)
def _handle_crew_planning(self):
"""Handles the Crew planning."""
self._logger.log("info", "Planning the crew execution")

View File

@@ -31,7 +31,7 @@ from crewai.flow.flow_visualizer import plot_flow
from crewai.flow.persistence.base import FlowPersistence
from crewai.flow.types import FlowExecutionData
from crewai.flow.utils import get_possible_return_constants
from crewai.utilities.printer import Printer
from crewai.utilities.printer import Printer, PrinterColor
logger = logging.getLogger(__name__)
@@ -105,7 +105,7 @@ def start(condition: str | dict | Callable | None = None) -> Callable:
condition : Optional[Union[str, dict, Callable]], optional
Defines when the start method should execute. Can be:
- str: Name of a method that triggers this start
- dict: Contains "type" ("AND"/"OR") and "methods" (list of triggers)
- dict: Result from or_() or and_(), including nested conditions
- Callable: A method reference that triggers this start
Default is None, meaning unconditional start.
@@ -140,13 +140,18 @@ def start(condition: str | dict | Callable | None = None) -> Callable:
if isinstance(condition, str):
func.__trigger_methods__ = [condition]
func.__condition_type__ = "OR"
elif (
isinstance(condition, dict)
and "type" in condition
and "methods" in condition
):
func.__trigger_methods__ = condition["methods"]
func.__condition_type__ = condition["type"]
elif isinstance(condition, dict) and "type" in condition:
if "conditions" in condition:
func.__trigger_condition__ = condition
func.__trigger_methods__ = _extract_all_methods(condition)
func.__condition_type__ = condition["type"]
elif "methods" in condition:
func.__trigger_methods__ = condition["methods"]
func.__condition_type__ = condition["type"]
else:
raise ValueError(
"Condition dict must contain 'conditions' or 'methods'"
)
elif callable(condition) and hasattr(condition, "__name__"):
func.__trigger_methods__ = [condition.__name__]
func.__condition_type__ = "OR"
@@ -172,7 +177,7 @@ def listen(condition: str | dict | Callable) -> Callable:
condition : Union[str, dict, Callable]
Specifies when the listener should execute. Can be:
- str: Name of a method that triggers this listener
- dict: Contains "type" ("AND"/"OR") and "methods" (list of triggers)
- dict: Result from or_() or and_(), including nested conditions
- Callable: A method reference that triggers this listener
Returns
@@ -200,13 +205,18 @@ def listen(condition: str | dict | Callable) -> Callable:
if isinstance(condition, str):
func.__trigger_methods__ = [condition]
func.__condition_type__ = "OR"
elif (
isinstance(condition, dict)
and "type" in condition
and "methods" in condition
):
func.__trigger_methods__ = condition["methods"]
func.__condition_type__ = condition["type"]
elif isinstance(condition, dict) and "type" in condition:
if "conditions" in condition:
func.__trigger_condition__ = condition
func.__trigger_methods__ = _extract_all_methods(condition)
func.__condition_type__ = condition["type"]
elif "methods" in condition:
func.__trigger_methods__ = condition["methods"]
func.__condition_type__ = condition["type"]
else:
raise ValueError(
"Condition dict must contain 'conditions' or 'methods'"
)
elif callable(condition) and hasattr(condition, "__name__"):
func.__trigger_methods__ = [condition.__name__]
func.__condition_type__ = "OR"
@@ -233,7 +243,7 @@ def router(condition: str | dict | Callable) -> Callable:
condition : Union[str, dict, Callable]
Specifies when the router should execute. Can be:
- str: Name of a method that triggers this router
- dict: Contains "type" ("AND"/"OR") and "methods" (list of triggers)
- dict: Result from or_() or and_(), including nested conditions
- Callable: A method reference that triggers this router
Returns
@@ -266,13 +276,18 @@ def router(condition: str | dict | Callable) -> Callable:
if isinstance(condition, str):
func.__trigger_methods__ = [condition]
func.__condition_type__ = "OR"
elif (
isinstance(condition, dict)
and "type" in condition
and "methods" in condition
):
func.__trigger_methods__ = condition["methods"]
func.__condition_type__ = condition["type"]
elif isinstance(condition, dict) and "type" in condition:
if "conditions" in condition:
func.__trigger_condition__ = condition
func.__trigger_methods__ = _extract_all_methods(condition)
func.__condition_type__ = condition["type"]
elif "methods" in condition:
func.__trigger_methods__ = condition["methods"]
func.__condition_type__ = condition["type"]
else:
raise ValueError(
"Condition dict must contain 'conditions' or 'methods'"
)
elif callable(condition) and hasattr(condition, "__name__"):
func.__trigger_methods__ = [condition.__name__]
func.__condition_type__ = "OR"
@@ -298,14 +313,15 @@ def or_(*conditions: str | dict | Callable) -> dict:
*conditions : Union[str, dict, Callable]
Variable number of conditions that can be:
- str: Method names
- dict: Existing condition dictionaries
- dict: Existing condition dictionaries (nested conditions)
- Callable: Method references
Returns
-------
dict
A condition dictionary with format:
{"type": "OR", "methods": list_of_method_names}
{"type": "OR", "conditions": list_of_conditions}
where each condition can be a string (method name) or a nested dict
Raises
------
@@ -317,18 +333,22 @@ def or_(*conditions: str | dict | Callable) -> dict:
>>> @listen(or_("success", "timeout"))
>>> def handle_completion(self):
... pass
>>> @listen(or_(and_("step1", "step2"), "step3"))
>>> def handle_nested(self):
... pass
"""
methods = []
processed_conditions: list[str | dict[str, Any]] = []
for condition in conditions:
if isinstance(condition, dict) and "methods" in condition:
methods.extend(condition["methods"])
if isinstance(condition, dict):
processed_conditions.append(condition)
elif isinstance(condition, str):
methods.append(condition)
processed_conditions.append(condition)
elif callable(condition):
methods.append(getattr(condition, "__name__", repr(condition)))
processed_conditions.append(getattr(condition, "__name__", repr(condition)))
else:
raise ValueError("Invalid condition in or_()")
return {"type": "OR", "methods": methods}
return {"type": "OR", "conditions": processed_conditions}
def and_(*conditions: str | dict | Callable) -> dict:
@@ -344,14 +364,15 @@ def and_(*conditions: str | dict | Callable) -> dict:
*conditions : Union[str, dict, Callable]
Variable number of conditions that can be:
- str: Method names
- dict: Existing condition dictionaries
- dict: Existing condition dictionaries (nested conditions)
- Callable: Method references
Returns
-------
dict
A condition dictionary with format:
{"type": "AND", "methods": list_of_method_names}
{"type": "AND", "conditions": list_of_conditions}
where each condition can be a string (method name) or a nested dict
Raises
------
@@ -363,18 +384,69 @@ def and_(*conditions: str | dict | Callable) -> dict:
>>> @listen(and_("validated", "processed"))
>>> def handle_complete_data(self):
... pass
>>> @listen(and_(or_("step1", "step2"), "step3"))
>>> def handle_nested(self):
... pass
"""
methods = []
processed_conditions: list[str | dict[str, Any]] = []
for condition in conditions:
if isinstance(condition, dict) and "methods" in condition:
methods.extend(condition["methods"])
if isinstance(condition, dict):
processed_conditions.append(condition)
elif isinstance(condition, str):
methods.append(condition)
processed_conditions.append(condition)
elif callable(condition):
methods.append(getattr(condition, "__name__", repr(condition)))
processed_conditions.append(getattr(condition, "__name__", repr(condition)))
else:
raise ValueError("Invalid condition in and_()")
return {"type": "AND", "methods": methods}
return {"type": "AND", "conditions": processed_conditions}
def _normalize_condition(condition: str | dict | list) -> dict:
"""Normalize a condition to standard format with 'conditions' key.
Args:
condition: Can be a string (method name), dict (condition), or list
Returns:
Normalized dict with 'type' and 'conditions' keys
"""
if isinstance(condition, str):
return {"type": "OR", "conditions": [condition]}
if isinstance(condition, dict):
if "conditions" in condition:
return condition
if "methods" in condition:
return {"type": condition["type"], "conditions": condition["methods"]}
return condition
if isinstance(condition, list):
return {"type": "OR", "conditions": condition}
return {"type": "OR", "conditions": [condition]}
def _extract_all_methods(condition: str | dict | list) -> list[str]:
"""Extract all method names from a condition (including nested).
Args:
condition: Can be a string, dict, or list
Returns:
List of all method names in the condition tree
"""
if isinstance(condition, str):
return [condition]
if isinstance(condition, dict):
normalized = _normalize_condition(condition)
methods = []
for sub_cond in normalized.get("conditions", []):
methods.extend(_extract_all_methods(sub_cond))
return methods
if isinstance(condition, list):
methods = []
for item in condition:
methods.extend(_extract_all_methods(item))
return methods
return []
class FlowMeta(type):
@@ -402,7 +474,10 @@ class FlowMeta(type):
if hasattr(attr_value, "__trigger_methods__"):
methods = attr_value.__trigger_methods__
condition_type = getattr(attr_value, "__condition_type__", "OR")
listeners[attr_name] = (condition_type, methods)
if hasattr(attr_value, "__trigger_condition__"):
listeners[attr_name] = attr_value.__trigger_condition__
else:
listeners[attr_name] = (condition_type, methods)
if (
hasattr(attr_value, "__is_router__")
@@ -822,6 +897,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
# Clear completed methods and outputs for a fresh start
self._completed_methods.clear()
self._method_outputs.clear()
self._pending_and_listeners.clear()
else:
# We're restoring from persistence, set the flag
self._is_execution_resuming = True
@@ -1086,10 +1162,16 @@ class Flow(Generic[T], metaclass=FlowMeta):
for method_name in self._start_methods:
# Check if this start method is triggered by the current trigger
if method_name in self._listeners:
condition_type, trigger_methods = self._listeners[
method_name
]
if current_trigger in trigger_methods:
condition_data = self._listeners[method_name]
should_trigger = False
if isinstance(condition_data, tuple):
_, trigger_methods = condition_data
should_trigger = current_trigger in trigger_methods
elif isinstance(condition_data, dict):
all_methods = _extract_all_methods(condition_data)
should_trigger = current_trigger in all_methods
if should_trigger:
# Only execute if this is a cycle (method was already completed)
if method_name in self._completed_methods:
# For router-triggered start methods in cycles, temporarily clear resumption flag
@@ -1099,6 +1181,51 @@ class Flow(Generic[T], metaclass=FlowMeta):
await self._execute_start_method(method_name)
self._is_execution_resuming = was_resuming
def _evaluate_condition(
self, condition: str | dict, trigger_method: str, listener_name: str
) -> bool:
"""Recursively evaluate a condition (simple or nested).
Args:
condition: Can be a string (method name) or dict (nested condition)
trigger_method: The method that just completed
listener_name: Name of the listener being evaluated
Returns:
True if the condition is satisfied, False otherwise
"""
if isinstance(condition, str):
return condition == trigger_method
if isinstance(condition, dict):
normalized = _normalize_condition(condition)
cond_type = normalized.get("type", "OR")
sub_conditions = normalized.get("conditions", [])
if cond_type == "OR":
return any(
self._evaluate_condition(sub_cond, trigger_method, listener_name)
for sub_cond in sub_conditions
)
if cond_type == "AND":
pending_key = f"{listener_name}:{id(condition)}"
if pending_key not in self._pending_and_listeners:
all_methods = set(_extract_all_methods(condition))
self._pending_and_listeners[pending_key] = all_methods
if trigger_method in self._pending_and_listeners[pending_key]:
self._pending_and_listeners[pending_key].discard(trigger_method)
if not self._pending_and_listeners[pending_key]:
self._pending_and_listeners.pop(pending_key, None)
return True
return False
return False
def _find_triggered_methods(
self, trigger_method: str, router_only: bool
) -> list[str]:
@@ -1106,7 +1233,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
Finds all methods that should be triggered based on conditions.
This internal method evaluates both OR and AND conditions to determine
which methods should be executed next in the flow.
which methods should be executed next in the flow. Supports nested conditions.
Parameters
----------
@@ -1123,14 +1250,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
Notes
-----
- Handles both OR and AND conditions:
* OR: Triggers if any condition is met
* AND: Triggers only when all conditions are met
- Handles both OR and AND conditions, including nested combinations
- Maintains state for AND conditions using _pending_and_listeners
- Separates router and normal listener evaluation
"""
triggered = []
for listener_name, (condition_type, methods) in self._listeners.items():
for listener_name, condition_data in self._listeners.items():
is_router = listener_name in self._routers
if router_only != is_router:
@@ -1139,23 +1265,29 @@ class Flow(Generic[T], metaclass=FlowMeta):
if not router_only and listener_name in self._start_methods:
continue
if condition_type == "OR":
# If the trigger_method matches any in methods, run this
if trigger_method in methods:
triggered.append(listener_name)
elif condition_type == "AND":
# Initialize pending methods for this listener if not already done
if listener_name not in self._pending_and_listeners:
self._pending_and_listeners[listener_name] = set(methods)
# Remove the trigger method from pending methods
if trigger_method in self._pending_and_listeners[listener_name]:
self._pending_and_listeners[listener_name].discard(trigger_method)
if isinstance(condition_data, tuple):
condition_type, methods = condition_data
if not self._pending_and_listeners[listener_name]:
# All required methods have been executed
if condition_type == "OR":
if trigger_method in methods:
triggered.append(listener_name)
elif condition_type == "AND":
if listener_name not in self._pending_and_listeners:
self._pending_and_listeners[listener_name] = set(methods)
if trigger_method in self._pending_and_listeners[listener_name]:
self._pending_and_listeners[listener_name].discard(
trigger_method
)
if not self._pending_and_listeners[listener_name]:
triggered.append(listener_name)
self._pending_and_listeners.pop(listener_name, None)
elif isinstance(condition_data, dict):
if self._evaluate_condition(
condition_data, trigger_method, listener_name
):
triggered.append(listener_name)
# Reset pending methods for this listener
self._pending_and_listeners.pop(listener_name, None)
return triggered
@@ -1218,7 +1350,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
raise
def _log_flow_event(
self, message: str, color: str = "yellow", level: str = "info"
self, message: str, color: PrinterColor | None = "yellow", level: str = "info"
) -> None:
"""Centralized logging method for flow events.

View File

@@ -5,7 +5,7 @@ import logging
import threading
import uuid
import warnings
from collections.abc import Callable, Sequence
from collections.abc import Callable
from concurrent.futures import Future
from copy import copy as shallow_copy
from hashlib import md5
@@ -152,15 +152,6 @@ class Task(BaseModel):
default=None,
description="Function or string description of a guardrail to validate task output before proceeding to next task",
)
guardrails: (
Sequence[Callable[[TaskOutput], tuple[bool, Any]] | str]
| Callable[[TaskOutput], tuple[bool, Any]]
| str
| None
) = Field(
default=None,
description="List of guardrails to validate task output before proceeding to next task. Also supports a single guardrail function or string description of a guardrail to validate task output before proceeding to next task",
)
max_retries: int | None = Field(
default=None,
description="[DEPRECATED] Maximum number of retries when guardrail fails. Use guardrail_max_retries instead. Will be removed in v1.0.0",
@@ -277,44 +268,6 @@ class Task(BaseModel):
return self
@model_validator(mode="after")
def ensure_guardrails_is_list_of_callables(self) -> "Task":
guardrails = []
if self.guardrails is not None and (
not isinstance(self.guardrails, (list, tuple)) or len(self.guardrails) > 0
):
if self.agent is None:
raise ValueError("Agent is required to use guardrails")
if callable(self.guardrails):
guardrails.append(self.guardrails)
elif isinstance(self.guardrails, str):
from crewai.tasks.llm_guardrail import LLMGuardrail
guardrails.append(
LLMGuardrail(description=self.guardrails, llm=self.agent.llm)
)
if isinstance(self.guardrails, list):
for guardrail in self.guardrails:
if callable(guardrail):
guardrails.append(guardrail)
elif isinstance(guardrail, str):
from crewai.tasks.llm_guardrail import LLMGuardrail
guardrails.append(
LLMGuardrail(description=guardrail, llm=self.agent.llm)
)
else:
raise ValueError("Guardrail must be a callable or a string")
self._guardrails = guardrails
if self._guardrails:
self.guardrail = None
self._guardrail = None
return self
@field_validator("id", mode="before")
@classmethod
def _deny_user_set_id(cls, v: UUID4 | None) -> None:
@@ -503,23 +456,48 @@ class Task(BaseModel):
output_format=self._get_output_format(),
)
if self._guardrails:
for guardrail in self._guardrails:
task_output = self._invoke_guardrail_function(
task_output=task_output,
agent=agent,
tools=tools,
guardrail=guardrail,
if self._guardrail:
guardrail_result = process_guardrail(
output=task_output,
guardrail=self._guardrail,
retry_count=self.retry_count,
event_source=self,
from_task=self,
from_agent=agent,
)
if not guardrail_result.success:
if self.retry_count >= self.guardrail_max_retries:
raise Exception(
f"Task failed guardrail validation after {self.guardrail_max_retries} retries. "
f"Last error: {guardrail_result.error}"
)
self.retry_count += 1
context = self.i18n.errors("validation_error").format(
guardrail_result_error=guardrail_result.error,
task_output=task_output.raw,
)
printer = Printer()
printer.print(
content=f"Guardrail blocked, retrying, due to: {guardrail_result.error}\n",
color="yellow",
)
return self._execute_core(agent, context, tools)
if guardrail_result.result is None:
raise Exception(
"Task guardrail returned None as result. This is not allowed."
)
# backwards support
if self._guardrail:
task_output = self._invoke_guardrail_function(
task_output=task_output,
agent=agent,
tools=tools,
guardrail=self._guardrail,
)
if isinstance(guardrail_result.result, str):
task_output.raw = guardrail_result.result
pydantic_output, json_output = self._export_output(
guardrail_result.result
)
task_output.pydantic = pydantic_output
task_output.json_dict = json_output
elif isinstance(guardrail_result.result, TaskOutput):
task_output = guardrail_result.result
self.output = task_output
self.end_time = datetime.datetime.now()
@@ -811,55 +789,3 @@ Follow these guidelines:
Fingerprint: The fingerprint of the task
"""
return self.security_config.fingerprint
def _invoke_guardrail_function(
self,
task_output: TaskOutput,
agent: BaseAgent,
tools: list[BaseTool],
guardrail: Callable | None,
) -> TaskOutput:
if guardrail:
guardrail_result = process_guardrail(
output=task_output,
guardrail=guardrail,
retry_count=self.retry_count,
event_source=self,
from_task=self,
from_agent=agent,
)
if not guardrail_result.success:
if self.retry_count >= self.guardrail_max_retries:
raise Exception(
f"Task failed guardrail validation after {self.guardrail_max_retries} retries. "
f"Last error: {guardrail_result.error}"
)
self.retry_count += 1
context = self.i18n.errors("validation_error").format(
guardrail_result_error=guardrail_result.error,
task_output=task_output.raw,
)
printer = Printer()
printer.print(
content=f"Guardrail blocked, retrying, due to: {guardrail_result.error}\n",
color="yellow",
)
return self._execute_core(agent, context, tools)
if guardrail_result.result is None:
raise Exception(
"Task guardrail returned None as result. This is not allowed."
)
if isinstance(guardrail_result.result, str):
task_output.raw = guardrail_result.result
pydantic_output, json_output = self._export_output(
guardrail_result.result
)
task_output.pydantic = pydantic_output
task_output.json_dict = json_output
elif isinstance(guardrail_result.result, TaskOutput):
task_output = guardrail_result.result
return task_output

View File

@@ -1,6 +1,11 @@
"""Utility for colored console output."""
from typing import Final, Literal, NamedTuple
from __future__ import annotations
from typing import TYPE_CHECKING, Final, Literal, NamedTuple
if TYPE_CHECKING:
from _typeshed import SupportsWrite
PrinterColor = Literal[
"purple",
@@ -54,13 +59,22 @@ class Printer:
@staticmethod
def print(
content: str | list[ColoredText], color: PrinterColor | None = None
content: str | list[ColoredText],
color: PrinterColor | None = None,
sep: str | None = " ",
end: str | None = "\n",
file: SupportsWrite[str] | None = None,
flush: Literal[False] = False,
) -> None:
"""Prints content to the console with optional color formatting.
Args:
content: Either a string or a list of ColoredText objects for multicolor output.
color: Optional color for the text when content is a string. Ignored when content is a list.
sep: Separator to use between the text and color.
end: String appended after the last value.
file: A file-like object (stream); defaults to the current sys.stdout.
flush: Whether to forcibly flush the stream.
"""
if isinstance(content, str):
content = [ColoredText(content, color)]
@@ -68,5 +82,9 @@ class Printer:
"".join(
f"{_COLOR_CODES[c.color] if c.color else ''}{c.text}{RESET}"
for c in content
)
),
sep=sep,
end=end,
file=file,
flush=flush,
)

View File

@@ -4744,3 +4744,81 @@ def test_ensure_exchanged_messages_are_propagated_to_external_memory():
assert "Researcher" in messages[0]["content"]
assert messages[1]["role"] == "user"
assert "Research a topic to teach a kid aged 6 about math" in messages[1]["content"]
@pytest.mark.vcr(filter_headers=["authorization"])
def test_crew_kickoff_stream(researcher):
"""Test that crew.kickoff_stream() yields events during execution."""
task = Task(
description="Research a topic about AI",
expected_output="A brief summary about AI",
agent=researcher,
)
crew = Crew(agents=[researcher], tasks=[task])
events = list(crew.kickoff_stream())
assert len(events) > 0
event_types = [event["type"] for event in events]
assert "crew_kickoff_started" in event_types
assert "final_output" in event_types
final_output_event = next(e for e in events if e["type"] == "final_output")
assert "output" in final_output_event["data"]
assert isinstance(final_output_event["data"]["output"], str)
assert len(final_output_event["data"]["output"]) > 0
@pytest.mark.vcr(filter_headers=["authorization"])
def test_crew_kickoff_stream_with_inputs(researcher):
"""Test that crew.kickoff_stream() works with inputs."""
task = Task(
description="Research about {topic}",
expected_output="A brief summary about {topic}",
agent=researcher,
)
crew = Crew(agents=[researcher], tasks=[task])
events = list(crew.kickoff_stream(inputs={"topic": "machine learning"}))
assert len(events) > 0
event_types = [event["type"] for event in events]
assert "crew_kickoff_started" in event_types
assert "final_output" in event_types
@pytest.mark.vcr(filter_headers=["authorization"])
def test_crew_kickoff_stream_includes_llm_chunks(researcher):
"""Test that crew.kickoff_stream() includes LLM stream chunks."""
task = Task(
description="Write a short poem about AI",
expected_output="A 2-line poem",
agent=researcher,
)
crew = Crew(agents=[researcher], tasks=[task])
events = list(crew.kickoff_stream())
event_types = [event["type"] for event in events]
assert "task_started" in event_types or "agent_execution_started" in event_types
def test_crew_kickoff_stream_handles_errors(researcher):
"""Test that crew.kickoff_stream() properly handles errors."""
task = Task(
description="This task will fail",
expected_output="Should not complete",
agent=researcher,
)
crew = Crew(agents=[researcher], tasks=[task])
with patch("crewai.crew.Crew.kickoff", side_effect=Exception("Test error")):
with pytest.raises(Exception, match="Test error"):
list(crew.kickoff_stream())

View File

@@ -1,8 +1,7 @@
import asyncio
import threading
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor
from typing import Any
from typing import Dict, Any, Callable
from unittest.mock import patch
import pytest
@@ -25,12 +24,9 @@ def simple_agent_factory():
@pytest.fixture
def simple_task_factory():
def create_task(name: str, agent: Agent, callback: Callable | None = None) -> Task:
def create_task(name: str, callback: Callable = None) -> Task:
return Task(
description=f"Task for {name}",
expected_output="Done",
agent=agent,
callback=callback,
description=f"Task for {name}", expected_output="Done", callback=callback
)
return create_task
@@ -38,9 +34,10 @@ def simple_task_factory():
@pytest.fixture
def crew_factory(simple_agent_factory, simple_task_factory):
def create_crew(name: str, task_callback: Callable | None = None) -> Crew:
def create_crew(name: str, task_callback: Callable = None) -> Crew:
agent = simple_agent_factory(name)
task = simple_task_factory(name, agent=agent, callback=task_callback)
task = simple_task_factory(name, callback=task_callback)
task.agent = agent
return Crew(agents=[agent], tasks=[task], verbose=False)
@@ -53,7 +50,7 @@ class TestCrewThreadSafety:
mock_execute_task.return_value = "Task completed"
num_crews = 5
def run_crew_with_context_check(crew_id: str) -> dict[str, Any]:
def run_crew_with_context_check(crew_id: str) -> Dict[str, Any]:
results = {"crew_id": crew_id, "contexts": []}
def check_context_task(output):
@@ -108,28 +105,28 @@ class TestCrewThreadSafety:
before_ctx = next(
ctx for ctx in result["contexts"] if ctx["stage"] == "before_kickoff"
)
assert before_ctx["crew_id"] is None, (
f"Context should be None before kickoff for {result['crew_id']}"
)
assert (
before_ctx["crew_id"] is None
), f"Context should be None before kickoff for {result['crew_id']}"
task_ctx = next(
ctx for ctx in result["contexts"] if ctx["stage"] == "task_callback"
)
assert task_ctx["crew_id"] == crew_uuid, (
f"Context mismatch during task for {result['crew_id']}"
)
assert (
task_ctx["crew_id"] == crew_uuid
), f"Context mismatch during task for {result['crew_id']}"
after_ctx = next(
ctx for ctx in result["contexts"] if ctx["stage"] == "after_kickoff"
)
assert after_ctx["crew_id"] is None, (
f"Context should be None after kickoff for {result['crew_id']}"
)
assert (
after_ctx["crew_id"] is None
), f"Context should be None after kickoff for {result['crew_id']}"
thread_name = before_ctx["thread"]
assert "ThreadPoolExecutor" in thread_name, (
f"Should run in thread pool for {result['crew_id']}"
)
assert (
"ThreadPoolExecutor" in thread_name
), f"Should run in thread pool for {result['crew_id']}"
@pytest.mark.asyncio
@patch("crewai.Agent.execute_task")
@@ -137,7 +134,7 @@ class TestCrewThreadSafety:
mock_execute_task.return_value = "Task completed"
num_crews = 5
async def run_crew_async(crew_id: str) -> dict[str, Any]:
async def run_crew_async(crew_id: str) -> Dict[str, Any]:
task_context = {"crew_id": crew_id, "context": None}
def capture_context(output):
@@ -165,12 +162,12 @@ class TestCrewThreadSafety:
crew_uuid = result["crew_uuid"]
task_ctx = result["task_context"]["context"]
assert task_ctx is not None, (
f"Context should exist during task for {result['crew_id']}"
)
assert task_ctx["crew_id"] == crew_uuid, (
f"Context mismatch for {result['crew_id']}"
)
assert (
task_ctx is not None
), f"Context should exist during task for {result['crew_id']}"
assert (
task_ctx["crew_id"] == crew_uuid
), f"Context mismatch for {result['crew_id']}"
@patch("crewai.Agent.execute_task")
def test_concurrent_kickoff_for_each(self, mock_execute_task, crew_factory):
@@ -196,9 +193,9 @@ class TestCrewThreadSafety:
assert len(contexts_captured) == len(inputs)
context_ids = [ctx["context_id"] for ctx in contexts_captured]
assert len(set(context_ids)) == len(inputs), (
"Each execution should have unique context"
)
assert len(set(context_ids)) == len(
inputs
), "Each execution should have unique context"
@patch("crewai.Agent.execute_task")
def test_no_context_leakage_between_crews(self, mock_execute_task, crew_factory):

View File

@@ -6,15 +6,15 @@ from datetime import datetime
import pytest
from pydantic import BaseModel
from crewai.flow.flow import Flow, and_, listen, or_, router, start
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.flow_events import (
FlowFinishedEvent,
FlowStartedEvent,
FlowPlotEvent,
FlowStartedEvent,
MethodExecutionFinishedEvent,
MethodExecutionStartedEvent,
)
from crewai.flow.flow import Flow, and_, listen, or_, router, start
def test_simple_sequential_flow():
@@ -679,11 +679,11 @@ def test_structured_flow_event_emission():
assert isinstance(received_events[3], MethodExecutionStartedEvent)
assert received_events[3].method_name == "send_welcome_message"
assert received_events[3].params == {}
assert getattr(received_events[3].state, "sent") is False
assert received_events[3].state.sent is False
assert isinstance(received_events[4], MethodExecutionFinishedEvent)
assert received_events[4].method_name == "send_welcome_message"
assert getattr(received_events[4].state, "sent") is True
assert received_events[4].state.sent is True
assert received_events[4].result == "Welcome, Anakin!"
assert isinstance(received_events[5], FlowFinishedEvent)
@@ -894,3 +894,75 @@ def test_flow_name():
flow = MyFlow()
assert flow.name == "MyFlow"
def test_nested_and_or_conditions():
"""Test nested conditions like or_(and_(A, B), and_(C, D)).
Reproduces bug from issue #3719 where nested conditions are flattened,
causing premature execution.
"""
execution_order = []
class NestedConditionFlow(Flow):
@start()
def method_1(self):
execution_order.append("method_1")
@listen(method_1)
def method_2(self):
execution_order.append("method_2")
@router(method_2)
def method_3(self):
execution_order.append("method_3")
# Choose b_condition path
return "b_condition"
@listen("b_condition")
def method_5(self):
execution_order.append("method_5")
@listen(method_5)
async def method_4(self):
execution_order.append("method_4")
@listen(or_("a_condition", "b_condition"))
async def method_6(self):
execution_order.append("method_6")
@listen(
or_(
and_("a_condition", method_6),
and_(method_6, method_4),
)
)
def method_7(self):
execution_order.append("method_7")
@listen(method_7)
async def method_8(self):
execution_order.append("method_8")
flow = NestedConditionFlow()
flow.kickoff()
# Verify execution happened
assert "method_1" in execution_order
assert "method_2" in execution_order
assert "method_3" in execution_order
assert "method_5" in execution_order
assert "method_4" in execution_order
assert "method_6" in execution_order
assert "method_7" in execution_order
assert "method_8" in execution_order
# Critical assertion: method_7 should only execute AFTER both method_6 AND method_4
# Since b_condition was returned, method_6 triggers on b_condition
# method_7 requires: (a_condition AND method_6) OR (method_6 AND method_4)
# The second condition (method_6 AND method_4) should be the one that triggers
assert execution_order.index("method_7") > execution_order.index("method_6")
assert execution_order.index("method_7") > execution_order.index("method_4")
# method_8 should execute after method_7
assert execution_order.index("method_8") > execution_order.index("method_7")

View File

@@ -14,24 +14,6 @@ from crewai.tasks.llm_guardrail import LLMGuardrail
from crewai.tasks.task_output import TaskOutput
def create_smart_task(**kwargs):
"""
Smart task factory that automatically assigns a mock agent when guardrails are present.
This maintains backward compatibility while handling the agent requirement for guardrails.
"""
guardrails_list = kwargs.get("guardrails")
has_guardrails = kwargs.get("guardrail") is not None or (
guardrails_list is not None and len(guardrails_list) > 0
)
if has_guardrails and kwargs.get("agent") is None:
kwargs["agent"] = Agent(
role="test_agent", goal="test_goal", backstory="test_backstory"
)
return Task(**kwargs)
def test_task_without_guardrail():
"""Test that tasks work normally without guardrails (backward compatibility)."""
agent = Mock()
@@ -39,7 +21,7 @@ def test_task_without_guardrail():
agent.execute_task.return_value = "test result"
agent.crew = None
task = create_smart_task(description="Test task", expected_output="Output")
task = Task(description="Test task", expected_output="Output")
result = task.execute_sync(agent=agent)
assert isinstance(result, TaskOutput)
@@ -57,9 +39,7 @@ def test_task_with_successful_guardrail_func():
agent.execute_task.return_value = "test result"
agent.crew = None
task = create_smart_task(
description="Test task", expected_output="Output", guardrail=guardrail
)
task = Task(description="Test task", expected_output="Output", guardrail=guardrail)
result = task.execute_sync(agent=agent)
assert isinstance(result, TaskOutput)
@@ -77,7 +57,7 @@ def test_task_with_failing_guardrail():
agent.execute_task.side_effect = ["bad result", "good result"]
agent.crew = None
task = create_smart_task(
task = Task(
description="Test task",
expected_output="Output",
guardrail=guardrail,
@@ -104,7 +84,7 @@ def test_task_with_guardrail_retries():
agent.execute_task.return_value = "bad result"
agent.crew = None
task = create_smart_task(
task = Task(
description="Test task",
expected_output="Output",
guardrail=guardrail,
@@ -129,7 +109,7 @@ def test_guardrail_error_in_context():
agent.role = "test_agent"
agent.crew = None
task = create_smart_task(
task = Task(
description="Test task",
expected_output="Output",
guardrail=guardrail,
@@ -197,7 +177,7 @@ def test_guardrail_emits_events(sample_agent):
started_guardrail = []
completed_guardrail = []
task = create_smart_task(
task = Task(
description="Gather information about available books on the First World War",
agent=sample_agent,
expected_output="A list of available books on the First World War",
@@ -230,7 +210,7 @@ def test_guardrail_emits_events(sample_agent):
def custom_guardrail(result: TaskOutput):
return (True, "good result from callable function")
task = create_smart_task(
task = Task(
description="Test task",
expected_output="Output",
guardrail=custom_guardrail,
@@ -282,7 +262,7 @@ def test_guardrail_when_an_error_occurs(sample_agent, task_output):
match="Error while validating the task output: Unexpected error",
),
):
task = create_smart_task(
task = Task(
description="Gather information about available books on the First World War",
agent=sample_agent,
expected_output="A list of available books on the First World War",
@@ -304,7 +284,7 @@ def test_hallucination_guardrail_integration():
context="Test reference context for validation", llm=mock_llm, threshold=8.0
)
task = create_smart_task(
task = Task(
description="Test task with hallucination guardrail",
expected_output="Valid output",
guardrail=guardrail,
@@ -324,352 +304,3 @@ def test_hallucination_guardrail_description_in_events():
event = LLMGuardrailStartedEvent(guardrail=guardrail, retry_count=0)
assert event.guardrail == "HallucinationGuardrail (no-op)"
def test_multiple_guardrails_sequential_processing():
"""Test that multiple guardrails are processed sequentially."""
def first_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""First guardrail adds prefix."""
return (True, f"[FIRST] {result.raw}")
def second_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Second guardrail adds suffix."""
return (True, f"{result.raw} [SECOND]")
def third_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Third guardrail converts to uppercase."""
return (True, result.raw.upper())
agent = Mock()
agent.role = "sequential_agent"
agent.execute_task.return_value = "original text"
agent.crew = None
task = create_smart_task(
description="Test sequential guardrails",
expected_output="Processed text",
guardrails=[first_guardrail, second_guardrail, third_guardrail],
)
result = task.execute_sync(agent=agent)
assert result.raw == "[FIRST] ORIGINAL TEXT [SECOND]"
def test_multiple_guardrails_with_validation_failure():
"""Test multiple guardrails where one fails validation."""
def length_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Ensure minimum length."""
if len(result.raw) < 10:
return (False, "Text too short")
return (True, result.raw)
def format_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Add formatting only if not already formatted."""
if not result.raw.startswith("Formatted:"):
return (True, f"Formatted: {result.raw}")
return (True, result.raw)
def validation_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Final validation."""
if "Formatted:" not in result.raw:
return (False, "Missing formatting")
return (True, result.raw)
# Use a callable that tracks calls and returns appropriate values
call_count = 0
def mock_execute_task(*args, **kwargs):
nonlocal call_count
call_count += 1
result = (
"short"
if call_count == 1
else "this is a longer text that meets requirements"
)
return result
agent = Mock()
agent.role = "validation_agent"
agent.execute_task = mock_execute_task
agent.crew = None
task = create_smart_task(
description="Test guardrails with validation",
expected_output="Valid formatted text",
guardrails=[length_guardrail, format_guardrail, validation_guardrail],
guardrail_max_retries=2,
)
result = task.execute_sync(agent=agent)
# The second call should be processed through all guardrails
assert result.raw == "Formatted: this is a longer text that meets requirements"
assert task.retry_count == 1
def test_multiple_guardrails_with_mixed_string_and_taskoutput():
"""Test guardrails that return both strings and TaskOutput objects."""
def string_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Returns a string."""
return (True, f"String: {result.raw}")
def taskoutput_guardrail(result: TaskOutput) -> tuple[bool, TaskOutput]:
"""Returns a TaskOutput object."""
new_output = TaskOutput(
name=result.name,
description=result.description,
expected_output=result.expected_output,
raw=f"TaskOutput: {result.raw}",
agent=result.agent,
output_format=result.output_format,
)
return (True, new_output)
def final_string_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Final string transformation."""
return (True, f"Final: {result.raw}")
agent = Mock()
agent.role = "mixed_agent"
agent.execute_task.return_value = "original"
agent.crew = None
task = create_smart_task(
description="Test mixed return types",
expected_output="Mixed processing",
guardrails=[string_guardrail, taskoutput_guardrail, final_string_guardrail],
)
result = task.execute_sync(agent=agent)
assert result.raw == "Final: TaskOutput: String: original"
def test_multiple_guardrails_with_retry_on_middle_guardrail():
"""Test that retry works correctly when a middle guardrail fails."""
call_count = {"first": 0, "second": 0, "third": 0}
def first_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Always succeeds."""
call_count["first"] += 1
return (True, f"First({call_count['first']}): {result.raw}")
def second_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Fails on first attempt, succeeds on second."""
call_count["second"] += 1
if call_count["second"] == 1:
return (False, "Second guardrail failed on first attempt")
return (True, f"Second({call_count['second']}): {result.raw}")
def third_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Always succeeds."""
call_count["third"] += 1
return (True, f"Third({call_count['third']}): {result.raw}")
agent = Mock()
agent.role = "retry_agent"
agent.execute_task.return_value = "base"
agent.crew = None
task = create_smart_task(
description="Test retry in middle guardrail",
expected_output="Retry handling",
guardrails=[first_guardrail, second_guardrail, third_guardrail],
guardrail_max_retries=2,
)
result = task.execute_sync(agent=agent)
# Based on the test output, the behavior is different than expected
# The guardrails are called multiple times, so let's verify the retry happened
assert task.retry_count == 1
# Verify that the second guardrail eventually succeeded
assert "Second(2)" in result.raw or call_count["second"] >= 2
def test_multiple_guardrails_with_max_retries_exceeded():
"""Test that exception is raised when max retries exceeded with multiple guardrails."""
def passing_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Always passes."""
return (True, f"Passed: {result.raw}")
def failing_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Always fails."""
return (False, "This guardrail always fails")
agent = Mock()
agent.role = "failing_agent"
agent.execute_task.return_value = "test"
agent.crew = None
task = create_smart_task(
description="Test max retries with multiple guardrails",
expected_output="Will fail",
guardrails=[passing_guardrail, failing_guardrail],
guardrail_max_retries=1,
)
with pytest.raises(Exception) as exc_info:
task.execute_sync(agent=agent)
assert "Task failed guardrail validation after 1 retries" in str(exc_info.value)
assert "This guardrail always fails" in str(exc_info.value)
assert task.retry_count == 1
def test_multiple_guardrails_empty_list():
"""Test that empty guardrails list works correctly."""
agent = Mock()
agent.role = "empty_agent"
agent.execute_task.return_value = "no guardrails"
agent.crew = None
task = create_smart_task(
description="Test empty guardrails list",
expected_output="No processing",
guardrails=[],
)
result = task.execute_sync(agent=agent)
assert result.raw == "no guardrails"
def test_multiple_guardrails_with_llm_guardrails():
"""Test mixing callable and LLM guardrails."""
def callable_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Callable guardrail."""
return (True, f"Callable: {result.raw}")
# Create a proper mock agent without config issues
from crewai import Agent
agent = Agent(
role="mixed_guardrail_agent", goal="Test goal", backstory="Test backstory"
)
task = create_smart_task(
description="Test mixed guardrail types",
expected_output="Mixed processing",
guardrails=[callable_guardrail, "Ensure the output is professional"],
agent=agent,
)
# The LLM guardrail will be converted to LLMGuardrail internally
assert len(task._guardrails) == 2
assert callable(task._guardrails[0])
assert callable(task._guardrails[1]) # LLMGuardrail is callable
def test_multiple_guardrails_processing_order():
"""Test that guardrails are processed in the correct order."""
processing_order = []
def first_guardrail(result: TaskOutput) -> tuple[bool, str]:
processing_order.append("first")
return (True, f"1-{result.raw}")
def second_guardrail(result: TaskOutput) -> tuple[bool, str]:
processing_order.append("second")
return (True, f"2-{result.raw}")
def third_guardrail(result: TaskOutput) -> tuple[bool, str]:
processing_order.append("third")
return (True, f"3-{result.raw}")
agent = Mock()
agent.role = "order_agent"
agent.execute_task.return_value = "base"
agent.crew = None
task = create_smart_task(
description="Test processing order",
expected_output="Ordered processing",
guardrails=[first_guardrail, second_guardrail, third_guardrail],
)
result = task.execute_sync(agent=agent)
assert processing_order == ["first", "second", "third"]
assert result.raw == "3-2-1-base"
def test_multiple_guardrails_with_pydantic_output():
"""Test multiple guardrails with Pydantic output model."""
from pydantic import BaseModel, Field
class TestModel(BaseModel):
content: str = Field(description="The content")
processed: bool = Field(description="Whether it was processed")
def json_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Convert to JSON format."""
import json
data = {"content": result.raw, "processed": True}
return (True, json.dumps(data))
def validation_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Validate JSON structure."""
import json
try:
data = json.loads(result.raw)
if "content" not in data or "processed" not in data:
return (False, "Missing required fields")
return (True, result.raw)
except json.JSONDecodeError:
return (False, "Invalid JSON format")
agent = Mock()
agent.role = "pydantic_agent"
agent.execute_task.return_value = "test content"
agent.crew = None
task = create_smart_task(
description="Test guardrails with Pydantic",
expected_output="Structured output",
guardrails=[json_guardrail, validation_guardrail],
output_pydantic=TestModel,
)
result = task.execute_sync(agent=agent)
# Verify the result is valid JSON and can be parsed
import json
parsed = json.loads(result.raw)
assert parsed["content"] == "test content"
assert parsed["processed"] is True
def test_guardrails_vs_single_guardrail_mutual_exclusion():
"""Test that guardrails list nullifies single guardrail."""
def single_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Single guardrail - should be ignored."""
return (True, f"Single: {result.raw}")
def list_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""List guardrail - should be used."""
return (True, f"List: {result.raw}")
agent = Mock()
agent.role = "exclusion_agent"
agent.execute_task.return_value = "test"
agent.crew = None
task = create_smart_task(
description="Test mutual exclusion",
expected_output="Exclusion test",
guardrail=single_guardrail, # This should be ignored
guardrails=[list_guardrail], # This should be used
)
result = task.execute_sync(agent=agent)
# Should only use the guardrails list, not the single guardrail
assert result.raw == "List: test"
assert task._guardrail is None # Single guardrail should be nullified