mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-30 23:02:50 +00:00
fix: add path and URL validation to RAG tools (#5310)
* fix: add path and URL validation to RAG tools Add validation utilities to prevent unauthorized file reads and SSRF when RAG tools accept LLM-controlled paths/URLs at runtime. Changes: - New crewai_tools.utilities.safe_path module with validate_file_path(), validate_directory_path(), and validate_url() - File paths validated against base directory (defaults to cwd). Resolves symlinks and ../ traversal. Rejects escape attempts. - URLs validated: file:// blocked entirely. HTTP/HTTPS resolves DNS and blocks private/reserved IPs (10.x, 172.16-31.x, 192.168.x, 127.x, 169.254.x, 0.0.0.0, ::1, fc00::/7). - Validation applied in RagTool.add() — catches all RAG search tools (JSON, CSV, PDF, TXT, DOCX, MDX, Directory, etc.) - Removed file:// scheme support from DataTypes.from_content() - CREWAI_TOOLS_ALLOW_UNSAFE_PATHS=true env var for backward compat - 27 tests covering traversal, symlinks, private IPs, cloud metadata, IPv6, escape hatch, and valid paths/URLs * fix: validate path/URL keyword args in RagTool.add() The original patch validated positional *args but left all keyword arguments (path=, file_path=, directory_path=, url=, website=, github_url=, youtube_url=) unvalidated, providing a trivial bypass for both path-traversal and SSRF checks. Applies validate_file_path() to path/file_path/directory_path kwargs and validate_url() to url/website/github_url/youtube_url kwargs before they reach the adapter. Adds a regression-test file covering all eight kwarg vectors plus the two existing positional-arg checks. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix: address CodeQL and review comments on RAG path/URL validation - Replace insecure tempfile.mktemp() with inline symlink target in test - Remove unused 'target' variable and unused tempfile import - Narrow broad except Exception: pass to only catch urlparse errors; validate_url ValueError now propagates instead of being silently swallowed - Fix ruff B904 (raise-without-from-inside-except) in safe_path.py - Fix ruff B007 (unused loop variable 'family') in safe_path.py - Use validate_directory_path in DirectorySearchTool.add() so the public utility is exercised in production code Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * style: fix ruff format + remaining lint issues * fix: resolve mypy type errors in RAG path/URL validation - Cast sockaddr[0] to str() to satisfy mypy (socket.getaddrinfo returns sockaddr where [0] is str but typed as str | int) - Remove now-unnecessary `type: ignore[assignment]` and `type: ignore[literal-required]` comments in rag_tool.py Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix: unroll dynamic TypedDict key loops to satisfy mypy literal-required Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * test: allow tmp paths in RAG data-type tests via CREWAI_TOOLS_ALLOW_UNSAFE_PATHS TemporaryDirectory creates files under /tmp/ which is outside CWD and is correctly blocked by the new path validation. These tests exercise data-type handling, not security, so add an autouse fixture that sets CREWAI_TOOLS_ALLOW_UNSAFE_PATHS=true for the whole file. Path/URL security is covered by test_rag_tool_path_validation.py. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * test: allow tmp paths in search-tool and rag_tool tests via CREWAI_TOOLS_ALLOW_UNSAFE_PATHS test_search_tools.py has tests for TXTSearchTool, CSVSearchTool, MDXSearchTool, JSONSearchTool, and DirectorySearchTool that create files under /tmp/ via tempfile, which is outside CWD and correctly blocked by the new path validation. rag_tool_test.py has one test that calls tool.add() with a TemporaryDirectory path. Add the same autouse allow_tmp_paths fixture used in test_rag_tool_add_data_type.py. Security is covered separately by test_rag_tool_path_validation.py. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * chore: update tool specifications * docs: document CodeInterpreterTool removal and RAG path/URL validation Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix: address three review comments on path/URL validation - safe_path._is_private_or_reserved: after unwrapping IPv4-mapped IPv6 to IPv4, only check against IPv4 networks to avoid TypeError when comparing an IPv4Address against IPv6Network objects. - safe_path.validate_file_path: handle filesystem-root base_dir ('/') by not appending os.sep when the base already ends with a separator, preventing the '//'-prefix bug. - rag_tool.add: path-detection heuristic now checks for both '/' and os.sep so forward-slash paths are caught on Windows as well as Unix. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix: remove unused _BLOCKED_NETWORKS variable after IPv4/IPv6 split * chore: update tool specifications --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
@@ -109,7 +109,7 @@ class DataTypes:
|
||||
if isinstance(content, str):
|
||||
try:
|
||||
url = urlparse(content)
|
||||
is_url = bool(url.scheme and url.netloc) or url.scheme == "file"
|
||||
is_url = bool(url.scheme in ("http", "https") and url.netloc)
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||
from crewai_tools.utilities.safe_path import validate_directory_path
|
||||
|
||||
|
||||
class FixedDirectorySearchToolSchema(BaseModel):
|
||||
@@ -37,6 +38,7 @@ class DirectorySearchTool(RagTool):
|
||||
self._generate_description()
|
||||
|
||||
def add(self, directory: str) -> None: # type: ignore[override]
|
||||
validate_directory_path(directory)
|
||||
super().add(directory, data_type=DataType.DIRECTORY)
|
||||
|
||||
def _run( # type: ignore[override]
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import os
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
||||
@@ -246,7 +247,83 @@ class RagTool(BaseTool):
|
||||
# Auto-detect type from extension
|
||||
rag_tool.add("path/to/document.pdf") # auto-detects PDF
|
||||
"""
|
||||
self.adapter.add(*args, **kwargs)
|
||||
# Validate file paths and URLs before adding to prevent
|
||||
# unauthorized file reads and SSRF.
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from crewai_tools.utilities.safe_path import validate_file_path, validate_url
|
||||
|
||||
def _check_url(value: str, label: str) -> None:
|
||||
try:
|
||||
validate_url(value)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Blocked unsafe {label}: {e}") from e
|
||||
|
||||
def _check_path(value: str, label: str) -> None:
|
||||
try:
|
||||
validate_file_path(value)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Blocked unsafe {label}: {e}") from e
|
||||
|
||||
validated_args: list[ContentItem] = []
|
||||
for arg in args:
|
||||
source_ref = (
|
||||
str(arg.get("source", arg.get("content", "")))
|
||||
if isinstance(arg, dict)
|
||||
else str(arg)
|
||||
)
|
||||
|
||||
# Check if it's a URL — only catch urlparse-specific errors here;
|
||||
# validate_url's ValueError must propagate so it is never silently bypassed.
|
||||
try:
|
||||
parsed = urlparse(source_ref)
|
||||
except (ValueError, AttributeError):
|
||||
parsed = None
|
||||
|
||||
if parsed is not None and parsed.scheme in ("http", "https", "file"):
|
||||
try:
|
||||
validate_url(source_ref)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Blocked unsafe URL: {e}") from e
|
||||
validated_args.append(arg)
|
||||
continue
|
||||
|
||||
# Check if it looks like a file path (not a plain text string).
|
||||
# Check both os.sep (backslash on Windows) and "/" so that
|
||||
# forward-slash paths like "sub/file.txt" are caught on all platforms.
|
||||
if (
|
||||
os.path.sep in source_ref
|
||||
or "/" in source_ref
|
||||
or source_ref.startswith(".")
|
||||
or os.path.isabs(source_ref)
|
||||
):
|
||||
try:
|
||||
validate_file_path(source_ref)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Blocked unsafe file path: {e}") from e
|
||||
|
||||
validated_args.append(arg)
|
||||
|
||||
# Validate keyword path/URL arguments — these are equally user-controlled
|
||||
# and must not bypass the checks applied to positional args.
|
||||
if "path" in kwargs and kwargs.get("path") is not None:
|
||||
_check_path(str(kwargs["path"]), "path")
|
||||
if "file_path" in kwargs and kwargs.get("file_path") is not None:
|
||||
_check_path(str(kwargs["file_path"]), "file_path")
|
||||
|
||||
if "directory_path" in kwargs and kwargs.get("directory_path") is not None:
|
||||
_check_path(str(kwargs["directory_path"]), "directory_path")
|
||||
|
||||
if "url" in kwargs and kwargs.get("url") is not None:
|
||||
_check_url(str(kwargs["url"]), "url")
|
||||
if "website" in kwargs and kwargs.get("website") is not None:
|
||||
_check_url(str(kwargs["website"]), "website")
|
||||
if "github_url" in kwargs and kwargs.get("github_url") is not None:
|
||||
_check_url(str(kwargs["github_url"]), "github_url")
|
||||
if "youtube_url" in kwargs and kwargs.get("youtube_url") is not None:
|
||||
_check_url(str(kwargs["youtube_url"]), "youtube_url")
|
||||
|
||||
self.adapter.add(*validated_args, **kwargs)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
|
||||
205
lib/crewai-tools/src/crewai_tools/utilities/safe_path.py
Normal file
205
lib/crewai-tools/src/crewai_tools/utilities/safe_path.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""Path and URL validation utilities for crewai-tools.
|
||||
|
||||
Provides validation for file paths and URLs to prevent unauthorized
|
||||
file access and server-side request forgery (SSRF) when tools accept
|
||||
user-controlled or LLM-controlled inputs at runtime.
|
||||
|
||||
Set CREWAI_TOOLS_ALLOW_UNSAFE_PATHS=true to bypass validation (not
|
||||
recommended for production).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_UNSAFE_PATHS_ENV = "CREWAI_TOOLS_ALLOW_UNSAFE_PATHS"
|
||||
|
||||
|
||||
def _is_escape_hatch_enabled() -> bool:
|
||||
"""Check if the unsafe paths escape hatch is enabled."""
|
||||
return os.environ.get(_UNSAFE_PATHS_ENV, "").lower() in ("true", "1", "yes")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# File path validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def validate_file_path(path: str, base_dir: str | None = None) -> str:
|
||||
"""Validate that a file path is safe to read.
|
||||
|
||||
Resolves symlinks and ``..`` components, then checks that the resolved
|
||||
path falls within *base_dir* (defaults to the current working directory).
|
||||
|
||||
Args:
|
||||
path: The file path to validate.
|
||||
base_dir: Allowed root directory. Defaults to ``os.getcwd()``.
|
||||
|
||||
Returns:
|
||||
The resolved, validated absolute path.
|
||||
|
||||
Raises:
|
||||
ValueError: If the path escapes the allowed directory.
|
||||
"""
|
||||
if _is_escape_hatch_enabled():
|
||||
logger.warning(
|
||||
"%s is enabled — skipping file path validation for: %s",
|
||||
_UNSAFE_PATHS_ENV,
|
||||
path,
|
||||
)
|
||||
return os.path.realpath(path)
|
||||
|
||||
if base_dir is None:
|
||||
base_dir = os.getcwd()
|
||||
|
||||
resolved_base = os.path.realpath(base_dir)
|
||||
resolved_path = os.path.realpath(
|
||||
os.path.join(resolved_base, path) if not os.path.isabs(path) else path
|
||||
)
|
||||
|
||||
# Ensure the resolved path is within the base directory.
|
||||
# When resolved_base already ends with a separator (e.g. the filesystem
|
||||
# root "/"), appending os.sep would double it ("//"), so use the base
|
||||
# as-is in that case.
|
||||
prefix = resolved_base if resolved_base.endswith(os.sep) else resolved_base + os.sep
|
||||
if not resolved_path.startswith(prefix) and resolved_path != resolved_base:
|
||||
raise ValueError(
|
||||
f"Path '{path}' resolves to '{resolved_path}' which is outside "
|
||||
f"the allowed directory '{resolved_base}'. "
|
||||
f"Set {_UNSAFE_PATHS_ENV}=true to bypass this check."
|
||||
)
|
||||
|
||||
return resolved_path
|
||||
|
||||
|
||||
def validate_directory_path(path: str, base_dir: str | None = None) -> str:
|
||||
"""Validate that a directory path is safe to read.
|
||||
|
||||
Same as :func:`validate_file_path` but also checks that the path
|
||||
is an existing directory.
|
||||
|
||||
Args:
|
||||
path: The directory path to validate.
|
||||
base_dir: Allowed root directory. Defaults to ``os.getcwd()``.
|
||||
|
||||
Returns:
|
||||
The resolved, validated absolute path.
|
||||
|
||||
Raises:
|
||||
ValueError: If the path escapes the allowed directory or is not a directory.
|
||||
"""
|
||||
validated = validate_file_path(path, base_dir)
|
||||
if not os.path.isdir(validated):
|
||||
raise ValueError(f"Path '{validated}' is not a directory.")
|
||||
return validated
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# URL validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Private and reserved IP ranges that should not be accessed
|
||||
_BLOCKED_IPV4_NETWORKS = [
|
||||
ipaddress.ip_network("10.0.0.0/8"),
|
||||
ipaddress.ip_network("172.16.0.0/12"),
|
||||
ipaddress.ip_network("192.168.0.0/16"),
|
||||
ipaddress.ip_network("127.0.0.0/8"),
|
||||
ipaddress.ip_network("169.254.0.0/16"), # Link-local / cloud metadata
|
||||
ipaddress.ip_network("0.0.0.0/32"),
|
||||
]
|
||||
|
||||
_BLOCKED_IPV6_NETWORKS = [
|
||||
ipaddress.ip_network("::1/128"),
|
||||
ipaddress.ip_network("::/128"),
|
||||
ipaddress.ip_network("fc00::/7"), # Unique local addresses
|
||||
ipaddress.ip_network("fe80::/10"), # Link-local IPv6
|
||||
]
|
||||
|
||||
|
||||
def _is_private_or_reserved(ip_str: str) -> bool:
|
||||
"""Check if an IP address is private, reserved, or otherwise unsafe."""
|
||||
try:
|
||||
addr = ipaddress.ip_address(ip_str)
|
||||
# Unwrap IPv4-mapped IPv6 addresses (e.g., ::ffff:127.0.0.1) to IPv4
|
||||
# so they are only checked against IPv4 networks (avoids TypeError when
|
||||
# an IPv4Address is compared against an IPv6Network).
|
||||
if isinstance(addr, ipaddress.IPv6Address) and addr.ipv4_mapped:
|
||||
addr = addr.ipv4_mapped
|
||||
networks = (
|
||||
_BLOCKED_IPV4_NETWORKS
|
||||
if isinstance(addr, ipaddress.IPv4Address)
|
||||
else _BLOCKED_IPV6_NETWORKS
|
||||
)
|
||||
return any(addr in network for network in networks)
|
||||
except ValueError:
|
||||
return True # If we can't parse, block it
|
||||
|
||||
|
||||
def validate_url(url: str) -> str:
|
||||
"""Validate that a URL is safe to fetch.
|
||||
|
||||
Blocks ``file://`` scheme entirely. For ``http``/``https``, resolves
|
||||
DNS and checks that the target IP is not private or reserved (prevents
|
||||
SSRF to internal services and cloud metadata endpoints).
|
||||
|
||||
Args:
|
||||
url: The URL to validate.
|
||||
|
||||
Returns:
|
||||
The validated URL string.
|
||||
|
||||
Raises:
|
||||
ValueError: If the URL uses a blocked scheme or resolves to a
|
||||
private/reserved IP address.
|
||||
"""
|
||||
if _is_escape_hatch_enabled():
|
||||
logger.warning(
|
||||
"%s is enabled — skipping URL validation for: %s",
|
||||
_UNSAFE_PATHS_ENV,
|
||||
url,
|
||||
)
|
||||
return url
|
||||
|
||||
parsed = urlparse(url)
|
||||
|
||||
# Block file:// scheme
|
||||
if parsed.scheme == "file":
|
||||
raise ValueError(
|
||||
f"file:// URLs are not allowed: '{url}'. "
|
||||
f"Use a file path instead, or set {_UNSAFE_PATHS_ENV}=true to bypass."
|
||||
)
|
||||
|
||||
# Only allow http and https
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise ValueError(
|
||||
f"URL scheme '{parsed.scheme}' is not allowed. Only http and https are supported."
|
||||
)
|
||||
|
||||
if not parsed.hostname:
|
||||
raise ValueError(f"URL has no hostname: '{url}'")
|
||||
|
||||
# Resolve DNS and check IPs
|
||||
try:
|
||||
addrinfos = socket.getaddrinfo(
|
||||
parsed.hostname, parsed.port or (443 if parsed.scheme == "https" else 80)
|
||||
)
|
||||
except socket.gaierror as exc:
|
||||
raise ValueError(f"Could not resolve hostname: '{parsed.hostname}'") from exc
|
||||
|
||||
for _family, _, _, _, sockaddr in addrinfos:
|
||||
ip_str = str(sockaddr[0])
|
||||
if _is_private_or_reserved(ip_str):
|
||||
raise ValueError(
|
||||
f"URL '{url}' resolves to private/reserved IP {ip_str}. "
|
||||
f"Access to internal networks is not allowed. "
|
||||
f"Set {_UNSAFE_PATHS_ENV}=true to bypass."
|
||||
)
|
||||
|
||||
return url
|
||||
@@ -3,10 +3,21 @@ from tempfile import TemporaryDirectory
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai_tools.adapters.crewai_rag_adapter import CrewAIRagAdapter
|
||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def allow_tmp_paths(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Allow absolute paths outside CWD (e.g. /tmp/) for these RagTool tests.
|
||||
|
||||
Path validation is tested separately in test_rag_tool_path_validation.py.
|
||||
"""
|
||||
monkeypatch.setenv("CREWAI_TOOLS_ALLOW_UNSAFE_PATHS", "true")
|
||||
|
||||
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.get_rag_client")
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
|
||||
def test_rag_tool_initialization(
|
||||
|
||||
@@ -10,6 +10,15 @@ from crewai_tools.rag.data_types import DataType
|
||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def allow_tmp_paths(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Allow absolute paths outside CWD (e.g. /tmp/) for these data-type tests.
|
||||
|
||||
Path validation is tested separately in test_rag_tool_path_validation.py.
|
||||
"""
|
||||
monkeypatch.setenv("CREWAI_TOOLS_ALLOW_UNSAFE_PATHS", "true")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_rag_client() -> MagicMock:
|
||||
"""Create a mock RAG client for testing."""
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
"""Tests for path and URL validation in RagTool.add() — both positional and keyword args."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_rag_client() -> MagicMock:
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_client.add_documents = MagicMock(return_value=None)
|
||||
mock_client.search = MagicMock(return_value=[])
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(mock_rag_client: MagicMock) -> RagTool:
|
||||
with (
|
||||
patch("crewai_tools.adapters.crewai_rag_adapter.get_rag_client", return_value=mock_rag_client),
|
||||
patch("crewai_tools.adapters.crewai_rag_adapter.create_client", return_value=mock_rag_client),
|
||||
):
|
||||
return RagTool()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Positional arg validation (existing behaviour, regression guard)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPositionalArgValidation:
|
||||
def test_blocks_traversal_in_positional_arg(self, tool):
|
||||
with pytest.raises(ValueError, match="Blocked unsafe"):
|
||||
tool.add("../../etc/passwd")
|
||||
|
||||
def test_blocks_file_url_in_positional_arg(self, tool):
|
||||
with pytest.raises(ValueError, match="Blocked unsafe"):
|
||||
tool.add("file:///etc/passwd")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Keyword arg validation (the newly fixed gap)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestKwargPathValidation:
|
||||
def test_blocks_traversal_via_path_kwarg(self, tool):
|
||||
with pytest.raises(ValueError, match="Blocked unsafe path"):
|
||||
tool.add(path="../../etc/passwd")
|
||||
|
||||
def test_blocks_traversal_via_file_path_kwarg(self, tool):
|
||||
with pytest.raises(ValueError, match="Blocked unsafe file_path"):
|
||||
tool.add(file_path="/etc/passwd")
|
||||
|
||||
def test_blocks_traversal_via_directory_path_kwarg(self, tool):
|
||||
with pytest.raises(ValueError, match="Blocked unsafe directory_path"):
|
||||
tool.add(directory_path="../../sensitive_dir")
|
||||
|
||||
def test_blocks_file_url_via_url_kwarg(self, tool):
|
||||
with pytest.raises(ValueError, match="Blocked unsafe url"):
|
||||
tool.add(url="file:///etc/passwd")
|
||||
|
||||
def test_blocks_private_ip_via_url_kwarg(self, tool):
|
||||
with pytest.raises(ValueError, match="Blocked unsafe url"):
|
||||
tool.add(url="http://169.254.169.254/latest/meta-data/")
|
||||
|
||||
def test_blocks_private_ip_via_website_kwarg(self, tool):
|
||||
with pytest.raises(ValueError, match="Blocked unsafe website"):
|
||||
tool.add(website="http://192.168.1.1/")
|
||||
|
||||
def test_blocks_file_url_via_github_url_kwarg(self, tool):
|
||||
with pytest.raises(ValueError, match="Blocked unsafe github_url"):
|
||||
tool.add(github_url="file:///etc/passwd")
|
||||
|
||||
def test_blocks_file_url_via_youtube_url_kwarg(self, tool):
|
||||
with pytest.raises(ValueError, match="Blocked unsafe youtube_url"):
|
||||
tool.add(youtube_url="file:///etc/passwd")
|
||||
|
||||
@@ -23,6 +23,15 @@ from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def allow_tmp_paths(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Allow absolute paths outside CWD (e.g. /tmp/) for these search-tool tests.
|
||||
|
||||
Path validation is tested separately in test_rag_tool_path_validation.py.
|
||||
"""
|
||||
monkeypatch.setenv("CREWAI_TOOLS_ALLOW_UNSAFE_PATHS", "true")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_adapter():
|
||||
mock_adapter = MagicMock(spec=Adapter)
|
||||
|
||||
0
lib/crewai-tools/tests/utilities/__init__.py
Normal file
0
lib/crewai-tools/tests/utilities/__init__.py
Normal file
170
lib/crewai-tools/tests/utilities/test_safe_path.py
Normal file
170
lib/crewai-tools/tests/utilities/test_safe_path.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""Tests for path and URL validation utilities."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai_tools.utilities.safe_path import (
|
||||
validate_directory_path,
|
||||
validate_file_path,
|
||||
validate_url,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# File path validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestValidateFilePath:
|
||||
"""Tests for validate_file_path."""
|
||||
|
||||
def test_valid_relative_path(self, tmp_path):
|
||||
"""Normal relative path within the base directory."""
|
||||
(tmp_path / "data.json").touch()
|
||||
result = validate_file_path("data.json", str(tmp_path))
|
||||
assert result == str(tmp_path / "data.json")
|
||||
|
||||
def test_valid_nested_path(self, tmp_path):
|
||||
"""Nested path within base directory."""
|
||||
(tmp_path / "sub").mkdir()
|
||||
(tmp_path / "sub" / "file.txt").touch()
|
||||
result = validate_file_path("sub/file.txt", str(tmp_path))
|
||||
assert result == str(tmp_path / "sub" / "file.txt")
|
||||
|
||||
def test_rejects_dotdot_traversal(self, tmp_path):
|
||||
"""Reject ../ traversal that escapes base_dir."""
|
||||
with pytest.raises(ValueError, match="outside the allowed directory"):
|
||||
validate_file_path("../../etc/passwd", str(tmp_path))
|
||||
|
||||
def test_rejects_absolute_path_outside_base(self, tmp_path):
|
||||
"""Reject absolute path outside base_dir."""
|
||||
with pytest.raises(ValueError, match="outside the allowed directory"):
|
||||
validate_file_path("/etc/passwd", str(tmp_path))
|
||||
|
||||
def test_allows_absolute_path_inside_base(self, tmp_path):
|
||||
"""Allow absolute path that's inside base_dir."""
|
||||
(tmp_path / "ok.txt").touch()
|
||||
result = validate_file_path(str(tmp_path / "ok.txt"), str(tmp_path))
|
||||
assert result == str(tmp_path / "ok.txt")
|
||||
|
||||
def test_rejects_symlink_escape(self, tmp_path):
|
||||
"""Reject symlinks that point outside base_dir."""
|
||||
link = tmp_path / "sneaky_link"
|
||||
# Create a symlink pointing to /etc/passwd
|
||||
os.symlink("/etc/passwd", str(link))
|
||||
with pytest.raises(ValueError, match="outside the allowed directory"):
|
||||
validate_file_path("sneaky_link", str(tmp_path))
|
||||
|
||||
def test_defaults_to_cwd(self):
|
||||
"""When no base_dir is given, use cwd."""
|
||||
cwd = os.getcwd()
|
||||
# A file in cwd should be valid
|
||||
result = validate_file_path(".", None)
|
||||
assert result == os.path.realpath(cwd)
|
||||
|
||||
def test_escape_hatch(self, tmp_path, monkeypatch):
|
||||
"""CREWAI_TOOLS_ALLOW_UNSAFE_PATHS=true bypasses validation."""
|
||||
monkeypatch.setenv("CREWAI_TOOLS_ALLOW_UNSAFE_PATHS", "true")
|
||||
# This would normally be rejected
|
||||
result = validate_file_path("/etc/passwd", str(tmp_path))
|
||||
assert result == os.path.realpath("/etc/passwd")
|
||||
|
||||
|
||||
class TestValidateDirectoryPath:
|
||||
"""Tests for validate_directory_path."""
|
||||
|
||||
def test_valid_directory(self, tmp_path):
|
||||
(tmp_path / "subdir").mkdir()
|
||||
result = validate_directory_path("subdir", str(tmp_path))
|
||||
assert result == str(tmp_path / "subdir")
|
||||
|
||||
def test_rejects_file_as_directory(self, tmp_path):
|
||||
(tmp_path / "file.txt").touch()
|
||||
with pytest.raises(ValueError, match="not a directory"):
|
||||
validate_directory_path("file.txt", str(tmp_path))
|
||||
|
||||
def test_rejects_traversal(self, tmp_path):
|
||||
with pytest.raises(ValueError, match="outside the allowed directory"):
|
||||
validate_directory_path("../../", str(tmp_path))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# URL validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestValidateUrl:
|
||||
"""Tests for validate_url."""
|
||||
|
||||
def test_valid_https_url(self):
|
||||
"""Normal HTTPS URL should pass."""
|
||||
result = validate_url("https://example.com/data.json")
|
||||
assert result == "https://example.com/data.json"
|
||||
|
||||
def test_valid_http_url(self):
|
||||
"""Normal HTTP URL should pass."""
|
||||
result = validate_url("http://example.com/api")
|
||||
assert result == "http://example.com/api"
|
||||
|
||||
def test_blocks_file_scheme(self):
|
||||
"""file:// URLs must be blocked."""
|
||||
with pytest.raises(ValueError, match="file:// URLs are not allowed"):
|
||||
validate_url("file:///etc/passwd")
|
||||
|
||||
def test_blocks_file_scheme_with_host(self):
|
||||
with pytest.raises(ValueError, match="file:// URLs are not allowed"):
|
||||
validate_url("file://localhost/etc/shadow")
|
||||
|
||||
def test_blocks_localhost(self):
|
||||
"""localhost must be blocked (resolves to 127.0.0.1)."""
|
||||
with pytest.raises(ValueError, match="private/reserved IP"):
|
||||
validate_url("http://localhost/admin")
|
||||
|
||||
def test_blocks_127_0_0_1(self):
|
||||
with pytest.raises(ValueError, match="private/reserved IP"):
|
||||
validate_url("http://127.0.0.1/admin")
|
||||
|
||||
def test_blocks_cloud_metadata(self):
|
||||
"""AWS/GCP/Azure metadata endpoint must be blocked."""
|
||||
with pytest.raises(ValueError, match="private/reserved IP"):
|
||||
validate_url("http://169.254.169.254/latest/meta-data/")
|
||||
|
||||
def test_blocks_private_10_range(self):
|
||||
with pytest.raises(ValueError, match="private/reserved IP"):
|
||||
validate_url("http://10.0.0.1/internal")
|
||||
|
||||
def test_blocks_private_172_range(self):
|
||||
with pytest.raises(ValueError, match="private/reserved IP"):
|
||||
validate_url("http://172.16.0.1/internal")
|
||||
|
||||
def test_blocks_private_192_range(self):
|
||||
with pytest.raises(ValueError, match="private/reserved IP"):
|
||||
validate_url("http://192.168.1.1/router")
|
||||
|
||||
def test_blocks_zero_address(self):
|
||||
with pytest.raises(ValueError, match="private/reserved IP"):
|
||||
validate_url("http://0.0.0.0/")
|
||||
|
||||
def test_blocks_ipv6_localhost(self):
|
||||
with pytest.raises(ValueError, match="private/reserved IP"):
|
||||
validate_url("http://[::1]/admin")
|
||||
|
||||
def test_blocks_ftp_scheme(self):
|
||||
with pytest.raises(ValueError, match="not allowed"):
|
||||
validate_url("ftp://example.com/file")
|
||||
|
||||
def test_blocks_empty_hostname(self):
|
||||
with pytest.raises(ValueError, match="no hostname"):
|
||||
validate_url("http:///path")
|
||||
|
||||
def test_blocks_unresolvable_host(self):
|
||||
with pytest.raises(ValueError, match="Could not resolve"):
|
||||
validate_url("http://this-host-definitely-does-not-exist-abc123.com/")
|
||||
|
||||
def test_escape_hatch(self, monkeypatch):
|
||||
"""CREWAI_TOOLS_ALLOW_UNSAFE_PATHS=true bypasses URL validation."""
|
||||
monkeypatch.setenv("CREWAI_TOOLS_ALLOW_UNSAFE_PATHS", "true")
|
||||
# file:// would normally be blocked
|
||||
result = validate_url("file:///etc/passwd")
|
||||
assert result == "file:///etc/passwd"
|
||||
@@ -609,7 +609,6 @@ def env() -> None:
|
||||
@env.command("view")
|
||||
def env_view() -> None:
|
||||
"""View tracing-related environment variables."""
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from rich.console import Console
|
||||
@@ -738,7 +737,6 @@ def traces_disable() -> None:
|
||||
@traces.command("status")
|
||||
def traces_status() -> None:
|
||||
"""Show current trace collection status."""
|
||||
import os
|
||||
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
from collections.abc import Coroutine
|
||||
import concurrent.futures
|
||||
import contextvars
|
||||
import inspect
|
||||
from typing import Any
|
||||
|
||||
Reference in New Issue
Block a user