Compare commits

...

6 Commits

8 changed files with 719 additions and 716 deletions

View File

@@ -618,22 +618,22 @@ class Agent(BaseAgent):
response_template=self.response_template, response_template=self.response_template,
).task_execution() ).task_execution()
stop_words = [self.i18n.slice("observation")] stop_sequences = [self.i18n.slice("observation")]
if self.response_template: if self.response_template:
stop_words.append( stop_sequences.append(
self.response_template.split("{{ .Response }}")[1].strip() self.response_template.split("{{ .Response }}")[1].strip()
) )
self.agent_executor = CrewAgentExecutor( self.agent_executor = CrewAgentExecutor(
llm=self.llm, llm=self.llm, # type: ignore[arg-type]
task=task, # type: ignore[arg-type] task=task, # type: ignore[arg-type]
agent=self, agent=self,
crew=self.crew, crew=self.crew,
tools=parsed_tools, tools=parsed_tools,
prompt=prompt, prompt=prompt,
original_tools=raw_tools, original_tools=raw_tools,
stop_words=stop_words, stop_sequences=stop_sequences,
max_iter=self.max_iter, max_iter=self.max_iter,
tools_handler=self.tools_handler, tools_handler=self.tools_handler,
tools_names=get_tool_names(parsed_tools), tools_names=get_tool_names(parsed_tools),
@@ -974,7 +974,9 @@ class Agent(BaseAgent):
path = parsed.path.replace("/", "_").strip("_") path = parsed.path.replace("/", "_").strip("_")
return f"{domain}_{path}" if path else domain return f"{domain}_{path}" if path else domain
def _get_mcp_tool_schemas(self, server_params: dict) -> dict[str, dict]: def _get_mcp_tool_schemas(
self, server_params: dict[str, Any]
) -> dict[str, dict[str, Any]] | Any:
"""Get tool schemas from MCP server for wrapper creation with caching.""" """Get tool schemas from MCP server for wrapper creation with caching."""
server_url = server_params["url"] server_url = server_params["url"]
@@ -1006,7 +1008,7 @@ class Agent(BaseAgent):
async def _get_mcp_tool_schemas_async( async def _get_mcp_tool_schemas_async(
self, server_params: dict[str, Any] self, server_params: dict[str, Any]
) -> dict[str, dict]: ) -> dict[str, dict[str, Any]]:
"""Async implementation of MCP tool schema retrieval with timeouts and retries.""" """Async implementation of MCP tool schema retrieval with timeouts and retries."""
server_url = server_params["url"] server_url = server_params["url"]
return await self._retry_mcp_discovery( return await self._retry_mcp_discovery(
@@ -1014,7 +1016,7 @@ class Agent(BaseAgent):
) )
async def _retry_mcp_discovery( async def _retry_mcp_discovery(
self, operation_func, server_url: str self, operation_func: Any, server_url: str
) -> dict[str, dict[str, Any]]: ) -> dict[str, dict[str, Any]]:
"""Retry MCP discovery operation with exponential backoff, avoiding try-except in loop.""" """Retry MCP discovery operation with exponential backoff, avoiding try-except in loop."""
last_error = None last_error = None
@@ -1045,7 +1047,7 @@ class Agent(BaseAgent):
@staticmethod @staticmethod
async def _attempt_mcp_discovery( async def _attempt_mcp_discovery(
operation_func, server_url: str operation_func: Any, server_url: str
) -> tuple[dict[str, dict[str, Any]] | None, str, bool]: ) -> tuple[dict[str, dict[str, Any]] | None, str, bool]:
"""Attempt single MCP discovery operation and return (result, error_message, should_retry).""" """Attempt single MCP discovery operation and return (result, error_message, should_retry)."""
try: try:
@@ -1149,13 +1151,13 @@ class Agent(BaseAgent):
Field(..., description=field_description), Field(..., description=field_description),
) )
else: else:
field_definitions[field_name] = ( field_definitions[field_name] = ( # type: ignore[assignment]
field_type | None, field_type | None,
Field(default=None, description=field_description), Field(default=None, description=field_description),
) )
model_name = f"{tool_name.replace('-', '_').replace(' ', '_')}Schema" model_name = f"{tool_name.replace('-', '_').replace(' ', '_')}Schema"
return create_model(model_name, **field_definitions) return create_model(model_name, **field_definitions) # type: ignore[no-any-return,call-overload]
def _json_type_to_python(self, field_schema: dict[str, Any]) -> type: def _json_type_to_python(self, field_schema: dict[str, Any]) -> type:
"""Convert JSON Schema type to Python type. """Convert JSON Schema type to Python type.
@@ -1175,12 +1177,12 @@ class Agent(BaseAgent):
if "const" in option: if "const" in option:
types.append(str) types.append(str)
else: else:
types.append(self._json_type_to_python(option)) types.append(self._json_type_to_python(option)) # type: ignore[arg-type]
unique_types = list(set(types)) unique_types = list(set(types))
if len(unique_types) > 1: if len(unique_types) > 1:
result = unique_types[0] result = unique_types[0]
for t in unique_types[1:]: for t in unique_types[1:]:
result = result | t result = result | t # type: ignore[assignment]
return result return result
return unique_types[0] return unique_types[0]
@@ -1193,10 +1195,10 @@ class Agent(BaseAgent):
"object": dict, "object": dict,
} }
return type_mapping.get(json_type, Any) return type_mapping.get(json_type, Any) # type: ignore[arg-type]
@staticmethod @staticmethod
def _fetch_amp_mcp_servers(mcp_name: str) -> list[dict]: def _fetch_amp_mcp_servers(mcp_name: str) -> list[dict[str, Any]]:
"""Fetch MCP server configurations from CrewAI AMP API.""" """Fetch MCP server configurations from CrewAI AMP API."""
# TODO: Implement AMP API call to "integrations/mcps" endpoint # TODO: Implement AMP API call to "integrations/mcps" endpoint
# Should return list of server configs with URLs # Should return list of server configs with URLs
@@ -1435,7 +1437,7 @@ class Agent(BaseAgent):
goal=self.goal, goal=self.goal,
backstory=self.backstory, backstory=self.backstory,
llm=self.llm, llm=self.llm,
tools=self.tools or [], tools=self.tools,
max_iterations=self.max_iter, max_iterations=self.max_iter,
max_execution_time=self.max_execution_time, max_execution_time=self.max_execution_time,
respect_context_window=self.respect_context_window, respect_context_window=self.respect_context_window,

View File

