mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 00:58:30 +00:00
feat: azure native tests
* feat: add Azure AI Inference support and related tests - Introduced the `azure-ai-inference` package with version `1.0.0b9` and its dependencies in `uv.lock` and `pyproject.toml`. - Added new test files for Azure LLM functionality, including tests for Azure completion and tool handling. - Implemented comprehensive test cases to validate Azure-specific behavior and integration with the CrewAI framework. - Enhanced the testing framework to mock Azure credentials and ensure proper isolation during tests. * feat: enhance AzureCompletion class with Azure OpenAI support - Added support for the Azure OpenAI endpoint in the AzureCompletion class, allowing for flexible endpoint configurations. - Implemented endpoint validation and correction to ensure proper URL formats for Azure OpenAI deployments. - Enhanced error handling to provide clearer messages for common HTTP errors, including authentication and rate limit issues. - Updated tests to validate the new endpoint handling and error messaging, ensuring robust integration with Azure AI Inference. - Refactored parameter preparation to conditionally include the model parameter based on the endpoint type.
This commit is contained in:
@@ -90,6 +90,9 @@ boto3 = [
|
||||
google-genai = [
|
||||
"google-genai>=1.2.0",
|
||||
]
|
||||
azure-ai-inference = [
|
||||
"azure-ai-inference>=1.0.0b9",
|
||||
]
|
||||
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -351,7 +351,7 @@ class LLM(BaseLLM):
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
elif provider == "azure":
|
||||
elif provider == "azure" or provider == "azure_openai":
|
||||
try:
|
||||
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||
|
||||
|
||||
@@ -10,14 +10,14 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
|
||||
|
||||
try:
|
||||
from azure.ai.inference import ChatCompletionsClient # type: ignore
|
||||
from azure.ai.inference.models import ( # type: ignore
|
||||
from azure.ai.inference import ChatCompletionsClient
|
||||
from azure.ai.inference.models import (
|
||||
ChatCompletions,
|
||||
ChatCompletionsToolCall,
|
||||
StreamingChatCompletionsUpdate,
|
||||
)
|
||||
from azure.core.credentials import AzureKeyCredential # type: ignore
|
||||
from azure.core.exceptions import HttpResponseError # type: ignore
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
from azure.core.exceptions import HttpResponseError
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
|
||||
@@ -80,7 +80,9 @@ class AzureCompletion(BaseLLM):
|
||||
or os.getenv("AZURE_OPENAI_ENDPOINT")
|
||||
or os.getenv("AZURE_API_BASE")
|
||||
)
|
||||
self.api_version = api_version or os.getenv("AZURE_API_VERSION") or "2024-02-01"
|
||||
self.api_version = api_version or os.getenv("AZURE_API_VERSION") or "2024-06-01"
|
||||
self.timeout = timeout
|
||||
self.max_retries = max_retries
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
@@ -91,10 +93,20 @@ class AzureCompletion(BaseLLM):
|
||||
"Azure endpoint is required. Set AZURE_ENDPOINT environment variable or pass endpoint parameter."
|
||||
)
|
||||
|
||||
self.client = ChatCompletionsClient(
|
||||
endpoint=self.endpoint,
|
||||
credential=AzureKeyCredential(self.api_key),
|
||||
)
|
||||
# Validate and potentially fix Azure OpenAI endpoint URL
|
||||
self.endpoint = self._validate_and_fix_endpoint(self.endpoint, model)
|
||||
|
||||
# Build client kwargs
|
||||
client_kwargs = {
|
||||
"endpoint": self.endpoint,
|
||||
"credential": AzureKeyCredential(self.api_key),
|
||||
}
|
||||
|
||||
# Add api_version if specified (primarily for Azure OpenAI endpoints)
|
||||
if self.api_version:
|
||||
client_kwargs["api_version"] = self.api_version
|
||||
|
||||
self.client = ChatCompletionsClient(**client_kwargs)
|
||||
|
||||
self.top_p = top_p
|
||||
self.frequency_penalty = frequency_penalty
|
||||
@@ -106,6 +118,34 @@ class AzureCompletion(BaseLLM):
|
||||
prefix in model.lower() for prefix in ["gpt-", "o1-", "text-"]
|
||||
)
|
||||
|
||||
self.is_azure_openai_endpoint = (
|
||||
"openai.azure.com" in self.endpoint
|
||||
and "/openai/deployments/" in self.endpoint
|
||||
)
|
||||
|
||||
def _validate_and_fix_endpoint(self, endpoint: str, model: str) -> str:
|
||||
"""Validate and fix Azure endpoint URL format.
|
||||
|
||||
Azure OpenAI endpoints should be in the format:
|
||||
https://<resource-name>.openai.azure.com/openai/deployments/<deployment-name>
|
||||
|
||||
Args:
|
||||
endpoint: The endpoint URL
|
||||
model: The model/deployment name
|
||||
|
||||
Returns:
|
||||
Validated and potentially corrected endpoint URL
|
||||
"""
|
||||
if "openai.azure.com" in endpoint and "/openai/deployments/" not in endpoint:
|
||||
endpoint = endpoint.rstrip("/")
|
||||
|
||||
if not endpoint.endswith("/openai/deployments"):
|
||||
deployment_name = model.replace("azure/", "")
|
||||
endpoint = f"{endpoint}/openai/deployments/{deployment_name}"
|
||||
logging.info(f"Constructed Azure OpenAI endpoint URL: {endpoint}")
|
||||
|
||||
return endpoint
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: str | list[dict[str, str]],
|
||||
@@ -158,7 +198,17 @@ class AzureCompletion(BaseLLM):
|
||||
)
|
||||
|
||||
except HttpResponseError as e:
|
||||
error_msg = f"Azure API HTTP error: {e.status_code} - {e.message}"
|
||||
if e.status_code == 401:
|
||||
error_msg = "Azure authentication failed. Check your API key."
|
||||
elif e.status_code == 404:
|
||||
error_msg = (
|
||||
f"Azure endpoint not found. Check endpoint URL: {self.endpoint}"
|
||||
)
|
||||
elif e.status_code == 429:
|
||||
error_msg = "Azure API rate limit exceeded. Please retry later."
|
||||
else:
|
||||
error_msg = f"Azure API HTTP error: {e.status_code} - {e.message}"
|
||||
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
@@ -187,11 +237,15 @@ class AzureCompletion(BaseLLM):
|
||||
Parameters dictionary for Azure API
|
||||
"""
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"stream": self.stream,
|
||||
}
|
||||
|
||||
# Only include model parameter for non-Azure OpenAI endpoints
|
||||
# Azure OpenAI endpoints have the deployment name in the URL
|
||||
if not self.is_azure_openai_endpoint:
|
||||
params["model"] = self.model
|
||||
|
||||
# Add optional parameters if set
|
||||
if self.temperature is not None:
|
||||
params["temperature"] = self.temperature
|
||||
@@ -250,7 +304,7 @@ class AzureCompletion(BaseLLM):
|
||||
messages: Input messages
|
||||
|
||||
Returns:
|
||||
List of dict objects
|
||||
List of dict objects with 'role' and 'content' keys
|
||||
"""
|
||||
# Use base class formatting first
|
||||
base_formatted = super()._format_messages(messages)
|
||||
@@ -258,18 +312,11 @@ class AzureCompletion(BaseLLM):
|
||||
azure_messages = []
|
||||
|
||||
for message in base_formatted:
|
||||
role = message.get("role")
|
||||
role = message.get("role", "user") # Default to user if no role
|
||||
content = message.get("content", "")
|
||||
|
||||
if role == "system":
|
||||
azure_messages.append(dict(content=content))
|
||||
elif role == "user":
|
||||
azure_messages.append(dict(content=content))
|
||||
elif role == "assistant":
|
||||
azure_messages.append(dict(content=content))
|
||||
else:
|
||||
# Default to user message for unknown roles
|
||||
azure_messages.append(dict(content=content))
|
||||
# Azure AI Inference requires both 'role' and 'content'
|
||||
azure_messages.append({"role": role, "content": content})
|
||||
|
||||
return azure_messages
|
||||
|
||||
@@ -339,6 +386,13 @@ class AzureCompletion(BaseLLM):
|
||||
logging.error(f"Context window exceeded: {e}")
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
|
||||
error_msg = f"Azure API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise e
|
||||
|
||||
return content
|
||||
|
||||
def _handle_streaming_completion(
|
||||
@@ -454,7 +508,9 @@ class AzureCompletion(BaseLLM):
|
||||
}
|
||||
|
||||
# Find the best match for the model name
|
||||
for model_prefix, size in context_windows.items():
|
||||
for model_prefix, size in sorted(
|
||||
context_windows.items(), key=lambda x: len(x[0]), reverse=True
|
||||
):
|
||||
if self.model.startswith(model_prefix):
|
||||
return int(size * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
|
||||
|
||||
0
lib/crewai/tests/llms/__init__.py
Normal file
0
lib/crewai/tests/llms/__init__.py
Normal file
3
lib/crewai/tests/llms/azure/__init__.py
Normal file
3
lib/crewai/tests/llms/azure/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# Azure LLM tests
|
||||
|
||||
|
||||
1087
lib/crewai/tests/llms/azure/test_azure.py
Normal file
1087
lib/crewai/tests/llms/azure/test_azure.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user