From 0e6689c19c6bf98460d05604b3aecb4d7ab8a716 Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Thu, 13 Feb 2025 13:15:10 +0000 Subject: [PATCH] test: Add comprehensive test for complex nested objects - Add test for various thread-safe primitives - Test nested dataclasses with complex state - Verify serialization of async primitives Co-Authored-By: Joe Moura --- tests/flow_test.py | 141 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 141 insertions(+) diff --git a/tests/flow_test.py b/tests/flow_test.py index b1ef6ba4b..dfae3b5b0 100644 --- a/tests/flow_test.py +++ b/tests/flow_test.py @@ -518,6 +518,147 @@ async def test_flow_with_async_locks(): assert flow.counter == 2 +def test_flow_with_complex_nested_objects(): + """Test that Flow properly handles complex nested objects.""" + import threading + import asyncio + from dataclasses import dataclass + from typing import Dict, List, Optional, Set, Tuple + + @dataclass + class ThreadSafePrimitives: + thread_lock: threading.Lock + rlock: threading.RLock + semaphore: threading.Semaphore + event: threading.Event + async_lock: asyncio.Lock + async_event: asyncio.Event + + def __post_init__(self): + self.thread_lock = self.thread_lock or threading.Lock() + self.rlock = self.rlock or threading.RLock() + self.semaphore = self.semaphore or threading.Semaphore() + self.event = self.event or threading.Event() + self.async_lock = self.async_lock or asyncio.Lock() + self.async_event = self.async_event or asyncio.Event() + + def __pydantic_validate__(self): + return { + "thread_lock": None, + "rlock": None, + "semaphore": None, + "event": None, + "async_lock": None, + "async_event": None + } + + @classmethod + def __get_pydantic_core_schema__(cls, source_type, handler): + from pydantic_core.core_schema import with_info_plain_validator_function + def validate(value, _): + if isinstance(value, cls): + return value + if isinstance(value, dict): + return cls( + thread_lock=None, + rlock=None, + semaphore=None, + event=None, + async_lock=None, + async_event=None + ) + raise ValueError(f"Invalid value type for {cls.__name__}") + return with_info_plain_validator_function(validate) + + @dataclass + class NestedContainer: + name: str + primitives: ThreadSafePrimitives + items: List[ThreadSafePrimitives] + mapping: Dict[str, ThreadSafePrimitives] + optional: Optional[ThreadSafePrimitives] + + def __post_init__(self): + self.primitives = self.primitives or ThreadSafePrimitives(None, None, None, None, None, None) + self.items = self.items or [] + self.mapping = self.mapping or {} + + def __pydantic_validate__(self): + return { + "name": self.name, + "primitives": self.primitives.__pydantic_validate__(), + "items": [item.__pydantic_validate__() for item in self.items], + "mapping": {k: v.__pydantic_validate__() for k, v in self.mapping.items()}, + "optional": self.optional.__pydantic_validate__() if self.optional else None + } + + @classmethod + def __get_pydantic_core_schema__(cls, source_type, handler): + from pydantic_core.core_schema import with_info_plain_validator_function + def validate(value, _): + if isinstance(value, cls): + return value + if isinstance(value, dict): + return cls( + name=value["name"], + primitives=ThreadSafePrimitives(None, None, None, None, None, None), + items=[], + mapping={}, + optional=None + ) + raise ValueError(f"Invalid value type for {cls.__name__}") + return with_info_plain_validator_function(validate) + + class ComplexState(BaseModel): + id: str = Field(default_factory=lambda: str(uuid4())) + name: str + nested: NestedContainer + items: List[NestedContainer] + mapping: Dict[str, NestedContainer] + optional: Optional[NestedContainer] = None + + class ComplexStateFlow(Flow[ComplexState]): + def __init__(self): + primitives = ThreadSafePrimitives( + thread_lock=threading.Lock(), + rlock=threading.RLock(), + semaphore=threading.Semaphore(), + event=threading.Event(), + async_lock=asyncio.Lock(), + async_event=asyncio.Event() + ) + container = NestedContainer( + name="test", + primitives=primitives, + items=[primitives], + mapping={"key": primitives}, + optional=primitives + ) + self.initial_state = ComplexState( + name="test", + nested=container, + items=[container], + mapping={"key": container}, + optional=container + ) + super().__init__() + + @start() + async def step_1(self): + with self.state.nested.primitives.rlock: + return "step 1" + + @listen(step_1) + async def step_2(self, result): + with self.state.items[0].primitives.rlock: + return result + " -> step 2" + + flow = ComplexStateFlow() + result = flow.kickoff() + + assert result == "step 1 -> step 2" + + def test_router_with_multiple_conditions(): """Test a router that triggers when any of multiple steps complete (OR condition), and another router that triggers only after all specified steps complete (AND condition).