@@ -137,7 +137,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
default=False, default=False,
description="Enable agent to delegate and ask questions among each other.", description="Enable agent to delegate and ask questions among each other.",
) )
tools: list[BaseTool] | None = Field( tools: list[BaseTool] = Field(
default_factory=list, description="Tools at agents' disposal" default_factory=list, description="Tools at agents' disposal"
) )
max_iter: int = Field( max_iter: int = Field(

View File

@@ -73,7 +73,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
max_iter: int, max_iter: int,
tools: list[CrewStructuredTool], tools: list[CrewStructuredTool],
tools_names: str, tools_names: str,
stop_words: list[str], stop_sequences: list[str],
tools_description: str, tools_description: str,
tools_handler: ToolsHandler, tools_handler: ToolsHandler,
step_callback: Any = None, step_callback: Any = None,
@@ -95,7 +95,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
max_iter: Maximum iterations. max_iter: Maximum iterations.
tools: Available tools. tools: Available tools.
tools_names: Tool names string. tools_names: Tool names string.
stop_words: Stop word list. stop_sequences: Stop sequences list for halting generation.
tools_description: Tool descriptions. tools_description: Tool descriptions.
tools_handler: Tool handler instance. tools_handler: Tool handler instance.
step_callback: Optional step callback. step_callback: Optional step callback.
@@ -114,7 +114,6 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self.prompt = prompt self.prompt = prompt
self.tools = tools self.tools = tools
self.tools_names = tools_names self.tools_names = tools_names
self.stop = stop_words
self.max_iter = max_iter self.max_iter = max_iter
self.callbacks = callbacks or [] self.callbacks = callbacks or []
self._printer: Printer = Printer() self._printer: Printer = Printer()
@@ -131,15 +130,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self.iterations = 0 self.iterations = 0
self.log_error_after = 3 self.log_error_after = 3
if self.llm: if self.llm:
# This may be mutating the shared llm object and needs further evaluation self.llm.stop_sequences.extend(stop_sequences)
existing_stop = getattr(self.llm, "stop", [])
self.llm.stop = list(
set(
existing_stop + self.stop
if isinstance(existing_stop, list)
else self.stop
)
)
@property @property
def use_stop_words(self) -> bool: def use_stop_words(self) -> bool:
@@ -148,7 +139,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
Returns: Returns:
bool: True if tool should be used or not. bool: True if tool should be used or not.
""" """
return self.llm.supports_stop_words() if self.llm else False return self.llm.supports_stop_words if self.llm else False
def invoke(self, inputs: dict[str, Any]) -> dict[str, Any]: def invoke(self, inputs: dict[str, Any]) -> dict[str, Any]:
"""Execute the agent with given inputs. """Execute the agent with given inputs.

View File

@@ -20,8 +20,7 @@ from typing import (
) )
from dotenv import load_dotenv from dotenv import load_dotenv
import httpx from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, Field
from typing_extensions import Self from typing_extensions import Self
from crewai.events.event_bus import crewai_event_bus from crewai.events.event_bus import crewai_event_bus
@@ -54,7 +53,6 @@ if TYPE_CHECKING:
from litellm.utils import supports_response_schema from litellm.utils import supports_response_schema
from crewai.agent.core import Agent from crewai.agent.core import Agent
from crewai.llms.hooks.base import BaseInterceptor
from crewai.task import Task from crewai.task import Task
from crewai.tools.base_tool import BaseTool from crewai.tools.base_tool import BaseTool
from crewai.utilities.types import LLMMessage from crewai.utilities.types import LLMMessage
@@ -320,7 +318,138 @@ class AccumulatedToolArgs(BaseModel):
class LLM(BaseLLM): class LLM(BaseLLM):
completion_cost: float | None = None completion_cost: float | None = Field(
default=None, description="The completion cost of the LLM."
)
top_p: float | None = Field(
default=None, description="Sampling probability threshold."
)
n: int | None = Field(
default=None, description="Number of completions to generate."
)
max_completion_tokens: int | None = Field(
default=None,
description="Maximum number of tokens to generate in the completion.",
)
max_tokens: int | None = Field(
default=None,
description="Maximum number of tokens allowed in the prompt + completion.",
)
presence_penalty: float | None = Field(
default=None, description="Penalty on the presence penalty."
)
frequency_penalty: float | None = Field(
default=None, description="Penalty on the frequency penalty."
)
logit_bias: dict[int, float] | None = Field(
default=None,
description="Modifies the likelihood of specified tokens appearing in the completion.",
)
response_format: type[BaseModel] | None = Field(
default=None,
description="Pydantic model class for structured response parsing.",
)
seed: int | None = Field(
default=None,
description="Random seed for reproducibility.",
)
logprobs: int | None = Field(
default=None,
description="Number of top logprobs to return.",
)
top_logprobs: int | None = Field(
default=None,
description="Number of top logprobs to return.",
)
api_base: str | None = Field(
default=None,
description="Base URL for the API endpoint.",
)
api_version: str | None = Field(
default=None,
description="API version to use.",
)
callbacks: list[Any] = Field(
default_factory=list,
description="List of callback handlers for LLM events.",
)
reasoning_effort: Literal["none", "low", "medium", "high"] | None = Field(
default=None,
description="Level of reasoning effort for the LLM.",
)
context_window_size: int = Field(
default=0,
description="The context window size of the LLM.",
)
is_anthropic: bool = Field(
default=False,
description="Indicates if the model is from Anthropic provider.",
)
supports_function_calling: bool = Field(
default=False,
description="Indicates if the model supports function calling.",
)
supports_stop_words: bool = Field(
default=False,
description="Indicates if the model supports stop words.",
)
@model_validator(mode="after")
def initialize_client(self) -> Self:
self.is_anthropic = any(
prefix in self.model.lower() for prefix in ANTHROPIC_PREFIXES
)
try:
provider = self._get_custom_llm_provider()
self.supports_function_calling = litellm.utils.supports_function_calling(
self.model, custom_llm_provider=provider
)
except Exception as e:
logging.error(f"Failed to check function calling support: {e!s}")
self.supports_function_calling = False
try:
params = get_supported_openai_params(model=self.model)
self.supports_stop_words = params is not None and "stop" in params
except Exception as e:
logging.error(f"Failed to get supported params: {e!s}")
self.supports_stop_words = False
with suppress_warnings():
callback_types = [type(callback) for callback in self.callbacks]
for callback in litellm.success_callback[:]:
if type(callback) in callback_types:
litellm.success_callback.remove(callback)
for callback in litellm._async_success_callback[:]:
if type(callback) in callback_types:
litellm._async_success_callback.remove(callback)
litellm.callbacks = self.callbacks
with suppress_warnings():
success_callbacks_str = os.environ.get("LITELLM_SUCCESS_CALLBACKS", "")
success_callbacks: list[str | Callable[..., Any] | CustomLogger] = []
if success_callbacks_str:
success_callbacks = [
cb.strip() for cb in success_callbacks_str.split(",") if cb.strip()
]
failure_callbacks_str = os.environ.get("LITELLM_FAILURE_CALLBACKS", "")
if failure_callbacks_str:
failure_callbacks: list[str | Callable[..., Any] | CustomLogger] = [
cb.strip() for cb in failure_callbacks_str.split(",") if cb.strip()
]
litellm.success_callback = success_callbacks
litellm.failure_callback = failure_callbacks
return self
# @computed_field
# @property
# def is_anthropic(self) -> bool:
# """Determine if the model is from Anthropic provider."""
# anthropic_prefixes = ("anthropic/", "claude-", "claude/")
# return any(prefix in self.model.lower() for prefix in anthropic_prefixes)
def __new__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> LLM: def __new__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> LLM:
"""Factory method that routes to native SDK or falls back to LiteLLM.""" """Factory method that routes to native SDK or falls back to LiteLLM."""
@@ -383,98 +512,6 @@ class LLM(BaseLLM):
return None return None
def __init__(
self,
model: str,
timeout: float | int | None = None,
temperature: float | None = None,
top_p: float | None = None,
n: int | None = None,
stop: str | list[str] | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | float | None = None,
presence_penalty: float | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[int, float] | None = None,
response_format: type[BaseModel] | None = None,
seed: int | None = None,
logprobs: int | None = None,
top_logprobs: int | None = None,
base_url: str | None = None,
api_base: str | None = None,
api_version: str | None = None,
api_key: str | None = None,
callbacks: list[Any] | None = None,
reasoning_effort: Literal["none", "low", "medium", "high"] | None = None,
stream: bool = False,
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
**kwargs: Any,
) -> None:
"""Initialize LLM instance.
Note: This __init__ method is only called for fallback instances.
Native provider instances handle their own initialization in their respective classes.
"""
super().__init__(
model=model,
temperature=temperature,
api_key=api_key,
base_url=base_url,
timeout=timeout,
**kwargs,
)
self.model = model
self.timeout = timeout
self.temperature = temperature
self.top_p = top_p
self.n = n
self.max_completion_tokens = max_completion_tokens
self.max_tokens = max_tokens
self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty
self.logit_bias = logit_bias
self.response_format = response_format
self.seed = seed
self.logprobs = logprobs
self.top_logprobs = top_logprobs
self.base_url = base_url
self.api_base = api_base
self.api_version = api_version
self.api_key = api_key
self.callbacks = callbacks
self.context_window_size = 0
self.reasoning_effort = reasoning_effort
self.additional_params = kwargs
self.is_anthropic = self._is_anthropic_model(model)
self.stream = stream
self.interceptor = interceptor
litellm.drop_params = True
# Normalize self.stop to always be a list[str]
if stop is None:
self.stop: list[str] = []
elif isinstance(stop, str):
self.stop = [stop]
else:
self.stop = stop
self.set_callbacks(callbacks or [])
self.set_env_callbacks()
@staticmethod
def _is_anthropic_model(model: str) -> bool:
"""Determine if the model is from Anthropic provider.
Args:
model: The model identifier string.
Returns:
bool: True if the model is from Anthropic, False otherwise.
"""
anthropic_prefixes = ("anthropic/", "claude-", "claude/")
return any(prefix in model.lower() for prefix in anthropic_prefixes)
def _prepare_completion_params( def _prepare_completion_params(
self, self,
messages: str | list[LLMMessage], messages: str | list[LLMMessage],
@@ -1188,8 +1225,6 @@ class LLM(BaseLLM):
message["role"] = msg_role message["role"] = msg_role
# --- 5) Set up callbacks if provided # --- 5) Set up callbacks if provided
with suppress_warnings(): with suppress_warnings():
if callbacks and len(callbacks) > 0:
self.set_callbacks(callbacks)
try: try:
# --- 6) Prepare parameters for the completion call # --- 6) Prepare parameters for the completion call
params = self._prepare_completion_params(messages, tools) params = self._prepare_completion_params(messages, tools)
@@ -1378,24 +1413,6 @@ class LLM(BaseLLM):
"Please remove response_format or use a supported model." "Please remove response_format or use a supported model."
) )
def supports_function_calling(self) -> bool:
try:
provider = self._get_custom_llm_provider()
return litellm.utils.supports_function_calling(
self.model, custom_llm_provider=provider
)
except Exception as e:
logging.error(f"Failed to check function calling support: {e!s}")
return False
def supports_stop_words(self) -> bool:
try:
params = get_supported_openai_params(model=self.model)
return params is not None and "stop" in params
except Exception as e:
logging.error(f"Failed to get supported params: {e!s}")
return False
def get_context_window_size(self) -> int: def get_context_window_size(self) -> int:
""" """
Returns the context window size, using 75% of the maximum to avoid Returns the context window size, using 75% of the maximum to avoid
@@ -1425,60 +1442,6 @@ class LLM(BaseLLM):
self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO) self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO)
return self.context_window_size return self.context_window_size
@staticmethod
def set_callbacks(callbacks: list[Any]) -> None:
"""
Attempt to keep a single set of callbacks in litellm by removing old
duplicates and adding new ones.
"""
with suppress_warnings():
callback_types = [type(callback) for callback in callbacks]
for callback in litellm.success_callback[:]:
if type(callback) in callback_types:
litellm.success_callback.remove(callback)
for callback in litellm._async_success_callback[:]:
if type(callback) in callback_types:
litellm._async_success_callback.remove(callback)
litellm.callbacks = callbacks
@staticmethod
def set_env_callbacks() -> None:
"""Sets the success and failure callbacks for the LiteLLM library from environment variables.
This method reads the `LITELLM_SUCCESS_CALLBACKS` and `LITELLM_FAILURE_CALLBACKS`
environment variables, which should contain comma-separated lists of callback names.
It then assigns these lists to `litellm.success_callback` and `litellm.failure_callback`,
respectively.
If the environment variables are not set or are empty, the corresponding callback lists
will be set to empty lists.
Examples:
LITELLM_SUCCESS_CALLBACKS="langfuse,langsmith"
LITELLM_FAILURE_CALLBACKS="langfuse"
This will set `litellm.success_callback` to ["langfuse", "langsmith"] and
`litellm.failure_callback` to ["langfuse"].
"""
with suppress_warnings():
success_callbacks_str = os.environ.get("LITELLM_SUCCESS_CALLBACKS", "")
success_callbacks: list[str | Callable[..., Any] | CustomLogger] = []
if success_callbacks_str:
success_callbacks = [
cb.strip() for cb in success_callbacks_str.split(",") if cb.strip()
]
failure_callbacks_str = os.environ.get("LITELLM_FAILURE_CALLBACKS", "")
if failure_callbacks_str:
failure_callbacks: list[str | Callable[..., Any] | CustomLogger] = [
cb.strip() for cb in failure_callbacks_str.split(",") if cb.strip()
]
litellm.success_callback = success_callbacks
litellm.failure_callback = failure_callbacks
def __copy__(self) -> LLM: def __copy__(self) -> LLM:
"""Create a shallow copy of the LLM instance.""" """Create a shallow copy of the LLM instance."""
# Filter out parameters that are already explicitly passed to avoid conflicts # Filter out parameters that are already explicitly passed to avoid conflicts
@@ -1539,7 +1502,7 @@ class LLM(BaseLLM):
**filtered_params, **filtered_params,
) )
def __deepcopy__(self, memo: dict[int, Any] | None) -> LLM: def __deepcopy__(self, memo: dict[int, Any] | None) -> LLM: # type: ignore[override]
"""Create a deep copy of the LLM instance.""" """Create a deep copy of the LLM instance."""
import copy import copy

