diff --git a/lib/crewai-tools/src/crewai_tools/tools/qdrant_vector_search_tool/qdrant_search_tool.py b/lib/crewai-tools/src/crewai_tools/tools/qdrant_vector_search_tool/qdrant_search_tool.py index c1a88114e..adcb733a5 100644 --- a/lib/crewai-tools/src/crewai_tools/tools/qdrant_vector_search_tool/qdrant_search_tool.py +++ b/lib/crewai-tools/src/crewai_tools/tools/qdrant_vector_search_tool/qdrant_search_tool.py @@ -1,80 +1,42 @@ -from collections.abc import Callable +from __future__ import annotations + +import importlib import json import os +from collections.abc import Callable from typing import Any - -try: - from qdrant_client import QdrantClient - from qdrant_client.http.models import FieldCondition, Filter, MatchValue - - QDRANT_AVAILABLE = True -except ImportError: - QDRANT_AVAILABLE = False - QdrantClient = Any # type: ignore[assignment,misc] # type placeholder - Filter = Any # type: ignore[assignment,misc] - FieldCondition = Any # type: ignore[assignment,misc] - MatchValue = Any # type: ignore[assignment,misc] - from crewai.tools import BaseTool, EnvVar -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, model_validator +from pydantic.types import ImportString class QdrantToolSchema(BaseModel): - """Input for QdrantTool.""" + query: str = Field(..., description="Query to search in Qdrant DB.") + filter_by: str | None = None + filter_value: str | None = None - query: str = Field( - ..., - description="The query to search retrieve relevant information from the Qdrant database. Pass only the query, not the question.", - ) - filter_by: str | None = Field( - default=None, - description="Filter by properties. Pass only the properties, not the question.", - ) - filter_value: str | None = Field( - default=None, - description="Filter by value. Pass only the value, not the question.", - ) + +class QdrantConfig(BaseModel): + """All Qdrant connection and search settings.""" + + qdrant_url: str + qdrant_api_key: str | None = None + collection_name: str + limit: int = 3 + score_threshold: float = 0.35 + filter_conditions: list[tuple[str, Any]] = Field(default_factory=list) class QdrantVectorSearchTool(BaseTool): - """Tool to query and filter results from a Qdrant database. - - This tool enables vector similarity search on internal documents stored in Qdrant, - with optional filtering capabilities. - - Attributes: - client: Configured QdrantClient instance - collection_name: Name of the Qdrant collection to search - limit: Maximum number of results to return - score_threshold: Minimum similarity score threshold - qdrant_url: Qdrant server URL - qdrant_api_key: Authentication key for Qdrant - """ + """Vector search tool for Qdrant.""" model_config = ConfigDict(arbitrary_types_allowed=True) - client: QdrantClient = None # type: ignore[assignment] + + # --- Metadata --- name: str = "QdrantVectorSearchTool" - description: str = "A tool to search the Qdrant database for relevant information on internal documents." + description: str = "Search Qdrant vector DB for relevant documents." args_schema: type[BaseModel] = QdrantToolSchema - query: str | None = None - filter_by: str | None = None - filter_value: str | None = None - collection_name: str | None = None - limit: int | None = Field(default=3) - score_threshold: float = Field(default=0.35) - qdrant_url: str = Field( - ..., - description="The URL of the Qdrant server", - ) - qdrant_api_key: str | None = Field( - default=None, - description="The API key for the Qdrant server", - ) - custom_embedding_fn: Callable | None = Field( - default=None, - description="A custom embedding function to use for vectorization. If not provided, the default model will be used.", - ) package_dependencies: list[str] = Field(default_factory=lambda: ["qdrant-client"]) env_vars: list[EnvVar] = Field( default_factory=lambda: [ @@ -83,107 +45,81 @@ class QdrantVectorSearchTool(BaseTool): ) ] ) + qdrant_config: QdrantConfig + qdrant_package: ImportString[Any] = Field( + default="qdrant_client", + description="Base package path for Qdrant. Will dynamically import client and models.", + ) + custom_embedding_fn: ImportString[Callable[[str], list[float]]] | None = Field( + default=None, + description="Optional embedding function or import path.", + ) + client: Any | None = None - def __init__(self, **kwargs): - super().__init__(**kwargs) - if QDRANT_AVAILABLE: - self.client = QdrantClient( - url=self.qdrant_url, - api_key=self.qdrant_api_key if self.qdrant_api_key else None, + @model_validator(mode="after") + def _setup_qdrant(self) -> QdrantVectorSearchTool: + # Import the qdrant_package if it's a string + if isinstance(self.qdrant_package, str): + self.qdrant_package = importlib.import_module(self.qdrant_package) + + if not self.client: + self.client = self.qdrant_package.QdrantClient( + url=self.qdrant_config.qdrant_url, + api_key=self.qdrant_config.qdrant_api_key or None, ) - else: - import click - - if click.confirm( - "The 'qdrant-client' package is required to use the QdrantVectorSearchTool. " - "Would you like to install it?" - ): - import subprocess - - subprocess.run(["uv", "add", "qdrant-client"], check=True) # noqa: S607 - else: - raise ImportError( - "The 'qdrant-client' package is required to use the QdrantVectorSearchTool. " - "Please install it with: uv add qdrant-client" - ) + return self def _run( self, query: str, filter_by: str | None = None, - filter_value: str | None = None, + filter_value: Any | None = None, ) -> str: - """Execute vector similarity search on Qdrant. + """Perform vector similarity search.""" + filter_ = self.qdrant_package.http.models.Filter + field_condition = self.qdrant_package.http.models.FieldCondition + match_value = self.qdrant_package.http.models.MatchValue + conditions = self.qdrant_config.filter_conditions.copy() + if filter_by and filter_value is not None: + conditions.append((filter_by, filter_value)) - Args: - query: Search query to vectorize and match - filter_by: Optional metadata field to filter on - filter_value: Optional value to filter by - - Returns: - JSON string containing search results with metadata and scores - - Raises: - ImportError: If qdrant-client is not installed - ValueError: If Qdrant credentials are missing - """ - if not self.qdrant_url: - raise ValueError("QDRANT_URL is not set") - - # Create filter if filter parameters are provided - search_filter = None - if filter_by and filter_value: - search_filter = Filter( + search_filter = ( + filter_( must=[ - FieldCondition(key=filter_by, match=MatchValue(value=filter_value)) + field_condition(key=k, match=match_value(value=v)) + for k, v in conditions ] ) - - # Search in Qdrant using the built-in query method - query_vector = ( - self._vectorize_query(query, embedding_model="text-embedding-3-large") - if not self.custom_embedding_fn - else self.custom_embedding_fn(query) + if conditions + else None ) - search_results = self.client.query_points( - collection_name=self.collection_name, # type: ignore[arg-type] + query_vector = ( + self.custom_embedding_fn(query) + if self.custom_embedding_fn + else ( + lambda: __import__("openai") + .Client(api_key=os.getenv("OPENAI_API_KEY")) + .embeddings.create(input=[query], model="text-embedding-3-large") + .data[0] + .embedding + )() + ) + results = self.client.query_points( + collection_name=self.qdrant_config.collection_name, query=query_vector, query_filter=search_filter, - limit=self.limit, # type: ignore[arg-type] - score_threshold=self.score_threshold, + limit=self.qdrant_config.limit, + score_threshold=self.qdrant_config.score_threshold, ) - # Format results similar to storage implementation - results = [] - # Extract the list of ScoredPoint objects from the tuple - for point in search_results: - result = { - "metadata": point[1][0].payload.get("metadata", {}), - "context": point[1][0].payload.get("text", ""), - "distance": point[1][0].score, - } - results.append(result) - - return json.dumps(results, indent=2) - - def _vectorize_query(self, query: str, embedding_model: str) -> list[float]: - """Default vectorization function with openai. - - Args: - query (str): The query to vectorize - embedding_model (str): The embedding model to use - - Returns: - list[float]: The vectorized query - """ - import openai - - client = openai.Client(api_key=os.getenv("OPENAI_API_KEY")) - return ( - client.embeddings.create( - input=[query], - model=embedding_model, - ) - .data[0] - .embedding + return json.dumps( + [ + { + "distance": p.score, + "metadata": p.payload.get("metadata", {}) if p.payload else {}, + "context": p.payload.get("text", "") if p.payload else {}, + } + for p in results.points + ], + indent=2, )