mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-30 23:02:50 +00:00
fix: address CodeQL and review comments on RAG path/URL validation
- Replace insecure tempfile.mktemp() with inline symlink target in test - Remove unused 'target' variable and unused tempfile import - Narrow broad except Exception: pass to only catch urlparse errors; validate_url ValueError now propagates instead of being silently swallowed - Fix ruff B904 (raise-without-from-inside-except) in safe_path.py - Fix ruff B007 (unused loop variable 'family') in safe_path.py - Use validate_directory_path in DirectorySearchTool.add() so the public utility is exercised in production code Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -4,6 +4,7 @@ from pydantic import BaseModel, Field
|
|||||||
|
|
||||||
from crewai_tools.rag.data_types import DataType
|
from crewai_tools.rag.data_types import DataType
|
||||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||||
|
from crewai_tools.utilities.safe_path import validate_directory_path
|
||||||
|
|
||||||
|
|
||||||
class FixedDirectorySearchToolSchema(BaseModel):
|
class FixedDirectorySearchToolSchema(BaseModel):
|
||||||
@@ -37,6 +38,7 @@ class DirectorySearchTool(RagTool):
|
|||||||
self._generate_description()
|
self._generate_description()
|
||||||
|
|
||||||
def add(self, directory: str) -> None: # type: ignore[override]
|
def add(self, directory: str) -> None: # type: ignore[override]
|
||||||
|
validate_directory_path(directory)
|
||||||
super().add(directory, data_type=DataType.DIRECTORY)
|
super().add(directory, data_type=DataType.DIRECTORY)
|
||||||
|
|
||||||
def _run( # type: ignore[override]
|
def _run( # type: ignore[override]
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import os
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
import os
|
||||||
from typing import Any, Literal, cast
|
from typing import Any, Literal, cast
|
||||||
|
|
||||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
||||||
@@ -249,9 +249,10 @@ class RagTool(BaseTool):
|
|||||||
"""
|
"""
|
||||||
# Validate file paths and URLs before adding to prevent
|
# Validate file paths and URLs before adding to prevent
|
||||||
# unauthorized file reads and SSRF.
|
# unauthorized file reads and SSRF.
|
||||||
from crewai_tools.utilities.safe_path import validate_file_path, validate_url
|
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from crewai_tools.utilities.safe_path import validate_file_path, validate_url
|
||||||
|
|
||||||
def _check_url(value: str, label: str) -> None:
|
def _check_url(value: str, label: str) -> None:
|
||||||
try:
|
try:
|
||||||
validate_url(value)
|
validate_url(value)
|
||||||
@@ -268,17 +269,20 @@ class RagTool(BaseTool):
|
|||||||
for arg in args:
|
for arg in args:
|
||||||
source_ref = str(arg.get("source", arg.get("content", ""))) if isinstance(arg, dict) else str(arg)
|
source_ref = str(arg.get("source", arg.get("content", ""))) if isinstance(arg, dict) else str(arg)
|
||||||
|
|
||||||
# Check if it's a URL
|
# Check if it's a URL — only catch urlparse-specific errors here;
|
||||||
|
# validate_url's ValueError must propagate so it is never silently bypassed.
|
||||||
try:
|
try:
|
||||||
parsed = urlparse(source_ref)
|
parsed = urlparse(source_ref)
|
||||||
if parsed.scheme in ("http", "https", "file"):
|
except (ValueError, AttributeError):
|
||||||
|
parsed = None # type: ignore[assignment]
|
||||||
|
|
||||||
|
if parsed is not None and parsed.scheme in ("http", "https", "file"):
|
||||||
|
try:
|
||||||
validate_url(source_ref)
|
validate_url(source_ref)
|
||||||
validated_args.append(arg)
|
except ValueError as e:
|
||||||
continue
|
raise ValueError(f"Blocked unsafe URL: {e}") from e
|
||||||
except ValueError as e:
|
validated_args.append(arg)
|
||||||
raise ValueError(f"Blocked unsafe URL: {e}") from e
|
continue
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Check if it looks like a file path (not a plain text string)
|
# Check if it looks like a file path (not a plain text string)
|
||||||
if os.path.sep in source_ref or source_ref.startswith(".") or os.path.isabs(source_ref):
|
if os.path.sep in source_ref or source_ref.startswith(".") or os.path.isabs(source_ref):
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import os
|
|||||||
import socket
|
import socket
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_UNSAFE_PATHS_ENV = "CREWAI_TOOLS_ALLOW_UNSAFE_PATHS"
|
_UNSAFE_PATHS_ENV = "CREWAI_TOOLS_ALLOW_UNSAFE_PATHS"
|
||||||
@@ -172,10 +173,10 @@ def validate_url(url: str) -> str:
|
|||||||
addrinfos = socket.getaddrinfo(
|
addrinfos = socket.getaddrinfo(
|
||||||
parsed.hostname, parsed.port or (443 if parsed.scheme == "https" else 80)
|
parsed.hostname, parsed.port or (443 if parsed.scheme == "https" else 80)
|
||||||
)
|
)
|
||||||
except socket.gaierror:
|
except socket.gaierror as exc:
|
||||||
raise ValueError(f"Could not resolve hostname: '{parsed.hostname}'")
|
raise ValueError(f"Could not resolve hostname: '{parsed.hostname}'") from exc
|
||||||
|
|
||||||
for family, _, _, _, sockaddr in addrinfos:
|
for _family, _, _, _, sockaddr in addrinfos:
|
||||||
ip_str = sockaddr[0]
|
ip_str = sockaddr[0]
|
||||||
if _is_private_or_reserved(ip_str):
|
if _is_private_or_reserved(ip_str):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import tempfile
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -52,7 +51,6 @@ class TestValidateFilePath:
|
|||||||
|
|
||||||
def test_rejects_symlink_escape(self, tmp_path):
|
def test_rejects_symlink_escape(self, tmp_path):
|
||||||
"""Reject symlinks that point outside base_dir."""
|
"""Reject symlinks that point outside base_dir."""
|
||||||
target = tempfile.mktemp() # path that doesn't exist
|
|
||||||
link = tmp_path / "sneaky_link"
|
link = tmp_path / "sneaky_link"
|
||||||
# Create a symlink pointing to /etc/passwd
|
# Create a symlink pointing to /etc/passwd
|
||||||
os.symlink("/etc/passwd", str(link))
|
os.symlink("/etc/passwd", str(link))
|
||||||
|
|||||||
Reference in New Issue
Block a user