mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-06-24 09:38:11 +00:00
Compare commits
7 Commits
1.14.8a3
...
worktree-s
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7a39a3076e | ||
|
|
8593b8990a | ||
|
|
a046e6a50b | ||
|
|
1862ff8f6c | ||
|
|
3f4d1355ec | ||
|
|
83e889b287 | ||
|
|
565592c36e |
@@ -17,7 +17,7 @@ from textual.binding import Binding, BindingType
|
||||
from textual.containers import Horizontal, Vertical, VerticalScroll
|
||||
from textual.css.query import NoMatches
|
||||
from textual.screen import ModalScreen
|
||||
from textual.widgets import Button, Footer, Header, Static
|
||||
from textual.widgets import Button, Footer, Header, Input, Static
|
||||
|
||||
|
||||
_SPINNER = "⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏"
|
||||
@@ -382,6 +382,18 @@ Screen {
|
||||
height: auto;
|
||||
}
|
||||
|
||||
#conversation-input {
|
||||
display: none;
|
||||
height: 3;
|
||||
border-top: hkey #333333;
|
||||
background: #1c1c1c;
|
||||
color: #e0e0e0;
|
||||
}
|
||||
|
||||
#conversation-input:focus {
|
||||
border-top: hkey #1F7982;
|
||||
}
|
||||
|
||||
Header {
|
||||
background: #1c1c1c;
|
||||
color: #FF5A50;
|
||||
@@ -483,6 +495,7 @@ FooterKey .footer-key--key {
|
||||
total_tasks: int = 0,
|
||||
agent_names: list[str] | None = None,
|
||||
task_names: list[str] | None = None,
|
||||
conversational: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.title = f"CrewAI — {crew_name}"
|
||||
@@ -544,6 +557,13 @@ FooterKey .footer-key--key {
|
||||
self._event_handlers: list[tuple[type, Any]] = []
|
||||
|
||||
self._crew: Any = None
|
||||
self._flow: Any = None
|
||||
self._is_conversational = conversational
|
||||
self._conversation_messages: list[tuple[str, str]] = []
|
||||
self._conversation_turns = 0
|
||||
self._conversation_turn_in_progress = False
|
||||
self._conversation_previous_defer_trace_finalization: bool | None = None
|
||||
self._conversation_exit_commands = {"exit", "quit"}
|
||||
self._default_inputs: dict[str, Any] | None = None
|
||||
self._crew_result: Any = None
|
||||
self._crew_json_path: Any = None
|
||||
@@ -566,6 +586,10 @@ FooterKey .footer-key--key {
|
||||
yield Static(id="task-header")
|
||||
with VerticalScroll(id="scroll-area"):
|
||||
yield Static(id="main-content")
|
||||
yield Input(
|
||||
placeholder="Message the flow...",
|
||||
id="conversation-input",
|
||||
)
|
||||
with VerticalScroll(id="log-panel"):
|
||||
yield Static(id="log-content")
|
||||
yield Footer()
|
||||
@@ -574,7 +598,9 @@ FooterKey .footer-key--key {
|
||||
self._start_time = time.time()
|
||||
self._subscribe()
|
||||
self._tick_timer = self.set_interval(1 / 8, self._tick)
|
||||
if self._crew:
|
||||
if self._is_conversational and self._flow:
|
||||
self._start_conversational_session()
|
||||
elif self._crew:
|
||||
self._run_crew_worker()
|
||||
elif self._crew_json_path:
|
||||
self._load_and_run_worker()
|
||||
@@ -725,6 +751,140 @@ FooterKey .footer-key--key {
|
||||
self._tick_timer = self.set_interval(1 / 2, self._tick)
|
||||
self._unsubscribe_if_no_running_memory_save(wait_for_queued=True)
|
||||
|
||||
# ── Conversational flow execution ───────────────────────
|
||||
|
||||
def _start_conversational_session(self) -> None:
|
||||
from crewai.events.listeners.tracing.utils import (
|
||||
set_suppress_tracing_messages,
|
||||
set_tui_mode,
|
||||
)
|
||||
|
||||
set_tui_mode(True)
|
||||
set_suppress_tracing_messages(True)
|
||||
with self._lock:
|
||||
self._status = "chatting"
|
||||
self._current_step = None
|
||||
self._elapsed_frozen = None
|
||||
self._conversation_previous_defer_trace_finalization = getattr(
|
||||
self._flow, "defer_trace_finalization", False
|
||||
)
|
||||
self._flow.defer_trace_finalization = True
|
||||
|
||||
try:
|
||||
input_widget = self.query_one("#conversation-input", Input)
|
||||
input_widget.display = True
|
||||
input_widget.focus()
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
def _finalize_conversational_session(self) -> None:
|
||||
if not (self._is_conversational and self._flow):
|
||||
return
|
||||
try:
|
||||
self._flow.finalize_session_traces()
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
previous = self._conversation_previous_defer_trace_finalization
|
||||
if previous is not None:
|
||||
try:
|
||||
self._flow.defer_trace_finalization = previous
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
def on_input_submitted(self, event: Input.Submitted) -> None:
|
||||
if event.input.id != "conversation-input":
|
||||
return
|
||||
if not self._is_conversational:
|
||||
return
|
||||
|
||||
message = event.value.strip()
|
||||
event.input.value = ""
|
||||
if not message:
|
||||
return
|
||||
if message.lower() in self._conversation_exit_commands:
|
||||
self._finalize_conversational_session()
|
||||
self._unsubscribe()
|
||||
self.exit(self._crew_result)
|
||||
return
|
||||
if self._conversation_turn_in_progress:
|
||||
return
|
||||
|
||||
with self._lock:
|
||||
self._conversation_messages.append(("user", message))
|
||||
self._conversation_turn_in_progress = True
|
||||
self._conversation_turns += 1
|
||||
self._status = "working"
|
||||
self._current_step = ("yellow", "Thinking…", "")
|
||||
self._is_streaming = False
|
||||
self._streaming_text = ""
|
||||
self._task_full_output = ""
|
||||
self._current_llm_text = ""
|
||||
|
||||
event.input.disabled = True
|
||||
self._run_conversation_turn_worker(message)
|
||||
|
||||
@work(thread=True, exclusive=True, group="conversation")
|
||||
def _run_conversation_turn_worker(self, message: str) -> None:
|
||||
from crewai.events.listeners.tracing.utils import (
|
||||
set_suppress_tracing_messages,
|
||||
set_tui_mode,
|
||||
)
|
||||
|
||||
set_tui_mode(True)
|
||||
set_suppress_tracing_messages(True)
|
||||
try:
|
||||
result = self._flow.handle_turn(message)
|
||||
if hasattr(result, "get_full_text") and hasattr(result, "result"):
|
||||
for _chunk in result:
|
||||
pass
|
||||
result = result.result
|
||||
self.call_from_thread(self._on_conversation_turn_done, result)
|
||||
except Exception as e:
|
||||
self.call_from_thread(self._on_conversation_turn_failed, str(e))
|
||||
|
||||
def _on_conversation_turn_done(self, result: Any) -> None:
|
||||
with self._lock:
|
||||
output = self._stringify_output(result)
|
||||
self._conversation_messages.append(("assistant", output))
|
||||
self._crew_result = result
|
||||
self._conversation_turn_in_progress = False
|
||||
self._status = "chatting"
|
||||
self._is_streaming = False
|
||||
self._streaming_text = ""
|
||||
self._current_step = None
|
||||
self._enable_conversation_input()
|
||||
self._tick()
|
||||
self._scroll_to_result()
|
||||
|
||||
def _on_conversation_turn_failed(self, error: str) -> None:
|
||||
with self._lock:
|
||||
self._status = "failed"
|
||||
self._error = error
|
||||
self._conversation_turn_in_progress = False
|
||||
self._is_streaming = False
|
||||
self._current_step = None
|
||||
self._enable_conversation_input()
|
||||
self._tick()
|
||||
|
||||
def _enable_conversation_input(self) -> None:
|
||||
try:
|
||||
input_widget = self.query_one("#conversation-input", Input)
|
||||
input_widget.disabled = False
|
||||
input_widget.focus()
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
def _stringify_output(self, result: Any) -> str:
|
||||
raw_result = getattr(result, "raw", result)
|
||||
if raw_result is None:
|
||||
return ""
|
||||
if isinstance(raw_result, str):
|
||||
return raw_result
|
||||
try:
|
||||
return _json.dumps(raw_result, default=str, ensure_ascii=False)
|
||||
except TypeError:
|
||||
return str(raw_result)
|
||||
|
||||
# ── Actions ─────────────────────────────────────────────
|
||||
|
||||
def action_toggle_sidebar(self) -> None:
|
||||
@@ -783,6 +943,7 @@ FooterKey .footer-key--key {
|
||||
self._refresh_log_panel()
|
||||
|
||||
async def action_quit(self) -> None:
|
||||
self._finalize_conversational_session()
|
||||
self._unsubscribe()
|
||||
self.exit(self._crew_result)
|
||||
|
||||
@@ -958,6 +1119,30 @@ FooterKey .footer-key--key {
|
||||
t = Text()
|
||||
sidebar_width = 30
|
||||
|
||||
if self._is_conversational:
|
||||
t.append(" CONVERSATION\n", style=f"bold {_C_PRIMARY}")
|
||||
t.append("\n")
|
||||
if self._conversation_turn_in_progress:
|
||||
t.append(f" {self._spinner()} ", style=_C_PRIMARY)
|
||||
t.append("Working\n", style=f"bold {_C_TEXT}")
|
||||
elif self._status == "failed":
|
||||
t.append(" ✘ Failed\n", style=_C_RED)
|
||||
else:
|
||||
t.append(" ● Ready\n", style=_C_GREEN)
|
||||
t.append(f" Turns {self._conversation_turns}\n", style=_C_DIM)
|
||||
t.append("\n")
|
||||
t.append(" TOKENS\n", style=f"bold {_C_PRIMARY}")
|
||||
t.append("\n")
|
||||
out = self._output_tokens + self._live_out_tokens
|
||||
t.append(f" ↑ {self._input_tokens:,}\n", style=_C_DIM)
|
||||
t.append(f" ↓ {out:,}\n", style=_C_DIM)
|
||||
t.append("\n")
|
||||
t.append(" COMMANDS\n", style=f"bold {_C_PRIMARY}")
|
||||
t.append("\n")
|
||||
t.append(" quit / exit\n", style=_C_DIM)
|
||||
widget.update(t)
|
||||
return
|
||||
|
||||
t.append(" TASKS\n", style=f"bold {_C_PRIMARY}")
|
||||
t.append("\n")
|
||||
|
||||
@@ -1011,6 +1196,22 @@ FooterKey .footer-key--key {
|
||||
widget = self.query_one("#task-header", Static)
|
||||
t = Text()
|
||||
|
||||
if self._is_conversational:
|
||||
if self._status == "failed":
|
||||
t.append("✘ ", style=f"bold {_C_RED}")
|
||||
t.append("Failed", style=f"bold {_C_RED}")
|
||||
if self._error:
|
||||
t.append(f"\n{self._error[:120]}", style=_C_RED)
|
||||
elif self._conversation_turn_in_progress:
|
||||
t.append(f"{self._spinner()} ", style=_C_PRIMARY)
|
||||
t.append("Flow is responding", style=f"bold {_C_PRIMARY}")
|
||||
else:
|
||||
t.append("● ", style=f"bold {_C_GREEN}")
|
||||
t.append("Conversational flow ready", style=f"bold {_C_GREEN}")
|
||||
t.append(" Type a message below", style=_C_DIM)
|
||||
widget.update(t)
|
||||
return
|
||||
|
||||
if self._status == "completed":
|
||||
elapsed = self._elapsed_frozen or (time.time() - self._start_time)
|
||||
t.append("✔ ", style=f"bold {_C_GREEN}")
|
||||
@@ -1062,6 +1263,41 @@ FooterKey .footer-key--key {
|
||||
t = Text()
|
||||
should_scroll = False
|
||||
|
||||
if self._is_conversational:
|
||||
if not self._conversation_messages and not self._is_streaming:
|
||||
t.append(" Start the conversation below.\n", style=_C_MUTED)
|
||||
for role, content in self._conversation_messages:
|
||||
if role == "user":
|
||||
t.append("\n You\n", style=f"bold {_C_TEAL}")
|
||||
else:
|
||||
t.append("\n Assistant\n", style=f"bold {_C_PRIMARY}")
|
||||
rendered = _format_json_in_text(_unescape_text(content))
|
||||
for line in rendered.split("\n"):
|
||||
style = _C_TEXT if role == "assistant" else _C_DIM
|
||||
t.append(f" {line}\n", style=style)
|
||||
|
||||
if self._is_streaming and self._streaming_text:
|
||||
text = _unescape_text(self._filtered_streaming_text())
|
||||
if text.strip():
|
||||
t.append("\n Assistant\n", style=f"bold {_C_PRIMARY}")
|
||||
for line in text.rstrip().split("\n")[-40:]:
|
||||
t.append(f" {line}\n", style=_C_TEXT)
|
||||
should_scroll = True
|
||||
|
||||
if self._status == "failed" and self._error:
|
||||
t.append("\n Error\n", style=f"bold {_C_RED}")
|
||||
t.append(f" {self._error}\n", style=_C_RED)
|
||||
|
||||
widget.update(t)
|
||||
if should_scroll:
|
||||
try:
|
||||
self.query_one("#scroll-area", VerticalScroll).scroll_end(
|
||||
animate=False
|
||||
)
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
return
|
||||
|
||||
# Plan section
|
||||
if self._plan and self._plan.get("steps"):
|
||||
plan_title = self._plan.get("plan", "Plan")
|
||||
|
||||
105
lib/cli/src/crewai_cli/kickoff_flow.py
Normal file
105
lib/cli/src/crewai_cli/kickoff_flow.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
from pathlib import Path
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
|
||||
|
||||
def _project_script_target(script_name: str) -> str | None:
|
||||
try:
|
||||
from crewai_cli.utils import read_toml
|
||||
|
||||
pyproject = read_toml()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
target = pyproject.get("project", {}).get("scripts", {}).get(script_name)
|
||||
return target if isinstance(target, str) else None
|
||||
|
||||
|
||||
def _prepare_project_import_path() -> None:
|
||||
cwd = Path.cwd()
|
||||
for path in (cwd / "src", cwd):
|
||||
path_str = str(path)
|
||||
if path.exists() and path_str not in sys.path:
|
||||
sys.path.insert(0, path_str)
|
||||
|
||||
|
||||
def _load_conversational_flow_from_kickoff_script() -> Any | None:
|
||||
target = _project_script_target("kickoff")
|
||||
if not target or ":" not in target:
|
||||
return None
|
||||
|
||||
module_name, _callable_name = target.split(":", 1)
|
||||
_prepare_project_import_path()
|
||||
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
from crewai.flow.flow import Flow
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
for value in vars(module).values():
|
||||
if (
|
||||
inspect.isclass(value)
|
||||
and value is not Flow
|
||||
and issubclass(value, Flow)
|
||||
and getattr(value, "conversational", False)
|
||||
):
|
||||
return value()
|
||||
|
||||
for value in vars(module).values():
|
||||
if (
|
||||
isinstance(value, Flow)
|
||||
and getattr(value, "conversational", False)
|
||||
and callable(getattr(value, "handle_turn", None))
|
||||
):
|
||||
return value
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _run_conversational_flow_tui(flow: Any) -> Any:
|
||||
from crewai_cli.crew_run_tui import CrewRunApp
|
||||
|
||||
app = CrewRunApp(
|
||||
crew_name=getattr(flow, "name", None) or type(flow).__name__,
|
||||
conversational=True,
|
||||
)
|
||||
app._flow = flow
|
||||
app.run()
|
||||
|
||||
if app._status == "failed":
|
||||
raise SystemExit(1)
|
||||
|
||||
return app._crew_result
|
||||
|
||||
|
||||
def kickoff_flow() -> None:
|
||||
"""
|
||||
Kickoff the flow by running a command in the UV environment.
|
||||
"""
|
||||
flow = _load_conversational_flow_from_kickoff_script()
|
||||
if flow is not None:
|
||||
_run_conversational_flow_tui(flow)
|
||||
return
|
||||
|
||||
command = ["uv", "run", "kickoff"]
|
||||
|
||||
try:
|
||||
result = subprocess.run(command, capture_output=False, text=True, check=True) # noqa: S603
|
||||
|
||||
if result.stderr:
|
||||
click.echo(result.stderr, err=True)
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
click.echo(f"An error occurred while running the flow: {e}", err=True)
|
||||
click.echo(e.output, err=True)
|
||||
|
||||
except Exception as e:
|
||||
click.echo(f"An unexpected error occurred: {e}", err=True)
|
||||
@@ -604,6 +604,16 @@ def _run_flow_project(
|
||||
run_declarative_flow_in_project_env(definition=definition)
|
||||
return
|
||||
|
||||
from crewai_cli.kickoff_flow import (
|
||||
_load_conversational_flow_from_kickoff_script,
|
||||
_run_conversational_flow_tui,
|
||||
)
|
||||
|
||||
flow = _load_conversational_flow_from_kickoff_script()
|
||||
if flow is not None:
|
||||
_run_conversational_flow_tui(flow)
|
||||
return
|
||||
|
||||
_execute_uv_script("kickoff", entity_type="flow")
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from pathlib import Path, PureWindowsPath
|
||||
import subprocess
|
||||
from typing import Any
|
||||
|
||||
@@ -12,7 +12,7 @@ from crewai_cli.utils import build_env_with_all_tool_credentials
|
||||
|
||||
|
||||
def run_declarative_flow_in_project_env(
|
||||
definition: str, inputs: str | None = None
|
||||
definition: str | Path, inputs: str | None = None
|
||||
) -> None:
|
||||
"""Run a declarative flow inside the project's Python environment."""
|
||||
if is_declarative_flow_project_env() or not _has_project_file():
|
||||
@@ -25,7 +25,7 @@ def run_declarative_flow_in_project_env(
|
||||
_execute_declarative_flow_command(["uv", "run", "crewai", "run"])
|
||||
|
||||
|
||||
def plot_declarative_flow_in_project_env(definition: str) -> None:
|
||||
def plot_declarative_flow_in_project_env(definition: str | Path) -> None:
|
||||
"""Plot a declarative flow inside the project's Python environment."""
|
||||
if is_declarative_flow_project_env() or not _has_project_file():
|
||||
plot_declarative_flow(definition=definition)
|
||||
@@ -34,7 +34,7 @@ def plot_declarative_flow_in_project_env(definition: str) -> None:
|
||||
_execute_declarative_flow_command(["uv", "run", "crewai", "flow", "plot"])
|
||||
|
||||
|
||||
def run_declarative_flow(definition: str, inputs: str | None = None) -> None:
|
||||
def run_declarative_flow(definition: str | Path, inputs: str | None = None) -> None:
|
||||
"""Run a declarative flow from a definition path."""
|
||||
parsed_inputs = _parse_inputs(inputs)
|
||||
|
||||
@@ -50,7 +50,7 @@ def run_declarative_flow(definition: str, inputs: str | None = None) -> None:
|
||||
click.echo(_format_result(result))
|
||||
|
||||
|
||||
def plot_declarative_flow(definition: str) -> None:
|
||||
def plot_declarative_flow(definition: str | Path) -> None:
|
||||
"""Plot a declarative flow from a definition path."""
|
||||
try:
|
||||
flow = load_declarative_flow(definition)
|
||||
@@ -62,7 +62,7 @@ def plot_declarative_flow(definition: str) -> None:
|
||||
raise SystemExit(1) from exc
|
||||
|
||||
|
||||
def load_declarative_flow(definition: str) -> Any:
|
||||
def load_declarative_flow(definition: str | Path) -> Any:
|
||||
"""Load a declarative Flow instance from a definition path."""
|
||||
try:
|
||||
from crewai.flow.flow import Flow
|
||||
@@ -102,7 +102,8 @@ def load_declarative_flow(definition: str) -> Any:
|
||||
|
||||
def configured_project_declarative_flow(
|
||||
pyproject_data: dict[str, Any] | None = None,
|
||||
) -> str | None:
|
||||
project_root: Path | None = None,
|
||||
) -> Path | None:
|
||||
"""Return the configured declarative flow source for flow projects."""
|
||||
if pyproject_data is None:
|
||||
try:
|
||||
@@ -118,7 +119,66 @@ def configured_project_declarative_flow(
|
||||
definition = crewai_config.get("definition")
|
||||
if not isinstance(definition, str):
|
||||
return None
|
||||
return definition.strip() or None
|
||||
definition = definition.strip()
|
||||
if not definition:
|
||||
return None
|
||||
|
||||
return _resolve_project_definition_path(
|
||||
definition=definition,
|
||||
project_root=project_root or Path.cwd(),
|
||||
)
|
||||
|
||||
|
||||
def _resolve_project_definition_path(definition: str, project_root: Path) -> Path:
|
||||
definition_path = Path(definition)
|
||||
windows_definition_path = PureWindowsPath(definition)
|
||||
|
||||
if definition.startswith("~"):
|
||||
raise click.UsageError(
|
||||
"[tool.crewai] definition must be a project-local path; "
|
||||
f"got {definition!r}."
|
||||
)
|
||||
|
||||
if definition_path.is_absolute() or windows_definition_path.is_absolute():
|
||||
raise click.UsageError(
|
||||
"[tool.crewai] definition must be relative to the project root; "
|
||||
f"got {definition!r}."
|
||||
)
|
||||
|
||||
try:
|
||||
root = project_root.resolve(strict=True)
|
||||
except OSError as exc:
|
||||
raise click.UsageError(
|
||||
f"Invalid project root for [tool.crewai] definition: {exc}"
|
||||
) from exc
|
||||
|
||||
candidate = root / definition_path
|
||||
try:
|
||||
resolved_candidate = candidate.resolve(strict=False)
|
||||
except OSError as exc:
|
||||
raise click.UsageError(
|
||||
f"Invalid [tool.crewai] definition path {definition!r}: {exc}"
|
||||
) from exc
|
||||
|
||||
if not resolved_candidate.is_relative_to(root):
|
||||
raise click.UsageError(
|
||||
"[tool.crewai] definition must resolve inside the project root; "
|
||||
f"got {definition!r}."
|
||||
)
|
||||
|
||||
if not resolved_candidate.exists():
|
||||
raise click.UsageError(
|
||||
"[tool.crewai] definition must point to an existing file; "
|
||||
f"got {definition!r}."
|
||||
)
|
||||
|
||||
if not resolved_candidate.is_file():
|
||||
raise click.UsageError(
|
||||
"[tool.crewai] definition must point to a regular file; "
|
||||
f"got {definition!r}."
|
||||
)
|
||||
|
||||
return resolved_candidate
|
||||
|
||||
|
||||
def _execute_declarative_flow_command(command: list[str]) -> None:
|
||||
|
||||
@@ -126,6 +126,52 @@ def test_chain_deploy_does_not_login_for_deploy_exit(monkeypatch, capsys) -> Non
|
||||
assert "Deploy failed with exit code 42" in capsys.readouterr().out
|
||||
|
||||
|
||||
def test_conversation_turn_done_records_assistant_message() -> None:
|
||||
class RawResult:
|
||||
raw = "hello from the flow"
|
||||
|
||||
app = CrewRunApp(conversational=True)
|
||||
app._conversation_turn_in_progress = True
|
||||
app._enable_conversation_input = lambda: None # type: ignore[method-assign]
|
||||
app._tick = lambda: None # type: ignore[method-assign]
|
||||
app._scroll_to_result = lambda: None # type: ignore[method-assign]
|
||||
|
||||
app._on_conversation_turn_done(RawResult())
|
||||
|
||||
assert app._conversation_messages == [("assistant", "hello from the flow")]
|
||||
assert app._conversation_turn_in_progress is False
|
||||
assert app._status == "chatting"
|
||||
assert isinstance(app._crew_result, RawResult)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conversation_input_submits_turn() -> None:
|
||||
class FakeFlow:
|
||||
defer_trace_finalization = False
|
||||
|
||||
def handle_turn(self, message: str) -> str:
|
||||
return f"reply: {message}"
|
||||
|
||||
def finalize_session_traces(self) -> None:
|
||||
pass
|
||||
|
||||
app = CrewRunApp(crew_name="Demo", conversational=True)
|
||||
app._flow = FakeFlow()
|
||||
|
||||
async with app.run_test() as pilot:
|
||||
await pilot.click("#conversation-input")
|
||||
await pilot.press("h", "i", "enter")
|
||||
for _ in range(50):
|
||||
await pilot.pause(0.05)
|
||||
if app._conversation_messages[-1:] == [("assistant", "reply: hi")]:
|
||||
break
|
||||
|
||||
assert app._conversation_messages == [
|
||||
("user", "hi"),
|
||||
("assistant", "reply: hi"),
|
||||
]
|
||||
|
||||
|
||||
def test_plan_step_status_updates_only_the_explicit_step() -> None:
|
||||
app = _app_with_plan()
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
from pathlib import Path
|
||||
import subprocess
|
||||
|
||||
import click
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
|
||||
@@ -107,6 +108,8 @@ def test_configured_project_declarative_flow(
|
||||
monkeypatch: pytest.MonkeyPatch, tmp_path: Path
|
||||
) -> None:
|
||||
monkeypatch.chdir(tmp_path)
|
||||
definition_path = tmp_path / "flow.yaml"
|
||||
definition_path.write_text(FLOW_YAML, encoding="utf-8")
|
||||
(tmp_path / "pyproject.toml").write_text(
|
||||
'[tool.crewai]\ntype = "flow"\ndefinition = " flow.yaml "\n',
|
||||
encoding="utf-8",
|
||||
@@ -114,4 +117,132 @@ def test_configured_project_declarative_flow(
|
||||
|
||||
from crewai_cli.run_declarative_flow import configured_project_declarative_flow
|
||||
|
||||
assert configured_project_declarative_flow() == "flow.yaml"
|
||||
assert configured_project_declarative_flow() == definition_path.resolve()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("definition", "expected_error"),
|
||||
[
|
||||
("C:/tmp/flow.yaml", "must be relative to the project root"),
|
||||
("~/flow.yaml", "must be a project-local path"),
|
||||
("../flow.yaml", "must resolve inside the project root"),
|
||||
],
|
||||
)
|
||||
def test_configured_project_declarative_flow_rejects_unsafe_paths(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path: Path,
|
||||
definition: str,
|
||||
expected_error: str,
|
||||
) -> None:
|
||||
monkeypatch.chdir(tmp_path)
|
||||
(tmp_path / "pyproject.toml").write_text(
|
||||
f'[tool.crewai]\ntype = "flow"\ndefinition = "{definition}"\n',
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
from crewai_cli.run_declarative_flow import configured_project_declarative_flow
|
||||
|
||||
with pytest.raises(click.UsageError) as exc_info:
|
||||
configured_project_declarative_flow()
|
||||
|
||||
assert expected_error in exc_info.value.message
|
||||
|
||||
|
||||
def test_configured_project_declarative_flow_allows_normalized_project_path(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
monkeypatch.chdir(tmp_path)
|
||||
definition_path = tmp_path / "flow.yaml"
|
||||
definition_path.write_text(FLOW_YAML, encoding="utf-8")
|
||||
(tmp_path / "src").mkdir()
|
||||
(tmp_path / "pyproject.toml").write_text(
|
||||
'[tool.crewai]\ntype = "flow"\ndefinition = "src/../flow.yaml"\n',
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
from crewai_cli.run_declarative_flow import configured_project_declarative_flow
|
||||
|
||||
assert configured_project_declarative_flow() == definition_path.resolve()
|
||||
|
||||
|
||||
def test_configured_project_declarative_flow_rejects_absolute_path(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
monkeypatch.chdir(tmp_path)
|
||||
definition = tmp_path / "flow.yaml"
|
||||
(tmp_path / "pyproject.toml").write_text(
|
||||
f'[tool.crewai]\ntype = "flow"\ndefinition = "{definition.as_posix()}"\n',
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
from crewai_cli.run_declarative_flow import configured_project_declarative_flow
|
||||
|
||||
with pytest.raises(click.UsageError) as exc_info:
|
||||
configured_project_declarative_flow()
|
||||
|
||||
assert "must be relative to the project root" in exc_info.value.message
|
||||
|
||||
|
||||
def test_configured_project_declarative_flow_rejects_symlink_escape(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
monkeypatch.chdir(tmp_path)
|
||||
outside_definition = tmp_path.parent / "outside-flow.yaml"
|
||||
outside_definition.write_text(FLOW_YAML, encoding="utf-8")
|
||||
link = tmp_path / "flow.yaml"
|
||||
try:
|
||||
link.symlink_to(outside_definition)
|
||||
except (NotImplementedError, OSError) as exc:
|
||||
pytest.skip(f"symlinks unavailable: {exc}")
|
||||
|
||||
(tmp_path / "pyproject.toml").write_text(
|
||||
'[tool.crewai]\ntype = "flow"\ndefinition = "flow.yaml"\n',
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
from crewai_cli.run_declarative_flow import configured_project_declarative_flow
|
||||
|
||||
with pytest.raises(click.UsageError) as exc_info:
|
||||
configured_project_declarative_flow()
|
||||
|
||||
assert "must resolve inside the project root" in exc_info.value.message
|
||||
|
||||
|
||||
def test_configured_project_declarative_flow_rejects_missing_file(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
monkeypatch.chdir(tmp_path)
|
||||
(tmp_path / "pyproject.toml").write_text(
|
||||
'[tool.crewai]\ntype = "flow"\ndefinition = "missing-flow.yaml"\n',
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
from crewai_cli.run_declarative_flow import configured_project_declarative_flow
|
||||
|
||||
with pytest.raises(click.UsageError) as exc_info:
|
||||
configured_project_declarative_flow()
|
||||
|
||||
assert "must point to an existing file" in exc_info.value.message
|
||||
|
||||
|
||||
def test_configured_project_declarative_flow_rejects_directory(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
monkeypatch.chdir(tmp_path)
|
||||
(tmp_path / "flow.yaml").mkdir()
|
||||
(tmp_path / "pyproject.toml").write_text(
|
||||
'[tool.crewai]\ntype = "flow"\ndefinition = "flow.yaml"\n',
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
from crewai_cli.run_declarative_flow import configured_project_declarative_flow
|
||||
|
||||
with pytest.raises(click.UsageError) as exc_info:
|
||||
configured_project_declarative_flow()
|
||||
|
||||
assert "must point to a regular file" in exc_info.value.message
|
||||
|
||||
63
lib/cli/tests/test_kickoff_flow.py
Normal file
63
lib/cli/tests/test_kickoff_flow.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
|
||||
from crewai_cli import kickoff_flow
|
||||
|
||||
|
||||
def test_loads_conversational_flow_from_kickoff_script(tmp_path, monkeypatch) -> None:
|
||||
package_dir = tmp_path / "src" / "demo_chat"
|
||||
package_dir.mkdir(parents=True)
|
||||
(package_dir / "__init__.py").write_text("")
|
||||
(package_dir / "main.py").write_text(
|
||||
"\n".join(
|
||||
[
|
||||
"from crewai.flow import Flow",
|
||||
"",
|
||||
"class DemoChatFlow(Flow):",
|
||||
" conversational = True",
|
||||
]
|
||||
)
|
||||
)
|
||||
(tmp_path / "pyproject.toml").write_text(
|
||||
"\n".join(
|
||||
[
|
||||
"[project]",
|
||||
'name = "demo-chat"',
|
||||
"[project.scripts]",
|
||||
'kickoff = "demo_chat.main:kickoff"',
|
||||
]
|
||||
)
|
||||
)
|
||||
monkeypatch.chdir(tmp_path)
|
||||
sys.modules.pop("demo_chat.main", None)
|
||||
sys.modules.pop("demo_chat", None)
|
||||
|
||||
flow = kickoff_flow._load_conversational_flow_from_kickoff_script()
|
||||
|
||||
assert flow is not None
|
||||
assert type(flow).__name__ == "DemoChatFlow"
|
||||
assert flow.conversational is True
|
||||
|
||||
|
||||
def test_kickoff_flow_falls_back_to_uv_when_no_conversational_flow(
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
calls: list[list[str]] = []
|
||||
|
||||
def fake_run(command, capture_output, text, check):
|
||||
calls.append(command)
|
||||
|
||||
class Result:
|
||||
stderr = ""
|
||||
|
||||
return Result()
|
||||
|
||||
monkeypatch.setattr(
|
||||
kickoff_flow, "_load_conversational_flow_from_kickoff_script", lambda: None
|
||||
)
|
||||
monkeypatch.setattr(kickoff_flow.subprocess, "run", fake_run)
|
||||
|
||||
kickoff_flow.kickoff_flow()
|
||||
|
||||
assert calls == [["uv", "run", "kickoff"]]
|
||||
@@ -645,6 +645,10 @@ def test_run_crew_runs_python_flow_project(monkeypatch, capsys):
|
||||
"_execute_uv_script",
|
||||
lambda script_name, **kwargs: calls.append((script_name, kwargs)),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"crewai_cli.kickoff_flow._load_conversational_flow_from_kickoff_script",
|
||||
lambda: None,
|
||||
)
|
||||
|
||||
run_crew_module.run_crew()
|
||||
|
||||
@@ -652,6 +656,41 @@ def test_run_crew_runs_python_flow_project(monkeypatch, capsys):
|
||||
assert calls == [("kickoff", {"entity_type": "flow"})]
|
||||
|
||||
|
||||
def test_run_crew_runs_conversational_flow_tui(monkeypatch, capsys):
|
||||
class Flow:
|
||||
pass
|
||||
|
||||
flow = Flow()
|
||||
calls = []
|
||||
|
||||
monkeypatch.setattr(run_crew_module, "_has_json_crew", lambda: False)
|
||||
monkeypatch.setattr(
|
||||
run_crew_module,
|
||||
"read_toml",
|
||||
lambda: {"tool": {"crewai": {"type": "flow"}}},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"crewai_cli.kickoff_flow._load_conversational_flow_from_kickoff_script",
|
||||
lambda: flow,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"crewai_cli.kickoff_flow._run_conversational_flow_tui",
|
||||
lambda loaded_flow: calls.append(loaded_flow),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
run_crew_module,
|
||||
"_execute_uv_script",
|
||||
lambda *_args, **_kwargs: pytest.fail(
|
||||
"conversational flows must use the TUI"
|
||||
),
|
||||
)
|
||||
|
||||
run_crew_module.run_crew()
|
||||
|
||||
assert capsys.readouterr().out == ""
|
||||
assert calls == [flow]
|
||||
|
||||
|
||||
def test_run_crew_rejects_filename_for_flow_project(monkeypatch):
|
||||
monkeypatch.setattr(run_crew_module, "_has_json_crew", lambda: False)
|
||||
monkeypatch.setattr(
|
||||
@@ -666,9 +705,14 @@ def test_run_crew_rejects_filename_for_flow_project(monkeypatch):
|
||||
assert "--filename can only be used when running crews" in exc_info.value.message
|
||||
|
||||
|
||||
def test_run_crew_runs_configured_declarative_flow_project(monkeypatch, capsys):
|
||||
def test_run_crew_runs_configured_declarative_flow_project(
|
||||
monkeypatch, tmp_path: Path, capsys
|
||||
):
|
||||
calls = []
|
||||
|
||||
monkeypatch.chdir(tmp_path)
|
||||
definition_path = tmp_path / "flow.yaml"
|
||||
definition_path.write_text("schema: crewai.flow/v1\n", encoding="utf-8")
|
||||
monkeypatch.setattr(run_crew_module, "_has_json_crew", lambda: False)
|
||||
monkeypatch.setattr(
|
||||
run_crew_module,
|
||||
@@ -695,4 +739,4 @@ def test_run_crew_runs_configured_declarative_flow_project(monkeypatch, capsys):
|
||||
run_crew_module.run_crew()
|
||||
|
||||
assert capsys.readouterr().out == ""
|
||||
assert calls == [("flow.yaml", None)]
|
||||
assert calls == [(definition_path.resolve(), None)]
|
||||
|
||||
161
lib/crewai-tools/src/crewai_tools/security/safe_requests.py
Normal file
161
lib/crewai-tools/src/crewai_tools/security/safe_requests.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""SSRF-safe HTTP fetching for crewai-tools.
|
||||
|
||||
:func:`validate_url` checks the URL it is handed, but it cannot protect a
|
||||
fetch on its own: ``requests`` re-resolves DNS at connect time and follows
|
||||
redirects automatically, so a public-looking host that 302-redirects to an
|
||||
internal address (or that rebinds DNS between validation and connect) reaches
|
||||
the internal target without ever being re-checked.
|
||||
|
||||
This module closes both gaps at the connection layer:
|
||||
|
||||
* :class:`SSRFProtectedAdapter` re-runs :func:`validate_url` for every request
|
||||
it sends. ``requests.Session.send`` invokes the adapter once per redirect
|
||||
hop, so each ``Location`` target is validated before it is followed.
|
||||
* The adapter's connections validate the *actual* peer IP immediately after
|
||||
the socket connects. The IP that was authorised is therefore the IP the
|
||||
connection uses, removing the DNS time-of-check/time-of-use gap that
|
||||
:func:`validate_url`'s own ``getaddrinfo`` call leaves open.
|
||||
|
||||
Use :func:`safe_get` (or :func:`create_safe_session`) instead of calling
|
||||
``requests.get`` directly from any tool that fetches a user- or
|
||||
LLM-controlled URL.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from requests.adapters import DEFAULT_POOLBLOCK, HTTPAdapter
|
||||
from urllib3.connection import HTTPConnection, HTTPSConnection
|
||||
from urllib3.connectionpool import HTTPConnectionPool, HTTPSConnectionPool
|
||||
from urllib3.poolmanager import PoolManager
|
||||
|
||||
from crewai_tools.security.safe_path import (
|
||||
_is_escape_hatch_enabled,
|
||||
_is_private_or_reserved,
|
||||
validate_url,
|
||||
)
|
||||
|
||||
|
||||
def _assert_safe_peer(sock: Any) -> None:
|
||||
"""Raise if a connected socket's peer is a private/reserved address.
|
||||
|
||||
Validating the real peer (rather than a separately resolved IP) is what
|
||||
defeats DNS rebinding: the address we connected to is the address we check.
|
||||
"""
|
||||
if _is_escape_hatch_enabled():
|
||||
return
|
||||
try:
|
||||
peer = sock.getpeername()
|
||||
except OSError:
|
||||
return
|
||||
ip_str = str(peer[0])
|
||||
if _is_private_or_reserved(ip_str):
|
||||
raise ValueError(
|
||||
f"Connection resolved to private/reserved IP {ip_str}. "
|
||||
f"Access to internal networks is not allowed (possible SSRF via "
|
||||
f"redirect or DNS rebinding)."
|
||||
)
|
||||
|
||||
|
||||
class _SafeHTTPConnection(HTTPConnection):
|
||||
def connect(self) -> None:
|
||||
super().connect()
|
||||
_assert_safe_peer(self.sock)
|
||||
|
||||
|
||||
class _SafeHTTPSConnection(HTTPSConnection):
|
||||
def connect(self) -> None:
|
||||
super().connect()
|
||||
_assert_safe_peer(self.sock)
|
||||
|
||||
|
||||
class _SafeHTTPConnectionPool(HTTPConnectionPool):
|
||||
ConnectionCls = _SafeHTTPConnection
|
||||
|
||||
|
||||
class _SafeHTTPSConnectionPool(HTTPSConnectionPool):
|
||||
ConnectionCls = _SafeHTTPSConnection
|
||||
|
||||
|
||||
_SAFE_POOL_CLASSES = {
|
||||
"http": _SafeHTTPConnectionPool,
|
||||
"https": _SafeHTTPSConnectionPool,
|
||||
}
|
||||
|
||||
|
||||
class _SafePoolManager(PoolManager):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.pool_classes_by_scheme = _SAFE_POOL_CLASSES
|
||||
|
||||
|
||||
class SSRFProtectedAdapter(HTTPAdapter):
|
||||
"""Transport adapter that re-validates every hop and pins the peer IP.
|
||||
|
||||
``validate_url`` runs on each ``send`` — including every redirect hop
|
||||
``requests`` follows — and the underlying connections reject any socket
|
||||
that ends up connected to a private/reserved address.
|
||||
"""
|
||||
|
||||
def init_poolmanager(
|
||||
self,
|
||||
connections: int,
|
||||
maxsize: int,
|
||||
block: bool = DEFAULT_POOLBLOCK,
|
||||
**pool_kwargs: Any,
|
||||
) -> None:
|
||||
self.poolmanager = _SafePoolManager(
|
||||
num_pools=connections,
|
||||
maxsize=maxsize,
|
||||
block=block,
|
||||
**pool_kwargs,
|
||||
)
|
||||
|
||||
def send(self, request: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
# Re-validate the target of every request the session sends. Because
|
||||
# Session.send calls this once per redirect hop, each Location is
|
||||
# checked before it is followed.
|
||||
validate_url(request.url)
|
||||
return super().send(request, *args, **kwargs)
|
||||
|
||||
|
||||
def create_safe_session() -> requests.Session:
|
||||
"""Return a ``requests.Session`` that is hardened against SSRF.
|
||||
|
||||
The session validates every request (and redirect hop) and pins
|
||||
connections to the validated peer IP.
|
||||
"""
|
||||
session = requests.Session()
|
||||
# Ambient proxy settings bypass the protected pool classes via requests'
|
||||
# proxy manager path, so safe fetches must opt out of environment config.
|
||||
session.trust_env = False
|
||||
adapter = SSRFProtectedAdapter()
|
||||
session.mount("http://", adapter)
|
||||
session.mount("https://", adapter)
|
||||
return session
|
||||
|
||||
|
||||
def safe_get(url: str, **kwargs: Any) -> requests.Response:
|
||||
"""Perform an SSRF-safe ``GET``.
|
||||
|
||||
Drop-in replacement for ``requests.get`` for tools that fetch a
|
||||
user- or LLM-controlled URL. Validates the initial URL and every redirect
|
||||
hop, and rejects connections that land on private/reserved addresses.
|
||||
|
||||
Args:
|
||||
url: The URL to fetch.
|
||||
**kwargs: Forwarded to ``Session.get`` (``headers``, ``cookies``,
|
||||
``timeout``, ...).
|
||||
|
||||
Returns:
|
||||
The ``requests.Response``.
|
||||
|
||||
Raises:
|
||||
ValueError: If the URL, a redirect target, or the connected peer is
|
||||
not allowed.
|
||||
"""
|
||||
validate_url(url)
|
||||
with create_safe_session() as session:
|
||||
return session.get(url, **kwargs)
|
||||
@@ -3,9 +3,8 @@ from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
import requests
|
||||
|
||||
from crewai_tools.security.safe_path import validate_url
|
||||
from crewai_tools.security.safe_requests import safe_get
|
||||
|
||||
|
||||
try:
|
||||
@@ -83,8 +82,7 @@ class ScrapeElementFromWebsiteTool(BaseTool):
|
||||
if website_url is None or css_element is None:
|
||||
raise ValueError("Both website_url and css_element must be provided.")
|
||||
|
||||
website_url = validate_url(website_url)
|
||||
page = requests.get(
|
||||
page = safe_get(
|
||||
website_url,
|
||||
headers=self.headers,
|
||||
cookies=self.cookies if self.cookies else {},
|
||||
|
||||
@@ -3,9 +3,8 @@ import re
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
import requests
|
||||
|
||||
from crewai_tools.security.safe_path import validate_url
|
||||
from crewai_tools.security.safe_requests import safe_get
|
||||
|
||||
|
||||
try:
|
||||
@@ -75,8 +74,7 @@ class ScrapeWebsiteTool(BaseTool):
|
||||
if website_url is None:
|
||||
raise ValueError("Website URL must be provided.")
|
||||
|
||||
website_url = validate_url(website_url)
|
||||
page = requests.get(
|
||||
page = safe_get(
|
||||
website_url,
|
||||
timeout=15,
|
||||
headers=self.headers,
|
||||
|
||||
148
lib/crewai-tools/tests/utilities/test_safe_requests.py
Normal file
148
lib/crewai-tools/tests/utilities/test_safe_requests.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""Tests for SSRF-safe HTTP fetching (redirect + DNS-rebinding protection)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import http.server
|
||||
import socketserver
|
||||
import threading
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from crewai_tools.security import safe_requests
|
||||
from crewai_tools.security.safe_requests import (
|
||||
SSRFProtectedAdapter,
|
||||
create_safe_session,
|
||||
safe_get,
|
||||
)
|
||||
|
||||
|
||||
INTERNAL_BODY = b"INTERNAL-ONLY-SECRET"
|
||||
|
||||
|
||||
class _InternalHandler(http.server.BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "text/plain")
|
||||
self.end_headers()
|
||||
self.wfile.write(INTERNAL_BODY)
|
||||
|
||||
def log_message(self, *args): # silence
|
||||
pass
|
||||
|
||||
|
||||
def _serve(handler):
|
||||
"""Start a localhost server on an ephemeral port; return (server, port)."""
|
||||
server = socketserver.TCPServer(("127.0.0.1", 0), handler)
|
||||
port = server.server_address[1]
|
||||
threading.Thread(target=server.serve_forever, daemon=True).start()
|
||||
return server, port
|
||||
|
||||
|
||||
class TestRedirectRevalidation:
|
||||
"""Layer 1: validate_url runs on every send, including each redirect hop.
|
||||
|
||||
``requests.Session.send`` calls ``adapter.send`` once per redirect hop, so
|
||||
re-validating in ``send`` is what blocks a 302 to an internal target.
|
||||
"""
|
||||
|
||||
def test_adapter_revalidates_before_any_network_call(self, monkeypatch):
|
||||
calls: list[str] = []
|
||||
|
||||
def spy(url: str) -> str:
|
||||
calls.append(url)
|
||||
if "internal.target" in url:
|
||||
raise ValueError("URL resolves to private/reserved IP")
|
||||
return url
|
||||
|
||||
monkeypatch.setattr(safe_requests, "validate_url", spy)
|
||||
|
||||
adapter = SSRFProtectedAdapter()
|
||||
# Internal redirect target: send() must reject it before ever calling
|
||||
# the real transport (super().send is never reached).
|
||||
req = requests.Request("GET", "http://internal.target/").prepare()
|
||||
with pytest.raises(ValueError, match="private/reserved"):
|
||||
adapter.send(req)
|
||||
assert calls == ["http://internal.target/"]
|
||||
|
||||
def test_session_mounts_protected_adapter(self):
|
||||
session = create_safe_session()
|
||||
assert isinstance(session.get_adapter("http://x"), SSRFProtectedAdapter)
|
||||
assert isinstance(session.get_adapter("https://x"), SSRFProtectedAdapter)
|
||||
assert session.trust_env is False
|
||||
|
||||
def test_safe_get_ignores_environment_proxies(self, monkeypatch):
|
||||
"""Environment proxies must not route safe fetches around the safe pool."""
|
||||
monkeypatch.setenv("HTTP_PROXY", "http://127.0.0.1:9999")
|
||||
monkeypatch.setenv("HTTPS_PROXY", "http://127.0.0.1:9999")
|
||||
monkeypatch.setattr(safe_requests, "validate_url", lambda url: url)
|
||||
|
||||
def fail_proxy_manager(self, proxy, **proxy_kwargs):
|
||||
raise AssertionError("safe_get unexpectedly used an environment proxy")
|
||||
|
||||
def fake_send(self, request, **kwargs):
|
||||
assert kwargs["proxies"] == {}
|
||||
response = requests.Response()
|
||||
response.status_code = 200
|
||||
response.url = request.url
|
||||
return response
|
||||
|
||||
monkeypatch.setattr(SSRFProtectedAdapter, "proxy_manager_for", fail_proxy_manager)
|
||||
monkeypatch.setattr(requests.adapters.HTTPAdapter, "send", fake_send)
|
||||
|
||||
response = safe_get("http://example.com/", timeout=10)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
class _FakeSock:
|
||||
def __init__(self, peer):
|
||||
self._peer = peer
|
||||
|
||||
def getpeername(self):
|
||||
return self._peer
|
||||
|
||||
|
||||
class TestConnectionPeerGuard:
|
||||
"""Layer 2: the connection rejects an internal peer IP at connect time.
|
||||
|
||||
This is what closes the validate-then-connect DNS-rebinding gap — the IP
|
||||
the socket actually connected to is the IP that gets checked, so a host
|
||||
that resolved public at validation time but connects internal is blocked.
|
||||
"""
|
||||
|
||||
def test_safe_get_blocks_direct_internal(self):
|
||||
# No network: validate_url rejects 127.0.0.1 at the URL layer first.
|
||||
with pytest.raises(ValueError, match="private/reserved"):
|
||||
safe_get("http://127.0.0.1:9/", timeout=10)
|
||||
|
||||
def test_assert_safe_peer_blocks_private(self):
|
||||
with pytest.raises(ValueError, match="private/reserved"):
|
||||
safe_requests._assert_safe_peer(_FakeSock(("127.0.0.1", 80)))
|
||||
|
||||
def test_assert_safe_peer_blocks_metadata(self):
|
||||
with pytest.raises(ValueError, match="private/reserved"):
|
||||
safe_requests._assert_safe_peer(_FakeSock(("169.254.169.254", 80)))
|
||||
|
||||
def test_assert_safe_peer_allows_public(self):
|
||||
# A public IP must not raise.
|
||||
safe_requests._assert_safe_peer(_FakeSock(("93.184.216.34", 80)))
|
||||
|
||||
def test_assert_safe_peer_respects_escape_hatch(self, monkeypatch):
|
||||
monkeypatch.setenv("CREWAI_TOOLS_ALLOW_UNSAFE_PATHS", "true")
|
||||
# No raise even for a private peer when the escape hatch is on.
|
||||
safe_requests._assert_safe_peer(_FakeSock(("127.0.0.1", 80)))
|
||||
|
||||
def test_connection_validates_peer_after_connect(self, monkeypatch):
|
||||
"""_SafeHTTPConnection.connect runs the peer guard after connecting."""
|
||||
conn = safe_requests._SafeHTTPConnection("example.com")
|
||||
|
||||
def fake_super_connect(self):
|
||||
# Simulate a rebind: we connected to an internal address.
|
||||
self.sock = _FakeSock(("127.0.0.1", 80))
|
||||
|
||||
monkeypatch.setattr(
|
||||
safe_requests.HTTPConnection, "connect", fake_super_connect
|
||||
)
|
||||
with pytest.raises(ValueError, match="private/reserved"):
|
||||
conn.connect()
|
||||
Reference in New Issue
Block a user