Refactor event base classes (#2491)

- Renamed `CrewEvent` to `BaseEvent` across the codebase for consistency
- Created a `CrewBaseEvent` that automatically identifies fingerprints for DRY
- Added a new `to_json()` method for serializing events
This commit is contained in:
Vini Brasil
2025-03-27 15:42:11 -03:00
committed by GitHub
parent fc9da22c38
commit f845fac4da
12 changed files with 155 additions and 219 deletions

View File

@@ -13,7 +13,7 @@ CrewAI provides a powerful event system that allows you to listen for and react
CrewAI uses an event bus architecture to emit events throughout the execution lifecycle. The event system is built on the following components:
1. **CrewAIEventsBus**: A singleton event bus that manages event registration and emission
2. **CrewEvent**: Base class for all events in the system
2. **BaseEvent**: Base class for all events in the system
3. **BaseEventListener**: Abstract base class for creating custom event listeners
When specific actions occur in CrewAI (like a Crew starting execution, an Agent completing a task, or a tool being used), the system emits corresponding events. You can register handlers for these events to execute custom code when they occur.
@@ -234,7 +234,7 @@ Each event handler receives two parameters:
1. **source**: The object that emitted the event
2. **event**: The event instance, containing event-specific data
The structure of the event object depends on the event type, but all events inherit from `CrewEvent` and include:
The structure of the event object depends on the event type, but all events inherit from `BaseEvent` and include:
- **timestamp**: The time when the event was emitted
- **type**: A string identifier for the event type

View File

@@ -4,13 +4,13 @@ from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.tools.base_tool import BaseTool
from crewai.tools.structured_tool import CrewStructuredTool
from .base_events import CrewEvent
from .base_events import BaseEvent
if TYPE_CHECKING:
from crewai.agents.agent_builder.base_agent import BaseAgent
class AgentExecutionStartedEvent(CrewEvent):
class AgentExecutionStartedEvent(BaseEvent):
"""Event emitted when an agent starts executing a task"""
agent: BaseAgent
@@ -24,14 +24,17 @@ class AgentExecutionStartedEvent(CrewEvent):
def __init__(self, **data):
super().__init__(**data)
# Set fingerprint data from the agent
if hasattr(self.agent, 'fingerprint') and self.agent.fingerprint:
if hasattr(self.agent, "fingerprint") and self.agent.fingerprint:
self.source_fingerprint = self.agent.fingerprint.uuid_str
self.source_type = "agent"
if hasattr(self.agent.fingerprint, 'metadata') and self.agent.fingerprint.metadata:
if (
hasattr(self.agent.fingerprint, "metadata")
and self.agent.fingerprint.metadata
):
self.fingerprint_metadata = self.agent.fingerprint.metadata
class AgentExecutionCompletedEvent(CrewEvent):
class AgentExecutionCompletedEvent(BaseEvent):
"""Event emitted when an agent completes executing a task"""
agent: BaseAgent
@@ -42,14 +45,17 @@ class AgentExecutionCompletedEvent(CrewEvent):
def __init__(self, **data):
super().__init__(**data)
# Set fingerprint data from the agent
if hasattr(self.agent, 'fingerprint') and self.agent.fingerprint:
if hasattr(self.agent, "fingerprint") and self.agent.fingerprint:
self.source_fingerprint = self.agent.fingerprint.uuid_str
self.source_type = "agent"
if hasattr(self.agent.fingerprint, 'metadata') and self.agent.fingerprint.metadata:
if (
hasattr(self.agent.fingerprint, "metadata")
and self.agent.fingerprint.metadata
):
self.fingerprint_metadata = self.agent.fingerprint.metadata
class AgentExecutionErrorEvent(CrewEvent):
class AgentExecutionErrorEvent(BaseEvent):
"""Event emitted when an agent encounters an error during execution"""
agent: BaseAgent
@@ -60,8 +66,11 @@ class AgentExecutionErrorEvent(CrewEvent):
def __init__(self, **data):
super().__init__(**data)
# Set fingerprint data from the agent
if hasattr(self.agent, 'fingerprint') and self.agent.fingerprint:
if hasattr(self.agent, "fingerprint") and self.agent.fingerprint:
self.source_fingerprint = self.agent.fingerprint.uuid_str
self.source_type = "agent"
if hasattr(self.agent.fingerprint, 'metadata') and self.agent.fingerprint.metadata:
if (
hasattr(self.agent.fingerprint, "metadata")
and self.agent.fingerprint.metadata
):
self.fingerprint_metadata = self.agent.fingerprint.metadata

View File

@@ -3,12 +3,26 @@ from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
from crewai.utilities.serialization import to_serializable
class CrewEvent(BaseModel):
"""Base class for all crew events"""
class BaseEvent(BaseModel):
"""Base class for all events"""
timestamp: datetime = Field(default_factory=datetime.now)
type: str
source_fingerprint: Optional[str] = None # UUID string of the source entity
source_type: Optional[str] = None # "agent", "task", "crew"
fingerprint_metadata: Optional[Dict[str, Any]] = None # Any relevant metadata
def to_json(self, exclude: set[str] | None = None):
"""
Converts the event to a JSON-serializable dictionary.
Args:
exclude (set[str], optional): Set of keys to exclude from the result. Defaults to None.
Returns:
dict: A JSON-serializable dictionary.
"""
return to_serializable(self, exclude=exclude)

View File

@@ -1,171 +1,102 @@
from typing import Any, Dict, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from pydantic import InstanceOf
from crewai.utilities.events.base_events import BaseEvent
from crewai.utilities.events.base_events import CrewEvent
if TYPE_CHECKING:
from crewai.crew import Crew
else:
Crew = Any
class CrewKickoffStartedEvent(CrewEvent):
class CrewBaseEvent(BaseEvent):
"""Base class for crew events with fingerprint handling"""
crew_name: Optional[str]
crew: Optional[Crew] = None
def __init__(self, **data):
super().__init__(**data)
self.set_crew_fingerprint()
def set_crew_fingerprint(self) -> None:
if self.crew and hasattr(self.crew, "fingerprint") and self.crew.fingerprint:
self.source_fingerprint = self.crew.fingerprint.uuid_str
self.source_type = "crew"
if (
hasattr(self.crew.fingerprint, "metadata")
and self.crew.fingerprint.metadata
):
self.fingerprint_metadata = self.crew.fingerprint.metadata
def to_json(self, exclude: set[str] | None = None):
if exclude is None:
exclude = set()
exclude.add("crew")
return super().to_json(exclude=exclude)
class CrewKickoffStartedEvent(CrewBaseEvent):
"""Event emitted when a crew starts execution"""
crew_name: Optional[str]
inputs: Optional[Dict[str, Any]]
type: str = "crew_kickoff_started"
crew: Optional[Any] = None
def __init__(self, **data):
super().__init__(**data)
# Set fingerprint data from the crew
if self.crew and hasattr(self.crew, 'fingerprint') and self.crew.fingerprint:
self.source_fingerprint = self.crew.fingerprint.uuid_str
self.source_type = "crew"
if hasattr(self.crew.fingerprint, 'metadata') and self.crew.fingerprint.metadata:
self.fingerprint_metadata = self.crew.fingerprint.metadata
class CrewKickoffCompletedEvent(CrewEvent):
class CrewKickoffCompletedEvent(CrewBaseEvent):
"""Event emitted when a crew completes execution"""
crew_name: Optional[str]
output: Any
type: str = "crew_kickoff_completed"
crew: Optional[Any] = None
def __init__(self, **data):
super().__init__(**data)
# Set fingerprint data from the crew
if self.crew and hasattr(self.crew, 'fingerprint') and self.crew.fingerprint:
self.source_fingerprint = self.crew.fingerprint.uuid_str
self.source_type = "crew"
if hasattr(self.crew.fingerprint, 'metadata') and self.crew.fingerprint.metadata:
self.fingerprint_metadata = self.crew.fingerprint.metadata
class CrewKickoffFailedEvent(CrewEvent):
class CrewKickoffFailedEvent(CrewBaseEvent):
"""Event emitted when a crew fails to complete execution"""
error: str
crew_name: Optional[str]
type: str = "crew_kickoff_failed"
crew: Optional[Any] = None
def __init__(self, **data):
super().__init__(**data)
# Set fingerprint data from the crew
if self.crew and hasattr(self.crew, 'fingerprint') and self.crew.fingerprint:
self.source_fingerprint = self.crew.fingerprint.uuid_str
self.source_type = "crew"
if hasattr(self.crew.fingerprint, 'metadata') and self.crew.fingerprint.metadata:
self.fingerprint_metadata = self.crew.fingerprint.metadata
class CrewTrainStartedEvent(CrewEvent):
class CrewTrainStartedEvent(CrewBaseEvent):
"""Event emitted when a crew starts training"""
crew_name: Optional[str]
n_iterations: int
filename: str
inputs: Optional[Dict[str, Any]]
type: str = "crew_train_started"
crew: Optional[Any] = None
def __init__(self, **data):
super().__init__(**data)
# Set fingerprint data from the crew
if self.crew and hasattr(self.crew, 'fingerprint') and self.crew.fingerprint:
self.source_fingerprint = self.crew.fingerprint.uuid_str
self.source_type = "crew"
if hasattr(self.crew.fingerprint, 'metadata') and self.crew.fingerprint.metadata:
self.fingerprint_metadata = self.crew.fingerprint.metadata
class CrewTrainCompletedEvent(CrewEvent):
class CrewTrainCompletedEvent(CrewBaseEvent):
"""Event emitted when a crew completes training"""
crew_name: Optional[str]
n_iterations: int
filename: str
type: str = "crew_train_completed"
crew: Optional[Any] = None
def __init__(self, **data):
super().__init__(**data)
# Set fingerprint data from the crew
if self.crew and hasattr(self.crew, 'fingerprint') and self.crew.fingerprint:
self.source_fingerprint = self.crew.fingerprint.uuid_str
self.source_type = "crew"
if hasattr(self.crew.fingerprint, 'metadata') and self.crew.fingerprint.metadata:
self.fingerprint_metadata = self.crew.fingerprint.metadata
class CrewTrainFailedEvent(CrewEvent):
class CrewTrainFailedEvent(CrewBaseEvent):
"""Event emitted when a crew fails to complete training"""
error: str
crew_name: Optional[str]
type: str = "crew_train_failed"
crew: Optional[Any] = None
def __init__(self, **data):
super().__init__(**data)
# Set fingerprint data from the crew
if self.crew and hasattr(self.crew, 'fingerprint') and self.crew.fingerprint:
self.source_fingerprint = self.crew.fingerprint.uuid_str
self.source_type = "crew"
if hasattr(self.crew.fingerprint, 'metadata') and self.crew.fingerprint.metadata:
self.fingerprint_metadata = self.crew.fingerprint.metadata
class CrewTestStartedEvent(CrewEvent):
class CrewTestStartedEvent(CrewBaseEvent):
"""Event emitted when a crew starts testing"""
crew_name: Optional[str]
n_iterations: int
eval_llm: Optional[Union[str, Any]]
inputs: Optional[Dict[str, Any]]
type: str = "crew_test_started"
crew: Optional[Any] = None
def __init__(self, **data):
super().__init__(**data)
# Set fingerprint data from the crew
if self.crew and hasattr(self.crew, 'fingerprint') and self.crew.fingerprint:
self.source_fingerprint = self.crew.fingerprint.uuid_str
self.source_type = "crew"
if hasattr(self.crew.fingerprint, 'metadata') and self.crew.fingerprint.metadata:
self.fingerprint_metadata = self.crew.fingerprint.metadata
class CrewTestCompletedEvent(CrewEvent):
class CrewTestCompletedEvent(CrewBaseEvent):
"""Event emitted when a crew completes testing"""
crew_name: Optional[str]
type: str = "crew_test_completed"
crew: Optional[Any] = None
def __init__(self, **data):
super().__init__(**data)
# Set fingerprint data from the crew
if self.crew and hasattr(self.crew, 'fingerprint') and self.crew.fingerprint:
self.source_fingerprint = self.crew.fingerprint.uuid_str
self.source_type = "crew"
if hasattr(self.crew.fingerprint, 'metadata') and self.crew.fingerprint.metadata:
self.fingerprint_metadata = self.crew.fingerprint.metadata
class CrewTestFailedEvent(CrewEvent):
class CrewTestFailedEvent(CrewBaseEvent):
"""Event emitted when a crew fails to complete testing"""
error: str
crew_name: Optional[str]
type: str = "crew_test_failed"
crew: Optional[Any] = None
def __init__(self, **data):
super().__init__(**data)
# Set fingerprint data from the crew
if self.crew and hasattr(self.crew, 'fingerprint') and self.crew.fingerprint:
self.source_fingerprint = self.crew.fingerprint.uuid_str
self.source_type = "crew"
if hasattr(self.crew.fingerprint, 'metadata') and self.crew.fingerprint.metadata:
self.fingerprint_metadata = self.crew.fingerprint.metadata

View File

@@ -4,10 +4,10 @@ from typing import Any, Callable, Dict, List, Type, TypeVar, cast
from blinker import Signal
from crewai.utilities.events.base_events import CrewEvent
from crewai.utilities.events.base_events import BaseEvent
from crewai.utilities.events.event_types import EventTypes
EventT = TypeVar("EventT", bound=CrewEvent)
EventT = TypeVar("EventT", bound=BaseEvent)
class CrewAIEventsBus:
@@ -30,7 +30,7 @@ class CrewAIEventsBus:
def _initialize(self) -> None:
"""Initialize the event bus internal state"""
self._signal = Signal("crewai_event_bus")
self._handlers: Dict[Type[CrewEvent], List[Callable]] = {}
self._handlers: Dict[Type[BaseEvent], List[Callable]] = {}
def on(
self, event_type: Type[EventT]
@@ -59,7 +59,7 @@ class CrewAIEventsBus:
return decorator
def emit(self, source: Any, event: CrewEvent) -> None:
def emit(self, source: Any, event: BaseEvent) -> None:
"""
Emit an event to all registered handlers

View File

@@ -2,10 +2,10 @@ from typing import Any, Dict, Optional, Union
from pydantic import BaseModel, ConfigDict
from .base_events import CrewEvent
from .base_events import BaseEvent
class FlowEvent(CrewEvent):
class FlowEvent(BaseEvent):
"""Base class for all flow events"""
type: str

View File

@@ -1,7 +1,7 @@
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from crewai.utilities.events.base_events import CrewEvent
from crewai.utilities.events.base_events import BaseEvent
class LLMCallType(Enum):
@@ -11,9 +11,9 @@ class LLMCallType(Enum):
LLM_CALL = "llm_call"
class LLMCallStartedEvent(CrewEvent):
class LLMCallStartedEvent(BaseEvent):
"""Event emitted when a LLM call starts
Attributes:
messages: Content can be either a string or a list of dictionaries that support
multimodal content (text, images, etc.)
@@ -26,7 +26,7 @@ class LLMCallStartedEvent(CrewEvent):
available_functions: Optional[Dict[str, Any]] = None
class LLMCallCompletedEvent(CrewEvent):
class LLMCallCompletedEvent(BaseEvent):
"""Event emitted when a LLM call completes"""
type: str = "llm_call_completed"
@@ -34,14 +34,14 @@ class LLMCallCompletedEvent(CrewEvent):
call_type: LLMCallType
class LLMCallFailedEvent(CrewEvent):
class LLMCallFailedEvent(BaseEvent):
"""Event emitted when a LLM call fails"""
error: str
type: str = "llm_call_failed"
class LLMStreamChunkEvent(CrewEvent):
class LLMStreamChunkEvent(BaseEvent):
"""Event emitted when a streaming chunk is received"""
type: str = "llm_stream_chunk"

View File

@@ -1,10 +1,10 @@
from typing import Any, Optional
from crewai.tasks.task_output import TaskOutput
from crewai.utilities.events.base_events import CrewEvent
from crewai.utilities.events.base_events import BaseEvent
class TaskStartedEvent(CrewEvent):
class TaskStartedEvent(BaseEvent):
"""Event emitted when a task starts"""
type: str = "task_started"
@@ -14,14 +14,17 @@ class TaskStartedEvent(CrewEvent):
def __init__(self, **data):
super().__init__(**data)
# Set fingerprint data from the task
if hasattr(self.task, 'fingerprint') and self.task.fingerprint:
if hasattr(self.task, "fingerprint") and self.task.fingerprint:
self.source_fingerprint = self.task.fingerprint.uuid_str
self.source_type = "task"
if hasattr(self.task.fingerprint, 'metadata') and self.task.fingerprint.metadata:
if (
hasattr(self.task.fingerprint, "metadata")
and self.task.fingerprint.metadata
):
self.fingerprint_metadata = self.task.fingerprint.metadata
class TaskCompletedEvent(CrewEvent):
class TaskCompletedEvent(BaseEvent):
"""Event emitted when a task completes"""
output: TaskOutput
@@ -31,14 +34,17 @@ class TaskCompletedEvent(CrewEvent):
def __init__(self, **data):
super().__init__(**data)
# Set fingerprint data from the task
if hasattr(self.task, 'fingerprint') and self.task.fingerprint:
if hasattr(self.task, "fingerprint") and self.task.fingerprint:
self.source_fingerprint = self.task.fingerprint.uuid_str
self.source_type = "task"
if hasattr(self.task.fingerprint, 'metadata') and self.task.fingerprint.metadata:
if (
hasattr(self.task.fingerprint, "metadata")
and self.task.fingerprint.metadata
):
self.fingerprint_metadata = self.task.fingerprint.metadata
class TaskFailedEvent(CrewEvent):
class TaskFailedEvent(BaseEvent):
"""Event emitted when a task fails"""
error: str
@@ -48,14 +54,17 @@ class TaskFailedEvent(CrewEvent):
def __init__(self, **data):
super().__init__(**data)
# Set fingerprint data from the task
if hasattr(self.task, 'fingerprint') and self.task.fingerprint:
if hasattr(self.task, "fingerprint") and self.task.fingerprint:
self.source_fingerprint = self.task.fingerprint.uuid_str
self.source_type = "task"
if hasattr(self.task.fingerprint, 'metadata') and self.task.fingerprint.metadata:
if (
hasattr(self.task.fingerprint, "metadata")
and self.task.fingerprint.metadata
):
self.fingerprint_metadata = self.task.fingerprint.metadata
class TaskEvaluationEvent(CrewEvent):
class TaskEvaluationEvent(BaseEvent):
"""Event emitted when a task evaluation is completed"""
type: str = "task_evaluation"
@@ -65,8 +74,11 @@ class TaskEvaluationEvent(CrewEvent):
def __init__(self, **data):
super().__init__(**data)
# Set fingerprint data from the task
if hasattr(self.task, 'fingerprint') and self.task.fingerprint:
if hasattr(self.task, "fingerprint") and self.task.fingerprint:
self.source_fingerprint = self.task.fingerprint.uuid_str
self.source_type = "task"
if hasattr(self.task.fingerprint, 'metadata') and self.task.fingerprint.metadata:
if (
hasattr(self.task.fingerprint, "metadata")
and self.task.fingerprint.metadata
):
self.fingerprint_metadata = self.task.fingerprint.metadata

View File

@@ -1,10 +1,10 @@
from datetime import datetime
from typing import Any, Callable, Dict, Optional
from .base_events import CrewEvent
from .base_events import BaseEvent
class ToolUsageEvent(CrewEvent):
class ToolUsageEvent(BaseEvent):
"""Base event for tool usage tracking"""
agent_key: str
@@ -21,10 +21,13 @@ class ToolUsageEvent(CrewEvent):
def __init__(self, **data):
super().__init__(**data)
# Set fingerprint data from the agent
if self.agent and hasattr(self.agent, 'fingerprint') and self.agent.fingerprint:
if self.agent and hasattr(self.agent, "fingerprint") and self.agent.fingerprint:
self.source_fingerprint = self.agent.fingerprint.uuid_str
self.source_type = "agent"
if hasattr(self.agent.fingerprint, 'metadata') and self.agent.fingerprint.metadata:
if (
hasattr(self.agent.fingerprint, "metadata")
and self.agent.fingerprint.metadata
):
self.fingerprint_metadata = self.agent.fingerprint.metadata
@@ -65,7 +68,7 @@ class ToolSelectionErrorEvent(ToolUsageEvent):
type: str = "tool_selection_error"
class ToolExecutionErrorEvent(CrewEvent):
class ToolExecutionErrorEvent(BaseEvent):
"""Event emitted when a tool execution encounters an error"""
error: Any
@@ -78,8 +81,11 @@ class ToolExecutionErrorEvent(CrewEvent):
def __init__(self, **data):
super().__init__(**data)
# Set fingerprint data from the agent
if self.agent and hasattr(self.agent, 'fingerprint') and self.agent.fingerprint:
if self.agent and hasattr(self.agent, "fingerprint") and self.agent.fingerprint:
self.source_fingerprint = self.agent.fingerprint.uuid_str
self.source_type = "agent"
if hasattr(self.agent.fingerprint, 'metadata') and self.agent.fingerprint.metadata:
if (
hasattr(self.agent.fingerprint, "metadata")
and self.agent.fingerprint.metadata
):
self.fingerprint_metadata = self.agent.fingerprint.metadata

View File

@@ -5,35 +5,17 @@ from typing import Any, Dict, List, Union
from pydantic import BaseModel
from crewai.flow import Flow
SerializablePrimitive = Union[str, int, float, bool, None]
Serializable = Union[
SerializablePrimitive, List["Serializable"], Dict[str, "Serializable"]
]
def export_state(flow: Flow) -> dict[str, Serializable]:
"""Exports the Flow's internal state as JSON-compatible data structures.
Performs a one-way transformation of a Flow's state into basic Python types
that can be safely serialized to JSON. To prevent infinite recursion with
circular references, the conversion is limited to a depth of 5 levels.
Args:
flow: The Flow object whose state needs to be exported
Returns:
dict[str, Any]: The transformed state using JSON-compatible Python
types.
"""
result = to_serializable(flow._state)
assert isinstance(result, dict)
return result
def to_serializable(
obj: Any, exclude: set[str] | None = None, max_depth: int = 5, _current_depth: int = 0
obj: Any,
exclude: set[str] | None = None,
max_depth: int = 5,
_current_depth: int = 0,
) -> Serializable:
"""Converts a Python object into a JSON-compatible representation.

View File

@@ -1,10 +1,10 @@
from unittest.mock import Mock
from crewai.utilities.events.base_events import CrewEvent
from crewai.utilities.events.base_events import BaseEvent
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
class TestEvent(CrewEvent):
class TestEvent(BaseEvent):
pass
@@ -24,7 +24,7 @@ def test_specific_event_handler():
def test_wildcard_event_handler():
mock_handler = Mock()
@crewai_event_bus.on(CrewEvent)
@crewai_event_bus.on(BaseEvent)
def handler(source, event):
mock_handler(source, event)

View File

@@ -5,8 +5,7 @@ from unittest.mock import Mock
import pytest
from pydantic import BaseModel
from crewai.flow import Flow
from crewai.flow.state_utils import export_state, to_serializable, to_string
from crewai.utilities.serialization import to_serializable, to_string
class Address(BaseModel):
@@ -23,16 +22,6 @@ class Person(BaseModel):
skills: List[str]
@pytest.fixture
def mock_flow():
def create_flow(state):
flow = Mock(spec=Flow)
flow._state = state
return flow
return create_flow
@pytest.mark.parametrize(
"test_input,expected",
[
@@ -47,9 +36,8 @@ def mock_flow():
({"nested": [1, [2, 3], {4, 5}]}, {"nested": [1, [2, 3], [4, 5]]}),
],
)
def test_basic_serialization(mock_flow, test_input, expected):
flow = mock_flow(test_input)
result = export_state(flow)
def test_basic_serialization(test_input, expected):
result = to_serializable(test_input)
assert result == expected
@@ -60,9 +48,8 @@ def test_basic_serialization(mock_flow, test_input, expected):
(datetime(2024, 1, 1, 12, 30), "2024-01-01T12:30:00"),
],
)
def test_temporal_serialization(mock_flow, input_date, expected):
flow = mock_flow({"date": input_date})
result = export_state(flow)
def test_temporal_serialization(input_date, expected):
result = to_serializable({"date": input_date})
assert result["date"] == expected
@@ -75,9 +62,8 @@ def test_temporal_serialization(mock_flow, input_date, expected):
("normal", "value", str),
],
)
def test_dictionary_key_serialization(mock_flow, key, value, expected_key_type):
flow = mock_flow({key: value})
result = export_state(flow)
def test_dictionary_key_serialization(key, value, expected_key_type):
result = to_serializable({key: value})
assert len(result) == 1
result_key = next(iter(result.keys()))
assert isinstance(result_key, expected_key_type)
@@ -91,14 +77,13 @@ def test_dictionary_key_serialization(mock_flow, key, value, expected_key_type):
(str.upper, "upper"),
],
)
def test_callable_serialization(mock_flow, callable_obj, expected_in_result):
flow = mock_flow({"func": callable_obj})
result = export_state(flow)
def test_callable_serialization(callable_obj, expected_in_result):
result = to_serializable({"func": callable_obj})
assert isinstance(result["func"], str)
assert expected_in_result in result["func"].lower()
def test_pydantic_model_serialization(mock_flow):
def test_pydantic_model_serialization():
address = Address(street="123 Main St", city="Tech City", country="Pythonia")
person = Person(
@@ -109,23 +94,21 @@ def test_pydantic_model_serialization(mock_flow):
skills=["Python", "Testing"],
)
flow = mock_flow(
{
"single_model": address,
"nested_model": person,
"model_list": [address, address],
"model_dict": {"home": address},
}
)
data = {
"single_model": address,
"nested_model": person,
"model_list": [address, address],
"model_dict": {"home": address},
}
result = export_state(flow)
result = to_serializable(data)
assert (
to_string(result)
== '{"single_model": {"street": "123 Main St", "city": "Tech City", "country": "Pythonia"}, "nested_model": {"name": "John Doe", "age": 30, "address": {"street": "123 Main St", "city": "Tech City", "country": "Pythonia"}, "birthday": "1994-01-01", "skills": ["Python", "Testing"]}, "model_list": [{"street": "123 Main St", "city": "Tech City", "country": "Pythonia"}, {"street": "123 Main St", "city": "Tech City", "country": "Pythonia"}], "model_dict": {"home": {"street": "123 Main St", "city": "Tech City", "country": "Pythonia"}}}'
)
def test_depth_limit(mock_flow):
def test_depth_limit():
"""Test max depth handling with a deeply nested structure"""
def create_nested(depth):
@@ -134,8 +117,7 @@ def test_depth_limit(mock_flow):
return {"next": create_nested(depth - 1)}
deep_structure = create_nested(10)
flow = mock_flow(deep_structure)
result = export_state(flow)
result = to_serializable(deep_structure)
assert result == {
"next": {