Compare commits

...

6 Commits

8 changed files with 719 additions and 716 deletions

View File

@@ -618,22 +618,22 @@ class Agent(BaseAgent):
response_template=self.response_template,
).task_execution()
stop_words = [self.i18n.slice("observation")]
stop_sequences = [self.i18n.slice("observation")]
if self.response_template:
stop_words.append(
stop_sequences.append(
self.response_template.split("{{ .Response }}")[1].strip()
)
self.agent_executor = CrewAgentExecutor(
llm=self.llm,
llm=self.llm, # type: ignore[arg-type]
task=task, # type: ignore[arg-type]
agent=self,
crew=self.crew,
tools=parsed_tools,
prompt=prompt,
original_tools=raw_tools,
stop_words=stop_words,
stop_sequences=stop_sequences,
max_iter=self.max_iter,
tools_handler=self.tools_handler,
tools_names=get_tool_names(parsed_tools),
@@ -974,7 +974,9 @@ class Agent(BaseAgent):
path = parsed.path.replace("/", "_").strip("_")
return f"{domain}_{path}" if path else domain
def _get_mcp_tool_schemas(self, server_params: dict) -> dict[str, dict]:
def _get_mcp_tool_schemas(
self, server_params: dict[str, Any]
) -> dict[str, dict[str, Any]] | Any:
"""Get tool schemas from MCP server for wrapper creation with caching."""
server_url = server_params["url"]
@@ -1006,7 +1008,7 @@ class Agent(BaseAgent):
async def _get_mcp_tool_schemas_async(
self, server_params: dict[str, Any]
) -> dict[str, dict]:
) -> dict[str, dict[str, Any]]:
"""Async implementation of MCP tool schema retrieval with timeouts and retries."""
server_url = server_params["url"]
return await self._retry_mcp_discovery(
@@ -1014,7 +1016,7 @@ class Agent(BaseAgent):
)
async def _retry_mcp_discovery(
self, operation_func, server_url: str
self, operation_func: Any, server_url: str
) -> dict[str, dict[str, Any]]:
"""Retry MCP discovery operation with exponential backoff, avoiding try-except in loop."""
last_error = None
@@ -1045,7 +1047,7 @@ class Agent(BaseAgent):
@staticmethod
async def _attempt_mcp_discovery(
operation_func, server_url: str
operation_func: Any, server_url: str
) -> tuple[dict[str, dict[str, Any]] | None, str, bool]:
"""Attempt single MCP discovery operation and return (result, error_message, should_retry)."""
try:
@@ -1149,13 +1151,13 @@ class Agent(BaseAgent):
Field(..., description=field_description),
)
else:
field_definitions[field_name] = (
field_definitions[field_name] = ( # type: ignore[assignment]
field_type | None,
Field(default=None, description=field_description),
)
model_name = f"{tool_name.replace('-', '_').replace(' ', '_')}Schema"
return create_model(model_name, **field_definitions)
return create_model(model_name, **field_definitions) # type: ignore[no-any-return,call-overload]
def _json_type_to_python(self, field_schema: dict[str, Any]) -> type:
"""Convert JSON Schema type to Python type.
@@ -1175,12 +1177,12 @@ class Agent(BaseAgent):
if "const" in option:
types.append(str)
else:
types.append(self._json_type_to_python(option))
types.append(self._json_type_to_python(option)) # type: ignore[arg-type]
unique_types = list(set(types))
if len(unique_types) > 1:
result = unique_types[0]
for t in unique_types[1:]:
result = result | t
result = result | t # type: ignore[assignment]
return result
return unique_types[0]
@@ -1193,10 +1195,10 @@ class Agent(BaseAgent):
"object": dict,
}
return type_mapping.get(json_type, Any)
return type_mapping.get(json_type, Any) # type: ignore[arg-type]
@staticmethod
def _fetch_amp_mcp_servers(mcp_name: str) -> list[dict]:
def _fetch_amp_mcp_servers(mcp_name: str) -> list[dict[str, Any]]:
"""Fetch MCP server configurations from CrewAI AMP API."""
# TODO: Implement AMP API call to "integrations/mcps" endpoint
# Should return list of server configs with URLs
@@ -1435,7 +1437,7 @@ class Agent(BaseAgent):
goal=self.goal,
backstory=self.backstory,
llm=self.llm,
tools=self.tools or [],
tools=self.tools,
max_iterations=self.max_iter,
max_execution_time=self.max_execution_time,
respect_context_window=self.respect_context_window,

View File

@@ -137,7 +137,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
default=False,
description="Enable agent to delegate and ask questions among each other.",
)
tools: list[BaseTool] | None = Field(
tools: list[BaseTool] = Field(
default_factory=list, description="Tools at agents' disposal"
)
max_iter: int = Field(

View File

@@ -73,7 +73,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
max_iter: int,
tools: list[CrewStructuredTool],
tools_names: str,
stop_words: list[str],
stop_sequences: list[str],
tools_description: str,
tools_handler: ToolsHandler,
step_callback: Any = None,
@@ -95,7 +95,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
max_iter: Maximum iterations.
tools: Available tools.
tools_names: Tool names string.
stop_words: Stop word list.
stop_sequences: Stop sequences list for halting generation.
tools_description: Tool descriptions.
tools_handler: Tool handler instance.
step_callback: Optional step callback.
@@ -114,7 +114,6 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self.prompt = prompt
self.tools = tools
self.tools_names = tools_names
self.stop = stop_words
self.max_iter = max_iter
self.callbacks = callbacks or []
self._printer: Printer = Printer()
@@ -131,15 +130,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self.iterations = 0
self.log_error_after = 3
if self.llm:
# This may be mutating the shared llm object and needs further evaluation
existing_stop = getattr(self.llm, "stop", [])
self.llm.stop = list(
set(
existing_stop + self.stop
if isinstance(existing_stop, list)
else self.stop
)
)
self.llm.stop_sequences.extend(stop_sequences)
@property
def use_stop_words(self) -> bool:
@@ -148,7 +139,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
Returns:
bool: True if tool should be used or not.
"""
return self.llm.supports_stop_words() if self.llm else False
return self.llm.supports_stop_words if self.llm else False
def invoke(self, inputs: dict[str, Any]) -> dict[str, Any]:
"""Execute the agent with given inputs.

View File

@@ -20,8 +20,7 @@ from typing import (
)
from dotenv import load_dotenv
import httpx
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
from typing_extensions import Self
from crewai.events.event_bus import crewai_event_bus
@@ -54,7 +53,6 @@ if TYPE_CHECKING:
from litellm.utils import supports_response_schema
from crewai.agent.core import Agent
from crewai.llms.hooks.base import BaseInterceptor
from crewai.task import Task
from crewai.tools.base_tool import BaseTool
from crewai.utilities.types import LLMMessage
@@ -320,7 +318,138 @@ class AccumulatedToolArgs(BaseModel):
class LLM(BaseLLM):
completion_cost: float | None = None
completion_cost: float | None = Field(
default=None, description="The completion cost of the LLM."
)
top_p: float | None = Field(
default=None, description="Sampling probability threshold."
)
n: int | None = Field(
default=None, description="Number of completions to generate."
)
max_completion_tokens: int | None = Field(
default=None,
description="Maximum number of tokens to generate in the completion.",
)
max_tokens: int | None = Field(
default=None,
description="Maximum number of tokens allowed in the prompt + completion.",
)
presence_penalty: float | None = Field(
default=None, description="Penalty on the presence penalty."
)
frequency_penalty: float | None = Field(
default=None, description="Penalty on the frequency penalty."
)
logit_bias: dict[int, float] | None = Field(
default=None,
description="Modifies the likelihood of specified tokens appearing in the completion.",
)
response_format: type[BaseModel] | None = Field(
default=None,
description="Pydantic model class for structured response parsing.",
)
seed: int | None = Field(
default=None,
description="Random seed for reproducibility.",
)
logprobs: int | None = Field(
default=None,
description="Number of top logprobs to return.",
)
top_logprobs: int | None = Field(
default=None,
description="Number of top logprobs to return.",
)
api_base: str | None = Field(
default=None,
description="Base URL for the API endpoint.",
)
api_version: str | None = Field(
default=None,
description="API version to use.",
)
callbacks: list[Any] = Field(
default_factory=list,
description="List of callback handlers for LLM events.",
)
reasoning_effort: Literal["none", "low", "medium", "high"] | None = Field(
default=None,
description="Level of reasoning effort for the LLM.",
)
context_window_size: int = Field(
default=0,
description="The context window size of the LLM.",
)
is_anthropic: bool = Field(
default=False,
description="Indicates if the model is from Anthropic provider.",
)
supports_function_calling: bool = Field(
default=False,
description="Indicates if the model supports function calling.",
)
supports_stop_words: bool = Field(
default=False,
description="Indicates if the model supports stop words.",
)
@model_validator(mode="after")
def initialize_client(self) -> Self:
self.is_anthropic = any(
prefix in self.model.lower() for prefix in ANTHROPIC_PREFIXES
)
try:
provider = self._get_custom_llm_provider()
self.supports_function_calling = litellm.utils.supports_function_calling(
self.model, custom_llm_provider=provider
)
except Exception as e:
logging.error(f"Failed to check function calling support: {e!s}")
self.supports_function_calling = False
try:
params = get_supported_openai_params(model=self.model)
self.supports_stop_words = params is not None and "stop" in params
except Exception as e:
logging.error(f"Failed to get supported params: {e!s}")
self.supports_stop_words = False
with suppress_warnings():
callback_types = [type(callback) for callback in self.callbacks]
for callback in litellm.success_callback[:]:
if type(callback) in callback_types:
litellm.success_callback.remove(callback)
for callback in litellm._async_success_callback[:]:
if type(callback) in callback_types:
litellm._async_success_callback.remove(callback)
litellm.callbacks = self.callbacks
with suppress_warnings():
success_callbacks_str = os.environ.get("LITELLM_SUCCESS_CALLBACKS", "")
success_callbacks: list[str | Callable[..., Any] | CustomLogger] = []
if success_callbacks_str:
success_callbacks = [
cb.strip() for cb in success_callbacks_str.split(",") if cb.strip()
]
failure_callbacks_str = os.environ.get("LITELLM_FAILURE_CALLBACKS", "")
if failure_callbacks_str:
failure_callbacks: list[str | Callable[..., Any] | CustomLogger] = [
cb.strip() for cb in failure_callbacks_str.split(",") if cb.strip()
]
litellm.success_callback = success_callbacks
litellm.failure_callback = failure_callbacks
return self
# @computed_field
# @property
# def is_anthropic(self) -> bool:
# """Determine if the model is from Anthropic provider."""
# anthropic_prefixes = ("anthropic/", "claude-", "claude/")
# return any(prefix in self.model.lower() for prefix in anthropic_prefixes)
def __new__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> LLM:
"""Factory method that routes to native SDK or falls back to LiteLLM."""
@@ -383,98 +512,6 @@ class LLM(BaseLLM):
return None
def __init__(
self,
model: str,
timeout: float | int | None = None,
temperature: float | None = None,
top_p: float | None = None,
n: int | None = None,
stop: str | list[str] | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | float | None = None,
presence_penalty: float | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[int, float] | None = None,
response_format: type[BaseModel] | None = None,
seed: int | None = None,
logprobs: int | None = None,
top_logprobs: int | None = None,
base_url: str | None = None,
api_base: str | None = None,
api_version: str | None = None,
api_key: str | None = None,
callbacks: list[Any] | None = None,
reasoning_effort: Literal["none", "low", "medium", "high"] | None = None,
stream: bool = False,
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
**kwargs: Any,
) -> None:
"""Initialize LLM instance.
Note: This __init__ method is only called for fallback instances.
Native provider instances handle their own initialization in their respective classes.
"""
super().__init__(
model=model,
temperature=temperature,
api_key=api_key,
base_url=base_url,
timeout=timeout,
**kwargs,
)
self.model = model
self.timeout = timeout
self.temperature = temperature
self.top_p = top_p
self.n = n
self.max_completion_tokens = max_completion_tokens
self.max_tokens = max_tokens
self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty
self.logit_bias = logit_bias
self.response_format = response_format
self.seed = seed
self.logprobs = logprobs
self.top_logprobs = top_logprobs
self.base_url = base_url
self.api_base = api_base
self.api_version = api_version
self.api_key = api_key
self.callbacks = callbacks
self.context_window_size = 0
self.reasoning_effort = reasoning_effort
self.additional_params = kwargs
self.is_anthropic = self._is_anthropic_model(model)
self.stream = stream
self.interceptor = interceptor
litellm.drop_params = True
# Normalize self.stop to always be a list[str]
if stop is None:
self.stop: list[str] = []
elif isinstance(stop, str):
self.stop = [stop]
else:
self.stop = stop
self.set_callbacks(callbacks or [])
self.set_env_callbacks()
@staticmethod
def _is_anthropic_model(model: str) -> bool:
"""Determine if the model is from Anthropic provider.
Args:
model: The model identifier string.
Returns:
bool: True if the model is from Anthropic, False otherwise.
"""
anthropic_prefixes = ("anthropic/", "claude-", "claude/")
return any(prefix in model.lower() for prefix in anthropic_prefixes)
def _prepare_completion_params(
self,
messages: str | list[LLMMessage],
@@ -1188,8 +1225,6 @@ class LLM(BaseLLM):
message["role"] = msg_role
# --- 5) Set up callbacks if provided
with suppress_warnings():
if callbacks and len(callbacks) > 0:
self.set_callbacks(callbacks)
try:
# --- 6) Prepare parameters for the completion call
params = self._prepare_completion_params(messages, tools)
@@ -1378,24 +1413,6 @@ class LLM(BaseLLM):
"Please remove response_format or use a supported model."
)
def supports_function_calling(self) -> bool:
try:
provider = self._get_custom_llm_provider()
return litellm.utils.supports_function_calling(
self.model, custom_llm_provider=provider
)
except Exception as e:
logging.error(f"Failed to check function calling support: {e!s}")
return False
def supports_stop_words(self) -> bool:
try:
params = get_supported_openai_params(model=self.model)
return params is not None and "stop" in params
except Exception as e:
logging.error(f"Failed to get supported params: {e!s}")
return False
def get_context_window_size(self) -> int:
"""
Returns the context window size, using 75% of the maximum to avoid
@@ -1425,60 +1442,6 @@ class LLM(BaseLLM):
self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO)
return self.context_window_size
@staticmethod
def set_callbacks(callbacks: list[Any]) -> None:
"""
Attempt to keep a single set of callbacks in litellm by removing old
duplicates and adding new ones.
"""
with suppress_warnings():
callback_types = [type(callback) for callback in callbacks]
for callback in litellm.success_callback[:]:
if type(callback) in callback_types:
litellm.success_callback.remove(callback)
for callback in litellm._async_success_callback[:]:
if type(callback) in callback_types:
litellm._async_success_callback.remove(callback)
litellm.callbacks = callbacks
@staticmethod
def set_env_callbacks() -> None:
"""Sets the success and failure callbacks for the LiteLLM library from environment variables.
This method reads the `LITELLM_SUCCESS_CALLBACKS` and `LITELLM_FAILURE_CALLBACKS`
environment variables, which should contain comma-separated lists of callback names.
It then assigns these lists to `litellm.success_callback` and `litellm.failure_callback`,
respectively.
If the environment variables are not set or are empty, the corresponding callback lists
will be set to empty lists.
Examples:
LITELLM_SUCCESS_CALLBACKS="langfuse,langsmith"
LITELLM_FAILURE_CALLBACKS="langfuse"
This will set `litellm.success_callback` to ["langfuse", "langsmith"] and
`litellm.failure_callback` to ["langfuse"].
"""
with suppress_warnings():
success_callbacks_str = os.environ.get("LITELLM_SUCCESS_CALLBACKS", "")
success_callbacks: list[str | Callable[..., Any] | CustomLogger] = []
if success_callbacks_str:
success_callbacks = [
cb.strip() for cb in success_callbacks_str.split(",") if cb.strip()
]
failure_callbacks_str = os.environ.get("LITELLM_FAILURE_CALLBACKS", "")
if failure_callbacks_str:
failure_callbacks: list[str | Callable[..., Any] | CustomLogger] = [
cb.strip() for cb in failure_callbacks_str.split(",") if cb.strip()
]
litellm.success_callback = success_callbacks
litellm.failure_callback = failure_callbacks
def __copy__(self) -> LLM:
"""Create a shallow copy of the LLM instance."""
# Filter out parameters that are already explicitly passed to avoid conflicts
@@ -1539,7 +1502,7 @@ class LLM(BaseLLM):
**filtered_params,
)
def __deepcopy__(self, memo: dict[int, Any] | None) -> LLM:
def __deepcopy__(self, memo: dict[int, Any] | None) -> LLM: # type: ignore[override]
"""Create a deep copy of the LLM instance."""
import copy

