mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 07:38:29 +00:00
Compare commits
3 Commits
devin/1741
...
devin/1741
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cf35f5af75 | ||
|
|
c2245c7024 | ||
|
|
cf2a1346fd |
@@ -83,42 +83,28 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
raise Exception("Collection not initialized")
|
||||
|
||||
def initialize_knowledge_storage(self):
|
||||
"""Initialize the knowledge storage with ChromaDB.
|
||||
|
||||
Handles SQLite3 version incompatibility gracefully by logging a warning
|
||||
and continuing without ChromaDB functionality.
|
||||
"""
|
||||
try:
|
||||
base_path = os.path.join(db_storage_path(), "knowledge")
|
||||
chroma_client = chromadb.PersistentClient(
|
||||
path=base_path,
|
||||
settings=Settings(allow_reset=True),
|
||||
)
|
||||
base_path = os.path.join(db_storage_path(), "knowledge")
|
||||
chroma_client = chromadb.PersistentClient(
|
||||
path=base_path,
|
||||
settings=Settings(allow_reset=True),
|
||||
)
|
||||
|
||||
self.app = chroma_client
|
||||
|
||||
self.app = chroma_client
|
||||
|
||||
try:
|
||||
collection_name = (
|
||||
f"knowledge_{self.collection_name}"
|
||||
if self.collection_name
|
||||
else "knowledge"
|
||||
)
|
||||
|
||||
if not self.app:
|
||||
raise Exception("Vector Database Client not initialized")
|
||||
|
||||
self.collection = self.app.get_or_create_collection(
|
||||
name=collection_name, embedding_function=self.embedder
|
||||
)
|
||||
except RuntimeError as e:
|
||||
if "unsupported version of sqlite3" in str(e).lower():
|
||||
# Log a warning but continue without ChromaDB
|
||||
logging.warning("ChromaDB requires SQLite3 >= 3.35.0. Current version is too old. Some features may be limited. Error: %s", e)
|
||||
self.app = None
|
||||
self.collection = None
|
||||
if self.app:
|
||||
self.collection = self.app.get_or_create_collection(
|
||||
name=collection_name, embedding_function=self.embedder
|
||||
)
|
||||
else:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to create or get collection: {e}")
|
||||
raise Exception("Vector Database Client not initialized")
|
||||
except Exception:
|
||||
raise Exception("Failed to create or get collection")
|
||||
|
||||
def reset(self):
|
||||
base_path = os.path.join(db_storage_path(), KNOWLEDGE_DIRECTORY)
|
||||
|
||||
@@ -4,19 +4,15 @@ import logging
|
||||
import os
|
||||
import shutil
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Union, Collection as TypeCollection
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from chromadb.api import ClientAPI
|
||||
from chromadb.api.models.Collection import Collection as ChromaCollection
|
||||
|
||||
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
|
||||
from crewai.utilities import EmbeddingConfigurator
|
||||
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
# Constants
|
||||
SQLITE_VERSION_ERROR = "ChromaDB requires SQLite3 >= 3.35.0. Current version is too old. Some features may be limited. Error: {}"
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def suppress_logging(
|
||||
@@ -64,40 +60,25 @@ class RAGStorage(BaseRAGStorage):
|
||||
self.embedder_config = configurator.configure_embedder(self.embedder_config)
|
||||
|
||||
def _initialize_app(self):
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
|
||||
self._set_embedder_config()
|
||||
chroma_client = chromadb.PersistentClient(
|
||||
path=self.path if self.path else self.storage_file_name,
|
||||
settings=Settings(allow_reset=self.allow_reset),
|
||||
)
|
||||
|
||||
self.app = chroma_client
|
||||
|
||||
try:
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
|
||||
self._set_embedder_config()
|
||||
if self.embedder_config is None:
|
||||
# ChromaDB is not available, skip initialization
|
||||
self.app = None
|
||||
self.collection = None
|
||||
return
|
||||
|
||||
chroma_client = chromadb.PersistentClient(
|
||||
path=self.path if self.path else self.storage_file_name,
|
||||
settings=Settings(allow_reset=self.allow_reset),
|
||||
self.collection = self.app.get_collection(
|
||||
name=self.type, embedding_function=self.embedder_config
|
||||
)
|
||||
except Exception:
|
||||
self.collection = self.app.create_collection(
|
||||
name=self.type, embedding_function=self.embedder_config
|
||||
)
|
||||
|
||||
self.app = chroma_client
|
||||
|
||||
try:
|
||||
self.collection = self.app.get_collection(
|
||||
name=self.type, embedding_function=self.embedder_config
|
||||
)
|
||||
except Exception:
|
||||
self.collection = self.app.create_collection(
|
||||
name=self.type, embedding_function=self.embedder_config
|
||||
)
|
||||
except RuntimeError as e:
|
||||
if "unsupported version of sqlite3" in str(e).lower():
|
||||
# Log a warning but continue without ChromaDB
|
||||
logging.warning(SQLITE_VERSION_ERROR.format(e))
|
||||
self.app = None
|
||||
self.collection = None
|
||||
else:
|
||||
raise
|
||||
|
||||
def _sanitize_role(self, role: str) -> str:
|
||||
"""
|
||||
|
||||
@@ -432,7 +432,13 @@ class ToolUsage:
|
||||
# Attempt 1: Parse as JSON
|
||||
try:
|
||||
arguments = json.loads(tool_input)
|
||||
if isinstance(arguments, dict):
|
||||
# Handle case where arguments is a list
|
||||
if isinstance(arguments, list) and len(arguments) > 0 and isinstance(arguments[0], dict):
|
||||
self._printer.print(
|
||||
content=f"Tool input is a list, extracting first element: {arguments[0]}", color="blue"
|
||||
)
|
||||
return arguments[0]
|
||||
elif isinstance(arguments, dict):
|
||||
return arguments
|
||||
except (JSONDecodeError, TypeError):
|
||||
pass # Continue to the next parsing attempt
|
||||
@@ -440,7 +446,13 @@ class ToolUsage:
|
||||
# Attempt 2: Parse as Python literal
|
||||
try:
|
||||
arguments = ast.literal_eval(tool_input)
|
||||
if isinstance(arguments, dict):
|
||||
# Handle case where arguments is a list
|
||||
if isinstance(arguments, list) and len(arguments) > 0 and isinstance(arguments[0], dict):
|
||||
self._printer.print(
|
||||
content=f"Tool input is a list, extracting first element: {arguments[0]}", color="blue"
|
||||
)
|
||||
return arguments[0]
|
||||
elif isinstance(arguments, dict):
|
||||
return arguments
|
||||
except (ValueError, SyntaxError):
|
||||
pass # Continue to the next parsing attempt
|
||||
@@ -448,7 +460,13 @@ class ToolUsage:
|
||||
# Attempt 3: Parse as JSON5
|
||||
try:
|
||||
arguments = json5.loads(tool_input)
|
||||
if isinstance(arguments, dict):
|
||||
# Handle case where arguments is a list
|
||||
if isinstance(arguments, list) and len(arguments) > 0 and isinstance(arguments[0], dict):
|
||||
self._printer.print(
|
||||
content=f"Tool input is a list, extracting first element: {arguments[0]}", color="blue"
|
||||
)
|
||||
return arguments[0]
|
||||
elif isinstance(arguments, dict):
|
||||
return arguments
|
||||
except (JSONDecodeError, ValueError, TypeError):
|
||||
pass # Continue to the next parsing attempt
|
||||
@@ -460,7 +478,13 @@ class ToolUsage:
|
||||
content=f"Repaired JSON: {repaired_input}", color="blue"
|
||||
)
|
||||
arguments = json.loads(repaired_input)
|
||||
if isinstance(arguments, dict):
|
||||
# Handle case where arguments is a list
|
||||
if isinstance(arguments, list) and len(arguments) > 0 and isinstance(arguments[0], dict):
|
||||
self._printer.print(
|
||||
content=f"Tool input is a list, extracting first element: {arguments[0]}", color="blue"
|
||||
)
|
||||
return arguments[0]
|
||||
elif isinstance(arguments, dict):
|
||||
return arguments
|
||||
except Exception as e:
|
||||
error = f"Failed to repair JSON: {e}"
|
||||
|
||||
@@ -1,31 +1,12 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, Optional, cast
|
||||
|
||||
# Import chromadb conditionally to handle SQLite3 version errors
|
||||
try:
|
||||
from chromadb import Documents, EmbeddingFunction, Embeddings
|
||||
from chromadb.api.types import validate_embedding_function
|
||||
CHROMADB_AVAILABLE = True
|
||||
except RuntimeError as e:
|
||||
if "unsupported version of sqlite3" in str(e).lower():
|
||||
logging.warning(f"ChromaDB requires SQLite3 >= 3.35.0. Current version is too old. Some features may be limited. Error: {e}")
|
||||
CHROMADB_AVAILABLE = False
|
||||
# Define placeholder types for type hints
|
||||
Documents = Any
|
||||
EmbeddingFunction = Any
|
||||
Embeddings = Any
|
||||
validate_embedding_function = lambda x: x # noqa: E731
|
||||
else:
|
||||
raise
|
||||
from chromadb import Documents, EmbeddingFunction, Embeddings
|
||||
from chromadb.api.types import validate_embedding_function
|
||||
|
||||
|
||||
class EmbeddingConfigurator:
|
||||
def __init__(self):
|
||||
if not CHROMADB_AVAILABLE:
|
||||
self.embedding_functions = {}
|
||||
return
|
||||
|
||||
self.embedding_functions = {
|
||||
"openai": self._configure_openai,
|
||||
"azure": self._configure_azure,
|
||||
@@ -40,45 +21,13 @@ class EmbeddingConfigurator:
|
||||
"custom": self._configure_custom,
|
||||
}
|
||||
|
||||
def _validate_config(self, config: Dict[str, Any]) -> bool:
|
||||
"""Validates that the configuration contains the required keys.
|
||||
|
||||
Args:
|
||||
config: The configuration dictionary to validate
|
||||
|
||||
Returns:
|
||||
bool: True if the configuration is valid, False otherwise
|
||||
"""
|
||||
if not config:
|
||||
return False
|
||||
|
||||
required_keys = {'provider'}
|
||||
return all(key in config for key in required_keys)
|
||||
|
||||
def configure_embedder(
|
||||
self,
|
||||
embedder_config: Optional[Dict[str, Any]] = None,
|
||||
) -> Optional[EmbeddingFunction]:
|
||||
"""Configures and returns an embedding function based on the provided config.
|
||||
|
||||
Args:
|
||||
embedder_config: Configuration dictionary for the embedder
|
||||
|
||||
Returns:
|
||||
Optional[EmbeddingFunction]: The configured embedding function or None if ChromaDB is not available
|
||||
|
||||
Raises:
|
||||
ValueError: If the configuration is invalid
|
||||
Exception: If the provider is not supported
|
||||
"""
|
||||
if not CHROMADB_AVAILABLE:
|
||||
return None
|
||||
|
||||
) -> EmbeddingFunction:
|
||||
"""Configures and returns an embedding function based on the provided config."""
|
||||
if embedder_config is None:
|
||||
return self._create_default_embedding_function()
|
||||
|
||||
if not self._validate_config(embedder_config):
|
||||
raise ValueError("Invalid embedder configuration: missing required keys")
|
||||
|
||||
provider = embedder_config.get("provider")
|
||||
config = embedder_config.get("config", {})
|
||||
@@ -98,9 +47,6 @@ class EmbeddingConfigurator:
|
||||
|
||||
@staticmethod
|
||||
def _create_default_embedding_function():
|
||||
if not CHROMADB_AVAILABLE:
|
||||
return None
|
||||
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
|
||||
76
tests/tools/test_tool_input_validation.py
Normal file
76
tests/tools/test_tool_input_validation.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.tools.tool_usage import ToolUsage
|
||||
|
||||
|
||||
class TestToolInputValidation:
|
||||
def setup_method(self):
|
||||
# Create mock objects for testing
|
||||
self.mock_tools_handler = MagicMock()
|
||||
self.mock_tools = [MagicMock()]
|
||||
self.mock_original_tools = [MagicMock()]
|
||||
self.mock_tools_description = "Mock tools description"
|
||||
self.mock_tools_names = "Mock tools names"
|
||||
self.mock_task = MagicMock()
|
||||
self.mock_function_calling_llm = MagicMock()
|
||||
|
||||
# Create mock agent with required string attributes
|
||||
self.mock_agent = MagicMock()
|
||||
self.mock_agent.key = "mock_agent_key"
|
||||
self.mock_agent.role = "mock_agent_role"
|
||||
self.mock_agent._original_role = "mock_original_role"
|
||||
|
||||
# Create mock action with required string attributes
|
||||
self.mock_action = MagicMock()
|
||||
self.mock_action.tool = "mock_tool_name"
|
||||
self.mock_action.tool_input = "mock_tool_input"
|
||||
|
||||
# Create ToolUsage instance
|
||||
self.tool_usage = ToolUsage(
|
||||
tools_handler=self.mock_tools_handler,
|
||||
tools=self.mock_tools,
|
||||
original_tools=self.mock_original_tools,
|
||||
tools_description=self.mock_tools_description,
|
||||
tools_names=self.mock_tools_names,
|
||||
task=self.mock_task,
|
||||
function_calling_llm=self.mock_function_calling_llm,
|
||||
agent=self.mock_agent,
|
||||
action=self.mock_action,
|
||||
)
|
||||
|
||||
# Patch the _emit_validate_input_error method to avoid event emission
|
||||
self.original_emit_validate_input_error = self.tool_usage._emit_validate_input_error
|
||||
self.tool_usage._emit_validate_input_error = MagicMock()
|
||||
|
||||
def teardown_method(self):
|
||||
# Restore the original method
|
||||
if hasattr(self, 'original_emit_validate_input_error'):
|
||||
self.tool_usage._emit_validate_input_error = self.original_emit_validate_input_error
|
||||
|
||||
def test_validate_tool_input_with_dict(self):
|
||||
# Test with a valid dictionary input
|
||||
tool_input = '{"ticker": "VST"}'
|
||||
result = self.tool_usage._validate_tool_input(tool_input)
|
||||
assert result == {"ticker": "VST"}
|
||||
|
||||
def test_validate_tool_input_with_list(self):
|
||||
# Test with a list input containing a dictionary as the first element
|
||||
tool_input = '[{"ticker": "VST"}, {"tool_code": "Stock Info", "tool_input": {"ticker": "VST"}}]'
|
||||
result = self.tool_usage._validate_tool_input(tool_input)
|
||||
assert result == {"ticker": "VST"}
|
||||
|
||||
def test_validate_tool_input_with_empty_list(self):
|
||||
# Test with an empty list input
|
||||
tool_input = '[]'
|
||||
with pytest.raises(Exception) as excinfo:
|
||||
self.tool_usage._validate_tool_input(tool_input)
|
||||
assert "Tool input must be a valid dictionary in JSON or Python literal format" in str(excinfo.value)
|
||||
|
||||
def test_validate_tool_input_with_list_of_non_dicts(self):
|
||||
# Test with a list input containing non-dictionary elements
|
||||
tool_input = '["not a dict", 123]'
|
||||
with pytest.raises(Exception) as excinfo:
|
||||
self.tool_usage._validate_tool_input(tool_input)
|
||||
assert "Tool input must be a valid dictionary in JSON or Python literal format" in str(excinfo.value)
|
||||
@@ -1,52 +0,0 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
class TestEmbeddingConfigurator(unittest.TestCase):
|
||||
@patch('crewai.utilities.embedding_configurator.CHROMADB_AVAILABLE', False)
|
||||
def test_embedding_configurator_with_chromadb_unavailable(self):
|
||||
from crewai.utilities.embedding_configurator import EmbeddingConfigurator
|
||||
|
||||
# Create an instance of EmbeddingConfigurator
|
||||
configurator = EmbeddingConfigurator()
|
||||
|
||||
# Verify that embedding_functions is empty
|
||||
self.assertEqual(configurator.embedding_functions, {})
|
||||
|
||||
# Verify that configure_embedder returns None
|
||||
self.assertIsNone(configurator.configure_embedder())
|
||||
|
||||
@patch('crewai.utilities.embedding_configurator.CHROMADB_AVAILABLE', True)
|
||||
def test_embedding_configurator_with_chromadb_available(self):
|
||||
from crewai.utilities.embedding_configurator import EmbeddingConfigurator
|
||||
|
||||
# Create an instance of EmbeddingConfigurator
|
||||
configurator = EmbeddingConfigurator()
|
||||
|
||||
# Verify that embedding_functions is not empty
|
||||
self.assertNotEqual(configurator.embedding_functions, {})
|
||||
|
||||
# Mock the _create_default_embedding_function method
|
||||
configurator._create_default_embedding_function = MagicMock(return_value="mock_embedding_function")
|
||||
|
||||
# Verify that configure_embedder returns the mock embedding function
|
||||
self.assertEqual(configurator.configure_embedder(), "mock_embedding_function")
|
||||
|
||||
@patch('crewai.utilities.embedding_configurator.CHROMADB_AVAILABLE', True)
|
||||
def test_embedding_configurator_with_invalid_config(self):
|
||||
from crewai.utilities.embedding_configurator import EmbeddingConfigurator
|
||||
|
||||
# Create an instance of EmbeddingConfigurator
|
||||
configurator = EmbeddingConfigurator()
|
||||
|
||||
# Test with empty config
|
||||
with self.assertRaises(ValueError):
|
||||
configurator.configure_embedder({})
|
||||
|
||||
# Test with missing required keys
|
||||
with self.assertRaises(ValueError):
|
||||
configurator.configure_embedder({"config": {}})
|
||||
|
||||
# Test with unsupported provider
|
||||
with self.assertRaises(Exception):
|
||||
configurator.configure_embedder({"provider": "unsupported_provider", "config": {}})
|
||||
Reference in New Issue
Block a user