Compare commits

..

10 Commits

Author SHA1 Message Date
Devin AI
18a38ba436 fix: reorder imports according to project style
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-11 11:21:36 +00:00
Devin AI
369ee46ff3 fix: alphabetically order standard library imports
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-11 11:20:36 +00:00
Devin AI
39a290b4d3 fix: reorder standard library imports
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-11 11:19:47 +00:00
Devin AI
d2cc61028f fix: reorder imports in test file
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-11 11:18:47 +00:00
Devin AI
edcd55d19f fix: organize imports according to linter rules
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-11 11:17:55 +00:00
Devin AI
097fac6c87 feat: enhance HumanTool with validation, timeout, and async support
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-11 11:16:46 +00:00
Devin AI
ae4ca7748c fix: sort imports according to linter rules
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-11 11:11:01 +00:00
Devin AI
8b58feb5e0 fix: sort imports according to linter rules
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-11 11:09:29 +00:00
Devin AI
a4856a9805 fix: add missing ToolCalling import in test_tool_usage.py
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-11 11:07:04 +00:00
Devin AI
364a31ca8b fix: handle LangGraph interrupts in human tool
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-11 11:06:49 +00:00
13 changed files with 237 additions and 122 deletions

View File

@@ -1,5 +1,3 @@
from typing import Optional
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
from crewai.memory.memory import Memory
from crewai.memory.storage.rag_storage import RAGStorage
@@ -40,7 +38,7 @@ class EntityMemory(Memory):
)
super().__init__(storage)
def save(self, item: EntityMemoryItem, custom_key: Optional[str] = None) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
def save(self, item: EntityMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
"""Saves an entity item into the SQLite storage."""
if self.memory_provider == "mem0":
data = f"""
@@ -51,7 +49,7 @@ class EntityMemory(Memory):
"""
else:
data = f"{item.name}({item.type}): {item.description}"
super().save(data, item.metadata, custom_key=custom_key)
super().save(data, item.metadata)
def reset(self) -> None:
try:

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
from crewai.memory.memory import Memory
@@ -19,12 +19,9 @@ class LongTermMemory(Memory):
storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage()
super().__init__(storage)
def save(self, item: LongTermMemoryItem, custom_key: Optional[str] = None) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
def save(self, item: LongTermMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
metadata = item.metadata
metadata.update({"agent": item.agent, "expected_output": item.expected_output})
if custom_key:
metadata.update({"custom_key": custom_key})
self.storage.save( # type: ignore # BUG?: Unexpected keyword argument "task_description","score","datetime" for "save" of "Storage"
task_description=item.task,
score=metadata["quality"],
@@ -32,8 +29,8 @@ class LongTermMemory(Memory):
datetime=item.datetime,
)
def search(self, task: str, latest_n: int = 3, custom_key: Optional[str] = None) -> List[Dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
return self.storage.load(task, latest_n, custom_key) # type: ignore # BUG?: "Storage" has no attribute "load"
def search(self, task: str, latest_n: int = 3) -> List[Dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
return self.storage.load(task, latest_n) # type: ignore # BUG?: "Storage" has no attribute "load"
def reset(self) -> None:
self.storage.reset()

View File

@@ -5,10 +5,7 @@ from crewai.memory.storage.rag_storage import RAGStorage
class Memory:
"""
Base class for memory, now supporting agent tags, generic metadata, and custom keys.
Custom keys allow scoping memories to specific entities (users, accounts, sessions),
retrieving memories contextually, and preventing data leakage across logical boundaries.
Base class for memory, now supporting agent tags and generic metadata.
"""
def __init__(self, storage: RAGStorage):
@@ -19,13 +16,10 @@ class Memory:
value: Any,
metadata: Optional[Dict[str, Any]] = None,
agent: Optional[str] = None,
custom_key: Optional[str] = None,
) -> None:
metadata = metadata or {}
if agent:
metadata["agent"] = agent
if custom_key:
metadata["custom_key"] = custom_key
self.storage.save(value, metadata)
@@ -34,12 +28,7 @@ class Memory:
query: str,
limit: int = 3,
score_threshold: float = 0.35,
custom_key: Optional[str] = None,
) -> List[Any]:
filter_dict = None
if custom_key:
filter_dict = {"custom_key": {"$eq": custom_key}}
return self.storage.search(
query=query, limit=limit, score_threshold=score_threshold, filter=filter_dict
query=query, limit=limit, score_threshold=score_threshold
)

View File

@@ -46,31 +46,22 @@ class ShortTermMemory(Memory):
value: Any,
metadata: Optional[Dict[str, Any]] = None,
agent: Optional[str] = None,
custom_key: Optional[str] = None,
) -> None:
item = ShortTermMemoryItem(data=value, metadata=metadata, agent=agent)
if self.memory_provider == "mem0":
item.data = f"Remember the following insights from Agent run: {item.data}"
super().save(value=item.data, metadata=item.metadata, agent=item.agent, custom_key=custom_key)
super().save(value=item.data, metadata=item.metadata, agent=item.agent)
def search(
self,
query: str,
limit: int = 3,
score_threshold: float = 0.35,
custom_key: Optional[str] = None,
):
filter_dict = None
if custom_key:
filter_dict = {"custom_key": {"$eq": custom_key}}
return self.storage.search(
query=query,
limit=limit,
score_threshold=score_threshold,
filter=filter_dict
)
query=query, limit=limit, score_threshold=score_threshold
) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters
def reset(self) -> None:
try:

