Compare commits

..

1 Commits

Author SHA1 Message Date
Vinicius Brasil
02eeefe5ea Add each composite action to FlowDefinition
Lets a definition loop over an array without writing Python. Each
iteration exposes `item` and prior steps `outputs`.

```yaml
do:
  call: each
  in: state.rows
  do:
    - normalize:
        call: tool
        ref: my_tools:NormalizeRowTool
        with: { row: "${ item }" }
    - lead_scoring:
        call: agent
        # ...
```
2026-06-14 16:05:25 -07:00
18 changed files with 847 additions and 448 deletions

View File

@@ -1,158 +0,0 @@
"""SSRF-safe HTTP fetching for crewai-tools.
:func:`validate_url` checks the URL it is handed, but it cannot protect a
fetch on its own: ``requests`` re-resolves DNS at connect time and follows
redirects automatically, so a public-looking host that 302-redirects to an
internal address (or that rebinds DNS between validation and connect) reaches
the internal target without ever being re-checked.
This module closes both gaps at the connection layer:
* :class:`SSRFProtectedAdapter` re-runs :func:`validate_url` for every request
it sends. ``requests.Session.send`` invokes the adapter once per redirect
hop, so each ``Location`` target is validated before it is followed.
* The adapter's connections validate the *actual* peer IP immediately after
the socket connects. The IP that was authorised is therefore the IP the
connection uses, removing the DNS time-of-check/time-of-use gap that
:func:`validate_url`'s own ``getaddrinfo`` call leaves open.
Use :func:`safe_get` (or :func:`create_safe_session`) instead of calling
``requests.get`` directly from any tool that fetches a user- or
LLM-controlled URL.
"""
from __future__ import annotations
from typing import Any
import requests
from requests.adapters import DEFAULT_POOLBLOCK, HTTPAdapter
from urllib3.connection import HTTPConnection, HTTPSConnection
from urllib3.connectionpool import HTTPConnectionPool, HTTPSConnectionPool
from urllib3.poolmanager import PoolManager
from crewai_tools.security.safe_path import (
_is_escape_hatch_enabled,
_is_private_or_reserved,
validate_url,
)
def _assert_safe_peer(sock: Any) -> None:
"""Raise if a connected socket's peer is a private/reserved address.
Validating the real peer (rather than a separately resolved IP) is what
defeats DNS rebinding: the address we connected to is the address we check.
"""
if _is_escape_hatch_enabled():
return
try:
peer = sock.getpeername()
except OSError:
return
ip_str = str(peer[0])
if _is_private_or_reserved(ip_str):
raise ValueError(
f"Connection resolved to private/reserved IP {ip_str}. "
f"Access to internal networks is not allowed (possible SSRF via "
f"redirect or DNS rebinding)."
)
class _SafeHTTPConnection(HTTPConnection):
def connect(self) -> None:
super().connect()
_assert_safe_peer(self.sock)
class _SafeHTTPSConnection(HTTPSConnection):
def connect(self) -> None:
super().connect()
_assert_safe_peer(self.sock)
class _SafeHTTPConnectionPool(HTTPConnectionPool):
ConnectionCls = _SafeHTTPConnection
class _SafeHTTPSConnectionPool(HTTPSConnectionPool):
ConnectionCls = _SafeHTTPSConnection
_SAFE_POOL_CLASSES = {
"http": _SafeHTTPConnectionPool,
"https": _SafeHTTPSConnectionPool,
}
class _SafePoolManager(PoolManager):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.pool_classes_by_scheme = _SAFE_POOL_CLASSES
class SSRFProtectedAdapter(HTTPAdapter):
"""Transport adapter that re-validates every hop and pins the peer IP.
``validate_url`` runs on each ``send`` — including every redirect hop
``requests`` follows — and the underlying connections reject any socket
that ends up connected to a private/reserved address.
"""
def init_poolmanager(
self,
connections: int,
maxsize: int,
block: bool = DEFAULT_POOLBLOCK,
**pool_kwargs: Any,
) -> None:
self.poolmanager = _SafePoolManager(
num_pools=connections,
maxsize=maxsize,
block=block,
**pool_kwargs,
)
def send(self, request: Any, *args: Any, **kwargs: Any) -> Any:
# Re-validate the target of every request the session sends. Because
# Session.send calls this once per redirect hop, each Location is
# checked before it is followed.
validate_url(request.url)
return super().send(request, *args, **kwargs)
def create_safe_session() -> requests.Session:
"""Return a ``requests.Session`` that is hardened against SSRF.
The session validates every request (and redirect hop) and pins
connections to the validated peer IP.
"""
session = requests.Session()
adapter = SSRFProtectedAdapter()
session.mount("http://", adapter)
session.mount("https://", adapter)
return session
def safe_get(url: str, **kwargs: Any) -> requests.Response:
"""Perform an SSRF-safe ``GET``.
Drop-in replacement for ``requests.get`` for tools that fetch a
user- or LLM-controlled URL. Validates the initial URL and every redirect
hop, and rejects connections that land on private/reserved addresses.
Args:
url: The URL to fetch.
**kwargs: Forwarded to ``Session.get`` (``headers``, ``cookies``,
``timeout``, ...).
Returns:
The ``requests.Response``.
Raises:
ValueError: If the URL, a redirect target, or the connected peer is
not allowed.
"""
validate_url(url)
with create_safe_session() as session:
return session.get(url, **kwargs)

View File

