Add Record/Replay functionality for offline processing (Issue #2759)

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-05-05 22:23:03 +00:00
parent dabf02a90d
commit d5dfd5a1f5
8 changed files with 463 additions and 10 deletions

View File

@@ -201,9 +201,20 @@ def install(context):
@crewai.command()
def run():
@click.option(
"--record",
is_flag=True,
help="Record LLM responses for later replay",
)
@click.option(
"--replay",
is_flag=True,
help="Replay from recorded LLM responses without making network calls",
)
def run(record: bool = False, replay: bool = False):
"""Run the Crew."""
run_crew()
click.echo("Running the Crew")
run_crew(record=record, replay=replay)
@crewai.command()

View File

@@ -14,13 +14,17 @@ class CrewType(Enum):
FLOW = "flow"
def run_crew() -> None:
def run_crew(record: bool = False, replay: bool = False) -> None:
"""
Run the crew or flow by running a command in the UV environment.
Starting from version 0.103.0, this command can be used to run both
standard crews and flows. For flows, it detects the type from pyproject.toml
and automatically runs the appropriate command.
Args:
record (bool, optional): Whether to record LLM responses. Defaults to False.
replay (bool, optional): Whether to replay from recorded LLM responses. Defaults to False.
"""
crewai_version = get_crewai_version()
min_required_version = "0.71.0"
@@ -44,17 +48,24 @@ def run_crew() -> None:
click.echo(f"Running the {'Flow' if is_flow else 'Crew'}")
# Execute the appropriate command
execute_command(crew_type)
execute_command(crew_type, record, replay)
def execute_command(crew_type: CrewType) -> None:
def execute_command(crew_type: CrewType, record: bool = False, replay: bool = False) -> None:
"""
Execute the appropriate command based on crew type.
Args:
crew_type: The type of crew to run
record: Whether to record LLM responses
replay: Whether to replay from recorded LLM responses
"""
command = ["uv", "run", "kickoff" if crew_type == CrewType.FLOW else "run_crew"]
if record:
command.append("--record")
if replay:
command.append("--replay")
try:
subprocess.run(command, capture_output=False, text=True, check=True)

View File

@@ -244,6 +244,15 @@ class Crew(FlowTrackable, BaseModel):
default_factory=SecurityConfig,
description="Security configuration for the crew, including fingerprinting.",
)
record_mode: bool = Field(
default=False,
description="Whether to record LLM responses for later replay.",
)
replay_mode: bool = Field(
default=False,
description="Whether to replay from recorded LLM responses without making network calls.",
)
_llm_response_cache_handler: Optional[Any] = PrivateAttr(default=None)
@field_validator("id", mode="before")
@classmethod
@@ -633,6 +642,17 @@ class Crew(FlowTrackable, BaseModel):
self._task_output_handler.reset()
self._logging_color = "bold_purple"
if self.record_mode and self.replay_mode:
raise ValueError("Cannot use both record_mode and replay_mode at the same time")
if self.record_mode or self.replay_mode:
from crewai.utilities.llm_response_cache_handler import LLMResponseCacheHandler
self._llm_response_cache_handler = LLMResponseCacheHandler()
if self.record_mode:
self._llm_response_cache_handler.start_recording()
elif self.replay_mode:
self._llm_response_cache_handler.start_replaying()
if inputs is not None:
self._inputs = inputs
self._interpolate_inputs(inputs)
@@ -651,6 +671,12 @@ class Crew(FlowTrackable, BaseModel):
if not agent.step_callback: # type: ignore # "BaseAgent" has no attribute "step_callback"
agent.step_callback = self.step_callback # type: ignore # "BaseAgent" has no attribute "step_callback"
if self._llm_response_cache_handler:
if hasattr(agent, "llm") and agent.llm:
agent.llm.set_response_cache_handler(self._llm_response_cache_handler)
if hasattr(agent, "function_calling_llm") and agent.function_calling_llm:
agent.function_calling_llm.set_response_cache_handler(self._llm_response_cache_handler)
agent.create_agent_executor()
@@ -1287,6 +1313,9 @@ class Crew(FlowTrackable, BaseModel):
def _finish_execution(self, final_string_output: str) -> None:
if self.max_rpm:
self._rpm_controller.stop_rpm_counter()
if self._llm_response_cache_handler:
self._llm_response_cache_handler.stop()
def calculate_usage_metrics(self) -> UsageMetrics:
"""Calculates and returns the usage metrics."""

View File

