mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 16:18:30 +00:00
feat: add support for llm message interceptor hooks
This commit is contained in:
287
lib/crewai/tests/llms/hooks/test_base_interceptor.py
Normal file
287
lib/crewai/tests/llms/hooks/test_base_interceptor.py
Normal file
@@ -0,0 +1,287 @@
|
||||
"""Tests for base interceptor functionality."""
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
|
||||
|
||||
class SimpleInterceptor(BaseInterceptor[httpx.Request, httpx.Response]):
|
||||
"""Simple test interceptor implementation."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize tracking lists."""
|
||||
self.outbound_calls: list[httpx.Request] = []
|
||||
self.inbound_calls: list[httpx.Response] = []
|
||||
|
||||
def on_outbound(self, message: httpx.Request) -> httpx.Request:
|
||||
"""Track outbound calls.
|
||||
|
||||
Args:
|
||||
message: The outbound request.
|
||||
|
||||
Returns:
|
||||
The request unchanged.
|
||||
"""
|
||||
self.outbound_calls.append(message)
|
||||
return message
|
||||
|
||||
def on_inbound(self, message: httpx.Response) -> httpx.Response:
|
||||
"""Track inbound calls.
|
||||
|
||||
Args:
|
||||
message: The inbound response.
|
||||
|
||||
Returns:
|
||||
The response unchanged.
|
||||
"""
|
||||
self.inbound_calls.append(message)
|
||||
return message
|
||||
|
||||
|
||||
class ModifyingInterceptor(BaseInterceptor[httpx.Request, httpx.Response]):
|
||||
"""Interceptor that modifies requests and responses."""
|
||||
|
||||
def on_outbound(self, message: httpx.Request) -> httpx.Request:
|
||||
"""Add custom header to outbound request.
|
||||
|
||||
Args:
|
||||
message: The outbound request.
|
||||
|
||||
Returns:
|
||||
Modified request with custom header.
|
||||
"""
|
||||
message.headers["X-Custom-Header"] = "test-value"
|
||||
message.headers["X-Intercepted"] = "true"
|
||||
return message
|
||||
|
||||
def on_inbound(self, message: httpx.Response) -> httpx.Response:
|
||||
"""Add custom header to inbound response.
|
||||
|
||||
Args:
|
||||
message: The inbound response.
|
||||
|
||||
Returns:
|
||||
Modified response with custom header.
|
||||
"""
|
||||
message.headers["X-Response-Intercepted"] = "true"
|
||||
return message
|
||||
|
||||
|
||||
class AsyncInterceptor(BaseInterceptor[httpx.Request, httpx.Response]):
|
||||
"""Interceptor with async support."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize tracking lists."""
|
||||
self.async_outbound_calls: list[httpx.Request] = []
|
||||
self.async_inbound_calls: list[httpx.Response] = []
|
||||
|
||||
def on_outbound(self, message: httpx.Request) -> httpx.Request:
|
||||
"""Handle sync outbound.
|
||||
|
||||
Args:
|
||||
message: The outbound request.
|
||||
|
||||
Returns:
|
||||
The request unchanged.
|
||||
"""
|
||||
return message
|
||||
|
||||
def on_inbound(self, message: httpx.Response) -> httpx.Response:
|
||||
"""Handle sync inbound.
|
||||
|
||||
Args:
|
||||
message: The inbound response.
|
||||
|
||||
Returns:
|
||||
The response unchanged.
|
||||
"""
|
||||
return message
|
||||
|
||||
async def aon_outbound(self, message: httpx.Request) -> httpx.Request:
|
||||
"""Handle async outbound.
|
||||
|
||||
Args:
|
||||
message: The outbound request.
|
||||
|
||||
Returns:
|
||||
Modified request with async header.
|
||||
"""
|
||||
self.async_outbound_calls.append(message)
|
||||
message.headers["X-Async-Outbound"] = "true"
|
||||
return message
|
||||
|
||||
async def aon_inbound(self, message: httpx.Response) -> httpx.Response:
|
||||
"""Handle async inbound.
|
||||
|
||||
Args:
|
||||
message: The inbound response.
|
||||
|
||||
Returns:
|
||||
Modified response with async header.
|
||||
"""
|
||||
self.async_inbound_calls.append(message)
|
||||
message.headers["X-Async-Inbound"] = "true"
|
||||
return message
|
||||
|
||||
|
||||
class TestBaseInterceptor:
|
||||
"""Test suite for BaseInterceptor class."""
|
||||
|
||||
def test_interceptor_instantiation(self) -> None:
|
||||
"""Test that interceptor can be instantiated."""
|
||||
interceptor = SimpleInterceptor()
|
||||
assert interceptor is not None
|
||||
assert isinstance(interceptor, BaseInterceptor)
|
||||
|
||||
def test_on_outbound_called(self) -> None:
|
||||
"""Test that on_outbound is called and tracks requests."""
|
||||
interceptor = SimpleInterceptor()
|
||||
request = httpx.Request("GET", "https://api.example.com/test")
|
||||
|
||||
result = interceptor.on_outbound(request)
|
||||
|
||||
assert len(interceptor.outbound_calls) == 1
|
||||
assert interceptor.outbound_calls[0] is request
|
||||
assert result is request
|
||||
|
||||
def test_on_inbound_called(self) -> None:
|
||||
"""Test that on_inbound is called and tracks responses."""
|
||||
interceptor = SimpleInterceptor()
|
||||
response = httpx.Response(200, json={"status": "ok"})
|
||||
|
||||
result = interceptor.on_inbound(response)
|
||||
|
||||
assert len(interceptor.inbound_calls) == 1
|
||||
assert interceptor.inbound_calls[0] is response
|
||||
assert result is response
|
||||
|
||||
def test_multiple_outbound_calls(self) -> None:
|
||||
"""Test that interceptor tracks multiple outbound calls."""
|
||||
interceptor = SimpleInterceptor()
|
||||
requests = [
|
||||
httpx.Request("GET", "https://api.example.com/1"),
|
||||
httpx.Request("POST", "https://api.example.com/2"),
|
||||
httpx.Request("PUT", "https://api.example.com/3"),
|
||||
]
|
||||
|
||||
for req in requests:
|
||||
interceptor.on_outbound(req)
|
||||
|
||||
assert len(interceptor.outbound_calls) == 3
|
||||
assert interceptor.outbound_calls == requests
|
||||
|
||||
def test_multiple_inbound_calls(self) -> None:
|
||||
"""Test that interceptor tracks multiple inbound calls."""
|
||||
interceptor = SimpleInterceptor()
|
||||
responses = [
|
||||
httpx.Response(200, json={"id": 1}),
|
||||
httpx.Response(201, json={"id": 2}),
|
||||
httpx.Response(404, json={"error": "not found"}),
|
||||
]
|
||||
|
||||
for resp in responses:
|
||||
interceptor.on_inbound(resp)
|
||||
|
||||
assert len(interceptor.inbound_calls) == 3
|
||||
assert interceptor.inbound_calls == responses
|
||||
|
||||
|
||||
class TestModifyingInterceptor:
|
||||
"""Test suite for interceptor that modifies messages."""
|
||||
|
||||
def test_outbound_header_modification(self) -> None:
|
||||
"""Test that interceptor can add headers to outbound requests."""
|
||||
interceptor = ModifyingInterceptor()
|
||||
request = httpx.Request("GET", "https://api.example.com/test")
|
||||
|
||||
result = interceptor.on_outbound(request)
|
||||
|
||||
assert result is request
|
||||
assert "X-Custom-Header" in result.headers
|
||||
assert result.headers["X-Custom-Header"] == "test-value"
|
||||
assert "X-Intercepted" in result.headers
|
||||
assert result.headers["X-Intercepted"] == "true"
|
||||
|
||||
def test_inbound_header_modification(self) -> None:
|
||||
"""Test that interceptor can add headers to inbound responses."""
|
||||
interceptor = ModifyingInterceptor()
|
||||
response = httpx.Response(200, json={"status": "ok"})
|
||||
|
||||
result = interceptor.on_inbound(response)
|
||||
|
||||
assert result is response
|
||||
assert "X-Response-Intercepted" in result.headers
|
||||
assert result.headers["X-Response-Intercepted"] == "true"
|
||||
|
||||
def test_preserves_existing_headers(self) -> None:
|
||||
"""Test that interceptor preserves existing headers."""
|
||||
interceptor = ModifyingInterceptor()
|
||||
request = httpx.Request(
|
||||
"GET",
|
||||
"https://api.example.com/test",
|
||||
headers={"Authorization": "Bearer token123", "Content-Type": "application/json"},
|
||||
)
|
||||
|
||||
result = interceptor.on_outbound(request)
|
||||
|
||||
assert result.headers["Authorization"] == "Bearer token123"
|
||||
assert result.headers["Content-Type"] == "application/json"
|
||||
assert result.headers["X-Custom-Header"] == "test-value"
|
||||
|
||||
|
||||
class TestAsyncInterceptor:
|
||||
"""Test suite for async interceptor functionality."""
|
||||
|
||||
def test_sync_methods_work(self) -> None:
|
||||
"""Test that sync methods still work on async interceptor."""
|
||||
interceptor = AsyncInterceptor()
|
||||
request = httpx.Request("GET", "https://api.example.com/test")
|
||||
response = httpx.Response(200)
|
||||
|
||||
req_result = interceptor.on_outbound(request)
|
||||
resp_result = interceptor.on_inbound(response)
|
||||
|
||||
assert req_result is request
|
||||
assert resp_result is response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_outbound(self) -> None:
|
||||
"""Test async outbound hook."""
|
||||
interceptor = AsyncInterceptor()
|
||||
request = httpx.Request("GET", "https://api.example.com/test")
|
||||
|
||||
result = await interceptor.aon_outbound(request)
|
||||
|
||||
assert result is request
|
||||
assert len(interceptor.async_outbound_calls) == 1
|
||||
assert interceptor.async_outbound_calls[0] is request
|
||||
assert "X-Async-Outbound" in result.headers
|
||||
assert result.headers["X-Async-Outbound"] == "true"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_inbound(self) -> None:
|
||||
"""Test async inbound hook."""
|
||||
interceptor = AsyncInterceptor()
|
||||
response = httpx.Response(200, json={"status": "ok"})
|
||||
|
||||
result = await interceptor.aon_inbound(response)
|
||||
|
||||
assert result is response
|
||||
assert len(interceptor.async_inbound_calls) == 1
|
||||
assert interceptor.async_inbound_calls[0] is response
|
||||
assert "X-Async-Inbound" in result.headers
|
||||
assert result.headers["X-Async-Inbound"] == "true"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_async_not_implemented(self) -> None:
|
||||
"""Test that default async methods raise NotImplementedError."""
|
||||
interceptor = SimpleInterceptor()
|
||||
request = httpx.Request("GET", "https://api.example.com/test")
|
||||
response = httpx.Response(200)
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
await interceptor.aon_outbound(request)
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
await interceptor.aon_inbound(response)
|
||||
Reference in New Issue
Block a user