mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
refactor: Improve Flow state serialization with Pydantic core schema
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -4,7 +4,8 @@ import asyncio
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from uuid import uuid4
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.flow.flow import Flow, and_, listen, or_, router, start
|
||||
from crewai.flow.flow_events import (
|
||||
@@ -377,6 +378,80 @@ def test_flow_with_thread_lock():
|
||||
assert flow.counter == 2
|
||||
|
||||
|
||||
def test_flow_with_nested_objects_and_locks():
|
||||
"""Test that Flow properly handles nested objects containing locks."""
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
@dataclass
|
||||
class NestedState:
|
||||
value: str
|
||||
lock: threading.RLock = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.lock is None:
|
||||
self.lock = threading.RLock()
|
||||
|
||||
def __pydantic_validate__(self):
|
||||
return {"value": self.value, "lock": threading.RLock()}
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, source_type, handler):
|
||||
from pydantic_core.core_schema import (
|
||||
with_info_plain_validator_function,
|
||||
str_schema,
|
||||
)
|
||||
def validate(value, _):
|
||||
if isinstance(value, cls):
|
||||
return value
|
||||
if isinstance(value, dict):
|
||||
return cls(value["value"])
|
||||
raise ValueError(f"Invalid value type for {cls.__name__}")
|
||||
return with_info_plain_validator_function(validate)
|
||||
|
||||
class ComplexState(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
name: str
|
||||
nested: NestedState
|
||||
items: List[NestedState]
|
||||
mapping: Dict[str, NestedState]
|
||||
optional: Optional[NestedState] = None
|
||||
|
||||
class ComplexStateFlow(Flow[ComplexState]):
|
||||
def __init__(self):
|
||||
self.initial_state = ComplexState(
|
||||
name="test",
|
||||
nested=NestedState("nested", threading.RLock()),
|
||||
items=[
|
||||
NestedState("item1", threading.RLock()),
|
||||
NestedState("item2", threading.RLock())
|
||||
],
|
||||
mapping={
|
||||
"key1": NestedState("map1", threading.RLock()),
|
||||
"key2": NestedState("map2", threading.RLock())
|
||||
},
|
||||
optional=NestedState("optional", threading.RLock())
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
@start()
|
||||
async def step_1(self):
|
||||
with self.state.nested.lock:
|
||||
return "step 1"
|
||||
|
||||
@listen(step_1)
|
||||
async def step_2(self, result):
|
||||
with self.state.items[0].lock:
|
||||
with self.state.mapping["key1"].lock:
|
||||
with self.state.optional.lock:
|
||||
return result + " -> step 2"
|
||||
|
||||
flow = ComplexStateFlow()
|
||||
result = flow.kickoff()
|
||||
|
||||
assert result == "step 1 -> step 2"
|
||||
|
||||
def test_flow_with_nested_locks():
|
||||
"""Test that Flow properly handles nested thread locks."""
|
||||
import threading
|
||||
|
||||
Reference in New Issue
Block a user