diff --git a/src/crewai/crew.py b/src/crewai/crew.py index d488783ea..16fc70430 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -1,10 +1,11 @@ import asyncio import json +import re import uuid import warnings from concurrent.futures import Future from hashlib import md5 -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union from pydantic import ( UUID4, diff --git a/src/crewai/memory/storage/rag_storage.py b/src/crewai/memory/storage/rag_storage.py index fd4c77838..dc50bb823 100644 --- a/src/crewai/memory/storage/rag_storage.py +++ b/src/crewai/memory/storage/rag_storage.py @@ -163,12 +163,3 @@ class RAGStorage(BaseRAGStorage): raise Exception( f"An error occurred while resetting the {self.type} memory: {e}" ) - - def _create_default_embedding_function(self): - from chromadb.utils.embedding_functions.openai_embedding_function import ( - OpenAIEmbeddingFunction, - ) - - return OpenAIEmbeddingFunction( - api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small" - ) diff --git a/src/crewai/utilities/embedding_configurator.py b/src/crewai/utilities/embedding_configurator.py index 44e832ec2..bceddffc4 100644 --- a/src/crewai/utilities/embedding_configurator.py +++ b/src/crewai/utilities/embedding_configurator.py @@ -1,5 +1,5 @@ import os -from typing import Any, Dict, cast +from typing import Any, Dict, Optional, cast from chromadb import Documents, EmbeddingFunction, Embeddings from chromadb.api.types import validate_embedding_function @@ -21,9 +21,36 @@ class EmbeddingConfigurator: def configure_embedder( self, - embedder_config: Dict[str, Any] | None = None, + embedder_config: Optional[Dict[str, Any]] = None, ) -> EmbeddingFunction: - """Configures and returns an embedding function based on the provided config.""" + """Configure and return an embedding function based on the provided config. + + Args: + embedder_config: Optional configuration dictionary containing: + - provider: Name of the embedding provider or EmbeddingFunction instance + - config: Provider-specific configuration dictionary with options like: + - api_key: API key for the provider + - model: Model name to use for embeddings + - url: API endpoint URL (for some providers) + - session: Session object (for some providers) + + Returns: + EmbeddingFunction: Configured embedding function for the specified provider + + Raises: + ValueError: If custom embedding function is invalid + Exception: If provider is not supported or configuration is invalid + + Examples: + >>> config = { + ... "provider": "openai", + ... "config": { + ... "api_key": "your-api-key", + ... "model": "text-embedding-3-small" + ... } + ... } + >>> embedder = EmbeddingConfigurator().configure_embedder(config) + """ if embedder_config is None: return self._create_default_embedding_function() @@ -47,12 +74,23 @@ class EmbeddingConfigurator: @staticmethod def _create_default_embedding_function(): - from chromadb.utils.embedding_functions.openai_embedding_function import ( - OpenAIEmbeddingFunction, - ) - - return OpenAIEmbeddingFunction( - api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small" + """Create a default embedding function based on environment variables. + + Environment Variables: + CREWAI_EMBEDDING_PROVIDER: The embedding provider to use (default: "openai") + CREWAI_EMBEDDING_MODEL: The model to use for embeddings + OPENAI_API_KEY: API key for OpenAI (required if using OpenAI provider) + + Returns: + EmbeddingFunction: Configured embedding function + """ + provider = os.getenv("CREWAI_EMBEDDING_PROVIDER", "openai") + config = { + "api_key": os.getenv("OPENAI_API_KEY"), + "model": os.getenv("CREWAI_EMBEDDING_MODEL", "text-embedding-3-small") + } + return EmbeddingConfigurator().configure_embedder( + {"provider": provider, "config": config} ) @staticmethod diff --git a/tests/memory/test_memory_reset.py b/tests/memory/test_memory_reset.py new file mode 100644 index 000000000..4ef8115bd --- /dev/null +++ b/tests/memory/test_memory_reset.py @@ -0,0 +1,72 @@ +import os +import tempfile +from typing import Generator +from pathlib import Path + +import pytest +from chromadb import Documents, EmbeddingFunction, Embeddings + +from crewai.memory import ShortTermMemory, LongTermMemory, EntityMemory +from crewai.utilities import EmbeddingConfigurator +from crewai.utilities.exceptions.embedding_exceptions import ( + EmbeddingConfigurationError, + EmbeddingProviderError +) + +@pytest.fixture +def temp_db_dir() -> Generator[Path, None, None]: + """Create a temporary directory for test databases.""" + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) + yield path + +def test_memory_reset_with_openai(temp_db_dir): + """Test memory reset with default OpenAI provider.""" + os.environ["OPENAI_API_KEY"] = "test-key" + memory = ShortTermMemory(path=temp_db_dir) + memory.reset() # Should work with OpenAI as default + +def test_memory_reset_with_ollama(temp_db_dir): + """Test memory reset with Ollama provider.""" + os.environ["CREWAI_EMBEDDING_PROVIDER"] = "ollama" + memory = ShortTermMemory(path=temp_db_dir) + memory.reset() # Should not raise any OpenAI-related errors + +def test_memory_reset_with_custom_provider(temp_db_dir): + """Test memory reset with custom embedding provider.""" + class CustomEmbedder(EmbeddingFunction): + def __call__(self, input: Documents) -> Embeddings: + if isinstance(input, str): + input = [input] + return [[0.5] * 10] * len(input) + + memory = ShortTermMemory( + path=temp_db_dir, + embedder_config={"provider": CustomEmbedder()} + ) + memory.reset() # Should work with custom embedder + +def test_memory_reset_with_invalid_provider(temp_db_dir): + """Test memory reset with invalid provider raises appropriate error.""" + os.environ["CREWAI_EMBEDDING_PROVIDER"] = "invalid_provider" + with pytest.raises(Exception) as exc_info: + memory = ShortTermMemory(path=temp_db_dir) + memory.reset() + assert "Unsupported embedding provider" in str(exc_info.value) + +def test_memory_reset_with_missing_api_key(temp_db_dir): + """Test memory reset with missing API key raises appropriate error.""" + os.environ.pop("OPENAI_API_KEY", None) # Ensure key is not set + os.environ["CREWAI_EMBEDDING_PROVIDER"] = "openai" + with pytest.raises(Exception) as exc_info: + memory = ShortTermMemory(path=temp_db_dir) + memory.reset() + assert "api_key" in str(exc_info.value).lower() + +def test_memory_reset_cleans_up_files(temp_db_dir): + """Test that memory reset properly cleans up database files.""" + memory = ShortTermMemory(path=temp_db_dir) + memory.save("test memory", {"test": "metadata"}) + assert any(temp_db_dir.iterdir()) # Directory should have files + memory.reset() + assert not any(temp_db_dir.iterdir()) # Directory should be empty