diff --git a/src/crewai/cli/cli.py b/src/crewai/cli/cli.py index b2d59adbe..c2a528663 100644 --- a/src/crewai/cli/cli.py +++ b/src/crewai/cli/cli.py @@ -201,9 +201,20 @@ def install(context): @crewai.command() -def run(): +@click.option( + "--record", + is_flag=True, + help="Record LLM responses for later replay", +) +@click.option( + "--replay", + is_flag=True, + help="Replay from recorded LLM responses without making network calls", +) +def run(record: bool = False, replay: bool = False): """Run the Crew.""" - run_crew() + click.echo("Running the Crew") + run_crew(record=record, replay=replay) @crewai.command() diff --git a/src/crewai/cli/run_crew.py b/src/crewai/cli/run_crew.py index 62241a4b5..37994ff11 100644 --- a/src/crewai/cli/run_crew.py +++ b/src/crewai/cli/run_crew.py @@ -14,13 +14,17 @@ class CrewType(Enum): FLOW = "flow" -def run_crew() -> None: +def run_crew(record: bool = False, replay: bool = False) -> None: """ Run the crew or flow by running a command in the UV environment. Starting from version 0.103.0, this command can be used to run both standard crews and flows. For flows, it detects the type from pyproject.toml and automatically runs the appropriate command. + + Args: + record (bool, optional): Whether to record LLM responses. Defaults to False. + replay (bool, optional): Whether to replay from recorded LLM responses. Defaults to False. """ crewai_version = get_crewai_version() min_required_version = "0.71.0" @@ -44,17 +48,24 @@ def run_crew() -> None: click.echo(f"Running the {'Flow' if is_flow else 'Crew'}") # Execute the appropriate command - execute_command(crew_type) + execute_command(crew_type, record, replay) -def execute_command(crew_type: CrewType) -> None: +def execute_command(crew_type: CrewType, record: bool = False, replay: bool = False) -> None: """ Execute the appropriate command based on crew type. Args: crew_type: The type of crew to run + record: Whether to record LLM responses + replay: Whether to replay from recorded LLM responses """ command = ["uv", "run", "kickoff" if crew_type == CrewType.FLOW else "run_crew"] + + if record: + command.append("--record") + if replay: + command.append("--replay") try: subprocess.run(command, capture_output=False, text=True, check=True) diff --git a/src/crewai/crew.py b/src/crewai/crew.py index 102f22881..6dfefe9c3 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -244,6 +244,15 @@ class Crew(FlowTrackable, BaseModel): default_factory=SecurityConfig, description="Security configuration for the crew, including fingerprinting.", ) + record_mode: bool = Field( + default=False, + description="Whether to record LLM responses for later replay.", + ) + replay_mode: bool = Field( + default=False, + description="Whether to replay from recorded LLM responses without making network calls.", + ) + _llm_response_cache_handler: Optional[Any] = PrivateAttr(default=None) @field_validator("id", mode="before") @classmethod @@ -633,6 +642,17 @@ class Crew(FlowTrackable, BaseModel): self._task_output_handler.reset() self._logging_color = "bold_purple" + if self.record_mode and self.replay_mode: + raise ValueError("Cannot use both record_mode and replay_mode at the same time") + + if self.record_mode or self.replay_mode: + from crewai.utilities.llm_response_cache_handler import LLMResponseCacheHandler + self._llm_response_cache_handler = LLMResponseCacheHandler() + if self.record_mode: + self._llm_response_cache_handler.start_recording() + elif self.replay_mode: + self._llm_response_cache_handler.start_replaying() + if inputs is not None: self._inputs = inputs self._interpolate_inputs(inputs) @@ -651,6 +671,12 @@ class Crew(FlowTrackable, BaseModel): if not agent.step_callback: # type: ignore # "BaseAgent" has no attribute "step_callback" agent.step_callback = self.step_callback # type: ignore # "BaseAgent" has no attribute "step_callback" + + if self._llm_response_cache_handler: + if hasattr(agent, "llm") and agent.llm: + agent.llm.set_response_cache_handler(self._llm_response_cache_handler) + if hasattr(agent, "function_calling_llm") and agent.function_calling_llm: + agent.function_calling_llm.set_response_cache_handler(self._llm_response_cache_handler) agent.create_agent_executor() @@ -1287,6 +1313,9 @@ class Crew(FlowTrackable, BaseModel): def _finish_execution(self, final_string_output: str) -> None: if self.max_rpm: self._rpm_controller.stop_rpm_counter() + + if self._llm_response_cache_handler: + self._llm_response_cache_handler.stop() def calculate_usage_metrics(self) -> UsageMetrics: """Calculates and returns the usage metrics.""" diff --git a/src/crewai/llm.py b/src/crewai/llm.py index c8c456297..b187a6168 100644 --- a/src/crewai/llm.py +++ b/src/crewai/llm.py @@ -296,6 +296,7 @@ class LLM(BaseLLM): self.additional_params = kwargs self.is_anthropic = self._is_anthropic_model(model) self.stream = stream + self._response_cache_handler = None litellm.drop_params = True @@ -869,25 +870,43 @@ class LLM(BaseLLM): for message in messages: if message.get("role") == "system": message["role"] = "assistant" + + if self._response_cache_handler and self._response_cache_handler.is_replaying(): + cached_response = self._response_cache_handler.get_cached_response( + self.model, messages + ) + if cached_response: + # Emit completion event for the cached response + self._handle_emit_call_events(cached_response, LLMCallType.LLM_CALL) + return cached_response - # --- 5) Set up callbacks if provided + # --- 6) Set up callbacks if provided with suppress_warnings(): if callbacks and len(callbacks) > 0: self.set_callbacks(callbacks) try: - # --- 6) Prepare parameters for the completion call + # --- 7) Prepare parameters for the completion call params = self._prepare_completion_params(messages, tools) - # --- 7) Make the completion call and handle response + # --- 8) Make the completion call and handle response if self.stream: - return self._handle_streaming_response( + response = self._handle_streaming_response( params, callbacks, available_functions ) else: - return self._handle_non_streaming_response( + response = self._handle_non_streaming_response( params, callbacks, available_functions ) + + if (self._response_cache_handler and + self._response_cache_handler.is_recording() and + isinstance(response, str)): + self._response_cache_handler.cache_response( + self.model, messages, response + ) + + return response except LLMContextLengthExceededException: # Re-raise LLMContextLengthExceededException as it should be handled @@ -1107,3 +1126,18 @@ class LLM(BaseLLM): litellm.success_callback = success_callbacks litellm.failure_callback = failure_callbacks + + def set_response_cache_handler(self, handler): + """ + Sets the response cache handler for record/replay functionality. + + Args: + handler: An instance of LLMResponseCacheHandler. + """ + self._response_cache_handler = handler + + def clear_response_cache_handler(self): + """ + Clears the response cache handler. + """ + self._response_cache_handler = None diff --git a/src/crewai/memory/storage/llm_response_cache_storage.py b/src/crewai/memory/storage/llm_response_cache_storage.py new file mode 100644 index 000000000..350247c7d --- /dev/null +++ b/src/crewai/memory/storage/llm_response_cache_storage.py @@ -0,0 +1,131 @@ +import json +import sqlite3 +import hashlib +from typing import Any, Dict, List, Optional + +from crewai.utilities import Printer +from crewai.utilities.crew_json_encoder import CrewJSONEncoder +from crewai.utilities.paths import db_storage_path + +class LLMResponseCacheStorage: + """ + SQLite storage for caching LLM responses. + Used for offline record/replay functionality. + """ + + def __init__( + self, db_path: str = f"{db_storage_path()}/llm_response_cache.db" + ) -> None: + self.db_path = db_path + self._printer: Printer = Printer() + self._initialize_db() + + def _initialize_db(self): + """ + Initializes the SQLite database and creates the llm_response_cache table + """ + try: + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS llm_response_cache ( + request_hash TEXT PRIMARY KEY, + model TEXT, + messages TEXT, + response TEXT, + timestamp DATETIME DEFAULT CURRENT_TIMESTAMP + ) + """ + ) + conn.commit() + except sqlite3.Error as e: + self._printer.print( + content=f"LLM RESPONSE CACHE ERROR: An error occurred during database initialization: {e}", + color="red", + ) + + def _compute_request_hash(self, model: str, messages: List[Dict[str, str]]) -> str: + """ + Computes a hash for the request based on the model and messages. + This hash is used as the key for caching. + + Sensitive information like API keys should not be included in the hash. + """ + message_str = json.dumps(messages, sort_keys=True) + request_hash = hashlib.sha256(f"{model}:{message_str}".encode()).hexdigest() + return request_hash + + def add(self, model: str, messages: List[Dict[str, str]], response: str) -> None: + """ + Adds a response to the cache. + """ + try: + request_hash = self._compute_request_hash(model, messages) + messages_json = json.dumps(messages, cls=CrewJSONEncoder) + + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + cursor.execute( + """ + INSERT OR REPLACE INTO llm_response_cache + (request_hash, model, messages, response) + VALUES (?, ?, ?, ?) + """, + ( + request_hash, + model, + messages_json, + response, + ), + ) + conn.commit() + except sqlite3.Error as e: + self._printer.print( + content=f"LLM RESPONSE CACHE ERROR: Failed to add response: {e}", + color="red", + ) + + def get(self, model: str, messages: List[Dict[str, str]]) -> Optional[str]: + """ + Retrieves a response from the cache based on the model and messages. + Returns None if not found. + """ + try: + request_hash = self._compute_request_hash(model, messages) + + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + cursor.execute( + """ + SELECT response + FROM llm_response_cache + WHERE request_hash = ? + """, + (request_hash,), + ) + + result = cursor.fetchone() + return result[0] if result else None + + except sqlite3.Error as e: + self._printer.print( + content=f"LLM RESPONSE CACHE ERROR: Failed to retrieve response: {e}", + color="red", + ) + return None + + def delete_all(self) -> None: + """ + Deletes all records from the llm_response_cache table. + """ + try: + with sqlite3.connect(self.db_path) as conn: + cursor = conn.cursor() + cursor.execute("DELETE FROM llm_response_cache") + conn.commit() + except sqlite3.Error as e: + self._printer.print( + content=f"LLM RESPONSE CACHE ERROR: Failed to clear cache: {e}", + color="red", + ) diff --git a/src/crewai/utilities/llm_response_cache_handler.py b/src/crewai/utilities/llm_response_cache_handler.py new file mode 100644 index 000000000..e8f4d573f --- /dev/null +++ b/src/crewai/utilities/llm_response_cache_handler.py @@ -0,0 +1,69 @@ +from typing import Any, Dict, List, Optional + +from crewai.memory.storage.llm_response_cache_storage import LLMResponseCacheStorage + +class LLMResponseCacheHandler: + """ + Handler for the LLM response cache storage. + Used for record/replay functionality. + """ + + def __init__(self) -> None: + self.storage = LLMResponseCacheStorage() + self._recording = False + self._replaying = False + + def start_recording(self) -> None: + """ + Starts recording LLM responses. + """ + self._recording = True + self._replaying = False + + def start_replaying(self) -> None: + """ + Starts replaying LLM responses from the cache. + """ + self._recording = False + self._replaying = True + + def stop(self) -> None: + """ + Stops recording or replaying. + """ + self._recording = False + self._replaying = False + + def is_recording(self) -> bool: + """ + Returns whether recording is active. + """ + return self._recording + + def is_replaying(self) -> bool: + """ + Returns whether replaying is active. + """ + return self._replaying + + def cache_response(self, model: str, messages: List[Dict[str, str]], response: str) -> None: + """ + Caches an LLM response if recording is active. + """ + if self._recording: + self.storage.add(model, messages, response) + + def get_cached_response(self, model: str, messages: List[Dict[str, str]]) -> Optional[str]: + """ + Retrieves a cached LLM response if replaying is active. + Returns None if not found or if replaying is not active. + """ + if self._replaying: + return self.storage.get(model, messages) + return None + + def clear_cache(self) -> None: + """ + Clears the LLM response cache. + """ + self.storage.delete_all() diff --git a/tests/llm_response_cache_test.py b/tests/llm_response_cache_test.py new file mode 100644 index 000000000..ae9efee5a --- /dev/null +++ b/tests/llm_response_cache_test.py @@ -0,0 +1,78 @@ +import pytest +from unittest.mock import MagicMock, patch + +from crewai.llm import LLM +from crewai.utilities.llm_response_cache_handler import LLMResponseCacheHandler + + +@pytest.fixture +def handler(): + handler = LLMResponseCacheHandler() + handler.storage.add = MagicMock() + handler.storage.get = MagicMock() + return handler + + +@pytest.mark.vcr(filter_headers=["authorization"]) +def test_llm_recording(handler): + handler.start_recording() + + llm = LLM(model="gpt-4o-mini") + llm.set_response_cache_handler(handler) + + messages = [{"role": "user", "content": "Hello, world!"}] + + with patch('litellm.completion') as mock_completion: + mock_completion.return_value = { + "choices": [{"message": {"content": "Hello, human!"}}] + } + + response = llm.call(messages) + + assert response == "Hello, human!" + + handler.storage.add.assert_called_once_with( + "gpt-4o-mini", messages, "Hello, human!" + ) + + +@pytest.mark.vcr(filter_headers=["authorization"]) +def test_llm_replaying(handler): + handler.start_replaying() + handler.storage.get.return_value = "Cached response" + + llm = LLM(model="gpt-4o-mini") + llm.set_response_cache_handler(handler) + + messages = [{"role": "user", "content": "Hello, world!"}] + + with patch('litellm.completion') as mock_completion: + response = llm.call(messages) + + assert response == "Cached response" + + mock_completion.assert_not_called() + + handler.storage.get.assert_called_once_with("gpt-4o-mini", messages) + + +@pytest.mark.vcr(filter_headers=["authorization"]) +def test_llm_replay_fallback(handler): + handler.start_replaying() + handler.storage.get.return_value = None + + llm = LLM(model="gpt-4o-mini") + llm.set_response_cache_handler(handler) + + messages = [{"role": "user", "content": "Hello, world!"}] + + with patch('litellm.completion') as mock_completion: + mock_completion.return_value = { + "choices": [{"message": {"content": "Hello, human!"}}] + } + + response = llm.call(messages) + + assert response == "Hello, human!" + + mock_completion.assert_called_once() diff --git a/tests/record_replay_test.py b/tests/record_replay_test.py new file mode 100644 index 000000000..087f2ab50 --- /dev/null +++ b/tests/record_replay_test.py @@ -0,0 +1,90 @@ +import pytest +from unittest.mock import MagicMock, patch + +from crewai.agent import Agent +from crewai.crew import Crew +from crewai.process import Process +from crewai.task import Task + + +@pytest.mark.vcr(filter_headers=["authorization"]) +def test_crew_recording_mode(): + agent = Agent( + role="Test Agent", + goal="Test the recording functionality", + backstory="A test agent for recording LLM responses", + ) + + task = Task( + description="Return a simple response", + expected_output="A simple response", + agent=agent, + ) + + crew = Crew( + agents=[agent], + tasks=[task], + process=Process.sequential, + record_mode=True, + ) + + mock_handler = MagicMock() + crew._llm_response_cache_handler = mock_handler + + mock_llm = MagicMock() + agent.llm = mock_llm + + with patch('crewai.utilities.llm_response_cache_handler.LLMResponseCacheHandler', return_value=mock_handler): + crew.kickoff() + + mock_handler.start_recording.assert_called_once() + + mock_llm.set_response_cache_handler.assert_called_once_with(mock_handler) + + +@pytest.mark.vcr(filter_headers=["authorization"]) +def test_crew_replay_mode(): + agent = Agent( + role="Test Agent", + goal="Test the replay functionality", + backstory="A test agent for replaying LLM responses", + ) + + task = Task( + description="Return a simple response", + expected_output="A simple response", + agent=agent, + ) + + crew = Crew( + agents=[agent], + tasks=[task], + process=Process.sequential, + replay_mode=True, + ) + + mock_handler = MagicMock() + crew._llm_response_cache_handler = mock_handler + + mock_llm = MagicMock() + agent.llm = mock_llm + + with patch('crewai.utilities.llm_response_cache_handler.LLMResponseCacheHandler', return_value=mock_handler): + crew.kickoff() + + mock_handler.start_replaying.assert_called_once() + + mock_llm.set_response_cache_handler.assert_called_once_with(mock_handler) + + +@pytest.mark.vcr(filter_headers=["authorization"]) +def test_record_replay_flags_conflict(): + with pytest.raises(ValueError): + crew = Crew( + agents=[], + tasks=[], + process=Process.sequential, + record_mode=True, + replay_mode=True, + ) + crew.kickoff()