Compare commits

...

3 Commits

Author SHA1 Message Date
Devin AI
c2e1fa0231 chore: re-trigger CI tests
Co-Authored-By: João <joao@crewai.com>
2026-02-02 09:27:56 +00:00
Devin AI
5ecbfa26cb chore: re-trigger CI
Co-Authored-By: João <joao@crewai.com>
2026-02-02 09:24:48 +00:00
Devin AI
87cf7b9234 feat: add OceanBase vector search tool
- Add OceanBaseVectorSearchTool for vector similarity search on OceanBase database
- Support for L2, cosine, and inner product distance functions
- Integration with OpenAI embeddings for query vectorization
- Add comprehensive tests with mocked pyobvector module
- Add README documentation with usage examples

Closes #4332

Co-Authored-By: João <joao@crewai.com>
2026-02-02 09:21:05 +00:00
6 changed files with 645 additions and 0 deletions

View File

@@ -98,6 +98,11 @@ from crewai_tools.tools.mongodb_vector_search_tool.vector_search import (
MongoDBVectorSearchTool,
)
from crewai_tools.tools.multion_tool.multion_tool import MultiOnTool
from crewai_tools.tools.oceanbase_vector_search_tool.oceanbase_vector_search_tool import (
OceanBaseToolSchema,
OceanBaseVectorSearchConfig,
OceanBaseVectorSearchTool,
)
from crewai_tools.tools.mysql_search_tool.mysql_search_tool import MySQLSearchTool
from crewai_tools.tools.nl2sql.nl2sql_tool import NL2SQLTool
from crewai_tools.tools.ocr_tool.ocr_tool import OCRTool
@@ -243,6 +248,9 @@ __all__ = [
"MongoDBVectorSearchTool",
"MultiOnTool",
"MySQLSearchTool",
"OceanBaseToolSchema",
"OceanBaseVectorSearchConfig",
"OceanBaseVectorSearchTool",
"NL2SQLTool",
"OCRTool",
"OxylabsAmazonProductScraperTool",

View File

@@ -87,6 +87,11 @@ from crewai_tools.tools.mongodb_vector_search_tool import (
MongoDBVectorSearchConfig,
MongoDBVectorSearchTool,
)
from crewai_tools.tools.oceanbase_vector_search_tool import (
OceanBaseToolSchema,
OceanBaseVectorSearchConfig,
OceanBaseVectorSearchTool,
)
from crewai_tools.tools.multion_tool.multion_tool import MultiOnTool
from crewai_tools.tools.mysql_search_tool.mysql_search_tool import MySQLSearchTool
from crewai_tools.tools.nl2sql.nl2sql_tool import NL2SQLTool
@@ -226,6 +231,9 @@ __all__ = [
"MongoDBVectorSearchConfig",
"MongoDBVectorSearchTool",
"MultiOnTool",
"OceanBaseToolSchema",
"OceanBaseVectorSearchConfig",
"OceanBaseVectorSearchTool",
"MySQLSearchTool",
"NL2SQLTool",
"OCRTool",

View File

@@ -0,0 +1,144 @@
# OceanBaseVectorSearchTool
## Description
This tool is designed for performing vector similarity searches within an OceanBase database. OceanBase is a distributed relational database developed by Ant Group that supports native vector indexing and search capabilities using HNSW (Hierarchical Navigable Small World) algorithm.
Use this tool to find semantically similar documents to a given query by leveraging OceanBase's vector search functionality.
For more information about OceanBase vector capabilities, see:
https://en.oceanbase.com/docs/common-oceanbase-database-10000000001976351
## Installation
Install the crewai_tools package with OceanBase support by executing the following command in your terminal:
```shell
pip install crewai-tools[oceanbase]
```
or
```shell
uv add crewai-tools --extra oceanbase
```
## Example
### Basic Usage
```python
from crewai_tools import OceanBaseVectorSearchTool
tool = OceanBaseVectorSearchTool(
connection_uri="127.0.0.1:2881",
user="root@test",
password="",
db_name="test",
table_name="documents",
)
```
### With Custom Configuration
```python
from crewai_tools import OceanBaseVectorSearchConfig, OceanBaseVectorSearchTool
query_config = OceanBaseVectorSearchConfig(
limit=10,
distance_func="cosine",
distance_threshold=0.5,
)
tool = OceanBaseVectorSearchTool(
connection_uri="127.0.0.1:2881",
user="root@test",
password="your_password",
db_name="my_database",
table_name="my_documents",
vector_column_name="embedding",
text_column_name="content",
metadata_column_name="metadata",
query_config=query_config,
embedding_model="text-embedding-3-large",
dimensions=3072,
)
```
### Adding the Tool to an Agent
```python
from crewai import Agent
from crewai_tools import OceanBaseVectorSearchTool
tool = OceanBaseVectorSearchTool(
connection_uri="127.0.0.1:2881",
user="root@test",
db_name="test",
table_name="documents",
)
rag_agent = Agent(
name="rag_agent",
role="You are a helpful assistant that can answer questions using the OceanBaseVectorSearchTool.",
goal="Answer user questions by searching relevant documents",
backstory="You have access to a knowledge base stored in OceanBase",
llm="gpt-4o-mini",
tools=[tool],
)
```
### Preloading Documents
```python
from crewai_tools import OceanBaseVectorSearchTool
import os
tool = OceanBaseVectorSearchTool(
connection_uri="127.0.0.1:2881",
user="root@test",
db_name="test",
table_name="documents",
)
texts = []
metadatas = []
for filename in os.listdir("knowledge"):
with open(os.path.join("knowledge", filename), "r") as f:
texts.append(f.read())
metadatas.append({"source": filename})
tool.add_texts(texts, metadatas=metadatas)
```
## Configuration Options
### OceanBaseVectorSearchConfig
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `limit` | int | 4 | Number of documents to return |
| `distance_func` | str | "l2" | Distance function: "l2", "cosine", or "inner_product" |
| `distance_threshold` | float | None | Only return results with distance <= threshold |
| `include_embeddings` | bool | False | Whether to include embedding vectors in results |
### OceanBaseVectorSearchTool
| Parameter | Type | Required | Description |
|-----------|------|----------|-------------|
| `connection_uri` | str | Yes | OceanBase connection URI (e.g., "127.0.0.1:2881") |
| `user` | str | Yes | Username for connection (e.g., "root@test") |
| `password` | str | No | Password for connection |
| `db_name` | str | No | Database name (default: "test") |
| `table_name` | str | Yes | Table containing vector data |
| `vector_column_name` | str | No | Column with embeddings (default: "embedding") |
| `text_column_name` | str | No | Column with text content (default: "text") |
| `metadata_column_name` | str | No | Column with metadata (default: "metadata") |
| `embedding_model` | str | No | OpenAI model for embeddings (default: "text-embedding-3-large") |
| `dimensions` | int | No | Embedding dimensions (default: 1536) |
| `query_config` | OceanBaseVectorSearchConfig | No | Search configuration |
## Environment Variables
- `OPENAI_API_KEY`: Required for generating embeddings
- `AZURE_OPENAI_ENDPOINT`: Optional, for Azure OpenAI support

View File

@@ -0,0 +1,12 @@
from crewai_tools.tools.oceanbase_vector_search_tool.oceanbase_vector_search_tool import (
OceanBaseToolSchema,
OceanBaseVectorSearchConfig,
OceanBaseVectorSearchTool,
)
__all__ = [
"OceanBaseToolSchema",
"OceanBaseVectorSearchConfig",
"OceanBaseVectorSearchTool",
]

View File

@@ -0,0 +1,265 @@
from __future__ import annotations
import json
from logging import getLogger
import os
from typing import Any
from crewai.tools import BaseTool, EnvVar
from pydantic import BaseModel, Field
try:
import pyobvector # noqa: F401
PYOBVECTOR_AVAILABLE = True
except ImportError:
PYOBVECTOR_AVAILABLE = False
logger = getLogger(__name__)
class OceanBaseToolSchema(BaseModel):
"""Input schema for OceanBase vector search tool."""
query: str = Field(
...,
description="The query to search for relevant information in the OceanBase database.",
)
class OceanBaseVectorSearchConfig(BaseModel):
"""Configuration for OceanBase vector search queries."""
limit: int = Field(
default=4,
description="Number of documents to return.",
)
distance_threshold: float | None = Field(
default=None,
description="Only return results where distance is less than or equal to this threshold.",
)
distance_func: str = Field(
default="l2",
description="Distance function to use for similarity search. Options: 'l2', 'cosine', 'inner_product'.",
)
include_embeddings: bool = Field(
default=False,
description="Whether to include the embedding vector of each result.",
)
class OceanBaseVectorSearchTool(BaseTool):
"""Tool to perform vector search on OceanBase database."""
name: str = "OceanBaseVectorSearchTool"
description: str = (
"A tool to perform vector similarity search on an OceanBase database "
"for retrieving relevant information from stored documents."
)
args_schema: type[BaseModel] = OceanBaseToolSchema
query_config: OceanBaseVectorSearchConfig | None = Field(
default=None,
description="OceanBase vector search query configuration.",
)
embedding_model: str = Field(
default="text-embedding-3-large",
description="OpenAI embedding model to use for generating query embeddings.",
)
dimensions: int = Field(
default=1536,
description="Number of dimensions in the embedding vector.",
)
connection_uri: str = Field(
...,
description="Connection URI for OceanBase (e.g., '127.0.0.1:2881').",
)
user: str = Field(
...,
description="Username for OceanBase connection (e.g., 'root@test').",
)
password: str = Field(
default="",
description="Password for OceanBase connection.",
)
db_name: str = Field(
default="test",
description="Database name in OceanBase.",
)
table_name: str = Field(
...,
description="Name of the table containing vector data.",
)
vector_column_name: str = Field(
default="embedding",
description="Name of the column containing vector embeddings.",
)
text_column_name: str = Field(
default="text",
description="Name of the column containing text content.",
)
metadata_column_name: str | None = Field(
default="metadata",
description="Name of the column containing metadata (optional).",
)
env_vars: list[EnvVar] = Field(
default_factory=lambda: [
EnvVar(
name="OPENAI_API_KEY",
description="API key for OpenAI embeddings",
required=True,
),
]
)
package_dependencies: list[str] = Field(default_factory=lambda: ["pyobvector"])
_client: Any = None
_openai_client: Any = None
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
if not PYOBVECTOR_AVAILABLE:
import click
if click.confirm(
"You are missing the 'pyobvector' package. Would you like to install it?"
):
import subprocess
subprocess.run(["uv", "add", "pyobvector"], check=True) # noqa: S607
else:
raise ImportError(
"The 'pyobvector' package is required for OceanBaseVectorSearchTool."
)
if "AZURE_OPENAI_ENDPOINT" in os.environ:
from openai import AzureOpenAI
self._openai_client = AzureOpenAI()
elif "OPENAI_API_KEY" in os.environ:
from openai import Client
self._openai_client = Client()
else:
raise ValueError(
"OPENAI_API_KEY environment variable is required for OceanBaseVectorSearchTool."
)
from pyobvector import ObVecClient
self._client = ObVecClient(
uri=self.connection_uri,
user=self.user,
password=self.password,
db_name=self.db_name,
)
def _embed_text(self, text: str) -> list[float]:
"""Generate embedding for the given text using OpenAI."""
response = self._openai_client.embeddings.create(
input=[text],
model=self.embedding_model,
dimensions=self.dimensions,
)
return response.data[0].embedding
def _get_distance_func(self) -> Any:
"""Get the appropriate distance function from pyobvector."""
from pyobvector import cosine_distance, inner_product, l2_distance
config = self.query_config or OceanBaseVectorSearchConfig()
distance_funcs = {
"l2": l2_distance,
"cosine": cosine_distance,
"inner_product": inner_product,
}
return distance_funcs.get(config.distance_func, l2_distance)
def _run(self, query: str) -> str:
"""Execute vector search on OceanBase."""
try:
config = self.query_config or OceanBaseVectorSearchConfig()
query_vector = self._embed_text(query)
output_columns = [self.text_column_name]
if self.metadata_column_name:
output_columns.append(self.metadata_column_name)
results = self._client.ann_search(
table_name=self.table_name,
vec_data=query_vector,
vec_column_name=self.vector_column_name,
distance_func=self._get_distance_func(),
with_dist=True,
topk=config.limit,
output_column_names=output_columns,
distance_threshold=config.distance_threshold,
)
formatted_results = []
for row in results:
result_dict: dict[str, Any] = {}
if len(row) >= 1:
result_dict["text"] = row[0]
if self.metadata_column_name and len(row) >= 2:
result_dict["metadata"] = row[1]
if len(row) > len(output_columns):
result_dict["distance"] = row[-1]
formatted_results.append(result_dict)
return json.dumps(formatted_results, indent=2, default=str)
except Exception as e:
logger.error(f"Error during OceanBase vector search: {e}")
return json.dumps({"error": str(e)})
def add_texts(
self,
texts: list[str],
metadatas: list[dict[str, Any]] | None = None,
ids: list[str] | None = None,
) -> list[str]:
"""Add texts with embeddings to the OceanBase table.
Args:
texts: List of text strings to add.
metadatas: Optional list of metadata dictionaries for each text.
ids: Optional list of unique IDs for each text.
Returns:
List of IDs for the added texts.
"""
import uuid
if ids is None:
ids = [str(uuid.uuid4()) for _ in texts]
if metadatas is None:
metadatas = [{} for _ in texts]
data = []
for text, metadata, doc_id in zip(texts, metadatas, ids, strict=False):
embedding = self._embed_text(text)
row = {
"id": doc_id,
self.text_column_name: text,
self.vector_column_name: embedding,
}
if self.metadata_column_name:
row[self.metadata_column_name] = metadata
data.append(row)
self._client.insert(self.table_name, data=data)
return ids
def __del__(self) -> None:
"""Cleanup clients on deletion."""
try:
if hasattr(self, "_openai_client") and self._openai_client:
self._openai_client.close()
except Exception as e:
logger.error(f"Error closing OpenAI client: {e}")

View File

@@ -0,0 +1,208 @@
import json
import sys
from unittest.mock import MagicMock, patch
import pytest
from crewai_tools import OceanBaseVectorSearchConfig
mock_pyobvector = MagicMock()
mock_pyobvector.ObVecClient = MagicMock()
mock_pyobvector.l2_distance = MagicMock(return_value="l2_func")
mock_pyobvector.cosine_distance = MagicMock(return_value="cosine_func")
mock_pyobvector.inner_product = MagicMock(return_value="ip_func")
sys.modules["pyobvector"] = mock_pyobvector
@pytest.fixture
def mock_openai_client():
"""Create a mock OpenAI client."""
mock_client = MagicMock()
mock_embedding = MagicMock()
mock_embedding.embedding = [0.1] * 1536
mock_response = MagicMock()
mock_response.data = [mock_embedding]
mock_client.embeddings.create.return_value = mock_response
return mock_client
@pytest.fixture
def mock_obvec_client():
"""Create a mock OceanBase vector client."""
mock_client = MagicMock()
return mock_client
@pytest.fixture
def oceanbase_vector_search_tool(mock_openai_client, mock_obvec_client):
"""Create an OceanBaseVectorSearchTool with mocked clients."""
from crewai_tools import OceanBaseVectorSearchTool
with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}):
with patch(
"crewai_tools.tools.oceanbase_vector_search_tool.oceanbase_vector_search_tool.PYOBVECTOR_AVAILABLE",
True,
):
mock_pyobvector.ObVecClient.return_value = mock_obvec_client
with patch("openai.Client") as mock_openai_class:
mock_openai_class.return_value = mock_openai_client
tool = OceanBaseVectorSearchTool(
connection_uri="127.0.0.1:2881",
user="root@test",
password="",
db_name="test",
table_name="test_table",
)
tool._openai_client = mock_openai_client
tool._client = mock_obvec_client
yield tool
def test_successful_query_execution(oceanbase_vector_search_tool, mock_obvec_client):
"""Test successful vector search query execution."""
mock_obvec_client.ann_search.return_value = [
("test document content", {"source": "test.txt"}, 0.1),
("another document", {"source": "test2.txt"}, 0.2),
]
results = json.loads(oceanbase_vector_search_tool._run(query="test query"))
assert len(results) == 2
assert results[0]["text"] == "test document content"
assert results[0]["metadata"] == {"source": "test.txt"}
assert results[0]["distance"] == 0.1
def test_query_with_custom_config(mock_openai_client, mock_obvec_client):
"""Test vector search with custom configuration."""
from crewai_tools import OceanBaseVectorSearchTool
query_config = OceanBaseVectorSearchConfig(
limit=10,
distance_func="cosine",
distance_threshold=0.5,
)
with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}):
with patch(
"crewai_tools.tools.oceanbase_vector_search_tool.oceanbase_vector_search_tool.PYOBVECTOR_AVAILABLE",
True,
):
mock_pyobvector.ObVecClient.return_value = mock_obvec_client
with patch("openai.Client") as mock_openai_class:
mock_openai_class.return_value = mock_openai_client
tool = OceanBaseVectorSearchTool(
connection_uri="127.0.0.1:2881",
user="root@test",
db_name="test",
table_name="test_table",
query_config=query_config,
)
tool._openai_client = mock_openai_client
tool._client = mock_obvec_client
mock_obvec_client.ann_search.return_value = [("doc", {}, 0.3)]
tool._run(query="test")
call_kwargs = mock_obvec_client.ann_search.call_args.kwargs
assert call_kwargs["topk"] == 10
assert call_kwargs["distance_threshold"] == 0.5
def test_add_texts(oceanbase_vector_search_tool, mock_obvec_client):
"""Test adding texts to the OceanBase table."""
texts = ["document 1", "document 2"]
metadatas = [{"source": "file1.txt"}, {"source": "file2.txt"}]
result_ids = oceanbase_vector_search_tool.add_texts(texts, metadatas=metadatas)
assert len(result_ids) == 2
mock_obvec_client.insert.assert_called_once()
call_args = mock_obvec_client.insert.call_args
assert call_args[0][0] == "test_table"
assert len(call_args[1]["data"]) == 2
def test_add_texts_without_metadata(oceanbase_vector_search_tool, mock_obvec_client):
"""Test adding texts without metadata."""
texts = ["document 1", "document 2"]
result_ids = oceanbase_vector_search_tool.add_texts(texts)
assert len(result_ids) == 2
mock_obvec_client.insert.assert_called_once()
def test_error_handling(oceanbase_vector_search_tool, mock_obvec_client):
"""Test error handling during search."""
mock_obvec_client.ann_search.side_effect = Exception("Database connection error")
result = json.loads(oceanbase_vector_search_tool._run(query="test"))
assert "error" in result
assert "Database connection error" in result["error"]
def test_config_defaults():
"""Test OceanBaseVectorSearchConfig default values."""
config = OceanBaseVectorSearchConfig()
assert config.limit == 4
assert config.distance_func == "l2"
assert config.distance_threshold is None
assert config.include_embeddings is False
def test_config_custom_values():
"""Test OceanBaseVectorSearchConfig with custom values."""
config = OceanBaseVectorSearchConfig(
limit=20,
distance_func="cosine",
distance_threshold=0.8,
include_embeddings=True,
)
assert config.limit == 20
assert config.distance_func == "cosine"
assert config.distance_threshold == 0.8
assert config.include_embeddings is True
def test_tool_schema():
"""Test OceanBaseToolSchema validation."""
from crewai_tools import OceanBaseToolSchema
schema = OceanBaseToolSchema(query="test query")
assert schema.query == "test query"
def test_tool_schema_requires_query():
"""Test that OceanBaseToolSchema requires a query."""
from crewai_tools import OceanBaseToolSchema
from pydantic import ValidationError
with pytest.raises(ValidationError):
OceanBaseToolSchema()
def test_distance_function_selection(oceanbase_vector_search_tool):
"""Test that the correct distance function is selected."""
oceanbase_vector_search_tool.query_config = OceanBaseVectorSearchConfig(
distance_func="l2"
)
func = oceanbase_vector_search_tool._get_distance_func()
assert func == mock_pyobvector.l2_distance
oceanbase_vector_search_tool.query_config = OceanBaseVectorSearchConfig(
distance_func="cosine"
)
func = oceanbase_vector_search_tool._get_distance_func()
assert func == mock_pyobvector.cosine_distance
oceanbase_vector_search_tool.query_config = OceanBaseVectorSearchConfig(
distance_func="inner_product"
)
func = oceanbase_vector_search_tool._get_distance_func()
assert func == mock_pyobvector.inner_product