mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-07 23:28:30 +00:00
Compare commits
2 Commits
devin/1763
...
devin/1763
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
88d93cd65b | ||
|
|
d160f0874a |
@@ -14,6 +14,22 @@ from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
|
||||
|
||||
def _default_embedding_function():
|
||||
"""Create a default embedding function using OpenAI's text-embedding-ada-002 model.
|
||||
|
||||
This function creates and returns an embedding function that uses OpenAI's API
|
||||
to generate embeddings for text inputs. The embedding function is used by the
|
||||
LanceDBAdapter to convert text queries into vector representations for similarity search.
|
||||
|
||||
Returns:
|
||||
Callable: A function that takes a list of strings and returns their embeddings
|
||||
as a list of vectors.
|
||||
|
||||
Example:
|
||||
>>> embed_fn = _default_embedding_function()
|
||||
>>> embeddings = embed_fn(["Hello world"])
|
||||
>>> len(embeddings[0]) # Vector dimension
|
||||
1536
|
||||
"""
|
||||
client = OpenAIClient()
|
||||
|
||||
def _embedding_function(input):
|
||||
@@ -24,6 +40,32 @@ def _default_embedding_function():
|
||||
|
||||
|
||||
class LanceDBAdapter(Adapter):
|
||||
"""Adapter for integrating LanceDB vector database with CrewAI RAG tools.
|
||||
|
||||
LanceDBAdapter provides a bridge between CrewAI's RAG (Retrieval-Augmented Generation)
|
||||
system and LanceDB, enabling efficient vector similarity search for knowledge retrieval.
|
||||
It handles embedding generation, vector search, and data ingestion with precise control
|
||||
over query parameters and column mappings.
|
||||
|
||||
Attributes:
|
||||
uri: Database connection URI or path to the LanceDB database.
|
||||
table_name: Name of the table to query within the LanceDB database.
|
||||
embedding_function: Function to convert text into embeddings. Defaults to OpenAI's
|
||||
text-embedding-ada-002 model.
|
||||
top_k: Number of top results to return from similarity search. Defaults to 3.
|
||||
vector_column_name: Name of the column containing vector embeddings. Defaults to "vector".
|
||||
text_column_name: Name of the column containing text content. Defaults to "text".
|
||||
|
||||
Example:
|
||||
>>> from crewai_tools.adapters.lancedb_adapter import LanceDBAdapter
|
||||
>>> adapter = LanceDBAdapter(
|
||||
... uri="./my_lancedb",
|
||||
... table_name="documents",
|
||||
... top_k=5
|
||||
... )
|
||||
>>> results = adapter.query("What is machine learning?")
|
||||
>>> print(results)
|
||||
"""
|
||||
uri: str | Path
|
||||
table_name: str
|
||||
embedding_function: Callable = Field(default_factory=_default_embedding_function)
|
||||
@@ -35,12 +77,44 @@ class LanceDBAdapter(Adapter):
|
||||
_table: LanceDBTable = PrivateAttr()
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
"""Initialize the database connection and table after model instantiation.
|
||||
|
||||
This method is automatically called after the Pydantic model is initialized.
|
||||
It establishes the connection to the LanceDB database and opens the specified
|
||||
table for querying and data operations.
|
||||
|
||||
Args:
|
||||
__context: Pydantic context object passed during initialization.
|
||||
|
||||
Raises:
|
||||
Exception: If the database connection fails or the table does not exist.
|
||||
"""
|
||||
self._db = lancedb_connect(self.uri)
|
||||
self._table = self._db.open_table(self.table_name)
|
||||
|
||||
super().model_post_init(__context)
|
||||
|
||||
def query(self, question: str) -> str: # type: ignore[override]
|
||||
"""Perform a vector similarity search for the given question.
|
||||
|
||||
This method converts the input question into an embedding vector and searches
|
||||
the LanceDB table for the most similar entries. It returns the top-k results
|
||||
based on vector similarity, providing precise retrieval for RAG applications.
|
||||
|
||||
Args:
|
||||
question: The text query to search for in the vector database.
|
||||
|
||||
Returns:
|
||||
A string containing the concatenated text results from the top-k most
|
||||
similar entries, separated by newlines.
|
||||
|
||||
Example:
|
||||
>>> adapter = LanceDBAdapter(uri="./db", table_name="docs")
|
||||
>>> results = adapter.query("What is CrewAI?")
|
||||
>>> print(results)
|
||||
CrewAI is a framework for orchestrating AI agents...
|
||||
CrewAI provides precise control over agent workflows...
|
||||
"""
|
||||
query = self.embedding_function([question])[0]
|
||||
results = (
|
||||
self._table.search(query, vector_column_name=self.vector_column_name)
|
||||
@@ -56,4 +130,23 @@ class LanceDBAdapter(Adapter):
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Add data to the LanceDB table.
|
||||
|
||||
This method provides a direct interface to add new records to the underlying
|
||||
LanceDB table. It accepts the same arguments as the LanceDB table's add method,
|
||||
allowing flexible data ingestion for building knowledge bases.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments to pass to the LanceDB table's add method.
|
||||
**kwargs: Keyword arguments to pass to the LanceDB table's add method.
|
||||
Common kwargs include 'data' (list of records) and 'mode' (append/overwrite).
|
||||
|
||||
Example:
|
||||
>>> adapter = LanceDBAdapter(uri="./db", table_name="docs")
|
||||
>>> data = [
|
||||
... {"text": "CrewAI enables agent collaboration", "vector": [0.1, 0.2, ...]},
|
||||
... {"text": "LanceDB provides vector storage", "vector": [0.3, 0.4, ...]}
|
||||
... ]
|
||||
>>> adapter.add(data=data)
|
||||
"""
|
||||
self._table.add(*args, **kwargs)
|
||||
|
||||
62
lib/crewai-tools/tests/adapters/test_lancedb_adapter_docs.py
Normal file
62
lib/crewai-tools/tests/adapters/test_lancedb_adapter_docs.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""Test that LanceDB adapter has proper docstrings."""
|
||||
|
||||
import inspect
|
||||
|
||||
import pytest
|
||||
|
||||
lancedb = pytest.importorskip("lancedb")
|
||||
|
||||
from crewai_tools.adapters.lancedb_adapter import (
|
||||
LanceDBAdapter,
|
||||
_default_embedding_function,
|
||||
)
|
||||
|
||||
|
||||
def test_lancedb_adapter_class_has_docstring():
|
||||
"""Verify that LanceDBAdapter class has a docstring."""
|
||||
assert LanceDBAdapter.__doc__ is not None, "LanceDBAdapter class is missing a docstring"
|
||||
assert len(LanceDBAdapter.__doc__.strip()) > 0, "LanceDBAdapter docstring is empty"
|
||||
|
||||
|
||||
def test_lancedb_adapter_model_post_init_has_docstring():
|
||||
"""Verify that model_post_init method has a docstring."""
|
||||
assert (
|
||||
LanceDBAdapter.model_post_init.__doc__ is not None
|
||||
), "model_post_init method is missing a docstring"
|
||||
assert (
|
||||
len(LanceDBAdapter.model_post_init.__doc__.strip()) > 0
|
||||
), "model_post_init docstring is empty"
|
||||
|
||||
|
||||
def test_lancedb_adapter_query_has_docstring():
|
||||
"""Verify that query method has a docstring."""
|
||||
assert LanceDBAdapter.query.__doc__ is not None, "query method is missing a docstring"
|
||||
assert len(LanceDBAdapter.query.__doc__.strip()) > 0, "query docstring is empty"
|
||||
|
||||
|
||||
def test_lancedb_adapter_add_has_docstring():
|
||||
"""Verify that add method has a docstring."""
|
||||
assert LanceDBAdapter.add.__doc__ is not None, "add method is missing a docstring"
|
||||
assert len(LanceDBAdapter.add.__doc__.strip()) > 0, "add docstring is empty"
|
||||
|
||||
|
||||
def test_default_embedding_function_has_docstring():
|
||||
"""Verify that _default_embedding_function has a docstring."""
|
||||
assert (
|
||||
_default_embedding_function.__doc__ is not None
|
||||
), "_default_embedding_function is missing a docstring"
|
||||
assert (
|
||||
len(_default_embedding_function.__doc__.strip()) > 0
|
||||
), "_default_embedding_function docstring is empty"
|
||||
|
||||
|
||||
def test_docstrings_contain_required_sections():
|
||||
"""Verify that docstrings contain Args, Returns, or Example sections where appropriate."""
|
||||
query_doc = LanceDBAdapter.query.__doc__
|
||||
assert query_doc is not None
|
||||
assert "Args:" in query_doc or "Parameters:" in query_doc, "query docstring should have Args/Parameters section"
|
||||
assert "Returns:" in query_doc, "query docstring should have Returns section"
|
||||
|
||||
add_doc = LanceDBAdapter.add.__doc__
|
||||
assert add_doc is not None
|
||||
assert "Args:" in add_doc or "Parameters:" in add_doc, "add docstring should have Args/Parameters section"
|
||||
@@ -1416,43 +1416,6 @@ class Agent(BaseAgent):
|
||||
)
|
||||
return None
|
||||
|
||||
def _build_runtime_tools(self) -> list[BaseTool]:
|
||||
"""Build a list of tools for runtime execution without mutating self.tools.
|
||||
|
||||
This method combines tools from multiple sources:
|
||||
- Agent's configured tools (self.tools)
|
||||
- Platform tools (if self.apps is set)
|
||||
- MCP tools (if self.mcps is set)
|
||||
- Multimodal tools (if self.multimodal is True)
|
||||
|
||||
Returns:
|
||||
A deduplicated list of tools ready for execution.
|
||||
"""
|
||||
runtime_tools: list[BaseTool] = list(self.tools or [])
|
||||
|
||||
if self.apps:
|
||||
platform_tools = self.get_platform_tools(self.apps)
|
||||
if platform_tools:
|
||||
runtime_tools.extend(platform_tools)
|
||||
|
||||
if self.mcps:
|
||||
mcp_tools = self.get_mcp_tools(self.mcps)
|
||||
if mcp_tools:
|
||||
runtime_tools.extend(mcp_tools)
|
||||
|
||||
if self.multimodal:
|
||||
multimodal_tools = self.get_multimodal_tools()
|
||||
runtime_tools.extend(multimodal_tools)
|
||||
|
||||
seen_names: set[str] = set()
|
||||
deduplicated_tools: list[BaseTool] = []
|
||||
for tool in runtime_tools:
|
||||
if tool.name not in seen_names:
|
||||
seen_names.add(tool.name)
|
||||
deduplicated_tools.append(tool)
|
||||
|
||||
return deduplicated_tools
|
||||
|
||||
def kickoff(
|
||||
self,
|
||||
messages: str | list[LLMMessage],
|
||||
@@ -1473,7 +1436,14 @@ class Agent(BaseAgent):
|
||||
Returns:
|
||||
LiteAgentOutput: The result of the agent execution.
|
||||
"""
|
||||
runtime_tools = self._build_runtime_tools()
|
||||
if self.apps:
|
||||
platform_tools = self.get_platform_tools(self.apps)
|
||||
if platform_tools:
|
||||
self.tools.extend(platform_tools)
|
||||
if self.mcps:
|
||||
mcps = self.get_mcp_tools(self.mcps)
|
||||
if mcps:
|
||||
self.tools.extend(mcps)
|
||||
|
||||
lite_agent = LiteAgent(
|
||||
id=self.id,
|
||||
@@ -1481,7 +1451,7 @@ class Agent(BaseAgent):
|
||||
goal=self.goal,
|
||||
backstory=self.backstory,
|
||||
llm=self.llm,
|
||||
tools=runtime_tools,
|
||||
tools=self.tools or [],
|
||||
max_iterations=self.max_iter,
|
||||
max_execution_time=self.max_execution_time,
|
||||
respect_context_window=self.respect_context_window,
|
||||
@@ -1514,15 +1484,12 @@ class Agent(BaseAgent):
|
||||
Returns:
|
||||
LiteAgentOutput: The result of the agent execution.
|
||||
"""
|
||||
runtime_tools = self._build_runtime_tools()
|
||||
|
||||
lite_agent = LiteAgent(
|
||||
id=self.id,
|
||||
role=self.role,
|
||||
goal=self.goal,
|
||||
backstory=self.backstory,
|
||||
llm=self.llm,
|
||||
tools=runtime_tools,
|
||||
tools=self.tools or [],
|
||||
max_iterations=self.max_iter,
|
||||
max_execution_time=self.max_execution_time,
|
||||
respect_context_window=self.respect_context_window,
|
||||
|
||||
@@ -1,299 +0,0 @@
|
||||
"""Test Agent multimodal kickoff functionality."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai import Agent
|
||||
from crewai.lite_agent import LiteAgent
|
||||
from crewai.lite_agent_output import LiteAgentOutput
|
||||
from crewai.tools.agent_tools.add_image_tool import AddImageTool
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_lite_agent():
|
||||
"""Fixture to mock LiteAgent to avoid LLM calls."""
|
||||
with patch("crewai.agent.core.LiteAgent") as mock_lite_agent_class:
|
||||
mock_instance = MagicMock(spec=LiteAgent)
|
||||
mock_output = LiteAgentOutput(
|
||||
raw="Test output",
|
||||
pydantic=None,
|
||||
agent_role="test role",
|
||||
usage_metrics=None,
|
||||
messages=[],
|
||||
)
|
||||
mock_instance.kickoff.return_value = mock_output
|
||||
mock_instance.kickoff_async.return_value = mock_output
|
||||
mock_lite_agent_class.return_value = mock_instance
|
||||
yield mock_lite_agent_class
|
||||
|
||||
|
||||
def test_agent_kickoff_with_multimodal_true_adds_image_tool(mock_lite_agent):
|
||||
"""Test that when multimodal=True, AddImageTool is added to the tools."""
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
multimodal=True,
|
||||
)
|
||||
|
||||
agent.kickoff("Test message")
|
||||
|
||||
mock_lite_agent.assert_called_once()
|
||||
call_kwargs = mock_lite_agent.call_args[1]
|
||||
tools = call_kwargs["tools"]
|
||||
|
||||
assert any(isinstance(tool, AddImageTool) for tool in tools)
|
||||
|
||||
|
||||
def test_agent_kickoff_with_multimodal_false_does_not_add_image_tool(mock_lite_agent):
|
||||
"""Test that when multimodal=False, AddImageTool is not added to the tools."""
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
multimodal=False,
|
||||
)
|
||||
|
||||
agent.kickoff("Test message")
|
||||
|
||||
mock_lite_agent.assert_called_once()
|
||||
call_kwargs = mock_lite_agent.call_args[1]
|
||||
tools = call_kwargs["tools"]
|
||||
|
||||
assert not any(isinstance(tool, AddImageTool) for tool in tools)
|
||||
|
||||
|
||||
def test_agent_kickoff_does_not_mutate_self_tools(mock_lite_agent):
|
||||
"""Test that calling kickoff does not mutate self.tools."""
|
||||
|
||||
class DummyTool(BaseTool):
|
||||
name: str = "dummy_tool"
|
||||
description: str = "A dummy tool"
|
||||
|
||||
def _run(self, **kwargs):
|
||||
return "dummy result"
|
||||
|
||||
dummy_tool = DummyTool()
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
tools=[dummy_tool],
|
||||
multimodal=True,
|
||||
)
|
||||
|
||||
original_tools_count = len(agent.tools)
|
||||
original_tools = list(agent.tools)
|
||||
|
||||
agent.kickoff("Test message")
|
||||
|
||||
assert len(agent.tools) == original_tools_count
|
||||
assert agent.tools == original_tools
|
||||
|
||||
|
||||
def test_agent_kickoff_multiple_calls_does_not_duplicate_tools(mock_lite_agent):
|
||||
"""Test that calling kickoff multiple times does not duplicate tools."""
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
multimodal=True,
|
||||
)
|
||||
|
||||
agent.kickoff("Test message 1")
|
||||
first_call_tools = mock_lite_agent.call_args[1]["tools"]
|
||||
first_call_image_tools = [
|
||||
tool for tool in first_call_tools if isinstance(tool, AddImageTool)
|
||||
]
|
||||
|
||||
agent.kickoff("Test message 2")
|
||||
second_call_tools = mock_lite_agent.call_args[1]["tools"]
|
||||
second_call_image_tools = [
|
||||
tool for tool in second_call_tools if isinstance(tool, AddImageTool)
|
||||
]
|
||||
|
||||
assert len(first_call_image_tools) == 1
|
||||
assert len(second_call_image_tools) == 1
|
||||
|
||||
|
||||
def test_agent_kickoff_async_with_multimodal_true_adds_image_tool(mock_lite_agent):
|
||||
"""Test that when multimodal=True, AddImageTool is added in kickoff_async."""
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
multimodal=True,
|
||||
)
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(agent.kickoff_async("Test message"))
|
||||
|
||||
mock_lite_agent.assert_called_once()
|
||||
call_kwargs = mock_lite_agent.call_args[1]
|
||||
tools = call_kwargs["tools"]
|
||||
|
||||
assert any(isinstance(tool, AddImageTool) for tool in tools)
|
||||
|
||||
|
||||
def test_agent_kickoff_async_with_multimodal_false_does_not_add_image_tool(
|
||||
mock_lite_agent,
|
||||
):
|
||||
"""Test that when multimodal=False, AddImageTool is not added in kickoff_async."""
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
multimodal=False,
|
||||
)
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(agent.kickoff_async("Test message"))
|
||||
|
||||
mock_lite_agent.assert_called_once()
|
||||
call_kwargs = mock_lite_agent.call_args[1]
|
||||
tools = call_kwargs["tools"]
|
||||
|
||||
assert not any(isinstance(tool, AddImageTool) for tool in tools)
|
||||
|
||||
|
||||
def test_agent_kickoff_async_does_not_mutate_self_tools(mock_lite_agent):
|
||||
"""Test that calling kickoff_async does not mutate self.tools."""
|
||||
|
||||
class DummyTool(BaseTool):
|
||||
name: str = "dummy_tool"
|
||||
description: str = "A dummy tool"
|
||||
|
||||
def _run(self, **kwargs):
|
||||
return "dummy result"
|
||||
|
||||
dummy_tool = DummyTool()
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
tools=[dummy_tool],
|
||||
multimodal=True,
|
||||
)
|
||||
|
||||
original_tools_count = len(agent.tools)
|
||||
original_tools = list(agent.tools)
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(agent.kickoff_async("Test message"))
|
||||
|
||||
assert len(agent.tools) == original_tools_count
|
||||
assert agent.tools == original_tools
|
||||
|
||||
|
||||
def test_agent_kickoff_with_existing_tools_and_multimodal(mock_lite_agent):
|
||||
"""Test that multimodal tools are added alongside existing tools."""
|
||||
|
||||
class DummyTool(BaseTool):
|
||||
name: str = "dummy_tool"
|
||||
description: str = "A dummy tool"
|
||||
|
||||
def _run(self, **kwargs):
|
||||
return "dummy result"
|
||||
|
||||
dummy_tool = DummyTool()
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
tools=[dummy_tool],
|
||||
multimodal=True,
|
||||
)
|
||||
|
||||
agent.kickoff("Test message")
|
||||
|
||||
mock_lite_agent.assert_called_once()
|
||||
call_kwargs = mock_lite_agent.call_args[1]
|
||||
tools = call_kwargs["tools"]
|
||||
|
||||
assert any(isinstance(tool, DummyTool) for tool in tools)
|
||||
assert any(isinstance(tool, AddImageTool) for tool in tools)
|
||||
assert len(tools) == 2
|
||||
|
||||
|
||||
def test_agent_kickoff_deduplicates_tools_by_name(mock_lite_agent):
|
||||
"""Test that tools with the same name are deduplicated."""
|
||||
|
||||
class DummyTool(BaseTool):
|
||||
name: str = "dummy_tool"
|
||||
description: str = "A dummy tool"
|
||||
|
||||
def _run(self, **kwargs):
|
||||
return "dummy result"
|
||||
|
||||
dummy_tool1 = DummyTool()
|
||||
dummy_tool2 = DummyTool()
|
||||
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
tools=[dummy_tool1, dummy_tool2],
|
||||
multimodal=False,
|
||||
)
|
||||
|
||||
agent.kickoff("Test message")
|
||||
|
||||
mock_lite_agent.assert_called_once()
|
||||
call_kwargs = mock_lite_agent.call_args[1]
|
||||
tools = call_kwargs["tools"]
|
||||
|
||||
dummy_tools = [tool for tool in tools if isinstance(tool, DummyTool)]
|
||||
assert len(dummy_tools) == 1
|
||||
|
||||
|
||||
def test_agent_kickoff_async_includes_platform_and_mcp_tools(mock_lite_agent):
|
||||
"""Test that kickoff_async includes platform and MCP tools like kickoff does."""
|
||||
with patch.object(Agent, "get_platform_tools") as mock_platform_tools, patch.object(
|
||||
Agent, "get_mcp_tools"
|
||||
) as mock_mcp_tools:
|
||||
|
||||
class PlatformTool(BaseTool):
|
||||
name: str = "platform_tool"
|
||||
description: str = "A platform tool"
|
||||
|
||||
def _run(self, **kwargs):
|
||||
return "platform result"
|
||||
|
||||
class MCPTool(BaseTool):
|
||||
name: str = "mcp_tool"
|
||||
description: str = "An MCP tool"
|
||||
|
||||
def _run(self, **kwargs):
|
||||
return "mcp result"
|
||||
|
||||
mock_platform_tools.return_value = [PlatformTool()]
|
||||
mock_mcp_tools.return_value = [MCPTool()]
|
||||
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
apps=["test_app"],
|
||||
mcps=["crewai-amp:dummy"],
|
||||
multimodal=True,
|
||||
)
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(agent.kickoff_async("Test message"))
|
||||
|
||||
mock_lite_agent.assert_called_once()
|
||||
call_kwargs = mock_lite_agent.call_args[1]
|
||||
tools = call_kwargs["tools"]
|
||||
|
||||
mock_platform_tools.assert_called_once()
|
||||
mock_mcp_tools.assert_called_once()
|
||||
|
||||
assert any(isinstance(tool, PlatformTool) for tool in tools)
|
||||
assert any(isinstance(tool, MCPTool) for tool in tools)
|
||||
assert any(isinstance(tool, AddImageTool) for tool in tools)
|
||||
@@ -13,7 +13,7 @@ load_result = load_dotenv(override=True)
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_test_environment():
|
||||
"""Set up test environment with a temporary directory for SQLite storage."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as temp_dir:
|
||||
# Create the directory with proper permissions
|
||||
storage_dir = Path(temp_dir) / "crewai_test_storage"
|
||||
storage_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -144,9 +144,8 @@ class TestAgentEvaluator:
|
||||
mock_crew.tasks.append(task)
|
||||
|
||||
events = {}
|
||||
started_event = threading.Event()
|
||||
completed_event = threading.Event()
|
||||
task_completed_event = threading.Event()
|
||||
results_condition = threading.Condition()
|
||||
results_ready = False
|
||||
|
||||
agent_evaluator = AgentEvaluator(
|
||||
agents=[agent], evaluators=[GoalAlignmentEvaluator()]
|
||||
@@ -156,13 +155,11 @@ class TestAgentEvaluator:
|
||||
async def capture_started(source, event):
|
||||
if event.agent_id == str(agent.id):
|
||||
events["started"] = event
|
||||
started_event.set()
|
||||
|
||||
@crewai_event_bus.on(AgentEvaluationCompletedEvent)
|
||||
async def capture_completed(source, event):
|
||||
if event.agent_id == str(agent.id):
|
||||
events["completed"] = event
|
||||
completed_event.set()
|
||||
|
||||
@crewai_event_bus.on(AgentEvaluationFailedEvent)
|
||||
def capture_failed(source, event):
|
||||
@@ -170,17 +167,20 @@ class TestAgentEvaluator:
|
||||
|
||||
@crewai_event_bus.on(TaskCompletedEvent)
|
||||
async def on_task_completed(source, event):
|
||||
# TaskCompletedEvent fires AFTER evaluation results are stored
|
||||
nonlocal results_ready
|
||||
if event.task and event.task.id == task.id:
|
||||
task_completed_event.set()
|
||||
while not agent_evaluator.get_evaluation_results().get(agent.role):
|
||||
pass
|
||||
with results_condition:
|
||||
results_ready = True
|
||||
results_condition.notify()
|
||||
|
||||
mock_crew.kickoff()
|
||||
|
||||
assert started_event.wait(timeout=5), "Timeout waiting for started event"
|
||||
assert completed_event.wait(timeout=5), "Timeout waiting for completed event"
|
||||
assert task_completed_event.wait(timeout=5), (
|
||||
"Timeout waiting for task completion"
|
||||
)
|
||||
with results_condition:
|
||||
assert results_condition.wait_for(
|
||||
lambda: results_ready, timeout=5
|
||||
), "Timeout waiting for evaluation results"
|
||||
|
||||
assert events.keys() == {"started", "completed"}
|
||||
assert events["started"].agent_id == str(agent.id)
|
||||
|
||||
@@ -647,6 +647,7 @@ def test_handle_streaming_tool_calls_no_tools(mock_emit):
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.skip(reason="Highly flaky on ci")
|
||||
def test_llm_call_when_stop_is_unsupported(caplog):
|
||||
llm = LLM(model="o1-mini", stop=["stop"], is_litellm=True)
|
||||
with caplog.at_level(logging.INFO):
|
||||
@@ -657,6 +658,7 @@ def test_llm_call_when_stop_is_unsupported(caplog):
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.skip(reason="Highly flaky on ci")
|
||||
def test_llm_call_when_stop_is_unsupported_when_additional_drop_params_is_provided(
|
||||
caplog,
|
||||
):
|
||||
@@ -664,7 +666,6 @@ def test_llm_call_when_stop_is_unsupported_when_additional_drop_params_is_provid
|
||||
model="o1-mini",
|
||||
stop=["stop"],
|
||||
additional_drop_params=["another_param"],
|
||||
is_litellm=True,
|
||||
)
|
||||
with caplog.at_level(logging.INFO):
|
||||
result = llm.call("What is the capital of France?")
|
||||
|
||||
@@ -273,12 +273,15 @@ def another_simple_tool():
|
||||
|
||||
|
||||
def test_internal_crew_with_mcp():
|
||||
from crewai_tools import MCPServerAdapter
|
||||
from crewai_tools.adapters.mcp_adapter import ToolCollection
|
||||
from crewai_tools.adapters.tool_collection import ToolCollection
|
||||
|
||||
mock = Mock(spec=MCPServerAdapter)
|
||||
mock.tools = ToolCollection([simple_tool, another_simple_tool])
|
||||
with patch("crewai_tools.MCPServerAdapter", return_value=mock) as adapter_mock:
|
||||
mock_adapter = Mock()
|
||||
mock_adapter.tools = ToolCollection([simple_tool, another_simple_tool])
|
||||
|
||||
with (
|
||||
patch("crewai_tools.MCPServerAdapter", return_value=mock_adapter) as adapter_mock,
|
||||
patch("crewai.llm.LLM.__new__", return_value=Mock()),
|
||||
):
|
||||
crew = InternalCrewWithMCP()
|
||||
assert crew.reporting_analyst().tools == [simple_tool, another_simple_tool]
|
||||
assert crew.researcher().tools == [simple_tool]
|
||||
|
||||
Reference in New Issue
Block a user