chore: continue refactoring llms to base models

This commit is contained in:
Greyson LaLonde
2025-11-10 16:05:23 -05:00
parent 46785adf58
commit d8fe83f76c
17 changed files with 379 additions and 429 deletions

View File

@@ -8,8 +8,8 @@ from crewai.crew import Crew
from crewai.crews.crew_output import CrewOutput from crewai.crews.crew_output import CrewOutput
from crewai.flow.flow import Flow from crewai.flow.flow import Flow
from crewai.knowledge.knowledge import Knowledge from crewai.knowledge.knowledge import Knowledge
from crewai.llm import LLM
from crewai.llm.base_llm import BaseLLM from crewai.llm.base_llm import BaseLLM
from crewai.llm.core import LLM
from crewai.process import Process from crewai.process import Process
from crewai.task import Task from crewai.task import Task
from crewai.tasks.llm_guardrail import LLMGuardrail from crewai.tasks.llm_guardrail import LLMGuardrail

View File

@@ -14,7 +14,8 @@ import tomli
from crewai.cli.utils import read_toml from crewai.cli.utils import read_toml
from crewai.cli.version import get_crewai_version from crewai.cli.version import get_crewai_version
from crewai.crew import Crew from crewai.crew import Crew
from crewai.llm import LLM, BaseLLM from crewai.llm import LLM
from crewai.llm.base_llm import BaseLLM
from crewai.types.crew_chat import ChatInputField, ChatInputs from crewai.types.crew_chat import ChatInputField, ChatInputs
from crewai.utilities.llm_utils import create_llm from crewai.utilities.llm_utils import create_llm
from crewai.utilities.printer import Printer from crewai.utilities.printer import Printer

View File

@@ -56,8 +56,8 @@ from crewai.events.types.crew_events import (
from crewai.flow.flow_trackable import FlowTrackable from crewai.flow.flow_trackable import FlowTrackable
from crewai.knowledge.knowledge import Knowledge from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.llm import LLM
from crewai.llm.base_llm import BaseLLM from crewai.llm.base_llm import BaseLLM
from crewai.llm.core import LLM
from crewai.memory.entity.entity_memory import EntityMemory from crewai.memory.entity.entity_memory import EntityMemory
from crewai.memory.external.external_memory import ExternalMemory from crewai.memory.external.external_memory import ExternalMemory
from crewai.memory.long_term.long_term_memory import LongTermMemory from crewai.memory.long_term.long_term_memory import LongTermMemory

View File

@@ -89,7 +89,7 @@ from crewai.events.types.tool_usage_events import (
ToolUsageStartedEvent, ToolUsageStartedEvent,
) )
from crewai.events.utils.console_formatter import ConsoleFormatter from crewai.events.utils.console_formatter import ConsoleFormatter
from crewai.llm import LLM from crewai.llm.core import LLM
from crewai.task import Task from crewai.task import Task
from crewai.telemetry.telemetry import Telemetry from crewai.telemetry.telemetry import Telemetry
from crewai.utilities import Logger from crewai.utilities import Logger

View File

@@ -7,7 +7,7 @@ from pydantic import BaseModel, Field
from crewai.agent import Agent from crewai.agent import Agent
from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.llm import BaseLLM from crewai.llm.base_llm import BaseLLM
from crewai.task import Task from crewai.task import Task
from crewai.utilities.llm_utils import create_llm from crewai.utilities.llm_utils import create_llm

View File

@@ -39,8 +39,8 @@ from crewai.events.types.agent_events import (
from crewai.events.types.logging_events import AgentLogsExecutionEvent from crewai.events.types.logging_events import AgentLogsExecutionEvent
from crewai.flow.flow_trackable import FlowTrackable from crewai.flow.flow_trackable import FlowTrackable
from crewai.lite_agent_output import LiteAgentOutput from crewai.lite_agent_output import LiteAgentOutput
from crewai.llm import LLM
from crewai.llm.base_llm import BaseLLM from crewai.llm.base_llm import BaseLLM
from crewai.llm.core import LLM
from crewai.tools.base_tool import BaseTool from crewai.tools.base_tool import BaseTool
from crewai.tools.structured_tool import CrewStructuredTool from crewai.tools.structured_tool import CrewStructuredTool
from crewai.utilities.agent_utils import ( from crewai.utilities.agent_utils import (

View File

@@ -66,7 +66,7 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta):
""" """
model_config: ClassVar[ConfigDict] = ConfigDict( model_config: ClassVar[ConfigDict] = ConfigDict(
arbitrary_types_allowed=True, extra="allow", validate_assignment=True arbitrary_types_allowed=True, extra="allow"
) )
# Core fields # Core fields
@@ -80,7 +80,9 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta):
default="openai", description="Provider name (openai, anthropic, etc.)" default="openai", description="Provider name (openai, anthropic, etc.)"
) )
stop: list[str] = Field( stop: list[str] = Field(
default_factory=list, description="Stop sequences for generation" default_factory=list,
description="Stop sequences for generation",
validation_alias="stop_sequences",
) )
# Internal fields # Internal fields
@@ -112,16 +114,18 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta):
if not values.get("model"): if not values.get("model"):
raise ValueError("Model name is required and cannot be empty") raise ValueError("Model name is required and cannot be empty")
# Handle stop sequences stop = values.get("stop") or values.get("stop_sequences")
stop = values.get("stop")
if stop is None: if stop is None:
values["stop"] = [] values["stop"] = []
elif isinstance(stop, str): elif isinstance(stop, str):
values["stop"] = [stop] values["stop"] = [stop]
elif not isinstance(stop, list): elif isinstance(stop, list):
values["stop"] = stop
else:
values["stop"] = [] values["stop"] = []
# Set default provider if not specified values.pop("stop_sequences", None)
if "provider" not in values or values["provider"] is None: if "provider" not in values or values["provider"] is None:
values["provider"] = "openai" values["provider"] = "openai"

View File

