chore: continue refactoring llms to base models

This commit is contained in:
Greyson LaLonde
2025-11-10 16:05:23 -05:00
parent 46785adf58
commit 722d316824
17 changed files with 381 additions and 430 deletions

View File

@@ -8,8 +8,8 @@ from crewai.crew import Crew
from crewai.crews.crew_output import CrewOutput
from crewai.flow.flow import Flow
from crewai.knowledge.knowledge import Knowledge
from crewai.llm import LLM
from crewai.llm.base_llm import BaseLLM
from crewai.llm.core import LLM
from crewai.process import Process
from crewai.task import Task
from crewai.tasks.llm_guardrail import LLMGuardrail

View File

@@ -14,7 +14,8 @@ import tomli
from crewai.cli.utils import read_toml
from crewai.cli.version import get_crewai_version
from crewai.crew import Crew
from crewai.llm import LLM, BaseLLM
from crewai.llm import LLM
from crewai.llm.base_llm import BaseLLM
from crewai.types.crew_chat import ChatInputField, ChatInputs
from crewai.utilities.llm_utils import create_llm
from crewai.utilities.printer import Printer

View File

@@ -56,8 +56,8 @@ from crewai.events.types.crew_events import (
from crewai.flow.flow_trackable import FlowTrackable
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.llm import LLM
from crewai.llm.base_llm import BaseLLM
from crewai.llm.core import LLM
from crewai.memory.entity.entity_memory import EntityMemory
from crewai.memory.external.external_memory import ExternalMemory
from crewai.memory.long_term.long_term_memory import LongTermMemory

View File

@@ -89,7 +89,7 @@ from crewai.events.types.tool_usage_events import (
ToolUsageStartedEvent,
)
from crewai.events.utils.console_formatter import ConsoleFormatter
from crewai.llm import LLM
from crewai.llm.core import LLM
from crewai.task import Task
from crewai.telemetry.telemetry import Telemetry
from crewai.utilities import Logger

View File

@@ -7,7 +7,7 @@ from pydantic import BaseModel, Field
from crewai.agent import Agent
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.llm import BaseLLM
from crewai.llm.base_llm import BaseLLM
from crewai.task import Task
from crewai.utilities.llm_utils import create_llm

View File

@@ -39,8 +39,8 @@ from crewai.events.types.agent_events import (
from crewai.events.types.logging_events import AgentLogsExecutionEvent
from crewai.flow.flow_trackable import FlowTrackable
from crewai.lite_agent_output import LiteAgentOutput
from crewai.llm import LLM
from crewai.llm.base_llm import BaseLLM
from crewai.llm.core import LLM
from crewai.tools.base_tool import BaseTool
from crewai.tools.structured_tool import CrewStructuredTool
from crewai.utilities.agent_utils import (

View File

@@ -66,7 +66,7 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta):
"""
model_config: ClassVar[ConfigDict] = ConfigDict(
arbitrary_types_allowed=True, extra="allow", validate_assignment=True
arbitrary_types_allowed=True, extra="allow"
)
# Core fields
@@ -80,7 +80,9 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta):
default="openai", description="Provider name (openai, anthropic, etc.)"
)
stop: list[str] = Field(
default_factory=list, description="Stop sequences for generation"
default_factory=list,
description="Stop sequences for generation",
validation_alias="stop_sequences",
)
# Internal fields
@@ -112,16 +114,18 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta):
if not values.get("model"):
raise ValueError("Model name is required and cannot be empty")
# Handle stop sequences
stop = values.get("stop")
stop = values.get("stop") or values.get("stop_sequences")
if stop is None:
values["stop"] = []
elif isinstance(stop, str):
values["stop"] = [stop]
elif not isinstance(stop, list):
elif isinstance(stop, list):
values["stop"] = stop
else:
values["stop"] = []
# Set default provider if not specified
values.pop("stop_sequences", None)
if "provider" not in values or values["provider"] is None:
values["provider"] = "openai"

View File

@@ -33,13 +33,12 @@ class LLMMeta(ModelMetaclass):
native provider implementation based on the model parameter.
"""
def __call__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> Any: # noqa: N805
def __call__(cls, *args: Any, **kwargs: Any) -> Any: # noqa: N805
"""Route to appropriate provider implementation at instantiation time.
Args:
model: The model identifier (e.g., "gpt-4", "claude-3-opus")
is_litellm: Force use of LiteLLM instead of native provider
**kwargs: Additional parameters for the LLM
*args: Positional arguments (model should be first for LLM class)
**kwargs: Keyword arguments including model, is_litellm, etc.
Returns:
Instance of the appropriate provider class or LLM class
@@ -47,18 +46,18 @@ class LLMMeta(ModelMetaclass):
Raises:
ValueError: If model is not a valid string
"""
if cls.__name__ != "LLM":
return super().__call__(*args, **kwargs)
model = kwargs.get("model") or (args[0] if args else None)
is_litellm = kwargs.get("is_litellm", False)
if not model or not isinstance(model, str):
raise ValueError("Model must be a non-empty string")
# Only perform routing if called on the base LLM class
# Subclasses (OpenAICompletion, etc.) should create normally
from crewai.llm import LLM
if cls is not LLM:
# Direct instantiation of provider class, skip routing
return super().__call__(model=model, **kwargs)
# Extract provider information
if args and not kwargs.get("model"):
kwargs["model"] = args[0]
args = args[1:]
explicit_provider = kwargs.get("provider")
if explicit_provider:
@@ -97,12 +96,10 @@ class LLMMeta(ModelMetaclass):
use_native = True
model_string = model
# Route to native provider if available
native_class = cls._get_native_provider(provider) if use_native else None
if native_class and not is_litellm and provider in SUPPORTED_NATIVE_PROVIDERS:
try:
# Remove 'provider' from kwargs to avoid duplicate keyword argument
kwargs_copy = {k: v for k, v in kwargs.items() if k != "provider"}
kwargs_copy = {k: v for k, v in kwargs.items() if k not in ("provider", "model")}
return native_class(
model=model_string, provider=provider, **kwargs_copy
)
@@ -111,15 +108,14 @@ class LLMMeta(ModelMetaclass):
except Exception as e:
raise ImportError(f"Error importing native provider: {e}") from e
# Fallback to LiteLLM
try:
import litellm # noqa: F401
except ImportError:
logging.error("LiteLLM is not available, falling back to LiteLLM")
raise ImportError("Fallback to LiteLLM is not available") from None
# Create actual LLM instance with is_litellm=True
return super().__call__(model=model, is_litellm=True, **kwargs)
kwargs_copy = {k: v for k, v in kwargs.items() if k not in ("model", "is_litellm")}
return super().__call__(model=model, is_litellm=True, **kwargs_copy)
@staticmethod
def _validate_model_in_constants(model: str, provider: str) -> bool:

View File

@@ -3,9 +3,10 @@ from __future__ import annotations
import json
import logging
import os
from typing import TYPE_CHECKING, Any, ClassVar, cast
from typing import Any, ClassVar, cast
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
from typing_extensions import Self
from crewai.events.types.llm_events import LLMCallType
from crewai.llm.base_llm import BaseLLM
@@ -19,9 +20,6 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
from crewai.utilities.types import LLMMessage
if TYPE_CHECKING:
from crewai.llm.hooks.base import BaseInterceptor
try:
from anthropic import Anthropic
from anthropic.types import Message
@@ -38,90 +36,67 @@ class AnthropicCompletion(BaseLLM):
This class provides direct integration with the Anthropic Python SDK,
offering native tool use, streaming support, and proper message formatting.
Attributes:
model: Anthropic model name (e.g., 'claude-3-5-sonnet-20241022')
base_url: Custom base URL for Anthropic API
timeout: Request timeout in seconds
max_retries: Maximum number of retries
max_tokens: Maximum tokens in response (required for Anthropic)
top_p: Nucleus sampling parameter
stream: Enable streaming responses
client_params: Additional parameters for the Anthropic client
interceptor: HTTP interceptor for modifying requests/responses at transport level
"""
model_config: ClassVar[ConfigDict] = ConfigDict(ignored_types=(property,))
model_config: ClassVar[ConfigDict] = ConfigDict(
ignored_types=(property,), arbitrary_types_allowed=True
)
def __init__(
self,
model: str = "claude-3-5-sonnet-20241022",
api_key: str | None = None,
base_url: str | None = None,
timeout: float | None = None,
max_retries: int = 2,
temperature: float | None = None,
max_tokens: int = 4096, # Required for Anthropic
top_p: float | None = None,
stop_sequences: list[str] | None = None,
stream: bool = False,
client_params: dict[str, Any] | None = None,
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
**kwargs: Any,
):
"""Initialize Anthropic chat completion client.
base_url: str | None = Field(
default=None, description="Custom base URL for Anthropic API"
)
timeout: float | None = Field(
default=None, description="Request timeout in seconds"
)
max_retries: int = Field(default=2, description="Maximum number of retries")
max_tokens: int = Field(
default=4096, description="Maximum tokens in response (required for Anthropic)"
)
top_p: float | None = Field(default=None, description="Nucleus sampling parameter")
stream: bool = Field(default=False, description="Enable streaming responses")
client_params: dict[str, Any] | None = Field(
default=None, description="Additional Anthropic client parameters"
)
interceptor: Any = Field(
default=None, description="HTTP interceptor for request/response modification"
)
client: Any = Field(
default=None, exclude=True, description="Anthropic client instance"
)
Args:
model: Anthropic model name (e.g., 'claude-3-5-sonnet-20241022')
api_key: Anthropic API key (defaults to ANTHROPIC_API_KEY env var)
base_url: Custom base URL for Anthropic API
timeout: Request timeout in seconds
max_retries: Maximum number of retries
temperature: Sampling temperature (0-1)
max_tokens: Maximum tokens in response (required for Anthropic)
top_p: Nucleus sampling parameter
stop_sequences: Stop sequences (Anthropic uses stop_sequences, not stop)
stream: Enable streaming responses
client_params: Additional parameters for the Anthropic client
interceptor: HTTP interceptor for modifying requests/responses at transport level.
**kwargs: Additional parameters
"""
super().__init__(
model=model, temperature=temperature, stop=stop_sequences or [], **kwargs
)
# Client params
self.interceptor = interceptor
self.client_params = client_params
self.base_url = base_url
self.timeout = timeout
self.max_retries = max_retries
_is_claude_3: bool = PrivateAttr(default=False)
_supports_tools: bool = PrivateAttr(default=False)
@model_validator(mode="after")
def setup_client(self) -> Self:
"""Initialize the Anthropic client and model-specific settings."""
self.client = Anthropic(**self._get_client_params())
# Store completion parameters
self.max_tokens = max_tokens
self.top_p = top_p
self.stream = stream
self.stop_sequences = stop_sequences or []
self._is_claude_3 = "claude-3" in self.model.lower()
self._supports_tools = self._is_claude_3
# Model-specific settings
self.is_claude_3 = "claude-3" in model.lower()
self.supports_tools = self.is_claude_3 # Claude 3+ supports tool use
return self
#
# @property
# def stop(self) -> list[str]: # type: ignore[misc]
# """Get stop sequences sent to the API."""
# return self.stop_sequences
@property
def is_claude_3(self) -> bool:
"""Check if model is Claude 3."""
return self._is_claude_3
# @stop.setter
# def stop(self, value: list[str] | str | None) -> None:
# """Set stop sequences.
#
# Synchronizes stop_sequences to ensure values set by CrewAgentExecutor
# are properly sent to the Anthropic API.
#
# Args:
# value: Stop sequences as a list, single string, or None
# """
# if value is None:
# self.stop_sequences = []
# elif isinstance(value, str):
# self.stop_sequences = [value]
# elif isinstance(value, list):
# self.stop_sequences = value
# else:
# self.stop_sequences = []
@property
def supports_tools(self) -> bool:
"""Check if model supports tools."""
return self._supports_tools
def _get_client_params(self) -> dict[str, Any]:
"""Get client parameters."""
@@ -250,8 +225,8 @@ class AnthropicCompletion(BaseLLM):
params["temperature"] = self.temperature
if self.top_p is not None:
params["top_p"] = self.top_p
if self.stop_sequences:
params["stop_sequences"] = self.stop_sequences
if self.stop:
params["stop_sequences"] = self.stop
# Handle tools for Claude 3+
if tools and self.supports_tools:

View File

@@ -3,9 +3,10 @@ from __future__ import annotations
import json
import logging
import os
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, ClassVar
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
from typing_extensions import Self
from crewai.llm.core import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES
from crewai.llm.providers.utils.common import safe_tool_conversion
@@ -17,7 +18,6 @@ from crewai.utilities.types import LLMMessage
if TYPE_CHECKING:
from crewai.llm.hooks.base import BaseInterceptor
from crewai.tools.base_tool import BaseTool
@@ -51,65 +51,77 @@ class AzureCompletion(BaseLLM):
This class provides direct integration with the Azure AI Inference Python SDK,
offering native function calling, streaming support, and proper Azure authentication.
Attributes:
model: Azure deployment name or model name
endpoint: Azure endpoint URL
api_version: Azure API version
timeout: Request timeout in seconds
max_retries: Maximum number of retries
top_p: Nucleus sampling parameter
frequency_penalty: Frequency penalty (-2 to 2)
presence_penalty: Presence penalty (-2 to 2)
max_tokens: Maximum tokens in response
stream: Enable streaming responses
interceptor: HTTP interceptor (not yet supported for Azure)
"""
def __init__(
self,
model: str,
api_key: str | None = None,
endpoint: str | None = None,
api_version: str | None = None,
timeout: float | None = None,
max_retries: int = 2,
temperature: float | None = None,
top_p: float | None = None,
frequency_penalty: float | None = None,
presence_penalty: float | None = None,
max_tokens: int | None = None,
stop: list[str] | None = None,
stream: bool = False,
interceptor: BaseInterceptor[Any, Any] | None = None,
**kwargs: Any,
):
"""Initialize Azure AI Inference chat completion client.
model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
Args:
model: Azure deployment name or model name
api_key: Azure API key (defaults to AZURE_API_KEY env var)
endpoint: Azure endpoint URL (defaults to AZURE_ENDPOINT env var)
api_version: Azure API version (defaults to AZURE_API_VERSION env var)
timeout: Request timeout in seconds
max_retries: Maximum number of retries
temperature: Sampling temperature (0-2)
top_p: Nucleus sampling parameter
frequency_penalty: Frequency penalty (-2 to 2)
presence_penalty: Presence penalty (-2 to 2)
max_tokens: Maximum tokens in response
stop: Stop sequences
stream: Enable streaming responses
interceptor: HTTP interceptor (not yet supported for Azure).
**kwargs: Additional parameters
"""
if interceptor is not None:
endpoint: str | None = Field(
default=None,
description="Azure endpoint URL (defaults to AZURE_ENDPOINT env var)",
)
api_version: str = Field(
default="2024-06-01",
description="Azure API version (defaults to AZURE_API_VERSION env var or 2024-06-01)",
)
timeout: float | None = Field(
default=None, description="Request timeout in seconds"
)
max_retries: int = Field(default=2, description="Maximum number of retries")
top_p: float | None = Field(default=None, description="Nucleus sampling parameter")
frequency_penalty: float | None = Field(
default=None, description="Frequency penalty (-2 to 2)"
)
presence_penalty: float | None = Field(
default=None, description="Presence penalty (-2 to 2)"
)
max_tokens: int | None = Field(
default=None, description="Maximum tokens in response"
)
stream: bool = Field(default=False, description="Enable streaming responses")
interceptor: Any = Field(
default=None, description="HTTP interceptor (not yet supported for Azure)"
)
client: Any = Field(default=None, exclude=True, description="Azure client instance")
_is_openai_model: bool = PrivateAttr(default=False)
_is_azure_openai_endpoint: bool = PrivateAttr(default=False)
@model_validator(mode="after")
def setup_client(self) -> Self:
"""Initialize the Azure client and validate configuration."""
if self.interceptor is not None:
raise NotImplementedError(
"HTTP interceptors are not yet supported for Azure AI Inference provider. "
"Interceptors are currently supported for OpenAI and Anthropic providers only."
)
super().__init__(
model=model, temperature=temperature, stop=stop or [], **kwargs
)
if self.api_key is None:
self.api_key = os.getenv("AZURE_API_KEY")
self.api_key = api_key or os.getenv("AZURE_API_KEY")
self.endpoint = (
endpoint
or os.getenv("AZURE_ENDPOINT")
or os.getenv("AZURE_OPENAI_ENDPOINT")
or os.getenv("AZURE_API_BASE")
)
self.api_version = api_version or os.getenv("AZURE_API_VERSION") or "2024-06-01"
self.timeout = timeout
self.max_retries = max_retries
if self.endpoint is None:
self.endpoint = (
os.getenv("AZURE_ENDPOINT")
or os.getenv("AZURE_OPENAI_ENDPOINT")
or os.getenv("AZURE_API_BASE")
)
if self.api_version == "2024-06-01":
env_version = os.getenv("AZURE_API_VERSION")
if env_version:
self.api_version = env_version
if not self.api_key:
raise ValueError(
@@ -120,36 +132,38 @@ class AzureCompletion(BaseLLM):
"Azure endpoint is required. Set AZURE_ENDPOINT environment variable or pass endpoint parameter."
)
# Validate and potentially fix Azure OpenAI endpoint URL
self.endpoint = self._validate_and_fix_endpoint(self.endpoint, model)
self.endpoint = self._validate_and_fix_endpoint(self.endpoint, self.model)
# Build client kwargs
client_kwargs = {
client_kwargs: dict[str, Any] = {
"endpoint": self.endpoint,
"credential": AzureKeyCredential(self.api_key),
}
# Add api_version if specified (primarily for Azure OpenAI endpoints)
if self.api_version:
client_kwargs["api_version"] = self.api_version
self.client = ChatCompletionsClient(**client_kwargs) # type: ignore[arg-type]
self.client = ChatCompletionsClient(**client_kwargs)
self.top_p = top_p
self.frequency_penalty = frequency_penalty
self.presence_penalty = presence_penalty
self.max_tokens = max_tokens
self.stream = stream
self.is_openai_model = any(
prefix in model.lower() for prefix in ["gpt-", "o1-", "text-"]
self._is_openai_model = any(
prefix in self.model.lower() for prefix in ["gpt-", "o1-", "text-"]
)
self.is_azure_openai_endpoint = (
self._is_azure_openai_endpoint = (
"openai.azure.com" in self.endpoint
and "/openai/deployments/" in self.endpoint
)
return self
@property
def is_openai_model(self) -> bool:
"""Check if model is an OpenAI model."""
return self._is_openai_model
@property
def is_azure_openai_endpoint(self) -> bool:
"""Check if endpoint is an Azure OpenAI endpoint."""
return self._is_azure_openai_endpoint
def _validate_and_fix_endpoint(self, endpoint: str, model: str) -> str:
"""Validate and fix Azure endpoint URL format.

View File

@@ -5,8 +5,8 @@ import logging
import os
from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast
from pydantic import BaseModel, ConfigDict
from typing_extensions import Required
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
from typing_extensions import Required, Self
from crewai.events.types.llm_events import LLMCallType
from crewai.llm.base_llm import BaseLLM
@@ -32,8 +32,6 @@ if TYPE_CHECKING:
ToolTypeDef,
)
from crewai.llm.hooks.base import BaseInterceptor
try:
from boto3.session import Session
@@ -143,76 +141,86 @@ class BedrockCompletion(BaseLLM):
- Complete streaming event handling (messageStart, contentBlockStart, etc.)
- Response metadata and trace information capture
- Model-specific conversation format handling (e.g., Cohere requirements)
Attributes:
model: The Bedrock model ID to use
aws_access_key_id: AWS access key (defaults to environment variable)
aws_secret_access_key: AWS secret key (defaults to environment variable)
aws_session_token: AWS session token for temporary credentials
region_name: AWS region name
max_tokens: Maximum tokens to generate
top_p: Nucleus sampling parameter
top_k: Top-k sampling parameter (Claude models only)
stop_sequences: List of sequences that stop generation
stream: Whether to use streaming responses
guardrail_config: Guardrail configuration for content filtering
additional_model_request_fields: Model-specific request parameters
additional_model_response_field_paths: Custom response field paths
interceptor: HTTP interceptor (not yet supported for Bedrock)
"""
model_config: ClassVar[ConfigDict] = ConfigDict(ignored_types=(property,))
model_config: ClassVar[ConfigDict] = ConfigDict(
ignored_types=(property,), arbitrary_types_allowed=True
)
def __init__(
self,
model: str = "anthropic.claude-3-5-sonnet-20241022-v2:0",
aws_access_key_id: str | None = None,
aws_secret_access_key: str | None = None,
aws_session_token: str | None = None,
region_name: str = "us-east-1",
temperature: float | None = None,
max_tokens: int | None = None,
top_p: float | None = None,
top_k: int | None = None,
stop_sequences: Sequence[str] | None = None,
stream: bool = False,
guardrail_config: dict[str, Any] | None = None,
additional_model_request_fields: dict[str, Any] | None = None,
additional_model_response_field_paths: list[str] | None = None,
interceptor: BaseInterceptor[Any, Any] | None = None,
**kwargs: Any,
) -> None:
"""Initialize AWS Bedrock completion client.
aws_access_key_id: str | None = Field(
default=None, description="AWS access key (defaults to environment variable)"
)
aws_secret_access_key: str | None = Field(
default=None, description="AWS secret key (defaults to environment variable)"
)
aws_session_token: str | None = Field(
default=None, description="AWS session token for temporary credentials"
)
region_name: str = Field(default="us-east-1", description="AWS region name")
max_tokens: int | None = Field(
default=None, description="Maximum tokens to generate"
)
top_p: float | None = Field(default=None, description="Nucleus sampling parameter")
top_k: int | None = Field(
default=None, description="Top-k sampling parameter (Claude models only)"
)
stream: bool = Field(
default=False, description="Whether to use streaming responses"
)
guardrail_config: dict[str, Any] | None = Field(
default=None, description="Guardrail configuration for content filtering"
)
additional_model_request_fields: dict[str, Any] | None = Field(
default=None, description="Model-specific request parameters"
)
additional_model_response_field_paths: list[str] | None = Field(
default=None, description="Custom response field paths"
)
interceptor: Any = Field(
default=None, description="HTTP interceptor (not yet supported for Bedrock)"
)
client: Any = Field(
default=None, exclude=True, description="Bedrock client instance"
)
Args:
model: The Bedrock model ID to use
aws_access_key_id: AWS access key (defaults to environment variable)
aws_secret_access_key: AWS secret key (defaults to environment variable)
aws_session_token: AWS session token for temporary credentials
region_name: AWS region name
temperature: Sampling temperature for response generation
max_tokens: Maximum tokens to generate
top_p: Nucleus sampling parameter
top_k: Top-k sampling parameter (Claude models only)
stop_sequences: List of sequences that stop generation
stream: Whether to use streaming responses
guardrail_config: Guardrail configuration for content filtering
additional_model_request_fields: Model-specific request parameters
additional_model_response_field_paths: Custom response field paths
interceptor: HTTP interceptor (not yet supported for Bedrock).
**kwargs: Additional parameters
"""
if interceptor is not None:
_is_claude_model: bool = PrivateAttr(default=False)
_supports_tools: bool = PrivateAttr(default=True)
_supports_streaming: bool = PrivateAttr(default=True)
_model_id: str = PrivateAttr()
@model_validator(mode="after")
def setup_client(self) -> Self:
"""Initialize the Bedrock client and validate configuration."""
if self.interceptor is not None:
raise NotImplementedError(
"HTTP interceptors are not yet supported for AWS Bedrock provider. "
"Interceptors are currently supported for OpenAI and Anthropic providers only."
)
# Extract provider from kwargs to avoid duplicate argument
kwargs.pop("provider", None)
super().__init__(
model=model,
temperature=temperature,
stop=stop_sequences or [],
provider="bedrock",
**kwargs,
)
# Initialize Bedrock client with proper configuration
session = Session(
aws_access_key_id=aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID"),
aws_secret_access_key=aws_secret_access_key
aws_access_key_id=self.aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID"),
aws_secret_access_key=self.aws_secret_access_key
or os.getenv("AWS_SECRET_ACCESS_KEY"),
aws_session_token=aws_session_token or os.getenv("AWS_SESSION_TOKEN"),
region_name=region_name,
aws_session_token=self.aws_session_token or os.getenv("AWS_SESSION_TOKEN"),
region_name=self.region_name,
)
# Configure client with timeouts and retries following AWS best practices
config = Config(
read_timeout=300,
retries={
@@ -223,53 +231,33 @@ class BedrockCompletion(BaseLLM):
)
self.client = session.client("bedrock-runtime", config=config)
self.region_name = region_name
# Store completion parameters
self.max_tokens = max_tokens
self.top_p = top_p
self.top_k = top_k
self.stream = stream
self.stop_sequences = stop_sequences or []
self._is_claude_model = "claude" in self.model.lower()
self._supports_tools = True
self._supports_streaming = True
self._model_id = self.model
# Store advanced features (optional)
self.guardrail_config = guardrail_config
self.additional_model_request_fields = additional_model_request_fields
self.additional_model_response_field_paths = (
additional_model_response_field_paths
)
return self
# Model-specific settings
self.is_claude_model = "claude" in model.lower()
self.supports_tools = True # Converse API supports tools for most models
self.supports_streaming = True
@property
def is_claude_model(self) -> bool:
"""Check if model is a Claude model."""
return self._is_claude_model
# Handle inference profiles for newer models
self.model_id = model
@property
def supports_tools(self) -> bool:
"""Check if model supports tools."""
return self._supports_tools
# @property
# def stop(self) -> list[str]: # type: ignore[misc]
# """Get stop sequences sent to the API."""
# return list(self.stop_sequences)
@property
def supports_streaming(self) -> bool:
"""Check if model supports streaming."""
return self._supports_streaming
# @stop.setter
# def stop(self, value: Sequence[str] | str | None) -> None:
# """Set stop sequences.
#
# Synchronizes stop_sequences to ensure values set by CrewAgentExecutor
# are properly sent to the Bedrock API.
#
# Args:
# value: Stop sequences as a Sequence, single string, or None
# """
# if value is None:
# self.stop_sequences = []
# elif isinstance(value, str):
# self.stop_sequences = [value]
# elif isinstance(value, Sequence):
# self.stop_sequences = list(value)
# else:
# self.stop_sequences = []
@property
def model_id(self) -> str:
"""Get the model ID."""
return self._model_id
def call(
self,
@@ -559,7 +547,7 @@ class BedrockCompletion(BaseLLM):
"Sequence[MessageTypeDef | MessageOutputTypeDef]",
cast(object, messages),
),
**body, # type: ignore[arg-type]
**body,
)
stream = response.get("stream")
@@ -821,8 +809,8 @@ class BedrockCompletion(BaseLLM):
config["temperature"] = float(self.temperature)
if self.top_p is not None:
config["topP"] = float(self.top_p)
if self.stop_sequences:
config["stopSequences"] = self.stop_sequences
if self.stop:
config["stopSequences"] = self.stop
if self.is_claude_model and self.top_k is not None:
# top_k is supported by Claude models

