mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-07 23:28:30 +00:00
288 lines
9.3 KiB
Python
288 lines
9.3 KiB
Python
"""Tests for base interceptor functionality."""
|
|
|
|
import httpx
|
|
import pytest
|
|
|
|
from crewai.llms.hooks.base import BaseInterceptor
|
|
|
|
|
|
class SimpleInterceptor(BaseInterceptor[httpx.Request, httpx.Response]):
|
|
"""Simple test interceptor implementation."""
|
|
|
|
def __init__(self) -> None:
|
|
"""Initialize tracking lists."""
|
|
self.outbound_calls: list[httpx.Request] = []
|
|
self.inbound_calls: list[httpx.Response] = []
|
|
|
|
def on_outbound(self, message: httpx.Request) -> httpx.Request:
|
|
"""Track outbound calls.
|
|
|
|
Args:
|
|
message: The outbound request.
|
|
|
|
Returns:
|
|
The request unchanged.
|
|
"""
|
|
self.outbound_calls.append(message)
|
|
return message
|
|
|
|
def on_inbound(self, message: httpx.Response) -> httpx.Response:
|
|
"""Track inbound calls.
|
|
|
|
Args:
|
|
message: The inbound response.
|
|
|
|
Returns:
|
|
The response unchanged.
|
|
"""
|
|
self.inbound_calls.append(message)
|
|
return message
|
|
|
|
|
|
class ModifyingInterceptor(BaseInterceptor[httpx.Request, httpx.Response]):
|
|
"""Interceptor that modifies requests and responses."""
|
|
|
|
def on_outbound(self, message: httpx.Request) -> httpx.Request:
|
|
"""Add custom header to outbound request.
|
|
|
|
Args:
|
|
message: The outbound request.
|
|
|
|
Returns:
|
|
Modified request with custom header.
|
|
"""
|
|
message.headers["X-Custom-Header"] = "test-value"
|
|
message.headers["X-Intercepted"] = "true"
|
|
return message
|
|
|
|
def on_inbound(self, message: httpx.Response) -> httpx.Response:
|
|
"""Add custom header to inbound response.
|
|
|
|
Args:
|
|
message: The inbound response.
|
|
|
|
Returns:
|
|
Modified response with custom header.
|
|
"""
|
|
message.headers["X-Response-Intercepted"] = "true"
|
|
return message
|
|
|
|
|
|
class AsyncInterceptor(BaseInterceptor[httpx.Request, httpx.Response]):
|
|
"""Interceptor with async support."""
|
|
|
|
def __init__(self) -> None:
|
|
"""Initialize tracking lists."""
|
|
self.async_outbound_calls: list[httpx.Request] = []
|
|
self.async_inbound_calls: list[httpx.Response] = []
|
|
|
|
def on_outbound(self, message: httpx.Request) -> httpx.Request:
|
|
"""Handle sync outbound.
|
|
|
|
Args:
|
|
message: The outbound request.
|
|
|
|
Returns:
|
|
The request unchanged.
|
|
"""
|
|
return message
|
|
|
|
def on_inbound(self, message: httpx.Response) -> httpx.Response:
|
|
"""Handle sync inbound.
|
|
|
|
Args:
|
|
message: The inbound response.
|
|
|
|
Returns:
|
|
The response unchanged.
|
|
"""
|
|
return message
|
|
|
|
async def aon_outbound(self, message: httpx.Request) -> httpx.Request:
|
|
"""Handle async outbound.
|
|
|
|
Args:
|
|
message: The outbound request.
|
|
|
|
Returns:
|
|
Modified request with async header.
|
|
"""
|
|
self.async_outbound_calls.append(message)
|
|
message.headers["X-Async-Outbound"] = "true"
|
|
return message
|
|
|
|
async def aon_inbound(self, message: httpx.Response) -> httpx.Response:
|
|
"""Handle async inbound.
|
|
|
|
Args:
|
|
message: The inbound response.
|
|
|
|
Returns:
|
|
Modified response with async header.
|
|
"""
|
|
self.async_inbound_calls.append(message)
|
|
message.headers["X-Async-Inbound"] = "true"
|
|
return message
|
|
|
|
|
|
class TestBaseInterceptor:
|
|
"""Test suite for BaseInterceptor class."""
|
|
|
|
def test_interceptor_instantiation(self) -> None:
|
|
"""Test that interceptor can be instantiated."""
|
|
interceptor = SimpleInterceptor()
|
|
assert interceptor is not None
|
|
assert isinstance(interceptor, BaseInterceptor)
|
|
|
|
def test_on_outbound_called(self) -> None:
|
|
"""Test that on_outbound is called and tracks requests."""
|
|
interceptor = SimpleInterceptor()
|
|
request = httpx.Request("GET", "https://api.example.com/test")
|
|
|
|
result = interceptor.on_outbound(request)
|
|
|
|
assert len(interceptor.outbound_calls) == 1
|
|
assert interceptor.outbound_calls[0] is request
|
|
assert result is request
|
|
|
|
def test_on_inbound_called(self) -> None:
|
|
"""Test that on_inbound is called and tracks responses."""
|
|
interceptor = SimpleInterceptor()
|
|
response = httpx.Response(200, json={"status": "ok"})
|
|
|
|
result = interceptor.on_inbound(response)
|
|
|
|
assert len(interceptor.inbound_calls) == 1
|
|
assert interceptor.inbound_calls[0] is response
|
|
assert result is response
|
|
|
|
def test_multiple_outbound_calls(self) -> None:
|
|
"""Test that interceptor tracks multiple outbound calls."""
|
|
interceptor = SimpleInterceptor()
|
|
requests = [
|
|
httpx.Request("GET", "https://api.example.com/1"),
|
|
httpx.Request("POST", "https://api.example.com/2"),
|
|
httpx.Request("PUT", "https://api.example.com/3"),
|
|
]
|
|
|
|
for req in requests:
|
|
interceptor.on_outbound(req)
|
|
|
|
assert len(interceptor.outbound_calls) == 3
|
|
assert interceptor.outbound_calls == requests
|
|
|
|
def test_multiple_inbound_calls(self) -> None:
|
|
"""Test that interceptor tracks multiple inbound calls."""
|
|
interceptor = SimpleInterceptor()
|
|
responses = [
|
|
httpx.Response(200, json={"id": 1}),
|
|
httpx.Response(201, json={"id": 2}),
|
|
httpx.Response(404, json={"error": "not found"}),
|
|
]
|
|
|
|
for resp in responses:
|
|
interceptor.on_inbound(resp)
|
|
|
|
assert len(interceptor.inbound_calls) == 3
|
|
assert interceptor.inbound_calls == responses
|
|
|
|
|
|
class TestModifyingInterceptor:
|
|
"""Test suite for interceptor that modifies messages."""
|
|
|
|
def test_outbound_header_modification(self) -> None:
|
|
"""Test that interceptor can add headers to outbound requests."""
|
|
interceptor = ModifyingInterceptor()
|
|
request = httpx.Request("GET", "https://api.example.com/test")
|
|
|
|
result = interceptor.on_outbound(request)
|
|
|
|
assert result is request
|
|
assert "X-Custom-Header" in result.headers
|
|
assert result.headers["X-Custom-Header"] == "test-value"
|
|
assert "X-Intercepted" in result.headers
|
|
assert result.headers["X-Intercepted"] == "true"
|
|
|
|
def test_inbound_header_modification(self) -> None:
|
|
"""Test that interceptor can add headers to inbound responses."""
|
|
interceptor = ModifyingInterceptor()
|
|
response = httpx.Response(200, json={"status": "ok"})
|
|
|
|
result = interceptor.on_inbound(response)
|
|
|
|
assert result is response
|
|
assert "X-Response-Intercepted" in result.headers
|
|
assert result.headers["X-Response-Intercepted"] == "true"
|
|
|
|
def test_preserves_existing_headers(self) -> None:
|
|
"""Test that interceptor preserves existing headers."""
|
|
interceptor = ModifyingInterceptor()
|
|
request = httpx.Request(
|
|
"GET",
|
|
"https://api.example.com/test",
|
|
headers={"Authorization": "Bearer token123", "Content-Type": "application/json"},
|
|
)
|
|
|
|
result = interceptor.on_outbound(request)
|
|
|
|
assert result.headers["Authorization"] == "Bearer token123"
|
|
assert result.headers["Content-Type"] == "application/json"
|
|
assert result.headers["X-Custom-Header"] == "test-value"
|
|
|
|
|
|
class TestAsyncInterceptor:
|
|
"""Test suite for async interceptor functionality."""
|
|
|
|
def test_sync_methods_work(self) -> None:
|
|
"""Test that sync methods still work on async interceptor."""
|
|
interceptor = AsyncInterceptor()
|
|
request = httpx.Request("GET", "https://api.example.com/test")
|
|
response = httpx.Response(200)
|
|
|
|
req_result = interceptor.on_outbound(request)
|
|
resp_result = interceptor.on_inbound(response)
|
|
|
|
assert req_result is request
|
|
assert resp_result is response
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_outbound(self) -> None:
|
|
"""Test async outbound hook."""
|
|
interceptor = AsyncInterceptor()
|
|
request = httpx.Request("GET", "https://api.example.com/test")
|
|
|
|
result = await interceptor.aon_outbound(request)
|
|
|
|
assert result is request
|
|
assert len(interceptor.async_outbound_calls) == 1
|
|
assert interceptor.async_outbound_calls[0] is request
|
|
assert "X-Async-Outbound" in result.headers
|
|
assert result.headers["X-Async-Outbound"] == "true"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_inbound(self) -> None:
|
|
"""Test async inbound hook."""
|
|
interceptor = AsyncInterceptor()
|
|
response = httpx.Response(200, json={"status": "ok"})
|
|
|
|
result = await interceptor.aon_inbound(response)
|
|
|
|
assert result is response
|
|
assert len(interceptor.async_inbound_calls) == 1
|
|
assert interceptor.async_inbound_calls[0] is response
|
|
assert "X-Async-Inbound" in result.headers
|
|
assert result.headers["X-Async-Inbound"] == "true"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_default_async_not_implemented(self) -> None:
|
|
"""Test that default async methods raise NotImplementedError."""
|
|
interceptor = SimpleInterceptor()
|
|
request = httpx.Request("GET", "https://api.example.com/test")
|
|
response = httpx.Response(200)
|
|
|
|
with pytest.raises(NotImplementedError):
|
|
await interceptor.aon_outbound(request)
|
|
|
|
with pytest.raises(NotImplementedError):
|
|
await interceptor.aon_inbound(response)
|