feat: add pydantic validation dunder to BaseInterceptor, improve HTTPTransport typing

This commit is contained in:
Greyson Lalonde
2025-11-06 10:42:31 -05:00
parent e4cc9a664c
commit 965aa48ea1
3 changed files with 106 additions and 18 deletions

View File

@@ -7,7 +7,14 @@ outbound and inbound messages at the transport level.
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Generic, TypeVar
from typing import TYPE_CHECKING, Any, Generic, TypeVar
from pydantic_core import core_schema
if TYPE_CHECKING:
from pydantic import GetCoreSchemaHandler
from pydantic_core import CoreSchema
T = TypeVar("T")
@@ -25,6 +32,7 @@ class BaseInterceptor(ABC, Generic[T, U]):
U: Inbound message type (e.g., httpx.Response)
Example:
>>> import httpx
>>> class CustomInterceptor(BaseInterceptor[httpx.Request, httpx.Response]):
... def on_outbound(self, message: httpx.Request) -> httpx.Request:
... message.headers["X-Custom-Header"] = "value"
@@ -80,3 +88,46 @@ class BaseInterceptor(ABC, Generic[T, U]):
Modified message object.
"""
raise NotImplementedError
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: GetCoreSchemaHandler
) -> CoreSchema:
"""Generate Pydantic core schema for BaseInterceptor.
This allows the generic BaseInterceptor to be used in Pydantic models
without requiring arbitrary_types_allowed=True. The schema validates
that the value is an instance of BaseInterceptor.
Args:
_source_type: The source type being validated (unused).
_handler: Handler for generating schemas (unused).
Returns:
A Pydantic core schema that validates BaseInterceptor instances.
"""
return core_schema.no_info_plain_validator_function(
_validate_interceptor,
serialization=core_schema.plain_serializer_function_ser_schema(
lambda x: x, return_schema=core_schema.any_schema()
),
)
def _validate_interceptor(value: Any) -> BaseInterceptor[T, U]:
"""Validate that the value is a BaseInterceptor instance.
Args:
value: The value to validate.
Returns:
The validated BaseInterceptor instance.
Raises:
ValueError: If the value is not a BaseInterceptor instance.
"""
if not isinstance(value, BaseInterceptor):
raise ValueError(
f"Expected BaseInterceptor instance, got {type(value).__name__}"
)
return value

View File

@@ -6,16 +6,53 @@ to enable request/response modification at the transport level.
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from collections.abc import Iterable
from typing import TYPE_CHECKING, TypedDict
import httpx
from httpx import (
AsyncHTTPTransport as _AsyncHTTPTransport,
HTTPTransport as _HTTPTransport,
)
from typing_extensions import NotRequired, Unpack
if TYPE_CHECKING:
from ssl import SSLContext
from httpx import Limits, Request, Response
from httpx._types import CertTypes, ProxyTypes
from crewai.llms.hooks.base import BaseInterceptor
class HTTPTransport(httpx.HTTPTransport):
class HTTPTransportKwargs(TypedDict):
"""Typed dictionary for httpx.HTTPTransport initialization parameters.
These parameters configure the underlying HTTP transport behavior including
SSL verification, proxies, connection limits, and low-level socket options.
"""
verify: bool | str | SSLContext
cert: NotRequired[CertTypes | None]
trust_env: bool
http1: bool
http2: bool
limits: Limits
proxy: NotRequired[ProxyTypes | None]
uds: NotRequired[str | None]
local_address: NotRequired[str | None]
retries: int
socket_options: NotRequired[
Iterable[
tuple[int, int, int]
| tuple[int, int, bytes | bytearray]
| tuple[int, int, None, int]
]
| None
]
class HTTPTransport(_HTTPTransport):
"""HTTP transport that uses an interceptor for request/response modification.
This transport is used internally when a user provides a BaseInterceptor.
@@ -25,19 +62,19 @@ class HTTPTransport(httpx.HTTPTransport):
def __init__(
self,
interceptor: BaseInterceptor[httpx.Request, httpx.Response],
**kwargs: Any,
interceptor: BaseInterceptor[Request, Response],
**kwargs: Unpack[HTTPTransportKwargs],
) -> None:
"""Initialize transport with interceptor.
Args:
interceptor: HTTP interceptor for modifying raw request/response objects.
**kwargs: Additional arguments passed to httpx.HTTPTransport.
**kwargs: HTTPTransport configuration parameters (verify, cert, proxy, etc.).
"""
super().__init__(**kwargs)
self.interceptor = interceptor
def handle_request(self, request: httpx.Request) -> httpx.Response:
def handle_request(self, request: Request) -> Response:
"""Handle request with interception.
Args:
@@ -51,7 +88,7 @@ class HTTPTransport(httpx.HTTPTransport):
return self.interceptor.on_inbound(response)
class AsyncHTTPransport(httpx.AsyncHTTPTransport):
class AsyncHTTPTransport(_AsyncHTTPTransport):
"""Async HTTP transport that uses an interceptor for request/response modification.
This transport is used internally when a user provides a BaseInterceptor.
@@ -61,19 +98,19 @@ class AsyncHTTPransport(httpx.AsyncHTTPTransport):
def __init__(
self,
interceptor: BaseInterceptor[httpx.Request, httpx.Response],
**kwargs: Any,
interceptor: BaseInterceptor[Request, Response],
**kwargs: Unpack[HTTPTransportKwargs],
) -> None:
"""Initialize async transport with interceptor.
Args:
interceptor: HTTP interceptor for modifying raw request/response objects.
**kwargs: Additional arguments passed to httpx.AsyncHTTPTransport.
**kwargs: HTTPTransport configuration parameters (verify, cert, proxy, etc.).
"""
super().__init__(**kwargs)
self.interceptor = interceptor
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
async def handle_async_request(self, request: Request) -> Response:
"""Handle async request with interception.
Args:

View File

@@ -6,7 +6,7 @@ import httpx
import pytest
from crewai.llms.hooks.base import BaseInterceptor
from crewai.llms.hooks.transport import AsyncHTTPransport, HTTPTransport
from crewai.llms.hooks.transport import AsyncHTTPTransport, HTTPTransport
class TrackingInterceptor(BaseInterceptor[httpx.Request, httpx.Response]):
@@ -128,7 +128,7 @@ class TestAsyncHTTPTransport:
def test_async_transport_instantiation(self) -> None:
"""Test that async transport can be instantiated with interceptor."""
interceptor = TrackingInterceptor()
transport = AsyncHTTPransport(interceptor=interceptor)
transport = AsyncHTTPTransport(interceptor=interceptor)
assert transport.interceptor is interceptor
@@ -136,13 +136,13 @@ class TestAsyncHTTPTransport:
"""Test that async transport requires interceptor parameter."""
# AsyncHTTPransport requires an interceptor parameter
with pytest.raises(TypeError):
AsyncHTTPransport()
AsyncHTTPTransport()
@pytest.mark.asyncio
async def test_async_interceptor_called_on_request(self) -> None:
"""Test that async interceptor hooks are called during request handling."""
interceptor = TrackingInterceptor()
transport = AsyncHTTPransport(interceptor=interceptor)
transport = AsyncHTTPTransport(interceptor=interceptor)
# Create a mock parent transport that returns a response
mock_response = httpx.Response(200, json={"success": True})
@@ -217,7 +217,7 @@ class TestTransportIntegration:
async def test_multiple_async_requests_same_interceptor(self) -> None:
"""Test that multiple async requests through same interceptor are tracked."""
interceptor = TrackingInterceptor()
transport = AsyncHTTPransport(interceptor=interceptor)
transport = AsyncHTTPTransport(interceptor=interceptor)
mock_response = httpx.Response(200)