Compare commits

...

2 Commits

Author SHA1 Message Date
Iris Clawd
42b4f0101e fix(ci): add type annotations to _SSRFSafeAdapter.send and fix test mocks
- Add proper type annotations to _SSRFSafeAdapter.send() to satisfy mypy
- Add 'Any' import from typing
- Update webpage_loader tests to mock safe_get instead of requests.get
  (the loader now uses safe_get for SSRF protection)
2026-05-05 15:47:02 +00:00
Iris Clawd
3dc8c45cc9 fix(security): validate IPs on every redirect hop to prevent SSRF bypass (OSS-51)
Adds a custom HTTPAdapter (_SSRFSafeAdapter) that intercepts every
request — including redirect hops — and validates the resolved IP
against the private/reserved blocklist before the connection proceeds.

New public API:
- safe_request_session(): returns a Session with the adapter mounted
- safe_get(url, **kwargs): drop-in replacement for requests.get() that
  validates the initial URL AND every redirect destination

Updated tools to use safe_get() instead of validate_url() + requests.get():
- ScrapeWebsiteTool
- ScrapeElementFromWebsiteTool
- WebPageLoader (RAG)

Closes OSS-51
2026-05-05 03:57:09 +00:00
6 changed files with 150 additions and 22 deletions

View File

@@ -2,9 +2,8 @@ import re
from typing import Any, Final
from bs4 import BeautifulSoup
import requests
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
from crewai_tools.security.safe_path import safe_get
from crewai_tools.rag.source_content import SourceContent
@@ -25,7 +24,7 @@ class WebPageLoader(BaseLoader):
)
try:
response = requests.get(url, timeout=15, headers=headers)
response = safe_get(url, timeout=15, headers=headers)
response.encoding = response.apparent_encoding
soup = BeautifulSoup(response.text, "html.parser")

View File

@@ -14,8 +14,12 @@ import ipaddress
import logging
import os
import socket
from typing import Any
from urllib.parse import urlparse
import requests
from requests.adapters import HTTPAdapter
logger = logging.getLogger(__name__)
@@ -203,3 +207,72 @@ def validate_url(url: str) -> str:
)
return url
# ---------------------------------------------------------------------------
# SSRF-safe HTTP requests (validates IPs on every redirect hop)
# ---------------------------------------------------------------------------
class _SSRFSafeAdapter(HTTPAdapter):
"""HTTPAdapter that validates the resolved IP of every request — including
redirect hops — against the private/reserved blocklist before the
connection is made."""
def send( # type: ignore[override]
self, request: requests.PreparedRequest, **kwargs: Any
) -> requests.Response:
parsed = urlparse(request.url)
if not _is_escape_hatch_enabled() and parsed.hostname:
try:
port = parsed.port or (443 if parsed.scheme == "https" else 80)
addrinfos = socket.getaddrinfo(parsed.hostname, port)
except socket.gaierror as exc:
raise ValueError(
f"Could not resolve hostname: '{parsed.hostname}'"
) from exc
for _family, _, _, _, sockaddr in addrinfos:
ip_str = str(sockaddr[0])
if _is_private_or_reserved(ip_str):
raise ValueError(
f"Redirect to '{request.url}' blocked: resolves to "
f"private/reserved IP {ip_str}. Access to internal "
f"networks is not allowed. "
f"Set {_UNSAFE_PATHS_ENV}=true to bypass."
)
return super().send(request, **kwargs)
def safe_request_session() -> requests.Session:
"""Return a :class:`requests.Session` that validates every connection
target (including redirect destinations) against the SSRF blocklist."""
session = requests.Session()
adapter = _SSRFSafeAdapter()
session.mount("http://", adapter)
session.mount("https://", adapter)
return session
def safe_get(url: str, **kwargs: Any) -> requests.Response:
"""Drop-in replacement for ``requests.get()`` with SSRF protection.
Validates the initial URL via :func:`validate_url`, then executes the
request through a session whose adapter re-checks every redirect hop.
Args:
url: The URL to fetch.
**kwargs: Passed through to ``session.get()`` (headers, cookies,
timeout, etc.).
Returns:
The :class:`requests.Response`.
Raises:
ValueError: If the initial URL or any redirect target resolves to
a private/reserved IP.
"""
validate_url(url)
session = safe_request_session()
return session.get(url, **kwargs)

