mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 07:38:29 +00:00
Some checks failed
Check Documentation Broken Links / Check broken links (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
Build uv cache / build-cache (3.10) (push) Has been cancelled
Build uv cache / build-cache (3.11) (push) Has been cancelled
Build uv cache / build-cache (3.12) (push) Has been cancelled
Build uv cache / build-cache (3.13) (push) Has been cancelled
- Replace Python representation with JsonSchema for tool arguments - Remove deprecated PydanticSchemaParser in favor of direct schema generation - Add handling for VAR_POSITIONAL and VAR_KEYWORD parameters - Improve tool argument schema collection
979 lines
35 KiB
Python
979 lines
35 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
from typing import TYPE_CHECKING, Any, TypedDict
|
|
|
|
from pydantic import BaseModel
|
|
from typing_extensions import Self
|
|
|
|
from crewai.utilities.agent_utils import is_context_length_exceeded
|
|
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
|
LLMContextLengthExceededError,
|
|
)
|
|
from crewai.utilities.pydantic_schema_utils import generate_model_description
|
|
from crewai.utilities.types import LLMMessage
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from crewai.llms.hooks.base import BaseInterceptor
|
|
|
|
|
|
try:
|
|
from azure.ai.inference import (
|
|
ChatCompletionsClient,
|
|
)
|
|
from azure.ai.inference.aio import (
|
|
ChatCompletionsClient as AsyncChatCompletionsClient,
|
|
)
|
|
from azure.ai.inference.models import (
|
|
ChatCompletions,
|
|
ChatCompletionsToolCall,
|
|
ChatCompletionsToolDefinition,
|
|
FunctionDefinition,
|
|
JsonSchemaFormat,
|
|
StreamingChatCompletionsUpdate,
|
|
)
|
|
from azure.core.credentials import (
|
|
AzureKeyCredential,
|
|
)
|
|
from azure.core.exceptions import (
|
|
HttpResponseError,
|
|
)
|
|
|
|
from crewai.events.types.llm_events import LLMCallType
|
|
from crewai.llms.base_llm import BaseLLM
|
|
|
|
except ImportError:
|
|
raise ImportError(
|
|
'Azure AI Inference native provider not available, to install: uv add "crewai[azure-ai-inference]"'
|
|
) from None
|
|
|
|
|
|
class AzureCompletionParams(TypedDict, total=False):
|
|
"""Type definition for Azure chat completion parameters."""
|
|
|
|
messages: list[LLMMessage]
|
|
stream: bool
|
|
model_extras: dict[str, Any]
|
|
response_format: JsonSchemaFormat
|
|
model: str
|
|
temperature: float
|
|
top_p: float
|
|
frequency_penalty: float
|
|
presence_penalty: float
|
|
max_tokens: int
|
|
stop: list[str]
|
|
tools: list[ChatCompletionsToolDefinition]
|
|
tool_choice: str
|
|
|
|
|
|
class AzureCompletion(BaseLLM):
|
|
"""Azure AI Inference native completion implementation.
|
|
|
|
This class provides direct integration with the Azure AI Inference Python SDK,
|
|
offering native function calling, streaming support, and proper Azure authentication.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model: str,
|
|
api_key: str | None = None,
|
|
endpoint: str | None = None,
|
|
api_version: str | None = None,
|
|
timeout: float | None = None,
|
|
max_retries: int = 2,
|
|
temperature: float | None = None,
|
|
top_p: float | None = None,
|
|
frequency_penalty: float | None = None,
|
|
presence_penalty: float | None = None,
|
|
max_tokens: int | None = None,
|
|
stop: list[str] | None = None,
|
|
stream: bool = False,
|
|
interceptor: BaseInterceptor[Any, Any] | None = None,
|
|
**kwargs: Any,
|
|
):
|
|
"""Initialize Azure AI Inference chat completion client.
|
|
|
|
Args:
|
|
model: Azure deployment name or model name
|
|
api_key: Azure API key (defaults to AZURE_API_KEY env var)
|
|
endpoint: Azure endpoint URL (defaults to AZURE_ENDPOINT env var)
|
|
api_version: Azure API version (defaults to AZURE_API_VERSION env var)
|
|
timeout: Request timeout in seconds
|
|
max_retries: Maximum number of retries
|
|
temperature: Sampling temperature (0-2)
|
|
top_p: Nucleus sampling parameter
|
|
frequency_penalty: Frequency penalty (-2 to 2)
|
|
presence_penalty: Presence penalty (-2 to 2)
|
|
max_tokens: Maximum tokens in response
|
|
stop: Stop sequences
|
|
stream: Enable streaming responses
|
|
interceptor: HTTP interceptor (not yet supported for Azure).
|
|
**kwargs: Additional parameters
|
|
"""
|
|
if interceptor is not None:
|
|
raise NotImplementedError(
|
|
"HTTP interceptors are not yet supported for Azure AI Inference provider. "
|
|
"Interceptors are currently supported for OpenAI and Anthropic providers only."
|
|
)
|
|
|
|
super().__init__(
|
|
model=model, temperature=temperature, stop=stop or [], **kwargs
|
|
)
|
|
|
|
self.api_key = api_key or os.getenv("AZURE_API_KEY")
|
|
self.endpoint = (
|
|
endpoint
|
|
or os.getenv("AZURE_ENDPOINT")
|
|
or os.getenv("AZURE_OPENAI_ENDPOINT")
|
|
or os.getenv("AZURE_API_BASE")
|
|
)
|
|
self.api_version = api_version or os.getenv("AZURE_API_VERSION") or "2024-06-01"
|
|
self.timeout = timeout
|
|
self.max_retries = max_retries
|
|
|
|
if not self.api_key:
|
|
raise ValueError(
|
|
"Azure API key is required. Set AZURE_API_KEY environment variable or pass api_key parameter."
|
|
)
|
|
if not self.endpoint:
|
|
raise ValueError(
|
|
"Azure endpoint is required. Set AZURE_ENDPOINT environment variable or pass endpoint parameter."
|
|
)
|
|
|
|
# Validate and potentially fix Azure OpenAI endpoint URL
|
|
self.endpoint = self._validate_and_fix_endpoint(self.endpoint, model)
|
|
|
|
# Build client kwargs
|
|
client_kwargs = {
|
|
"endpoint": self.endpoint,
|
|
"credential": AzureKeyCredential(self.api_key),
|
|
}
|
|
|
|
# Add api_version if specified (primarily for Azure OpenAI endpoints)
|
|
if self.api_version:
|
|
client_kwargs["api_version"] = self.api_version
|
|
|
|
self.client = ChatCompletionsClient(**client_kwargs) # type: ignore[arg-type]
|
|
|
|
self.async_client = AsyncChatCompletionsClient(**client_kwargs) # type: ignore[arg-type]
|
|
|
|
self.top_p = top_p
|
|
self.frequency_penalty = frequency_penalty
|
|
self.presence_penalty = presence_penalty
|
|
self.max_tokens = max_tokens
|
|
self.stream = stream
|
|
|
|
self.is_openai_model = any(
|
|
prefix in model.lower() for prefix in ["gpt-", "o1-", "text-"]
|
|
)
|
|
|
|
self.is_azure_openai_endpoint = (
|
|
"openai.azure.com" in self.endpoint
|
|
and "/openai/deployments/" in self.endpoint
|
|
)
|
|
|
|
@staticmethod
|
|
def _validate_and_fix_endpoint(endpoint: str, model: str) -> str:
|
|
"""Validate and fix Azure endpoint URL format.
|
|
|
|
Azure OpenAI endpoints should be in the format:
|
|
https://<resource-name>.openai.azure.com/openai/deployments/<deployment-name>
|
|
|
|
Args:
|
|
endpoint: The endpoint URL
|
|
model: The model/deployment name
|
|
|
|
Returns:
|
|
Validated and potentially corrected endpoint URL
|
|
"""
|
|
if "openai.azure.com" in endpoint and "/openai/deployments/" not in endpoint:
|
|
endpoint = endpoint.rstrip("/")
|
|
|
|
if not endpoint.endswith("/openai/deployments"):
|
|
deployment_name = model.replace("azure/", "")
|
|
endpoint = f"{endpoint}/openai/deployments/{deployment_name}"
|
|
logging.info(f"Constructed Azure OpenAI endpoint URL: {endpoint}")
|
|
|
|
return endpoint
|
|
|
|
def _handle_api_error(
|
|
self,
|
|
error: Exception,
|
|
from_task: Any | None = None,
|
|
from_agent: Any | None = None,
|
|
) -> None:
|
|
"""Handle API errors with appropriate logging and events.
|
|
|
|
Args:
|
|
error: The exception that occurred
|
|
from_task: Task that initiated the call
|
|
from_agent: Agent that initiated the call
|
|
|
|
Raises:
|
|
The original exception after logging and emitting events
|
|
"""
|
|
if isinstance(error, HttpResponseError):
|
|
if error.status_code == 401:
|
|
error_msg = "Azure authentication failed. Check your API key."
|
|
elif error.status_code == 404:
|
|
error_msg = (
|
|
f"Azure endpoint not found. Check endpoint URL: {self.endpoint}"
|
|
)
|
|
elif error.status_code == 429:
|
|
error_msg = "Azure API rate limit exceeded. Please retry later."
|
|
else:
|
|
error_msg = (
|
|
f"Azure API HTTP error: {error.status_code} - {error.message}"
|
|
)
|
|
else:
|
|
error_msg = f"Azure API call failed: {error!s}"
|
|
|
|
logging.error(error_msg)
|
|
self._emit_call_failed_event(
|
|
error=error_msg, from_task=from_task, from_agent=from_agent
|
|
)
|
|
raise error
|
|
|
|
def _handle_completion_error(
|
|
self,
|
|
error: Exception,
|
|
from_task: Any | None = None,
|
|
from_agent: Any | None = None,
|
|
) -> None:
|
|
"""Handle completion-specific errors including context length checks.
|
|
|
|
Args:
|
|
error: The exception that occurred
|
|
from_task: Task that initiated the call
|
|
from_agent: Agent that initiated the call
|
|
|
|
Raises:
|
|
LLMContextLengthExceededError if context window exceeded, otherwise the original exception
|
|
"""
|
|
if is_context_length_exceeded(error):
|
|
logging.error(f"Context window exceeded: {error}")
|
|
raise LLMContextLengthExceededError(str(error)) from error
|
|
|
|
error_msg = f"Azure API call failed: {error!s}"
|
|
logging.error(error_msg)
|
|
self._emit_call_failed_event(
|
|
error=error_msg, from_task=from_task, from_agent=from_agent
|
|
)
|
|
raise error
|
|
|
|
def call(
|
|
self,
|
|
messages: str | list[LLMMessage],
|
|
tools: list[dict[str, Any]] | None = None,
|
|
callbacks: list[Any] | None = None,
|
|
available_functions: dict[str, Any] | None = None,
|
|
from_task: Any | None = None,
|
|
from_agent: Any | None = None,
|
|
response_model: type[BaseModel] | None = None,
|
|
) -> str | Any:
|
|
"""Call Azure AI Inference chat completions API.
|
|
|
|
Args:
|
|
messages: Input messages for the chat completion
|
|
tools: List of tool/function definitions
|
|
callbacks: Callback functions (not used in native implementation)
|
|
available_functions: Available functions for tool calling
|
|
from_task: Task that initiated the call
|
|
from_agent: Agent that initiated the call
|
|
response_model: Response model
|
|
|
|
Returns:
|
|
Chat completion response or tool call result
|
|
"""
|
|
try:
|
|
# Emit call started event
|
|
self._emit_call_started_event(
|
|
messages=messages,
|
|
tools=tools,
|
|
callbacks=callbacks,
|
|
available_functions=available_functions,
|
|
from_task=from_task,
|
|
from_agent=from_agent,
|
|
)
|
|
|
|
# Format messages for Azure
|
|
formatted_messages = self._format_messages_for_azure(messages)
|
|
|
|
if not self._invoke_before_llm_call_hooks(formatted_messages, from_agent):
|
|
raise ValueError("LLM call blocked by before_llm_call hook")
|
|
|
|
# Prepare completion parameters
|
|
completion_params = self._prepare_completion_params(
|
|
formatted_messages, tools, response_model
|
|
)
|
|
|
|
# Handle streaming vs non-streaming
|
|
if self.stream:
|
|
return self._handle_streaming_completion(
|
|
completion_params,
|
|
available_functions,
|
|
from_task,
|
|
from_agent,
|
|
response_model,
|
|
)
|
|
|
|
return self._handle_completion(
|
|
completion_params,
|
|
available_functions,
|
|
from_task,
|
|
from_agent,
|
|
response_model,
|
|
)
|
|
|
|
except Exception as e:
|
|
return self._handle_api_error(e, from_task, from_agent) # type: ignore[func-returns-value]
|
|
|
|
async def acall( # type: ignore[return]
|
|
self,
|
|
messages: str | list[LLMMessage],
|
|
tools: list[dict[str, Any]] | None = None,
|
|
callbacks: list[Any] | None = None,
|
|
available_functions: dict[str, Any] | None = None,
|
|
from_task: Any | None = None,
|
|
from_agent: Any | None = None,
|
|
response_model: type[BaseModel] | None = None,
|
|
) -> str | Any:
|
|
"""Call Azure AI Inference chat completions API asynchronously.
|
|
|
|
Args:
|
|
messages: Input messages for the chat completion
|
|
tools: List of tool/function definitions
|
|
callbacks: Callback functions (not used in native implementation)
|
|
available_functions: Available functions for tool calling
|
|
from_task: Task that initiated the call
|
|
from_agent: Agent that initiated the call
|
|
response_model: Pydantic model for structured output
|
|
|
|
Returns:
|
|
Chat completion response or tool call result
|
|
"""
|
|
try:
|
|
self._emit_call_started_event(
|
|
messages=messages,
|
|
tools=tools,
|
|
callbacks=callbacks,
|
|
available_functions=available_functions,
|
|
from_task=from_task,
|
|
from_agent=from_agent,
|
|
)
|
|
|
|
formatted_messages = self._format_messages_for_azure(messages)
|
|
|
|
completion_params = self._prepare_completion_params(
|
|
formatted_messages, tools, response_model
|
|
)
|
|
|
|
if self.stream:
|
|
return await self._ahandle_streaming_completion(
|
|
completion_params,
|
|
available_functions,
|
|
from_task,
|
|
from_agent,
|
|
response_model,
|
|
)
|
|
|
|
return await self._ahandle_completion(
|
|
completion_params,
|
|
available_functions,
|
|
from_task,
|
|
from_agent,
|
|
response_model,
|
|
)
|
|
|
|
except Exception as e:
|
|
self._handle_api_error(e, from_task, from_agent)
|
|
|
|
def _prepare_completion_params(
|
|
self,
|
|
messages: list[LLMMessage],
|
|
tools: list[dict[str, Any]] | None = None,
|
|
response_model: type[BaseModel] | None = None,
|
|
) -> AzureCompletionParams:
|
|
"""Prepare parameters for Azure AI Inference chat completion.
|
|
|
|
Args:
|
|
messages: Formatted messages for Azure
|
|
tools: Tool definitions
|
|
response_model: Pydantic model for structured output
|
|
|
|
Returns:
|
|
Parameters dictionary for Azure API
|
|
"""
|
|
params: AzureCompletionParams = {
|
|
"messages": messages,
|
|
"stream": self.stream,
|
|
}
|
|
|
|
if self.stream:
|
|
params["model_extras"] = {"stream_options": {"include_usage": True}}
|
|
|
|
if response_model and self.is_openai_model:
|
|
model_description = generate_model_description(response_model)
|
|
json_schema_info = model_description["json_schema"]
|
|
json_schema_name = json_schema_info["name"]
|
|
|
|
params["response_format"] = JsonSchemaFormat(
|
|
name=json_schema_name,
|
|
schema=json_schema_info["schema"],
|
|
description=f"Schema for {json_schema_name}",
|
|
strict=json_schema_info["strict"],
|
|
)
|
|
|
|
# Only include model parameter for non-Azure OpenAI endpoints
|
|
# Azure OpenAI endpoints have the deployment name in the URL
|
|
if not self.is_azure_openai_endpoint:
|
|
params["model"] = self.model
|
|
|
|
# Add optional parameters if set
|
|
if self.temperature is not None:
|
|
params["temperature"] = self.temperature
|
|
if self.top_p is not None:
|
|
params["top_p"] = self.top_p
|
|
if self.frequency_penalty is not None:
|
|
params["frequency_penalty"] = self.frequency_penalty
|
|
if self.presence_penalty is not None:
|
|
params["presence_penalty"] = self.presence_penalty
|
|
if self.max_tokens is not None:
|
|
params["max_tokens"] = self.max_tokens
|
|
if self.stop:
|
|
params["stop"] = self.stop
|
|
|
|
# Handle tools/functions for Azure OpenAI models
|
|
if tools and self.is_openai_model:
|
|
params["tools"] = self._convert_tools_for_interference(tools)
|
|
params["tool_choice"] = "auto"
|
|
|
|
additional_params = self.additional_params
|
|
additional_drop_params = additional_params.get("additional_drop_params")
|
|
drop_params = additional_params.get("drop_params")
|
|
|
|
if drop_params and isinstance(additional_drop_params, list):
|
|
for drop_param in additional_drop_params:
|
|
if isinstance(drop_param, str):
|
|
params.pop(drop_param, None) # type: ignore[misc]
|
|
|
|
return params
|
|
|
|
def _convert_tools_for_interference( # type: ignore[override]
|
|
self, tools: list[dict[str, Any]]
|
|
) -> list[ChatCompletionsToolDefinition]:
|
|
"""Convert CrewAI tool format to Azure OpenAI function calling format.
|
|
|
|
Args:
|
|
tools: List of CrewAI tool definitions
|
|
|
|
Returns:
|
|
List of Azure ChatCompletionsToolDefinition objects
|
|
"""
|
|
from crewai.llms.providers.utils.common import safe_tool_conversion
|
|
|
|
azure_tools: list[ChatCompletionsToolDefinition] = []
|
|
|
|
for tool in tools:
|
|
name, description, parameters = safe_tool_conversion(tool, "Azure")
|
|
|
|
function_def = FunctionDefinition(
|
|
name=name,
|
|
description=description,
|
|
parameters=parameters
|
|
if isinstance(parameters, dict)
|
|
else dict(parameters)
|
|
if parameters
|
|
else None,
|
|
)
|
|
|
|
tool_def = ChatCompletionsToolDefinition(function=function_def)
|
|
|
|
azure_tools.append(tool_def)
|
|
|
|
return azure_tools
|
|
|
|
def _format_messages_for_azure(
|
|
self, messages: str | list[LLMMessage]
|
|
) -> list[LLMMessage]:
|
|
"""Format messages for Azure AI Inference API.
|
|
|
|
Args:
|
|
messages: Input messages
|
|
|
|
Returns:
|
|
List of dict objects with 'role' and 'content' keys
|
|
"""
|
|
# Use base class formatting first
|
|
base_formatted = super()._format_messages(messages)
|
|
|
|
azure_messages: list[LLMMessage] = []
|
|
|
|
for message in base_formatted:
|
|
role = message.get("role", "user") # Default to user if no role
|
|
content = message.get("content", "")
|
|
|
|
# Azure AI Inference requires both 'role' and 'content'
|
|
azure_messages.append({"role": role, "content": content})
|
|
|
|
return azure_messages
|
|
|
|
def _validate_and_emit_structured_output(
|
|
self,
|
|
content: str,
|
|
response_model: type[BaseModel],
|
|
params: AzureCompletionParams,
|
|
from_task: Any | None = None,
|
|
from_agent: Any | None = None,
|
|
) -> str:
|
|
"""Validate content against response model and emit completion event.
|
|
|
|
Args:
|
|
content: Response content to validate
|
|
response_model: Pydantic model for validation
|
|
params: Completion parameters containing messages
|
|
from_task: Task that initiated the call
|
|
from_agent: Agent that initiated the call
|
|
|
|
Returns:
|
|
Validated and serialized JSON string
|
|
|
|
Raises:
|
|
ValueError: If validation fails
|
|
"""
|
|
try:
|
|
structured_data = response_model.model_validate_json(content)
|
|
structured_json = structured_data.model_dump_json()
|
|
|
|
self._emit_call_completed_event(
|
|
response=structured_json,
|
|
call_type=LLMCallType.LLM_CALL,
|
|
from_task=from_task,
|
|
from_agent=from_agent,
|
|
messages=params["messages"],
|
|
)
|
|
|
|
return structured_json
|
|
except Exception as e:
|
|
error_msg = f"Failed to validate structured output with model {response_model.__name__}: {e}"
|
|
logging.error(error_msg)
|
|
raise ValueError(error_msg) from e
|
|
|
|
def _process_completion_response(
|
|
self,
|
|
response: ChatCompletions,
|
|
params: AzureCompletionParams,
|
|
available_functions: dict[str, Any] | None = None,
|
|
from_task: Any | None = None,
|
|
from_agent: Any | None = None,
|
|
response_model: type[BaseModel] | None = None,
|
|
) -> str | Any:
|
|
"""Process completion response with usage tracking, tool execution, and events.
|
|
|
|
Args:
|
|
response: Chat completion response from Azure API
|
|
params: Completion parameters containing messages
|
|
available_functions: Available functions for tool calling
|
|
from_task: Task that initiated the call
|
|
from_agent: Agent that initiated the call
|
|
response_model: Pydantic model for structured output
|
|
|
|
Returns:
|
|
Response content or structured output
|
|
"""
|
|
if not response.choices:
|
|
raise ValueError("No choices returned from Azure API")
|
|
|
|
choice = response.choices[0]
|
|
message = choice.message
|
|
|
|
# Extract and track token usage
|
|
usage = self._extract_azure_token_usage(response)
|
|
self._track_token_usage_internal(usage)
|
|
|
|
if response_model and self.is_openai_model:
|
|
content = message.content or ""
|
|
return self._validate_and_emit_structured_output(
|
|
content=content,
|
|
response_model=response_model,
|
|
params=params,
|
|
from_task=from_task,
|
|
from_agent=from_agent,
|
|
)
|
|
|
|
# Handle tool calls
|
|
if message.tool_calls and available_functions:
|
|
tool_call = message.tool_calls[0] # Handle first tool call
|
|
if isinstance(tool_call, ChatCompletionsToolCall):
|
|
function_name = tool_call.function.name
|
|
|
|
try:
|
|
function_args = json.loads(tool_call.function.arguments)
|
|
except json.JSONDecodeError as e:
|
|
logging.error(f"Failed to parse tool arguments: {e}")
|
|
function_args = {}
|
|
|
|
# Execute tool
|
|
result = self._handle_tool_execution(
|
|
function_name=function_name,
|
|
function_args=function_args,
|
|
available_functions=available_functions,
|
|
from_task=from_task,
|
|
from_agent=from_agent,
|
|
)
|
|
|
|
if result is not None:
|
|
return result
|
|
|
|
# Extract content
|
|
content = message.content or ""
|
|
|
|
# Apply stop words
|
|
content = self._apply_stop_words(content)
|
|
|
|
# Emit completion event and return content
|
|
self._emit_call_completed_event(
|
|
response=content,
|
|
call_type=LLMCallType.LLM_CALL,
|
|
from_task=from_task,
|
|
from_agent=from_agent,
|
|
messages=params["messages"],
|
|
)
|
|
|
|
return self._invoke_after_llm_call_hooks(
|
|
params["messages"], content, from_agent
|
|
)
|
|
|
|
def _handle_completion(
|
|
self,
|
|
params: AzureCompletionParams,
|
|
available_functions: dict[str, Any] | None = None,
|
|
from_task: Any | None = None,
|
|
from_agent: Any | None = None,
|
|
response_model: type[BaseModel] | None = None,
|
|
) -> str | Any:
|
|
"""Handle non-streaming chat completion."""
|
|
try:
|
|
# Cast params to Any to avoid type checking issues with TypedDict unpacking
|
|
response: ChatCompletions = self.client.complete(**params) # type: ignore[assignment,arg-type]
|
|
return self._process_completion_response(
|
|
response=response,
|
|
params=params,
|
|
available_functions=available_functions,
|
|
from_task=from_task,
|
|
from_agent=from_agent,
|
|
response_model=response_model,
|
|
)
|
|
except Exception as e:
|
|
return self._handle_completion_error(e, from_task, from_agent) # type: ignore[func-returns-value]
|
|
|
|
def _process_streaming_update(
|
|
self,
|
|
update: StreamingChatCompletionsUpdate,
|
|
full_response: str,
|
|
tool_calls: dict[str, dict[str, str]],
|
|
from_task: Any | None = None,
|
|
from_agent: Any | None = None,
|
|
) -> str:
|
|
"""Process a single streaming update chunk.
|
|
|
|
Args:
|
|
update: Streaming update from Azure API
|
|
full_response: Accumulated response content
|
|
tool_calls: Dictionary of accumulated tool calls
|
|
from_task: Task that initiated the call
|
|
from_agent: Agent that initiated the call
|
|
|
|
Returns:
|
|
Updated full_response string
|
|
"""
|
|
if update.choices:
|
|
choice = update.choices[0]
|
|
if choice.delta and choice.delta.content:
|
|
content_delta = choice.delta.content
|
|
full_response += content_delta
|
|
self._emit_stream_chunk_event(
|
|
chunk=content_delta,
|
|
from_task=from_task,
|
|
from_agent=from_agent,
|
|
)
|
|
|
|
if choice.delta and choice.delta.tool_calls:
|
|
for tool_call in choice.delta.tool_calls:
|
|
call_id = tool_call.id or "default"
|
|
if call_id not in tool_calls:
|
|
tool_calls[call_id] = {
|
|
"name": "",
|
|
"arguments": "",
|
|
}
|
|
|
|
if tool_call.function and tool_call.function.name:
|
|
tool_calls[call_id]["name"] = tool_call.function.name
|
|
if tool_call.function and tool_call.function.arguments:
|
|
tool_calls[call_id]["arguments"] += tool_call.function.arguments
|
|
|
|
return full_response
|
|
|
|
def _finalize_streaming_response(
|
|
self,
|
|
full_response: str,
|
|
tool_calls: dict[str, dict[str, str]],
|
|
usage_data: dict[str, int],
|
|
params: AzureCompletionParams,
|
|
available_functions: dict[str, Any] | None = None,
|
|
from_task: Any | None = None,
|
|
from_agent: Any | None = None,
|
|
response_model: type[BaseModel] | None = None,
|
|
) -> str | Any:
|
|
"""Finalize streaming response with usage tracking, tool execution, and events.
|
|
|
|
Args:
|
|
full_response: The complete streamed response content
|
|
tool_calls: Dictionary of tool calls accumulated during streaming
|
|
usage_data: Token usage data from the stream
|
|
params: Completion parameters containing messages
|
|
available_functions: Available functions for tool calling
|
|
from_task: Task that initiated the call
|
|
from_agent: Agent that initiated the call
|
|
response_model: Pydantic model for structured output validation
|
|
|
|
Returns:
|
|
Final response content after processing, or structured output
|
|
"""
|
|
self._track_token_usage_internal(usage_data)
|
|
|
|
# Handle structured output validation
|
|
if response_model and self.is_openai_model:
|
|
return self._validate_and_emit_structured_output(
|
|
content=full_response,
|
|
response_model=response_model,
|
|
params=params,
|
|
from_task=from_task,
|
|
from_agent=from_agent,
|
|
)
|
|
|
|
# Handle completed tool calls
|
|
if tool_calls and available_functions:
|
|
for call_data in tool_calls.values():
|
|
function_name = call_data["name"]
|
|
|
|
try:
|
|
function_args = json.loads(call_data["arguments"])
|
|
except json.JSONDecodeError as e:
|
|
logging.error(f"Failed to parse streamed tool arguments: {e}")
|
|
continue
|
|
|
|
# Execute tool
|
|
result = self._handle_tool_execution(
|
|
function_name=function_name,
|
|
function_args=function_args,
|
|
available_functions=available_functions,
|
|
from_task=from_task,
|
|
from_agent=from_agent,
|
|
)
|
|
|
|
if result is not None:
|
|
return result
|
|
|
|
# Apply stop words to full response
|
|
full_response = self._apply_stop_words(full_response)
|
|
|
|
# Emit completion event and return full response
|
|
self._emit_call_completed_event(
|
|
response=full_response,
|
|
call_type=LLMCallType.LLM_CALL,
|
|
from_task=from_task,
|
|
from_agent=from_agent,
|
|
messages=params["messages"],
|
|
)
|
|
|
|
return self._invoke_after_llm_call_hooks(
|
|
params["messages"], full_response, from_agent
|
|
)
|
|
|
|
def _handle_streaming_completion(
|
|
self,
|
|
params: AzureCompletionParams,
|
|
available_functions: dict[str, Any] | None = None,
|
|
from_task: Any | None = None,
|
|
from_agent: Any | None = None,
|
|
response_model: type[BaseModel] | None = None,
|
|
) -> str | Any:
|
|
"""Handle streaming chat completion."""
|
|
full_response = ""
|
|
tool_calls: dict[str, dict[str, Any]] = {}
|
|
|
|
usage_data = {"total_tokens": 0}
|
|
for update in self.client.complete(**params): # type: ignore[arg-type]
|
|
if isinstance(update, StreamingChatCompletionsUpdate):
|
|
if update.usage:
|
|
usage = update.usage
|
|
usage_data = {
|
|
"prompt_tokens": usage.prompt_tokens,
|
|
"completion_tokens": usage.completion_tokens,
|
|
"total_tokens": usage.total_tokens,
|
|
}
|
|
continue
|
|
|
|
full_response = self._process_streaming_update(
|
|
update=update,
|
|
full_response=full_response,
|
|
tool_calls=tool_calls,
|
|
from_task=from_task,
|
|
from_agent=from_agent,
|
|
)
|
|
|
|
return self._finalize_streaming_response(
|
|
full_response=full_response,
|
|
tool_calls=tool_calls,
|
|
usage_data=usage_data,
|
|
params=params,
|
|
available_functions=available_functions,
|
|
from_task=from_task,
|
|
from_agent=from_agent,
|
|
response_model=response_model,
|
|
)
|
|
|
|
async def _ahandle_completion(
|
|
self,
|
|
params: AzureCompletionParams,
|
|
available_functions: dict[str, Any] | None = None,
|
|
from_task: Any | None = None,
|
|
from_agent: Any | None = None,
|
|
response_model: type[BaseModel] | None = None,
|
|
) -> str | Any:
|
|
"""Handle non-streaming chat completion asynchronously."""
|
|
try:
|
|
# Cast params to Any to avoid type checking issues with TypedDict unpacking
|
|
response: ChatCompletions = await self.async_client.complete(**params) # type: ignore[assignment,arg-type]
|
|
return self._process_completion_response(
|
|
response=response,
|
|
params=params,
|
|
available_functions=available_functions,
|
|
from_task=from_task,
|
|
from_agent=from_agent,
|
|
response_model=response_model,
|
|
)
|
|
except Exception as e:
|
|
return self._handle_completion_error(e, from_task, from_agent) # type: ignore[func-returns-value]
|
|
|
|
async def _ahandle_streaming_completion(
|
|
self,
|
|
params: AzureCompletionParams,
|
|
available_functions: dict[str, Any] | None = None,
|
|
from_task: Any | None = None,
|
|
from_agent: Any | None = None,
|
|
response_model: type[BaseModel] | None = None,
|
|
) -> str | Any:
|
|
"""Handle streaming chat completion asynchronously."""
|
|
full_response = ""
|
|
tool_calls: dict[str, dict[str, Any]] = {}
|
|
|
|
usage_data = {"total_tokens": 0}
|
|
|
|
stream = await self.async_client.complete(**params) # type: ignore[arg-type]
|
|
async for update in stream: # type: ignore[union-attr]
|
|
if isinstance(update, StreamingChatCompletionsUpdate):
|
|
if hasattr(update, "usage") and update.usage:
|
|
usage = update.usage
|
|
usage_data = {
|
|
"prompt_tokens": getattr(usage, "prompt_tokens", 0),
|
|
"completion_tokens": getattr(usage, "completion_tokens", 0),
|
|
"total_tokens": getattr(usage, "total_tokens", 0),
|
|
}
|
|
continue
|
|
|
|
full_response = self._process_streaming_update(
|
|
update=update,
|
|
full_response=full_response,
|
|
tool_calls=tool_calls,
|
|
from_task=from_task,
|
|
from_agent=from_agent,
|
|
)
|
|
|
|
return self._finalize_streaming_response(
|
|
full_response=full_response,
|
|
tool_calls=tool_calls,
|
|
usage_data=usage_data,
|
|
params=params,
|
|
available_functions=available_functions,
|
|
from_task=from_task,
|
|
from_agent=from_agent,
|
|
response_model=response_model,
|
|
)
|
|
|
|
def supports_function_calling(self) -> bool:
|
|
"""Check if the model supports function calling."""
|
|
# Azure OpenAI models support function calling
|
|
return self.is_openai_model
|
|
|
|
def supports_stop_words(self) -> bool:
|
|
"""Check if the model supports stop words."""
|
|
return True # Most Azure models support stop sequences
|
|
|
|
def get_context_window_size(self) -> int:
|
|
"""Get the context window size for the model."""
|
|
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES
|
|
|
|
min_context = 1024
|
|
max_context = 2097152
|
|
|
|
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
|
|
if value < min_context or value > max_context:
|
|
raise ValueError(
|
|
f"Context window for {key} must be between {min_context} and {max_context}"
|
|
)
|
|
|
|
# Context window sizes for common Azure models
|
|
context_windows = {
|
|
"gpt-4": 8192,
|
|
"gpt-4o": 128000,
|
|
"gpt-4o-mini": 200000,
|
|
"gpt-4-turbo": 128000,
|
|
"gpt-35-turbo": 16385,
|
|
"gpt-3.5-turbo": 16385,
|
|
"text-embedding": 8191,
|
|
}
|
|
|
|
# Find the best match for the model name
|
|
for model_prefix, size in sorted(
|
|
context_windows.items(), key=lambda x: len(x[0]), reverse=True
|
|
):
|
|
if self.model.startswith(model_prefix):
|
|
return int(size * CONTEXT_WINDOW_USAGE_RATIO)
|
|
|
|
# Default context window size
|
|
return int(8192 * CONTEXT_WINDOW_USAGE_RATIO)
|
|
|
|
@staticmethod
|
|
def _extract_azure_token_usage(response: ChatCompletions) -> dict[str, Any]:
|
|
"""Extract token usage from Azure response."""
|
|
if hasattr(response, "usage") and response.usage:
|
|
usage = response.usage
|
|
return {
|
|
"prompt_tokens": getattr(usage, "prompt_tokens", 0),
|
|
"completion_tokens": getattr(usage, "completion_tokens", 0),
|
|
"total_tokens": getattr(usage, "total_tokens", 0),
|
|
}
|
|
return {"total_tokens": 0}
|
|
|
|
async def aclose(self) -> None:
|
|
"""Close the async client and clean up resources.
|
|
|
|
This ensures proper cleanup of the underlying aiohttp session
|
|
to avoid unclosed connector warnings.
|
|
"""
|
|
if hasattr(self.async_client, "close"):
|
|
await self.async_client.close()
|
|
|
|
async def __aenter__(self) -> Self:
|
|
"""Async context manager entry."""
|
|
return self
|
|
|
|
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
"""Async context manager exit."""
|
|
await self.aclose()
|