Merge branch 'main' into fix-conditional-tasks-out-of-index

This commit is contained in:
Brandon Hancock (bhancock_ai)
2025-01-21 11:10:50 -05:00
committed by GitHub
13 changed files with 441 additions and 256 deletions

View File

@@ -147,7 +147,36 @@ Some commands may require additional configuration or setup within your project
</Note>
### 9. API Keys
### 9. Chat
Starting in version `0.98.0`, when you run the `crewai chat` command, you start an interactive session with your crew. The AI assistant will guide you by asking for necessary inputs to execute the crew. Once all inputs are provided, the crew will execute its tasks.
After receiving the results, you can continue interacting with the assistant for further instructions or questions.
```shell
crewai chat
```
<Note>
Ensure you execute these commands from your CrewAI project's root directory.
</Note>
<Note>
IMPORTANT: Set the `chat_llm` property in your `crew.py` file to enable this command.
```python
@crew
def crew(self) -> Crew:
return Crew(
agents=self.agents,
tasks=self.tasks,
process=Process.sequential,
verbose=True,
chat_llm="gpt-4o", # LLM for chat orchestration
)
```
</Note>
### 10. API Keys
When running ```crewai create crew``` command, the CLI will first show you the top 5 most common LLM providers and ask you to select one.

View File

