Compare commits

..

17 Commits

Author SHA1 Message Date
Devin AI
f0d0511b24 chore: Merge remote changes and remove telemetry dependency
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-21 03:26:18 +00:00
Devin AI
9fad174b74 fix: Remove telemetry dependency from Flow plot method
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-21 03:25:10 +00:00
Devin AI
dfba35e475 fix: Fix thread lock type checking in Flow state serialization
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-21 03:18:20 +00:00
Devin AI
2e01d1029b fix: Fix type error in Flow state serialization
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-21 03:17:58 +00:00
Devin AI
84f770aa5d style: Fix import sorting in tests
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-21 03:16:54 +00:00
Devin AI
2a2c163c3d test: Add comprehensive test for complex nested objects
- Add test for various thread-safe primitives
- Test nested dataclasses with complex state
- Verify serialization of async primitives

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-21 03:16:54 +00:00
Devin AI
3348de8db7 refactor: Improve Flow state serialization with Pydantic core schema
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-21 03:16:54 +00:00
Devin AI
93ec41225b refactor: Improve Flow state serialization
- Add BaseStateEvent class for common state processing
- Add state serialization caching for performance
- Add tests for nested locks and async context
- Improve error handling and validation
- Enhance documentation

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-21 03:16:35 +00:00
Devin AI
92e1877bf0 fix: Handle thread locks in Flow state serialization
- Add state serialization in Flow events to avoid pickling RLock objects
- Update event emission to use serialized state
- Add test case for Flow with thread locks

Fixes #2120

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-21 03:15:10 +00:00
João Moura
9d0a08206f Merge branch 'main' into devin/1739448321-fix-flow-state-pickling 2025-02-13 15:15:37 -03:00
Devin AI
b02e952c32 fix: Fix thread lock type checking in Flow state serialization
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-13 13:38:07 +00:00
Devin AI
ac703bafc8 fix: Fix type error in Flow state serialization
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-13 13:33:31 +00:00
Devin AI
fd70de34cf style: Fix import sorting in tests
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-13 13:20:55 +00:00
Devin AI
0e6689c19c test: Add comprehensive test for complex nested objects
- Add test for various thread-safe primitives
- Test nested dataclasses with complex state
- Verify serialization of async primitives

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-13 13:15:10 +00:00
Devin AI
cf7a26e009 refactor: Improve Flow state serialization with Pydantic core schema
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-13 13:14:31 +00:00
Devin AI
ed877467e1 refactor: Improve Flow state serialization
- Add BaseStateEvent class for common state processing
- Add state serialization caching for performance
- Add tests for nested locks and async context
- Improve error handling and validation
- Enhance documentation

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-13 12:14:59 +00:00
Devin AI
252095a668 fix: Handle thread locks in Flow state serialization
- Add state serialization in Flow events to avoid pickling RLock objects
- Update event emission to use serialized state
- Add test case for Flow with thread locks

Fixes #2120

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-13 12:07:26 +00:00
7 changed files with 705 additions and 306 deletions

View File

