refactor: improve Ollama embedder with validation and tests

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-02-09 05:42:43 +00:00
parent 26b62231db
commit 015ce7f550
2 changed files with 125 additions and 57 deletions

View File

@@ -2,55 +2,88 @@ import pytest
from unittest.mock import patch
from crewai.utilities.embedding_configurator import EmbeddingConfigurator
def test_ollama_embedder_url_config():
configurator = EmbeddingConfigurator()
test_cases = [
# Test default URL
{
"config": {"provider": "ollama", "config": {"model": "test-model"}},
"expected_url": "http://localhost:11434/api/embeddings"
},
# Test legacy url key
{
"config": {"provider": "ollama", "config": {"model": "test-model", "url": "http://custom:11434"}},
"expected_url": "http://custom:11434"
},
# Test api_url key
{
"config": {"provider": "ollama", "config": {"model": "test-model", "api_url": "http://api:11434"}},
"expected_url": "http://api:11434"
},
# Test base_url key
{
"config": {"provider": "ollama", "config": {"model": "test-model", "base_url": "http://base:11434"}},
"expected_url": "http://base:11434"
},
# Test api_base key
{
"config": {"provider": "ollama", "config": {"model": "test-model", "api_base": "http://base-api:11434"}},
"expected_url": "http://base-api:11434"
},
# Test URL precedence order
{
"config": {
"provider": "ollama",
"config": {
"model": "test-model",
"url": "http://url:11434",
"api_url": "http://api:11434",
"base_url": "http://base:11434",
"api_base": "http://base-api:11434"
}
@pytest.mark.parametrize(
"test_case",
[
pytest.param(
{
"config": {"provider": "ollama", "config": {"model": "test-model"}},
"expected_url": "http://localhost:11434/api/embeddings"
},
"expected_url": "http://url:11434" # url key should have highest precedence
}
id="default_url"
),
pytest.param(
{
"config": {"provider": "ollama", "config": {"model": "test-model", "url": "http://custom:11434"}},
"expected_url": "http://custom:11434"
},
id="legacy_url"
),
pytest.param(
{
"config": {"provider": "ollama", "config": {"model": "test-model", "api_url": "http://api:11434"}},
"expected_url": "http://api:11434"
},
id="api_url"
),
pytest.param(
{
"config": {"provider": "ollama", "config": {"model": "test-model", "base_url": "http://base:11434"}},
"expected_url": "http://base:11434"
},
id="base_url"
),
pytest.param(
{
"config": {"provider": "ollama", "config": {"model": "test-model", "api_base": "http://base-api:11434"}},
"expected_url": "http://base-api:11434"
},
id="api_base"
),
pytest.param(
{
"config": {
"provider": "ollama",
"config": {
"model": "test-model",
"url": "http://url:11434",
"api_url": "http://api:11434",
"base_url": "http://base:11434",
"api_base": "http://base-api:11434"
}
},
"expected_url": "http://url:11434"
},
id="url_precedence"
),
]
)
def test_ollama_embedder_url_config(test_case):
configurator = EmbeddingConfigurator()
with patch("chromadb.utils.embedding_functions.ollama_embedding_function.OllamaEmbeddingFunction") as mock_ollama:
configurator.configure_embedder(test_case["config"])
mock_ollama.assert_called_once()
_, kwargs = mock_ollama.call_args
assert kwargs["url"] == test_case["expected_url"]
mock_ollama.reset_mock()
for test_case in test_cases:
with patch("chromadb.utils.embedding_functions.ollama_embedding_function.OllamaEmbeddingFunction") as mock_ollama:
configurator.configure_embedder(test_case["config"])
mock_ollama.assert_called_once()
_, kwargs = mock_ollama.call_args
assert kwargs["url"] == test_case["expected_url"]
mock_ollama.reset_mock()
def test_ollama_embedder_invalid_url():
configurator = EmbeddingConfigurator()
with pytest.raises(ValueError, match="Invalid URL format"):
configurator.configure_embedder({
"provider": "ollama",
"config": {
"model": "test-model",
"url": "invalid-url"
}
})
def test_ollama_embedder_missing_model():
configurator = EmbeddingConfigurator()
with pytest.raises(ValueError, match="Model name is required"):
configurator.configure_embedder({
"provider": "ollama",
"config": {
"url": "http://valid:11434"
}
})