Validate redirects for scraping URL fetches

This commit is contained in:
Rip&Tear
2026-06-25 11:27:25 +08:00
parent 01fc389d4a
commit cd3fa72ead
10 changed files with 205 additions and 19 deletions

View File

@@ -8,6 +8,7 @@ import requests
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
from crewai_tools.rag.source_content import SourceContent
from crewai_tools.security.safe_requests import safe_get
class DocsSiteLoader(BaseLoader):
@@ -26,7 +27,7 @@ class DocsSiteLoader(BaseLoader):
docs_url = source.source
try:
response = requests.get(docs_url, timeout=30)
response = safe_get(docs_url, timeout=30)
response.raise_for_status()
except requests.RequestException as e:
raise ValueError(

View File

@@ -2,10 +2,9 @@ import os
import tempfile
from typing import Any
import requests
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
from crewai_tools.rag.source_content import SourceContent
from crewai_tools.security.safe_requests import safe_get
class DOCXLoader(BaseLoader):
@@ -43,7 +42,7 @@ class DOCXLoader(BaseLoader):
)
try:
response = requests.get(url, headers=headers, timeout=30)
response = safe_get(url, headers=headers, timeout=30)
response.raise_for_status()
# Create temporary file to save the DOCX content

View File

@@ -6,10 +6,9 @@ import tempfile
from typing import Any
from urllib.parse import urlparse
import requests
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
from crewai_tools.rag.source_content import SourceContent
from crewai_tools.security.safe_requests import safe_get
class PDFLoader(BaseLoader):
@@ -47,7 +46,7 @@ class PDFLoader(BaseLoader):
)
try:
response = requests.get(url, headers=headers, timeout=30)
response = safe_get(url, headers=headers, timeout=30)
response.raise_for_status()
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as temp_file:

View File

