fix: address CodeQL and review comments on RAG path/URL validation

- Replace insecure tempfile.mktemp() with inline symlink target in test
- Remove unused 'target' variable and unused tempfile import
- Narrow broad except Exception: pass to only catch urlparse errors;
  validate_url ValueError now propagates instead of being silently swallowed
- Fix ruff B904 (raise-without-from-inside-except) in safe_path.py
- Fix ruff B007 (unused loop variable 'family') in safe_path.py
- Use validate_directory_path in DirectorySearchTool.add() so the
  public utility is exercised in production code

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Alex
2026-04-06 23:12:37 -07:00
parent d435619aba
commit 57e23b4dca
4 changed files with 20 additions and 15 deletions

View File

@@ -4,6 +4,7 @@ from pydantic import BaseModel, Field
from crewai_tools.rag.data_types import DataType
from crewai_tools.tools.rag.rag_tool import RagTool
from crewai_tools.utilities.safe_path import validate_directory_path
class FixedDirectorySearchToolSchema(BaseModel):
@@ -37,6 +38,7 @@ class DirectorySearchTool(RagTool):
self._generate_description()
def add(self, directory: str) -> None: # type: ignore[override]
validate_directory_path(directory)
super().add(directory, data_type=DataType.DIRECTORY)
def _run( # type: ignore[override]

View File

@@ -1,5 +1,5 @@
import os
from abc import ABC, abstractmethod
import os
from typing import Any, Literal, cast
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
@@ -249,9 +249,10 @@ class RagTool(BaseTool):
"""
# Validate file paths and URLs before adding to prevent
# unauthorized file reads and SSRF.
from crewai_tools.utilities.safe_path import validate_file_path, validate_url
from urllib.parse import urlparse
from crewai_tools.utilities.safe_path import validate_file_path, validate_url
def _check_url(value: str, label: str) -> None:
try:
validate_url(value)
@@ -268,17 +269,20 @@ class RagTool(BaseTool):
for arg in args:
source_ref = str(arg.get("source", arg.get("content", ""))) if isinstance(arg, dict) else str(arg)
# Check if it's a URL
# Check if it's a URL — only catch urlparse-specific errors here;
# validate_url's ValueError must propagate so it is never silently bypassed.
try:
parsed = urlparse(source_ref)
if parsed.scheme in ("http", "https", "file"):
except (ValueError, AttributeError):
parsed = None # type: ignore[assignment]
if parsed is not None and parsed.scheme in ("http", "https", "file"):
try:
validate_url(source_ref)
validated_args.append(arg)
continue
except ValueError as e:
raise ValueError(f"Blocked unsafe URL: {e}") from e
except Exception:
pass
except ValueError as e:
raise ValueError(f"Blocked unsafe URL: {e}") from e
validated_args.append(arg)
continue
# Check if it looks like a file path (not a plain text string)
if os.path.sep in source_ref or source_ref.startswith(".") or os.path.isabs(source_ref):

View File

@@ -16,6 +16,7 @@ import os
import socket
from urllib.parse import urlparse
logger = logging.getLogger(__name__)
_UNSAFE_PATHS_ENV = "CREWAI_TOOLS_ALLOW_UNSAFE_PATHS"
@@ -172,10 +173,10 @@ def validate_url(url: str) -> str:
addrinfos = socket.getaddrinfo(
parsed.hostname, parsed.port or (443 if parsed.scheme == "https" else 80)
)
except socket.gaierror:
raise ValueError(f"Could not resolve hostname: '{parsed.hostname}'")
except socket.gaierror as exc:
raise ValueError(f"Could not resolve hostname: '{parsed.hostname}'") from exc
for family, _, _, _, sockaddr in addrinfos:
for _family, _, _, _, sockaddr in addrinfos:
ip_str = sockaddr[0]
if _is_private_or_reserved(ip_str):
raise ValueError(

View File

@@ -3,7 +3,6 @@
from __future__ import annotations
import os
import tempfile
import pytest
@@ -52,7 +51,6 @@ class TestValidateFilePath:
def test_rejects_symlink_escape(self, tmp_path):
"""Reject symlinks that point outside base_dir."""
target = tempfile.mktemp() # path that doesn't exist
link = tmp_path / "sneaky_link"
# Create a symlink pointing to /etc/passwd
os.symlink("/etc/passwd", str(link))