mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 16:18:30 +00:00
feat: add support for llm message interceptor hooks
This commit is contained in:
1
lib/crewai/tests/llms/hooks/__init__.py
Normal file
1
lib/crewai/tests/llms/hooks/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for LLM interceptor hooks functionality."""
|
||||
311
lib/crewai/tests/llms/hooks/test_anthropic_interceptor.py
Normal file
311
lib/crewai/tests/llms/hooks/test_anthropic_interceptor.py
Normal file
@@ -0,0 +1,311 @@
|
||||
"""Tests for Anthropic provider with interceptor integration."""
|
||||
|
||||
import os
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_anthropic_api_key(monkeypatch):
|
||||
"""Set dummy Anthropic API key for tests that don't make real API calls."""
|
||||
if "ANTHROPIC_API_KEY" not in os.environ:
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-test-key-dummy")
|
||||
|
||||
|
||||
class AnthropicTestInterceptor(BaseInterceptor[httpx.Request, httpx.Response]):
|
||||
"""Test interceptor for Anthropic provider."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize tracking and modification state."""
|
||||
self.outbound_calls: list[httpx.Request] = []
|
||||
self.inbound_calls: list[httpx.Response] = []
|
||||
self.custom_header_value = "anthropic-test-value"
|
||||
|
||||
def on_outbound(self, message: httpx.Request) -> httpx.Request:
|
||||
"""Track and modify outbound Anthropic requests.
|
||||
|
||||
Args:
|
||||
message: The outbound request.
|
||||
|
||||
Returns:
|
||||
Modified request with custom headers.
|
||||
"""
|
||||
self.outbound_calls.append(message)
|
||||
message.headers["X-Anthropic-Interceptor"] = self.custom_header_value
|
||||
message.headers["X-Request-ID"] = "test-request-456"
|
||||
return message
|
||||
|
||||
def on_inbound(self, message: httpx.Response) -> httpx.Response:
|
||||
"""Track inbound Anthropic responses.
|
||||
|
||||
Args:
|
||||
message: The inbound response.
|
||||
|
||||
Returns:
|
||||
The response with tracking header.
|
||||
"""
|
||||
self.inbound_calls.append(message)
|
||||
message.headers["X-Response-Tracked"] = "true"
|
||||
return message
|
||||
|
||||
|
||||
class TestAnthropicInterceptorIntegration:
|
||||
"""Test suite for Anthropic provider with interceptor."""
|
||||
|
||||
def test_anthropic_llm_accepts_interceptor(self) -> None:
|
||||
"""Test that Anthropic LLM accepts interceptor parameter."""
|
||||
interceptor = AnthropicTestInterceptor()
|
||||
|
||||
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022", interceptor=interceptor)
|
||||
|
||||
assert llm.interceptor is interceptor
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization", "x-api-key"])
|
||||
def test_anthropic_call_with_interceptor_tracks_requests(self) -> None:
|
||||
"""Test that interceptor tracks Anthropic API requests."""
|
||||
interceptor = AnthropicTestInterceptor()
|
||||
llm = LLM(model="anthropic/claude-3-5-haiku-20241022", interceptor=interceptor)
|
||||
|
||||
# Make a simple completion call
|
||||
result = llm.call(
|
||||
messages=[{"role": "user", "content": "Say 'Hello World' and nothing else"}]
|
||||
)
|
||||
|
||||
# Verify custom headers were added
|
||||
for request in interceptor.outbound_calls:
|
||||
assert "X-Anthropic-Interceptor" in request.headers
|
||||
assert request.headers["X-Anthropic-Interceptor"] == "anthropic-test-value"
|
||||
assert "X-Request-ID" in request.headers
|
||||
assert request.headers["X-Request-ID"] == "test-request-456"
|
||||
|
||||
# Verify response was tracked
|
||||
for response in interceptor.inbound_calls:
|
||||
assert "X-Response-Tracked" in response.headers
|
||||
assert response.headers["X-Response-Tracked"] == "true"
|
||||
|
||||
# Verify result is valid
|
||||
assert result is not None
|
||||
assert isinstance(result, str)
|
||||
assert len(result) > 0
|
||||
|
||||
def test_anthropic_without_interceptor_works(self) -> None:
|
||||
"""Test that Anthropic LLM works without interceptor."""
|
||||
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
|
||||
assert llm.interceptor is None
|
||||
|
||||
def test_multiple_anthropic_llms_different_interceptors(self) -> None:
|
||||
"""Test that multiple Anthropic LLMs can have different interceptors."""
|
||||
interceptor1 = AnthropicTestInterceptor()
|
||||
interceptor1.custom_header_value = "claude-opus-value"
|
||||
|
||||
interceptor2 = AnthropicTestInterceptor()
|
||||
interceptor2.custom_header_value = "claude-sonnet-value"
|
||||
|
||||
llm1 = LLM(model="anthropic/claude-3-opus-20240229", interceptor=interceptor1)
|
||||
llm2 = LLM(model="anthropic/claude-3-5-sonnet-20241022", interceptor=interceptor2)
|
||||
|
||||
assert llm1.interceptor is interceptor1
|
||||
assert llm2.interceptor is interceptor2
|
||||
assert llm1.interceptor.custom_header_value == "claude-opus-value"
|
||||
assert llm2.interceptor.custom_header_value == "claude-sonnet-value"
|
||||
|
||||
|
||||
class AnthropicLoggingInterceptor(BaseInterceptor[httpx.Request, httpx.Response]):
|
||||
"""Interceptor that logs Anthropic request/response details."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize logging lists."""
|
||||
self.request_urls: list[str] = []
|
||||
self.request_methods: list[str] = []
|
||||
self.response_status_codes: list[int] = []
|
||||
self.anthropic_version_headers: list[str] = []
|
||||
|
||||
def on_outbound(self, message: httpx.Request) -> httpx.Request:
|
||||
"""Log outbound request details.
|
||||
|
||||
Args:
|
||||
message: The outbound request.
|
||||
|
||||
Returns:
|
||||
The request unchanged.
|
||||
"""
|
||||
self.request_urls.append(str(message.url))
|
||||
self.request_methods.append(message.method)
|
||||
if "anthropic-version" in message.headers:
|
||||
self.anthropic_version_headers.append(message.headers["anthropic-version"])
|
||||
return message
|
||||
|
||||
def on_inbound(self, message: httpx.Response) -> httpx.Response:
|
||||
"""Log inbound response details.
|
||||
|
||||
Args:
|
||||
message: The inbound response.
|
||||
|
||||
Returns:
|
||||
The response unchanged.
|
||||
"""
|
||||
self.response_status_codes.append(message.status_code)
|
||||
return message
|
||||
|
||||
|
||||
class TestAnthropicLoggingInterceptor:
|
||||
"""Test suite for logging interceptor with Anthropic."""
|
||||
|
||||
def test_logging_interceptor_instantiation(self) -> None:
|
||||
"""Test that logging interceptor can be created with Anthropic LLM."""
|
||||
interceptor = AnthropicLoggingInterceptor()
|
||||
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022", interceptor=interceptor)
|
||||
|
||||
assert llm.interceptor is interceptor
|
||||
assert isinstance(llm.interceptor, AnthropicLoggingInterceptor)
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization", "x-api-key"])
|
||||
def test_logging_interceptor_tracks_details(self) -> None:
|
||||
"""Test that logging interceptor tracks request/response details."""
|
||||
interceptor = AnthropicLoggingInterceptor()
|
||||
llm = LLM(model="anthropic/claude-3-5-haiku-20241022", interceptor=interceptor)
|
||||
|
||||
# Make a completion call
|
||||
result = llm.call(messages=[{"role": "user", "content": "Count from 1 to 3"}])
|
||||
|
||||
# Verify URL points to Anthropic API
|
||||
for url in interceptor.request_urls:
|
||||
assert "anthropic" in url.lower() or "api" in url.lower()
|
||||
|
||||
# Verify methods are POST (messages endpoint uses POST)
|
||||
for method in interceptor.request_methods:
|
||||
assert method == "POST"
|
||||
|
||||
# Verify successful status codes
|
||||
for status_code in interceptor.response_status_codes:
|
||||
assert 200 <= status_code < 300
|
||||
|
||||
|
||||
# Verify result is valid
|
||||
assert result is not None
|
||||
|
||||
|
||||
class AnthropicHeaderInterceptor(BaseInterceptor[httpx.Request, httpx.Response]):
|
||||
"""Interceptor that adds Anthropic-specific headers."""
|
||||
|
||||
def __init__(self, workspace_id: str, user_id: str) -> None:
|
||||
"""Initialize with Anthropic-specific metadata.
|
||||
|
||||
Args:
|
||||
workspace_id: The workspace ID to inject.
|
||||
user_id: The user ID to inject.
|
||||
"""
|
||||
self.workspace_id = workspace_id
|
||||
self.user_id = user_id
|
||||
|
||||
def on_outbound(self, message: httpx.Request) -> httpx.Request:
|
||||
"""Add custom metadata headers to request.
|
||||
|
||||
Args:
|
||||
message: The outbound request.
|
||||
|
||||
Returns:
|
||||
Request with metadata headers.
|
||||
"""
|
||||
message.headers["X-Workspace-ID"] = self.workspace_id
|
||||
message.headers["X-User-ID"] = self.user_id
|
||||
message.headers["X-Custom-Client"] = "crewai-interceptor"
|
||||
return message
|
||||
|
||||
def on_inbound(self, message: httpx.Response) -> httpx.Response:
|
||||
"""Pass through inbound response.
|
||||
|
||||
Args:
|
||||
message: The inbound response.
|
||||
|
||||
Returns:
|
||||
The response unchanged.
|
||||
"""
|
||||
return message
|
||||
|
||||
|
||||
class TestAnthropicHeaderInterceptor:
|
||||
"""Test suite for header interceptor with Anthropic."""
|
||||
|
||||
def test_header_interceptor_with_anthropic(self) -> None:
|
||||
"""Test that header interceptor can be used with Anthropic LLM."""
|
||||
interceptor = AnthropicHeaderInterceptor(
|
||||
workspace_id="ws-789", user_id="user-012"
|
||||
)
|
||||
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022", interceptor=interceptor)
|
||||
|
||||
assert llm.interceptor is interceptor
|
||||
assert llm.interceptor.workspace_id == "ws-789"
|
||||
assert llm.interceptor.user_id == "user-012"
|
||||
|
||||
def test_header_interceptor_adds_headers(self) -> None:
|
||||
"""Test that header interceptor adds custom headers to requests."""
|
||||
interceptor = AnthropicHeaderInterceptor(workspace_id="ws-123", user_id="u-456")
|
||||
request = httpx.Request("POST", "https://api.anthropic.com/v1/messages")
|
||||
|
||||
modified_request = interceptor.on_outbound(request)
|
||||
|
||||
assert "X-Workspace-ID" in modified_request.headers
|
||||
assert modified_request.headers["X-Workspace-ID"] == "ws-123"
|
||||
assert "X-User-ID" in modified_request.headers
|
||||
assert modified_request.headers["X-User-ID"] == "u-456"
|
||||
assert "X-Custom-Client" in modified_request.headers
|
||||
assert modified_request.headers["X-Custom-Client"] == "crewai-interceptor"
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization", "x-api-key"])
|
||||
def test_header_interceptor_with_real_call(self) -> None:
|
||||
"""Test that header interceptor works with real Anthropic API call."""
|
||||
interceptor = AnthropicHeaderInterceptor(workspace_id="ws-999", user_id="u-888")
|
||||
llm = LLM(model="anthropic/claude-3-5-haiku-20241022", interceptor=interceptor)
|
||||
|
||||
# Make a simple call
|
||||
result = llm.call(
|
||||
messages=[{"role": "user", "content": "Reply with just the word: SUCCESS"}]
|
||||
)
|
||||
|
||||
# Verify the call succeeded
|
||||
assert result is not None
|
||||
assert len(result) > 0
|
||||
|
||||
# Verify the interceptor was configured
|
||||
assert llm.interceptor is interceptor
|
||||
|
||||
|
||||
class TestMixedProviderInterceptors:
|
||||
"""Test suite for using interceptors with different providers."""
|
||||
|
||||
def test_openai_and_anthropic_different_interceptors(self) -> None:
|
||||
"""Test that OpenAI and Anthropic LLMs can have different interceptors."""
|
||||
openai_interceptor = AnthropicTestInterceptor()
|
||||
openai_interceptor.custom_header_value = "openai-specific"
|
||||
|
||||
anthropic_interceptor = AnthropicTestInterceptor()
|
||||
anthropic_interceptor.custom_header_value = "anthropic-specific"
|
||||
|
||||
openai_llm = LLM(model="gpt-4", interceptor=openai_interceptor)
|
||||
anthropic_llm = LLM(
|
||||
model="anthropic/claude-3-5-sonnet-20241022", interceptor=anthropic_interceptor
|
||||
)
|
||||
|
||||
assert openai_llm.interceptor is openai_interceptor
|
||||
assert anthropic_llm.interceptor is anthropic_interceptor
|
||||
assert openai_llm.interceptor.custom_header_value == "openai-specific"
|
||||
assert anthropic_llm.interceptor.custom_header_value == "anthropic-specific"
|
||||
|
||||
def test_same_interceptor_different_providers(self) -> None:
|
||||
"""Test that same interceptor instance can be used with multiple providers."""
|
||||
shared_interceptor = AnthropicTestInterceptor()
|
||||
|
||||
openai_llm = LLM(model="gpt-4", interceptor=shared_interceptor)
|
||||
anthropic_llm = LLM(
|
||||
model="anthropic/claude-3-5-sonnet-20241022", interceptor=shared_interceptor
|
||||
)
|
||||
|
||||
assert openai_llm.interceptor is shared_interceptor
|
||||
assert anthropic_llm.interceptor is shared_interceptor
|
||||
assert openai_llm.interceptor is anthropic_llm.interceptor
|
||||
287
lib/crewai/tests/llms/hooks/test_base_interceptor.py
Normal file
287
lib/crewai/tests/llms/hooks/test_base_interceptor.py
Normal file
@@ -0,0 +1,287 @@
|
||||
"""Tests for base interceptor functionality."""
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
|
||||
|
||||
class SimpleInterceptor(BaseInterceptor[httpx.Request, httpx.Response]):
|
||||
"""Simple test interceptor implementation."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize tracking lists."""
|
||||
self.outbound_calls: list[httpx.Request] = []
|
||||
self.inbound_calls: list[httpx.Response] = []
|
||||
|
||||
def on_outbound(self, message: httpx.Request) -> httpx.Request:
|
||||
"""Track outbound calls.
|
||||
|
||||
Args:
|
||||
message: The outbound request.
|
||||
|
||||
Returns:
|
||||
The request unchanged.
|
||||
"""
|
||||
self.outbound_calls.append(message)
|
||||
return message
|
||||
|
||||
def on_inbound(self, message: httpx.Response) -> httpx.Response:
|
||||
"""Track inbound calls.
|
||||
|
||||
Args:
|
||||
message: The inbound response.
|
||||
|
||||
Returns:
|
||||
The response unchanged.
|
||||
"""
|
||||
self.inbound_calls.append(message)
|
||||
return message
|
||||
|
||||
|
||||
class ModifyingInterceptor(BaseInterceptor[httpx.Request, httpx.Response]):
|
||||
"""Interceptor that modifies requests and responses."""
|
||||
|
||||
def on_outbound(self, message: httpx.Request) -> httpx.Request:
|
||||
"""Add custom header to outbound request.
|
||||
|
||||
Args:
|
||||
message: The outbound request.
|
||||
|
||||
Returns:
|
||||
Modified request with custom header.
|
||||
"""
|
||||
message.headers["X-Custom-Header"] = "test-value"
|
||||
message.headers["X-Intercepted"] = "true"
|
||||
return message
|
||||
|
||||
def on_inbound(self, message: httpx.Response) -> httpx.Response:
|
||||
"""Add custom header to inbound response.
|
||||
|
||||
Args:
|
||||
message: The inbound response.
|
||||
|
||||
Returns:
|
||||
Modified response with custom header.
|
||||
"""
|
||||
message.headers["X-Response-Intercepted"] = "true"
|
||||
return message
|
||||
|
||||
|
||||
class AsyncInterceptor(BaseInterceptor[httpx.Request, httpx.Response]):
|
||||
"""Interceptor with async support."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize tracking lists."""
|
||||
self.async_outbound_calls: list[httpx.Request] = []
|
||||
self.async_inbound_calls: list[httpx.Response] = []
|
||||
|
||||
def on_outbound(self, message: httpx.Request) -> httpx.Request:
|
||||
"""Handle sync outbound.
|
||||
|
||||
Args:
|
||||
message: The outbound request.
|
||||
|
||||
Returns:
|
||||
The request unchanged.
|
||||
"""
|
||||
return message
|
||||
|
||||
def on_inbound(self, message: httpx.Response) -> httpx.Response:
|
||||
"""Handle sync inbound.
|
||||
|
||||
Args:
|
||||
message: The inbound response.
|
||||
|
||||
Returns:
|
||||
The response unchanged.
|
||||
"""
|
||||
return message
|
||||
|
||||
async def aon_outbound(self, message: httpx.Request) -> httpx.Request:
|
||||
"""Handle async outbound.
|
||||
|
||||
Args:
|
||||
message: The outbound request.
|
||||
|
||||
Returns:
|
||||
Modified request with async header.
|
||||
"""
|
||||
self.async_outbound_calls.append(message)
|
||||
message.headers["X-Async-Outbound"] = "true"
|
||||
return message
|
||||
|
||||
async def aon_inbound(self, message: httpx.Response) -> httpx.Response:
|
||||
"""Handle async inbound.
|
||||
|
||||
Args:
|
||||
message: The inbound response.
|
||||
|
||||
Returns:
|
||||
Modified response with async header.
|
||||
"""
|
||||
self.async_inbound_calls.append(message)
|
||||
message.headers["X-Async-Inbound"] = "true"
|
||||
return message
|
||||
|
||||
|
||||
class TestBaseInterceptor:
|
||||
"""Test suite for BaseInterceptor class."""
|
||||
|
||||
def test_interceptor_instantiation(self) -> None:
|
||||
"""Test that interceptor can be instantiated."""
|
||||
interceptor = SimpleInterceptor()
|
||||
assert interceptor is not None
|
||||
assert isinstance(interceptor, BaseInterceptor)
|
||||
|
||||
def test_on_outbound_called(self) -> None:
|
||||
"""Test that on_outbound is called and tracks requests."""
|
||||
interceptor = SimpleInterceptor()
|
||||
request = httpx.Request("GET", "https://api.example.com/test")
|
||||
|
||||
result = interceptor.on_outbound(request)
|
||||
|
||||
assert len(interceptor.outbound_calls) == 1
|
||||
assert interceptor.outbound_calls[0] is request
|
||||
assert result is request
|
||||
|
||||
def test_on_inbound_called(self) -> None:
|
||||
"""Test that on_inbound is called and tracks responses."""
|
||||
interceptor = SimpleInterceptor()
|
||||
response = httpx.Response(200, json={"status": "ok"})
|
||||
|
||||
result = interceptor.on_inbound(response)
|
||||
|
||||
assert len(interceptor.inbound_calls) == 1
|
||||
assert interceptor.inbound_calls[0] is response
|
||||
assert result is response
|
||||
|
||||
def test_multiple_outbound_calls(self) -> None:
|
||||
"""Test that interceptor tracks multiple outbound calls."""
|
||||
interceptor = SimpleInterceptor()
|
||||
requests = [
|
||||
httpx.Request("GET", "https://api.example.com/1"),
|
||||
httpx.Request("POST", "https://api.example.com/2"),
|
||||
httpx.Request("PUT", "https://api.example.com/3"),
|
||||
]
|
||||
|
||||
for req in requests:
|
||||
interceptor.on_outbound(req)
|
||||
|
||||
assert len(interceptor.outbound_calls) == 3
|
||||
assert interceptor.outbound_calls == requests
|
||||
|
||||
def test_multiple_inbound_calls(self) -> None:
|
||||
"""Test that interceptor tracks multiple inbound calls."""
|
||||
interceptor = SimpleInterceptor()
|
||||
responses = [
|
||||
httpx.Response(200, json={"id": 1}),
|
||||
httpx.Response(201, json={"id": 2}),
|
||||
httpx.Response(404, json={"error": "not found"}),
|
||||
]
|
||||
|
||||
for resp in responses:
|
||||
interceptor.on_inbound(resp)
|
||||
|
||||
assert len(interceptor.inbound_calls) == 3
|
||||
assert interceptor.inbound_calls == responses
|
||||
|
||||
|
||||
class TestModifyingInterceptor:
|
||||
"""Test suite for interceptor that modifies messages."""
|
||||
|
||||
def test_outbound_header_modification(self) -> None:
|
||||
"""Test that interceptor can add headers to outbound requests."""
|
||||
interceptor = ModifyingInterceptor()
|
||||
request = httpx.Request("GET", "https://api.example.com/test")
|
||||
|
||||
result = interceptor.on_outbound(request)
|
||||
|
||||
assert result is request
|
||||
assert "X-Custom-Header" in result.headers
|
||||
assert result.headers["X-Custom-Header"] == "test-value"
|
||||
assert "X-Intercepted" in result.headers
|
||||
assert result.headers["X-Intercepted"] == "true"
|
||||
|
||||
def test_inbound_header_modification(self) -> None:
|
||||
"""Test that interceptor can add headers to inbound responses."""
|
||||
interceptor = ModifyingInterceptor()
|
||||
response = httpx.Response(200, json={"status": "ok"})
|
||||
|
||||
result = interceptor.on_inbound(response)
|
||||
|
||||
assert result is response
|
||||
assert "X-Response-Intercepted" in result.headers
|
||||
assert result.headers["X-Response-Intercepted"] == "true"
|
||||
|
||||
def test_preserves_existing_headers(self) -> None:
|
||||
"""Test that interceptor preserves existing headers."""
|
||||
interceptor = ModifyingInterceptor()
|
||||
request = httpx.Request(
|
||||
"GET",
|
||||
"https://api.example.com/test",
|
||||
headers={"Authorization": "Bearer token123", "Content-Type": "application/json"},
|
||||
)
|
||||
|
||||
result = interceptor.on_outbound(request)
|
||||
|
||||
assert result.headers["Authorization"] == "Bearer token123"
|
||||
assert result.headers["Content-Type"] == "application/json"
|
||||
assert result.headers["X-Custom-Header"] == "test-value"
|
||||
|
||||
|
||||
class TestAsyncInterceptor:
|
||||
"""Test suite for async interceptor functionality."""
|
||||
|
||||
def test_sync_methods_work(self) -> None:
|
||||
"""Test that sync methods still work on async interceptor."""
|
||||
interceptor = AsyncInterceptor()
|
||||
request = httpx.Request("GET", "https://api.example.com/test")
|
||||
response = httpx.Response(200)
|
||||
|
||||
req_result = interceptor.on_outbound(request)
|
||||
resp_result = interceptor.on_inbound(response)
|
||||
|
||||
assert req_result is request
|
||||
assert resp_result is response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_outbound(self) -> None:
|
||||
"""Test async outbound hook."""
|
||||
interceptor = AsyncInterceptor()
|
||||
request = httpx.Request("GET", "https://api.example.com/test")
|
||||
|
||||
result = await interceptor.aon_outbound(request)
|
||||
|
||||
assert result is request
|
||||
assert len(interceptor.async_outbound_calls) == 1
|
||||
assert interceptor.async_outbound_calls[0] is request
|
||||
assert "X-Async-Outbound" in result.headers
|
||||
assert result.headers["X-Async-Outbound"] == "true"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_inbound(self) -> None:
|
||||
"""Test async inbound hook."""
|
||||
interceptor = AsyncInterceptor()
|
||||
response = httpx.Response(200, json={"status": "ok"})
|
||||
|
||||
result = await interceptor.aon_inbound(response)
|
||||
|
||||
assert result is response
|
||||
assert len(interceptor.async_inbound_calls) == 1
|
||||
assert interceptor.async_inbound_calls[0] is response
|
||||
assert "X-Async-Inbound" in result.headers
|
||||
assert result.headers["X-Async-Inbound"] == "true"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_async_not_implemented(self) -> None:
|
||||
"""Test that default async methods raise NotImplementedError."""
|
||||
interceptor = SimpleInterceptor()
|
||||
request = httpx.Request("GET", "https://api.example.com/test")
|
||||
response = httpx.Response(200)
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
await interceptor.aon_outbound(request)
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
await interceptor.aon_inbound(response)
|
||||
262
lib/crewai/tests/llms/hooks/test_openai_interceptor.py
Normal file
262
lib/crewai/tests/llms/hooks/test_openai_interceptor.py
Normal file
@@ -0,0 +1,262 @@
|
||||
"""Tests for OpenAI provider with interceptor integration."""
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
|
||||
|
||||
class OpenAITestInterceptor(BaseInterceptor[httpx.Request, httpx.Response]):
|
||||
"""Test interceptor for OpenAI provider."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize tracking and modification state."""
|
||||
self.outbound_calls: list[httpx.Request] = []
|
||||
self.inbound_calls: list[httpx.Response] = []
|
||||
self.custom_header_value = "openai-test-value"
|
||||
|
||||
def on_outbound(self, message: httpx.Request) -> httpx.Request:
|
||||
"""Track and modify outbound OpenAI requests.
|
||||
|
||||
Args:
|
||||
message: The outbound request.
|
||||
|
||||
Returns:
|
||||
Modified request with custom headers.
|
||||
"""
|
||||
self.outbound_calls.append(message)
|
||||
message.headers["X-OpenAI-Interceptor"] = self.custom_header_value
|
||||
message.headers["X-Request-ID"] = "test-request-123"
|
||||
return message
|
||||
|
||||
def on_inbound(self, message: httpx.Response) -> httpx.Response:
|
||||
"""Track inbound OpenAI responses.
|
||||
|
||||
Args:
|
||||
message: The inbound response.
|
||||
|
||||
Returns:
|
||||
The response with tracking header.
|
||||
"""
|
||||
self.inbound_calls.append(message)
|
||||
message.headers["X-Response-Tracked"] = "true"
|
||||
return message
|
||||
|
||||
|
||||
class TestOpenAIInterceptorIntegration:
|
||||
"""Test suite for OpenAI provider with interceptor."""
|
||||
|
||||
def test_openai_llm_accepts_interceptor(self) -> None:
|
||||
"""Test that OpenAI LLM accepts interceptor parameter."""
|
||||
interceptor = OpenAITestInterceptor()
|
||||
|
||||
llm = LLM(model="gpt-4", interceptor=interceptor)
|
||||
|
||||
assert llm.interceptor is interceptor
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_openai_call_with_interceptor_tracks_requests(self) -> None:
|
||||
"""Test that interceptor tracks OpenAI API requests."""
|
||||
interceptor = OpenAITestInterceptor()
|
||||
llm = LLM(model="gpt-4o-mini", interceptor=interceptor)
|
||||
|
||||
# Make a simple completion call
|
||||
result = llm.call(
|
||||
messages=[{"role": "user", "content": "Say 'Hello World' and nothing else"}]
|
||||
)
|
||||
|
||||
# Verify custom headers were added
|
||||
for request in interceptor.outbound_calls:
|
||||
assert "X-OpenAI-Interceptor" in request.headers
|
||||
assert request.headers["X-OpenAI-Interceptor"] == "openai-test-value"
|
||||
assert "X-Request-ID" in request.headers
|
||||
assert request.headers["X-Request-ID"] == "test-request-123"
|
||||
|
||||
# Verify response was tracked
|
||||
for response in interceptor.inbound_calls:
|
||||
assert "X-Response-Tracked" in response.headers
|
||||
assert response.headers["X-Response-Tracked"] == "true"
|
||||
|
||||
# Verify result is valid
|
||||
assert result is not None
|
||||
assert isinstance(result, str)
|
||||
assert len(result) > 0
|
||||
|
||||
def test_openai_without_interceptor_works(self) -> None:
|
||||
"""Test that OpenAI LLM works without interceptor."""
|
||||
llm = LLM(model="gpt-4")
|
||||
|
||||
assert llm.interceptor is None
|
||||
|
||||
def test_multiple_openai_llms_different_interceptors(self) -> None:
|
||||
"""Test that multiple OpenAI LLMs can have different interceptors."""
|
||||
interceptor1 = OpenAITestInterceptor()
|
||||
interceptor1.custom_header_value = "llm1-value"
|
||||
|
||||
interceptor2 = OpenAITestInterceptor()
|
||||
interceptor2.custom_header_value = "llm2-value"
|
||||
|
||||
llm1 = LLM(model="gpt-4", interceptor=interceptor1)
|
||||
llm2 = LLM(model="gpt-3.5-turbo", interceptor=interceptor2)
|
||||
|
||||
assert llm1.interceptor is interceptor1
|
||||
assert llm2.interceptor is interceptor2
|
||||
assert llm1.interceptor.custom_header_value == "llm1-value"
|
||||
assert llm2.interceptor.custom_header_value == "llm2-value"
|
||||
|
||||
|
||||
class LoggingInterceptor(BaseInterceptor[httpx.Request, httpx.Response]):
|
||||
"""Interceptor that logs request/response details for testing."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize logging lists."""
|
||||
self.request_urls: list[str] = []
|
||||
self.request_methods: list[str] = []
|
||||
self.response_status_codes: list[int] = []
|
||||
|
||||
def on_outbound(self, message: httpx.Request) -> httpx.Request:
|
||||
"""Log outbound request details.
|
||||
|
||||
Args:
|
||||
message: The outbound request.
|
||||
|
||||
Returns:
|
||||
The request unchanged.
|
||||
"""
|
||||
self.request_urls.append(str(message.url))
|
||||
self.request_methods.append(message.method)
|
||||
return message
|
||||
|
||||
def on_inbound(self, message: httpx.Response) -> httpx.Response:
|
||||
"""Log inbound response details.
|
||||
|
||||
Args:
|
||||
message: The inbound response.
|
||||
|
||||
Returns:
|
||||
The response unchanged.
|
||||
"""
|
||||
self.response_status_codes.append(message.status_code)
|
||||
return message
|
||||
|
||||
|
||||
class TestOpenAILoggingInterceptor:
|
||||
"""Test suite for logging interceptor with OpenAI."""
|
||||
|
||||
def test_logging_interceptor_instantiation(self) -> None:
|
||||
"""Test that logging interceptor can be created with OpenAI LLM."""
|
||||
interceptor = LoggingInterceptor()
|
||||
llm = LLM(model="gpt-4", interceptor=interceptor)
|
||||
|
||||
assert llm.interceptor is interceptor
|
||||
assert isinstance(llm.interceptor, LoggingInterceptor)
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_logging_interceptor_tracks_details(self) -> None:
|
||||
"""Test that logging interceptor tracks request/response details."""
|
||||
interceptor = LoggingInterceptor()
|
||||
llm = LLM(model="gpt-4o-mini", interceptor=interceptor)
|
||||
|
||||
# Make a completion call
|
||||
result = llm.call(
|
||||
messages=[{"role": "user", "content": "Count from 1 to 3"}]
|
||||
)
|
||||
|
||||
# Verify URL points to OpenAI API
|
||||
for url in interceptor.request_urls:
|
||||
assert "openai" in url.lower() or "api" in url.lower()
|
||||
|
||||
# Verify methods are POST (chat completions use POST)
|
||||
for method in interceptor.request_methods:
|
||||
assert method == "POST"
|
||||
|
||||
# Verify successful status codes
|
||||
for status_code in interceptor.response_status_codes:
|
||||
assert 200 <= status_code < 300
|
||||
|
||||
# Verify result is valid
|
||||
assert result is not None
|
||||
|
||||
|
||||
class AuthInterceptor(BaseInterceptor[httpx.Request, httpx.Response]):
|
||||
"""Interceptor that adds authentication headers."""
|
||||
|
||||
def __init__(self, api_key: str, org_id: str) -> None:
|
||||
"""Initialize with auth credentials.
|
||||
|
||||
Args:
|
||||
api_key: The API key to inject.
|
||||
org_id: The organization ID to inject.
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.org_id = org_id
|
||||
|
||||
def on_outbound(self, message: httpx.Request) -> httpx.Request:
|
||||
"""Add authentication headers to request.
|
||||
|
||||
Args:
|
||||
message: The outbound request.
|
||||
|
||||
Returns:
|
||||
Request with auth headers.
|
||||
"""
|
||||
message.headers["X-Custom-API-Key"] = self.api_key
|
||||
message.headers["X-Organization-ID"] = self.org_id
|
||||
return message
|
||||
|
||||
def on_inbound(self, message: httpx.Response) -> httpx.Response:
|
||||
"""Pass through inbound response.
|
||||
|
||||
Args:
|
||||
message: The inbound response.
|
||||
|
||||
Returns:
|
||||
The response unchanged.
|
||||
"""
|
||||
return message
|
||||
|
||||
|
||||
class TestOpenAIAuthInterceptor:
|
||||
"""Test suite for authentication interceptor with OpenAI."""
|
||||
|
||||
def test_auth_interceptor_with_openai(self) -> None:
|
||||
"""Test that auth interceptor can be used with OpenAI LLM."""
|
||||
interceptor = AuthInterceptor(api_key="custom-key-123", org_id="org-456")
|
||||
llm = LLM(model="gpt-4", interceptor=interceptor)
|
||||
|
||||
assert llm.interceptor is interceptor
|
||||
assert llm.interceptor.api_key == "custom-key-123"
|
||||
assert llm.interceptor.org_id == "org-456"
|
||||
|
||||
def test_auth_interceptor_adds_headers(self) -> None:
|
||||
"""Test that auth interceptor adds custom headers to requests."""
|
||||
interceptor = AuthInterceptor(api_key="test-key", org_id="test-org")
|
||||
request = httpx.Request("POST", "https://api.openai.com/v1/chat/completions")
|
||||
|
||||
modified_request = interceptor.on_outbound(request)
|
||||
|
||||
assert "X-Custom-API-Key" in modified_request.headers
|
||||
assert modified_request.headers["X-Custom-API-Key"] == "test-key"
|
||||
assert "X-Organization-ID" in modified_request.headers
|
||||
assert modified_request.headers["X-Organization-ID"] == "test-org"
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_auth_interceptor_with_real_call(self) -> None:
|
||||
"""Test that auth interceptor works with real OpenAI API call."""
|
||||
interceptor = AuthInterceptor(api_key="custom-123", org_id="org-789")
|
||||
llm = LLM(model="gpt-4o-mini", interceptor=interceptor)
|
||||
|
||||
# Make a simple call
|
||||
result = llm.call(
|
||||
messages=[{"role": "user", "content": "Reply with just the word: SUCCESS"}]
|
||||
)
|
||||
|
||||
# Verify the call succeeded
|
||||
assert result is not None
|
||||
assert len(result) > 0
|
||||
|
||||
# Verify headers were added to outbound requests
|
||||
# (We can't directly inspect the request sent to OpenAI in this test,
|
||||
# but we verify the interceptor was configured and the call succeeded)
|
||||
assert llm.interceptor is interceptor
|
||||
248
lib/crewai/tests/llms/hooks/test_transport.py
Normal file
248
lib/crewai/tests/llms/hooks/test_transport.py
Normal file
@@ -0,0 +1,248 @@
|
||||
"""Tests for transport layer with interceptor integration."""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
from crewai.llms.hooks.transport import AsyncHTTPransport, HTTPTransport
|
||||
|
||||
|
||||
class TrackingInterceptor(BaseInterceptor[httpx.Request, httpx.Response]):
|
||||
"""Test interceptor that tracks all calls."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize tracking lists."""
|
||||
self.outbound_calls: list[httpx.Request] = []
|
||||
self.inbound_calls: list[httpx.Response] = []
|
||||
self.async_outbound_calls: list[httpx.Request] = []
|
||||
self.async_inbound_calls: list[httpx.Response] = []
|
||||
|
||||
def on_outbound(self, message: httpx.Request) -> httpx.Request:
|
||||
"""Track outbound calls and add header.
|
||||
|
||||
Args:
|
||||
message: The outbound request.
|
||||
|
||||
Returns:
|
||||
Modified request with tracking header.
|
||||
"""
|
||||
self.outbound_calls.append(message)
|
||||
message.headers["X-Intercepted-Sync"] = "true"
|
||||
return message
|
||||
|
||||
def on_inbound(self, message: httpx.Response) -> httpx.Response:
|
||||
"""Track inbound calls.
|
||||
|
||||
Args:
|
||||
message: The inbound response.
|
||||
|
||||
Returns:
|
||||
The response with tracking header.
|
||||
"""
|
||||
self.inbound_calls.append(message)
|
||||
message.headers["X-Response-Intercepted-Sync"] = "true"
|
||||
return message
|
||||
|
||||
async def aon_outbound(self, message: httpx.Request) -> httpx.Request:
|
||||
"""Track async outbound calls and add header.
|
||||
|
||||
Args:
|
||||
message: The outbound request.
|
||||
|
||||
Returns:
|
||||
Modified request with tracking header.
|
||||
"""
|
||||
self.async_outbound_calls.append(message)
|
||||
message.headers["X-Intercepted-Async"] = "true"
|
||||
return message
|
||||
|
||||
async def aon_inbound(self, message: httpx.Response) -> httpx.Response:
|
||||
"""Track async inbound calls.
|
||||
|
||||
Args:
|
||||
message: The inbound response.
|
||||
|
||||
Returns:
|
||||
The response with tracking header.
|
||||
"""
|
||||
self.async_inbound_calls.append(message)
|
||||
message.headers["X-Response-Intercepted-Async"] = "true"
|
||||
return message
|
||||
|
||||
|
||||
class TestHTTPTransport:
|
||||
"""Test suite for sync HTTPTransport with interceptor."""
|
||||
|
||||
def test_transport_instantiation(self) -> None:
|
||||
"""Test that transport can be instantiated with interceptor."""
|
||||
interceptor = TrackingInterceptor()
|
||||
transport = HTTPTransport(interceptor=interceptor)
|
||||
|
||||
assert transport.interceptor is interceptor
|
||||
|
||||
def test_transport_requires_interceptor(self) -> None:
|
||||
"""Test that transport requires interceptor parameter."""
|
||||
# HTTPTransport requires an interceptor parameter
|
||||
with pytest.raises(TypeError):
|
||||
HTTPTransport()
|
||||
|
||||
def test_interceptor_called_on_request(self) -> None:
|
||||
"""Test that interceptor hooks are called during request handling."""
|
||||
interceptor = TrackingInterceptor()
|
||||
transport = HTTPTransport(interceptor=interceptor)
|
||||
|
||||
# Create a mock parent transport that returns a response
|
||||
mock_response = httpx.Response(200, json={"success": True})
|
||||
mock_parent_handle = Mock(return_value=mock_response)
|
||||
|
||||
# Monkey-patch the parent's handle_request
|
||||
original_handle = httpx.HTTPTransport.handle_request
|
||||
httpx.HTTPTransport.handle_request = mock_parent_handle
|
||||
|
||||
try:
|
||||
request = httpx.Request("GET", "https://api.example.com/test")
|
||||
response = transport.handle_request(request)
|
||||
|
||||
# Verify interceptor was called
|
||||
assert len(interceptor.outbound_calls) == 1
|
||||
assert len(interceptor.inbound_calls) == 1
|
||||
assert interceptor.outbound_calls[0] is request
|
||||
assert interceptor.inbound_calls[0] is response
|
||||
|
||||
# Verify headers were added
|
||||
assert "X-Intercepted-Sync" in request.headers
|
||||
assert request.headers["X-Intercepted-Sync"] == "true"
|
||||
assert "X-Response-Intercepted-Sync" in response.headers
|
||||
assert response.headers["X-Response-Intercepted-Sync"] == "true"
|
||||
finally:
|
||||
# Restore original method
|
||||
httpx.HTTPTransport.handle_request = original_handle
|
||||
|
||||
|
||||
|
||||
class TestAsyncHTTPTransport:
|
||||
"""Test suite for async AsyncHTTPransport with interceptor."""
|
||||
|
||||
def test_async_transport_instantiation(self) -> None:
|
||||
"""Test that async transport can be instantiated with interceptor."""
|
||||
interceptor = TrackingInterceptor()
|
||||
transport = AsyncHTTPransport(interceptor=interceptor)
|
||||
|
||||
assert transport.interceptor is interceptor
|
||||
|
||||
def test_async_transport_requires_interceptor(self) -> None:
|
||||
"""Test that async transport requires interceptor parameter."""
|
||||
# AsyncHTTPransport requires an interceptor parameter
|
||||
with pytest.raises(TypeError):
|
||||
AsyncHTTPransport()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_interceptor_called_on_request(self) -> None:
|
||||
"""Test that async interceptor hooks are called during request handling."""
|
||||
interceptor = TrackingInterceptor()
|
||||
transport = AsyncHTTPransport(interceptor=interceptor)
|
||||
|
||||
# Create a mock parent transport that returns a response
|
||||
mock_response = httpx.Response(200, json={"success": True})
|
||||
|
||||
async def mock_handle(*args, **kwargs):
|
||||
return mock_response
|
||||
|
||||
mock_parent_handle = Mock(side_effect=mock_handle)
|
||||
|
||||
# Monkey-patch the parent's handle_async_request
|
||||
original_handle = httpx.AsyncHTTPTransport.handle_async_request
|
||||
httpx.AsyncHTTPTransport.handle_async_request = mock_parent_handle
|
||||
|
||||
try:
|
||||
request = httpx.Request("GET", "https://api.example.com/test")
|
||||
response = await transport.handle_async_request(request)
|
||||
|
||||
# Verify async interceptor was called
|
||||
assert len(interceptor.async_outbound_calls) == 1
|
||||
assert len(interceptor.async_inbound_calls) == 1
|
||||
assert interceptor.async_outbound_calls[0] is request
|
||||
assert interceptor.async_inbound_calls[0] is response
|
||||
|
||||
# Verify sync interceptor was NOT called
|
||||
assert len(interceptor.outbound_calls) == 0
|
||||
assert len(interceptor.inbound_calls) == 0
|
||||
|
||||
# Verify async headers were added
|
||||
assert "X-Intercepted-Async" in request.headers
|
||||
assert request.headers["X-Intercepted-Async"] == "true"
|
||||
assert "X-Response-Intercepted-Async" in response.headers
|
||||
assert response.headers["X-Response-Intercepted-Async"] == "true"
|
||||
finally:
|
||||
# Restore original method
|
||||
httpx.AsyncHTTPTransport.handle_async_request = original_handle
|
||||
|
||||
|
||||
|
||||
class TestTransportIntegration:
|
||||
"""Test suite for transport integration scenarios."""
|
||||
|
||||
def test_multiple_requests_same_interceptor(self) -> None:
|
||||
"""Test that multiple requests through same interceptor are tracked."""
|
||||
interceptor = TrackingInterceptor()
|
||||
transport = HTTPTransport(interceptor=interceptor)
|
||||
|
||||
mock_response = httpx.Response(200)
|
||||
mock_parent_handle = Mock(return_value=mock_response)
|
||||
|
||||
original_handle = httpx.HTTPTransport.handle_request
|
||||
httpx.HTTPTransport.handle_request = mock_parent_handle
|
||||
|
||||
try:
|
||||
# Make multiple requests
|
||||
requests = [
|
||||
httpx.Request("GET", "https://api.example.com/1"),
|
||||
httpx.Request("POST", "https://api.example.com/2"),
|
||||
httpx.Request("PUT", "https://api.example.com/3"),
|
||||
]
|
||||
|
||||
for req in requests:
|
||||
transport.handle_request(req)
|
||||
|
||||
# Verify all requests were intercepted
|
||||
assert len(interceptor.outbound_calls) == 3
|
||||
assert len(interceptor.inbound_calls) == 3
|
||||
assert interceptor.outbound_calls == requests
|
||||
finally:
|
||||
httpx.HTTPTransport.handle_request = original_handle
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_async_requests_same_interceptor(self) -> None:
|
||||
"""Test that multiple async requests through same interceptor are tracked."""
|
||||
interceptor = TrackingInterceptor()
|
||||
transport = AsyncHTTPransport(interceptor=interceptor)
|
||||
|
||||
mock_response = httpx.Response(200)
|
||||
|
||||
async def mock_handle(*args, **kwargs):
|
||||
return mock_response
|
||||
|
||||
mock_parent_handle = Mock(side_effect=mock_handle)
|
||||
|
||||
original_handle = httpx.AsyncHTTPTransport.handle_async_request
|
||||
httpx.AsyncHTTPTransport.handle_async_request = mock_parent_handle
|
||||
|
||||
try:
|
||||
# Make multiple async requests
|
||||
requests = [
|
||||
httpx.Request("GET", "https://api.example.com/1"),
|
||||
httpx.Request("POST", "https://api.example.com/2"),
|
||||
httpx.Request("DELETE", "https://api.example.com/3"),
|
||||
]
|
||||
|
||||
for req in requests:
|
||||
await transport.handle_async_request(req)
|
||||
|
||||
# Verify all requests were intercepted
|
||||
assert len(interceptor.async_outbound_calls) == 3
|
||||
assert len(interceptor.async_inbound_calls) == 3
|
||||
assert interceptor.async_outbound_calls == requests
|
||||
finally:
|
||||
httpx.AsyncHTTPTransport.handle_async_request = original_handle
|
||||
319
lib/crewai/tests/llms/hooks/test_unsupported_providers.py
Normal file
319
lib/crewai/tests/llms/hooks/test_unsupported_providers.py
Normal file
@@ -0,0 +1,319 @@
|
||||
"""Tests for interceptor behavior with unsupported providers."""
|
||||
|
||||
import os
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_provider_api_keys(monkeypatch):
|
||||
"""Set dummy API keys for providers that require them."""
|
||||
if "OPENAI_API_KEY" not in os.environ:
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "sk-test-key-dummy")
|
||||
if "ANTHROPIC_API_KEY" not in os.environ:
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-test-key-dummy")
|
||||
if "GOOGLE_API_KEY" not in os.environ:
|
||||
monkeypatch.setenv("GOOGLE_API_KEY", "test-google-key-dummy")
|
||||
|
||||
|
||||
class DummyInterceptor(BaseInterceptor[httpx.Request, httpx.Response]):
|
||||
"""Simple dummy interceptor for testing."""
|
||||
|
||||
def on_outbound(self, message: httpx.Request) -> httpx.Request:
|
||||
"""Pass through outbound request.
|
||||
|
||||
Args:
|
||||
message: The outbound request.
|
||||
|
||||
Returns:
|
||||
The request unchanged.
|
||||
"""
|
||||
message.headers["X-Dummy"] = "true"
|
||||
return message
|
||||
|
||||
def on_inbound(self, message: httpx.Response) -> httpx.Response:
|
||||
"""Pass through inbound response.
|
||||
|
||||
Args:
|
||||
message: The inbound response.
|
||||
|
||||
Returns:
|
||||
The response unchanged.
|
||||
"""
|
||||
return message
|
||||
|
||||
|
||||
class TestAzureProviderInterceptor:
|
||||
"""Test suite for Azure provider with interceptor (unsupported)."""
|
||||
|
||||
def test_azure_llm_accepts_interceptor_parameter(self) -> None:
|
||||
"""Test that Azure LLM raises NotImplementedError with interceptor."""
|
||||
interceptor = DummyInterceptor()
|
||||
|
||||
# Azure provider should raise NotImplementedError
|
||||
with pytest.raises(NotImplementedError) as exc_info:
|
||||
LLM(
|
||||
model="azure/gpt-4",
|
||||
interceptor=interceptor,
|
||||
api_key="test-key",
|
||||
endpoint="https://test.openai.azure.com/openai/deployments/gpt-4",
|
||||
)
|
||||
|
||||
assert "interceptor" in str(exc_info.value).lower()
|
||||
|
||||
def test_azure_raises_not_implemented_on_initialization(self) -> None:
|
||||
"""Test that Azure raises NotImplementedError when interceptor is used."""
|
||||
interceptor = DummyInterceptor()
|
||||
|
||||
with pytest.raises(NotImplementedError) as exc_info:
|
||||
LLM(
|
||||
model="azure/gpt-4",
|
||||
interceptor=interceptor,
|
||||
api_key="test-key",
|
||||
endpoint="https://test.openai.azure.com/openai/deployments/gpt-4",
|
||||
)
|
||||
|
||||
error_msg = str(exc_info.value).lower()
|
||||
assert "interceptor" in error_msg
|
||||
assert "azure" in error_msg
|
||||
|
||||
def test_azure_without_interceptor_works(self) -> None:
|
||||
"""Test that Azure LLM works without interceptor."""
|
||||
llm = LLM(
|
||||
model="azure/gpt-4",
|
||||
api_key="test-key",
|
||||
endpoint="https://test.openai.azure.com/openai/deployments/gpt-4",
|
||||
)
|
||||
|
||||
# Azure provider doesn't have interceptor attribute
|
||||
assert not hasattr(llm, 'interceptor') or llm.interceptor is None
|
||||
|
||||
|
||||
class TestBedrockProviderInterceptor:
|
||||
"""Test suite for Bedrock provider with interceptor (unsupported)."""
|
||||
|
||||
def test_bedrock_llm_accepts_interceptor_parameter(self) -> None:
|
||||
"""Test that Bedrock LLM raises NotImplementedError with interceptor."""
|
||||
interceptor = DummyInterceptor()
|
||||
|
||||
# Bedrock provider should raise NotImplementedError
|
||||
with pytest.raises(NotImplementedError) as exc_info:
|
||||
LLM(
|
||||
model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
interceptor=interceptor,
|
||||
aws_access_key_id="test-access-key",
|
||||
aws_secret_access_key="test-secret-key",
|
||||
aws_region_name="us-east-1",
|
||||
)
|
||||
|
||||
error_msg = str(exc_info.value).lower()
|
||||
assert "interceptor" in error_msg
|
||||
assert "bedrock" in error_msg
|
||||
|
||||
def test_bedrock_raises_not_implemented_on_initialization(self) -> None:
|
||||
"""Test that Bedrock raises NotImplementedError when interceptor is used."""
|
||||
interceptor = DummyInterceptor()
|
||||
|
||||
with pytest.raises(NotImplementedError) as exc_info:
|
||||
LLM(
|
||||
model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
interceptor=interceptor,
|
||||
aws_access_key_id="test-access-key",
|
||||
aws_secret_access_key="test-secret-key",
|
||||
aws_region_name="us-east-1",
|
||||
)
|
||||
|
||||
error_msg = str(exc_info.value).lower()
|
||||
assert "interceptor" in error_msg
|
||||
assert "bedrock" in error_msg
|
||||
|
||||
def test_bedrock_without_interceptor_works(self) -> None:
|
||||
"""Test that Bedrock LLM works without interceptor."""
|
||||
llm = LLM(
|
||||
model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
aws_access_key_id="test-access-key",
|
||||
aws_secret_access_key="test-secret-key",
|
||||
aws_region_name="us-east-1",
|
||||
)
|
||||
|
||||
# Bedrock provider doesn't have interceptor attribute
|
||||
assert not hasattr(llm, 'interceptor') or llm.interceptor is None
|
||||
|
||||
|
||||
class TestGeminiProviderInterceptor:
|
||||
"""Test suite for Gemini provider with interceptor (unsupported)."""
|
||||
|
||||
def test_gemini_llm_accepts_interceptor_parameter(self) -> None:
|
||||
"""Test that Gemini LLM raises NotImplementedError with interceptor."""
|
||||
interceptor = DummyInterceptor()
|
||||
|
||||
# Gemini provider should raise NotImplementedError
|
||||
with pytest.raises(NotImplementedError) as exc_info:
|
||||
LLM(
|
||||
model="gemini/gemini-pro",
|
||||
interceptor=interceptor,
|
||||
api_key="test-gemini-key",
|
||||
)
|
||||
|
||||
error_msg = str(exc_info.value).lower()
|
||||
assert "interceptor" in error_msg
|
||||
assert "gemini" in error_msg
|
||||
|
||||
def test_gemini_raises_not_implemented_on_initialization(self) -> None:
|
||||
"""Test that Gemini raises NotImplementedError when interceptor is used."""
|
||||
interceptor = DummyInterceptor()
|
||||
|
||||
with pytest.raises(NotImplementedError) as exc_info:
|
||||
LLM(
|
||||
model="gemini/gemini-pro",
|
||||
interceptor=interceptor,
|
||||
api_key="test-gemini-key",
|
||||
)
|
||||
|
||||
error_msg = str(exc_info.value).lower()
|
||||
assert "interceptor" in error_msg
|
||||
assert "gemini" in error_msg
|
||||
|
||||
def test_gemini_without_interceptor_works(self) -> None:
|
||||
"""Test that Gemini LLM works without interceptor."""
|
||||
llm = LLM(
|
||||
model="gemini/gemini-pro",
|
||||
api_key="test-gemini-key",
|
||||
)
|
||||
|
||||
# Gemini provider doesn't have interceptor attribute
|
||||
assert not hasattr(llm, 'interceptor') or llm.interceptor is None
|
||||
|
||||
|
||||
class TestUnsupportedProviderMessages:
|
||||
"""Test suite for error messages from unsupported providers."""
|
||||
|
||||
def test_azure_error_message_is_clear(self) -> None:
|
||||
"""Test that Azure error message clearly states lack of support."""
|
||||
interceptor = DummyInterceptor()
|
||||
|
||||
with pytest.raises(NotImplementedError) as exc_info:
|
||||
LLM(
|
||||
model="azure/gpt-4",
|
||||
interceptor=interceptor,
|
||||
api_key="test-key",
|
||||
endpoint="https://test.openai.azure.com/openai/deployments/gpt-4",
|
||||
)
|
||||
|
||||
error_message = str(exc_info.value).lower()
|
||||
assert "azure" in error_message
|
||||
assert "interceptor" in error_message
|
||||
|
||||
def test_bedrock_error_message_is_clear(self) -> None:
|
||||
"""Test that Bedrock error message clearly states lack of support."""
|
||||
interceptor = DummyInterceptor()
|
||||
|
||||
with pytest.raises(NotImplementedError) as exc_info:
|
||||
LLM(
|
||||
model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
interceptor=interceptor,
|
||||
aws_access_key_id="test-access-key",
|
||||
aws_secret_access_key="test-secret-key",
|
||||
aws_region_name="us-east-1",
|
||||
)
|
||||
|
||||
error_message = str(exc_info.value).lower()
|
||||
assert "bedrock" in error_message
|
||||
assert "interceptor" in error_message
|
||||
|
||||
def test_gemini_error_message_is_clear(self) -> None:
|
||||
"""Test that Gemini error message clearly states lack of support."""
|
||||
interceptor = DummyInterceptor()
|
||||
|
||||
with pytest.raises(NotImplementedError) as exc_info:
|
||||
LLM(
|
||||
model="gemini/gemini-pro",
|
||||
interceptor=interceptor,
|
||||
api_key="test-gemini-key",
|
||||
)
|
||||
|
||||
error_message = str(exc_info.value).lower()
|
||||
assert "gemini" in error_message
|
||||
assert "interceptor" in error_message
|
||||
|
||||
|
||||
class TestProviderSupportMatrix:
|
||||
"""Test suite to document which providers support interceptors."""
|
||||
|
||||
def test_supported_providers_accept_interceptor(self) -> None:
|
||||
"""Test that supported providers accept and use interceptors."""
|
||||
interceptor = DummyInterceptor()
|
||||
|
||||
# OpenAI - SUPPORTED
|
||||
openai_llm = LLM(model="gpt-4", interceptor=interceptor)
|
||||
assert openai_llm.interceptor is interceptor
|
||||
|
||||
# Anthropic - SUPPORTED
|
||||
anthropic_llm = LLM(model="anthropic/claude-3-opus-20240229", interceptor=interceptor)
|
||||
assert anthropic_llm.interceptor is interceptor
|
||||
|
||||
def test_unsupported_providers_raise_error(self) -> None:
|
||||
"""Test that unsupported providers raise NotImplementedError."""
|
||||
interceptor = DummyInterceptor()
|
||||
|
||||
# Azure - NOT SUPPORTED
|
||||
with pytest.raises(NotImplementedError):
|
||||
LLM(
|
||||
model="azure/gpt-4",
|
||||
interceptor=interceptor,
|
||||
api_key="test",
|
||||
endpoint="https://test.openai.azure.com/openai/deployments/gpt-4",
|
||||
)
|
||||
|
||||
# Bedrock - NOT SUPPORTED
|
||||
with pytest.raises(NotImplementedError):
|
||||
LLM(
|
||||
model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
interceptor=interceptor,
|
||||
aws_access_key_id="test",
|
||||
aws_secret_access_key="test",
|
||||
aws_region_name="us-east-1",
|
||||
)
|
||||
|
||||
# Gemini - NOT SUPPORTED
|
||||
with pytest.raises(NotImplementedError):
|
||||
LLM(
|
||||
model="gemini/gemini-pro",
|
||||
interceptor=interceptor,
|
||||
api_key="test",
|
||||
)
|
||||
|
||||
def test_all_providers_work_without_interceptor(self) -> None:
|
||||
"""Test that all providers work normally without interceptor."""
|
||||
# OpenAI
|
||||
openai_llm = LLM(model="gpt-4")
|
||||
assert openai_llm.interceptor is None
|
||||
|
||||
# Anthropic
|
||||
anthropic_llm = LLM(model="anthropic/claude-3-opus-20240229")
|
||||
assert anthropic_llm.interceptor is None
|
||||
|
||||
# Azure - doesn't have interceptor attribute
|
||||
azure_llm = LLM(
|
||||
model="azure/gpt-4",
|
||||
api_key="test",
|
||||
endpoint="https://test.openai.azure.com/openai/deployments/gpt-4",
|
||||
)
|
||||
assert not hasattr(azure_llm, 'interceptor') or azure_llm.interceptor is None
|
||||
|
||||
# Bedrock - doesn't have interceptor attribute
|
||||
bedrock_llm = LLM(
|
||||
model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
aws_access_key_id="test",
|
||||
aws_secret_access_key="test",
|
||||
aws_region_name="us-east-1",
|
||||
)
|
||||
assert not hasattr(bedrock_llm, 'interceptor') or bedrock_llm.interceptor is None
|
||||
|
||||
# Gemini - doesn't have interceptor attribute
|
||||
gemini_llm = LLM(model="gemini/gemini-pro", api_key="test")
|
||||
assert not hasattr(gemini_llm, 'interceptor') or gemini_llm.interceptor is None
|
||||
Reference in New Issue
Block a user