mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 00:58:30 +00:00
feat: add pydantic validation dunder to BaseInterceptor, improve HTTPTransport typing
This commit is contained in:
@@ -7,7 +7,14 @@ outbound and inbound messages at the transport level.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Generic, TypeVar
|
from typing import TYPE_CHECKING, Any, Generic, TypeVar
|
||||||
|
|
||||||
|
from pydantic_core import core_schema
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from pydantic import GetCoreSchemaHandler
|
||||||
|
from pydantic_core import CoreSchema
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
@@ -25,6 +32,7 @@ class BaseInterceptor(ABC, Generic[T, U]):
|
|||||||
U: Inbound message type (e.g., httpx.Response)
|
U: Inbound message type (e.g., httpx.Response)
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
>>> import httpx
|
||||||
>>> class CustomInterceptor(BaseInterceptor[httpx.Request, httpx.Response]):
|
>>> class CustomInterceptor(BaseInterceptor[httpx.Request, httpx.Response]):
|
||||||
... def on_outbound(self, message: httpx.Request) -> httpx.Request:
|
... def on_outbound(self, message: httpx.Request) -> httpx.Request:
|
||||||
... message.headers["X-Custom-Header"] = "value"
|
... message.headers["X-Custom-Header"] = "value"
|
||||||
@@ -80,3 +88,46 @@ class BaseInterceptor(ABC, Generic[T, U]):
|
|||||||
Modified message object.
|
Modified message object.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(
|
||||||
|
cls, _source_type: Any, _handler: GetCoreSchemaHandler
|
||||||
|
) -> CoreSchema:
|
||||||
|
"""Generate Pydantic core schema for BaseInterceptor.
|
||||||
|
|
||||||
|
This allows the generic BaseInterceptor to be used in Pydantic models
|
||||||
|
without requiring arbitrary_types_allowed=True. The schema validates
|
||||||
|
that the value is an instance of BaseInterceptor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
_source_type: The source type being validated (unused).
|
||||||
|
_handler: Handler for generating schemas (unused).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Pydantic core schema that validates BaseInterceptor instances.
|
||||||
|
"""
|
||||||
|
return core_schema.no_info_plain_validator_function(
|
||||||
|
_validate_interceptor,
|
||||||
|
serialization=core_schema.plain_serializer_function_ser_schema(
|
||||||
|
lambda x: x, return_schema=core_schema.any_schema()
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_interceptor(value: Any) -> BaseInterceptor[T, U]:
|
||||||
|
"""Validate that the value is a BaseInterceptor instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: The value to validate.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The validated BaseInterceptor instance.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the value is not a BaseInterceptor instance.
|
||||||
|
"""
|
||||||
|
if not isinstance(value, BaseInterceptor):
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected BaseInterceptor instance, got {type(value).__name__}"
|
||||||
|
)
|
||||||
|
return value
|
||||||
|
|||||||
@@ -6,16 +6,53 @@ to enable request/response modification at the transport level.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any
|
from collections.abc import Iterable
|
||||||
|
from typing import TYPE_CHECKING, TypedDict
|
||||||
|
|
||||||
import httpx
|
from httpx import (
|
||||||
|
AsyncHTTPTransport as _AsyncHTTPTransport,
|
||||||
|
HTTPTransport as _HTTPTransport,
|
||||||
|
)
|
||||||
|
from typing_extensions import NotRequired, Unpack
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from ssl import SSLContext
|
||||||
|
|
||||||
|
from httpx import Limits, Request, Response
|
||||||
|
from httpx._types import CertTypes, ProxyTypes
|
||||||
|
|
||||||
from crewai.llms.hooks.base import BaseInterceptor
|
from crewai.llms.hooks.base import BaseInterceptor
|
||||||
|
|
||||||
|
|
||||||
class HTTPTransport(httpx.HTTPTransport):
|
class HTTPTransportKwargs(TypedDict):
|
||||||
|
"""Typed dictionary for httpx.HTTPTransport initialization parameters.
|
||||||
|
|
||||||
|
These parameters configure the underlying HTTP transport behavior including
|
||||||
|
SSL verification, proxies, connection limits, and low-level socket options.
|
||||||
|
"""
|
||||||
|
|
||||||
|
verify: bool | str | SSLContext
|
||||||
|
cert: NotRequired[CertTypes | None]
|
||||||
|
trust_env: bool
|
||||||
|
http1: bool
|
||||||
|
http2: bool
|
||||||
|
limits: Limits
|
||||||
|
proxy: NotRequired[ProxyTypes | None]
|
||||||
|
uds: NotRequired[str | None]
|
||||||
|
local_address: NotRequired[str | None]
|
||||||
|
retries: int
|
||||||
|
socket_options: NotRequired[
|
||||||
|
Iterable[
|
||||||
|
tuple[int, int, int]
|
||||||
|
| tuple[int, int, bytes | bytearray]
|
||||||
|
| tuple[int, int, None, int]
|
||||||
|
]
|
||||||
|
| None
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class HTTPTransport(_HTTPTransport):
|
||||||
"""HTTP transport that uses an interceptor for request/response modification.
|
"""HTTP transport that uses an interceptor for request/response modification.
|
||||||
|
|
||||||
This transport is used internally when a user provides a BaseInterceptor.
|
This transport is used internally when a user provides a BaseInterceptor.
|
||||||
@@ -25,19 +62,19 @@ class HTTPTransport(httpx.HTTPTransport):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
interceptor: BaseInterceptor[httpx.Request, httpx.Response],
|
interceptor: BaseInterceptor[Request, Response],
|
||||||
**kwargs: Any,
|
**kwargs: Unpack[HTTPTransportKwargs],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize transport with interceptor.
|
"""Initialize transport with interceptor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
interceptor: HTTP interceptor for modifying raw request/response objects.
|
interceptor: HTTP interceptor for modifying raw request/response objects.
|
||||||
**kwargs: Additional arguments passed to httpx.HTTPTransport.
|
**kwargs: HTTPTransport configuration parameters (verify, cert, proxy, etc.).
|
||||||
"""
|
"""
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.interceptor = interceptor
|
self.interceptor = interceptor
|
||||||
|
|
||||||
def handle_request(self, request: httpx.Request) -> httpx.Response:
|
def handle_request(self, request: Request) -> Response:
|
||||||
"""Handle request with interception.
|
"""Handle request with interception.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -51,7 +88,7 @@ class HTTPTransport(httpx.HTTPTransport):
|
|||||||
return self.interceptor.on_inbound(response)
|
return self.interceptor.on_inbound(response)
|
||||||
|
|
||||||
|
|
||||||
class AsyncHTTPransport(httpx.AsyncHTTPTransport):
|
class AsyncHTTPTransport(_AsyncHTTPTransport):
|
||||||
"""Async HTTP transport that uses an interceptor for request/response modification.
|
"""Async HTTP transport that uses an interceptor for request/response modification.
|
||||||
|
|
||||||
This transport is used internally when a user provides a BaseInterceptor.
|
This transport is used internally when a user provides a BaseInterceptor.
|
||||||
@@ -61,19 +98,19 @@ class AsyncHTTPransport(httpx.AsyncHTTPTransport):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
interceptor: BaseInterceptor[httpx.Request, httpx.Response],
|
interceptor: BaseInterceptor[Request, Response],
|
||||||
**kwargs: Any,
|
**kwargs: Unpack[HTTPTransportKwargs],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize async transport with interceptor.
|
"""Initialize async transport with interceptor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
interceptor: HTTP interceptor for modifying raw request/response objects.
|
interceptor: HTTP interceptor for modifying raw request/response objects.
|
||||||
**kwargs: Additional arguments passed to httpx.AsyncHTTPTransport.
|
**kwargs: HTTPTransport configuration parameters (verify, cert, proxy, etc.).
|
||||||
"""
|
"""
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.interceptor = interceptor
|
self.interceptor = interceptor
|
||||||
|
|
||||||
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
|
async def handle_async_request(self, request: Request) -> Response:
|
||||||
"""Handle async request with interception.
|
"""Handle async request with interception.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import httpx
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from crewai.llms.hooks.base import BaseInterceptor
|
from crewai.llms.hooks.base import BaseInterceptor
|
||||||
from crewai.llms.hooks.transport import AsyncHTTPransport, HTTPTransport
|
from crewai.llms.hooks.transport import AsyncHTTPTransport, HTTPTransport
|
||||||
|
|
||||||
|
|
||||||
class TrackingInterceptor(BaseInterceptor[httpx.Request, httpx.Response]):
|
class TrackingInterceptor(BaseInterceptor[httpx.Request, httpx.Response]):
|
||||||
@@ -128,7 +128,7 @@ class TestAsyncHTTPTransport:
|
|||||||
def test_async_transport_instantiation(self) -> None:
|
def test_async_transport_instantiation(self) -> None:
|
||||||
"""Test that async transport can be instantiated with interceptor."""
|
"""Test that async transport can be instantiated with interceptor."""
|
||||||
interceptor = TrackingInterceptor()
|
interceptor = TrackingInterceptor()
|
||||||
transport = AsyncHTTPransport(interceptor=interceptor)
|
transport = AsyncHTTPTransport(interceptor=interceptor)
|
||||||
|
|
||||||
assert transport.interceptor is interceptor
|
assert transport.interceptor is interceptor
|
||||||
|
|
||||||
@@ -136,13 +136,13 @@ class TestAsyncHTTPTransport:
|
|||||||
"""Test that async transport requires interceptor parameter."""
|
"""Test that async transport requires interceptor parameter."""
|
||||||
# AsyncHTTPransport requires an interceptor parameter
|
# AsyncHTTPransport requires an interceptor parameter
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
AsyncHTTPransport()
|
AsyncHTTPTransport()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_async_interceptor_called_on_request(self) -> None:
|
async def test_async_interceptor_called_on_request(self) -> None:
|
||||||
"""Test that async interceptor hooks are called during request handling."""
|
"""Test that async interceptor hooks are called during request handling."""
|
||||||
interceptor = TrackingInterceptor()
|
interceptor = TrackingInterceptor()
|
||||||
transport = AsyncHTTPransport(interceptor=interceptor)
|
transport = AsyncHTTPTransport(interceptor=interceptor)
|
||||||
|
|
||||||
# Create a mock parent transport that returns a response
|
# Create a mock parent transport that returns a response
|
||||||
mock_response = httpx.Response(200, json={"success": True})
|
mock_response = httpx.Response(200, json={"success": True})
|
||||||
@@ -217,7 +217,7 @@ class TestTransportIntegration:
|
|||||||
async def test_multiple_async_requests_same_interceptor(self) -> None:
|
async def test_multiple_async_requests_same_interceptor(self) -> None:
|
||||||
"""Test that multiple async requests through same interceptor are tracked."""
|
"""Test that multiple async requests through same interceptor are tracked."""
|
||||||
interceptor = TrackingInterceptor()
|
interceptor = TrackingInterceptor()
|
||||||
transport = AsyncHTTPransport(interceptor=interceptor)
|
transport = AsyncHTTPTransport(interceptor=interceptor)
|
||||||
|
|
||||||
mock_response = httpx.Response(200)
|
mock_response = httpx.Response(200)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user