mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
Fix IndexError in litellm's ollama_pt function when using Ollama/Qwen models with tools
This patch addresses issue #2744 by adding bounds checking before accessing messages[msg_i].get('tool_calls') in the ollama_pt function. The issue occurs when an assistant message is the last message in the list, causing msg_i to go out of bounds. The fix is implemented as a monkey patch in CrewAI to avoid waiting for an upstream fix in litellm. Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
11
src/crewai/patches/__init__.py
Normal file
11
src/crewai/patches/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
"""
|
||||||
|
Patches module for CrewAI.
|
||||||
|
|
||||||
|
This module contains patches for dependencies that need to be fixed
|
||||||
|
without waiting for upstream changes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from crewai.patches.litellm_patch import apply_patches
|
||||||
|
|
||||||
|
# Apply all patches when the module is imported
|
||||||
|
apply_patches()
|
||||||
129
src/crewai/patches/litellm_patch.py
Normal file
129
src/crewai/patches/litellm_patch.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
"""
|
||||||
|
Patch for litellm to fix IndexError in ollama_pt function.
|
||||||
|
|
||||||
|
This patch addresses issue #2744 in the crewAI repository, where an IndexError occurs
|
||||||
|
in litellm's Ollama prompt template function when CrewAI Agent with Tools uses Ollama/Qwen models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Union
|
||||||
|
|
||||||
|
|
||||||
|
def apply_patches():
|
||||||
|
"""Apply all patches to fix known issues with dependencies."""
|
||||||
|
patch_litellm_ollama_pt()
|
||||||
|
|
||||||
|
|
||||||
|
def patch_litellm_ollama_pt():
|
||||||
|
"""
|
||||||
|
Patch the ollama_pt function in litellm to fix IndexError.
|
||||||
|
|
||||||
|
The issue occurs when accessing messages[msg_i].get("tool_calls") without checking
|
||||||
|
if msg_i is within bounds of the messages list. This happens after tool execution
|
||||||
|
during the next LLM call.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Import the module containing the function to patch
|
||||||
|
import litellm.litellm_core_utils.prompt_templates.factory as factory
|
||||||
|
|
||||||
|
# Define a patched version of the function
|
||||||
|
def patched_ollama_pt(model: str, messages: list) -> Union[str, Any]:
|
||||||
|
"""
|
||||||
|
Patched version of ollama_pt that adds bounds checking.
|
||||||
|
|
||||||
|
This fixes the IndexError that occurs when the assistant message is the last
|
||||||
|
message in the list and msg_i goes out of bounds.
|
||||||
|
"""
|
||||||
|
user_message_types = {"user", "tool", "function"}
|
||||||
|
msg_i = 0
|
||||||
|
images = []
|
||||||
|
prompt = ""
|
||||||
|
while msg_i < len(messages):
|
||||||
|
init_msg_i = msg_i
|
||||||
|
user_content_str = ""
|
||||||
|
## MERGE CONSECUTIVE USER CONTENT ##
|
||||||
|
while msg_i < len(messages) and messages[msg_i]["role"] in user_message_types:
|
||||||
|
msg_content = messages[msg_i].get("content")
|
||||||
|
if msg_content:
|
||||||
|
if isinstance(msg_content, list):
|
||||||
|
for m in msg_content:
|
||||||
|
if m.get("type", "") == "image_url":
|
||||||
|
if isinstance(m["image_url"], str):
|
||||||
|
images.append(m["image_url"])
|
||||||
|
elif isinstance(m["image_url"], dict):
|
||||||
|
images.append(m["image_url"]["url"])
|
||||||
|
elif m.get("type", "") == "text":
|
||||||
|
user_content_str += m["text"]
|
||||||
|
else:
|
||||||
|
# Tool message content will always be a string
|
||||||
|
user_content_str += msg_content
|
||||||
|
|
||||||
|
msg_i += 1
|
||||||
|
|
||||||
|
if user_content_str:
|
||||||
|
prompt += f"### User:\n{user_content_str}\n\n"
|
||||||
|
|
||||||
|
system_content_str, msg_i = factory._handle_ollama_system_message(
|
||||||
|
messages, prompt, msg_i
|
||||||
|
)
|
||||||
|
if system_content_str:
|
||||||
|
prompt += f"### System:\n{system_content_str}\n\n"
|
||||||
|
|
||||||
|
assistant_content_str = ""
|
||||||
|
## MERGE CONSECUTIVE ASSISTANT CONTENT ##
|
||||||
|
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
|
||||||
|
assistant_content_str += factory.convert_content_list_to_str(messages[msg_i])
|
||||||
|
msg_i += 1
|
||||||
|
|
||||||
|
# Add bounds check before accessing messages[msg_i]
|
||||||
|
# This is the key fix for the IndexError
|
||||||
|
if msg_i < len(messages):
|
||||||
|
tool_calls = messages[msg_i].get("tool_calls")
|
||||||
|
ollama_tool_calls = []
|
||||||
|
if tool_calls:
|
||||||
|
for call in tool_calls:
|
||||||
|
call_id: str = call["id"]
|
||||||
|
function_name: str = call["function"]["name"]
|
||||||
|
arguments = factory.json.loads(call["function"]["arguments"])
|
||||||
|
|
||||||
|
ollama_tool_calls.append(
|
||||||
|
{
|
||||||
|
"id": call_id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": function_name,
|
||||||
|
"arguments": arguments,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if ollama_tool_calls:
|
||||||
|
assistant_content_str += (
|
||||||
|
f"Tool Calls: {factory.json.dumps(ollama_tool_calls, indent=2)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
msg_i += 1
|
||||||
|
|
||||||
|
if assistant_content_str:
|
||||||
|
prompt += f"### Assistant:\n{assistant_content_str}\n\n"
|
||||||
|
|
||||||
|
if msg_i == init_msg_i: # prevent infinite loops
|
||||||
|
raise factory.litellm.BadRequestError(
|
||||||
|
message=factory.BAD_MESSAGE_ERROR_STR + f"passed in {messages[msg_i]}",
|
||||||
|
model=model,
|
||||||
|
llm_provider="ollama",
|
||||||
|
)
|
||||||
|
|
||||||
|
response_dict = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"images": images,
|
||||||
|
}
|
||||||
|
|
||||||
|
return response_dict
|
||||||
|
|
||||||
|
# Replace the original function with our patched version
|
||||||
|
factory.ollama_pt = patched_ollama_pt
|
||||||
|
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to apply litellm ollama_pt patch: {e}")
|
||||||
|
return False
|
||||||
45
tests/patches/test_litellm_patch.py
Normal file
45
tests/patches/test_litellm_patch.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
"""
|
||||||
|
Test for the litellm patch that fixes the IndexError in ollama_pt function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.litellm_core_utils.prompt_templates.factory import ollama_pt
|
||||||
|
|
||||||
|
from crewai.patches.litellm_patch import patch_litellm_ollama_pt
|
||||||
|
|
||||||
|
|
||||||
|
class TestLitellmPatch(unittest.TestCase):
|
||||||
|
def test_ollama_pt_patch_fixes_index_error(self):
|
||||||
|
"""Test that the patch fixes the IndexError in ollama_pt."""
|
||||||
|
# Create a message list where the assistant message is the last one
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
{"role": "assistant", "content": "Hi there"},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Store the original function to restore it after the test
|
||||||
|
original_ollama_pt = litellm.litellm_core_utils.prompt_templates.factory.ollama_pt
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Apply the patch
|
||||||
|
patch_litellm_ollama_pt()
|
||||||
|
|
||||||
|
# The patched function should not raise an IndexError
|
||||||
|
result = ollama_pt("qwen3:4b", messages)
|
||||||
|
|
||||||
|
# Verify the result is as expected
|
||||||
|
self.assertIn("prompt", result)
|
||||||
|
self.assertIn("images", result)
|
||||||
|
self.assertIn("### User:\nHello", result["prompt"])
|
||||||
|
self.assertIn("### Assistant:\nHi there", result["prompt"])
|
||||||
|
finally:
|
||||||
|
# Restore the original function to avoid affecting other tests
|
||||||
|
litellm.litellm_core_utils.prompt_templates.factory.ollama_pt = original_ollama_pt
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user