Enhance Azure OpenAI embedding validation with comprehensive parameter checks

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-03-10 23:56:06 +00:00
parent e5ca5fb1dd
commit ffdc9a1aa3
2 changed files with 119 additions and 33 deletions

View File

@@ -74,7 +74,7 @@ class EmbeddingConfigurator:
) )
@staticmethod @staticmethod
def _configure_azure(config, model_name): def _configure_azure(config: Dict[str, Any], model_name: str) -> EmbeddingFunction:
""" """
Configure an Azure OpenAI embedding function. Configure an Azure OpenAI embedding function.
@@ -91,18 +91,30 @@ class EmbeddingConfigurator:
An OpenAIEmbeddingFunction configured for Azure OpenAI An OpenAIEmbeddingFunction configured for Azure OpenAI
Raises: Raises:
ValueError: If required parameters are missing ValueError: If required parameters are missing or invalid
""" """
from chromadb.utils.embedding_functions.openai_embedding_function import ( from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction, OpenAIEmbeddingFunction,
) )
# Check if deployment_id is provided for Azure OpenAI # Check required parameters for Azure OpenAI
deployment_id = config.get("deployment_id") required_params = {
if not deployment_id: "api_key": "API key",
"api_base": "API base URL",
"api_version": "API version",
"deployment_id": "deployment ID"
}
missing_params = []
for param, description in required_params.items():
if not config.get(param):
missing_params.append(f"{description} ({param})")
if missing_params:
params_str = ", ".join(missing_params)
raise ValueError( raise ValueError(
"Missing required parameter 'deployment_id' for Azure OpenAI embeddings. " f"Missing required parameters for Azure OpenAI embeddings: {params_str}. "
"Please provide a deployment_id in your Azure embedder configuration." f"Ensure these parameters match your Azure OpenAI embedding model configuration."
) )
return OpenAIEmbeddingFunction( return OpenAIEmbeddingFunction(
@@ -113,7 +125,7 @@ class EmbeddingConfigurator:
model_name=model_name, model_name=model_name,
default_headers=config.get("default_headers"), default_headers=config.get("default_headers"),
dimensions=config.get("dimensions"), dimensions=config.get("dimensions"),
deployment_id=deployment_id, deployment_id=config.get("deployment_id"),
organization_id=config.get("organization_id"), organization_id=config.get("organization_id"),
) )

View File