@@ -1,9 +1,12 @@
import asyncio
import copy
import dataclasses
import inspect
import logging
import threading
from typing import (
Any,
Awaitable,
Callable,
Dict,
Generic,
@@ -572,15 +575,173 @@ class Flow(Generic[T], metaclass=FlowMeta):
k: v for k, v in model.__dict__.items() if not k.startswith("_")
}
# Create new instance of the same class
# Create new instance of the same class, handling thread locks
model_class = type(model)
return cast(T, model_class(**state_dict))
serialized_dict = self._serialize_value(state_dict)
return cast(T, model_class(**serialized_dict))
raise TypeError(
f"Initial state must be dict or BaseModel, got {type(self.initial_state)}"
)
def _get_thread_safe_primitive_type(self, value: Any) -> Optional[Type[Union[threading.Lock, threading.RLock, threading.Semaphore, threading.Event, threading.Condition, asyncio.Lock, asyncio.Event, asyncio.Condition, asyncio.Semaphore]]]:
"""Get the type of a thread-safe primitive for recreation.
Args:
value: Any Python value to check
Returns:
The type of the thread-safe primitive, or None if not a primitive
"""
if hasattr(value, '_is_owned') and hasattr(value, 'acquire'):
# Get the actual types since some are factory functions
rlock_type = type(threading.RLock())
lock_type = type(threading.Lock())
semaphore_type = type(threading.Semaphore())
event_type = type(threading.Event())
condition_type = type(threading.Condition())
async_lock_type = type(asyncio.Lock())
async_event_type = type(asyncio.Event())
async_condition_type = type(asyncio.Condition())
async_semaphore_type = type(asyncio.Semaphore())
if isinstance(value, rlock_type):
return threading.RLock
elif isinstance(value, lock_type):
return threading.Lock
elif isinstance(value, semaphore_type):
return threading.Semaphore
elif isinstance(value, event_type):
return threading.Event
elif isinstance(value, condition_type):
return threading.Condition
elif isinstance(value, async_lock_type):
return asyncio.Lock
elif isinstance(value, async_event_type):
return asyncio.Event
elif isinstance(value, async_condition_type):
return asyncio.Condition
elif isinstance(value, async_semaphore_type):
return asyncio.Semaphore
return None
def _serialize_dataclass(self, value: Any) -> Union[Dict[str, Any], Any]:
"""Serialize a dataclass instance.
Args:
value: A dataclass instance
Returns:
A new instance of the dataclass with thread-safe primitives recreated
"""
if not hasattr(value, '__class__'):
return value
if hasattr(value, '__pydantic_validate__'):
return value.__pydantic_validate__()
# Get field values, handling thread-safe primitives
field_values = {}
for field in dataclasses.fields(value):
field_value = getattr(value, field.name)
primitive_type = self._get_thread_safe_primitive_type(field_value)
if primitive_type is not None:
field_values[field.name] = primitive_type()
else:
field_values[field.name] = self._serialize_value(field_value)
# Create new instance
return value.__class__(**field_values)
def _copy_state(self) -> T:
return copy.deepcopy(self._state)
"""Create a deep copy of the current state.
Returns:
A deep copy of the current state object
"""
return self._serialize_value(self._state)
def _serialize_value(self, value: Any) -> Any:
"""Recursively serialize a value, handling nested objects and locks.
Args:
value: Any Python value to serialize
Returns:
Serialized version of the value with thread-safe primitives handled
"""
# Handle None
if value is None:
return None
# Handle thread-safe primitives
primitive_type = self._get_thread_safe_primitive_type(value)
if primitive_type is not None:
return None
# Handle Pydantic models
if isinstance(value, BaseModel):
return type(value)(**{
k: self._serialize_value(v)
for k, v in value.model_dump().items()
})
# Handle dataclasses
if dataclasses.is_dataclass(value):
return self._serialize_dataclass(value)
# Handle dictionaries
if isinstance(value, dict):
return {
k: self._serialize_value(v)
for k, v in value.items()
}
# Handle lists, tuples, and sets
if isinstance(value, (list, tuple, set)):
serialized = [self._serialize_value(item) for item in value]
return (
serialized if isinstance(value, list)
else tuple(serialized) if isinstance(value, tuple)
else set(serialized)
)
# Handle other types
return value
def _serialize_state(self) -> Union[Dict[str, Any], BaseModel]:
"""Serialize the current state for event emission.
This method handles the serialization of both BaseModel and dictionary states,
ensuring thread-safe copying of state data. Uses caching to improve performance
when state hasn't changed. Handles nested objects and locks recursively.
Returns:
Union[Dict[str, Any], BaseModel]: Serialized state as either a new BaseModel instance or dictionary
Raises:
ValueError: If state has invalid type
Exception: If serialization fails, logs error and returns empty dict
"""
try:
if not isinstance(self._state, (dict, BaseModel)):
raise ValueError(f"Invalid state type: {type(self._state)}")
if not hasattr(self, '_last_state_hash'):
self._last_state_hash = None
self._last_serialized_state = None
current_hash = hash(str(self._state))
if current_hash == self._last_state_hash:
return self._last_serialized_state
serialized = self._serialize_value(self._state)
self._last_state_hash = current_hash
self._last_serialized_state = serialized
return serialized
except Exception as e:
logger.error(f"State serialization failed: {str(e)}")
return cast(Dict[str, Any], {})
@property
def state(self) -> T:
@@ -712,7 +873,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
else:
raise TypeError(f"State must be dict or BaseModel, got {type(self._state)}")
def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Union[Any, None]:
"""Start the flow execution.
Args:
@@ -816,12 +977,27 @@ class Flow(Generic[T], metaclass=FlowMeta):
@trace_flow_step
async def _execute_method(
self, method_name: str, method: Callable, *args: Any, **kwargs: Any
self, method_name: str, method: Union[Callable[..., Any], Callable[..., Awaitable[Any]]], *args: Any, **kwargs: Any
) -> Any:
"""Execute a flow method with proper event handling and state management.
Args:
method_name: Name of the method to execute
method: The method to execute
*args: Positional arguments for the method
**kwargs: Keyword arguments for the method
Returns:
The result of the method execution
Raises:
Any exception that occurs during method execution
"""
try:
dumped_params = {f"_{i}": arg for i, arg in enumerate(args)} | (
kwargs or {}
)
# Serialize state before event emission to avoid pickling issues
state_copy = self._serialize_state()
dumped_params = {f"_{i}": arg for i, arg in enumerate(args)} | (kwargs or {})
crewai_event_bus.emit(
self,
MethodExecutionStartedEvent(
@@ -829,7 +1005,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
method_name=method_name,
flow_name=self.__class__.__name__,
params=dumped_params,
state=self._copy_state(),
state=state_copy,
),
)
@@ -844,13 +1020,16 @@ class Flow(Generic[T], metaclass=FlowMeta):
self._method_execution_counts.get(method_name, 0) + 1
)
# Serialize state after execution
state_copy = self._serialize_state()
crewai_event_bus.emit(
self,
MethodExecutionFinishedEvent(
type="method_execution_finished",
method_name=method_name,
flow_name=self.__class__.__name__,
state=self._copy_state(),
state=state_copy,
result=result,
),
)
@@ -918,7 +1097,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
await asyncio.gather(*tasks)
def _find_triggered_methods(
self, trigger_method: str, router_only: bool
self, trigger_method: str, router_only: bool = False
) -> List[str]:
"""
Finds all methods that should be triggered based on conditions.
@@ -1028,7 +1207,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
traceback.print_exc()
def _log_flow_event(
self, message: str, color: str = "yellow", level: str = "info"
self, message: str, color: Optional[str] = "yellow", level: Optional[str] = "info"
) -> None:
"""Centralized logging method for flow events.
@@ -1053,7 +1232,12 @@ class Flow(Generic[T], metaclass=FlowMeta):
elif level == "warning":
logger.warning(message)
def plot(self, filename: str = "crewai_flow") -> None:
def plot(self, filename: Optional[str] = "crewai_flow") -> None:
"""Plot the flow graph visualization.
Args:
filename: Optional name for the output file (default: "crewai_flow")
"""
crewai_event_bus.emit(
self,
FlowPlotEvent(

View File

@@ -0,0 +1,78 @@
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Dict, Optional, Union
from pydantic import BaseModel
@dataclass
class Event:
type: str
flow_name: str
timestamp: datetime = field(init=False)
def __post_init__(self):
self.timestamp = datetime.now()
@dataclass
class BaseStateEvent(Event):
"""Base class for events containing state data.
Handles common state serialization and validation logic to ensure thread-safe
state handling and proper type validation.
Raises:
ValueError: If state has invalid type
"""
state: Union[Dict[str, Any], BaseModel]
def __post_init__(self):
super().__post_init__()
self._process_state()
def _process_state(self):
"""Process and validate state data.
Ensures state is of valid type and creates a new instance of BaseModel
states to avoid thread lock serialization issues.
Raises:
ValueError: If state has invalid type
"""
if not isinstance(self.state, (dict, BaseModel)):
raise ValueError(f"Invalid state type: {type(self.state)}")
if isinstance(self.state, BaseModel):
self.state = type(self.state)(**self.state.model_dump())
@dataclass
class FlowStartedEvent(Event):
inputs: Optional[Dict[str, Any]] = None
@dataclass
class MethodExecutionStartedEvent(BaseStateEvent):
method_name: str
state: Union[Dict[str, Any], BaseModel]
params: Optional[Dict[str, Any]] = None
def __post_init__(self):
super().__post_init__()
self._process_state()
@dataclass
class MethodExecutionFinishedEvent(BaseStateEvent):
method_name: str
state: Union[Dict[str, Any], BaseModel]
result: Any = None
def __post_init__(self):
super().__post_init__()
self._process_state()
@dataclass
class FlowFinishedEvent(Event):
result: Optional[Any] = None

View File

@@ -3,20 +3,17 @@ import inspect
import json
import logging
import threading
import typing
import uuid
from concurrent.futures import Future
from copy import copy
from hashlib import md5
from pathlib import Path
from typing import (
AbstractSet,
Any,
Callable,
ClassVar,
Dict,
List,
Mapping,
Optional,
Set,
Tuple,
@@ -35,7 +32,6 @@ from pydantic import (
from pydantic_core import PydanticCustomError
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.tasks.exceptions import GuardrailValidationError
from crewai.tasks.guardrail_result import GuardrailResult
from crewai.tasks.output_format import OutputFormat
from crewai.tasks.task_output import TaskOutput
@@ -117,7 +113,7 @@ class Task(BaseModel):
description="Task output, it's final result after being executed", default=None
)
tools: Optional[List[BaseTool]] = Field(
default_factory=list[BaseTool],
default_factory=list,
description="Tools the agent is limited to use for this task.",
)
id: UUID4 = Field(
@@ -133,7 +129,7 @@ class Task(BaseModel):
description="A converter class used to export structured output",
default=None,
)
processed_by_agents: Set[str] = Field(default_factory=set[str])
processed_by_agents: Set[str] = Field(default_factory=set)
guardrail: Optional[Callable[[TaskOutput], Tuple[bool, Any]]] = Field(
default=None,
description="Function to validate task output before proceeding to next task",
@@ -155,8 +151,8 @@ class Task(BaseModel):
"""Validate that the guardrail function has the correct signature and behavior.
While type hints provide static checking, this validator ensures runtime safety by:
1. Verifying the function accepts exactly one required positional parameter (the TaskOutput)
2. Checking return type annotations match tuple[bool, Any] or specific types like tuple[bool, str]
1. Verifying the function accepts exactly one parameter (the TaskOutput)
2. Checking return type annotations match Tuple[bool, Any] if present
3. Providing clear, immediate error messages for debugging
This runtime validation is crucial because:
@@ -164,24 +160,6 @@ class Task(BaseModel):
- Function signatures need immediate validation before task execution
- Clear error messages help users debug guardrail implementation issues
Examples:
Simple validation with new style annotation:
>>> def validate_output(result: TaskOutput) -> tuple[bool, str]:
... return (True, result.raw.upper())
Validation with optional parameters:
>>> def validate_with_options(result: TaskOutput, strict: bool = True) -> tuple[bool, str]:
... if strict and not result.raw.isupper():
... return (False, "Text must be uppercase")
... return (True, result.raw)
Validation with specific return type:
>>> def validate_task_output(result: TaskOutput) -> tuple[bool, TaskOutput]:
... if not result.raw:
... return (False, result)
... result.raw = result.raw.strip()
... return (True, result)
Args:
v: The guardrail function to validate
@@ -190,57 +168,22 @@ class Task(BaseModel):
Raises:
ValueError: If the function signature is invalid or return annotation
doesn't match tuple[bool, Any] or specific allowed types
doesn't match Tuple[bool, Any]
"""
if v is not None:
sig = inspect.signature(v)
# Get required positional parameters (excluding those with defaults)
required_params = [
param for param in sig.parameters.values()
if param.default == inspect.Parameter.empty
and param.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
]
keyword_only_params = [
param for param in sig.parameters.values()
if param.kind == inspect.Parameter.KEYWORD_ONLY
]
if len(required_params) != 1 or (len(keyword_only_params) > 0 and any(p.default == inspect.Parameter.empty for p in keyword_only_params)):
raise GuardrailValidationError(
"Guardrail function must accept exactly one required positional parameter and no required keyword-only parameters",
{"params": [str(p) for p in sig.parameters.values()]}
)
if len(sig.parameters) != 1:
raise ValueError("Guardrail function must accept exactly one parameter")
# Check return annotation if present, but don't require it
type_hints = typing.get_type_hints(v)
return_annotation = type_hints.get('return')
if return_annotation:
# Convert annotation to string for comparison
annotation_str = str(return_annotation).lower().replace(' ', '')
# Normalize type strings
normalized_annotation = (
annotation_str.replace('typing.', '')
.replace('dict[str,typing.any]', 'dict[str,any]')
.replace('dict[str, any]', 'dict[str,any]')
)
VALID_RETURN_TYPES = {
'tuple[bool,any]',
'tuple[bool,str]',
'tuple[bool,dict[str,any]]',
'tuple[bool,taskoutput]'
}
# Check if the normalized annotation matches any valid pattern
is_valid = normalized_annotation == 'tuple[bool,any]'
if not is_valid:
is_valid = normalized_annotation in VALID_RETURN_TYPES
if not is_valid:
raise GuardrailValidationError(
f"Invalid return type annotation. Expected one of: "
f"{', '.join(VALID_RETURN_TYPES)}",
{"got": annotation_str}
return_annotation = sig.return_annotation
if return_annotation != inspect.Signature.empty:
if not (
return_annotation == Tuple[bool, Any]
or str(return_annotation) == "Tuple[bool, Any]"
):
raise ValueError(
"If return type is annotated, it must be Tuple[bool, Any]"
)
return v
@@ -468,7 +411,6 @@ class Task(BaseModel):
"Task guardrail returned None as result. This is not allowed."
)
# Handle different result types
if isinstance(guardrail_result.result, str):
task_output.raw = guardrail_result.result
pydantic_output, json_output = self._export_output(
@@ -478,13 +420,6 @@ class Task(BaseModel):
task_output.json_dict = json_output
elif isinstance(guardrail_result.result, TaskOutput):
task_output = guardrail_result.result
elif isinstance(guardrail_result.result, dict):
task_output.raw = guardrail_result.result
task_output.json_dict = guardrail_result.result
pydantic_output, _ = self._export_output(
json.dumps(guardrail_result.result)
)
task_output.pydantic = pydantic_output
self.output = task_output
self.end_time = datetime.datetime.now()
@@ -675,74 +610,40 @@ class Task(BaseModel):
self.delegations += 1
def copy(
self,
agents: List["BaseAgent"] | None = None,
task_mapping: Dict[str, "Task"] | None = None,
*,
include: AbstractSet[int] | AbstractSet[str] | Mapping[int, Any] | Mapping[str, Any] | None = None,
exclude: AbstractSet[int] | AbstractSet[str] | Mapping[int, Any] | Mapping[str, Any] | None = None,
update: dict[str, Any] | None = None,
deep: bool = False,
self, agents: List["BaseAgent"], task_mapping: Dict[str, "Task"]
) -> "Task":
"""Create a deep copy of the Task.
Args:
agents: Optional list of agents to copy agent references
task_mapping: Optional mapping of task keys to tasks for context
include: Fields to include in the copy
exclude: Fields to exclude from the copy
update: Fields to update in the copy
deep: Whether to perform a deep copy
"""
if agents is None and task_mapping is None:
# New style copy using BaseModel
copied = super().copy(
include=include,
exclude=exclude,
update=update,
deep=deep,
)
# Copy mutable fields
if self.tools:
copied.tools = copy(self.tools)
if self.context:
copied.context = copy(self.context)
return copied
# Legacy copy behavior
exclude_fields = {
"""Create a deep copy of the Task."""
exclude = {
"id",
"agent",
"context",
"tools",
}
copied_data = self.model_dump(exclude=exclude_fields)
copied_data = self.model_dump(exclude=exclude)
copied_data = {k: v for k, v in copied_data.items() if v is not None}
cloned_context = (
[task_mapping[context_task.key] for context_task in self.context]
if self.context and task_mapping
if self.context
else None
)
def get_agent_by_role(role: str) -> Union["BaseAgent", None]:
if not agents:
return None
return next((agent for agent in agents if agent.role == role), None)
cloned_agent = get_agent_by_role(self.agent.role) if self.agent else None
cloned_tools = copy(self.tools) if self.tools else []
return Task(
copied_task = Task(
**copied_data,
context=cloned_context,
agent=cloned_agent,
tools=cloned_tools,
)
return copied_task
def _export_output(
self, result: str
) -> Tuple[Optional[BaseModel], Optional[Dict[str, Any]]]:

View File

@@ -1,25 +0,0 @@
"""
Module for task-related exceptions.
This module provides custom exceptions used throughout the task system
to provide more specific error handling and context.
"""
from typing import Any, Dict, Optional
class GuardrailValidationError(Exception):
"""Exception raised for guardrail validation errors.
This exception provides detailed context about why a guardrail
validation failed, including the specific validation that failed
and any relevant context information.
Attributes:
message: A clear description of the validation error
context: Optional dictionary containing additional error context
"""
def __init__(self, message: str, context: Optional[Dict[str, Any]] = None):
self.message = message
self.context = context or {}
super().__init__(self.message)

View File

@@ -1,5 +1,5 @@
import json
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field, model_validator
@@ -15,7 +15,7 @@ class TaskOutput(BaseModel):
description="Expected output of the task", default=None
)
summary: Optional[str] = Field(description="Summary of the task", default=None)
raw: Any = Field(description="Raw output of the task", default="")
raw: str = Field(description="Raw output of the task", default="")
pydantic: Optional[BaseModel] = Field(
description="Pydantic output of task", default=None
)

View File

@@ -2,9 +2,11 @@
import asyncio
from datetime import datetime
from typing import Dict, List, Optional, Set, Tuple
from uuid import uuid4
import pytest
from pydantic import BaseModel
from pydantic import BaseModel, Field
from crewai.flow.flow import Flow, and_, listen, or_, router, start
from crewai.utilities.events import (
@@ -350,6 +352,315 @@ def test_flow_uuid_structured():
assert flow.state.message == "final"
def test_flow_with_thread_lock():
"""Test that Flow properly handles thread locks in state."""
import threading
class LockFlow(Flow):
def __init__(self):
super().__init__()
self.lock = threading.RLock()
self.counter = 0
@start()
async def step_1(self):
with self.lock:
self.counter += 1
return "step 1"
@listen(step_1)
async def step_2(self, result):
with self.lock:
self.counter += 1
return result + " -> step 2"
flow = LockFlow()
result = flow.kickoff()
assert result == "step 1 -> step 2"
assert flow.counter == 2
def test_flow_with_nested_objects_and_locks():
"""Test that Flow properly handles nested objects containing locks."""
import threading
from dataclasses import dataclass
from typing import Dict, List, Optional
@dataclass
class NestedState:
value: str
lock: threading.RLock = None
def __post_init__(self):
if self.lock is None:
self.lock = threading.RLock()
def __pydantic_validate__(self):
return {"value": self.value, "lock": threading.RLock()}
@classmethod
def __get_pydantic_core_schema__(cls, source_type, handler):
from pydantic_core.core_schema import (
str_schema,
with_info_plain_validator_function,
)
def validate(value, _):
if isinstance(value, cls):
return value
if isinstance(value, dict):
return cls(value["value"])
raise ValueError(f"Invalid value type for {cls.__name__}")
return with_info_plain_validator_function(validate)
class ComplexState(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4()))
name: str
nested: NestedState
items: List[NestedState]
mapping: Dict[str, NestedState]
optional: Optional[NestedState] = None
class ComplexStateFlow(Flow[ComplexState]):
def __init__(self):
self.initial_state = ComplexState(
name="test",
nested=NestedState("nested", threading.RLock()),
items=[
NestedState("item1", threading.RLock()),
NestedState("item2", threading.RLock())
],
mapping={
"key1": NestedState("map1", threading.RLock()),
"key2": NestedState("map2", threading.RLock())
},
optional=NestedState("optional", threading.RLock())
)
super().__init__()
@start()
async def step_1(self):
with self.state.nested.lock:
return "step 1"
@listen(step_1)
async def step_2(self, result):
with self.state.items[0].lock:
with self.state.mapping["key1"].lock:
with self.state.optional.lock:
return result + " -> step 2"
flow = ComplexStateFlow()
result = flow.kickoff()
assert result == "step 1 -> step 2"
def test_flow_with_nested_locks():
"""Test that Flow properly handles nested thread locks."""
import threading
class NestedLockFlow(Flow):
def __init__(self):
super().__init__()
self.outer_lock = threading.RLock()
self.inner_lock = threading.RLock()
self.counter = 0
@start()
async def step_1(self):
with self.outer_lock:
with self.inner_lock:
self.counter += 1
return "step 1"
@listen(step_1)
async def step_2(self, result):
with self.outer_lock:
with self.inner_lock:
self.counter += 1
return result + " -> step 2"
flow = NestedLockFlow()
result = flow.kickoff()
assert result == "step 1 -> step 2"
assert flow.counter == 2
@pytest.mark.asyncio
async def test_flow_with_async_locks():
"""Test that Flow properly handles locks in async context."""
import asyncio
import threading
class AsyncLockFlow(Flow):
def __init__(self):
super().__init__()
self.lock = threading.RLock()
self.async_lock = asyncio.Lock()
self.counter = 0
@start()
async def step_1(self):
async with self.async_lock:
with self.lock:
self.counter += 1
return "step 1"
@listen(step_1)
async def step_2(self, result):
async with self.async_lock:
with self.lock:
self.counter += 1
return result + " -> step 2"
flow = AsyncLockFlow()
result = await flow.kickoff_async()
assert result == "step 1 -> step 2"
assert flow.counter == 2
def test_flow_with_complex_nested_objects():
"""Test that Flow properly handles complex nested objects."""
import asyncio
import threading
from dataclasses import dataclass
@dataclass
class ThreadSafePrimitives:
thread_lock: threading.Lock
rlock: threading.RLock
semaphore: threading.Semaphore
event: threading.Event
async_lock: asyncio.Lock
async_event: asyncio.Event
def __post_init__(self):
self.thread_lock = self.thread_lock or threading.Lock()
self.rlock = self.rlock or threading.RLock()
self.semaphore = self.semaphore or threading.Semaphore()
self.event = self.event or threading.Event()
self.async_lock = self.async_lock or asyncio.Lock()
self.async_event = self.async_event or asyncio.Event()
def __pydantic_validate__(self):
return {
"thread_lock": None,
"rlock": None,
"semaphore": None,
"event": None,
"async_lock": None,
"async_event": None
}
@classmethod
def __get_pydantic_core_schema__(cls, source_type, handler):
from pydantic_core.core_schema import with_info_plain_validator_function
def validate(value, _):
if isinstance(value, cls):
return value
if isinstance(value, dict):
return cls(
thread_lock=None,
rlock=None,
semaphore=None,
event=None,
async_lock=None,
async_event=None
)
raise ValueError(f"Invalid value type for {cls.__name__}")
return with_info_plain_validator_function(validate)
@dataclass
class NestedContainer:
name: str
primitives: ThreadSafePrimitives
items: List[ThreadSafePrimitives]
mapping: Dict[str, ThreadSafePrimitives]
optional: Optional[ThreadSafePrimitives]
def __post_init__(self):
self.primitives = self.primitives or ThreadSafePrimitives(None, None, None, None, None, None)
self.items = self.items or []
self.mapping = self.mapping or {}
def __pydantic_validate__(self):
return {
"name": self.name,
"primitives": self.primitives.__pydantic_validate__(),
"items": [item.__pydantic_validate__() for item in self.items],
"mapping": {k: v.__pydantic_validate__() for k, v in self.mapping.items()},
"optional": self.optional.__pydantic_validate__() if self.optional else None
}
@classmethod
def __get_pydantic_core_schema__(cls, source_type, handler):
from pydantic_core.core_schema import with_info_plain_validator_function
def validate(value, _):
if isinstance(value, cls):
return value
if isinstance(value, dict):
return cls(
name=value["name"],
primitives=ThreadSafePrimitives(None, None, None, None, None, None),
items=[],
mapping={},
optional=None
)
raise ValueError(f"Invalid value type for {cls.__name__}")
return with_info_plain_validator_function(validate)
class ComplexState(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4()))
name: str
nested: NestedContainer
items: List[NestedContainer]
mapping: Dict[str, NestedContainer]
optional: Optional[NestedContainer] = None
class ComplexStateFlow(Flow[ComplexState]):
def __init__(self):
primitives = ThreadSafePrimitives(
thread_lock=threading.Lock(),
rlock=threading.RLock(),
semaphore=threading.Semaphore(),
event=threading.Event(),
async_lock=asyncio.Lock(),
async_event=asyncio.Event()
)
container = NestedContainer(
name="test",
primitives=primitives,
items=[primitives],
mapping={"key": primitives},
optional=primitives
)
self.initial_state = ComplexState(
name="test",
nested=container,
items=[container],
mapping={"key": container},
optional=container
)
super().__init__()
@start()
async def step_1(self):
with self.state.nested.primitives.rlock:
return "step 1"
@listen(step_1)
async def step_2(self, result):
with self.state.items[0].primitives.rlock:
return result + " -> step 2"
flow = ComplexStateFlow()
result = flow.kickoff()
assert result == "step 1 -> step 2"
def test_router_with_multiple_conditions():
"""Test a router that triggers when any of multiple steps complete (OR condition),
and another router that triggers only after all specified steps complete (AND condition).

View File

@@ -1,179 +1,129 @@
"""Tests for task guardrails functionality."""
from typing import Any, Dict
from unittest.mock import Mock
import pytest
from crewai.task import Task
from crewai.tasks.exceptions import GuardrailValidationError
from crewai.tasks.task_output import TaskOutput
class TestTaskGuardrails:
"""Test suite for task guardrail functionality."""
def test_task_without_guardrail():
"""Test that tasks work normally without guardrails (backward compatibility)."""
agent = Mock()
agent.role = "test_agent"
agent.execute_task.return_value = "test result"
agent.crew = None
@pytest.fixture
def mock_agent(self):
"""Fixture providing a mock agent for testing."""
agent = Mock()
agent.role = "test_agent"
agent.crew = None
return agent
task = Task(description="Test task", expected_output="Output")
def test_task_without_guardrail(self, mock_agent):
"""Test that tasks work normally without guardrails (backward compatibility)."""
mock_agent.execute_task.return_value = "test result"
task = Task(description="Test task", expected_output="Output")
result = task.execute_sync(agent=mock_agent)
assert isinstance(result, TaskOutput)
assert result.raw == "test result"
result = task.execute_sync(agent=agent)
assert isinstance(result, TaskOutput)
assert result.raw == "test result"
def test_task_with_successful_guardrail(self, mock_agent):
"""Test that successful guardrail validation passes transformed result."""
def guardrail(result: TaskOutput):
return (True, result.raw.upper())
def test_task_with_successful_guardrail():
"""Test that successful guardrail validation passes transformed result."""
mock_agent.execute_task.return_value = "test result"
task = Task(description="Test task", expected_output="Output", guardrail=guardrail)
def guardrail(result: TaskOutput):
return (True, result.raw.upper())
result = task.execute_sync(agent=mock_agent)
assert isinstance(result, TaskOutput)
assert result.raw == "TEST RESULT"
agent = Mock()
agent.role = "test_agent"
agent.execute_task.return_value = "test result"
agent.crew = None
task = Task(description="Test task", expected_output="Output", guardrail=guardrail)
result = task.execute_sync(agent=agent)
assert isinstance(result, TaskOutput)
assert result.raw == "TEST RESULT"
def test_task_with_failing_guardrail(self, mock_agent):
"""Test that failing guardrail triggers retry with error context."""
def guardrail(result: TaskOutput):
return (False, "Invalid format")
def test_task_with_failing_guardrail():
"""Test that failing guardrail triggers retry with error context."""
mock_agent.execute_task.side_effect = ["bad result", "good result"]
task = Task(
description="Test task",
expected_output="Output",
guardrail=guardrail,
max_retries=1,
)
def guardrail(result: TaskOutput):
return (False, "Invalid format")
# First execution fails guardrail, second succeeds
mock_agent.execute_task.side_effect = ["bad result", "good result"]
with pytest.raises(Exception) as exc_info:
task.execute_sync(agent=mock_agent)
agent = Mock()
agent.role = "test_agent"
agent.execute_task.side_effect = ["bad result", "good result"]
agent.crew = None
assert "Task failed guardrail validation" in str(exc_info.value)
assert task.retry_count == 1
task = Task(
description="Test task",
expected_output="Output",
guardrail=guardrail,
max_retries=1,
)
# First execution fails guardrail, second succeeds
agent.execute_task.side_effect = ["bad result", "good result"]
with pytest.raises(Exception) as exc_info:
task.execute_sync(agent=agent)
assert "Task failed guardrail validation" in str(exc_info.value)
assert task.retry_count == 1
def test_task_with_guardrail_retries(self, mock_agent):
"""Test that guardrail respects max_retries configuration."""
def guardrail(result: TaskOutput):
return (False, "Invalid format")
def test_task_with_guardrail_retries():
"""Test that guardrail respects max_retries configuration."""
mock_agent.execute_task.return_value = "bad result"
task = Task(
description="Test task",
expected_output="Output",
guardrail=guardrail,
max_retries=2,
)
def guardrail(result: TaskOutput):
return (False, "Invalid format")
with pytest.raises(Exception) as exc_info:
task.execute_sync(agent=mock_agent)
agent = Mock()
agent.role = "test_agent"
agent.execute_task.return_value = "bad result"
agent.crew = None
assert task.retry_count == 2
assert "Task failed guardrail validation after 2 retries" in str(exc_info.value)
assert "Invalid format" in str(exc_info.value)
task = Task(
description="Test task",
expected_output="Output",
guardrail=guardrail,
max_retries=2,
)
with pytest.raises(Exception) as exc_info:
task.execute_sync(agent=agent)
assert task.retry_count == 2
assert "Task failed guardrail validation after 2 retries" in str(exc_info.value)
assert "Invalid format" in str(exc_info.value)
def test_guardrail_error_in_context(self, mock_agent):
"""Test that guardrail error is passed in context for retry."""
def guardrail(result: TaskOutput):
return (False, "Expected JSON, got string")
def test_guardrail_error_in_context():
"""Test that guardrail error is passed in context for retry."""
task = Task(
description="Test task",
expected_output="Output",
guardrail=guardrail,
max_retries=1,
)
def guardrail(result: TaskOutput):
return (False, "Expected JSON, got string")
# Mock execute_task to succeed on second attempt
first_call = True
def execute_task(task, context, tools):
nonlocal first_call
if first_call:
first_call = False
return "invalid"
return '{"valid": "json"}'
agent = Mock()
agent.role = "test_agent"
agent.crew = None
mock_agent.execute_task.side_effect = execute_task
task = Task(
description="Test task",
expected_output="Output",
guardrail=guardrail,
max_retries=1,
)
with pytest.raises(Exception) as exc_info:
task.execute_sync(agent=mock_agent)
# Mock execute_task to succeed on second attempt
first_call = True
assert "Task failed guardrail validation" in str(exc_info.value)
assert "Expected JSON, got string" in str(exc_info.value)
def execute_task(task, context, tools):
nonlocal first_call
if first_call:
first_call = False
return "invalid"
return '{"valid": "json"}'
agent.execute_task.side_effect = execute_task
def test_guardrail_with_new_style_annotation(self, mock_agent):
"""Test guardrail with new style tuple annotation."""
def guardrail(result: TaskOutput) -> tuple[bool, str]:
return (True, result.raw.upper())
mock_agent.execute_task.return_value = "test result"
task = Task(
description="Test task",
expected_output="Output",
guardrail=guardrail
)
with pytest.raises(Exception) as exc_info:
task.execute_sync(agent=agent)
result = task.execute_sync(agent=mock_agent)
assert isinstance(result, TaskOutput)
assert result.raw == "TEST RESULT"
def test_guardrail_with_optional_params(self, mock_agent):
"""Test guardrail with optional parameters."""
def guardrail(result: TaskOutput, optional_param: str = "default") -> tuple[bool, str]:
return (True, f"{result.raw}-{optional_param}")
mock_agent.execute_task.return_value = "test"
task = Task(
description="Test task",
expected_output="Output",
guardrail=guardrail
)
result = task.execute_sync(agent=mock_agent)
assert isinstance(result, TaskOutput)
assert result.raw == "test-default"
def test_guardrail_with_invalid_optional_params(self, mock_agent):
"""Test guardrail with invalid optional parameters."""
def guardrail(result: TaskOutput, *, required_kwonly: str) -> tuple[bool, str]:
return (True, result.raw)
with pytest.raises(GuardrailValidationError) as exc_info:
Task(
description="Test task",
expected_output="Output",
guardrail=guardrail
)
assert "exactly one required positional parameter" in str(exc_info.value)
def test_guardrail_with_dict_return_type(self, mock_agent):
"""Test guardrail with dict return type."""
def guardrail(result: TaskOutput) -> tuple[bool, dict[str, Any]]:
return (True, {"processed": result.raw.upper()})
mock_agent.execute_task.return_value = "test"
task = Task(
description="Test task",
expected_output="Output",
guardrail=guardrail
)
result = task.execute_sync(agent=mock_agent)
assert isinstance(result, TaskOutput)
assert result.raw == {"processed": "TEST"}
assert "Task failed guardrail validation" in str(exc_info.value)
assert "Expected JSON, got string" in str(exc_info.value)