Enhance QdrantVectorSearchTool (#3806)
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled

This commit is contained in:
Daniel Barreto
2025-10-28 14:42:40 -03:00
committed by GitHub
parent 410db1ff39
commit 70b083945f
4 changed files with 322 additions and 100 deletions

View File

@@ -1,9 +1,9 @@
from __future__ import annotations
from collections.abc import Callable
import importlib
import json
import os
from collections.abc import Callable
from typing import Any
from crewai.tools import BaseTool, EnvVar
@@ -12,9 +12,13 @@ from pydantic.types import ImportString
class QdrantToolSchema(BaseModel):
query: str = Field(..., description="Query to search in Qdrant DB.")
filter_by: str | None = None
filter_value: str | None = None
query: str = Field(..., description="Query to search in Qdrant DB")
filter_by: str | None = Field(
default=None, description="Parameter to filter the search by."
)
filter_value: Any | None = Field(
default=None, description="Value to filter the search by."
)
class QdrantConfig(BaseModel):
@@ -25,7 +29,9 @@ class QdrantConfig(BaseModel):
collection_name: str
limit: int = 3
score_threshold: float = 0.35
filter_conditions: list[tuple[str, Any]] = Field(default_factory=list)
filter: Any | None = Field(
default=None, description="Qdrant Filter instance for advanced filtering."
)
class QdrantVectorSearchTool(BaseTool):
@@ -76,23 +82,26 @@ class QdrantVectorSearchTool(BaseTool):
filter_value: Any | None = None,
) -> str:
"""Perform vector similarity search."""
filter_ = self.qdrant_package.http.models.Filter
field_condition = self.qdrant_package.http.models.FieldCondition
match_value = self.qdrant_package.http.models.MatchValue
conditions = self.qdrant_config.filter_conditions.copy()
if filter_by and filter_value is not None:
conditions.append((filter_by, filter_value))
search_filter = (
filter_(
must=[
field_condition(key=k, match=match_value(value=v))
for k, v in conditions
]
)
if conditions
else None
self.qdrant_config.filter.model_copy()
if self.qdrant_config.filter is not None
else self.qdrant_package.http.models.Filter(must=[])
)
if filter_by and filter_value is not None:
if not hasattr(search_filter, "must") or not isinstance(
search_filter.must, list
):
search_filter.must = []
search_filter.must.append(
self.qdrant_package.http.models.FieldCondition(
key=filter_by,
match=self.qdrant_package.http.models.MatchValue(
value=filter_value
),
)
)
query_vector = (
self.custom_embedding_fn(query)
if self.custom_embedding_fn