mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-02-11 08:28:18 +00:00
Compare commits
3 Commits
codex/code
...
devin/1770
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c2e1fa0231 | ||
|
|
5ecbfa26cb | ||
|
|
87cf7b9234 |
@@ -98,6 +98,11 @@ from crewai_tools.tools.mongodb_vector_search_tool.vector_search import (
|
||||
MongoDBVectorSearchTool,
|
||||
)
|
||||
from crewai_tools.tools.multion_tool.multion_tool import MultiOnTool
|
||||
from crewai_tools.tools.oceanbase_vector_search_tool.oceanbase_vector_search_tool import (
|
||||
OceanBaseToolSchema,
|
||||
OceanBaseVectorSearchConfig,
|
||||
OceanBaseVectorSearchTool,
|
||||
)
|
||||
from crewai_tools.tools.mysql_search_tool.mysql_search_tool import MySQLSearchTool
|
||||
from crewai_tools.tools.nl2sql.nl2sql_tool import NL2SQLTool
|
||||
from crewai_tools.tools.ocr_tool.ocr_tool import OCRTool
|
||||
@@ -243,6 +248,9 @@ __all__ = [
|
||||
"MongoDBVectorSearchTool",
|
||||
"MultiOnTool",
|
||||
"MySQLSearchTool",
|
||||
"OceanBaseToolSchema",
|
||||
"OceanBaseVectorSearchConfig",
|
||||
"OceanBaseVectorSearchTool",
|
||||
"NL2SQLTool",
|
||||
"OCRTool",
|
||||
"OxylabsAmazonProductScraperTool",
|
||||
|
||||
@@ -87,6 +87,11 @@ from crewai_tools.tools.mongodb_vector_search_tool import (
|
||||
MongoDBVectorSearchConfig,
|
||||
MongoDBVectorSearchTool,
|
||||
)
|
||||
from crewai_tools.tools.oceanbase_vector_search_tool import (
|
||||
OceanBaseToolSchema,
|
||||
OceanBaseVectorSearchConfig,
|
||||
OceanBaseVectorSearchTool,
|
||||
)
|
||||
from crewai_tools.tools.multion_tool.multion_tool import MultiOnTool
|
||||
from crewai_tools.tools.mysql_search_tool.mysql_search_tool import MySQLSearchTool
|
||||
from crewai_tools.tools.nl2sql.nl2sql_tool import NL2SQLTool
|
||||
@@ -226,6 +231,9 @@ __all__ = [
|
||||
"MongoDBVectorSearchConfig",
|
||||
"MongoDBVectorSearchTool",
|
||||
"MultiOnTool",
|
||||
"OceanBaseToolSchema",
|
||||
"OceanBaseVectorSearchConfig",
|
||||
"OceanBaseVectorSearchTool",
|
||||
"MySQLSearchTool",
|
||||
"NL2SQLTool",
|
||||
"OCRTool",
|
||||
|
||||
@@ -0,0 +1,144 @@
|
||||
# OceanBaseVectorSearchTool
|
||||
|
||||
## Description
|
||||
|
||||
This tool is designed for performing vector similarity searches within an OceanBase database. OceanBase is a distributed relational database developed by Ant Group that supports native vector indexing and search capabilities using HNSW (Hierarchical Navigable Small World) algorithm.
|
||||
|
||||
Use this tool to find semantically similar documents to a given query by leveraging OceanBase's vector search functionality.
|
||||
|
||||
For more information about OceanBase vector capabilities, see:
|
||||
https://en.oceanbase.com/docs/common-oceanbase-database-10000000001976351
|
||||
|
||||
## Installation
|
||||
|
||||
Install the crewai_tools package with OceanBase support by executing the following command in your terminal:
|
||||
|
||||
```shell
|
||||
pip install crewai-tools[oceanbase]
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```shell
|
||||
uv add crewai-tools --extra oceanbase
|
||||
```
|
||||
|
||||
## Example
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```python
|
||||
from crewai_tools import OceanBaseVectorSearchTool
|
||||
|
||||
tool = OceanBaseVectorSearchTool(
|
||||
connection_uri="127.0.0.1:2881",
|
||||
user="root@test",
|
||||
password="",
|
||||
db_name="test",
|
||||
table_name="documents",
|
||||
)
|
||||
```
|
||||
|
||||
### With Custom Configuration
|
||||
|
||||
```python
|
||||
from crewai_tools import OceanBaseVectorSearchConfig, OceanBaseVectorSearchTool
|
||||
|
||||
query_config = OceanBaseVectorSearchConfig(
|
||||
limit=10,
|
||||
distance_func="cosine",
|
||||
distance_threshold=0.5,
|
||||
)
|
||||
|
||||
tool = OceanBaseVectorSearchTool(
|
||||
connection_uri="127.0.0.1:2881",
|
||||
user="root@test",
|
||||
password="your_password",
|
||||
db_name="my_database",
|
||||
table_name="my_documents",
|
||||
vector_column_name="embedding",
|
||||
text_column_name="content",
|
||||
metadata_column_name="metadata",
|
||||
query_config=query_config,
|
||||
embedding_model="text-embedding-3-large",
|
||||
dimensions=3072,
|
||||
)
|
||||
```
|
||||
|
||||
### Adding the Tool to an Agent
|
||||
|
||||
```python
|
||||
from crewai import Agent
|
||||
from crewai_tools import OceanBaseVectorSearchTool
|
||||
|
||||
tool = OceanBaseVectorSearchTool(
|
||||
connection_uri="127.0.0.1:2881",
|
||||
user="root@test",
|
||||
db_name="test",
|
||||
table_name="documents",
|
||||
)
|
||||
|
||||
rag_agent = Agent(
|
||||
name="rag_agent",
|
||||
role="You are a helpful assistant that can answer questions using the OceanBaseVectorSearchTool.",
|
||||
goal="Answer user questions by searching relevant documents",
|
||||
backstory="You have access to a knowledge base stored in OceanBase",
|
||||
llm="gpt-4o-mini",
|
||||
tools=[tool],
|
||||
)
|
||||
```
|
||||
|
||||
### Preloading Documents
|
||||
|
||||
```python
|
||||
from crewai_tools import OceanBaseVectorSearchTool
|
||||
import os
|
||||
|
||||
tool = OceanBaseVectorSearchTool(
|
||||
connection_uri="127.0.0.1:2881",
|
||||
user="root@test",
|
||||
db_name="test",
|
||||
table_name="documents",
|
||||
)
|
||||
|
||||
texts = []
|
||||
metadatas = []
|
||||
for filename in os.listdir("knowledge"):
|
||||
with open(os.path.join("knowledge", filename), "r") as f:
|
||||
texts.append(f.read())
|
||||
metadatas.append({"source": filename})
|
||||
|
||||
tool.add_texts(texts, metadatas=metadatas)
|
||||
```
|
||||
|
||||
## Configuration Options
|
||||
|
||||
### OceanBaseVectorSearchConfig
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `limit` | int | 4 | Number of documents to return |
|
||||
| `distance_func` | str | "l2" | Distance function: "l2", "cosine", or "inner_product" |
|
||||
| `distance_threshold` | float | None | Only return results with distance <= threshold |
|
||||
| `include_embeddings` | bool | False | Whether to include embedding vectors in results |
|
||||
|
||||
### OceanBaseVectorSearchTool
|
||||
|
||||
| Parameter | Type | Required | Description |
|
||||
|-----------|------|----------|-------------|
|
||||
| `connection_uri` | str | Yes | OceanBase connection URI (e.g., "127.0.0.1:2881") |
|
||||
| `user` | str | Yes | Username for connection (e.g., "root@test") |
|
||||
| `password` | str | No | Password for connection |
|
||||
| `db_name` | str | No | Database name (default: "test") |
|
||||
| `table_name` | str | Yes | Table containing vector data |
|
||||
| `vector_column_name` | str | No | Column with embeddings (default: "embedding") |
|
||||
| `text_column_name` | str | No | Column with text content (default: "text") |
|
||||
| `metadata_column_name` | str | No | Column with metadata (default: "metadata") |
|
||||
| `embedding_model` | str | No | OpenAI model for embeddings (default: "text-embedding-3-large") |
|
||||
| `dimensions` | int | No | Embedding dimensions (default: 1536) |
|
||||
| `query_config` | OceanBaseVectorSearchConfig | No | Search configuration |
|
||||
|
||||
## Environment Variables
|
||||
|
||||
- `OPENAI_API_KEY`: Required for generating embeddings
|
||||
- `AZURE_OPENAI_ENDPOINT`: Optional, for Azure OpenAI support
|
||||
@@ -0,0 +1,12 @@
|
||||
from crewai_tools.tools.oceanbase_vector_search_tool.oceanbase_vector_search_tool import (
|
||||
OceanBaseToolSchema,
|
||||
OceanBaseVectorSearchConfig,
|
||||
OceanBaseVectorSearchTool,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"OceanBaseToolSchema",
|
||||
"OceanBaseVectorSearchConfig",
|
||||
"OceanBaseVectorSearchTool",
|
||||
]
|
||||
@@ -0,0 +1,265 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from logging import getLogger
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool, EnvVar
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
try:
|
||||
import pyobvector # noqa: F401
|
||||
|
||||
PYOBVECTOR_AVAILABLE = True
|
||||
except ImportError:
|
||||
PYOBVECTOR_AVAILABLE = False
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class OceanBaseToolSchema(BaseModel):
|
||||
"""Input schema for OceanBase vector search tool."""
|
||||
|
||||
query: str = Field(
|
||||
...,
|
||||
description="The query to search for relevant information in the OceanBase database.",
|
||||
)
|
||||
|
||||
|
||||
class OceanBaseVectorSearchConfig(BaseModel):
|
||||
"""Configuration for OceanBase vector search queries."""
|
||||
|
||||
limit: int = Field(
|
||||
default=4,
|
||||
description="Number of documents to return.",
|
||||
)
|
||||
distance_threshold: float | None = Field(
|
||||
default=None,
|
||||
description="Only return results where distance is less than or equal to this threshold.",
|
||||
)
|
||||
distance_func: str = Field(
|
||||
default="l2",
|
||||
description="Distance function to use for similarity search. Options: 'l2', 'cosine', 'inner_product'.",
|
||||
)
|
||||
include_embeddings: bool = Field(
|
||||
default=False,
|
||||
description="Whether to include the embedding vector of each result.",
|
||||
)
|
||||
|
||||
|
||||
class OceanBaseVectorSearchTool(BaseTool):
|
||||
"""Tool to perform vector search on OceanBase database."""
|
||||
|
||||
name: str = "OceanBaseVectorSearchTool"
|
||||
description: str = (
|
||||
"A tool to perform vector similarity search on an OceanBase database "
|
||||
"for retrieving relevant information from stored documents."
|
||||
)
|
||||
|
||||
args_schema: type[BaseModel] = OceanBaseToolSchema
|
||||
query_config: OceanBaseVectorSearchConfig | None = Field(
|
||||
default=None,
|
||||
description="OceanBase vector search query configuration.",
|
||||
)
|
||||
embedding_model: str = Field(
|
||||
default="text-embedding-3-large",
|
||||
description="OpenAI embedding model to use for generating query embeddings.",
|
||||
)
|
||||
dimensions: int = Field(
|
||||
default=1536,
|
||||
description="Number of dimensions in the embedding vector.",
|
||||
)
|
||||
connection_uri: str = Field(
|
||||
...,
|
||||
description="Connection URI for OceanBase (e.g., '127.0.0.1:2881').",
|
||||
)
|
||||
user: str = Field(
|
||||
...,
|
||||
description="Username for OceanBase connection (e.g., 'root@test').",
|
||||
)
|
||||
password: str = Field(
|
||||
default="",
|
||||
description="Password for OceanBase connection.",
|
||||
)
|
||||
db_name: str = Field(
|
||||
default="test",
|
||||
description="Database name in OceanBase.",
|
||||
)
|
||||
table_name: str = Field(
|
||||
...,
|
||||
description="Name of the table containing vector data.",
|
||||
)
|
||||
vector_column_name: str = Field(
|
||||
default="embedding",
|
||||
description="Name of the column containing vector embeddings.",
|
||||
)
|
||||
text_column_name: str = Field(
|
||||
default="text",
|
||||
description="Name of the column containing text content.",
|
||||
)
|
||||
metadata_column_name: str | None = Field(
|
||||
default="metadata",
|
||||
description="Name of the column containing metadata (optional).",
|
||||
)
|
||||
env_vars: list[EnvVar] = Field(
|
||||
default_factory=lambda: [
|
||||
EnvVar(
|
||||
name="OPENAI_API_KEY",
|
||||
description="API key for OpenAI embeddings",
|
||||
required=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
package_dependencies: list[str] = Field(default_factory=lambda: ["pyobvector"])
|
||||
|
||||
_client: Any = None
|
||||
_openai_client: Any = None
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
if not PYOBVECTOR_AVAILABLE:
|
||||
import click
|
||||
|
||||
if click.confirm(
|
||||
"You are missing the 'pyobvector' package. Would you like to install it?"
|
||||
):
|
||||
import subprocess
|
||||
|
||||
subprocess.run(["uv", "add", "pyobvector"], check=True) # noqa: S607
|
||||
else:
|
||||
raise ImportError(
|
||||
"The 'pyobvector' package is required for OceanBaseVectorSearchTool."
|
||||
)
|
||||
|
||||
if "AZURE_OPENAI_ENDPOINT" in os.environ:
|
||||
from openai import AzureOpenAI
|
||||
|
||||
self._openai_client = AzureOpenAI()
|
||||
elif "OPENAI_API_KEY" in os.environ:
|
||||
from openai import Client
|
||||
|
||||
self._openai_client = Client()
|
||||
else:
|
||||
raise ValueError(
|
||||
"OPENAI_API_KEY environment variable is required for OceanBaseVectorSearchTool."
|
||||
)
|
||||
|
||||
from pyobvector import ObVecClient
|
||||
|
||||
self._client = ObVecClient(
|
||||
uri=self.connection_uri,
|
||||
user=self.user,
|
||||
password=self.password,
|
||||
db_name=self.db_name,
|
||||
)
|
||||
|
||||
def _embed_text(self, text: str) -> list[float]:
|
||||
"""Generate embedding for the given text using OpenAI."""
|
||||
response = self._openai_client.embeddings.create(
|
||||
input=[text],
|
||||
model=self.embedding_model,
|
||||
dimensions=self.dimensions,
|
||||
)
|
||||
return response.data[0].embedding
|
||||
|
||||
def _get_distance_func(self) -> Any:
|
||||
"""Get the appropriate distance function from pyobvector."""
|
||||
from pyobvector import cosine_distance, inner_product, l2_distance
|
||||
|
||||
config = self.query_config or OceanBaseVectorSearchConfig()
|
||||
distance_funcs = {
|
||||
"l2": l2_distance,
|
||||
"cosine": cosine_distance,
|
||||
"inner_product": inner_product,
|
||||
}
|
||||
return distance_funcs.get(config.distance_func, l2_distance)
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
"""Execute vector search on OceanBase."""
|
||||
try:
|
||||
config = self.query_config or OceanBaseVectorSearchConfig()
|
||||
|
||||
query_vector = self._embed_text(query)
|
||||
|
||||
output_columns = [self.text_column_name]
|
||||
if self.metadata_column_name:
|
||||
output_columns.append(self.metadata_column_name)
|
||||
|
||||
results = self._client.ann_search(
|
||||
table_name=self.table_name,
|
||||
vec_data=query_vector,
|
||||
vec_column_name=self.vector_column_name,
|
||||
distance_func=self._get_distance_func(),
|
||||
with_dist=True,
|
||||
topk=config.limit,
|
||||
output_column_names=output_columns,
|
||||
distance_threshold=config.distance_threshold,
|
||||
)
|
||||
|
||||
formatted_results = []
|
||||
for row in results:
|
||||
result_dict: dict[str, Any] = {}
|
||||
|
||||
if len(row) >= 1:
|
||||
result_dict["text"] = row[0]
|
||||
if self.metadata_column_name and len(row) >= 2:
|
||||
result_dict["metadata"] = row[1]
|
||||
if len(row) > len(output_columns):
|
||||
result_dict["distance"] = row[-1]
|
||||
|
||||
formatted_results.append(result_dict)
|
||||
|
||||
return json.dumps(formatted_results, indent=2, default=str)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during OceanBase vector search: {e}")
|
||||
return json.dumps({"error": str(e)})
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: list[str],
|
||||
metadatas: list[dict[str, Any]] | None = None,
|
||||
ids: list[str] | None = None,
|
||||
) -> list[str]:
|
||||
"""Add texts with embeddings to the OceanBase table.
|
||||
|
||||
Args:
|
||||
texts: List of text strings to add.
|
||||
metadatas: Optional list of metadata dictionaries for each text.
|
||||
ids: Optional list of unique IDs for each text.
|
||||
|
||||
Returns:
|
||||
List of IDs for the added texts.
|
||||
"""
|
||||
import uuid
|
||||
|
||||
if ids is None:
|
||||
ids = [str(uuid.uuid4()) for _ in texts]
|
||||
|
||||
if metadatas is None:
|
||||
metadatas = [{} for _ in texts]
|
||||
|
||||
data = []
|
||||
for text, metadata, doc_id in zip(texts, metadatas, ids, strict=False):
|
||||
embedding = self._embed_text(text)
|
||||
row = {
|
||||
"id": doc_id,
|
||||
self.text_column_name: text,
|
||||
self.vector_column_name: embedding,
|
||||
}
|
||||
if self.metadata_column_name:
|
||||
row[self.metadata_column_name] = metadata
|
||||
data.append(row)
|
||||
|
||||
self._client.insert(self.table_name, data=data)
|
||||
return ids
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""Cleanup clients on deletion."""
|
||||
try:
|
||||
if hasattr(self, "_openai_client") and self._openai_client:
|
||||
self._openai_client.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing OpenAI client: {e}")
|
||||
@@ -0,0 +1,208 @@
|
||||
import json
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai_tools import OceanBaseVectorSearchConfig
|
||||
|
||||
|
||||
mock_pyobvector = MagicMock()
|
||||
mock_pyobvector.ObVecClient = MagicMock()
|
||||
mock_pyobvector.l2_distance = MagicMock(return_value="l2_func")
|
||||
mock_pyobvector.cosine_distance = MagicMock(return_value="cosine_func")
|
||||
mock_pyobvector.inner_product = MagicMock(return_value="ip_func")
|
||||
sys.modules["pyobvector"] = mock_pyobvector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_client():
|
||||
"""Create a mock OpenAI client."""
|
||||
mock_client = MagicMock()
|
||||
mock_embedding = MagicMock()
|
||||
mock_embedding.embedding = [0.1] * 1536
|
||||
mock_response = MagicMock()
|
||||
mock_response.data = [mock_embedding]
|
||||
mock_client.embeddings.create.return_value = mock_response
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_obvec_client():
|
||||
"""Create a mock OceanBase vector client."""
|
||||
mock_client = MagicMock()
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oceanbase_vector_search_tool(mock_openai_client, mock_obvec_client):
|
||||
"""Create an OceanBaseVectorSearchTool with mocked clients."""
|
||||
from crewai_tools import OceanBaseVectorSearchTool
|
||||
|
||||
with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}):
|
||||
with patch(
|
||||
"crewai_tools.tools.oceanbase_vector_search_tool.oceanbase_vector_search_tool.PYOBVECTOR_AVAILABLE",
|
||||
True,
|
||||
):
|
||||
mock_pyobvector.ObVecClient.return_value = mock_obvec_client
|
||||
with patch("openai.Client") as mock_openai_class:
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
tool = OceanBaseVectorSearchTool(
|
||||
connection_uri="127.0.0.1:2881",
|
||||
user="root@test",
|
||||
password="",
|
||||
db_name="test",
|
||||
table_name="test_table",
|
||||
)
|
||||
tool._openai_client = mock_openai_client
|
||||
tool._client = mock_obvec_client
|
||||
yield tool
|
||||
|
||||
|
||||
def test_successful_query_execution(oceanbase_vector_search_tool, mock_obvec_client):
|
||||
"""Test successful vector search query execution."""
|
||||
mock_obvec_client.ann_search.return_value = [
|
||||
("test document content", {"source": "test.txt"}, 0.1),
|
||||
("another document", {"source": "test2.txt"}, 0.2),
|
||||
]
|
||||
|
||||
results = json.loads(oceanbase_vector_search_tool._run(query="test query"))
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0]["text"] == "test document content"
|
||||
assert results[0]["metadata"] == {"source": "test.txt"}
|
||||
assert results[0]["distance"] == 0.1
|
||||
|
||||
|
||||
def test_query_with_custom_config(mock_openai_client, mock_obvec_client):
|
||||
"""Test vector search with custom configuration."""
|
||||
from crewai_tools import OceanBaseVectorSearchTool
|
||||
|
||||
query_config = OceanBaseVectorSearchConfig(
|
||||
limit=10,
|
||||
distance_func="cosine",
|
||||
distance_threshold=0.5,
|
||||
)
|
||||
|
||||
with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}):
|
||||
with patch(
|
||||
"crewai_tools.tools.oceanbase_vector_search_tool.oceanbase_vector_search_tool.PYOBVECTOR_AVAILABLE",
|
||||
True,
|
||||
):
|
||||
mock_pyobvector.ObVecClient.return_value = mock_obvec_client
|
||||
with patch("openai.Client") as mock_openai_class:
|
||||
mock_openai_class.return_value = mock_openai_client
|
||||
tool = OceanBaseVectorSearchTool(
|
||||
connection_uri="127.0.0.1:2881",
|
||||
user="root@test",
|
||||
db_name="test",
|
||||
table_name="test_table",
|
||||
query_config=query_config,
|
||||
)
|
||||
tool._openai_client = mock_openai_client
|
||||
tool._client = mock_obvec_client
|
||||
|
||||
mock_obvec_client.ann_search.return_value = [("doc", {}, 0.3)]
|
||||
|
||||
tool._run(query="test")
|
||||
|
||||
call_kwargs = mock_obvec_client.ann_search.call_args.kwargs
|
||||
assert call_kwargs["topk"] == 10
|
||||
assert call_kwargs["distance_threshold"] == 0.5
|
||||
|
||||
|
||||
def test_add_texts(oceanbase_vector_search_tool, mock_obvec_client):
|
||||
"""Test adding texts to the OceanBase table."""
|
||||
texts = ["document 1", "document 2"]
|
||||
metadatas = [{"source": "file1.txt"}, {"source": "file2.txt"}]
|
||||
|
||||
result_ids = oceanbase_vector_search_tool.add_texts(texts, metadatas=metadatas)
|
||||
|
||||
assert len(result_ids) == 2
|
||||
mock_obvec_client.insert.assert_called_once()
|
||||
call_args = mock_obvec_client.insert.call_args
|
||||
assert call_args[0][0] == "test_table"
|
||||
assert len(call_args[1]["data"]) == 2
|
||||
|
||||
|
||||
def test_add_texts_without_metadata(oceanbase_vector_search_tool, mock_obvec_client):
|
||||
"""Test adding texts without metadata."""
|
||||
texts = ["document 1", "document 2"]
|
||||
|
||||
result_ids = oceanbase_vector_search_tool.add_texts(texts)
|
||||
|
||||
assert len(result_ids) == 2
|
||||
mock_obvec_client.insert.assert_called_once()
|
||||
|
||||
|
||||
def test_error_handling(oceanbase_vector_search_tool, mock_obvec_client):
|
||||
"""Test error handling during search."""
|
||||
mock_obvec_client.ann_search.side_effect = Exception("Database connection error")
|
||||
|
||||
result = json.loads(oceanbase_vector_search_tool._run(query="test"))
|
||||
|
||||
assert "error" in result
|
||||
assert "Database connection error" in result["error"]
|
||||
|
||||
|
||||
def test_config_defaults():
|
||||
"""Test OceanBaseVectorSearchConfig default values."""
|
||||
config = OceanBaseVectorSearchConfig()
|
||||
|
||||
assert config.limit == 4
|
||||
assert config.distance_func == "l2"
|
||||
assert config.distance_threshold is None
|
||||
assert config.include_embeddings is False
|
||||
|
||||
|
||||
def test_config_custom_values():
|
||||
"""Test OceanBaseVectorSearchConfig with custom values."""
|
||||
config = OceanBaseVectorSearchConfig(
|
||||
limit=20,
|
||||
distance_func="cosine",
|
||||
distance_threshold=0.8,
|
||||
include_embeddings=True,
|
||||
)
|
||||
|
||||
assert config.limit == 20
|
||||
assert config.distance_func == "cosine"
|
||||
assert config.distance_threshold == 0.8
|
||||
assert config.include_embeddings is True
|
||||
|
||||
|
||||
def test_tool_schema():
|
||||
"""Test OceanBaseToolSchema validation."""
|
||||
from crewai_tools import OceanBaseToolSchema
|
||||
|
||||
schema = OceanBaseToolSchema(query="test query")
|
||||
assert schema.query == "test query"
|
||||
|
||||
|
||||
def test_tool_schema_requires_query():
|
||||
"""Test that OceanBaseToolSchema requires a query."""
|
||||
from crewai_tools import OceanBaseToolSchema
|
||||
from pydantic import ValidationError
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
OceanBaseToolSchema()
|
||||
|
||||
|
||||
def test_distance_function_selection(oceanbase_vector_search_tool):
|
||||
"""Test that the correct distance function is selected."""
|
||||
oceanbase_vector_search_tool.query_config = OceanBaseVectorSearchConfig(
|
||||
distance_func="l2"
|
||||
)
|
||||
func = oceanbase_vector_search_tool._get_distance_func()
|
||||
assert func == mock_pyobvector.l2_distance
|
||||
|
||||
oceanbase_vector_search_tool.query_config = OceanBaseVectorSearchConfig(
|
||||
distance_func="cosine"
|
||||
)
|
||||
func = oceanbase_vector_search_tool._get_distance_func()
|
||||
assert func == mock_pyobvector.cosine_distance
|
||||
|
||||
oceanbase_vector_search_tool.query_config = OceanBaseVectorSearchConfig(
|
||||
distance_func="inner_product"
|
||||
)
|
||||
func = oceanbase_vector_search_tool._get_distance_func()
|
||||
assert func == mock_pyobvector.inner_product
|
||||
Reference in New Issue
Block a user