feat: improvements on QdrantVectorSearchTool

* Implement improvements on QdrantVectorSearchTool

- Allow search filters to be set at the constructor level
- Fix issue that prevented multiple records from being returned

* Implement improvements on QdrantVectorSearchTool

- Allow search filters to be set at the constructor level
- Fix issue that prevented multiple records from being returned

---------

Co-authored-by: Greyson LaLonde <greyson.r.lalonde@gmail.com>
This commit is contained in:
Daniel Barreto
2025-10-21 17:50:08 -03:00
committed by GitHub
parent f6e13eb890
commit 2ee27efca7

View File

@@ -1,80 +1,42 @@
from collections.abc import Callable from __future__ import annotations
import importlib
import json import json
import os import os
from collections.abc import Callable
from typing import Any from typing import Any
try:
from qdrant_client import QdrantClient
from qdrant_client.http.models import FieldCondition, Filter, MatchValue
QDRANT_AVAILABLE = True
except ImportError:
QDRANT_AVAILABLE = False
QdrantClient = Any # type: ignore[assignment,misc] # type placeholder
Filter = Any # type: ignore[assignment,misc]
FieldCondition = Any # type: ignore[assignment,misc]
MatchValue = Any # type: ignore[assignment,misc]
from crewai.tools import BaseTool, EnvVar from crewai.tools import BaseTool, EnvVar
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field, model_validator
from pydantic.types import ImportString
class QdrantToolSchema(BaseModel): class QdrantToolSchema(BaseModel):
"""Input for QdrantTool.""" query: str = Field(..., description="Query to search in Qdrant DB.")
filter_by: str | None = None
filter_value: str | None = None
query: str = Field(
..., class QdrantConfig(BaseModel):
description="The query to search retrieve relevant information from the Qdrant database. Pass only the query, not the question.", """All Qdrant connection and search settings."""
)
filter_by: str | None = Field( qdrant_url: str
default=None, qdrant_api_key: str | None = None
description="Filter by properties. Pass only the properties, not the question.", collection_name: str
) limit: int = 3
filter_value: str | None = Field( score_threshold: float = 0.35
default=None, filter_conditions: list[tuple[str, Any]] = Field(default_factory=list)
description="Filter by value. Pass only the value, not the question.",
)
class QdrantVectorSearchTool(BaseTool): class QdrantVectorSearchTool(BaseTool):
"""Tool to query and filter results from a Qdrant database. """Vector search tool for Qdrant."""
This tool enables vector similarity search on internal documents stored in Qdrant,
with optional filtering capabilities.
Attributes:
client: Configured QdrantClient instance
collection_name: Name of the Qdrant collection to search
limit: Maximum number of results to return
score_threshold: Minimum similarity score threshold
qdrant_url: Qdrant server URL
qdrant_api_key: Authentication key for Qdrant
"""
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
client: QdrantClient = None # type: ignore[assignment]
# --- Metadata ---
name: str = "QdrantVectorSearchTool" name: str = "QdrantVectorSearchTool"
description: str = "A tool to search the Qdrant database for relevant information on internal documents." description: str = "Search Qdrant vector DB for relevant documents."
args_schema: type[BaseModel] = QdrantToolSchema args_schema: type[BaseModel] = QdrantToolSchema
query: str | None = None
filter_by: str | None = None
filter_value: str | None = None
collection_name: str | None = None
limit: int | None = Field(default=3)
score_threshold: float = Field(default=0.35)
qdrant_url: str = Field(
...,
description="The URL of the Qdrant server",
)
qdrant_api_key: str | None = Field(
default=None,
description="The API key for the Qdrant server",
)
custom_embedding_fn: Callable | None = Field(
default=None,
description="A custom embedding function to use for vectorization. If not provided, the default model will be used.",
)
package_dependencies: list[str] = Field(default_factory=lambda: ["qdrant-client"]) package_dependencies: list[str] = Field(default_factory=lambda: ["qdrant-client"])
env_vars: list[EnvVar] = Field( env_vars: list[EnvVar] = Field(
default_factory=lambda: [ default_factory=lambda: [
@@ -83,107 +45,81 @@ class QdrantVectorSearchTool(BaseTool):
) )
] ]
) )
qdrant_config: QdrantConfig
qdrant_package: ImportString[Any] = Field(
default="qdrant_client",
description="Base package path for Qdrant. Will dynamically import client and models.",
)
custom_embedding_fn: ImportString[Callable[[str], list[float]]] | None = Field(
default=None,
description="Optional embedding function or import path.",
)
client: Any | None = None
def __init__(self, **kwargs): @model_validator(mode="after")
super().__init__(**kwargs) def _setup_qdrant(self) -> QdrantVectorSearchTool:
if QDRANT_AVAILABLE: # Import the qdrant_package if it's a string
self.client = QdrantClient( if isinstance(self.qdrant_package, str):
url=self.qdrant_url, self.qdrant_package = importlib.import_module(self.qdrant_package)
api_key=self.qdrant_api_key if self.qdrant_api_key else None,
if not self.client:
self.client = self.qdrant_package.QdrantClient(
url=self.qdrant_config.qdrant_url,
api_key=self.qdrant_config.qdrant_api_key or None,
) )
else: return self
import click
if click.confirm(
"The 'qdrant-client' package is required to use the QdrantVectorSearchTool. "
"Would you like to install it?"
):
import subprocess
subprocess.run(["uv", "add", "qdrant-client"], check=True) # noqa: S607
else:
raise ImportError(
"The 'qdrant-client' package is required to use the QdrantVectorSearchTool. "
"Please install it with: uv add qdrant-client"
)
def _run( def _run(
self, self,
query: str, query: str,
filter_by: str | None = None, filter_by: str | None = None,
filter_value: str | None = None, filter_value: Any | None = None,
) -> str: ) -> str:
"""Execute vector similarity search on Qdrant. """Perform vector similarity search."""
filter_ = self.qdrant_package.http.models.Filter
field_condition = self.qdrant_package.http.models.FieldCondition
match_value = self.qdrant_package.http.models.MatchValue
conditions = self.qdrant_config.filter_conditions.copy()
if filter_by and filter_value is not None:
conditions.append((filter_by, filter_value))
Args: search_filter = (
query: Search query to vectorize and match filter_(
filter_by: Optional metadata field to filter on
filter_value: Optional value to filter by
Returns:
JSON string containing search results with metadata and scores
Raises:
ImportError: If qdrant-client is not installed
ValueError: If Qdrant credentials are missing
"""
if not self.qdrant_url:
raise ValueError("QDRANT_URL is not set")
# Create filter if filter parameters are provided
search_filter = None
if filter_by and filter_value:
search_filter = Filter(
must=[ must=[
FieldCondition(key=filter_by, match=MatchValue(value=filter_value)) field_condition(key=k, match=match_value(value=v))
for k, v in conditions
] ]
) )
if conditions
# Search in Qdrant using the built-in query method else None
query_vector = (
self._vectorize_query(query, embedding_model="text-embedding-3-large")
if not self.custom_embedding_fn
else self.custom_embedding_fn(query)
) )
search_results = self.client.query_points( query_vector = (
collection_name=self.collection_name, # type: ignore[arg-type] self.custom_embedding_fn(query)
if self.custom_embedding_fn
else (
lambda: __import__("openai")
.Client(api_key=os.getenv("OPENAI_API_KEY"))
.embeddings.create(input=[query], model="text-embedding-3-large")
.data[0]
.embedding
)()
)
results = self.client.query_points(
collection_name=self.qdrant_config.collection_name,
query=query_vector, query=query_vector,
query_filter=search_filter, query_filter=search_filter,
limit=self.limit, # type: ignore[arg-type] limit=self.qdrant_config.limit,
score_threshold=self.score_threshold, score_threshold=self.qdrant_config.score_threshold,
) )
# Format results similar to storage implementation return json.dumps(
results = [] [
# Extract the list of ScoredPoint objects from the tuple {
for point in search_results: "distance": p.score,
result = { "metadata": p.payload.get("metadata", {}) if p.payload else {},
"metadata": point[1][0].payload.get("metadata", {}), "context": p.payload.get("text", "") if p.payload else {},
"context": point[1][0].payload.get("text", ""), }
"distance": point[1][0].score, for p in results.points
} ],
results.append(result) indent=2,
return json.dumps(results, indent=2)
def _vectorize_query(self, query: str, embedding_model: str) -> list[float]:
"""Default vectorization function with openai.
Args:
query (str): The query to vectorize
embedding_model (str): The embedding model to use
Returns:
list[float]: The vectorized query
"""
import openai
client = openai.Client(api_key=os.getenv("OPENAI_API_KEY"))
return (
client.embeddings.create(
input=[query],
model=embedding_model,
)
.data[0]
.embedding
) )