mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-06-09 18:28:10 +00:00
Compare commits
3 Commits
1.14.7a3
...
fix/interp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2f5928e4bb | ||
|
|
703ffe67ee | ||
|
|
8919026326 |
@@ -7,7 +7,6 @@ from copy import copy as shallow_copy
|
||||
from hashlib import md5
|
||||
import json
|
||||
from pathlib import Path
|
||||
import re
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Annotated,
|
||||
@@ -142,7 +141,10 @@ from crewai.utilities.streaming import (
|
||||
signal_end,
|
||||
signal_error,
|
||||
)
|
||||
from crewai.utilities.string_utils import sanitize_tool_name
|
||||
from crewai.utilities.string_utils import (
|
||||
extract_template_variables,
|
||||
sanitize_tool_name,
|
||||
)
|
||||
from crewai.utilities.task_output_storage_handler import TaskOutputStorageHandler
|
||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
|
||||
@@ -1960,20 +1962,24 @@ class Crew(FlowTrackable, BaseModel):
|
||||
Scans each task's 'description' + 'expected_output', and each agent's
|
||||
'role', 'goal', and 'backstory'.
|
||||
|
||||
Only placeholders that interpolation can actually fill are returned;
|
||||
non-identifier expressions such as ``{x if x else "y"}`` are ignored so
|
||||
they are not surfaced as required inputs (matching interpolation
|
||||
behavior, see :func:`extract_template_variables`).
|
||||
|
||||
Returns a set of all discovered placeholder names.
|
||||
"""
|
||||
placeholder_pattern = re.compile(r"\{(.+?)}")
|
||||
required_inputs: set[str] = set()
|
||||
|
||||
for task in self.tasks:
|
||||
# description and expected_output might contain e.g. {topic}, {user_name}
|
||||
text = f"{task.description or ''} {task.expected_output or ''}"
|
||||
required_inputs.update(placeholder_pattern.findall(text))
|
||||
required_inputs.update(extract_template_variables(text))
|
||||
|
||||
for agent in self.agents:
|
||||
# role, goal, backstory might have placeholders like {role_detail}, etc.
|
||||
text = f"{agent.role or ''} {agent.goal or ''} {agent.backstory or ''}"
|
||||
required_inputs.update(placeholder_pattern.findall(text))
|
||||
required_inputs.update(extract_template_variables(text))
|
||||
|
||||
return required_inputs
|
||||
|
||||
|
||||
@@ -15,10 +15,7 @@ from crewai.flow.dsl._human_feedback import (
|
||||
from crewai.flow.dsl._listen import listen
|
||||
from crewai.flow.dsl._router import router
|
||||
from crewai.flow.dsl._start import start
|
||||
from crewai.flow.dsl._utils import (
|
||||
build_flow_definition as build_flow_definition,
|
||||
extract_flow_definition as extract_flow_definition,
|
||||
)
|
||||
from crewai.flow.dsl._utils import build_flow_definition as build_flow_definition
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
||||
@@ -1,12 +1,4 @@
|
||||
"""Flow DSL condition primitives.
|
||||
|
||||
Type guards, the public ``or_`` / ``and_`` combinators, and the conversions
|
||||
between runtime conditions, normalized conditions, and the
|
||||
``FlowDefinitionCondition`` shape stored on a :class:`FlowDefinition`. These are
|
||||
the lower layer of the DSL: the decorators and the definition builder
|
||||
(``_utils``) build on top of them, so this module imports nothing from its
|
||||
siblings.
|
||||
"""
|
||||
"""Flow DSL condition primitives."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -20,268 +12,75 @@ from crewai.flow.dsl._types import FlowTrigger
|
||||
from crewai.flow.flow_definition import FlowDefinitionCondition
|
||||
from crewai.flow.flow_wrappers import (
|
||||
FlowCondition,
|
||||
FlowConditions,
|
||||
SimpleFlowCondition,
|
||||
FlowConditionType,
|
||||
)
|
||||
from crewai.flow.types import FlowMethodName
|
||||
|
||||
|
||||
def _is_non_string_sequence(value: Any) -> bool:
|
||||
return isinstance(value, Sequence) and not isinstance(value, (str, bytes))
|
||||
|
||||
|
||||
def is_simple_flow_condition(obj: Any) -> TypeIs[SimpleFlowCondition]:
|
||||
"""Check if the object is a ``(condition_type, methods)`` tuple."""
|
||||
return (
|
||||
isinstance(obj, tuple)
|
||||
and len(obj) == 2
|
||||
and isinstance(obj[0], str)
|
||||
and isinstance(obj[1], list)
|
||||
)
|
||||
|
||||
|
||||
def is_flow_condition_dict(obj: Any) -> TypeIs[FlowCondition]:
|
||||
"""Check if the object matches the FlowCondition structure."""
|
||||
if not isinstance(obj, dict):
|
||||
return False
|
||||
|
||||
type_value = obj.get("type")
|
||||
if type_value not in ("AND", "OR"):
|
||||
return False
|
||||
|
||||
if "conditions" in obj:
|
||||
conditions = obj["conditions"]
|
||||
if not _is_non_string_sequence(conditions):
|
||||
return False
|
||||
for cond in conditions:
|
||||
if not (
|
||||
isinstance(cond, str)
|
||||
or (isinstance(cond, dict) and is_flow_condition_dict(cond))
|
||||
):
|
||||
return False
|
||||
|
||||
if "methods" in obj:
|
||||
methods = obj["methods"]
|
||||
if not (
|
||||
_is_non_string_sequence(methods)
|
||||
and all(isinstance(m, str) for m in methods)
|
||||
):
|
||||
return False
|
||||
|
||||
allowed_keys = {"type", "conditions", "methods"}
|
||||
if not set(obj).issubset(allowed_keys):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _method_reference_name(value: Any) -> FlowMethodName | None:
|
||||
name = getattr(value, "__name__", None)
|
||||
if callable(value) and isinstance(name, str):
|
||||
return FlowMethodName(name)
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_condition(
|
||||
condition: FlowConditions | FlowCondition | str,
|
||||
) -> FlowCondition:
|
||||
if isinstance(condition, str):
|
||||
return {"type": OR_CONDITION, "conditions": [FlowMethodName(condition)]}
|
||||
if is_flow_condition_dict(condition):
|
||||
if "conditions" in condition:
|
||||
return condition
|
||||
if "methods" in condition:
|
||||
normalized_methods: list[str | FlowMethodName | FlowCondition] = list(
|
||||
condition["methods"]
|
||||
)
|
||||
return {"type": condition["type"], "conditions": normalized_methods}
|
||||
return condition
|
||||
if _is_non_string_sequence(condition) and all(
|
||||
isinstance(item, str) or is_flow_condition_dict(item) for item in condition
|
||||
):
|
||||
return {"type": OR_CONDITION, "conditions": condition}
|
||||
|
||||
raise ValueError(f"Cannot normalize condition: {condition}")
|
||||
|
||||
|
||||
def _extract_all_methods_recursive(
|
||||
condition: str | FlowCondition | dict[str, Any] | list[Any],
|
||||
flow: Any | None = None,
|
||||
) -> list[FlowMethodName]:
|
||||
if isinstance(condition, str):
|
||||
if flow is not None:
|
||||
if condition in flow._methods:
|
||||
return [FlowMethodName(condition)]
|
||||
return []
|
||||
return [FlowMethodName(condition)]
|
||||
if is_flow_condition_dict(condition):
|
||||
normalized = _normalize_condition(condition)
|
||||
methods = []
|
||||
for sub_cond in normalized.get("conditions", []):
|
||||
methods.extend(_extract_all_methods_recursive(sub_cond, flow))
|
||||
return methods
|
||||
if isinstance(condition, list):
|
||||
methods = []
|
||||
for item in condition:
|
||||
methods.extend(_extract_all_methods_recursive(item, flow))
|
||||
return methods
|
||||
return []
|
||||
|
||||
|
||||
def _extract_all_methods(
|
||||
condition: str | FlowCondition | dict[str, Any] | list[Any],
|
||||
) -> list[FlowMethodName]:
|
||||
if isinstance(condition, str):
|
||||
return [FlowMethodName(condition)]
|
||||
if is_flow_condition_dict(condition):
|
||||
normalized = _normalize_condition(condition)
|
||||
cond_type = normalized.get("type", OR_CONDITION)
|
||||
|
||||
if cond_type == AND_CONDITION:
|
||||
return [
|
||||
FlowMethodName(sub_cond)
|
||||
for sub_cond in normalized.get("conditions", [])
|
||||
if isinstance(sub_cond, str)
|
||||
]
|
||||
return []
|
||||
if isinstance(condition, list):
|
||||
methods = []
|
||||
for item in condition:
|
||||
methods.extend(_extract_all_methods(item))
|
||||
return methods
|
||||
return []
|
||||
|
||||
|
||||
def _condition_trigger(condition: FlowTrigger) -> FlowMethodName | FlowCondition:
|
||||
if isinstance(condition, str):
|
||||
return FlowMethodName(condition)
|
||||
if is_flow_condition_dict(condition):
|
||||
return condition
|
||||
method_name = _method_reference_name(condition)
|
||||
if method_name is not None:
|
||||
return method_name
|
||||
raise ValueError("Invalid condition")
|
||||
|
||||
|
||||
def _condition_triggers(
|
||||
conditions: Sequence[FlowTrigger],
|
||||
error_message: str,
|
||||
) -> FlowConditions:
|
||||
try:
|
||||
return [_condition_trigger(condition) for condition in conditions]
|
||||
except ValueError as exc:
|
||||
raise ValueError(error_message) from exc
|
||||
|
||||
|
||||
def _definition_condition_from_runtime(condition: Any) -> FlowDefinitionCondition:
|
||||
if isinstance(condition, str):
|
||||
return str(condition)
|
||||
method_name = _method_reference_name(condition)
|
||||
if method_name is not None:
|
||||
return str(method_name)
|
||||
if is_flow_condition_dict(condition):
|
||||
normalized = _normalize_condition(condition)
|
||||
key = "and" if normalized.get("type") == AND_CONDITION else "or"
|
||||
return {
|
||||
key: [
|
||||
_definition_condition_from_runtime(sub_condition)
|
||||
for sub_condition in normalized.get("conditions", [])
|
||||
]
|
||||
}
|
||||
if isinstance(condition, list):
|
||||
return {"or": [_definition_condition_from_runtime(item) for item in condition]}
|
||||
return str(condition)
|
||||
_CONDITION_TYPES = (AND_CONDITION, OR_CONDITION)
|
||||
|
||||
|
||||
def or_(*triggers: FlowTrigger) -> FlowCondition:
|
||||
"""Combine multiple triggers with OR logic for flow control.
|
||||
|
||||
Creates a condition that is satisfied when any of the specified triggers
|
||||
are met. This is used with @start, @listen, or @router decorators to create
|
||||
complex triggering conditions.
|
||||
|
||||
Args:
|
||||
triggers: Route labels, method references, or existing conditions
|
||||
returned by or_() / and_().
|
||||
|
||||
Returns:
|
||||
A condition dictionary with format {"type": "OR", "conditions": list_of_triggers}.
|
||||
|
||||
Raises:
|
||||
ValueError: If a trigger format is invalid.
|
||||
|
||||
Examples:
|
||||
>>> @listen(or_("success", "timeout"))
|
||||
>>> def handle_completion(self):
|
||||
... pass
|
||||
|
||||
>>> @listen(or_(and_("step1", "step2"), "step3"))
|
||||
>>> def handle_nested(self):
|
||||
... pass
|
||||
"""
|
||||
processed_triggers = _condition_triggers(triggers, "Invalid trigger in or_()")
|
||||
return {"type": OR_CONDITION, "conditions": processed_triggers}
|
||||
"""Return a condition that fires when any trigger fires."""
|
||||
return _condition_tree(OR_CONDITION, triggers)
|
||||
|
||||
|
||||
def and_(*triggers: FlowTrigger) -> FlowCondition:
|
||||
"""Combine multiple triggers with AND logic for flow control.
|
||||
|
||||
Creates a condition that is satisfied only when all specified triggers
|
||||
are met. This is used with @start, @listen, or @router decorators to create
|
||||
complex triggering conditions.
|
||||
|
||||
Args:
|
||||
triggers: Route labels, method references, or existing conditions
|
||||
returned by or_() / and_().
|
||||
|
||||
Returns:
|
||||
A condition dictionary with format {"type": "AND", "conditions": list_of_conditions}
|
||||
where each condition can be a route label, method name, or nested condition.
|
||||
|
||||
Raises:
|
||||
ValueError: If any trigger is invalid.
|
||||
|
||||
Examples:
|
||||
>>> @listen(and_("validated", "processed"))
|
||||
>>> def handle_complete_data(self):
|
||||
... pass
|
||||
|
||||
>>> @listen(and_(or_("step1", "step2"), "step3"))
|
||||
>>> def handle_nested(self):
|
||||
... pass
|
||||
"""
|
||||
processed_triggers = _condition_triggers(triggers, "Invalid trigger in and_()")
|
||||
return {"type": AND_CONDITION, "conditions": processed_triggers}
|
||||
"""Return a condition that fires after all triggers fire."""
|
||||
return _condition_tree(AND_CONDITION, triggers)
|
||||
|
||||
|
||||
def _runtime_condition_from_definition(
|
||||
condition: FlowDefinitionCondition,
|
||||
) -> FlowMethodName | FlowCondition:
|
||||
if isinstance(condition, str):
|
||||
return FlowMethodName(condition)
|
||||
if is_flow_condition_dict(condition):
|
||||
return condition
|
||||
def _trigger_name(value: Any) -> str | None:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
|
||||
if "and" in condition:
|
||||
return {
|
||||
"type": AND_CONDITION,
|
||||
"conditions": [
|
||||
_runtime_condition_from_definition(item)
|
||||
for item in condition.get("and", [])
|
||||
],
|
||||
}
|
||||
name = getattr(value, "__name__", None)
|
||||
if callable(value) and isinstance(name, str):
|
||||
return name
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _is_condition(value: Any) -> TypeIs[FlowCondition]:
|
||||
return (
|
||||
isinstance(value, dict)
|
||||
and set(value) == {"type", "conditions"}
|
||||
and value["type"] in _CONDITION_TYPES
|
||||
and isinstance(value["conditions"], list)
|
||||
and all(
|
||||
_trigger_name(condition) is not None or _is_condition(condition)
|
||||
for condition in value["conditions"]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _coerce_trigger(trigger: FlowTrigger) -> str | FlowCondition:
|
||||
name = _trigger_name(trigger)
|
||||
if name is not None:
|
||||
return name
|
||||
if _is_condition(trigger):
|
||||
return trigger
|
||||
raise ValueError("Invalid condition")
|
||||
|
||||
|
||||
def _condition_tree(
|
||||
condition_type: FlowConditionType,
|
||||
triggers: Sequence[FlowTrigger],
|
||||
) -> FlowCondition:
|
||||
return {
|
||||
"type": OR_CONDITION,
|
||||
"conditions": [
|
||||
_runtime_condition_from_definition(item) for item in condition.get("or", [])
|
||||
],
|
||||
"type": condition_type,
|
||||
"conditions": [_coerce_trigger(trigger) for trigger in triggers],
|
||||
}
|
||||
|
||||
|
||||
def _runtime_listener_condition_from_definition(
|
||||
condition: FlowDefinitionCondition,
|
||||
) -> SimpleFlowCondition | FlowCondition:
|
||||
runtime_condition = _runtime_condition_from_definition(condition)
|
||||
if isinstance(runtime_condition, str):
|
||||
return (OR_CONDITION, [FlowMethodName(str(runtime_condition))])
|
||||
return runtime_condition
|
||||
def _to_definition_condition(condition: FlowTrigger) -> FlowDefinitionCondition:
|
||||
trigger = _coerce_trigger(condition)
|
||||
if isinstance(trigger, str):
|
||||
return trigger
|
||||
|
||||
key = trigger["type"].lower()
|
||||
return {
|
||||
key: [
|
||||
_to_definition_condition(sub_condition)
|
||||
for sub_condition in trigger["conditions"]
|
||||
]
|
||||
}
|
||||
|
||||
@@ -27,13 +27,8 @@ def _stamp_human_feedback_metadata(
|
||||
config: HumanFeedbackConfig,
|
||||
) -> None:
|
||||
for attr in [
|
||||
"__trigger_methods__",
|
||||
"__condition_type__",
|
||||
"__trigger_condition__",
|
||||
"__is_flow_method__",
|
||||
"__flow_persistence_config__",
|
||||
"__is_router__",
|
||||
"__router_emit__",
|
||||
"__flow_method_definition__",
|
||||
]:
|
||||
if hasattr(func, attr):
|
||||
@@ -43,8 +38,6 @@ def _stamp_human_feedback_metadata(
|
||||
wrapper.__is_flow_method__ = True
|
||||
|
||||
if config.emit:
|
||||
wrapper.__is_router__ = True
|
||||
wrapper.__router_emit__ = list(config.emit)
|
||||
fragment = getattr(wrapper, "__flow_method_definition__", None)
|
||||
if isinstance(fragment, FlowMethodDefinition):
|
||||
wrapper.__flow_method_definition__ = fragment.model_copy(
|
||||
|
||||
@@ -3,13 +3,12 @@ from __future__ import annotations
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from crewai.flow.dsl._conditions import _definition_condition_from_runtime
|
||||
from crewai.flow.dsl._conditions import _to_definition_condition
|
||||
from crewai.flow.dsl._types import FlowMethodDecorator, FlowTrigger
|
||||
from crewai.flow.dsl._utils import (
|
||||
P,
|
||||
R,
|
||||
_set_flow_method_definition,
|
||||
_set_trigger_metadata,
|
||||
)
|
||||
from crewai.flow.flow_definition import FlowMethodDefinition
|
||||
from crewai.flow.flow_wrappers import ListenMethod
|
||||
@@ -46,10 +45,8 @@ def listen(condition: FlowTrigger) -> FlowMethodDecorator:
|
||||
wrapper = ListenMethod(func)
|
||||
|
||||
_set_flow_method_definition(
|
||||
wrapper,
|
||||
FlowMethodDefinition(listen=_definition_condition_from_runtime(condition)),
|
||||
wrapper, FlowMethodDefinition(listen=_to_definition_condition(condition))
|
||||
)
|
||||
_set_trigger_metadata(wrapper, condition)
|
||||
return wrapper
|
||||
|
||||
return cast(FlowMethodDecorator, decorator)
|
||||
|
||||
@@ -14,13 +14,12 @@ from typing import (
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
from crewai.flow.dsl._conditions import _definition_condition_from_runtime
|
||||
from crewai.flow.dsl._conditions import _to_definition_condition
|
||||
from crewai.flow.dsl._types import FlowMethodDecorator, FlowTrigger
|
||||
from crewai.flow.dsl._utils import (
|
||||
P,
|
||||
R,
|
||||
_set_flow_method_definition,
|
||||
_set_trigger_metadata,
|
||||
)
|
||||
from crewai.flow.flow_definition import FlowMethodDefinition
|
||||
from crewai.flow.flow_wrappers import RouterMethod
|
||||
@@ -149,18 +148,11 @@ def router(
|
||||
_set_flow_method_definition(
|
||||
wrapper,
|
||||
FlowMethodDefinition(
|
||||
listen=_definition_condition_from_runtime(condition),
|
||||
listen=_to_definition_condition(condition),
|
||||
router=True,
|
||||
emit=router_events or None,
|
||||
),
|
||||
)
|
||||
|
||||
_set_trigger_metadata(wrapper, condition)
|
||||
|
||||
if emit is not None:
|
||||
wrapper.__router_emit__ = router_events
|
||||
elif router_events:
|
||||
wrapper.__router_emit__ = router_events
|
||||
return wrapper
|
||||
|
||||
return cast(FlowMethodDecorator, decorator)
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from crewai.flow.dsl._conditions import _definition_condition_from_runtime
|
||||
from crewai.flow.dsl._conditions import _to_definition_condition
|
||||
from crewai.flow.dsl._types import FlowMethodDecorator, FlowTrigger
|
||||
from crewai.flow.dsl._utils import (
|
||||
P,
|
||||
@@ -56,9 +56,7 @@ def start(
|
||||
if condition is not None:
|
||||
_set_flow_method_definition(
|
||||
wrapper,
|
||||
FlowMethodDefinition(
|
||||
start=_definition_condition_from_runtime(condition)
|
||||
),
|
||||
FlowMethodDefinition(start=_to_definition_condition(condition)),
|
||||
)
|
||||
else:
|
||||
_set_flow_method_definition(wrapper, FlowMethodDefinition(start=True))
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, ParamSpec, TypeVar
|
||||
@@ -8,19 +7,9 @@ from typing import Any, ParamSpec, TypeVar
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from crewai.flow.constants import AND_CONDITION, OR_CONDITION
|
||||
from crewai.flow.dsl._conditions import (
|
||||
_definition_condition_from_runtime,
|
||||
_extract_all_methods,
|
||||
_method_reference_name,
|
||||
_runtime_listener_condition_from_definition,
|
||||
is_flow_condition_dict,
|
||||
)
|
||||
from crewai.flow.dsl._types import FlowTrigger
|
||||
from crewai.flow.flow_definition import (
|
||||
FlowConfigDefinition,
|
||||
FlowDefinition,
|
||||
FlowDefinitionCondition,
|
||||
FlowDefinitionDiagnostic,
|
||||
FlowHumanFeedbackDefinition,
|
||||
FlowMethodDefinition,
|
||||
@@ -29,10 +18,7 @@ from crewai.flow.flow_definition import (
|
||||
)
|
||||
from crewai.flow.flow_wrappers import (
|
||||
FlowMethod,
|
||||
ListenMethod,
|
||||
RouterMethod,
|
||||
)
|
||||
from crewai.flow.types import FlowMethodName
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
@@ -45,11 +31,8 @@ _FLOW_METHOD_DEFINITION_ATTR = "__flow_method_definition__"
|
||||
|
||||
def is_flow_method(obj: Any) -> TypeIs[FlowMethod[Any, Any]]:
|
||||
"""Check if the object carries Flow method wrapper metadata."""
|
||||
return (
|
||||
hasattr(obj, "__is_flow_method__")
|
||||
or hasattr(obj, "__trigger_methods__")
|
||||
or hasattr(obj, "__is_router__")
|
||||
or hasattr(obj, _FLOW_METHOD_DEFINITION_ATTR)
|
||||
return hasattr(obj, "__is_flow_method__") or hasattr(
|
||||
obj, _FLOW_METHOD_DEFINITION_ATTR
|
||||
)
|
||||
|
||||
|
||||
@@ -59,42 +42,6 @@ def _should_include_flow_method(flow_class: type, method: Any) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def _flow_method_names(values: Sequence[Any]) -> list[FlowMethodName]:
|
||||
return [FlowMethodName(str(value)) for value in values]
|
||||
|
||||
|
||||
def _set_trigger_metadata(
|
||||
wrapper: ListenMethod[P, R] | RouterMethod[P, R],
|
||||
condition: FlowTrigger,
|
||||
) -> None:
|
||||
if isinstance(condition, str):
|
||||
wrapper.__trigger_methods__ = [FlowMethodName(condition)]
|
||||
wrapper.__condition_type__ = OR_CONDITION
|
||||
return
|
||||
|
||||
if is_flow_condition_dict(condition):
|
||||
if "conditions" in condition:
|
||||
wrapper.__trigger_condition__ = condition
|
||||
wrapper.__trigger_methods__ = _extract_all_methods(condition)
|
||||
wrapper.__condition_type__ = condition["type"]
|
||||
return
|
||||
if "methods" in condition:
|
||||
wrapper.__trigger_methods__ = _flow_method_names(condition["methods"])
|
||||
wrapper.__condition_type__ = condition["type"]
|
||||
return
|
||||
raise ValueError("Condition dict must contain 'conditions' or 'methods'")
|
||||
|
||||
method_name = _method_reference_name(condition)
|
||||
if method_name is not None:
|
||||
wrapper.__trigger_methods__ = [method_name]
|
||||
wrapper.__condition_type__ = OR_CONDITION
|
||||
return
|
||||
|
||||
raise ValueError(
|
||||
"Condition must be a method, string, or a result of or_() or and_()"
|
||||
)
|
||||
|
||||
|
||||
def _set_flow_method_definition(
|
||||
wrapper: FlowMethod[P, R],
|
||||
definition: FlowMethodDefinition,
|
||||
@@ -236,48 +183,6 @@ def _build_config_definition(
|
||||
return FlowConfigDefinition(**values)
|
||||
|
||||
|
||||
def _condition_from_method_metadata(method: Any) -> FlowDefinitionCondition | None:
|
||||
trigger_condition = getattr(method, "__trigger_condition__", None)
|
||||
if trigger_condition is not None:
|
||||
return _definition_condition_from_runtime(trigger_condition)
|
||||
|
||||
trigger_methods = getattr(method, "__trigger_methods__", None)
|
||||
if trigger_methods is None:
|
||||
return None
|
||||
condition_type = getattr(method, "__condition_type__", OR_CONDITION)
|
||||
method_names = [str(method_name) for method_name in trigger_methods]
|
||||
if condition_type == AND_CONDITION:
|
||||
return {"and": method_names}
|
||||
if len(method_names) == 1:
|
||||
return method_names[0]
|
||||
return {"or": method_names}
|
||||
|
||||
|
||||
def _flow_method_definition_from_legacy_metadata(method: Any) -> FlowMethodDefinition:
|
||||
is_router = bool(getattr(method, "__is_router__", False))
|
||||
condition = _condition_from_method_metadata(method)
|
||||
|
||||
definition = FlowMethodDefinition(
|
||||
listen=condition,
|
||||
router=is_router,
|
||||
)
|
||||
|
||||
router_emit = getattr(method, "__router_emit__", None)
|
||||
if router_emit:
|
||||
definition.emit = [str(value) for value in router_emit]
|
||||
return definition
|
||||
|
||||
|
||||
def _definition_trigger_condition(
|
||||
method_definition: FlowMethodDefinition,
|
||||
) -> FlowDefinitionCondition | None:
|
||||
if method_definition.listen is not None:
|
||||
return method_definition.listen
|
||||
if isinstance(method_definition.start, (str, dict)):
|
||||
return method_definition.start
|
||||
return None
|
||||
|
||||
|
||||
def _build_human_feedback_definition(
|
||||
method: Any,
|
||||
diagnostics: list[FlowDefinitionDiagnostic],
|
||||
@@ -332,13 +237,10 @@ def _build_method_definition(
|
||||
) -> FlowMethodDefinition:
|
||||
fragment = _get_flow_method_definition(method)
|
||||
if fragment is None:
|
||||
method_definition = _flow_method_definition_from_legacy_metadata(method)
|
||||
method_definition = FlowMethodDefinition()
|
||||
else:
|
||||
method_definition = fragment.model_copy(deep=True)
|
||||
|
||||
if bool(getattr(method, "__is_router__", False)):
|
||||
method_definition.router = True
|
||||
|
||||
human_feedback = _build_human_feedback_definition(
|
||||
method, diagnostics, f"{path}.human_feedback"
|
||||
)
|
||||
@@ -352,11 +254,6 @@ def _build_method_definition(
|
||||
method, diagnostics, f"{path}.persist"
|
||||
)
|
||||
|
||||
router_emit = getattr(method, "__router_emit__", None)
|
||||
if router_emit and not (human_feedback and human_feedback.emit):
|
||||
if not method_definition.emit:
|
||||
method_definition.emit = [str(value) for value in router_emit]
|
||||
|
||||
return method_definition
|
||||
|
||||
|
||||
@@ -431,68 +328,3 @@ def build_flow_definition(
|
||||
) -> FlowDefinition:
|
||||
"""Build a FlowDefinition from a Python Flow class."""
|
||||
return _build_flow_definition_from_class(flow_class, namespace)
|
||||
|
||||
|
||||
def extract_flow_definition(
|
||||
namespace: dict[str, Any],
|
||||
) -> tuple[list[str], dict[str, Any], set[str], dict[str, Any]]:
|
||||
"""Extract the structural flow registries from a Python class namespace."""
|
||||
start_methods: list[str] = []
|
||||
listeners: dict[str, Any] = {}
|
||||
router_emit: dict[str, Any] = {}
|
||||
routers: set[str] = set()
|
||||
|
||||
for attr_name, attr_value in namespace.items():
|
||||
if is_flow_method(attr_value):
|
||||
method_definition = _get_flow_method_definition(attr_value)
|
||||
if method_definition is not None:
|
||||
condition = _definition_trigger_condition(method_definition)
|
||||
if condition is not None and not method_definition.is_start:
|
||||
listeners[attr_name] = _runtime_listener_condition_from_definition(
|
||||
condition
|
||||
)
|
||||
|
||||
is_router = method_definition.router or bool(
|
||||
getattr(attr_value, "__is_router__", False)
|
||||
)
|
||||
if is_router:
|
||||
routers.add(attr_name)
|
||||
if method_definition.emit:
|
||||
router_emit[attr_name] = [
|
||||
str(value) for value in method_definition.emit
|
||||
]
|
||||
elif (
|
||||
hasattr(attr_value, "__router_emit__")
|
||||
and attr_value.__router_emit__
|
||||
):
|
||||
router_emit[attr_name] = attr_value.__router_emit__
|
||||
else:
|
||||
router_emit[attr_name] = []
|
||||
continue
|
||||
|
||||
if (
|
||||
hasattr(attr_value, "__trigger_methods__")
|
||||
and attr_value.__trigger_methods__ is not None
|
||||
):
|
||||
methods = attr_value.__trigger_methods__
|
||||
condition_type = getattr(attr_value, "__condition_type__", OR_CONDITION)
|
||||
|
||||
if (
|
||||
hasattr(attr_value, "__trigger_condition__")
|
||||
and attr_value.__trigger_condition__ is not None
|
||||
):
|
||||
listeners[attr_name] = attr_value.__trigger_condition__
|
||||
else:
|
||||
listeners[attr_name] = (condition_type, methods)
|
||||
|
||||
if hasattr(attr_value, "__is_router__") and attr_value.__is_router__:
|
||||
routers.add(attr_name)
|
||||
if (
|
||||
hasattr(attr_value, "__router_emit__")
|
||||
and attr_value.__router_emit__
|
||||
):
|
||||
router_emit[attr_name] = attr_value.__router_emit__
|
||||
else:
|
||||
router_emit[attr_name] = []
|
||||
|
||||
return start_methods, listeners, routers, router_emit
|
||||
|
||||
@@ -16,7 +16,6 @@ P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
FlowConditionType: TypeAlias = Literal["OR", "AND"]
|
||||
SimpleFlowCondition: TypeAlias = tuple[FlowConditionType, list[FlowMethodName]]
|
||||
|
||||
__all__ = [
|
||||
"FlowCondition",
|
||||
@@ -25,7 +24,6 @@ __all__ = [
|
||||
"FlowMethod",
|
||||
"ListenMethod",
|
||||
"RouterMethod",
|
||||
"SimpleFlowCondition",
|
||||
"StartMethod",
|
||||
]
|
||||
|
||||
@@ -38,15 +36,13 @@ class FlowCondition(TypedDict, total=False):
|
||||
Attributes:
|
||||
type: The type of the condition.
|
||||
conditions: A sequence of route labels, method names, or nested conditions.
|
||||
methods: A legacy sequence of route labels or method names.
|
||||
"""
|
||||
|
||||
type: Required[FlowConditionType]
|
||||
conditions: Sequence[str | FlowMethodName | FlowCondition]
|
||||
methods: Sequence[str | FlowMethodName]
|
||||
conditions: Sequence[str | FlowCondition]
|
||||
|
||||
|
||||
FlowConditions: TypeAlias = Sequence[str | FlowMethodName | FlowCondition]
|
||||
FlowConditions: TypeAlias = Sequence[str | FlowCondition]
|
||||
|
||||
|
||||
class FlowMethod(Generic[P, R]):
|
||||
@@ -83,8 +79,6 @@ class FlowMethod(Generic[P, R]):
|
||||
|
||||
# Preserve flow-related attributes from wrapped method (e.g., from @human_feedback)
|
||||
for attr in [
|
||||
"__is_router__",
|
||||
"__router_emit__",
|
||||
"__human_feedback_config__",
|
||||
"__conversational_only__", # gates registration on Flow.conversational
|
||||
"__flow_persistence_config__",
|
||||
@@ -162,16 +156,6 @@ class StartMethod(FlowMethod[P, R]):
|
||||
class ListenMethod(FlowMethod[P, R]):
|
||||
"""Wrapper for methods marked as flow listeners."""
|
||||
|
||||
__trigger_methods__: list[FlowMethodName] | None = None
|
||||
__condition_type__: FlowConditionType | None = None
|
||||
__trigger_condition__: FlowCondition | None = None
|
||||
|
||||
|
||||
class RouterMethod(FlowMethod[P, R]):
|
||||
"""Wrapper for methods marked as flow routers."""
|
||||
|
||||
__is_router__: bool = True
|
||||
__trigger_methods__: list[FlowMethodName] | None = None
|
||||
__condition_type__: FlowConditionType | None = None
|
||||
__trigger_condition__: FlowCondition | None = None
|
||||
__router_emit__: list[str] | None = None
|
||||
|
||||
@@ -187,16 +187,12 @@ class HumanFeedbackMethod(FlowMethod[Any, Any]):
|
||||
"""Wrapper for methods decorated with @human_feedback.
|
||||
|
||||
This wrapper extends FlowMethod to add human feedback specific attributes
|
||||
that are used by FlowMeta for routing and by visualization tools.
|
||||
used by the FlowDefinition builder and runtime feedback handling.
|
||||
|
||||
Attributes:
|
||||
__is_router__: True when emit is specified, enabling router behavior.
|
||||
__router_emit__: List of possible outcomes when acting as a router.
|
||||
__human_feedback_config__: The HumanFeedbackConfig for this method.
|
||||
"""
|
||||
|
||||
__is_router__: bool = False
|
||||
__router_emit__: list[str] | None = None
|
||||
__human_feedback_config__: HumanFeedbackConfig | None = None
|
||||
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ from crewai_core.printer import PRINTER
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.flow.persistence.base import FlowPersistence
|
||||
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
|
||||
from crewai.flow.persistence.factory import default_flow_persistence
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -67,11 +67,6 @@ def _stamp_persistence_metadata(
|
||||
|
||||
|
||||
_PRESERVED_FLOW_ATTRS: Final[tuple[str, ...]] = (
|
||||
"__trigger_methods__",
|
||||
"__condition_type__",
|
||||
"__trigger_condition__",
|
||||
"__is_router__",
|
||||
"__router_emit__",
|
||||
"__human_feedback_config__",
|
||||
"__flow_persistence_config__",
|
||||
"__flow_method_definition__",
|
||||
@@ -171,7 +166,9 @@ def persist(
|
||||
|
||||
Args:
|
||||
persistence: Optional FlowPersistence implementation to use.
|
||||
If not provided, uses SQLiteFlowPersistence.
|
||||
If not provided, uses ``default_flow_persistence()`` (the
|
||||
registered factory when present, else the built-in SQLite
|
||||
fallback).
|
||||
verbose: Whether to log persistence operations. Defaults to False.
|
||||
|
||||
Returns:
|
||||
@@ -190,7 +187,9 @@ def persist(
|
||||
"""
|
||||
|
||||
def decorator(target: type | Callable[..., T]) -> type | Callable[..., T]:
|
||||
actual_persistence = persistence or SQLiteFlowPersistence()
|
||||
actual_persistence = (
|
||||
persistence if persistence is not None else default_flow_persistence()
|
||||
)
|
||||
|
||||
if isinstance(target, type):
|
||||
_stamp_persistence_metadata(target, actual_persistence, verbose)
|
||||
@@ -210,10 +209,7 @@ def persist(
|
||||
for name, method in target.__dict__.items()
|
||||
if callable(method)
|
||||
and (
|
||||
hasattr(method, "__trigger_methods__")
|
||||
or hasattr(method, "__condition_type__")
|
||||
or hasattr(method, "__is_flow_method__")
|
||||
or hasattr(method, "__is_router__")
|
||||
hasattr(method, "__is_flow_method__")
|
||||
or hasattr(method, "__flow_method_definition__")
|
||||
)
|
||||
}
|
||||
|
||||
60
lib/crewai/src/crewai/flow/persistence/factory.py
Normal file
60
lib/crewai/src/crewai/flow/persistence/factory.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""Pluggable default persistence backend for flows.
|
||||
|
||||
By default, ``@persist`` and the flow runtime persist state with
|
||||
:class:`~crewai.flow.persistence.sqlite.SQLiteFlowPersistence` when no explicit
|
||||
``persistence=`` is given. Registering a factory via
|
||||
:func:`set_flow_persistence_factory` lets an application back flow state with a
|
||||
custom :class:`~crewai.flow.persistence.base.FlowPersistence` -- a database, a
|
||||
remote service, an in-memory fake for tests -- without passing a
|
||||
``persistence=`` instance at every ``@persist`` / kickoff site.
|
||||
|
||||
This mirrors :func:`crewai_core.lock_store.set_lock_backend`: a one-time,
|
||||
process-wide setter intended for application startup. Pass ``None`` to restore
|
||||
the built-in SQLite default. Call :func:`default_flow_persistence` to build the
|
||||
default backend (the registered factory if any, else SQLite).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.flow.persistence.base import FlowPersistence
|
||||
|
||||
FlowPersistenceFactory = Callable[[], "FlowPersistence"]
|
||||
|
||||
_factory: FlowPersistenceFactory | None = None
|
||||
|
||||
|
||||
def set_flow_persistence_factory(factory: FlowPersistenceFactory | None) -> None:
|
||||
"""Replace the process-wide default flow persistence factory.
|
||||
|
||||
Intended for one-time setup at startup. Pass ``None`` to restore the
|
||||
built-in ``SQLiteFlowPersistence``. Only affects flows that fall back to
|
||||
the default; an explicit ``persistence=`` instance always wins.
|
||||
|
||||
The default is resolved at each fall-back site (``@persist`` and the
|
||||
runtime's pause/resume paths), so the factory may be called more than once
|
||||
for a single flow. Return instances backed by shared durable state (or a
|
||||
singleton) so state saved on one call is visible to the next -- the
|
||||
built-in SQLite default satisfies this by sharing one on-disk file.
|
||||
"""
|
||||
global _factory
|
||||
_factory = factory
|
||||
|
||||
|
||||
def default_flow_persistence() -> FlowPersistence:
|
||||
"""Build the default flow persistence backend.
|
||||
|
||||
Returns the result of the registered factory if one is set, otherwise a
|
||||
built-in :class:`~crewai.flow.persistence.sqlite.SQLiteFlowPersistence`.
|
||||
"""
|
||||
factory = _factory
|
||||
if factory is not None:
|
||||
return factory()
|
||||
|
||||
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
|
||||
|
||||
return SQLiteFlowPersistence()
|
||||
@@ -89,27 +89,17 @@ from crewai.experimental.conversational import (
|
||||
ConversationState,
|
||||
)
|
||||
from crewai.experimental.conversational_mixin import _ConversationalMixin
|
||||
from crewai.flow.constants import AND_CONDITION, OR_CONDITION
|
||||
from crewai.flow.dsl._conditions import (
|
||||
_extract_all_methods,
|
||||
_extract_all_methods_recursive,
|
||||
_normalize_condition,
|
||||
_runtime_listener_condition_from_definition,
|
||||
is_flow_condition_dict,
|
||||
is_simple_flow_condition,
|
||||
)
|
||||
from crewai.flow.dsl._utils import (
|
||||
build_flow_definition,
|
||||
extract_flow_definition,
|
||||
)
|
||||
from crewai.flow.dsl._utils import build_flow_definition
|
||||
from crewai.flow.flow_context import current_flow_id, current_flow_request_id
|
||||
from crewai.flow.flow_definition import FlowDefinition, FlowDefinitionCondition
|
||||
from crewai.flow.flow_definition import (
|
||||
FlowDefinition,
|
||||
FlowDefinitionCondition,
|
||||
FlowMethodDefinition,
|
||||
)
|
||||
from crewai.flow.flow_wrappers import (
|
||||
FlowCondition,
|
||||
FlowMethod,
|
||||
ListenMethod,
|
||||
RouterMethod,
|
||||
SimpleFlowCondition,
|
||||
StartMethod,
|
||||
)
|
||||
from crewai.flow.human_feedback import HumanFeedbackResult
|
||||
@@ -164,6 +154,25 @@ ExecutionContext = Any # type: ignore[assignment,misc]
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _iter_condition_events(condition: FlowDefinitionCondition) -> Iterator[str]:
|
||||
if isinstance(condition, str):
|
||||
yield condition
|
||||
return
|
||||
|
||||
sub_conditions = condition["and"] if "and" in condition else condition["or"]
|
||||
for sub_condition in sub_conditions:
|
||||
yield from _iter_condition_events(sub_condition)
|
||||
|
||||
|
||||
def _is_multi_event_or(
|
||||
condition: FlowDefinitionCondition,
|
||||
) -> bool:
|
||||
if isinstance(condition, str):
|
||||
return False
|
||||
|
||||
return "or" in condition and len(condition["or"]) > 1
|
||||
|
||||
|
||||
def _resolve_persistence(value: Any) -> Any:
|
||||
if value is None or isinstance(value, FlowPersistence):
|
||||
return value
|
||||
@@ -601,18 +610,10 @@ class FlowMeta(ModelMetaclass):
|
||||
annotations[attr_name] = ClassVar[type(attr_value)]
|
||||
namespace["__annotations__"] = annotations
|
||||
|
||||
cls = super().__new__(mcs, name, bases, namespace)
|
||||
|
||||
_, listeners, routers, router_emit = extract_flow_definition(namespace)
|
||||
|
||||
cls._listeners = listeners # type: ignore[attr-defined]
|
||||
cls._routers = routers # type: ignore[attr-defined]
|
||||
cls._router_emit = router_emit # type: ignore[attr-defined]
|
||||
# The static FlowDefinition is built lazily (on first access via
|
||||
# ``Flow.flow_definition()`` or visualization), not at class-definition
|
||||
# time, to avoid AST parsing and diagnostic logging on every import.
|
||||
|
||||
return cls
|
||||
return super().__new__(mcs, name, bases, namespace)
|
||||
|
||||
|
||||
class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
@@ -627,9 +628,6 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
)
|
||||
__hash__ = object.__hash__
|
||||
|
||||
_listeners: ClassVar[dict[FlowMethodName, SimpleFlowCondition | FlowCondition]] = {}
|
||||
_routers: ClassVar[set[FlowMethodName]] = set()
|
||||
_router_emit: ClassVar[dict[FlowMethodName, list[FlowMethodName]]] = {}
|
||||
_flow_definition: ClassVar[FlowDefinition | None] = None
|
||||
|
||||
# === EXPERIMENTAL: conversational mode ===
|
||||
@@ -677,7 +675,7 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
return flow_definition
|
||||
|
||||
@classmethod
|
||||
def _definition_start_method_names(cls) -> list[FlowMethodName]:
|
||||
def _start_method_names(cls) -> list[FlowMethodName]:
|
||||
return [
|
||||
FlowMethodName(method_name)
|
||||
for method_name, method_definition in cls.flow_definition().methods.items()
|
||||
@@ -685,21 +683,39 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _definition_start_condition(
|
||||
def _listener_methods(
|
||||
cls,
|
||||
) -> Iterator[tuple[FlowMethodName, FlowMethodDefinition, FlowDefinitionCondition]]:
|
||||
# (name, definition, condition) for every non-start method that listens.
|
||||
# Routers are included (they listen too); callers wanting only plain
|
||||
# listeners filter on definition.router.
|
||||
for method_name, method_definition in cls.flow_definition().methods.items():
|
||||
if method_definition.listen is not None and not method_definition.is_start:
|
||||
yield (
|
||||
FlowMethodName(method_name),
|
||||
method_definition,
|
||||
method_definition.listen,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _start_condition(
|
||||
cls, method_name: FlowMethodName
|
||||
) -> FlowDefinitionCondition | None:
|
||||
method_definition = cls.flow_definition().methods.get(str(method_name))
|
||||
if method_definition is None:
|
||||
return None
|
||||
method_definition = cls.flow_definition().methods[str(method_name)]
|
||||
start = method_definition.start
|
||||
if isinstance(start, (str, dict)):
|
||||
return start
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _definition_has_start(cls, method_name: FlowMethodName) -> bool:
|
||||
method_definition = cls.flow_definition().methods.get(str(method_name))
|
||||
return bool(method_definition and method_definition.is_start)
|
||||
def _listen_condition(
|
||||
cls, method_name: FlowMethodName
|
||||
) -> FlowDefinitionCondition | None:
|
||||
return cls.flow_definition().methods[str(method_name)].listen
|
||||
|
||||
@classmethod
|
||||
def _is_router(cls, method_name: FlowMethodName) -> bool:
|
||||
return cls.flow_definition().methods[str(method_name)].router
|
||||
|
||||
initial_state: Annotated[ # type: ignore[type-arg]
|
||||
type[BaseModel] | type[dict] | dict[str, Any] | BaseModel | None,
|
||||
@@ -848,10 +864,13 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
_method_execution_counts: dict[FlowMethodName, int] = PrivateAttr(
|
||||
default_factory=dict
|
||||
)
|
||||
_pending_and_listeners: dict[PendingListenerKey, set[FlowMethodName]] = PrivateAttr(
|
||||
_pending_and_listeners: dict[PendingListenerKey, set[int]] = PrivateAttr(
|
||||
default_factory=dict
|
||||
)
|
||||
_fired_or_listeners: set[FlowMethodName] = PrivateAttr(default_factory=set)
|
||||
_racing_groups_cache: dict[frozenset[FlowMethodName], FlowMethodName] | None = (
|
||||
PrivateAttr(default=None)
|
||||
)
|
||||
_method_outputs: list[Any] = PrivateAttr(default_factory=list)
|
||||
_state_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
|
||||
_or_listeners_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
|
||||
@@ -992,22 +1011,6 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
result: list[str] = self.memory.extract_memories(content)
|
||||
return result
|
||||
|
||||
def _mark_or_listener_fired(self, listener_name: FlowMethodName) -> bool:
|
||||
"""Mark an OR listener as fired atomically.
|
||||
|
||||
Args:
|
||||
listener_name: The name of the OR listener to mark.
|
||||
|
||||
Returns:
|
||||
True if this call was the first to fire the listener.
|
||||
False if the listener was already fired.
|
||||
"""
|
||||
with self._or_listeners_lock:
|
||||
if listener_name in self._fired_or_listeners:
|
||||
return False
|
||||
self._fired_or_listeners.add(listener_name)
|
||||
return True
|
||||
|
||||
def _clear_or_listeners(self) -> None:
|
||||
"""Clear fired OR listeners for cyclic flows."""
|
||||
with self._or_listeners_lock:
|
||||
@@ -1021,25 +1024,11 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
def _start_condition_triggered_by(
|
||||
self, method_name: FlowMethodName, trigger: FlowMethodName
|
||||
) -> bool:
|
||||
condition = type(self)._definition_start_condition(method_name)
|
||||
condition = type(self)._start_condition(method_name)
|
||||
if condition is None:
|
||||
return False
|
||||
condition_data = _runtime_listener_condition_from_definition(condition)
|
||||
if is_simple_flow_condition(condition_data):
|
||||
condition_type, methods = condition_data
|
||||
if condition_type == OR_CONDITION:
|
||||
return trigger in methods
|
||||
pending_key = PendingListenerKey(method_name)
|
||||
if pending_key not in self._pending_and_listeners:
|
||||
self._pending_and_listeners[pending_key] = set(methods)
|
||||
if trigger in self._pending_and_listeners[pending_key]:
|
||||
self._pending_and_listeners[pending_key].discard(trigger)
|
||||
if not self._pending_and_listeners[pending_key]:
|
||||
self._pending_and_listeners.pop(pending_key, None)
|
||||
return True
|
||||
return False
|
||||
return self._evaluate_condition(
|
||||
condition_data,
|
||||
condition,
|
||||
trigger,
|
||||
method_name,
|
||||
pending_key_prefix=f"start:{method_name}",
|
||||
@@ -1050,18 +1039,9 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
trigger: FlowMethodName,
|
||||
rearmable: set[FlowMethodName] | None = None,
|
||||
) -> None:
|
||||
"""Re-arm fired OR listeners whose condition includes ``trigger``.
|
||||
|
||||
Called when a router emits a fresh signal so cyclic flows can re-fire
|
||||
multi-source ``or_`` listeners. Listeners whose condition does not
|
||||
reference the trigger are left fired.
|
||||
|
||||
Args:
|
||||
trigger: The signal/method name a router just emitted.
|
||||
rearmable: Optional set restricting which listeners may be re-armed.
|
||||
When provided, listeners outside this set are skipped, and any
|
||||
listener re-armed is removed from it.
|
||||
"""
|
||||
# When a router emits a fresh signal, re-arm fired multi-event or_()
|
||||
# listeners that reference the trigger so cyclic flows can re-fire them.
|
||||
# A given rearmable set, when passed, bounds which listeners may re-arm.
|
||||
with self._or_listeners_lock:
|
||||
if not self._fired_or_listeners:
|
||||
return
|
||||
@@ -1075,87 +1055,60 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
trigger_str = str(trigger)
|
||||
to_discard: list[FlowMethodName] = []
|
||||
for listener_name in candidates:
|
||||
condition_data = self._listeners.get(listener_name)
|
||||
if condition_data is None:
|
||||
condition = type(self)._listen_condition(listener_name)
|
||||
if condition is None:
|
||||
continue
|
||||
if is_simple_flow_condition(condition_data):
|
||||
_, methods = condition_data
|
||||
if trigger in methods or trigger_str in {str(m) for m in methods}:
|
||||
to_discard.append(listener_name)
|
||||
elif is_flow_condition_dict(condition_data):
|
||||
all_methods = _extract_all_methods_recursive(condition_data)
|
||||
if trigger_str in {str(m) for m in all_methods}:
|
||||
to_discard.append(listener_name)
|
||||
if trigger_str in _iter_condition_events(condition):
|
||||
to_discard.append(listener_name)
|
||||
for listener_name in to_discard:
|
||||
self._fired_or_listeners.discard(listener_name)
|
||||
if rearmable is not None:
|
||||
rearmable.discard(listener_name)
|
||||
|
||||
def _build_racing_groups(self) -> dict[frozenset[FlowMethodName], FlowMethodName]:
|
||||
"""Identify groups of methods that race for the same OR listener.
|
||||
|
||||
Analyzes the flow graph to find listeners with OR conditions that have
|
||||
multiple trigger methods. These trigger methods form a "racing group"
|
||||
where only the first to complete should trigger the OR listener.
|
||||
|
||||
Only methods that are EXCLUSIVELY sources for the OR listener are included
|
||||
in the racing group. Methods that are also triggers for other listeners
|
||||
(e.g., AND conditions) are not cancelled when another racing source wins.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping frozensets of racing method names to their
|
||||
shared OR listener name.
|
||||
|
||||
Example:
|
||||
If we have `@listen(or_(method_a, method_b))` on `handler`,
|
||||
and method_a/method_b aren't used elsewhere,
|
||||
this returns: {frozenset({'method_a', 'method_b'}): 'handler'}
|
||||
"""
|
||||
# Events of a multi-event or_() listener race: only the first to fire
|
||||
# should trigger it. We map {frozenset(racing events): listener}.
|
||||
# Only events that EXCLUSIVELY feed one OR listener race; an event that
|
||||
# also feeds another listener (e.g. an AND) is left alone when a sibling
|
||||
# wins. e.g. @listen(or_(a, b)) on handler -> {frozenset({a, b}): handler}.
|
||||
racing_groups: dict[frozenset[FlowMethodName], FlowMethodName] = {}
|
||||
listener_conditions: dict[FlowMethodName, FlowDefinitionCondition] = {
|
||||
listener_name: condition
|
||||
for listener_name, method_definition, condition in type(
|
||||
self
|
||||
)._listener_methods()
|
||||
if not method_definition.router
|
||||
}
|
||||
|
||||
method_to_listeners: dict[FlowMethodName, set[FlowMethodName]] = {}
|
||||
for listener_name, condition_data in self._listeners.items():
|
||||
if is_simple_flow_condition(condition_data):
|
||||
_, methods = condition_data
|
||||
for m in methods:
|
||||
method_to_listeners.setdefault(m, set()).add(listener_name)
|
||||
elif is_flow_condition_dict(condition_data):
|
||||
all_methods = _extract_all_methods_recursive(condition_data)
|
||||
for m in all_methods:
|
||||
method_name = FlowMethodName(m) if isinstance(m, str) else m
|
||||
method_to_listeners.setdefault(method_name, set()).add(
|
||||
listener_name
|
||||
)
|
||||
events_by_listener: dict[FlowMethodName, set[str]] = {
|
||||
listener_name: set(_iter_condition_events(condition))
|
||||
for listener_name, condition in listener_conditions.items()
|
||||
}
|
||||
|
||||
for listener_name, condition_data in self._listeners.items():
|
||||
if listener_name in self._routers:
|
||||
listeners_by_event: dict[str, set[FlowMethodName]] = {}
|
||||
for listener_name, events in events_by_listener.items():
|
||||
for event in events:
|
||||
listeners_by_event.setdefault(event, set()).add(listener_name)
|
||||
|
||||
for listener_name, condition in listener_conditions.items():
|
||||
if not isinstance(condition, dict):
|
||||
continue
|
||||
events = events_by_listener[listener_name]
|
||||
if "or" not in condition or len(events) <= 1:
|
||||
continue
|
||||
|
||||
trigger_methods: set[FlowMethodName] = set()
|
||||
|
||||
if is_simple_flow_condition(condition_data):
|
||||
condition_type, methods = condition_data
|
||||
if condition_type == OR_CONDITION and len(methods) > 1:
|
||||
trigger_methods = set(methods)
|
||||
|
||||
elif is_flow_condition_dict(condition_data):
|
||||
top_level_type = condition_data.get("type", OR_CONDITION)
|
||||
if top_level_type == OR_CONDITION:
|
||||
all_methods = _extract_all_methods_recursive(condition_data)
|
||||
if len(all_methods) > 1:
|
||||
trigger_methods = set(
|
||||
FlowMethodName(m) if isinstance(m, str) else m
|
||||
for m in all_methods
|
||||
)
|
||||
|
||||
if trigger_methods:
|
||||
exclusive_methods = {
|
||||
m
|
||||
for m in trigger_methods
|
||||
if method_to_listeners.get(m, set()) == {listener_name}
|
||||
}
|
||||
if len(exclusive_methods) > 1:
|
||||
racing_groups[frozenset(exclusive_methods)] = listener_name
|
||||
exclusive_events = {
|
||||
event
|
||||
for event in events
|
||||
if listeners_by_event.get(event, set()) == {listener_name}
|
||||
}
|
||||
if len(exclusive_events) > 1:
|
||||
# Racing only applies to method-completion events: each member is
|
||||
# later executed as a method and intersected with the running
|
||||
# method names, so the leaves re-enter method space here.
|
||||
racing_groups[
|
||||
frozenset(FlowMethodName(event) for event in exclusive_events)
|
||||
] = listener_name
|
||||
|
||||
return racing_groups
|
||||
|
||||
@@ -1172,16 +1125,15 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
Tuple of (racing_members, or_listener_name) if these listeners race,
|
||||
None otherwise.
|
||||
"""
|
||||
if not hasattr(self, "_racing_groups_cache"):
|
||||
if self._racing_groups_cache is None:
|
||||
self._racing_groups_cache = self._build_racing_groups()
|
||||
|
||||
listener_set = set(listener_names)
|
||||
|
||||
for racing_members, or_listener in self._racing_groups_cache.items():
|
||||
if racing_members & listener_set:
|
||||
racing_subset = racing_members & listener_set
|
||||
if len(racing_subset) > 1:
|
||||
return (frozenset(racing_subset), or_listener)
|
||||
racing_subset = racing_members & listener_set
|
||||
if len(racing_subset) > 1:
|
||||
return (frozenset(racing_subset), or_listener)
|
||||
|
||||
return None
|
||||
|
||||
@@ -1252,7 +1204,9 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
Args:
|
||||
flow_id: The unique identifier of the paused flow (from state.id)
|
||||
persistence: The persistence backend where the state was saved.
|
||||
If not provided, defaults to SQLiteFlowPersistence().
|
||||
If not provided, uses ``default_flow_persistence()`` (the
|
||||
registered factory when present, else the built-in SQLite
|
||||
fallback).
|
||||
**kwargs: Additional keyword arguments passed to the Flow constructor
|
||||
|
||||
Returns:
|
||||
@@ -1274,9 +1228,9 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
```
|
||||
"""
|
||||
if persistence is None:
|
||||
from crewai.flow.persistence import SQLiteFlowPersistence
|
||||
from crewai.flow.persistence.factory import default_flow_persistence
|
||||
|
||||
persistence = SQLiteFlowPersistence()
|
||||
persistence = default_flow_persistence()
|
||||
|
||||
loaded = persistence.load_pending_feedback(flow_id)
|
||||
if loaded is None:
|
||||
@@ -1463,7 +1417,7 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
|
||||
self._pending_feedback_context = None
|
||||
|
||||
if self.persistence:
|
||||
if self.persistence is not None:
|
||||
self.persistence.clear_pending_feedback(context.flow_id)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
@@ -1505,9 +1459,9 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
self._pending_feedback_context = e.context
|
||||
|
||||
if self.persistence is None:
|
||||
from crewai.flow.persistence import SQLiteFlowPersistence
|
||||
from crewai.flow.persistence.factory import default_flow_persistence
|
||||
|
||||
self.persistence = SQLiteFlowPersistence()
|
||||
self.persistence = default_flow_persistence()
|
||||
|
||||
state_data = (
|
||||
self._state
|
||||
@@ -2221,11 +2175,11 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
# Determine which start methods to execute at kickoff
|
||||
# Conditional start methods are only triggered by their conditions
|
||||
# UNLESS there are no unconditional starts (then all starts run as entry points)
|
||||
start_methods = type(self)._definition_start_method_names()
|
||||
start_methods = type(self)._start_method_names()
|
||||
unconditional_starts = [
|
||||
start_method
|
||||
for start_method in start_methods
|
||||
if type(self)._definition_start_condition(start_method) is None
|
||||
if type(self)._start_condition(start_method) is None
|
||||
]
|
||||
# If there are unconditional starts, only run those at kickoff
|
||||
# If there are NO unconditional starts, run all starts (including conditional ones)
|
||||
@@ -2244,9 +2198,11 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
if isinstance(e, HumanFeedbackPending):
|
||||
# Auto-save pending feedback (create default persistence if needed)
|
||||
if self.persistence is None:
|
||||
from crewai.flow.persistence import SQLiteFlowPersistence
|
||||
from crewai.flow.persistence.factory import (
|
||||
default_flow_persistence,
|
||||
)
|
||||
|
||||
self.persistence = SQLiteFlowPersistence()
|
||||
self.persistence = default_flow_persistence()
|
||||
|
||||
state_data = (
|
||||
self._state
|
||||
@@ -2448,11 +2404,12 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
)
|
||||
|
||||
# If start method is a router, use its result as an additional trigger
|
||||
if start_method_name in self._routers and result is not None:
|
||||
if type(self)._is_router(start_method_name) and result is not None:
|
||||
# Execute listeners for the start method name first
|
||||
await self._execute_listeners(start_method_name, result, finished_event_id)
|
||||
# Then execute listeners for the router result (e.g., "approved")
|
||||
router_result_trigger = FlowMethodName(str(result))
|
||||
router_result = result.value if isinstance(result, enum.Enum) else result
|
||||
router_result_trigger = FlowMethodName(str(router_result))
|
||||
listener_result = (
|
||||
self.last_human_feedback
|
||||
if self.last_human_feedback is not None
|
||||
@@ -2597,9 +2554,9 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
e.context.method_name = method_name
|
||||
|
||||
if self.persistence is None:
|
||||
from crewai.flow.persistence import SQLiteFlowPersistence
|
||||
from crewai.flow.persistence.factory import default_flow_persistence
|
||||
|
||||
self.persistence = SQLiteFlowPersistence()
|
||||
self.persistence = default_flow_persistence()
|
||||
|
||||
# Emit paused event (not failed)
|
||||
if not self.suppress_flow_events:
|
||||
@@ -2693,27 +2650,24 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
) = await self._execute_single_listener(
|
||||
router_name, router_input, current_triggering_event_id
|
||||
)
|
||||
if router_result: # Only add non-None results
|
||||
router_result_str = (
|
||||
router_result.value
|
||||
if isinstance(router_result, enum.Enum)
|
||||
else str(router_result)
|
||||
)
|
||||
router_results.append(FlowMethodName(router_result_str))
|
||||
# If this was a human_feedback router, map the outcome to the feedback
|
||||
if self.last_human_feedback is not None:
|
||||
router_result_to_feedback[router_result_str] = (
|
||||
self.last_human_feedback
|
||||
)
|
||||
current_trigger = (
|
||||
FlowMethodName(
|
||||
router_result.value
|
||||
if isinstance(router_result, enum.Enum)
|
||||
else str(router_result)
|
||||
)
|
||||
if router_result is not None
|
||||
else FlowMethodName("")
|
||||
if router_result is None:
|
||||
current_trigger = FlowMethodName("")
|
||||
continue
|
||||
|
||||
router_result = (
|
||||
router_result.value
|
||||
if isinstance(router_result, enum.Enum)
|
||||
else router_result
|
||||
)
|
||||
router_result_str = str(router_result)
|
||||
router_result_event = FlowMethodName(router_result_str)
|
||||
router_results.append(router_result_event)
|
||||
|
||||
if self.last_human_feedback is not None:
|
||||
router_result_to_feedback[router_result_str] = (
|
||||
self.last_human_feedback
|
||||
)
|
||||
current_trigger = router_result_event
|
||||
|
||||
all_triggers = [trigger_method, *router_results]
|
||||
|
||||
@@ -2759,7 +2713,7 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
if current_trigger in router_results:
|
||||
for method_name in type(self)._definition_start_method_names():
|
||||
for method_name in type(self)._start_method_names():
|
||||
if self._start_condition_triggered_by(
|
||||
method_name, current_trigger
|
||||
):
|
||||
@@ -2774,165 +2728,86 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
|
||||
def _evaluate_condition(
|
||||
self,
|
||||
condition: str | FlowMethodName | FlowCondition,
|
||||
condition: FlowDefinitionCondition,
|
||||
trigger_method: FlowMethodName,
|
||||
listener_name: FlowMethodName,
|
||||
pending_key_prefix: str | None = None,
|
||||
) -> bool:
|
||||
"""Recursively evaluate a condition (simple or nested).
|
||||
|
||||
Args:
|
||||
condition: Can be a string (method name) or dict (nested condition)
|
||||
trigger_method: The method that just completed
|
||||
listener_name: Name of the listener being evaluated
|
||||
|
||||
Returns:
|
||||
True if the condition is satisfied, False otherwise
|
||||
"""
|
||||
if isinstance(condition, str):
|
||||
return condition == trigger_method
|
||||
return condition == str(trigger_method)
|
||||
|
||||
def _sub_prefix(index: int) -> str | None:
|
||||
if pending_key_prefix is None:
|
||||
return None
|
||||
return f"{pending_key_prefix}:{index}"
|
||||
|
||||
if is_flow_condition_dict(condition):
|
||||
normalized = _normalize_condition(condition)
|
||||
cond_type = normalized.get("type", OR_CONDITION)
|
||||
sub_conditions = normalized.get("conditions", [])
|
||||
if "or" in condition:
|
||||
# Evaluate every sub-condition (no short-circuit): a nested and_()
|
||||
# branch needs the chance to clear its pending state in
|
||||
# _pending_and_listeners even when an earlier branch already matched.
|
||||
any_matched = False
|
||||
for index, sub_condition in enumerate(condition["or"]):
|
||||
if self._evaluate_condition(
|
||||
sub_condition,
|
||||
trigger_method,
|
||||
listener_name,
|
||||
pending_key_prefix=_sub_prefix(index),
|
||||
):
|
||||
any_matched = True
|
||||
return any_matched
|
||||
|
||||
if cond_type == OR_CONDITION:
|
||||
return any(
|
||||
self._evaluate_condition(
|
||||
sub_cond,
|
||||
trigger_method,
|
||||
listener_name,
|
||||
pending_key_prefix=_sub_prefix(index),
|
||||
)
|
||||
for index, sub_cond in enumerate(sub_conditions)
|
||||
)
|
||||
sub_conditions = condition["and"]
|
||||
pending_key = PendingListenerKey(
|
||||
pending_key_prefix
|
||||
if pending_key_prefix is not None
|
||||
else f"{listener_name}:{id(condition)}"
|
||||
)
|
||||
|
||||
if cond_type == AND_CONDITION:
|
||||
pending_key = PendingListenerKey(
|
||||
pending_key_prefix
|
||||
if pending_key_prefix is not None
|
||||
else f"{listener_name}:{id(condition)}"
|
||||
)
|
||||
if pending_key not in self._pending_and_listeners:
|
||||
self._pending_and_listeners[pending_key] = set(range(len(sub_conditions)))
|
||||
|
||||
if pending_key not in self._pending_and_listeners:
|
||||
all_methods = set(_extract_all_methods(condition))
|
||||
self._pending_and_listeners[pending_key] = all_methods
|
||||
pending_conditions = self._pending_and_listeners[pending_key]
|
||||
for index, sub_condition in enumerate(sub_conditions):
|
||||
if index not in pending_conditions:
|
||||
continue
|
||||
if self._evaluate_condition(
|
||||
sub_condition,
|
||||
trigger_method,
|
||||
listener_name,
|
||||
pending_key_prefix=_sub_prefix(index),
|
||||
):
|
||||
pending_conditions.discard(index)
|
||||
|
||||
if trigger_method in self._pending_and_listeners[pending_key]:
|
||||
self._pending_and_listeners[pending_key].discard(trigger_method)
|
||||
|
||||
direct_methods_satisfied = not self._pending_and_listeners[pending_key]
|
||||
|
||||
nested_conditions_satisfied = all(
|
||||
(
|
||||
self._evaluate_condition(
|
||||
sub_cond,
|
||||
trigger_method,
|
||||
listener_name,
|
||||
pending_key_prefix=_sub_prefix(index),
|
||||
)
|
||||
if is_flow_condition_dict(sub_cond)
|
||||
else True
|
||||
)
|
||||
for index, sub_cond in enumerate(sub_conditions)
|
||||
)
|
||||
|
||||
if direct_methods_satisfied and nested_conditions_satisfied:
|
||||
self._pending_and_listeners.pop(pending_key, None)
|
||||
return True
|
||||
|
||||
return False
|
||||
if not pending_conditions:
|
||||
self._pending_and_listeners.pop(pending_key, None)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _find_triggered_methods(
|
||||
self, trigger_method: FlowMethodName, router_only: bool
|
||||
) -> list[FlowMethodName]:
|
||||
"""Finds all methods that should be triggered based on conditions.
|
||||
|
||||
This internal method evaluates both OR and AND conditions to determine
|
||||
which methods should be executed next in the flow. Supports nested conditions.
|
||||
|
||||
Args:
|
||||
trigger_method: The name of the method that just completed execution.
|
||||
router_only: If True, only consider router methods. If False, only consider non-router methods.
|
||||
|
||||
Returns:
|
||||
Names of methods that should be triggered.
|
||||
|
||||
Note:
|
||||
- Handles both OR and AND conditions, including nested combinations
|
||||
- Maintains state for AND conditions using _pending_and_listeners
|
||||
- Separates router and normal listener evaluation
|
||||
"""
|
||||
triggered: list[FlowMethodName] = []
|
||||
|
||||
for listener_name, condition_data in self._listeners.items():
|
||||
is_router = listener_name in self._routers
|
||||
|
||||
for listener_name, method_definition, condition in type(
|
||||
self
|
||||
)._listener_methods():
|
||||
is_router = method_definition.router
|
||||
if router_only != is_router:
|
||||
continue
|
||||
|
||||
if not router_only and type(self)._definition_has_start(listener_name):
|
||||
should_check_fired = _is_multi_event_or(condition) and not is_router
|
||||
if should_check_fired and listener_name in self._fired_or_listeners:
|
||||
continue
|
||||
|
||||
if is_simple_flow_condition(condition_data):
|
||||
condition_type, methods = condition_data
|
||||
|
||||
if condition_type == OR_CONDITION:
|
||||
# Only trigger multi-source OR listeners (or_(A, B, C)) once - skip if already fired
|
||||
# Simple single-method listeners fire every time their trigger occurs
|
||||
# Routers also fire every time - they're decision points
|
||||
has_multiple_triggers = len(methods) > 1
|
||||
should_check_fired = has_multiple_triggers and not is_router
|
||||
|
||||
if (
|
||||
not should_check_fired
|
||||
or listener_name not in self._fired_or_listeners
|
||||
):
|
||||
if trigger_method in methods:
|
||||
triggered.append(listener_name)
|
||||
# Only track multi-source OR listeners (not single-method or routers)
|
||||
if should_check_fired:
|
||||
self._fired_or_listeners.add(listener_name)
|
||||
elif condition_type == AND_CONDITION:
|
||||
pending_key = PendingListenerKey(listener_name)
|
||||
if pending_key not in self._pending_and_listeners:
|
||||
self._pending_and_listeners[pending_key] = set(methods)
|
||||
if trigger_method in self._pending_and_listeners[pending_key]:
|
||||
self._pending_and_listeners[pending_key].discard(trigger_method)
|
||||
|
||||
if not self._pending_and_listeners[pending_key]:
|
||||
triggered.append(listener_name)
|
||||
self._pending_and_listeners.pop(pending_key, None)
|
||||
|
||||
elif is_flow_condition_dict(condition_data):
|
||||
# For complex conditions, check if top-level is OR and track accordingly
|
||||
top_level_type = condition_data.get("type", OR_CONDITION)
|
||||
is_or_based = top_level_type == OR_CONDITION
|
||||
|
||||
# Only track multi-source OR conditions (multiple sub-conditions), not routers
|
||||
sub_conditions = condition_data.get("conditions", [])
|
||||
has_multiple_triggers = is_or_based and len(sub_conditions) > 1
|
||||
should_check_fired = has_multiple_triggers and not is_router
|
||||
|
||||
# Skip compound OR-based listeners that have already fired
|
||||
if should_check_fired and listener_name in self._fired_or_listeners:
|
||||
continue
|
||||
|
||||
if self._evaluate_condition(
|
||||
condition_data, trigger_method, listener_name
|
||||
):
|
||||
triggered.append(listener_name)
|
||||
# Track compound OR-based listeners so they only fire once
|
||||
if should_check_fired:
|
||||
self._fired_or_listeners.add(listener_name)
|
||||
if self._evaluate_condition(
|
||||
condition,
|
||||
trigger_method,
|
||||
listener_name,
|
||||
):
|
||||
triggered.append(listener_name)
|
||||
if should_check_fired:
|
||||
self._fired_or_listeners.add(listener_name)
|
||||
|
||||
return triggered
|
||||
|
||||
@@ -2984,13 +2859,10 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
|
||||
# For routers, also check if any conditional starts they triggered are completed
|
||||
# If so, continue their chains
|
||||
if listener_name in self._routers:
|
||||
for start_method_name in type(
|
||||
self
|
||||
)._definition_start_method_names():
|
||||
if type(self)._is_router(listener_name):
|
||||
for start_method_name in type(self)._start_method_names():
|
||||
if (
|
||||
type(self)._definition_start_condition(start_method_name)
|
||||
is not None
|
||||
type(self)._start_condition(start_method_name) is not None
|
||||
and start_method_name in self._completed_methods
|
||||
):
|
||||
# This conditional start was executed, continue its chain
|
||||
|
||||
@@ -5,15 +5,7 @@ the Flow system.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
NewType,
|
||||
ParamSpec,
|
||||
Protocol,
|
||||
TypeVar,
|
||||
TypedDict,
|
||||
)
|
||||
from typing import Annotated, Any, NewType, ParamSpec, Protocol, TypeVar, TypedDict
|
||||
|
||||
from typing_extensions import NotRequired, Required
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSourc
|
||||
from crewai.knowledge.source.text_file_knowledge_source import (
|
||||
TextFileKnowledgeSource,
|
||||
)
|
||||
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
|
||||
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
@@ -89,7 +90,7 @@ class Knowledge(BaseModel):
|
||||
Knowledge is a collection of sources and setup for the vector store to save and query relevant context.
|
||||
Args:
|
||||
sources: list[BaseKnowledgeSource] = Field(default_factory=list)
|
||||
storage: KnowledgeStorage | None = Field(default=None)
|
||||
storage: BaseKnowledgeStorage | None = Field(default=None)
|
||||
embedder: EmbedderConfig | None = None
|
||||
"""
|
||||
|
||||
@@ -98,7 +99,7 @@ class Knowledge(BaseModel):
|
||||
BeforeValidator(_resolve_knowledge_sources),
|
||||
] = Field(default_factory=list)
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
storage: KnowledgeStorage | None = Field(default=None)
|
||||
storage: BaseKnowledgeStorage | None = Field(default=None)
|
||||
embedder: Annotated[
|
||||
EmbedderConfig | None,
|
||||
PlainSerializer(
|
||||
@@ -112,15 +113,22 @@ class Knowledge(BaseModel):
|
||||
collection_name: str,
|
||||
sources: list[BaseKnowledgeSource],
|
||||
embedder: EmbedderConfig | None = None,
|
||||
storage: KnowledgeStorage | None = None,
|
||||
storage: BaseKnowledgeStorage | None = None,
|
||||
**data: object,
|
||||
) -> None:
|
||||
super().__init__(**data)
|
||||
if storage:
|
||||
if storage is not None:
|
||||
self.storage = storage
|
||||
else:
|
||||
self.storage = KnowledgeStorage(
|
||||
embedder=embedder, collection_name=collection_name
|
||||
from crewai.knowledge.storage.factory import resolve_knowledge_storage
|
||||
|
||||
custom = resolve_knowledge_storage(embedder, collection_name)
|
||||
self.storage = (
|
||||
custom
|
||||
if custom is not None
|
||||
else KnowledgeStorage(
|
||||
embedder=embedder, collection_name=collection_name
|
||||
)
|
||||
)
|
||||
self.sources = sources
|
||||
|
||||
@@ -152,10 +160,9 @@ class Knowledge(BaseModel):
|
||||
raise e
|
||||
|
||||
def reset(self) -> None:
|
||||
if self.storage:
|
||||
self.storage.reset()
|
||||
else:
|
||||
if self.storage is None:
|
||||
raise ValueError("Storage is not initialized.")
|
||||
self.storage.reset()
|
||||
|
||||
async def aquery(
|
||||
self, query: list[str], results_limit: int = 5, score_threshold: float = 0.6
|
||||
@@ -193,7 +200,6 @@ class Knowledge(BaseModel):
|
||||
|
||||
async def areset(self) -> None:
|
||||
"""Reset the knowledge base asynchronously."""
|
||||
if self.storage:
|
||||
await self.storage.areset()
|
||||
else:
|
||||
if self.storage is None:
|
||||
raise ValueError("Storage is not initialized.")
|
||||
await self.storage.areset()
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
|
||||
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
|
||||
from crewai.utilities.constants import KNOWLEDGE_DIRECTORY
|
||||
from crewai.utilities.logger import Logger
|
||||
|
||||
@@ -22,7 +22,7 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
default_factory=list, description="The path to the file"
|
||||
)
|
||||
content: dict[Path, str] = Field(init=False, default_factory=dict)
|
||||
storage: KnowledgeStorage | None = Field(default=None)
|
||||
storage: BaseKnowledgeStorage | None = Field(default=None)
|
||||
safe_file_paths: list[Path] = Field(default_factory=list)
|
||||
|
||||
@field_validator("file_path", "file_paths", mode="before")
|
||||
@@ -70,14 +70,14 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
|
||||
def _save_documents(self) -> None:
|
||||
"""Save the documents to the storage."""
|
||||
if self.storage:
|
||||
if self.storage is not None:
|
||||
self.storage.save(self.chunks)
|
||||
else:
|
||||
raise ValueError("No storage found to save documents.")
|
||||
|
||||
async def _asave_documents(self) -> None:
|
||||
"""Save the documents to the storage asynchronously."""
|
||||
if self.storage:
|
||||
if self.storage is not None:
|
||||
await self.storage.asave(self.chunks)
|
||||
else:
|
||||
raise ValueError("No storage found to save documents.")
|
||||
|
||||
@@ -4,9 +4,15 @@ from typing import Any
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
|
||||
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
|
||||
|
||||
|
||||
# ``KnowledgeStorage`` is re-exported for backwards compatibility; the ``storage``
|
||||
# field below is typed to the base interface so any backend plugs in.
|
||||
__all__ = ["BaseKnowledgeSource", "KnowledgeStorage"]
|
||||
|
||||
|
||||
class BaseKnowledgeSource(BaseModel, ABC):
|
||||
"""Abstract base class for knowledge sources."""
|
||||
|
||||
@@ -18,7 +24,7 @@ class BaseKnowledgeSource(BaseModel, ABC):
|
||||
)
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
storage: KnowledgeStorage | None = Field(default=None)
|
||||
storage: BaseKnowledgeStorage | None = Field(default=None)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict) # Currently unused
|
||||
collection_name: str | None = Field(default=None)
|
||||
|
||||
@@ -49,7 +55,7 @@ class BaseKnowledgeSource(BaseModel, ABC):
|
||||
Raises:
|
||||
ValueError: If no storage is configured.
|
||||
"""
|
||||
if self.storage:
|
||||
if self.storage is not None:
|
||||
self.storage.save(self.chunks)
|
||||
else:
|
||||
raise ValueError("No storage found to save documents.")
|
||||
@@ -66,7 +72,7 @@ class BaseKnowledgeSource(BaseModel, ABC):
|
||||
Raises:
|
||||
ValueError: If no storage is configured.
|
||||
"""
|
||||
if self.storage:
|
||||
if self.storage is not None:
|
||||
await self.storage.asave(self.chunks)
|
||||
else:
|
||||
raise ValueError("No storage found to save documents.")
|
||||
|
||||
56
lib/crewai/src/crewai/knowledge/storage/factory.py
Normal file
56
lib/crewai/src/crewai/knowledge/storage/factory.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Pluggable default storage backend for knowledge collections.
|
||||
|
||||
By default, :class:`~crewai.knowledge.knowledge.Knowledge` builds a
|
||||
:class:`~crewai.knowledge.storage.knowledge_storage.KnowledgeStorage` when no
|
||||
explicit ``storage=`` is given. Registering a factory via
|
||||
:func:`set_knowledge_storage_factory` lets an application back knowledge with a
|
||||
custom :class:`~crewai.knowledge.storage.base_knowledge_storage.BaseKnowledgeStorage`
|
||||
without subclassing ``Knowledge`` or passing a ``storage=`` instance at every
|
||||
call site.
|
||||
|
||||
This mirrors :func:`crewai_core.lock_store.set_lock_backend`: a one-time,
|
||||
process-wide setter intended for application startup. Pass ``None`` to restore
|
||||
the built-in default.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
|
||||
# Receives the same inputs as the built-in default -- the embedder config and
|
||||
# collection name -- and returns a storage backend, or ``None`` to defer to the
|
||||
# built-in ``KnowledgeStorage``.
|
||||
KnowledgeStorageFactory = Callable[
|
||||
["EmbedderConfig | None", "str | None"], "BaseKnowledgeStorage | None"
|
||||
]
|
||||
|
||||
_factory: KnowledgeStorageFactory | None = None
|
||||
|
||||
|
||||
def set_knowledge_storage_factory(factory: KnowledgeStorageFactory | None) -> None:
|
||||
"""Replace the process-wide default knowledge storage factory.
|
||||
|
||||
Intended for one-time setup at startup. Pass ``None`` to restore the
|
||||
built-in ``KnowledgeStorage``. Only affects ``Knowledge`` instances
|
||||
constructed afterwards; an explicit ``storage=`` instance always wins.
|
||||
"""
|
||||
global _factory
|
||||
_factory = factory
|
||||
|
||||
|
||||
def resolve_knowledge_storage(
|
||||
embedder: EmbedderConfig | None, collection_name: str | None
|
||||
) -> BaseKnowledgeStorage | None:
|
||||
"""Return the registered factory's backend, or ``None`` for the built-in.
|
||||
|
||||
``None`` means no factory is registered or it declined; the caller then
|
||||
falls back to the built-in ``KnowledgeStorage``.
|
||||
"""
|
||||
factory = _factory
|
||||
return factory(embedder, collection_name) if factory is not None else None
|
||||
55
lib/crewai/src/crewai/memory/storage/factory.py
Normal file
55
lib/crewai/src/crewai/memory/storage/factory.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""Pluggable default storage backend for the unified memory system.
|
||||
|
||||
By default, :class:`~crewai.memory.unified_memory.Memory` builds a built-in
|
||||
vector store from its ``storage`` spec string (LanceDB, or Qdrant for the
|
||||
``"qdrant-edge"`` spec). Registering a factory via
|
||||
:func:`set_memory_storage_factory` lets an application route memory through a
|
||||
custom :class:`~crewai.memory.storage.backend.StorageBackend` -- a different
|
||||
vector store, a remote service, an in-memory fake for tests -- without
|
||||
subclassing ``Memory`` or threading an explicit ``storage=`` instance through
|
||||
every construction site.
|
||||
|
||||
This mirrors :func:`crewai_core.lock_store.set_lock_backend`: a one-time,
|
||||
process-wide setter intended for application startup. Pass ``None`` to restore
|
||||
the built-in default.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.memory.storage.backend import StorageBackend
|
||||
|
||||
# Receives the raw ``storage`` spec string and returns a backend to use, or
|
||||
# ``None`` to defer to the built-in selection for that spec.
|
||||
MemoryStorageFactory = Callable[[str], "StorageBackend | None"]
|
||||
|
||||
_factory: MemoryStorageFactory | None = None
|
||||
|
||||
|
||||
def set_memory_storage_factory(factory: MemoryStorageFactory | None) -> None:
|
||||
"""Replace the process-wide default memory storage factory.
|
||||
|
||||
Intended for one-time setup at startup. Pass ``None`` to restore the
|
||||
built-in LanceDB/Qdrant selection. Only affects ``Memory`` instances
|
||||
constructed afterwards; an explicit ``storage=`` instance always wins.
|
||||
|
||||
The factory is consulted for every string ``storage`` spec, so it must
|
||||
return ``None`` for specs it does not handle to let the built-in
|
||||
LanceDB/Qdrant/path selection take over.
|
||||
"""
|
||||
global _factory
|
||||
_factory = factory
|
||||
|
||||
|
||||
def resolve_memory_storage(spec: str) -> StorageBackend | None:
|
||||
"""Return the registered factory's backend for ``spec``, or ``None``.
|
||||
|
||||
``None`` means no factory is registered or it declined this spec; the
|
||||
caller then falls back to the built-in selection.
|
||||
"""
|
||||
factory = _factory
|
||||
return factory(spec) if factory is not None else None
|
||||
@@ -204,7 +204,12 @@ class Memory(BaseModel):
|
||||
)
|
||||
|
||||
if isinstance(self.storage, str):
|
||||
if self.storage == "qdrant-edge":
|
||||
from crewai.memory.storage.factory import resolve_memory_storage
|
||||
|
||||
custom = resolve_memory_storage(self.storage)
|
||||
if custom is not None:
|
||||
self._storage = custom
|
||||
elif self.storage == "qdrant-edge":
|
||||
from crewai.memory.storage.qdrant_edge_storage import QdrantEdgeStorage
|
||||
|
||||
self._storage = QdrantEdgeStorage()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Factory functions for creating RAG clients from configuration."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from crewai.rag.config.optional_imports.protocols import (
|
||||
@@ -11,6 +12,32 @@ from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.utilities.import_utils import require
|
||||
|
||||
|
||||
# RAG uses a provider-keyed registry (rather than the single-default setter
|
||||
# used by the memory/knowledge/flow seams) because ``create_client`` already
|
||||
# dispatches on ``config.provider`` -- the natural seam here is per-provider.
|
||||
# A factory receives the RAG config and returns a client; one registered for a
|
||||
# built-in provider name overrides the built-in for that provider.
|
||||
RagClientFactory = Callable[[RagConfigType], BaseClient]
|
||||
|
||||
_factories: dict[str, RagClientFactory] = {}
|
||||
|
||||
|
||||
def register_rag_client_factory(provider: str, factory: RagClientFactory) -> None:
|
||||
"""Register a client factory for a RAG ``provider`` name.
|
||||
|
||||
Lets an application plug in a client for a new provider, or override a
|
||||
built-in provider (``"chromadb"`` / ``"qdrant"``), without modifying
|
||||
:func:`create_client`. Registered factories take precedence over the
|
||||
built-ins. Intended for one-time setup at startup.
|
||||
"""
|
||||
_factories[provider] = factory
|
||||
|
||||
|
||||
def unregister_rag_client_factory(provider: str) -> None:
|
||||
"""Remove a previously registered factory; a no-op if none is registered."""
|
||||
_factories.pop(provider, None)
|
||||
|
||||
|
||||
def create_client(config: RagConfigType) -> BaseClient:
|
||||
"""Create a client from configuration using the appropriate factory.
|
||||
|
||||
@@ -24,6 +51,10 @@ def create_client(config: RagConfigType) -> BaseClient:
|
||||
ValueError: If the configuration provider is not supported.
|
||||
"""
|
||||
|
||||
factory = _factories.get(config.provider)
|
||||
if factory is not None:
|
||||
return factory(config)
|
||||
|
||||
if config.provider == "chromadb":
|
||||
chromadb_mod = cast(
|
||||
ChromaFactoryModule,
|
||||
|
||||
@@ -23,6 +23,26 @@ def _duplicate_separator_pattern(separator: str) -> re.Pattern[str]:
|
||||
return re.compile(f"(?:{re.escape(separator)}){{2,}}")
|
||||
|
||||
|
||||
def extract_template_variables(input_string: str | None) -> list[str]:
|
||||
"""Return the template variable names referenced in a string.
|
||||
|
||||
Only recognizes placeholders that interpolation can actually fill, i.e.
|
||||
``{name}`` where ``name`` starts with a letter/underscore and contains only
|
||||
letters, numbers, underscores, and hyphens. Expressions such as
|
||||
``{x if x else "y"}`` or JSON snippets are intentionally ignored so they are
|
||||
never treated as required inputs.
|
||||
|
||||
Args:
|
||||
input_string: The string to scan. May be ``None`` or empty.
|
||||
|
||||
Returns:
|
||||
The matched variable names, in order of appearance (with duplicates).
|
||||
"""
|
||||
if not input_string:
|
||||
return []
|
||||
return _VARIABLE_PATTERN.findall(input_string)
|
||||
|
||||
|
||||
def sanitize_tool_name(name: str, max_length: int = _MAX_TOOL_NAME_LENGTH) -> str:
|
||||
"""Sanitize tool name for LLM provider compatibility.
|
||||
|
||||
|
||||
130
lib/crewai/tests/knowledge/test_storage_factory.py
Normal file
130
lib/crewai/tests/knowledge/test_storage_factory.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""Tests for the pluggable knowledge storage factory seam.
|
||||
|
||||
We verify our own logic: the set/get round-trip, that a registered factory is
|
||||
consulted when no explicit ``storage=`` is given (and receives the embedder and
|
||||
collection name), and that an explicit ``storage=`` instance bypasses it.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
import crewai.knowledge.storage.factory as factory
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
|
||||
from crewai.rag.types import SearchResult
|
||||
|
||||
|
||||
class _FakeKnowledgeStorage(BaseKnowledgeStorage):
|
||||
"""Minimal stand-in implementing the abstract interface."""
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: list[str],
|
||||
limit: int = 5,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[SearchResult]:
|
||||
return []
|
||||
|
||||
async def asearch(
|
||||
self,
|
||||
query: list[str],
|
||||
limit: int = 5,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[SearchResult]:
|
||||
return []
|
||||
|
||||
def save(self, documents: list[str]) -> None:
|
||||
return None
|
||||
|
||||
async def asave(self, documents: list[str]) -> None:
|
||||
return None
|
||||
|
||||
def reset(self) -> None:
|
||||
return None
|
||||
|
||||
async def areset(self) -> None:
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_factory():
|
||||
"""Reset the factory around each test without clobbering preexisting state."""
|
||||
original = factory._factory
|
||||
factory.set_knowledge_storage_factory(None)
|
||||
yield
|
||||
factory.set_knowledge_storage_factory(original)
|
||||
|
||||
|
||||
def test_resolve_reflects_registered_factory():
|
||||
fake = _FakeKnowledgeStorage()
|
||||
assert factory.resolve_knowledge_storage(None, "docs") is None
|
||||
|
||||
factory.set_knowledge_storage_factory(lambda embedder, name: fake)
|
||||
assert factory.resolve_knowledge_storage(None, "docs") is fake
|
||||
|
||||
|
||||
def test_factory_used_when_no_explicit_storage():
|
||||
fake = _FakeKnowledgeStorage()
|
||||
factory.set_knowledge_storage_factory(lambda embedder, name: fake)
|
||||
|
||||
knowledge = Knowledge(collection_name="docs", sources=[])
|
||||
|
||||
assert knowledge.storage is fake
|
||||
|
||||
|
||||
def test_factory_receives_embedder_and_collection_name():
|
||||
seen: list[tuple[object, object]] = []
|
||||
|
||||
def make(embedder, collection_name):
|
||||
seen.append((embedder, collection_name))
|
||||
return _FakeKnowledgeStorage()
|
||||
|
||||
factory.set_knowledge_storage_factory(make)
|
||||
Knowledge(collection_name="docs", sources=[])
|
||||
|
||||
assert seen == [(None, "docs")]
|
||||
|
||||
|
||||
def test_explicit_storage_bypasses_factory():
|
||||
factory_called = False
|
||||
|
||||
def make(embedder, name):
|
||||
nonlocal factory_called
|
||||
factory_called = True
|
||||
return _FakeKnowledgeStorage()
|
||||
|
||||
factory.set_knowledge_storage_factory(make)
|
||||
|
||||
explicit = _FakeKnowledgeStorage()
|
||||
knowledge = Knowledge(collection_name="docs", sources=[], storage=explicit)
|
||||
|
||||
assert knowledge.storage is explicit
|
||||
assert factory_called is False
|
||||
|
||||
|
||||
def test_falsy_explicit_storage_is_honored():
|
||||
# A custom backend that is falsy (defines __bool__/__len__) must still be
|
||||
# used and operated on, not silently treated as "not initialized" by a
|
||||
# truthiness check in __init__, reset, or the source save path.
|
||||
reset_calls: list[bool] = []
|
||||
|
||||
class _FalsyStorage(_FakeKnowledgeStorage):
|
||||
def __bool__(self) -> bool:
|
||||
return False
|
||||
|
||||
def reset(self) -> None:
|
||||
reset_calls.append(True)
|
||||
|
||||
explicit = _FalsyStorage()
|
||||
knowledge = Knowledge(collection_name="docs", sources=[], storage=explicit)
|
||||
|
||||
assert knowledge.storage is explicit
|
||||
|
||||
# reset must call the backend, not raise "Storage is not initialized."
|
||||
knowledge.reset()
|
||||
assert reset_calls == [True]
|
||||
72
lib/crewai/tests/memory/test_storage_factory.py
Normal file
72
lib/crewai/tests/memory/test_storage_factory.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""Tests for the pluggable memory storage factory seam.
|
||||
|
||||
We verify our own logic: the set/get round-trip, that a registered factory is
|
||||
consulted for string ``storage`` specs (and receives the spec), and that an
|
||||
explicit ``storage=`` instance bypasses the factory entirely.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
import crewai.memory.storage.factory as factory
|
||||
from crewai.memory.unified_memory import Memory
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_factory():
|
||||
"""Reset the factory around each test without clobbering preexisting state."""
|
||||
original = factory._factory
|
||||
factory.set_memory_storage_factory(None)
|
||||
yield
|
||||
factory.set_memory_storage_factory(original)
|
||||
|
||||
|
||||
def test_resolve_reflects_registered_factory():
|
||||
sentinel = object()
|
||||
assert factory.resolve_memory_storage("lancedb") is None
|
||||
|
||||
factory.set_memory_storage_factory(lambda spec: sentinel)
|
||||
assert factory.resolve_memory_storage("lancedb") is sentinel
|
||||
|
||||
factory.set_memory_storage_factory(None)
|
||||
assert factory.resolve_memory_storage("lancedb") is None
|
||||
|
||||
|
||||
def test_factory_backend_used_for_string_spec():
|
||||
sentinel = object()
|
||||
factory.set_memory_storage_factory(lambda spec: sentinel)
|
||||
|
||||
mem = Memory(storage="lancedb")
|
||||
|
||||
assert mem._storage is sentinel
|
||||
|
||||
|
||||
def test_factory_receives_the_raw_spec():
|
||||
seen: list[str] = []
|
||||
|
||||
def make(spec):
|
||||
seen.append(spec)
|
||||
return object()
|
||||
|
||||
factory.set_memory_storage_factory(make)
|
||||
Memory(storage="some/custom/path")
|
||||
|
||||
assert seen == ["some/custom/path"]
|
||||
|
||||
|
||||
def test_explicit_storage_instance_bypasses_factory():
|
||||
factory_called = False
|
||||
|
||||
def make(spec):
|
||||
nonlocal factory_called
|
||||
factory_called = True
|
||||
return object()
|
||||
|
||||
factory.set_memory_storage_factory(make)
|
||||
|
||||
explicit = object()
|
||||
mem = Memory(storage=explicit) # type: ignore[arg-type]
|
||||
|
||||
assert mem._storage is explicit
|
||||
assert factory_called is False
|
||||
66
lib/crewai/tests/rag/test_client_factory_registry.py
Normal file
66
lib/crewai/tests/rag/test_client_factory_registry.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""Tests for the RAG client factory registry seam.
|
||||
|
||||
We verify our own logic: a registered factory is used for its provider,
|
||||
factories override the built-in providers, unregister removes them, and an
|
||||
unknown provider still raises.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
import crewai.rag.factory as factory
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_registry():
|
||||
"""Reset the registry around each test without clobbering preexisting state."""
|
||||
original = dict(factory._factories)
|
||||
factory._factories.clear()
|
||||
yield
|
||||
factory._factories.clear()
|
||||
factory._factories.update(original)
|
||||
|
||||
|
||||
def test_registered_factory_is_used_for_its_provider():
|
||||
sentinel = object()
|
||||
factory.register_rag_client_factory("custom", lambda config: sentinel)
|
||||
|
||||
assert factory.create_client(SimpleNamespace(provider="custom")) is sentinel
|
||||
|
||||
|
||||
def test_factory_receives_the_config():
|
||||
seen: list[object] = []
|
||||
config = SimpleNamespace(provider="custom")
|
||||
factory.register_rag_client_factory("custom", lambda cfg: seen.append(cfg) or object())
|
||||
|
||||
factory.create_client(config)
|
||||
|
||||
assert seen == [config]
|
||||
|
||||
|
||||
def test_factory_overrides_builtin_provider():
|
||||
sentinel = object()
|
||||
factory.register_rag_client_factory("chromadb", lambda config: sentinel)
|
||||
|
||||
# Resolves via the registry without importing the built-in chromadb factory.
|
||||
assert factory.create_client(SimpleNamespace(provider="chromadb")) is sentinel
|
||||
|
||||
|
||||
def test_unregister_removes_factory():
|
||||
factory.register_rag_client_factory("custom", lambda config: object())
|
||||
factory.unregister_rag_client_factory("custom")
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported provider: custom"):
|
||||
factory.create_client(SimpleNamespace(provider="custom"))
|
||||
|
||||
|
||||
def test_unregister_unknown_provider_is_noop():
|
||||
factory.unregister_rag_client_factory("never-registered")
|
||||
|
||||
|
||||
def test_unknown_provider_still_raises():
|
||||
with pytest.raises(ValueError, match="Unsupported provider: nope"):
|
||||
factory.create_client(SimpleNamespace(provider="nope"))
|
||||
@@ -3895,6 +3895,29 @@ def test_fetch_inputs():
|
||||
)
|
||||
|
||||
|
||||
def test_fetch_inputs_ignores_non_identifier_placeholders():
|
||||
agent = Agent(
|
||||
role="Report writer",
|
||||
goal="Write a report for {company_name}.",
|
||||
backstory="Expert reporter.",
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description=(
|
||||
'Greet {company_name if company_name else "Individual Client"} '
|
||||
"and summarize {search_period}."
|
||||
),
|
||||
expected_output="A summary for {company_name}.",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
|
||||
# Only the simple {company_name} placeholders are returned; the inline conditional
|
||||
# expression (which interpolation cannot fill) is ignored.
|
||||
assert crew.fetch_inputs() == {"company_name", "search_period"}
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_task_tools_preserve_code_execution_tools():
|
||||
"""
|
||||
|
||||
@@ -161,6 +161,27 @@ def test_flow_with_or_condition():
|
||||
)
|
||||
|
||||
|
||||
def test_flow_executes_and_condition_with_single_branch_or():
|
||||
class NestedConditionFlow(Flow):
|
||||
@start()
|
||||
def event_a(self):
|
||||
return "a"
|
||||
|
||||
@listen(event_a)
|
||||
def event_b(self):
|
||||
return "b"
|
||||
|
||||
@router(event_b)
|
||||
def emit_event_c(self):
|
||||
return "event_c"
|
||||
|
||||
@listen(and_(event_a, event_b, or_("event_c")))
|
||||
def event_d(self):
|
||||
return "done"
|
||||
|
||||
assert NestedConditionFlow().kickoff() == "done"
|
||||
|
||||
|
||||
def test_or_listener_fires_once_across_parallel_starts():
|
||||
"""Parallel ``@start`` paths feeding ``or_`` must not double-fire the listener."""
|
||||
fire_count = 0
|
||||
@@ -303,6 +324,90 @@ def test_start_runtime_uses_flow_definition_without_legacy_start_metadata():
|
||||
assert execution_order == ["begin", "route", "branch", "done"]
|
||||
|
||||
|
||||
def test_listen_runtime_uses_flow_definition_without_legacy_listener_metadata():
|
||||
execution_order = []
|
||||
|
||||
class DefinitionListenFlow(Flow):
|
||||
@start()
|
||||
def begin(self):
|
||||
execution_order.append("begin")
|
||||
|
||||
@listen(begin)
|
||||
def by_callable(self):
|
||||
execution_order.append("by_callable")
|
||||
|
||||
@listen(and_(begin, by_callable))
|
||||
def by_and(self):
|
||||
execution_order.append("by_and")
|
||||
|
||||
@listen(or_(and_(begin, by_callable), "fallback"))
|
||||
def nested(self):
|
||||
execution_order.append("nested")
|
||||
|
||||
for method_name in ("by_callable", "by_and", "nested"):
|
||||
method = DefinitionListenFlow.__dict__[method_name]
|
||||
assert not hasattr(method, "__trigger_methods__")
|
||||
assert not hasattr(method, "__condition_type__")
|
||||
assert not hasattr(method, "__trigger_condition__")
|
||||
|
||||
DefinitionListenFlow().kickoff()
|
||||
|
||||
assert execution_order[0] == "begin"
|
||||
assert {"by_callable", "by_and", "nested"}.issubset(execution_order)
|
||||
|
||||
|
||||
def test_router_runtime_uses_flow_definition_without_legacy_router_metadata():
|
||||
execution_order = []
|
||||
|
||||
class DefinitionRouterFlow(Flow):
|
||||
@start()
|
||||
def begin(self):
|
||||
execution_order.append("begin")
|
||||
return "begin"
|
||||
|
||||
@router(begin, emit=["go_left"])
|
||||
def decide(self):
|
||||
execution_order.append("decide")
|
||||
return "go_left"
|
||||
|
||||
@listen("go_left")
|
||||
def handle_left(self):
|
||||
execution_order.append("handle_left")
|
||||
|
||||
route = DefinitionRouterFlow.__dict__["decide"]
|
||||
assert not hasattr(route, "__is_router__")
|
||||
assert not hasattr(route, "__router_emit__")
|
||||
assert not hasattr(route, "__trigger_methods__")
|
||||
assert not hasattr(route, "__condition_type__")
|
||||
assert not hasattr(route, "__trigger_condition__")
|
||||
|
||||
DefinitionRouterFlow().kickoff()
|
||||
|
||||
assert execution_order == ["begin", "decide", "handle_left"]
|
||||
|
||||
|
||||
def test_router_falsy_result_emits_runtime_event():
|
||||
execution_order = []
|
||||
|
||||
class FalsyRouterResultFlow(Flow):
|
||||
@start()
|
||||
def begin(self):
|
||||
execution_order.append("begin")
|
||||
|
||||
@router(begin)
|
||||
def decide(self):
|
||||
execution_order.append("decide")
|
||||
return 0
|
||||
|
||||
@listen("0")
|
||||
def handle_zero(self):
|
||||
execution_order.append("handle_zero")
|
||||
|
||||
FalsyRouterResultFlow().kickoff()
|
||||
|
||||
assert execution_order == ["begin", "decide", "handle_zero"]
|
||||
|
||||
|
||||
def test_async_flow():
|
||||
"""Test an asynchronous flow."""
|
||||
execution_order = []
|
||||
@@ -1436,6 +1541,43 @@ def test_deeply_nested_conditions():
|
||||
assert and_ab_satisfied or and_cd_satisfied
|
||||
|
||||
|
||||
def test_or_branch_does_not_leave_stale_and_state():
|
||||
"""or_() over nested and_() branches must not leave stale pending AND state.
|
||||
|
||||
Regression: evaluating an or_() condition stopped at the first branch that was
|
||||
satisfied, so a later and_() branch that the *same* trigger would have completed
|
||||
never cleared its pending state. On the next cycle that trigger alone then
|
||||
spuriously re-satisfied the whole condition. Both branches share the final
|
||||
event ``x`` here, so the shared trigger that completes branch ``(a AND x)`` also
|
||||
completes branch ``(c AND x)`` and both must be cleared together.
|
||||
"""
|
||||
|
||||
class StaleStateFlow(Flow):
|
||||
@start()
|
||||
def begin(self):
|
||||
pass
|
||||
|
||||
@listen(or_(and_("a", "x"), and_("c", "x")))
|
||||
def joined(self):
|
||||
pass
|
||||
|
||||
flow = StaleStateFlow()
|
||||
condition = type(flow)._listen_condition("joined")
|
||||
|
||||
def fires(trigger):
|
||||
return flow._evaluate_condition(condition, trigger, "joined")
|
||||
|
||||
# First cycle: "a" then "c" arrive, then the shared "x" completes (a AND x).
|
||||
assert fires("a") is False
|
||||
assert fires("c") is False
|
||||
assert fires("x") is True
|
||||
|
||||
# Next cycle: "x" alone must NOT re-satisfy the condition. The "c" from the
|
||||
# previous cycle was consumed when "joined" fired, so neither branch is half
|
||||
# complete and "x" by itself is insufficient.
|
||||
assert fires("x") is False
|
||||
|
||||
|
||||
def test_mixed_sync_async_execution_order():
|
||||
"""Test that execution order is preserved with mixed sync/async methods."""
|
||||
execution_order = []
|
||||
|
||||
@@ -344,6 +344,7 @@ class TestConversationalFlow:
|
||||
"end",
|
||||
}
|
||||
|
||||
@conversational_graph_broken
|
||||
def test_router_infers_custom_routes_without_internal_routes(self) -> None:
|
||||
class ResearchRoute(BaseModel):
|
||||
intent: Literal["research", "converse", "end"]
|
||||
@@ -739,6 +740,7 @@ class TestConversationalFlow:
|
||||
assert flow.state.messages[-1].content == "fresh research"
|
||||
assert flow._is_execution_resuming is False
|
||||
|
||||
@conversational_graph_broken
|
||||
def test_route_catalog_combines_docstrings_builtins_and_overrides(self) -> None:
|
||||
"""Catalog precedence: route_descriptions > built-in > docstring."""
|
||||
|
||||
@@ -770,6 +772,7 @@ class TestConversationalFlow:
|
||||
assert "Ordinary chat" in catalog["converse"]
|
||||
assert "finished" in catalog["end"]
|
||||
|
||||
@conversational_graph_broken
|
||||
def test_route_catalog_falls_back_to_empty_when_no_docstring(self) -> None:
|
||||
@ConversationConfig(router=RouterConfig(routes=["BARE"]))
|
||||
class BareFlow(ConversationalFlow):
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Tests for the static Flow Definition contract."""
|
||||
|
||||
import ast
|
||||
from enum import Enum
|
||||
import importlib
|
||||
import inspect
|
||||
@@ -15,7 +14,6 @@ import crewai.flow.dsl as flow_dsl
|
||||
import crewai.flow.flow_definition as flow_definition
|
||||
import crewai.flow.visualization.builder as visualization_builder
|
||||
from crewai.flow import Flow, and_, human_feedback, listen, or_, persist, router, start
|
||||
from crewai.flow.dsl._conditions import is_flow_condition_dict
|
||||
|
||||
|
||||
def test_flow_public_exports_are_explicit():
|
||||
@@ -50,79 +48,64 @@ def test_flow_public_exports_are_explicit():
|
||||
assert "calculate_node_levels" not in flow_visualization.__all__
|
||||
|
||||
|
||||
def test_flow_condition_dict_accepts_non_string_sequences():
|
||||
condition = {
|
||||
"type": "OR",
|
||||
"conditions": (
|
||||
"approved",
|
||||
{"type": "AND", "methods": ("validated", "processed")},
|
||||
),
|
||||
def test_condition_combinators_return_nested_runtime_tree():
|
||||
condition = and_("event_a", "event_b", or_("event_c"))
|
||||
|
||||
assert condition == {
|
||||
"type": "AND",
|
||||
"conditions": [
|
||||
"event_a",
|
||||
"event_b",
|
||||
{"type": "OR", "conditions": ["event_c"]},
|
||||
],
|
||||
}
|
||||
|
||||
assert is_flow_condition_dict(condition)
|
||||
assert not is_flow_condition_dict({"type": "OR", "conditions": "approved"})
|
||||
assert not is_flow_condition_dict({"type": "OR", "methods": b"approved"})
|
||||
|
||||
def test_flow_definition_lowers_nested_conditions():
|
||||
class NestedFlow(Flow):
|
||||
@start()
|
||||
def begin(self):
|
||||
return "begin"
|
||||
|
||||
@listen(begin)
|
||||
def validated(self):
|
||||
return "validated"
|
||||
|
||||
@listen(begin)
|
||||
def processed(self):
|
||||
return "processed"
|
||||
|
||||
@listen(or_(and_(validated, processed), begin))
|
||||
def finalize(self):
|
||||
return "done"
|
||||
|
||||
finalize = NestedFlow.flow_definition().methods["finalize"]
|
||||
|
||||
assert finalize.listen == {"or": [{"and": ["validated", "processed"]}, "begin"]}
|
||||
|
||||
|
||||
def test_private_flow_helpers_do_not_have_docstrings():
|
||||
import crewai.flow.flow_wrappers as flow_wrappers
|
||||
import crewai.flow.human_feedback as human_feedback
|
||||
import crewai.flow.persistence.decorators as persistence_decorators
|
||||
import crewai.flow.visualization.types as visualization_types
|
||||
def test_flow_definition_preserves_single_branch_nested_conditions():
|
||||
class AmbiguousFlow(Flow):
|
||||
@start()
|
||||
def event_a(self):
|
||||
return "a"
|
||||
|
||||
modules = [
|
||||
flow_dsl,
|
||||
flow_definition,
|
||||
flow_wrappers,
|
||||
human_feedback,
|
||||
persistence_decorators,
|
||||
visualization_builder,
|
||||
visualization_types,
|
||||
]
|
||||
violations: list[str] = []
|
||||
@listen(event_a)
|
||||
def event_b(self):
|
||||
return "b"
|
||||
|
||||
for module in modules:
|
||||
source_path = Path(inspect.getsourcefile(module) or "")
|
||||
tree = ast.parse(source_path.read_text())
|
||||
stack: list[ast.AST] = []
|
||||
if getattr(module, "__all__", None) == [] and ast.get_docstring(tree):
|
||||
violations.append(f"{source_path}:1:<module>")
|
||||
@listen(and_(event_a, event_b, or_("event_c")))
|
||||
def event_d(self):
|
||||
return "d"
|
||||
|
||||
class PrivateDocstringVisitor(ast.NodeVisitor):
|
||||
def visit_ClassDef(self, node: ast.ClassDef) -> None:
|
||||
self._check_docstring(node)
|
||||
stack.append(node)
|
||||
self.generic_visit(node)
|
||||
stack.pop()
|
||||
event_d = AmbiguousFlow.flow_definition().methods["event_d"]
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
|
||||
self._check_docstring(node)
|
||||
stack.append(node)
|
||||
self.generic_visit(node)
|
||||
stack.pop()
|
||||
assert event_d.listen == {"and": ["event_a", "event_b", {"or": ["event_c"]}]}
|
||||
|
||||
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
|
||||
self._check_docstring(node)
|
||||
stack.append(node)
|
||||
self.generic_visit(node)
|
||||
stack.pop()
|
||||
|
||||
def _check_docstring(
|
||||
self,
|
||||
node: ast.ClassDef | ast.FunctionDef | ast.AsyncFunctionDef,
|
||||
) -> None:
|
||||
is_dunder = node.name.startswith("__") and node.name.endswith("__")
|
||||
is_private_name = node.name.startswith("_") and not is_dunder
|
||||
is_nested_function = any(
|
||||
isinstance(parent, (ast.FunctionDef, ast.AsyncFunctionDef))
|
||||
for parent in stack
|
||||
)
|
||||
if (is_private_name or is_nested_function) and ast.get_docstring(node):
|
||||
violations.append(f"{source_path}:{node.lineno}:{node.name}")
|
||||
|
||||
PrivateDocstringVisitor().visit(tree)
|
||||
|
||||
assert violations == []
|
||||
def test_flow_definition_rejects_invalid_condition():
|
||||
with pytest.raises(ValueError, match="Invalid condition"):
|
||||
start(123)(lambda self: None)
|
||||
|
||||
|
||||
def test_flow_definition_contract_is_dsl_agnostic():
|
||||
@@ -304,81 +287,11 @@ def test_flow_definition_fragments_cover_start_listen_and_condition_sugar():
|
||||
|
||||
assert not hasattr(FragmentFlow.__dict__["begin"], "__is_start_method__")
|
||||
assert not hasattr(FragmentFlow.__dict__["restart"], "__trigger_methods__")
|
||||
assert "restart" not in FragmentFlow._listeners
|
||||
assert FragmentFlow._listeners["by_callable"] == ("OR", ["begin"])
|
||||
assert FragmentFlow._listeners["by_string"] == ("OR", ["manual_event"])
|
||||
assert FragmentFlow._listeners["by_and"] == {
|
||||
"type": "AND",
|
||||
"conditions": ["begin", "by_callable"],
|
||||
}
|
||||
assert FragmentFlow._listeners["nested"] == {
|
||||
"type": "OR",
|
||||
"conditions": [
|
||||
{"type": "AND", "conditions": ["manual_event", "by_string"]},
|
||||
"fallback_event",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def test_extract_flow_definition_prefers_fragments_over_legacy_metadata():
|
||||
class RegistryFlow(Flow):
|
||||
@start()
|
||||
def begin(self):
|
||||
return "begin"
|
||||
|
||||
@listen(begin)
|
||||
def handle(self):
|
||||
return "handle"
|
||||
|
||||
@router(handle, emit=["done"])
|
||||
def decide(self):
|
||||
return "done"
|
||||
|
||||
handle = RegistryFlow.__dict__["handle"]
|
||||
original_trigger_methods = handle.__trigger_methods__
|
||||
handle.__trigger_methods__ = ["wrong"]
|
||||
try:
|
||||
_, listeners, routers, router_emit = flow_dsl.extract_flow_definition(
|
||||
{
|
||||
"begin": RegistryFlow.__dict__["begin"],
|
||||
"handle": handle,
|
||||
"decide": RegistryFlow.__dict__["decide"],
|
||||
}
|
||||
)
|
||||
finally:
|
||||
handle.__trigger_methods__ = original_trigger_methods
|
||||
|
||||
assert listeners["handle"] == ("OR", ["begin"])
|
||||
assert listeners["decide"] == ("OR", ["handle"])
|
||||
assert routers == {"decide"}
|
||||
assert router_emit == {"decide": ["done"]}
|
||||
|
||||
|
||||
def test_flow_definition_falls_back_to_legacy_listener_router_metadata_without_fragment():
|
||||
class LegacyMetadataFlow(Flow):
|
||||
@start()
|
||||
def begin(self):
|
||||
return "begin"
|
||||
|
||||
@router(begin, emit=["left"])
|
||||
def decide(self):
|
||||
return "left"
|
||||
|
||||
@listen("left")
|
||||
def left(self):
|
||||
return "left"
|
||||
|
||||
for method_name in ("decide", "left"):
|
||||
method = LegacyMetadataFlow.__dict__[method_name]
|
||||
delattr(method, "__flow_method_definition__")
|
||||
|
||||
definition = flow_dsl.build_flow_definition(LegacyMetadataFlow)
|
||||
|
||||
assert definition.methods["begin"].start is True
|
||||
assert definition.methods["decide"].listen == "begin"
|
||||
assert definition.methods["decide"].router is True
|
||||
assert definition.methods["decide"].emit == ["left"]
|
||||
assert definition.methods["left"].listen == "left"
|
||||
for method_name in ("by_callable", "by_string", "by_and", "nested"):
|
||||
method = FragmentFlow.__dict__[method_name]
|
||||
assert not hasattr(method, "__trigger_methods__")
|
||||
assert not hasattr(method, "__condition_type__")
|
||||
assert not hasattr(method, "__trigger_condition__")
|
||||
|
||||
|
||||
def test_human_feedback_emit_overrides_inner_router_emit():
|
||||
@@ -400,9 +313,6 @@ def test_human_feedback_emit_overrides_inner_router_emit():
|
||||
def proceed(self):
|
||||
return "ok"
|
||||
|
||||
assert "route" in FeedbackOverRouterFlow._routers
|
||||
assert FeedbackOverRouterFlow._router_emit["route"] == ["approved", "rejected"]
|
||||
|
||||
route = FeedbackOverRouterFlow.flow_definition().methods["route"]
|
||||
assert route.router is True
|
||||
assert route.human_feedback is not None
|
||||
|
||||
68
lib/crewai/tests/test_flow_persistence_factory.py
Normal file
68
lib/crewai/tests/test_flow_persistence_factory.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""Tests for the pluggable flow persistence factory seam.
|
||||
|
||||
We verify our own logic: that ``default_flow_persistence`` returns the
|
||||
registered factory's result, and that it falls back to the built-in SQLite
|
||||
persistence when no factory is registered.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
import crewai.flow.persistence.factory as factory
|
||||
from crewai.flow.persistence.base import FlowPersistence
|
||||
from crewai.flow.persistence.decorators import persist
|
||||
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_factory():
|
||||
"""Reset the factory around each test without clobbering preexisting state."""
|
||||
original = factory._factory
|
||||
factory.set_flow_persistence_factory(None)
|
||||
yield
|
||||
factory.set_flow_persistence_factory(original)
|
||||
|
||||
|
||||
def test_default_uses_registered_factory():
|
||||
sentinel = SQLiteFlowPersistence()
|
||||
factory.set_flow_persistence_factory(lambda: sentinel)
|
||||
|
||||
assert factory.default_flow_persistence() is sentinel
|
||||
|
||||
|
||||
def test_default_falls_back_to_sqlite():
|
||||
assert isinstance(factory.default_flow_persistence(), SQLiteFlowPersistence)
|
||||
|
||||
|
||||
def test_persist_decorator_honors_falsy_persistence():
|
||||
# @persist with an explicit but falsy FlowPersistence must keep it, not
|
||||
# replace it with the default via a truthiness check.
|
||||
class _FalsyPersistence(FlowPersistence):
|
||||
def __bool__(self) -> bool:
|
||||
return False
|
||||
|
||||
def init_db(self) -> None:
|
||||
pass
|
||||
|
||||
def save_state(
|
||||
self,
|
||||
flow_uuid: str,
|
||||
method_name: str,
|
||||
state_data: dict[str, Any] | BaseModel,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def load_state(self, flow_uuid: str) -> dict[str, Any] | None:
|
||||
return None
|
||||
|
||||
falsy = _FalsyPersistence()
|
||||
|
||||
@persist(persistence=falsy)
|
||||
class _DummyFlow:
|
||||
pass
|
||||
|
||||
assert _DummyFlow.__flow_persistence_config__.persistence is falsy
|
||||
@@ -78,8 +78,9 @@ class TestHumanFeedbackValidation:
|
||||
return "output"
|
||||
|
||||
assert hasattr(test_method, "__human_feedback_config__")
|
||||
assert test_method.__is_router__ is True
|
||||
assert test_method.__router_emit__ == ["approve", "reject"]
|
||||
assert test_method.__human_feedback_config__.emit == ["approve", "reject"]
|
||||
assert not hasattr(test_method, "__is_router__")
|
||||
assert not hasattr(test_method, "__router_emit__")
|
||||
|
||||
def test_valid_configuration_without_routing(self):
|
||||
"""Test that valid configuration without routing doesn't raise."""
|
||||
@@ -89,7 +90,7 @@ class TestHumanFeedbackValidation:
|
||||
return "output"
|
||||
|
||||
assert hasattr(test_method, "__human_feedback_config__")
|
||||
assert not hasattr(test_method, "__is_router__") or not test_method.__is_router__
|
||||
assert not hasattr(test_method, "__is_router__")
|
||||
|
||||
def test_persist_preserves_human_feedback_llm_attribute(self):
|
||||
"""Test @persist preserves the live LLM stashed by @human_feedback."""
|
||||
@@ -177,8 +178,8 @@ class TestDecoratorAttributePreservation:
|
||||
assert fragment is not None
|
||||
assert fragment.start is True
|
||||
|
||||
def test_preserves_listen_method_attributes(self):
|
||||
"""Test that @human_feedback preserves @listen decorator attributes."""
|
||||
def test_preserves_listen_method_definition(self):
|
||||
"""Test that @human_feedback preserves the @listen method definition."""
|
||||
|
||||
class TestFlow(Flow):
|
||||
@start()
|
||||
@@ -191,12 +192,14 @@ class TestDecoratorAttributePreservation:
|
||||
return "review output"
|
||||
|
||||
flow = TestFlow()
|
||||
assert "review" in flow._listeners or any(
|
||||
"review" in str(v) for v in flow._listeners.values()
|
||||
)
|
||||
method = flow._methods.get("review")
|
||||
assert method is not None
|
||||
fragment = getattr(method, "__flow_method_definition__", None)
|
||||
assert fragment is not None
|
||||
assert fragment.listen == "begin"
|
||||
|
||||
def test_sets_router_attributes_when_emit_specified(self):
|
||||
"""Test that router attributes are set when emit is specified."""
|
||||
def test_emit_is_stored_on_human_feedback_config(self):
|
||||
"""Test that emit outcomes are stored on human feedback config."""
|
||||
|
||||
@human_feedback(
|
||||
message="Review:",
|
||||
@@ -206,8 +209,12 @@ class TestDecoratorAttributePreservation:
|
||||
def review_method(self):
|
||||
return "output"
|
||||
|
||||
assert review_method.__is_router__ is True
|
||||
assert review_method.__router_emit__ == ["approved", "rejected"]
|
||||
assert review_method.__human_feedback_config__.emit == [
|
||||
"approved",
|
||||
"rejected",
|
||||
]
|
||||
assert not hasattr(review_method, "__is_router__")
|
||||
assert not hasattr(review_method, "__router_emit__")
|
||||
|
||||
|
||||
class TestAsyncSupport:
|
||||
|
||||
@@ -1,7 +1,45 @@
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import pytest
|
||||
from crewai.utilities.string_utils import interpolate_only
|
||||
from crewai.utilities.string_utils import (
|
||||
extract_template_variables,
|
||||
interpolate_only,
|
||||
)
|
||||
|
||||
|
||||
class TestExtractTemplateVariables:
|
||||
"""Tests for extract_template_variables in string_utils.py."""
|
||||
|
||||
def test_extracts_simple_identifiers(self):
|
||||
assert extract_template_variables("Hi {name}, see {topic}.") == [
|
||||
"name",
|
||||
"topic",
|
||||
]
|
||||
|
||||
def test_allows_underscores_and_hyphens(self):
|
||||
assert extract_template_variables("{user_name} {role-detail}") == [
|
||||
"user_name",
|
||||
"role-detail",
|
||||
]
|
||||
|
||||
def test_ignores_inline_expressions(self):
|
||||
text = '{company_name if company_name else "Individual Client"}'
|
||||
assert extract_template_variables(text) == []
|
||||
|
||||
def test_ignores_json_like_braces(self):
|
||||
assert extract_template_variables('{"key": "value"}') == []
|
||||
|
||||
def test_matches_what_interpolation_fills(self):
|
||||
text = 'Use {topic} and {x if x else "y"}.'
|
||||
variables = extract_template_variables(text)
|
||||
assert variables == ["topic"]
|
||||
# interpolation fills exactly the extracted variable and leaves the rest
|
||||
result = interpolate_only(text, {"topic": "AI"})
|
||||
assert result == 'Use AI and {x if x else "y"}.'
|
||||
|
||||
@pytest.mark.parametrize("value", [None, ""])
|
||||
def test_handles_empty_input(self, value):
|
||||
assert extract_template_variables(value) == []
|
||||
|
||||
|
||||
class TestInterpolateOnly:
|
||||
|
||||
Reference in New Issue
Block a user