mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 04:18:35 +00:00
Compare commits
6 Commits
1.2.1
...
devin/1751
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b1d0072c13 | ||
|
|
b022e06e0d | ||
|
|
2ede9ca9be | ||
|
|
f4807ee858 | ||
|
|
abd1d341da | ||
|
|
d422439b7a |
@@ -91,14 +91,52 @@ class EmbeddingConfigurator:
|
||||
organization_id=config.get("organization_id"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _validate_url(url):
|
||||
"""Validate that a URL is properly formatted and uses HTTP/HTTPS scheme."""
|
||||
if not url:
|
||||
return False
|
||||
|
||||
from urllib.parse import urlparse
|
||||
|
||||
try:
|
||||
result = urlparse(url)
|
||||
return all([
|
||||
result.scheme in ('http', 'https'),
|
||||
result.netloc
|
||||
])
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _configure_ollama(config, model_name):
|
||||
"""Configure Ollama embedding function.
|
||||
|
||||
Supports configuration via:
|
||||
1. config.url - Direct URL to Ollama embeddings endpoint
|
||||
2. config.api_base - Base URL for Ollama API
|
||||
3. API_BASE environment variable - Base URL from environment
|
||||
4. Default: http://localhost:11434/api/embeddings
|
||||
|
||||
Note: When using api_base or API_BASE, ensure the URL includes the full
|
||||
embeddings endpoint path (e.g., http://localhost:11434/api/embeddings)
|
||||
"""
|
||||
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
||||
OllamaEmbeddingFunction,
|
||||
)
|
||||
|
||||
url = (
|
||||
config.get("url")
|
||||
or config.get("api_base")
|
||||
or os.getenv("API_BASE")
|
||||
or "http://localhost:11434/api/embeddings"
|
||||
)
|
||||
|
||||
if not EmbeddingConfigurator._validate_url(url):
|
||||
raise ValueError(f"Invalid Ollama API URL: {url}")
|
||||
|
||||
return OllamaEmbeddingFunction(
|
||||
url=config.get("url", "http://localhost:11434/api/embeddings"),
|
||||
url=url,
|
||||
model_name=model_name,
|
||||
)
|
||||
|
||||
|
||||
176
tests/utilities/test_ollama_embedding_configurator.py
Normal file
176
tests/utilities/test_ollama_embedding_configurator.py
Normal file
@@ -0,0 +1,176 @@
|
||||
import os
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
from crewai.utilities.embedding_configurator import EmbeddingConfigurator
|
||||
|
||||
|
||||
@pytest.mark.url_configuration
|
||||
class TestOllamaEmbeddingConfigurator:
|
||||
def setup_method(self):
|
||||
self.configurator = EmbeddingConfigurator()
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_ollama_default_url(self):
|
||||
config = {"provider": "ollama", "config": {"model": "llama2"}}
|
||||
|
||||
with patch("chromadb.utils.embedding_functions.ollama_embedding_function.OllamaEmbeddingFunction") as mock_ollama:
|
||||
self.configurator.configure_embedder(config)
|
||||
mock_ollama.assert_called_once_with(
|
||||
url="http://localhost:11434/api/embeddings",
|
||||
model_name="llama2"
|
||||
)
|
||||
|
||||
@patch.dict(os.environ, {"API_BASE": "http://custom-ollama:8080/api/embeddings"}, clear=True)
|
||||
def test_ollama_respects_api_base_env_var(self):
|
||||
config = {"provider": "ollama", "config": {"model": "llama2"}}
|
||||
|
||||
with patch("chromadb.utils.embedding_functions.ollama_embedding_function.OllamaEmbeddingFunction") as mock_ollama:
|
||||
self.configurator.configure_embedder(config)
|
||||
mock_ollama.assert_called_once_with(
|
||||
url="http://custom-ollama:8080/api/embeddings",
|
||||
model_name="llama2"
|
||||
)
|
||||
|
||||
@patch.dict(os.environ, {"API_BASE": "http://env-ollama:8080/api/embeddings"}, clear=True)
|
||||
def test_ollama_config_url_overrides_env_var(self):
|
||||
config = {
|
||||
"provider": "ollama",
|
||||
"config": {
|
||||
"model": "llama2",
|
||||
"url": "http://config-ollama:9090/api/embeddings"
|
||||
}
|
||||
}
|
||||
|
||||
with patch("chromadb.utils.embedding_functions.ollama_embedding_function.OllamaEmbeddingFunction") as mock_ollama:
|
||||
self.configurator.configure_embedder(config)
|
||||
mock_ollama.assert_called_once_with(
|
||||
url="http://config-ollama:9090/api/embeddings",
|
||||
model_name="llama2"
|
||||
)
|
||||
|
||||
@patch.dict(os.environ, {"API_BASE": "http://env-ollama:8080/api/embeddings"}, clear=True)
|
||||
def test_ollama_config_api_base_overrides_env_var(self):
|
||||
config = {
|
||||
"provider": "ollama",
|
||||
"config": {
|
||||
"model": "llama2",
|
||||
"api_base": "http://config-ollama:9090/api/embeddings"
|
||||
}
|
||||
}
|
||||
|
||||
with patch("chromadb.utils.embedding_functions.ollama_embedding_function.OllamaEmbeddingFunction") as mock_ollama:
|
||||
self.configurator.configure_embedder(config)
|
||||
mock_ollama.assert_called_once_with(
|
||||
url="http://config-ollama:9090/api/embeddings",
|
||||
model_name="llama2"
|
||||
)
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_ollama_url_priority_order(self):
|
||||
config = {
|
||||
"provider": "ollama",
|
||||
"config": {
|
||||
"model": "llama2",
|
||||
"url": "http://url-config:1111/api/embeddings",
|
||||
"api_base": "http://api-base-config:2222/api/embeddings"
|
||||
}
|
||||
}
|
||||
|
||||
with patch("chromadb.utils.embedding_functions.ollama_embedding_function.OllamaEmbeddingFunction") as mock_ollama:
|
||||
self.configurator.configure_embedder(config)
|
||||
mock_ollama.assert_called_once_with(
|
||||
url="http://url-config:1111/api/embeddings",
|
||||
model_name="llama2"
|
||||
)
|
||||
|
||||
@patch.dict(os.environ, {"API_BASE": "http://localhost:11434/api/embeddings"}, clear=True)
|
||||
def test_ollama_uses_provided_url_as_is(self):
|
||||
config = {"provider": "ollama", "config": {"model": "llama2"}}
|
||||
|
||||
with patch("chromadb.utils.embedding_functions.ollama_embedding_function.OllamaEmbeddingFunction") as mock_ollama:
|
||||
self.configurator.configure_embedder(config)
|
||||
mock_ollama.assert_called_once_with(
|
||||
url="http://localhost:11434/api/embeddings",
|
||||
model_name="llama2"
|
||||
)
|
||||
|
||||
@patch.dict(os.environ, {"API_BASE": "http://custom-base:9000"}, clear=True)
|
||||
def test_ollama_requires_complete_url_in_api_base(self):
|
||||
"""Test that demonstrates users must provide complete URLs including endpoint."""
|
||||
config = {"provider": "ollama", "config": {"model": "llama2"}}
|
||||
|
||||
with patch("chromadb.utils.embedding_functions.ollama_embedding_function.OllamaEmbeddingFunction") as mock_ollama:
|
||||
self.configurator.configure_embedder(config)
|
||||
mock_ollama.assert_called_once_with(
|
||||
url="http://custom-base:9000",
|
||||
model_name="llama2"
|
||||
)
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_ollama_config_api_base_without_url(self):
|
||||
config = {
|
||||
"provider": "ollama",
|
||||
"config": {
|
||||
"model": "llama2",
|
||||
"api_base": "http://config-ollama:9090/api/embeddings"
|
||||
}
|
||||
}
|
||||
|
||||
with patch("chromadb.utils.embedding_functions.ollama_embedding_function.OllamaEmbeddingFunction") as mock_ollama:
|
||||
self.configurator.configure_embedder(config)
|
||||
mock_ollama.assert_called_once_with(
|
||||
url="http://config-ollama:9090/api/embeddings",
|
||||
model_name="llama2"
|
||||
)
|
||||
|
||||
@pytest.mark.error_handling
|
||||
class TestOllamaErrorHandling:
|
||||
def setup_method(self):
|
||||
self.configurator = EmbeddingConfigurator()
|
||||
|
||||
@pytest.mark.parametrize("invalid_url", [
|
||||
"not-a-url",
|
||||
"ftp://invalid-scheme",
|
||||
"http://",
|
||||
"://missing-scheme",
|
||||
"http:///missing-netloc",
|
||||
])
|
||||
def test_invalid_url_raises_error(self, invalid_url):
|
||||
"""Test that invalid URLs raise ValueError with clear error message."""
|
||||
config = {
|
||||
"provider": "ollama",
|
||||
"config": {
|
||||
"model": "llama2",
|
||||
"url": invalid_url
|
||||
}
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid Ollama API URL"):
|
||||
self.configurator.configure_embedder(config)
|
||||
|
||||
@pytest.mark.parametrize("invalid_api_base", [
|
||||
"not-a-url",
|
||||
"ftp://invalid-scheme",
|
||||
"http://",
|
||||
"://missing-scheme",
|
||||
])
|
||||
def test_invalid_api_base_raises_error(self, invalid_api_base):
|
||||
"""Test that invalid api_base URLs raise ValueError with clear error message."""
|
||||
config = {
|
||||
"provider": "ollama",
|
||||
"config": {
|
||||
"model": "llama2",
|
||||
"api_base": invalid_api_base
|
||||
}
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid Ollama API URL"):
|
||||
self.configurator.configure_embedder(config)
|
||||
|
||||
@patch.dict(os.environ, {"API_BASE": "not-a-valid-url"}, clear=True)
|
||||
def test_invalid_env_var_raises_error(self):
|
||||
"""Test that invalid API_BASE environment variable raises ValueError."""
|
||||
config = {"provider": "ollama", "config": {"model": "llama2"}}
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid Ollama API URL"):
|
||||
self.configurator.configure_embedder(config)
|
||||
Reference in New Issue
Block a user