Address PR feedback: Add type hints, use Color enum, and enhance tests

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-03-26 09:39:08 +00:00
parent 67ad6afbde
commit 9f8f999bed
3 changed files with 159 additions and 19 deletions

View File

@@ -1,9 +1,33 @@
from typing import Optional from typing import Optional
import sys import sys
from enum import Enum
class Color(Enum):
"""Enum for text colors in terminal output."""
PURPLE = "\033[95m"
RED = "\033[91m"
GREEN = "\033[92m"
BLUE = "\033[94m"
YELLOW = "\033[93m"
BOLD = "\033[1m"
RESET = "\033[00m"
class Printer: class Printer:
def print(self, content: str, color: Optional[str] = None): """
Utility class for printing formatted text to stdout.
Uses direct stdout writing for compatibility with asynchronous environments.
"""
def print(self, content: str, color: Optional[str] = None) -> None:
"""
Print content with optional color formatting.
Args:
content: The text to print
color: Optional color name (e.g., "purple", "bold_green")
"""
output = content output = content
if color == "purple": if color == "purple":
output = self._format_purple(content) output = self._format_purple(content)
@@ -19,26 +43,53 @@ class Printer:
output = self._format_yellow(content) output = self._format_yellow(content)
elif color == "bold_yellow": elif color == "bold_yellow":
output = self._format_bold_yellow(content) output = self._format_bold_yellow(content)
sys.stdout.write(f"{output}\n")
sys.stdout.flush() try:
sys.stdout.write(f"{output}\n")
sys.stdout.flush()
except IOError:
pass
def _format_bold_purple(self, content): def _format_text(self, content: str, color: Color, bold: bool = False) -> str:
return "\033[1m\033[95m {}\033[00m".format(content) """
Format text with color and optional bold styling.
Args:
content: The text to format
color: The color to apply
bold: Whether to apply bold formatting
Returns:
Formatted text string
"""
if bold:
return f"{Color.BOLD.value}{color.value} {content}{Color.RESET.value}"
return f"{color.value} {content}{Color.RESET.value}"
def _format_bold_green(self, content): def _format_bold_purple(self, content: str) -> str:
return "\033[1m\033[92m {}\033[00m".format(content) """Format text as bold purple."""
return self._format_text(content, Color.PURPLE, bold=True)
def _format_purple(self, content): def _format_bold_green(self, content: str) -> str:
return "\033[95m {}\033[00m".format(content) """Format text as bold green."""
return self._format_text(content, Color.GREEN, bold=True)
def _format_red(self, content): def _format_purple(self, content: str) -> str:
return "\033[91m {}\033[00m".format(content) """Format text as purple."""
return self._format_text(content, Color.PURPLE)
def _format_bold_blue(self, content): def _format_red(self, content: str) -> str:
return "\033[1m\033[94m {}\033[00m".format(content) """Format text as red."""
return self._format_text(content, Color.RED)
def _format_yellow(self, content): def _format_bold_blue(self, content: str) -> str:
return "\033[93m {}\033[00m".format(content) """Format text as bold blue."""
return self._format_text(content, Color.BLUE, bold=True)
def _format_bold_yellow(self, content): def _format_yellow(self, content: str) -> str:
return "\033[1m\033[93m {}\033[00m".format(content) """Format text as yellow."""
return self._format_text(content, Color.YELLOW)
def _format_bold_yellow(self, content: str) -> str:
"""Format text as bold yellow."""
return self._format_text(content, Color.YELLOW, bold=True)

View File

@@ -2,26 +2,35 @@ import sys
import unittest import unittest
from unittest.mock import patch from unittest.mock import patch
import asyncio import asyncio
import pytest
from io import StringIO from io import StringIO
try: try:
import fastapi import fastapi
from fastapi import FastAPI
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
try:
from httpx import AsyncClient
ASYNC_CLIENT_AVAILABLE = True
except ImportError:
ASYNC_CLIENT_AVAILABLE = False
FASTAPI_AVAILABLE = True FASTAPI_AVAILABLE = True
except ImportError: except ImportError:
FASTAPI_AVAILABLE = False FASTAPI_AVAILABLE = False
ASYNC_CLIENT_AVAILABLE = False
from crewai.utilities.logger import Logger from crewai.utilities.logger import Logger
@unittest.skipIf(not FASTAPI_AVAILABLE, "FastAPI not installed") @unittest.skipIf(not FASTAPI_AVAILABLE, "FastAPI not installed")
class TestFastAPILogger(unittest.TestCase): class TestFastAPILogger(unittest.TestCase):
"""Test suite for Logger class in FastAPI context."""
def setUp(self): def setUp(self):
"""Set up test environment before each test."""
if not FASTAPI_AVAILABLE: if not FASTAPI_AVAILABLE:
self.skipTest("FastAPI not installed") self.skipTest("FastAPI not installed")
from fastapi import FastAPI
self.app = FastAPI() self.app = FastAPI()
self.logger = Logger(verbose=True) self.logger = Logger(verbose=True)
@@ -30,6 +39,11 @@ class TestFastAPILogger(unittest.TestCase):
self.logger.log("info", "This is a test log message from FastAPI") self.logger.log("info", "This is a test log message from FastAPI")
return {"message": "Hello World"} return {"message": "Hello World"}
@self.app.get("/error")
async def error_route():
self.logger.log("error", "This is an error log message from FastAPI")
return {"error": "Test error"}
self.client = TestClient(self.app) self.client = TestClient(self.app)
self.output = StringIO() self.output = StringIO()
@@ -37,9 +51,11 @@ class TestFastAPILogger(unittest.TestCase):
sys.stdout = self.output sys.stdout = self.output
def tearDown(self): def tearDown(self):
"""Clean up test environment after each test."""
sys.stdout = self.old_stdout sys.stdout = self.old_stdout
def test_logger_in_fastapi_context(self): def test_logger_in_fastapi_context(self):
"""Test that logger works in FastAPI context."""
response = self.client.get("/") response = self.client.get("/")
output = self.output.getvalue() output = self.output.getvalue()
@@ -48,3 +64,29 @@ class TestFastAPILogger(unittest.TestCase):
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), {"message": "Hello World"}) self.assertEqual(response.json(), {"message": "Hello World"})
@pytest.mark.parametrize("route,log_level,expected_message", [
("/", "info", "This is a test log message from FastAPI"),
("/error", "error", "This is an error log message from FastAPI")
])
def test_multiple_routes(self, route, log_level, expected_message):
"""Test logging from different routes with different log levels."""
response = self.client.get(route)
output = self.output.getvalue()
self.assertIn(f"[{log_level.upper()}]: {expected_message}", output)
self.assertEqual(response.status_code, 200)
@unittest.skipIf(not ASYNC_CLIENT_AVAILABLE, "AsyncClient not available")
@pytest.mark.asyncio
async def test_async_logger_in_fastapi(self):
"""Test logger in async context using AsyncClient."""
self.output = StringIO()
sys.stdout = self.output
async with AsyncClient(app=self.app, base_url="http://test") as ac:
response = await ac.get("/")
self.assertEqual(response.status_code, 200)
output = self.output.getvalue()
self.assertIn("[INFO]: This is a test log message from FastAPI", output)

View File

@@ -1,22 +1,29 @@
import sys import sys
import unittest import unittest
import threading
from unittest.mock import patch from unittest.mock import patch
from io import StringIO from io import StringIO
import pytest
from crewai.utilities.logger import Logger from crewai.utilities.logger import Logger
class TestLogger(unittest.TestCase): class TestLogger(unittest.TestCase):
"""Test suite for the Logger class."""
def setUp(self): def setUp(self):
"""Set up test environment before each test."""
self.logger = Logger(verbose=True) self.logger = Logger(verbose=True)
self.output = StringIO() self.output = StringIO()
self.old_stdout = sys.stdout self.old_stdout = sys.stdout
sys.stdout = self.output sys.stdout = self.output
def tearDown(self): def tearDown(self):
"""Clean up test environment after each test."""
sys.stdout = self.old_stdout sys.stdout = self.old_stdout
def test_log_in_sync_context(self): def test_log_in_sync_context(self):
"""Test logging in a regular synchronous context."""
self.logger.log("info", "Test message") self.logger.log("info", "Test message")
output = self.output.getvalue() output = self.output.getvalue()
self.assertIn("[INFO]: Test message", output) self.assertIn("[INFO]: Test message", output)
@@ -24,12 +31,52 @@ class TestLogger(unittest.TestCase):
@patch('sys.stdout.flush') @patch('sys.stdout.flush')
def test_stdout_is_flushed(self, mock_flush): def test_stdout_is_flushed(self, mock_flush):
"""Test that stdout is properly flushed after writing."""
self.logger.log("info", "Test message") self.logger.log("info", "Test message")
mock_flush.assert_called_once() mock_flush.assert_called_once()
@pytest.mark.parametrize("log_level,message", [
("info", "Info message"),
("error", "Error message"),
("warning", "Warning message"),
("debug", "Debug message")
])
def test_multiple_log_levels(self, log_level, message):
"""Test logging with different log levels."""
self.logger.log(log_level, message)
output = self.output.getvalue()
self.assertIn(f"[{log_level.upper()}]: {message}", output)
def test_thread_safety(self):
"""Test that logger is thread-safe."""
messages = []
for i in range(10):
messages.append(f"Message {i}")
threads = []
for message in messages:
thread = threading.Thread(
target=lambda msg: self.logger.log("info", msg),
args=(message,)
)
threads.append(thread)
for thread in threads:
thread.start()
for thread in threads:
thread.join()
output = self.output.getvalue()
for message in messages:
self.assertIn(message, output)
class TestFastAPICompatibility(unittest.TestCase): class TestFastAPICompatibility(unittest.TestCase):
"""Test compatibility with FastAPI."""
def test_import_in_fastapi(self): def test_import_in_fastapi(self):
"""Test that logger can be imported in a FastAPI context."""
try: try:
import fastapi import fastapi
from crewai.utilities.logger import Logger from crewai.utilities.logger import Logger