feat: azure native tests

* feat: add Azure AI Inference support and related tests

- Introduced the `azure-ai-inference` package with version `1.0.0b9` and its dependencies in `uv.lock` and `pyproject.toml`.
- Added new test files for Azure LLM functionality, including tests for Azure completion and tool handling.
- Implemented comprehensive test cases to validate Azure-specific behavior and integration with the CrewAI framework.
- Enhanced the testing framework to mock Azure credentials and ensure proper isolation during tests.

* feat: enhance AzureCompletion class with Azure OpenAI support

- Added support for the Azure OpenAI endpoint in the AzureCompletion class, allowing for flexible endpoint configurations.
- Implemented endpoint validation and correction to ensure proper URL formats for Azure OpenAI deployments.
- Enhanced error handling to provide clearer messages for common HTTP errors, including authentication and rate limit issues.
- Updated tests to validate the new endpoint handling and error messaging, ensuring robust integration with Azure AI Inference.
- Refactored parameter preparation to conditionally include the model parameter based on the endpoint type.
This commit is contained in:
Lorenze Jay
2025-10-17 08:36:29 -07:00
committed by GitHub
parent 3b32793e78
commit 6515c7faeb
7 changed files with 1214 additions and 25 deletions

View File

@@ -90,6 +90,9 @@ boto3 = [
google-genai = [
"google-genai>=1.2.0",
]
azure-ai-inference = [
"azure-ai-inference>=1.0.0b9",
]
[project.scripts]

View File

@@ -351,7 +351,7 @@ class LLM(BaseLLM):
except ImportError:
return None
elif provider == "azure":
elif provider == "azure" or provider == "azure_openai":
try:
from crewai.llms.providers.azure.completion import AzureCompletion

View File

@@ -10,14 +10,14 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
try:
from azure.ai.inference import ChatCompletionsClient # type: ignore
from azure.ai.inference.models import ( # type: ignore
from azure.ai.inference import ChatCompletionsClient
from azure.ai.inference.models import (
ChatCompletions,
ChatCompletionsToolCall,
StreamingChatCompletionsUpdate,
)
from azure.core.credentials import AzureKeyCredential # type: ignore
from azure.core.exceptions import HttpResponseError # type: ignore
from azure.core.credentials import AzureKeyCredential
from azure.core.exceptions import HttpResponseError
from crewai.events.types.llm_events import LLMCallType
from crewai.llms.base_llm import BaseLLM
@@ -80,7 +80,9 @@ class AzureCompletion(BaseLLM):
or os.getenv("AZURE_OPENAI_ENDPOINT")
or os.getenv("AZURE_API_BASE")
)
self.api_version = api_version or os.getenv("AZURE_API_VERSION") or "2024-02-01"
self.api_version = api_version or os.getenv("AZURE_API_VERSION") or "2024-06-01"
self.timeout = timeout
self.max_retries = max_retries
if not self.api_key:
raise ValueError(
@@ -91,10 +93,20 @@ class AzureCompletion(BaseLLM):
"Azure endpoint is required. Set AZURE_ENDPOINT environment variable or pass endpoint parameter."
)
self.client = ChatCompletionsClient(
endpoint=self.endpoint,
credential=AzureKeyCredential(self.api_key),
)
# Validate and potentially fix Azure OpenAI endpoint URL
self.endpoint = self._validate_and_fix_endpoint(self.endpoint, model)
# Build client kwargs
client_kwargs = {
"endpoint": self.endpoint,
"credential": AzureKeyCredential(self.api_key),
}
# Add api_version if specified (primarily for Azure OpenAI endpoints)
if self.api_version:
client_kwargs["api_version"] = self.api_version
self.client = ChatCompletionsClient(**client_kwargs)
self.top_p = top_p
self.frequency_penalty = frequency_penalty
@@ -106,6 +118,34 @@ class AzureCompletion(BaseLLM):
prefix in model.lower() for prefix in ["gpt-", "o1-", "text-"]
)
self.is_azure_openai_endpoint = (
"openai.azure.com" in self.endpoint
and "/openai/deployments/" in self.endpoint
)
def _validate_and_fix_endpoint(self, endpoint: str, model: str) -> str:
"""Validate and fix Azure endpoint URL format.
Azure OpenAI endpoints should be in the format:
https://<resource-name>.openai.azure.com/openai/deployments/<deployment-name>
Args:
endpoint: The endpoint URL
model: The model/deployment name
Returns:
Validated and potentially corrected endpoint URL
"""
if "openai.azure.com" in endpoint and "/openai/deployments/" not in endpoint:
endpoint = endpoint.rstrip("/")
if not endpoint.endswith("/openai/deployments"):
deployment_name = model.replace("azure/", "")
endpoint = f"{endpoint}/openai/deployments/{deployment_name}"
logging.info(f"Constructed Azure OpenAI endpoint URL: {endpoint}")
return endpoint
def call(
self,
messages: str | list[dict[str, str]],
@@ -158,7 +198,17 @@ class AzureCompletion(BaseLLM):
)
except HttpResponseError as e:
error_msg = f"Azure API HTTP error: {e.status_code} - {e.message}"
if e.status_code == 401:
error_msg = "Azure authentication failed. Check your API key."
elif e.status_code == 404:
error_msg = (
f"Azure endpoint not found. Check endpoint URL: {self.endpoint}"
)
elif e.status_code == 429:
error_msg = "Azure API rate limit exceeded. Please retry later."
else:
error_msg = f"Azure API HTTP error: {e.status_code} - {e.message}"
logging.error(error_msg)
self._emit_call_failed_event(
error=error_msg, from_task=from_task, from_agent=from_agent
@@ -187,11 +237,15 @@ class AzureCompletion(BaseLLM):
Parameters dictionary for Azure API
"""
params = {
"model": self.model,
"messages": messages,
"stream": self.stream,
}
# Only include model parameter for non-Azure OpenAI endpoints
# Azure OpenAI endpoints have the deployment name in the URL
if not self.is_azure_openai_endpoint:
params["model"] = self.model
# Add optional parameters if set
if self.temperature is not None:
params["temperature"] = self.temperature
@@ -250,7 +304,7 @@ class AzureCompletion(BaseLLM):
messages: Input messages
Returns:
List of dict objects
List of dict objects with 'role' and 'content' keys
"""
# Use base class formatting first
base_formatted = super()._format_messages(messages)
@@ -258,18 +312,11 @@ class AzureCompletion(BaseLLM):
azure_messages = []
for message in base_formatted:
role = message.get("role")
role = message.get("role", "user") # Default to user if no role
content = message.get("content", "")
if role == "system":
azure_messages.append(dict(content=content))
elif role == "user":
azure_messages.append(dict(content=content))
elif role == "assistant":
azure_messages.append(dict(content=content))
else:
# Default to user message for unknown roles
azure_messages.append(dict(content=content))
# Azure AI Inference requires both 'role' and 'content'
azure_messages.append({"role": role, "content": content})
return azure_messages
@@ -339,6 +386,13 @@ class AzureCompletion(BaseLLM):
logging.error(f"Context window exceeded: {e}")
raise LLMContextLengthExceededError(str(e)) from e
error_msg = f"Azure API call failed: {e!s}"
logging.error(error_msg)
self._emit_call_failed_event(
error=error_msg, from_task=from_task, from_agent=from_agent
)
raise e
return content
def _handle_streaming_completion(
@@ -454,7 +508,9 @@ class AzureCompletion(BaseLLM):
}
# Find the best match for the model name
for model_prefix, size in context_windows.items():
for model_prefix, size in sorted(
context_windows.items(), key=lambda x: len(x[0]), reverse=True
):
if self.model.startswith(model_prefix):
return int(size * CONTEXT_WINDOW_USAGE_RATIO)

View File

View File

@@ -0,0 +1,3 @@
# Azure LLM tests

File diff suppressed because it is too large Load Diff