@@ -33,13 +33,12 @@ class LLMMeta(ModelMetaclass):
native provider implementation based on the model parameter. native provider implementation based on the model parameter.
""" """
def __call__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> Any: # noqa: N805 def __call__(cls, *args: Any, **kwargs: Any) -> Any: # noqa: N805
"""Route to appropriate provider implementation at instantiation time. """Route to appropriate provider implementation at instantiation time.
Args: Args:
model: The model identifier (e.g., "gpt-4", "claude-3-opus") *args: Positional arguments (model should be first for LLM class)
is_litellm: Force use of LiteLLM instead of native provider **kwargs: Keyword arguments including model, is_litellm, etc.
**kwargs: Additional parameters for the LLM
Returns: Returns:
Instance of the appropriate provider class or LLM class Instance of the appropriate provider class or LLM class
@@ -47,18 +46,18 @@ class LLMMeta(ModelMetaclass):
Raises: Raises:
ValueError: If model is not a valid string ValueError: If model is not a valid string
""" """
if cls.__name__ != "LLM":
return super().__call__(*args, **kwargs)
model = kwargs.get("model") or (args[0] if args else None)
is_litellm = kwargs.get("is_litellm", False)
if not model or not isinstance(model, str): if not model or not isinstance(model, str):
raise ValueError("Model must be a non-empty string") raise ValueError("Model must be a non-empty string")
# Only perform routing if called on the base LLM class if args and not kwargs.get("model"):
# Subclasses (OpenAICompletion, etc.) should create normally kwargs["model"] = args[0]
from crewai.llm import LLM args = args[1:]
if cls is not LLM:
# Direct instantiation of provider class, skip routing
return super().__call__(model=model, **kwargs)
# Extract provider information
explicit_provider = kwargs.get("provider") explicit_provider = kwargs.get("provider")
if explicit_provider: if explicit_provider:
@@ -97,12 +96,10 @@ class LLMMeta(ModelMetaclass):
use_native = True use_native = True
model_string = model model_string = model
# Route to native provider if available
native_class = cls._get_native_provider(provider) if use_native else None native_class = cls._get_native_provider(provider) if use_native else None
if native_class and not is_litellm and provider in SUPPORTED_NATIVE_PROVIDERS: if native_class and not is_litellm and provider in SUPPORTED_NATIVE_PROVIDERS:
try: try:
# Remove 'provider' from kwargs to avoid duplicate keyword argument kwargs_copy = {k: v for k, v in kwargs.items() if k not in ("provider", "model")}
kwargs_copy = {k: v for k, v in kwargs.items() if k != "provider"}
return native_class( return native_class(
model=model_string, provider=provider, **kwargs_copy model=model_string, provider=provider, **kwargs_copy
) )
@@ -111,14 +108,12 @@ class LLMMeta(ModelMetaclass):
except Exception as e: except Exception as e:
raise ImportError(f"Error importing native provider: {e}") from e raise ImportError(f"Error importing native provider: {e}") from e
# Fallback to LiteLLM
try: try:
import litellm # noqa: F401 import litellm # noqa: F401
except ImportError: except ImportError:
logging.error("LiteLLM is not available, falling back to LiteLLM") logging.error("LiteLLM is not available, falling back to LiteLLM")
raise ImportError("Fallback to LiteLLM is not available") from None raise ImportError("Fallback to LiteLLM is not available") from None
# Create actual LLM instance with is_litellm=True
return super().__call__(model=model, is_litellm=True, **kwargs) return super().__call__(model=model, is_litellm=True, **kwargs)
@staticmethod @staticmethod

View File