@@ -3,8 +3,9 @@ from typing import Any
from crewai.tools import BaseTool
from pydantic import BaseModel, Field
import requests
from crewai_tools.security.safe_requests import safe_get
from crewai_tools.security.safe_path import validate_url
try:
@@ -82,7 +83,8 @@ class ScrapeElementFromWebsiteTool(BaseTool):
if website_url is None or css_element is None:
raise ValueError("Both website_url and css_element must be provided.")
page = safe_get(
website_url = validate_url(website_url)
page = requests.get(
website_url,
headers=self.headers,
cookies=self.cookies if self.cookies else {},

View File

@@ -3,8 +3,9 @@ import re
from typing import Any
from pydantic import Field
import requests
from crewai_tools.security.safe_requests import safe_get
from crewai_tools.security.safe_path import validate_url
try:
@@ -74,7 +75,8 @@ class ScrapeWebsiteTool(BaseTool):
if website_url is None:
raise ValueError("Website URL must be provided.")
page = safe_get(
website_url = validate_url(website_url)
page = requests.get(
website_url,
timeout=15,
headers=self.headers,

View File

@@ -1,124 +0,0 @@
"""Tests for SSRF-safe HTTP fetching (redirect + DNS-rebinding protection)."""
from __future__ import annotations
import http.server
import socketserver
import threading
import pytest
import requests
from crewai_tools.security import safe_requests
from crewai_tools.security.safe_requests import (
SSRFProtectedAdapter,
create_safe_session,
safe_get,
)
INTERNAL_BODY = b"INTERNAL-ONLY-SECRET"
class _InternalHandler(http.server.BaseHTTPRequestHandler):
def do_GET(self):
self.send_response(200)
self.send_header("Content-Type", "text/plain")
self.end_headers()
self.wfile.write(INTERNAL_BODY)
def log_message(self, *args): # silence
pass
def _serve(handler):
"""Start a localhost server on an ephemeral port; return (server, port)."""
server = socketserver.TCPServer(("127.0.0.1", 0), handler)
port = server.server_address[1]
threading.Thread(target=server.serve_forever, daemon=True).start()
return server, port
class TestRedirectRevalidation:
"""Layer 1: validate_url runs on every send, including each redirect hop.
``requests.Session.send`` calls ``adapter.send`` once per redirect hop, so
re-validating in ``send`` is what blocks a 302 to an internal target.
"""
def test_adapter_revalidates_before_any_network_call(self, monkeypatch):
calls: list[str] = []
def spy(url: str) -> str:
calls.append(url)
if "internal.target" in url:
raise ValueError("URL resolves to private/reserved IP")
return url
monkeypatch.setattr(safe_requests, "validate_url", spy)
adapter = SSRFProtectedAdapter()
# Internal redirect target: send() must reject it before ever calling
# the real transport (super().send is never reached).
req = requests.Request("GET", "http://internal.target/").prepare()
with pytest.raises(ValueError, match="private/reserved"):
adapter.send(req)
assert calls == ["http://internal.target/"]
def test_session_mounts_protected_adapter(self):
session = create_safe_session()
assert isinstance(session.get_adapter("http://x"), SSRFProtectedAdapter)
assert isinstance(session.get_adapter("https://x"), SSRFProtectedAdapter)
class _FakeSock:
def __init__(self, peer):
self._peer = peer
def getpeername(self):
return self._peer
class TestConnectionPeerGuard:
"""Layer 2: the connection rejects an internal peer IP at connect time.
This is what closes the validate-then-connect DNS-rebinding gap — the IP
the socket actually connected to is the IP that gets checked, so a host
that resolved public at validation time but connects internal is blocked.
"""
def test_safe_get_blocks_direct_internal(self):
# No network: validate_url rejects 127.0.0.1 at the URL layer first.
with pytest.raises(ValueError, match="private/reserved"):
safe_get("http://127.0.0.1:9/", timeout=10)
def test_assert_safe_peer_blocks_private(self):
with pytest.raises(ValueError, match="private/reserved"):
safe_requests._assert_safe_peer(_FakeSock(("127.0.0.1", 80)))
def test_assert_safe_peer_blocks_metadata(self):
with pytest.raises(ValueError, match="private/reserved"):
safe_requests._assert_safe_peer(_FakeSock(("169.254.169.254", 80)))
def test_assert_safe_peer_allows_public(self):
# A public IP must not raise.
safe_requests._assert_safe_peer(_FakeSock(("93.184.216.34", 80)))
def test_assert_safe_peer_respects_escape_hatch(self, monkeypatch):
monkeypatch.setenv("CREWAI_TOOLS_ALLOW_UNSAFE_PATHS", "true")
# No raise even for a private peer when the escape hatch is on.
safe_requests._assert_safe_peer(_FakeSock(("127.0.0.1", 80)))
def test_connection_validates_peer_after_connect(self, monkeypatch):
"""_SafeHTTPConnection.connect runs the peer guard after connecting."""
conn = safe_requests._SafeHTTPConnection("example.com")
def fake_super_connect(self):
# Simulate a rebind: we connected to an internal address.
self.sock = _FakeSock(("127.0.0.1", 80))
monkeypatch.setattr(
safe_requests.HTTPConnection, "connect", fake_super_connect
)
with pytest.raises(ValueError, match="private/reserved"):
conn.connect()

View File

@@ -11,9 +11,17 @@ from __future__ import annotations
import json
import logging
import re
from typing import Any, Literal as TypingLiteral
from pydantic import BaseModel, ConfigDict, Field, field_serializer, model_validator
from pydantic import (
BaseModel,
ConfigDict,
Field,
RootModel,
field_serializer,
model_validator,
)
import yaml
from crewai.flow.conversational_definition import (
@@ -25,6 +33,7 @@ from crewai.flow.conversational_definition import (
logger = logging.getLogger(__name__)
FlowDefinitionCondition = str | dict[str, Any]
_STEP_NAME_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
__all__ = [
"FlowActionDefinition",
@@ -35,6 +44,8 @@ __all__ = [
"FlowDefinition",
"FlowDefinitionCondition",
"FlowDefinitionDiagnostic",
"FlowEachActionDefinition",
"FlowEachInnerActionDefinition",
"FlowExpressionActionDefinition",
"FlowHumanFeedbackDefinition",
"FlowMethodDefinition",
@@ -148,10 +159,11 @@ class FlowHumanFeedbackDefinition(BaseModel):
class FlowCodeActionDefinition(BaseModel):
"""A Flow method action that executes importable Python code."""
model_config = ConfigDict(extra="forbid")
model_config = ConfigDict(populate_by_name=True, extra="forbid")
call: TypingLiteral["code"] = "code"
ref: str
with_: dict[str, Any] | None = Field(default=None, alias="with")
class FlowToolActionDefinition(BaseModel):
@@ -173,14 +185,75 @@ class FlowExpressionActionDefinition(BaseModel):
expr: str
FlowActionDefinition = (
FlowInnerActionDefinition = (
FlowCodeActionDefinition | FlowToolActionDefinition | FlowExpressionActionDefinition
)
class FlowEachInnerActionDefinition(RootModel[dict[str, FlowInnerActionDefinition]]):
"""One named action inside an ``each`` composite action."""
@property
def name(self) -> str:
return next(iter(self.root))
@property
def action(self) -> FlowInnerActionDefinition:
return next(iter(self.root.values()))
class FlowEachActionDefinition(BaseModel):
"""A composite action that runs a sequential mini-pipeline for each item."""
model_config = ConfigDict(populate_by_name=True, extra="forbid")
call: TypingLiteral["each"]
in_: str = Field(alias="in")
do: list[FlowEachInnerActionDefinition]
@model_validator(mode="before")
@classmethod
def _validate_inner_action_list(cls, data: Any) -> Any:
if not isinstance(data, dict) or "do" not in data:
return data
inner_actions = data["do"]
if not isinstance(inner_actions, list) or not inner_actions:
raise ValueError("each.do must contain at least one action")
seen: set[str] = set()
for inner_action in inner_actions:
if isinstance(inner_action, FlowEachInnerActionDefinition):
action_mapping = inner_action.root
elif isinstance(inner_action, dict):
action_mapping = inner_action
else:
raise ValueError("each.do entries must be one-key mappings")
if len(action_mapping) != 1:
raise ValueError("each.do entries must be one-key mappings")
name = next(iter(action_mapping))
_validate_step_name(name, field="each.do action names")
if name in seen:
raise ValueError(f"each.do action names must be unique: {name!r}")
seen.add(name)
return data
FlowActionDefinition = (
FlowCodeActionDefinition
| FlowToolActionDefinition
| FlowExpressionActionDefinition
| FlowEachActionDefinition
)
class FlowMethodDefinition(BaseModel):
"""Static definition of one Flow method and its execution roles."""
description: str | None = None
do: FlowActionDefinition
start: bool | FlowDefinitionCondition | None = None
listen: FlowDefinitionCondition | None = None
@@ -227,6 +300,12 @@ class FlowDefinition(BaseModel):
methods: dict[str, FlowMethodDefinition] = Field(default_factory=dict)
diagnostics: list[FlowDefinitionDiagnostic] = Field(default_factory=list)
@model_validator(mode="after")
def _validate_method_names(self) -> FlowDefinition:
for method_name in self.methods:
_validate_step_name(method_name, field="Flow method names")
return self
def to_dict(self, *, exclude_none: bool = True) -> dict[str, Any]:
"""Serialize the definition to a JSON/YAML-ready dictionary."""
return self.model_dump(by_alias=True, exclude_none=exclude_none, mode="json")
@@ -369,6 +448,11 @@ def _deserialize_diagnostics(value: Any) -> list[FlowDefinitionDiagnostic]:
return [FlowDefinitionDiagnostic.model_validate(item) for item in value or []]
def _validate_step_name(name: str, *, field: str) -> None:
if not isinstance(name, str) or not _STEP_NAME_PATTERN.fullmatch(name):
raise ValueError(f"{field} must match {_STEP_NAME_PATTERN.pattern}")
def _merge_diagnostics(
*diagnostic_groups: list[FlowDefinitionDiagnostic],
) -> list[FlowDefinitionDiagnostic]:

View File

@@ -121,11 +121,8 @@ from crewai.flow.human_feedback import (
)
from crewai.flow.input_provider import InputProvider
from crewai.flow.persistence.base import FlowPersistence
from crewai.flow.runtime._resolvers import (
resolve_action,
resolve_instance_ref,
resolve_ref,
)
from crewai.flow.runtime._actions import build_action
from crewai.flow.runtime._refs import resolve_instance_ref, resolve_ref
from crewai.flow.types import (
FlowExecutionData,
FlowMethodName,
@@ -1092,9 +1089,9 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
self._methods.update(methods)
def _action_bound_methods(self) -> dict[FlowMethodName, Callable[..., Any]]:
def resolve(name: str, definition: FlowMethodDefinition) -> Callable[..., Any]:
def build(name: str, definition: FlowMethodDefinition) -> Callable[..., Any]:
try:
return resolve_action(self, definition.do)
return build_action(self, definition.do)
except Exception as e:
unresolved.append(f"{name}: {e}")
return lambda *args, **kwargs: None
@@ -1102,9 +1099,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
methods: dict[FlowMethodName, Callable[..., Any]] = {}
unresolved: list[str] = []
for method_name, method_definition in self._definition.methods.items():
methods[FlowMethodName(method_name)] = resolve(
method_name, method_definition
)
methods[FlowMethodName(method_name)] = build(method_name, method_definition)
if unresolved:
raise ValueError(
f"Cannot build flow {self._definition.name!r} from its definition; "

View File

@@ -0,0 +1,48 @@
"""Build FlowDefinition actions into live runtime callables."""
from __future__ import annotations
from collections.abc import Callable
from typing import TYPE_CHECKING, Any
from crewai.flow.flow_definition import (
FlowActionDefinition,
FlowInnerActionDefinition,
)
from crewai.flow.runtime._actions._base import ActionHandlerRegistry
from crewai.flow.runtime._actions._code import CodeActionHandler
from crewai.flow.runtime._actions._each import EachActionHandler
from crewai.flow.runtime._actions._expression import ExpressionActionHandler
from crewai.flow.runtime._actions._tool import ToolActionHandler
if TYPE_CHECKING:
from crewai.flow.runtime import Flow
__all__ = [
"build_action",
]
_SIMPLE_ACTION_HANDLERS = (
CodeActionHandler(),
ToolActionHandler(),
ExpressionActionHandler(),
)
_SIMPLE_ACTION_REGISTRY = ActionHandlerRegistry[FlowInnerActionDefinition](
_SIMPLE_ACTION_HANDLERS
)
_ACTION_REGISTRY = ActionHandlerRegistry[FlowActionDefinition](
(
*_SIMPLE_ACTION_HANDLERS,
EachActionHandler(_SIMPLE_ACTION_REGISTRY),
)
)
def build_action(flow: Flow[Any], action: FlowActionDefinition) -> Callable[..., Any]:
"""Turn one `do:` action into the callable the flow runs for that node."""
return _ACTION_REGISTRY.build(flow, action)

View File

@@ -0,0 +1,39 @@
"""Shared action handler contracts."""
from __future__ import annotations
from collections.abc import Callable, Iterable
from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar
from pydantic import BaseModel
if TYPE_CHECKING:
from crewai.flow.runtime import Flow
ActionT = TypeVar("ActionT", bound=BaseModel)
ResolvedAction = Callable[..., Any]
class ActionHandler(Protocol[ActionT]):
"""Handler for one concrete FlowDefinition action type."""
action_type: type[ActionT]
def build(self, flow: Flow[Any], action: ActionT) -> ResolvedAction:
"""Build the callable executed by the flow."""
class ActionHandlerRegistry(Generic[ActionT]):
"""Build action callables with an ordered set of typed handlers."""
def __init__(self, handlers: Iterable[ActionHandler[Any]]) -> None:
self._handlers = tuple(handlers)
def build(self, flow: Flow[Any], action: ActionT) -> ResolvedAction:
for handler in self._handlers:
if isinstance(action, handler.action_type):
return handler.build(flow, action)
call = getattr(action, "call", None)
raise ValueError(f"unknown call type {call!r}")

View File

@@ -0,0 +1,51 @@
"""Handler for ``call: code`` FlowDefinition actions."""
from __future__ import annotations
from collections.abc import Callable
import functools
from typing import TYPE_CHECKING, Any, cast
from crewai.flow.flow_definition import FlowCodeActionDefinition
from crewai.flow.runtime._actions._base import ResolvedAction
from crewai.flow.runtime._actions._runtime import LOCAL_CONTEXT_KWARG
from crewai.flow.runtime._expressions import render_with_block
from crewai.flow.runtime._refs import InvalidRefError, resolve_ref
if TYPE_CHECKING:
from crewai.flow.runtime import Flow
class CodeActionHandler:
"""Build importable Python callables and bind them to the running flow."""
action_type = FlowCodeActionDefinition
def build(
self, flow: Flow[Any], action: FlowCodeActionDefinition
) -> ResolvedAction:
handler = _resolve_code_handler(flow, action)
def run_code(*args: Any, **kwargs: Any) -> Any:
local_context = kwargs.pop(LOCAL_CONTEXT_KWARG, None)
if action.with_ is None:
return handler(*args, **kwargs)
return handler(
**render_with_block(flow, action.with_, local_context=local_context)
)
return functools.update_wrapper(run_code, handler)
def _resolve_code_handler(
flow: Flow[Any], action: FlowCodeActionDefinition
) -> Callable[..., Any]:
ref = action.ref
target = resolve_ref(ref, field="do")
if not callable(target):
raise InvalidRefError(f"invalid do ref {ref!r}; object is not callable")
handler = cast(Callable[..., Any], target)
if getattr(handler, "__self__", None) is None:
handler = handler.__get__(flow, type(flow))
return handler

View File

@@ -0,0 +1,73 @@
"""Handler for ``call: each`` FlowDefinition actions."""
from __future__ import annotations
from collections.abc import Callable
from typing import TYPE_CHECKING, Any
from crewai.flow.flow_definition import (
FlowEachActionDefinition,
FlowEachInnerActionDefinition,
FlowInnerActionDefinition,
)
from crewai.flow.runtime._actions._base import (
ActionHandlerRegistry,
ResolvedAction,
)
from crewai.flow.runtime._actions._runtime import (
LOCAL_CONTEXT_KWARG,
ensure_array,
invoke_callable,
)
from crewai.flow.runtime._expressions import evaluate_expression
if TYPE_CHECKING:
from crewai.flow.runtime import Flow
class EachActionHandler:
"""Build a sequential mini-pipeline for every item in an array."""
action_type = FlowEachActionDefinition
def __init__(
self, inner_registry: ActionHandlerRegistry[FlowInnerActionDefinition]
) -> None:
self._inner_registry = inner_registry
def build(
self, flow: Flow[Any], action: FlowEachActionDefinition
) -> ResolvedAction:
inner_actions = [
(inner_action.name, self._resolve_inner_action(flow, inner_action))
for inner_action in action.do
]
async def run_each(*_args: Any, **_kwargs: Any) -> list[Any]:
items = ensure_array(evaluate_expression(flow, action.in_))
results: list[Any] = []
for item in items:
local_outputs: dict[str, Any] = {}
last_output: Any = None
for name, run_inner_action in inner_actions:
last_output = await run_inner_action(
{"item": item, "outputs": local_outputs}
)
local_outputs[name] = last_output
results.append(last_output)
return results
return run_each
def _resolve_inner_action(
self, flow: Flow[Any], inner_action: FlowEachInnerActionDefinition
) -> Callable[[dict[str, Any]], Any]:
run_action = self._inner_registry.build(flow, inner_action.action)
async def run_inner_action(local_context: dict[str, Any]) -> Any:
return await invoke_callable(
run_action, **{LOCAL_CONTEXT_KWARG: local_context}
)
return run_inner_action

View File

@@ -0,0 +1,29 @@
"""Handler for ``call: expression`` FlowDefinition actions."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from crewai.flow.flow_definition import FlowExpressionActionDefinition
from crewai.flow.runtime._actions._base import ResolvedAction
from crewai.flow.runtime._actions._runtime import LOCAL_CONTEXT_KWARG
from crewai.flow.runtime._expressions import evaluate_expression
if TYPE_CHECKING:
from crewai.flow.runtime import Flow
class ExpressionActionHandler:
"""Build CEL expression actions."""
action_type = FlowExpressionActionDefinition
def build(
self, flow: Flow[Any], action: FlowExpressionActionDefinition
) -> ResolvedAction:
def run_expression(*_args: Any, **kwargs: Any) -> Any:
local_context = kwargs.pop(LOCAL_CONTEXT_KWARG, None)
return evaluate_expression(flow, action.expr, local_context=local_context)
return run_expression

View File

@@ -0,0 +1,28 @@
"""Runtime helpers shared by action resolvers."""
from __future__ import annotations
from collections.abc import Callable
import inspect
from typing import Any
LOCAL_CONTEXT_KWARG = "__flow_definition_local_context"
async def invoke_callable(
handler: Callable[..., Any], *args: Any, **kwargs: Any
) -> Any:
if inspect.iscoroutinefunction(handler):
result = await handler(*args, **kwargs)
else:
result = handler(*args, **kwargs)
if inspect.isawaitable(result):
result = await result
return result
def ensure_array(value: Any) -> list[Any]:
if isinstance(value, list):
return value
raise ValueError("each.in must evaluate to an array")

View File

@@ -0,0 +1,52 @@
"""Handler for ``call: tool`` FlowDefinition actions."""
from __future__ import annotations
from collections.abc import Callable
import inspect
from typing import TYPE_CHECKING, Any, cast
from crewai.flow.flow_definition import FlowToolActionDefinition
from crewai.flow.runtime._actions._base import ResolvedAction
from crewai.flow.runtime._actions._runtime import LOCAL_CONTEXT_KWARG
from crewai.flow.runtime._expressions import render_with_block
from crewai.flow.runtime._refs import InvalidRefError, resolve_ref
if TYPE_CHECKING:
from crewai.flow.runtime import Flow
class ToolActionHandler:
"""Build and instantiate CrewAI tool actions."""
action_type = FlowToolActionDefinition
def build(
self, flow: Flow[Any], action: FlowToolActionDefinition
) -> ResolvedAction:
target = resolve_ref(action.ref, field="do")
from crewai.tools import BaseTool
if not (inspect.isclass(target) and issubclass(target, BaseTool)):
raise InvalidRefError(
f"invalid tool ref {action.ref!r}; expected a BaseTool class"
)
try:
tool_cls = cast(Callable[[], BaseTool], target)
tool = tool_cls()
except Exception as e:
raise InvalidRefError(
f"cannot instantiate tool ref {action.ref!r} without arguments: {e}"
) from e
tool_kwargs = action.with_ or {}
def run_tool(*_args: Any, **kwargs: Any) -> Any:
local_context = kwargs.pop(LOCAL_CONTEXT_KWARG, None)
return tool.run(
**render_with_block(flow, tool_kwargs, local_context=local_context)
)
return run_tool

View File

@@ -2,7 +2,6 @@
from __future__ import annotations
import copy
import dataclasses
from itertools import pairwise
import json
@@ -25,25 +24,36 @@ class FlowExpressionError(ValueError):
"""A FlowDefinition expression failed to parse or evaluate."""
def render_with_block(flow: Flow[Any], value: Any) -> Any:
def render_with_block(
flow: Flow[Any], value: Any, local_context: dict[str, Any] | None = None
) -> Any:
"""Render CEL expressions inside a FlowDefinition ``with:`` payload."""
context = _expression_context(flow)
context = _expression_context(flow, local_context=local_context)
return _render_value(value, context)
def evaluate_expression(flow: Flow[Any], expression: str) -> Any:
def evaluate_expression(
flow: Flow[Any], expression: str, local_context: dict[str, Any] | None = None
) -> Any:
"""Evaluate a FlowDefinition CEL expression against runtime context."""
expression = expression.strip()
if not expression:
raise FlowExpressionError("empty CEL expression")
return _eval_cel(expression, _expression_context(flow))
return _eval_cel(expression, _expression_context(flow, local_context=local_context))
def _expression_context(flow: Flow[Any]) -> dict[str, Any]:
return {
def _expression_context(
flow: Flow[Any], local_context: dict[str, Any] | None = None
) -> dict[str, Any]:
context = {
"state": flow._copy_and_serialize_state(),
"outputs": _outputs_by_name(flow._method_outputs),
}
if local_context:
context.update(
{key: _to_json_safe(value) for key, value in local_context.items()}
)
return context
def _outputs_by_name(method_outputs: list[Any]) -> dict[str, Any]:
@@ -54,15 +64,24 @@ def _outputs_by_name(method_outputs: list[Any]) -> dict[str, Any]:
if isinstance(entry, dict) and "output" in entry:
method = str(entry.get("method", ""))
output = entry["output"]
output = copy.deepcopy(output)
if isinstance(output, BaseModel):
output = output.model_dump(mode="json")
elif dataclasses.is_dataclass(output) and not isinstance(output, type):
output = dataclasses.asdict(output)
outputs[method] = output
outputs[method] = _to_json_safe(output)
return outputs
def _to_json_safe(value: Any) -> Any:
if isinstance(value, BaseModel):
return value.model_dump(mode="json")
if dataclasses.is_dataclass(value) and not isinstance(value, type):
return dataclasses.asdict(value)
if isinstance(value, dict):
return {key: _to_json_safe(item) for key, item in value.items()}
if isinstance(value, list):
return [_to_json_safe(item) for item in value]
if isinstance(value, tuple):
return [_to_json_safe(item) for item in value]
return value
def _render_value(value: Any, context: dict[str, Any]) -> Any:
if isinstance(value, str):
return _render_string(value, context)

View File

@@ -0,0 +1,38 @@
"""Resolution of ``module:qualname`` refs into live Python objects."""
from __future__ import annotations
import importlib
import inspect
from operator import attrgetter
from typing import Any
class InvalidRefError(ValueError):
"""A definition ref that cannot be resolved to a live object."""
def resolve_ref(ref: str, *, field: str) -> Any:
"""Import the object a definition's `module:qualname` ref points to."""
module_name, _, qualname = ref.partition(":")
if "<" in ref or not module_name or not qualname:
raise InvalidRefError(
f"invalid {field} ref {ref!r}; expected 'module:qualname'"
)
try:
return attrgetter(qualname)(importlib.import_module(module_name))
except (ImportError, AttributeError) as e:
raise InvalidRefError(f"unresolvable {field} ref {ref!r}") from e
def resolve_instance_ref(ref: str, *, field: str) -> Any:
"""Resolve a ref, auto-instantiating a no-arg class into an instance."""
target = resolve_ref(ref, field=field)
if not inspect.isclass(target):
return target
try:
return target()
except Exception as e:
raise InvalidRefError(
f"cannot instantiate {field} ref {ref!r} without arguments: {e}"
) from e

View File

@@ -1,116 +0,0 @@
"""Resolution of FlowDefinition refs (``module:qualname``) into live objects.
Every ref-shaped value in a definition — ``do`` actions, ``state.ref``,
``config.input_provider``, ``human_feedback.provider`` — resolves through
:func:`resolve_ref`. Failures are loud and name the field and the ref.
"""
from __future__ import annotations
from collections.abc import Callable
import importlib
import inspect
from operator import attrgetter
from typing import TYPE_CHECKING, Any, cast
from crewai.flow.flow_definition import (
FlowActionDefinition,
FlowCodeActionDefinition,
FlowExpressionActionDefinition,
FlowToolActionDefinition,
)
from crewai.flow.runtime._expressions import evaluate_expression, render_with_block
if TYPE_CHECKING:
from crewai.flow.runtime import Flow
class InvalidRefError(ValueError):
"""A definition ref that cannot be resolved to a live object."""
def resolve_ref(ref: str, *, field: str) -> Any:
"""Import the object a definition's `module:qualname` ref points to."""
module_name, _, qualname = ref.partition(":")
if "<" in ref or not module_name or not qualname:
raise InvalidRefError(
f"invalid {field} ref {ref!r}; expected 'module:qualname'"
)
try:
return attrgetter(qualname)(importlib.import_module(module_name))
except (ImportError, AttributeError) as e:
raise InvalidRefError(f"unresolvable {field} ref {ref!r}") from e
def resolve_instance_ref(ref: str, *, field: str) -> Any:
"""Resolve a ref, auto-instantiating a no-arg class into an instance."""
target = resolve_ref(ref, field=field)
if not inspect.isclass(target):
return target
try:
return target()
except Exception as e:
raise InvalidRefError(
f"cannot instantiate {field} ref {ref!r} without arguments: {e}"
) from e
def _resolve_code_action(
flow: Flow[Any], action: FlowCodeActionDefinition
) -> Callable[..., Any]:
ref = action.ref
target = resolve_ref(ref, field="do")
if not callable(target):
raise InvalidRefError(f"invalid do ref {ref!r}; object is not callable")
handler = cast(Callable[..., Any], target)
if getattr(handler, "__self__", None) is None:
handler = handler.__get__(flow, type(flow))
return handler
def _resolve_tool_action(
flow: Flow[Any], action: FlowToolActionDefinition
) -> Callable[..., Any]:
target = resolve_ref(action.ref, field="do")
from crewai.tools import BaseTool
if not (inspect.isclass(target) and issubclass(target, BaseTool)):
raise InvalidRefError(
f"invalid tool ref {action.ref!r}; expected a BaseTool class"
)
try:
tool_cls = cast(Callable[[], BaseTool], target)
tool = tool_cls()
except Exception as e:
raise InvalidRefError(
f"cannot instantiate tool ref {action.ref!r} without arguments: {e}"
) from e
tool_kwargs = action.with_ or {}
def run_tool(*_args: Any, **_kwargs: Any) -> Any:
return tool.run(**render_with_block(flow, tool_kwargs))
return run_tool
def _resolve_expression_action(
flow: Flow[Any], action: FlowExpressionActionDefinition
) -> Callable[..., Any]:
def run_expression(*_args: Any, **_kwargs: Any) -> Any:
return evaluate_expression(flow, action.expr)
return run_expression
def resolve_action(flow: Flow[Any], action: FlowActionDefinition) -> Callable[..., Any]:
"""Turn one `do:` action into the callable the flow runs for that node."""
if action.call == "code":
return _resolve_code_action(flow, action)
if action.call == "tool":
return _resolve_tool_action(flow, action)
if action.call == "expression":
return _resolve_expression_action(flow, action)
raise ValueError(f"unknown call type {action.call!r}")

View File

@@ -44,6 +44,8 @@ def test_flow_public_exports_are_explicit():
"FlowDefinition",
"FlowDefinitionCondition",
"FlowDefinitionDiagnostic",
"FlowEachActionDefinition",
"FlowEachInnerActionDefinition",
"FlowExpressionActionDefinition",
"FlowHumanFeedbackDefinition",
"FlowMethodDefinition",
@@ -432,6 +434,73 @@ def test_flow_definition_round_trips_json_and_yaml():
assert yaml_round_trip.methods["decide"].listen == "begin"
def test_each_action_round_trips_json_and_yaml():
definition = flow_definition.FlowDefinition.from_dict(
{
"schema": "crewai.flow/v1",
"name": "EachFlow",
"methods": {
"process_rows": {
"description": "Process every loaded row.",
"start": True,
"do": {
"call": "each",
"in": "state.rows",
"do": [
{
"normalize": {
"call": "tool",
"ref": "my_tools:NormalizeRowTool",
"with": {"row": "${ item }"},
}
},
{
"save": {
"call": "code",
"ref": "my_flow:save_row",
"with": {
"row": "${ item }",
"normalized": "${ outputs.normalize }",
},
}
},
],
},
}
},
}
)
json_round_trip = flow_definition.FlowDefinition.from_json(definition.to_json())
yaml_round_trip = flow_definition.FlowDefinition.from_yaml(definition.to_yaml())
assert json_round_trip.to_dict() == definition.to_dict()
assert yaml_round_trip.to_dict() == definition.to_dict()
assert yaml_round_trip.methods["process_rows"].description == (
"Process every loaded row."
)
assert yaml_round_trip.methods["process_rows"].do.call == "each"
def test_flow_definition_rejects_invalid_method_names():
with pytest.raises(ValueError, match="Flow method names must match"):
flow_definition.FlowDefinition.from_dict(
{
"schema": "crewai.flow/v1",
"name": "InvalidMethodNameFlow",
"methods": {
"process-rows": {
"start": True,
"do": {
"call": "expression",
"expr": "'done'",
},
}
},
}
)
def test_flow_definition_detects_persist_metadata():
@persist(verbose=True)
class PersistedFlow(Flow[dict]):

View File

@@ -67,6 +67,26 @@ class ToolInputFlow(Flow):
return {"query": "ai agents", "suffix": " news"}
class EachActionFlow(Flow):
def normalize_row(self, row: str, prefix: str = "normalized") -> str:
return f"{prefix}:{row}"
def save_row(self, row: str, normalized: str) -> dict[str, str]:
return {"row": row, "normalized": normalized}
def keyword_code(self, name: str, punctuation: str) -> str:
return f"{name}{punctuation}"
def fail_on_bad_row(self, row: str) -> str:
if row == "bad":
raise RuntimeError("bad row")
return row
def after_each(self) -> str:
self.state["after_count"] = self.state.get("after_count", 0) + 1
return f"after:{self.state['after_count']}"
CHAIN_YAML = f"""
schema: crewai.flow/v1
name: ChainFlow
@@ -727,6 +747,274 @@ methods:
flow.kickoff()
def test_code_action_renders_keyword_inputs():
yaml_str = f"""
schema: crewai.flow/v1
name: CodeWithFlow
methods:
greet:
do:
call: code
ref: {__name__}:EachActionFlow.keyword_code
with:
name: "${{state.name}}"
punctuation: "!"
start: true
"""
flow = Flow.from_definition(FlowDefinition.from_yaml(yaml_str))
assert flow.kickoff(inputs={"name": "hello"}) == "hello!"
def test_each_action_executes_one_nested_code_action():
yaml_str = f"""
schema: crewai.flow/v1
name: EachFlow
methods:
process_rows:
do:
call: each
in: state.rows
do:
- normalize:
call: code
ref: {__name__}:EachActionFlow.normalize_row
with:
row: "${{item}}"
start: true
"""
flow = Flow.from_definition(FlowDefinition.from_yaml(yaml_str))
assert flow.kickoff(inputs={"rows": ["a", "b"]}) == [
"normalized:a",
"normalized:b",
]
def test_each_action_uses_iteration_outputs_between_nested_actions():
yaml_str = f"""
schema: crewai.flow/v1
name: EachFlow
methods:
process_rows:
do:
call: each
in: state.rows
do:
- normalize:
call: code
ref: {__name__}:EachActionFlow.normalize_row
with:
row: "${{item}}"
prefix: saved
- save:
call: code
ref: {__name__}:EachActionFlow.save_row
with:
row: "${{item}}"
normalized: "${{outputs.normalize}}"
start: true
"""
flow = Flow.from_definition(FlowDefinition.from_yaml(yaml_str))
assert flow.kickoff(inputs={"rows": ["a", "b"]}) == [
{"row": "a", "normalized": "saved:a"},
{"row": "b", "normalized": "saved:b"},
]
def test_each_action_resets_inner_outputs_between_iterations():
yaml_str = """
schema: crewai.flow/v1
name: EachFlow
methods:
process_rows:
do:
call: each
in: state.rows
do:
- leak_check:
call: expression
expr: "has(outputs.previous) ? outputs.previous : 'empty'"
- previous:
call: expression
expr: item
start: true
"""
flow = Flow.from_definition(FlowDefinition.from_yaml(yaml_str))
assert flow.kickoff(inputs={"rows": ["a", "b"]}) == ["a", "b"]
assert flow._method_outputs == [
{"method": "process_rows", "output": ["a", "b"]}
]
def test_each_action_empty_list_returns_empty_and_listener_runs_once():
yaml_str = f"""
schema: crewai.flow/v1
name: EachFlow
methods:
process_rows:
do:
call: each
in: state.rows
do:
- normalize:
call: code
ref: {__name__}:EachActionFlow.normalize_row
with:
row: "${{item}}"
start: true
after_each:
do:
call: code
ref: {__name__}:EachActionFlow.after_each
listen: process_rows
"""
flow = Flow.from_definition(FlowDefinition.from_yaml(yaml_str))
events = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(MethodExecutionFinishedEvent)
def on_finished(source, event):
events.append(event.method_name)
result = flow.kickoff(inputs={"rows": []})
assert result == "after:1"
assert flow.method_outputs == [[], "after:1"]
assert flow.state["after_count"] == 1
assert events.count("process_rows") == 1
assert events.count("after_each") == 1
@pytest.mark.parametrize(
("expr", "inputs"),
[
("1", {}),
('"rows"', {}),
("state.rows", {"rows": {"a": 1}}),
],
)
def test_each_action_rejects_non_list_inputs(expr, inputs):
definition = FlowDefinition.from_dict(
{
"schema": "crewai.flow/v1",
"name": "EachFlow",
"methods": {
"process_rows": {
"start": True,
"do": {
"call": "each",
"in": expr,
"do": [{"value": {"call": "expression", "expr": "item"}}],
},
}
},
}
)
flow = Flow.from_definition(definition)
with pytest.raises(ValueError, match="each.in must evaluate to an array"):
flow.kickoff(inputs=inputs)
@pytest.mark.parametrize(
"action_do",
[
[],
[{"first": {"call": "expression", "expr": "item"}, "second": {"call": "expression", "expr": "item"}}],
[{"1bad": {"call": "expression", "expr": "item"}}],
[
{"same": {"call": "expression", "expr": "item"}},
{"same": {"call": "expression", "expr": "item"}},
],
],
)
def test_each_action_validates_inner_action_shape(action_do):
with pytest.raises(ValidationError):
FlowDefinition.from_dict(
{
"schema": "crewai.flow/v1",
"name": "EachFlow",
"methods": {
"process_rows": {
"start": True,
"do": {
"call": "each",
"in": "state.rows",
"do": action_do,
},
}
},
}
)
def test_each_action_rejects_nested_each_actions():
with pytest.raises(ValidationError):
FlowDefinition.from_dict(
{
"schema": "crewai.flow/v1",
"name": "EachFlow",
"methods": {
"process_rows": {
"start": True,
"do": {
"call": "each",
"in": "state.rows",
"do": [
{
"nested": {
"call": "each",
"in": "state.children",
"do": [
{
"child": {
"call": "expression",
"expr": "item",
}
}
],
}
}
],
},
}
},
}
)
def test_each_action_failure_fails_outer_method():
yaml_str = f"""
schema: crewai.flow/v1
name: EachFlow
methods:
process_rows:
do:
call: each
in: state.rows
do:
- validate:
call: code
ref: {__name__}:EachActionFlow.fail_on_bad_row
with:
row: "${{item}}"
start: true
"""
flow = Flow.from_definition(FlowDefinition.from_yaml(yaml_str))
with pytest.raises(RuntimeError, match="bad row"):
flow.kickoff(inputs={"rows": ["ok", "bad"]})
def test_expression_action_round_trips():
definition = FlowDefinition.from_dict(
{
@@ -830,26 +1118,6 @@ def test_tool_action_requires_module_qualname_ref():
Flow.from_definition(definition)
def test_code_action_rejects_tool_inputs():
with pytest.raises(ValidationError):
FlowDefinition.from_dict(
{
"schema": "crewai.flow/v1",
"name": "InvalidCodeActionFlow",
"methods": {
"begin": {
"start": True,
"do": {
"call": "code",
"ref": f"{__name__}:ChainFlow.begin",
"with": {"search_query": "ai agents"},
},
}
},
}
)
def test_pydantic_state_from_ref_parity():
flow, result = assert_parity(PydanticStateFlow, PYDANTIC_STATE_YAML)
assert result == "count=1"