Compare commits

..

3 Commits

Author SHA1 Message Date
Devin AI
cf35f5af75 Fix linting issue: Properly sort imports in test file
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-03-05 22:20:40 +00:00
Devin AI
c2245c7024 Fix linting issue: Sort imports in test file
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-03-05 22:19:03 +00:00
Devin AI
cf2a1346fd Fix issue 2288: Handle list inputs in tool_usage._validate_tool_input
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-03-05 22:17:34 +00:00
6 changed files with 141 additions and 180 deletions

View File

@@ -83,42 +83,28 @@ class KnowledgeStorage(BaseKnowledgeStorage):
raise Exception("Collection not initialized")
def initialize_knowledge_storage(self):
"""Initialize the knowledge storage with ChromaDB.
Handles SQLite3 version incompatibility gracefully by logging a warning
and continuing without ChromaDB functionality.
"""
try:
base_path = os.path.join(db_storage_path(), "knowledge")
chroma_client = chromadb.PersistentClient(
path=base_path,
settings=Settings(allow_reset=True),
)
base_path = os.path.join(db_storage_path(), "knowledge")
chroma_client = chromadb.PersistentClient(
path=base_path,
settings=Settings(allow_reset=True),
)
self.app = chroma_client
self.app = chroma_client
try:
collection_name = (
f"knowledge_{self.collection_name}"
if self.collection_name
else "knowledge"
)
if not self.app:
raise Exception("Vector Database Client not initialized")
self.collection = self.app.get_or_create_collection(
name=collection_name, embedding_function=self.embedder
)
except RuntimeError as e:
if "unsupported version of sqlite3" in str(e).lower():
# Log a warning but continue without ChromaDB
logging.warning("ChromaDB requires SQLite3 >= 3.35.0. Current version is too old. Some features may be limited. Error: %s", e)
self.app = None
self.collection = None
if self.app:
self.collection = self.app.get_or_create_collection(
name=collection_name, embedding_function=self.embedder
)
else:
raise
except Exception as e:
raise Exception(f"Failed to create or get collection: {e}")
raise Exception("Vector Database Client not initialized")
except Exception:
raise Exception("Failed to create or get collection")
def reset(self):
base_path = os.path.join(db_storage_path(), KNOWLEDGE_DIRECTORY)

View File

