diff --git a/src/crewai_tools/tools/qdrant_vector_search_tool/qdrant_search_tool.py b/src/crewai_tools/tools/qdrant_vector_search_tool/qdrant_search_tool.py index 307fcb8d1..1dd8c6078 100644 --- a/src/crewai_tools/tools/qdrant_vector_search_tool/qdrant_search_tool.py +++ b/src/crewai_tools/tools/qdrant_vector_search_tool/qdrant_search_tool.py @@ -1,5 +1,8 @@ import json +import os from typing import Any, Optional, Type +import ollama + try: from qdrant_client import QdrantClient @@ -8,7 +11,7 @@ try: QDRANT_AVAILABLE = True except ImportError: QDRANT_AVAILABLE = False - QdrantClient = Any + QdrantClient = Any # type placeholder Filter = Any FieldCondition = Any MatchValue = Any @@ -35,101 +38,51 @@ class QdrantToolSchema(BaseModel): class QdrantVectorSearchTool(BaseTool): - """Tool to query and filter results from a Qdrant vector database. + """Tool to query and filter results from a Qdrant database. - This tool provides functionality to perform semantic search operations on documents - stored in a Qdrant collection, with optional filtering capabilities. + This tool enables vector similarity search on internal documents stored in Qdrant, + with optional filtering capabilities. Attributes: - name (str): Name of the tool - description (str): Description of the tool's functionality - client (QdrantClient): Qdrant client instance - collection_name (str): Name of the Qdrant collection to search - limit (int): Maximum number of results to return - score_threshold (float): Minimum similarity score threshold + client: Configured QdrantClient instance + collection_name: Name of the Qdrant collection to search + limit: Maximum number of results to return + score_threshold: Minimum similarity score threshold + qdrant_url: Qdrant server URL + qdrant_api_key: Authentication key for Qdrant """ + model_config = {"arbitrary_types_allowed": True} + client: QdrantClient = None name: str = "QdrantVectorSearchTool" description: str = "A tool to search the Qdrant database for relevant information on internal documents." args_schema: Type[BaseModel] = QdrantToolSchema - - model_config = {"arbitrary_types_allowed": True} - client: Optional[QdrantClient] = None - collection_name: str = Field( - ..., - description="The name of the Qdrant collection to search", - ) + query: Optional[str] = None + filter_by: Optional[str] = None + filter_value: Optional[str] = None + collection_name: Optional[str] = None limit: Optional[int] = Field(default=3) score_threshold: float = Field(default=0.35) qdrant_url: str = Field( ..., description="The URL of the Qdrant server", ) - qdrant_api_key: Optional[str] = Field( - default=None, + qdrant_api_key: str = Field( + ..., description="The API key for the Qdrant server", ) - vectorizer: Optional[str] = Field( - default="BAAI/bge-small-en-v1.5", - description="The vectorizer to use for the Qdrant server", + custom_embedding_fn: Optional[callable] = Field( + default=None, + description="A custom embedding function to use for vectorization. If not provided, the default model will be used.", ) - def __init__( - self, - qdrant_url: str, - collection_name: str, - qdrant_api_key: Optional[str] = None, - vectorizer: Optional[str] = None, - **kwargs, - ) -> None: - """Initialize the QdrantVectorSearchTool. - - Args: - qdrant_url: URL of the Qdrant server - collection_name: Name of the collection to search - qdrant_api_key: Optional API key for authentication - vectorizer: Optional model name for text vectorization - - Raises: - ImportError: If qdrant-client package is not installed - ConnectionError: If unable to connect to Qdrant server - """ - kwargs["qdrant_url"] = qdrant_url - kwargs["collection_name"] = collection_name - kwargs["qdrant_api_key"] = qdrant_api_key - if vectorizer: - kwargs["vectorizer"] = vectorizer - + def __init__(self, **kwargs): super().__init__(**kwargs) - if QDRANT_AVAILABLE: - try: - self.client = QdrantClient( - url=qdrant_url, - api_key=qdrant_api_key, - ) - # Verify connection - self.client.get_collections() - except Exception as e: - raise ConnectionError(f"Failed to connect to Qdrant server: {str(e)}") - else: - import click - - if click.confirm( - "You are missing the 'qdrant-client' package. Would you like to install it?" - ): - import subprocess - - subprocess.run( - ["uv", "add", "crewai[tools]", "qdrant-client"], check=True - ) - else: - raise ImportError( - "The 'qdrant-client' package is required to use the QdrantVectorSearchTool. " - "Please install it with: uv add crewai[tools] qdrant-client" - ) - if vectorizer: - self.client.set_model(self.vectorizer) + self.client = QdrantClient( + url=self.qdrant_url, + api_key=self.qdrant_api_key, + ) def _run( self, @@ -137,24 +90,30 @@ class QdrantVectorSearchTool(BaseTool): filter_by: Optional[str] = None, filter_value: Optional[str] = None, ) -> str: - """Execute the vector search query. + """Execute vector similarity search on Qdrant. Args: - query: Search query text - filter_by: Optional field name to filter results + query: Search query to vectorize and match + filter_by: Optional metadata field to filter on filter_value: Optional value to filter by Returns: - JSON string containing search results with metadata + JSON string containing search results with metadata and scores Raises: - ValueError: If filter_by is provided without filter_value or vice versa + ImportError: If qdrant-client is not installed + ValueError: If Qdrant credentials are missing """ - if bool(filter_by) != bool(filter_value): - raise ValueError( - "Both filter_by and filter_value must be provided together" + if not QDRANT_AVAILABLE: + raise ImportError( + "The 'qdrant-client' package is required to use the QdrantVectorSearchTool. " + "Please install it with: pip install qdrant-client" ) + if not self.qdrant_url: + raise ValueError("QDRANT_URL is not set") + + # Create filter if filter parameters are provided search_filter = None if filter_by and filter_value: search_filter = Filter( @@ -163,29 +122,52 @@ class QdrantVectorSearchTool(BaseTool): ] ) - try: - search_results = self.client.query( - collection_name=self.collection_name, - query_text=[query], - query_filter=search_filter, - limit=self.limit, - score_threshold=self.score_threshold, + # Search in Qdrant using the built-in query method + + query_vector = ( + self._vectorize_query(query) + if not self.custom_embedding_fn + else self.custom_embedding_fn(query) + ) + search_results = self.client.query_points( + collection_name=self.collection_name, + query=query_vector, + query_filter=search_filter, + limit=self.limit, + score_threshold=self.score_threshold, + ) + + # Format results similar to storage implementation + results = [] + # Extract the list of ScoredPoint objects from the tuple + for point in search_results: + result = { + "metadata": point[1][0].payload.get("metadata", {}), + "context": point[1][0].payload.get("text", ""), + "distance": point[1][0].score, + } + results.append(result) + + return json.dumps(results, indent=2) + + def _vectorize_query(self, query: str) -> list[float]: + """Default vectorization function with openai. + + Args: + query (str): The query to vectorize + + Returns: + list[float]: The vectorized query + """ + import openai + + client = openai.Client(api_key=os.getenv("OPENAI_API_KEY")) + embedding = ( + client.embeddings.create( + input=[query], + model="text-embedding-3-small", ) - - results = [ - { - "id": point.id, - "metadata": point.metadata, - "context": point.document, - "score": point.score, - } - for point in search_results - ] - - if not results: - return json.dumps({"message": "No results found", "results": []}) - - return json.dumps(results, indent=2) - - except Exception as e: - raise RuntimeError(f"Error executing Qdrant search: {str(e)}") + .data[0] + .embedding + ) + return embedding