Compare commits

..

3 Commits

Author SHA1 Message Date
Gabe
2f5928e4bb fix: only treat interpolatable placeholders as crew inputs 2026-06-09 13:42:42 -03:00
Vini Brasil
703ffe67ee Migrate @listen/@router runtime to read from FlowDefinition (#6084)
Some checks are pending
CodeQL Advanced / Analyze (actions) (push) Waiting to run
CodeQL Advanced / Analyze (python) (push) Waiting to run
Vulnerability Scan / pip-audit (push) Waiting to run
* Migrate @listen/@router runtime to read from FlowDefinition

The runtime now resolves listener conditions, router status, and emit
values from `FlowMethodDefinition` instead of legacy method metadata and
the `_listeners`/`_routers`/`_router_emit` registries.

* Evaluate AND/OR listener conditions over the definition shape via
  `_evaluate_definition_condition`
* Drop the class registries and the `FlowMeta` extraction that built
  them; stop stamping `__trigger_methods__`, `__is_router__`,
  `__router_emit__`, and friends
* `@human_feedback` emit now lives only on its config

* Simplify conditionals DSL
2026-06-09 09:40:30 -07:00
Matt Aitchison
8919026326 feat(storage): pluggable default backends for memory, knowledge, rag, flow (#6079)
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Vulnerability Scan / pip-audit (push) Has been cancelled
Add opt-in extension seams so an application can route memory, knowledge,
RAG, and flow persistence through a custom backend without subclassing or
threading an explicit instance through every construction site -- mirroring
the existing crewai_core.lock_store.set_lock_backend seam.

- memory:    crewai.memory.storage.factory.set_memory_storage_factory
- knowledge: crewai.knowledge.storage.factory.set_knowledge_storage_factory
- rag:       crewai.rag.factory.register_rag_client_factory (provider registry)
- flow:      crewai.flow.persistence.factory.set_flow_persistence_factory

Each construction site consults the registered factory and falls back to the
built-in default when none is set; an explicit instance always wins. Widen
Knowledge.storage and the knowledge source base classes to BaseKnowledgeStorage
(consistent with BaseAgent.knowledge_storage) so any base-interface backend
plugs in. Runtime-free tests cover each seam.
2026-06-08 21:14:13 -05:00
32 changed files with 1163 additions and 1011 deletions

View File

@@ -7,7 +7,6 @@ from copy import copy as shallow_copy
from hashlib import md5
import json
from pathlib import Path
import re
from typing import (
TYPE_CHECKING,
Annotated,
@@ -142,7 +141,10 @@ from crewai.utilities.streaming import (
signal_end,
signal_error,
)
from crewai.utilities.string_utils import sanitize_tool_name
from crewai.utilities.string_utils import (
extract_template_variables,
sanitize_tool_name,
)
from crewai.utilities.task_output_storage_handler import TaskOutputStorageHandler
from crewai.utilities.training_handler import CrewTrainingHandler
@@ -1960,20 +1962,24 @@ class Crew(FlowTrackable, BaseModel):
Scans each task's 'description' + 'expected_output', and each agent's
'role', 'goal', and 'backstory'.
Only placeholders that interpolation can actually fill are returned;
non-identifier expressions such as ``{x if x else "y"}`` are ignored so
they are not surfaced as required inputs (matching interpolation
behavior, see :func:`extract_template_variables`).
Returns a set of all discovered placeholder names.
"""
placeholder_pattern = re.compile(r"\{(.+?)}")
required_inputs: set[str] = set()
for task in self.tasks:
# description and expected_output might contain e.g. {topic}, {user_name}
text = f"{task.description or ''} {task.expected_output or ''}"
required_inputs.update(placeholder_pattern.findall(text))
required_inputs.update(extract_template_variables(text))
for agent in self.agents:
# role, goal, backstory might have placeholders like {role_detail}, etc.
text = f"{agent.role or ''} {agent.goal or ''} {agent.backstory or ''}"
required_inputs.update(placeholder_pattern.findall(text))
required_inputs.update(extract_template_variables(text))
return required_inputs

View File

@@ -15,10 +15,7 @@ from crewai.flow.dsl._human_feedback import (
from crewai.flow.dsl._listen import listen
from crewai.flow.dsl._router import router
from crewai.flow.dsl._start import start
from crewai.flow.dsl._utils import (
build_flow_definition as build_flow_definition,
extract_flow_definition as extract_flow_definition,
)
from crewai.flow.dsl._utils import build_flow_definition as build_flow_definition
__all__ = [

View File

@@ -1,12 +1,4 @@
"""Flow DSL condition primitives.
Type guards, the public ``or_`` / ``and_`` combinators, and the conversions
between runtime conditions, normalized conditions, and the
``FlowDefinitionCondition`` shape stored on a :class:`FlowDefinition`. These are
the lower layer of the DSL: the decorators and the definition builder
(``_utils``) build on top of them, so this module imports nothing from its
siblings.
"""
"""Flow DSL condition primitives."""
from __future__ import annotations
@@ -20,268 +12,75 @@ from crewai.flow.dsl._types import FlowTrigger
from crewai.flow.flow_definition import FlowDefinitionCondition
from crewai.flow.flow_wrappers import (
FlowCondition,
FlowConditions,
SimpleFlowCondition,
FlowConditionType,
)
from crewai.flow.types import FlowMethodName
def _is_non_string_sequence(value: Any) -> bool:
return isinstance(value, Sequence) and not isinstance(value, (str, bytes))
def is_simple_flow_condition(obj: Any) -> TypeIs[SimpleFlowCondition]:
"""Check if the object is a ``(condition_type, methods)`` tuple."""
return (
isinstance(obj, tuple)
and len(obj) == 2
and isinstance(obj[0], str)
and isinstance(obj[1], list)
)
def is_flow_condition_dict(obj: Any) -> TypeIs[FlowCondition]:
"""Check if the object matches the FlowCondition structure."""
if not isinstance(obj, dict):
return False
type_value = obj.get("type")
if type_value not in ("AND", "OR"):
return False
if "conditions" in obj:
conditions = obj["conditions"]
if not _is_non_string_sequence(conditions):
return False
for cond in conditions:
if not (
isinstance(cond, str)
or (isinstance(cond, dict) and is_flow_condition_dict(cond))
):
return False
if "methods" in obj:
methods = obj["methods"]
if not (
_is_non_string_sequence(methods)
and all(isinstance(m, str) for m in methods)
):
return False
allowed_keys = {"type", "conditions", "methods"}
if not set(obj).issubset(allowed_keys):
return False
return True
def _method_reference_name(value: Any) -> FlowMethodName | None:
name = getattr(value, "__name__", None)
if callable(value) and isinstance(name, str):
return FlowMethodName(name)
return None
def _normalize_condition(
condition: FlowConditions | FlowCondition | str,
) -> FlowCondition:
if isinstance(condition, str):
return {"type": OR_CONDITION, "conditions": [FlowMethodName(condition)]}
if is_flow_condition_dict(condition):
if "conditions" in condition:
return condition
if "methods" in condition:
normalized_methods: list[str | FlowMethodName | FlowCondition] = list(
condition["methods"]
)
return {"type": condition["type"], "conditions": normalized_methods}
return condition
if _is_non_string_sequence(condition) and all(
isinstance(item, str) or is_flow_condition_dict(item) for item in condition
):
return {"type": OR_CONDITION, "conditions": condition}
raise ValueError(f"Cannot normalize condition: {condition}")
def _extract_all_methods_recursive(
condition: str | FlowCondition | dict[str, Any] | list[Any],
flow: Any | None = None,
) -> list[FlowMethodName]:
if isinstance(condition, str):
if flow is not None:
if condition in flow._methods:
return [FlowMethodName(condition)]
return []
return [FlowMethodName(condition)]
if is_flow_condition_dict(condition):
normalized = _normalize_condition(condition)
methods = []
for sub_cond in normalized.get("conditions", []):
methods.extend(_extract_all_methods_recursive(sub_cond, flow))
return methods
if isinstance(condition, list):
methods = []
for item in condition:
methods.extend(_extract_all_methods_recursive(item, flow))
return methods
return []
def _extract_all_methods(
condition: str | FlowCondition | dict[str, Any] | list[Any],
) -> list[FlowMethodName]:
if isinstance(condition, str):
return [FlowMethodName(condition)]
if is_flow_condition_dict(condition):
normalized = _normalize_condition(condition)
cond_type = normalized.get("type", OR_CONDITION)
if cond_type == AND_CONDITION:
return [
FlowMethodName(sub_cond)
for sub_cond in normalized.get("conditions", [])
if isinstance(sub_cond, str)
]
return []
if isinstance(condition, list):
methods = []
for item in condition:
methods.extend(_extract_all_methods(item))
return methods
return []
def _condition_trigger(condition: FlowTrigger) -> FlowMethodName | FlowCondition:
if isinstance(condition, str):
return FlowMethodName(condition)
if is_flow_condition_dict(condition):
return condition
method_name = _method_reference_name(condition)
if method_name is not None:
return method_name
raise ValueError("Invalid condition")
def _condition_triggers(
conditions: Sequence[FlowTrigger],
error_message: str,
) -> FlowConditions:
try:
return [_condition_trigger(condition) for condition in conditions]
except ValueError as exc:
raise ValueError(error_message) from exc
def _definition_condition_from_runtime(condition: Any) -> FlowDefinitionCondition:
if isinstance(condition, str):
return str(condition)
method_name = _method_reference_name(condition)
if method_name is not None:
return str(method_name)
if is_flow_condition_dict(condition):
normalized = _normalize_condition(condition)
key = "and" if normalized.get("type") == AND_CONDITION else "or"
return {
key: [
_definition_condition_from_runtime(sub_condition)
for sub_condition in normalized.get("conditions", [])
]
}
if isinstance(condition, list):
return {"or": [_definition_condition_from_runtime(item) for item in condition]}
return str(condition)
_CONDITION_TYPES = (AND_CONDITION, OR_CONDITION)
def or_(*triggers: FlowTrigger) -> FlowCondition:
"""Combine multiple triggers with OR logic for flow control.
Creates a condition that is satisfied when any of the specified triggers
are met. This is used with @start, @listen, or @router decorators to create
complex triggering conditions.
Args:
triggers: Route labels, method references, or existing conditions
returned by or_() / and_().
Returns:
A condition dictionary with format {"type": "OR", "conditions": list_of_triggers}.
Raises:
ValueError: If a trigger format is invalid.
Examples:
>>> @listen(or_("success", "timeout"))
>>> def handle_completion(self):
... pass
>>> @listen(or_(and_("step1", "step2"), "step3"))
>>> def handle_nested(self):
... pass
"""
processed_triggers = _condition_triggers(triggers, "Invalid trigger in or_()")
return {"type": OR_CONDITION, "conditions": processed_triggers}
"""Return a condition that fires when any trigger fires."""
return _condition_tree(OR_CONDITION, triggers)
def and_(*triggers: FlowTrigger) -> FlowCondition:
"""Combine multiple triggers with AND logic for flow control.
Creates a condition that is satisfied only when all specified triggers
are met. This is used with @start, @listen, or @router decorators to create
complex triggering conditions.
Args:
triggers: Route labels, method references, or existing conditions
returned by or_() / and_().
Returns:
A condition dictionary with format {"type": "AND", "conditions": list_of_conditions}
where each condition can be a route label, method name, or nested condition.
Raises:
ValueError: If any trigger is invalid.
Examples:
>>> @listen(and_("validated", "processed"))
>>> def handle_complete_data(self):
... pass
>>> @listen(and_(or_("step1", "step2"), "step3"))
>>> def handle_nested(self):
... pass
"""
processed_triggers = _condition_triggers(triggers, "Invalid trigger in and_()")
return {"type": AND_CONDITION, "conditions": processed_triggers}
"""Return a condition that fires after all triggers fire."""
return _condition_tree(AND_CONDITION, triggers)
def _runtime_condition_from_definition(
condition: FlowDefinitionCondition,
) -> FlowMethodName | FlowCondition:
if isinstance(condition, str):
return FlowMethodName(condition)
if is_flow_condition_dict(condition):
return condition
def _trigger_name(value: Any) -> str | None:
if isinstance(value, str):
return value
if "and" in condition:
return {
"type": AND_CONDITION,
"conditions": [
_runtime_condition_from_definition(item)
for item in condition.get("and", [])
],
}
name = getattr(value, "__name__", None)
if callable(value) and isinstance(name, str):
return name
return None
def _is_condition(value: Any) -> TypeIs[FlowCondition]:
return (
isinstance(value, dict)
and set(value) == {"type", "conditions"}
and value["type"] in _CONDITION_TYPES
and isinstance(value["conditions"], list)
and all(
_trigger_name(condition) is not None or _is_condition(condition)
for condition in value["conditions"]
)
)
def _coerce_trigger(trigger: FlowTrigger) -> str | FlowCondition:
name = _trigger_name(trigger)
if name is not None:
return name
if _is_condition(trigger):
return trigger
raise ValueError("Invalid condition")
def _condition_tree(
condition_type: FlowConditionType,
triggers: Sequence[FlowTrigger],
) -> FlowCondition:
return {
"type": OR_CONDITION,
"conditions": [
_runtime_condition_from_definition(item) for item in condition.get("or", [])
],
"type": condition_type,
"conditions": [_coerce_trigger(trigger) for trigger in triggers],
}
def _runtime_listener_condition_from_definition(
condition: FlowDefinitionCondition,
) -> SimpleFlowCondition | FlowCondition:
runtime_condition = _runtime_condition_from_definition(condition)
if isinstance(runtime_condition, str):
return (OR_CONDITION, [FlowMethodName(str(runtime_condition))])
return runtime_condition
def _to_definition_condition(condition: FlowTrigger) -> FlowDefinitionCondition:
trigger = _coerce_trigger(condition)
if isinstance(trigger, str):
return trigger
key = trigger["type"].lower()
return {
key: [
_to_definition_condition(sub_condition)
for sub_condition in trigger["conditions"]
]
}

View File

@@ -27,13 +27,8 @@ def _stamp_human_feedback_metadata(
config: HumanFeedbackConfig,
) -> None:
for attr in [
"__trigger_methods__",
"__condition_type__",
"__trigger_condition__",
"__is_flow_method__",
"__flow_persistence_config__",
"__is_router__",
"__router_emit__",
"__flow_method_definition__",
]:
if hasattr(func, attr):
@@ -43,8 +38,6 @@ def _stamp_human_feedback_metadata(
wrapper.__is_flow_method__ = True
if config.emit:
wrapper.__is_router__ = True
wrapper.__router_emit__ = list(config.emit)
fragment = getattr(wrapper, "__flow_method_definition__", None)
if isinstance(fragment, FlowMethodDefinition):
wrapper.__flow_method_definition__ = fragment.model_copy(

View File

@@ -3,13 +3,12 @@ from __future__ import annotations
from collections.abc import Callable
from typing import cast
from crewai.flow.dsl._conditions import _definition_condition_from_runtime
from crewai.flow.dsl._conditions import _to_definition_condition
from crewai.flow.dsl._types import FlowMethodDecorator, FlowTrigger
from crewai.flow.dsl._utils import (
P,
R,
_set_flow_method_definition,
_set_trigger_metadata,
)
from crewai.flow.flow_definition import FlowMethodDefinition
from crewai.flow.flow_wrappers import ListenMethod
@@ -46,10 +45,8 @@ def listen(condition: FlowTrigger) -> FlowMethodDecorator:
wrapper = ListenMethod(func)
_set_flow_method_definition(
wrapper,
FlowMethodDefinition(listen=_definition_condition_from_runtime(condition)),
wrapper, FlowMethodDefinition(listen=_to_definition_condition(condition))
)
_set_trigger_metadata(wrapper, condition)
return wrapper
return cast(FlowMethodDecorator, decorator)

View File

@@ -14,13 +14,12 @@ from typing import (
get_type_hints,
)
from crewai.flow.dsl._conditions import _definition_condition_from_runtime
from crewai.flow.dsl._conditions import _to_definition_condition
from crewai.flow.dsl._types import FlowMethodDecorator, FlowTrigger
from crewai.flow.dsl._utils import (
P,
R,
_set_flow_method_definition,
_set_trigger_metadata,
)
from crewai.flow.flow_definition import FlowMethodDefinition
from crewai.flow.flow_wrappers import RouterMethod
@@ -149,18 +148,11 @@ def router(
_set_flow_method_definition(
wrapper,
FlowMethodDefinition(
listen=_definition_condition_from_runtime(condition),
listen=_to_definition_condition(condition),
router=True,
emit=router_events or None,
),
)
_set_trigger_metadata(wrapper, condition)
if emit is not None:
wrapper.__router_emit__ = router_events
elif router_events:
wrapper.__router_emit__ = router_events
return wrapper
return cast(FlowMethodDecorator, decorator)

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from collections.abc import Callable
from typing import cast
from crewai.flow.dsl._conditions import _definition_condition_from_runtime
from crewai.flow.dsl._conditions import _to_definition_condition
from crewai.flow.dsl._types import FlowMethodDecorator, FlowTrigger
from crewai.flow.dsl._utils import (
P,
@@ -56,9 +56,7 @@ def start(
if condition is not None:
_set_flow_method_definition(
wrapper,
FlowMethodDefinition(
start=_definition_condition_from_runtime(condition)
),
FlowMethodDefinition(start=_to_definition_condition(condition)),
)
else:
_set_flow_method_definition(wrapper, FlowMethodDefinition(start=True))

View File

@@ -1,6 +1,5 @@
from __future__ import annotations
from collections.abc import Sequence
import json
import logging
from typing import Any, ParamSpec, TypeVar
@@ -8,19 +7,9 @@ from typing import Any, ParamSpec, TypeVar
from pydantic import BaseModel
from typing_extensions import TypeIs
from crewai.flow.constants import AND_CONDITION, OR_CONDITION
from crewai.flow.dsl._conditions import (
_definition_condition_from_runtime,
_extract_all_methods,
_method_reference_name,
_runtime_listener_condition_from_definition,
is_flow_condition_dict,
)
from crewai.flow.dsl._types import FlowTrigger
from crewai.flow.flow_definition import (
FlowConfigDefinition,
FlowDefinition,
FlowDefinitionCondition,
FlowDefinitionDiagnostic,
FlowHumanFeedbackDefinition,
FlowMethodDefinition,
@@ -29,10 +18,7 @@ from crewai.flow.flow_definition import (
)
from crewai.flow.flow_wrappers import (
FlowMethod,
ListenMethod,
RouterMethod,
)
from crewai.flow.types import FlowMethodName
P = ParamSpec("P")
@@ -45,11 +31,8 @@ _FLOW_METHOD_DEFINITION_ATTR = "__flow_method_definition__"
def is_flow_method(obj: Any) -> TypeIs[FlowMethod[Any, Any]]:
"""Check if the object carries Flow method wrapper metadata."""
return (
hasattr(obj, "__is_flow_method__")
or hasattr(obj, "__trigger_methods__")
or hasattr(obj, "__is_router__")
or hasattr(obj, _FLOW_METHOD_DEFINITION_ATTR)
return hasattr(obj, "__is_flow_method__") or hasattr(
obj, _FLOW_METHOD_DEFINITION_ATTR
)
@@ -59,42 +42,6 @@ def _should_include_flow_method(flow_class: type, method: Any) -> bool:
return True
def _flow_method_names(values: Sequence[Any]) -> list[FlowMethodName]:
return [FlowMethodName(str(value)) for value in values]
def _set_trigger_metadata(
wrapper: ListenMethod[P, R] | RouterMethod[P, R],
condition: FlowTrigger,
) -> None:
if isinstance(condition, str):
wrapper.__trigger_methods__ = [FlowMethodName(condition)]
wrapper.__condition_type__ = OR_CONDITION
return
if is_flow_condition_dict(condition):
if "conditions" in condition:
wrapper.__trigger_condition__ = condition
wrapper.__trigger_methods__ = _extract_all_methods(condition)
wrapper.__condition_type__ = condition["type"]
return
if "methods" in condition:
wrapper.__trigger_methods__ = _flow_method_names(condition["methods"])
wrapper.__condition_type__ = condition["type"]
return
raise ValueError("Condition dict must contain 'conditions' or 'methods'")
method_name = _method_reference_name(condition)
if method_name is not None:
wrapper.__trigger_methods__ = [method_name]
wrapper.__condition_type__ = OR_CONDITION
return
raise ValueError(
"Condition must be a method, string, or a result of or_() or and_()"
)
def _set_flow_method_definition(
wrapper: FlowMethod[P, R],
definition: FlowMethodDefinition,
@@ -236,48 +183,6 @@ def _build_config_definition(
return FlowConfigDefinition(**values)
def _condition_from_method_metadata(method: Any) -> FlowDefinitionCondition | None:
trigger_condition = getattr(method, "__trigger_condition__", None)
if trigger_condition is not None:
return _definition_condition_from_runtime(trigger_condition)
trigger_methods = getattr(method, "__trigger_methods__", None)
if trigger_methods is None:
return None
condition_type = getattr(method, "__condition_type__", OR_CONDITION)
method_names = [str(method_name) for method_name in trigger_methods]
if condition_type == AND_CONDITION:
return {"and": method_names}
if len(method_names) == 1:
return method_names[0]
return {"or": method_names}
def _flow_method_definition_from_legacy_metadata(method: Any) -> FlowMethodDefinition:
is_router = bool(getattr(method, "__is_router__", False))
condition = _condition_from_method_metadata(method)
definition = FlowMethodDefinition(
listen=condition,
router=is_router,
)
router_emit = getattr(method, "__router_emit__", None)
if router_emit:
definition.emit = [str(value) for value in router_emit]
return definition
def _definition_trigger_condition(
method_definition: FlowMethodDefinition,
) -> FlowDefinitionCondition | None:
if method_definition.listen is not None:
return method_definition.listen
if isinstance(method_definition.start, (str, dict)):
return method_definition.start
return None
def _build_human_feedback_definition(
method: Any,
diagnostics: list[FlowDefinitionDiagnostic],
@@ -332,13 +237,10 @@ def _build_method_definition(
) -> FlowMethodDefinition:
fragment = _get_flow_method_definition(method)
if fragment is None:
method_definition = _flow_method_definition_from_legacy_metadata(method)
method_definition = FlowMethodDefinition()
else:
method_definition = fragment.model_copy(deep=True)
if bool(getattr(method, "__is_router__", False)):
method_definition.router = True
human_feedback = _build_human_feedback_definition(
method, diagnostics, f"{path}.human_feedback"
)
@@ -352,11 +254,6 @@ def _build_method_definition(
method, diagnostics, f"{path}.persist"
)
router_emit = getattr(method, "__router_emit__", None)
if router_emit and not (human_feedback and human_feedback.emit):
if not method_definition.emit:
method_definition.emit = [str(value) for value in router_emit]
return method_definition
@@ -431,68 +328,3 @@ def build_flow_definition(
) -> FlowDefinition:
"""Build a FlowDefinition from a Python Flow class."""
return _build_flow_definition_from_class(flow_class, namespace)
def extract_flow_definition(
namespace: dict[str, Any],
) -> tuple[list[str], dict[str, Any], set[str], dict[str, Any]]:
"""Extract the structural flow registries from a Python class namespace."""
start_methods: list[str] = []
listeners: dict[str, Any] = {}
router_emit: dict[str, Any] = {}
routers: set[str] = set()
for attr_name, attr_value in namespace.items():
if is_flow_method(attr_value):
method_definition = _get_flow_method_definition(attr_value)
if method_definition is not None:
condition = _definition_trigger_condition(method_definition)
if condition is not None and not method_definition.is_start:
listeners[attr_name] = _runtime_listener_condition_from_definition(
condition
)
is_router = method_definition.router or bool(
getattr(attr_value, "__is_router__", False)
)
if is_router:
routers.add(attr_name)
if method_definition.emit:
router_emit[attr_name] = [
str(value) for value in method_definition.emit
]
elif (
hasattr(attr_value, "__router_emit__")
and attr_value.__router_emit__
):
router_emit[attr_name] = attr_value.__router_emit__
else:
router_emit[attr_name] = []
continue
if (
hasattr(attr_value, "__trigger_methods__")
and attr_value.__trigger_methods__ is not None
):
methods = attr_value.__trigger_methods__
condition_type = getattr(attr_value, "__condition_type__", OR_CONDITION)
if (
hasattr(attr_value, "__trigger_condition__")
and attr_value.__trigger_condition__ is not None
):
listeners[attr_name] = attr_value.__trigger_condition__
else:
listeners[attr_name] = (condition_type, methods)
if hasattr(attr_value, "__is_router__") and attr_value.__is_router__:
routers.add(attr_name)
if (
hasattr(attr_value, "__router_emit__")
and attr_value.__router_emit__
):
router_emit[attr_name] = attr_value.__router_emit__
else:
router_emit[attr_name] = []
return start_methods, listeners, routers, router_emit

View File

@@ -16,7 +16,6 @@ P = ParamSpec("P")
R = TypeVar("R")
FlowConditionType: TypeAlias = Literal["OR", "AND"]
SimpleFlowCondition: TypeAlias = tuple[FlowConditionType, list[FlowMethodName]]
__all__ = [
"FlowCondition",
@@ -25,7 +24,6 @@ __all__ = [
"FlowMethod",
"ListenMethod",
"RouterMethod",
"SimpleFlowCondition",
"StartMethod",
]
@@ -38,15 +36,13 @@ class FlowCondition(TypedDict, total=False):
Attributes:
type: The type of the condition.
conditions: A sequence of route labels, method names, or nested conditions.
methods: A legacy sequence of route labels or method names.
"""
type: Required[FlowConditionType]
conditions: Sequence[str | FlowMethodName | FlowCondition]
methods: Sequence[str | FlowMethodName]
conditions: Sequence[str | FlowCondition]
FlowConditions: TypeAlias = Sequence[str | FlowMethodName | FlowCondition]
FlowConditions: TypeAlias = Sequence[str | FlowCondition]
class FlowMethod(Generic[P, R]):
@@ -83,8 +79,6 @@ class FlowMethod(Generic[P, R]):
# Preserve flow-related attributes from wrapped method (e.g., from @human_feedback)
for attr in [
"__is_router__",
"__router_emit__",
"__human_feedback_config__",
"__conversational_only__", # gates registration on Flow.conversational
"__flow_persistence_config__",
@@ -162,16 +156,6 @@ class StartMethod(FlowMethod[P, R]):
class ListenMethod(FlowMethod[P, R]):
"""Wrapper for methods marked as flow listeners."""
__trigger_methods__: list[FlowMethodName] | None = None
__condition_type__: FlowConditionType | None = None
__trigger_condition__: FlowCondition | None = None
class RouterMethod(FlowMethod[P, R]):
"""Wrapper for methods marked as flow routers."""
__is_router__: bool = True
__trigger_methods__: list[FlowMethodName] | None = None
__condition_type__: FlowConditionType | None = None
__trigger_condition__: FlowCondition | None = None
__router_emit__: list[str] | None = None

View File

@@ -187,16 +187,12 @@ class HumanFeedbackMethod(FlowMethod[Any, Any]):
"""Wrapper for methods decorated with @human_feedback.
This wrapper extends FlowMethod to add human feedback specific attributes
that are used by FlowMeta for routing and by visualization tools.
used by the FlowDefinition builder and runtime feedback handling.
Attributes:
__is_router__: True when emit is specified, enabling router behavior.
__router_emit__: List of possible outcomes when acting as a router.
__human_feedback_config__: The HumanFeedbackConfig for this method.
"""
__is_router__: bool = False
__router_emit__: list[str] | None = None
__human_feedback_config__: HumanFeedbackConfig | None = None

View File

@@ -35,7 +35,7 @@ from crewai_core.printer import PRINTER
from pydantic import BaseModel
from crewai.flow.persistence.base import FlowPersistence
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
from crewai.flow.persistence.factory import default_flow_persistence
if TYPE_CHECKING:
@@ -67,11 +67,6 @@ def _stamp_persistence_metadata(
_PRESERVED_FLOW_ATTRS: Final[tuple[str, ...]] = (
"__trigger_methods__",
"__condition_type__",
"__trigger_condition__",
"__is_router__",
"__router_emit__",
"__human_feedback_config__",
"__flow_persistence_config__",
"__flow_method_definition__",
@@ -171,7 +166,9 @@ def persist(
Args:
persistence: Optional FlowPersistence implementation to use.
If not provided, uses SQLiteFlowPersistence.
If not provided, uses ``default_flow_persistence()`` (the
registered factory when present, else the built-in SQLite
fallback).
verbose: Whether to log persistence operations. Defaults to False.
Returns:
@@ -190,7 +187,9 @@ def persist(
"""
def decorator(target: type | Callable[..., T]) -> type | Callable[..., T]:
actual_persistence = persistence or SQLiteFlowPersistence()
actual_persistence = (
persistence if persistence is not None else default_flow_persistence()
)
if isinstance(target, type):
_stamp_persistence_metadata(target, actual_persistence, verbose)
@@ -210,10 +209,7 @@ def persist(
for name, method in target.__dict__.items()
if callable(method)
and (
hasattr(method, "__trigger_methods__")
or hasattr(method, "__condition_type__")
or hasattr(method, "__is_flow_method__")
or hasattr(method, "__is_router__")
hasattr(method, "__is_flow_method__")
or hasattr(method, "__flow_method_definition__")
)
}

View File

@@ -0,0 +1,60 @@
"""Pluggable default persistence backend for flows.
By default, ``@persist`` and the flow runtime persist state with
:class:`~crewai.flow.persistence.sqlite.SQLiteFlowPersistence` when no explicit
``persistence=`` is given. Registering a factory via
:func:`set_flow_persistence_factory` lets an application back flow state with a
custom :class:`~crewai.flow.persistence.base.FlowPersistence` -- a database, a
remote service, an in-memory fake for tests -- without passing a
``persistence=`` instance at every ``@persist`` / kickoff site.
This mirrors :func:`crewai_core.lock_store.set_lock_backend`: a one-time,
process-wide setter intended for application startup. Pass ``None`` to restore
the built-in SQLite default. Call :func:`default_flow_persistence` to build the
default backend (the registered factory if any, else SQLite).
"""
from __future__ import annotations
from collections.abc import Callable
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from crewai.flow.persistence.base import FlowPersistence
FlowPersistenceFactory = Callable[[], "FlowPersistence"]
_factory: FlowPersistenceFactory | None = None
def set_flow_persistence_factory(factory: FlowPersistenceFactory | None) -> None:
"""Replace the process-wide default flow persistence factory.
Intended for one-time setup at startup. Pass ``None`` to restore the
built-in ``SQLiteFlowPersistence``. Only affects flows that fall back to
the default; an explicit ``persistence=`` instance always wins.
The default is resolved at each fall-back site (``@persist`` and the
runtime's pause/resume paths), so the factory may be called more than once
for a single flow. Return instances backed by shared durable state (or a
singleton) so state saved on one call is visible to the next -- the
built-in SQLite default satisfies this by sharing one on-disk file.
"""
global _factory
_factory = factory
def default_flow_persistence() -> FlowPersistence:
"""Build the default flow persistence backend.
Returns the result of the registered factory if one is set, otherwise a
built-in :class:`~crewai.flow.persistence.sqlite.SQLiteFlowPersistence`.
"""
factory = _factory
if factory is not None:
return factory()
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
return SQLiteFlowPersistence()

View File

@@ -89,27 +89,17 @@ from crewai.experimental.conversational import (
ConversationState,
)
from crewai.experimental.conversational_mixin import _ConversationalMixin
from crewai.flow.constants import AND_CONDITION, OR_CONDITION
from crewai.flow.dsl._conditions import (
_extract_all_methods,
_extract_all_methods_recursive,
_normalize_condition,
_runtime_listener_condition_from_definition,
is_flow_condition_dict,
is_simple_flow_condition,
)
from crewai.flow.dsl._utils import (
build_flow_definition,
extract_flow_definition,
)
from crewai.flow.dsl._utils import build_flow_definition
from crewai.flow.flow_context import current_flow_id, current_flow_request_id
from crewai.flow.flow_definition import FlowDefinition, FlowDefinitionCondition
from crewai.flow.flow_definition import (
FlowDefinition,
FlowDefinitionCondition,
FlowMethodDefinition,
)
from crewai.flow.flow_wrappers import (
FlowCondition,
FlowMethod,
ListenMethod,
RouterMethod,
SimpleFlowCondition,
StartMethod,
)
from crewai.flow.human_feedback import HumanFeedbackResult
@@ -164,6 +154,25 @@ ExecutionContext = Any # type: ignore[assignment,misc]
logger = logging.getLogger(__name__)
def _iter_condition_events(condition: FlowDefinitionCondition) -> Iterator[str]:
if isinstance(condition, str):
yield condition
return
sub_conditions = condition["and"] if "and" in condition else condition["or"]
for sub_condition in sub_conditions:
yield from _iter_condition_events(sub_condition)
def _is_multi_event_or(
condition: FlowDefinitionCondition,
) -> bool:
if isinstance(condition, str):
return False
return "or" in condition and len(condition["or"]) > 1
def _resolve_persistence(value: Any) -> Any:
if value is None or isinstance(value, FlowPersistence):
return value
@@ -601,18 +610,10 @@ class FlowMeta(ModelMetaclass):
annotations[attr_name] = ClassVar[type(attr_value)]
namespace["__annotations__"] = annotations
cls = super().__new__(mcs, name, bases, namespace)
_, listeners, routers, router_emit = extract_flow_definition(namespace)
cls._listeners = listeners # type: ignore[attr-defined]
cls._routers = routers # type: ignore[attr-defined]
cls._router_emit = router_emit # type: ignore[attr-defined]
# The static FlowDefinition is built lazily (on first access via
# ``Flow.flow_definition()`` or visualization), not at class-definition
# time, to avoid AST parsing and diagnostic logging on every import.
return cls
return super().__new__(mcs, name, bases, namespace)
class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
@@ -627,9 +628,6 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
)
__hash__ = object.__hash__
_listeners: ClassVar[dict[FlowMethodName, SimpleFlowCondition | FlowCondition]] = {}
_routers: ClassVar[set[FlowMethodName]] = set()
_router_emit: ClassVar[dict[FlowMethodName, list[FlowMethodName]]] = {}
_flow_definition: ClassVar[FlowDefinition | None] = None
# === EXPERIMENTAL: conversational mode ===
@@ -677,7 +675,7 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
return flow_definition
@classmethod
def _definition_start_method_names(cls) -> list[FlowMethodName]:
def _start_method_names(cls) -> list[FlowMethodName]:
return [
FlowMethodName(method_name)
for method_name, method_definition in cls.flow_definition().methods.items()
@@ -685,21 +683,39 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
]
@classmethod
def _definition_start_condition(
def _listener_methods(
cls,
) -> Iterator[tuple[FlowMethodName, FlowMethodDefinition, FlowDefinitionCondition]]:
# (name, definition, condition) for every non-start method that listens.
# Routers are included (they listen too); callers wanting only plain
# listeners filter on definition.router.
for method_name, method_definition in cls.flow_definition().methods.items():
if method_definition.listen is not None and not method_definition.is_start:
yield (
FlowMethodName(method_name),
method_definition,
method_definition.listen,
)
@classmethod
def _start_condition(
cls, method_name: FlowMethodName
) -> FlowDefinitionCondition | None:
method_definition = cls.flow_definition().methods.get(str(method_name))
if method_definition is None:
return None
method_definition = cls.flow_definition().methods[str(method_name)]
start = method_definition.start
if isinstance(start, (str, dict)):
return start
return None
@classmethod
def _definition_has_start(cls, method_name: FlowMethodName) -> bool:
method_definition = cls.flow_definition().methods.get(str(method_name))
return bool(method_definition and method_definition.is_start)
def _listen_condition(
cls, method_name: FlowMethodName
) -> FlowDefinitionCondition | None:
return cls.flow_definition().methods[str(method_name)].listen
@classmethod
def _is_router(cls, method_name: FlowMethodName) -> bool:
return cls.flow_definition().methods[str(method_name)].router
initial_state: Annotated[ # type: ignore[type-arg]
type[BaseModel] | type[dict] | dict[str, Any] | BaseModel | None,
@@ -848,10 +864,13 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
_method_execution_counts: dict[FlowMethodName, int] = PrivateAttr(
default_factory=dict
)
_pending_and_listeners: dict[PendingListenerKey, set[FlowMethodName]] = PrivateAttr(
_pending_and_listeners: dict[PendingListenerKey, set[int]] = PrivateAttr(
default_factory=dict
)
_fired_or_listeners: set[FlowMethodName] = PrivateAttr(default_factory=set)
_racing_groups_cache: dict[frozenset[FlowMethodName], FlowMethodName] | None = (
PrivateAttr(default=None)
)
_method_outputs: list[Any] = PrivateAttr(default_factory=list)
_state_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
_or_listeners_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
@@ -992,22 +1011,6 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
result: list[str] = self.memory.extract_memories(content)
return result
def _mark_or_listener_fired(self, listener_name: FlowMethodName) -> bool:
"""Mark an OR listener as fired atomically.
Args:
listener_name: The name of the OR listener to mark.
Returns:
True if this call was the first to fire the listener.
False if the listener was already fired.
"""
with self._or_listeners_lock:
if listener_name in self._fired_or_listeners:
return False
self._fired_or_listeners.add(listener_name)
return True
def _clear_or_listeners(self) -> None:
"""Clear fired OR listeners for cyclic flows."""
with self._or_listeners_lock:
@@ -1021,25 +1024,11 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
def _start_condition_triggered_by(
self, method_name: FlowMethodName, trigger: FlowMethodName
) -> bool:
condition = type(self)._definition_start_condition(method_name)
condition = type(self)._start_condition(method_name)
if condition is None:
return False
condition_data = _runtime_listener_condition_from_definition(condition)
if is_simple_flow_condition(condition_data):
condition_type, methods = condition_data
if condition_type == OR_CONDITION:
return trigger in methods
pending_key = PendingListenerKey(method_name)
if pending_key not in self._pending_and_listeners:
self._pending_and_listeners[pending_key] = set(methods)
if trigger in self._pending_and_listeners[pending_key]:
self._pending_and_listeners[pending_key].discard(trigger)
if not self._pending_and_listeners[pending_key]:
self._pending_and_listeners.pop(pending_key, None)
return True
return False
return self._evaluate_condition(
condition_data,
condition,
trigger,
method_name,
pending_key_prefix=f"start:{method_name}",
@@ -1050,18 +1039,9 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
trigger: FlowMethodName,
rearmable: set[FlowMethodName] | None = None,
) -> None:
"""Re-arm fired OR listeners whose condition includes ``trigger``.
Called when a router emits a fresh signal so cyclic flows can re-fire
multi-source ``or_`` listeners. Listeners whose condition does not
reference the trigger are left fired.
Args:
trigger: The signal/method name a router just emitted.
rearmable: Optional set restricting which listeners may be re-armed.
When provided, listeners outside this set are skipped, and any
listener re-armed is removed from it.
"""
# When a router emits a fresh signal, re-arm fired multi-event or_()
# listeners that reference the trigger so cyclic flows can re-fire them.
# A given rearmable set, when passed, bounds which listeners may re-arm.
with self._or_listeners_lock:
if not self._fired_or_listeners:
return
@@ -1075,87 +1055,60 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
trigger_str = str(trigger)
to_discard: list[FlowMethodName] = []
for listener_name in candidates:
condition_data = self._listeners.get(listener_name)
if condition_data is None:
condition = type(self)._listen_condition(listener_name)
if condition is None:
continue
if is_simple_flow_condition(condition_data):
_, methods = condition_data
if trigger in methods or trigger_str in {str(m) for m in methods}:
to_discard.append(listener_name)
elif is_flow_condition_dict(condition_data):
all_methods = _extract_all_methods_recursive(condition_data)
if trigger_str in {str(m) for m in all_methods}:
to_discard.append(listener_name)
if trigger_str in _iter_condition_events(condition):
to_discard.append(listener_name)
for listener_name in to_discard:
self._fired_or_listeners.discard(listener_name)
if rearmable is not None:
rearmable.discard(listener_name)
def _build_racing_groups(self) -> dict[frozenset[FlowMethodName], FlowMethodName]:
"""Identify groups of methods that race for the same OR listener.
Analyzes the flow graph to find listeners with OR conditions that have
multiple trigger methods. These trigger methods form a "racing group"
where only the first to complete should trigger the OR listener.
Only methods that are EXCLUSIVELY sources for the OR listener are included
in the racing group. Methods that are also triggers for other listeners
(e.g., AND conditions) are not cancelled when another racing source wins.
Returns:
Dictionary mapping frozensets of racing method names to their
shared OR listener name.
Example:
If we have `@listen(or_(method_a, method_b))` on `handler`,
and method_a/method_b aren't used elsewhere,
this returns: {frozenset({'method_a', 'method_b'}): 'handler'}
"""
# Events of a multi-event or_() listener race: only the first to fire
# should trigger it. We map {frozenset(racing events): listener}.
# Only events that EXCLUSIVELY feed one OR listener race; an event that
# also feeds another listener (e.g. an AND) is left alone when a sibling
# wins. e.g. @listen(or_(a, b)) on handler -> {frozenset({a, b}): handler}.
racing_groups: dict[frozenset[FlowMethodName], FlowMethodName] = {}
listener_conditions: dict[FlowMethodName, FlowDefinitionCondition] = {
listener_name: condition
for listener_name, method_definition, condition in type(
self
)._listener_methods()
if not method_definition.router
}
method_to_listeners: dict[FlowMethodName, set[FlowMethodName]] = {}
for listener_name, condition_data in self._listeners.items():
if is_simple_flow_condition(condition_data):
_, methods = condition_data
for m in methods:
method_to_listeners.setdefault(m, set()).add(listener_name)
elif is_flow_condition_dict(condition_data):
all_methods = _extract_all_methods_recursive(condition_data)
for m in all_methods:
method_name = FlowMethodName(m) if isinstance(m, str) else m
method_to_listeners.setdefault(method_name, set()).add(
listener_name
)
events_by_listener: dict[FlowMethodName, set[str]] = {
listener_name: set(_iter_condition_events(condition))
for listener_name, condition in listener_conditions.items()
}
for listener_name, condition_data in self._listeners.items():
if listener_name in self._routers:
listeners_by_event: dict[str, set[FlowMethodName]] = {}
for listener_name, events in events_by_listener.items():
for event in events:
listeners_by_event.setdefault(event, set()).add(listener_name)
for listener_name, condition in listener_conditions.items():
if not isinstance(condition, dict):
continue
events = events_by_listener[listener_name]
if "or" not in condition or len(events) <= 1:
continue
trigger_methods: set[FlowMethodName] = set()
if is_simple_flow_condition(condition_data):
condition_type, methods = condition_data
if condition_type == OR_CONDITION and len(methods) > 1:
trigger_methods = set(methods)
elif is_flow_condition_dict(condition_data):
top_level_type = condition_data.get("type", OR_CONDITION)
if top_level_type == OR_CONDITION:
all_methods = _extract_all_methods_recursive(condition_data)
if len(all_methods) > 1:
trigger_methods = set(
FlowMethodName(m) if isinstance(m, str) else m
for m in all_methods
)
if trigger_methods:
exclusive_methods = {
m
for m in trigger_methods
if method_to_listeners.get(m, set()) == {listener_name}
}
if len(exclusive_methods) > 1:
racing_groups[frozenset(exclusive_methods)] = listener_name
exclusive_events = {
event
for event in events
if listeners_by_event.get(event, set()) == {listener_name}
}
if len(exclusive_events) > 1:
# Racing only applies to method-completion events: each member is
# later executed as a method and intersected with the running
# method names, so the leaves re-enter method space here.
racing_groups[
frozenset(FlowMethodName(event) for event in exclusive_events)
] = listener_name
return racing_groups
@@ -1172,16 +1125,15 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
Tuple of (racing_members, or_listener_name) if these listeners race,
None otherwise.
"""
if not hasattr(self, "_racing_groups_cache"):
if self._racing_groups_cache is None:
self._racing_groups_cache = self._build_racing_groups()
listener_set = set(listener_names)
for racing_members, or_listener in self._racing_groups_cache.items():
if racing_members & listener_set:
racing_subset = racing_members & listener_set
if len(racing_subset) > 1:
return (frozenset(racing_subset), or_listener)
racing_subset = racing_members & listener_set
if len(racing_subset) > 1:
return (frozenset(racing_subset), or_listener)
return None
@@ -1252,7 +1204,9 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
Args:
flow_id: The unique identifier of the paused flow (from state.id)
persistence: The persistence backend where the state was saved.
If not provided, defaults to SQLiteFlowPersistence().
If not provided, uses ``default_flow_persistence()`` (the
registered factory when present, else the built-in SQLite
fallback).
**kwargs: Additional keyword arguments passed to the Flow constructor
Returns:
@@ -1274,9 +1228,9 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
```
"""
if persistence is None:
from crewai.flow.persistence import SQLiteFlowPersistence
from crewai.flow.persistence.factory import default_flow_persistence
persistence = SQLiteFlowPersistence()
persistence = default_flow_persistence()
loaded = persistence.load_pending_feedback(flow_id)
if loaded is None:
@@ -1463,7 +1417,7 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
self._pending_feedback_context = None
if self.persistence:
if self.persistence is not None:
self.persistence.clear_pending_feedback(context.flow_id)
crewai_event_bus.emit(
@@ -1505,9 +1459,9 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
self._pending_feedback_context = e.context
if self.persistence is None:
from crewai.flow.persistence import SQLiteFlowPersistence
from crewai.flow.persistence.factory import default_flow_persistence
self.persistence = SQLiteFlowPersistence()
self.persistence = default_flow_persistence()
state_data = (
self._state
@@ -2221,11 +2175,11 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
# Determine which start methods to execute at kickoff
# Conditional start methods are only triggered by their conditions
# UNLESS there are no unconditional starts (then all starts run as entry points)
start_methods = type(self)._definition_start_method_names()
start_methods = type(self)._start_method_names()
unconditional_starts = [
start_method
for start_method in start_methods
if type(self)._definition_start_condition(start_method) is None
if type(self)._start_condition(start_method) is None
]
# If there are unconditional starts, only run those at kickoff
# If there are NO unconditional starts, run all starts (including conditional ones)
@@ -2244,9 +2198,11 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
if isinstance(e, HumanFeedbackPending):
# Auto-save pending feedback (create default persistence if needed)
if self.persistence is None:
from crewai.flow.persistence import SQLiteFlowPersistence
from crewai.flow.persistence.factory import (
default_flow_persistence,
)
self.persistence = SQLiteFlowPersistence()
self.persistence = default_flow_persistence()
state_data = (
self._state
@@ -2448,11 +2404,12 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
)
# If start method is a router, use its result as an additional trigger
if start_method_name in self._routers and result is not None:
if type(self)._is_router(start_method_name) and result is not None:
# Execute listeners for the start method name first
await self._execute_listeners(start_method_name, result, finished_event_id)
# Then execute listeners for the router result (e.g., "approved")
router_result_trigger = FlowMethodName(str(result))
router_result = result.value if isinstance(result, enum.Enum) else result
router_result_trigger = FlowMethodName(str(router_result))
listener_result = (
self.last_human_feedback
if self.last_human_feedback is not None
@@ -2597,9 +2554,9 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
e.context.method_name = method_name
if self.persistence is None:
from crewai.flow.persistence import SQLiteFlowPersistence
from crewai.flow.persistence.factory import default_flow_persistence
self.persistence = SQLiteFlowPersistence()
self.persistence = default_flow_persistence()
# Emit paused event (not failed)
if not self.suppress_flow_events:
@@ -2693,27 +2650,24 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
) = await self._execute_single_listener(
router_name, router_input, current_triggering_event_id
)
if router_result: # Only add non-None results
router_result_str = (
router_result.value
if isinstance(router_result, enum.Enum)
else str(router_result)
)
router_results.append(FlowMethodName(router_result_str))
# If this was a human_feedback router, map the outcome to the feedback
if self.last_human_feedback is not None:
router_result_to_feedback[router_result_str] = (
self.last_human_feedback
)
current_trigger = (
FlowMethodName(
router_result.value
if isinstance(router_result, enum.Enum)
else str(router_result)
)
if router_result is not None
else FlowMethodName("")
if router_result is None:
current_trigger = FlowMethodName("")
continue
router_result = (
router_result.value
if isinstance(router_result, enum.Enum)
else router_result
)
router_result_str = str(router_result)
router_result_event = FlowMethodName(router_result_str)
router_results.append(router_result_event)
if self.last_human_feedback is not None:
router_result_to_feedback[router_result_str] = (
self.last_human_feedback
)
current_trigger = router_result_event
all_triggers = [trigger_method, *router_results]
@@ -2759,7 +2713,7 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
await asyncio.gather(*tasks)
if current_trigger in router_results:
for method_name in type(self)._definition_start_method_names():
for method_name in type(self)._start_method_names():
if self._start_condition_triggered_by(
method_name, current_trigger
):
@@ -2774,165 +2728,86 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
def _evaluate_condition(
self,
condition: str | FlowMethodName | FlowCondition,
condition: FlowDefinitionCondition,
trigger_method: FlowMethodName,
listener_name: FlowMethodName,
pending_key_prefix: str | None = None,
) -> bool:
"""Recursively evaluate a condition (simple or nested).
Args:
condition: Can be a string (method name) or dict (nested condition)
trigger_method: The method that just completed
listener_name: Name of the listener being evaluated
Returns:
True if the condition is satisfied, False otherwise
"""
if isinstance(condition, str):
return condition == trigger_method
return condition == str(trigger_method)
def _sub_prefix(index: int) -> str | None:
if pending_key_prefix is None:
return None
return f"{pending_key_prefix}:{index}"
if is_flow_condition_dict(condition):
normalized = _normalize_condition(condition)
cond_type = normalized.get("type", OR_CONDITION)
sub_conditions = normalized.get("conditions", [])
if "or" in condition:
# Evaluate every sub-condition (no short-circuit): a nested and_()
# branch needs the chance to clear its pending state in
# _pending_and_listeners even when an earlier branch already matched.
any_matched = False
for index, sub_condition in enumerate(condition["or"]):
if self._evaluate_condition(
sub_condition,
trigger_method,
listener_name,
pending_key_prefix=_sub_prefix(index),
):
any_matched = True
return any_matched
if cond_type == OR_CONDITION:
return any(
self._evaluate_condition(
sub_cond,
trigger_method,
listener_name,
pending_key_prefix=_sub_prefix(index),
)
for index, sub_cond in enumerate(sub_conditions)
)
sub_conditions = condition["and"]
pending_key = PendingListenerKey(
pending_key_prefix
if pending_key_prefix is not None
else f"{listener_name}:{id(condition)}"
)
if cond_type == AND_CONDITION:
pending_key = PendingListenerKey(
pending_key_prefix
if pending_key_prefix is not None
else f"{listener_name}:{id(condition)}"
)
if pending_key not in self._pending_and_listeners:
self._pending_and_listeners[pending_key] = set(range(len(sub_conditions)))
if pending_key not in self._pending_and_listeners:
all_methods = set(_extract_all_methods(condition))
self._pending_and_listeners[pending_key] = all_methods
pending_conditions = self._pending_and_listeners[pending_key]
for index, sub_condition in enumerate(sub_conditions):
if index not in pending_conditions:
continue
if self._evaluate_condition(
sub_condition,
trigger_method,
listener_name,
pending_key_prefix=_sub_prefix(index),
):
pending_conditions.discard(index)
if trigger_method in self._pending_and_listeners[pending_key]:
self._pending_and_listeners[pending_key].discard(trigger_method)
direct_methods_satisfied = not self._pending_and_listeners[pending_key]
nested_conditions_satisfied = all(
(
self._evaluate_condition(
sub_cond,
trigger_method,
listener_name,
pending_key_prefix=_sub_prefix(index),
)
if is_flow_condition_dict(sub_cond)
else True
)
for index, sub_cond in enumerate(sub_conditions)
)
if direct_methods_satisfied and nested_conditions_satisfied:
self._pending_and_listeners.pop(pending_key, None)
return True
return False
if not pending_conditions:
self._pending_and_listeners.pop(pending_key, None)
return True
return False
def _find_triggered_methods(
self, trigger_method: FlowMethodName, router_only: bool
) -> list[FlowMethodName]:
"""Finds all methods that should be triggered based on conditions.
This internal method evaluates both OR and AND conditions to determine
which methods should be executed next in the flow. Supports nested conditions.
Args:
trigger_method: The name of the method that just completed execution.
router_only: If True, only consider router methods. If False, only consider non-router methods.
Returns:
Names of methods that should be triggered.
Note:
- Handles both OR and AND conditions, including nested combinations
- Maintains state for AND conditions using _pending_and_listeners
- Separates router and normal listener evaluation
"""
triggered: list[FlowMethodName] = []
for listener_name, condition_data in self._listeners.items():
is_router = listener_name in self._routers
for listener_name, method_definition, condition in type(
self
)._listener_methods():
is_router = method_definition.router
if router_only != is_router:
continue
if not router_only and type(self)._definition_has_start(listener_name):
should_check_fired = _is_multi_event_or(condition) and not is_router
if should_check_fired and listener_name in self._fired_or_listeners:
continue
if is_simple_flow_condition(condition_data):
condition_type, methods = condition_data
if condition_type == OR_CONDITION:
# Only trigger multi-source OR listeners (or_(A, B, C)) once - skip if already fired
# Simple single-method listeners fire every time their trigger occurs
# Routers also fire every time - they're decision points
has_multiple_triggers = len(methods) > 1
should_check_fired = has_multiple_triggers and not is_router
if (
not should_check_fired
or listener_name not in self._fired_or_listeners
):
if trigger_method in methods:
triggered.append(listener_name)
# Only track multi-source OR listeners (not single-method or routers)
if should_check_fired:
self._fired_or_listeners.add(listener_name)
elif condition_type == AND_CONDITION:
pending_key = PendingListenerKey(listener_name)
if pending_key not in self._pending_and_listeners:
self._pending_and_listeners[pending_key] = set(methods)
if trigger_method in self._pending_and_listeners[pending_key]:
self._pending_and_listeners[pending_key].discard(trigger_method)
if not self._pending_and_listeners[pending_key]:
triggered.append(listener_name)
self._pending_and_listeners.pop(pending_key, None)
elif is_flow_condition_dict(condition_data):
# For complex conditions, check if top-level is OR and track accordingly
top_level_type = condition_data.get("type", OR_CONDITION)
is_or_based = top_level_type == OR_CONDITION
# Only track multi-source OR conditions (multiple sub-conditions), not routers
sub_conditions = condition_data.get("conditions", [])
has_multiple_triggers = is_or_based and len(sub_conditions) > 1
should_check_fired = has_multiple_triggers and not is_router
# Skip compound OR-based listeners that have already fired
if should_check_fired and listener_name in self._fired_or_listeners:
continue
if self._evaluate_condition(
condition_data, trigger_method, listener_name
):
triggered.append(listener_name)
# Track compound OR-based listeners so they only fire once
if should_check_fired:
self._fired_or_listeners.add(listener_name)
if self._evaluate_condition(
condition,
trigger_method,
listener_name,
):
triggered.append(listener_name)
if should_check_fired:
self._fired_or_listeners.add(listener_name)
return triggered
@@ -2984,13 +2859,10 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
# For routers, also check if any conditional starts they triggered are completed
# If so, continue their chains
if listener_name in self._routers:
for start_method_name in type(
self
)._definition_start_method_names():
if type(self)._is_router(listener_name):
for start_method_name in type(self)._start_method_names():
if (
type(self)._definition_start_condition(start_method_name)
is not None
type(self)._start_condition(start_method_name) is not None
and start_method_name in self._completed_methods
):
# This conditional start was executed, continue its chain

View File

@@ -5,15 +5,7 @@ the Flow system.
"""
from datetime import datetime
from typing import (
Annotated,
Any,
NewType,
ParamSpec,
Protocol,
TypeVar,
TypedDict,
)
from typing import Annotated, Any, NewType, ParamSpec, Protocol, TypeVar, TypedDict
from typing_extensions import NotRequired, Required

View File

@@ -13,6 +13,7 @@ from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSourc
from crewai.knowledge.source.text_file_knowledge_source import (
TextFileKnowledgeSource,
)
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
from crewai.rag.embeddings.types import EmbedderConfig
@@ -89,7 +90,7 @@ class Knowledge(BaseModel):
Knowledge is a collection of sources and setup for the vector store to save and query relevant context.
Args:
sources: list[BaseKnowledgeSource] = Field(default_factory=list)
storage: KnowledgeStorage | None = Field(default=None)
storage: BaseKnowledgeStorage | None = Field(default=None)
embedder: EmbedderConfig | None = None
"""
@@ -98,7 +99,7 @@ class Knowledge(BaseModel):
BeforeValidator(_resolve_knowledge_sources),
] = Field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True)
storage: KnowledgeStorage | None = Field(default=None)
storage: BaseKnowledgeStorage | None = Field(default=None)
embedder: Annotated[
EmbedderConfig | None,
PlainSerializer(
@@ -112,15 +113,22 @@ class Knowledge(BaseModel):
collection_name: str,
sources: list[BaseKnowledgeSource],
embedder: EmbedderConfig | None = None,
storage: KnowledgeStorage | None = None,
storage: BaseKnowledgeStorage | None = None,
**data: object,
) -> None:
super().__init__(**data)
if storage:
if storage is not None:
self.storage = storage
else:
self.storage = KnowledgeStorage(
embedder=embedder, collection_name=collection_name
from crewai.knowledge.storage.factory import resolve_knowledge_storage
custom = resolve_knowledge_storage(embedder, collection_name)
self.storage = (
custom
if custom is not None
else KnowledgeStorage(
embedder=embedder, collection_name=collection_name
)
)
self.sources = sources
@@ -152,10 +160,9 @@ class Knowledge(BaseModel):
raise e
def reset(self) -> None:
if self.storage:
self.storage.reset()
else:
if self.storage is None:
raise ValueError("Storage is not initialized.")
self.storage.reset()
async def aquery(
self, query: list[str], results_limit: int = 5, score_threshold: float = 0.6
@@ -193,7 +200,6 @@ class Knowledge(BaseModel):
async def areset(self) -> None:
"""Reset the knowledge base asynchronously."""
if self.storage:
await self.storage.areset()
else:
if self.storage is None:
raise ValueError("Storage is not initialized.")
await self.storage.areset()

View File

@@ -5,7 +5,7 @@ from typing import Any
from pydantic import Field, field_validator
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
from crewai.utilities.constants import KNOWLEDGE_DIRECTORY
from crewai.utilities.logger import Logger
@@ -22,7 +22,7 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
default_factory=list, description="The path to the file"
)
content: dict[Path, str] = Field(init=False, default_factory=dict)
storage: KnowledgeStorage | None = Field(default=None)
storage: BaseKnowledgeStorage | None = Field(default=None)
safe_file_paths: list[Path] = Field(default_factory=list)
@field_validator("file_path", "file_paths", mode="before")
@@ -70,14 +70,14 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
def _save_documents(self) -> None:
"""Save the documents to the storage."""
if self.storage:
if self.storage is not None:
self.storage.save(self.chunks)
else:
raise ValueError("No storage found to save documents.")
async def _asave_documents(self) -> None:
"""Save the documents to the storage asynchronously."""
if self.storage:
if self.storage is not None:
await self.storage.asave(self.chunks)
else:
raise ValueError("No storage found to save documents.")

View File

@@ -4,9 +4,15 @@ from typing import Any
import numpy as np
from pydantic import BaseModel, ConfigDict, Field
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
# ``KnowledgeStorage`` is re-exported for backwards compatibility; the ``storage``
# field below is typed to the base interface so any backend plugs in.
__all__ = ["BaseKnowledgeSource", "KnowledgeStorage"]
class BaseKnowledgeSource(BaseModel, ABC):
"""Abstract base class for knowledge sources."""
@@ -18,7 +24,7 @@ class BaseKnowledgeSource(BaseModel, ABC):
)
model_config = ConfigDict(arbitrary_types_allowed=True)
storage: KnowledgeStorage | None = Field(default=None)
storage: BaseKnowledgeStorage | None = Field(default=None)
metadata: dict[str, Any] = Field(default_factory=dict) # Currently unused
collection_name: str | None = Field(default=None)
@@ -49,7 +55,7 @@ class BaseKnowledgeSource(BaseModel, ABC):
Raises:
ValueError: If no storage is configured.
"""
if self.storage:
if self.storage is not None:
self.storage.save(self.chunks)
else:
raise ValueError("No storage found to save documents.")
@@ -66,7 +72,7 @@ class BaseKnowledgeSource(BaseModel, ABC):
Raises:
ValueError: If no storage is configured.
"""
if self.storage:
if self.storage is not None:
await self.storage.asave(self.chunks)
else:
raise ValueError("No storage found to save documents.")

View File

@@ -0,0 +1,56 @@
"""Pluggable default storage backend for knowledge collections.
By default, :class:`~crewai.knowledge.knowledge.Knowledge` builds a
:class:`~crewai.knowledge.storage.knowledge_storage.KnowledgeStorage` when no
explicit ``storage=`` is given. Registering a factory via
:func:`set_knowledge_storage_factory` lets an application back knowledge with a
custom :class:`~crewai.knowledge.storage.base_knowledge_storage.BaseKnowledgeStorage`
without subclassing ``Knowledge`` or passing a ``storage=`` instance at every
call site.
This mirrors :func:`crewai_core.lock_store.set_lock_backend`: a one-time,
process-wide setter intended for application startup. Pass ``None`` to restore
the built-in default.
"""
from __future__ import annotations
from collections.abc import Callable
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
from crewai.rag.embeddings.types import EmbedderConfig
# Receives the same inputs as the built-in default -- the embedder config and
# collection name -- and returns a storage backend, or ``None`` to defer to the
# built-in ``KnowledgeStorage``.
KnowledgeStorageFactory = Callable[
["EmbedderConfig | None", "str | None"], "BaseKnowledgeStorage | None"
]
_factory: KnowledgeStorageFactory | None = None
def set_knowledge_storage_factory(factory: KnowledgeStorageFactory | None) -> None:
"""Replace the process-wide default knowledge storage factory.
Intended for one-time setup at startup. Pass ``None`` to restore the
built-in ``KnowledgeStorage``. Only affects ``Knowledge`` instances
constructed afterwards; an explicit ``storage=`` instance always wins.
"""
global _factory
_factory = factory
def resolve_knowledge_storage(
embedder: EmbedderConfig | None, collection_name: str | None
) -> BaseKnowledgeStorage | None:
"""Return the registered factory's backend, or ``None`` for the built-in.
``None`` means no factory is registered or it declined; the caller then
falls back to the built-in ``KnowledgeStorage``.
"""
factory = _factory
return factory(embedder, collection_name) if factory is not None else None

View File

@@ -0,0 +1,55 @@
"""Pluggable default storage backend for the unified memory system.
By default, :class:`~crewai.memory.unified_memory.Memory` builds a built-in
vector store from its ``storage`` spec string (LanceDB, or Qdrant for the
``"qdrant-edge"`` spec). Registering a factory via
:func:`set_memory_storage_factory` lets an application route memory through a
custom :class:`~crewai.memory.storage.backend.StorageBackend` -- a different
vector store, a remote service, an in-memory fake for tests -- without
subclassing ``Memory`` or threading an explicit ``storage=`` instance through
every construction site.
This mirrors :func:`crewai_core.lock_store.set_lock_backend`: a one-time,
process-wide setter intended for application startup. Pass ``None`` to restore
the built-in default.
"""
from __future__ import annotations
from collections.abc import Callable
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from crewai.memory.storage.backend import StorageBackend
# Receives the raw ``storage`` spec string and returns a backend to use, or
# ``None`` to defer to the built-in selection for that spec.
MemoryStorageFactory = Callable[[str], "StorageBackend | None"]
_factory: MemoryStorageFactory | None = None
def set_memory_storage_factory(factory: MemoryStorageFactory | None) -> None:
"""Replace the process-wide default memory storage factory.
Intended for one-time setup at startup. Pass ``None`` to restore the
built-in LanceDB/Qdrant selection. Only affects ``Memory`` instances
constructed afterwards; an explicit ``storage=`` instance always wins.
The factory is consulted for every string ``storage`` spec, so it must
return ``None`` for specs it does not handle to let the built-in
LanceDB/Qdrant/path selection take over.
"""
global _factory
_factory = factory
def resolve_memory_storage(spec: str) -> StorageBackend | None:
"""Return the registered factory's backend for ``spec``, or ``None``.
``None`` means no factory is registered or it declined this spec; the
caller then falls back to the built-in selection.
"""
factory = _factory
return factory(spec) if factory is not None else None

View File

@@ -204,7 +204,12 @@ class Memory(BaseModel):
)
if isinstance(self.storage, str):
if self.storage == "qdrant-edge":
from crewai.memory.storage.factory import resolve_memory_storage
custom = resolve_memory_storage(self.storage)
if custom is not None:
self._storage = custom
elif self.storage == "qdrant-edge":
from crewai.memory.storage.qdrant_edge_storage import QdrantEdgeStorage
self._storage = QdrantEdgeStorage()

View File

@@ -1,5 +1,6 @@
"""Factory functions for creating RAG clients from configuration."""
from collections.abc import Callable
from typing import cast
from crewai.rag.config.optional_imports.protocols import (
@@ -11,6 +12,32 @@ from crewai.rag.core.base_client import BaseClient
from crewai.utilities.import_utils import require
# RAG uses a provider-keyed registry (rather than the single-default setter
# used by the memory/knowledge/flow seams) because ``create_client`` already
# dispatches on ``config.provider`` -- the natural seam here is per-provider.
# A factory receives the RAG config and returns a client; one registered for a
# built-in provider name overrides the built-in for that provider.
RagClientFactory = Callable[[RagConfigType], BaseClient]
_factories: dict[str, RagClientFactory] = {}
def register_rag_client_factory(provider: str, factory: RagClientFactory) -> None:
"""Register a client factory for a RAG ``provider`` name.
Lets an application plug in a client for a new provider, or override a
built-in provider (``"chromadb"`` / ``"qdrant"``), without modifying
:func:`create_client`. Registered factories take precedence over the
built-ins. Intended for one-time setup at startup.
"""
_factories[provider] = factory
def unregister_rag_client_factory(provider: str) -> None:
"""Remove a previously registered factory; a no-op if none is registered."""
_factories.pop(provider, None)
def create_client(config: RagConfigType) -> BaseClient:
"""Create a client from configuration using the appropriate factory.
@@ -24,6 +51,10 @@ def create_client(config: RagConfigType) -> BaseClient:
ValueError: If the configuration provider is not supported.
"""
factory = _factories.get(config.provider)
if factory is not None:
return factory(config)
if config.provider == "chromadb":
chromadb_mod = cast(
ChromaFactoryModule,

View File

@@ -23,6 +23,26 @@ def _duplicate_separator_pattern(separator: str) -> re.Pattern[str]:
return re.compile(f"(?:{re.escape(separator)}){{2,}}")
def extract_template_variables(input_string: str | None) -> list[str]:
"""Return the template variable names referenced in a string.
Only recognizes placeholders that interpolation can actually fill, i.e.
``{name}`` where ``name`` starts with a letter/underscore and contains only
letters, numbers, underscores, and hyphens. Expressions such as
``{x if x else "y"}`` or JSON snippets are intentionally ignored so they are
never treated as required inputs.
Args:
input_string: The string to scan. May be ``None`` or empty.
Returns:
The matched variable names, in order of appearance (with duplicates).
"""
if not input_string:
return []
return _VARIABLE_PATTERN.findall(input_string)
def sanitize_tool_name(name: str, max_length: int = _MAX_TOOL_NAME_LENGTH) -> str:
"""Sanitize tool name for LLM provider compatibility.

View File

@@ -0,0 +1,130 @@
"""Tests for the pluggable knowledge storage factory seam.
We verify our own logic: the set/get round-trip, that a registered factory is
consulted when no explicit ``storage=`` is given (and receives the embedder and
collection name), and that an explicit ``storage=`` instance bypasses it.
"""
from __future__ import annotations
from typing import Any
import pytest
import crewai.knowledge.storage.factory as factory
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
from crewai.rag.types import SearchResult
class _FakeKnowledgeStorage(BaseKnowledgeStorage):
"""Minimal stand-in implementing the abstract interface."""
def search(
self,
query: list[str],
limit: int = 5,
metadata_filter: dict[str, Any] | None = None,
score_threshold: float = 0.6,
) -> list[SearchResult]:
return []
async def asearch(
self,
query: list[str],
limit: int = 5,
metadata_filter: dict[str, Any] | None = None,
score_threshold: float = 0.6,
) -> list[SearchResult]:
return []
def save(self, documents: list[str]) -> None:
return None
async def asave(self, documents: list[str]) -> None:
return None
def reset(self) -> None:
return None
async def areset(self) -> None:
return None
@pytest.fixture(autouse=True)
def reset_factory():
"""Reset the factory around each test without clobbering preexisting state."""
original = factory._factory
factory.set_knowledge_storage_factory(None)
yield
factory.set_knowledge_storage_factory(original)
def test_resolve_reflects_registered_factory():
fake = _FakeKnowledgeStorage()
assert factory.resolve_knowledge_storage(None, "docs") is None
factory.set_knowledge_storage_factory(lambda embedder, name: fake)
assert factory.resolve_knowledge_storage(None, "docs") is fake
def test_factory_used_when_no_explicit_storage():
fake = _FakeKnowledgeStorage()
factory.set_knowledge_storage_factory(lambda embedder, name: fake)
knowledge = Knowledge(collection_name="docs", sources=[])
assert knowledge.storage is fake
def test_factory_receives_embedder_and_collection_name():
seen: list[tuple[object, object]] = []
def make(embedder, collection_name):
seen.append((embedder, collection_name))
return _FakeKnowledgeStorage()
factory.set_knowledge_storage_factory(make)
Knowledge(collection_name="docs", sources=[])
assert seen == [(None, "docs")]
def test_explicit_storage_bypasses_factory():
factory_called = False
def make(embedder, name):
nonlocal factory_called
factory_called = True
return _FakeKnowledgeStorage()
factory.set_knowledge_storage_factory(make)
explicit = _FakeKnowledgeStorage()
knowledge = Knowledge(collection_name="docs", sources=[], storage=explicit)
assert knowledge.storage is explicit
assert factory_called is False
def test_falsy_explicit_storage_is_honored():
# A custom backend that is falsy (defines __bool__/__len__) must still be
# used and operated on, not silently treated as "not initialized" by a
# truthiness check in __init__, reset, or the source save path.
reset_calls: list[bool] = []
class _FalsyStorage(_FakeKnowledgeStorage):
def __bool__(self) -> bool:
return False
def reset(self) -> None:
reset_calls.append(True)
explicit = _FalsyStorage()
knowledge = Knowledge(collection_name="docs", sources=[], storage=explicit)
assert knowledge.storage is explicit
# reset must call the backend, not raise "Storage is not initialized."
knowledge.reset()
assert reset_calls == [True]

View File

@@ -0,0 +1,72 @@
"""Tests for the pluggable memory storage factory seam.
We verify our own logic: the set/get round-trip, that a registered factory is
consulted for string ``storage`` specs (and receives the spec), and that an
explicit ``storage=`` instance bypasses the factory entirely.
"""
from __future__ import annotations
import pytest
import crewai.memory.storage.factory as factory
from crewai.memory.unified_memory import Memory
@pytest.fixture(autouse=True)
def reset_factory():
"""Reset the factory around each test without clobbering preexisting state."""
original = factory._factory
factory.set_memory_storage_factory(None)
yield
factory.set_memory_storage_factory(original)
def test_resolve_reflects_registered_factory():
sentinel = object()
assert factory.resolve_memory_storage("lancedb") is None
factory.set_memory_storage_factory(lambda spec: sentinel)
assert factory.resolve_memory_storage("lancedb") is sentinel
factory.set_memory_storage_factory(None)
assert factory.resolve_memory_storage("lancedb") is None
def test_factory_backend_used_for_string_spec():
sentinel = object()
factory.set_memory_storage_factory(lambda spec: sentinel)
mem = Memory(storage="lancedb")
assert mem._storage is sentinel
def test_factory_receives_the_raw_spec():
seen: list[str] = []
def make(spec):
seen.append(spec)
return object()
factory.set_memory_storage_factory(make)
Memory(storage="some/custom/path")
assert seen == ["some/custom/path"]
def test_explicit_storage_instance_bypasses_factory():
factory_called = False
def make(spec):
nonlocal factory_called
factory_called = True
return object()
factory.set_memory_storage_factory(make)
explicit = object()
mem = Memory(storage=explicit) # type: ignore[arg-type]
assert mem._storage is explicit
assert factory_called is False

View File

@@ -0,0 +1,66 @@
"""Tests for the RAG client factory registry seam.
We verify our own logic: a registered factory is used for its provider,
factories override the built-in providers, unregister removes them, and an
unknown provider still raises.
"""
from __future__ import annotations
from types import SimpleNamespace
import pytest
import crewai.rag.factory as factory
@pytest.fixture(autouse=True)
def reset_registry():
"""Reset the registry around each test without clobbering preexisting state."""
original = dict(factory._factories)
factory._factories.clear()
yield
factory._factories.clear()
factory._factories.update(original)
def test_registered_factory_is_used_for_its_provider():
sentinel = object()
factory.register_rag_client_factory("custom", lambda config: sentinel)
assert factory.create_client(SimpleNamespace(provider="custom")) is sentinel
def test_factory_receives_the_config():
seen: list[object] = []
config = SimpleNamespace(provider="custom")
factory.register_rag_client_factory("custom", lambda cfg: seen.append(cfg) or object())
factory.create_client(config)
assert seen == [config]
def test_factory_overrides_builtin_provider():
sentinel = object()
factory.register_rag_client_factory("chromadb", lambda config: sentinel)
# Resolves via the registry without importing the built-in chromadb factory.
assert factory.create_client(SimpleNamespace(provider="chromadb")) is sentinel
def test_unregister_removes_factory():
factory.register_rag_client_factory("custom", lambda config: object())
factory.unregister_rag_client_factory("custom")
with pytest.raises(ValueError, match="Unsupported provider: custom"):
factory.create_client(SimpleNamespace(provider="custom"))
def test_unregister_unknown_provider_is_noop():
factory.unregister_rag_client_factory("never-registered")
def test_unknown_provider_still_raises():
with pytest.raises(ValueError, match="Unsupported provider: nope"):
factory.create_client(SimpleNamespace(provider="nope"))

View File

@@ -3895,6 +3895,29 @@ def test_fetch_inputs():
)
def test_fetch_inputs_ignores_non_identifier_placeholders():
agent = Agent(
role="Report writer",
goal="Write a report for {company_name}.",
backstory="Expert reporter.",
)
task = Task(
description=(
'Greet {company_name if company_name else "Individual Client"} '
"and summarize {search_period}."
),
expected_output="A summary for {company_name}.",
agent=agent,
)
crew = Crew(agents=[agent], tasks=[task])
# Only the simple {company_name} placeholders are returned; the inline conditional
# expression (which interpolation cannot fill) is ignored.
assert crew.fetch_inputs() == {"company_name", "search_period"}
@pytest.mark.vcr()
def test_task_tools_preserve_code_execution_tools():
"""

View File

@@ -161,6 +161,27 @@ def test_flow_with_or_condition():
)
def test_flow_executes_and_condition_with_single_branch_or():
class NestedConditionFlow(Flow):
@start()
def event_a(self):
return "a"
@listen(event_a)
def event_b(self):
return "b"
@router(event_b)
def emit_event_c(self):
return "event_c"
@listen(and_(event_a, event_b, or_("event_c")))
def event_d(self):
return "done"
assert NestedConditionFlow().kickoff() == "done"
def test_or_listener_fires_once_across_parallel_starts():
"""Parallel ``@start`` paths feeding ``or_`` must not double-fire the listener."""
fire_count = 0
@@ -303,6 +324,90 @@ def test_start_runtime_uses_flow_definition_without_legacy_start_metadata():
assert execution_order == ["begin", "route", "branch", "done"]
def test_listen_runtime_uses_flow_definition_without_legacy_listener_metadata():
execution_order = []
class DefinitionListenFlow(Flow):
@start()
def begin(self):
execution_order.append("begin")
@listen(begin)
def by_callable(self):
execution_order.append("by_callable")
@listen(and_(begin, by_callable))
def by_and(self):
execution_order.append("by_and")
@listen(or_(and_(begin, by_callable), "fallback"))
def nested(self):
execution_order.append("nested")
for method_name in ("by_callable", "by_and", "nested"):
method = DefinitionListenFlow.__dict__[method_name]
assert not hasattr(method, "__trigger_methods__")
assert not hasattr(method, "__condition_type__")
assert not hasattr(method, "__trigger_condition__")
DefinitionListenFlow().kickoff()
assert execution_order[0] == "begin"
assert {"by_callable", "by_and", "nested"}.issubset(execution_order)
def test_router_runtime_uses_flow_definition_without_legacy_router_metadata():
execution_order = []
class DefinitionRouterFlow(Flow):
@start()
def begin(self):
execution_order.append("begin")
return "begin"
@router(begin, emit=["go_left"])
def decide(self):
execution_order.append("decide")
return "go_left"
@listen("go_left")
def handle_left(self):
execution_order.append("handle_left")
route = DefinitionRouterFlow.__dict__["decide"]
assert not hasattr(route, "__is_router__")
assert not hasattr(route, "__router_emit__")
assert not hasattr(route, "__trigger_methods__")
assert not hasattr(route, "__condition_type__")
assert not hasattr(route, "__trigger_condition__")
DefinitionRouterFlow().kickoff()
assert execution_order == ["begin", "decide", "handle_left"]
def test_router_falsy_result_emits_runtime_event():
execution_order = []
class FalsyRouterResultFlow(Flow):
@start()
def begin(self):
execution_order.append("begin")
@router(begin)
def decide(self):
execution_order.append("decide")
return 0
@listen("0")
def handle_zero(self):
execution_order.append("handle_zero")
FalsyRouterResultFlow().kickoff()
assert execution_order == ["begin", "decide", "handle_zero"]
def test_async_flow():
"""Test an asynchronous flow."""
execution_order = []
@@ -1436,6 +1541,43 @@ def test_deeply_nested_conditions():
assert and_ab_satisfied or and_cd_satisfied
def test_or_branch_does_not_leave_stale_and_state():
"""or_() over nested and_() branches must not leave stale pending AND state.
Regression: evaluating an or_() condition stopped at the first branch that was
satisfied, so a later and_() branch that the *same* trigger would have completed
never cleared its pending state. On the next cycle that trigger alone then
spuriously re-satisfied the whole condition. Both branches share the final
event ``x`` here, so the shared trigger that completes branch ``(a AND x)`` also
completes branch ``(c AND x)`` and both must be cleared together.
"""
class StaleStateFlow(Flow):
@start()
def begin(self):
pass
@listen(or_(and_("a", "x"), and_("c", "x")))
def joined(self):
pass
flow = StaleStateFlow()
condition = type(flow)._listen_condition("joined")
def fires(trigger):
return flow._evaluate_condition(condition, trigger, "joined")
# First cycle: "a" then "c" arrive, then the shared "x" completes (a AND x).
assert fires("a") is False
assert fires("c") is False
assert fires("x") is True
# Next cycle: "x" alone must NOT re-satisfy the condition. The "c" from the
# previous cycle was consumed when "joined" fired, so neither branch is half
# complete and "x" by itself is insufficient.
assert fires("x") is False
def test_mixed_sync_async_execution_order():
"""Test that execution order is preserved with mixed sync/async methods."""
execution_order = []

View File

@@ -344,6 +344,7 @@ class TestConversationalFlow:
"end",
}
@conversational_graph_broken
def test_router_infers_custom_routes_without_internal_routes(self) -> None:
class ResearchRoute(BaseModel):
intent: Literal["research", "converse", "end"]
@@ -739,6 +740,7 @@ class TestConversationalFlow:
assert flow.state.messages[-1].content == "fresh research"
assert flow._is_execution_resuming is False
@conversational_graph_broken
def test_route_catalog_combines_docstrings_builtins_and_overrides(self) -> None:
"""Catalog precedence: route_descriptions > built-in > docstring."""
@@ -770,6 +772,7 @@ class TestConversationalFlow:
assert "Ordinary chat" in catalog["converse"]
assert "finished" in catalog["end"]
@conversational_graph_broken
def test_route_catalog_falls_back_to_empty_when_no_docstring(self) -> None:
@ConversationConfig(router=RouterConfig(routes=["BARE"]))
class BareFlow(ConversationalFlow):

View File

@@ -1,6 +1,5 @@
"""Tests for the static Flow Definition contract."""
import ast
from enum import Enum
import importlib
import inspect
@@ -15,7 +14,6 @@ import crewai.flow.dsl as flow_dsl
import crewai.flow.flow_definition as flow_definition
import crewai.flow.visualization.builder as visualization_builder
from crewai.flow import Flow, and_, human_feedback, listen, or_, persist, router, start
from crewai.flow.dsl._conditions import is_flow_condition_dict
def test_flow_public_exports_are_explicit():
@@ -50,79 +48,64 @@ def test_flow_public_exports_are_explicit():
assert "calculate_node_levels" not in flow_visualization.__all__
def test_flow_condition_dict_accepts_non_string_sequences():
condition = {
"type": "OR",
"conditions": (
"approved",
{"type": "AND", "methods": ("validated", "processed")},
),
def test_condition_combinators_return_nested_runtime_tree():
condition = and_("event_a", "event_b", or_("event_c"))
assert condition == {
"type": "AND",
"conditions": [
"event_a",
"event_b",
{"type": "OR", "conditions": ["event_c"]},
],
}
assert is_flow_condition_dict(condition)
assert not is_flow_condition_dict({"type": "OR", "conditions": "approved"})
assert not is_flow_condition_dict({"type": "OR", "methods": b"approved"})
def test_flow_definition_lowers_nested_conditions():
class NestedFlow(Flow):
@start()
def begin(self):
return "begin"
@listen(begin)
def validated(self):
return "validated"
@listen(begin)
def processed(self):
return "processed"
@listen(or_(and_(validated, processed), begin))
def finalize(self):
return "done"
finalize = NestedFlow.flow_definition().methods["finalize"]
assert finalize.listen == {"or": [{"and": ["validated", "processed"]}, "begin"]}
def test_private_flow_helpers_do_not_have_docstrings():
import crewai.flow.flow_wrappers as flow_wrappers
import crewai.flow.human_feedback as human_feedback
import crewai.flow.persistence.decorators as persistence_decorators
import crewai.flow.visualization.types as visualization_types
def test_flow_definition_preserves_single_branch_nested_conditions():
class AmbiguousFlow(Flow):
@start()
def event_a(self):
return "a"
modules = [
flow_dsl,
flow_definition,
flow_wrappers,
human_feedback,
persistence_decorators,
visualization_builder,
visualization_types,
]
violations: list[str] = []
@listen(event_a)
def event_b(self):
return "b"
for module in modules:
source_path = Path(inspect.getsourcefile(module) or "")
tree = ast.parse(source_path.read_text())
stack: list[ast.AST] = []
if getattr(module, "__all__", None) == [] and ast.get_docstring(tree):
violations.append(f"{source_path}:1:<module>")
@listen(and_(event_a, event_b, or_("event_c")))
def event_d(self):
return "d"
class PrivateDocstringVisitor(ast.NodeVisitor):
def visit_ClassDef(self, node: ast.ClassDef) -> None:
self._check_docstring(node)
stack.append(node)
self.generic_visit(node)
stack.pop()
event_d = AmbiguousFlow.flow_definition().methods["event_d"]
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
self._check_docstring(node)
stack.append(node)
self.generic_visit(node)
stack.pop()
assert event_d.listen == {"and": ["event_a", "event_b", {"or": ["event_c"]}]}
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
self._check_docstring(node)
stack.append(node)
self.generic_visit(node)
stack.pop()
def _check_docstring(
self,
node: ast.ClassDef | ast.FunctionDef | ast.AsyncFunctionDef,
) -> None:
is_dunder = node.name.startswith("__") and node.name.endswith("__")
is_private_name = node.name.startswith("_") and not is_dunder
is_nested_function = any(
isinstance(parent, (ast.FunctionDef, ast.AsyncFunctionDef))
for parent in stack
)
if (is_private_name or is_nested_function) and ast.get_docstring(node):
violations.append(f"{source_path}:{node.lineno}:{node.name}")
PrivateDocstringVisitor().visit(tree)
assert violations == []
def test_flow_definition_rejects_invalid_condition():
with pytest.raises(ValueError, match="Invalid condition"):
start(123)(lambda self: None)
def test_flow_definition_contract_is_dsl_agnostic():
@@ -304,81 +287,11 @@ def test_flow_definition_fragments_cover_start_listen_and_condition_sugar():
assert not hasattr(FragmentFlow.__dict__["begin"], "__is_start_method__")
assert not hasattr(FragmentFlow.__dict__["restart"], "__trigger_methods__")
assert "restart" not in FragmentFlow._listeners
assert FragmentFlow._listeners["by_callable"] == ("OR", ["begin"])
assert FragmentFlow._listeners["by_string"] == ("OR", ["manual_event"])
assert FragmentFlow._listeners["by_and"] == {
"type": "AND",
"conditions": ["begin", "by_callable"],
}
assert FragmentFlow._listeners["nested"] == {
"type": "OR",
"conditions": [
{"type": "AND", "conditions": ["manual_event", "by_string"]},
"fallback_event",
],
}
def test_extract_flow_definition_prefers_fragments_over_legacy_metadata():
class RegistryFlow(Flow):
@start()
def begin(self):
return "begin"
@listen(begin)
def handle(self):
return "handle"
@router(handle, emit=["done"])
def decide(self):
return "done"
handle = RegistryFlow.__dict__["handle"]
original_trigger_methods = handle.__trigger_methods__
handle.__trigger_methods__ = ["wrong"]
try:
_, listeners, routers, router_emit = flow_dsl.extract_flow_definition(
{
"begin": RegistryFlow.__dict__["begin"],
"handle": handle,
"decide": RegistryFlow.__dict__["decide"],
}
)
finally:
handle.__trigger_methods__ = original_trigger_methods
assert listeners["handle"] == ("OR", ["begin"])
assert listeners["decide"] == ("OR", ["handle"])
assert routers == {"decide"}
assert router_emit == {"decide": ["done"]}
def test_flow_definition_falls_back_to_legacy_listener_router_metadata_without_fragment():
class LegacyMetadataFlow(Flow):
@start()
def begin(self):
return "begin"
@router(begin, emit=["left"])
def decide(self):
return "left"
@listen("left")
def left(self):
return "left"
for method_name in ("decide", "left"):
method = LegacyMetadataFlow.__dict__[method_name]
delattr(method, "__flow_method_definition__")
definition = flow_dsl.build_flow_definition(LegacyMetadataFlow)
assert definition.methods["begin"].start is True
assert definition.methods["decide"].listen == "begin"
assert definition.methods["decide"].router is True
assert definition.methods["decide"].emit == ["left"]
assert definition.methods["left"].listen == "left"
for method_name in ("by_callable", "by_string", "by_and", "nested"):
method = FragmentFlow.__dict__[method_name]
assert not hasattr(method, "__trigger_methods__")
assert not hasattr(method, "__condition_type__")
assert not hasattr(method, "__trigger_condition__")
def test_human_feedback_emit_overrides_inner_router_emit():
@@ -400,9 +313,6 @@ def test_human_feedback_emit_overrides_inner_router_emit():
def proceed(self):
return "ok"
assert "route" in FeedbackOverRouterFlow._routers
assert FeedbackOverRouterFlow._router_emit["route"] == ["approved", "rejected"]
route = FeedbackOverRouterFlow.flow_definition().methods["route"]
assert route.router is True
assert route.human_feedback is not None

View File

@@ -0,0 +1,68 @@
"""Tests for the pluggable flow persistence factory seam.
We verify our own logic: that ``default_flow_persistence`` returns the
registered factory's result, and that it falls back to the built-in SQLite
persistence when no factory is registered.
"""
from __future__ import annotations
from typing import Any
import pytest
from pydantic import BaseModel
import crewai.flow.persistence.factory as factory
from crewai.flow.persistence.base import FlowPersistence
from crewai.flow.persistence.decorators import persist
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
@pytest.fixture(autouse=True)
def reset_factory():
"""Reset the factory around each test without clobbering preexisting state."""
original = factory._factory
factory.set_flow_persistence_factory(None)
yield
factory.set_flow_persistence_factory(original)
def test_default_uses_registered_factory():
sentinel = SQLiteFlowPersistence()
factory.set_flow_persistence_factory(lambda: sentinel)
assert factory.default_flow_persistence() is sentinel
def test_default_falls_back_to_sqlite():
assert isinstance(factory.default_flow_persistence(), SQLiteFlowPersistence)
def test_persist_decorator_honors_falsy_persistence():
# @persist with an explicit but falsy FlowPersistence must keep it, not
# replace it with the default via a truthiness check.
class _FalsyPersistence(FlowPersistence):
def __bool__(self) -> bool:
return False
def init_db(self) -> None:
pass
def save_state(
self,
flow_uuid: str,
method_name: str,
state_data: dict[str, Any] | BaseModel,
) -> None:
pass
def load_state(self, flow_uuid: str) -> dict[str, Any] | None:
return None
falsy = _FalsyPersistence()
@persist(persistence=falsy)
class _DummyFlow:
pass
assert _DummyFlow.__flow_persistence_config__.persistence is falsy

View File

@@ -78,8 +78,9 @@ class TestHumanFeedbackValidation:
return "output"
assert hasattr(test_method, "__human_feedback_config__")
assert test_method.__is_router__ is True
assert test_method.__router_emit__ == ["approve", "reject"]
assert test_method.__human_feedback_config__.emit == ["approve", "reject"]
assert not hasattr(test_method, "__is_router__")
assert not hasattr(test_method, "__router_emit__")
def test_valid_configuration_without_routing(self):
"""Test that valid configuration without routing doesn't raise."""
@@ -89,7 +90,7 @@ class TestHumanFeedbackValidation:
return "output"
assert hasattr(test_method, "__human_feedback_config__")
assert not hasattr(test_method, "__is_router__") or not test_method.__is_router__
assert not hasattr(test_method, "__is_router__")
def test_persist_preserves_human_feedback_llm_attribute(self):
"""Test @persist preserves the live LLM stashed by @human_feedback."""
@@ -177,8 +178,8 @@ class TestDecoratorAttributePreservation:
assert fragment is not None
assert fragment.start is True
def test_preserves_listen_method_attributes(self):
"""Test that @human_feedback preserves @listen decorator attributes."""
def test_preserves_listen_method_definition(self):
"""Test that @human_feedback preserves the @listen method definition."""
class TestFlow(Flow):
@start()
@@ -191,12 +192,14 @@ class TestDecoratorAttributePreservation:
return "review output"
flow = TestFlow()
assert "review" in flow._listeners or any(
"review" in str(v) for v in flow._listeners.values()
)
method = flow._methods.get("review")
assert method is not None
fragment = getattr(method, "__flow_method_definition__", None)
assert fragment is not None
assert fragment.listen == "begin"
def test_sets_router_attributes_when_emit_specified(self):
"""Test that router attributes are set when emit is specified."""
def test_emit_is_stored_on_human_feedback_config(self):
"""Test that emit outcomes are stored on human feedback config."""
@human_feedback(
message="Review:",
@@ -206,8 +209,12 @@ class TestDecoratorAttributePreservation:
def review_method(self):
return "output"
assert review_method.__is_router__ is True
assert review_method.__router_emit__ == ["approved", "rejected"]
assert review_method.__human_feedback_config__.emit == [
"approved",
"rejected",
]
assert not hasattr(review_method, "__is_router__")
assert not hasattr(review_method, "__router_emit__")
class TestAsyncSupport:

View File

@@ -1,7 +1,45 @@
from typing import Any, Dict, List, Union
import pytest
from crewai.utilities.string_utils import interpolate_only
from crewai.utilities.string_utils import (
extract_template_variables,
interpolate_only,
)
class TestExtractTemplateVariables:
"""Tests for extract_template_variables in string_utils.py."""
def test_extracts_simple_identifiers(self):
assert extract_template_variables("Hi {name}, see {topic}.") == [
"name",
"topic",
]
def test_allows_underscores_and_hyphens(self):
assert extract_template_variables("{user_name} {role-detail}") == [
"user_name",
"role-detail",
]
def test_ignores_inline_expressions(self):
text = '{company_name if company_name else "Individual Client"}'
assert extract_template_variables(text) == []
def test_ignores_json_like_braces(self):
assert extract_template_variables('{"key": "value"}') == []
def test_matches_what_interpolation_fills(self):
text = 'Use {topic} and {x if x else "y"}.'
variables = extract_template_variables(text)
assert variables == ["topic"]
# interpolation fills exactly the extracted variable and leaves the rest
result = interpolate_only(text, {"topic": "AI"})
assert result == 'Use AI and {x if x else "y"}.'
@pytest.mark.parametrize("value", [None, ""])
def test_handles_empty_input(self, value):
assert extract_template_variables(value) == []
class TestInterpolateOnly: