mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-06-15 05:08:11 +00:00
Compare commits
1 Commits
worktree-s
...
flow-itera
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
02eeefe5ea |
@@ -1,158 +0,0 @@
|
||||
"""SSRF-safe HTTP fetching for crewai-tools.
|
||||
|
||||
:func:`validate_url` checks the URL it is handed, but it cannot protect a
|
||||
fetch on its own: ``requests`` re-resolves DNS at connect time and follows
|
||||
redirects automatically, so a public-looking host that 302-redirects to an
|
||||
internal address (or that rebinds DNS between validation and connect) reaches
|
||||
the internal target without ever being re-checked.
|
||||
|
||||
This module closes both gaps at the connection layer:
|
||||
|
||||
* :class:`SSRFProtectedAdapter` re-runs :func:`validate_url` for every request
|
||||
it sends. ``requests.Session.send`` invokes the adapter once per redirect
|
||||
hop, so each ``Location`` target is validated before it is followed.
|
||||
* The adapter's connections validate the *actual* peer IP immediately after
|
||||
the socket connects. The IP that was authorised is therefore the IP the
|
||||
connection uses, removing the DNS time-of-check/time-of-use gap that
|
||||
:func:`validate_url`'s own ``getaddrinfo`` call leaves open.
|
||||
|
||||
Use :func:`safe_get` (or :func:`create_safe_session`) instead of calling
|
||||
``requests.get`` directly from any tool that fetches a user- or
|
||||
LLM-controlled URL.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from requests.adapters import DEFAULT_POOLBLOCK, HTTPAdapter
|
||||
from urllib3.connection import HTTPConnection, HTTPSConnection
|
||||
from urllib3.connectionpool import HTTPConnectionPool, HTTPSConnectionPool
|
||||
from urllib3.poolmanager import PoolManager
|
||||
|
||||
from crewai_tools.security.safe_path import (
|
||||
_is_escape_hatch_enabled,
|
||||
_is_private_or_reserved,
|
||||
validate_url,
|
||||
)
|
||||
|
||||
|
||||
def _assert_safe_peer(sock: Any) -> None:
|
||||
"""Raise if a connected socket's peer is a private/reserved address.
|
||||
|
||||
Validating the real peer (rather than a separately resolved IP) is what
|
||||
defeats DNS rebinding: the address we connected to is the address we check.
|
||||
"""
|
||||
if _is_escape_hatch_enabled():
|
||||
return
|
||||
try:
|
||||
peer = sock.getpeername()
|
||||
except OSError:
|
||||
return
|
||||
ip_str = str(peer[0])
|
||||
if _is_private_or_reserved(ip_str):
|
||||
raise ValueError(
|
||||
f"Connection resolved to private/reserved IP {ip_str}. "
|
||||
f"Access to internal networks is not allowed (possible SSRF via "
|
||||
f"redirect or DNS rebinding)."
|
||||
)
|
||||
|
||||
|
||||
class _SafeHTTPConnection(HTTPConnection):
|
||||
def connect(self) -> None:
|
||||
super().connect()
|
||||
_assert_safe_peer(self.sock)
|
||||
|
||||
|
||||
class _SafeHTTPSConnection(HTTPSConnection):
|
||||
def connect(self) -> None:
|
||||
super().connect()
|
||||
_assert_safe_peer(self.sock)
|
||||
|
||||
|
||||
class _SafeHTTPConnectionPool(HTTPConnectionPool):
|
||||
ConnectionCls = _SafeHTTPConnection
|
||||
|
||||
|
||||
class _SafeHTTPSConnectionPool(HTTPSConnectionPool):
|
||||
ConnectionCls = _SafeHTTPSConnection
|
||||
|
||||
|
||||
_SAFE_POOL_CLASSES = {
|
||||
"http": _SafeHTTPConnectionPool,
|
||||
"https": _SafeHTTPSConnectionPool,
|
||||
}
|
||||
|
||||
|
||||
class _SafePoolManager(PoolManager):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.pool_classes_by_scheme = _SAFE_POOL_CLASSES
|
||||
|
||||
|
||||
class SSRFProtectedAdapter(HTTPAdapter):
|
||||
"""Transport adapter that re-validates every hop and pins the peer IP.
|
||||
|
||||
``validate_url`` runs on each ``send`` — including every redirect hop
|
||||
``requests`` follows — and the underlying connections reject any socket
|
||||
that ends up connected to a private/reserved address.
|
||||
"""
|
||||
|
||||
def init_poolmanager(
|
||||
self,
|
||||
connections: int,
|
||||
maxsize: int,
|
||||
block: bool = DEFAULT_POOLBLOCK,
|
||||
**pool_kwargs: Any,
|
||||
) -> None:
|
||||
self.poolmanager = _SafePoolManager(
|
||||
num_pools=connections,
|
||||
maxsize=maxsize,
|
||||
block=block,
|
||||
**pool_kwargs,
|
||||
)
|
||||
|
||||
def send(self, request: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
# Re-validate the target of every request the session sends. Because
|
||||
# Session.send calls this once per redirect hop, each Location is
|
||||
# checked before it is followed.
|
||||
validate_url(request.url)
|
||||
return super().send(request, *args, **kwargs)
|
||||
|
||||
|
||||
def create_safe_session() -> requests.Session:
|
||||
"""Return a ``requests.Session`` that is hardened against SSRF.
|
||||
|
||||
The session validates every request (and redirect hop) and pins
|
||||
connections to the validated peer IP.
|
||||
"""
|
||||
session = requests.Session()
|
||||
adapter = SSRFProtectedAdapter()
|
||||
session.mount("http://", adapter)
|
||||
session.mount("https://", adapter)
|
||||
return session
|
||||
|
||||
|
||||
def safe_get(url: str, **kwargs: Any) -> requests.Response:
|
||||
"""Perform an SSRF-safe ``GET``.
|
||||
|
||||
Drop-in replacement for ``requests.get`` for tools that fetch a
|
||||
user- or LLM-controlled URL. Validates the initial URL and every redirect
|
||||
hop, and rejects connections that land on private/reserved addresses.
|
||||
|
||||
Args:
|
||||
url: The URL to fetch.
|
||||
**kwargs: Forwarded to ``Session.get`` (``headers``, ``cookies``,
|
||||
``timeout``, ...).
|
||||
|
||||
Returns:
|
||||
The ``requests.Response``.
|
||||
|
||||
Raises:
|
||||
ValueError: If the URL, a redirect target, or the connected peer is
|
||||
not allowed.
|
||||
"""
|
||||
validate_url(url)
|
||||
with create_safe_session() as session:
|
||||
return session.get(url, **kwargs)
|
||||
@@ -3,8 +3,9 @@ from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
import requests
|
||||
|
||||
from crewai_tools.security.safe_requests import safe_get
|
||||
from crewai_tools.security.safe_path import validate_url
|
||||
|
||||
|
||||
try:
|
||||
@@ -82,7 +83,8 @@ class ScrapeElementFromWebsiteTool(BaseTool):
|
||||
if website_url is None or css_element is None:
|
||||
raise ValueError("Both website_url and css_element must be provided.")
|
||||
|
||||
page = safe_get(
|
||||
website_url = validate_url(website_url)
|
||||
page = requests.get(
|
||||
website_url,
|
||||
headers=self.headers,
|
||||
cookies=self.cookies if self.cookies else {},
|
||||
|
||||
@@ -3,8 +3,9 @@ import re
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
import requests
|
||||
|
||||
from crewai_tools.security.safe_requests import safe_get
|
||||
from crewai_tools.security.safe_path import validate_url
|
||||
|
||||
|
||||
try:
|
||||
@@ -74,7 +75,8 @@ class ScrapeWebsiteTool(BaseTool):
|
||||
if website_url is None:
|
||||
raise ValueError("Website URL must be provided.")
|
||||
|
||||
page = safe_get(
|
||||
website_url = validate_url(website_url)
|
||||
page = requests.get(
|
||||
website_url,
|
||||
timeout=15,
|
||||
headers=self.headers,
|
||||
|
||||
@@ -1,124 +0,0 @@
|
||||
"""Tests for SSRF-safe HTTP fetching (redirect + DNS-rebinding protection)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import http.server
|
||||
import socketserver
|
||||
import threading
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from crewai_tools.security import safe_requests
|
||||
from crewai_tools.security.safe_requests import (
|
||||
SSRFProtectedAdapter,
|
||||
create_safe_session,
|
||||
safe_get,
|
||||
)
|
||||
|
||||
|
||||
INTERNAL_BODY = b"INTERNAL-ONLY-SECRET"
|
||||
|
||||
|
||||
class _InternalHandler(http.server.BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "text/plain")
|
||||
self.end_headers()
|
||||
self.wfile.write(INTERNAL_BODY)
|
||||
|
||||
def log_message(self, *args): # silence
|
||||
pass
|
||||
|
||||
|
||||
def _serve(handler):
|
||||
"""Start a localhost server on an ephemeral port; return (server, port)."""
|
||||
server = socketserver.TCPServer(("127.0.0.1", 0), handler)
|
||||
port = server.server_address[1]
|
||||
threading.Thread(target=server.serve_forever, daemon=True).start()
|
||||
return server, port
|
||||
|
||||
|
||||
class TestRedirectRevalidation:
|
||||
"""Layer 1: validate_url runs on every send, including each redirect hop.
|
||||
|
||||
``requests.Session.send`` calls ``adapter.send`` once per redirect hop, so
|
||||
re-validating in ``send`` is what blocks a 302 to an internal target.
|
||||
"""
|
||||
|
||||
def test_adapter_revalidates_before_any_network_call(self, monkeypatch):
|
||||
calls: list[str] = []
|
||||
|
||||
def spy(url: str) -> str:
|
||||
calls.append(url)
|
||||
if "internal.target" in url:
|
||||
raise ValueError("URL resolves to private/reserved IP")
|
||||
return url
|
||||
|
||||
monkeypatch.setattr(safe_requests, "validate_url", spy)
|
||||
|
||||
adapter = SSRFProtectedAdapter()
|
||||
# Internal redirect target: send() must reject it before ever calling
|
||||
# the real transport (super().send is never reached).
|
||||
req = requests.Request("GET", "http://internal.target/").prepare()
|
||||
with pytest.raises(ValueError, match="private/reserved"):
|
||||
adapter.send(req)
|
||||
assert calls == ["http://internal.target/"]
|
||||
|
||||
def test_session_mounts_protected_adapter(self):
|
||||
session = create_safe_session()
|
||||
assert isinstance(session.get_adapter("http://x"), SSRFProtectedAdapter)
|
||||
assert isinstance(session.get_adapter("https://x"), SSRFProtectedAdapter)
|
||||
|
||||
|
||||
class _FakeSock:
|
||||
def __init__(self, peer):
|
||||
self._peer = peer
|
||||
|
||||
def getpeername(self):
|
||||
return self._peer
|
||||
|
||||
|
||||
class TestConnectionPeerGuard:
|
||||
"""Layer 2: the connection rejects an internal peer IP at connect time.
|
||||
|
||||
This is what closes the validate-then-connect DNS-rebinding gap — the IP
|
||||
the socket actually connected to is the IP that gets checked, so a host
|
||||
that resolved public at validation time but connects internal is blocked.
|
||||
"""
|
||||
|
||||
def test_safe_get_blocks_direct_internal(self):
|
||||
# No network: validate_url rejects 127.0.0.1 at the URL layer first.
|
||||
with pytest.raises(ValueError, match="private/reserved"):
|
||||
safe_get("http://127.0.0.1:9/", timeout=10)
|
||||
|
||||
def test_assert_safe_peer_blocks_private(self):
|
||||
with pytest.raises(ValueError, match="private/reserved"):
|
||||
safe_requests._assert_safe_peer(_FakeSock(("127.0.0.1", 80)))
|
||||
|
||||
def test_assert_safe_peer_blocks_metadata(self):
|
||||
with pytest.raises(ValueError, match="private/reserved"):
|
||||
safe_requests._assert_safe_peer(_FakeSock(("169.254.169.254", 80)))
|
||||
|
||||
def test_assert_safe_peer_allows_public(self):
|
||||
# A public IP must not raise.
|
||||
safe_requests._assert_safe_peer(_FakeSock(("93.184.216.34", 80)))
|
||||
|
||||
def test_assert_safe_peer_respects_escape_hatch(self, monkeypatch):
|
||||
monkeypatch.setenv("CREWAI_TOOLS_ALLOW_UNSAFE_PATHS", "true")
|
||||
# No raise even for a private peer when the escape hatch is on.
|
||||
safe_requests._assert_safe_peer(_FakeSock(("127.0.0.1", 80)))
|
||||
|
||||
def test_connection_validates_peer_after_connect(self, monkeypatch):
|
||||
"""_SafeHTTPConnection.connect runs the peer guard after connecting."""
|
||||
conn = safe_requests._SafeHTTPConnection("example.com")
|
||||
|
||||
def fake_super_connect(self):
|
||||
# Simulate a rebind: we connected to an internal address.
|
||||
self.sock = _FakeSock(("127.0.0.1", 80))
|
||||
|
||||
monkeypatch.setattr(
|
||||
safe_requests.HTTPConnection, "connect", fake_super_connect
|
||||
)
|
||||
with pytest.raises(ValueError, match="private/reserved"):
|
||||
conn.connect()
|
||||
@@ -11,9 +11,17 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Literal as TypingLiteral
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_serializer, model_validator
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
RootModel,
|
||||
field_serializer,
|
||||
model_validator,
|
||||
)
|
||||
import yaml
|
||||
|
||||
from crewai.flow.conversational_definition import (
|
||||
@@ -25,6 +33,7 @@ from crewai.flow.conversational_definition import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
FlowDefinitionCondition = str | dict[str, Any]
|
||||
_STEP_NAME_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
|
||||
|
||||
__all__ = [
|
||||
"FlowActionDefinition",
|
||||
@@ -35,6 +44,8 @@ __all__ = [
|
||||
"FlowDefinition",
|
||||
"FlowDefinitionCondition",
|
||||
"FlowDefinitionDiagnostic",
|
||||
"FlowEachActionDefinition",
|
||||
"FlowEachInnerActionDefinition",
|
||||
"FlowExpressionActionDefinition",
|
||||
"FlowHumanFeedbackDefinition",
|
||||
"FlowMethodDefinition",
|
||||
@@ -148,10 +159,11 @@ class FlowHumanFeedbackDefinition(BaseModel):
|
||||
class FlowCodeActionDefinition(BaseModel):
|
||||
"""A Flow method action that executes importable Python code."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
call: TypingLiteral["code"] = "code"
|
||||
ref: str
|
||||
with_: dict[str, Any] | None = Field(default=None, alias="with")
|
||||
|
||||
|
||||
class FlowToolActionDefinition(BaseModel):
|
||||
@@ -173,14 +185,75 @@ class FlowExpressionActionDefinition(BaseModel):
|
||||
expr: str
|
||||
|
||||
|
||||
FlowActionDefinition = (
|
||||
FlowInnerActionDefinition = (
|
||||
FlowCodeActionDefinition | FlowToolActionDefinition | FlowExpressionActionDefinition
|
||||
)
|
||||
|
||||
|
||||
class FlowEachInnerActionDefinition(RootModel[dict[str, FlowInnerActionDefinition]]):
|
||||
"""One named action inside an ``each`` composite action."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return next(iter(self.root))
|
||||
|
||||
@property
|
||||
def action(self) -> FlowInnerActionDefinition:
|
||||
return next(iter(self.root.values()))
|
||||
|
||||
|
||||
class FlowEachActionDefinition(BaseModel):
|
||||
"""A composite action that runs a sequential mini-pipeline for each item."""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
call: TypingLiteral["each"]
|
||||
in_: str = Field(alias="in")
|
||||
do: list[FlowEachInnerActionDefinition]
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _validate_inner_action_list(cls, data: Any) -> Any:
|
||||
if not isinstance(data, dict) or "do" not in data:
|
||||
return data
|
||||
|
||||
inner_actions = data["do"]
|
||||
if not isinstance(inner_actions, list) or not inner_actions:
|
||||
raise ValueError("each.do must contain at least one action")
|
||||
|
||||
seen: set[str] = set()
|
||||
for inner_action in inner_actions:
|
||||
if isinstance(inner_action, FlowEachInnerActionDefinition):
|
||||
action_mapping = inner_action.root
|
||||
elif isinstance(inner_action, dict):
|
||||
action_mapping = inner_action
|
||||
else:
|
||||
raise ValueError("each.do entries must be one-key mappings")
|
||||
|
||||
if len(action_mapping) != 1:
|
||||
raise ValueError("each.do entries must be one-key mappings")
|
||||
|
||||
name = next(iter(action_mapping))
|
||||
_validate_step_name(name, field="each.do action names")
|
||||
if name in seen:
|
||||
raise ValueError(f"each.do action names must be unique: {name!r}")
|
||||
seen.add(name)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
FlowActionDefinition = (
|
||||
FlowCodeActionDefinition
|
||||
| FlowToolActionDefinition
|
||||
| FlowExpressionActionDefinition
|
||||
| FlowEachActionDefinition
|
||||
)
|
||||
|
||||
|
||||
class FlowMethodDefinition(BaseModel):
|
||||
"""Static definition of one Flow method and its execution roles."""
|
||||
|
||||
description: str | None = None
|
||||
do: FlowActionDefinition
|
||||
start: bool | FlowDefinitionCondition | None = None
|
||||
listen: FlowDefinitionCondition | None = None
|
||||
@@ -227,6 +300,12 @@ class FlowDefinition(BaseModel):
|
||||
methods: dict[str, FlowMethodDefinition] = Field(default_factory=dict)
|
||||
diagnostics: list[FlowDefinitionDiagnostic] = Field(default_factory=list)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_method_names(self) -> FlowDefinition:
|
||||
for method_name in self.methods:
|
||||
_validate_step_name(method_name, field="Flow method names")
|
||||
return self
|
||||
|
||||
def to_dict(self, *, exclude_none: bool = True) -> dict[str, Any]:
|
||||
"""Serialize the definition to a JSON/YAML-ready dictionary."""
|
||||
return self.model_dump(by_alias=True, exclude_none=exclude_none, mode="json")
|
||||
@@ -369,6 +448,11 @@ def _deserialize_diagnostics(value: Any) -> list[FlowDefinitionDiagnostic]:
|
||||
return [FlowDefinitionDiagnostic.model_validate(item) for item in value or []]
|
||||
|
||||
|
||||
def _validate_step_name(name: str, *, field: str) -> None:
|
||||
if not isinstance(name, str) or not _STEP_NAME_PATTERN.fullmatch(name):
|
||||
raise ValueError(f"{field} must match {_STEP_NAME_PATTERN.pattern}")
|
||||
|
||||
|
||||
def _merge_diagnostics(
|
||||
*diagnostic_groups: list[FlowDefinitionDiagnostic],
|
||||
) -> list[FlowDefinitionDiagnostic]:
|
||||
|
||||
@@ -121,11 +121,8 @@ from crewai.flow.human_feedback import (
|
||||
)
|
||||
from crewai.flow.input_provider import InputProvider
|
||||
from crewai.flow.persistence.base import FlowPersistence
|
||||
from crewai.flow.runtime._resolvers import (
|
||||
resolve_action,
|
||||
resolve_instance_ref,
|
||||
resolve_ref,
|
||||
)
|
||||
from crewai.flow.runtime._actions import build_action
|
||||
from crewai.flow.runtime._refs import resolve_instance_ref, resolve_ref
|
||||
from crewai.flow.types import (
|
||||
FlowExecutionData,
|
||||
FlowMethodName,
|
||||
@@ -1092,9 +1089,9 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
self._methods.update(methods)
|
||||
|
||||
def _action_bound_methods(self) -> dict[FlowMethodName, Callable[..., Any]]:
|
||||
def resolve(name: str, definition: FlowMethodDefinition) -> Callable[..., Any]:
|
||||
def build(name: str, definition: FlowMethodDefinition) -> Callable[..., Any]:
|
||||
try:
|
||||
return resolve_action(self, definition.do)
|
||||
return build_action(self, definition.do)
|
||||
except Exception as e:
|
||||
unresolved.append(f"{name}: {e}")
|
||||
return lambda *args, **kwargs: None
|
||||
@@ -1102,9 +1099,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
methods: dict[FlowMethodName, Callable[..., Any]] = {}
|
||||
unresolved: list[str] = []
|
||||
for method_name, method_definition in self._definition.methods.items():
|
||||
methods[FlowMethodName(method_name)] = resolve(
|
||||
method_name, method_definition
|
||||
)
|
||||
methods[FlowMethodName(method_name)] = build(method_name, method_definition)
|
||||
if unresolved:
|
||||
raise ValueError(
|
||||
f"Cannot build flow {self._definition.name!r} from its definition; "
|
||||
|
||||
48
lib/crewai/src/crewai/flow/runtime/_actions/__init__.py
Normal file
48
lib/crewai/src/crewai/flow/runtime/_actions/__init__.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""Build FlowDefinition actions into live runtime callables."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.flow.flow_definition import (
|
||||
FlowActionDefinition,
|
||||
FlowInnerActionDefinition,
|
||||
)
|
||||
from crewai.flow.runtime._actions._base import ActionHandlerRegistry
|
||||
from crewai.flow.runtime._actions._code import CodeActionHandler
|
||||
from crewai.flow.runtime._actions._each import EachActionHandler
|
||||
from crewai.flow.runtime._actions._expression import ExpressionActionHandler
|
||||
from crewai.flow.runtime._actions._tool import ToolActionHandler
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.flow.runtime import Flow
|
||||
|
||||
|
||||
__all__ = [
|
||||
"build_action",
|
||||
]
|
||||
|
||||
|
||||
_SIMPLE_ACTION_HANDLERS = (
|
||||
CodeActionHandler(),
|
||||
ToolActionHandler(),
|
||||
ExpressionActionHandler(),
|
||||
)
|
||||
|
||||
_SIMPLE_ACTION_REGISTRY = ActionHandlerRegistry[FlowInnerActionDefinition](
|
||||
_SIMPLE_ACTION_HANDLERS
|
||||
)
|
||||
|
||||
_ACTION_REGISTRY = ActionHandlerRegistry[FlowActionDefinition](
|
||||
(
|
||||
*_SIMPLE_ACTION_HANDLERS,
|
||||
EachActionHandler(_SIMPLE_ACTION_REGISTRY),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def build_action(flow: Flow[Any], action: FlowActionDefinition) -> Callable[..., Any]:
|
||||
"""Turn one `do:` action into the callable the flow runs for that node."""
|
||||
return _ACTION_REGISTRY.build(flow, action)
|
||||
39
lib/crewai/src/crewai/flow/runtime/_actions/_base.py
Normal file
39
lib/crewai/src/crewai/flow/runtime/_actions/_base.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""Shared action handler contracts."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.flow.runtime import Flow
|
||||
|
||||
|
||||
ActionT = TypeVar("ActionT", bound=BaseModel)
|
||||
ResolvedAction = Callable[..., Any]
|
||||
|
||||
|
||||
class ActionHandler(Protocol[ActionT]):
|
||||
"""Handler for one concrete FlowDefinition action type."""
|
||||
|
||||
action_type: type[ActionT]
|
||||
|
||||
def build(self, flow: Flow[Any], action: ActionT) -> ResolvedAction:
|
||||
"""Build the callable executed by the flow."""
|
||||
|
||||
|
||||
class ActionHandlerRegistry(Generic[ActionT]):
|
||||
"""Build action callables with an ordered set of typed handlers."""
|
||||
|
||||
def __init__(self, handlers: Iterable[ActionHandler[Any]]) -> None:
|
||||
self._handlers = tuple(handlers)
|
||||
|
||||
def build(self, flow: Flow[Any], action: ActionT) -> ResolvedAction:
|
||||
for handler in self._handlers:
|
||||
if isinstance(action, handler.action_type):
|
||||
return handler.build(flow, action)
|
||||
call = getattr(action, "call", None)
|
||||
raise ValueError(f"unknown call type {call!r}")
|
||||
51
lib/crewai/src/crewai/flow/runtime/_actions/_code.py
Normal file
51
lib/crewai/src/crewai/flow/runtime/_actions/_code.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""Handler for ``call: code`` FlowDefinition actions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
import functools
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from crewai.flow.flow_definition import FlowCodeActionDefinition
|
||||
from crewai.flow.runtime._actions._base import ResolvedAction
|
||||
from crewai.flow.runtime._actions._runtime import LOCAL_CONTEXT_KWARG
|
||||
from crewai.flow.runtime._expressions import render_with_block
|
||||
from crewai.flow.runtime._refs import InvalidRefError, resolve_ref
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.flow.runtime import Flow
|
||||
|
||||
|
||||
class CodeActionHandler:
|
||||
"""Build importable Python callables and bind them to the running flow."""
|
||||
|
||||
action_type = FlowCodeActionDefinition
|
||||
|
||||
def build(
|
||||
self, flow: Flow[Any], action: FlowCodeActionDefinition
|
||||
) -> ResolvedAction:
|
||||
handler = _resolve_code_handler(flow, action)
|
||||
|
||||
def run_code(*args: Any, **kwargs: Any) -> Any:
|
||||
local_context = kwargs.pop(LOCAL_CONTEXT_KWARG, None)
|
||||
if action.with_ is None:
|
||||
return handler(*args, **kwargs)
|
||||
return handler(
|
||||
**render_with_block(flow, action.with_, local_context=local_context)
|
||||
)
|
||||
|
||||
return functools.update_wrapper(run_code, handler)
|
||||
|
||||
|
||||
def _resolve_code_handler(
|
||||
flow: Flow[Any], action: FlowCodeActionDefinition
|
||||
) -> Callable[..., Any]:
|
||||
ref = action.ref
|
||||
target = resolve_ref(ref, field="do")
|
||||
if not callable(target):
|
||||
raise InvalidRefError(f"invalid do ref {ref!r}; object is not callable")
|
||||
handler = cast(Callable[..., Any], target)
|
||||
if getattr(handler, "__self__", None) is None:
|
||||
handler = handler.__get__(flow, type(flow))
|
||||
return handler
|
||||
73
lib/crewai/src/crewai/flow/runtime/_actions/_each.py
Normal file
73
lib/crewai/src/crewai/flow/runtime/_actions/_each.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""Handler for ``call: each`` FlowDefinition actions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.flow.flow_definition import (
|
||||
FlowEachActionDefinition,
|
||||
FlowEachInnerActionDefinition,
|
||||
FlowInnerActionDefinition,
|
||||
)
|
||||
from crewai.flow.runtime._actions._base import (
|
||||
ActionHandlerRegistry,
|
||||
ResolvedAction,
|
||||
)
|
||||
from crewai.flow.runtime._actions._runtime import (
|
||||
LOCAL_CONTEXT_KWARG,
|
||||
ensure_array,
|
||||
invoke_callable,
|
||||
)
|
||||
from crewai.flow.runtime._expressions import evaluate_expression
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.flow.runtime import Flow
|
||||
|
||||
|
||||
class EachActionHandler:
|
||||
"""Build a sequential mini-pipeline for every item in an array."""
|
||||
|
||||
action_type = FlowEachActionDefinition
|
||||
|
||||
def __init__(
|
||||
self, inner_registry: ActionHandlerRegistry[FlowInnerActionDefinition]
|
||||
) -> None:
|
||||
self._inner_registry = inner_registry
|
||||
|
||||
def build(
|
||||
self, flow: Flow[Any], action: FlowEachActionDefinition
|
||||
) -> ResolvedAction:
|
||||
inner_actions = [
|
||||
(inner_action.name, self._resolve_inner_action(flow, inner_action))
|
||||
for inner_action in action.do
|
||||
]
|
||||
|
||||
async def run_each(*_args: Any, **_kwargs: Any) -> list[Any]:
|
||||
items = ensure_array(evaluate_expression(flow, action.in_))
|
||||
results: list[Any] = []
|
||||
for item in items:
|
||||
local_outputs: dict[str, Any] = {}
|
||||
last_output: Any = None
|
||||
for name, run_inner_action in inner_actions:
|
||||
last_output = await run_inner_action(
|
||||
{"item": item, "outputs": local_outputs}
|
||||
)
|
||||
local_outputs[name] = last_output
|
||||
results.append(last_output)
|
||||
return results
|
||||
|
||||
return run_each
|
||||
|
||||
def _resolve_inner_action(
|
||||
self, flow: Flow[Any], inner_action: FlowEachInnerActionDefinition
|
||||
) -> Callable[[dict[str, Any]], Any]:
|
||||
run_action = self._inner_registry.build(flow, inner_action.action)
|
||||
|
||||
async def run_inner_action(local_context: dict[str, Any]) -> Any:
|
||||
return await invoke_callable(
|
||||
run_action, **{LOCAL_CONTEXT_KWARG: local_context}
|
||||
)
|
||||
|
||||
return run_inner_action
|
||||
29
lib/crewai/src/crewai/flow/runtime/_actions/_expression.py
Normal file
29
lib/crewai/src/crewai/flow/runtime/_actions/_expression.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""Handler for ``call: expression`` FlowDefinition actions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.flow.flow_definition import FlowExpressionActionDefinition
|
||||
from crewai.flow.runtime._actions._base import ResolvedAction
|
||||
from crewai.flow.runtime._actions._runtime import LOCAL_CONTEXT_KWARG
|
||||
from crewai.flow.runtime._expressions import evaluate_expression
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.flow.runtime import Flow
|
||||
|
||||
|
||||
class ExpressionActionHandler:
|
||||
"""Build CEL expression actions."""
|
||||
|
||||
action_type = FlowExpressionActionDefinition
|
||||
|
||||
def build(
|
||||
self, flow: Flow[Any], action: FlowExpressionActionDefinition
|
||||
) -> ResolvedAction:
|
||||
def run_expression(*_args: Any, **kwargs: Any) -> Any:
|
||||
local_context = kwargs.pop(LOCAL_CONTEXT_KWARG, None)
|
||||
return evaluate_expression(flow, action.expr, local_context=local_context)
|
||||
|
||||
return run_expression
|
||||
28
lib/crewai/src/crewai/flow/runtime/_actions/_runtime.py
Normal file
28
lib/crewai/src/crewai/flow/runtime/_actions/_runtime.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Runtime helpers shared by action resolvers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
import inspect
|
||||
from typing import Any
|
||||
|
||||
|
||||
LOCAL_CONTEXT_KWARG = "__flow_definition_local_context"
|
||||
|
||||
|
||||
async def invoke_callable(
|
||||
handler: Callable[..., Any], *args: Any, **kwargs: Any
|
||||
) -> Any:
|
||||
if inspect.iscoroutinefunction(handler):
|
||||
result = await handler(*args, **kwargs)
|
||||
else:
|
||||
result = handler(*args, **kwargs)
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
return result
|
||||
|
||||
|
||||
def ensure_array(value: Any) -> list[Any]:
|
||||
if isinstance(value, list):
|
||||
return value
|
||||
raise ValueError("each.in must evaluate to an array")
|
||||
52
lib/crewai/src/crewai/flow/runtime/_actions/_tool.py
Normal file
52
lib/crewai/src/crewai/flow/runtime/_actions/_tool.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Handler for ``call: tool`` FlowDefinition actions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from crewai.flow.flow_definition import FlowToolActionDefinition
|
||||
from crewai.flow.runtime._actions._base import ResolvedAction
|
||||
from crewai.flow.runtime._actions._runtime import LOCAL_CONTEXT_KWARG
|
||||
from crewai.flow.runtime._expressions import render_with_block
|
||||
from crewai.flow.runtime._refs import InvalidRefError, resolve_ref
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.flow.runtime import Flow
|
||||
|
||||
|
||||
class ToolActionHandler:
|
||||
"""Build and instantiate CrewAI tool actions."""
|
||||
|
||||
action_type = FlowToolActionDefinition
|
||||
|
||||
def build(
|
||||
self, flow: Flow[Any], action: FlowToolActionDefinition
|
||||
) -> ResolvedAction:
|
||||
target = resolve_ref(action.ref, field="do")
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
if not (inspect.isclass(target) and issubclass(target, BaseTool)):
|
||||
raise InvalidRefError(
|
||||
f"invalid tool ref {action.ref!r}; expected a BaseTool class"
|
||||
)
|
||||
|
||||
try:
|
||||
tool_cls = cast(Callable[[], BaseTool], target)
|
||||
tool = tool_cls()
|
||||
except Exception as e:
|
||||
raise InvalidRefError(
|
||||
f"cannot instantiate tool ref {action.ref!r} without arguments: {e}"
|
||||
) from e
|
||||
|
||||
tool_kwargs = action.with_ or {}
|
||||
|
||||
def run_tool(*_args: Any, **kwargs: Any) -> Any:
|
||||
local_context = kwargs.pop(LOCAL_CONTEXT_KWARG, None)
|
||||
return tool.run(
|
||||
**render_with_block(flow, tool_kwargs, local_context=local_context)
|
||||
)
|
||||
|
||||
return run_tool
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import dataclasses
|
||||
from itertools import pairwise
|
||||
import json
|
||||
@@ -25,25 +24,36 @@ class FlowExpressionError(ValueError):
|
||||
"""A FlowDefinition expression failed to parse or evaluate."""
|
||||
|
||||
|
||||
def render_with_block(flow: Flow[Any], value: Any) -> Any:
|
||||
def render_with_block(
|
||||
flow: Flow[Any], value: Any, local_context: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
"""Render CEL expressions inside a FlowDefinition ``with:`` payload."""
|
||||
context = _expression_context(flow)
|
||||
context = _expression_context(flow, local_context=local_context)
|
||||
return _render_value(value, context)
|
||||
|
||||
|
||||
def evaluate_expression(flow: Flow[Any], expression: str) -> Any:
|
||||
def evaluate_expression(
|
||||
flow: Flow[Any], expression: str, local_context: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
"""Evaluate a FlowDefinition CEL expression against runtime context."""
|
||||
expression = expression.strip()
|
||||
if not expression:
|
||||
raise FlowExpressionError("empty CEL expression")
|
||||
return _eval_cel(expression, _expression_context(flow))
|
||||
return _eval_cel(expression, _expression_context(flow, local_context=local_context))
|
||||
|
||||
|
||||
def _expression_context(flow: Flow[Any]) -> dict[str, Any]:
|
||||
return {
|
||||
def _expression_context(
|
||||
flow: Flow[Any], local_context: dict[str, Any] | None = None
|
||||
) -> dict[str, Any]:
|
||||
context = {
|
||||
"state": flow._copy_and_serialize_state(),
|
||||
"outputs": _outputs_by_name(flow._method_outputs),
|
||||
}
|
||||
if local_context:
|
||||
context.update(
|
||||
{key: _to_json_safe(value) for key, value in local_context.items()}
|
||||
)
|
||||
return context
|
||||
|
||||
|
||||
def _outputs_by_name(method_outputs: list[Any]) -> dict[str, Any]:
|
||||
@@ -54,15 +64,24 @@ def _outputs_by_name(method_outputs: list[Any]) -> dict[str, Any]:
|
||||
if isinstance(entry, dict) and "output" in entry:
|
||||
method = str(entry.get("method", ""))
|
||||
output = entry["output"]
|
||||
output = copy.deepcopy(output)
|
||||
if isinstance(output, BaseModel):
|
||||
output = output.model_dump(mode="json")
|
||||
elif dataclasses.is_dataclass(output) and not isinstance(output, type):
|
||||
output = dataclasses.asdict(output)
|
||||
outputs[method] = output
|
||||
outputs[method] = _to_json_safe(output)
|
||||
return outputs
|
||||
|
||||
|
||||
def _to_json_safe(value: Any) -> Any:
|
||||
if isinstance(value, BaseModel):
|
||||
return value.model_dump(mode="json")
|
||||
if dataclasses.is_dataclass(value) and not isinstance(value, type):
|
||||
return dataclasses.asdict(value)
|
||||
if isinstance(value, dict):
|
||||
return {key: _to_json_safe(item) for key, item in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [_to_json_safe(item) for item in value]
|
||||
if isinstance(value, tuple):
|
||||
return [_to_json_safe(item) for item in value]
|
||||
return value
|
||||
|
||||
|
||||
def _render_value(value: Any, context: dict[str, Any]) -> Any:
|
||||
if isinstance(value, str):
|
||||
return _render_string(value, context)
|
||||
|
||||
38
lib/crewai/src/crewai/flow/runtime/_refs.py
Normal file
38
lib/crewai/src/crewai/flow/runtime/_refs.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""Resolution of ``module:qualname`` refs into live Python objects."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
from operator import attrgetter
|
||||
from typing import Any
|
||||
|
||||
|
||||
class InvalidRefError(ValueError):
|
||||
"""A definition ref that cannot be resolved to a live object."""
|
||||
|
||||
|
||||
def resolve_ref(ref: str, *, field: str) -> Any:
|
||||
"""Import the object a definition's `module:qualname` ref points to."""
|
||||
module_name, _, qualname = ref.partition(":")
|
||||
if "<" in ref or not module_name or not qualname:
|
||||
raise InvalidRefError(
|
||||
f"invalid {field} ref {ref!r}; expected 'module:qualname'"
|
||||
)
|
||||
try:
|
||||
return attrgetter(qualname)(importlib.import_module(module_name))
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise InvalidRefError(f"unresolvable {field} ref {ref!r}") from e
|
||||
|
||||
|
||||
def resolve_instance_ref(ref: str, *, field: str) -> Any:
|
||||
"""Resolve a ref, auto-instantiating a no-arg class into an instance."""
|
||||
target = resolve_ref(ref, field=field)
|
||||
if not inspect.isclass(target):
|
||||
return target
|
||||
try:
|
||||
return target()
|
||||
except Exception as e:
|
||||
raise InvalidRefError(
|
||||
f"cannot instantiate {field} ref {ref!r} without arguments: {e}"
|
||||
) from e
|
||||
@@ -1,116 +0,0 @@
|
||||
"""Resolution of FlowDefinition refs (``module:qualname``) into live objects.
|
||||
|
||||
Every ref-shaped value in a definition — ``do`` actions, ``state.ref``,
|
||||
``config.input_provider``, ``human_feedback.provider`` — resolves through
|
||||
:func:`resolve_ref`. Failures are loud and name the field and the ref.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
import importlib
|
||||
import inspect
|
||||
from operator import attrgetter
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from crewai.flow.flow_definition import (
|
||||
FlowActionDefinition,
|
||||
FlowCodeActionDefinition,
|
||||
FlowExpressionActionDefinition,
|
||||
FlowToolActionDefinition,
|
||||
)
|
||||
from crewai.flow.runtime._expressions import evaluate_expression, render_with_block
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.flow.runtime import Flow
|
||||
|
||||
|
||||
class InvalidRefError(ValueError):
|
||||
"""A definition ref that cannot be resolved to a live object."""
|
||||
|
||||
|
||||
def resolve_ref(ref: str, *, field: str) -> Any:
|
||||
"""Import the object a definition's `module:qualname` ref points to."""
|
||||
module_name, _, qualname = ref.partition(":")
|
||||
if "<" in ref or not module_name or not qualname:
|
||||
raise InvalidRefError(
|
||||
f"invalid {field} ref {ref!r}; expected 'module:qualname'"
|
||||
)
|
||||
try:
|
||||
return attrgetter(qualname)(importlib.import_module(module_name))
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise InvalidRefError(f"unresolvable {field} ref {ref!r}") from e
|
||||
|
||||
|
||||
def resolve_instance_ref(ref: str, *, field: str) -> Any:
|
||||
"""Resolve a ref, auto-instantiating a no-arg class into an instance."""
|
||||
target = resolve_ref(ref, field=field)
|
||||
if not inspect.isclass(target):
|
||||
return target
|
||||
try:
|
||||
return target()
|
||||
except Exception as e:
|
||||
raise InvalidRefError(
|
||||
f"cannot instantiate {field} ref {ref!r} without arguments: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
def _resolve_code_action(
|
||||
flow: Flow[Any], action: FlowCodeActionDefinition
|
||||
) -> Callable[..., Any]:
|
||||
ref = action.ref
|
||||
target = resolve_ref(ref, field="do")
|
||||
if not callable(target):
|
||||
raise InvalidRefError(f"invalid do ref {ref!r}; object is not callable")
|
||||
handler = cast(Callable[..., Any], target)
|
||||
if getattr(handler, "__self__", None) is None:
|
||||
handler = handler.__get__(flow, type(flow))
|
||||
return handler
|
||||
|
||||
|
||||
def _resolve_tool_action(
|
||||
flow: Flow[Any], action: FlowToolActionDefinition
|
||||
) -> Callable[..., Any]:
|
||||
target = resolve_ref(action.ref, field="do")
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
if not (inspect.isclass(target) and issubclass(target, BaseTool)):
|
||||
raise InvalidRefError(
|
||||
f"invalid tool ref {action.ref!r}; expected a BaseTool class"
|
||||
)
|
||||
|
||||
try:
|
||||
tool_cls = cast(Callable[[], BaseTool], target)
|
||||
tool = tool_cls()
|
||||
except Exception as e:
|
||||
raise InvalidRefError(
|
||||
f"cannot instantiate tool ref {action.ref!r} without arguments: {e}"
|
||||
) from e
|
||||
|
||||
tool_kwargs = action.with_ or {}
|
||||
|
||||
def run_tool(*_args: Any, **_kwargs: Any) -> Any:
|
||||
return tool.run(**render_with_block(flow, tool_kwargs))
|
||||
|
||||
return run_tool
|
||||
|
||||
|
||||
def _resolve_expression_action(
|
||||
flow: Flow[Any], action: FlowExpressionActionDefinition
|
||||
) -> Callable[..., Any]:
|
||||
def run_expression(*_args: Any, **_kwargs: Any) -> Any:
|
||||
return evaluate_expression(flow, action.expr)
|
||||
|
||||
return run_expression
|
||||
|
||||
|
||||
def resolve_action(flow: Flow[Any], action: FlowActionDefinition) -> Callable[..., Any]:
|
||||
"""Turn one `do:` action into the callable the flow runs for that node."""
|
||||
if action.call == "code":
|
||||
return _resolve_code_action(flow, action)
|
||||
if action.call == "tool":
|
||||
return _resolve_tool_action(flow, action)
|
||||
if action.call == "expression":
|
||||
return _resolve_expression_action(flow, action)
|
||||
raise ValueError(f"unknown call type {action.call!r}")
|
||||
@@ -44,6 +44,8 @@ def test_flow_public_exports_are_explicit():
|
||||
"FlowDefinition",
|
||||
"FlowDefinitionCondition",
|
||||
"FlowDefinitionDiagnostic",
|
||||
"FlowEachActionDefinition",
|
||||
"FlowEachInnerActionDefinition",
|
||||
"FlowExpressionActionDefinition",
|
||||
"FlowHumanFeedbackDefinition",
|
||||
"FlowMethodDefinition",
|
||||
@@ -432,6 +434,73 @@ def test_flow_definition_round_trips_json_and_yaml():
|
||||
assert yaml_round_trip.methods["decide"].listen == "begin"
|
||||
|
||||
|
||||
def test_each_action_round_trips_json_and_yaml():
|
||||
definition = flow_definition.FlowDefinition.from_dict(
|
||||
{
|
||||
"schema": "crewai.flow/v1",
|
||||
"name": "EachFlow",
|
||||
"methods": {
|
||||
"process_rows": {
|
||||
"description": "Process every loaded row.",
|
||||
"start": True,
|
||||
"do": {
|
||||
"call": "each",
|
||||
"in": "state.rows",
|
||||
"do": [
|
||||
{
|
||||
"normalize": {
|
||||
"call": "tool",
|
||||
"ref": "my_tools:NormalizeRowTool",
|
||||
"with": {"row": "${ item }"},
|
||||
}
|
||||
},
|
||||
{
|
||||
"save": {
|
||||
"call": "code",
|
||||
"ref": "my_flow:save_row",
|
||||
"with": {
|
||||
"row": "${ item }",
|
||||
"normalized": "${ outputs.normalize }",
|
||||
},
|
||||
}
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
json_round_trip = flow_definition.FlowDefinition.from_json(definition.to_json())
|
||||
yaml_round_trip = flow_definition.FlowDefinition.from_yaml(definition.to_yaml())
|
||||
|
||||
assert json_round_trip.to_dict() == definition.to_dict()
|
||||
assert yaml_round_trip.to_dict() == definition.to_dict()
|
||||
assert yaml_round_trip.methods["process_rows"].description == (
|
||||
"Process every loaded row."
|
||||
)
|
||||
assert yaml_round_trip.methods["process_rows"].do.call == "each"
|
||||
|
||||
|
||||
def test_flow_definition_rejects_invalid_method_names():
|
||||
with pytest.raises(ValueError, match="Flow method names must match"):
|
||||
flow_definition.FlowDefinition.from_dict(
|
||||
{
|
||||
"schema": "crewai.flow/v1",
|
||||
"name": "InvalidMethodNameFlow",
|
||||
"methods": {
|
||||
"process-rows": {
|
||||
"start": True,
|
||||
"do": {
|
||||
"call": "expression",
|
||||
"expr": "'done'",
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_flow_definition_detects_persist_metadata():
|
||||
@persist(verbose=True)
|
||||
class PersistedFlow(Flow[dict]):
|
||||
|
||||
@@ -67,6 +67,26 @@ class ToolInputFlow(Flow):
|
||||
return {"query": "ai agents", "suffix": " news"}
|
||||
|
||||
|
||||
class EachActionFlow(Flow):
|
||||
def normalize_row(self, row: str, prefix: str = "normalized") -> str:
|
||||
return f"{prefix}:{row}"
|
||||
|
||||
def save_row(self, row: str, normalized: str) -> dict[str, str]:
|
||||
return {"row": row, "normalized": normalized}
|
||||
|
||||
def keyword_code(self, name: str, punctuation: str) -> str:
|
||||
return f"{name}{punctuation}"
|
||||
|
||||
def fail_on_bad_row(self, row: str) -> str:
|
||||
if row == "bad":
|
||||
raise RuntimeError("bad row")
|
||||
return row
|
||||
|
||||
def after_each(self) -> str:
|
||||
self.state["after_count"] = self.state.get("after_count", 0) + 1
|
||||
return f"after:{self.state['after_count']}"
|
||||
|
||||
|
||||
CHAIN_YAML = f"""
|
||||
schema: crewai.flow/v1
|
||||
name: ChainFlow
|
||||
@@ -727,6 +747,274 @@ methods:
|
||||
flow.kickoff()
|
||||
|
||||
|
||||
def test_code_action_renders_keyword_inputs():
|
||||
yaml_str = f"""
|
||||
schema: crewai.flow/v1
|
||||
name: CodeWithFlow
|
||||
methods:
|
||||
greet:
|
||||
do:
|
||||
call: code
|
||||
ref: {__name__}:EachActionFlow.keyword_code
|
||||
with:
|
||||
name: "${{state.name}}"
|
||||
punctuation: "!"
|
||||
start: true
|
||||
"""
|
||||
|
||||
flow = Flow.from_definition(FlowDefinition.from_yaml(yaml_str))
|
||||
|
||||
assert flow.kickoff(inputs={"name": "hello"}) == "hello!"
|
||||
|
||||
|
||||
def test_each_action_executes_one_nested_code_action():
|
||||
yaml_str = f"""
|
||||
schema: crewai.flow/v1
|
||||
name: EachFlow
|
||||
methods:
|
||||
process_rows:
|
||||
do:
|
||||
call: each
|
||||
in: state.rows
|
||||
do:
|
||||
- normalize:
|
||||
call: code
|
||||
ref: {__name__}:EachActionFlow.normalize_row
|
||||
with:
|
||||
row: "${{item}}"
|
||||
start: true
|
||||
"""
|
||||
|
||||
flow = Flow.from_definition(FlowDefinition.from_yaml(yaml_str))
|
||||
|
||||
assert flow.kickoff(inputs={"rows": ["a", "b"]}) == [
|
||||
"normalized:a",
|
||||
"normalized:b",
|
||||
]
|
||||
|
||||
|
||||
def test_each_action_uses_iteration_outputs_between_nested_actions():
|
||||
yaml_str = f"""
|
||||
schema: crewai.flow/v1
|
||||
name: EachFlow
|
||||
methods:
|
||||
process_rows:
|
||||
do:
|
||||
call: each
|
||||
in: state.rows
|
||||
do:
|
||||
- normalize:
|
||||
call: code
|
||||
ref: {__name__}:EachActionFlow.normalize_row
|
||||
with:
|
||||
row: "${{item}}"
|
||||
prefix: saved
|
||||
- save:
|
||||
call: code
|
||||
ref: {__name__}:EachActionFlow.save_row
|
||||
with:
|
||||
row: "${{item}}"
|
||||
normalized: "${{outputs.normalize}}"
|
||||
start: true
|
||||
"""
|
||||
|
||||
flow = Flow.from_definition(FlowDefinition.from_yaml(yaml_str))
|
||||
|
||||
assert flow.kickoff(inputs={"rows": ["a", "b"]}) == [
|
||||
{"row": "a", "normalized": "saved:a"},
|
||||
{"row": "b", "normalized": "saved:b"},
|
||||
]
|
||||
|
||||
|
||||
def test_each_action_resets_inner_outputs_between_iterations():
|
||||
yaml_str = """
|
||||
schema: crewai.flow/v1
|
||||
name: EachFlow
|
||||
methods:
|
||||
process_rows:
|
||||
do:
|
||||
call: each
|
||||
in: state.rows
|
||||
do:
|
||||
- leak_check:
|
||||
call: expression
|
||||
expr: "has(outputs.previous) ? outputs.previous : 'empty'"
|
||||
- previous:
|
||||
call: expression
|
||||
expr: item
|
||||
start: true
|
||||
"""
|
||||
|
||||
flow = Flow.from_definition(FlowDefinition.from_yaml(yaml_str))
|
||||
|
||||
assert flow.kickoff(inputs={"rows": ["a", "b"]}) == ["a", "b"]
|
||||
assert flow._method_outputs == [
|
||||
{"method": "process_rows", "output": ["a", "b"]}
|
||||
]
|
||||
|
||||
|
||||
def test_each_action_empty_list_returns_empty_and_listener_runs_once():
|
||||
yaml_str = f"""
|
||||
schema: crewai.flow/v1
|
||||
name: EachFlow
|
||||
methods:
|
||||
process_rows:
|
||||
do:
|
||||
call: each
|
||||
in: state.rows
|
||||
do:
|
||||
- normalize:
|
||||
call: code
|
||||
ref: {__name__}:EachActionFlow.normalize_row
|
||||
with:
|
||||
row: "${{item}}"
|
||||
start: true
|
||||
after_each:
|
||||
do:
|
||||
call: code
|
||||
ref: {__name__}:EachActionFlow.after_each
|
||||
listen: process_rows
|
||||
"""
|
||||
|
||||
flow = Flow.from_definition(FlowDefinition.from_yaml(yaml_str))
|
||||
events = []
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(MethodExecutionFinishedEvent)
|
||||
def on_finished(source, event):
|
||||
events.append(event.method_name)
|
||||
|
||||
result = flow.kickoff(inputs={"rows": []})
|
||||
|
||||
assert result == "after:1"
|
||||
assert flow.method_outputs == [[], "after:1"]
|
||||
assert flow.state["after_count"] == 1
|
||||
assert events.count("process_rows") == 1
|
||||
assert events.count("after_each") == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("expr", "inputs"),
|
||||
[
|
||||
("1", {}),
|
||||
('"rows"', {}),
|
||||
("state.rows", {"rows": {"a": 1}}),
|
||||
],
|
||||
)
|
||||
def test_each_action_rejects_non_list_inputs(expr, inputs):
|
||||
definition = FlowDefinition.from_dict(
|
||||
{
|
||||
"schema": "crewai.flow/v1",
|
||||
"name": "EachFlow",
|
||||
"methods": {
|
||||
"process_rows": {
|
||||
"start": True,
|
||||
"do": {
|
||||
"call": "each",
|
||||
"in": expr,
|
||||
"do": [{"value": {"call": "expression", "expr": "item"}}],
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
flow = Flow.from_definition(definition)
|
||||
|
||||
with pytest.raises(ValueError, match="each.in must evaluate to an array"):
|
||||
flow.kickoff(inputs=inputs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"action_do",
|
||||
[
|
||||
[],
|
||||
[{"first": {"call": "expression", "expr": "item"}, "second": {"call": "expression", "expr": "item"}}],
|
||||
[{"1bad": {"call": "expression", "expr": "item"}}],
|
||||
[
|
||||
{"same": {"call": "expression", "expr": "item"}},
|
||||
{"same": {"call": "expression", "expr": "item"}},
|
||||
],
|
||||
],
|
||||
)
|
||||
def test_each_action_validates_inner_action_shape(action_do):
|
||||
with pytest.raises(ValidationError):
|
||||
FlowDefinition.from_dict(
|
||||
{
|
||||
"schema": "crewai.flow/v1",
|
||||
"name": "EachFlow",
|
||||
"methods": {
|
||||
"process_rows": {
|
||||
"start": True,
|
||||
"do": {
|
||||
"call": "each",
|
||||
"in": "state.rows",
|
||||
"do": action_do,
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_each_action_rejects_nested_each_actions():
|
||||
with pytest.raises(ValidationError):
|
||||
FlowDefinition.from_dict(
|
||||
{
|
||||
"schema": "crewai.flow/v1",
|
||||
"name": "EachFlow",
|
||||
"methods": {
|
||||
"process_rows": {
|
||||
"start": True,
|
||||
"do": {
|
||||
"call": "each",
|
||||
"in": "state.rows",
|
||||
"do": [
|
||||
{
|
||||
"nested": {
|
||||
"call": "each",
|
||||
"in": "state.children",
|
||||
"do": [
|
||||
{
|
||||
"child": {
|
||||
"call": "expression",
|
||||
"expr": "item",
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_each_action_failure_fails_outer_method():
|
||||
yaml_str = f"""
|
||||
schema: crewai.flow/v1
|
||||
name: EachFlow
|
||||
methods:
|
||||
process_rows:
|
||||
do:
|
||||
call: each
|
||||
in: state.rows
|
||||
do:
|
||||
- validate:
|
||||
call: code
|
||||
ref: {__name__}:EachActionFlow.fail_on_bad_row
|
||||
with:
|
||||
row: "${{item}}"
|
||||
start: true
|
||||
"""
|
||||
|
||||
flow = Flow.from_definition(FlowDefinition.from_yaml(yaml_str))
|
||||
|
||||
with pytest.raises(RuntimeError, match="bad row"):
|
||||
flow.kickoff(inputs={"rows": ["ok", "bad"]})
|
||||
|
||||
|
||||
def test_expression_action_round_trips():
|
||||
definition = FlowDefinition.from_dict(
|
||||
{
|
||||
@@ -830,26 +1118,6 @@ def test_tool_action_requires_module_qualname_ref():
|
||||
Flow.from_definition(definition)
|
||||
|
||||
|
||||
def test_code_action_rejects_tool_inputs():
|
||||
with pytest.raises(ValidationError):
|
||||
FlowDefinition.from_dict(
|
||||
{
|
||||
"schema": "crewai.flow/v1",
|
||||
"name": "InvalidCodeActionFlow",
|
||||
"methods": {
|
||||
"begin": {
|
||||
"start": True,
|
||||
"do": {
|
||||
"call": "code",
|
||||
"ref": f"{__name__}:ChainFlow.begin",
|
||||
"with": {"search_query": "ai agents"},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_pydantic_state_from_ref_parity():
|
||||
flow, result = assert_parity(PydanticStateFlow, PYDANTIC_STATE_YAML)
|
||||
assert result == "count=1"
|
||||
|
||||
Reference in New Issue
Block a user