mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 16:18:30 +00:00
Fix RagTool validation and qdrant import issues (#3918)
- Fix issue #3: Make similarity_threshold and limit optional with defaults - Updated BaseTool._default_args_schema to extract default values from function signatures - Updated BaseTool._set_args_schema for consistency - Now uses inspect.signature() and pydantic.create_model() to properly handle defaults - Fix issue #1: Make qdrant_client import conditional - Moved qdrant_client imports to TYPE_CHECKING block - Added runtime imports only when QdrantConfig is actually used - Prevents import errors when using ChromaDB without qdrant_client installed - Added comprehensive tests: - test_rag_tool_default_parameters_are_optional: Verifies schema has correct required fields - test_rag_tool_chromadb_no_qdrant_import: Ensures ChromaDB usage doesn't import qdrant_client All existing tests pass. Fixes reported issues where LLM calls would fail with ValidationError on first attempt due to required fields that should have been optional. Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -4,6 +4,7 @@ from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
import sys
|
||||
|
||||
from crewai_tools.adapters.crewai_rag_adapter import CrewAIRagAdapter
|
||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||
@@ -176,3 +177,69 @@ def test_rag_tool_no_results(
|
||||
result = tool._run(query="Non-existent content")
|
||||
assert "Relevant Content:" in result
|
||||
assert "No relevant content found" in result
|
||||
|
||||
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.get_rag_client")
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
|
||||
def test_rag_tool_default_parameters_are_optional(
|
||||
mock_create_client: Mock, mock_get_rag_client: Mock
|
||||
) -> None:
|
||||
"""Test that similarity_threshold and limit parameters have defaults and are optional."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_client.search = MagicMock(
|
||||
return_value=[{"content": "Test content", "metadata": {}, "score": 0.9}]
|
||||
)
|
||||
mock_get_rag_client.return_value = mock_client
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
class MyTool(RagTool):
|
||||
pass
|
||||
|
||||
tool = MyTool()
|
||||
|
||||
schema = tool.args_schema.model_json_schema()
|
||||
required_fields = schema.get("required", [])
|
||||
|
||||
assert "query" in required_fields, "query should be required"
|
||||
assert "similarity_threshold" not in required_fields, "similarity_threshold should be optional"
|
||||
assert "limit" not in required_fields, "limit should be optional"
|
||||
|
||||
properties = schema.get("properties", {})
|
||||
assert "query" in properties
|
||||
assert "similarity_threshold" in properties
|
||||
assert "limit" in properties
|
||||
|
||||
result = tool._run(query="Test query")
|
||||
assert "Relevant Content:" in result
|
||||
assert "Test content" in result
|
||||
|
||||
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.get_rag_client")
|
||||
def test_rag_tool_chromadb_no_qdrant_import(mock_get_rag_client: Mock) -> None:
|
||||
"""Test that using ChromaDB config does not import qdrant_client."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_get_rag_client.return_value = mock_client
|
||||
|
||||
original_modules = sys.modules.copy()
|
||||
|
||||
if "qdrant_client" in sys.modules:
|
||||
del sys.modules["qdrant_client"]
|
||||
if "qdrant_client.models" in sys.modules:
|
||||
del sys.modules["qdrant_client.models"]
|
||||
|
||||
sys.modules["qdrant_client"] = None
|
||||
sys.modules["qdrant_client.models"] = None
|
||||
|
||||
try:
|
||||
class MyTool(RagTool):
|
||||
pass
|
||||
|
||||
tool = MyTool()
|
||||
|
||||
assert tool.adapter is not None
|
||||
assert isinstance(tool.adapter, CrewAIRagAdapter)
|
||||
finally:
|
||||
sys.modules.clear()
|
||||
sys.modules.update(original_modules)
|
||||
|
||||
Reference in New Issue
Block a user