Compare commits

..

2 Commits

Author SHA1 Message Date
Devin AI
88d93cd65b Add comprehensive docstrings to LanceDB adapter
- Add Google-style docstrings to all public functions and classes in lancedb_adapter.py
- Include Args, Returns, Raises, and Example sections where appropriate
- Add test file to verify docstrings exist for all public API methods
- Addresses issue #3955

Co-Authored-By: João <joao@crewai.com>
2025-11-19 17:58:31 +00:00
Greyson LaLonde
d160f0874a chore: don't fail on cleanup error
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
Build uv cache / build-cache (3.10) (push) Has been cancelled
Build uv cache / build-cache (3.11) (push) Has been cancelled
Build uv cache / build-cache (3.12) (push) Has been cancelled
Build uv cache / build-cache (3.13) (push) Has been cancelled
2025-11-19 01:28:25 -05:00
8 changed files with 188 additions and 361 deletions

View File

@@ -14,6 +14,22 @@ from crewai_tools.tools.rag.rag_tool import Adapter
def _default_embedding_function():
"""Create a default embedding function using OpenAI's text-embedding-ada-002 model.
This function creates and returns an embedding function that uses OpenAI's API
to generate embeddings for text inputs. The embedding function is used by the
LanceDBAdapter to convert text queries into vector representations for similarity search.
Returns:
Callable: A function that takes a list of strings and returns their embeddings
as a list of vectors.
Example:
>>> embed_fn = _default_embedding_function()
>>> embeddings = embed_fn(["Hello world"])
>>> len(embeddings[0]) # Vector dimension
1536
"""
client = OpenAIClient()
def _embedding_function(input):
@@ -24,6 +40,32 @@ def _default_embedding_function():
class LanceDBAdapter(Adapter):
"""Adapter for integrating LanceDB vector database with CrewAI RAG tools.
LanceDBAdapter provides a bridge between CrewAI's RAG (Retrieval-Augmented Generation)
system and LanceDB, enabling efficient vector similarity search for knowledge retrieval.
It handles embedding generation, vector search, and data ingestion with precise control
over query parameters and column mappings.
Attributes:
uri: Database connection URI or path to the LanceDB database.
table_name: Name of the table to query within the LanceDB database.
embedding_function: Function to convert text into embeddings. Defaults to OpenAI's
text-embedding-ada-002 model.
top_k: Number of top results to return from similarity search. Defaults to 3.
vector_column_name: Name of the column containing vector embeddings. Defaults to "vector".
text_column_name: Name of the column containing text content. Defaults to "text".
Example:
>>> from crewai_tools.adapters.lancedb_adapter import LanceDBAdapter
>>> adapter = LanceDBAdapter(
... uri="./my_lancedb",
... table_name="documents",
... top_k=5
... )
>>> results = adapter.query("What is machine learning?")
>>> print(results)
"""
uri: str | Path
table_name: str
embedding_function: Callable = Field(default_factory=_default_embedding_function)
@@ -35,12 +77,44 @@ class LanceDBAdapter(Adapter):
_table: LanceDBTable = PrivateAttr()
def model_post_init(self, __context: Any) -> None:
"""Initialize the database connection and table after model instantiation.
This method is automatically called after the Pydantic model is initialized.
It establishes the connection to the LanceDB database and opens the specified
table for querying and data operations.
Args:
__context: Pydantic context object passed during initialization.
Raises:
Exception: If the database connection fails or the table does not exist.
"""
self._db = lancedb_connect(self.uri)
self._table = self._db.open_table(self.table_name)
super().model_post_init(__context)
def query(self, question: str) -> str: # type: ignore[override]
"""Perform a vector similarity search for the given question.
This method converts the input question into an embedding vector and searches
the LanceDB table for the most similar entries. It returns the top-k results
based on vector similarity, providing precise retrieval for RAG applications.
Args:
question: The text query to search for in the vector database.
Returns:
A string containing the concatenated text results from the top-k most
similar entries, separated by newlines.
Example:
>>> adapter = LanceDBAdapter(uri="./db", table_name="docs")
>>> results = adapter.query("What is CrewAI?")
>>> print(results)
CrewAI is a framework for orchestrating AI agents...
CrewAI provides precise control over agent workflows...
"""
query = self.embedding_function([question])[0]
results = (
self._table.search(query, vector_column_name=self.vector_column_name)
@@ -56,4 +130,23 @@ class LanceDBAdapter(Adapter):
*args: Any,
**kwargs: Any,
) -> None:
"""Add data to the LanceDB table.
This method provides a direct interface to add new records to the underlying
LanceDB table. It accepts the same arguments as the LanceDB table's add method,
allowing flexible data ingestion for building knowledge bases.
Args:
*args: Positional arguments to pass to the LanceDB table's add method.
**kwargs: Keyword arguments to pass to the LanceDB table's add method.
Common kwargs include 'data' (list of records) and 'mode' (append/overwrite).
Example:
>>> adapter = LanceDBAdapter(uri="./db", table_name="docs")
>>> data = [
... {"text": "CrewAI enables agent collaboration", "vector": [0.1, 0.2, ...]},
... {"text": "LanceDB provides vector storage", "vector": [0.3, 0.4, ...]}
... ]
>>> adapter.add(data=data)
"""
self._table.add(*args, **kwargs)

