chore: remove duplication in azure client

This commit is contained in:
Greyson LaLonde
2025-11-11 18:01:27 -05:00
parent 93f1fbd75e
commit 8b83bf3e54
22 changed files with 370 additions and 209 deletions

View File

@@ -12,10 +12,11 @@ import json
import logging
import os
import re
from typing import TYPE_CHECKING, Any, ClassVar, Final
from typing import TYPE_CHECKING, Any, Final
from dotenv import load_dotenv
import httpx
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from pydantic import BaseModel, Field, field_validator
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.llm_events import (
@@ -42,6 +43,8 @@ if TYPE_CHECKING:
from crewai.utilities.types import LLMMessage
load_dotenv()
DEFAULT_CONTEXT_WINDOW_SIZE: Final[int] = 4096
DEFAULT_SUPPORTS_STOP_WORDS: Final[bool] = True
_JSON_EXTRACTION_PATTERN: Final[re.Pattern[str]] = re.compile(r"\{.*}", re.DOTALL)
@@ -65,10 +68,6 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta):
stop: A list of stop sequences that the LLM should use to stop generation.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(
extra="allow", populate_by_name=True
)
# Core fields
model: str = Field(..., description="The model identifier/name")
temperature: float | None = Field(
@@ -100,7 +99,7 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta):
"cached_prompt_tokens": 0,
}
@field_validator("api_key", mode="before")
@field_validator("api_key", mode="after")
@classmethod
def _validate_api_key(cls, value: str | None) -> str | None:
"""Validate API key for authentication.
@@ -137,37 +136,6 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta):
return value
return []
@model_validator(mode="before")
@classmethod
def _extract_stop_and_validate(cls, values: dict[str, Any]) -> dict[str, Any]:
"""Extract and normalize stop sequences before model initialization.
Args:
values: Input values dictionary
Returns:
Processed values dictionary
"""
if not values.get("model"):
raise ValueError("Model name is required and cannot be empty")
stop = values.get("stop") or values.get("stop_sequences")
if stop is None:
values["stop"] = []
elif isinstance(stop, str):
values["stop"] = [stop]
elif isinstance(stop, list):
values["stop"] = stop
else:
values["stop"] = []
values.pop("stop_sequences", None)
if "provider" not in values or values["provider"] is None:
values["provider"] = "openai"
return values
@property
def additional_params(self) -> dict[str, Any]:
"""Get additional parameters stored as extra fields.
@@ -190,20 +158,6 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta):
self.__pydantic_extra__ = {}
self.__pydantic_extra__.update(value)
def model_post_init(self, __context: Any) -> None:
"""Initialize token usage tracking after model initialization.
Args:
__context: Pydantic context (unused)
"""
self._token_usage = {
"total_tokens": 0,
"prompt_tokens": 0,
"completion_tokens": 0,
"successful_requests": 0,
"cached_prompt_tokens": 0,
}
@abstractmethod
def call(
self,

View File

@@ -0,0 +1,14 @@
from crewai.llm.constants import SupportedNativeProviders
PROVIDER_MAPPING: dict[str, SupportedNativeProviders] = {
"openai": "openai",
"anthropic": "anthropic",
"claude": "anthropic",
"azure": "azure",
"azure_openai": "azure",
"google": "gemini",
"gemini": "gemini",
"bedrock": "bedrock",
"aws": "bedrock",
}

View File

@@ -9,6 +9,7 @@ from __future__ import annotations
import logging
from typing import Any, cast
from pydantic import ConfigDict
from pydantic._internal._model_construction import ModelMetaclass
from crewai.llm.constants import (
@@ -21,6 +22,7 @@ from crewai.llm.constants import (
SupportedModels,
SupportedNativeProviders,
)
from crewai.llm.internal.constants import PROVIDER_MAPPING
class LLMMeta(ModelMetaclass):
@@ -30,6 +32,41 @@ class LLMMeta(ModelMetaclass):
native provider implementation based on the model parameter.
"""
def __new__(
mcs,
name: str,
bases: tuple[type, ...],
namespace: dict[str, Any],
**kwargs: Any,
) -> type:
"""Create new LLM class with proper model_config for custom LLMs.
Args:
name: Class name
bases: Base classes
namespace: Class namespace
**kwargs: Additional arguments
Returns:
New class
"""
if name != "BaseLLM" and any(
base.__name__ in ("BaseLLM", "LLM") for base in bases
):
if "model_config" not in namespace:
namespace["model_config"] = ConfigDict(
extra="allow", populate_by_name=True
)
elif isinstance(namespace["model_config"], dict):
config_dict = cast(
ConfigDict, cast(object, dict(namespace["model_config"]))
)
config_dict.setdefault("extra", "allow")
config_dict.setdefault("populate_by_name", True)
namespace["model_config"] = ConfigDict(**config_dict)
return super().__new__(mcs, name, bases, namespace)
def __call__(cls, *args: Any, **kwargs: Any) -> Any: # noqa: N805
"""Route to appropriate provider implementation at instantiation time.
@@ -57,7 +94,7 @@ class LLMMeta(ModelMetaclass):
if args and not kwargs.get("model"):
kwargs["model"] = cast(SupportedModels, args[0])
args = args[1:]
_ = args[1:]
explicit_provider = cast(SupportedNativeProviders, kwargs.get("provider"))
if explicit_provider:
@@ -70,19 +107,7 @@ class LLMMeta(ModelMetaclass):
model.partition("/"),
)
provider_mapping: dict[str, SupportedNativeProviders] = {
"openai": "openai",
"anthropic": "anthropic",
"claude": "anthropic",
"azure": "azure",
"azure_openai": "azure",
"google": "gemini",
"gemini": "gemini",
"bedrock": "bedrock",
"aws": "bedrock",
}
canonical_provider = provider_mapping.get(prefix.lower())
canonical_provider = PROVIDER_MAPPING.get(prefix.lower())
if canonical_provider and cls._validate_model_in_constants(
model_part, canonical_provider

View File

@@ -4,6 +4,7 @@ import json
import logging
from typing import TYPE_CHECKING, Any, cast
from dotenv import load_dotenv
import httpx
from pydantic import BaseModel, Field, PrivateAttr, model_validator
from typing_extensions import Self
@@ -21,13 +22,14 @@ from crewai.utilities.types import LLMMessage
if TYPE_CHECKING:
from anthropic.types import Message
from crewai.agent.core import Agent
from crewai.task import Task
try:
from anthropic import Anthropic
from anthropic.types import Message
from anthropic.types.tool_use_block import ToolUseBlock
except ImportError:
raise ImportError(
@@ -35,6 +37,9 @@ except ImportError:
) from None
load_dotenv()
class AnthropicCompletion(BaseLLM):
"""Anthropic native completion implementation.
@@ -66,11 +71,9 @@ class AnthropicCompletion(BaseLLM):
top_p: float | None = Field(default=None, description="Nucleus sampling parameter")
stream: bool = Field(default=False, description="Enable streaming responses")
client_params: dict[str, Any] | None = Field(
default=None, description="Additional Anthropic client parameters"
)
client: Anthropic = Field(
default_factory=Anthropic, exclude=True, description="Anthropic client instance"
default_factory=dict, description="Additional Anthropic client parameters"
)
_client: Anthropic = PrivateAttr(default=None) # type: ignore[assignment]
_is_claude_3: bool = PrivateAttr(default=False)
_supports_tools: bool = PrivateAttr(default=False)
@@ -78,7 +81,7 @@ class AnthropicCompletion(BaseLLM):
@model_validator(mode="after")
def setup_client(self) -> Self:
"""Initialize the Anthropic client and model-specific settings."""
self.client = Anthropic(**self._get_client_params())
self._client = Anthropic(**self._get_client_params())
self._is_claude_3 = "claude-3" in self.model.lower()
self._supports_tools = self._is_claude_3
@@ -98,9 +101,6 @@ class AnthropicCompletion(BaseLLM):
def _get_client_params(self) -> dict[str, Any]:
"""Get client parameters."""
if self.api_key is None:
raise ValueError("ANTHROPIC_API_KEY is required")
client_params = {
"api_key": self.api_key,
"base_url": self.base_url,
@@ -330,7 +330,7 @@ class AnthropicCompletion(BaseLLM):
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
try:
response: Message = self.client.messages.create(**params)
response: Message = self._client.messages.create(**params)
except Exception as e:
if is_context_length_exceeded(e):
@@ -424,7 +424,7 @@ class AnthropicCompletion(BaseLLM):
stream_params = {k: v for k, v in params.items() if k != "stream"}
# Make streaming API call
with self.client.messages.stream(**stream_params) as stream:
with self._client.messages.stream(**stream_params) as stream:
for event in stream:
if hasattr(event, "delta") and hasattr(event.delta, "text"):
text_delta = event.delta.text
@@ -552,7 +552,7 @@ class AnthropicCompletion(BaseLLM):
try:
# Send tool results back to Claude for final response
final_response: Message = self.client.messages.create(**follow_up_params)
final_response: Message = self._client.messages.create(**follow_up_params)
# Track token usage for follow-up call
follow_up_usage = self._extract_anthropic_token_usage(final_response)

View File

@@ -5,6 +5,7 @@ import logging
import os
from typing import TYPE_CHECKING, Any
from dotenv import load_dotenv
from pydantic import BaseModel, Field, PrivateAttr, model_validator
from typing_extensions import Self
@@ -48,6 +49,9 @@ except ImportError:
) from None
load_dotenv()
class AzureCompletion(BaseLLM):
"""Azure AI Inference native completion implementation.
@@ -68,12 +72,14 @@ class AzureCompletion(BaseLLM):
interceptor: HTTP interceptor (not yet supported for Azure)
"""
endpoint: str | None = Field(
default=None,
endpoint: str = Field( # type: ignore[assignment]
default_factory=lambda: os.getenv("AZURE_ENDPOINT")
or os.getenv("AZURE_OPENAI_ENDPOINT")
or os.getenv("AZURE_API_BASE"),
description="Azure endpoint URL (defaults to AZURE_ENDPOINT env var)",
)
api_version: str = Field(
default="2024-06-01",
default_factory=lambda: os.getenv("AZURE_API_VERSION", "2024-06-01"),
description="Azure API version (defaults to AZURE_API_VERSION env var or 2024-06-01)",
)
timeout: float | None = Field(
@@ -82,18 +88,16 @@ class AzureCompletion(BaseLLM):
max_retries: int = Field(default=2, description="Maximum number of retries")
top_p: float | None = Field(default=None, description="Nucleus sampling parameter")
frequency_penalty: float | None = Field(
default=None, description="Frequency penalty (-2 to 2)"
default=None, le=2.0, ge=-2.0, description="Frequency penalty (-2 to 2)"
)
presence_penalty: float | None = Field(
default=None, description="Presence penalty (-2 to 2)"
default=None, le=2.0, ge=-2.0, description="Presence penalty (-2 to 2)"
)
max_tokens: int | None = Field(
default=None, description="Maximum tokens in response"
)
stream: bool = Field(default=False, description="Enable streaming responses")
_client: ChatCompletionsClient = PrivateAttr(
default_factory=ChatCompletionsClient, # type: ignore[arg-type]
)
_client: ChatCompletionsClient = PrivateAttr(default=None) # type: ignore[assignment]
_is_openai_model: bool = PrivateAttr(default=False)
_is_azure_openai_endpoint: bool = PrivateAttr(default=False)
@@ -107,26 +111,13 @@ class AzureCompletion(BaseLLM):
"Interceptors are currently supported for OpenAI and Anthropic providers only."
)
if self.endpoint is None:
self.endpoint = (
os.getenv("AZURE_ENDPOINT")
or os.getenv("AZURE_OPENAI_ENDPOINT")
or os.getenv("AZURE_API_BASE")
)
if self.api_version == "2024-06-01":
env_version = os.getenv("AZURE_API_VERSION")
if env_version:
self.api_version = env_version
if not self.api_key:
self.api_key = os.getenv("AZURE_API_KEY")
if not self.api_key:
raise ValueError(
"Azure API key is required. Set AZURE_API_KEY environment variable or pass api_key parameter."
)
if not self.endpoint:
raise ValueError(
"Azure endpoint is required. Set AZURE_ENDPOINT environment variable or pass endpoint parameter."
)
self.endpoint = self._validate_and_fix_endpoint(self.endpoint, self.model)
@@ -138,7 +129,7 @@ class AzureCompletion(BaseLLM):
if self.api_version:
client_kwargs["api_version"] = self.api_version
self.client = ChatCompletionsClient(**client_kwargs)
self._client = ChatCompletionsClient(**client_kwargs)
self._is_openai_model = any(
prefix in self.model.lower() for prefix in ["gpt-", "o1-", "text-"]
@@ -160,7 +151,7 @@ class AzureCompletion(BaseLLM):
"""Check if endpoint is an Azure OpenAI endpoint."""
return self._is_azure_openai_endpoint
def _validate_and_fix_endpoint(self, endpoint: str, model: str) -> str:
def _validate_and_fix_endpoint(self, endpoint: str | None, model: str) -> str:
"""Validate and fix Azure endpoint URL format.
Azure OpenAI endpoints should be in the format:
@@ -172,7 +163,15 @@ class AzureCompletion(BaseLLM):
Returns:
Validated and potentially corrected endpoint URL
Raises:
ValueError: If endpoint is None or empty
"""
if not endpoint:
raise ValueError(
"Azure endpoint is required. Set AZURE_ENDPOINT environment variable or pass endpoint parameter."
)
if "openai.azure.com" in endpoint and "/openai/deployments/" not in endpoint:
endpoint = endpoint.rstrip("/")
@@ -388,7 +387,7 @@ class AzureCompletion(BaseLLM):
"""Handle non-streaming chat completion."""
# Make API call
try:
response: ChatCompletions = self.client.complete(**params)
response: ChatCompletions = self._client.complete(**params)
if not response.choices:
raise ValueError("No choices returned from Azure API")
@@ -486,7 +485,7 @@ class AzureCompletion(BaseLLM):
tool_calls = {}
# Make streaming API call
for update in self.client.complete(**params):
for update in self._client.complete(**params):
if isinstance(update, StreamingChatCompletionsUpdate):
if update.choices:
choice = update.choices[0]

View File

@@ -5,6 +5,8 @@ import logging
import os
from typing import TYPE_CHECKING, Any, TypedDict, cast
from dotenv import load_dotenv
from mypy_boto3_bedrock_runtime.client import BedrockRuntimeClient
from pydantic import BaseModel, Field, PrivateAttr, model_validator
from typing_extensions import Required, Self
@@ -75,6 +77,9 @@ else:
topK: int
load_dotenv()
class ToolInputSchema(TypedDict):
"""Type definition for tool input schema in Converse API."""
@@ -161,16 +166,22 @@ class BedrockCompletion(BaseLLM):
interceptor: HTTP interceptor (not yet supported for Bedrock)
"""
aws_access_key_id: str | None = Field(
default=None, description="AWS access key (defaults to environment variable)"
aws_access_key_id: str = Field( # type: ignore[assignment]
default_factory=lambda: os.getenv("AWS_ACCESS_KEY_ID"),
description="AWS access key (defaults to environment variable)",
)
aws_secret_access_key: str | None = Field(
default=None, description="AWS secret key (defaults to environment variable)"
aws_secret_access_key: str = Field( # type: ignore[assignment]
default_factory=lambda: os.getenv("AWS_SECRET_ACCESS_KEY"),
description="AWS secret key (defaults to environment variable)",
)
aws_session_token: str | None = Field(
default=None, description="AWS session token for temporary credentials"
aws_session_token: str = Field( # type: ignore[assignment]
default_factory=lambda: os.getenv("AWS_SESSION_TOKEN"),
description="AWS session token for temporary credentials",
)
region_name: str = Field(
default_factory=lambda: os.getenv("AWS_REGION", "us-east-1"),
description="AWS region name",
)
region_name: str = Field(default="us-east-1", description="AWS region name")
max_tokens: int | None = Field(
default=None, description="Maximum tokens to generate"
)
@@ -181,17 +192,18 @@ class BedrockCompletion(BaseLLM):
stream: bool = Field(
default=False, description="Whether to use streaming responses"
)
guardrail_config: dict[str, Any] | None = Field(
default=None, description="Guardrail configuration for content filtering"
guardrail_config: dict[str, Any] = Field(
default_factory=dict,
description="Guardrail configuration for content filtering",
)
additional_model_request_fields: dict[str, Any] | None = Field(
default=None, description="Model-specific request parameters"
additional_model_request_fields: dict[str, Any] = Field(
default_factory=dict, description="Model-specific request parameters"
)
additional_model_response_field_paths: list[str] | None = Field(
default=None, description="Custom response field paths"
additional_model_response_field_paths: list[str] = Field(
default_factory=list, description="Custom response field paths"
)
client: Any = Field(
default=None, exclude=True, description="Bedrock client instance"
_client: BedrockRuntimeClient = PrivateAttr( # type: ignore[assignment]
default_factory=lambda: Session().client,
)
_is_claude_model: bool = PrivateAttr(default=False)
@@ -209,10 +221,9 @@ class BedrockCompletion(BaseLLM):
)
session = Session(
aws_access_key_id=self.aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID"),
aws_secret_access_key=self.aws_secret_access_key
or os.getenv("AWS_SECRET_ACCESS_KEY"),
aws_session_token=self.aws_session_token or os.getenv("AWS_SESSION_TOKEN"),
aws_access_key_id=self.aws_access_key_id,
aws_secret_access_key=self.aws_secret_access_key,
aws_session_token=self.aws_session_token,
region_name=self.region_name,
)
@@ -225,7 +236,7 @@ class BedrockCompletion(BaseLLM):
tcp_keepalive=True,
)
self.client = session.client("bedrock-runtime", config=config)
self._client = session.client("bedrock-runtime", config=config)
self._is_claude_model = "claude" in self.model.lower()
self._supports_tools = True
@@ -365,7 +376,7 @@ class BedrockCompletion(BaseLLM):
raise ValueError(f"Invalid message format at index {i}")
# Call Bedrock Converse API with proper error handling
response = self.client.converse(
response = self._client.converse(
modelId=self.model_id,
messages=cast(
"Sequence[MessageTypeDef | MessageOutputTypeDef]",
@@ -536,13 +547,13 @@ class BedrockCompletion(BaseLLM):
tool_use_id = None
try:
response = self.client.converse_stream(
response = self._client.converse_stream(
modelId=self.model_id,
messages=cast(
"Sequence[MessageTypeDef | MessageOutputTypeDef]",
cast(object, messages),
),
**body,
**body, # type: ignore[arg-type]
)
stream = response.get("stream")

View File

@@ -1,8 +1,11 @@
from __future__ import annotations
import logging
import os
from typing import TYPE_CHECKING, Any, ClassVar, cast
from typing import TYPE_CHECKING, Any, cast
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
from dotenv import load_dotenv
from pydantic import BaseModel, Field, PrivateAttr, model_validator
from typing_extensions import Self
from crewai.events.types.llm_events import LLMCallType
@@ -31,6 +34,9 @@ except ImportError:
) from None
load_dotenv()
class GeminiCompletion(BaseLLM):
"""Google Gemini native completion implementation.
@@ -50,15 +56,12 @@ class GeminiCompletion(BaseLLM):
interceptor: HTTP interceptor (not yet supported for Gemini)
"""
model_config: ClassVar[ConfigDict] = ConfigDict(
ignored_types=(property,), arbitrary_types_allowed=True
)
project: str | None = Field(
default=None, description="Google Cloud project ID (for Vertex AI)"
default_factory=lambda: os.getenv("GOOGLE_CLOUD_PROJECT"),
description="Google Cloud project ID (for Vertex AI)",
)
location: str = Field(
default="us-central1",
default_factory=lambda: os.getenv("GOOGLE_CLOUD_LOCATION", "us-central1"),
description="Google Cloud location (for Vertex AI, defaults to 'us-central1')",
)
top_p: float | None = Field(default=None, description="Nucleus sampling parameter")
@@ -74,9 +77,7 @@ class GeminiCompletion(BaseLLM):
default_factory=dict,
description="Additional parameters for Google Gen AI Client constructor",
)
client: Any = Field(
default=None, exclude=True, description="Gemini client instance"
)
_client: Any = PrivateAttr(default=None)
_is_gemini_2: bool = PrivateAttr(default=False)
_is_gemini_1_5: bool = PrivateAttr(default=False)
@@ -94,17 +95,9 @@ class GeminiCompletion(BaseLLM):
if self.api_key is None:
self.api_key = os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY")
if self.project is None:
self.project = os.getenv("GOOGLE_CLOUD_PROJECT")
if self.location == "us-central1":
env_location = os.getenv("GOOGLE_CLOUD_LOCATION")
if env_location:
self.location = env_location
use_vertexai = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true"
self.client = self._initialize_client(use_vertexai)
self._client = self._initialize_client(use_vertexai)
self._is_gemini_2 = "gemini-2" in self.model.lower()
self._is_gemini_1_5 = "gemini-1.5" in self.model.lower()
@@ -176,9 +169,9 @@ class GeminiCompletion(BaseLLM):
params = {}
if (
hasattr(self, "client")
and hasattr(self.client, "vertexai")
and self.client.vertexai
hasattr(self, "_client")
and hasattr(self._client, "vertexai")
and self._client.vertexai
):
params.update(
{
@@ -400,7 +393,7 @@ class GeminiCompletion(BaseLLM):
}
try:
response = self.client.models.generate_content(**api_params)
response = self._client.models.generate_content(**api_params)
usage = self._extract_token_usage(response)
except Exception as e:
@@ -468,7 +461,7 @@ class GeminiCompletion(BaseLLM):
"config": config,
}
for chunk in self.client.models.generate_content_stream(**api_params):
for chunk in self._client.models.generate_content_stream(**api_params):
if hasattr(chunk, "text") and chunk.text:
full_response += chunk.text
self._emit_stream_chunk_event(

View File

@@ -6,6 +6,7 @@ import logging
import os
from typing import TYPE_CHECKING, Any
from dotenv import load_dotenv
import httpx
from openai import APIConnectionError, NotFoundError, OpenAI
from openai.types.chat import ChatCompletion, ChatCompletionChunk
@@ -32,6 +33,9 @@ if TYPE_CHECKING:
from crewai.tools.base_tool import BaseTool
load_dotenv()
class OpenAICompletion(BaseLLM):
"""OpenAI native completion implementation.
@@ -40,43 +44,51 @@ class OpenAICompletion(BaseLLM):
"""
# Client configuration fields
organization: str | None = Field(None, description="OpenAI organization ID")
project: str | None = Field(None, description="OpenAI project ID")
max_retries: int = Field(2, description="Maximum number of retries")
default_headers: dict[str, str] | None = Field(
None, description="Default headers for requests"
organization: str | None = Field(default=None, description="OpenAI organization ID")
project: str | None = Field(default=None, description="OpenAI project ID")
max_retries: int = Field(default=2, description="Maximum number of retries")
default_headers: dict[str, str] = Field(
default_factory=dict, description="Default headers for requests"
)
default_query: dict[str, Any] | None = Field(
None, description="Default query parameters"
default_query: dict[str, Any] = Field(
default_factory=dict, description="Default query parameters"
)
client_params: dict[str, Any] | None = Field(
None, description="Additional client parameters"
client_params: dict[str, Any] = Field(
default_factory=dict, description="Additional client parameters"
)
timeout: float | None = Field(default=None, description="Request timeout")
api_base: str | None = Field(
default=None, description="API base URL", deprecated=True
)
timeout: float | None = Field(None, description="Request timeout")
api_base: str | None = Field(None, description="API base URL (deprecated)")
# Completion parameters
top_p: float | None = Field(None, description="Top-p sampling parameter")
frequency_penalty: float | None = Field(None, description="Frequency penalty")
presence_penalty: float | None = Field(None, description="Presence penalty")
max_tokens: int | None = Field(None, description="Maximum tokens")
top_p: float | None = Field(default=None, description="Top-p sampling parameter")
frequency_penalty: float | None = Field(
default=None, description="Frequency penalty"
)
presence_penalty: float | None = Field(default=None, description="Presence penalty")
max_tokens: int | None = Field(default=None, description="Maximum tokens")
max_completion_tokens: int | None = Field(
None, description="Maximum completion tokens"
)
seed: int | None = Field(None, description="Random seed")
stream: bool = Field(False, description="Enable streaming")
seed: int | None = Field(default=None, description="Random seed")
stream: bool = Field(default=False, description="Enable streaming")
response_format: dict[str, Any] | type[BaseModel] | None = Field(
None, description="Response format"
default=None, description="Response format"
)
logprobs: bool | None = Field(None, description="Return log probabilities")
logprobs: bool | None = Field(default=None, description="Return log probabilities")
top_logprobs: int | None = Field(
None, description="Number of top log probabilities"
default=None, description="Number of top log probabilities"
)
reasoning_effort: str | None = Field(
default=None, description="Reasoning effort level"
)
reasoning_effort: str | None = Field(None, description="Reasoning effort level")
_client: OpenAI = PrivateAttr(default_factory=OpenAI)
is_o1_model: bool = Field(False, description="Whether this is an O1 model")
is_gpt4_model: bool = Field(False, description="Whether this is a GPT-4 model")
_client: OpenAI = PrivateAttr(default=None) # type: ignore[assignment]
is_o1_model: bool = Field(default=False, description="Whether this is an O1 model")
is_gpt4_model: bool = Field(
default=False, description="Whether this is a GPT-4 model"
)
@model_validator(mode="after")
def setup_client(self) -> Self:
@@ -97,10 +109,6 @@ class OpenAICompletion(BaseLLM):
def _get_client_params(self) -> dict[str, Any]:
"""Get OpenAI client parameters."""
if self.api_key is None:
raise ValueError("OPENAI_API_KEY is required")
base_params = {
"api_key": self.api_key,
"organization": self.organization,

View File

@@ -1 +1,38 @@
"""LLM implementations for crewAI."""
"""LLM implementations for crewAI.
.. deprecated:: 1.4.0
The `crewai.llms` package is deprecated. Use `crewai.llm` instead.
This package was reorganized from `crewai.llms.*` to `crewai.llm.*`.
All submodules are redirected to their new locations in `crewai.llm.*`.
Migration guide:
Old: from crewai.llms.base_llm import BaseLLM
New: from crewai.llm.base_llm import BaseLLM
Old: from crewai.llms.hooks.base import BaseInterceptor
New: from crewai.llm.hooks.base import BaseInterceptor
Old: from crewai.llms.constants import OPENAI_MODELS
New: from crewai.llm.constants import OPENAI_MODELS
Or use top-level imports:
from crewai import LLM, BaseLLM
"""
import warnings
from crewai.llm import LLM
from crewai.llm.base_llm import BaseLLM
# Issue deprecation warning when this module is imported
warnings.warn(
"The 'crewai.llms' package is deprecated and will be removed in a future version. "
"Please use 'crewai.llm' (singular) instead. "
"All submodules have been reorganized from 'crewai.llms.*' to 'crewai.llm.*'.",
DeprecationWarning,
stacklevel=2,
)
__all__ = ["LLM", "BaseLLM"]

View File

@@ -0,0 +1,15 @@
"""Deprecated: Use crewai.llm.base_llm instead.
.. deprecated:: 1.4.0
"""
import warnings
warnings.warn(
"crewai.llms.base_llm is deprecated. Use crewai.llm.base_llm instead.",
DeprecationWarning,
stacklevel=2,
)
from crewai.llm.base_llm import * # noqa: E402, F403

View File

@@ -0,0 +1,15 @@
"""Deprecated: Use crewai.llm.constants instead.
.. deprecated:: 1.4.0
"""
import warnings
warnings.warn(
"crewai.llms.constants is deprecated. Use crewai.llm.constants instead.",
DeprecationWarning,
stacklevel=2,
)
from crewai.llm.constants import * # noqa: E402, F403

View File

@@ -0,0 +1,15 @@
"""Deprecated: Use crewai.llm.hooks instead.
.. deprecated:: 1.4.0
"""
import warnings
warnings.warn(
"crewai.llms.hooks is deprecated. Use crewai.llm.hooks instead.",
DeprecationWarning,
stacklevel=2,
)
from crewai.llm.hooks import * # noqa: E402, F403

View File

@@ -0,0 +1,15 @@
"""Deprecated: Use crewai.llm.hooks.base instead.
.. deprecated:: 1.4.0
"""
import warnings
warnings.warn(
"crewai.llms.hooks.base is deprecated. Use crewai.llm.hooks.base instead.",
DeprecationWarning,
stacklevel=2,
)
from crewai.llm.hooks.base import * # noqa: E402, F403

View File

@@ -0,0 +1,15 @@
"""Deprecated: Use crewai.llm.hooks.transport instead.
.. deprecated:: 1.4.0
"""
import warnings
warnings.warn(
"crewai.llms.hooks.transport is deprecated. Use crewai.llm.hooks.transport instead.",
DeprecationWarning,
stacklevel=2,
)
from crewai.llm.hooks.transport import * # noqa: E402, F403

View File

@@ -0,0 +1,15 @@
"""Deprecated: Use crewai.llm.internal instead.
.. deprecated:: 1.4.0
"""
import warnings
warnings.warn(
"crewai.llms.internal is deprecated. Use crewai.llm.internal instead.",
DeprecationWarning,
stacklevel=2,
)
from crewai.llm.internal import * # noqa: E402, F403

View File

@@ -0,0 +1,15 @@
"""Deprecated: Use crewai.llm.internal.constants instead.
.. deprecated:: 1.4.0
"""
import warnings
warnings.warn(
"crewai.llms.internal.constants is deprecated. Use crewai.llm.internal.constants instead.",
DeprecationWarning,
stacklevel=2,
)
from crewai.llm.internal.constants import * # noqa: E402, F403

View File

@@ -0,0 +1,15 @@
"""Deprecated: Use crewai.llm.providers instead.
.. deprecated:: 1.4.0
"""
import warnings
warnings.warn(
"crewai.llms.providers is deprecated. Use crewai.llm.providers instead.",
DeprecationWarning,
stacklevel=2,
)
from crewai.llm.providers import * # noqa: E402, F403

View File

@@ -60,7 +60,7 @@ def test_anthropic_tool_use_conversation_flow():
available_functions = {"get_weather": mock_weather_tool}
# Mock the Anthropic client responses
with patch.object(completion.client.messages, 'create') as mock_create:
with patch.object(completion._client.messages, 'create') as mock_create:
# Mock initial response with tool use - need to properly mock ToolUseBlock
mock_tool_use = Mock(spec=ToolUseBlock)
mock_tool_use.id = "tool_123"
@@ -199,8 +199,8 @@ def test_anthropic_specific_parameters():
assert isinstance(llm, AnthropicCompletion)
assert llm.stop == ["Human:", "Assistant:"]
assert llm.stream == True
assert llm.client.max_retries == 5
assert llm.client.timeout == 60
assert llm._client.max_retries == 5
assert llm._client.timeout == 60
def test_anthropic_completion_call():
@@ -637,8 +637,8 @@ def test_anthropic_environment_variable_api_key():
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-anthropic-key"}):
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
assert llm.client is not None
assert hasattr(llm.client, 'messages')
assert llm._client is not None
assert hasattr(llm._client, 'messages')
def test_anthropic_token_usage_tracking():
@@ -648,7 +648,7 @@ def test_anthropic_token_usage_tracking():
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
# Mock the Anthropic response with usage information
with patch.object(llm.client.messages, 'create') as mock_create:
with patch.object(llm._client.messages, 'create') as mock_create:
mock_response = MagicMock()
mock_response.content = [MagicMock(text="test response")]
mock_response.usage = MagicMock(input_tokens=50, output_tokens=25)

View File

@@ -64,7 +64,7 @@ def test_azure_tool_use_conversation_flow():
available_functions = {"get_weather": mock_weather_tool}
# Mock the Azure client responses
with patch.object(completion.client, 'complete') as mock_complete:
with patch.object(completion._client, 'complete') as mock_complete:
# Mock tool call in response with proper type
mock_tool_call = MagicMock(spec=ChatCompletionsToolCall)
mock_tool_call.function.name = "get_weather"
@@ -602,7 +602,7 @@ def test_azure_environment_variable_endpoint():
}):
llm = LLM(model="azure/gpt-4")
assert llm.client is not None
assert llm._client is not None
assert llm.endpoint == "https://test.openai.azure.com/openai/deployments/gpt-4"
@@ -613,7 +613,7 @@ def test_azure_token_usage_tracking():
llm = LLM(model="azure/gpt-4")
# Mock the Azure response with usage information
with patch.object(llm.client, 'complete') as mock_complete:
with patch.object(llm._client, 'complete') as mock_complete:
mock_message = MagicMock()
mock_message.content = "test response"
mock_message.tool_calls = None
@@ -651,7 +651,7 @@ def test_azure_http_error_handling():
llm = LLM(model="azure/gpt-4")
# Mock an HTTP error
with patch.object(llm.client, 'complete') as mock_complete:
with patch.object(llm._client, 'complete') as mock_complete:
mock_complete.side_effect = HttpResponseError(message="Rate limit exceeded", response=MagicMock(status_code=429))
with pytest.raises(HttpResponseError):
@@ -668,7 +668,7 @@ def test_azure_streaming_completion():
llm = LLM(model="azure/gpt-4", stream=True)
# Mock streaming response
with patch.object(llm.client, 'complete') as mock_complete:
with patch.object(llm._client, 'complete') as mock_complete:
# Create mock streaming updates with proper type
mock_updates = []
for chunk in ["Hello", " ", "world", "!"]:
@@ -891,7 +891,7 @@ def test_azure_improved_error_messages():
llm = LLM(model="azure/gpt-4")
with patch.object(llm.client, 'complete') as mock_complete:
with patch.object(llm._client, 'complete') as mock_complete:
error_401 = HttpResponseError(message="Unauthorized")
error_401.status_code = 401
mock_complete.side_effect = error_401

View File

@@ -579,7 +579,7 @@ def test_bedrock_token_usage_tracking():
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
# Mock the Bedrock response with usage information
with patch.object(llm.client, 'converse') as mock_converse:
with patch.object(llm._client, 'converse') as mock_converse:
mock_response = {
'output': {
'message': {
@@ -624,7 +624,7 @@ def test_bedrock_tool_use_conversation_flow():
available_functions = {"get_weather": mock_weather_tool}
# Mock the Bedrock client responses
with patch.object(llm.client, 'converse') as mock_converse:
with patch.object(llm._client, 'converse') as mock_converse:
# First response: tool use request
tool_use_response = {
'output': {
@@ -710,7 +710,7 @@ def test_bedrock_client_error_handling():
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
# Test ValidationException
with patch.object(llm.client, 'converse') as mock_converse:
with patch.object(llm._client, 'converse') as mock_converse:
error_response = {
'Error': {
'Code': 'ValidationException',
@@ -724,7 +724,7 @@ def test_bedrock_client_error_handling():
assert "validation" in str(exc_info.value).lower()
# Test ThrottlingException
with patch.object(llm.client, 'converse') as mock_converse:
with patch.object(llm._client, 'converse') as mock_converse:
error_response = {
'Error': {
'Code': 'ThrottlingException',
@@ -762,7 +762,7 @@ def test_bedrock_stop_sequences_sent_to_api():
llm.stop = ["\nObservation:", "\nThought:"]
# Patch the API call to capture parameters without making real call
with patch.object(llm.client, 'converse') as mock_converse:
with patch.object(llm._client, 'converse') as mock_converse:
mock_response = {
'output': {
'message': {

View File

@@ -59,7 +59,7 @@ def test_gemini_tool_use_conversation_flow():
available_functions = {"get_weather": mock_weather_tool}
# Mock the Google Gemini client responses
with patch.object(completion.client.models, 'generate_content') as mock_generate:
with patch.object(completion._client.models, 'generate_content') as mock_generate:
# Mock function call in response
mock_function_call = Mock()
mock_function_call.name = "get_weather"
@@ -614,8 +614,8 @@ def test_gemini_environment_variable_api_key():
with patch.dict(os.environ, {"GOOGLE_API_KEY": "test-google-key"}):
llm = LLM(model="google/gemini-2.0-flash-001")
assert llm.client is not None
assert hasattr(llm.client, 'models')
assert llm._client is not None
assert hasattr(llm._client, 'models')
assert llm.api_key == "test-google-key"
@@ -626,7 +626,7 @@ def test_gemini_token_usage_tracking():
llm = LLM(model="google/gemini-2.0-flash-001")
# Mock the Gemini response with usage information
with patch.object(llm.client.models, 'generate_content') as mock_generate:
with patch.object(llm._client.models, 'generate_content') as mock_generate:
mock_response = MagicMock()
mock_response.text = "test response"
mock_response.candidates = []
@@ -675,7 +675,7 @@ def test_gemini_stop_sequences_sent_to_api():
llm.stop = ["\nObservation:", "\nThought:"]
# Patch the API call to capture parameters without making real call
with patch.object(llm.client.models, 'generate_content') as mock_generate:
with patch.object(llm._client.models, 'generate_content') as mock_generate:
mock_response = MagicMock()
mock_response.text = "Hello"
mock_response.candidates = []

View File

@@ -369,11 +369,11 @@ def test_openai_client_setup_with_extra_arguments():
assert llm.top_p == 0.5
# Check that client parameters are properly configured
assert llm.client.max_retries == 3
assert llm.client.timeout == 30
assert llm._client.max_retries == 3
assert llm._client.timeout == 30
# Test that parameters are properly used in API calls
with patch.object(llm.client.chat.completions, 'create') as mock_create:
with patch.object(llm._client.chat.completions, 'create') as mock_create:
mock_create.return_value = MagicMock(
choices=[MagicMock(message=MagicMock(content="test response", tool_calls=None))],
usage=MagicMock(prompt_tokens=10, completion_tokens=20, total_tokens=30)
@@ -394,7 +394,7 @@ def test_extra_arguments_are_passed_to_openai_completion():
"""
llm = LLM(model="gpt-4o", temperature=0.7, max_tokens=1000, top_p=0.5, max_retries=3)
with patch.object(llm.client.chat.completions, 'create') as mock_create:
with patch.object(llm._client.chat.completions, 'create') as mock_create:
mock_create.return_value = MagicMock(
choices=[MagicMock(message=MagicMock(content="test response", tool_calls=None))],
usage=MagicMock(prompt_tokens=10, completion_tokens=20, total_tokens=30)
@@ -501,7 +501,7 @@ def test_openai_streaming_with_response_model():
llm = LLM(model="openai/gpt-4o", stream=True)
with patch.object(llm.client.chat.completions, "create") as mock_create:
with patch.object(llm._client.chat.completions, "create") as mock_create:
mock_chunk1 = MagicMock()
mock_chunk1.choices = [
MagicMock(delta=MagicMock(content='{"answer": "test", ', tool_calls=None))