diff --git a/src/crewai/patches/litellm_patch.py b/src/crewai/patches/litellm_patch.py index 474702fdb..d8521ee77 100644 --- a/src/crewai/patches/litellm_patch.py +++ b/src/crewai/patches/litellm_patch.py @@ -4,6 +4,10 @@ Patch for litellm to fix UnicodeDecodeError on Windows systems. This patch ensures that all file open operations in litellm use UTF-8 encoding, which prevents UnicodeDecodeError when loading JSON files on Windows systems where the default encoding is cp1252 or cp1254. + +WARNING: This patch monkey-patches the built-in open() function globally on Windows. +It forces UTF-8 encoding on all text-mode file opens, which could affect third-party +libraries expecting default platform encodings. Apply with caution and test comprehensively. """ import builtins @@ -12,6 +16,7 @@ import io import json import logging import os +import sys from importlib import resources from typing import Any, Optional, Union @@ -19,12 +24,26 @@ logger = logging.getLogger(__name__) def apply_patches(): - """Apply all patches to fix litellm encoding issues.""" - logger.info("Applying litellm encoding patches") + """ + Apply patches to fix litellm encoding issues on Windows systems. - original_open = builtins.open + This function only applies the patch on Windows platforms where the issue occurs. + It stores the original open function for proper restoration later. + """ + # Only apply patch on Windows systems + if sys.platform != "win32": + logger.debug("Skipping litellm encoding patches on non-Windows platform") + return - @functools.wraps(original_open) + if hasattr(builtins, '_original_open'): + logger.debug("Litellm encoding patches already applied") + return + + logger.debug("Applying litellm encoding patches on Windows") + + builtins._original_open = builtins.open + + @functools.wraps(builtins._original_open) def patched_open( file, mode='r', buffering=-1, encoding=None, errors=None, newline=None, closefd=True, opener=None @@ -32,18 +51,23 @@ def apply_patches(): if 'r' in mode and encoding is None and 'b' not in mode: encoding = 'utf-8' - return original_open( + return builtins._original_open( file, mode, buffering, encoding, errors, newline, closefd, opener ) builtins.open = patched_open - logger.info("Successfully applied litellm encoding patches") + logger.debug("Successfully applied litellm encoding patches") def remove_patches(): - """Remove all patches (for testing purposes).""" + """ + Remove all patches (for testing purposes). + + This function properly restores the original open function if it was patched. + """ if hasattr(builtins, '_original_open'): builtins.open = builtins._original_open - logger.info("Removed litellm encoding patches") + delattr(builtins, '_original_open') + logger.debug("Removed litellm encoding patches") diff --git a/tests/litellm_tests/test_encoding.py b/tests/litellm_tests/test_encoding.py index aa2476296..fa307eba4 100644 --- a/tests/litellm_tests/test_encoding.py +++ b/tests/litellm_tests/test_encoding.py @@ -4,14 +4,21 @@ import sys import unittest from unittest.mock import mock_open, patch -import pytest - from crewai.llm import LLM +from crewai.patches.litellm_patch import apply_patches, remove_patches class TestLitellmEncoding(unittest.TestCase): """Test that the litellm encoding patch works correctly.""" + def setUp(self): + """Set up the test environment by applying the patch.""" + apply_patches() + + def tearDown(self): + """Clean up the test environment by removing the patch.""" + remove_patches() + def test_json_load_with_utf8_encoding(self): """Test that json.load is called with UTF-8 encoding.""" @@ -25,3 +32,20 @@ class TestLitellmEncoding(unittest.TestCase): with open('test.json', 'r') as f: data = json.load(f) self.assertEqual(data['test'], '日本語テキスト') + + def test_without_patch(self): + """Test that demonstrates the issue without the patch.""" + remove_patches() + + mock_content = '{"test": "日本語テキスト"}' # Japanese text that would fail with cp1252 + + with patch('sys.platform', 'win32'): + mock_open_without_encoding = mock_open(read_data=mock_content) + mock_open_without_encoding.side_effect = UnicodeDecodeError('cp1252', b'\x81', 0, 1, 'invalid start byte') + + with patch('builtins.open', mock_open_without_encoding): + with self.assertRaises(UnicodeDecodeError): + with open('test.json', 'r') as f: + json.load(f) + + apply_patches()