View File

@@ -70,31 +70,22 @@ class LTMSQLiteStorage:
)
def load(
self, task_description: str, latest_n: int, custom_key: Optional[str] = None
self, task_description: str, latest_n: int
) -> Optional[List[Dict[str, Any]]]:
"""Queries the LTM table by task description with error handling."""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
query = """
cursor.execute(
f"""
SELECT metadata, datetime, score
FROM long_term_memories
WHERE task_description = ?
"""
params = [task_description]
if custom_key:
query += " AND json_extract(metadata, '$.custom_key') = ?"
params.append(custom_key)
query += f"""
ORDER BY datetime DESC, score ASC
LIMIT {latest_n}
"""
cursor.execute(query, params)
""", # nosec
(task_description,),
)
rows = cursor.fetchall()
if rows:
return [

View File

@@ -120,11 +120,7 @@ class RAGStorage(BaseRAGStorage):
try:
with suppress_logging():
response = self.collection.query(
query_texts=query,
n_results=limit,
where=filter
)
response = self.collection.query(query_texts=query, n_results=limit)
results = []
for i in range(len(response["ids"][0])):

View File

@@ -26,27 +26,20 @@ class UserMemory(Memory):
value,
metadata: Optional[Dict[str, Any]] = None,
agent: Optional[str] = None,
custom_key: Optional[str] = None,
) -> None:
# TODO: Change this function since we want to take care of the case where we save memories for the usr
data = f"Remember the details about the user: {value}"
super().save(data, metadata, custom_key=custom_key)
super().save(data, metadata)
def search(
self,
query: str,
limit: int = 3,
score_threshold: float = 0.35,
custom_key: Optional[str] = None,
):
filter_dict = None
if custom_key:
filter_dict = {"custom_key": {"$eq": custom_key}}
results = self.storage.search(
query=query,
limit=limit,
score_threshold=score_threshold,
filter=filter_dict,
)
return results

View File

@@ -1 +1,2 @@
from .base_tool import BaseTool, tool
from .human_tool import HumanTool

View File

