mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-06-17 14:18:10 +00:00
Compare commits
2 Commits
joaomdmour
...
flow-scrip
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
386a1650da | ||
|
|
7bb9bc7e1a |
@@ -2275,8 +2275,6 @@ class Crew(FlowTrackable, BaseModel):
|
||||
"""
|
||||
|
||||
def default_reset(memory: Any) -> Any:
|
||||
if isinstance(memory, Memory):
|
||||
return memory.reset_all()
|
||||
return memory.reset()
|
||||
|
||||
def knowledge_reset(memory: Any) -> Any:
|
||||
|
||||
@@ -15,10 +15,13 @@ from crewai.flow.flow_definition import (
|
||||
FlowConversationalRouterDefinition,
|
||||
FlowDefinition,
|
||||
FlowDefinitionDiagnostic,
|
||||
FlowDictStateDefinition,
|
||||
FlowHumanFeedbackDefinition,
|
||||
FlowMethodDefinition,
|
||||
FlowPersistenceDefinition,
|
||||
FlowPydanticStateDefinition,
|
||||
FlowStateDefinition,
|
||||
FlowUnknownStateDefinition,
|
||||
_object_ref,
|
||||
)
|
||||
from crewai.flow.flow_wrappers import (
|
||||
@@ -185,12 +188,11 @@ def _build_state_definition(
|
||||
default = None
|
||||
if isinstance(state_value, dict):
|
||||
default = _serialize_static_value(state_value, diagnostics, "state.default")
|
||||
return FlowStateDefinition(type="dict", default=default)
|
||||
return FlowDictStateDefinition(default=default)
|
||||
if isinstance(state_value, type) and issubclass(state_value, PydanticBaseModel):
|
||||
return FlowStateDefinition(type="pydantic", ref=_state_ref(state_value))
|
||||
return FlowPydanticStateDefinition(ref=_state_ref(state_value))
|
||||
if isinstance(state_value, PydanticBaseModel):
|
||||
return FlowStateDefinition(
|
||||
type="pydantic",
|
||||
return FlowPydanticStateDefinition(
|
||||
ref=_state_ref(state_value),
|
||||
default=_serialize_static_value(state_value, diagnostics, "state.default"),
|
||||
)
|
||||
@@ -201,7 +203,7 @@ def _build_state_definition(
|
||||
message=f"could not serialize state type {_object_ref(state_value)}",
|
||||
)
|
||||
)
|
||||
return FlowStateDefinition(type="unknown", ref=_state_ref(state_value))
|
||||
return FlowUnknownStateDefinition(ref=_state_ref(state_value))
|
||||
|
||||
|
||||
def _build_config_definition(
|
||||
|
||||
@@ -12,7 +12,7 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Literal as TypingLiteral
|
||||
from typing import Annotated, Any, Literal as TypingLiteral, TypeAlias
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
@@ -46,14 +46,19 @@ __all__ = [
|
||||
"FlowDefinition",
|
||||
"FlowDefinitionCondition",
|
||||
"FlowDefinitionDiagnostic",
|
||||
"FlowDictStateDefinition",
|
||||
"FlowEachActionDefinition",
|
||||
"FlowEachInnerActionDefinition",
|
||||
"FlowExpressionActionDefinition",
|
||||
"FlowHumanFeedbackDefinition",
|
||||
"FlowJsonSchemaStateDefinition",
|
||||
"FlowMethodDefinition",
|
||||
"FlowPersistenceDefinition",
|
||||
"FlowPydanticStateDefinition",
|
||||
"FlowScriptActionDefinition",
|
||||
"FlowStateDefinition",
|
||||
"FlowToolActionDefinition",
|
||||
"FlowUnknownStateDefinition",
|
||||
]
|
||||
|
||||
|
||||
@@ -74,13 +79,114 @@ class FlowDefinitionDiagnostic(BaseModel):
|
||||
path: str | None = None
|
||||
|
||||
|
||||
class FlowStateDefinition(BaseModel):
|
||||
"""Static description of a Flow state contract."""
|
||||
class FlowDictStateDefinition(BaseModel):
|
||||
"""Static description of a plain dictionary Flow state contract."""
|
||||
|
||||
type: TypingLiteral["dict", "pydantic", "json_schema", "unknown"] = "dict"
|
||||
ref: str | None = None
|
||||
json_schema: dict[str, Any] | None = None
|
||||
default: dict[str, Any] | None = None
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
type: TypingLiteral["dict"] = Field(
|
||||
default="dict",
|
||||
description="Plain dictionary state with optional default values.",
|
||||
examples=["dict"],
|
||||
)
|
||||
default: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="Default state values applied before kickoff inputs.",
|
||||
examples=[{"topic": "AI agents", "limit": 3}],
|
||||
)
|
||||
|
||||
|
||||
class FlowPydanticStateDefinition(BaseModel):
|
||||
"""Static description of an importable Pydantic Flow state contract."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
type: TypingLiteral["pydantic"] = Field(
|
||||
default="pydantic",
|
||||
description="Importable Pydantic model used as the Flow state type.",
|
||||
examples=["pydantic"],
|
||||
)
|
||||
ref: str | None = Field(
|
||||
default=None,
|
||||
description="Import reference for the state model, formatted as module:qualname.",
|
||||
examples=["my_project.flows:ResearchState"],
|
||||
)
|
||||
json_schema: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Fallback JSON Schema used when the Pydantic state ref is unavailable."
|
||||
),
|
||||
examples=[
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"topic": {"type": "string"}},
|
||||
"required": ["topic"],
|
||||
}
|
||||
],
|
||||
)
|
||||
default: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="Default state values applied before kickoff inputs.",
|
||||
examples=[{"topic": "AI agents", "limit": 3}],
|
||||
)
|
||||
|
||||
|
||||
class FlowJsonSchemaStateDefinition(BaseModel):
|
||||
"""Static description of an inline JSON Schema Flow state contract."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
type: TypingLiteral["json_schema"] = Field(
|
||||
default="json_schema",
|
||||
description="Inline JSON Schema used as the Flow state contract.",
|
||||
examples=["json_schema"],
|
||||
)
|
||||
json_schema: dict[str, Any] = Field(
|
||||
description="JSON Schema used to validate and document flow state.",
|
||||
examples=[
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {"topic": {"type": "string"}},
|
||||
"required": ["topic"],
|
||||
}
|
||||
],
|
||||
)
|
||||
default: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="Default state values applied before kickoff inputs.",
|
||||
examples=[{"topic": "AI agents", "limit": 3}],
|
||||
)
|
||||
|
||||
|
||||
class FlowUnknownStateDefinition(BaseModel):
|
||||
"""Static description of a state contract that could not be serialized."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
type: TypingLiteral["unknown"] = Field(
|
||||
default="unknown",
|
||||
description="Unknown state representation; runtime falls back to dictionary state.",
|
||||
examples=["unknown"],
|
||||
)
|
||||
ref: str | None = Field(
|
||||
default=None,
|
||||
description="Best-effort import reference for the unknown state type.",
|
||||
examples=["my_project.flows:CustomState"],
|
||||
)
|
||||
default: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="Default state values applied before kickoff inputs.",
|
||||
examples=[{"topic": "AI agents", "limit": 3}],
|
||||
)
|
||||
|
||||
|
||||
FlowStateDefinition: TypeAlias = Annotated[
|
||||
FlowDictStateDefinition
|
||||
| FlowPydanticStateDefinition
|
||||
| FlowJsonSchemaStateDefinition
|
||||
| FlowUnknownStateDefinition,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class FlowConfigDefinition(BaseModel):
|
||||
@@ -196,11 +302,39 @@ class FlowExpressionActionDefinition(BaseModel):
|
||||
expr: str
|
||||
|
||||
|
||||
class FlowScriptActionDefinition(BaseModel):
|
||||
"""A Flow method action that executes trusted inline Python."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
call: TypingLiteral["script"] = Field(
|
||||
description="Action discriminator. Use script to execute trusted inline Python.",
|
||||
examples=["script"],
|
||||
)
|
||||
code: str = Field(
|
||||
description=(
|
||||
"Trusted Python source executed as a generated function. Runtime values are "
|
||||
"passed as state, outputs, input, and item; they are not interpolated into "
|
||||
"the source. This is not sandboxed."
|
||||
),
|
||||
examples=[
|
||||
"state['normalized_topic'] = input.strip()\n"
|
||||
"return state['normalized_topic']"
|
||||
],
|
||||
)
|
||||
language: TypingLiteral["python"] = Field(
|
||||
default="python",
|
||||
description="Script language. Only python is currently supported.",
|
||||
examples=["python"],
|
||||
)
|
||||
|
||||
|
||||
FlowInnerActionDefinition = (
|
||||
FlowCodeActionDefinition
|
||||
| FlowToolActionDefinition
|
||||
| FlowCrewActionDefinition
|
||||
| FlowExpressionActionDefinition
|
||||
| FlowScriptActionDefinition
|
||||
)
|
||||
|
||||
|
||||
@@ -252,6 +386,7 @@ FlowActionDefinition = (
|
||||
| FlowToolActionDefinition
|
||||
| FlowCrewActionDefinition
|
||||
| FlowExpressionActionDefinition
|
||||
| FlowScriptActionDefinition
|
||||
| FlowEachActionDefinition
|
||||
)
|
||||
|
||||
|
||||
@@ -193,26 +193,24 @@ def _build_definition_state_model(
|
||||
kwargs = dict(state_definition.default or {})
|
||||
|
||||
model_class: type[BaseModel] | None = None
|
||||
if state_definition.ref:
|
||||
state_ref = getattr(state_definition, "ref", None)
|
||||
if state_ref:
|
||||
try:
|
||||
resolved: Any = resolve_ref(state_definition.ref, field="state")
|
||||
resolved: Any = resolve_ref(state_ref, field="state")
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Could not import state ref %r", state_definition.ref, exc_info=True
|
||||
)
|
||||
logger.warning("Could not import state ref %r", state_ref, exc_info=True)
|
||||
else:
|
||||
if isinstance(resolved, type) and issubclass(resolved, BaseModel):
|
||||
model_class = resolved
|
||||
else:
|
||||
logger.warning(
|
||||
"State ref %r is not a pydantic model", state_definition.ref
|
||||
)
|
||||
logger.warning("State ref %r is not a pydantic model", state_ref)
|
||||
|
||||
if model_class is None and state_definition.json_schema:
|
||||
json_schema = getattr(state_definition, "json_schema", None)
|
||||
if model_class is None and json_schema:
|
||||
from crewai.utilities.pydantic_schema_utils import create_model_from_schema
|
||||
|
||||
try:
|
||||
model_class = create_model_from_schema(state_definition.json_schema)
|
||||
model_class = create_model_from_schema(json_schema)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Could not build a state model from the declared json_schema",
|
||||
@@ -1092,6 +1090,8 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
def build(name: str, definition: FlowMethodDefinition) -> Callable[..., Any]:
|
||||
try:
|
||||
return build_action(self, definition.do)
|
||||
except RuntimeError:
|
||||
raise
|
||||
except Exception as e:
|
||||
unresolved.append(f"{name}: {e}")
|
||||
return lambda *args, **kwargs: None
|
||||
|
||||
@@ -2,10 +2,12 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
import contextvars
|
||||
import inspect
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Protocol, cast
|
||||
|
||||
from crewai.flow.flow_definition import (
|
||||
@@ -15,9 +17,11 @@ from crewai.flow.flow_definition import (
|
||||
FlowEachActionDefinition,
|
||||
FlowEachInnerActionDefinition,
|
||||
FlowExpressionActionDefinition,
|
||||
FlowScriptActionDefinition,
|
||||
FlowToolActionDefinition,
|
||||
)
|
||||
from crewai.flow.runtime._expressions import evaluate_expression, render_with_block
|
||||
from crewai.flow.runtime._outputs import outputs_by_name
|
||||
from crewai.flow.runtime._refs import InvalidRefError, resolve_ref
|
||||
|
||||
|
||||
@@ -29,6 +33,8 @@ __all__ = ["build_action"]
|
||||
|
||||
LocalContext = dict[str, Any]
|
||||
_LOCAL_CONTEXT_KWARG = "__flow_definition_local_context"
|
||||
_ALLOW_SCRIPT_EXECUTION_ENV_VAR = "CREWAI_ALLOW_FLOW_SCRIPT_EXECUTION"
|
||||
_TRUSTED_SCRIPT_EXECUTION_VALUES = frozenset({"1", "true", "yes"})
|
||||
|
||||
|
||||
class _BuiltAction(Protocol):
|
||||
@@ -140,6 +146,62 @@ class ExpressionAction:
|
||||
)
|
||||
|
||||
|
||||
class ScriptAction:
|
||||
definition_type = FlowScriptActionDefinition
|
||||
|
||||
def __init__(self, flow: Flow[Any], definition: FlowScriptActionDefinition) -> None:
|
||||
self.flow = flow
|
||||
self.definition = definition
|
||||
self.handler = self._compile_handler()
|
||||
|
||||
def run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
local_context = _pop_local_context(kwargs)
|
||||
return self.handler(
|
||||
state=self.flow.state,
|
||||
outputs=outputs_by_name(
|
||||
self.flow._method_outputs,
|
||||
local_outputs=local_context.get("outputs") if local_context else None,
|
||||
),
|
||||
input=args[0] if args else None,
|
||||
item=local_context.get("item") if local_context else None,
|
||||
)
|
||||
|
||||
def _compile_handler(self) -> Callable[..., Any]:
|
||||
raw = os.environ.get(_ALLOW_SCRIPT_EXECUTION_ENV_VAR, "")
|
||||
if raw.strip().lower() not in _TRUSTED_SCRIPT_EXECUTION_VALUES:
|
||||
raise RuntimeError(
|
||||
"Flow script execution is disabled by default. "
|
||||
f"Set {_ALLOW_SCRIPT_EXECUTION_ENV_VAR}=1 to enable it only for "
|
||||
"trusted flow definitions."
|
||||
)
|
||||
|
||||
filename = f"crewai.flow.script.{self.flow._definition.name}"
|
||||
module = ast.parse(self.definition.code, filename=filename)
|
||||
function = ast.FunctionDef(
|
||||
name="_flow_script",
|
||||
args=ast.arguments(
|
||||
posonlyargs=[],
|
||||
args=[ast.arg(arg) for arg in ("state", "outputs", "input", "item")],
|
||||
vararg=None,
|
||||
kwonlyargs=[],
|
||||
kw_defaults=[],
|
||||
kwarg=None,
|
||||
defaults=[],
|
||||
),
|
||||
body=module.body or [ast.Pass()],
|
||||
decorator_list=[],
|
||||
returns=None,
|
||||
type_comment=None,
|
||||
type_params=[],
|
||||
)
|
||||
module.body = [function]
|
||||
ast.fix_missing_locations(module)
|
||||
|
||||
namespace: dict[str, Any] = {"__name__": filename}
|
||||
exec(compile(module, filename, "exec"), namespace) # nosec B102 # noqa: S102
|
||||
return cast(Callable[..., Any], namespace["_flow_script"])
|
||||
|
||||
|
||||
class EachAction:
|
||||
definition_type = FlowEachActionDefinition
|
||||
|
||||
@@ -199,6 +261,7 @@ _ACTION_TYPES: tuple[_ActionType, ...] = (
|
||||
ToolAction,
|
||||
CrewAction,
|
||||
ExpressionAction,
|
||||
ScriptAction,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import json
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from crewai.flow.runtime._outputs import outputs_by_name
|
||||
from crewai.utilities.serialization import to_serializable
|
||||
|
||||
|
||||
@@ -44,7 +45,12 @@ def evaluate_expression(
|
||||
def _expression_context(
|
||||
flow: Flow[Any], local_context: dict[str, Any] | None = None
|
||||
) -> dict[str, Any]:
|
||||
outputs = _outputs_by_name(flow._method_outputs)
|
||||
local_outputs = local_context.get("outputs") if local_context else None
|
||||
outputs = outputs_by_name(
|
||||
flow._method_outputs,
|
||||
local_outputs=local_outputs,
|
||||
serialize=True,
|
||||
)
|
||||
context: dict[str, Any] = {
|
||||
"state": flow._copy_and_serialize_state(),
|
||||
"outputs": outputs,
|
||||
@@ -53,29 +59,12 @@ def _expression_context(
|
||||
local_values = {
|
||||
key: to_serializable(value, max_depth=0)
|
||||
for key, value in local_context.items()
|
||||
if key not in {"outputs", "state"}
|
||||
}
|
||||
local_outputs = local_values.pop("outputs", None)
|
||||
local_values.pop("state", None)
|
||||
context.update(local_values)
|
||||
if local_outputs is not None:
|
||||
if not isinstance(local_outputs, dict):
|
||||
raise TypeError("flow definition local outputs must be a mapping")
|
||||
context["outputs"] = {**outputs, **local_outputs}
|
||||
return context
|
||||
|
||||
|
||||
def _outputs_by_name(method_outputs: list[Any]) -> dict[str, Any]:
|
||||
outputs: dict[str, Any] = {}
|
||||
for entry in method_outputs:
|
||||
method = ""
|
||||
output = entry
|
||||
if isinstance(entry, dict) and "output" in entry:
|
||||
method = str(entry.get("method", ""))
|
||||
output = entry["output"]
|
||||
outputs[method] = to_serializable(output, max_depth=0)
|
||||
return outputs
|
||||
|
||||
|
||||
def _render_value(value: Any, context: dict[str, Any]) -> Any:
|
||||
if isinstance(value, str):
|
||||
return _render_string(value, context)
|
||||
|
||||
40
lib/crewai/src/crewai/flow/runtime/_outputs.py
Normal file
40
lib/crewai/src/crewai/flow/runtime/_outputs.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""Shared FlowDefinition runtime output helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from crewai.utilities.serialization import to_serializable
|
||||
|
||||
|
||||
class _MethodOutput(TypedDict):
|
||||
method: str
|
||||
output: Any
|
||||
|
||||
|
||||
def outputs_by_name(
|
||||
method_outputs: list[_MethodOutput],
|
||||
*,
|
||||
local_outputs: Mapping[str, Any] | None = None,
|
||||
serialize: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
outputs: dict[str, Any] = {}
|
||||
for entry in method_outputs:
|
||||
outputs[entry["method"]] = _output_value(entry["output"], serialize=serialize)
|
||||
|
||||
if local_outputs is not None:
|
||||
outputs.update(
|
||||
{
|
||||
key: _output_value(output, serialize=serialize)
|
||||
for key, output in local_outputs.items()
|
||||
}
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def _output_value(value: Any, *, serialize: bool) -> Any:
|
||||
if not serialize:
|
||||
return value
|
||||
return to_serializable(value, max_depth=0)
|
||||
@@ -149,7 +149,6 @@ class Memory(BaseModel):
|
||||
)
|
||||
_pending_saves: list[Future[Any]] = PrivateAttr(default_factory=list)
|
||||
_pending_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
|
||||
_reset_lock: Any = PrivateAttr(default_factory=threading.RLock)
|
||||
|
||||
def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Memory:
|
||||
"""Deepcopy that handles unpickleable private attrs (ThreadPoolExecutor, Lock)."""
|
||||
@@ -169,10 +168,7 @@ class Memory(BaseModel):
|
||||
)
|
||||
private = {}
|
||||
for k, v in (self.__pydantic_private__ or {}).items():
|
||||
if k in {"_save_pool", "_pending_lock", "_reset_lock"}:
|
||||
attr = self.__private_attributes__[k]
|
||||
private[k] = attr.get_default()
|
||||
elif isinstance(v, (ThreadPoolExecutor, threading.Lock)):
|
||||
if isinstance(v, (ThreadPoolExecutor, threading.Lock)):
|
||||
attr = self.__private_attributes__[k]
|
||||
private[k] = attr.get_default()
|
||||
else:
|
||||
@@ -279,25 +275,22 @@ class Memory(BaseModel):
|
||||
If the pool has been shut down (e.g. after ``close()``), the save
|
||||
runs synchronously as a fallback so late saves still succeed.
|
||||
"""
|
||||
with self._reset_lock:
|
||||
ctx = contextvars.copy_context()
|
||||
ctx = contextvars.copy_context()
|
||||
try:
|
||||
future: Future[Any] = self._save_pool.submit(ctx.run, fn, *args, **kwargs)
|
||||
except RuntimeError:
|
||||
# Pool shut down -- run synchronously as fallback
|
||||
future = Future()
|
||||
try:
|
||||
future: Future[Any] = self._save_pool.submit(
|
||||
ctx.run, fn, *args, **kwargs
|
||||
)
|
||||
except RuntimeError:
|
||||
# Pool shut down -- run synchronously as fallback
|
||||
future = Future()
|
||||
try:
|
||||
result = fn(*args, **kwargs)
|
||||
future.set_result(result)
|
||||
except Exception as exc:
|
||||
future.set_exception(exc)
|
||||
return future
|
||||
with self._pending_lock:
|
||||
self._pending_saves.append(future)
|
||||
future.add_done_callback(self._on_save_done)
|
||||
result = fn(*args, **kwargs)
|
||||
future.set_result(result)
|
||||
except Exception as exc:
|
||||
future.set_exception(exc)
|
||||
return future
|
||||
with self._pending_lock:
|
||||
self._pending_saves.append(future)
|
||||
future.add_done_callback(self._on_save_done)
|
||||
return future
|
||||
|
||||
def _on_save_done(self, future: Future[Any]) -> None:
|
||||
"""Remove a completed future from the pending list and emit failure event if needed.
|
||||
@@ -997,20 +990,12 @@ class Memory(BaseModel):
|
||||
scope: Scope to reset. If None and root_scope is set, resets only
|
||||
within root_scope. If None and no root_scope, resets all.
|
||||
"""
|
||||
with self._reset_lock:
|
||||
self.drain_writes()
|
||||
effective_scope = scope
|
||||
if effective_scope is None and self.root_scope:
|
||||
effective_scope = self.root_scope
|
||||
elif effective_scope is not None and self.root_scope:
|
||||
effective_scope = join_scope_paths(self.root_scope, effective_scope)
|
||||
self._storage.reset(scope_prefix=effective_scope)
|
||||
|
||||
def reset_all(self) -> None:
|
||||
"""Reset the entire backing memory store, ignoring ``root_scope``."""
|
||||
with self._reset_lock:
|
||||
self.drain_writes()
|
||||
self._storage.reset(scope_prefix=None)
|
||||
effective_scope = scope
|
||||
if effective_scope is None and self.root_scope:
|
||||
effective_scope = self.root_scope
|
||||
elif effective_scope is not None and self.root_scope:
|
||||
effective_scope = join_scope_paths(self.root_scope, effective_scope)
|
||||
self._storage.reset(scope_prefix=effective_scope)
|
||||
|
||||
async def aextract_memories(self, content: str) -> list[str]:
|
||||
"""Async variant of extract_memories."""
|
||||
|
||||
@@ -6,10 +6,7 @@ from typing import Any
|
||||
import click
|
||||
|
||||
from crewai.flow import Flow
|
||||
from crewai.memory.unified_memory import Memory
|
||||
from crewai.project.crew_loader import load_crew
|
||||
from crewai.project.json_loader import find_crew_json_file
|
||||
from crewai.utilities.project_utils import get_crews, get_flows, read_toml
|
||||
from crewai.utilities.project_utils import get_crews, get_flows
|
||||
|
||||
|
||||
def _reset_flow_memory(flow: Flow[Any]) -> None:
|
||||
@@ -26,9 +23,7 @@ def _reset_flow_memory(flow: Flow[Any]) -> None:
|
||||
if mem is None:
|
||||
return
|
||||
try:
|
||||
if isinstance(mem, Memory):
|
||||
mem.reset_all()
|
||||
elif hasattr(mem, "reset"):
|
||||
if hasattr(mem, "reset"):
|
||||
mem.reset()
|
||||
elif hasattr(mem, "_memory") and mem._memory is not None:
|
||||
mem._memory.reset()
|
||||
@@ -42,38 +37,6 @@ def _reset_flow_memory(flow: Flow[Any]) -> None:
|
||||
click.echo(f"Memory reset skipped: {exc}", err=True)
|
||||
|
||||
|
||||
def _current_project_declares_flow() -> bool:
|
||||
try:
|
||||
pyproject_data = read_toml()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
declared_type: str | None = (
|
||||
pyproject_data.get("tool", {}).get("crewai", {}).get("type")
|
||||
)
|
||||
return declared_type == "flow"
|
||||
|
||||
|
||||
def _get_json_crew() -> Any | None:
|
||||
"""Load a JSON-first crew from the current project, if present."""
|
||||
if _current_project_declares_flow():
|
||||
return None
|
||||
|
||||
crew_path = find_crew_json_file()
|
||||
if crew_path is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
crew, _ = load_crew(crew_path)
|
||||
except Exception as exc:
|
||||
click.echo(
|
||||
f"Skipping JSON crew at {crew_path}: failed to load ({exc}).",
|
||||
err=True,
|
||||
)
|
||||
return None
|
||||
return crew
|
||||
|
||||
|
||||
def reset_memories_command(
|
||||
memory: bool,
|
||||
knowledge: bool,
|
||||
@@ -98,8 +61,6 @@ def reset_memories_command(
|
||||
return
|
||||
|
||||
crews = get_crews()
|
||||
if json_crew := _get_json_crew():
|
||||
crews.append(json_crew)
|
||||
flows = get_flows()
|
||||
|
||||
if not crews and not flows:
|
||||
|
||||
@@ -4,12 +4,10 @@ Non-core CLI tests (train, test, version, deploy, login, flow_add_crew)
|
||||
have moved to lib/cli/tests/test_cli.py.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
from click.testing import CliRunner
|
||||
from crewai.crew import Crew
|
||||
from crewai.memory.unified_memory import Memory
|
||||
from crewai_cli.cli import reset_memories
|
||||
import pytest
|
||||
|
||||
@@ -32,8 +30,6 @@ def mock_get_crews(mock_crew):
|
||||
"crewai.utilities.reset_memories.get_crews", return_value=[mock_crew]
|
||||
) as mock_get_crew, mock.patch(
|
||||
"crewai.utilities.reset_memories.get_flows", return_value=[]
|
||||
), mock.patch(
|
||||
"crewai.utilities.reset_memories._get_json_crew", return_value=None
|
||||
):
|
||||
yield mock_get_crew
|
||||
|
||||
@@ -174,8 +170,6 @@ def mock_get_flows(mock_flow):
|
||||
"crewai.utilities.reset_memories.get_flows", return_value=[mock_flow]
|
||||
) as mock_get_flow, mock.patch(
|
||||
"crewai.utilities.reset_memories.get_crews", return_value=[]
|
||||
), mock.patch(
|
||||
"crewai.utilities.reset_memories._get_json_crew", return_value=None
|
||||
):
|
||||
yield mock_get_flow
|
||||
|
||||
@@ -186,33 +180,6 @@ def test_reset_flow_memory(mock_get_flows, mock_flow, runner):
|
||||
assert "[Flow (TestFlow)] Memory has been reset." in result.output
|
||||
|
||||
|
||||
def test_reset_flow_unified_memory_uses_full_reset(runner, tmp_path):
|
||||
flow = mock.Mock()
|
||||
flow.name = "TestFlow"
|
||||
flow.memory = Memory(
|
||||
storage=str(tmp_path / "db"),
|
||||
llm=mock.Mock(),
|
||||
embedder=lambda texts: [[0.1] * 4 for _ in texts],
|
||||
)
|
||||
|
||||
with mock.patch(
|
||||
"crewai.utilities.reset_memories.get_flows", return_value=[flow]
|
||||
), mock.patch(
|
||||
"crewai.utilities.reset_memories.get_crews", return_value=[]
|
||||
), mock.patch(
|
||||
"crewai.utilities.reset_memories._get_json_crew", return_value=None
|
||||
), mock.patch.object(
|
||||
Memory, "reset_all"
|
||||
) as reset_all, mock.patch.object(
|
||||
Memory, "reset"
|
||||
) as reset:
|
||||
result = runner.invoke(reset_memories, ["-m"])
|
||||
|
||||
reset_all.assert_called_once_with()
|
||||
reset.assert_not_called()
|
||||
assert "[Flow (TestFlow)] Memory has been reset." in result.output
|
||||
|
||||
|
||||
def test_reset_flow_all_memories(mock_get_flows, mock_flow, runner):
|
||||
result = runner.invoke(reset_memories, ["-a"])
|
||||
mock_flow.memory.reset.assert_called_once()
|
||||
@@ -230,83 +197,16 @@ def test_reset_no_crew_or_flow_found(runner):
|
||||
"crewai.utilities.reset_memories.get_crews", return_value=[]
|
||||
), mock.patch(
|
||||
"crewai.utilities.reset_memories.get_flows", return_value=[]
|
||||
), mock.patch(
|
||||
"crewai.utilities.reset_memories._get_json_crew", return_value=None
|
||||
):
|
||||
result = runner.invoke(reset_memories, ["-m"])
|
||||
assert "No crew or flow found." in result.output
|
||||
|
||||
|
||||
def test_reset_json_crew_memory(mock_crew, runner, monkeypatch, tmp_path):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
(tmp_path / "crew.jsonc").write_text("{}")
|
||||
|
||||
with mock.patch(
|
||||
"crewai.utilities.reset_memories.get_crews", return_value=[]
|
||||
), mock.patch(
|
||||
"crewai.utilities.reset_memories.get_flows", return_value=[]
|
||||
), mock.patch(
|
||||
"crewai.utilities.reset_memories.load_crew",
|
||||
return_value=(mock_crew, {}),
|
||||
) as mock_load_crew:
|
||||
result = runner.invoke(reset_memories, ["-m"])
|
||||
|
||||
mock_load_crew.assert_called_once_with(Path("crew.jsonc"))
|
||||
mock_crew.reset_memories.assert_called_once_with(command_type="memory")
|
||||
assert f"[Crew ({mock_crew.name})] Memory has been reset." in result.output
|
||||
|
||||
|
||||
def test_reset_invalid_json_crew_does_not_block_classic_crew(
|
||||
mock_crew, runner, monkeypatch, tmp_path
|
||||
):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
(tmp_path / "crew.jsonc").write_text("{invalid")
|
||||
|
||||
with mock.patch(
|
||||
"crewai.utilities.reset_memories.get_crews", return_value=[mock_crew]
|
||||
), mock.patch(
|
||||
"crewai.utilities.reset_memories.get_flows", return_value=[]
|
||||
), mock.patch(
|
||||
"crewai.utilities.reset_memories.load_crew",
|
||||
side_effect=ValueError("invalid JSON"),
|
||||
) as mock_load_crew:
|
||||
result = runner.invoke(reset_memories, ["-m"])
|
||||
|
||||
mock_load_crew.assert_called_once_with(Path("crew.jsonc"))
|
||||
mock_crew.reset_memories.assert_called_once_with(command_type="memory")
|
||||
assert "Skipping JSON crew at crew.jsonc: failed to load (invalid JSON)." in result.output
|
||||
assert f"[Crew ({mock_crew.name})] Memory has been reset." in result.output
|
||||
|
||||
|
||||
def test_reset_json_crew_skipped_for_declared_flow_project(
|
||||
mock_crew, runner, monkeypatch, tmp_path
|
||||
):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
(tmp_path / "crew.jsonc").write_text("{}")
|
||||
(tmp_path / "pyproject.toml").write_text('[tool.crewai]\ntype = "flow"\n')
|
||||
|
||||
with mock.patch(
|
||||
"crewai.utilities.reset_memories.get_crews", return_value=[]
|
||||
), mock.patch(
|
||||
"crewai.utilities.reset_memories.get_flows", return_value=[]
|
||||
), mock.patch(
|
||||
"crewai.utilities.reset_memories.load_crew",
|
||||
return_value=(mock_crew, {}),
|
||||
) as mock_load_crew:
|
||||
result = runner.invoke(reset_memories, ["-m"])
|
||||
|
||||
mock_load_crew.assert_not_called()
|
||||
mock_crew.reset_memories.assert_not_called()
|
||||
assert "No crew or flow found." in result.output
|
||||
|
||||
|
||||
def test_reset_crew_and_flow_memory(mock_crew, mock_flow, runner):
|
||||
with mock.patch(
|
||||
"crewai.utilities.reset_memories.get_crews", return_value=[mock_crew]
|
||||
), mock.patch(
|
||||
"crewai.utilities.reset_memories.get_flows", return_value=[mock_flow]
|
||||
), mock.patch(
|
||||
"crewai.utilities.reset_memories._get_json_crew", return_value=None
|
||||
):
|
||||
result = runner.invoke(reset_memories, ["-m"])
|
||||
mock_crew.reset_memories.assert_called_once_with(command_type="memory")
|
||||
@@ -323,8 +223,6 @@ def test_reset_flow_memory_none(runner):
|
||||
"crewai.utilities.reset_memories.get_crews", return_value=[]
|
||||
), mock.patch(
|
||||
"crewai.utilities.reset_memories.get_flows", return_value=[mock_flow]
|
||||
), mock.patch(
|
||||
"crewai.utilities.reset_memories._get_json_crew", return_value=None
|
||||
):
|
||||
result = runner.invoke(reset_memories, ["-m"])
|
||||
assert "[Flow (NoMemFlow)] Memory has been reset." in result.output
|
||||
|
||||
@@ -8,7 +8,6 @@ not silently zero-fill vectors or return empty search results.
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -98,33 +97,6 @@ def test_lancedb_reopened_store_detects_mismatch(lancedb_path: Path) -> None:
|
||||
reopened.search([0.1] * 8)
|
||||
|
||||
|
||||
def test_memory_reset_all_rebuilds_reopened_store_with_new_dimension(
|
||||
lancedb_path: Path,
|
||||
) -> None:
|
||||
from crewai.memory.storage.lancedb_storage import LanceDBStorage
|
||||
from crewai.memory.unified_memory import Memory
|
||||
|
||||
old = LanceDBStorage(path=str(lancedb_path), vector_dim=4)
|
||||
old.save([_record(4)])
|
||||
|
||||
mem = Memory(
|
||||
storage=str(lancedb_path),
|
||||
llm=MagicMock(),
|
||||
embedder=lambda texts: [[0.1] * 8 for _ in texts],
|
||||
root_scope="/crew/test",
|
||||
)
|
||||
|
||||
mem.reset_all()
|
||||
mem.remember(
|
||||
"new embedder output",
|
||||
scope="/facts",
|
||||
categories=["test"],
|
||||
importance=0.5,
|
||||
)
|
||||
|
||||
assert mem.recall("new embedder output", scope="/facts", depth="shallow")
|
||||
|
||||
|
||||
def test_lancedb_matching_dim_still_works(lancedb_path: Path) -> None:
|
||||
from crewai.memory.storage.lancedb_storage import LanceDBStorage
|
||||
|
||||
|
||||
@@ -954,54 +954,6 @@ def test_remember_many_returns_immediately(tmp_path: Path) -> None:
|
||||
assert mem._storage.count() == 2
|
||||
|
||||
|
||||
def test_reset_all_blocks_new_save_submission_until_reset_completes(
|
||||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""A save cannot be submitted between draining writes and resetting storage."""
|
||||
from crewai.memory.unified_memory import Memory
|
||||
|
||||
mem = Memory(
|
||||
storage=str(tmp_path / "db"),
|
||||
llm=MagicMock(),
|
||||
embedder=lambda texts: [[0.1] * 4 for _ in texts],
|
||||
)
|
||||
reset_started = threading.Event()
|
||||
release_reset = threading.Event()
|
||||
submission_returned = threading.Event()
|
||||
order: list[str] = []
|
||||
original_reset = mem._storage.reset
|
||||
|
||||
def blocking_reset(scope_prefix: str | None = None) -> None:
|
||||
order.append("reset-start")
|
||||
reset_started.set()
|
||||
assert release_reset.wait(timeout=2)
|
||||
original_reset(scope_prefix=scope_prefix)
|
||||
order.append("reset-end")
|
||||
|
||||
def submit_save() -> None:
|
||||
mem._submit_save(lambda: order.append("save"))
|
||||
order.append("submit-returned")
|
||||
submission_returned.set()
|
||||
|
||||
monkeypatch.setattr(mem._storage, "reset", blocking_reset)
|
||||
|
||||
reset_thread = threading.Thread(target=mem.reset_all)
|
||||
reset_thread.start()
|
||||
assert reset_started.wait(timeout=2)
|
||||
|
||||
submit_thread = threading.Thread(target=submit_save)
|
||||
submit_thread.start()
|
||||
assert not submission_returned.wait(timeout=0.1)
|
||||
|
||||
release_reset.set()
|
||||
reset_thread.join(timeout=2)
|
||||
submit_thread.join(timeout=2)
|
||||
|
||||
assert not reset_thread.is_alive()
|
||||
assert not submit_thread.is_alive()
|
||||
assert order.index("reset-end") < order.index("submit-returned")
|
||||
|
||||
|
||||
def test_recall_drains_pending_writes(tmp_path: Path, mock_embedder: MagicMock) -> None:
|
||||
"""recall() should automatically wait for pending background saves."""
|
||||
from crewai.memory.unified_memory import Memory
|
||||
|
||||
@@ -645,14 +645,11 @@ class TestLegacyMethodOutputsRestore:
|
||||
context = _expression_context(restored)
|
||||
assert context["outputs"] == {"": "legacy"}
|
||||
|
||||
def test_raw_legacy_outputs_remain_readable(self) -> None:
|
||||
from crewai.flow.runtime._expressions import _expression_context
|
||||
|
||||
def test_raw_legacy_outputs_property_remains_readable(self) -> None:
|
||||
flow = Flow()
|
||||
flow._method_outputs = ["legacy"]
|
||||
|
||||
assert flow.method_outputs == ["legacy"]
|
||||
assert _expression_context(flow)["outputs"] == {"": "legacy"}
|
||||
|
||||
|
||||
class TestAgentCheckpoint:
|
||||
|
||||
@@ -4584,26 +4584,6 @@ def test_reset_knowledge_with_no_crew_knowledge(researcher, writer):
|
||||
)
|
||||
|
||||
|
||||
def test_reset_memory_uses_full_unified_memory_reset(researcher):
|
||||
crew = Crew(
|
||||
agents=[researcher],
|
||||
process=Process.sequential,
|
||||
tasks=[
|
||||
Task(description="Task 1", expected_output="output", agent=researcher),
|
||||
],
|
||||
memory=True,
|
||||
)
|
||||
|
||||
assert isinstance(crew._memory, Memory)
|
||||
with patch.object(Memory, "reset_all") as reset_all, patch.object(
|
||||
Memory, "reset"
|
||||
) as reset:
|
||||
crew.reset_memories(command_type="memory")
|
||||
|
||||
reset_all.assert_called_once_with()
|
||||
reset.assert_not_called()
|
||||
|
||||
|
||||
def test_reset_knowledge_with_only_crew_knowledge(researcher, writer):
|
||||
mock_ks = MagicMock(spec=Knowledge)
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from pathlib import Path
|
||||
from typing import Annotated, Literal
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
import crewai.flow.dsl as flow_dsl
|
||||
import crewai.flow.flow_definition as flow_definition
|
||||
@@ -45,19 +45,52 @@ def test_flow_public_exports_are_explicit():
|
||||
"FlowDefinition",
|
||||
"FlowDefinitionCondition",
|
||||
"FlowDefinitionDiagnostic",
|
||||
"FlowDictStateDefinition",
|
||||
"FlowEachActionDefinition",
|
||||
"FlowEachInnerActionDefinition",
|
||||
"FlowExpressionActionDefinition",
|
||||
"FlowHumanFeedbackDefinition",
|
||||
"FlowJsonSchemaStateDefinition",
|
||||
"FlowMethodDefinition",
|
||||
"FlowPersistenceDefinition",
|
||||
"FlowPydanticStateDefinition",
|
||||
"FlowScriptActionDefinition",
|
||||
"FlowStateDefinition",
|
||||
"FlowToolActionDefinition",
|
||||
"FlowUnknownStateDefinition",
|
||||
}
|
||||
assert "build_flow_structure" in flow_visualization.__all__
|
||||
assert "calculate_node_levels" not in flow_visualization.__all__
|
||||
|
||||
|
||||
def test_flow_state_definition_uses_discriminated_branches():
|
||||
definition = flow_definition.FlowDefinition.model_validate(
|
||||
{
|
||||
"name": "TypedStateFlow",
|
||||
"state": {
|
||||
"type": "json_schema",
|
||||
"json_schema": {"type": "object"},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
assert isinstance(
|
||||
definition.state,
|
||||
flow_definition.FlowJsonSchemaStateDefinition,
|
||||
)
|
||||
|
||||
with pytest.raises(ValidationError, match="extra_forbidden"):
|
||||
flow_definition.FlowDefinition.model_validate(
|
||||
{
|
||||
"name": "InvalidStateFlow",
|
||||
"state": {
|
||||
"type": "dict",
|
||||
"ref": "my_project.flows:ResearchState",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_condition_combinators_return_nested_runtime_tree():
|
||||
condition = and_("event_a", "event_b", or_("event_c"))
|
||||
|
||||
|
||||
@@ -1145,6 +1145,116 @@ methods:
|
||||
assert flow.kickoff(inputs={"rows": ["a", "b"]}) == ["async:a", "async:b"]
|
||||
|
||||
|
||||
def test_script_action_requires_explicit_opt_in():
|
||||
yaml_str = """
|
||||
schema: crewai.flow/v1
|
||||
name: ScriptFlow
|
||||
methods:
|
||||
normalize:
|
||||
do:
|
||||
call: script
|
||||
code: |
|
||||
return "blocked"
|
||||
start: true
|
||||
"""
|
||||
|
||||
with pytest.raises(
|
||||
RuntimeError, match="CREWAI_ALLOW_FLOW_SCRIPT_EXECUTION=1"
|
||||
) as exc_info:
|
||||
Flow.from_definition(FlowDefinition.from_yaml(yaml_str))
|
||||
assert "methods with unresolvable actions" not in str(exc_info.value)
|
||||
|
||||
|
||||
def test_script_action_runs_python_imports_mutates_state_and_returns_value(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
monkeypatch.setenv("CREWAI_ALLOW_FLOW_SCRIPT_EXECUTION", "1")
|
||||
|
||||
yaml_str = """
|
||||
schema: crewai.flow/v1
|
||||
name: ScriptFlow
|
||||
methods:
|
||||
normalize:
|
||||
do:
|
||||
call: script
|
||||
code: |
|
||||
import math
|
||||
|
||||
state["rounded"] = math.ceil(state["raw_score"])
|
||||
return f"rounded:{state['rounded']}"
|
||||
start: true
|
||||
"""
|
||||
|
||||
flow = Flow.from_definition(FlowDefinition.from_yaml(yaml_str))
|
||||
|
||||
assert flow.kickoff(inputs={"raw_score": 3.2}) == "rounded:4"
|
||||
assert flow.state["rounded"] == 4
|
||||
|
||||
|
||||
def test_script_listener_reads_trigger_input_and_outputs(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
monkeypatch.setenv("CREWAI_ALLOW_FLOW_SCRIPT_EXECUTION", "1")
|
||||
|
||||
yaml_str = """
|
||||
schema: crewai.flow/v1
|
||||
name: ScriptFlow
|
||||
methods:
|
||||
seed:
|
||||
do:
|
||||
call: expression
|
||||
expr: "'alpha'"
|
||||
start: true
|
||||
combine:
|
||||
do:
|
||||
call: script
|
||||
code: |
|
||||
state["input_matches_output"] = input == outputs["seed"]
|
||||
return f"{outputs['seed']}:{input}"
|
||||
listen: seed
|
||||
"""
|
||||
|
||||
flow = Flow.from_definition(FlowDefinition.from_yaml(yaml_str))
|
||||
|
||||
assert flow.kickoff() == "alpha:alpha"
|
||||
assert flow.state["input_matches_output"] is True
|
||||
|
||||
|
||||
def test_script_each_action_reads_item_and_inner_outputs(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
monkeypatch.setenv("CREWAI_ALLOW_FLOW_SCRIPT_EXECUTION", "1")
|
||||
|
||||
yaml_str = """
|
||||
schema: crewai.flow/v1
|
||||
name: ScriptEachFlow
|
||||
methods:
|
||||
seed:
|
||||
do:
|
||||
call: expression
|
||||
expr: "'global'"
|
||||
start: true
|
||||
process_rows:
|
||||
do:
|
||||
call: each
|
||||
in: state.rows
|
||||
do:
|
||||
- clean:
|
||||
call: script
|
||||
code: |
|
||||
return item.strip()
|
||||
- tag:
|
||||
call: script
|
||||
code: |
|
||||
return f"{outputs['seed']}:{outputs['clean']}"
|
||||
listen: seed
|
||||
"""
|
||||
|
||||
flow = Flow.from_definition(FlowDefinition.from_yaml(yaml_str))
|
||||
|
||||
assert flow.kickoff(inputs={"rows": [" a ", " b "]}) == ["global:a", "global:b"]
|
||||
|
||||
|
||||
def test_each_action_uses_iteration_outputs_between_nested_actions():
|
||||
yaml_str = f"""
|
||||
schema: crewai.flow/v1
|
||||
|
||||
Reference in New Issue
Block a user