View File

@@ -13,8 +13,9 @@ import logging
import re
from typing import TYPE_CHECKING, Any, Final
from pydantic import BaseModel
from pydantic import AliasChoices, BaseModel, Field, PrivateAttr, field_validator
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.llm_events import (
LLMCallCompletedEvent,
@@ -28,6 +29,7 @@ from crewai.events.types.tool_usage_events import (
ToolUsageFinishedEvent,
ToolUsageStartedEvent,
)
from crewai.llms.hooks import BaseInterceptor
from crewai.types.usage_metrics import UsageMetrics
@@ -43,7 +45,7 @@ DEFAULT_SUPPORTS_STOP_WORDS: Final[bool] = True
_JSON_EXTRACTION_PATTERN: Final[re.Pattern[str]] = re.compile(r"\{.*}", re.DOTALL)
class BaseLLM(ABC):
class BaseLLM(BaseModel, ABC):
"""Abstract base class for LLM implementations.
This class defines the interface that all LLM implementations must follow.
@@ -55,70 +57,105 @@ class BaseLLM(ABC):
implement proper validation for input parameters and provide clear error
messages when things go wrong.
Attributes:
model: The model identifier/name.
temperature: Optional temperature setting for response generation.
stop: A list of stop sequences that the LLM should use to stop generation.
additional_params: Additional provider-specific parameters.
"""
is_litellm: bool = False
provider: str | re.Pattern[str] = Field(
default="openai", description="The provider of the LLM."
)
model: str = Field(description="The model identifier/name.")
temperature: float | None = Field(
default=None, ge=0, le=2, description="Temperature for response generation."
)
api_key: str | None = Field(default=None, description="API key for authentication.")
base_url: str | None = Field(default=None, description="Base URL for API calls.")
timeout: float | None = Field(default=None, description="Timeout for API calls.")
max_retries: int = Field(
default=2, description="Maximum number of API requests to make."
)
max_tokens: int | None = Field(
default=None, description="Maximum tokens for response generation."
)
stream: bool | None = Field(default=False, description="Stream the API requests.")
client: Any = Field(description="Underlying LLM client instance.")
interceptor: BaseInterceptor[Any, Any] | None = Field(
default=None,
description="An optional HTTPX interceptor for modifying requests/responses.",
)
client_params: dict[str, Any] = Field(
default_factory=dict,
description="Additional parameters for the underlying LLM client.",
)
supports_stop_words: bool = Field(
default=DEFAULT_SUPPORTS_STOP_WORDS,
description="Whether or not to support stop words.",
)
stop_sequences: list[str] = Field(
default_factory=list,
validation_alias=AliasChoices("stop_sequences", "stop"),
description="Stop sequences for generation (synchronized with stop).",
)
is_litellm: bool = Field(
default=False, description="Is this LLM implementation in litellm?"
)
additional_params: dict[str, Any] = Field(
default_factory=dict,
description="Additional parameters for LLM calls.",
)
_token_usage: TokenProcess = PrivateAttr(default_factory=TokenProcess)
def __init__(
self,
model: str,
temperature: float | None = None,
api_key: str | None = None,
base_url: str | None = None,
provider: str | None = None,
**kwargs: Any,
) -> None:
"""Initialize the BaseLLM with default attributes.
@field_validator("provider", mode="before")
@classmethod
def extract_provider_from_model(
cls, v: str | re.Pattern[str] | None, info: Any
) -> str | re.Pattern[str]:
"""Extract provider from model string if not explicitly provided.
Args:
model: The model identifier/name.
temperature: Optional temperature setting for response generation.
stop: Optional list of stop sequences for generation.
**kwargs: Additional provider-specific parameters.
v: Provided provider value (can be str, Pattern, or None)
info: Validation info containing other field values
Returns:
Provider name (str) or Pattern
"""
if not model:
raise ValueError("Model name is required and cannot be empty")
# If provider explicitly provided, validate and return it
if v is not None:
if not isinstance(v, (str, re.Pattern)):
raise ValueError(f"Provider must be str or Pattern, got {type(v)}")
return v
self.model = model
self.temperature = temperature
self.api_key = api_key
self.base_url = base_url
# Store additional parameters for provider-specific use
self.additional_params = kwargs
self._provider = provider or "openai"
model: str = info.data.get("model", "")
if "/" in model:
return model.partition("/")[0]
return "openai"
stop = kwargs.pop("stop", None)
if stop is None:
self.stop: list[str] = []
elif isinstance(stop, str):
self.stop = [stop]
elif isinstance(stop, list):
self.stop = stop
else:
self.stop = []
@field_validator("stop_sequences", mode="before")
@classmethod
def normalize_stop_sequences(
cls, v: str | list[str] | set[str] | None
) -> list[str]:
"""Validate and normalize stop sequences.
self._token_usage = {
"total_tokens": 0,
"prompt_tokens": 0,
"completion_tokens": 0,
"successful_requests": 0,
"cached_prompt_tokens": 0,
}
Converts string to list and handles None values.
AliasChoices handles accepting both 'stop' and 'stop_sequences' parameter names.
"""
if v is None:
return []
if isinstance(v, str):
return [v]
if isinstance(v, set):
return list(v)
if isinstance(v, list):
return v
return []
@property
def provider(self) -> str:
"""Get the provider of the LLM."""
return self._provider
@provider.setter
def provider(self, value: str) -> None:
"""Set the provider of the LLM."""
self._provider = value
def stop(self) -> list[str]:
"""Alias for stop_sequences to maintain backward compatibility."""
return self.stop_sequences
@abstractmethod
def call(
@@ -171,14 +208,6 @@ class BaseLLM(ABC):
"""
return tools
def supports_stop_words(self) -> bool:
"""Check if the LLM supports stop words.
Returns:
True if the LLM supports stop words, False otherwise.
"""
return DEFAULT_SUPPORTS_STOP_WORDS
def _supports_stop_words_implementation(self) -> bool:
"""Check if stop words are configured for this LLM instance.
@@ -506,7 +535,7 @@ class BaseLLM(ABC):
"""
if "/" in model:
return model.partition("/")[0]
return "openai" # Default provider
return "openai"
def _track_token_usage_internal(self, usage_data: dict[str, Any]) -> None:
"""Track token usage internally in the LLM instance.
@@ -535,11 +564,11 @@ class BaseLLM(ABC):
or 0
)
self._token_usage["prompt_tokens"] += prompt_tokens
self._token_usage["completion_tokens"] += completion_tokens
self._token_usage["total_tokens"] += prompt_tokens + completion_tokens
self._token_usage["successful_requests"] += 1
self._token_usage["cached_prompt_tokens"] += cached_tokens
self._token_usage.prompt_tokens += prompt_tokens
self._token_usage.completion_tokens += completion_tokens
self._token_usage.total_tokens += prompt_tokens + completion_tokens
self._token_usage.successful_requests += 1
self._token_usage.cached_prompt_tokens += cached_tokens
def get_token_usage_summary(self) -> UsageMetrics:
"""Get summary of token usage for this LLM instance.
@@ -547,4 +576,10 @@ class BaseLLM(ABC):
Returns:
Dictionary with token usage totals
"""
return UsageMetrics(**self._token_usage)
return UsageMetrics(
prompt_tokens=self._token_usage.prompt_tokens,
completion_tokens=self._token_usage.completion_tokens,
total_tokens=self._token_usage.total_tokens,
successful_requests=self._token_usage.successful_requests,
cached_prompt_tokens=self._token_usage.cached_prompt_tokens,
)

View File

@@ -5,11 +5,14 @@ import logging
import os
from typing import TYPE_CHECKING, Any, cast
from pydantic import BaseModel
from pydantic import BaseModel, Field, PrivateAttr, computed_field, model_validator
from typing_extensions import Self
from crewai.events.types.llm_events import LLMCallType
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO
from crewai.llms.base_llm import BaseLLM
from crewai.llms.hooks.transport import HTTPTransport
from crewai.llms.providers.utils.common import safe_tool_conversion
from crewai.utilities.agent_utils import is_context_length_exceeded
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededError,
@@ -18,7 +21,8 @@ from crewai.utilities.types import LLMMessage
if TYPE_CHECKING:
from crewai.llms.hooks.base import BaseInterceptor
from crewai.agent import Agent
from crewai.task import Task
try:
from anthropic import Anthropic
@@ -31,6 +35,19 @@ except ImportError:
) from None
ANTHROPIC_CONTEXT_WINDOWS: dict[str, int] = {
"claude-3-5-sonnet": 200000,
"claude-3-5-haiku": 200000,
"claude-3-opus": 200000,
"claude-3-sonnet": 200000,
"claude-3-haiku": 200000,
"claude-3-7-sonnet": 200000,
"claude-2.1": 200000,
"claude-2": 100000,
"claude-instant": 100000,
}
class AnthropicCompletion(BaseLLM):
"""Anthropic native completion implementation.
@@ -38,110 +55,69 @@ class AnthropicCompletion(BaseLLM):
offering native tool use, streaming support, and proper message formatting.
"""
def __init__(
self,
model: str = "claude-3-5-sonnet-20241022",
api_key: str | None = None,
base_url: str | None = None,
timeout: float | None = None,
max_retries: int = 2,
temperature: float | None = None,
max_tokens: int = 4096, # Required for Anthropic
top_p: float | None = None,
stop_sequences: list[str] | None = None,
stream: bool = False,
client_params: dict[str, Any] | None = None,
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
**kwargs: Any,
):
"""Initialize Anthropic chat completion client.
model: str = Field(
default="claude-3-5-sonnet-20241022",
description="Anthropic model name (e.g., 'claude-3-5-sonnet-20241022')",
)
max_tokens: int = Field(
default=4096,
description="Maximum number of allowed tokens in response.",
)
top_p: float | None = Field(
default=None,
description="Nucleus sampling parameter.",
)
_client: Anthropic = PrivateAttr(
default_factory=Anthropic,
)
Args:
model: Anthropic model name (e.g., 'claude-3-5-sonnet-20241022')
api_key: Anthropic API key (defaults to ANTHROPIC_API_KEY env var)
base_url: Custom base URL for Anthropic API
timeout: Request timeout in seconds
max_retries: Maximum number of retries
temperature: Sampling temperature (0-1)
max_tokens: Maximum tokens in response (required for Anthropic)
top_p: Nucleus sampling parameter
stop_sequences: Stop sequences (Anthropic uses stop_sequences, not stop)
stream: Enable streaming responses
client_params: Additional parameters for the Anthropic client
interceptor: HTTP interceptor for modifying requests/responses at transport level.
**kwargs: Additional parameters
@model_validator(mode="after")
def initialize_client(self) -> Self:
"""Initialize the Anthropic client after Pydantic validation.
This runs after all field validation is complete, ensuring that:
- All BaseLLM fields are set (model, temperature, stop_sequences, etc.)
- Field validators have run (stop_sequences is normalized to set[str])
- API key and other configuration is ready
"""
super().__init__(
model=model, temperature=temperature, stop=stop_sequences or [], **kwargs
)
# Client params
self.interceptor = interceptor
self.client_params = client_params
self.base_url = base_url
self.timeout = timeout
self.max_retries = max_retries
self.client = Anthropic(**self._get_client_params())
# Store completion parameters
self.max_tokens = max_tokens
self.top_p = top_p
self.stream = stream
self.stop_sequences = stop_sequences or []
# Model-specific settings
self.is_claude_3 = "claude-3" in model.lower()
self.supports_tools = self.is_claude_3 # Claude 3+ supports tool use
@property
def stop(self) -> list[str]:
"""Get stop sequences sent to the API."""
return self.stop_sequences
@stop.setter
def stop(self, value: list[str] | str | None) -> None:
"""Set stop sequences.
Synchronizes stop_sequences to ensure values set by CrewAgentExecutor
are properly sent to the Anthropic API.
Args:
value: Stop sequences as a list, single string, or None
"""
if value is None:
self.stop_sequences = []
elif isinstance(value, str):
self.stop_sequences = [value]
elif isinstance(value, list):
self.stop_sequences = value
else:
self.stop_sequences = []
def _get_client_params(self) -> dict[str, Any]:
"""Get client parameters."""
if self.api_key is None:
self.api_key = os.getenv("ANTHROPIC_API_KEY")
if self.api_key is None:
raise ValueError("ANTHROPIC_API_KEY is required")
client_params = {
"api_key": self.api_key,
"base_url": self.base_url,
"timeout": self.timeout,
"max_retries": self.max_retries,
}
params = self.model_dump(
include={"api_key", "base_url", "timeout", "max_retries"},
exclude_none=True,
)
if self.interceptor:
transport = HTTPTransport(interceptor=self.interceptor)
http_client = httpx.Client(transport=transport)
client_params["http_client"] = http_client # type: ignore[assignment]
params["http_client"] = http_client
if self.client_params:
client_params.update(self.client_params)
params.update(self.client_params)
return client_params
self._client = Anthropic(**params)
return self
@computed_field # type: ignore[prop-decorator]
@property
def is_claude_3(self) -> bool:
"""Check if the model is Claude 3 or higher."""
return "claude-3" in self.model.lower()
@computed_field # type: ignore[prop-decorator]
@property
def supports_tools(self) -> bool:
"""Check if the model supports tool use."""
return self.is_claude_3
@computed_field # type: ignore[prop-decorator]
@property
def supports_function_calling(self) -> bool:
"""Check if the model supports function calling."""
return self.supports_tools
def call(
self,
@@ -149,8 +125,8 @@ class AnthropicCompletion(BaseLLM):
tools: list[dict[str, Any]] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Call Anthropic messages API.
@@ -229,25 +205,21 @@ class AnthropicCompletion(BaseLLM):
Returns:
Parameters dictionary for Anthropic API
"""
params = {
"model": self.model,
"messages": messages,
"max_tokens": self.max_tokens,
"stream": self.stream,
}
params = self.model_dump(
include={
"model",
"max_tokens",
"stream",
"temperature",
"top_p",
"stop_sequences",
},
)
params["messages"] = messages
# Add system message if present
if system_message:
params["system"] = system_message
# Add optional parameters if set
if self.temperature is not None:
params["temperature"] = self.temperature
if self.top_p is not None:
params["top_p"] = self.top_p
if self.stop_sequences:
params["stop_sequences"] = self.stop_sequences
# Handle tools for Claude 3+
if tools and self.supports_tools:
params["tools"] = self._convert_tools_for_interference(tools)
@@ -266,8 +238,6 @@ class AnthropicCompletion(BaseLLM):
continue
try:
from crewai.llms.providers.utils.common import safe_tool_conversion
name, description, parameters = safe_tool_conversion(tool, "Anthropic")
except (ImportError, KeyError, ValueError) as e:
logging.error(f"Error converting tool to Anthropic format: {e}")
@@ -341,8 +311,8 @@ class AnthropicCompletion(BaseLLM):
self,
params: dict[str, Any],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle non-streaming message completion."""
@@ -357,7 +327,7 @@ class AnthropicCompletion(BaseLLM):
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
try:
response: Message = self.client.messages.create(**params)
response: Message = self._client.messages.create(**params)
except Exception as e:
if is_context_length_exceeded(e):
@@ -429,8 +399,8 @@ class AnthropicCompletion(BaseLLM):
self,
params: dict[str, Any],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
response_model: type[BaseModel] | None = None,
) -> str:
"""Handle streaming message completion."""
@@ -451,7 +421,7 @@ class AnthropicCompletion(BaseLLM):
stream_params = {k: v for k, v in params.items() if k != "stream"}
# Make streaming API call
with self.client.messages.stream(**stream_params) as stream:
with self._client.messages.stream(**stream_params) as stream:
for event in stream:
if hasattr(event, "delta") and hasattr(event.delta, "text"):
text_delta = event.delta.text
@@ -525,8 +495,8 @@ class AnthropicCompletion(BaseLLM):
tool_uses: list[ToolUseBlock],
params: dict[str, Any],
available_functions: dict[str, Any],
from_task: Any | None = None,
from_agent: Any | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
) -> str:
"""Handle the complete tool use conversation flow.
@@ -579,7 +549,7 @@ class AnthropicCompletion(BaseLLM):
try:
# Send tool results back to Claude for final response
final_response: Message = self.client.messages.create(**follow_up_params)
final_response: Message = self._client.messages.create(**follow_up_params)
# Track token usage for follow-up call
follow_up_usage = self._extract_anthropic_token_usage(final_response)
@@ -626,48 +596,24 @@ class AnthropicCompletion(BaseLLM):
return tool_results[0]["content"]
raise e
def supports_function_calling(self) -> bool:
"""Check if the model supports function calling."""
return self.supports_tools
def supports_stop_words(self) -> bool:
"""Check if the model supports stop words."""
return True # All Claude models support stop sequences
def get_context_window_size(self) -> int:
"""Get the context window size for the model."""
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO
# Context window sizes for Anthropic models
context_windows = {
"claude-3-5-sonnet": 200000,
"claude-3-5-haiku": 200000,
"claude-3-opus": 200000,
"claude-3-sonnet": 200000,
"claude-3-haiku": 200000,
"claude-3-7-sonnet": 200000,
"claude-2.1": 200000,
"claude-2": 100000,
"claude-instant": 100000,
}
# Find the best match for the model name
for model_prefix, size in context_windows.items():
for model_prefix, size in ANTHROPIC_CONTEXT_WINDOWS.items():
if self.model.startswith(model_prefix):
return int(size * CONTEXT_WINDOW_USAGE_RATIO)
# Default context window size for Claude models
return int(200000 * CONTEXT_WINDOW_USAGE_RATIO)
def _extract_anthropic_token_usage(self, response: Message) -> dict[str, Any]:
@staticmethod
def _extract_anthropic_token_usage(response: Message) -> dict[str, Any]:
"""Extract token usage from Anthropic response."""
if hasattr(response, "usage") and response.usage:
if response.usage:
usage = response.usage
input_tokens = getattr(usage, "input_tokens", 0)
output_tokens = getattr(usage, "output_tokens", 0)
return {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens,
"input_tokens": usage.input_tokens,
"output_tokens": usage.output_tokens,
"total_tokens": usage.input_tokens + usage.output_tokens,
}
return {"total_tokens": 0}

View File

@@ -1,12 +1,14 @@
import logging
import os
from typing import Any, cast
from __future__ import annotations
from pydantic import BaseModel
import logging
from typing import TYPE_CHECKING, Any, cast
from pydantic import BaseModel, Field, PrivateAttr, computed_field, model_validator
from typing_extensions import Self
from crewai.events.types.llm_events import LLMCallType
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES
from crewai.llms.base_llm import BaseLLM
from crewai.llms.hooks.base import BaseInterceptor
from crewai.utilities.agent_utils import is_context_length_exceeded
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededError,
@@ -14,6 +16,11 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
from crewai.utilities.types import LLMMessage
if TYPE_CHECKING:
from crewai.agent import Agent
from crewai.task import Task
try:
from google import genai # type: ignore[import-untyped]
from google.genai import types # type: ignore[import-untyped]
@@ -24,6 +31,27 @@ except ImportError:
) from None
GEMINI_CONTEXT_WINDOWS: dict[str, int] = {
"gemini-2.0-flash": 1048576, # 1M tokens
"gemini-2.0-flash-thinking": 32768,
"gemini-2.0-flash-lite": 1048576,
"gemini-2.5-flash": 1048576,
"gemini-2.5-pro": 1048576,
"gemini-1.5-pro": 2097152, # 2M tokens
"gemini-1.5-flash": 1048576,
"gemini-1.5-flash-8b": 1048576,
"gemini-1.0-pro": 32768,
"gemma-3-1b": 32000,
"gemma-3-4b": 128000,
"gemma-3-12b": 128000,
"gemma-3-27b": 128000,
}
# Context window validation constraints
MIN_CONTEXT_WINDOW: int = 1024
MAX_CONTEXT_WINDOW: int = 2097152
class GeminiCompletion(BaseLLM):
"""Google Gemini native completion implementation.
@@ -31,78 +59,140 @@ class GeminiCompletion(BaseLLM):
offering native function calling, streaming support, and proper Gemini formatting.
"""
def __init__(
self,
model: str = "gemini-2.0-flash-001",
api_key: str | None = None,
project: str | None = None,
location: str | None = None,
temperature: float | None = None,
top_p: float | None = None,
top_k: int | None = None,
max_output_tokens: int | None = None,
stop_sequences: list[str] | None = None,
stream: bool = False,
safety_settings: dict[str, Any] | None = None,
client_params: dict[str, Any] | None = None,
interceptor: BaseInterceptor[Any, Any] | None = None,
**kwargs: Any,
):
"""Initialize Google Gemini chat completion client.
model: str = Field(
default="gemini-2.0-flash-001",
description="Gemini model name (e.g., 'gemini-2.0-flash-001', 'gemini-1.5-pro')",
)
project: str | None = Field(
default=None,
description="Google Cloud project ID (for Vertex AI)",
)
location: str = Field(
default="us-central1",
description="Google Cloud location (for Vertex AI)",
)
top_p: float | None = Field(
default=None,
description="Nucleus sampling parameter",
)
top_k: int | None = Field(
default=None,
description="Top-k sampling parameter",
)
max_output_tokens: int | None = Field(
default=None,
description="Maximum tokens in response",
)
safety_settings: dict[str, Any] | None = Field(
default=None,
description="Safety filter settings",
)
_client: genai.Client = PrivateAttr( # type: ignore[no-any-unimported]
default_factory=genai.Client,
)
Args:
model: Gemini model name (e.g., 'gemini-2.0-flash-001', 'gemini-1.5-pro')
api_key: Google API key (defaults to GOOGLE_API_KEY or GEMINI_API_KEY env var)
project: Google Cloud project ID (for Vertex AI)
location: Google Cloud location (for Vertex AI, defaults to 'us-central1')
temperature: Sampling temperature (0-2)
top_p: Nucleus sampling parameter
top_k: Top-k sampling parameter
max_output_tokens: Maximum tokens in response
stop_sequences: Stop sequences
stream: Enable streaming responses
safety_settings: Safety filter settings
client_params: Additional parameters to pass to the Google Gen AI Client constructor.
Supports parameters like http_options, credentials, debug_config, etc.
interceptor: HTTP interceptor (not yet supported for Gemini).
**kwargs: Additional parameters
@model_validator(mode="after")
def initialize_client(self) -> Self:
"""Initialize the Anthropic client after Pydantic validation.
This runs after all field validation is complete, ensuring that:
- All BaseLLM fields are set (model, temperature, stop_sequences, etc.)
- Field validators have run (stop_sequences is normalized to set[str])
- API key and other configuration is ready
"""
if interceptor is not None:
raise NotImplementedError(
"HTTP interceptors are not yet supported for Google Gemini provider. "
"Interceptors are currently supported for OpenAI and Anthropic providers only."
)
self._client = genai.Client(**self._get_client_params())
return self
super().__init__(
model=model, temperature=temperature, stop=stop_sequences or [], **kwargs
)
# def __init__(
# self,
# model: str = "gemini-2.0-flash-001",
# api_key: str | None = None,
# project: str | None = None,
# location: str | None = None,
# temperature: float | None = None,
# top_p: float | None = None,
# top_k: int | None = None,
# max_output_tokens: int | None = None,
# stop_sequences: list[str] | None = None,
# stream: bool = False,
# safety_settings: dict[str, Any] | None = None,
# client_params: dict[str, Any] | None = None,
# interceptor: BaseInterceptor[Any, Any] | None = None,
# **kwargs: Any,
# # ):
# """Initialize Google Gemini chat completion client.
#
# Args:
# model: Gemini model name (e.g., 'gemini-2.0-flash-001', 'gemini-1.5-pro')
# api_key: Google API key (defaults to GOOGLE_API_KEY or GEMINI_API_KEY env var)
# project: Google Cloud project ID (for Vertex AI)
# location: Google Cloud location (for Vertex AI, defaults to 'us-central1')
# temperature: Sampling temperature (0-2)
# top_p: Nucleus sampling parameter
# top_k: Top-k sampling parameter
# max_output_tokens: Maximum tokens in response
# stop_sequences: Stop sequences
# stream: Enable streaming responses
# safety_settings: Safety filter settings
# client_params: Additional parameters to pass to the Google Gen AI Client constructor.
# Supports parameters like http_options, credentials, debug_config, etc.
# interceptor: HTTP interceptor (not yet supported for Gemini).
# **kwargs: Additional parameters
# """
# if interceptor is not None:
# raise NotImplementedError(
# "HTTP interceptors are not yet supported for Google Gemini provider. "
# "Interceptors are currently supported for OpenAI and Anthropic providers only."
# )
#
# super().__init__(
# model=model, temperature=temperature, stop=stop_sequences or [], **kwargs
# )
#
# # Store client params for later use
# self.client_params = client_params or {}
#
# # Get API configuration with environment variable fallbacks
# self.api_key = (
# api_key or os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY")
# )
# self.project = project or os.getenv("GOOGLE_CLOUD_PROJECT")
# self.location = location or os.getenv("GOOGLE_CLOUD_LOCATION") or "us-central1"
#
# use_vertexai = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true"
#
# self.client = self._initialize_client(use_vertexai)
#
# # Store completion parameters
# self.top_p = top_p
# self.top_k = top_k
# self.max_output_tokens = max_output_tokens
# self.stream = stream
# self.safety_settings = safety_settings or {}
# self.stop_sequences = stop_sequences or []
#
# # Model-specific settings
# self.is_gemini_2 = "gemini-2" in model.lower()
# self.is_gemini_1_5 = "gemini-1.5" in model.lower()
# self.supports_tools = self.is_gemini_1_5 or self.is_gemini_2
# Store client params for later use
self.client_params = client_params or {}
@computed_field # type: ignore[prop-decorator]
@property
def is_gemini_2(self) -> bool:
"""Check if the model is Gemini 2.x."""
return "gemini-2" in self.model.lower()
# Get API configuration with environment variable fallbacks
self.api_key = (
api_key or os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY")
)
self.project = project or os.getenv("GOOGLE_CLOUD_PROJECT")
self.location = location or os.getenv("GOOGLE_CLOUD_LOCATION") or "us-central1"
@computed_field # type: ignore[prop-decorator]
@property
def is_gemini_1_5(self) -> bool:
"""Check if the model is Gemini 1.5.x."""
return "gemini-1.5" in self.model.lower()
use_vertexai = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true"
self.client = self._initialize_client(use_vertexai)
# Store completion parameters
self.top_p = top_p
self.top_k = top_k
self.max_output_tokens = max_output_tokens
self.stream = stream
self.safety_settings = safety_settings or {}
self.stop_sequences = stop_sequences or []
# Model-specific settings
self.is_gemini_2 = "gemini-2" in model.lower()
self.is_gemini_1_5 = "gemini-1.5" in model.lower()
self.supports_tools = self.is_gemini_1_5 or self.is_gemini_2
@computed_field # type: ignore[prop-decorator]
@property
def supports_tools(self) -> bool:
"""Check if the model supports tool/function calling."""
return self.is_gemini_1_5 or self.is_gemini_2
@property
def stop(self) -> list[str]:
@@ -142,6 +232,12 @@ class GeminiCompletion(BaseLLM):
if self.client_params:
client_params.update(self.client_params)
if self.interceptor:
raise NotImplementedError(
"HTTP interceptors are not yet supported for Google Gemini provider. "
"Interceptors are currently supported for OpenAI and Anthropic providers only."
)
if use_vertexai or self.project:
client_params.update(
{
@@ -181,7 +277,7 @@ class GeminiCompletion(BaseLLM):
if (
hasattr(self, "client")
and hasattr(self.client, "vertexai")
and hasattr(self._client, "vertexai")
and self.client.vertexai
):
# Vertex AI configuration
@@ -206,8 +302,8 @@ class GeminiCompletion(BaseLLM):
tools: list[dict[str, Any]] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Call Google Gemini generate content API.
@@ -294,7 +390,16 @@ class GeminiCompletion(BaseLLM):
GenerateContentConfig object for Gemini API
"""
self.tools = tools
config_params = {}
config_params = self.model_dump(
include={
"temperature",
"top_p",
"top_k",
"max_output_tokens",
"stop_sequences",
"safety_settings",
}
)
# Add system instruction if present
if system_instruction:
@@ -304,18 +409,6 @@ class GeminiCompletion(BaseLLM):
)
config_params["system_instruction"] = system_content
# Add generation config parameters
if self.temperature is not None:
config_params["temperature"] = self.temperature
if self.top_p is not None:
config_params["top_p"] = self.top_p
if self.top_k is not None:
config_params["top_k"] = self.top_k
if self.max_output_tokens is not None:
config_params["max_output_tokens"] = self.max_output_tokens
if self.stop_sequences:
config_params["stop_sequences"] = self.stop_sequences
if response_model:
config_params["response_mime_type"] = "application/json"
config_params["response_schema"] = response_model.model_json_schema()
@@ -324,9 +417,6 @@ class GeminiCompletion(BaseLLM):
if tools and self.supports_tools:
config_params["tools"] = self._convert_tools_for_interference(tools)
if self.safety_settings:
config_params["safety_settings"] = self.safety_settings
return types.GenerateContentConfig(**config_params)
def _convert_tools_for_interference( # type: ignore[no-any-unimported]
@@ -404,8 +494,8 @@ class GeminiCompletion(BaseLLM):
system_instruction: str | None,
config: types.GenerateContentConfig,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle non-streaming content generation."""
@@ -416,7 +506,7 @@ class GeminiCompletion(BaseLLM):
}
try:
response = self.client.models.generate_content(**api_params)
response = self._client.models.generate_content(**api_params)
usage = self._extract_token_usage(response)
except Exception as e:
@@ -470,8 +560,8 @@ class GeminiCompletion(BaseLLM):
contents: list[types.Content],
config: types.GenerateContentConfig,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
response_model: type[BaseModel] | None = None,
) -> str:
"""Handle streaming content generation."""
@@ -484,7 +574,7 @@ class GeminiCompletion(BaseLLM):
"config": config,
}
for chunk in self.client.models.generate_content_stream(**api_params):
for chunk in self._client.models.generate_content_stream(**api_params):
if hasattr(chunk, "text") and chunk.text:
full_response += chunk.text
self._emit_stream_chunk_event(
@@ -537,52 +627,30 @@ class GeminiCompletion(BaseLLM):
return full_response
@computed_field # type: ignore[prop-decorator]
@property
def supports_function_calling(self) -> bool:
"""Check if the model supports function calling."""
return self.supports_tools
def supports_stop_words(self) -> bool:
"""Check if the model supports stop words."""
return True
def get_context_window_size(self) -> int:
"""Get the context window size for the model."""
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES
min_context = 1024
max_context = 2097152
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
if value < min_context or value > max_context:
if value < MIN_CONTEXT_WINDOW or value > MAX_CONTEXT_WINDOW:
raise ValueError(
f"Context window for {key} must be between {min_context} and {max_context}"
f"Context window for {key} must be between {MIN_CONTEXT_WINDOW} and {MAX_CONTEXT_WINDOW}"
)
context_windows = {
"gemini-2.0-flash": 1048576, # 1M tokens
"gemini-2.0-flash-thinking": 32768,
"gemini-2.0-flash-lite": 1048576,
"gemini-2.5-flash": 1048576,
"gemini-2.5-pro": 1048576,
"gemini-1.5-pro": 2097152, # 2M tokens
"gemini-1.5-flash": 1048576,
"gemini-1.5-flash-8b": 1048576,
"gemini-1.0-pro": 32768,
"gemma-3-1b": 32000,
"gemma-3-4b": 128000,
"gemma-3-12b": 128000,
"gemma-3-27b": 128000,
}
# Find the best match for the model name
for model_prefix, size in context_windows.items():
for model_prefix, size in GEMINI_CONTEXT_WINDOWS.items():
if self.model.startswith(model_prefix):
return int(size * CONTEXT_WINDOW_USAGE_RATIO)
# Default context window size for Gemini models
return int(1048576 * CONTEXT_WINDOW_USAGE_RATIO) # 1M tokens
def _extract_token_usage(self, response: dict[str, Any]) -> dict[str, Any]:
@staticmethod
def _extract_token_usage(response: dict[str, Any]) -> dict[str, Any]:
"""Extract token usage from Gemini response."""
if hasattr(response, "usage_metadata"):
usage = response.usage_metadata
@@ -594,8 +662,8 @@ class GeminiCompletion(BaseLLM):
}
return {"total_tokens": 0}
@staticmethod
def _convert_contents_to_dict( # type: ignore[no-any-unimported]
self,
contents: list[types.Content],
) -> list[dict[str, str]]:
"""Convert contents to dict format."""

