mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 12:28:30 +00:00
Compare commits
6 Commits
main
...
gl/feat/py
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
81850350e8 | ||
|
|
7404d8f198 | ||
|
|
138b9af274 | ||
|
|
a5e0803f20 | ||
|
|
c4279b0339 | ||
|
|
965aa48ea1 |
@@ -618,22 +618,22 @@ class Agent(BaseAgent):
|
||||
response_template=self.response_template,
|
||||
).task_execution()
|
||||
|
||||
stop_words = [self.i18n.slice("observation")]
|
||||
stop_sequences = [self.i18n.slice("observation")]
|
||||
|
||||
if self.response_template:
|
||||
stop_words.append(
|
||||
stop_sequences.append(
|
||||
self.response_template.split("{{ .Response }}")[1].strip()
|
||||
)
|
||||
|
||||
self.agent_executor = CrewAgentExecutor(
|
||||
llm=self.llm,
|
||||
llm=self.llm, # type: ignore[arg-type]
|
||||
task=task, # type: ignore[arg-type]
|
||||
agent=self,
|
||||
crew=self.crew,
|
||||
tools=parsed_tools,
|
||||
prompt=prompt,
|
||||
original_tools=raw_tools,
|
||||
stop_words=stop_words,
|
||||
stop_sequences=stop_sequences,
|
||||
max_iter=self.max_iter,
|
||||
tools_handler=self.tools_handler,
|
||||
tools_names=get_tool_names(parsed_tools),
|
||||
@@ -974,7 +974,9 @@ class Agent(BaseAgent):
|
||||
path = parsed.path.replace("/", "_").strip("_")
|
||||
return f"{domain}_{path}" if path else domain
|
||||
|
||||
def _get_mcp_tool_schemas(self, server_params: dict) -> dict[str, dict]:
|
||||
def _get_mcp_tool_schemas(
|
||||
self, server_params: dict[str, Any]
|
||||
) -> dict[str, dict[str, Any]] | Any:
|
||||
"""Get tool schemas from MCP server for wrapper creation with caching."""
|
||||
server_url = server_params["url"]
|
||||
|
||||
@@ -1006,7 +1008,7 @@ class Agent(BaseAgent):
|
||||
|
||||
async def _get_mcp_tool_schemas_async(
|
||||
self, server_params: dict[str, Any]
|
||||
) -> dict[str, dict]:
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Async implementation of MCP tool schema retrieval with timeouts and retries."""
|
||||
server_url = server_params["url"]
|
||||
return await self._retry_mcp_discovery(
|
||||
@@ -1014,7 +1016,7 @@ class Agent(BaseAgent):
|
||||
)
|
||||
|
||||
async def _retry_mcp_discovery(
|
||||
self, operation_func, server_url: str
|
||||
self, operation_func: Any, server_url: str
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Retry MCP discovery operation with exponential backoff, avoiding try-except in loop."""
|
||||
last_error = None
|
||||
@@ -1045,7 +1047,7 @@ class Agent(BaseAgent):
|
||||
|
||||
@staticmethod
|
||||
async def _attempt_mcp_discovery(
|
||||
operation_func, server_url: str
|
||||
operation_func: Any, server_url: str
|
||||
) -> tuple[dict[str, dict[str, Any]] | None, str, bool]:
|
||||
"""Attempt single MCP discovery operation and return (result, error_message, should_retry)."""
|
||||
try:
|
||||
@@ -1149,13 +1151,13 @@ class Agent(BaseAgent):
|
||||
Field(..., description=field_description),
|
||||
)
|
||||
else:
|
||||
field_definitions[field_name] = (
|
||||
field_definitions[field_name] = ( # type: ignore[assignment]
|
||||
field_type | None,
|
||||
Field(default=None, description=field_description),
|
||||
)
|
||||
|
||||
model_name = f"{tool_name.replace('-', '_').replace(' ', '_')}Schema"
|
||||
return create_model(model_name, **field_definitions)
|
||||
return create_model(model_name, **field_definitions) # type: ignore[no-any-return,call-overload]
|
||||
|
||||
def _json_type_to_python(self, field_schema: dict[str, Any]) -> type:
|
||||
"""Convert JSON Schema type to Python type.
|
||||
@@ -1175,12 +1177,12 @@ class Agent(BaseAgent):
|
||||
if "const" in option:
|
||||
types.append(str)
|
||||
else:
|
||||
types.append(self._json_type_to_python(option))
|
||||
types.append(self._json_type_to_python(option)) # type: ignore[arg-type]
|
||||
unique_types = list(set(types))
|
||||
if len(unique_types) > 1:
|
||||
result = unique_types[0]
|
||||
for t in unique_types[1:]:
|
||||
result = result | t
|
||||
result = result | t # type: ignore[assignment]
|
||||
return result
|
||||
return unique_types[0]
|
||||
|
||||
@@ -1193,10 +1195,10 @@ class Agent(BaseAgent):
|
||||
"object": dict,
|
||||
}
|
||||
|
||||
return type_mapping.get(json_type, Any)
|
||||
return type_mapping.get(json_type, Any) # type: ignore[arg-type]
|
||||
|
||||
@staticmethod
|
||||
def _fetch_amp_mcp_servers(mcp_name: str) -> list[dict]:
|
||||
def _fetch_amp_mcp_servers(mcp_name: str) -> list[dict[str, Any]]:
|
||||
"""Fetch MCP server configurations from CrewAI AMP API."""
|
||||
# TODO: Implement AMP API call to "integrations/mcps" endpoint
|
||||
# Should return list of server configs with URLs
|
||||
@@ -1435,7 +1437,7 @@ class Agent(BaseAgent):
|
||||
goal=self.goal,
|
||||
backstory=self.backstory,
|
||||
llm=self.llm,
|
||||
tools=self.tools or [],
|
||||
tools=self.tools,
|
||||
max_iterations=self.max_iter,
|
||||
max_execution_time=self.max_execution_time,
|
||||
respect_context_window=self.respect_context_window,
|
||||
|
||||
@@ -137,7 +137,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
default=False,
|
||||
description="Enable agent to delegate and ask questions among each other.",
|
||||
)
|
||||
tools: list[BaseTool] | None = Field(
|
||||
tools: list[BaseTool] = Field(
|
||||
default_factory=list, description="Tools at agents' disposal"
|
||||
)
|
||||
max_iter: int = Field(
|
||||
|
||||
@@ -73,7 +73,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
max_iter: int,
|
||||
tools: list[CrewStructuredTool],
|
||||
tools_names: str,
|
||||
stop_words: list[str],
|
||||
stop_sequences: list[str],
|
||||
tools_description: str,
|
||||
tools_handler: ToolsHandler,
|
||||
step_callback: Any = None,
|
||||
@@ -95,7 +95,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
max_iter: Maximum iterations.
|
||||
tools: Available tools.
|
||||
tools_names: Tool names string.
|
||||
stop_words: Stop word list.
|
||||
stop_sequences: Stop sequences list for halting generation.
|
||||
tools_description: Tool descriptions.
|
||||
tools_handler: Tool handler instance.
|
||||
step_callback: Optional step callback.
|
||||
@@ -114,7 +114,6 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
self.prompt = prompt
|
||||
self.tools = tools
|
||||
self.tools_names = tools_names
|
||||
self.stop = stop_words
|
||||
self.max_iter = max_iter
|
||||
self.callbacks = callbacks or []
|
||||
self._printer: Printer = Printer()
|
||||
@@ -131,15 +130,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
self.iterations = 0
|
||||
self.log_error_after = 3
|
||||
if self.llm:
|
||||
# This may be mutating the shared llm object and needs further evaluation
|
||||
existing_stop = getattr(self.llm, "stop", [])
|
||||
self.llm.stop = list(
|
||||
set(
|
||||
existing_stop + self.stop
|
||||
if isinstance(existing_stop, list)
|
||||
else self.stop
|
||||
)
|
||||
)
|
||||
self.llm.stop_sequences.extend(stop_sequences)
|
||||
|
||||
@property
|
||||
def use_stop_words(self) -> bool:
|
||||
@@ -148,7 +139,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
Returns:
|
||||
bool: True if tool should be used or not.
|
||||
"""
|
||||
return self.llm.supports_stop_words() if self.llm else False
|
||||
return self.llm.supports_stop_words if self.llm else False
|
||||
|
||||
def invoke(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Execute the agent with given inputs.
|
||||
|
||||
@@ -20,8 +20,7 @@ from typing import (
|
||||
)
|
||||
|
||||
from dotenv import load_dotenv
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
@@ -54,7 +53,6 @@ if TYPE_CHECKING:
|
||||
from litellm.utils import supports_response_schema
|
||||
|
||||
from crewai.agent.core import Agent
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
from crewai.task import Task
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.utilities.types import LLMMessage
|
||||
@@ -320,7 +318,138 @@ class AccumulatedToolArgs(BaseModel):
|
||||
|
||||
|
||||
class LLM(BaseLLM):
|
||||
completion_cost: float | None = None
|
||||
completion_cost: float | None = Field(
|
||||
default=None, description="The completion cost of the LLM."
|
||||
)
|
||||
top_p: float | None = Field(
|
||||
default=None, description="Sampling probability threshold."
|
||||
)
|
||||
n: int | None = Field(
|
||||
default=None, description="Number of completions to generate."
|
||||
)
|
||||
max_completion_tokens: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum number of tokens to generate in the completion.",
|
||||
)
|
||||
max_tokens: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum number of tokens allowed in the prompt + completion.",
|
||||
)
|
||||
presence_penalty: float | None = Field(
|
||||
default=None, description="Penalty on the presence penalty."
|
||||
)
|
||||
frequency_penalty: float | None = Field(
|
||||
default=None, description="Penalty on the frequency penalty."
|
||||
)
|
||||
logit_bias: dict[int, float] | None = Field(
|
||||
default=None,
|
||||
description="Modifies the likelihood of specified tokens appearing in the completion.",
|
||||
)
|
||||
response_format: type[BaseModel] | None = Field(
|
||||
default=None,
|
||||
description="Pydantic model class for structured response parsing.",
|
||||
)
|
||||
seed: int | None = Field(
|
||||
default=None,
|
||||
description="Random seed for reproducibility.",
|
||||
)
|
||||
logprobs: int | None = Field(
|
||||
default=None,
|
||||
description="Number of top logprobs to return.",
|
||||
)
|
||||
top_logprobs: int | None = Field(
|
||||
default=None,
|
||||
description="Number of top logprobs to return.",
|
||||
)
|
||||
api_base: str | None = Field(
|
||||
default=None,
|
||||
description="Base URL for the API endpoint.",
|
||||
)
|
||||
api_version: str | None = Field(
|
||||
default=None,
|
||||
description="API version to use.",
|
||||
)
|
||||
callbacks: list[Any] = Field(
|
||||
default_factory=list,
|
||||
description="List of callback handlers for LLM events.",
|
||||
)
|
||||
reasoning_effort: Literal["none", "low", "medium", "high"] | None = Field(
|
||||
default=None,
|
||||
description="Level of reasoning effort for the LLM.",
|
||||
)
|
||||
context_window_size: int = Field(
|
||||
default=0,
|
||||
description="The context window size of the LLM.",
|
||||
)
|
||||
is_anthropic: bool = Field(
|
||||
default=False,
|
||||
description="Indicates if the model is from Anthropic provider.",
|
||||
)
|
||||
supports_function_calling: bool = Field(
|
||||
default=False,
|
||||
description="Indicates if the model supports function calling.",
|
||||
)
|
||||
supports_stop_words: bool = Field(
|
||||
default=False,
|
||||
description="Indicates if the model supports stop words.",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def initialize_client(self) -> Self:
|
||||
self.is_anthropic = any(
|
||||
prefix in self.model.lower() for prefix in ANTHROPIC_PREFIXES
|
||||
)
|
||||
try:
|
||||
provider = self._get_custom_llm_provider()
|
||||
self.supports_function_calling = litellm.utils.supports_function_calling(
|
||||
self.model, custom_llm_provider=provider
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to check function calling support: {e!s}")
|
||||
self.supports_function_calling = False
|
||||
try:
|
||||
params = get_supported_openai_params(model=self.model)
|
||||
self.supports_stop_words = params is not None and "stop" in params
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to get supported params: {e!s}")
|
||||
self.supports_stop_words = False
|
||||
|
||||
with suppress_warnings():
|
||||
callback_types = [type(callback) for callback in self.callbacks]
|
||||
for callback in litellm.success_callback[:]:
|
||||
if type(callback) in callback_types:
|
||||
litellm.success_callback.remove(callback)
|
||||
|
||||
for callback in litellm._async_success_callback[:]:
|
||||
if type(callback) in callback_types:
|
||||
litellm._async_success_callback.remove(callback)
|
||||
|
||||
litellm.callbacks = self.callbacks
|
||||
|
||||
with suppress_warnings():
|
||||
success_callbacks_str = os.environ.get("LITELLM_SUCCESS_CALLBACKS", "")
|
||||
success_callbacks: list[str | Callable[..., Any] | CustomLogger] = []
|
||||
if success_callbacks_str:
|
||||
success_callbacks = [
|
||||
cb.strip() for cb in success_callbacks_str.split(",") if cb.strip()
|
||||
]
|
||||
|
||||
failure_callbacks_str = os.environ.get("LITELLM_FAILURE_CALLBACKS", "")
|
||||
if failure_callbacks_str:
|
||||
failure_callbacks: list[str | Callable[..., Any] | CustomLogger] = [
|
||||
cb.strip() for cb in failure_callbacks_str.split(",") if cb.strip()
|
||||
]
|
||||
|
||||
litellm.success_callback = success_callbacks
|
||||
litellm.failure_callback = failure_callbacks
|
||||
return self
|
||||
|
||||
# @computed_field
|
||||
# @property
|
||||
# def is_anthropic(self) -> bool:
|
||||
# """Determine if the model is from Anthropic provider."""
|
||||
# anthropic_prefixes = ("anthropic/", "claude-", "claude/")
|
||||
# return any(prefix in self.model.lower() for prefix in anthropic_prefixes)
|
||||
|
||||
def __new__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> LLM:
|
||||
"""Factory method that routes to native SDK or falls back to LiteLLM."""
|
||||
@@ -383,98 +512,6 @@ class LLM(BaseLLM):
|
||||
|
||||
return None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
timeout: float | int | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
n: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | float | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[int, float] | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
seed: int | None = None,
|
||||
logprobs: int | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
base_url: str | None = None,
|
||||
api_base: str | None = None,
|
||||
api_version: str | None = None,
|
||||
api_key: str | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
reasoning_effort: Literal["none", "low", "medium", "high"] | None = None,
|
||||
stream: bool = False,
|
||||
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize LLM instance.
|
||||
|
||||
Note: This __init__ method is only called for fallback instances.
|
||||
Native provider instances handle their own initialization in their respective classes.
|
||||
"""
|
||||
super().__init__(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
timeout=timeout,
|
||||
**kwargs,
|
||||
)
|
||||
self.model = model
|
||||
self.timeout = timeout
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
self.n = n
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
self.max_tokens = max_tokens
|
||||
self.presence_penalty = presence_penalty
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.logit_bias = logit_bias
|
||||
self.response_format = response_format
|
||||
self.seed = seed
|
||||
self.logprobs = logprobs
|
||||
self.top_logprobs = top_logprobs
|
||||
self.base_url = base_url
|
||||
self.api_base = api_base
|
||||
self.api_version = api_version
|
||||
self.api_key = api_key
|
||||
self.callbacks = callbacks
|
||||
self.context_window_size = 0
|
||||
self.reasoning_effort = reasoning_effort
|
||||
self.additional_params = kwargs
|
||||
self.is_anthropic = self._is_anthropic_model(model)
|
||||
self.stream = stream
|
||||
self.interceptor = interceptor
|
||||
|
||||
litellm.drop_params = True
|
||||
|
||||
# Normalize self.stop to always be a list[str]
|
||||
if stop is None:
|
||||
self.stop: list[str] = []
|
||||
elif isinstance(stop, str):
|
||||
self.stop = [stop]
|
||||
else:
|
||||
self.stop = stop
|
||||
|
||||
self.set_callbacks(callbacks or [])
|
||||
self.set_env_callbacks()
|
||||
|
||||
@staticmethod
|
||||
def _is_anthropic_model(model: str) -> bool:
|
||||
"""Determine if the model is from Anthropic provider.
|
||||
|
||||
Args:
|
||||
model: The model identifier string.
|
||||
|
||||
Returns:
|
||||
bool: True if the model is from Anthropic, False otherwise.
|
||||
"""
|
||||
anthropic_prefixes = ("anthropic/", "claude-", "claude/")
|
||||
return any(prefix in model.lower() for prefix in anthropic_prefixes)
|
||||
|
||||
def _prepare_completion_params(
|
||||
self,
|
||||
messages: str | list[LLMMessage],
|
||||
@@ -1188,8 +1225,6 @@ class LLM(BaseLLM):
|
||||
message["role"] = msg_role
|
||||
# --- 5) Set up callbacks if provided
|
||||
with suppress_warnings():
|
||||
if callbacks and len(callbacks) > 0:
|
||||
self.set_callbacks(callbacks)
|
||||
try:
|
||||
# --- 6) Prepare parameters for the completion call
|
||||
params = self._prepare_completion_params(messages, tools)
|
||||
@@ -1378,24 +1413,6 @@ class LLM(BaseLLM):
|
||||
"Please remove response_format or use a supported model."
|
||||
)
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
try:
|
||||
provider = self._get_custom_llm_provider()
|
||||
return litellm.utils.supports_function_calling(
|
||||
self.model, custom_llm_provider=provider
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to check function calling support: {e!s}")
|
||||
return False
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
try:
|
||||
params = get_supported_openai_params(model=self.model)
|
||||
return params is not None and "stop" in params
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to get supported params: {e!s}")
|
||||
return False
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
"""
|
||||
Returns the context window size, using 75% of the maximum to avoid
|
||||
@@ -1425,60 +1442,6 @@ class LLM(BaseLLM):
|
||||
self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
return self.context_window_size
|
||||
|
||||
@staticmethod
|
||||
def set_callbacks(callbacks: list[Any]) -> None:
|
||||
"""
|
||||
Attempt to keep a single set of callbacks in litellm by removing old
|
||||
duplicates and adding new ones.
|
||||
"""
|
||||
with suppress_warnings():
|
||||
callback_types = [type(callback) for callback in callbacks]
|
||||
for callback in litellm.success_callback[:]:
|
||||
if type(callback) in callback_types:
|
||||
litellm.success_callback.remove(callback)
|
||||
|
||||
for callback in litellm._async_success_callback[:]:
|
||||
if type(callback) in callback_types:
|
||||
litellm._async_success_callback.remove(callback)
|
||||
|
||||
litellm.callbacks = callbacks
|
||||
|
||||
@staticmethod
|
||||
def set_env_callbacks() -> None:
|
||||
"""Sets the success and failure callbacks for the LiteLLM library from environment variables.
|
||||
|
||||
This method reads the `LITELLM_SUCCESS_CALLBACKS` and `LITELLM_FAILURE_CALLBACKS`
|
||||
environment variables, which should contain comma-separated lists of callback names.
|
||||
It then assigns these lists to `litellm.success_callback` and `litellm.failure_callback`,
|
||||
respectively.
|
||||
|
||||
If the environment variables are not set or are empty, the corresponding callback lists
|
||||
will be set to empty lists.
|
||||
|
||||
Examples:
|
||||
LITELLM_SUCCESS_CALLBACKS="langfuse,langsmith"
|
||||
LITELLM_FAILURE_CALLBACKS="langfuse"
|
||||
|
||||
This will set `litellm.success_callback` to ["langfuse", "langsmith"] and
|
||||
`litellm.failure_callback` to ["langfuse"].
|
||||
"""
|
||||
with suppress_warnings():
|
||||
success_callbacks_str = os.environ.get("LITELLM_SUCCESS_CALLBACKS", "")
|
||||
success_callbacks: list[str | Callable[..., Any] | CustomLogger] = []
|
||||
if success_callbacks_str:
|
||||
success_callbacks = [
|
||||
cb.strip() for cb in success_callbacks_str.split(",") if cb.strip()
|
||||
]
|
||||
|
||||
failure_callbacks_str = os.environ.get("LITELLM_FAILURE_CALLBACKS", "")
|
||||
if failure_callbacks_str:
|
||||
failure_callbacks: list[str | Callable[..., Any] | CustomLogger] = [
|
||||
cb.strip() for cb in failure_callbacks_str.split(",") if cb.strip()
|
||||
]
|
||||
|
||||
litellm.success_callback = success_callbacks
|
||||
litellm.failure_callback = failure_callbacks
|
||||
|
||||
def __copy__(self) -> LLM:
|
||||
"""Create a shallow copy of the LLM instance."""
|
||||
# Filter out parameters that are already explicitly passed to avoid conflicts
|
||||
@@ -1539,7 +1502,7 @@ class LLM(BaseLLM):
|
||||
**filtered_params,
|
||||
)
|
||||
|
||||
def __deepcopy__(self, memo: dict[int, Any] | None) -> LLM:
|
||||
def __deepcopy__(self, memo: dict[int, Any] | None) -> LLM: # type: ignore[override]
|
||||
"""Create a deep copy of the LLM instance."""
|
||||
import copy
|
||||
|
||||
|
||||
@@ -13,8 +13,9 @@ import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, Final
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import AliasChoices, BaseModel, Field, PrivateAttr, field_validator
|
||||
|
||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.llm_events import (
|
||||
LLMCallCompletedEvent,
|
||||
@@ -28,6 +29,7 @@ from crewai.events.types.tool_usage_events import (
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageStartedEvent,
|
||||
)
|
||||
from crewai.llms.hooks import BaseInterceptor
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
|
||||
|
||||
@@ -43,7 +45,7 @@ DEFAULT_SUPPORTS_STOP_WORDS: Final[bool] = True
|
||||
_JSON_EXTRACTION_PATTERN: Final[re.Pattern[str]] = re.compile(r"\{.*}", re.DOTALL)
|
||||
|
||||
|
||||
class BaseLLM(ABC):
|
||||
class BaseLLM(BaseModel, ABC):
|
||||
"""Abstract base class for LLM implementations.
|
||||
|
||||
This class defines the interface that all LLM implementations must follow.
|
||||
@@ -55,70 +57,105 @@ class BaseLLM(ABC):
|
||||
implement proper validation for input parameters and provide clear error
|
||||
messages when things go wrong.
|
||||
|
||||
|
||||
Attributes:
|
||||
model: The model identifier/name.
|
||||
temperature: Optional temperature setting for response generation.
|
||||
stop: A list of stop sequences that the LLM should use to stop generation.
|
||||
additional_params: Additional provider-specific parameters.
|
||||
"""
|
||||
|
||||
is_litellm: bool = False
|
||||
provider: str | re.Pattern[str] = Field(
|
||||
default="openai", description="The provider of the LLM."
|
||||
)
|
||||
model: str = Field(description="The model identifier/name.")
|
||||
temperature: float | None = Field(
|
||||
default=None, ge=0, le=2, description="Temperature for response generation."
|
||||
)
|
||||
api_key: str | None = Field(default=None, description="API key for authentication.")
|
||||
base_url: str | None = Field(default=None, description="Base URL for API calls.")
|
||||
timeout: float | None = Field(default=None, description="Timeout for API calls.")
|
||||
max_retries: int = Field(
|
||||
default=2, description="Maximum number of API requests to make."
|
||||
)
|
||||
max_tokens: int | None = Field(
|
||||
default=None, description="Maximum tokens for response generation."
|
||||
)
|
||||
stream: bool | None = Field(default=False, description="Stream the API requests.")
|
||||
client: Any = Field(description="Underlying LLM client instance.")
|
||||
interceptor: BaseInterceptor[Any, Any] | None = Field(
|
||||
default=None,
|
||||
description="An optional HTTPX interceptor for modifying requests/responses.",
|
||||
)
|
||||
client_params: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Additional parameters for the underlying LLM client.",
|
||||
)
|
||||
supports_stop_words: bool = Field(
|
||||
default=DEFAULT_SUPPORTS_STOP_WORDS,
|
||||
description="Whether or not to support stop words.",
|
||||
)
|
||||
stop_sequences: list[str] = Field(
|
||||
default_factory=list,
|
||||
validation_alias=AliasChoices("stop_sequences", "stop"),
|
||||
description="Stop sequences for generation (synchronized with stop).",
|
||||
)
|
||||
is_litellm: bool = Field(
|
||||
default=False, description="Is this LLM implementation in litellm?"
|
||||
)
|
||||
additional_params: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Additional parameters for LLM calls.",
|
||||
)
|
||||
_token_usage: TokenProcess = PrivateAttr(default_factory=TokenProcess)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
temperature: float | None = None,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
provider: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the BaseLLM with default attributes.
|
||||
@field_validator("provider", mode="before")
|
||||
@classmethod
|
||||
def extract_provider_from_model(
|
||||
cls, v: str | re.Pattern[str] | None, info: Any
|
||||
) -> str | re.Pattern[str]:
|
||||
"""Extract provider from model string if not explicitly provided.
|
||||
|
||||
Args:
|
||||
model: The model identifier/name.
|
||||
temperature: Optional temperature setting for response generation.
|
||||
stop: Optional list of stop sequences for generation.
|
||||
**kwargs: Additional provider-specific parameters.
|
||||
v: Provided provider value (can be str, Pattern, or None)
|
||||
info: Validation info containing other field values
|
||||
|
||||
Returns:
|
||||
Provider name (str) or Pattern
|
||||
"""
|
||||
if not model:
|
||||
raise ValueError("Model name is required and cannot be empty")
|
||||
# If provider explicitly provided, validate and return it
|
||||
if v is not None:
|
||||
if not isinstance(v, (str, re.Pattern)):
|
||||
raise ValueError(f"Provider must be str or Pattern, got {type(v)}")
|
||||
return v
|
||||
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
# Store additional parameters for provider-specific use
|
||||
self.additional_params = kwargs
|
||||
self._provider = provider or "openai"
|
||||
model: str = info.data.get("model", "")
|
||||
if "/" in model:
|
||||
return model.partition("/")[0]
|
||||
return "openai"
|
||||
|
||||
stop = kwargs.pop("stop", None)
|
||||
if stop is None:
|
||||
self.stop: list[str] = []
|
||||
elif isinstance(stop, str):
|
||||
self.stop = [stop]
|
||||
elif isinstance(stop, list):
|
||||
self.stop = stop
|
||||
else:
|
||||
self.stop = []
|
||||
@field_validator("stop_sequences", mode="before")
|
||||
@classmethod
|
||||
def normalize_stop_sequences(
|
||||
cls, v: str | list[str] | set[str] | None
|
||||
) -> list[str]:
|
||||
"""Validate and normalize stop sequences.
|
||||
|
||||
self._token_usage = {
|
||||
"total_tokens": 0,
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"successful_requests": 0,
|
||||
"cached_prompt_tokens": 0,
|
||||
}
|
||||
Converts string to list and handles None values.
|
||||
AliasChoices handles accepting both 'stop' and 'stop_sequences' parameter names.
|
||||
"""
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, str):
|
||||
return [v]
|
||||
if isinstance(v, set):
|
||||
return list(v)
|
||||
if isinstance(v, list):
|
||||
return v
|
||||
return []
|
||||
|
||||
@property
|
||||
def provider(self) -> str:
|
||||
"""Get the provider of the LLM."""
|
||||
return self._provider
|
||||
|
||||
@provider.setter
|
||||
def provider(self, value: str) -> None:
|
||||
"""Set the provider of the LLM."""
|
||||
self._provider = value
|
||||
def stop(self) -> list[str]:
|
||||
"""Alias for stop_sequences to maintain backward compatibility."""
|
||||
return self.stop_sequences
|
||||
|
||||
@abstractmethod
|
||||
def call(
|
||||
@@ -171,14 +208,6 @@ class BaseLLM(ABC):
|
||||
"""
|
||||
return tools
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
"""Check if the LLM supports stop words.
|
||||
|
||||
Returns:
|
||||
True if the LLM supports stop words, False otherwise.
|
||||
"""
|
||||
return DEFAULT_SUPPORTS_STOP_WORDS
|
||||
|
||||
def _supports_stop_words_implementation(self) -> bool:
|
||||
"""Check if stop words are configured for this LLM instance.
|
||||
|
||||
@@ -506,7 +535,7 @@ class BaseLLM(ABC):
|
||||
"""
|
||||
if "/" in model:
|
||||
return model.partition("/")[0]
|
||||
return "openai" # Default provider
|
||||
return "openai"
|
||||
|
||||
def _track_token_usage_internal(self, usage_data: dict[str, Any]) -> None:
|
||||
"""Track token usage internally in the LLM instance.
|
||||
@@ -535,11 +564,11 @@ class BaseLLM(ABC):
|
||||
or 0
|
||||
)
|
||||
|
||||
self._token_usage["prompt_tokens"] += prompt_tokens
|
||||
self._token_usage["completion_tokens"] += completion_tokens
|
||||
self._token_usage["total_tokens"] += prompt_tokens + completion_tokens
|
||||
self._token_usage["successful_requests"] += 1
|
||||
self._token_usage["cached_prompt_tokens"] += cached_tokens
|
||||
self._token_usage.prompt_tokens += prompt_tokens
|
||||
self._token_usage.completion_tokens += completion_tokens
|
||||
self._token_usage.total_tokens += prompt_tokens + completion_tokens
|
||||
self._token_usage.successful_requests += 1
|
||||
self._token_usage.cached_prompt_tokens += cached_tokens
|
||||
|
||||
def get_token_usage_summary(self) -> UsageMetrics:
|
||||
"""Get summary of token usage for this LLM instance.
|
||||
@@ -547,4 +576,10 @@ class BaseLLM(ABC):
|
||||
Returns:
|
||||
Dictionary with token usage totals
|
||||
"""
|
||||
return UsageMetrics(**self._token_usage)
|
||||
return UsageMetrics(
|
||||
prompt_tokens=self._token_usage.prompt_tokens,
|
||||
completion_tokens=self._token_usage.completion_tokens,
|
||||
total_tokens=self._token_usage.total_tokens,
|
||||
successful_requests=self._token_usage.successful_requests,
|
||||
cached_prompt_tokens=self._token_usage.cached_prompt_tokens,
|
||||
)
|
||||
|
||||
@@ -5,11 +5,14 @@ import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field, PrivateAttr, computed_field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.llms.hooks.transport import HTTPTransport
|
||||
from crewai.llms.providers.utils.common import safe_tool_conversion
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
@@ -18,7 +21,8 @@ from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
from crewai.agent import Agent
|
||||
from crewai.task import Task
|
||||
|
||||
try:
|
||||
from anthropic import Anthropic
|
||||
@@ -31,6 +35,19 @@ except ImportError:
|
||||
) from None
|
||||
|
||||
|
||||
ANTHROPIC_CONTEXT_WINDOWS: dict[str, int] = {
|
||||
"claude-3-5-sonnet": 200000,
|
||||
"claude-3-5-haiku": 200000,
|
||||
"claude-3-opus": 200000,
|
||||
"claude-3-sonnet": 200000,
|
||||
"claude-3-haiku": 200000,
|
||||
"claude-3-7-sonnet": 200000,
|
||||
"claude-2.1": 200000,
|
||||
"claude-2": 100000,
|
||||
"claude-instant": 100000,
|
||||
}
|
||||
|
||||
|
||||
class AnthropicCompletion(BaseLLM):
|
||||
"""Anthropic native completion implementation.
|
||||
|
||||
@@ -38,110 +55,69 @@ class AnthropicCompletion(BaseLLM):
|
||||
offering native tool use, streaming support, and proper message formatting.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "claude-3-5-sonnet-20241022",
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
timeout: float | None = None,
|
||||
max_retries: int = 2,
|
||||
temperature: float | None = None,
|
||||
max_tokens: int = 4096, # Required for Anthropic
|
||||
top_p: float | None = None,
|
||||
stop_sequences: list[str] | None = None,
|
||||
stream: bool = False,
|
||||
client_params: dict[str, Any] | None = None,
|
||||
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize Anthropic chat completion client.
|
||||
model: str = Field(
|
||||
default="claude-3-5-sonnet-20241022",
|
||||
description="Anthropic model name (e.g., 'claude-3-5-sonnet-20241022')",
|
||||
)
|
||||
max_tokens: int = Field(
|
||||
default=4096,
|
||||
description="Maximum number of allowed tokens in response.",
|
||||
)
|
||||
top_p: float | None = Field(
|
||||
default=None,
|
||||
description="Nucleus sampling parameter.",
|
||||
)
|
||||
_client: Anthropic = PrivateAttr(
|
||||
default_factory=Anthropic,
|
||||
)
|
||||
|
||||
Args:
|
||||
model: Anthropic model name (e.g., 'claude-3-5-sonnet-20241022')
|
||||
api_key: Anthropic API key (defaults to ANTHROPIC_API_KEY env var)
|
||||
base_url: Custom base URL for Anthropic API
|
||||
timeout: Request timeout in seconds
|
||||
max_retries: Maximum number of retries
|
||||
temperature: Sampling temperature (0-1)
|
||||
max_tokens: Maximum tokens in response (required for Anthropic)
|
||||
top_p: Nucleus sampling parameter
|
||||
stop_sequences: Stop sequences (Anthropic uses stop_sequences, not stop)
|
||||
stream: Enable streaming responses
|
||||
client_params: Additional parameters for the Anthropic client
|
||||
interceptor: HTTP interceptor for modifying requests/responses at transport level.
|
||||
**kwargs: Additional parameters
|
||||
@model_validator(mode="after")
|
||||
def initialize_client(self) -> Self:
|
||||
"""Initialize the Anthropic client after Pydantic validation.
|
||||
|
||||
This runs after all field validation is complete, ensuring that:
|
||||
- All BaseLLM fields are set (model, temperature, stop_sequences, etc.)
|
||||
- Field validators have run (stop_sequences is normalized to set[str])
|
||||
- API key and other configuration is ready
|
||||
"""
|
||||
super().__init__(
|
||||
model=model, temperature=temperature, stop=stop_sequences or [], **kwargs
|
||||
)
|
||||
|
||||
# Client params
|
||||
self.interceptor = interceptor
|
||||
self.client_params = client_params
|
||||
self.base_url = base_url
|
||||
self.timeout = timeout
|
||||
self.max_retries = max_retries
|
||||
|
||||
self.client = Anthropic(**self._get_client_params())
|
||||
|
||||
# Store completion parameters
|
||||
self.max_tokens = max_tokens
|
||||
self.top_p = top_p
|
||||
self.stream = stream
|
||||
self.stop_sequences = stop_sequences or []
|
||||
|
||||
# Model-specific settings
|
||||
self.is_claude_3 = "claude-3" in model.lower()
|
||||
self.supports_tools = self.is_claude_3 # Claude 3+ supports tool use
|
||||
|
||||
@property
|
||||
def stop(self) -> list[str]:
|
||||
"""Get stop sequences sent to the API."""
|
||||
return self.stop_sequences
|
||||
|
||||
@stop.setter
|
||||
def stop(self, value: list[str] | str | None) -> None:
|
||||
"""Set stop sequences.
|
||||
|
||||
Synchronizes stop_sequences to ensure values set by CrewAgentExecutor
|
||||
are properly sent to the Anthropic API.
|
||||
|
||||
Args:
|
||||
value: Stop sequences as a list, single string, or None
|
||||
"""
|
||||
if value is None:
|
||||
self.stop_sequences = []
|
||||
elif isinstance(value, str):
|
||||
self.stop_sequences = [value]
|
||||
elif isinstance(value, list):
|
||||
self.stop_sequences = value
|
||||
else:
|
||||
self.stop_sequences = []
|
||||
|
||||
def _get_client_params(self) -> dict[str, Any]:
|
||||
"""Get client parameters."""
|
||||
|
||||
if self.api_key is None:
|
||||
self.api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
if self.api_key is None:
|
||||
raise ValueError("ANTHROPIC_API_KEY is required")
|
||||
|
||||
client_params = {
|
||||
"api_key": self.api_key,
|
||||
"base_url": self.base_url,
|
||||
"timeout": self.timeout,
|
||||
"max_retries": self.max_retries,
|
||||
}
|
||||
params = self.model_dump(
|
||||
include={"api_key", "base_url", "timeout", "max_retries"},
|
||||
exclude_none=True,
|
||||
)
|
||||
|
||||
if self.interceptor:
|
||||
transport = HTTPTransport(interceptor=self.interceptor)
|
||||
http_client = httpx.Client(transport=transport)
|
||||
client_params["http_client"] = http_client # type: ignore[assignment]
|
||||
params["http_client"] = http_client
|
||||
|
||||
if self.client_params:
|
||||
client_params.update(self.client_params)
|
||||
params.update(self.client_params)
|
||||
|
||||
return client_params
|
||||
self._client = Anthropic(**params)
|
||||
return self
|
||||
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def is_claude_3(self) -> bool:
|
||||
"""Check if the model is Claude 3 or higher."""
|
||||
return "claude-3" in self.model.lower()
|
||||
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def supports_tools(self) -> bool:
|
||||
"""Check if the model supports tool use."""
|
||||
return self.is_claude_3
|
||||
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the model supports function calling."""
|
||||
return self.supports_tools
|
||||
|
||||
def call(
|
||||
self,
|
||||
@@ -149,8 +125,8 @@ class AnthropicCompletion(BaseLLM):
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Call Anthropic messages API.
|
||||
@@ -229,25 +205,21 @@ class AnthropicCompletion(BaseLLM):
|
||||
Returns:
|
||||
Parameters dictionary for Anthropic API
|
||||
"""
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": self.max_tokens,
|
||||
"stream": self.stream,
|
||||
}
|
||||
|
||||
params = self.model_dump(
|
||||
include={
|
||||
"model",
|
||||
"max_tokens",
|
||||
"stream",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"stop_sequences",
|
||||
},
|
||||
)
|
||||
params["messages"] = messages
|
||||
# Add system message if present
|
||||
if system_message:
|
||||
params["system"] = system_message
|
||||
|
||||
# Add optional parameters if set
|
||||
if self.temperature is not None:
|
||||
params["temperature"] = self.temperature
|
||||
if self.top_p is not None:
|
||||
params["top_p"] = self.top_p
|
||||
if self.stop_sequences:
|
||||
params["stop_sequences"] = self.stop_sequences
|
||||
|
||||
# Handle tools for Claude 3+
|
||||
if tools and self.supports_tools:
|
||||
params["tools"] = self._convert_tools_for_interference(tools)
|
||||
@@ -266,8 +238,6 @@ class AnthropicCompletion(BaseLLM):
|
||||
continue
|
||||
|
||||
try:
|
||||
from crewai.llms.providers.utils.common import safe_tool_conversion
|
||||
|
||||
name, description, parameters = safe_tool_conversion(tool, "Anthropic")
|
||||
except (ImportError, KeyError, ValueError) as e:
|
||||
logging.error(f"Error converting tool to Anthropic format: {e}")
|
||||
@@ -341,8 +311,8 @@ class AnthropicCompletion(BaseLLM):
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming message completion."""
|
||||
@@ -357,7 +327,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
|
||||
|
||||
try:
|
||||
response: Message = self.client.messages.create(**params)
|
||||
response: Message = self._client.messages.create(**params)
|
||||
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
@@ -429,8 +399,8 @@ class AnthropicCompletion(BaseLLM):
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
"""Handle streaming message completion."""
|
||||
@@ -451,7 +421,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
stream_params = {k: v for k, v in params.items() if k != "stream"}
|
||||
|
||||
# Make streaming API call
|
||||
with self.client.messages.stream(**stream_params) as stream:
|
||||
with self._client.messages.stream(**stream_params) as stream:
|
||||
for event in stream:
|
||||
if hasattr(event, "delta") and hasattr(event.delta, "text"):
|
||||
text_delta = event.delta.text
|
||||
@@ -525,8 +495,8 @@ class AnthropicCompletion(BaseLLM):
|
||||
tool_uses: list[ToolUseBlock],
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any],
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
) -> str:
|
||||
"""Handle the complete tool use conversation flow.
|
||||
|
||||
@@ -579,7 +549,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
try:
|
||||
# Send tool results back to Claude for final response
|
||||
final_response: Message = self.client.messages.create(**follow_up_params)
|
||||
final_response: Message = self._client.messages.create(**follow_up_params)
|
||||
|
||||
# Track token usage for follow-up call
|
||||
follow_up_usage = self._extract_anthropic_token_usage(final_response)
|
||||
@@ -626,48 +596,24 @@ class AnthropicCompletion(BaseLLM):
|
||||
return tool_results[0]["content"]
|
||||
raise e
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the model supports function calling."""
|
||||
return self.supports_tools
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
"""Check if the model supports stop words."""
|
||||
return True # All Claude models support stop sequences
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
"""Get the context window size for the model."""
|
||||
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO
|
||||
|
||||
# Context window sizes for Anthropic models
|
||||
context_windows = {
|
||||
"claude-3-5-sonnet": 200000,
|
||||
"claude-3-5-haiku": 200000,
|
||||
"claude-3-opus": 200000,
|
||||
"claude-3-sonnet": 200000,
|
||||
"claude-3-haiku": 200000,
|
||||
"claude-3-7-sonnet": 200000,
|
||||
"claude-2.1": 200000,
|
||||
"claude-2": 100000,
|
||||
"claude-instant": 100000,
|
||||
}
|
||||
|
||||
# Find the best match for the model name
|
||||
for model_prefix, size in context_windows.items():
|
||||
for model_prefix, size in ANTHROPIC_CONTEXT_WINDOWS.items():
|
||||
if self.model.startswith(model_prefix):
|
||||
return int(size * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
|
||||
# Default context window size for Claude models
|
||||
return int(200000 * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
|
||||
def _extract_anthropic_token_usage(self, response: Message) -> dict[str, Any]:
|
||||
@staticmethod
|
||||
def _extract_anthropic_token_usage(response: Message) -> dict[str, Any]:
|
||||
"""Extract token usage from Anthropic response."""
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
if response.usage:
|
||||
usage = response.usage
|
||||
input_tokens = getattr(usage, "input_tokens", 0)
|
||||
output_tokens = getattr(usage, "output_tokens", 0)
|
||||
return {
|
||||
"input_tokens": input_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
"total_tokens": input_tokens + output_tokens,
|
||||
"input_tokens": usage.input_tokens,
|
||||
"output_tokens": usage.output_tokens,
|
||||
"total_tokens": usage.input_tokens + usage.output_tokens,
|
||||
}
|
||||
return {"total_tokens": 0}
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, cast
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from pydantic import BaseModel, Field, PrivateAttr, computed_field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
@@ -14,6 +16,11 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent import Agent
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
try:
|
||||
from google import genai # type: ignore[import-untyped]
|
||||
from google.genai import types # type: ignore[import-untyped]
|
||||
@@ -24,6 +31,27 @@ except ImportError:
|
||||
) from None
|
||||
|
||||
|
||||
GEMINI_CONTEXT_WINDOWS: dict[str, int] = {
|
||||
"gemini-2.0-flash": 1048576, # 1M tokens
|
||||
"gemini-2.0-flash-thinking": 32768,
|
||||
"gemini-2.0-flash-lite": 1048576,
|
||||
"gemini-2.5-flash": 1048576,
|
||||
"gemini-2.5-pro": 1048576,
|
||||
"gemini-1.5-pro": 2097152, # 2M tokens
|
||||
"gemini-1.5-flash": 1048576,
|
||||
"gemini-1.5-flash-8b": 1048576,
|
||||
"gemini-1.0-pro": 32768,
|
||||
"gemma-3-1b": 32000,
|
||||
"gemma-3-4b": 128000,
|
||||
"gemma-3-12b": 128000,
|
||||
"gemma-3-27b": 128000,
|
||||
}
|
||||
|
||||
# Context window validation constraints
|
||||
MIN_CONTEXT_WINDOW: int = 1024
|
||||
MAX_CONTEXT_WINDOW: int = 2097152
|
||||
|
||||
|
||||
class GeminiCompletion(BaseLLM):
|
||||
"""Google Gemini native completion implementation.
|
||||
|
||||
@@ -31,78 +59,140 @@ class GeminiCompletion(BaseLLM):
|
||||
offering native function calling, streaming support, and proper Gemini formatting.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "gemini-2.0-flash-001",
|
||||
api_key: str | None = None,
|
||||
project: str | None = None,
|
||||
location: str | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
top_k: int | None = None,
|
||||
max_output_tokens: int | None = None,
|
||||
stop_sequences: list[str] | None = None,
|
||||
stream: bool = False,
|
||||
safety_settings: dict[str, Any] | None = None,
|
||||
client_params: dict[str, Any] | None = None,
|
||||
interceptor: BaseInterceptor[Any, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize Google Gemini chat completion client.
|
||||
model: str = Field(
|
||||
default="gemini-2.0-flash-001",
|
||||
description="Gemini model name (e.g., 'gemini-2.0-flash-001', 'gemini-1.5-pro')",
|
||||
)
|
||||
project: str | None = Field(
|
||||
default=None,
|
||||
description="Google Cloud project ID (for Vertex AI)",
|
||||
)
|
||||
location: str = Field(
|
||||
default="us-central1",
|
||||
description="Google Cloud location (for Vertex AI)",
|
||||
)
|
||||
top_p: float | None = Field(
|
||||
default=None,
|
||||
description="Nucleus sampling parameter",
|
||||
)
|
||||
top_k: int | None = Field(
|
||||
default=None,
|
||||
description="Top-k sampling parameter",
|
||||
)
|
||||
max_output_tokens: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum tokens in response",
|
||||
)
|
||||
safety_settings: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="Safety filter settings",
|
||||
)
|
||||
_client: genai.Client = PrivateAttr( # type: ignore[no-any-unimported]
|
||||
default_factory=genai.Client,
|
||||
)
|
||||
|
||||
Args:
|
||||
model: Gemini model name (e.g., 'gemini-2.0-flash-001', 'gemini-1.5-pro')
|
||||
api_key: Google API key (defaults to GOOGLE_API_KEY or GEMINI_API_KEY env var)
|
||||
project: Google Cloud project ID (for Vertex AI)
|
||||
location: Google Cloud location (for Vertex AI, defaults to 'us-central1')
|
||||
temperature: Sampling temperature (0-2)
|
||||
top_p: Nucleus sampling parameter
|
||||
top_k: Top-k sampling parameter
|
||||
max_output_tokens: Maximum tokens in response
|
||||
stop_sequences: Stop sequences
|
||||
stream: Enable streaming responses
|
||||
safety_settings: Safety filter settings
|
||||
client_params: Additional parameters to pass to the Google Gen AI Client constructor.
|
||||
Supports parameters like http_options, credentials, debug_config, etc.
|
||||
interceptor: HTTP interceptor (not yet supported for Gemini).
|
||||
**kwargs: Additional parameters
|
||||
@model_validator(mode="after")
|
||||
def initialize_client(self) -> Self:
|
||||
"""Initialize the Anthropic client after Pydantic validation.
|
||||
|
||||
This runs after all field validation is complete, ensuring that:
|
||||
- All BaseLLM fields are set (model, temperature, stop_sequences, etc.)
|
||||
- Field validators have run (stop_sequences is normalized to set[str])
|
||||
- API key and other configuration is ready
|
||||
"""
|
||||
if interceptor is not None:
|
||||
raise NotImplementedError(
|
||||
"HTTP interceptors are not yet supported for Google Gemini provider. "
|
||||
"Interceptors are currently supported for OpenAI and Anthropic providers only."
|
||||
)
|
||||
self._client = genai.Client(**self._get_client_params())
|
||||
return self
|
||||
|
||||
super().__init__(
|
||||
model=model, temperature=temperature, stop=stop_sequences or [], **kwargs
|
||||
)
|
||||
# def __init__(
|
||||
# self,
|
||||
# model: str = "gemini-2.0-flash-001",
|
||||
# api_key: str | None = None,
|
||||
# project: str | None = None,
|
||||
# location: str | None = None,
|
||||
# temperature: float | None = None,
|
||||
# top_p: float | None = None,
|
||||
# top_k: int | None = None,
|
||||
# max_output_tokens: int | None = None,
|
||||
# stop_sequences: list[str] | None = None,
|
||||
# stream: bool = False,
|
||||
# safety_settings: dict[str, Any] | None = None,
|
||||
# client_params: dict[str, Any] | None = None,
|
||||
# interceptor: BaseInterceptor[Any, Any] | None = None,
|
||||
# **kwargs: Any,
|
||||
# # ):
|
||||
# """Initialize Google Gemini chat completion client.
|
||||
#
|
||||
# Args:
|
||||
# model: Gemini model name (e.g., 'gemini-2.0-flash-001', 'gemini-1.5-pro')
|
||||
# api_key: Google API key (defaults to GOOGLE_API_KEY or GEMINI_API_KEY env var)
|
||||
# project: Google Cloud project ID (for Vertex AI)
|
||||
# location: Google Cloud location (for Vertex AI, defaults to 'us-central1')
|
||||
# temperature: Sampling temperature (0-2)
|
||||
# top_p: Nucleus sampling parameter
|
||||
# top_k: Top-k sampling parameter
|
||||
# max_output_tokens: Maximum tokens in response
|
||||
# stop_sequences: Stop sequences
|
||||
# stream: Enable streaming responses
|
||||
# safety_settings: Safety filter settings
|
||||
# client_params: Additional parameters to pass to the Google Gen AI Client constructor.
|
||||
# Supports parameters like http_options, credentials, debug_config, etc.
|
||||
# interceptor: HTTP interceptor (not yet supported for Gemini).
|
||||
# **kwargs: Additional parameters
|
||||
# """
|
||||
# if interceptor is not None:
|
||||
# raise NotImplementedError(
|
||||
# "HTTP interceptors are not yet supported for Google Gemini provider. "
|
||||
# "Interceptors are currently supported for OpenAI and Anthropic providers only."
|
||||
# )
|
||||
#
|
||||
# super().__init__(
|
||||
# model=model, temperature=temperature, stop=stop_sequences or [], **kwargs
|
||||
# )
|
||||
#
|
||||
# # Store client params for later use
|
||||
# self.client_params = client_params or {}
|
||||
#
|
||||
# # Get API configuration with environment variable fallbacks
|
||||
# self.api_key = (
|
||||
# api_key or os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY")
|
||||
# )
|
||||
# self.project = project or os.getenv("GOOGLE_CLOUD_PROJECT")
|
||||
# self.location = location or os.getenv("GOOGLE_CLOUD_LOCATION") or "us-central1"
|
||||
#
|
||||
# use_vertexai = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true"
|
||||
#
|
||||
# self.client = self._initialize_client(use_vertexai)
|
||||
#
|
||||
# # Store completion parameters
|
||||
# self.top_p = top_p
|
||||
# self.top_k = top_k
|
||||
# self.max_output_tokens = max_output_tokens
|
||||
# self.stream = stream
|
||||
# self.safety_settings = safety_settings or {}
|
||||
# self.stop_sequences = stop_sequences or []
|
||||
#
|
||||
# # Model-specific settings
|
||||
# self.is_gemini_2 = "gemini-2" in model.lower()
|
||||
# self.is_gemini_1_5 = "gemini-1.5" in model.lower()
|
||||
# self.supports_tools = self.is_gemini_1_5 or self.is_gemini_2
|
||||
|
||||
# Store client params for later use
|
||||
self.client_params = client_params or {}
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def is_gemini_2(self) -> bool:
|
||||
"""Check if the model is Gemini 2.x."""
|
||||
return "gemini-2" in self.model.lower()
|
||||
|
||||
# Get API configuration with environment variable fallbacks
|
||||
self.api_key = (
|
||||
api_key or os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY")
|
||||
)
|
||||
self.project = project or os.getenv("GOOGLE_CLOUD_PROJECT")
|
||||
self.location = location or os.getenv("GOOGLE_CLOUD_LOCATION") or "us-central1"
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def is_gemini_1_5(self) -> bool:
|
||||
"""Check if the model is Gemini 1.5.x."""
|
||||
return "gemini-1.5" in self.model.lower()
|
||||
|
||||
use_vertexai = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true"
|
||||
|
||||
self.client = self._initialize_client(use_vertexai)
|
||||
|
||||
# Store completion parameters
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
self.max_output_tokens = max_output_tokens
|
||||
self.stream = stream
|
||||
self.safety_settings = safety_settings or {}
|
||||
self.stop_sequences = stop_sequences or []
|
||||
|
||||
# Model-specific settings
|
||||
self.is_gemini_2 = "gemini-2" in model.lower()
|
||||
self.is_gemini_1_5 = "gemini-1.5" in model.lower()
|
||||
self.supports_tools = self.is_gemini_1_5 or self.is_gemini_2
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def supports_tools(self) -> bool:
|
||||
"""Check if the model supports tool/function calling."""
|
||||
return self.is_gemini_1_5 or self.is_gemini_2
|
||||
|
||||
@property
|
||||
def stop(self) -> list[str]:
|
||||
@@ -142,6 +232,12 @@ class GeminiCompletion(BaseLLM):
|
||||
if self.client_params:
|
||||
client_params.update(self.client_params)
|
||||
|
||||
if self.interceptor:
|
||||
raise NotImplementedError(
|
||||
"HTTP interceptors are not yet supported for Google Gemini provider. "
|
||||
"Interceptors are currently supported for OpenAI and Anthropic providers only."
|
||||
)
|
||||
|
||||
if use_vertexai or self.project:
|
||||
client_params.update(
|
||||
{
|
||||
@@ -181,7 +277,7 @@ class GeminiCompletion(BaseLLM):
|
||||
|
||||
if (
|
||||
hasattr(self, "client")
|
||||
and hasattr(self.client, "vertexai")
|
||||
and hasattr(self._client, "vertexai")
|
||||
and self.client.vertexai
|
||||
):
|
||||
# Vertex AI configuration
|
||||
@@ -206,8 +302,8 @@ class GeminiCompletion(BaseLLM):
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Call Google Gemini generate content API.
|
||||
@@ -294,7 +390,16 @@ class GeminiCompletion(BaseLLM):
|
||||
GenerateContentConfig object for Gemini API
|
||||
"""
|
||||
self.tools = tools
|
||||
config_params = {}
|
||||
config_params = self.model_dump(
|
||||
include={
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
"max_output_tokens",
|
||||
"stop_sequences",
|
||||
"safety_settings",
|
||||
}
|
||||
)
|
||||
|
||||
# Add system instruction if present
|
||||
if system_instruction:
|
||||
@@ -304,18 +409,6 @@ class GeminiCompletion(BaseLLM):
|
||||
)
|
||||
config_params["system_instruction"] = system_content
|
||||
|
||||
# Add generation config parameters
|
||||
if self.temperature is not None:
|
||||
config_params["temperature"] = self.temperature
|
||||
if self.top_p is not None:
|
||||
config_params["top_p"] = self.top_p
|
||||
if self.top_k is not None:
|
||||
config_params["top_k"] = self.top_k
|
||||
if self.max_output_tokens is not None:
|
||||
config_params["max_output_tokens"] = self.max_output_tokens
|
||||
if self.stop_sequences:
|
||||
config_params["stop_sequences"] = self.stop_sequences
|
||||
|
||||
if response_model:
|
||||
config_params["response_mime_type"] = "application/json"
|
||||
config_params["response_schema"] = response_model.model_json_schema()
|
||||
@@ -324,9 +417,6 @@ class GeminiCompletion(BaseLLM):
|
||||
if tools and self.supports_tools:
|
||||
config_params["tools"] = self._convert_tools_for_interference(tools)
|
||||
|
||||
if self.safety_settings:
|
||||
config_params["safety_settings"] = self.safety_settings
|
||||
|
||||
return types.GenerateContentConfig(**config_params)
|
||||
|
||||
def _convert_tools_for_interference( # type: ignore[no-any-unimported]
|
||||
@@ -404,8 +494,8 @@ class GeminiCompletion(BaseLLM):
|
||||
system_instruction: str | None,
|
||||
config: types.GenerateContentConfig,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming content generation."""
|
||||
@@ -416,7 +506,7 @@ class GeminiCompletion(BaseLLM):
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.client.models.generate_content(**api_params)
|
||||
response = self._client.models.generate_content(**api_params)
|
||||
|
||||
usage = self._extract_token_usage(response)
|
||||
except Exception as e:
|
||||
@@ -470,8 +560,8 @@ class GeminiCompletion(BaseLLM):
|
||||
contents: list[types.Content],
|
||||
config: types.GenerateContentConfig,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
"""Handle streaming content generation."""
|
||||
@@ -484,7 +574,7 @@ class GeminiCompletion(BaseLLM):
|
||||
"config": config,
|
||||
}
|
||||
|
||||
for chunk in self.client.models.generate_content_stream(**api_params):
|
||||
for chunk in self._client.models.generate_content_stream(**api_params):
|
||||
if hasattr(chunk, "text") and chunk.text:
|
||||
full_response += chunk.text
|
||||
self._emit_stream_chunk_event(
|
||||
@@ -537,52 +627,30 @@ class GeminiCompletion(BaseLLM):
|
||||
|
||||
return full_response
|
||||
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the model supports function calling."""
|
||||
return self.supports_tools
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
"""Check if the model supports stop words."""
|
||||
return True
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
"""Get the context window size for the model."""
|
||||
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES
|
||||
|
||||
min_context = 1024
|
||||
max_context = 2097152
|
||||
|
||||
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
|
||||
if value < min_context or value > max_context:
|
||||
if value < MIN_CONTEXT_WINDOW or value > MAX_CONTEXT_WINDOW:
|
||||
raise ValueError(
|
||||
f"Context window for {key} must be between {min_context} and {max_context}"
|
||||
f"Context window for {key} must be between {MIN_CONTEXT_WINDOW} and {MAX_CONTEXT_WINDOW}"
|
||||
)
|
||||
|
||||
context_windows = {
|
||||
"gemini-2.0-flash": 1048576, # 1M tokens
|
||||
"gemini-2.0-flash-thinking": 32768,
|
||||
"gemini-2.0-flash-lite": 1048576,
|
||||
"gemini-2.5-flash": 1048576,
|
||||
"gemini-2.5-pro": 1048576,
|
||||
"gemini-1.5-pro": 2097152, # 2M tokens
|
||||
"gemini-1.5-flash": 1048576,
|
||||
"gemini-1.5-flash-8b": 1048576,
|
||||
"gemini-1.0-pro": 32768,
|
||||
"gemma-3-1b": 32000,
|
||||
"gemma-3-4b": 128000,
|
||||
"gemma-3-12b": 128000,
|
||||
"gemma-3-27b": 128000,
|
||||
}
|
||||
|
||||
# Find the best match for the model name
|
||||
for model_prefix, size in context_windows.items():
|
||||
for model_prefix, size in GEMINI_CONTEXT_WINDOWS.items():
|
||||
if self.model.startswith(model_prefix):
|
||||
return int(size * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
|
||||
# Default context window size for Gemini models
|
||||
return int(1048576 * CONTEXT_WINDOW_USAGE_RATIO) # 1M tokens
|
||||
|
||||
def _extract_token_usage(self, response: dict[str, Any]) -> dict[str, Any]:
|
||||
@staticmethod
|
||||
def _extract_token_usage(response: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Extract token usage from Gemini response."""
|
||||
if hasattr(response, "usage_metadata"):
|
||||
usage = response.usage_metadata
|
||||
@@ -594,8 +662,8 @@ class GeminiCompletion(BaseLLM):
|
||||
}
|
||||
return {"total_tokens": 0}
|
||||
|
||||
@staticmethod
|
||||
def _convert_contents_to_dict( # type: ignore[no-any-unimported]
|
||||
self,
|
||||
contents: list[types.Content],
|
||||
) -> list[dict[str, str]]:
|
||||
"""Convert contents to dict format."""
|
||||
|
||||
@@ -4,16 +4,23 @@ from collections.abc import Iterator
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Final
|
||||
|
||||
import httpx
|
||||
from openai import APIConnectionError, NotFoundError, OpenAI
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion import Choice
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||
from pydantic import BaseModel
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
PrivateAttr,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.llms.hooks.transport import HTTPTransport
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
@@ -25,11 +32,28 @@ from crewai.utilities.types import LLMMessage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent.core import Agent
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
from crewai.task import Task
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
|
||||
|
||||
OPENAI_CONTEXT_WINDOWS: dict[str, int] = {
|
||||
"gpt-4": 8192,
|
||||
"gpt-4o": 128000,
|
||||
"gpt-4o-mini": 200000,
|
||||
"gpt-4-turbo": 128000,
|
||||
"gpt-4.1": 1047576,
|
||||
"gpt-4.1-mini-2025-04-14": 1047576,
|
||||
"gpt-4.1-nano-2025-04-14": 1047576,
|
||||
"o1-preview": 128000,
|
||||
"o1-mini": 128000,
|
||||
"o3-mini": 200000,
|
||||
"o4-mini": 200000,
|
||||
}
|
||||
|
||||
MIN_CONTEXT_WINDOW: Final[int] = 1024
|
||||
MAX_CONTEXT_WINDOW: Final[int] = 2097152
|
||||
|
||||
|
||||
class OpenAICompletion(BaseLLM):
|
||||
"""OpenAI native completion implementation.
|
||||
|
||||
@@ -37,112 +61,125 @@ class OpenAICompletion(BaseLLM):
|
||||
offering native structured outputs, function calling, and streaming support.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "gpt-4o",
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
organization: str | None = None,
|
||||
project: str | None = None,
|
||||
timeout: float | None = None,
|
||||
max_retries: int = 2,
|
||||
default_headers: dict[str, str] | None = None,
|
||||
default_query: dict[str, Any] | None = None,
|
||||
client_params: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
seed: int | None = None,
|
||||
stream: bool = False,
|
||||
response_format: dict[str, Any] | type[BaseModel] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
reasoning_effort: str | None = None,
|
||||
provider: str | None = None,
|
||||
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize OpenAI chat completion client."""
|
||||
model: str = Field(
|
||||
default="gpt-4o",
|
||||
description="OpenAI model name (e.g., 'gpt-4o')",
|
||||
)
|
||||
organization: str | None = Field(
|
||||
default=None,
|
||||
description="Name of the OpenAI organization",
|
||||
)
|
||||
project: str | None = Field(
|
||||
default=None,
|
||||
description="Name of the OpenAI project",
|
||||
)
|
||||
api_base: str | None = Field(
|
||||
default=os.getenv("OPENAI_BASE_URL"),
|
||||
description="Base URL for OpenAI API",
|
||||
)
|
||||
default_headers: dict[str, str] | None = Field(
|
||||
default=None,
|
||||
description="Default headers for OpenAI API requests",
|
||||
)
|
||||
default_query: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="Default query parameters for OpenAI API requests",
|
||||
)
|
||||
top_p: float | None = Field(
|
||||
default=None,
|
||||
description="Top-p sampling parameter",
|
||||
)
|
||||
frequency_penalty: float | None = Field(
|
||||
default=None,
|
||||
description="Frequency penalty parameter",
|
||||
)
|
||||
presence_penalty: float | None = Field(
|
||||
default=None,
|
||||
description="Presence penalty parameter",
|
||||
)
|
||||
max_completion_tokens: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum tokens for completion",
|
||||
)
|
||||
seed: int | None = Field(
|
||||
default=None,
|
||||
description="Random seed for reproducibility",
|
||||
)
|
||||
response_format: dict[str, Any] | type[BaseModel] | None = Field(
|
||||
default=None,
|
||||
description="Response format for structured output",
|
||||
)
|
||||
logprobs: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to include log probabilities",
|
||||
)
|
||||
top_logprobs: int | None = Field(
|
||||
default=None,
|
||||
description="Number of top log probabilities to return",
|
||||
)
|
||||
reasoning_effort: str | None = Field(
|
||||
default=None,
|
||||
description="Reasoning effort level for o1 models",
|
||||
)
|
||||
supports_function_calling: bool = Field(
|
||||
default=True,
|
||||
description="Whether the model supports function calling",
|
||||
)
|
||||
is_o1_model: bool = Field(
|
||||
default=False,
|
||||
description="Whether the model is an o1 model",
|
||||
)
|
||||
is_gpt4_model: bool = Field(
|
||||
default=False,
|
||||
description="Whether the model is a GPT-4 model",
|
||||
)
|
||||
_client: OpenAI = PrivateAttr(
|
||||
default_factory=OpenAI,
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
provider = kwargs.pop("provider", "openai")
|
||||
|
||||
self.interceptor = interceptor
|
||||
# Client configuration attributes
|
||||
self.organization = organization
|
||||
self.project = project
|
||||
self.max_retries = max_retries
|
||||
self.default_headers = default_headers
|
||||
self.default_query = default_query
|
||||
self.client_params = client_params
|
||||
self.timeout = timeout
|
||||
self.base_url = base_url
|
||||
self.api_base = kwargs.pop("api_base", None)
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
api_key=api_key or os.getenv("OPENAI_API_KEY"),
|
||||
base_url=base_url,
|
||||
timeout=timeout,
|
||||
provider=provider,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
client_config = self._get_client_params()
|
||||
if self.interceptor:
|
||||
transport = HTTPTransport(interceptor=self.interceptor)
|
||||
http_client = httpx.Client(transport=transport)
|
||||
client_config["http_client"] = http_client
|
||||
|
||||
self.client = OpenAI(**client_config)
|
||||
|
||||
# Completion parameters
|
||||
self.top_p = top_p
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.presence_penalty = presence_penalty
|
||||
self.max_tokens = max_tokens
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
self.seed = seed
|
||||
self.stream = stream
|
||||
self.response_format = response_format
|
||||
self.logprobs = logprobs
|
||||
self.top_logprobs = top_logprobs
|
||||
self.reasoning_effort = reasoning_effort
|
||||
self.is_o1_model = "o1" in model.lower()
|
||||
self.is_gpt4_model = "gpt-4" in model.lower()
|
||||
|
||||
def _get_client_params(self) -> dict[str, Any]:
|
||||
"""Get OpenAI client parameters."""
|
||||
@model_validator(mode="after")
|
||||
def initialize_client(self) -> Self:
|
||||
"""Initialize the Anthropic client after Pydantic validation.
|
||||
|
||||
This runs after all field validation is complete, ensuring that:
|
||||
- All BaseLLM fields are set (model, temperature, stop_sequences, etc.)
|
||||
- Field validators have run (stop_sequences is normalized to set[str])
|
||||
- API key and other configuration is ready
|
||||
"""
|
||||
if self.api_key is None:
|
||||
self.api_key = os.getenv("OPENAI_API_KEY")
|
||||
if self.api_key is None:
|
||||
raise ValueError("OPENAI_API_KEY is required")
|
||||
|
||||
base_params = {
|
||||
"api_key": self.api_key,
|
||||
"organization": self.organization,
|
||||
"project": self.project,
|
||||
"base_url": self.base_url
|
||||
or self.api_base
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or None,
|
||||
"timeout": self.timeout,
|
||||
"max_retries": self.max_retries,
|
||||
"default_headers": self.default_headers,
|
||||
"default_query": self.default_query,
|
||||
}
|
||||
self.is_o1_model = "o1" in self.model.lower()
|
||||
self.supports_function_calling = not self.is_o1_model
|
||||
self.is_gpt4_model = "gpt-4" in self.model.lower()
|
||||
self.supports_stop_words = not self.is_o1_model
|
||||
|
||||
client_params = {k: v for k, v in base_params.items() if v is not None}
|
||||
params = self.model_dump(
|
||||
include={
|
||||
"api_key",
|
||||
"organization",
|
||||
"project",
|
||||
"base_url",
|
||||
"timeout",
|
||||
"max_retries",
|
||||
"default_headers",
|
||||
"default_query",
|
||||
},
|
||||
exclude_none=True,
|
||||
)
|
||||
|
||||
if self.interceptor:
|
||||
transport = HTTPTransport(interceptor=self.interceptor)
|
||||
http_client = httpx.Client(transport=transport)
|
||||
params["http_client"] = http_client
|
||||
|
||||
if self.client_params:
|
||||
client_params.update(self.client_params)
|
||||
params.update(self.client_params)
|
||||
|
||||
return client_params
|
||||
self._client = OpenAI(**params)
|
||||
return self
|
||||
|
||||
def call(
|
||||
self,
|
||||
@@ -213,38 +250,26 @@ class OpenAICompletion(BaseLLM):
|
||||
self, messages: list[LLMMessage], tools: list[dict[str, BaseTool]] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Prepare parameters for OpenAI chat completion."""
|
||||
params: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
}
|
||||
if self.stream:
|
||||
params["stream"] = self.stream
|
||||
|
||||
params = self.model_dump(
|
||||
include={
|
||||
"model",
|
||||
"stream",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
"max_completion_tokens",
|
||||
"max_tokens",
|
||||
"seed",
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"reasoning_effort",
|
||||
},
|
||||
exclude_none=True,
|
||||
)
|
||||
params["messages"] = messages
|
||||
params.update(self.additional_params)
|
||||
|
||||
if self.temperature is not None:
|
||||
params["temperature"] = self.temperature
|
||||
if self.top_p is not None:
|
||||
params["top_p"] = self.top_p
|
||||
if self.frequency_penalty is not None:
|
||||
params["frequency_penalty"] = self.frequency_penalty
|
||||
if self.presence_penalty is not None:
|
||||
params["presence_penalty"] = self.presence_penalty
|
||||
if self.max_completion_tokens is not None:
|
||||
params["max_completion_tokens"] = self.max_completion_tokens
|
||||
elif self.max_tokens is not None:
|
||||
params["max_tokens"] = self.max_tokens
|
||||
if self.seed is not None:
|
||||
params["seed"] = self.seed
|
||||
if self.logprobs is not None:
|
||||
params["logprobs"] = self.logprobs
|
||||
if self.top_logprobs is not None:
|
||||
params["top_logprobs"] = self.top_logprobs
|
||||
|
||||
# Handle o1 model specific parameters
|
||||
if self.is_o1_model and self.reasoning_effort:
|
||||
params["reasoning_effort"] = self.reasoning_effort
|
||||
|
||||
if tools:
|
||||
params["tools"] = self._convert_tools_for_interference(tools)
|
||||
params["tool_choice"] = "auto"
|
||||
@@ -296,14 +321,14 @@ class OpenAICompletion(BaseLLM):
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming chat completion."""
|
||||
try:
|
||||
if response_model:
|
||||
parsed_response = self.client.beta.chat.completions.parse(
|
||||
parsed_response = self._client.beta.chat.completions.parse(
|
||||
**params,
|
||||
response_format=response_model,
|
||||
)
|
||||
@@ -327,7 +352,7 @@ class OpenAICompletion(BaseLLM):
|
||||
)
|
||||
return structured_json
|
||||
|
||||
response: ChatCompletion = self.client.chat.completions.create(**params)
|
||||
response: ChatCompletion = self._client.chat.completions.create(**params)
|
||||
|
||||
usage = self._extract_openai_token_usage(response)
|
||||
|
||||
@@ -419,8 +444,8 @@ class OpenAICompletion(BaseLLM):
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
"""Handle streaming chat completion."""
|
||||
@@ -429,7 +454,7 @@ class OpenAICompletion(BaseLLM):
|
||||
|
||||
if response_model:
|
||||
completion_stream: Iterator[ChatCompletionChunk] = (
|
||||
self.client.chat.completions.create(**params)
|
||||
self._client.chat.completions.create(**params)
|
||||
)
|
||||
|
||||
accumulated_content = ""
|
||||
@@ -472,7 +497,7 @@ class OpenAICompletion(BaseLLM):
|
||||
)
|
||||
return accumulated_content
|
||||
|
||||
stream: Iterator[ChatCompletionChunk] = self.client.chat.completions.create(
|
||||
stream: Iterator[ChatCompletionChunk] = self._client.chat.completions.create(
|
||||
**params
|
||||
)
|
||||
|
||||
@@ -550,58 +575,31 @@ class OpenAICompletion(BaseLLM):
|
||||
|
||||
return full_response
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the model supports function calling."""
|
||||
return not self.is_o1_model
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
"""Check if the model supports stop words."""
|
||||
return not self.is_o1_model
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
"""Get the context window size for the model."""
|
||||
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES
|
||||
|
||||
min_context = 1024
|
||||
max_context = 2097152
|
||||
|
||||
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
|
||||
if value < min_context or value > max_context:
|
||||
if value < MIN_CONTEXT_WINDOW or value > MAX_CONTEXT_WINDOW:
|
||||
raise ValueError(
|
||||
f"Context window for {key} must be between {min_context} and {max_context}"
|
||||
f"Context window for {key} must be between {MIN_CONTEXT_WINDOW} and {MAX_CONTEXT_WINDOW}"
|
||||
)
|
||||
|
||||
# Context window sizes for OpenAI models
|
||||
context_windows = {
|
||||
"gpt-4": 8192,
|
||||
"gpt-4o": 128000,
|
||||
"gpt-4o-mini": 200000,
|
||||
"gpt-4-turbo": 128000,
|
||||
"gpt-4.1": 1047576,
|
||||
"gpt-4.1-mini-2025-04-14": 1047576,
|
||||
"gpt-4.1-nano-2025-04-14": 1047576,
|
||||
"o1-preview": 128000,
|
||||
"o1-mini": 128000,
|
||||
"o3-mini": 200000,
|
||||
"o4-mini": 200000,
|
||||
}
|
||||
|
||||
# Find the best match for the model name
|
||||
for model_prefix, size in context_windows.items():
|
||||
for model_prefix, size in OPENAI_CONTEXT_WINDOWS.items():
|
||||
if self.model.startswith(model_prefix):
|
||||
return int(size * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
|
||||
# Default context window size
|
||||
return int(8192 * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
|
||||
def _extract_openai_token_usage(self, response: ChatCompletion) -> dict[str, Any]:
|
||||
@staticmethod
|
||||
def _extract_openai_token_usage(response: ChatCompletion) -> dict[str, Any]:
|
||||
"""Extract token usage from OpenAI ChatCompletion response."""
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
if response.usage:
|
||||
usage = response.usage
|
||||
return {
|
||||
"prompt_tokens": getattr(usage, "prompt_tokens", 0),
|
||||
"completion_tokens": getattr(usage, "completion_tokens", 0),
|
||||
"total_tokens": getattr(usage, "total_tokens", 0),
|
||||
"prompt_tokens": usage.prompt_tokens,
|
||||
"completion_tokens": usage.completion_tokens,
|
||||
"total_tokens": usage.total_tokens,
|
||||
}
|
||||
return {"total_tokens": 0}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user