mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 23:58:34 +00:00
Compare commits
8 Commits
devin/1756
...
devin/1739
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
deba76cb9d | ||
|
|
5fd64d7f51 | ||
|
|
dea20a5010 | ||
|
|
5eefd90512 | ||
|
|
015ce7f550 | ||
|
|
26b62231db | ||
|
|
5b710cf2f9 | ||
|
|
97c8a8ab72 |
@@ -285,8 +285,37 @@ The `embedder` parameter supports various embedding model providers that include
|
|||||||
- `openai`: OpenAI's embedding models
|
- `openai`: OpenAI's embedding models
|
||||||
- `google`: Google's text embedding models
|
- `google`: Google's text embedding models
|
||||||
- `azure`: Azure OpenAI embeddings
|
- `azure`: Azure OpenAI embeddings
|
||||||
- `ollama`: Local embeddings with Ollama
|
- `ollama`: Local embeddings with Ollama (supports flexible URL configuration)
|
||||||
- `vertexai`: Google Cloud VertexAI embeddings
|
- `vertexai`: Google Cloud VertexAI embeddings
|
||||||
|
|
||||||
|
Here's an example of configuring the Ollama embedder with custom URL settings:
|
||||||
|
```python
|
||||||
|
# Configure Ollama embedder with custom URL
|
||||||
|
agent = Agent(
|
||||||
|
role="Data Analyst",
|
||||||
|
goal="Analyze data efficiently",
|
||||||
|
embedder={
|
||||||
|
"provider": "ollama",
|
||||||
|
"config": {
|
||||||
|
"model": "llama2",
|
||||||
|
# URL configuration supports multiple keys in priority order:
|
||||||
|
# 1. url: Legacy key (highest priority)
|
||||||
|
# 2. api_url: Alternative key following HuggingFace pattern
|
||||||
|
# 3. base_url: Alternative key
|
||||||
|
# 4. api_base: Alternative key following Azure pattern
|
||||||
|
"url": "http://ollama:11434/api/embeddings" # Example for Docker setup
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
The Ollama embedder supports multiple URL configuration keys for flexibility:
|
||||||
|
- `url`: Legacy key (highest priority)
|
||||||
|
- `api_url`: Alternative key following HuggingFace pattern
|
||||||
|
- `base_url`: Alternative key
|
||||||
|
- `api_base`: Alternative key following Azure pattern
|
||||||
|
|
||||||
|
If no URL is specified, it defaults to `http://localhost:11434/api/embeddings`.
|
||||||
- `cohere`: Cohere's embedding models
|
- `cohere`: Cohere's embedding models
|
||||||
- `voyageai`: VoyageAI's embedding models
|
- `voyageai`: VoyageAI's embedding models
|
||||||
- `bedrock`: AWS Bedrock embeddings
|
- `bedrock`: AWS Bedrock embeddings
|
||||||
|
|||||||
@@ -1,11 +1,36 @@
|
|||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import urllib.parse
|
||||||
from typing import Any, Dict, Optional, cast
|
from typing import Any, Dict, Optional, cast
|
||||||
|
|
||||||
from chromadb import Documents, EmbeddingFunction, Embeddings
|
from chromadb import Documents, EmbeddingFunction, Embeddings
|
||||||
from chromadb.api.types import validate_embedding_function
|
from chromadb.api.types import validate_embedding_function
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingConfigurator:
|
class EmbeddingConfigurator:
|
||||||
|
@staticmethod
|
||||||
|
def _validate_url(url: str) -> str:
|
||||||
|
"""Validate URL format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: URL to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The validated URL
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If URL is invalid
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = urllib.parse.urlparse(url)
|
||||||
|
if all([result.scheme, result.netloc]):
|
||||||
|
return url
|
||||||
|
raise ValueError(f"Invalid URL format: {url}")
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Invalid URL: {str(e)}")
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.embedding_functions = {
|
self.embedding_functions = {
|
||||||
"openai": self._configure_openai,
|
"openai": self._configure_openai,
|
||||||
@@ -92,13 +117,44 @@ class EmbeddingConfigurator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_ollama(config, model_name):
|
def _configure_ollama(config: Dict[str, Any], model_name: Optional[str]) -> EmbeddingFunction:
|
||||||
|
"""Configure Ollama embedder with flexible URL configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Configuration dictionary that supports multiple URL keys in priority order:
|
||||||
|
1. url: Legacy key (highest priority)
|
||||||
|
2. api_url: Alternative key following HuggingFace pattern
|
||||||
|
3. base_url: Alternative key
|
||||||
|
4. api_base: Alternative key following Azure pattern
|
||||||
|
Default: http://localhost:11434/api/embeddings
|
||||||
|
model_name: Name of the Ollama model to use
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OllamaEmbeddingFunction: Configured embedder instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If URL is invalid or model name is missing
|
||||||
|
"""
|
||||||
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
||||||
OllamaEmbeddingFunction,
|
OllamaEmbeddingFunction,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not model_name:
|
||||||
|
raise ValueError("Model name is required for Ollama embedder configuration")
|
||||||
|
|
||||||
|
url = (
|
||||||
|
config.get("url")
|
||||||
|
or config.get("api_url")
|
||||||
|
or config.get("base_url")
|
||||||
|
or config.get("api_base")
|
||||||
|
or "http://localhost:11434/api/embeddings"
|
||||||
|
)
|
||||||
|
|
||||||
|
validated_url = EmbeddingConfigurator._validate_url(url)
|
||||||
|
logger.info(f"Configuring Ollama embedder with URL: {validated_url}")
|
||||||
|
|
||||||
return OllamaEmbeddingFunction(
|
return OllamaEmbeddingFunction(
|
||||||
url=config.get("url", "http://localhost:11434/api/embeddings"),
|
url=validated_url,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
92
tests/embedder_test.py
Normal file
92
tests/embedder_test.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from crewai.utilities.embedding_configurator import EmbeddingConfigurator
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_case",
|
||||||
|
[
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"config": {"provider": "ollama", "config": {"model": "test-model"}},
|
||||||
|
"expected_url": "http://localhost:11434/api/embeddings"
|
||||||
|
},
|
||||||
|
id="default_url"
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"config": {"provider": "ollama", "config": {"model": "test-model", "url": "http://custom:11434"}},
|
||||||
|
"expected_url": "http://custom:11434"
|
||||||
|
},
|
||||||
|
id="legacy_url"
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"config": {"provider": "ollama", "config": {"model": "test-model", "api_url": "http://api:11434"}},
|
||||||
|
"expected_url": "http://api:11434"
|
||||||
|
},
|
||||||
|
id="api_url"
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"config": {"provider": "ollama", "config": {"model": "test-model", "base_url": "http://base:11434"}},
|
||||||
|
"expected_url": "http://base:11434"
|
||||||
|
},
|
||||||
|
id="base_url"
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"config": {"provider": "ollama", "config": {"model": "test-model", "api_base": "http://base-api:11434"}},
|
||||||
|
"expected_url": "http://base-api:11434"
|
||||||
|
},
|
||||||
|
id="api_base"
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"config": {
|
||||||
|
"provider": "ollama",
|
||||||
|
"config": {
|
||||||
|
"model": "test-model",
|
||||||
|
"url": "http://url:11434",
|
||||||
|
"api_url": "http://api:11434",
|
||||||
|
"base_url": "http://base:11434",
|
||||||
|
"api_base": "http://base-api:11434"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"expected_url": "http://url:11434"
|
||||||
|
},
|
||||||
|
id="url_precedence"
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_ollama_embedder_url_config(test_case):
|
||||||
|
configurator = EmbeddingConfigurator()
|
||||||
|
with patch("chromadb.utils.embedding_functions.ollama_embedding_function.OllamaEmbeddingFunction") as mock_ollama:
|
||||||
|
configurator.configure_embedder(test_case["config"])
|
||||||
|
mock_ollama.assert_called_once()
|
||||||
|
_, kwargs = mock_ollama.call_args
|
||||||
|
assert kwargs["url"] == test_case["expected_url"]
|
||||||
|
mock_ollama.reset_mock()
|
||||||
|
|
||||||
|
def test_ollama_embedder_invalid_url():
|
||||||
|
configurator = EmbeddingConfigurator()
|
||||||
|
with pytest.raises(ValueError, match="Invalid URL format"):
|
||||||
|
configurator.configure_embedder({
|
||||||
|
"provider": "ollama",
|
||||||
|
"config": {
|
||||||
|
"model": "test-model",
|
||||||
|
"url": "invalid-url"
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
def test_ollama_embedder_missing_model():
|
||||||
|
configurator = EmbeddingConfigurator()
|
||||||
|
with pytest.raises(ValueError, match="Model name is required"):
|
||||||
|
configurator.configure_embedder({
|
||||||
|
"provider": "ollama",
|
||||||
|
"config": {
|
||||||
|
"url": "http://valid:11434"
|
||||||
|
}
|
||||||
|
})
|
||||||
@@ -369,7 +369,9 @@ def test_converter_with_llama3_2_model():
|
|||||||
|
|
||||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
def test_converter_with_llama3_1_model():
|
def test_converter_with_llama3_1_model():
|
||||||
llm = LLM(model="ollama/llama3.1", base_url="http://localhost:11434")
|
llm = Mock(spec=LLM)
|
||||||
|
llm.supports_function_calling.return_value = False
|
||||||
|
llm.call.return_value = '{"name": "Alice Llama", "age": 30}'
|
||||||
sample_text = "Name: Alice Llama, Age: 30"
|
sample_text = "Name: Alice Llama, Age: 30"
|
||||||
|
|
||||||
instructions = get_conversion_instructions(SimpleModel, llm)
|
instructions = get_conversion_instructions(SimpleModel, llm)
|
||||||
@@ -385,9 +387,10 @@ def test_converter_with_llama3_1_model():
|
|||||||
assert isinstance(output, SimpleModel)
|
assert isinstance(output, SimpleModel)
|
||||||
assert output.name == "Alice Llama"
|
assert output.name == "Alice Llama"
|
||||||
assert output.age == 30
|
assert output.age == 30
|
||||||
|
llm.call.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
@pytest.mark.vcr(filter_headers=["authorization"], record_mode="new_episodes")
|
||||||
def test_converter_with_nested_model():
|
def test_converter_with_nested_model():
|
||||||
llm = LLM(model="gpt-4o-mini")
|
llm = LLM(model="gpt-4o-mini")
|
||||||
sample_text = "Name: John Doe\nAge: 30\nAddress: 123 Main St, Anytown, 12345"
|
sample_text = "Name: John Doe\nAge: 30\nAddress: 123 Main St, Anytown, 12345"
|
||||||
|
|||||||
Reference in New Issue
Block a user