@@ -278,7 +278,7 @@ email_summarizer:
Summarize emails into a concise and clear summary
backstory: >
You will create a 5 bullet point summary of the report
llm: mixtal_llm
llm: openai/gpt-4o
```
<Tip>

View File

@@ -1,78 +1,117 @@
---
title: Composio Tool
description: The `ComposioTool` is a wrapper around the composio set of tools and gives your agent access to a wide variety of tools from the Composio SDK.
title: Composio
description: Composio provides 250+ production-ready tools for AI agents with flexible authentication management.
icon: gear-code
---
# `ComposioTool`
# `ComposioToolSet`
## Description
Composio is an integration platform that allows you to connect your AI agents to 250+ tools. Key features include:
This tools is a wrapper around the composio set of tools and gives your agent access to a wide variety of tools from the Composio SDK.
- **Enterprise-Grade Authentication**: Built-in support for OAuth, API Keys, JWT with automatic token refresh
- **Full Observability**: Detailed tool usage logs, execution timestamps, and more
## Installation
To incorporate this tool into your project, follow the installation instructions below:
To incorporate Composio tools into your project, follow the instructions below:
```shell
pip install composio-core
pip install 'crewai[tools]'
pip install composio-crewai
pip install crewai
```
after the installation is complete, either run `composio login` or export your composio API key as `COMPOSIO_API_KEY`.
After the installation is complete, either run `composio login` or export your composio API key as `COMPOSIO_API_KEY`. Get your Composio API key from [here](https://app.composio.dev)
## Example
The following example demonstrates how to initialize the tool and execute a github action:
1. Initialize Composio tools
1. Initialize Composio toolset
```python Code
from composio import App
from crewai_tools import ComposioTool
from crewai import Agent, Task
from composio_crewai import ComposioToolSet, App, Action
from crewai import Agent, Task, Crew
tools = [ComposioTool.from_action(action=Action.GITHUB_ACTIVITY_STAR_REPO_FOR_AUTHENTICATED_USER)]
toolset = ComposioToolSet()
```
If you don't know what action you want to use, use `from_app` and `tags` filter to get relevant actions
2. Connect your GitHub account
<CodeGroup>
```shell CLI
composio add github
```
```python Code
tools = ComposioTool.from_app(App.GITHUB, tags=["important"])
request = toolset.initiate_connection(app=App.GITHUB)
print(f"Open this URL to authenticate: {request.redirectUrl}")
```
</CodeGroup>
or use `use_case` to search relevant actions
3. Get Tools
- Retrieving all the tools from an app (not recommended for production):
```python Code
tools = ComposioTool.from_app(App.GITHUB, use_case="Star a github repository")
tools = toolset.get_tools(apps=[App.GITHUB])
```
2. Define agent
- Filtering tools based on tags:
```python Code
tag = "users"
filtered_action_enums = toolset.find_actions_by_tags(
App.GITHUB,
tags=[tag],
)
tools = toolset.get_tools(actions=filtered_action_enums)
```
- Filtering tools based on use case:
```python Code
use_case = "Star a repository on GitHub"
filtered_action_enums = toolset.find_actions_by_use_case(
App.GITHUB, use_case=use_case, advanced=False
)
tools = toolset.get_tools(actions=filtered_action_enums)
```<Tip>Set `advanced` to True to get actions for complex use cases</Tip>
- Using specific tools:
In this demo, we will use the `GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER` action from the GitHub app.
```python Code
tools = toolset.get_tools(
actions=[Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER]
)
```
Learn more about filtering actions [here](https://docs.composio.dev/patterns/tools/use-tools/use-specific-actions)
4. Define agent
```python Code
crewai_agent = Agent(
role="Github Agent",
goal="You take action on Github using Github APIs",
backstory=(
"You are AI agent that is responsible for taking actions on Github "
"on users behalf. You need to take action on Github using Github APIs"
),
role="GitHub Agent",
goal="You take action on GitHub using GitHub APIs",
backstory="You are AI agent that is responsible for taking actions on GitHub on behalf of users using GitHub APIs",
verbose=True,
tools=tools,
llm= # pass an llm
)
```
3. Execute task
5. Execute task
```python Code
task = Task(
description="Star a repo ComposioHQ/composio on GitHub",
description="Star a repo composiohq/composio on GitHub",
agent=crewai_agent,
expected_output="if the star happened",
expected_output="Status of the operation",
)
task.execute()
crew = Crew(agents=[crewai_agent], tasks=[task])
crew.kickoff()
```
* More detailed list of tools can be found [here](https://app.composio.dev)
* More detailed list of tools can be found [here](https://app.composio.dev)

View File

@@ -37,7 +37,6 @@ from crewai.tasks.task_output import TaskOutput
from crewai.telemetry import Telemetry
from crewai.tools.agent_tools.agent_tools import AgentTools
from crewai.tools.base_tool import Tool
from crewai.types.crew_chat import ChatInputs
from crewai.types.usage_metrics import UsageMetrics
from crewai.utilities import I18N, FileHandler, Logger, RPMController
from crewai.utilities.constants import TRAINING_DATA_FILE
@@ -84,6 +83,7 @@ class Crew(BaseModel):
step_callback: Callback to be executed after each step for every agents execution.
share_crew: Whether you want to share the complete crew information and execution with crewAI to make the library better, and allow us to train models.
planning: Plan the crew execution and add the plan to the crew.
chat_llm: The language model used for orchestrating chat interactions with the crew.
"""
__hash__ = object.__hash__ # type: ignore

View File

@@ -447,14 +447,12 @@ class Flow(Generic[T], metaclass=FlowMeta):
def __init__(
self,
persistence: Optional[FlowPersistence] = None,
restore_uuid: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Initialize a new Flow instance.
Args:
persistence: Optional persistence backend for storing flow states
restore_uuid: Optional UUID to restore state from persistence
**kwargs: Additional state values to initialize or override
"""
# Initialize basic instance attributes
@@ -464,64 +462,12 @@ class Flow(Generic[T], metaclass=FlowMeta):
self._method_outputs: List[Any] = [] # List to store all method outputs
self._persistence: Optional[FlowPersistence] = persistence
# Validate state model before initialization
if isinstance(self.initial_state, type):
if issubclass(self.initial_state, BaseModel) and not issubclass(
self.initial_state, FlowState
):
# Check if model has id field
model_fields = getattr(self.initial_state, "model_fields", None)
if not model_fields or "id" not in model_fields:
raise ValueError("Flow state model must have an 'id' field")
# Initialize state with initial values
self._state = self._create_initial_state()
# Handle persistence and potential ID conflicts
stored_state = None
if self._persistence is not None:
if (
restore_uuid
and kwargs
and "id" in kwargs
and restore_uuid != kwargs["id"]
):
raise ValueError(
f"Conflicting IDs provided: restore_uuid='{restore_uuid}' "
f"vs kwargs['id']='{kwargs['id']}'. Use only one ID for restoration."
)
# Attempt to load state, prioritizing restore_uuid
if restore_uuid:
self._log_flow_event(f"Loading flow state from memory for UUID: {restore_uuid}", color="bold_yellow")
stored_state = self._persistence.load_state(restore_uuid)
if not stored_state:
raise ValueError(
f"No state found for restore_uuid='{restore_uuid}'"
)
elif kwargs and "id" in kwargs:
self._log_flow_event(f"Loading flow state from memory for ID: {kwargs['id']}", color="bold_yellow")
stored_state = self._persistence.load_state(kwargs["id"])
if not stored_state:
# For kwargs["id"], we allow creating new state if not found
self._state = self._create_initial_state()
if kwargs:
self._initialize_state(kwargs)
return
# Initialize state based on persistence and kwargs
if stored_state:
# Create initial state and restore from persistence
self._state = self._create_initial_state()
self._restore_state(stored_state)
# Apply any additional kwargs to override specific fields
if kwargs:
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "id"}
if filtered_kwargs:
self._initialize_state(filtered_kwargs)
else:
# No stored state, create new state with initial values
self._state = self._create_initial_state()
# Apply any additional kwargs
if kwargs:
self._initialize_state(kwargs)
# Apply any additional kwargs
if kwargs:
self._initialize_state(kwargs)
self._telemetry.flow_creation_span(self.__class__.__name__)
@@ -635,18 +581,18 @@ class Flow(Generic[T], metaclass=FlowMeta):
@property
def flow_id(self) -> str:
"""Returns the unique identifier of this flow instance.
This property provides a consistent way to access the flow's unique identifier
regardless of the underlying state implementation (dict or BaseModel).
Returns:
str: The flow's unique identifier, or an empty string if not found
Note:
This property safely handles both dictionary and BaseModel state types,
returning an empty string if the ID cannot be retrieved rather than raising
an exception.
Example:
```python
flow = MyFlow()
@@ -656,7 +602,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
try:
if not hasattr(self, '_state'):
return ""
if isinstance(self._state, dict):
return str(self._state.get("id", ""))
elif isinstance(self._state, BaseModel):
@@ -731,7 +677,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
"""
# When restoring from persistence, use the stored ID
stored_id = stored_state.get("id")
self._log_flow_event(f"Restoring flow state from memory for ID: {stored_id}", color="bold_yellow")
if not stored_id:
raise ValueError("Stored state must have an 'id' field")
@@ -755,6 +700,36 @@ class Flow(Generic[T], metaclass=FlowMeta):
raise TypeError(f"State must be dict or BaseModel, got {type(self._state)}")
def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
"""Start the flow execution.
Args:
inputs: Optional dictionary containing input values and potentially a state ID to restore
"""
# Handle state restoration if ID is provided in inputs
if inputs and 'id' in inputs and self._persistence is not None:
restore_uuid = inputs['id']
stored_state = self._persistence.load_state(restore_uuid)
# Override the id in the state if it exists in inputs
if 'id' in inputs:
if isinstance(self._state, dict):
self._state['id'] = inputs['id']
elif isinstance(self._state, BaseModel):
setattr(self._state, 'id', inputs['id'])
if stored_state:
self._log_flow_event(f"Loading flow state from memory for UUID: {restore_uuid}", color="yellow")
# Restore the state
self._restore_state(stored_state)
else:
self._log_flow_event(f"No flow state found for UUID: {restore_uuid}", color="red")
# Apply any additional inputs after restoration
filtered_inputs = {k: v for k, v in inputs.items() if k != 'id'}
if filtered_inputs:
self._initialize_state(filtered_inputs)
# Start flow execution
self.event_emitter.send(
self,
event=FlowStartedEvent(
@@ -762,10 +737,11 @@ class Flow(Generic[T], metaclass=FlowMeta):
flow_name=self.__class__.__name__,
),
)
self._log_flow_event(f"Flow started with ID: {self.flow_id}", color="yellow")
self._log_flow_event(f"Flow started with ID: {self.flow_id}", color="bold_magenta")
if inputs is not None:
if inputs is not None and 'id' not in inputs:
self._initialize_state(inputs)
return asyncio.run(self.kickoff_async())
async def kickoff_async(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
@@ -1010,18 +986,18 @@ class Flow(Generic[T], metaclass=FlowMeta):
def _log_flow_event(self, message: str, color: str = "yellow", level: str = "info") -> None:
"""Centralized logging method for flow events.
This method provides a consistent interface for logging flow-related events,
combining both console output with colors and proper logging levels.
Args:
message: The message to log
color: Color to use for console output (default: yellow)
Available colors: purple, red, bold_green, bold_purple,
bold_blue, yellow, bold_yellow
bold_blue, yellow, yellow
level: Log level to use (default: info)
Supported levels: info, warning
Note:
This method uses the Printer utility for colored console output
and the standard logging module for log level support.
@@ -1031,7 +1007,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
logger.info(message)
elif level == "warning":
logger.warning(message)
def plot(self, filename: str = "crewai_flow") -> None:
self._telemetry.flow_plotting_span(
self.__class__.__name__, list(self._methods.keys())

View File

@@ -54,57 +54,44 @@ LOG_MESSAGES = {
class PersistenceDecorator:
"""Class to handle flow state persistence with consistent logging."""
_printer = Printer() # Class-level printer instance
@classmethod
def persist_state(cls, flow_instance: Any, method_name: str, persistence_instance: FlowPersistence) -> None:
"""Persist flow state with proper error handling and logging.
This method handles the persistence of flow state data, including proper
error handling and colored console output for status updates.
Args:
flow_instance: The flow instance whose state to persist
method_name: Name of the method that triggered persistence
persistence_instance: The persistence backend to use
Raises:
ValueError: If flow has no state or state lacks an ID
RuntimeError: If state persistence fails
AttributeError: If flow instance lacks required state attributes
Note:
Uses bold_yellow color for success messages and red for errors.
All operations are logged at appropriate levels (info/error).
Example:
```python
@persist
def my_flow_method(self):
# Method implementation
pass
# State will be automatically persisted after method execution
```
"""
try:
state = getattr(flow_instance, 'state', None)
if state is None:
raise ValueError("Flow instance has no state")
flow_uuid: Optional[str] = None
if isinstance(state, dict):
flow_uuid = state.get('id')
elif isinstance(state, BaseModel):
flow_uuid = getattr(state, 'id', None)
if not flow_uuid:
raise ValueError("Flow state must have an 'id' field for persistence")
# Log state saving with consistent message
cls._printer.print(LOG_MESSAGES["save_state"].format(flow_uuid), color="bold_yellow")
cls._printer.print(LOG_MESSAGES["save_state"].format(flow_uuid), color="cyan")
logger.info(LOG_MESSAGES["save_state"].format(flow_uuid))
try:
persistence_instance.save_state(
flow_uuid=flow_uuid,
@@ -154,44 +141,79 @@ def persist(persistence: Optional[FlowPersistence] = None):
def begin(self):
pass
"""
def decorator(target: Union[Type, Callable[..., T]]) -> Union[Type, Callable[..., T]]:
"""Decorator that handles both class and method decoration."""
actual_persistence = persistence or SQLiteFlowPersistence()
if isinstance(target, type):
# Class decoration
class_methods = {}
for name, method in target.__dict__.items():
if callable(method) and hasattr(method, "__is_flow_method__"):
# Wrap each flow method with persistence
if asyncio.iscoroutinefunction(method):
@functools.wraps(method)
async def class_async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
method_coro = method(self, *args, **kwargs)
if asyncio.iscoroutine(method_coro):
result = await method_coro
else:
result = method_coro
PersistenceDecorator.persist_state(self, method.__name__, actual_persistence)
return result
class_methods[name] = class_async_wrapper
else:
@functools.wraps(method)
def class_sync_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
result = method(self, *args, **kwargs)
PersistenceDecorator.persist_state(self, method.__name__, actual_persistence)
return result
class_methods[name] = class_sync_wrapper
original_init = getattr(target, "__init__")
# Preserve flow-specific attributes
@functools.wraps(original_init)
def new_init(self: Any, *args: Any, **kwargs: Any) -> None:
if 'persistence' not in kwargs:
kwargs['persistence'] = actual_persistence
original_init(self, *args, **kwargs)
setattr(target, "__init__", new_init)
# Store original methods to preserve their decorators
original_methods = {}
for name, method in target.__dict__.items():
if callable(method) and (
hasattr(method, "__is_start_method__") or
hasattr(method, "__trigger_methods__") or
hasattr(method, "__condition_type__") or
hasattr(method, "__is_flow_method__") or
hasattr(method, "__is_router__")
):
original_methods[name] = method
# Create wrapped versions of the methods that include persistence
for name, method in original_methods.items():
if asyncio.iscoroutinefunction(method):
# Create a closure to capture the current name and method
def create_async_wrapper(method_name: str, original_method: Callable):
@functools.wraps(original_method)
async def method_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
result = await original_method(self, *args, **kwargs)
PersistenceDecorator.persist_state(self, method_name, actual_persistence)
return result
return method_wrapper
wrapped = create_async_wrapper(name, method)
# Preserve all original decorators and attributes
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
if hasattr(method, attr):
setattr(class_methods[name], attr, getattr(method, attr))
setattr(class_methods[name], "__is_flow_method__", True)
setattr(wrapped, attr, getattr(method, attr))
setattr(wrapped, "__is_flow_method__", True)
# Update the class with the wrapped method
setattr(target, name, wrapped)
else:
# Create a closure to capture the current name and method
def create_sync_wrapper(method_name: str, original_method: Callable):
@functools.wraps(original_method)
def method_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
result = original_method(self, *args, **kwargs)
PersistenceDecorator.persist_state(self, method_name, actual_persistence)
return result
return method_wrapper
wrapped = create_sync_wrapper(name, method)
# Preserve all original decorators and attributes
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
if hasattr(method, attr):
setattr(wrapped, attr, getattr(method, attr))
setattr(wrapped, "__is_flow_method__", True)
# Update the class with the wrapped method
setattr(target, name, wrapped)
# Update class with wrapped methods
for name, method in class_methods.items():
setattr(target, name, method)
return target
else:
# Method decoration
@@ -208,6 +230,7 @@ def persist(persistence: Optional[FlowPersistence] = None):
result = method_coro
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence)
return result
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
if hasattr(method, attr):
setattr(method_async_wrapper, attr, getattr(method, attr))
@@ -219,6 +242,7 @@ def persist(persistence: Optional[FlowPersistence] = None):
result = method(flow_instance, *args, **kwargs)
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence)
return result
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
if hasattr(method, attr):
setattr(method_sync_wrapper, attr, getattr(method, attr))

View File

@@ -3,10 +3,9 @@ SQLite-based implementation of flow state persistence.
"""
import json
import os
import sqlite3
import tempfile
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Optional, Union
from pydantic import BaseModel
@@ -16,34 +15,34 @@ from crewai.flow.persistence.base import FlowPersistence
class SQLiteFlowPersistence(FlowPersistence):
"""SQLite-based implementation of flow state persistence.
This class provides a simple, file-based persistence implementation using SQLite.
It's suitable for development and testing, or for production use cases with
moderate performance requirements.
"""
db_path: str # Type annotation for instance variable
def __init__(self, db_path: Optional[str] = None):
"""Initialize SQLite persistence.
Args:
db_path: Path to the SQLite database file. If not provided, uses
db_storage_path() from utilities.paths.
Raises:
ValueError: If db_path is invalid
"""
from crewai.utilities.paths import db_storage_path
# Get path from argument or default location
path = db_path or db_storage_path()
path = db_path or str(Path(db_storage_path()) / "flow_states.db")
if not path:
raise ValueError("Database path must be provided")
self.db_path = path # Now mypy knows this is str
self.init_db()
def init_db(self) -> None:
"""Create the necessary tables if they don't exist."""
with sqlite3.connect(self.db_path) as conn:
@@ -58,10 +57,10 @@ class SQLiteFlowPersistence(FlowPersistence):
""")
# Add index for faster UUID lookups
conn.execute("""
CREATE INDEX IF NOT EXISTS idx_flow_states_uuid
CREATE INDEX IF NOT EXISTS idx_flow_states_uuid
ON flow_states(flow_uuid)
""")
def save_state(
self,
flow_uuid: str,
@@ -69,7 +68,7 @@ class SQLiteFlowPersistence(FlowPersistence):
state_data: Union[Dict[str, Any], BaseModel],
) -> None:
"""Save the current flow state to SQLite.
Args:
flow_uuid: Unique identifier for the flow instance
method_name: Name of the method that just completed
@@ -84,7 +83,7 @@ class SQLiteFlowPersistence(FlowPersistence):
raise ValueError(
f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}"
)
with sqlite3.connect(self.db_path) as conn:
conn.execute("""
INSERT INTO flow_states (
@@ -99,13 +98,13 @@ class SQLiteFlowPersistence(FlowPersistence):
datetime.utcnow().isoformat(),
json.dumps(state_dict),
))
def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]:
"""Load the most recent state for a given flow UUID.
Args:
flow_uuid: Unique identifier for the flow instance
Returns:
The most recent state as a dictionary, or None if no state exists
"""
@@ -118,7 +117,7 @@ class SQLiteFlowPersistence(FlowPersistence):
LIMIT 1
""", (flow_uuid,))
row = cursor.fetchone()
if row:
return json.loads(row[0])
return None

View File

@@ -23,7 +23,7 @@ class KickoffTaskOutputsSQLiteStorage:
) -> None:
if db_path is None:
# Get the parent directory of the default db path and create our db file there
db_path = str(Path(db_storage_path()).parent / "latest_kickoff_task_outputs.db")
db_path = str(Path(db_storage_path()) / "latest_kickoff_task_outputs.db")
self.db_path = db_path
self._printer: Printer = Printer()
self._initialize_db()

View File

@@ -17,7 +17,7 @@ class LTMSQLiteStorage:
) -> None:
if db_path is None:
# Get the parent directory of the default db path and create our db file there
db_path = str(Path(db_storage_path()).parent / "long_term_memory_storage.db")
db_path = str(Path(db_storage_path()) / "long_term_memory_storage.db")
self.db_path = db_path
self._printer: Printer = Printer()
# Ensure parent directory exists

