diff --git a/lib/crewai-tools/src/crewai_tools/__init__.py b/lib/crewai-tools/src/crewai_tools/__init__.py index 5aded17a7..48d288a88 100644 --- a/lib/crewai-tools/src/crewai_tools/__init__.py +++ b/lib/crewai-tools/src/crewai_tools/__init__.py @@ -98,6 +98,11 @@ from crewai_tools.tools.mongodb_vector_search_tool.vector_search import ( MongoDBVectorSearchTool, ) from crewai_tools.tools.multion_tool.multion_tool import MultiOnTool +from crewai_tools.tools.oceanbase_vector_search_tool.oceanbase_vector_search_tool import ( + OceanBaseToolSchema, + OceanBaseVectorSearchConfig, + OceanBaseVectorSearchTool, +) from crewai_tools.tools.mysql_search_tool.mysql_search_tool import MySQLSearchTool from crewai_tools.tools.nl2sql.nl2sql_tool import NL2SQLTool from crewai_tools.tools.ocr_tool.ocr_tool import OCRTool @@ -243,6 +248,9 @@ __all__ = [ "MongoDBVectorSearchTool", "MultiOnTool", "MySQLSearchTool", + "OceanBaseToolSchema", + "OceanBaseVectorSearchConfig", + "OceanBaseVectorSearchTool", "NL2SQLTool", "OCRTool", "OxylabsAmazonProductScraperTool", diff --git a/lib/crewai-tools/src/crewai_tools/tools/__init__.py b/lib/crewai-tools/src/crewai_tools/tools/__init__.py index 51d32ddc2..395963dd2 100644 --- a/lib/crewai-tools/src/crewai_tools/tools/__init__.py +++ b/lib/crewai-tools/src/crewai_tools/tools/__init__.py @@ -87,6 +87,11 @@ from crewai_tools.tools.mongodb_vector_search_tool import ( MongoDBVectorSearchConfig, MongoDBVectorSearchTool, ) +from crewai_tools.tools.oceanbase_vector_search_tool import ( + OceanBaseToolSchema, + OceanBaseVectorSearchConfig, + OceanBaseVectorSearchTool, +) from crewai_tools.tools.multion_tool.multion_tool import MultiOnTool from crewai_tools.tools.mysql_search_tool.mysql_search_tool import MySQLSearchTool from crewai_tools.tools.nl2sql.nl2sql_tool import NL2SQLTool @@ -226,6 +231,9 @@ __all__ = [ "MongoDBVectorSearchConfig", "MongoDBVectorSearchTool", "MultiOnTool", + "OceanBaseToolSchema", + "OceanBaseVectorSearchConfig", + "OceanBaseVectorSearchTool", "MySQLSearchTool", "NL2SQLTool", "OCRTool", diff --git a/lib/crewai-tools/src/crewai_tools/tools/oceanbase_vector_search_tool/README.md b/lib/crewai-tools/src/crewai_tools/tools/oceanbase_vector_search_tool/README.md new file mode 100644 index 000000000..9aafeef18 --- /dev/null +++ b/lib/crewai-tools/src/crewai_tools/tools/oceanbase_vector_search_tool/README.md @@ -0,0 +1,144 @@ +# OceanBaseVectorSearchTool + +## Description + +This tool is designed for performing vector similarity searches within an OceanBase database. OceanBase is a distributed relational database developed by Ant Group that supports native vector indexing and search capabilities using HNSW (Hierarchical Navigable Small World) algorithm. + +Use this tool to find semantically similar documents to a given query by leveraging OceanBase's vector search functionality. + +For more information about OceanBase vector capabilities, see: +https://en.oceanbase.com/docs/common-oceanbase-database-10000000001976351 + +## Installation + +Install the crewai_tools package with OceanBase support by executing the following command in your terminal: + +```shell +pip install crewai-tools[oceanbase] +``` + +or + +```shell +uv add crewai-tools --extra oceanbase +``` + +## Example + +### Basic Usage + +```python +from crewai_tools import OceanBaseVectorSearchTool + +tool = OceanBaseVectorSearchTool( + connection_uri="127.0.0.1:2881", + user="root@test", + password="", + db_name="test", + table_name="documents", +) +``` + +### With Custom Configuration + +```python +from crewai_tools import OceanBaseVectorSearchConfig, OceanBaseVectorSearchTool + +query_config = OceanBaseVectorSearchConfig( + limit=10, + distance_func="cosine", + distance_threshold=0.5, +) + +tool = OceanBaseVectorSearchTool( + connection_uri="127.0.0.1:2881", + user="root@test", + password="your_password", + db_name="my_database", + table_name="my_documents", + vector_column_name="embedding", + text_column_name="content", + metadata_column_name="metadata", + query_config=query_config, + embedding_model="text-embedding-3-large", + dimensions=3072, +) +``` + +### Adding the Tool to an Agent + +```python +from crewai import Agent +from crewai_tools import OceanBaseVectorSearchTool + +tool = OceanBaseVectorSearchTool( + connection_uri="127.0.0.1:2881", + user="root@test", + db_name="test", + table_name="documents", +) + +rag_agent = Agent( + name="rag_agent", + role="You are a helpful assistant that can answer questions using the OceanBaseVectorSearchTool.", + goal="Answer user questions by searching relevant documents", + backstory="You have access to a knowledge base stored in OceanBase", + llm="gpt-4o-mini", + tools=[tool], +) +``` + +### Preloading Documents + +```python +from crewai_tools import OceanBaseVectorSearchTool +import os + +tool = OceanBaseVectorSearchTool( + connection_uri="127.0.0.1:2881", + user="root@test", + db_name="test", + table_name="documents", +) + +texts = [] +metadatas = [] +for filename in os.listdir("knowledge"): + with open(os.path.join("knowledge", filename), "r") as f: + texts.append(f.read()) + metadatas.append({"source": filename}) + +tool.add_texts(texts, metadatas=metadatas) +``` + +## Configuration Options + +### OceanBaseVectorSearchConfig + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `limit` | int | 4 | Number of documents to return | +| `distance_func` | str | "l2" | Distance function: "l2", "cosine", or "inner_product" | +| `distance_threshold` | float | None | Only return results with distance <= threshold | +| `include_embeddings` | bool | False | Whether to include embedding vectors in results | + +### OceanBaseVectorSearchTool + +| Parameter | Type | Required | Description | +|-----------|------|----------|-------------| +| `connection_uri` | str | Yes | OceanBase connection URI (e.g., "127.0.0.1:2881") | +| `user` | str | Yes | Username for connection (e.g., "root@test") | +| `password` | str | No | Password for connection | +| `db_name` | str | No | Database name (default: "test") | +| `table_name` | str | Yes | Table containing vector data | +| `vector_column_name` | str | No | Column with embeddings (default: "embedding") | +| `text_column_name` | str | No | Column with text content (default: "text") | +| `metadata_column_name` | str | No | Column with metadata (default: "metadata") | +| `embedding_model` | str | No | OpenAI model for embeddings (default: "text-embedding-3-large") | +| `dimensions` | int | No | Embedding dimensions (default: 1536) | +| `query_config` | OceanBaseVectorSearchConfig | No | Search configuration | + +## Environment Variables + +- `OPENAI_API_KEY`: Required for generating embeddings +- `AZURE_OPENAI_ENDPOINT`: Optional, for Azure OpenAI support diff --git a/lib/crewai-tools/src/crewai_tools/tools/oceanbase_vector_search_tool/__init__.py b/lib/crewai-tools/src/crewai_tools/tools/oceanbase_vector_search_tool/__init__.py new file mode 100644 index 000000000..c9bc8b901 --- /dev/null +++ b/lib/crewai-tools/src/crewai_tools/tools/oceanbase_vector_search_tool/__init__.py @@ -0,0 +1,12 @@ +from crewai_tools.tools.oceanbase_vector_search_tool.oceanbase_vector_search_tool import ( + OceanBaseToolSchema, + OceanBaseVectorSearchConfig, + OceanBaseVectorSearchTool, +) + + +__all__ = [ + "OceanBaseToolSchema", + "OceanBaseVectorSearchConfig", + "OceanBaseVectorSearchTool", +] diff --git a/lib/crewai-tools/src/crewai_tools/tools/oceanbase_vector_search_tool/oceanbase_vector_search_tool.py b/lib/crewai-tools/src/crewai_tools/tools/oceanbase_vector_search_tool/oceanbase_vector_search_tool.py new file mode 100644 index 000000000..47f2c134d --- /dev/null +++ b/lib/crewai-tools/src/crewai_tools/tools/oceanbase_vector_search_tool/oceanbase_vector_search_tool.py @@ -0,0 +1,265 @@ +from __future__ import annotations + +import json +from logging import getLogger +import os +from typing import Any + +from crewai.tools import BaseTool, EnvVar +from pydantic import BaseModel, Field + + +try: + import pyobvector # noqa: F401 + + PYOBVECTOR_AVAILABLE = True +except ImportError: + PYOBVECTOR_AVAILABLE = False + +logger = getLogger(__name__) + + +class OceanBaseToolSchema(BaseModel): + """Input schema for OceanBase vector search tool.""" + + query: str = Field( + ..., + description="The query to search for relevant information in the OceanBase database.", + ) + + +class OceanBaseVectorSearchConfig(BaseModel): + """Configuration for OceanBase vector search queries.""" + + limit: int = Field( + default=4, + description="Number of documents to return.", + ) + distance_threshold: float | None = Field( + default=None, + description="Only return results where distance is less than or equal to this threshold.", + ) + distance_func: str = Field( + default="l2", + description="Distance function to use for similarity search. Options: 'l2', 'cosine', 'inner_product'.", + ) + include_embeddings: bool = Field( + default=False, + description="Whether to include the embedding vector of each result.", + ) + + +class OceanBaseVectorSearchTool(BaseTool): + """Tool to perform vector search on OceanBase database.""" + + name: str = "OceanBaseVectorSearchTool" + description: str = ( + "A tool to perform vector similarity search on an OceanBase database " + "for retrieving relevant information from stored documents." + ) + + args_schema: type[BaseModel] = OceanBaseToolSchema + query_config: OceanBaseVectorSearchConfig | None = Field( + default=None, + description="OceanBase vector search query configuration.", + ) + embedding_model: str = Field( + default="text-embedding-3-large", + description="OpenAI embedding model to use for generating query embeddings.", + ) + dimensions: int = Field( + default=1536, + description="Number of dimensions in the embedding vector.", + ) + connection_uri: str = Field( + ..., + description="Connection URI for OceanBase (e.g., '127.0.0.1:2881').", + ) + user: str = Field( + ..., + description="Username for OceanBase connection (e.g., 'root@test').", + ) + password: str = Field( + default="", + description="Password for OceanBase connection.", + ) + db_name: str = Field( + default="test", + description="Database name in OceanBase.", + ) + table_name: str = Field( + ..., + description="Name of the table containing vector data.", + ) + vector_column_name: str = Field( + default="embedding", + description="Name of the column containing vector embeddings.", + ) + text_column_name: str = Field( + default="text", + description="Name of the column containing text content.", + ) + metadata_column_name: str | None = Field( + default="metadata", + description="Name of the column containing metadata (optional).", + ) + env_vars: list[EnvVar] = Field( + default_factory=lambda: [ + EnvVar( + name="OPENAI_API_KEY", + description="API key for OpenAI embeddings", + required=True, + ), + ] + ) + package_dependencies: list[str] = Field(default_factory=lambda: ["pyobvector"]) + + _client: Any = None + _openai_client: Any = None + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + if not PYOBVECTOR_AVAILABLE: + import click + + if click.confirm( + "You are missing the 'pyobvector' package. Would you like to install it?" + ): + import subprocess + + subprocess.run(["uv", "add", "pyobvector"], check=True) # noqa: S607 + else: + raise ImportError( + "The 'pyobvector' package is required for OceanBaseVectorSearchTool." + ) + + if "AZURE_OPENAI_ENDPOINT" in os.environ: + from openai import AzureOpenAI + + self._openai_client = AzureOpenAI() + elif "OPENAI_API_KEY" in os.environ: + from openai import Client + + self._openai_client = Client() + else: + raise ValueError( + "OPENAI_API_KEY environment variable is required for OceanBaseVectorSearchTool." + ) + + from pyobvector import ObVecClient + + self._client = ObVecClient( + uri=self.connection_uri, + user=self.user, + password=self.password, + db_name=self.db_name, + ) + + def _embed_text(self, text: str) -> list[float]: + """Generate embedding for the given text using OpenAI.""" + response = self._openai_client.embeddings.create( + input=[text], + model=self.embedding_model, + dimensions=self.dimensions, + ) + return response.data[0].embedding + + def _get_distance_func(self) -> Any: + """Get the appropriate distance function from pyobvector.""" + from pyobvector import cosine_distance, inner_product, l2_distance + + config = self.query_config or OceanBaseVectorSearchConfig() + distance_funcs = { + "l2": l2_distance, + "cosine": cosine_distance, + "inner_product": inner_product, + } + return distance_funcs.get(config.distance_func, l2_distance) + + def _run(self, query: str) -> str: + """Execute vector search on OceanBase.""" + try: + config = self.query_config or OceanBaseVectorSearchConfig() + + query_vector = self._embed_text(query) + + output_columns = [self.text_column_name] + if self.metadata_column_name: + output_columns.append(self.metadata_column_name) + + results = self._client.ann_search( + table_name=self.table_name, + vec_data=query_vector, + vec_column_name=self.vector_column_name, + distance_func=self._get_distance_func(), + with_dist=True, + topk=config.limit, + output_column_names=output_columns, + distance_threshold=config.distance_threshold, + ) + + formatted_results = [] + for row in results: + result_dict: dict[str, Any] = {} + + if len(row) >= 1: + result_dict["text"] = row[0] + if self.metadata_column_name and len(row) >= 2: + result_dict["metadata"] = row[1] + if len(row) > len(output_columns): + result_dict["distance"] = row[-1] + + formatted_results.append(result_dict) + + return json.dumps(formatted_results, indent=2, default=str) + + except Exception as e: + logger.error(f"Error during OceanBase vector search: {e}") + return json.dumps({"error": str(e)}) + + def add_texts( + self, + texts: list[str], + metadatas: list[dict[str, Any]] | None = None, + ids: list[str] | None = None, + ) -> list[str]: + """Add texts with embeddings to the OceanBase table. + + Args: + texts: List of text strings to add. + metadatas: Optional list of metadata dictionaries for each text. + ids: Optional list of unique IDs for each text. + + Returns: + List of IDs for the added texts. + """ + import uuid + + if ids is None: + ids = [str(uuid.uuid4()) for _ in texts] + + if metadatas is None: + metadatas = [{} for _ in texts] + + data = [] + for text, metadata, doc_id in zip(texts, metadatas, ids, strict=False): + embedding = self._embed_text(text) + row = { + "id": doc_id, + self.text_column_name: text, + self.vector_column_name: embedding, + } + if self.metadata_column_name: + row[self.metadata_column_name] = metadata + data.append(row) + + self._client.insert(self.table_name, data=data) + return ids + + def __del__(self) -> None: + """Cleanup clients on deletion.""" + try: + if hasattr(self, "_openai_client") and self._openai_client: + self._openai_client.close() + except Exception as e: + logger.error(f"Error closing OpenAI client: {e}") diff --git a/lib/crewai-tools/tests/tools/test_oceanbase_vector_search_tool.py b/lib/crewai-tools/tests/tools/test_oceanbase_vector_search_tool.py new file mode 100644 index 000000000..43349ad36 --- /dev/null +++ b/lib/crewai-tools/tests/tools/test_oceanbase_vector_search_tool.py @@ -0,0 +1,208 @@ +import json +import sys +from unittest.mock import MagicMock, patch + +import pytest + +from crewai_tools import OceanBaseVectorSearchConfig + + +mock_pyobvector = MagicMock() +mock_pyobvector.ObVecClient = MagicMock() +mock_pyobvector.l2_distance = MagicMock(return_value="l2_func") +mock_pyobvector.cosine_distance = MagicMock(return_value="cosine_func") +mock_pyobvector.inner_product = MagicMock(return_value="ip_func") +sys.modules["pyobvector"] = mock_pyobvector + + +@pytest.fixture +def mock_openai_client(): + """Create a mock OpenAI client.""" + mock_client = MagicMock() + mock_embedding = MagicMock() + mock_embedding.embedding = [0.1] * 1536 + mock_response = MagicMock() + mock_response.data = [mock_embedding] + mock_client.embeddings.create.return_value = mock_response + return mock_client + + +@pytest.fixture +def mock_obvec_client(): + """Create a mock OceanBase vector client.""" + mock_client = MagicMock() + return mock_client + + +@pytest.fixture +def oceanbase_vector_search_tool(mock_openai_client, mock_obvec_client): + """Create an OceanBaseVectorSearchTool with mocked clients.""" + from crewai_tools import OceanBaseVectorSearchTool + + with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}): + with patch( + "crewai_tools.tools.oceanbase_vector_search_tool.oceanbase_vector_search_tool.PYOBVECTOR_AVAILABLE", + True, + ): + mock_pyobvector.ObVecClient.return_value = mock_obvec_client + with patch("openai.Client") as mock_openai_class: + mock_openai_class.return_value = mock_openai_client + tool = OceanBaseVectorSearchTool( + connection_uri="127.0.0.1:2881", + user="root@test", + password="", + db_name="test", + table_name="test_table", + ) + tool._openai_client = mock_openai_client + tool._client = mock_obvec_client + yield tool + + +def test_successful_query_execution(oceanbase_vector_search_tool, mock_obvec_client): + """Test successful vector search query execution.""" + mock_obvec_client.ann_search.return_value = [ + ("test document content", {"source": "test.txt"}, 0.1), + ("another document", {"source": "test2.txt"}, 0.2), + ] + + results = json.loads(oceanbase_vector_search_tool._run(query="test query")) + + assert len(results) == 2 + assert results[0]["text"] == "test document content" + assert results[0]["metadata"] == {"source": "test.txt"} + assert results[0]["distance"] == 0.1 + + +def test_query_with_custom_config(mock_openai_client, mock_obvec_client): + """Test vector search with custom configuration.""" + from crewai_tools import OceanBaseVectorSearchTool + + query_config = OceanBaseVectorSearchConfig( + limit=10, + distance_func="cosine", + distance_threshold=0.5, + ) + + with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}): + with patch( + "crewai_tools.tools.oceanbase_vector_search_tool.oceanbase_vector_search_tool.PYOBVECTOR_AVAILABLE", + True, + ): + mock_pyobvector.ObVecClient.return_value = mock_obvec_client + with patch("openai.Client") as mock_openai_class: + mock_openai_class.return_value = mock_openai_client + tool = OceanBaseVectorSearchTool( + connection_uri="127.0.0.1:2881", + user="root@test", + db_name="test", + table_name="test_table", + query_config=query_config, + ) + tool._openai_client = mock_openai_client + tool._client = mock_obvec_client + + mock_obvec_client.ann_search.return_value = [("doc", {}, 0.3)] + + tool._run(query="test") + + call_kwargs = mock_obvec_client.ann_search.call_args.kwargs + assert call_kwargs["topk"] == 10 + assert call_kwargs["distance_threshold"] == 0.5 + + +def test_add_texts(oceanbase_vector_search_tool, mock_obvec_client): + """Test adding texts to the OceanBase table.""" + texts = ["document 1", "document 2"] + metadatas = [{"source": "file1.txt"}, {"source": "file2.txt"}] + + result_ids = oceanbase_vector_search_tool.add_texts(texts, metadatas=metadatas) + + assert len(result_ids) == 2 + mock_obvec_client.insert.assert_called_once() + call_args = mock_obvec_client.insert.call_args + assert call_args[0][0] == "test_table" + assert len(call_args[1]["data"]) == 2 + + +def test_add_texts_without_metadata(oceanbase_vector_search_tool, mock_obvec_client): + """Test adding texts without metadata.""" + texts = ["document 1", "document 2"] + + result_ids = oceanbase_vector_search_tool.add_texts(texts) + + assert len(result_ids) == 2 + mock_obvec_client.insert.assert_called_once() + + +def test_error_handling(oceanbase_vector_search_tool, mock_obvec_client): + """Test error handling during search.""" + mock_obvec_client.ann_search.side_effect = Exception("Database connection error") + + result = json.loads(oceanbase_vector_search_tool._run(query="test")) + + assert "error" in result + assert "Database connection error" in result["error"] + + +def test_config_defaults(): + """Test OceanBaseVectorSearchConfig default values.""" + config = OceanBaseVectorSearchConfig() + + assert config.limit == 4 + assert config.distance_func == "l2" + assert config.distance_threshold is None + assert config.include_embeddings is False + + +def test_config_custom_values(): + """Test OceanBaseVectorSearchConfig with custom values.""" + config = OceanBaseVectorSearchConfig( + limit=20, + distance_func="cosine", + distance_threshold=0.8, + include_embeddings=True, + ) + + assert config.limit == 20 + assert config.distance_func == "cosine" + assert config.distance_threshold == 0.8 + assert config.include_embeddings is True + + +def test_tool_schema(): + """Test OceanBaseToolSchema validation.""" + from crewai_tools import OceanBaseToolSchema + + schema = OceanBaseToolSchema(query="test query") + assert schema.query == "test query" + + +def test_tool_schema_requires_query(): + """Test that OceanBaseToolSchema requires a query.""" + from crewai_tools import OceanBaseToolSchema + from pydantic import ValidationError + + with pytest.raises(ValidationError): + OceanBaseToolSchema() + + +def test_distance_function_selection(oceanbase_vector_search_tool): + """Test that the correct distance function is selected.""" + oceanbase_vector_search_tool.query_config = OceanBaseVectorSearchConfig( + distance_func="l2" + ) + func = oceanbase_vector_search_tool._get_distance_func() + assert func == mock_pyobvector.l2_distance + + oceanbase_vector_search_tool.query_config = OceanBaseVectorSearchConfig( + distance_func="cosine" + ) + func = oceanbase_vector_search_tool._get_distance_func() + assert func == mock_pyobvector.cosine_distance + + oceanbase_vector_search_tool.query_config = OceanBaseVectorSearchConfig( + distance_func="inner_product" + ) + func = oceanbase_vector_search_tool._get_distance_func() + assert func == mock_pyobvector.inner_product