mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 16:18:30 +00:00
refactor: improve Ollama embedder with validation and tests
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -1,11 +1,36 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, cast
|
||||
import urllib.parse
|
||||
from typing import Any, Dict, Optional, cast
|
||||
|
||||
from chromadb import Documents, EmbeddingFunction, Embeddings
|
||||
from chromadb.api.types import validate_embedding_function
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmbeddingConfigurator:
|
||||
@staticmethod
|
||||
def _validate_url(url: str) -> str:
|
||||
"""Validate URL format.
|
||||
|
||||
Args:
|
||||
url: URL to validate
|
||||
|
||||
Returns:
|
||||
str: The validated URL
|
||||
|
||||
Raises:
|
||||
ValueError: If URL is invalid
|
||||
"""
|
||||
try:
|
||||
result = urllib.parse.urlparse(url)
|
||||
if all([result.scheme, result.netloc]):
|
||||
return url
|
||||
raise ValueError(f"Invalid URL format: {url}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid URL: {str(e)}")
|
||||
|
||||
def __init__(self):
|
||||
self.embedding_functions = {
|
||||
"openai": self._configure_openai,
|
||||
@@ -81,24 +106,31 @@ class EmbeddingConfigurator:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_ollama(config, model_name):
|
||||
def _configure_ollama(config: Dict[str, Any], model_name: Optional[str]) -> EmbeddingFunction:
|
||||
"""Configure Ollama embedder with flexible URL configuration.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary that supports multiple URL keys:
|
||||
- url: Legacy key (default: http://localhost:11434/api/embeddings)
|
||||
- api_url: Alternative key following HuggingFace pattern
|
||||
- base_url: Alternative key
|
||||
- api_base: Alternative key following Azure pattern
|
||||
config: Configuration dictionary that supports multiple URL keys in priority order:
|
||||
1. url: Legacy key (highest priority)
|
||||
2. api_url: Alternative key following HuggingFace pattern
|
||||
3. base_url: Alternative key
|
||||
4. api_base: Alternative key following Azure pattern
|
||||
Default: http://localhost:11434/api/embeddings
|
||||
model_name: Name of the Ollama model to use
|
||||
|
||||
Returns:
|
||||
OllamaEmbeddingFunction: Configured embedder instance
|
||||
|
||||
Raises:
|
||||
ValueError: If URL is invalid or model name is missing
|
||||
"""
|
||||
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
||||
OllamaEmbeddingFunction,
|
||||
)
|
||||
|
||||
if not model_name:
|
||||
raise ValueError("Model name is required for Ollama embedder configuration")
|
||||
|
||||
url = (
|
||||
config.get("url")
|
||||
or config.get("api_url")
|
||||
@@ -106,9 +138,12 @@ class EmbeddingConfigurator:
|
||||
or config.get("api_base")
|
||||
or "http://localhost:11434/api/embeddings"
|
||||
)
|
||||
|
||||
validated_url = EmbeddingConfigurator._validate_url(url)
|
||||
logger.info(f"Configuring Ollama embedder with URL: {validated_url}")
|
||||
|
||||
return OllamaEmbeddingFunction(
|
||||
url=url,
|
||||
url=validated_url,
|
||||
model_name=model_name,
|
||||
)
|
||||
|
||||
|
||||
@@ -2,55 +2,88 @@ import pytest
|
||||
from unittest.mock import patch
|
||||
from crewai.utilities.embedding_configurator import EmbeddingConfigurator
|
||||
|
||||
def test_ollama_embedder_url_config():
|
||||
configurator = EmbeddingConfigurator()
|
||||
|
||||
test_cases = [
|
||||
# Test default URL
|
||||
{
|
||||
"config": {"provider": "ollama", "config": {"model": "test-model"}},
|
||||
"expected_url": "http://localhost:11434/api/embeddings"
|
||||
},
|
||||
# Test legacy url key
|
||||
{
|
||||
"config": {"provider": "ollama", "config": {"model": "test-model", "url": "http://custom:11434"}},
|
||||
"expected_url": "http://custom:11434"
|
||||
},
|
||||
# Test api_url key
|
||||
{
|
||||
"config": {"provider": "ollama", "config": {"model": "test-model", "api_url": "http://api:11434"}},
|
||||
"expected_url": "http://api:11434"
|
||||
},
|
||||
# Test base_url key
|
||||
{
|
||||
"config": {"provider": "ollama", "config": {"model": "test-model", "base_url": "http://base:11434"}},
|
||||
"expected_url": "http://base:11434"
|
||||
},
|
||||
# Test api_base key
|
||||
{
|
||||
"config": {"provider": "ollama", "config": {"model": "test-model", "api_base": "http://base-api:11434"}},
|
||||
"expected_url": "http://base-api:11434"
|
||||
},
|
||||
# Test URL precedence order
|
||||
{
|
||||
"config": {
|
||||
"provider": "ollama",
|
||||
"config": {
|
||||
"model": "test-model",
|
||||
"url": "http://url:11434",
|
||||
"api_url": "http://api:11434",
|
||||
"base_url": "http://base:11434",
|
||||
"api_base": "http://base-api:11434"
|
||||
}
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
pytest.param(
|
||||
{
|
||||
"config": {"provider": "ollama", "config": {"model": "test-model"}},
|
||||
"expected_url": "http://localhost:11434/api/embeddings"
|
||||
},
|
||||
"expected_url": "http://url:11434" # url key should have highest precedence
|
||||
}
|
||||
id="default_url"
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"config": {"provider": "ollama", "config": {"model": "test-model", "url": "http://custom:11434"}},
|
||||
"expected_url": "http://custom:11434"
|
||||
},
|
||||
id="legacy_url"
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"config": {"provider": "ollama", "config": {"model": "test-model", "api_url": "http://api:11434"}},
|
||||
"expected_url": "http://api:11434"
|
||||
},
|
||||
id="api_url"
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"config": {"provider": "ollama", "config": {"model": "test-model", "base_url": "http://base:11434"}},
|
||||
"expected_url": "http://base:11434"
|
||||
},
|
||||
id="base_url"
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"config": {"provider": "ollama", "config": {"model": "test-model", "api_base": "http://base-api:11434"}},
|
||||
"expected_url": "http://base-api:11434"
|
||||
},
|
||||
id="api_base"
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"config": {
|
||||
"provider": "ollama",
|
||||
"config": {
|
||||
"model": "test-model",
|
||||
"url": "http://url:11434",
|
||||
"api_url": "http://api:11434",
|
||||
"base_url": "http://base:11434",
|
||||
"api_base": "http://base-api:11434"
|
||||
}
|
||||
},
|
||||
"expected_url": "http://url:11434"
|
||||
},
|
||||
id="url_precedence"
|
||||
),
|
||||
]
|
||||
)
|
||||
def test_ollama_embedder_url_config(test_case):
|
||||
configurator = EmbeddingConfigurator()
|
||||
with patch("chromadb.utils.embedding_functions.ollama_embedding_function.OllamaEmbeddingFunction") as mock_ollama:
|
||||
configurator.configure_embedder(test_case["config"])
|
||||
mock_ollama.assert_called_once()
|
||||
_, kwargs = mock_ollama.call_args
|
||||
assert kwargs["url"] == test_case["expected_url"]
|
||||
mock_ollama.reset_mock()
|
||||
|
||||
for test_case in test_cases:
|
||||
with patch("chromadb.utils.embedding_functions.ollama_embedding_function.OllamaEmbeddingFunction") as mock_ollama:
|
||||
configurator.configure_embedder(test_case["config"])
|
||||
mock_ollama.assert_called_once()
|
||||
_, kwargs = mock_ollama.call_args
|
||||
assert kwargs["url"] == test_case["expected_url"]
|
||||
mock_ollama.reset_mock()
|
||||
def test_ollama_embedder_invalid_url():
|
||||
configurator = EmbeddingConfigurator()
|
||||
with pytest.raises(ValueError, match="Invalid URL format"):
|
||||
configurator.configure_embedder({
|
||||
"provider": "ollama",
|
||||
"config": {
|
||||
"model": "test-model",
|
||||
"url": "invalid-url"
|
||||
}
|
||||
})
|
||||
|
||||
def test_ollama_embedder_missing_model():
|
||||
configurator = EmbeddingConfigurator()
|
||||
with pytest.raises(ValueError, match="Model name is required"):
|
||||
configurator.configure_embedder({
|
||||
"provider": "ollama",
|
||||
"config": {
|
||||
"url": "http://valid:11434"
|
||||
}
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user