mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 23:58:34 +00:00
feat: azure native tests
* feat: add Azure AI Inference support and related tests - Introduced the `azure-ai-inference` package with version `1.0.0b9` and its dependencies in `uv.lock` and `pyproject.toml`. - Added new test files for Azure LLM functionality, including tests for Azure completion and tool handling. - Implemented comprehensive test cases to validate Azure-specific behavior and integration with the CrewAI framework. - Enhanced the testing framework to mock Azure credentials and ensure proper isolation during tests. * feat: enhance AzureCompletion class with Azure OpenAI support - Added support for the Azure OpenAI endpoint in the AzureCompletion class, allowing for flexible endpoint configurations. - Implemented endpoint validation and correction to ensure proper URL formats for Azure OpenAI deployments. - Enhanced error handling to provide clearer messages for common HTTP errors, including authentication and rate limit issues. - Updated tests to validate the new endpoint handling and error messaging, ensuring robust integration with Azure AI Inference. - Refactored parameter preparation to conditionally include the model parameter based on the endpoint type.
This commit is contained in:
@@ -90,6 +90,9 @@ boto3 = [
|
||||
google-genai = [
|
||||
"google-genai>=1.2.0",
|
||||
]
|
||||
azure-ai-inference = [
|
||||
"azure-ai-inference>=1.0.0b9",
|
||||
]
|
||||
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -351,7 +351,7 @@ class LLM(BaseLLM):
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
elif provider == "azure":
|
||||
elif provider == "azure" or provider == "azure_openai":
|
||||
try:
|
||||
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||
|
||||
|
||||
@@ -10,14 +10,14 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
|
||||
|
||||
try:
|
||||
from azure.ai.inference import ChatCompletionsClient # type: ignore
|
||||
from azure.ai.inference.models import ( # type: ignore
|
||||
from azure.ai.inference import ChatCompletionsClient
|
||||
from azure.ai.inference.models import (
|
||||
ChatCompletions,
|
||||
ChatCompletionsToolCall,
|
||||
StreamingChatCompletionsUpdate,
|
||||
)
|
||||
from azure.core.credentials import AzureKeyCredential # type: ignore
|
||||
from azure.core.exceptions import HttpResponseError # type: ignore
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
from azure.core.exceptions import HttpResponseError
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
|
||||
@@ -80,7 +80,9 @@ class AzureCompletion(BaseLLM):
|
||||
or os.getenv("AZURE_OPENAI_ENDPOINT")
|
||||
or os.getenv("AZURE_API_BASE")
|
||||
)
|
||||
self.api_version = api_version or os.getenv("AZURE_API_VERSION") or "2024-02-01"
|
||||
self.api_version = api_version or os.getenv("AZURE_API_VERSION") or "2024-06-01"
|
||||
self.timeout = timeout
|
||||
self.max_retries = max_retries
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
@@ -91,10 +93,20 @@ class AzureCompletion(BaseLLM):
|
||||
"Azure endpoint is required. Set AZURE_ENDPOINT environment variable or pass endpoint parameter."
|
||||
)
|
||||
|
||||
self.client = ChatCompletionsClient(
|
||||
endpoint=self.endpoint,
|
||||
credential=AzureKeyCredential(self.api_key),
|
||||
)
|
||||
# Validate and potentially fix Azure OpenAI endpoint URL
|
||||
self.endpoint = self._validate_and_fix_endpoint(self.endpoint, model)
|
||||
|
||||
# Build client kwargs
|
||||
client_kwargs = {
|
||||
"endpoint": self.endpoint,
|
||||
"credential": AzureKeyCredential(self.api_key),
|
||||
}
|
||||
|
||||
# Add api_version if specified (primarily for Azure OpenAI endpoints)
|
||||
if self.api_version:
|
||||
client_kwargs["api_version"] = self.api_version
|
||||
|
||||
self.client = ChatCompletionsClient(**client_kwargs)
|
||||
|
||||
self.top_p = top_p
|
||||
self.frequency_penalty = frequency_penalty
|
||||
@@ -106,6 +118,34 @@ class AzureCompletion(BaseLLM):
|
||||
prefix in model.lower() for prefix in ["gpt-", "o1-", "text-"]
|
||||
)
|
||||
|
||||
self.is_azure_openai_endpoint = (
|
||||
"openai.azure.com" in self.endpoint
|
||||
and "/openai/deployments/" in self.endpoint
|
||||
)
|
||||
|
||||
def _validate_and_fix_endpoint(self, endpoint: str, model: str) -> str:
|
||||
"""Validate and fix Azure endpoint URL format.
|
||||
|
||||
Azure OpenAI endpoints should be in the format:
|
||||
https://<resource-name>.openai.azure.com/openai/deployments/<deployment-name>
|
||||
|
||||
Args:
|
||||
endpoint: The endpoint URL
|
||||
model: The model/deployment name
|
||||
|
||||
Returns:
|
||||
Validated and potentially corrected endpoint URL
|
||||
"""
|
||||
if "openai.azure.com" in endpoint and "/openai/deployments/" not in endpoint:
|
||||
endpoint = endpoint.rstrip("/")
|
||||
|
||||
if not endpoint.endswith("/openai/deployments"):
|
||||
deployment_name = model.replace("azure/", "")
|
||||
endpoint = f"{endpoint}/openai/deployments/{deployment_name}"
|
||||
logging.info(f"Constructed Azure OpenAI endpoint URL: {endpoint}")
|
||||
|
||||
return endpoint
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: str | list[dict[str, str]],
|
||||
@@ -158,7 +198,17 @@ class AzureCompletion(BaseLLM):
|
||||
)
|
||||
|
||||
except HttpResponseError as e:
|
||||
error_msg = f"Azure API HTTP error: {e.status_code} - {e.message}"
|
||||
if e.status_code == 401:
|
||||
error_msg = "Azure authentication failed. Check your API key."
|
||||
elif e.status_code == 404:
|
||||
error_msg = (
|
||||
f"Azure endpoint not found. Check endpoint URL: {self.endpoint}"
|
||||
)
|
||||
elif e.status_code == 429:
|
||||
error_msg = "Azure API rate limit exceeded. Please retry later."
|
||||
else:
|
||||
error_msg = f"Azure API HTTP error: {e.status_code} - {e.message}"
|
||||
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
@@ -187,11 +237,15 @@ class AzureCompletion(BaseLLM):
|
||||
Parameters dictionary for Azure API
|
||||
"""
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"stream": self.stream,
|
||||
}
|
||||
|
||||
# Only include model parameter for non-Azure OpenAI endpoints
|
||||
# Azure OpenAI endpoints have the deployment name in the URL
|
||||
if not self.is_azure_openai_endpoint:
|
||||
params["model"] = self.model
|
||||
|
||||
# Add optional parameters if set
|
||||
if self.temperature is not None:
|
||||
params["temperature"] = self.temperature
|
||||
@@ -250,7 +304,7 @@ class AzureCompletion(BaseLLM):
|
||||
messages: Input messages
|
||||
|
||||
Returns:
|
||||
List of dict objects
|
||||
List of dict objects with 'role' and 'content' keys
|
||||
"""
|
||||
# Use base class formatting first
|
||||
base_formatted = super()._format_messages(messages)
|
||||
@@ -258,18 +312,11 @@ class AzureCompletion(BaseLLM):
|
||||
azure_messages = []
|
||||
|
||||
for message in base_formatted:
|
||||
role = message.get("role")
|
||||
role = message.get("role", "user") # Default to user if no role
|
||||
content = message.get("content", "")
|
||||
|
||||
if role == "system":
|
||||
azure_messages.append(dict(content=content))
|
||||
elif role == "user":
|
||||
azure_messages.append(dict(content=content))
|
||||
elif role == "assistant":
|
||||
azure_messages.append(dict(content=content))
|
||||
else:
|
||||
# Default to user message for unknown roles
|
||||
azure_messages.append(dict(content=content))
|
||||
# Azure AI Inference requires both 'role' and 'content'
|
||||
azure_messages.append({"role": role, "content": content})
|
||||
|
||||
return azure_messages
|
||||
|
||||
@@ -339,6 +386,13 @@ class AzureCompletion(BaseLLM):
|
||||
logging.error(f"Context window exceeded: {e}")
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
|
||||
error_msg = f"Azure API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise e
|
||||
|
||||
return content
|
||||
|
||||
def _handle_streaming_completion(
|
||||
@@ -454,7 +508,9 @@ class AzureCompletion(BaseLLM):
|
||||
}
|
||||
|
||||
# Find the best match for the model name
|
||||
for model_prefix, size in context_windows.items():
|
||||
for model_prefix, size in sorted(
|
||||
context_windows.items(), key=lambda x: len(x[0]), reverse=True
|
||||
):
|
||||
if self.model.startswith(model_prefix):
|
||||
return int(size * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
|
||||
|
||||
0
lib/crewai/tests/llms/__init__.py
Normal file
0
lib/crewai/tests/llms/__init__.py
Normal file
3
lib/crewai/tests/llms/azure/__init__.py
Normal file
3
lib/crewai/tests/llms/azure/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# Azure LLM tests
|
||||
|
||||
|
||||
1087
lib/crewai/tests/llms/azure/test_azure.py
Normal file
1087
lib/crewai/tests/llms/azure/test_azure.py
Normal file
File diff suppressed because it is too large
Load Diff
42
uv.lock
generated
42
uv.lock
generated
@@ -344,6 +344,33 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f8/aa/5082412d1ee302e9e7d80b6949bc4d2a8fa1149aaab610c5fc24709605d6/authlib-1.6.5-py2.py3-none-any.whl", hash = "sha256:3e0e0507807f842b02175507bdee8957a1d5707fd4afb17c32fb43fee90b6e3a", size = 243608, upload-time = "2025-10-02T13:36:07.637Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "azure-ai-inference"
|
||||
version = "1.0.0b9"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "azure-core" },
|
||||
{ name = "isodate" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/4e/6a/ed85592e5c64e08c291992f58b1a94dab6869f28fb0f40fd753dced73ba6/azure_ai_inference-1.0.0b9.tar.gz", hash = "sha256:1feb496bd84b01ee2691befc04358fa25d7c344d8288e99364438859ad7cd5a4", size = 182408, upload-time = "2025-02-15T00:37:28.464Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/4f/0f/27520da74769db6e58327d96c98e7b9a07ce686dff582c9a5ec60b03f9dd/azure_ai_inference-1.0.0b9-py3-none-any.whl", hash = "sha256:49823732e674092dad83bb8b0d1b65aa73111fab924d61349eb2a8cdc0493990", size = 124885, upload-time = "2025-02-15T00:37:29.964Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "azure-core"
|
||||
version = "1.36.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "requests" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/0a/c4/d4ff3bc3ddf155156460bff340bbe9533f99fac54ddea165f35a8619f162/azure_core-1.36.0.tar.gz", hash = "sha256:22e5605e6d0bf1d229726af56d9e92bc37b6e726b141a18be0b4d424131741b7", size = 351139, upload-time = "2025-10-15T00:33:49.083Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b1/3c/b90d5afc2e47c4a45f4bba00f9c3193b0417fad5ad3bb07869f9d12832aa/azure_core-1.36.0-py3-none-any.whl", hash = "sha256:fee9923a3a753e94a259563429f3644aaf05c486d45b1215d098115102d91d3b", size = 213302, upload-time = "2025-10-15T00:33:51.058Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "backoff"
|
||||
version = "2.2.1"
|
||||
@@ -1042,6 +1069,9 @@ aisuite = [
|
||||
aws = [
|
||||
{ name = "boto3" },
|
||||
]
|
||||
azure-ai-inference = [
|
||||
{ name = "azure-ai-inference" },
|
||||
]
|
||||
boto3 = [
|
||||
{ name = "boto3" },
|
||||
]
|
||||
@@ -1086,6 +1116,7 @@ watson = [
|
||||
requires-dist = [
|
||||
{ name = "aisuite", marker = "extra == 'aisuite'", specifier = ">=0.1.10" },
|
||||
{ name = "appdirs", specifier = ">=1.4.4" },
|
||||
{ name = "azure-ai-inference", marker = "extra == 'azure-ai-inference'", specifier = ">=1.0.0b9" },
|
||||
{ name = "boto3", marker = "extra == 'aws'", specifier = ">=1.40.38" },
|
||||
{ name = "boto3", marker = "extra == 'boto3'", specifier = ">=1.40.45" },
|
||||
{ name = "chromadb", specifier = "~=1.1.0" },
|
||||
@@ -1124,7 +1155,7 @@ requires-dist = [
|
||||
{ name = "uv", specifier = ">=0.4.25" },
|
||||
{ name = "voyageai", marker = "extra == 'voyageai'", specifier = ">=0.3.5" },
|
||||
]
|
||||
provides-extras = ["aisuite", "aws", "boto3", "docling", "embeddings", "google-genai", "litellm", "mem0", "openpyxl", "pandas", "pdfplumber", "qdrant", "tools", "voyageai", "watson"]
|
||||
provides-extras = ["aisuite", "aws", "azure-ai-inference", "boto3", "docling", "embeddings", "google-genai", "litellm", "mem0", "openpyxl", "pandas", "pdfplumber", "qdrant", "tools", "voyageai", "watson"]
|
||||
|
||||
[[package]]
|
||||
name = "crewai-devtools"
|
||||
@@ -2912,6 +2943,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/d9/33/1f075bf72b0b747cb3288d011319aaf64083cf2efef8354174e3ed4540e2/ipython_pygments_lexers-1.1.1-py3-none-any.whl", hash = "sha256:a9462224a505ade19a605f71f8fa63c2048833ce50abc86768a0d81d876dc81c", size = 8074, upload-time = "2025-01-17T11:24:33.271Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "isodate"
|
||||
version = "0.7.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/54/4d/e940025e2ce31a8ce1202635910747e5a87cc3a6a6bb2d00973375014749/isodate-0.7.2.tar.gz", hash = "sha256:4cd1aa0f43ca76f4a6c6c0292a85f40b35ec2e43e315b59f06e6d32171a953e6", size = 29705, upload-time = "2024-10-08T23:04:11.5Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/15/aa/0aca39a37d3c7eb941ba736ede56d689e7be91cab5d9ca846bde3999eba6/isodate-0.7.2-py3-none-any.whl", hash = "sha256:28009937d8031054830160fce6d409ed342816b543597cece116d966c6d99e15", size = 22320, upload-time = "2024-10-08T23:04:09.501Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jedi"
|
||||
version = "0.19.2"
|
||||
|
||||
Reference in New Issue
Block a user