From 015ce7f550041d2637d74bb071d0da2205bb3dba Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Sun, 9 Feb 2025 05:42:43 +0000 Subject: [PATCH] refactor: improve Ollama embedder with validation and tests Co-Authored-By: Joe Moura --- .../utilities/embedding_configurator.py | 51 +++++-- tests/embedder_test.py | 131 +++++++++++------- 2 files changed, 125 insertions(+), 57 deletions(-) diff --git a/src/crewai/utilities/embedding_configurator.py b/src/crewai/utilities/embedding_configurator.py index c3d22f41f..0fac43c3b 100644 --- a/src/crewai/utilities/embedding_configurator.py +++ b/src/crewai/utilities/embedding_configurator.py @@ -1,11 +1,36 @@ +import logging import os -from typing import Any, Dict, cast +import urllib.parse +from typing import Any, Dict, Optional, cast from chromadb import Documents, EmbeddingFunction, Embeddings from chromadb.api.types import validate_embedding_function +logger = logging.getLogger(__name__) + class EmbeddingConfigurator: + @staticmethod + def _validate_url(url: str) -> str: + """Validate URL format. + + Args: + url: URL to validate + + Returns: + str: The validated URL + + Raises: + ValueError: If URL is invalid + """ + try: + result = urllib.parse.urlparse(url) + if all([result.scheme, result.netloc]): + return url + raise ValueError(f"Invalid URL format: {url}") + except Exception as e: + raise ValueError(f"Invalid URL: {str(e)}") + def __init__(self): self.embedding_functions = { "openai": self._configure_openai, @@ -81,24 +106,31 @@ class EmbeddingConfigurator: ) @staticmethod - def _configure_ollama(config, model_name): + def _configure_ollama(config: Dict[str, Any], model_name: Optional[str]) -> EmbeddingFunction: """Configure Ollama embedder with flexible URL configuration. Args: - config: Configuration dictionary that supports multiple URL keys: - - url: Legacy key (default: http://localhost:11434/api/embeddings) - - api_url: Alternative key following HuggingFace pattern - - base_url: Alternative key - - api_base: Alternative key following Azure pattern + config: Configuration dictionary that supports multiple URL keys in priority order: + 1. url: Legacy key (highest priority) + 2. api_url: Alternative key following HuggingFace pattern + 3. base_url: Alternative key + 4. api_base: Alternative key following Azure pattern + Default: http://localhost:11434/api/embeddings model_name: Name of the Ollama model to use Returns: OllamaEmbeddingFunction: Configured embedder instance + + Raises: + ValueError: If URL is invalid or model name is missing """ from chromadb.utils.embedding_functions.ollama_embedding_function import ( OllamaEmbeddingFunction, ) + if not model_name: + raise ValueError("Model name is required for Ollama embedder configuration") + url = ( config.get("url") or config.get("api_url") @@ -106,9 +138,12 @@ class EmbeddingConfigurator: or config.get("api_base") or "http://localhost:11434/api/embeddings" ) + + validated_url = EmbeddingConfigurator._validate_url(url) + logger.info(f"Configuring Ollama embedder with URL: {validated_url}") return OllamaEmbeddingFunction( - url=url, + url=validated_url, model_name=model_name, ) diff --git a/tests/embedder_test.py b/tests/embedder_test.py index e663bb2b6..f52751418 100644 --- a/tests/embedder_test.py +++ b/tests/embedder_test.py @@ -2,55 +2,88 @@ import pytest from unittest.mock import patch from crewai.utilities.embedding_configurator import EmbeddingConfigurator -def test_ollama_embedder_url_config(): - configurator = EmbeddingConfigurator() - - test_cases = [ - # Test default URL - { - "config": {"provider": "ollama", "config": {"model": "test-model"}}, - "expected_url": "http://localhost:11434/api/embeddings" - }, - # Test legacy url key - { - "config": {"provider": "ollama", "config": {"model": "test-model", "url": "http://custom:11434"}}, - "expected_url": "http://custom:11434" - }, - # Test api_url key - { - "config": {"provider": "ollama", "config": {"model": "test-model", "api_url": "http://api:11434"}}, - "expected_url": "http://api:11434" - }, - # Test base_url key - { - "config": {"provider": "ollama", "config": {"model": "test-model", "base_url": "http://base:11434"}}, - "expected_url": "http://base:11434" - }, - # Test api_base key - { - "config": {"provider": "ollama", "config": {"model": "test-model", "api_base": "http://base-api:11434"}}, - "expected_url": "http://base-api:11434" - }, - # Test URL precedence order - { - "config": { - "provider": "ollama", - "config": { - "model": "test-model", - "url": "http://url:11434", - "api_url": "http://api:11434", - "base_url": "http://base:11434", - "api_base": "http://base-api:11434" - } +@pytest.mark.parametrize( + "test_case", + [ + pytest.param( + { + "config": {"provider": "ollama", "config": {"model": "test-model"}}, + "expected_url": "http://localhost:11434/api/embeddings" }, - "expected_url": "http://url:11434" # url key should have highest precedence - } + id="default_url" + ), + pytest.param( + { + "config": {"provider": "ollama", "config": {"model": "test-model", "url": "http://custom:11434"}}, + "expected_url": "http://custom:11434" + }, + id="legacy_url" + ), + pytest.param( + { + "config": {"provider": "ollama", "config": {"model": "test-model", "api_url": "http://api:11434"}}, + "expected_url": "http://api:11434" + }, + id="api_url" + ), + pytest.param( + { + "config": {"provider": "ollama", "config": {"model": "test-model", "base_url": "http://base:11434"}}, + "expected_url": "http://base:11434" + }, + id="base_url" + ), + pytest.param( + { + "config": {"provider": "ollama", "config": {"model": "test-model", "api_base": "http://base-api:11434"}}, + "expected_url": "http://base-api:11434" + }, + id="api_base" + ), + pytest.param( + { + "config": { + "provider": "ollama", + "config": { + "model": "test-model", + "url": "http://url:11434", + "api_url": "http://api:11434", + "base_url": "http://base:11434", + "api_base": "http://base-api:11434" + } + }, + "expected_url": "http://url:11434" + }, + id="url_precedence" + ), ] +) +def test_ollama_embedder_url_config(test_case): + configurator = EmbeddingConfigurator() + with patch("chromadb.utils.embedding_functions.ollama_embedding_function.OllamaEmbeddingFunction") as mock_ollama: + configurator.configure_embedder(test_case["config"]) + mock_ollama.assert_called_once() + _, kwargs = mock_ollama.call_args + assert kwargs["url"] == test_case["expected_url"] + mock_ollama.reset_mock() - for test_case in test_cases: - with patch("chromadb.utils.embedding_functions.ollama_embedding_function.OllamaEmbeddingFunction") as mock_ollama: - configurator.configure_embedder(test_case["config"]) - mock_ollama.assert_called_once() - _, kwargs = mock_ollama.call_args - assert kwargs["url"] == test_case["expected_url"] - mock_ollama.reset_mock() +def test_ollama_embedder_invalid_url(): + configurator = EmbeddingConfigurator() + with pytest.raises(ValueError, match="Invalid URL format"): + configurator.configure_embedder({ + "provider": "ollama", + "config": { + "model": "test-model", + "url": "invalid-url" + } + }) + +def test_ollama_embedder_missing_model(): + configurator = EmbeddingConfigurator() + with pytest.raises(ValueError, match="Model name is required"): + configurator.configure_embedder({ + "provider": "ollama", + "config": { + "url": "http://valid:11434" + } + })