mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
Add MongoDB Vector Search Tool (#319)
* INTPYTHON-580 Design and Implement MongoDBVectorSearchTool * add implementation * wip * wip * finish tests * add todo * refactor to wrap langchain-mongodb * cleanup * address review * Fix usage of EnvVar class * inline code * lint * lint * fix usage of SearchIndexModel * Refactor: Update EnvVar import path and remove unused tests.utils module - Changed import of EnvVar from tests.utils to crewai.tools in multiple files. - Updated README.md for MongoDB vector search tool with additional context. - Modified subprocess command in vector_search.py for package installation. - Cleaned up test_generate_tool_specs.py to improve mock patching syntax. - Deleted unused tests/utils.py file. * update the crewai dep and the lockfile * chore: update package versions and dependencies in uv.lock - Removed `auth0-python` package. - Updated `crewai` version to 0.140.0 and adjusted its dependencies. - Changed `json-repair` version to 0.25.2. - Updated `litellm` version to 1.72.6. - Modified dependency markers for several packages to improve compatibility with Python versions. * refactor: improve MongoDB vector search tool with enhanced error handling and new dimensions field - Added logging for error handling in the _run method and during client cleanup. - Introduced a new 'dimensions' field in the MongoDBVectorSearchConfig for embedding vector size. - Refactored the _run method to return JSON formatted results and handle exceptions gracefully. - Cleaned up import statements and improved code readability. * address review * update tests * debug * fix test * fix test * fix test * support azure openai --------- Co-authored-by: lorenzejay <lorenzejaytech@gmail.com>
This commit is contained in:
@@ -25,6 +25,7 @@ CrewAI provides an extensive collection of powerful tools ready to enhance your
|
|||||||
- **File Management**: `FileReadTool`, `FileWriteTool`
|
- **File Management**: `FileReadTool`, `FileWriteTool`
|
||||||
- **Web Scraping**: `ScrapeWebsiteTool`, `SeleniumScrapingTool`
|
- **Web Scraping**: `ScrapeWebsiteTool`, `SeleniumScrapingTool`
|
||||||
- **Database Integrations**: `PGSearchTool`, `MySQLSearchTool`
|
- **Database Integrations**: `PGSearchTool`, `MySQLSearchTool`
|
||||||
|
- **Vector Database Integrations**: `MongoDBVectorSearchTool`, `QdrantVectorSearchTool`, `WeaviateVectorSearchTool`
|
||||||
- **API Integrations**: `SerperApiTool`, `EXASearchTool`
|
- **API Integrations**: `SerperApiTool`, `EXASearchTool`
|
||||||
- **AI-powered Tools**: `DallETool`, `VisionTool`, `StagehandTool`
|
- **AI-powered Tools**: `DallETool`, `VisionTool`, `StagehandTool`
|
||||||
|
|
||||||
@@ -226,4 +227,3 @@ Join our rapidly growing community and receive real-time support:
|
|||||||
- [Open an Issue](https://github.com/crewAIInc/crewAI/issues)
|
- [Open an Issue](https://github.com/crewAIInc/crewAI/issues)
|
||||||
|
|
||||||
Build smarter, faster, and more powerful AI solutions—powered by CrewAI Tools.
|
Build smarter, faster, and more powerful AI solutions—powered by CrewAI Tools.
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from .adapters.enterprise_adapter import EnterpriseActionTool
|
from .adapters.enterprise_adapter import EnterpriseActionTool
|
||||||
from .adapters.mcp_adapter import MCPServerAdapter
|
from .adapters.mcp_adapter import MCPServerAdapter
|
||||||
|
from .adapters.zapier_adapter import ZapierActionTool
|
||||||
from .aws import (
|
from .aws import (
|
||||||
BedrockInvokeAgentTool,
|
BedrockInvokeAgentTool,
|
||||||
BedrockKBRetrieverTool,
|
BedrockKBRetrieverTool,
|
||||||
@@ -23,9 +24,9 @@ from .tools import (
|
|||||||
DirectorySearchTool,
|
DirectorySearchTool,
|
||||||
DOCXSearchTool,
|
DOCXSearchTool,
|
||||||
EXASearchTool,
|
EXASearchTool,
|
||||||
|
FileCompressorTool,
|
||||||
FileReadTool,
|
FileReadTool,
|
||||||
FileWriterTool,
|
FileWriterTool,
|
||||||
FileCompressorTool,
|
|
||||||
FirecrawlCrawlWebsiteTool,
|
FirecrawlCrawlWebsiteTool,
|
||||||
FirecrawlScrapeWebsiteTool,
|
FirecrawlScrapeWebsiteTool,
|
||||||
FirecrawlSearchTool,
|
FirecrawlSearchTool,
|
||||||
@@ -35,6 +36,8 @@ from .tools import (
|
|||||||
LinkupSearchTool,
|
LinkupSearchTool,
|
||||||
LlamaIndexTool,
|
LlamaIndexTool,
|
||||||
MDXSearchTool,
|
MDXSearchTool,
|
||||||
|
MongoDBVectorSearchConfig,
|
||||||
|
MongoDBVectorSearchTool,
|
||||||
MultiOnTool,
|
MultiOnTool,
|
||||||
MySQLSearchTool,
|
MySQLSearchTool,
|
||||||
NL2SQLTool,
|
NL2SQLTool,
|
||||||
@@ -76,4 +79,3 @@ from .tools import (
|
|||||||
YoutubeVideoSearchTool,
|
YoutubeVideoSearchTool,
|
||||||
ZapierActionTools,
|
ZapierActionTools,
|
||||||
)
|
)
|
||||||
from .adapters.zapier_adapter import ZapierActionTool
|
|
||||||
|
|||||||
@@ -16,10 +16,10 @@ from .docx_search_tool.docx_search_tool import DOCXSearchTool
|
|||||||
from .exa_tools.exa_search_tool import EXASearchTool
|
from .exa_tools.exa_search_tool import EXASearchTool
|
||||||
from .file_read_tool.file_read_tool import FileReadTool
|
from .file_read_tool.file_read_tool import FileReadTool
|
||||||
from .file_writer_tool.file_writer_tool import FileWriterTool
|
from .file_writer_tool.file_writer_tool import FileWriterTool
|
||||||
|
from .files_compressor_tool.files_compressor_tool import FileCompressorTool
|
||||||
from .firecrawl_crawl_website_tool.firecrawl_crawl_website_tool import (
|
from .firecrawl_crawl_website_tool.firecrawl_crawl_website_tool import (
|
||||||
FirecrawlCrawlWebsiteTool,
|
FirecrawlCrawlWebsiteTool,
|
||||||
)
|
)
|
||||||
from .files_compressor_tool.files_compressor_tool import FileCompressorTool
|
|
||||||
from .firecrawl_scrape_website_tool.firecrawl_scrape_website_tool import (
|
from .firecrawl_scrape_website_tool.firecrawl_scrape_website_tool import (
|
||||||
FirecrawlScrapeWebsiteTool,
|
FirecrawlScrapeWebsiteTool,
|
||||||
)
|
)
|
||||||
@@ -30,6 +30,11 @@ from .json_search_tool.json_search_tool import JSONSearchTool
|
|||||||
from .linkup.linkup_search_tool import LinkupSearchTool
|
from .linkup.linkup_search_tool import LinkupSearchTool
|
||||||
from .llamaindex_tool.llamaindex_tool import LlamaIndexTool
|
from .llamaindex_tool.llamaindex_tool import LlamaIndexTool
|
||||||
from .mdx_search_tool.mdx_search_tool import MDXSearchTool
|
from .mdx_search_tool.mdx_search_tool import MDXSearchTool
|
||||||
|
from .mongodb_vector_search_tool import (
|
||||||
|
MongoDBToolSchema,
|
||||||
|
MongoDBVectorSearchConfig,
|
||||||
|
MongoDBVectorSearchTool,
|
||||||
|
)
|
||||||
from .multion_tool.multion_tool import MultiOnTool
|
from .multion_tool.multion_tool import MultiOnTool
|
||||||
from .mysql_search_tool.mysql_search_tool import MySQLSearchTool
|
from .mysql_search_tool.mysql_search_tool import MySQLSearchTool
|
||||||
from .nl2sql.nl2sql_tool import NL2SQLTool
|
from .nl2sql.nl2sql_tool import NL2SQLTool
|
||||||
|
|||||||
87
src/crewai_tools/tools/mongodb_vector_search_tool/README.md
Normal file
87
src/crewai_tools/tools/mongodb_vector_search_tool/README.md
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
# MongoDBVectorSearchTool
|
||||||
|
|
||||||
|
## Description
|
||||||
|
This tool is specifically crafted for conducting vector searches within docs within a MongoDB database. Use this tool to find semantically similar docs to a given query.
|
||||||
|
|
||||||
|
MongoDB can act as a vector database that is used to store and query vector embeddings. You can follow the docs here:
|
||||||
|
https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-overview/
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
Install the crewai_tools package with MongoDB support by executing the following command in your terminal:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
pip install crewai-tools[mongodb]
|
||||||
|
```
|
||||||
|
|
||||||
|
or
|
||||||
|
|
||||||
|
```
|
||||||
|
uv add crewai-tools --extra mongodb
|
||||||
|
```
|
||||||
|
|
||||||
|
## Example
|
||||||
|
To utilize the MongoDBVectorSearchTool for different use cases, follow these examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crewai_tools import MongoDBVectorSearchTool
|
||||||
|
|
||||||
|
# To enable the tool to search any website the agent comes across or learns about during its operation
|
||||||
|
tool = MongoDBVectorSearchTool(
|
||||||
|
database_name="example_database',
|
||||||
|
collection_name='example_collections',
|
||||||
|
connection_string="<your_mongodb_connection_string>",
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
or
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crewai_tools import MongoDBVectorSearchConfig, MongoDBVectorSearchTool
|
||||||
|
|
||||||
|
# Setup custom embedding model and customize the parameters.
|
||||||
|
query_config = MongoDBVectorSearchConfig(limit=10, oversampling_factor=2)
|
||||||
|
tool = MongoDBVectorSearchTool(
|
||||||
|
database_name="example_database',
|
||||||
|
collection_name='example_collections',
|
||||||
|
connection_string="<your_mongodb_connection_string>",
|
||||||
|
query_config=query_config,
|
||||||
|
index_name="my_vector_index",
|
||||||
|
generative_model="gpt-4o-mini"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Adding the tool to an agent
|
||||||
|
rag_agent = Agent(
|
||||||
|
name="rag_agent",
|
||||||
|
role="You are a helpful assistant that can answer questions with the help of the MongoDBVectorSearchTool.",
|
||||||
|
goal="...",
|
||||||
|
backstory="...",
|
||||||
|
llm="gpt-4o-mini",
|
||||||
|
tools=[tool],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Preloading the MongoDB database with documents:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crewai_tools import MongoDBVectorSearchTool
|
||||||
|
|
||||||
|
# Generate the documents and add them to the MongoDB database
|
||||||
|
test_docs = client.collections.get("example_collections")
|
||||||
|
|
||||||
|
# Create the tool.
|
||||||
|
tool = MongoDBVectorSearchTool(
|
||||||
|
database_name="example_database',
|
||||||
|
collection_name='example_collections',
|
||||||
|
connection_string="<your_mongodb_connection_string>",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add the text from a set of CrewAI knowledge documents.
|
||||||
|
texts = []
|
||||||
|
for d in os.listdir("knowledge"):
|
||||||
|
with open(os.path.join("knowledge", d), "r") as f:
|
||||||
|
texts.append(f.read())
|
||||||
|
tool.add_texts(text)
|
||||||
|
|
||||||
|
# Create the vector search index (if it wasn't already created in Atlas).
|
||||||
|
tool.create_vector_search_index(dimensions=3072)
|
||||||
|
```
|
||||||
@@ -0,0 +1,11 @@
|
|||||||
|
from .vector_search import (
|
||||||
|
MongoDBToolSchema,
|
||||||
|
MongoDBVectorSearchConfig,
|
||||||
|
MongoDBVectorSearchTool,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"MongoDBVectorSearchConfig",
|
||||||
|
"MongoDBVectorSearchTool",
|
||||||
|
"MongoDBToolSchema",
|
||||||
|
]
|
||||||
120
src/crewai_tools/tools/mongodb_vector_search_tool/utils.py
Normal file
120
src/crewai_tools/tools/mongodb_vector_search_tool/utils.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from time import monotonic, sleep
|
||||||
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from pymongo.collection import Collection
|
||||||
|
|
||||||
|
|
||||||
|
def _vector_search_index_definition(
|
||||||
|
dimensions: int,
|
||||||
|
path: str,
|
||||||
|
similarity: str,
|
||||||
|
filters: Optional[List[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
# https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/
|
||||||
|
fields = [
|
||||||
|
{
|
||||||
|
"numDimensions": dimensions,
|
||||||
|
"path": path,
|
||||||
|
"similarity": similarity,
|
||||||
|
"type": "vector",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
if filters:
|
||||||
|
for field in filters:
|
||||||
|
fields.append({"type": "filter", "path": field})
|
||||||
|
definition = {"fields": fields}
|
||||||
|
definition.update(kwargs)
|
||||||
|
return definition
|
||||||
|
|
||||||
|
|
||||||
|
def create_vector_search_index(
|
||||||
|
collection: Collection,
|
||||||
|
index_name: str,
|
||||||
|
dimensions: int,
|
||||||
|
path: str,
|
||||||
|
similarity: str,
|
||||||
|
filters: Optional[List[str]] = None,
|
||||||
|
*,
|
||||||
|
wait_until_complete: Optional[float] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Experimental Utility function to create a vector search index
|
||||||
|
|
||||||
|
Args:
|
||||||
|
collection (Collection): MongoDB Collection
|
||||||
|
index_name (str): Name of Index
|
||||||
|
dimensions (int): Number of dimensions in embedding
|
||||||
|
path (str): field with vector embedding
|
||||||
|
similarity (str): The similarity score used for the index
|
||||||
|
filters (List[str]): Fields/paths to index to allow filtering in $vectorSearch
|
||||||
|
wait_until_complete (Optional[float]): If provided, number of seconds to wait
|
||||||
|
until search index is ready.
|
||||||
|
kwargs: Keyword arguments supplying any additional options to SearchIndexModel.
|
||||||
|
"""
|
||||||
|
from pymongo.operations import SearchIndexModel
|
||||||
|
|
||||||
|
if collection.name not in collection.database.list_collection_names():
|
||||||
|
collection.database.create_collection(collection.name)
|
||||||
|
|
||||||
|
result = collection.create_search_index(
|
||||||
|
SearchIndexModel(
|
||||||
|
definition=_vector_search_index_definition(
|
||||||
|
dimensions=dimensions,
|
||||||
|
path=path,
|
||||||
|
similarity=similarity,
|
||||||
|
filters=filters,
|
||||||
|
**kwargs,
|
||||||
|
),
|
||||||
|
name=index_name,
|
||||||
|
type="vectorSearch",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if wait_until_complete:
|
||||||
|
_wait_for_predicate(
|
||||||
|
predicate=lambda: _is_index_ready(collection, index_name),
|
||||||
|
err=f"{index_name=} did not complete in {wait_until_complete}!",
|
||||||
|
timeout=wait_until_complete,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_index_ready(collection: Collection, index_name: str) -> bool:
|
||||||
|
"""Check for the index name in the list of available search indexes to see if the
|
||||||
|
specified index is of status READY
|
||||||
|
|
||||||
|
Args:
|
||||||
|
collection (Collection): MongoDB Collection to for the search indexes
|
||||||
|
index_name (str): Vector Search Index name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool : True if the index is present and READY false otherwise
|
||||||
|
"""
|
||||||
|
for index in collection.list_search_indexes(index_name):
|
||||||
|
if index["status"] == "READY":
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _wait_for_predicate(
|
||||||
|
predicate: Callable, err: str, timeout: float = 120, interval: float = 0.5
|
||||||
|
) -> None:
|
||||||
|
"""Generic to block until the predicate returns true
|
||||||
|
|
||||||
|
Args:
|
||||||
|
predicate (Callable[, bool]): A function that returns a boolean value
|
||||||
|
err (str): Error message to raise if nothing occurs
|
||||||
|
timeout (float, optional): Wait time for predicate. Defaults to TIMEOUT.
|
||||||
|
interval (float, optional): Interval to check predicate. Defaults to DELAY.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TimeoutError: _description_
|
||||||
|
"""
|
||||||
|
start = monotonic()
|
||||||
|
while not predicate():
|
||||||
|
if monotonic() - start > timeout:
|
||||||
|
raise TimeoutError(err)
|
||||||
|
sleep(interval)
|
||||||
@@ -0,0 +1,326 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
from importlib.metadata import version
|
||||||
|
from logging import getLogger
|
||||||
|
from typing import Any, Dict, Iterable, List, Optional, Type
|
||||||
|
|
||||||
|
from crewai.tools import BaseTool, EnvVar
|
||||||
|
from openai import AzureOpenAI, Client
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from crewai_tools.tools.mongodb_vector_search_tool.utils import (
|
||||||
|
create_vector_search_index,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import pymongo # noqa: F403
|
||||||
|
|
||||||
|
MONGODB_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
MONGODB_AVAILABLE = False
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MongoDBVectorSearchConfig(BaseModel):
|
||||||
|
"""Configuration for MongoDB vector search queries."""
|
||||||
|
|
||||||
|
limit: Optional[int] = Field(
|
||||||
|
default=4, description="number of documents to return."
|
||||||
|
)
|
||||||
|
pre_filter: Optional[dict[str, Any]] = Field(
|
||||||
|
default=None,
|
||||||
|
description="List of MQL match expressions comparing an indexed field",
|
||||||
|
)
|
||||||
|
post_filter_pipeline: Optional[list[dict]] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Pipeline of MongoDB aggregation stages to filter/process results after $vectorSearch.",
|
||||||
|
)
|
||||||
|
oversampling_factor: int = Field(
|
||||||
|
default=10,
|
||||||
|
description="Multiple of limit used when generating number of candidates at each step in the HNSW Vector Search",
|
||||||
|
)
|
||||||
|
include_embeddings: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="Whether to include the embedding vector of each result in metadata.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MongoDBToolSchema(MongoDBVectorSearchConfig):
|
||||||
|
"""Input for MongoDBTool."""
|
||||||
|
|
||||||
|
query: str = Field(
|
||||||
|
...,
|
||||||
|
description="The query to search retrieve relevant information from the MongoDB database. Pass only the query, not the question.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MongoDBVectorSearchTool(BaseTool):
|
||||||
|
"""Tool to perfrom a vector search the MongoDB database"""
|
||||||
|
|
||||||
|
name: str = "MongoDBVectorSearchTool"
|
||||||
|
description: str = "A tool to perfrom a vector search on a MongoDB database for relevant information on internal documents."
|
||||||
|
|
||||||
|
args_schema: Type[BaseModel] = MongoDBToolSchema
|
||||||
|
query_config: Optional[MongoDBVectorSearchConfig] = Field(
|
||||||
|
default=None, description="MongoDB Vector Search query configuration"
|
||||||
|
)
|
||||||
|
embedding_model: str = Field(
|
||||||
|
default="text-embedding-3-large",
|
||||||
|
description="Text OpenAI embedding model to use",
|
||||||
|
)
|
||||||
|
vector_index_name: str = Field(
|
||||||
|
default="vector_index", description="Name of the Atlas Search vector index"
|
||||||
|
)
|
||||||
|
text_key: str = Field(
|
||||||
|
default="text",
|
||||||
|
description="MongoDB field that will contain the text for each document",
|
||||||
|
)
|
||||||
|
embedding_key: str = Field(
|
||||||
|
default="embedding",
|
||||||
|
description="Field that will contain the embedding for each document",
|
||||||
|
)
|
||||||
|
database_name: str = Field(..., description="The name of the MongoDB database")
|
||||||
|
collection_name: str = Field(..., description="The name of the MongoDB collection")
|
||||||
|
connection_string: str = Field(
|
||||||
|
...,
|
||||||
|
description="The connection string of the MongoDB cluster",
|
||||||
|
)
|
||||||
|
dimensions: int = Field(
|
||||||
|
default=1536,
|
||||||
|
description="Number of dimensions in the embedding vector",
|
||||||
|
)
|
||||||
|
env_vars: List[EnvVar] = [
|
||||||
|
EnvVar(
|
||||||
|
name="BROWSERBASE_API_KEY",
|
||||||
|
description="API key for Browserbase services",
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
EnvVar(
|
||||||
|
name="BROWSERBASE_PROJECT_ID",
|
||||||
|
description="Project ID for Browserbase services",
|
||||||
|
required=False,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
package_dependencies: List[str] = ["mongdb"]
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
if not MONGODB_AVAILABLE:
|
||||||
|
import click
|
||||||
|
|
||||||
|
if click.confirm(
|
||||||
|
"You are missing the 'mongodb' crewai tool. Would you like to install it?"
|
||||||
|
):
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
subprocess.run(["uv", "add", "pymongo"], check=True)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ImportError("You are missing the 'mongodb' crewai tool.")
|
||||||
|
|
||||||
|
if "AZURE_OPENAI_ENDPOINT" in os.environ:
|
||||||
|
self._openai_client = AzureOpenAI()
|
||||||
|
elif "OPENAI_API_KEY" in os.environ:
|
||||||
|
self._openai_client = Client()
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"OPENAI_API_KEY environment variable is required for MongoDBVectorSearchTool and it is mandatory to use the tool."
|
||||||
|
)
|
||||||
|
|
||||||
|
from pymongo import MongoClient
|
||||||
|
from pymongo.driver_info import DriverInfo
|
||||||
|
|
||||||
|
self._client = MongoClient(
|
||||||
|
self.connection_string,
|
||||||
|
driver=DriverInfo(name="CrewAI", version=version("crewai-tools")),
|
||||||
|
)
|
||||||
|
self._coll = self._client[self.database_name][self.collection_name]
|
||||||
|
|
||||||
|
def create_vector_search_index(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
dimensions: int,
|
||||||
|
relevance_score_fn: str = "cosine",
|
||||||
|
auto_index_timeout: int = 15,
|
||||||
|
) -> None:
|
||||||
|
"""Convenience function to create a vector search index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dimensions: Number of dimensions in embedding. If the value is set and
|
||||||
|
the index does not exist, an index will be created.
|
||||||
|
relevance_score_fn: The similarity score used for the index
|
||||||
|
Currently supported: 'euclidean', 'cosine', and 'dotProduct'
|
||||||
|
auto_index_timeout: Timeout in seconds to wait for an auto-created index
|
||||||
|
to be ready.
|
||||||
|
"""
|
||||||
|
|
||||||
|
create_vector_search_index(
|
||||||
|
collection=self._coll,
|
||||||
|
index_name=self.vector_index_name,
|
||||||
|
dimensions=dimensions,
|
||||||
|
path=self.embedding_key,
|
||||||
|
similarity=relevance_score_fn,
|
||||||
|
wait_until_complete=auto_index_timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_texts(
|
||||||
|
self,
|
||||||
|
texts: Iterable[str],
|
||||||
|
metadatas: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
ids: Optional[List[str]] = None,
|
||||||
|
batch_size: int = 100,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[str]:
|
||||||
|
"""Add texts, create embeddings, and add to the Collection and index.
|
||||||
|
|
||||||
|
Important notes on ids:
|
||||||
|
- If _id or id is a key in the metadatas dicts, one must
|
||||||
|
pop them and provide as separate list.
|
||||||
|
- They must be unique.
|
||||||
|
- If they are not provided, the VectorStore will create unique ones,
|
||||||
|
stored as bson.ObjectIds internally, and strings in Langchain.
|
||||||
|
These will appear in Document.metadata with key, '_id'.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: Iterable of strings to add to the vectorstore.
|
||||||
|
metadatas: Optional list of metadatas associated with the texts.
|
||||||
|
ids: Optional list of unique ids that will be used as index in VectorStore.
|
||||||
|
See note on ids.
|
||||||
|
batch_size: Number of documents to insert at a time.
|
||||||
|
Tuning this may help with performance and sidestep MongoDB limits.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ids added to the vectorstore.
|
||||||
|
"""
|
||||||
|
from bson import ObjectId
|
||||||
|
|
||||||
|
_metadatas = metadatas or [{} for _ in texts]
|
||||||
|
ids = [str(ObjectId()) for _ in range(len(list(texts)))]
|
||||||
|
metadatas_batch = _metadatas
|
||||||
|
|
||||||
|
result_ids = []
|
||||||
|
texts_batch = []
|
||||||
|
metadatas_batch = []
|
||||||
|
size = 0
|
||||||
|
i = 0
|
||||||
|
for j, (text, metadata) in enumerate(zip(texts, _metadatas)):
|
||||||
|
size += len(text) + len(metadata)
|
||||||
|
texts_batch.append(text)
|
||||||
|
metadatas_batch.append(metadata)
|
||||||
|
if (j + 1) % batch_size == 0 or size >= 47_000_000:
|
||||||
|
batch_res = self._bulk_embed_and_insert_texts(
|
||||||
|
texts_batch, metadatas_batch, ids[i : j + 1]
|
||||||
|
)
|
||||||
|
result_ids.extend(batch_res)
|
||||||
|
texts_batch = []
|
||||||
|
metadatas_batch = []
|
||||||
|
size = 0
|
||||||
|
i = j + 1
|
||||||
|
if texts_batch:
|
||||||
|
batch_res = self._bulk_embed_and_insert_texts(
|
||||||
|
texts_batch, metadatas_batch, ids[i : j + 1]
|
||||||
|
)
|
||||||
|
result_ids.extend(batch_res)
|
||||||
|
return result_ids
|
||||||
|
|
||||||
|
def _embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
return [
|
||||||
|
i.embedding
|
||||||
|
for i in self._openai_client.embeddings.create(
|
||||||
|
input=texts,
|
||||||
|
model=self.embedding_model,
|
||||||
|
dimensions=self.dimensions,
|
||||||
|
).data
|
||||||
|
]
|
||||||
|
|
||||||
|
def _bulk_embed_and_insert_texts(
|
||||||
|
self,
|
||||||
|
texts: List[str],
|
||||||
|
metadatas: List[dict],
|
||||||
|
ids: List[str],
|
||||||
|
) -> List[str]:
|
||||||
|
"""Bulk insert single batch of texts, embeddings, and ids."""
|
||||||
|
from bson import ObjectId
|
||||||
|
from pymongo.operations import ReplaceOne
|
||||||
|
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
# Compute embedding vectors
|
||||||
|
embeddings = self._embed_texts(texts)
|
||||||
|
docs = [
|
||||||
|
{
|
||||||
|
"_id": ObjectId(i),
|
||||||
|
self.text_key: t,
|
||||||
|
self.embedding_key: embedding,
|
||||||
|
**m,
|
||||||
|
}
|
||||||
|
for i, t, m, embedding in zip(ids, texts, metadatas, embeddings)
|
||||||
|
]
|
||||||
|
operations = [ReplaceOne({"_id": doc["_id"]}, doc, upsert=True) for doc in docs]
|
||||||
|
# insert the documents in MongoDB Atlas
|
||||||
|
result = self._coll.bulk_write(operations)
|
||||||
|
assert result.upserted_ids is not None
|
||||||
|
return [str(_id) for _id in result.upserted_ids.values()]
|
||||||
|
|
||||||
|
def _run(self, query: str) -> str:
|
||||||
|
try:
|
||||||
|
query_config = self.query_config or MongoDBVectorSearchConfig()
|
||||||
|
limit = query_config.limit
|
||||||
|
oversampling_factor = query_config.oversampling_factor
|
||||||
|
pre_filter = query_config.pre_filter
|
||||||
|
include_embeddings = query_config.include_embeddings
|
||||||
|
post_filter_pipeline = query_config.post_filter_pipeline
|
||||||
|
|
||||||
|
# Create the embedding for the query
|
||||||
|
query_vector = self._embed_texts([query])[0]
|
||||||
|
|
||||||
|
# Atlas Vector Search, potentially with filter
|
||||||
|
stage = {
|
||||||
|
"index": self.vector_index_name,
|
||||||
|
"path": self.embedding_key,
|
||||||
|
"queryVector": query_vector,
|
||||||
|
"numCandidates": limit * oversampling_factor,
|
||||||
|
"limit": limit,
|
||||||
|
}
|
||||||
|
if pre_filter:
|
||||||
|
stage["filter"] = pre_filter
|
||||||
|
|
||||||
|
pipeline = [
|
||||||
|
{"$vectorSearch": stage},
|
||||||
|
{"$set": {"score": {"$meta": "vectorSearchScore"}}},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Remove embeddings unless requested
|
||||||
|
if not include_embeddings:
|
||||||
|
pipeline.append({"$project": {self.embedding_key: 0}})
|
||||||
|
|
||||||
|
# Post-processing
|
||||||
|
if post_filter_pipeline is not None:
|
||||||
|
pipeline.extend(post_filter_pipeline)
|
||||||
|
|
||||||
|
# Execution
|
||||||
|
cursor = self._coll.aggregate(pipeline) # type: ignore[arg-type]
|
||||||
|
docs = []
|
||||||
|
|
||||||
|
# Format
|
||||||
|
for doc in cursor:
|
||||||
|
docs.append(doc)
|
||||||
|
return json.dumps(docs)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error: {e}")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
"""Cleanup clients on deletion."""
|
||||||
|
try:
|
||||||
|
if hasattr(self, "_client") and self._client:
|
||||||
|
self._client.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if hasattr(self, "_openai_client") and self._openai_client:
|
||||||
|
self._openai_client.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error: {e}")
|
||||||
@@ -1,12 +1,13 @@
|
|||||||
import json
|
import json
|
||||||
from typing import List, Optional, Type
|
from typing import List, Optional, Type
|
||||||
|
|
||||||
import pytest
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
from generate_tool_specs import ToolSpecExtractor
|
import pytest
|
||||||
from crewai.tools.base_tool import BaseTool, EnvVar
|
from crewai.tools.base_tool import BaseTool, EnvVar
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from generate_tool_specs import ToolSpecExtractor
|
||||||
|
|
||||||
|
|
||||||
class MockToolSchema(BaseModel):
|
class MockToolSchema(BaseModel):
|
||||||
query: str = Field(..., description="The query parameter")
|
query: str = Field(..., description="The query parameter")
|
||||||
@@ -19,15 +20,30 @@ class MockTool(BaseTool):
|
|||||||
description: str = "A tool that mocks search functionality"
|
description: str = "A tool that mocks search functionality"
|
||||||
args_schema: Type[BaseModel] = MockToolSchema
|
args_schema: Type[BaseModel] = MockToolSchema
|
||||||
|
|
||||||
another_parameter: str = Field("Another way to define a default value", description="")
|
another_parameter: str = Field(
|
||||||
|
"Another way to define a default value", description=""
|
||||||
|
)
|
||||||
my_parameter: str = Field("This is default value", description="What a description")
|
my_parameter: str = Field("This is default value", description="What a description")
|
||||||
my_parameter_bool: bool = Field(False)
|
my_parameter_bool: bool = Field(False)
|
||||||
package_dependencies: List[str] = Field(["this-is-a-required-package", "another-required-package"], description="")
|
package_dependencies: List[str] = Field(
|
||||||
|
["this-is-a-required-package", "another-required-package"], description=""
|
||||||
|
)
|
||||||
env_vars: List[EnvVar] = [
|
env_vars: List[EnvVar] = [
|
||||||
EnvVar(name="SERPER_API_KEY", description="API key for Serper", required=True, default=None),
|
EnvVar(
|
||||||
EnvVar(name="API_RATE_LIMIT", description="API rate limit", required=False, default="100")
|
name="SERPER_API_KEY",
|
||||||
|
description="API key for Serper",
|
||||||
|
required=True,
|
||||||
|
default=None,
|
||||||
|
),
|
||||||
|
EnvVar(
|
||||||
|
name="API_RATE_LIMIT",
|
||||||
|
description="API rate limit",
|
||||||
|
required=False,
|
||||||
|
default="100",
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def extractor():
|
def extractor():
|
||||||
ext = ToolSpecExtractor()
|
ext = ToolSpecExtractor()
|
||||||
@@ -37,7 +53,7 @@ def extractor():
|
|||||||
def test_unwrap_schema(extractor):
|
def test_unwrap_schema(extractor):
|
||||||
nested_schema = {
|
nested_schema = {
|
||||||
"type": "function-after",
|
"type": "function-after",
|
||||||
"schema": {"type": "default", "schema": {"type": "str", "value": "test"}}
|
"schema": {"type": "default", "schema": {"type": "str", "value": "test"}},
|
||||||
}
|
}
|
||||||
result = extractor._unwrap_schema(nested_schema)
|
result = extractor._unwrap_schema(nested_schema)
|
||||||
assert result["type"] == "str"
|
assert result["type"] == "str"
|
||||||
@@ -46,12 +62,15 @@ def test_unwrap_schema(extractor):
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_tool_extractor(extractor):
|
def mock_tool_extractor(extractor):
|
||||||
with mock.patch("generate_tool_specs.dir", return_value=["MockTool"]), \
|
with (
|
||||||
mock.patch("generate_tool_specs.getattr", return_value=MockTool):
|
mock.patch("generate_tool_specs.dir", return_value=["MockTool"]),
|
||||||
|
mock.patch("generate_tool_specs.getattr", return_value=MockTool),
|
||||||
|
):
|
||||||
extractor.extract_all_tools()
|
extractor.extract_all_tools()
|
||||||
assert len(extractor.tools_spec) == 1
|
assert len(extractor.tools_spec) == 1
|
||||||
return extractor.tools_spec[0]
|
return extractor.tools_spec[0]
|
||||||
|
|
||||||
|
|
||||||
def test_extract_basic_tool_info(mock_tool_extractor):
|
def test_extract_basic_tool_info(mock_tool_extractor):
|
||||||
tool_info = mock_tool_extractor
|
tool_info = mock_tool_extractor
|
||||||
|
|
||||||
@@ -69,6 +88,7 @@ def test_extract_basic_tool_info(mock_tool_extractor):
|
|||||||
assert tool_info["humanized_name"] == "Mock Search Tool"
|
assert tool_info["humanized_name"] == "Mock Search Tool"
|
||||||
assert tool_info["description"] == "A tool that mocks search functionality"
|
assert tool_info["description"] == "A tool that mocks search functionality"
|
||||||
|
|
||||||
|
|
||||||
def test_extract_init_params_schema(mock_tool_extractor):
|
def test_extract_init_params_schema(mock_tool_extractor):
|
||||||
tool_info = mock_tool_extractor
|
tool_info = mock_tool_extractor
|
||||||
init_params_schema = tool_info["init_params_schema"]
|
init_params_schema = tool_info["init_params_schema"]
|
||||||
@@ -80,20 +100,21 @@ def test_extract_init_params_schema(mock_tool_extractor):
|
|||||||
"type",
|
"type",
|
||||||
}
|
}
|
||||||
|
|
||||||
another_parameter = init_params_schema['properties']['another_parameter']
|
another_parameter = init_params_schema["properties"]["another_parameter"]
|
||||||
assert another_parameter["description"] == ""
|
assert another_parameter["description"] == ""
|
||||||
assert another_parameter["default"] == "Another way to define a default value"
|
assert another_parameter["default"] == "Another way to define a default value"
|
||||||
assert another_parameter["type"] == "string"
|
assert another_parameter["type"] == "string"
|
||||||
|
|
||||||
my_parameter = init_params_schema['properties']['my_parameter']
|
my_parameter = init_params_schema["properties"]["my_parameter"]
|
||||||
assert my_parameter["description"] == "What a description"
|
assert my_parameter["description"] == "What a description"
|
||||||
assert my_parameter["default"] == "This is default value"
|
assert my_parameter["default"] == "This is default value"
|
||||||
assert my_parameter["type"] == "string"
|
assert my_parameter["type"] == "string"
|
||||||
|
|
||||||
my_parameter_bool = init_params_schema['properties']['my_parameter_bool']
|
my_parameter_bool = init_params_schema["properties"]["my_parameter_bool"]
|
||||||
assert my_parameter_bool["default"] == False
|
assert my_parameter_bool["default"] == False
|
||||||
assert my_parameter_bool["type"] == "boolean"
|
assert my_parameter_bool["type"] == "boolean"
|
||||||
|
|
||||||
|
|
||||||
def test_extract_env_vars(mock_tool_extractor):
|
def test_extract_env_vars(mock_tool_extractor):
|
||||||
tool_info = mock_tool_extractor
|
tool_info = mock_tool_extractor
|
||||||
|
|
||||||
@@ -109,6 +130,7 @@ def test_extract_env_vars(mock_tool_extractor):
|
|||||||
assert rate_limit_var["required"] == False
|
assert rate_limit_var["required"] == False
|
||||||
assert rate_limit_var["default"] == "100"
|
assert rate_limit_var["default"] == "100"
|
||||||
|
|
||||||
|
|
||||||
def test_extract_run_params_schema(mock_tool_extractor):
|
def test_extract_run_params_schema(mock_tool_extractor):
|
||||||
tool_info = mock_tool_extractor
|
tool_info = mock_tool_extractor
|
||||||
|
|
||||||
@@ -131,22 +153,31 @@ def test_extract_run_params_schema(mock_tool_extractor):
|
|||||||
filters_param = run_params_schema["properties"]["filters"]
|
filters_param = run_params_schema["properties"]["filters"]
|
||||||
assert filters_param["description"] == "Optional filters to apply"
|
assert filters_param["description"] == "Optional filters to apply"
|
||||||
assert filters_param["default"] == None
|
assert filters_param["default"] == None
|
||||||
assert filters_param['anyOf'] == [{'items': {'type': 'string'}, 'type': 'array'}, {'type': 'null'}]
|
assert filters_param["anyOf"] == [
|
||||||
|
{"items": {"type": "string"}, "type": "array"},
|
||||||
|
{"type": "null"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_extract_package_dependencies(mock_tool_extractor):
|
def test_extract_package_dependencies(mock_tool_extractor):
|
||||||
tool_info = mock_tool_extractor
|
tool_info = mock_tool_extractor
|
||||||
assert tool_info["package_dependencies"] == ["this-is-a-required-package", "another-required-package"]
|
assert tool_info["package_dependencies"] == [
|
||||||
|
"this-is-a-required-package",
|
||||||
|
"another-required-package",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_save_to_json(extractor, tmp_path):
|
def test_save_to_json(extractor, tmp_path):
|
||||||
extractor.tools_spec = [{
|
extractor.tools_spec = [
|
||||||
"name": "TestTool",
|
{
|
||||||
"humanized_name": "Test Tool",
|
"name": "TestTool",
|
||||||
"description": "A test tool",
|
"humanized_name": "Test Tool",
|
||||||
"run_params_schema": [
|
"description": "A test tool",
|
||||||
{"name": "param1", "description": "Test parameter", "type": "str"}
|
"run_params_schema": [
|
||||||
]
|
{"name": "param1", "description": "Test parameter", "type": "str"}
|
||||||
}]
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
file_path = tmp_path / "output.json"
|
file_path = tmp_path / "output.json"
|
||||||
extractor.save_to_json(str(file_path))
|
extractor.save_to_json(str(file_path))
|
||||||
|
|||||||
75
tests/tools/test_mongodb_vector_search_tool.py
Normal file
75
tests/tools/test_mongodb_vector_search_tool.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
import json
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from crewai_tools import MongoDBVectorSearchConfig, MongoDBVectorSearchTool
|
||||||
|
|
||||||
|
|
||||||
|
# Unit Test Fixtures
|
||||||
|
@pytest.fixture
|
||||||
|
def mongodb_vector_search_tool():
|
||||||
|
tool = MongoDBVectorSearchTool(
|
||||||
|
connection_string="foo", database_name="bar", collection_name="test"
|
||||||
|
)
|
||||||
|
tool._embed_texts = lambda x: [[0.1]]
|
||||||
|
yield tool
|
||||||
|
|
||||||
|
|
||||||
|
# Unit Tests
|
||||||
|
def test_successful_query_execution(mongodb_vector_search_tool):
|
||||||
|
# Enable embedding
|
||||||
|
with patch.object(mongodb_vector_search_tool._coll, "aggregate") as mock_aggregate:
|
||||||
|
mock_aggregate.return_value = [dict(text="foo", score=0.1, _id=1)]
|
||||||
|
|
||||||
|
results = json.loads(mongodb_vector_search_tool._run(query="sandwiches"))
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0]["text"] == "foo"
|
||||||
|
assert results[0]["_id"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_provide_config():
|
||||||
|
query_config = MongoDBVectorSearchConfig(limit=10)
|
||||||
|
tool = MongoDBVectorSearchTool(
|
||||||
|
connection_string="foo",
|
||||||
|
database_name="bar",
|
||||||
|
collection_name="test",
|
||||||
|
query_config=query_config,
|
||||||
|
vector_index_name="foo",
|
||||||
|
embedding_model="bar",
|
||||||
|
)
|
||||||
|
tool._embed_texts = lambda x: [[0.1]]
|
||||||
|
with patch.object(tool._coll, "aggregate") as mock_aggregate:
|
||||||
|
mock_aggregate.return_value = [dict(text="foo", score=0.1, _id=1)]
|
||||||
|
|
||||||
|
tool._run(query="sandwiches")
|
||||||
|
assert mock_aggregate.mock_calls[-1].args[0][0]["$vectorSearch"]["limit"] == 10
|
||||||
|
|
||||||
|
mock_aggregate.return_value = [dict(text="foo", score=0.1, _id=1)]
|
||||||
|
|
||||||
|
|
||||||
|
def test_cleanup_on_deletion(mongodb_vector_search_tool):
|
||||||
|
with patch.object(mongodb_vector_search_tool, "_client") as mock_client:
|
||||||
|
# Trigger cleanup
|
||||||
|
mongodb_vector_search_tool.__del__()
|
||||||
|
|
||||||
|
mock_client.close.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_search_index(mongodb_vector_search_tool):
|
||||||
|
with patch(
|
||||||
|
"crewai_tools.tools.mongodb_vector_search_tool.vector_search.create_vector_search_index"
|
||||||
|
) as mock_create_search_index:
|
||||||
|
mongodb_vector_search_tool.create_vector_search_index(dimensions=10)
|
||||||
|
kwargs = mock_create_search_index.mock_calls[0].kwargs
|
||||||
|
assert kwargs["dimensions"] == 10
|
||||||
|
assert kwargs["similarity"] == "cosine"
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_texts(mongodb_vector_search_tool):
|
||||||
|
with patch.object(mongodb_vector_search_tool._coll, "bulk_write") as bulk_write:
|
||||||
|
mongodb_vector_search_tool.add_texts(["foo"])
|
||||||
|
args = bulk_write.mock_calls[0].args
|
||||||
|
assert "ReplaceOne" in str(args[0][0])
|
||||||
|
assert "foo" in str(args[0][0])
|
||||||
Reference in New Issue
Block a user