diff --git a/lib/crewai/src/crewai/llm/base_llm.py b/lib/crewai/src/crewai/llm/base_llm.py index 9c26d59fc..3f46ad327 100644 --- a/lib/crewai/src/crewai/llm/base_llm.py +++ b/lib/crewai/src/crewai/llm/base_llm.py @@ -12,10 +12,11 @@ import json import logging import os import re -from typing import TYPE_CHECKING, Any, ClassVar, Final +from typing import TYPE_CHECKING, Any, Final +from dotenv import load_dotenv import httpx -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +from pydantic import BaseModel, Field, field_validator from crewai.events.event_bus import crewai_event_bus from crewai.events.types.llm_events import ( @@ -42,6 +43,8 @@ if TYPE_CHECKING: from crewai.utilities.types import LLMMessage +load_dotenv() + DEFAULT_CONTEXT_WINDOW_SIZE: Final[int] = 4096 DEFAULT_SUPPORTS_STOP_WORDS: Final[bool] = True _JSON_EXTRACTION_PATTERN: Final[re.Pattern[str]] = re.compile(r"\{.*}", re.DOTALL) @@ -65,10 +68,6 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta): stop: A list of stop sequences that the LLM should use to stop generation. """ - model_config: ClassVar[ConfigDict] = ConfigDict( - extra="allow", populate_by_name=True - ) - # Core fields model: str = Field(..., description="The model identifier/name") temperature: float | None = Field( @@ -100,7 +99,7 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta): "cached_prompt_tokens": 0, } - @field_validator("api_key", mode="before") + @field_validator("api_key", mode="after") @classmethod def _validate_api_key(cls, value: str | None) -> str | None: """Validate API key for authentication. @@ -137,37 +136,6 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta): return value return [] - @model_validator(mode="before") - @classmethod - def _extract_stop_and_validate(cls, values: dict[str, Any]) -> dict[str, Any]: - """Extract and normalize stop sequences before model initialization. - - Args: - values: Input values dictionary - - Returns: - Processed values dictionary - """ - if not values.get("model"): - raise ValueError("Model name is required and cannot be empty") - - stop = values.get("stop") or values.get("stop_sequences") - if stop is None: - values["stop"] = [] - elif isinstance(stop, str): - values["stop"] = [stop] - elif isinstance(stop, list): - values["stop"] = stop - else: - values["stop"] = [] - - values.pop("stop_sequences", None) - - if "provider" not in values or values["provider"] is None: - values["provider"] = "openai" - - return values - @property def additional_params(self) -> dict[str, Any]: """Get additional parameters stored as extra fields. @@ -190,20 +158,6 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta): self.__pydantic_extra__ = {} self.__pydantic_extra__.update(value) - def model_post_init(self, __context: Any) -> None: - """Initialize token usage tracking after model initialization. - - Args: - __context: Pydantic context (unused) - """ - self._token_usage = { - "total_tokens": 0, - "prompt_tokens": 0, - "completion_tokens": 0, - "successful_requests": 0, - "cached_prompt_tokens": 0, - } - @abstractmethod def call( self, diff --git a/lib/crewai/src/crewai/llm/internal/constants.py b/lib/crewai/src/crewai/llm/internal/constants.py new file mode 100644 index 000000000..1d24c4682 --- /dev/null +++ b/lib/crewai/src/crewai/llm/internal/constants.py @@ -0,0 +1,14 @@ +from crewai.llm.constants import SupportedNativeProviders + + +PROVIDER_MAPPING: dict[str, SupportedNativeProviders] = { + "openai": "openai", + "anthropic": "anthropic", + "claude": "anthropic", + "azure": "azure", + "azure_openai": "azure", + "google": "gemini", + "gemini": "gemini", + "bedrock": "bedrock", + "aws": "bedrock", +} diff --git a/lib/crewai/src/crewai/llm/internal/meta.py b/lib/crewai/src/crewai/llm/internal/meta.py index f210ab742..8bfb74c24 100644 --- a/lib/crewai/src/crewai/llm/internal/meta.py +++ b/lib/crewai/src/crewai/llm/internal/meta.py @@ -9,6 +9,7 @@ from __future__ import annotations import logging from typing import Any, cast +from pydantic import ConfigDict from pydantic._internal._model_construction import ModelMetaclass from crewai.llm.constants import ( @@ -21,6 +22,7 @@ from crewai.llm.constants import ( SupportedModels, SupportedNativeProviders, ) +from crewai.llm.internal.constants import PROVIDER_MAPPING class LLMMeta(ModelMetaclass): @@ -30,6 +32,41 @@ class LLMMeta(ModelMetaclass): native provider implementation based on the model parameter. """ + def __new__( + mcs, + name: str, + bases: tuple[type, ...], + namespace: dict[str, Any], + **kwargs: Any, + ) -> type: + """Create new LLM class with proper model_config for custom LLMs. + + Args: + name: Class name + bases: Base classes + namespace: Class namespace + **kwargs: Additional arguments + + Returns: + New class + """ + if name != "BaseLLM" and any( + base.__name__ in ("BaseLLM", "LLM") for base in bases + ): + if "model_config" not in namespace: + namespace["model_config"] = ConfigDict( + extra="allow", populate_by_name=True + ) + elif isinstance(namespace["model_config"], dict): + config_dict = cast( + ConfigDict, cast(object, dict(namespace["model_config"])) + ) + config_dict.setdefault("extra", "allow") + config_dict.setdefault("populate_by_name", True) + namespace["model_config"] = ConfigDict(**config_dict) + + return super().__new__(mcs, name, bases, namespace) + def __call__(cls, *args: Any, **kwargs: Any) -> Any: # noqa: N805 """Route to appropriate provider implementation at instantiation time. @@ -57,7 +94,7 @@ class LLMMeta(ModelMetaclass): if args and not kwargs.get("model"): kwargs["model"] = cast(SupportedModels, args[0]) - args = args[1:] + _ = args[1:] explicit_provider = cast(SupportedNativeProviders, kwargs.get("provider")) if explicit_provider: @@ -70,19 +107,7 @@ class LLMMeta(ModelMetaclass): model.partition("/"), ) - provider_mapping: dict[str, SupportedNativeProviders] = { - "openai": "openai", - "anthropic": "anthropic", - "claude": "anthropic", - "azure": "azure", - "azure_openai": "azure", - "google": "gemini", - "gemini": "gemini", - "bedrock": "bedrock", - "aws": "bedrock", - } - - canonical_provider = provider_mapping.get(prefix.lower()) + canonical_provider = PROVIDER_MAPPING.get(prefix.lower()) if canonical_provider and cls._validate_model_in_constants( model_part, canonical_provider diff --git a/lib/crewai/src/crewai/llm/providers/anthropic/completion.py b/lib/crewai/src/crewai/llm/providers/anthropic/completion.py index b0e4b5c87..6303f4e4c 100644 --- a/lib/crewai/src/crewai/llm/providers/anthropic/completion.py +++ b/lib/crewai/src/crewai/llm/providers/anthropic/completion.py @@ -4,6 +4,7 @@ import json import logging from typing import TYPE_CHECKING, Any, cast +from dotenv import load_dotenv import httpx from pydantic import BaseModel, Field, PrivateAttr, model_validator from typing_extensions import Self @@ -21,13 +22,14 @@ from crewai.utilities.types import LLMMessage if TYPE_CHECKING: + from anthropic.types import Message + from crewai.agent.core import Agent from crewai.task import Task try: from anthropic import Anthropic - from anthropic.types import Message from anthropic.types.tool_use_block import ToolUseBlock except ImportError: raise ImportError( @@ -35,6 +37,9 @@ except ImportError: ) from None +load_dotenv() + + class AnthropicCompletion(BaseLLM): """Anthropic native completion implementation. @@ -66,11 +71,9 @@ class AnthropicCompletion(BaseLLM): top_p: float | None = Field(default=None, description="Nucleus sampling parameter") stream: bool = Field(default=False, description="Enable streaming responses") client_params: dict[str, Any] | None = Field( - default=None, description="Additional Anthropic client parameters" - ) - client: Anthropic = Field( - default_factory=Anthropic, exclude=True, description="Anthropic client instance" + default_factory=dict, description="Additional Anthropic client parameters" ) + _client: Anthropic = PrivateAttr(default=None) # type: ignore[assignment] _is_claude_3: bool = PrivateAttr(default=False) _supports_tools: bool = PrivateAttr(default=False) @@ -78,7 +81,7 @@ class AnthropicCompletion(BaseLLM): @model_validator(mode="after") def setup_client(self) -> Self: """Initialize the Anthropic client and model-specific settings.""" - self.client = Anthropic(**self._get_client_params()) + self._client = Anthropic(**self._get_client_params()) self._is_claude_3 = "claude-3" in self.model.lower() self._supports_tools = self._is_claude_3 @@ -98,9 +101,6 @@ class AnthropicCompletion(BaseLLM): def _get_client_params(self) -> dict[str, Any]: """Get client parameters.""" - if self.api_key is None: - raise ValueError("ANTHROPIC_API_KEY is required") - client_params = { "api_key": self.api_key, "base_url": self.base_url, @@ -330,7 +330,7 @@ class AnthropicCompletion(BaseLLM): params["tool_choice"] = {"type": "tool", "name": "structured_output"} try: - response: Message = self.client.messages.create(**params) + response: Message = self._client.messages.create(**params) except Exception as e: if is_context_length_exceeded(e): @@ -424,7 +424,7 @@ class AnthropicCompletion(BaseLLM): stream_params = {k: v for k, v in params.items() if k != "stream"} # Make streaming API call - with self.client.messages.stream(**stream_params) as stream: + with self._client.messages.stream(**stream_params) as stream: for event in stream: if hasattr(event, "delta") and hasattr(event.delta, "text"): text_delta = event.delta.text @@ -552,7 +552,7 @@ class AnthropicCompletion(BaseLLM): try: # Send tool results back to Claude for final response - final_response: Message = self.client.messages.create(**follow_up_params) + final_response: Message = self._client.messages.create(**follow_up_params) # Track token usage for follow-up call follow_up_usage = self._extract_anthropic_token_usage(final_response) diff --git a/lib/crewai/src/crewai/llm/providers/azure/completion.py b/lib/crewai/src/crewai/llm/providers/azure/completion.py index b30e9a2ba..9963fee6f 100644 --- a/lib/crewai/src/crewai/llm/providers/azure/completion.py +++ b/lib/crewai/src/crewai/llm/providers/azure/completion.py @@ -5,6 +5,7 @@ import logging import os from typing import TYPE_CHECKING, Any +from dotenv import load_dotenv from pydantic import BaseModel, Field, PrivateAttr, model_validator from typing_extensions import Self @@ -48,6 +49,9 @@ except ImportError: ) from None +load_dotenv() + + class AzureCompletion(BaseLLM): """Azure AI Inference native completion implementation. @@ -68,12 +72,14 @@ class AzureCompletion(BaseLLM): interceptor: HTTP interceptor (not yet supported for Azure) """ - endpoint: str | None = Field( - default=None, + endpoint: str = Field( # type: ignore[assignment] + default_factory=lambda: os.getenv("AZURE_ENDPOINT") + or os.getenv("AZURE_OPENAI_ENDPOINT") + or os.getenv("AZURE_API_BASE"), description="Azure endpoint URL (defaults to AZURE_ENDPOINT env var)", ) api_version: str = Field( - default="2024-06-01", + default_factory=lambda: os.getenv("AZURE_API_VERSION", "2024-06-01"), description="Azure API version (defaults to AZURE_API_VERSION env var or 2024-06-01)", ) timeout: float | None = Field( @@ -82,18 +88,16 @@ class AzureCompletion(BaseLLM): max_retries: int = Field(default=2, description="Maximum number of retries") top_p: float | None = Field(default=None, description="Nucleus sampling parameter") frequency_penalty: float | None = Field( - default=None, description="Frequency penalty (-2 to 2)" + default=None, le=2.0, ge=-2.0, description="Frequency penalty (-2 to 2)" ) presence_penalty: float | None = Field( - default=None, description="Presence penalty (-2 to 2)" + default=None, le=2.0, ge=-2.0, description="Presence penalty (-2 to 2)" ) max_tokens: int | None = Field( default=None, description="Maximum tokens in response" ) stream: bool = Field(default=False, description="Enable streaming responses") - _client: ChatCompletionsClient = PrivateAttr( - default_factory=ChatCompletionsClient, # type: ignore[arg-type] - ) + _client: ChatCompletionsClient = PrivateAttr(default=None) # type: ignore[assignment] _is_openai_model: bool = PrivateAttr(default=False) _is_azure_openai_endpoint: bool = PrivateAttr(default=False) @@ -107,26 +111,13 @@ class AzureCompletion(BaseLLM): "Interceptors are currently supported for OpenAI and Anthropic providers only." ) - if self.endpoint is None: - self.endpoint = ( - os.getenv("AZURE_ENDPOINT") - or os.getenv("AZURE_OPENAI_ENDPOINT") - or os.getenv("AZURE_API_BASE") - ) - - if self.api_version == "2024-06-01": - env_version = os.getenv("AZURE_API_VERSION") - if env_version: - self.api_version = env_version + if not self.api_key: + self.api_key = os.getenv("AZURE_API_KEY") if not self.api_key: raise ValueError( "Azure API key is required. Set AZURE_API_KEY environment variable or pass api_key parameter." ) - if not self.endpoint: - raise ValueError( - "Azure endpoint is required. Set AZURE_ENDPOINT environment variable or pass endpoint parameter." - ) self.endpoint = self._validate_and_fix_endpoint(self.endpoint, self.model) @@ -138,7 +129,7 @@ class AzureCompletion(BaseLLM): if self.api_version: client_kwargs["api_version"] = self.api_version - self.client = ChatCompletionsClient(**client_kwargs) + self._client = ChatCompletionsClient(**client_kwargs) self._is_openai_model = any( prefix in self.model.lower() for prefix in ["gpt-", "o1-", "text-"] @@ -160,7 +151,7 @@ class AzureCompletion(BaseLLM): """Check if endpoint is an Azure OpenAI endpoint.""" return self._is_azure_openai_endpoint - def _validate_and_fix_endpoint(self, endpoint: str, model: str) -> str: + def _validate_and_fix_endpoint(self, endpoint: str | None, model: str) -> str: """Validate and fix Azure endpoint URL format. Azure OpenAI endpoints should be in the format: @@ -172,7 +163,15 @@ class AzureCompletion(BaseLLM): Returns: Validated and potentially corrected endpoint URL + + Raises: + ValueError: If endpoint is None or empty """ + if not endpoint: + raise ValueError( + "Azure endpoint is required. Set AZURE_ENDPOINT environment variable or pass endpoint parameter." + ) + if "openai.azure.com" in endpoint and "/openai/deployments/" not in endpoint: endpoint = endpoint.rstrip("/") @@ -388,7 +387,7 @@ class AzureCompletion(BaseLLM): """Handle non-streaming chat completion.""" # Make API call try: - response: ChatCompletions = self.client.complete(**params) + response: ChatCompletions = self._client.complete(**params) if not response.choices: raise ValueError("No choices returned from Azure API") @@ -486,7 +485,7 @@ class AzureCompletion(BaseLLM): tool_calls = {} # Make streaming API call - for update in self.client.complete(**params): + for update in self._client.complete(**params): if isinstance(update, StreamingChatCompletionsUpdate): if update.choices: choice = update.choices[0] diff --git a/lib/crewai/src/crewai/llm/providers/bedrock/completion.py b/lib/crewai/src/crewai/llm/providers/bedrock/completion.py index 282ea840d..62afae103 100644 --- a/lib/crewai/src/crewai/llm/providers/bedrock/completion.py +++ b/lib/crewai/src/crewai/llm/providers/bedrock/completion.py @@ -5,6 +5,8 @@ import logging import os from typing import TYPE_CHECKING, Any, TypedDict, cast +from dotenv import load_dotenv +from mypy_boto3_bedrock_runtime.client import BedrockRuntimeClient from pydantic import BaseModel, Field, PrivateAttr, model_validator from typing_extensions import Required, Self @@ -75,6 +77,9 @@ else: topK: int +load_dotenv() + + class ToolInputSchema(TypedDict): """Type definition for tool input schema in Converse API.""" @@ -161,16 +166,22 @@ class BedrockCompletion(BaseLLM): interceptor: HTTP interceptor (not yet supported for Bedrock) """ - aws_access_key_id: str | None = Field( - default=None, description="AWS access key (defaults to environment variable)" + aws_access_key_id: str = Field( # type: ignore[assignment] + default_factory=lambda: os.getenv("AWS_ACCESS_KEY_ID"), + description="AWS access key (defaults to environment variable)", ) - aws_secret_access_key: str | None = Field( - default=None, description="AWS secret key (defaults to environment variable)" + aws_secret_access_key: str = Field( # type: ignore[assignment] + default_factory=lambda: os.getenv("AWS_SECRET_ACCESS_KEY"), + description="AWS secret key (defaults to environment variable)", ) - aws_session_token: str | None = Field( - default=None, description="AWS session token for temporary credentials" + aws_session_token: str = Field( # type: ignore[assignment] + default_factory=lambda: os.getenv("AWS_SESSION_TOKEN"), + description="AWS session token for temporary credentials", + ) + region_name: str = Field( + default_factory=lambda: os.getenv("AWS_REGION", "us-east-1"), + description="AWS region name", ) - region_name: str = Field(default="us-east-1", description="AWS region name") max_tokens: int | None = Field( default=None, description="Maximum tokens to generate" ) @@ -181,17 +192,18 @@ class BedrockCompletion(BaseLLM): stream: bool = Field( default=False, description="Whether to use streaming responses" ) - guardrail_config: dict[str, Any] | None = Field( - default=None, description="Guardrail configuration for content filtering" + guardrail_config: dict[str, Any] = Field( + default_factory=dict, + description="Guardrail configuration for content filtering", ) - additional_model_request_fields: dict[str, Any] | None = Field( - default=None, description="Model-specific request parameters" + additional_model_request_fields: dict[str, Any] = Field( + default_factory=dict, description="Model-specific request parameters" ) - additional_model_response_field_paths: list[str] | None = Field( - default=None, description="Custom response field paths" + additional_model_response_field_paths: list[str] = Field( + default_factory=list, description="Custom response field paths" ) - client: Any = Field( - default=None, exclude=True, description="Bedrock client instance" + _client: BedrockRuntimeClient = PrivateAttr( # type: ignore[assignment] + default_factory=lambda: Session().client, ) _is_claude_model: bool = PrivateAttr(default=False) @@ -209,10 +221,9 @@ class BedrockCompletion(BaseLLM): ) session = Session( - aws_access_key_id=self.aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID"), - aws_secret_access_key=self.aws_secret_access_key - or os.getenv("AWS_SECRET_ACCESS_KEY"), - aws_session_token=self.aws_session_token or os.getenv("AWS_SESSION_TOKEN"), + aws_access_key_id=self.aws_access_key_id, + aws_secret_access_key=self.aws_secret_access_key, + aws_session_token=self.aws_session_token, region_name=self.region_name, ) @@ -225,7 +236,7 @@ class BedrockCompletion(BaseLLM): tcp_keepalive=True, ) - self.client = session.client("bedrock-runtime", config=config) + self._client = session.client("bedrock-runtime", config=config) self._is_claude_model = "claude" in self.model.lower() self._supports_tools = True @@ -365,7 +376,7 @@ class BedrockCompletion(BaseLLM): raise ValueError(f"Invalid message format at index {i}") # Call Bedrock Converse API with proper error handling - response = self.client.converse( + response = self._client.converse( modelId=self.model_id, messages=cast( "Sequence[MessageTypeDef | MessageOutputTypeDef]", @@ -536,13 +547,13 @@ class BedrockCompletion(BaseLLM): tool_use_id = None try: - response = self.client.converse_stream( + response = self._client.converse_stream( modelId=self.model_id, messages=cast( "Sequence[MessageTypeDef | MessageOutputTypeDef]", cast(object, messages), ), - **body, + **body, # type: ignore[arg-type] ) stream = response.get("stream") diff --git a/lib/crewai/src/crewai/llm/providers/gemini/completion.py b/lib/crewai/src/crewai/llm/providers/gemini/completion.py index 5adfa3863..38321d053 100644 --- a/lib/crewai/src/crewai/llm/providers/gemini/completion.py +++ b/lib/crewai/src/crewai/llm/providers/gemini/completion.py @@ -1,8 +1,11 @@ +from __future__ import annotations + import logging import os -from typing import TYPE_CHECKING, Any, ClassVar, cast +from typing import TYPE_CHECKING, Any, cast -from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator +from dotenv import load_dotenv +from pydantic import BaseModel, Field, PrivateAttr, model_validator from typing_extensions import Self from crewai.events.types.llm_events import LLMCallType @@ -31,6 +34,9 @@ except ImportError: ) from None +load_dotenv() + + class GeminiCompletion(BaseLLM): """Google Gemini native completion implementation. @@ -50,15 +56,12 @@ class GeminiCompletion(BaseLLM): interceptor: HTTP interceptor (not yet supported for Gemini) """ - model_config: ClassVar[ConfigDict] = ConfigDict( - ignored_types=(property,), arbitrary_types_allowed=True - ) - project: str | None = Field( - default=None, description="Google Cloud project ID (for Vertex AI)" + default_factory=lambda: os.getenv("GOOGLE_CLOUD_PROJECT"), + description="Google Cloud project ID (for Vertex AI)", ) location: str = Field( - default="us-central1", + default_factory=lambda: os.getenv("GOOGLE_CLOUD_LOCATION", "us-central1"), description="Google Cloud location (for Vertex AI, defaults to 'us-central1')", ) top_p: float | None = Field(default=None, description="Nucleus sampling parameter") @@ -74,9 +77,7 @@ class GeminiCompletion(BaseLLM): default_factory=dict, description="Additional parameters for Google Gen AI Client constructor", ) - client: Any = Field( - default=None, exclude=True, description="Gemini client instance" - ) + _client: Any = PrivateAttr(default=None) _is_gemini_2: bool = PrivateAttr(default=False) _is_gemini_1_5: bool = PrivateAttr(default=False) @@ -94,17 +95,9 @@ class GeminiCompletion(BaseLLM): if self.api_key is None: self.api_key = os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY") - if self.project is None: - self.project = os.getenv("GOOGLE_CLOUD_PROJECT") - - if self.location == "us-central1": - env_location = os.getenv("GOOGLE_CLOUD_LOCATION") - if env_location: - self.location = env_location - use_vertexai = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true" - self.client = self._initialize_client(use_vertexai) + self._client = self._initialize_client(use_vertexai) self._is_gemini_2 = "gemini-2" in self.model.lower() self._is_gemini_1_5 = "gemini-1.5" in self.model.lower() @@ -176,9 +169,9 @@ class GeminiCompletion(BaseLLM): params = {} if ( - hasattr(self, "client") - and hasattr(self.client, "vertexai") - and self.client.vertexai + hasattr(self, "_client") + and hasattr(self._client, "vertexai") + and self._client.vertexai ): params.update( { @@ -400,7 +393,7 @@ class GeminiCompletion(BaseLLM): } try: - response = self.client.models.generate_content(**api_params) + response = self._client.models.generate_content(**api_params) usage = self._extract_token_usage(response) except Exception as e: @@ -468,7 +461,7 @@ class GeminiCompletion(BaseLLM): "config": config, } - for chunk in self.client.models.generate_content_stream(**api_params): + for chunk in self._client.models.generate_content_stream(**api_params): if hasattr(chunk, "text") and chunk.text: full_response += chunk.text self._emit_stream_chunk_event( diff --git a/lib/crewai/src/crewai/llm/providers/openai/completion.py b/lib/crewai/src/crewai/llm/providers/openai/completion.py index 8a0143da6..7d5dd1ec7 100644 --- a/lib/crewai/src/crewai/llm/providers/openai/completion.py +++ b/lib/crewai/src/crewai/llm/providers/openai/completion.py @@ -6,6 +6,7 @@ import logging import os from typing import TYPE_CHECKING, Any +from dotenv import load_dotenv import httpx from openai import APIConnectionError, NotFoundError, OpenAI from openai.types.chat import ChatCompletion, ChatCompletionChunk @@ -32,6 +33,9 @@ if TYPE_CHECKING: from crewai.tools.base_tool import BaseTool +load_dotenv() + + class OpenAICompletion(BaseLLM): """OpenAI native completion implementation. @@ -40,43 +44,51 @@ class OpenAICompletion(BaseLLM): """ # Client configuration fields - organization: str | None = Field(None, description="OpenAI organization ID") - project: str | None = Field(None, description="OpenAI project ID") - max_retries: int = Field(2, description="Maximum number of retries") - default_headers: dict[str, str] | None = Field( - None, description="Default headers for requests" + organization: str | None = Field(default=None, description="OpenAI organization ID") + project: str | None = Field(default=None, description="OpenAI project ID") + max_retries: int = Field(default=2, description="Maximum number of retries") + default_headers: dict[str, str] = Field( + default_factory=dict, description="Default headers for requests" ) - default_query: dict[str, Any] | None = Field( - None, description="Default query parameters" + default_query: dict[str, Any] = Field( + default_factory=dict, description="Default query parameters" ) - client_params: dict[str, Any] | None = Field( - None, description="Additional client parameters" + client_params: dict[str, Any] = Field( + default_factory=dict, description="Additional client parameters" + ) + timeout: float | None = Field(default=None, description="Request timeout") + api_base: str | None = Field( + default=None, description="API base URL", deprecated=True ) - timeout: float | None = Field(None, description="Request timeout") - api_base: str | None = Field(None, description="API base URL (deprecated)") # Completion parameters - top_p: float | None = Field(None, description="Top-p sampling parameter") - frequency_penalty: float | None = Field(None, description="Frequency penalty") - presence_penalty: float | None = Field(None, description="Presence penalty") - max_tokens: int | None = Field(None, description="Maximum tokens") + top_p: float | None = Field(default=None, description="Top-p sampling parameter") + frequency_penalty: float | None = Field( + default=None, description="Frequency penalty" + ) + presence_penalty: float | None = Field(default=None, description="Presence penalty") + max_tokens: int | None = Field(default=None, description="Maximum tokens") max_completion_tokens: int | None = Field( None, description="Maximum completion tokens" ) - seed: int | None = Field(None, description="Random seed") - stream: bool = Field(False, description="Enable streaming") + seed: int | None = Field(default=None, description="Random seed") + stream: bool = Field(default=False, description="Enable streaming") response_format: dict[str, Any] | type[BaseModel] | None = Field( - None, description="Response format" + default=None, description="Response format" ) - logprobs: bool | None = Field(None, description="Return log probabilities") + logprobs: bool | None = Field(default=None, description="Return log probabilities") top_logprobs: int | None = Field( - None, description="Number of top log probabilities" + default=None, description="Number of top log probabilities" + ) + reasoning_effort: str | None = Field( + default=None, description="Reasoning effort level" ) - reasoning_effort: str | None = Field(None, description="Reasoning effort level") - _client: OpenAI = PrivateAttr(default_factory=OpenAI) - is_o1_model: bool = Field(False, description="Whether this is an O1 model") - is_gpt4_model: bool = Field(False, description="Whether this is a GPT-4 model") + _client: OpenAI = PrivateAttr(default=None) # type: ignore[assignment] + is_o1_model: bool = Field(default=False, description="Whether this is an O1 model") + is_gpt4_model: bool = Field( + default=False, description="Whether this is a GPT-4 model" + ) @model_validator(mode="after") def setup_client(self) -> Self: @@ -97,10 +109,6 @@ class OpenAICompletion(BaseLLM): def _get_client_params(self) -> dict[str, Any]: """Get OpenAI client parameters.""" - - if self.api_key is None: - raise ValueError("OPENAI_API_KEY is required") - base_params = { "api_key": self.api_key, "organization": self.organization, diff --git a/lib/crewai/src/crewai/llms/__init__.py b/lib/crewai/src/crewai/llms/__init__.py index fda1e6a3b..9ffac98a5 100644 --- a/lib/crewai/src/crewai/llms/__init__.py +++ b/lib/crewai/src/crewai/llms/__init__.py @@ -1 +1,38 @@ -"""LLM implementations for crewAI.""" +"""LLM implementations for crewAI. + +.. deprecated:: 1.4.0 + The `crewai.llms` package is deprecated. Use `crewai.llm` instead. + + This package was reorganized from `crewai.llms.*` to `crewai.llm.*`. + All submodules are redirected to their new locations in `crewai.llm.*`. + + Migration guide: + Old: from crewai.llms.base_llm import BaseLLM + New: from crewai.llm.base_llm import BaseLLM + + Old: from crewai.llms.hooks.base import BaseInterceptor + New: from crewai.llm.hooks.base import BaseInterceptor + + Old: from crewai.llms.constants import OPENAI_MODELS + New: from crewai.llm.constants import OPENAI_MODELS + + Or use top-level imports: + from crewai import LLM, BaseLLM +""" + +import warnings + +from crewai.llm import LLM +from crewai.llm.base_llm import BaseLLM + + +# Issue deprecation warning when this module is imported +warnings.warn( + "The 'crewai.llms' package is deprecated and will be removed in a future version. " + "Please use 'crewai.llm' (singular) instead. " + "All submodules have been reorganized from 'crewai.llms.*' to 'crewai.llm.*'.", + DeprecationWarning, + stacklevel=2, +) + +__all__ = ["LLM", "BaseLLM"] diff --git a/lib/crewai/src/crewai/llms/base_llm.py b/lib/crewai/src/crewai/llms/base_llm.py new file mode 100644 index 000000000..0eef37033 --- /dev/null +++ b/lib/crewai/src/crewai/llms/base_llm.py @@ -0,0 +1,15 @@ +"""Deprecated: Use crewai.llm.base_llm instead. + +.. deprecated:: 1.4.0 +""" + +import warnings + + +warnings.warn( + "crewai.llms.base_llm is deprecated. Use crewai.llm.base_llm instead.", + DeprecationWarning, + stacklevel=2, +) + +from crewai.llm.base_llm import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/llms/constants.py b/lib/crewai/src/crewai/llms/constants.py new file mode 100644 index 000000000..8dc310b0a --- /dev/null +++ b/lib/crewai/src/crewai/llms/constants.py @@ -0,0 +1,15 @@ +"""Deprecated: Use crewai.llm.constants instead. + +.. deprecated:: 1.4.0 +""" + +import warnings + + +warnings.warn( + "crewai.llms.constants is deprecated. Use crewai.llm.constants instead.", + DeprecationWarning, + stacklevel=2, +) + +from crewai.llm.constants import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/llms/hooks/__init__.py b/lib/crewai/src/crewai/llms/hooks/__init__.py new file mode 100644 index 000000000..c63684cd7 --- /dev/null +++ b/lib/crewai/src/crewai/llms/hooks/__init__.py @@ -0,0 +1,15 @@ +"""Deprecated: Use crewai.llm.hooks instead. + +.. deprecated:: 1.4.0 +""" + +import warnings + + +warnings.warn( + "crewai.llms.hooks is deprecated. Use crewai.llm.hooks instead.", + DeprecationWarning, + stacklevel=2, +) + +from crewai.llm.hooks import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/llms/hooks/base.py b/lib/crewai/src/crewai/llms/hooks/base.py new file mode 100644 index 000000000..7149e70f7 --- /dev/null +++ b/lib/crewai/src/crewai/llms/hooks/base.py @@ -0,0 +1,15 @@ +"""Deprecated: Use crewai.llm.hooks.base instead. + +.. deprecated:: 1.4.0 +""" + +import warnings + + +warnings.warn( + "crewai.llms.hooks.base is deprecated. Use crewai.llm.hooks.base instead.", + DeprecationWarning, + stacklevel=2, +) + +from crewai.llm.hooks.base import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/llms/hooks/transport.py b/lib/crewai/src/crewai/llms/hooks/transport.py new file mode 100644 index 000000000..8ec3bc65e --- /dev/null +++ b/lib/crewai/src/crewai/llms/hooks/transport.py @@ -0,0 +1,15 @@ +"""Deprecated: Use crewai.llm.hooks.transport instead. + +.. deprecated:: 1.4.0 +""" + +import warnings + + +warnings.warn( + "crewai.llms.hooks.transport is deprecated. Use crewai.llm.hooks.transport instead.", + DeprecationWarning, + stacklevel=2, +) + +from crewai.llm.hooks.transport import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/llms/internal/__init__.py b/lib/crewai/src/crewai/llms/internal/__init__.py new file mode 100644 index 000000000..cca43464b --- /dev/null +++ b/lib/crewai/src/crewai/llms/internal/__init__.py @@ -0,0 +1,15 @@ +"""Deprecated: Use crewai.llm.internal instead. + +.. deprecated:: 1.4.0 +""" + +import warnings + + +warnings.warn( + "crewai.llms.internal is deprecated. Use crewai.llm.internal instead.", + DeprecationWarning, + stacklevel=2, +) + +from crewai.llm.internal import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/llms/internal/constants.py b/lib/crewai/src/crewai/llms/internal/constants.py new file mode 100644 index 000000000..5fe8c439c --- /dev/null +++ b/lib/crewai/src/crewai/llms/internal/constants.py @@ -0,0 +1,15 @@ +"""Deprecated: Use crewai.llm.internal.constants instead. + +.. deprecated:: 1.4.0 +""" + +import warnings + + +warnings.warn( + "crewai.llms.internal.constants is deprecated. Use crewai.llm.internal.constants instead.", + DeprecationWarning, + stacklevel=2, +) + +from crewai.llm.internal.constants import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/llms/providers/__init__.py b/lib/crewai/src/crewai/llms/providers/__init__.py new file mode 100644 index 000000000..95bc6d448 --- /dev/null +++ b/lib/crewai/src/crewai/llms/providers/__init__.py @@ -0,0 +1,15 @@ +"""Deprecated: Use crewai.llm.providers instead. + +.. deprecated:: 1.4.0 +""" + +import warnings + + +warnings.warn( + "crewai.llms.providers is deprecated. Use crewai.llm.providers instead.", + DeprecationWarning, + stacklevel=2, +) + +from crewai.llm.providers import * # noqa: E402, F403 diff --git a/lib/crewai/tests/llms/anthropic/test_anthropic.py b/lib/crewai/tests/llms/anthropic/test_anthropic.py index 5a91b2e1e..2cdbe1d49 100644 --- a/lib/crewai/tests/llms/anthropic/test_anthropic.py +++ b/lib/crewai/tests/llms/anthropic/test_anthropic.py @@ -60,7 +60,7 @@ def test_anthropic_tool_use_conversation_flow(): available_functions = {"get_weather": mock_weather_tool} # Mock the Anthropic client responses - with patch.object(completion.client.messages, 'create') as mock_create: + with patch.object(completion._client.messages, 'create') as mock_create: # Mock initial response with tool use - need to properly mock ToolUseBlock mock_tool_use = Mock(spec=ToolUseBlock) mock_tool_use.id = "tool_123" @@ -199,8 +199,8 @@ def test_anthropic_specific_parameters(): assert isinstance(llm, AnthropicCompletion) assert llm.stop == ["Human:", "Assistant:"] assert llm.stream == True - assert llm.client.max_retries == 5 - assert llm.client.timeout == 60 + assert llm._client.max_retries == 5 + assert llm._client.timeout == 60 def test_anthropic_completion_call(): @@ -637,8 +637,8 @@ def test_anthropic_environment_variable_api_key(): with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-anthropic-key"}): llm = LLM(model="anthropic/claude-3-5-sonnet-20241022") - assert llm.client is not None - assert hasattr(llm.client, 'messages') + assert llm._client is not None + assert hasattr(llm._client, 'messages') def test_anthropic_token_usage_tracking(): @@ -648,7 +648,7 @@ def test_anthropic_token_usage_tracking(): llm = LLM(model="anthropic/claude-3-5-sonnet-20241022") # Mock the Anthropic response with usage information - with patch.object(llm.client.messages, 'create') as mock_create: + with patch.object(llm._client.messages, 'create') as mock_create: mock_response = MagicMock() mock_response.content = [MagicMock(text="test response")] mock_response.usage = MagicMock(input_tokens=50, output_tokens=25) diff --git a/lib/crewai/tests/llms/azure/test_azure.py b/lib/crewai/tests/llms/azure/test_azure.py index 6cc4eb463..055c9d499 100644 --- a/lib/crewai/tests/llms/azure/test_azure.py +++ b/lib/crewai/tests/llms/azure/test_azure.py @@ -64,7 +64,7 @@ def test_azure_tool_use_conversation_flow(): available_functions = {"get_weather": mock_weather_tool} # Mock the Azure client responses - with patch.object(completion.client, 'complete') as mock_complete: + with patch.object(completion._client, 'complete') as mock_complete: # Mock tool call in response with proper type mock_tool_call = MagicMock(spec=ChatCompletionsToolCall) mock_tool_call.function.name = "get_weather" @@ -602,7 +602,7 @@ def test_azure_environment_variable_endpoint(): }): llm = LLM(model="azure/gpt-4") - assert llm.client is not None + assert llm._client is not None assert llm.endpoint == "https://test.openai.azure.com/openai/deployments/gpt-4" @@ -613,7 +613,7 @@ def test_azure_token_usage_tracking(): llm = LLM(model="azure/gpt-4") # Mock the Azure response with usage information - with patch.object(llm.client, 'complete') as mock_complete: + with patch.object(llm._client, 'complete') as mock_complete: mock_message = MagicMock() mock_message.content = "test response" mock_message.tool_calls = None @@ -651,7 +651,7 @@ def test_azure_http_error_handling(): llm = LLM(model="azure/gpt-4") # Mock an HTTP error - with patch.object(llm.client, 'complete') as mock_complete: + with patch.object(llm._client, 'complete') as mock_complete: mock_complete.side_effect = HttpResponseError(message="Rate limit exceeded", response=MagicMock(status_code=429)) with pytest.raises(HttpResponseError): @@ -668,7 +668,7 @@ def test_azure_streaming_completion(): llm = LLM(model="azure/gpt-4", stream=True) # Mock streaming response - with patch.object(llm.client, 'complete') as mock_complete: + with patch.object(llm._client, 'complete') as mock_complete: # Create mock streaming updates with proper type mock_updates = [] for chunk in ["Hello", " ", "world", "!"]: @@ -891,7 +891,7 @@ def test_azure_improved_error_messages(): llm = LLM(model="azure/gpt-4") - with patch.object(llm.client, 'complete') as mock_complete: + with patch.object(llm._client, 'complete') as mock_complete: error_401 = HttpResponseError(message="Unauthorized") error_401.status_code = 401 mock_complete.side_effect = error_401 diff --git a/lib/crewai/tests/llms/bedrock/test_bedrock.py b/lib/crewai/tests/llms/bedrock/test_bedrock.py index b3c12cdc2..130e49890 100644 --- a/lib/crewai/tests/llms/bedrock/test_bedrock.py +++ b/lib/crewai/tests/llms/bedrock/test_bedrock.py @@ -579,7 +579,7 @@ def test_bedrock_token_usage_tracking(): llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0") # Mock the Bedrock response with usage information - with patch.object(llm.client, 'converse') as mock_converse: + with patch.object(llm._client, 'converse') as mock_converse: mock_response = { 'output': { 'message': { @@ -624,7 +624,7 @@ def test_bedrock_tool_use_conversation_flow(): available_functions = {"get_weather": mock_weather_tool} # Mock the Bedrock client responses - with patch.object(llm.client, 'converse') as mock_converse: + with patch.object(llm._client, 'converse') as mock_converse: # First response: tool use request tool_use_response = { 'output': { @@ -710,7 +710,7 @@ def test_bedrock_client_error_handling(): llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0") # Test ValidationException - with patch.object(llm.client, 'converse') as mock_converse: + with patch.object(llm._client, 'converse') as mock_converse: error_response = { 'Error': { 'Code': 'ValidationException', @@ -724,7 +724,7 @@ def test_bedrock_client_error_handling(): assert "validation" in str(exc_info.value).lower() # Test ThrottlingException - with patch.object(llm.client, 'converse') as mock_converse: + with patch.object(llm._client, 'converse') as mock_converse: error_response = { 'Error': { 'Code': 'ThrottlingException', @@ -762,7 +762,7 @@ def test_bedrock_stop_sequences_sent_to_api(): llm.stop = ["\nObservation:", "\nThought:"] # Patch the API call to capture parameters without making real call - with patch.object(llm.client, 'converse') as mock_converse: + with patch.object(llm._client, 'converse') as mock_converse: mock_response = { 'output': { 'message': { diff --git a/lib/crewai/tests/llms/google/test_google.py b/lib/crewai/tests/llms/google/test_google.py index 11af3e83b..ffcc2a978 100644 --- a/lib/crewai/tests/llms/google/test_google.py +++ b/lib/crewai/tests/llms/google/test_google.py @@ -59,7 +59,7 @@ def test_gemini_tool_use_conversation_flow(): available_functions = {"get_weather": mock_weather_tool} # Mock the Google Gemini client responses - with patch.object(completion.client.models, 'generate_content') as mock_generate: + with patch.object(completion._client.models, 'generate_content') as mock_generate: # Mock function call in response mock_function_call = Mock() mock_function_call.name = "get_weather" @@ -614,8 +614,8 @@ def test_gemini_environment_variable_api_key(): with patch.dict(os.environ, {"GOOGLE_API_KEY": "test-google-key"}): llm = LLM(model="google/gemini-2.0-flash-001") - assert llm.client is not None - assert hasattr(llm.client, 'models') + assert llm._client is not None + assert hasattr(llm._client, 'models') assert llm.api_key == "test-google-key" @@ -626,7 +626,7 @@ def test_gemini_token_usage_tracking(): llm = LLM(model="google/gemini-2.0-flash-001") # Mock the Gemini response with usage information - with patch.object(llm.client.models, 'generate_content') as mock_generate: + with patch.object(llm._client.models, 'generate_content') as mock_generate: mock_response = MagicMock() mock_response.text = "test response" mock_response.candidates = [] @@ -675,7 +675,7 @@ def test_gemini_stop_sequences_sent_to_api(): llm.stop = ["\nObservation:", "\nThought:"] # Patch the API call to capture parameters without making real call - with patch.object(llm.client.models, 'generate_content') as mock_generate: + with patch.object(llm._client.models, 'generate_content') as mock_generate: mock_response = MagicMock() mock_response.text = "Hello" mock_response.candidates = [] diff --git a/lib/crewai/tests/llms/openai/test_openai.py b/lib/crewai/tests/llms/openai/test_openai.py index aee167ab5..d115eb6eb 100644 --- a/lib/crewai/tests/llms/openai/test_openai.py +++ b/lib/crewai/tests/llms/openai/test_openai.py @@ -369,11 +369,11 @@ def test_openai_client_setup_with_extra_arguments(): assert llm.top_p == 0.5 # Check that client parameters are properly configured - assert llm.client.max_retries == 3 - assert llm.client.timeout == 30 + assert llm._client.max_retries == 3 + assert llm._client.timeout == 30 # Test that parameters are properly used in API calls - with patch.object(llm.client.chat.completions, 'create') as mock_create: + with patch.object(llm._client.chat.completions, 'create') as mock_create: mock_create.return_value = MagicMock( choices=[MagicMock(message=MagicMock(content="test response", tool_calls=None))], usage=MagicMock(prompt_tokens=10, completion_tokens=20, total_tokens=30) @@ -394,7 +394,7 @@ def test_extra_arguments_are_passed_to_openai_completion(): """ llm = LLM(model="gpt-4o", temperature=0.7, max_tokens=1000, top_p=0.5, max_retries=3) - with patch.object(llm.client.chat.completions, 'create') as mock_create: + with patch.object(llm._client.chat.completions, 'create') as mock_create: mock_create.return_value = MagicMock( choices=[MagicMock(message=MagicMock(content="test response", tool_calls=None))], usage=MagicMock(prompt_tokens=10, completion_tokens=20, total_tokens=30) @@ -501,7 +501,7 @@ def test_openai_streaming_with_response_model(): llm = LLM(model="openai/gpt-4o", stream=True) - with patch.object(llm.client.chat.completions, "create") as mock_create: + with patch.object(llm._client.chat.completions, "create") as mock_create: mock_chunk1 = MagicMock() mock_chunk1.choices = [ MagicMock(delta=MagicMock(content='{"answer": "test", ', tool_calls=None))