View File

@@ -13,8 +13,9 @@ import logging
import re import re
from typing import TYPE_CHECKING, Any, Final from typing import TYPE_CHECKING, Any, Final
from pydantic import BaseModel from pydantic import AliasChoices, BaseModel, Field, PrivateAttr, field_validator
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
from crewai.events.event_bus import crewai_event_bus from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.llm_events import ( from crewai.events.types.llm_events import (
LLMCallCompletedEvent, LLMCallCompletedEvent,
@@ -28,6 +29,7 @@ from crewai.events.types.tool_usage_events import (
ToolUsageFinishedEvent, ToolUsageFinishedEvent,
ToolUsageStartedEvent, ToolUsageStartedEvent,
) )
from crewai.llms.hooks import BaseInterceptor
from crewai.types.usage_metrics import UsageMetrics from crewai.types.usage_metrics import UsageMetrics
@@ -43,7 +45,7 @@ DEFAULT_SUPPORTS_STOP_WORDS: Final[bool] = True
_JSON_EXTRACTION_PATTERN: Final[re.Pattern[str]] = re.compile(r"\{.*}", re.DOTALL) _JSON_EXTRACTION_PATTERN: Final[re.Pattern[str]] = re.compile(r"\{.*}", re.DOTALL)
class BaseLLM(ABC): class BaseLLM(BaseModel, ABC):
"""Abstract base class for LLM implementations. """Abstract base class for LLM implementations.
This class defines the interface that all LLM implementations must follow. This class defines the interface that all LLM implementations must follow.
@@ -55,70 +57,105 @@ class BaseLLM(ABC):
implement proper validation for input parameters and provide clear error implement proper validation for input parameters and provide clear error
messages when things go wrong. messages when things go wrong.
Attributes: Attributes:
model: The model identifier/name. model: The model identifier/name.
temperature: Optional temperature setting for response generation. temperature: Optional temperature setting for response generation.
stop: A list of stop sequences that the LLM should use to stop generation.
additional_params: Additional provider-specific parameters.
""" """
is_litellm: bool = False provider: str | re.Pattern[str] = Field(
default="openai", description="The provider of the LLM."
)
model: str = Field(description="The model identifier/name.")
temperature: float | None = Field(
default=None, ge=0, le=2, description="Temperature for response generation."
)
api_key: str | None = Field(default=None, description="API key for authentication.")
base_url: str | None = Field(default=None, description="Base URL for API calls.")
timeout: float | None = Field(default=None, description="Timeout for API calls.")
max_retries: int = Field(
default=2, description="Maximum number of API requests to make."
)
max_tokens: int | None = Field(
default=None, description="Maximum tokens for response generation."
)
stream: bool | None = Field(default=False, description="Stream the API requests.")
client: Any = Field(description="Underlying LLM client instance.")
interceptor: BaseInterceptor[Any, Any] | None = Field(
default=None,
description="An optional HTTPX interceptor for modifying requests/responses.",
)
client_params: dict[str, Any] = Field(
default_factory=dict,
description="Additional parameters for the underlying LLM client.",
)
supports_stop_words: bool = Field(
default=DEFAULT_SUPPORTS_STOP_WORDS,
description="Whether or not to support stop words.",
)
stop_sequences: list[str] = Field(
default_factory=list,
validation_alias=AliasChoices("stop_sequences", "stop"),
description="Stop sequences for generation (synchronized with stop).",
)
is_litellm: bool = Field(
default=False, description="Is this LLM implementation in litellm?"
)
additional_params: dict[str, Any] = Field(
default_factory=dict,
description="Additional parameters for LLM calls.",
)
_token_usage: TokenProcess = PrivateAttr(default_factory=TokenProcess)
def __init__( @field_validator("provider", mode="before")
self, @classmethod
model: str, def extract_provider_from_model(
temperature: float | None = None, cls, v: str | re.Pattern[str] | None, info: Any
api_key: str | None = None, ) -> str | re.Pattern[str]:
base_url: str | None = None, """Extract provider from model string if not explicitly provided.
provider: str | None = None,
**kwargs: Any,
) -> None:
"""Initialize the BaseLLM with default attributes.
Args: Args:
model: The model identifier/name. v: Provided provider value (can be str, Pattern, or None)
temperature: Optional temperature setting for response generation. info: Validation info containing other field values
stop: Optional list of stop sequences for generation.
**kwargs: Additional provider-specific parameters. Returns:
Provider name (str) or Pattern
""" """
if not model: # If provider explicitly provided, validate and return it
raise ValueError("Model name is required and cannot be empty") if v is not None:
if not isinstance(v, (str, re.Pattern)):
raise ValueError(f"Provider must be str or Pattern, got {type(v)}")
return v
self.model = model model: str = info.data.get("model", "")
self.temperature = temperature if "/" in model:
self.api_key = api_key return model.partition("/")[0]
self.base_url = base_url return "openai"
# Store additional parameters for provider-specific use
self.additional_params = kwargs
self._provider = provider or "openai"
stop = kwargs.pop("stop", None) @field_validator("stop_sequences", mode="before")
if stop is None: @classmethod
self.stop: list[str] = [] def normalize_stop_sequences(
elif isinstance(stop, str): cls, v: str | list[str] | set[str] | None
self.stop = [stop] ) -> list[str]:
elif isinstance(stop, list): """Validate and normalize stop sequences.
self.stop = stop
else:
self.stop = []
self._token_usage = { Converts string to list and handles None values.
"total_tokens": 0, AliasChoices handles accepting both 'stop' and 'stop_sequences' parameter names.
"prompt_tokens": 0, """
"completion_tokens": 0, if v is None:
"successful_requests": 0, return []
"cached_prompt_tokens": 0, if isinstance(v, str):
} return [v]
if isinstance(v, set):
return list(v)
if isinstance(v, list):
return v
return []
@property @property
def provider(self) -> str: def stop(self) -> list[str]:
"""Get the provider of the LLM.""" """Alias for stop_sequences to maintain backward compatibility."""
return self._provider return self.stop_sequences
@provider.setter
def provider(self, value: str) -> None:
"""Set the provider of the LLM."""
self._provider = value
@abstractmethod @abstractmethod
def call( def call(
@@ -171,14 +208,6 @@ class BaseLLM(ABC):
""" """
return tools return tools
def supports_stop_words(self) -> bool:
"""Check if the LLM supports stop words.
Returns:
True if the LLM supports stop words, False otherwise.
"""
return DEFAULT_SUPPORTS_STOP_WORDS
def _supports_stop_words_implementation(self) -> bool: def _supports_stop_words_implementation(self) -> bool:
"""Check if stop words are configured for this LLM instance. """Check if stop words are configured for this LLM instance.
@@ -506,7 +535,7 @@ class BaseLLM(ABC):
""" """
if "/" in model: if "/" in model:
return model.partition("/")[0] return model.partition("/")[0]
return "openai" # Default provider return "openai"
def _track_token_usage_internal(self, usage_data: dict[str, Any]) -> None: def _track_token_usage_internal(self, usage_data: dict[str, Any]) -> None:
"""Track token usage internally in the LLM instance. """Track token usage internally in the LLM instance.
@@ -535,11 +564,11 @@ class BaseLLM(ABC):
or 0 or 0
) )
self._token_usage["prompt_tokens"] += prompt_tokens self._token_usage.prompt_tokens += prompt_tokens
self._token_usage["completion_tokens"] += completion_tokens self._token_usage.completion_tokens += completion_tokens
self._token_usage["total_tokens"] += prompt_tokens + completion_tokens self._token_usage.total_tokens += prompt_tokens + completion_tokens
self._token_usage["successful_requests"] += 1 self._token_usage.successful_requests += 1
self._token_usage["cached_prompt_tokens"] += cached_tokens self._token_usage.cached_prompt_tokens += cached_tokens
def get_token_usage_summary(self) -> UsageMetrics: def get_token_usage_summary(self) -> UsageMetrics:
"""Get summary of token usage for this LLM instance. """Get summary of token usage for this LLM instance.
@@ -547,4 +576,10 @@ class BaseLLM(ABC):
Returns: Returns:
Dictionary with token usage totals Dictionary with token usage totals
""" """
return UsageMetrics(**self._token_usage) return UsageMetrics(
prompt_tokens=self._token_usage.prompt_tokens,
completion_tokens=self._token_usage.completion_tokens,
total_tokens=self._token_usage.total_tokens,
successful_requests=self._token_usage.successful_requests,
cached_prompt_tokens=self._token_usage.cached_prompt_tokens,
)

View File

@@ -5,11 +5,14 @@ import logging
import os import os
from typing import TYPE_CHECKING, Any, cast from typing import TYPE_CHECKING, Any, cast
from pydantic import BaseModel from pydantic import BaseModel, Field, PrivateAttr, computed_field, model_validator
from typing_extensions import Self
from crewai.events.types.llm_events import LLMCallType from crewai.events.types.llm_events import LLMCallType
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO
from crewai.llms.base_llm import BaseLLM from crewai.llms.base_llm import BaseLLM
from crewai.llms.hooks.transport import HTTPTransport from crewai.llms.hooks.transport import HTTPTransport
from crewai.llms.providers.utils.common import safe_tool_conversion
from crewai.utilities.agent_utils import is_context_length_exceeded from crewai.utilities.agent_utils import is_context_length_exceeded
from crewai.utilities.exceptions.context_window_exceeding_exception import ( from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededError, LLMContextLengthExceededError,
@@ -18,7 +21,8 @@ from crewai.utilities.types import LLMMessage
if TYPE_CHECKING: if TYPE_CHECKING:
from crewai.llms.hooks.base import BaseInterceptor from crewai.agent import Agent
from crewai.task import Task
try: try:
from anthropic import Anthropic from anthropic import Anthropic
@@ -31,6 +35,19 @@ except ImportError:
) from None ) from None
ANTHROPIC_CONTEXT_WINDOWS: dict[str, int] = {
"claude-3-5-sonnet": 200000,
"claude-3-5-haiku": 200000,
"claude-3-opus": 200000,
"claude-3-sonnet": 200000,
"claude-3-haiku": 200000,
"claude-3-7-sonnet": 200000,
"claude-2.1": 200000,
"claude-2": 100000,
"claude-instant": 100000,
}
class AnthropicCompletion(BaseLLM): class AnthropicCompletion(BaseLLM):
"""Anthropic native completion implementation. """Anthropic native completion implementation.
@@ -38,110 +55,69 @@ class AnthropicCompletion(BaseLLM):
offering native tool use, streaming support, and proper message formatting. offering native tool use, streaming support, and proper message formatting.
""" """
def __init__( model: str = Field(
self, default="claude-3-5-sonnet-20241022",
model: str = "claude-3-5-sonnet-20241022", description="Anthropic model name (e.g., 'claude-3-5-sonnet-20241022')",
api_key: str | None = None, )
base_url: str | None = None, max_tokens: int = Field(
timeout: float | None = None, default=4096,
max_retries: int = 2, description="Maximum number of allowed tokens in response.",
temperature: float | None = None, )
max_tokens: int = 4096, # Required for Anthropic top_p: float | None = Field(
top_p: float | None = None, default=None,
stop_sequences: list[str] | None = None, description="Nucleus sampling parameter.",
stream: bool = False, )
client_params: dict[str, Any] | None = None, _client: Anthropic = PrivateAttr(
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None, default_factory=Anthropic,
**kwargs: Any, )
):
"""Initialize Anthropic chat completion client.
Args: @model_validator(mode="after")
model: Anthropic model name (e.g., 'claude-3-5-sonnet-20241022') def initialize_client(self) -> Self:
api_key: Anthropic API key (defaults to ANTHROPIC_API_KEY env var) """Initialize the Anthropic client after Pydantic validation.
base_url: Custom base URL for Anthropic API
timeout: Request timeout in seconds This runs after all field validation is complete, ensuring that:
max_retries: Maximum number of retries - All BaseLLM fields are set (model, temperature, stop_sequences, etc.)
temperature: Sampling temperature (0-1) - Field validators have run (stop_sequences is normalized to set[str])
max_tokens: Maximum tokens in response (required for Anthropic) - API key and other configuration is ready
top_p: Nucleus sampling parameter
stop_sequences: Stop sequences (Anthropic uses stop_sequences, not stop)
stream: Enable streaming responses
client_params: Additional parameters for the Anthropic client
interceptor: HTTP interceptor for modifying requests/responses at transport level.
**kwargs: Additional parameters
""" """
super().__init__(
model=model, temperature=temperature, stop=stop_sequences or [], **kwargs
)
# Client params
self.interceptor = interceptor
self.client_params = client_params
self.base_url = base_url
self.timeout = timeout
self.max_retries = max_retries
self.client = Anthropic(**self._get_client_params())
# Store completion parameters
self.max_tokens = max_tokens
self.top_p = top_p
self.stream = stream
self.stop_sequences = stop_sequences or []
# Model-specific settings
self.is_claude_3 = "claude-3" in model.lower()
self.supports_tools = self.is_claude_3 # Claude 3+ supports tool use
@property
def stop(self) -> list[str]:
"""Get stop sequences sent to the API."""
return self.stop_sequences
@stop.setter
def stop(self, value: list[str] | str | None) -> None:
"""Set stop sequences.
Synchronizes stop_sequences to ensure values set by CrewAgentExecutor
are properly sent to the Anthropic API.
Args:
value: Stop sequences as a list, single string, or None
"""
if value is None:
self.stop_sequences = []
elif isinstance(value, str):
self.stop_sequences = [value]
elif isinstance(value, list):
self.stop_sequences = value
else:
self.stop_sequences = []
def _get_client_params(self) -> dict[str, Any]:
"""Get client parameters."""
if self.api_key is None: if self.api_key is None:
self.api_key = os.getenv("ANTHROPIC_API_KEY") self.api_key = os.getenv("ANTHROPIC_API_KEY")
if self.api_key is None: if self.api_key is None:
raise ValueError("ANTHROPIC_API_KEY is required") raise ValueError("ANTHROPIC_API_KEY is required")
client_params = { params = self.model_dump(
"api_key": self.api_key, include={"api_key", "base_url", "timeout", "max_retries"},
"base_url": self.base_url, exclude_none=True,
"timeout": self.timeout, )
"max_retries": self.max_retries,
}
if self.interceptor: if self.interceptor:
transport = HTTPTransport(interceptor=self.interceptor) transport = HTTPTransport(interceptor=self.interceptor)
http_client = httpx.Client(transport=transport) http_client = httpx.Client(transport=transport)
client_params["http_client"] = http_client # type: ignore[assignment] params["http_client"] = http_client
if self.client_params: if self.client_params:
client_params.update(self.client_params) params.update(self.client_params)
return client_params self._client = Anthropic(**params)
return self
@computed_field # type: ignore[prop-decorator]
@property
def is_claude_3(self) -> bool:
"""Check if the model is Claude 3 or higher."""
return "claude-3" in self.model.lower()
@computed_field # type: ignore[prop-decorator]
@property
def supports_tools(self) -> bool:
"""Check if the model supports tool use."""
return self.is_claude_3
@computed_field # type: ignore[prop-decorator]
@property
def supports_function_calling(self) -> bool:
"""Check if the model supports function calling."""
return self.supports_tools
def call( def call(
self, self,
@@ -149,8 +125,8 @@ class AnthropicCompletion(BaseLLM):
tools: list[dict[str, Any]] | None = None, tools: list[dict[str, Any]] | None = None,
callbacks: list[Any] | None = None, callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None, available_functions: dict[str, Any] | None = None,
from_task: Any | None = None, from_task: Task | None = None,
from_agent: Any | None = None, from_agent: Agent | None = None,
response_model: type[BaseModel] | None = None, response_model: type[BaseModel] | None = None,
) -> str | Any: ) -> str | Any:
"""Call Anthropic messages API. """Call Anthropic messages API.
@@ -229,25 +205,21 @@ class AnthropicCompletion(BaseLLM):
Returns: Returns:
Parameters dictionary for Anthropic API Parameters dictionary for Anthropic API
""" """
params = { params = self.model_dump(
"model": self.model, include={
"messages": messages, "model",
"max_tokens": self.max_tokens, "max_tokens",
"stream": self.stream, "stream",
} "temperature",
"top_p",
"stop_sequences",
},
)
params["messages"] = messages
# Add system message if present # Add system message if present
if system_message: if system_message:
params["system"] = system_message params["system"] = system_message
# Add optional parameters if set
if self.temperature is not None:
params["temperature"] = self.temperature
if self.top_p is not None:
params["top_p"] = self.top_p
if self.stop_sequences:
params["stop_sequences"] = self.stop_sequences
# Handle tools for Claude 3+ # Handle tools for Claude 3+
if tools and self.supports_tools: if tools and self.supports_tools:
params["tools"] = self._convert_tools_for_interference(tools) params["tools"] = self._convert_tools_for_interference(tools)
@@ -266,8 +238,6 @@ class AnthropicCompletion(BaseLLM):
continue continue
try: try:
from crewai.llms.providers.utils.common import safe_tool_conversion
name, description, parameters = safe_tool_conversion(tool, "Anthropic") name, description, parameters = safe_tool_conversion(tool, "Anthropic")
except (ImportError, KeyError, ValueError) as e: except (ImportError, KeyError, ValueError) as e:
logging.error(f"Error converting tool to Anthropic format: {e}") logging.error(f"Error converting tool to Anthropic format: {e}")
@@ -341,8 +311,8 @@ class AnthropicCompletion(BaseLLM):
self, self,
params: dict[str, Any], params: dict[str, Any],
available_functions: dict[str, Any] | None = None, available_functions: dict[str, Any] | None = None,
from_task: Any | None = None, from_task: Task | None = None,
from_agent: Any | None = None, from_agent: Agent | None = None,
response_model: type[BaseModel] | None = None, response_model: type[BaseModel] | None = None,
) -> str | Any: ) -> str | Any:
"""Handle non-streaming message completion.""" """Handle non-streaming message completion."""
@@ -357,7 +327,7 @@ class AnthropicCompletion(BaseLLM):
params["tool_choice"] = {"type": "tool", "name": "structured_output"} params["tool_choice"] = {"type": "tool", "name": "structured_output"}
try: try:
response: Message = self.client.messages.create(**params) response: Message = self._client.messages.create(**params)
except Exception as e: except Exception as e:
if is_context_length_exceeded(e): if is_context_length_exceeded(e):
@@ -429,8 +399,8 @@ class AnthropicCompletion(BaseLLM):
self, self,
params: dict[str, Any], params: dict[str, Any],
available_functions: dict[str, Any] | None = None, available_functions: dict[str, Any] | None = None,
from_task: Any | None = None, from_task: Task | None = None,
from_agent: Any | None = None, from_agent: Agent | None = None,
response_model: type[BaseModel] | None = None, response_model: type[BaseModel] | None = None,
) -> str: ) -> str:
"""Handle streaming message completion.""" """Handle streaming message completion."""
@@ -451,7 +421,7 @@ class AnthropicCompletion(BaseLLM):
stream_params = {k: v for k, v in params.items() if k != "stream"} stream_params = {k: v for k, v in params.items() if k != "stream"}
# Make streaming API call # Make streaming API call
with self.client.messages.stream(**stream_params) as stream: with self._client.messages.stream(**stream_params) as stream:
for event in stream: for event in stream:
if hasattr(event, "delta") and hasattr(event.delta, "text"): if hasattr(event, "delta") and hasattr(event.delta, "text"):
text_delta = event.delta.text text_delta = event.delta.text
@@ -525,8 +495,8 @@ class AnthropicCompletion(BaseLLM):
tool_uses: list[ToolUseBlock], tool_uses: list[ToolUseBlock],
params: dict[str, Any], params: dict[str, Any],
available_functions: dict[str, Any], available_functions: dict[str, Any],
from_task: Any | None = None, from_task: Task | None = None,
from_agent: Any | None = None, from_agent: Agent | None = None,
) -> str: ) -> str:
"""Handle the complete tool use conversation flow. """Handle the complete tool use conversation flow.
@@ -579,7 +549,7 @@ class AnthropicCompletion(BaseLLM):
try: try:
# Send tool results back to Claude for final response # Send tool results back to Claude for final response
final_response: Message = self.client.messages.create(**follow_up_params) final_response: Message = self._client.messages.create(**follow_up_params)
# Track token usage for follow-up call # Track token usage for follow-up call
follow_up_usage = self._extract_anthropic_token_usage(final_response) follow_up_usage = self._extract_anthropic_token_usage(final_response)
@@ -626,48 +596,24 @@ class AnthropicCompletion(BaseLLM):
return tool_results[0]["content"] return tool_results[0]["content"]
raise e raise e
def supports_function_calling(self) -> bool:
"""Check if the model supports function calling."""
return self.supports_tools
def supports_stop_words(self) -> bool:
"""Check if the model supports stop words."""
return True # All Claude models support stop sequences
def get_context_window_size(self) -> int: def get_context_window_size(self) -> int:
"""Get the context window size for the model.""" """Get the context window size for the model."""
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO
# Context window sizes for Anthropic models
context_windows = {
"claude-3-5-sonnet": 200000,
"claude-3-5-haiku": 200000,
"claude-3-opus": 200000,
"claude-3-sonnet": 200000,
"claude-3-haiku": 200000,
"claude-3-7-sonnet": 200000,
"claude-2.1": 200000,
"claude-2": 100000,
"claude-instant": 100000,
}
# Find the best match for the model name # Find the best match for the model name
for model_prefix, size in context_windows.items(): for model_prefix, size in ANTHROPIC_CONTEXT_WINDOWS.items():
if self.model.startswith(model_prefix): if self.model.startswith(model_prefix):
return int(size * CONTEXT_WINDOW_USAGE_RATIO) return int(size * CONTEXT_WINDOW_USAGE_RATIO)
# Default context window size for Claude models # Default context window size for Claude models
return int(200000 * CONTEXT_WINDOW_USAGE_RATIO) return int(200000 * CONTEXT_WINDOW_USAGE_RATIO)
def _extract_anthropic_token_usage(self, response: Message) -> dict[str, Any]: @staticmethod
def _extract_anthropic_token_usage(response: Message) -> dict[str, Any]:
"""Extract token usage from Anthropic response.""" """Extract token usage from Anthropic response."""
if hasattr(response, "usage") and response.usage: if response.usage:
usage = response.usage usage = response.usage
input_tokens = getattr(usage, "input_tokens", 0)
output_tokens = getattr(usage, "output_tokens", 0)
return { return {
"input_tokens": input_tokens, "input_tokens": usage.input_tokens,
"output_tokens": output_tokens, "output_tokens": usage.output_tokens,
"total_tokens": input_tokens + output_tokens, "total_tokens": usage.input_tokens + usage.output_tokens,
} }
return {"total_tokens": 0} return {"total_tokens": 0}

