Implement reviewer suggestions: CLI validation, enhanced error handling, and test improvements

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-05-05 22:36:15 +00:00
parent 6e8e066091
commit dd5f170f45
3 changed files with 103 additions and 11 deletions

View File

@@ -213,6 +213,8 @@ def install(context):
) )
def run(record: bool = False, replay: bool = False): def run(record: bool = False, replay: bool = False):
"""Run the Crew.""" """Run the Crew."""
if record and replay:
raise click.UsageError("Cannot use --record and --replay simultaneously")
click.echo("Running the Crew") click.echo("Running the Crew")
run_crew(record=record, replay=replay) run_crew(record=record, replay=replay)

View File

@@ -1,17 +1,33 @@
import logging
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from crewai.memory.storage.llm_response_cache_storage import LLMResponseCacheStorage from crewai.memory.storage.llm_response_cache_storage import LLMResponseCacheStorage
logger = logging.getLogger(__name__)
class LLMResponseCacheHandler: class LLMResponseCacheHandler:
""" """
Handler for the LLM response cache storage. Handler for the LLM response cache storage.
Used for record/replay functionality. Used for record/replay functionality.
""" """
def __init__(self) -> None: def __init__(self, max_cache_age_days: int = 7) -> None:
"""
Initializes the LLM response cache handler.
Args:
max_cache_age_days: Maximum age of cache entries in days. Defaults to 7.
"""
self.storage = LLMResponseCacheStorage() self.storage = LLMResponseCacheStorage()
self._recording = False self._recording = False
self._replaying = False self._replaying = False
self.max_cache_age_days = max_cache_age_days
try:
self.storage.cleanup_expired_cache(self.max_cache_age_days)
except Exception as e:
logger.warning(f"Failed to cleanup expired cache on initialization: {e}")
def start_recording(self) -> None: def start_recording(self) -> None:
""" """
@@ -19,6 +35,7 @@ class LLMResponseCacheHandler:
""" """
self._recording = True self._recording = True
self._replaying = False self._replaying = False
logger.info("Started recording LLM responses")
def start_replaying(self) -> None: def start_replaying(self) -> None:
""" """
@@ -26,13 +43,28 @@ class LLMResponseCacheHandler:
""" """
self._recording = False self._recording = False
self._replaying = True self._replaying = True
logger.info("Started replaying LLM responses from cache")
try:
stats = self.storage.get_cache_stats()
logger.info(f"Cache statistics: {stats}")
except Exception as e:
logger.warning(f"Failed to get cache statistics: {e}")
def stop(self) -> None: def stop(self) -> None:
""" """
Stops recording or replaying. Stops recording or replaying.
""" """
was_recording = self._recording
was_replaying = self._replaying
self._recording = False self._recording = False
self._replaying = False self._replaying = False
if was_recording:
logger.info("Stopped recording LLM responses")
if was_replaying:
logger.info("Stopped replaying LLM responses")
def is_recording(self) -> bool: def is_recording(self) -> bool:
""" """
@@ -49,21 +81,76 @@ class LLMResponseCacheHandler:
def cache_response(self, model: str, messages: List[Dict[str, str]], response: str) -> None: def cache_response(self, model: str, messages: List[Dict[str, str]], response: str) -> None:
""" """
Caches an LLM response if recording is active. Caches an LLM response if recording is active.
Args:
model: The model used for the LLM call.
messages: The messages sent to the LLM.
response: The response from the LLM.
""" """
if self._recording: if not self._recording:
return
try:
self.storage.add(model, messages, response) self.storage.add(model, messages, response)
logger.debug(f"Cached response for model {model}")
except Exception as e:
logger.error(f"Failed to cache response: {e}")
def get_cached_response(self, model: str, messages: List[Dict[str, str]]) -> Optional[str]: def get_cached_response(self, model: str, messages: List[Dict[str, str]]) -> Optional[str]:
""" """
Retrieves a cached LLM response if replaying is active. Retrieves a cached LLM response if replaying is active.
Returns None if not found or if replaying is not active. Returns None if not found or if replaying is not active.
Args:
model: The model used for the LLM call.
messages: The messages sent to the LLM.
Returns:
The cached response, or None if not found or if replaying is not active.
""" """
if self._replaying: if not self._replaying:
return self.storage.get(model, messages) return None
return None
try:
response = self.storage.get(model, messages)
if response:
logger.debug(f"Retrieved cached response for model {model}")
else:
logger.debug(f"No cached response found for model {model}")
return response
except Exception as e:
logger.error(f"Failed to retrieve cached response: {e}")
return None
def clear_cache(self) -> None: def clear_cache(self) -> None:
""" """
Clears the LLM response cache. Clears the LLM response cache.
""" """
self.storage.delete_all() try:
self.storage.delete_all()
logger.info("Cleared LLM response cache")
except Exception as e:
logger.error(f"Failed to clear cache: {e}")
def cleanup_expired_cache(self) -> None:
"""
Removes cache entries older than the maximum age.
"""
try:
self.storage.cleanup_expired_cache(self.max_cache_age_days)
logger.info(f"Cleaned up expired cache entries (older than {self.max_cache_age_days} days)")
except Exception as e:
logger.error(f"Failed to cleanup expired cache: {e}")
def get_cache_stats(self) -> Dict[str, Any]:
"""
Returns statistics about the cache.
Returns:
A dictionary containing cache statistics.
"""
try:
return self.storage.get_cache_stats()
except Exception as e:
logger.error(f"Failed to get cache stats: {e}")
return {"error": str(e)}

View File

@@ -1,6 +1,7 @@
import pytest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest
from crewai.agent import Agent from crewai.agent import Agent
from crewai.crew import Crew from crewai.crew import Crew
from crewai.process import Process from crewai.process import Process
@@ -34,8 +35,9 @@ def test_crew_recording_mode():
mock_llm = MagicMock() mock_llm = MagicMock()
agent.llm = mock_llm agent.llm = mock_llm
with patch('crewai.utilities.llm_response_cache_handler.LLMResponseCacheHandler', return_value=mock_handler): with patch('crewai.agent.Agent.execute_task', return_value="Test response"):
crew.kickoff() with patch('crewai.utilities.llm_response_cache_handler.LLMResponseCacheHandler', return_value=mock_handler):
crew.kickoff()
mock_handler.start_recording.assert_called_once() mock_handler.start_recording.assert_called_once()
@@ -69,8 +71,9 @@ def test_crew_replay_mode():
mock_llm = MagicMock() mock_llm = MagicMock()
agent.llm = mock_llm agent.llm = mock_llm
with patch('crewai.utilities.llm_response_cache_handler.LLMResponseCacheHandler', return_value=mock_handler): with patch('crewai.agent.Agent.execute_task', return_value="Test response"):
crew.kickoff() with patch('crewai.utilities.llm_response_cache_handler.LLMResponseCacheHandler', return_value=mock_handler):
crew.kickoff()
mock_handler.start_replaying.assert_called_once() mock_handler.start_replaying.assert_called_once()