mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-26 16:48:13 +00:00
Compare commits
6 Commits
devin/1769
...
gl/feat/py
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
81850350e8 | ||
|
|
7404d8f198 | ||
|
|
138b9af274 | ||
|
|
a5e0803f20 | ||
|
|
c4279b0339 | ||
|
|
965aa48ea1 |
@@ -618,22 +618,22 @@ class Agent(BaseAgent):
|
|||||||
response_template=self.response_template,
|
response_template=self.response_template,
|
||||||
).task_execution()
|
).task_execution()
|
||||||
|
|
||||||
stop_words = [self.i18n.slice("observation")]
|
stop_sequences = [self.i18n.slice("observation")]
|
||||||
|
|
||||||
if self.response_template:
|
if self.response_template:
|
||||||
stop_words.append(
|
stop_sequences.append(
|
||||||
self.response_template.split("{{ .Response }}")[1].strip()
|
self.response_template.split("{{ .Response }}")[1].strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
self.agent_executor = CrewAgentExecutor(
|
self.agent_executor = CrewAgentExecutor(
|
||||||
llm=self.llm,
|
llm=self.llm, # type: ignore[arg-type]
|
||||||
task=task, # type: ignore[arg-type]
|
task=task, # type: ignore[arg-type]
|
||||||
agent=self,
|
agent=self,
|
||||||
crew=self.crew,
|
crew=self.crew,
|
||||||
tools=parsed_tools,
|
tools=parsed_tools,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
original_tools=raw_tools,
|
original_tools=raw_tools,
|
||||||
stop_words=stop_words,
|
stop_sequences=stop_sequences,
|
||||||
max_iter=self.max_iter,
|
max_iter=self.max_iter,
|
||||||
tools_handler=self.tools_handler,
|
tools_handler=self.tools_handler,
|
||||||
tools_names=get_tool_names(parsed_tools),
|
tools_names=get_tool_names(parsed_tools),
|
||||||
@@ -974,7 +974,9 @@ class Agent(BaseAgent):
|
|||||||
path = parsed.path.replace("/", "_").strip("_")
|
path = parsed.path.replace("/", "_").strip("_")
|
||||||
return f"{domain}_{path}" if path else domain
|
return f"{domain}_{path}" if path else domain
|
||||||
|
|
||||||
def _get_mcp_tool_schemas(self, server_params: dict) -> dict[str, dict]:
|
def _get_mcp_tool_schemas(
|
||||||
|
self, server_params: dict[str, Any]
|
||||||
|
) -> dict[str, dict[str, Any]] | Any:
|
||||||
"""Get tool schemas from MCP server for wrapper creation with caching."""
|
"""Get tool schemas from MCP server for wrapper creation with caching."""
|
||||||
server_url = server_params["url"]
|
server_url = server_params["url"]
|
||||||
|
|
||||||
@@ -1006,7 +1008,7 @@ class Agent(BaseAgent):
|
|||||||
|
|
||||||
async def _get_mcp_tool_schemas_async(
|
async def _get_mcp_tool_schemas_async(
|
||||||
self, server_params: dict[str, Any]
|
self, server_params: dict[str, Any]
|
||||||
) -> dict[str, dict]:
|
) -> dict[str, dict[str, Any]]:
|
||||||
"""Async implementation of MCP tool schema retrieval with timeouts and retries."""
|
"""Async implementation of MCP tool schema retrieval with timeouts and retries."""
|
||||||
server_url = server_params["url"]
|
server_url = server_params["url"]
|
||||||
return await self._retry_mcp_discovery(
|
return await self._retry_mcp_discovery(
|
||||||
@@ -1014,7 +1016,7 @@ class Agent(BaseAgent):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def _retry_mcp_discovery(
|
async def _retry_mcp_discovery(
|
||||||
self, operation_func, server_url: str
|
self, operation_func: Any, server_url: str
|
||||||
) -> dict[str, dict[str, Any]]:
|
) -> dict[str, dict[str, Any]]:
|
||||||
"""Retry MCP discovery operation with exponential backoff, avoiding try-except in loop."""
|
"""Retry MCP discovery operation with exponential backoff, avoiding try-except in loop."""
|
||||||
last_error = None
|
last_error = None
|
||||||
@@ -1045,7 +1047,7 @@ class Agent(BaseAgent):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _attempt_mcp_discovery(
|
async def _attempt_mcp_discovery(
|
||||||
operation_func, server_url: str
|
operation_func: Any, server_url: str
|
||||||
) -> tuple[dict[str, dict[str, Any]] | None, str, bool]:
|
) -> tuple[dict[str, dict[str, Any]] | None, str, bool]:
|
||||||
"""Attempt single MCP discovery operation and return (result, error_message, should_retry)."""
|
"""Attempt single MCP discovery operation and return (result, error_message, should_retry)."""
|
||||||
try:
|
try:
|
||||||
@@ -1149,13 +1151,13 @@ class Agent(BaseAgent):
|
|||||||
Field(..., description=field_description),
|
Field(..., description=field_description),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
field_definitions[field_name] = (
|
field_definitions[field_name] = ( # type: ignore[assignment]
|
||||||
field_type | None,
|
field_type | None,
|
||||||
Field(default=None, description=field_description),
|
Field(default=None, description=field_description),
|
||||||
)
|
)
|
||||||
|
|
||||||
model_name = f"{tool_name.replace('-', '_').replace(' ', '_')}Schema"
|
model_name = f"{tool_name.replace('-', '_').replace(' ', '_')}Schema"
|
||||||
return create_model(model_name, **field_definitions)
|
return create_model(model_name, **field_definitions) # type: ignore[no-any-return,call-overload]
|
||||||
|
|
||||||
def _json_type_to_python(self, field_schema: dict[str, Any]) -> type:
|
def _json_type_to_python(self, field_schema: dict[str, Any]) -> type:
|
||||||
"""Convert JSON Schema type to Python type.
|
"""Convert JSON Schema type to Python type.
|
||||||
@@ -1175,12 +1177,12 @@ class Agent(BaseAgent):
|
|||||||
if "const" in option:
|
if "const" in option:
|
||||||
types.append(str)
|
types.append(str)
|
||||||
else:
|
else:
|
||||||
types.append(self._json_type_to_python(option))
|
types.append(self._json_type_to_python(option)) # type: ignore[arg-type]
|
||||||
unique_types = list(set(types))
|
unique_types = list(set(types))
|
||||||
if len(unique_types) > 1:
|
if len(unique_types) > 1:
|
||||||
result = unique_types[0]
|
result = unique_types[0]
|
||||||
for t in unique_types[1:]:
|
for t in unique_types[1:]:
|
||||||
result = result | t
|
result = result | t # type: ignore[assignment]
|
||||||
return result
|
return result
|
||||||
return unique_types[0]
|
return unique_types[0]
|
||||||
|
|
||||||
@@ -1193,10 +1195,10 @@ class Agent(BaseAgent):
|
|||||||
"object": dict,
|
"object": dict,
|
||||||
}
|
}
|
||||||
|
|
||||||
return type_mapping.get(json_type, Any)
|
return type_mapping.get(json_type, Any) # type: ignore[arg-type]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _fetch_amp_mcp_servers(mcp_name: str) -> list[dict]:
|
def _fetch_amp_mcp_servers(mcp_name: str) -> list[dict[str, Any]]:
|
||||||
"""Fetch MCP server configurations from CrewAI AMP API."""
|
"""Fetch MCP server configurations from CrewAI AMP API."""
|
||||||
# TODO: Implement AMP API call to "integrations/mcps" endpoint
|
# TODO: Implement AMP API call to "integrations/mcps" endpoint
|
||||||
# Should return list of server configs with URLs
|
# Should return list of server configs with URLs
|
||||||
@@ -1435,7 +1437,7 @@ class Agent(BaseAgent):
|
|||||||
goal=self.goal,
|
goal=self.goal,
|
||||||
backstory=self.backstory,
|
backstory=self.backstory,
|
||||||
llm=self.llm,
|
llm=self.llm,
|
||||||
tools=self.tools or [],
|
tools=self.tools,
|
||||||
max_iterations=self.max_iter,
|
max_iterations=self.max_iter,
|
||||||
max_execution_time=self.max_execution_time,
|
max_execution_time=self.max_execution_time,
|
||||||
respect_context_window=self.respect_context_window,
|
respect_context_window=self.respect_context_window,
|
||||||
|
|||||||
@@ -137,7 +137,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
|||||||
default=False,
|
default=False,
|
||||||
description="Enable agent to delegate and ask questions among each other.",
|
description="Enable agent to delegate and ask questions among each other.",
|
||||||
)
|
)
|
||||||
tools: list[BaseTool] | None = Field(
|
tools: list[BaseTool] = Field(
|
||||||
default_factory=list, description="Tools at agents' disposal"
|
default_factory=list, description="Tools at agents' disposal"
|
||||||
)
|
)
|
||||||
max_iter: int = Field(
|
max_iter: int = Field(
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
max_iter: int,
|
max_iter: int,
|
||||||
tools: list[CrewStructuredTool],
|
tools: list[CrewStructuredTool],
|
||||||
tools_names: str,
|
tools_names: str,
|
||||||
stop_words: list[str],
|
stop_sequences: list[str],
|
||||||
tools_description: str,
|
tools_description: str,
|
||||||
tools_handler: ToolsHandler,
|
tools_handler: ToolsHandler,
|
||||||
step_callback: Any = None,
|
step_callback: Any = None,
|
||||||
@@ -95,7 +95,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
max_iter: Maximum iterations.
|
max_iter: Maximum iterations.
|
||||||
tools: Available tools.
|
tools: Available tools.
|
||||||
tools_names: Tool names string.
|
tools_names: Tool names string.
|
||||||
stop_words: Stop word list.
|
stop_sequences: Stop sequences list for halting generation.
|
||||||
tools_description: Tool descriptions.
|
tools_description: Tool descriptions.
|
||||||
tools_handler: Tool handler instance.
|
tools_handler: Tool handler instance.
|
||||||
step_callback: Optional step callback.
|
step_callback: Optional step callback.
|
||||||
@@ -114,7 +114,6 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
self.prompt = prompt
|
self.prompt = prompt
|
||||||
self.tools = tools
|
self.tools = tools
|
||||||
self.tools_names = tools_names
|
self.tools_names = tools_names
|
||||||
self.stop = stop_words
|
|
||||||
self.max_iter = max_iter
|
self.max_iter = max_iter
|
||||||
self.callbacks = callbacks or []
|
self.callbacks = callbacks or []
|
||||||
self._printer: Printer = Printer()
|
self._printer: Printer = Printer()
|
||||||
@@ -131,15 +130,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
self.iterations = 0
|
self.iterations = 0
|
||||||
self.log_error_after = 3
|
self.log_error_after = 3
|
||||||
if self.llm:
|
if self.llm:
|
||||||
# This may be mutating the shared llm object and needs further evaluation
|
self.llm.stop_sequences.extend(stop_sequences)
|
||||||
existing_stop = getattr(self.llm, "stop", [])
|
|
||||||
self.llm.stop = list(
|
|
||||||
set(
|
|
||||||
existing_stop + self.stop
|
|
||||||
if isinstance(existing_stop, list)
|
|
||||||
else self.stop
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def use_stop_words(self) -> bool:
|
def use_stop_words(self) -> bool:
|
||||||
@@ -148,7 +139,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
Returns:
|
Returns:
|
||||||
bool: True if tool should be used or not.
|
bool: True if tool should be used or not.
|
||||||
"""
|
"""
|
||||||
return self.llm.supports_stop_words() if self.llm else False
|
return self.llm.supports_stop_words if self.llm else False
|
||||||
|
|
||||||
def invoke(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
def invoke(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""Execute the agent with given inputs.
|
"""Execute the agent with given inputs.
|
||||||
|
|||||||
@@ -20,8 +20,7 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
import httpx
|
from pydantic import BaseModel, Field, model_validator
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from crewai.events.event_bus import crewai_event_bus
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
@@ -54,7 +53,6 @@ if TYPE_CHECKING:
|
|||||||
from litellm.utils import supports_response_schema
|
from litellm.utils import supports_response_schema
|
||||||
|
|
||||||
from crewai.agent.core import Agent
|
from crewai.agent.core import Agent
|
||||||
from crewai.llms.hooks.base import BaseInterceptor
|
|
||||||
from crewai.task import Task
|
from crewai.task import Task
|
||||||
from crewai.tools.base_tool import BaseTool
|
from crewai.tools.base_tool import BaseTool
|
||||||
from crewai.utilities.types import LLMMessage
|
from crewai.utilities.types import LLMMessage
|
||||||
@@ -320,7 +318,138 @@ class AccumulatedToolArgs(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class LLM(BaseLLM):
|
class LLM(BaseLLM):
|
||||||
completion_cost: float | None = None
|
completion_cost: float | None = Field(
|
||||||
|
default=None, description="The completion cost of the LLM."
|
||||||
|
)
|
||||||
|
top_p: float | None = Field(
|
||||||
|
default=None, description="Sampling probability threshold."
|
||||||
|
)
|
||||||
|
n: int | None = Field(
|
||||||
|
default=None, description="Number of completions to generate."
|
||||||
|
)
|
||||||
|
max_completion_tokens: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Maximum number of tokens to generate in the completion.",
|
||||||
|
)
|
||||||
|
max_tokens: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Maximum number of tokens allowed in the prompt + completion.",
|
||||||
|
)
|
||||||
|
presence_penalty: float | None = Field(
|
||||||
|
default=None, description="Penalty on the presence penalty."
|
||||||
|
)
|
||||||
|
frequency_penalty: float | None = Field(
|
||||||
|
default=None, description="Penalty on the frequency penalty."
|
||||||
|
)
|
||||||
|
logit_bias: dict[int, float] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Modifies the likelihood of specified tokens appearing in the completion.",
|
||||||
|
)
|
||||||
|
response_format: type[BaseModel] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Pydantic model class for structured response parsing.",
|
||||||
|
)
|
||||||
|
seed: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Random seed for reproducibility.",
|
||||||
|
)
|
||||||
|
logprobs: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Number of top logprobs to return.",
|
||||||
|
)
|
||||||
|
top_logprobs: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Number of top logprobs to return.",
|
||||||
|
)
|
||||||
|
api_base: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Base URL for the API endpoint.",
|
||||||
|
)
|
||||||
|
api_version: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="API version to use.",
|
||||||
|
)
|
||||||
|
callbacks: list[Any] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="List of callback handlers for LLM events.",
|
||||||
|
)
|
||||||
|
reasoning_effort: Literal["none", "low", "medium", "high"] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Level of reasoning effort for the LLM.",
|
||||||
|
)
|
||||||
|
context_window_size: int = Field(
|
||||||
|
default=0,
|
||||||
|
description="The context window size of the LLM.",
|
||||||
|
)
|
||||||
|
is_anthropic: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="Indicates if the model is from Anthropic provider.",
|
||||||
|
)
|
||||||
|
supports_function_calling: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="Indicates if the model supports function calling.",
|
||||||
|
)
|
||||||
|
supports_stop_words: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="Indicates if the model supports stop words.",
|
||||||
|
)
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def initialize_client(self) -> Self:
|
||||||
|
self.is_anthropic = any(
|
||||||
|
prefix in self.model.lower() for prefix in ANTHROPIC_PREFIXES
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
provider = self._get_custom_llm_provider()
|
||||||
|
self.supports_function_calling = litellm.utils.supports_function_calling(
|
||||||
|
self.model, custom_llm_provider=provider
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to check function calling support: {e!s}")
|
||||||
|
self.supports_function_calling = False
|
||||||
|
try:
|
||||||
|
params = get_supported_openai_params(model=self.model)
|
||||||
|
self.supports_stop_words = params is not None and "stop" in params
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to get supported params: {e!s}")
|
||||||
|
self.supports_stop_words = False
|
||||||
|
|
||||||
|
with suppress_warnings():
|
||||||
|
callback_types = [type(callback) for callback in self.callbacks]
|
||||||
|
for callback in litellm.success_callback[:]:
|
||||||
|
if type(callback) in callback_types:
|
||||||
|
litellm.success_callback.remove(callback)
|
||||||
|
|
||||||
|
for callback in litellm._async_success_callback[:]:
|
||||||
|
if type(callback) in callback_types:
|
||||||
|
litellm._async_success_callback.remove(callback)
|
||||||
|
|
||||||
|
litellm.callbacks = self.callbacks
|
||||||
|
|
||||||
|
with suppress_warnings():
|
||||||
|
success_callbacks_str = os.environ.get("LITELLM_SUCCESS_CALLBACKS", "")
|
||||||
|
success_callbacks: list[str | Callable[..., Any] | CustomLogger] = []
|
||||||
|
if success_callbacks_str:
|
||||||
|
success_callbacks = [
|
||||||
|
cb.strip() for cb in success_callbacks_str.split(",") if cb.strip()
|
||||||
|
]
|
||||||
|
|
||||||
|
failure_callbacks_str = os.environ.get("LITELLM_FAILURE_CALLBACKS", "")
|
||||||
|
if failure_callbacks_str:
|
||||||
|
failure_callbacks: list[str | Callable[..., Any] | CustomLogger] = [
|
||||||
|
cb.strip() for cb in failure_callbacks_str.split(",") if cb.strip()
|
||||||
|
]
|
||||||
|
|
||||||
|
litellm.success_callback = success_callbacks
|
||||||
|
litellm.failure_callback = failure_callbacks
|
||||||
|
return self
|
||||||
|
|
||||||
|
# @computed_field
|
||||||
|
# @property
|
||||||
|
# def is_anthropic(self) -> bool:
|
||||||
|
# """Determine if the model is from Anthropic provider."""
|
||||||
|
# anthropic_prefixes = ("anthropic/", "claude-", "claude/")
|
||||||
|
# return any(prefix in self.model.lower() for prefix in anthropic_prefixes)
|
||||||
|
|
||||||
def __new__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> LLM:
|
def __new__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> LLM:
|
||||||
"""Factory method that routes to native SDK or falls back to LiteLLM."""
|
"""Factory method that routes to native SDK or falls back to LiteLLM."""
|
||||||
@@ -383,98 +512,6 @@ class LLM(BaseLLM):
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
timeout: float | int | None = None,
|
|
||||||
temperature: float | None = None,
|
|
||||||
top_p: float | None = None,
|
|
||||||
n: int | None = None,
|
|
||||||
stop: str | list[str] | None = None,
|
|
||||||
max_completion_tokens: int | None = None,
|
|
||||||
max_tokens: int | float | None = None,
|
|
||||||
presence_penalty: float | None = None,
|
|
||||||
frequency_penalty: float | None = None,
|
|
||||||
logit_bias: dict[int, float] | None = None,
|
|
||||||
response_format: type[BaseModel] | None = None,
|
|
||||||
seed: int | None = None,
|
|
||||||
logprobs: int | None = None,
|
|
||||||
top_logprobs: int | None = None,
|
|
||||||
base_url: str | None = None,
|
|
||||||
api_base: str | None = None,
|
|
||||||
api_version: str | None = None,
|
|
||||||
api_key: str | None = None,
|
|
||||||
callbacks: list[Any] | None = None,
|
|
||||||
reasoning_effort: Literal["none", "low", "medium", "high"] | None = None,
|
|
||||||
stream: bool = False,
|
|
||||||
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
"""Initialize LLM instance.
|
|
||||||
|
|
||||||
Note: This __init__ method is only called for fallback instances.
|
|
||||||
Native provider instances handle their own initialization in their respective classes.
|
|
||||||
"""
|
|
||||||
super().__init__(
|
|
||||||
model=model,
|
|
||||||
temperature=temperature,
|
|
||||||
api_key=api_key,
|
|
||||||
base_url=base_url,
|
|
||||||
timeout=timeout,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
self.model = model
|
|
||||||
self.timeout = timeout
|
|
||||||
self.temperature = temperature
|
|
||||||
self.top_p = top_p
|
|
||||||
self.n = n
|
|
||||||
self.max_completion_tokens = max_completion_tokens
|
|
||||||
self.max_tokens = max_tokens
|
|
||||||
self.presence_penalty = presence_penalty
|
|
||||||
self.frequency_penalty = frequency_penalty
|
|
||||||
self.logit_bias = logit_bias
|
|
||||||
self.response_format = response_format
|
|
||||||
self.seed = seed
|
|
||||||
self.logprobs = logprobs
|
|
||||||
self.top_logprobs = top_logprobs
|
|
||||||
self.base_url = base_url
|
|
||||||
self.api_base = api_base
|
|
||||||
self.api_version = api_version
|
|
||||||
self.api_key = api_key
|
|
||||||
self.callbacks = callbacks
|
|
||||||
self.context_window_size = 0
|
|
||||||
self.reasoning_effort = reasoning_effort
|
|
||||||
self.additional_params = kwargs
|
|
||||||
self.is_anthropic = self._is_anthropic_model(model)
|
|
||||||
self.stream = stream
|
|
||||||
self.interceptor = interceptor
|
|
||||||
|
|
||||||
litellm.drop_params = True
|
|
||||||
|
|
||||||
# Normalize self.stop to always be a list[str]
|
|
||||||
if stop is None:
|
|
||||||
self.stop: list[str] = []
|
|
||||||
elif isinstance(stop, str):
|
|
||||||
self.stop = [stop]
|
|
||||||
else:
|
|
||||||
self.stop = stop
|
|
||||||
|
|
||||||
self.set_callbacks(callbacks or [])
|
|
||||||
self.set_env_callbacks()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _is_anthropic_model(model: str) -> bool:
|
|
||||||
"""Determine if the model is from Anthropic provider.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: The model identifier string.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if the model is from Anthropic, False otherwise.
|
|
||||||
"""
|
|
||||||
anthropic_prefixes = ("anthropic/", "claude-", "claude/")
|
|
||||||
return any(prefix in model.lower() for prefix in anthropic_prefixes)
|
|
||||||
|
|
||||||
def _prepare_completion_params(
|
def _prepare_completion_params(
|
||||||
self,
|
self,
|
||||||
messages: str | list[LLMMessage],
|
messages: str | list[LLMMessage],
|
||||||
@@ -1188,8 +1225,6 @@ class LLM(BaseLLM):
|
|||||||
message["role"] = msg_role
|
message["role"] = msg_role
|
||||||
# --- 5) Set up callbacks if provided
|
# --- 5) Set up callbacks if provided
|
||||||
with suppress_warnings():
|
with suppress_warnings():
|
||||||
if callbacks and len(callbacks) > 0:
|
|
||||||
self.set_callbacks(callbacks)
|
|
||||||
try:
|
try:
|
||||||
# --- 6) Prepare parameters for the completion call
|
# --- 6) Prepare parameters for the completion call
|
||||||
params = self._prepare_completion_params(messages, tools)
|
params = self._prepare_completion_params(messages, tools)
|
||||||
@@ -1378,24 +1413,6 @@ class LLM(BaseLLM):
|
|||||||
"Please remove response_format or use a supported model."
|
"Please remove response_format or use a supported model."
|
||||||
)
|
)
|
||||||
|
|
||||||
def supports_function_calling(self) -> bool:
|
|
||||||
try:
|
|
||||||
provider = self._get_custom_llm_provider()
|
|
||||||
return litellm.utils.supports_function_calling(
|
|
||||||
self.model, custom_llm_provider=provider
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Failed to check function calling support: {e!s}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def supports_stop_words(self) -> bool:
|
|
||||||
try:
|
|
||||||
params = get_supported_openai_params(model=self.model)
|
|
||||||
return params is not None and "stop" in params
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Failed to get supported params: {e!s}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_context_window_size(self) -> int:
|
def get_context_window_size(self) -> int:
|
||||||
"""
|
"""
|
||||||
Returns the context window size, using 75% of the maximum to avoid
|
Returns the context window size, using 75% of the maximum to avoid
|
||||||
@@ -1425,60 +1442,6 @@ class LLM(BaseLLM):
|
|||||||
self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO)
|
self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO)
|
||||||
return self.context_window_size
|
return self.context_window_size
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def set_callbacks(callbacks: list[Any]) -> None:
|
|
||||||
"""
|
|
||||||
Attempt to keep a single set of callbacks in litellm by removing old
|
|
||||||
duplicates and adding new ones.
|
|
||||||
"""
|
|
||||||
with suppress_warnings():
|
|
||||||
callback_types = [type(callback) for callback in callbacks]
|
|
||||||
for callback in litellm.success_callback[:]:
|
|
||||||
if type(callback) in callback_types:
|
|
||||||
litellm.success_callback.remove(callback)
|
|
||||||
|
|
||||||
for callback in litellm._async_success_callback[:]:
|
|
||||||
if type(callback) in callback_types:
|
|
||||||
litellm._async_success_callback.remove(callback)
|
|
||||||
|
|
||||||
litellm.callbacks = callbacks
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def set_env_callbacks() -> None:
|
|
||||||
"""Sets the success and failure callbacks for the LiteLLM library from environment variables.
|
|
||||||
|
|
||||||
This method reads the `LITELLM_SUCCESS_CALLBACKS` and `LITELLM_FAILURE_CALLBACKS`
|
|
||||||
environment variables, which should contain comma-separated lists of callback names.
|
|
||||||
It then assigns these lists to `litellm.success_callback` and `litellm.failure_callback`,
|
|
||||||
respectively.
|
|
||||||
|
|
||||||
If the environment variables are not set or are empty, the corresponding callback lists
|
|
||||||
will be set to empty lists.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
LITELLM_SUCCESS_CALLBACKS="langfuse,langsmith"
|
|
||||||
LITELLM_FAILURE_CALLBACKS="langfuse"
|
|
||||||
|
|
||||||
This will set `litellm.success_callback` to ["langfuse", "langsmith"] and
|
|
||||||
`litellm.failure_callback` to ["langfuse"].
|
|
||||||
"""
|
|
||||||
with suppress_warnings():
|
|
||||||
success_callbacks_str = os.environ.get("LITELLM_SUCCESS_CALLBACKS", "")
|
|
||||||
success_callbacks: list[str | Callable[..., Any] | CustomLogger] = []
|
|
||||||
if success_callbacks_str:
|
|
||||||
success_callbacks = [
|
|
||||||
cb.strip() for cb in success_callbacks_str.split(",") if cb.strip()
|
|
||||||
]
|
|
||||||
|
|
||||||
failure_callbacks_str = os.environ.get("LITELLM_FAILURE_CALLBACKS", "")
|
|
||||||
if failure_callbacks_str:
|
|
||||||
failure_callbacks: list[str | Callable[..., Any] | CustomLogger] = [
|
|
||||||
cb.strip() for cb in failure_callbacks_str.split(",") if cb.strip()
|
|
||||||
]
|
|
||||||
|
|
||||||
litellm.success_callback = success_callbacks
|
|
||||||
litellm.failure_callback = failure_callbacks
|
|
||||||
|
|
||||||
def __copy__(self) -> LLM:
|
def __copy__(self) -> LLM:
|
||||||
"""Create a shallow copy of the LLM instance."""
|
"""Create a shallow copy of the LLM instance."""
|
||||||
# Filter out parameters that are already explicitly passed to avoid conflicts
|
# Filter out parameters that are already explicitly passed to avoid conflicts
|
||||||
@@ -1539,7 +1502,7 @@ class LLM(BaseLLM):
|
|||||||
**filtered_params,
|
**filtered_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __deepcopy__(self, memo: dict[int, Any] | None) -> LLM:
|
def __deepcopy__(self, memo: dict[int, Any] | None) -> LLM: # type: ignore[override]
|
||||||
"""Create a deep copy of the LLM instance."""
|
"""Create a deep copy of the LLM instance."""
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
|
|||||||
@@ -13,8 +13,9 @@ import logging
|
|||||||
import re
|
import re
|
||||||
from typing import TYPE_CHECKING, Any, Final
|
from typing import TYPE_CHECKING, Any, Final
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import AliasChoices, BaseModel, Field, PrivateAttr, field_validator
|
||||||
|
|
||||||
|
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||||
from crewai.events.event_bus import crewai_event_bus
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
from crewai.events.types.llm_events import (
|
from crewai.events.types.llm_events import (
|
||||||
LLMCallCompletedEvent,
|
LLMCallCompletedEvent,
|
||||||
@@ -28,6 +29,7 @@ from crewai.events.types.tool_usage_events import (
|
|||||||
ToolUsageFinishedEvent,
|
ToolUsageFinishedEvent,
|
||||||
ToolUsageStartedEvent,
|
ToolUsageStartedEvent,
|
||||||
)
|
)
|
||||||
|
from crewai.llms.hooks import BaseInterceptor
|
||||||
from crewai.types.usage_metrics import UsageMetrics
|
from crewai.types.usage_metrics import UsageMetrics
|
||||||
|
|
||||||
|
|
||||||
@@ -43,7 +45,7 @@ DEFAULT_SUPPORTS_STOP_WORDS: Final[bool] = True
|
|||||||
_JSON_EXTRACTION_PATTERN: Final[re.Pattern[str]] = re.compile(r"\{.*}", re.DOTALL)
|
_JSON_EXTRACTION_PATTERN: Final[re.Pattern[str]] = re.compile(r"\{.*}", re.DOTALL)
|
||||||
|
|
||||||
|
|
||||||
class BaseLLM(ABC):
|
class BaseLLM(BaseModel, ABC):
|
||||||
"""Abstract base class for LLM implementations.
|
"""Abstract base class for LLM implementations.
|
||||||
|
|
||||||
This class defines the interface that all LLM implementations must follow.
|
This class defines the interface that all LLM implementations must follow.
|
||||||
@@ -55,70 +57,105 @@ class BaseLLM(ABC):
|
|||||||
implement proper validation for input parameters and provide clear error
|
implement proper validation for input parameters and provide clear error
|
||||||
messages when things go wrong.
|
messages when things go wrong.
|
||||||
|
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
model: The model identifier/name.
|
model: The model identifier/name.
|
||||||
temperature: Optional temperature setting for response generation.
|
temperature: Optional temperature setting for response generation.
|
||||||
stop: A list of stop sequences that the LLM should use to stop generation.
|
|
||||||
additional_params: Additional provider-specific parameters.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
is_litellm: bool = False
|
provider: str | re.Pattern[str] = Field(
|
||||||
|
default="openai", description="The provider of the LLM."
|
||||||
|
)
|
||||||
|
model: str = Field(description="The model identifier/name.")
|
||||||
|
temperature: float | None = Field(
|
||||||
|
default=None, ge=0, le=2, description="Temperature for response generation."
|
||||||
|
)
|
||||||
|
api_key: str | None = Field(default=None, description="API key for authentication.")
|
||||||
|
base_url: str | None = Field(default=None, description="Base URL for API calls.")
|
||||||
|
timeout: float | None = Field(default=None, description="Timeout for API calls.")
|
||||||
|
max_retries: int = Field(
|
||||||
|
default=2, description="Maximum number of API requests to make."
|
||||||
|
)
|
||||||
|
max_tokens: int | None = Field(
|
||||||
|
default=None, description="Maximum tokens for response generation."
|
||||||
|
)
|
||||||
|
stream: bool | None = Field(default=False, description="Stream the API requests.")
|
||||||
|
client: Any = Field(description="Underlying LLM client instance.")
|
||||||
|
interceptor: BaseInterceptor[Any, Any] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="An optional HTTPX interceptor for modifying requests/responses.",
|
||||||
|
)
|
||||||
|
client_params: dict[str, Any] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="Additional parameters for the underlying LLM client.",
|
||||||
|
)
|
||||||
|
supports_stop_words: bool = Field(
|
||||||
|
default=DEFAULT_SUPPORTS_STOP_WORDS,
|
||||||
|
description="Whether or not to support stop words.",
|
||||||
|
)
|
||||||
|
stop_sequences: list[str] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
validation_alias=AliasChoices("stop_sequences", "stop"),
|
||||||
|
description="Stop sequences for generation (synchronized with stop).",
|
||||||
|
)
|
||||||
|
is_litellm: bool = Field(
|
||||||
|
default=False, description="Is this LLM implementation in litellm?"
|
||||||
|
)
|
||||||
|
additional_params: dict[str, Any] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="Additional parameters for LLM calls.",
|
||||||
|
)
|
||||||
|
_token_usage: TokenProcess = PrivateAttr(default_factory=TokenProcess)
|
||||||
|
|
||||||
def __init__(
|
@field_validator("provider", mode="before")
|
||||||
self,
|
@classmethod
|
||||||
model: str,
|
def extract_provider_from_model(
|
||||||
temperature: float | None = None,
|
cls, v: str | re.Pattern[str] | None, info: Any
|
||||||
api_key: str | None = None,
|
) -> str | re.Pattern[str]:
|
||||||
base_url: str | None = None,
|
"""Extract provider from model string if not explicitly provided.
|
||||||
provider: str | None = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
"""Initialize the BaseLLM with default attributes.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: The model identifier/name.
|
v: Provided provider value (can be str, Pattern, or None)
|
||||||
temperature: Optional temperature setting for response generation.
|
info: Validation info containing other field values
|
||||||
stop: Optional list of stop sequences for generation.
|
|
||||||
**kwargs: Additional provider-specific parameters.
|
Returns:
|
||||||
|
Provider name (str) or Pattern
|
||||||
"""
|
"""
|
||||||
if not model:
|
# If provider explicitly provided, validate and return it
|
||||||
raise ValueError("Model name is required and cannot be empty")
|
if v is not None:
|
||||||
|
if not isinstance(v, (str, re.Pattern)):
|
||||||
|
raise ValueError(f"Provider must be str or Pattern, got {type(v)}")
|
||||||
|
return v
|
||||||
|
|
||||||
self.model = model
|
model: str = info.data.get("model", "")
|
||||||
self.temperature = temperature
|
if "/" in model:
|
||||||
self.api_key = api_key
|
return model.partition("/")[0]
|
||||||
self.base_url = base_url
|
return "openai"
|
||||||
# Store additional parameters for provider-specific use
|
|
||||||
self.additional_params = kwargs
|
|
||||||
self._provider = provider or "openai"
|
|
||||||
|
|
||||||
stop = kwargs.pop("stop", None)
|
@field_validator("stop_sequences", mode="before")
|
||||||
if stop is None:
|
@classmethod
|
||||||
self.stop: list[str] = []
|
def normalize_stop_sequences(
|
||||||
elif isinstance(stop, str):
|
cls, v: str | list[str] | set[str] | None
|
||||||
self.stop = [stop]
|
) -> list[str]:
|
||||||
elif isinstance(stop, list):
|
"""Validate and normalize stop sequences.
|
||||||
self.stop = stop
|
|
||||||
else:
|
|
||||||
self.stop = []
|
|
||||||
|
|
||||||
self._token_usage = {
|
Converts string to list and handles None values.
|
||||||
"total_tokens": 0,
|
AliasChoices handles accepting both 'stop' and 'stop_sequences' parameter names.
|
||||||
"prompt_tokens": 0,
|
"""
|
||||||
"completion_tokens": 0,
|
if v is None:
|
||||||
"successful_requests": 0,
|
return []
|
||||||
"cached_prompt_tokens": 0,
|
if isinstance(v, str):
|
||||||
}
|
return [v]
|
||||||
|
if isinstance(v, set):
|
||||||
|
return list(v)
|
||||||
|
if isinstance(v, list):
|
||||||
|
return v
|
||||||
|
return []
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def provider(self) -> str:
|
def stop(self) -> list[str]:
|
||||||
"""Get the provider of the LLM."""
|
"""Alias for stop_sequences to maintain backward compatibility."""
|
||||||
return self._provider
|
return self.stop_sequences
|
||||||
|
|
||||||
@provider.setter
|
|
||||||
def provider(self, value: str) -> None:
|
|
||||||
"""Set the provider of the LLM."""
|
|
||||||
self._provider = value
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def call(
|
def call(
|
||||||
@@ -171,14 +208,6 @@ class BaseLLM(ABC):
|
|||||||
"""
|
"""
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
def supports_stop_words(self) -> bool:
|
|
||||||
"""Check if the LLM supports stop words.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if the LLM supports stop words, False otherwise.
|
|
||||||
"""
|
|
||||||
return DEFAULT_SUPPORTS_STOP_WORDS
|
|
||||||
|
|
||||||
def _supports_stop_words_implementation(self) -> bool:
|
def _supports_stop_words_implementation(self) -> bool:
|
||||||
"""Check if stop words are configured for this LLM instance.
|
"""Check if stop words are configured for this LLM instance.
|
||||||
|
|
||||||
@@ -506,7 +535,7 @@ class BaseLLM(ABC):
|
|||||||
"""
|
"""
|
||||||
if "/" in model:
|
if "/" in model:
|
||||||
return model.partition("/")[0]
|
return model.partition("/")[0]
|
||||||
return "openai" # Default provider
|
return "openai"
|
||||||
|
|
||||||
def _track_token_usage_internal(self, usage_data: dict[str, Any]) -> None:
|
def _track_token_usage_internal(self, usage_data: dict[str, Any]) -> None:
|
||||||
"""Track token usage internally in the LLM instance.
|
"""Track token usage internally in the LLM instance.
|
||||||
@@ -535,11 +564,11 @@ class BaseLLM(ABC):
|
|||||||
or 0
|
or 0
|
||||||
)
|
)
|
||||||
|
|
||||||
self._token_usage["prompt_tokens"] += prompt_tokens
|
self._token_usage.prompt_tokens += prompt_tokens
|
||||||
self._token_usage["completion_tokens"] += completion_tokens
|
self._token_usage.completion_tokens += completion_tokens
|
||||||
self._token_usage["total_tokens"] += prompt_tokens + completion_tokens
|
self._token_usage.total_tokens += prompt_tokens + completion_tokens
|
||||||
self._token_usage["successful_requests"] += 1
|
self._token_usage.successful_requests += 1
|
||||||
self._token_usage["cached_prompt_tokens"] += cached_tokens
|
self._token_usage.cached_prompt_tokens += cached_tokens
|
||||||
|
|
||||||
def get_token_usage_summary(self) -> UsageMetrics:
|
def get_token_usage_summary(self) -> UsageMetrics:
|
||||||
"""Get summary of token usage for this LLM instance.
|
"""Get summary of token usage for this LLM instance.
|
||||||
@@ -547,4 +576,10 @@ class BaseLLM(ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
Dictionary with token usage totals
|
Dictionary with token usage totals
|
||||||
"""
|
"""
|
||||||
return UsageMetrics(**self._token_usage)
|
return UsageMetrics(
|
||||||
|
prompt_tokens=self._token_usage.prompt_tokens,
|
||||||
|
completion_tokens=self._token_usage.completion_tokens,
|
||||||
|
total_tokens=self._token_usage.total_tokens,
|
||||||
|
successful_requests=self._token_usage.successful_requests,
|
||||||
|
cached_prompt_tokens=self._token_usage.cached_prompt_tokens,
|
||||||
|
)
|
||||||
|
|||||||
@@ -5,11 +5,14 @@ import logging
|
|||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, Any, cast
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field, PrivateAttr, computed_field, model_validator
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
from crewai.events.types.llm_events import LLMCallType
|
from crewai.events.types.llm_events import LLMCallType
|
||||||
|
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO
|
||||||
from crewai.llms.base_llm import BaseLLM
|
from crewai.llms.base_llm import BaseLLM
|
||||||
from crewai.llms.hooks.transport import HTTPTransport
|
from crewai.llms.hooks.transport import HTTPTransport
|
||||||
|
from crewai.llms.providers.utils.common import safe_tool_conversion
|
||||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||||
LLMContextLengthExceededError,
|
LLMContextLengthExceededError,
|
||||||
@@ -18,7 +21,8 @@ from crewai.utilities.types import LLMMessage
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from crewai.llms.hooks.base import BaseInterceptor
|
from crewai.agent import Agent
|
||||||
|
from crewai.task import Task
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from anthropic import Anthropic
|
from anthropic import Anthropic
|
||||||
@@ -31,6 +35,19 @@ except ImportError:
|
|||||||
) from None
|
) from None
|
||||||
|
|
||||||
|
|
||||||
|
ANTHROPIC_CONTEXT_WINDOWS: dict[str, int] = {
|
||||||
|
"claude-3-5-sonnet": 200000,
|
||||||
|
"claude-3-5-haiku": 200000,
|
||||||
|
"claude-3-opus": 200000,
|
||||||
|
"claude-3-sonnet": 200000,
|
||||||
|
"claude-3-haiku": 200000,
|
||||||
|
"claude-3-7-sonnet": 200000,
|
||||||
|
"claude-2.1": 200000,
|
||||||
|
"claude-2": 100000,
|
||||||
|
"claude-instant": 100000,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class AnthropicCompletion(BaseLLM):
|
class AnthropicCompletion(BaseLLM):
|
||||||
"""Anthropic native completion implementation.
|
"""Anthropic native completion implementation.
|
||||||
|
|
||||||
@@ -38,110 +55,69 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
offering native tool use, streaming support, and proper message formatting.
|
offering native tool use, streaming support, and proper message formatting.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
model: str = Field(
|
||||||
self,
|
default="claude-3-5-sonnet-20241022",
|
||||||
model: str = "claude-3-5-sonnet-20241022",
|
description="Anthropic model name (e.g., 'claude-3-5-sonnet-20241022')",
|
||||||
api_key: str | None = None,
|
)
|
||||||
base_url: str | None = None,
|
max_tokens: int = Field(
|
||||||
timeout: float | None = None,
|
default=4096,
|
||||||
max_retries: int = 2,
|
description="Maximum number of allowed tokens in response.",
|
||||||
temperature: float | None = None,
|
)
|
||||||
max_tokens: int = 4096, # Required for Anthropic
|
top_p: float | None = Field(
|
||||||
top_p: float | None = None,
|
default=None,
|
||||||
stop_sequences: list[str] | None = None,
|
description="Nucleus sampling parameter.",
|
||||||
stream: bool = False,
|
)
|
||||||
client_params: dict[str, Any] | None = None,
|
_client: Anthropic = PrivateAttr(
|
||||||
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
|
default_factory=Anthropic,
|
||||||
**kwargs: Any,
|
)
|
||||||
):
|
|
||||||
"""Initialize Anthropic chat completion client.
|
|
||||||
|
|
||||||
Args:
|
@model_validator(mode="after")
|
||||||
model: Anthropic model name (e.g., 'claude-3-5-sonnet-20241022')
|
def initialize_client(self) -> Self:
|
||||||
api_key: Anthropic API key (defaults to ANTHROPIC_API_KEY env var)
|
"""Initialize the Anthropic client after Pydantic validation.
|
||||||
base_url: Custom base URL for Anthropic API
|
|
||||||
timeout: Request timeout in seconds
|
This runs after all field validation is complete, ensuring that:
|
||||||
max_retries: Maximum number of retries
|
- All BaseLLM fields are set (model, temperature, stop_sequences, etc.)
|
||||||
temperature: Sampling temperature (0-1)
|
- Field validators have run (stop_sequences is normalized to set[str])
|
||||||
max_tokens: Maximum tokens in response (required for Anthropic)
|
- API key and other configuration is ready
|
||||||
top_p: Nucleus sampling parameter
|
|
||||||
stop_sequences: Stop sequences (Anthropic uses stop_sequences, not stop)
|
|
||||||
stream: Enable streaming responses
|
|
||||||
client_params: Additional parameters for the Anthropic client
|
|
||||||
interceptor: HTTP interceptor for modifying requests/responses at transport level.
|
|
||||||
**kwargs: Additional parameters
|
|
||||||
"""
|
"""
|
||||||
super().__init__(
|
|
||||||
model=model, temperature=temperature, stop=stop_sequences or [], **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
# Client params
|
|
||||||
self.interceptor = interceptor
|
|
||||||
self.client_params = client_params
|
|
||||||
self.base_url = base_url
|
|
||||||
self.timeout = timeout
|
|
||||||
self.max_retries = max_retries
|
|
||||||
|
|
||||||
self.client = Anthropic(**self._get_client_params())
|
|
||||||
|
|
||||||
# Store completion parameters
|
|
||||||
self.max_tokens = max_tokens
|
|
||||||
self.top_p = top_p
|
|
||||||
self.stream = stream
|
|
||||||
self.stop_sequences = stop_sequences or []
|
|
||||||
|
|
||||||
# Model-specific settings
|
|
||||||
self.is_claude_3 = "claude-3" in model.lower()
|
|
||||||
self.supports_tools = self.is_claude_3 # Claude 3+ supports tool use
|
|
||||||
|
|
||||||
@property
|
|
||||||
def stop(self) -> list[str]:
|
|
||||||
"""Get stop sequences sent to the API."""
|
|
||||||
return self.stop_sequences
|
|
||||||
|
|
||||||
@stop.setter
|
|
||||||
def stop(self, value: list[str] | str | None) -> None:
|
|
||||||
"""Set stop sequences.
|
|
||||||
|
|
||||||
Synchronizes stop_sequences to ensure values set by CrewAgentExecutor
|
|
||||||
are properly sent to the Anthropic API.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
value: Stop sequences as a list, single string, or None
|
|
||||||
"""
|
|
||||||
if value is None:
|
|
||||||
self.stop_sequences = []
|
|
||||||
elif isinstance(value, str):
|
|
||||||
self.stop_sequences = [value]
|
|
||||||
elif isinstance(value, list):
|
|
||||||
self.stop_sequences = value
|
|
||||||
else:
|
|
||||||
self.stop_sequences = []
|
|
||||||
|
|
||||||
def _get_client_params(self) -> dict[str, Any]:
|
|
||||||
"""Get client parameters."""
|
|
||||||
|
|
||||||
if self.api_key is None:
|
if self.api_key is None:
|
||||||
self.api_key = os.getenv("ANTHROPIC_API_KEY")
|
self.api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||||
if self.api_key is None:
|
if self.api_key is None:
|
||||||
raise ValueError("ANTHROPIC_API_KEY is required")
|
raise ValueError("ANTHROPIC_API_KEY is required")
|
||||||
|
|
||||||
client_params = {
|
params = self.model_dump(
|
||||||
"api_key": self.api_key,
|
include={"api_key", "base_url", "timeout", "max_retries"},
|
||||||
"base_url": self.base_url,
|
exclude_none=True,
|
||||||
"timeout": self.timeout,
|
)
|
||||||
"max_retries": self.max_retries,
|
|
||||||
}
|
|
||||||
|
|
||||||
if self.interceptor:
|
if self.interceptor:
|
||||||
transport = HTTPTransport(interceptor=self.interceptor)
|
transport = HTTPTransport(interceptor=self.interceptor)
|
||||||
http_client = httpx.Client(transport=transport)
|
http_client = httpx.Client(transport=transport)
|
||||||
client_params["http_client"] = http_client # type: ignore[assignment]
|
params["http_client"] = http_client
|
||||||
|
|
||||||
if self.client_params:
|
if self.client_params:
|
||||||
client_params.update(self.client_params)
|
params.update(self.client_params)
|
||||||
|
|
||||||
return client_params
|
self._client = Anthropic(**params)
|
||||||
|
return self
|
||||||
|
|
||||||
|
@computed_field # type: ignore[prop-decorator]
|
||||||
|
@property
|
||||||
|
def is_claude_3(self) -> bool:
|
||||||
|
"""Check if the model is Claude 3 or higher."""
|
||||||
|
return "claude-3" in self.model.lower()
|
||||||
|
|
||||||
|
@computed_field # type: ignore[prop-decorator]
|
||||||
|
@property
|
||||||
|
def supports_tools(self) -> bool:
|
||||||
|
"""Check if the model supports tool use."""
|
||||||
|
return self.is_claude_3
|
||||||
|
|
||||||
|
@computed_field # type: ignore[prop-decorator]
|
||||||
|
@property
|
||||||
|
def supports_function_calling(self) -> bool:
|
||||||
|
"""Check if the model supports function calling."""
|
||||||
|
return self.supports_tools
|
||||||
|
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
@@ -149,8 +125,8 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
tools: list[dict[str, Any]] | None = None,
|
tools: list[dict[str, Any]] | None = None,
|
||||||
callbacks: list[Any] | None = None,
|
callbacks: list[Any] | None = None,
|
||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Any | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Agent | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str | Any:
|
) -> str | Any:
|
||||||
"""Call Anthropic messages API.
|
"""Call Anthropic messages API.
|
||||||
@@ -229,25 +205,21 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
Returns:
|
Returns:
|
||||||
Parameters dictionary for Anthropic API
|
Parameters dictionary for Anthropic API
|
||||||
"""
|
"""
|
||||||
params = {
|
params = self.model_dump(
|
||||||
"model": self.model,
|
include={
|
||||||
"messages": messages,
|
"model",
|
||||||
"max_tokens": self.max_tokens,
|
"max_tokens",
|
||||||
"stream": self.stream,
|
"stream",
|
||||||
}
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"stop_sequences",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
params["messages"] = messages
|
||||||
# Add system message if present
|
# Add system message if present
|
||||||
if system_message:
|
if system_message:
|
||||||
params["system"] = system_message
|
params["system"] = system_message
|
||||||
|
|
||||||
# Add optional parameters if set
|
|
||||||
if self.temperature is not None:
|
|
||||||
params["temperature"] = self.temperature
|
|
||||||
if self.top_p is not None:
|
|
||||||
params["top_p"] = self.top_p
|
|
||||||
if self.stop_sequences:
|
|
||||||
params["stop_sequences"] = self.stop_sequences
|
|
||||||
|
|
||||||
# Handle tools for Claude 3+
|
# Handle tools for Claude 3+
|
||||||
if tools and self.supports_tools:
|
if tools and self.supports_tools:
|
||||||
params["tools"] = self._convert_tools_for_interference(tools)
|
params["tools"] = self._convert_tools_for_interference(tools)
|
||||||
@@ -266,8 +238,6 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from crewai.llms.providers.utils.common import safe_tool_conversion
|
|
||||||
|
|
||||||
name, description, parameters = safe_tool_conversion(tool, "Anthropic")
|
name, description, parameters = safe_tool_conversion(tool, "Anthropic")
|
||||||
except (ImportError, KeyError, ValueError) as e:
|
except (ImportError, KeyError, ValueError) as e:
|
||||||
logging.error(f"Error converting tool to Anthropic format: {e}")
|
logging.error(f"Error converting tool to Anthropic format: {e}")
|
||||||
@@ -341,8 +311,8 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
self,
|
self,
|
||||||
params: dict[str, Any],
|
params: dict[str, Any],
|
||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Any | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Agent | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str | Any:
|
) -> str | Any:
|
||||||
"""Handle non-streaming message completion."""
|
"""Handle non-streaming message completion."""
|
||||||
@@ -357,7 +327,7 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
|
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response: Message = self.client.messages.create(**params)
|
response: Message = self._client.messages.create(**params)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if is_context_length_exceeded(e):
|
if is_context_length_exceeded(e):
|
||||||
@@ -429,8 +399,8 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
self,
|
self,
|
||||||
params: dict[str, Any],
|
params: dict[str, Any],
|
||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Any | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Agent | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Handle streaming message completion."""
|
"""Handle streaming message completion."""
|
||||||
@@ -451,7 +421,7 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
stream_params = {k: v for k, v in params.items() if k != "stream"}
|
stream_params = {k: v for k, v in params.items() if k != "stream"}
|
||||||
|
|
||||||
# Make streaming API call
|
# Make streaming API call
|
||||||
with self.client.messages.stream(**stream_params) as stream:
|
with self._client.messages.stream(**stream_params) as stream:
|
||||||
for event in stream:
|
for event in stream:
|
||||||
if hasattr(event, "delta") and hasattr(event.delta, "text"):
|
if hasattr(event, "delta") and hasattr(event.delta, "text"):
|
||||||
text_delta = event.delta.text
|
text_delta = event.delta.text
|
||||||
@@ -525,8 +495,8 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
tool_uses: list[ToolUseBlock],
|
tool_uses: list[ToolUseBlock],
|
||||||
params: dict[str, Any],
|
params: dict[str, Any],
|
||||||
available_functions: dict[str, Any],
|
available_functions: dict[str, Any],
|
||||||
from_task: Any | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Agent | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Handle the complete tool use conversation flow.
|
"""Handle the complete tool use conversation flow.
|
||||||
|
|
||||||
@@ -579,7 +549,7 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Send tool results back to Claude for final response
|
# Send tool results back to Claude for final response
|
||||||
final_response: Message = self.client.messages.create(**follow_up_params)
|
final_response: Message = self._client.messages.create(**follow_up_params)
|
||||||
|
|
||||||
# Track token usage for follow-up call
|
# Track token usage for follow-up call
|
||||||
follow_up_usage = self._extract_anthropic_token_usage(final_response)
|
follow_up_usage = self._extract_anthropic_token_usage(final_response)
|
||||||
@@ -626,48 +596,24 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
return tool_results[0]["content"]
|
return tool_results[0]["content"]
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def supports_function_calling(self) -> bool:
|
|
||||||
"""Check if the model supports function calling."""
|
|
||||||
return self.supports_tools
|
|
||||||
|
|
||||||
def supports_stop_words(self) -> bool:
|
|
||||||
"""Check if the model supports stop words."""
|
|
||||||
return True # All Claude models support stop sequences
|
|
||||||
|
|
||||||
def get_context_window_size(self) -> int:
|
def get_context_window_size(self) -> int:
|
||||||
"""Get the context window size for the model."""
|
"""Get the context window size for the model."""
|
||||||
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO
|
|
||||||
|
|
||||||
# Context window sizes for Anthropic models
|
|
||||||
context_windows = {
|
|
||||||
"claude-3-5-sonnet": 200000,
|
|
||||||
"claude-3-5-haiku": 200000,
|
|
||||||
"claude-3-opus": 200000,
|
|
||||||
"claude-3-sonnet": 200000,
|
|
||||||
"claude-3-haiku": 200000,
|
|
||||||
"claude-3-7-sonnet": 200000,
|
|
||||||
"claude-2.1": 200000,
|
|
||||||
"claude-2": 100000,
|
|
||||||
"claude-instant": 100000,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Find the best match for the model name
|
# Find the best match for the model name
|
||||||
for model_prefix, size in context_windows.items():
|
for model_prefix, size in ANTHROPIC_CONTEXT_WINDOWS.items():
|
||||||
if self.model.startswith(model_prefix):
|
if self.model.startswith(model_prefix):
|
||||||
return int(size * CONTEXT_WINDOW_USAGE_RATIO)
|
return int(size * CONTEXT_WINDOW_USAGE_RATIO)
|
||||||
|
|
||||||
# Default context window size for Claude models
|
# Default context window size for Claude models
|
||||||
return int(200000 * CONTEXT_WINDOW_USAGE_RATIO)
|
return int(200000 * CONTEXT_WINDOW_USAGE_RATIO)
|
||||||
|
|
||||||
def _extract_anthropic_token_usage(self, response: Message) -> dict[str, Any]:
|
@staticmethod
|
||||||
|
def _extract_anthropic_token_usage(response: Message) -> dict[str, Any]:
|
||||||
"""Extract token usage from Anthropic response."""
|
"""Extract token usage from Anthropic response."""
|
||||||
if hasattr(response, "usage") and response.usage:
|
if response.usage:
|
||||||
usage = response.usage
|
usage = response.usage
|
||||||
input_tokens = getattr(usage, "input_tokens", 0)
|
|
||||||
output_tokens = getattr(usage, "output_tokens", 0)
|
|
||||||
return {
|
return {
|
||||||
"input_tokens": input_tokens,
|
"input_tokens": usage.input_tokens,
|
||||||
"output_tokens": output_tokens,
|
"output_tokens": usage.output_tokens,
|
||||||
"total_tokens": input_tokens + output_tokens,
|
"total_tokens": usage.input_tokens + usage.output_tokens,
|
||||||
}
|
}
|
||||||
return {"total_tokens": 0}
|
return {"total_tokens": 0}
|
||||||
|
|||||||
@@ -1,12 +1,14 @@
|
|||||||
import logging
|
from __future__ import annotations
|
||||||
import os
|
|
||||||
from typing import Any, cast
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, PrivateAttr, computed_field, model_validator
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
from crewai.events.types.llm_events import LLMCallType
|
from crewai.events.types.llm_events import LLMCallType
|
||||||
|
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES
|
||||||
from crewai.llms.base_llm import BaseLLM
|
from crewai.llms.base_llm import BaseLLM
|
||||||
from crewai.llms.hooks.base import BaseInterceptor
|
|
||||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||||
LLMContextLengthExceededError,
|
LLMContextLengthExceededError,
|
||||||
@@ -14,6 +16,11 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
|||||||
from crewai.utilities.types import LLMMessage
|
from crewai.utilities.types import LLMMessage
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from crewai.agent import Agent
|
||||||
|
from crewai.task import Task
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from google import genai # type: ignore[import-untyped]
|
from google import genai # type: ignore[import-untyped]
|
||||||
from google.genai import types # type: ignore[import-untyped]
|
from google.genai import types # type: ignore[import-untyped]
|
||||||
@@ -24,6 +31,27 @@ except ImportError:
|
|||||||
) from None
|
) from None
|
||||||
|
|
||||||
|
|
||||||
|
GEMINI_CONTEXT_WINDOWS: dict[str, int] = {
|
||||||
|
"gemini-2.0-flash": 1048576, # 1M tokens
|
||||||
|
"gemini-2.0-flash-thinking": 32768,
|
||||||
|
"gemini-2.0-flash-lite": 1048576,
|
||||||
|
"gemini-2.5-flash": 1048576,
|
||||||
|
"gemini-2.5-pro": 1048576,
|
||||||
|
"gemini-1.5-pro": 2097152, # 2M tokens
|
||||||
|
"gemini-1.5-flash": 1048576,
|
||||||
|
"gemini-1.5-flash-8b": 1048576,
|
||||||
|
"gemini-1.0-pro": 32768,
|
||||||
|
"gemma-3-1b": 32000,
|
||||||
|
"gemma-3-4b": 128000,
|
||||||
|
"gemma-3-12b": 128000,
|
||||||
|
"gemma-3-27b": 128000,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Context window validation constraints
|
||||||
|
MIN_CONTEXT_WINDOW: int = 1024
|
||||||
|
MAX_CONTEXT_WINDOW: int = 2097152
|
||||||
|
|
||||||
|
|
||||||
class GeminiCompletion(BaseLLM):
|
class GeminiCompletion(BaseLLM):
|
||||||
"""Google Gemini native completion implementation.
|
"""Google Gemini native completion implementation.
|
||||||
|
|
||||||
@@ -31,78 +59,140 @@ class GeminiCompletion(BaseLLM):
|
|||||||
offering native function calling, streaming support, and proper Gemini formatting.
|
offering native function calling, streaming support, and proper Gemini formatting.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
model: str = Field(
|
||||||
self,
|
default="gemini-2.0-flash-001",
|
||||||
model: str = "gemini-2.0-flash-001",
|
description="Gemini model name (e.g., 'gemini-2.0-flash-001', 'gemini-1.5-pro')",
|
||||||
api_key: str | None = None,
|
)
|
||||||
project: str | None = None,
|
project: str | None = Field(
|
||||||
location: str | None = None,
|
default=None,
|
||||||
temperature: float | None = None,
|
description="Google Cloud project ID (for Vertex AI)",
|
||||||
top_p: float | None = None,
|
)
|
||||||
top_k: int | None = None,
|
location: str = Field(
|
||||||
max_output_tokens: int | None = None,
|
default="us-central1",
|
||||||
stop_sequences: list[str] | None = None,
|
description="Google Cloud location (for Vertex AI)",
|
||||||
stream: bool = False,
|
)
|
||||||
safety_settings: dict[str, Any] | None = None,
|
top_p: float | None = Field(
|
||||||
client_params: dict[str, Any] | None = None,
|
default=None,
|
||||||
interceptor: BaseInterceptor[Any, Any] | None = None,
|
description="Nucleus sampling parameter",
|
||||||
**kwargs: Any,
|
)
|
||||||
):
|
top_k: int | None = Field(
|
||||||
"""Initialize Google Gemini chat completion client.
|
default=None,
|
||||||
|
description="Top-k sampling parameter",
|
||||||
|
)
|
||||||
|
max_output_tokens: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Maximum tokens in response",
|
||||||
|
)
|
||||||
|
safety_settings: dict[str, Any] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Safety filter settings",
|
||||||
|
)
|
||||||
|
_client: genai.Client = PrivateAttr( # type: ignore[no-any-unimported]
|
||||||
|
default_factory=genai.Client,
|
||||||
|
)
|
||||||
|
|
||||||
Args:
|
@model_validator(mode="after")
|
||||||
model: Gemini model name (e.g., 'gemini-2.0-flash-001', 'gemini-1.5-pro')
|
def initialize_client(self) -> Self:
|
||||||
api_key: Google API key (defaults to GOOGLE_API_KEY or GEMINI_API_KEY env var)
|
"""Initialize the Anthropic client after Pydantic validation.
|
||||||
project: Google Cloud project ID (for Vertex AI)
|
|
||||||
location: Google Cloud location (for Vertex AI, defaults to 'us-central1')
|
This runs after all field validation is complete, ensuring that:
|
||||||
temperature: Sampling temperature (0-2)
|
- All BaseLLM fields are set (model, temperature, stop_sequences, etc.)
|
||||||
top_p: Nucleus sampling parameter
|
- Field validators have run (stop_sequences is normalized to set[str])
|
||||||
top_k: Top-k sampling parameter
|
- API key and other configuration is ready
|
||||||
max_output_tokens: Maximum tokens in response
|
|
||||||
stop_sequences: Stop sequences
|
|
||||||
stream: Enable streaming responses
|
|
||||||
safety_settings: Safety filter settings
|
|
||||||
client_params: Additional parameters to pass to the Google Gen AI Client constructor.
|
|
||||||
Supports parameters like http_options, credentials, debug_config, etc.
|
|
||||||
interceptor: HTTP interceptor (not yet supported for Gemini).
|
|
||||||
**kwargs: Additional parameters
|
|
||||||
"""
|
"""
|
||||||
if interceptor is not None:
|
self._client = genai.Client(**self._get_client_params())
|
||||||
raise NotImplementedError(
|
return self
|
||||||
"HTTP interceptors are not yet supported for Google Gemini provider. "
|
|
||||||
"Interceptors are currently supported for OpenAI and Anthropic providers only."
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(
|
# def __init__(
|
||||||
model=model, temperature=temperature, stop=stop_sequences or [], **kwargs
|
# self,
|
||||||
)
|
# model: str = "gemini-2.0-flash-001",
|
||||||
|
# api_key: str | None = None,
|
||||||
|
# project: str | None = None,
|
||||||
|
# location: str | None = None,
|
||||||
|
# temperature: float | None = None,
|
||||||
|
# top_p: float | None = None,
|
||||||
|
# top_k: int | None = None,
|
||||||
|
# max_output_tokens: int | None = None,
|
||||||
|
# stop_sequences: list[str] | None = None,
|
||||||
|
# stream: bool = False,
|
||||||
|
# safety_settings: dict[str, Any] | None = None,
|
||||||
|
# client_params: dict[str, Any] | None = None,
|
||||||
|
# interceptor: BaseInterceptor[Any, Any] | None = None,
|
||||||
|
# **kwargs: Any,
|
||||||
|
# # ):
|
||||||
|
# """Initialize Google Gemini chat completion client.
|
||||||
|
#
|
||||||
|
# Args:
|
||||||
|
# model: Gemini model name (e.g., 'gemini-2.0-flash-001', 'gemini-1.5-pro')
|
||||||
|
# api_key: Google API key (defaults to GOOGLE_API_KEY or GEMINI_API_KEY env var)
|
||||||
|
# project: Google Cloud project ID (for Vertex AI)
|
||||||
|
# location: Google Cloud location (for Vertex AI, defaults to 'us-central1')
|
||||||
|
# temperature: Sampling temperature (0-2)
|
||||||
|
# top_p: Nucleus sampling parameter
|
||||||
|
# top_k: Top-k sampling parameter
|
||||||
|
# max_output_tokens: Maximum tokens in response
|
||||||
|
# stop_sequences: Stop sequences
|
||||||
|
# stream: Enable streaming responses
|
||||||
|
# safety_settings: Safety filter settings
|
||||||
|
# client_params: Additional parameters to pass to the Google Gen AI Client constructor.
|
||||||
|
# Supports parameters like http_options, credentials, debug_config, etc.
|
||||||
|
# interceptor: HTTP interceptor (not yet supported for Gemini).
|
||||||
|
# **kwargs: Additional parameters
|
||||||
|
# """
|
||||||
|
# if interceptor is not None:
|
||||||
|
# raise NotImplementedError(
|
||||||
|
# "HTTP interceptors are not yet supported for Google Gemini provider. "
|
||||||
|
# "Interceptors are currently supported for OpenAI and Anthropic providers only."
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# super().__init__(
|
||||||
|
# model=model, temperature=temperature, stop=stop_sequences or [], **kwargs
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# # Store client params for later use
|
||||||
|
# self.client_params = client_params or {}
|
||||||
|
#
|
||||||
|
# # Get API configuration with environment variable fallbacks
|
||||||
|
# self.api_key = (
|
||||||
|
# api_key or os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY")
|
||||||
|
# )
|
||||||
|
# self.project = project or os.getenv("GOOGLE_CLOUD_PROJECT")
|
||||||
|
# self.location = location or os.getenv("GOOGLE_CLOUD_LOCATION") or "us-central1"
|
||||||
|
#
|
||||||
|
# use_vertexai = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true"
|
||||||
|
#
|
||||||
|
# self.client = self._initialize_client(use_vertexai)
|
||||||
|
#
|
||||||
|
# # Store completion parameters
|
||||||
|
# self.top_p = top_p
|
||||||
|
# self.top_k = top_k
|
||||||
|
# self.max_output_tokens = max_output_tokens
|
||||||
|
# self.stream = stream
|
||||||
|
# self.safety_settings = safety_settings or {}
|
||||||
|
# self.stop_sequences = stop_sequences or []
|
||||||
|
#
|
||||||
|
# # Model-specific settings
|
||||||
|
# self.is_gemini_2 = "gemini-2" in model.lower()
|
||||||
|
# self.is_gemini_1_5 = "gemini-1.5" in model.lower()
|
||||||
|
# self.supports_tools = self.is_gemini_1_5 or self.is_gemini_2
|
||||||
|
|
||||||
# Store client params for later use
|
@computed_field # type: ignore[prop-decorator]
|
||||||
self.client_params = client_params or {}
|
@property
|
||||||
|
def is_gemini_2(self) -> bool:
|
||||||
|
"""Check if the model is Gemini 2.x."""
|
||||||
|
return "gemini-2" in self.model.lower()
|
||||||
|
|
||||||
# Get API configuration with environment variable fallbacks
|
@computed_field # type: ignore[prop-decorator]
|
||||||
self.api_key = (
|
@property
|
||||||
api_key or os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY")
|
def is_gemini_1_5(self) -> bool:
|
||||||
)
|
"""Check if the model is Gemini 1.5.x."""
|
||||||
self.project = project or os.getenv("GOOGLE_CLOUD_PROJECT")
|
return "gemini-1.5" in self.model.lower()
|
||||||
self.location = location or os.getenv("GOOGLE_CLOUD_LOCATION") or "us-central1"
|
|
||||||
|
|
||||||
use_vertexai = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true"
|
@computed_field # type: ignore[prop-decorator]
|
||||||
|
@property
|
||||||
self.client = self._initialize_client(use_vertexai)
|
def supports_tools(self) -> bool:
|
||||||
|
"""Check if the model supports tool/function calling."""
|
||||||
# Store completion parameters
|
return self.is_gemini_1_5 or self.is_gemini_2
|
||||||
self.top_p = top_p
|
|
||||||
self.top_k = top_k
|
|
||||||
self.max_output_tokens = max_output_tokens
|
|
||||||
self.stream = stream
|
|
||||||
self.safety_settings = safety_settings or {}
|
|
||||||
self.stop_sequences = stop_sequences or []
|
|
||||||
|
|
||||||
# Model-specific settings
|
|
||||||
self.is_gemini_2 = "gemini-2" in model.lower()
|
|
||||||
self.is_gemini_1_5 = "gemini-1.5" in model.lower()
|
|
||||||
self.supports_tools = self.is_gemini_1_5 or self.is_gemini_2
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def stop(self) -> list[str]:
|
def stop(self) -> list[str]:
|
||||||
@@ -142,6 +232,12 @@ class GeminiCompletion(BaseLLM):
|
|||||||
if self.client_params:
|
if self.client_params:
|
||||||
client_params.update(self.client_params)
|
client_params.update(self.client_params)
|
||||||
|
|
||||||
|
if self.interceptor:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"HTTP interceptors are not yet supported for Google Gemini provider. "
|
||||||
|
"Interceptors are currently supported for OpenAI and Anthropic providers only."
|
||||||
|
)
|
||||||
|
|
||||||
if use_vertexai or self.project:
|
if use_vertexai or self.project:
|
||||||
client_params.update(
|
client_params.update(
|
||||||
{
|
{
|
||||||
@@ -181,7 +277,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
|
|
||||||
if (
|
if (
|
||||||
hasattr(self, "client")
|
hasattr(self, "client")
|
||||||
and hasattr(self.client, "vertexai")
|
and hasattr(self._client, "vertexai")
|
||||||
and self.client.vertexai
|
and self.client.vertexai
|
||||||
):
|
):
|
||||||
# Vertex AI configuration
|
# Vertex AI configuration
|
||||||
@@ -206,8 +302,8 @@ class GeminiCompletion(BaseLLM):
|
|||||||
tools: list[dict[str, Any]] | None = None,
|
tools: list[dict[str, Any]] | None = None,
|
||||||
callbacks: list[Any] | None = None,
|
callbacks: list[Any] | None = None,
|
||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Any | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Agent | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str | Any:
|
) -> str | Any:
|
||||||
"""Call Google Gemini generate content API.
|
"""Call Google Gemini generate content API.
|
||||||
@@ -294,7 +390,16 @@ class GeminiCompletion(BaseLLM):
|
|||||||
GenerateContentConfig object for Gemini API
|
GenerateContentConfig object for Gemini API
|
||||||
"""
|
"""
|
||||||
self.tools = tools
|
self.tools = tools
|
||||||
config_params = {}
|
config_params = self.model_dump(
|
||||||
|
include={
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"top_k",
|
||||||
|
"max_output_tokens",
|
||||||
|
"stop_sequences",
|
||||||
|
"safety_settings",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Add system instruction if present
|
# Add system instruction if present
|
||||||
if system_instruction:
|
if system_instruction:
|
||||||
@@ -304,18 +409,6 @@ class GeminiCompletion(BaseLLM):
|
|||||||
)
|
)
|
||||||
config_params["system_instruction"] = system_content
|
config_params["system_instruction"] = system_content
|
||||||
|
|
||||||
# Add generation config parameters
|
|
||||||
if self.temperature is not None:
|
|
||||||
config_params["temperature"] = self.temperature
|
|
||||||
if self.top_p is not None:
|
|
||||||
config_params["top_p"] = self.top_p
|
|
||||||
if self.top_k is not None:
|
|
||||||
config_params["top_k"] = self.top_k
|
|
||||||
if self.max_output_tokens is not None:
|
|
||||||
config_params["max_output_tokens"] = self.max_output_tokens
|
|
||||||
if self.stop_sequences:
|
|
||||||
config_params["stop_sequences"] = self.stop_sequences
|
|
||||||
|
|
||||||
if response_model:
|
if response_model:
|
||||||
config_params["response_mime_type"] = "application/json"
|
config_params["response_mime_type"] = "application/json"
|
||||||
config_params["response_schema"] = response_model.model_json_schema()
|
config_params["response_schema"] = response_model.model_json_schema()
|
||||||
@@ -324,9 +417,6 @@ class GeminiCompletion(BaseLLM):
|
|||||||
if tools and self.supports_tools:
|
if tools and self.supports_tools:
|
||||||
config_params["tools"] = self._convert_tools_for_interference(tools)
|
config_params["tools"] = self._convert_tools_for_interference(tools)
|
||||||
|
|
||||||
if self.safety_settings:
|
|
||||||
config_params["safety_settings"] = self.safety_settings
|
|
||||||
|
|
||||||
return types.GenerateContentConfig(**config_params)
|
return types.GenerateContentConfig(**config_params)
|
||||||
|
|
||||||
def _convert_tools_for_interference( # type: ignore[no-any-unimported]
|
def _convert_tools_for_interference( # type: ignore[no-any-unimported]
|
||||||
@@ -404,8 +494,8 @@ class GeminiCompletion(BaseLLM):
|
|||||||
system_instruction: str | None,
|
system_instruction: str | None,
|
||||||
config: types.GenerateContentConfig,
|
config: types.GenerateContentConfig,
|
||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Any | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Agent | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str | Any:
|
) -> str | Any:
|
||||||
"""Handle non-streaming content generation."""
|
"""Handle non-streaming content generation."""
|
||||||
@@ -416,7 +506,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = self.client.models.generate_content(**api_params)
|
response = self._client.models.generate_content(**api_params)
|
||||||
|
|
||||||
usage = self._extract_token_usage(response)
|
usage = self._extract_token_usage(response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -470,8 +560,8 @@ class GeminiCompletion(BaseLLM):
|
|||||||
contents: list[types.Content],
|
contents: list[types.Content],
|
||||||
config: types.GenerateContentConfig,
|
config: types.GenerateContentConfig,
|
||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Any | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Agent | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Handle streaming content generation."""
|
"""Handle streaming content generation."""
|
||||||
@@ -484,7 +574,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
"config": config,
|
"config": config,
|
||||||
}
|
}
|
||||||
|
|
||||||
for chunk in self.client.models.generate_content_stream(**api_params):
|
for chunk in self._client.models.generate_content_stream(**api_params):
|
||||||
if hasattr(chunk, "text") and chunk.text:
|
if hasattr(chunk, "text") and chunk.text:
|
||||||
full_response += chunk.text
|
full_response += chunk.text
|
||||||
self._emit_stream_chunk_event(
|
self._emit_stream_chunk_event(
|
||||||
@@ -537,52 +627,30 @@ class GeminiCompletion(BaseLLM):
|
|||||||
|
|
||||||
return full_response
|
return full_response
|
||||||
|
|
||||||
|
@computed_field # type: ignore[prop-decorator]
|
||||||
|
@property
|
||||||
def supports_function_calling(self) -> bool:
|
def supports_function_calling(self) -> bool:
|
||||||
"""Check if the model supports function calling."""
|
"""Check if the model supports function calling."""
|
||||||
return self.supports_tools
|
return self.supports_tools
|
||||||
|
|
||||||
def supports_stop_words(self) -> bool:
|
|
||||||
"""Check if the model supports stop words."""
|
|
||||||
return True
|
|
||||||
|
|
||||||
def get_context_window_size(self) -> int:
|
def get_context_window_size(self) -> int:
|
||||||
"""Get the context window size for the model."""
|
"""Get the context window size for the model."""
|
||||||
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES
|
|
||||||
|
|
||||||
min_context = 1024
|
|
||||||
max_context = 2097152
|
|
||||||
|
|
||||||
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
|
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
|
||||||
if value < min_context or value > max_context:
|
if value < MIN_CONTEXT_WINDOW or value > MAX_CONTEXT_WINDOW:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Context window for {key} must be between {min_context} and {max_context}"
|
f"Context window for {key} must be between {MIN_CONTEXT_WINDOW} and {MAX_CONTEXT_WINDOW}"
|
||||||
)
|
)
|
||||||
|
|
||||||
context_windows = {
|
|
||||||
"gemini-2.0-flash": 1048576, # 1M tokens
|
|
||||||
"gemini-2.0-flash-thinking": 32768,
|
|
||||||
"gemini-2.0-flash-lite": 1048576,
|
|
||||||
"gemini-2.5-flash": 1048576,
|
|
||||||
"gemini-2.5-pro": 1048576,
|
|
||||||
"gemini-1.5-pro": 2097152, # 2M tokens
|
|
||||||
"gemini-1.5-flash": 1048576,
|
|
||||||
"gemini-1.5-flash-8b": 1048576,
|
|
||||||
"gemini-1.0-pro": 32768,
|
|
||||||
"gemma-3-1b": 32000,
|
|
||||||
"gemma-3-4b": 128000,
|
|
||||||
"gemma-3-12b": 128000,
|
|
||||||
"gemma-3-27b": 128000,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Find the best match for the model name
|
# Find the best match for the model name
|
||||||
for model_prefix, size in context_windows.items():
|
for model_prefix, size in GEMINI_CONTEXT_WINDOWS.items():
|
||||||
if self.model.startswith(model_prefix):
|
if self.model.startswith(model_prefix):
|
||||||
return int(size * CONTEXT_WINDOW_USAGE_RATIO)
|
return int(size * CONTEXT_WINDOW_USAGE_RATIO)
|
||||||
|
|
||||||
# Default context window size for Gemini models
|
# Default context window size for Gemini models
|
||||||
return int(1048576 * CONTEXT_WINDOW_USAGE_RATIO) # 1M tokens
|
return int(1048576 * CONTEXT_WINDOW_USAGE_RATIO) # 1M tokens
|
||||||
|
|
||||||
def _extract_token_usage(self, response: dict[str, Any]) -> dict[str, Any]:
|
@staticmethod
|
||||||
|
def _extract_token_usage(response: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""Extract token usage from Gemini response."""
|
"""Extract token usage from Gemini response."""
|
||||||
if hasattr(response, "usage_metadata"):
|
if hasattr(response, "usage_metadata"):
|
||||||
usage = response.usage_metadata
|
usage = response.usage_metadata
|
||||||
@@ -594,8 +662,8 @@ class GeminiCompletion(BaseLLM):
|
|||||||
}
|
}
|
||||||
return {"total_tokens": 0}
|
return {"total_tokens": 0}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def _convert_contents_to_dict( # type: ignore[no-any-unimported]
|
def _convert_contents_to_dict( # type: ignore[no-any-unimported]
|
||||||
self,
|
|
||||||
contents: list[types.Content],
|
contents: list[types.Content],
|
||||||
) -> list[dict[str, str]]:
|
) -> list[dict[str, str]]:
|
||||||
"""Convert contents to dict format."""
|
"""Convert contents to dict format."""
|
||||||
|
|||||||
@@ -4,16 +4,23 @@ from collections.abc import Iterator
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any, Final
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from openai import APIConnectionError, NotFoundError, OpenAI
|
from openai import APIConnectionError, NotFoundError, OpenAI
|
||||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||||
from openai.types.chat.chat_completion import Choice
|
from openai.types.chat.chat_completion import Choice
|
||||||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||||
from pydantic import BaseModel
|
from pydantic import (
|
||||||
|
BaseModel,
|
||||||
|
Field,
|
||||||
|
PrivateAttr,
|
||||||
|
model_validator,
|
||||||
|
)
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
from crewai.events.types.llm_events import LLMCallType
|
from crewai.events.types.llm_events import LLMCallType
|
||||||
|
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES
|
||||||
from crewai.llms.base_llm import BaseLLM
|
from crewai.llms.base_llm import BaseLLM
|
||||||
from crewai.llms.hooks.transport import HTTPTransport
|
from crewai.llms.hooks.transport import HTTPTransport
|
||||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||||
@@ -25,11 +32,28 @@ from crewai.utilities.types import LLMMessage
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from crewai.agent.core import Agent
|
from crewai.agent.core import Agent
|
||||||
from crewai.llms.hooks.base import BaseInterceptor
|
|
||||||
from crewai.task import Task
|
from crewai.task import Task
|
||||||
from crewai.tools.base_tool import BaseTool
|
from crewai.tools.base_tool import BaseTool
|
||||||
|
|
||||||
|
|
||||||
|
OPENAI_CONTEXT_WINDOWS: dict[str, int] = {
|
||||||
|
"gpt-4": 8192,
|
||||||
|
"gpt-4o": 128000,
|
||||||
|
"gpt-4o-mini": 200000,
|
||||||
|
"gpt-4-turbo": 128000,
|
||||||
|
"gpt-4.1": 1047576,
|
||||||
|
"gpt-4.1-mini-2025-04-14": 1047576,
|
||||||
|
"gpt-4.1-nano-2025-04-14": 1047576,
|
||||||
|
"o1-preview": 128000,
|
||||||
|
"o1-mini": 128000,
|
||||||
|
"o3-mini": 200000,
|
||||||
|
"o4-mini": 200000,
|
||||||
|
}
|
||||||
|
|
||||||
|
MIN_CONTEXT_WINDOW: Final[int] = 1024
|
||||||
|
MAX_CONTEXT_WINDOW: Final[int] = 2097152
|
||||||
|
|
||||||
|
|
||||||
class OpenAICompletion(BaseLLM):
|
class OpenAICompletion(BaseLLM):
|
||||||
"""OpenAI native completion implementation.
|
"""OpenAI native completion implementation.
|
||||||
|
|
||||||
@@ -37,112 +61,125 @@ class OpenAICompletion(BaseLLM):
|
|||||||
offering native structured outputs, function calling, and streaming support.
|
offering native structured outputs, function calling, and streaming support.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
model: str = Field(
|
||||||
self,
|
default="gpt-4o",
|
||||||
model: str = "gpt-4o",
|
description="OpenAI model name (e.g., 'gpt-4o')",
|
||||||
api_key: str | None = None,
|
)
|
||||||
base_url: str | None = None,
|
organization: str | None = Field(
|
||||||
organization: str | None = None,
|
default=None,
|
||||||
project: str | None = None,
|
description="Name of the OpenAI organization",
|
||||||
timeout: float | None = None,
|
)
|
||||||
max_retries: int = 2,
|
project: str | None = Field(
|
||||||
default_headers: dict[str, str] | None = None,
|
default=None,
|
||||||
default_query: dict[str, Any] | None = None,
|
description="Name of the OpenAI project",
|
||||||
client_params: dict[str, Any] | None = None,
|
)
|
||||||
temperature: float | None = None,
|
api_base: str | None = Field(
|
||||||
top_p: float | None = None,
|
default=os.getenv("OPENAI_BASE_URL"),
|
||||||
frequency_penalty: float | None = None,
|
description="Base URL for OpenAI API",
|
||||||
presence_penalty: float | None = None,
|
)
|
||||||
max_tokens: int | None = None,
|
default_headers: dict[str, str] | None = Field(
|
||||||
max_completion_tokens: int | None = None,
|
default=None,
|
||||||
seed: int | None = None,
|
description="Default headers for OpenAI API requests",
|
||||||
stream: bool = False,
|
)
|
||||||
response_format: dict[str, Any] | type[BaseModel] | None = None,
|
default_query: dict[str, Any] | None = Field(
|
||||||
logprobs: bool | None = None,
|
default=None,
|
||||||
top_logprobs: int | None = None,
|
description="Default query parameters for OpenAI API requests",
|
||||||
reasoning_effort: str | None = None,
|
)
|
||||||
provider: str | None = None,
|
top_p: float | None = Field(
|
||||||
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
|
default=None,
|
||||||
**kwargs: Any,
|
description="Top-p sampling parameter",
|
||||||
) -> None:
|
)
|
||||||
"""Initialize OpenAI chat completion client."""
|
frequency_penalty: float | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Frequency penalty parameter",
|
||||||
|
)
|
||||||
|
presence_penalty: float | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Presence penalty parameter",
|
||||||
|
)
|
||||||
|
max_completion_tokens: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Maximum tokens for completion",
|
||||||
|
)
|
||||||
|
seed: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Random seed for reproducibility",
|
||||||
|
)
|
||||||
|
response_format: dict[str, Any] | type[BaseModel] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Response format for structured output",
|
||||||
|
)
|
||||||
|
logprobs: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Whether to include log probabilities",
|
||||||
|
)
|
||||||
|
top_logprobs: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Number of top log probabilities to return",
|
||||||
|
)
|
||||||
|
reasoning_effort: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Reasoning effort level for o1 models",
|
||||||
|
)
|
||||||
|
supports_function_calling: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Whether the model supports function calling",
|
||||||
|
)
|
||||||
|
is_o1_model: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="Whether the model is an o1 model",
|
||||||
|
)
|
||||||
|
is_gpt4_model: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="Whether the model is a GPT-4 model",
|
||||||
|
)
|
||||||
|
_client: OpenAI = PrivateAttr(
|
||||||
|
default_factory=OpenAI,
|
||||||
|
)
|
||||||
|
|
||||||
if provider is None:
|
@model_validator(mode="after")
|
||||||
provider = kwargs.pop("provider", "openai")
|
def initialize_client(self) -> Self:
|
||||||
|
"""Initialize the Anthropic client after Pydantic validation.
|
||||||
self.interceptor = interceptor
|
|
||||||
# Client configuration attributes
|
|
||||||
self.organization = organization
|
|
||||||
self.project = project
|
|
||||||
self.max_retries = max_retries
|
|
||||||
self.default_headers = default_headers
|
|
||||||
self.default_query = default_query
|
|
||||||
self.client_params = client_params
|
|
||||||
self.timeout = timeout
|
|
||||||
self.base_url = base_url
|
|
||||||
self.api_base = kwargs.pop("api_base", None)
|
|
||||||
|
|
||||||
super().__init__(
|
|
||||||
model=model,
|
|
||||||
temperature=temperature,
|
|
||||||
api_key=api_key or os.getenv("OPENAI_API_KEY"),
|
|
||||||
base_url=base_url,
|
|
||||||
timeout=timeout,
|
|
||||||
provider=provider,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
client_config = self._get_client_params()
|
|
||||||
if self.interceptor:
|
|
||||||
transport = HTTPTransport(interceptor=self.interceptor)
|
|
||||||
http_client = httpx.Client(transport=transport)
|
|
||||||
client_config["http_client"] = http_client
|
|
||||||
|
|
||||||
self.client = OpenAI(**client_config)
|
|
||||||
|
|
||||||
# Completion parameters
|
|
||||||
self.top_p = top_p
|
|
||||||
self.frequency_penalty = frequency_penalty
|
|
||||||
self.presence_penalty = presence_penalty
|
|
||||||
self.max_tokens = max_tokens
|
|
||||||
self.max_completion_tokens = max_completion_tokens
|
|
||||||
self.seed = seed
|
|
||||||
self.stream = stream
|
|
||||||
self.response_format = response_format
|
|
||||||
self.logprobs = logprobs
|
|
||||||
self.top_logprobs = top_logprobs
|
|
||||||
self.reasoning_effort = reasoning_effort
|
|
||||||
self.is_o1_model = "o1" in model.lower()
|
|
||||||
self.is_gpt4_model = "gpt-4" in model.lower()
|
|
||||||
|
|
||||||
def _get_client_params(self) -> dict[str, Any]:
|
|
||||||
"""Get OpenAI client parameters."""
|
|
||||||
|
|
||||||
|
This runs after all field validation is complete, ensuring that:
|
||||||
|
- All BaseLLM fields are set (model, temperature, stop_sequences, etc.)
|
||||||
|
- Field validators have run (stop_sequences is normalized to set[str])
|
||||||
|
- API key and other configuration is ready
|
||||||
|
"""
|
||||||
if self.api_key is None:
|
if self.api_key is None:
|
||||||
self.api_key = os.getenv("OPENAI_API_KEY")
|
self.api_key = os.getenv("OPENAI_API_KEY")
|
||||||
if self.api_key is None:
|
if self.api_key is None:
|
||||||
raise ValueError("OPENAI_API_KEY is required")
|
raise ValueError("OPENAI_API_KEY is required")
|
||||||
|
|
||||||
base_params = {
|
self.is_o1_model = "o1" in self.model.lower()
|
||||||
"api_key": self.api_key,
|
self.supports_function_calling = not self.is_o1_model
|
||||||
"organization": self.organization,
|
self.is_gpt4_model = "gpt-4" in self.model.lower()
|
||||||
"project": self.project,
|
self.supports_stop_words = not self.is_o1_model
|
||||||
"base_url": self.base_url
|
|
||||||
or self.api_base
|
|
||||||
or os.getenv("OPENAI_BASE_URL")
|
|
||||||
or None,
|
|
||||||
"timeout": self.timeout,
|
|
||||||
"max_retries": self.max_retries,
|
|
||||||
"default_headers": self.default_headers,
|
|
||||||
"default_query": self.default_query,
|
|
||||||
}
|
|
||||||
|
|
||||||
client_params = {k: v for k, v in base_params.items() if v is not None}
|
params = self.model_dump(
|
||||||
|
include={
|
||||||
|
"api_key",
|
||||||
|
"organization",
|
||||||
|
"project",
|
||||||
|
"base_url",
|
||||||
|
"timeout",
|
||||||
|
"max_retries",
|
||||||
|
"default_headers",
|
||||||
|
"default_query",
|
||||||
|
},
|
||||||
|
exclude_none=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.interceptor:
|
||||||
|
transport = HTTPTransport(interceptor=self.interceptor)
|
||||||
|
http_client = httpx.Client(transport=transport)
|
||||||
|
params["http_client"] = http_client
|
||||||
|
|
||||||
if self.client_params:
|
if self.client_params:
|
||||||
client_params.update(self.client_params)
|
params.update(self.client_params)
|
||||||
|
|
||||||
return client_params
|
self._client = OpenAI(**params)
|
||||||
|
return self
|
||||||
|
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
@@ -213,38 +250,26 @@ class OpenAICompletion(BaseLLM):
|
|||||||
self, messages: list[LLMMessage], tools: list[dict[str, BaseTool]] | None = None
|
self, messages: list[LLMMessage], tools: list[dict[str, BaseTool]] | None = None
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Prepare parameters for OpenAI chat completion."""
|
"""Prepare parameters for OpenAI chat completion."""
|
||||||
params: dict[str, Any] = {
|
params = self.model_dump(
|
||||||
"model": self.model,
|
include={
|
||||||
"messages": messages,
|
"model",
|
||||||
}
|
"stream",
|
||||||
if self.stream:
|
"temperature",
|
||||||
params["stream"] = self.stream
|
"top_p",
|
||||||
|
"frequency_penalty",
|
||||||
|
"presence_penalty",
|
||||||
|
"max_completion_tokens",
|
||||||
|
"max_tokens",
|
||||||
|
"seed",
|
||||||
|
"logprobs",
|
||||||
|
"top_logprobs",
|
||||||
|
"reasoning_effort",
|
||||||
|
},
|
||||||
|
exclude_none=True,
|
||||||
|
)
|
||||||
|
params["messages"] = messages
|
||||||
params.update(self.additional_params)
|
params.update(self.additional_params)
|
||||||
|
|
||||||
if self.temperature is not None:
|
|
||||||
params["temperature"] = self.temperature
|
|
||||||
if self.top_p is not None:
|
|
||||||
params["top_p"] = self.top_p
|
|
||||||
if self.frequency_penalty is not None:
|
|
||||||
params["frequency_penalty"] = self.frequency_penalty
|
|
||||||
if self.presence_penalty is not None:
|
|
||||||
params["presence_penalty"] = self.presence_penalty
|
|
||||||
if self.max_completion_tokens is not None:
|
|
||||||
params["max_completion_tokens"] = self.max_completion_tokens
|
|
||||||
elif self.max_tokens is not None:
|
|
||||||
params["max_tokens"] = self.max_tokens
|
|
||||||
if self.seed is not None:
|
|
||||||
params["seed"] = self.seed
|
|
||||||
if self.logprobs is not None:
|
|
||||||
params["logprobs"] = self.logprobs
|
|
||||||
if self.top_logprobs is not None:
|
|
||||||
params["top_logprobs"] = self.top_logprobs
|
|
||||||
|
|
||||||
# Handle o1 model specific parameters
|
|
||||||
if self.is_o1_model and self.reasoning_effort:
|
|
||||||
params["reasoning_effort"] = self.reasoning_effort
|
|
||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
params["tools"] = self._convert_tools_for_interference(tools)
|
params["tools"] = self._convert_tools_for_interference(tools)
|
||||||
params["tool_choice"] = "auto"
|
params["tool_choice"] = "auto"
|
||||||
@@ -296,14 +321,14 @@ class OpenAICompletion(BaseLLM):
|
|||||||
self,
|
self,
|
||||||
params: dict[str, Any],
|
params: dict[str, Any],
|
||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Any | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Agent | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str | Any:
|
) -> str | Any:
|
||||||
"""Handle non-streaming chat completion."""
|
"""Handle non-streaming chat completion."""
|
||||||
try:
|
try:
|
||||||
if response_model:
|
if response_model:
|
||||||
parsed_response = self.client.beta.chat.completions.parse(
|
parsed_response = self._client.beta.chat.completions.parse(
|
||||||
**params,
|
**params,
|
||||||
response_format=response_model,
|
response_format=response_model,
|
||||||
)
|
)
|
||||||
@@ -327,7 +352,7 @@ class OpenAICompletion(BaseLLM):
|
|||||||
)
|
)
|
||||||
return structured_json
|
return structured_json
|
||||||
|
|
||||||
response: ChatCompletion = self.client.chat.completions.create(**params)
|
response: ChatCompletion = self._client.chat.completions.create(**params)
|
||||||
|
|
||||||
usage = self._extract_openai_token_usage(response)
|
usage = self._extract_openai_token_usage(response)
|
||||||
|
|
||||||
@@ -419,8 +444,8 @@ class OpenAICompletion(BaseLLM):
|
|||||||
self,
|
self,
|
||||||
params: dict[str, Any],
|
params: dict[str, Any],
|
||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Any | None = None,
|
from_task: Task | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Agent | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Handle streaming chat completion."""
|
"""Handle streaming chat completion."""
|
||||||
@@ -429,7 +454,7 @@ class OpenAICompletion(BaseLLM):
|
|||||||
|
|
||||||
if response_model:
|
if response_model:
|
||||||
completion_stream: Iterator[ChatCompletionChunk] = (
|
completion_stream: Iterator[ChatCompletionChunk] = (
|
||||||
self.client.chat.completions.create(**params)
|
self._client.chat.completions.create(**params)
|
||||||
)
|
)
|
||||||
|
|
||||||
accumulated_content = ""
|
accumulated_content = ""
|
||||||
@@ -472,7 +497,7 @@ class OpenAICompletion(BaseLLM):
|
|||||||
)
|
)
|
||||||
return accumulated_content
|
return accumulated_content
|
||||||
|
|
||||||
stream: Iterator[ChatCompletionChunk] = self.client.chat.completions.create(
|
stream: Iterator[ChatCompletionChunk] = self._client.chat.completions.create(
|
||||||
**params
|
**params
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -550,58 +575,31 @@ class OpenAICompletion(BaseLLM):
|
|||||||
|
|
||||||
return full_response
|
return full_response
|
||||||
|
|
||||||
def supports_function_calling(self) -> bool:
|
|
||||||
"""Check if the model supports function calling."""
|
|
||||||
return not self.is_o1_model
|
|
||||||
|
|
||||||
def supports_stop_words(self) -> bool:
|
|
||||||
"""Check if the model supports stop words."""
|
|
||||||
return not self.is_o1_model
|
|
||||||
|
|
||||||
def get_context_window_size(self) -> int:
|
def get_context_window_size(self) -> int:
|
||||||
"""Get the context window size for the model."""
|
"""Get the context window size for the model."""
|
||||||
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES
|
|
||||||
|
|
||||||
min_context = 1024
|
|
||||||
max_context = 2097152
|
|
||||||
|
|
||||||
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
|
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
|
||||||
if value < min_context or value > max_context:
|
if value < MIN_CONTEXT_WINDOW or value > MAX_CONTEXT_WINDOW:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Context window for {key} must be between {min_context} and {max_context}"
|
f"Context window for {key} must be between {MIN_CONTEXT_WINDOW} and {MAX_CONTEXT_WINDOW}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Context window sizes for OpenAI models
|
|
||||||
context_windows = {
|
|
||||||
"gpt-4": 8192,
|
|
||||||
"gpt-4o": 128000,
|
|
||||||
"gpt-4o-mini": 200000,
|
|
||||||
"gpt-4-turbo": 128000,
|
|
||||||
"gpt-4.1": 1047576,
|
|
||||||
"gpt-4.1-mini-2025-04-14": 1047576,
|
|
||||||
"gpt-4.1-nano-2025-04-14": 1047576,
|
|
||||||
"o1-preview": 128000,
|
|
||||||
"o1-mini": 128000,
|
|
||||||
"o3-mini": 200000,
|
|
||||||
"o4-mini": 200000,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Find the best match for the model name
|
# Find the best match for the model name
|
||||||
for model_prefix, size in context_windows.items():
|
for model_prefix, size in OPENAI_CONTEXT_WINDOWS.items():
|
||||||
if self.model.startswith(model_prefix):
|
if self.model.startswith(model_prefix):
|
||||||
return int(size * CONTEXT_WINDOW_USAGE_RATIO)
|
return int(size * CONTEXT_WINDOW_USAGE_RATIO)
|
||||||
|
|
||||||
# Default context window size
|
# Default context window size
|
||||||
return int(8192 * CONTEXT_WINDOW_USAGE_RATIO)
|
return int(8192 * CONTEXT_WINDOW_USAGE_RATIO)
|
||||||
|
|
||||||
def _extract_openai_token_usage(self, response: ChatCompletion) -> dict[str, Any]:
|
@staticmethod
|
||||||
|
def _extract_openai_token_usage(response: ChatCompletion) -> dict[str, Any]:
|
||||||
"""Extract token usage from OpenAI ChatCompletion response."""
|
"""Extract token usage from OpenAI ChatCompletion response."""
|
||||||
if hasattr(response, "usage") and response.usage:
|
if response.usage:
|
||||||
usage = response.usage
|
usage = response.usage
|
||||||
return {
|
return {
|
||||||
"prompt_tokens": getattr(usage, "prompt_tokens", 0),
|
"prompt_tokens": usage.prompt_tokens,
|
||||||
"completion_tokens": getattr(usage, "completion_tokens", 0),
|
"completion_tokens": usage.completion_tokens,
|
||||||
"total_tokens": getattr(usage, "total_tokens", 0),
|
"total_tokens": usage.total_tokens,
|
||||||
}
|
}
|
||||||
return {"total_tokens": 0}
|
return {"total_tokens": 0}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user