View File

@@ -4,16 +4,23 @@ from collections.abc import Iterator
import json
import logging
import os
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Final
import httpx
from openai import APIConnectionError, NotFoundError, OpenAI
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from pydantic import BaseModel
from pydantic import (
BaseModel,
Field,
PrivateAttr,
model_validator,
)
from typing_extensions import Self
from crewai.events.types.llm_events import LLMCallType
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES
from crewai.llms.base_llm import BaseLLM
from crewai.llms.hooks.transport import HTTPTransport
from crewai.utilities.agent_utils import is_context_length_exceeded
@@ -25,11 +32,28 @@ from crewai.utilities.types import LLMMessage
if TYPE_CHECKING:
from crewai.agent.core import Agent
from crewai.llms.hooks.base import BaseInterceptor
from crewai.task import Task
from crewai.tools.base_tool import BaseTool
OPENAI_CONTEXT_WINDOWS: dict[str, int] = {
"gpt-4": 8192,
"gpt-4o": 128000,
"gpt-4o-mini": 200000,
"gpt-4-turbo": 128000,
"gpt-4.1": 1047576,
"gpt-4.1-mini-2025-04-14": 1047576,
"gpt-4.1-nano-2025-04-14": 1047576,
"o1-preview": 128000,
"o1-mini": 128000,
"o3-mini": 200000,
"o4-mini": 200000,
}
MIN_CONTEXT_WINDOW: Final[int] = 1024
MAX_CONTEXT_WINDOW: Final[int] = 2097152
class OpenAICompletion(BaseLLM):
"""OpenAI native completion implementation.
@@ -37,112 +61,125 @@ class OpenAICompletion(BaseLLM):
offering native structured outputs, function calling, and streaming support.
"""
def __init__(
self,
model: str = "gpt-4o",
api_key: str | None = None,
base_url: str | None = None,
organization: str | None = None,
project: str | None = None,
timeout: float | None = None,
max_retries: int = 2,
default_headers: dict[str, str] | None = None,
default_query: dict[str, Any] | None = None,
client_params: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
frequency_penalty: float | None = None,
presence_penalty: float | None = None,
max_tokens: int | None = None,
max_completion_tokens: int | None = None,
seed: int | None = None,
stream: bool = False,
response_format: dict[str, Any] | type[BaseModel] | None = None,
logprobs: bool | None = None,
top_logprobs: int | None = None,
reasoning_effort: str | None = None,
provider: str | None = None,
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
**kwargs: Any,
) -> None:
"""Initialize OpenAI chat completion client."""
model: str = Field(
default="gpt-4o",
description="OpenAI model name (e.g., 'gpt-4o')",
)
organization: str | None = Field(
default=None,
description="Name of the OpenAI organization",
)
project: str | None = Field(
default=None,
description="Name of the OpenAI project",
)
api_base: str | None = Field(
default=os.getenv("OPENAI_BASE_URL"),
description="Base URL for OpenAI API",
)
default_headers: dict[str, str] | None = Field(
default=None,
description="Default headers for OpenAI API requests",
)
default_query: dict[str, Any] | None = Field(
default=None,
description="Default query parameters for OpenAI API requests",
)
top_p: float | None = Field(
default=None,
description="Top-p sampling parameter",
)
frequency_penalty: float | None = Field(
default=None,
description="Frequency penalty parameter",
)
presence_penalty: float | None = Field(
default=None,
description="Presence penalty parameter",
)
max_completion_tokens: int | None = Field(
default=None,
description="Maximum tokens for completion",
)
seed: int | None = Field(
default=None,
description="Random seed for reproducibility",
)
response_format: dict[str, Any] | type[BaseModel] | None = Field(
default=None,
description="Response format for structured output",
)
logprobs: bool | None = Field(
default=None,
description="Whether to include log probabilities",
)
top_logprobs: int | None = Field(
default=None,
description="Number of top log probabilities to return",
)
reasoning_effort: str | None = Field(
default=None,
description="Reasoning effort level for o1 models",
)
supports_function_calling: bool = Field(
default=True,
description="Whether the model supports function calling",
)
is_o1_model: bool = Field(
default=False,
description="Whether the model is an o1 model",
)
is_gpt4_model: bool = Field(
default=False,
description="Whether the model is a GPT-4 model",
)
_client: OpenAI = PrivateAttr(
default_factory=OpenAI,
)
if provider is None:
provider = kwargs.pop("provider", "openai")
self.interceptor = interceptor
# Client configuration attributes
self.organization = organization
self.project = project
self.max_retries = max_retries
self.default_headers = default_headers
self.default_query = default_query
self.client_params = client_params
self.timeout = timeout
self.base_url = base_url
self.api_base = kwargs.pop("api_base", None)
super().__init__(
model=model,
temperature=temperature,
api_key=api_key or os.getenv("OPENAI_API_KEY"),
base_url=base_url,
timeout=timeout,
provider=provider,
**kwargs,
)
client_config = self._get_client_params()
if self.interceptor:
transport = HTTPTransport(interceptor=self.interceptor)
http_client = httpx.Client(transport=transport)
client_config["http_client"] = http_client
self.client = OpenAI(**client_config)
# Completion parameters
self.top_p = top_p
self.frequency_penalty = frequency_penalty
self.presence_penalty = presence_penalty
self.max_tokens = max_tokens
self.max_completion_tokens = max_completion_tokens
self.seed = seed
self.stream = stream
self.response_format = response_format
self.logprobs = logprobs
self.top_logprobs = top_logprobs
self.reasoning_effort = reasoning_effort
self.is_o1_model = "o1" in model.lower()
self.is_gpt4_model = "gpt-4" in model.lower()
def _get_client_params(self) -> dict[str, Any]:
"""Get OpenAI client parameters."""
@model_validator(mode="after")
def initialize_client(self) -> Self:
"""Initialize the Anthropic client after Pydantic validation.
This runs after all field validation is complete, ensuring that:
- All BaseLLM fields are set (model, temperature, stop_sequences, etc.)
- Field validators have run (stop_sequences is normalized to set[str])
- API key and other configuration is ready
"""
if self.api_key is None:
self.api_key = os.getenv("OPENAI_API_KEY")
if self.api_key is None:
raise ValueError("OPENAI_API_KEY is required")
base_params = {
"api_key": self.api_key,
"organization": self.organization,
"project": self.project,
"base_url": self.base_url
or self.api_base
or os.getenv("OPENAI_BASE_URL")
or None,
"timeout": self.timeout,
"max_retries": self.max_retries,
"default_headers": self.default_headers,
"default_query": self.default_query,
}
self.is_o1_model = "o1" in self.model.lower()
self.supports_function_calling = not self.is_o1_model
self.is_gpt4_model = "gpt-4" in self.model.lower()
self.supports_stop_words = not self.is_o1_model
client_params = {k: v for k, v in base_params.items() if v is not None}
params = self.model_dump(
include={
"api_key",
"organization",
"project",
"base_url",
"timeout",
"max_retries",
"default_headers",
"default_query",
},
exclude_none=True,
)
if self.interceptor:
transport = HTTPTransport(interceptor=self.interceptor)
http_client = httpx.Client(transport=transport)
params["http_client"] = http_client
if self.client_params:
client_params.update(self.client_params)
params.update(self.client_params)
return client_params
self._client = OpenAI(**params)
return self
def call(
self,
@@ -213,38 +250,26 @@ class OpenAICompletion(BaseLLM):
self, messages: list[LLMMessage], tools: list[dict[str, BaseTool]] | None = None
) -> dict[str, Any]:
"""Prepare parameters for OpenAI chat completion."""
params: dict[str, Any] = {
"model": self.model,
"messages": messages,
}
if self.stream:
params["stream"] = self.stream
params = self.model_dump(
include={
"model",
"stream",
"temperature",
"top_p",
"frequency_penalty",
"presence_penalty",
"max_completion_tokens",
"max_tokens",
"seed",
"logprobs",
"top_logprobs",
"reasoning_effort",
},
exclude_none=True,
)
params["messages"] = messages
params.update(self.additional_params)
if self.temperature is not None:
params["temperature"] = self.temperature
if self.top_p is not None:
params["top_p"] = self.top_p
if self.frequency_penalty is not None:
params["frequency_penalty"] = self.frequency_penalty
if self.presence_penalty is not None:
params["presence_penalty"] = self.presence_penalty
if self.max_completion_tokens is not None:
params["max_completion_tokens"] = self.max_completion_tokens
elif self.max_tokens is not None:
params["max_tokens"] = self.max_tokens
if self.seed is not None:
params["seed"] = self.seed
if self.logprobs is not None:
params["logprobs"] = self.logprobs
if self.top_logprobs is not None:
params["top_logprobs"] = self.top_logprobs
# Handle o1 model specific parameters
if self.is_o1_model and self.reasoning_effort:
params["reasoning_effort"] = self.reasoning_effort
if tools:
params["tools"] = self._convert_tools_for_interference(tools)
params["tool_choice"] = "auto"
@@ -296,14 +321,14 @@ class OpenAICompletion(BaseLLM):
self,
params: dict[str, Any],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle non-streaming chat completion."""
try:
if response_model:
parsed_response = self.client.beta.chat.completions.parse(
parsed_response = self._client.beta.chat.completions.parse(
**params,
response_format=response_model,
)
@@ -327,7 +352,7 @@ class OpenAICompletion(BaseLLM):
)
return structured_json
response: ChatCompletion = self.client.chat.completions.create(**params)
response: ChatCompletion = self._client.chat.completions.create(**params)
usage = self._extract_openai_token_usage(response)
@@ -419,8 +444,8 @@ class OpenAICompletion(BaseLLM):
self,
params: dict[str, Any],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
response_model: type[BaseModel] | None = None,
) -> str:
"""Handle streaming chat completion."""
@@ -429,7 +454,7 @@ class OpenAICompletion(BaseLLM):
if response_model:
completion_stream: Iterator[ChatCompletionChunk] = (
self.client.chat.completions.create(**params)
self._client.chat.completions.create(**params)
)
accumulated_content = ""
@@ -472,7 +497,7 @@ class OpenAICompletion(BaseLLM):
)
return accumulated_content
stream: Iterator[ChatCompletionChunk] = self.client.chat.completions.create(
stream: Iterator[ChatCompletionChunk] = self._client.chat.completions.create(
**params
)
@@ -550,58 +575,31 @@ class OpenAICompletion(BaseLLM):
return full_response
def supports_function_calling(self) -> bool:
"""Check if the model supports function calling."""
return not self.is_o1_model
def supports_stop_words(self) -> bool:
"""Check if the model supports stop words."""
return not self.is_o1_model
def get_context_window_size(self) -> int:
"""Get the context window size for the model."""
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES
min_context = 1024
max_context = 2097152
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
if value < min_context or value > max_context:
if value < MIN_CONTEXT_WINDOW or value > MAX_CONTEXT_WINDOW:
raise ValueError(
f"Context window for {key} must be between {min_context} and {max_context}"
f"Context window for {key} must be between {MIN_CONTEXT_WINDOW} and {MAX_CONTEXT_WINDOW}"
)
# Context window sizes for OpenAI models
context_windows = {
"gpt-4": 8192,
"gpt-4o": 128000,
"gpt-4o-mini": 200000,
"gpt-4-turbo": 128000,
"gpt-4.1": 1047576,
"gpt-4.1-mini-2025-04-14": 1047576,
"gpt-4.1-nano-2025-04-14": 1047576,
"o1-preview": 128000,
"o1-mini": 128000,
"o3-mini": 200000,
"o4-mini": 200000,
}
# Find the best match for the model name
for model_prefix, size in context_windows.items():
for model_prefix, size in OPENAI_CONTEXT_WINDOWS.items():
if self.model.startswith(model_prefix):
return int(size * CONTEXT_WINDOW_USAGE_RATIO)
# Default context window size
return int(8192 * CONTEXT_WINDOW_USAGE_RATIO)
def _extract_openai_token_usage(self, response: ChatCompletion) -> dict[str, Any]:
@staticmethod
def _extract_openai_token_usage(response: ChatCompletion) -> dict[str, Any]:
"""Extract token usage from OpenAI ChatCompletion response."""
if hasattr(response, "usage") and response.usage:
if response.usage:
usage = response.usage
return {
"prompt_tokens": getattr(usage, "prompt_tokens", 0),
"completion_tokens": getattr(usage, "completion_tokens", 0),
"total_tokens": getattr(usage, "total_tokens", 0),
"prompt_tokens": usage.prompt_tokens,
"completion_tokens": usage.completion_tokens,
"total_tokens": usage.total_tokens,
}
return {"total_tokens": 0}