mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 04:18:35 +00:00
Fix RagTool validation and qdrant import issues (#3918)
- Fix issue #3: Make similarity_threshold and limit optional with defaults - Updated BaseTool._default_args_schema to extract default values from function signatures - Updated BaseTool._set_args_schema for consistency - Now uses inspect.signature() and pydantic.create_model() to properly handle defaults - Fix issue #1: Make qdrant_client import conditional - Moved qdrant_client imports to TYPE_CHECKING block - Added runtime imports only when QdrantConfig is actually used - Prevents import errors when using ChromaDB without qdrant_client installed - Added comprehensive tests: - test_rag_tool_default_parameters_are_optional: Verifies schema has correct required fields - test_rag_tool_chromadb_no_qdrant_import: Ensures ChromaDB usage doesn't import qdrant_client All existing tests pass. Fixes reported issues where LLM calls would fail with ValidationError on first attempt due to required fields that should have been optional. Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -2,17 +2,15 @@
|
|||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, TypeAlias, TypedDict
|
from typing import TYPE_CHECKING, Any, TypeAlias, TypedDict
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from crewai.rag.config.types import RagConfigType
|
from crewai.rag.config.types import RagConfigType
|
||||||
from crewai.rag.config.utils import get_rag_client
|
from crewai.rag.config.utils import get_rag_client
|
||||||
from crewai.rag.core.base_client import BaseClient
|
from crewai.rag.core.base_client import BaseClient
|
||||||
from crewai.rag.factory import create_client
|
from crewai.rag.factory import create_client
|
||||||
from crewai.rag.qdrant.config import QdrantConfig
|
|
||||||
from crewai.rag.types import BaseRecord, SearchResult
|
from crewai.rag.types import BaseRecord, SearchResult
|
||||||
from pydantic import PrivateAttr
|
from pydantic import PrivateAttr
|
||||||
from qdrant_client.models import VectorParams
|
|
||||||
from typing_extensions import Unpack
|
from typing_extensions import Unpack
|
||||||
|
|
||||||
from crewai_tools.rag.data_types import DataType
|
from crewai_tools.rag.data_types import DataType
|
||||||
@@ -20,6 +18,10 @@ from crewai_tools.rag.misc import sanitize_metadata_for_chromadb
|
|||||||
from crewai_tools.tools.rag.rag_tool import Adapter
|
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
ContentItem: TypeAlias = str | Path | dict[str, Any]
|
ContentItem: TypeAlias = str | Path | dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
@@ -56,9 +58,16 @@ class CrewAIRagAdapter(Adapter):
|
|||||||
else:
|
else:
|
||||||
self._client = get_rag_client()
|
self._client = get_rag_client()
|
||||||
collection_params: dict[str, Any] = {"collection_name": self.collection_name}
|
collection_params: dict[str, Any] = {"collection_name": self.collection_name}
|
||||||
if isinstance(self.config, QdrantConfig) and self.config.vectors_config:
|
|
||||||
if isinstance(self.config.vectors_config, VectorParams):
|
if self.config is not None:
|
||||||
collection_params["vectors_config"] = self.config.vectors_config
|
from crewai.rag.qdrant.config import QdrantConfig
|
||||||
|
|
||||||
|
if isinstance(self.config, QdrantConfig) and self.config.vectors_config:
|
||||||
|
from qdrant_client.models import VectorParams
|
||||||
|
|
||||||
|
if isinstance(self.config.vectors_config, VectorParams):
|
||||||
|
collection_params["vectors_config"] = self.config.vectors_config
|
||||||
|
|
||||||
self._client.get_or_create_collection(**collection_params)
|
self._client.get_or_create_collection(**collection_params)
|
||||||
|
|
||||||
def query(
|
def query(
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from pathlib import Path
|
|||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
from typing import cast
|
from typing import cast
|
||||||
from unittest.mock import MagicMock, Mock, patch
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
import sys
|
||||||
|
|
||||||
from crewai_tools.adapters.crewai_rag_adapter import CrewAIRagAdapter
|
from crewai_tools.adapters.crewai_rag_adapter import CrewAIRagAdapter
|
||||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||||
@@ -176,3 +177,69 @@ def test_rag_tool_no_results(
|
|||||||
result = tool._run(query="Non-existent content")
|
result = tool._run(query="Non-existent content")
|
||||||
assert "Relevant Content:" in result
|
assert "Relevant Content:" in result
|
||||||
assert "No relevant content found" in result
|
assert "No relevant content found" in result
|
||||||
|
|
||||||
|
|
||||||
|
@patch("crewai_tools.adapters.crewai_rag_adapter.get_rag_client")
|
||||||
|
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
|
||||||
|
def test_rag_tool_default_parameters_are_optional(
|
||||||
|
mock_create_client: Mock, mock_get_rag_client: Mock
|
||||||
|
) -> None:
|
||||||
|
"""Test that similarity_threshold and limit parameters have defaults and are optional."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||||
|
mock_client.search = MagicMock(
|
||||||
|
return_value=[{"content": "Test content", "metadata": {}, "score": 0.9}]
|
||||||
|
)
|
||||||
|
mock_get_rag_client.return_value = mock_client
|
||||||
|
mock_create_client.return_value = mock_client
|
||||||
|
|
||||||
|
class MyTool(RagTool):
|
||||||
|
pass
|
||||||
|
|
||||||
|
tool = MyTool()
|
||||||
|
|
||||||
|
schema = tool.args_schema.model_json_schema()
|
||||||
|
required_fields = schema.get("required", [])
|
||||||
|
|
||||||
|
assert "query" in required_fields, "query should be required"
|
||||||
|
assert "similarity_threshold" not in required_fields, "similarity_threshold should be optional"
|
||||||
|
assert "limit" not in required_fields, "limit should be optional"
|
||||||
|
|
||||||
|
properties = schema.get("properties", {})
|
||||||
|
assert "query" in properties
|
||||||
|
assert "similarity_threshold" in properties
|
||||||
|
assert "limit" in properties
|
||||||
|
|
||||||
|
result = tool._run(query="Test query")
|
||||||
|
assert "Relevant Content:" in result
|
||||||
|
assert "Test content" in result
|
||||||
|
|
||||||
|
|
||||||
|
@patch("crewai_tools.adapters.crewai_rag_adapter.get_rag_client")
|
||||||
|
def test_rag_tool_chromadb_no_qdrant_import(mock_get_rag_client: Mock) -> None:
|
||||||
|
"""Test that using ChromaDB config does not import qdrant_client."""
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||||
|
mock_get_rag_client.return_value = mock_client
|
||||||
|
|
||||||
|
original_modules = sys.modules.copy()
|
||||||
|
|
||||||
|
if "qdrant_client" in sys.modules:
|
||||||
|
del sys.modules["qdrant_client"]
|
||||||
|
if "qdrant_client.models" in sys.modules:
|
||||||
|
del sys.modules["qdrant_client.models"]
|
||||||
|
|
||||||
|
sys.modules["qdrant_client"] = None
|
||||||
|
sys.modules["qdrant_client.models"] = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
class MyTool(RagTool):
|
||||||
|
pass
|
||||||
|
|
||||||
|
tool = MyTool()
|
||||||
|
|
||||||
|
assert tool.adapter is not None
|
||||||
|
assert isinstance(tool.adapter, CrewAIRagAdapter)
|
||||||
|
finally:
|
||||||
|
sys.modules.clear()
|
||||||
|
sys.modules.update(original_modules)
|
||||||
|
|||||||
@@ -80,20 +80,30 @@ class BaseTool(BaseModel, ABC):
|
|||||||
if v != cls._ArgsSchemaPlaceholder:
|
if v != cls._ArgsSchemaPlaceholder:
|
||||||
return v
|
return v
|
||||||
|
|
||||||
return cast(
|
# Extract both annotations and defaults from the _run method signature
|
||||||
type[PydanticBaseModel],
|
sig = signature(cls._run)
|
||||||
type(
|
fields: dict[str, Any] = {}
|
||||||
f"{cls.__name__}Schema",
|
|
||||||
(PydanticBaseModel,),
|
for param_name, param in sig.parameters.items():
|
||||||
{
|
if param_name == "self":
|
||||||
"__annotations__": {
|
continue
|
||||||
k: v
|
|
||||||
for k, v in cls._run.__annotations__.items()
|
annotation = cls._run.__annotations__.get(param_name, Any)
|
||||||
if k != "return"
|
|
||||||
},
|
if param.default is not param.empty:
|
||||||
},
|
default = param.default
|
||||||
),
|
else:
|
||||||
)
|
default = ...
|
||||||
|
|
||||||
|
fields[param_name] = (annotation, default)
|
||||||
|
|
||||||
|
if not fields:
|
||||||
|
return cast(
|
||||||
|
type[PydanticBaseModel],
|
||||||
|
type(f"{cls.__name__}Schema", (PydanticBaseModel,), {}),
|
||||||
|
)
|
||||||
|
|
||||||
|
return create_model(f"{cls.__name__}Schema", **fields)
|
||||||
|
|
||||||
@field_validator("max_usage_count", mode="before")
|
@field_validator("max_usage_count", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -196,20 +206,29 @@ class BaseTool(BaseModel, ABC):
|
|||||||
def _set_args_schema(self) -> None:
|
def _set_args_schema(self) -> None:
|
||||||
if self.args_schema is None:
|
if self.args_schema is None:
|
||||||
class_name = f"{self.__class__.__name__}Schema"
|
class_name = f"{self.__class__.__name__}Schema"
|
||||||
self.args_schema = cast(
|
sig = signature(self._run)
|
||||||
type[PydanticBaseModel],
|
fields: dict[str, Any] = {}
|
||||||
type(
|
|
||||||
class_name,
|
for param_name, param in sig.parameters.items():
|
||||||
(PydanticBaseModel,),
|
if param_name == "self":
|
||||||
{
|
continue
|
||||||
"__annotations__": {
|
|
||||||
k: v
|
annotation = self._run.__annotations__.get(param_name, Any)
|
||||||
for k, v in self._run.__annotations__.items()
|
|
||||||
if k != "return"
|
if param.default is not param.empty:
|
||||||
},
|
default = param.default
|
||||||
},
|
else:
|
||||||
),
|
default = ...
|
||||||
)
|
|
||||||
|
fields[param_name] = (annotation, default)
|
||||||
|
|
||||||
|
if not fields:
|
||||||
|
self.args_schema = cast(
|
||||||
|
type[PydanticBaseModel],
|
||||||
|
type(class_name, (PydanticBaseModel,), {}),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.args_schema = create_model(class_name, **fields)
|
||||||
|
|
||||||
def _generate_description(self) -> None:
|
def _generate_description(self) -> None:
|
||||||
args_schema = {
|
args_schema = {
|
||||||
|
|||||||
Reference in New Issue
Block a user