default openai

This commit is contained in:
Lorenze Jay
2025-02-03 16:19:03 -08:00
parent 05982aeef2
commit 5a9bb24b63

View File

@@ -1,5 +1,8 @@
import json
import os
from typing import Any, Optional, Type
import ollama
try:
from qdrant_client import QdrantClient
@@ -8,7 +11,7 @@ try:
QDRANT_AVAILABLE = True
except ImportError:
QDRANT_AVAILABLE = False
QdrantClient = Any
QdrantClient = Any # type placeholder
Filter = Any
FieldCondition = Any
MatchValue = Any
@@ -35,101 +38,51 @@ class QdrantToolSchema(BaseModel):
class QdrantVectorSearchTool(BaseTool):
"""Tool to query and filter results from a Qdrant vector database.
"""Tool to query and filter results from a Qdrant database.
This tool provides functionality to perform semantic search operations on documents
stored in a Qdrant collection, with optional filtering capabilities.
This tool enables vector similarity search on internal documents stored in Qdrant,
with optional filtering capabilities.
Attributes:
name (str): Name of the tool
description (str): Description of the tool's functionality
client (QdrantClient): Qdrant client instance
collection_name (str): Name of the Qdrant collection to search
limit (int): Maximum number of results to return
score_threshold (float): Minimum similarity score threshold
client: Configured QdrantClient instance
collection_name: Name of the Qdrant collection to search
limit: Maximum number of results to return
score_threshold: Minimum similarity score threshold
qdrant_url: Qdrant server URL
qdrant_api_key: Authentication key for Qdrant
"""
model_config = {"arbitrary_types_allowed": True}
client: QdrantClient = None
name: str = "QdrantVectorSearchTool"
description: str = "A tool to search the Qdrant database for relevant information on internal documents."
args_schema: Type[BaseModel] = QdrantToolSchema
model_config = {"arbitrary_types_allowed": True}
client: Optional[QdrantClient] = None
collection_name: str = Field(
...,
description="The name of the Qdrant collection to search",
)
query: Optional[str] = None
filter_by: Optional[str] = None
filter_value: Optional[str] = None
collection_name: Optional[str] = None
limit: Optional[int] = Field(default=3)
score_threshold: float = Field(default=0.35)
qdrant_url: str = Field(
...,
description="The URL of the Qdrant server",
)
qdrant_api_key: Optional[str] = Field(
default=None,
qdrant_api_key: str = Field(
...,
description="The API key for the Qdrant server",
)
vectorizer: Optional[str] = Field(
default="BAAI/bge-small-en-v1.5",
description="The vectorizer to use for the Qdrant server",
custom_embedding_fn: Optional[callable] = Field(
default=None,
description="A custom embedding function to use for vectorization. If not provided, the default model will be used.",
)
def __init__(
self,
qdrant_url: str,
collection_name: str,
qdrant_api_key: Optional[str] = None,
vectorizer: Optional[str] = None,
**kwargs,
) -> None:
"""Initialize the QdrantVectorSearchTool.
Args:
qdrant_url: URL of the Qdrant server
collection_name: Name of the collection to search
qdrant_api_key: Optional API key for authentication
vectorizer: Optional model name for text vectorization
Raises:
ImportError: If qdrant-client package is not installed
ConnectionError: If unable to connect to Qdrant server
"""
kwargs["qdrant_url"] = qdrant_url
kwargs["collection_name"] = collection_name
kwargs["qdrant_api_key"] = qdrant_api_key
if vectorizer:
kwargs["vectorizer"] = vectorizer
def __init__(self, **kwargs):
super().__init__(**kwargs)
if QDRANT_AVAILABLE:
try:
self.client = QdrantClient(
url=qdrant_url,
api_key=qdrant_api_key,
)
# Verify connection
self.client.get_collections()
except Exception as e:
raise ConnectionError(f"Failed to connect to Qdrant server: {str(e)}")
else:
import click
if click.confirm(
"You are missing the 'qdrant-client' package. Would you like to install it?"
):
import subprocess
subprocess.run(
["uv", "add", "crewai[tools]", "qdrant-client"], check=True
)
else:
raise ImportError(
"The 'qdrant-client' package is required to use the QdrantVectorSearchTool. "
"Please install it with: uv add crewai[tools] qdrant-client"
)
if vectorizer:
self.client.set_model(self.vectorizer)
self.client = QdrantClient(
url=self.qdrant_url,
api_key=self.qdrant_api_key,
)
def _run(
self,
@@ -137,24 +90,30 @@ class QdrantVectorSearchTool(BaseTool):
filter_by: Optional[str] = None,
filter_value: Optional[str] = None,
) -> str:
"""Execute the vector search query.
"""Execute vector similarity search on Qdrant.
Args:
query: Search query text
filter_by: Optional field name to filter results
query: Search query to vectorize and match
filter_by: Optional metadata field to filter on
filter_value: Optional value to filter by
Returns:
JSON string containing search results with metadata
JSON string containing search results with metadata and scores
Raises:
ValueError: If filter_by is provided without filter_value or vice versa
ImportError: If qdrant-client is not installed
ValueError: If Qdrant credentials are missing
"""
if bool(filter_by) != bool(filter_value):
raise ValueError(
"Both filter_by and filter_value must be provided together"
if not QDRANT_AVAILABLE:
raise ImportError(
"The 'qdrant-client' package is required to use the QdrantVectorSearchTool. "
"Please install it with: pip install qdrant-client"
)
if not self.qdrant_url:
raise ValueError("QDRANT_URL is not set")
# Create filter if filter parameters are provided
search_filter = None
if filter_by and filter_value:
search_filter = Filter(
@@ -163,29 +122,52 @@ class QdrantVectorSearchTool(BaseTool):
]
)
try:
search_results = self.client.query(
collection_name=self.collection_name,
query_text=[query],
query_filter=search_filter,
limit=self.limit,
score_threshold=self.score_threshold,
# Search in Qdrant using the built-in query method
query_vector = (
self._vectorize_query(query)
if not self.custom_embedding_fn
else self.custom_embedding_fn(query)
)
search_results = self.client.query_points(
collection_name=self.collection_name,
query=query_vector,
query_filter=search_filter,
limit=self.limit,
score_threshold=self.score_threshold,
)
# Format results similar to storage implementation
results = []
# Extract the list of ScoredPoint objects from the tuple
for point in search_results:
result = {
"metadata": point[1][0].payload.get("metadata", {}),
"context": point[1][0].payload.get("text", ""),
"distance": point[1][0].score,
}
results.append(result)
return json.dumps(results, indent=2)
def _vectorize_query(self, query: str) -> list[float]:
"""Default vectorization function with openai.
Args:
query (str): The query to vectorize
Returns:
list[float]: The vectorized query
"""
import openai
client = openai.Client(api_key=os.getenv("OPENAI_API_KEY"))
embedding = (
client.embeddings.create(
input=[query],
model="text-embedding-3-small",
)
results = [
{
"id": point.id,
"metadata": point.metadata,
"context": point.document,
"score": point.score,
}
for point in search_results
]
if not results:
return json.dumps({"message": "No results found", "results": []})
return json.dumps(results, indent=2)
except Exception as e:
raise RuntimeError(f"Error executing Qdrant search: {str(e)}")
.data[0]
.embedding
)
return embedding