mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-02 15:52:34 +00:00
Lorenze/native inference sdks (#3619)
* ruff linted * using native sdks with litellm fallback * drop exa * drop print on completion * Refactor LLM and utility functions for type consistency - Updated `max_tokens` parameter in `LLM` class to accept `float` in addition to `int`. - Modified `create_llm` function to ensure consistent type hints and return types, now returning `LLM | BaseLLM | None`. - Adjusted type hints for various parameters in `create_llm` and `_llm_via_environment_or_fallback` functions for improved clarity and type safety. - Enhanced test cases to reflect changes in type handling and ensure proper instantiation of LLM instances. * fix agent_tests * fix litellm tests and usagemetrics fix * drop print * Refactor LLM event handling and improve test coverage - Removed commented-out event emission for LLM call failures in `llm.py`. - Added `from_agent` parameter to `CrewAgentExecutor` for better context in LLM responses. - Enhanced test for LLM call failure to simulate OpenAI API failure and updated assertions for clarity. - Updated agent and task ID assertions in tests to ensure they are consistently treated as strings. * fix test_converter * fixed tests/agents/test_agent.py * Refactor LLM context length exception handling and improve provider integration - Renamed `LLMContextLengthExceededException` to `LLMContextLengthExceededExceptionError` for clarity and consistency. - Updated LLM class to pass the provider parameter correctly during initialization. - Enhanced error handling in various LLM provider implementations to raise the new exception type. - Adjusted tests to reflect the updated exception name and ensure proper error handling in context length scenarios. * Enhance LLM context window handling across providers - Introduced CONTEXT_WINDOW_USAGE_RATIO to adjust context window sizes dynamically for Anthropic, Azure, Gemini, and OpenAI LLMs. - Added validation for context window sizes in Azure and Gemini providers to ensure they fall within acceptable limits. - Updated context window size calculations to use the new ratio, improving consistency and adaptability across different models. - Removed hardcoded context window sizes in favor of ratio-based calculations for better flexibility. * fix test agent again * fix test agent * feat: add native LLM providers for Anthropic, Azure, and Gemini - Introduced new completion implementations for Anthropic, Azure, and Gemini, integrating their respective SDKs. - Added utility functions for tool validation and extraction to support function calling across LLM providers. - Enhanced context window management and token usage extraction for each provider. - Created a common utility module for shared functionality among LLM providers. * chore: update dependencies and improve context management - Removed direct dependency on `litellm` from the main dependencies and added it under extras for better modularity. - Updated the `litellm` dependency specification to allow for greater flexibility in versioning. - Refactored context length exception handling across various LLM providers to use a consistent error class. - Enhanced platform-specific dependency markers for NVIDIA packages to ensure compatibility across different systems. * refactor(tests): update LLM instantiation to include is_litellm flag in test cases - Modified multiple test cases in test_llm.py to set the is_litellm parameter to True when instantiating the LLM class. - This change ensures that the tests are aligned with the latest LLM configuration requirements and improves consistency across test scenarios. - Adjusted relevant assertions and comments to reflect the updated LLM behavior. * linter * linted * revert constants * fix(tests): correct type hint in expected model description - Updated the expected description in the test_generate_model_description_dict_field function to use 'Dict' instead of 'dict' for consistency with type hinting conventions. - This change ensures that the test accurately reflects the expected output format for model descriptions. * refactor(llm): enhance LLM instantiation and error handling - Updated the LLM class to include validation for the model parameter, ensuring it is a non-empty string. - Improved error handling by logging warnings when the native SDK fails, allowing for a fallback to LiteLLM. - Adjusted the instantiation of LLM in test cases to consistently include the is_litellm flag, aligning with recent changes in LLM configuration. - Modified relevant tests to reflect these updates, ensuring better coverage and accuracy in testing scenarios. * fixed test * refactor(llm): enhance token usage tracking and add copy methods - Updated the LLM class to track token usage and log callbacks in streaming mode, improving monitoring capabilities. - Introduced shallow and deep copy methods for the LLM instance, allowing for better management of LLM configurations and parameters. - Adjusted test cases to instantiate LLM with the is_litellm flag, ensuring alignment with recent changes in LLM configuration. * refactor(tests): reorganize imports and enhance error messages in test cases - Cleaned up import statements in test_crew.py for better organization and readability. - Enhanced error messages in test cases to use `re.escape` for improved regex matching, ensuring more robust error handling. - Adjusted comments for clarity and consistency across test scenarios. - Ensured that all necessary modules are imported correctly to avoid potential runtime issues.
This commit is contained in:
@@ -114,7 +114,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
self.messages: list[dict[str, str]] = []
|
||||
self.iterations = 0
|
||||
self.log_error_after = 3
|
||||
existing_stop = self.llm.stop or []
|
||||
existing_stop = getattr(self.llm, "stop", [])
|
||||
self.llm.stop = list(
|
||||
set(
|
||||
existing_stop + self.stop
|
||||
@@ -192,6 +192,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
callbacks=self.callbacks,
|
||||
printer=self._printer,
|
||||
from_task=self.task,
|
||||
from_agent=self.agent,
|
||||
)
|
||||
formatted_answer = process_llm_response(answer, self.use_stop_words)
|
||||
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import Future
|
||||
from copy import copy as shallow_copy
|
||||
from hashlib import md5
|
||||
import json
|
||||
import re
|
||||
from typing import (
|
||||
Any,
|
||||
cast,
|
||||
)
|
||||
import uuid
|
||||
import warnings
|
||||
|
||||
from opentelemetry import baggage
|
||||
from opentelemetry.context import attach, detach
|
||||
@@ -82,6 +82,7 @@ from crewai.utilities.planning_handler import CrewPlanner
|
||||
from crewai.utilities.task_output_storage_handler import TaskOutputStorageHandler
|
||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
|
||||
|
||||
warnings.filterwarnings("ignore", category=SyntaxWarning, module="pysbd")
|
||||
|
||||
|
||||
@@ -1353,13 +1354,34 @@ class Crew(FlowTrackable, BaseModel):
|
||||
def calculate_usage_metrics(self) -> UsageMetrics:
|
||||
"""Calculates and returns the usage metrics."""
|
||||
total_usage_metrics = UsageMetrics()
|
||||
|
||||
for agent in self.agents:
|
||||
if hasattr(agent, "_token_process"):
|
||||
token_sum = agent._token_process.get_summary()
|
||||
total_usage_metrics.add_usage_metrics(token_sum)
|
||||
if isinstance(agent.llm, BaseLLM):
|
||||
llm_usage = agent.llm.get_token_usage_summary()
|
||||
|
||||
total_usage_metrics.add_usage_metrics(llm_usage)
|
||||
else:
|
||||
# fallback litellm
|
||||
if hasattr(agent, "_token_process"):
|
||||
token_sum = agent._token_process.get_summary()
|
||||
total_usage_metrics.add_usage_metrics(token_sum)
|
||||
|
||||
if self.manager_agent and hasattr(self.manager_agent, "_token_process"):
|
||||
token_sum = self.manager_agent._token_process.get_summary()
|
||||
total_usage_metrics.add_usage_metrics(token_sum)
|
||||
|
||||
if (
|
||||
self.manager_agent
|
||||
and hasattr(self.manager_agent, "llm")
|
||||
and hasattr(self.manager_agent.llm, "get_token_usage_summary")
|
||||
):
|
||||
if isinstance(self.manager_agent.llm, BaseLLM):
|
||||
llm_usage = self.manager_agent.llm.get_token_usage_summary()
|
||||
else:
|
||||
llm_usage = self.manager_agent.llm._token_process.get_summary()
|
||||
|
||||
total_usage_metrics.add_usage_metrics(llm_usage)
|
||||
|
||||
self.usage_metrics = total_usage_metrics
|
||||
return total_usage_metrics
|
||||
|
||||
|
||||
@@ -17,6 +17,11 @@ class BaseEvent(BaseModel):
|
||||
)
|
||||
fingerprint_metadata: dict[str, Any] | None = None # Any relevant metadata
|
||||
|
||||
task_id: str | None = None
|
||||
task_name: str | None = None
|
||||
agent_id: str | None = None
|
||||
agent_role: str | None = None
|
||||
|
||||
def to_json(self, exclude: set[str] | None = None):
|
||||
"""
|
||||
Converts the event to a JSON-serializable dictionary.
|
||||
@@ -31,7 +36,7 @@ class BaseEvent(BaseModel):
|
||||
|
||||
def _set_task_params(self, data: dict[str, Any]):
|
||||
if "from_task" in data and (task := data["from_task"]):
|
||||
self.task_id = task.id
|
||||
self.task_id = str(task.id)
|
||||
self.task_name = task.name or task.description
|
||||
self.from_task = None
|
||||
|
||||
@@ -42,6 +47,6 @@ class BaseEvent(BaseModel):
|
||||
if not agent:
|
||||
return
|
||||
|
||||
self.agent_id = agent.id
|
||||
self.agent_id = str(agent.id)
|
||||
self.agent_role = agent.role
|
||||
self.from_agent = None
|
||||
|
||||
@@ -7,19 +7,23 @@ from crewai.events.base_events import BaseEvent
|
||||
|
||||
|
||||
class LLMEventBase(BaseEvent):
|
||||
task_name: str | None = None
|
||||
task_id: str | None = None
|
||||
|
||||
agent_id: str | None = None
|
||||
agent_role: str | None = None
|
||||
|
||||
from_task: Any | None = None
|
||||
from_agent: Any | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
if data.get("from_task"):
|
||||
task = data["from_task"]
|
||||
data["task_id"] = str(task.id)
|
||||
data["task_name"] = task.name or task.description
|
||||
data["from_task"] = None
|
||||
|
||||
if data.get("from_agent"):
|
||||
agent = data["from_agent"]
|
||||
data["agent_id"] = str(agent.id)
|
||||
data["agent_role"] = agent.role
|
||||
data["from_agent"] = None
|
||||
|
||||
super().__init__(**data)
|
||||
self._set_agent_params(data)
|
||||
self._set_task_params(data)
|
||||
|
||||
|
||||
class LLMCallType(Enum):
|
||||
|
||||
@@ -27,9 +27,20 @@ class ToolUsageEvent(BaseEvent):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def __init__(self, **data):
|
||||
if data.get("from_task"):
|
||||
task = data["from_task"]
|
||||
data["task_id"] = str(task.id)
|
||||
data["task_name"] = task.name or task.description
|
||||
data["from_task"] = None
|
||||
|
||||
if data.get("from_agent"):
|
||||
agent = data["from_agent"]
|
||||
data["agent_id"] = str(agent.id)
|
||||
data["agent_role"] = agent.role
|
||||
data["from_agent"] = None
|
||||
|
||||
super().__init__(**data)
|
||||
self._set_agent_params(data)
|
||||
self._set_task_params(data)
|
||||
|
||||
# Set fingerprint data from the agent
|
||||
if self.agent and hasattr(self.agent, "fingerprint") and self.agent.fingerprint:
|
||||
self.source_fingerprint = self.agent.fingerprint.uuid_str
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
import inspect
|
||||
from typing import (
|
||||
Any,
|
||||
cast,
|
||||
get_args,
|
||||
get_origin,
|
||||
)
|
||||
import uuid
|
||||
|
||||
from pydantic import (
|
||||
UUID4,
|
||||
@@ -351,7 +351,10 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
)
|
||||
|
||||
# Calculate token usage metrics
|
||||
usage_metrics = self._token_process.get_summary()
|
||||
if isinstance(self.llm, BaseLLM):
|
||||
usage_metrics = self.llm.get_token_usage_summary()
|
||||
else:
|
||||
usage_metrics = self._token_process.get_summary()
|
||||
|
||||
# Create output
|
||||
output = LiteAgentOutput(
|
||||
@@ -400,7 +403,10 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
elif isinstance(guardrail_result.result, BaseModel):
|
||||
output.pydantic = guardrail_result.result
|
||||
|
||||
usage_metrics = self._token_process.get_summary()
|
||||
if isinstance(self.llm, BaseLLM):
|
||||
usage_metrics = self.llm.get_token_usage_summary()
|
||||
else:
|
||||
usage_metrics = self._token_process.get_summary()
|
||||
output.usage_metrics = usage_metrics.model_dump() if usage_metrics else None
|
||||
|
||||
# Emit completion event
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Final,
|
||||
Literal,
|
||||
@@ -17,7 +18,6 @@ from typing import (
|
||||
)
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from litellm.types.utils import ChatCompletionDeltaToolCall
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
@@ -39,19 +39,42 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
)
|
||||
from crewai.utilities.logger_utils import suppress_warnings
|
||||
|
||||
with suppress_warnings():
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm import Choices
|
||||
from litellm.exceptions import ContextWindowExceededError
|
||||
from litellm.litellm_core_utils.get_supported_openai_params import (
|
||||
get_supported_openai_params,
|
||||
)
|
||||
from litellm.types.utils import ChatCompletionDeltaToolCall, ModelResponse
|
||||
from litellm.utils import supports_response_schema
|
||||
|
||||
try:
|
||||
import litellm
|
||||
from litellm import Choices, CustomLogger
|
||||
from litellm.exceptions import ContextWindowExceededError
|
||||
from litellm.litellm_core_utils.get_supported_openai_params import (
|
||||
get_supported_openai_params,
|
||||
)
|
||||
from litellm.types.utils import ModelResponse
|
||||
from litellm.types.utils import ChatCompletionDeltaToolCall, ModelResponse
|
||||
from litellm.utils import supports_response_schema
|
||||
|
||||
LITELLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
LITELLM_AVAILABLE = False
|
||||
litellm = None # type: ignore
|
||||
Choices = None # type: ignore
|
||||
ContextWindowExceededError = Exception # type: ignore
|
||||
get_supported_openai_params = None # type: ignore
|
||||
ChatCompletionDeltaToolCall = None # type: ignore
|
||||
ModelResponse = None # type: ignore
|
||||
supports_response_schema = None # type: ignore
|
||||
|
||||
|
||||
load_dotenv()
|
||||
|
||||
litellm.suppress_debug_info = True
|
||||
if LITELLM_AVAILABLE:
|
||||
litellm.suppress_debug_info = True
|
||||
|
||||
|
||||
class FilteredStream(io.TextIOBase):
|
||||
@@ -275,6 +298,77 @@ class AccumulatedToolArgs(BaseModel):
|
||||
class LLM(BaseLLM):
|
||||
completion_cost: float | None = None
|
||||
|
||||
def __new__(cls, model: str, is_litellm: bool = False, **kwargs) -> "LLM":
|
||||
"""Factory method that routes to native SDK or falls back to LiteLLM."""
|
||||
if not model or not isinstance(model, str):
|
||||
raise ValueError("Model must be a non-empty string")
|
||||
|
||||
provider = model.partition("/")[0] if "/" in model else "openai"
|
||||
|
||||
native_class = cls._get_native_provider(provider)
|
||||
if native_class and not is_litellm:
|
||||
try:
|
||||
model_string = model.partition("/")[2] if "/" in model else model
|
||||
return native_class(model=model_string, provider=provider, **kwargs)
|
||||
except Exception as e:
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(
|
||||
f"Native SDK failed for {provider}: {e}, falling back to LiteLLM"
|
||||
)
|
||||
|
||||
# FALLBACK to LiteLLM
|
||||
if not LITELLM_AVAILABLE:
|
||||
raise ImportError(
|
||||
"Please install the required dependencies:\n"
|
||||
"- For LiteLLM: uv add litellm"
|
||||
)
|
||||
|
||||
instance = object.__new__(cls)
|
||||
super(LLM, instance).__init__(model=model, is_litellm=True, **kwargs)
|
||||
instance.is_litellm = True
|
||||
return instance
|
||||
|
||||
@classmethod
|
||||
def _get_native_provider(cls, provider: str) -> type | None:
|
||||
"""Get native provider class if available."""
|
||||
if provider == "openai":
|
||||
try:
|
||||
from crewai.llms.providers.openai.completion import OpenAICompletion
|
||||
|
||||
return OpenAICompletion
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
elif provider == "anthropic" or provider == "claude":
|
||||
try:
|
||||
from crewai.llms.providers.anthropic.completion import (
|
||||
AnthropicCompletion,
|
||||
)
|
||||
|
||||
return AnthropicCompletion
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
elif provider == "azure":
|
||||
try:
|
||||
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||
|
||||
return AzureCompletion
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
elif provider == "google" or provider == "gemini":
|
||||
try:
|
||||
from crewai.llms.providers.gemini.completion import GeminiCompletion
|
||||
|
||||
return GeminiCompletion
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
@@ -284,7 +378,7 @@ class LLM(BaseLLM):
|
||||
n: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
max_tokens: int | float | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[int, float] | None = None,
|
||||
@@ -301,6 +395,11 @@ class LLM(BaseLLM):
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize LLM instance.
|
||||
|
||||
Note: This __init__ method is only called for fallback instances.
|
||||
Native provider instances handle their own initialization in their respective classes.
|
||||
"""
|
||||
self.model = model
|
||||
self.timeout = timeout
|
||||
self.temperature = temperature
|
||||
@@ -328,7 +427,7 @@ class LLM(BaseLLM):
|
||||
|
||||
litellm.drop_params = True
|
||||
|
||||
# Normalize self.stop to always be a List[str]
|
||||
# Normalize self.stop to always be a list[str]
|
||||
if stop is None:
|
||||
self.stop: list[str] = []
|
||||
elif isinstance(stop, str):
|
||||
@@ -349,7 +448,8 @@ class LLM(BaseLLM):
|
||||
Returns:
|
||||
bool: True if the model is from Anthropic, False otherwise.
|
||||
"""
|
||||
return any(prefix in model.lower() for prefix in ANTHROPIC_PREFIXES)
|
||||
anthropic_prefixes = ("anthropic/", "claude-", "claude/")
|
||||
return any(prefix in model.lower() for prefix in anthropic_prefixes)
|
||||
|
||||
def _prepare_completion_params(
|
||||
self,
|
||||
@@ -514,10 +614,6 @@ class LLM(BaseLLM):
|
||||
# Add the chunk content to the full response
|
||||
full_response += chunk_content
|
||||
|
||||
# Emit the chunk event
|
||||
if not hasattr(crewai_event_bus, "emit"):
|
||||
raise Exception("crewai_event_bus must have an `emit` method")
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMStreamChunkEvent(
|
||||
@@ -623,7 +719,9 @@ class LLM(BaseLLM):
|
||||
# --- 8) If no tool calls or no available functions, return the text response directly
|
||||
|
||||
if not tool_calls or not available_functions:
|
||||
# Log token usage if available in streaming mode
|
||||
# Track token usage and log callbacks if available in streaming mode
|
||||
if usage_info:
|
||||
self._track_token_usage_internal(usage_info)
|
||||
self._handle_streaming_callbacks(callbacks, usage_info, last_chunk)
|
||||
# Emit completion event and return response
|
||||
self._handle_emit_call_events(
|
||||
@@ -640,7 +738,9 @@ class LLM(BaseLLM):
|
||||
if tool_result is not None:
|
||||
return tool_result
|
||||
|
||||
# --- 10) Log token usage if available in streaming mode
|
||||
# --- 10) Track token usage and log callbacks if available in streaming mode
|
||||
if usage_info:
|
||||
self._track_token_usage_internal(usage_info)
|
||||
self._handle_streaming_callbacks(callbacks, usage_info, last_chunk)
|
||||
|
||||
# --- 11) Emit completion event and return response
|
||||
@@ -671,11 +771,6 @@ class LLM(BaseLLM):
|
||||
)
|
||||
return full_response
|
||||
|
||||
# Emit failed event and re-raise the exception
|
||||
if not hasattr(crewai_event_bus, "emit"):
|
||||
raise AttributeError(
|
||||
"crewai_event_bus must have an 'emit' method"
|
||||
) from e
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(
|
||||
@@ -702,8 +797,7 @@ class LLM(BaseLLM):
|
||||
current_tool_accumulator.function.arguments += (
|
||||
tool_call.function.arguments
|
||||
)
|
||||
if not hasattr(crewai_event_bus, "emit"):
|
||||
raise AttributeError("crewai_event_bus must have an 'emit' method")
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMStreamChunkEvent(
|
||||
@@ -832,6 +926,7 @@ class LLM(BaseLLM):
|
||||
messages=params["messages"],
|
||||
)
|
||||
return text_response
|
||||
|
||||
# --- 6) If there is no text response, no available functions, but there are tool calls, return the tool calls
|
||||
if tool_calls and not available_functions and not text_response:
|
||||
return tool_calls
|
||||
@@ -886,9 +981,6 @@ class LLM(BaseLLM):
|
||||
function_args = json.loads(tool_call.function.arguments)
|
||||
fn = available_functions[function_name]
|
||||
|
||||
# --- 3.2) Execute function
|
||||
if not hasattr(crewai_event_bus, "emit"):
|
||||
raise AttributeError("crewai_event_bus must have an 'emit' method")
|
||||
started_at = datetime.now()
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -928,10 +1020,6 @@ class LLM(BaseLLM):
|
||||
function_name, lambda: None
|
||||
) # Ensure fn is always a callable
|
||||
logging.error(f"Error executing function '{function_name}': {e}")
|
||||
if not hasattr(crewai_event_bus, "emit"):
|
||||
raise AttributeError(
|
||||
"crewai_event_bus must have an 'emit' method"
|
||||
) from e
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(error=f"Tool execution error: {e!s}"),
|
||||
@@ -982,9 +1070,6 @@ class LLM(BaseLLM):
|
||||
ValueError: If response format is not supported
|
||||
LLMContextLengthExceededError: If input exceeds model's context limit
|
||||
"""
|
||||
# --- 1) Emit call started event
|
||||
if not hasattr(crewai_event_bus, "emit"):
|
||||
raise AttributeError("crewai_event_bus must have an 'emit' method")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallStartedEvent(
|
||||
@@ -1021,10 +1106,10 @@ class LLM(BaseLLM):
|
||||
return self._handle_streaming_response(
|
||||
params, callbacks, available_functions, from_task, from_agent
|
||||
)
|
||||
|
||||
return self._handle_non_streaming_response(
|
||||
params, callbacks, available_functions, from_task, from_agent
|
||||
)
|
||||
|
||||
except LLMContextLengthExceededError:
|
||||
# Re-raise LLMContextLengthExceededError as it should be handled
|
||||
# by the CrewAgentExecutor._invoke_loop method, which can then decide
|
||||
@@ -1057,10 +1142,6 @@ class LLM(BaseLLM):
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if not hasattr(crewai_event_bus, "emit"):
|
||||
raise AttributeError(
|
||||
"crewai_event_bus must have an 'emit' method"
|
||||
) from e
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(
|
||||
@@ -1086,8 +1167,6 @@ class LLM(BaseLLM):
|
||||
from_agent: Optional agent object
|
||||
messages: Optional messages object
|
||||
"""
|
||||
if not hasattr(crewai_event_bus, "emit"):
|
||||
raise AttributeError("crewai_event_bus must have an 'emit' method")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallCompletedEvent(
|
||||
@@ -1225,11 +1304,14 @@ class LLM(BaseLLM):
|
||||
if self.context_window_size != 0:
|
||||
return self.context_window_size
|
||||
|
||||
min_context = 1024
|
||||
max_context = 2097152 # Current max from gemini-1.5-pro
|
||||
|
||||
# Validate all context window sizes
|
||||
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
|
||||
if value < MIN_CONTEXT or value > MAX_CONTEXT:
|
||||
if value < min_context or value > max_context:
|
||||
raise ValueError(
|
||||
f"Context window for {key} must be between {MIN_CONTEXT} and {MAX_CONTEXT}"
|
||||
f"Context window for {key} must be between {min_context} and {max_context}"
|
||||
)
|
||||
|
||||
self.context_window_size = int(
|
||||
@@ -1293,3 +1375,129 @@ class LLM(BaseLLM):
|
||||
|
||||
litellm.success_callback = success_callbacks
|
||||
litellm.failure_callback = failure_callbacks
|
||||
|
||||
def __copy__(self):
|
||||
"""Create a shallow copy of the LLM instance."""
|
||||
# Filter out parameters that are already explicitly passed to avoid conflicts
|
||||
filtered_params = {
|
||||
k: v
|
||||
for k, v in self.additional_params.items()
|
||||
if k
|
||||
not in [
|
||||
"model",
|
||||
"is_litellm",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"n",
|
||||
"max_completion_tokens",
|
||||
"max_tokens",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"logit_bias",
|
||||
"response_format",
|
||||
"seed",
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"base_url",
|
||||
"api_base",
|
||||
"api_version",
|
||||
"api_key",
|
||||
"callbacks",
|
||||
"reasoning_effort",
|
||||
"stream",
|
||||
"stop",
|
||||
]
|
||||
}
|
||||
|
||||
# Create a new instance with the same parameters
|
||||
return LLM(
|
||||
model=self.model,
|
||||
is_litellm=self.is_litellm,
|
||||
temperature=self.temperature,
|
||||
top_p=self.top_p,
|
||||
n=self.n,
|
||||
max_completion_tokens=self.max_completion_tokens,
|
||||
max_tokens=self.max_tokens,
|
||||
presence_penalty=self.presence_penalty,
|
||||
frequency_penalty=self.frequency_penalty,
|
||||
logit_bias=self.logit_bias,
|
||||
response_format=self.response_format,
|
||||
seed=self.seed,
|
||||
logprobs=self.logprobs,
|
||||
top_logprobs=self.top_logprobs,
|
||||
base_url=self.base_url,
|
||||
api_base=self.api_base,
|
||||
api_version=self.api_version,
|
||||
api_key=self.api_key,
|
||||
callbacks=self.callbacks,
|
||||
reasoning_effort=self.reasoning_effort,
|
||||
stream=self.stream,
|
||||
stop=self.stop,
|
||||
**filtered_params,
|
||||
)
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
"""Create a deep copy of the LLM instance."""
|
||||
import copy
|
||||
|
||||
# Filter out parameters that are already explicitly passed to avoid conflicts
|
||||
filtered_params = {
|
||||
k: copy.deepcopy(v, memo)
|
||||
for k, v in self.additional_params.items()
|
||||
if k
|
||||
not in [
|
||||
"model",
|
||||
"is_litellm",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"n",
|
||||
"max_completion_tokens",
|
||||
"max_tokens",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"logit_bias",
|
||||
"response_format",
|
||||
"seed",
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"base_url",
|
||||
"api_base",
|
||||
"api_version",
|
||||
"api_key",
|
||||
"callbacks",
|
||||
"reasoning_effort",
|
||||
"stream",
|
||||
"stop",
|
||||
]
|
||||
}
|
||||
|
||||
# Create a new instance with the same parameters
|
||||
return LLM(
|
||||
model=self.model,
|
||||
is_litellm=self.is_litellm,
|
||||
temperature=self.temperature,
|
||||
top_p=self.top_p,
|
||||
n=self.n,
|
||||
max_completion_tokens=self.max_completion_tokens,
|
||||
max_tokens=self.max_tokens,
|
||||
presence_penalty=self.presence_penalty,
|
||||
frequency_penalty=self.frequency_penalty,
|
||||
logit_bias=copy.deepcopy(self.logit_bias, memo)
|
||||
if self.logit_bias
|
||||
else None,
|
||||
response_format=copy.deepcopy(self.response_format, memo)
|
||||
if self.response_format
|
||||
else None,
|
||||
seed=self.seed,
|
||||
logprobs=self.logprobs,
|
||||
top_logprobs=self.top_logprobs,
|
||||
base_url=self.base_url,
|
||||
api_base=self.api_base,
|
||||
api_version=self.api_version,
|
||||
api_key=self.api_key,
|
||||
callbacks=copy.deepcopy(self.callbacks, memo) if self.callbacks else None,
|
||||
reasoning_effort=self.reasoning_effort,
|
||||
stream=self.stream,
|
||||
stop=copy.deepcopy(self.stop, memo) if self.stop else None,
|
||||
**filtered_params,
|
||||
)
|
||||
|
||||
@@ -1,12 +1,33 @@
|
||||
"""Base LLM abstract class for CrewAI.
|
||||
|
||||
This module provides the abstract base class for all LLM implementations
|
||||
in CrewAI.
|
||||
in CrewAI, including common functionality for native SDK implementations.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Final
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.llm_events import (
|
||||
LLMCallCompletedEvent,
|
||||
LLMCallFailedEvent,
|
||||
LLMCallStartedEvent,
|
||||
LLMCallType,
|
||||
LLMStreamChunkEvent,
|
||||
)
|
||||
from crewai.events.types.tool_usage_events import (
|
||||
ToolUsageErrorEvent,
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageStartedEvent,
|
||||
)
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
|
||||
|
||||
DEFAULT_CONTEXT_WINDOW_SIZE: Final[int] = 4096
|
||||
DEFAULT_SUPPORTS_STOP_WORDS: Final[bool] = True
|
||||
|
||||
@@ -27,13 +48,20 @@ class BaseLLM(ABC):
|
||||
model: The model identifier/name.
|
||||
temperature: Optional temperature setting for response generation.
|
||||
stop: A list of stop sequences that the LLM should use to stop generation.
|
||||
additional_params: Additional provider-specific parameters.
|
||||
"""
|
||||
|
||||
is_litellm: bool = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
temperature: float | None = None,
|
||||
stop: list[str] | None = None,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
timeout: float | None = None,
|
||||
provider: str | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Initialize the BaseLLM with default attributes.
|
||||
|
||||
@@ -41,10 +69,44 @@ class BaseLLM(ABC):
|
||||
model: The model identifier/name.
|
||||
temperature: Optional temperature setting for response generation.
|
||||
stop: Optional list of stop sequences for generation.
|
||||
**kwargs: Additional provider-specific parameters.
|
||||
"""
|
||||
if not model:
|
||||
raise ValueError("Model name is required and cannot be empty")
|
||||
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.stop: list[str] = stop or []
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
# Store additional parameters for provider-specific use
|
||||
self.additional_params = kwargs
|
||||
self._provider = provider or "openai"
|
||||
|
||||
stop = kwargs.pop("stop", None)
|
||||
if stop is None:
|
||||
self.stop: list[str] = []
|
||||
elif isinstance(stop, str):
|
||||
self.stop = [stop]
|
||||
else:
|
||||
self.stop = stop
|
||||
|
||||
self._token_usage = {
|
||||
"total_tokens": 0,
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"successful_requests": 0,
|
||||
"cached_prompt_tokens": 0,
|
||||
}
|
||||
|
||||
@property
|
||||
def provider(self) -> str:
|
||||
"""Get the provider of the LLM."""
|
||||
return self._provider
|
||||
|
||||
@provider.setter
|
||||
def provider(self, value: str) -> None:
|
||||
"""Set the provider of the LLM."""
|
||||
self._provider = value
|
||||
|
||||
@abstractmethod
|
||||
def call(
|
||||
@@ -82,6 +144,17 @@ class BaseLLM(ABC):
|
||||
RuntimeError: If the LLM request fails for other reasons.
|
||||
"""
|
||||
|
||||
def _convert_tools_for_interference(self, tools: list[dict]) -> list[dict]:
|
||||
"""Convert tools to a format that can be used for interference.
|
||||
|
||||
Args:
|
||||
tools: List of tools to convert.
|
||||
|
||||
Returns:
|
||||
List of converted tools (default implementation returns as-is)
|
||||
"""
|
||||
return tools
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
"""Check if the LLM supports stop words.
|
||||
|
||||
@@ -90,6 +163,58 @@ class BaseLLM(ABC):
|
||||
"""
|
||||
return DEFAULT_SUPPORTS_STOP_WORDS
|
||||
|
||||
def _supports_stop_words_implementation(self) -> bool:
|
||||
"""Check if stop words are configured for this LLM instance.
|
||||
|
||||
Native providers can override supports_stop_words() to return this value
|
||||
to ensure consistent behavior based on whether stop words are actually configured.
|
||||
|
||||
Returns:
|
||||
True if stop words are configured and can be applied
|
||||
"""
|
||||
return bool(self.stop)
|
||||
|
||||
def _apply_stop_words(self, content: str) -> str:
|
||||
"""Apply stop words to truncate response content.
|
||||
|
||||
This method provides consistent stop word behavior across all native SDK providers.
|
||||
Native providers should call this method to post-process their responses.
|
||||
|
||||
Args:
|
||||
content: The raw response content from the LLM
|
||||
|
||||
Returns:
|
||||
Content truncated at the first occurrence of any stop word
|
||||
|
||||
Example:
|
||||
>>> llm = MyNativeLLM(stop=["Observation:", "Final Answer:"])
|
||||
>>> response = "I need to search.\\n\\nAction: search\\nObservation: Found results"
|
||||
>>> llm._apply_stop_words(response)
|
||||
"I need to search.\\n\\nAction: search"
|
||||
"""
|
||||
if not self.stop or not content:
|
||||
return content
|
||||
|
||||
# Find the earliest occurrence of any stop word
|
||||
earliest_stop_pos = len(content)
|
||||
found_stop_word = None
|
||||
|
||||
for stop_word in self.stop:
|
||||
stop_pos = content.find(stop_word)
|
||||
if stop_pos != -1 and stop_pos < earliest_stop_pos:
|
||||
earliest_stop_pos = stop_pos
|
||||
found_stop_word = stop_word
|
||||
|
||||
# Truncate at the stop word if found
|
||||
if found_stop_word is not None:
|
||||
truncated = content[:earliest_stop_pos].strip()
|
||||
logging.debug(
|
||||
f"Applied stop word '{found_stop_word}' at position {earliest_stop_pos}"
|
||||
)
|
||||
return truncated
|
||||
|
||||
return content
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
"""Get the context window size for the LLM.
|
||||
|
||||
@@ -98,3 +223,314 @@ class BaseLLM(ABC):
|
||||
"""
|
||||
# Default implementation - subclasses should override with model-specific values
|
||||
return DEFAULT_CONTEXT_WINDOW_SIZE
|
||||
|
||||
# Common helper methods for native SDK implementations
|
||||
|
||||
def _emit_call_started_event(
|
||||
self,
|
||||
messages: str | list[dict[str, str]],
|
||||
tools: list[dict] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> None:
|
||||
"""Emit LLM call started event."""
|
||||
if not hasattr(crewai_event_bus, "emit"):
|
||||
raise ValueError("crewai_event_bus does not have an emit method") from None
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallStartedEvent(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
model=self.model,
|
||||
),
|
||||
)
|
||||
|
||||
def _emit_call_completed_event(
|
||||
self,
|
||||
response: Any,
|
||||
call_type: LLMCallType,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
messages: str | list[dict[str, Any]] | None = None,
|
||||
) -> None:
|
||||
"""Emit LLM call completed event."""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallCompletedEvent(
|
||||
messages=messages,
|
||||
response=response,
|
||||
call_type=call_type,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
model=self.model,
|
||||
),
|
||||
)
|
||||
|
||||
def _emit_call_failed_event(
|
||||
self,
|
||||
error: str,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> None:
|
||||
"""Emit LLM call failed event."""
|
||||
if not hasattr(crewai_event_bus, "emit"):
|
||||
raise ValueError("crewai_event_bus does not have an emit method") from None
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(
|
||||
error=error,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
def _emit_stream_chunk_event(
|
||||
self,
|
||||
chunk: str,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
tool_call: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Emit stream chunk event."""
|
||||
if not hasattr(crewai_event_bus, "emit"):
|
||||
raise ValueError("crewai_event_bus does not have an emit method") from None
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMStreamChunkEvent(
|
||||
chunk=chunk,
|
||||
tool_call=tool_call,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
def _handle_tool_execution(
|
||||
self,
|
||||
function_name: str,
|
||||
function_args: dict[str, Any],
|
||||
available_functions: dict[str, Any],
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str | None:
|
||||
"""Handle tool execution with proper event emission.
|
||||
|
||||
Args:
|
||||
function_name: Name of the function to execute
|
||||
function_args: Arguments to pass to the function
|
||||
available_functions: Dict of available functions
|
||||
from_task: Optional task object
|
||||
from_agent: Optional agent object
|
||||
|
||||
Returns:
|
||||
Result of function execution or None if function not found
|
||||
"""
|
||||
if function_name not in available_functions:
|
||||
logging.warning(
|
||||
f"Function '{function_name}' not found in available functions"
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
# Emit tool usage started event
|
||||
started_at = datetime.now()
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolUsageStartedEvent(
|
||||
tool_name=function_name,
|
||||
tool_args=function_args,
|
||||
from_agent=from_agent,
|
||||
from_task=from_task,
|
||||
),
|
||||
)
|
||||
|
||||
# Execute the function
|
||||
fn = available_functions[function_name]
|
||||
result = fn(**function_args)
|
||||
|
||||
# Emit tool usage finished event
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolUsageFinishedEvent(
|
||||
output=result,
|
||||
tool_name=function_name,
|
||||
tool_args=function_args,
|
||||
started_at=started_at,
|
||||
finished_at=datetime.now(),
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
# Emit LLM call completed event for tool call
|
||||
self._emit_call_completed_event(
|
||||
response=result,
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
return str(result)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error executing function '{function_name}': {e!s}"
|
||||
logging.error(error_msg)
|
||||
|
||||
# Emit tool usage error event
|
||||
if not hasattr(crewai_event_bus, "emit"):
|
||||
raise ValueError(
|
||||
"crewai_event_bus does not have an emit method"
|
||||
) from None
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolUsageErrorEvent(
|
||||
tool_name=function_name,
|
||||
tool_args=function_args,
|
||||
error=error_msg,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
# Emit LLM call failed event
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _format_messages(
|
||||
self, messages: str | list[dict[str, str]]
|
||||
) -> list[dict[str, str]]:
|
||||
"""Convert messages to standard format.
|
||||
|
||||
Args:
|
||||
messages: Input messages (string or list of message dicts)
|
||||
|
||||
Returns:
|
||||
List of message dictionaries with 'role' and 'content' keys
|
||||
|
||||
Raises:
|
||||
ValueError: If message format is invalid
|
||||
"""
|
||||
if isinstance(messages, str):
|
||||
return [{"role": "user", "content": messages}]
|
||||
|
||||
# Validate message format
|
||||
for i, msg in enumerate(messages):
|
||||
if not isinstance(msg, dict):
|
||||
raise ValueError(f"Message at index {i} must be a dictionary")
|
||||
if "role" not in msg or "content" not in msg:
|
||||
raise ValueError(
|
||||
f"Message at index {i} must have 'role' and 'content' keys"
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
def _validate_structured_output(
|
||||
self,
|
||||
response: str,
|
||||
response_format: type[BaseModel] | None,
|
||||
) -> str | BaseModel:
|
||||
"""Validate and parse structured output.
|
||||
|
||||
Args:
|
||||
response: Raw response string
|
||||
response_format: Optional Pydantic model for structured output
|
||||
|
||||
Returns:
|
||||
Parsed response (BaseModel instance if response_format provided, otherwise string)
|
||||
|
||||
Raises:
|
||||
ValueError: If structured output validation fails
|
||||
"""
|
||||
if response_format is None:
|
||||
return response
|
||||
|
||||
try:
|
||||
# Try to parse as JSON first
|
||||
if response.strip().startswith("{") or response.strip().startswith("["):
|
||||
data = json.loads(response)
|
||||
return response_format.model_validate(data)
|
||||
|
||||
# Try to extract JSON from response
|
||||
import re
|
||||
|
||||
json_match = re.search(r"\{.*\}", response, re.DOTALL)
|
||||
if json_match:
|
||||
data = json.loads(json_match.group())
|
||||
return response_format.model_validate(data)
|
||||
|
||||
raise ValueError("No JSON found in response")
|
||||
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
logging.warning(f"Failed to parse structured output: {e}")
|
||||
raise ValueError(
|
||||
f"Failed to parse response into {response_format.__name__}: {e}"
|
||||
) from e
|
||||
|
||||
def _extract_provider(self, model: str) -> str:
|
||||
"""Extract provider from model string.
|
||||
|
||||
Args:
|
||||
model: Model string (e.g., 'openai/gpt-4' or 'gpt-4')
|
||||
|
||||
Returns:
|
||||
Provider name (e.g., 'openai')
|
||||
"""
|
||||
if "/" in model:
|
||||
return model.partition("/")[0]
|
||||
return "openai" # Default provider
|
||||
|
||||
def _track_token_usage_internal(self, usage_data: dict[str, Any]) -> None:
|
||||
"""Track token usage internally in the LLM instance.
|
||||
|
||||
Args:
|
||||
usage_data: Token usage data from the API response
|
||||
"""
|
||||
# Extract tokens in a provider-agnostic way
|
||||
prompt_tokens = (
|
||||
usage_data.get("prompt_tokens")
|
||||
or usage_data.get("prompt_token_count")
|
||||
or usage_data.get("input_tokens")
|
||||
or 0
|
||||
)
|
||||
|
||||
completion_tokens = (
|
||||
usage_data.get("completion_tokens")
|
||||
or usage_data.get("candidates_token_count")
|
||||
or usage_data.get("output_tokens")
|
||||
or 0
|
||||
)
|
||||
|
||||
cached_tokens = (
|
||||
usage_data.get("cached_tokens")
|
||||
or usage_data.get("cached_prompt_tokens")
|
||||
or 0
|
||||
)
|
||||
|
||||
self._token_usage["prompt_tokens"] += prompt_tokens
|
||||
self._token_usage["completion_tokens"] += completion_tokens
|
||||
self._token_usage["total_tokens"] += prompt_tokens + completion_tokens
|
||||
self._token_usage["successful_requests"] += 1
|
||||
self._token_usage["cached_prompt_tokens"] += cached_tokens
|
||||
|
||||
def get_token_usage_summary(self) -> UsageMetrics:
|
||||
"""Get summary of token usage for this LLM instance.
|
||||
|
||||
Returns:
|
||||
Dictionary with token usage totals
|
||||
"""
|
||||
return UsageMetrics(**self._token_usage)
|
||||
|
||||
0
lib/crewai/src/crewai/llms/providers/__init__.py
Normal file
0
lib/crewai/src/crewai/llms/providers/__init__.py
Normal file
432
lib/crewai/src/crewai/llms/providers/anthropic/completion.py
Normal file
432
lib/crewai/src/crewai/llms/providers/anthropic/completion.py
Normal file
@@ -0,0 +1,432 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
from anthropic import Anthropic
|
||||
from anthropic.types import Message
|
||||
from anthropic.types.tool_use_block import ToolUseBlock
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Anthropic native provider not available, to install: `uv add anthropic`"
|
||||
) from None
|
||||
|
||||
|
||||
class AnthropicCompletion(BaseLLM):
|
||||
"""Anthropic native completion implementation.
|
||||
|
||||
This class provides direct integration with the Anthropic Python SDK,
|
||||
offering native tool use, streaming support, and proper message formatting.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "claude-3-5-sonnet-20241022",
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
timeout: float | None = None,
|
||||
max_retries: int = 2,
|
||||
temperature: float | None = None,
|
||||
max_tokens: int = 4096, # Required for Anthropic
|
||||
top_p: float | None = None,
|
||||
stop_sequences: list[str] | None = None,
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize Anthropic chat completion client.
|
||||
|
||||
Args:
|
||||
model: Anthropic model name (e.g., 'claude-3-5-sonnet-20241022')
|
||||
api_key: Anthropic API key (defaults to ANTHROPIC_API_KEY env var)
|
||||
base_url: Custom base URL for Anthropic API
|
||||
timeout: Request timeout in seconds
|
||||
max_retries: Maximum number of retries
|
||||
temperature: Sampling temperature (0-1)
|
||||
max_tokens: Maximum tokens in response (required for Anthropic)
|
||||
top_p: Nucleus sampling parameter
|
||||
stop_sequences: Stop sequences (Anthropic uses stop_sequences, not stop)
|
||||
stream: Enable streaming responses
|
||||
**kwargs: Additional parameters
|
||||
"""
|
||||
super().__init__(
|
||||
model=model, temperature=temperature, stop=stop_sequences or [], **kwargs
|
||||
)
|
||||
|
||||
# Initialize Anthropic client
|
||||
self.client = Anthropic(
|
||||
api_key=api_key or os.getenv("ANTHROPIC_API_KEY"),
|
||||
base_url=base_url,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
|
||||
# Store completion parameters
|
||||
self.max_tokens = max_tokens
|
||||
self.top_p = top_p
|
||||
self.stream = stream
|
||||
self.stop_sequences = stop_sequences or []
|
||||
|
||||
# Model-specific settings
|
||||
self.is_claude_3 = "claude-3" in model.lower()
|
||||
self.supports_tools = self.is_claude_3 # Claude 3+ supports tool use
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: str | list[dict[str, str]],
|
||||
tools: list[dict] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str | Any:
|
||||
"""Call Anthropic messages API.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the chat completion
|
||||
tools: List of tool/function definitions
|
||||
callbacks: Callback functions (not used in native implementation)
|
||||
available_functions: Available functions for tool calling
|
||||
from_task: Task that initiated the call
|
||||
from_agent: Agent that initiated the call
|
||||
|
||||
Returns:
|
||||
Chat completion response or tool call result
|
||||
"""
|
||||
try:
|
||||
# Emit call started event
|
||||
self._emit_call_started_event(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
# Format messages for Anthropic
|
||||
formatted_messages, system_message = self._format_messages_for_anthropic(
|
||||
messages
|
||||
)
|
||||
|
||||
# Prepare completion parameters
|
||||
completion_params = self._prepare_completion_params(
|
||||
formatted_messages, system_message, tools
|
||||
)
|
||||
|
||||
# Handle streaming vs non-streaming
|
||||
if self.stream:
|
||||
return self._handle_streaming_completion(
|
||||
completion_params, available_functions, from_task, from_agent
|
||||
)
|
||||
|
||||
return self._handle_completion(
|
||||
completion_params, available_functions, from_task, from_agent
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Anthropic API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
|
||||
def _prepare_completion_params(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
system_message: str | None = None,
|
||||
tools: list[dict] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Prepare parameters for Anthropic messages API.
|
||||
|
||||
Args:
|
||||
messages: Formatted messages for Anthropic
|
||||
system_message: Extracted system message
|
||||
tools: Tool definitions
|
||||
|
||||
Returns:
|
||||
Parameters dictionary for Anthropic API
|
||||
"""
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": self.max_tokens,
|
||||
"stream": self.stream,
|
||||
}
|
||||
|
||||
# Add system message if present
|
||||
if system_message:
|
||||
params["system"] = system_message
|
||||
|
||||
# Add optional parameters if set
|
||||
if self.temperature is not None:
|
||||
params["temperature"] = self.temperature
|
||||
if self.top_p is not None:
|
||||
params["top_p"] = self.top_p
|
||||
if self.stop_sequences:
|
||||
params["stop_sequences"] = self.stop_sequences
|
||||
|
||||
# Handle tools for Claude 3+
|
||||
if tools and self.supports_tools:
|
||||
params["tools"] = self._convert_tools_for_interference(tools)
|
||||
|
||||
return params
|
||||
|
||||
def _convert_tools_for_interference(self, tools: list[dict]) -> list[dict]:
|
||||
"""Convert CrewAI tool format to Anthropic tool use format."""
|
||||
from crewai.llms.providers.utils.common import safe_tool_conversion
|
||||
|
||||
anthropic_tools = []
|
||||
|
||||
for tool in tools:
|
||||
name, description, parameters = safe_tool_conversion(tool, "Anthropic")
|
||||
|
||||
anthropic_tool = {
|
||||
"name": name,
|
||||
"description": description,
|
||||
}
|
||||
|
||||
if parameters and isinstance(parameters, dict):
|
||||
anthropic_tool["input_schema"] = parameters # type: ignore
|
||||
|
||||
anthropic_tools.append(anthropic_tool)
|
||||
|
||||
return anthropic_tools
|
||||
|
||||
def _format_messages_for_anthropic(
|
||||
self, messages: str | list[dict[str, str]]
|
||||
) -> tuple[list[dict[str, str]], str | None]:
|
||||
"""Format messages for Anthropic API.
|
||||
|
||||
Anthropic has specific requirements:
|
||||
- System messages are separate from conversation messages
|
||||
- Messages must alternate between user and assistant
|
||||
- First message must be from user
|
||||
|
||||
Args:
|
||||
messages: Input messages
|
||||
|
||||
Returns:
|
||||
Tuple of (formatted_messages, system_message)
|
||||
"""
|
||||
# Use base class formatting first
|
||||
base_formatted = super()._format_messages(messages)
|
||||
|
||||
formatted_messages = []
|
||||
system_message = None
|
||||
|
||||
for message in base_formatted:
|
||||
role = message.get("role")
|
||||
content = message.get("content", "")
|
||||
|
||||
if role == "system":
|
||||
# Extract system message - Anthropic handles it separately
|
||||
if system_message:
|
||||
system_message += f"\n\n{content}"
|
||||
else:
|
||||
system_message = content
|
||||
else:
|
||||
# Add user/assistant messages - ensure both role and content are str, not None
|
||||
role_str = role if role is not None else "user"
|
||||
content_str = content if content is not None else ""
|
||||
formatted_messages.append({"role": role_str, "content": content_str})
|
||||
|
||||
# Ensure first message is from user (Anthropic requirement)
|
||||
if not formatted_messages:
|
||||
# If no messages, add a default user message
|
||||
formatted_messages.append({"role": "user", "content": "Hello"})
|
||||
elif formatted_messages[0]["role"] != "user":
|
||||
# If first message is not from user, insert a user message at the beginning
|
||||
formatted_messages.insert(0, {"role": "user", "content": "Hello"})
|
||||
|
||||
return formatted_messages, system_message
|
||||
|
||||
def _handle_completion(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming message completion."""
|
||||
try:
|
||||
response: Message = self.client.messages.create(**params)
|
||||
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
logging.error(f"Context window exceeded: {e}")
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
raise e from e
|
||||
|
||||
usage = self._extract_anthropic_token_usage(response)
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
if response.content and available_functions:
|
||||
for content_block in response.content:
|
||||
if isinstance(content_block, ToolUseBlock):
|
||||
function_name = content_block.name
|
||||
function_args = content_block.input
|
||||
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args, # type: ignore
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Extract text content
|
||||
content = ""
|
||||
if response.content:
|
||||
for content_block in response.content:
|
||||
if hasattr(content_block, "text"):
|
||||
content += content_block.text
|
||||
|
||||
content = self._apply_stop_words(content)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=content,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
if usage.get("total_tokens", 0) > 0:
|
||||
logging.info(f"Anthropic API usage: {usage}")
|
||||
|
||||
return content
|
||||
|
||||
def _handle_streaming_completion(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str:
|
||||
"""Handle streaming message completion."""
|
||||
full_response = ""
|
||||
tool_uses = {}
|
||||
|
||||
# Make streaming API call
|
||||
with self.client.messages.stream(**params) as stream:
|
||||
for event in stream:
|
||||
# Handle content delta events
|
||||
if hasattr(event, "delta") and hasattr(event.delta, "text"):
|
||||
text_delta = event.delta.text
|
||||
full_response += text_delta
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=text_delta,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
# Handle tool use events
|
||||
elif hasattr(event, "delta") and hasattr(event.delta, "partial_json"):
|
||||
# Tool use streaming - accumulate JSON
|
||||
tool_id = getattr(event, "index", "default")
|
||||
if tool_id not in tool_uses:
|
||||
tool_uses[tool_id] = {
|
||||
"name": "",
|
||||
"input": "",
|
||||
}
|
||||
|
||||
if hasattr(event.delta, "name"):
|
||||
tool_uses[tool_id]["name"] = event.delta.name
|
||||
if hasattr(event.delta, "partial_json"):
|
||||
tool_uses[tool_id]["input"] += event.delta.partial_json
|
||||
|
||||
# Handle completed tool uses
|
||||
if tool_uses and available_functions:
|
||||
for tool_data in tool_uses.values():
|
||||
function_name = tool_data["name"]
|
||||
|
||||
try:
|
||||
function_args = json.loads(tool_data["input"])
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"Failed to parse streamed tool arguments: {e}")
|
||||
continue
|
||||
|
||||
# Execute tool
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Apply stop words to full response
|
||||
full_response = self._apply_stop_words(full_response)
|
||||
|
||||
# Emit completion event and return full response
|
||||
self._emit_call_completed_event(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return full_response
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the model supports function calling."""
|
||||
return self.supports_tools
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
"""Check if the model supports stop words."""
|
||||
return True # All Claude models support stop sequences
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
"""Get the context window size for the model."""
|
||||
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO
|
||||
|
||||
# Context window sizes for Anthropic models
|
||||
context_windows = {
|
||||
"claude-3-5-sonnet": 200000,
|
||||
"claude-3-5-haiku": 200000,
|
||||
"claude-3-opus": 200000,
|
||||
"claude-3-sonnet": 200000,
|
||||
"claude-3-haiku": 200000,
|
||||
"claude-3-7-sonnet": 200000,
|
||||
"claude-2.1": 200000,
|
||||
"claude-2": 100000,
|
||||
"claude-instant": 100000,
|
||||
}
|
||||
|
||||
# Find the best match for the model name
|
||||
for model_prefix, size in context_windows.items():
|
||||
if self.model.startswith(model_prefix):
|
||||
return int(size * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
|
||||
# Default context window size for Claude models
|
||||
return int(200000 * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
|
||||
def _extract_anthropic_token_usage(self, response: Message) -> dict[str, Any]:
|
||||
"""Extract token usage from Anthropic response."""
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
usage = response.usage
|
||||
input_tokens = getattr(usage, "input_tokens", 0)
|
||||
output_tokens = getattr(usage, "output_tokens", 0)
|
||||
return {
|
||||
"input_tokens": input_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
"total_tokens": input_tokens + output_tokens,
|
||||
}
|
||||
return {"total_tokens": 0}
|
||||
473
lib/crewai/src/crewai/llms/providers/azure/completion.py
Normal file
473
lib/crewai/src/crewai/llms/providers/azure/completion.py
Normal file
@@ -0,0 +1,473 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
from azure.ai.inference import ChatCompletionsClient # type: ignore
|
||||
from azure.ai.inference.models import ( # type: ignore
|
||||
ChatCompletions,
|
||||
ChatCompletionsToolCall,
|
||||
StreamingChatCompletionsUpdate,
|
||||
)
|
||||
from azure.core.credentials import AzureKeyCredential # type: ignore
|
||||
from azure.core.exceptions import HttpResponseError # type: ignore
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Azure AI Inference native provider not available, to install: `uv add azure-ai-inference`"
|
||||
) from None
|
||||
|
||||
|
||||
class AzureCompletion(BaseLLM):
|
||||
"""Azure AI Inference native completion implementation.
|
||||
|
||||
This class provides direct integration with the Azure AI Inference Python SDK,
|
||||
offering native function calling, streaming support, and proper Azure authentication.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
api_key: str | None = None,
|
||||
endpoint: str | None = None,
|
||||
api_version: str | None = None,
|
||||
timeout: float | None = None,
|
||||
max_retries: int = 2,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize Azure AI Inference chat completion client.
|
||||
|
||||
Args:
|
||||
model: Azure deployment name or model name
|
||||
api_key: Azure API key (defaults to AZURE_API_KEY env var)
|
||||
endpoint: Azure endpoint URL (defaults to AZURE_ENDPOINT env var)
|
||||
api_version: Azure API version (defaults to AZURE_API_VERSION env var)
|
||||
timeout: Request timeout in seconds
|
||||
max_retries: Maximum number of retries
|
||||
temperature: Sampling temperature (0-2)
|
||||
top_p: Nucleus sampling parameter
|
||||
frequency_penalty: Frequency penalty (-2 to 2)
|
||||
presence_penalty: Presence penalty (-2 to 2)
|
||||
max_tokens: Maximum tokens in response
|
||||
stop: Stop sequences
|
||||
stream: Enable streaming responses
|
||||
**kwargs: Additional parameters
|
||||
"""
|
||||
super().__init__(
|
||||
model=model, temperature=temperature, stop=stop or [], **kwargs
|
||||
)
|
||||
|
||||
self.api_key = api_key or os.getenv("AZURE_API_KEY")
|
||||
self.endpoint = (
|
||||
endpoint
|
||||
or os.getenv("AZURE_ENDPOINT")
|
||||
or os.getenv("AZURE_OPENAI_ENDPOINT")
|
||||
or os.getenv("AZURE_API_BASE")
|
||||
)
|
||||
self.api_version = api_version or os.getenv("AZURE_API_VERSION") or "2024-02-01"
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
"Azure API key is required. Set AZURE_API_KEY environment variable or pass api_key parameter."
|
||||
)
|
||||
if not self.endpoint:
|
||||
raise ValueError(
|
||||
"Azure endpoint is required. Set AZURE_ENDPOINT environment variable or pass endpoint parameter."
|
||||
)
|
||||
|
||||
self.client = ChatCompletionsClient(
|
||||
endpoint=self.endpoint,
|
||||
credential=AzureKeyCredential(self.api_key),
|
||||
)
|
||||
|
||||
self.top_p = top_p
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.presence_penalty = presence_penalty
|
||||
self.max_tokens = max_tokens
|
||||
self.stream = stream
|
||||
|
||||
self.is_openai_model = any(
|
||||
prefix in model.lower() for prefix in ["gpt-", "o1-", "text-"]
|
||||
)
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: str | list[dict[str, str]],
|
||||
tools: list[dict] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str | Any:
|
||||
"""Call Azure AI Inference chat completions API.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the chat completion
|
||||
tools: List of tool/function definitions
|
||||
callbacks: Callback functions (not used in native implementation)
|
||||
available_functions: Available functions for tool calling
|
||||
from_task: Task that initiated the call
|
||||
from_agent: Agent that initiated the call
|
||||
|
||||
Returns:
|
||||
Chat completion response or tool call result
|
||||
"""
|
||||
try:
|
||||
# Emit call started event
|
||||
self._emit_call_started_event(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
# Format messages for Azure
|
||||
formatted_messages = self._format_messages_for_azure(messages)
|
||||
|
||||
# Prepare completion parameters
|
||||
completion_params = self._prepare_completion_params(
|
||||
formatted_messages, tools
|
||||
)
|
||||
|
||||
# Handle streaming vs non-streaming
|
||||
if self.stream:
|
||||
return self._handle_streaming_completion(
|
||||
completion_params, available_functions, from_task, from_agent
|
||||
)
|
||||
|
||||
return self._handle_completion(
|
||||
completion_params, available_functions, from_task, from_agent
|
||||
)
|
||||
|
||||
except HttpResponseError as e:
|
||||
error_msg = f"Azure API HTTP error: {e.status_code} - {e.message}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"Azure API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
|
||||
def _prepare_completion_params(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
tools: list[dict] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Prepare parameters for Azure AI Inference chat completion.
|
||||
|
||||
Args:
|
||||
messages: Formatted messages for Azure
|
||||
tools: Tool definitions
|
||||
|
||||
Returns:
|
||||
Parameters dictionary for Azure API
|
||||
"""
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"stream": self.stream,
|
||||
}
|
||||
|
||||
# Add optional parameters if set
|
||||
if self.temperature is not None:
|
||||
params["temperature"] = self.temperature
|
||||
if self.top_p is not None:
|
||||
params["top_p"] = self.top_p
|
||||
if self.frequency_penalty is not None:
|
||||
params["frequency_penalty"] = self.frequency_penalty
|
||||
if self.presence_penalty is not None:
|
||||
params["presence_penalty"] = self.presence_penalty
|
||||
if self.max_tokens is not None:
|
||||
params["max_tokens"] = self.max_tokens
|
||||
if self.stop:
|
||||
params["stop"] = self.stop
|
||||
|
||||
# Handle tools/functions for Azure OpenAI models
|
||||
if tools and self.is_openai_model:
|
||||
params["tools"] = self._convert_tools_for_interference(tools)
|
||||
params["tool_choice"] = "auto"
|
||||
|
||||
return params
|
||||
|
||||
def _convert_tools_for_interference(self, tools: list[dict]) -> list[dict]:
|
||||
"""Convert CrewAI tool format to Azure OpenAI function calling format."""
|
||||
|
||||
from crewai.llms.providers.utils.common import safe_tool_conversion
|
||||
|
||||
azure_tools = []
|
||||
|
||||
for tool in tools:
|
||||
name, description, parameters = safe_tool_conversion(tool, "Azure")
|
||||
|
||||
azure_tool = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
"description": description,
|
||||
},
|
||||
}
|
||||
|
||||
if parameters:
|
||||
if isinstance(parameters, dict):
|
||||
azure_tool["function"]["parameters"] = parameters # type: ignore
|
||||
else:
|
||||
azure_tool["function"]["parameters"] = dict(parameters)
|
||||
|
||||
azure_tools.append(azure_tool)
|
||||
|
||||
return azure_tools
|
||||
|
||||
def _format_messages_for_azure(
|
||||
self, messages: str | list[dict[str, str]]
|
||||
) -> list[dict[str, str]]:
|
||||
"""Format messages for Azure AI Inference API.
|
||||
|
||||
Args:
|
||||
messages: Input messages
|
||||
|
||||
Returns:
|
||||
List of dict objects
|
||||
"""
|
||||
# Use base class formatting first
|
||||
base_formatted = super()._format_messages(messages)
|
||||
|
||||
azure_messages = []
|
||||
|
||||
for message in base_formatted:
|
||||
role = message.get("role")
|
||||
content = message.get("content", "")
|
||||
|
||||
if role == "system":
|
||||
azure_messages.append(dict(content=content))
|
||||
elif role == "user":
|
||||
azure_messages.append(dict(content=content))
|
||||
elif role == "assistant":
|
||||
azure_messages.append(dict(content=content))
|
||||
else:
|
||||
# Default to user message for unknown roles
|
||||
azure_messages.append(dict(content=content))
|
||||
|
||||
return azure_messages
|
||||
|
||||
def _handle_completion(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming chat completion."""
|
||||
# Make API call
|
||||
try:
|
||||
response: ChatCompletions = self.client.complete(**params)
|
||||
|
||||
if not response.choices:
|
||||
raise ValueError("No choices returned from Azure API")
|
||||
|
||||
choice = response.choices[0]
|
||||
message = choice.message
|
||||
|
||||
# Extract and track token usage
|
||||
usage = self._extract_azure_token_usage(response)
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
# Handle tool calls
|
||||
if message.tool_calls and available_functions:
|
||||
tool_call = message.tool_calls[0] # Handle first tool call
|
||||
if isinstance(tool_call, ChatCompletionsToolCall):
|
||||
function_name = tool_call.function.name
|
||||
|
||||
try:
|
||||
function_args = json.loads(tool_call.function.arguments)
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"Failed to parse tool arguments: {e}")
|
||||
function_args = {}
|
||||
|
||||
# Execute tool
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Extract content
|
||||
content = message.content or ""
|
||||
|
||||
# Apply stop words
|
||||
content = self._apply_stop_words(content)
|
||||
|
||||
# Emit completion event and return content
|
||||
self._emit_call_completed_event(
|
||||
response=content,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
logging.error(f"Context window exceeded: {e}")
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
|
||||
return content
|
||||
|
||||
def _handle_streaming_completion(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str:
|
||||
"""Handle streaming chat completion."""
|
||||
full_response = ""
|
||||
tool_calls = {}
|
||||
|
||||
# Make streaming API call
|
||||
for update in self.client.complete(**params):
|
||||
if isinstance(update, StreamingChatCompletionsUpdate):
|
||||
if update.choices:
|
||||
choice = update.choices[0]
|
||||
if choice.delta and choice.delta.content:
|
||||
content_delta = choice.delta.content
|
||||
full_response += content_delta
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=content_delta,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
# Handle tool call streaming
|
||||
if choice.delta and choice.delta.tool_calls:
|
||||
for tool_call in choice.delta.tool_calls:
|
||||
call_id = tool_call.id or "default"
|
||||
if call_id not in tool_calls:
|
||||
tool_calls[call_id] = {
|
||||
"name": "",
|
||||
"arguments": "",
|
||||
}
|
||||
|
||||
if tool_call.function and tool_call.function.name:
|
||||
tool_calls[call_id]["name"] = tool_call.function.name
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
tool_calls[call_id]["arguments"] += (
|
||||
tool_call.function.arguments
|
||||
)
|
||||
|
||||
# Handle completed tool calls
|
||||
if tool_calls and available_functions:
|
||||
for call_data in tool_calls.values():
|
||||
function_name = call_data["name"]
|
||||
|
||||
try:
|
||||
function_args = json.loads(call_data["arguments"])
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"Failed to parse streamed tool arguments: {e}")
|
||||
continue
|
||||
|
||||
# Execute tool
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Apply stop words to full response
|
||||
full_response = self._apply_stop_words(full_response)
|
||||
|
||||
# Emit completion event and return full response
|
||||
self._emit_call_completed_event(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return full_response
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the model supports function calling."""
|
||||
# Azure OpenAI models support function calling
|
||||
return self.is_openai_model
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
"""Check if the model supports stop words."""
|
||||
return True # Most Azure models support stop sequences
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
"""Get the context window size for the model."""
|
||||
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES
|
||||
|
||||
min_context = 1024
|
||||
max_context = 2097152
|
||||
|
||||
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
|
||||
if value < min_context or value > max_context:
|
||||
raise ValueError(
|
||||
f"Context window for {key} must be between {min_context} and {max_context}"
|
||||
)
|
||||
|
||||
# Context window sizes for common Azure models
|
||||
context_windows = {
|
||||
"gpt-4": 8192,
|
||||
"gpt-4o": 128000,
|
||||
"gpt-4o-mini": 200000,
|
||||
"gpt-4-turbo": 128000,
|
||||
"gpt-35-turbo": 16385,
|
||||
"gpt-3.5-turbo": 16385,
|
||||
"text-embedding": 8191,
|
||||
}
|
||||
|
||||
# Find the best match for the model name
|
||||
for model_prefix, size in context_windows.items():
|
||||
if self.model.startswith(model_prefix):
|
||||
return int(size * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
|
||||
# Default context window size
|
||||
return int(8192 * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
|
||||
def _extract_azure_token_usage(self, response: ChatCompletions) -> dict[str, Any]:
|
||||
"""Extract token usage from Azure response."""
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
usage = response.usage
|
||||
return {
|
||||
"prompt_tokens": getattr(usage, "prompt_tokens", 0),
|
||||
"completion_tokens": getattr(usage, "completion_tokens", 0),
|
||||
"total_tokens": getattr(usage, "total_tokens", 0),
|
||||
}
|
||||
return {"total_tokens": 0}
|
||||
497
lib/crewai/src/crewai/llms/providers/gemini/completion.py
Normal file
497
lib/crewai/src/crewai/llms/providers/gemini/completion.py
Normal file
@@ -0,0 +1,497 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
from google import genai # type: ignore
|
||||
from google.genai import types # type: ignore
|
||||
from google.genai.errors import APIError # type: ignore
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Google Gen AI native provider not available, to install: `uv add google-genai`"
|
||||
) from None
|
||||
|
||||
|
||||
class GeminiCompletion(BaseLLM):
|
||||
"""Google Gemini native completion implementation.
|
||||
|
||||
This class provides direct integration with the Google Gen AI Python SDK,
|
||||
offering native function calling, streaming support, and proper Gemini formatting.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "gemini-2.0-flash-001",
|
||||
api_key: str | None = None,
|
||||
project: str | None = None,
|
||||
location: str | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
top_k: int | None = None,
|
||||
max_output_tokens: int | None = None,
|
||||
stop_sequences: list[str] | None = None,
|
||||
stream: bool = False,
|
||||
safety_settings: dict[str, Any] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize Google Gemini chat completion client.
|
||||
|
||||
Args:
|
||||
model: Gemini model name (e.g., 'gemini-2.0-flash-001', 'gemini-1.5-pro')
|
||||
api_key: Google API key (defaults to GOOGLE_API_KEY or GEMINI_API_KEY env var)
|
||||
project: Google Cloud project ID (for Vertex AI)
|
||||
location: Google Cloud location (for Vertex AI, defaults to 'us-central1')
|
||||
temperature: Sampling temperature (0-2)
|
||||
top_p: Nucleus sampling parameter
|
||||
top_k: Top-k sampling parameter
|
||||
max_output_tokens: Maximum tokens in response
|
||||
stop_sequences: Stop sequences
|
||||
stream: Enable streaming responses
|
||||
safety_settings: Safety filter settings
|
||||
**kwargs: Additional parameters
|
||||
"""
|
||||
super().__init__(
|
||||
model=model, temperature=temperature, stop=stop_sequences or [], **kwargs
|
||||
)
|
||||
|
||||
# Get API configuration
|
||||
self.api_key = (
|
||||
api_key or os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY")
|
||||
)
|
||||
self.project = project or os.getenv("GOOGLE_CLOUD_PROJECT")
|
||||
self.location = location or os.getenv("GOOGLE_CLOUD_LOCATION") or "us-central1"
|
||||
|
||||
# Initialize client based on available configuration
|
||||
if self.project:
|
||||
# Use Vertex AI
|
||||
self.client = genai.Client(
|
||||
vertexai=True,
|
||||
project=self.project,
|
||||
location=self.location,
|
||||
)
|
||||
elif self.api_key:
|
||||
# Use Gemini Developer API
|
||||
self.client = genai.Client(api_key=self.api_key)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either GOOGLE_API_KEY/GEMINI_API_KEY (for Gemini API) or "
|
||||
"GOOGLE_CLOUD_PROJECT (for Vertex AI) must be set"
|
||||
)
|
||||
|
||||
# Store completion parameters
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
self.max_output_tokens = max_output_tokens
|
||||
self.stream = stream
|
||||
self.safety_settings = safety_settings or {}
|
||||
self.stop_sequences = stop_sequences or []
|
||||
|
||||
# Model-specific settings
|
||||
self.is_gemini_2 = "gemini-2" in model.lower()
|
||||
self.is_gemini_1_5 = "gemini-1.5" in model.lower()
|
||||
self.supports_tools = self.is_gemini_1_5 or self.is_gemini_2
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: str | list[dict[str, str]],
|
||||
tools: list[dict] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str | Any:
|
||||
"""Call Google Gemini generate content API.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the chat completion
|
||||
tools: List of tool/function definitions
|
||||
callbacks: Callback functions (not used as token counts are handled by the reponse)
|
||||
available_functions: Available functions for tool calling
|
||||
from_task: Task that initiated the call
|
||||
from_agent: Agent that initiated the call
|
||||
|
||||
Returns:
|
||||
Chat completion response or tool call result
|
||||
"""
|
||||
try:
|
||||
self._emit_call_started_event(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
self.tools = tools
|
||||
|
||||
formatted_content, system_instruction = self._format_messages_for_gemini(
|
||||
messages
|
||||
)
|
||||
|
||||
config = self._prepare_generation_config(system_instruction, tools)
|
||||
|
||||
if self.stream:
|
||||
return self._handle_streaming_completion(
|
||||
formatted_content,
|
||||
config,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
)
|
||||
|
||||
return self._handle_completion(
|
||||
formatted_content,
|
||||
system_instruction,
|
||||
config,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
)
|
||||
|
||||
except APIError as e:
|
||||
error_msg = f"Google Gemini API error: {e.code} - {e.message}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"Google Gemini API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
|
||||
def _prepare_generation_config(
|
||||
self,
|
||||
system_instruction: str | None = None,
|
||||
tools: list[dict] | None = None,
|
||||
) -> types.GenerateContentConfig:
|
||||
"""Prepare generation config for Google Gemini API.
|
||||
|
||||
Args:
|
||||
system_instruction: System instruction for the model
|
||||
tools: Tool definitions
|
||||
|
||||
Returns:
|
||||
GenerateContentConfig object for Gemini API
|
||||
"""
|
||||
self.tools = tools
|
||||
config_params = {}
|
||||
|
||||
# Add system instruction if present
|
||||
if system_instruction:
|
||||
# Convert system instruction to Content format
|
||||
system_content = types.Content(
|
||||
role="user", parts=[types.Part.from_text(text=system_instruction)]
|
||||
)
|
||||
config_params["system_instruction"] = system_content
|
||||
|
||||
# Add generation config parameters
|
||||
if self.temperature is not None:
|
||||
config_params["temperature"] = self.temperature
|
||||
if self.top_p is not None:
|
||||
config_params["top_p"] = self.top_p
|
||||
if self.top_k is not None:
|
||||
config_params["top_k"] = self.top_k
|
||||
if self.max_output_tokens is not None:
|
||||
config_params["max_output_tokens"] = self.max_output_tokens
|
||||
if self.stop_sequences:
|
||||
config_params["stop_sequences"] = self.stop_sequences
|
||||
|
||||
# Handle tools for supported models
|
||||
if tools and self.supports_tools:
|
||||
config_params["tools"] = self._convert_tools_for_interference(tools)
|
||||
|
||||
if self.safety_settings:
|
||||
config_params["safety_settings"] = self.safety_settings
|
||||
|
||||
return types.GenerateContentConfig(**config_params)
|
||||
|
||||
def _convert_tools_for_interference(self, tools: list[dict]) -> list[types.Tool]:
|
||||
"""Convert CrewAI tool format to Gemini function declaration format."""
|
||||
gemini_tools = []
|
||||
|
||||
from crewai.llms.providers.utils.common import safe_tool_conversion
|
||||
|
||||
for tool in tools:
|
||||
name, description, parameters = safe_tool_conversion(tool, "Gemini")
|
||||
|
||||
function_declaration = types.FunctionDeclaration(
|
||||
name=name,
|
||||
description=description,
|
||||
)
|
||||
|
||||
# Add parameters if present - ensure parameters is a dict
|
||||
if parameters and isinstance(parameters, dict):
|
||||
function_declaration.parameters = parameters
|
||||
|
||||
gemini_tool = types.Tool(function_declarations=[function_declaration])
|
||||
gemini_tools.append(gemini_tool)
|
||||
|
||||
return gemini_tools
|
||||
|
||||
def _format_messages_for_gemini(
|
||||
self, messages: str | list[dict[str, str]]
|
||||
) -> tuple[list[types.Content], str | None]:
|
||||
"""Format messages for Gemini API.
|
||||
|
||||
Gemini has specific requirements:
|
||||
- System messages are separate system_instruction
|
||||
- Content is organized as Content objects with Parts
|
||||
- Roles are 'user' and 'model' (not 'assistant')
|
||||
|
||||
Args:
|
||||
messages: Input messages
|
||||
|
||||
Returns:
|
||||
Tuple of (formatted_contents, system_instruction)
|
||||
"""
|
||||
# Use base class formatting first
|
||||
base_formatted = super()._format_messages(messages)
|
||||
|
||||
contents = []
|
||||
system_instruction = None
|
||||
|
||||
for message in base_formatted:
|
||||
role = message.get("role")
|
||||
content = message.get("content", "")
|
||||
|
||||
if role == "system":
|
||||
# Extract system instruction - Gemini handles it separately
|
||||
if system_instruction:
|
||||
system_instruction += f"\n\n{content}"
|
||||
else:
|
||||
system_instruction = content
|
||||
else:
|
||||
# Convert role for Gemini (assistant -> model)
|
||||
gemini_role = "model" if role == "assistant" else "user"
|
||||
|
||||
# Create Content object
|
||||
gemini_content = types.Content(
|
||||
role=gemini_role, parts=[types.Part.from_text(text=content)]
|
||||
)
|
||||
contents.append(gemini_content)
|
||||
|
||||
return contents, system_instruction
|
||||
|
||||
def _handle_completion(
|
||||
self,
|
||||
contents: list[types.Content],
|
||||
system_instruction: str | None,
|
||||
config: types.GenerateContentConfig,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming content generation."""
|
||||
api_params = {
|
||||
"model": self.model,
|
||||
"contents": contents,
|
||||
"config": config,
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.client.models.generate_content(**api_params)
|
||||
|
||||
usage = self._extract_token_usage(response)
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
logging.error(f"Context window exceeded: {e}")
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
raise e from e
|
||||
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
if response.candidates and (self.tools or available_functions):
|
||||
candidate = response.candidates[0]
|
||||
if candidate.content and candidate.content.parts:
|
||||
for part in candidate.content.parts:
|
||||
if hasattr(part, "function_call") and part.function_call:
|
||||
function_name = part.function_call.name
|
||||
function_args = (
|
||||
dict(part.function_call.args)
|
||||
if part.function_call.args
|
||||
else {}
|
||||
)
|
||||
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions, # type: ignore
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
content = response.text if hasattr(response, "text") else ""
|
||||
content = self._apply_stop_words(content)
|
||||
|
||||
messages_for_event = self._convert_contents_to_dict(contents)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=content,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=messages_for_event,
|
||||
)
|
||||
|
||||
return content
|
||||
|
||||
def _handle_streaming_completion(
|
||||
self,
|
||||
contents: list[types.Content],
|
||||
config: types.GenerateContentConfig,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str:
|
||||
"""Handle streaming content generation."""
|
||||
full_response = ""
|
||||
function_calls = {}
|
||||
|
||||
api_params = {
|
||||
"model": self.model,
|
||||
"contents": contents,
|
||||
"config": config,
|
||||
}
|
||||
|
||||
for chunk in self.client.models.generate_content_stream(**api_params):
|
||||
if hasattr(chunk, "text") and chunk.text:
|
||||
full_response += chunk.text
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=chunk.text,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if hasattr(chunk, "candidates") and chunk.candidates:
|
||||
candidate = chunk.candidates[0]
|
||||
if candidate.content and candidate.content.parts:
|
||||
for part in candidate.content.parts:
|
||||
if hasattr(part, "function_call") and part.function_call:
|
||||
call_id = part.function_call.name or "default"
|
||||
if call_id not in function_calls:
|
||||
function_calls[call_id] = {
|
||||
"name": part.function_call.name,
|
||||
"args": dict(part.function_call.args)
|
||||
if part.function_call.args
|
||||
else {},
|
||||
}
|
||||
|
||||
# Handle completed function calls
|
||||
if function_calls and available_functions:
|
||||
for call_data in function_calls.values():
|
||||
function_name = call_data["name"]
|
||||
function_args = call_data["args"]
|
||||
|
||||
# Execute tool
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
messages_for_event = self._convert_contents_to_dict(contents)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=messages_for_event,
|
||||
)
|
||||
|
||||
return full_response
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the model supports function calling."""
|
||||
return self.supports_tools
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
"""Check if the model supports stop words."""
|
||||
return self._supports_stop_words_implementation()
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
"""Get the context window size for the model."""
|
||||
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES
|
||||
|
||||
min_context = 1024
|
||||
max_context = 2097152
|
||||
|
||||
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
|
||||
if value < min_context or value > max_context:
|
||||
raise ValueError(
|
||||
f"Context window for {key} must be between {min_context} and {max_context}"
|
||||
)
|
||||
|
||||
context_windows = {
|
||||
"gemini-2.0-flash": 1048576, # 1M tokens
|
||||
"gemini-2.0-flash-thinking": 32768,
|
||||
"gemini-2.0-flash-lite": 1048576,
|
||||
"gemini-2.5-flash": 1048576,
|
||||
"gemini-2.5-pro": 1048576,
|
||||
"gemini-1.5-pro": 2097152, # 2M tokens
|
||||
"gemini-1.5-flash": 1048576,
|
||||
"gemini-1.5-flash-8b": 1048576,
|
||||
"gemini-1.0-pro": 32768,
|
||||
"gemma-3-1b": 32000,
|
||||
"gemma-3-4b": 128000,
|
||||
"gemma-3-12b": 128000,
|
||||
"gemma-3-27b": 128000,
|
||||
}
|
||||
|
||||
# Find the best match for the model name
|
||||
for model_prefix, size in context_windows.items():
|
||||
if self.model.startswith(model_prefix):
|
||||
return int(size * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
|
||||
# Default context window size for Gemini models
|
||||
return int(1048576 * CONTEXT_WINDOW_USAGE_RATIO) # 1M tokens
|
||||
|
||||
def _extract_token_usage(self, response: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Extract token usage from Gemini response."""
|
||||
if hasattr(response, "usage_metadata"):
|
||||
usage = response.usage_metadata
|
||||
return {
|
||||
"prompt_token_count": getattr(usage, "prompt_token_count", 0),
|
||||
"candidates_token_count": getattr(usage, "candidates_token_count", 0),
|
||||
"total_token_count": getattr(usage, "total_token_count", 0),
|
||||
"total_tokens": getattr(usage, "total_token_count", 0),
|
||||
}
|
||||
return {"total_tokens": 0}
|
||||
|
||||
def _convert_contents_to_dict(
|
||||
self, contents: list[types.Content]
|
||||
) -> list[dict[str, str]]:
|
||||
"""Convert contents to dict format."""
|
||||
return [
|
||||
{
|
||||
"role": "assistant"
|
||||
if content_obj.role == "model"
|
||||
else content_obj.role,
|
||||
"content": " ".join(
|
||||
part.text
|
||||
for part in content_obj.parts
|
||||
if hasattr(part, "text") and part.text
|
||||
),
|
||||
}
|
||||
for content_obj in contents
|
||||
]
|
||||
484
lib/crewai/src/crewai/llms/providers/openai/completion.py
Normal file
484
lib/crewai/src/crewai/llms/providers/openai/completion.py
Normal file
@@ -0,0 +1,484 @@
|
||||
from collections.abc import Iterator
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
)
|
||||
from openai import OpenAI
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion import Choice
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class OpenAICompletion(BaseLLM):
|
||||
"""OpenAI native completion implementation.
|
||||
|
||||
This class provides direct integration with the OpenAI Python SDK,
|
||||
offering native structured outputs, function calling, and streaming support.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "gpt-4o",
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
organization: str | None = None,
|
||||
project: str | None = None,
|
||||
timeout: float | None = None,
|
||||
max_retries: int = 2,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
seed: int | None = None,
|
||||
stream: bool = False,
|
||||
response_format: dict[str, Any] | type[BaseModel] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
reasoning_effort: str | None = None, # For o1 models
|
||||
provider: str | None = None, # Add provider parameter
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize OpenAI chat completion client."""
|
||||
|
||||
if provider is None:
|
||||
provider = kwargs.pop("provider", "openai")
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
api_key=api_key or os.getenv("OPENAI_API_KEY"),
|
||||
base_url=base_url,
|
||||
timeout=timeout,
|
||||
provider=provider,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.client = OpenAI(
|
||||
api_key=api_key or os.getenv("OPENAI_API_KEY"),
|
||||
base_url=base_url,
|
||||
organization=organization,
|
||||
project=project,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
|
||||
self.top_p = top_p
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.presence_penalty = presence_penalty
|
||||
self.max_tokens = max_tokens
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
self.seed = seed
|
||||
self.stream = stream
|
||||
self.response_format = response_format
|
||||
self.logprobs = logprobs
|
||||
self.top_logprobs = top_logprobs
|
||||
self.reasoning_effort = reasoning_effort
|
||||
self.timeout = timeout
|
||||
self.is_o1_model = "o1" in model.lower()
|
||||
self.is_gpt4_model = "gpt-4" in model.lower()
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: str | list[dict[str, str]],
|
||||
tools: list[dict] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str | Any:
|
||||
"""Call OpenAI chat completion API.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the chat completion
|
||||
tools: list of tool/function definitions
|
||||
callbacks: Callback functions (not used in native implementation)
|
||||
available_functions: Available functions for tool calling
|
||||
from_task: Task that initiated the call
|
||||
from_agent: Agent that initiated the call
|
||||
|
||||
Returns:
|
||||
Chat completion response or tool call result
|
||||
"""
|
||||
try:
|
||||
self._emit_call_started_event(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
formatted_messages = self._format_messages(messages)
|
||||
|
||||
completion_params = self._prepare_completion_params(
|
||||
formatted_messages, tools
|
||||
)
|
||||
|
||||
if self.stream:
|
||||
return self._handle_streaming_completion(
|
||||
completion_params, available_functions, from_task, from_agent
|
||||
)
|
||||
|
||||
return self._handle_completion(
|
||||
completion_params, available_functions, from_task, from_agent
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"OpenAI API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
|
||||
def _prepare_completion_params(
|
||||
self, messages: list[dict[str, str]], tools: list[dict] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Prepare parameters for OpenAI chat completion."""
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"stream": self.stream,
|
||||
}
|
||||
|
||||
params.update(self.additional_params)
|
||||
|
||||
if self.temperature is not None:
|
||||
params["temperature"] = self.temperature
|
||||
if self.top_p is not None:
|
||||
params["top_p"] = self.top_p
|
||||
if self.frequency_penalty is not None:
|
||||
params["frequency_penalty"] = self.frequency_penalty
|
||||
if self.presence_penalty is not None:
|
||||
params["presence_penalty"] = self.presence_penalty
|
||||
if self.max_completion_tokens is not None:
|
||||
params["max_completion_tokens"] = self.max_completion_tokens
|
||||
elif self.max_tokens is not None:
|
||||
params["max_tokens"] = self.max_tokens
|
||||
if self.seed is not None:
|
||||
params["seed"] = self.seed
|
||||
if self.logprobs is not None:
|
||||
params["logprobs"] = self.logprobs
|
||||
if self.top_logprobs is not None:
|
||||
params["top_logprobs"] = self.top_logprobs
|
||||
|
||||
# Handle o1 model specific parameters
|
||||
if self.is_o1_model and self.reasoning_effort:
|
||||
params["reasoning_effort"] = self.reasoning_effort
|
||||
|
||||
# Handle response format for structured outputs
|
||||
if self.response_format:
|
||||
if isinstance(self.response_format, type) and issubclass(
|
||||
self.response_format, BaseModel
|
||||
):
|
||||
# Convert Pydantic model to OpenAI response format
|
||||
params["response_format"] = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": self.response_format.__name__,
|
||||
"schema": self.response_format.model_json_schema(),
|
||||
},
|
||||
}
|
||||
else:
|
||||
params["response_format"] = self.response_format
|
||||
|
||||
if tools:
|
||||
params["tools"] = self._convert_tools_for_interference(tools)
|
||||
params["tool_choice"] = "auto"
|
||||
|
||||
# Filter out CrewAI-specific parameters that shouldn't go to the API
|
||||
crewai_specific_params = {
|
||||
"callbacks",
|
||||
"available_functions",
|
||||
"from_task",
|
||||
"from_agent",
|
||||
"provider",
|
||||
"api_key",
|
||||
"base_url",
|
||||
"timeout",
|
||||
"max_retries",
|
||||
}
|
||||
|
||||
return {k: v for k, v in params.items() if k not in crewai_specific_params}
|
||||
|
||||
def _convert_tools_for_interference(self, tools: list[dict]) -> list[dict]:
|
||||
"""Convert CrewAI tool format to OpenAI function calling format."""
|
||||
from crewai.llms.providers.utils.common import safe_tool_conversion
|
||||
|
||||
openai_tools = []
|
||||
|
||||
for tool in tools:
|
||||
name, description, parameters = safe_tool_conversion(tool, "OpenAI")
|
||||
|
||||
openai_tool = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
"description": description,
|
||||
},
|
||||
}
|
||||
|
||||
if parameters:
|
||||
if isinstance(parameters, dict):
|
||||
openai_tool["function"]["parameters"] = parameters # type: ignore
|
||||
else:
|
||||
openai_tool["function"]["parameters"] = dict(parameters)
|
||||
|
||||
openai_tools.append(openai_tool)
|
||||
return openai_tools
|
||||
|
||||
def _handle_completion(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming chat completion."""
|
||||
try:
|
||||
response: ChatCompletion = self.client.chat.completions.create(**params)
|
||||
|
||||
usage = self._extract_openai_token_usage(response)
|
||||
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
choice: Choice = response.choices[0]
|
||||
message = choice.message
|
||||
|
||||
if message.tool_calls and available_functions:
|
||||
tool_call = message.tool_calls[0]
|
||||
function_name = tool_call.function.name
|
||||
|
||||
try:
|
||||
function_args = json.loads(tool_call.function.arguments)
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"Failed to parse tool arguments: {e}")
|
||||
function_args = {}
|
||||
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
content = message.content or ""
|
||||
content = self._apply_stop_words(content)
|
||||
|
||||
if self.response_format and isinstance(self.response_format, type):
|
||||
try:
|
||||
structured_result = self._validate_structured_output(
|
||||
content, self.response_format
|
||||
)
|
||||
self._emit_call_completed_event(
|
||||
response=structured_result,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return structured_result
|
||||
except ValueError as e:
|
||||
logging.warning(f"Structured output validation failed: {e}")
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=content,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
if usage.get("total_tokens", 0) > 0:
|
||||
logging.info(f"OpenAI API usage: {usage}")
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
logging.error(f"Context window exceeded: {e}")
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
raise e from e
|
||||
|
||||
return content
|
||||
|
||||
def _handle_streaming_completion(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str:
|
||||
"""Handle streaming chat completion."""
|
||||
full_response = ""
|
||||
tool_calls = {}
|
||||
|
||||
# Make streaming API call
|
||||
stream: Iterator[ChatCompletionChunk] = self.client.chat.completions.create(
|
||||
**params
|
||||
)
|
||||
|
||||
for chunk in stream:
|
||||
if not chunk.choices:
|
||||
continue
|
||||
|
||||
choice = chunk.choices[0]
|
||||
delta: ChoiceDelta = choice.delta
|
||||
|
||||
# Handle content streaming
|
||||
if delta.content:
|
||||
full_response += delta.content
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=delta.content,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
# Handle tool call streaming
|
||||
if delta.tool_calls:
|
||||
for tool_call in delta.tool_calls:
|
||||
call_id = tool_call.id or "default"
|
||||
if call_id not in tool_calls:
|
||||
tool_calls[call_id] = {
|
||||
"name": "",
|
||||
"arguments": "",
|
||||
}
|
||||
|
||||
if tool_call.function and tool_call.function.name:
|
||||
tool_calls[call_id]["name"] = tool_call.function.name
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
tool_calls[call_id]["arguments"] += tool_call.function.arguments
|
||||
|
||||
if tool_calls and available_functions:
|
||||
for call_data in tool_calls.values():
|
||||
function_name = call_data["name"]
|
||||
arguments = call_data["arguments"]
|
||||
|
||||
# Skip if function name is empty or arguments are empty
|
||||
if not function_name or not arguments:
|
||||
continue
|
||||
|
||||
# Check if function exists in available functions
|
||||
if function_name not in available_functions:
|
||||
logging.warning(
|
||||
f"Function '{function_name}' not found in available functions"
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
function_args = json.loads(arguments)
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"Failed to parse streamed tool arguments: {e}")
|
||||
continue
|
||||
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Apply stop words to full response
|
||||
full_response = self._apply_stop_words(full_response)
|
||||
|
||||
# Emit completion event and return full response
|
||||
self._emit_call_completed_event(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return full_response
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the model supports function calling."""
|
||||
return not self.is_o1_model
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
"""Check if the model supports stop words."""
|
||||
return not self.is_o1_model
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
"""Get the context window size for the model."""
|
||||
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES
|
||||
|
||||
min_context = 1024
|
||||
max_context = 2097152
|
||||
|
||||
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
|
||||
if value < min_context or value > max_context:
|
||||
raise ValueError(
|
||||
f"Context window for {key} must be between {min_context} and {max_context}"
|
||||
)
|
||||
|
||||
# Context window sizes for OpenAI models
|
||||
context_windows = {
|
||||
"gpt-4": 8192,
|
||||
"gpt-4o": 128000,
|
||||
"gpt-4o-mini": 200000,
|
||||
"gpt-4-turbo": 128000,
|
||||
"gpt-4.1": 1047576,
|
||||
"gpt-4.1-mini-2025-04-14": 1047576,
|
||||
"gpt-4.1-nano-2025-04-14": 1047576,
|
||||
"o1-preview": 128000,
|
||||
"o1-mini": 128000,
|
||||
"o3-mini": 200000,
|
||||
"o4-mini": 200000,
|
||||
}
|
||||
|
||||
# Find the best match for the model name
|
||||
for model_prefix, size in context_windows.items():
|
||||
if self.model.startswith(model_prefix):
|
||||
return int(size * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
|
||||
# Default context window size
|
||||
return int(8192 * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
|
||||
def _extract_openai_token_usage(self, response: ChatCompletion) -> dict[str, Any]:
|
||||
"""Extract token usage from OpenAI ChatCompletion response."""
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
usage = response.usage
|
||||
return {
|
||||
"prompt_tokens": getattr(usage, "prompt_tokens", 0),
|
||||
"completion_tokens": getattr(usage, "completion_tokens", 0),
|
||||
"total_tokens": getattr(usage, "total_tokens", 0),
|
||||
}
|
||||
return {"total_tokens": 0}
|
||||
|
||||
def _format_messages(
|
||||
self, messages: str | list[dict[str, str]]
|
||||
) -> list[dict[str, str]]:
|
||||
"""Format messages for OpenAI API."""
|
||||
# Use base class formatting first
|
||||
base_formatted = super()._format_messages(messages)
|
||||
|
||||
# Apply OpenAI-specific formatting
|
||||
formatted_messages = []
|
||||
|
||||
for message in base_formatted:
|
||||
if self.is_o1_model and message.get("role") == "system":
|
||||
formatted_messages.append(
|
||||
{"role": "user", "content": f"System: {message['content']}"}
|
||||
)
|
||||
else:
|
||||
formatted_messages.append(message)
|
||||
|
||||
return formatted_messages
|
||||
136
lib/crewai/src/crewai/llms/providers/utils/common.py
Normal file
136
lib/crewai/src/crewai/llms/providers/utils/common.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
|
||||
def validate_function_name(name: str, provider: str = "LLM") -> str:
|
||||
"""Validate function name according to common LLM provider requirements.
|
||||
|
||||
Most LLM providers (OpenAI, Gemini, Anthropic) have similar requirements:
|
||||
- Must start with letter or underscore
|
||||
- Only alphanumeric, underscore, dot, colon, dash allowed
|
||||
- Maximum length of 64 characters
|
||||
- Cannot be empty
|
||||
|
||||
Args:
|
||||
name: The function name to validate
|
||||
provider: The provider name for error messages
|
||||
|
||||
Returns:
|
||||
The validated function name (unchanged if valid)
|
||||
|
||||
Raises:
|
||||
ValueError: If the function name is invalid
|
||||
"""
|
||||
if not name or not isinstance(name, str):
|
||||
raise ValueError(f"{provider} function name cannot be empty")
|
||||
|
||||
if not (name[0].isalpha() or name[0] == "_"):
|
||||
raise ValueError(
|
||||
f"{provider} function name '{name}' must start with a letter or underscore"
|
||||
)
|
||||
|
||||
if len(name) > 64:
|
||||
raise ValueError(
|
||||
f"{provider} function name '{name}' exceeds 64 character limit"
|
||||
)
|
||||
|
||||
# Check for invalid characters (most providers support these)
|
||||
if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_.\-:]*$", name):
|
||||
raise ValueError(
|
||||
f"{provider} function name '{name}' contains invalid characters. "
|
||||
f"Only letters, numbers, underscore, dot, colon, dash allowed"
|
||||
)
|
||||
|
||||
return name
|
||||
|
||||
|
||||
def extract_tool_info(tool: dict[str, Any]) -> tuple[str, str, dict[str, Any]]:
|
||||
"""Extract tool information from various schema formats.
|
||||
|
||||
Handles both OpenAI/standard format and direct format:
|
||||
- OpenAI format: {"type": "function", "function": {"name": "...", ...}}
|
||||
- Direct format: {"name": "...", "description": "...", ...}
|
||||
|
||||
Args:
|
||||
tool: Tool dictionary in any supported format
|
||||
|
||||
Returns:
|
||||
Tuple of (name, description, parameters)
|
||||
|
||||
Raises:
|
||||
ValueError: If tool format is invalid
|
||||
"""
|
||||
if not isinstance(tool, dict):
|
||||
raise ValueError("Tool must be a dictionary")
|
||||
|
||||
# Handle nested function schema format (OpenAI/standard)
|
||||
if "function" in tool:
|
||||
function_info = tool["function"]
|
||||
if not isinstance(function_info, dict):
|
||||
raise ValueError("Tool function must be a dictionary")
|
||||
|
||||
name = function_info.get("name", "")
|
||||
description = function_info.get("description", "")
|
||||
parameters = function_info.get("parameters", {})
|
||||
else:
|
||||
# Direct format
|
||||
name = tool.get("name", "")
|
||||
description = tool.get("description", "")
|
||||
parameters = tool.get("parameters", {})
|
||||
|
||||
# Also check for args_schema (Pydantic format)
|
||||
if not parameters and "args_schema" in tool:
|
||||
if hasattr(tool["args_schema"], "model_json_schema"):
|
||||
parameters = tool["args_schema"].model_json_schema()
|
||||
|
||||
return name, description, parameters
|
||||
|
||||
|
||||
def log_tool_conversion(tool: dict[str, Any], provider: str) -> None:
|
||||
"""Log tool conversion for debugging.
|
||||
|
||||
Args:
|
||||
tool: The tool being converted
|
||||
provider: The provider name
|
||||
"""
|
||||
try:
|
||||
name, description, parameters = extract_tool_info(tool)
|
||||
logging.debug(
|
||||
f"{provider}: Converting tool '{name}' (desc: {description[:50]}...)"
|
||||
)
|
||||
logging.debug(f"{provider}: Tool parameters: {parameters}")
|
||||
except Exception as e:
|
||||
logging.error(f"{provider}: Error extracting tool info: {e}")
|
||||
logging.error(f"{provider}: Tool structure: {tool}")
|
||||
|
||||
|
||||
def safe_tool_conversion(
|
||||
tool: dict[str, Any], provider: str
|
||||
) -> tuple[str, str, dict[str, Any]]:
|
||||
"""Safely extract and validate tool information.
|
||||
|
||||
Combines extraction, validation, and logging for robust tool conversion.
|
||||
|
||||
Args:
|
||||
tool: Tool dictionary to convert
|
||||
provider: Provider name for error messages and logging
|
||||
|
||||
Returns:
|
||||
Tuple of (validated_name, description, parameters)
|
||||
|
||||
Raises:
|
||||
ValueError: If tool is invalid or name validation fails
|
||||
"""
|
||||
try:
|
||||
log_tool_conversion(tool, provider)
|
||||
|
||||
name, description, parameters = extract_tool_info(tool)
|
||||
|
||||
validated_name = validate_function_name(name, provider)
|
||||
|
||||
logging.info(f"{provider}: Successfully validated tool '{validated_name}'")
|
||||
return validated_name, description, parameters
|
||||
except Exception as e:
|
||||
logging.error(f"{provider}: Error converting tool: {e}")
|
||||
raise
|
||||
@@ -1,10 +1,10 @@
|
||||
import ast
|
||||
import datetime
|
||||
import json
|
||||
import time
|
||||
from difflib import SequenceMatcher
|
||||
import json
|
||||
from json import JSONDecodeError
|
||||
from textwrap import dedent
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
import json5
|
||||
@@ -29,6 +29,7 @@ from crewai.utilities.agent_utils import (
|
||||
render_text_description_and_args,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.lite_agent import LiteAgent
|
||||
@@ -587,7 +588,23 @@ class ToolUsage:
|
||||
e: Exception,
|
||||
) -> None:
|
||||
event_data = self._prepare_event_data(tool, tool_calling)
|
||||
crewai_event_bus.emit(self, ToolUsageErrorEvent(**{**event_data, "error": e}))
|
||||
event_data.update(
|
||||
{
|
||||
"task_id": str(self.task.id) if self.task else None,
|
||||
"task_name": self.task.name or self.task.description
|
||||
if self.task
|
||||
else None,
|
||||
}
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
ToolUsageErrorEvent(
|
||||
**{
|
||||
**event_data,
|
||||
"error": e,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
def on_tool_use_finished(
|
||||
self,
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Final, Literal, TypedDict
|
||||
|
||||
from rich.console import Console
|
||||
@@ -15,7 +15,6 @@ from crewai.agents.parser import (
|
||||
parse,
|
||||
)
|
||||
from crewai.cli.config import Settings
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.tools import BaseTool as CrewAITool
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
@@ -29,11 +28,14 @@ from crewai.utilities.i18n import I18N
|
||||
from crewai.utilities.printer import ColoredText, Printer
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent import Agent
|
||||
from crewai.llm import LLM
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
|
||||
class SummaryContent(TypedDict):
|
||||
"""Structure for summary content entries.
|
||||
|
||||
@@ -392,8 +394,10 @@ def is_context_length_exceeded(exception: Exception) -> bool:
|
||||
Returns:
|
||||
bool: True if the exception is due to context length exceeding
|
||||
"""
|
||||
return LLMContextLengthExceededError(str(exception))._is_context_limit_error(
|
||||
str(exception)
|
||||
return (
|
||||
LLMContextLengthExceededError(str(exception))
|
||||
._is_context_limit_error(str(exception))
|
||||
._is_context_limit_error(str(exception))
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from crewai.cli.constants import DEFAULT_LLM_MODEL, ENV_VARS, LITELLM_PARAMS
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -42,7 +43,7 @@ def create_llm(
|
||||
or str(llm_value)
|
||||
)
|
||||
temperature: float | None = getattr(llm_value, "temperature", None)
|
||||
max_tokens: int | None = getattr(llm_value, "max_tokens", None)
|
||||
max_tokens: float | int | None = getattr(llm_value, "max_tokens", None)
|
||||
logprobs: int | None = getattr(llm_value, "logprobs", None)
|
||||
timeout: float | None = getattr(llm_value, "timeout", None)
|
||||
api_key: str | None = getattr(llm_value, "api_key", None)
|
||||
@@ -59,6 +60,7 @@ def create_llm(
|
||||
base_url=base_url,
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error instantiating LLM from unknown object type: {e}")
|
||||
return None
|
||||
@@ -117,6 +119,7 @@ def _llm_via_environment_or_fallback() -> LLM | None:
|
||||
elif api_base and not base_url:
|
||||
base_url = api_base
|
||||
|
||||
# Initialize llm_params dictionary
|
||||
llm_params: dict[str, Any] = {
|
||||
"model": model,
|
||||
"temperature": temperature,
|
||||
@@ -140,6 +143,11 @@ def _llm_via_environment_or_fallback() -> LLM | None:
|
||||
"callbacks": callbacks,
|
||||
}
|
||||
|
||||
unaccepted_attributes = [
|
||||
"AWS_ACCESS_KEY_ID",
|
||||
"AWS_SECRET_ACCESS_KEY",
|
||||
"AWS_REGION_NAME",
|
||||
]
|
||||
set_provider = model_name.partition("/")[0] if "/" in model_name else "openai"
|
||||
|
||||
if set_provider in ENV_VARS:
|
||||
@@ -147,7 +155,7 @@ def _llm_via_environment_or_fallback() -> LLM | None:
|
||||
if isinstance(env_vars_for_provider, (list, tuple)):
|
||||
for env_var in env_vars_for_provider:
|
||||
key_name = env_var.get("key_name")
|
||||
if key_name and key_name not in UNACCEPTED_ATTRIBUTES:
|
||||
if key_name and key_name not in unaccepted_attributes:
|
||||
env_value = os.environ.get(key_name)
|
||||
if env_value:
|
||||
# Map environment variable names to recognized parameters
|
||||
|
||||
@@ -102,21 +102,18 @@ class AgentReasoning:
|
||||
try:
|
||||
output = self.__handle_agent_reasoning()
|
||||
|
||||
# Emit reasoning completed event
|
||||
try:
|
||||
crewai_event_bus.emit(
|
||||
self.agent,
|
||||
AgentReasoningCompletedEvent(
|
||||
agent_role=self.agent.role,
|
||||
task_id=str(self.task.id),
|
||||
plan=output.plan.plan,
|
||||
ready=output.plan.ready,
|
||||
attempt=1,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
crewai_event_bus.emit(
|
||||
self.agent,
|
||||
AgentReasoningCompletedEvent(
|
||||
agent_role=self.agent.role,
|
||||
task_id=str(self.task.id),
|
||||
plan=output.plan.plan,
|
||||
ready=output.plan.ready,
|
||||
attempt=1,
|
||||
from_task=self.task,
|
||||
from_agent=self.agent,
|
||||
),
|
||||
)
|
||||
|
||||
return output
|
||||
except Exception as e:
|
||||
@@ -130,10 +127,11 @@ class AgentReasoning:
|
||||
error=str(e),
|
||||
attempt=1,
|
||||
from_task=self.task,
|
||||
from_agent=self.agent,
|
||||
),
|
||||
)
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
except Exception as e:
|
||||
logging.error(f"Error emitting reasoning failed event: {e}")
|
||||
|
||||
raise
|
||||
|
||||
|
||||
@@ -4,10 +4,24 @@ This module provides a callback handler that tracks token usage
|
||||
for LLM API calls through the litellm library.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.types.utils import Usage
|
||||
else:
|
||||
try:
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.types.utils import Usage
|
||||
except ImportError:
|
||||
|
||||
class CustomLogger:
|
||||
"""Fallback CustomLogger when litellm is not available."""
|
||||
|
||||
class Usage:
|
||||
"""Fallback Usage when litellm is not available."""
|
||||
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.types.utils import Usage
|
||||
|
||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||
from crewai.utilities.logger_utils import suppress_warnings
|
||||
|
||||
Reference in New Issue
Block a user