Merge pull request #250 from parthbs/feat/local-qdrant-client-support

#249 feat: add support for local qdrant client
This commit is contained in:
Lucas Gomide
2025-04-02 10:16:14 -03:00
committed by GitHub
2 changed files with 9 additions and 8 deletions

View File

@@ -26,7 +26,7 @@ tool = QdrantVectorSearchTool(
collection_name="example_collections", collection_name="example_collections",
limit=3, limit=3,
qdrant_url="https://your-qdrant-cluster-url.com", qdrant_url="https://your-qdrant-cluster-url.com",
qdrant_api_key="your-qdrant-api-key", qdrant_api_key="your-qdrant-api-key", # (optional)
) )
@@ -43,7 +43,7 @@ rag_agent = Agent(
- `collection_name` : The name of the collection to search within. (Required) - `collection_name` : The name of the collection to search within. (Required)
- `qdrant_url` : The URL of the Qdrant cluster. (Required) - `qdrant_url` : The URL of the Qdrant cluster. (Required)
- `qdrant_api_key` : The API key for the Qdrant cluster. (Required) - `qdrant_api_key` : The API key for the Qdrant cluster. (Optional)
- `limit` : The number of results to return. (Optional) - `limit` : The number of results to return. (Optional)
- `vectorizer` : The vectorizer to use. (Optional) - `vectorizer` : The vectorizer to use. (Optional)

View File

@@ -66,8 +66,8 @@ class QdrantVectorSearchTool(BaseTool):
..., ...,
description="The URL of the Qdrant server", description="The URL of the Qdrant server",
) )
qdrant_api_key: str = Field( qdrant_api_key: Optional[str] = Field(
..., default=None,
description="The API key for the Qdrant server", description="The API key for the Qdrant server",
) )
custom_embedding_fn: Optional[callable] = Field( custom_embedding_fn: Optional[callable] = Field(
@@ -80,7 +80,7 @@ class QdrantVectorSearchTool(BaseTool):
if QDRANT_AVAILABLE: if QDRANT_AVAILABLE:
self.client = QdrantClient( self.client = QdrantClient(
url=self.qdrant_url, url=self.qdrant_url,
api_key=self.qdrant_api_key, api_key=self.qdrant_api_key if self.qdrant_api_key else None,
) )
else: else:
import click import click
@@ -133,7 +133,7 @@ class QdrantVectorSearchTool(BaseTool):
# Search in Qdrant using the built-in query method # Search in Qdrant using the built-in query method
query_vector = ( query_vector = (
self._vectorize_query(query) self._vectorize_query(query, embedding_model="text-embedding-3-large")
if not self.custom_embedding_fn if not self.custom_embedding_fn
else self.custom_embedding_fn(query) else self.custom_embedding_fn(query)
) )
@@ -158,11 +158,12 @@ class QdrantVectorSearchTool(BaseTool):
return json.dumps(results, indent=2) return json.dumps(results, indent=2)
def _vectorize_query(self, query: str) -> list[float]: def _vectorize_query(self, query: str, embedding_model: str) -> list[float]:
"""Default vectorization function with openai. """Default vectorization function with openai.
Args: Args:
query (str): The query to vectorize query (str): The query to vectorize
embedding_model (str): The embedding model to use
Returns: Returns:
list[float]: The vectorized query list[float]: The vectorized query
@@ -173,7 +174,7 @@ class QdrantVectorSearchTool(BaseTool):
embedding = ( embedding = (
client.embeddings.create( client.embeddings.create(
input=[query], input=[query],
model="text-embedding-3-small", model=embedding_model,
) )
.data[0] .data[0]
.embedding .embedding