@@ -296,6 +296,7 @@ class LLM(BaseLLM):
self.additional_params = kwargs
self.is_anthropic = self._is_anthropic_model(model)
self.stream = stream
self._response_cache_handler = None
litellm.drop_params = True
@@ -869,25 +870,43 @@ class LLM(BaseLLM):
for message in messages:
if message.get("role") == "system":
message["role"] = "assistant"
if self._response_cache_handler and self._response_cache_handler.is_replaying():
cached_response = self._response_cache_handler.get_cached_response(
self.model, messages
)
if cached_response:
# Emit completion event for the cached response
self._handle_emit_call_events(cached_response, LLMCallType.LLM_CALL)
return cached_response
# --- 5) Set up callbacks if provided
# --- 6) Set up callbacks if provided
with suppress_warnings():
if callbacks and len(callbacks) > 0:
self.set_callbacks(callbacks)
try:
# --- 6) Prepare parameters for the completion call
# --- 7) Prepare parameters for the completion call
params = self._prepare_completion_params(messages, tools)
# --- 7) Make the completion call and handle response
# --- 8) Make the completion call and handle response
if self.stream:
return self._handle_streaming_response(
response = self._handle_streaming_response(
params, callbacks, available_functions
)
else:
return self._handle_non_streaming_response(
response = self._handle_non_streaming_response(
params, callbacks, available_functions
)
if (self._response_cache_handler and
self._response_cache_handler.is_recording() and
isinstance(response, str)):
self._response_cache_handler.cache_response(
self.model, messages, response
)
return response
except LLMContextLengthExceededException:
# Re-raise LLMContextLengthExceededException as it should be handled
@@ -1107,3 +1126,18 @@ class LLM(BaseLLM):
litellm.success_callback = success_callbacks
litellm.failure_callback = failure_callbacks
def set_response_cache_handler(self, handler):
"""
Sets the response cache handler for record/replay functionality.
Args:
handler: An instance of LLMResponseCacheHandler.
"""
self._response_cache_handler = handler
def clear_response_cache_handler(self):
"""
Clears the response cache handler.
"""
self._response_cache_handler = None

View File

@@ -0,0 +1,131 @@
import json
import sqlite3
import hashlib
from typing import Any, Dict, List, Optional
from crewai.utilities import Printer
from crewai.utilities.crew_json_encoder import CrewJSONEncoder
from crewai.utilities.paths import db_storage_path
class LLMResponseCacheStorage:
"""
SQLite storage for caching LLM responses.
Used for offline record/replay functionality.
"""
def __init__(
self, db_path: str = f"{db_storage_path()}/llm_response_cache.db"
) -> None:
self.db_path = db_path
self._printer: Printer = Printer()
self._initialize_db()
def _initialize_db(self):
"""
Initializes the SQLite database and creates the llm_response_cache table
"""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS llm_response_cache (
request_hash TEXT PRIMARY KEY,
model TEXT,
messages TEXT,
response TEXT,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
)
"""
)
conn.commit()
except sqlite3.Error as e:
self._printer.print(
content=f"LLM RESPONSE CACHE ERROR: An error occurred during database initialization: {e}",
color="red",
)
def _compute_request_hash(self, model: str, messages: List[Dict[str, str]]) -> str:
"""
Computes a hash for the request based on the model and messages.
This hash is used as the key for caching.
Sensitive information like API keys should not be included in the hash.
"""
message_str = json.dumps(messages, sort_keys=True)
request_hash = hashlib.sha256(f"{model}:{message_str}".encode()).hexdigest()
return request_hash
def add(self, model: str, messages: List[Dict[str, str]], response: str) -> None:
"""
Adds a response to the cache.
"""
try:
request_hash = self._compute_request_hash(model, messages)
messages_json = json.dumps(messages, cls=CrewJSONEncoder)
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute(
"""
INSERT OR REPLACE INTO llm_response_cache
(request_hash, model, messages, response)
VALUES (?, ?, ?, ?)
""",
(
request_hash,
model,
messages_json,
response,
),
)
conn.commit()
except sqlite3.Error as e:
self._printer.print(
content=f"LLM RESPONSE CACHE ERROR: Failed to add response: {e}",
color="red",
)
def get(self, model: str, messages: List[Dict[str, str]]) -> Optional[str]:
"""
Retrieves a response from the cache based on the model and messages.
Returns None if not found.
"""
try:
request_hash = self._compute_request_hash(model, messages)
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT response
FROM llm_response_cache
WHERE request_hash = ?
""",
(request_hash,),
)
result = cursor.fetchone()
return result[0] if result else None
except sqlite3.Error as e:
self._printer.print(
content=f"LLM RESPONSE CACHE ERROR: Failed to retrieve response: {e}",
color="red",
)
return None
def delete_all(self) -> None:
"""
Deletes all records from the llm_response_cache table.
"""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("DELETE FROM llm_response_cache")
conn.commit()
except sqlite3.Error as e:
self._printer.print(
content=f"LLM RESPONSE CACHE ERROR: Failed to clear cache: {e}",
color="red",
)

View File

@@ -0,0 +1,69 @@
from typing import Any, Dict, List, Optional
from crewai.memory.storage.llm_response_cache_storage import LLMResponseCacheStorage
class LLMResponseCacheHandler:
"""
Handler for the LLM response cache storage.
Used for record/replay functionality.
"""
def __init__(self) -> None:
self.storage = LLMResponseCacheStorage()
self._recording = False
self._replaying = False
def start_recording(self) -> None:
"""
Starts recording LLM responses.
"""
self._recording = True
self._replaying = False
def start_replaying(self) -> None:
"""
Starts replaying LLM responses from the cache.
"""
self._recording = False
self._replaying = True
def stop(self) -> None:
"""
Stops recording or replaying.
"""
self._recording = False
self._replaying = False
def is_recording(self) -> bool:
"""
Returns whether recording is active.
"""
return self._recording
def is_replaying(self) -> bool:
"""
Returns whether replaying is active.
"""
return self._replaying
def cache_response(self, model: str, messages: List[Dict[str, str]], response: str) -> None:
"""
Caches an LLM response if recording is active.
"""
if self._recording:
self.storage.add(model, messages, response)
def get_cached_response(self, model: str, messages: List[Dict[str, str]]) -> Optional[str]:
"""
Retrieves a cached LLM response if replaying is active.
Returns None if not found or if replaying is not active.
"""
if self._replaying:
return self.storage.get(model, messages)
return None
def clear_cache(self) -> None:
"""
Clears the LLM response cache.
"""
self.storage.delete_all()

View File

@@ -0,0 +1,78 @@
import pytest
from unittest.mock import MagicMock, patch
from crewai.llm import LLM
from crewai.utilities.llm_response_cache_handler import LLMResponseCacheHandler
@pytest.fixture
def handler():
handler = LLMResponseCacheHandler()
handler.storage.add = MagicMock()
handler.storage.get = MagicMock()
return handler
@pytest.mark.vcr(filter_headers=["authorization"])
def test_llm_recording(handler):
handler.start_recording()
llm = LLM(model="gpt-4o-mini")
llm.set_response_cache_handler(handler)
messages = [{"role": "user", "content": "Hello, world!"}]
with patch('litellm.completion') as mock_completion:
mock_completion.return_value = {
"choices": [{"message": {"content": "Hello, human!"}}]
}
response = llm.call(messages)
assert response == "Hello, human!"
handler.storage.add.assert_called_once_with(
"gpt-4o-mini", messages, "Hello, human!"
)
@pytest.mark.vcr(filter_headers=["authorization"])
def test_llm_replaying(handler):
handler.start_replaying()
handler.storage.get.return_value = "Cached response"
llm = LLM(model="gpt-4o-mini")
llm.set_response_cache_handler(handler)
messages = [{"role": "user", "content": "Hello, world!"}]
with patch('litellm.completion') as mock_completion:
response = llm.call(messages)
assert response == "Cached response"
mock_completion.assert_not_called()
handler.storage.get.assert_called_once_with("gpt-4o-mini", messages)
@pytest.mark.vcr(filter_headers=["authorization"])
def test_llm_replay_fallback(handler):
handler.start_replaying()
handler.storage.get.return_value = None
llm = LLM(model="gpt-4o-mini")
llm.set_response_cache_handler(handler)
messages = [{"role": "user", "content": "Hello, world!"}]
with patch('litellm.completion') as mock_completion:
mock_completion.return_value = {
"choices": [{"message": {"content": "Hello, human!"}}]
}
response = llm.call(messages)
assert response == "Hello, human!"
mock_completion.assert_called_once()

View File

@@ -0,0 +1,90 @@
import pytest
from unittest.mock import MagicMock, patch
from crewai.agent import Agent
from crewai.crew import Crew
from crewai.process import Process
from crewai.task import Task
@pytest.mark.vcr(filter_headers=["authorization"])
def test_crew_recording_mode():
agent = Agent(
role="Test Agent",
goal="Test the recording functionality",
backstory="A test agent for recording LLM responses",
)
task = Task(
description="Return a simple response",
expected_output="A simple response",
agent=agent,
)
crew = Crew(
agents=[agent],
tasks=[task],
process=Process.sequential,
record_mode=True,
)
mock_handler = MagicMock()
crew._llm_response_cache_handler = mock_handler
mock_llm = MagicMock()
agent.llm = mock_llm
with patch('crewai.utilities.llm_response_cache_handler.LLMResponseCacheHandler', return_value=mock_handler):
crew.kickoff()
mock_handler.start_recording.assert_called_once()
mock_llm.set_response_cache_handler.assert_called_once_with(mock_handler)
@pytest.mark.vcr(filter_headers=["authorization"])
def test_crew_replay_mode():
agent = Agent(
role="Test Agent",
goal="Test the replay functionality",
backstory="A test agent for replaying LLM responses",
)
task = Task(
description="Return a simple response",
expected_output="A simple response",
agent=agent,
)
crew = Crew(
agents=[agent],
tasks=[task],
process=Process.sequential,
replay_mode=True,
)
mock_handler = MagicMock()
crew._llm_response_cache_handler = mock_handler
mock_llm = MagicMock()
agent.llm = mock_llm
with patch('crewai.utilities.llm_response_cache_handler.LLMResponseCacheHandler', return_value=mock_handler):
crew.kickoff()
mock_handler.start_replaying.assert_called_once()
mock_llm.set_response_cache_handler.assert_called_once_with(mock_handler)
@pytest.mark.vcr(filter_headers=["authorization"])
def test_record_replay_flags_conflict():
with pytest.raises(ValueError):
crew = Crew(
agents=[],
tasks=[],
process=Process.sequential,
record_mode=True,
replay_mode=True,
)
crew.kickoff()