View File

@@ -7,7 +7,7 @@ import appdirs
def db_storage_path() -> str:
"""Returns the path for SQLite database storage.
Returns:
str: Full path to the SQLite database file
"""
@@ -16,7 +16,7 @@ def db_storage_path() -> str:
data_dir = Path(appdirs.user_data_dir(app_name, app_author))
data_dir.mkdir(parents=True, exist_ok=True)
return str(data_dir / "crewai_flows.db")
return str(data_dir)
def get_project_directory_name():
@@ -28,4 +28,4 @@ def get_project_directory_name():
else:
cwd = Path.cwd()
project_directory_name = cwd.name
return project_directory_name
return project_directory_name

View File

@@ -21,6 +21,16 @@ class Printer:
self._print_yellow(content)
elif color == "bold_yellow":
self._print_bold_yellow(content)
elif color == "cyan":
self._print_cyan(content)
elif color == "bold_cyan":
self._print_bold_cyan(content)
elif color == "magenta":
self._print_magenta(content)
elif color == "bold_magenta":
self._print_bold_magenta(content)
elif color == "green":
self._print_green(content)
else:
print(content)
@@ -44,3 +54,18 @@ class Printer:
def _print_bold_yellow(self, content):
print("\033[1m\033[93m {}\033[00m".format(content))
def _print_cyan(self, content):
print("\033[96m {}\033[00m".format(content))
def _print_bold_cyan(self, content):
print("\033[1m\033[96m {}\033[00m".format(content))
def _print_magenta(self, content):
print("\033[35m {}\033[00m".format(content))
def _print_bold_magenta(self, content):
print("\033[1m\033[35m {}\033[00m".format(content))
def _print_green(self, content):
print("\033[32m {}\033[00m".format(content))

