mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
feat: add pydantic validation dunder to BaseInterceptor, improve HTTPTransport typing
This commit is contained in:
@@ -7,7 +7,14 @@ outbound and inbound messages at the transport level.
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Generic, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, Generic, TypeVar
|
||||
|
||||
from pydantic_core import core_schema
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
from pydantic_core import CoreSchema
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
@@ -25,6 +32,7 @@ class BaseInterceptor(ABC, Generic[T, U]):
|
||||
U: Inbound message type (e.g., httpx.Response)
|
||||
|
||||
Example:
|
||||
>>> import httpx
|
||||
>>> class CustomInterceptor(BaseInterceptor[httpx.Request, httpx.Response]):
|
||||
... def on_outbound(self, message: httpx.Request) -> httpx.Request:
|
||||
... message.headers["X-Custom-Header"] = "value"
|
||||
@@ -80,3 +88,46 @@ class BaseInterceptor(ABC, Generic[T, U]):
|
||||
Modified message object.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, _source_type: Any, _handler: GetCoreSchemaHandler
|
||||
) -> CoreSchema:
|
||||
"""Generate Pydantic core schema for BaseInterceptor.
|
||||
|
||||
This allows the generic BaseInterceptor to be used in Pydantic models
|
||||
without requiring arbitrary_types_allowed=True. The schema validates
|
||||
that the value is an instance of BaseInterceptor.
|
||||
|
||||
Args:
|
||||
_source_type: The source type being validated (unused).
|
||||
_handler: Handler for generating schemas (unused).
|
||||
|
||||
Returns:
|
||||
A Pydantic core schema that validates BaseInterceptor instances.
|
||||
"""
|
||||
return core_schema.no_info_plain_validator_function(
|
||||
_validate_interceptor,
|
||||
serialization=core_schema.plain_serializer_function_ser_schema(
|
||||
lambda x: x, return_schema=core_schema.any_schema()
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _validate_interceptor(value: Any) -> BaseInterceptor[T, U]:
|
||||
"""Validate that the value is a BaseInterceptor instance.
|
||||
|
||||
Args:
|
||||
value: The value to validate.
|
||||
|
||||
Returns:
|
||||
The validated BaseInterceptor instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If the value is not a BaseInterceptor instance.
|
||||
"""
|
||||
if not isinstance(value, BaseInterceptor):
|
||||
raise ValueError(
|
||||
f"Expected BaseInterceptor instance, got {type(value).__name__}"
|
||||
)
|
||||
return value
|
||||
|
||||
@@ -6,16 +6,53 @@ to enable request/response modification at the transport level.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, TypedDict
|
||||
|
||||
import httpx
|
||||
from httpx import (
|
||||
AsyncHTTPTransport as _AsyncHTTPTransport,
|
||||
HTTPTransport as _HTTPTransport,
|
||||
)
|
||||
from typing_extensions import NotRequired, Unpack
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ssl import SSLContext
|
||||
|
||||
from httpx import Limits, Request, Response
|
||||
from httpx._types import CertTypes, ProxyTypes
|
||||
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
|
||||
|
||||
class HTTPTransport(httpx.HTTPTransport):
|
||||
class HTTPTransportKwargs(TypedDict):
|
||||
"""Typed dictionary for httpx.HTTPTransport initialization parameters.
|
||||
|
||||
These parameters configure the underlying HTTP transport behavior including
|
||||
SSL verification, proxies, connection limits, and low-level socket options.
|
||||
"""
|
||||
|
||||
verify: bool | str | SSLContext
|
||||
cert: NotRequired[CertTypes | None]
|
||||
trust_env: bool
|
||||
http1: bool
|
||||
http2: bool
|
||||
limits: Limits
|
||||
proxy: NotRequired[ProxyTypes | None]
|
||||
uds: NotRequired[str | None]
|
||||
local_address: NotRequired[str | None]
|
||||
retries: int
|
||||
socket_options: NotRequired[
|
||||
Iterable[
|
||||
tuple[int, int, int]
|
||||
| tuple[int, int, bytes | bytearray]
|
||||
| tuple[int, int, None, int]
|
||||
]
|
||||
| None
|
||||
]
|
||||
|
||||
|
||||
class HTTPTransport(_HTTPTransport):
|
||||
"""HTTP transport that uses an interceptor for request/response modification.
|
||||
|
||||
This transport is used internally when a user provides a BaseInterceptor.
|
||||
@@ -25,19 +62,19 @@ class HTTPTransport(httpx.HTTPTransport):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
interceptor: BaseInterceptor[httpx.Request, httpx.Response],
|
||||
**kwargs: Any,
|
||||
interceptor: BaseInterceptor[Request, Response],
|
||||
**kwargs: Unpack[HTTPTransportKwargs],
|
||||
) -> None:
|
||||
"""Initialize transport with interceptor.
|
||||
|
||||
Args:
|
||||
interceptor: HTTP interceptor for modifying raw request/response objects.
|
||||
**kwargs: Additional arguments passed to httpx.HTTPTransport.
|
||||
**kwargs: HTTPTransport configuration parameters (verify, cert, proxy, etc.).
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self.interceptor = interceptor
|
||||
|
||||
def handle_request(self, request: httpx.Request) -> httpx.Response:
|
||||
def handle_request(self, request: Request) -> Response:
|
||||
"""Handle request with interception.
|
||||
|
||||
Args:
|
||||
@@ -51,7 +88,7 @@ class HTTPTransport(httpx.HTTPTransport):
|
||||
return self.interceptor.on_inbound(response)
|
||||
|
||||
|
||||
class AsyncHTTPransport(httpx.AsyncHTTPTransport):
|
||||
class AsyncHTTPTransport(_AsyncHTTPTransport):
|
||||
"""Async HTTP transport that uses an interceptor for request/response modification.
|
||||
|
||||
This transport is used internally when a user provides a BaseInterceptor.
|
||||
@@ -61,19 +98,19 @@ class AsyncHTTPransport(httpx.AsyncHTTPTransport):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
interceptor: BaseInterceptor[httpx.Request, httpx.Response],
|
||||
**kwargs: Any,
|
||||
interceptor: BaseInterceptor[Request, Response],
|
||||
**kwargs: Unpack[HTTPTransportKwargs],
|
||||
) -> None:
|
||||
"""Initialize async transport with interceptor.
|
||||
|
||||
Args:
|
||||
interceptor: HTTP interceptor for modifying raw request/response objects.
|
||||
**kwargs: Additional arguments passed to httpx.AsyncHTTPTransport.
|
||||
**kwargs: HTTPTransport configuration parameters (verify, cert, proxy, etc.).
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self.interceptor = interceptor
|
||||
|
||||
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
|
||||
async def handle_async_request(self, request: Request) -> Response:
|
||||
"""Handle async request with interception.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -6,7 +6,7 @@ import httpx
|
||||
import pytest
|
||||
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
from crewai.llms.hooks.transport import AsyncHTTPransport, HTTPTransport
|
||||
from crewai.llms.hooks.transport import AsyncHTTPTransport, HTTPTransport
|
||||
|
||||
|
||||
class TrackingInterceptor(BaseInterceptor[httpx.Request, httpx.Response]):
|
||||
@@ -128,7 +128,7 @@ class TestAsyncHTTPTransport:
|
||||
def test_async_transport_instantiation(self) -> None:
|
||||
"""Test that async transport can be instantiated with interceptor."""
|
||||
interceptor = TrackingInterceptor()
|
||||
transport = AsyncHTTPransport(interceptor=interceptor)
|
||||
transport = AsyncHTTPTransport(interceptor=interceptor)
|
||||
|
||||
assert transport.interceptor is interceptor
|
||||
|
||||
@@ -136,13 +136,13 @@ class TestAsyncHTTPTransport:
|
||||
"""Test that async transport requires interceptor parameter."""
|
||||
# AsyncHTTPransport requires an interceptor parameter
|
||||
with pytest.raises(TypeError):
|
||||
AsyncHTTPransport()
|
||||
AsyncHTTPTransport()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_interceptor_called_on_request(self) -> None:
|
||||
"""Test that async interceptor hooks are called during request handling."""
|
||||
interceptor = TrackingInterceptor()
|
||||
transport = AsyncHTTPransport(interceptor=interceptor)
|
||||
transport = AsyncHTTPTransport(interceptor=interceptor)
|
||||
|
||||
# Create a mock parent transport that returns a response
|
||||
mock_response = httpx.Response(200, json={"success": True})
|
||||
@@ -217,7 +217,7 @@ class TestTransportIntegration:
|
||||
async def test_multiple_async_requests_same_interceptor(self) -> None:
|
||||
"""Test that multiple async requests through same interceptor are tracked."""
|
||||
interceptor = TrackingInterceptor()
|
||||
transport = AsyncHTTPransport(interceptor=interceptor)
|
||||
transport = AsyncHTTPTransport(interceptor=interceptor)
|
||||
|
||||
mock_response = httpx.Response(200)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user