test: Add comprehensive test for complex nested objects

- Add test for various thread-safe primitives
- Test nested dataclasses with complex state
- Verify serialization of async primitives

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-02-13 13:15:10 +00:00
parent 3348de8db7
commit 2a2c163c3d

View File

@@ -520,6 +520,147 @@ async def test_flow_with_async_locks():
assert flow.counter == 2
def test_flow_with_complex_nested_objects():
"""Test that Flow properly handles complex nested objects."""
import threading
import asyncio
from dataclasses import dataclass
from typing import Dict, List, Optional, Set, Tuple
@dataclass
class ThreadSafePrimitives:
thread_lock: threading.Lock
rlock: threading.RLock
semaphore: threading.Semaphore
event: threading.Event
async_lock: asyncio.Lock
async_event: asyncio.Event
def __post_init__(self):
self.thread_lock = self.thread_lock or threading.Lock()
self.rlock = self.rlock or threading.RLock()
self.semaphore = self.semaphore or threading.Semaphore()
self.event = self.event or threading.Event()
self.async_lock = self.async_lock or asyncio.Lock()
self.async_event = self.async_event or asyncio.Event()
def __pydantic_validate__(self):
return {
"thread_lock": None,
"rlock": None,
"semaphore": None,
"event": None,
"async_lock": None,
"async_event": None
}
@classmethod
def __get_pydantic_core_schema__(cls, source_type, handler):
from pydantic_core.core_schema import with_info_plain_validator_function
def validate(value, _):
if isinstance(value, cls):
return value
if isinstance(value, dict):
return cls(
thread_lock=None,
rlock=None,
semaphore=None,
event=None,
async_lock=None,
async_event=None
)
raise ValueError(f"Invalid value type for {cls.__name__}")
return with_info_plain_validator_function(validate)
@dataclass
class NestedContainer:
name: str
primitives: ThreadSafePrimitives
items: List[ThreadSafePrimitives]
mapping: Dict[str, ThreadSafePrimitives]
optional: Optional[ThreadSafePrimitives]
def __post_init__(self):
self.primitives = self.primitives or ThreadSafePrimitives(None, None, None, None, None, None)
self.items = self.items or []
self.mapping = self.mapping or {}
def __pydantic_validate__(self):
return {
"name": self.name,
"primitives": self.primitives.__pydantic_validate__(),
"items": [item.__pydantic_validate__() for item in self.items],
"mapping": {k: v.__pydantic_validate__() for k, v in self.mapping.items()},
"optional": self.optional.__pydantic_validate__() if self.optional else None
}
@classmethod
def __get_pydantic_core_schema__(cls, source_type, handler):
from pydantic_core.core_schema import with_info_plain_validator_function
def validate(value, _):
if isinstance(value, cls):
return value
if isinstance(value, dict):
return cls(
name=value["name"],
primitives=ThreadSafePrimitives(None, None, None, None, None, None),
items=[],
mapping={},
optional=None
)
raise ValueError(f"Invalid value type for {cls.__name__}")
return with_info_plain_validator_function(validate)
class ComplexState(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4()))
name: str
nested: NestedContainer
items: List[NestedContainer]
mapping: Dict[str, NestedContainer]
optional: Optional[NestedContainer] = None
class ComplexStateFlow(Flow[ComplexState]):
def __init__(self):
primitives = ThreadSafePrimitives(
thread_lock=threading.Lock(),
rlock=threading.RLock(),
semaphore=threading.Semaphore(),
event=threading.Event(),
async_lock=asyncio.Lock(),
async_event=asyncio.Event()
)
container = NestedContainer(
name="test",
primitives=primitives,
items=[primitives],
mapping={"key": primitives},
optional=primitives
)
self.initial_state = ComplexState(
name="test",
nested=container,
items=[container],
mapping={"key": container},
optional=container
)
super().__init__()
@start()
async def step_1(self):
with self.state.nested.primitives.rlock:
return "step 1"
@listen(step_1)
async def step_2(self, result):
with self.state.items[0].primitives.rlock:
return result + " -> step 2"
flow = ComplexStateFlow()
result = flow.kickoff()
assert result == "step 1 -> step 2"
def test_router_with_multiple_conditions():
"""Test a router that triggers when any of multiple steps complete (OR condition),
and another router that triggers only after all specified steps complete (AND condition).