@@ -0,0 +1,98 @@
"""Tool for handling human input using LangGraph's interrupt mechanism."""
import logging
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
from crewai.tools import BaseTool
class HumanToolSchema(BaseModel):
"""Schema for HumanTool input validation."""
query: str = Field(
...,
description="The question to ask the user. Must be a non-empty string."
)
timeout: Optional[float] = Field(
default=None,
description="Optional timeout in seconds for waiting for user response"
)
class HumanTool(BaseTool):
"""Tool for getting human input using LangGraph's interrupt mechanism.
This tool allows agents to request input from users through LangGraph's
interrupt mechanism. It supports timeout configuration and input validation.
"""
name: str = "human"
description: str = "Useful to ask user to enter input."
args_schema: type[BaseModel] = HumanToolSchema
result_as_answer: bool = False # Don't use the response as final answer
def _run(self, query: str, timeout: Optional[float] = None) -> str:
"""Execute the human input tool.
Args:
query: The question to ask the user
timeout: Optional timeout in seconds
Returns:
The user's response
Raises:
ImportError: If LangGraph is not installed
TimeoutError: If response times out
ValueError: If query is invalid
"""
if not query or not isinstance(query, str):
raise ValueError("Query must be a non-empty string")
try:
from langgraph.prebuilt.state_graphs import interrupt
logging.info(f"Requesting human input: {query}")
human_response = interrupt({"query": query, "timeout": timeout})
return human_response["data"]
except ImportError:
logging.error("LangGraph not installed")
raise ImportError(
"LangGraph is required for HumanTool. "
"Install with `pip install langgraph`"
)
except Exception as e:
logging.error(f"Error during human input: {str(e)}")
raise
async def _arun(self, query: str, timeout: Optional[float] = None) -> str:
"""Execute the human input tool asynchronously.
Args:
query: The question to ask the user
timeout: Optional timeout in seconds
Returns:
The user's response
Raises:
ImportError: If LangGraph is not installed
TimeoutError: If response times out
ValueError: If query is invalid
"""
if not query or not isinstance(query, str):
raise ValueError("Query must be a non-empty string")
try:
from langgraph.prebuilt.state_graphs import interrupt
logging.info(f"Requesting async human input: {query}")
human_response = interrupt({"query": query, "timeout": timeout})
return human_response["data"]
except ImportError:
logging.error("LangGraph not installed")
raise ImportError(
"LangGraph is required for HumanTool. "
"Install with `pip install langgraph`"
)
except Exception as e:
logging.error(f"Error during async human input: {str(e)}")
raise

View File

@@ -182,6 +182,10 @@ class ToolUsage:
else:
result = tool.invoke(input={})
except Exception as e:
# Check if this is a LangGraph interrupt that should be propagated
if hasattr(e, '__class__') and e.__class__.__name__ == 'Interrupt':
raise e # Propagate interrupt up
self.on_tool_error(tool=tool, tool_calling=calling, e=e)
self._run_attempts += 1
if self._run_attempts > self._max_parsing_attempts:

View File

@@ -1,57 +0,0 @@
import pytest
from unittest.mock import patch, MagicMock
from crewai.memory.short_term.short_term_memory import ShortTermMemory
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
from crewai.agent import Agent
from crewai.crew import Crew
from crewai.task import Task
@pytest.fixture
def short_term_memory():
"""Fixture to create a ShortTermMemory instance"""
agent = Agent(
role="Researcher",
goal="Search relevant data and provide results",
backstory="You are a researcher at a leading tech think tank.",
tools=[],
verbose=True,
)
task = Task(
description="Perform a search on specific topics.",
expected_output="A list of relevant URLs based on the search query.",
agent=agent,
)
return ShortTermMemory(crew=Crew(agents=[agent], tasks=[task]))
def test_save_with_custom_key(short_term_memory):
"""Test that save method correctly passes custom_key to storage"""
with patch.object(short_term_memory.storage, 'save') as mock_save:
short_term_memory.save(
value="Test data",
metadata={"task": "test_task"},
agent="test_agent",
custom_key="user123",
)
called_args = mock_save.call_args[0]
called_kwargs = mock_save.call_args[1]
assert "custom_key" in called_args[1]
assert called_args[1]["custom_key"] == "user123"
def test_search_with_custom_key(short_term_memory):
"""Test that search method correctly passes custom_key to storage"""
expected_results = [{"context": "Test data", "metadata": {"custom_key": "user123"}, "score": 0.95}]
with patch.object(short_term_memory.storage, 'search', return_value=expected_results) as mock_search:
results = short_term_memory.search("test query", custom_key="user123")
mock_search.assert_called_once()
filter_arg = mock_search.call_args[1].get('filter')
assert filter_arg == {"custom_key": {"$eq": "user123"}}
assert results == expected_results

View File

