Lorenze/native inference sdks (#3619)

* ruff linted

* using native sdks with litellm fallback

* drop exa

* drop print on completion

* Refactor LLM and utility functions for type consistency

- Updated `max_tokens` parameter in `LLM` class to accept `float` in addition to `int`.
- Modified `create_llm` function to ensure consistent type hints and return types, now returning `LLM | BaseLLM | None`.
- Adjusted type hints for various parameters in `create_llm` and `_llm_via_environment_or_fallback` functions for improved clarity and type safety.
- Enhanced test cases to reflect changes in type handling and ensure proper instantiation of LLM instances.

* fix agent_tests

* fix litellm tests and usagemetrics fix

* drop print

* Refactor LLM event handling and improve test coverage

- Removed commented-out event emission for LLM call failures in `llm.py`.
- Added `from_agent` parameter to `CrewAgentExecutor` for better context in LLM responses.
- Enhanced test for LLM call failure to simulate OpenAI API failure and updated assertions for clarity.
- Updated agent and task ID assertions in tests to ensure they are consistently treated as strings.

* fix test_converter

* fixed tests/agents/test_agent.py

* Refactor LLM context length exception handling and improve provider integration

- Renamed `LLMContextLengthExceededException` to `LLMContextLengthExceededExceptionError` for clarity and consistency.
- Updated LLM class to pass the provider parameter correctly during initialization.
- Enhanced error handling in various LLM provider implementations to raise the new exception type.
- Adjusted tests to reflect the updated exception name and ensure proper error handling in context length scenarios.

* Enhance LLM context window handling across providers

- Introduced CONTEXT_WINDOW_USAGE_RATIO to adjust context window sizes dynamically for Anthropic, Azure, Gemini, and OpenAI LLMs.
- Added validation for context window sizes in Azure and Gemini providers to ensure they fall within acceptable limits.
- Updated context window size calculations to use the new ratio, improving consistency and adaptability across different models.
- Removed hardcoded context window sizes in favor of ratio-based calculations for better flexibility.

* fix test agent again

* fix test agent

* feat: add native LLM providers for Anthropic, Azure, and Gemini

- Introduced new completion implementations for Anthropic, Azure, and Gemini, integrating their respective SDKs.
- Added utility functions for tool validation and extraction to support function calling across LLM providers.
- Enhanced context window management and token usage extraction for each provider.
- Created a common utility module for shared functionality among LLM providers.

* chore: update dependencies and improve context management

- Removed direct dependency on `litellm` from the main dependencies and added it under extras for better modularity.
- Updated the `litellm` dependency specification to allow for greater flexibility in versioning.
- Refactored context length exception handling across various LLM providers to use a consistent error class.
- Enhanced platform-specific dependency markers for NVIDIA packages to ensure compatibility across different systems.

* refactor(tests): update LLM instantiation to include is_litellm flag in test cases

- Modified multiple test cases in test_llm.py to set the is_litellm parameter to True when instantiating the LLM class.
- This change ensures that the tests are aligned with the latest LLM configuration requirements and improves consistency across test scenarios.
- Adjusted relevant assertions and comments to reflect the updated LLM behavior.

* linter

* linted

* revert constants

* fix(tests): correct type hint in expected model description

- Updated the expected description in the test_generate_model_description_dict_field function to use 'Dict' instead of 'dict' for consistency with type hinting conventions.
- This change ensures that the test accurately reflects the expected output format for model descriptions.

* refactor(llm): enhance LLM instantiation and error handling

- Updated the LLM class to include validation for the model parameter, ensuring it is a non-empty string.
- Improved error handling by logging warnings when the native SDK fails, allowing for a fallback to LiteLLM.
- Adjusted the instantiation of LLM in test cases to consistently include the is_litellm flag, aligning with recent changes in LLM configuration.
- Modified relevant tests to reflect these updates, ensuring better coverage and accuracy in testing scenarios.

* fixed test

* refactor(llm): enhance token usage tracking and add copy methods

- Updated the LLM class to track token usage and log callbacks in streaming mode, improving monitoring capabilities.
- Introduced shallow and deep copy methods for the LLM instance, allowing for better management of LLM configurations and parameters.
- Adjusted test cases to instantiate LLM with the is_litellm flag, ensuring alignment with recent changes in LLM configuration.

* refactor(tests): reorganize imports and enhance error messages in test cases

- Cleaned up import statements in test_crew.py for better organization and readability.
- Enhanced error messages in test cases to use `re.escape` for improved regex matching, ensuring more robust error handling.
- Adjusted comments for clarity and consistency across test scenarios.
- Ensured that all necessary modules are imported correctly to avoid potential runtime issues.
This commit is contained in:
Lorenze Jay
2025-10-03 14:32:35 -07:00
committed by GitHub
parent 428810bd6f
commit 126b91eab3
77 changed files with 25026 additions and 493 deletions

View File

@@ -114,7 +114,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self.messages: list[dict[str, str]] = []
self.iterations = 0
self.log_error_after = 3
existing_stop = self.llm.stop or []
existing_stop = getattr(self.llm, "stop", [])
self.llm.stop = list(
set(
existing_stop + self.stop
@@ -192,6 +192,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
callbacks=self.callbacks,
printer=self._printer,
from_task=self.task,
from_agent=self.agent,
)
formatted_answer = process_llm_response(answer, self.use_stop_words)

View File

@@ -1,16 +1,16 @@
import asyncio
import json
import re
import uuid
import warnings
from collections.abc import Callable
from concurrent.futures import Future
from copy import copy as shallow_copy
from hashlib import md5
import json
import re
from typing import (
Any,
cast,
)
import uuid
import warnings
from opentelemetry import baggage
from opentelemetry.context import attach, detach
@@ -82,6 +82,7 @@ from crewai.utilities.planning_handler import CrewPlanner
from crewai.utilities.task_output_storage_handler import TaskOutputStorageHandler
from crewai.utilities.training_handler import CrewTrainingHandler
warnings.filterwarnings("ignore", category=SyntaxWarning, module="pysbd")
@@ -1353,13 +1354,34 @@ class Crew(FlowTrackable, BaseModel):
def calculate_usage_metrics(self) -> UsageMetrics:
"""Calculates and returns the usage metrics."""
total_usage_metrics = UsageMetrics()
for agent in self.agents:
if hasattr(agent, "_token_process"):
token_sum = agent._token_process.get_summary()
total_usage_metrics.add_usage_metrics(token_sum)
if isinstance(agent.llm, BaseLLM):
llm_usage = agent.llm.get_token_usage_summary()
total_usage_metrics.add_usage_metrics(llm_usage)
else:
# fallback litellm
if hasattr(agent, "_token_process"):
token_sum = agent._token_process.get_summary()
total_usage_metrics.add_usage_metrics(token_sum)
if self.manager_agent and hasattr(self.manager_agent, "_token_process"):
token_sum = self.manager_agent._token_process.get_summary()
total_usage_metrics.add_usage_metrics(token_sum)
if (
self.manager_agent
and hasattr(self.manager_agent, "llm")
and hasattr(self.manager_agent.llm, "get_token_usage_summary")
):
if isinstance(self.manager_agent.llm, BaseLLM):
llm_usage = self.manager_agent.llm.get_token_usage_summary()
else:
llm_usage = self.manager_agent.llm._token_process.get_summary()
total_usage_metrics.add_usage_metrics(llm_usage)
self.usage_metrics = total_usage_metrics
return total_usage_metrics

View File

@@ -17,6 +17,11 @@ class BaseEvent(BaseModel):
)
fingerprint_metadata: dict[str, Any] | None = None # Any relevant metadata
task_id: str | None = None
task_name: str | None = None
agent_id: str | None = None
agent_role: str | None = None
def to_json(self, exclude: set[str] | None = None):
"""
Converts the event to a JSON-serializable dictionary.
@@ -31,7 +36,7 @@ class BaseEvent(BaseModel):
def _set_task_params(self, data: dict[str, Any]):
if "from_task" in data and (task := data["from_task"]):
self.task_id = task.id
self.task_id = str(task.id)
self.task_name = task.name or task.description
self.from_task = None
@@ -42,6 +47,6 @@ class BaseEvent(BaseModel):
if not agent:
return
self.agent_id = agent.id
self.agent_id = str(agent.id)
self.agent_role = agent.role
self.from_agent = None

View File

@@ -7,19 +7,23 @@ from crewai.events.base_events import BaseEvent
class LLMEventBase(BaseEvent):
task_name: str | None = None
task_id: str | None = None
agent_id: str | None = None
agent_role: str | None = None
from_task: Any | None = None
from_agent: Any | None = None
def __init__(self, **data):
if data.get("from_task"):
task = data["from_task"]
data["task_id"] = str(task.id)
data["task_name"] = task.name or task.description
data["from_task"] = None
if data.get("from_agent"):
agent = data["from_agent"]
data["agent_id"] = str(agent.id)
data["agent_role"] = agent.role
data["from_agent"] = None
super().__init__(**data)
self._set_agent_params(data)
self._set_task_params(data)
class LLMCallType(Enum):

View File

@@ -27,9 +27,20 @@ class ToolUsageEvent(BaseEvent):
model_config = ConfigDict(arbitrary_types_allowed=True)
def __init__(self, **data):
if data.get("from_task"):
task = data["from_task"]
data["task_id"] = str(task.id)
data["task_name"] = task.name or task.description
data["from_task"] = None
if data.get("from_agent"):
agent = data["from_agent"]
data["agent_id"] = str(agent.id)
data["agent_role"] = agent.role
data["from_agent"] = None
super().__init__(**data)
self._set_agent_params(data)
self._set_task_params(data)
# Set fingerprint data from the agent
if self.agent and hasattr(self.agent, "fingerprint") and self.agent.fingerprint:
self.source_fingerprint = self.agent.fingerprint.uuid_str

View File

@@ -1,13 +1,13 @@
import asyncio
import inspect
import uuid
from collections.abc import Callable
import inspect
from typing import (
Any,
cast,
get_args,
get_origin,
)
import uuid
from pydantic import (
UUID4,
@@ -351,7 +351,10 @@ class LiteAgent(FlowTrackable, BaseModel):
)
# Calculate token usage metrics
usage_metrics = self._token_process.get_summary()
if isinstance(self.llm, BaseLLM):
usage_metrics = self.llm.get_token_usage_summary()
else:
usage_metrics = self._token_process.get_summary()
# Create output
output = LiteAgentOutput(
@@ -400,7 +403,10 @@ class LiteAgent(FlowTrackable, BaseModel):
elif isinstance(guardrail_result.result, BaseModel):
output.pydantic = guardrail_result.result
usage_metrics = self._token_process.get_summary()
if isinstance(self.llm, BaseLLM):
usage_metrics = self.llm.get_token_usage_summary()
else:
usage_metrics = self._token_process.get_summary()
output.usage_metrics = usage_metrics.model_dump() if usage_metrics else None
# Emit completion event

View File

@@ -1,13 +1,14 @@
from collections import defaultdict
from collections.abc import Callable
from datetime import datetime
import io
import json
import logging
import os
import sys
import threading
from collections import defaultdict
from collections.abc import Callable
from datetime import datetime
from typing import (
TYPE_CHECKING,
Any,
Final,
Literal,
@@ -17,7 +18,6 @@ from typing import (
)
from dotenv import load_dotenv
from litellm.types.utils import ChatCompletionDeltaToolCall
from pydantic import BaseModel, Field
from crewai.events.event_bus import crewai_event_bus
@@ -39,19 +39,42 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
)
from crewai.utilities.logger_utils import suppress_warnings
with suppress_warnings():
if TYPE_CHECKING:
from litellm import Choices
from litellm.exceptions import ContextWindowExceededError
from litellm.litellm_core_utils.get_supported_openai_params import (
get_supported_openai_params,
)
from litellm.types.utils import ChatCompletionDeltaToolCall, ModelResponse
from litellm.utils import supports_response_schema
try:
import litellm
from litellm import Choices, CustomLogger
from litellm.exceptions import ContextWindowExceededError
from litellm.litellm_core_utils.get_supported_openai_params import (
get_supported_openai_params,
)
from litellm.types.utils import ModelResponse
from litellm.types.utils import ChatCompletionDeltaToolCall, ModelResponse
from litellm.utils import supports_response_schema
LITELLM_AVAILABLE = True
except ImportError:
LITELLM_AVAILABLE = False
litellm = None # type: ignore
Choices = None # type: ignore
ContextWindowExceededError = Exception # type: ignore
get_supported_openai_params = None # type: ignore
ChatCompletionDeltaToolCall = None # type: ignore
ModelResponse = None # type: ignore
supports_response_schema = None # type: ignore
load_dotenv()
litellm.suppress_debug_info = True
if LITELLM_AVAILABLE:
litellm.suppress_debug_info = True
class FilteredStream(io.TextIOBase):
@@ -275,6 +298,77 @@ class AccumulatedToolArgs(BaseModel):
class LLM(BaseLLM):
completion_cost: float | None = None
def __new__(cls, model: str, is_litellm: bool = False, **kwargs) -> "LLM":
"""Factory method that routes to native SDK or falls back to LiteLLM."""
if not model or not isinstance(model, str):
raise ValueError("Model must be a non-empty string")
provider = model.partition("/")[0] if "/" in model else "openai"
native_class = cls._get_native_provider(provider)
if native_class and not is_litellm:
try:
model_string = model.partition("/")[2] if "/" in model else model
return native_class(model=model_string, provider=provider, **kwargs)
except Exception as e:
import logging
logger = logging.getLogger(__name__)
logger.warning(
f"Native SDK failed for {provider}: {e}, falling back to LiteLLM"
)
# FALLBACK to LiteLLM
if not LITELLM_AVAILABLE:
raise ImportError(
"Please install the required dependencies:\n"
"- For LiteLLM: uv add litellm"
)
instance = object.__new__(cls)
super(LLM, instance).__init__(model=model, is_litellm=True, **kwargs)
instance.is_litellm = True
return instance
@classmethod
def _get_native_provider(cls, provider: str) -> type | None:
"""Get native provider class if available."""
if provider == "openai":
try:
from crewai.llms.providers.openai.completion import OpenAICompletion
return OpenAICompletion
except ImportError:
return None
elif provider == "anthropic" or provider == "claude":
try:
from crewai.llms.providers.anthropic.completion import (
AnthropicCompletion,
)
return AnthropicCompletion
except ImportError:
return None
elif provider == "azure":
try:
from crewai.llms.providers.azure.completion import AzureCompletion
return AzureCompletion
except ImportError:
return None
elif provider == "google" or provider == "gemini":
try:
from crewai.llms.providers.gemini.completion import GeminiCompletion
return GeminiCompletion
except ImportError:
return None
return None
def __init__(
self,
model: str,
@@ -284,7 +378,7 @@ class LLM(BaseLLM):
n: int | None = None,
stop: str | list[str] | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
max_tokens: int | float | None = None,
presence_penalty: float | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[int, float] | None = None,
@@ -301,6 +395,11 @@ class LLM(BaseLLM):
stream: bool = False,
**kwargs,
):
"""Initialize LLM instance.
Note: This __init__ method is only called for fallback instances.
Native provider instances handle their own initialization in their respective classes.
"""
self.model = model
self.timeout = timeout
self.temperature = temperature
@@ -328,7 +427,7 @@ class LLM(BaseLLM):
litellm.drop_params = True
# Normalize self.stop to always be a List[str]
# Normalize self.stop to always be a list[str]
if stop is None:
self.stop: list[str] = []
elif isinstance(stop, str):
@@ -349,7 +448,8 @@ class LLM(BaseLLM):
Returns:
bool: True if the model is from Anthropic, False otherwise.
"""
return any(prefix in model.lower() for prefix in ANTHROPIC_PREFIXES)
anthropic_prefixes = ("anthropic/", "claude-", "claude/")
return any(prefix in model.lower() for prefix in anthropic_prefixes)
def _prepare_completion_params(
self,
@@ -514,10 +614,6 @@ class LLM(BaseLLM):
# Add the chunk content to the full response
full_response += chunk_content
# Emit the chunk event
if not hasattr(crewai_event_bus, "emit"):
raise Exception("crewai_event_bus must have an `emit` method")
crewai_event_bus.emit(
self,
event=LLMStreamChunkEvent(
@@ -623,7 +719,9 @@ class LLM(BaseLLM):
# --- 8) If no tool calls or no available functions, return the text response directly
if not tool_calls or not available_functions:
# Log token usage if available in streaming mode
# Track token usage and log callbacks if available in streaming mode
if usage_info:
self._track_token_usage_internal(usage_info)
self._handle_streaming_callbacks(callbacks, usage_info, last_chunk)
# Emit completion event and return response
self._handle_emit_call_events(
@@ -640,7 +738,9 @@ class LLM(BaseLLM):
if tool_result is not None:
return tool_result
# --- 10) Log token usage if available in streaming mode
# --- 10) Track token usage and log callbacks if available in streaming mode
if usage_info:
self._track_token_usage_internal(usage_info)
self._handle_streaming_callbacks(callbacks, usage_info, last_chunk)
# --- 11) Emit completion event and return response
@@ -671,11 +771,6 @@ class LLM(BaseLLM):
)
return full_response
# Emit failed event and re-raise the exception
if not hasattr(crewai_event_bus, "emit"):
raise AttributeError(
"crewai_event_bus must have an 'emit' method"
) from e
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(
@@ -702,8 +797,7 @@ class LLM(BaseLLM):
current_tool_accumulator.function.arguments += (
tool_call.function.arguments
)
if not hasattr(crewai_event_bus, "emit"):
raise AttributeError("crewai_event_bus must have an 'emit' method")
crewai_event_bus.emit(
self,
event=LLMStreamChunkEvent(
@@ -832,6 +926,7 @@ class LLM(BaseLLM):
messages=params["messages"],
)
return text_response
# --- 6) If there is no text response, no available functions, but there are tool calls, return the tool calls
if tool_calls and not available_functions and not text_response:
return tool_calls
@@ -886,9 +981,6 @@ class LLM(BaseLLM):
function_args = json.loads(tool_call.function.arguments)
fn = available_functions[function_name]
# --- 3.2) Execute function
if not hasattr(crewai_event_bus, "emit"):
raise AttributeError("crewai_event_bus must have an 'emit' method")
started_at = datetime.now()
crewai_event_bus.emit(
self,
@@ -928,10 +1020,6 @@ class LLM(BaseLLM):
function_name, lambda: None
) # Ensure fn is always a callable
logging.error(f"Error executing function '{function_name}': {e}")
if not hasattr(crewai_event_bus, "emit"):
raise AttributeError(
"crewai_event_bus must have an 'emit' method"
) from e
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(error=f"Tool execution error: {e!s}"),
@@ -982,9 +1070,6 @@ class LLM(BaseLLM):
ValueError: If response format is not supported
LLMContextLengthExceededError: If input exceeds model's context limit
"""
# --- 1) Emit call started event
if not hasattr(crewai_event_bus, "emit"):
raise AttributeError("crewai_event_bus must have an 'emit' method")
crewai_event_bus.emit(
self,
event=LLMCallStartedEvent(
@@ -1021,10 +1106,10 @@ class LLM(BaseLLM):
return self._handle_streaming_response(
params, callbacks, available_functions, from_task, from_agent
)
return self._handle_non_streaming_response(
params, callbacks, available_functions, from_task, from_agent
)
except LLMContextLengthExceededError:
# Re-raise LLMContextLengthExceededError as it should be handled
# by the CrewAgentExecutor._invoke_loop method, which can then decide
@@ -1057,10 +1142,6 @@ class LLM(BaseLLM):
from_agent=from_agent,
)
if not hasattr(crewai_event_bus, "emit"):
raise AttributeError(
"crewai_event_bus must have an 'emit' method"
) from e
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(
@@ -1086,8 +1167,6 @@ class LLM(BaseLLM):
from_agent: Optional agent object
messages: Optional messages object
"""
if not hasattr(crewai_event_bus, "emit"):
raise AttributeError("crewai_event_bus must have an 'emit' method")
crewai_event_bus.emit(
self,
event=LLMCallCompletedEvent(
@@ -1225,11 +1304,14 @@ class LLM(BaseLLM):
if self.context_window_size != 0:
return self.context_window_size
min_context = 1024
max_context = 2097152 # Current max from gemini-1.5-pro
# Validate all context window sizes
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
if value < MIN_CONTEXT or value > MAX_CONTEXT:
if value < min_context or value > max_context:
raise ValueError(
f"Context window for {key} must be between {MIN_CONTEXT} and {MAX_CONTEXT}"
f"Context window for {key} must be between {min_context} and {max_context}"
)
self.context_window_size = int(
@@ -1293,3 +1375,129 @@ class LLM(BaseLLM):
litellm.success_callback = success_callbacks
litellm.failure_callback = failure_callbacks
def __copy__(self):
"""Create a shallow copy of the LLM instance."""
# Filter out parameters that are already explicitly passed to avoid conflicts
filtered_params = {
k: v
for k, v in self.additional_params.items()
if k
not in [
"model",
"is_litellm",
"temperature",
"top_p",
"n",
"max_completion_tokens",
"max_tokens",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"response_format",
"seed",
"logprobs",
"top_logprobs",
"base_url",
"api_base",
"api_version",
"api_key",
"callbacks",
"reasoning_effort",
"stream",
"stop",
]
}
# Create a new instance with the same parameters
return LLM(
model=self.model,
is_litellm=self.is_litellm,
temperature=self.temperature,
top_p=self.top_p,
n=self.n,
max_completion_tokens=self.max_completion_tokens,
max_tokens=self.max_tokens,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
logit_bias=self.logit_bias,
response_format=self.response_format,
seed=self.seed,
logprobs=self.logprobs,
top_logprobs=self.top_logprobs,
base_url=self.base_url,
api_base=self.api_base,
api_version=self.api_version,
api_key=self.api_key,
callbacks=self.callbacks,
reasoning_effort=self.reasoning_effort,
stream=self.stream,
stop=self.stop,
**filtered_params,
)
def __deepcopy__(self, memo):
"""Create a deep copy of the LLM instance."""
import copy
# Filter out parameters that are already explicitly passed to avoid conflicts
filtered_params = {
k: copy.deepcopy(v, memo)
for k, v in self.additional_params.items()
if k
not in [
"model",
"is_litellm",
"temperature",
"top_p",
"n",
"max_completion_tokens",
"max_tokens",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"response_format",
"seed",
"logprobs",
"top_logprobs",
"base_url",
"api_base",
"api_version",
"api_key",
"callbacks",
"reasoning_effort",
"stream",
"stop",
]
}
# Create a new instance with the same parameters
return LLM(
model=self.model,
is_litellm=self.is_litellm,
temperature=self.temperature,
top_p=self.top_p,
n=self.n,
max_completion_tokens=self.max_completion_tokens,
max_tokens=self.max_tokens,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
logit_bias=copy.deepcopy(self.logit_bias, memo)
if self.logit_bias
else None,
response_format=copy.deepcopy(self.response_format, memo)
if self.response_format
else None,
seed=self.seed,
logprobs=self.logprobs,
top_logprobs=self.top_logprobs,
base_url=self.base_url,
api_base=self.api_base,
api_version=self.api_version,
api_key=self.api_key,
callbacks=copy.deepcopy(self.callbacks, memo) if self.callbacks else None,
reasoning_effort=self.reasoning_effort,
stream=self.stream,
stop=copy.deepcopy(self.stop, memo) if self.stop else None,
**filtered_params,
)

View File

@@ -1,12 +1,33 @@
"""Base LLM abstract class for CrewAI.
This module provides the abstract base class for all LLM implementations
in CrewAI.
in CrewAI, including common functionality for native SDK implementations.
"""
from abc import ABC, abstractmethod
from datetime import datetime
import json
import logging
from typing import Any, Final
from pydantic import BaseModel
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.llm_events import (
LLMCallCompletedEvent,
LLMCallFailedEvent,
LLMCallStartedEvent,
LLMCallType,
LLMStreamChunkEvent,
)
from crewai.events.types.tool_usage_events import (
ToolUsageErrorEvent,
ToolUsageFinishedEvent,
ToolUsageStartedEvent,
)
from crewai.types.usage_metrics import UsageMetrics
DEFAULT_CONTEXT_WINDOW_SIZE: Final[int] = 4096
DEFAULT_SUPPORTS_STOP_WORDS: Final[bool] = True
@@ -27,13 +48,20 @@ class BaseLLM(ABC):
model: The model identifier/name.
temperature: Optional temperature setting for response generation.
stop: A list of stop sequences that the LLM should use to stop generation.
additional_params: Additional provider-specific parameters.
"""
is_litellm: bool = False
def __init__(
self,
model: str,
temperature: float | None = None,
stop: list[str] | None = None,
api_key: str | None = None,
base_url: str | None = None,
timeout: float | None = None,
provider: str | None = None,
**kwargs,
) -> None:
"""Initialize the BaseLLM with default attributes.
@@ -41,10 +69,44 @@ class BaseLLM(ABC):
model: The model identifier/name.
temperature: Optional temperature setting for response generation.
stop: Optional list of stop sequences for generation.
**kwargs: Additional provider-specific parameters.
"""
if not model:
raise ValueError("Model name is required and cannot be empty")
self.model = model
self.temperature = temperature
self.stop: list[str] = stop or []
self.api_key = api_key
self.base_url = base_url
# Store additional parameters for provider-specific use
self.additional_params = kwargs
self._provider = provider or "openai"
stop = kwargs.pop("stop", None)
if stop is None:
self.stop: list[str] = []
elif isinstance(stop, str):
self.stop = [stop]
else:
self.stop = stop
self._token_usage = {
"total_tokens": 0,
"prompt_tokens": 0,
"completion_tokens": 0,
"successful_requests": 0,
"cached_prompt_tokens": 0,
}
@property
def provider(self) -> str:
"""Get the provider of the LLM."""
return self._provider
@provider.setter
def provider(self, value: str) -> None:
"""Set the provider of the LLM."""
self._provider = value
@abstractmethod
def call(
@@ -82,6 +144,17 @@ class BaseLLM(ABC):
RuntimeError: If the LLM request fails for other reasons.
"""
def _convert_tools_for_interference(self, tools: list[dict]) -> list[dict]:
"""Convert tools to a format that can be used for interference.
Args:
tools: List of tools to convert.
Returns:
List of converted tools (default implementation returns as-is)
"""
return tools
def supports_stop_words(self) -> bool:
"""Check if the LLM supports stop words.
@@ -90,6 +163,58 @@ class BaseLLM(ABC):
"""
return DEFAULT_SUPPORTS_STOP_WORDS
def _supports_stop_words_implementation(self) -> bool:
"""Check if stop words are configured for this LLM instance.
Native providers can override supports_stop_words() to return this value
to ensure consistent behavior based on whether stop words are actually configured.
Returns:
True if stop words are configured and can be applied
"""
return bool(self.stop)
def _apply_stop_words(self, content: str) -> str:
"""Apply stop words to truncate response content.
This method provides consistent stop word behavior across all native SDK providers.
Native providers should call this method to post-process their responses.
Args:
content: The raw response content from the LLM
Returns:
Content truncated at the first occurrence of any stop word
Example:
>>> llm = MyNativeLLM(stop=["Observation:", "Final Answer:"])
>>> response = "I need to search.\\n\\nAction: search\\nObservation: Found results"
>>> llm._apply_stop_words(response)
"I need to search.\\n\\nAction: search"
"""
if not self.stop or not content:
return content
# Find the earliest occurrence of any stop word
earliest_stop_pos = len(content)
found_stop_word = None
for stop_word in self.stop:
stop_pos = content.find(stop_word)
if stop_pos != -1 and stop_pos < earliest_stop_pos:
earliest_stop_pos = stop_pos
found_stop_word = stop_word
# Truncate at the stop word if found
if found_stop_word is not None:
truncated = content[:earliest_stop_pos].strip()
logging.debug(
f"Applied stop word '{found_stop_word}' at position {earliest_stop_pos}"
)
return truncated
return content
def get_context_window_size(self) -> int:
"""Get the context window size for the LLM.
@@ -98,3 +223,314 @@ class BaseLLM(ABC):
"""
# Default implementation - subclasses should override with model-specific values
return DEFAULT_CONTEXT_WINDOW_SIZE
# Common helper methods for native SDK implementations
def _emit_call_started_event(
self,
messages: str | list[dict[str, str]],
tools: list[dict] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> None:
"""Emit LLM call started event."""
if not hasattr(crewai_event_bus, "emit"):
raise ValueError("crewai_event_bus does not have an emit method") from None
crewai_event_bus.emit(
self,
event=LLMCallStartedEvent(
messages=messages,
tools=tools,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
model=self.model,
),
)
def _emit_call_completed_event(
self,
response: Any,
call_type: LLMCallType,
from_task: Any | None = None,
from_agent: Any | None = None,
messages: str | list[dict[str, Any]] | None = None,
) -> None:
"""Emit LLM call completed event."""
crewai_event_bus.emit(
self,
event=LLMCallCompletedEvent(
messages=messages,
response=response,
call_type=call_type,
from_task=from_task,
from_agent=from_agent,
model=self.model,
),
)
def _emit_call_failed_event(
self,
error: str,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> None:
"""Emit LLM call failed event."""
if not hasattr(crewai_event_bus, "emit"):
raise ValueError("crewai_event_bus does not have an emit method") from None
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(
error=error,
from_task=from_task,
from_agent=from_agent,
),
)
def _emit_stream_chunk_event(
self,
chunk: str,
from_task: Any | None = None,
from_agent: Any | None = None,
tool_call: dict[str, Any] | None = None,
) -> None:
"""Emit stream chunk event."""
if not hasattr(crewai_event_bus, "emit"):
raise ValueError("crewai_event_bus does not have an emit method") from None
crewai_event_bus.emit(
self,
event=LLMStreamChunkEvent(
chunk=chunk,
tool_call=tool_call,
from_task=from_task,
from_agent=from_agent,
),
)
def _handle_tool_execution(
self,
function_name: str,
function_args: dict[str, Any],
available_functions: dict[str, Any],
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str | None:
"""Handle tool execution with proper event emission.
Args:
function_name: Name of the function to execute
function_args: Arguments to pass to the function
available_functions: Dict of available functions
from_task: Optional task object
from_agent: Optional agent object
Returns:
Result of function execution or None if function not found
"""
if function_name not in available_functions:
logging.warning(
f"Function '{function_name}' not found in available functions"
)
return None
try:
# Emit tool usage started event
started_at = datetime.now()
crewai_event_bus.emit(
self,
event=ToolUsageStartedEvent(
tool_name=function_name,
tool_args=function_args,
from_agent=from_agent,
from_task=from_task,
),
)
# Execute the function
fn = available_functions[function_name]
result = fn(**function_args)
# Emit tool usage finished event
crewai_event_bus.emit(
self,
event=ToolUsageFinishedEvent(
output=result,
tool_name=function_name,
tool_args=function_args,
started_at=started_at,
finished_at=datetime.now(),
from_task=from_task,
from_agent=from_agent,
),
)
# Emit LLM call completed event for tool call
self._emit_call_completed_event(
response=result,
call_type=LLMCallType.TOOL_CALL,
from_task=from_task,
from_agent=from_agent,
)
return str(result)
except Exception as e:
error_msg = f"Error executing function '{function_name}': {e!s}"
logging.error(error_msg)
# Emit tool usage error event
if not hasattr(crewai_event_bus, "emit"):
raise ValueError(
"crewai_event_bus does not have an emit method"
) from None
crewai_event_bus.emit(
self,
event=ToolUsageErrorEvent(
tool_name=function_name,
tool_args=function_args,
error=error_msg,
from_task=from_task,
from_agent=from_agent,
),
)
# Emit LLM call failed event
self._emit_call_failed_event(
error=error_msg,
from_task=from_task,
from_agent=from_agent,
)
return None
def _format_messages(
self, messages: str | list[dict[str, str]]
) -> list[dict[str, str]]:
"""Convert messages to standard format.
Args:
messages: Input messages (string or list of message dicts)
Returns:
List of message dictionaries with 'role' and 'content' keys
Raises:
ValueError: If message format is invalid
"""
if isinstance(messages, str):
return [{"role": "user", "content": messages}]
# Validate message format
for i, msg in enumerate(messages):
if not isinstance(msg, dict):
raise ValueError(f"Message at index {i} must be a dictionary")
if "role" not in msg or "content" not in msg:
raise ValueError(
f"Message at index {i} must have 'role' and 'content' keys"
)
return messages
def _validate_structured_output(
self,
response: str,
response_format: type[BaseModel] | None,
) -> str | BaseModel:
"""Validate and parse structured output.
Args:
response: Raw response string
response_format: Optional Pydantic model for structured output
Returns:
Parsed response (BaseModel instance if response_format provided, otherwise string)
Raises:
ValueError: If structured output validation fails
"""
if response_format is None:
return response
try:
# Try to parse as JSON first
if response.strip().startswith("{") or response.strip().startswith("["):
data = json.loads(response)
return response_format.model_validate(data)
# Try to extract JSON from response
import re
json_match = re.search(r"\{.*\}", response, re.DOTALL)
if json_match:
data = json.loads(json_match.group())
return response_format.model_validate(data)
raise ValueError("No JSON found in response")
except (json.JSONDecodeError, ValueError) as e:
logging.warning(f"Failed to parse structured output: {e}")
raise ValueError(
f"Failed to parse response into {response_format.__name__}: {e}"
) from e
def _extract_provider(self, model: str) -> str:
"""Extract provider from model string.
Args:
model: Model string (e.g., 'openai/gpt-4' or 'gpt-4')
Returns:
Provider name (e.g., 'openai')
"""
if "/" in model:
return model.partition("/")[0]
return "openai" # Default provider
def _track_token_usage_internal(self, usage_data: dict[str, Any]) -> None:
"""Track token usage internally in the LLM instance.
Args:
usage_data: Token usage data from the API response
"""
# Extract tokens in a provider-agnostic way
prompt_tokens = (
usage_data.get("prompt_tokens")
or usage_data.get("prompt_token_count")
or usage_data.get("input_tokens")
or 0
)
completion_tokens = (
usage_data.get("completion_tokens")
or usage_data.get("candidates_token_count")
or usage_data.get("output_tokens")
or 0
)
cached_tokens = (
usage_data.get("cached_tokens")
or usage_data.get("cached_prompt_tokens")
or 0
)
self._token_usage["prompt_tokens"] += prompt_tokens
self._token_usage["completion_tokens"] += completion_tokens
self._token_usage["total_tokens"] += prompt_tokens + completion_tokens
self._token_usage["successful_requests"] += 1
self._token_usage["cached_prompt_tokens"] += cached_tokens
def get_token_usage_summary(self) -> UsageMetrics:
"""Get summary of token usage for this LLM instance.
Returns:
Dictionary with token usage totals
"""
return UsageMetrics(**self._token_usage)

View File

@@ -0,0 +1,432 @@
import json
import logging
import os
from typing import Any
from crewai.events.types.llm_events import LLMCallType
from crewai.llms.base_llm import BaseLLM
from crewai.utilities.agent_utils import is_context_length_exceeded
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededError,
)
try:
from anthropic import Anthropic
from anthropic.types import Message
from anthropic.types.tool_use_block import ToolUseBlock
except ImportError:
raise ImportError(
"Anthropic native provider not available, to install: `uv add anthropic`"
) from None
class AnthropicCompletion(BaseLLM):
"""Anthropic native completion implementation.
This class provides direct integration with the Anthropic Python SDK,
offering native tool use, streaming support, and proper message formatting.
"""
def __init__(
self,
model: str = "claude-3-5-sonnet-20241022",
api_key: str | None = None,
base_url: str | None = None,
timeout: float | None = None,
max_retries: int = 2,
temperature: float | None = None,
max_tokens: int = 4096, # Required for Anthropic
top_p: float | None = None,
stop_sequences: list[str] | None = None,
stream: bool = False,
**kwargs,
):
"""Initialize Anthropic chat completion client.
Args:
model: Anthropic model name (e.g., 'claude-3-5-sonnet-20241022')
api_key: Anthropic API key (defaults to ANTHROPIC_API_KEY env var)
base_url: Custom base URL for Anthropic API
timeout: Request timeout in seconds
max_retries: Maximum number of retries
temperature: Sampling temperature (0-1)
max_tokens: Maximum tokens in response (required for Anthropic)
top_p: Nucleus sampling parameter
stop_sequences: Stop sequences (Anthropic uses stop_sequences, not stop)
stream: Enable streaming responses
**kwargs: Additional parameters
"""
super().__init__(
model=model, temperature=temperature, stop=stop_sequences or [], **kwargs
)
# Initialize Anthropic client
self.client = Anthropic(
api_key=api_key or os.getenv("ANTHROPIC_API_KEY"),
base_url=base_url,
timeout=timeout,
max_retries=max_retries,
)
# Store completion parameters
self.max_tokens = max_tokens
self.top_p = top_p
self.stream = stream
self.stop_sequences = stop_sequences or []
# Model-specific settings
self.is_claude_3 = "claude-3" in model.lower()
self.supports_tools = self.is_claude_3 # Claude 3+ supports tool use
def call(
self,
messages: str | list[dict[str, str]],
tools: list[dict] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str | Any:
"""Call Anthropic messages API.
Args:
messages: Input messages for the chat completion
tools: List of tool/function definitions
callbacks: Callback functions (not used in native implementation)
available_functions: Available functions for tool calling
from_task: Task that initiated the call
from_agent: Agent that initiated the call
Returns:
Chat completion response or tool call result
"""
try:
# Emit call started event
self._emit_call_started_event(
messages=messages,
tools=tools,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
)
# Format messages for Anthropic
formatted_messages, system_message = self._format_messages_for_anthropic(
messages
)
# Prepare completion parameters
completion_params = self._prepare_completion_params(
formatted_messages, system_message, tools
)
# Handle streaming vs non-streaming
if self.stream:
return self._handle_streaming_completion(
completion_params, available_functions, from_task, from_agent
)
return self._handle_completion(
completion_params, available_functions, from_task, from_agent
)
except Exception as e:
error_msg = f"Anthropic API call failed: {e!s}"
logging.error(error_msg)
self._emit_call_failed_event(
error=error_msg, from_task=from_task, from_agent=from_agent
)
raise
def _prepare_completion_params(
self,
messages: list[dict[str, str]],
system_message: str | None = None,
tools: list[dict] | None = None,
) -> dict[str, Any]:
"""Prepare parameters for Anthropic messages API.
Args:
messages: Formatted messages for Anthropic
system_message: Extracted system message
tools: Tool definitions
Returns:
Parameters dictionary for Anthropic API
"""
params = {
"model": self.model,
"messages": messages,
"max_tokens": self.max_tokens,
"stream": self.stream,
}
# Add system message if present
if system_message:
params["system"] = system_message
# Add optional parameters if set
if self.temperature is not None:
params["temperature"] = self.temperature
if self.top_p is not None:
params["top_p"] = self.top_p
if self.stop_sequences:
params["stop_sequences"] = self.stop_sequences
# Handle tools for Claude 3+
if tools and self.supports_tools:
params["tools"] = self._convert_tools_for_interference(tools)
return params
def _convert_tools_for_interference(self, tools: list[dict]) -> list[dict]:
"""Convert CrewAI tool format to Anthropic tool use format."""
from crewai.llms.providers.utils.common import safe_tool_conversion
anthropic_tools = []
for tool in tools:
name, description, parameters = safe_tool_conversion(tool, "Anthropic")
anthropic_tool = {
"name": name,
"description": description,
}
if parameters and isinstance(parameters, dict):
anthropic_tool["input_schema"] = parameters # type: ignore
anthropic_tools.append(anthropic_tool)
return anthropic_tools
def _format_messages_for_anthropic(
self, messages: str | list[dict[str, str]]
) -> tuple[list[dict[str, str]], str | None]:
"""Format messages for Anthropic API.
Anthropic has specific requirements:
- System messages are separate from conversation messages
- Messages must alternate between user and assistant
- First message must be from user
Args:
messages: Input messages
Returns:
Tuple of (formatted_messages, system_message)
"""
# Use base class formatting first
base_formatted = super()._format_messages(messages)
formatted_messages = []
system_message = None
for message in base_formatted:
role = message.get("role")
content = message.get("content", "")
if role == "system":
# Extract system message - Anthropic handles it separately
if system_message:
system_message += f"\n\n{content}"
else:
system_message = content
else:
# Add user/assistant messages - ensure both role and content are str, not None
role_str = role if role is not None else "user"
content_str = content if content is not None else ""
formatted_messages.append({"role": role_str, "content": content_str})
# Ensure first message is from user (Anthropic requirement)
if not formatted_messages:
# If no messages, add a default user message
formatted_messages.append({"role": "user", "content": "Hello"})
elif formatted_messages[0]["role"] != "user":
# If first message is not from user, insert a user message at the beginning
formatted_messages.insert(0, {"role": "user", "content": "Hello"})
return formatted_messages, system_message
def _handle_completion(
self,
params: dict[str, Any],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str | Any:
"""Handle non-streaming message completion."""
try:
response: Message = self.client.messages.create(**params)
except Exception as e:
if is_context_length_exceeded(e):
logging.error(f"Context window exceeded: {e}")
raise LLMContextLengthExceededError(str(e)) from e
raise e from e
usage = self._extract_anthropic_token_usage(response)
self._track_token_usage_internal(usage)
if response.content and available_functions:
for content_block in response.content:
if isinstance(content_block, ToolUseBlock):
function_name = content_block.name
function_args = content_block.input
result = self._handle_tool_execution(
function_name=function_name,
function_args=function_args, # type: ignore
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
)
if result is not None:
return result
# Extract text content
content = ""
if response.content:
for content_block in response.content:
if hasattr(content_block, "text"):
content += content_block.text
content = self._apply_stop_words(content)
self._emit_call_completed_event(
response=content,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
if usage.get("total_tokens", 0) > 0:
logging.info(f"Anthropic API usage: {usage}")
return content
def _handle_streaming_completion(
self,
params: dict[str, Any],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str:
"""Handle streaming message completion."""
full_response = ""
tool_uses = {}
# Make streaming API call
with self.client.messages.stream(**params) as stream:
for event in stream:
# Handle content delta events
if hasattr(event, "delta") and hasattr(event.delta, "text"):
text_delta = event.delta.text
full_response += text_delta
self._emit_stream_chunk_event(
chunk=text_delta,
from_task=from_task,
from_agent=from_agent,
)
# Handle tool use events
elif hasattr(event, "delta") and hasattr(event.delta, "partial_json"):
# Tool use streaming - accumulate JSON
tool_id = getattr(event, "index", "default")
if tool_id not in tool_uses:
tool_uses[tool_id] = {
"name": "",
"input": "",
}
if hasattr(event.delta, "name"):
tool_uses[tool_id]["name"] = event.delta.name
if hasattr(event.delta, "partial_json"):
tool_uses[tool_id]["input"] += event.delta.partial_json
# Handle completed tool uses
if tool_uses and available_functions:
for tool_data in tool_uses.values():
function_name = tool_data["name"]
try:
function_args = json.loads(tool_data["input"])
except json.JSONDecodeError as e:
logging.error(f"Failed to parse streamed tool arguments: {e}")
continue
# Execute tool
result = self._handle_tool_execution(
function_name=function_name,
function_args=function_args,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
)
if result is not None:
return result
# Apply stop words to full response
full_response = self._apply_stop_words(full_response)
# Emit completion event and return full response
self._emit_call_completed_event(
response=full_response,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return full_response
def supports_function_calling(self) -> bool:
"""Check if the model supports function calling."""
return self.supports_tools
def supports_stop_words(self) -> bool:
"""Check if the model supports stop words."""
return True # All Claude models support stop sequences
def get_context_window_size(self) -> int:
"""Get the context window size for the model."""
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO
# Context window sizes for Anthropic models
context_windows = {
"claude-3-5-sonnet": 200000,
"claude-3-5-haiku": 200000,
"claude-3-opus": 200000,
"claude-3-sonnet": 200000,
"claude-3-haiku": 200000,
"claude-3-7-sonnet": 200000,
"claude-2.1": 200000,
"claude-2": 100000,
"claude-instant": 100000,
}
# Find the best match for the model name
for model_prefix, size in context_windows.items():
if self.model.startswith(model_prefix):
return int(size * CONTEXT_WINDOW_USAGE_RATIO)
# Default context window size for Claude models
return int(200000 * CONTEXT_WINDOW_USAGE_RATIO)
def _extract_anthropic_token_usage(self, response: Message) -> dict[str, Any]:
"""Extract token usage from Anthropic response."""
if hasattr(response, "usage") and response.usage:
usage = response.usage
input_tokens = getattr(usage, "input_tokens", 0)
output_tokens = getattr(usage, "output_tokens", 0)
return {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens,
}
return {"total_tokens": 0}

View File

@@ -0,0 +1,473 @@
import json
import logging
import os
from typing import Any
from crewai.utilities.agent_utils import is_context_length_exceeded
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededError,
)
try:
from azure.ai.inference import ChatCompletionsClient # type: ignore
from azure.ai.inference.models import ( # type: ignore
ChatCompletions,
ChatCompletionsToolCall,
StreamingChatCompletionsUpdate,
)
from azure.core.credentials import AzureKeyCredential # type: ignore
from azure.core.exceptions import HttpResponseError # type: ignore
from crewai.events.types.llm_events import LLMCallType
from crewai.llms.base_llm import BaseLLM
except ImportError:
raise ImportError(
"Azure AI Inference native provider not available, to install: `uv add azure-ai-inference`"
) from None
class AzureCompletion(BaseLLM):
"""Azure AI Inference native completion implementation.
This class provides direct integration with the Azure AI Inference Python SDK,
offering native function calling, streaming support, and proper Azure authentication.
"""
def __init__(
self,
model: str,
api_key: str | None = None,
endpoint: str | None = None,
api_version: str | None = None,
timeout: float | None = None,
max_retries: int = 2,
temperature: float | None = None,
top_p: float | None = None,
frequency_penalty: float | None = None,
presence_penalty: float | None = None,
max_tokens: int | None = None,
stop: list[str] | None = None,
stream: bool = False,
**kwargs,
):
"""Initialize Azure AI Inference chat completion client.
Args:
model: Azure deployment name or model name
api_key: Azure API key (defaults to AZURE_API_KEY env var)
endpoint: Azure endpoint URL (defaults to AZURE_ENDPOINT env var)
api_version: Azure API version (defaults to AZURE_API_VERSION env var)
timeout: Request timeout in seconds
max_retries: Maximum number of retries
temperature: Sampling temperature (0-2)
top_p: Nucleus sampling parameter
frequency_penalty: Frequency penalty (-2 to 2)
presence_penalty: Presence penalty (-2 to 2)
max_tokens: Maximum tokens in response
stop: Stop sequences
stream: Enable streaming responses
**kwargs: Additional parameters
"""
super().__init__(
model=model, temperature=temperature, stop=stop or [], **kwargs
)
self.api_key = api_key or os.getenv("AZURE_API_KEY")
self.endpoint = (
endpoint
or os.getenv("AZURE_ENDPOINT")
or os.getenv("AZURE_OPENAI_ENDPOINT")
or os.getenv("AZURE_API_BASE")
)
self.api_version = api_version or os.getenv("AZURE_API_VERSION") or "2024-02-01"
if not self.api_key:
raise ValueError(
"Azure API key is required. Set AZURE_API_KEY environment variable or pass api_key parameter."
)
if not self.endpoint:
raise ValueError(
"Azure endpoint is required. Set AZURE_ENDPOINT environment variable or pass endpoint parameter."
)
self.client = ChatCompletionsClient(
endpoint=self.endpoint,
credential=AzureKeyCredential(self.api_key),
)
self.top_p = top_p
self.frequency_penalty = frequency_penalty
self.presence_penalty = presence_penalty
self.max_tokens = max_tokens
self.stream = stream
self.is_openai_model = any(
prefix in model.lower() for prefix in ["gpt-", "o1-", "text-"]
)
def call(
self,
messages: str | list[dict[str, str]],
tools: list[dict] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str | Any:
"""Call Azure AI Inference chat completions API.
Args:
messages: Input messages for the chat completion
tools: List of tool/function definitions
callbacks: Callback functions (not used in native implementation)
available_functions: Available functions for tool calling
from_task: Task that initiated the call
from_agent: Agent that initiated the call
Returns:
Chat completion response or tool call result
"""
try:
# Emit call started event
self._emit_call_started_event(
messages=messages,
tools=tools,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
)
# Format messages for Azure
formatted_messages = self._format_messages_for_azure(messages)
# Prepare completion parameters
completion_params = self._prepare_completion_params(
formatted_messages, tools
)
# Handle streaming vs non-streaming
if self.stream:
return self._handle_streaming_completion(
completion_params, available_functions, from_task, from_agent
)
return self._handle_completion(
completion_params, available_functions, from_task, from_agent
)
except HttpResponseError as e:
error_msg = f"Azure API HTTP error: {e.status_code} - {e.message}"
logging.error(error_msg)
self._emit_call_failed_event(
error=error_msg, from_task=from_task, from_agent=from_agent
)
raise
except Exception as e:
error_msg = f"Azure API call failed: {e!s}"
logging.error(error_msg)
self._emit_call_failed_event(
error=error_msg, from_task=from_task, from_agent=from_agent
)
raise
def _prepare_completion_params(
self,
messages: list[dict[str, str]],
tools: list[dict] | None = None,
) -> dict[str, Any]:
"""Prepare parameters for Azure AI Inference chat completion.
Args:
messages: Formatted messages for Azure
tools: Tool definitions
Returns:
Parameters dictionary for Azure API
"""
params = {
"model": self.model,
"messages": messages,
"stream": self.stream,
}
# Add optional parameters if set
if self.temperature is not None:
params["temperature"] = self.temperature
if self.top_p is not None:
params["top_p"] = self.top_p
if self.frequency_penalty is not None:
params["frequency_penalty"] = self.frequency_penalty
if self.presence_penalty is not None:
params["presence_penalty"] = self.presence_penalty
if self.max_tokens is not None:
params["max_tokens"] = self.max_tokens
if self.stop:
params["stop"] = self.stop
# Handle tools/functions for Azure OpenAI models
if tools and self.is_openai_model:
params["tools"] = self._convert_tools_for_interference(tools)
params["tool_choice"] = "auto"
return params
def _convert_tools_for_interference(self, tools: list[dict]) -> list[dict]:
"""Convert CrewAI tool format to Azure OpenAI function calling format."""
from crewai.llms.providers.utils.common import safe_tool_conversion
azure_tools = []
for tool in tools:
name, description, parameters = safe_tool_conversion(tool, "Azure")
azure_tool = {
"type": "function",
"function": {
"name": name,
"description": description,
},
}
if parameters:
if isinstance(parameters, dict):
azure_tool["function"]["parameters"] = parameters # type: ignore
else:
azure_tool["function"]["parameters"] = dict(parameters)
azure_tools.append(azure_tool)
return azure_tools
def _format_messages_for_azure(
self, messages: str | list[dict[str, str]]
) -> list[dict[str, str]]:
"""Format messages for Azure AI Inference API.
Args:
messages: Input messages
Returns:
List of dict objects
"""
# Use base class formatting first
base_formatted = super()._format_messages(messages)
azure_messages = []
for message in base_formatted:
role = message.get("role")
content = message.get("content", "")
if role == "system":
azure_messages.append(dict(content=content))
elif role == "user":
azure_messages.append(dict(content=content))
elif role == "assistant":
azure_messages.append(dict(content=content))
else:
# Default to user message for unknown roles
azure_messages.append(dict(content=content))
return azure_messages
def _handle_completion(
self,
params: dict[str, Any],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str | Any:
"""Handle non-streaming chat completion."""
# Make API call
try:
response: ChatCompletions = self.client.complete(**params)
if not response.choices:
raise ValueError("No choices returned from Azure API")
choice = response.choices[0]
message = choice.message
# Extract and track token usage
usage = self._extract_azure_token_usage(response)
self._track_token_usage_internal(usage)
# Handle tool calls
if message.tool_calls and available_functions:
tool_call = message.tool_calls[0] # Handle first tool call
if isinstance(tool_call, ChatCompletionsToolCall):
function_name = tool_call.function.name
try:
function_args = json.loads(tool_call.function.arguments)
except json.JSONDecodeError as e:
logging.error(f"Failed to parse tool arguments: {e}")
function_args = {}
# Execute tool
result = self._handle_tool_execution(
function_name=function_name,
function_args=function_args,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
)
if result is not None:
return result
# Extract content
content = message.content or ""
# Apply stop words
content = self._apply_stop_words(content)
# Emit completion event and return content
self._emit_call_completed_event(
response=content,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
except Exception as e:
if is_context_length_exceeded(e):
logging.error(f"Context window exceeded: {e}")
raise LLMContextLengthExceededError(str(e)) from e
return content
def _handle_streaming_completion(
self,
params: dict[str, Any],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str:
"""Handle streaming chat completion."""
full_response = ""
tool_calls = {}
# Make streaming API call
for update in self.client.complete(**params):
if isinstance(update, StreamingChatCompletionsUpdate):
if update.choices:
choice = update.choices[0]
if choice.delta and choice.delta.content:
content_delta = choice.delta.content
full_response += content_delta
self._emit_stream_chunk_event(
chunk=content_delta,
from_task=from_task,
from_agent=from_agent,
)
# Handle tool call streaming
if choice.delta and choice.delta.tool_calls:
for tool_call in choice.delta.tool_calls:
call_id = tool_call.id or "default"
if call_id not in tool_calls:
tool_calls[call_id] = {
"name": "",
"arguments": "",
}
if tool_call.function and tool_call.function.name:
tool_calls[call_id]["name"] = tool_call.function.name
if tool_call.function and tool_call.function.arguments:
tool_calls[call_id]["arguments"] += (
tool_call.function.arguments
)
# Handle completed tool calls
if tool_calls and available_functions:
for call_data in tool_calls.values():
function_name = call_data["name"]
try:
function_args = json.loads(call_data["arguments"])
except json.JSONDecodeError as e:
logging.error(f"Failed to parse streamed tool arguments: {e}")
continue
# Execute tool
result = self._handle_tool_execution(
function_name=function_name,
function_args=function_args,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
)
if result is not None:
return result
# Apply stop words to full response
full_response = self._apply_stop_words(full_response)
# Emit completion event and return full response
self._emit_call_completed_event(
response=full_response,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return full_response
def supports_function_calling(self) -> bool:
"""Check if the model supports function calling."""
# Azure OpenAI models support function calling
return self.is_openai_model
def supports_stop_words(self) -> bool:
"""Check if the model supports stop words."""
return True # Most Azure models support stop sequences
def get_context_window_size(self) -> int:
"""Get the context window size for the model."""
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES
min_context = 1024
max_context = 2097152
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
if value < min_context or value > max_context:
raise ValueError(
f"Context window for {key} must be between {min_context} and {max_context}"
)
# Context window sizes for common Azure models
context_windows = {
"gpt-4": 8192,
"gpt-4o": 128000,
"gpt-4o-mini": 200000,
"gpt-4-turbo": 128000,
"gpt-35-turbo": 16385,
"gpt-3.5-turbo": 16385,
"text-embedding": 8191,
}
# Find the best match for the model name
for model_prefix, size in context_windows.items():
if self.model.startswith(model_prefix):
return int(size * CONTEXT_WINDOW_USAGE_RATIO)
# Default context window size
return int(8192 * CONTEXT_WINDOW_USAGE_RATIO)
def _extract_azure_token_usage(self, response: ChatCompletions) -> dict[str, Any]:
"""Extract token usage from Azure response."""
if hasattr(response, "usage") and response.usage:
usage = response.usage
return {
"prompt_tokens": getattr(usage, "prompt_tokens", 0),
"completion_tokens": getattr(usage, "completion_tokens", 0),
"total_tokens": getattr(usage, "total_tokens", 0),
}
return {"total_tokens": 0}

View File

@@ -0,0 +1,497 @@
import logging
import os
from typing import Any
from crewai.events.types.llm_events import LLMCallType
from crewai.llms.base_llm import BaseLLM
from crewai.utilities.agent_utils import is_context_length_exceeded
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededError,
)
try:
from google import genai # type: ignore
from google.genai import types # type: ignore
from google.genai.errors import APIError # type: ignore
except ImportError:
raise ImportError(
"Google Gen AI native provider not available, to install: `uv add google-genai`"
) from None
class GeminiCompletion(BaseLLM):
"""Google Gemini native completion implementation.
This class provides direct integration with the Google Gen AI Python SDK,
offering native function calling, streaming support, and proper Gemini formatting.
"""
def __init__(
self,
model: str = "gemini-2.0-flash-001",
api_key: str | None = None,
project: str | None = None,
location: str | None = None,
temperature: float | None = None,
top_p: float | None = None,
top_k: int | None = None,
max_output_tokens: int | None = None,
stop_sequences: list[str] | None = None,
stream: bool = False,
safety_settings: dict[str, Any] | None = None,
**kwargs,
):
"""Initialize Google Gemini chat completion client.
Args:
model: Gemini model name (e.g., 'gemini-2.0-flash-001', 'gemini-1.5-pro')
api_key: Google API key (defaults to GOOGLE_API_KEY or GEMINI_API_KEY env var)
project: Google Cloud project ID (for Vertex AI)
location: Google Cloud location (for Vertex AI, defaults to 'us-central1')
temperature: Sampling temperature (0-2)
top_p: Nucleus sampling parameter
top_k: Top-k sampling parameter
max_output_tokens: Maximum tokens in response
stop_sequences: Stop sequences
stream: Enable streaming responses
safety_settings: Safety filter settings
**kwargs: Additional parameters
"""
super().__init__(
model=model, temperature=temperature, stop=stop_sequences or [], **kwargs
)
# Get API configuration
self.api_key = (
api_key or os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY")
)
self.project = project or os.getenv("GOOGLE_CLOUD_PROJECT")
self.location = location or os.getenv("GOOGLE_CLOUD_LOCATION") or "us-central1"
# Initialize client based on available configuration
if self.project:
# Use Vertex AI
self.client = genai.Client(
vertexai=True,
project=self.project,
location=self.location,
)
elif self.api_key:
# Use Gemini Developer API
self.client = genai.Client(api_key=self.api_key)
else:
raise ValueError(
"Either GOOGLE_API_KEY/GEMINI_API_KEY (for Gemini API) or "
"GOOGLE_CLOUD_PROJECT (for Vertex AI) must be set"
)
# Store completion parameters
self.top_p = top_p
self.top_k = top_k
self.max_output_tokens = max_output_tokens
self.stream = stream
self.safety_settings = safety_settings or {}
self.stop_sequences = stop_sequences or []
# Model-specific settings
self.is_gemini_2 = "gemini-2" in model.lower()
self.is_gemini_1_5 = "gemini-1.5" in model.lower()
self.supports_tools = self.is_gemini_1_5 or self.is_gemini_2
def call(
self,
messages: str | list[dict[str, str]],
tools: list[dict] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str | Any:
"""Call Google Gemini generate content API.
Args:
messages: Input messages for the chat completion
tools: List of tool/function definitions
callbacks: Callback functions (not used as token counts are handled by the reponse)
available_functions: Available functions for tool calling
from_task: Task that initiated the call
from_agent: Agent that initiated the call
Returns:
Chat completion response or tool call result
"""
try:
self._emit_call_started_event(
messages=messages,
tools=tools,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
)
self.tools = tools
formatted_content, system_instruction = self._format_messages_for_gemini(
messages
)
config = self._prepare_generation_config(system_instruction, tools)
if self.stream:
return self._handle_streaming_completion(
formatted_content,
config,
available_functions,
from_task,
from_agent,
)
return self._handle_completion(
formatted_content,
system_instruction,
config,
available_functions,
from_task,
from_agent,
)
except APIError as e:
error_msg = f"Google Gemini API error: {e.code} - {e.message}"
logging.error(error_msg)
self._emit_call_failed_event(
error=error_msg, from_task=from_task, from_agent=from_agent
)
raise
except Exception as e:
error_msg = f"Google Gemini API call failed: {e!s}"
logging.error(error_msg)
self._emit_call_failed_event(
error=error_msg, from_task=from_task, from_agent=from_agent
)
raise
def _prepare_generation_config(
self,
system_instruction: str | None = None,
tools: list[dict] | None = None,
) -> types.GenerateContentConfig:
"""Prepare generation config for Google Gemini API.
Args:
system_instruction: System instruction for the model
tools: Tool definitions
Returns:
GenerateContentConfig object for Gemini API
"""
self.tools = tools
config_params = {}
# Add system instruction if present
if system_instruction:
# Convert system instruction to Content format
system_content = types.Content(
role="user", parts=[types.Part.from_text(text=system_instruction)]
)
config_params["system_instruction"] = system_content
# Add generation config parameters
if self.temperature is not None:
config_params["temperature"] = self.temperature
if self.top_p is not None:
config_params["top_p"] = self.top_p
if self.top_k is not None:
config_params["top_k"] = self.top_k
if self.max_output_tokens is not None:
config_params["max_output_tokens"] = self.max_output_tokens
if self.stop_sequences:
config_params["stop_sequences"] = self.stop_sequences
# Handle tools for supported models
if tools and self.supports_tools:
config_params["tools"] = self._convert_tools_for_interference(tools)
if self.safety_settings:
config_params["safety_settings"] = self.safety_settings
return types.GenerateContentConfig(**config_params)
def _convert_tools_for_interference(self, tools: list[dict]) -> list[types.Tool]:
"""Convert CrewAI tool format to Gemini function declaration format."""
gemini_tools = []
from crewai.llms.providers.utils.common import safe_tool_conversion
for tool in tools:
name, description, parameters = safe_tool_conversion(tool, "Gemini")
function_declaration = types.FunctionDeclaration(
name=name,
description=description,
)
# Add parameters if present - ensure parameters is a dict
if parameters and isinstance(parameters, dict):
function_declaration.parameters = parameters
gemini_tool = types.Tool(function_declarations=[function_declaration])
gemini_tools.append(gemini_tool)
return gemini_tools
def _format_messages_for_gemini(
self, messages: str | list[dict[str, str]]
) -> tuple[list[types.Content], str | None]:
"""Format messages for Gemini API.
Gemini has specific requirements:
- System messages are separate system_instruction
- Content is organized as Content objects with Parts
- Roles are 'user' and 'model' (not 'assistant')
Args:
messages: Input messages
Returns:
Tuple of (formatted_contents, system_instruction)
"""
# Use base class formatting first
base_formatted = super()._format_messages(messages)
contents = []
system_instruction = None
for message in base_formatted:
role = message.get("role")
content = message.get("content", "")
if role == "system":
# Extract system instruction - Gemini handles it separately
if system_instruction:
system_instruction += f"\n\n{content}"
else:
system_instruction = content
else:
# Convert role for Gemini (assistant -> model)
gemini_role = "model" if role == "assistant" else "user"
# Create Content object
gemini_content = types.Content(
role=gemini_role, parts=[types.Part.from_text(text=content)]
)
contents.append(gemini_content)
return contents, system_instruction
def _handle_completion(
self,
contents: list[types.Content],
system_instruction: str | None,
config: types.GenerateContentConfig,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str | Any:
"""Handle non-streaming content generation."""
api_params = {
"model": self.model,
"contents": contents,
"config": config,
}
try:
response = self.client.models.generate_content(**api_params)
usage = self._extract_token_usage(response)
except Exception as e:
if is_context_length_exceeded(e):
logging.error(f"Context window exceeded: {e}")
raise LLMContextLengthExceededError(str(e)) from e
raise e from e
self._track_token_usage_internal(usage)
if response.candidates and (self.tools or available_functions):
candidate = response.candidates[0]
if candidate.content and candidate.content.parts:
for part in candidate.content.parts:
if hasattr(part, "function_call") and part.function_call:
function_name = part.function_call.name
function_args = (
dict(part.function_call.args)
if part.function_call.args
else {}
)
result = self._handle_tool_execution(
function_name=function_name,
function_args=function_args,
available_functions=available_functions, # type: ignore
from_task=from_task,
from_agent=from_agent,
)
if result is not None:
return result
content = response.text if hasattr(response, "text") else ""
content = self._apply_stop_words(content)
messages_for_event = self._convert_contents_to_dict(contents)
self._emit_call_completed_event(
response=content,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=messages_for_event,
)
return content
def _handle_streaming_completion(
self,
contents: list[types.Content],
config: types.GenerateContentConfig,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str:
"""Handle streaming content generation."""
full_response = ""
function_calls = {}
api_params = {
"model": self.model,
"contents": contents,
"config": config,
}
for chunk in self.client.models.generate_content_stream(**api_params):
if hasattr(chunk, "text") and chunk.text:
full_response += chunk.text
self._emit_stream_chunk_event(
chunk=chunk.text,
from_task=from_task,
from_agent=from_agent,
)
if hasattr(chunk, "candidates") and chunk.candidates:
candidate = chunk.candidates[0]
if candidate.content and candidate.content.parts:
for part in candidate.content.parts:
if hasattr(part, "function_call") and part.function_call:
call_id = part.function_call.name or "default"
if call_id not in function_calls:
function_calls[call_id] = {
"name": part.function_call.name,
"args": dict(part.function_call.args)
if part.function_call.args
else {},
}
# Handle completed function calls
if function_calls and available_functions:
for call_data in function_calls.values():
function_name = call_data["name"]
function_args = call_data["args"]
# Execute tool
result = self._handle_tool_execution(
function_name=function_name,
function_args=function_args,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
)
if result is not None:
return result
messages_for_event = self._convert_contents_to_dict(contents)
self._emit_call_completed_event(
response=full_response,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=messages_for_event,
)
return full_response
def supports_function_calling(self) -> bool:
"""Check if the model supports function calling."""
return self.supports_tools
def supports_stop_words(self) -> bool:
"""Check if the model supports stop words."""
return self._supports_stop_words_implementation()
def get_context_window_size(self) -> int:
"""Get the context window size for the model."""
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES
min_context = 1024
max_context = 2097152
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
if value < min_context or value > max_context:
raise ValueError(
f"Context window for {key} must be between {min_context} and {max_context}"
)
context_windows = {
"gemini-2.0-flash": 1048576, # 1M tokens
"gemini-2.0-flash-thinking": 32768,
"gemini-2.0-flash-lite": 1048576,
"gemini-2.5-flash": 1048576,
"gemini-2.5-pro": 1048576,
"gemini-1.5-pro": 2097152, # 2M tokens
"gemini-1.5-flash": 1048576,
"gemini-1.5-flash-8b": 1048576,
"gemini-1.0-pro": 32768,
"gemma-3-1b": 32000,
"gemma-3-4b": 128000,
"gemma-3-12b": 128000,
"gemma-3-27b": 128000,
}
# Find the best match for the model name
for model_prefix, size in context_windows.items():
if self.model.startswith(model_prefix):
return int(size * CONTEXT_WINDOW_USAGE_RATIO)
# Default context window size for Gemini models
return int(1048576 * CONTEXT_WINDOW_USAGE_RATIO) # 1M tokens
def _extract_token_usage(self, response: dict[str, Any]) -> dict[str, Any]:
"""Extract token usage from Gemini response."""
if hasattr(response, "usage_metadata"):
usage = response.usage_metadata
return {
"prompt_token_count": getattr(usage, "prompt_token_count", 0),
"candidates_token_count": getattr(usage, "candidates_token_count", 0),
"total_token_count": getattr(usage, "total_token_count", 0),
"total_tokens": getattr(usage, "total_token_count", 0),
}
return {"total_tokens": 0}
def _convert_contents_to_dict(
self, contents: list[types.Content]
) -> list[dict[str, str]]:
"""Convert contents to dict format."""
return [
{
"role": "assistant"
if content_obj.role == "model"
else content_obj.role,
"content": " ".join(
part.text
for part in content_obj.parts
if hasattr(part, "text") and part.text
),
}
for content_obj in contents
]

View File

@@ -0,0 +1,484 @@
from collections.abc import Iterator
import json
import logging
import os
from typing import Any
from crewai.events.types.llm_events import LLMCallType
from crewai.llms.base_llm import BaseLLM
from crewai.utilities.agent_utils import is_context_length_exceeded
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededError,
)
from openai import OpenAI
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from pydantic import BaseModel
class OpenAICompletion(BaseLLM):
"""OpenAI native completion implementation.
This class provides direct integration with the OpenAI Python SDK,
offering native structured outputs, function calling, and streaming support.
"""
def __init__(
self,
model: str = "gpt-4o",
api_key: str | None = None,
base_url: str | None = None,
organization: str | None = None,
project: str | None = None,
timeout: float | None = None,
max_retries: int = 2,
temperature: float | None = None,
top_p: float | None = None,
frequency_penalty: float | None = None,
presence_penalty: float | None = None,
max_tokens: int | None = None,
max_completion_tokens: int | None = None,
seed: int | None = None,
stream: bool = False,
response_format: dict[str, Any] | type[BaseModel] | None = None,
logprobs: bool | None = None,
top_logprobs: int | None = None,
reasoning_effort: str | None = None, # For o1 models
provider: str | None = None, # Add provider parameter
**kwargs,
):
"""Initialize OpenAI chat completion client."""
if provider is None:
provider = kwargs.pop("provider", "openai")
super().__init__(
model=model,
temperature=temperature,
api_key=api_key or os.getenv("OPENAI_API_KEY"),
base_url=base_url,
timeout=timeout,
provider=provider,
**kwargs,
)
self.client = OpenAI(
api_key=api_key or os.getenv("OPENAI_API_KEY"),
base_url=base_url,
organization=organization,
project=project,
timeout=timeout,
max_retries=max_retries,
)
self.top_p = top_p
self.frequency_penalty = frequency_penalty
self.presence_penalty = presence_penalty
self.max_tokens = max_tokens
self.max_completion_tokens = max_completion_tokens
self.seed = seed
self.stream = stream
self.response_format = response_format
self.logprobs = logprobs
self.top_logprobs = top_logprobs
self.reasoning_effort = reasoning_effort
self.timeout = timeout
self.is_o1_model = "o1" in model.lower()
self.is_gpt4_model = "gpt-4" in model.lower()
def call(
self,
messages: str | list[dict[str, str]],
tools: list[dict] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str | Any:
"""Call OpenAI chat completion API.
Args:
messages: Input messages for the chat completion
tools: list of tool/function definitions
callbacks: Callback functions (not used in native implementation)
available_functions: Available functions for tool calling
from_task: Task that initiated the call
from_agent: Agent that initiated the call
Returns:
Chat completion response or tool call result
"""
try:
self._emit_call_started_event(
messages=messages,
tools=tools,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
)
formatted_messages = self._format_messages(messages)
completion_params = self._prepare_completion_params(
formatted_messages, tools
)
if self.stream:
return self._handle_streaming_completion(
completion_params, available_functions, from_task, from_agent
)
return self._handle_completion(
completion_params, available_functions, from_task, from_agent
)
except Exception as e:
error_msg = f"OpenAI API call failed: {e!s}"
logging.error(error_msg)
self._emit_call_failed_event(
error=error_msg, from_task=from_task, from_agent=from_agent
)
raise
def _prepare_completion_params(
self, messages: list[dict[str, str]], tools: list[dict] | None = None
) -> dict[str, Any]:
"""Prepare parameters for OpenAI chat completion."""
params = {
"model": self.model,
"messages": messages,
"stream": self.stream,
}
params.update(self.additional_params)
if self.temperature is not None:
params["temperature"] = self.temperature
if self.top_p is not None:
params["top_p"] = self.top_p
if self.frequency_penalty is not None:
params["frequency_penalty"] = self.frequency_penalty
if self.presence_penalty is not None:
params["presence_penalty"] = self.presence_penalty
if self.max_completion_tokens is not None:
params["max_completion_tokens"] = self.max_completion_tokens
elif self.max_tokens is not None:
params["max_tokens"] = self.max_tokens
if self.seed is not None:
params["seed"] = self.seed
if self.logprobs is not None:
params["logprobs"] = self.logprobs
if self.top_logprobs is not None:
params["top_logprobs"] = self.top_logprobs
# Handle o1 model specific parameters
if self.is_o1_model and self.reasoning_effort:
params["reasoning_effort"] = self.reasoning_effort
# Handle response format for structured outputs
if self.response_format:
if isinstance(self.response_format, type) and issubclass(
self.response_format, BaseModel
):
# Convert Pydantic model to OpenAI response format
params["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": self.response_format.__name__,
"schema": self.response_format.model_json_schema(),
},
}
else:
params["response_format"] = self.response_format
if tools:
params["tools"] = self._convert_tools_for_interference(tools)
params["tool_choice"] = "auto"
# Filter out CrewAI-specific parameters that shouldn't go to the API
crewai_specific_params = {
"callbacks",
"available_functions",
"from_task",
"from_agent",
"provider",
"api_key",
"base_url",
"timeout",
"max_retries",
}
return {k: v for k, v in params.items() if k not in crewai_specific_params}
def _convert_tools_for_interference(self, tools: list[dict]) -> list[dict]:
"""Convert CrewAI tool format to OpenAI function calling format."""
from crewai.llms.providers.utils.common import safe_tool_conversion
openai_tools = []
for tool in tools:
name, description, parameters = safe_tool_conversion(tool, "OpenAI")
openai_tool = {
"type": "function",
"function": {
"name": name,
"description": description,
},
}
if parameters:
if isinstance(parameters, dict):
openai_tool["function"]["parameters"] = parameters # type: ignore
else:
openai_tool["function"]["parameters"] = dict(parameters)
openai_tools.append(openai_tool)
return openai_tools
def _handle_completion(
self,
params: dict[str, Any],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str | Any:
"""Handle non-streaming chat completion."""
try:
response: ChatCompletion = self.client.chat.completions.create(**params)
usage = self._extract_openai_token_usage(response)
self._track_token_usage_internal(usage)
choice: Choice = response.choices[0]
message = choice.message
if message.tool_calls and available_functions:
tool_call = message.tool_calls[0]
function_name = tool_call.function.name
try:
function_args = json.loads(tool_call.function.arguments)
except json.JSONDecodeError as e:
logging.error(f"Failed to parse tool arguments: {e}")
function_args = {}
result = self._handle_tool_execution(
function_name=function_name,
function_args=function_args,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
)
if result is not None:
return result
content = message.content or ""
content = self._apply_stop_words(content)
if self.response_format and isinstance(self.response_format, type):
try:
structured_result = self._validate_structured_output(
content, self.response_format
)
self._emit_call_completed_event(
response=structured_result,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_result
except ValueError as e:
logging.warning(f"Structured output validation failed: {e}")
self._emit_call_completed_event(
response=content,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
if usage.get("total_tokens", 0) > 0:
logging.info(f"OpenAI API usage: {usage}")
except Exception as e:
if is_context_length_exceeded(e):
logging.error(f"Context window exceeded: {e}")
raise LLMContextLengthExceededError(str(e)) from e
raise e from e
return content
def _handle_streaming_completion(
self,
params: dict[str, Any],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str:
"""Handle streaming chat completion."""
full_response = ""
tool_calls = {}
# Make streaming API call
stream: Iterator[ChatCompletionChunk] = self.client.chat.completions.create(
**params
)
for chunk in stream:
if not chunk.choices:
continue
choice = chunk.choices[0]
delta: ChoiceDelta = choice.delta
# Handle content streaming
if delta.content:
full_response += delta.content
self._emit_stream_chunk_event(
chunk=delta.content,
from_task=from_task,
from_agent=from_agent,
)
# Handle tool call streaming
if delta.tool_calls:
for tool_call in delta.tool_calls:
call_id = tool_call.id or "default"
if call_id not in tool_calls:
tool_calls[call_id] = {
"name": "",
"arguments": "",
}
if tool_call.function and tool_call.function.name:
tool_calls[call_id]["name"] = tool_call.function.name
if tool_call.function and tool_call.function.arguments:
tool_calls[call_id]["arguments"] += tool_call.function.arguments
if tool_calls and available_functions:
for call_data in tool_calls.values():
function_name = call_data["name"]
arguments = call_data["arguments"]
# Skip if function name is empty or arguments are empty
if not function_name or not arguments:
continue
# Check if function exists in available functions
if function_name not in available_functions:
logging.warning(
f"Function '{function_name}' not found in available functions"
)
continue
try:
function_args = json.loads(arguments)
except json.JSONDecodeError as e:
logging.error(f"Failed to parse streamed tool arguments: {e}")
continue
result = self._handle_tool_execution(
function_name=function_name,
function_args=function_args,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
)
if result is not None:
return result
# Apply stop words to full response
full_response = self._apply_stop_words(full_response)
# Emit completion event and return full response
self._emit_call_completed_event(
response=full_response,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return full_response
def supports_function_calling(self) -> bool:
"""Check if the model supports function calling."""
return not self.is_o1_model
def supports_stop_words(self) -> bool:
"""Check if the model supports stop words."""
return not self.is_o1_model
def get_context_window_size(self) -> int:
"""Get the context window size for the model."""
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES
min_context = 1024
max_context = 2097152
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
if value < min_context or value > max_context:
raise ValueError(
f"Context window for {key} must be between {min_context} and {max_context}"
)
# Context window sizes for OpenAI models
context_windows = {
"gpt-4": 8192,
"gpt-4o": 128000,
"gpt-4o-mini": 200000,
"gpt-4-turbo": 128000,
"gpt-4.1": 1047576,
"gpt-4.1-mini-2025-04-14": 1047576,
"gpt-4.1-nano-2025-04-14": 1047576,
"o1-preview": 128000,
"o1-mini": 128000,
"o3-mini": 200000,
"o4-mini": 200000,
}
# Find the best match for the model name
for model_prefix, size in context_windows.items():
if self.model.startswith(model_prefix):
return int(size * CONTEXT_WINDOW_USAGE_RATIO)
# Default context window size
return int(8192 * CONTEXT_WINDOW_USAGE_RATIO)
def _extract_openai_token_usage(self, response: ChatCompletion) -> dict[str, Any]:
"""Extract token usage from OpenAI ChatCompletion response."""
if hasattr(response, "usage") and response.usage:
usage = response.usage
return {
"prompt_tokens": getattr(usage, "prompt_tokens", 0),
"completion_tokens": getattr(usage, "completion_tokens", 0),
"total_tokens": getattr(usage, "total_tokens", 0),
}
return {"total_tokens": 0}
def _format_messages(
self, messages: str | list[dict[str, str]]
) -> list[dict[str, str]]:
"""Format messages for OpenAI API."""
# Use base class formatting first
base_formatted = super()._format_messages(messages)
# Apply OpenAI-specific formatting
formatted_messages = []
for message in base_formatted:
if self.is_o1_model and message.get("role") == "system":
formatted_messages.append(
{"role": "user", "content": f"System: {message['content']}"}
)
else:
formatted_messages.append(message)
return formatted_messages

View File

@@ -0,0 +1,136 @@
import logging
import re
from typing import Any
def validate_function_name(name: str, provider: str = "LLM") -> str:
"""Validate function name according to common LLM provider requirements.
Most LLM providers (OpenAI, Gemini, Anthropic) have similar requirements:
- Must start with letter or underscore
- Only alphanumeric, underscore, dot, colon, dash allowed
- Maximum length of 64 characters
- Cannot be empty
Args:
name: The function name to validate
provider: The provider name for error messages
Returns:
The validated function name (unchanged if valid)
Raises:
ValueError: If the function name is invalid
"""
if not name or not isinstance(name, str):
raise ValueError(f"{provider} function name cannot be empty")
if not (name[0].isalpha() or name[0] == "_"):
raise ValueError(
f"{provider} function name '{name}' must start with a letter or underscore"
)
if len(name) > 64:
raise ValueError(
f"{provider} function name '{name}' exceeds 64 character limit"
)
# Check for invalid characters (most providers support these)
if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_.\-:]*$", name):
raise ValueError(
f"{provider} function name '{name}' contains invalid characters. "
f"Only letters, numbers, underscore, dot, colon, dash allowed"
)
return name
def extract_tool_info(tool: dict[str, Any]) -> tuple[str, str, dict[str, Any]]:
"""Extract tool information from various schema formats.
Handles both OpenAI/standard format and direct format:
- OpenAI format: {"type": "function", "function": {"name": "...", ...}}
- Direct format: {"name": "...", "description": "...", ...}
Args:
tool: Tool dictionary in any supported format
Returns:
Tuple of (name, description, parameters)
Raises:
ValueError: If tool format is invalid
"""
if not isinstance(tool, dict):
raise ValueError("Tool must be a dictionary")
# Handle nested function schema format (OpenAI/standard)
if "function" in tool:
function_info = tool["function"]
if not isinstance(function_info, dict):
raise ValueError("Tool function must be a dictionary")
name = function_info.get("name", "")
description = function_info.get("description", "")
parameters = function_info.get("parameters", {})
else:
# Direct format
name = tool.get("name", "")
description = tool.get("description", "")
parameters = tool.get("parameters", {})
# Also check for args_schema (Pydantic format)
if not parameters and "args_schema" in tool:
if hasattr(tool["args_schema"], "model_json_schema"):
parameters = tool["args_schema"].model_json_schema()
return name, description, parameters
def log_tool_conversion(tool: dict[str, Any], provider: str) -> None:
"""Log tool conversion for debugging.
Args:
tool: The tool being converted
provider: The provider name
"""
try:
name, description, parameters = extract_tool_info(tool)
logging.debug(
f"{provider}: Converting tool '{name}' (desc: {description[:50]}...)"
)
logging.debug(f"{provider}: Tool parameters: {parameters}")
except Exception as e:
logging.error(f"{provider}: Error extracting tool info: {e}")
logging.error(f"{provider}: Tool structure: {tool}")
def safe_tool_conversion(
tool: dict[str, Any], provider: str
) -> tuple[str, str, dict[str, Any]]:
"""Safely extract and validate tool information.
Combines extraction, validation, and logging for robust tool conversion.
Args:
tool: Tool dictionary to convert
provider: Provider name for error messages and logging
Returns:
Tuple of (validated_name, description, parameters)
Raises:
ValueError: If tool is invalid or name validation fails
"""
try:
log_tool_conversion(tool, provider)
name, description, parameters = extract_tool_info(tool)
validated_name = validate_function_name(name, provider)
logging.info(f"{provider}: Successfully validated tool '{validated_name}'")
return validated_name, description, parameters
except Exception as e:
logging.error(f"{provider}: Error converting tool: {e}")
raise

View File

@@ -1,10 +1,10 @@
import ast
import datetime
import json
import time
from difflib import SequenceMatcher
import json
from json import JSONDecodeError
from textwrap import dedent
import time
from typing import TYPE_CHECKING, Any, Union
import json5
@@ -29,6 +29,7 @@ from crewai.utilities.agent_utils import (
render_text_description_and_args,
)
if TYPE_CHECKING:
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.lite_agent import LiteAgent
@@ -587,7 +588,23 @@ class ToolUsage:
e: Exception,
) -> None:
event_data = self._prepare_event_data(tool, tool_calling)
crewai_event_bus.emit(self, ToolUsageErrorEvent(**{**event_data, "error": e}))
event_data.update(
{
"task_id": str(self.task.id) if self.task else None,
"task_name": self.task.name or self.task.description
if self.task
else None,
}
)
crewai_event_bus.emit(
self,
ToolUsageErrorEvent(
**{
**event_data,
"error": e,
}
),
)
def on_tool_use_finished(
self,

View File

@@ -1,8 +1,8 @@
from __future__ import annotations
from collections.abc import Callable, Sequence
import json
import re
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Any, Final, Literal, TypedDict
from rich.console import Console
@@ -15,7 +15,6 @@ from crewai.agents.parser import (
parse,
)
from crewai.cli.config import Settings
from crewai.llm import LLM
from crewai.llms.base_llm import BaseLLM
from crewai.tools import BaseTool as CrewAITool
from crewai.tools.base_tool import BaseTool
@@ -29,11 +28,14 @@ from crewai.utilities.i18n import I18N
from crewai.utilities.printer import ColoredText, Printer
from crewai.utilities.types import LLMMessage
if TYPE_CHECKING:
from crewai.agent import Agent
from crewai.llm import LLM
from crewai.task import Task
class SummaryContent(TypedDict):
"""Structure for summary content entries.
@@ -392,8 +394,10 @@ def is_context_length_exceeded(exception: Exception) -> bool:
Returns:
bool: True if the exception is due to context length exceeding
"""
return LLMContextLengthExceededError(str(exception))._is_context_limit_error(
str(exception)
return (
LLMContextLengthExceededError(str(exception))
._is_context_limit_error(str(exception))
._is_context_limit_error(str(exception))
)

View File

@@ -6,6 +6,7 @@ from crewai.cli.constants import DEFAULT_LLM_MODEL, ENV_VARS, LITELLM_PARAMS
from crewai.llm import LLM
from crewai.llms.base_llm import BaseLLM
logger = logging.getLogger(__name__)
@@ -42,7 +43,7 @@ def create_llm(
or str(llm_value)
)
temperature: float | None = getattr(llm_value, "temperature", None)
max_tokens: int | None = getattr(llm_value, "max_tokens", None)
max_tokens: float | int | None = getattr(llm_value, "max_tokens", None)
logprobs: int | None = getattr(llm_value, "logprobs", None)
timeout: float | None = getattr(llm_value, "timeout", None)
api_key: str | None = getattr(llm_value, "api_key", None)
@@ -59,6 +60,7 @@ def create_llm(
base_url=base_url,
api_base=api_base,
)
except Exception as e:
logger.debug(f"Error instantiating LLM from unknown object type: {e}")
return None
@@ -117,6 +119,7 @@ def _llm_via_environment_or_fallback() -> LLM | None:
elif api_base and not base_url:
base_url = api_base
# Initialize llm_params dictionary
llm_params: dict[str, Any] = {
"model": model,
"temperature": temperature,
@@ -140,6 +143,11 @@ def _llm_via_environment_or_fallback() -> LLM | None:
"callbacks": callbacks,
}
unaccepted_attributes = [
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
"AWS_REGION_NAME",
]
set_provider = model_name.partition("/")[0] if "/" in model_name else "openai"
if set_provider in ENV_VARS:
@@ -147,7 +155,7 @@ def _llm_via_environment_or_fallback() -> LLM | None:
if isinstance(env_vars_for_provider, (list, tuple)):
for env_var in env_vars_for_provider:
key_name = env_var.get("key_name")
if key_name and key_name not in UNACCEPTED_ATTRIBUTES:
if key_name and key_name not in unaccepted_attributes:
env_value = os.environ.get(key_name)
if env_value:
# Map environment variable names to recognized parameters

View File

@@ -102,21 +102,18 @@ class AgentReasoning:
try:
output = self.__handle_agent_reasoning()
# Emit reasoning completed event
try:
crewai_event_bus.emit(
self.agent,
AgentReasoningCompletedEvent(
agent_role=self.agent.role,
task_id=str(self.task.id),
plan=output.plan.plan,
ready=output.plan.ready,
attempt=1,
from_task=self.task,
),
)
except Exception: # noqa: S110
pass
crewai_event_bus.emit(
self.agent,
AgentReasoningCompletedEvent(
agent_role=self.agent.role,
task_id=str(self.task.id),
plan=output.plan.plan,
ready=output.plan.ready,
attempt=1,
from_task=self.task,
from_agent=self.agent,
),
)
return output
except Exception as e:
@@ -130,10 +127,11 @@ class AgentReasoning:
error=str(e),
attempt=1,
from_task=self.task,
from_agent=self.agent,
),
)
except Exception: # noqa: S110
pass
except Exception as e:
logging.error(f"Error emitting reasoning failed event: {e}")
raise

View File

@@ -4,10 +4,24 @@ This module provides a callback handler that tracks token usage
for LLM API calls through the litellm library.
"""
from typing import Any
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.utils import Usage
else:
try:
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.utils import Usage
except ImportError:
class CustomLogger:
"""Fallback CustomLogger when litellm is not available."""
class Usage:
"""Fallback Usage when litellm is not available."""
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.utils import Usage
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
from crewai.utilities.logger_utils import suppress_warnings