mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 08:38:30 +00:00
Address PR comments: Improve code quality and add validation
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -83,6 +83,11 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
raise Exception("Collection not initialized")
|
raise Exception("Collection not initialized")
|
||||||
|
|
||||||
def initialize_knowledge_storage(self):
|
def initialize_knowledge_storage(self):
|
||||||
|
"""Initialize the knowledge storage with ChromaDB.
|
||||||
|
|
||||||
|
Handles SQLite3 version incompatibility gracefully by logging a warning
|
||||||
|
and continuing without ChromaDB functionality.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
base_path = os.path.join(db_storage_path(), "knowledge")
|
base_path = os.path.join(db_storage_path(), "knowledge")
|
||||||
chroma_client = chromadb.PersistentClient(
|
chroma_client = chromadb.PersistentClient(
|
||||||
@@ -91,29 +96,29 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.app = chroma_client
|
self.app = chroma_client
|
||||||
|
|
||||||
try:
|
collection_name = (
|
||||||
collection_name = (
|
f"knowledge_{self.collection_name}"
|
||||||
f"knowledge_{self.collection_name}"
|
if self.collection_name
|
||||||
if self.collection_name
|
else "knowledge"
|
||||||
else "knowledge"
|
)
|
||||||
)
|
|
||||||
if self.app:
|
if not self.app:
|
||||||
self.collection = self.app.get_or_create_collection(
|
raise Exception("Vector Database Client not initialized")
|
||||||
name=collection_name, embedding_function=self.embedder
|
|
||||||
)
|
self.collection = self.app.get_or_create_collection(
|
||||||
else:
|
name=collection_name, embedding_function=self.embedder
|
||||||
raise Exception("Vector Database Client not initialized")
|
)
|
||||||
except Exception:
|
|
||||||
raise Exception("Failed to create or get collection")
|
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
if "unsupported version of sqlite3" in str(e).lower():
|
if "unsupported version of sqlite3" in str(e).lower():
|
||||||
# Log a warning but continue without ChromaDB
|
# Log a warning but continue without ChromaDB
|
||||||
logging.warning(f"ChromaDB requires SQLite3 >= 3.35.0. Current version is too old. Some features may be limited. Error: {e}")
|
logging.warning("ChromaDB requires SQLite3 >= 3.35.0. Current version is too old. Some features may be limited. Error: %s", e)
|
||||||
self.app = None
|
self.app = None
|
||||||
self.collection = None
|
self.collection = None
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Failed to create or get collection: {e}")
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
base_path = os.path.join(db_storage_path(), KNOWLEDGE_DIRECTORY)
|
base_path = os.path.join(db_storage_path(), KNOWLEDGE_DIRECTORY)
|
||||||
|
|||||||
@@ -4,15 +4,19 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, Union, Collection
|
||||||
|
|
||||||
from chromadb.api import ClientAPI
|
from chromadb.api import ClientAPI
|
||||||
|
from chromadb.api.models.Collection import Collection
|
||||||
|
|
||||||
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
|
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
|
||||||
from crewai.utilities import EmbeddingConfigurator
|
from crewai.utilities import EmbeddingConfigurator
|
||||||
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
|
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
|
||||||
from crewai.utilities.paths import db_storage_path
|
from crewai.utilities.paths import db_storage_path
|
||||||
|
|
||||||
|
# Constants
|
||||||
|
SQLITE_VERSION_ERROR = "ChromaDB requires SQLite3 >= 3.35.0. Current version is too old. Some features may be limited. Error: {}"
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def suppress_logging(
|
def suppress_logging(
|
||||||
@@ -89,7 +93,7 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
if "unsupported version of sqlite3" in str(e).lower():
|
if "unsupported version of sqlite3" in str(e).lower():
|
||||||
# Log a warning but continue without ChromaDB
|
# Log a warning but continue without ChromaDB
|
||||||
logging.warning(f"ChromaDB requires SQLite3 >= 3.35.0. Current version is too old. Some features may be limited. Error: {e}")
|
logging.warning(SQLITE_VERSION_ERROR.format(e))
|
||||||
self.app = None
|
self.app = None
|
||||||
self.collection = None
|
self.collection = None
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -40,16 +40,45 @@ class EmbeddingConfigurator:
|
|||||||
"custom": self._configure_custom,
|
"custom": self._configure_custom,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _validate_config(self, config: Dict[str, Any]) -> bool:
|
||||||
|
"""Validates that the configuration contains the required keys.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: The configuration dictionary to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the configuration is valid, False otherwise
|
||||||
|
"""
|
||||||
|
if not config:
|
||||||
|
return False
|
||||||
|
|
||||||
|
required_keys = {'provider'}
|
||||||
|
return all(key in config for key in required_keys)
|
||||||
|
|
||||||
def configure_embedder(
|
def configure_embedder(
|
||||||
self,
|
self,
|
||||||
embedder_config: Optional[Dict[str, Any]] = None,
|
embedder_config: Optional[Dict[str, Any]] = None,
|
||||||
) -> Optional[EmbeddingFunction]:
|
) -> Optional[EmbeddingFunction]:
|
||||||
"""Configures and returns an embedding function based on the provided config."""
|
"""Configures and returns an embedding function based on the provided config.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedder_config: Configuration dictionary for the embedder
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[EmbeddingFunction]: The configured embedding function or None if ChromaDB is not available
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the configuration is invalid
|
||||||
|
Exception: If the provider is not supported
|
||||||
|
"""
|
||||||
if not CHROMADB_AVAILABLE:
|
if not CHROMADB_AVAILABLE:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if embedder_config is None:
|
if embedder_config is None:
|
||||||
return self._create_default_embedding_function()
|
return self._create_default_embedding_function()
|
||||||
|
|
||||||
|
if not self._validate_config(embedder_config):
|
||||||
|
raise ValueError("Invalid embedder configuration: missing required keys")
|
||||||
|
|
||||||
provider = embedder_config.get("provider")
|
provider = embedder_config.get("provider")
|
||||||
config = embedder_config.get("config", {})
|
config = embedder_config.get("config", {})
|
||||||
|
|||||||
@@ -31,3 +31,22 @@ class TestEmbeddingConfigurator(unittest.TestCase):
|
|||||||
|
|
||||||
# Verify that configure_embedder returns the mock embedding function
|
# Verify that configure_embedder returns the mock embedding function
|
||||||
self.assertEqual(configurator.configure_embedder(), "mock_embedding_function")
|
self.assertEqual(configurator.configure_embedder(), "mock_embedding_function")
|
||||||
|
|
||||||
|
@patch('crewai.utilities.embedding_configurator.CHROMADB_AVAILABLE', True)
|
||||||
|
def test_embedding_configurator_with_invalid_config(self):
|
||||||
|
from crewai.utilities.embedding_configurator import EmbeddingConfigurator
|
||||||
|
|
||||||
|
# Create an instance of EmbeddingConfigurator
|
||||||
|
configurator = EmbeddingConfigurator()
|
||||||
|
|
||||||
|
# Test with empty config
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
configurator.configure_embedder({})
|
||||||
|
|
||||||
|
# Test with missing required keys
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
configurator.configure_embedder({"config": {}})
|
||||||
|
|
||||||
|
# Test with unsupported provider
|
||||||
|
with self.assertRaises(Exception):
|
||||||
|
configurator.configure_embedder({"provider": "unsupported_provider", "config": {}})
|
||||||
|
|||||||
Reference in New Issue
Block a user