@@ -3,9 +3,10 @@ from __future__ import annotations
import json import json
import logging import logging
import os import os
from typing import TYPE_CHECKING, Any, ClassVar, cast from typing import Any, ClassVar, cast
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
from typing_extensions import Self
from crewai.events.types.llm_events import LLMCallType from crewai.events.types.llm_events import LLMCallType
from crewai.llm.base_llm import BaseLLM from crewai.llm.base_llm import BaseLLM
@@ -19,9 +20,6 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
from crewai.utilities.types import LLMMessage from crewai.utilities.types import LLMMessage
if TYPE_CHECKING:
from crewai.llm.hooks.base import BaseInterceptor
try: try:
from anthropic import Anthropic from anthropic import Anthropic
from anthropic.types import Message from anthropic.types import Message
@@ -38,90 +36,67 @@ class AnthropicCompletion(BaseLLM):
This class provides direct integration with the Anthropic Python SDK, This class provides direct integration with the Anthropic Python SDK,
offering native tool use, streaming support, and proper message formatting. offering native tool use, streaming support, and proper message formatting.
Attributes:
model: Anthropic model name (e.g., 'claude-3-5-sonnet-20241022')
base_url: Custom base URL for Anthropic API
timeout: Request timeout in seconds
max_retries: Maximum number of retries
max_tokens: Maximum tokens in response (required for Anthropic)
top_p: Nucleus sampling parameter
stream: Enable streaming responses
client_params: Additional parameters for the Anthropic client
interceptor: HTTP interceptor for modifying requests/responses at transport level
""" """
model_config: ClassVar[ConfigDict] = ConfigDict(ignored_types=(property,)) model_config: ClassVar[ConfigDict] = ConfigDict(
ignored_types=(property,), arbitrary_types_allowed=True
)
def __init__( base_url: str | None = Field(
self, default=None, description="Custom base URL for Anthropic API"
model: str = "claude-3-5-sonnet-20241022", )
api_key: str | None = None, timeout: float | None = Field(
base_url: str | None = None, default=None, description="Request timeout in seconds"
timeout: float | None = None, )
max_retries: int = 2, max_retries: int = Field(default=2, description="Maximum number of retries")
temperature: float | None = None, max_tokens: int = Field(
max_tokens: int = 4096, # Required for Anthropic default=4096, description="Maximum tokens in response (required for Anthropic)"
top_p: float | None = None, )
stop_sequences: list[str] | None = None, top_p: float | None = Field(default=None, description="Nucleus sampling parameter")
stream: bool = False, stream: bool = Field(default=False, description="Enable streaming responses")
client_params: dict[str, Any] | None = None, client_params: dict[str, Any] | None = Field(
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None, default=None, description="Additional Anthropic client parameters"
**kwargs: Any, )
): interceptor: Any = Field(
"""Initialize Anthropic chat completion client. default=None, description="HTTP interceptor for request/response modification"
)
client: Any = Field(
default=None, exclude=True, description="Anthropic client instance"
)
Args: _is_claude_3: bool = PrivateAttr(default=False)
model: Anthropic model name (e.g., 'claude-3-5-sonnet-20241022') _supports_tools: bool = PrivateAttr(default=False)
api_key: Anthropic API key (defaults to ANTHROPIC_API_KEY env var)
base_url: Custom base URL for Anthropic API
timeout: Request timeout in seconds
max_retries: Maximum number of retries
temperature: Sampling temperature (0-1)
max_tokens: Maximum tokens in response (required for Anthropic)
top_p: Nucleus sampling parameter
stop_sequences: Stop sequences (Anthropic uses stop_sequences, not stop)
stream: Enable streaming responses
client_params: Additional parameters for the Anthropic client
interceptor: HTTP interceptor for modifying requests/responses at transport level.
**kwargs: Additional parameters
"""
super().__init__(
model=model, temperature=temperature, stop=stop_sequences or [], **kwargs
)
# Client params
self.interceptor = interceptor
self.client_params = client_params
self.base_url = base_url
self.timeout = timeout
self.max_retries = max_retries
@model_validator(mode="after")
def setup_client(self) -> Self:
"""Initialize the Anthropic client and model-specific settings."""
self.client = Anthropic(**self._get_client_params()) self.client = Anthropic(**self._get_client_params())
# Store completion parameters self._is_claude_3 = "claude-3" in self.model.lower()
self.max_tokens = max_tokens self._supports_tools = self._is_claude_3
self.top_p = top_p
self.stream = stream
self.stop_sequences = stop_sequences or []
# Model-specific settings return self
self.is_claude_3 = "claude-3" in model.lower()
self.supports_tools = self.is_claude_3 # Claude 3+ supports tool use
# @property
# @property def is_claude_3(self) -> bool:
# def stop(self) -> list[str]: # type: ignore[misc] """Check if model is Claude 3."""
# """Get stop sequences sent to the API.""" return self._is_claude_3
# return self.stop_sequences
# @stop.setter @property
# def stop(self, value: list[str] | str | None) -> None: def supports_tools(self) -> bool:
# """Set stop sequences. """Check if model supports tools."""
# return self._supports_tools
# Synchronizes stop_sequences to ensure values set by CrewAgentExecutor
# are properly sent to the Anthropic API.
#
# Args:
# value: Stop sequences as a list, single string, or None
# """
# if value is None:
# self.stop_sequences = []
# elif isinstance(value, str):
# self.stop_sequences = [value]
# elif isinstance(value, list):
# self.stop_sequences = value
# else:
# self.stop_sequences = []
def _get_client_params(self) -> dict[str, Any]: def _get_client_params(self) -> dict[str, Any]:
"""Get client parameters.""" """Get client parameters."""
@@ -250,8 +225,8 @@ class AnthropicCompletion(BaseLLM):
params["temperature"] = self.temperature params["temperature"] = self.temperature
if self.top_p is not None: if self.top_p is not None:
params["top_p"] = self.top_p params["top_p"] = self.top_p
if self.stop_sequences: if self.stop:
params["stop_sequences"] = self.stop_sequences params["stop_sequences"] = self.stop
# Handle tools for Claude 3+ # Handle tools for Claude 3+
if tools and self.supports_tools: if tools and self.supports_tools:

View File

@@ -3,9 +3,10 @@ from __future__ import annotations
import json import json
import logging import logging
import os import os
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any, ClassVar
from pydantic import BaseModel from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
from typing_extensions import Self
from crewai.llm.core import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES from crewai.llm.core import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES
from crewai.llm.providers.utils.common import safe_tool_conversion from crewai.llm.providers.utils.common import safe_tool_conversion
@@ -17,7 +18,6 @@ from crewai.utilities.types import LLMMessage
if TYPE_CHECKING: if TYPE_CHECKING:
from crewai.llm.hooks.base import BaseInterceptor
from crewai.tools.base_tool import BaseTool from crewai.tools.base_tool import BaseTool
@@ -51,65 +51,77 @@ class AzureCompletion(BaseLLM):
This class provides direct integration with the Azure AI Inference Python SDK, This class provides direct integration with the Azure AI Inference Python SDK,
offering native function calling, streaming support, and proper Azure authentication. offering native function calling, streaming support, and proper Azure authentication.
Attributes:
model: Azure deployment name or model name
endpoint: Azure endpoint URL
api_version: Azure API version
timeout: Request timeout in seconds
max_retries: Maximum number of retries
top_p: Nucleus sampling parameter
frequency_penalty: Frequency penalty (-2 to 2)
presence_penalty: Presence penalty (-2 to 2)
max_tokens: Maximum tokens in response
stream: Enable streaming responses
interceptor: HTTP interceptor (not yet supported for Azure)
""" """
def __init__( model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
self,
model: str,
api_key: str | None = None,
endpoint: str | None = None,
api_version: str | None = None,
timeout: float | None = None,
max_retries: int = 2,
temperature: float | None = None,
top_p: float | None = None,
frequency_penalty: float | None = None,
presence_penalty: float | None = None,
max_tokens: int | None = None,
stop: list[str] | None = None,
stream: bool = False,
interceptor: BaseInterceptor[Any, Any] | None = None,
**kwargs: Any,
):
"""Initialize Azure AI Inference chat completion client.
Args: endpoint: str | None = Field(
model: Azure deployment name or model name default=None,
api_key: Azure API key (defaults to AZURE_API_KEY env var) description="Azure endpoint URL (defaults to AZURE_ENDPOINT env var)",
endpoint: Azure endpoint URL (defaults to AZURE_ENDPOINT env var) )
api_version: Azure API version (defaults to AZURE_API_VERSION env var) api_version: str = Field(
timeout: Request timeout in seconds default="2024-06-01",
max_retries: Maximum number of retries description="Azure API version (defaults to AZURE_API_VERSION env var or 2024-06-01)",
temperature: Sampling temperature (0-2) )
top_p: Nucleus sampling parameter timeout: float | None = Field(
frequency_penalty: Frequency penalty (-2 to 2) default=None, description="Request timeout in seconds"
presence_penalty: Presence penalty (-2 to 2) )
max_tokens: Maximum tokens in response max_retries: int = Field(default=2, description="Maximum number of retries")
stop: Stop sequences top_p: float | None = Field(default=None, description="Nucleus sampling parameter")
stream: Enable streaming responses frequency_penalty: float | None = Field(
interceptor: HTTP interceptor (not yet supported for Azure). default=None, description="Frequency penalty (-2 to 2)"
**kwargs: Additional parameters )
""" presence_penalty: float | None = Field(
if interceptor is not None: default=None, description="Presence penalty (-2 to 2)"
)
max_tokens: int | None = Field(
default=None, description="Maximum tokens in response"
)
stream: bool = Field(default=False, description="Enable streaming responses")
interceptor: Any = Field(
default=None, description="HTTP interceptor (not yet supported for Azure)"
)
client: Any = Field(default=None, exclude=True, description="Azure client instance")
_is_openai_model: bool = PrivateAttr(default=False)
_is_azure_openai_endpoint: bool = PrivateAttr(default=False)
@model_validator(mode="after")
def setup_client(self) -> Self:
"""Initialize the Azure client and validate configuration."""
if self.interceptor is not None:
raise NotImplementedError( raise NotImplementedError(
"HTTP interceptors are not yet supported for Azure AI Inference provider. " "HTTP interceptors are not yet supported for Azure AI Inference provider. "
"Interceptors are currently supported for OpenAI and Anthropic providers only." "Interceptors are currently supported for OpenAI and Anthropic providers only."
) )
super().__init__( if self.api_key is None:
model=model, temperature=temperature, stop=stop or [], **kwargs self.api_key = os.getenv("AZURE_API_KEY")
)
self.api_key = api_key or os.getenv("AZURE_API_KEY") if self.endpoint is None:
self.endpoint = ( self.endpoint = (
endpoint os.getenv("AZURE_ENDPOINT")
or os.getenv("AZURE_ENDPOINT") or os.getenv("AZURE_OPENAI_ENDPOINT")
or os.getenv("AZURE_OPENAI_ENDPOINT") or os.getenv("AZURE_API_BASE")
or os.getenv("AZURE_API_BASE") )
)
self.api_version = api_version or os.getenv("AZURE_API_VERSION") or "2024-06-01" if self.api_version == "2024-06-01":
self.timeout = timeout env_version = os.getenv("AZURE_API_VERSION")
self.max_retries = max_retries if env_version:
self.api_version = env_version
if not self.api_key: if not self.api_key:
raise ValueError( raise ValueError(
@@ -120,36 +132,38 @@ class AzureCompletion(BaseLLM):
"Azure endpoint is required. Set AZURE_ENDPOINT environment variable or pass endpoint parameter." "Azure endpoint is required. Set AZURE_ENDPOINT environment variable or pass endpoint parameter."
) )
# Validate and potentially fix Azure OpenAI endpoint URL self.endpoint = self._validate_and_fix_endpoint(self.endpoint, self.model)
self.endpoint = self._validate_and_fix_endpoint(self.endpoint, model)
# Build client kwargs client_kwargs: dict[str, Any] = {
client_kwargs = {
"endpoint": self.endpoint, "endpoint": self.endpoint,
"credential": AzureKeyCredential(self.api_key), "credential": AzureKeyCredential(self.api_key),
} }
# Add api_version if specified (primarily for Azure OpenAI endpoints)
if self.api_version: if self.api_version:
client_kwargs["api_version"] = self.api_version client_kwargs["api_version"] = self.api_version
self.client = ChatCompletionsClient(**client_kwargs) # type: ignore[arg-type] self.client = ChatCompletionsClient(**client_kwargs)
self.top_p = top_p self._is_openai_model = any(
self.frequency_penalty = frequency_penalty prefix in self.model.lower() for prefix in ["gpt-", "o1-", "text-"]
self.presence_penalty = presence_penalty
self.max_tokens = max_tokens
self.stream = stream
self.is_openai_model = any(
prefix in model.lower() for prefix in ["gpt-", "o1-", "text-"]
) )
self._is_azure_openai_endpoint = (
self.is_azure_openai_endpoint = (
"openai.azure.com" in self.endpoint "openai.azure.com" in self.endpoint
and "/openai/deployments/" in self.endpoint and "/openai/deployments/" in self.endpoint
) )
return self
@property
def is_openai_model(self) -> bool:
"""Check if model is an OpenAI model."""
return self._is_openai_model
@property
def is_azure_openai_endpoint(self) -> bool:
"""Check if endpoint is an Azure OpenAI endpoint."""
return self._is_azure_openai_endpoint
def _validate_and_fix_endpoint(self, endpoint: str, model: str) -> str: def _validate_and_fix_endpoint(self, endpoint: str, model: str) -> str:
"""Validate and fix Azure endpoint URL format. """Validate and fix Azure endpoint URL format.

