mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-25 08:08:14 +00:00
Add MongoDB Vector Search Tool (#319)
* INTPYTHON-580 Design and Implement MongoDBVectorSearchTool * add implementation * wip * wip * finish tests * add todo * refactor to wrap langchain-mongodb * cleanup * address review * Fix usage of EnvVar class * inline code * lint * lint * fix usage of SearchIndexModel * Refactor: Update EnvVar import path and remove unused tests.utils module - Changed import of EnvVar from tests.utils to crewai.tools in multiple files. - Updated README.md for MongoDB vector search tool with additional context. - Modified subprocess command in vector_search.py for package installation. - Cleaned up test_generate_tool_specs.py to improve mock patching syntax. - Deleted unused tests/utils.py file. * update the crewai dep and the lockfile * chore: update package versions and dependencies in uv.lock - Removed `auth0-python` package. - Updated `crewai` version to 0.140.0 and adjusted its dependencies. - Changed `json-repair` version to 0.25.2. - Updated `litellm` version to 1.72.6. - Modified dependency markers for several packages to improve compatibility with Python versions. * refactor: improve MongoDB vector search tool with enhanced error handling and new dimensions field - Added logging for error handling in the _run method and during client cleanup. - Introduced a new 'dimensions' field in the MongoDBVectorSearchConfig for embedding vector size. - Refactored the _run method to return JSON formatted results and handle exceptions gracefully. - Cleaned up import statements and improved code readability. * address review * update tests * debug * fix test * fix test * fix test * support azure openai --------- Co-authored-by: lorenzejay <lorenzejaytech@gmail.com>
This commit is contained in:
@@ -1,12 +1,13 @@
|
||||
import json
|
||||
from typing import List, Optional, Type
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
from unittest import mock
|
||||
|
||||
from generate_tool_specs import ToolSpecExtractor
|
||||
import pytest
|
||||
from crewai.tools.base_tool import BaseTool, EnvVar
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from generate_tool_specs import ToolSpecExtractor
|
||||
|
||||
|
||||
class MockToolSchema(BaseModel):
|
||||
query: str = Field(..., description="The query parameter")
|
||||
@@ -19,15 +20,30 @@ class MockTool(BaseTool):
|
||||
description: str = "A tool that mocks search functionality"
|
||||
args_schema: Type[BaseModel] = MockToolSchema
|
||||
|
||||
another_parameter: str = Field("Another way to define a default value", description="")
|
||||
another_parameter: str = Field(
|
||||
"Another way to define a default value", description=""
|
||||
)
|
||||
my_parameter: str = Field("This is default value", description="What a description")
|
||||
my_parameter_bool: bool = Field(False)
|
||||
package_dependencies: List[str] = Field(["this-is-a-required-package", "another-required-package"], description="")
|
||||
package_dependencies: List[str] = Field(
|
||||
["this-is-a-required-package", "another-required-package"], description=""
|
||||
)
|
||||
env_vars: List[EnvVar] = [
|
||||
EnvVar(name="SERPER_API_KEY", description="API key for Serper", required=True, default=None),
|
||||
EnvVar(name="API_RATE_LIMIT", description="API rate limit", required=False, default="100")
|
||||
EnvVar(
|
||||
name="SERPER_API_KEY",
|
||||
description="API key for Serper",
|
||||
required=True,
|
||||
default=None,
|
||||
),
|
||||
EnvVar(
|
||||
name="API_RATE_LIMIT",
|
||||
description="API rate limit",
|
||||
required=False,
|
||||
default="100",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def extractor():
|
||||
ext = ToolSpecExtractor()
|
||||
@@ -37,7 +53,7 @@ def extractor():
|
||||
def test_unwrap_schema(extractor):
|
||||
nested_schema = {
|
||||
"type": "function-after",
|
||||
"schema": {"type": "default", "schema": {"type": "str", "value": "test"}}
|
||||
"schema": {"type": "default", "schema": {"type": "str", "value": "test"}},
|
||||
}
|
||||
result = extractor._unwrap_schema(nested_schema)
|
||||
assert result["type"] == "str"
|
||||
@@ -46,12 +62,15 @@ def test_unwrap_schema(extractor):
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool_extractor(extractor):
|
||||
with mock.patch("generate_tool_specs.dir", return_value=["MockTool"]), \
|
||||
mock.patch("generate_tool_specs.getattr", return_value=MockTool):
|
||||
with (
|
||||
mock.patch("generate_tool_specs.dir", return_value=["MockTool"]),
|
||||
mock.patch("generate_tool_specs.getattr", return_value=MockTool),
|
||||
):
|
||||
extractor.extract_all_tools()
|
||||
assert len(extractor.tools_spec) == 1
|
||||
return extractor.tools_spec[0]
|
||||
|
||||
|
||||
def test_extract_basic_tool_info(mock_tool_extractor):
|
||||
tool_info = mock_tool_extractor
|
||||
|
||||
@@ -69,6 +88,7 @@ def test_extract_basic_tool_info(mock_tool_extractor):
|
||||
assert tool_info["humanized_name"] == "Mock Search Tool"
|
||||
assert tool_info["description"] == "A tool that mocks search functionality"
|
||||
|
||||
|
||||
def test_extract_init_params_schema(mock_tool_extractor):
|
||||
tool_info = mock_tool_extractor
|
||||
init_params_schema = tool_info["init_params_schema"]
|
||||
@@ -80,20 +100,21 @@ def test_extract_init_params_schema(mock_tool_extractor):
|
||||
"type",
|
||||
}
|
||||
|
||||
another_parameter = init_params_schema['properties']['another_parameter']
|
||||
another_parameter = init_params_schema["properties"]["another_parameter"]
|
||||
assert another_parameter["description"] == ""
|
||||
assert another_parameter["default"] == "Another way to define a default value"
|
||||
assert another_parameter["type"] == "string"
|
||||
|
||||
my_parameter = init_params_schema['properties']['my_parameter']
|
||||
my_parameter = init_params_schema["properties"]["my_parameter"]
|
||||
assert my_parameter["description"] == "What a description"
|
||||
assert my_parameter["default"] == "This is default value"
|
||||
assert my_parameter["type"] == "string"
|
||||
|
||||
my_parameter_bool = init_params_schema['properties']['my_parameter_bool']
|
||||
my_parameter_bool = init_params_schema["properties"]["my_parameter_bool"]
|
||||
assert my_parameter_bool["default"] == False
|
||||
assert my_parameter_bool["type"] == "boolean"
|
||||
|
||||
|
||||
def test_extract_env_vars(mock_tool_extractor):
|
||||
tool_info = mock_tool_extractor
|
||||
|
||||
@@ -109,6 +130,7 @@ def test_extract_env_vars(mock_tool_extractor):
|
||||
assert rate_limit_var["required"] == False
|
||||
assert rate_limit_var["default"] == "100"
|
||||
|
||||
|
||||
def test_extract_run_params_schema(mock_tool_extractor):
|
||||
tool_info = mock_tool_extractor
|
||||
|
||||
@@ -131,22 +153,31 @@ def test_extract_run_params_schema(mock_tool_extractor):
|
||||
filters_param = run_params_schema["properties"]["filters"]
|
||||
assert filters_param["description"] == "Optional filters to apply"
|
||||
assert filters_param["default"] == None
|
||||
assert filters_param['anyOf'] == [{'items': {'type': 'string'}, 'type': 'array'}, {'type': 'null'}]
|
||||
assert filters_param["anyOf"] == [
|
||||
{"items": {"type": "string"}, "type": "array"},
|
||||
{"type": "null"},
|
||||
]
|
||||
|
||||
|
||||
def test_extract_package_dependencies(mock_tool_extractor):
|
||||
tool_info = mock_tool_extractor
|
||||
assert tool_info["package_dependencies"] == ["this-is-a-required-package", "another-required-package"]
|
||||
assert tool_info["package_dependencies"] == [
|
||||
"this-is-a-required-package",
|
||||
"another-required-package",
|
||||
]
|
||||
|
||||
|
||||
def test_save_to_json(extractor, tmp_path):
|
||||
extractor.tools_spec = [{
|
||||
"name": "TestTool",
|
||||
"humanized_name": "Test Tool",
|
||||
"description": "A test tool",
|
||||
"run_params_schema": [
|
||||
{"name": "param1", "description": "Test parameter", "type": "str"}
|
||||
]
|
||||
}]
|
||||
extractor.tools_spec = [
|
||||
{
|
||||
"name": "TestTool",
|
||||
"humanized_name": "Test Tool",
|
||||
"description": "A test tool",
|
||||
"run_params_schema": [
|
||||
{"name": "param1", "description": "Test parameter", "type": "str"}
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
file_path = tmp_path / "output.json"
|
||||
extractor.save_to_json(str(file_path))
|
||||
|
||||
75
tests/tools/test_mongodb_vector_search_tool.py
Normal file
75
tests/tools/test_mongodb_vector_search_tool.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai_tools import MongoDBVectorSearchConfig, MongoDBVectorSearchTool
|
||||
|
||||
|
||||
# Unit Test Fixtures
|
||||
@pytest.fixture
|
||||
def mongodb_vector_search_tool():
|
||||
tool = MongoDBVectorSearchTool(
|
||||
connection_string="foo", database_name="bar", collection_name="test"
|
||||
)
|
||||
tool._embed_texts = lambda x: [[0.1]]
|
||||
yield tool
|
||||
|
||||
|
||||
# Unit Tests
|
||||
def test_successful_query_execution(mongodb_vector_search_tool):
|
||||
# Enable embedding
|
||||
with patch.object(mongodb_vector_search_tool._coll, "aggregate") as mock_aggregate:
|
||||
mock_aggregate.return_value = [dict(text="foo", score=0.1, _id=1)]
|
||||
|
||||
results = json.loads(mongodb_vector_search_tool._run(query="sandwiches"))
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["text"] == "foo"
|
||||
assert results[0]["_id"] == 1
|
||||
|
||||
|
||||
def test_provide_config():
|
||||
query_config = MongoDBVectorSearchConfig(limit=10)
|
||||
tool = MongoDBVectorSearchTool(
|
||||
connection_string="foo",
|
||||
database_name="bar",
|
||||
collection_name="test",
|
||||
query_config=query_config,
|
||||
vector_index_name="foo",
|
||||
embedding_model="bar",
|
||||
)
|
||||
tool._embed_texts = lambda x: [[0.1]]
|
||||
with patch.object(tool._coll, "aggregate") as mock_aggregate:
|
||||
mock_aggregate.return_value = [dict(text="foo", score=0.1, _id=1)]
|
||||
|
||||
tool._run(query="sandwiches")
|
||||
assert mock_aggregate.mock_calls[-1].args[0][0]["$vectorSearch"]["limit"] == 10
|
||||
|
||||
mock_aggregate.return_value = [dict(text="foo", score=0.1, _id=1)]
|
||||
|
||||
|
||||
def test_cleanup_on_deletion(mongodb_vector_search_tool):
|
||||
with patch.object(mongodb_vector_search_tool, "_client") as mock_client:
|
||||
# Trigger cleanup
|
||||
mongodb_vector_search_tool.__del__()
|
||||
|
||||
mock_client.close.assert_called_once()
|
||||
|
||||
|
||||
def test_create_search_index(mongodb_vector_search_tool):
|
||||
with patch(
|
||||
"crewai_tools.tools.mongodb_vector_search_tool.vector_search.create_vector_search_index"
|
||||
) as mock_create_search_index:
|
||||
mongodb_vector_search_tool.create_vector_search_index(dimensions=10)
|
||||
kwargs = mock_create_search_index.mock_calls[0].kwargs
|
||||
assert kwargs["dimensions"] == 10
|
||||
assert kwargs["similarity"] == "cosine"
|
||||
|
||||
|
||||
def test_add_texts(mongodb_vector_search_tool):
|
||||
with patch.object(mongodb_vector_search_tool._coll, "bulk_write") as bulk_write:
|
||||
mongodb_vector_search_tool.add_texts(["foo"])
|
||||
args = bulk_write.mock_calls[0].args
|
||||
assert "ReplaceOne" in str(args[0][0])
|
||||
assert "foo" in str(args[0][0])
|
||||
Reference in New Issue
Block a user