mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-07 10:12:38 +00:00
fix(ci): add type annotations to _SSRFSafeAdapter.send and fix test mocks
- Add proper type annotations to _SSRFSafeAdapter.send() to satisfy mypy - Add 'Any' import from typing - Update webpage_loader tests to mock safe_get instead of requests.get (the loader now uses safe_get for SSRF protection)
This commit is contained in:
@@ -14,6 +14,7 @@ import ipaddress
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import socket
|
import socket
|
||||||
|
from typing import Any
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
@@ -218,7 +219,9 @@ class _SSRFSafeAdapter(HTTPAdapter):
|
|||||||
redirect hops — against the private/reserved blocklist before the
|
redirect hops — against the private/reserved blocklist before the
|
||||||
connection is made."""
|
connection is made."""
|
||||||
|
|
||||||
def send(self, request, **kwargs):
|
def send( # type: ignore[override]
|
||||||
|
self, request: requests.PreparedRequest, **kwargs: Any
|
||||||
|
) -> requests.Response:
|
||||||
parsed = urlparse(request.url)
|
parsed = urlparse(request.url)
|
||||||
if not _is_escape_hatch_enabled() and parsed.hostname:
|
if not _is_escape_hatch_enabled() and parsed.hostname:
|
||||||
try:
|
try:
|
||||||
@@ -252,7 +255,7 @@ def safe_request_session() -> requests.Session:
|
|||||||
return session
|
return session
|
||||||
|
|
||||||
|
|
||||||
def safe_get(url: str, **kwargs) -> requests.Response:
|
def safe_get(url: str, **kwargs: Any) -> requests.Response:
|
||||||
"""Drop-in replacement for ``requests.get()`` with SSRF protection.
|
"""Drop-in replacement for ``requests.get()`` with SSRF protection.
|
||||||
|
|
||||||
Validates the initial URL via :func:`validate_url`, then executes the
|
Validates the initial URL via :func:`validate_url`, then executes the
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ class TestWebPageLoader:
|
|||||||
soup.return_value = script_style_elements or []
|
soup.return_value = script_style_elements or []
|
||||||
return soup
|
return soup
|
||||||
|
|
||||||
@patch("requests.get")
|
@patch("crewai_tools.rag.loaders.webpage_loader.safe_get")
|
||||||
@patch("crewai_tools.rag.loaders.webpage_loader.BeautifulSoup")
|
@patch("crewai_tools.rag.loaders.webpage_loader.BeautifulSoup")
|
||||||
def test_load_basic_webpage(self, mock_bs, mock_get):
|
def test_load_basic_webpage(self, mock_bs, mock_get):
|
||||||
mock_get.return_value = self.setup_mock_response(
|
mock_get.return_value = self.setup_mock_response(
|
||||||
@@ -37,7 +37,7 @@ class TestWebPageLoader:
|
|||||||
assert result.content == "Test content"
|
assert result.content == "Test content"
|
||||||
assert result.metadata["title"] == "Test Page"
|
assert result.metadata["title"] == "Test Page"
|
||||||
|
|
||||||
@patch("requests.get")
|
@patch("crewai_tools.rag.loaders.webpage_loader.safe_get")
|
||||||
@patch("crewai_tools.rag.loaders.webpage_loader.BeautifulSoup")
|
@patch("crewai_tools.rag.loaders.webpage_loader.BeautifulSoup")
|
||||||
def test_load_webpage_with_scripts_and_styles(self, mock_bs, mock_get):
|
def test_load_webpage_with_scripts_and_styles(self, mock_bs, mock_get):
|
||||||
html = """
|
html = """
|
||||||
@@ -62,7 +62,7 @@ class TestWebPageLoader:
|
|||||||
for el in scripts + styles:
|
for el in scripts + styles:
|
||||||
el.decompose.assert_called_once()
|
el.decompose.assert_called_once()
|
||||||
|
|
||||||
@patch("requests.get")
|
@patch("crewai_tools.rag.loaders.webpage_loader.safe_get")
|
||||||
@patch("crewai_tools.rag.loaders.webpage_loader.BeautifulSoup")
|
@patch("crewai_tools.rag.loaders.webpage_loader.BeautifulSoup")
|
||||||
def test_text_cleaning_and_title_handling(self, mock_bs, mock_get):
|
def test_text_cleaning_and_title_handling(self, mock_bs, mock_get):
|
||||||
mock_get.return_value = self.setup_mock_response(
|
mock_get.return_value = self.setup_mock_response(
|
||||||
@@ -77,7 +77,7 @@ class TestWebPageLoader:
|
|||||||
assert result.content is not None
|
assert result.content is not None
|
||||||
assert result.metadata["title"] == ""
|
assert result.metadata["title"] == ""
|
||||||
|
|
||||||
@patch("requests.get")
|
@patch("crewai_tools.rag.loaders.webpage_loader.safe_get")
|
||||||
@patch("crewai_tools.rag.loaders.webpage_loader.BeautifulSoup")
|
@patch("crewai_tools.rag.loaders.webpage_loader.BeautifulSoup")
|
||||||
def test_empty_or_missing_title(self, mock_bs, mock_get):
|
def test_empty_or_missing_title(self, mock_bs, mock_get):
|
||||||
for title in [None, ""]:
|
for title in [None, ""]:
|
||||||
@@ -90,7 +90,7 @@ class TestWebPageLoader:
|
|||||||
result = loader.load(SourceContent("https://example.com"))
|
result = loader.load(SourceContent("https://example.com"))
|
||||||
assert result.metadata["title"] == ""
|
assert result.metadata["title"] == ""
|
||||||
|
|
||||||
@patch("requests.get")
|
@patch("crewai_tools.rag.loaders.webpage_loader.safe_get")
|
||||||
def test_custom_and_default_headers(self, mock_get):
|
def test_custom_and_default_headers(self, mock_get):
|
||||||
mock_get.return_value = self.setup_mock_response(
|
mock_get.return_value = self.setup_mock_response(
|
||||||
"<html><body>Test</body></html>"
|
"<html><body>Test</body></html>"
|
||||||
@@ -109,14 +109,14 @@ class TestWebPageLoader:
|
|||||||
|
|
||||||
assert mock_get.call_args[1]["headers"] == custom_headers
|
assert mock_get.call_args[1]["headers"] == custom_headers
|
||||||
|
|
||||||
@patch("requests.get")
|
@patch("crewai_tools.rag.loaders.webpage_loader.safe_get")
|
||||||
def test_error_handling(self, mock_get):
|
def test_error_handling(self, mock_get):
|
||||||
for error in [Exception("Fail"), ValueError("Bad"), ImportError("Oops")]:
|
for error in [Exception("Fail"), ValueError("Bad"), ImportError("Oops")]:
|
||||||
mock_get.side_effect = error
|
mock_get.side_effect = error
|
||||||
with pytest.raises(ValueError, match="Error loading webpage"):
|
with pytest.raises(ValueError, match="Error loading webpage"):
|
||||||
WebPageLoader().load(SourceContent("https://example.com"))
|
WebPageLoader().load(SourceContent("https://example.com"))
|
||||||
|
|
||||||
@patch("requests.get")
|
@patch("crewai_tools.rag.loaders.webpage_loader.safe_get")
|
||||||
def test_timeout_and_http_error(self, mock_get):
|
def test_timeout_and_http_error(self, mock_get):
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
@@ -131,7 +131,7 @@ class TestWebPageLoader:
|
|||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
WebPageLoader().load(SourceContent("https://example.com/404"))
|
WebPageLoader().load(SourceContent("https://example.com/404"))
|
||||||
|
|
||||||
@patch("requests.get")
|
@patch("crewai_tools.rag.loaders.webpage_loader.safe_get")
|
||||||
@patch("crewai_tools.rag.loaders.webpage_loader.BeautifulSoup")
|
@patch("crewai_tools.rag.loaders.webpage_loader.BeautifulSoup")
|
||||||
def test_doc_id_consistency(self, mock_bs, mock_get):
|
def test_doc_id_consistency(self, mock_bs, mock_get):
|
||||||
mock_get.return_value = self.setup_mock_response(
|
mock_get.return_value = self.setup_mock_response(
|
||||||
@@ -145,7 +145,7 @@ class TestWebPageLoader:
|
|||||||
|
|
||||||
assert result1.doc_id == result2.doc_id
|
assert result1.doc_id == result2.doc_id
|
||||||
|
|
||||||
@patch("requests.get")
|
@patch("crewai_tools.rag.loaders.webpage_loader.safe_get")
|
||||||
@patch("crewai_tools.rag.loaders.webpage_loader.BeautifulSoup")
|
@patch("crewai_tools.rag.loaders.webpage_loader.BeautifulSoup")
|
||||||
def test_status_code_and_content_type(self, mock_bs, mock_get):
|
def test_status_code_and_content_type(self, mock_bs, mock_get):
|
||||||
for status in [200, 201, 301]:
|
for status in [200, 201, 301]:
|
||||||
|
|||||||
Reference in New Issue
Block a user