mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-06 01:32:36 +00:00
fix(ci): add type annotations to _SSRFSafeAdapter.send and fix test mocks
- Add proper type annotations to _SSRFSafeAdapter.send() to satisfy mypy - Add 'Any' import from typing - Update webpage_loader tests to mock safe_get instead of requests.get (the loader now uses safe_get for SSRF protection)
This commit is contained in:
@@ -14,6 +14,7 @@ import ipaddress
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
@@ -218,7 +219,9 @@ class _SSRFSafeAdapter(HTTPAdapter):
|
||||
redirect hops — against the private/reserved blocklist before the
|
||||
connection is made."""
|
||||
|
||||
def send(self, request, **kwargs):
|
||||
def send( # type: ignore[override]
|
||||
self, request: requests.PreparedRequest, **kwargs: Any
|
||||
) -> requests.Response:
|
||||
parsed = urlparse(request.url)
|
||||
if not _is_escape_hatch_enabled() and parsed.hostname:
|
||||
try:
|
||||
@@ -252,7 +255,7 @@ def safe_request_session() -> requests.Session:
|
||||
return session
|
||||
|
||||
|
||||
def safe_get(url: str, **kwargs) -> requests.Response:
|
||||
def safe_get(url: str, **kwargs: Any) -> requests.Response:
|
||||
"""Drop-in replacement for ``requests.get()`` with SSRF protection.
|
||||
|
||||
Validates the initial URL via :func:`validate_url`, then executes the
|
||||
|
||||
@@ -22,7 +22,7 @@ class TestWebPageLoader:
|
||||
soup.return_value = script_style_elements or []
|
||||
return soup
|
||||
|
||||
@patch("requests.get")
|
||||
@patch("crewai_tools.rag.loaders.webpage_loader.safe_get")
|
||||
@patch("crewai_tools.rag.loaders.webpage_loader.BeautifulSoup")
|
||||
def test_load_basic_webpage(self, mock_bs, mock_get):
|
||||
mock_get.return_value = self.setup_mock_response(
|
||||
@@ -37,7 +37,7 @@ class TestWebPageLoader:
|
||||
assert result.content == "Test content"
|
||||
assert result.metadata["title"] == "Test Page"
|
||||
|
||||
@patch("requests.get")
|
||||
@patch("crewai_tools.rag.loaders.webpage_loader.safe_get")
|
||||
@patch("crewai_tools.rag.loaders.webpage_loader.BeautifulSoup")
|
||||
def test_load_webpage_with_scripts_and_styles(self, mock_bs, mock_get):
|
||||
html = """
|
||||
@@ -62,7 +62,7 @@ class TestWebPageLoader:
|
||||
for el in scripts + styles:
|
||||
el.decompose.assert_called_once()
|
||||
|
||||
@patch("requests.get")
|
||||
@patch("crewai_tools.rag.loaders.webpage_loader.safe_get")
|
||||
@patch("crewai_tools.rag.loaders.webpage_loader.BeautifulSoup")
|
||||
def test_text_cleaning_and_title_handling(self, mock_bs, mock_get):
|
||||
mock_get.return_value = self.setup_mock_response(
|
||||
@@ -77,7 +77,7 @@ class TestWebPageLoader:
|
||||
assert result.content is not None
|
||||
assert result.metadata["title"] == ""
|
||||
|
||||
@patch("requests.get")
|
||||
@patch("crewai_tools.rag.loaders.webpage_loader.safe_get")
|
||||
@patch("crewai_tools.rag.loaders.webpage_loader.BeautifulSoup")
|
||||
def test_empty_or_missing_title(self, mock_bs, mock_get):
|
||||
for title in [None, ""]:
|
||||
@@ -90,7 +90,7 @@ class TestWebPageLoader:
|
||||
result = loader.load(SourceContent("https://example.com"))
|
||||
assert result.metadata["title"] == ""
|
||||
|
||||
@patch("requests.get")
|
||||
@patch("crewai_tools.rag.loaders.webpage_loader.safe_get")
|
||||
def test_custom_and_default_headers(self, mock_get):
|
||||
mock_get.return_value = self.setup_mock_response(
|
||||
"<html><body>Test</body></html>"
|
||||
@@ -109,14 +109,14 @@ class TestWebPageLoader:
|
||||
|
||||
assert mock_get.call_args[1]["headers"] == custom_headers
|
||||
|
||||
@patch("requests.get")
|
||||
@patch("crewai_tools.rag.loaders.webpage_loader.safe_get")
|
||||
def test_error_handling(self, mock_get):
|
||||
for error in [Exception("Fail"), ValueError("Bad"), ImportError("Oops")]:
|
||||
mock_get.side_effect = error
|
||||
with pytest.raises(ValueError, match="Error loading webpage"):
|
||||
WebPageLoader().load(SourceContent("https://example.com"))
|
||||
|
||||
@patch("requests.get")
|
||||
@patch("crewai_tools.rag.loaders.webpage_loader.safe_get")
|
||||
def test_timeout_and_http_error(self, mock_get):
|
||||
import requests
|
||||
|
||||
@@ -131,7 +131,7 @@ class TestWebPageLoader:
|
||||
with pytest.raises(ValueError):
|
||||
WebPageLoader().load(SourceContent("https://example.com/404"))
|
||||
|
||||
@patch("requests.get")
|
||||
@patch("crewai_tools.rag.loaders.webpage_loader.safe_get")
|
||||
@patch("crewai_tools.rag.loaders.webpage_loader.BeautifulSoup")
|
||||
def test_doc_id_consistency(self, mock_bs, mock_get):
|
||||
mock_get.return_value = self.setup_mock_response(
|
||||
@@ -145,7 +145,7 @@ class TestWebPageLoader:
|
||||
|
||||
assert result1.doc_id == result2.doc_id
|
||||
|
||||
@patch("requests.get")
|
||||
@patch("crewai_tools.rag.loaders.webpage_loader.safe_get")
|
||||
@patch("crewai_tools.rag.loaders.webpage_loader.BeautifulSoup")
|
||||
def test_status_code_and_content_type(self, mock_bs, mock_get):
|
||||
for status in [200, 201, 301]:
|
||||
|
||||
Reference in New Issue
Block a user