From 7404d8f1983854efd4cab6bd53b42e03c29fb7a6 Mon Sep 17 00:00:00 2001 From: Greyson Lalonde Date: Thu, 6 Nov 2025 18:40:28 -0500 Subject: [PATCH] feat: restructure llms to pydantic --- lib/crewai/src/crewai/agent/core.py | 32 +- .../crewai/agents/agent_builder/base_agent.py | 2 +- .../src/crewai/agents/crew_agent_executor.py | 17 +- lib/crewai/src/crewai/llm.py | 305 +++++++--------- lib/crewai/src/crewai/llms/base_llm.py | 169 +++++---- lib/crewai/src/crewai/llms/hooks/transport.py | 2 +- .../llms/providers/anthropic/completion.py | 223 +++++------- .../llms/providers/gemini/completion.py | 320 +++++++++------- .../llms/providers/openai/completion.py | 342 +++++++++--------- 9 files changed, 719 insertions(+), 693 deletions(-) diff --git a/lib/crewai/src/crewai/agent/core.py b/lib/crewai/src/crewai/agent/core.py index 3e925cef6..02fad524f 100644 --- a/lib/crewai/src/crewai/agent/core.py +++ b/lib/crewai/src/crewai/agent/core.py @@ -604,22 +604,22 @@ class Agent(BaseAgent): response_template=self.response_template, ).task_execution() - stop_words = [self.i18n.slice("observation")] + stop_sequences = [self.i18n.slice("observation")] if self.response_template: - stop_words.append( + stop_sequences.append( self.response_template.split("{{ .Response }}")[1].strip() ) self.agent_executor = CrewAgentExecutor( - llm=self.llm, + llm=self.llm, # type: ignore[arg-type] task=task, # type: ignore[arg-type] agent=self, crew=self.crew, tools=parsed_tools, prompt=prompt, original_tools=raw_tools, - stop_words=stop_words, + stop_sequences=stop_sequences, max_iter=self.max_iter, tools_handler=self.tools_handler, tools_names=get_tool_names(parsed_tools), @@ -762,7 +762,9 @@ class Agent(BaseAgent): path = parsed.path.replace("/", "_").strip("_") return f"{domain}_{path}" if path else domain - def _get_mcp_tool_schemas(self, server_params: dict) -> dict[str, dict]: + def _get_mcp_tool_schemas( + self, server_params: dict[str, Any] + ) -> dict[str, dict[str, Any]] | Any: """Get tool schemas from MCP server for wrapper creation with caching.""" server_url = server_params["url"] @@ -794,7 +796,7 @@ class Agent(BaseAgent): async def _get_mcp_tool_schemas_async( self, server_params: dict[str, Any] - ) -> dict[str, dict]: + ) -> dict[str, dict[str, Any]]: """Async implementation of MCP tool schema retrieval with timeouts and retries.""" server_url = server_params["url"] return await self._retry_mcp_discovery( @@ -802,7 +804,7 @@ class Agent(BaseAgent): ) async def _retry_mcp_discovery( - self, operation_func, server_url: str + self, operation_func: Any, server_url: str ) -> dict[str, dict[str, Any]]: """Retry MCP discovery operation with exponential backoff, avoiding try-except in loop.""" last_error = None @@ -833,7 +835,7 @@ class Agent(BaseAgent): @staticmethod async def _attempt_mcp_discovery( - operation_func, server_url: str + operation_func: Any, server_url: str ) -> tuple[dict[str, dict[str, Any]] | None, str, bool]: """Attempt single MCP discovery operation and return (result, error_message, should_retry).""" try: @@ -937,13 +939,13 @@ class Agent(BaseAgent): Field(..., description=field_description), ) else: - field_definitions[field_name] = ( + field_definitions[field_name] = ( # type: ignore[assignment] field_type | None, Field(default=None, description=field_description), ) model_name = f"{tool_name.replace('-', '_').replace(' ', '_')}Schema" - return create_model(model_name, **field_definitions) + return create_model(model_name, **field_definitions) # type: ignore[no-any-return,call-overload] def _json_type_to_python(self, field_schema: dict[str, Any]) -> type: """Convert JSON Schema type to Python type. @@ -963,12 +965,12 @@ class Agent(BaseAgent): if "const" in option: types.append(str) else: - types.append(self._json_type_to_python(option)) + types.append(self._json_type_to_python(option)) # type: ignore[arg-type] unique_types = list(set(types)) if len(unique_types) > 1: result = unique_types[0] for t in unique_types[1:]: - result = result | t + result = result | t # type: ignore[assignment] return result return unique_types[0] @@ -981,10 +983,10 @@ class Agent(BaseAgent): "object": dict, } - return type_mapping.get(json_type, Any) + return type_mapping.get(json_type, Any) # type: ignore[arg-type] @staticmethod - def _fetch_amp_mcp_servers(mcp_name: str) -> list[dict]: + def _fetch_amp_mcp_servers(mcp_name: str) -> list[dict[str, Any]]: """Fetch MCP server configurations from CrewAI AMP API.""" # TODO: Implement AMP API call to "integrations/mcps" endpoint # Should return list of server configs with URLs @@ -1223,7 +1225,7 @@ class Agent(BaseAgent): goal=self.goal, backstory=self.backstory, llm=self.llm, - tools=self.tools or [], + tools=self.tools, max_iterations=self.max_iter, max_execution_time=self.max_execution_time, respect_context_window=self.respect_context_window, diff --git a/lib/crewai/src/crewai/agents/agent_builder/base_agent.py b/lib/crewai/src/crewai/agents/agent_builder/base_agent.py index b26c24515..c6dfd9e38 100644 --- a/lib/crewai/src/crewai/agents/agent_builder/base_agent.py +++ b/lib/crewai/src/crewai/agents/agent_builder/base_agent.py @@ -136,7 +136,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta): default=False, description="Enable agent to delegate and ask questions among each other.", ) - tools: list[BaseTool] | None = Field( + tools: list[BaseTool] = Field( default_factory=list, description="Tools at agents' disposal" ) max_iter: int = Field( diff --git a/lib/crewai/src/crewai/agents/crew_agent_executor.py b/lib/crewai/src/crewai/agents/crew_agent_executor.py index 8c1eb2c0e..76f10f9be 100644 --- a/lib/crewai/src/crewai/agents/crew_agent_executor.py +++ b/lib/crewai/src/crewai/agents/crew_agent_executor.py @@ -73,7 +73,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): max_iter: int, tools: list[CrewStructuredTool], tools_names: str, - stop_words: list[str], + stop_sequences: list[str], tools_description: str, tools_handler: ToolsHandler, step_callback: Any = None, @@ -95,7 +95,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): max_iter: Maximum iterations. tools: Available tools. tools_names: Tool names string. - stop_words: Stop word list. + stop_sequences: Stop sequences list for halting generation. tools_description: Tool descriptions. tools_handler: Tool handler instance. step_callback: Optional step callback. @@ -114,7 +114,6 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): self.prompt = prompt self.tools = tools self.tools_names = tools_names - self.stop = stop_words self.max_iter = max_iter self.callbacks = callbacks or [] self._printer: Printer = Printer() @@ -131,15 +130,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): self.iterations = 0 self.log_error_after = 3 if self.llm: - # This may be mutating the shared llm object and needs further evaluation - existing_stop = getattr(self.llm, "stop", []) - self.llm.stop = list( - set( - existing_stop + self.stop - if isinstance(existing_stop, list) - else self.stop - ) - ) + self.llm.stop_sequences.extend(stop_sequences) @property def use_stop_words(self) -> bool: @@ -148,7 +139,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): Returns: bool: True if tool should be used or not. """ - return self.llm.supports_stop_words() if self.llm else False + return self.llm.supports_stop_words if self.llm else False def invoke(self, inputs: dict[str, Any]) -> dict[str, Any]: """Execute the agent with given inputs. diff --git a/lib/crewai/src/crewai/llm.py b/lib/crewai/src/crewai/llm.py index 2e2684ebe..7016ec2e0 100644 --- a/lib/crewai/src/crewai/llm.py +++ b/lib/crewai/src/crewai/llm.py @@ -20,8 +20,7 @@ from typing import ( ) from dotenv import load_dotenv -import httpx -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator from typing_extensions import Self from crewai.events.event_bus import crewai_event_bus @@ -54,7 +53,6 @@ if TYPE_CHECKING: from litellm.utils import supports_response_schema from crewai.agent.core import Agent - from crewai.llms.hooks.base import BaseInterceptor from crewai.task import Task from crewai.tools.base_tool import BaseTool from crewai.utilities.types import LLMMessage @@ -320,7 +318,138 @@ class AccumulatedToolArgs(BaseModel): class LLM(BaseLLM): - completion_cost: float | None = None + completion_cost: float | None = Field( + default=None, description="The completion cost of the LLM." + ) + top_p: float | None = Field( + default=None, description="Sampling probability threshold." + ) + n: int | None = Field( + default=None, description="Number of completions to generate." + ) + max_completion_tokens: int | None = Field( + default=None, + description="Maximum number of tokens to generate in the completion.", + ) + max_tokens: int | None = Field( + default=None, + description="Maximum number of tokens allowed in the prompt + completion.", + ) + presence_penalty: float | None = Field( + default=None, description="Penalty on the presence penalty." + ) + frequency_penalty: float | None = Field( + default=None, description="Penalty on the frequency penalty." + ) + logit_bias: dict[int, float] | None = Field( + default=None, + description="Modifies the likelihood of specified tokens appearing in the completion.", + ) + response_format: type[BaseModel] | None = Field( + default=None, + description="Pydantic model class for structured response parsing.", + ) + seed: int | None = Field( + default=None, + description="Random seed for reproducibility.", + ) + logprobs: int | None = Field( + default=None, + description="Number of top logprobs to return.", + ) + top_logprobs: int | None = Field( + default=None, + description="Number of top logprobs to return.", + ) + api_base: str | None = Field( + default=None, + description="Base URL for the API endpoint.", + ) + api_version: str | None = Field( + default=None, + description="API version to use.", + ) + callbacks: list[Any] = Field( + default_factory=list, + description="List of callback handlers for LLM events.", + ) + reasoning_effort: Literal["none", "low", "medium", "high"] | None = Field( + default=None, + description="Level of reasoning effort for the LLM.", + ) + context_window_size: int = Field( + default=0, + description="The context window size of the LLM.", + ) + is_anthropic: bool = Field( + default=False, + description="Indicates if the model is from Anthropic provider.", + ) + supports_function_calling: bool = Field( + default=False, + description="Indicates if the model supports function calling.", + ) + supports_stop_words: bool = Field( + default=False, + description="Indicates if the model supports stop words.", + ) + + @model_validator(mode="after") + def initialize_client(self) -> Self: + self.is_anthropic = any( + prefix in self.model.lower() for prefix in ANTHROPIC_PREFIXES + ) + try: + provider = self._get_custom_llm_provider() + self.supports_function_calling = litellm.utils.supports_function_calling( + self.model, custom_llm_provider=provider + ) + except Exception as e: + logging.error(f"Failed to check function calling support: {e!s}") + self.supports_function_calling = False + try: + params = get_supported_openai_params(model=self.model) + self.supports_stop_words = params is not None and "stop" in params + except Exception as e: + logging.error(f"Failed to get supported params: {e!s}") + self.supports_stop_words = False + + with suppress_warnings(): + callback_types = [type(callback) for callback in self.callbacks] + for callback in litellm.success_callback[:]: + if type(callback) in callback_types: + litellm.success_callback.remove(callback) + + for callback in litellm._async_success_callback[:]: + if type(callback) in callback_types: + litellm._async_success_callback.remove(callback) + + litellm.callbacks = self.callbacks + + with suppress_warnings(): + success_callbacks_str = os.environ.get("LITELLM_SUCCESS_CALLBACKS", "") + success_callbacks: list[str | Callable[..., Any] | CustomLogger] = [] + if success_callbacks_str: + success_callbacks = [ + cb.strip() for cb in success_callbacks_str.split(",") if cb.strip() + ] + + failure_callbacks_str = os.environ.get("LITELLM_FAILURE_CALLBACKS", "") + if failure_callbacks_str: + failure_callbacks: list[str | Callable[..., Any] | CustomLogger] = [ + cb.strip() for cb in failure_callbacks_str.split(",") if cb.strip() + ] + + litellm.success_callback = success_callbacks + litellm.failure_callback = failure_callbacks + return self + + # @computed_field + # @property + # def is_anthropic(self) -> bool: + # """Determine if the model is from Anthropic provider.""" + # anthropic_prefixes = ("anthropic/", "claude-", "claude/") + # return any(prefix in self.model.lower() for prefix in anthropic_prefixes) def __new__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> LLM: """Factory method that routes to native SDK or falls back to LiteLLM.""" @@ -383,98 +512,6 @@ class LLM(BaseLLM): return None - def __init__( - self, - model: str, - timeout: float | int | None = None, - temperature: float | None = None, - top_p: float | None = None, - n: int | None = None, - stop: str | list[str] | None = None, - max_completion_tokens: int | None = None, - max_tokens: int | float | None = None, - presence_penalty: float | None = None, - frequency_penalty: float | None = None, - logit_bias: dict[int, float] | None = None, - response_format: type[BaseModel] | None = None, - seed: int | None = None, - logprobs: int | None = None, - top_logprobs: int | None = None, - base_url: str | None = None, - api_base: str | None = None, - api_version: str | None = None, - api_key: str | None = None, - callbacks: list[Any] | None = None, - reasoning_effort: Literal["none", "low", "medium", "high"] | None = None, - stream: bool = False, - interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None, - **kwargs: Any, - ) -> None: - """Initialize LLM instance. - - Note: This __init__ method is only called for fallback instances. - Native provider instances handle their own initialization in their respective classes. - """ - super().__init__( - model=model, - temperature=temperature, - api_key=api_key, - base_url=base_url, - timeout=timeout, - **kwargs, - ) - self.model = model - self.timeout = timeout - self.temperature = temperature - self.top_p = top_p - self.n = n - self.max_completion_tokens = max_completion_tokens - self.max_tokens = max_tokens - self.presence_penalty = presence_penalty - self.frequency_penalty = frequency_penalty - self.logit_bias = logit_bias - self.response_format = response_format - self.seed = seed - self.logprobs = logprobs - self.top_logprobs = top_logprobs - self.base_url = base_url - self.api_base = api_base - self.api_version = api_version - self.api_key = api_key - self.callbacks = callbacks - self.context_window_size = 0 - self.reasoning_effort = reasoning_effort - self.additional_params = kwargs - self.is_anthropic = self._is_anthropic_model(model) - self.stream = stream - self.interceptor = interceptor - - litellm.drop_params = True - - # Normalize self.stop to always be a list[str] - if stop is None: - self.stop: list[str] = [] - elif isinstance(stop, str): - self.stop = [stop] - else: - self.stop = stop - - self.set_callbacks(callbacks or []) - self.set_env_callbacks() - - @staticmethod - def _is_anthropic_model(model: str) -> bool: - """Determine if the model is from Anthropic provider. - - Args: - model: The model identifier string. - - Returns: - bool: True if the model is from Anthropic, False otherwise. - """ - anthropic_prefixes = ("anthropic/", "claude-", "claude/") - return any(prefix in model.lower() for prefix in anthropic_prefixes) - def _prepare_completion_params( self, messages: str | list[LLMMessage], @@ -1188,8 +1225,6 @@ class LLM(BaseLLM): message["role"] = msg_role # --- 5) Set up callbacks if provided with suppress_warnings(): - if callbacks and len(callbacks) > 0: - self.set_callbacks(callbacks) try: # --- 6) Prepare parameters for the completion call params = self._prepare_completion_params(messages, tools) @@ -1378,24 +1413,6 @@ class LLM(BaseLLM): "Please remove response_format or use a supported model." ) - def supports_function_calling(self) -> bool: - try: - provider = self._get_custom_llm_provider() - return litellm.utils.supports_function_calling( - self.model, custom_llm_provider=provider - ) - except Exception as e: - logging.error(f"Failed to check function calling support: {e!s}") - return False - - def supports_stop_words(self) -> bool: - try: - params = get_supported_openai_params(model=self.model) - return params is not None and "stop" in params - except Exception as e: - logging.error(f"Failed to get supported params: {e!s}") - return False - def get_context_window_size(self) -> int: """ Returns the context window size, using 75% of the maximum to avoid @@ -1425,60 +1442,6 @@ class LLM(BaseLLM): self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO) return self.context_window_size - @staticmethod - def set_callbacks(callbacks: list[Any]) -> None: - """ - Attempt to keep a single set of callbacks in litellm by removing old - duplicates and adding new ones. - """ - with suppress_warnings(): - callback_types = [type(callback) for callback in callbacks] - for callback in litellm.success_callback[:]: - if type(callback) in callback_types: - litellm.success_callback.remove(callback) - - for callback in litellm._async_success_callback[:]: - if type(callback) in callback_types: - litellm._async_success_callback.remove(callback) - - litellm.callbacks = callbacks - - @staticmethod - def set_env_callbacks() -> None: - """Sets the success and failure callbacks for the LiteLLM library from environment variables. - - This method reads the `LITELLM_SUCCESS_CALLBACKS` and `LITELLM_FAILURE_CALLBACKS` - environment variables, which should contain comma-separated lists of callback names. - It then assigns these lists to `litellm.success_callback` and `litellm.failure_callback`, - respectively. - - If the environment variables are not set or are empty, the corresponding callback lists - will be set to empty lists. - - Examples: - LITELLM_SUCCESS_CALLBACKS="langfuse,langsmith" - LITELLM_FAILURE_CALLBACKS="langfuse" - - This will set `litellm.success_callback` to ["langfuse", "langsmith"] and - `litellm.failure_callback` to ["langfuse"]. - """ - with suppress_warnings(): - success_callbacks_str = os.environ.get("LITELLM_SUCCESS_CALLBACKS", "") - success_callbacks: list[str | Callable[..., Any] | CustomLogger] = [] - if success_callbacks_str: - success_callbacks = [ - cb.strip() for cb in success_callbacks_str.split(",") if cb.strip() - ] - - failure_callbacks_str = os.environ.get("LITELLM_FAILURE_CALLBACKS", "") - if failure_callbacks_str: - failure_callbacks: list[str | Callable[..., Any] | CustomLogger] = [ - cb.strip() for cb in failure_callbacks_str.split(",") if cb.strip() - ] - - litellm.success_callback = success_callbacks - litellm.failure_callback = failure_callbacks - def __copy__(self) -> LLM: """Create a shallow copy of the LLM instance.""" # Filter out parameters that are already explicitly passed to avoid conflicts @@ -1539,7 +1502,7 @@ class LLM(BaseLLM): **filtered_params, ) - def __deepcopy__(self, memo: dict[int, Any] | None) -> LLM: + def __deepcopy__(self, memo: dict[int, Any] | None) -> LLM: # type: ignore[override] """Create a deep copy of the LLM instance.""" import copy diff --git a/lib/crewai/src/crewai/llms/base_llm.py b/lib/crewai/src/crewai/llms/base_llm.py index a7026c5c5..116227fa9 100644 --- a/lib/crewai/src/crewai/llms/base_llm.py +++ b/lib/crewai/src/crewai/llms/base_llm.py @@ -13,8 +13,9 @@ import logging import re from typing import TYPE_CHECKING, Any, Final -from pydantic import BaseModel +from pydantic import AliasChoices, BaseModel, Field, PrivateAttr, field_validator +from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess from crewai.events.event_bus import crewai_event_bus from crewai.events.types.llm_events import ( LLMCallCompletedEvent, @@ -28,6 +29,7 @@ from crewai.events.types.tool_usage_events import ( ToolUsageFinishedEvent, ToolUsageStartedEvent, ) +from crewai.llms.hooks import BaseInterceptor from crewai.types.usage_metrics import UsageMetrics @@ -43,7 +45,7 @@ DEFAULT_SUPPORTS_STOP_WORDS: Final[bool] = True _JSON_EXTRACTION_PATTERN: Final[re.Pattern[str]] = re.compile(r"\{.*}", re.DOTALL) -class BaseLLM(ABC): +class BaseLLM(BaseModel, ABC): """Abstract base class for LLM implementations. This class defines the interface that all LLM implementations must follow. @@ -55,70 +57,105 @@ class BaseLLM(ABC): implement proper validation for input parameters and provide clear error messages when things go wrong. + Attributes: model: The model identifier/name. temperature: Optional temperature setting for response generation. - stop: A list of stop sequences that the LLM should use to stop generation. - additional_params: Additional provider-specific parameters. """ - is_litellm: bool = False + provider: str | re.Pattern[str] = Field( + default="openai", description="The provider of the LLM." + ) + model: str = Field(description="The model identifier/name.") + temperature: float | None = Field( + default=None, ge=0, le=2, description="Temperature for response generation." + ) + api_key: str | None = Field(default=None, description="API key for authentication.") + base_url: str | None = Field(default=None, description="Base URL for API calls.") + timeout: float | None = Field(default=None, description="Timeout for API calls.") + max_retries: int = Field( + default=2, description="Maximum number of API requests to make." + ) + max_tokens: int | None = Field( + default=None, description="Maximum tokens for response generation." + ) + stream: bool | None = Field(default=False, description="Stream the API requests.") + client: Any = Field(description="Underlying LLM client instance.") + interceptor: BaseInterceptor[Any, Any] | None = Field( + default=None, + description="An optional HTTPX interceptor for modifying requests/responses.", + ) + client_params: dict[str, Any] = Field( + default_factory=dict, + description="Additional parameters for the underlying LLM client.", + ) + supports_stop_words: bool = Field( + default=DEFAULT_SUPPORTS_STOP_WORDS, + description="Whether or not to support stop words.", + ) + stop_sequences: list[str] = Field( + default_factory=list, + validation_alias=AliasChoices("stop_sequences", "stop"), + description="Stop sequences for generation (synchronized with stop).", + ) + is_litellm: bool = Field( + default=False, description="Is this LLM implementation in litellm?" + ) + additional_params: dict[str, Any] = Field( + default_factory=dict, + description="Additional parameters for LLM calls.", + ) + _token_usage: TokenProcess = PrivateAttr(default_factory=TokenProcess) - def __init__( - self, - model: str, - temperature: float | None = None, - api_key: str | None = None, - base_url: str | None = None, - provider: str | None = None, - **kwargs: Any, - ) -> None: - """Initialize the BaseLLM with default attributes. + @field_validator("provider", mode="before") + @classmethod + def extract_provider_from_model( + cls, v: str | re.Pattern[str] | None, info: Any + ) -> str | re.Pattern[str]: + """Extract provider from model string if not explicitly provided. Args: - model: The model identifier/name. - temperature: Optional temperature setting for response generation. - stop: Optional list of stop sequences for generation. - **kwargs: Additional provider-specific parameters. + v: Provided provider value (can be str, Pattern, or None) + info: Validation info containing other field values + + Returns: + Provider name (str) or Pattern """ - if not model: - raise ValueError("Model name is required and cannot be empty") + # If provider explicitly provided, validate and return it + if v is not None: + if not isinstance(v, (str, re.Pattern)): + raise ValueError(f"Provider must be str or Pattern, got {type(v)}") + return v - self.model = model - self.temperature = temperature - self.api_key = api_key - self.base_url = base_url - # Store additional parameters for provider-specific use - self.additional_params = kwargs - self._provider = provider or "openai" + model: str = info.data.get("model", "") + if "/" in model: + return model.partition("/")[0] + return "openai" - stop = kwargs.pop("stop", None) - if stop is None: - self.stop: list[str] = [] - elif isinstance(stop, str): - self.stop = [stop] - elif isinstance(stop, list): - self.stop = stop - else: - self.stop = [] + @field_validator("stop_sequences", mode="before") + @classmethod + def normalize_stop_sequences( + cls, v: str | list[str] | set[str] | None + ) -> list[str]: + """Validate and normalize stop sequences. - self._token_usage = { - "total_tokens": 0, - "prompt_tokens": 0, - "completion_tokens": 0, - "successful_requests": 0, - "cached_prompt_tokens": 0, - } + Converts string to list and handles None values. + AliasChoices handles accepting both 'stop' and 'stop_sequences' parameter names. + """ + if v is None: + return [] + if isinstance(v, str): + return [v] + if isinstance(v, set): + return list(v) + if isinstance(v, list): + return v + return [] @property - def provider(self) -> str: - """Get the provider of the LLM.""" - return self._provider - - @provider.setter - def provider(self, value: str) -> None: - """Set the provider of the LLM.""" - self._provider = value + def stop(self) -> list[str]: + """Alias for stop_sequences to maintain backward compatibility.""" + return self.stop_sequences @abstractmethod def call( @@ -171,14 +208,6 @@ class BaseLLM(ABC): """ return tools - def supports_stop_words(self) -> bool: - """Check if the LLM supports stop words. - - Returns: - True if the LLM supports stop words, False otherwise. - """ - return DEFAULT_SUPPORTS_STOP_WORDS - def _supports_stop_words_implementation(self) -> bool: """Check if stop words are configured for this LLM instance. @@ -506,7 +535,7 @@ class BaseLLM(ABC): """ if "/" in model: return model.partition("/")[0] - return "openai" # Default provider + return "openai" def _track_token_usage_internal(self, usage_data: dict[str, Any]) -> None: """Track token usage internally in the LLM instance. @@ -535,11 +564,11 @@ class BaseLLM(ABC): or 0 ) - self._token_usage["prompt_tokens"] += prompt_tokens - self._token_usage["completion_tokens"] += completion_tokens - self._token_usage["total_tokens"] += prompt_tokens + completion_tokens - self._token_usage["successful_requests"] += 1 - self._token_usage["cached_prompt_tokens"] += cached_tokens + self._token_usage.prompt_tokens += prompt_tokens + self._token_usage.completion_tokens += completion_tokens + self._token_usage.total_tokens += prompt_tokens + completion_tokens + self._token_usage.successful_requests += 1 + self._token_usage.cached_prompt_tokens += cached_tokens def get_token_usage_summary(self) -> UsageMetrics: """Get summary of token usage for this LLM instance. @@ -547,4 +576,10 @@ class BaseLLM(ABC): Returns: Dictionary with token usage totals """ - return UsageMetrics(**self._token_usage) + return UsageMetrics( + prompt_tokens=self._token_usage.prompt_tokens, + completion_tokens=self._token_usage.completion_tokens, + total_tokens=self._token_usage.total_tokens, + successful_requests=self._token_usage.successful_requests, + cached_prompt_tokens=self._token_usage.cached_prompt_tokens, + ) diff --git a/lib/crewai/src/crewai/llms/hooks/transport.py b/lib/crewai/src/crewai/llms/hooks/transport.py index ee3f9224c..3595ae999 100644 --- a/lib/crewai/src/crewai/llms/hooks/transport.py +++ b/lib/crewai/src/crewai/llms/hooks/transport.py @@ -25,7 +25,7 @@ if TYPE_CHECKING: from crewai.llms.hooks.base import BaseInterceptor -class HTTPTransportKwargs(TypedDict): +class HTTPTransportKwargs(TypedDict, total=False): """Typed dictionary for httpx.HTTPTransport initialization parameters. These parameters configure the underlying HTTP transport behavior including diff --git a/lib/crewai/src/crewai/llms/providers/anthropic/completion.py b/lib/crewai/src/crewai/llms/providers/anthropic/completion.py index 50298eb77..d0435b5a0 100644 --- a/lib/crewai/src/crewai/llms/providers/anthropic/completion.py +++ b/lib/crewai/src/crewai/llms/providers/anthropic/completion.py @@ -5,11 +5,14 @@ import logging import os from typing import TYPE_CHECKING, Any, cast -from pydantic import BaseModel +from pydantic import BaseModel, Field, PrivateAttr, computed_field, model_validator +from typing_extensions import Self from crewai.events.types.llm_events import LLMCallType +from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO from crewai.llms.base_llm import BaseLLM from crewai.llms.hooks.transport import HTTPTransport +from crewai.llms.providers.utils.common import safe_tool_conversion from crewai.utilities.agent_utils import is_context_length_exceeded from crewai.utilities.exceptions.context_window_exceeding_exception import ( LLMContextLengthExceededError, @@ -18,7 +21,8 @@ from crewai.utilities.types import LLMMessage if TYPE_CHECKING: - from crewai.llms.hooks.base import BaseInterceptor + from crewai.agent import Agent + from crewai.task import Task try: from anthropic import Anthropic @@ -31,6 +35,19 @@ except ImportError: ) from None +ANTHROPIC_CONTEXT_WINDOWS: dict[str, int] = { + "claude-3-5-sonnet": 200000, + "claude-3-5-haiku": 200000, + "claude-3-opus": 200000, + "claude-3-sonnet": 200000, + "claude-3-haiku": 200000, + "claude-3-7-sonnet": 200000, + "claude-2.1": 200000, + "claude-2": 100000, + "claude-instant": 100000, +} + + class AnthropicCompletion(BaseLLM): """Anthropic native completion implementation. @@ -38,86 +55,69 @@ class AnthropicCompletion(BaseLLM): offering native tool use, streaming support, and proper message formatting. """ - def __init__( - self, - model: str = "claude-3-5-sonnet-20241022", - api_key: str | None = None, - base_url: str | None = None, - timeout: float | None = None, - max_retries: int = 2, - temperature: float | None = None, - max_tokens: int = 4096, # Required for Anthropic - top_p: float | None = None, - stop_sequences: list[str] | None = None, - stream: bool = False, - client_params: dict[str, Any] | None = None, - interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None, - **kwargs: Any, - ): - """Initialize Anthropic chat completion client. + model: str = Field( + default="claude-3-5-sonnet-20241022", + description="Anthropic model name (e.g., 'claude-3-5-sonnet-20241022')", + ) + max_tokens: int = Field( + default=4096, + description="Maximum number of allowed tokens in response.", + ) + top_p: float | None = Field( + default=None, + description="Nucleus sampling parameter.", + ) + _client: Anthropic = PrivateAttr( + default_factory=Anthropic, + ) - Args: - model: Anthropic model name (e.g., 'claude-3-5-sonnet-20241022') - api_key: Anthropic API key (defaults to ANTHROPIC_API_KEY env var) - base_url: Custom base URL for Anthropic API - timeout: Request timeout in seconds - max_retries: Maximum number of retries - temperature: Sampling temperature (0-1) - max_tokens: Maximum tokens in response (required for Anthropic) - top_p: Nucleus sampling parameter - stop_sequences: Stop sequences (Anthropic uses stop_sequences, not stop) - stream: Enable streaming responses - client_params: Additional parameters for the Anthropic client - interceptor: HTTP interceptor for modifying requests/responses at transport level. - **kwargs: Additional parameters + @model_validator(mode="after") + def initialize_client(self) -> Self: + """Initialize the Anthropic client after Pydantic validation. + + This runs after all field validation is complete, ensuring that: + - All BaseLLM fields are set (model, temperature, stop_sequences, etc.) + - Field validators have run (stop_sequences is normalized to set[str]) + - API key and other configuration is ready """ - super().__init__( - model=model, temperature=temperature, stop=stop_sequences or [], **kwargs - ) - - # Client params - self.interceptor = interceptor - self.client_params = client_params - self.base_url = base_url - self.timeout = timeout - self.max_retries = max_retries - - self.client = Anthropic(**self._get_client_params()) - - # Store completion parameters - self.max_tokens = max_tokens - self.top_p = top_p - self.stream = stream - self.stop_sequences = stop_sequences or [] - - # Model-specific settings - self.is_claude_3 = "claude-3" in model.lower() - self.supports_tools = self.is_claude_3 # Claude 3+ supports tool use - - def _get_client_params(self) -> dict[str, Any]: - """Get client parameters.""" - if self.api_key is None: self.api_key = os.getenv("ANTHROPIC_API_KEY") if self.api_key is None: raise ValueError("ANTHROPIC_API_KEY is required") - client_params = { - "api_key": self.api_key, - "base_url": self.base_url, - "timeout": self.timeout, - "max_retries": self.max_retries, - } + params = self.model_dump( + include={"api_key", "base_url", "timeout", "max_retries"}, + exclude_none=True, + ) if self.interceptor: transport = HTTPTransport(interceptor=self.interceptor) http_client = httpx.Client(transport=transport) - client_params["http_client"] = http_client # type: ignore[assignment] + params["http_client"] = http_client if self.client_params: - client_params.update(self.client_params) + params.update(self.client_params) - return client_params + self._client = Anthropic(**params) + return self + + @computed_field # type: ignore[prop-decorator] + @property + def is_claude_3(self) -> bool: + """Check if the model is Claude 3 or higher.""" + return "claude-3" in self.model.lower() + + @computed_field # type: ignore[prop-decorator] + @property + def supports_tools(self) -> bool: + """Check if the model supports tool use.""" + return self.is_claude_3 + + @computed_field # type: ignore[prop-decorator] + @property + def supports_function_calling(self) -> bool: + """Check if the model supports function calling.""" + return self.supports_tools def call( self, @@ -125,8 +125,8 @@ class AnthropicCompletion(BaseLLM): tools: list[dict[str, Any]] | None = None, callbacks: list[Any] | None = None, available_functions: dict[str, Any] | None = None, - from_task: Any | None = None, - from_agent: Any | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, response_model: type[BaseModel] | None = None, ) -> str | Any: """Call Anthropic messages API. @@ -205,25 +205,20 @@ class AnthropicCompletion(BaseLLM): Returns: Parameters dictionary for Anthropic API """ - params = { - "model": self.model, - "messages": messages, - "max_tokens": self.max_tokens, - "stream": self.stream, - } - + params = self.model_dump( + include={ + "model", + "max_tokens", + "stream", + "temperaturetop_p", + "stop_sequences", + }, + ) + params["messages"] = messages # Add system message if present if system_message: params["system"] = system_message - # Add optional parameters if set - if self.temperature is not None: - params["temperature"] = self.temperature - if self.top_p is not None: - params["top_p"] = self.top_p - if self.stop_sequences: - params["stop_sequences"] = self.stop_sequences - # Handle tools for Claude 3+ if tools and self.supports_tools: params["tools"] = self._convert_tools_for_interference(tools) @@ -242,8 +237,6 @@ class AnthropicCompletion(BaseLLM): continue try: - from crewai.llms.providers.utils.common import safe_tool_conversion - name, description, parameters = safe_tool_conversion(tool, "Anthropic") except (ImportError, KeyError, ValueError) as e: logging.error(f"Error converting tool to Anthropic format: {e}") @@ -317,8 +310,8 @@ class AnthropicCompletion(BaseLLM): self, params: dict[str, Any], available_functions: dict[str, Any] | None = None, - from_task: Any | None = None, - from_agent: Any | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, response_model: type[BaseModel] | None = None, ) -> str | Any: """Handle non-streaming message completion.""" @@ -333,7 +326,7 @@ class AnthropicCompletion(BaseLLM): params["tool_choice"] = {"type": "tool", "name": "structured_output"} try: - response: Message = self.client.messages.create(**params) + response: Message = self._client.messages.create(**params) except Exception as e: if is_context_length_exceeded(e): @@ -405,8 +398,8 @@ class AnthropicCompletion(BaseLLM): self, params: dict[str, Any], available_functions: dict[str, Any] | None = None, - from_task: Any | None = None, - from_agent: Any | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, response_model: type[BaseModel] | None = None, ) -> str: """Handle streaming message completion.""" @@ -427,7 +420,7 @@ class AnthropicCompletion(BaseLLM): stream_params = {k: v for k, v in params.items() if k != "stream"} # Make streaming API call - with self.client.messages.stream(**stream_params) as stream: + with self._client.messages.stream(**stream_params) as stream: for event in stream: if hasattr(event, "delta") and hasattr(event.delta, "text"): text_delta = event.delta.text @@ -501,8 +494,8 @@ class AnthropicCompletion(BaseLLM): tool_uses: list[ToolUseBlock], params: dict[str, Any], available_functions: dict[str, Any], - from_task: Any | None = None, - from_agent: Any | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, ) -> str: """Handle the complete tool use conversation flow. @@ -555,7 +548,7 @@ class AnthropicCompletion(BaseLLM): try: # Send tool results back to Claude for final response - final_response: Message = self.client.messages.create(**follow_up_params) + final_response: Message = self._client.messages.create(**follow_up_params) # Track token usage for follow-up call follow_up_usage = self._extract_anthropic_token_usage(final_response) @@ -602,48 +595,24 @@ class AnthropicCompletion(BaseLLM): return tool_results[0]["content"] raise e - def supports_function_calling(self) -> bool: - """Check if the model supports function calling.""" - return self.supports_tools - - def supports_stop_words(self) -> bool: - """Check if the model supports stop words.""" - return True # All Claude models support stop sequences - def get_context_window_size(self) -> int: """Get the context window size for the model.""" - from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO - - # Context window sizes for Anthropic models - context_windows = { - "claude-3-5-sonnet": 200000, - "claude-3-5-haiku": 200000, - "claude-3-opus": 200000, - "claude-3-sonnet": 200000, - "claude-3-haiku": 200000, - "claude-3-7-sonnet": 200000, - "claude-2.1": 200000, - "claude-2": 100000, - "claude-instant": 100000, - } - # Find the best match for the model name - for model_prefix, size in context_windows.items(): + for model_prefix, size in ANTHROPIC_CONTEXT_WINDOWS.items(): if self.model.startswith(model_prefix): return int(size * CONTEXT_WINDOW_USAGE_RATIO) # Default context window size for Claude models return int(200000 * CONTEXT_WINDOW_USAGE_RATIO) - def _extract_anthropic_token_usage(self, response: Message) -> dict[str, Any]: + @staticmethod + def _extract_anthropic_token_usage(response: Message) -> dict[str, Any]: """Extract token usage from Anthropic response.""" - if hasattr(response, "usage") and response.usage: + if response.usage: usage = response.usage - input_tokens = getattr(usage, "input_tokens", 0) - output_tokens = getattr(usage, "output_tokens", 0) return { - "input_tokens": input_tokens, - "output_tokens": output_tokens, - "total_tokens": input_tokens + output_tokens, + "input_tokens": usage.input_tokens, + "output_tokens": usage.output_tokens, + "total_tokens": usage.input_tokens + usage.output_tokens, } return {"total_tokens": 0} diff --git a/lib/crewai/src/crewai/llms/providers/gemini/completion.py b/lib/crewai/src/crewai/llms/providers/gemini/completion.py index 45b603c19..b0545018d 100644 --- a/lib/crewai/src/crewai/llms/providers/gemini/completion.py +++ b/lib/crewai/src/crewai/llms/providers/gemini/completion.py @@ -1,12 +1,14 @@ -import logging -import os -from typing import Any, cast +from __future__ import annotations -from pydantic import BaseModel +import logging +from typing import TYPE_CHECKING, Any, cast + +from pydantic import BaseModel, Field, PrivateAttr, computed_field, model_validator +from typing_extensions import Self from crewai.events.types.llm_events import LLMCallType +from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES from crewai.llms.base_llm import BaseLLM -from crewai.llms.hooks.base import BaseInterceptor from crewai.utilities.agent_utils import is_context_length_exceeded from crewai.utilities.exceptions.context_window_exceeding_exception import ( LLMContextLengthExceededError, @@ -14,6 +16,11 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import ( from crewai.utilities.types import LLMMessage +if TYPE_CHECKING: + from crewai.agent import Agent + from crewai.task import Task + + try: from google import genai # type: ignore[import-untyped] from google.genai import types # type: ignore[import-untyped] @@ -24,6 +31,27 @@ except ImportError: ) from None +GEMINI_CONTEXT_WINDOWS: dict[str, int] = { + "gemini-2.0-flash": 1048576, # 1M tokens + "gemini-2.0-flash-thinking": 32768, + "gemini-2.0-flash-lite": 1048576, + "gemini-2.5-flash": 1048576, + "gemini-2.5-pro": 1048576, + "gemini-1.5-pro": 2097152, # 2M tokens + "gemini-1.5-flash": 1048576, + "gemini-1.5-flash-8b": 1048576, + "gemini-1.0-pro": 32768, + "gemma-3-1b": 32000, + "gemma-3-4b": 128000, + "gemma-3-12b": 128000, + "gemma-3-27b": 128000, +} + +# Context window validation constraints +MIN_CONTEXT_WINDOW: int = 1024 +MAX_CONTEXT_WINDOW: int = 2097152 + + class GeminiCompletion(BaseLLM): """Google Gemini native completion implementation. @@ -31,78 +59,140 @@ class GeminiCompletion(BaseLLM): offering native function calling, streaming support, and proper Gemini formatting. """ - def __init__( - self, - model: str = "gemini-2.0-flash-001", - api_key: str | None = None, - project: str | None = None, - location: str | None = None, - temperature: float | None = None, - top_p: float | None = None, - top_k: int | None = None, - max_output_tokens: int | None = None, - stop_sequences: list[str] | None = None, - stream: bool = False, - safety_settings: dict[str, Any] | None = None, - client_params: dict[str, Any] | None = None, - interceptor: BaseInterceptor[Any, Any] | None = None, - **kwargs: Any, - ): - """Initialize Google Gemini chat completion client. + model: str = Field( + default="gemini-2.0-flash-001", + description="Gemini model name (e.g., 'gemini-2.0-flash-001', 'gemini-1.5-pro')", + ) + project: str | None = Field( + default=None, + description="Google Cloud project ID (for Vertex AI)", + ) + location: str = Field( + default="us-central1", + description="Google Cloud location (for Vertex AI)", + ) + top_p: float | None = Field( + default=None, + description="Nucleus sampling parameter", + ) + top_k: int | None = Field( + default=None, + description="Top-k sampling parameter", + ) + max_output_tokens: int | None = Field( + default=None, + description="Maximum tokens in response", + ) + safety_settings: dict[str, Any] | None = Field( + default=None, + description="Safety filter settings", + ) + _client: genai.Client = PrivateAttr( # type: ignore[no-any-unimported] + default_factory=genai.Client, + ) - Args: - model: Gemini model name (e.g., 'gemini-2.0-flash-001', 'gemini-1.5-pro') - api_key: Google API key (defaults to GOOGLE_API_KEY or GEMINI_API_KEY env var) - project: Google Cloud project ID (for Vertex AI) - location: Google Cloud location (for Vertex AI, defaults to 'us-central1') - temperature: Sampling temperature (0-2) - top_p: Nucleus sampling parameter - top_k: Top-k sampling parameter - max_output_tokens: Maximum tokens in response - stop_sequences: Stop sequences - stream: Enable streaming responses - safety_settings: Safety filter settings - client_params: Additional parameters to pass to the Google Gen AI Client constructor. - Supports parameters like http_options, credentials, debug_config, etc. - interceptor: HTTP interceptor (not yet supported for Gemini). - **kwargs: Additional parameters + @model_validator(mode="after") + def initialize_client(self) -> Self: + """Initialize the Anthropic client after Pydantic validation. + + This runs after all field validation is complete, ensuring that: + - All BaseLLM fields are set (model, temperature, stop_sequences, etc.) + - Field validators have run (stop_sequences is normalized to set[str]) + - API key and other configuration is ready """ - if interceptor is not None: - raise NotImplementedError( - "HTTP interceptors are not yet supported for Google Gemini provider. " - "Interceptors are currently supported for OpenAI and Anthropic providers only." - ) + self._client = genai.Client(**self._get_client_params()) + return self - super().__init__( - model=model, temperature=temperature, stop=stop_sequences or [], **kwargs - ) + # def __init__( + # self, + # model: str = "gemini-2.0-flash-001", + # api_key: str | None = None, + # project: str | None = None, + # location: str | None = None, + # temperature: float | None = None, + # top_p: float | None = None, + # top_k: int | None = None, + # max_output_tokens: int | None = None, + # stop_sequences: list[str] | None = None, + # stream: bool = False, + # safety_settings: dict[str, Any] | None = None, + # client_params: dict[str, Any] | None = None, + # interceptor: BaseInterceptor[Any, Any] | None = None, + # **kwargs: Any, + # # ): + # """Initialize Google Gemini chat completion client. + # + # Args: + # model: Gemini model name (e.g., 'gemini-2.0-flash-001', 'gemini-1.5-pro') + # api_key: Google API key (defaults to GOOGLE_API_KEY or GEMINI_API_KEY env var) + # project: Google Cloud project ID (for Vertex AI) + # location: Google Cloud location (for Vertex AI, defaults to 'us-central1') + # temperature: Sampling temperature (0-2) + # top_p: Nucleus sampling parameter + # top_k: Top-k sampling parameter + # max_output_tokens: Maximum tokens in response + # stop_sequences: Stop sequences + # stream: Enable streaming responses + # safety_settings: Safety filter settings + # client_params: Additional parameters to pass to the Google Gen AI Client constructor. + # Supports parameters like http_options, credentials, debug_config, etc. + # interceptor: HTTP interceptor (not yet supported for Gemini). + # **kwargs: Additional parameters + # """ + # if interceptor is not None: + # raise NotImplementedError( + # "HTTP interceptors are not yet supported for Google Gemini provider. " + # "Interceptors are currently supported for OpenAI and Anthropic providers only." + # ) + # + # super().__init__( + # model=model, temperature=temperature, stop=stop_sequences or [], **kwargs + # ) + # + # # Store client params for later use + # self.client_params = client_params or {} + # + # # Get API configuration with environment variable fallbacks + # self.api_key = ( + # api_key or os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY") + # ) + # self.project = project or os.getenv("GOOGLE_CLOUD_PROJECT") + # self.location = location or os.getenv("GOOGLE_CLOUD_LOCATION") or "us-central1" + # + # use_vertexai = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true" + # + # self.client = self._initialize_client(use_vertexai) + # + # # Store completion parameters + # self.top_p = top_p + # self.top_k = top_k + # self.max_output_tokens = max_output_tokens + # self.stream = stream + # self.safety_settings = safety_settings or {} + # self.stop_sequences = stop_sequences or [] + # + # # Model-specific settings + # self.is_gemini_2 = "gemini-2" in model.lower() + # self.is_gemini_1_5 = "gemini-1.5" in model.lower() + # self.supports_tools = self.is_gemini_1_5 or self.is_gemini_2 - # Store client params for later use - self.client_params = client_params or {} + @computed_field # type: ignore[prop-decorator] + @property + def is_gemini_2(self) -> bool: + """Check if the model is Gemini 2.x.""" + return "gemini-2" in self.model.lower() - # Get API configuration with environment variable fallbacks - self.api_key = ( - api_key or os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY") - ) - self.project = project or os.getenv("GOOGLE_CLOUD_PROJECT") - self.location = location or os.getenv("GOOGLE_CLOUD_LOCATION") or "us-central1" + @computed_field # type: ignore[prop-decorator] + @property + def is_gemini_1_5(self) -> bool: + """Check if the model is Gemini 1.5.x.""" + return "gemini-1.5" in self.model.lower() - use_vertexai = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true" - - self.client = self._initialize_client(use_vertexai) - - # Store completion parameters - self.top_p = top_p - self.top_k = top_k - self.max_output_tokens = max_output_tokens - self.stream = stream - self.safety_settings = safety_settings or {} - self.stop_sequences = stop_sequences or [] - - # Model-specific settings - self.is_gemini_2 = "gemini-2" in model.lower() - self.is_gemini_1_5 = "gemini-1.5" in model.lower() - self.supports_tools = self.is_gemini_1_5 or self.is_gemini_2 + @computed_field # type: ignore[prop-decorator] + @property + def supports_tools(self) -> bool: + """Check if the model supports tool/function calling.""" + return self.is_gemini_1_5 or self.is_gemini_2 def _initialize_client(self, use_vertexai: bool = False) -> genai.Client: # type: ignore[no-any-unimported] """Initialize the Google Gen AI client with proper parameter handling. @@ -118,6 +208,12 @@ class GeminiCompletion(BaseLLM): if self.client_params: client_params.update(self.client_params) + if self.interceptor: + raise NotImplementedError( + "HTTP interceptors are not yet supported for Google Gemini provider. " + "Interceptors are currently supported for OpenAI and Anthropic providers only." + ) + if use_vertexai or self.project: client_params.update( { @@ -157,7 +253,7 @@ class GeminiCompletion(BaseLLM): if ( hasattr(self, "client") - and hasattr(self.client, "vertexai") + and hasattr(self._client, "vertexai") and self.client.vertexai ): # Vertex AI configuration @@ -182,8 +278,8 @@ class GeminiCompletion(BaseLLM): tools: list[dict[str, Any]] | None = None, callbacks: list[Any] | None = None, available_functions: dict[str, Any] | None = None, - from_task: Any | None = None, - from_agent: Any | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, response_model: type[BaseModel] | None = None, ) -> str | Any: """Call Google Gemini generate content API. @@ -270,7 +366,16 @@ class GeminiCompletion(BaseLLM): GenerateContentConfig object for Gemini API """ self.tools = tools - config_params = {} + config_params = self.model_dump( + include={ + "temperature", + "top_p", + "top_k", + "max_output_tokens", + "stop_sequences", + "safety_settings", + } + ) # Add system instruction if present if system_instruction: @@ -280,18 +385,6 @@ class GeminiCompletion(BaseLLM): ) config_params["system_instruction"] = system_content - # Add generation config parameters - if self.temperature is not None: - config_params["temperature"] = self.temperature - if self.top_p is not None: - config_params["top_p"] = self.top_p - if self.top_k is not None: - config_params["top_k"] = self.top_k - if self.max_output_tokens is not None: - config_params["max_output_tokens"] = self.max_output_tokens - if self.stop_sequences: - config_params["stop_sequences"] = self.stop_sequences - if response_model: config_params["response_mime_type"] = "application/json" config_params["response_schema"] = response_model.model_json_schema() @@ -300,9 +393,6 @@ class GeminiCompletion(BaseLLM): if tools and self.supports_tools: config_params["tools"] = self._convert_tools_for_interference(tools) - if self.safety_settings: - config_params["safety_settings"] = self.safety_settings - return types.GenerateContentConfig(**config_params) def _convert_tools_for_interference( # type: ignore[no-any-unimported] @@ -380,8 +470,8 @@ class GeminiCompletion(BaseLLM): system_instruction: str | None, config: types.GenerateContentConfig, available_functions: dict[str, Any] | None = None, - from_task: Any | None = None, - from_agent: Any | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, response_model: type[BaseModel] | None = None, ) -> str | Any: """Handle non-streaming content generation.""" @@ -392,7 +482,7 @@ class GeminiCompletion(BaseLLM): } try: - response = self.client.models.generate_content(**api_params) + response = self._client.models.generate_content(**api_params) usage = self._extract_token_usage(response) except Exception as e: @@ -446,8 +536,8 @@ class GeminiCompletion(BaseLLM): contents: list[types.Content], config: types.GenerateContentConfig, available_functions: dict[str, Any] | None = None, - from_task: Any | None = None, - from_agent: Any | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, response_model: type[BaseModel] | None = None, ) -> str: """Handle streaming content generation.""" @@ -460,7 +550,7 @@ class GeminiCompletion(BaseLLM): "config": config, } - for chunk in self.client.models.generate_content_stream(**api_params): + for chunk in self._client.models.generate_content_stream(**api_params): if hasattr(chunk, "text") and chunk.text: full_response += chunk.text self._emit_stream_chunk_event( @@ -513,52 +603,30 @@ class GeminiCompletion(BaseLLM): return full_response + @computed_field # type: ignore[prop-decorator] + @property def supports_function_calling(self) -> bool: """Check if the model supports function calling.""" return self.supports_tools - def supports_stop_words(self) -> bool: - """Check if the model supports stop words.""" - return True - def get_context_window_size(self) -> int: """Get the context window size for the model.""" - from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES - - min_context = 1024 - max_context = 2097152 - for key, value in LLM_CONTEXT_WINDOW_SIZES.items(): - if value < min_context or value > max_context: + if value < MIN_CONTEXT_WINDOW or value > MAX_CONTEXT_WINDOW: raise ValueError( - f"Context window for {key} must be between {min_context} and {max_context}" + f"Context window for {key} must be between {MIN_CONTEXT_WINDOW} and {MAX_CONTEXT_WINDOW}" ) - context_windows = { - "gemini-2.0-flash": 1048576, # 1M tokens - "gemini-2.0-flash-thinking": 32768, - "gemini-2.0-flash-lite": 1048576, - "gemini-2.5-flash": 1048576, - "gemini-2.5-pro": 1048576, - "gemini-1.5-pro": 2097152, # 2M tokens - "gemini-1.5-flash": 1048576, - "gemini-1.5-flash-8b": 1048576, - "gemini-1.0-pro": 32768, - "gemma-3-1b": 32000, - "gemma-3-4b": 128000, - "gemma-3-12b": 128000, - "gemma-3-27b": 128000, - } - # Find the best match for the model name - for model_prefix, size in context_windows.items(): + for model_prefix, size in GEMINI_CONTEXT_WINDOWS.items(): if self.model.startswith(model_prefix): return int(size * CONTEXT_WINDOW_USAGE_RATIO) # Default context window size for Gemini models return int(1048576 * CONTEXT_WINDOW_USAGE_RATIO) # 1M tokens - def _extract_token_usage(self, response: dict[str, Any]) -> dict[str, Any]: + @staticmethod + def _extract_token_usage(response: dict[str, Any]) -> dict[str, Any]: """Extract token usage from Gemini response.""" if hasattr(response, "usage_metadata"): usage = response.usage_metadata @@ -570,8 +638,8 @@ class GeminiCompletion(BaseLLM): } return {"total_tokens": 0} + @staticmethod def _convert_contents_to_dict( # type: ignore[no-any-unimported] - self, contents: list[types.Content], ) -> list[dict[str, str]]: """Convert contents to dict format.""" diff --git a/lib/crewai/src/crewai/llms/providers/openai/completion.py b/lib/crewai/src/crewai/llms/providers/openai/completion.py index fdf7b03c7..b49e2fb5a 100644 --- a/lib/crewai/src/crewai/llms/providers/openai/completion.py +++ b/lib/crewai/src/crewai/llms/providers/openai/completion.py @@ -4,16 +4,23 @@ from collections.abc import Iterator import json import logging import os -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Final import httpx from openai import APIConnectionError, NotFoundError, OpenAI from openai.types.chat import ChatCompletion, ChatCompletionChunk from openai.types.chat.chat_completion import Choice from openai.types.chat.chat_completion_chunk import ChoiceDelta -from pydantic import BaseModel +from pydantic import ( + BaseModel, + Field, + PrivateAttr, + model_validator, +) +from typing_extensions import Self from crewai.events.types.llm_events import LLMCallType +from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES from crewai.llms.base_llm import BaseLLM from crewai.llms.hooks.transport import HTTPTransport from crewai.utilities.agent_utils import is_context_length_exceeded @@ -25,11 +32,28 @@ from crewai.utilities.types import LLMMessage if TYPE_CHECKING: from crewai.agent.core import Agent - from crewai.llms.hooks.base import BaseInterceptor from crewai.task import Task from crewai.tools.base_tool import BaseTool +OPENAI_CONTEXT_WINDOWS: dict[str, int] = { + "gpt-4": 8192, + "gpt-4o": 128000, + "gpt-4o-mini": 200000, + "gpt-4-turbo": 128000, + "gpt-4.1": 1047576, + "gpt-4.1-mini-2025-04-14": 1047576, + "gpt-4.1-nano-2025-04-14": 1047576, + "o1-preview": 128000, + "o1-mini": 128000, + "o3-mini": 200000, + "o4-mini": 200000, +} + +MIN_CONTEXT_WINDOW: Final[int] = 1024 +MAX_CONTEXT_WINDOW: Final[int] = 2097152 + + class OpenAICompletion(BaseLLM): """OpenAI native completion implementation. @@ -37,112 +61,125 @@ class OpenAICompletion(BaseLLM): offering native structured outputs, function calling, and streaming support. """ - def __init__( - self, - model: str = "gpt-4o", - api_key: str | None = None, - base_url: str | None = None, - organization: str | None = None, - project: str | None = None, - timeout: float | None = None, - max_retries: int = 2, - default_headers: dict[str, str] | None = None, - default_query: dict[str, Any] | None = None, - client_params: dict[str, Any] | None = None, - temperature: float | None = None, - top_p: float | None = None, - frequency_penalty: float | None = None, - presence_penalty: float | None = None, - max_tokens: int | None = None, - max_completion_tokens: int | None = None, - seed: int | None = None, - stream: bool = False, - response_format: dict[str, Any] | type[BaseModel] | None = None, - logprobs: bool | None = None, - top_logprobs: int | None = None, - reasoning_effort: str | None = None, - provider: str | None = None, - interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None, - **kwargs: Any, - ) -> None: - """Initialize OpenAI chat completion client.""" + model: str = Field( + default="gpt-4o", + description="OpenAI model name (e.g., 'gpt-4o')", + ) + organization: str | None = Field( + default=None, + description="Name of the OpenAI organization", + ) + project: str | None = Field( + default=None, + description="Name of the OpenAI project", + ) + api_base: str | None = Field( + default=os.getenv("OPENAI_BASE_URL"), + description="Base URL for OpenAI API", + ) + default_headers: dict[str, str] | None = Field( + default=None, + description="Default headers for OpenAI API requests", + ) + default_query: dict[str, Any] | None = Field( + default=None, + description="Default query parameters for OpenAI API requests", + ) + top_p: float | None = Field( + default=None, + description="Top-p sampling parameter", + ) + frequency_penalty: float | None = Field( + default=None, + description="Frequency penalty parameter", + ) + presence_penalty: float | None = Field( + default=None, + description="Presence penalty parameter", + ) + max_completion_tokens: int | None = Field( + default=None, + description="Maximum tokens for completion", + ) + seed: int | None = Field( + default=None, + description="Random seed for reproducibility", + ) + response_format: dict[str, Any] | type[BaseModel] | None = Field( + default=None, + description="Response format for structured output", + ) + logprobs: bool | None = Field( + default=None, + description="Whether to include log probabilities", + ) + top_logprobs: int | None = Field( + default=None, + description="Number of top log probabilities to return", + ) + reasoning_effort: str | None = Field( + default=None, + description="Reasoning effort level for o1 models", + ) + supports_function_calling: bool = Field( + default=True, + description="Whether the model supports function calling", + ) + is_o1_model: bool = Field( + default=False, + description="Whether the model is an o1 model", + ) + is_gpt4_model: bool = Field( + default=False, + description="Whether the model is a GPT-4 model", + ) + _client: OpenAI = PrivateAttr( + default_factory=OpenAI, + ) - if provider is None: - provider = kwargs.pop("provider", "openai") - - self.interceptor = interceptor - # Client configuration attributes - self.organization = organization - self.project = project - self.max_retries = max_retries - self.default_headers = default_headers - self.default_query = default_query - self.client_params = client_params - self.timeout = timeout - self.base_url = base_url - self.api_base = kwargs.pop("api_base", None) - - super().__init__( - model=model, - temperature=temperature, - api_key=api_key or os.getenv("OPENAI_API_KEY"), - base_url=base_url, - timeout=timeout, - provider=provider, - **kwargs, - ) - - client_config = self._get_client_params() - if self.interceptor: - transport = HTTPTransport(interceptor=self.interceptor) - http_client = httpx.Client(transport=transport) - client_config["http_client"] = http_client - - self.client = OpenAI(**client_config) - - # Completion parameters - self.top_p = top_p - self.frequency_penalty = frequency_penalty - self.presence_penalty = presence_penalty - self.max_tokens = max_tokens - self.max_completion_tokens = max_completion_tokens - self.seed = seed - self.stream = stream - self.response_format = response_format - self.logprobs = logprobs - self.top_logprobs = top_logprobs - self.reasoning_effort = reasoning_effort - self.is_o1_model = "o1" in model.lower() - self.is_gpt4_model = "gpt-4" in model.lower() - - def _get_client_params(self) -> dict[str, Any]: - """Get OpenAI client parameters.""" + @model_validator(mode="after") + def initialize_client(self) -> Self: + """Initialize the Anthropic client after Pydantic validation. + This runs after all field validation is complete, ensuring that: + - All BaseLLM fields are set (model, temperature, stop_sequences, etc.) + - Field validators have run (stop_sequences is normalized to set[str]) + - API key and other configuration is ready + """ if self.api_key is None: self.api_key = os.getenv("OPENAI_API_KEY") if self.api_key is None: raise ValueError("OPENAI_API_KEY is required") - base_params = { - "api_key": self.api_key, - "organization": self.organization, - "project": self.project, - "base_url": self.base_url - or self.api_base - or os.getenv("OPENAI_BASE_URL") - or None, - "timeout": self.timeout, - "max_retries": self.max_retries, - "default_headers": self.default_headers, - "default_query": self.default_query, - } + self.is_o1_model = "o1" in self.model.lower() + self.supports_function_calling = not self.is_o1_model + self.is_gpt4_model = "gpt-4" in self.model.lower() + self.supports_stop_words = not self.is_o1_model - client_params = {k: v for k, v in base_params.items() if v is not None} + params = self.model_dump( + include={ + "api_key", + "organization", + "project", + "base_url", + "timeout", + "max_retries", + "default_headers", + "default_query", + }, + exclude_none=True, + ) + + if self.interceptor: + transport = HTTPTransport(interceptor=self.interceptor) + http_client = httpx.Client(transport=transport) + params["http_client"] = http_client if self.client_params: - client_params.update(self.client_params) + params.update(self.client_params) - return client_params + self._client = OpenAI(**params) + return self def call( self, @@ -213,38 +250,26 @@ class OpenAICompletion(BaseLLM): self, messages: list[LLMMessage], tools: list[dict[str, BaseTool]] | None = None ) -> dict[str, Any]: """Prepare parameters for OpenAI chat completion.""" - params: dict[str, Any] = { - "model": self.model, - "messages": messages, - } - if self.stream: - params["stream"] = self.stream - + params = self.model_dump( + include={ + "model", + "stream", + "temperature", + "top_p", + "frequency_penalty", + "presence_penalty", + "max_completion_tokens", + "max_tokens", + "seed", + "logprobs", + "top_logprobs", + "reasoning_effort", + }, + exclude_none=True, + ) + params["messages"] = messages params.update(self.additional_params) - if self.temperature is not None: - params["temperature"] = self.temperature - if self.top_p is not None: - params["top_p"] = self.top_p - if self.frequency_penalty is not None: - params["frequency_penalty"] = self.frequency_penalty - if self.presence_penalty is not None: - params["presence_penalty"] = self.presence_penalty - if self.max_completion_tokens is not None: - params["max_completion_tokens"] = self.max_completion_tokens - elif self.max_tokens is not None: - params["max_tokens"] = self.max_tokens - if self.seed is not None: - params["seed"] = self.seed - if self.logprobs is not None: - params["logprobs"] = self.logprobs - if self.top_logprobs is not None: - params["top_logprobs"] = self.top_logprobs - - # Handle o1 model specific parameters - if self.is_o1_model and self.reasoning_effort: - params["reasoning_effort"] = self.reasoning_effort - if tools: params["tools"] = self._convert_tools_for_interference(tools) params["tool_choice"] = "auto" @@ -296,14 +321,14 @@ class OpenAICompletion(BaseLLM): self, params: dict[str, Any], available_functions: dict[str, Any] | None = None, - from_task: Any | None = None, - from_agent: Any | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, response_model: type[BaseModel] | None = None, ) -> str | Any: """Handle non-streaming chat completion.""" try: if response_model: - parsed_response = self.client.beta.chat.completions.parse( + parsed_response = self._client.beta.chat.completions.parse( **params, response_format=response_model, ) @@ -327,7 +352,7 @@ class OpenAICompletion(BaseLLM): ) return structured_json - response: ChatCompletion = self.client.chat.completions.create(**params) + response: ChatCompletion = self._client.chat.completions.create(**params) usage = self._extract_openai_token_usage(response) @@ -419,8 +444,8 @@ class OpenAICompletion(BaseLLM): self, params: dict[str, Any], available_functions: dict[str, Any] | None = None, - from_task: Any | None = None, - from_agent: Any | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, response_model: type[BaseModel] | None = None, ) -> str: """Handle streaming chat completion.""" @@ -429,7 +454,7 @@ class OpenAICompletion(BaseLLM): if response_model: completion_stream: Iterator[ChatCompletionChunk] = ( - self.client.chat.completions.create(**params) + self._client.chat.completions.create(**params) ) accumulated_content = "" @@ -472,7 +497,7 @@ class OpenAICompletion(BaseLLM): ) return accumulated_content - stream: Iterator[ChatCompletionChunk] = self.client.chat.completions.create( + stream: Iterator[ChatCompletionChunk] = self._client.chat.completions.create( **params ) @@ -550,58 +575,31 @@ class OpenAICompletion(BaseLLM): return full_response - def supports_function_calling(self) -> bool: - """Check if the model supports function calling.""" - return not self.is_o1_model - - def supports_stop_words(self) -> bool: - """Check if the model supports stop words.""" - return not self.is_o1_model - def get_context_window_size(self) -> int: """Get the context window size for the model.""" - from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES - - min_context = 1024 - max_context = 2097152 - for key, value in LLM_CONTEXT_WINDOW_SIZES.items(): - if value < min_context or value > max_context: + if value < MIN_CONTEXT_WINDOW or value > MAX_CONTEXT_WINDOW: raise ValueError( - f"Context window for {key} must be between {min_context} and {max_context}" + f"Context window for {key} must be between {MIN_CONTEXT_WINDOW} and {MAX_CONTEXT_WINDOW}" ) - # Context window sizes for OpenAI models - context_windows = { - "gpt-4": 8192, - "gpt-4o": 128000, - "gpt-4o-mini": 200000, - "gpt-4-turbo": 128000, - "gpt-4.1": 1047576, - "gpt-4.1-mini-2025-04-14": 1047576, - "gpt-4.1-nano-2025-04-14": 1047576, - "o1-preview": 128000, - "o1-mini": 128000, - "o3-mini": 200000, - "o4-mini": 200000, - } - # Find the best match for the model name - for model_prefix, size in context_windows.items(): + for model_prefix, size in OPENAI_CONTEXT_WINDOWS.items(): if self.model.startswith(model_prefix): return int(size * CONTEXT_WINDOW_USAGE_RATIO) # Default context window size return int(8192 * CONTEXT_WINDOW_USAGE_RATIO) - def _extract_openai_token_usage(self, response: ChatCompletion) -> dict[str, Any]: + @staticmethod + def _extract_openai_token_usage(response: ChatCompletion) -> dict[str, Any]: """Extract token usage from OpenAI ChatCompletion response.""" - if hasattr(response, "usage") and response.usage: + if response.usage: usage = response.usage return { - "prompt_tokens": getattr(usage, "prompt_tokens", 0), - "completion_tokens": getattr(usage, "completion_tokens", 0), - "total_tokens": getattr(usage, "total_tokens", 0), + "prompt_tokens": usage.prompt_tokens, + "completion_tokens": usage.completion_tokens, + "total_tokens": usage.total_tokens, } return {"total_tokens": 0}