diff --git a/lib/crewai/src/crewai/llm.py b/lib/crewai/src/crewai/llm.py index 352bec16d..4e7303347 100644 --- a/lib/crewai/src/crewai/llm.py +++ b/lib/crewai/src/crewai/llm.py @@ -20,8 +20,7 @@ from typing import ( ) from dotenv import load_dotenv -import httpx -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator from typing_extensions import Self from crewai.events.event_bus import crewai_event_bus @@ -37,7 +36,12 @@ from crewai.events.types.tool_usage_events import ( ToolUsageFinishedEvent, ToolUsageStartedEvent, ) -from crewai.llms.base_llm import BaseLLM, get_current_call_id, llm_call_context +from crewai.llms.base_llm import ( + BaseLLM, + JsonResponseFormat, + get_current_call_id, + llm_call_context, +) from crewai.llms.constants import ( ANTHROPIC_MODELS, AZURE_MODELS, @@ -63,8 +67,6 @@ except ImportError: if TYPE_CHECKING: from crewai.agent.core import Agent - from crewai.llms.hooks.base import BaseInterceptor - from crewai.llms.providers.anthropic.completion import AnthropicThinkingConfig from crewai.task import Task from crewai.tools.base_tool import BaseTool from crewai.utilities.types import LLMMessage @@ -342,6 +344,27 @@ class AccumulatedToolArgs(BaseModel): class LLM(BaseLLM): completion_cost: float | None = None + timeout: float | int | None = None + top_p: float | None = None + n: int | None = None + max_completion_tokens: int | None = None + max_tokens: int | float | None = None + presence_penalty: float | None = None + frequency_penalty: float | None = None + logit_bias: dict[int, float] | None = None + response_format: JsonResponseFormat | type[BaseModel] | None = None + seed: int | None = None + logprobs: int | None = None + top_logprobs: int | None = None + api_base: str | None = None + api_version: str | None = None + callbacks: list[Any] | None = None + reasoning_effort: Literal["none", "low", "medium", "high"] | None = None + stream: bool = False + interceptor: Any = None + thinking: Any = None + context_window_size: int = 0 + is_anthropic: bool = False def __new__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> LLM: """Factory method that routes to native SDK or falls back to LiteLLM. @@ -436,10 +459,7 @@ class LLM(BaseLLM): logger.error(error_msg) raise ImportError(error_msg) from None - instance = object.__new__(cls) - super(LLM, instance).__init__(model=model, is_litellm=True, **kwargs) - instance.is_litellm = True - return instance + return object.__new__(cls) @classmethod def _matches_provider_pattern(cls, model: str, provider: str) -> bool: @@ -624,89 +644,23 @@ class LLM(BaseLLM): return None - def __init__( - self, - model: str, - timeout: float | int | None = None, - temperature: float | None = None, - top_p: float | None = None, - n: int | None = None, - stop: str | list[str] | None = None, - max_completion_tokens: int | None = None, - max_tokens: int | float | None = None, - presence_penalty: float | None = None, - frequency_penalty: float | None = None, - logit_bias: dict[int, float] | None = None, - response_format: type[BaseModel] | None = None, - seed: int | None = None, - logprobs: int | None = None, - top_logprobs: int | None = None, - base_url: str | None = None, - api_base: str | None = None, - api_version: str | None = None, - api_key: str | None = None, - callbacks: list[Any] | None = None, - reasoning_effort: Literal["none", "low", "medium", "high"] | None = None, - stream: bool = False, - interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None, - thinking: AnthropicThinkingConfig | dict[str, Any] | None = None, - prefer_upload: bool = False, - **kwargs: Any, - ) -> None: - """Initialize LLM instance. + @model_validator(mode="before") + @classmethod + def _validate_llm_fields(cls, data: Any) -> Any: + if not isinstance(data, dict): + return data + model = data.get("model", "") + data["is_anthropic"] = cls._is_anthropic_model(model) + return data - Note: This __init__ method is only called for fallback instances. - Native provider instances handle their own initialization in their respective classes. - """ - super().__init__( - model=model, - temperature=temperature, - api_key=api_key, - base_url=base_url, - timeout=timeout, - **kwargs, - ) - self.model = model - self.timeout = timeout - self.temperature = temperature - self.top_p = top_p - self.n = n - self.max_completion_tokens = max_completion_tokens - self.max_tokens = max_tokens - self.presence_penalty = presence_penalty - self.frequency_penalty = frequency_penalty - self.logit_bias = logit_bias - self.response_format = response_format - self.seed = seed - self.logprobs = logprobs - self.top_logprobs = top_logprobs - self.base_url = base_url - self.api_base = api_base - self.api_version = api_version - self.api_key = api_key - self.callbacks = callbacks - self.context_window_size = 0 - self.reasoning_effort = reasoning_effort - self.prefer_upload = prefer_upload - self.additional_params = { - k: v for k, v in kwargs.items() if k not in ("is_litellm", "provider") - } - self.is_anthropic = self._is_anthropic_model(model) - self.stream = stream - self.interceptor = interceptor - - litellm.drop_params = True - - # Normalize self.stop to always be a list[str] - if stop is None: - self.stop: list[str] = [] - elif isinstance(stop, str): - self.stop = [stop] - else: - self.stop = stop - - self.set_callbacks(callbacks or []) - self.set_env_callbacks() + @model_validator(mode="after") + def _init_litellm(self) -> LLM: + self.is_litellm = True + if LITELLM_AVAILABLE: + litellm.drop_params = True + self.set_callbacks(self.callbacks or []) + self.set_env_callbacks() + return self @staticmethod def _is_anthropic_model(model: str) -> bool: @@ -2442,7 +2396,7 @@ class LLM(BaseLLM): **filtered_params, ) - def __deepcopy__(self, memo: dict[int, Any] | None) -> LLM: + def __deepcopy__(self, memo: dict[int, Any] | None = None) -> LLM: """Create a deep copy of the LLM instance.""" import copy diff --git a/lib/crewai/src/crewai/llms/base_llm.py b/lib/crewai/src/crewai/llms/base_llm.py index 6e81271e1..857c2707d 100644 --- a/lib/crewai/src/crewai/llms/base_llm.py +++ b/lib/crewai/src/crewai/llms/base_llm.py @@ -14,10 +14,18 @@ from datetime import datetime import json import logging import re -from typing import TYPE_CHECKING, Any, Final +from typing import TYPE_CHECKING, Any, Final, Literal import uuid -from pydantic import BaseModel +from pydantic import ( + AliasChoices, + BaseModel, + ConfigDict, + Field, + PrivateAttr, + model_validator, +) +from typing_extensions import TypedDict from crewai.events.event_bus import crewai_event_bus from crewai.events.types.llm_events import ( @@ -51,6 +59,12 @@ if TYPE_CHECKING: from crewai.utilities.types import LLMMessage +class JsonResponseFormat(TypedDict): + """Response format requesting raw JSON output (e.g. ``{"type": "json_object"}``).""" + + type: Literal["json_object"] + + DEFAULT_CONTEXT_WINDOW_SIZE: Final[int] = 4096 DEFAULT_SUPPORTS_STOP_WORDS: Final[bool] = True _JSON_EXTRACTION_PATTERN: Final[re.Pattern[str]] = re.compile(r"\{.*}", re.DOTALL) @@ -82,7 +96,7 @@ def get_current_call_id() -> str: return call_id -class BaseLLM(ABC): +class BaseLLM(BaseModel, ABC): """Abstract base class for LLM implementations. This class defines the interface that all LLM implementations must follow. @@ -101,56 +115,100 @@ class BaseLLM(ABC): additional_params: Additional provider-specific parameters. """ + model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True) + + model: str + temperature: float | None = None + api_key: str | None = None + base_url: str | None = None + provider: str = Field(default="openai") + prefer_upload: bool = False is_litellm: bool = False + stop: list[str] = Field( + default_factory=list, + validation_alias=AliasChoices("stop", "stop_sequences"), + ) + additional_params: dict[str, Any] = Field(default_factory=dict) - def __init__( - self, - model: str, - temperature: float | None = None, - api_key: str | None = None, - base_url: str | None = None, - provider: str | None = None, - prefer_upload: bool = False, - **kwargs: Any, - ) -> None: - """Initialize the BaseLLM with default attributes. + def __setattr__(self, name: str, value: Any) -> None: + if name in ("stop", "stop_sequences"): + if value is None: + value = [] + elif isinstance(value, str): + value = [value] + elif not isinstance(value, list): + value = list(value) + name = "stop" + try: + super().__setattr__(name, value) + except ValueError: + if name in self.model_fields: + raise # Re-raise validation errors on declared fields + # Fallback for attributes not declared as fields (e.g. mock patching) + object.__setattr__(self, name, value) + except AttributeError: + object.__setattr__(self, name, value) - Args: - model: The model identifier/name. - temperature: Optional temperature setting for response generation. - stop: Optional list of stop sequences for generation. - prefer_upload: Whether to prefer file upload over inline base64. - **kwargs: Additional provider-specific parameters. + def __delattr__(self, name: str) -> None: + try: + super().__delattr__(name) + except AttributeError: + object.__delattr__(self, name) + + @property + def stop_sequences(self) -> list[str]: + """Alias for ``stop`` — kept for backward compatibility with provider APIs. + + Writes are handled by ``__setattr__``, which normalizes and redirects + ``stop_sequences`` assignments to the ``stop`` field. """ - if not model: - raise ValueError("Model name is required and cannot be empty") + return self.stop - self.model = model - self.temperature = temperature - self.api_key = api_key - self.base_url = base_url - self.prefer_upload = prefer_upload - # Store additional parameters for provider-specific use - self.additional_params = kwargs - self._provider = provider or "openai" - - stop = kwargs.pop("stop", None) - if stop is None: - self.stop: list[str] = [] - elif isinstance(stop, str): - self.stop = [stop] - elif isinstance(stop, list): - self.stop = stop - else: - self.stop = [] - - self._token_usage = { + _token_usage: dict[str, int] = PrivateAttr( + default_factory=lambda: { "total_tokens": 0, "prompt_tokens": 0, "completion_tokens": 0, "successful_requests": 0, "cached_prompt_tokens": 0, } + ) + + @model_validator(mode="before") + @classmethod + def _validate_init_fields(cls, data: Any) -> Any: + if not isinstance(data, dict): + return data + + if not data.get("model"): + raise ValueError("Model name is required and cannot be empty") + + # Normalize stop: accept str, list, or None; also accept stop_sequences alias + stop_seqs = data.pop("stop_sequences", None) + stop = stop_seqs if stop_seqs is not None else data.get("stop") + if stop is None: + data["stop"] = [] + elif isinstance(stop, str): + data["stop"] = [stop] + elif isinstance(stop, list): + data["stop"] = stop + else: + data["stop"] = list(stop) + + # Default provider + if not data.get("provider"): + data["provider"] = "openai" + + # Collect unknown kwargs into additional_params + known_fields = set(cls.model_fields.keys()) + extras = {k: v for k, v in data.items() if k not in known_fields} + for k in extras: + data.pop(k) + existing = data.get("additional_params") or {} + existing.update(extras) + data["additional_params"] = existing + + return data def to_config_dict(self) -> dict[str, Any]: """Serialize this LLM to a dict that can reconstruct it via ``LLM(**config)``. @@ -174,16 +232,6 @@ class BaseLLM(ABC): return config - @property - def provider(self) -> str: - """Get the provider of the LLM.""" - return self._provider - - @provider.setter - def provider(self, value: str) -> None: - """Set the provider of the LLM.""" - self._provider = value - @abstractmethod def call( self, diff --git a/lib/crewai/src/crewai/llms/providers/anthropic/completion.py b/lib/crewai/src/crewai/llms/providers/anthropic/completion.py index 077c31589..1c77d2bc7 100644 --- a/lib/crewai/src/crewai/llms/providers/anthropic/completion.py +++ b/lib/crewai/src/crewai/llms/providers/anthropic/completion.py @@ -3,12 +3,13 @@ from __future__ import annotations import json import logging import os -from typing import TYPE_CHECKING, Any, Final, Literal, TypeGuard, cast +from typing import Any, Final, Literal, TypeGuard, cast -from pydantic import BaseModel +from pydantic import BaseModel, PrivateAttr, model_validator from crewai.events.types.llm_events import LLMCallType -from crewai.llms.base_llm import BaseLLM, llm_call_context +from crewai.llms.base_llm import BaseLLM, JsonResponseFormat, llm_call_context +from crewai.llms.hooks.base import BaseInterceptor from crewai.llms.hooks.transport import AsyncHTTPTransport, HTTPTransport from crewai.utilities.agent_utils import is_context_length_exceeded from crewai.utilities.exceptions.context_window_exceeding_exception import ( @@ -17,9 +18,6 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import ( from crewai.utilities.types import LLMMessage -if TYPE_CHECKING: - from crewai.llms.hooks.base import BaseInterceptor - try: from anthropic import Anthropic, AsyncAnthropic, transform_schema from anthropic.types import ( @@ -150,60 +148,47 @@ class AnthropicCompletion(BaseLLM): offering native tool use, streaming support, and proper message formatting. """ - def __init__( - self, - model: str = "claude-3-5-sonnet-20241022", - api_key: str | None = None, - base_url: str | None = None, - timeout: float | None = None, - max_retries: int = 2, - temperature: float | None = None, - max_tokens: int = 4096, # Required for Anthropic - top_p: float | None = None, - stop_sequences: list[str] | None = None, - stream: bool = False, - client_params: dict[str, Any] | None = None, - interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None, - thinking: AnthropicThinkingConfig | None = None, - response_format: type[BaseModel] | None = None, - tool_search: AnthropicToolSearchConfig | bool | None = None, - **kwargs: Any, - ): - """Initialize Anthropic chat completion client. + model: str = "claude-3-5-sonnet-20241022" + timeout: float | None = None + max_retries: int = 2 + max_tokens: int = 4096 + top_p: float | None = None + stream: bool = False + client_params: dict[str, Any] | None = None + interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None + thinking: AnthropicThinkingConfig | None = None + response_format: JsonResponseFormat | type[BaseModel] | None = None + tool_search: AnthropicToolSearchConfig | None = None + is_claude_3: bool = False + supports_tools: bool = True - Args: - model: Anthropic model name (e.g., 'claude-3-5-sonnet-20241022') - api_key: Anthropic API key (defaults to ANTHROPIC_API_KEY env var) - base_url: Custom base URL for Anthropic API - timeout: Request timeout in seconds - max_retries: Maximum number of retries - temperature: Sampling temperature (0-1) - max_tokens: Maximum tokens in response (required for Anthropic) - top_p: Nucleus sampling parameter - stop_sequences: Stop sequences (Anthropic uses stop_sequences, not stop) - stream: Enable streaming responses - client_params: Additional parameters for the Anthropic client - interceptor: HTTP interceptor for modifying requests/responses at transport level. - response_format: Pydantic model for structured output. When provided, responses - will be validated against this model schema. - tool_search: Enable Anthropic's server-side tool search. When True, uses "bm25" - variant by default. Pass an AnthropicToolSearchConfig to choose "regex" or - "bm25". When enabled, tools are automatically marked with defer_loading=True - and a tool search tool is injected into the tools list. - **kwargs: Additional parameters - """ - super().__init__( - model=model, temperature=temperature, stop=stop_sequences or [], **kwargs - ) + _client: Any = PrivateAttr(default=None) + _async_client: Any = PrivateAttr(default=None) + _previous_thinking_blocks: list[Any] = PrivateAttr(default_factory=list) - # Client params - self.interceptor = interceptor - self.client_params = client_params - self.base_url = base_url - self.timeout = timeout - self.max_retries = max_retries + @model_validator(mode="before") + @classmethod + def _normalize_anthropic_fields(cls, data: Any) -> Any: + if not isinstance(data, dict): + return data + # Anthropic uses stop_sequences; normalize from stop kwarg + popped = data.pop("stop_sequences", None) + seqs = popped if popped is not None else (data.get("stop") or []) + if isinstance(seqs, str): + seqs = [seqs] + data["stop"] = seqs + data["is_claude_3"] = "claude-3" in data.get("model", "").lower() + # Normalize tool_search + ts = data.get("tool_search") + if ts is True: + data["tool_search"] = AnthropicToolSearchConfig() + elif ts is not None and not isinstance(ts, AnthropicToolSearchConfig): + data["tool_search"] = None + return data - self.client = Anthropic(**self._get_client_params()) + @model_validator(mode="after") + def _init_clients(self) -> AnthropicCompletion: + self._client = Anthropic(**self._get_client_params()) async_client_params = self._get_client_params() if self.interceptor: @@ -211,51 +196,8 @@ class AnthropicCompletion(BaseLLM): async_http_client = httpx.AsyncClient(transport=async_transport) async_client_params["http_client"] = async_http_client - self.async_client = AsyncAnthropic(**async_client_params) - - # Store completion parameters - self.max_tokens = max_tokens - self.top_p = top_p - self.stream = stream - self.stop_sequences = stop_sequences or [] - self.thinking = thinking - self.previous_thinking_blocks: list[ThinkingBlock] = [] - self.response_format = response_format - # Tool search config - self.tool_search: AnthropicToolSearchConfig | None - if tool_search is True: - self.tool_search = AnthropicToolSearchConfig() - elif isinstance(tool_search, AnthropicToolSearchConfig): - self.tool_search = tool_search - else: - self.tool_search = None - # Model-specific settings - self.is_claude_3 = "claude-3" in model.lower() - self.supports_tools = True - - @property - def stop(self) -> list[str]: - """Get stop sequences sent to the API.""" - return self.stop_sequences - - @stop.setter - def stop(self, value: list[str] | str | None) -> None: - """Set stop sequences. - - Synchronizes stop_sequences to ensure values set by CrewAgentExecutor - are properly sent to the Anthropic API. - - Args: - value: Stop sequences as a list, single string, or None - """ - if value is None: - self.stop_sequences = [] - elif isinstance(value, str): - self.stop_sequences = [value] - elif isinstance(value, list): - self.stop_sequences = value - else: - self.stop_sequences = [] + self._async_client = AsyncAnthropic(**async_client_params) + return self def to_config_dict(self) -> dict[str, Any]: """Extend base config with Anthropic-specific fields.""" @@ -751,11 +693,11 @@ class AnthropicCompletion(BaseLLM): ) elif isinstance(content, list): formatted_messages.append({"role": "assistant", "content": content}) - elif self.thinking and self.previous_thinking_blocks: + elif self.thinking and self._previous_thinking_blocks: structured_content = cast( list[dict[str, Any]], [ - *self.previous_thinking_blocks, + *self._previous_thinking_blocks, {"type": "text", "text": content if content else ""}, ], ) @@ -809,7 +751,7 @@ class AnthropicCompletion(BaseLLM): available_functions: dict[str, Any] | None = None, from_task: Any | None = None, from_agent: Any | None = None, - response_model: type[BaseModel] | None = None, + response_model: JsonResponseFormat | type[BaseModel] | None = None, ) -> str | Any: """Handle non-streaming message completion.""" uses_file_api = _contains_file_id_reference(params.get("messages", [])) @@ -843,11 +785,11 @@ class AnthropicCompletion(BaseLLM): try: if betas: params["betas"] = betas - response = self.client.beta.messages.create( + response = self._client.beta.messages.create( **params, extra_body=extra_body ) else: - response = self.client.messages.create(**params) + response = self._client.messages.create(**params) except Exception as e: if is_context_length_exceeded(e): @@ -928,7 +870,7 @@ class AnthropicCompletion(BaseLLM): thinking_blocks.append(cast(ThinkingBlock, thinking_block)) if thinking_blocks: - self.previous_thinking_blocks = thinking_blocks + self._previous_thinking_blocks = thinking_blocks content = self._apply_stop_words(content) self._emit_call_completed_event( @@ -952,7 +894,7 @@ class AnthropicCompletion(BaseLLM): available_functions: dict[str, Any] | None = None, from_task: Any | None = None, from_agent: Any | None = None, - response_model: type[BaseModel] | None = None, + response_model: JsonResponseFormat | type[BaseModel] | None = None, ) -> str | Any: """Handle streaming message completion.""" betas: list[str] = [] @@ -991,9 +933,9 @@ class AnthropicCompletion(BaseLLM): current_tool_calls: dict[int, dict[str, Any]] = {} stream_context = ( - self.client.beta.messages.stream(**stream_params, extra_body=extra_body) + self._client.beta.messages.stream(**stream_params, extra_body=extra_body) if betas - else self.client.messages.stream(**stream_params) + else self._client.messages.stream(**stream_params) ) with stream_context as stream: response_id = None @@ -1072,7 +1014,7 @@ class AnthropicCompletion(BaseLLM): thinking_blocks.append(cast(ThinkingBlock, thinking_block)) if thinking_blocks: - self.previous_thinking_blocks = thinking_blocks + self._previous_thinking_blocks = thinking_blocks usage = self._extract_anthropic_token_usage(final_message) self._track_token_usage_internal(usage) @@ -1269,7 +1211,7 @@ class AnthropicCompletion(BaseLLM): try: # Send tool results back to Claude for final response - final_response: Message = self.client.messages.create(**follow_up_params) + final_response: Message = self._client.messages.create(**follow_up_params) # Track token usage for follow-up call follow_up_usage = self._extract_anthropic_token_usage(final_response) @@ -1288,7 +1230,7 @@ class AnthropicCompletion(BaseLLM): thinking_blocks.append(cast(ThinkingBlock, thinking_block)) if thinking_blocks: - self.previous_thinking_blocks = thinking_blocks + self._previous_thinking_blocks = thinking_blocks final_content = self._apply_stop_words(final_content) @@ -1330,7 +1272,7 @@ class AnthropicCompletion(BaseLLM): available_functions: dict[str, Any] | None = None, from_task: Any | None = None, from_agent: Any | None = None, - response_model: type[BaseModel] | None = None, + response_model: JsonResponseFormat | type[BaseModel] | None = None, ) -> str | Any: """Handle non-streaming async message completion.""" uses_file_api = _contains_file_id_reference(params.get("messages", [])) @@ -1364,11 +1306,11 @@ class AnthropicCompletion(BaseLLM): try: if betas: params["betas"] = betas - response = await self.async_client.beta.messages.create( + response = await self._async_client.beta.messages.create( **params, extra_body=extra_body ) else: - response = await self.async_client.messages.create(**params) + response = await self._async_client.messages.create(**params) except Exception as e: if is_context_length_exceeded(e): @@ -1461,7 +1403,7 @@ class AnthropicCompletion(BaseLLM): available_functions: dict[str, Any] | None = None, from_task: Any | None = None, from_agent: Any | None = None, - response_model: type[BaseModel] | None = None, + response_model: JsonResponseFormat | type[BaseModel] | None = None, ) -> str | Any: """Handle async streaming message completion.""" betas: list[str] = [] @@ -1498,11 +1440,11 @@ class AnthropicCompletion(BaseLLM): current_tool_calls: dict[int, dict[str, Any]] = {} stream_context = ( - self.async_client.beta.messages.stream( + self._async_client.beta.messages.stream( **stream_params, extra_body=extra_body ) if betas - else self.async_client.messages.stream(**stream_params) + else self._async_client.messages.stream(**stream_params) ) async with stream_context as stream: response_id = None @@ -1664,7 +1606,7 @@ class AnthropicCompletion(BaseLLM): ] try: - final_response: Message = await self.async_client.messages.create( + final_response: Message = await self._async_client.messages.create( **follow_up_params ) @@ -1786,8 +1728,8 @@ class AnthropicCompletion(BaseLLM): from crewai_files.uploaders.anthropic import AnthropicFileUploader return AnthropicFileUploader( - client=self.client, - async_client=self.async_client, + client=self._client, + async_client=self._async_client, ) except ImportError: return None diff --git a/lib/crewai/src/crewai/llms/providers/azure/completion.py b/lib/crewai/src/crewai/llms/providers/azure/completion.py index accaf5b8e..cae50d0c6 100644 --- a/lib/crewai/src/crewai/llms/providers/azure/completion.py +++ b/lib/crewai/src/crewai/llms/providers/azure/completion.py @@ -3,11 +3,13 @@ from __future__ import annotations import json import logging import os -from typing import TYPE_CHECKING, Any, TypedDict +from typing import Any, TypedDict +from urllib.parse import urlparse -from pydantic import BaseModel +from pydantic import BaseModel, PrivateAttr, model_validator from typing_extensions import Self +from crewai.llms.hooks.base import BaseInterceptor from crewai.utilities.agent_utils import is_context_length_exceeded from crewai.utilities.exceptions.context_window_exceeding_exception import ( LLMContextLengthExceededError, @@ -16,10 +18,6 @@ from crewai.utilities.pydantic_schema_utils import generate_model_description from crewai.utilities.types import LLMMessage -if TYPE_CHECKING: - from crewai.llms.hooks.base import BaseInterceptor - - try: from azure.ai.inference import ( ChatCompletionsClient, @@ -76,109 +74,84 @@ class AzureCompletion(BaseLLM): offering native function calling, streaming support, and proper Azure authentication. """ - def __init__( - self, - model: str, - api_key: str | None = None, - endpoint: str | None = None, - api_version: str | None = None, - timeout: float | None = None, - max_retries: int = 2, - temperature: float | None = None, - top_p: float | None = None, - frequency_penalty: float | None = None, - presence_penalty: float | None = None, - max_tokens: int | None = None, - stop: list[str] | None = None, - stream: bool = False, - interceptor: BaseInterceptor[Any, Any] | None = None, - response_format: type[BaseModel] | None = None, - **kwargs: Any, - ): - """Initialize Azure AI Inference chat completion client. + endpoint: str | None = None + api_version: str | None = None + timeout: float | None = None + max_retries: int = 2 + top_p: float | None = None + frequency_penalty: float | None = None + presence_penalty: float | None = None + max_tokens: int | None = None + stream: bool = False + interceptor: BaseInterceptor[Any, Any] | None = None + response_format: type[BaseModel] | None = None + is_openai_model: bool = False + is_azure_openai_endpoint: bool = False - Args: - model: Azure deployment name or model name - api_key: Azure API key (defaults to AZURE_API_KEY env var) - endpoint: Azure endpoint URL (defaults to AZURE_ENDPOINT env var) - api_version: Azure API version (defaults to AZURE_API_VERSION env var) - timeout: Request timeout in seconds - max_retries: Maximum number of retries - temperature: Sampling temperature (0-2) - top_p: Nucleus sampling parameter - frequency_penalty: Frequency penalty (-2 to 2) - presence_penalty: Presence penalty (-2 to 2) - max_tokens: Maximum tokens in response - stop: Stop sequences - stream: Enable streaming responses - interceptor: HTTP interceptor (not yet supported for Azure). - response_format: Pydantic model for structured output. Used as default when - response_model is not passed to call()/acall() methods. - Only works with OpenAI models deployed on Azure. - **kwargs: Additional parameters - """ - if interceptor is not None: + _client: Any = PrivateAttr(default=None) + _async_client: Any = PrivateAttr(default=None) + + @model_validator(mode="before") + @classmethod + def _normalize_azure_fields(cls, data: Any) -> Any: + if not isinstance(data, dict): + return data + + if data.get("interceptor") is not None: raise NotImplementedError( "HTTP interceptors are not yet supported for Azure AI Inference provider. " "Interceptors are currently supported for OpenAI and Anthropic providers only." ) - super().__init__( - model=model, temperature=temperature, stop=stop or [], **kwargs - ) - - self.api_key = api_key or os.getenv("AZURE_API_KEY") - self.endpoint = ( - endpoint + # Resolve env vars + data["api_key"] = data.get("api_key") or os.getenv("AZURE_API_KEY") + data["endpoint"] = ( + data.get("endpoint") or os.getenv("AZURE_ENDPOINT") or os.getenv("AZURE_OPENAI_ENDPOINT") or os.getenv("AZURE_API_BASE") ) - self.api_version = api_version or os.getenv("AZURE_API_VERSION") or "2024-06-01" - self.timeout = timeout - self.max_retries = max_retries + data["api_version"] = ( + data.get("api_version") or os.getenv("AZURE_API_VERSION") or "2024-06-01" + ) - if not self.api_key: + if not data["api_key"]: raise ValueError( "Azure API key is required. Set AZURE_API_KEY environment variable or pass api_key parameter." ) - if not self.endpoint: + if not data["endpoint"]: raise ValueError( "Azure endpoint is required. Set AZURE_ENDPOINT environment variable or pass endpoint parameter." ) - # Validate and potentially fix Azure OpenAI endpoint URL - self.endpoint = self._validate_and_fix_endpoint(self.endpoint, model) + model = data.get("model", "") + data["endpoint"] = AzureCompletion._validate_and_fix_endpoint( + data["endpoint"], model + ) + data["is_openai_model"] = any( + prefix in model.lower() for prefix in ["gpt-", "o1-", "text-"] + ) + parsed = urlparse(data["endpoint"]) + hostname = parsed.hostname or "" + data["is_azure_openai_endpoint"] = ( + hostname == "openai.azure.com" or hostname.endswith(".openai.azure.com") + ) and "/openai/deployments/" in data["endpoint"] + return data - # Build client kwargs - client_kwargs = { + @model_validator(mode="after") + def _init_clients(self) -> AzureCompletion: + if not self.api_key: + raise ValueError("Azure API key is required.") + client_kwargs: dict[str, Any] = { "endpoint": self.endpoint, "credential": AzureKeyCredential(self.api_key), } - - # Add api_version if specified (primarily for Azure OpenAI endpoints) if self.api_version: client_kwargs["api_version"] = self.api_version - self.client = ChatCompletionsClient(**client_kwargs) # type: ignore[arg-type] - - self.async_client = AsyncChatCompletionsClient(**client_kwargs) # type: ignore[arg-type] - - self.top_p = top_p - self.frequency_penalty = frequency_penalty - self.presence_penalty = presence_penalty - self.max_tokens = max_tokens - self.stream = stream - self.response_format = response_format - - self.is_openai_model = any( - prefix in model.lower() for prefix in ["gpt-", "o1-", "text-"] - ) - - self.is_azure_openai_endpoint = ( - "openai.azure.com" in self.endpoint - and "/openai/deployments/" in self.endpoint - ) + self._client = ChatCompletionsClient(**client_kwargs) + self._async_client = AsyncChatCompletionsClient(**client_kwargs) + return self def to_config_dict(self) -> dict[str, Any]: """Extend base config with Azure-specific fields.""" @@ -215,7 +188,11 @@ class AzureCompletion(BaseLLM): Returns: Validated and potentially corrected endpoint URL """ - if "openai.azure.com" in endpoint and "/openai/deployments/" not in endpoint: + ep_host = urlparse(endpoint).hostname or "" + is_azure_openai = ep_host == "openai.azure.com" or ep_host.endswith( + ".openai.azure.com" + ) + if is_azure_openai and "/openai/deployments/" not in endpoint: endpoint = endpoint.rstrip("/") if not endpoint.endswith("/openai/deployments"): @@ -731,7 +708,7 @@ class AzureCompletion(BaseLLM): """Handle non-streaming chat completion.""" try: # Cast params to Any to avoid type checking issues with TypedDict unpacking - response: ChatCompletions = self.client.complete(**params) # type: ignore[assignment,arg-type] + response: ChatCompletions = self._client.complete(**params) return self._process_completion_response( response=response, params=params, @@ -926,7 +903,7 @@ class AzureCompletion(BaseLLM): tool_calls: dict[int, dict[str, Any]] = {} usage_data = {"total_tokens": 0} - for update in self.client.complete(**params): # type: ignore[arg-type] + for update in self._client.complete(**params): if isinstance(update, StreamingChatCompletionsUpdate): if update.usage: usage = update.usage @@ -967,7 +944,7 @@ class AzureCompletion(BaseLLM): """Handle non-streaming chat completion asynchronously.""" try: # Cast params to Any to avoid type checking issues with TypedDict unpacking - response: ChatCompletions = await self.async_client.complete(**params) # type: ignore[assignment,arg-type] + response: ChatCompletions = await self._async_client.complete(**params) return self._process_completion_response( response=response, params=params, @@ -993,8 +970,8 @@ class AzureCompletion(BaseLLM): usage_data = {"total_tokens": 0} - stream = await self.async_client.complete(**params) # type: ignore[arg-type] - async for update in stream: # type: ignore[union-attr] + stream = await self._async_client.complete(**params) + async for update in stream: if isinstance(update, StreamingChatCompletionsUpdate): if hasattr(update, "usage") and update.usage: usage = update.usage @@ -1110,8 +1087,8 @@ class AzureCompletion(BaseLLM): This ensures proper cleanup of the underlying aiohttp session to avoid unclosed connector warnings. """ - if hasattr(self.async_client, "close"): - await self.async_client.close() + if hasattr(self._async_client, "close"): + await self._async_client.close() async def __aenter__(self) -> Self: """Async context manager entry.""" diff --git a/lib/crewai/src/crewai/llms/providers/bedrock/completion.py b/lib/crewai/src/crewai/llms/providers/bedrock/completion.py index b17c98874..510c84cc7 100644 --- a/lib/crewai/src/crewai/llms/providers/bedrock/completion.py +++ b/lib/crewai/src/crewai/llms/providers/bedrock/completion.py @@ -7,7 +7,7 @@ import logging import os from typing import TYPE_CHECKING, Any, TypedDict, cast -from pydantic import BaseModel +from pydantic import BaseModel, PrivateAttr, model_validator from typing_extensions import Required from crewai.events.types.llm_events import LLMCallType @@ -33,7 +33,7 @@ if TYPE_CHECKING: ToolTypeDef, ) - from crewai.llms.hooks.base import BaseInterceptor +from crewai.llms.hooks.base import BaseInterceptor try: @@ -228,129 +228,97 @@ class BedrockCompletion(BaseLLM): - Model-specific conversation format handling (e.g., Cohere requirements) """ - def __init__( - self, - model: str = "anthropic.claude-3-5-sonnet-20241022-v2:0", - aws_access_key_id: str | None = None, - aws_secret_access_key: str | None = None, - aws_session_token: str | None = None, - region_name: str | None = None, - temperature: float | None = None, - max_tokens: int | None = None, - top_p: float | None = None, - top_k: int | None = None, - stop_sequences: Sequence[str] | None = None, - stream: bool = False, - guardrail_config: dict[str, Any] | None = None, - additional_model_request_fields: dict[str, Any] | None = None, - additional_model_response_field_paths: list[str] | None = None, - interceptor: BaseInterceptor[Any, Any] | None = None, - response_format: type[BaseModel] | None = None, - **kwargs: Any, - ) -> None: - """Initialize AWS Bedrock completion client. + model: str = "anthropic.claude-3-5-sonnet-20241022-v2:0" + aws_access_key_id: str | None = None + aws_secret_access_key: str | None = None + aws_session_token: str | None = None + region_name: str | None = None + max_tokens: int | None = None + top_p: float | None = None + top_k: int | None = None + stream: bool = False + guardrail_config: dict[str, Any] | None = None + additional_model_request_fields: dict[str, Any] | None = None + additional_model_response_field_paths: list[str] | None = None + interceptor: BaseInterceptor[Any, Any] | None = None + response_format: type[BaseModel] | None = None + is_claude_model: bool = False + supports_tools: bool = True + supports_streaming: bool = True + model_id: str = "" - Args: - model: The Bedrock model ID to use - aws_access_key_id: AWS access key (defaults to environment variable) - aws_secret_access_key: AWS secret key (defaults to environment variable) - aws_session_token: AWS session token for temporary credentials - region_name: AWS region name - temperature: Sampling temperature for response generation - max_tokens: Maximum tokens to generate - top_p: Nucleus sampling parameter - top_k: Top-k sampling parameter (Claude models only) - stop_sequences: List of sequences that stop generation - stream: Whether to use streaming responses - guardrail_config: Guardrail configuration for content filtering - additional_model_request_fields: Model-specific request parameters - additional_model_response_field_paths: Custom response field paths - interceptor: HTTP interceptor (not yet supported for Bedrock). - response_format: Pydantic model for structured output. Used as default when - response_model is not passed to call()/acall() methods. - **kwargs: Additional parameters - """ - if interceptor is not None: + _client: Any = PrivateAttr(default=None) + _async_exit_stack: Any = PrivateAttr(default=None) + _async_client_initialized: bool = PrivateAttr(default=False) + _async_client: Any = PrivateAttr(default=None) + + @model_validator(mode="before") + @classmethod + def _normalize_bedrock_fields(cls, data: Any) -> Any: + if not isinstance(data, dict): + return data + + if data.get("interceptor") is not None: raise NotImplementedError( "HTTP interceptors are not yet supported for AWS Bedrock provider. " "Interceptors are currently supported for OpenAI and Anthropic providers only." ) - # Extract provider from kwargs to avoid duplicate argument - kwargs.pop("provider", None) + # Force provider to bedrock + data.pop("provider", None) + data["provider"] = "bedrock" - super().__init__( - model=model, - temperature=temperature, - stop=stop_sequences or [], - provider="bedrock", - **kwargs, + # Normalize stop_sequences from stop kwarg + popped = data.pop("stop_sequences", None) + seqs = popped if popped is not None else (data.get("stop") or []) + if isinstance(seqs, str): + seqs = [seqs] + elif isinstance(seqs, Sequence) and not isinstance(seqs, list): + seqs = list(seqs) + data["stop"] = seqs + + # Resolve env vars + data["aws_access_key_id"] = data.get("aws_access_key_id") or os.getenv( + "AWS_ACCESS_KEY_ID" ) - - # Configure client with timeouts and retries following AWS best practices - config = Config( - read_timeout=300, - retries={ - "max_attempts": 3, - "mode": "adaptive", - }, - tcp_keepalive=True, + data["aws_secret_access_key"] = data.get("aws_secret_access_key") or os.getenv( + "AWS_SECRET_ACCESS_KEY" ) - - self.region_name = ( - region_name + data["aws_session_token"] = data.get("aws_session_token") or os.getenv( + "AWS_SESSION_TOKEN" + ) + data["region_name"] = ( + data.get("region_name") or os.getenv("AWS_DEFAULT_REGION") or os.getenv("AWS_REGION_NAME") or "us-east-1" ) - self.aws_access_key_id = aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID") - self.aws_secret_access_key = aws_secret_access_key or os.getenv( - "AWS_SECRET_ACCESS_KEY" - ) - self.aws_session_token = aws_session_token or os.getenv("AWS_SESSION_TOKEN") + model = data.get("model", "anthropic.claude-3-5-sonnet-20241022-v2:0") + data["is_claude_model"] = "claude" in model.lower() + data["model_id"] = model + return data - # Initialize Bedrock client with proper configuration + @model_validator(mode="after") + def _init_clients(self) -> BedrockCompletion: + config = Config( + read_timeout=300, + retries={"max_attempts": 3, "mode": "adaptive"}, + tcp_keepalive=True, + ) session = Session( aws_access_key_id=self.aws_access_key_id, aws_secret_access_key=self.aws_secret_access_key, aws_session_token=self.aws_session_token, region_name=self.region_name, ) - - self.client = session.client("bedrock-runtime", config=config) - + self._client = session.client("bedrock-runtime", config=config) self._async_exit_stack = AsyncExitStack() if AIOBOTOCORE_AVAILABLE else None - self._async_client_initialized = False - - # Store completion parameters - self.max_tokens = max_tokens - self.top_p = top_p - self.top_k = top_k - self.stream = stream - self.stop_sequences = stop_sequences - self.response_format = response_format - - # Store advanced features (optional) - self.guardrail_config = guardrail_config - self.additional_model_request_fields = additional_model_request_fields - self.additional_model_response_field_paths = ( - additional_model_response_field_paths - ) - - # Model-specific settings - self.is_claude_model = "claude" in model.lower() - self.supports_tools = True # Converse API supports tools for most models - self.supports_streaming = True - - # Handle inference profiles for newer models - self.model_id = model + return self def to_config_dict(self) -> dict[str, Any]: """Extend base config with Bedrock-specific fields.""" config = super().to_config_dict() - # NOTE: AWS credentials (access_key, secret_key, session_token) are - # intentionally excluded — they must come from env on resume. if self.region_name and self.region_name != "us-east-1": config["region_name"] = self.region_name if self.max_tokens is not None: @@ -363,30 +331,6 @@ class BedrockCompletion(BaseLLM): config["guardrail_config"] = self.guardrail_config return config - @property - def stop(self) -> list[str]: - """Get stop sequences sent to the API.""" - return [] if self.stop_sequences is None else list(self.stop_sequences) - - @stop.setter - def stop(self, value: Sequence[str] | str | None) -> None: - """Set stop sequences. - - Synchronizes stop_sequences to ensure values set by CrewAgentExecutor - are properly sent to the Bedrock API. - - Args: - value: Stop sequences as a Sequence, single string, or None - """ - if value is None: - self.stop_sequences = [] - elif isinstance(value, str): - self.stop_sequences = [value] - elif isinstance(value, Sequence): - self.stop_sequences = list(value) - else: - self.stop_sequences = [] - def call( self, messages: str | list[LLMMessage], @@ -710,7 +654,7 @@ class BedrockCompletion(BaseLLM): raise ValueError(f"Invalid message format at index {i}") # Call Bedrock Converse API with proper error handling - response = self.client.converse( + response = self._client.converse( modelId=self.model_id, messages=cast( "Sequence[MessageTypeDef | MessageOutputTypeDef]", @@ -994,13 +938,13 @@ class BedrockCompletion(BaseLLM): accumulated_tool_input = "" try: - response = self.client.converse_stream( + response = self._client.converse_stream( modelId=self.model_id, messages=cast( "Sequence[MessageTypeDef | MessageOutputTypeDef]", cast(object, messages), ), - **body, # type: ignore[arg-type] + **body, ) stream = response.get("stream") diff --git a/lib/crewai/src/crewai/llms/providers/gemini/completion.py b/lib/crewai/src/crewai/llms/providers/gemini/completion.py index f332bbc54..827df750c 100644 --- a/lib/crewai/src/crewai/llms/providers/gemini/completion.py +++ b/lib/crewai/src/crewai/llms/providers/gemini/completion.py @@ -5,12 +5,13 @@ import json import logging import os import re -from typing import TYPE_CHECKING, Any, Literal, cast +from typing import Any, Literal, cast -from pydantic import BaseModel +from pydantic import BaseModel, Field, PrivateAttr, model_validator from crewai.events.types.llm_events import LLMCallType from crewai.llms.base_llm import BaseLLM, llm_call_context +from crewai.llms.hooks.base import BaseInterceptor from crewai.utilities.agent_utils import is_context_length_exceeded from crewai.utilities.exceptions.context_window_exceeding_exception import ( LLMContextLengthExceededError, @@ -19,10 +20,6 @@ from crewai.utilities.pydantic_schema_utils import generate_model_description from crewai.utilities.types import LLMMessage -if TYPE_CHECKING: - from crewai.llms.hooks.base import BaseInterceptor - - try: from google import genai from google.genai import types @@ -44,137 +41,84 @@ class GeminiCompletion(BaseLLM): offering native function calling, streaming support, and proper Gemini formatting. """ - def __init__( - self, - model: str = "gemini-2.0-flash-001", - api_key: str | None = None, - project: str | None = None, - location: str | None = None, - temperature: float | None = None, - top_p: float | None = None, - top_k: int | None = None, - max_output_tokens: int | None = None, - stop_sequences: list[str] | None = None, - stream: bool = False, - safety_settings: dict[str, Any] | None = None, - client_params: dict[str, Any] | None = None, - interceptor: BaseInterceptor[Any, Any] | None = None, - use_vertexai: bool | None = None, - response_format: type[BaseModel] | None = None, - thinking_config: types.ThinkingConfig | None = None, - **kwargs: Any, - ): - """Initialize Google Gemini chat completion client. + model: str = "gemini-2.0-flash-001" + project: str | None = None + location: str | None = None + top_p: float | None = None + top_k: int | None = None + max_output_tokens: int | None = None + stream: bool = False + safety_settings: dict[str, Any] = Field(default_factory=dict) + client_params: dict[str, Any] = Field(default_factory=dict) + interceptor: BaseInterceptor[Any, Any] | None = None + use_vertexai: bool = False + response_format: type[BaseModel] | None = None + thinking_config: Any = None + tools: list[dict[str, Any]] | None = None + supports_tools: bool = False + is_gemini_2_0: bool = False - Args: - model: Gemini model name (e.g., 'gemini-2.0-flash-001', 'gemini-1.5-pro') - api_key: Google API key for Gemini API authentication. - Defaults to GOOGLE_API_KEY or GEMINI_API_KEY env var. - NOTE: Cannot be used with Vertex AI (project parameter). Use Gemini API instead. - project: Google Cloud project ID for Vertex AI with ADC authentication. - Requires Application Default Credentials (gcloud auth application-default login). - NOTE: Vertex AI does NOT support API keys, only OAuth2/ADC. - If both api_key and project are set, api_key takes precedence. - location: Google Cloud location (for Vertex AI with ADC, defaults to 'us-central1') - temperature: Sampling temperature (0-2) - top_p: Nucleus sampling parameter - top_k: Top-k sampling parameter - max_output_tokens: Maximum tokens in response - stop_sequences: Stop sequences - stream: Enable streaming responses - safety_settings: Safety filter settings - client_params: Additional parameters to pass to the Google Gen AI Client constructor. - Supports parameters like http_options, credentials, debug_config, etc. - interceptor: HTTP interceptor (not yet supported for Gemini). - use_vertexai: Whether to use Vertex AI instead of Gemini API. - - True: Use Vertex AI (with ADC or Express mode with API key) - - False: Use Gemini API (explicitly override env var) - - None (default): Check GOOGLE_GENAI_USE_VERTEXAI env var - When using Vertex AI with API key (Express mode), http_options with - api_version="v1" is automatically configured. - response_format: Pydantic model for structured output. Used as default when - response_model is not passed to call()/acall() methods. - thinking_config: ThinkingConfig for thinking models (gemini-2.5+, gemini-3+). - Controls thought output via include_thoughts, thinking_budget, - and thinking_level. When None, thinking models automatically - get include_thoughts=True so thought content is surfaced. - **kwargs: Additional parameters - """ - if interceptor is not None: + _client: Any = PrivateAttr(default=None) + + @model_validator(mode="before") + @classmethod + def _normalize_gemini_fields(cls, data: Any) -> Any: + if not isinstance(data, dict): + return data + + if data.get("interceptor") is not None: raise NotImplementedError( "HTTP interceptors are not yet supported for Google Gemini provider. " "Interceptors are currently supported for OpenAI and Anthropic providers only." ) - super().__init__( - model=model, temperature=temperature, stop=stop_sequences or [], **kwargs + # Normalize stop_sequences from stop kwarg + popped = data.pop("stop_sequences", None) + seqs = popped if popped is not None else (data.get("stop") or []) + if isinstance(seqs, str): + seqs = [seqs] + data["stop"] = seqs + + # Resolve env vars + data["api_key"] = ( + data.get("api_key") + or os.getenv("GOOGLE_API_KEY") + or os.getenv("GEMINI_API_KEY") + ) + data["project"] = data.get("project") or os.getenv("GOOGLE_CLOUD_PROJECT") + data["location"] = ( + data.get("location") or os.getenv("GOOGLE_CLOUD_LOCATION") or "us-central1" ) - # Store client params for later use - self.client_params = client_params or {} - - # Get API configuration with environment variable fallbacks - self.api_key = ( - api_key or os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY") - ) - self.project = project or os.getenv("GOOGLE_CLOUD_PROJECT") - self.location = location or os.getenv("GOOGLE_CLOUD_LOCATION") or "us-central1" - - if use_vertexai is None: - use_vertexai = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true" - - self.client = self._initialize_client(use_vertexai) - - # Store completion parameters - self.top_p = top_p - self.top_k = top_k - self.max_output_tokens = max_output_tokens - self.stream = stream - self.safety_settings = safety_settings or {} - self.stop_sequences = stop_sequences or [] - self.tools: list[dict[str, Any]] | None = None - self.response_format = response_format + use_vx = data.get("use_vertexai") + if use_vx is None: + use_vx = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true" + data["use_vertexai"] = use_vx # Model-specific settings + model = data.get("model", "gemini-2.0-flash-001") version_match = re.search(r"gemini-(\d+(?:\.\d+)?)", model.lower()) - self.supports_tools = bool( + data["supports_tools"] = bool( version_match and float(version_match.group(1)) >= 1.5 ) - self.is_gemini_2_0 = bool( + data["is_gemini_2_0"] = bool( version_match and float(version_match.group(1)) >= 2.0 ) - self.thinking_config = thinking_config + # Auto-enable thinking for gemini-2.5+ if ( - self.thinking_config is None + data.get("thinking_config") is None and version_match and float(version_match.group(1)) >= 2.5 ): - self.thinking_config = types.ThinkingConfig(include_thoughts=True) + data["thinking_config"] = types.ThinkingConfig(include_thoughts=True) - @property - def stop(self) -> list[str]: - """Get stop sequences sent to the API.""" - return self.stop_sequences + return data - @stop.setter - def stop(self, value: list[str] | str | None) -> None: - """Set stop sequences. - - Synchronizes stop_sequences to ensure values set by CrewAgentExecutor - are properly sent to the Gemini API. - - Args: - value: Stop sequences as a list, single string, or None - """ - if value is None: - self.stop_sequences = [] - elif isinstance(value, str): - self.stop_sequences = [value] - elif isinstance(value, list): - self.stop_sequences = value - else: - self.stop_sequences = [] + @model_validator(mode="after") + def _init_client(self) -> GeminiCompletion: + self._client = self._initialize_client(self.use_vertexai) + return self def to_config_dict(self) -> dict[str, Any]: """Extend base config with Gemini/Vertex-specific fields.""" @@ -283,8 +227,8 @@ class GeminiCompletion(BaseLLM): if ( hasattr(self, "client") - and hasattr(self.client, "vertexai") - and self.client.vertexai + and hasattr(self._client, "vertexai") + and self._client.vertexai ): # Vertex AI configuration params.update( @@ -1152,7 +1096,7 @@ class GeminiCompletion(BaseLLM): try: # The API accepts list[Content] but mypy is overly strict about variance contents_for_api: Any = contents - response = self.client.models.generate_content( + response = self._client.models.generate_content( model=self.model, contents=contents_for_api, config=config, @@ -1192,7 +1136,7 @@ class GeminiCompletion(BaseLLM): # The API accepts list[Content] but mypy is overly strict about variance contents_for_api: Any = contents - for chunk in self.client.models.generate_content_stream( + for chunk in self._client.models.generate_content_stream( model=self.model, contents=contents_for_api, config=config, @@ -1230,7 +1174,7 @@ class GeminiCompletion(BaseLLM): try: # The API accepts list[Content] but mypy is overly strict about variance contents_for_api: Any = contents - response = await self.client.aio.models.generate_content( + response = await self._client.aio.models.generate_content( model=self.model, contents=contents_for_api, config=config, @@ -1270,7 +1214,7 @@ class GeminiCompletion(BaseLLM): # The API accepts list[Content] but mypy is overly strict about variance contents_for_api: Any = contents - stream = await self.client.aio.models.generate_content_stream( + stream = await self._client.aio.models.generate_content_stream( model=self.model, contents=contents_for_api, config=config, @@ -1474,6 +1418,6 @@ class GeminiCompletion(BaseLLM): try: from crewai_files.uploaders.gemini import GeminiFileUploader - return GeminiFileUploader(client=self.client) + return GeminiFileUploader(client=self._client) except ImportError: return None diff --git a/lib/crewai/src/crewai/llms/providers/openai/completion.py b/lib/crewai/src/crewai/llms/providers/openai/completion.py index 803fd98cf..8870fcd85 100644 --- a/lib/crewai/src/crewai/llms/providers/openai/completion.py +++ b/lib/crewai/src/crewai/llms/providers/openai/completion.py @@ -14,10 +14,11 @@ from openai.types.chat import ChatCompletion, ChatCompletionChunk from openai.types.chat.chat_completion import Choice from openai.types.chat.chat_completion_chunk import ChoiceDelta from openai.types.responses import Response -from pydantic import BaseModel +from pydantic import BaseModel, PrivateAttr, model_validator from crewai.events.types.llm_events import LLMCallType -from crewai.llms.base_llm import BaseLLM, llm_call_context +from crewai.llms.base_llm import BaseLLM, JsonResponseFormat, llm_call_context +from crewai.llms.hooks.base import BaseInterceptor from crewai.llms.hooks.transport import AsyncHTTPTransport, HTTPTransport from crewai.utilities.agent_utils import is_context_length_exceeded from crewai.utilities.exceptions.context_window_exceeding_exception import ( @@ -29,7 +30,6 @@ from crewai.utilities.types import LLMMessage if TYPE_CHECKING: from crewai.agent.core import Agent - from crewai.llms.hooks.base import BaseInterceptor from crewai.task import Task from crewai.tools.base_tool import BaseTool @@ -183,77 +183,69 @@ class OpenAICompletion(BaseLLM): "computer_use": "computer_use_preview", } - def __init__( - self, - model: str = "gpt-4o", - api_key: str | None = None, - base_url: str | None = None, - organization: str | None = None, - project: str | None = None, - timeout: float | None = None, - max_retries: int = 2, - default_headers: dict[str, str] | None = None, - default_query: dict[str, Any] | None = None, - client_params: dict[str, Any] | None = None, - temperature: float | None = None, - top_p: float | None = None, - frequency_penalty: float | None = None, - presence_penalty: float | None = None, - max_tokens: int | None = None, - max_completion_tokens: int | None = None, - seed: int | None = None, - stream: bool = False, - response_format: dict[str, Any] | type[BaseModel] | None = None, - logprobs: bool | None = None, - top_logprobs: int | None = None, - reasoning_effort: str | None = None, - provider: str | None = None, - interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None, - api: Literal["completions", "responses"] = "completions", - instructions: str | None = None, - store: bool | None = None, - previous_response_id: str | None = None, - include: list[str] | None = None, - builtin_tools: list[str] | None = None, - parse_tool_outputs: bool = False, - auto_chain: bool = False, - auto_chain_reasoning: bool = False, - **kwargs: Any, - ) -> None: - """Initialize OpenAI completion client.""" + model: str = "gpt-4o" + organization: str | None = None + project: str | None = None + timeout: float | None = None + max_retries: int = 2 + default_headers: dict[str, str] | None = None + default_query: dict[str, Any] | None = None + client_params: dict[str, Any] | None = None + top_p: float | None = None + frequency_penalty: float | None = None + presence_penalty: float | None = None + max_tokens: int | None = None + max_completion_tokens: int | None = None + seed: int | None = None + stream: bool = False + response_format: JsonResponseFormat | type[BaseModel] | None = None + logprobs: bool | None = None + top_logprobs: int | None = None + reasoning_effort: str | None = None + interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None + api: Literal["completions", "responses"] = "completions" + instructions: str | None = None + store: bool | None = None + previous_response_id: str | None = None + include: list[str] | None = None + builtin_tools: list[str] | None = None + parse_tool_outputs: bool = False + auto_chain: bool = False + auto_chain_reasoning: bool = False + api_base: str | None = None + is_o1_model: bool = False + is_gpt4_model: bool = False - if provider is None: - provider = kwargs.pop("provider", "openai") + _client: Any = PrivateAttr(default=None) + _async_client: Any = PrivateAttr(default=None) + _last_response_id: str | None = PrivateAttr(default=None) + _last_reasoning_items: list[Any] | None = PrivateAttr(default=None) - self.interceptor = interceptor - # Client configuration attributes - self.organization = organization - self.project = project - self.max_retries = max_retries - self.default_headers = default_headers - self.default_query = default_query - self.client_params = client_params - self.timeout = timeout - self.base_url = base_url - self.api_base = kwargs.pop("api_base", None) - - super().__init__( - model=model, - temperature=temperature, - api_key=api_key or os.getenv("OPENAI_API_KEY"), - base_url=base_url, - timeout=timeout, - provider=provider, - **kwargs, - ) + @model_validator(mode="before") + @classmethod + def _normalize_openai_fields(cls, data: Any) -> Any: + if not isinstance(data, dict): + return data + if not data.get("provider"): + data["provider"] = "openai" + data["api_key"] = data.get("api_key") or os.getenv("OPENAI_API_KEY") + # Extract api_base from kwargs if present + if "api_base" not in data: + data["api_base"] = None + model = data.get("model", "gpt-4o") + data["is_o1_model"] = "o1" in model.lower() + data["is_gpt4_model"] = "gpt-4" in model.lower() + return data + @model_validator(mode="after") + def _init_clients(self) -> OpenAICompletion: client_config = self._get_client_params() if self.interceptor: transport = HTTPTransport(interceptor=self.interceptor) http_client = httpx.Client(transport=transport) client_config["http_client"] = http_client - self.client = OpenAI(**client_config) + self._client = OpenAI(**client_config) async_client_config = self._get_client_params() if self.interceptor: @@ -261,35 +253,8 @@ class OpenAICompletion(BaseLLM): async_http_client = httpx.AsyncClient(transport=async_transport) async_client_config["http_client"] = async_http_client - self.async_client = AsyncOpenAI(**async_client_config) - - # Completion parameters - self.top_p = top_p - self.frequency_penalty = frequency_penalty - self.presence_penalty = presence_penalty - self.max_tokens = max_tokens - self.max_completion_tokens = max_completion_tokens - self.seed = seed - self.stream = stream - self.response_format = response_format - self.logprobs = logprobs - self.top_logprobs = top_logprobs - self.reasoning_effort = reasoning_effort - self.is_o1_model = "o1" in model.lower() - self.is_gpt4_model = "gpt-4" in model.lower() - - # API selection and Responses API parameters - self.api = api - self.instructions = instructions - self.store = store - self.previous_response_id = previous_response_id - self.include = include - self.builtin_tools = builtin_tools - self.parse_tool_outputs = parse_tool_outputs - self.auto_chain = auto_chain - self.auto_chain_reasoning = auto_chain_reasoning - self._last_response_id: str | None = None - self._last_reasoning_items: list[Any] | None = None + self._async_client = AsyncOpenAI(**async_client_config) + return self @property def last_response_id(self) -> str | None: @@ -818,7 +783,7 @@ class OpenAICompletion(BaseLLM): ) -> str | ResponsesAPIResult | Any: """Handle non-streaming Responses API call.""" try: - response: Response = self.client.responses.create(**params) + response: Response = self._client.responses.create(**params) # Track response ID for auto-chaining if self.auto_chain and response.id: @@ -950,7 +915,7 @@ class OpenAICompletion(BaseLLM): ) -> str | ResponsesAPIResult | Any: """Handle async non-streaming Responses API call.""" try: - response: Response = await self.async_client.responses.create(**params) + response: Response = await self._async_client.responses.create(**params) # Track response ID for auto-chaining if self.auto_chain and response.id: @@ -1081,7 +1046,7 @@ class OpenAICompletion(BaseLLM): function_calls: list[dict[str, Any]] = [] final_response: Response | None = None - stream = self.client.responses.create(**params) + stream = self._client.responses.create(**params) response_id_stream = None for event in stream: @@ -1205,7 +1170,7 @@ class OpenAICompletion(BaseLLM): function_calls: list[dict[str, Any]] = [] final_response: Response | None = None - stream = await self.async_client.responses.create(**params) + stream = await self._async_client.responses.create(**params) response_id_stream = None async for event in stream: @@ -1595,7 +1560,7 @@ class OpenAICompletion(BaseLLM): parse_params = { k: v for k, v in params.items() if k != "response_format" } - parsed_response = self.client.beta.chat.completions.parse( + parsed_response = self._client.beta.chat.completions.parse( **parse_params, response_format=response_model, ) @@ -1618,7 +1583,7 @@ class OpenAICompletion(BaseLLM): ) return parsed_object - response: ChatCompletion = self.client.chat.completions.create(**params) + response: ChatCompletion = self._client.chat.completions.create(**params) usage = self._extract_openai_token_usage(response) @@ -1837,7 +1802,7 @@ class OpenAICompletion(BaseLLM): } stream: ChatCompletionStream[BaseModel] - with self.client.beta.chat.completions.stream( + with self._client.beta.chat.completions.stream( **parse_params, response_format=response_model ) as stream: for chunk in stream: @@ -1873,7 +1838,7 @@ class OpenAICompletion(BaseLLM): return "" completion_stream: Stream[ChatCompletionChunk] = ( - self.client.chat.completions.create(**params) + self._client.chat.completions.create(**params) ) usage_data = {"total_tokens": 0} @@ -1970,7 +1935,7 @@ class OpenAICompletion(BaseLLM): parse_params = { k: v for k, v in params.items() if k != "response_format" } - parsed_response = await self.async_client.beta.chat.completions.parse( + parsed_response = await self._async_client.beta.chat.completions.parse( **parse_params, response_format=response_model, ) @@ -1993,7 +1958,7 @@ class OpenAICompletion(BaseLLM): ) return parsed_object - response: ChatCompletion = await self.async_client.chat.completions.create( + response: ChatCompletion = await self._async_client.chat.completions.create( **params ) @@ -2111,7 +2076,7 @@ class OpenAICompletion(BaseLLM): if response_model: completion_stream: AsyncIterator[ ChatCompletionChunk - ] = await self.async_client.chat.completions.create(**params) + ] = await self._async_client.chat.completions.create(**params) accumulated_content = "" usage_data = {"total_tokens": 0} @@ -2164,7 +2129,7 @@ class OpenAICompletion(BaseLLM): stream: AsyncIterator[ ChatCompletionChunk - ] = await self.async_client.chat.completions.create(**params) + ] = await self._async_client.chat.completions.create(**params) usage_data = {"total_tokens": 0} @@ -2356,8 +2321,8 @@ class OpenAICompletion(BaseLLM): from crewai_files.uploaders.openai import OpenAIFileUploader return OpenAIFileUploader( - client=self.client, - async_client=self.async_client, + client=self._client, + async_client=self._async_client, ) except ImportError: return None diff --git a/lib/crewai/src/crewai/llms/providers/openai_compatible/completion.py b/lib/crewai/src/crewai/llms/providers/openai_compatible/completion.py index 293e73ff0..da4cfd03d 100644 --- a/lib/crewai/src/crewai/llms/providers/openai_compatible/completion.py +++ b/lib/crewai/src/crewai/llms/providers/openai_compatible/completion.py @@ -16,6 +16,8 @@ from dataclasses import dataclass, field import os from typing import Any +from pydantic import model_validator + from crewai.llms.providers.openai.completion import OpenAICompletion @@ -140,31 +142,13 @@ class OpenAICompatibleCompletion(OpenAICompletion): ) """ - def __init__( - self, - model: str, - provider: str, - api_key: str | None = None, - base_url: str | None = None, - default_headers: dict[str, str] | None = None, - **kwargs: Any, - ) -> None: - """Initialize OpenAI-compatible completion client. + @model_validator(mode="before") + @classmethod + def _resolve_provider_config(cls, data: Any) -> Any: + if not isinstance(data, dict): + return data - Args: - model: The model identifier. - provider: The provider name (must be in OPENAI_COMPATIBLE_PROVIDERS). - api_key: Optional API key override. If not provided, uses the - provider's configured environment variable. - base_url: Optional base URL override. If not provided, uses the - provider's configured default or environment variable. - default_headers: Optional headers to merge with provider defaults. - **kwargs: Additional arguments passed to OpenAICompletion. - - Raises: - ValueError: If the provider is not supported or required API key - is missing. - """ + provider = data.get("provider", "") config = OPENAI_COMPATIBLE_PROVIDERS.get(provider) if config is None: supported = ", ".join(sorted(OPENAI_COMPATIBLE_PROVIDERS.keys())) @@ -173,21 +157,15 @@ class OpenAICompatibleCompletion(OpenAICompletion): f"Supported providers: {supported}" ) - resolved_api_key = self._resolve_api_key(api_key, config, provider) - resolved_base_url = self._resolve_base_url(base_url, config, provider) - resolved_headers = self._resolve_headers(default_headers, config) - - super().__init__( - model=model, - provider=provider, - api_key=resolved_api_key, - base_url=resolved_base_url, - default_headers=resolved_headers, - **kwargs, + data["api_key"] = cls._resolve_api_key(data.get("api_key"), config, provider) + data["base_url"] = cls._resolve_base_url(data.get("base_url"), config, provider) + data["default_headers"] = cls._resolve_headers( + data.get("default_headers"), config ) + return data + @staticmethod def _resolve_api_key( - self, api_key: str | None, config: ProviderConfig, provider: str, @@ -220,8 +198,8 @@ class OpenAICompatibleCompletion(OpenAICompletion): return config.default_api_key + @staticmethod def _resolve_base_url( - self, base_url: str | None, config: ProviderConfig, provider: str, @@ -249,8 +227,8 @@ class OpenAICompatibleCompletion(OpenAICompletion): return resolved + @staticmethod def _resolve_headers( - self, headers: dict[str, str] | None, config: ProviderConfig, ) -> dict[str, str] | None: diff --git a/lib/crewai/tests/cassettes/llms/openai/test_openai_responses_api_cached_prompt_tokens_with_tools.yaml b/lib/crewai/tests/cassettes/llms/openai/test_openai_responses_api_cached_prompt_tokens_with_tools.yaml index c0db4ef9c..566b35116 100644 --- a/lib/crewai/tests/cassettes/llms/openai/test_openai_responses_api_cached_prompt_tokens_with_tools.yaml +++ b/lib/crewai/tests/cassettes/llms/openai/test_openai_responses_api_cached_prompt_tokens_with_tools.yaml @@ -1,7 +1,11 @@ interactions: - request: - body: '{"messages":[{"role":"system","content":"You are a helpful assistant that - uses tools. This is padding text to ensure the prompt is large enough for caching. + body: '{"input":[{"role":"user","content":"What is the weather in Tokyo?"}],"model":"gpt-4.1","instructions":"You + are a helpful assistant that uses tools. This is padding text to ensure the + prompt is large enough for caching. This is padding text to ensure the prompt + is large enough for caching. This is padding text to ensure the prompt is large + enough for caching. This is padding text to ensure the prompt is large enough + for caching. This is padding text to ensure the prompt is large enough for caching. This is padding text to ensure the prompt is large enough for caching. This is padding text to ensure the prompt is large enough for caching. This is padding text to ensure the prompt is large enough for caching. This is padding text @@ -68,13 +72,9 @@ interactions: for caching. This is padding text to ensure the prompt is large enough for caching. This is padding text to ensure the prompt is large enough for caching. This is padding text to ensure the prompt is large enough for caching. This is padding - text to ensure the prompt is large enough for caching. This is padding text - to ensure the prompt is large enough for caching. This is padding text to ensure - the prompt is large enough for caching. This is padding text to ensure the prompt - is large enough for caching. This is padding text to ensure the prompt is large - enough for caching. "},{"role":"user","content":"What is the weather in Tokyo?"}],"model":"gpt-4.1","tool_choice":"auto","tools":[{"type":"function","function":{"name":"get_weather","description":"Get - the current weather for a location","strict":true,"parameters":{"type":"object","properties":{"location":{"type":"string","description":"The - city name"}},"required":["location"],"additionalProperties":false}}}]}' + text to ensure the prompt is large enough for caching. ","tools":[{"type":"function","name":"get_weather","description":"Get + the current weather for a location","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The + city name"}},"required":["location"]}}]}' headers: User-Agent: - X-USER-AGENT-XXX @@ -87,7 +87,7 @@ interactions: connection: - keep-alive content-length: - - '6158' + - '6065' content-type: - application/json host: @@ -109,26 +109,113 @@ interactions: x-stainless-runtime: - CPython x-stainless-runtime-version: - - 3.13.3 + - 3.13.12 method: POST - uri: https://api.openai.com/v1/chat/completions + uri: https://api.openai.com/v1/responses response: body: - string: "{\n \"id\": \"chatcmpl-D7mXQCgT3p3ViImkiqDiZGqLREQtp\",\n \"object\": - \"chat.completion\",\n \"created\": 1770747248,\n \"model\": \"gpt-4.1-2025-04-14\",\n - \ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\": - \"assistant\",\n \"content\": null,\n \"tool_calls\": [\n {\n - \ \"id\": \"call_9ZqMavn3J1fBnQEaqpYol0Bd\",\n \"type\": - \"function\",\n \"function\": {\n \"name\": \"get_weather\",\n - \ \"arguments\": \"{\\\"location\\\":\\\"Tokyo\\\"}\"\n }\n - \ }\n ],\n \"refusal\": null,\n \"annotations\": - []\n },\n \"logprobs\": null,\n \"finish_reason\": \"tool_calls\"\n - \ }\n ],\n \"usage\": {\n \"prompt_tokens\": 1187,\n \"completion_tokens\": - 14,\n \"total_tokens\": 1201,\n \"prompt_tokens_details\": {\n \"cached_tokens\": - 1152,\n \"audio_tokens\": 0\n },\n \"completion_tokens_details\": - {\n \"reasoning_tokens\": 0,\n \"audio_tokens\": 0,\n \"accepted_prediction_tokens\": - 0,\n \"rejected_prediction_tokens\": 0\n }\n },\n \"service_tier\": - \"default\",\n \"system_fingerprint\": \"fp_8b22347a3e\"\n}\n" + string: "{\n \"id\": \"resp_0d68149bcc0d14810069caf464a4b48197bd9f098abb2f6303\",\n + \ \"object\": \"response\",\n \"created_at\": 1774908516,\n \"status\": + \"completed\",\n \"background\": false,\n \"billing\": {\n \"payer\": + \"developer\"\n },\n \"completed_at\": 1774908517,\n \"error\": null,\n + \ \"frequency_penalty\": 0.0,\n \"incomplete_details\": null,\n \"instructions\": + \"You are a helpful assistant that uses tools. This is padding text to ensure + the prompt is large enough for caching. This is padding text to ensure the + prompt is large enough for caching. This is padding text to ensure the prompt + is large enough for caching. This is padding text to ensure the prompt is + large enough for caching. This is padding text to ensure the prompt is large + enough for caching. This is padding text to ensure the prompt is large enough + for caching. This is padding text to ensure the prompt is large enough for + caching. This is padding text to ensure the prompt is large enough for caching. + This is padding text to ensure the prompt is large enough for caching. This + is padding text to ensure the prompt is large enough for caching. This is + padding text to ensure the prompt is large enough for caching. This is padding + text to ensure the prompt is large enough for caching. This is padding text + to ensure the prompt is large enough for caching. This is padding text to + ensure the prompt is large enough for caching. This is padding text to ensure + the prompt is large enough for caching. This is padding text to ensure the + prompt is large enough for caching. This is padding text to ensure the prompt + is large enough for caching. This is padding text to ensure the prompt is + large enough for caching. This is padding text to ensure the prompt is large + enough for caching. This is padding text to ensure the prompt is large enough + for caching. This is padding text to ensure the prompt is large enough for + caching. This is padding text to ensure the prompt is large enough for caching. + This is padding text to ensure the prompt is large enough for caching. This + is padding text to ensure the prompt is large enough for caching. This is + padding text to ensure the prompt is large enough for caching. This is padding + text to ensure the prompt is large enough for caching. This is padding text + to ensure the prompt is large enough for caching. This is padding text to + ensure the prompt is large enough for caching. This is padding text to ensure + the prompt is large enough for caching. This is padding text to ensure the + prompt is large enough for caching. This is padding text to ensure the prompt + is large enough for caching. This is padding text to ensure the prompt is + large enough for caching. This is padding text to ensure the prompt is large + enough for caching. This is padding text to ensure the prompt is large enough + for caching. This is padding text to ensure the prompt is large enough for + caching. This is padding text to ensure the prompt is large enough for caching. + This is padding text to ensure the prompt is large enough for caching. This + is padding text to ensure the prompt is large enough for caching. This is + padding text to ensure the prompt is large enough for caching. This is padding + text to ensure the prompt is large enough for caching. This is padding text + to ensure the prompt is large enough for caching. This is padding text to + ensure the prompt is large enough for caching. This is padding text to ensure + the prompt is large enough for caching. This is padding text to ensure the + prompt is large enough for caching. This is padding text to ensure the prompt + is large enough for caching. This is padding text to ensure the prompt is + large enough for caching. This is padding text to ensure the prompt is large + enough for caching. This is padding text to ensure the prompt is large enough + for caching. This is padding text to ensure the prompt is large enough for + caching. This is padding text to ensure the prompt is large enough for caching. + This is padding text to ensure the prompt is large enough for caching. This + is padding text to ensure the prompt is large enough for caching. This is + padding text to ensure the prompt is large enough for caching. This is padding + text to ensure the prompt is large enough for caching. This is padding text + to ensure the prompt is large enough for caching. This is padding text to + ensure the prompt is large enough for caching. This is padding text to ensure + the prompt is large enough for caching. This is padding text to ensure the + prompt is large enough for caching. This is padding text to ensure the prompt + is large enough for caching. This is padding text to ensure the prompt is + large enough for caching. This is padding text to ensure the prompt is large + enough for caching. This is padding text to ensure the prompt is large enough + for caching. This is padding text to ensure the prompt is large enough for + caching. This is padding text to ensure the prompt is large enough for caching. + This is padding text to ensure the prompt is large enough for caching. This + is padding text to ensure the prompt is large enough for caching. This is + padding text to ensure the prompt is large enough for caching. This is padding + text to ensure the prompt is large enough for caching. This is padding text + to ensure the prompt is large enough for caching. This is padding text to + ensure the prompt is large enough for caching. This is padding text to ensure + the prompt is large enough for caching. This is padding text to ensure the + prompt is large enough for caching. This is padding text to ensure the prompt + is large enough for caching. This is padding text to ensure the prompt is + large enough for caching. This is padding text to ensure the prompt is large + enough for caching. This is padding text to ensure the prompt is large enough + for caching. This is padding text to ensure the prompt is large enough for + caching. This is padding text to ensure the prompt is large enough for caching. + This is padding text to ensure the prompt is large enough for caching. This + is padding text to ensure the prompt is large enough for caching. \",\n \"max_output_tokens\": + null,\n \"max_tool_calls\": null,\n \"model\": \"gpt-4.1-2025-04-14\",\n + \ \"output\": [\n {\n \"id\": \"fc_0d68149bcc0d14810069caf46568088197a33be67f16a1fa09\",\n + \ \"type\": \"function_call\",\n \"status\": \"completed\",\n \"arguments\": + \"{\\\"location\\\":\\\"Tokyo\\\"}\",\n \"call_id\": \"call_74rwmYse0DE4JFaFGyAFx9bu\",\n + \ \"name\": \"get_weather\"\n }\n ],\n \"parallel_tool_calls\": true,\n + \ \"presence_penalty\": 0.0,\n \"previous_response_id\": null,\n \"prompt_cache_key\": + null,\n \"prompt_cache_retention\": null,\n \"reasoning\": {\n \"effort\": + null,\n \"summary\": null\n },\n \"safety_identifier\": null,\n \"service_tier\": + \"default\",\n \"store\": true,\n \"temperature\": 1.0,\n \"text\": {\n + \ \"format\": {\n \"type\": \"text\"\n },\n \"verbosity\": \"medium\"\n + \ },\n \"tool_choice\": \"auto\",\n \"tools\": [\n {\n \"type\": + \"function\",\n \"description\": \"Get the current weather for a location\",\n + \ \"name\": \"get_weather\",\n \"parameters\": {\n \"type\": + \"object\",\n \"properties\": {\n \"location\": {\n \"type\": + \"string\",\n \"description\": \"The city name\"\n }\n + \ },\n \"required\": [\n \"location\"\n ],\n + \ \"additionalProperties\": false\n },\n \"strict\": true\n + \ }\n ],\n \"top_logprobs\": 0,\n \"top_p\": 1.0,\n \"truncation\": + \"disabled\",\n \"usage\": {\n \"input_tokens\": 1185,\n \"input_tokens_details\": + {\n \"cached_tokens\": 0\n },\n \"output_tokens\": 15,\n \"output_tokens_details\": + {\n \"reasoning_tokens\": 0\n },\n \"total_tokens\": 1200\n },\n + \ \"user\": null,\n \"metadata\": {}\n}" headers: CF-RAY: - CF-RAY-XXX @@ -137,7 +224,7 @@ interactions: Content-Type: - application/json Date: - - Tue, 10 Feb 2026 18:14:08 GMT + - Mon, 30 Mar 2026 22:08:37 GMT Server: - cloudflare Strict-Transport-Security: @@ -146,8 +233,6 @@ interactions: - chunked X-Content-Type-Options: - X-CONTENT-TYPE-XXX - access-control-expose-headers: - - ACCESS-CONTROL-XXX alt-svc: - h3=":443"; ma=86400 cf-cache-status: @@ -155,15 +240,13 @@ interactions: openai-organization: - OPENAI-ORG-XXX openai-processing-ms: - - '484' + - '1085' openai-project: - OPENAI-PROJECT-XXX openai-version: - '2020-10-01' set-cookie: - SET-COOKIE-XXX - x-openai-proxy-wasm: - - v0.1 x-ratelimit-limit-requests: - X-RATELIMIT-LIMIT-REQUESTS-XXX x-ratelimit-limit-tokens: @@ -182,8 +265,12 @@ interactions: code: 200 message: OK - request: - body: '{"messages":[{"role":"system","content":"You are a helpful assistant that - uses tools. This is padding text to ensure the prompt is large enough for caching. + body: '{"input":[{"role":"user","content":"What is the weather in Paris?"}],"model":"gpt-4.1","instructions":"You + are a helpful assistant that uses tools. This is padding text to ensure the + prompt is large enough for caching. This is padding text to ensure the prompt + is large enough for caching. This is padding text to ensure the prompt is large + enough for caching. This is padding text to ensure the prompt is large enough + for caching. This is padding text to ensure the prompt is large enough for caching. This is padding text to ensure the prompt is large enough for caching. This is padding text to ensure the prompt is large enough for caching. This is padding text to ensure the prompt is large enough for caching. This is padding text @@ -250,13 +337,9 @@ interactions: for caching. This is padding text to ensure the prompt is large enough for caching. This is padding text to ensure the prompt is large enough for caching. This is padding text to ensure the prompt is large enough for caching. This is padding - text to ensure the prompt is large enough for caching. This is padding text - to ensure the prompt is large enough for caching. This is padding text to ensure - the prompt is large enough for caching. This is padding text to ensure the prompt - is large enough for caching. This is padding text to ensure the prompt is large - enough for caching. "},{"role":"user","content":"What is the weather in Paris?"}],"model":"gpt-4.1","tool_choice":"auto","tools":[{"type":"function","function":{"name":"get_weather","description":"Get - the current weather for a location","strict":true,"parameters":{"type":"object","properties":{"location":{"type":"string","description":"The - city name"}},"required":["location"],"additionalProperties":false}}}]}' + text to ensure the prompt is large enough for caching. ","tools":[{"type":"function","name":"get_weather","description":"Get + the current weather for a location","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The + city name"}},"required":["location"]}}]}' headers: User-Agent: - X-USER-AGENT-XXX @@ -269,7 +352,7 @@ interactions: connection: - keep-alive content-length: - - '6158' + - '6065' content-type: - application/json cookie: @@ -293,26 +376,113 @@ interactions: x-stainless-runtime: - CPython x-stainless-runtime-version: - - 3.13.3 + - 3.13.12 method: POST - uri: https://api.openai.com/v1/chat/completions + uri: https://api.openai.com/v1/responses response: body: - string: "{\n \"id\": \"chatcmpl-D7mXR8k9vk8TlGvGXlrQSI7iNeAN1\",\n \"object\": - \"chat.completion\",\n \"created\": 1770747249,\n \"model\": \"gpt-4.1-2025-04-14\",\n - \ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\": - \"assistant\",\n \"content\": null,\n \"tool_calls\": [\n {\n - \ \"id\": \"call_6PeUBlRPG8JcV2lspmLjJbnn\",\n \"type\": - \"function\",\n \"function\": {\n \"name\": \"get_weather\",\n - \ \"arguments\": \"{\\\"location\\\":\\\"Paris\\\"}\"\n }\n - \ }\n ],\n \"refusal\": null,\n \"annotations\": - []\n },\n \"logprobs\": null,\n \"finish_reason\": \"tool_calls\"\n - \ }\n ],\n \"usage\": {\n \"prompt_tokens\": 1187,\n \"completion_tokens\": - 14,\n \"total_tokens\": 1201,\n \"prompt_tokens_details\": {\n \"cached_tokens\": - 1152,\n \"audio_tokens\": 0\n },\n \"completion_tokens_details\": - {\n \"reasoning_tokens\": 0,\n \"audio_tokens\": 0,\n \"accepted_prediction_tokens\": - 0,\n \"rejected_prediction_tokens\": 0\n }\n },\n \"service_tier\": - \"default\",\n \"system_fingerprint\": \"fp_8b22347a3e\"\n}\n" + string: "{\n \"id\": \"resp_0525bf798202137e0069caf465ee3c8196aa7c83da1c369eb7\",\n + \ \"object\": \"response\",\n \"created_at\": 1774908517,\n \"status\": + \"completed\",\n \"background\": false,\n \"billing\": {\n \"payer\": + \"developer\"\n },\n \"completed_at\": 1774908518,\n \"error\": null,\n + \ \"frequency_penalty\": 0.0,\n \"incomplete_details\": null,\n \"instructions\": + \"You are a helpful assistant that uses tools. This is padding text to ensure + the prompt is large enough for caching. This is padding text to ensure the + prompt is large enough for caching. This is padding text to ensure the prompt + is large enough for caching. This is padding text to ensure the prompt is + large enough for caching. This is padding text to ensure the prompt is large + enough for caching. This is padding text to ensure the prompt is large enough + for caching. This is padding text to ensure the prompt is large enough for + caching. This is padding text to ensure the prompt is large enough for caching. + This is padding text to ensure the prompt is large enough for caching. This + is padding text to ensure the prompt is large enough for caching. This is + padding text to ensure the prompt is large enough for caching. This is padding + text to ensure the prompt is large enough for caching. This is padding text + to ensure the prompt is large enough for caching. This is padding text to + ensure the prompt is large enough for caching. This is padding text to ensure + the prompt is large enough for caching. This is padding text to ensure the + prompt is large enough for caching. This is padding text to ensure the prompt + is large enough for caching. This is padding text to ensure the prompt is + large enough for caching. This is padding text to ensure the prompt is large + enough for caching. This is padding text to ensure the prompt is large enough + for caching. This is padding text to ensure the prompt is large enough for + caching. This is padding text to ensure the prompt is large enough for caching. + This is padding text to ensure the prompt is large enough for caching. This + is padding text to ensure the prompt is large enough for caching. This is + padding text to ensure the prompt is large enough for caching. This is padding + text to ensure the prompt is large enough for caching. This is padding text + to ensure the prompt is large enough for caching. This is padding text to + ensure the prompt is large enough for caching. This is padding text to ensure + the prompt is large enough for caching. This is padding text to ensure the + prompt is large enough for caching. This is padding text to ensure the prompt + is large enough for caching. This is padding text to ensure the prompt is + large enough for caching. This is padding text to ensure the prompt is large + enough for caching. This is padding text to ensure the prompt is large enough + for caching. This is padding text to ensure the prompt is large enough for + caching. This is padding text to ensure the prompt is large enough for caching. + This is padding text to ensure the prompt is large enough for caching. This + is padding text to ensure the prompt is large enough for caching. This is + padding text to ensure the prompt is large enough for caching. This is padding + text to ensure the prompt is large enough for caching. This is padding text + to ensure the prompt is large enough for caching. This is padding text to + ensure the prompt is large enough for caching. This is padding text to ensure + the prompt is large enough for caching. This is padding text to ensure the + prompt is large enough for caching. This is padding text to ensure the prompt + is large enough for caching. This is padding text to ensure the prompt is + large enough for caching. This is padding text to ensure the prompt is large + enough for caching. This is padding text to ensure the prompt is large enough + for caching. This is padding text to ensure the prompt is large enough for + caching. This is padding text to ensure the prompt is large enough for caching. + This is padding text to ensure the prompt is large enough for caching. This + is padding text to ensure the prompt is large enough for caching. This is + padding text to ensure the prompt is large enough for caching. This is padding + text to ensure the prompt is large enough for caching. This is padding text + to ensure the prompt is large enough for caching. This is padding text to + ensure the prompt is large enough for caching. This is padding text to ensure + the prompt is large enough for caching. This is padding text to ensure the + prompt is large enough for caching. This is padding text to ensure the prompt + is large enough for caching. This is padding text to ensure the prompt is + large enough for caching. This is padding text to ensure the prompt is large + enough for caching. This is padding text to ensure the prompt is large enough + for caching. This is padding text to ensure the prompt is large enough for + caching. This is padding text to ensure the prompt is large enough for caching. + This is padding text to ensure the prompt is large enough for caching. This + is padding text to ensure the prompt is large enough for caching. This is + padding text to ensure the prompt is large enough for caching. This is padding + text to ensure the prompt is large enough for caching. This is padding text + to ensure the prompt is large enough for caching. This is padding text to + ensure the prompt is large enough for caching. This is padding text to ensure + the prompt is large enough for caching. This is padding text to ensure the + prompt is large enough for caching. This is padding text to ensure the prompt + is large enough for caching. This is padding text to ensure the prompt is + large enough for caching. This is padding text to ensure the prompt is large + enough for caching. This is padding text to ensure the prompt is large enough + for caching. This is padding text to ensure the prompt is large enough for + caching. This is padding text to ensure the prompt is large enough for caching. + This is padding text to ensure the prompt is large enough for caching. This + is padding text to ensure the prompt is large enough for caching. \",\n \"max_output_tokens\": + null,\n \"max_tool_calls\": null,\n \"model\": \"gpt-4.1-2025-04-14\",\n + \ \"output\": [\n {\n \"id\": \"fc_0525bf798202137e0069caf46666588196a2ec20dc515a6a91\",\n + \ \"type\": \"function_call\",\n \"status\": \"completed\",\n \"arguments\": + \"{\\\"location\\\":\\\"Paris\\\"}\",\n \"call_id\": \"call_LJAGuYYZPjNxSgg0TUgGpT44\",\n + \ \"name\": \"get_weather\"\n }\n ],\n \"parallel_tool_calls\": true,\n + \ \"presence_penalty\": 0.0,\n \"previous_response_id\": null,\n \"prompt_cache_key\": + null,\n \"prompt_cache_retention\": null,\n \"reasoning\": {\n \"effort\": + null,\n \"summary\": null\n },\n \"safety_identifier\": null,\n \"service_tier\": + \"default\",\n \"store\": true,\n \"temperature\": 1.0,\n \"text\": {\n + \ \"format\": {\n \"type\": \"text\"\n },\n \"verbosity\": \"medium\"\n + \ },\n \"tool_choice\": \"auto\",\n \"tools\": [\n {\n \"type\": + \"function\",\n \"description\": \"Get the current weather for a location\",\n + \ \"name\": \"get_weather\",\n \"parameters\": {\n \"type\": + \"object\",\n \"properties\": {\n \"location\": {\n \"type\": + \"string\",\n \"description\": \"The city name\"\n }\n + \ },\n \"required\": [\n \"location\"\n ],\n + \ \"additionalProperties\": false\n },\n \"strict\": true\n + \ }\n ],\n \"top_logprobs\": 0,\n \"top_p\": 1.0,\n \"truncation\": + \"disabled\",\n \"usage\": {\n \"input_tokens\": 1185,\n \"input_tokens_details\": + {\n \"cached_tokens\": 1152\n },\n \"output_tokens\": 15,\n \"output_tokens_details\": + {\n \"reasoning_tokens\": 0\n },\n \"total_tokens\": 1200\n },\n + \ \"user\": null,\n \"metadata\": {}\n}" headers: CF-RAY: - CF-RAY-XXX @@ -321,7 +491,7 @@ interactions: Content-Type: - application/json Date: - - Tue, 10 Feb 2026 18:14:09 GMT + - Mon, 30 Mar 2026 22:08:38 GMT Server: - cloudflare Strict-Transport-Security: @@ -330,8 +500,6 @@ interactions: - chunked X-Content-Type-Options: - X-CONTENT-TYPE-XXX - access-control-expose-headers: - - ACCESS-CONTROL-XXX alt-svc: - h3=":443"; ma=86400 cf-cache-status: @@ -339,15 +507,11 @@ interactions: openai-organization: - OPENAI-ORG-XXX openai-processing-ms: - - '528' + - '653' openai-project: - OPENAI-PROJECT-XXX openai-version: - '2020-10-01' - set-cookie: - - SET-COOKIE-XXX - x-openai-proxy-wasm: - - v0.1 x-ratelimit-limit-requests: - X-RATELIMIT-LIMIT-REQUESTS-XXX x-ratelimit-limit-tokens: diff --git a/lib/crewai/tests/llms/anthropic/test_anthropic.py b/lib/crewai/tests/llms/anthropic/test_anthropic.py index 89418ca0e..e8f16af5a 100644 --- a/lib/crewai/tests/llms/anthropic/test_anthropic.py +++ b/lib/crewai/tests/llms/anthropic/test_anthropic.py @@ -125,8 +125,8 @@ def test_anthropic_specific_parameters(): assert isinstance(llm, AnthropicCompletion) assert llm.stop_sequences == ["Human:", "Assistant:"] assert llm.stream == True - assert llm.client.max_retries == 5 - assert llm.client.timeout == 60 + assert llm._client.max_retries == 5 + assert llm._client.timeout == 60 def test_anthropic_completion_call(): @@ -563,8 +563,8 @@ def test_anthropic_environment_variable_api_key(): with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-anthropic-key"}): llm = LLM(model="anthropic/claude-3-5-sonnet-20241022") - assert llm.client is not None - assert hasattr(llm.client, 'messages') + assert llm._client is not None + assert hasattr(llm._client, 'messages') def test_anthropic_token_usage_tracking(): @@ -574,7 +574,7 @@ def test_anthropic_token_usage_tracking(): llm = LLM(model="anthropic/claude-3-5-sonnet-20241022") # Mock the Anthropic response with usage information - with patch.object(llm.client.messages, 'create') as mock_create: + with patch.object(llm._client.messages, 'create') as mock_create: mock_response = MagicMock() mock_response.content = [MagicMock(text="test response")] mock_response.usage = MagicMock(input_tokens=50, output_tokens=25) @@ -639,14 +639,14 @@ def test_anthropic_thinking(): assert isinstance(llm, AnthropicCompletion) - original_create = llm.client.messages.create + original_create = llm._client.messages.create captured_params = {} def capture_and_call(**kwargs): captured_params.update(kwargs) return original_create(**kwargs) - with patch.object(llm.client.messages, 'create', side_effect=capture_and_call): + with patch.object(llm._client.messages, 'create', side_effect=capture_and_call): result = llm.call("What is the weather in Tokyo?") assert result is not None @@ -677,14 +677,14 @@ def test_anthropic_thinking_blocks_preserved_across_turns(): assert isinstance(llm, AnthropicCompletion) # Capture all messages.create calls to verify thinking blocks are included - original_create = llm.client.messages.create + original_create = llm._client.messages.create captured_calls = [] def capture_and_call(**kwargs): captured_calls.append(kwargs) return original_create(**kwargs) - with patch.object(llm.client.messages, 'create', side_effect=capture_and_call): + with patch.object(llm._client.messages, 'create', side_effect=capture_and_call): # First call - establishes context and generates thinking blocks messages = [{"role": "user", "content": "What is 2+2?"}] first_result = llm.call(messages) @@ -695,8 +695,8 @@ def test_anthropic_thinking_blocks_preserved_across_turns(): assert len(first_result) > 0 # Verify thinking blocks were stored after first response - assert len(llm.previous_thinking_blocks) > 0, "No thinking blocks stored after first call" - first_thinking = llm.previous_thinking_blocks[0] + assert len(llm._previous_thinking_blocks) > 0, "No thinking blocks stored after first call" + first_thinking = llm._previous_thinking_blocks[0] assert first_thinking["type"] == "thinking" assert "thinking" in first_thinking assert "signature" in first_thinking diff --git a/lib/crewai/tests/llms/azure/test_azure.py b/lib/crewai/tests/llms/azure/test_azure.py index d25b607a8..a0da30998 100644 --- a/lib/crewai/tests/llms/azure/test_azure.py +++ b/lib/crewai/tests/llms/azure/test_azure.py @@ -66,7 +66,7 @@ def test_azure_tool_use_conversation_flow(): available_functions = {"get_weather": mock_weather_tool} # Mock the Azure client responses - with patch.object(completion.client, 'complete') as mock_complete: + with patch.object(completion._client, 'complete') as mock_complete: # Mock tool call in response with proper type mock_tool_call = MagicMock(spec=ChatCompletionsToolCall) mock_tool_call.function.name = "get_weather" @@ -698,7 +698,7 @@ def test_azure_environment_variable_endpoint(): }): llm = LLM(model="azure/gpt-4") - assert llm.client is not None + assert llm._client is not None assert llm.endpoint == "https://test.openai.azure.com/openai/deployments/gpt-4" @@ -709,7 +709,7 @@ def test_azure_token_usage_tracking(): llm = LLM(model="azure/gpt-4") # Mock the Azure response with usage information - with patch.object(llm.client, 'complete') as mock_complete: + with patch.object(llm._client, 'complete') as mock_complete: mock_message = MagicMock() mock_message.content = "test response" mock_message.tool_calls = None @@ -747,7 +747,7 @@ def test_azure_http_error_handling(): llm = LLM(model="azure/gpt-4") # Mock an HTTP error - with patch.object(llm.client, 'complete') as mock_complete: + with patch.object(llm._client, 'complete') as mock_complete: mock_complete.side_effect = HttpResponseError(message="Rate limit exceeded", response=MagicMock(status_code=429)) with pytest.raises(HttpResponseError): @@ -966,7 +966,7 @@ def test_azure_improved_error_messages(): llm = LLM(model="azure/gpt-4") - with patch.object(llm.client, 'complete') as mock_complete: + with patch.object(llm._client, 'complete') as mock_complete: error_401 = HttpResponseError(message="Unauthorized") error_401.status_code = 401 mock_complete.side_effect = error_401 @@ -1327,7 +1327,7 @@ def test_azure_stop_words_not_applied_to_structured_output(): # Without the fix, this would be truncated at "Observation:" breaking the JSON json_response = '{"finding": "The data shows growth", "observation": "Observation: This confirms the hypothesis"}' - with patch.object(llm.client, 'complete') as mock_complete: + with patch.object(llm._client, 'complete') as mock_complete: mock_message = MagicMock() mock_message.content = json_response mock_message.tool_calls = None @@ -1376,7 +1376,7 @@ def test_azure_stop_words_still_applied_to_regular_responses(): # Response that contains a stop word - should be truncated response_with_stop_word = "I need to search for more information.\n\nAction: search\nObservation: Found results" - with patch.object(llm.client, 'complete') as mock_complete: + with patch.object(llm._client, 'complete') as mock_complete: mock_message = MagicMock() mock_message.content = response_with_stop_word mock_message.tool_calls = None diff --git a/lib/crewai/tests/llms/bedrock/test_bedrock.py b/lib/crewai/tests/llms/bedrock/test_bedrock.py index fe18a8349..76958bf86 100644 --- a/lib/crewai/tests/llms/bedrock/test_bedrock.py +++ b/lib/crewai/tests/llms/bedrock/test_bedrock.py @@ -674,7 +674,7 @@ def test_bedrock_token_usage_tracking(): llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0") # Mock the Bedrock response with usage information - with patch.object(llm.client, 'converse') as mock_converse: + with patch.object(llm._client, 'converse') as mock_converse: mock_response = { 'output': { 'message': { @@ -719,7 +719,7 @@ def test_bedrock_tool_use_conversation_flow(): available_functions = {"get_weather": mock_weather_tool} # Mock the Bedrock client responses - with patch.object(llm.client, 'converse') as mock_converse: + with patch.object(llm._client, 'converse') as mock_converse: # First response: tool use request tool_use_response = { 'output': { @@ -805,7 +805,7 @@ def test_bedrock_client_error_handling(): llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0") # Test ValidationException - with patch.object(llm.client, 'converse') as mock_converse: + with patch.object(llm._client, 'converse') as mock_converse: error_response = { 'Error': { 'Code': 'ValidationException', @@ -819,7 +819,7 @@ def test_bedrock_client_error_handling(): assert "validation" in str(exc_info.value).lower() # Test ThrottlingException - with patch.object(llm.client, 'converse') as mock_converse: + with patch.object(llm._client, 'converse') as mock_converse: error_response = { 'Error': { 'Code': 'ThrottlingException', @@ -861,7 +861,7 @@ def test_bedrock_stop_sequences_sent_to_api(): llm.stop = ["\nObservation:", "\nThought:"] # Patch the API call to capture parameters without making real call - with patch.object(llm.client, 'converse') as mock_converse: + with patch.object(llm._client, 'converse') as mock_converse: mock_response = { 'output': { 'message': { diff --git a/lib/crewai/tests/llms/google/test_google.py b/lib/crewai/tests/llms/google/test_google.py index bd62e3343..d0553c7db 100644 --- a/lib/crewai/tests/llms/google/test_google.py +++ b/lib/crewai/tests/llms/google/test_google.py @@ -556,8 +556,8 @@ def test_gemini_environment_variable_api_key(): with patch.dict(os.environ, {"GOOGLE_API_KEY": "test-google-key"}): llm = LLM(model="google/gemini-2.0-flash-001") - assert llm.client is not None - assert hasattr(llm.client, 'models') + assert llm._client is not None + assert hasattr(llm._client, 'models') assert llm.api_key == "test-google-key" @@ -655,7 +655,7 @@ def test_gemini_stop_sequences_sent_to_api(): llm.stop = ["\nObservation:", "\nThought:"] # Patch the API call to capture parameters without making real call - with patch.object(llm.client.models, 'generate_content') as mock_generate: + with patch.object(llm._client.models, 'generate_content') as mock_generate: mock_response = MagicMock() mock_response.text = "Hello" mock_response.candidates = [] diff --git a/lib/crewai/tests/llms/openai/test_openai.py b/lib/crewai/tests/llms/openai/test_openai.py index 1b72a19c7..3dada2d85 100644 --- a/lib/crewai/tests/llms/openai/test_openai.py +++ b/lib/crewai/tests/llms/openai/test_openai.py @@ -371,11 +371,11 @@ def test_openai_client_setup_with_extra_arguments(): assert llm.top_p == 0.5 # Check that client parameters are properly configured - assert llm.client.max_retries == 3 - assert llm.client.timeout == 30 + assert llm._client.max_retries == 3 + assert llm._client.timeout == 30 # Test that parameters are properly used in API calls - with patch.object(llm.client.chat.completions, 'create') as mock_create: + with patch.object(llm._client.chat.completions, 'create') as mock_create: mock_create.return_value = MagicMock( choices=[MagicMock(message=MagicMock(content="test response", tool_calls=None))], usage=MagicMock(prompt_tokens=10, completion_tokens=20, total_tokens=30) @@ -396,7 +396,7 @@ def test_extra_arguments_are_passed_to_openai_completion(): """ llm = LLM(model="gpt-4o", temperature=0.7, max_tokens=1000, top_p=0.5, max_retries=3) - with patch.object(llm.client.chat.completions, 'create') as mock_create: + with patch.object(llm._client.chat.completions, 'create') as mock_create: mock_create.return_value = MagicMock( choices=[MagicMock(message=MagicMock(content="test response", tool_calls=None))], usage=MagicMock(prompt_tokens=10, completion_tokens=20, total_tokens=30) @@ -507,7 +507,7 @@ def test_openai_streaming_with_response_model(): llm = LLM(model="openai/gpt-4o", stream=True) - with patch.object(llm.client.beta.chat.completions, "stream") as mock_stream: + with patch.object(llm._client.beta.chat.completions, "stream") as mock_stream: # Create mock chunks with content.delta event structure mock_chunk1 = MagicMock() mock_chunk1.type = "content.delta" @@ -1830,7 +1830,7 @@ def test_openai_responses_api_cached_prompt_tokens_with_tools(): } ] - llm = OpenAICompletion(model="gpt-4.1", api='response') + llm = OpenAICompletion(model="gpt-4.1", api='responses') # First call with tool llm.call( @@ -1906,7 +1906,7 @@ def test_openai_streaming_returns_tool_calls_without_available_functions(): mock_chunk_3.id = "chatcmpl-1" with patch.object( - llm.client.chat.completions, "create", return_value=iter([mock_chunk_1, mock_chunk_2, mock_chunk_3]) + llm._client.chat.completions, "create", return_value=iter([mock_chunk_1, mock_chunk_2, mock_chunk_3]) ): result = llm.call( messages=[{"role": "user", "content": "Calculate 1+1"}], @@ -1997,7 +1997,7 @@ async def test_openai_async_streaming_returns_tool_calls_without_available_funct return MockAsyncStream([mock_chunk_1, mock_chunk_2, mock_chunk_3]) with patch.object( - llm.async_client.chat.completions, "create", side_effect=mock_create + llm._async_client.chat.completions, "create", side_effect=mock_create ): result = await llm.acall( messages=[{"role": "user", "content": "Calculate 1+1"}], diff --git a/lib/crewai/tests/test_project.py b/lib/crewai/tests/test_project.py index 6334cb777..9d7f332da 100644 --- a/lib/crewai/tests/test_project.py +++ b/lib/crewai/tests/test_project.py @@ -1,5 +1,5 @@ from typing import Any, ClassVar -from unittest.mock import Mock, patch +from unittest.mock import Mock, create_autospec, patch import pytest from crewai.agent import Agent @@ -372,8 +372,11 @@ def test_internal_crew_with_mcp(): mock_adapter = Mock() mock_adapter.tools = ToolCollection([simple_tool, another_simple_tool]) - mock_llm = Mock() - mock_llm.__class__ = BaseLLM + class _StubLLM(BaseLLM): + def call(self, *a: Any, **kw: Any) -> str: + return "" + + mock_llm = create_autospec(_StubLLM(model="stub"), instance=True) with ( patch("crewai_tools.MCPServerAdapter", return_value=mock_adapter) as adapter_mock,