View File

@@ -3,9 +3,7 @@ from typing import Any
from crewai.tools import BaseTool
from pydantic import BaseModel, Field
import requests
from crewai_tools.security.safe_path import validate_url
from crewai_tools.security.safe_path import safe_get
try:
@@ -83,8 +81,7 @@ class ScrapeElementFromWebsiteTool(BaseTool):
if website_url is None or css_element is None:
raise ValueError("Both website_url and css_element must be provided.")
website_url = validate_url(website_url)
page = requests.get(
page = safe_get(
website_url,
headers=self.headers,
cookies=self.cookies if self.cookies else {},

View File

@@ -3,9 +3,7 @@ import re
from typing import Any
from pydantic import Field
import requests
from crewai_tools.security.safe_path import validate_url
from crewai_tools.security.safe_path import safe_get
try:
@@ -75,8 +73,7 @@ class ScrapeWebsiteTool(BaseTool):
if website_url is None:
raise ValueError("Website URL must be provided.")
website_url = validate_url(website_url)
page = requests.get(
page = safe_get(
website_url,
timeout=15,
headers=self.headers,

View File

@@ -22,7 +22,7 @@ class TestWebPageLoader:
soup.return_value = script_style_elements or []
return soup
@patch("requests.get")
@patch("crewai_tools.rag.loaders.webpage_loader.safe_get")
@patch("crewai_tools.rag.loaders.webpage_loader.BeautifulSoup")
def test_load_basic_webpage(self, mock_bs, mock_get):
mock_get.return_value = self.setup_mock_response(
@@ -37,7 +37,7 @@ class TestWebPageLoader:
assert result.content == "Test content"
assert result.metadata["title"] == "Test Page"
@patch("requests.get")
@patch("crewai_tools.rag.loaders.webpage_loader.safe_get")
@patch("crewai_tools.rag.loaders.webpage_loader.BeautifulSoup")
def test_load_webpage_with_scripts_and_styles(self, mock_bs, mock_get):
html = """
@@ -62,7 +62,7 @@ class TestWebPageLoader:
for el in scripts + styles:
el.decompose.assert_called_once()
@patch("requests.get")
@patch("crewai_tools.rag.loaders.webpage_loader.safe_get")
@patch("crewai_tools.rag.loaders.webpage_loader.BeautifulSoup")
def test_text_cleaning_and_title_handling(self, mock_bs, mock_get):
mock_get.return_value = self.setup_mock_response(
@@ -77,7 +77,7 @@ class TestWebPageLoader:
assert result.content is not None
assert result.metadata["title"] == ""
@patch("requests.get")
@patch("crewai_tools.rag.loaders.webpage_loader.safe_get")
@patch("crewai_tools.rag.loaders.webpage_loader.BeautifulSoup")
def test_empty_or_missing_title(self, mock_bs, mock_get):
for title in [None, ""]:
@@ -90,7 +90,7 @@ class TestWebPageLoader:
result = loader.load(SourceContent("https://example.com"))
assert result.metadata["title"] == ""
@patch("requests.get")
@patch("crewai_tools.rag.loaders.webpage_loader.safe_get")
def test_custom_and_default_headers(self, mock_get):
mock_get.return_value = self.setup_mock_response(
"<html><body>Test</body></html>"
@@ -109,14 +109,14 @@ class TestWebPageLoader:
assert mock_get.call_args[1]["headers"] == custom_headers
@patch("requests.get")
@patch("crewai_tools.rag.loaders.webpage_loader.safe_get")
def test_error_handling(self, mock_get):
for error in [Exception("Fail"), ValueError("Bad"), ImportError("Oops")]:
mock_get.side_effect = error
with pytest.raises(ValueError, match="Error loading webpage"):
WebPageLoader().load(SourceContent("https://example.com"))
@patch("requests.get")
@patch("crewai_tools.rag.loaders.webpage_loader.safe_get")
def test_timeout_and_http_error(self, mock_get):
import requests
@@ -131,7 +131,7 @@ class TestWebPageLoader:
with pytest.raises(ValueError):
WebPageLoader().load(SourceContent("https://example.com/404"))
@patch("requests.get")
@patch("crewai_tools.rag.loaders.webpage_loader.safe_get")
@patch("crewai_tools.rag.loaders.webpage_loader.BeautifulSoup")
def test_doc_id_consistency(self, mock_bs, mock_get):
mock_get.return_value = self.setup_mock_response(
@@ -145,7 +145,7 @@ class TestWebPageLoader:
assert result1.doc_id == result2.doc_id
@patch("requests.get")
@patch("crewai_tools.rag.loaders.webpage_loader.safe_get")
@patch("crewai_tools.rag.loaders.webpage_loader.BeautifulSoup")
def test_status_code_and_content_type(self, mock_bs, mock_get):
for status in [200, 201, 301]:

View File

@@ -6,7 +6,10 @@ import os
import pytest
from unittest.mock import MagicMock, patch
from crewai_tools.security.safe_path import (
safe_get,
validate_directory_path,
validate_file_path,
validate_url,
@@ -168,3 +171,62 @@ class TestValidateUrl:
# file:// would normally be blocked
result = validate_url("file:///etc/passwd")
assert result == "file:///etc/passwd"
# ---------------------------------------------------------------------------
# safe_get — redirect-aware SSRF protection
# ---------------------------------------------------------------------------
def _fake_getaddrinfo_factory(ip: str):
"""Return a getaddrinfo replacement that always resolves to *ip*."""
def _fake(host, port, *args, **kwargs):
return [(2, 1, 6, "", (ip, port or 80))]
return _fake
class TestSafeGet:
"""Tests for safe_get (validates IPs on every redirect hop)."""
@patch("crewai_tools.security.safe_path.socket.getaddrinfo",
side_effect=_fake_getaddrinfo_factory("93.184.216.34"))
@patch("requests.adapters.HTTPAdapter.send")
def test_allows_public_url(self, mock_send, mock_dns):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.is_redirect = False
mock_response.headers = {}
mock_send.return_value = mock_response
resp = safe_get("https://example.com/page")
assert resp.status_code == 200
@patch("crewai_tools.security.safe_path.socket.getaddrinfo",
side_effect=_fake_getaddrinfo_factory("127.0.0.1"))
def test_blocks_redirect_to_localhost(self, mock_dns):
with pytest.raises(ValueError, match="private/reserved IP"):
safe_get("http://evil.com/redirect")
@patch("crewai_tools.security.safe_path.socket.getaddrinfo",
side_effect=_fake_getaddrinfo_factory("169.254.169.254"))
def test_blocks_redirect_to_metadata(self, mock_dns):
with pytest.raises(ValueError, match="private/reserved IP"):
safe_get("http://evil.com/metadata")
@patch("crewai_tools.security.safe_path.socket.getaddrinfo",
side_effect=_fake_getaddrinfo_factory("10.0.0.1"))
def test_blocks_redirect_to_private_range(self, mock_dns):
with pytest.raises(ValueError, match="private/reserved IP"):
safe_get("http://evil.com/internal")
@patch("crewai_tools.security.safe_path.socket.getaddrinfo",
side_effect=_fake_getaddrinfo_factory("169.254.169.254"))
@patch("requests.adapters.HTTPAdapter.send")
def test_escape_hatch_bypasses_redirect_check(self, mock_send, mock_dns, monkeypatch):
monkeypatch.setenv("CREWAI_TOOLS_ALLOW_UNSAFE_PATHS", "true")
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.is_redirect = False
mock_response.headers = {}
mock_send.return_value = mock_response
resp = safe_get("http://evil.com/metadata")
assert resp.status_code == 200