View File

@@ -5,8 +5,8 @@ import logging
import os import os
from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
from typing_extensions import Required from typing_extensions import Required, Self
from crewai.events.types.llm_events import LLMCallType from crewai.events.types.llm_events import LLMCallType
from crewai.llm.base_llm import BaseLLM from crewai.llm.base_llm import BaseLLM
@@ -32,8 +32,6 @@ if TYPE_CHECKING:
ToolTypeDef, ToolTypeDef,
) )
from crewai.llm.hooks.base import BaseInterceptor
try: try:
from boto3.session import Session from boto3.session import Session
@@ -143,76 +141,86 @@ class BedrockCompletion(BaseLLM):
- Complete streaming event handling (messageStart, contentBlockStart, etc.) - Complete streaming event handling (messageStart, contentBlockStart, etc.)
- Response metadata and trace information capture - Response metadata and trace information capture
- Model-specific conversation format handling (e.g., Cohere requirements) - Model-specific conversation format handling (e.g., Cohere requirements)
Attributes:
model: The Bedrock model ID to use
aws_access_key_id: AWS access key (defaults to environment variable)
aws_secret_access_key: AWS secret key (defaults to environment variable)
aws_session_token: AWS session token for temporary credentials
region_name: AWS region name
max_tokens: Maximum tokens to generate
top_p: Nucleus sampling parameter
top_k: Top-k sampling parameter (Claude models only)
stop_sequences: List of sequences that stop generation
stream: Whether to use streaming responses
guardrail_config: Guardrail configuration for content filtering
additional_model_request_fields: Model-specific request parameters
additional_model_response_field_paths: Custom response field paths
interceptor: HTTP interceptor (not yet supported for Bedrock)
""" """
model_config: ClassVar[ConfigDict] = ConfigDict(ignored_types=(property,)) model_config: ClassVar[ConfigDict] = ConfigDict(
ignored_types=(property,), arbitrary_types_allowed=True
)
def __init__( aws_access_key_id: str | None = Field(
self, default=None, description="AWS access key (defaults to environment variable)"
model: str = "anthropic.claude-3-5-sonnet-20241022-v2:0", )
aws_access_key_id: str | None = None, aws_secret_access_key: str | None = Field(
aws_secret_access_key: str | None = None, default=None, description="AWS secret key (defaults to environment variable)"
aws_session_token: str | None = None, )
region_name: str = "us-east-1", aws_session_token: str | None = Field(
temperature: float | None = None, default=None, description="AWS session token for temporary credentials"
max_tokens: int | None = None, )
top_p: float | None = None, region_name: str = Field(default="us-east-1", description="AWS region name")
top_k: int | None = None, max_tokens: int | None = Field(
stop_sequences: Sequence[str] | None = None, default=None, description="Maximum tokens to generate"
stream: bool = False, )
guardrail_config: dict[str, Any] | None = None, top_p: float | None = Field(default=None, description="Nucleus sampling parameter")
additional_model_request_fields: dict[str, Any] | None = None, top_k: int | None = Field(
additional_model_response_field_paths: list[str] | None = None, default=None, description="Top-k sampling parameter (Claude models only)"
interceptor: BaseInterceptor[Any, Any] | None = None, )
**kwargs: Any, stream: bool = Field(
) -> None: default=False, description="Whether to use streaming responses"
"""Initialize AWS Bedrock completion client. )
guardrail_config: dict[str, Any] | None = Field(
default=None, description="Guardrail configuration for content filtering"
)
additional_model_request_fields: dict[str, Any] | None = Field(
default=None, description="Model-specific request parameters"
)
additional_model_response_field_paths: list[str] | None = Field(
default=None, description="Custom response field paths"
)
interceptor: Any = Field(
default=None, description="HTTP interceptor (not yet supported for Bedrock)"
)
client: Any = Field(
default=None, exclude=True, description="Bedrock client instance"
)
Args: _is_claude_model: bool = PrivateAttr(default=False)
model: The Bedrock model ID to use _supports_tools: bool = PrivateAttr(default=True)
aws_access_key_id: AWS access key (defaults to environment variable) _supports_streaming: bool = PrivateAttr(default=True)
aws_secret_access_key: AWS secret key (defaults to environment variable) _model_id: str = PrivateAttr()
aws_session_token: AWS session token for temporary credentials
region_name: AWS region name @model_validator(mode="after")
temperature: Sampling temperature for response generation def setup_client(self) -> Self:
max_tokens: Maximum tokens to generate """Initialize the Bedrock client and validate configuration."""
top_p: Nucleus sampling parameter if self.interceptor is not None:
top_k: Top-k sampling parameter (Claude models only)
stop_sequences: List of sequences that stop generation
stream: Whether to use streaming responses
guardrail_config: Guardrail configuration for content filtering
additional_model_request_fields: Model-specific request parameters
additional_model_response_field_paths: Custom response field paths
interceptor: HTTP interceptor (not yet supported for Bedrock).
**kwargs: Additional parameters
"""
if interceptor is not None:
raise NotImplementedError( raise NotImplementedError(
"HTTP interceptors are not yet supported for AWS Bedrock provider. " "HTTP interceptors are not yet supported for AWS Bedrock provider. "
"Interceptors are currently supported for OpenAI and Anthropic providers only." "Interceptors are currently supported for OpenAI and Anthropic providers only."
) )
# Extract provider from kwargs to avoid duplicate argument
kwargs.pop("provider", None)
super().__init__(
model=model,
temperature=temperature,
stop=stop_sequences or [],
provider="bedrock",
**kwargs,
)
# Initialize Bedrock client with proper configuration
session = Session( session = Session(
aws_access_key_id=aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID"), aws_access_key_id=self.aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID"),
aws_secret_access_key=aws_secret_access_key aws_secret_access_key=self.aws_secret_access_key
or os.getenv("AWS_SECRET_ACCESS_KEY"), or os.getenv("AWS_SECRET_ACCESS_KEY"),
aws_session_token=aws_session_token or os.getenv("AWS_SESSION_TOKEN"), aws_session_token=self.aws_session_token or os.getenv("AWS_SESSION_TOKEN"),
region_name=region_name, region_name=self.region_name,
) )
# Configure client with timeouts and retries following AWS best practices
config = Config( config = Config(
read_timeout=300, read_timeout=300,
retries={ retries={
@@ -223,53 +231,33 @@ class BedrockCompletion(BaseLLM):
) )
self.client = session.client("bedrock-runtime", config=config) self.client = session.client("bedrock-runtime", config=config)
self.region_name = region_name
# Store completion parameters self._is_claude_model = "claude" in self.model.lower()
self.max_tokens = max_tokens self._supports_tools = True
self.top_p = top_p self._supports_streaming = True
self.top_k = top_k self._model_id = self.model
self.stream = stream
self.stop_sequences = stop_sequences or []
# Store advanced features (optional) return self
self.guardrail_config = guardrail_config
self.additional_model_request_fields = additional_model_request_fields
self.additional_model_response_field_paths = (
additional_model_response_field_paths
)
# Model-specific settings @property
self.is_claude_model = "claude" in model.lower() def is_claude_model(self) -> bool:
self.supports_tools = True # Converse API supports tools for most models """Check if model is a Claude model."""
self.supports_streaming = True return self._is_claude_model
# Handle inference profiles for newer models @property
self.model_id = model def supports_tools(self) -> bool:
"""Check if model supports tools."""
return self._supports_tools
# @property @property
# def stop(self) -> list[str]: # type: ignore[misc] def supports_streaming(self) -> bool:
# """Get stop sequences sent to the API.""" """Check if model supports streaming."""
# return list(self.stop_sequences) return self._supports_streaming
# @stop.setter @property
# def stop(self, value: Sequence[str] | str | None) -> None: def model_id(self) -> str:
# """Set stop sequences. """Get the model ID."""
# return self._model_id
# Synchronizes stop_sequences to ensure values set by CrewAgentExecutor
# are properly sent to the Bedrock API.
#
# Args:
# value: Stop sequences as a Sequence, single string, or None
# """
# if value is None:
# self.stop_sequences = []
# elif isinstance(value, str):
# self.stop_sequences = [value]
# elif isinstance(value, Sequence):
# self.stop_sequences = list(value)
# else:
# self.stop_sequences = []
def call( def call(
self, self,
@@ -559,7 +547,7 @@ class BedrockCompletion(BaseLLM):
"Sequence[MessageTypeDef | MessageOutputTypeDef]", "Sequence[MessageTypeDef | MessageOutputTypeDef]",
cast(object, messages), cast(object, messages),
), ),
**body, # type: ignore[arg-type] **body,
) )
stream = response.get("stream") stream = response.get("stream")
@@ -821,8 +809,8 @@ class BedrockCompletion(BaseLLM):
config["temperature"] = float(self.temperature) config["temperature"] = float(self.temperature)
if self.top_p is not None: if self.top_p is not None:
config["topP"] = float(self.top_p) config["topP"] = float(self.top_p)
if self.stop_sequences: if self.stop:
config["stopSequences"] = self.stop_sequences config["stopSequences"] = self.stop
if self.is_claude_model and self.top_k is not None: if self.is_claude_model and self.top_k is not None:
# top_k is supported by Claude models # top_k is supported by Claude models

View File

@@ -2,12 +2,12 @@ import logging
import os import os
from typing import Any, ClassVar, cast from typing import Any, ClassVar, cast
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
from typing_extensions import Self
from crewai.events.types.llm_events import LLMCallType from crewai.events.types.llm_events import LLMCallType
from crewai.llm.base_llm import BaseLLM from crewai.llm.base_llm import BaseLLM
from crewai.llm.core import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES from crewai.llm.core import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES
from crewai.llm.hooks.base import BaseInterceptor
from crewai.llm.providers.utils.common import safe_tool_conversion from crewai.llm.providers.utils.common import safe_tool_conversion
from crewai.utilities.agent_utils import is_context_length_exceeded from crewai.utilities.agent_utils import is_context_length_exceeded
from crewai.utilities.exceptions.context_window_exceeding_exception import ( from crewai.utilities.exceptions.context_window_exceeding_exception import (
@@ -31,108 +31,124 @@ class GeminiCompletion(BaseLLM):
This class provides direct integration with the Google Gen AI Python SDK, This class provides direct integration with the Google Gen AI Python SDK,
offering native function calling, streaming support, and proper Gemini formatting. offering native function calling, streaming support, and proper Gemini formatting.
Attributes:
model: Gemini model name (e.g., 'gemini-2.0-flash-001', 'gemini-1.5-pro')
project: Google Cloud project ID (for Vertex AI)
location: Google Cloud location (for Vertex AI, defaults to 'us-central1')
top_p: Nucleus sampling parameter
top_k: Top-k sampling parameter
max_output_tokens: Maximum tokens in response
stop_sequences: Stop sequences
stream: Enable streaming responses
safety_settings: Safety filter settings
client_params: Additional parameters for Google Gen AI Client constructor
interceptor: HTTP interceptor (not yet supported for Gemini)
""" """
model_config: ClassVar[ConfigDict] = ConfigDict(ignored_types=(property,)) model_config: ClassVar[ConfigDict] = ConfigDict(
ignored_types=(property,), arbitrary_types_allowed=True
)
def __init__( project: str | None = Field(
self, default=None, description="Google Cloud project ID (for Vertex AI)"
model: str = "gemini-2.0-flash-001", )
api_key: str | None = None, location: str = Field(
project: str | None = None, default="us-central1",
location: str | None = None, description="Google Cloud location (for Vertex AI, defaults to 'us-central1')",
temperature: float | None = None, )
top_p: float | None = None, top_p: float | None = Field(default=None, description="Nucleus sampling parameter")
top_k: int | None = None, top_k: int | None = Field(default=None, description="Top-k sampling parameter")
max_output_tokens: int | None = None, max_output_tokens: int | None = Field(
stop_sequences: list[str] | None = None, default=None, description="Maximum tokens in response"
stream: bool = False, )
safety_settings: dict[str, Any] | None = None, stream: bool = Field(default=False, description="Enable streaming responses")
client_params: dict[str, Any] | None = None, safety_settings: dict[str, Any] = Field(
interceptor: BaseInterceptor[Any, Any] | None = None, default_factory=dict, description="Safety filter settings"
**kwargs: Any, )
): client_params: dict[str, Any] = Field(
"""Initialize Google Gemini chat completion client. default_factory=dict,
description="Additional parameters for Google Gen AI Client constructor",
)
interceptor: Any = Field(
default=None, description="HTTP interceptor (not yet supported for Gemini)"
)
client: Any = Field(
default=None, exclude=True, description="Gemini client instance"
)
_is_gemini_2: bool = PrivateAttr(default=False)
_is_gemini_1_5: bool = PrivateAttr(default=False)
_supports_tools: bool = PrivateAttr(default=False)
@property
def stop_sequences(self) -> list[str]:
"""Get stop sequences as a list.
This property provides access to stop sequences in Gemini's native format
while maintaining synchronization with the base class's stop attribute.
"""
if self.stop is None:
return []
if isinstance(self.stop, str):
return [self.stop]
return self.stop
@stop_sequences.setter
def stop_sequences(self, value: list[str] | str | None) -> None:
"""Set stop sequences, synchronizing with the stop attribute.
Args: Args:
model: Gemini model name (e.g., 'gemini-2.0-flash-001', 'gemini-1.5-pro') value: Stop sequences as a list, string, or None
api_key: Google API key (defaults to GOOGLE_API_KEY or GEMINI_API_KEY env var)
project: Google Cloud project ID (for Vertex AI)
location: Google Cloud location (for Vertex AI, defaults to 'us-central1')
temperature: Sampling temperature (0-2)
top_p: Nucleus sampling parameter
top_k: Top-k sampling parameter
max_output_tokens: Maximum tokens in response
stop_sequences: Stop sequences
stream: Enable streaming responses
safety_settings: Safety filter settings
client_params: Additional parameters to pass to the Google Gen AI Client constructor.
Supports parameters like http_options, credentials, debug_config, etc.
interceptor: HTTP interceptor (not yet supported for Gemini).
**kwargs: Additional parameters
""" """
if interceptor is not None: self.stop = value
@model_validator(mode="after")
def setup_client(self) -> Self:
"""Initialize the Gemini client and validate configuration."""
if self.interceptor is not None:
raise NotImplementedError( raise NotImplementedError(
"HTTP interceptors are not yet supported for Google Gemini provider. " "HTTP interceptors are not yet supported for Google Gemini provider. "
"Interceptors are currently supported for OpenAI and Anthropic providers only." "Interceptors are currently supported for OpenAI and Anthropic providers only."
) )
super().__init__( if self.api_key is None:
model=model, temperature=temperature, stop=stop_sequences or [], **kwargs self.api_key = os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY")
)
# Store client params for later use if self.project is None:
self.client_params = client_params or {} self.project = os.getenv("GOOGLE_CLOUD_PROJECT")
# Get API configuration with environment variable fallbacks if self.location == "us-central1":
self.api_key = ( env_location = os.getenv("GOOGLE_CLOUD_LOCATION")
api_key or os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY") if env_location:
) self.location = env_location
self.project = project or os.getenv("GOOGLE_CLOUD_PROJECT")
self.location = location or os.getenv("GOOGLE_CLOUD_LOCATION") or "us-central1"
use_vertexai = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true" use_vertexai = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true"
self.client = self._initialize_client(use_vertexai) self.client = self._initialize_client(use_vertexai)
# Store completion parameters self._is_gemini_2 = "gemini-2" in self.model.lower()
self.top_p = top_p self._is_gemini_1_5 = "gemini-1.5" in self.model.lower()
self.top_k = top_k self._supports_tools = self._is_gemini_1_5 or self._is_gemini_2
self.max_output_tokens = max_output_tokens
self.stream = stream
self.safety_settings = safety_settings or {}
self.stop_sequences = stop_sequences or []
# Model-specific settings return self
self.is_gemini_2 = "gemini-2" in model.lower()
self.is_gemini_1_5 = "gemini-1.5" in model.lower()
self.supports_tools = self.is_gemini_1_5 or self.is_gemini_2
# @property @property
# def stop(self) -> list[str]: # type: ignore[misc] def is_gemini_2(self) -> bool:
# """Get stop sequences sent to the API.""" """Check if model is Gemini 2."""
# return self.stop_sequences return self._is_gemini_2
# @stop.setter @property
# def stop(self, value: list[str] | str | None) -> None: def is_gemini_1_5(self) -> bool:
# """Set stop sequences. """Check if model is Gemini 1.5."""
# return self._is_gemini_1_5
# Synchronizes stop_sequences to ensure values set by CrewAgentExecutor
# are properly sent to the Gemini API.
#
# Args:
# value: Stop sequences as a list, single string, or None
# """
# if value is None:
# self.stop_sequences = []
# elif isinstance(value, str):
# self.stop_sequences = [value]
# elif isinstance(value, list):
# self.stop_sequences = value
# else:
# self.stop_sequences = []
def _initialize_client(self, use_vertexai: bool = False) -> genai.Client: # type: ignore[no-any-unimported] @property
def supports_tools(self) -> bool:
"""Check if model supports tools."""
return self._supports_tools
def _initialize_client(self, use_vertexai: bool = False) -> Any:
"""Initialize the Google Gen AI client with proper parameter handling. """Initialize the Google Gen AI client with proper parameter handling.
Args: Args:
@@ -154,12 +170,9 @@ class GeminiCompletion(BaseLLM):
"location": self.location, "location": self.location,
} }
) )
client_params.pop("api_key", None) client_params.pop("api_key", None)
elif self.api_key: elif self.api_key:
client_params["api_key"] = self.api_key client_params["api_key"] = self.api_key
client_params.pop("vertexai", None) client_params.pop("vertexai", None)
client_params.pop("project", None) client_params.pop("project", None)
client_params.pop("location", None) client_params.pop("location", None)
@@ -188,7 +201,6 @@ class GeminiCompletion(BaseLLM):
and hasattr(self.client, "vertexai") and hasattr(self.client, "vertexai")
and self.client.vertexai and self.client.vertexai
): ):
# Vertex AI configuration
params.update( params.update(
{ {
"vertexai": True, "vertexai": True,
@@ -300,15 +312,12 @@ class GeminiCompletion(BaseLLM):
self.tools = tools self.tools = tools
config_params = {} config_params = {}
# Add system instruction if present
if system_instruction: if system_instruction:
# Convert system instruction to Content format
system_content = types.Content( system_content = types.Content(
role="user", parts=[types.Part.from_text(text=system_instruction)] role="user", parts=[types.Part.from_text(text=system_instruction)]
) )
config_params["system_instruction"] = system_content config_params["system_instruction"] = system_content
# Add generation config parameters
if self.temperature is not None: if self.temperature is not None:
config_params["temperature"] = self.temperature config_params["temperature"] = self.temperature
if self.top_p is not None: if self.top_p is not None:
@@ -317,14 +326,13 @@ class GeminiCompletion(BaseLLM):
config_params["top_k"] = self.top_k config_params["top_k"] = self.top_k
if self.max_output_tokens is not None: if self.max_output_tokens is not None:
config_params["max_output_tokens"] = self.max_output_tokens config_params["max_output_tokens"] = self.max_output_tokens
if self.stop_sequences: if self.stop:
config_params["stop_sequences"] = self.stop_sequences config_params["stop_sequences"] = self.stop
if response_model: if response_model:
config_params["response_mime_type"] = "application/json" config_params["response_mime_type"] = "application/json"
config_params["response_schema"] = response_model.model_json_schema() config_params["response_schema"] = response_model.model_json_schema()
# Handle tools for supported models
if tools and self.supports_tools: if tools and self.supports_tools:
config_params["tools"] = self._convert_tools_for_interference(tools) config_params["tools"] = self._convert_tools_for_interference(tools)
@@ -347,7 +355,6 @@ class GeminiCompletion(BaseLLM):
description=description, description=description,
) )
# Add parameters if present - ensure parameters is a dict
if parameters and isinstance(parameters, dict): if parameters and isinstance(parameters, dict):
function_declaration.parameters = parameters function_declaration.parameters = parameters
@@ -383,16 +390,12 @@ class GeminiCompletion(BaseLLM):
content = message.get("content", "") content = message.get("content", "")
if role == "system": if role == "system":
# Extract system instruction - Gemini handles it separately
if system_instruction: if system_instruction:
system_instruction += f"\n\n{content}" system_instruction += f"\n\n{content}"
else: else:
system_instruction = cast(str, content) system_instruction = cast(str, content)
else: else:
# Convert role for Gemini (assistant -> model)
gemini_role = "model" if role == "assistant" else "user" gemini_role = "model" if role == "assistant" else "user"
# Create Content object
gemini_content = types.Content( gemini_content = types.Content(
role=gemini_role, parts=[types.Part.from_text(text=content)] role=gemini_role, parts=[types.Part.from_text(text=content)]
) )
@@ -509,13 +512,11 @@ class GeminiCompletion(BaseLLM):
else {}, else {},
} }
# Handle completed function calls
if function_calls and available_functions: if function_calls and available_functions:
for call_data in function_calls.values(): for call_data in function_calls.values():
function_name = call_data["name"] function_name = call_data["name"]
function_args = call_data["args"] function_args = call_data["args"]
# Execute tool
result = self._handle_tool_execution( result = self._handle_tool_execution(
function_name=function_name, function_name=function_name,
function_args=function_args, function_args=function_args,
@@ -575,13 +576,11 @@ class GeminiCompletion(BaseLLM):
"gemma-3-27b": 128000, "gemma-3-27b": 128000,
} }
# Find the best match for the model name
for model_prefix, size in context_windows.items(): for model_prefix, size in context_windows.items():
if self.model.startswith(model_prefix): if self.model.startswith(model_prefix):
return int(size * CONTEXT_WINDOW_USAGE_RATIO) return int(size * CONTEXT_WINDOW_USAGE_RATIO)
# Default context window size for Gemini models return int(1048576 * CONTEXT_WINDOW_USAGE_RATIO)
return int(1048576 * CONTEXT_WINDOW_USAGE_RATIO) # 1M tokens
def _extract_token_usage(self, response: dict[str, Any]) -> dict[str, Any]: def _extract_token_usage(self, response: dict[str, Any]) -> dict[str, Any]:
"""Extract token usage from Gemini response.""" """Extract token usage from Gemini response."""

