mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-02 13:48:09 +00:00
fix: add SSRF protection to FileUrl in crewai-files
- Add security.py module with DNS-resolving URL validation that blocks private/reserved IP ranges (10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16, 127.0.0.0/8, 169.254.0.0/16, 0.0.0.0/32) and IPv6 equivalents - Update FileUrl._validate_url() to use the new validate_url() function - Disable follow_redirects in read()/aread() to prevent redirect-based SSRF - Add CREWAI_FILES_ALLOW_UNSAFE_URLS env var escape hatch - Add comprehensive tests for SSRF protection Fixes #5843 Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
113
lib/crewai-files/src/crewai_files/core/security.py
Normal file
113
lib/crewai-files/src/crewai_files/core/security.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""URL validation utilities for crewai-files.
|
||||
|
||||
Provides SSRF protection by resolving DNS and checking that target IPs
|
||||
are not private or reserved before allowing HTTP requests.
|
||||
|
||||
Set CREWAI_FILES_ALLOW_UNSAFE_URLS=true to bypass validation (not
|
||||
recommended for production).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
from urllib.parse import urlparse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_UNSAFE_URLS_ENV = "CREWAI_FILES_ALLOW_UNSAFE_URLS"
|
||||
|
||||
_BLOCKED_IPV4_NETWORKS = [
|
||||
ipaddress.ip_network("10.0.0.0/8"),
|
||||
ipaddress.ip_network("172.16.0.0/12"),
|
||||
ipaddress.ip_network("192.168.0.0/16"),
|
||||
ipaddress.ip_network("127.0.0.0/8"),
|
||||
ipaddress.ip_network("169.254.0.0/16"), # Link-local / cloud metadata
|
||||
ipaddress.ip_network("0.0.0.0/32"),
|
||||
]
|
||||
|
||||
_BLOCKED_IPV6_NETWORKS = [
|
||||
ipaddress.ip_network("::1/128"),
|
||||
ipaddress.ip_network("::/128"),
|
||||
ipaddress.ip_network("fc00::/7"), # Unique local addresses
|
||||
ipaddress.ip_network("fe80::/10"), # Link-local IPv6
|
||||
]
|
||||
|
||||
|
||||
def _is_escape_hatch_enabled() -> bool:
|
||||
"""Check if the unsafe URLs escape hatch is enabled."""
|
||||
return os.environ.get(_UNSAFE_URLS_ENV, "").lower() in ("true", "1", "yes")
|
||||
|
||||
|
||||
def _is_private_or_reserved(ip_str: str) -> bool:
|
||||
"""Check if an IP address is private, reserved, or otherwise unsafe."""
|
||||
try:
|
||||
addr = ipaddress.ip_address(ip_str)
|
||||
if isinstance(addr, ipaddress.IPv6Address) and addr.ipv4_mapped:
|
||||
addr = addr.ipv4_mapped
|
||||
networks = (
|
||||
_BLOCKED_IPV4_NETWORKS
|
||||
if isinstance(addr, ipaddress.IPv4Address)
|
||||
else _BLOCKED_IPV6_NETWORKS
|
||||
)
|
||||
return any(addr in network for network in networks)
|
||||
except ValueError:
|
||||
return True # If we can't parse, block it
|
||||
|
||||
|
||||
def validate_url(url: str) -> str:
|
||||
"""Validate that a URL is safe to fetch (SSRF protection).
|
||||
|
||||
Blocks ``file://`` scheme entirely. For ``http``/``https``, resolves
|
||||
DNS and checks that the target IP is not private or reserved.
|
||||
|
||||
Args:
|
||||
url: The URL to validate.
|
||||
|
||||
Returns:
|
||||
The validated URL string.
|
||||
|
||||
Raises:
|
||||
ValueError: If the URL uses a blocked scheme or resolves to a
|
||||
private/reserved IP address.
|
||||
"""
|
||||
if _is_escape_hatch_enabled():
|
||||
logger.warning(
|
||||
"%s is enabled - skipping URL validation for: %s",
|
||||
_UNSAFE_URLS_ENV,
|
||||
url,
|
||||
)
|
||||
return url
|
||||
|
||||
parsed = urlparse(url)
|
||||
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise ValueError(
|
||||
f"Invalid URL scheme: {url}. Only http and https are allowed."
|
||||
)
|
||||
|
||||
if not parsed.hostname:
|
||||
raise ValueError(f"URL has no hostname: '{url}'")
|
||||
|
||||
try:
|
||||
addrinfos = socket.getaddrinfo(
|
||||
parsed.hostname,
|
||||
parsed.port or (443 if parsed.scheme == "https" else 80),
|
||||
)
|
||||
except socket.gaierror as exc:
|
||||
raise ValueError(
|
||||
f"Could not resolve hostname: '{parsed.hostname}'"
|
||||
) from exc
|
||||
|
||||
for _family, _, _, _, sockaddr in addrinfos:
|
||||
ip_str = str(sockaddr[0])
|
||||
if _is_private_or_reserved(ip_str):
|
||||
raise ValueError(
|
||||
f"URL '{url}' resolves to private/reserved IP {ip_str}. "
|
||||
f"Access to internal networks is not allowed. "
|
||||
f"Set {_UNSAFE_URLS_ENV}=true to bypass."
|
||||
)
|
||||
|
||||
return url
|
||||
@@ -449,9 +449,10 @@ class FileUrl(BaseModel):
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_url(self) -> FileUrl:
|
||||
"""Validate URL format."""
|
||||
if not self.url.startswith(("http://", "https://")):
|
||||
raise ValueError(f"Invalid URL scheme: {self.url}")
|
||||
"""Validate URL format and ensure it does not target internal networks."""
|
||||
from crewai_files.core.security import validate_url
|
||||
|
||||
validate_url(self.url)
|
||||
return self
|
||||
|
||||
@property
|
||||
@@ -475,7 +476,7 @@ class FileUrl(BaseModel):
|
||||
if self._content is None:
|
||||
import httpx
|
||||
|
||||
response = httpx.get(self.url, follow_redirects=True)
|
||||
response = httpx.get(self.url, follow_redirects=False)
|
||||
response.raise_for_status()
|
||||
self._content = response.content
|
||||
if "content-type" in response.headers:
|
||||
@@ -488,7 +489,7 @@ class FileUrl(BaseModel):
|
||||
import httpx
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(self.url, follow_redirects=True)
|
||||
response = await client.get(self.url, follow_redirects=False)
|
||||
response.raise_for_status()
|
||||
self._content = response.content
|
||||
if "content-type" in response.headers:
|
||||
|
||||
@@ -95,7 +95,7 @@ class TestFileUrl:
|
||||
content = url.read()
|
||||
|
||||
mock_get.assert_called_once_with(
|
||||
"https://example.com/image.png", follow_redirects=True
|
||||
"https://example.com/image.png", follow_redirects=False
|
||||
)
|
||||
assert content == b"fake image content"
|
||||
|
||||
|
||||
233
lib/crewai-files/tests/test_ssrf_protection.py
Normal file
233
lib/crewai-files/tests/test_ssrf_protection.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""Tests for SSRF protection in FileUrl and the security module."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from crewai_files import FileUrl
|
||||
from crewai_files.core.security import (
|
||||
_is_private_or_reserved,
|
||||
validate_url,
|
||||
)
|
||||
import pytest
|
||||
|
||||
|
||||
class TestIsPrivateOrReserved:
|
||||
"""Tests for _is_private_or_reserved helper."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ip",
|
||||
[
|
||||
"127.0.0.1",
|
||||
"127.0.0.2",
|
||||
"10.0.0.1",
|
||||
"10.255.255.255",
|
||||
"172.16.0.1",
|
||||
"172.31.255.255",
|
||||
"192.168.0.1",
|
||||
"192.168.1.100",
|
||||
"169.254.169.254", # AWS metadata
|
||||
"0.0.0.0",
|
||||
],
|
||||
)
|
||||
def test_private_ipv4_addresses_blocked(self, ip):
|
||||
assert _is_private_or_reserved(ip) is True
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ip",
|
||||
[
|
||||
"::1",
|
||||
"::",
|
||||
"fc00::1",
|
||||
"fd00::1",
|
||||
"fe80::1",
|
||||
],
|
||||
)
|
||||
def test_private_ipv6_addresses_blocked(self, ip):
|
||||
assert _is_private_or_reserved(ip) is True
|
||||
|
||||
def test_ipv4_mapped_ipv6_loopback_blocked(self):
|
||||
assert _is_private_or_reserved("::ffff:127.0.0.1") is True
|
||||
|
||||
def test_ipv4_mapped_ipv6_metadata_blocked(self):
|
||||
assert _is_private_or_reserved("::ffff:169.254.169.254") is True
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ip",
|
||||
[
|
||||
"8.8.8.8",
|
||||
"1.1.1.1",
|
||||
"93.184.216.34", # example.com
|
||||
"2606:2800:220:1:248:1893:25c8:1946",
|
||||
],
|
||||
)
|
||||
def test_public_addresses_allowed(self, ip):
|
||||
assert _is_private_or_reserved(ip) is False
|
||||
|
||||
def test_unparseable_ip_blocked(self):
|
||||
assert _is_private_or_reserved("not-an-ip") is True
|
||||
|
||||
|
||||
class TestValidateUrl:
|
||||
"""Tests for validate_url function."""
|
||||
|
||||
def test_blocks_file_scheme(self):
|
||||
with pytest.raises(ValueError, match="Invalid URL scheme"):
|
||||
validate_url("file:///etc/passwd")
|
||||
|
||||
def test_blocks_ftp_scheme(self):
|
||||
with pytest.raises(ValueError, match="Invalid URL scheme"):
|
||||
validate_url("ftp://example.com/file")
|
||||
|
||||
def test_blocks_data_scheme(self):
|
||||
with pytest.raises(ValueError, match="Invalid URL scheme"):
|
||||
validate_url("data:text/plain;base64,SGVsbG8=")
|
||||
|
||||
def test_blocks_no_hostname(self):
|
||||
with pytest.raises(ValueError, match="URL has no hostname"):
|
||||
validate_url("http://")
|
||||
|
||||
def test_blocks_localhost(self):
|
||||
with pytest.raises(ValueError, match="private/reserved IP"):
|
||||
validate_url("http://localhost/secret")
|
||||
|
||||
def test_blocks_127_0_0_1(self):
|
||||
with pytest.raises(ValueError, match="private/reserved IP"):
|
||||
validate_url("http://127.0.0.1/secret")
|
||||
|
||||
def test_blocks_metadata_endpoint(self):
|
||||
"""Block AWS/GCP/Azure cloud metadata endpoint."""
|
||||
with pytest.raises(ValueError, match="private/reserved IP"):
|
||||
validate_url("http://169.254.169.254/latest/meta-data/")
|
||||
|
||||
def test_blocks_private_10_network(self):
|
||||
with pytest.raises(ValueError, match="private/reserved IP"):
|
||||
validate_url("http://10.0.0.1/internal")
|
||||
|
||||
def test_blocks_private_172_network(self):
|
||||
with pytest.raises(ValueError, match="private/reserved IP"):
|
||||
validate_url("http://172.16.0.1/internal")
|
||||
|
||||
def test_blocks_private_192_168_network(self):
|
||||
with pytest.raises(ValueError, match="private/reserved IP"):
|
||||
validate_url("http://192.168.1.1/internal")
|
||||
|
||||
def test_blocks_0_0_0_0(self):
|
||||
with pytest.raises(ValueError, match="private/reserved IP"):
|
||||
validate_url("http://0.0.0.0/")
|
||||
|
||||
def test_allows_public_https_url(self):
|
||||
result = validate_url("https://example.com/image.png")
|
||||
assert result == "https://example.com/image.png"
|
||||
|
||||
def test_allows_public_http_url(self):
|
||||
result = validate_url("http://example.com/file.pdf")
|
||||
assert result == "http://example.com/file.pdf"
|
||||
|
||||
def test_blocks_unresolvable_hostname(self):
|
||||
with pytest.raises(ValueError, match="Could not resolve hostname"):
|
||||
validate_url("https://this-domain-definitely-does-not-exist-12345.invalid/")
|
||||
|
||||
def test_blocks_dns_resolving_to_private_ip(self):
|
||||
"""Simulate a hostname that resolves to a private IP (DNS rebinding)."""
|
||||
fake_addrinfo = [(2, 1, 6, "", ("10.0.0.1", 443))]
|
||||
with patch("socket.getaddrinfo", return_value=fake_addrinfo):
|
||||
with pytest.raises(ValueError, match="private/reserved IP"):
|
||||
validate_url("https://evil.example.com/steal")
|
||||
|
||||
def test_escape_hatch_allows_private_urls(self):
|
||||
"""CREWAI_FILES_ALLOW_UNSAFE_URLS=true bypasses validation."""
|
||||
with patch.dict("os.environ", {"CREWAI_FILES_ALLOW_UNSAFE_URLS": "true"}):
|
||||
result = validate_url("http://127.0.0.1/secret")
|
||||
assert result == "http://127.0.0.1/secret"
|
||||
|
||||
|
||||
class TestFileUrlSSRF:
|
||||
"""Tests for SSRF protection integrated into FileUrl."""
|
||||
|
||||
def test_rejects_localhost_url(self):
|
||||
with pytest.raises(ValueError, match="private/reserved IP"):
|
||||
FileUrl(url="http://localhost/secret")
|
||||
|
||||
def test_rejects_127_0_0_1(self):
|
||||
with pytest.raises(ValueError, match="private/reserved IP"):
|
||||
FileUrl(url="http://127.0.0.1:8080/admin")
|
||||
|
||||
def test_rejects_metadata_endpoint(self):
|
||||
with pytest.raises(ValueError, match="private/reserved IP"):
|
||||
FileUrl(url="http://169.254.169.254/latest/meta-data/")
|
||||
|
||||
def test_rejects_private_network_10(self):
|
||||
with pytest.raises(ValueError, match="private/reserved IP"):
|
||||
FileUrl(url="http://10.0.0.1/internal-service")
|
||||
|
||||
def test_rejects_private_network_172(self):
|
||||
with pytest.raises(ValueError, match="private/reserved IP"):
|
||||
FileUrl(url="http://172.16.5.10/api")
|
||||
|
||||
def test_rejects_private_network_192_168(self):
|
||||
with pytest.raises(ValueError, match="private/reserved IP"):
|
||||
FileUrl(url="http://192.168.1.1/router")
|
||||
|
||||
def test_rejects_file_scheme(self):
|
||||
with pytest.raises(ValueError, match="Invalid URL scheme"):
|
||||
FileUrl(url="file:///etc/passwd")
|
||||
|
||||
def test_rejects_ftp_scheme(self):
|
||||
with pytest.raises(ValueError, match="Invalid URL scheme"):
|
||||
FileUrl(url="ftp://example.com/file")
|
||||
|
||||
def test_accepts_public_https_url(self):
|
||||
url = FileUrl(url="https://example.com/image.png")
|
||||
assert url.url == "https://example.com/image.png"
|
||||
|
||||
def test_accepts_public_http_url(self):
|
||||
url = FileUrl(url="http://example.com/file.pdf")
|
||||
assert url.url == "http://example.com/file.pdf"
|
||||
|
||||
def test_rejects_dns_rebinding_to_private_ip(self):
|
||||
"""A hostname that resolves to an internal IP should be blocked."""
|
||||
fake_addrinfo = [(2, 1, 6, "", ("169.254.169.254", 80))]
|
||||
with patch("socket.getaddrinfo", return_value=fake_addrinfo):
|
||||
with pytest.raises(ValueError, match="private/reserved IP"):
|
||||
FileUrl(url="http://metadata.evil.com/steal-creds")
|
||||
|
||||
def test_escape_hatch_bypasses_ssrf_check(self):
|
||||
with patch.dict("os.environ", {"CREWAI_FILES_ALLOW_UNSAFE_URLS": "true"}):
|
||||
url = FileUrl(url="http://127.0.0.1/local")
|
||||
assert url.url == "http://127.0.0.1/local"
|
||||
|
||||
def test_read_does_not_follow_redirects(self):
|
||||
"""Verify read() uses follow_redirects=False to prevent redirect-based SSRF."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
url = FileUrl(url="https://example.com/image.png")
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = b"data"
|
||||
mock_response.headers = {}
|
||||
|
||||
with patch("httpx.get", return_value=mock_response) as mock_get:
|
||||
url.read()
|
||||
mock_get.assert_called_once_with(
|
||||
"https://example.com/image.png", follow_redirects=False
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aread_does_not_follow_redirects(self):
|
||||
"""Verify aread() uses follow_redirects=False."""
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
url = FileUrl(url="https://example.com/image.png")
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = b"data"
|
||||
mock_response.headers = {}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
await url.aread()
|
||||
mock_client.get.assert_called_once_with(
|
||||
"https://example.com/image.png", follow_redirects=False
|
||||
)
|
||||
Reference in New Issue
Block a user