mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
feat: improvements on QdrantVectorSearchTool
* Implement improvements on QdrantVectorSearchTool - Allow search filters to be set at the constructor level - Fix issue that prevented multiple records from being returned * Implement improvements on QdrantVectorSearchTool - Allow search filters to be set at the constructor level - Fix issue that prevented multiple records from being returned --------- Co-authored-by: Greyson LaLonde <greyson.r.lalonde@gmail.com>
This commit is contained in:
@@ -1,80 +1,42 @@
|
|||||||
from collections.abc import Callable
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import importlib
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
from collections.abc import Callable
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
from qdrant_client import QdrantClient
|
|
||||||
from qdrant_client.http.models import FieldCondition, Filter, MatchValue
|
|
||||||
|
|
||||||
QDRANT_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
QDRANT_AVAILABLE = False
|
|
||||||
QdrantClient = Any # type: ignore[assignment,misc] # type placeholder
|
|
||||||
Filter = Any # type: ignore[assignment,misc]
|
|
||||||
FieldCondition = Any # type: ignore[assignment,misc]
|
|
||||||
MatchValue = Any # type: ignore[assignment,misc]
|
|
||||||
|
|
||||||
from crewai.tools import BaseTool, EnvVar
|
from crewai.tools import BaseTool, EnvVar
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||||
|
from pydantic.types import ImportString
|
||||||
|
|
||||||
|
|
||||||
class QdrantToolSchema(BaseModel):
|
class QdrantToolSchema(BaseModel):
|
||||||
"""Input for QdrantTool."""
|
query: str = Field(..., description="Query to search in Qdrant DB.")
|
||||||
|
filter_by: str | None = None
|
||||||
|
filter_value: str | None = None
|
||||||
|
|
||||||
query: str = Field(
|
|
||||||
...,
|
class QdrantConfig(BaseModel):
|
||||||
description="The query to search retrieve relevant information from the Qdrant database. Pass only the query, not the question.",
|
"""All Qdrant connection and search settings."""
|
||||||
)
|
|
||||||
filter_by: str | None = Field(
|
qdrant_url: str
|
||||||
default=None,
|
qdrant_api_key: str | None = None
|
||||||
description="Filter by properties. Pass only the properties, not the question.",
|
collection_name: str
|
||||||
)
|
limit: int = 3
|
||||||
filter_value: str | None = Field(
|
score_threshold: float = 0.35
|
||||||
default=None,
|
filter_conditions: list[tuple[str, Any]] = Field(default_factory=list)
|
||||||
description="Filter by value. Pass only the value, not the question.",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class QdrantVectorSearchTool(BaseTool):
|
class QdrantVectorSearchTool(BaseTool):
|
||||||
"""Tool to query and filter results from a Qdrant database.
|
"""Vector search tool for Qdrant."""
|
||||||
|
|
||||||
This tool enables vector similarity search on internal documents stored in Qdrant,
|
|
||||||
with optional filtering capabilities.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
client: Configured QdrantClient instance
|
|
||||||
collection_name: Name of the Qdrant collection to search
|
|
||||||
limit: Maximum number of results to return
|
|
||||||
score_threshold: Minimum similarity score threshold
|
|
||||||
qdrant_url: Qdrant server URL
|
|
||||||
qdrant_api_key: Authentication key for Qdrant
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
client: QdrantClient = None # type: ignore[assignment]
|
|
||||||
|
# --- Metadata ---
|
||||||
name: str = "QdrantVectorSearchTool"
|
name: str = "QdrantVectorSearchTool"
|
||||||
description: str = "A tool to search the Qdrant database for relevant information on internal documents."
|
description: str = "Search Qdrant vector DB for relevant documents."
|
||||||
args_schema: type[BaseModel] = QdrantToolSchema
|
args_schema: type[BaseModel] = QdrantToolSchema
|
||||||
query: str | None = None
|
|
||||||
filter_by: str | None = None
|
|
||||||
filter_value: str | None = None
|
|
||||||
collection_name: str | None = None
|
|
||||||
limit: int | None = Field(default=3)
|
|
||||||
score_threshold: float = Field(default=0.35)
|
|
||||||
qdrant_url: str = Field(
|
|
||||||
...,
|
|
||||||
description="The URL of the Qdrant server",
|
|
||||||
)
|
|
||||||
qdrant_api_key: str | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="The API key for the Qdrant server",
|
|
||||||
)
|
|
||||||
custom_embedding_fn: Callable | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="A custom embedding function to use for vectorization. If not provided, the default model will be used.",
|
|
||||||
)
|
|
||||||
package_dependencies: list[str] = Field(default_factory=lambda: ["qdrant-client"])
|
package_dependencies: list[str] = Field(default_factory=lambda: ["qdrant-client"])
|
||||||
env_vars: list[EnvVar] = Field(
|
env_vars: list[EnvVar] = Field(
|
||||||
default_factory=lambda: [
|
default_factory=lambda: [
|
||||||
@@ -83,107 +45,81 @@ class QdrantVectorSearchTool(BaseTool):
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
qdrant_config: QdrantConfig
|
||||||
|
qdrant_package: ImportString[Any] = Field(
|
||||||
|
default="qdrant_client",
|
||||||
|
description="Base package path for Qdrant. Will dynamically import client and models.",
|
||||||
|
)
|
||||||
|
custom_embedding_fn: ImportString[Callable[[str], list[float]]] | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Optional embedding function or import path.",
|
||||||
|
)
|
||||||
|
client: Any | None = None
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
@model_validator(mode="after")
|
||||||
super().__init__(**kwargs)
|
def _setup_qdrant(self) -> QdrantVectorSearchTool:
|
||||||
if QDRANT_AVAILABLE:
|
# Import the qdrant_package if it's a string
|
||||||
self.client = QdrantClient(
|
if isinstance(self.qdrant_package, str):
|
||||||
url=self.qdrant_url,
|
self.qdrant_package = importlib.import_module(self.qdrant_package)
|
||||||
api_key=self.qdrant_api_key if self.qdrant_api_key else None,
|
|
||||||
|
if not self.client:
|
||||||
|
self.client = self.qdrant_package.QdrantClient(
|
||||||
|
url=self.qdrant_config.qdrant_url,
|
||||||
|
api_key=self.qdrant_config.qdrant_api_key or None,
|
||||||
)
|
)
|
||||||
else:
|
return self
|
||||||
import click
|
|
||||||
|
|
||||||
if click.confirm(
|
|
||||||
"The 'qdrant-client' package is required to use the QdrantVectorSearchTool. "
|
|
||||||
"Would you like to install it?"
|
|
||||||
):
|
|
||||||
import subprocess
|
|
||||||
|
|
||||||
subprocess.run(["uv", "add", "qdrant-client"], check=True) # noqa: S607
|
|
||||||
else:
|
|
||||||
raise ImportError(
|
|
||||||
"The 'qdrant-client' package is required to use the QdrantVectorSearchTool. "
|
|
||||||
"Please install it with: uv add qdrant-client"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _run(
|
def _run(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
filter_by: str | None = None,
|
filter_by: str | None = None,
|
||||||
filter_value: str | None = None,
|
filter_value: Any | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Execute vector similarity search on Qdrant.
|
"""Perform vector similarity search."""
|
||||||
|
filter_ = self.qdrant_package.http.models.Filter
|
||||||
|
field_condition = self.qdrant_package.http.models.FieldCondition
|
||||||
|
match_value = self.qdrant_package.http.models.MatchValue
|
||||||
|
conditions = self.qdrant_config.filter_conditions.copy()
|
||||||
|
if filter_by and filter_value is not None:
|
||||||
|
conditions.append((filter_by, filter_value))
|
||||||
|
|
||||||
Args:
|
search_filter = (
|
||||||
query: Search query to vectorize and match
|
filter_(
|
||||||
filter_by: Optional metadata field to filter on
|
|
||||||
filter_value: Optional value to filter by
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
JSON string containing search results with metadata and scores
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ImportError: If qdrant-client is not installed
|
|
||||||
ValueError: If Qdrant credentials are missing
|
|
||||||
"""
|
|
||||||
if not self.qdrant_url:
|
|
||||||
raise ValueError("QDRANT_URL is not set")
|
|
||||||
|
|
||||||
# Create filter if filter parameters are provided
|
|
||||||
search_filter = None
|
|
||||||
if filter_by and filter_value:
|
|
||||||
search_filter = Filter(
|
|
||||||
must=[
|
must=[
|
||||||
FieldCondition(key=filter_by, match=MatchValue(value=filter_value))
|
field_condition(key=k, match=match_value(value=v))
|
||||||
|
for k, v in conditions
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
if conditions
|
||||||
# Search in Qdrant using the built-in query method
|
else None
|
||||||
query_vector = (
|
|
||||||
self._vectorize_query(query, embedding_model="text-embedding-3-large")
|
|
||||||
if not self.custom_embedding_fn
|
|
||||||
else self.custom_embedding_fn(query)
|
|
||||||
)
|
)
|
||||||
search_results = self.client.query_points(
|
query_vector = (
|
||||||
collection_name=self.collection_name, # type: ignore[arg-type]
|
self.custom_embedding_fn(query)
|
||||||
|
if self.custom_embedding_fn
|
||||||
|
else (
|
||||||
|
lambda: __import__("openai")
|
||||||
|
.Client(api_key=os.getenv("OPENAI_API_KEY"))
|
||||||
|
.embeddings.create(input=[query], model="text-embedding-3-large")
|
||||||
|
.data[0]
|
||||||
|
.embedding
|
||||||
|
)()
|
||||||
|
)
|
||||||
|
results = self.client.query_points(
|
||||||
|
collection_name=self.qdrant_config.collection_name,
|
||||||
query=query_vector,
|
query=query_vector,
|
||||||
query_filter=search_filter,
|
query_filter=search_filter,
|
||||||
limit=self.limit, # type: ignore[arg-type]
|
limit=self.qdrant_config.limit,
|
||||||
score_threshold=self.score_threshold,
|
score_threshold=self.qdrant_config.score_threshold,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Format results similar to storage implementation
|
return json.dumps(
|
||||||
results = []
|
[
|
||||||
# Extract the list of ScoredPoint objects from the tuple
|
{
|
||||||
for point in search_results:
|
"distance": p.score,
|
||||||
result = {
|
"metadata": p.payload.get("metadata", {}) if p.payload else {},
|
||||||
"metadata": point[1][0].payload.get("metadata", {}),
|
"context": p.payload.get("text", "") if p.payload else {},
|
||||||
"context": point[1][0].payload.get("text", ""),
|
}
|
||||||
"distance": point[1][0].score,
|
for p in results.points
|
||||||
}
|
],
|
||||||
results.append(result)
|
indent=2,
|
||||||
|
|
||||||
return json.dumps(results, indent=2)
|
|
||||||
|
|
||||||
def _vectorize_query(self, query: str, embedding_model: str) -> list[float]:
|
|
||||||
"""Default vectorization function with openai.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query (str): The query to vectorize
|
|
||||||
embedding_model (str): The embedding model to use
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[float]: The vectorized query
|
|
||||||
"""
|
|
||||||
import openai
|
|
||||||
|
|
||||||
client = openai.Client(api_key=os.getenv("OPENAI_API_KEY"))
|
|
||||||
return (
|
|
||||||
client.embeddings.create(
|
|
||||||
input=[query],
|
|
||||||
model=embedding_model,
|
|
||||||
)
|
|
||||||
.data[0]
|
|
||||||
.embedding
|
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user