mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
Add MongoDB Vector Search Tool (#319)
* INTPYTHON-580 Design and Implement MongoDBVectorSearchTool * add implementation * wip * wip * finish tests * add todo * refactor to wrap langchain-mongodb * cleanup * address review * Fix usage of EnvVar class * inline code * lint * lint * fix usage of SearchIndexModel * Refactor: Update EnvVar import path and remove unused tests.utils module - Changed import of EnvVar from tests.utils to crewai.tools in multiple files. - Updated README.md for MongoDB vector search tool with additional context. - Modified subprocess command in vector_search.py for package installation. - Cleaned up test_generate_tool_specs.py to improve mock patching syntax. - Deleted unused tests/utils.py file. * update the crewai dep and the lockfile * chore: update package versions and dependencies in uv.lock - Removed `auth0-python` package. - Updated `crewai` version to 0.140.0 and adjusted its dependencies. - Changed `json-repair` version to 0.25.2. - Updated `litellm` version to 1.72.6. - Modified dependency markers for several packages to improve compatibility with Python versions. * refactor: improve MongoDB vector search tool with enhanced error handling and new dimensions field - Added logging for error handling in the _run method and during client cleanup. - Introduced a new 'dimensions' field in the MongoDBVectorSearchConfig for embedding vector size. - Refactored the _run method to return JSON formatted results and handle exceptions gracefully. - Cleaned up import statements and improved code readability. * address review * update tests * debug * fix test * fix test * fix test * support azure openai --------- Co-authored-by: lorenzejay <lorenzejaytech@gmail.com>
This commit is contained in:
@@ -25,6 +25,7 @@ CrewAI provides an extensive collection of powerful tools ready to enhance your
|
||||
- **File Management**: `FileReadTool`, `FileWriteTool`
|
||||
- **Web Scraping**: `ScrapeWebsiteTool`, `SeleniumScrapingTool`
|
||||
- **Database Integrations**: `PGSearchTool`, `MySQLSearchTool`
|
||||
- **Vector Database Integrations**: `MongoDBVectorSearchTool`, `QdrantVectorSearchTool`, `WeaviateVectorSearchTool`
|
||||
- **API Integrations**: `SerperApiTool`, `EXASearchTool`
|
||||
- **AI-powered Tools**: `DallETool`, `VisionTool`, `StagehandTool`
|
||||
|
||||
@@ -226,4 +227,3 @@ Join our rapidly growing community and receive real-time support:
|
||||
- [Open an Issue](https://github.com/crewAIInc/crewAI/issues)
|
||||
|
||||
Build smarter, faster, and more powerful AI solutions—powered by CrewAI Tools.
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from .adapters.enterprise_adapter import EnterpriseActionTool
|
||||
from .adapters.mcp_adapter import MCPServerAdapter
|
||||
from .adapters.zapier_adapter import ZapierActionTool
|
||||
from .aws import (
|
||||
BedrockInvokeAgentTool,
|
||||
BedrockKBRetrieverTool,
|
||||
@@ -23,9 +24,9 @@ from .tools import (
|
||||
DirectorySearchTool,
|
||||
DOCXSearchTool,
|
||||
EXASearchTool,
|
||||
FileCompressorTool,
|
||||
FileReadTool,
|
||||
FileWriterTool,
|
||||
FileCompressorTool,
|
||||
FirecrawlCrawlWebsiteTool,
|
||||
FirecrawlScrapeWebsiteTool,
|
||||
FirecrawlSearchTool,
|
||||
@@ -35,6 +36,8 @@ from .tools import (
|
||||
LinkupSearchTool,
|
||||
LlamaIndexTool,
|
||||
MDXSearchTool,
|
||||
MongoDBVectorSearchConfig,
|
||||
MongoDBVectorSearchTool,
|
||||
MultiOnTool,
|
||||
MySQLSearchTool,
|
||||
NL2SQLTool,
|
||||
@@ -76,4 +79,3 @@ from .tools import (
|
||||
YoutubeVideoSearchTool,
|
||||
ZapierActionTools,
|
||||
)
|
||||
from .adapters.zapier_adapter import ZapierActionTool
|
||||
|
||||
@@ -16,10 +16,10 @@ from .docx_search_tool.docx_search_tool import DOCXSearchTool
|
||||
from .exa_tools.exa_search_tool import EXASearchTool
|
||||
from .file_read_tool.file_read_tool import FileReadTool
|
||||
from .file_writer_tool.file_writer_tool import FileWriterTool
|
||||
from .files_compressor_tool.files_compressor_tool import FileCompressorTool
|
||||
from .firecrawl_crawl_website_tool.firecrawl_crawl_website_tool import (
|
||||
FirecrawlCrawlWebsiteTool,
|
||||
)
|
||||
from .files_compressor_tool.files_compressor_tool import FileCompressorTool
|
||||
from .firecrawl_scrape_website_tool.firecrawl_scrape_website_tool import (
|
||||
FirecrawlScrapeWebsiteTool,
|
||||
)
|
||||
@@ -30,6 +30,11 @@ from .json_search_tool.json_search_tool import JSONSearchTool
|
||||
from .linkup.linkup_search_tool import LinkupSearchTool
|
||||
from .llamaindex_tool.llamaindex_tool import LlamaIndexTool
|
||||
from .mdx_search_tool.mdx_search_tool import MDXSearchTool
|
||||
from .mongodb_vector_search_tool import (
|
||||
MongoDBToolSchema,
|
||||
MongoDBVectorSearchConfig,
|
||||
MongoDBVectorSearchTool,
|
||||
)
|
||||
from .multion_tool.multion_tool import MultiOnTool
|
||||
from .mysql_search_tool.mysql_search_tool import MySQLSearchTool
|
||||
from .nl2sql.nl2sql_tool import NL2SQLTool
|
||||
|
||||
87
src/crewai_tools/tools/mongodb_vector_search_tool/README.md
Normal file
87
src/crewai_tools/tools/mongodb_vector_search_tool/README.md
Normal file
@@ -0,0 +1,87 @@
|
||||
# MongoDBVectorSearchTool
|
||||
|
||||
## Description
|
||||
This tool is specifically crafted for conducting vector searches within docs within a MongoDB database. Use this tool to find semantically similar docs to a given query.
|
||||
|
||||
MongoDB can act as a vector database that is used to store and query vector embeddings. You can follow the docs here:
|
||||
https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-overview/
|
||||
|
||||
## Installation
|
||||
Install the crewai_tools package with MongoDB support by executing the following command in your terminal:
|
||||
|
||||
```shell
|
||||
pip install crewai-tools[mongodb]
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```
|
||||
uv add crewai-tools --extra mongodb
|
||||
```
|
||||
|
||||
## Example
|
||||
To utilize the MongoDBVectorSearchTool for different use cases, follow these examples:
|
||||
|
||||
```python
|
||||
from crewai_tools import MongoDBVectorSearchTool
|
||||
|
||||
# To enable the tool to search any website the agent comes across or learns about during its operation
|
||||
tool = MongoDBVectorSearchTool(
|
||||
database_name="example_database',
|
||||
collection_name='example_collections',
|
||||
connection_string="<your_mongodb_connection_string>",
|
||||
)
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```python
|
||||
from crewai_tools import MongoDBVectorSearchConfig, MongoDBVectorSearchTool
|
||||
|
||||
# Setup custom embedding model and customize the parameters.
|
||||
query_config = MongoDBVectorSearchConfig(limit=10, oversampling_factor=2)
|
||||
tool = MongoDBVectorSearchTool(
|
||||
database_name="example_database',
|
||||
collection_name='example_collections',
|
||||
connection_string="<your_mongodb_connection_string>",
|
||||
query_config=query_config,
|
||||
index_name="my_vector_index",
|
||||
generative_model="gpt-4o-mini"
|
||||
)
|
||||
|
||||
# Adding the tool to an agent
|
||||
rag_agent = Agent(
|
||||
name="rag_agent",
|
||||
role="You are a helpful assistant that can answer questions with the help of the MongoDBVectorSearchTool.",
|
||||
goal="...",
|
||||
backstory="...",
|
||||
llm="gpt-4o-mini",
|
||||
tools=[tool],
|
||||
)
|
||||
```
|
||||
|
||||
Preloading the MongoDB database with documents:
|
||||
|
||||
```python
|
||||
from crewai_tools import MongoDBVectorSearchTool
|
||||
|
||||
# Generate the documents and add them to the MongoDB database
|
||||
test_docs = client.collections.get("example_collections")
|
||||
|
||||
# Create the tool.
|
||||
tool = MongoDBVectorSearchTool(
|
||||
database_name="example_database',
|
||||
collection_name='example_collections',
|
||||
connection_string="<your_mongodb_connection_string>",
|
||||
)
|
||||
|
||||
# Add the text from a set of CrewAI knowledge documents.
|
||||
texts = []
|
||||
for d in os.listdir("knowledge"):
|
||||
with open(os.path.join("knowledge", d), "r") as f:
|
||||
texts.append(f.read())
|
||||
tool.add_texts(text)
|
||||
|
||||
# Create the vector search index (if it wasn't already created in Atlas).
|
||||
tool.create_vector_search_index(dimensions=3072)
|
||||
```
|
||||
@@ -0,0 +1,11 @@
|
||||
from .vector_search import (
|
||||
MongoDBToolSchema,
|
||||
MongoDBVectorSearchConfig,
|
||||
MongoDBVectorSearchTool,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"MongoDBVectorSearchConfig",
|
||||
"MongoDBVectorSearchTool",
|
||||
"MongoDBToolSchema",
|
||||
]
|
||||
120
src/crewai_tools/tools/mongodb_vector_search_tool/utils.py
Normal file
120
src/crewai_tools/tools/mongodb_vector_search_tool/utils.py
Normal file
@@ -0,0 +1,120 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from time import monotonic, sleep
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.collection import Collection
|
||||
|
||||
|
||||
def _vector_search_index_definition(
|
||||
dimensions: int,
|
||||
path: str,
|
||||
similarity: str,
|
||||
filters: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
# https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/
|
||||
fields = [
|
||||
{
|
||||
"numDimensions": dimensions,
|
||||
"path": path,
|
||||
"similarity": similarity,
|
||||
"type": "vector",
|
||||
},
|
||||
]
|
||||
if filters:
|
||||
for field in filters:
|
||||
fields.append({"type": "filter", "path": field})
|
||||
definition = {"fields": fields}
|
||||
definition.update(kwargs)
|
||||
return definition
|
||||
|
||||
|
||||
def create_vector_search_index(
|
||||
collection: Collection,
|
||||
index_name: str,
|
||||
dimensions: int,
|
||||
path: str,
|
||||
similarity: str,
|
||||
filters: Optional[List[str]] = None,
|
||||
*,
|
||||
wait_until_complete: Optional[float] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Experimental Utility function to create a vector search index
|
||||
|
||||
Args:
|
||||
collection (Collection): MongoDB Collection
|
||||
index_name (str): Name of Index
|
||||
dimensions (int): Number of dimensions in embedding
|
||||
path (str): field with vector embedding
|
||||
similarity (str): The similarity score used for the index
|
||||
filters (List[str]): Fields/paths to index to allow filtering in $vectorSearch
|
||||
wait_until_complete (Optional[float]): If provided, number of seconds to wait
|
||||
until search index is ready.
|
||||
kwargs: Keyword arguments supplying any additional options to SearchIndexModel.
|
||||
"""
|
||||
from pymongo.operations import SearchIndexModel
|
||||
|
||||
if collection.name not in collection.database.list_collection_names():
|
||||
collection.database.create_collection(collection.name)
|
||||
|
||||
result = collection.create_search_index(
|
||||
SearchIndexModel(
|
||||
definition=_vector_search_index_definition(
|
||||
dimensions=dimensions,
|
||||
path=path,
|
||||
similarity=similarity,
|
||||
filters=filters,
|
||||
**kwargs,
|
||||
),
|
||||
name=index_name,
|
||||
type="vectorSearch",
|
||||
)
|
||||
)
|
||||
|
||||
if wait_until_complete:
|
||||
_wait_for_predicate(
|
||||
predicate=lambda: _is_index_ready(collection, index_name),
|
||||
err=f"{index_name=} did not complete in {wait_until_complete}!",
|
||||
timeout=wait_until_complete,
|
||||
)
|
||||
|
||||
|
||||
def _is_index_ready(collection: Collection, index_name: str) -> bool:
|
||||
"""Check for the index name in the list of available search indexes to see if the
|
||||
specified index is of status READY
|
||||
|
||||
Args:
|
||||
collection (Collection): MongoDB Collection to for the search indexes
|
||||
index_name (str): Vector Search Index name
|
||||
|
||||
Returns:
|
||||
bool : True if the index is present and READY false otherwise
|
||||
"""
|
||||
for index in collection.list_search_indexes(index_name):
|
||||
if index["status"] == "READY":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _wait_for_predicate(
|
||||
predicate: Callable, err: str, timeout: float = 120, interval: float = 0.5
|
||||
) -> None:
|
||||
"""Generic to block until the predicate returns true
|
||||
|
||||
Args:
|
||||
predicate (Callable[, bool]): A function that returns a boolean value
|
||||
err (str): Error message to raise if nothing occurs
|
||||
timeout (float, optional): Wait time for predicate. Defaults to TIMEOUT.
|
||||
interval (float, optional): Interval to check predicate. Defaults to DELAY.
|
||||
|
||||
Raises:
|
||||
TimeoutError: _description_
|
||||
"""
|
||||
start = monotonic()
|
||||
while not predicate():
|
||||
if monotonic() - start > timeout:
|
||||
raise TimeoutError(err)
|
||||
sleep(interval)
|
||||
@@ -0,0 +1,326 @@
|
||||
import json
|
||||
import os
|
||||
from importlib.metadata import version
|
||||
from logging import getLogger
|
||||
from typing import Any, Dict, Iterable, List, Optional, Type
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
from openai import AzureOpenAI, Client
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai_tools.tools.mongodb_vector_search_tool.utils import (
|
||||
create_vector_search_index,
|
||||
)
|
||||
|
||||
try:
|
||||
import pymongo # noqa: F403
|
||||
|
||||
MONGODB_AVAILABLE = True
|
||||
except ImportError:
|
||||
MONGODB_AVAILABLE = False
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class MongoDBVectorSearchConfig(BaseModel):
|
||||
"""Configuration for MongoDB vector search queries."""
|
||||
|
||||
limit: Optional[int] = Field(
|
||||
default=4, description="number of documents to return."
|
||||
)
|
||||
pre_filter: Optional[dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description="List of MQL match expressions comparing an indexed field",
|
||||
)
|
||||
post_filter_pipeline: Optional[list[dict]] = Field(
|
||||
default=None,
|
||||
description="Pipeline of MongoDB aggregation stages to filter/process results after $vectorSearch.",
|
||||
)
|
||||
oversampling_factor: int = Field(
|
||||
default=10,
|
||||
description="Multiple of limit used when generating number of candidates at each step in the HNSW Vector Search",
|
||||
)
|
||||
include_embeddings: bool = Field(
|
||||
default=False,
|
||||
description="Whether to include the embedding vector of each result in metadata.",
|
||||
)
|
||||
|
||||
|
||||
class MongoDBToolSchema(MongoDBVectorSearchConfig):
|
||||
"""Input for MongoDBTool."""
|
||||
|
||||
query: str = Field(
|
||||
...,
|
||||
description="The query to search retrieve relevant information from the MongoDB database. Pass only the query, not the question.",
|
||||
)
|
||||
|
||||
|
||||
class MongoDBVectorSearchTool(BaseTool):
|
||||
"""Tool to perfrom a vector search the MongoDB database"""
|
||||
|
||||
name: str = "MongoDBVectorSearchTool"
|
||||
description: str = "A tool to perfrom a vector search on a MongoDB database for relevant information on internal documents."
|
||||
|
||||
args_schema: Type[BaseModel] = MongoDBToolSchema
|
||||
query_config: Optional[MongoDBVectorSearchConfig] = Field(
|
||||
default=None, description="MongoDB Vector Search query configuration"
|
||||
)
|
||||
embedding_model: str = Field(
|
||||
default="text-embedding-3-large",
|
||||
description="Text OpenAI embedding model to use",
|
||||
)
|
||||
vector_index_name: str = Field(
|
||||
default="vector_index", description="Name of the Atlas Search vector index"
|
||||
)
|
||||
text_key: str = Field(
|
||||
default="text",
|
||||
description="MongoDB field that will contain the text for each document",
|
||||
)
|
||||
embedding_key: str = Field(
|
||||
default="embedding",
|
||||
description="Field that will contain the embedding for each document",
|
||||
)
|
||||
database_name: str = Field(..., description="The name of the MongoDB database")
|
||||
collection_name: str = Field(..., description="The name of the MongoDB collection")
|
||||
connection_string: str = Field(
|
||||
...,
|
||||
description="The connection string of the MongoDB cluster",
|
||||
)
|
||||
dimensions: int = Field(
|
||||
default=1536,
|
||||
description="Number of dimensions in the embedding vector",
|
||||
)
|
||||
env_vars: List[EnvVar] = [
|
||||
EnvVar(
|
||||
name="BROWSERBASE_API_KEY",
|
||||
description="API key for Browserbase services",
|
||||
required=False,
|
||||
),
|
||||
EnvVar(
|
||||
name="BROWSERBASE_PROJECT_ID",
|
||||
description="Project ID for Browserbase services",
|
||||
required=False,
|
||||
),
|
||||
]
|
||||
package_dependencies: List[str] = ["mongdb"]
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if not MONGODB_AVAILABLE:
|
||||
import click
|
||||
|
||||
if click.confirm(
|
||||
"You are missing the 'mongodb' crewai tool. Would you like to install it?"
|
||||
):
|
||||
import subprocess
|
||||
|
||||
subprocess.run(["uv", "add", "pymongo"], check=True)
|
||||
|
||||
else:
|
||||
raise ImportError("You are missing the 'mongodb' crewai tool.")
|
||||
|
||||
if "AZURE_OPENAI_ENDPOINT" in os.environ:
|
||||
self._openai_client = AzureOpenAI()
|
||||
elif "OPENAI_API_KEY" in os.environ:
|
||||
self._openai_client = Client()
|
||||
else:
|
||||
raise ValueError(
|
||||
"OPENAI_API_KEY environment variable is required for MongoDBVectorSearchTool and it is mandatory to use the tool."
|
||||
)
|
||||
|
||||
from pymongo import MongoClient
|
||||
from pymongo.driver_info import DriverInfo
|
||||
|
||||
self._client = MongoClient(
|
||||
self.connection_string,
|
||||
driver=DriverInfo(name="CrewAI", version=version("crewai-tools")),
|
||||
)
|
||||
self._coll = self._client[self.database_name][self.collection_name]
|
||||
|
||||
def create_vector_search_index(
|
||||
self,
|
||||
*,
|
||||
dimensions: int,
|
||||
relevance_score_fn: str = "cosine",
|
||||
auto_index_timeout: int = 15,
|
||||
) -> None:
|
||||
"""Convenience function to create a vector search index.
|
||||
|
||||
Args:
|
||||
dimensions: Number of dimensions in embedding. If the value is set and
|
||||
the index does not exist, an index will be created.
|
||||
relevance_score_fn: The similarity score used for the index
|
||||
Currently supported: 'euclidean', 'cosine', and 'dotProduct'
|
||||
auto_index_timeout: Timeout in seconds to wait for an auto-created index
|
||||
to be ready.
|
||||
"""
|
||||
|
||||
create_vector_search_index(
|
||||
collection=self._coll,
|
||||
index_name=self.vector_index_name,
|
||||
dimensions=dimensions,
|
||||
path=self.embedding_key,
|
||||
similarity=relevance_score_fn,
|
||||
wait_until_complete=auto_index_timeout,
|
||||
)
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[Dict[str, Any]]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
batch_size: int = 100,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Add texts, create embeddings, and add to the Collection and index.
|
||||
|
||||
Important notes on ids:
|
||||
- If _id or id is a key in the metadatas dicts, one must
|
||||
pop them and provide as separate list.
|
||||
- They must be unique.
|
||||
- If they are not provided, the VectorStore will create unique ones,
|
||||
stored as bson.ObjectIds internally, and strings in Langchain.
|
||||
These will appear in Document.metadata with key, '_id'.
|
||||
|
||||
Args:
|
||||
texts: Iterable of strings to add to the vectorstore.
|
||||
metadatas: Optional list of metadatas associated with the texts.
|
||||
ids: Optional list of unique ids that will be used as index in VectorStore.
|
||||
See note on ids.
|
||||
batch_size: Number of documents to insert at a time.
|
||||
Tuning this may help with performance and sidestep MongoDB limits.
|
||||
|
||||
Returns:
|
||||
List of ids added to the vectorstore.
|
||||
"""
|
||||
from bson import ObjectId
|
||||
|
||||
_metadatas = metadatas or [{} for _ in texts]
|
||||
ids = [str(ObjectId()) for _ in range(len(list(texts)))]
|
||||
metadatas_batch = _metadatas
|
||||
|
||||
result_ids = []
|
||||
texts_batch = []
|
||||
metadatas_batch = []
|
||||
size = 0
|
||||
i = 0
|
||||
for j, (text, metadata) in enumerate(zip(texts, _metadatas)):
|
||||
size += len(text) + len(metadata)
|
||||
texts_batch.append(text)
|
||||
metadatas_batch.append(metadata)
|
||||
if (j + 1) % batch_size == 0 or size >= 47_000_000:
|
||||
batch_res = self._bulk_embed_and_insert_texts(
|
||||
texts_batch, metadatas_batch, ids[i : j + 1]
|
||||
)
|
||||
result_ids.extend(batch_res)
|
||||
texts_batch = []
|
||||
metadatas_batch = []
|
||||
size = 0
|
||||
i = j + 1
|
||||
if texts_batch:
|
||||
batch_res = self._bulk_embed_and_insert_texts(
|
||||
texts_batch, metadatas_batch, ids[i : j + 1]
|
||||
)
|
||||
result_ids.extend(batch_res)
|
||||
return result_ids
|
||||
|
||||
def _embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
return [
|
||||
i.embedding
|
||||
for i in self._openai_client.embeddings.create(
|
||||
input=texts,
|
||||
model=self.embedding_model,
|
||||
dimensions=self.dimensions,
|
||||
).data
|
||||
]
|
||||
|
||||
def _bulk_embed_and_insert_texts(
|
||||
self,
|
||||
texts: List[str],
|
||||
metadatas: List[dict],
|
||||
ids: List[str],
|
||||
) -> List[str]:
|
||||
"""Bulk insert single batch of texts, embeddings, and ids."""
|
||||
from bson import ObjectId
|
||||
from pymongo.operations import ReplaceOne
|
||||
|
||||
if not texts:
|
||||
return []
|
||||
# Compute embedding vectors
|
||||
embeddings = self._embed_texts(texts)
|
||||
docs = [
|
||||
{
|
||||
"_id": ObjectId(i),
|
||||
self.text_key: t,
|
||||
self.embedding_key: embedding,
|
||||
**m,
|
||||
}
|
||||
for i, t, m, embedding in zip(ids, texts, metadatas, embeddings)
|
||||
]
|
||||
operations = [ReplaceOne({"_id": doc["_id"]}, doc, upsert=True) for doc in docs]
|
||||
# insert the documents in MongoDB Atlas
|
||||
result = self._coll.bulk_write(operations)
|
||||
assert result.upserted_ids is not None
|
||||
return [str(_id) for _id in result.upserted_ids.values()]
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
try:
|
||||
query_config = self.query_config or MongoDBVectorSearchConfig()
|
||||
limit = query_config.limit
|
||||
oversampling_factor = query_config.oversampling_factor
|
||||
pre_filter = query_config.pre_filter
|
||||
include_embeddings = query_config.include_embeddings
|
||||
post_filter_pipeline = query_config.post_filter_pipeline
|
||||
|
||||
# Create the embedding for the query
|
||||
query_vector = self._embed_texts([query])[0]
|
||||
|
||||
# Atlas Vector Search, potentially with filter
|
||||
stage = {
|
||||
"index": self.vector_index_name,
|
||||
"path": self.embedding_key,
|
||||
"queryVector": query_vector,
|
||||
"numCandidates": limit * oversampling_factor,
|
||||
"limit": limit,
|
||||
}
|
||||
if pre_filter:
|
||||
stage["filter"] = pre_filter
|
||||
|
||||
pipeline = [
|
||||
{"$vectorSearch": stage},
|
||||
{"$set": {"score": {"$meta": "vectorSearchScore"}}},
|
||||
]
|
||||
|
||||
# Remove embeddings unless requested
|
||||
if not include_embeddings:
|
||||
pipeline.append({"$project": {self.embedding_key: 0}})
|
||||
|
||||
# Post-processing
|
||||
if post_filter_pipeline is not None:
|
||||
pipeline.extend(post_filter_pipeline)
|
||||
|
||||
# Execution
|
||||
cursor = self._coll.aggregate(pipeline) # type: ignore[arg-type]
|
||||
docs = []
|
||||
|
||||
# Format
|
||||
for doc in cursor:
|
||||
docs.append(doc)
|
||||
return json.dumps(docs)
|
||||
except Exception as e:
|
||||
logger.error(f"Error: {e}")
|
||||
return ""
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup clients on deletion."""
|
||||
try:
|
||||
if hasattr(self, "_client") and self._client:
|
||||
self._client.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error: {e}")
|
||||
|
||||
try:
|
||||
if hasattr(self, "_openai_client") and self._openai_client:
|
||||
self._openai_client.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error: {e}")
|
||||
@@ -1,12 +1,13 @@
|
||||
import json
|
||||
from typing import List, Optional, Type
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
from unittest import mock
|
||||
|
||||
from generate_tool_specs import ToolSpecExtractor
|
||||
import pytest
|
||||
from crewai.tools.base_tool import BaseTool, EnvVar
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from generate_tool_specs import ToolSpecExtractor
|
||||
|
||||
|
||||
class MockToolSchema(BaseModel):
|
||||
query: str = Field(..., description="The query parameter")
|
||||
@@ -19,15 +20,30 @@ class MockTool(BaseTool):
|
||||
description: str = "A tool that mocks search functionality"
|
||||
args_schema: Type[BaseModel] = MockToolSchema
|
||||
|
||||
another_parameter: str = Field("Another way to define a default value", description="")
|
||||
another_parameter: str = Field(
|
||||
"Another way to define a default value", description=""
|
||||
)
|
||||
my_parameter: str = Field("This is default value", description="What a description")
|
||||
my_parameter_bool: bool = Field(False)
|
||||
package_dependencies: List[str] = Field(["this-is-a-required-package", "another-required-package"], description="")
|
||||
package_dependencies: List[str] = Field(
|
||||
["this-is-a-required-package", "another-required-package"], description=""
|
||||
)
|
||||
env_vars: List[EnvVar] = [
|
||||
EnvVar(name="SERPER_API_KEY", description="API key for Serper", required=True, default=None),
|
||||
EnvVar(name="API_RATE_LIMIT", description="API rate limit", required=False, default="100")
|
||||
EnvVar(
|
||||
name="SERPER_API_KEY",
|
||||
description="API key for Serper",
|
||||
required=True,
|
||||
default=None,
|
||||
),
|
||||
EnvVar(
|
||||
name="API_RATE_LIMIT",
|
||||
description="API rate limit",
|
||||
required=False,
|
||||
default="100",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def extractor():
|
||||
ext = ToolSpecExtractor()
|
||||
@@ -37,7 +53,7 @@ def extractor():
|
||||
def test_unwrap_schema(extractor):
|
||||
nested_schema = {
|
||||
"type": "function-after",
|
||||
"schema": {"type": "default", "schema": {"type": "str", "value": "test"}}
|
||||
"schema": {"type": "default", "schema": {"type": "str", "value": "test"}},
|
||||
}
|
||||
result = extractor._unwrap_schema(nested_schema)
|
||||
assert result["type"] == "str"
|
||||
@@ -46,12 +62,15 @@ def test_unwrap_schema(extractor):
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool_extractor(extractor):
|
||||
with mock.patch("generate_tool_specs.dir", return_value=["MockTool"]), \
|
||||
mock.patch("generate_tool_specs.getattr", return_value=MockTool):
|
||||
with (
|
||||
mock.patch("generate_tool_specs.dir", return_value=["MockTool"]),
|
||||
mock.patch("generate_tool_specs.getattr", return_value=MockTool),
|
||||
):
|
||||
extractor.extract_all_tools()
|
||||
assert len(extractor.tools_spec) == 1
|
||||
return extractor.tools_spec[0]
|
||||
|
||||
|
||||
def test_extract_basic_tool_info(mock_tool_extractor):
|
||||
tool_info = mock_tool_extractor
|
||||
|
||||
@@ -69,6 +88,7 @@ def test_extract_basic_tool_info(mock_tool_extractor):
|
||||
assert tool_info["humanized_name"] == "Mock Search Tool"
|
||||
assert tool_info["description"] == "A tool that mocks search functionality"
|
||||
|
||||
|
||||
def test_extract_init_params_schema(mock_tool_extractor):
|
||||
tool_info = mock_tool_extractor
|
||||
init_params_schema = tool_info["init_params_schema"]
|
||||
@@ -80,20 +100,21 @@ def test_extract_init_params_schema(mock_tool_extractor):
|
||||
"type",
|
||||
}
|
||||
|
||||
another_parameter = init_params_schema['properties']['another_parameter']
|
||||
another_parameter = init_params_schema["properties"]["another_parameter"]
|
||||
assert another_parameter["description"] == ""
|
||||
assert another_parameter["default"] == "Another way to define a default value"
|
||||
assert another_parameter["type"] == "string"
|
||||
|
||||
my_parameter = init_params_schema['properties']['my_parameter']
|
||||
my_parameter = init_params_schema["properties"]["my_parameter"]
|
||||
assert my_parameter["description"] == "What a description"
|
||||
assert my_parameter["default"] == "This is default value"
|
||||
assert my_parameter["type"] == "string"
|
||||
|
||||
my_parameter_bool = init_params_schema['properties']['my_parameter_bool']
|
||||
my_parameter_bool = init_params_schema["properties"]["my_parameter_bool"]
|
||||
assert my_parameter_bool["default"] == False
|
||||
assert my_parameter_bool["type"] == "boolean"
|
||||
|
||||
|
||||
def test_extract_env_vars(mock_tool_extractor):
|
||||
tool_info = mock_tool_extractor
|
||||
|
||||
@@ -109,6 +130,7 @@ def test_extract_env_vars(mock_tool_extractor):
|
||||
assert rate_limit_var["required"] == False
|
||||
assert rate_limit_var["default"] == "100"
|
||||
|
||||
|
||||
def test_extract_run_params_schema(mock_tool_extractor):
|
||||
tool_info = mock_tool_extractor
|
||||
|
||||
@@ -131,22 +153,31 @@ def test_extract_run_params_schema(mock_tool_extractor):
|
||||
filters_param = run_params_schema["properties"]["filters"]
|
||||
assert filters_param["description"] == "Optional filters to apply"
|
||||
assert filters_param["default"] == None
|
||||
assert filters_param['anyOf'] == [{'items': {'type': 'string'}, 'type': 'array'}, {'type': 'null'}]
|
||||
assert filters_param["anyOf"] == [
|
||||
{"items": {"type": "string"}, "type": "array"},
|
||||
{"type": "null"},
|
||||
]
|
||||
|
||||
|
||||
def test_extract_package_dependencies(mock_tool_extractor):
|
||||
tool_info = mock_tool_extractor
|
||||
assert tool_info["package_dependencies"] == ["this-is-a-required-package", "another-required-package"]
|
||||
assert tool_info["package_dependencies"] == [
|
||||
"this-is-a-required-package",
|
||||
"another-required-package",
|
||||
]
|
||||
|
||||
|
||||
def test_save_to_json(extractor, tmp_path):
|
||||
extractor.tools_spec = [{
|
||||
"name": "TestTool",
|
||||
"humanized_name": "Test Tool",
|
||||
"description": "A test tool",
|
||||
"run_params_schema": [
|
||||
{"name": "param1", "description": "Test parameter", "type": "str"}
|
||||
]
|
||||
}]
|
||||
extractor.tools_spec = [
|
||||
{
|
||||
"name": "TestTool",
|
||||
"humanized_name": "Test Tool",
|
||||
"description": "A test tool",
|
||||
"run_params_schema": [
|
||||
{"name": "param1", "description": "Test parameter", "type": "str"}
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
file_path = tmp_path / "output.json"
|
||||
extractor.save_to_json(str(file_path))
|
||||
|
||||
75
tests/tools/test_mongodb_vector_search_tool.py
Normal file
75
tests/tools/test_mongodb_vector_search_tool.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai_tools import MongoDBVectorSearchConfig, MongoDBVectorSearchTool
|
||||
|
||||
|
||||
# Unit Test Fixtures
|
||||
@pytest.fixture
|
||||
def mongodb_vector_search_tool():
|
||||
tool = MongoDBVectorSearchTool(
|
||||
connection_string="foo", database_name="bar", collection_name="test"
|
||||
)
|
||||
tool._embed_texts = lambda x: [[0.1]]
|
||||
yield tool
|
||||
|
||||
|
||||
# Unit Tests
|
||||
def test_successful_query_execution(mongodb_vector_search_tool):
|
||||
# Enable embedding
|
||||
with patch.object(mongodb_vector_search_tool._coll, "aggregate") as mock_aggregate:
|
||||
mock_aggregate.return_value = [dict(text="foo", score=0.1, _id=1)]
|
||||
|
||||
results = json.loads(mongodb_vector_search_tool._run(query="sandwiches"))
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["text"] == "foo"
|
||||
assert results[0]["_id"] == 1
|
||||
|
||||
|
||||
def test_provide_config():
|
||||
query_config = MongoDBVectorSearchConfig(limit=10)
|
||||
tool = MongoDBVectorSearchTool(
|
||||
connection_string="foo",
|
||||
database_name="bar",
|
||||
collection_name="test",
|
||||
query_config=query_config,
|
||||
vector_index_name="foo",
|
||||
embedding_model="bar",
|
||||
)
|
||||
tool._embed_texts = lambda x: [[0.1]]
|
||||
with patch.object(tool._coll, "aggregate") as mock_aggregate:
|
||||
mock_aggregate.return_value = [dict(text="foo", score=0.1, _id=1)]
|
||||
|
||||
tool._run(query="sandwiches")
|
||||
assert mock_aggregate.mock_calls[-1].args[0][0]["$vectorSearch"]["limit"] == 10
|
||||
|
||||
mock_aggregate.return_value = [dict(text="foo", score=0.1, _id=1)]
|
||||
|
||||
|
||||
def test_cleanup_on_deletion(mongodb_vector_search_tool):
|
||||
with patch.object(mongodb_vector_search_tool, "_client") as mock_client:
|
||||
# Trigger cleanup
|
||||
mongodb_vector_search_tool.__del__()
|
||||
|
||||
mock_client.close.assert_called_once()
|
||||
|
||||
|
||||
def test_create_search_index(mongodb_vector_search_tool):
|
||||
with patch(
|
||||
"crewai_tools.tools.mongodb_vector_search_tool.vector_search.create_vector_search_index"
|
||||
) as mock_create_search_index:
|
||||
mongodb_vector_search_tool.create_vector_search_index(dimensions=10)
|
||||
kwargs = mock_create_search_index.mock_calls[0].kwargs
|
||||
assert kwargs["dimensions"] == 10
|
||||
assert kwargs["similarity"] == "cosine"
|
||||
|
||||
|
||||
def test_add_texts(mongodb_vector_search_tool):
|
||||
with patch.object(mongodb_vector_search_tool._coll, "bulk_write") as bulk_write:
|
||||
mongodb_vector_search_tool.add_texts(["foo"])
|
||||
args = bulk_write.mock_calls[0].args
|
||||
assert "ReplaceOne" in str(args[0][0])
|
||||
assert "foo" in str(args[0][0])
|
||||
Reference in New Issue
Block a user