View File

@@ -1,12 +1,14 @@
import logging from __future__ import annotations
import os
from typing import Any, cast
from pydantic import BaseModel import logging
from typing import TYPE_CHECKING, Any, cast
from pydantic import BaseModel, Field, PrivateAttr, computed_field, model_validator
from typing_extensions import Self
from crewai.events.types.llm_events import LLMCallType from crewai.events.types.llm_events import LLMCallType
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES
from crewai.llms.base_llm import BaseLLM from crewai.llms.base_llm import BaseLLM
from crewai.llms.hooks.base import BaseInterceptor
from crewai.utilities.agent_utils import is_context_length_exceeded from crewai.utilities.agent_utils import is_context_length_exceeded
from crewai.utilities.exceptions.context_window_exceeding_exception import ( from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededError, LLMContextLengthExceededError,
@@ -14,6 +16,11 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
from crewai.utilities.types import LLMMessage from crewai.utilities.types import LLMMessage
if TYPE_CHECKING:
from crewai.agent import Agent
from crewai.task import Task
try: try:
from google import genai # type: ignore[import-untyped] from google import genai # type: ignore[import-untyped]
from google.genai import types # type: ignore[import-untyped] from google.genai import types # type: ignore[import-untyped]
@@ -24,6 +31,27 @@ except ImportError:
) from None ) from None
GEMINI_CONTEXT_WINDOWS: dict[str, int] = {
"gemini-2.0-flash": 1048576, # 1M tokens
"gemini-2.0-flash-thinking": 32768,
"gemini-2.0-flash-lite": 1048576,
"gemini-2.5-flash": 1048576,
"gemini-2.5-pro": 1048576,
"gemini-1.5-pro": 2097152, # 2M tokens
"gemini-1.5-flash": 1048576,
"gemini-1.5-flash-8b": 1048576,
"gemini-1.0-pro": 32768,
"gemma-3-1b": 32000,
"gemma-3-4b": 128000,
"gemma-3-12b": 128000,
"gemma-3-27b": 128000,
}
# Context window validation constraints
MIN_CONTEXT_WINDOW: int = 1024
MAX_CONTEXT_WINDOW: int = 2097152
class GeminiCompletion(BaseLLM): class GeminiCompletion(BaseLLM):
"""Google Gemini native completion implementation. """Google Gemini native completion implementation.
@@ -31,78 +59,140 @@ class GeminiCompletion(BaseLLM):
offering native function calling, streaming support, and proper Gemini formatting. offering native function calling, streaming support, and proper Gemini formatting.
""" """
def __init__( model: str = Field(
self, default="gemini-2.0-flash-001",
model: str = "gemini-2.0-flash-001", description="Gemini model name (e.g., 'gemini-2.0-flash-001', 'gemini-1.5-pro')",
api_key: str | None = None, )
project: str | None = None, project: str | None = Field(
location: str | None = None, default=None,
temperature: float | None = None, description="Google Cloud project ID (for Vertex AI)",
top_p: float | None = None, )
top_k: int | None = None, location: str = Field(
max_output_tokens: int | None = None, default="us-central1",
stop_sequences: list[str] | None = None, description="Google Cloud location (for Vertex AI)",
stream: bool = False, )
safety_settings: dict[str, Any] | None = None, top_p: float | None = Field(
client_params: dict[str, Any] | None = None, default=None,
interceptor: BaseInterceptor[Any, Any] | None = None, description="Nucleus sampling parameter",
**kwargs: Any, )
): top_k: int | None = Field(
"""Initialize Google Gemini chat completion client. default=None,
description="Top-k sampling parameter",
)
max_output_tokens: int | None = Field(
default=None,
description="Maximum tokens in response",
)
safety_settings: dict[str, Any] | None = Field(
default=None,
description="Safety filter settings",
)
_client: genai.Client = PrivateAttr( # type: ignore[no-any-unimported]
default_factory=genai.Client,
)
Args: @model_validator(mode="after")
model: Gemini model name (e.g., 'gemini-2.0-flash-001', 'gemini-1.5-pro') def initialize_client(self) -> Self:
api_key: Google API key (defaults to GOOGLE_API_KEY or GEMINI_API_KEY env var) """Initialize the Anthropic client after Pydantic validation.
project: Google Cloud project ID (for Vertex AI)
location: Google Cloud location (for Vertex AI, defaults to 'us-central1') This runs after all field validation is complete, ensuring that:
temperature: Sampling temperature (0-2) - All BaseLLM fields are set (model, temperature, stop_sequences, etc.)
top_p: Nucleus sampling parameter - Field validators have run (stop_sequences is normalized to set[str])
top_k: Top-k sampling parameter - API key and other configuration is ready
max_output_tokens: Maximum tokens in response
stop_sequences: Stop sequences
stream: Enable streaming responses
safety_settings: Safety filter settings
client_params: Additional parameters to pass to the Google Gen AI Client constructor.
Supports parameters like http_options, credentials, debug_config, etc.
interceptor: HTTP interceptor (not yet supported for Gemini).
**kwargs: Additional parameters
""" """
if interceptor is not None: self._client = genai.Client(**self._get_client_params())
raise NotImplementedError( return self
"HTTP interceptors are not yet supported for Google Gemini provider. "
"Interceptors are currently supported for OpenAI and Anthropic providers only."
)
super().__init__( # def __init__(
model=model, temperature=temperature, stop=stop_sequences or [], **kwargs # self,
) # model: str = "gemini-2.0-flash-001",
# api_key: str | None = None,
# project: str | None = None,
# location: str | None = None,
# temperature: float | None = None,
# top_p: float | None = None,
# top_k: int | None = None,
# max_output_tokens: int | None = None,
# stop_sequences: list[str] | None = None,
# stream: bool = False,
# safety_settings: dict[str, Any] | None = None,
# client_params: dict[str, Any] | None = None,
# interceptor: BaseInterceptor[Any, Any] | None = None,
# **kwargs: Any,
# # ):
# """Initialize Google Gemini chat completion client.
#
# Args:
# model: Gemini model name (e.g., 'gemini-2.0-flash-001', 'gemini-1.5-pro')
# api_key: Google API key (defaults to GOOGLE_API_KEY or GEMINI_API_KEY env var)
# project: Google Cloud project ID (for Vertex AI)
# location: Google Cloud location (for Vertex AI, defaults to 'us-central1')
# temperature: Sampling temperature (0-2)
# top_p: Nucleus sampling parameter
# top_k: Top-k sampling parameter
# max_output_tokens: Maximum tokens in response
# stop_sequences: Stop sequences
# stream: Enable streaming responses
# safety_settings: Safety filter settings
# client_params: Additional parameters to pass to the Google Gen AI Client constructor.
# Supports parameters like http_options, credentials, debug_config, etc.
# interceptor: HTTP interceptor (not yet supported for Gemini).
# **kwargs: Additional parameters
# """
# if interceptor is not None:
# raise NotImplementedError(
# "HTTP interceptors are not yet supported for Google Gemini provider. "
# "Interceptors are currently supported for OpenAI and Anthropic providers only."
# )
#
# super().__init__(
# model=model, temperature=temperature, stop=stop_sequences or [], **kwargs
# )
#
# # Store client params for later use
# self.client_params = client_params or {}
#
# # Get API configuration with environment variable fallbacks
# self.api_key = (
# api_key or os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY")
# )
# self.project = project or os.getenv("GOOGLE_CLOUD_PROJECT")
# self.location = location or os.getenv("GOOGLE_CLOUD_LOCATION") or "us-central1"
#
# use_vertexai = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true"
#
# self.client = self._initialize_client(use_vertexai)
#
# # Store completion parameters
# self.top_p = top_p
# self.top_k = top_k
# self.max_output_tokens = max_output_tokens
# self.stream = stream
# self.safety_settings = safety_settings or {}
# self.stop_sequences = stop_sequences or []
#
# # Model-specific settings
# self.is_gemini_2 = "gemini-2" in model.lower()
# self.is_gemini_1_5 = "gemini-1.5" in model.lower()
# self.supports_tools = self.is_gemini_1_5 or self.is_gemini_2
# Store client params for later use @computed_field # type: ignore[prop-decorator]
self.client_params = client_params or {} @property
def is_gemini_2(self) -> bool:
"""Check if the model is Gemini 2.x."""
return "gemini-2" in self.model.lower()
# Get API configuration with environment variable fallbacks @computed_field # type: ignore[prop-decorator]
self.api_key = ( @property
api_key or os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY") def is_gemini_1_5(self) -> bool:
) """Check if the model is Gemini 1.5.x."""
self.project = project or os.getenv("GOOGLE_CLOUD_PROJECT") return "gemini-1.5" in self.model.lower()
self.location = location or os.getenv("GOOGLE_CLOUD_LOCATION") or "us-central1"
use_vertexai = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true" @computed_field # type: ignore[prop-decorator]
@property
self.client = self._initialize_client(use_vertexai) def supports_tools(self) -> bool:
"""Check if the model supports tool/function calling."""
# Store completion parameters return self.is_gemini_1_5 or self.is_gemini_2
self.top_p = top_p
self.top_k = top_k
self.max_output_tokens = max_output_tokens
self.stream = stream
self.safety_settings = safety_settings or {}
self.stop_sequences = stop_sequences or []
# Model-specific settings
self.is_gemini_2 = "gemini-2" in model.lower()
self.is_gemini_1_5 = "gemini-1.5" in model.lower()
self.supports_tools = self.is_gemini_1_5 or self.is_gemini_2
@property @property
def stop(self) -> list[str]: def stop(self) -> list[str]:
@@ -142,6 +232,12 @@ class GeminiCompletion(BaseLLM):
if self.client_params: if self.client_params:
client_params.update(self.client_params) client_params.update(self.client_params)
if self.interceptor:
raise NotImplementedError(
"HTTP interceptors are not yet supported for Google Gemini provider. "
"Interceptors are currently supported for OpenAI and Anthropic providers only."
)
if use_vertexai or self.project: if use_vertexai or self.project:
client_params.update( client_params.update(
{ {
@@ -181,7 +277,7 @@ class GeminiCompletion(BaseLLM):
if ( if (
hasattr(self, "client") hasattr(self, "client")
and hasattr(self.client, "vertexai") and hasattr(self._client, "vertexai")
and self.client.vertexai and self.client.vertexai
): ):
# Vertex AI configuration # Vertex AI configuration
@@ -206,8 +302,8 @@ class GeminiCompletion(BaseLLM):
tools: list[dict[str, Any]] | None = None, tools: list[dict[str, Any]] | None = None,
callbacks: list[Any] | None = None, callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None, available_functions: dict[str, Any] | None = None,
from_task: Any | None = None, from_task: Task | None = None,
from_agent: Any | None = None, from_agent: Agent | None = None,
response_model: type[BaseModel] | None = None, response_model: type[BaseModel] | None = None,
) -> str | Any: ) -> str | Any:
"""Call Google Gemini generate content API. """Call Google Gemini generate content API.
@@ -294,7 +390,16 @@ class GeminiCompletion(BaseLLM):
GenerateContentConfig object for Gemini API GenerateContentConfig object for Gemini API
""" """
self.tools = tools self.tools = tools
config_params = {} config_params = self.model_dump(
include={
"temperature",
"top_p",
"top_k",
"max_output_tokens",
"stop_sequences",
"safety_settings",
}
)
# Add system instruction if present # Add system instruction if present
if system_instruction: if system_instruction:
@@ -304,18 +409,6 @@ class GeminiCompletion(BaseLLM):
) )
config_params["system_instruction"] = system_content config_params["system_instruction"] = system_content
# Add generation config parameters
if self.temperature is not None:
config_params["temperature"] = self.temperature
if self.top_p is not None:
config_params["top_p"] = self.top_p
if self.top_k is not None:
config_params["top_k"] = self.top_k
if self.max_output_tokens is not None:
config_params["max_output_tokens"] = self.max_output_tokens
if self.stop_sequences:
config_params["stop_sequences"] = self.stop_sequences
if response_model: if response_model:
config_params["response_mime_type"] = "application/json" config_params["response_mime_type"] = "application/json"
config_params["response_schema"] = response_model.model_json_schema() config_params["response_schema"] = response_model.model_json_schema()
@@ -324,9 +417,6 @@ class GeminiCompletion(BaseLLM):
if tools and self.supports_tools: if tools and self.supports_tools:
config_params["tools"] = self._convert_tools_for_interference(tools) config_params["tools"] = self._convert_tools_for_interference(tools)
if self.safety_settings:
config_params["safety_settings"] = self.safety_settings
return types.GenerateContentConfig(**config_params) return types.GenerateContentConfig(**config_params)
def _convert_tools_for_interference( # type: ignore[no-any-unimported] def _convert_tools_for_interference( # type: ignore[no-any-unimported]
@@ -404,8 +494,8 @@ class GeminiCompletion(BaseLLM):
system_instruction: str | None, system_instruction: str | None,
config: types.GenerateContentConfig, config: types.GenerateContentConfig,
available_functions: dict[str, Any] | None = None, available_functions: dict[str, Any] | None = None,
from_task: Any | None = None, from_task: Task | None = None,
from_agent: Any | None = None, from_agent: Agent | None = None,
response_model: type[BaseModel] | None = None, response_model: type[BaseModel] | None = None,
) -> str | Any: ) -> str | Any:
"""Handle non-streaming content generation.""" """Handle non-streaming content generation."""
@@ -416,7 +506,7 @@ class GeminiCompletion(BaseLLM):
} }
try: try:
response = self.client.models.generate_content(**api_params) response = self._client.models.generate_content(**api_params)
usage = self._extract_token_usage(response) usage = self._extract_token_usage(response)
except Exception as e: except Exception as e:
@@ -470,8 +560,8 @@ class GeminiCompletion(BaseLLM):
contents: list[types.Content], contents: list[types.Content],
config: types.GenerateContentConfig, config: types.GenerateContentConfig,
available_functions: dict[str, Any] | None = None, available_functions: dict[str, Any] | None = None,
from_task: Any | None = None, from_task: Task | None = None,
from_agent: Any | None = None, from_agent: Agent | None = None,
response_model: type[BaseModel] | None = None, response_model: type[BaseModel] | None = None,
) -> str: ) -> str:
"""Handle streaming content generation.""" """Handle streaming content generation."""
@@ -484,7 +574,7 @@ class GeminiCompletion(BaseLLM):
"config": config, "config": config,
} }
for chunk in self.client.models.generate_content_stream(**api_params): for chunk in self._client.models.generate_content_stream(**api_params):
if hasattr(chunk, "text") and chunk.text: if hasattr(chunk, "text") and chunk.text:
full_response += chunk.text full_response += chunk.text
self._emit_stream_chunk_event( self._emit_stream_chunk_event(
@@ -537,52 +627,30 @@ class GeminiCompletion(BaseLLM):
return full_response return full_response
@computed_field # type: ignore[prop-decorator]
@property
def supports_function_calling(self) -> bool: def supports_function_calling(self) -> bool:
"""Check if the model supports function calling.""" """Check if the model supports function calling."""
return self.supports_tools return self.supports_tools
def supports_stop_words(self) -> bool:
"""Check if the model supports stop words."""
return True
def get_context_window_size(self) -> int: def get_context_window_size(self) -> int:
"""Get the context window size for the model.""" """Get the context window size for the model."""
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES
min_context = 1024
max_context = 2097152
for key, value in LLM_CONTEXT_WINDOW_SIZES.items(): for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
if value < min_context or value > max_context: if value < MIN_CONTEXT_WINDOW or value > MAX_CONTEXT_WINDOW:
raise ValueError( raise ValueError(
f"Context window for {key} must be between {min_context} and {max_context}" f"Context window for {key} must be between {MIN_CONTEXT_WINDOW} and {MAX_CONTEXT_WINDOW}"
) )
context_windows = {
"gemini-2.0-flash": 1048576, # 1M tokens
"gemini-2.0-flash-thinking": 32768,
"gemini-2.0-flash-lite": 1048576,
"gemini-2.5-flash": 1048576,
"gemini-2.5-pro": 1048576,
"gemini-1.5-pro": 2097152, # 2M tokens
"gemini-1.5-flash": 1048576,
"gemini-1.5-flash-8b": 1048576,
"gemini-1.0-pro": 32768,
"gemma-3-1b": 32000,
"gemma-3-4b": 128000,
"gemma-3-12b": 128000,
"gemma-3-27b": 128000,
}
# Find the best match for the model name # Find the best match for the model name
for model_prefix, size in context_windows.items(): for model_prefix, size in GEMINI_CONTEXT_WINDOWS.items():
if self.model.startswith(model_prefix): if self.model.startswith(model_prefix):
return int(size * CONTEXT_WINDOW_USAGE_RATIO) return int(size * CONTEXT_WINDOW_USAGE_RATIO)
# Default context window size for Gemini models # Default context window size for Gemini models
return int(1048576 * CONTEXT_WINDOW_USAGE_RATIO) # 1M tokens return int(1048576 * CONTEXT_WINDOW_USAGE_RATIO) # 1M tokens
def _extract_token_usage(self, response: dict[str, Any]) -> dict[str, Any]: @staticmethod
def _extract_token_usage(response: dict[str, Any]) -> dict[str, Any]:
"""Extract token usage from Gemini response.""" """Extract token usage from Gemini response."""
if hasattr(response, "usage_metadata"): if hasattr(response, "usage_metadata"):
usage = response.usage_metadata usage = response.usage_metadata
@@ -594,8 +662,8 @@ class GeminiCompletion(BaseLLM):
} }
return {"total_tokens": 0} return {"total_tokens": 0}
@staticmethod
def _convert_contents_to_dict( # type: ignore[no-any-unimported] def _convert_contents_to_dict( # type: ignore[no-any-unimported]
self,
contents: list[types.Content], contents: list[types.Content],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
"""Convert contents to dict format.""" """Convert contents to dict format."""

View File

@@ -4,16 +4,23 @@ from collections.abc import Iterator
import json import json
import logging import logging
import os import os
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any, Final
import httpx import httpx
from openai import APIConnectionError, NotFoundError, OpenAI from openai import APIConnectionError, NotFoundError, OpenAI
from openai.types.chat import ChatCompletion, ChatCompletionChunk from openai.types.chat import ChatCompletion, ChatCompletionChunk
from openai.types.chat.chat_completion import Choice from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_chunk import ChoiceDelta from openai.types.chat.chat_completion_chunk import ChoiceDelta
from pydantic import BaseModel from pydantic import (
BaseModel,
Field,
PrivateAttr,
model_validator,
)
from typing_extensions import Self
from crewai.events.types.llm_events import LLMCallType from crewai.events.types.llm_events import LLMCallType
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES
from crewai.llms.base_llm import BaseLLM from crewai.llms.base_llm import BaseLLM
from crewai.llms.hooks.transport import HTTPTransport from crewai.llms.hooks.transport import HTTPTransport
from crewai.utilities.agent_utils import is_context_length_exceeded from crewai.utilities.agent_utils import is_context_length_exceeded
@@ -25,11 +32,28 @@ from crewai.utilities.types import LLMMessage
if TYPE_CHECKING: if TYPE_CHECKING:
from crewai.agent.core import Agent from crewai.agent.core import Agent
from crewai.llms.hooks.base import BaseInterceptor
from crewai.task import Task from crewai.task import Task
from crewai.tools.base_tool import BaseTool from crewai.tools.base_tool import BaseTool
OPENAI_CONTEXT_WINDOWS: dict[str, int] = {
"gpt-4": 8192,
"gpt-4o": 128000,
"gpt-4o-mini": 200000,
"gpt-4-turbo": 128000,
"gpt-4.1": 1047576,
"gpt-4.1-mini-2025-04-14": 1047576,
"gpt-4.1-nano-2025-04-14": 1047576,
"o1-preview": 128000,
"o1-mini": 128000,
"o3-mini": 200000,
"o4-mini": 200000,
}
MIN_CONTEXT_WINDOW: Final[int] = 1024
MAX_CONTEXT_WINDOW: Final[int] = 2097152
class OpenAICompletion(BaseLLM): class OpenAICompletion(BaseLLM):
"""OpenAI native completion implementation. """OpenAI native completion implementation.
@@ -37,112 +61,125 @@ class OpenAICompletion(BaseLLM):
offering native structured outputs, function calling, and streaming support. offering native structured outputs, function calling, and streaming support.
""" """
def __init__( model: str = Field(
self, default="gpt-4o",
model: str = "gpt-4o", description="OpenAI model name (e.g., 'gpt-4o')",
api_key: str | None = None, )
base_url: str | None = None, organization: str | None = Field(
organization: str | None = None, default=None,
project: str | None = None, description="Name of the OpenAI organization",
timeout: float | None = None, )
max_retries: int = 2, project: str | None = Field(
default_headers: dict[str, str] | None = None, default=None,
default_query: dict[str, Any] | None = None, description="Name of the OpenAI project",
client_params: dict[str, Any] | None = None, )
temperature: float | None = None, api_base: str | None = Field(
top_p: float | None = None, default=os.getenv("OPENAI_BASE_URL"),
frequency_penalty: float | None = None, description="Base URL for OpenAI API",
presence_penalty: float | None = None, )
max_tokens: int | None = None, default_headers: dict[str, str] | None = Field(
max_completion_tokens: int | None = None, default=None,
seed: int | None = None, description="Default headers for OpenAI API requests",
stream: bool = False, )
response_format: dict[str, Any] | type[BaseModel] | None = None, default_query: dict[str, Any] | None = Field(
logprobs: bool | None = None, default=None,
top_logprobs: int | None = None, description="Default query parameters for OpenAI API requests",
reasoning_effort: str | None = None, )
provider: str | None = None, top_p: float | None = Field(
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None, default=None,
**kwargs: Any, description="Top-p sampling parameter",
) -> None: )
"""Initialize OpenAI chat completion client.""" frequency_penalty: float | None = Field(
default=None,
description="Frequency penalty parameter",
)
presence_penalty: float | None = Field(
default=None,
description="Presence penalty parameter",
)
max_completion_tokens: int | None = Field(
default=None,
description="Maximum tokens for completion",
)
seed: int | None = Field(
default=None,
description="Random seed for reproducibility",
)
response_format: dict[str, Any] | type[BaseModel] | None = Field(
default=None,
description="Response format for structured output",
)
logprobs: bool | None = Field(
default=None,
description="Whether to include log probabilities",
)
top_logprobs: int | None = Field(
default=None,
description="Number of top log probabilities to return",
)
reasoning_effort: str | None = Field(
default=None,
description="Reasoning effort level for o1 models",
)
supports_function_calling: bool = Field(
default=True,
description="Whether the model supports function calling",
)
is_o1_model: bool = Field(
default=False,
description="Whether the model is an o1 model",
)
is_gpt4_model: bool = Field(
default=False,
description="Whether the model is a GPT-4 model",
)
_client: OpenAI = PrivateAttr(
default_factory=OpenAI,
)
if provider is None: @model_validator(mode="after")
provider = kwargs.pop("provider", "openai") def initialize_client(self) -> Self:
"""Initialize the Anthropic client after Pydantic validation.
self.interceptor = interceptor
# Client configuration attributes
self.organization = organization
self.project = project
self.max_retries = max_retries
self.default_headers = default_headers
self.default_query = default_query
self.client_params = client_params
self.timeout = timeout
self.base_url = base_url
self.api_base = kwargs.pop("api_base", None)
super().__init__(
model=model,
temperature=temperature,
api_key=api_key or os.getenv("OPENAI_API_KEY"),
base_url=base_url,
timeout=timeout,
provider=provider,
**kwargs,
)
client_config = self._get_client_params()
if self.interceptor:
transport = HTTPTransport(interceptor=self.interceptor)
http_client = httpx.Client(transport=transport)
client_config["http_client"] = http_client
self.client = OpenAI(**client_config)
# Completion parameters
self.top_p = top_p
self.frequency_penalty = frequency_penalty
self.presence_penalty = presence_penalty
self.max_tokens = max_tokens
self.max_completion_tokens = max_completion_tokens
self.seed = seed
self.stream = stream
self.response_format = response_format
self.logprobs = logprobs
self.top_logprobs = top_logprobs
self.reasoning_effort = reasoning_effort
self.is_o1_model = "o1" in model.lower()
self.is_gpt4_model = "gpt-4" in model.lower()
def _get_client_params(self) -> dict[str, Any]:
"""Get OpenAI client parameters."""
This runs after all field validation is complete, ensuring that:
- All BaseLLM fields are set (model, temperature, stop_sequences, etc.)
- Field validators have run (stop_sequences is normalized to set[str])
- API key and other configuration is ready
"""
if self.api_key is None: if self.api_key is None:
self.api_key = os.getenv("OPENAI_API_KEY") self.api_key = os.getenv("OPENAI_API_KEY")
if self.api_key is None: if self.api_key is None:
raise ValueError("OPENAI_API_KEY is required") raise ValueError("OPENAI_API_KEY is required")
base_params = { self.is_o1_model = "o1" in self.model.lower()
"api_key": self.api_key, self.supports_function_calling = not self.is_o1_model
"organization": self.organization, self.is_gpt4_model = "gpt-4" in self.model.lower()
"project": self.project, self.supports_stop_words = not self.is_o1_model
"base_url": self.base_url
or self.api_base
or os.getenv("OPENAI_BASE_URL")
or None,
"timeout": self.timeout,
"max_retries": self.max_retries,
"default_headers": self.default_headers,
"default_query": self.default_query,
}
client_params = {k: v for k, v in base_params.items() if v is not None} params = self.model_dump(
include={
"api_key",
"organization",
"project",
"base_url",
"timeout",
"max_retries",
"default_headers",
"default_query",
},
exclude_none=True,
)
if self.interceptor:
transport = HTTPTransport(interceptor=self.interceptor)
http_client = httpx.Client(transport=transport)
params["http_client"] = http_client
if self.client_params: if self.client_params:
client_params.update(self.client_params) params.update(self.client_params)
return client_params self._client = OpenAI(**params)
return self
def call( def call(
self, self,
@@ -213,38 +250,26 @@ class OpenAICompletion(BaseLLM):
self, messages: list[LLMMessage], tools: list[dict[str, BaseTool]] | None = None self, messages: list[LLMMessage], tools: list[dict[str, BaseTool]] | None = None
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Prepare parameters for OpenAI chat completion.""" """Prepare parameters for OpenAI chat completion."""
params: dict[str, Any] = { params = self.model_dump(
"model": self.model, include={
"messages": messages, "model",
} "stream",
if self.stream: "temperature",
params["stream"] = self.stream "top_p",
"frequency_penalty",
"presence_penalty",
"max_completion_tokens",
"max_tokens",
"seed",
"logprobs",
"top_logprobs",
"reasoning_effort",
},
exclude_none=True,
)
params["messages"] = messages
params.update(self.additional_params) params.update(self.additional_params)
if self.temperature is not None:
params["temperature"] = self.temperature
if self.top_p is not None:
params["top_p"] = self.top_p
if self.frequency_penalty is not None:
params["frequency_penalty"] = self.frequency_penalty
if self.presence_penalty is not None:
params["presence_penalty"] = self.presence_penalty
if self.max_completion_tokens is not None:
params["max_completion_tokens"] = self.max_completion_tokens
elif self.max_tokens is not None:
params["max_tokens"] = self.max_tokens
if self.seed is not None:
params["seed"] = self.seed
if self.logprobs is not None:
params["logprobs"] = self.logprobs
if self.top_logprobs is not None:
params["top_logprobs"] = self.top_logprobs
# Handle o1 model specific parameters
if self.is_o1_model and self.reasoning_effort:
params["reasoning_effort"] = self.reasoning_effort
if tools: if tools:
params["tools"] = self._convert_tools_for_interference(tools) params["tools"] = self._convert_tools_for_interference(tools)
params["tool_choice"] = "auto" params["tool_choice"] = "auto"
@@ -296,14 +321,14 @@ class OpenAICompletion(BaseLLM):
self, self,
params: dict[str, Any], params: dict[str, Any],
available_functions: dict[str, Any] | None = None, available_functions: dict[str, Any] | None = None,
from_task: Any | None = None, from_task: Task | None = None,
from_agent: Any | None = None, from_agent: Agent | None = None,
response_model: type[BaseModel] | None = None, response_model: type[BaseModel] | None = None,
) -> str | Any: ) -> str | Any:
"""Handle non-streaming chat completion.""" """Handle non-streaming chat completion."""
try: try:
if response_model: if response_model:
parsed_response = self.client.beta.chat.completions.parse( parsed_response = self._client.beta.chat.completions.parse(
**params, **params,
response_format=response_model, response_format=response_model,
) )
@@ -327,7 +352,7 @@ class OpenAICompletion(BaseLLM):
) )
return structured_json return structured_json
response: ChatCompletion = self.client.chat.completions.create(**params) response: ChatCompletion = self._client.chat.completions.create(**params)
usage = self._extract_openai_token_usage(response) usage = self._extract_openai_token_usage(response)
@@ -419,8 +444,8 @@ class OpenAICompletion(BaseLLM):
self, self,
params: dict[str, Any], params: dict[str, Any],
available_functions: dict[str, Any] | None = None, available_functions: dict[str, Any] | None = None,
from_task: Any | None = None, from_task: Task | None = None,
from_agent: Any | None = None, from_agent: Agent | None = None,
response_model: type[BaseModel] | None = None, response_model: type[BaseModel] | None = None,
) -> str: ) -> str:
"""Handle streaming chat completion.""" """Handle streaming chat completion."""
@@ -429,7 +454,7 @@ class OpenAICompletion(BaseLLM):
if response_model: if response_model:
completion_stream: Iterator[ChatCompletionChunk] = ( completion_stream: Iterator[ChatCompletionChunk] = (
self.client.chat.completions.create(**params) self._client.chat.completions.create(**params)
) )
accumulated_content = "" accumulated_content = ""
@@ -472,7 +497,7 @@ class OpenAICompletion(BaseLLM):
) )
return accumulated_content return accumulated_content
stream: Iterator[ChatCompletionChunk] = self.client.chat.completions.create( stream: Iterator[ChatCompletionChunk] = self._client.chat.completions.create(
**params **params
) )
@@ -550,58 +575,31 @@ class OpenAICompletion(BaseLLM):
return full_response return full_response
def supports_function_calling(self) -> bool:
"""Check if the model supports function calling."""
return not self.is_o1_model
def supports_stop_words(self) -> bool:
"""Check if the model supports stop words."""
return not self.is_o1_model
def get_context_window_size(self) -> int: def get_context_window_size(self) -> int:
"""Get the context window size for the model.""" """Get the context window size for the model."""
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES
min_context = 1024
max_context = 2097152
for key, value in LLM_CONTEXT_WINDOW_SIZES.items(): for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
if value < min_context or value > max_context: if value < MIN_CONTEXT_WINDOW or value > MAX_CONTEXT_WINDOW:
raise ValueError( raise ValueError(
f"Context window for {key} must be between {min_context} and {max_context}" f"Context window for {key} must be between {MIN_CONTEXT_WINDOW} and {MAX_CONTEXT_WINDOW}"
) )
# Context window sizes for OpenAI models
context_windows = {
"gpt-4": 8192,
"gpt-4o": 128000,
"gpt-4o-mini": 200000,
"gpt-4-turbo": 128000,
"gpt-4.1": 1047576,
"gpt-4.1-mini-2025-04-14": 1047576,
"gpt-4.1-nano-2025-04-14": 1047576,
"o1-preview": 128000,
"o1-mini": 128000,
"o3-mini": 200000,
"o4-mini": 200000,
}
# Find the best match for the model name # Find the best match for the model name
for model_prefix, size in context_windows.items(): for model_prefix, size in OPENAI_CONTEXT_WINDOWS.items():
if self.model.startswith(model_prefix): if self.model.startswith(model_prefix):
return int(size * CONTEXT_WINDOW_USAGE_RATIO) return int(size * CONTEXT_WINDOW_USAGE_RATIO)
# Default context window size # Default context window size
return int(8192 * CONTEXT_WINDOW_USAGE_RATIO) return int(8192 * CONTEXT_WINDOW_USAGE_RATIO)
def _extract_openai_token_usage(self, response: ChatCompletion) -> dict[str, Any]: @staticmethod
def _extract_openai_token_usage(response: ChatCompletion) -> dict[str, Any]:
"""Extract token usage from OpenAI ChatCompletion response.""" """Extract token usage from OpenAI ChatCompletion response."""
if hasattr(response, "usage") and response.usage: if response.usage:
usage = response.usage usage = response.usage
return { return {
"prompt_tokens": getattr(usage, "prompt_tokens", 0), "prompt_tokens": usage.prompt_tokens,
"completion_tokens": getattr(usage, "completion_tokens", 0), "completion_tokens": usage.completion_tokens,
"total_tokens": getattr(usage, "total_tokens", 0), "total_tokens": usage.total_tokens,
} }
return {"total_tokens": 0} return {"total_tokens": 0}