View File

@@ -0,0 +1,62 @@
"""Test that LanceDB adapter has proper docstrings."""
import inspect
import pytest
lancedb = pytest.importorskip("lancedb")
from crewai_tools.adapters.lancedb_adapter import (
LanceDBAdapter,
_default_embedding_function,
)
def test_lancedb_adapter_class_has_docstring():
"""Verify that LanceDBAdapter class has a docstring."""
assert LanceDBAdapter.__doc__ is not None, "LanceDBAdapter class is missing a docstring"
assert len(LanceDBAdapter.__doc__.strip()) > 0, "LanceDBAdapter docstring is empty"
def test_lancedb_adapter_model_post_init_has_docstring():
"""Verify that model_post_init method has a docstring."""
assert (
LanceDBAdapter.model_post_init.__doc__ is not None
), "model_post_init method is missing a docstring"
assert (
len(LanceDBAdapter.model_post_init.__doc__.strip()) > 0
), "model_post_init docstring is empty"
def test_lancedb_adapter_query_has_docstring():
"""Verify that query method has a docstring."""
assert LanceDBAdapter.query.__doc__ is not None, "query method is missing a docstring"
assert len(LanceDBAdapter.query.__doc__.strip()) > 0, "query docstring is empty"
def test_lancedb_adapter_add_has_docstring():
"""Verify that add method has a docstring."""
assert LanceDBAdapter.add.__doc__ is not None, "add method is missing a docstring"
assert len(LanceDBAdapter.add.__doc__.strip()) > 0, "add docstring is empty"
def test_default_embedding_function_has_docstring():
"""Verify that _default_embedding_function has a docstring."""
assert (
_default_embedding_function.__doc__ is not None
), "_default_embedding_function is missing a docstring"
assert (
len(_default_embedding_function.__doc__.strip()) > 0
), "_default_embedding_function docstring is empty"
def test_docstrings_contain_required_sections():
"""Verify that docstrings contain Args, Returns, or Example sections where appropriate."""
query_doc = LanceDBAdapter.query.__doc__
assert query_doc is not None
assert "Args:" in query_doc or "Parameters:" in query_doc, "query docstring should have Args/Parameters section"
assert "Returns:" in query_doc, "query docstring should have Returns section"
add_doc = LanceDBAdapter.add.__doc__
assert add_doc is not None
assert "Args:" in add_doc or "Parameters:" in add_doc, "add docstring should have Args/Parameters section"

View File

