refactor: improve Ollama embedder with validation and tests

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-02-09 05:42:43 +00:00
parent 26b62231db
commit 015ce7f550
2 changed files with 125 additions and 57 deletions

View File

@@ -1,11 +1,36 @@
import logging
import os import os
from typing import Any, Dict, cast import urllib.parse
from typing import Any, Dict, Optional, cast
from chromadb import Documents, EmbeddingFunction, Embeddings from chromadb import Documents, EmbeddingFunction, Embeddings
from chromadb.api.types import validate_embedding_function from chromadb.api.types import validate_embedding_function
logger = logging.getLogger(__name__)
class EmbeddingConfigurator: class EmbeddingConfigurator:
@staticmethod
def _validate_url(url: str) -> str:
"""Validate URL format.
Args:
url: URL to validate
Returns:
str: The validated URL
Raises:
ValueError: If URL is invalid
"""
try:
result = urllib.parse.urlparse(url)
if all([result.scheme, result.netloc]):
return url
raise ValueError(f"Invalid URL format: {url}")
except Exception as e:
raise ValueError(f"Invalid URL: {str(e)}")
def __init__(self): def __init__(self):
self.embedding_functions = { self.embedding_functions = {
"openai": self._configure_openai, "openai": self._configure_openai,
@@ -81,24 +106,31 @@ class EmbeddingConfigurator:
) )
@staticmethod @staticmethod
def _configure_ollama(config, model_name): def _configure_ollama(config: Dict[str, Any], model_name: Optional[str]) -> EmbeddingFunction:
"""Configure Ollama embedder with flexible URL configuration. """Configure Ollama embedder with flexible URL configuration.
Args: Args:
config: Configuration dictionary that supports multiple URL keys: config: Configuration dictionary that supports multiple URL keys in priority order:
- url: Legacy key (default: http://localhost:11434/api/embeddings) 1. url: Legacy key (highest priority)
- api_url: Alternative key following HuggingFace pattern 2. api_url: Alternative key following HuggingFace pattern
- base_url: Alternative key 3. base_url: Alternative key
- api_base: Alternative key following Azure pattern 4. api_base: Alternative key following Azure pattern
Default: http://localhost:11434/api/embeddings
model_name: Name of the Ollama model to use model_name: Name of the Ollama model to use
Returns: Returns:
OllamaEmbeddingFunction: Configured embedder instance OllamaEmbeddingFunction: Configured embedder instance
Raises:
ValueError: If URL is invalid or model name is missing
""" """
from chromadb.utils.embedding_functions.ollama_embedding_function import ( from chromadb.utils.embedding_functions.ollama_embedding_function import (
OllamaEmbeddingFunction, OllamaEmbeddingFunction,
) )
if not model_name:
raise ValueError("Model name is required for Ollama embedder configuration")
url = ( url = (
config.get("url") config.get("url")
or config.get("api_url") or config.get("api_url")
@@ -106,9 +138,12 @@ class EmbeddingConfigurator:
or config.get("api_base") or config.get("api_base")
or "http://localhost:11434/api/embeddings" or "http://localhost:11434/api/embeddings"
) )
validated_url = EmbeddingConfigurator._validate_url(url)
logger.info(f"Configuring Ollama embedder with URL: {validated_url}")
return OllamaEmbeddingFunction( return OllamaEmbeddingFunction(
url=url, url=validated_url,
model_name=model_name, model_name=model_name,
) )

View File

@@ -2,55 +2,88 @@ import pytest
from unittest.mock import patch from unittest.mock import patch
from crewai.utilities.embedding_configurator import EmbeddingConfigurator from crewai.utilities.embedding_configurator import EmbeddingConfigurator
def test_ollama_embedder_url_config(): @pytest.mark.parametrize(
configurator = EmbeddingConfigurator() "test_case",
[
test_cases = [ pytest.param(
# Test default URL {
{ "config": {"provider": "ollama", "config": {"model": "test-model"}},
"config": {"provider": "ollama", "config": {"model": "test-model"}}, "expected_url": "http://localhost:11434/api/embeddings"
"expected_url": "http://localhost:11434/api/embeddings"
},
# Test legacy url key
{
"config": {"provider": "ollama", "config": {"model": "test-model", "url": "http://custom:11434"}},
"expected_url": "http://custom:11434"
},
# Test api_url key
{
"config": {"provider": "ollama", "config": {"model": "test-model", "api_url": "http://api:11434"}},
"expected_url": "http://api:11434"
},
# Test base_url key
{
"config": {"provider": "ollama", "config": {"model": "test-model", "base_url": "http://base:11434"}},
"expected_url": "http://base:11434"
},
# Test api_base key
{
"config": {"provider": "ollama", "config": {"model": "test-model", "api_base": "http://base-api:11434"}},
"expected_url": "http://base-api:11434"
},
# Test URL precedence order
{
"config": {
"provider": "ollama",
"config": {
"model": "test-model",
"url": "http://url:11434",
"api_url": "http://api:11434",
"base_url": "http://base:11434",
"api_base": "http://base-api:11434"
}
}, },
"expected_url": "http://url:11434" # url key should have highest precedence id="default_url"
} ),
pytest.param(
{
"config": {"provider": "ollama", "config": {"model": "test-model", "url": "http://custom:11434"}},
"expected_url": "http://custom:11434"
},
id="legacy_url"
),
pytest.param(
{
"config": {"provider": "ollama", "config": {"model": "test-model", "api_url": "http://api:11434"}},
"expected_url": "http://api:11434"
},
id="api_url"
),
pytest.param(
{
"config": {"provider": "ollama", "config": {"model": "test-model", "base_url": "http://base:11434"}},
"expected_url": "http://base:11434"
},
id="base_url"
),
pytest.param(
{
"config": {"provider": "ollama", "config": {"model": "test-model", "api_base": "http://base-api:11434"}},
"expected_url": "http://base-api:11434"
},
id="api_base"
),
pytest.param(
{
"config": {
"provider": "ollama",
"config": {
"model": "test-model",
"url": "http://url:11434",
"api_url": "http://api:11434",
"base_url": "http://base:11434",
"api_base": "http://base-api:11434"
}
},
"expected_url": "http://url:11434"
},
id="url_precedence"
),
] ]
)
def test_ollama_embedder_url_config(test_case):
configurator = EmbeddingConfigurator()
with patch("chromadb.utils.embedding_functions.ollama_embedding_function.OllamaEmbeddingFunction") as mock_ollama:
configurator.configure_embedder(test_case["config"])
mock_ollama.assert_called_once()
_, kwargs = mock_ollama.call_args
assert kwargs["url"] == test_case["expected_url"]
mock_ollama.reset_mock()
for test_case in test_cases: def test_ollama_embedder_invalid_url():
with patch("chromadb.utils.embedding_functions.ollama_embedding_function.OllamaEmbeddingFunction") as mock_ollama: configurator = EmbeddingConfigurator()
configurator.configure_embedder(test_case["config"]) with pytest.raises(ValueError, match="Invalid URL format"):
mock_ollama.assert_called_once() configurator.configure_embedder({
_, kwargs = mock_ollama.call_args "provider": "ollama",
assert kwargs["url"] == test_case["expected_url"] "config": {
mock_ollama.reset_mock() "model": "test-model",
"url": "invalid-url"
}
})
def test_ollama_embedder_missing_model():
configurator = EmbeddingConfigurator()
with pytest.raises(ValueError, match="Model name is required"):
configurator.configure_embedder({
"provider": "ollama",
"config": {
"url": "http://valid:11434"
}
})