mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 17:18:29 +00:00
default openai
This commit is contained in:
@@ -1,5 +1,8 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Optional, Type
|
||||
import ollama
|
||||
|
||||
|
||||
try:
|
||||
from qdrant_client import QdrantClient
|
||||
@@ -8,7 +11,7 @@ try:
|
||||
QDRANT_AVAILABLE = True
|
||||
except ImportError:
|
||||
QDRANT_AVAILABLE = False
|
||||
QdrantClient = Any
|
||||
QdrantClient = Any # type placeholder
|
||||
Filter = Any
|
||||
FieldCondition = Any
|
||||
MatchValue = Any
|
||||
@@ -35,101 +38,51 @@ class QdrantToolSchema(BaseModel):
|
||||
|
||||
|
||||
class QdrantVectorSearchTool(BaseTool):
|
||||
"""Tool to query and filter results from a Qdrant vector database.
|
||||
"""Tool to query and filter results from a Qdrant database.
|
||||
|
||||
This tool provides functionality to perform semantic search operations on documents
|
||||
stored in a Qdrant collection, with optional filtering capabilities.
|
||||
This tool enables vector similarity search on internal documents stored in Qdrant,
|
||||
with optional filtering capabilities.
|
||||
|
||||
Attributes:
|
||||
name (str): Name of the tool
|
||||
description (str): Description of the tool's functionality
|
||||
client (QdrantClient): Qdrant client instance
|
||||
collection_name (str): Name of the Qdrant collection to search
|
||||
limit (int): Maximum number of results to return
|
||||
score_threshold (float): Minimum similarity score threshold
|
||||
client: Configured QdrantClient instance
|
||||
collection_name: Name of the Qdrant collection to search
|
||||
limit: Maximum number of results to return
|
||||
score_threshold: Minimum similarity score threshold
|
||||
qdrant_url: Qdrant server URL
|
||||
qdrant_api_key: Authentication key for Qdrant
|
||||
"""
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
client: QdrantClient = None
|
||||
name: str = "QdrantVectorSearchTool"
|
||||
description: str = "A tool to search the Qdrant database for relevant information on internal documents."
|
||||
args_schema: Type[BaseModel] = QdrantToolSchema
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
client: Optional[QdrantClient] = None
|
||||
collection_name: str = Field(
|
||||
...,
|
||||
description="The name of the Qdrant collection to search",
|
||||
)
|
||||
query: Optional[str] = None
|
||||
filter_by: Optional[str] = None
|
||||
filter_value: Optional[str] = None
|
||||
collection_name: Optional[str] = None
|
||||
limit: Optional[int] = Field(default=3)
|
||||
score_threshold: float = Field(default=0.35)
|
||||
qdrant_url: str = Field(
|
||||
...,
|
||||
description="The URL of the Qdrant server",
|
||||
)
|
||||
qdrant_api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
qdrant_api_key: str = Field(
|
||||
...,
|
||||
description="The API key for the Qdrant server",
|
||||
)
|
||||
vectorizer: Optional[str] = Field(
|
||||
default="BAAI/bge-small-en-v1.5",
|
||||
description="The vectorizer to use for the Qdrant server",
|
||||
custom_embedding_fn: Optional[callable] = Field(
|
||||
default=None,
|
||||
description="A custom embedding function to use for vectorization. If not provided, the default model will be used.",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
qdrant_url: str,
|
||||
collection_name: str,
|
||||
qdrant_api_key: Optional[str] = None,
|
||||
vectorizer: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Initialize the QdrantVectorSearchTool.
|
||||
|
||||
Args:
|
||||
qdrant_url: URL of the Qdrant server
|
||||
collection_name: Name of the collection to search
|
||||
qdrant_api_key: Optional API key for authentication
|
||||
vectorizer: Optional model name for text vectorization
|
||||
|
||||
Raises:
|
||||
ImportError: If qdrant-client package is not installed
|
||||
ConnectionError: If unable to connect to Qdrant server
|
||||
"""
|
||||
kwargs["qdrant_url"] = qdrant_url
|
||||
kwargs["collection_name"] = collection_name
|
||||
kwargs["qdrant_api_key"] = qdrant_api_key
|
||||
if vectorizer:
|
||||
kwargs["vectorizer"] = vectorizer
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if QDRANT_AVAILABLE:
|
||||
try:
|
||||
self.client = QdrantClient(
|
||||
url=qdrant_url,
|
||||
api_key=qdrant_api_key,
|
||||
)
|
||||
# Verify connection
|
||||
self.client.get_collections()
|
||||
except Exception as e:
|
||||
raise ConnectionError(f"Failed to connect to Qdrant server: {str(e)}")
|
||||
else:
|
||||
import click
|
||||
|
||||
if click.confirm(
|
||||
"You are missing the 'qdrant-client' package. Would you like to install it?"
|
||||
):
|
||||
import subprocess
|
||||
|
||||
subprocess.run(
|
||||
["uv", "add", "crewai[tools]", "qdrant-client"], check=True
|
||||
)
|
||||
else:
|
||||
raise ImportError(
|
||||
"The 'qdrant-client' package is required to use the QdrantVectorSearchTool. "
|
||||
"Please install it with: uv add crewai[tools] qdrant-client"
|
||||
)
|
||||
if vectorizer:
|
||||
self.client.set_model(self.vectorizer)
|
||||
self.client = QdrantClient(
|
||||
url=self.qdrant_url,
|
||||
api_key=self.qdrant_api_key,
|
||||
)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
@@ -137,24 +90,30 @@ class QdrantVectorSearchTool(BaseTool):
|
||||
filter_by: Optional[str] = None,
|
||||
filter_value: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Execute the vector search query.
|
||||
"""Execute vector similarity search on Qdrant.
|
||||
|
||||
Args:
|
||||
query: Search query text
|
||||
filter_by: Optional field name to filter results
|
||||
query: Search query to vectorize and match
|
||||
filter_by: Optional metadata field to filter on
|
||||
filter_value: Optional value to filter by
|
||||
|
||||
Returns:
|
||||
JSON string containing search results with metadata
|
||||
JSON string containing search results with metadata and scores
|
||||
|
||||
Raises:
|
||||
ValueError: If filter_by is provided without filter_value or vice versa
|
||||
ImportError: If qdrant-client is not installed
|
||||
ValueError: If Qdrant credentials are missing
|
||||
"""
|
||||
if bool(filter_by) != bool(filter_value):
|
||||
raise ValueError(
|
||||
"Both filter_by and filter_value must be provided together"
|
||||
if not QDRANT_AVAILABLE:
|
||||
raise ImportError(
|
||||
"The 'qdrant-client' package is required to use the QdrantVectorSearchTool. "
|
||||
"Please install it with: pip install qdrant-client"
|
||||
)
|
||||
|
||||
if not self.qdrant_url:
|
||||
raise ValueError("QDRANT_URL is not set")
|
||||
|
||||
# Create filter if filter parameters are provided
|
||||
search_filter = None
|
||||
if filter_by and filter_value:
|
||||
search_filter = Filter(
|
||||
@@ -163,29 +122,52 @@ class QdrantVectorSearchTool(BaseTool):
|
||||
]
|
||||
)
|
||||
|
||||
try:
|
||||
search_results = self.client.query(
|
||||
collection_name=self.collection_name,
|
||||
query_text=[query],
|
||||
query_filter=search_filter,
|
||||
limit=self.limit,
|
||||
score_threshold=self.score_threshold,
|
||||
# Search in Qdrant using the built-in query method
|
||||
|
||||
query_vector = (
|
||||
self._vectorize_query(query)
|
||||
if not self.custom_embedding_fn
|
||||
else self.custom_embedding_fn(query)
|
||||
)
|
||||
search_results = self.client.query_points(
|
||||
collection_name=self.collection_name,
|
||||
query=query_vector,
|
||||
query_filter=search_filter,
|
||||
limit=self.limit,
|
||||
score_threshold=self.score_threshold,
|
||||
)
|
||||
|
||||
# Format results similar to storage implementation
|
||||
results = []
|
||||
# Extract the list of ScoredPoint objects from the tuple
|
||||
for point in search_results:
|
||||
result = {
|
||||
"metadata": point[1][0].payload.get("metadata", {}),
|
||||
"context": point[1][0].payload.get("text", ""),
|
||||
"distance": point[1][0].score,
|
||||
}
|
||||
results.append(result)
|
||||
|
||||
return json.dumps(results, indent=2)
|
||||
|
||||
def _vectorize_query(self, query: str) -> list[float]:
|
||||
"""Default vectorization function with openai.
|
||||
|
||||
Args:
|
||||
query (str): The query to vectorize
|
||||
|
||||
Returns:
|
||||
list[float]: The vectorized query
|
||||
"""
|
||||
import openai
|
||||
|
||||
client = openai.Client(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
embedding = (
|
||||
client.embeddings.create(
|
||||
input=[query],
|
||||
model="text-embedding-3-small",
|
||||
)
|
||||
|
||||
results = [
|
||||
{
|
||||
"id": point.id,
|
||||
"metadata": point.metadata,
|
||||
"context": point.document,
|
||||
"score": point.score,
|
||||
}
|
||||
for point in search_results
|
||||
]
|
||||
|
||||
if not results:
|
||||
return json.dumps({"message": "No results found", "results": []})
|
||||
|
||||
return json.dumps(results, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error executing Qdrant search: {str(e)}")
|
||||
.data[0]
|
||||
.embedding
|
||||
)
|
||||
return embedding
|
||||
|
||||
Reference in New Issue
Block a user