View File

@@ -11,7 +11,8 @@ from openai import APIConnectionError, NotFoundError, OpenAI
from openai.types.chat import ChatCompletion, ChatCompletionChunk from openai.types.chat import ChatCompletion, ChatCompletionChunk
from openai.types.chat.chat_completion import Choice from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_chunk import ChoiceDelta from openai.types.chat.chat_completion_chunk import ChoiceDelta
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, model_validator
from typing_extensions import Self
from crewai.events.types.llm_events import LLMCallType from crewai.events.types.llm_events import LLMCallType
from crewai.llm.base_llm import BaseLLM from crewai.llm.base_llm import BaseLLM
@@ -73,26 +74,18 @@ class OpenAICompletion(BaseLLM):
) )
reasoning_effort: str | None = Field(None, description="Reasoning effort level") reasoning_effort: str | None = Field(None, description="Reasoning effort level")
# Internal state
client: OpenAI = Field( client: OpenAI = Field(
default_factory=OpenAI, exclude=True, description="OpenAI client instance" default_factory=OpenAI, exclude=True, description="OpenAI client instance"
) )
is_o1_model: bool = Field(False, description="Whether this is an O1 model") is_o1_model: bool = Field(False, description="Whether this is an O1 model")
is_gpt4_model: bool = Field(False, description="Whether this is a GPT-4 model") is_gpt4_model: bool = Field(False, description="Whether this is a GPT-4 model")
def model_post_init(self, __context: Any) -> None: @model_validator(mode="after")
"""Initialize OpenAI client after model initialization. def setup_client(self) -> Self:
"""Initialize OpenAI client after model validation."""
Args:
__context: Pydantic context
"""
super().model_post_init(__context)
# Set API key from environment if not provided
if self.api_key is None: if self.api_key is None:
self.api_key = os.getenv("OPENAI_API_KEY") self.api_key = os.getenv("OPENAI_API_KEY")
# Initialize client
client_config = self._get_client_params() client_config = self._get_client_params()
if self.interceptor: if self.interceptor:
transport = HTTPTransport(interceptor=self.interceptor) transport = HTTPTransport(interceptor=self.interceptor)
@@ -101,10 +94,11 @@ class OpenAICompletion(BaseLLM):
self.client = OpenAI(**client_config) self.client = OpenAI(**client_config)
# Set model flags
self.is_o1_model = "o1" in self.model.lower() self.is_o1_model = "o1" in self.model.lower()
self.is_gpt4_model = "gpt-4" in self.model.lower() self.is_gpt4_model = "gpt-4" in self.model.lower()
return self
def _get_client_params(self) -> dict[str, Any]: def _get_client_params(self) -> dict[str, Any]:
"""Get OpenAI client parameters.""" """Get OpenAI client parameters."""