@@ -4,19 +4,15 @@ import logging
import os
import shutil
import uuid
from typing import Any, Dict, List, Optional, Union, Collection as TypeCollection
from typing import Any, Dict, List, Optional
from chromadb.api import ClientAPI
from chromadb.api.models.Collection import Collection as ChromaCollection
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
from crewai.utilities import EmbeddingConfigurator
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
from crewai.utilities.paths import db_storage_path
# Constants
SQLITE_VERSION_ERROR = "ChromaDB requires SQLite3 >= 3.35.0. Current version is too old. Some features may be limited. Error: {}"
@contextlib.contextmanager
def suppress_logging(
@@ -64,40 +60,25 @@ class RAGStorage(BaseRAGStorage):
self.embedder_config = configurator.configure_embedder(self.embedder_config)
def _initialize_app(self):
import chromadb
from chromadb.config import Settings
self._set_embedder_config()
chroma_client = chromadb.PersistentClient(
path=self.path if self.path else self.storage_file_name,
settings=Settings(allow_reset=self.allow_reset),
)
self.app = chroma_client
try:
import chromadb
from chromadb.config import Settings
self._set_embedder_config()
if self.embedder_config is None:
# ChromaDB is not available, skip initialization
self.app = None
self.collection = None
return
chroma_client = chromadb.PersistentClient(
path=self.path if self.path else self.storage_file_name,
settings=Settings(allow_reset=self.allow_reset),
self.collection = self.app.get_collection(
name=self.type, embedding_function=self.embedder_config
)
except Exception:
self.collection = self.app.create_collection(
name=self.type, embedding_function=self.embedder_config
)
self.app = chroma_client
try:
self.collection = self.app.get_collection(
name=self.type, embedding_function=self.embedder_config
)
except Exception:
self.collection = self.app.create_collection(
name=self.type, embedding_function=self.embedder_config
)
except RuntimeError as e:
if "unsupported version of sqlite3" in str(e).lower():
# Log a warning but continue without ChromaDB
logging.warning(SQLITE_VERSION_ERROR.format(e))
self.app = None
self.collection = None
else:
raise
def _sanitize_role(self, role: str) -> str:
"""

View File

@@ -432,7 +432,13 @@ class ToolUsage:
# Attempt 1: Parse as JSON
try:
arguments = json.loads(tool_input)
if isinstance(arguments, dict):
# Handle case where arguments is a list
if isinstance(arguments, list) and len(arguments) > 0 and isinstance(arguments[0], dict):
self._printer.print(
content=f"Tool input is a list, extracting first element: {arguments[0]}", color="blue"
)
return arguments[0]
elif isinstance(arguments, dict):
return arguments
except (JSONDecodeError, TypeError):
pass # Continue to the next parsing attempt
@@ -440,7 +446,13 @@ class ToolUsage:
# Attempt 2: Parse as Python literal
try:
arguments = ast.literal_eval(tool_input)
if isinstance(arguments, dict):
# Handle case where arguments is a list
if isinstance(arguments, list) and len(arguments) > 0 and isinstance(arguments[0], dict):
self._printer.print(
content=f"Tool input is a list, extracting first element: {arguments[0]}", color="blue"
)
return arguments[0]
elif isinstance(arguments, dict):
return arguments
except (ValueError, SyntaxError):
pass # Continue to the next parsing attempt
@@ -448,7 +460,13 @@ class ToolUsage:
# Attempt 3: Parse as JSON5
try:
arguments = json5.loads(tool_input)
if isinstance(arguments, dict):
# Handle case where arguments is a list
if isinstance(arguments, list) and len(arguments) > 0 and isinstance(arguments[0], dict):
self._printer.print(
content=f"Tool input is a list, extracting first element: {arguments[0]}", color="blue"
)
return arguments[0]
elif isinstance(arguments, dict):
return arguments
except (JSONDecodeError, ValueError, TypeError):
pass # Continue to the next parsing attempt
@@ -460,7 +478,13 @@ class ToolUsage:
content=f"Repaired JSON: {repaired_input}", color="blue"
)
arguments = json.loads(repaired_input)
if isinstance(arguments, dict):
# Handle case where arguments is a list
if isinstance(arguments, list) and len(arguments) > 0 and isinstance(arguments[0], dict):
self._printer.print(
content=f"Tool input is a list, extracting first element: {arguments[0]}", color="blue"
)
return arguments[0]
elif isinstance(arguments, dict):
return arguments
except Exception as e:
error = f"Failed to repair JSON: {e}"

View File

@@ -1,31 +1,12 @@
import logging
import os
from typing import Any, Dict, Optional, cast
# Import chromadb conditionally to handle SQLite3 version errors
try:
from chromadb import Documents, EmbeddingFunction, Embeddings
from chromadb.api.types import validate_embedding_function
CHROMADB_AVAILABLE = True
except RuntimeError as e:
if "unsupported version of sqlite3" in str(e).lower():
logging.warning(f"ChromaDB requires SQLite3 >= 3.35.0. Current version is too old. Some features may be limited. Error: {e}")
CHROMADB_AVAILABLE = False
# Define placeholder types for type hints
Documents = Any
EmbeddingFunction = Any
Embeddings = Any
validate_embedding_function = lambda x: x # noqa: E731
else:
raise
from chromadb import Documents, EmbeddingFunction, Embeddings
from chromadb.api.types import validate_embedding_function
class EmbeddingConfigurator:
def __init__(self):
if not CHROMADB_AVAILABLE:
self.embedding_functions = {}
return
self.embedding_functions = {
"openai": self._configure_openai,
"azure": self._configure_azure,
@@ -40,45 +21,13 @@ class EmbeddingConfigurator:
"custom": self._configure_custom,
}
def _validate_config(self, config: Dict[str, Any]) -> bool:
"""Validates that the configuration contains the required keys.
Args:
config: The configuration dictionary to validate
Returns:
bool: True if the configuration is valid, False otherwise
"""
if not config:
return False
required_keys = {'provider'}
return all(key in config for key in required_keys)
def configure_embedder(
self,
embedder_config: Optional[Dict[str, Any]] = None,
) -> Optional[EmbeddingFunction]:
"""Configures and returns an embedding function based on the provided config.
Args:
embedder_config: Configuration dictionary for the embedder
Returns:
Optional[EmbeddingFunction]: The configured embedding function or None if ChromaDB is not available
Raises:
ValueError: If the configuration is invalid
Exception: If the provider is not supported
"""
if not CHROMADB_AVAILABLE:
return None
) -> EmbeddingFunction:
"""Configures and returns an embedding function based on the provided config."""
if embedder_config is None:
return self._create_default_embedding_function()
if not self._validate_config(embedder_config):
raise ValueError("Invalid embedder configuration: missing required keys")
provider = embedder_config.get("provider")
config = embedder_config.get("config", {})
@@ -98,9 +47,6 @@ class EmbeddingConfigurator:
@staticmethod
def _create_default_embedding_function():
if not CHROMADB_AVAILABLE:
return None
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)

View File

@@ -0,0 +1,76 @@
from unittest.mock import MagicMock, patch
import pytest
from crewai.tools.tool_usage import ToolUsage
class TestToolInputValidation:
def setup_method(self):
# Create mock objects for testing
self.mock_tools_handler = MagicMock()
self.mock_tools = [MagicMock()]
self.mock_original_tools = [MagicMock()]
self.mock_tools_description = "Mock tools description"
self.mock_tools_names = "Mock tools names"
self.mock_task = MagicMock()
self.mock_function_calling_llm = MagicMock()
# Create mock agent with required string attributes
self.mock_agent = MagicMock()
self.mock_agent.key = "mock_agent_key"
self.mock_agent.role = "mock_agent_role"
self.mock_agent._original_role = "mock_original_role"
# Create mock action with required string attributes
self.mock_action = MagicMock()
self.mock_action.tool = "mock_tool_name"
self.mock_action.tool_input = "mock_tool_input"
# Create ToolUsage instance
self.tool_usage = ToolUsage(
tools_handler=self.mock_tools_handler,
tools=self.mock_tools,
original_tools=self.mock_original_tools,
tools_description=self.mock_tools_description,
tools_names=self.mock_tools_names,
task=self.mock_task,
function_calling_llm=self.mock_function_calling_llm,
agent=self.mock_agent,
action=self.mock_action,
)
# Patch the _emit_validate_input_error method to avoid event emission
self.original_emit_validate_input_error = self.tool_usage._emit_validate_input_error
self.tool_usage._emit_validate_input_error = MagicMock()
def teardown_method(self):
# Restore the original method
if hasattr(self, 'original_emit_validate_input_error'):
self.tool_usage._emit_validate_input_error = self.original_emit_validate_input_error
def test_validate_tool_input_with_dict(self):
# Test with a valid dictionary input
tool_input = '{"ticker": "VST"}'
result = self.tool_usage._validate_tool_input(tool_input)
assert result == {"ticker": "VST"}
def test_validate_tool_input_with_list(self):
# Test with a list input containing a dictionary as the first element
tool_input = '[{"ticker": "VST"}, {"tool_code": "Stock Info", "tool_input": {"ticker": "VST"}}]'
result = self.tool_usage._validate_tool_input(tool_input)
assert result == {"ticker": "VST"}
def test_validate_tool_input_with_empty_list(self):
# Test with an empty list input
tool_input = '[]'
with pytest.raises(Exception) as excinfo:
self.tool_usage._validate_tool_input(tool_input)
assert "Tool input must be a valid dictionary in JSON or Python literal format" in str(excinfo.value)
def test_validate_tool_input_with_list_of_non_dicts(self):
# Test with a list input containing non-dictionary elements
tool_input = '["not a dict", 123]'
with pytest.raises(Exception) as excinfo:
self.tool_usage._validate_tool_input(tool_input)
assert "Tool input must be a valid dictionary in JSON or Python literal format" in str(excinfo.value)

View File

@@ -1,52 +0,0 @@
import unittest
from unittest.mock import MagicMock, patch
class TestEmbeddingConfigurator(unittest.TestCase):
@patch('crewai.utilities.embedding_configurator.CHROMADB_AVAILABLE', False)
def test_embedding_configurator_with_chromadb_unavailable(self):
from crewai.utilities.embedding_configurator import EmbeddingConfigurator
# Create an instance of EmbeddingConfigurator
configurator = EmbeddingConfigurator()
# Verify that embedding_functions is empty
self.assertEqual(configurator.embedding_functions, {})
# Verify that configure_embedder returns None
self.assertIsNone(configurator.configure_embedder())
@patch('crewai.utilities.embedding_configurator.CHROMADB_AVAILABLE', True)
def test_embedding_configurator_with_chromadb_available(self):
from crewai.utilities.embedding_configurator import EmbeddingConfigurator
# Create an instance of EmbeddingConfigurator
configurator = EmbeddingConfigurator()
# Verify that embedding_functions is not empty
self.assertNotEqual(configurator.embedding_functions, {})
# Mock the _create_default_embedding_function method
configurator._create_default_embedding_function = MagicMock(return_value="mock_embedding_function")
# Verify that configure_embedder returns the mock embedding function
self.assertEqual(configurator.configure_embedder(), "mock_embedding_function")
@patch('crewai.utilities.embedding_configurator.CHROMADB_AVAILABLE', True)
def test_embedding_configurator_with_invalid_config(self):
from crewai.utilities.embedding_configurator import EmbeddingConfigurator
# Create an instance of EmbeddingConfigurator
configurator = EmbeddingConfigurator()
# Test with empty config
with self.assertRaises(ValueError):
configurator.configure_embedder({})
# Test with missing required keys
with self.assertRaises(ValueError):
configurator.configure_embedder({"config": {}})
# Test with unsupported provider
with self.assertRaises(Exception):
configurator.configure_embedder({"provider": "unsupported_provider", "config": {}})