mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-06-15 13:18:09 +00:00
Compare commits
3 Commits
flow-itera
...
devin/1781
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a8bf69e05b | ||
|
|
29a39cfeef | ||
|
|
7575d9b64a |
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
import hashlib
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
import uuid
|
||||
@@ -54,6 +55,7 @@ class CrewAIRagAdapter(Adapter):
|
||||
similarity_threshold: float = 0.6
|
||||
limit: int = 5
|
||||
config: RagConfigType | None = None
|
||||
content_filter: Callable[[list[str]], list[str]] | None = None
|
||||
_client: BaseClient | None = PrivateAttr(default=None)
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
@@ -348,6 +350,15 @@ class CrewAIRagAdapter(Adapter):
|
||||
)
|
||||
|
||||
if documents:
|
||||
if self.content_filter is not None:
|
||||
filtered_contents = set(
|
||||
self.content_filter([doc["content"] for doc in documents])
|
||||
)
|
||||
documents = [
|
||||
doc for doc in documents if doc["content"] in filtered_contents
|
||||
]
|
||||
if not documents:
|
||||
return
|
||||
if self._client is None:
|
||||
raise ValueError("Client is not initialized")
|
||||
self._client.add_documents(
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Callable, Iterator
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
@@ -246,6 +246,26 @@ class NL2SQLTool(BaseTool):
|
||||
"write operations."
|
||||
),
|
||||
)
|
||||
require_approval: bool = Field(
|
||||
default=False,
|
||||
title="Require Approval",
|
||||
description=(
|
||||
"When True, every query is shown to a human for approval "
|
||||
"before execution. The approval_handler callable is invoked "
|
||||
"with the SQL string and must return True to proceed. "
|
||||
"Defaults to an interactive terminal prompt."
|
||||
),
|
||||
)
|
||||
approval_handler: Callable[[str], bool] | None = Field(
|
||||
default=None,
|
||||
exclude=True,
|
||||
description=(
|
||||
"Custom callable invoked when require_approval is True. "
|
||||
"Receives the SQL query string and must return True to "
|
||||
"allow execution or False to reject it. When None, a "
|
||||
"built-in interactive terminal prompt is used."
|
||||
),
|
||||
)
|
||||
tables: list[dict[str, Any]] = Field(default_factory=list)
|
||||
columns: dict[str, list[dict[str, Any]] | str] = Field(default_factory=dict)
|
||||
args_schema: type[BaseModel] = NL2SQLToolInput
|
||||
@@ -420,9 +440,31 @@ class NL2SQLTool(BaseTool):
|
||||
|
||||
# Core execution
|
||||
|
||||
def _request_approval(self, sql_query: str) -> bool:
|
||||
"""Ask for human approval before executing the query.
|
||||
|
||||
Uses ``approval_handler`` if provided, otherwise falls back to an
|
||||
interactive terminal prompt via ``input()``.
|
||||
"""
|
||||
if self.approval_handler is not None:
|
||||
return self.approval_handler(sql_query)
|
||||
try:
|
||||
answer = input(
|
||||
f"\n[NL2SQLTool] The following query requires approval "
|
||||
f"before execution:\n\n {sql_query}\n\n"
|
||||
f"Execute this query? (y/n): "
|
||||
)
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
return False
|
||||
return answer.strip().lower() in ("y", "yes")
|
||||
|
||||
def _run(self, sql_query: str) -> list[dict[str, Any]] | str:
|
||||
try:
|
||||
self._validate_query(sql_query)
|
||||
if self.require_approval and not self._request_approval(sql_query):
|
||||
return (
|
||||
f"Query execution was rejected by the human reviewer: {sql_query}"
|
||||
)
|
||||
data = self.execute_sql(sql_query)
|
||||
except ValueError:
|
||||
raise
|
||||
|
||||
@@ -0,0 +1,96 @@
|
||||
"""Tests for CrewAIRagAdapter.content_filter."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai_tools.adapters.crewai_rag_adapter import CrewAIRagAdapter
|
||||
|
||||
|
||||
def _make_adapter(
|
||||
content_filter=None,
|
||||
collection_name: str = "test_collection",
|
||||
) -> CrewAIRagAdapter:
|
||||
"""Build a CrewAIRagAdapter with a mocked RAG client."""
|
||||
mock_client = MagicMock()
|
||||
with patch(
|
||||
"crewai_tools.adapters.crewai_rag_adapter.get_rag_client",
|
||||
return_value=mock_client,
|
||||
):
|
||||
adapter = CrewAIRagAdapter(
|
||||
collection_name=collection_name,
|
||||
content_filter=content_filter,
|
||||
)
|
||||
return adapter
|
||||
|
||||
|
||||
class TestContentFilterOnAdd:
|
||||
def test_filter_removes_documents(self) -> None:
|
||||
"""Documents whose content is rejected by the filter are not indexed."""
|
||||
|
||||
def drop_secrets(contents: list[str]) -> list[str]:
|
||||
return [c for c in contents if "SECRET" not in c]
|
||||
|
||||
adapter = _make_adapter(content_filter=drop_secrets)
|
||||
mock_client = adapter._client
|
||||
assert mock_client is not None
|
||||
|
||||
adapter.add(
|
||||
"safe text",
|
||||
data_type="text",
|
||||
)
|
||||
# The add method processes the text into BaseRecord documents.
|
||||
# With the filter, only safe ones should pass.
|
||||
if mock_client.add_documents.called:
|
||||
docs = mock_client.add_documents.call_args.kwargs["documents"]
|
||||
for doc in docs:
|
||||
assert "SECRET" not in doc["content"]
|
||||
|
||||
def test_filter_drops_all_skips_add(self) -> None:
|
||||
"""When the filter removes every document, add_documents is not called."""
|
||||
adapter = _make_adapter(content_filter=lambda contents: [])
|
||||
mock_client = adapter._client
|
||||
assert mock_client is not None
|
||||
|
||||
adapter.add("anything", data_type="text")
|
||||
|
||||
mock_client.add_documents.assert_not_called()
|
||||
|
||||
def test_filter_exception_propagates(self) -> None:
|
||||
"""An exception from content_filter aborts the add."""
|
||||
|
||||
def exploding_filter(contents: list[str]) -> list[str]:
|
||||
raise ValueError("Policy violation")
|
||||
|
||||
adapter = _make_adapter(content_filter=exploding_filter)
|
||||
|
||||
with pytest.raises(ValueError, match="Policy violation"):
|
||||
adapter.add("content", data_type="text")
|
||||
|
||||
def test_no_filter_is_noop(self) -> None:
|
||||
"""When content_filter is None, documents are persisted normally."""
|
||||
adapter = _make_adapter(content_filter=None)
|
||||
assert adapter.content_filter is None
|
||||
mock_client = adapter._client
|
||||
assert mock_client is not None
|
||||
|
||||
adapter.add("hello world", data_type="text")
|
||||
|
||||
mock_client.add_documents.assert_called_once()
|
||||
docs = mock_client.add_documents.call_args.kwargs["documents"]
|
||||
assert len(docs) >= 1
|
||||
|
||||
def test_filter_receives_all_content_strings(self) -> None:
|
||||
"""The filter callable receives the full list of content strings."""
|
||||
received: list[list[str]] = []
|
||||
|
||||
def capturing_filter(contents: list[str]) -> list[str]:
|
||||
received.append(contents)
|
||||
return contents
|
||||
|
||||
adapter = _make_adapter(content_filter=capturing_filter)
|
||||
|
||||
adapter.add("some text content", data_type="text")
|
||||
|
||||
assert len(received) == 1
|
||||
assert all(isinstance(c, str) for c in received[0])
|
||||
@@ -598,3 +598,85 @@ class TestCTEUnknownCommand:
|
||||
tool = _make_tool(allow_dml=False)
|
||||
with pytest.raises(ValueError, match="unrecognised"):
|
||||
tool._validate_query("WITH cte AS (SELECT 1) FOOBAR")
|
||||
|
||||
|
||||
# --- require_approval tests ---
|
||||
|
||||
|
||||
class TestRequireApproval:
|
||||
def test_approval_granted_executes_query(self):
|
||||
"""When the approval handler returns True, the query runs normally."""
|
||||
tool = _make_tool(
|
||||
require_approval=True,
|
||||
approval_handler=lambda sql: True,
|
||||
)
|
||||
result = tool._run("SELECT 1 AS val")
|
||||
assert result == [{"val": 1}]
|
||||
|
||||
def test_approval_rejected_blocks_query(self):
|
||||
"""When the approval handler returns False, execution is blocked."""
|
||||
tool = _make_tool(
|
||||
require_approval=True,
|
||||
approval_handler=lambda sql: False,
|
||||
)
|
||||
result = tool._run("SELECT 1 AS val")
|
||||
assert "rejected" in result.lower()
|
||||
|
||||
def test_approval_handler_receives_sql_string(self):
|
||||
"""The approval_handler receives the exact SQL query string."""
|
||||
received: list[str] = []
|
||||
|
||||
def spy(sql: str) -> bool:
|
||||
received.append(sql)
|
||||
return True
|
||||
|
||||
tool = _make_tool(require_approval=True, approval_handler=spy)
|
||||
tool._run("SELECT 42 AS answer")
|
||||
assert received == ["SELECT 42 AS answer"]
|
||||
|
||||
def test_no_approval_when_flag_is_false(self):
|
||||
"""require_approval=False never invokes the handler."""
|
||||
handler = MagicMock(return_value=True)
|
||||
tool = _make_tool(require_approval=False, approval_handler=handler)
|
||||
tool._run("SELECT 1")
|
||||
handler.assert_not_called()
|
||||
|
||||
def test_default_prompt_on_eof(self):
|
||||
"""The built-in prompt returns False when input() raises EOFError."""
|
||||
tool = _make_tool(require_approval=True)
|
||||
with patch("builtins.input", side_effect=EOFError):
|
||||
result = tool._run("SELECT 1")
|
||||
assert "rejected" in result.lower()
|
||||
|
||||
def test_default_prompt_yes(self):
|
||||
"""The built-in prompt allows execution when user types 'y'."""
|
||||
tool = _make_tool(require_approval=True)
|
||||
with patch("builtins.input", return_value="y"):
|
||||
result = tool._run("SELECT 1 AS val")
|
||||
assert result == [{"val": 1}]
|
||||
|
||||
def test_default_prompt_no(self):
|
||||
"""The built-in prompt blocks execution when user types 'n'."""
|
||||
tool = _make_tool(require_approval=True)
|
||||
with patch("builtins.input", return_value="n"):
|
||||
result = tool._run("SELECT 1")
|
||||
assert "rejected" in result.lower()
|
||||
|
||||
def test_approval_checked_after_validation(self):
|
||||
"""Validation runs before approval — blocked queries never reach the handler."""
|
||||
handler = MagicMock(return_value=True)
|
||||
tool = _make_tool(
|
||||
allow_dml=False,
|
||||
require_approval=True,
|
||||
approval_handler=handler,
|
||||
)
|
||||
with pytest.raises(ValueError, match="read-only mode"):
|
||||
tool._run("DROP TABLE users")
|
||||
handler.assert_not_called()
|
||||
|
||||
def test_approval_with_keyboard_interrupt(self):
|
||||
"""KeyboardInterrupt during input() rejects the query gracefully."""
|
||||
tool = _make_tool(require_approval=True)
|
||||
with patch("builtins.input", side_effect=KeyboardInterrupt):
|
||||
result = tool._run("SELECT 1")
|
||||
assert "rejected" in result.lower()
|
||||
|
||||
@@ -15870,6 +15870,12 @@
|
||||
"title": "Database URI",
|
||||
"type": "string"
|
||||
},
|
||||
"require_approval": {
|
||||
"default": false,
|
||||
"description": "When True, every query is shown to a human for approval before execution. The approval_handler callable is invoked with the SQL string and must return True to proceed. Defaults to an interactive terminal prompt.",
|
||||
"title": "Require Approval",
|
||||
"type": "boolean"
|
||||
},
|
||||
"tables": {
|
||||
"items": {
|
||||
"additionalProperties": true,
|
||||
|
||||
@@ -11,17 +11,9 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Literal as TypingLiteral
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
RootModel,
|
||||
field_serializer,
|
||||
model_validator,
|
||||
)
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_serializer, model_validator
|
||||
import yaml
|
||||
|
||||
from crewai.flow.conversational_definition import (
|
||||
@@ -33,7 +25,6 @@ from crewai.flow.conversational_definition import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
FlowDefinitionCondition = str | dict[str, Any]
|
||||
_STEP_NAME_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
|
||||
|
||||
__all__ = [
|
||||
"FlowActionDefinition",
|
||||
@@ -44,8 +35,6 @@ __all__ = [
|
||||
"FlowDefinition",
|
||||
"FlowDefinitionCondition",
|
||||
"FlowDefinitionDiagnostic",
|
||||
"FlowEachActionDefinition",
|
||||
"FlowEachInnerActionDefinition",
|
||||
"FlowExpressionActionDefinition",
|
||||
"FlowHumanFeedbackDefinition",
|
||||
"FlowMethodDefinition",
|
||||
@@ -159,11 +148,10 @@ class FlowHumanFeedbackDefinition(BaseModel):
|
||||
class FlowCodeActionDefinition(BaseModel):
|
||||
"""A Flow method action that executes importable Python code."""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
call: TypingLiteral["code"] = "code"
|
||||
ref: str
|
||||
with_: dict[str, Any] | None = Field(default=None, alias="with")
|
||||
|
||||
|
||||
class FlowToolActionDefinition(BaseModel):
|
||||
@@ -185,75 +173,14 @@ class FlowExpressionActionDefinition(BaseModel):
|
||||
expr: str
|
||||
|
||||
|
||||
FlowInnerActionDefinition = (
|
||||
FlowCodeActionDefinition | FlowToolActionDefinition | FlowExpressionActionDefinition
|
||||
)
|
||||
|
||||
|
||||
class FlowEachInnerActionDefinition(RootModel[dict[str, FlowInnerActionDefinition]]):
|
||||
"""One named action inside an ``each`` composite action."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return next(iter(self.root))
|
||||
|
||||
@property
|
||||
def action(self) -> FlowInnerActionDefinition:
|
||||
return next(iter(self.root.values()))
|
||||
|
||||
|
||||
class FlowEachActionDefinition(BaseModel):
|
||||
"""A composite action that runs a sequential mini-pipeline for each item."""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
call: TypingLiteral["each"]
|
||||
in_: str = Field(alias="in")
|
||||
do: list[FlowEachInnerActionDefinition]
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _validate_inner_action_list(cls, data: Any) -> Any:
|
||||
if not isinstance(data, dict) or "do" not in data:
|
||||
return data
|
||||
|
||||
inner_actions = data["do"]
|
||||
if not isinstance(inner_actions, list) or not inner_actions:
|
||||
raise ValueError("each.do must contain at least one action")
|
||||
|
||||
seen: set[str] = set()
|
||||
for inner_action in inner_actions:
|
||||
if isinstance(inner_action, FlowEachInnerActionDefinition):
|
||||
action_mapping = inner_action.root
|
||||
elif isinstance(inner_action, dict):
|
||||
action_mapping = inner_action
|
||||
else:
|
||||
raise ValueError("each.do entries must be one-key mappings")
|
||||
|
||||
if len(action_mapping) != 1:
|
||||
raise ValueError("each.do entries must be one-key mappings")
|
||||
|
||||
name = next(iter(action_mapping))
|
||||
_validate_step_name(name, field="each.do action names")
|
||||
if name in seen:
|
||||
raise ValueError(f"each.do action names must be unique: {name!r}")
|
||||
seen.add(name)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
FlowActionDefinition = (
|
||||
FlowCodeActionDefinition
|
||||
| FlowToolActionDefinition
|
||||
| FlowExpressionActionDefinition
|
||||
| FlowEachActionDefinition
|
||||
FlowCodeActionDefinition | FlowToolActionDefinition | FlowExpressionActionDefinition
|
||||
)
|
||||
|
||||
|
||||
class FlowMethodDefinition(BaseModel):
|
||||
"""Static definition of one Flow method and its execution roles."""
|
||||
|
||||
description: str | None = None
|
||||
do: FlowActionDefinition
|
||||
start: bool | FlowDefinitionCondition | None = None
|
||||
listen: FlowDefinitionCondition | None = None
|
||||
@@ -300,12 +227,6 @@ class FlowDefinition(BaseModel):
|
||||
methods: dict[str, FlowMethodDefinition] = Field(default_factory=dict)
|
||||
diagnostics: list[FlowDefinitionDiagnostic] = Field(default_factory=list)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_method_names(self) -> FlowDefinition:
|
||||
for method_name in self.methods:
|
||||
_validate_step_name(method_name, field="Flow method names")
|
||||
return self
|
||||
|
||||
def to_dict(self, *, exclude_none: bool = True) -> dict[str, Any]:
|
||||
"""Serialize the definition to a JSON/YAML-ready dictionary."""
|
||||
return self.model_dump(by_alias=True, exclude_none=exclude_none, mode="json")
|
||||
@@ -448,11 +369,6 @@ def _deserialize_diagnostics(value: Any) -> list[FlowDefinitionDiagnostic]:
|
||||
return [FlowDefinitionDiagnostic.model_validate(item) for item in value or []]
|
||||
|
||||
|
||||
def _validate_step_name(name: str, *, field: str) -> None:
|
||||
if not isinstance(name, str) or not _STEP_NAME_PATTERN.fullmatch(name):
|
||||
raise ValueError(f"{field} must match {_STEP_NAME_PATTERN.pattern}")
|
||||
|
||||
|
||||
def _merge_diagnostics(
|
||||
*diagnostic_groups: list[FlowDefinitionDiagnostic],
|
||||
) -> list[FlowDefinitionDiagnostic]:
|
||||
|
||||
@@ -121,8 +121,11 @@ from crewai.flow.human_feedback import (
|
||||
)
|
||||
from crewai.flow.input_provider import InputProvider
|
||||
from crewai.flow.persistence.base import FlowPersistence
|
||||
from crewai.flow.runtime._actions import build_action
|
||||
from crewai.flow.runtime._refs import resolve_instance_ref, resolve_ref
|
||||
from crewai.flow.runtime._resolvers import (
|
||||
resolve_action,
|
||||
resolve_instance_ref,
|
||||
resolve_ref,
|
||||
)
|
||||
from crewai.flow.types import (
|
||||
FlowExecutionData,
|
||||
FlowMethodName,
|
||||
@@ -1089,9 +1092,9 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
self._methods.update(methods)
|
||||
|
||||
def _action_bound_methods(self) -> dict[FlowMethodName, Callable[..., Any]]:
|
||||
def build(name: str, definition: FlowMethodDefinition) -> Callable[..., Any]:
|
||||
def resolve(name: str, definition: FlowMethodDefinition) -> Callable[..., Any]:
|
||||
try:
|
||||
return build_action(self, definition.do)
|
||||
return resolve_action(self, definition.do)
|
||||
except Exception as e:
|
||||
unresolved.append(f"{name}: {e}")
|
||||
return lambda *args, **kwargs: None
|
||||
@@ -1099,7 +1102,9 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
methods: dict[FlowMethodName, Callable[..., Any]] = {}
|
||||
unresolved: list[str] = []
|
||||
for method_name, method_definition in self._definition.methods.items():
|
||||
methods[FlowMethodName(method_name)] = build(method_name, method_definition)
|
||||
methods[FlowMethodName(method_name)] = resolve(
|
||||
method_name, method_definition
|
||||
)
|
||||
if unresolved:
|
||||
raise ValueError(
|
||||
f"Cannot build flow {self._definition.name!r} from its definition; "
|
||||
|
||||
@@ -1,48 +0,0 @@
|
||||
"""Build FlowDefinition actions into live runtime callables."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.flow.flow_definition import (
|
||||
FlowActionDefinition,
|
||||
FlowInnerActionDefinition,
|
||||
)
|
||||
from crewai.flow.runtime._actions._base import ActionHandlerRegistry
|
||||
from crewai.flow.runtime._actions._code import CodeActionHandler
|
||||
from crewai.flow.runtime._actions._each import EachActionHandler
|
||||
from crewai.flow.runtime._actions._expression import ExpressionActionHandler
|
||||
from crewai.flow.runtime._actions._tool import ToolActionHandler
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.flow.runtime import Flow
|
||||
|
||||
|
||||
__all__ = [
|
||||
"build_action",
|
||||
]
|
||||
|
||||
|
||||
_SIMPLE_ACTION_HANDLERS = (
|
||||
CodeActionHandler(),
|
||||
ToolActionHandler(),
|
||||
ExpressionActionHandler(),
|
||||
)
|
||||
|
||||
_SIMPLE_ACTION_REGISTRY = ActionHandlerRegistry[FlowInnerActionDefinition](
|
||||
_SIMPLE_ACTION_HANDLERS
|
||||
)
|
||||
|
||||
_ACTION_REGISTRY = ActionHandlerRegistry[FlowActionDefinition](
|
||||
(
|
||||
*_SIMPLE_ACTION_HANDLERS,
|
||||
EachActionHandler(_SIMPLE_ACTION_REGISTRY),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def build_action(flow: Flow[Any], action: FlowActionDefinition) -> Callable[..., Any]:
|
||||
"""Turn one `do:` action into the callable the flow runs for that node."""
|
||||
return _ACTION_REGISTRY.build(flow, action)
|
||||
@@ -1,39 +0,0 @@
|
||||
"""Shared action handler contracts."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.flow.runtime import Flow
|
||||
|
||||
|
||||
ActionT = TypeVar("ActionT", bound=BaseModel)
|
||||
ResolvedAction = Callable[..., Any]
|
||||
|
||||
|
||||
class ActionHandler(Protocol[ActionT]):
|
||||
"""Handler for one concrete FlowDefinition action type."""
|
||||
|
||||
action_type: type[ActionT]
|
||||
|
||||
def build(self, flow: Flow[Any], action: ActionT) -> ResolvedAction:
|
||||
"""Build the callable executed by the flow."""
|
||||
|
||||
|
||||
class ActionHandlerRegistry(Generic[ActionT]):
|
||||
"""Build action callables with an ordered set of typed handlers."""
|
||||
|
||||
def __init__(self, handlers: Iterable[ActionHandler[Any]]) -> None:
|
||||
self._handlers = tuple(handlers)
|
||||
|
||||
def build(self, flow: Flow[Any], action: ActionT) -> ResolvedAction:
|
||||
for handler in self._handlers:
|
||||
if isinstance(action, handler.action_type):
|
||||
return handler.build(flow, action)
|
||||
call = getattr(action, "call", None)
|
||||
raise ValueError(f"unknown call type {call!r}")
|
||||
@@ -1,51 +0,0 @@
|
||||
"""Handler for ``call: code`` FlowDefinition actions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
import functools
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from crewai.flow.flow_definition import FlowCodeActionDefinition
|
||||
from crewai.flow.runtime._actions._base import ResolvedAction
|
||||
from crewai.flow.runtime._actions._runtime import LOCAL_CONTEXT_KWARG
|
||||
from crewai.flow.runtime._expressions import render_with_block
|
||||
from crewai.flow.runtime._refs import InvalidRefError, resolve_ref
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.flow.runtime import Flow
|
||||
|
||||
|
||||
class CodeActionHandler:
|
||||
"""Build importable Python callables and bind them to the running flow."""
|
||||
|
||||
action_type = FlowCodeActionDefinition
|
||||
|
||||
def build(
|
||||
self, flow: Flow[Any], action: FlowCodeActionDefinition
|
||||
) -> ResolvedAction:
|
||||
handler = _resolve_code_handler(flow, action)
|
||||
|
||||
def run_code(*args: Any, **kwargs: Any) -> Any:
|
||||
local_context = kwargs.pop(LOCAL_CONTEXT_KWARG, None)
|
||||
if action.with_ is None:
|
||||
return handler(*args, **kwargs)
|
||||
return handler(
|
||||
**render_with_block(flow, action.with_, local_context=local_context)
|
||||
)
|
||||
|
||||
return functools.update_wrapper(run_code, handler)
|
||||
|
||||
|
||||
def _resolve_code_handler(
|
||||
flow: Flow[Any], action: FlowCodeActionDefinition
|
||||
) -> Callable[..., Any]:
|
||||
ref = action.ref
|
||||
target = resolve_ref(ref, field="do")
|
||||
if not callable(target):
|
||||
raise InvalidRefError(f"invalid do ref {ref!r}; object is not callable")
|
||||
handler = cast(Callable[..., Any], target)
|
||||
if getattr(handler, "__self__", None) is None:
|
||||
handler = handler.__get__(flow, type(flow))
|
||||
return handler
|
||||
@@ -1,73 +0,0 @@
|
||||
"""Handler for ``call: each`` FlowDefinition actions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.flow.flow_definition import (
|
||||
FlowEachActionDefinition,
|
||||
FlowEachInnerActionDefinition,
|
||||
FlowInnerActionDefinition,
|
||||
)
|
||||
from crewai.flow.runtime._actions._base import (
|
||||
ActionHandlerRegistry,
|
||||
ResolvedAction,
|
||||
)
|
||||
from crewai.flow.runtime._actions._runtime import (
|
||||
LOCAL_CONTEXT_KWARG,
|
||||
ensure_array,
|
||||
invoke_callable,
|
||||
)
|
||||
from crewai.flow.runtime._expressions import evaluate_expression
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.flow.runtime import Flow
|
||||
|
||||
|
||||
class EachActionHandler:
|
||||
"""Build a sequential mini-pipeline for every item in an array."""
|
||||
|
||||
action_type = FlowEachActionDefinition
|
||||
|
||||
def __init__(
|
||||
self, inner_registry: ActionHandlerRegistry[FlowInnerActionDefinition]
|
||||
) -> None:
|
||||
self._inner_registry = inner_registry
|
||||
|
||||
def build(
|
||||
self, flow: Flow[Any], action: FlowEachActionDefinition
|
||||
) -> ResolvedAction:
|
||||
inner_actions = [
|
||||
(inner_action.name, self._resolve_inner_action(flow, inner_action))
|
||||
for inner_action in action.do
|
||||
]
|
||||
|
||||
async def run_each(*_args: Any, **_kwargs: Any) -> list[Any]:
|
||||
items = ensure_array(evaluate_expression(flow, action.in_))
|
||||
results: list[Any] = []
|
||||
for item in items:
|
||||
local_outputs: dict[str, Any] = {}
|
||||
last_output: Any = None
|
||||
for name, run_inner_action in inner_actions:
|
||||
last_output = await run_inner_action(
|
||||
{"item": item, "outputs": local_outputs}
|
||||
)
|
||||
local_outputs[name] = last_output
|
||||
results.append(last_output)
|
||||
return results
|
||||
|
||||
return run_each
|
||||
|
||||
def _resolve_inner_action(
|
||||
self, flow: Flow[Any], inner_action: FlowEachInnerActionDefinition
|
||||
) -> Callable[[dict[str, Any]], Any]:
|
||||
run_action = self._inner_registry.build(flow, inner_action.action)
|
||||
|
||||
async def run_inner_action(local_context: dict[str, Any]) -> Any:
|
||||
return await invoke_callable(
|
||||
run_action, **{LOCAL_CONTEXT_KWARG: local_context}
|
||||
)
|
||||
|
||||
return run_inner_action
|
||||
@@ -1,29 +0,0 @@
|
||||
"""Handler for ``call: expression`` FlowDefinition actions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.flow.flow_definition import FlowExpressionActionDefinition
|
||||
from crewai.flow.runtime._actions._base import ResolvedAction
|
||||
from crewai.flow.runtime._actions._runtime import LOCAL_CONTEXT_KWARG
|
||||
from crewai.flow.runtime._expressions import evaluate_expression
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.flow.runtime import Flow
|
||||
|
||||
|
||||
class ExpressionActionHandler:
|
||||
"""Build CEL expression actions."""
|
||||
|
||||
action_type = FlowExpressionActionDefinition
|
||||
|
||||
def build(
|
||||
self, flow: Flow[Any], action: FlowExpressionActionDefinition
|
||||
) -> ResolvedAction:
|
||||
def run_expression(*_args: Any, **kwargs: Any) -> Any:
|
||||
local_context = kwargs.pop(LOCAL_CONTEXT_KWARG, None)
|
||||
return evaluate_expression(flow, action.expr, local_context=local_context)
|
||||
|
||||
return run_expression
|
||||
@@ -1,28 +0,0 @@
|
||||
"""Runtime helpers shared by action resolvers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
import inspect
|
||||
from typing import Any
|
||||
|
||||
|
||||
LOCAL_CONTEXT_KWARG = "__flow_definition_local_context"
|
||||
|
||||
|
||||
async def invoke_callable(
|
||||
handler: Callable[..., Any], *args: Any, **kwargs: Any
|
||||
) -> Any:
|
||||
if inspect.iscoroutinefunction(handler):
|
||||
result = await handler(*args, **kwargs)
|
||||
else:
|
||||
result = handler(*args, **kwargs)
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
return result
|
||||
|
||||
|
||||
def ensure_array(value: Any) -> list[Any]:
|
||||
if isinstance(value, list):
|
||||
return value
|
||||
raise ValueError("each.in must evaluate to an array")
|
||||
@@ -1,52 +0,0 @@
|
||||
"""Handler for ``call: tool`` FlowDefinition actions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from crewai.flow.flow_definition import FlowToolActionDefinition
|
||||
from crewai.flow.runtime._actions._base import ResolvedAction
|
||||
from crewai.flow.runtime._actions._runtime import LOCAL_CONTEXT_KWARG
|
||||
from crewai.flow.runtime._expressions import render_with_block
|
||||
from crewai.flow.runtime._refs import InvalidRefError, resolve_ref
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.flow.runtime import Flow
|
||||
|
||||
|
||||
class ToolActionHandler:
|
||||
"""Build and instantiate CrewAI tool actions."""
|
||||
|
||||
action_type = FlowToolActionDefinition
|
||||
|
||||
def build(
|
||||
self, flow: Flow[Any], action: FlowToolActionDefinition
|
||||
) -> ResolvedAction:
|
||||
target = resolve_ref(action.ref, field="do")
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
if not (inspect.isclass(target) and issubclass(target, BaseTool)):
|
||||
raise InvalidRefError(
|
||||
f"invalid tool ref {action.ref!r}; expected a BaseTool class"
|
||||
)
|
||||
|
||||
try:
|
||||
tool_cls = cast(Callable[[], BaseTool], target)
|
||||
tool = tool_cls()
|
||||
except Exception as e:
|
||||
raise InvalidRefError(
|
||||
f"cannot instantiate tool ref {action.ref!r} without arguments: {e}"
|
||||
) from e
|
||||
|
||||
tool_kwargs = action.with_ or {}
|
||||
|
||||
def run_tool(*_args: Any, **kwargs: Any) -> Any:
|
||||
local_context = kwargs.pop(LOCAL_CONTEXT_KWARG, None)
|
||||
return tool.run(
|
||||
**render_with_block(flow, tool_kwargs, local_context=local_context)
|
||||
)
|
||||
|
||||
return run_tool
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import dataclasses
|
||||
from itertools import pairwise
|
||||
import json
|
||||
@@ -24,36 +25,25 @@ class FlowExpressionError(ValueError):
|
||||
"""A FlowDefinition expression failed to parse or evaluate."""
|
||||
|
||||
|
||||
def render_with_block(
|
||||
flow: Flow[Any], value: Any, local_context: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
def render_with_block(flow: Flow[Any], value: Any) -> Any:
|
||||
"""Render CEL expressions inside a FlowDefinition ``with:`` payload."""
|
||||
context = _expression_context(flow, local_context=local_context)
|
||||
context = _expression_context(flow)
|
||||
return _render_value(value, context)
|
||||
|
||||
|
||||
def evaluate_expression(
|
||||
flow: Flow[Any], expression: str, local_context: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
def evaluate_expression(flow: Flow[Any], expression: str) -> Any:
|
||||
"""Evaluate a FlowDefinition CEL expression against runtime context."""
|
||||
expression = expression.strip()
|
||||
if not expression:
|
||||
raise FlowExpressionError("empty CEL expression")
|
||||
return _eval_cel(expression, _expression_context(flow, local_context=local_context))
|
||||
return _eval_cel(expression, _expression_context(flow))
|
||||
|
||||
|
||||
def _expression_context(
|
||||
flow: Flow[Any], local_context: dict[str, Any] | None = None
|
||||
) -> dict[str, Any]:
|
||||
context = {
|
||||
def _expression_context(flow: Flow[Any]) -> dict[str, Any]:
|
||||
return {
|
||||
"state": flow._copy_and_serialize_state(),
|
||||
"outputs": _outputs_by_name(flow._method_outputs),
|
||||
}
|
||||
if local_context:
|
||||
context.update(
|
||||
{key: _to_json_safe(value) for key, value in local_context.items()}
|
||||
)
|
||||
return context
|
||||
|
||||
|
||||
def _outputs_by_name(method_outputs: list[Any]) -> dict[str, Any]:
|
||||
@@ -64,24 +54,15 @@ def _outputs_by_name(method_outputs: list[Any]) -> dict[str, Any]:
|
||||
if isinstance(entry, dict) and "output" in entry:
|
||||
method = str(entry.get("method", ""))
|
||||
output = entry["output"]
|
||||
outputs[method] = _to_json_safe(output)
|
||||
output = copy.deepcopy(output)
|
||||
if isinstance(output, BaseModel):
|
||||
output = output.model_dump(mode="json")
|
||||
elif dataclasses.is_dataclass(output) and not isinstance(output, type):
|
||||
output = dataclasses.asdict(output)
|
||||
outputs[method] = output
|
||||
return outputs
|
||||
|
||||
|
||||
def _to_json_safe(value: Any) -> Any:
|
||||
if isinstance(value, BaseModel):
|
||||
return value.model_dump(mode="json")
|
||||
if dataclasses.is_dataclass(value) and not isinstance(value, type):
|
||||
return dataclasses.asdict(value)
|
||||
if isinstance(value, dict):
|
||||
return {key: _to_json_safe(item) for key, item in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [_to_json_safe(item) for item in value]
|
||||
if isinstance(value, tuple):
|
||||
return [_to_json_safe(item) for item in value]
|
||||
return value
|
||||
|
||||
|
||||
def _render_value(value: Any, context: dict[str, Any]) -> Any:
|
||||
if isinstance(value, str):
|
||||
return _render_string(value, context)
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
"""Resolution of ``module:qualname`` refs into live Python objects."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
from operator import attrgetter
|
||||
from typing import Any
|
||||
|
||||
|
||||
class InvalidRefError(ValueError):
|
||||
"""A definition ref that cannot be resolved to a live object."""
|
||||
|
||||
|
||||
def resolve_ref(ref: str, *, field: str) -> Any:
|
||||
"""Import the object a definition's `module:qualname` ref points to."""
|
||||
module_name, _, qualname = ref.partition(":")
|
||||
if "<" in ref or not module_name or not qualname:
|
||||
raise InvalidRefError(
|
||||
f"invalid {field} ref {ref!r}; expected 'module:qualname'"
|
||||
)
|
||||
try:
|
||||
return attrgetter(qualname)(importlib.import_module(module_name))
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise InvalidRefError(f"unresolvable {field} ref {ref!r}") from e
|
||||
|
||||
|
||||
def resolve_instance_ref(ref: str, *, field: str) -> Any:
|
||||
"""Resolve a ref, auto-instantiating a no-arg class into an instance."""
|
||||
target = resolve_ref(ref, field=field)
|
||||
if not inspect.isclass(target):
|
||||
return target
|
||||
try:
|
||||
return target()
|
||||
except Exception as e:
|
||||
raise InvalidRefError(
|
||||
f"cannot instantiate {field} ref {ref!r} without arguments: {e}"
|
||||
) from e
|
||||
116
lib/crewai/src/crewai/flow/runtime/_resolvers.py
Normal file
116
lib/crewai/src/crewai/flow/runtime/_resolvers.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""Resolution of FlowDefinition refs (``module:qualname``) into live objects.
|
||||
|
||||
Every ref-shaped value in a definition — ``do`` actions, ``state.ref``,
|
||||
``config.input_provider``, ``human_feedback.provider`` — resolves through
|
||||
:func:`resolve_ref`. Failures are loud and name the field and the ref.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
import importlib
|
||||
import inspect
|
||||
from operator import attrgetter
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from crewai.flow.flow_definition import (
|
||||
FlowActionDefinition,
|
||||
FlowCodeActionDefinition,
|
||||
FlowExpressionActionDefinition,
|
||||
FlowToolActionDefinition,
|
||||
)
|
||||
from crewai.flow.runtime._expressions import evaluate_expression, render_with_block
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.flow.runtime import Flow
|
||||
|
||||
|
||||
class InvalidRefError(ValueError):
|
||||
"""A definition ref that cannot be resolved to a live object."""
|
||||
|
||||
|
||||
def resolve_ref(ref: str, *, field: str) -> Any:
|
||||
"""Import the object a definition's `module:qualname` ref points to."""
|
||||
module_name, _, qualname = ref.partition(":")
|
||||
if "<" in ref or not module_name or not qualname:
|
||||
raise InvalidRefError(
|
||||
f"invalid {field} ref {ref!r}; expected 'module:qualname'"
|
||||
)
|
||||
try:
|
||||
return attrgetter(qualname)(importlib.import_module(module_name))
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise InvalidRefError(f"unresolvable {field} ref {ref!r}") from e
|
||||
|
||||
|
||||
def resolve_instance_ref(ref: str, *, field: str) -> Any:
|
||||
"""Resolve a ref, auto-instantiating a no-arg class into an instance."""
|
||||
target = resolve_ref(ref, field=field)
|
||||
if not inspect.isclass(target):
|
||||
return target
|
||||
try:
|
||||
return target()
|
||||
except Exception as e:
|
||||
raise InvalidRefError(
|
||||
f"cannot instantiate {field} ref {ref!r} without arguments: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
def _resolve_code_action(
|
||||
flow: Flow[Any], action: FlowCodeActionDefinition
|
||||
) -> Callable[..., Any]:
|
||||
ref = action.ref
|
||||
target = resolve_ref(ref, field="do")
|
||||
if not callable(target):
|
||||
raise InvalidRefError(f"invalid do ref {ref!r}; object is not callable")
|
||||
handler = cast(Callable[..., Any], target)
|
||||
if getattr(handler, "__self__", None) is None:
|
||||
handler = handler.__get__(flow, type(flow))
|
||||
return handler
|
||||
|
||||
|
||||
def _resolve_tool_action(
|
||||
flow: Flow[Any], action: FlowToolActionDefinition
|
||||
) -> Callable[..., Any]:
|
||||
target = resolve_ref(action.ref, field="do")
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
if not (inspect.isclass(target) and issubclass(target, BaseTool)):
|
||||
raise InvalidRefError(
|
||||
f"invalid tool ref {action.ref!r}; expected a BaseTool class"
|
||||
)
|
||||
|
||||
try:
|
||||
tool_cls = cast(Callable[[], BaseTool], target)
|
||||
tool = tool_cls()
|
||||
except Exception as e:
|
||||
raise InvalidRefError(
|
||||
f"cannot instantiate tool ref {action.ref!r} without arguments: {e}"
|
||||
) from e
|
||||
|
||||
tool_kwargs = action.with_ or {}
|
||||
|
||||
def run_tool(*_args: Any, **_kwargs: Any) -> Any:
|
||||
return tool.run(**render_with_block(flow, tool_kwargs))
|
||||
|
||||
return run_tool
|
||||
|
||||
|
||||
def _resolve_expression_action(
|
||||
flow: Flow[Any], action: FlowExpressionActionDefinition
|
||||
) -> Callable[..., Any]:
|
||||
def run_expression(*_args: Any, **_kwargs: Any) -> Any:
|
||||
return evaluate_expression(flow, action.expr)
|
||||
|
||||
return run_expression
|
||||
|
||||
|
||||
def resolve_action(flow: Flow[Any], action: FlowActionDefinition) -> Callable[..., Any]:
|
||||
"""Turn one `do:` action into the callable the flow runs for that node."""
|
||||
if action.call == "code":
|
||||
return _resolve_code_action(flow, action)
|
||||
if action.call == "tool":
|
||||
return _resolve_tool_action(flow, action)
|
||||
if action.call == "expression":
|
||||
return _resolve_expression_action(flow, action)
|
||||
raise ValueError(f"unknown call type {action.call!r}")
|
||||
@@ -1,3 +1,4 @@
|
||||
from collections.abc import Callable
|
||||
import logging
|
||||
import traceback
|
||||
from typing import Any, cast
|
||||
@@ -32,6 +33,16 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
| type[BaseEmbeddingsProvider[Any]]
|
||||
| None
|
||||
) = Field(default=None, exclude=True)
|
||||
content_filter: Callable[[list[str]], list[str]] | None = Field(
|
||||
default=None,
|
||||
exclude=True,
|
||||
description=(
|
||||
"Optional callable that inspects and filters documents before "
|
||||
"they are indexed. Receives the full document list and must "
|
||||
"return the (possibly filtered) list to persist. Raise an "
|
||||
"exception inside the callable to abort the save entirely."
|
||||
),
|
||||
)
|
||||
_client: BaseClient | None = PrivateAttr(default=None)
|
||||
|
||||
@model_validator(mode="after")
|
||||
@@ -106,6 +117,11 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
if not documents:
|
||||
return
|
||||
|
||||
if self.content_filter is not None:
|
||||
documents = self.content_filter(documents)
|
||||
if not documents:
|
||||
return
|
||||
|
||||
try:
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
@@ -187,6 +203,11 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
if not documents:
|
||||
return
|
||||
|
||||
if self.content_filter is not None:
|
||||
documents = self.content_filter(documents)
|
||||
if not documents:
|
||||
return
|
||||
|
||||
try:
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
|
||||
@@ -193,3 +193,118 @@ def test_dimension_mismatch_error_handling(mock_get_client: MagicMock) -> None:
|
||||
|
||||
with pytest.raises(ValueError, match="Embedding dimension mismatch"):
|
||||
storage.save(["test document"])
|
||||
|
||||
|
||||
# --- content_filter tests ---
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_content_filter_removes_documents(mock_get_client: MagicMock) -> None:
|
||||
"""content_filter can drop specific documents before indexing."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
def reject_secrets(docs: list[str]) -> list[str]:
|
||||
return [d for d in docs if "SECRET" not in d]
|
||||
|
||||
storage = KnowledgeStorage(
|
||||
collection_name="filter_test", content_filter=reject_secrets
|
||||
)
|
||||
storage.save(["safe content", "contains SECRET key", "also safe"])
|
||||
|
||||
mock_client.add_documents.assert_called_once()
|
||||
added = mock_client.add_documents.call_args.kwargs["documents"]
|
||||
contents = [doc["content"] for doc in added]
|
||||
assert contents == ["safe content", "also safe"]
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_content_filter_returns_empty_skips_save(mock_get_client: MagicMock) -> None:
|
||||
"""When content_filter filters out all documents, save is skipped entirely."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
storage = KnowledgeStorage(
|
||||
collection_name="empty_filter", content_filter=lambda docs: []
|
||||
)
|
||||
storage.save(["doc1", "doc2"])
|
||||
|
||||
mock_client.add_documents.assert_not_called()
|
||||
mock_client.get_or_create_collection.assert_not_called()
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_content_filter_exception_propagates(mock_get_client: MagicMock) -> None:
|
||||
"""Exceptions raised inside content_filter abort the save."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
def strict_filter(docs: list[str]) -> list[str]:
|
||||
raise ValueError("Blocked by policy")
|
||||
|
||||
storage = KnowledgeStorage(
|
||||
collection_name="strict_test", content_filter=strict_filter
|
||||
)
|
||||
with pytest.raises(ValueError, match="Blocked by policy"):
|
||||
storage.save(["some content"])
|
||||
|
||||
mock_client.add_documents.assert_not_called()
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_content_filter_none_is_noop(mock_get_client: MagicMock) -> None:
|
||||
"""When content_filter is None (default), all documents are saved."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
storage = KnowledgeStorage(collection_name="noop_test")
|
||||
assert storage.content_filter is None
|
||||
storage.save(["doc1", "doc2"])
|
||||
|
||||
mock_client.add_documents.assert_called_once()
|
||||
added = mock_client.add_documents.call_args.kwargs["documents"]
|
||||
assert len(added) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
async def test_content_filter_async_save(mock_get_client: MagicMock) -> None:
|
||||
"""content_filter is applied in asave() as well."""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.aget_or_create_collection = AsyncMock()
|
||||
mock_client.aadd_documents = AsyncMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
def only_short(docs: list[str]) -> list[str]:
|
||||
return [d for d in docs if len(d) < 20]
|
||||
|
||||
storage = KnowledgeStorage(
|
||||
collection_name="async_filter", content_filter=only_short
|
||||
)
|
||||
await storage.asave(["short", "this is a much longer document string"])
|
||||
|
||||
mock_client.aadd_documents.assert_called_once()
|
||||
added = mock_client.aadd_documents.call_args.kwargs["documents"]
|
||||
assert len(added) == 1
|
||||
assert added[0]["content"] == "short"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
async def test_content_filter_async_all_filtered(mock_get_client: MagicMock) -> None:
|
||||
"""asave() skips persistence when content_filter removes everything."""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.aget_or_create_collection = AsyncMock()
|
||||
mock_client.aadd_documents = AsyncMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
storage = KnowledgeStorage(
|
||||
collection_name="async_empty", content_filter=lambda docs: []
|
||||
)
|
||||
await storage.asave(["doc1"])
|
||||
|
||||
mock_client.aadd_documents.assert_not_called()
|
||||
|
||||
@@ -44,8 +44,6 @@ def test_flow_public_exports_are_explicit():
|
||||
"FlowDefinition",
|
||||
"FlowDefinitionCondition",
|
||||
"FlowDefinitionDiagnostic",
|
||||
"FlowEachActionDefinition",
|
||||
"FlowEachInnerActionDefinition",
|
||||
"FlowExpressionActionDefinition",
|
||||
"FlowHumanFeedbackDefinition",
|
||||
"FlowMethodDefinition",
|
||||
@@ -434,73 +432,6 @@ def test_flow_definition_round_trips_json_and_yaml():
|
||||
assert yaml_round_trip.methods["decide"].listen == "begin"
|
||||
|
||||
|
||||
def test_each_action_round_trips_json_and_yaml():
|
||||
definition = flow_definition.FlowDefinition.from_dict(
|
||||
{
|
||||
"schema": "crewai.flow/v1",
|
||||
"name": "EachFlow",
|
||||
"methods": {
|
||||
"process_rows": {
|
||||
"description": "Process every loaded row.",
|
||||
"start": True,
|
||||
"do": {
|
||||
"call": "each",
|
||||
"in": "state.rows",
|
||||
"do": [
|
||||
{
|
||||
"normalize": {
|
||||
"call": "tool",
|
||||
"ref": "my_tools:NormalizeRowTool",
|
||||
"with": {"row": "${ item }"},
|
||||
}
|
||||
},
|
||||
{
|
||||
"save": {
|
||||
"call": "code",
|
||||
"ref": "my_flow:save_row",
|
||||
"with": {
|
||||
"row": "${ item }",
|
||||
"normalized": "${ outputs.normalize }",
|
||||
},
|
||||
}
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
json_round_trip = flow_definition.FlowDefinition.from_json(definition.to_json())
|
||||
yaml_round_trip = flow_definition.FlowDefinition.from_yaml(definition.to_yaml())
|
||||
|
||||
assert json_round_trip.to_dict() == definition.to_dict()
|
||||
assert yaml_round_trip.to_dict() == definition.to_dict()
|
||||
assert yaml_round_trip.methods["process_rows"].description == (
|
||||
"Process every loaded row."
|
||||
)
|
||||
assert yaml_round_trip.methods["process_rows"].do.call == "each"
|
||||
|
||||
|
||||
def test_flow_definition_rejects_invalid_method_names():
|
||||
with pytest.raises(ValueError, match="Flow method names must match"):
|
||||
flow_definition.FlowDefinition.from_dict(
|
||||
{
|
||||
"schema": "crewai.flow/v1",
|
||||
"name": "InvalidMethodNameFlow",
|
||||
"methods": {
|
||||
"process-rows": {
|
||||
"start": True,
|
||||
"do": {
|
||||
"call": "expression",
|
||||
"expr": "'done'",
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_flow_definition_detects_persist_metadata():
|
||||
@persist(verbose=True)
|
||||
class PersistedFlow(Flow[dict]):
|
||||
|
||||
@@ -67,26 +67,6 @@ class ToolInputFlow(Flow):
|
||||
return {"query": "ai agents", "suffix": " news"}
|
||||
|
||||
|
||||
class EachActionFlow(Flow):
|
||||
def normalize_row(self, row: str, prefix: str = "normalized") -> str:
|
||||
return f"{prefix}:{row}"
|
||||
|
||||
def save_row(self, row: str, normalized: str) -> dict[str, str]:
|
||||
return {"row": row, "normalized": normalized}
|
||||
|
||||
def keyword_code(self, name: str, punctuation: str) -> str:
|
||||
return f"{name}{punctuation}"
|
||||
|
||||
def fail_on_bad_row(self, row: str) -> str:
|
||||
if row == "bad":
|
||||
raise RuntimeError("bad row")
|
||||
return row
|
||||
|
||||
def after_each(self) -> str:
|
||||
self.state["after_count"] = self.state.get("after_count", 0) + 1
|
||||
return f"after:{self.state['after_count']}"
|
||||
|
||||
|
||||
CHAIN_YAML = f"""
|
||||
schema: crewai.flow/v1
|
||||
name: ChainFlow
|
||||
@@ -747,274 +727,6 @@ methods:
|
||||
flow.kickoff()
|
||||
|
||||
|
||||
def test_code_action_renders_keyword_inputs():
|
||||
yaml_str = f"""
|
||||
schema: crewai.flow/v1
|
||||
name: CodeWithFlow
|
||||
methods:
|
||||
greet:
|
||||
do:
|
||||
call: code
|
||||
ref: {__name__}:EachActionFlow.keyword_code
|
||||
with:
|
||||
name: "${{state.name}}"
|
||||
punctuation: "!"
|
||||
start: true
|
||||
"""
|
||||
|
||||
flow = Flow.from_definition(FlowDefinition.from_yaml(yaml_str))
|
||||
|
||||
assert flow.kickoff(inputs={"name": "hello"}) == "hello!"
|
||||
|
||||
|
||||
def test_each_action_executes_one_nested_code_action():
|
||||
yaml_str = f"""
|
||||
schema: crewai.flow/v1
|
||||
name: EachFlow
|
||||
methods:
|
||||
process_rows:
|
||||
do:
|
||||
call: each
|
||||
in: state.rows
|
||||
do:
|
||||
- normalize:
|
||||
call: code
|
||||
ref: {__name__}:EachActionFlow.normalize_row
|
||||
with:
|
||||
row: "${{item}}"
|
||||
start: true
|
||||
"""
|
||||
|
||||
flow = Flow.from_definition(FlowDefinition.from_yaml(yaml_str))
|
||||
|
||||
assert flow.kickoff(inputs={"rows": ["a", "b"]}) == [
|
||||
"normalized:a",
|
||||
"normalized:b",
|
||||
]
|
||||
|
||||
|
||||
def test_each_action_uses_iteration_outputs_between_nested_actions():
|
||||
yaml_str = f"""
|
||||
schema: crewai.flow/v1
|
||||
name: EachFlow
|
||||
methods:
|
||||
process_rows:
|
||||
do:
|
||||
call: each
|
||||
in: state.rows
|
||||
do:
|
||||
- normalize:
|
||||
call: code
|
||||
ref: {__name__}:EachActionFlow.normalize_row
|
||||
with:
|
||||
row: "${{item}}"
|
||||
prefix: saved
|
||||
- save:
|
||||
call: code
|
||||
ref: {__name__}:EachActionFlow.save_row
|
||||
with:
|
||||
row: "${{item}}"
|
||||
normalized: "${{outputs.normalize}}"
|
||||
start: true
|
||||
"""
|
||||
|
||||
flow = Flow.from_definition(FlowDefinition.from_yaml(yaml_str))
|
||||
|
||||
assert flow.kickoff(inputs={"rows": ["a", "b"]}) == [
|
||||
{"row": "a", "normalized": "saved:a"},
|
||||
{"row": "b", "normalized": "saved:b"},
|
||||
]
|
||||
|
||||
|
||||
def test_each_action_resets_inner_outputs_between_iterations():
|
||||
yaml_str = """
|
||||
schema: crewai.flow/v1
|
||||
name: EachFlow
|
||||
methods:
|
||||
process_rows:
|
||||
do:
|
||||
call: each
|
||||
in: state.rows
|
||||
do:
|
||||
- leak_check:
|
||||
call: expression
|
||||
expr: "has(outputs.previous) ? outputs.previous : 'empty'"
|
||||
- previous:
|
||||
call: expression
|
||||
expr: item
|
||||
start: true
|
||||
"""
|
||||
|
||||
flow = Flow.from_definition(FlowDefinition.from_yaml(yaml_str))
|
||||
|
||||
assert flow.kickoff(inputs={"rows": ["a", "b"]}) == ["a", "b"]
|
||||
assert flow._method_outputs == [
|
||||
{"method": "process_rows", "output": ["a", "b"]}
|
||||
]
|
||||
|
||||
|
||||
def test_each_action_empty_list_returns_empty_and_listener_runs_once():
|
||||
yaml_str = f"""
|
||||
schema: crewai.flow/v1
|
||||
name: EachFlow
|
||||
methods:
|
||||
process_rows:
|
||||
do:
|
||||
call: each
|
||||
in: state.rows
|
||||
do:
|
||||
- normalize:
|
||||
call: code
|
||||
ref: {__name__}:EachActionFlow.normalize_row
|
||||
with:
|
||||
row: "${{item}}"
|
||||
start: true
|
||||
after_each:
|
||||
do:
|
||||
call: code
|
||||
ref: {__name__}:EachActionFlow.after_each
|
||||
listen: process_rows
|
||||
"""
|
||||
|
||||
flow = Flow.from_definition(FlowDefinition.from_yaml(yaml_str))
|
||||
events = []
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(MethodExecutionFinishedEvent)
|
||||
def on_finished(source, event):
|
||||
events.append(event.method_name)
|
||||
|
||||
result = flow.kickoff(inputs={"rows": []})
|
||||
|
||||
assert result == "after:1"
|
||||
assert flow.method_outputs == [[], "after:1"]
|
||||
assert flow.state["after_count"] == 1
|
||||
assert events.count("process_rows") == 1
|
||||
assert events.count("after_each") == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("expr", "inputs"),
|
||||
[
|
||||
("1", {}),
|
||||
('"rows"', {}),
|
||||
("state.rows", {"rows": {"a": 1}}),
|
||||
],
|
||||
)
|
||||
def test_each_action_rejects_non_list_inputs(expr, inputs):
|
||||
definition = FlowDefinition.from_dict(
|
||||
{
|
||||
"schema": "crewai.flow/v1",
|
||||
"name": "EachFlow",
|
||||
"methods": {
|
||||
"process_rows": {
|
||||
"start": True,
|
||||
"do": {
|
||||
"call": "each",
|
||||
"in": expr,
|
||||
"do": [{"value": {"call": "expression", "expr": "item"}}],
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
flow = Flow.from_definition(definition)
|
||||
|
||||
with pytest.raises(ValueError, match="each.in must evaluate to an array"):
|
||||
flow.kickoff(inputs=inputs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"action_do",
|
||||
[
|
||||
[],
|
||||
[{"first": {"call": "expression", "expr": "item"}, "second": {"call": "expression", "expr": "item"}}],
|
||||
[{"1bad": {"call": "expression", "expr": "item"}}],
|
||||
[
|
||||
{"same": {"call": "expression", "expr": "item"}},
|
||||
{"same": {"call": "expression", "expr": "item"}},
|
||||
],
|
||||
],
|
||||
)
|
||||
def test_each_action_validates_inner_action_shape(action_do):
|
||||
with pytest.raises(ValidationError):
|
||||
FlowDefinition.from_dict(
|
||||
{
|
||||
"schema": "crewai.flow/v1",
|
||||
"name": "EachFlow",
|
||||
"methods": {
|
||||
"process_rows": {
|
||||
"start": True,
|
||||
"do": {
|
||||
"call": "each",
|
||||
"in": "state.rows",
|
||||
"do": action_do,
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_each_action_rejects_nested_each_actions():
|
||||
with pytest.raises(ValidationError):
|
||||
FlowDefinition.from_dict(
|
||||
{
|
||||
"schema": "crewai.flow/v1",
|
||||
"name": "EachFlow",
|
||||
"methods": {
|
||||
"process_rows": {
|
||||
"start": True,
|
||||
"do": {
|
||||
"call": "each",
|
||||
"in": "state.rows",
|
||||
"do": [
|
||||
{
|
||||
"nested": {
|
||||
"call": "each",
|
||||
"in": "state.children",
|
||||
"do": [
|
||||
{
|
||||
"child": {
|
||||
"call": "expression",
|
||||
"expr": "item",
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_each_action_failure_fails_outer_method():
|
||||
yaml_str = f"""
|
||||
schema: crewai.flow/v1
|
||||
name: EachFlow
|
||||
methods:
|
||||
process_rows:
|
||||
do:
|
||||
call: each
|
||||
in: state.rows
|
||||
do:
|
||||
- validate:
|
||||
call: code
|
||||
ref: {__name__}:EachActionFlow.fail_on_bad_row
|
||||
with:
|
||||
row: "${{item}}"
|
||||
start: true
|
||||
"""
|
||||
|
||||
flow = Flow.from_definition(FlowDefinition.from_yaml(yaml_str))
|
||||
|
||||
with pytest.raises(RuntimeError, match="bad row"):
|
||||
flow.kickoff(inputs={"rows": ["ok", "bad"]})
|
||||
|
||||
|
||||
def test_expression_action_round_trips():
|
||||
definition = FlowDefinition.from_dict(
|
||||
{
|
||||
@@ -1118,6 +830,26 @@ def test_tool_action_requires_module_qualname_ref():
|
||||
Flow.from_definition(definition)
|
||||
|
||||
|
||||
def test_code_action_rejects_tool_inputs():
|
||||
with pytest.raises(ValidationError):
|
||||
FlowDefinition.from_dict(
|
||||
{
|
||||
"schema": "crewai.flow/v1",
|
||||
"name": "InvalidCodeActionFlow",
|
||||
"methods": {
|
||||
"begin": {
|
||||
"start": True,
|
||||
"do": {
|
||||
"call": "code",
|
||||
"ref": f"{__name__}:ChainFlow.begin",
|
||||
"with": {"search_query": "ai agents"},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_pydantic_state_from_ref_parity():
|
||||
flow, result = assert_parity(PydanticStateFlow, PYDANTIC_STATE_YAML)
|
||||
assert result == "count=1"
|
||||
|
||||
Reference in New Issue
Block a user