Address code review feedback: Windows-only patch, proper restoration, improved tests

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-04-28 21:07:54 +00:00
parent b45bb89e10
commit 9916bfd2f3
2 changed files with 58 additions and 10 deletions

View File

@@ -4,6 +4,10 @@ Patch for litellm to fix UnicodeDecodeError on Windows systems.
This patch ensures that all file open operations in litellm use UTF-8 encoding, This patch ensures that all file open operations in litellm use UTF-8 encoding,
which prevents UnicodeDecodeError when loading JSON files on Windows systems which prevents UnicodeDecodeError when loading JSON files on Windows systems
where the default encoding is cp1252 or cp1254. where the default encoding is cp1252 or cp1254.
WARNING: This patch monkey-patches the built-in open() function globally on Windows.
It forces UTF-8 encoding on all text-mode file opens, which could affect third-party
libraries expecting default platform encodings. Apply with caution and test comprehensively.
""" """
import builtins import builtins
@@ -12,6 +16,7 @@ import io
import json import json
import logging import logging
import os import os
import sys
from importlib import resources from importlib import resources
from typing import Any, Optional, Union from typing import Any, Optional, Union
@@ -19,12 +24,26 @@ logger = logging.getLogger(__name__)
def apply_patches(): def apply_patches():
"""Apply all patches to fix litellm encoding issues.""" """
logger.info("Applying litellm encoding patches") Apply patches to fix litellm encoding issues on Windows systems.
original_open = builtins.open This function only applies the patch on Windows platforms where the issue occurs.
It stores the original open function for proper restoration later.
"""
# Only apply patch on Windows systems
if sys.platform != "win32":
logger.debug("Skipping litellm encoding patches on non-Windows platform")
return
@functools.wraps(original_open) if hasattr(builtins, '_original_open'):
logger.debug("Litellm encoding patches already applied")
return
logger.debug("Applying litellm encoding patches on Windows")
builtins._original_open = builtins.open
@functools.wraps(builtins._original_open)
def patched_open( def patched_open(
file, mode='r', buffering=-1, encoding=None, file, mode='r', buffering=-1, encoding=None,
errors=None, newline=None, closefd=True, opener=None errors=None, newline=None, closefd=True, opener=None
@@ -32,18 +51,23 @@ def apply_patches():
if 'r' in mode and encoding is None and 'b' not in mode: if 'r' in mode and encoding is None and 'b' not in mode:
encoding = 'utf-8' encoding = 'utf-8'
return original_open( return builtins._original_open(
file, mode, buffering, encoding, file, mode, buffering, encoding,
errors, newline, closefd, opener errors, newline, closefd, opener
) )
builtins.open = patched_open builtins.open = patched_open
logger.info("Successfully applied litellm encoding patches") logger.debug("Successfully applied litellm encoding patches")
def remove_patches(): def remove_patches():
"""Remove all patches (for testing purposes).""" """
Remove all patches (for testing purposes).
This function properly restores the original open function if it was patched.
"""
if hasattr(builtins, '_original_open'): if hasattr(builtins, '_original_open'):
builtins.open = builtins._original_open builtins.open = builtins._original_open
logger.info("Removed litellm encoding patches") delattr(builtins, '_original_open')
logger.debug("Removed litellm encoding patches")

View File

@@ -4,14 +4,21 @@ import sys
import unittest import unittest
from unittest.mock import mock_open, patch from unittest.mock import mock_open, patch
import pytest
from crewai.llm import LLM from crewai.llm import LLM
from crewai.patches.litellm_patch import apply_patches, remove_patches
class TestLitellmEncoding(unittest.TestCase): class TestLitellmEncoding(unittest.TestCase):
"""Test that the litellm encoding patch works correctly.""" """Test that the litellm encoding patch works correctly."""
def setUp(self):
"""Set up the test environment by applying the patch."""
apply_patches()
def tearDown(self):
"""Clean up the test environment by removing the patch."""
remove_patches()
def test_json_load_with_utf8_encoding(self): def test_json_load_with_utf8_encoding(self):
"""Test that json.load is called with UTF-8 encoding.""" """Test that json.load is called with UTF-8 encoding."""
@@ -25,3 +32,20 @@ class TestLitellmEncoding(unittest.TestCase):
with open('test.json', 'r') as f: with open('test.json', 'r') as f:
data = json.load(f) data = json.load(f)
self.assertEqual(data['test'], '日本語テキスト') self.assertEqual(data['test'], '日本語テキスト')
def test_without_patch(self):
"""Test that demonstrates the issue without the patch."""
remove_patches()
mock_content = '{"test": "日本語テキスト"}' # Japanese text that would fail with cp1252
with patch('sys.platform', 'win32'):
mock_open_without_encoding = mock_open(read_data=mock_content)
mock_open_without_encoding.side_effect = UnicodeDecodeError('cp1252', b'\x81', 0, 1, 'invalid start byte')
with patch('builtins.open', mock_open_without_encoding):
with self.assertRaises(UnicodeDecodeError):
with open('test.json', 'r') as f:
json.load(f)
apply_patches()