Compare commits

..

5 Commits

Author SHA1 Message Date
Devin AI
c63010daaa Fix import sorting in litellm_patch.py
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-05-03 02:35:40 +00:00
Devin AI
d0191df996 Fix type annotations for all functions in litellm_patch.py
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-05-03 02:31:34 +00:00
Devin AI
e27bcfb381 Fix type annotation for images variable
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-05-03 02:30:01 +00:00
Devin AI
082cbd2c1c Fix lint issues and improve patch implementation
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-05-03 02:25:10 +00:00
Devin AI
3361fab293 Fix IndexError in litellm's ollama_pt function when using Ollama/Qwen models with tools
This patch addresses issue #2744 by adding bounds checking before accessing
messages[msg_i].get('tool_calls') in the ollama_pt function. The issue occurs
when an assistant message is the last message in the list, causing msg_i to
go out of bounds.

The fix is implemented as a monkey patch in CrewAI to avoid waiting for
an upstream fix in litellm.

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-05-03 02:07:03 +00:00
8 changed files with 269 additions and 218 deletions

View File

@@ -0,0 +1,11 @@
"""
Patches module for CrewAI.
This module contains patches for external dependencies to fix known issues.
Version: 1.0.0
"""
from crewai.patches.litellm_patch import apply_patches, patch_litellm_ollama_pt
__all__ = ["apply_patches", "patch_litellm_ollama_pt"]

View File

@@ -0,0 +1,186 @@
"""
Patch for litellm to fix IndexError in ollama_pt function.
This patch addresses issue #2744 in the crewAI repository, where an IndexError occurs
in litellm's Ollama prompt template function when CrewAI Agent with Tools uses Ollama/Qwen models.
Version: 1.0.0
"""
import json
import logging
from typing import Any, Dict, List, Optional, Tuple, Union
# Set up logging
logger = logging.getLogger(__name__)
# Patch version
PATCH_VERSION = "1.0.0"
class PatchApplicationError(Exception):
"""Exception raised when a patch fails to apply."""
pass
def apply_patches() -> bool:
"""
Apply all patches to fix known issues with dependencies.
Returns:
bool: True if all patches were applied successfully, False otherwise.
"""
success = patch_litellm_ollama_pt()
logger.info(f"LiteLLM ollama_pt patch applied: {success}")
return success
def patch_litellm_ollama_pt() -> bool:
"""
Patch the ollama_pt function in litellm to fix IndexError.
The issue occurs when accessing messages[msg_i].get("tool_calls") without checking
if msg_i is within bounds of the messages list. This happens after tool execution
during the next LLM call.
Returns:
bool: True if the patch was applied successfully, False otherwise.
Raises:
PatchApplicationError: If there's an error during patch application.
"""
try:
# Import the module containing the function to patch
import litellm.litellm_core_utils.prompt_templates.factory as factory
# Define a patched version of the function
def patched_ollama_pt(model: str, messages: List[Dict]) -> Dict[str, Any]:
"""
Patched version of ollama_pt that adds bounds checking.
This fixes the IndexError that occurs when the assistant message is the last
message in the list and msg_i goes out of bounds.
Args:
model: The model name.
messages: The list of messages to process.
Returns:
Dict containing the prompt and images.
"""
user_message_types = {"user", "tool", "function"}
msg_i = 0
images: List[str] = []
prompt = ""
# Handle empty messages list
if not messages:
return {"prompt": prompt, "images": images}
while msg_i < len(messages):
init_msg_i = msg_i
user_content_str = ""
## MERGE CONSECUTIVE USER CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] in user_message_types:
msg_content = messages[msg_i].get("content")
if msg_content:
if isinstance(msg_content, list):
for m in msg_content:
if m.get("type", "") == "image_url":
if isinstance(m["image_url"], str):
images.append(m["image_url"])
elif isinstance(m["image_url"], dict):
images.append(m["image_url"]["url"])
elif m.get("type", "") == "text":
user_content_str += m["text"]
else:
# Tool message content will always be a string
user_content_str += msg_content
msg_i += 1
if user_content_str:
prompt += f"### User:\n{user_content_str}\n\n"
system_content_str, msg_i = factory._handle_ollama_system_message(
messages, prompt, msg_i
)
if system_content_str:
prompt += f"### System:\n{system_content_str}\n\n"
assistant_content_str = ""
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
assistant_content_str += factory.convert_content_list_to_str(messages[msg_i])
msg_i += 1
# Add bounds check before accessing messages[msg_i]
# This is the key fix for the IndexError
if msg_i < len(messages):
tool_calls = messages[msg_i].get("tool_calls")
ollama_tool_calls = []
if tool_calls:
for call in tool_calls:
call_id = call["id"]
function_name = call["function"]["name"]
arguments = json.loads(call["function"]["arguments"])
ollama_tool_calls.append(
{
"id": call_id,
"type": "function",
"function": {
"name": function_name,
"arguments": arguments,
},
}
)
if ollama_tool_calls:
assistant_content_str += (
f"Tool Calls: {json.dumps(ollama_tool_calls, indent=2)}"
)
msg_i += 1
if assistant_content_str:
prompt += f"### Assistant:\n{assistant_content_str}\n\n"
if msg_i == init_msg_i: # prevent infinite loops
raise factory.litellm.BadRequestError(
message=factory.BAD_MESSAGE_ERROR_STR + f"passed in {messages[msg_i]}",
model=model,
llm_provider="ollama",
)
response_dict = {
"prompt": prompt,
"images": images,
}
return response_dict
# Replace the original function with our patched version
factory.ollama_pt = patched_ollama_pt
logger.info(f"Successfully applied litellm ollama_pt patch version {PATCH_VERSION}")
return True
except Exception as e:
error_msg = f"Failed to apply litellm ollama_pt patch: {e}"
logger.error(error_msg)
return False
# For backwards compatibility
def patch_litellm() -> bool:
"""
Legacy function for backwards compatibility.
Returns:
bool: True if the patch was applied successfully, False otherwise.
"""
try:
return patch_litellm_ollama_pt()
except Exception as e:
logger.error(f"Failed to apply legacy litellm patch: {e}")
return False

