mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-26 08:38:15 +00:00
Implement reviewer suggestions: CLI validation, enhanced error handling, and test improvements
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -213,6 +213,8 @@ def install(context):
|
|||||||
)
|
)
|
||||||
def run(record: bool = False, replay: bool = False):
|
def run(record: bool = False, replay: bool = False):
|
||||||
"""Run the Crew."""
|
"""Run the Crew."""
|
||||||
|
if record and replay:
|
||||||
|
raise click.UsageError("Cannot use --record and --replay simultaneously")
|
||||||
click.echo("Running the Crew")
|
click.echo("Running the Crew")
|
||||||
run_crew(record=record, replay=replay)
|
run_crew(record=record, replay=replay)
|
||||||
|
|
||||||
|
|||||||
@@ -1,17 +1,33 @@
|
|||||||
|
import logging
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from crewai.memory.storage.llm_response_cache_storage import LLMResponseCacheStorage
|
from crewai.memory.storage.llm_response_cache_storage import LLMResponseCacheStorage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LLMResponseCacheHandler:
|
class LLMResponseCacheHandler:
|
||||||
"""
|
"""
|
||||||
Handler for the LLM response cache storage.
|
Handler for the LLM response cache storage.
|
||||||
Used for record/replay functionality.
|
Used for record/replay functionality.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self, max_cache_age_days: int = 7) -> None:
|
||||||
|
"""
|
||||||
|
Initializes the LLM response cache handler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_cache_age_days: Maximum age of cache entries in days. Defaults to 7.
|
||||||
|
"""
|
||||||
self.storage = LLMResponseCacheStorage()
|
self.storage = LLMResponseCacheStorage()
|
||||||
self._recording = False
|
self._recording = False
|
||||||
self._replaying = False
|
self._replaying = False
|
||||||
|
self.max_cache_age_days = max_cache_age_days
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.storage.cleanup_expired_cache(self.max_cache_age_days)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to cleanup expired cache on initialization: {e}")
|
||||||
|
|
||||||
def start_recording(self) -> None:
|
def start_recording(self) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -19,6 +35,7 @@ class LLMResponseCacheHandler:
|
|||||||
"""
|
"""
|
||||||
self._recording = True
|
self._recording = True
|
||||||
self._replaying = False
|
self._replaying = False
|
||||||
|
logger.info("Started recording LLM responses")
|
||||||
|
|
||||||
def start_replaying(self) -> None:
|
def start_replaying(self) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -26,13 +43,28 @@ class LLMResponseCacheHandler:
|
|||||||
"""
|
"""
|
||||||
self._recording = False
|
self._recording = False
|
||||||
self._replaying = True
|
self._replaying = True
|
||||||
|
logger.info("Started replaying LLM responses from cache")
|
||||||
|
|
||||||
|
try:
|
||||||
|
stats = self.storage.get_cache_stats()
|
||||||
|
logger.info(f"Cache statistics: {stats}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to get cache statistics: {e}")
|
||||||
|
|
||||||
def stop(self) -> None:
|
def stop(self) -> None:
|
||||||
"""
|
"""
|
||||||
Stops recording or replaying.
|
Stops recording or replaying.
|
||||||
"""
|
"""
|
||||||
|
was_recording = self._recording
|
||||||
|
was_replaying = self._replaying
|
||||||
|
|
||||||
self._recording = False
|
self._recording = False
|
||||||
self._replaying = False
|
self._replaying = False
|
||||||
|
|
||||||
|
if was_recording:
|
||||||
|
logger.info("Stopped recording LLM responses")
|
||||||
|
if was_replaying:
|
||||||
|
logger.info("Stopped replaying LLM responses")
|
||||||
|
|
||||||
def is_recording(self) -> bool:
|
def is_recording(self) -> bool:
|
||||||
"""
|
"""
|
||||||
@@ -49,21 +81,76 @@ class LLMResponseCacheHandler:
|
|||||||
def cache_response(self, model: str, messages: List[Dict[str, str]], response: str) -> None:
|
def cache_response(self, model: str, messages: List[Dict[str, str]], response: str) -> None:
|
||||||
"""
|
"""
|
||||||
Caches an LLM response if recording is active.
|
Caches an LLM response if recording is active.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model used for the LLM call.
|
||||||
|
messages: The messages sent to the LLM.
|
||||||
|
response: The response from the LLM.
|
||||||
"""
|
"""
|
||||||
if self._recording:
|
if not self._recording:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
self.storage.add(model, messages, response)
|
self.storage.add(model, messages, response)
|
||||||
|
logger.debug(f"Cached response for model {model}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to cache response: {e}")
|
||||||
|
|
||||||
def get_cached_response(self, model: str, messages: List[Dict[str, str]]) -> Optional[str]:
|
def get_cached_response(self, model: str, messages: List[Dict[str, str]]) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Retrieves a cached LLM response if replaying is active.
|
Retrieves a cached LLM response if replaying is active.
|
||||||
Returns None if not found or if replaying is not active.
|
Returns None if not found or if replaying is not active.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model used for the LLM call.
|
||||||
|
messages: The messages sent to the LLM.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The cached response, or None if not found or if replaying is not active.
|
||||||
"""
|
"""
|
||||||
if self._replaying:
|
if not self._replaying:
|
||||||
return self.storage.get(model, messages)
|
return None
|
||||||
return None
|
|
||||||
|
try:
|
||||||
|
response = self.storage.get(model, messages)
|
||||||
|
if response:
|
||||||
|
logger.debug(f"Retrieved cached response for model {model}")
|
||||||
|
else:
|
||||||
|
logger.debug(f"No cached response found for model {model}")
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to retrieve cached response: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
def clear_cache(self) -> None:
|
def clear_cache(self) -> None:
|
||||||
"""
|
"""
|
||||||
Clears the LLM response cache.
|
Clears the LLM response cache.
|
||||||
"""
|
"""
|
||||||
self.storage.delete_all()
|
try:
|
||||||
|
self.storage.delete_all()
|
||||||
|
logger.info("Cleared LLM response cache")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to clear cache: {e}")
|
||||||
|
|
||||||
|
def cleanup_expired_cache(self) -> None:
|
||||||
|
"""
|
||||||
|
Removes cache entries older than the maximum age.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.storage.cleanup_expired_cache(self.max_cache_age_days)
|
||||||
|
logger.info(f"Cleaned up expired cache entries (older than {self.max_cache_age_days} days)")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to cleanup expired cache: {e}")
|
||||||
|
|
||||||
|
def get_cache_stats(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Returns statistics about the cache.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary containing cache statistics.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return self.storage.get_cache_stats()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get cache stats: {e}")
|
||||||
|
return {"error": str(e)}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import pytest
|
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from crewai.agent import Agent
|
from crewai.agent import Agent
|
||||||
from crewai.crew import Crew
|
from crewai.crew import Crew
|
||||||
from crewai.process import Process
|
from crewai.process import Process
|
||||||
@@ -34,8 +35,9 @@ def test_crew_recording_mode():
|
|||||||
mock_llm = MagicMock()
|
mock_llm = MagicMock()
|
||||||
agent.llm = mock_llm
|
agent.llm = mock_llm
|
||||||
|
|
||||||
with patch('crewai.utilities.llm_response_cache_handler.LLMResponseCacheHandler', return_value=mock_handler):
|
with patch('crewai.agent.Agent.execute_task', return_value="Test response"):
|
||||||
crew.kickoff()
|
with patch('crewai.utilities.llm_response_cache_handler.LLMResponseCacheHandler', return_value=mock_handler):
|
||||||
|
crew.kickoff()
|
||||||
|
|
||||||
mock_handler.start_recording.assert_called_once()
|
mock_handler.start_recording.assert_called_once()
|
||||||
|
|
||||||
@@ -69,8 +71,9 @@ def test_crew_replay_mode():
|
|||||||
mock_llm = MagicMock()
|
mock_llm = MagicMock()
|
||||||
agent.llm = mock_llm
|
agent.llm = mock_llm
|
||||||
|
|
||||||
with patch('crewai.utilities.llm_response_cache_handler.LLMResponseCacheHandler', return_value=mock_handler):
|
with patch('crewai.agent.Agent.execute_task', return_value="Test response"):
|
||||||
crew.kickoff()
|
with patch('crewai.utilities.llm_response_cache_handler.LLMResponseCacheHandler', return_value=mock_handler):
|
||||||
|
crew.kickoff()
|
||||||
|
|
||||||
mock_handler.start_replaying.assert_called_once()
|
mock_handler.start_replaying.assert_called_once()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user