@@ -23,7 +23,7 @@ def load_from_url(
Raises:
ValueError: If there's an error fetching the URL
"""
import requests
from crewai_tools.security.safe_requests import safe_get
headers = kwargs.get(
"headers",
@@ -34,7 +34,7 @@ def load_from_url(
)
try:
response = requests.get(url, headers=headers, timeout=30)
response = safe_get(url, headers=headers, timeout=30)
response.raise_for_status()
return response.text
except Exception as e:

View File

@@ -2,10 +2,10 @@ import re
from typing import Any, Final
from bs4 import BeautifulSoup
import requests
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
from crewai_tools.rag.source_content import SourceContent
from crewai_tools.security.safe_requests import safe_get
_SPACES_PATTERN: Final[re.Pattern[str]] = re.compile(r"[ \t]+")
@@ -25,7 +25,7 @@ class WebPageLoader(BaseLoader):
)
try:
response = requests.get(url, timeout=15, headers=headers)
response = safe_get(url, timeout=15, headers=headers)
response.encoding = response.apparent_encoding
soup = BeautifulSoup(response.text, "html.parser")

View File

@@ -0,0 +1,49 @@
"""HTTP helpers that preserve crewai-tools URL safety checks."""
from __future__ import annotations
from typing import Any
from urllib.parse import urljoin
import requests
from crewai_tools.security.safe_path import validate_url
_REDIRECT_STATUS_CODES = {301, 302, 303, 307, 308}
def safe_get(url: str, *, max_redirects: int = 10, **kwargs: Any) -> requests.Response:
"""GET a URL while validating each redirect target before following it."""
current_url = validate_url(url)
request_kwargs = {**kwargs, "allow_redirects": False}
timeout = request_kwargs.pop("timeout", 30)
history: list[requests.Response] = []
redirects_followed = 0
while True:
response = requests.get(current_url, timeout=timeout, **request_kwargs)
if (
response.status_code not in _REDIRECT_STATUS_CODES
or "Location" not in response.headers
):
response.history = history
return response
if redirects_followed >= max_redirects:
response.close()
raise ValueError(f"Too many redirects while fetching URL: {url}")
location = response.headers.get("Location")
if not location:
response.history = history
return response
try:
redirect_url = validate_url(urljoin(response.url, location))
except ValueError:
response.close()
raise
history.append(response)
current_url = redirect_url
redirects_followed += 1

View File

@@ -3,9 +3,8 @@ from typing import Any
from crewai.tools import BaseTool
from pydantic import BaseModel, Field
import requests
from crewai_tools.security.safe_path import validate_url
from crewai_tools.security.safe_requests import safe_get
try:
@@ -83,8 +82,7 @@ class ScrapeElementFromWebsiteTool(BaseTool):
if website_url is None or css_element is None:
raise ValueError("Both website_url and css_element must be provided.")
website_url = validate_url(website_url)
page = requests.get(
page = safe_get(
website_url,
headers=self.headers,
cookies=self.cookies if self.cookies else {},

View File

@@ -3,9 +3,8 @@ import re
from typing import Any
from pydantic import Field
import requests
from crewai_tools.security.safe_path import validate_url
from crewai_tools.security.safe_requests import safe_get
try:
@@ -75,8 +74,7 @@ class ScrapeWebsiteTool(BaseTool):
if website_url is None:
raise ValueError("Website URL must be provided.")
website_url = validate_url(website_url)
page = requests.get(
page = safe_get(
website_url,
timeout=15,
headers=self.headers,

View File

@@ -0,0 +1,28 @@
from __future__ import annotations
import socket
from typing import Any
import pytest
@pytest.fixture(autouse=True)
def public_example_dns(monkeypatch: pytest.MonkeyPatch) -> None:
original_getaddrinfo = socket.getaddrinfo
def fake_getaddrinfo(
host: str, port: int, *args: Any, **kwargs: Any
) -> list[tuple[Any, ...]]:
if host in {"example.com", "api.example.com"}:
return [
(
socket.AF_INET,
socket.SOCK_STREAM,
6,
"",
("93.184.216.34", port),
)
]
return original_getaddrinfo(host, port, *args, **kwargs)
monkeypatch.setattr(socket, "getaddrinfo", fake_getaddrinfo)

View File

@@ -0,0 +1,114 @@
"""Tests for redirect-aware safe HTTP helpers."""
from __future__ import annotations
import socket
from io import BytesIO
from typing import Any
import pytest
import requests
from crewai_tools.security.safe_requests import safe_get
def _response(url: str, status_code: int, *, location: str | None = None) -> requests.Response:
response = requests.Response()
response.status_code = status_code
response.url = url
response._content = b"ok"
response.raw = BytesIO()
if location is not None:
response.headers["Location"] = location
return response
@pytest.fixture
def public_dns(monkeypatch: pytest.MonkeyPatch) -> None:
original_getaddrinfo = socket.getaddrinfo
def fake_getaddrinfo(
host: str, port: int, *args: Any, **kwargs: Any
) -> list[tuple[Any, ...]]:
if host in {"public.example", "safe.example"}:
return [
(
socket.AF_INET,
socket.SOCK_STREAM,
6,
"",
("93.184.216.34", port),
)
]
return original_getaddrinfo(host, port, *args, **kwargs)
monkeypatch.setattr(socket, "getaddrinfo", fake_getaddrinfo)
def test_safe_get_blocks_direct_internal_url() -> None:
with pytest.raises(ValueError, match="private/reserved IP"):
safe_get("http://127.0.0.1/admin", timeout=15)
def test_safe_get_blocks_redirect_to_internal_url(
monkeypatch: pytest.MonkeyPatch, public_dns: None
) -> None:
requested_urls: list[str] = []
def fake_get(url: str, **kwargs: Any) -> requests.Response:
requested_urls.append(url)
assert kwargs["allow_redirects"] is False
return _response(url, 302, location="http://127.0.0.1/admin")
monkeypatch.setattr(
"crewai_tools.security.safe_requests.requests.get",
fake_get,
)
with pytest.raises(ValueError, match="private/reserved IP"):
safe_get("http://public.example/start", timeout=15)
assert requested_urls == ["http://public.example/start"]
def test_safe_get_follows_safe_relative_redirect(
monkeypatch: pytest.MonkeyPatch, public_dns: None
) -> None:
requested_urls: list[str] = []
def fake_get(url: str, **kwargs: Any) -> requests.Response:
requested_urls.append(url)
assert kwargs["allow_redirects"] is False
if url == "http://public.example/start":
return _response(url, 302, location="/final")
return _response(url, 200)
monkeypatch.setattr(
"crewai_tools.security.safe_requests.requests.get",
fake_get,
)
response = safe_get("http://public.example/start", timeout=15)
assert response.status_code == 200
assert response.url == "http://public.example/final"
assert requested_urls == [
"http://public.example/start",
"http://public.example/final",
]
assert len(response.history) == 1
def test_safe_get_fails_closed_after_too_many_redirects(
monkeypatch: pytest.MonkeyPatch, public_dns: None
) -> None:
def fake_get(url: str, **kwargs: Any) -> requests.Response:
return _response(url, 302, location="http://safe.example/again")
monkeypatch.setattr(
"crewai_tools.security.safe_requests.requests.get",
fake_get,
)
with pytest.raises(ValueError, match="Too many redirects"):
safe_get("http://public.example/start", max_redirects=1, timeout=15)