Merge branch 'main' into main

This commit is contained in:
fzowl
2025-01-14 13:13:10 +01:00
committed by GitHub
3 changed files with 165 additions and 17 deletions

View File

@@ -13,9 +13,10 @@ from typing import (
Union,
cast,
)
from uuid import uuid4
from blinker import Signal
from pydantic import BaseModel, ValidationError
from pydantic import BaseModel, Field, ValidationError
from crewai.flow.flow_events import (
FlowFinishedEvent,
@@ -27,7 +28,12 @@ from crewai.flow.flow_visualizer import plot_flow
from crewai.flow.utils import get_possible_return_constants
from crewai.telemetry import Telemetry
T = TypeVar("T", bound=Union[BaseModel, Dict[str, Any]])
class FlowState(BaseModel):
"""Base model for all flow states, ensuring each state has a unique ID."""
id: str = Field(default_factory=lambda: str(uuid4()), description="Unique identifier for the flow state")
T = TypeVar("T", bound=Union[FlowState, Dict[str, Any]])
def start(condition: Optional[Union[str, dict, Callable]] = None) -> Callable:
@@ -377,14 +383,37 @@ class Flow(Generic[T], metaclass=FlowMeta):
self._methods[method_name] = getattr(self, method_name)
def _create_initial_state(self) -> T:
# Handle case where initial_state is None but we have a type parameter
if self.initial_state is None and hasattr(self, "_initial_state_T"):
return self._initial_state_T() # type: ignore
state_type = getattr(self, "_initial_state_T")
if isinstance(state_type, type):
if issubclass(state_type, FlowState):
return state_type() # type: ignore
elif issubclass(state_type, BaseModel):
# Create a new type that includes the ID field
class StateWithId(state_type, FlowState): # type: ignore
pass
return StateWithId() # type: ignore
# Handle case where no initial state is provided
if self.initial_state is None:
return {} # type: ignore
elif isinstance(self.initial_state, type):
return self.initial_state()
else:
return self.initial_state
return {"id": str(uuid4())} # type: ignore
# Handle case where initial_state is a type (class)
if isinstance(self.initial_state, type):
if issubclass(self.initial_state, FlowState):
return self.initial_state() # type: ignore
elif issubclass(self.initial_state, BaseModel):
# Create a new type that includes the ID field
class StateWithId(self.initial_state, FlowState): # type: ignore
pass
return StateWithId() # type: ignore
# Handle dictionary case
if isinstance(self.initial_state, dict) and "id" not in self.initial_state:
self.initial_state["id"] = str(uuid4())
return self.initial_state # type: ignore
@property
def state(self) -> T:
@@ -396,10 +425,17 @@ class Flow(Generic[T], metaclass=FlowMeta):
return self._method_outputs
def _initialize_state(self, inputs: Dict[str, Any]) -> None:
if isinstance(self._state, BaseModel):
if isinstance(self._state, dict):
# Preserve the ID when updating unstructured state
current_id = self._state.get("id")
self._state.update(inputs)
if current_id:
self._state["id"] = current_id
elif "id" not in self._state:
self._state["id"] = str(uuid4())
elif isinstance(self._state, BaseModel):
# Structured state
try:
def create_model_with_extra_forbid(
base_model: Type[BaseModel],
) -> Type[BaseModel]:
@@ -409,16 +445,28 @@ class Flow(Generic[T], metaclass=FlowMeta):
return ModelWithExtraForbid
# Get current state as dict, preserving the ID if it exists
state_model = cast(BaseModel, self._state)
current_state = (
state_model.model_dump()
if hasattr(state_model, "model_dump")
else state_model.dict()
if hasattr(state_model, "dict")
else {
k: v
for k, v in state_model.__dict__.items()
if not k.startswith("_")
}
)
ModelWithExtraForbid = create_model_with_extra_forbid(
self._state.__class__
)
self._state = cast(
T, ModelWithExtraForbid(**{**self._state.model_dump(), **inputs})
T, ModelWithExtraForbid(**{**current_state, **inputs})
)
except ValidationError as e:
raise ValueError(f"Invalid inputs for structured state: {e}") from e
elif isinstance(self._state, dict):
self._state.update(inputs)
else:
raise TypeError("State must be a BaseModel instance or a dictionary.")