From 965aa48ea113c878202e3168c5d9f00dbb7880bd Mon Sep 17 00:00:00 2001 From: Greyson Lalonde Date: Thu, 6 Nov 2025 10:42:31 -0500 Subject: [PATCH] feat: add pydantic validation dunder to BaseInterceptor, improve HTTPTransport typing --- lib/crewai/src/crewai/llms/hooks/base.py | 53 +++++++++++++++- lib/crewai/src/crewai/llms/hooks/transport.py | 61 +++++++++++++++---- lib/crewai/tests/llms/hooks/test_transport.py | 10 +-- 3 files changed, 106 insertions(+), 18 deletions(-) diff --git a/lib/crewai/src/crewai/llms/hooks/base.py b/lib/crewai/src/crewai/llms/hooks/base.py index d476e1acb..2de545fd8 100644 --- a/lib/crewai/src/crewai/llms/hooks/base.py +++ b/lib/crewai/src/crewai/llms/hooks/base.py @@ -7,7 +7,14 @@ outbound and inbound messages at the transport level. from __future__ import annotations from abc import ABC, abstractmethod -from typing import Generic, TypeVar +from typing import TYPE_CHECKING, Any, Generic, TypeVar + +from pydantic_core import core_schema + + +if TYPE_CHECKING: + from pydantic import GetCoreSchemaHandler + from pydantic_core import CoreSchema T = TypeVar("T") @@ -25,6 +32,7 @@ class BaseInterceptor(ABC, Generic[T, U]): U: Inbound message type (e.g., httpx.Response) Example: + >>> import httpx >>> class CustomInterceptor(BaseInterceptor[httpx.Request, httpx.Response]): ... def on_outbound(self, message: httpx.Request) -> httpx.Request: ... message.headers["X-Custom-Header"] = "value" @@ -80,3 +88,46 @@ class BaseInterceptor(ABC, Generic[T, U]): Modified message object. """ raise NotImplementedError + + @classmethod + def __get_pydantic_core_schema__( + cls, _source_type: Any, _handler: GetCoreSchemaHandler + ) -> CoreSchema: + """Generate Pydantic core schema for BaseInterceptor. + + This allows the generic BaseInterceptor to be used in Pydantic models + without requiring arbitrary_types_allowed=True. The schema validates + that the value is an instance of BaseInterceptor. + + Args: + _source_type: The source type being validated (unused). + _handler: Handler for generating schemas (unused). + + Returns: + A Pydantic core schema that validates BaseInterceptor instances. + """ + return core_schema.no_info_plain_validator_function( + _validate_interceptor, + serialization=core_schema.plain_serializer_function_ser_schema( + lambda x: x, return_schema=core_schema.any_schema() + ), + ) + + +def _validate_interceptor(value: Any) -> BaseInterceptor[T, U]: + """Validate that the value is a BaseInterceptor instance. + + Args: + value: The value to validate. + + Returns: + The validated BaseInterceptor instance. + + Raises: + ValueError: If the value is not a BaseInterceptor instance. + """ + if not isinstance(value, BaseInterceptor): + raise ValueError( + f"Expected BaseInterceptor instance, got {type(value).__name__}" + ) + return value diff --git a/lib/crewai/src/crewai/llms/hooks/transport.py b/lib/crewai/src/crewai/llms/hooks/transport.py index baea1a528..ee3f9224c 100644 --- a/lib/crewai/src/crewai/llms/hooks/transport.py +++ b/lib/crewai/src/crewai/llms/hooks/transport.py @@ -6,16 +6,53 @@ to enable request/response modification at the transport level. from __future__ import annotations -from typing import TYPE_CHECKING, Any +from collections.abc import Iterable +from typing import TYPE_CHECKING, TypedDict -import httpx +from httpx import ( + AsyncHTTPTransport as _AsyncHTTPTransport, + HTTPTransport as _HTTPTransport, +) +from typing_extensions import NotRequired, Unpack if TYPE_CHECKING: + from ssl import SSLContext + + from httpx import Limits, Request, Response + from httpx._types import CertTypes, ProxyTypes + from crewai.llms.hooks.base import BaseInterceptor -class HTTPTransport(httpx.HTTPTransport): +class HTTPTransportKwargs(TypedDict): + """Typed dictionary for httpx.HTTPTransport initialization parameters. + + These parameters configure the underlying HTTP transport behavior including + SSL verification, proxies, connection limits, and low-level socket options. + """ + + verify: bool | str | SSLContext + cert: NotRequired[CertTypes | None] + trust_env: bool + http1: bool + http2: bool + limits: Limits + proxy: NotRequired[ProxyTypes | None] + uds: NotRequired[str | None] + local_address: NotRequired[str | None] + retries: int + socket_options: NotRequired[ + Iterable[ + tuple[int, int, int] + | tuple[int, int, bytes | bytearray] + | tuple[int, int, None, int] + ] + | None + ] + + +class HTTPTransport(_HTTPTransport): """HTTP transport that uses an interceptor for request/response modification. This transport is used internally when a user provides a BaseInterceptor. @@ -25,19 +62,19 @@ class HTTPTransport(httpx.HTTPTransport): def __init__( self, - interceptor: BaseInterceptor[httpx.Request, httpx.Response], - **kwargs: Any, + interceptor: BaseInterceptor[Request, Response], + **kwargs: Unpack[HTTPTransportKwargs], ) -> None: """Initialize transport with interceptor. Args: interceptor: HTTP interceptor for modifying raw request/response objects. - **kwargs: Additional arguments passed to httpx.HTTPTransport. + **kwargs: HTTPTransport configuration parameters (verify, cert, proxy, etc.). """ super().__init__(**kwargs) self.interceptor = interceptor - def handle_request(self, request: httpx.Request) -> httpx.Response: + def handle_request(self, request: Request) -> Response: """Handle request with interception. Args: @@ -51,7 +88,7 @@ class HTTPTransport(httpx.HTTPTransport): return self.interceptor.on_inbound(response) -class AsyncHTTPransport(httpx.AsyncHTTPTransport): +class AsyncHTTPTransport(_AsyncHTTPTransport): """Async HTTP transport that uses an interceptor for request/response modification. This transport is used internally when a user provides a BaseInterceptor. @@ -61,19 +98,19 @@ class AsyncHTTPransport(httpx.AsyncHTTPTransport): def __init__( self, - interceptor: BaseInterceptor[httpx.Request, httpx.Response], - **kwargs: Any, + interceptor: BaseInterceptor[Request, Response], + **kwargs: Unpack[HTTPTransportKwargs], ) -> None: """Initialize async transport with interceptor. Args: interceptor: HTTP interceptor for modifying raw request/response objects. - **kwargs: Additional arguments passed to httpx.AsyncHTTPTransport. + **kwargs: HTTPTransport configuration parameters (verify, cert, proxy, etc.). """ super().__init__(**kwargs) self.interceptor = interceptor - async def handle_async_request(self, request: httpx.Request) -> httpx.Response: + async def handle_async_request(self, request: Request) -> Response: """Handle async request with interception. Args: diff --git a/lib/crewai/tests/llms/hooks/test_transport.py b/lib/crewai/tests/llms/hooks/test_transport.py index 5299fa871..5ff5162bd 100644 --- a/lib/crewai/tests/llms/hooks/test_transport.py +++ b/lib/crewai/tests/llms/hooks/test_transport.py @@ -6,7 +6,7 @@ import httpx import pytest from crewai.llms.hooks.base import BaseInterceptor -from crewai.llms.hooks.transport import AsyncHTTPransport, HTTPTransport +from crewai.llms.hooks.transport import AsyncHTTPTransport, HTTPTransport class TrackingInterceptor(BaseInterceptor[httpx.Request, httpx.Response]): @@ -128,7 +128,7 @@ class TestAsyncHTTPTransport: def test_async_transport_instantiation(self) -> None: """Test that async transport can be instantiated with interceptor.""" interceptor = TrackingInterceptor() - transport = AsyncHTTPransport(interceptor=interceptor) + transport = AsyncHTTPTransport(interceptor=interceptor) assert transport.interceptor is interceptor @@ -136,13 +136,13 @@ class TestAsyncHTTPTransport: """Test that async transport requires interceptor parameter.""" # AsyncHTTPransport requires an interceptor parameter with pytest.raises(TypeError): - AsyncHTTPransport() + AsyncHTTPTransport() @pytest.mark.asyncio async def test_async_interceptor_called_on_request(self) -> None: """Test that async interceptor hooks are called during request handling.""" interceptor = TrackingInterceptor() - transport = AsyncHTTPransport(interceptor=interceptor) + transport = AsyncHTTPTransport(interceptor=interceptor) # Create a mock parent transport that returns a response mock_response = httpx.Response(200, json={"success": True}) @@ -217,7 +217,7 @@ class TestTransportIntegration: async def test_multiple_async_requests_same_interceptor(self) -> None: """Test that multiple async requests through same interceptor are tracked.""" interceptor = TrackingInterceptor() - transport = AsyncHTTPransport(interceptor=interceptor) + transport = AsyncHTTPTransport(interceptor=interceptor) mock_response = httpx.Response(200)