mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
Implement Flow state export method (#2134)
This commit implements a method for exporting the state of a flow into a JSON-serializable dictionary. The idea is producing a human-readable version of state that can be inspected or consumed by other systems, hence JSON and not pickling or marshalling. I consider it an export because it's a one-way process, meaning it cannot be loaded back into Python because of complex types.
This commit is contained in:
committed by
Brandon Hancock
parent
f8c74b4fbb
commit
ae19437473
52
src/crewai/flow/state_utils.py
Normal file
52
src/crewai/flow/state_utils.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from datetime import date, datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.flow import Flow
|
||||
|
||||
|
||||
def export_state(flow: Flow) -> dict[str, Any]:
|
||||
"""Exports the Flow's internal state as JSON-compatible data structures.
|
||||
|
||||
Performs a one-way transformation of a Flow's state into basic Python types
|
||||
that can be safely serialized to JSON. To prevent infinite recursion with
|
||||
circular references, the conversion is limited to a depth of 5 levels.
|
||||
|
||||
Args:
|
||||
flow: The Flow object whose state needs to be exported
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: The transformed state using JSON-compatible Python
|
||||
types.
|
||||
"""
|
||||
return _to_serializable(flow._state)
|
||||
|
||||
|
||||
def _to_serializable(obj: Any, max_depth: int = 5, _current_depth: int = 0) -> Any:
|
||||
if _current_depth >= max_depth:
|
||||
return repr(obj)
|
||||
|
||||
if isinstance(obj, (str, int, float, bool, type(None))):
|
||||
return obj
|
||||
elif isinstance(obj, (date, datetime)):
|
||||
return obj.isoformat()
|
||||
elif isinstance(obj, (list, tuple, set)):
|
||||
return [_to_serializable(item, max_depth, _current_depth + 1) for item in obj]
|
||||
elif isinstance(obj, dict):
|
||||
return {
|
||||
_to_serializable_key(key): _to_serializable(
|
||||
value, max_depth, _current_depth + 1
|
||||
)
|
||||
for key, value in obj.items()
|
||||
}
|
||||
elif isinstance(obj, BaseModel):
|
||||
return _to_serializable(obj.model_dump(), max_depth, _current_depth + 1)
|
||||
else:
|
||||
return repr(obj)
|
||||
|
||||
|
||||
def _to_serializable_key(key: Any) -> str:
|
||||
if isinstance(key, (str, int)):
|
||||
return str(key)
|
||||
return f"key_{id(key)}_{repr(key)}"
|
||||
156
tests/flow/test_state_utils.py
Normal file
156
tests/flow/test_state_utils.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from datetime import date, datetime
|
||||
from typing import List
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.flow import Flow
|
||||
from crewai.flow.state_utils import export_state
|
||||
|
||||
|
||||
class Address(BaseModel):
|
||||
street: str
|
||||
city: str
|
||||
country: str
|
||||
|
||||
|
||||
class Person(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
address: Address
|
||||
birthday: date
|
||||
skills: List[str]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_flow():
|
||||
def create_flow(state):
|
||||
flow = Mock(spec=Flow)
|
||||
flow._state = state
|
||||
return flow
|
||||
|
||||
return create_flow
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_input,expected",
|
||||
[
|
||||
({"text": "hello world"}, {"text": "hello world"}),
|
||||
({"number": 42}, {"number": 42}),
|
||||
({"decimal": 3.14}, {"decimal": 3.14}),
|
||||
({"flag": True}, {"flag": True}),
|
||||
({"empty": None}, {"empty": None}),
|
||||
({"list": [1, 2, 3]}, {"list": [1, 2, 3]}),
|
||||
({"tuple": (1, 2, 3)}, {"tuple": [1, 2, 3]}),
|
||||
({"set": {1, 2, 3}}, {"set": [1, 2, 3]}),
|
||||
({"nested": [1, [2, 3], {4, 5}]}, {"nested": [1, [2, 3], [4, 5]]}),
|
||||
],
|
||||
)
|
||||
def test_basic_serialization(mock_flow, test_input, expected):
|
||||
flow = mock_flow(test_input)
|
||||
result = export_state(flow)
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_date,expected",
|
||||
[
|
||||
(date(2024, 1, 1), "2024-01-01"),
|
||||
(datetime(2024, 1, 1, 12, 30), "2024-01-01T12:30:00"),
|
||||
],
|
||||
)
|
||||
def test_temporal_serialization(mock_flow, input_date, expected):
|
||||
flow = mock_flow({"date": input_date})
|
||||
result = export_state(flow)
|
||||
assert result["date"] == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"key,value,expected_key_type",
|
||||
[
|
||||
(("tuple", "key"), "value", str),
|
||||
(None, "value", str),
|
||||
(123, "value", str),
|
||||
("normal", "value", str),
|
||||
],
|
||||
)
|
||||
def test_dictionary_key_serialization(mock_flow, key, value, expected_key_type):
|
||||
flow = mock_flow({key: value})
|
||||
result = export_state(flow)
|
||||
assert len(result) == 1
|
||||
result_key = next(iter(result.keys()))
|
||||
assert isinstance(result_key, expected_key_type)
|
||||
assert result[result_key] == value
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"callable_obj,expected_in_result",
|
||||
[
|
||||
(lambda x: x * 2, "lambda"),
|
||||
(str.upper, "upper"),
|
||||
],
|
||||
)
|
||||
def test_callable_serialization(mock_flow, callable_obj, expected_in_result):
|
||||
flow = mock_flow({"func": callable_obj})
|
||||
result = export_state(flow)
|
||||
assert isinstance(result["func"], str)
|
||||
assert expected_in_result in result["func"].lower()
|
||||
|
||||
|
||||
def test_pydantic_model_serialization(mock_flow):
|
||||
address = Address(street="123 Main St", city="Tech City", country="Pythonia")
|
||||
|
||||
person = Person(
|
||||
name="John Doe",
|
||||
age=30,
|
||||
address=address,
|
||||
birthday=date(1994, 1, 1),
|
||||
skills=["Python", "Testing"],
|
||||
)
|
||||
|
||||
flow = mock_flow(
|
||||
{
|
||||
"single_model": address,
|
||||
"nested_model": person,
|
||||
"model_list": [address, address],
|
||||
"model_dict": {"home": address},
|
||||
}
|
||||
)
|
||||
|
||||
result = export_state(flow)
|
||||
|
||||
assert result["single_model"]["street"] == "123 Main St"
|
||||
|
||||
assert result["nested_model"]["name"] == "John Doe"
|
||||
assert result["nested_model"]["address"]["city"] == "Tech City"
|
||||
assert result["nested_model"]["birthday"] == "1994-01-01"
|
||||
|
||||
assert len(result["model_list"]) == 2
|
||||
assert all(m["street"] == "123 Main St" for m in result["model_list"])
|
||||
assert result["model_dict"]["home"]["city"] == "Tech City"
|
||||
|
||||
|
||||
def test_depth_limit(mock_flow):
|
||||
"""Test max depth handling with a deeply nested structure"""
|
||||
|
||||
def create_nested(depth):
|
||||
if depth == 0:
|
||||
return "value"
|
||||
return {"next": create_nested(depth - 1)}
|
||||
|
||||
deep_structure = create_nested(10)
|
||||
flow = mock_flow(deep_structure)
|
||||
result = export_state(flow)
|
||||
|
||||
assert result == {
|
||||
"next": {
|
||||
"next": {
|
||||
"next": {
|
||||
"next": {
|
||||
"next": "{'next': {'next': {'next': {'next': {'next': 'value'}}}}}"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user