View File

@@ -8,7 +8,7 @@ Classes:
from typing import Any from typing import Any
from crewai.llm import LLM from crewai.llm.core import LLM
from crewai.tasks.task_output import TaskOutput from crewai.tasks.task_output import TaskOutput
from crewai.utilities.logger import Logger from crewai.utilities.logger import Logger

View File

@@ -36,7 +36,7 @@ if TYPE_CHECKING:
from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.tools_handler import ToolsHandler from crewai.agents.tools_handler import ToolsHandler
from crewai.lite_agent import LiteAgent from crewai.lite_agent import LiteAgent
from crewai.llm import LLM from crewai.llm.core import LLM
from crewai.task import Task from crewai.task import Task

View File

@@ -11,7 +11,7 @@ from crewai.events.event_types import (
ToolUsageFinishedEvent, ToolUsageFinishedEvent,
ToolUsageStartedEvent, ToolUsageStartedEvent,
) )
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM from crewai.llm.core import CONTEXT_WINDOW_USAGE_RATIO, LLM
from crewai.utilities.token_counter_callback import TokenCalcHandler from crewai.utilities.token_counter_callback import TokenCalcHandler
from pydantic import BaseModel from pydantic import BaseModel
import pytest import pytest
@@ -229,7 +229,7 @@ def test_validate_call_params_supported():
a: int a: int
# Patch supports_response_schema to simulate a supported model. # Patch supports_response_schema to simulate a supported model.
with patch("crewai.llm.supports_response_schema", return_value=True): with patch("crewai.llm.core.supports_response_schema", return_value=True):
llm = LLM( llm = LLM(
model="openrouter/deepseek/deepseek-chat", response_format=DummyResponse model="openrouter/deepseek/deepseek-chat", response_format=DummyResponse
) )
@@ -242,7 +242,7 @@ def test_validate_call_params_not_supported():
a: int a: int
# Patch supports_response_schema to simulate an unsupported model. # Patch supports_response_schema to simulate an unsupported model.
with patch("crewai.llm.supports_response_schema", return_value=False): with patch("crewai.llm.core.supports_response_schema", return_value=False):
llm = LLM(model="gemini/gemini-1.5-pro", response_format=DummyResponse, is_litellm=True) llm = LLM(model="gemini/gemini-1.5-pro", response_format=DummyResponse, is_litellm=True)
with pytest.raises(ValueError) as excinfo: with pytest.raises(ValueError) as excinfo:
llm._validate_call_params() llm._validate_call_params()
@@ -342,7 +342,7 @@ def test_context_window_validation():
# Test invalid window size # Test invalid window size
with pytest.raises(ValueError) as excinfo: with pytest.raises(ValueError) as excinfo:
with patch.dict( with patch.dict(
"crewai.llm.LLM_CONTEXT_WINDOW_SIZES", "crewai.llm.core.LLM_CONTEXT_WINDOW_SIZES",
{"test-model": 500}, # Below minimum {"test-model": 500}, # Below minimum
clear=True, clear=True,
): ):
@@ -702,8 +702,8 @@ def test_ollama_does_not_modify_when_last_is_user(ollama_llm):
def test_native_provider_raises_error_when_supported_but_fails(): def test_native_provider_raises_error_when_supported_but_fails():
"""Test that when a native provider is in SUPPORTED_NATIVE_PROVIDERS but fails to instantiate, we raise the error.""" """Test that when a native provider is in SUPPORTED_NATIVE_PROVIDERS but fails to instantiate, we raise the error."""
with patch("crewai.llm.SUPPORTED_NATIVE_PROVIDERS", ["openai"]): with patch("crewai.llm.internal.meta.SUPPORTED_NATIVE_PROVIDERS", ["openai"]):
with patch("crewai.llm.LLM._get_native_provider") as mock_get_native: with patch("crewai.llm.internal.meta.LLMMeta._get_native_provider") as mock_get_native:
# Mock that provider exists but throws an error when instantiated # Mock that provider exists but throws an error when instantiated
mock_provider = MagicMock() mock_provider = MagicMock()
mock_provider.side_effect = ValueError("Native provider initialization failed") mock_provider.side_effect = ValueError("Native provider initialization failed")
@@ -718,7 +718,7 @@ def test_native_provider_raises_error_when_supported_but_fails():
def test_native_provider_falls_back_to_litellm_when_not_in_supported_list(): def test_native_provider_falls_back_to_litellm_when_not_in_supported_list():
"""Test that when a provider is not in SUPPORTED_NATIVE_PROVIDERS, we fall back to LiteLLM.""" """Test that when a provider is not in SUPPORTED_NATIVE_PROVIDERS, we fall back to LiteLLM."""
with patch("crewai.llm.SUPPORTED_NATIVE_PROVIDERS", ["openai", "anthropic"]): with patch("crewai.llm.internal.meta.SUPPORTED_NATIVE_PROVIDERS", ["openai", "anthropic"]):
# Using a provider not in the supported list # Using a provider not in the supported list
llm = LLM(model="groq/llama-3.1-70b-versatile", is_litellm=False) llm = LLM(model="groq/llama-3.1-70b-versatile", is_litellm=False)

