mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 09:08:31 +00:00
Fix lint issues and improve patch implementation
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -1,11 +1,11 @@
|
||||
"""
|
||||
Patches module for CrewAI.
|
||||
|
||||
This module contains patches for dependencies that need to be fixed
|
||||
without waiting for upstream changes.
|
||||
This module contains patches for external dependencies to fix known issues.
|
||||
|
||||
Version: 1.0.0
|
||||
"""
|
||||
|
||||
from crewai.patches.litellm_patch import apply_patches
|
||||
from crewai.patches.litellm_patch import apply_patches, patch_litellm_ollama_pt
|
||||
|
||||
# Apply all patches when the module is imported
|
||||
apply_patches()
|
||||
__all__ = ["apply_patches", "patch_litellm_ollama_pt"]
|
||||
|
||||
@@ -3,14 +3,31 @@ Patch for litellm to fix IndexError in ollama_pt function.
|
||||
|
||||
This patch addresses issue #2744 in the crewAI repository, where an IndexError occurs
|
||||
in litellm's Ollama prompt template function when CrewAI Agent with Tools uses Ollama/Qwen models.
|
||||
|
||||
Version: 1.0.0
|
||||
"""
|
||||
|
||||
from typing import Any, Union
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
# Set up logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Patch version
|
||||
PATCH_VERSION = "1.0.0"
|
||||
|
||||
|
||||
class PatchApplicationError(Exception):
|
||||
"""Exception raised when a patch fails to apply."""
|
||||
pass
|
||||
|
||||
|
||||
def apply_patches():
|
||||
"""Apply all patches to fix known issues with dependencies."""
|
||||
patch_litellm_ollama_pt()
|
||||
success = patch_litellm_ollama_pt()
|
||||
logger.info(f"LiteLLM ollama_pt patch applied: {success}")
|
||||
return success
|
||||
|
||||
|
||||
def patch_litellm_ollama_pt():
|
||||
@@ -20,23 +37,41 @@ def patch_litellm_ollama_pt():
|
||||
The issue occurs when accessing messages[msg_i].get("tool_calls") without checking
|
||||
if msg_i is within bounds of the messages list. This happens after tool execution
|
||||
during the next LLM call.
|
||||
|
||||
Returns:
|
||||
bool: True if the patch was applied successfully, False otherwise.
|
||||
|
||||
Raises:
|
||||
PatchApplicationError: If there's an error during patch application.
|
||||
"""
|
||||
try:
|
||||
# Import the module containing the function to patch
|
||||
import litellm.litellm_core_utils.prompt_templates.factory as factory
|
||||
|
||||
# Define a patched version of the function
|
||||
def patched_ollama_pt(model: str, messages: list) -> Union[str, Any]:
|
||||
def patched_ollama_pt(model: str, messages: List[Dict]) -> Dict[str, Any]:
|
||||
"""
|
||||
Patched version of ollama_pt that adds bounds checking.
|
||||
|
||||
This fixes the IndexError that occurs when the assistant message is the last
|
||||
message in the list and msg_i goes out of bounds.
|
||||
|
||||
Args:
|
||||
model: The model name.
|
||||
messages: The list of messages to process.
|
||||
|
||||
Returns:
|
||||
Dict containing the prompt and images.
|
||||
"""
|
||||
user_message_types = {"user", "tool", "function"}
|
||||
msg_i = 0
|
||||
images = []
|
||||
prompt = ""
|
||||
|
||||
# Handle empty messages list
|
||||
if not messages:
|
||||
return {"prompt": prompt, "images": images}
|
||||
|
||||
while msg_i < len(messages):
|
||||
init_msg_i = msg_i
|
||||
user_content_str = ""
|
||||
@@ -81,9 +116,9 @@ def patch_litellm_ollama_pt():
|
||||
ollama_tool_calls = []
|
||||
if tool_calls:
|
||||
for call in tool_calls:
|
||||
call_id: str = call["id"]
|
||||
function_name: str = call["function"]["name"]
|
||||
arguments = factory.json.loads(call["function"]["arguments"])
|
||||
call_id = call["id"]
|
||||
function_name = call["function"]["name"]
|
||||
arguments = json.loads(call["function"]["arguments"])
|
||||
|
||||
ollama_tool_calls.append(
|
||||
{
|
||||
@@ -98,7 +133,7 @@ def patch_litellm_ollama_pt():
|
||||
|
||||
if ollama_tool_calls:
|
||||
assistant_content_str += (
|
||||
f"Tool Calls: {factory.json.dumps(ollama_tool_calls, indent=2)}"
|
||||
f"Tool Calls: {json.dumps(ollama_tool_calls, indent=2)}"
|
||||
)
|
||||
|
||||
msg_i += 1
|
||||
@@ -123,7 +158,24 @@ def patch_litellm_ollama_pt():
|
||||
# Replace the original function with our patched version
|
||||
factory.ollama_pt = patched_ollama_pt
|
||||
|
||||
logger.info(f"Successfully applied litellm ollama_pt patch version {PATCH_VERSION}")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to apply litellm ollama_pt patch: {e}")
|
||||
error_msg = f"Failed to apply litellm ollama_pt patch: {e}"
|
||||
logger.error(error_msg)
|
||||
return False
|
||||
|
||||
|
||||
# For backwards compatibility
|
||||
def patch_litellm():
|
||||
"""
|
||||
Legacy function for backwards compatibility.
|
||||
|
||||
Returns:
|
||||
bool: True if the patch was applied successfully, False otherwise.
|
||||
"""
|
||||
try:
|
||||
return patch_litellm_ollama_pt()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to apply legacy litellm patch: {e}")
|
||||
return False
|
||||
|
||||
@@ -2,11 +2,12 @@
|
||||
Test for the litellm patch that fixes the IndexError in ollama_pt function.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import sys
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import litellm
|
||||
import pytest
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import ollama_pt
|
||||
|
||||
from crewai.patches.litellm_patch import patch_litellm_ollama_pt
|
||||
@@ -26,10 +27,11 @@ class TestLitellmPatch(unittest.TestCase):
|
||||
|
||||
try:
|
||||
# Apply the patch
|
||||
patch_litellm_ollama_pt()
|
||||
success = patch_litellm_ollama_pt()
|
||||
self.assertTrue(success, "Patch application failed")
|
||||
|
||||
# The patched function should not raise an IndexError
|
||||
result = ollama_pt("qwen3:4b", messages)
|
||||
# Use the function from the module directly to ensure we're using the patched version
|
||||
result = litellm.litellm_core_utils.prompt_templates.factory.ollama_pt("qwen3:4b", messages)
|
||||
|
||||
# Verify the result is as expected
|
||||
self.assertIn("prompt", result)
|
||||
@@ -39,6 +41,30 @@ class TestLitellmPatch(unittest.TestCase):
|
||||
finally:
|
||||
# Restore the original function to avoid affecting other tests
|
||||
litellm.litellm_core_utils.prompt_templates.factory.ollama_pt = original_ollama_pt
|
||||
|
||||
def test_ollama_pt_patch_with_empty_messages(self):
|
||||
"""Test that the patch handles empty message lists."""
|
||||
messages = []
|
||||
|
||||
# Store the original function to restore it after the test
|
||||
original_ollama_pt = litellm.litellm_core_utils.prompt_templates.factory.ollama_pt
|
||||
|
||||
try:
|
||||
# Apply the patch
|
||||
success = patch_litellm_ollama_pt()
|
||||
self.assertTrue(success, "Patch application failed")
|
||||
|
||||
# Use the function from the module directly to ensure we're using the patched version
|
||||
result = litellm.litellm_core_utils.prompt_templates.factory.ollama_pt("qwen3:4b", messages)
|
||||
|
||||
# Verify the result is as expected
|
||||
self.assertIn("prompt", result)
|
||||
self.assertIn("images", result)
|
||||
self.assertEqual("", result["prompt"])
|
||||
self.assertEqual([], result["images"])
|
||||
finally:
|
||||
# Restore the original function to avoid affecting other tests
|
||||
litellm.litellm_core_utils.prompt_templates.factory.ollama_pt = original_ollama_pt
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user