diff --git a/README.md b/README.md index 4ce6d3807..3ee271370 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ CrewAI provides an extensive collection of powerful tools ready to enhance your - **File Management**: `FileReadTool`, `FileWriteTool` - **Web Scraping**: `ScrapeWebsiteTool`, `SeleniumScrapingTool` - **Database Integrations**: `PGSearchTool`, `MySQLSearchTool` +- **Vector Database Integrations**: `MongoDBVectorSearchTool`, `QdrantVectorSearchTool`, `WeaviateVectorSearchTool` - **API Integrations**: `SerperApiTool`, `EXASearchTool` - **AI-powered Tools**: `DallETool`, `VisionTool`, `StagehandTool` @@ -226,4 +227,3 @@ Join our rapidly growing community and receive real-time support: - [Open an Issue](https://github.com/crewAIInc/crewAI/issues) Build smarter, faster, and more powerful AI solutions—powered by CrewAI Tools. - diff --git a/src/crewai_tools/__init__.py b/src/crewai_tools/__init__.py index 8df620788..7831b957d 100644 --- a/src/crewai_tools/__init__.py +++ b/src/crewai_tools/__init__.py @@ -1,5 +1,6 @@ from .adapters.enterprise_adapter import EnterpriseActionTool from .adapters.mcp_adapter import MCPServerAdapter +from .adapters.zapier_adapter import ZapierActionTool from .aws import ( BedrockInvokeAgentTool, BedrockKBRetrieverTool, @@ -23,9 +24,9 @@ from .tools import ( DirectorySearchTool, DOCXSearchTool, EXASearchTool, + FileCompressorTool, FileReadTool, FileWriterTool, - FileCompressorTool, FirecrawlCrawlWebsiteTool, FirecrawlScrapeWebsiteTool, FirecrawlSearchTool, @@ -35,6 +36,8 @@ from .tools import ( LinkupSearchTool, LlamaIndexTool, MDXSearchTool, + MongoDBVectorSearchConfig, + MongoDBVectorSearchTool, MultiOnTool, MySQLSearchTool, NL2SQLTool, @@ -76,4 +79,3 @@ from .tools import ( YoutubeVideoSearchTool, ZapierActionTools, ) -from .adapters.zapier_adapter import ZapierActionTool diff --git a/src/crewai_tools/tools/__init__.py b/src/crewai_tools/tools/__init__.py index 47f3f5f80..d4b54c5ff 100644 --- a/src/crewai_tools/tools/__init__.py +++ b/src/crewai_tools/tools/__init__.py @@ -16,10 +16,10 @@ from .docx_search_tool.docx_search_tool import DOCXSearchTool from .exa_tools.exa_search_tool import EXASearchTool from .file_read_tool.file_read_tool import FileReadTool from .file_writer_tool.file_writer_tool import FileWriterTool +from .files_compressor_tool.files_compressor_tool import FileCompressorTool from .firecrawl_crawl_website_tool.firecrawl_crawl_website_tool import ( FirecrawlCrawlWebsiteTool, ) -from .files_compressor_tool.files_compressor_tool import FileCompressorTool from .firecrawl_scrape_website_tool.firecrawl_scrape_website_tool import ( FirecrawlScrapeWebsiteTool, ) @@ -30,6 +30,11 @@ from .json_search_tool.json_search_tool import JSONSearchTool from .linkup.linkup_search_tool import LinkupSearchTool from .llamaindex_tool.llamaindex_tool import LlamaIndexTool from .mdx_search_tool.mdx_search_tool import MDXSearchTool +from .mongodb_vector_search_tool import ( + MongoDBToolSchema, + MongoDBVectorSearchConfig, + MongoDBVectorSearchTool, +) from .multion_tool.multion_tool import MultiOnTool from .mysql_search_tool.mysql_search_tool import MySQLSearchTool from .nl2sql.nl2sql_tool import NL2SQLTool diff --git a/src/crewai_tools/tools/mongodb_vector_search_tool/README.md b/src/crewai_tools/tools/mongodb_vector_search_tool/README.md new file mode 100644 index 000000000..c66dfcf43 --- /dev/null +++ b/src/crewai_tools/tools/mongodb_vector_search_tool/README.md @@ -0,0 +1,87 @@ +# MongoDBVectorSearchTool + +## Description +This tool is specifically crafted for conducting vector searches within docs within a MongoDB database. Use this tool to find semantically similar docs to a given query. + +MongoDB can act as a vector database that is used to store and query vector embeddings. You can follow the docs here: +https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-overview/ + +## Installation +Install the crewai_tools package with MongoDB support by executing the following command in your terminal: + +```shell +pip install crewai-tools[mongodb] +``` + +or + +``` +uv add crewai-tools --extra mongodb +``` + +## Example +To utilize the MongoDBVectorSearchTool for different use cases, follow these examples: + +```python +from crewai_tools import MongoDBVectorSearchTool + +# To enable the tool to search any website the agent comes across or learns about during its operation +tool = MongoDBVectorSearchTool( + database_name="example_database', + collection_name='example_collections', + connection_string="", +) +``` + +or + +```python +from crewai_tools import MongoDBVectorSearchConfig, MongoDBVectorSearchTool + +# Setup custom embedding model and customize the parameters. +query_config = MongoDBVectorSearchConfig(limit=10, oversampling_factor=2) +tool = MongoDBVectorSearchTool( + database_name="example_database', + collection_name='example_collections', + connection_string="", + query_config=query_config, + index_name="my_vector_index", + generative_model="gpt-4o-mini" +) + +# Adding the tool to an agent +rag_agent = Agent( + name="rag_agent", + role="You are a helpful assistant that can answer questions with the help of the MongoDBVectorSearchTool.", + goal="...", + backstory="...", + llm="gpt-4o-mini", + tools=[tool], +) +``` + +Preloading the MongoDB database with documents: + +```python +from crewai_tools import MongoDBVectorSearchTool + +# Generate the documents and add them to the MongoDB database +test_docs = client.collections.get("example_collections") + +# Create the tool. +tool = MongoDBVectorSearchTool( + database_name="example_database', + collection_name='example_collections', + connection_string="", +) + +# Add the text from a set of CrewAI knowledge documents. +texts = [] +for d in os.listdir("knowledge"): + with open(os.path.join("knowledge", d), "r") as f: + texts.append(f.read()) +tool.add_texts(text) + +# Create the vector search index (if it wasn't already created in Atlas). +tool.create_vector_search_index(dimensions=3072) +``` diff --git a/src/crewai_tools/tools/mongodb_vector_search_tool/__init__.py b/src/crewai_tools/tools/mongodb_vector_search_tool/__init__.py new file mode 100644 index 000000000..c7e991472 --- /dev/null +++ b/src/crewai_tools/tools/mongodb_vector_search_tool/__init__.py @@ -0,0 +1,11 @@ +from .vector_search import ( + MongoDBToolSchema, + MongoDBVectorSearchConfig, + MongoDBVectorSearchTool, +) + +__all__ = [ + "MongoDBVectorSearchConfig", + "MongoDBVectorSearchTool", + "MongoDBToolSchema", +] diff --git a/src/crewai_tools/tools/mongodb_vector_search_tool/utils.py b/src/crewai_tools/tools/mongodb_vector_search_tool/utils.py new file mode 100644 index 000000000..a66586f6f --- /dev/null +++ b/src/crewai_tools/tools/mongodb_vector_search_tool/utils.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +from time import monotonic, sleep +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional + +if TYPE_CHECKING: + from pymongo.collection import Collection + + +def _vector_search_index_definition( + dimensions: int, + path: str, + similarity: str, + filters: Optional[List[str]] = None, + **kwargs: Any, +) -> Dict[str, Any]: + # https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/ + fields = [ + { + "numDimensions": dimensions, + "path": path, + "similarity": similarity, + "type": "vector", + }, + ] + if filters: + for field in filters: + fields.append({"type": "filter", "path": field}) + definition = {"fields": fields} + definition.update(kwargs) + return definition + + +def create_vector_search_index( + collection: Collection, + index_name: str, + dimensions: int, + path: str, + similarity: str, + filters: Optional[List[str]] = None, + *, + wait_until_complete: Optional[float] = None, + **kwargs: Any, +) -> None: + """Experimental Utility function to create a vector search index + + Args: + collection (Collection): MongoDB Collection + index_name (str): Name of Index + dimensions (int): Number of dimensions in embedding + path (str): field with vector embedding + similarity (str): The similarity score used for the index + filters (List[str]): Fields/paths to index to allow filtering in $vectorSearch + wait_until_complete (Optional[float]): If provided, number of seconds to wait + until search index is ready. + kwargs: Keyword arguments supplying any additional options to SearchIndexModel. + """ + from pymongo.operations import SearchIndexModel + + if collection.name not in collection.database.list_collection_names(): + collection.database.create_collection(collection.name) + + result = collection.create_search_index( + SearchIndexModel( + definition=_vector_search_index_definition( + dimensions=dimensions, + path=path, + similarity=similarity, + filters=filters, + **kwargs, + ), + name=index_name, + type="vectorSearch", + ) + ) + + if wait_until_complete: + _wait_for_predicate( + predicate=lambda: _is_index_ready(collection, index_name), + err=f"{index_name=} did not complete in {wait_until_complete}!", + timeout=wait_until_complete, + ) + + +def _is_index_ready(collection: Collection, index_name: str) -> bool: + """Check for the index name in the list of available search indexes to see if the + specified index is of status READY + + Args: + collection (Collection): MongoDB Collection to for the search indexes + index_name (str): Vector Search Index name + + Returns: + bool : True if the index is present and READY false otherwise + """ + for index in collection.list_search_indexes(index_name): + if index["status"] == "READY": + return True + return False + + +def _wait_for_predicate( + predicate: Callable, err: str, timeout: float = 120, interval: float = 0.5 +) -> None: + """Generic to block until the predicate returns true + + Args: + predicate (Callable[, bool]): A function that returns a boolean value + err (str): Error message to raise if nothing occurs + timeout (float, optional): Wait time for predicate. Defaults to TIMEOUT. + interval (float, optional): Interval to check predicate. Defaults to DELAY. + + Raises: + TimeoutError: _description_ + """ + start = monotonic() + while not predicate(): + if monotonic() - start > timeout: + raise TimeoutError(err) + sleep(interval) diff --git a/src/crewai_tools/tools/mongodb_vector_search_tool/vector_search.py b/src/crewai_tools/tools/mongodb_vector_search_tool/vector_search.py new file mode 100644 index 000000000..3f8af315d --- /dev/null +++ b/src/crewai_tools/tools/mongodb_vector_search_tool/vector_search.py @@ -0,0 +1,326 @@ +import json +import os +from importlib.metadata import version +from logging import getLogger +from typing import Any, Dict, Iterable, List, Optional, Type + +from crewai.tools import BaseTool, EnvVar +from openai import AzureOpenAI, Client +from pydantic import BaseModel, Field + +from crewai_tools.tools.mongodb_vector_search_tool.utils import ( + create_vector_search_index, +) + +try: + import pymongo # noqa: F403 + + MONGODB_AVAILABLE = True +except ImportError: + MONGODB_AVAILABLE = False + +logger = getLogger(__name__) + + +class MongoDBVectorSearchConfig(BaseModel): + """Configuration for MongoDB vector search queries.""" + + limit: Optional[int] = Field( + default=4, description="number of documents to return." + ) + pre_filter: Optional[dict[str, Any]] = Field( + default=None, + description="List of MQL match expressions comparing an indexed field", + ) + post_filter_pipeline: Optional[list[dict]] = Field( + default=None, + description="Pipeline of MongoDB aggregation stages to filter/process results after $vectorSearch.", + ) + oversampling_factor: int = Field( + default=10, + description="Multiple of limit used when generating number of candidates at each step in the HNSW Vector Search", + ) + include_embeddings: bool = Field( + default=False, + description="Whether to include the embedding vector of each result in metadata.", + ) + + +class MongoDBToolSchema(MongoDBVectorSearchConfig): + """Input for MongoDBTool.""" + + query: str = Field( + ..., + description="The query to search retrieve relevant information from the MongoDB database. Pass only the query, not the question.", + ) + + +class MongoDBVectorSearchTool(BaseTool): + """Tool to perfrom a vector search the MongoDB database""" + + name: str = "MongoDBVectorSearchTool" + description: str = "A tool to perfrom a vector search on a MongoDB database for relevant information on internal documents." + + args_schema: Type[BaseModel] = MongoDBToolSchema + query_config: Optional[MongoDBVectorSearchConfig] = Field( + default=None, description="MongoDB Vector Search query configuration" + ) + embedding_model: str = Field( + default="text-embedding-3-large", + description="Text OpenAI embedding model to use", + ) + vector_index_name: str = Field( + default="vector_index", description="Name of the Atlas Search vector index" + ) + text_key: str = Field( + default="text", + description="MongoDB field that will contain the text for each document", + ) + embedding_key: str = Field( + default="embedding", + description="Field that will contain the embedding for each document", + ) + database_name: str = Field(..., description="The name of the MongoDB database") + collection_name: str = Field(..., description="The name of the MongoDB collection") + connection_string: str = Field( + ..., + description="The connection string of the MongoDB cluster", + ) + dimensions: int = Field( + default=1536, + description="Number of dimensions in the embedding vector", + ) + env_vars: List[EnvVar] = [ + EnvVar( + name="BROWSERBASE_API_KEY", + description="API key for Browserbase services", + required=False, + ), + EnvVar( + name="BROWSERBASE_PROJECT_ID", + description="Project ID for Browserbase services", + required=False, + ), + ] + package_dependencies: List[str] = ["mongdb"] + + def __init__(self, **kwargs): + super().__init__(**kwargs) + if not MONGODB_AVAILABLE: + import click + + if click.confirm( + "You are missing the 'mongodb' crewai tool. Would you like to install it?" + ): + import subprocess + + subprocess.run(["uv", "add", "pymongo"], check=True) + + else: + raise ImportError("You are missing the 'mongodb' crewai tool.") + + if "AZURE_OPENAI_ENDPOINT" in os.environ: + self._openai_client = AzureOpenAI() + elif "OPENAI_API_KEY" in os.environ: + self._openai_client = Client() + else: + raise ValueError( + "OPENAI_API_KEY environment variable is required for MongoDBVectorSearchTool and it is mandatory to use the tool." + ) + + from pymongo import MongoClient + from pymongo.driver_info import DriverInfo + + self._client = MongoClient( + self.connection_string, + driver=DriverInfo(name="CrewAI", version=version("crewai-tools")), + ) + self._coll = self._client[self.database_name][self.collection_name] + + def create_vector_search_index( + self, + *, + dimensions: int, + relevance_score_fn: str = "cosine", + auto_index_timeout: int = 15, + ) -> None: + """Convenience function to create a vector search index. + + Args: + dimensions: Number of dimensions in embedding. If the value is set and + the index does not exist, an index will be created. + relevance_score_fn: The similarity score used for the index + Currently supported: 'euclidean', 'cosine', and 'dotProduct' + auto_index_timeout: Timeout in seconds to wait for an auto-created index + to be ready. + """ + + create_vector_search_index( + collection=self._coll, + index_name=self.vector_index_name, + dimensions=dimensions, + path=self.embedding_key, + similarity=relevance_score_fn, + wait_until_complete=auto_index_timeout, + ) + + def add_texts( + self, + texts: Iterable[str], + metadatas: Optional[List[Dict[str, Any]]] = None, + ids: Optional[List[str]] = None, + batch_size: int = 100, + **kwargs: Any, + ) -> List[str]: + """Add texts, create embeddings, and add to the Collection and index. + + Important notes on ids: + - If _id or id is a key in the metadatas dicts, one must + pop them and provide as separate list. + - They must be unique. + - If they are not provided, the VectorStore will create unique ones, + stored as bson.ObjectIds internally, and strings in Langchain. + These will appear in Document.metadata with key, '_id'. + + Args: + texts: Iterable of strings to add to the vectorstore. + metadatas: Optional list of metadatas associated with the texts. + ids: Optional list of unique ids that will be used as index in VectorStore. + See note on ids. + batch_size: Number of documents to insert at a time. + Tuning this may help with performance and sidestep MongoDB limits. + + Returns: + List of ids added to the vectorstore. + """ + from bson import ObjectId + + _metadatas = metadatas or [{} for _ in texts] + ids = [str(ObjectId()) for _ in range(len(list(texts)))] + metadatas_batch = _metadatas + + result_ids = [] + texts_batch = [] + metadatas_batch = [] + size = 0 + i = 0 + for j, (text, metadata) in enumerate(zip(texts, _metadatas)): + size += len(text) + len(metadata) + texts_batch.append(text) + metadatas_batch.append(metadata) + if (j + 1) % batch_size == 0 or size >= 47_000_000: + batch_res = self._bulk_embed_and_insert_texts( + texts_batch, metadatas_batch, ids[i : j + 1] + ) + result_ids.extend(batch_res) + texts_batch = [] + metadatas_batch = [] + size = 0 + i = j + 1 + if texts_batch: + batch_res = self._bulk_embed_and_insert_texts( + texts_batch, metadatas_batch, ids[i : j + 1] + ) + result_ids.extend(batch_res) + return result_ids + + def _embed_texts(self, texts: List[str]) -> List[List[float]]: + return [ + i.embedding + for i in self._openai_client.embeddings.create( + input=texts, + model=self.embedding_model, + dimensions=self.dimensions, + ).data + ] + + def _bulk_embed_and_insert_texts( + self, + texts: List[str], + metadatas: List[dict], + ids: List[str], + ) -> List[str]: + """Bulk insert single batch of texts, embeddings, and ids.""" + from bson import ObjectId + from pymongo.operations import ReplaceOne + + if not texts: + return [] + # Compute embedding vectors + embeddings = self._embed_texts(texts) + docs = [ + { + "_id": ObjectId(i), + self.text_key: t, + self.embedding_key: embedding, + **m, + } + for i, t, m, embedding in zip(ids, texts, metadatas, embeddings) + ] + operations = [ReplaceOne({"_id": doc["_id"]}, doc, upsert=True) for doc in docs] + # insert the documents in MongoDB Atlas + result = self._coll.bulk_write(operations) + assert result.upserted_ids is not None + return [str(_id) for _id in result.upserted_ids.values()] + + def _run(self, query: str) -> str: + try: + query_config = self.query_config or MongoDBVectorSearchConfig() + limit = query_config.limit + oversampling_factor = query_config.oversampling_factor + pre_filter = query_config.pre_filter + include_embeddings = query_config.include_embeddings + post_filter_pipeline = query_config.post_filter_pipeline + + # Create the embedding for the query + query_vector = self._embed_texts([query])[0] + + # Atlas Vector Search, potentially with filter + stage = { + "index": self.vector_index_name, + "path": self.embedding_key, + "queryVector": query_vector, + "numCandidates": limit * oversampling_factor, + "limit": limit, + } + if pre_filter: + stage["filter"] = pre_filter + + pipeline = [ + {"$vectorSearch": stage}, + {"$set": {"score": {"$meta": "vectorSearchScore"}}}, + ] + + # Remove embeddings unless requested + if not include_embeddings: + pipeline.append({"$project": {self.embedding_key: 0}}) + + # Post-processing + if post_filter_pipeline is not None: + pipeline.extend(post_filter_pipeline) + + # Execution + cursor = self._coll.aggregate(pipeline) # type: ignore[arg-type] + docs = [] + + # Format + for doc in cursor: + docs.append(doc) + return json.dumps(docs) + except Exception as e: + logger.error(f"Error: {e}") + return "" + + def __del__(self): + """Cleanup clients on deletion.""" + try: + if hasattr(self, "_client") and self._client: + self._client.close() + except Exception as e: + logger.error(f"Error: {e}") + + try: + if hasattr(self, "_openai_client") and self._openai_client: + self._openai_client.close() + except Exception as e: + logger.error(f"Error: {e}") diff --git a/tests/test_generate_tool_specs.py b/tests/test_generate_tool_specs.py index 73034a174..eeb407be1 100644 --- a/tests/test_generate_tool_specs.py +++ b/tests/test_generate_tool_specs.py @@ -1,12 +1,13 @@ import json from typing import List, Optional, Type - -import pytest -from pydantic import BaseModel, Field from unittest import mock -from generate_tool_specs import ToolSpecExtractor +import pytest from crewai.tools.base_tool import BaseTool, EnvVar +from pydantic import BaseModel, Field + +from generate_tool_specs import ToolSpecExtractor + class MockToolSchema(BaseModel): query: str = Field(..., description="The query parameter") @@ -19,15 +20,30 @@ class MockTool(BaseTool): description: str = "A tool that mocks search functionality" args_schema: Type[BaseModel] = MockToolSchema - another_parameter: str = Field("Another way to define a default value", description="") + another_parameter: str = Field( + "Another way to define a default value", description="" + ) my_parameter: str = Field("This is default value", description="What a description") my_parameter_bool: bool = Field(False) - package_dependencies: List[str] = Field(["this-is-a-required-package", "another-required-package"], description="") + package_dependencies: List[str] = Field( + ["this-is-a-required-package", "another-required-package"], description="" + ) env_vars: List[EnvVar] = [ - EnvVar(name="SERPER_API_KEY", description="API key for Serper", required=True, default=None), - EnvVar(name="API_RATE_LIMIT", description="API rate limit", required=False, default="100") + EnvVar( + name="SERPER_API_KEY", + description="API key for Serper", + required=True, + default=None, + ), + EnvVar( + name="API_RATE_LIMIT", + description="API rate limit", + required=False, + default="100", + ), ] + @pytest.fixture def extractor(): ext = ToolSpecExtractor() @@ -37,7 +53,7 @@ def extractor(): def test_unwrap_schema(extractor): nested_schema = { "type": "function-after", - "schema": {"type": "default", "schema": {"type": "str", "value": "test"}} + "schema": {"type": "default", "schema": {"type": "str", "value": "test"}}, } result = extractor._unwrap_schema(nested_schema) assert result["type"] == "str" @@ -46,12 +62,15 @@ def test_unwrap_schema(extractor): @pytest.fixture def mock_tool_extractor(extractor): - with mock.patch("generate_tool_specs.dir", return_value=["MockTool"]), \ - mock.patch("generate_tool_specs.getattr", return_value=MockTool): + with ( + mock.patch("generate_tool_specs.dir", return_value=["MockTool"]), + mock.patch("generate_tool_specs.getattr", return_value=MockTool), + ): extractor.extract_all_tools() assert len(extractor.tools_spec) == 1 return extractor.tools_spec[0] + def test_extract_basic_tool_info(mock_tool_extractor): tool_info = mock_tool_extractor @@ -69,6 +88,7 @@ def test_extract_basic_tool_info(mock_tool_extractor): assert tool_info["humanized_name"] == "Mock Search Tool" assert tool_info["description"] == "A tool that mocks search functionality" + def test_extract_init_params_schema(mock_tool_extractor): tool_info = mock_tool_extractor init_params_schema = tool_info["init_params_schema"] @@ -80,20 +100,21 @@ def test_extract_init_params_schema(mock_tool_extractor): "type", } - another_parameter = init_params_schema['properties']['another_parameter'] + another_parameter = init_params_schema["properties"]["another_parameter"] assert another_parameter["description"] == "" assert another_parameter["default"] == "Another way to define a default value" assert another_parameter["type"] == "string" - my_parameter = init_params_schema['properties']['my_parameter'] + my_parameter = init_params_schema["properties"]["my_parameter"] assert my_parameter["description"] == "What a description" assert my_parameter["default"] == "This is default value" assert my_parameter["type"] == "string" - my_parameter_bool = init_params_schema['properties']['my_parameter_bool'] + my_parameter_bool = init_params_schema["properties"]["my_parameter_bool"] assert my_parameter_bool["default"] == False assert my_parameter_bool["type"] == "boolean" + def test_extract_env_vars(mock_tool_extractor): tool_info = mock_tool_extractor @@ -109,6 +130,7 @@ def test_extract_env_vars(mock_tool_extractor): assert rate_limit_var["required"] == False assert rate_limit_var["default"] == "100" + def test_extract_run_params_schema(mock_tool_extractor): tool_info = mock_tool_extractor @@ -131,22 +153,31 @@ def test_extract_run_params_schema(mock_tool_extractor): filters_param = run_params_schema["properties"]["filters"] assert filters_param["description"] == "Optional filters to apply" assert filters_param["default"] == None - assert filters_param['anyOf'] == [{'items': {'type': 'string'}, 'type': 'array'}, {'type': 'null'}] + assert filters_param["anyOf"] == [ + {"items": {"type": "string"}, "type": "array"}, + {"type": "null"}, + ] + def test_extract_package_dependencies(mock_tool_extractor): tool_info = mock_tool_extractor - assert tool_info["package_dependencies"] == ["this-is-a-required-package", "another-required-package"] + assert tool_info["package_dependencies"] == [ + "this-is-a-required-package", + "another-required-package", + ] def test_save_to_json(extractor, tmp_path): - extractor.tools_spec = [{ - "name": "TestTool", - "humanized_name": "Test Tool", - "description": "A test tool", - "run_params_schema": [ - {"name": "param1", "description": "Test parameter", "type": "str"} - ] - }] + extractor.tools_spec = [ + { + "name": "TestTool", + "humanized_name": "Test Tool", + "description": "A test tool", + "run_params_schema": [ + {"name": "param1", "description": "Test parameter", "type": "str"} + ], + } + ] file_path = tmp_path / "output.json" extractor.save_to_json(str(file_path)) diff --git a/tests/tools/test_mongodb_vector_search_tool.py b/tests/tools/test_mongodb_vector_search_tool.py new file mode 100644 index 000000000..b76debbde --- /dev/null +++ b/tests/tools/test_mongodb_vector_search_tool.py @@ -0,0 +1,75 @@ +import json +from unittest.mock import patch + +import pytest + +from crewai_tools import MongoDBVectorSearchConfig, MongoDBVectorSearchTool + + +# Unit Test Fixtures +@pytest.fixture +def mongodb_vector_search_tool(): + tool = MongoDBVectorSearchTool( + connection_string="foo", database_name="bar", collection_name="test" + ) + tool._embed_texts = lambda x: [[0.1]] + yield tool + + +# Unit Tests +def test_successful_query_execution(mongodb_vector_search_tool): + # Enable embedding + with patch.object(mongodb_vector_search_tool._coll, "aggregate") as mock_aggregate: + mock_aggregate.return_value = [dict(text="foo", score=0.1, _id=1)] + + results = json.loads(mongodb_vector_search_tool._run(query="sandwiches")) + + assert len(results) == 1 + assert results[0]["text"] == "foo" + assert results[0]["_id"] == 1 + + +def test_provide_config(): + query_config = MongoDBVectorSearchConfig(limit=10) + tool = MongoDBVectorSearchTool( + connection_string="foo", + database_name="bar", + collection_name="test", + query_config=query_config, + vector_index_name="foo", + embedding_model="bar", + ) + tool._embed_texts = lambda x: [[0.1]] + with patch.object(tool._coll, "aggregate") as mock_aggregate: + mock_aggregate.return_value = [dict(text="foo", score=0.1, _id=1)] + + tool._run(query="sandwiches") + assert mock_aggregate.mock_calls[-1].args[0][0]["$vectorSearch"]["limit"] == 10 + + mock_aggregate.return_value = [dict(text="foo", score=0.1, _id=1)] + + +def test_cleanup_on_deletion(mongodb_vector_search_tool): + with patch.object(mongodb_vector_search_tool, "_client") as mock_client: + # Trigger cleanup + mongodb_vector_search_tool.__del__() + + mock_client.close.assert_called_once() + + +def test_create_search_index(mongodb_vector_search_tool): + with patch( + "crewai_tools.tools.mongodb_vector_search_tool.vector_search.create_vector_search_index" + ) as mock_create_search_index: + mongodb_vector_search_tool.create_vector_search_index(dimensions=10) + kwargs = mock_create_search_index.mock_calls[0].kwargs + assert kwargs["dimensions"] == 10 + assert kwargs["similarity"] == "cosine" + + +def test_add_texts(mongodb_vector_search_tool): + with patch.object(mongodb_vector_search_tool._coll, "bulk_write") as bulk_write: + mongodb_vector_search_tool.add_texts(["foo"]) + args = bulk_write.mock_calls[0].args + assert "ReplaceOne" in str(args[0][0]) + assert "foo" in str(args[0][0])