From 082cbd2c1caf874fe6fa752790973cae7e08ed38 Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Sat, 3 May 2025 02:25:10 +0000 Subject: [PATCH] Fix lint issues and improve patch implementation Co-Authored-By: Joe Moura --- src/crewai/patches/__init__.py | 10 ++--- src/crewai/patches/litellm_patch.py | 68 +++++++++++++++++++++++++---- tests/patches/test_litellm_patch.py | 36 ++++++++++++--- 3 files changed, 96 insertions(+), 18 deletions(-) diff --git a/src/crewai/patches/__init__.py b/src/crewai/patches/__init__.py index e8b0958a8..7ff793477 100644 --- a/src/crewai/patches/__init__.py +++ b/src/crewai/patches/__init__.py @@ -1,11 +1,11 @@ """ Patches module for CrewAI. -This module contains patches for dependencies that need to be fixed -without waiting for upstream changes. +This module contains patches for external dependencies to fix known issues. + +Version: 1.0.0 """ -from crewai.patches.litellm_patch import apply_patches +from crewai.patches.litellm_patch import apply_patches, patch_litellm_ollama_pt -# Apply all patches when the module is imported -apply_patches() +__all__ = ["apply_patches", "patch_litellm_ollama_pt"] diff --git a/src/crewai/patches/litellm_patch.py b/src/crewai/patches/litellm_patch.py index 349be7683..e65615e75 100644 --- a/src/crewai/patches/litellm_patch.py +++ b/src/crewai/patches/litellm_patch.py @@ -3,14 +3,31 @@ Patch for litellm to fix IndexError in ollama_pt function. This patch addresses issue #2744 in the crewAI repository, where an IndexError occurs in litellm's Ollama prompt template function when CrewAI Agent with Tools uses Ollama/Qwen models. + +Version: 1.0.0 """ -from typing import Any, Union +import json +import logging +from typing import Any, Dict, List, Union + +# Set up logging +logger = logging.getLogger(__name__) + +# Patch version +PATCH_VERSION = "1.0.0" + + +class PatchApplicationError(Exception): + """Exception raised when a patch fails to apply.""" + pass def apply_patches(): """Apply all patches to fix known issues with dependencies.""" - patch_litellm_ollama_pt() + success = patch_litellm_ollama_pt() + logger.info(f"LiteLLM ollama_pt patch applied: {success}") + return success def patch_litellm_ollama_pt(): @@ -20,23 +37,41 @@ def patch_litellm_ollama_pt(): The issue occurs when accessing messages[msg_i].get("tool_calls") without checking if msg_i is within bounds of the messages list. This happens after tool execution during the next LLM call. + + Returns: + bool: True if the patch was applied successfully, False otherwise. + + Raises: + PatchApplicationError: If there's an error during patch application. """ try: # Import the module containing the function to patch import litellm.litellm_core_utils.prompt_templates.factory as factory # Define a patched version of the function - def patched_ollama_pt(model: str, messages: list) -> Union[str, Any]: + def patched_ollama_pt(model: str, messages: List[Dict]) -> Dict[str, Any]: """ Patched version of ollama_pt that adds bounds checking. This fixes the IndexError that occurs when the assistant message is the last message in the list and msg_i goes out of bounds. + + Args: + model: The model name. + messages: The list of messages to process. + + Returns: + Dict containing the prompt and images. """ user_message_types = {"user", "tool", "function"} msg_i = 0 images = [] prompt = "" + + # Handle empty messages list + if not messages: + return {"prompt": prompt, "images": images} + while msg_i < len(messages): init_msg_i = msg_i user_content_str = "" @@ -81,9 +116,9 @@ def patch_litellm_ollama_pt(): ollama_tool_calls = [] if tool_calls: for call in tool_calls: - call_id: str = call["id"] - function_name: str = call["function"]["name"] - arguments = factory.json.loads(call["function"]["arguments"]) + call_id = call["id"] + function_name = call["function"]["name"] + arguments = json.loads(call["function"]["arguments"]) ollama_tool_calls.append( { @@ -98,7 +133,7 @@ def patch_litellm_ollama_pt(): if ollama_tool_calls: assistant_content_str += ( - f"Tool Calls: {factory.json.dumps(ollama_tool_calls, indent=2)}" + f"Tool Calls: {json.dumps(ollama_tool_calls, indent=2)}" ) msg_i += 1 @@ -123,7 +158,24 @@ def patch_litellm_ollama_pt(): # Replace the original function with our patched version factory.ollama_pt = patched_ollama_pt + logger.info(f"Successfully applied litellm ollama_pt patch version {PATCH_VERSION}") return True except Exception as e: - print(f"Failed to apply litellm ollama_pt patch: {e}") + error_msg = f"Failed to apply litellm ollama_pt patch: {e}" + logger.error(error_msg) + return False + + +# For backwards compatibility +def patch_litellm(): + """ + Legacy function for backwards compatibility. + + Returns: + bool: True if the patch was applied successfully, False otherwise. + """ + try: + return patch_litellm_ollama_pt() + except Exception as e: + logger.error(f"Failed to apply legacy litellm patch: {e}") return False diff --git a/tests/patches/test_litellm_patch.py b/tests/patches/test_litellm_patch.py index fa839fb0f..77be010a1 100644 --- a/tests/patches/test_litellm_patch.py +++ b/tests/patches/test_litellm_patch.py @@ -2,11 +2,12 @@ Test for the litellm patch that fixes the IndexError in ollama_pt function. """ -import unittest -from unittest.mock import patch, MagicMock import sys +import unittest +from unittest.mock import MagicMock, patch import litellm +import pytest from litellm.litellm_core_utils.prompt_templates.factory import ollama_pt from crewai.patches.litellm_patch import patch_litellm_ollama_pt @@ -26,10 +27,11 @@ class TestLitellmPatch(unittest.TestCase): try: # Apply the patch - patch_litellm_ollama_pt() + success = patch_litellm_ollama_pt() + self.assertTrue(success, "Patch application failed") - # The patched function should not raise an IndexError - result = ollama_pt("qwen3:4b", messages) + # Use the function from the module directly to ensure we're using the patched version + result = litellm.litellm_core_utils.prompt_templates.factory.ollama_pt("qwen3:4b", messages) # Verify the result is as expected self.assertIn("prompt", result) @@ -39,6 +41,30 @@ class TestLitellmPatch(unittest.TestCase): finally: # Restore the original function to avoid affecting other tests litellm.litellm_core_utils.prompt_templates.factory.ollama_pt = original_ollama_pt + + def test_ollama_pt_patch_with_empty_messages(self): + """Test that the patch handles empty message lists.""" + messages = [] + + # Store the original function to restore it after the test + original_ollama_pt = litellm.litellm_core_utils.prompt_templates.factory.ollama_pt + + try: + # Apply the patch + success = patch_litellm_ollama_pt() + self.assertTrue(success, "Patch application failed") + + # Use the function from the module directly to ensure we're using the patched version + result = litellm.litellm_core_utils.prompt_templates.factory.ollama_pt("qwen3:4b", messages) + + # Verify the result is as expected + self.assertIn("prompt", result) + self.assertIn("images", result) + self.assertEqual("", result["prompt"]) + self.assertEqual([], result["images"]) + finally: + # Restore the original function to avoid affecting other tests + litellm.litellm_core_utils.prompt_templates.factory.ollama_pt = original_ollama_pt if __name__ == "__main__":