diff --git a/src/crewai/utilities/printer.py b/src/crewai/utilities/printer.py index 7a3b2f7d8..e082877ea 100644 --- a/src/crewai/utilities/printer.py +++ b/src/crewai/utilities/printer.py @@ -1,9 +1,33 @@ from typing import Optional import sys +from enum import Enum + + +class Color(Enum): + """Enum for text colors in terminal output.""" + PURPLE = "\033[95m" + RED = "\033[91m" + GREEN = "\033[92m" + BLUE = "\033[94m" + YELLOW = "\033[93m" + BOLD = "\033[1m" + RESET = "\033[00m" class Printer: - def print(self, content: str, color: Optional[str] = None): + """ + Utility class for printing formatted text to stdout. + Uses direct stdout writing for compatibility with asynchronous environments. + """ + + def print(self, content: str, color: Optional[str] = None) -> None: + """ + Print content with optional color formatting. + + Args: + content: The text to print + color: Optional color name (e.g., "purple", "bold_green") + """ output = content if color == "purple": output = self._format_purple(content) @@ -19,26 +43,53 @@ class Printer: output = self._format_yellow(content) elif color == "bold_yellow": output = self._format_bold_yellow(content) - sys.stdout.write(f"{output}\n") - sys.stdout.flush() + + try: + sys.stdout.write(f"{output}\n") + sys.stdout.flush() + except IOError: + pass - def _format_bold_purple(self, content): - return "\033[1m\033[95m {}\033[00m".format(content) + def _format_text(self, content: str, color: Color, bold: bool = False) -> str: + """ + Format text with color and optional bold styling. + + Args: + content: The text to format + color: The color to apply + bold: Whether to apply bold formatting + + Returns: + Formatted text string + """ + if bold: + return f"{Color.BOLD.value}{color.value} {content}{Color.RESET.value}" + return f"{color.value} {content}{Color.RESET.value}" - def _format_bold_green(self, content): - return "\033[1m\033[92m {}\033[00m".format(content) + def _format_bold_purple(self, content: str) -> str: + """Format text as bold purple.""" + return self._format_text(content, Color.PURPLE, bold=True) - def _format_purple(self, content): - return "\033[95m {}\033[00m".format(content) + def _format_bold_green(self, content: str) -> str: + """Format text as bold green.""" + return self._format_text(content, Color.GREEN, bold=True) - def _format_red(self, content): - return "\033[91m {}\033[00m".format(content) + def _format_purple(self, content: str) -> str: + """Format text as purple.""" + return self._format_text(content, Color.PURPLE) - def _format_bold_blue(self, content): - return "\033[1m\033[94m {}\033[00m".format(content) + def _format_red(self, content: str) -> str: + """Format text as red.""" + return self._format_text(content, Color.RED) - def _format_yellow(self, content): - return "\033[93m {}\033[00m".format(content) + def _format_bold_blue(self, content: str) -> str: + """Format text as bold blue.""" + return self._format_text(content, Color.BLUE, bold=True) - def _format_bold_yellow(self, content): - return "\033[1m\033[93m {}\033[00m".format(content) + def _format_yellow(self, content: str) -> str: + """Format text as yellow.""" + return self._format_text(content, Color.YELLOW) + + def _format_bold_yellow(self, content: str) -> str: + """Format text as bold yellow.""" + return self._format_text(content, Color.YELLOW, bold=True) diff --git a/tests/utilities/test_fastapi_logger.py b/tests/utilities/test_fastapi_logger.py index dc4957ddd..e96d0d25e 100644 --- a/tests/utilities/test_fastapi_logger.py +++ b/tests/utilities/test_fastapi_logger.py @@ -2,26 +2,35 @@ import sys import unittest from unittest.mock import patch import asyncio +import pytest from io import StringIO try: import fastapi + from fastapi import FastAPI from fastapi.testclient import TestClient + try: + from httpx import AsyncClient + ASYNC_CLIENT_AVAILABLE = True + except ImportError: + ASYNC_CLIENT_AVAILABLE = False FASTAPI_AVAILABLE = True except ImportError: FASTAPI_AVAILABLE = False + ASYNC_CLIENT_AVAILABLE = False from crewai.utilities.logger import Logger @unittest.skipIf(not FASTAPI_AVAILABLE, "FastAPI not installed") class TestFastAPILogger(unittest.TestCase): + """Test suite for Logger class in FastAPI context.""" + def setUp(self): + """Set up test environment before each test.""" if not FASTAPI_AVAILABLE: self.skipTest("FastAPI not installed") - from fastapi import FastAPI - self.app = FastAPI() self.logger = Logger(verbose=True) @@ -30,6 +39,11 @@ class TestFastAPILogger(unittest.TestCase): self.logger.log("info", "This is a test log message from FastAPI") return {"message": "Hello World"} + @self.app.get("/error") + async def error_route(): + self.logger.log("error", "This is an error log message from FastAPI") + return {"error": "Test error"} + self.client = TestClient(self.app) self.output = StringIO() @@ -37,9 +51,11 @@ class TestFastAPILogger(unittest.TestCase): sys.stdout = self.output def tearDown(self): + """Clean up test environment after each test.""" sys.stdout = self.old_stdout def test_logger_in_fastapi_context(self): + """Test that logger works in FastAPI context.""" response = self.client.get("/") output = self.output.getvalue() @@ -48,3 +64,29 @@ class TestFastAPILogger(unittest.TestCase): self.assertEqual(response.status_code, 200) self.assertEqual(response.json(), {"message": "Hello World"}) + + @pytest.mark.parametrize("route,log_level,expected_message", [ + ("/", "info", "This is a test log message from FastAPI"), + ("/error", "error", "This is an error log message from FastAPI") + ]) + def test_multiple_routes(self, route, log_level, expected_message): + """Test logging from different routes with different log levels.""" + response = self.client.get(route) + + output = self.output.getvalue() + self.assertIn(f"[{log_level.upper()}]: {expected_message}", output) + self.assertEqual(response.status_code, 200) + + @unittest.skipIf(not ASYNC_CLIENT_AVAILABLE, "AsyncClient not available") + @pytest.mark.asyncio + async def test_async_logger_in_fastapi(self): + """Test logger in async context using AsyncClient.""" + self.output = StringIO() + sys.stdout = self.output + + async with AsyncClient(app=self.app, base_url="http://test") as ac: + response = await ac.get("/") + self.assertEqual(response.status_code, 200) + + output = self.output.getvalue() + self.assertIn("[INFO]: This is a test log message from FastAPI", output) diff --git a/tests/utilities/test_logger.py b/tests/utilities/test_logger.py index b53222016..1ee271886 100644 --- a/tests/utilities/test_logger.py +++ b/tests/utilities/test_logger.py @@ -1,22 +1,29 @@ import sys import unittest +import threading from unittest.mock import patch from io import StringIO +import pytest from crewai.utilities.logger import Logger class TestLogger(unittest.TestCase): + """Test suite for the Logger class.""" + def setUp(self): + """Set up test environment before each test.""" self.logger = Logger(verbose=True) self.output = StringIO() self.old_stdout = sys.stdout sys.stdout = self.output def tearDown(self): + """Clean up test environment after each test.""" sys.stdout = self.old_stdout def test_log_in_sync_context(self): + """Test logging in a regular synchronous context.""" self.logger.log("info", "Test message") output = self.output.getvalue() self.assertIn("[INFO]: Test message", output) @@ -24,12 +31,52 @@ class TestLogger(unittest.TestCase): @patch('sys.stdout.flush') def test_stdout_is_flushed(self, mock_flush): + """Test that stdout is properly flushed after writing.""" self.logger.log("info", "Test message") mock_flush.assert_called_once() + + @pytest.mark.parametrize("log_level,message", [ + ("info", "Info message"), + ("error", "Error message"), + ("warning", "Warning message"), + ("debug", "Debug message") + ]) + def test_multiple_log_levels(self, log_level, message): + """Test logging with different log levels.""" + self.logger.log(log_level, message) + output = self.output.getvalue() + self.assertIn(f"[{log_level.upper()}]: {message}", output) + + def test_thread_safety(self): + """Test that logger is thread-safe.""" + messages = [] + for i in range(10): + messages.append(f"Message {i}") + + threads = [] + for message in messages: + thread = threading.Thread( + target=lambda msg: self.logger.log("info", msg), + args=(message,) + ) + threads.append(thread) + + for thread in threads: + thread.start() + + for thread in threads: + thread.join() + + output = self.output.getvalue() + for message in messages: + self.assertIn(message, output) class TestFastAPICompatibility(unittest.TestCase): + """Test compatibility with FastAPI.""" + def test_import_in_fastapi(self): + """Test that logger can be imported in a FastAPI context.""" try: import fastapi from crewai.utilities.logger import Logger