View File

@@ -1,2 +1 @@
from .base_tool import BaseTool, tool
from .human_tool import HumanTool

View File

@@ -1,98 +0,0 @@
"""Tool for handling human input using LangGraph's interrupt mechanism."""
import logging
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
from crewai.tools import BaseTool
class HumanToolSchema(BaseModel):
"""Schema for HumanTool input validation."""
query: str = Field(
...,
description="The question to ask the user. Must be a non-empty string."
)
timeout: Optional[float] = Field(
default=None,
description="Optional timeout in seconds for waiting for user response"
)
class HumanTool(BaseTool):
"""Tool for getting human input using LangGraph's interrupt mechanism.
This tool allows agents to request input from users through LangGraph's
interrupt mechanism. It supports timeout configuration and input validation.
"""
name: str = "human"
description: str = "Useful to ask user to enter input."
args_schema: type[BaseModel] = HumanToolSchema
result_as_answer: bool = False # Don't use the response as final answer
def _run(self, query: str, timeout: Optional[float] = None) -> str:
"""Execute the human input tool.
Args:
query: The question to ask the user
timeout: Optional timeout in seconds
Returns:
The user's response
Raises:
ImportError: If LangGraph is not installed
TimeoutError: If response times out
ValueError: If query is invalid
"""
if not query or not isinstance(query, str):
raise ValueError("Query must be a non-empty string")
try:
from langgraph.prebuilt.state_graphs import interrupt
logging.info(f"Requesting human input: {query}")
human_response = interrupt({"query": query, "timeout": timeout})
return human_response["data"]
except ImportError:
logging.error("LangGraph not installed")
raise ImportError(
"LangGraph is required for HumanTool. "
"Install with `pip install langgraph`"
)
except Exception as e:
logging.error(f"Error during human input: {str(e)}")
raise
async def _arun(self, query: str, timeout: Optional[float] = None) -> str:
"""Execute the human input tool asynchronously.
Args:
query: The question to ask the user
timeout: Optional timeout in seconds
Returns:
The user's response
Raises:
ImportError: If LangGraph is not installed
TimeoutError: If response times out
ValueError: If query is invalid
"""
if not query or not isinstance(query, str):
raise ValueError("Query must be a non-empty string")
try:
from langgraph.prebuilt.state_graphs import interrupt
logging.info(f"Requesting async human input: {query}")
human_response = interrupt({"query": query, "timeout": timeout})
return human_response["data"]
except ImportError:
logging.error("LangGraph not installed")
raise ImportError(
"LangGraph is required for HumanTool. "
"Install with `pip install langgraph`"
)
except Exception as e:
logging.error(f"Error during async human input: {str(e)}")
raise

