mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 07:38:29 +00:00
chore: continue refactoring llms to base models
This commit is contained in:
@@ -8,8 +8,8 @@ from crewai.crew import Crew
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.flow.flow import Flow
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.llm import LLM
|
||||
from crewai.llm.base_llm import BaseLLM
|
||||
from crewai.llm.core import LLM
|
||||
from crewai.process import Process
|
||||
from crewai.task import Task
|
||||
from crewai.tasks.llm_guardrail import LLMGuardrail
|
||||
|
||||
@@ -14,7 +14,8 @@ import tomli
|
||||
from crewai.cli.utils import read_toml
|
||||
from crewai.cli.version import get_crewai_version
|
||||
from crewai.crew import Crew
|
||||
from crewai.llm import LLM, BaseLLM
|
||||
from crewai.llm import LLM
|
||||
from crewai.llm.base_llm import BaseLLM
|
||||
from crewai.types.crew_chat import ChatInputField, ChatInputs
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
from crewai.utilities.printer import Printer
|
||||
|
||||
@@ -56,8 +56,8 @@ from crewai.events.types.crew_events import (
|
||||
from crewai.flow.flow_trackable import FlowTrackable
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.llm import LLM
|
||||
from crewai.llm.base_llm import BaseLLM
|
||||
from crewai.llm.core import LLM
|
||||
from crewai.memory.entity.entity_memory import EntityMemory
|
||||
from crewai.memory.external.external_memory import ExternalMemory
|
||||
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||
|
||||
@@ -89,7 +89,7 @@ from crewai.events.types.tool_usage_events import (
|
||||
ToolUsageStartedEvent,
|
||||
)
|
||||
from crewai.events.utils.console_formatter import ConsoleFormatter
|
||||
from crewai.llm import LLM
|
||||
from crewai.llm.core import LLM
|
||||
from crewai.task import Task
|
||||
from crewai.telemetry.telemetry import Telemetry
|
||||
from crewai.utilities import Logger
|
||||
|
||||
@@ -7,7 +7,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.llm import BaseLLM
|
||||
from crewai.llm.base_llm import BaseLLM
|
||||
from crewai.task import Task
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
|
||||
|
||||
@@ -39,8 +39,8 @@ from crewai.events.types.agent_events import (
|
||||
from crewai.events.types.logging_events import AgentLogsExecutionEvent
|
||||
from crewai.flow.flow_trackable import FlowTrackable
|
||||
from crewai.lite_agent_output import LiteAgentOutput
|
||||
from crewai.llm import LLM
|
||||
from crewai.llm.base_llm import BaseLLM
|
||||
from crewai.llm.core import LLM
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
from crewai.utilities.agent_utils import (
|
||||
|
||||
@@ -66,7 +66,7 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta):
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(
|
||||
arbitrary_types_allowed=True, extra="allow", validate_assignment=True
|
||||
arbitrary_types_allowed=True, extra="allow"
|
||||
)
|
||||
|
||||
# Core fields
|
||||
@@ -80,7 +80,9 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta):
|
||||
default="openai", description="Provider name (openai, anthropic, etc.)"
|
||||
)
|
||||
stop: list[str] = Field(
|
||||
default_factory=list, description="Stop sequences for generation"
|
||||
default_factory=list,
|
||||
description="Stop sequences for generation",
|
||||
validation_alias="stop_sequences",
|
||||
)
|
||||
|
||||
# Internal fields
|
||||
@@ -112,16 +114,18 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta):
|
||||
if not values.get("model"):
|
||||
raise ValueError("Model name is required and cannot be empty")
|
||||
|
||||
# Handle stop sequences
|
||||
stop = values.get("stop")
|
||||
stop = values.get("stop") or values.get("stop_sequences")
|
||||
if stop is None:
|
||||
values["stop"] = []
|
||||
elif isinstance(stop, str):
|
||||
values["stop"] = [stop]
|
||||
elif not isinstance(stop, list):
|
||||
elif isinstance(stop, list):
|
||||
values["stop"] = stop
|
||||
else:
|
||||
values["stop"] = []
|
||||
|
||||
# Set default provider if not specified
|
||||
values.pop("stop_sequences", None)
|
||||
|
||||
if "provider" not in values or values["provider"] is None:
|
||||
values["provider"] = "openai"
|
||||
|
||||
|
||||
@@ -33,13 +33,12 @@ class LLMMeta(ModelMetaclass):
|
||||
native provider implementation based on the model parameter.
|
||||
"""
|
||||
|
||||
def __call__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> Any: # noqa: N805
|
||||
def __call__(cls, *args: Any, **kwargs: Any) -> Any: # noqa: N805
|
||||
"""Route to appropriate provider implementation at instantiation time.
|
||||
|
||||
Args:
|
||||
model: The model identifier (e.g., "gpt-4", "claude-3-opus")
|
||||
is_litellm: Force use of LiteLLM instead of native provider
|
||||
**kwargs: Additional parameters for the LLM
|
||||
*args: Positional arguments (model should be first for LLM class)
|
||||
**kwargs: Keyword arguments including model, is_litellm, etc.
|
||||
|
||||
Returns:
|
||||
Instance of the appropriate provider class or LLM class
|
||||
@@ -47,18 +46,18 @@ class LLMMeta(ModelMetaclass):
|
||||
Raises:
|
||||
ValueError: If model is not a valid string
|
||||
"""
|
||||
if cls.__name__ != "LLM":
|
||||
return super().__call__(*args, **kwargs)
|
||||
|
||||
model = kwargs.get("model") or (args[0] if args else None)
|
||||
is_litellm = kwargs.get("is_litellm", False)
|
||||
|
||||
if not model or not isinstance(model, str):
|
||||
raise ValueError("Model must be a non-empty string")
|
||||
|
||||
# Only perform routing if called on the base LLM class
|
||||
# Subclasses (OpenAICompletion, etc.) should create normally
|
||||
from crewai.llm import LLM
|
||||
|
||||
if cls is not LLM:
|
||||
# Direct instantiation of provider class, skip routing
|
||||
return super().__call__(model=model, **kwargs)
|
||||
|
||||
# Extract provider information
|
||||
if args and not kwargs.get("model"):
|
||||
kwargs["model"] = args[0]
|
||||
args = args[1:]
|
||||
explicit_provider = kwargs.get("provider")
|
||||
|
||||
if explicit_provider:
|
||||
@@ -97,12 +96,10 @@ class LLMMeta(ModelMetaclass):
|
||||
use_native = True
|
||||
model_string = model
|
||||
|
||||
# Route to native provider if available
|
||||
native_class = cls._get_native_provider(provider) if use_native else None
|
||||
if native_class and not is_litellm and provider in SUPPORTED_NATIVE_PROVIDERS:
|
||||
try:
|
||||
# Remove 'provider' from kwargs to avoid duplicate keyword argument
|
||||
kwargs_copy = {k: v for k, v in kwargs.items() if k != "provider"}
|
||||
kwargs_copy = {k: v for k, v in kwargs.items() if k not in ("provider", "model")}
|
||||
return native_class(
|
||||
model=model_string, provider=provider, **kwargs_copy
|
||||
)
|
||||
@@ -111,15 +108,14 @@ class LLMMeta(ModelMetaclass):
|
||||
except Exception as e:
|
||||
raise ImportError(f"Error importing native provider: {e}") from e
|
||||
|
||||
# Fallback to LiteLLM
|
||||
try:
|
||||
import litellm # noqa: F401
|
||||
except ImportError:
|
||||
logging.error("LiteLLM is not available, falling back to LiteLLM")
|
||||
raise ImportError("Fallback to LiteLLM is not available") from None
|
||||
|
||||
# Create actual LLM instance with is_litellm=True
|
||||
return super().__call__(model=model, is_litellm=True, **kwargs)
|
||||
kwargs_copy = {k: v for k, v in kwargs.items() if k not in ("model", "is_litellm")}
|
||||
return super().__call__(model=model, is_litellm=True, **kwargs_copy)
|
||||
|
||||
@staticmethod
|
||||
def _validate_model_in_constants(model: str, provider: str) -> bool:
|
||||
|
||||
@@ -3,9 +3,10 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, cast
|
||||
from typing import Any, ClassVar, cast
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llm.base_llm import BaseLLM
|
||||
@@ -19,9 +20,6 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.llm.hooks.base import BaseInterceptor
|
||||
|
||||
try:
|
||||
from anthropic import Anthropic
|
||||
from anthropic.types import Message
|
||||
@@ -38,90 +36,67 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
This class provides direct integration with the Anthropic Python SDK,
|
||||
offering native tool use, streaming support, and proper message formatting.
|
||||
|
||||
Attributes:
|
||||
model: Anthropic model name (e.g., 'claude-3-5-sonnet-20241022')
|
||||
base_url: Custom base URL for Anthropic API
|
||||
timeout: Request timeout in seconds
|
||||
max_retries: Maximum number of retries
|
||||
max_tokens: Maximum tokens in response (required for Anthropic)
|
||||
top_p: Nucleus sampling parameter
|
||||
stream: Enable streaming responses
|
||||
client_params: Additional parameters for the Anthropic client
|
||||
interceptor: HTTP interceptor for modifying requests/responses at transport level
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(ignored_types=(property,))
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(
|
||||
ignored_types=(property,), arbitrary_types_allowed=True
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "claude-3-5-sonnet-20241022",
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
timeout: float | None = None,
|
||||
max_retries: int = 2,
|
||||
temperature: float | None = None,
|
||||
max_tokens: int = 4096, # Required for Anthropic
|
||||
top_p: float | None = None,
|
||||
stop_sequences: list[str] | None = None,
|
||||
stream: bool = False,
|
||||
client_params: dict[str, Any] | None = None,
|
||||
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize Anthropic chat completion client.
|
||||
base_url: str | None = Field(
|
||||
default=None, description="Custom base URL for Anthropic API"
|
||||
)
|
||||
timeout: float | None = Field(
|
||||
default=None, description="Request timeout in seconds"
|
||||
)
|
||||
max_retries: int = Field(default=2, description="Maximum number of retries")
|
||||
max_tokens: int = Field(
|
||||
default=4096, description="Maximum tokens in response (required for Anthropic)"
|
||||
)
|
||||
top_p: float | None = Field(default=None, description="Nucleus sampling parameter")
|
||||
stream: bool = Field(default=False, description="Enable streaming responses")
|
||||
client_params: dict[str, Any] | None = Field(
|
||||
default=None, description="Additional Anthropic client parameters"
|
||||
)
|
||||
interceptor: Any = Field(
|
||||
default=None, description="HTTP interceptor for request/response modification"
|
||||
)
|
||||
client: Any = Field(
|
||||
default=None, exclude=True, description="Anthropic client instance"
|
||||
)
|
||||
|
||||
Args:
|
||||
model: Anthropic model name (e.g., 'claude-3-5-sonnet-20241022')
|
||||
api_key: Anthropic API key (defaults to ANTHROPIC_API_KEY env var)
|
||||
base_url: Custom base URL for Anthropic API
|
||||
timeout: Request timeout in seconds
|
||||
max_retries: Maximum number of retries
|
||||
temperature: Sampling temperature (0-1)
|
||||
max_tokens: Maximum tokens in response (required for Anthropic)
|
||||
top_p: Nucleus sampling parameter
|
||||
stop_sequences: Stop sequences (Anthropic uses stop_sequences, not stop)
|
||||
stream: Enable streaming responses
|
||||
client_params: Additional parameters for the Anthropic client
|
||||
interceptor: HTTP interceptor for modifying requests/responses at transport level.
|
||||
**kwargs: Additional parameters
|
||||
"""
|
||||
super().__init__(
|
||||
model=model, temperature=temperature, stop=stop_sequences or [], **kwargs
|
||||
)
|
||||
|
||||
# Client params
|
||||
self.interceptor = interceptor
|
||||
self.client_params = client_params
|
||||
self.base_url = base_url
|
||||
self.timeout = timeout
|
||||
self.max_retries = max_retries
|
||||
_is_claude_3: bool = PrivateAttr(default=False)
|
||||
_supports_tools: bool = PrivateAttr(default=False)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def setup_client(self) -> Self:
|
||||
"""Initialize the Anthropic client and model-specific settings."""
|
||||
self.client = Anthropic(**self._get_client_params())
|
||||
|
||||
# Store completion parameters
|
||||
self.max_tokens = max_tokens
|
||||
self.top_p = top_p
|
||||
self.stream = stream
|
||||
self.stop_sequences = stop_sequences or []
|
||||
self._is_claude_3 = "claude-3" in self.model.lower()
|
||||
self._supports_tools = self._is_claude_3
|
||||
|
||||
# Model-specific settings
|
||||
self.is_claude_3 = "claude-3" in model.lower()
|
||||
self.supports_tools = self.is_claude_3 # Claude 3+ supports tool use
|
||||
return self
|
||||
|
||||
#
|
||||
# @property
|
||||
# def stop(self) -> list[str]: # type: ignore[misc]
|
||||
# """Get stop sequences sent to the API."""
|
||||
# return self.stop_sequences
|
||||
@property
|
||||
def is_claude_3(self) -> bool:
|
||||
"""Check if model is Claude 3."""
|
||||
return self._is_claude_3
|
||||
|
||||
# @stop.setter
|
||||
# def stop(self, value: list[str] | str | None) -> None:
|
||||
# """Set stop sequences.
|
||||
#
|
||||
# Synchronizes stop_sequences to ensure values set by CrewAgentExecutor
|
||||
# are properly sent to the Anthropic API.
|
||||
#
|
||||
# Args:
|
||||
# value: Stop sequences as a list, single string, or None
|
||||
# """
|
||||
# if value is None:
|
||||
# self.stop_sequences = []
|
||||
# elif isinstance(value, str):
|
||||
# self.stop_sequences = [value]
|
||||
# elif isinstance(value, list):
|
||||
# self.stop_sequences = value
|
||||
# else:
|
||||
# self.stop_sequences = []
|
||||
@property
|
||||
def supports_tools(self) -> bool:
|
||||
"""Check if model supports tools."""
|
||||
return self._supports_tools
|
||||
|
||||
def _get_client_params(self) -> dict[str, Any]:
|
||||
"""Get client parameters."""
|
||||
@@ -250,8 +225,8 @@ class AnthropicCompletion(BaseLLM):
|
||||
params["temperature"] = self.temperature
|
||||
if self.top_p is not None:
|
||||
params["top_p"] = self.top_p
|
||||
if self.stop_sequences:
|
||||
params["stop_sequences"] = self.stop_sequences
|
||||
if self.stop:
|
||||
params["stop_sequences"] = self.stop
|
||||
|
||||
# Handle tools for Claude 3+
|
||||
if tools and self.supports_tools:
|
||||
|
||||
@@ -3,9 +3,10 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, ClassVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.llm.core import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES
|
||||
from crewai.llm.providers.utils.common import safe_tool_conversion
|
||||
@@ -17,7 +18,6 @@ from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.llm.hooks.base import BaseInterceptor
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
|
||||
|
||||
@@ -51,65 +51,77 @@ class AzureCompletion(BaseLLM):
|
||||
|
||||
This class provides direct integration with the Azure AI Inference Python SDK,
|
||||
offering native function calling, streaming support, and proper Azure authentication.
|
||||
|
||||
Attributes:
|
||||
model: Azure deployment name or model name
|
||||
endpoint: Azure endpoint URL
|
||||
api_version: Azure API version
|
||||
timeout: Request timeout in seconds
|
||||
max_retries: Maximum number of retries
|
||||
top_p: Nucleus sampling parameter
|
||||
frequency_penalty: Frequency penalty (-2 to 2)
|
||||
presence_penalty: Presence penalty (-2 to 2)
|
||||
max_tokens: Maximum tokens in response
|
||||
stream: Enable streaming responses
|
||||
interceptor: HTTP interceptor (not yet supported for Azure)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
api_key: str | None = None,
|
||||
endpoint: str | None = None,
|
||||
api_version: str | None = None,
|
||||
timeout: float | None = None,
|
||||
max_retries: int = 2,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = False,
|
||||
interceptor: BaseInterceptor[Any, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize Azure AI Inference chat completion client.
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
Args:
|
||||
model: Azure deployment name or model name
|
||||
api_key: Azure API key (defaults to AZURE_API_KEY env var)
|
||||
endpoint: Azure endpoint URL (defaults to AZURE_ENDPOINT env var)
|
||||
api_version: Azure API version (defaults to AZURE_API_VERSION env var)
|
||||
timeout: Request timeout in seconds
|
||||
max_retries: Maximum number of retries
|
||||
temperature: Sampling temperature (0-2)
|
||||
top_p: Nucleus sampling parameter
|
||||
frequency_penalty: Frequency penalty (-2 to 2)
|
||||
presence_penalty: Presence penalty (-2 to 2)
|
||||
max_tokens: Maximum tokens in response
|
||||
stop: Stop sequences
|
||||
stream: Enable streaming responses
|
||||
interceptor: HTTP interceptor (not yet supported for Azure).
|
||||
**kwargs: Additional parameters
|
||||
"""
|
||||
if interceptor is not None:
|
||||
endpoint: str | None = Field(
|
||||
default=None,
|
||||
description="Azure endpoint URL (defaults to AZURE_ENDPOINT env var)",
|
||||
)
|
||||
api_version: str = Field(
|
||||
default="2024-06-01",
|
||||
description="Azure API version (defaults to AZURE_API_VERSION env var or 2024-06-01)",
|
||||
)
|
||||
timeout: float | None = Field(
|
||||
default=None, description="Request timeout in seconds"
|
||||
)
|
||||
max_retries: int = Field(default=2, description="Maximum number of retries")
|
||||
top_p: float | None = Field(default=None, description="Nucleus sampling parameter")
|
||||
frequency_penalty: float | None = Field(
|
||||
default=None, description="Frequency penalty (-2 to 2)"
|
||||
)
|
||||
presence_penalty: float | None = Field(
|
||||
default=None, description="Presence penalty (-2 to 2)"
|
||||
)
|
||||
max_tokens: int | None = Field(
|
||||
default=None, description="Maximum tokens in response"
|
||||
)
|
||||
stream: bool = Field(default=False, description="Enable streaming responses")
|
||||
interceptor: Any = Field(
|
||||
default=None, description="HTTP interceptor (not yet supported for Azure)"
|
||||
)
|
||||
client: Any = Field(default=None, exclude=True, description="Azure client instance")
|
||||
|
||||
_is_openai_model: bool = PrivateAttr(default=False)
|
||||
_is_azure_openai_endpoint: bool = PrivateAttr(default=False)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def setup_client(self) -> Self:
|
||||
"""Initialize the Azure client and validate configuration."""
|
||||
if self.interceptor is not None:
|
||||
raise NotImplementedError(
|
||||
"HTTP interceptors are not yet supported for Azure AI Inference provider. "
|
||||
"Interceptors are currently supported for OpenAI and Anthropic providers only."
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
model=model, temperature=temperature, stop=stop or [], **kwargs
|
||||
)
|
||||
if self.api_key is None:
|
||||
self.api_key = os.getenv("AZURE_API_KEY")
|
||||
|
||||
self.api_key = api_key or os.getenv("AZURE_API_KEY")
|
||||
self.endpoint = (
|
||||
endpoint
|
||||
or os.getenv("AZURE_ENDPOINT")
|
||||
or os.getenv("AZURE_OPENAI_ENDPOINT")
|
||||
or os.getenv("AZURE_API_BASE")
|
||||
)
|
||||
self.api_version = api_version or os.getenv("AZURE_API_VERSION") or "2024-06-01"
|
||||
self.timeout = timeout
|
||||
self.max_retries = max_retries
|
||||
if self.endpoint is None:
|
||||
self.endpoint = (
|
||||
os.getenv("AZURE_ENDPOINT")
|
||||
or os.getenv("AZURE_OPENAI_ENDPOINT")
|
||||
or os.getenv("AZURE_API_BASE")
|
||||
)
|
||||
|
||||
if self.api_version == "2024-06-01":
|
||||
env_version = os.getenv("AZURE_API_VERSION")
|
||||
if env_version:
|
||||
self.api_version = env_version
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
@@ -120,36 +132,38 @@ class AzureCompletion(BaseLLM):
|
||||
"Azure endpoint is required. Set AZURE_ENDPOINT environment variable or pass endpoint parameter."
|
||||
)
|
||||
|
||||
# Validate and potentially fix Azure OpenAI endpoint URL
|
||||
self.endpoint = self._validate_and_fix_endpoint(self.endpoint, model)
|
||||
self.endpoint = self._validate_and_fix_endpoint(self.endpoint, self.model)
|
||||
|
||||
# Build client kwargs
|
||||
client_kwargs = {
|
||||
client_kwargs: dict[str, Any] = {
|
||||
"endpoint": self.endpoint,
|
||||
"credential": AzureKeyCredential(self.api_key),
|
||||
}
|
||||
|
||||
# Add api_version if specified (primarily for Azure OpenAI endpoints)
|
||||
if self.api_version:
|
||||
client_kwargs["api_version"] = self.api_version
|
||||
|
||||
self.client = ChatCompletionsClient(**client_kwargs) # type: ignore[arg-type]
|
||||
self.client = ChatCompletionsClient(**client_kwargs)
|
||||
|
||||
self.top_p = top_p
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.presence_penalty = presence_penalty
|
||||
self.max_tokens = max_tokens
|
||||
self.stream = stream
|
||||
|
||||
self.is_openai_model = any(
|
||||
prefix in model.lower() for prefix in ["gpt-", "o1-", "text-"]
|
||||
self._is_openai_model = any(
|
||||
prefix in self.model.lower() for prefix in ["gpt-", "o1-", "text-"]
|
||||
)
|
||||
|
||||
self.is_azure_openai_endpoint = (
|
||||
self._is_azure_openai_endpoint = (
|
||||
"openai.azure.com" in self.endpoint
|
||||
and "/openai/deployments/" in self.endpoint
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
@property
|
||||
def is_openai_model(self) -> bool:
|
||||
"""Check if model is an OpenAI model."""
|
||||
return self._is_openai_model
|
||||
|
||||
@property
|
||||
def is_azure_openai_endpoint(self) -> bool:
|
||||
"""Check if endpoint is an Azure OpenAI endpoint."""
|
||||
return self._is_azure_openai_endpoint
|
||||
|
||||
def _validate_and_fix_endpoint(self, endpoint: str, model: str) -> str:
|
||||
"""Validate and fix Azure endpoint URL format.
|
||||
|
||||
|
||||
@@ -5,8 +5,8 @@ import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from typing_extensions import Required
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
|
||||
from typing_extensions import Required, Self
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llm.base_llm import BaseLLM
|
||||
@@ -32,8 +32,6 @@ if TYPE_CHECKING:
|
||||
ToolTypeDef,
|
||||
)
|
||||
|
||||
from crewai.llm.hooks.base import BaseInterceptor
|
||||
|
||||
|
||||
try:
|
||||
from boto3.session import Session
|
||||
@@ -143,76 +141,86 @@ class BedrockCompletion(BaseLLM):
|
||||
- Complete streaming event handling (messageStart, contentBlockStart, etc.)
|
||||
- Response metadata and trace information capture
|
||||
- Model-specific conversation format handling (e.g., Cohere requirements)
|
||||
|
||||
Attributes:
|
||||
model: The Bedrock model ID to use
|
||||
aws_access_key_id: AWS access key (defaults to environment variable)
|
||||
aws_secret_access_key: AWS secret key (defaults to environment variable)
|
||||
aws_session_token: AWS session token for temporary credentials
|
||||
region_name: AWS region name
|
||||
max_tokens: Maximum tokens to generate
|
||||
top_p: Nucleus sampling parameter
|
||||
top_k: Top-k sampling parameter (Claude models only)
|
||||
stop_sequences: List of sequences that stop generation
|
||||
stream: Whether to use streaming responses
|
||||
guardrail_config: Guardrail configuration for content filtering
|
||||
additional_model_request_fields: Model-specific request parameters
|
||||
additional_model_response_field_paths: Custom response field paths
|
||||
interceptor: HTTP interceptor (not yet supported for Bedrock)
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(ignored_types=(property,))
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(
|
||||
ignored_types=(property,), arbitrary_types_allowed=True
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
aws_access_key_id: str | None = None,
|
||||
aws_secret_access_key: str | None = None,
|
||||
aws_session_token: str | None = None,
|
||||
region_name: str = "us-east-1",
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
top_p: float | None = None,
|
||||
top_k: int | None = None,
|
||||
stop_sequences: Sequence[str] | None = None,
|
||||
stream: bool = False,
|
||||
guardrail_config: dict[str, Any] | None = None,
|
||||
additional_model_request_fields: dict[str, Any] | None = None,
|
||||
additional_model_response_field_paths: list[str] | None = None,
|
||||
interceptor: BaseInterceptor[Any, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize AWS Bedrock completion client.
|
||||
aws_access_key_id: str | None = Field(
|
||||
default=None, description="AWS access key (defaults to environment variable)"
|
||||
)
|
||||
aws_secret_access_key: str | None = Field(
|
||||
default=None, description="AWS secret key (defaults to environment variable)"
|
||||
)
|
||||
aws_session_token: str | None = Field(
|
||||
default=None, description="AWS session token for temporary credentials"
|
||||
)
|
||||
region_name: str = Field(default="us-east-1", description="AWS region name")
|
||||
max_tokens: int | None = Field(
|
||||
default=None, description="Maximum tokens to generate"
|
||||
)
|
||||
top_p: float | None = Field(default=None, description="Nucleus sampling parameter")
|
||||
top_k: int | None = Field(
|
||||
default=None, description="Top-k sampling parameter (Claude models only)"
|
||||
)
|
||||
stream: bool = Field(
|
||||
default=False, description="Whether to use streaming responses"
|
||||
)
|
||||
guardrail_config: dict[str, Any] | None = Field(
|
||||
default=None, description="Guardrail configuration for content filtering"
|
||||
)
|
||||
additional_model_request_fields: dict[str, Any] | None = Field(
|
||||
default=None, description="Model-specific request parameters"
|
||||
)
|
||||
additional_model_response_field_paths: list[str] | None = Field(
|
||||
default=None, description="Custom response field paths"
|
||||
)
|
||||
interceptor: Any = Field(
|
||||
default=None, description="HTTP interceptor (not yet supported for Bedrock)"
|
||||
)
|
||||
client: Any = Field(
|
||||
default=None, exclude=True, description="Bedrock client instance"
|
||||
)
|
||||
|
||||
Args:
|
||||
model: The Bedrock model ID to use
|
||||
aws_access_key_id: AWS access key (defaults to environment variable)
|
||||
aws_secret_access_key: AWS secret key (defaults to environment variable)
|
||||
aws_session_token: AWS session token for temporary credentials
|
||||
region_name: AWS region name
|
||||
temperature: Sampling temperature for response generation
|
||||
max_tokens: Maximum tokens to generate
|
||||
top_p: Nucleus sampling parameter
|
||||
top_k: Top-k sampling parameter (Claude models only)
|
||||
stop_sequences: List of sequences that stop generation
|
||||
stream: Whether to use streaming responses
|
||||
guardrail_config: Guardrail configuration for content filtering
|
||||
additional_model_request_fields: Model-specific request parameters
|
||||
additional_model_response_field_paths: Custom response field paths
|
||||
interceptor: HTTP interceptor (not yet supported for Bedrock).
|
||||
**kwargs: Additional parameters
|
||||
"""
|
||||
if interceptor is not None:
|
||||
_is_claude_model: bool = PrivateAttr(default=False)
|
||||
_supports_tools: bool = PrivateAttr(default=True)
|
||||
_supports_streaming: bool = PrivateAttr(default=True)
|
||||
_model_id: str = PrivateAttr()
|
||||
|
||||
@model_validator(mode="after")
|
||||
def setup_client(self) -> Self:
|
||||
"""Initialize the Bedrock client and validate configuration."""
|
||||
if self.interceptor is not None:
|
||||
raise NotImplementedError(
|
||||
"HTTP interceptors are not yet supported for AWS Bedrock provider. "
|
||||
"Interceptors are currently supported for OpenAI and Anthropic providers only."
|
||||
)
|
||||
|
||||
# Extract provider from kwargs to avoid duplicate argument
|
||||
kwargs.pop("provider", None)
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
stop=stop_sequences or [],
|
||||
provider="bedrock",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Initialize Bedrock client with proper configuration
|
||||
session = Session(
|
||||
aws_access_key_id=aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID"),
|
||||
aws_secret_access_key=aws_secret_access_key
|
||||
aws_access_key_id=self.aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID"),
|
||||
aws_secret_access_key=self.aws_secret_access_key
|
||||
or os.getenv("AWS_SECRET_ACCESS_KEY"),
|
||||
aws_session_token=aws_session_token or os.getenv("AWS_SESSION_TOKEN"),
|
||||
region_name=region_name,
|
||||
aws_session_token=self.aws_session_token or os.getenv("AWS_SESSION_TOKEN"),
|
||||
region_name=self.region_name,
|
||||
)
|
||||
|
||||
# Configure client with timeouts and retries following AWS best practices
|
||||
config = Config(
|
||||
read_timeout=300,
|
||||
retries={
|
||||
@@ -223,53 +231,33 @@ class BedrockCompletion(BaseLLM):
|
||||
)
|
||||
|
||||
self.client = session.client("bedrock-runtime", config=config)
|
||||
self.region_name = region_name
|
||||
|
||||
# Store completion parameters
|
||||
self.max_tokens = max_tokens
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
self.stream = stream
|
||||
self.stop_sequences = stop_sequences or []
|
||||
self._is_claude_model = "claude" in self.model.lower()
|
||||
self._supports_tools = True
|
||||
self._supports_streaming = True
|
||||
self._model_id = self.model
|
||||
|
||||
# Store advanced features (optional)
|
||||
self.guardrail_config = guardrail_config
|
||||
self.additional_model_request_fields = additional_model_request_fields
|
||||
self.additional_model_response_field_paths = (
|
||||
additional_model_response_field_paths
|
||||
)
|
||||
return self
|
||||
|
||||
# Model-specific settings
|
||||
self.is_claude_model = "claude" in model.lower()
|
||||
self.supports_tools = True # Converse API supports tools for most models
|
||||
self.supports_streaming = True
|
||||
@property
|
||||
def is_claude_model(self) -> bool:
|
||||
"""Check if model is a Claude model."""
|
||||
return self._is_claude_model
|
||||
|
||||
# Handle inference profiles for newer models
|
||||
self.model_id = model
|
||||
@property
|
||||
def supports_tools(self) -> bool:
|
||||
"""Check if model supports tools."""
|
||||
return self._supports_tools
|
||||
|
||||
# @property
|
||||
# def stop(self) -> list[str]: # type: ignore[misc]
|
||||
# """Get stop sequences sent to the API."""
|
||||
# return list(self.stop_sequences)
|
||||
@property
|
||||
def supports_streaming(self) -> bool:
|
||||
"""Check if model supports streaming."""
|
||||
return self._supports_streaming
|
||||
|
||||
# @stop.setter
|
||||
# def stop(self, value: Sequence[str] | str | None) -> None:
|
||||
# """Set stop sequences.
|
||||
#
|
||||
# Synchronizes stop_sequences to ensure values set by CrewAgentExecutor
|
||||
# are properly sent to the Bedrock API.
|
||||
#
|
||||
# Args:
|
||||
# value: Stop sequences as a Sequence, single string, or None
|
||||
# """
|
||||
# if value is None:
|
||||
# self.stop_sequences = []
|
||||
# elif isinstance(value, str):
|
||||
# self.stop_sequences = [value]
|
||||
# elif isinstance(value, Sequence):
|
||||
# self.stop_sequences = list(value)
|
||||
# else:
|
||||
# self.stop_sequences = []
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
"""Get the model ID."""
|
||||
return self._model_id
|
||||
|
||||
def call(
|
||||
self,
|
||||
@@ -559,7 +547,7 @@ class BedrockCompletion(BaseLLM):
|
||||
"Sequence[MessageTypeDef | MessageOutputTypeDef]",
|
||||
cast(object, messages),
|
||||
),
|
||||
**body, # type: ignore[arg-type]
|
||||
**body,
|
||||
)
|
||||
|
||||
stream = response.get("stream")
|
||||
@@ -821,8 +809,8 @@ class BedrockCompletion(BaseLLM):
|
||||
config["temperature"] = float(self.temperature)
|
||||
if self.top_p is not None:
|
||||
config["topP"] = float(self.top_p)
|
||||
if self.stop_sequences:
|
||||
config["stopSequences"] = self.stop_sequences
|
||||
if self.stop:
|
||||
config["stopSequences"] = self.stop
|
||||
|
||||
if self.is_claude_model and self.top_k is not None:
|
||||
# top_k is supported by Claude models
|
||||
|
||||
@@ -2,12 +2,12 @@ import logging
|
||||
import os
|
||||
from typing import Any, ClassVar, cast
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llm.base_llm import BaseLLM
|
||||
from crewai.llm.core import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES
|
||||
from crewai.llm.hooks.base import BaseInterceptor
|
||||
from crewai.llm.providers.utils.common import safe_tool_conversion
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
@@ -31,108 +31,124 @@ class GeminiCompletion(BaseLLM):
|
||||
|
||||
This class provides direct integration with the Google Gen AI Python SDK,
|
||||
offering native function calling, streaming support, and proper Gemini formatting.
|
||||
|
||||
Attributes:
|
||||
model: Gemini model name (e.g., 'gemini-2.0-flash-001', 'gemini-1.5-pro')
|
||||
project: Google Cloud project ID (for Vertex AI)
|
||||
location: Google Cloud location (for Vertex AI, defaults to 'us-central1')
|
||||
top_p: Nucleus sampling parameter
|
||||
top_k: Top-k sampling parameter
|
||||
max_output_tokens: Maximum tokens in response
|
||||
stop_sequences: Stop sequences
|
||||
stream: Enable streaming responses
|
||||
safety_settings: Safety filter settings
|
||||
client_params: Additional parameters for Google Gen AI Client constructor
|
||||
interceptor: HTTP interceptor (not yet supported for Gemini)
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(ignored_types=(property,))
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(
|
||||
ignored_types=(property,), arbitrary_types_allowed=True
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "gemini-2.0-flash-001",
|
||||
api_key: str | None = None,
|
||||
project: str | None = None,
|
||||
location: str | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
top_k: int | None = None,
|
||||
max_output_tokens: int | None = None,
|
||||
stop_sequences: list[str] | None = None,
|
||||
stream: bool = False,
|
||||
safety_settings: dict[str, Any] | None = None,
|
||||
client_params: dict[str, Any] | None = None,
|
||||
interceptor: BaseInterceptor[Any, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize Google Gemini chat completion client.
|
||||
project: str | None = Field(
|
||||
default=None, description="Google Cloud project ID (for Vertex AI)"
|
||||
)
|
||||
location: str = Field(
|
||||
default="us-central1",
|
||||
description="Google Cloud location (for Vertex AI, defaults to 'us-central1')",
|
||||
)
|
||||
top_p: float | None = Field(default=None, description="Nucleus sampling parameter")
|
||||
top_k: int | None = Field(default=None, description="Top-k sampling parameter")
|
||||
max_output_tokens: int | None = Field(
|
||||
default=None, description="Maximum tokens in response"
|
||||
)
|
||||
stream: bool = Field(default=False, description="Enable streaming responses")
|
||||
safety_settings: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Safety filter settings"
|
||||
)
|
||||
client_params: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Additional parameters for Google Gen AI Client constructor",
|
||||
)
|
||||
interceptor: Any = Field(
|
||||
default=None, description="HTTP interceptor (not yet supported for Gemini)"
|
||||
)
|
||||
client: Any = Field(
|
||||
default=None, exclude=True, description="Gemini client instance"
|
||||
)
|
||||
|
||||
_is_gemini_2: bool = PrivateAttr(default=False)
|
||||
_is_gemini_1_5: bool = PrivateAttr(default=False)
|
||||
_supports_tools: bool = PrivateAttr(default=False)
|
||||
|
||||
@property
|
||||
def stop_sequences(self) -> list[str]:
|
||||
"""Get stop sequences as a list.
|
||||
|
||||
This property provides access to stop sequences in Gemini's native format
|
||||
while maintaining synchronization with the base class's stop attribute.
|
||||
"""
|
||||
if self.stop is None:
|
||||
return []
|
||||
if isinstance(self.stop, str):
|
||||
return [self.stop]
|
||||
return self.stop
|
||||
|
||||
@stop_sequences.setter
|
||||
def stop_sequences(self, value: list[str] | str | None) -> None:
|
||||
"""Set stop sequences, synchronizing with the stop attribute.
|
||||
|
||||
Args:
|
||||
model: Gemini model name (e.g., 'gemini-2.0-flash-001', 'gemini-1.5-pro')
|
||||
api_key: Google API key (defaults to GOOGLE_API_KEY or GEMINI_API_KEY env var)
|
||||
project: Google Cloud project ID (for Vertex AI)
|
||||
location: Google Cloud location (for Vertex AI, defaults to 'us-central1')
|
||||
temperature: Sampling temperature (0-2)
|
||||
top_p: Nucleus sampling parameter
|
||||
top_k: Top-k sampling parameter
|
||||
max_output_tokens: Maximum tokens in response
|
||||
stop_sequences: Stop sequences
|
||||
stream: Enable streaming responses
|
||||
safety_settings: Safety filter settings
|
||||
client_params: Additional parameters to pass to the Google Gen AI Client constructor.
|
||||
Supports parameters like http_options, credentials, debug_config, etc.
|
||||
interceptor: HTTP interceptor (not yet supported for Gemini).
|
||||
**kwargs: Additional parameters
|
||||
value: Stop sequences as a list, string, or None
|
||||
"""
|
||||
if interceptor is not None:
|
||||
self.stop = value
|
||||
|
||||
@model_validator(mode="after")
|
||||
def setup_client(self) -> Self:
|
||||
"""Initialize the Gemini client and validate configuration."""
|
||||
if self.interceptor is not None:
|
||||
raise NotImplementedError(
|
||||
"HTTP interceptors are not yet supported for Google Gemini provider. "
|
||||
"Interceptors are currently supported for OpenAI and Anthropic providers only."
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
model=model, temperature=temperature, stop=stop_sequences or [], **kwargs
|
||||
)
|
||||
if self.api_key is None:
|
||||
self.api_key = os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY")
|
||||
|
||||
# Store client params for later use
|
||||
self.client_params = client_params or {}
|
||||
if self.project is None:
|
||||
self.project = os.getenv("GOOGLE_CLOUD_PROJECT")
|
||||
|
||||
# Get API configuration with environment variable fallbacks
|
||||
self.api_key = (
|
||||
api_key or os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY")
|
||||
)
|
||||
self.project = project or os.getenv("GOOGLE_CLOUD_PROJECT")
|
||||
self.location = location or os.getenv("GOOGLE_CLOUD_LOCATION") or "us-central1"
|
||||
if self.location == "us-central1":
|
||||
env_location = os.getenv("GOOGLE_CLOUD_LOCATION")
|
||||
if env_location:
|
||||
self.location = env_location
|
||||
|
||||
use_vertexai = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true"
|
||||
|
||||
self.client = self._initialize_client(use_vertexai)
|
||||
|
||||
# Store completion parameters
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
self.max_output_tokens = max_output_tokens
|
||||
self.stream = stream
|
||||
self.safety_settings = safety_settings or {}
|
||||
self.stop_sequences = stop_sequences or []
|
||||
self._is_gemini_2 = "gemini-2" in self.model.lower()
|
||||
self._is_gemini_1_5 = "gemini-1.5" in self.model.lower()
|
||||
self._supports_tools = self._is_gemini_1_5 or self._is_gemini_2
|
||||
|
||||
# Model-specific settings
|
||||
self.is_gemini_2 = "gemini-2" in model.lower()
|
||||
self.is_gemini_1_5 = "gemini-1.5" in model.lower()
|
||||
self.supports_tools = self.is_gemini_1_5 or self.is_gemini_2
|
||||
return self
|
||||
|
||||
# @property
|
||||
# def stop(self) -> list[str]: # type: ignore[misc]
|
||||
# """Get stop sequences sent to the API."""
|
||||
# return self.stop_sequences
|
||||
@property
|
||||
def is_gemini_2(self) -> bool:
|
||||
"""Check if model is Gemini 2."""
|
||||
return self._is_gemini_2
|
||||
|
||||
# @stop.setter
|
||||
# def stop(self, value: list[str] | str | None) -> None:
|
||||
# """Set stop sequences.
|
||||
#
|
||||
# Synchronizes stop_sequences to ensure values set by CrewAgentExecutor
|
||||
# are properly sent to the Gemini API.
|
||||
#
|
||||
# Args:
|
||||
# value: Stop sequences as a list, single string, or None
|
||||
# """
|
||||
# if value is None:
|
||||
# self.stop_sequences = []
|
||||
# elif isinstance(value, str):
|
||||
# self.stop_sequences = [value]
|
||||
# elif isinstance(value, list):
|
||||
# self.stop_sequences = value
|
||||
# else:
|
||||
# self.stop_sequences = []
|
||||
@property
|
||||
def is_gemini_1_5(self) -> bool:
|
||||
"""Check if model is Gemini 1.5."""
|
||||
return self._is_gemini_1_5
|
||||
|
||||
def _initialize_client(self, use_vertexai: bool = False) -> genai.Client: # type: ignore[no-any-unimported]
|
||||
@property
|
||||
def supports_tools(self) -> bool:
|
||||
"""Check if model supports tools."""
|
||||
return self._supports_tools
|
||||
|
||||
def _initialize_client(self, use_vertexai: bool = False) -> Any:
|
||||
"""Initialize the Google Gen AI client with proper parameter handling.
|
||||
|
||||
Args:
|
||||
@@ -154,12 +170,9 @@ class GeminiCompletion(BaseLLM):
|
||||
"location": self.location,
|
||||
}
|
||||
)
|
||||
|
||||
client_params.pop("api_key", None)
|
||||
|
||||
elif self.api_key:
|
||||
client_params["api_key"] = self.api_key
|
||||
|
||||
client_params.pop("vertexai", None)
|
||||
client_params.pop("project", None)
|
||||
client_params.pop("location", None)
|
||||
@@ -188,7 +201,6 @@ class GeminiCompletion(BaseLLM):
|
||||
and hasattr(self.client, "vertexai")
|
||||
and self.client.vertexai
|
||||
):
|
||||
# Vertex AI configuration
|
||||
params.update(
|
||||
{
|
||||
"vertexai": True,
|
||||
@@ -300,15 +312,12 @@ class GeminiCompletion(BaseLLM):
|
||||
self.tools = tools
|
||||
config_params = {}
|
||||
|
||||
# Add system instruction if present
|
||||
if system_instruction:
|
||||
# Convert system instruction to Content format
|
||||
system_content = types.Content(
|
||||
role="user", parts=[types.Part.from_text(text=system_instruction)]
|
||||
)
|
||||
config_params["system_instruction"] = system_content
|
||||
|
||||
# Add generation config parameters
|
||||
if self.temperature is not None:
|
||||
config_params["temperature"] = self.temperature
|
||||
if self.top_p is not None:
|
||||
@@ -317,14 +326,13 @@ class GeminiCompletion(BaseLLM):
|
||||
config_params["top_k"] = self.top_k
|
||||
if self.max_output_tokens is not None:
|
||||
config_params["max_output_tokens"] = self.max_output_tokens
|
||||
if self.stop_sequences:
|
||||
config_params["stop_sequences"] = self.stop_sequences
|
||||
if self.stop:
|
||||
config_params["stop_sequences"] = self.stop
|
||||
|
||||
if response_model:
|
||||
config_params["response_mime_type"] = "application/json"
|
||||
config_params["response_schema"] = response_model.model_json_schema()
|
||||
|
||||
# Handle tools for supported models
|
||||
if tools and self.supports_tools:
|
||||
config_params["tools"] = self._convert_tools_for_interference(tools)
|
||||
|
||||
@@ -347,7 +355,6 @@ class GeminiCompletion(BaseLLM):
|
||||
description=description,
|
||||
)
|
||||
|
||||
# Add parameters if present - ensure parameters is a dict
|
||||
if parameters and isinstance(parameters, dict):
|
||||
function_declaration.parameters = parameters
|
||||
|
||||
@@ -383,16 +390,12 @@ class GeminiCompletion(BaseLLM):
|
||||
content = message.get("content", "")
|
||||
|
||||
if role == "system":
|
||||
# Extract system instruction - Gemini handles it separately
|
||||
if system_instruction:
|
||||
system_instruction += f"\n\n{content}"
|
||||
else:
|
||||
system_instruction = cast(str, content)
|
||||
else:
|
||||
# Convert role for Gemini (assistant -> model)
|
||||
gemini_role = "model" if role == "assistant" else "user"
|
||||
|
||||
# Create Content object
|
||||
gemini_content = types.Content(
|
||||
role=gemini_role, parts=[types.Part.from_text(text=content)]
|
||||
)
|
||||
@@ -509,13 +512,11 @@ class GeminiCompletion(BaseLLM):
|
||||
else {},
|
||||
}
|
||||
|
||||
# Handle completed function calls
|
||||
if function_calls and available_functions:
|
||||
for call_data in function_calls.values():
|
||||
function_name = call_data["name"]
|
||||
function_args = call_data["args"]
|
||||
|
||||
# Execute tool
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
@@ -575,13 +576,11 @@ class GeminiCompletion(BaseLLM):
|
||||
"gemma-3-27b": 128000,
|
||||
}
|
||||
|
||||
# Find the best match for the model name
|
||||
for model_prefix, size in context_windows.items():
|
||||
if self.model.startswith(model_prefix):
|
||||
return int(size * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
|
||||
# Default context window size for Gemini models
|
||||
return int(1048576 * CONTEXT_WINDOW_USAGE_RATIO) # 1M tokens
|
||||
return int(1048576 * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
|
||||
def _extract_token_usage(self, response: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Extract token usage from Gemini response."""
|
||||
|
||||
@@ -11,7 +11,8 @@ from openai import APIConnectionError, NotFoundError, OpenAI
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion import Choice
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llm.base_llm import BaseLLM
|
||||
@@ -73,26 +74,18 @@ class OpenAICompletion(BaseLLM):
|
||||
)
|
||||
reasoning_effort: str | None = Field(None, description="Reasoning effort level")
|
||||
|
||||
# Internal state
|
||||
client: OpenAI = Field(
|
||||
default_factory=OpenAI, exclude=True, description="OpenAI client instance"
|
||||
)
|
||||
is_o1_model: bool = Field(False, description="Whether this is an O1 model")
|
||||
is_gpt4_model: bool = Field(False, description="Whether this is a GPT-4 model")
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
"""Initialize OpenAI client after model initialization.
|
||||
|
||||
Args:
|
||||
__context: Pydantic context
|
||||
"""
|
||||
super().model_post_init(__context)
|
||||
|
||||
# Set API key from environment if not provided
|
||||
@model_validator(mode="after")
|
||||
def setup_client(self) -> Self:
|
||||
"""Initialize OpenAI client after model validation."""
|
||||
if self.api_key is None:
|
||||
self.api_key = os.getenv("OPENAI_API_KEY")
|
||||
|
||||
# Initialize client
|
||||
client_config = self._get_client_params()
|
||||
if self.interceptor:
|
||||
transport = HTTPTransport(interceptor=self.interceptor)
|
||||
@@ -101,10 +94,11 @@ class OpenAICompletion(BaseLLM):
|
||||
|
||||
self.client = OpenAI(**client_config)
|
||||
|
||||
# Set model flags
|
||||
self.is_o1_model = "o1" in self.model.lower()
|
||||
self.is_gpt4_model = "gpt-4" in self.model.lower()
|
||||
|
||||
return self
|
||||
|
||||
def _get_client_params(self) -> dict[str, Any]:
|
||||
"""Get OpenAI client parameters."""
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ Classes:
|
||||
|
||||
from typing import Any
|
||||
|
||||
from crewai.llm import LLM
|
||||
from crewai.llm.core import LLM
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.utilities.logger import Logger
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ if TYPE_CHECKING:
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.tools_handler import ToolsHandler
|
||||
from crewai.lite_agent import LiteAgent
|
||||
from crewai.llm import LLM
|
||||
from crewai.llm.core import LLM
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from crewai.events.event_types import (
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageStartedEvent,
|
||||
)
|
||||
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM
|
||||
from crewai.llm.core import CONTEXT_WINDOW_USAGE_RATIO, LLM
|
||||
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
||||
from pydantic import BaseModel
|
||||
import pytest
|
||||
@@ -229,7 +229,7 @@ def test_validate_call_params_supported():
|
||||
a: int
|
||||
|
||||
# Patch supports_response_schema to simulate a supported model.
|
||||
with patch("crewai.llm.supports_response_schema", return_value=True):
|
||||
with patch("crewai.llm.core.supports_response_schema", return_value=True):
|
||||
llm = LLM(
|
||||
model="openrouter/deepseek/deepseek-chat", response_format=DummyResponse
|
||||
)
|
||||
@@ -242,7 +242,7 @@ def test_validate_call_params_not_supported():
|
||||
a: int
|
||||
|
||||
# Patch supports_response_schema to simulate an unsupported model.
|
||||
with patch("crewai.llm.supports_response_schema", return_value=False):
|
||||
with patch("crewai.llm.core.supports_response_schema", return_value=False):
|
||||
llm = LLM(model="gemini/gemini-1.5-pro", response_format=DummyResponse, is_litellm=True)
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
llm._validate_call_params()
|
||||
@@ -342,7 +342,7 @@ def test_context_window_validation():
|
||||
# Test invalid window size
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
with patch.dict(
|
||||
"crewai.llm.LLM_CONTEXT_WINDOW_SIZES",
|
||||
"crewai.llm.core.LLM_CONTEXT_WINDOW_SIZES",
|
||||
{"test-model": 500}, # Below minimum
|
||||
clear=True,
|
||||
):
|
||||
@@ -702,8 +702,8 @@ def test_ollama_does_not_modify_when_last_is_user(ollama_llm):
|
||||
|
||||
def test_native_provider_raises_error_when_supported_but_fails():
|
||||
"""Test that when a native provider is in SUPPORTED_NATIVE_PROVIDERS but fails to instantiate, we raise the error."""
|
||||
with patch("crewai.llm.SUPPORTED_NATIVE_PROVIDERS", ["openai"]):
|
||||
with patch("crewai.llm.LLM._get_native_provider") as mock_get_native:
|
||||
with patch("crewai.llm.internal.meta.SUPPORTED_NATIVE_PROVIDERS", ["openai"]):
|
||||
with patch("crewai.llm.internal.meta.LLMMeta._get_native_provider") as mock_get_native:
|
||||
# Mock that provider exists but throws an error when instantiated
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.side_effect = ValueError("Native provider initialization failed")
|
||||
@@ -718,7 +718,7 @@ def test_native_provider_raises_error_when_supported_but_fails():
|
||||
|
||||
def test_native_provider_falls_back_to_litellm_when_not_in_supported_list():
|
||||
"""Test that when a provider is not in SUPPORTED_NATIVE_PROVIDERS, we fall back to LiteLLM."""
|
||||
with patch("crewai.llm.SUPPORTED_NATIVE_PROVIDERS", ["openai", "anthropic"]):
|
||||
with patch("crewai.llm.internal.meta.SUPPORTED_NATIVE_PROVIDERS", ["openai", "anthropic"]):
|
||||
# Using a provider not in the supported list
|
||||
llm = LLM(model="groq/llama-3.1-70b-versatile", is_litellm=False)
|
||||
|
||||
|
||||
20
logs.txt
20
logs.txt
@@ -1,20 +0,0 @@
|
||||
lib/crewai/src/crewai/agent/core.py:901: error: Argument 1 has incompatible type "ToolFilterContext"; expected "dict[str, Any]" [arg-type]
|
||||
lib/crewai/src/crewai/agent/core.py:901: note: Error code "arg-type" not covered by "type: ignore" comment
|
||||
lib/crewai/src/crewai/agent/core.py:905: error: Argument 1 has incompatible type "dict[str, Any]"; expected "ToolFilterContext" [arg-type]
|
||||
lib/crewai/src/crewai/agent/core.py:905: note: Error code "arg-type" not covered by "type: ignore" comment
|
||||
lib/crewai/src/crewai/agent/core.py:996: error: Returning Any from function declared to return "dict[str, dict[str, Any]]" [no-any-return]
|
||||
lib/crewai/src/crewai/agent/core.py:1157: error: Incompatible types in assignment (expression has type "tuple[UnionType, None]", target has type "tuple[type, Any]") [assignment]
|
||||
lib/crewai/src/crewai/agent/core.py:1183: error: Argument 1 to "append" of "list" has incompatible type "type"; expected "type[str]" [arg-type]
|
||||
lib/crewai/src/crewai/agent/core.py:1188: error: Incompatible types in assignment (expression has type "UnionType", variable has type "type[str]") [assignment]
|
||||
lib/crewai/src/crewai/agent/core.py:1201: error: Argument 1 to "get" of "dict" has incompatible type "Any | None"; expected "str" [arg-type]
|
||||
Found 7 errors in 1 file (checked 4 source files)
|
||||
Success: no issues found in 4 source files
|
||||
lib/crewai/src/crewai/llm/providers/gemini/completion.py:111: error: BaseModel field may only be overridden by another field [misc]
|
||||
Found 1 error in 1 file (checked 4 source files)
|
||||
Success: no issues found in 4 source files
|
||||
lib/crewai/src/crewai/llm/providers/anthropic/completion.py:101: error: BaseModel field may only be overridden by another field [misc]
|
||||
Found 1 error in 1 file (checked 4 source files)
|
||||
lib/crewai/src/crewai/llm/providers/bedrock/completion.py:250: error: BaseModel field may only be overridden by another field [misc]
|
||||
Found 1 error in 1 file (checked 4 source files)
|
||||
|
||||
uv-lock..............................................(no files to check)Skipped
|
||||
Reference in New Issue
Block a user