View File

@@ -0,0 +1,112 @@
"""Test that persisted state properly overrides default values."""
from crewai.flow.flow import Flow, FlowState, listen, start
from crewai.flow.persistence import persist
class PoemState(FlowState):
"""Test state model with default values that should be overridden."""
sentence_count: int = 1000 # Default that should be overridden
has_set_count: bool = False # Track whether we've set the count
poem_type: str = ""
def test_default_value_override():
"""Test that persisted state values override class defaults."""
@persist()
class PoemFlow(Flow[PoemState]):
initial_state = PoemState
@start()
def set_sentence_count(self):
if self.state.has_set_count and self.state.sentence_count == 2:
self.state.sentence_count = 3
elif self.state.has_set_count and self.state.sentence_count == 1000:
self.state.sentence_count = 1000
elif self.state.has_set_count and self.state.sentence_count == 5:
self.state.sentence_count = 5
else:
self.state.sentence_count = 2
self.state.has_set_count = True
# First run - should set sentence_count to 2
flow1 = PoemFlow()
flow1.kickoff()
original_uuid = flow1.state.id
assert flow1.state.sentence_count == 2
# Second run - should load sentence_count=2 instead of default 1000
flow2 = PoemFlow()
flow2.kickoff(inputs={"id": original_uuid})
assert flow2.state.sentence_count == 3 # Should load 2, not default 1000
# Fourth run - explicit override should work
flow3 = PoemFlow()
flow3.kickoff(inputs={
"id": original_uuid,
"has_set_count": True,
"sentence_count": 5, # Override persisted value
})
assert flow3.state.sentence_count == 5 # Should use override value
# Third run - should not load sentence_count=2 instead of default 1000
flow4 = PoemFlow()
flow4.kickoff(inputs={"has_set_count": True})
assert flow4.state.sentence_count == 1000 # Should load 1000, not 2
def test_multi_step_default_override():
"""Test default value override with multiple start methods."""
@persist()
class MultiStepPoemFlow(Flow[PoemState]):
initial_state = PoemState
@start()
def set_sentence_count(self):
print("Setting sentence count")
if not self.state.has_set_count:
self.state.sentence_count = 3
self.state.has_set_count = True
@listen(set_sentence_count)
def set_poem_type(self):
print("Setting poem type")
if self.state.sentence_count == 3:
self.state.poem_type = "haiku"
elif self.state.sentence_count == 5:
self.state.poem_type = "limerick"
else:
self.state.poem_type = "free_verse"
@listen(set_poem_type)
def finished(self):
print("finished")
# First run - should set both sentence count and poem type
flow1 = MultiStepPoemFlow()
flow1.kickoff()
original_uuid = flow1.state.id
assert flow1.state.sentence_count == 3
assert flow1.state.poem_type == "haiku"
# Second run - should load persisted state and update poem type
flow2 = MultiStepPoemFlow()
flow2.kickoff(inputs={
"id": original_uuid,
"sentence_count": 5
})
assert flow2.state.sentence_count == 5
assert flow2.state.poem_type == "limerick"
# Third run - new flow without persisted state should use defaults
flow3 = MultiStepPoemFlow()
flow3.kickoff(inputs={
"id": original_uuid
})
assert flow3.state.sentence_count == 5
assert flow3.state.poem_type == "limerick"