View File

@@ -182,10 +182,6 @@ class ToolUsage:
else:
result = tool.invoke(input={})
except Exception as e:
# Check if this is a LangGraph interrupt that should be propagated
if hasattr(e, '__class__') and e.__class__.__name__ == 'Interrupt':
raise e # Propagate interrupt up
self.on_tool_error(tool=tool, tool_calling=calling, e=e)
self._run_attempts += 1
if self._run_attempts > self._max_parsing_attempts:

View File

@@ -0,0 +1,71 @@
"""
Test for the litellm patch that fixes the IndexError in ollama_pt function.
"""
import sys
import unittest
from unittest.mock import MagicMock, patch
import litellm
import pytest
from litellm.litellm_core_utils.prompt_templates.factory import ollama_pt
from crewai.patches.litellm_patch import patch_litellm_ollama_pt
class TestLitellmPatch(unittest.TestCase):
def test_ollama_pt_patch_fixes_index_error(self):
"""Test that the patch fixes the IndexError in ollama_pt."""
# Create a message list where the assistant message is the last one
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there"},
]
# Store the original function to restore it after the test
original_ollama_pt = litellm.litellm_core_utils.prompt_templates.factory.ollama_pt
try:
# Apply the patch
success = patch_litellm_ollama_pt()
self.assertTrue(success, "Patch application failed")
# Use the function from the module directly to ensure we're using the patched version
result = litellm.litellm_core_utils.prompt_templates.factory.ollama_pt("qwen3:4b", messages)
# Verify the result is as expected
self.assertIn("prompt", result)
self.assertIn("images", result)
self.assertIn("### User:\nHello", result["prompt"])
self.assertIn("### Assistant:\nHi there", result["prompt"])
finally:
# Restore the original function to avoid affecting other tests
litellm.litellm_core_utils.prompt_templates.factory.ollama_pt = original_ollama_pt
def test_ollama_pt_patch_with_empty_messages(self):
"""Test that the patch handles empty message lists."""
messages = []
# Store the original function to restore it after the test
original_ollama_pt = litellm.litellm_core_utils.prompt_templates.factory.ollama_pt
try:
# Apply the patch
success = patch_litellm_ollama_pt()
self.assertTrue(success, "Patch application failed")
# Use the function from the module directly to ensure we're using the patched version
result = litellm.litellm_core_utils.prompt_templates.factory.ollama_pt("qwen3:4b", messages)
# Verify the result is as expected
self.assertIn("prompt", result)
self.assertIn("images", result)
self.assertEqual("", result["prompt"])
self.assertEqual([], result["images"])
finally:
# Restore the original function to avoid affecting other tests
litellm.litellm_core_utils.prompt_templates.factory.ollama_pt = original_ollama_pt
if __name__ == "__main__":
unittest.main()

View File

