Fix RagTool validation and qdrant import issues (#3918)

- Fix issue #3: Make similarity_threshold and limit optional with defaults
  - Updated BaseTool._default_args_schema to extract default values from function signatures
  - Updated BaseTool._set_args_schema for consistency
  - Now uses inspect.signature() and pydantic.create_model() to properly handle defaults

- Fix issue #1: Make qdrant_client import conditional
  - Moved qdrant_client imports to TYPE_CHECKING block
  - Added runtime imports only when QdrantConfig is actually used
  - Prevents import errors when using ChromaDB without qdrant_client installed

- Added comprehensive tests:
  - test_rag_tool_default_parameters_are_optional: Verifies schema has correct required fields
  - test_rag_tool_chromadb_no_qdrant_import: Ensures ChromaDB usage doesn't import qdrant_client

All existing tests pass. Fixes reported issues where LLM calls would fail with ValidationError
on first attempt due to required fields that should have been optional.

Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
Devin AI
2025-11-14 15:09:34 +00:00
parent d7bdac12a2
commit 62d12543f2
3 changed files with 129 additions and 34 deletions

View File

@@ -4,6 +4,7 @@ from pathlib import Path
from tempfile import TemporaryDirectory
from typing import cast
from unittest.mock import MagicMock, Mock, patch
import sys
from crewai_tools.adapters.crewai_rag_adapter import CrewAIRagAdapter
from crewai_tools.tools.rag.rag_tool import RagTool
@@ -176,3 +177,69 @@ def test_rag_tool_no_results(
result = tool._run(query="Non-existent content")
assert "Relevant Content:" in result
assert "No relevant content found" in result
@patch("crewai_tools.adapters.crewai_rag_adapter.get_rag_client")
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
def test_rag_tool_default_parameters_are_optional(
mock_create_client: Mock, mock_get_rag_client: Mock
) -> None:
"""Test that similarity_threshold and limit parameters have defaults and are optional."""
mock_client = MagicMock()
mock_client.get_or_create_collection = MagicMock(return_value=None)
mock_client.search = MagicMock(
return_value=[{"content": "Test content", "metadata": {}, "score": 0.9}]
)
mock_get_rag_client.return_value = mock_client
mock_create_client.return_value = mock_client
class MyTool(RagTool):
pass
tool = MyTool()
schema = tool.args_schema.model_json_schema()
required_fields = schema.get("required", [])
assert "query" in required_fields, "query should be required"
assert "similarity_threshold" not in required_fields, "similarity_threshold should be optional"
assert "limit" not in required_fields, "limit should be optional"
properties = schema.get("properties", {})
assert "query" in properties
assert "similarity_threshold" in properties
assert "limit" in properties
result = tool._run(query="Test query")
assert "Relevant Content:" in result
assert "Test content" in result
@patch("crewai_tools.adapters.crewai_rag_adapter.get_rag_client")
def test_rag_tool_chromadb_no_qdrant_import(mock_get_rag_client: Mock) -> None:
"""Test that using ChromaDB config does not import qdrant_client."""
mock_client = MagicMock()
mock_client.get_or_create_collection = MagicMock(return_value=None)
mock_get_rag_client.return_value = mock_client
original_modules = sys.modules.copy()
if "qdrant_client" in sys.modules:
del sys.modules["qdrant_client"]
if "qdrant_client.models" in sys.modules:
del sys.modules["qdrant_client.models"]
sys.modules["qdrant_client"] = None
sys.modules["qdrant_client.models"] = None
try:
class MyTool(RagTool):
pass
tool = MyTool()
assert tool.adapter is not None
assert isinstance(tool.adapter, CrewAIRagAdapter)
finally:
sys.modules.clear()
sys.modules.update(original_modules)