From 62d12543f2b83f7a66d631c8ea80325358a5dfad Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Fri, 14 Nov 2025 15:09:34 +0000 Subject: [PATCH] Fix RagTool validation and qdrant import issues (#3918) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix issue #3: Make similarity_threshold and limit optional with defaults - Updated BaseTool._default_args_schema to extract default values from function signatures - Updated BaseTool._set_args_schema for consistency - Now uses inspect.signature() and pydantic.create_model() to properly handle defaults - Fix issue #1: Make qdrant_client import conditional - Moved qdrant_client imports to TYPE_CHECKING block - Added runtime imports only when QdrantConfig is actually used - Prevents import errors when using ChromaDB without qdrant_client installed - Added comprehensive tests: - test_rag_tool_default_parameters_are_optional: Verifies schema has correct required fields - test_rag_tool_chromadb_no_qdrant_import: Ensures ChromaDB usage doesn't import qdrant_client All existing tests pass. Fixes reported issues where LLM calls would fail with ValidationError on first attempt due to required fields that should have been optional. Co-Authored-By: João --- .../adapters/crewai_rag_adapter.py | 21 ++++-- .../tests/tools/rag/rag_tool_test.py | 67 +++++++++++++++++ lib/crewai/src/crewai/tools/base_tool.py | 75 ++++++++++++------- 3 files changed, 129 insertions(+), 34 deletions(-) diff --git a/lib/crewai-tools/src/crewai_tools/adapters/crewai_rag_adapter.py b/lib/crewai-tools/src/crewai_tools/adapters/crewai_rag_adapter.py index 9716ca4e9..08230d864 100644 --- a/lib/crewai-tools/src/crewai_tools/adapters/crewai_rag_adapter.py +++ b/lib/crewai-tools/src/crewai_tools/adapters/crewai_rag_adapter.py @@ -2,17 +2,15 @@ import hashlib from pathlib import Path -from typing import Any, TypeAlias, TypedDict +from typing import TYPE_CHECKING, Any, TypeAlias, TypedDict import uuid from crewai.rag.config.types import RagConfigType from crewai.rag.config.utils import get_rag_client from crewai.rag.core.base_client import BaseClient from crewai.rag.factory import create_client -from crewai.rag.qdrant.config import QdrantConfig from crewai.rag.types import BaseRecord, SearchResult from pydantic import PrivateAttr -from qdrant_client.models import VectorParams from typing_extensions import Unpack from crewai_tools.rag.data_types import DataType @@ -20,6 +18,10 @@ from crewai_tools.rag.misc import sanitize_metadata_for_chromadb from crewai_tools.tools.rag.rag_tool import Adapter +if TYPE_CHECKING: + pass + + ContentItem: TypeAlias = str | Path | dict[str, Any] @@ -56,9 +58,16 @@ class CrewAIRagAdapter(Adapter): else: self._client = get_rag_client() collection_params: dict[str, Any] = {"collection_name": self.collection_name} - if isinstance(self.config, QdrantConfig) and self.config.vectors_config: - if isinstance(self.config.vectors_config, VectorParams): - collection_params["vectors_config"] = self.config.vectors_config + + if self.config is not None: + from crewai.rag.qdrant.config import QdrantConfig + + if isinstance(self.config, QdrantConfig) and self.config.vectors_config: + from qdrant_client.models import VectorParams + + if isinstance(self.config.vectors_config, VectorParams): + collection_params["vectors_config"] = self.config.vectors_config + self._client.get_or_create_collection(**collection_params) def query( diff --git a/lib/crewai-tools/tests/tools/rag/rag_tool_test.py b/lib/crewai-tools/tests/tools/rag/rag_tool_test.py index 5298ce1e2..77b8ad9c8 100644 --- a/lib/crewai-tools/tests/tools/rag/rag_tool_test.py +++ b/lib/crewai-tools/tests/tools/rag/rag_tool_test.py @@ -4,6 +4,7 @@ from pathlib import Path from tempfile import TemporaryDirectory from typing import cast from unittest.mock import MagicMock, Mock, patch +import sys from crewai_tools.adapters.crewai_rag_adapter import CrewAIRagAdapter from crewai_tools.tools.rag.rag_tool import RagTool @@ -176,3 +177,69 @@ def test_rag_tool_no_results( result = tool._run(query="Non-existent content") assert "Relevant Content:" in result assert "No relevant content found" in result + + +@patch("crewai_tools.adapters.crewai_rag_adapter.get_rag_client") +@patch("crewai_tools.adapters.crewai_rag_adapter.create_client") +def test_rag_tool_default_parameters_are_optional( + mock_create_client: Mock, mock_get_rag_client: Mock +) -> None: + """Test that similarity_threshold and limit parameters have defaults and are optional.""" + mock_client = MagicMock() + mock_client.get_or_create_collection = MagicMock(return_value=None) + mock_client.search = MagicMock( + return_value=[{"content": "Test content", "metadata": {}, "score": 0.9}] + ) + mock_get_rag_client.return_value = mock_client + mock_create_client.return_value = mock_client + + class MyTool(RagTool): + pass + + tool = MyTool() + + schema = tool.args_schema.model_json_schema() + required_fields = schema.get("required", []) + + assert "query" in required_fields, "query should be required" + assert "similarity_threshold" not in required_fields, "similarity_threshold should be optional" + assert "limit" not in required_fields, "limit should be optional" + + properties = schema.get("properties", {}) + assert "query" in properties + assert "similarity_threshold" in properties + assert "limit" in properties + + result = tool._run(query="Test query") + assert "Relevant Content:" in result + assert "Test content" in result + + +@patch("crewai_tools.adapters.crewai_rag_adapter.get_rag_client") +def test_rag_tool_chromadb_no_qdrant_import(mock_get_rag_client: Mock) -> None: + """Test that using ChromaDB config does not import qdrant_client.""" + mock_client = MagicMock() + mock_client.get_or_create_collection = MagicMock(return_value=None) + mock_get_rag_client.return_value = mock_client + + original_modules = sys.modules.copy() + + if "qdrant_client" in sys.modules: + del sys.modules["qdrant_client"] + if "qdrant_client.models" in sys.modules: + del sys.modules["qdrant_client.models"] + + sys.modules["qdrant_client"] = None + sys.modules["qdrant_client.models"] = None + + try: + class MyTool(RagTool): + pass + + tool = MyTool() + + assert tool.adapter is not None + assert isinstance(tool.adapter, CrewAIRagAdapter) + finally: + sys.modules.clear() + sys.modules.update(original_modules) diff --git a/lib/crewai/src/crewai/tools/base_tool.py b/lib/crewai/src/crewai/tools/base_tool.py index 19ed6b671..34fd45ded 100644 --- a/lib/crewai/src/crewai/tools/base_tool.py +++ b/lib/crewai/src/crewai/tools/base_tool.py @@ -80,20 +80,30 @@ class BaseTool(BaseModel, ABC): if v != cls._ArgsSchemaPlaceholder: return v - return cast( - type[PydanticBaseModel], - type( - f"{cls.__name__}Schema", - (PydanticBaseModel,), - { - "__annotations__": { - k: v - for k, v in cls._run.__annotations__.items() - if k != "return" - }, - }, - ), - ) + # Extract both annotations and defaults from the _run method signature + sig = signature(cls._run) + fields: dict[str, Any] = {} + + for param_name, param in sig.parameters.items(): + if param_name == "self": + continue + + annotation = cls._run.__annotations__.get(param_name, Any) + + if param.default is not param.empty: + default = param.default + else: + default = ... + + fields[param_name] = (annotation, default) + + if not fields: + return cast( + type[PydanticBaseModel], + type(f"{cls.__name__}Schema", (PydanticBaseModel,), {}), + ) + + return create_model(f"{cls.__name__}Schema", **fields) @field_validator("max_usage_count", mode="before") @classmethod @@ -196,20 +206,29 @@ class BaseTool(BaseModel, ABC): def _set_args_schema(self) -> None: if self.args_schema is None: class_name = f"{self.__class__.__name__}Schema" - self.args_schema = cast( - type[PydanticBaseModel], - type( - class_name, - (PydanticBaseModel,), - { - "__annotations__": { - k: v - for k, v in self._run.__annotations__.items() - if k != "return" - }, - }, - ), - ) + sig = signature(self._run) + fields: dict[str, Any] = {} + + for param_name, param in sig.parameters.items(): + if param_name == "self": + continue + + annotation = self._run.__annotations__.get(param_name, Any) + + if param.default is not param.empty: + default = param.default + else: + default = ... + + fields[param_name] = (annotation, default) + + if not fields: + self.args_schema = cast( + type[PydanticBaseModel], + type(class_name, (PydanticBaseModel,), {}), + ) + else: + self.args_schema = create_model(class_name, **fields) def _generate_description(self) -> None: args_schema = {