From 294f2cc3a98c7d3f8b401ef2941baf0e2f5524e9 Mon Sep 17 00:00:00 2001 From: "devin-ai-integration[bot]" <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Thu, 16 Jan 2025 10:23:46 -0300 Subject: [PATCH] Add @persist decorator with FlowPersistence interface (#1892) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add @persist decorator with SQLite persistence - Add FlowPersistence abstract base class - Implement SQLiteFlowPersistence backend - Add @persist decorator for flow state persistence - Add tests for flow persistence functionality Co-Authored-By: Joe Moura * Fix remaining merge conflicts in uv.lock - Remove stray merge conflict markers - Keep main's comprehensive platform-specific resolution markers - Preserve all required dependencies for persistence functionality Co-Authored-By: Joe Moura * Fix final CUDA dependency conflicts in uv.lock - Resolve NVIDIA CUDA solver dependency conflicts - Use main's comprehensive platform checks - Ensure all merge conflict markers are removed - Preserve persistence-related dependencies Co-Authored-By: Joe Moura * Fix nvidia-cusparse-cu12 dependency conflicts in uv.lock - Resolve NVIDIA CUSPARSE dependency conflicts - Use main's comprehensive platform checks - Complete systematic check of entire uv.lock file - Ensure all merge conflict markers are removed Co-Authored-By: Joe Moura * Fix triton filelock dependency conflicts in uv.lock - Resolve triton package filelock dependency conflict - Use main's comprehensive platform checks - Complete final systematic check of entire uv.lock file - Ensure TOML file structure is valid Co-Authored-By: Joe Moura * Fix merge conflict in crew_test.py - Remove duplicate assertion in test_multimodal_agent_live_image_analysis - Clean up conflict markers - Preserve test functionality Co-Authored-By: Joe Moura * Clean up trailing merge conflict marker in crew_test.py - Remove remaining conflict marker at end of file - Preserve test functionality - Complete conflict resolution Co-Authored-By: Joe Moura * Improve type safety in persistence implementation and resolve merge conflicts Co-Authored-By: Joe Moura * fix: Add explicit type casting in _create_initial_state method Co-Authored-By: Joe Moura * fix: Improve type safety in flow state handling with proper validation Co-Authored-By: Joe Moura * fix: Improve type system with proper TypeVar scoping and validation Co-Authored-By: Joe Moura * fix: Improve state restoration logic and add comprehensive tests Co-Authored-By: Joe Moura * fix: Initialize FlowState instances without passing id to constructor Co-Authored-By: Joe Moura * feat: Add class-level flow persistence decorator with SQLite default - Add class-level @persist decorator support - Set SQLiteFlowPersistence as default backend - Use db_storage_path for consistent database location - Improve async method handling and type safety - Add comprehensive docstrings and examples Co-Authored-By: Joe Moura * fix: Sort imports in decorators.py to fix lint error Co-Authored-By: Joe Moura * style: Organize imports according to PEP 8 standard Co-Authored-By: Joe Moura * style: Format typing imports with line breaks for better readability Co-Authored-By: Joe Moura * style: Simplify import organization to fix lint error Co-Authored-By: Joe Moura * style: Fix import sorting using Ruff auto-fix Co-Authored-By: Joe Moura --------- Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Co-authored-by: Joe Moura Co-authored-by: João Moura --- src/crewai/flow/flow.py | 384 +++++++++++++++++--- src/crewai/flow/persistence/__init__.py | 18 + src/crewai/flow/persistence/base.py | 53 +++ src/crewai/flow/persistence/decorators.py | 177 +++++++++ src/crewai/flow/persistence/sqlite.py | 124 +++++++ src/crewai/utilities/paths.py | 10 +- tests/cassettes/test_agent_human_input.yaml | 188 +++++++--- tests/crew_test.py | 1 + tests/test_flow_persistence.py | 195 ++++++++++ uv.lock | 49 +-- 10 files changed, 1061 insertions(+), 138 deletions(-) create mode 100644 src/crewai/flow/persistence/__init__.py create mode 100644 src/crewai/flow/persistence/base.py create mode 100644 src/crewai/flow/persistence/decorators.py create mode 100644 src/crewai/flow/persistence/sqlite.py create mode 100644 tests/test_flow_persistence.py diff --git a/src/crewai/flow/flow.py b/src/crewai/flow/flow.py index f10626ce4..ef688b9c1 100644 --- a/src/crewai/flow/flow.py +++ b/src/crewai/flow/flow.py @@ -1,5 +1,6 @@ import asyncio import inspect +import uuid from typing import ( Any, Callable, @@ -12,6 +13,7 @@ from typing import ( TypeVar, Union, cast, + overload, ) from uuid import uuid4 @@ -25,6 +27,8 @@ from crewai.flow.flow_events import ( MethodExecutionStartedEvent, ) from crewai.flow.flow_visualizer import plot_flow +from crewai.flow.persistence import FlowPersistence +from crewai.flow.persistence.base import FlowPersistence from crewai.flow.utils import get_possible_return_constants from crewai.telemetry import Telemetry @@ -33,7 +37,46 @@ class FlowState(BaseModel): """Base model for all flow states, ensuring each state has a unique ID.""" id: str = Field(default_factory=lambda: str(uuid4()), description="Unique identifier for the flow state") -T = TypeVar("T", bound=Union[FlowState, Dict[str, Any]]) +# Type variables with explicit bounds +T = TypeVar("T", bound=Union[Dict[str, Any], BaseModel]) # Generic flow state type parameter +StateT = TypeVar("StateT", bound=Union[Dict[str, Any], BaseModel]) # State validation type parameter + +def ensure_state_type(state: Any, expected_type: Type[StateT]) -> StateT: + """Ensure state matches expected type with proper validation. + + Args: + state: State instance to validate + expected_type: Expected type for the state + + Returns: + Validated state instance + + Raises: + TypeError: If state doesn't match expected type + ValueError: If state validation fails + """ + """Ensure state matches expected type with proper validation. + + Args: + state: State instance to validate + expected_type: Expected type for the state + + Returns: + Validated state instance + + Raises: + TypeError: If state doesn't match expected type + ValueError: If state validation fails + """ + if expected_type == dict: + if not isinstance(state, dict): + raise TypeError(f"Expected dict, got {type(state).__name__}") + return cast(StateT, state) + if isinstance(expected_type, type) and issubclass(expected_type, BaseModel): + if not isinstance(state, expected_type): + raise TypeError(f"Expected {expected_type.__name__}, got {type(state).__name__}") + return cast(StateT, state) + raise TypeError(f"Invalid expected_type: {expected_type}") def start(condition: Optional[Union[str, dict, Callable]] = None) -> Callable: @@ -326,21 +369,27 @@ class FlowMeta(type): routers = set() for attr_name, attr_value in dct.items(): - if hasattr(attr_value, "__is_start_method__"): - start_methods.append(attr_name) + # Check for any flow-related attributes + if (hasattr(attr_value, "__is_flow_method__") or + hasattr(attr_value, "__is_start_method__") or + hasattr(attr_value, "__trigger_methods__") or + hasattr(attr_value, "__is_router__")): + + # Register start methods + if hasattr(attr_value, "__is_start_method__"): + start_methods.append(attr_name) + + # Register listeners and routers if hasattr(attr_value, "__trigger_methods__"): methods = attr_value.__trigger_methods__ condition_type = getattr(attr_value, "__condition_type__", "OR") listeners[attr_name] = (condition_type, methods) - elif hasattr(attr_value, "__trigger_methods__"): - methods = attr_value.__trigger_methods__ - condition_type = getattr(attr_value, "__condition_type__", "OR") - listeners[attr_name] = (condition_type, methods) - if hasattr(attr_value, "__is_router__") and attr_value.__is_router__: - routers.add(attr_name) - possible_returns = get_possible_return_constants(attr_value) - if possible_returns: - router_paths[attr_name] = possible_returns + + if hasattr(attr_value, "__is_router__") and attr_value.__is_router__: + routers.add(attr_name) + possible_returns = get_possible_return_constants(attr_value) + if possible_returns: + router_paths[attr_name] = possible_returns setattr(cls, "_start_methods", start_methods) setattr(cls, "_listeners", listeners) @@ -351,6 +400,9 @@ class FlowMeta(type): class Flow(Generic[T], metaclass=FlowMeta): + """Base class for all flows. + + Type parameter T must be either Dict[str, Any] or a subclass of BaseModel.""" _telemetry = Telemetry() _start_methods: List[str] = [] @@ -367,53 +419,220 @@ class Flow(Generic[T], metaclass=FlowMeta): _FlowGeneric.__name__ = f"{cls.__name__}[{item.__name__}]" return _FlowGeneric - def __init__(self) -> None: + def __init__( + self, + persistence: Optional[FlowPersistence] = None, + restore_uuid: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Initialize a new Flow instance. + + Args: + persistence: Optional persistence backend for storing flow states + restore_uuid: Optional UUID to restore state from persistence + **kwargs: Additional state values to initialize or override + """ + # Initialize basic instance attributes self._methods: Dict[str, Callable] = {} - self._state: T = self._create_initial_state() self._method_execution_counts: Dict[str, int] = {} self._pending_and_listeners: Dict[str, Set[str]] = {} self._method_outputs: List[Any] = [] # List to store all method outputs + self._persistence: Optional[FlowPersistence] = persistence + + # Validate state model before initialization + if isinstance(self.initial_state, type): + if issubclass(self.initial_state, BaseModel) and not issubclass(self.initial_state, FlowState): + # Check if model has id field + model_fields = getattr(self.initial_state, "model_fields", None) + if not model_fields or "id" not in model_fields: + raise ValueError("Flow state model must have an 'id' field") + + # Handle persistence and potential ID conflicts + stored_state = None + if self._persistence is not None: + if restore_uuid and kwargs and "id" in kwargs and restore_uuid != kwargs["id"]: + raise ValueError( + f"Conflicting IDs provided: restore_uuid='{restore_uuid}' " + f"vs kwargs['id']='{kwargs['id']}'. Use only one ID for restoration." + ) + + # Attempt to load state, prioritizing restore_uuid + if restore_uuid: + stored_state = self._persistence.load_state(restore_uuid) + if not stored_state: + raise ValueError(f"No state found for restore_uuid='{restore_uuid}'") + elif kwargs and "id" in kwargs: + stored_state = self._persistence.load_state(kwargs["id"]) + if not stored_state: + # For kwargs["id"], we allow creating new state if not found + self._state = self._create_initial_state() + if kwargs: + self._initialize_state(kwargs) + return + + # Initialize state based on persistence and kwargs + if stored_state: + # Create initial state and restore from persistence + self._state = self._create_initial_state() + self._restore_state(stored_state) + # Apply any additional kwargs to override specific fields + if kwargs: + filtered_kwargs = {k: v for k, v in kwargs.items() if k != "id"} + if filtered_kwargs: + self._initialize_state(filtered_kwargs) + else: + # No stored state, create new state with initial values + self._state = self._create_initial_state() + # Apply any additional kwargs + if kwargs: + self._initialize_state(kwargs) self._telemetry.flow_creation_span(self.__class__.__name__) + # Register all flow-related methods for method_name in dir(self): - if callable(getattr(self, method_name)) and not method_name.startswith( - "__" - ): - self._methods[method_name] = getattr(self, method_name) + if not method_name.startswith("_"): + method = getattr(self, method_name) + # Check for any flow-related attributes + if (hasattr(method, "__is_flow_method__") or + hasattr(method, "__is_start_method__") or + hasattr(method, "__trigger_methods__") or + hasattr(method, "__is_router__")): + # Ensure method is bound to this instance + if not hasattr(method, "__self__"): + method = method.__get__(self, self.__class__) + self._methods[method_name] = method + + def _create_initial_state(self) -> T: + """Create and initialize flow state with UUID and default values. + + Returns: + New state instance with UUID and default values initialized + + Raises: + ValueError: If structured state model lacks 'id' field + TypeError: If state is neither BaseModel nor dictionary + """ # Handle case where initial_state is None but we have a type parameter if self.initial_state is None and hasattr(self, "_initial_state_T"): state_type = getattr(self, "_initial_state_T") if isinstance(state_type, type): if issubclass(state_type, FlowState): - return state_type() # type: ignore + # Create instance without id, then set it + instance = state_type() + if not hasattr(instance, 'id'): + setattr(instance, 'id', str(uuid4())) + return cast(T, instance) elif issubclass(state_type, BaseModel): # Create a new type that includes the ID field class StateWithId(state_type, FlowState): # type: ignore pass - return StateWithId() # type: ignore + instance = StateWithId() + if not hasattr(instance, 'id'): + setattr(instance, 'id', str(uuid4())) + return cast(T, instance) + elif state_type == dict: + return cast(T, {"id": str(uuid4())}) # Minimal dict state + + # Handle case where no initial state is provided + if self.initial_state is None: + return cast(T, {"id": str(uuid4())}) + + # Handle case where initial_state is a type (class) + if isinstance(self.initial_state, type): + if issubclass(self.initial_state, FlowState): + return cast(T, self.initial_state()) # Uses model defaults + elif issubclass(self.initial_state, BaseModel): + # Validate that the model has an id field + model_fields = getattr(self.initial_state, "model_fields", None) + if not model_fields or "id" not in model_fields: + raise ValueError("Flow state model must have an 'id' field") + return cast(T, self.initial_state()) # Uses model defaults + elif self.initial_state == dict: + return cast(T, {"id": str(uuid4())}) + + # Handle dictionary instance case + if isinstance(self.initial_state, dict): + new_state = dict(self.initial_state) # Copy to avoid mutations + if "id" not in new_state: + new_state["id"] = str(uuid4()) + return cast(T, new_state) + + # Handle BaseModel instance case + if isinstance(self.initial_state, BaseModel): + model = cast(BaseModel, self.initial_state) + if not hasattr(model, "id"): + raise ValueError("Flow state model must have an 'id' field") + + # Create new instance with same values to avoid mutations + if hasattr(model, "model_dump"): + # Pydantic v2 + state_dict = model.model_dump() + elif hasattr(model, "dict"): + # Pydantic v1 + state_dict = model.dict() + else: + # Fallback for other BaseModel implementations + state_dict = { + k: v for k, v in model.__dict__.items() + if not k.startswith("_") + } + + # Create new instance of the same class + model_class = type(model) + return cast(T, model_class(**state_dict)) + + raise TypeError( + f"Initial state must be dict or BaseModel, got {type(self.initial_state)}" + ) + # Handle case where initial_state is None but we have a type parameter + if self.initial_state is None and hasattr(self, "_initial_state_T"): + state_type = getattr(self, "_initial_state_T") + if isinstance(state_type, type): + if issubclass(state_type, FlowState): + return cast(T, state_type()) + elif issubclass(state_type, BaseModel): + # Create a new type that includes the ID field + class StateWithId(state_type, FlowState): # type: ignore + pass + return cast(T, StateWithId()) + elif state_type == dict: + return cast(T, {"id": str(uuid4())}) # Handle case where no initial state is provided if self.initial_state is None: - return {"id": str(uuid4())} # type: ignore + return cast(T, {"id": str(uuid4())}) # Handle case where initial_state is a type (class) if isinstance(self.initial_state, type): if issubclass(self.initial_state, FlowState): - return self.initial_state() # type: ignore + return cast(T, self.initial_state()) elif issubclass(self.initial_state, BaseModel): - # Create a new type that includes the ID field - class StateWithId(self.initial_state, FlowState): # type: ignore - pass - return StateWithId() # type: ignore + # Validate that the model has an id field + model_fields = getattr(self.initial_state, "model_fields", None) + if not model_fields or "id" not in model_fields: + raise ValueError("Flow state model must have an 'id' field") + return cast(T, self.initial_state()) + elif self.initial_state == dict: + return cast(T, {"id": str(uuid4())}) - # Handle dictionary case - if isinstance(self.initial_state, dict) and "id" not in self.initial_state: - self.initial_state["id"] = str(uuid4()) + # Handle dictionary instance case + if isinstance(self.initial_state, dict): + if "id" not in self.initial_state: + self.initial_state["id"] = str(uuid4()) + return cast(T, dict(self.initial_state)) # Create new dict to avoid mutations - return self.initial_state # type: ignore + # Handle BaseModel instance case + if isinstance(self.initial_state, BaseModel): + if not hasattr(self.initial_state, "id"): + raise ValueError("Flow state model must have an 'id' field") + return cast(T, self.initial_state) + + raise TypeError( + f"Initial state must be dict or BaseModel, got {type(self.initial_state)}" + ) @property def state(self) -> T: @@ -425,50 +644,95 @@ class Flow(Generic[T], metaclass=FlowMeta): return self._method_outputs def _initialize_state(self, inputs: Dict[str, Any]) -> None: + """Initialize or update flow state with new inputs. + + Args: + inputs: Dictionary of state values to set/update + + Raises: + ValueError: If validation fails for structured state + TypeError: If state is neither BaseModel nor dictionary + """ if isinstance(self._state, dict): - # Preserve the ID when updating unstructured state + # For dict states, preserve existing fields unless overridden current_id = self._state.get("id") - self._state.update(inputs) + # Only update specified fields + for k, v in inputs.items(): + self._state[k] = v + # Ensure ID is preserved or generated if current_id: self._state["id"] = current_id elif "id" not in self._state: self._state["id"] = str(uuid4()) elif isinstance(self._state, BaseModel): - # Structured state + # For BaseModel states, preserve existing fields unless overridden try: - def create_model_with_extra_forbid( - base_model: Type[BaseModel], - ) -> Type[BaseModel]: - class ModelWithExtraForbid(base_model): # type: ignore - model_config = base_model.model_config.copy() - model_config["extra"] = "forbid" - - return ModelWithExtraForbid - - # Get current state as dict, preserving the ID if it exists - state_model = cast(BaseModel, self._state) - current_state = ( - state_model.model_dump() - if hasattr(state_model, "model_dump") - else state_model.dict() - if hasattr(state_model, "dict") - else { - k: v - for k, v in state_model.__dict__.items() + model = cast(BaseModel, self._state) + # Get current state as dict + if hasattr(model, "model_dump"): + current_state = model.model_dump() + elif hasattr(model, "dict"): + current_state = model.dict() + else: + current_state = { + k: v for k, v in model.__dict__.items() if not k.startswith("_") } - ) - - ModelWithExtraForbid = create_model_with_extra_forbid( - self._state.__class__ - ) - self._state = cast( - T, ModelWithExtraForbid(**{**current_state, **inputs}) - ) + + # Create new state with preserved fields and updates + new_state = {**current_state, **inputs} + + # Create new instance with merged state + model_class = type(model) + if hasattr(model_class, "model_validate"): + # Pydantic v2 + self._state = cast(T, model_class.model_validate(new_state)) + elif hasattr(model_class, "parse_obj"): + # Pydantic v1 + self._state = cast(T, model_class.parse_obj(new_state)) + else: + # Fallback for other BaseModel implementations + self._state = cast(T, model_class(**new_state)) except ValidationError as e: raise ValueError(f"Invalid inputs for structured state: {e}") from e else: raise TypeError("State must be a BaseModel instance or a dictionary.") + + def _restore_state(self, stored_state: Dict[str, Any]) -> None: + """Restore flow state from persistence. + + Args: + stored_state: Previously stored state to restore + + Raises: + ValueError: If validation fails for structured state + TypeError: If state is neither BaseModel nor dictionary + """ + # When restoring from persistence, use the stored ID + stored_id = stored_state.get("id") + if not stored_id: + raise ValueError("Stored state must have an 'id' field") + + if isinstance(self._state, dict): + # For dict states, update all fields from stored state + self._state.clear() + self._state.update(stored_state) + elif isinstance(self._state, BaseModel): + # For BaseModel states, create new instance with stored values + model = cast(BaseModel, self._state) + if hasattr(model, "model_validate"): + # Pydantic v2 + self._state = cast(T, type(model).model_validate(stored_state)) + elif hasattr(model, "parse_obj"): + # Pydantic v1 + self._state = cast(T, type(model).parse_obj(stored_state)) + else: + # Fallback for other BaseModel implementations + self._state = cast(T, type(model)(**stored_state)) + else: + raise TypeError( + f"State must be dict or BaseModel, got {type(self._state)}" + ) def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any: self.event_emitter.send( diff --git a/src/crewai/flow/persistence/__init__.py b/src/crewai/flow/persistence/__init__.py new file mode 100644 index 000000000..0b673f6bf --- /dev/null +++ b/src/crewai/flow/persistence/__init__.py @@ -0,0 +1,18 @@ +""" +CrewAI Flow Persistence. + +This module provides interfaces and implementations for persisting flow states. +""" + +from typing import Any, Dict, TypeVar, Union + +from pydantic import BaseModel + +from crewai.flow.persistence.base import FlowPersistence +from crewai.flow.persistence.decorators import persist +from crewai.flow.persistence.sqlite import SQLiteFlowPersistence + +__all__ = ["FlowPersistence", "persist", "SQLiteFlowPersistence"] + +StateType = TypeVar('StateType', bound=Union[Dict[str, Any], BaseModel]) +DictStateType = Dict[str, Any] diff --git a/src/crewai/flow/persistence/base.py b/src/crewai/flow/persistence/base.py new file mode 100644 index 000000000..c926f6f34 --- /dev/null +++ b/src/crewai/flow/persistence/base.py @@ -0,0 +1,53 @@ +"""Base class for flow state persistence.""" + +import abc +from typing import Any, Dict, Optional, Union + +from pydantic import BaseModel + + +class FlowPersistence(abc.ABC): + """Abstract base class for flow state persistence. + + This class defines the interface that all persistence implementations must follow. + It supports both structured (Pydantic BaseModel) and unstructured (dict) states. + """ + + @abc.abstractmethod + def init_db(self) -> None: + """Initialize the persistence backend. + + This method should handle any necessary setup, such as: + - Creating tables + - Establishing connections + - Setting up indexes + """ + pass + + @abc.abstractmethod + def save_state( + self, + flow_uuid: str, + method_name: str, + state_data: Union[Dict[str, Any], BaseModel] + ) -> None: + """Persist the flow state after method completion. + + Args: + flow_uuid: Unique identifier for the flow instance + method_name: Name of the method that just completed + state_data: Current state data (either dict or Pydantic model) + """ + pass + + @abc.abstractmethod + def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]: + """Load the most recent state for a given flow UUID. + + Args: + flow_uuid: Unique identifier for the flow instance + + Returns: + The most recent state as a dictionary, or None if no state exists + """ + pass diff --git a/src/crewai/flow/persistence/decorators.py b/src/crewai/flow/persistence/decorators.py new file mode 100644 index 000000000..4906e95d5 --- /dev/null +++ b/src/crewai/flow/persistence/decorators.py @@ -0,0 +1,177 @@ +""" +Decorators for flow state persistence. + +Example: + ```python + from crewai.flow.flow import Flow, start + from crewai.flow.persistence import persist, SQLiteFlowPersistence + + class MyFlow(Flow): + @start() + @persist(SQLiteFlowPersistence()) + def sync_method(self): + # Synchronous method implementation + pass + + @start() + @persist(SQLiteFlowPersistence()) + async def async_method(self): + # Asynchronous method implementation + await some_async_operation() + ``` +""" + +import asyncio +import functools +import inspect +import logging +from typing import ( + Any, + Callable, + Dict, + Optional, + Type, + TypeVar, + Union, + cast, + get_type_hints, +) + +from pydantic import BaseModel + +from crewai.flow.persistence.base import FlowPersistence +from crewai.flow.persistence.sqlite import SQLiteFlowPersistence + +logger = logging.getLogger(__name__) +T = TypeVar("T") + + +def persist(persistence: Optional[FlowPersistence] = None): + """Decorator to persist flow state. + + This decorator can be applied at either the class level or method level. + When applied at the class level, it automatically persists all flow method + states. When applied at the method level, it persists only that method's + state. + + Args: + persistence: Optional FlowPersistence implementation to use. + If not provided, uses SQLiteFlowPersistence. + + Returns: + A decorator that can be applied to either a class or method + + Raises: + ValueError: If the flow state doesn't have an 'id' field + RuntimeError: If state persistence fails + + Example: + @persist # Class-level persistence with default SQLite + class MyFlow(Flow[MyState]): + @start() + def begin(self): + pass + """ + def _persist_state(flow_instance: Any, method_name: str, persistence_instance: FlowPersistence) -> None: + """Helper to persist state with error handling.""" + try: + # Get flow UUID from state + state = getattr(flow_instance, 'state', None) + if state is None: + raise ValueError("Flow instance has no state") + + flow_uuid: Optional[str] = None + if isinstance(state, dict): + flow_uuid = state.get('id') + elif isinstance(state, BaseModel): + flow_uuid = getattr(state, 'id', None) + + if not flow_uuid: + raise ValueError( + "Flow state must have an 'id' field for persistence" + ) + + # Persist the state + persistence_instance.save_state( + flow_uuid=flow_uuid, + method_name=method_name, + state_data=state, + ) + except Exception as e: + logger.error( + f"Failed to persist state for method {method_name}: {str(e)}" + ) + raise RuntimeError(f"State persistence failed: {str(e)}") from e + + def decorator(target: Union[Type, Callable[..., T]]) -> Union[Type, Callable[..., T]]: + """Decorator that handles both class and method decoration.""" + actual_persistence = persistence or SQLiteFlowPersistence() + + if isinstance(target, type): + # Class decoration + class_methods = {} + for name, method in target.__dict__.items(): + if callable(method) and hasattr(method, "__is_flow_method__"): + # Wrap each flow method with persistence + if asyncio.iscoroutinefunction(method): + @functools.wraps(method) + async def class_async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + method_coro = method(self, *args, **kwargs) + if asyncio.iscoroutine(method_coro): + result = await method_coro + else: + result = method_coro + _persist_state(self, method.__name__, actual_persistence) + return result + class_methods[name] = class_async_wrapper + else: + @functools.wraps(method) + def class_sync_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + result = method(self, *args, **kwargs) + _persist_state(self, method.__name__, actual_persistence) + return result + class_methods[name] = class_sync_wrapper + + # Preserve flow-specific attributes + for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]: + if hasattr(method, attr): + setattr(class_methods[name], attr, getattr(method, attr)) + setattr(class_methods[name], "__is_flow_method__", True) + + # Update class with wrapped methods + for name, method in class_methods.items(): + setattr(target, name, method) + return target + else: + # Method decoration + method = target + setattr(method, "__is_flow_method__", True) + + if asyncio.iscoroutinefunction(method): + @functools.wraps(method) + async def method_async_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T: + method_coro = method(flow_instance, *args, **kwargs) + if asyncio.iscoroutine(method_coro): + result = await method_coro + else: + result = method_coro + _persist_state(flow_instance, method.__name__, actual_persistence) + return result + for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]: + if hasattr(method, attr): + setattr(method_async_wrapper, attr, getattr(method, attr)) + setattr(method_async_wrapper, "__is_flow_method__", True) + return cast(Callable[..., T], method_async_wrapper) + else: + @functools.wraps(method) + def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T: + result = method(flow_instance, *args, **kwargs) + _persist_state(flow_instance, method.__name__, actual_persistence) + return result + for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]: + if hasattr(method, attr): + setattr(method_sync_wrapper, attr, getattr(method, attr)) + setattr(method_sync_wrapper, "__is_flow_method__", True) + return cast(Callable[..., T], method_sync_wrapper) + + return decorator diff --git a/src/crewai/flow/persistence/sqlite.py b/src/crewai/flow/persistence/sqlite.py new file mode 100644 index 000000000..bdd091b2b --- /dev/null +++ b/src/crewai/flow/persistence/sqlite.py @@ -0,0 +1,124 @@ +""" +SQLite-based implementation of flow state persistence. +""" + +import json +import os +import sqlite3 +import tempfile +from datetime import datetime +from typing import Any, Dict, Optional, Union + +from pydantic import BaseModel + +from crewai.flow.persistence.base import FlowPersistence + + +class SQLiteFlowPersistence(FlowPersistence): + """SQLite-based implementation of flow state persistence. + + This class provides a simple, file-based persistence implementation using SQLite. + It's suitable for development and testing, or for production use cases with + moderate performance requirements. + """ + + db_path: str # Type annotation for instance variable + + def __init__(self, db_path: Optional[str] = None): + """Initialize SQLite persistence. + + Args: + db_path: Path to the SQLite database file. If not provided, uses + db_storage_path() from utilities.paths. + + Raises: + ValueError: If db_path is invalid + """ + from crewai.utilities.paths import db_storage_path + # Get path from argument or default location + path = db_path or db_storage_path() + + if not path: + raise ValueError("Database path must be provided") + + self.db_path = path # Now mypy knows this is str + self.init_db() + + def init_db(self) -> None: + """Create the necessary tables if they don't exist.""" + with sqlite3.connect(self.db_path) as conn: + conn.execute(""" + CREATE TABLE IF NOT EXISTS flow_states ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + flow_uuid TEXT NOT NULL, + method_name TEXT NOT NULL, + timestamp DATETIME NOT NULL, + state_json TEXT NOT NULL + ) + """) + # Add index for faster UUID lookups + conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_flow_states_uuid + ON flow_states(flow_uuid) + """) + + def save_state( + self, + flow_uuid: str, + method_name: str, + state_data: Union[Dict[str, Any], BaseModel], + ) -> None: + """Save the current flow state to SQLite. + + Args: + flow_uuid: Unique identifier for the flow instance + method_name: Name of the method that just completed + state_data: Current state data (either dict or Pydantic model) + """ + # Convert state_data to dict, handling both Pydantic and dict cases + if isinstance(state_data, BaseModel): + state_dict = dict(state_data) # Use dict() for better type compatibility + elif isinstance(state_data, dict): + state_dict = state_data + else: + raise ValueError( + f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}" + ) + + with sqlite3.connect(self.db_path) as conn: + conn.execute(""" + INSERT INTO flow_states ( + flow_uuid, + method_name, + timestamp, + state_json + ) VALUES (?, ?, ?, ?) + """, ( + flow_uuid, + method_name, + datetime.utcnow().isoformat(), + json.dumps(state_dict), + )) + + def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]: + """Load the most recent state for a given flow UUID. + + Args: + flow_uuid: Unique identifier for the flow instance + + Returns: + The most recent state as a dictionary, or None if no state exists + """ + with sqlite3.connect(self.db_path) as conn: + cursor = conn.execute(""" + SELECT state_json + FROM flow_states + WHERE flow_uuid = ? + ORDER BY id DESC + LIMIT 1 + """, (flow_uuid,)) + row = cursor.fetchone() + + if row: + return json.loads(row[0]) + return None diff --git a/src/crewai/utilities/paths.py b/src/crewai/utilities/paths.py index 9bf167ee6..5d91d1719 100644 --- a/src/crewai/utilities/paths.py +++ b/src/crewai/utilities/paths.py @@ -5,14 +5,18 @@ import appdirs """Path management utilities for CrewAI storage and configuration.""" -def db_storage_path(): - """Returns the path for database storage.""" +def db_storage_path() -> str: + """Returns the path for SQLite database storage. + + Returns: + str: Full path to the SQLite database file + """ app_name = get_project_directory_name() app_author = "CrewAI" data_dir = Path(appdirs.user_data_dir(app_name, app_author)) data_dir.mkdir(parents=True, exist_ok=True) - return data_dir + return str(data_dir / "crewai_flows.db") def get_project_directory_name(): diff --git a/tests/cassettes/test_agent_human_input.yaml b/tests/cassettes/test_agent_human_input.yaml index 16b9ac9a5..8c5fd3a80 100644 --- a/tests/cassettes/test_agent_human_input.yaml +++ b/tests/cassettes/test_agent_human_input.yaml @@ -1,4 +1,87 @@ interactions: +- request: + body: !!binary | + CqcXCiQKIgoMc2VydmljZS5uYW1lEhIKEGNyZXdBSS10ZWxlbWV0cnkS/hYKEgoQY3Jld2FpLnRl + bGVtZXRyeRJ5ChBuJJtOdNaB05mOW/p3915eEgj2tkAd3rZcASoQVG9vbCBVc2FnZSBFcnJvcjAB + OYa7/URvKBUYQUpcFEVvKBUYShoKDmNyZXdhaV92ZXJzaW9uEggKBjAuODYuMEoPCgNsbG0SCAoG + Z3B0LTRvegIYAYUBAAEAABLJBwoQifhX01E5i+5laGdALAlZBBIIBuGM1aN+OPgqDENyZXcgQ3Jl + YXRlZDABORVGruBvKBUYQaipwOBvKBUYShoKDmNyZXdhaV92ZXJzaW9uEggKBjAuODYuMEoaCg5w + eXRob25fdmVyc2lvbhIICgYzLjEyLjdKLgoIY3Jld19rZXkSIgogN2U2NjA4OTg5ODU5YTY3ZWVj + ODhlZWY3ZmNlODUyMjVKMQoHY3Jld19pZBImCiRiOThiNWEwMC01YTI1LTQxMDctYjQwNS1hYmYz + MjBhOGYzYThKHAoMY3Jld19wcm9jZXNzEgwKCnNlcXVlbnRpYWxKEQoLY3Jld19tZW1vcnkSAhAA + ShoKFGNyZXdfbnVtYmVyX29mX3Rhc2tzEgIYAUobChVjcmV3X251bWJlcl9vZl9hZ2VudHMSAhgB + SuQCCgtjcmV3X2FnZW50cxLUAgrRAlt7ImtleSI6ICIyMmFjZDYxMWU0NGVmNWZhYzA1YjUzM2Q3 + NWU4ODkzYiIsICJpZCI6ICJkNWIyMzM1YS0yMmIyLTQyZWEtYmYwNS03OTc3NmU3MmYzOTIiLCAi + cm9sZSI6ICJEYXRhIFNjaWVudGlzdCIsICJ2ZXJib3NlPyI6IGZhbHNlLCAibWF4X2l0ZXIiOiAy + MCwgIm1heF9ycG0iOiBudWxsLCAiZnVuY3Rpb25fY2FsbGluZ19sbG0iOiAiIiwgImxsbSI6ICJn + cHQtNG8tbWluaSIsICJkZWxlZ2F0aW9uX2VuYWJsZWQ/IjogZmFsc2UsICJhbGxvd19jb2RlX2V4 + ZWN1dGlvbj8iOiBmYWxzZSwgIm1heF9yZXRyeV9saW1pdCI6IDIsICJ0b29sc19uYW1lcyI6IFsi + Z2V0IGdyZWV0aW5ncyJdfV1KkgIKCmNyZXdfdGFza3MSgwIKgAJbeyJrZXkiOiAiYTI3N2IzNGIy + YzE0NmYwYzU2YzVlMTM1NmU4ZjhhNTciLCAiaWQiOiAiMjJiZWMyMzEtY2QyMS00YzU4LTgyN2Ut + MDU4MWE4ZjBjMTExIiwgImFzeW5jX2V4ZWN1dGlvbj8iOiBmYWxzZSwgImh1bWFuX2lucHV0PyI6 + IGZhbHNlLCAiYWdlbnRfcm9sZSI6ICJEYXRhIFNjaWVudGlzdCIsICJhZ2VudF9rZXkiOiAiMjJh + Y2Q2MTFlNDRlZjVmYWMwNWI1MzNkNzVlODg5M2IiLCAidG9vbHNfbmFtZXMiOiBbImdldCBncmVl + dGluZ3MiXX1degIYAYUBAAEAABKOAgoQ5WYoxRtTyPjge4BduhL0rRIIv2U6rvWALfwqDFRhc2sg + Q3JlYXRlZDABOX068uBvKBUYQZkv8+BvKBUYSi4KCGNyZXdfa2V5EiIKIDdlNjYwODk4OTg1OWE2 + N2VlYzg4ZWVmN2ZjZTg1MjI1SjEKB2NyZXdfaWQSJgokYjk4YjVhMDAtNWEyNS00MTA3LWI0MDUt + YWJmMzIwYThmM2E4Si4KCHRhc2tfa2V5EiIKIGEyNzdiMzRiMmMxNDZmMGM1NmM1ZTEzNTZlOGY4 + YTU3SjEKB3Rhc2tfaWQSJgokMjJiZWMyMzEtY2QyMS00YzU4LTgyN2UtMDU4MWE4ZjBjMTExegIY + AYUBAAEAABKQAQoQXyeDtJDFnyp2Fjk9YEGTpxIIaNE7gbhPNYcqClRvb2wgVXNhZ2UwATkaXTvj + bygVGEGvx0rjbygVGEoaCg5jcmV3YWlfdmVyc2lvbhIICgYwLjg2LjBKHAoJdG9vbF9uYW1lEg8K + DUdldCBHcmVldGluZ3NKDgoIYXR0ZW1wdHMSAhgBegIYAYUBAAEAABLVBwoQMWfznt0qwauEzl7T + UOQxRBII9q+pUS5EdLAqDENyZXcgQ3JlYXRlZDABORONPORvKBUYQSAoS+RvKBUYShoKDmNyZXdh + aV92ZXJzaW9uEggKBjAuODYuMEoaCg5weXRob25fdmVyc2lvbhIICgYzLjEyLjdKLgoIY3Jld19r + ZXkSIgogYzMwNzYwMDkzMjY3NjE0NDRkNTdjNzFkMWRhM2YyN2NKMQoHY3Jld19pZBImCiQ3OTQw + MTkyNS1iOGU5LTQ3MDgtODUzMC00NDhhZmEzYmY4YjBKHAoMY3Jld19wcm9jZXNzEgwKCnNlcXVl + bnRpYWxKEQoLY3Jld19tZW1vcnkSAhAAShoKFGNyZXdfbnVtYmVyX29mX3Rhc2tzEgIYAUobChVj + cmV3X251bWJlcl9vZl9hZ2VudHMSAhgBSuoCCgtjcmV3X2FnZW50cxLaAgrXAlt7ImtleSI6ICI5 + OGYzYjFkNDdjZTk2OWNmMDU3NzI3Yjc4NDE0MjVjZCIsICJpZCI6ICI5OTJkZjYyZi1kY2FiLTQy + OTUtOTIwNi05MDBkNDExNGIxZTkiLCAicm9sZSI6ICJGcmllbmRseSBOZWlnaGJvciIsICJ2ZXJi + b3NlPyI6IGZhbHNlLCAibWF4X2l0ZXIiOiAyMCwgIm1heF9ycG0iOiBudWxsLCAiZnVuY3Rpb25f + Y2FsbGluZ19sbG0iOiAiIiwgImxsbSI6ICJncHQtNG8tbWluaSIsICJkZWxlZ2F0aW9uX2VuYWJs + ZWQ/IjogZmFsc2UsICJhbGxvd19jb2RlX2V4ZWN1dGlvbj8iOiBmYWxzZSwgIm1heF9yZXRyeV9s + aW1pdCI6IDIsICJ0b29sc19uYW1lcyI6IFsiZGVjaWRlIGdyZWV0aW5ncyJdfV1KmAIKCmNyZXdf + dGFza3MSiQIKhgJbeyJrZXkiOiAiODBkN2JjZDQ5MDk5MjkwMDgzODMyZjBlOTgzMzgwZGYiLCAi + aWQiOiAiMmZmNjE5N2UtYmEyNy00YjczLWI0YTctNGZhMDQ4ZTYyYjQ3IiwgImFzeW5jX2V4ZWN1 + dGlvbj8iOiBmYWxzZSwgImh1bWFuX2lucHV0PyI6IGZhbHNlLCAiYWdlbnRfcm9sZSI6ICJGcmll + bmRseSBOZWlnaGJvciIsICJhZ2VudF9rZXkiOiAiOThmM2IxZDQ3Y2U5NjljZjA1NzcyN2I3ODQx + NDI1Y2QiLCAidG9vbHNfbmFtZXMiOiBbImRlY2lkZSBncmVldGluZ3MiXX1degIYAYUBAAEAABKO + AgoQnjTp5boK7/+DQxztYIpqihIIgGnMUkBtzHEqDFRhc2sgQ3JlYXRlZDABOcpYcuRvKBUYQalE + c+RvKBUYSi4KCGNyZXdfa2V5EiIKIGMzMDc2MDA5MzI2NzYxNDQ0ZDU3YzcxZDFkYTNmMjdjSjEK + B2NyZXdfaWQSJgokNzk0MDE5MjUtYjhlOS00NzA4LTg1MzAtNDQ4YWZhM2JmOGIwSi4KCHRhc2tf + a2V5EiIKIDgwZDdiY2Q0OTA5OTI5MDA4MzgzMmYwZTk4MzM4MGRmSjEKB3Rhc2tfaWQSJgokMmZm + NjE5N2UtYmEyNy00YjczLWI0YTctNGZhMDQ4ZTYyYjQ3egIYAYUBAAEAABKTAQoQ26H9pLUgswDN + p9XhJwwL6BIIx3bw7mAvPYwqClRvb2wgVXNhZ2UwATmy7NPlbygVGEEvb+HlbygVGEoaCg5jcmV3 + YWlfdmVyc2lvbhIICgYwLjg2LjBKHwoJdG9vbF9uYW1lEhIKEERlY2lkZSBHcmVldGluZ3NKDgoI + YXR0ZW1wdHMSAhgBegIYAYUBAAEAAA== + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate + Connection: + - keep-alive + Content-Length: + - '2986' + Content-Type: + - application/x-protobuf + User-Agent: + - OTel-OTLP-Exporter-Python/1.27.0 + method: POST + uri: https://telemetry.crewai.com:4319/v1/traces + response: + body: + string: "\n\0" + headers: + Content-Length: + - '2' + Content-Type: + - application/x-protobuf + Date: + - Fri, 27 Dec 2024 22:14:53 GMT + status: + code: 200 + message: OK - request: body: '{"messages": [{"role": "system", "content": "You are test role. test backstory\nYour personal goal is: test goal\nTo give my best complete final answer to the task @@ -22,18 +105,20 @@ interactions: - '824' content-type: - application/json + cookie: + - _cfuvid=ePJSDFdHag2D8lj21_ijAMWjoA6xfnPNxN4uekvC728-1727226247743-0.0.1.1-604800000 host: - api.openai.com user-agent: - OpenAI/Python 1.52.1 x-stainless-arch: - - arm64 + - x64 x-stainless-async: - 'false' x-stainless-lang: - python x-stainless-os: - - MacOS + - Linux x-stainless-package-version: - 1.52.1 x-stainless-raw-response: @@ -47,8 +132,8 @@ interactions: method: POST uri: https://api.openai.com/v1/chat/completions response: - content: "{\n \"id\": \"chatcmpl-AaqIIsTxhvf75xvuu7gQScIlRSKbW\",\n \"object\": - \"chat.completion\",\n \"created\": 1733344190,\n \"model\": \"gpt-4o-mini-2024-07-18\",\n + content: "{\n \"id\": \"chatcmpl-AjCtZLLrWi8ZASpP9bz6HaCV7xBIn\",\n \"object\": + \"chat.completion\",\n \"created\": 1735337693,\n \"model\": \"gpt-4o-mini-2024-07-18\",\n \ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\": \"assistant\",\n \"content\": \"I now can give a great answer \\nFinal Answer: Hi\",\n \"refusal\": null\n },\n \"logprobs\": null,\n @@ -57,12 +142,12 @@ interactions: {\n \"cached_tokens\": 0,\n \"audio_tokens\": 0\n },\n \"completion_tokens_details\": {\n \"reasoning_tokens\": 0,\n \"audio_tokens\": 0,\n \"accepted_prediction_tokens\": 0,\n \"rejected_prediction_tokens\": 0\n }\n },\n \"system_fingerprint\": - \"fp_0705bf87c0\"\n}\n" + \"fp_0aa8d3e20b\"\n}\n" headers: CF-Cache-Status: - DYNAMIC CF-RAY: - - 8ece8cfc3b1f4532-ATL + - 8f8caa83deca756b-SEA Connection: - keep-alive Content-Encoding: @@ -70,14 +155,14 @@ interactions: Content-Type: - application/json Date: - - Wed, 04 Dec 2024 20:29:50 GMT + - Fri, 27 Dec 2024 22:14:53 GMT Server: - cloudflare Set-Cookie: - - __cf_bm=QJZZjZ6eqnVamqUkw.Bx0mj7oBi3a_vGEH1VODcUxlg-1733344190-1.0.1.1-xyN0ekA9xIrSwEhRBmTiWJ3Pt72UYLU5owKfkz5yihVmMTfsr_Qz.ssGPJ5cuft066v1xVjb4zOSTdFmesMSKg; - path=/; expires=Wed, 04-Dec-24 20:59:50 GMT; domain=.api.openai.com; HttpOnly; + - __cf_bm=wJkq_yLkzE3OdxE0aMJz.G0kce969.9JxRmZ0ratl4c-1735337693-1.0.1.1-OKpUoRrSPFGvWv5Hp5ET1PNZ7iZNHPKEAuakpcQUxxPSeisUIIR3qIOZ31MGmYugqB5.wkvidgbxOAagqJvmnw; + path=/; expires=Fri, 27-Dec-24 22:44:53 GMT; domain=.api.openai.com; HttpOnly; Secure; SameSite=None - - _cfuvid=eCIkP8GVPvpkg19eOhCquWFHm.RTQBQy4yHLGGEAH5c-1733344190334-0.0.1.1-604800000; + - _cfuvid=A_ASCLNAVfQoyucWOAIhecWtEpNotYoZr0bAFihgNxs-1735337693273-0.0.1.1-604800000; path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None Transfer-Encoding: - chunked @@ -90,7 +175,7 @@ interactions: openai-organization: - crewai-iuxna1 openai-processing-ms: - - '313' + - '404' openai-version: - '2020-10-01' strict-transport-security: @@ -108,7 +193,7 @@ interactions: x-ratelimit-reset-tokens: - 0s x-request-id: - - req_9fd9a8ee688045dcf7ac5f6fdf689372 + - req_6ac84634bff9193743c4b0911c09b4a6 http_version: HTTP/1.1 status_code: 200 - request: @@ -131,20 +216,20 @@ interactions: content-type: - application/json cookie: - - __cf_bm=QJZZjZ6eqnVamqUkw.Bx0mj7oBi3a_vGEH1VODcUxlg-1733344190-1.0.1.1-xyN0ekA9xIrSwEhRBmTiWJ3Pt72UYLU5owKfkz5yihVmMTfsr_Qz.ssGPJ5cuft066v1xVjb4zOSTdFmesMSKg; - _cfuvid=eCIkP8GVPvpkg19eOhCquWFHm.RTQBQy4yHLGGEAH5c-1733344190334-0.0.1.1-604800000 + - _cfuvid=A_ASCLNAVfQoyucWOAIhecWtEpNotYoZr0bAFihgNxs-1735337693273-0.0.1.1-604800000; + __cf_bm=wJkq_yLkzE3OdxE0aMJz.G0kce969.9JxRmZ0ratl4c-1735337693-1.0.1.1-OKpUoRrSPFGvWv5Hp5ET1PNZ7iZNHPKEAuakpcQUxxPSeisUIIR3qIOZ31MGmYugqB5.wkvidgbxOAagqJvmnw host: - api.openai.com user-agent: - OpenAI/Python 1.52.1 x-stainless-arch: - - arm64 + - x64 x-stainless-async: - 'false' x-stainless-lang: - python x-stainless-os: - - MacOS + - Linux x-stainless-package-version: - 1.52.1 x-stainless-raw-response: @@ -158,8 +243,8 @@ interactions: method: POST uri: https://api.openai.com/v1/chat/completions response: - content: "{\n \"id\": \"chatcmpl-AaqIIaQlLyoyPmk909PvAIfA2TmJL\",\n \"object\": - \"chat.completion\",\n \"created\": 1733344190,\n \"model\": \"gpt-4o-mini-2024-07-18\",\n + content: "{\n \"id\": \"chatcmpl-AjCtZNlWdrrPZhq0MJDqd16sMuQEJ\",\n \"object\": + \"chat.completion\",\n \"created\": 1735337693,\n \"model\": \"gpt-4o-mini-2024-07-18\",\n \ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\": \"assistant\",\n \"content\": \"True\",\n \"refusal\": null\n \ },\n \"logprobs\": null,\n \"finish_reason\": \"stop\"\n }\n @@ -168,12 +253,12 @@ interactions: 0,\n \"audio_tokens\": 0\n },\n \"completion_tokens_details\": {\n \ \"reasoning_tokens\": 0,\n \"audio_tokens\": 0,\n \"accepted_prediction_tokens\": 0,\n \"rejected_prediction_tokens\": 0\n }\n },\n \"system_fingerprint\": - \"fp_0705bf87c0\"\n}\n" + \"fp_0aa8d3e20b\"\n}\n" headers: CF-Cache-Status: - DYNAMIC CF-RAY: - - 8ece8d060b5e4532-ATL + - 8f8caa87094f756b-SEA Connection: - keep-alive Content-Encoding: @@ -181,7 +266,7 @@ interactions: Content-Type: - application/json Date: - - Wed, 04 Dec 2024 20:29:50 GMT + - Fri, 27 Dec 2024 22:14:53 GMT Server: - cloudflare Transfer-Encoding: @@ -195,7 +280,7 @@ interactions: openai-organization: - crewai-iuxna1 openai-processing-ms: - - '375' + - '156' openai-version: - '2020-10-01' strict-transport-security: @@ -213,7 +298,7 @@ interactions: x-ratelimit-reset-tokens: - 0s x-request-id: - - req_be7cb475e0859a82c37ee3f2871ea5ea + - req_ec74bef2a9ef7b2144c03fd7f7bbeab0 http_version: HTTP/1.1 status_code: 200 - request: @@ -242,20 +327,20 @@ interactions: content-type: - application/json cookie: - - __cf_bm=QJZZjZ6eqnVamqUkw.Bx0mj7oBi3a_vGEH1VODcUxlg-1733344190-1.0.1.1-xyN0ekA9xIrSwEhRBmTiWJ3Pt72UYLU5owKfkz5yihVmMTfsr_Qz.ssGPJ5cuft066v1xVjb4zOSTdFmesMSKg; - _cfuvid=eCIkP8GVPvpkg19eOhCquWFHm.RTQBQy4yHLGGEAH5c-1733344190334-0.0.1.1-604800000 + - _cfuvid=A_ASCLNAVfQoyucWOAIhecWtEpNotYoZr0bAFihgNxs-1735337693273-0.0.1.1-604800000; + __cf_bm=wJkq_yLkzE3OdxE0aMJz.G0kce969.9JxRmZ0ratl4c-1735337693-1.0.1.1-OKpUoRrSPFGvWv5Hp5ET1PNZ7iZNHPKEAuakpcQUxxPSeisUIIR3qIOZ31MGmYugqB5.wkvidgbxOAagqJvmnw host: - api.openai.com user-agent: - OpenAI/Python 1.52.1 x-stainless-arch: - - arm64 + - x64 x-stainless-async: - 'false' x-stainless-lang: - python x-stainless-os: - - MacOS + - Linux x-stainless-package-version: - 1.52.1 x-stainless-raw-response: @@ -269,22 +354,23 @@ interactions: method: POST uri: https://api.openai.com/v1/chat/completions response: - content: "{\n \"id\": \"chatcmpl-AaqIJAAxpVfUOdrsgYKHwfRlHv4RS\",\n \"object\": - \"chat.completion\",\n \"created\": 1733344191,\n \"model\": \"gpt-4o-mini-2024-07-18\",\n + content: "{\n \"id\": \"chatcmpl-AjCtZGv4f3h7GDdhyOy9G0sB1lRgC\",\n \"object\": + \"chat.completion\",\n \"created\": 1735337693,\n \"model\": \"gpt-4o-mini-2024-07-18\",\n \ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\": - \"assistant\",\n \"content\": \"Thought: I now can give a great answer - \ \\nFinal Answer: Hello\",\n \"refusal\": null\n },\n \"logprobs\": - null,\n \"finish_reason\": \"stop\"\n }\n ],\n \"usage\": {\n \"prompt_tokens\": - 188,\n \"completion_tokens\": 14,\n \"total_tokens\": 202,\n \"prompt_tokens_details\": - {\n \"cached_tokens\": 0,\n \"audio_tokens\": 0\n },\n \"completion_tokens_details\": - {\n \"reasoning_tokens\": 0,\n \"audio_tokens\": 0,\n \"accepted_prediction_tokens\": + \"assistant\",\n \"content\": \"Thought: I understand the feedback and + will adjust my response accordingly. \\nFinal Answer: Hello\",\n \"refusal\": + null\n },\n \"logprobs\": null,\n \"finish_reason\": \"stop\"\n + \ }\n ],\n \"usage\": {\n \"prompt_tokens\": 188,\n \"completion_tokens\": + 18,\n \"total_tokens\": 206,\n \"prompt_tokens_details\": {\n \"cached_tokens\": + 0,\n \"audio_tokens\": 0\n },\n \"completion_tokens_details\": {\n + \ \"reasoning_tokens\": 0,\n \"audio_tokens\": 0,\n \"accepted_prediction_tokens\": 0,\n \"rejected_prediction_tokens\": 0\n }\n },\n \"system_fingerprint\": - \"fp_0705bf87c0\"\n}\n" + \"fp_0aa8d3e20b\"\n}\n" headers: CF-Cache-Status: - DYNAMIC CF-RAY: - - 8ece8d090fc34532-ATL + - 8f8caa88cac4756b-SEA Connection: - keep-alive Content-Encoding: @@ -292,7 +378,7 @@ interactions: Content-Type: - application/json Date: - - Wed, 04 Dec 2024 20:29:51 GMT + - Fri, 27 Dec 2024 22:14:54 GMT Server: - cloudflare Transfer-Encoding: @@ -306,7 +392,7 @@ interactions: openai-organization: - crewai-iuxna1 openai-processing-ms: - - '484' + - '358' openai-version: - '2020-10-01' strict-transport-security: @@ -324,7 +410,7 @@ interactions: x-ratelimit-reset-tokens: - 0s x-request-id: - - req_5bf4a565ad6c2567a1ed204ecac89134 + - req_ae1ab6b206d28ded6fee3c83ed0c2ab7 http_version: HTTP/1.1 status_code: 200 - request: @@ -346,20 +432,20 @@ interactions: content-type: - application/json cookie: - - __cf_bm=QJZZjZ6eqnVamqUkw.Bx0mj7oBi3a_vGEH1VODcUxlg-1733344190-1.0.1.1-xyN0ekA9xIrSwEhRBmTiWJ3Pt72UYLU5owKfkz5yihVmMTfsr_Qz.ssGPJ5cuft066v1xVjb4zOSTdFmesMSKg; - _cfuvid=eCIkP8GVPvpkg19eOhCquWFHm.RTQBQy4yHLGGEAH5c-1733344190334-0.0.1.1-604800000 + - _cfuvid=A_ASCLNAVfQoyucWOAIhecWtEpNotYoZr0bAFihgNxs-1735337693273-0.0.1.1-604800000; + __cf_bm=wJkq_yLkzE3OdxE0aMJz.G0kce969.9JxRmZ0ratl4c-1735337693-1.0.1.1-OKpUoRrSPFGvWv5Hp5ET1PNZ7iZNHPKEAuakpcQUxxPSeisUIIR3qIOZ31MGmYugqB5.wkvidgbxOAagqJvmnw host: - api.openai.com user-agent: - OpenAI/Python 1.52.1 x-stainless-arch: - - arm64 + - x64 x-stainless-async: - 'false' x-stainless-lang: - python x-stainless-os: - - MacOS + - Linux x-stainless-package-version: - 1.52.1 x-stainless-raw-response: @@ -373,8 +459,8 @@ interactions: method: POST uri: https://api.openai.com/v1/chat/completions response: - content: "{\n \"id\": \"chatcmpl-AaqIJqyG8vl9mxj2qDPZgaxyNLLIq\",\n \"object\": - \"chat.completion\",\n \"created\": 1733344191,\n \"model\": \"gpt-4o-mini-2024-07-18\",\n + content: "{\n \"id\": \"chatcmpl-AjCtaiHL4TY8Dssk0j2miqmjrzquy\",\n \"object\": + \"chat.completion\",\n \"created\": 1735337694,\n \"model\": \"gpt-4o-mini-2024-07-18\",\n \ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\": \"assistant\",\n \"content\": \"False\",\n \"refusal\": null\n \ },\n \"logprobs\": null,\n \"finish_reason\": \"stop\"\n }\n @@ -383,12 +469,12 @@ interactions: 0,\n \"audio_tokens\": 0\n },\n \"completion_tokens_details\": {\n \ \"reasoning_tokens\": 0,\n \"audio_tokens\": 0,\n \"accepted_prediction_tokens\": 0,\n \"rejected_prediction_tokens\": 0\n }\n },\n \"system_fingerprint\": - \"fp_0705bf87c0\"\n}\n" + \"fp_0aa8d3e20b\"\n}\n" headers: CF-Cache-Status: - DYNAMIC CF-RAY: - - 8ece8d0cfdeb4532-ATL + - 8f8caa8bdd26756b-SEA Connection: - keep-alive Content-Encoding: @@ -396,7 +482,7 @@ interactions: Content-Type: - application/json Date: - - Wed, 04 Dec 2024 20:29:51 GMT + - Fri, 27 Dec 2024 22:14:54 GMT Server: - cloudflare Transfer-Encoding: @@ -410,7 +496,7 @@ interactions: openai-organization: - crewai-iuxna1 openai-processing-ms: - - '341' + - '184' openai-version: - '2020-10-01' strict-transport-security: @@ -428,7 +514,7 @@ interactions: x-ratelimit-reset-tokens: - 0s x-request-id: - - req_5554bade8ceda00cf364b76a51b708ff + - req_652891f79c1104a7a8436275d78a69f1 http_version: HTTP/1.1 status_code: 200 version: 1 diff --git a/tests/crew_test.py b/tests/crew_test.py index 4f22c0d6e..74a659738 100644 --- a/tests/crew_test.py +++ b/tests/crew_test.py @@ -1228,6 +1228,7 @@ def test_kickoff_for_each_empty_input(): assert results == [] +@pytest.mark.vcr(filter_headers=["authorization"]) def test_kickoff_for_each_invalid_input(): """Tests if kickoff_for_each raises TypeError for invalid input types.""" diff --git a/tests/test_flow_persistence.py b/tests/test_flow_persistence.py new file mode 100644 index 000000000..74971f30d --- /dev/null +++ b/tests/test_flow_persistence.py @@ -0,0 +1,195 @@ +"""Test flow state persistence functionality.""" + +import os +from typing import Dict, Optional + +import pytest +from pydantic import BaseModel + +from crewai.flow.flow import Flow, FlowState, start +from crewai.flow.persistence import FlowPersistence, persist +from crewai.flow.persistence.sqlite import SQLiteFlowPersistence + + +class TestState(FlowState): + """Test state model with required id field.""" + counter: int = 0 + message: str = "" + + +def test_persist_decorator_saves_state(tmp_path): + """Test that @persist decorator saves state in SQLite.""" + db_path = os.path.join(tmp_path, "test_flows.db") + persistence = SQLiteFlowPersistence(db_path) + + class TestFlow(Flow[Dict[str, str]]): + initial_state = dict() # Use dict instance as initial state + + @start() + @persist(persistence) + def init_step(self): + self.state["message"] = "Hello, World!" + self.state["id"] = "test-uuid" # Ensure we have an ID for persistence + + # Run flow and verify state is saved + flow = TestFlow(persistence=persistence) + flow.kickoff() + + # Load state from DB and verify + saved_state = persistence.load_state(flow.state["id"]) + assert saved_state is not None + assert saved_state["message"] == "Hello, World!" + + +def test_structured_state_persistence(tmp_path): + """Test persistence with Pydantic model state.""" + db_path = os.path.join(tmp_path, "test_flows.db") + persistence = SQLiteFlowPersistence(db_path) + + class StructuredFlow(Flow[TestState]): + initial_state = TestState + + @start() + @persist(persistence) + def count_up(self): + self.state.counter += 1 + self.state.message = f"Count is {self.state.counter}" + + # Run flow and verify state changes are saved + flow = StructuredFlow(persistence=persistence) + flow.kickoff() + + # Load and verify state + saved_state = persistence.load_state(flow.state.id) + assert saved_state is not None + assert saved_state["counter"] == 1 + assert saved_state["message"] == "Count is 1" + + +def test_flow_state_restoration(tmp_path): + """Test restoring flow state from persistence with various restoration methods.""" + db_path = os.path.join(tmp_path, "test_flows.db") + persistence = SQLiteFlowPersistence(db_path) + + # First flow execution to create initial state + class RestorableFlow(Flow[TestState]): + initial_state = TestState + + @start() + @persist(persistence) + def set_message(self): + self.state.message = "Original message" + self.state.counter = 42 + + # Create and persist initial state + flow1 = RestorableFlow(persistence=persistence) + flow1.kickoff() + original_uuid = flow1.state.id + + # Test case 1: Restore using restore_uuid with field override + flow2 = RestorableFlow( + persistence=persistence, + restore_uuid=original_uuid, + counter=43, # Override counter + ) + + # Verify state restoration and selective field override + assert flow2.state.id == original_uuid + assert flow2.state.message == "Original message" # Preserved + assert flow2.state.counter == 43 # Overridden + + # Test case 2: Restore using kwargs['id'] + flow3 = RestorableFlow( + persistence=persistence, + id=original_uuid, + message="Updated message", # Override message + ) + + # Verify state restoration and selective field override + assert flow3.state.id == original_uuid + assert flow3.state.counter == 42 # Preserved + assert flow3.state.message == "Updated message" # Overridden + + # Test case 3: Verify error on conflicting IDs + with pytest.raises(ValueError) as exc_info: + RestorableFlow( + persistence=persistence, + restore_uuid=original_uuid, + id="different-id", # Conflict with restore_uuid + ) + assert "Conflicting IDs provided" in str(exc_info.value) + + # Test case 4: Verify error on non-existent restore_uuid + with pytest.raises(ValueError) as exc_info: + RestorableFlow( + persistence=persistence, + restore_uuid="non-existent-uuid", + ) + assert "No state found" in str(exc_info.value) + + # Test case 5: Allow new state creation with kwargs['id'] + new_uuid = "new-flow-id" + flow4 = RestorableFlow( + persistence=persistence, + id=new_uuid, + message="New message", + counter=100, + ) + + # Verify new state creation with provided ID + assert flow4.state.id == new_uuid + assert flow4.state.message == "New message" + assert flow4.state.counter == 100 + + +def test_multiple_method_persistence(tmp_path): + """Test state persistence across multiple method executions.""" + db_path = os.path.join(tmp_path, "test_flows.db") + persistence = SQLiteFlowPersistence(db_path) + + class MultiStepFlow(Flow[TestState]): + initial_state = TestState + + @start() + @persist(persistence) + def step_1(self): + self.state.counter = 1 + self.state.message = "Step 1" + + @start() + @persist(persistence) + def step_2(self): + self.state.counter = 2 + self.state.message = "Step 2" + + flow = MultiStepFlow(persistence=persistence) + flow.kickoff() + + # Load final state + final_state = persistence.load_state(flow.state.id) + assert final_state is not None + assert final_state["counter"] == 2 + assert final_state["message"] == "Step 2" + + +def test_persistence_error_handling(tmp_path): + """Test error handling in persistence operations.""" + db_path = os.path.join(tmp_path, "test_flows.db") + persistence = SQLiteFlowPersistence(db_path) + + class InvalidFlow(Flow[TestState]): + # Missing id field in initial state + class InvalidState(BaseModel): + value: str = "" + + initial_state = InvalidState + + @start() + @persist(persistence) + def will_fail(self): + self.state.value = "test" + + with pytest.raises(ValueError) as exc_info: + flow = InvalidFlow(persistence=persistence) + + assert "must have an 'id' field" in str(exc_info.value) diff --git a/uv.lock b/uv.lock index 56ed758df..11f2e6691 100644 --- a/uv.lock +++ b/uv.lock @@ -1,6 +1,7 @@ version = 1 requires-python = ">=3.10, <3.13" resolution-markers = [ + "python_full_version < '3.11' and platform_system == 'Darwin' and sys_platform == 'darwin'", "python_full_version < '3.11' and platform_machine == 'aarch64' and platform_system == 'Linux' and sys_platform == 'darwin'", "(python_full_version < '3.11' and platform_machine != 'aarch64' and platform_system != 'Darwin' and sys_platform == 'darwin') or (python_full_version < '3.11' and platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'darwin')", @@ -36,7 +37,7 @@ resolution-markers = [ "python_full_version >= '3.12.4' and platform_machine == 'aarch64' and platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'linux'", "(python_full_version >= '3.12.4' and platform_machine != 'aarch64' and platform_system == 'Darwin' and sys_platform != 'darwin') or (python_full_version >= '3.12.4' and platform_system == 'Darwin' and sys_platform != 'darwin' and sys_platform != 'linux')", "python_full_version >= '3.12.4' and platform_machine == 'aarch64' and platform_system == 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux'", - "(python_full_version >= '3.12.4' and platform_machine != 'aarch64' and platform_system != 'Darwin' and sys_platform != 'darwin') or (python_full_version >= '3.12.4' and platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux')", + "(python_full_version >= '3.12.4' and platform_machine != 'aarch64' and platform_system != 'Darwin' and sys_platform != 'darwin') or (python_full_version >= '3.12.4' and platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform != 'darwin' and sys_platform != 'linux')" ] [[package]] @@ -345,7 +346,7 @@ name = "build" version = "1.2.2.post1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "os_name == 'nt'" }, + { name = "colorama", marker = "(os_name == 'nt' and platform_machine != 'aarch64' and sys_platform == 'linux') or (os_name == 'nt' and sys_platform != 'darwin' and sys_platform != 'linux')" }, { name = "importlib-metadata", marker = "python_full_version < '3.10.2'" }, { name = "packaging" }, { name = "pyproject-hooks" }, @@ -580,7 +581,7 @@ name = "click" version = "8.1.7" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/96/d3/f04c7bfcf5c1862a2a5b845c6b2b360488cf47af55dfa79c98f6a6bf98b5/click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de", size = 336121 } wheels = [ @@ -2587,7 +2588,7 @@ version = "1.6.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "click" }, - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, { name = "ghp-import" }, { name = "jinja2" }, { name = "markdown" }, @@ -2768,7 +2769,7 @@ version = "2.10.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "pygments" }, - { name = "pywin32", marker = "platform_system == 'Windows'" }, + { name = "pywin32", marker = "sys_platform == 'win32'" }, { name = "tqdm" }, ] sdist = { url = "https://files.pythonhosted.org/packages/3a/93/80ac75c20ce54c785648b4ed363c88f148bf22637e10c9863db4fbe73e74/mpire-2.10.2.tar.gz", hash = "sha256:f66a321e93fadff34585a4bfa05e95bd946cf714b442f51c529038eb45773d97", size = 271270 } @@ -3015,7 +3016,7 @@ name = "nvidia-cudnn-cu12" version = "9.1.0.70" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform != 'linux')" }, + { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform != 'linux')" } ] wheels = [ { url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 }, @@ -3044,7 +3045,7 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform != 'linux')" }, { name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform != 'linux')" }, - { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform != 'linux')" }, + { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform != 'linux')" } ] wheels = [ { url = "https://files.pythonhosted.org/packages/bc/1d/8de1e5c67099015c834315e333911273a8c6aaba78923dd1d1e25fc5f217/nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd", size = 124161928 }, @@ -3055,7 +3056,7 @@ name = "nvidia-cusparse-cu12" version = "12.1.0.106" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform != 'linux')" }, + { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform != 'linux')" } ] wheels = [ { url = "https://files.pythonhosted.org/packages/65/5b/cfaeebf25cd9fdec14338ccb16f6b2c4c7fa9163aefcf057d86b9cc248bb/nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c", size = 195958278 }, @@ -3605,7 +3606,7 @@ name = "portalocker" version = "2.10.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "pywin32", marker = "platform_system == 'Windows'" }, + { name = "pywin32", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/ed/d3/c6c64067759e87af98cc668c1cc75171347d0f1577fab7ca3749134e3cd4/portalocker-2.10.1.tar.gz", hash = "sha256:ef1bf844e878ab08aee7e40184156e1151f228f103aa5c6bd0724cc330960f8f", size = 40891 } wheels = [ @@ -5193,19 +5194,19 @@ dependencies = [ { name = "fsspec" }, { name = "jinja2" }, { name = "networkx" }, - { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "sympy" }, - { name = "triton", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "typing-extensions" }, ] wheels = [ @@ -5252,7 +5253,7 @@ name = "tqdm" version = "4.66.5" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/58/83/6ba9844a41128c62e810fddddd72473201f3eacde02046066142a2d96cc5/tqdm-4.66.5.tar.gz", hash = "sha256:e1020aef2e5096702d8a025ac7d16b1577279c9d63f8375b63083e9a5f0fcbad", size = 169504 } wheels = [ @@ -5295,7 +5296,7 @@ version = "0.27.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "attrs" }, - { name = "cffi", marker = "implementation_name != 'pypy' and os_name == 'nt'" }, + { name = "cffi", marker = "(implementation_name != 'pypy' and os_name == 'nt' and platform_machine != 'aarch64' and sys_platform == 'linux') or (implementation_name != 'pypy' and os_name == 'nt' and sys_platform != 'darwin' and sys_platform != 'linux')" }, { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, { name = "idna" }, { name = "outcome" }, @@ -5326,7 +5327,7 @@ name = "triton" version = "3.0.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "filelock", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform != 'linux')" }, + { name = "filelock", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform != 'linux')" } ] wheels = [ { url = "https://files.pythonhosted.org/packages/45/27/14cc3101409b9b4b9241d2ba7deaa93535a217a211c86c4cc7151fb12181/triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e1efef76935b2febc365bfadf74bcb65a6f959a9872e5bddf44cc9e0adce1e1a", size = 209376304 },