@@ -0,0 +1,83 @@
"""Test HumanTool functionality."""
from unittest.mock import patch
import pytest
from crewai.tools import HumanTool
def test_human_tool_basic():
"""Test basic HumanTool creation and attributes."""
tool = HumanTool()
assert tool.name == "human"
assert "ask user to enter input" in tool.description.lower()
assert not tool.result_as_answer
@pytest.mark.vcr(filter_headers=["authorization"])
def test_human_tool_with_langgraph_interrupt():
"""Test HumanTool with LangGraph interrupt handling."""
tool = HumanTool()
with patch('langgraph.prebuilt.state_graphs.interrupt') as mock_interrupt:
mock_interrupt.return_value = {"data": "test response"}
result = tool._run("test query")
assert result == "test response"
mock_interrupt.assert_called_with({"query": "test query", "timeout": None})
def test_human_tool_timeout():
"""Test HumanTool timeout handling."""
tool = HumanTool()
timeout = 30.0
with patch('langgraph.prebuilt.state_graphs.interrupt') as mock_interrupt:
mock_interrupt.return_value = {"data": "test response"}
result = tool._run("test query", timeout=timeout)
assert result == "test response"
mock_interrupt.assert_called_with({"query": "test query", "timeout": timeout})
def test_human_tool_invalid_input():
"""Test HumanTool input validation."""
tool = HumanTool()
with pytest.raises(ValueError, match="Query must be a non-empty string"):
tool._run("")
with pytest.raises(ValueError, match="Query must be a non-empty string"):
tool._run(None)
@pytest.mark.asyncio
async def test_human_tool_async():
"""Test async HumanTool functionality."""
tool = HumanTool()
with patch('langgraph.prebuilt.state_graphs.interrupt') as mock_interrupt:
mock_interrupt.return_value = {"data": "test response"}
result = await tool._arun("test query")
assert result == "test response"
mock_interrupt.assert_called_with({"query": "test query", "timeout": None})
@pytest.mark.asyncio
async def test_human_tool_async_timeout():
"""Test async HumanTool timeout handling."""
tool = HumanTool()
timeout = 30.0
with patch('langgraph.prebuilt.state_graphs.interrupt') as mock_interrupt:
mock_interrupt.return_value = {"data": "test response"}
result = await tool._arun("test query", timeout=timeout)
assert result == "test response"
mock_interrupt.assert_called_with({"query": "test query", "timeout": timeout})
def test_human_tool_without_langgraph():
"""Test HumanTool behavior when LangGraph is not installed."""
tool = HumanTool()
with patch.dict('sys.modules', {'langgraph': None}):
with pytest.raises(ImportError) as exc_info:
tool._run("test query")
assert "LangGraph is required" in str(exc_info.value)
assert "pip install langgraph" in str(exc_info.value)

View File

@@ -1,12 +1,13 @@
import json
import random
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch
import pytest
from pydantic import BaseModel, Field
from crewai import Agent, Task
from crewai.tools import BaseTool
from crewai.tools.tool_calling import ToolCalling
from crewai.tools.tool_usage import ToolUsage
@@ -85,6 +86,36 @@ def test_random_number_tool_schema():
)
def test_tool_usage_interrupt_handling():
"""Test that tool usage properly propagates LangGraph interrupts."""
class InterruptingTool(BaseTool):
name: str = "interrupt_test"
description: str = "A tool that raises LangGraph interrupts"
def _run(self, query: str) -> str:
raise type('Interrupt', (Exception,), {})("test interrupt")
tool = InterruptingTool()
tool_usage = ToolUsage(
tools_handler=MagicMock(),
tools=[tool],
original_tools=[tool],
tools_description="Sample tool for testing",
tools_names="interrupt_test",
task=MagicMock(),
function_calling_llm=MagicMock(),
agent=MagicMock(),
action=MagicMock(),
)
# Test that interrupt is propagated
with pytest.raises(Exception) as exc_info:
tool_usage.use(
ToolCalling(tool_name="interrupt_test", arguments={"query": "test"}, log="test"),
"test"
)
assert "test interrupt" in str(exc_info.value)
def test_tool_usage_render():
tool = RandomNumberTool()