fix: Update embedding configuration and fix type errors

- Add configurable embedding provider support
- Remove OpenAI dependency for memory reset
- Add tests for different embedding providers
- Fix type hints and improve docstrings

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-02-09 23:25:38 +00:00
parent 409892d65f
commit d56523a01a
4 changed files with 121 additions and 19 deletions

View File

@@ -1,10 +1,11 @@
import asyncio
import json
import re
import uuid
import warnings
from concurrent.futures import Future
from hashlib import md5
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from pydantic import (
UUID4,

View File

@@ -163,12 +163,3 @@ class RAGStorage(BaseRAGStorage):
raise Exception(
f"An error occurred while resetting the {self.type} memory: {e}"
)
def _create_default_embedding_function(self):
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
return OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
)

View File

@@ -1,5 +1,5 @@
import os
from typing import Any, Dict, cast
from typing import Any, Dict, Optional, cast
from chromadb import Documents, EmbeddingFunction, Embeddings
from chromadb.api.types import validate_embedding_function
@@ -21,9 +21,36 @@ class EmbeddingConfigurator:
def configure_embedder(
self,
embedder_config: Dict[str, Any] | None = None,
embedder_config: Optional[Dict[str, Any]] = None,
) -> EmbeddingFunction:
"""Configures and returns an embedding function based on the provided config."""
"""Configure and return an embedding function based on the provided config.
Args:
embedder_config: Optional configuration dictionary containing:
- provider: Name of the embedding provider or EmbeddingFunction instance
- config: Provider-specific configuration dictionary with options like:
- api_key: API key for the provider
- model: Model name to use for embeddings
- url: API endpoint URL (for some providers)
- session: Session object (for some providers)
Returns:
EmbeddingFunction: Configured embedding function for the specified provider
Raises:
ValueError: If custom embedding function is invalid
Exception: If provider is not supported or configuration is invalid
Examples:
>>> config = {
... "provider": "openai",
... "config": {
... "api_key": "your-api-key",
... "model": "text-embedding-3-small"
... }
... }
>>> embedder = EmbeddingConfigurator().configure_embedder(config)
"""
if embedder_config is None:
return self._create_default_embedding_function()
@@ -47,12 +74,23 @@ class EmbeddingConfigurator:
@staticmethod
def _create_default_embedding_function():
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
return OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
"""Create a default embedding function based on environment variables.
Environment Variables:
CREWAI_EMBEDDING_PROVIDER: The embedding provider to use (default: "openai")
CREWAI_EMBEDDING_MODEL: The model to use for embeddings
OPENAI_API_KEY: API key for OpenAI (required if using OpenAI provider)
Returns:
EmbeddingFunction: Configured embedding function
"""
provider = os.getenv("CREWAI_EMBEDDING_PROVIDER", "openai")
config = {
"api_key": os.getenv("OPENAI_API_KEY"),
"model": os.getenv("CREWAI_EMBEDDING_MODEL", "text-embedding-3-small")
}
return EmbeddingConfigurator().configure_embedder(
{"provider": provider, "config": config}
)
@staticmethod

View File

@@ -0,0 +1,72 @@
import os
import tempfile
from typing import Generator
from pathlib import Path
import pytest
from chromadb import Documents, EmbeddingFunction, Embeddings
from crewai.memory import ShortTermMemory, LongTermMemory, EntityMemory
from crewai.utilities import EmbeddingConfigurator
from crewai.utilities.exceptions.embedding_exceptions import (
EmbeddingConfigurationError,
EmbeddingProviderError
)
@pytest.fixture
def temp_db_dir() -> Generator[Path, None, None]:
"""Create a temporary directory for test databases."""
with tempfile.TemporaryDirectory() as tmpdir:
path = Path(tmpdir)
yield path
def test_memory_reset_with_openai(temp_db_dir):
"""Test memory reset with default OpenAI provider."""
os.environ["OPENAI_API_KEY"] = "test-key"
memory = ShortTermMemory(path=temp_db_dir)
memory.reset() # Should work with OpenAI as default
def test_memory_reset_with_ollama(temp_db_dir):
"""Test memory reset with Ollama provider."""
os.environ["CREWAI_EMBEDDING_PROVIDER"] = "ollama"
memory = ShortTermMemory(path=temp_db_dir)
memory.reset() # Should not raise any OpenAI-related errors
def test_memory_reset_with_custom_provider(temp_db_dir):
"""Test memory reset with custom embedding provider."""
class CustomEmbedder(EmbeddingFunction):
def __call__(self, input: Documents) -> Embeddings:
if isinstance(input, str):
input = [input]
return [[0.5] * 10] * len(input)
memory = ShortTermMemory(
path=temp_db_dir,
embedder_config={"provider": CustomEmbedder()}
)
memory.reset() # Should work with custom embedder
def test_memory_reset_with_invalid_provider(temp_db_dir):
"""Test memory reset with invalid provider raises appropriate error."""
os.environ["CREWAI_EMBEDDING_PROVIDER"] = "invalid_provider"
with pytest.raises(Exception) as exc_info:
memory = ShortTermMemory(path=temp_db_dir)
memory.reset()
assert "Unsupported embedding provider" in str(exc_info.value)
def test_memory_reset_with_missing_api_key(temp_db_dir):
"""Test memory reset with missing API key raises appropriate error."""
os.environ.pop("OPENAI_API_KEY", None) # Ensure key is not set
os.environ["CREWAI_EMBEDDING_PROVIDER"] = "openai"
with pytest.raises(Exception) as exc_info:
memory = ShortTermMemory(path=temp_db_dir)
memory.reset()
assert "api_key" in str(exc_info.value).lower()
def test_memory_reset_cleans_up_files(temp_db_dir):
"""Test that memory reset properly cleans up database files."""
memory = ShortTermMemory(path=temp_db_dir)
memory.save("test memory", {"test": "metadata"})
assert any(temp_db_dir.iterdir()) # Directory should have files
memory.reset()
assert not any(temp_db_dir.iterdir()) # Directory should be empty