mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
Add Record/Replay functionality for offline processing (Issue #2759)
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -201,9 +201,20 @@ def install(context):
|
||||
|
||||
|
||||
@crewai.command()
|
||||
def run():
|
||||
@click.option(
|
||||
"--record",
|
||||
is_flag=True,
|
||||
help="Record LLM responses for later replay",
|
||||
)
|
||||
@click.option(
|
||||
"--replay",
|
||||
is_flag=True,
|
||||
help="Replay from recorded LLM responses without making network calls",
|
||||
)
|
||||
def run(record: bool = False, replay: bool = False):
|
||||
"""Run the Crew."""
|
||||
run_crew()
|
||||
click.echo("Running the Crew")
|
||||
run_crew(record=record, replay=replay)
|
||||
|
||||
|
||||
@crewai.command()
|
||||
|
||||
@@ -14,13 +14,17 @@ class CrewType(Enum):
|
||||
FLOW = "flow"
|
||||
|
||||
|
||||
def run_crew() -> None:
|
||||
def run_crew(record: bool = False, replay: bool = False) -> None:
|
||||
"""
|
||||
Run the crew or flow by running a command in the UV environment.
|
||||
|
||||
Starting from version 0.103.0, this command can be used to run both
|
||||
standard crews and flows. For flows, it detects the type from pyproject.toml
|
||||
and automatically runs the appropriate command.
|
||||
|
||||
Args:
|
||||
record (bool, optional): Whether to record LLM responses. Defaults to False.
|
||||
replay (bool, optional): Whether to replay from recorded LLM responses. Defaults to False.
|
||||
"""
|
||||
crewai_version = get_crewai_version()
|
||||
min_required_version = "0.71.0"
|
||||
@@ -44,17 +48,24 @@ def run_crew() -> None:
|
||||
click.echo(f"Running the {'Flow' if is_flow else 'Crew'}")
|
||||
|
||||
# Execute the appropriate command
|
||||
execute_command(crew_type)
|
||||
execute_command(crew_type, record, replay)
|
||||
|
||||
|
||||
def execute_command(crew_type: CrewType) -> None:
|
||||
def execute_command(crew_type: CrewType, record: bool = False, replay: bool = False) -> None:
|
||||
"""
|
||||
Execute the appropriate command based on crew type.
|
||||
|
||||
Args:
|
||||
crew_type: The type of crew to run
|
||||
record: Whether to record LLM responses
|
||||
replay: Whether to replay from recorded LLM responses
|
||||
"""
|
||||
command = ["uv", "run", "kickoff" if crew_type == CrewType.FLOW else "run_crew"]
|
||||
|
||||
if record:
|
||||
command.append("--record")
|
||||
if replay:
|
||||
command.append("--replay")
|
||||
|
||||
try:
|
||||
subprocess.run(command, capture_output=False, text=True, check=True)
|
||||
|
||||
@@ -244,6 +244,15 @@ class Crew(FlowTrackable, BaseModel):
|
||||
default_factory=SecurityConfig,
|
||||
description="Security configuration for the crew, including fingerprinting.",
|
||||
)
|
||||
record_mode: bool = Field(
|
||||
default=False,
|
||||
description="Whether to record LLM responses for later replay.",
|
||||
)
|
||||
replay_mode: bool = Field(
|
||||
default=False,
|
||||
description="Whether to replay from recorded LLM responses without making network calls.",
|
||||
)
|
||||
_llm_response_cache_handler: Optional[Any] = PrivateAttr(default=None)
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
@@ -633,6 +642,17 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self._task_output_handler.reset()
|
||||
self._logging_color = "bold_purple"
|
||||
|
||||
if self.record_mode and self.replay_mode:
|
||||
raise ValueError("Cannot use both record_mode and replay_mode at the same time")
|
||||
|
||||
if self.record_mode or self.replay_mode:
|
||||
from crewai.utilities.llm_response_cache_handler import LLMResponseCacheHandler
|
||||
self._llm_response_cache_handler = LLMResponseCacheHandler()
|
||||
if self.record_mode:
|
||||
self._llm_response_cache_handler.start_recording()
|
||||
elif self.replay_mode:
|
||||
self._llm_response_cache_handler.start_replaying()
|
||||
|
||||
if inputs is not None:
|
||||
self._inputs = inputs
|
||||
self._interpolate_inputs(inputs)
|
||||
@@ -651,6 +671,12 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
if not agent.step_callback: # type: ignore # "BaseAgent" has no attribute "step_callback"
|
||||
agent.step_callback = self.step_callback # type: ignore # "BaseAgent" has no attribute "step_callback"
|
||||
|
||||
if self._llm_response_cache_handler:
|
||||
if hasattr(agent, "llm") and agent.llm:
|
||||
agent.llm.set_response_cache_handler(self._llm_response_cache_handler)
|
||||
if hasattr(agent, "function_calling_llm") and agent.function_calling_llm:
|
||||
agent.function_calling_llm.set_response_cache_handler(self._llm_response_cache_handler)
|
||||
|
||||
agent.create_agent_executor()
|
||||
|
||||
@@ -1287,6 +1313,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
def _finish_execution(self, final_string_output: str) -> None:
|
||||
if self.max_rpm:
|
||||
self._rpm_controller.stop_rpm_counter()
|
||||
|
||||
if self._llm_response_cache_handler:
|
||||
self._llm_response_cache_handler.stop()
|
||||
|
||||
def calculate_usage_metrics(self) -> UsageMetrics:
|
||||
"""Calculates and returns the usage metrics."""
|
||||
|
||||
@@ -296,6 +296,7 @@ class LLM(BaseLLM):
|
||||
self.additional_params = kwargs
|
||||
self.is_anthropic = self._is_anthropic_model(model)
|
||||
self.stream = stream
|
||||
self._response_cache_handler = None
|
||||
|
||||
litellm.drop_params = True
|
||||
|
||||
@@ -869,25 +870,43 @@ class LLM(BaseLLM):
|
||||
for message in messages:
|
||||
if message.get("role") == "system":
|
||||
message["role"] = "assistant"
|
||||
|
||||
if self._response_cache_handler and self._response_cache_handler.is_replaying():
|
||||
cached_response = self._response_cache_handler.get_cached_response(
|
||||
self.model, messages
|
||||
)
|
||||
if cached_response:
|
||||
# Emit completion event for the cached response
|
||||
self._handle_emit_call_events(cached_response, LLMCallType.LLM_CALL)
|
||||
return cached_response
|
||||
|
||||
# --- 5) Set up callbacks if provided
|
||||
# --- 6) Set up callbacks if provided
|
||||
with suppress_warnings():
|
||||
if callbacks and len(callbacks) > 0:
|
||||
self.set_callbacks(callbacks)
|
||||
|
||||
try:
|
||||
# --- 6) Prepare parameters for the completion call
|
||||
# --- 7) Prepare parameters for the completion call
|
||||
params = self._prepare_completion_params(messages, tools)
|
||||
|
||||
# --- 7) Make the completion call and handle response
|
||||
# --- 8) Make the completion call and handle response
|
||||
if self.stream:
|
||||
return self._handle_streaming_response(
|
||||
response = self._handle_streaming_response(
|
||||
params, callbacks, available_functions
|
||||
)
|
||||
else:
|
||||
return self._handle_non_streaming_response(
|
||||
response = self._handle_non_streaming_response(
|
||||
params, callbacks, available_functions
|
||||
)
|
||||
|
||||
if (self._response_cache_handler and
|
||||
self._response_cache_handler.is_recording() and
|
||||
isinstance(response, str)):
|
||||
self._response_cache_handler.cache_response(
|
||||
self.model, messages, response
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except LLMContextLengthExceededException:
|
||||
# Re-raise LLMContextLengthExceededException as it should be handled
|
||||
@@ -1107,3 +1126,18 @@ class LLM(BaseLLM):
|
||||
|
||||
litellm.success_callback = success_callbacks
|
||||
litellm.failure_callback = failure_callbacks
|
||||
|
||||
def set_response_cache_handler(self, handler):
|
||||
"""
|
||||
Sets the response cache handler for record/replay functionality.
|
||||
|
||||
Args:
|
||||
handler: An instance of LLMResponseCacheHandler.
|
||||
"""
|
||||
self._response_cache_handler = handler
|
||||
|
||||
def clear_response_cache_handler(self):
|
||||
"""
|
||||
Clears the response cache handler.
|
||||
"""
|
||||
self._response_cache_handler = None
|
||||
|
||||
131
src/crewai/memory/storage/llm_response_cache_storage.py
Normal file
131
src/crewai/memory/storage/llm_response_cache_storage.py
Normal file
@@ -0,0 +1,131 @@
|
||||
import json
|
||||
import sqlite3
|
||||
import hashlib
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from crewai.utilities import Printer
|
||||
from crewai.utilities.crew_json_encoder import CrewJSONEncoder
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
class LLMResponseCacheStorage:
|
||||
"""
|
||||
SQLite storage for caching LLM responses.
|
||||
Used for offline record/replay functionality.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, db_path: str = f"{db_storage_path()}/llm_response_cache.db"
|
||||
) -> None:
|
||||
self.db_path = db_path
|
||||
self._printer: Printer = Printer()
|
||||
self._initialize_db()
|
||||
|
||||
def _initialize_db(self):
|
||||
"""
|
||||
Initializes the SQLite database and creates the llm_response_cache table
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS llm_response_cache (
|
||||
request_hash TEXT PRIMARY KEY,
|
||||
model TEXT,
|
||||
messages TEXT,
|
||||
response TEXT,
|
||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._printer.print(
|
||||
content=f"LLM RESPONSE CACHE ERROR: An error occurred during database initialization: {e}",
|
||||
color="red",
|
||||
)
|
||||
|
||||
def _compute_request_hash(self, model: str, messages: List[Dict[str, str]]) -> str:
|
||||
"""
|
||||
Computes a hash for the request based on the model and messages.
|
||||
This hash is used as the key for caching.
|
||||
|
||||
Sensitive information like API keys should not be included in the hash.
|
||||
"""
|
||||
message_str = json.dumps(messages, sort_keys=True)
|
||||
request_hash = hashlib.sha256(f"{model}:{message_str}".encode()).hexdigest()
|
||||
return request_hash
|
||||
|
||||
def add(self, model: str, messages: List[Dict[str, str]], response: str) -> None:
|
||||
"""
|
||||
Adds a response to the cache.
|
||||
"""
|
||||
try:
|
||||
request_hash = self._compute_request_hash(model, messages)
|
||||
messages_json = json.dumps(messages, cls=CrewJSONEncoder)
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO llm_response_cache
|
||||
(request_hash, model, messages, response)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
request_hash,
|
||||
model,
|
||||
messages_json,
|
||||
response,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._printer.print(
|
||||
content=f"LLM RESPONSE CACHE ERROR: Failed to add response: {e}",
|
||||
color="red",
|
||||
)
|
||||
|
||||
def get(self, model: str, messages: List[Dict[str, str]]) -> Optional[str]:
|
||||
"""
|
||||
Retrieves a response from the cache based on the model and messages.
|
||||
Returns None if not found.
|
||||
"""
|
||||
try:
|
||||
request_hash = self._compute_request_hash(model, messages)
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT response
|
||||
FROM llm_response_cache
|
||||
WHERE request_hash = ?
|
||||
""",
|
||||
(request_hash,),
|
||||
)
|
||||
|
||||
result = cursor.fetchone()
|
||||
return result[0] if result else None
|
||||
|
||||
except sqlite3.Error as e:
|
||||
self._printer.print(
|
||||
content=f"LLM RESPONSE CACHE ERROR: Failed to retrieve response: {e}",
|
||||
color="red",
|
||||
)
|
||||
return None
|
||||
|
||||
def delete_all(self) -> None:
|
||||
"""
|
||||
Deletes all records from the llm_response_cache table.
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("DELETE FROM llm_response_cache")
|
||||
conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._printer.print(
|
||||
content=f"LLM RESPONSE CACHE ERROR: Failed to clear cache: {e}",
|
||||
color="red",
|
||||
)
|
||||
69
src/crewai/utilities/llm_response_cache_handler.py
Normal file
69
src/crewai/utilities/llm_response_cache_handler.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from crewai.memory.storage.llm_response_cache_storage import LLMResponseCacheStorage
|
||||
|
||||
class LLMResponseCacheHandler:
|
||||
"""
|
||||
Handler for the LLM response cache storage.
|
||||
Used for record/replay functionality.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.storage = LLMResponseCacheStorage()
|
||||
self._recording = False
|
||||
self._replaying = False
|
||||
|
||||
def start_recording(self) -> None:
|
||||
"""
|
||||
Starts recording LLM responses.
|
||||
"""
|
||||
self._recording = True
|
||||
self._replaying = False
|
||||
|
||||
def start_replaying(self) -> None:
|
||||
"""
|
||||
Starts replaying LLM responses from the cache.
|
||||
"""
|
||||
self._recording = False
|
||||
self._replaying = True
|
||||
|
||||
def stop(self) -> None:
|
||||
"""
|
||||
Stops recording or replaying.
|
||||
"""
|
||||
self._recording = False
|
||||
self._replaying = False
|
||||
|
||||
def is_recording(self) -> bool:
|
||||
"""
|
||||
Returns whether recording is active.
|
||||
"""
|
||||
return self._recording
|
||||
|
||||
def is_replaying(self) -> bool:
|
||||
"""
|
||||
Returns whether replaying is active.
|
||||
"""
|
||||
return self._replaying
|
||||
|
||||
def cache_response(self, model: str, messages: List[Dict[str, str]], response: str) -> None:
|
||||
"""
|
||||
Caches an LLM response if recording is active.
|
||||
"""
|
||||
if self._recording:
|
||||
self.storage.add(model, messages, response)
|
||||
|
||||
def get_cached_response(self, model: str, messages: List[Dict[str, str]]) -> Optional[str]:
|
||||
"""
|
||||
Retrieves a cached LLM response if replaying is active.
|
||||
Returns None if not found or if replaying is not active.
|
||||
"""
|
||||
if self._replaying:
|
||||
return self.storage.get(model, messages)
|
||||
return None
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""
|
||||
Clears the LLM response cache.
|
||||
"""
|
||||
self.storage.delete_all()
|
||||
78
tests/llm_response_cache_test.py
Normal file
78
tests/llm_response_cache_test.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from crewai.llm import LLM
|
||||
from crewai.utilities.llm_response_cache_handler import LLMResponseCacheHandler
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def handler():
|
||||
handler = LLMResponseCacheHandler()
|
||||
handler.storage.add = MagicMock()
|
||||
handler.storage.get = MagicMock()
|
||||
return handler
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_llm_recording(handler):
|
||||
handler.start_recording()
|
||||
|
||||
llm = LLM(model="gpt-4o-mini")
|
||||
llm.set_response_cache_handler(handler)
|
||||
|
||||
messages = [{"role": "user", "content": "Hello, world!"}]
|
||||
|
||||
with patch('litellm.completion') as mock_completion:
|
||||
mock_completion.return_value = {
|
||||
"choices": [{"message": {"content": "Hello, human!"}}]
|
||||
}
|
||||
|
||||
response = llm.call(messages)
|
||||
|
||||
assert response == "Hello, human!"
|
||||
|
||||
handler.storage.add.assert_called_once_with(
|
||||
"gpt-4o-mini", messages, "Hello, human!"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_llm_replaying(handler):
|
||||
handler.start_replaying()
|
||||
handler.storage.get.return_value = "Cached response"
|
||||
|
||||
llm = LLM(model="gpt-4o-mini")
|
||||
llm.set_response_cache_handler(handler)
|
||||
|
||||
messages = [{"role": "user", "content": "Hello, world!"}]
|
||||
|
||||
with patch('litellm.completion') as mock_completion:
|
||||
response = llm.call(messages)
|
||||
|
||||
assert response == "Cached response"
|
||||
|
||||
mock_completion.assert_not_called()
|
||||
|
||||
handler.storage.get.assert_called_once_with("gpt-4o-mini", messages)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_llm_replay_fallback(handler):
|
||||
handler.start_replaying()
|
||||
handler.storage.get.return_value = None
|
||||
|
||||
llm = LLM(model="gpt-4o-mini")
|
||||
llm.set_response_cache_handler(handler)
|
||||
|
||||
messages = [{"role": "user", "content": "Hello, world!"}]
|
||||
|
||||
with patch('litellm.completion') as mock_completion:
|
||||
mock_completion.return_value = {
|
||||
"choices": [{"message": {"content": "Hello, human!"}}]
|
||||
}
|
||||
|
||||
response = llm.call(messages)
|
||||
|
||||
assert response == "Hello, human!"
|
||||
|
||||
mock_completion.assert_called_once()
|
||||
90
tests/record_replay_test.py
Normal file
90
tests/record_replay_test.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.crew import Crew
|
||||
from crewai.process import Process
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_recording_mode():
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test the recording functionality",
|
||||
backstory="A test agent for recording LLM responses",
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Return a simple response",
|
||||
expected_output="A simple response",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[agent],
|
||||
tasks=[task],
|
||||
process=Process.sequential,
|
||||
record_mode=True,
|
||||
)
|
||||
|
||||
mock_handler = MagicMock()
|
||||
crew._llm_response_cache_handler = mock_handler
|
||||
|
||||
mock_llm = MagicMock()
|
||||
agent.llm = mock_llm
|
||||
|
||||
with patch('crewai.utilities.llm_response_cache_handler.LLMResponseCacheHandler', return_value=mock_handler):
|
||||
crew.kickoff()
|
||||
|
||||
mock_handler.start_recording.assert_called_once()
|
||||
|
||||
mock_llm.set_response_cache_handler.assert_called_once_with(mock_handler)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_replay_mode():
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test the replay functionality",
|
||||
backstory="A test agent for replaying LLM responses",
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Return a simple response",
|
||||
expected_output="A simple response",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[agent],
|
||||
tasks=[task],
|
||||
process=Process.sequential,
|
||||
replay_mode=True,
|
||||
)
|
||||
|
||||
mock_handler = MagicMock()
|
||||
crew._llm_response_cache_handler = mock_handler
|
||||
|
||||
mock_llm = MagicMock()
|
||||
agent.llm = mock_llm
|
||||
|
||||
with patch('crewai.utilities.llm_response_cache_handler.LLMResponseCacheHandler', return_value=mock_handler):
|
||||
crew.kickoff()
|
||||
|
||||
mock_handler.start_replaying.assert_called_once()
|
||||
|
||||
mock_llm.set_response_cache_handler.assert_called_once_with(mock_handler)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_record_replay_flags_conflict():
|
||||
with pytest.raises(ValueError):
|
||||
crew = Crew(
|
||||
agents=[],
|
||||
tasks=[],
|
||||
process=Process.sequential,
|
||||
record_mode=True,
|
||||
replay_mode=True,
|
||||
)
|
||||
crew.kickoff()
|
||||
Reference in New Issue
Block a user