mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
Enhance Azure OpenAI embedding validation with comprehensive parameter checks
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -74,7 +74,7 @@ class EmbeddingConfigurator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_azure(config, model_name):
|
def _configure_azure(config: Dict[str, Any], model_name: str) -> EmbeddingFunction:
|
||||||
"""
|
"""
|
||||||
Configure an Azure OpenAI embedding function.
|
Configure an Azure OpenAI embedding function.
|
||||||
|
|
||||||
@@ -91,18 +91,30 @@ class EmbeddingConfigurator:
|
|||||||
An OpenAIEmbeddingFunction configured for Azure OpenAI
|
An OpenAIEmbeddingFunction configured for Azure OpenAI
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If required parameters are missing
|
ValueError: If required parameters are missing or invalid
|
||||||
"""
|
"""
|
||||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||||
OpenAIEmbeddingFunction,
|
OpenAIEmbeddingFunction,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if deployment_id is provided for Azure OpenAI
|
# Check required parameters for Azure OpenAI
|
||||||
deployment_id = config.get("deployment_id")
|
required_params = {
|
||||||
if not deployment_id:
|
"api_key": "API key",
|
||||||
|
"api_base": "API base URL",
|
||||||
|
"api_version": "API version",
|
||||||
|
"deployment_id": "deployment ID"
|
||||||
|
}
|
||||||
|
|
||||||
|
missing_params = []
|
||||||
|
for param, description in required_params.items():
|
||||||
|
if not config.get(param):
|
||||||
|
missing_params.append(f"{description} ({param})")
|
||||||
|
|
||||||
|
if missing_params:
|
||||||
|
params_str = ", ".join(missing_params)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Missing required parameter 'deployment_id' for Azure OpenAI embeddings. "
|
f"Missing required parameters for Azure OpenAI embeddings: {params_str}. "
|
||||||
"Please provide a deployment_id in your Azure embedder configuration."
|
f"Ensure these parameters match your Azure OpenAI embedding model configuration."
|
||||||
)
|
)
|
||||||
|
|
||||||
return OpenAIEmbeddingFunction(
|
return OpenAIEmbeddingFunction(
|
||||||
@@ -113,7 +125,7 @@ class EmbeddingConfigurator:
|
|||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
default_headers=config.get("default_headers"),
|
default_headers=config.get("default_headers"),
|
||||||
dimensions=config.get("dimensions"),
|
dimensions=config.get("dimensions"),
|
||||||
deployment_id=deployment_id,
|
deployment_id=config.get("deployment_id"),
|
||||||
organization_id=config.get("organization_id"),
|
organization_id=config.get("organization_id"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -4,51 +4,125 @@ import pytest
|
|||||||
|
|
||||||
from crewai.utilities.embedding_configurator import EmbeddingConfigurator
|
from crewai.utilities.embedding_configurator import EmbeddingConfigurator
|
||||||
|
|
||||||
|
# Test constants for Azure OpenAI configurations
|
||||||
|
AZURE_BASE_CONFIG = {
|
||||||
|
"provider": "azure",
|
||||||
|
"config": {
|
||||||
|
"model": "text-embedding-ada-002",
|
||||||
|
"api_key": "test-key",
|
||||||
|
"api_base": "https://test.openai.azure.com",
|
||||||
|
"api_version": "2023-05-15",
|
||||||
|
"api_type": "azure",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
AZURE_COMPLETE_CONFIG = {
|
||||||
|
"provider": "azure",
|
||||||
|
"config": {
|
||||||
|
"model": "text-embedding-ada-002",
|
||||||
|
"api_key": "test-key",
|
||||||
|
"api_base": "https://test.openai.azure.com",
|
||||||
|
"api_version": "2023-05-15",
|
||||||
|
"api_type": "azure",
|
||||||
|
"deployment_id": "text-embedding-ada-002",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_azure_embedder_missing_deployment_id():
|
def test_azure_embedder_missing_deployment_id():
|
||||||
"""Test that Azure embedder raises an error when deployment_id is missing"""
|
"""Test that Azure embedder raises an error when deployment_id is missing"""
|
||||||
embedder_config = {
|
embedder_config = AZURE_BASE_CONFIG.copy()
|
||||||
"provider": "azure",
|
|
||||||
"config": {
|
|
||||||
"model": "text-embedding-ada-002",
|
|
||||||
"api_key": "test-key",
|
|
||||||
"api_base": "https://test.openai.azure.com",
|
|
||||||
"api_version": "2023-05-15",
|
|
||||||
"api_type": "azure",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
configurator = EmbeddingConfigurator()
|
configurator = EmbeddingConfigurator()
|
||||||
|
|
||||||
with pytest.raises(ValueError) as excinfo:
|
with pytest.raises(ValueError) as excinfo:
|
||||||
configurator.configure_embedder(embedder_config)
|
configurator.configure_embedder(embedder_config)
|
||||||
|
|
||||||
assert "Missing required parameter 'deployment_id'" in str(excinfo.value)
|
assert "Missing required parameters" in str(excinfo.value)
|
||||||
|
assert "deployment ID (deployment_id)" in str(excinfo.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_azure_embedder_missing_api_key():
|
||||||
|
"""Test that Azure embedder raises an error when api_key is missing"""
|
||||||
|
embedder_config = AZURE_BASE_CONFIG.copy()
|
||||||
|
embedder_config["config"] = embedder_config["config"].copy()
|
||||||
|
embedder_config["config"]["deployment_id"] = "text-embedding-ada-002"
|
||||||
|
embedder_config["config"].pop("api_key")
|
||||||
|
|
||||||
|
configurator = EmbeddingConfigurator()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as excinfo:
|
||||||
|
configurator.configure_embedder(embedder_config)
|
||||||
|
|
||||||
|
assert "Missing required parameters" in str(excinfo.value)
|
||||||
|
assert "API key (api_key)" in str(excinfo.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_azure_embedder_missing_api_base():
|
||||||
|
"""Test that Azure embedder raises an error when api_base is missing"""
|
||||||
|
embedder_config = AZURE_BASE_CONFIG.copy()
|
||||||
|
embedder_config["config"] = embedder_config["config"].copy()
|
||||||
|
embedder_config["config"]["deployment_id"] = "text-embedding-ada-002"
|
||||||
|
embedder_config["config"].pop("api_base")
|
||||||
|
|
||||||
|
configurator = EmbeddingConfigurator()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as excinfo:
|
||||||
|
configurator.configure_embedder(embedder_config)
|
||||||
|
|
||||||
|
assert "Missing required parameters" in str(excinfo.value)
|
||||||
|
assert "API base URL (api_base)" in str(excinfo.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_azure_embedder_missing_api_version():
|
||||||
|
"""Test that Azure embedder raises an error when api_version is missing"""
|
||||||
|
embedder_config = AZURE_BASE_CONFIG.copy()
|
||||||
|
embedder_config["config"] = embedder_config["config"].copy()
|
||||||
|
embedder_config["config"]["deployment_id"] = "text-embedding-ada-002"
|
||||||
|
embedder_config["config"].pop("api_version")
|
||||||
|
|
||||||
|
configurator = EmbeddingConfigurator()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as excinfo:
|
||||||
|
configurator.configure_embedder(embedder_config)
|
||||||
|
|
||||||
|
assert "Missing required parameters" in str(excinfo.value)
|
||||||
|
assert "API version (api_version)" in str(excinfo.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_azure_embedder_empty_parameters():
|
||||||
|
"""Test that Azure embedder raises an error when parameters are empty strings"""
|
||||||
|
embedder_config = AZURE_BASE_CONFIG.copy()
|
||||||
|
embedder_config["config"] = embedder_config["config"].copy()
|
||||||
|
embedder_config["config"]["deployment_id"] = ""
|
||||||
|
embedder_config["config"]["api_key"] = ""
|
||||||
|
|
||||||
|
configurator = EmbeddingConfigurator()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as excinfo:
|
||||||
|
configurator.configure_embedder(embedder_config)
|
||||||
|
|
||||||
|
assert "Missing required parameters" in str(excinfo.value)
|
||||||
|
assert "API key (api_key)" in str(excinfo.value)
|
||||||
|
assert "deployment ID (deployment_id)" in str(excinfo.value)
|
||||||
|
|
||||||
|
|
||||||
@patch("chromadb.utils.embedding_functions.openai_embedding_function.OpenAIEmbeddingFunction")
|
@patch("chromadb.utils.embedding_functions.openai_embedding_function.OpenAIEmbeddingFunction")
|
||||||
def test_azure_embedder_with_deployment_id(mock_openai_embedding):
|
def test_azure_embedder_with_all_required_parameters(mock_openai_embedding):
|
||||||
"""Test that Azure embedder works when deployment_id is provided"""
|
"""Test that Azure embedder works when all required parameters are provided"""
|
||||||
mock_instance = MagicMock()
|
mock_instance = MagicMock()
|
||||||
mock_openai_embedding.return_value = mock_instance
|
mock_openai_embedding.return_value = mock_instance
|
||||||
|
|
||||||
embedder_config = {
|
embedder_config = AZURE_COMPLETE_CONFIG.copy()
|
||||||
"provider": "azure",
|
|
||||||
"config": {
|
|
||||||
"model": "text-embedding-ada-002",
|
|
||||||
"api_key": "test-key",
|
|
||||||
"api_base": "https://test.openai.azure.com",
|
|
||||||
"api_version": "2023-05-15",
|
|
||||||
"api_type": "azure",
|
|
||||||
"deployment_id": "text-embedding-ada-002",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
configurator = EmbeddingConfigurator()
|
configurator = EmbeddingConfigurator()
|
||||||
result = configurator.configure_embedder(embedder_config)
|
result = configurator.configure_embedder(embedder_config)
|
||||||
|
|
||||||
assert result == mock_instance
|
assert result == mock_instance
|
||||||
mock_openai_embedding.assert_called_once()
|
mock_openai_embedding.assert_called_once()
|
||||||
# Verify deployment_id was passed correctly
|
# Verify parameters were passed correctly
|
||||||
call_kwargs = mock_openai_embedding.call_args.kwargs
|
call_kwargs = mock_openai_embedding.call_args.kwargs
|
||||||
|
assert call_kwargs["api_key"] == "test-key"
|
||||||
|
assert call_kwargs["api_base"] == "https://test.openai.azure.com"
|
||||||
|
assert call_kwargs["api_version"] == "2023-05-15"
|
||||||
assert call_kwargs["deployment_id"] == "text-embedding-ada-002"
|
assert call_kwargs["deployment_id"] == "text-embedding-ada-002"
|
||||||
|
|||||||
Reference in New Issue
Block a user