Add MongoDB Vector Search Tool (#319)

* INTPYTHON-580 Design and Implement MongoDBVectorSearchTool

* add implementation

* wip

* wip

* finish tests

* add todo

* refactor to wrap langchain-mongodb

* cleanup

* address review

* Fix usage of EnvVar class

* inline code

* lint

* lint

* fix usage of SearchIndexModel

* Refactor: Update EnvVar import path and remove unused tests.utils module

- Changed import of EnvVar from tests.utils to crewai.tools in multiple files.
- Updated README.md for MongoDB vector search tool with additional context.
- Modified subprocess command in vector_search.py for package installation.
- Cleaned up test_generate_tool_specs.py to improve mock patching syntax.
- Deleted unused tests/utils.py file.

* update the crewai dep and the lockfile

* chore: update package versions and dependencies in uv.lock

- Removed `auth0-python` package.
- Updated `crewai` version to 0.140.0 and adjusted its dependencies.
- Changed `json-repair` version to 0.25.2.
- Updated `litellm` version to 1.72.6.
- Modified dependency markers for several packages to improve compatibility with Python versions.

* refactor: improve MongoDB vector search tool with enhanced error handling and new dimensions field

- Added logging for error handling in the _run method and during client cleanup.
- Introduced a new 'dimensions' field in the MongoDBVectorSearchConfig for embedding vector size.
- Refactored the _run method to return JSON formatted results and handle exceptions gracefully.
- Cleaned up import statements and improved code readability.

* address review

* update tests

* debug

* fix test

* fix test

* fix test

* support azure openai

---------

Co-authored-by: lorenzejay <lorenzejaytech@gmail.com>
This commit is contained in:
Steven Silvester
2025-07-09 10:44:23 -05:00
committed by GitHub
parent c45e92bd17
commit e0de166592
9 changed files with 685 additions and 28 deletions

View File

@@ -25,6 +25,7 @@ CrewAI provides an extensive collection of powerful tools ready to enhance your
- **File Management**: `FileReadTool`, `FileWriteTool`
- **Web Scraping**: `ScrapeWebsiteTool`, `SeleniumScrapingTool`
- **Database Integrations**: `PGSearchTool`, `MySQLSearchTool`
- **Vector Database Integrations**: `MongoDBVectorSearchTool`, `QdrantVectorSearchTool`, `WeaviateVectorSearchTool`
- **API Integrations**: `SerperApiTool`, `EXASearchTool`
- **AI-powered Tools**: `DallETool`, `VisionTool`, `StagehandTool`
@@ -226,4 +227,3 @@ Join our rapidly growing community and receive real-time support:
- [Open an Issue](https://github.com/crewAIInc/crewAI/issues)
Build smarter, faster, and more powerful AI solutions—powered by CrewAI Tools.

View File

@@ -1,5 +1,6 @@
from .adapters.enterprise_adapter import EnterpriseActionTool
from .adapters.mcp_adapter import MCPServerAdapter
from .adapters.zapier_adapter import ZapierActionTool
from .aws import (
BedrockInvokeAgentTool,
BedrockKBRetrieverTool,
@@ -23,9 +24,9 @@ from .tools import (
DirectorySearchTool,
DOCXSearchTool,
EXASearchTool,
FileCompressorTool,
FileReadTool,
FileWriterTool,
FileCompressorTool,
FirecrawlCrawlWebsiteTool,
FirecrawlScrapeWebsiteTool,
FirecrawlSearchTool,
@@ -35,6 +36,8 @@ from .tools import (
LinkupSearchTool,
LlamaIndexTool,
MDXSearchTool,
MongoDBVectorSearchConfig,
MongoDBVectorSearchTool,
MultiOnTool,
MySQLSearchTool,
NL2SQLTool,
@@ -76,4 +79,3 @@ from .tools import (
YoutubeVideoSearchTool,
ZapierActionTools,
)
from .adapters.zapier_adapter import ZapierActionTool

View File

@@ -16,10 +16,10 @@ from .docx_search_tool.docx_search_tool import DOCXSearchTool
from .exa_tools.exa_search_tool import EXASearchTool
from .file_read_tool.file_read_tool import FileReadTool
from .file_writer_tool.file_writer_tool import FileWriterTool
from .files_compressor_tool.files_compressor_tool import FileCompressorTool
from .firecrawl_crawl_website_tool.firecrawl_crawl_website_tool import (
FirecrawlCrawlWebsiteTool,
)
from .files_compressor_tool.files_compressor_tool import FileCompressorTool
from .firecrawl_scrape_website_tool.firecrawl_scrape_website_tool import (
FirecrawlScrapeWebsiteTool,
)
@@ -30,6 +30,11 @@ from .json_search_tool.json_search_tool import JSONSearchTool
from .linkup.linkup_search_tool import LinkupSearchTool
from .llamaindex_tool.llamaindex_tool import LlamaIndexTool
from .mdx_search_tool.mdx_search_tool import MDXSearchTool
from .mongodb_vector_search_tool import (
MongoDBToolSchema,
MongoDBVectorSearchConfig,
MongoDBVectorSearchTool,
)
from .multion_tool.multion_tool import MultiOnTool
from .mysql_search_tool.mysql_search_tool import MySQLSearchTool
from .nl2sql.nl2sql_tool import NL2SQLTool

View File

@@ -0,0 +1,87 @@
# MongoDBVectorSearchTool
## Description
This tool is specifically crafted for conducting vector searches within docs within a MongoDB database. Use this tool to find semantically similar docs to a given query.
MongoDB can act as a vector database that is used to store and query vector embeddings. You can follow the docs here:
https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-overview/
## Installation
Install the crewai_tools package with MongoDB support by executing the following command in your terminal:
```shell
pip install crewai-tools[mongodb]
```
or
```
uv add crewai-tools --extra mongodb
```
## Example
To utilize the MongoDBVectorSearchTool for different use cases, follow these examples:
```python
from crewai_tools import MongoDBVectorSearchTool
# To enable the tool to search any website the agent comes across or learns about during its operation
tool = MongoDBVectorSearchTool(
database_name="example_database',
collection_name='example_collections',
connection_string="<your_mongodb_connection_string>",
)
```
or
```python
from crewai_tools import MongoDBVectorSearchConfig, MongoDBVectorSearchTool
# Setup custom embedding model and customize the parameters.
query_config = MongoDBVectorSearchConfig(limit=10, oversampling_factor=2)
tool = MongoDBVectorSearchTool(
database_name="example_database',
collection_name='example_collections',
connection_string="<your_mongodb_connection_string>",
query_config=query_config,
index_name="my_vector_index",
generative_model="gpt-4o-mini"
)
# Adding the tool to an agent
rag_agent = Agent(
name="rag_agent",
role="You are a helpful assistant that can answer questions with the help of the MongoDBVectorSearchTool.",
goal="...",
backstory="...",
llm="gpt-4o-mini",
tools=[tool],
)
```
Preloading the MongoDB database with documents:
```python
from crewai_tools import MongoDBVectorSearchTool
# Generate the documents and add them to the MongoDB database
test_docs = client.collections.get("example_collections")
# Create the tool.
tool = MongoDBVectorSearchTool(
database_name="example_database',
collection_name='example_collections',
connection_string="<your_mongodb_connection_string>",
)
# Add the text from a set of CrewAI knowledge documents.
texts = []
for d in os.listdir("knowledge"):
with open(os.path.join("knowledge", d), "r") as f:
texts.append(f.read())
tool.add_texts(text)
# Create the vector search index (if it wasn't already created in Atlas).
tool.create_vector_search_index(dimensions=3072)
```

View File

@@ -0,0 +1,11 @@
from .vector_search import (
MongoDBToolSchema,
MongoDBVectorSearchConfig,
MongoDBVectorSearchTool,
)
__all__ = [
"MongoDBVectorSearchConfig",
"MongoDBVectorSearchTool",
"MongoDBToolSchema",
]

View File

@@ -0,0 +1,120 @@
from __future__ import annotations
from time import monotonic, sleep
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
if TYPE_CHECKING:
from pymongo.collection import Collection
def _vector_search_index_definition(
dimensions: int,
path: str,
similarity: str,
filters: Optional[List[str]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
# https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/
fields = [
{
"numDimensions": dimensions,
"path": path,
"similarity": similarity,
"type": "vector",
},
]
if filters:
for field in filters:
fields.append({"type": "filter", "path": field})
definition = {"fields": fields}
definition.update(kwargs)
return definition
def create_vector_search_index(
collection: Collection,
index_name: str,
dimensions: int,
path: str,
similarity: str,
filters: Optional[List[str]] = None,
*,
wait_until_complete: Optional[float] = None,
**kwargs: Any,
) -> None:
"""Experimental Utility function to create a vector search index
Args:
collection (Collection): MongoDB Collection
index_name (str): Name of Index
dimensions (int): Number of dimensions in embedding
path (str): field with vector embedding
similarity (str): The similarity score used for the index
filters (List[str]): Fields/paths to index to allow filtering in $vectorSearch
wait_until_complete (Optional[float]): If provided, number of seconds to wait
until search index is ready.
kwargs: Keyword arguments supplying any additional options to SearchIndexModel.
"""
from pymongo.operations import SearchIndexModel
if collection.name not in collection.database.list_collection_names():
collection.database.create_collection(collection.name)
result = collection.create_search_index(
SearchIndexModel(
definition=_vector_search_index_definition(
dimensions=dimensions,
path=path,
similarity=similarity,
filters=filters,
**kwargs,
),
name=index_name,
type="vectorSearch",
)
)
if wait_until_complete:
_wait_for_predicate(
predicate=lambda: _is_index_ready(collection, index_name),
err=f"{index_name=} did not complete in {wait_until_complete}!",
timeout=wait_until_complete,
)
def _is_index_ready(collection: Collection, index_name: str) -> bool:
"""Check for the index name in the list of available search indexes to see if the
specified index is of status READY
Args:
collection (Collection): MongoDB Collection to for the search indexes
index_name (str): Vector Search Index name
Returns:
bool : True if the index is present and READY false otherwise
"""
for index in collection.list_search_indexes(index_name):
if index["status"] == "READY":
return True
return False
def _wait_for_predicate(
predicate: Callable, err: str, timeout: float = 120, interval: float = 0.5
) -> None:
"""Generic to block until the predicate returns true
Args:
predicate (Callable[, bool]): A function that returns a boolean value
err (str): Error message to raise if nothing occurs
timeout (float, optional): Wait time for predicate. Defaults to TIMEOUT.
interval (float, optional): Interval to check predicate. Defaults to DELAY.
Raises:
TimeoutError: _description_
"""
start = monotonic()
while not predicate():
if monotonic() - start > timeout:
raise TimeoutError(err)
sleep(interval)

View File

@@ -0,0 +1,326 @@
import json
import os
from importlib.metadata import version
from logging import getLogger
from typing import Any, Dict, Iterable, List, Optional, Type
from crewai.tools import BaseTool, EnvVar
from openai import AzureOpenAI, Client
from pydantic import BaseModel, Field
from crewai_tools.tools.mongodb_vector_search_tool.utils import (
create_vector_search_index,
)
try:
import pymongo # noqa: F403
MONGODB_AVAILABLE = True
except ImportError:
MONGODB_AVAILABLE = False
logger = getLogger(__name__)
class MongoDBVectorSearchConfig(BaseModel):
"""Configuration for MongoDB vector search queries."""
limit: Optional[int] = Field(
default=4, description="number of documents to return."
)
pre_filter: Optional[dict[str, Any]] = Field(
default=None,
description="List of MQL match expressions comparing an indexed field",
)
post_filter_pipeline: Optional[list[dict]] = Field(
default=None,
description="Pipeline of MongoDB aggregation stages to filter/process results after $vectorSearch.",
)
oversampling_factor: int = Field(
default=10,
description="Multiple of limit used when generating number of candidates at each step in the HNSW Vector Search",
)
include_embeddings: bool = Field(
default=False,
description="Whether to include the embedding vector of each result in metadata.",
)
class MongoDBToolSchema(MongoDBVectorSearchConfig):
"""Input for MongoDBTool."""
query: str = Field(
...,
description="The query to search retrieve relevant information from the MongoDB database. Pass only the query, not the question.",
)
class MongoDBVectorSearchTool(BaseTool):
"""Tool to perfrom a vector search the MongoDB database"""
name: str = "MongoDBVectorSearchTool"
description: str = "A tool to perfrom a vector search on a MongoDB database for relevant information on internal documents."
args_schema: Type[BaseModel] = MongoDBToolSchema
query_config: Optional[MongoDBVectorSearchConfig] = Field(
default=None, description="MongoDB Vector Search query configuration"
)
embedding_model: str = Field(
default="text-embedding-3-large",
description="Text OpenAI embedding model to use",
)
vector_index_name: str = Field(
default="vector_index", description="Name of the Atlas Search vector index"
)
text_key: str = Field(
default="text",
description="MongoDB field that will contain the text for each document",
)
embedding_key: str = Field(
default="embedding",
description="Field that will contain the embedding for each document",
)
database_name: str = Field(..., description="The name of the MongoDB database")
collection_name: str = Field(..., description="The name of the MongoDB collection")
connection_string: str = Field(
...,
description="The connection string of the MongoDB cluster",
)
dimensions: int = Field(
default=1536,
description="Number of dimensions in the embedding vector",
)
env_vars: List[EnvVar] = [
EnvVar(
name="BROWSERBASE_API_KEY",
description="API key for Browserbase services",
required=False,
),
EnvVar(
name="BROWSERBASE_PROJECT_ID",
description="Project ID for Browserbase services",
required=False,
),
]
package_dependencies: List[str] = ["mongdb"]
def __init__(self, **kwargs):
super().__init__(**kwargs)
if not MONGODB_AVAILABLE:
import click
if click.confirm(
"You are missing the 'mongodb' crewai tool. Would you like to install it?"
):
import subprocess
subprocess.run(["uv", "add", "pymongo"], check=True)
else:
raise ImportError("You are missing the 'mongodb' crewai tool.")
if "AZURE_OPENAI_ENDPOINT" in os.environ:
self._openai_client = AzureOpenAI()
elif "OPENAI_API_KEY" in os.environ:
self._openai_client = Client()
else:
raise ValueError(
"OPENAI_API_KEY environment variable is required for MongoDBVectorSearchTool and it is mandatory to use the tool."
)
from pymongo import MongoClient
from pymongo.driver_info import DriverInfo
self._client = MongoClient(
self.connection_string,
driver=DriverInfo(name="CrewAI", version=version("crewai-tools")),
)
self._coll = self._client[self.database_name][self.collection_name]
def create_vector_search_index(
self,
*,
dimensions: int,
relevance_score_fn: str = "cosine",
auto_index_timeout: int = 15,
) -> None:
"""Convenience function to create a vector search index.
Args:
dimensions: Number of dimensions in embedding. If the value is set and
the index does not exist, an index will be created.
relevance_score_fn: The similarity score used for the index
Currently supported: 'euclidean', 'cosine', and 'dotProduct'
auto_index_timeout: Timeout in seconds to wait for an auto-created index
to be ready.
"""
create_vector_search_index(
collection=self._coll,
index_name=self.vector_index_name,
dimensions=dimensions,
path=self.embedding_key,
similarity=relevance_score_fn,
wait_until_complete=auto_index_timeout,
)
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[Dict[str, Any]]] = None,
ids: Optional[List[str]] = None,
batch_size: int = 100,
**kwargs: Any,
) -> List[str]:
"""Add texts, create embeddings, and add to the Collection and index.
Important notes on ids:
- If _id or id is a key in the metadatas dicts, one must
pop them and provide as separate list.
- They must be unique.
- If they are not provided, the VectorStore will create unique ones,
stored as bson.ObjectIds internally, and strings in Langchain.
These will appear in Document.metadata with key, '_id'.
Args:
texts: Iterable of strings to add to the vectorstore.
metadatas: Optional list of metadatas associated with the texts.
ids: Optional list of unique ids that will be used as index in VectorStore.
See note on ids.
batch_size: Number of documents to insert at a time.
Tuning this may help with performance and sidestep MongoDB limits.
Returns:
List of ids added to the vectorstore.
"""
from bson import ObjectId
_metadatas = metadatas or [{} for _ in texts]
ids = [str(ObjectId()) for _ in range(len(list(texts)))]
metadatas_batch = _metadatas
result_ids = []
texts_batch = []
metadatas_batch = []
size = 0
i = 0
for j, (text, metadata) in enumerate(zip(texts, _metadatas)):
size += len(text) + len(metadata)
texts_batch.append(text)
metadatas_batch.append(metadata)
if (j + 1) % batch_size == 0 or size >= 47_000_000:
batch_res = self._bulk_embed_and_insert_texts(
texts_batch, metadatas_batch, ids[i : j + 1]
)
result_ids.extend(batch_res)
texts_batch = []
metadatas_batch = []
size = 0
i = j + 1
if texts_batch:
batch_res = self._bulk_embed_and_insert_texts(
texts_batch, metadatas_batch, ids[i : j + 1]
)
result_ids.extend(batch_res)
return result_ids
def _embed_texts(self, texts: List[str]) -> List[List[float]]:
return [
i.embedding
for i in self._openai_client.embeddings.create(
input=texts,
model=self.embedding_model,
dimensions=self.dimensions,
).data
]
def _bulk_embed_and_insert_texts(
self,
texts: List[str],
metadatas: List[dict],
ids: List[str],
) -> List[str]:
"""Bulk insert single batch of texts, embeddings, and ids."""
from bson import ObjectId
from pymongo.operations import ReplaceOne
if not texts:
return []
# Compute embedding vectors
embeddings = self._embed_texts(texts)
docs = [
{
"_id": ObjectId(i),
self.text_key: t,
self.embedding_key: embedding,
**m,
}
for i, t, m, embedding in zip(ids, texts, metadatas, embeddings)
]
operations = [ReplaceOne({"_id": doc["_id"]}, doc, upsert=True) for doc in docs]
# insert the documents in MongoDB Atlas
result = self._coll.bulk_write(operations)
assert result.upserted_ids is not None
return [str(_id) for _id in result.upserted_ids.values()]
def _run(self, query: str) -> str:
try:
query_config = self.query_config or MongoDBVectorSearchConfig()
limit = query_config.limit
oversampling_factor = query_config.oversampling_factor
pre_filter = query_config.pre_filter
include_embeddings = query_config.include_embeddings
post_filter_pipeline = query_config.post_filter_pipeline
# Create the embedding for the query
query_vector = self._embed_texts([query])[0]
# Atlas Vector Search, potentially with filter
stage = {
"index": self.vector_index_name,
"path": self.embedding_key,
"queryVector": query_vector,
"numCandidates": limit * oversampling_factor,
"limit": limit,
}
if pre_filter:
stage["filter"] = pre_filter
pipeline = [
{"$vectorSearch": stage},
{"$set": {"score": {"$meta": "vectorSearchScore"}}},
]
# Remove embeddings unless requested
if not include_embeddings:
pipeline.append({"$project": {self.embedding_key: 0}})
# Post-processing
if post_filter_pipeline is not None:
pipeline.extend(post_filter_pipeline)
# Execution
cursor = self._coll.aggregate(pipeline) # type: ignore[arg-type]
docs = []
# Format
for doc in cursor:
docs.append(doc)
return json.dumps(docs)
except Exception as e:
logger.error(f"Error: {e}")
return ""
def __del__(self):
"""Cleanup clients on deletion."""
try:
if hasattr(self, "_client") and self._client:
self._client.close()
except Exception as e:
logger.error(f"Error: {e}")
try:
if hasattr(self, "_openai_client") and self._openai_client:
self._openai_client.close()
except Exception as e:
logger.error(f"Error: {e}")

View File

@@ -1,12 +1,13 @@
import json
from typing import List, Optional, Type
import pytest
from pydantic import BaseModel, Field
from unittest import mock
from generate_tool_specs import ToolSpecExtractor
import pytest
from crewai.tools.base_tool import BaseTool, EnvVar
from pydantic import BaseModel, Field
from generate_tool_specs import ToolSpecExtractor
class MockToolSchema(BaseModel):
query: str = Field(..., description="The query parameter")
@@ -19,15 +20,30 @@ class MockTool(BaseTool):
description: str = "A tool that mocks search functionality"
args_schema: Type[BaseModel] = MockToolSchema
another_parameter: str = Field("Another way to define a default value", description="")
another_parameter: str = Field(
"Another way to define a default value", description=""
)
my_parameter: str = Field("This is default value", description="What a description")
my_parameter_bool: bool = Field(False)
package_dependencies: List[str] = Field(["this-is-a-required-package", "another-required-package"], description="")
package_dependencies: List[str] = Field(
["this-is-a-required-package", "another-required-package"], description=""
)
env_vars: List[EnvVar] = [
EnvVar(name="SERPER_API_KEY", description="API key for Serper", required=True, default=None),
EnvVar(name="API_RATE_LIMIT", description="API rate limit", required=False, default="100")
EnvVar(
name="SERPER_API_KEY",
description="API key for Serper",
required=True,
default=None,
),
EnvVar(
name="API_RATE_LIMIT",
description="API rate limit",
required=False,
default="100",
),
]
@pytest.fixture
def extractor():
ext = ToolSpecExtractor()
@@ -37,7 +53,7 @@ def extractor():
def test_unwrap_schema(extractor):
nested_schema = {
"type": "function-after",
"schema": {"type": "default", "schema": {"type": "str", "value": "test"}}
"schema": {"type": "default", "schema": {"type": "str", "value": "test"}},
}
result = extractor._unwrap_schema(nested_schema)
assert result["type"] == "str"
@@ -46,12 +62,15 @@ def test_unwrap_schema(extractor):
@pytest.fixture
def mock_tool_extractor(extractor):
with mock.patch("generate_tool_specs.dir", return_value=["MockTool"]), \
mock.patch("generate_tool_specs.getattr", return_value=MockTool):
with (
mock.patch("generate_tool_specs.dir", return_value=["MockTool"]),
mock.patch("generate_tool_specs.getattr", return_value=MockTool),
):
extractor.extract_all_tools()
assert len(extractor.tools_spec) == 1
return extractor.tools_spec[0]
def test_extract_basic_tool_info(mock_tool_extractor):
tool_info = mock_tool_extractor
@@ -69,6 +88,7 @@ def test_extract_basic_tool_info(mock_tool_extractor):
assert tool_info["humanized_name"] == "Mock Search Tool"
assert tool_info["description"] == "A tool that mocks search functionality"
def test_extract_init_params_schema(mock_tool_extractor):
tool_info = mock_tool_extractor
init_params_schema = tool_info["init_params_schema"]
@@ -80,20 +100,21 @@ def test_extract_init_params_schema(mock_tool_extractor):
"type",
}
another_parameter = init_params_schema['properties']['another_parameter']
another_parameter = init_params_schema["properties"]["another_parameter"]
assert another_parameter["description"] == ""
assert another_parameter["default"] == "Another way to define a default value"
assert another_parameter["type"] == "string"
my_parameter = init_params_schema['properties']['my_parameter']
my_parameter = init_params_schema["properties"]["my_parameter"]
assert my_parameter["description"] == "What a description"
assert my_parameter["default"] == "This is default value"
assert my_parameter["type"] == "string"
my_parameter_bool = init_params_schema['properties']['my_parameter_bool']
my_parameter_bool = init_params_schema["properties"]["my_parameter_bool"]
assert my_parameter_bool["default"] == False
assert my_parameter_bool["type"] == "boolean"
def test_extract_env_vars(mock_tool_extractor):
tool_info = mock_tool_extractor
@@ -109,6 +130,7 @@ def test_extract_env_vars(mock_tool_extractor):
assert rate_limit_var["required"] == False
assert rate_limit_var["default"] == "100"
def test_extract_run_params_schema(mock_tool_extractor):
tool_info = mock_tool_extractor
@@ -131,22 +153,31 @@ def test_extract_run_params_schema(mock_tool_extractor):
filters_param = run_params_schema["properties"]["filters"]
assert filters_param["description"] == "Optional filters to apply"
assert filters_param["default"] == None
assert filters_param['anyOf'] == [{'items': {'type': 'string'}, 'type': 'array'}, {'type': 'null'}]
assert filters_param["anyOf"] == [
{"items": {"type": "string"}, "type": "array"},
{"type": "null"},
]
def test_extract_package_dependencies(mock_tool_extractor):
tool_info = mock_tool_extractor
assert tool_info["package_dependencies"] == ["this-is-a-required-package", "another-required-package"]
assert tool_info["package_dependencies"] == [
"this-is-a-required-package",
"another-required-package",
]
def test_save_to_json(extractor, tmp_path):
extractor.tools_spec = [{
"name": "TestTool",
"humanized_name": "Test Tool",
"description": "A test tool",
"run_params_schema": [
{"name": "param1", "description": "Test parameter", "type": "str"}
]
}]
extractor.tools_spec = [
{
"name": "TestTool",
"humanized_name": "Test Tool",
"description": "A test tool",
"run_params_schema": [
{"name": "param1", "description": "Test parameter", "type": "str"}
],
}
]
file_path = tmp_path / "output.json"
extractor.save_to_json(str(file_path))

View File

@@ -0,0 +1,75 @@
import json
from unittest.mock import patch
import pytest
from crewai_tools import MongoDBVectorSearchConfig, MongoDBVectorSearchTool
# Unit Test Fixtures
@pytest.fixture
def mongodb_vector_search_tool():
tool = MongoDBVectorSearchTool(
connection_string="foo", database_name="bar", collection_name="test"
)
tool._embed_texts = lambda x: [[0.1]]
yield tool
# Unit Tests
def test_successful_query_execution(mongodb_vector_search_tool):
# Enable embedding
with patch.object(mongodb_vector_search_tool._coll, "aggregate") as mock_aggregate:
mock_aggregate.return_value = [dict(text="foo", score=0.1, _id=1)]
results = json.loads(mongodb_vector_search_tool._run(query="sandwiches"))
assert len(results) == 1
assert results[0]["text"] == "foo"
assert results[0]["_id"] == 1
def test_provide_config():
query_config = MongoDBVectorSearchConfig(limit=10)
tool = MongoDBVectorSearchTool(
connection_string="foo",
database_name="bar",
collection_name="test",
query_config=query_config,
vector_index_name="foo",
embedding_model="bar",
)
tool._embed_texts = lambda x: [[0.1]]
with patch.object(tool._coll, "aggregate") as mock_aggregate:
mock_aggregate.return_value = [dict(text="foo", score=0.1, _id=1)]
tool._run(query="sandwiches")
assert mock_aggregate.mock_calls[-1].args[0][0]["$vectorSearch"]["limit"] == 10
mock_aggregate.return_value = [dict(text="foo", score=0.1, _id=1)]
def test_cleanup_on_deletion(mongodb_vector_search_tool):
with patch.object(mongodb_vector_search_tool, "_client") as mock_client:
# Trigger cleanup
mongodb_vector_search_tool.__del__()
mock_client.close.assert_called_once()
def test_create_search_index(mongodb_vector_search_tool):
with patch(
"crewai_tools.tools.mongodb_vector_search_tool.vector_search.create_vector_search_index"
) as mock_create_search_index:
mongodb_vector_search_tool.create_vector_search_index(dimensions=10)
kwargs = mock_create_search_index.mock_calls[0].kwargs
assert kwargs["dimensions"] == 10
assert kwargs["similarity"] == "cosine"
def test_add_texts(mongodb_vector_search_tool):
with patch.object(mongodb_vector_search_tool._coll, "bulk_write") as bulk_write:
mongodb_vector_search_tool.add_texts(["foo"])
args = bulk_write.mock_calls[0].args
assert "ReplaceOne" in str(args[0][0])
assert "foo" in str(args[0][0])