mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-07 15:18:29 +00:00
test: Add comprehensive test for complex nested objects
- Add test for various thread-safe primitives - Test nested dataclasses with complex state - Verify serialization of async primitives Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -520,6 +520,147 @@ async def test_flow_with_async_locks():
|
||||
assert flow.counter == 2
|
||||
|
||||
|
||||
def test_flow_with_complex_nested_objects():
|
||||
"""Test that Flow properly handles complex nested objects."""
|
||||
import threading
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
@dataclass
|
||||
class ThreadSafePrimitives:
|
||||
thread_lock: threading.Lock
|
||||
rlock: threading.RLock
|
||||
semaphore: threading.Semaphore
|
||||
event: threading.Event
|
||||
async_lock: asyncio.Lock
|
||||
async_event: asyncio.Event
|
||||
|
||||
def __post_init__(self):
|
||||
self.thread_lock = self.thread_lock or threading.Lock()
|
||||
self.rlock = self.rlock or threading.RLock()
|
||||
self.semaphore = self.semaphore or threading.Semaphore()
|
||||
self.event = self.event or threading.Event()
|
||||
self.async_lock = self.async_lock or asyncio.Lock()
|
||||
self.async_event = self.async_event or asyncio.Event()
|
||||
|
||||
def __pydantic_validate__(self):
|
||||
return {
|
||||
"thread_lock": None,
|
||||
"rlock": None,
|
||||
"semaphore": None,
|
||||
"event": None,
|
||||
"async_lock": None,
|
||||
"async_event": None
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, source_type, handler):
|
||||
from pydantic_core.core_schema import with_info_plain_validator_function
|
||||
def validate(value, _):
|
||||
if isinstance(value, cls):
|
||||
return value
|
||||
if isinstance(value, dict):
|
||||
return cls(
|
||||
thread_lock=None,
|
||||
rlock=None,
|
||||
semaphore=None,
|
||||
event=None,
|
||||
async_lock=None,
|
||||
async_event=None
|
||||
)
|
||||
raise ValueError(f"Invalid value type for {cls.__name__}")
|
||||
return with_info_plain_validator_function(validate)
|
||||
|
||||
@dataclass
|
||||
class NestedContainer:
|
||||
name: str
|
||||
primitives: ThreadSafePrimitives
|
||||
items: List[ThreadSafePrimitives]
|
||||
mapping: Dict[str, ThreadSafePrimitives]
|
||||
optional: Optional[ThreadSafePrimitives]
|
||||
|
||||
def __post_init__(self):
|
||||
self.primitives = self.primitives or ThreadSafePrimitives(None, None, None, None, None, None)
|
||||
self.items = self.items or []
|
||||
self.mapping = self.mapping or {}
|
||||
|
||||
def __pydantic_validate__(self):
|
||||
return {
|
||||
"name": self.name,
|
||||
"primitives": self.primitives.__pydantic_validate__(),
|
||||
"items": [item.__pydantic_validate__() for item in self.items],
|
||||
"mapping": {k: v.__pydantic_validate__() for k, v in self.mapping.items()},
|
||||
"optional": self.optional.__pydantic_validate__() if self.optional else None
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, source_type, handler):
|
||||
from pydantic_core.core_schema import with_info_plain_validator_function
|
||||
def validate(value, _):
|
||||
if isinstance(value, cls):
|
||||
return value
|
||||
if isinstance(value, dict):
|
||||
return cls(
|
||||
name=value["name"],
|
||||
primitives=ThreadSafePrimitives(None, None, None, None, None, None),
|
||||
items=[],
|
||||
mapping={},
|
||||
optional=None
|
||||
)
|
||||
raise ValueError(f"Invalid value type for {cls.__name__}")
|
||||
return with_info_plain_validator_function(validate)
|
||||
|
||||
class ComplexState(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
name: str
|
||||
nested: NestedContainer
|
||||
items: List[NestedContainer]
|
||||
mapping: Dict[str, NestedContainer]
|
||||
optional: Optional[NestedContainer] = None
|
||||
|
||||
class ComplexStateFlow(Flow[ComplexState]):
|
||||
def __init__(self):
|
||||
primitives = ThreadSafePrimitives(
|
||||
thread_lock=threading.Lock(),
|
||||
rlock=threading.RLock(),
|
||||
semaphore=threading.Semaphore(),
|
||||
event=threading.Event(),
|
||||
async_lock=asyncio.Lock(),
|
||||
async_event=asyncio.Event()
|
||||
)
|
||||
container = NestedContainer(
|
||||
name="test",
|
||||
primitives=primitives,
|
||||
items=[primitives],
|
||||
mapping={"key": primitives},
|
||||
optional=primitives
|
||||
)
|
||||
self.initial_state = ComplexState(
|
||||
name="test",
|
||||
nested=container,
|
||||
items=[container],
|
||||
mapping={"key": container},
|
||||
optional=container
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
@start()
|
||||
async def step_1(self):
|
||||
with self.state.nested.primitives.rlock:
|
||||
return "step 1"
|
||||
|
||||
@listen(step_1)
|
||||
async def step_2(self, result):
|
||||
with self.state.items[0].primitives.rlock:
|
||||
return result + " -> step 2"
|
||||
|
||||
flow = ComplexStateFlow()
|
||||
result = flow.kickoff()
|
||||
|
||||
assert result == "step 1 -> step 2"
|
||||
|
||||
|
||||
def test_router_with_multiple_conditions():
|
||||
"""Test a router that triggers when any of multiple steps complete (OR condition),
|
||||
and another router that triggers only after all specified steps complete (AND condition).
|
||||
|
||||
Reference in New Issue
Block a user