refactor: Improve Flow state serialization with Pydantic core schema

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-02-13 13:14:31 +00:00
parent ed877467e1
commit cf7a26e009
2 changed files with 220 additions and 10 deletions

View File

@@ -4,7 +4,8 @@ import asyncio
from datetime import datetime
import pytest
from pydantic import BaseModel
from uuid import uuid4
from pydantic import BaseModel, Field
from crewai.flow.flow import Flow, and_, listen, or_, router, start
from crewai.flow.flow_events import (
@@ -377,6 +378,80 @@ def test_flow_with_thread_lock():
assert flow.counter == 2
def test_flow_with_nested_objects_and_locks():
"""Test that Flow properly handles nested objects containing locks."""
import threading
from dataclasses import dataclass
from typing import Dict, List, Optional
@dataclass
class NestedState:
value: str
lock: threading.RLock = None
def __post_init__(self):
if self.lock is None:
self.lock = threading.RLock()
def __pydantic_validate__(self):
return {"value": self.value, "lock": threading.RLock()}
@classmethod
def __get_pydantic_core_schema__(cls, source_type, handler):
from pydantic_core.core_schema import (
with_info_plain_validator_function,
str_schema,
)
def validate(value, _):
if isinstance(value, cls):
return value
if isinstance(value, dict):
return cls(value["value"])
raise ValueError(f"Invalid value type for {cls.__name__}")
return with_info_plain_validator_function(validate)
class ComplexState(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4()))
name: str
nested: NestedState
items: List[NestedState]
mapping: Dict[str, NestedState]
optional: Optional[NestedState] = None
class ComplexStateFlow(Flow[ComplexState]):
def __init__(self):
self.initial_state = ComplexState(
name="test",
nested=NestedState("nested", threading.RLock()),
items=[
NestedState("item1", threading.RLock()),
NestedState("item2", threading.RLock())
],
mapping={
"key1": NestedState("map1", threading.RLock()),
"key2": NestedState("map2", threading.RLock())
},
optional=NestedState("optional", threading.RLock())
)
super().__init__()
@start()
async def step_1(self):
with self.state.nested.lock:
return "step 1"
@listen(step_1)
async def step_2(self, result):
with self.state.items[0].lock:
with self.state.mapping["key1"].lock:
with self.state.optional.lock:
return result + " -> step 2"
flow = ComplexStateFlow()
result = flow.kickoff()
assert result == "step 1 -> step 2"
def test_flow_with_nested_locks():
"""Test that Flow properly handles nested thread locks."""
import threading