fix(ci): add type annotations to _SSRFSafeAdapter.send and fix test mocks

- Add proper type annotations to _SSRFSafeAdapter.send() to satisfy mypy
- Add 'Any' import from typing
- Update webpage_loader tests to mock safe_get instead of requests.get
  (the loader now uses safe_get for SSRF protection)
This commit is contained in:
Iris Clawd
2026-05-05 15:47:02 +00:00
parent 3dc8c45cc9
commit 42b4f0101e
2 changed files with 14 additions and 11 deletions

View File

@@ -14,6 +14,7 @@ import ipaddress
import logging import logging
import os import os
import socket import socket
from typing import Any
from urllib.parse import urlparse from urllib.parse import urlparse
import requests import requests
@@ -218,7 +219,9 @@ class _SSRFSafeAdapter(HTTPAdapter):
redirect hops — against the private/reserved blocklist before the redirect hops — against the private/reserved blocklist before the
connection is made.""" connection is made."""
def send(self, request, **kwargs): def send( # type: ignore[override]
self, request: requests.PreparedRequest, **kwargs: Any
) -> requests.Response:
parsed = urlparse(request.url) parsed = urlparse(request.url)
if not _is_escape_hatch_enabled() and parsed.hostname: if not _is_escape_hatch_enabled() and parsed.hostname:
try: try:
@@ -252,7 +255,7 @@ def safe_request_session() -> requests.Session:
return session return session
def safe_get(url: str, **kwargs) -> requests.Response: def safe_get(url: str, **kwargs: Any) -> requests.Response:
"""Drop-in replacement for ``requests.get()`` with SSRF protection. """Drop-in replacement for ``requests.get()`` with SSRF protection.
Validates the initial URL via :func:`validate_url`, then executes the Validates the initial URL via :func:`validate_url`, then executes the

View File

@@ -22,7 +22,7 @@ class TestWebPageLoader:
soup.return_value = script_style_elements or [] soup.return_value = script_style_elements or []
return soup return soup
@patch("requests.get") @patch("crewai_tools.rag.loaders.webpage_loader.safe_get")
@patch("crewai_tools.rag.loaders.webpage_loader.BeautifulSoup") @patch("crewai_tools.rag.loaders.webpage_loader.BeautifulSoup")
def test_load_basic_webpage(self, mock_bs, mock_get): def test_load_basic_webpage(self, mock_bs, mock_get):
mock_get.return_value = self.setup_mock_response( mock_get.return_value = self.setup_mock_response(
@@ -37,7 +37,7 @@ class TestWebPageLoader:
assert result.content == "Test content" assert result.content == "Test content"
assert result.metadata["title"] == "Test Page" assert result.metadata["title"] == "Test Page"
@patch("requests.get") @patch("crewai_tools.rag.loaders.webpage_loader.safe_get")
@patch("crewai_tools.rag.loaders.webpage_loader.BeautifulSoup") @patch("crewai_tools.rag.loaders.webpage_loader.BeautifulSoup")
def test_load_webpage_with_scripts_and_styles(self, mock_bs, mock_get): def test_load_webpage_with_scripts_and_styles(self, mock_bs, mock_get):
html = """ html = """
@@ -62,7 +62,7 @@ class TestWebPageLoader:
for el in scripts + styles: for el in scripts + styles:
el.decompose.assert_called_once() el.decompose.assert_called_once()
@patch("requests.get") @patch("crewai_tools.rag.loaders.webpage_loader.safe_get")
@patch("crewai_tools.rag.loaders.webpage_loader.BeautifulSoup") @patch("crewai_tools.rag.loaders.webpage_loader.BeautifulSoup")
def test_text_cleaning_and_title_handling(self, mock_bs, mock_get): def test_text_cleaning_and_title_handling(self, mock_bs, mock_get):
mock_get.return_value = self.setup_mock_response( mock_get.return_value = self.setup_mock_response(
@@ -77,7 +77,7 @@ class TestWebPageLoader:
assert result.content is not None assert result.content is not None
assert result.metadata["title"] == "" assert result.metadata["title"] == ""
@patch("requests.get") @patch("crewai_tools.rag.loaders.webpage_loader.safe_get")
@patch("crewai_tools.rag.loaders.webpage_loader.BeautifulSoup") @patch("crewai_tools.rag.loaders.webpage_loader.BeautifulSoup")
def test_empty_or_missing_title(self, mock_bs, mock_get): def test_empty_or_missing_title(self, mock_bs, mock_get):
for title in [None, ""]: for title in [None, ""]:
@@ -90,7 +90,7 @@ class TestWebPageLoader:
result = loader.load(SourceContent("https://example.com")) result = loader.load(SourceContent("https://example.com"))
assert result.metadata["title"] == "" assert result.metadata["title"] == ""
@patch("requests.get") @patch("crewai_tools.rag.loaders.webpage_loader.safe_get")
def test_custom_and_default_headers(self, mock_get): def test_custom_and_default_headers(self, mock_get):
mock_get.return_value = self.setup_mock_response( mock_get.return_value = self.setup_mock_response(
"<html><body>Test</body></html>" "<html><body>Test</body></html>"
@@ -109,14 +109,14 @@ class TestWebPageLoader:
assert mock_get.call_args[1]["headers"] == custom_headers assert mock_get.call_args[1]["headers"] == custom_headers
@patch("requests.get") @patch("crewai_tools.rag.loaders.webpage_loader.safe_get")
def test_error_handling(self, mock_get): def test_error_handling(self, mock_get):
for error in [Exception("Fail"), ValueError("Bad"), ImportError("Oops")]: for error in [Exception("Fail"), ValueError("Bad"), ImportError("Oops")]:
mock_get.side_effect = error mock_get.side_effect = error
with pytest.raises(ValueError, match="Error loading webpage"): with pytest.raises(ValueError, match="Error loading webpage"):
WebPageLoader().load(SourceContent("https://example.com")) WebPageLoader().load(SourceContent("https://example.com"))
@patch("requests.get") @patch("crewai_tools.rag.loaders.webpage_loader.safe_get")
def test_timeout_and_http_error(self, mock_get): def test_timeout_and_http_error(self, mock_get):
import requests import requests
@@ -131,7 +131,7 @@ class TestWebPageLoader:
with pytest.raises(ValueError): with pytest.raises(ValueError):
WebPageLoader().load(SourceContent("https://example.com/404")) WebPageLoader().load(SourceContent("https://example.com/404"))
@patch("requests.get") @patch("crewai_tools.rag.loaders.webpage_loader.safe_get")
@patch("crewai_tools.rag.loaders.webpage_loader.BeautifulSoup") @patch("crewai_tools.rag.loaders.webpage_loader.BeautifulSoup")
def test_doc_id_consistency(self, mock_bs, mock_get): def test_doc_id_consistency(self, mock_bs, mock_get):
mock_get.return_value = self.setup_mock_response( mock_get.return_value = self.setup_mock_response(
@@ -145,7 +145,7 @@ class TestWebPageLoader:
assert result1.doc_id == result2.doc_id assert result1.doc_id == result2.doc_id
@patch("requests.get") @patch("crewai_tools.rag.loaders.webpage_loader.safe_get")
@patch("crewai_tools.rag.loaders.webpage_loader.BeautifulSoup") @patch("crewai_tools.rag.loaders.webpage_loader.BeautifulSoup")
def test_status_code_and_content_type(self, mock_bs, mock_get): def test_status_code_and_content_type(self, mock_bs, mock_get):
for status in [200, 201, 301]: for status in [200, 201, 301]: