#249 feat: add support for local qdrant client

This commit is contained in:
Parth Patel
2025-03-25 19:01:01 +00:00
parent baea6dc8a4
commit 5ded394e43

View File

@@ -66,8 +66,8 @@ class QdrantVectorSearchTool(BaseTool):
...,
description="The URL of the Qdrant server",
)
qdrant_api_key: str = Field(
...,
qdrant_api_key: Optional[str] = Field(
default=None,
description="The API key for the Qdrant server",
)
custom_embedding_fn: Optional[callable] = Field(
@@ -80,7 +80,7 @@ class QdrantVectorSearchTool(BaseTool):
if QDRANT_AVAILABLE:
self.client = QdrantClient(
url=self.qdrant_url,
api_key=self.qdrant_api_key,
api_key=self.qdrant_api_key if self.qdrant_api_key else None,
)
else:
import click
@@ -133,7 +133,7 @@ class QdrantVectorSearchTool(BaseTool):
# Search in Qdrant using the built-in query method
query_vector = (
self._vectorize_query(query)
self._vectorize_query(query, embedding_model="text-embedding-3-large")
if not self.custom_embedding_fn
else self.custom_embedding_fn(query)
)
@@ -158,11 +158,12 @@ class QdrantVectorSearchTool(BaseTool):
return json.dumps(results, indent=2)
def _vectorize_query(self, query: str) -> list[float]:
def _vectorize_query(self, query: str, embedding_model: str) -> list[float]:
"""Default vectorization function with openai.
Args:
query (str): The query to vectorize
embedding_model (str): The embedding model to use
Returns:
list[float]: The vectorized query
@@ -173,7 +174,7 @@ class QdrantVectorSearchTool(BaseTool):
embedding = (
client.embeddings.create(
input=[query],
model="text-embedding-3-small",
model=embedding_model,
)
.data[0]
.embedding