feat: add pydantic validation dunder to BaseInterceptor
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled

This commit is contained in:
Greyson LaLonde
2025-11-06 15:27:07 -05:00
committed by GitHub
parent fc521839e4
commit 9e5906c52f
8 changed files with 1761 additions and 2226 deletions

View File

@@ -7,7 +7,14 @@ outbound and inbound messages at the transport level.
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Generic, TypeVar
from typing import TYPE_CHECKING, Any, Generic, TypeVar
from pydantic_core import core_schema
if TYPE_CHECKING:
from pydantic import GetCoreSchemaHandler
from pydantic_core import CoreSchema
T = TypeVar("T")
@@ -25,6 +32,7 @@ class BaseInterceptor(ABC, Generic[T, U]):
U: Inbound message type (e.g., httpx.Response)
Example:
>>> import httpx
>>> class CustomInterceptor(BaseInterceptor[httpx.Request, httpx.Response]):
... def on_outbound(self, message: httpx.Request) -> httpx.Request:
... message.headers["X-Custom-Header"] = "value"
@@ -80,3 +88,46 @@ class BaseInterceptor(ABC, Generic[T, U]):
Modified message object.
"""
raise NotImplementedError
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: GetCoreSchemaHandler
) -> CoreSchema:
"""Generate Pydantic core schema for BaseInterceptor.
This allows the generic BaseInterceptor to be used in Pydantic models
without requiring arbitrary_types_allowed=True. The schema validates
that the value is an instance of BaseInterceptor.
Args:
_source_type: The source type being validated (unused).
_handler: Handler for generating schemas (unused).
Returns:
A Pydantic core schema that validates BaseInterceptor instances.
"""
return core_schema.no_info_plain_validator_function(
_validate_interceptor,
serialization=core_schema.plain_serializer_function_ser_schema(
lambda x: x, return_schema=core_schema.any_schema()
),
)
def _validate_interceptor(value: Any) -> BaseInterceptor[T, U]:
"""Validate that the value is a BaseInterceptor instance.
Args:
value: The value to validate.
Returns:
The validated BaseInterceptor instance.
Raises:
ValueError: If the value is not a BaseInterceptor instance.
"""
if not isinstance(value, BaseInterceptor):
raise ValueError(
f"Expected BaseInterceptor instance, got {type(value).__name__}"
)
return value

View File

@@ -6,16 +6,53 @@ to enable request/response modification at the transport level.
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from collections.abc import Iterable
from typing import TYPE_CHECKING, TypedDict
import httpx
from httpx import (
AsyncHTTPTransport as _AsyncHTTPTransport,
HTTPTransport as _HTTPTransport,
)
from typing_extensions import NotRequired, Unpack
if TYPE_CHECKING:
from ssl import SSLContext
from httpx import Limits, Request, Response
from httpx._types import CertTypes, ProxyTypes
from crewai.llms.hooks.base import BaseInterceptor
class HTTPTransport(httpx.HTTPTransport):
class HTTPTransportKwargs(TypedDict):
"""Typed dictionary for httpx.HTTPTransport initialization parameters.
These parameters configure the underlying HTTP transport behavior including
SSL verification, proxies, connection limits, and low-level socket options.
"""
verify: bool | str | SSLContext
cert: NotRequired[CertTypes | None]
trust_env: bool
http1: bool
http2: bool
limits: Limits
proxy: NotRequired[ProxyTypes | None]
uds: NotRequired[str | None]
local_address: NotRequired[str | None]
retries: int
socket_options: NotRequired[
Iterable[
tuple[int, int, int]
| tuple[int, int, bytes | bytearray]
| tuple[int, int, None, int]
]
| None
]
class HTTPTransport(_HTTPTransport):
"""HTTP transport that uses an interceptor for request/response modification.
This transport is used internally when a user provides a BaseInterceptor.
@@ -25,19 +62,19 @@ class HTTPTransport(httpx.HTTPTransport):
def __init__(
self,
interceptor: BaseInterceptor[httpx.Request, httpx.Response],
**kwargs: Any,
interceptor: BaseInterceptor[Request, Response],
**kwargs: Unpack[HTTPTransportKwargs],
) -> None:
"""Initialize transport with interceptor.
Args:
interceptor: HTTP interceptor for modifying raw request/response objects.
**kwargs: Additional arguments passed to httpx.HTTPTransport.
**kwargs: HTTPTransport configuration parameters (verify, cert, proxy, etc.).
"""
super().__init__(**kwargs)
self.interceptor = interceptor
def handle_request(self, request: httpx.Request) -> httpx.Response:
def handle_request(self, request: Request) -> Response:
"""Handle request with interception.
Args:
@@ -51,7 +88,7 @@ class HTTPTransport(httpx.HTTPTransport):
return self.interceptor.on_inbound(response)
class AsyncHTTPransport(httpx.AsyncHTTPTransport):
class AsyncHTTPTransport(_AsyncHTTPTransport):
"""Async HTTP transport that uses an interceptor for request/response modification.
This transport is used internally when a user provides a BaseInterceptor.
@@ -61,19 +98,19 @@ class AsyncHTTPransport(httpx.AsyncHTTPTransport):
def __init__(
self,
interceptor: BaseInterceptor[httpx.Request, httpx.Response],
**kwargs: Any,
interceptor: BaseInterceptor[Request, Response],
**kwargs: Unpack[HTTPTransportKwargs],
) -> None:
"""Initialize async transport with interceptor.
Args:
interceptor: HTTP interceptor for modifying raw request/response objects.
**kwargs: Additional arguments passed to httpx.AsyncHTTPTransport.
**kwargs: HTTPTransport configuration parameters (verify, cert, proxy, etc.).
"""
super().__init__(**kwargs)
self.interceptor = interceptor
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
async def handle_async_request(self, request: Request) -> Response:
"""Handle async request with interception.
Args: