mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 04:18:35 +00:00
Fix RagTool validation and qdrant import issues (#3918)
- Fix issue #3: Make similarity_threshold and limit optional with defaults - Updated BaseTool._default_args_schema to extract default values from function signatures - Updated BaseTool._set_args_schema for consistency - Now uses inspect.signature() and pydantic.create_model() to properly handle defaults - Fix issue #1: Make qdrant_client import conditional - Moved qdrant_client imports to TYPE_CHECKING block - Added runtime imports only when QdrantConfig is actually used - Prevents import errors when using ChromaDB without qdrant_client installed - Added comprehensive tests: - test_rag_tool_default_parameters_are_optional: Verifies schema has correct required fields - test_rag_tool_chromadb_no_qdrant_import: Ensures ChromaDB usage doesn't import qdrant_client All existing tests pass. Fixes reported issues where LLM calls would fail with ValidationError on first attempt due to required fields that should have been optional. Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -2,17 +2,15 @@
|
||||
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from typing import Any, TypeAlias, TypedDict
|
||||
from typing import TYPE_CHECKING, Any, TypeAlias, TypedDict
|
||||
import uuid
|
||||
|
||||
from crewai.rag.config.types import RagConfigType
|
||||
from crewai.rag.config.utils import get_rag_client
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.factory import create_client
|
||||
from crewai.rag.qdrant.config import QdrantConfig
|
||||
from crewai.rag.types import BaseRecord, SearchResult
|
||||
from pydantic import PrivateAttr
|
||||
from qdrant_client.models import VectorParams
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
@@ -20,6 +18,10 @@ from crewai_tools.rag.misc import sanitize_metadata_for_chromadb
|
||||
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
ContentItem: TypeAlias = str | Path | dict[str, Any]
|
||||
|
||||
|
||||
@@ -56,9 +58,16 @@ class CrewAIRagAdapter(Adapter):
|
||||
else:
|
||||
self._client = get_rag_client()
|
||||
collection_params: dict[str, Any] = {"collection_name": self.collection_name}
|
||||
if isinstance(self.config, QdrantConfig) and self.config.vectors_config:
|
||||
if isinstance(self.config.vectors_config, VectorParams):
|
||||
collection_params["vectors_config"] = self.config.vectors_config
|
||||
|
||||
if self.config is not None:
|
||||
from crewai.rag.qdrant.config import QdrantConfig
|
||||
|
||||
if isinstance(self.config, QdrantConfig) and self.config.vectors_config:
|
||||
from qdrant_client.models import VectorParams
|
||||
|
||||
if isinstance(self.config.vectors_config, VectorParams):
|
||||
collection_params["vectors_config"] = self.config.vectors_config
|
||||
|
||||
self._client.get_or_create_collection(**collection_params)
|
||||
|
||||
def query(
|
||||
|
||||
@@ -4,6 +4,7 @@ from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
import sys
|
||||
|
||||
from crewai_tools.adapters.crewai_rag_adapter import CrewAIRagAdapter
|
||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||
@@ -176,3 +177,69 @@ def test_rag_tool_no_results(
|
||||
result = tool._run(query="Non-existent content")
|
||||
assert "Relevant Content:" in result
|
||||
assert "No relevant content found" in result
|
||||
|
||||
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.get_rag_client")
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
|
||||
def test_rag_tool_default_parameters_are_optional(
|
||||
mock_create_client: Mock, mock_get_rag_client: Mock
|
||||
) -> None:
|
||||
"""Test that similarity_threshold and limit parameters have defaults and are optional."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_client.search = MagicMock(
|
||||
return_value=[{"content": "Test content", "metadata": {}, "score": 0.9}]
|
||||
)
|
||||
mock_get_rag_client.return_value = mock_client
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
class MyTool(RagTool):
|
||||
pass
|
||||
|
||||
tool = MyTool()
|
||||
|
||||
schema = tool.args_schema.model_json_schema()
|
||||
required_fields = schema.get("required", [])
|
||||
|
||||
assert "query" in required_fields, "query should be required"
|
||||
assert "similarity_threshold" not in required_fields, "similarity_threshold should be optional"
|
||||
assert "limit" not in required_fields, "limit should be optional"
|
||||
|
||||
properties = schema.get("properties", {})
|
||||
assert "query" in properties
|
||||
assert "similarity_threshold" in properties
|
||||
assert "limit" in properties
|
||||
|
||||
result = tool._run(query="Test query")
|
||||
assert "Relevant Content:" in result
|
||||
assert "Test content" in result
|
||||
|
||||
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.get_rag_client")
|
||||
def test_rag_tool_chromadb_no_qdrant_import(mock_get_rag_client: Mock) -> None:
|
||||
"""Test that using ChromaDB config does not import qdrant_client."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_get_rag_client.return_value = mock_client
|
||||
|
||||
original_modules = sys.modules.copy()
|
||||
|
||||
if "qdrant_client" in sys.modules:
|
||||
del sys.modules["qdrant_client"]
|
||||
if "qdrant_client.models" in sys.modules:
|
||||
del sys.modules["qdrant_client.models"]
|
||||
|
||||
sys.modules["qdrant_client"] = None
|
||||
sys.modules["qdrant_client.models"] = None
|
||||
|
||||
try:
|
||||
class MyTool(RagTool):
|
||||
pass
|
||||
|
||||
tool = MyTool()
|
||||
|
||||
assert tool.adapter is not None
|
||||
assert isinstance(tool.adapter, CrewAIRagAdapter)
|
||||
finally:
|
||||
sys.modules.clear()
|
||||
sys.modules.update(original_modules)
|
||||
|
||||
@@ -80,20 +80,30 @@ class BaseTool(BaseModel, ABC):
|
||||
if v != cls._ArgsSchemaPlaceholder:
|
||||
return v
|
||||
|
||||
return cast(
|
||||
type[PydanticBaseModel],
|
||||
type(
|
||||
f"{cls.__name__}Schema",
|
||||
(PydanticBaseModel,),
|
||||
{
|
||||
"__annotations__": {
|
||||
k: v
|
||||
for k, v in cls._run.__annotations__.items()
|
||||
if k != "return"
|
||||
},
|
||||
},
|
||||
),
|
||||
)
|
||||
# Extract both annotations and defaults from the _run method signature
|
||||
sig = signature(cls._run)
|
||||
fields: dict[str, Any] = {}
|
||||
|
||||
for param_name, param in sig.parameters.items():
|
||||
if param_name == "self":
|
||||
continue
|
||||
|
||||
annotation = cls._run.__annotations__.get(param_name, Any)
|
||||
|
||||
if param.default is not param.empty:
|
||||
default = param.default
|
||||
else:
|
||||
default = ...
|
||||
|
||||
fields[param_name] = (annotation, default)
|
||||
|
||||
if not fields:
|
||||
return cast(
|
||||
type[PydanticBaseModel],
|
||||
type(f"{cls.__name__}Schema", (PydanticBaseModel,), {}),
|
||||
)
|
||||
|
||||
return create_model(f"{cls.__name__}Schema", **fields)
|
||||
|
||||
@field_validator("max_usage_count", mode="before")
|
||||
@classmethod
|
||||
@@ -196,20 +206,29 @@ class BaseTool(BaseModel, ABC):
|
||||
def _set_args_schema(self) -> None:
|
||||
if self.args_schema is None:
|
||||
class_name = f"{self.__class__.__name__}Schema"
|
||||
self.args_schema = cast(
|
||||
type[PydanticBaseModel],
|
||||
type(
|
||||
class_name,
|
||||
(PydanticBaseModel,),
|
||||
{
|
||||
"__annotations__": {
|
||||
k: v
|
||||
for k, v in self._run.__annotations__.items()
|
||||
if k != "return"
|
||||
},
|
||||
},
|
||||
),
|
||||
)
|
||||
sig = signature(self._run)
|
||||
fields: dict[str, Any] = {}
|
||||
|
||||
for param_name, param in sig.parameters.items():
|
||||
if param_name == "self":
|
||||
continue
|
||||
|
||||
annotation = self._run.__annotations__.get(param_name, Any)
|
||||
|
||||
if param.default is not param.empty:
|
||||
default = param.default
|
||||
else:
|
||||
default = ...
|
||||
|
||||
fields[param_name] = (annotation, default)
|
||||
|
||||
if not fields:
|
||||
self.args_schema = cast(
|
||||
type[PydanticBaseModel],
|
||||
type(class_name, (PydanticBaseModel,), {}),
|
||||
)
|
||||
else:
|
||||
self.args_schema = create_model(class_name, **fields)
|
||||
|
||||
def _generate_description(self) -> None:
|
||||
args_schema = {
|
||||
|
||||
Reference in New Issue
Block a user