@@ -1416,43 +1416,6 @@ class Agent(BaseAgent):
)
return None
def _build_runtime_tools(self) -> list[BaseTool]:
"""Build a list of tools for runtime execution without mutating self.tools.
This method combines tools from multiple sources:
- Agent's configured tools (self.tools)
- Platform tools (if self.apps is set)
- MCP tools (if self.mcps is set)
- Multimodal tools (if self.multimodal is True)
Returns:
A deduplicated list of tools ready for execution.
"""
runtime_tools: list[BaseTool] = list(self.tools or [])
if self.apps:
platform_tools = self.get_platform_tools(self.apps)
if platform_tools:
runtime_tools.extend(platform_tools)
if self.mcps:
mcp_tools = self.get_mcp_tools(self.mcps)
if mcp_tools:
runtime_tools.extend(mcp_tools)
if self.multimodal:
multimodal_tools = self.get_multimodal_tools()
runtime_tools.extend(multimodal_tools)
seen_names: set[str] = set()
deduplicated_tools: list[BaseTool] = []
for tool in runtime_tools:
if tool.name not in seen_names:
seen_names.add(tool.name)
deduplicated_tools.append(tool)
return deduplicated_tools
def kickoff(
self,
messages: str | list[LLMMessage],
@@ -1473,7 +1436,14 @@ class Agent(BaseAgent):
Returns:
LiteAgentOutput: The result of the agent execution.
"""
runtime_tools = self._build_runtime_tools()
if self.apps:
platform_tools = self.get_platform_tools(self.apps)
if platform_tools:
self.tools.extend(platform_tools)
if self.mcps:
mcps = self.get_mcp_tools(self.mcps)
if mcps:
self.tools.extend(mcps)
lite_agent = LiteAgent(
id=self.id,
@@ -1481,7 +1451,7 @@ class Agent(BaseAgent):
goal=self.goal,
backstory=self.backstory,
llm=self.llm,
tools=runtime_tools,
tools=self.tools or [],
max_iterations=self.max_iter,
max_execution_time=self.max_execution_time,
respect_context_window=self.respect_context_window,
@@ -1514,15 +1484,12 @@ class Agent(BaseAgent):
Returns:
LiteAgentOutput: The result of the agent execution.
"""
runtime_tools = self._build_runtime_tools()
lite_agent = LiteAgent(
id=self.id,
role=self.role,
goal=self.goal,
backstory=self.backstory,
llm=self.llm,
tools=runtime_tools,
tools=self.tools or [],
max_iterations=self.max_iter,
max_execution_time=self.max_execution_time,
respect_context_window=self.respect_context_window,

View File

@@ -1,299 +0,0 @@
"""Test Agent multimodal kickoff functionality."""
from unittest.mock import MagicMock, patch
import pytest
from crewai import Agent
from crewai.lite_agent import LiteAgent
from crewai.lite_agent_output import LiteAgentOutput
from crewai.tools.agent_tools.add_image_tool import AddImageTool
from crewai.tools.base_tool import BaseTool
@pytest.fixture
def mock_lite_agent():
"""Fixture to mock LiteAgent to avoid LLM calls."""
with patch("crewai.agent.core.LiteAgent") as mock_lite_agent_class:
mock_instance = MagicMock(spec=LiteAgent)
mock_output = LiteAgentOutput(
raw="Test output",
pydantic=None,
agent_role="test role",
usage_metrics=None,
messages=[],
)
mock_instance.kickoff.return_value = mock_output
mock_instance.kickoff_async.return_value = mock_output
mock_lite_agent_class.return_value = mock_instance
yield mock_lite_agent_class
def test_agent_kickoff_with_multimodal_true_adds_image_tool(mock_lite_agent):
"""Test that when multimodal=True, AddImageTool is added to the tools."""
agent = Agent(
role="test role",
goal="test goal",
backstory="test backstory",
multimodal=True,
)
agent.kickoff("Test message")
mock_lite_agent.assert_called_once()
call_kwargs = mock_lite_agent.call_args[1]
tools = call_kwargs["tools"]
assert any(isinstance(tool, AddImageTool) for tool in tools)
def test_agent_kickoff_with_multimodal_false_does_not_add_image_tool(mock_lite_agent):
"""Test that when multimodal=False, AddImageTool is not added to the tools."""
agent = Agent(
role="test role",
goal="test goal",
backstory="test backstory",
multimodal=False,
)
agent.kickoff("Test message")
mock_lite_agent.assert_called_once()
call_kwargs = mock_lite_agent.call_args[1]
tools = call_kwargs["tools"]
assert not any(isinstance(tool, AddImageTool) for tool in tools)
def test_agent_kickoff_does_not_mutate_self_tools(mock_lite_agent):
"""Test that calling kickoff does not mutate self.tools."""
class DummyTool(BaseTool):
name: str = "dummy_tool"
description: str = "A dummy tool"
def _run(self, **kwargs):
return "dummy result"
dummy_tool = DummyTool()
agent = Agent(
role="test role",
goal="test goal",
backstory="test backstory",
tools=[dummy_tool],
multimodal=True,
)
original_tools_count = len(agent.tools)
original_tools = list(agent.tools)
agent.kickoff("Test message")
assert len(agent.tools) == original_tools_count
assert agent.tools == original_tools
def test_agent_kickoff_multiple_calls_does_not_duplicate_tools(mock_lite_agent):
"""Test that calling kickoff multiple times does not duplicate tools."""
agent = Agent(
role="test role",
goal="test goal",
backstory="test backstory",
multimodal=True,
)
agent.kickoff("Test message 1")
first_call_tools = mock_lite_agent.call_args[1]["tools"]
first_call_image_tools = [
tool for tool in first_call_tools if isinstance(tool, AddImageTool)
]
agent.kickoff("Test message 2")
second_call_tools = mock_lite_agent.call_args[1]["tools"]
second_call_image_tools = [
tool for tool in second_call_tools if isinstance(tool, AddImageTool)
]
assert len(first_call_image_tools) == 1
assert len(second_call_image_tools) == 1
def test_agent_kickoff_async_with_multimodal_true_adds_image_tool(mock_lite_agent):
"""Test that when multimodal=True, AddImageTool is added in kickoff_async."""
agent = Agent(
role="test role",
goal="test goal",
backstory="test backstory",
multimodal=True,
)
import asyncio
asyncio.run(agent.kickoff_async("Test message"))
mock_lite_agent.assert_called_once()
call_kwargs = mock_lite_agent.call_args[1]
tools = call_kwargs["tools"]
assert any(isinstance(tool, AddImageTool) for tool in tools)
def test_agent_kickoff_async_with_multimodal_false_does_not_add_image_tool(
mock_lite_agent,
):
"""Test that when multimodal=False, AddImageTool is not added in kickoff_async."""
agent = Agent(
role="test role",
goal="test goal",
backstory="test backstory",
multimodal=False,
)
import asyncio
asyncio.run(agent.kickoff_async("Test message"))
mock_lite_agent.assert_called_once()
call_kwargs = mock_lite_agent.call_args[1]
tools = call_kwargs["tools"]
assert not any(isinstance(tool, AddImageTool) for tool in tools)
def test_agent_kickoff_async_does_not_mutate_self_tools(mock_lite_agent):
"""Test that calling kickoff_async does not mutate self.tools."""
class DummyTool(BaseTool):
name: str = "dummy_tool"
description: str = "A dummy tool"
def _run(self, **kwargs):
return "dummy result"
dummy_tool = DummyTool()
agent = Agent(
role="test role",
goal="test goal",
backstory="test backstory",
tools=[dummy_tool],
multimodal=True,
)
original_tools_count = len(agent.tools)
original_tools = list(agent.tools)
import asyncio
asyncio.run(agent.kickoff_async("Test message"))
assert len(agent.tools) == original_tools_count
assert agent.tools == original_tools
def test_agent_kickoff_with_existing_tools_and_multimodal(mock_lite_agent):
"""Test that multimodal tools are added alongside existing tools."""
class DummyTool(BaseTool):
name: str = "dummy_tool"
description: str = "A dummy tool"
def _run(self, **kwargs):
return "dummy result"
dummy_tool = DummyTool()
agent = Agent(
role="test role",
goal="test goal",
backstory="test backstory",
tools=[dummy_tool],
multimodal=True,
)
agent.kickoff("Test message")
mock_lite_agent.assert_called_once()
call_kwargs = mock_lite_agent.call_args[1]
tools = call_kwargs["tools"]
assert any(isinstance(tool, DummyTool) for tool in tools)
assert any(isinstance(tool, AddImageTool) for tool in tools)
assert len(tools) == 2
def test_agent_kickoff_deduplicates_tools_by_name(mock_lite_agent):
"""Test that tools with the same name are deduplicated."""
class DummyTool(BaseTool):
name: str = "dummy_tool"
description: str = "A dummy tool"
def _run(self, **kwargs):
return "dummy result"
dummy_tool1 = DummyTool()
dummy_tool2 = DummyTool()
agent = Agent(
role="test role",
goal="test goal",
backstory="test backstory",
tools=[dummy_tool1, dummy_tool2],
multimodal=False,
)
agent.kickoff("Test message")
mock_lite_agent.assert_called_once()
call_kwargs = mock_lite_agent.call_args[1]
tools = call_kwargs["tools"]
dummy_tools = [tool for tool in tools if isinstance(tool, DummyTool)]
assert len(dummy_tools) == 1
def test_agent_kickoff_async_includes_platform_and_mcp_tools(mock_lite_agent):
"""Test that kickoff_async includes platform and MCP tools like kickoff does."""
with patch.object(Agent, "get_platform_tools") as mock_platform_tools, patch.object(
Agent, "get_mcp_tools"
) as mock_mcp_tools:
class PlatformTool(BaseTool):
name: str = "platform_tool"
description: str = "A platform tool"
def _run(self, **kwargs):
return "platform result"
class MCPTool(BaseTool):
name: str = "mcp_tool"
description: str = "An MCP tool"
def _run(self, **kwargs):
return "mcp result"
mock_platform_tools.return_value = [PlatformTool()]
mock_mcp_tools.return_value = [MCPTool()]
agent = Agent(
role="test role",
goal="test goal",
backstory="test backstory",
apps=["test_app"],
mcps=["crewai-amp:dummy"],
multimodal=True,
)
import asyncio
asyncio.run(agent.kickoff_async("Test message"))
mock_lite_agent.assert_called_once()
call_kwargs = mock_lite_agent.call_args[1]
tools = call_kwargs["tools"]
mock_platform_tools.assert_called_once()
mock_mcp_tools.assert_called_once()
assert any(isinstance(tool, PlatformTool) for tool in tools)
assert any(isinstance(tool, MCPTool) for tool in tools)
assert any(isinstance(tool, AddImageTool) for tool in tools)

View File

@@ -13,7 +13,7 @@ load_result = load_dotenv(override=True)
@pytest.fixture(autouse=True)
def setup_test_environment():
"""Set up test environment with a temporary directory for SQLite storage."""
with tempfile.TemporaryDirectory() as temp_dir:
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as temp_dir:
# Create the directory with proper permissions
storage_dir = Path(temp_dir) / "crewai_test_storage"
storage_dir.mkdir(parents=True, exist_ok=True)

View File

@@ -144,9 +144,8 @@ class TestAgentEvaluator:
mock_crew.tasks.append(task)
events = {}
started_event = threading.Event()
completed_event = threading.Event()
task_completed_event = threading.Event()
results_condition = threading.Condition()
results_ready = False
agent_evaluator = AgentEvaluator(
agents=[agent], evaluators=[GoalAlignmentEvaluator()]
@@ -156,13 +155,11 @@ class TestAgentEvaluator:
async def capture_started(source, event):
if event.agent_id == str(agent.id):
events["started"] = event
started_event.set()
@crewai_event_bus.on(AgentEvaluationCompletedEvent)
async def capture_completed(source, event):
if event.agent_id == str(agent.id):
events["completed"] = event
completed_event.set()
@crewai_event_bus.on(AgentEvaluationFailedEvent)
def capture_failed(source, event):
@@ -170,17 +167,20 @@ class TestAgentEvaluator:
@crewai_event_bus.on(TaskCompletedEvent)
async def on_task_completed(source, event):
# TaskCompletedEvent fires AFTER evaluation results are stored
nonlocal results_ready
if event.task and event.task.id == task.id:
task_completed_event.set()
while not agent_evaluator.get_evaluation_results().get(agent.role):
pass
with results_condition:
results_ready = True
results_condition.notify()
mock_crew.kickoff()
assert started_event.wait(timeout=5), "Timeout waiting for started event"
assert completed_event.wait(timeout=5), "Timeout waiting for completed event"
assert task_completed_event.wait(timeout=5), (
"Timeout waiting for task completion"
)
with results_condition:
assert results_condition.wait_for(
lambda: results_ready, timeout=5
), "Timeout waiting for evaluation results"
assert events.keys() == {"started", "completed"}
assert events["started"].agent_id == str(agent.id)

View File

@@ -647,6 +647,7 @@ def test_handle_streaming_tool_calls_no_tools(mock_emit):
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.skip(reason="Highly flaky on ci")
def test_llm_call_when_stop_is_unsupported(caplog):
llm = LLM(model="o1-mini", stop=["stop"], is_litellm=True)
with caplog.at_level(logging.INFO):
@@ -657,6 +658,7 @@ def test_llm_call_when_stop_is_unsupported(caplog):
@pytest.mark.vcr(filter_headers=["authorization"])
@pytest.mark.skip(reason="Highly flaky on ci")
def test_llm_call_when_stop_is_unsupported_when_additional_drop_params_is_provided(
caplog,
):
@@ -664,7 +666,6 @@ def test_llm_call_when_stop_is_unsupported_when_additional_drop_params_is_provid
model="o1-mini",
stop=["stop"],
additional_drop_params=["another_param"],
is_litellm=True,
)
with caplog.at_level(logging.INFO):
result = llm.call("What is the capital of France?")

View File

@@ -273,12 +273,15 @@ def another_simple_tool():
def test_internal_crew_with_mcp():
from crewai_tools import MCPServerAdapter
from crewai_tools.adapters.mcp_adapter import ToolCollection
from crewai_tools.adapters.tool_collection import ToolCollection
mock = Mock(spec=MCPServerAdapter)
mock.tools = ToolCollection([simple_tool, another_simple_tool])
with patch("crewai_tools.MCPServerAdapter", return_value=mock) as adapter_mock:
mock_adapter = Mock()
mock_adapter.tools = ToolCollection([simple_tool, another_simple_tool])
with (
patch("crewai_tools.MCPServerAdapter", return_value=mock_adapter) as adapter_mock,
patch("crewai.llm.LLM.__new__", return_value=Mock()),
):
crew = InternalCrewWithMCP()
assert crew.reporting_analyst().tools == [simple_tool, another_simple_tool]
assert crew.researcher().tools == [simple_tool]