mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-06 14:48:29 +00:00
Compare commits
5 Commits
devin/1739
...
devin/1746
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c63010daaa | ||
|
|
d0191df996 | ||
|
|
e27bcfb381 | ||
|
|
082cbd2c1c | ||
|
|
3361fab293 |
11
src/crewai/patches/__init__.py
Normal file
11
src/crewai/patches/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
Patches module for CrewAI.
|
||||
|
||||
This module contains patches for external dependencies to fix known issues.
|
||||
|
||||
Version: 1.0.0
|
||||
"""
|
||||
|
||||
from crewai.patches.litellm_patch import apply_patches, patch_litellm_ollama_pt
|
||||
|
||||
__all__ = ["apply_patches", "patch_litellm_ollama_pt"]
|
||||
186
src/crewai/patches/litellm_patch.py
Normal file
186
src/crewai/patches/litellm_patch.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""
|
||||
Patch for litellm to fix IndexError in ollama_pt function.
|
||||
|
||||
This patch addresses issue #2744 in the crewAI repository, where an IndexError occurs
|
||||
in litellm's Ollama prompt template function when CrewAI Agent with Tools uses Ollama/Qwen models.
|
||||
|
||||
Version: 1.0.0
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
# Set up logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Patch version
|
||||
PATCH_VERSION = "1.0.0"
|
||||
|
||||
|
||||
class PatchApplicationError(Exception):
|
||||
"""Exception raised when a patch fails to apply."""
|
||||
pass
|
||||
|
||||
|
||||
def apply_patches() -> bool:
|
||||
"""
|
||||
Apply all patches to fix known issues with dependencies.
|
||||
|
||||
Returns:
|
||||
bool: True if all patches were applied successfully, False otherwise.
|
||||
"""
|
||||
success = patch_litellm_ollama_pt()
|
||||
logger.info(f"LiteLLM ollama_pt patch applied: {success}")
|
||||
return success
|
||||
|
||||
|
||||
def patch_litellm_ollama_pt() -> bool:
|
||||
"""
|
||||
Patch the ollama_pt function in litellm to fix IndexError.
|
||||
|
||||
The issue occurs when accessing messages[msg_i].get("tool_calls") without checking
|
||||
if msg_i is within bounds of the messages list. This happens after tool execution
|
||||
during the next LLM call.
|
||||
|
||||
Returns:
|
||||
bool: True if the patch was applied successfully, False otherwise.
|
||||
|
||||
Raises:
|
||||
PatchApplicationError: If there's an error during patch application.
|
||||
"""
|
||||
try:
|
||||
# Import the module containing the function to patch
|
||||
import litellm.litellm_core_utils.prompt_templates.factory as factory
|
||||
|
||||
# Define a patched version of the function
|
||||
def patched_ollama_pt(model: str, messages: List[Dict]) -> Dict[str, Any]:
|
||||
"""
|
||||
Patched version of ollama_pt that adds bounds checking.
|
||||
|
||||
This fixes the IndexError that occurs when the assistant message is the last
|
||||
message in the list and msg_i goes out of bounds.
|
||||
|
||||
Args:
|
||||
model: The model name.
|
||||
messages: The list of messages to process.
|
||||
|
||||
Returns:
|
||||
Dict containing the prompt and images.
|
||||
"""
|
||||
user_message_types = {"user", "tool", "function"}
|
||||
msg_i = 0
|
||||
images: List[str] = []
|
||||
prompt = ""
|
||||
|
||||
# Handle empty messages list
|
||||
if not messages:
|
||||
return {"prompt": prompt, "images": images}
|
||||
|
||||
while msg_i < len(messages):
|
||||
init_msg_i = msg_i
|
||||
user_content_str = ""
|
||||
## MERGE CONSECUTIVE USER CONTENT ##
|
||||
while msg_i < len(messages) and messages[msg_i]["role"] in user_message_types:
|
||||
msg_content = messages[msg_i].get("content")
|
||||
if msg_content:
|
||||
if isinstance(msg_content, list):
|
||||
for m in msg_content:
|
||||
if m.get("type", "") == "image_url":
|
||||
if isinstance(m["image_url"], str):
|
||||
images.append(m["image_url"])
|
||||
elif isinstance(m["image_url"], dict):
|
||||
images.append(m["image_url"]["url"])
|
||||
elif m.get("type", "") == "text":
|
||||
user_content_str += m["text"]
|
||||
else:
|
||||
# Tool message content will always be a string
|
||||
user_content_str += msg_content
|
||||
|
||||
msg_i += 1
|
||||
|
||||
if user_content_str:
|
||||
prompt += f"### User:\n{user_content_str}\n\n"
|
||||
|
||||
system_content_str, msg_i = factory._handle_ollama_system_message(
|
||||
messages, prompt, msg_i
|
||||
)
|
||||
if system_content_str:
|
||||
prompt += f"### System:\n{system_content_str}\n\n"
|
||||
|
||||
assistant_content_str = ""
|
||||
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
|
||||
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
|
||||
assistant_content_str += factory.convert_content_list_to_str(messages[msg_i])
|
||||
msg_i += 1
|
||||
|
||||
# Add bounds check before accessing messages[msg_i]
|
||||
# This is the key fix for the IndexError
|
||||
if msg_i < len(messages):
|
||||
tool_calls = messages[msg_i].get("tool_calls")
|
||||
ollama_tool_calls = []
|
||||
if tool_calls:
|
||||
for call in tool_calls:
|
||||
call_id = call["id"]
|
||||
function_name = call["function"]["name"]
|
||||
arguments = json.loads(call["function"]["arguments"])
|
||||
|
||||
ollama_tool_calls.append(
|
||||
{
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": function_name,
|
||||
"arguments": arguments,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
if ollama_tool_calls:
|
||||
assistant_content_str += (
|
||||
f"Tool Calls: {json.dumps(ollama_tool_calls, indent=2)}"
|
||||
)
|
||||
|
||||
msg_i += 1
|
||||
|
||||
if assistant_content_str:
|
||||
prompt += f"### Assistant:\n{assistant_content_str}\n\n"
|
||||
|
||||
if msg_i == init_msg_i: # prevent infinite loops
|
||||
raise factory.litellm.BadRequestError(
|
||||
message=factory.BAD_MESSAGE_ERROR_STR + f"passed in {messages[msg_i]}",
|
||||
model=model,
|
||||
llm_provider="ollama",
|
||||
)
|
||||
|
||||
response_dict = {
|
||||
"prompt": prompt,
|
||||
"images": images,
|
||||
}
|
||||
|
||||
return response_dict
|
||||
|
||||
# Replace the original function with our patched version
|
||||
factory.ollama_pt = patched_ollama_pt
|
||||
|
||||
logger.info(f"Successfully applied litellm ollama_pt patch version {PATCH_VERSION}")
|
||||
return True
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to apply litellm ollama_pt patch: {e}"
|
||||
logger.error(error_msg)
|
||||
return False
|
||||
|
||||
|
||||
# For backwards compatibility
|
||||
def patch_litellm() -> bool:
|
||||
"""
|
||||
Legacy function for backwards compatibility.
|
||||
|
||||
Returns:
|
||||
bool: True if the patch was applied successfully, False otherwise.
|
||||
"""
|
||||
try:
|
||||
return patch_litellm_ollama_pt()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to apply legacy litellm patch: {e}")
|
||||
return False
|
||||
@@ -1,2 +1 @@
|
||||
from .base_tool import BaseTool, tool
|
||||
from .human_tool import HumanTool
|
||||
|
||||
@@ -1,98 +0,0 @@
|
||||
"""Tool for handling human input using LangGraph's interrupt mechanism."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
|
||||
class HumanToolSchema(BaseModel):
|
||||
"""Schema for HumanTool input validation."""
|
||||
query: str = Field(
|
||||
...,
|
||||
description="The question to ask the user. Must be a non-empty string."
|
||||
)
|
||||
timeout: Optional[float] = Field(
|
||||
default=None,
|
||||
description="Optional timeout in seconds for waiting for user response"
|
||||
)
|
||||
|
||||
class HumanTool(BaseTool):
|
||||
"""Tool for getting human input using LangGraph's interrupt mechanism.
|
||||
|
||||
This tool allows agents to request input from users through LangGraph's
|
||||
interrupt mechanism. It supports timeout configuration and input validation.
|
||||
"""
|
||||
|
||||
name: str = "human"
|
||||
description: str = "Useful to ask user to enter input."
|
||||
args_schema: type[BaseModel] = HumanToolSchema
|
||||
result_as_answer: bool = False # Don't use the response as final answer
|
||||
|
||||
def _run(self, query: str, timeout: Optional[float] = None) -> str:
|
||||
"""Execute the human input tool.
|
||||
|
||||
Args:
|
||||
query: The question to ask the user
|
||||
timeout: Optional timeout in seconds
|
||||
|
||||
Returns:
|
||||
The user's response
|
||||
|
||||
Raises:
|
||||
ImportError: If LangGraph is not installed
|
||||
TimeoutError: If response times out
|
||||
ValueError: If query is invalid
|
||||
"""
|
||||
if not query or not isinstance(query, str):
|
||||
raise ValueError("Query must be a non-empty string")
|
||||
|
||||
try:
|
||||
from langgraph.prebuilt.state_graphs import interrupt
|
||||
logging.info(f"Requesting human input: {query}")
|
||||
human_response = interrupt({"query": query, "timeout": timeout})
|
||||
return human_response["data"]
|
||||
except ImportError:
|
||||
logging.error("LangGraph not installed")
|
||||
raise ImportError(
|
||||
"LangGraph is required for HumanTool. "
|
||||
"Install with `pip install langgraph`"
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error during human input: {str(e)}")
|
||||
raise
|
||||
|
||||
async def _arun(self, query: str, timeout: Optional[float] = None) -> str:
|
||||
"""Execute the human input tool asynchronously.
|
||||
|
||||
Args:
|
||||
query: The question to ask the user
|
||||
timeout: Optional timeout in seconds
|
||||
|
||||
Returns:
|
||||
The user's response
|
||||
|
||||
Raises:
|
||||
ImportError: If LangGraph is not installed
|
||||
TimeoutError: If response times out
|
||||
ValueError: If query is invalid
|
||||
"""
|
||||
if not query or not isinstance(query, str):
|
||||
raise ValueError("Query must be a non-empty string")
|
||||
|
||||
try:
|
||||
from langgraph.prebuilt.state_graphs import interrupt
|
||||
logging.info(f"Requesting async human input: {query}")
|
||||
human_response = interrupt({"query": query, "timeout": timeout})
|
||||
return human_response["data"]
|
||||
except ImportError:
|
||||
logging.error("LangGraph not installed")
|
||||
raise ImportError(
|
||||
"LangGraph is required for HumanTool. "
|
||||
"Install with `pip install langgraph`"
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error during async human input: {str(e)}")
|
||||
raise
|
||||
@@ -182,10 +182,6 @@ class ToolUsage:
|
||||
else:
|
||||
result = tool.invoke(input={})
|
||||
except Exception as e:
|
||||
# Check if this is a LangGraph interrupt that should be propagated
|
||||
if hasattr(e, '__class__') and e.__class__.__name__ == 'Interrupt':
|
||||
raise e # Propagate interrupt up
|
||||
|
||||
self.on_tool_error(tool=tool, tool_calling=calling, e=e)
|
||||
self._run_attempts += 1
|
||||
if self._run_attempts > self._max_parsing_attempts:
|
||||
|
||||
71
tests/patches/test_litellm_patch.py
Normal file
71
tests/patches/test_litellm_patch.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""
|
||||
Test for the litellm patch that fixes the IndexError in ollama_pt function.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import litellm
|
||||
import pytest
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import ollama_pt
|
||||
|
||||
from crewai.patches.litellm_patch import patch_litellm_ollama_pt
|
||||
|
||||
|
||||
class TestLitellmPatch(unittest.TestCase):
|
||||
def test_ollama_pt_patch_fixes_index_error(self):
|
||||
"""Test that the patch fixes the IndexError in ollama_pt."""
|
||||
# Create a message list where the assistant message is the last one
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there"},
|
||||
]
|
||||
|
||||
# Store the original function to restore it after the test
|
||||
original_ollama_pt = litellm.litellm_core_utils.prompt_templates.factory.ollama_pt
|
||||
|
||||
try:
|
||||
# Apply the patch
|
||||
success = patch_litellm_ollama_pt()
|
||||
self.assertTrue(success, "Patch application failed")
|
||||
|
||||
# Use the function from the module directly to ensure we're using the patched version
|
||||
result = litellm.litellm_core_utils.prompt_templates.factory.ollama_pt("qwen3:4b", messages)
|
||||
|
||||
# Verify the result is as expected
|
||||
self.assertIn("prompt", result)
|
||||
self.assertIn("images", result)
|
||||
self.assertIn("### User:\nHello", result["prompt"])
|
||||
self.assertIn("### Assistant:\nHi there", result["prompt"])
|
||||
finally:
|
||||
# Restore the original function to avoid affecting other tests
|
||||
litellm.litellm_core_utils.prompt_templates.factory.ollama_pt = original_ollama_pt
|
||||
|
||||
def test_ollama_pt_patch_with_empty_messages(self):
|
||||
"""Test that the patch handles empty message lists."""
|
||||
messages = []
|
||||
|
||||
# Store the original function to restore it after the test
|
||||
original_ollama_pt = litellm.litellm_core_utils.prompt_templates.factory.ollama_pt
|
||||
|
||||
try:
|
||||
# Apply the patch
|
||||
success = patch_litellm_ollama_pt()
|
||||
self.assertTrue(success, "Patch application failed")
|
||||
|
||||
# Use the function from the module directly to ensure we're using the patched version
|
||||
result = litellm.litellm_core_utils.prompt_templates.factory.ollama_pt("qwen3:4b", messages)
|
||||
|
||||
# Verify the result is as expected
|
||||
self.assertIn("prompt", result)
|
||||
self.assertIn("images", result)
|
||||
self.assertEqual("", result["prompt"])
|
||||
self.assertEqual([], result["images"])
|
||||
finally:
|
||||
# Restore the original function to avoid affecting other tests
|
||||
litellm.litellm_core_utils.prompt_templates.factory.ollama_pt = original_ollama_pt
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,83 +0,0 @@
|
||||
"""Test HumanTool functionality."""
|
||||
|
||||
from unittest.mock import patch
|
||||
import pytest
|
||||
|
||||
from crewai.tools import HumanTool
|
||||
|
||||
def test_human_tool_basic():
|
||||
"""Test basic HumanTool creation and attributes."""
|
||||
tool = HumanTool()
|
||||
assert tool.name == "human"
|
||||
assert "ask user to enter input" in tool.description.lower()
|
||||
assert not tool.result_as_answer
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_human_tool_with_langgraph_interrupt():
|
||||
"""Test HumanTool with LangGraph interrupt handling."""
|
||||
tool = HumanTool()
|
||||
|
||||
with patch('langgraph.prebuilt.state_graphs.interrupt') as mock_interrupt:
|
||||
mock_interrupt.return_value = {"data": "test response"}
|
||||
result = tool._run("test query")
|
||||
assert result == "test response"
|
||||
mock_interrupt.assert_called_with({"query": "test query", "timeout": None})
|
||||
|
||||
|
||||
def test_human_tool_timeout():
|
||||
"""Test HumanTool timeout handling."""
|
||||
tool = HumanTool()
|
||||
timeout = 30.0
|
||||
|
||||
with patch('langgraph.prebuilt.state_graphs.interrupt') as mock_interrupt:
|
||||
mock_interrupt.return_value = {"data": "test response"}
|
||||
result = tool._run("test query", timeout=timeout)
|
||||
assert result == "test response"
|
||||
mock_interrupt.assert_called_with({"query": "test query", "timeout": timeout})
|
||||
|
||||
|
||||
def test_human_tool_invalid_input():
|
||||
"""Test HumanTool input validation."""
|
||||
tool = HumanTool()
|
||||
|
||||
with pytest.raises(ValueError, match="Query must be a non-empty string"):
|
||||
tool._run("")
|
||||
|
||||
with pytest.raises(ValueError, match="Query must be a non-empty string"):
|
||||
tool._run(None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_human_tool_async():
|
||||
"""Test async HumanTool functionality."""
|
||||
tool = HumanTool()
|
||||
|
||||
with patch('langgraph.prebuilt.state_graphs.interrupt') as mock_interrupt:
|
||||
mock_interrupt.return_value = {"data": "test response"}
|
||||
result = await tool._arun("test query")
|
||||
assert result == "test response"
|
||||
mock_interrupt.assert_called_with({"query": "test query", "timeout": None})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_human_tool_async_timeout():
|
||||
"""Test async HumanTool timeout handling."""
|
||||
tool = HumanTool()
|
||||
timeout = 30.0
|
||||
|
||||
with patch('langgraph.prebuilt.state_graphs.interrupt') as mock_interrupt:
|
||||
mock_interrupt.return_value = {"data": "test response"}
|
||||
result = await tool._arun("test query", timeout=timeout)
|
||||
assert result == "test response"
|
||||
mock_interrupt.assert_called_with({"query": "test query", "timeout": timeout})
|
||||
|
||||
|
||||
def test_human_tool_without_langgraph():
|
||||
"""Test HumanTool behavior when LangGraph is not installed."""
|
||||
tool = HumanTool()
|
||||
|
||||
with patch.dict('sys.modules', {'langgraph': None}):
|
||||
with pytest.raises(ImportError) as exc_info:
|
||||
tool._run("test query")
|
||||
assert "LangGraph is required" in str(exc_info.value)
|
||||
assert "pip install langgraph" in str(exc_info.value)
|
||||
@@ -1,13 +1,12 @@
|
||||
import json
|
||||
import random
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai import Agent, Task
|
||||
from crewai.tools import BaseTool
|
||||
from crewai.tools.tool_calling import ToolCalling
|
||||
from crewai.tools.tool_usage import ToolUsage
|
||||
|
||||
|
||||
@@ -86,36 +85,6 @@ def test_random_number_tool_schema():
|
||||
)
|
||||
|
||||
|
||||
def test_tool_usage_interrupt_handling():
|
||||
"""Test that tool usage properly propagates LangGraph interrupts."""
|
||||
class InterruptingTool(BaseTool):
|
||||
name: str = "interrupt_test"
|
||||
description: str = "A tool that raises LangGraph interrupts"
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
raise type('Interrupt', (Exception,), {})("test interrupt")
|
||||
|
||||
tool = InterruptingTool()
|
||||
tool_usage = ToolUsage(
|
||||
tools_handler=MagicMock(),
|
||||
tools=[tool],
|
||||
original_tools=[tool],
|
||||
tools_description="Sample tool for testing",
|
||||
tools_names="interrupt_test",
|
||||
task=MagicMock(),
|
||||
function_calling_llm=MagicMock(),
|
||||
agent=MagicMock(),
|
||||
action=MagicMock(),
|
||||
)
|
||||
|
||||
# Test that interrupt is propagated
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
tool_usage.use(
|
||||
ToolCalling(tool_name="interrupt_test", arguments={"query": "test"}, log="test"),
|
||||
"test"
|
||||
)
|
||||
assert "test interrupt" in str(exc_info.value)
|
||||
|
||||
def test_tool_usage_render():
|
||||
tool = RandomNumberTool()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user