View File

@@ -1,12 +1,12 @@
"""Test flow state persistence functionality."""
import os
from typing import Dict, Optional
from typing import Dict
import pytest
from pydantic import BaseModel
from crewai.flow.flow import Flow, FlowState, start
from crewai.flow.flow import Flow, FlowState, listen, start
from crewai.flow.persistence import persist
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
@@ -73,13 +73,14 @@ def test_flow_state_restoration(tmp_path):
# First flow execution to create initial state
class RestorableFlow(Flow[TestState]):
initial_state = TestState
@start()
@persist(persistence)
def set_message(self):
self.state.message = "Original message"
self.state.counter = 42
if self.state.message == "":
self.state.message = "Original message"
if self.state.counter == 0:
self.state.counter = 42
# Create and persist initial state
flow1 = RestorableFlow(persistence=persistence)
@@ -87,11 +88,11 @@ def test_flow_state_restoration(tmp_path):
original_uuid = flow1.state.id
# Test case 1: Restore using restore_uuid with field override
flow2 = RestorableFlow(
persistence=persistence,
restore_uuid=original_uuid,
counter=43, # Override counter
)
flow2 = RestorableFlow(persistence=persistence)
flow2.kickoff(inputs={
"id": original_uuid,
"counter": 43
})
# Verify state restoration and selective field override
assert flow2.state.id == original_uuid
@@ -99,48 +100,17 @@ def test_flow_state_restoration(tmp_path):
assert flow2.state.counter == 43 # Overridden
# Test case 2: Restore using kwargs['id']
flow3 = RestorableFlow(
persistence=persistence,
id=original_uuid,
message="Updated message", # Override message
)
flow3 = RestorableFlow(persistence=persistence)
flow3.kickoff(inputs={
"id": original_uuid,
"message": "Updated message"
})
# Verify state restoration and selective field override
assert flow3.state.id == original_uuid
assert flow3.state.counter == 42 # Preserved
assert flow3.state.counter == 43 # Preserved
assert flow3.state.message == "Updated message" # Overridden
# Test case 3: Verify error on conflicting IDs
with pytest.raises(ValueError) as exc_info:
RestorableFlow(
persistence=persistence,
restore_uuid=original_uuid,
id="different-id", # Conflict with restore_uuid
)
assert "Conflicting IDs provided" in str(exc_info.value)
# Test case 4: Verify error on non-existent restore_uuid
with pytest.raises(ValueError) as exc_info:
RestorableFlow(
persistence=persistence,
restore_uuid="non-existent-uuid",
)
assert "No state found" in str(exc_info.value)
# Test case 5: Allow new state creation with kwargs['id']
new_uuid = "new-flow-id"
flow4 = RestorableFlow(
persistence=persistence,
id=new_uuid,
message="New message",
counter=100,
)
# Verify new state creation with provided ID
assert flow4.state.id == new_uuid
assert flow4.state.message == "New message"
assert flow4.state.counter == 100
def test_multiple_method_persistence(tmp_path):
"""Test state persistence across multiple method executions."""
@@ -148,48 +118,59 @@ def test_multiple_method_persistence(tmp_path):
persistence = SQLiteFlowPersistence(db_path)
class MultiStepFlow(Flow[TestState]):
initial_state = TestState
@start()
@persist(persistence)
def step_1(self):
self.state.counter = 1
self.state.message = "Step 1"
if self.state.counter == 1:
self.state.counter = 99999
self.state.message = "Step 99999"
else:
self.state.counter = 1
self.state.message = "Step 1"
@start()
@listen(step_1)
@persist(persistence)
def step_2(self):
self.state.counter = 2
self.state.message = "Step 2"
if self.state.counter == 1:
self.state.counter = 2
self.state.message = "Step 2"
flow = MultiStepFlow(persistence=persistence)
flow.kickoff()
flow2 = MultiStepFlow(persistence=persistence)
flow2.kickoff(inputs={"id": flow.state.id})
# Load final state
final_state = persistence.load_state(flow.state.id)
final_state = flow2.state
assert final_state is not None
assert final_state["counter"] == 2
assert final_state["message"] == "Step 2"
def test_persistence_error_handling(tmp_path):
"""Test error handling in persistence operations."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
class InvalidFlow(Flow[TestState]):
# Missing id field in initial state
class InvalidState(BaseModel):
value: str = ""
initial_state = InvalidState
assert final_state.counter == 2
assert final_state.message == "Step 2"
class NoPersistenceMultiStepFlow(Flow[TestState]):
@start()
@persist(persistence)
def will_fail(self):
self.state.value = "test"
def step_1(self):
if self.state.counter == 1:
self.state.counter = 99999
self.state.message = "Step 99999"
else:
self.state.counter = 1
self.state.message = "Step 1"
with pytest.raises(ValueError) as exc_info:
flow = InvalidFlow(persistence=persistence)
@listen(step_1)
def step_2(self):
if self.state.counter == 1:
self.state.counter = 2
self.state.message = "Step 2"
assert "must have an 'id' field" in str(exc_info.value)
flow = NoPersistenceMultiStepFlow(persistence=persistence)
flow.kickoff()
flow2 = NoPersistenceMultiStepFlow(persistence=persistence)
flow2.kickoff(inputs={"id": flow.state.id})
# Load final state
final_state = flow2.state
assert final_state.counter == 99999
assert final_state.message == "Step 99999"