View File

@@ -2,12 +2,12 @@ import logging
import os
from typing import Any, ClassVar, cast
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
from typing_extensions import Self
from crewai.events.types.llm_events import LLMCallType
from crewai.llm.base_llm import BaseLLM
from crewai.llm.core import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES
from crewai.llm.hooks.base import BaseInterceptor
from crewai.llm.providers.utils.common import safe_tool_conversion
from crewai.utilities.agent_utils import is_context_length_exceeded
from crewai.utilities.exceptions.context_window_exceeding_exception import (
@@ -31,108 +31,124 @@ class GeminiCompletion(BaseLLM):
This class provides direct integration with the Google Gen AI Python SDK,
offering native function calling, streaming support, and proper Gemini formatting.
Attributes:
model: Gemini model name (e.g., 'gemini-2.0-flash-001', 'gemini-1.5-pro')
project: Google Cloud project ID (for Vertex AI)
location: Google Cloud location (for Vertex AI, defaults to 'us-central1')
top_p: Nucleus sampling parameter
top_k: Top-k sampling parameter
max_output_tokens: Maximum tokens in response
stop_sequences: Stop sequences
stream: Enable streaming responses
safety_settings: Safety filter settings
client_params: Additional parameters for Google Gen AI Client constructor
interceptor: HTTP interceptor (not yet supported for Gemini)
"""
model_config: ClassVar[ConfigDict] = ConfigDict(ignored_types=(property,))
model_config: ClassVar[ConfigDict] = ConfigDict(
ignored_types=(property,), arbitrary_types_allowed=True
)
def __init__(
self,
model: str = "gemini-2.0-flash-001",
api_key: str | None = None,
project: str | None = None,
location: str | None = None,
temperature: float | None = None,
top_p: float | None = None,
top_k: int | None = None,
max_output_tokens: int | None = None,
stop_sequences: list[str] | None = None,
stream: bool = False,
safety_settings: dict[str, Any] | None = None,
client_params: dict[str, Any] | None = None,
interceptor: BaseInterceptor[Any, Any] | None = None,
**kwargs: Any,
):
"""Initialize Google Gemini chat completion client.
project: str | None = Field(
default=None, description="Google Cloud project ID (for Vertex AI)"
)
location: str = Field(
default="us-central1",
description="Google Cloud location (for Vertex AI, defaults to 'us-central1')",
)
top_p: float | None = Field(default=None, description="Nucleus sampling parameter")
top_k: int | None = Field(default=None, description="Top-k sampling parameter")
max_output_tokens: int | None = Field(
default=None, description="Maximum tokens in response"
)
stream: bool = Field(default=False, description="Enable streaming responses")
safety_settings: dict[str, Any] = Field(
default_factory=dict, description="Safety filter settings"
)
client_params: dict[str, Any] = Field(
default_factory=dict,
description="Additional parameters for Google Gen AI Client constructor",
)
interceptor: Any = Field(
default=None, description="HTTP interceptor (not yet supported for Gemini)"
)
client: Any = Field(
default=None, exclude=True, description="Gemini client instance"
)
_is_gemini_2: bool = PrivateAttr(default=False)
_is_gemini_1_5: bool = PrivateAttr(default=False)
_supports_tools: bool = PrivateAttr(default=False)
@property
def stop_sequences(self) -> list[str]:
"""Get stop sequences as a list.
This property provides access to stop sequences in Gemini's native format
while maintaining synchronization with the base class's stop attribute.
"""
if self.stop is None:
return []
if isinstance(self.stop, str):
return [self.stop]
return self.stop
@stop_sequences.setter
def stop_sequences(self, value: list[str] | str | None) -> None:
"""Set stop sequences, synchronizing with the stop attribute.
Args:
model: Gemini model name (e.g., 'gemini-2.0-flash-001', 'gemini-1.5-pro')
api_key: Google API key (defaults to GOOGLE_API_KEY or GEMINI_API_KEY env var)
project: Google Cloud project ID (for Vertex AI)
location: Google Cloud location (for Vertex AI, defaults to 'us-central1')
temperature: Sampling temperature (0-2)
top_p: Nucleus sampling parameter
top_k: Top-k sampling parameter
max_output_tokens: Maximum tokens in response
stop_sequences: Stop sequences
stream: Enable streaming responses
safety_settings: Safety filter settings
client_params: Additional parameters to pass to the Google Gen AI Client constructor.
Supports parameters like http_options, credentials, debug_config, etc.
interceptor: HTTP interceptor (not yet supported for Gemini).
**kwargs: Additional parameters
value: Stop sequences as a list, string, or None
"""
if interceptor is not None:
self.stop = value
@model_validator(mode="after")
def setup_client(self) -> Self:
"""Initialize the Gemini client and validate configuration."""
if self.interceptor is not None:
raise NotImplementedError(
"HTTP interceptors are not yet supported for Google Gemini provider. "
"Interceptors are currently supported for OpenAI and Anthropic providers only."
)
super().__init__(
model=model, temperature=temperature, stop=stop_sequences or [], **kwargs
)
if self.api_key is None:
self.api_key = os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY")
# Store client params for later use
self.client_params = client_params or {}
if self.project is None:
self.project = os.getenv("GOOGLE_CLOUD_PROJECT")
# Get API configuration with environment variable fallbacks
self.api_key = (
api_key or os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY")
)
self.project = project or os.getenv("GOOGLE_CLOUD_PROJECT")
self.location = location or os.getenv("GOOGLE_CLOUD_LOCATION") or "us-central1"
if self.location == "us-central1":
env_location = os.getenv("GOOGLE_CLOUD_LOCATION")
if env_location:
self.location = env_location
use_vertexai = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true"
self.client = self._initialize_client(use_vertexai)
# Store completion parameters
self.top_p = top_p
self.top_k = top_k
self.max_output_tokens = max_output_tokens
self.stream = stream
self.safety_settings = safety_settings or {}
self.stop_sequences = stop_sequences or []
self._is_gemini_2 = "gemini-2" in self.model.lower()
self._is_gemini_1_5 = "gemini-1.5" in self.model.lower()
self._supports_tools = self._is_gemini_1_5 or self._is_gemini_2
# Model-specific settings
self.is_gemini_2 = "gemini-2" in model.lower()
self.is_gemini_1_5 = "gemini-1.5" in model.lower()
self.supports_tools = self.is_gemini_1_5 or self.is_gemini_2
return self
# @property
# def stop(self) -> list[str]: # type: ignore[misc]
# """Get stop sequences sent to the API."""
# return self.stop_sequences
@property
def is_gemini_2(self) -> bool:
"""Check if model is Gemini 2."""
return self._is_gemini_2
# @stop.setter
# def stop(self, value: list[str] | str | None) -> None:
# """Set stop sequences.
#
# Synchronizes stop_sequences to ensure values set by CrewAgentExecutor
# are properly sent to the Gemini API.
#
# Args:
# value: Stop sequences as a list, single string, or None
# """
# if value is None:
# self.stop_sequences = []
# elif isinstance(value, str):
# self.stop_sequences = [value]
# elif isinstance(value, list):
# self.stop_sequences = value
# else:
# self.stop_sequences = []
@property
def is_gemini_1_5(self) -> bool:
"""Check if model is Gemini 1.5."""
return self._is_gemini_1_5
def _initialize_client(self, use_vertexai: bool = False) -> genai.Client: # type: ignore[no-any-unimported]
@property
def supports_tools(self) -> bool:
"""Check if model supports tools."""
return self._supports_tools
def _initialize_client(self, use_vertexai: bool = False) -> Any:
"""Initialize the Google Gen AI client with proper parameter handling.
Args:
@@ -154,12 +170,9 @@ class GeminiCompletion(BaseLLM):
"location": self.location,
}
)
client_params.pop("api_key", None)
elif self.api_key:
client_params["api_key"] = self.api_key
client_params.pop("vertexai", None)
client_params.pop("project", None)
client_params.pop("location", None)
@@ -188,7 +201,6 @@ class GeminiCompletion(BaseLLM):
and hasattr(self.client, "vertexai")
and self.client.vertexai
):
# Vertex AI configuration
params.update(
{
"vertexai": True,
@@ -300,15 +312,12 @@ class GeminiCompletion(BaseLLM):
self.tools = tools
config_params = {}
# Add system instruction if present
if system_instruction:
# Convert system instruction to Content format
system_content = types.Content(
role="user", parts=[types.Part.from_text(text=system_instruction)]
)
config_params["system_instruction"] = system_content
# Add generation config parameters
if self.temperature is not None:
config_params["temperature"] = self.temperature
if self.top_p is not None:
@@ -317,14 +326,13 @@ class GeminiCompletion(BaseLLM):
config_params["top_k"] = self.top_k
if self.max_output_tokens is not None:
config_params["max_output_tokens"] = self.max_output_tokens
if self.stop_sequences:
config_params["stop_sequences"] = self.stop_sequences
if self.stop:
config_params["stop_sequences"] = self.stop
if response_model:
config_params["response_mime_type"] = "application/json"
config_params["response_schema"] = response_model.model_json_schema()
# Handle tools for supported models
if tools and self.supports_tools:
config_params["tools"] = self._convert_tools_for_interference(tools)
@@ -347,7 +355,6 @@ class GeminiCompletion(BaseLLM):
description=description,
)
# Add parameters if present - ensure parameters is a dict
if parameters and isinstance(parameters, dict):
function_declaration.parameters = parameters
@@ -383,16 +390,12 @@ class GeminiCompletion(BaseLLM):
content = message.get("content", "")
if role == "system":
# Extract system instruction - Gemini handles it separately
if system_instruction:
system_instruction += f"\n\n{content}"
else:
system_instruction = cast(str, content)
else:
# Convert role for Gemini (assistant -> model)
gemini_role = "model" if role == "assistant" else "user"
# Create Content object
gemini_content = types.Content(
role=gemini_role, parts=[types.Part.from_text(text=content)]
)
@@ -509,13 +512,11 @@ class GeminiCompletion(BaseLLM):
else {},
}
# Handle completed function calls
if function_calls and available_functions:
for call_data in function_calls.values():
function_name = call_data["name"]
function_args = call_data["args"]
# Execute tool
result = self._handle_tool_execution(
function_name=function_name,
function_args=function_args,
@@ -575,13 +576,11 @@ class GeminiCompletion(BaseLLM):
"gemma-3-27b": 128000,
}
# Find the best match for the model name
for model_prefix, size in context_windows.items():
if self.model.startswith(model_prefix):
return int(size * CONTEXT_WINDOW_USAGE_RATIO)
# Default context window size for Gemini models
return int(1048576 * CONTEXT_WINDOW_USAGE_RATIO) # 1M tokens
return int(1048576 * CONTEXT_WINDOW_USAGE_RATIO)
def _extract_token_usage(self, response: dict[str, Any]) -> dict[str, Any]:
"""Extract token usage from Gemini response."""

View File

@@ -11,7 +11,8 @@ from openai import APIConnectionError, NotFoundError, OpenAI
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
from typing_extensions import Self
from crewai.events.types.llm_events import LLMCallType
from crewai.llm.base_llm import BaseLLM
@@ -73,26 +74,18 @@ class OpenAICompletion(BaseLLM):
)
reasoning_effort: str | None = Field(None, description="Reasoning effort level")
# Internal state
client: OpenAI = Field(
default_factory=OpenAI, exclude=True, description="OpenAI client instance"
)
is_o1_model: bool = Field(False, description="Whether this is an O1 model")
is_gpt4_model: bool = Field(False, description="Whether this is a GPT-4 model")
def model_post_init(self, __context: Any) -> None:
"""Initialize OpenAI client after model initialization.
Args:
__context: Pydantic context
"""
super().model_post_init(__context)
# Set API key from environment if not provided
@model_validator(mode="after")
def setup_client(self) -> Self:
"""Initialize OpenAI client after model validation."""
if self.api_key is None:
self.api_key = os.getenv("OPENAI_API_KEY")
# Initialize client
client_config = self._get_client_params()
if self.interceptor:
transport = HTTPTransport(interceptor=self.interceptor)
@@ -101,10 +94,11 @@ class OpenAICompletion(BaseLLM):
self.client = OpenAI(**client_config)
# Set model flags
self.is_o1_model = "o1" in self.model.lower()
self.is_gpt4_model = "gpt-4" in self.model.lower()
return self
def _get_client_params(self) -> dict[str, Any]:
"""Get OpenAI client parameters."""

View File

@@ -8,7 +8,7 @@ Classes:
from typing import Any
from crewai.llm import LLM
from crewai.llm.core import LLM
from crewai.tasks.task_output import TaskOutput
from crewai.utilities.logger import Logger

View File

@@ -36,7 +36,7 @@ if TYPE_CHECKING:
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.tools_handler import ToolsHandler
from crewai.lite_agent import LiteAgent
from crewai.llm import LLM
from crewai.llm.core import LLM
from crewai.task import Task

View File

@@ -11,7 +11,7 @@ from crewai.events.event_types import (
ToolUsageFinishedEvent,
ToolUsageStartedEvent,
)
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM
from crewai.llm.core import CONTEXT_WINDOW_USAGE_RATIO, LLM
from crewai.utilities.token_counter_callback import TokenCalcHandler
from pydantic import BaseModel
import pytest
@@ -229,7 +229,7 @@ def test_validate_call_params_supported():
a: int
# Patch supports_response_schema to simulate a supported model.
with patch("crewai.llm.supports_response_schema", return_value=True):
with patch("crewai.llm.core.supports_response_schema", return_value=True):
llm = LLM(
model="openrouter/deepseek/deepseek-chat", response_format=DummyResponse
)
@@ -242,7 +242,7 @@ def test_validate_call_params_not_supported():
a: int
# Patch supports_response_schema to simulate an unsupported model.
with patch("crewai.llm.supports_response_schema", return_value=False):
with patch("crewai.llm.core.supports_response_schema", return_value=False):
llm = LLM(model="gemini/gemini-1.5-pro", response_format=DummyResponse, is_litellm=True)
with pytest.raises(ValueError) as excinfo:
llm._validate_call_params()
@@ -342,7 +342,7 @@ def test_context_window_validation():
# Test invalid window size
with pytest.raises(ValueError) as excinfo:
with patch.dict(
"crewai.llm.LLM_CONTEXT_WINDOW_SIZES",
"crewai.llm.core.LLM_CONTEXT_WINDOW_SIZES",
{"test-model": 500}, # Below minimum
clear=True,
):
@@ -702,8 +702,8 @@ def test_ollama_does_not_modify_when_last_is_user(ollama_llm):
def test_native_provider_raises_error_when_supported_but_fails():
"""Test that when a native provider is in SUPPORTED_NATIVE_PROVIDERS but fails to instantiate, we raise the error."""
with patch("crewai.llm.SUPPORTED_NATIVE_PROVIDERS", ["openai"]):
with patch("crewai.llm.LLM._get_native_provider") as mock_get_native:
with patch("crewai.llm.internal.meta.SUPPORTED_NATIVE_PROVIDERS", ["openai"]):
with patch("crewai.llm.internal.meta.LLMMeta._get_native_provider") as mock_get_native:
# Mock that provider exists but throws an error when instantiated
mock_provider = MagicMock()
mock_provider.side_effect = ValueError("Native provider initialization failed")
@@ -718,7 +718,7 @@ def test_native_provider_raises_error_when_supported_but_fails():
def test_native_provider_falls_back_to_litellm_when_not_in_supported_list():
"""Test that when a provider is not in SUPPORTED_NATIVE_PROVIDERS, we fall back to LiteLLM."""
with patch("crewai.llm.SUPPORTED_NATIVE_PROVIDERS", ["openai", "anthropic"]):
with patch("crewai.llm.internal.meta.SUPPORTED_NATIVE_PROVIDERS", ["openai", "anthropic"]):
# Using a provider not in the supported list
llm = LLM(model="groq/llama-3.1-70b-versatile", is_litellm=False)

View File

@@ -1,20 +0,0 @@
lib/crewai/src/crewai/agent/core.py:901: error: Argument 1 has incompatible type "ToolFilterContext"; expected "dict[str, Any]" [arg-type]
lib/crewai/src/crewai/agent/core.py:901: note: Error code "arg-type" not covered by "type: ignore" comment
lib/crewai/src/crewai/agent/core.py:905: error: Argument 1 has incompatible type "dict[str, Any]"; expected "ToolFilterContext" [arg-type]
lib/crewai/src/crewai/agent/core.py:905: note: Error code "arg-type" not covered by "type: ignore" comment
lib/crewai/src/crewai/agent/core.py:996: error: Returning Any from function declared to return "dict[str, dict[str, Any]]" [no-any-return]
lib/crewai/src/crewai/agent/core.py:1157: error: Incompatible types in assignment (expression has type "tuple[UnionType, None]", target has type "tuple[type, Any]") [assignment]
lib/crewai/src/crewai/agent/core.py:1183: error: Argument 1 to "append" of "list" has incompatible type "type"; expected "type[str]" [arg-type]
lib/crewai/src/crewai/agent/core.py:1188: error: Incompatible types in assignment (expression has type "UnionType", variable has type "type[str]") [assignment]
lib/crewai/src/crewai/agent/core.py:1201: error: Argument 1 to "get" of "dict" has incompatible type "Any | None"; expected "str" [arg-type]
Found 7 errors in 1 file (checked 4 source files)
Success: no issues found in 4 source files
lib/crewai/src/crewai/llm/providers/gemini/completion.py:111: error: BaseModel field may only be overridden by another field [misc]
Found 1 error in 1 file (checked 4 source files)
Success: no issues found in 4 source files
lib/crewai/src/crewai/llm/providers/anthropic/completion.py:101: error: BaseModel field may only be overridden by another field [misc]
Found 1 error in 1 file (checked 4 source files)
lib/crewai/src/crewai/llm/providers/bedrock/completion.py:250: error: BaseModel field may only be overridden by another field [misc]
Found 1 error in 1 file (checked 4 source files)
uv-lock..............................................(no files to check)Skipped