mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 16:48:30 +00:00
feat: replace embedchain with native crewai adapter (#451)
- Remove embedchain adapter; add crewai rag adapter and update all search tools - Add loaders: pdf, youtube (video & channel), github, docs site, mysql, postgresql - Add configurable similarity threshold, limit params, and embedding_model support - Improve chromadb compatibility (sanitize metadata, convert columns, fix chunking) - Fix xml encoding, Python 3.10 issues, and youtube url spoofing - Update crewai dependency and instructions; refresh uv.lock - Update tests for new rag adapter and search params
This commit is contained in:
@@ -1,43 +1,54 @@
|
||||
import os
|
||||
from tempfile import NamedTemporaryFile
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import cast
|
||||
from unittest import mock
|
||||
from pathlib import Path
|
||||
|
||||
from pytest import fixture
|
||||
|
||||
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
|
||||
from crewai_tools.adapters.crewai_rag_adapter import CrewAIRagAdapter
|
||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||
|
||||
|
||||
@fixture(autouse=True)
|
||||
def mock_embedchain_db_uri():
|
||||
with NamedTemporaryFile() as tmp:
|
||||
uri = f"sqlite:///{tmp.name}"
|
||||
with mock.patch.dict(os.environ, {"EMBEDCHAIN_DB_URI": uri}):
|
||||
yield
|
||||
|
||||
|
||||
def test_custom_llm_and_embedder():
|
||||
def test_rag_tool_initialization():
|
||||
"""Test that RagTool initializes with CrewAI adapter by default."""
|
||||
class MyTool(RagTool):
|
||||
pass
|
||||
|
||||
tool = MyTool(
|
||||
config=dict(
|
||||
llm=dict(
|
||||
provider="openai",
|
||||
config=dict(model="gpt-3.5-custom"),
|
||||
),
|
||||
embedder=dict(
|
||||
provider="openai",
|
||||
config=dict(model="text-embedding-3-custom"),
|
||||
),
|
||||
)
|
||||
)
|
||||
tool = MyTool()
|
||||
assert tool.adapter is not None
|
||||
assert isinstance(tool.adapter, EmbedchainAdapter)
|
||||
assert isinstance(tool.adapter, CrewAIRagAdapter)
|
||||
|
||||
adapter = cast(CrewAIRagAdapter, tool.adapter)
|
||||
assert adapter.collection_name == "rag_tool_collection"
|
||||
assert adapter._client is not None
|
||||
|
||||
adapter = cast(EmbedchainAdapter, tool.adapter)
|
||||
assert adapter.embedchain_app.llm.config.model == "gpt-3.5-custom"
|
||||
assert (
|
||||
adapter.embedchain_app.embedding_model.config.model == "text-embedding-3-custom"
|
||||
)
|
||||
|
||||
def test_rag_tool_add_and_query():
|
||||
"""Test adding content and querying with RagTool."""
|
||||
class MyTool(RagTool):
|
||||
pass
|
||||
|
||||
tool = MyTool()
|
||||
|
||||
tool.add("The sky is blue on a clear day.")
|
||||
tool.add("Machine learning is a subset of artificial intelligence.")
|
||||
|
||||
result = tool._run(query="What color is the sky?")
|
||||
assert "Relevant Content:" in result
|
||||
|
||||
result = tool._run(query="Tell me about machine learning")
|
||||
assert "Relevant Content:" in result
|
||||
|
||||
|
||||
def test_rag_tool_with_file():
|
||||
"""Test RagTool with file content."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("Python is a programming language known for its simplicity.")
|
||||
|
||||
class MyTool(RagTool):
|
||||
pass
|
||||
|
||||
tool = MyTool()
|
||||
tool.add(str(test_file))
|
||||
|
||||
result = tool._run(query="What is Python?")
|
||||
assert "Relevant Content:" in result
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import ANY, MagicMock
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from embedchain.models.data_type import DataType
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from crewai_tools.tools import (
|
||||
CodeDocsSearchTool,
|
||||
CSVSearchTool,
|
||||
@@ -49,7 +49,7 @@ def test_pdf_search_tool(mock_adapter):
|
||||
result = tool._run(query="test content")
|
||||
assert "this is a test" in result.lower()
|
||||
mock_adapter.add.assert_called_once_with("test.pdf", data_type=DataType.PDF_FILE)
|
||||
mock_adapter.query.assert_called_once_with("test content")
|
||||
mock_adapter.query.assert_called_once_with("test content", similarity_threshold=0.6, limit=5)
|
||||
|
||||
mock_adapter.query.reset_mock()
|
||||
mock_adapter.add.reset_mock()
|
||||
@@ -58,7 +58,7 @@ def test_pdf_search_tool(mock_adapter):
|
||||
result = tool._run(pdf="test.pdf", query="test content")
|
||||
assert "this is a test" in result.lower()
|
||||
mock_adapter.add.assert_called_once_with("test.pdf", data_type=DataType.PDF_FILE)
|
||||
mock_adapter.query.assert_called_once_with("test content")
|
||||
mock_adapter.query.assert_called_once_with("test content", similarity_threshold=0.6, limit=5)
|
||||
|
||||
|
||||
def test_txt_search_tool():
|
||||
@@ -82,7 +82,7 @@ def test_docx_search_tool(mock_adapter):
|
||||
result = tool._run(search_query="test content")
|
||||
assert "this is a test" in result.lower()
|
||||
mock_adapter.add.assert_called_once_with("test.docx", data_type=DataType.DOCX)
|
||||
mock_adapter.query.assert_called_once_with("test content")
|
||||
mock_adapter.query.assert_called_once_with("test content", similarity_threshold=0.6, limit=5)
|
||||
|
||||
mock_adapter.query.reset_mock()
|
||||
mock_adapter.add.reset_mock()
|
||||
@@ -91,7 +91,7 @@ def test_docx_search_tool(mock_adapter):
|
||||
result = tool._run(docx="test.docx", search_query="test content")
|
||||
assert "this is a test" in result.lower()
|
||||
mock_adapter.add.assert_called_once_with("test.docx", data_type=DataType.DOCX)
|
||||
mock_adapter.query.assert_called_once_with("test content")
|
||||
mock_adapter.query.assert_called_once_with("test content", similarity_threshold=0.6, limit=5)
|
||||
|
||||
|
||||
def test_json_search_tool():
|
||||
@@ -114,7 +114,7 @@ def test_xml_search_tool(mock_adapter):
|
||||
result = tool._run(search_query="test XML", xml="test.xml")
|
||||
assert "this is a test" in result.lower()
|
||||
mock_adapter.add.assert_called_once_with("test.xml")
|
||||
mock_adapter.query.assert_called_once_with("test XML")
|
||||
mock_adapter.query.assert_called_once_with("test XML", similarity_threshold=0.6, limit=5)
|
||||
|
||||
|
||||
def test_csv_search_tool():
|
||||
@@ -153,8 +153,8 @@ def test_website_search_tool(mock_adapter):
|
||||
tool = WebsiteSearchTool(website=website, adapter=mock_adapter)
|
||||
result = tool._run(search_query=search_query)
|
||||
|
||||
mock_adapter.query.assert_called_once_with("what is crewai?")
|
||||
mock_adapter.add.assert_called_once_with(website, data_type=DataType.WEB_PAGE)
|
||||
mock_adapter.query.assert_called_once_with("what is crewai?", similarity_threshold=0.6, limit=5)
|
||||
mock_adapter.add.assert_called_once_with(website, data_type=DataType.WEBSITE)
|
||||
|
||||
assert "this is a test" in result.lower()
|
||||
|
||||
@@ -164,8 +164,8 @@ def test_website_search_tool(mock_adapter):
|
||||
tool = WebsiteSearchTool(adapter=mock_adapter)
|
||||
result = tool._run(website=website, search_query=search_query)
|
||||
|
||||
mock_adapter.query.assert_called_once_with("what is crewai?")
|
||||
mock_adapter.add.assert_called_once_with(website, data_type=DataType.WEB_PAGE)
|
||||
mock_adapter.query.assert_called_once_with("what is crewai?", similarity_threshold=0.6, limit=5)
|
||||
mock_adapter.add.assert_called_once_with(website, data_type=DataType.WEBSITE)
|
||||
|
||||
assert "this is a test" in result.lower()
|
||||
|
||||
@@ -185,7 +185,7 @@ def test_youtube_video_search_tool(mock_adapter):
|
||||
mock_adapter.add.assert_called_once_with(
|
||||
youtube_video_url, data_type=DataType.YOUTUBE_VIDEO
|
||||
)
|
||||
mock_adapter.query.assert_called_once_with(search_query)
|
||||
mock_adapter.query.assert_called_once_with(search_query, similarity_threshold=0.6, limit=5)
|
||||
|
||||
mock_adapter.query.reset_mock()
|
||||
mock_adapter.add.reset_mock()
|
||||
@@ -197,7 +197,7 @@ def test_youtube_video_search_tool(mock_adapter):
|
||||
mock_adapter.add.assert_called_once_with(
|
||||
youtube_video_url, data_type=DataType.YOUTUBE_VIDEO
|
||||
)
|
||||
mock_adapter.query.assert_called_once_with(search_query)
|
||||
mock_adapter.query.assert_called_once_with(search_query, similarity_threshold=0.6, limit=5)
|
||||
|
||||
|
||||
def test_youtube_channel_search_tool(mock_adapter):
|
||||
@@ -213,7 +213,7 @@ def test_youtube_channel_search_tool(mock_adapter):
|
||||
mock_adapter.add.assert_called_once_with(
|
||||
youtube_channel_handle, data_type=DataType.YOUTUBE_CHANNEL
|
||||
)
|
||||
mock_adapter.query.assert_called_once_with(search_query)
|
||||
mock_adapter.query.assert_called_once_with(search_query, similarity_threshold=0.6, limit=5)
|
||||
|
||||
mock_adapter.query.reset_mock()
|
||||
mock_adapter.add.reset_mock()
|
||||
@@ -227,7 +227,7 @@ def test_youtube_channel_search_tool(mock_adapter):
|
||||
mock_adapter.add.assert_called_once_with(
|
||||
youtube_channel_handle, data_type=DataType.YOUTUBE_CHANNEL
|
||||
)
|
||||
mock_adapter.query.assert_called_once_with(search_query)
|
||||
mock_adapter.query.assert_called_once_with(search_query, similarity_threshold=0.6, limit=5)
|
||||
|
||||
|
||||
def test_code_docs_search_tool(mock_adapter):
|
||||
@@ -239,7 +239,7 @@ def test_code_docs_search_tool(mock_adapter):
|
||||
result = tool._run(search_query=search_query)
|
||||
assert "test documentation" in result
|
||||
mock_adapter.add.assert_called_once_with(docs_url, data_type=DataType.DOCS_SITE)
|
||||
mock_adapter.query.assert_called_once_with(search_query)
|
||||
mock_adapter.query.assert_called_once_with(search_query, similarity_threshold=0.6, limit=5)
|
||||
|
||||
mock_adapter.query.reset_mock()
|
||||
mock_adapter.add.reset_mock()
|
||||
@@ -248,7 +248,7 @@ def test_code_docs_search_tool(mock_adapter):
|
||||
result = tool._run(docs_url=docs_url, search_query=search_query)
|
||||
assert "test documentation" in result
|
||||
mock_adapter.add.assert_called_once_with(docs_url, data_type=DataType.DOCS_SITE)
|
||||
mock_adapter.query.assert_called_once_with(search_query)
|
||||
mock_adapter.query.assert_called_once_with(search_query, similarity_threshold=0.6, limit=5)
|
||||
|
||||
|
||||
def test_github_search_tool(mock_adapter):
|
||||
@@ -264,9 +264,11 @@ def test_github_search_tool(mock_adapter):
|
||||
result = tool._run(search_query="tell me about crewai repo")
|
||||
assert "repo description" in result
|
||||
mock_adapter.add.assert_called_once_with(
|
||||
"repo:crewai/crewai type:code", data_type="github", loader=ANY
|
||||
"https://github.com/crewai/crewai",
|
||||
data_type=DataType.GITHUB,
|
||||
metadata={"content_types": ["code"], "gh_token": "test_token"}
|
||||
)
|
||||
mock_adapter.query.assert_called_once_with("tell me about crewai repo")
|
||||
mock_adapter.query.assert_called_once_with("tell me about crewai repo", similarity_threshold=0.6, limit=5)
|
||||
|
||||
# ensure content types provided by run call is used
|
||||
mock_adapter.query.reset_mock()
|
||||
@@ -280,9 +282,11 @@ def test_github_search_tool(mock_adapter):
|
||||
)
|
||||
assert "repo description" in result
|
||||
mock_adapter.add.assert_called_once_with(
|
||||
"repo:crewai/crewai type:code,issue", data_type="github", loader=ANY
|
||||
"https://github.com/crewai/crewai",
|
||||
data_type=DataType.GITHUB,
|
||||
metadata={"content_types": ["code", "issue"], "gh_token": "test_token"}
|
||||
)
|
||||
mock_adapter.query.assert_called_once_with("tell me about crewai repo")
|
||||
mock_adapter.query.assert_called_once_with("tell me about crewai repo", similarity_threshold=0.6, limit=5)
|
||||
|
||||
# ensure default content types are used if not provided
|
||||
mock_adapter.query.reset_mock()
|
||||
@@ -295,9 +299,11 @@ def test_github_search_tool(mock_adapter):
|
||||
)
|
||||
assert "repo description" in result
|
||||
mock_adapter.add.assert_called_once_with(
|
||||
"repo:crewai/crewai type:code,repo,pr,issue", data_type="github", loader=ANY
|
||||
"https://github.com/crewai/crewai",
|
||||
data_type=DataType.GITHUB,
|
||||
metadata={"content_types": ["code", "repo", "pr", "issue"], "gh_token": "test_token"}
|
||||
)
|
||||
mock_adapter.query.assert_called_once_with("tell me about crewai repo")
|
||||
mock_adapter.query.assert_called_once_with("tell me about crewai repo", similarity_threshold=0.6, limit=5)
|
||||
|
||||
# ensure nothing is added if no repo is provided
|
||||
mock_adapter.query.reset_mock()
|
||||
@@ -306,4 +312,4 @@ def test_github_search_tool(mock_adapter):
|
||||
tool = GithubSearchTool(gh_token="test_token", adapter=mock_adapter)
|
||||
result = tool._run(search_query="tell me about crewai repo")
|
||||
mock_adapter.add.assert_not_called()
|
||||
mock_adapter.query.assert_called_once_with("tell me about crewai repo")
|
||||
mock_adapter.query.assert_called_once_with("tell me about crewai repo", similarity_threshold=0.6, limit=5)
|
||||
|
||||
Reference in New Issue
Block a user