View File

@@ -1,20 +0,0 @@
lib/crewai/src/crewai/agent/core.py:901: error: Argument 1 has incompatible type "ToolFilterContext"; expected "dict[str, Any]" [arg-type]
lib/crewai/src/crewai/agent/core.py:901: note: Error code "arg-type" not covered by "type: ignore" comment
lib/crewai/src/crewai/agent/core.py:905: error: Argument 1 has incompatible type "dict[str, Any]"; expected "ToolFilterContext" [arg-type]
lib/crewai/src/crewai/agent/core.py:905: note: Error code "arg-type" not covered by "type: ignore" comment
lib/crewai/src/crewai/agent/core.py:996: error: Returning Any from function declared to return "dict[str, dict[str, Any]]" [no-any-return]
lib/crewai/src/crewai/agent/core.py:1157: error: Incompatible types in assignment (expression has type "tuple[UnionType, None]", target has type "tuple[type, Any]") [assignment]
lib/crewai/src/crewai/agent/core.py:1183: error: Argument 1 to "append" of "list" has incompatible type "type"; expected "type[str]" [arg-type]
lib/crewai/src/crewai/agent/core.py:1188: error: Incompatible types in assignment (expression has type "UnionType", variable has type "type[str]") [assignment]
lib/crewai/src/crewai/agent/core.py:1201: error: Argument 1 to "get" of "dict" has incompatible type "Any | None"; expected "str" [arg-type]
Found 7 errors in 1 file (checked 4 source files)
Success: no issues found in 4 source files
lib/crewai/src/crewai/llm/providers/gemini/completion.py:111: error: BaseModel field may only be overridden by another field [misc]
Found 1 error in 1 file (checked 4 source files)
Success: no issues found in 4 source files
lib/crewai/src/crewai/llm/providers/anthropic/completion.py:101: error: BaseModel field may only be overridden by another field [misc]
Found 1 error in 1 file (checked 4 source files)
lib/crewai/src/crewai/llm/providers/bedrock/completion.py:250: error: BaseModel field may only be overridden by another field [misc]
Found 1 error in 1 file (checked 4 source files)
uv-lock..............................................(no files to check)Skipped