mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 16:18:30 +00:00
chore: remove duplication in azure client
This commit is contained in:
@@ -12,10 +12,11 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Final
|
||||
from typing import TYPE_CHECKING, Any, Final
|
||||
|
||||
from dotenv import load_dotenv
|
||||
import httpx
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.llm_events import (
|
||||
@@ -42,6 +43,8 @@ if TYPE_CHECKING:
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
load_dotenv()
|
||||
|
||||
DEFAULT_CONTEXT_WINDOW_SIZE: Final[int] = 4096
|
||||
DEFAULT_SUPPORTS_STOP_WORDS: Final[bool] = True
|
||||
_JSON_EXTRACTION_PATTERN: Final[re.Pattern[str]] = re.compile(r"\{.*}", re.DOTALL)
|
||||
@@ -65,10 +68,6 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta):
|
||||
stop: A list of stop sequences that the LLM should use to stop generation.
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(
|
||||
extra="allow", populate_by_name=True
|
||||
)
|
||||
|
||||
# Core fields
|
||||
model: str = Field(..., description="The model identifier/name")
|
||||
temperature: float | None = Field(
|
||||
@@ -100,7 +99,7 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta):
|
||||
"cached_prompt_tokens": 0,
|
||||
}
|
||||
|
||||
@field_validator("api_key", mode="before")
|
||||
@field_validator("api_key", mode="after")
|
||||
@classmethod
|
||||
def _validate_api_key(cls, value: str | None) -> str | None:
|
||||
"""Validate API key for authentication.
|
||||
@@ -137,37 +136,6 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta):
|
||||
return value
|
||||
return []
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _extract_stop_and_validate(cls, values: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Extract and normalize stop sequences before model initialization.
|
||||
|
||||
Args:
|
||||
values: Input values dictionary
|
||||
|
||||
Returns:
|
||||
Processed values dictionary
|
||||
"""
|
||||
if not values.get("model"):
|
||||
raise ValueError("Model name is required and cannot be empty")
|
||||
|
||||
stop = values.get("stop") or values.get("stop_sequences")
|
||||
if stop is None:
|
||||
values["stop"] = []
|
||||
elif isinstance(stop, str):
|
||||
values["stop"] = [stop]
|
||||
elif isinstance(stop, list):
|
||||
values["stop"] = stop
|
||||
else:
|
||||
values["stop"] = []
|
||||
|
||||
values.pop("stop_sequences", None)
|
||||
|
||||
if "provider" not in values or values["provider"] is None:
|
||||
values["provider"] = "openai"
|
||||
|
||||
return values
|
||||
|
||||
@property
|
||||
def additional_params(self) -> dict[str, Any]:
|
||||
"""Get additional parameters stored as extra fields.
|
||||
@@ -190,20 +158,6 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta):
|
||||
self.__pydantic_extra__ = {}
|
||||
self.__pydantic_extra__.update(value)
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
"""Initialize token usage tracking after model initialization.
|
||||
|
||||
Args:
|
||||
__context: Pydantic context (unused)
|
||||
"""
|
||||
self._token_usage = {
|
||||
"total_tokens": 0,
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"successful_requests": 0,
|
||||
"cached_prompt_tokens": 0,
|
||||
}
|
||||
|
||||
@abstractmethod
|
||||
def call(
|
||||
self,
|
||||
|
||||
14
lib/crewai/src/crewai/llm/internal/constants.py
Normal file
14
lib/crewai/src/crewai/llm/internal/constants.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from crewai.llm.constants import SupportedNativeProviders
|
||||
|
||||
|
||||
PROVIDER_MAPPING: dict[str, SupportedNativeProviders] = {
|
||||
"openai": "openai",
|
||||
"anthropic": "anthropic",
|
||||
"claude": "anthropic",
|
||||
"azure": "azure",
|
||||
"azure_openai": "azure",
|
||||
"google": "gemini",
|
||||
"gemini": "gemini",
|
||||
"bedrock": "bedrock",
|
||||
"aws": "bedrock",
|
||||
}
|
||||
@@ -9,6 +9,7 @@ from __future__ import annotations
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
|
||||
from pydantic import ConfigDict
|
||||
from pydantic._internal._model_construction import ModelMetaclass
|
||||
|
||||
from crewai.llm.constants import (
|
||||
@@ -21,6 +22,7 @@ from crewai.llm.constants import (
|
||||
SupportedModels,
|
||||
SupportedNativeProviders,
|
||||
)
|
||||
from crewai.llm.internal.constants import PROVIDER_MAPPING
|
||||
|
||||
|
||||
class LLMMeta(ModelMetaclass):
|
||||
@@ -30,6 +32,41 @@ class LLMMeta(ModelMetaclass):
|
||||
native provider implementation based on the model parameter.
|
||||
"""
|
||||
|
||||
def __new__(
|
||||
mcs,
|
||||
name: str,
|
||||
bases: tuple[type, ...],
|
||||
namespace: dict[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> type:
|
||||
"""Create new LLM class with proper model_config for custom LLMs.
|
||||
|
||||
Args:
|
||||
name: Class name
|
||||
bases: Base classes
|
||||
namespace: Class namespace
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
New class
|
||||
"""
|
||||
if name != "BaseLLM" and any(
|
||||
base.__name__ in ("BaseLLM", "LLM") for base in bases
|
||||
):
|
||||
if "model_config" not in namespace:
|
||||
namespace["model_config"] = ConfigDict(
|
||||
extra="allow", populate_by_name=True
|
||||
)
|
||||
elif isinstance(namespace["model_config"], dict):
|
||||
config_dict = cast(
|
||||
ConfigDict, cast(object, dict(namespace["model_config"]))
|
||||
)
|
||||
config_dict.setdefault("extra", "allow")
|
||||
config_dict.setdefault("populate_by_name", True)
|
||||
namespace["model_config"] = ConfigDict(**config_dict)
|
||||
|
||||
return super().__new__(mcs, name, bases, namespace)
|
||||
|
||||
def __call__(cls, *args: Any, **kwargs: Any) -> Any: # noqa: N805
|
||||
"""Route to appropriate provider implementation at instantiation time.
|
||||
|
||||
@@ -57,7 +94,7 @@ class LLMMeta(ModelMetaclass):
|
||||
|
||||
if args and not kwargs.get("model"):
|
||||
kwargs["model"] = cast(SupportedModels, args[0])
|
||||
args = args[1:]
|
||||
_ = args[1:]
|
||||
explicit_provider = cast(SupportedNativeProviders, kwargs.get("provider"))
|
||||
|
||||
if explicit_provider:
|
||||
@@ -70,19 +107,7 @@ class LLMMeta(ModelMetaclass):
|
||||
model.partition("/"),
|
||||
)
|
||||
|
||||
provider_mapping: dict[str, SupportedNativeProviders] = {
|
||||
"openai": "openai",
|
||||
"anthropic": "anthropic",
|
||||
"claude": "anthropic",
|
||||
"azure": "azure",
|
||||
"azure_openai": "azure",
|
||||
"google": "gemini",
|
||||
"gemini": "gemini",
|
||||
"bedrock": "bedrock",
|
||||
"aws": "bedrock",
|
||||
}
|
||||
|
||||
canonical_provider = provider_mapping.get(prefix.lower())
|
||||
canonical_provider = PROVIDER_MAPPING.get(prefix.lower())
|
||||
|
||||
if canonical_provider and cls._validate_model_in_constants(
|
||||
model_part, canonical_provider
|
||||
|
||||
@@ -4,6 +4,7 @@ import json
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from dotenv import load_dotenv
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field, PrivateAttr, model_validator
|
||||
from typing_extensions import Self
|
||||
@@ -21,13 +22,14 @@ from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from anthropic.types import Message
|
||||
|
||||
from crewai.agent.core import Agent
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
try:
|
||||
from anthropic import Anthropic
|
||||
from anthropic.types import Message
|
||||
from anthropic.types.tool_use_block import ToolUseBlock
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
@@ -35,6 +37,9 @@ except ImportError:
|
||||
) from None
|
||||
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class AnthropicCompletion(BaseLLM):
|
||||
"""Anthropic native completion implementation.
|
||||
|
||||
@@ -66,11 +71,9 @@ class AnthropicCompletion(BaseLLM):
|
||||
top_p: float | None = Field(default=None, description="Nucleus sampling parameter")
|
||||
stream: bool = Field(default=False, description="Enable streaming responses")
|
||||
client_params: dict[str, Any] | None = Field(
|
||||
default=None, description="Additional Anthropic client parameters"
|
||||
)
|
||||
client: Anthropic = Field(
|
||||
default_factory=Anthropic, exclude=True, description="Anthropic client instance"
|
||||
default_factory=dict, description="Additional Anthropic client parameters"
|
||||
)
|
||||
_client: Anthropic = PrivateAttr(default=None) # type: ignore[assignment]
|
||||
|
||||
_is_claude_3: bool = PrivateAttr(default=False)
|
||||
_supports_tools: bool = PrivateAttr(default=False)
|
||||
@@ -78,7 +81,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
@model_validator(mode="after")
|
||||
def setup_client(self) -> Self:
|
||||
"""Initialize the Anthropic client and model-specific settings."""
|
||||
self.client = Anthropic(**self._get_client_params())
|
||||
self._client = Anthropic(**self._get_client_params())
|
||||
|
||||
self._is_claude_3 = "claude-3" in self.model.lower()
|
||||
self._supports_tools = self._is_claude_3
|
||||
@@ -98,9 +101,6 @@ class AnthropicCompletion(BaseLLM):
|
||||
def _get_client_params(self) -> dict[str, Any]:
|
||||
"""Get client parameters."""
|
||||
|
||||
if self.api_key is None:
|
||||
raise ValueError("ANTHROPIC_API_KEY is required")
|
||||
|
||||
client_params = {
|
||||
"api_key": self.api_key,
|
||||
"base_url": self.base_url,
|
||||
@@ -330,7 +330,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
|
||||
|
||||
try:
|
||||
response: Message = self.client.messages.create(**params)
|
||||
response: Message = self._client.messages.create(**params)
|
||||
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
@@ -424,7 +424,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
stream_params = {k: v for k, v in params.items() if k != "stream"}
|
||||
|
||||
# Make streaming API call
|
||||
with self.client.messages.stream(**stream_params) as stream:
|
||||
with self._client.messages.stream(**stream_params) as stream:
|
||||
for event in stream:
|
||||
if hasattr(event, "delta") and hasattr(event.delta, "text"):
|
||||
text_delta = event.delta.text
|
||||
@@ -552,7 +552,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
try:
|
||||
# Send tool results back to Claude for final response
|
||||
final_response: Message = self.client.messages.create(**follow_up_params)
|
||||
final_response: Message = self._client.messages.create(**follow_up_params)
|
||||
|
||||
# Track token usage for follow-up call
|
||||
follow_up_usage = self._extract_anthropic_token_usage(final_response)
|
||||
|
||||
@@ -5,6 +5,7 @@ import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseModel, Field, PrivateAttr, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
@@ -48,6 +49,9 @@ except ImportError:
|
||||
) from None
|
||||
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class AzureCompletion(BaseLLM):
|
||||
"""Azure AI Inference native completion implementation.
|
||||
|
||||
@@ -68,12 +72,14 @@ class AzureCompletion(BaseLLM):
|
||||
interceptor: HTTP interceptor (not yet supported for Azure)
|
||||
"""
|
||||
|
||||
endpoint: str | None = Field(
|
||||
default=None,
|
||||
endpoint: str = Field( # type: ignore[assignment]
|
||||
default_factory=lambda: os.getenv("AZURE_ENDPOINT")
|
||||
or os.getenv("AZURE_OPENAI_ENDPOINT")
|
||||
or os.getenv("AZURE_API_BASE"),
|
||||
description="Azure endpoint URL (defaults to AZURE_ENDPOINT env var)",
|
||||
)
|
||||
api_version: str = Field(
|
||||
default="2024-06-01",
|
||||
default_factory=lambda: os.getenv("AZURE_API_VERSION", "2024-06-01"),
|
||||
description="Azure API version (defaults to AZURE_API_VERSION env var or 2024-06-01)",
|
||||
)
|
||||
timeout: float | None = Field(
|
||||
@@ -82,18 +88,16 @@ class AzureCompletion(BaseLLM):
|
||||
max_retries: int = Field(default=2, description="Maximum number of retries")
|
||||
top_p: float | None = Field(default=None, description="Nucleus sampling parameter")
|
||||
frequency_penalty: float | None = Field(
|
||||
default=None, description="Frequency penalty (-2 to 2)"
|
||||
default=None, le=2.0, ge=-2.0, description="Frequency penalty (-2 to 2)"
|
||||
)
|
||||
presence_penalty: float | None = Field(
|
||||
default=None, description="Presence penalty (-2 to 2)"
|
||||
default=None, le=2.0, ge=-2.0, description="Presence penalty (-2 to 2)"
|
||||
)
|
||||
max_tokens: int | None = Field(
|
||||
default=None, description="Maximum tokens in response"
|
||||
)
|
||||
stream: bool = Field(default=False, description="Enable streaming responses")
|
||||
_client: ChatCompletionsClient = PrivateAttr(
|
||||
default_factory=ChatCompletionsClient, # type: ignore[arg-type]
|
||||
)
|
||||
_client: ChatCompletionsClient = PrivateAttr(default=None) # type: ignore[assignment]
|
||||
|
||||
_is_openai_model: bool = PrivateAttr(default=False)
|
||||
_is_azure_openai_endpoint: bool = PrivateAttr(default=False)
|
||||
@@ -107,26 +111,13 @@ class AzureCompletion(BaseLLM):
|
||||
"Interceptors are currently supported for OpenAI and Anthropic providers only."
|
||||
)
|
||||
|
||||
if self.endpoint is None:
|
||||
self.endpoint = (
|
||||
os.getenv("AZURE_ENDPOINT")
|
||||
or os.getenv("AZURE_OPENAI_ENDPOINT")
|
||||
or os.getenv("AZURE_API_BASE")
|
||||
)
|
||||
|
||||
if self.api_version == "2024-06-01":
|
||||
env_version = os.getenv("AZURE_API_VERSION")
|
||||
if env_version:
|
||||
self.api_version = env_version
|
||||
if not self.api_key:
|
||||
self.api_key = os.getenv("AZURE_API_KEY")
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
"Azure API key is required. Set AZURE_API_KEY environment variable or pass api_key parameter."
|
||||
)
|
||||
if not self.endpoint:
|
||||
raise ValueError(
|
||||
"Azure endpoint is required. Set AZURE_ENDPOINT environment variable or pass endpoint parameter."
|
||||
)
|
||||
|
||||
self.endpoint = self._validate_and_fix_endpoint(self.endpoint, self.model)
|
||||
|
||||
@@ -138,7 +129,7 @@ class AzureCompletion(BaseLLM):
|
||||
if self.api_version:
|
||||
client_kwargs["api_version"] = self.api_version
|
||||
|
||||
self.client = ChatCompletionsClient(**client_kwargs)
|
||||
self._client = ChatCompletionsClient(**client_kwargs)
|
||||
|
||||
self._is_openai_model = any(
|
||||
prefix in self.model.lower() for prefix in ["gpt-", "o1-", "text-"]
|
||||
@@ -160,7 +151,7 @@ class AzureCompletion(BaseLLM):
|
||||
"""Check if endpoint is an Azure OpenAI endpoint."""
|
||||
return self._is_azure_openai_endpoint
|
||||
|
||||
def _validate_and_fix_endpoint(self, endpoint: str, model: str) -> str:
|
||||
def _validate_and_fix_endpoint(self, endpoint: str | None, model: str) -> str:
|
||||
"""Validate and fix Azure endpoint URL format.
|
||||
|
||||
Azure OpenAI endpoints should be in the format:
|
||||
@@ -172,7 +163,15 @@ class AzureCompletion(BaseLLM):
|
||||
|
||||
Returns:
|
||||
Validated and potentially corrected endpoint URL
|
||||
|
||||
Raises:
|
||||
ValueError: If endpoint is None or empty
|
||||
"""
|
||||
if not endpoint:
|
||||
raise ValueError(
|
||||
"Azure endpoint is required. Set AZURE_ENDPOINT environment variable or pass endpoint parameter."
|
||||
)
|
||||
|
||||
if "openai.azure.com" in endpoint and "/openai/deployments/" not in endpoint:
|
||||
endpoint = endpoint.rstrip("/")
|
||||
|
||||
@@ -388,7 +387,7 @@ class AzureCompletion(BaseLLM):
|
||||
"""Handle non-streaming chat completion."""
|
||||
# Make API call
|
||||
try:
|
||||
response: ChatCompletions = self.client.complete(**params)
|
||||
response: ChatCompletions = self._client.complete(**params)
|
||||
|
||||
if not response.choices:
|
||||
raise ValueError("No choices returned from Azure API")
|
||||
@@ -486,7 +485,7 @@ class AzureCompletion(BaseLLM):
|
||||
tool_calls = {}
|
||||
|
||||
# Make streaming API call
|
||||
for update in self.client.complete(**params):
|
||||
for update in self._client.complete(**params):
|
||||
if isinstance(update, StreamingChatCompletionsUpdate):
|
||||
if update.choices:
|
||||
choice = update.choices[0]
|
||||
|
||||
@@ -5,6 +5,8 @@ import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, TypedDict, cast
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from mypy_boto3_bedrock_runtime.client import BedrockRuntimeClient
|
||||
from pydantic import BaseModel, Field, PrivateAttr, model_validator
|
||||
from typing_extensions import Required, Self
|
||||
|
||||
@@ -75,6 +77,9 @@ else:
|
||||
topK: int
|
||||
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class ToolInputSchema(TypedDict):
|
||||
"""Type definition for tool input schema in Converse API."""
|
||||
|
||||
@@ -161,16 +166,22 @@ class BedrockCompletion(BaseLLM):
|
||||
interceptor: HTTP interceptor (not yet supported for Bedrock)
|
||||
"""
|
||||
|
||||
aws_access_key_id: str | None = Field(
|
||||
default=None, description="AWS access key (defaults to environment variable)"
|
||||
aws_access_key_id: str = Field( # type: ignore[assignment]
|
||||
default_factory=lambda: os.getenv("AWS_ACCESS_KEY_ID"),
|
||||
description="AWS access key (defaults to environment variable)",
|
||||
)
|
||||
aws_secret_access_key: str | None = Field(
|
||||
default=None, description="AWS secret key (defaults to environment variable)"
|
||||
aws_secret_access_key: str = Field( # type: ignore[assignment]
|
||||
default_factory=lambda: os.getenv("AWS_SECRET_ACCESS_KEY"),
|
||||
description="AWS secret key (defaults to environment variable)",
|
||||
)
|
||||
aws_session_token: str | None = Field(
|
||||
default=None, description="AWS session token for temporary credentials"
|
||||
aws_session_token: str = Field( # type: ignore[assignment]
|
||||
default_factory=lambda: os.getenv("AWS_SESSION_TOKEN"),
|
||||
description="AWS session token for temporary credentials",
|
||||
)
|
||||
region_name: str = Field(
|
||||
default_factory=lambda: os.getenv("AWS_REGION", "us-east-1"),
|
||||
description="AWS region name",
|
||||
)
|
||||
region_name: str = Field(default="us-east-1", description="AWS region name")
|
||||
max_tokens: int | None = Field(
|
||||
default=None, description="Maximum tokens to generate"
|
||||
)
|
||||
@@ -181,17 +192,18 @@ class BedrockCompletion(BaseLLM):
|
||||
stream: bool = Field(
|
||||
default=False, description="Whether to use streaming responses"
|
||||
)
|
||||
guardrail_config: dict[str, Any] | None = Field(
|
||||
default=None, description="Guardrail configuration for content filtering"
|
||||
guardrail_config: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Guardrail configuration for content filtering",
|
||||
)
|
||||
additional_model_request_fields: dict[str, Any] | None = Field(
|
||||
default=None, description="Model-specific request parameters"
|
||||
additional_model_request_fields: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Model-specific request parameters"
|
||||
)
|
||||
additional_model_response_field_paths: list[str] | None = Field(
|
||||
default=None, description="Custom response field paths"
|
||||
additional_model_response_field_paths: list[str] = Field(
|
||||
default_factory=list, description="Custom response field paths"
|
||||
)
|
||||
client: Any = Field(
|
||||
default=None, exclude=True, description="Bedrock client instance"
|
||||
_client: BedrockRuntimeClient = PrivateAttr( # type: ignore[assignment]
|
||||
default_factory=lambda: Session().client,
|
||||
)
|
||||
|
||||
_is_claude_model: bool = PrivateAttr(default=False)
|
||||
@@ -209,10 +221,9 @@ class BedrockCompletion(BaseLLM):
|
||||
)
|
||||
|
||||
session = Session(
|
||||
aws_access_key_id=self.aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID"),
|
||||
aws_secret_access_key=self.aws_secret_access_key
|
||||
or os.getenv("AWS_SECRET_ACCESS_KEY"),
|
||||
aws_session_token=self.aws_session_token or os.getenv("AWS_SESSION_TOKEN"),
|
||||
aws_access_key_id=self.aws_access_key_id,
|
||||
aws_secret_access_key=self.aws_secret_access_key,
|
||||
aws_session_token=self.aws_session_token,
|
||||
region_name=self.region_name,
|
||||
)
|
||||
|
||||
@@ -225,7 +236,7 @@ class BedrockCompletion(BaseLLM):
|
||||
tcp_keepalive=True,
|
||||
)
|
||||
|
||||
self.client = session.client("bedrock-runtime", config=config)
|
||||
self._client = session.client("bedrock-runtime", config=config)
|
||||
|
||||
self._is_claude_model = "claude" in self.model.lower()
|
||||
self._supports_tools = True
|
||||
@@ -365,7 +376,7 @@ class BedrockCompletion(BaseLLM):
|
||||
raise ValueError(f"Invalid message format at index {i}")
|
||||
|
||||
# Call Bedrock Converse API with proper error handling
|
||||
response = self.client.converse(
|
||||
response = self._client.converse(
|
||||
modelId=self.model_id,
|
||||
messages=cast(
|
||||
"Sequence[MessageTypeDef | MessageOutputTypeDef]",
|
||||
@@ -536,13 +547,13 @@ class BedrockCompletion(BaseLLM):
|
||||
tool_use_id = None
|
||||
|
||||
try:
|
||||
response = self.client.converse_stream(
|
||||
response = self._client.converse_stream(
|
||||
modelId=self.model_id,
|
||||
messages=cast(
|
||||
"Sequence[MessageTypeDef | MessageOutputTypeDef]",
|
||||
cast(object, messages),
|
||||
),
|
||||
**body,
|
||||
**body, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
stream = response.get("stream")
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, cast
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseModel, Field, PrivateAttr, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
@@ -31,6 +34,9 @@ except ImportError:
|
||||
) from None
|
||||
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class GeminiCompletion(BaseLLM):
|
||||
"""Google Gemini native completion implementation.
|
||||
|
||||
@@ -50,15 +56,12 @@ class GeminiCompletion(BaseLLM):
|
||||
interceptor: HTTP interceptor (not yet supported for Gemini)
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(
|
||||
ignored_types=(property,), arbitrary_types_allowed=True
|
||||
)
|
||||
|
||||
project: str | None = Field(
|
||||
default=None, description="Google Cloud project ID (for Vertex AI)"
|
||||
default_factory=lambda: os.getenv("GOOGLE_CLOUD_PROJECT"),
|
||||
description="Google Cloud project ID (for Vertex AI)",
|
||||
)
|
||||
location: str = Field(
|
||||
default="us-central1",
|
||||
default_factory=lambda: os.getenv("GOOGLE_CLOUD_LOCATION", "us-central1"),
|
||||
description="Google Cloud location (for Vertex AI, defaults to 'us-central1')",
|
||||
)
|
||||
top_p: float | None = Field(default=None, description="Nucleus sampling parameter")
|
||||
@@ -74,9 +77,7 @@ class GeminiCompletion(BaseLLM):
|
||||
default_factory=dict,
|
||||
description="Additional parameters for Google Gen AI Client constructor",
|
||||
)
|
||||
client: Any = Field(
|
||||
default=None, exclude=True, description="Gemini client instance"
|
||||
)
|
||||
_client: Any = PrivateAttr(default=None)
|
||||
|
||||
_is_gemini_2: bool = PrivateAttr(default=False)
|
||||
_is_gemini_1_5: bool = PrivateAttr(default=False)
|
||||
@@ -94,17 +95,9 @@ class GeminiCompletion(BaseLLM):
|
||||
if self.api_key is None:
|
||||
self.api_key = os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY")
|
||||
|
||||
if self.project is None:
|
||||
self.project = os.getenv("GOOGLE_CLOUD_PROJECT")
|
||||
|
||||
if self.location == "us-central1":
|
||||
env_location = os.getenv("GOOGLE_CLOUD_LOCATION")
|
||||
if env_location:
|
||||
self.location = env_location
|
||||
|
||||
use_vertexai = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true"
|
||||
|
||||
self.client = self._initialize_client(use_vertexai)
|
||||
self._client = self._initialize_client(use_vertexai)
|
||||
|
||||
self._is_gemini_2 = "gemini-2" in self.model.lower()
|
||||
self._is_gemini_1_5 = "gemini-1.5" in self.model.lower()
|
||||
@@ -176,9 +169,9 @@ class GeminiCompletion(BaseLLM):
|
||||
params = {}
|
||||
|
||||
if (
|
||||
hasattr(self, "client")
|
||||
and hasattr(self.client, "vertexai")
|
||||
and self.client.vertexai
|
||||
hasattr(self, "_client")
|
||||
and hasattr(self._client, "vertexai")
|
||||
and self._client.vertexai
|
||||
):
|
||||
params.update(
|
||||
{
|
||||
@@ -400,7 +393,7 @@ class GeminiCompletion(BaseLLM):
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.client.models.generate_content(**api_params)
|
||||
response = self._client.models.generate_content(**api_params)
|
||||
|
||||
usage = self._extract_token_usage(response)
|
||||
except Exception as e:
|
||||
@@ -468,7 +461,7 @@ class GeminiCompletion(BaseLLM):
|
||||
"config": config,
|
||||
}
|
||||
|
||||
for chunk in self.client.models.generate_content_stream(**api_params):
|
||||
for chunk in self._client.models.generate_content_stream(**api_params):
|
||||
if hasattr(chunk, "text") and chunk.text:
|
||||
full_response += chunk.text
|
||||
self._emit_stream_chunk_event(
|
||||
|
||||
@@ -6,6 +6,7 @@ import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from dotenv import load_dotenv
|
||||
import httpx
|
||||
from openai import APIConnectionError, NotFoundError, OpenAI
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
@@ -32,6 +33,9 @@ if TYPE_CHECKING:
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class OpenAICompletion(BaseLLM):
|
||||
"""OpenAI native completion implementation.
|
||||
|
||||
@@ -40,43 +44,51 @@ class OpenAICompletion(BaseLLM):
|
||||
"""
|
||||
|
||||
# Client configuration fields
|
||||
organization: str | None = Field(None, description="OpenAI organization ID")
|
||||
project: str | None = Field(None, description="OpenAI project ID")
|
||||
max_retries: int = Field(2, description="Maximum number of retries")
|
||||
default_headers: dict[str, str] | None = Field(
|
||||
None, description="Default headers for requests"
|
||||
organization: str | None = Field(default=None, description="OpenAI organization ID")
|
||||
project: str | None = Field(default=None, description="OpenAI project ID")
|
||||
max_retries: int = Field(default=2, description="Maximum number of retries")
|
||||
default_headers: dict[str, str] = Field(
|
||||
default_factory=dict, description="Default headers for requests"
|
||||
)
|
||||
default_query: dict[str, Any] | None = Field(
|
||||
None, description="Default query parameters"
|
||||
default_query: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Default query parameters"
|
||||
)
|
||||
client_params: dict[str, Any] | None = Field(
|
||||
None, description="Additional client parameters"
|
||||
client_params: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Additional client parameters"
|
||||
)
|
||||
timeout: float | None = Field(default=None, description="Request timeout")
|
||||
api_base: str | None = Field(
|
||||
default=None, description="API base URL", deprecated=True
|
||||
)
|
||||
timeout: float | None = Field(None, description="Request timeout")
|
||||
api_base: str | None = Field(None, description="API base URL (deprecated)")
|
||||
|
||||
# Completion parameters
|
||||
top_p: float | None = Field(None, description="Top-p sampling parameter")
|
||||
frequency_penalty: float | None = Field(None, description="Frequency penalty")
|
||||
presence_penalty: float | None = Field(None, description="Presence penalty")
|
||||
max_tokens: int | None = Field(None, description="Maximum tokens")
|
||||
top_p: float | None = Field(default=None, description="Top-p sampling parameter")
|
||||
frequency_penalty: float | None = Field(
|
||||
default=None, description="Frequency penalty"
|
||||
)
|
||||
presence_penalty: float | None = Field(default=None, description="Presence penalty")
|
||||
max_tokens: int | None = Field(default=None, description="Maximum tokens")
|
||||
max_completion_tokens: int | None = Field(
|
||||
None, description="Maximum completion tokens"
|
||||
)
|
||||
seed: int | None = Field(None, description="Random seed")
|
||||
stream: bool = Field(False, description="Enable streaming")
|
||||
seed: int | None = Field(default=None, description="Random seed")
|
||||
stream: bool = Field(default=False, description="Enable streaming")
|
||||
response_format: dict[str, Any] | type[BaseModel] | None = Field(
|
||||
None, description="Response format"
|
||||
default=None, description="Response format"
|
||||
)
|
||||
logprobs: bool | None = Field(None, description="Return log probabilities")
|
||||
logprobs: bool | None = Field(default=None, description="Return log probabilities")
|
||||
top_logprobs: int | None = Field(
|
||||
None, description="Number of top log probabilities"
|
||||
default=None, description="Number of top log probabilities"
|
||||
)
|
||||
reasoning_effort: str | None = Field(
|
||||
default=None, description="Reasoning effort level"
|
||||
)
|
||||
reasoning_effort: str | None = Field(None, description="Reasoning effort level")
|
||||
|
||||
_client: OpenAI = PrivateAttr(default_factory=OpenAI)
|
||||
is_o1_model: bool = Field(False, description="Whether this is an O1 model")
|
||||
is_gpt4_model: bool = Field(False, description="Whether this is a GPT-4 model")
|
||||
_client: OpenAI = PrivateAttr(default=None) # type: ignore[assignment]
|
||||
is_o1_model: bool = Field(default=False, description="Whether this is an O1 model")
|
||||
is_gpt4_model: bool = Field(
|
||||
default=False, description="Whether this is a GPT-4 model"
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def setup_client(self) -> Self:
|
||||
@@ -97,10 +109,6 @@ class OpenAICompletion(BaseLLM):
|
||||
|
||||
def _get_client_params(self) -> dict[str, Any]:
|
||||
"""Get OpenAI client parameters."""
|
||||
|
||||
if self.api_key is None:
|
||||
raise ValueError("OPENAI_API_KEY is required")
|
||||
|
||||
base_params = {
|
||||
"api_key": self.api_key,
|
||||
"organization": self.organization,
|
||||
|
||||
@@ -1 +1,38 @@
|
||||
"""LLM implementations for crewAI."""
|
||||
"""LLM implementations for crewAI.
|
||||
|
||||
.. deprecated:: 1.4.0
|
||||
The `crewai.llms` package is deprecated. Use `crewai.llm` instead.
|
||||
|
||||
This package was reorganized from `crewai.llms.*` to `crewai.llm.*`.
|
||||
All submodules are redirected to their new locations in `crewai.llm.*`.
|
||||
|
||||
Migration guide:
|
||||
Old: from crewai.llms.base_llm import BaseLLM
|
||||
New: from crewai.llm.base_llm import BaseLLM
|
||||
|
||||
Old: from crewai.llms.hooks.base import BaseInterceptor
|
||||
New: from crewai.llm.hooks.base import BaseInterceptor
|
||||
|
||||
Old: from crewai.llms.constants import OPENAI_MODELS
|
||||
New: from crewai.llm.constants import OPENAI_MODELS
|
||||
|
||||
Or use top-level imports:
|
||||
from crewai import LLM, BaseLLM
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
||||
from crewai.llm import LLM
|
||||
from crewai.llm.base_llm import BaseLLM
|
||||
|
||||
|
||||
# Issue deprecation warning when this module is imported
|
||||
warnings.warn(
|
||||
"The 'crewai.llms' package is deprecated and will be removed in a future version. "
|
||||
"Please use 'crewai.llm' (singular) instead. "
|
||||
"All submodules have been reorganized from 'crewai.llms.*' to 'crewai.llm.*'.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
__all__ = ["LLM", "BaseLLM"]
|
||||
|
||||
15
lib/crewai/src/crewai/llms/base_llm.py
Normal file
15
lib/crewai/src/crewai/llms/base_llm.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Deprecated: Use crewai.llm.base_llm instead.
|
||||
|
||||
.. deprecated:: 1.4.0
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
||||
|
||||
warnings.warn(
|
||||
"crewai.llms.base_llm is deprecated. Use crewai.llm.base_llm instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
from crewai.llm.base_llm import * # noqa: E402, F403
|
||||
15
lib/crewai/src/crewai/llms/constants.py
Normal file
15
lib/crewai/src/crewai/llms/constants.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Deprecated: Use crewai.llm.constants instead.
|
||||
|
||||
.. deprecated:: 1.4.0
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
||||
|
||||
warnings.warn(
|
||||
"crewai.llms.constants is deprecated. Use crewai.llm.constants instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
from crewai.llm.constants import * # noqa: E402, F403
|
||||
15
lib/crewai/src/crewai/llms/hooks/__init__.py
Normal file
15
lib/crewai/src/crewai/llms/hooks/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Deprecated: Use crewai.llm.hooks instead.
|
||||
|
||||
.. deprecated:: 1.4.0
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
||||
|
||||
warnings.warn(
|
||||
"crewai.llms.hooks is deprecated. Use crewai.llm.hooks instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
from crewai.llm.hooks import * # noqa: E402, F403
|
||||
15
lib/crewai/src/crewai/llms/hooks/base.py
Normal file
15
lib/crewai/src/crewai/llms/hooks/base.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Deprecated: Use crewai.llm.hooks.base instead.
|
||||
|
||||
.. deprecated:: 1.4.0
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
||||
|
||||
warnings.warn(
|
||||
"crewai.llms.hooks.base is deprecated. Use crewai.llm.hooks.base instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
from crewai.llm.hooks.base import * # noqa: E402, F403
|
||||
15
lib/crewai/src/crewai/llms/hooks/transport.py
Normal file
15
lib/crewai/src/crewai/llms/hooks/transport.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Deprecated: Use crewai.llm.hooks.transport instead.
|
||||
|
||||
.. deprecated:: 1.4.0
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
||||
|
||||
warnings.warn(
|
||||
"crewai.llms.hooks.transport is deprecated. Use crewai.llm.hooks.transport instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
from crewai.llm.hooks.transport import * # noqa: E402, F403
|
||||
15
lib/crewai/src/crewai/llms/internal/__init__.py
Normal file
15
lib/crewai/src/crewai/llms/internal/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Deprecated: Use crewai.llm.internal instead.
|
||||
|
||||
.. deprecated:: 1.4.0
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
||||
|
||||
warnings.warn(
|
||||
"crewai.llms.internal is deprecated. Use crewai.llm.internal instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
from crewai.llm.internal import * # noqa: E402, F403
|
||||
15
lib/crewai/src/crewai/llms/internal/constants.py
Normal file
15
lib/crewai/src/crewai/llms/internal/constants.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Deprecated: Use crewai.llm.internal.constants instead.
|
||||
|
||||
.. deprecated:: 1.4.0
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
||||
|
||||
warnings.warn(
|
||||
"crewai.llms.internal.constants is deprecated. Use crewai.llm.internal.constants instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
from crewai.llm.internal.constants import * # noqa: E402, F403
|
||||
15
lib/crewai/src/crewai/llms/providers/__init__.py
Normal file
15
lib/crewai/src/crewai/llms/providers/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Deprecated: Use crewai.llm.providers instead.
|
||||
|
||||
.. deprecated:: 1.4.0
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
||||
|
||||
warnings.warn(
|
||||
"crewai.llms.providers is deprecated. Use crewai.llm.providers instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
from crewai.llm.providers import * # noqa: E402, F403
|
||||
@@ -60,7 +60,7 @@ def test_anthropic_tool_use_conversation_flow():
|
||||
available_functions = {"get_weather": mock_weather_tool}
|
||||
|
||||
# Mock the Anthropic client responses
|
||||
with patch.object(completion.client.messages, 'create') as mock_create:
|
||||
with patch.object(completion._client.messages, 'create') as mock_create:
|
||||
# Mock initial response with tool use - need to properly mock ToolUseBlock
|
||||
mock_tool_use = Mock(spec=ToolUseBlock)
|
||||
mock_tool_use.id = "tool_123"
|
||||
@@ -199,8 +199,8 @@ def test_anthropic_specific_parameters():
|
||||
assert isinstance(llm, AnthropicCompletion)
|
||||
assert llm.stop == ["Human:", "Assistant:"]
|
||||
assert llm.stream == True
|
||||
assert llm.client.max_retries == 5
|
||||
assert llm.client.timeout == 60
|
||||
assert llm._client.max_retries == 5
|
||||
assert llm._client.timeout == 60
|
||||
|
||||
|
||||
def test_anthropic_completion_call():
|
||||
@@ -637,8 +637,8 @@ def test_anthropic_environment_variable_api_key():
|
||||
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-anthropic-key"}):
|
||||
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
|
||||
assert llm.client is not None
|
||||
assert hasattr(llm.client, 'messages')
|
||||
assert llm._client is not None
|
||||
assert hasattr(llm._client, 'messages')
|
||||
|
||||
|
||||
def test_anthropic_token_usage_tracking():
|
||||
@@ -648,7 +648,7 @@ def test_anthropic_token_usage_tracking():
|
||||
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
|
||||
# Mock the Anthropic response with usage information
|
||||
with patch.object(llm.client.messages, 'create') as mock_create:
|
||||
with patch.object(llm._client.messages, 'create') as mock_create:
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock(text="test response")]
|
||||
mock_response.usage = MagicMock(input_tokens=50, output_tokens=25)
|
||||
|
||||
@@ -64,7 +64,7 @@ def test_azure_tool_use_conversation_flow():
|
||||
available_functions = {"get_weather": mock_weather_tool}
|
||||
|
||||
# Mock the Azure client responses
|
||||
with patch.object(completion.client, 'complete') as mock_complete:
|
||||
with patch.object(completion._client, 'complete') as mock_complete:
|
||||
# Mock tool call in response with proper type
|
||||
mock_tool_call = MagicMock(spec=ChatCompletionsToolCall)
|
||||
mock_tool_call.function.name = "get_weather"
|
||||
@@ -602,7 +602,7 @@ def test_azure_environment_variable_endpoint():
|
||||
}):
|
||||
llm = LLM(model="azure/gpt-4")
|
||||
|
||||
assert llm.client is not None
|
||||
assert llm._client is not None
|
||||
assert llm.endpoint == "https://test.openai.azure.com/openai/deployments/gpt-4"
|
||||
|
||||
|
||||
@@ -613,7 +613,7 @@ def test_azure_token_usage_tracking():
|
||||
llm = LLM(model="azure/gpt-4")
|
||||
|
||||
# Mock the Azure response with usage information
|
||||
with patch.object(llm.client, 'complete') as mock_complete:
|
||||
with patch.object(llm._client, 'complete') as mock_complete:
|
||||
mock_message = MagicMock()
|
||||
mock_message.content = "test response"
|
||||
mock_message.tool_calls = None
|
||||
@@ -651,7 +651,7 @@ def test_azure_http_error_handling():
|
||||
llm = LLM(model="azure/gpt-4")
|
||||
|
||||
# Mock an HTTP error
|
||||
with patch.object(llm.client, 'complete') as mock_complete:
|
||||
with patch.object(llm._client, 'complete') as mock_complete:
|
||||
mock_complete.side_effect = HttpResponseError(message="Rate limit exceeded", response=MagicMock(status_code=429))
|
||||
|
||||
with pytest.raises(HttpResponseError):
|
||||
@@ -668,7 +668,7 @@ def test_azure_streaming_completion():
|
||||
llm = LLM(model="azure/gpt-4", stream=True)
|
||||
|
||||
# Mock streaming response
|
||||
with patch.object(llm.client, 'complete') as mock_complete:
|
||||
with patch.object(llm._client, 'complete') as mock_complete:
|
||||
# Create mock streaming updates with proper type
|
||||
mock_updates = []
|
||||
for chunk in ["Hello", " ", "world", "!"]:
|
||||
@@ -891,7 +891,7 @@ def test_azure_improved_error_messages():
|
||||
|
||||
llm = LLM(model="azure/gpt-4")
|
||||
|
||||
with patch.object(llm.client, 'complete') as mock_complete:
|
||||
with patch.object(llm._client, 'complete') as mock_complete:
|
||||
error_401 = HttpResponseError(message="Unauthorized")
|
||||
error_401.status_code = 401
|
||||
mock_complete.side_effect = error_401
|
||||
|
||||
@@ -579,7 +579,7 @@ def test_bedrock_token_usage_tracking():
|
||||
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
|
||||
# Mock the Bedrock response with usage information
|
||||
with patch.object(llm.client, 'converse') as mock_converse:
|
||||
with patch.object(llm._client, 'converse') as mock_converse:
|
||||
mock_response = {
|
||||
'output': {
|
||||
'message': {
|
||||
@@ -624,7 +624,7 @@ def test_bedrock_tool_use_conversation_flow():
|
||||
available_functions = {"get_weather": mock_weather_tool}
|
||||
|
||||
# Mock the Bedrock client responses
|
||||
with patch.object(llm.client, 'converse') as mock_converse:
|
||||
with patch.object(llm._client, 'converse') as mock_converse:
|
||||
# First response: tool use request
|
||||
tool_use_response = {
|
||||
'output': {
|
||||
@@ -710,7 +710,7 @@ def test_bedrock_client_error_handling():
|
||||
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
|
||||
# Test ValidationException
|
||||
with patch.object(llm.client, 'converse') as mock_converse:
|
||||
with patch.object(llm._client, 'converse') as mock_converse:
|
||||
error_response = {
|
||||
'Error': {
|
||||
'Code': 'ValidationException',
|
||||
@@ -724,7 +724,7 @@ def test_bedrock_client_error_handling():
|
||||
assert "validation" in str(exc_info.value).lower()
|
||||
|
||||
# Test ThrottlingException
|
||||
with patch.object(llm.client, 'converse') as mock_converse:
|
||||
with patch.object(llm._client, 'converse') as mock_converse:
|
||||
error_response = {
|
||||
'Error': {
|
||||
'Code': 'ThrottlingException',
|
||||
@@ -762,7 +762,7 @@ def test_bedrock_stop_sequences_sent_to_api():
|
||||
llm.stop = ["\nObservation:", "\nThought:"]
|
||||
|
||||
# Patch the API call to capture parameters without making real call
|
||||
with patch.object(llm.client, 'converse') as mock_converse:
|
||||
with patch.object(llm._client, 'converse') as mock_converse:
|
||||
mock_response = {
|
||||
'output': {
|
||||
'message': {
|
||||
|
||||
@@ -59,7 +59,7 @@ def test_gemini_tool_use_conversation_flow():
|
||||
available_functions = {"get_weather": mock_weather_tool}
|
||||
|
||||
# Mock the Google Gemini client responses
|
||||
with patch.object(completion.client.models, 'generate_content') as mock_generate:
|
||||
with patch.object(completion._client.models, 'generate_content') as mock_generate:
|
||||
# Mock function call in response
|
||||
mock_function_call = Mock()
|
||||
mock_function_call.name = "get_weather"
|
||||
@@ -614,8 +614,8 @@ def test_gemini_environment_variable_api_key():
|
||||
with patch.dict(os.environ, {"GOOGLE_API_KEY": "test-google-key"}):
|
||||
llm = LLM(model="google/gemini-2.0-flash-001")
|
||||
|
||||
assert llm.client is not None
|
||||
assert hasattr(llm.client, 'models')
|
||||
assert llm._client is not None
|
||||
assert hasattr(llm._client, 'models')
|
||||
assert llm.api_key == "test-google-key"
|
||||
|
||||
|
||||
@@ -626,7 +626,7 @@ def test_gemini_token_usage_tracking():
|
||||
llm = LLM(model="google/gemini-2.0-flash-001")
|
||||
|
||||
# Mock the Gemini response with usage information
|
||||
with patch.object(llm.client.models, 'generate_content') as mock_generate:
|
||||
with patch.object(llm._client.models, 'generate_content') as mock_generate:
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "test response"
|
||||
mock_response.candidates = []
|
||||
@@ -675,7 +675,7 @@ def test_gemini_stop_sequences_sent_to_api():
|
||||
llm.stop = ["\nObservation:", "\nThought:"]
|
||||
|
||||
# Patch the API call to capture parameters without making real call
|
||||
with patch.object(llm.client.models, 'generate_content') as mock_generate:
|
||||
with patch.object(llm._client.models, 'generate_content') as mock_generate:
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Hello"
|
||||
mock_response.candidates = []
|
||||
|
||||
@@ -369,11 +369,11 @@ def test_openai_client_setup_with_extra_arguments():
|
||||
assert llm.top_p == 0.5
|
||||
|
||||
# Check that client parameters are properly configured
|
||||
assert llm.client.max_retries == 3
|
||||
assert llm.client.timeout == 30
|
||||
assert llm._client.max_retries == 3
|
||||
assert llm._client.timeout == 30
|
||||
|
||||
# Test that parameters are properly used in API calls
|
||||
with patch.object(llm.client.chat.completions, 'create') as mock_create:
|
||||
with patch.object(llm._client.chat.completions, 'create') as mock_create:
|
||||
mock_create.return_value = MagicMock(
|
||||
choices=[MagicMock(message=MagicMock(content="test response", tool_calls=None))],
|
||||
usage=MagicMock(prompt_tokens=10, completion_tokens=20, total_tokens=30)
|
||||
@@ -394,7 +394,7 @@ def test_extra_arguments_are_passed_to_openai_completion():
|
||||
"""
|
||||
llm = LLM(model="gpt-4o", temperature=0.7, max_tokens=1000, top_p=0.5, max_retries=3)
|
||||
|
||||
with patch.object(llm.client.chat.completions, 'create') as mock_create:
|
||||
with patch.object(llm._client.chat.completions, 'create') as mock_create:
|
||||
mock_create.return_value = MagicMock(
|
||||
choices=[MagicMock(message=MagicMock(content="test response", tool_calls=None))],
|
||||
usage=MagicMock(prompt_tokens=10, completion_tokens=20, total_tokens=30)
|
||||
@@ -501,7 +501,7 @@ def test_openai_streaming_with_response_model():
|
||||
|
||||
llm = LLM(model="openai/gpt-4o", stream=True)
|
||||
|
||||
with patch.object(llm.client.chat.completions, "create") as mock_create:
|
||||
with patch.object(llm._client.chat.completions, "create") as mock_create:
|
||||
mock_chunk1 = MagicMock()
|
||||
mock_chunk1.choices = [
|
||||
MagicMock(delta=MagicMock(content='{"answer": "test", ', tool_calls=None))
|
||||
|
||||
Reference in New Issue
Block a user