mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
test: Add comprehensive test for complex nested objects
- Add test for various thread-safe primitives - Test nested dataclasses with complex state - Verify serialization of async primitives Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -518,6 +518,147 @@ async def test_flow_with_async_locks():
|
|||||||
assert flow.counter == 2
|
assert flow.counter == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_flow_with_complex_nested_objects():
|
||||||
|
"""Test that Flow properly handles complex nested objects."""
|
||||||
|
import threading
|
||||||
|
import asyncio
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ThreadSafePrimitives:
|
||||||
|
thread_lock: threading.Lock
|
||||||
|
rlock: threading.RLock
|
||||||
|
semaphore: threading.Semaphore
|
||||||
|
event: threading.Event
|
||||||
|
async_lock: asyncio.Lock
|
||||||
|
async_event: asyncio.Event
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
self.thread_lock = self.thread_lock or threading.Lock()
|
||||||
|
self.rlock = self.rlock or threading.RLock()
|
||||||
|
self.semaphore = self.semaphore or threading.Semaphore()
|
||||||
|
self.event = self.event or threading.Event()
|
||||||
|
self.async_lock = self.async_lock or asyncio.Lock()
|
||||||
|
self.async_event = self.async_event or asyncio.Event()
|
||||||
|
|
||||||
|
def __pydantic_validate__(self):
|
||||||
|
return {
|
||||||
|
"thread_lock": None,
|
||||||
|
"rlock": None,
|
||||||
|
"semaphore": None,
|
||||||
|
"event": None,
|
||||||
|
"async_lock": None,
|
||||||
|
"async_event": None
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(cls, source_type, handler):
|
||||||
|
from pydantic_core.core_schema import with_info_plain_validator_function
|
||||||
|
def validate(value, _):
|
||||||
|
if isinstance(value, cls):
|
||||||
|
return value
|
||||||
|
if isinstance(value, dict):
|
||||||
|
return cls(
|
||||||
|
thread_lock=None,
|
||||||
|
rlock=None,
|
||||||
|
semaphore=None,
|
||||||
|
event=None,
|
||||||
|
async_lock=None,
|
||||||
|
async_event=None
|
||||||
|
)
|
||||||
|
raise ValueError(f"Invalid value type for {cls.__name__}")
|
||||||
|
return with_info_plain_validator_function(validate)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NestedContainer:
|
||||||
|
name: str
|
||||||
|
primitives: ThreadSafePrimitives
|
||||||
|
items: List[ThreadSafePrimitives]
|
||||||
|
mapping: Dict[str, ThreadSafePrimitives]
|
||||||
|
optional: Optional[ThreadSafePrimitives]
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
self.primitives = self.primitives or ThreadSafePrimitives(None, None, None, None, None, None)
|
||||||
|
self.items = self.items or []
|
||||||
|
self.mapping = self.mapping or {}
|
||||||
|
|
||||||
|
def __pydantic_validate__(self):
|
||||||
|
return {
|
||||||
|
"name": self.name,
|
||||||
|
"primitives": self.primitives.__pydantic_validate__(),
|
||||||
|
"items": [item.__pydantic_validate__() for item in self.items],
|
||||||
|
"mapping": {k: v.__pydantic_validate__() for k, v in self.mapping.items()},
|
||||||
|
"optional": self.optional.__pydantic_validate__() if self.optional else None
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(cls, source_type, handler):
|
||||||
|
from pydantic_core.core_schema import with_info_plain_validator_function
|
||||||
|
def validate(value, _):
|
||||||
|
if isinstance(value, cls):
|
||||||
|
return value
|
||||||
|
if isinstance(value, dict):
|
||||||
|
return cls(
|
||||||
|
name=value["name"],
|
||||||
|
primitives=ThreadSafePrimitives(None, None, None, None, None, None),
|
||||||
|
items=[],
|
||||||
|
mapping={},
|
||||||
|
optional=None
|
||||||
|
)
|
||||||
|
raise ValueError(f"Invalid value type for {cls.__name__}")
|
||||||
|
return with_info_plain_validator_function(validate)
|
||||||
|
|
||||||
|
class ComplexState(BaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||||
|
name: str
|
||||||
|
nested: NestedContainer
|
||||||
|
items: List[NestedContainer]
|
||||||
|
mapping: Dict[str, NestedContainer]
|
||||||
|
optional: Optional[NestedContainer] = None
|
||||||
|
|
||||||
|
class ComplexStateFlow(Flow[ComplexState]):
|
||||||
|
def __init__(self):
|
||||||
|
primitives = ThreadSafePrimitives(
|
||||||
|
thread_lock=threading.Lock(),
|
||||||
|
rlock=threading.RLock(),
|
||||||
|
semaphore=threading.Semaphore(),
|
||||||
|
event=threading.Event(),
|
||||||
|
async_lock=asyncio.Lock(),
|
||||||
|
async_event=asyncio.Event()
|
||||||
|
)
|
||||||
|
container = NestedContainer(
|
||||||
|
name="test",
|
||||||
|
primitives=primitives,
|
||||||
|
items=[primitives],
|
||||||
|
mapping={"key": primitives},
|
||||||
|
optional=primitives
|
||||||
|
)
|
||||||
|
self.initial_state = ComplexState(
|
||||||
|
name="test",
|
||||||
|
nested=container,
|
||||||
|
items=[container],
|
||||||
|
mapping={"key": container},
|
||||||
|
optional=container
|
||||||
|
)
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@start()
|
||||||
|
async def step_1(self):
|
||||||
|
with self.state.nested.primitives.rlock:
|
||||||
|
return "step 1"
|
||||||
|
|
||||||
|
@listen(step_1)
|
||||||
|
async def step_2(self, result):
|
||||||
|
with self.state.items[0].primitives.rlock:
|
||||||
|
return result + " -> step 2"
|
||||||
|
|
||||||
|
flow = ComplexStateFlow()
|
||||||
|
result = flow.kickoff()
|
||||||
|
|
||||||
|
assert result == "step 1 -> step 2"
|
||||||
|
|
||||||
|
|
||||||
def test_router_with_multiple_conditions():
|
def test_router_with_multiple_conditions():
|
||||||
"""Test a router that triggers when any of multiple steps complete (OR condition),
|
"""Test a router that triggers when any of multiple steps complete (OR condition),
|
||||||
and another router that triggers only after all specified steps complete (AND condition).
|
and another router that triggers only after all specified steps complete (AND condition).
|
||||||
|
|||||||
Reference in New Issue
Block a user