@@ -4,51 +4,125 @@ import pytest
from crewai.utilities.embedding_configurator import EmbeddingConfigurator from crewai.utilities.embedding_configurator import EmbeddingConfigurator
# Test constants for Azure OpenAI configurations
AZURE_BASE_CONFIG = {
"provider": "azure",
"config": {
"model": "text-embedding-ada-002",
"api_key": "test-key",
"api_base": "https://test.openai.azure.com",
"api_version": "2023-05-15",
"api_type": "azure",
}
}
AZURE_COMPLETE_CONFIG = {
"provider": "azure",
"config": {
"model": "text-embedding-ada-002",
"api_key": "test-key",
"api_base": "https://test.openai.azure.com",
"api_version": "2023-05-15",
"api_type": "azure",
"deployment_id": "text-embedding-ada-002",
}
}
def test_azure_embedder_missing_deployment_id(): def test_azure_embedder_missing_deployment_id():
"""Test that Azure embedder raises an error when deployment_id is missing""" """Test that Azure embedder raises an error when deployment_id is missing"""
embedder_config = { embedder_config = AZURE_BASE_CONFIG.copy()
"provider": "azure",
"config": {
"model": "text-embedding-ada-002",
"api_key": "test-key",
"api_base": "https://test.openai.azure.com",
"api_version": "2023-05-15",
"api_type": "azure",
}
}
configurator = EmbeddingConfigurator() configurator = EmbeddingConfigurator()
with pytest.raises(ValueError) as excinfo: with pytest.raises(ValueError) as excinfo:
configurator.configure_embedder(embedder_config) configurator.configure_embedder(embedder_config)
assert "Missing required parameter 'deployment_id'" in str(excinfo.value) assert "Missing required parameters" in str(excinfo.value)
assert "deployment ID (deployment_id)" in str(excinfo.value)
def test_azure_embedder_missing_api_key():
"""Test that Azure embedder raises an error when api_key is missing"""
embedder_config = AZURE_BASE_CONFIG.copy()
embedder_config["config"] = embedder_config["config"].copy()
embedder_config["config"]["deployment_id"] = "text-embedding-ada-002"
embedder_config["config"].pop("api_key")
configurator = EmbeddingConfigurator()
with pytest.raises(ValueError) as excinfo:
configurator.configure_embedder(embedder_config)
assert "Missing required parameters" in str(excinfo.value)
assert "API key (api_key)" in str(excinfo.value)
def test_azure_embedder_missing_api_base():
"""Test that Azure embedder raises an error when api_base is missing"""
embedder_config = AZURE_BASE_CONFIG.copy()
embedder_config["config"] = embedder_config["config"].copy()
embedder_config["config"]["deployment_id"] = "text-embedding-ada-002"
embedder_config["config"].pop("api_base")
configurator = EmbeddingConfigurator()
with pytest.raises(ValueError) as excinfo:
configurator.configure_embedder(embedder_config)
assert "Missing required parameters" in str(excinfo.value)
assert "API base URL (api_base)" in str(excinfo.value)
def test_azure_embedder_missing_api_version():
"""Test that Azure embedder raises an error when api_version is missing"""
embedder_config = AZURE_BASE_CONFIG.copy()
embedder_config["config"] = embedder_config["config"].copy()
embedder_config["config"]["deployment_id"] = "text-embedding-ada-002"
embedder_config["config"].pop("api_version")
configurator = EmbeddingConfigurator()
with pytest.raises(ValueError) as excinfo:
configurator.configure_embedder(embedder_config)
assert "Missing required parameters" in str(excinfo.value)
assert "API version (api_version)" in str(excinfo.value)
def test_azure_embedder_empty_parameters():
"""Test that Azure embedder raises an error when parameters are empty strings"""
embedder_config = AZURE_BASE_CONFIG.copy()
embedder_config["config"] = embedder_config["config"].copy()
embedder_config["config"]["deployment_id"] = ""
embedder_config["config"]["api_key"] = ""
configurator = EmbeddingConfigurator()
with pytest.raises(ValueError) as excinfo:
configurator.configure_embedder(embedder_config)
assert "Missing required parameters" in str(excinfo.value)
assert "API key (api_key)" in str(excinfo.value)
assert "deployment ID (deployment_id)" in str(excinfo.value)
@patch("chromadb.utils.embedding_functions.openai_embedding_function.OpenAIEmbeddingFunction") @patch("chromadb.utils.embedding_functions.openai_embedding_function.OpenAIEmbeddingFunction")
def test_azure_embedder_with_deployment_id(mock_openai_embedding): def test_azure_embedder_with_all_required_parameters(mock_openai_embedding):
"""Test that Azure embedder works when deployment_id is provided""" """Test that Azure embedder works when all required parameters are provided"""
mock_instance = MagicMock() mock_instance = MagicMock()
mock_openai_embedding.return_value = mock_instance mock_openai_embedding.return_value = mock_instance
embedder_config = { embedder_config = AZURE_COMPLETE_CONFIG.copy()
"provider": "azure",
"config": {
"model": "text-embedding-ada-002",
"api_key": "test-key",
"api_base": "https://test.openai.azure.com",
"api_version": "2023-05-15",
"api_type": "azure",
"deployment_id": "text-embedding-ada-002",
}
}
configurator = EmbeddingConfigurator() configurator = EmbeddingConfigurator()
result = configurator.configure_embedder(embedder_config) result = configurator.configure_embedder(embedder_config)
assert result == mock_instance assert result == mock_instance
mock_openai_embedding.assert_called_once() mock_openai_embedding.assert_called_once()
# Verify deployment_id was passed correctly # Verify parameters were passed correctly
call_kwargs = mock_openai_embedding.call_args.kwargs call_kwargs = mock_openai_embedding.call_args.kwargs
assert call_kwargs["api_key"] == "test-key"
assert call_kwargs["api_base"] == "https://test.openai.azure.com"
assert call_kwargs["api_version"] == "2023-05-15"
assert call_kwargs["deployment_id"] == "text-embedding-ada-002" assert call_kwargs["deployment_id"] == "text-embedding-ada-002"