Brandon/new release cleanup (#1918)

* WIP

* fixes to match enterprise changes
This commit is contained in:
Brandon Hancock (bhancock_ai)
2025-01-18 13:46:41 -05:00
committed by GitHub
parent 4a44245de9
commit 627bb3f5f6
6 changed files with 119 additions and 281 deletions

View File

@@ -1,6 +1,5 @@
import asyncio
import inspect
import uuid
from typing import (
Any,
Callable,
@@ -13,7 +12,6 @@ from typing import (
TypeVar,
Union,
cast,
overload,
)
from uuid import uuid4
@@ -27,7 +25,6 @@ from crewai.flow.flow_events import (
MethodExecutionStartedEvent,
)
from crewai.flow.flow_visualizer import plot_flow
from crewai.flow.persistence import FlowPersistence
from crewai.flow.persistence.base import FlowPersistence
from crewai.flow.utils import get_possible_return_constants
from crewai.telemetry import Telemetry
@@ -35,22 +32,32 @@ from crewai.telemetry import Telemetry
class FlowState(BaseModel):
"""Base model for all flow states, ensuring each state has a unique ID."""
id: str = Field(default_factory=lambda: str(uuid4()), description="Unique identifier for the flow state")
id: str = Field(
default_factory=lambda: str(uuid4()),
description="Unique identifier for the flow state",
)
# Type variables with explicit bounds
T = TypeVar("T", bound=Union[Dict[str, Any], BaseModel]) # Generic flow state type parameter
StateT = TypeVar("StateT", bound=Union[Dict[str, Any], BaseModel]) # State validation type parameter
T = TypeVar(
"T", bound=Union[Dict[str, Any], BaseModel]
) # Generic flow state type parameter
StateT = TypeVar(
"StateT", bound=Union[Dict[str, Any], BaseModel]
) # State validation type parameter
def ensure_state_type(state: Any, expected_type: Type[StateT]) -> StateT:
"""Ensure state matches expected type with proper validation.
Args:
state: State instance to validate
expected_type: Expected type for the state
Returns:
Validated state instance
Raises:
TypeError: If state doesn't match expected type
ValueError: If state validation fails
@@ -68,13 +75,15 @@ def ensure_state_type(state: Any, expected_type: Type[StateT]) -> StateT:
TypeError: If state doesn't match expected type
ValueError: If state validation fails
"""
if expected_type == dict:
if expected_type is dict:
if not isinstance(state, dict):
raise TypeError(f"Expected dict, got {type(state).__name__}")
return cast(StateT, state)
if isinstance(expected_type, type) and issubclass(expected_type, BaseModel):
if not isinstance(state, expected_type):
raise TypeError(f"Expected {expected_type.__name__}, got {type(state).__name__}")
raise TypeError(
f"Expected {expected_type.__name__}, got {type(state).__name__}"
)
return cast(StateT, state)
raise TypeError(f"Invalid expected_type: {expected_type}")
@@ -120,6 +129,7 @@ def start(condition: Optional[Union[str, dict, Callable]] = None) -> Callable:
>>> def complex_start(self):
... pass
"""
def decorator(func):
func.__is_start_method__ = True
if condition is not None:
@@ -144,6 +154,7 @@ def start(condition: Optional[Union[str, dict, Callable]] = None) -> Callable:
return decorator
def listen(condition: Union[str, dict, Callable]) -> Callable:
"""
Creates a listener that executes when specified conditions are met.
@@ -180,6 +191,7 @@ def listen(condition: Union[str, dict, Callable]) -> Callable:
>>> def handle_completion(self):
... pass
"""
def decorator(func):
if isinstance(condition, str):
func.__trigger_methods__ = [condition]
@@ -244,6 +256,7 @@ def router(condition: Union[str, dict, Callable]) -> Callable:
... return CONTINUE
... return STOP
"""
def decorator(func):
func.__is_router__ = True
if isinstance(condition, str):
@@ -267,6 +280,7 @@ def router(condition: Union[str, dict, Callable]) -> Callable:
return decorator
def or_(*conditions: Union[str, dict, Callable]) -> dict:
"""
Combines multiple conditions with OR logic for flow control.
@@ -370,22 +384,27 @@ class FlowMeta(type):
for attr_name, attr_value in dct.items():
# Check for any flow-related attributes
if (hasattr(attr_value, "__is_flow_method__") or
hasattr(attr_value, "__is_start_method__") or
hasattr(attr_value, "__trigger_methods__") or
hasattr(attr_value, "__is_router__")):
if (
hasattr(attr_value, "__is_flow_method__")
or hasattr(attr_value, "__is_start_method__")
or hasattr(attr_value, "__trigger_methods__")
or hasattr(attr_value, "__is_router__")
):
# Register start methods
if hasattr(attr_value, "__is_start_method__"):
start_methods.append(attr_name)
# Register listeners and routers
if hasattr(attr_value, "__trigger_methods__"):
methods = attr_value.__trigger_methods__
condition_type = getattr(attr_value, "__condition_type__", "OR")
listeners[attr_name] = (condition_type, methods)
if hasattr(attr_value, "__is_router__") and attr_value.__is_router__:
if (
hasattr(attr_value, "__is_router__")
and attr_value.__is_router__
):
routers.add(attr_name)
possible_returns = get_possible_return_constants(attr_value)
if possible_returns:
@@ -401,8 +420,9 @@ class FlowMeta(type):
class Flow(Generic[T], metaclass=FlowMeta):
"""Base class for all flows.
Type parameter T must be either Dict[str, Any] or a subclass of BaseModel."""
_telemetry = Telemetry()
_start_methods: List[str] = []
@@ -426,7 +446,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
**kwargs: Any,
) -> None:
"""Initialize a new Flow instance.
Args:
persistence: Optional persistence backend for storing flow states
restore_uuid: Optional UUID to restore state from persistence
@@ -438,29 +458,38 @@ class Flow(Generic[T], metaclass=FlowMeta):
self._pending_and_listeners: Dict[str, Set[str]] = {}
self._method_outputs: List[Any] = [] # List to store all method outputs
self._persistence: Optional[FlowPersistence] = persistence
# Validate state model before initialization
if isinstance(self.initial_state, type):
if issubclass(self.initial_state, BaseModel) and not issubclass(self.initial_state, FlowState):
if issubclass(self.initial_state, BaseModel) and not issubclass(
self.initial_state, FlowState
):
# Check if model has id field
model_fields = getattr(self.initial_state, "model_fields", None)
if not model_fields or "id" not in model_fields:
raise ValueError("Flow state model must have an 'id' field")
# Handle persistence and potential ID conflicts
stored_state = None
if self._persistence is not None:
if restore_uuid and kwargs and "id" in kwargs and restore_uuid != kwargs["id"]:
if (
restore_uuid
and kwargs
and "id" in kwargs
and restore_uuid != kwargs["id"]
):
raise ValueError(
f"Conflicting IDs provided: restore_uuid='{restore_uuid}' "
f"vs kwargs['id']='{kwargs['id']}'. Use only one ID for restoration."
)
# Attempt to load state, prioritizing restore_uuid
if restore_uuid:
stored_state = self._persistence.load_state(restore_uuid)
if not stored_state:
raise ValueError(f"No state found for restore_uuid='{restore_uuid}'")
raise ValueError(
f"No state found for restore_uuid='{restore_uuid}'"
)
elif kwargs and "id" in kwargs:
stored_state = self._persistence.load_state(kwargs["id"])
if not stored_state:
@@ -469,7 +498,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
if kwargs:
self._initialize_state(kwargs)
return
# Initialize state based on persistence and kwargs
if stored_state:
# Create initial state and restore from persistence
@@ -494,23 +523,23 @@ class Flow(Generic[T], metaclass=FlowMeta):
if not method_name.startswith("_"):
method = getattr(self, method_name)
# Check for any flow-related attributes
if (hasattr(method, "__is_flow_method__") or
hasattr(method, "__is_start_method__") or
hasattr(method, "__trigger_methods__") or
hasattr(method, "__is_router__")):
if (
hasattr(method, "__is_flow_method__")
or hasattr(method, "__is_start_method__")
or hasattr(method, "__trigger_methods__")
or hasattr(method, "__is_router__")
):
# Ensure method is bound to this instance
if not hasattr(method, "__self__"):
method = method.__get__(self, self.__class__)
self._methods[method_name] = method
def _create_initial_state(self) -> T:
"""Create and initialize flow state with UUID and default values.
Returns:
New state instance with UUID and default values initialized
Raises:
ValueError: If structured state model lacks 'id' field
TypeError: If state is neither BaseModel nor dictionary
@@ -522,24 +551,25 @@ class Flow(Generic[T], metaclass=FlowMeta):
if issubclass(state_type, FlowState):
# Create instance without id, then set it
instance = state_type()
if not hasattr(instance, 'id'):
setattr(instance, 'id', str(uuid4()))
if not hasattr(instance, "id"):
setattr(instance, "id", str(uuid4()))
return cast(T, instance)
elif issubclass(state_type, BaseModel):
# Create a new type that includes the ID field
class StateWithId(state_type, FlowState): # type: ignore
pass
instance = StateWithId()
if not hasattr(instance, 'id'):
setattr(instance, 'id', str(uuid4()))
if not hasattr(instance, "id"):
setattr(instance, "id", str(uuid4()))
return cast(T, instance)
elif state_type == dict:
return cast(T, {"id": str(uuid4())}) # Minimal dict state
elif state_type is dict:
return cast(T, {"id": str(uuid4())})
# Handle case where no initial state is provided
if self.initial_state is None:
return cast(T, {"id": str(uuid4())})
# Handle case where initial_state is a type (class)
if isinstance(self.initial_state, type):
if issubclass(self.initial_state, FlowState):
@@ -550,22 +580,22 @@ class Flow(Generic[T], metaclass=FlowMeta):
if not model_fields or "id" not in model_fields:
raise ValueError("Flow state model must have an 'id' field")
return cast(T, self.initial_state()) # Uses model defaults
elif self.initial_state == dict:
elif self.initial_state is dict:
return cast(T, {"id": str(uuid4())})
# Handle dictionary instance case
if isinstance(self.initial_state, dict):
new_state = dict(self.initial_state) # Copy to avoid mutations
if "id" not in new_state:
new_state["id"] = str(uuid4())
return cast(T, new_state)
# Handle BaseModel instance case
if isinstance(self.initial_state, BaseModel):
model = cast(BaseModel, self.initial_state)
if not hasattr(model, "id"):
raise ValueError("Flow state model must have an 'id' field")
# Create new instance with same values to avoid mutations
if hasattr(model, "model_dump"):
# Pydantic v2
@@ -576,60 +606,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
else:
# Fallback for other BaseModel implementations
state_dict = {
k: v for k, v in model.__dict__.items()
if not k.startswith("_")
k: v for k, v in model.__dict__.items() if not k.startswith("_")
}
# Create new instance of the same class
model_class = type(model)
return cast(T, model_class(**state_dict))
raise TypeError(
f"Initial state must be dict or BaseModel, got {type(self.initial_state)}"
)
# Handle case where initial_state is None but we have a type parameter
if self.initial_state is None and hasattr(self, "_initial_state_T"):
state_type = getattr(self, "_initial_state_T")
if isinstance(state_type, type):
if issubclass(state_type, FlowState):
return cast(T, state_type())
elif issubclass(state_type, BaseModel):
# Create a new type that includes the ID field
class StateWithId(state_type, FlowState): # type: ignore
pass
return cast(T, StateWithId())
elif state_type == dict:
return cast(T, {"id": str(uuid4())})
# Handle case where no initial state is provided
if self.initial_state is None:
return cast(T, {"id": str(uuid4())})
# Handle case where initial_state is a type (class)
if isinstance(self.initial_state, type):
if issubclass(self.initial_state, FlowState):
return cast(T, self.initial_state())
elif issubclass(self.initial_state, BaseModel):
# Validate that the model has an id field
model_fields = getattr(self.initial_state, "model_fields", None)
if not model_fields or "id" not in model_fields:
raise ValueError("Flow state model must have an 'id' field")
return cast(T, self.initial_state())
elif self.initial_state == dict:
return cast(T, {"id": str(uuid4())})
# Handle dictionary instance case
if isinstance(self.initial_state, dict):
if "id" not in self.initial_state:
self.initial_state["id"] = str(uuid4())
return cast(T, dict(self.initial_state)) # Create new dict to avoid mutations
# Handle BaseModel instance case
if isinstance(self.initial_state, BaseModel):
if not hasattr(self.initial_state, "id"):
raise ValueError("Flow state model must have an 'id' field")
return cast(T, self.initial_state)
raise TypeError(
f"Initial state must be dict or BaseModel, got {type(self.initial_state)}"
)
@@ -645,10 +628,10 @@ class Flow(Generic[T], metaclass=FlowMeta):
def _initialize_state(self, inputs: Dict[str, Any]) -> None:
"""Initialize or update flow state with new inputs.
Args:
inputs: Dictionary of state values to set/update
Raises:
ValueError: If validation fails for structured state
TypeError: If state is neither BaseModel nor dictionary
@@ -675,13 +658,12 @@ class Flow(Generic[T], metaclass=FlowMeta):
current_state = model.dict()
else:
current_state = {
k: v for k, v in model.__dict__.items()
if not k.startswith("_")
k: v for k, v in model.__dict__.items() if not k.startswith("_")
}
# Create new state with preserved fields and updates
new_state = {**current_state, **inputs}
# Create new instance with merged state
model_class = type(model)
if hasattr(model_class, "model_validate"):
@@ -697,13 +679,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
raise ValueError(f"Invalid inputs for structured state: {e}") from e
else:
raise TypeError("State must be a BaseModel instance or a dictionary.")
def _restore_state(self, stored_state: Dict[str, Any]) -> None:
"""Restore flow state from persistence.
Args:
stored_state: Previously stored state to restore
Raises:
ValueError: If validation fails for structured state
TypeError: If state is neither BaseModel nor dictionary
@@ -712,7 +694,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
stored_id = stored_state.get("id")
if not stored_id:
raise ValueError("Stored state must have an 'id' field")
if isinstance(self._state, dict):
# For dict states, update all fields from stored state
self._state.clear()
@@ -730,9 +712,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
# Fallback for other BaseModel implementations
self._state = cast(T, type(model)(**stored_state))
else:
raise TypeError(
f"State must be dict or BaseModel, got {type(self._state)}"
)
raise TypeError(f"State must be dict or BaseModel, got {type(self._state)}")
def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
self.event_emitter.send(