mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
fix: Update embedding configuration and fix type errors
- Add configurable embedding provider support - Remove OpenAI dependency for memory reset - Add tests for different embedding providers - Fix type hints and improve docstrings Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -1,10 +1,11 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
from concurrent.futures import Future
|
from concurrent.futures import Future
|
||||||
from hashlib import md5
|
from hashlib import md5
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
UUID4,
|
UUID4,
|
||||||
|
|||||||
@@ -163,12 +163,3 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
raise Exception(
|
raise Exception(
|
||||||
f"An error occurred while resetting the {self.type} memory: {e}"
|
f"An error occurred while resetting the {self.type} memory: {e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_default_embedding_function(self):
|
|
||||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
|
||||||
OpenAIEmbeddingFunction,
|
|
||||||
)
|
|
||||||
|
|
||||||
return OpenAIEmbeddingFunction(
|
|
||||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Any, Dict, cast
|
from typing import Any, Dict, Optional, cast
|
||||||
|
|
||||||
from chromadb import Documents, EmbeddingFunction, Embeddings
|
from chromadb import Documents, EmbeddingFunction, Embeddings
|
||||||
from chromadb.api.types import validate_embedding_function
|
from chromadb.api.types import validate_embedding_function
|
||||||
@@ -21,9 +21,36 @@ class EmbeddingConfigurator:
|
|||||||
|
|
||||||
def configure_embedder(
|
def configure_embedder(
|
||||||
self,
|
self,
|
||||||
embedder_config: Dict[str, Any] | None = None,
|
embedder_config: Optional[Dict[str, Any]] = None,
|
||||||
) -> EmbeddingFunction:
|
) -> EmbeddingFunction:
|
||||||
"""Configures and returns an embedding function based on the provided config."""
|
"""Configure and return an embedding function based on the provided config.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedder_config: Optional configuration dictionary containing:
|
||||||
|
- provider: Name of the embedding provider or EmbeddingFunction instance
|
||||||
|
- config: Provider-specific configuration dictionary with options like:
|
||||||
|
- api_key: API key for the provider
|
||||||
|
- model: Model name to use for embeddings
|
||||||
|
- url: API endpoint URL (for some providers)
|
||||||
|
- session: Session object (for some providers)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EmbeddingFunction: Configured embedding function for the specified provider
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If custom embedding function is invalid
|
||||||
|
Exception: If provider is not supported or configuration is invalid
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> config = {
|
||||||
|
... "provider": "openai",
|
||||||
|
... "config": {
|
||||||
|
... "api_key": "your-api-key",
|
||||||
|
... "model": "text-embedding-3-small"
|
||||||
|
... }
|
||||||
|
... }
|
||||||
|
>>> embedder = EmbeddingConfigurator().configure_embedder(config)
|
||||||
|
"""
|
||||||
if embedder_config is None:
|
if embedder_config is None:
|
||||||
return self._create_default_embedding_function()
|
return self._create_default_embedding_function()
|
||||||
|
|
||||||
@@ -47,12 +74,23 @@ class EmbeddingConfigurator:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _create_default_embedding_function():
|
def _create_default_embedding_function():
|
||||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
"""Create a default embedding function based on environment variables.
|
||||||
OpenAIEmbeddingFunction,
|
|
||||||
)
|
Environment Variables:
|
||||||
|
CREWAI_EMBEDDING_PROVIDER: The embedding provider to use (default: "openai")
|
||||||
return OpenAIEmbeddingFunction(
|
CREWAI_EMBEDDING_MODEL: The model to use for embeddings
|
||||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
OPENAI_API_KEY: API key for OpenAI (required if using OpenAI provider)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EmbeddingFunction: Configured embedding function
|
||||||
|
"""
|
||||||
|
provider = os.getenv("CREWAI_EMBEDDING_PROVIDER", "openai")
|
||||||
|
config = {
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
"model": os.getenv("CREWAI_EMBEDDING_MODEL", "text-embedding-3-small")
|
||||||
|
}
|
||||||
|
return EmbeddingConfigurator().configure_embedder(
|
||||||
|
{"provider": provider, "config": config}
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
72
tests/memory/test_memory_reset.py
Normal file
72
tests/memory/test_memory_reset.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from typing import Generator
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from chromadb import Documents, EmbeddingFunction, Embeddings
|
||||||
|
|
||||||
|
from crewai.memory import ShortTermMemory, LongTermMemory, EntityMemory
|
||||||
|
from crewai.utilities import EmbeddingConfigurator
|
||||||
|
from crewai.utilities.exceptions.embedding_exceptions import (
|
||||||
|
EmbeddingConfigurationError,
|
||||||
|
EmbeddingProviderError
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_db_dir() -> Generator[Path, None, None]:
|
||||||
|
"""Create a temporary directory for test databases."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
path = Path(tmpdir)
|
||||||
|
yield path
|
||||||
|
|
||||||
|
def test_memory_reset_with_openai(temp_db_dir):
|
||||||
|
"""Test memory reset with default OpenAI provider."""
|
||||||
|
os.environ["OPENAI_API_KEY"] = "test-key"
|
||||||
|
memory = ShortTermMemory(path=temp_db_dir)
|
||||||
|
memory.reset() # Should work with OpenAI as default
|
||||||
|
|
||||||
|
def test_memory_reset_with_ollama(temp_db_dir):
|
||||||
|
"""Test memory reset with Ollama provider."""
|
||||||
|
os.environ["CREWAI_EMBEDDING_PROVIDER"] = "ollama"
|
||||||
|
memory = ShortTermMemory(path=temp_db_dir)
|
||||||
|
memory.reset() # Should not raise any OpenAI-related errors
|
||||||
|
|
||||||
|
def test_memory_reset_with_custom_provider(temp_db_dir):
|
||||||
|
"""Test memory reset with custom embedding provider."""
|
||||||
|
class CustomEmbedder(EmbeddingFunction):
|
||||||
|
def __call__(self, input: Documents) -> Embeddings:
|
||||||
|
if isinstance(input, str):
|
||||||
|
input = [input]
|
||||||
|
return [[0.5] * 10] * len(input)
|
||||||
|
|
||||||
|
memory = ShortTermMemory(
|
||||||
|
path=temp_db_dir,
|
||||||
|
embedder_config={"provider": CustomEmbedder()}
|
||||||
|
)
|
||||||
|
memory.reset() # Should work with custom embedder
|
||||||
|
|
||||||
|
def test_memory_reset_with_invalid_provider(temp_db_dir):
|
||||||
|
"""Test memory reset with invalid provider raises appropriate error."""
|
||||||
|
os.environ["CREWAI_EMBEDDING_PROVIDER"] = "invalid_provider"
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
memory = ShortTermMemory(path=temp_db_dir)
|
||||||
|
memory.reset()
|
||||||
|
assert "Unsupported embedding provider" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_memory_reset_with_missing_api_key(temp_db_dir):
|
||||||
|
"""Test memory reset with missing API key raises appropriate error."""
|
||||||
|
os.environ.pop("OPENAI_API_KEY", None) # Ensure key is not set
|
||||||
|
os.environ["CREWAI_EMBEDDING_PROVIDER"] = "openai"
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
memory = ShortTermMemory(path=temp_db_dir)
|
||||||
|
memory.reset()
|
||||||
|
assert "api_key" in str(exc_info.value).lower()
|
||||||
|
|
||||||
|
def test_memory_reset_cleans_up_files(temp_db_dir):
|
||||||
|
"""Test that memory reset properly cleans up database files."""
|
||||||
|
memory = ShortTermMemory(path=temp_db_dir)
|
||||||
|
memory.save("test memory", {"test": "metadata"})
|
||||||
|
assert any(temp_db_dir.iterdir()) # Directory should have files
|
||||||
|
memory.reset()
|
||||||
|
assert not any(temp_db_dir.iterdir()) # Directory should be empty
|
||||||
Reference in New Issue
Block a user