feat: azure native tests

* feat: add Azure AI Inference support and related tests

- Introduced the `azure-ai-inference` package with version `1.0.0b9` and its dependencies in `uv.lock` and `pyproject.toml`.
- Added new test files for Azure LLM functionality, including tests for Azure completion and tool handling.
- Implemented comprehensive test cases to validate Azure-specific behavior and integration with the CrewAI framework.
- Enhanced the testing framework to mock Azure credentials and ensure proper isolation during tests.

* feat: enhance AzureCompletion class with Azure OpenAI support

- Added support for the Azure OpenAI endpoint in the AzureCompletion class, allowing for flexible endpoint configurations.
- Implemented endpoint validation and correction to ensure proper URL formats for Azure OpenAI deployments.
- Enhanced error handling to provide clearer messages for common HTTP errors, including authentication and rate limit issues.
- Updated tests to validate the new endpoint handling and error messaging, ensuring robust integration with Azure AI Inference.
- Refactored parameter preparation to conditionally include the model parameter based on the endpoint type.
This commit is contained in:
Lorenze Jay
2025-10-17 08:36:29 -07:00
committed by GitHub
parent 3b32793e78
commit 6515c7faeb
7 changed files with 1214 additions and 25 deletions

View File

@@ -90,6 +90,9 @@ boto3 = [
google-genai = [ google-genai = [
"google-genai>=1.2.0", "google-genai>=1.2.0",
] ]
azure-ai-inference = [
"azure-ai-inference>=1.0.0b9",
]
[project.scripts] [project.scripts]

View File

@@ -351,7 +351,7 @@ class LLM(BaseLLM):
except ImportError: except ImportError:
return None return None
elif provider == "azure": elif provider == "azure" or provider == "azure_openai":
try: try:
from crewai.llms.providers.azure.completion import AzureCompletion from crewai.llms.providers.azure.completion import AzureCompletion

View File

@@ -10,14 +10,14 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
try: try:
from azure.ai.inference import ChatCompletionsClient # type: ignore from azure.ai.inference import ChatCompletionsClient
from azure.ai.inference.models import ( # type: ignore from azure.ai.inference.models import (
ChatCompletions, ChatCompletions,
ChatCompletionsToolCall, ChatCompletionsToolCall,
StreamingChatCompletionsUpdate, StreamingChatCompletionsUpdate,
) )
from azure.core.credentials import AzureKeyCredential # type: ignore from azure.core.credentials import AzureKeyCredential
from azure.core.exceptions import HttpResponseError # type: ignore from azure.core.exceptions import HttpResponseError
from crewai.events.types.llm_events import LLMCallType from crewai.events.types.llm_events import LLMCallType
from crewai.llms.base_llm import BaseLLM from crewai.llms.base_llm import BaseLLM
@@ -80,7 +80,9 @@ class AzureCompletion(BaseLLM):
or os.getenv("AZURE_OPENAI_ENDPOINT") or os.getenv("AZURE_OPENAI_ENDPOINT")
or os.getenv("AZURE_API_BASE") or os.getenv("AZURE_API_BASE")
) )
self.api_version = api_version or os.getenv("AZURE_API_VERSION") or "2024-02-01" self.api_version = api_version or os.getenv("AZURE_API_VERSION") or "2024-06-01"
self.timeout = timeout
self.max_retries = max_retries
if not self.api_key: if not self.api_key:
raise ValueError( raise ValueError(
@@ -91,10 +93,20 @@ class AzureCompletion(BaseLLM):
"Azure endpoint is required. Set AZURE_ENDPOINT environment variable or pass endpoint parameter." "Azure endpoint is required. Set AZURE_ENDPOINT environment variable or pass endpoint parameter."
) )
self.client = ChatCompletionsClient( # Validate and potentially fix Azure OpenAI endpoint URL
endpoint=self.endpoint, self.endpoint = self._validate_and_fix_endpoint(self.endpoint, model)
credential=AzureKeyCredential(self.api_key),
) # Build client kwargs
client_kwargs = {
"endpoint": self.endpoint,
"credential": AzureKeyCredential(self.api_key),
}
# Add api_version if specified (primarily for Azure OpenAI endpoints)
if self.api_version:
client_kwargs["api_version"] = self.api_version
self.client = ChatCompletionsClient(**client_kwargs)
self.top_p = top_p self.top_p = top_p
self.frequency_penalty = frequency_penalty self.frequency_penalty = frequency_penalty
@@ -106,6 +118,34 @@ class AzureCompletion(BaseLLM):
prefix in model.lower() for prefix in ["gpt-", "o1-", "text-"] prefix in model.lower() for prefix in ["gpt-", "o1-", "text-"]
) )
self.is_azure_openai_endpoint = (
"openai.azure.com" in self.endpoint
and "/openai/deployments/" in self.endpoint
)
def _validate_and_fix_endpoint(self, endpoint: str, model: str) -> str:
"""Validate and fix Azure endpoint URL format.
Azure OpenAI endpoints should be in the format:
https://<resource-name>.openai.azure.com/openai/deployments/<deployment-name>
Args:
endpoint: The endpoint URL
model: The model/deployment name
Returns:
Validated and potentially corrected endpoint URL
"""
if "openai.azure.com" in endpoint and "/openai/deployments/" not in endpoint:
endpoint = endpoint.rstrip("/")
if not endpoint.endswith("/openai/deployments"):
deployment_name = model.replace("azure/", "")
endpoint = f"{endpoint}/openai/deployments/{deployment_name}"
logging.info(f"Constructed Azure OpenAI endpoint URL: {endpoint}")
return endpoint
def call( def call(
self, self,
messages: str | list[dict[str, str]], messages: str | list[dict[str, str]],
@@ -158,7 +198,17 @@ class AzureCompletion(BaseLLM):
) )
except HttpResponseError as e: except HttpResponseError as e:
error_msg = f"Azure API HTTP error: {e.status_code} - {e.message}" if e.status_code == 401:
error_msg = "Azure authentication failed. Check your API key."
elif e.status_code == 404:
error_msg = (
f"Azure endpoint not found. Check endpoint URL: {self.endpoint}"
)
elif e.status_code == 429:
error_msg = "Azure API rate limit exceeded. Please retry later."
else:
error_msg = f"Azure API HTTP error: {e.status_code} - {e.message}"
logging.error(error_msg) logging.error(error_msg)
self._emit_call_failed_event( self._emit_call_failed_event(
error=error_msg, from_task=from_task, from_agent=from_agent error=error_msg, from_task=from_task, from_agent=from_agent
@@ -187,11 +237,15 @@ class AzureCompletion(BaseLLM):
Parameters dictionary for Azure API Parameters dictionary for Azure API
""" """
params = { params = {
"model": self.model,
"messages": messages, "messages": messages,
"stream": self.stream, "stream": self.stream,
} }
# Only include model parameter for non-Azure OpenAI endpoints
# Azure OpenAI endpoints have the deployment name in the URL
if not self.is_azure_openai_endpoint:
params["model"] = self.model
# Add optional parameters if set # Add optional parameters if set
if self.temperature is not None: if self.temperature is not None:
params["temperature"] = self.temperature params["temperature"] = self.temperature
@@ -250,7 +304,7 @@ class AzureCompletion(BaseLLM):
messages: Input messages messages: Input messages
Returns: Returns:
List of dict objects List of dict objects with 'role' and 'content' keys
""" """
# Use base class formatting first # Use base class formatting first
base_formatted = super()._format_messages(messages) base_formatted = super()._format_messages(messages)
@@ -258,18 +312,11 @@ class AzureCompletion(BaseLLM):
azure_messages = [] azure_messages = []
for message in base_formatted: for message in base_formatted:
role = message.get("role") role = message.get("role", "user") # Default to user if no role
content = message.get("content", "") content = message.get("content", "")
if role == "system": # Azure AI Inference requires both 'role' and 'content'
azure_messages.append(dict(content=content)) azure_messages.append({"role": role, "content": content})
elif role == "user":
azure_messages.append(dict(content=content))
elif role == "assistant":
azure_messages.append(dict(content=content))
else:
# Default to user message for unknown roles
azure_messages.append(dict(content=content))
return azure_messages return azure_messages
@@ -339,6 +386,13 @@ class AzureCompletion(BaseLLM):
logging.error(f"Context window exceeded: {e}") logging.error(f"Context window exceeded: {e}")
raise LLMContextLengthExceededError(str(e)) from e raise LLMContextLengthExceededError(str(e)) from e
error_msg = f"Azure API call failed: {e!s}"
logging.error(error_msg)
self._emit_call_failed_event(
error=error_msg, from_task=from_task, from_agent=from_agent
)
raise e
return content return content
def _handle_streaming_completion( def _handle_streaming_completion(
@@ -454,7 +508,9 @@ class AzureCompletion(BaseLLM):
} }
# Find the best match for the model name # Find the best match for the model name
for model_prefix, size in context_windows.items(): for model_prefix, size in sorted(
context_windows.items(), key=lambda x: len(x[0]), reverse=True
):
if self.model.startswith(model_prefix): if self.model.startswith(model_prefix):
return int(size * CONTEXT_WINDOW_USAGE_RATIO) return int(size * CONTEXT_WINDOW_USAGE_RATIO)

View File

View File

@@ -0,0 +1,3 @@
# Azure LLM tests

File diff suppressed because it is too large Load Diff

42
uv.lock generated
View File

@@ -344,6 +344,33 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/f8/aa/5082412d1ee302e9e7d80b6949bc4d2a8fa1149aaab610c5fc24709605d6/authlib-1.6.5-py2.py3-none-any.whl", hash = "sha256:3e0e0507807f842b02175507bdee8957a1d5707fd4afb17c32fb43fee90b6e3a", size = 243608, upload-time = "2025-10-02T13:36:07.637Z" }, { url = "https://files.pythonhosted.org/packages/f8/aa/5082412d1ee302e9e7d80b6949bc4d2a8fa1149aaab610c5fc24709605d6/authlib-1.6.5-py2.py3-none-any.whl", hash = "sha256:3e0e0507807f842b02175507bdee8957a1d5707fd4afb17c32fb43fee90b6e3a", size = 243608, upload-time = "2025-10-02T13:36:07.637Z" },
] ]
[[package]]
name = "azure-ai-inference"
version = "1.0.0b9"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "azure-core" },
{ name = "isodate" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/4e/6a/ed85592e5c64e08c291992f58b1a94dab6869f28fb0f40fd753dced73ba6/azure_ai_inference-1.0.0b9.tar.gz", hash = "sha256:1feb496bd84b01ee2691befc04358fa25d7c344d8288e99364438859ad7cd5a4", size = 182408, upload-time = "2025-02-15T00:37:28.464Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/4f/0f/27520da74769db6e58327d96c98e7b9a07ce686dff582c9a5ec60b03f9dd/azure_ai_inference-1.0.0b9-py3-none-any.whl", hash = "sha256:49823732e674092dad83bb8b0d1b65aa73111fab924d61349eb2a8cdc0493990", size = 124885, upload-time = "2025-02-15T00:37:29.964Z" },
]
[[package]]
name = "azure-core"
version = "1.36.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "requests" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/0a/c4/d4ff3bc3ddf155156460bff340bbe9533f99fac54ddea165f35a8619f162/azure_core-1.36.0.tar.gz", hash = "sha256:22e5605e6d0bf1d229726af56d9e92bc37b6e726b141a18be0b4d424131741b7", size = 351139, upload-time = "2025-10-15T00:33:49.083Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b1/3c/b90d5afc2e47c4a45f4bba00f9c3193b0417fad5ad3bb07869f9d12832aa/azure_core-1.36.0-py3-none-any.whl", hash = "sha256:fee9923a3a753e94a259563429f3644aaf05c486d45b1215d098115102d91d3b", size = 213302, upload-time = "2025-10-15T00:33:51.058Z" },
]
[[package]] [[package]]
name = "backoff" name = "backoff"
version = "2.2.1" version = "2.2.1"
@@ -1042,6 +1069,9 @@ aisuite = [
aws = [ aws = [
{ name = "boto3" }, { name = "boto3" },
] ]
azure-ai-inference = [
{ name = "azure-ai-inference" },
]
boto3 = [ boto3 = [
{ name = "boto3" }, { name = "boto3" },
] ]
@@ -1086,6 +1116,7 @@ watson = [
requires-dist = [ requires-dist = [
{ name = "aisuite", marker = "extra == 'aisuite'", specifier = ">=0.1.10" }, { name = "aisuite", marker = "extra == 'aisuite'", specifier = ">=0.1.10" },
{ name = "appdirs", specifier = ">=1.4.4" }, { name = "appdirs", specifier = ">=1.4.4" },
{ name = "azure-ai-inference", marker = "extra == 'azure-ai-inference'", specifier = ">=1.0.0b9" },
{ name = "boto3", marker = "extra == 'aws'", specifier = ">=1.40.38" }, { name = "boto3", marker = "extra == 'aws'", specifier = ">=1.40.38" },
{ name = "boto3", marker = "extra == 'boto3'", specifier = ">=1.40.45" }, { name = "boto3", marker = "extra == 'boto3'", specifier = ">=1.40.45" },
{ name = "chromadb", specifier = "~=1.1.0" }, { name = "chromadb", specifier = "~=1.1.0" },
@@ -1124,7 +1155,7 @@ requires-dist = [
{ name = "uv", specifier = ">=0.4.25" }, { name = "uv", specifier = ">=0.4.25" },
{ name = "voyageai", marker = "extra == 'voyageai'", specifier = ">=0.3.5" }, { name = "voyageai", marker = "extra == 'voyageai'", specifier = ">=0.3.5" },
] ]
provides-extras = ["aisuite", "aws", "boto3", "docling", "embeddings", "google-genai", "litellm", "mem0", "openpyxl", "pandas", "pdfplumber", "qdrant", "tools", "voyageai", "watson"] provides-extras = ["aisuite", "aws", "azure-ai-inference", "boto3", "docling", "embeddings", "google-genai", "litellm", "mem0", "openpyxl", "pandas", "pdfplumber", "qdrant", "tools", "voyageai", "watson"]
[[package]] [[package]]
name = "crewai-devtools" name = "crewai-devtools"
@@ -2912,6 +2943,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/d9/33/1f075bf72b0b747cb3288d011319aaf64083cf2efef8354174e3ed4540e2/ipython_pygments_lexers-1.1.1-py3-none-any.whl", hash = "sha256:a9462224a505ade19a605f71f8fa63c2048833ce50abc86768a0d81d876dc81c", size = 8074, upload-time = "2025-01-17T11:24:33.271Z" }, { url = "https://files.pythonhosted.org/packages/d9/33/1f075bf72b0b747cb3288d011319aaf64083cf2efef8354174e3ed4540e2/ipython_pygments_lexers-1.1.1-py3-none-any.whl", hash = "sha256:a9462224a505ade19a605f71f8fa63c2048833ce50abc86768a0d81d876dc81c", size = 8074, upload-time = "2025-01-17T11:24:33.271Z" },
] ]
[[package]]
name = "isodate"
version = "0.7.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/54/4d/e940025e2ce31a8ce1202635910747e5a87cc3a6a6bb2d00973375014749/isodate-0.7.2.tar.gz", hash = "sha256:4cd1aa0f43ca76f4a6c6c0292a85f40b35ec2e43e315b59f06e6d32171a953e6", size = 29705, upload-time = "2024-10-08T23:04:11.5Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/15/aa/0aca39a37d3c7eb941ba736ede56d689e7be91cab5d9ca846bde3999eba6/isodate-0.7.2-py3-none-any.whl", hash = "sha256:28009937d8031054830160fce6d409ed342816b543597cece116d966c6d99e15", size = 22320, upload-time = "2024-10-08T23:04:09.501Z" },
]
[[package]] [[package]]
name = "jedi" name = "jedi"
version = "0.19.2" version = "0.19.2"