Address code review feedback: Windows-only patch, proper restoration, improved tests

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-04-28 21:07:54 +00:00
parent b45bb89e10
commit 9916bfd2f3
2 changed files with 58 additions and 10 deletions

View File

@@ -4,6 +4,10 @@ Patch for litellm to fix UnicodeDecodeError on Windows systems.
This patch ensures that all file open operations in litellm use UTF-8 encoding,
which prevents UnicodeDecodeError when loading JSON files on Windows systems
where the default encoding is cp1252 or cp1254.
WARNING: This patch monkey-patches the built-in open() function globally on Windows.
It forces UTF-8 encoding on all text-mode file opens, which could affect third-party
libraries expecting default platform encodings. Apply with caution and test comprehensively.
"""
import builtins
@@ -12,6 +16,7 @@ import io
import json
import logging
import os
import sys
from importlib import resources
from typing import Any, Optional, Union
@@ -19,12 +24,26 @@ logger = logging.getLogger(__name__)
def apply_patches():
"""Apply all patches to fix litellm encoding issues."""
logger.info("Applying litellm encoding patches")
"""
Apply patches to fix litellm encoding issues on Windows systems.
original_open = builtins.open
This function only applies the patch on Windows platforms where the issue occurs.
It stores the original open function for proper restoration later.
"""
# Only apply patch on Windows systems
if sys.platform != "win32":
logger.debug("Skipping litellm encoding patches on non-Windows platform")
return
@functools.wraps(original_open)
if hasattr(builtins, '_original_open'):
logger.debug("Litellm encoding patches already applied")
return
logger.debug("Applying litellm encoding patches on Windows")
builtins._original_open = builtins.open
@functools.wraps(builtins._original_open)
def patched_open(
file, mode='r', buffering=-1, encoding=None,
errors=None, newline=None, closefd=True, opener=None
@@ -32,18 +51,23 @@ def apply_patches():
if 'r' in mode and encoding is None and 'b' not in mode:
encoding = 'utf-8'
return original_open(
return builtins._original_open(
file, mode, buffering, encoding,
errors, newline, closefd, opener
)
builtins.open = patched_open
logger.info("Successfully applied litellm encoding patches")
logger.debug("Successfully applied litellm encoding patches")
def remove_patches():
"""Remove all patches (for testing purposes)."""
"""
Remove all patches (for testing purposes).
This function properly restores the original open function if it was patched.
"""
if hasattr(builtins, '_original_open'):
builtins.open = builtins._original_open
logger.info("Removed litellm encoding patches")
delattr(builtins, '_original_open')
logger.debug("Removed litellm encoding patches")

View File

@@ -4,14 +4,21 @@ import sys
import unittest
from unittest.mock import mock_open, patch
import pytest
from crewai.llm import LLM
from crewai.patches.litellm_patch import apply_patches, remove_patches
class TestLitellmEncoding(unittest.TestCase):
"""Test that the litellm encoding patch works correctly."""
def setUp(self):
"""Set up the test environment by applying the patch."""
apply_patches()
def tearDown(self):
"""Clean up the test environment by removing the patch."""
remove_patches()
def test_json_load_with_utf8_encoding(self):
"""Test that json.load is called with UTF-8 encoding."""
@@ -25,3 +32,20 @@ class TestLitellmEncoding(unittest.TestCase):
with open('test.json', 'r') as f:
data = json.load(f)
self.assertEqual(data['test'], '日本語テキスト')
def test_without_patch(self):
"""Test that demonstrates the issue without the patch."""
remove_patches()
mock_content = '{"test": "日本語テキスト"}' # Japanese text that would fail with cp1252
with patch('sys.platform', 'win32'):
mock_open_without_encoding = mock_open(read_data=mock_content)
mock_open_without_encoding.side_effect = UnicodeDecodeError('cp1252', b'\x81', 0, 1, 'invalid start byte')
with patch('builtins.open', mock_open_without_encoding):
with self.assertRaises(UnicodeDecodeError):
with open('test.json', 'r') as f:
json.load(f)
apply_patches()