@@ -1,83 +0,0 @@
"""Test HumanTool functionality."""
from unittest.mock import patch
import pytest
from crewai.tools import HumanTool
def test_human_tool_basic():
"""Test basic HumanTool creation and attributes."""
tool = HumanTool()
assert tool.name == "human"
assert "ask user to enter input" in tool.description.lower()
assert not tool.result_as_answer
@pytest.mark.vcr(filter_headers=["authorization"])
def test_human_tool_with_langgraph_interrupt():
"""Test HumanTool with LangGraph interrupt handling."""
tool = HumanTool()
with patch('langgraph.prebuilt.state_graphs.interrupt') as mock_interrupt:
mock_interrupt.return_value = {"data": "test response"}
result = tool._run("test query")
assert result == "test response"
mock_interrupt.assert_called_with({"query": "test query", "timeout": None})
def test_human_tool_timeout():
"""Test HumanTool timeout handling."""
tool = HumanTool()
timeout = 30.0
with patch('langgraph.prebuilt.state_graphs.interrupt') as mock_interrupt:
mock_interrupt.return_value = {"data": "test response"}
result = tool._run("test query", timeout=timeout)
assert result == "test response"
mock_interrupt.assert_called_with({"query": "test query", "timeout": timeout})
def test_human_tool_invalid_input():
"""Test HumanTool input validation."""
tool = HumanTool()
with pytest.raises(ValueError, match="Query must be a non-empty string"):
tool._run("")
with pytest.raises(ValueError, match="Query must be a non-empty string"):
tool._run(None)
@pytest.mark.asyncio
async def test_human_tool_async():
"""Test async HumanTool functionality."""
tool = HumanTool()
with patch('langgraph.prebuilt.state_graphs.interrupt') as mock_interrupt:
mock_interrupt.return_value = {"data": "test response"}
result = await tool._arun("test query")
assert result == "test response"
mock_interrupt.assert_called_with({"query": "test query", "timeout": None})
@pytest.mark.asyncio
async def test_human_tool_async_timeout():
"""Test async HumanTool timeout handling."""
tool = HumanTool()
timeout = 30.0
with patch('langgraph.prebuilt.state_graphs.interrupt') as mock_interrupt:
mock_interrupt.return_value = {"data": "test response"}
result = await tool._arun("test query", timeout=timeout)
assert result == "test response"
mock_interrupt.assert_called_with({"query": "test query", "timeout": timeout})
def test_human_tool_without_langgraph():
"""Test HumanTool behavior when LangGraph is not installed."""
tool = HumanTool()
with patch.dict('sys.modules', {'langgraph': None}):
with pytest.raises(ImportError) as exc_info:
tool._run("test query")
assert "LangGraph is required" in str(exc_info.value)
assert "pip install langgraph" in str(exc_info.value)

View File

@@ -1,13 +1,12 @@
import json
import random
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock
import pytest
from pydantic import BaseModel, Field
from crewai import Agent, Task
from crewai.tools import BaseTool
from crewai.tools.tool_calling import ToolCalling
from crewai.tools.tool_usage import ToolUsage
@@ -86,36 +85,6 @@ def test_random_number_tool_schema():
)
def test_tool_usage_interrupt_handling():
"""Test that tool usage properly propagates LangGraph interrupts."""
class InterruptingTool(BaseTool):
name: str = "interrupt_test"
description: str = "A tool that raises LangGraph interrupts"
def _run(self, query: str) -> str:
raise type('Interrupt', (Exception,), {})("test interrupt")
tool = InterruptingTool()
tool_usage = ToolUsage(
tools_handler=MagicMock(),
tools=[tool],
original_tools=[tool],
tools_description="Sample tool for testing",
tools_names="interrupt_test",
task=MagicMock(),
function_calling_llm=MagicMock(),
agent=MagicMock(),
action=MagicMock(),
)
# Test that interrupt is propagated
with pytest.raises(Exception) as exc_info:
tool_usage.use(
ToolCalling(tool_name="interrupt_test", arguments={"query": "test"}, log="test"),
"test"
)
assert "test interrupt" in str(exc_info.value)
def test_tool_usage_render():
tool = RandomNumberTool()