mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-07 23:28:30 +00:00
Compare commits
4 Commits
lorenze/ad
...
devin/1760
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4cb8df09dd | ||
|
|
0938279b89 | ||
|
|
42f2b4d551 | ||
|
|
0229390ad1 |
@@ -864,7 +864,6 @@ class Agent(BaseAgent):
|
||||
i18n=self.i18n,
|
||||
original_agent=self,
|
||||
guardrail=self.guardrail,
|
||||
guardrail_max_retries=self.guardrail_max_retries,
|
||||
)
|
||||
|
||||
return await lite_agent.kickoff_async(messages)
|
||||
|
||||
@@ -31,7 +31,7 @@ from crewai.flow.flow_visualizer import plot_flow
|
||||
from crewai.flow.persistence.base import FlowPersistence
|
||||
from crewai.flow.types import FlowExecutionData
|
||||
from crewai.flow.utils import get_possible_return_constants
|
||||
from crewai.utilities.printer import Printer
|
||||
from crewai.utilities.printer import Printer, PrinterColor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -105,7 +105,7 @@ def start(condition: str | dict | Callable | None = None) -> Callable:
|
||||
condition : Optional[Union[str, dict, Callable]], optional
|
||||
Defines when the start method should execute. Can be:
|
||||
- str: Name of a method that triggers this start
|
||||
- dict: Contains "type" ("AND"/"OR") and "methods" (list of triggers)
|
||||
- dict: Result from or_() or and_(), including nested conditions
|
||||
- Callable: A method reference that triggers this start
|
||||
Default is None, meaning unconditional start.
|
||||
|
||||
@@ -140,13 +140,18 @@ def start(condition: str | dict | Callable | None = None) -> Callable:
|
||||
if isinstance(condition, str):
|
||||
func.__trigger_methods__ = [condition]
|
||||
func.__condition_type__ = "OR"
|
||||
elif (
|
||||
isinstance(condition, dict)
|
||||
and "type" in condition
|
||||
and "methods" in condition
|
||||
):
|
||||
func.__trigger_methods__ = condition["methods"]
|
||||
func.__condition_type__ = condition["type"]
|
||||
elif isinstance(condition, dict) and "type" in condition:
|
||||
if "conditions" in condition:
|
||||
func.__trigger_condition__ = condition
|
||||
func.__trigger_methods__ = _extract_all_methods(condition)
|
||||
func.__condition_type__ = condition["type"]
|
||||
elif "methods" in condition:
|
||||
func.__trigger_methods__ = condition["methods"]
|
||||
func.__condition_type__ = condition["type"]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Condition dict must contain 'conditions' or 'methods'"
|
||||
)
|
||||
elif callable(condition) and hasattr(condition, "__name__"):
|
||||
func.__trigger_methods__ = [condition.__name__]
|
||||
func.__condition_type__ = "OR"
|
||||
@@ -172,7 +177,7 @@ def listen(condition: str | dict | Callable) -> Callable:
|
||||
condition : Union[str, dict, Callable]
|
||||
Specifies when the listener should execute. Can be:
|
||||
- str: Name of a method that triggers this listener
|
||||
- dict: Contains "type" ("AND"/"OR") and "methods" (list of triggers)
|
||||
- dict: Result from or_() or and_(), including nested conditions
|
||||
- Callable: A method reference that triggers this listener
|
||||
|
||||
Returns
|
||||
@@ -200,13 +205,18 @@ def listen(condition: str | dict | Callable) -> Callable:
|
||||
if isinstance(condition, str):
|
||||
func.__trigger_methods__ = [condition]
|
||||
func.__condition_type__ = "OR"
|
||||
elif (
|
||||
isinstance(condition, dict)
|
||||
and "type" in condition
|
||||
and "methods" in condition
|
||||
):
|
||||
func.__trigger_methods__ = condition["methods"]
|
||||
func.__condition_type__ = condition["type"]
|
||||
elif isinstance(condition, dict) and "type" in condition:
|
||||
if "conditions" in condition:
|
||||
func.__trigger_condition__ = condition
|
||||
func.__trigger_methods__ = _extract_all_methods(condition)
|
||||
func.__condition_type__ = condition["type"]
|
||||
elif "methods" in condition:
|
||||
func.__trigger_methods__ = condition["methods"]
|
||||
func.__condition_type__ = condition["type"]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Condition dict must contain 'conditions' or 'methods'"
|
||||
)
|
||||
elif callable(condition) and hasattr(condition, "__name__"):
|
||||
func.__trigger_methods__ = [condition.__name__]
|
||||
func.__condition_type__ = "OR"
|
||||
@@ -233,7 +243,7 @@ def router(condition: str | dict | Callable) -> Callable:
|
||||
condition : Union[str, dict, Callable]
|
||||
Specifies when the router should execute. Can be:
|
||||
- str: Name of a method that triggers this router
|
||||
- dict: Contains "type" ("AND"/"OR") and "methods" (list of triggers)
|
||||
- dict: Result from or_() or and_(), including nested conditions
|
||||
- Callable: A method reference that triggers this router
|
||||
|
||||
Returns
|
||||
@@ -266,13 +276,18 @@ def router(condition: str | dict | Callable) -> Callable:
|
||||
if isinstance(condition, str):
|
||||
func.__trigger_methods__ = [condition]
|
||||
func.__condition_type__ = "OR"
|
||||
elif (
|
||||
isinstance(condition, dict)
|
||||
and "type" in condition
|
||||
and "methods" in condition
|
||||
):
|
||||
func.__trigger_methods__ = condition["methods"]
|
||||
func.__condition_type__ = condition["type"]
|
||||
elif isinstance(condition, dict) and "type" in condition:
|
||||
if "conditions" in condition:
|
||||
func.__trigger_condition__ = condition
|
||||
func.__trigger_methods__ = _extract_all_methods(condition)
|
||||
func.__condition_type__ = condition["type"]
|
||||
elif "methods" in condition:
|
||||
func.__trigger_methods__ = condition["methods"]
|
||||
func.__condition_type__ = condition["type"]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Condition dict must contain 'conditions' or 'methods'"
|
||||
)
|
||||
elif callable(condition) and hasattr(condition, "__name__"):
|
||||
func.__trigger_methods__ = [condition.__name__]
|
||||
func.__condition_type__ = "OR"
|
||||
@@ -298,14 +313,15 @@ def or_(*conditions: str | dict | Callable) -> dict:
|
||||
*conditions : Union[str, dict, Callable]
|
||||
Variable number of conditions that can be:
|
||||
- str: Method names
|
||||
- dict: Existing condition dictionaries
|
||||
- dict: Existing condition dictionaries (nested conditions)
|
||||
- Callable: Method references
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
A condition dictionary with format:
|
||||
{"type": "OR", "methods": list_of_method_names}
|
||||
{"type": "OR", "conditions": list_of_conditions}
|
||||
where each condition can be a string (method name) or a nested dict
|
||||
|
||||
Raises
|
||||
------
|
||||
@@ -317,18 +333,22 @@ def or_(*conditions: str | dict | Callable) -> dict:
|
||||
>>> @listen(or_("success", "timeout"))
|
||||
>>> def handle_completion(self):
|
||||
... pass
|
||||
|
||||
>>> @listen(or_(and_("step1", "step2"), "step3"))
|
||||
>>> def handle_nested(self):
|
||||
... pass
|
||||
"""
|
||||
methods = []
|
||||
processed_conditions: list[str | dict[str, Any]] = []
|
||||
for condition in conditions:
|
||||
if isinstance(condition, dict) and "methods" in condition:
|
||||
methods.extend(condition["methods"])
|
||||
if isinstance(condition, dict):
|
||||
processed_conditions.append(condition)
|
||||
elif isinstance(condition, str):
|
||||
methods.append(condition)
|
||||
processed_conditions.append(condition)
|
||||
elif callable(condition):
|
||||
methods.append(getattr(condition, "__name__", repr(condition)))
|
||||
processed_conditions.append(getattr(condition, "__name__", repr(condition)))
|
||||
else:
|
||||
raise ValueError("Invalid condition in or_()")
|
||||
return {"type": "OR", "methods": methods}
|
||||
return {"type": "OR", "conditions": processed_conditions}
|
||||
|
||||
|
||||
def and_(*conditions: str | dict | Callable) -> dict:
|
||||
@@ -344,14 +364,15 @@ def and_(*conditions: str | dict | Callable) -> dict:
|
||||
*conditions : Union[str, dict, Callable]
|
||||
Variable number of conditions that can be:
|
||||
- str: Method names
|
||||
- dict: Existing condition dictionaries
|
||||
- dict: Existing condition dictionaries (nested conditions)
|
||||
- Callable: Method references
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
A condition dictionary with format:
|
||||
{"type": "AND", "methods": list_of_method_names}
|
||||
{"type": "AND", "conditions": list_of_conditions}
|
||||
where each condition can be a string (method name) or a nested dict
|
||||
|
||||
Raises
|
||||
------
|
||||
@@ -363,18 +384,69 @@ def and_(*conditions: str | dict | Callable) -> dict:
|
||||
>>> @listen(and_("validated", "processed"))
|
||||
>>> def handle_complete_data(self):
|
||||
... pass
|
||||
|
||||
>>> @listen(and_(or_("step1", "step2"), "step3"))
|
||||
>>> def handle_nested(self):
|
||||
... pass
|
||||
"""
|
||||
methods = []
|
||||
processed_conditions: list[str | dict[str, Any]] = []
|
||||
for condition in conditions:
|
||||
if isinstance(condition, dict) and "methods" in condition:
|
||||
methods.extend(condition["methods"])
|
||||
if isinstance(condition, dict):
|
||||
processed_conditions.append(condition)
|
||||
elif isinstance(condition, str):
|
||||
methods.append(condition)
|
||||
processed_conditions.append(condition)
|
||||
elif callable(condition):
|
||||
methods.append(getattr(condition, "__name__", repr(condition)))
|
||||
processed_conditions.append(getattr(condition, "__name__", repr(condition)))
|
||||
else:
|
||||
raise ValueError("Invalid condition in and_()")
|
||||
return {"type": "AND", "methods": methods}
|
||||
return {"type": "AND", "conditions": processed_conditions}
|
||||
|
||||
|
||||
def _normalize_condition(condition: str | dict | list) -> dict:
|
||||
"""Normalize a condition to standard format with 'conditions' key.
|
||||
|
||||
Args:
|
||||
condition: Can be a string (method name), dict (condition), or list
|
||||
|
||||
Returns:
|
||||
Normalized dict with 'type' and 'conditions' keys
|
||||
"""
|
||||
if isinstance(condition, str):
|
||||
return {"type": "OR", "conditions": [condition]}
|
||||
if isinstance(condition, dict):
|
||||
if "conditions" in condition:
|
||||
return condition
|
||||
if "methods" in condition:
|
||||
return {"type": condition["type"], "conditions": condition["methods"]}
|
||||
return condition
|
||||
if isinstance(condition, list):
|
||||
return {"type": "OR", "conditions": condition}
|
||||
return {"type": "OR", "conditions": [condition]}
|
||||
|
||||
|
||||
def _extract_all_methods(condition: str | dict | list) -> list[str]:
|
||||
"""Extract all method names from a condition (including nested).
|
||||
|
||||
Args:
|
||||
condition: Can be a string, dict, or list
|
||||
|
||||
Returns:
|
||||
List of all method names in the condition tree
|
||||
"""
|
||||
if isinstance(condition, str):
|
||||
return [condition]
|
||||
if isinstance(condition, dict):
|
||||
normalized = _normalize_condition(condition)
|
||||
methods = []
|
||||
for sub_cond in normalized.get("conditions", []):
|
||||
methods.extend(_extract_all_methods(sub_cond))
|
||||
return methods
|
||||
if isinstance(condition, list):
|
||||
methods = []
|
||||
for item in condition:
|
||||
methods.extend(_extract_all_methods(item))
|
||||
return methods
|
||||
return []
|
||||
|
||||
|
||||
class FlowMeta(type):
|
||||
@@ -402,7 +474,10 @@ class FlowMeta(type):
|
||||
if hasattr(attr_value, "__trigger_methods__"):
|
||||
methods = attr_value.__trigger_methods__
|
||||
condition_type = getattr(attr_value, "__condition_type__", "OR")
|
||||
listeners[attr_name] = (condition_type, methods)
|
||||
if hasattr(attr_value, "__trigger_condition__"):
|
||||
listeners[attr_name] = attr_value.__trigger_condition__
|
||||
else:
|
||||
listeners[attr_name] = (condition_type, methods)
|
||||
|
||||
if (
|
||||
hasattr(attr_value, "__is_router__")
|
||||
@@ -822,6 +897,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
# Clear completed methods and outputs for a fresh start
|
||||
self._completed_methods.clear()
|
||||
self._method_outputs.clear()
|
||||
self._pending_and_listeners.clear()
|
||||
else:
|
||||
# We're restoring from persistence, set the flag
|
||||
self._is_execution_resuming = True
|
||||
@@ -1086,10 +1162,16 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
for method_name in self._start_methods:
|
||||
# Check if this start method is triggered by the current trigger
|
||||
if method_name in self._listeners:
|
||||
condition_type, trigger_methods = self._listeners[
|
||||
method_name
|
||||
]
|
||||
if current_trigger in trigger_methods:
|
||||
condition_data = self._listeners[method_name]
|
||||
should_trigger = False
|
||||
if isinstance(condition_data, tuple):
|
||||
_, trigger_methods = condition_data
|
||||
should_trigger = current_trigger in trigger_methods
|
||||
elif isinstance(condition_data, dict):
|
||||
all_methods = _extract_all_methods(condition_data)
|
||||
should_trigger = current_trigger in all_methods
|
||||
|
||||
if should_trigger:
|
||||
# Only execute if this is a cycle (method was already completed)
|
||||
if method_name in self._completed_methods:
|
||||
# For router-triggered start methods in cycles, temporarily clear resumption flag
|
||||
@@ -1099,6 +1181,51 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
await self._execute_start_method(method_name)
|
||||
self._is_execution_resuming = was_resuming
|
||||
|
||||
def _evaluate_condition(
|
||||
self, condition: str | dict, trigger_method: str, listener_name: str
|
||||
) -> bool:
|
||||
"""Recursively evaluate a condition (simple or nested).
|
||||
|
||||
Args:
|
||||
condition: Can be a string (method name) or dict (nested condition)
|
||||
trigger_method: The method that just completed
|
||||
listener_name: Name of the listener being evaluated
|
||||
|
||||
Returns:
|
||||
True if the condition is satisfied, False otherwise
|
||||
"""
|
||||
if isinstance(condition, str):
|
||||
return condition == trigger_method
|
||||
|
||||
if isinstance(condition, dict):
|
||||
normalized = _normalize_condition(condition)
|
||||
cond_type = normalized.get("type", "OR")
|
||||
sub_conditions = normalized.get("conditions", [])
|
||||
|
||||
if cond_type == "OR":
|
||||
return any(
|
||||
self._evaluate_condition(sub_cond, trigger_method, listener_name)
|
||||
for sub_cond in sub_conditions
|
||||
)
|
||||
|
||||
if cond_type == "AND":
|
||||
pending_key = f"{listener_name}:{id(condition)}"
|
||||
|
||||
if pending_key not in self._pending_and_listeners:
|
||||
all_methods = set(_extract_all_methods(condition))
|
||||
self._pending_and_listeners[pending_key] = all_methods
|
||||
|
||||
if trigger_method in self._pending_and_listeners[pending_key]:
|
||||
self._pending_and_listeners[pending_key].discard(trigger_method)
|
||||
|
||||
if not self._pending_and_listeners[pending_key]:
|
||||
self._pending_and_listeners.pop(pending_key, None)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
def _find_triggered_methods(
|
||||
self, trigger_method: str, router_only: bool
|
||||
) -> list[str]:
|
||||
@@ -1106,7 +1233,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
Finds all methods that should be triggered based on conditions.
|
||||
|
||||
This internal method evaluates both OR and AND conditions to determine
|
||||
which methods should be executed next in the flow.
|
||||
which methods should be executed next in the flow. Supports nested conditions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -1123,14 +1250,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
|
||||
Notes
|
||||
-----
|
||||
- Handles both OR and AND conditions:
|
||||
* OR: Triggers if any condition is met
|
||||
* AND: Triggers only when all conditions are met
|
||||
- Handles both OR and AND conditions, including nested combinations
|
||||
- Maintains state for AND conditions using _pending_and_listeners
|
||||
- Separates router and normal listener evaluation
|
||||
"""
|
||||
triggered = []
|
||||
for listener_name, (condition_type, methods) in self._listeners.items():
|
||||
|
||||
for listener_name, condition_data in self._listeners.items():
|
||||
is_router = listener_name in self._routers
|
||||
|
||||
if router_only != is_router:
|
||||
@@ -1139,23 +1265,29 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
if not router_only and listener_name in self._start_methods:
|
||||
continue
|
||||
|
||||
if condition_type == "OR":
|
||||
# If the trigger_method matches any in methods, run this
|
||||
if trigger_method in methods:
|
||||
triggered.append(listener_name)
|
||||
elif condition_type == "AND":
|
||||
# Initialize pending methods for this listener if not already done
|
||||
if listener_name not in self._pending_and_listeners:
|
||||
self._pending_and_listeners[listener_name] = set(methods)
|
||||
# Remove the trigger method from pending methods
|
||||
if trigger_method in self._pending_and_listeners[listener_name]:
|
||||
self._pending_and_listeners[listener_name].discard(trigger_method)
|
||||
if isinstance(condition_data, tuple):
|
||||
condition_type, methods = condition_data
|
||||
|
||||
if not self._pending_and_listeners[listener_name]:
|
||||
# All required methods have been executed
|
||||
if condition_type == "OR":
|
||||
if trigger_method in methods:
|
||||
triggered.append(listener_name)
|
||||
elif condition_type == "AND":
|
||||
if listener_name not in self._pending_and_listeners:
|
||||
self._pending_and_listeners[listener_name] = set(methods)
|
||||
if trigger_method in self._pending_and_listeners[listener_name]:
|
||||
self._pending_and_listeners[listener_name].discard(
|
||||
trigger_method
|
||||
)
|
||||
|
||||
if not self._pending_and_listeners[listener_name]:
|
||||
triggered.append(listener_name)
|
||||
self._pending_and_listeners.pop(listener_name, None)
|
||||
|
||||
elif isinstance(condition_data, dict):
|
||||
if self._evaluate_condition(
|
||||
condition_data, trigger_method, listener_name
|
||||
):
|
||||
triggered.append(listener_name)
|
||||
# Reset pending methods for this listener
|
||||
self._pending_and_listeners.pop(listener_name, None)
|
||||
|
||||
return triggered
|
||||
|
||||
@@ -1218,7 +1350,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
raise
|
||||
|
||||
def _log_flow_event(
|
||||
self, message: str, color: str = "yellow", level: str = "info"
|
||||
self, message: str, color: PrinterColor | None = "yellow", level: str = "info"
|
||||
) -> None:
|
||||
"""Centralized logging method for flow events.
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import logging
|
||||
import threading
|
||||
import uuid
|
||||
import warnings
|
||||
from collections.abc import Callable, Sequence
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import Future
|
||||
from copy import copy as shallow_copy
|
||||
from hashlib import md5
|
||||
@@ -152,15 +152,6 @@ class Task(BaseModel):
|
||||
default=None,
|
||||
description="Function or string description of a guardrail to validate task output before proceeding to next task",
|
||||
)
|
||||
guardrails: (
|
||||
Sequence[Callable[[TaskOutput], tuple[bool, Any]] | str]
|
||||
| Callable[[TaskOutput], tuple[bool, Any]]
|
||||
| str
|
||||
| None
|
||||
) = Field(
|
||||
default=None,
|
||||
description="List of guardrails to validate task output before proceeding to next task. Also supports a single guardrail function or string description of a guardrail to validate task output before proceeding to next task",
|
||||
)
|
||||
max_retries: int | None = Field(
|
||||
default=None,
|
||||
description="[DEPRECATED] Maximum number of retries when guardrail fails. Use guardrail_max_retries instead. Will be removed in v1.0.0",
|
||||
@@ -277,44 +268,6 @@ class Task(BaseModel):
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def ensure_guardrails_is_list_of_callables(self) -> "Task":
|
||||
guardrails = []
|
||||
if self.guardrails is not None and (
|
||||
not isinstance(self.guardrails, (list, tuple)) or len(self.guardrails) > 0
|
||||
):
|
||||
if self.agent is None:
|
||||
raise ValueError("Agent is required to use guardrails")
|
||||
|
||||
if callable(self.guardrails):
|
||||
guardrails.append(self.guardrails)
|
||||
elif isinstance(self.guardrails, str):
|
||||
from crewai.tasks.llm_guardrail import LLMGuardrail
|
||||
|
||||
guardrails.append(
|
||||
LLMGuardrail(description=self.guardrails, llm=self.agent.llm)
|
||||
)
|
||||
|
||||
if isinstance(self.guardrails, list):
|
||||
for guardrail in self.guardrails:
|
||||
if callable(guardrail):
|
||||
guardrails.append(guardrail)
|
||||
elif isinstance(guardrail, str):
|
||||
from crewai.tasks.llm_guardrail import LLMGuardrail
|
||||
|
||||
guardrails.append(
|
||||
LLMGuardrail(description=guardrail, llm=self.agent.llm)
|
||||
)
|
||||
else:
|
||||
raise ValueError("Guardrail must be a callable or a string")
|
||||
|
||||
self._guardrails = guardrails
|
||||
if self._guardrails:
|
||||
self.guardrail = None
|
||||
self._guardrail = None
|
||||
|
||||
return self
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
def _deny_user_set_id(cls, v: UUID4 | None) -> None:
|
||||
@@ -503,23 +456,48 @@ class Task(BaseModel):
|
||||
output_format=self._get_output_format(),
|
||||
)
|
||||
|
||||
if self._guardrails:
|
||||
for guardrail in self._guardrails:
|
||||
task_output = self._invoke_guardrail_function(
|
||||
task_output=task_output,
|
||||
agent=agent,
|
||||
tools=tools,
|
||||
guardrail=guardrail,
|
||||
if self._guardrail:
|
||||
guardrail_result = process_guardrail(
|
||||
output=task_output,
|
||||
guardrail=self._guardrail,
|
||||
retry_count=self.retry_count,
|
||||
event_source=self,
|
||||
from_task=self,
|
||||
from_agent=agent,
|
||||
)
|
||||
if not guardrail_result.success:
|
||||
if self.retry_count >= self.guardrail_max_retries:
|
||||
raise Exception(
|
||||
f"Task failed guardrail validation after {self.guardrail_max_retries} retries. "
|
||||
f"Last error: {guardrail_result.error}"
|
||||
)
|
||||
|
||||
self.retry_count += 1
|
||||
context = self.i18n.errors("validation_error").format(
|
||||
guardrail_result_error=guardrail_result.error,
|
||||
task_output=task_output.raw,
|
||||
)
|
||||
printer = Printer()
|
||||
printer.print(
|
||||
content=f"Guardrail blocked, retrying, due to: {guardrail_result.error}\n",
|
||||
color="yellow",
|
||||
)
|
||||
return self._execute_core(agent, context, tools)
|
||||
|
||||
if guardrail_result.result is None:
|
||||
raise Exception(
|
||||
"Task guardrail returned None as result. This is not allowed."
|
||||
)
|
||||
|
||||
# backwards support
|
||||
if self._guardrail:
|
||||
task_output = self._invoke_guardrail_function(
|
||||
task_output=task_output,
|
||||
agent=agent,
|
||||
tools=tools,
|
||||
guardrail=self._guardrail,
|
||||
)
|
||||
if isinstance(guardrail_result.result, str):
|
||||
task_output.raw = guardrail_result.result
|
||||
pydantic_output, json_output = self._export_output(
|
||||
guardrail_result.result
|
||||
)
|
||||
task_output.pydantic = pydantic_output
|
||||
task_output.json_dict = json_output
|
||||
elif isinstance(guardrail_result.result, TaskOutput):
|
||||
task_output = guardrail_result.result
|
||||
|
||||
self.output = task_output
|
||||
self.end_time = datetime.datetime.now()
|
||||
@@ -811,55 +789,3 @@ Follow these guidelines:
|
||||
Fingerprint: The fingerprint of the task
|
||||
"""
|
||||
return self.security_config.fingerprint
|
||||
|
||||
def _invoke_guardrail_function(
|
||||
self,
|
||||
task_output: TaskOutput,
|
||||
agent: BaseAgent,
|
||||
tools: list[BaseTool],
|
||||
guardrail: Callable | None,
|
||||
) -> TaskOutput:
|
||||
if guardrail:
|
||||
guardrail_result = process_guardrail(
|
||||
output=task_output,
|
||||
guardrail=guardrail,
|
||||
retry_count=self.retry_count,
|
||||
event_source=self,
|
||||
from_task=self,
|
||||
from_agent=agent,
|
||||
)
|
||||
if not guardrail_result.success:
|
||||
if self.retry_count >= self.guardrail_max_retries:
|
||||
raise Exception(
|
||||
f"Task failed guardrail validation after {self.guardrail_max_retries} retries. "
|
||||
f"Last error: {guardrail_result.error}"
|
||||
)
|
||||
|
||||
self.retry_count += 1
|
||||
context = self.i18n.errors("validation_error").format(
|
||||
guardrail_result_error=guardrail_result.error,
|
||||
task_output=task_output.raw,
|
||||
)
|
||||
printer = Printer()
|
||||
printer.print(
|
||||
content=f"Guardrail blocked, retrying, due to: {guardrail_result.error}\n",
|
||||
color="yellow",
|
||||
)
|
||||
return self._execute_core(agent, context, tools)
|
||||
|
||||
if guardrail_result.result is None:
|
||||
raise Exception(
|
||||
"Task guardrail returned None as result. This is not allowed."
|
||||
)
|
||||
|
||||
if isinstance(guardrail_result.result, str):
|
||||
task_output.raw = guardrail_result.result
|
||||
pydantic_output, json_output = self._export_output(
|
||||
guardrail_result.result
|
||||
)
|
||||
task_output.pydantic = pydantic_output
|
||||
task_output.json_dict = json_output
|
||||
elif isinstance(guardrail_result.result, TaskOutput):
|
||||
task_output = guardrail_result.result
|
||||
|
||||
return task_output
|
||||
|
||||
@@ -14,6 +14,7 @@ from pydantic import (
|
||||
from pydantic import BaseModel as PydanticBaseModel
|
||||
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
from crewai.utilities.asyncio_utils import run_coroutine_sync
|
||||
|
||||
|
||||
class EnvVar(BaseModel):
|
||||
@@ -90,7 +91,7 @@ class BaseTool(BaseModel, ABC):
|
||||
|
||||
# If _run is async, we safely run it
|
||||
if asyncio.iscoroutine(result):
|
||||
result = asyncio.run(result)
|
||||
result = run_coroutine_sync(result)
|
||||
|
||||
self.current_usage_count += 1
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, get_type_hints
|
||||
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
|
||||
from crewai.utilities.asyncio_utils import run_coroutine_sync
|
||||
from crewai.utilities.logger import Logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -269,12 +270,12 @@ class CrewStructuredTool:
|
||||
self._increment_usage_count()
|
||||
|
||||
if inspect.iscoroutinefunction(self.func):
|
||||
return asyncio.run(self.func(**parsed_args, **kwargs))
|
||||
return run_coroutine_sync(self.func(**parsed_args, **kwargs))
|
||||
|
||||
result = self.func(**parsed_args, **kwargs)
|
||||
|
||||
if asyncio.iscoroutine(result):
|
||||
return asyncio.run(result)
|
||||
return run_coroutine_sync(result)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
53
src/crewai/utilities/asyncio_utils.py
Normal file
53
src/crewai/utilities/asyncio_utils.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""Utilities for handling asyncio operations safely across different contexts."""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Coroutine
|
||||
from typing import Any
|
||||
|
||||
|
||||
def run_coroutine_sync(coro: Coroutine) -> Any:
|
||||
"""
|
||||
Run a coroutine synchronously, handling both cases where an event loop
|
||||
is already running and where it's not.
|
||||
|
||||
This is useful when you need to run async code from sync code, but you're
|
||||
not sure if you're already in an async context (e.g., when using asyncio.to_thread).
|
||||
|
||||
Args:
|
||||
coro: The coroutine to run
|
||||
|
||||
Returns:
|
||||
The result of the coroutine
|
||||
|
||||
Raises:
|
||||
Any exception raised by the coroutine
|
||||
"""
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return asyncio.run(coro)
|
||||
else:
|
||||
import threading
|
||||
|
||||
result = None
|
||||
exception = None
|
||||
|
||||
def run_in_new_loop():
|
||||
nonlocal result, exception
|
||||
try:
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
try:
|
||||
result = new_loop.run_until_complete(coro)
|
||||
finally:
|
||||
new_loop.close()
|
||||
except Exception as e:
|
||||
exception = e
|
||||
|
||||
thread = threading.Thread(target=run_in_new_loop)
|
||||
thread.start()
|
||||
thread.join()
|
||||
|
||||
if exception:
|
||||
raise exception
|
||||
return result
|
||||
@@ -1,6 +1,11 @@
|
||||
"""Utility for colored console output."""
|
||||
|
||||
from typing import Final, Literal, NamedTuple
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Final, Literal, NamedTuple
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _typeshed import SupportsWrite
|
||||
|
||||
PrinterColor = Literal[
|
||||
"purple",
|
||||
@@ -54,13 +59,22 @@ class Printer:
|
||||
|
||||
@staticmethod
|
||||
def print(
|
||||
content: str | list[ColoredText], color: PrinterColor | None = None
|
||||
content: str | list[ColoredText],
|
||||
color: PrinterColor | None = None,
|
||||
sep: str | None = " ",
|
||||
end: str | None = "\n",
|
||||
file: SupportsWrite[str] | None = None,
|
||||
flush: Literal[False] = False,
|
||||
) -> None:
|
||||
"""Prints content to the console with optional color formatting.
|
||||
|
||||
Args:
|
||||
content: Either a string or a list of ColoredText objects for multicolor output.
|
||||
color: Optional color for the text when content is a string. Ignored when content is a list.
|
||||
sep: Separator to use between the text and color.
|
||||
end: String appended after the last value.
|
||||
file: A file-like object (stream); defaults to the current sys.stdout.
|
||||
flush: Whether to forcibly flush the stream.
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
content = [ColoredText(content, color)]
|
||||
@@ -68,5 +82,9 @@ class Printer:
|
||||
"".join(
|
||||
f"{_COLOR_CODES[c.color] if c.color else ''}{c.text}{RESET}"
|
||||
for c in content
|
||||
)
|
||||
),
|
||||
sep=sep,
|
||||
end=end,
|
||||
file=file,
|
||||
flush=flush,
|
||||
)
|
||||
|
||||
142
tests/test_asyncio_tools.py
Normal file
142
tests/test_asyncio_tools.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""Tests for asyncio tool execution in different contexts."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai import Agent, Crew, Task
|
||||
from crewai.tools import tool
|
||||
|
||||
|
||||
@tool
|
||||
async def async_test_tool(message: str) -> str:
|
||||
"""An async tool that returns a message."""
|
||||
await asyncio.sleep(0.01)
|
||||
return f"Processed: {message}"
|
||||
|
||||
|
||||
@tool
|
||||
def sync_test_tool(message: str) -> str:
|
||||
"""A sync tool that returns a message."""
|
||||
return f"Sync: {message}"
|
||||
|
||||
|
||||
class TestAsyncioToolExecution:
|
||||
"""Test that tools work correctly in different asyncio contexts."""
|
||||
|
||||
@patch("crewai.Agent.execute_task")
|
||||
def test_async_tool_with_asyncio_to_thread(self, mock_execute_task):
|
||||
"""Test that async tools work when crew is run with asyncio.to_thread."""
|
||||
mock_execute_task.return_value = "Task completed"
|
||||
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test async tool execution",
|
||||
backstory="A test agent",
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="A result",
|
||||
agent=agent,
|
||||
tools=[async_test_tool],
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task], verbose=False)
|
||||
|
||||
async def run_with_to_thread():
|
||||
"""Run crew with asyncio.to_thread - this should not hang."""
|
||||
result = await asyncio.to_thread(crew.kickoff)
|
||||
return result
|
||||
|
||||
result = asyncio.run(run_with_to_thread())
|
||||
assert result is not None
|
||||
|
||||
@patch("crewai.Agent.execute_task")
|
||||
def test_sync_tool_with_asyncio_to_thread(self, mock_execute_task):
|
||||
"""Test that sync tools work when crew is run with asyncio.to_thread."""
|
||||
mock_execute_task.return_value = "Task completed"
|
||||
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test sync tool execution",
|
||||
backstory="A test agent",
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="A result",
|
||||
agent=agent,
|
||||
tools=[sync_test_tool],
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task], verbose=False)
|
||||
|
||||
async def run_with_to_thread():
|
||||
"""Run crew with asyncio.to_thread."""
|
||||
result = await asyncio.to_thread(crew.kickoff)
|
||||
return result
|
||||
|
||||
result = asyncio.run(run_with_to_thread())
|
||||
assert result is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.Agent.execute_task")
|
||||
async def test_async_tool_with_kickoff_async(self, mock_execute_task):
|
||||
"""Test that async tools work with kickoff_async."""
|
||||
mock_execute_task.return_value = "Task completed"
|
||||
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test async tool execution",
|
||||
backstory="A test agent",
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="A result",
|
||||
agent=agent,
|
||||
tools=[async_test_tool],
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task], verbose=False)
|
||||
|
||||
result = await crew.kickoff_async()
|
||||
assert result is not None
|
||||
|
||||
def test_async_tool_direct_invocation(self):
|
||||
"""Test that async tools can be invoked directly from sync context."""
|
||||
structured_tool = async_test_tool.to_structured_tool()
|
||||
result = structured_tool.invoke({"message": "test"})
|
||||
assert result == "Processed: test"
|
||||
|
||||
def test_async_tool_invocation_from_thread(self):
|
||||
"""Test that async tools work when invoked from a thread pool."""
|
||||
structured_tool = async_test_tool.to_structured_tool()
|
||||
|
||||
def invoke_tool():
|
||||
return structured_tool.invoke({"message": "test"})
|
||||
|
||||
async def run_in_thread():
|
||||
result = await asyncio.to_thread(invoke_tool)
|
||||
return result
|
||||
|
||||
result = asyncio.run(run_in_thread())
|
||||
assert result == "Processed: test"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_async_tools_concurrent(self):
|
||||
"""Test multiple async tool invocations concurrently."""
|
||||
structured_tool = async_test_tool.to_structured_tool()
|
||||
|
||||
async def invoke_async():
|
||||
return await structured_tool.ainvoke({"message": "test"})
|
||||
|
||||
results = await asyncio.gather(
|
||||
invoke_async(), invoke_async(), invoke_async()
|
||||
)
|
||||
|
||||
assert len(results) == 3
|
||||
for r in results:
|
||||
assert "test" in str(r)
|
||||
@@ -1,8 +1,7 @@
|
||||
import asyncio
|
||||
import threading
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
from typing import Dict, Any, Callable
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@@ -25,12 +24,9 @@ def simple_agent_factory():
|
||||
|
||||
@pytest.fixture
|
||||
def simple_task_factory():
|
||||
def create_task(name: str, agent: Agent, callback: Callable | None = None) -> Task:
|
||||
def create_task(name: str, callback: Callable = None) -> Task:
|
||||
return Task(
|
||||
description=f"Task for {name}",
|
||||
expected_output="Done",
|
||||
agent=agent,
|
||||
callback=callback,
|
||||
description=f"Task for {name}", expected_output="Done", callback=callback
|
||||
)
|
||||
|
||||
return create_task
|
||||
@@ -38,9 +34,10 @@ def simple_task_factory():
|
||||
|
||||
@pytest.fixture
|
||||
def crew_factory(simple_agent_factory, simple_task_factory):
|
||||
def create_crew(name: str, task_callback: Callable | None = None) -> Crew:
|
||||
def create_crew(name: str, task_callback: Callable = None) -> Crew:
|
||||
agent = simple_agent_factory(name)
|
||||
task = simple_task_factory(name, agent=agent, callback=task_callback)
|
||||
task = simple_task_factory(name, callback=task_callback)
|
||||
task.agent = agent
|
||||
|
||||
return Crew(agents=[agent], tasks=[task], verbose=False)
|
||||
|
||||
@@ -53,7 +50,7 @@ class TestCrewThreadSafety:
|
||||
mock_execute_task.return_value = "Task completed"
|
||||
num_crews = 5
|
||||
|
||||
def run_crew_with_context_check(crew_id: str) -> dict[str, Any]:
|
||||
def run_crew_with_context_check(crew_id: str) -> Dict[str, Any]:
|
||||
results = {"crew_id": crew_id, "contexts": []}
|
||||
|
||||
def check_context_task(output):
|
||||
@@ -108,28 +105,28 @@ class TestCrewThreadSafety:
|
||||
before_ctx = next(
|
||||
ctx for ctx in result["contexts"] if ctx["stage"] == "before_kickoff"
|
||||
)
|
||||
assert before_ctx["crew_id"] is None, (
|
||||
f"Context should be None before kickoff for {result['crew_id']}"
|
||||
)
|
||||
assert (
|
||||
before_ctx["crew_id"] is None
|
||||
), f"Context should be None before kickoff for {result['crew_id']}"
|
||||
|
||||
task_ctx = next(
|
||||
ctx for ctx in result["contexts"] if ctx["stage"] == "task_callback"
|
||||
)
|
||||
assert task_ctx["crew_id"] == crew_uuid, (
|
||||
f"Context mismatch during task for {result['crew_id']}"
|
||||
)
|
||||
assert (
|
||||
task_ctx["crew_id"] == crew_uuid
|
||||
), f"Context mismatch during task for {result['crew_id']}"
|
||||
|
||||
after_ctx = next(
|
||||
ctx for ctx in result["contexts"] if ctx["stage"] == "after_kickoff"
|
||||
)
|
||||
assert after_ctx["crew_id"] is None, (
|
||||
f"Context should be None after kickoff for {result['crew_id']}"
|
||||
)
|
||||
assert (
|
||||
after_ctx["crew_id"] is None
|
||||
), f"Context should be None after kickoff for {result['crew_id']}"
|
||||
|
||||
thread_name = before_ctx["thread"]
|
||||
assert "ThreadPoolExecutor" in thread_name, (
|
||||
f"Should run in thread pool for {result['crew_id']}"
|
||||
)
|
||||
assert (
|
||||
"ThreadPoolExecutor" in thread_name
|
||||
), f"Should run in thread pool for {result['crew_id']}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.Agent.execute_task")
|
||||
@@ -137,7 +134,7 @@ class TestCrewThreadSafety:
|
||||
mock_execute_task.return_value = "Task completed"
|
||||
num_crews = 5
|
||||
|
||||
async def run_crew_async(crew_id: str) -> dict[str, Any]:
|
||||
async def run_crew_async(crew_id: str) -> Dict[str, Any]:
|
||||
task_context = {"crew_id": crew_id, "context": None}
|
||||
|
||||
def capture_context(output):
|
||||
@@ -165,12 +162,12 @@ class TestCrewThreadSafety:
|
||||
crew_uuid = result["crew_uuid"]
|
||||
task_ctx = result["task_context"]["context"]
|
||||
|
||||
assert task_ctx is not None, (
|
||||
f"Context should exist during task for {result['crew_id']}"
|
||||
)
|
||||
assert task_ctx["crew_id"] == crew_uuid, (
|
||||
f"Context mismatch for {result['crew_id']}"
|
||||
)
|
||||
assert (
|
||||
task_ctx is not None
|
||||
), f"Context should exist during task for {result['crew_id']}"
|
||||
assert (
|
||||
task_ctx["crew_id"] == crew_uuid
|
||||
), f"Context mismatch for {result['crew_id']}"
|
||||
|
||||
@patch("crewai.Agent.execute_task")
|
||||
def test_concurrent_kickoff_for_each(self, mock_execute_task, crew_factory):
|
||||
@@ -196,9 +193,9 @@ class TestCrewThreadSafety:
|
||||
assert len(contexts_captured) == len(inputs)
|
||||
|
||||
context_ids = [ctx["context_id"] for ctx in contexts_captured]
|
||||
assert len(set(context_ids)) == len(inputs), (
|
||||
"Each execution should have unique context"
|
||||
)
|
||||
assert len(set(context_ids)) == len(
|
||||
inputs
|
||||
), "Each execution should have unique context"
|
||||
|
||||
@patch("crewai.Agent.execute_task")
|
||||
def test_no_context_leakage_between_crews(self, mock_execute_task, crew_factory):
|
||||
|
||||
@@ -6,15 +6,15 @@ from datetime import datetime
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.flow.flow import Flow, and_, listen, or_, router, start
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.flow_events import (
|
||||
FlowFinishedEvent,
|
||||
FlowStartedEvent,
|
||||
FlowPlotEvent,
|
||||
FlowStartedEvent,
|
||||
MethodExecutionFinishedEvent,
|
||||
MethodExecutionStartedEvent,
|
||||
)
|
||||
from crewai.flow.flow import Flow, and_, listen, or_, router, start
|
||||
|
||||
|
||||
def test_simple_sequential_flow():
|
||||
@@ -679,11 +679,11 @@ def test_structured_flow_event_emission():
|
||||
assert isinstance(received_events[3], MethodExecutionStartedEvent)
|
||||
assert received_events[3].method_name == "send_welcome_message"
|
||||
assert received_events[3].params == {}
|
||||
assert getattr(received_events[3].state, "sent") is False
|
||||
assert received_events[3].state.sent is False
|
||||
|
||||
assert isinstance(received_events[4], MethodExecutionFinishedEvent)
|
||||
assert received_events[4].method_name == "send_welcome_message"
|
||||
assert getattr(received_events[4].state, "sent") is True
|
||||
assert received_events[4].state.sent is True
|
||||
assert received_events[4].result == "Welcome, Anakin!"
|
||||
|
||||
assert isinstance(received_events[5], FlowFinishedEvent)
|
||||
@@ -894,3 +894,75 @@ def test_flow_name():
|
||||
|
||||
flow = MyFlow()
|
||||
assert flow.name == "MyFlow"
|
||||
|
||||
|
||||
def test_nested_and_or_conditions():
|
||||
"""Test nested conditions like or_(and_(A, B), and_(C, D)).
|
||||
|
||||
Reproduces bug from issue #3719 where nested conditions are flattened,
|
||||
causing premature execution.
|
||||
"""
|
||||
execution_order = []
|
||||
|
||||
class NestedConditionFlow(Flow):
|
||||
@start()
|
||||
def method_1(self):
|
||||
execution_order.append("method_1")
|
||||
|
||||
@listen(method_1)
|
||||
def method_2(self):
|
||||
execution_order.append("method_2")
|
||||
|
||||
@router(method_2)
|
||||
def method_3(self):
|
||||
execution_order.append("method_3")
|
||||
# Choose b_condition path
|
||||
return "b_condition"
|
||||
|
||||
@listen("b_condition")
|
||||
def method_5(self):
|
||||
execution_order.append("method_5")
|
||||
|
||||
@listen(method_5)
|
||||
async def method_4(self):
|
||||
execution_order.append("method_4")
|
||||
|
||||
@listen(or_("a_condition", "b_condition"))
|
||||
async def method_6(self):
|
||||
execution_order.append("method_6")
|
||||
|
||||
@listen(
|
||||
or_(
|
||||
and_("a_condition", method_6),
|
||||
and_(method_6, method_4),
|
||||
)
|
||||
)
|
||||
def method_7(self):
|
||||
execution_order.append("method_7")
|
||||
|
||||
@listen(method_7)
|
||||
async def method_8(self):
|
||||
execution_order.append("method_8")
|
||||
|
||||
flow = NestedConditionFlow()
|
||||
flow.kickoff()
|
||||
|
||||
# Verify execution happened
|
||||
assert "method_1" in execution_order
|
||||
assert "method_2" in execution_order
|
||||
assert "method_3" in execution_order
|
||||
assert "method_5" in execution_order
|
||||
assert "method_4" in execution_order
|
||||
assert "method_6" in execution_order
|
||||
assert "method_7" in execution_order
|
||||
assert "method_8" in execution_order
|
||||
|
||||
# Critical assertion: method_7 should only execute AFTER both method_6 AND method_4
|
||||
# Since b_condition was returned, method_6 triggers on b_condition
|
||||
# method_7 requires: (a_condition AND method_6) OR (method_6 AND method_4)
|
||||
# The second condition (method_6 AND method_4) should be the one that triggers
|
||||
assert execution_order.index("method_7") > execution_order.index("method_6")
|
||||
assert execution_order.index("method_7") > execution_order.index("method_4")
|
||||
|
||||
# method_8 should execute after method_7
|
||||
assert execution_order.index("method_8") > execution_order.index("method_7")
|
||||
|
||||
@@ -14,24 +14,6 @@ from crewai.tasks.llm_guardrail import LLMGuardrail
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
|
||||
|
||||
def create_smart_task(**kwargs):
|
||||
"""
|
||||
Smart task factory that automatically assigns a mock agent when guardrails are present.
|
||||
This maintains backward compatibility while handling the agent requirement for guardrails.
|
||||
"""
|
||||
guardrails_list = kwargs.get("guardrails")
|
||||
has_guardrails = kwargs.get("guardrail") is not None or (
|
||||
guardrails_list is not None and len(guardrails_list) > 0
|
||||
)
|
||||
|
||||
if has_guardrails and kwargs.get("agent") is None:
|
||||
kwargs["agent"] = Agent(
|
||||
role="test_agent", goal="test_goal", backstory="test_backstory"
|
||||
)
|
||||
|
||||
return Task(**kwargs)
|
||||
|
||||
|
||||
def test_task_without_guardrail():
|
||||
"""Test that tasks work normally without guardrails (backward compatibility)."""
|
||||
agent = Mock()
|
||||
@@ -39,7 +21,7 @@ def test_task_without_guardrail():
|
||||
agent.execute_task.return_value = "test result"
|
||||
agent.crew = None
|
||||
|
||||
task = create_smart_task(description="Test task", expected_output="Output")
|
||||
task = Task(description="Test task", expected_output="Output")
|
||||
|
||||
result = task.execute_sync(agent=agent)
|
||||
assert isinstance(result, TaskOutput)
|
||||
@@ -57,9 +39,7 @@ def test_task_with_successful_guardrail_func():
|
||||
agent.execute_task.return_value = "test result"
|
||||
agent.crew = None
|
||||
|
||||
task = create_smart_task(
|
||||
description="Test task", expected_output="Output", guardrail=guardrail
|
||||
)
|
||||
task = Task(description="Test task", expected_output="Output", guardrail=guardrail)
|
||||
|
||||
result = task.execute_sync(agent=agent)
|
||||
assert isinstance(result, TaskOutput)
|
||||
@@ -77,7 +57,7 @@ def test_task_with_failing_guardrail():
|
||||
agent.execute_task.side_effect = ["bad result", "good result"]
|
||||
agent.crew = None
|
||||
|
||||
task = create_smart_task(
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Output",
|
||||
guardrail=guardrail,
|
||||
@@ -104,7 +84,7 @@ def test_task_with_guardrail_retries():
|
||||
agent.execute_task.return_value = "bad result"
|
||||
agent.crew = None
|
||||
|
||||
task = create_smart_task(
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Output",
|
||||
guardrail=guardrail,
|
||||
@@ -129,7 +109,7 @@ def test_guardrail_error_in_context():
|
||||
agent.role = "test_agent"
|
||||
agent.crew = None
|
||||
|
||||
task = create_smart_task(
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Output",
|
||||
guardrail=guardrail,
|
||||
@@ -197,7 +177,7 @@ def test_guardrail_emits_events(sample_agent):
|
||||
started_guardrail = []
|
||||
completed_guardrail = []
|
||||
|
||||
task = create_smart_task(
|
||||
task = Task(
|
||||
description="Gather information about available books on the First World War",
|
||||
agent=sample_agent,
|
||||
expected_output="A list of available books on the First World War",
|
||||
@@ -230,7 +210,7 @@ def test_guardrail_emits_events(sample_agent):
|
||||
def custom_guardrail(result: TaskOutput):
|
||||
return (True, "good result from callable function")
|
||||
|
||||
task = create_smart_task(
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Output",
|
||||
guardrail=custom_guardrail,
|
||||
@@ -282,7 +262,7 @@ def test_guardrail_when_an_error_occurs(sample_agent, task_output):
|
||||
match="Error while validating the task output: Unexpected error",
|
||||
),
|
||||
):
|
||||
task = create_smart_task(
|
||||
task = Task(
|
||||
description="Gather information about available books on the First World War",
|
||||
agent=sample_agent,
|
||||
expected_output="A list of available books on the First World War",
|
||||
@@ -304,7 +284,7 @@ def test_hallucination_guardrail_integration():
|
||||
context="Test reference context for validation", llm=mock_llm, threshold=8.0
|
||||
)
|
||||
|
||||
task = create_smart_task(
|
||||
task = Task(
|
||||
description="Test task with hallucination guardrail",
|
||||
expected_output="Valid output",
|
||||
guardrail=guardrail,
|
||||
@@ -324,352 +304,3 @@ def test_hallucination_guardrail_description_in_events():
|
||||
|
||||
event = LLMGuardrailStartedEvent(guardrail=guardrail, retry_count=0)
|
||||
assert event.guardrail == "HallucinationGuardrail (no-op)"
|
||||
|
||||
|
||||
def test_multiple_guardrails_sequential_processing():
|
||||
"""Test that multiple guardrails are processed sequentially."""
|
||||
|
||||
def first_guardrail(result: TaskOutput) -> tuple[bool, str]:
|
||||
"""First guardrail adds prefix."""
|
||||
return (True, f"[FIRST] {result.raw}")
|
||||
|
||||
def second_guardrail(result: TaskOutput) -> tuple[bool, str]:
|
||||
"""Second guardrail adds suffix."""
|
||||
return (True, f"{result.raw} [SECOND]")
|
||||
|
||||
def third_guardrail(result: TaskOutput) -> tuple[bool, str]:
|
||||
"""Third guardrail converts to uppercase."""
|
||||
return (True, result.raw.upper())
|
||||
|
||||
agent = Mock()
|
||||
agent.role = "sequential_agent"
|
||||
agent.execute_task.return_value = "original text"
|
||||
agent.crew = None
|
||||
|
||||
task = create_smart_task(
|
||||
description="Test sequential guardrails",
|
||||
expected_output="Processed text",
|
||||
guardrails=[first_guardrail, second_guardrail, third_guardrail],
|
||||
)
|
||||
|
||||
result = task.execute_sync(agent=agent)
|
||||
assert result.raw == "[FIRST] ORIGINAL TEXT [SECOND]"
|
||||
|
||||
|
||||
def test_multiple_guardrails_with_validation_failure():
|
||||
"""Test multiple guardrails where one fails validation."""
|
||||
|
||||
def length_guardrail(result: TaskOutput) -> tuple[bool, str]:
|
||||
"""Ensure minimum length."""
|
||||
if len(result.raw) < 10:
|
||||
return (False, "Text too short")
|
||||
return (True, result.raw)
|
||||
|
||||
def format_guardrail(result: TaskOutput) -> tuple[bool, str]:
|
||||
"""Add formatting only if not already formatted."""
|
||||
if not result.raw.startswith("Formatted:"):
|
||||
return (True, f"Formatted: {result.raw}")
|
||||
return (True, result.raw)
|
||||
|
||||
def validation_guardrail(result: TaskOutput) -> tuple[bool, str]:
|
||||
"""Final validation."""
|
||||
if "Formatted:" not in result.raw:
|
||||
return (False, "Missing formatting")
|
||||
return (True, result.raw)
|
||||
|
||||
# Use a callable that tracks calls and returns appropriate values
|
||||
call_count = 0
|
||||
|
||||
def mock_execute_task(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
result = (
|
||||
"short"
|
||||
if call_count == 1
|
||||
else "this is a longer text that meets requirements"
|
||||
)
|
||||
return result
|
||||
|
||||
agent = Mock()
|
||||
agent.role = "validation_agent"
|
||||
agent.execute_task = mock_execute_task
|
||||
agent.crew = None
|
||||
|
||||
task = create_smart_task(
|
||||
description="Test guardrails with validation",
|
||||
expected_output="Valid formatted text",
|
||||
guardrails=[length_guardrail, format_guardrail, validation_guardrail],
|
||||
guardrail_max_retries=2,
|
||||
)
|
||||
|
||||
result = task.execute_sync(agent=agent)
|
||||
# The second call should be processed through all guardrails
|
||||
assert result.raw == "Formatted: this is a longer text that meets requirements"
|
||||
assert task.retry_count == 1
|
||||
|
||||
|
||||
def test_multiple_guardrails_with_mixed_string_and_taskoutput():
|
||||
"""Test guardrails that return both strings and TaskOutput objects."""
|
||||
|
||||
def string_guardrail(result: TaskOutput) -> tuple[bool, str]:
|
||||
"""Returns a string."""
|
||||
return (True, f"String: {result.raw}")
|
||||
|
||||
def taskoutput_guardrail(result: TaskOutput) -> tuple[bool, TaskOutput]:
|
||||
"""Returns a TaskOutput object."""
|
||||
new_output = TaskOutput(
|
||||
name=result.name,
|
||||
description=result.description,
|
||||
expected_output=result.expected_output,
|
||||
raw=f"TaskOutput: {result.raw}",
|
||||
agent=result.agent,
|
||||
output_format=result.output_format,
|
||||
)
|
||||
return (True, new_output)
|
||||
|
||||
def final_string_guardrail(result: TaskOutput) -> tuple[bool, str]:
|
||||
"""Final string transformation."""
|
||||
return (True, f"Final: {result.raw}")
|
||||
|
||||
agent = Mock()
|
||||
agent.role = "mixed_agent"
|
||||
agent.execute_task.return_value = "original"
|
||||
agent.crew = None
|
||||
|
||||
task = create_smart_task(
|
||||
description="Test mixed return types",
|
||||
expected_output="Mixed processing",
|
||||
guardrails=[string_guardrail, taskoutput_guardrail, final_string_guardrail],
|
||||
)
|
||||
|
||||
result = task.execute_sync(agent=agent)
|
||||
assert result.raw == "Final: TaskOutput: String: original"
|
||||
|
||||
|
||||
def test_multiple_guardrails_with_retry_on_middle_guardrail():
|
||||
"""Test that retry works correctly when a middle guardrail fails."""
|
||||
|
||||
call_count = {"first": 0, "second": 0, "third": 0}
|
||||
|
||||
def first_guardrail(result: TaskOutput) -> tuple[bool, str]:
|
||||
"""Always succeeds."""
|
||||
call_count["first"] += 1
|
||||
return (True, f"First({call_count['first']}): {result.raw}")
|
||||
|
||||
def second_guardrail(result: TaskOutput) -> tuple[bool, str]:
|
||||
"""Fails on first attempt, succeeds on second."""
|
||||
call_count["second"] += 1
|
||||
if call_count["second"] == 1:
|
||||
return (False, "Second guardrail failed on first attempt")
|
||||
return (True, f"Second({call_count['second']}): {result.raw}")
|
||||
|
||||
def third_guardrail(result: TaskOutput) -> tuple[bool, str]:
|
||||
"""Always succeeds."""
|
||||
call_count["third"] += 1
|
||||
return (True, f"Third({call_count['third']}): {result.raw}")
|
||||
|
||||
agent = Mock()
|
||||
agent.role = "retry_agent"
|
||||
agent.execute_task.return_value = "base"
|
||||
agent.crew = None
|
||||
|
||||
task = create_smart_task(
|
||||
description="Test retry in middle guardrail",
|
||||
expected_output="Retry handling",
|
||||
guardrails=[first_guardrail, second_guardrail, third_guardrail],
|
||||
guardrail_max_retries=2,
|
||||
)
|
||||
|
||||
result = task.execute_sync(agent=agent)
|
||||
# Based on the test output, the behavior is different than expected
|
||||
# The guardrails are called multiple times, so let's verify the retry happened
|
||||
assert task.retry_count == 1
|
||||
# Verify that the second guardrail eventually succeeded
|
||||
assert "Second(2)" in result.raw or call_count["second"] >= 2
|
||||
|
||||
|
||||
def test_multiple_guardrails_with_max_retries_exceeded():
|
||||
"""Test that exception is raised when max retries exceeded with multiple guardrails."""
|
||||
|
||||
def passing_guardrail(result: TaskOutput) -> tuple[bool, str]:
|
||||
"""Always passes."""
|
||||
return (True, f"Passed: {result.raw}")
|
||||
|
||||
def failing_guardrail(result: TaskOutput) -> tuple[bool, str]:
|
||||
"""Always fails."""
|
||||
return (False, "This guardrail always fails")
|
||||
|
||||
agent = Mock()
|
||||
agent.role = "failing_agent"
|
||||
agent.execute_task.return_value = "test"
|
||||
agent.crew = None
|
||||
|
||||
task = create_smart_task(
|
||||
description="Test max retries with multiple guardrails",
|
||||
expected_output="Will fail",
|
||||
guardrails=[passing_guardrail, failing_guardrail],
|
||||
guardrail_max_retries=1,
|
||||
)
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
task.execute_sync(agent=agent)
|
||||
|
||||
assert "Task failed guardrail validation after 1 retries" in str(exc_info.value)
|
||||
assert "This guardrail always fails" in str(exc_info.value)
|
||||
assert task.retry_count == 1
|
||||
|
||||
|
||||
def test_multiple_guardrails_empty_list():
|
||||
"""Test that empty guardrails list works correctly."""
|
||||
|
||||
agent = Mock()
|
||||
agent.role = "empty_agent"
|
||||
agent.execute_task.return_value = "no guardrails"
|
||||
agent.crew = None
|
||||
|
||||
task = create_smart_task(
|
||||
description="Test empty guardrails list",
|
||||
expected_output="No processing",
|
||||
guardrails=[],
|
||||
)
|
||||
|
||||
result = task.execute_sync(agent=agent)
|
||||
assert result.raw == "no guardrails"
|
||||
|
||||
|
||||
def test_multiple_guardrails_with_llm_guardrails():
|
||||
"""Test mixing callable and LLM guardrails."""
|
||||
|
||||
def callable_guardrail(result: TaskOutput) -> tuple[bool, str]:
|
||||
"""Callable guardrail."""
|
||||
return (True, f"Callable: {result.raw}")
|
||||
|
||||
# Create a proper mock agent without config issues
|
||||
from crewai import Agent
|
||||
|
||||
agent = Agent(
|
||||
role="mixed_guardrail_agent", goal="Test goal", backstory="Test backstory"
|
||||
)
|
||||
|
||||
task = create_smart_task(
|
||||
description="Test mixed guardrail types",
|
||||
expected_output="Mixed processing",
|
||||
guardrails=[callable_guardrail, "Ensure the output is professional"],
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
# The LLM guardrail will be converted to LLMGuardrail internally
|
||||
assert len(task._guardrails) == 2
|
||||
assert callable(task._guardrails[0])
|
||||
assert callable(task._guardrails[1]) # LLMGuardrail is callable
|
||||
|
||||
|
||||
def test_multiple_guardrails_processing_order():
|
||||
"""Test that guardrails are processed in the correct order."""
|
||||
|
||||
processing_order = []
|
||||
|
||||
def first_guardrail(result: TaskOutput) -> tuple[bool, str]:
|
||||
processing_order.append("first")
|
||||
return (True, f"1-{result.raw}")
|
||||
|
||||
def second_guardrail(result: TaskOutput) -> tuple[bool, str]:
|
||||
processing_order.append("second")
|
||||
return (True, f"2-{result.raw}")
|
||||
|
||||
def third_guardrail(result: TaskOutput) -> tuple[bool, str]:
|
||||
processing_order.append("third")
|
||||
return (True, f"3-{result.raw}")
|
||||
|
||||
agent = Mock()
|
||||
agent.role = "order_agent"
|
||||
agent.execute_task.return_value = "base"
|
||||
agent.crew = None
|
||||
|
||||
task = create_smart_task(
|
||||
description="Test processing order",
|
||||
expected_output="Ordered processing",
|
||||
guardrails=[first_guardrail, second_guardrail, third_guardrail],
|
||||
)
|
||||
|
||||
result = task.execute_sync(agent=agent)
|
||||
assert processing_order == ["first", "second", "third"]
|
||||
assert result.raw == "3-2-1-base"
|
||||
|
||||
|
||||
def test_multiple_guardrails_with_pydantic_output():
|
||||
"""Test multiple guardrails with Pydantic output model."""
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class TestModel(BaseModel):
|
||||
content: str = Field(description="The content")
|
||||
processed: bool = Field(description="Whether it was processed")
|
||||
|
||||
def json_guardrail(result: TaskOutput) -> tuple[bool, str]:
|
||||
"""Convert to JSON format."""
|
||||
import json
|
||||
|
||||
data = {"content": result.raw, "processed": True}
|
||||
return (True, json.dumps(data))
|
||||
|
||||
def validation_guardrail(result: TaskOutput) -> tuple[bool, str]:
|
||||
"""Validate JSON structure."""
|
||||
import json
|
||||
|
||||
try:
|
||||
data = json.loads(result.raw)
|
||||
if "content" not in data or "processed" not in data:
|
||||
return (False, "Missing required fields")
|
||||
return (True, result.raw)
|
||||
except json.JSONDecodeError:
|
||||
return (False, "Invalid JSON format")
|
||||
|
||||
agent = Mock()
|
||||
agent.role = "pydantic_agent"
|
||||
agent.execute_task.return_value = "test content"
|
||||
agent.crew = None
|
||||
|
||||
task = create_smart_task(
|
||||
description="Test guardrails with Pydantic",
|
||||
expected_output="Structured output",
|
||||
guardrails=[json_guardrail, validation_guardrail],
|
||||
output_pydantic=TestModel,
|
||||
)
|
||||
|
||||
result = task.execute_sync(agent=agent)
|
||||
|
||||
# Verify the result is valid JSON and can be parsed
|
||||
import json
|
||||
|
||||
parsed = json.loads(result.raw)
|
||||
assert parsed["content"] == "test content"
|
||||
assert parsed["processed"] is True
|
||||
|
||||
|
||||
def test_guardrails_vs_single_guardrail_mutual_exclusion():
|
||||
"""Test that guardrails list nullifies single guardrail."""
|
||||
|
||||
def single_guardrail(result: TaskOutput) -> tuple[bool, str]:
|
||||
"""Single guardrail - should be ignored."""
|
||||
return (True, f"Single: {result.raw}")
|
||||
|
||||
def list_guardrail(result: TaskOutput) -> tuple[bool, str]:
|
||||
"""List guardrail - should be used."""
|
||||
return (True, f"List: {result.raw}")
|
||||
|
||||
agent = Mock()
|
||||
agent.role = "exclusion_agent"
|
||||
agent.execute_task.return_value = "test"
|
||||
agent.crew = None
|
||||
|
||||
task = create_smart_task(
|
||||
description="Test mutual exclusion",
|
||||
expected_output="Exclusion test",
|
||||
guardrail=single_guardrail, # This should be ignored
|
||||
guardrails=[list_guardrail], # This should be used
|
||||
)
|
||||
|
||||
result = task.execute_sync(agent=agent)
|
||||
# Should only use the guardrails list, not the single guardrail
|
||||
assert result.raw == "List: test"
|
||||
assert task._guardrail is None # Single guardrail should be nullified
|
||||
|
||||
Reference in New Issue
Block a user