Fix lint issues and improve patch implementation

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-05-03 02:25:10 +00:00
parent 3361fab293
commit 082cbd2c1c
3 changed files with 96 additions and 18 deletions

View File

@@ -1,11 +1,11 @@
"""
Patches module for CrewAI.
This module contains patches for dependencies that need to be fixed
without waiting for upstream changes.
This module contains patches for external dependencies to fix known issues.
Version: 1.0.0
"""
from crewai.patches.litellm_patch import apply_patches
from crewai.patches.litellm_patch import apply_patches, patch_litellm_ollama_pt
# Apply all patches when the module is imported
apply_patches()
__all__ = ["apply_patches", "patch_litellm_ollama_pt"]

View File

@@ -3,14 +3,31 @@ Patch for litellm to fix IndexError in ollama_pt function.
This patch addresses issue #2744 in the crewAI repository, where an IndexError occurs
in litellm's Ollama prompt template function when CrewAI Agent with Tools uses Ollama/Qwen models.
Version: 1.0.0
"""
from typing import Any, Union
import json
import logging
from typing import Any, Dict, List, Union
# Set up logging
logger = logging.getLogger(__name__)
# Patch version
PATCH_VERSION = "1.0.0"
class PatchApplicationError(Exception):
"""Exception raised when a patch fails to apply."""
pass
def apply_patches():
"""Apply all patches to fix known issues with dependencies."""
patch_litellm_ollama_pt()
success = patch_litellm_ollama_pt()
logger.info(f"LiteLLM ollama_pt patch applied: {success}")
return success
def patch_litellm_ollama_pt():
@@ -20,23 +37,41 @@ def patch_litellm_ollama_pt():
The issue occurs when accessing messages[msg_i].get("tool_calls") without checking
if msg_i is within bounds of the messages list. This happens after tool execution
during the next LLM call.
Returns:
bool: True if the patch was applied successfully, False otherwise.
Raises:
PatchApplicationError: If there's an error during patch application.
"""
try:
# Import the module containing the function to patch
import litellm.litellm_core_utils.prompt_templates.factory as factory
# Define a patched version of the function
def patched_ollama_pt(model: str, messages: list) -> Union[str, Any]:
def patched_ollama_pt(model: str, messages: List[Dict]) -> Dict[str, Any]:
"""
Patched version of ollama_pt that adds bounds checking.
This fixes the IndexError that occurs when the assistant message is the last
message in the list and msg_i goes out of bounds.
Args:
model: The model name.
messages: The list of messages to process.
Returns:
Dict containing the prompt and images.
"""
user_message_types = {"user", "tool", "function"}
msg_i = 0
images = []
prompt = ""
# Handle empty messages list
if not messages:
return {"prompt": prompt, "images": images}
while msg_i < len(messages):
init_msg_i = msg_i
user_content_str = ""
@@ -81,9 +116,9 @@ def patch_litellm_ollama_pt():
ollama_tool_calls = []
if tool_calls:
for call in tool_calls:
call_id: str = call["id"]
function_name: str = call["function"]["name"]
arguments = factory.json.loads(call["function"]["arguments"])
call_id = call["id"]
function_name = call["function"]["name"]
arguments = json.loads(call["function"]["arguments"])
ollama_tool_calls.append(
{
@@ -98,7 +133,7 @@ def patch_litellm_ollama_pt():
if ollama_tool_calls:
assistant_content_str += (
f"Tool Calls: {factory.json.dumps(ollama_tool_calls, indent=2)}"
f"Tool Calls: {json.dumps(ollama_tool_calls, indent=2)}"
)
msg_i += 1
@@ -123,7 +158,24 @@ def patch_litellm_ollama_pt():
# Replace the original function with our patched version
factory.ollama_pt = patched_ollama_pt
logger.info(f"Successfully applied litellm ollama_pt patch version {PATCH_VERSION}")
return True
except Exception as e:
print(f"Failed to apply litellm ollama_pt patch: {e}")
error_msg = f"Failed to apply litellm ollama_pt patch: {e}"
logger.error(error_msg)
return False
# For backwards compatibility
def patch_litellm():
"""
Legacy function for backwards compatibility.
Returns:
bool: True if the patch was applied successfully, False otherwise.
"""
try:
return patch_litellm_ollama_pt()
except Exception as e:
logger.error(f"Failed to apply legacy litellm patch: {e}")
return False

View File

@@ -2,11 +2,12 @@
Test for the litellm patch that fixes the IndexError in ollama_pt function.
"""
import unittest
from unittest.mock import patch, MagicMock
import sys
import unittest
from unittest.mock import MagicMock, patch
import litellm
import pytest
from litellm.litellm_core_utils.prompt_templates.factory import ollama_pt
from crewai.patches.litellm_patch import patch_litellm_ollama_pt
@@ -26,10 +27,11 @@ class TestLitellmPatch(unittest.TestCase):
try:
# Apply the patch
patch_litellm_ollama_pt()
success = patch_litellm_ollama_pt()
self.assertTrue(success, "Patch application failed")
# The patched function should not raise an IndexError
result = ollama_pt("qwen3:4b", messages)
# Use the function from the module directly to ensure we're using the patched version
result = litellm.litellm_core_utils.prompt_templates.factory.ollama_pt("qwen3:4b", messages)
# Verify the result is as expected
self.assertIn("prompt", result)
@@ -39,6 +41,30 @@ class TestLitellmPatch(unittest.TestCase):
finally:
# Restore the original function to avoid affecting other tests
litellm.litellm_core_utils.prompt_templates.factory.ollama_pt = original_ollama_pt
def test_ollama_pt_patch_with_empty_messages(self):
"""Test that the patch handles empty message lists."""
messages = []
# Store the original function to restore it after the test
original_ollama_pt = litellm.litellm_core_utils.prompt_templates.factory.ollama_pt
try:
# Apply the patch
success = patch_litellm_ollama_pt()
self.assertTrue(success, "Patch application failed")
# Use the function from the module directly to ensure we're using the patched version
result = litellm.litellm_core_utils.prompt_templates.factory.ollama_pt("qwen3:4b", messages)
# Verify the result is as expected
self.assertIn("prompt", result)
self.assertIn("images", result)
self.assertEqual("", result["prompt"])
self.assertEqual([], result["images"])
finally:
# Restore the original function to avoid affecting other tests
litellm.litellm_core_utils.prompt_templates.factory.ollama_pt = original_ollama_pt
if __name__ == "__main__":