mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-07 19:48:13 +00:00
fix: address CodeQL and review comments on RAG path/URL validation
- Replace insecure tempfile.mktemp() with inline symlink target in test - Remove unused 'target' variable and unused tempfile import - Narrow broad except Exception: pass to only catch urlparse errors; validate_url ValueError now propagates instead of being silently swallowed - Fix ruff B904 (raise-without-from-inside-except) in safe_path.py - Fix ruff B007 (unused loop variable 'family') in safe_path.py - Use validate_directory_path in DirectorySearchTool.add() so the public utility is exercised in production code Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -4,6 +4,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||
from crewai_tools.utilities.safe_path import validate_directory_path
|
||||
|
||||
|
||||
class FixedDirectorySearchToolSchema(BaseModel):
|
||||
@@ -37,6 +38,7 @@ class DirectorySearchTool(RagTool):
|
||||
self._generate_description()
|
||||
|
||||
def add(self, directory: str) -> None: # type: ignore[override]
|
||||
validate_directory_path(directory)
|
||||
super().add(directory, data_type=DataType.DIRECTORY)
|
||||
|
||||
def _run( # type: ignore[override]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
import os
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
||||
@@ -249,9 +249,10 @@ class RagTool(BaseTool):
|
||||
"""
|
||||
# Validate file paths and URLs before adding to prevent
|
||||
# unauthorized file reads and SSRF.
|
||||
from crewai_tools.utilities.safe_path import validate_file_path, validate_url
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from crewai_tools.utilities.safe_path import validate_file_path, validate_url
|
||||
|
||||
def _check_url(value: str, label: str) -> None:
|
||||
try:
|
||||
validate_url(value)
|
||||
@@ -268,17 +269,20 @@ class RagTool(BaseTool):
|
||||
for arg in args:
|
||||
source_ref = str(arg.get("source", arg.get("content", ""))) if isinstance(arg, dict) else str(arg)
|
||||
|
||||
# Check if it's a URL
|
||||
# Check if it's a URL — only catch urlparse-specific errors here;
|
||||
# validate_url's ValueError must propagate so it is never silently bypassed.
|
||||
try:
|
||||
parsed = urlparse(source_ref)
|
||||
if parsed.scheme in ("http", "https", "file"):
|
||||
except (ValueError, AttributeError):
|
||||
parsed = None # type: ignore[assignment]
|
||||
|
||||
if parsed is not None and parsed.scheme in ("http", "https", "file"):
|
||||
try:
|
||||
validate_url(source_ref)
|
||||
validated_args.append(arg)
|
||||
continue
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Blocked unsafe URL: {e}") from e
|
||||
except Exception:
|
||||
pass
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Blocked unsafe URL: {e}") from e
|
||||
validated_args.append(arg)
|
||||
continue
|
||||
|
||||
# Check if it looks like a file path (not a plain text string)
|
||||
if os.path.sep in source_ref or source_ref.startswith(".") or os.path.isabs(source_ref):
|
||||
|
||||
@@ -16,6 +16,7 @@ import os
|
||||
import socket
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_UNSAFE_PATHS_ENV = "CREWAI_TOOLS_ALLOW_UNSAFE_PATHS"
|
||||
@@ -172,10 +173,10 @@ def validate_url(url: str) -> str:
|
||||
addrinfos = socket.getaddrinfo(
|
||||
parsed.hostname, parsed.port or (443 if parsed.scheme == "https" else 80)
|
||||
)
|
||||
except socket.gaierror:
|
||||
raise ValueError(f"Could not resolve hostname: '{parsed.hostname}'")
|
||||
except socket.gaierror as exc:
|
||||
raise ValueError(f"Could not resolve hostname: '{parsed.hostname}'") from exc
|
||||
|
||||
for family, _, _, _, sockaddr in addrinfos:
|
||||
for _family, _, _, _, sockaddr in addrinfos:
|
||||
ip_str = sockaddr[0]
|
||||
if _is_private_or_reserved(ip_str):
|
||||
raise ValueError(
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -52,7 +51,6 @@ class TestValidateFilePath:
|
||||
|
||||
def test_rejects_symlink_escape(self, tmp_path):
|
||||
"""Reject symlinks that point outside base_dir."""
|
||||
target = tempfile.mktemp() # path that doesn't exist
|
||||
link = tmp_path / "sneaky_link"
|
||||
# Create a symlink pointing to /etc/passwd
|
||||
os.symlink("/etc/passwd", str(link))
|
||||
|
||||
Reference in New Issue
Block a user