mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 00:58:30 +00:00
Add Record/Replay functionality for offline processing (Issue #2759)
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -201,9 +201,20 @@ def install(context):
|
|||||||
|
|
||||||
|
|
||||||
@crewai.command()
|
@crewai.command()
|
||||||
def run():
|
@click.option(
|
||||||
|
"--record",
|
||||||
|
is_flag=True,
|
||||||
|
help="Record LLM responses for later replay",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--replay",
|
||||||
|
is_flag=True,
|
||||||
|
help="Replay from recorded LLM responses without making network calls",
|
||||||
|
)
|
||||||
|
def run(record: bool = False, replay: bool = False):
|
||||||
"""Run the Crew."""
|
"""Run the Crew."""
|
||||||
run_crew()
|
click.echo("Running the Crew")
|
||||||
|
run_crew(record=record, replay=replay)
|
||||||
|
|
||||||
|
|
||||||
@crewai.command()
|
@crewai.command()
|
||||||
|
|||||||
@@ -14,13 +14,17 @@ class CrewType(Enum):
|
|||||||
FLOW = "flow"
|
FLOW = "flow"
|
||||||
|
|
||||||
|
|
||||||
def run_crew() -> None:
|
def run_crew(record: bool = False, replay: bool = False) -> None:
|
||||||
"""
|
"""
|
||||||
Run the crew or flow by running a command in the UV environment.
|
Run the crew or flow by running a command in the UV environment.
|
||||||
|
|
||||||
Starting from version 0.103.0, this command can be used to run both
|
Starting from version 0.103.0, this command can be used to run both
|
||||||
standard crews and flows. For flows, it detects the type from pyproject.toml
|
standard crews and flows. For flows, it detects the type from pyproject.toml
|
||||||
and automatically runs the appropriate command.
|
and automatically runs the appropriate command.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
record (bool, optional): Whether to record LLM responses. Defaults to False.
|
||||||
|
replay (bool, optional): Whether to replay from recorded LLM responses. Defaults to False.
|
||||||
"""
|
"""
|
||||||
crewai_version = get_crewai_version()
|
crewai_version = get_crewai_version()
|
||||||
min_required_version = "0.71.0"
|
min_required_version = "0.71.0"
|
||||||
@@ -44,17 +48,24 @@ def run_crew() -> None:
|
|||||||
click.echo(f"Running the {'Flow' if is_flow else 'Crew'}")
|
click.echo(f"Running the {'Flow' if is_flow else 'Crew'}")
|
||||||
|
|
||||||
# Execute the appropriate command
|
# Execute the appropriate command
|
||||||
execute_command(crew_type)
|
execute_command(crew_type, record, replay)
|
||||||
|
|
||||||
|
|
||||||
def execute_command(crew_type: CrewType) -> None:
|
def execute_command(crew_type: CrewType, record: bool = False, replay: bool = False) -> None:
|
||||||
"""
|
"""
|
||||||
Execute the appropriate command based on crew type.
|
Execute the appropriate command based on crew type.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
crew_type: The type of crew to run
|
crew_type: The type of crew to run
|
||||||
|
record: Whether to record LLM responses
|
||||||
|
replay: Whether to replay from recorded LLM responses
|
||||||
"""
|
"""
|
||||||
command = ["uv", "run", "kickoff" if crew_type == CrewType.FLOW else "run_crew"]
|
command = ["uv", "run", "kickoff" if crew_type == CrewType.FLOW else "run_crew"]
|
||||||
|
|
||||||
|
if record:
|
||||||
|
command.append("--record")
|
||||||
|
if replay:
|
||||||
|
command.append("--replay")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
subprocess.run(command, capture_output=False, text=True, check=True)
|
subprocess.run(command, capture_output=False, text=True, check=True)
|
||||||
|
|||||||
@@ -244,6 +244,15 @@ class Crew(FlowTrackable, BaseModel):
|
|||||||
default_factory=SecurityConfig,
|
default_factory=SecurityConfig,
|
||||||
description="Security configuration for the crew, including fingerprinting.",
|
description="Security configuration for the crew, including fingerprinting.",
|
||||||
)
|
)
|
||||||
|
record_mode: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="Whether to record LLM responses for later replay.",
|
||||||
|
)
|
||||||
|
replay_mode: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="Whether to replay from recorded LLM responses without making network calls.",
|
||||||
|
)
|
||||||
|
_llm_response_cache_handler: Optional[Any] = PrivateAttr(default=None)
|
||||||
|
|
||||||
@field_validator("id", mode="before")
|
@field_validator("id", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -633,6 +642,17 @@ class Crew(FlowTrackable, BaseModel):
|
|||||||
self._task_output_handler.reset()
|
self._task_output_handler.reset()
|
||||||
self._logging_color = "bold_purple"
|
self._logging_color = "bold_purple"
|
||||||
|
|
||||||
|
if self.record_mode and self.replay_mode:
|
||||||
|
raise ValueError("Cannot use both record_mode and replay_mode at the same time")
|
||||||
|
|
||||||
|
if self.record_mode or self.replay_mode:
|
||||||
|
from crewai.utilities.llm_response_cache_handler import LLMResponseCacheHandler
|
||||||
|
self._llm_response_cache_handler = LLMResponseCacheHandler()
|
||||||
|
if self.record_mode:
|
||||||
|
self._llm_response_cache_handler.start_recording()
|
||||||
|
elif self.replay_mode:
|
||||||
|
self._llm_response_cache_handler.start_replaying()
|
||||||
|
|
||||||
if inputs is not None:
|
if inputs is not None:
|
||||||
self._inputs = inputs
|
self._inputs = inputs
|
||||||
self._interpolate_inputs(inputs)
|
self._interpolate_inputs(inputs)
|
||||||
@@ -651,6 +671,12 @@ class Crew(FlowTrackable, BaseModel):
|
|||||||
|
|
||||||
if not agent.step_callback: # type: ignore # "BaseAgent" has no attribute "step_callback"
|
if not agent.step_callback: # type: ignore # "BaseAgent" has no attribute "step_callback"
|
||||||
agent.step_callback = self.step_callback # type: ignore # "BaseAgent" has no attribute "step_callback"
|
agent.step_callback = self.step_callback # type: ignore # "BaseAgent" has no attribute "step_callback"
|
||||||
|
|
||||||
|
if self._llm_response_cache_handler:
|
||||||
|
if hasattr(agent, "llm") and agent.llm:
|
||||||
|
agent.llm.set_response_cache_handler(self._llm_response_cache_handler)
|
||||||
|
if hasattr(agent, "function_calling_llm") and agent.function_calling_llm:
|
||||||
|
agent.function_calling_llm.set_response_cache_handler(self._llm_response_cache_handler)
|
||||||
|
|
||||||
agent.create_agent_executor()
|
agent.create_agent_executor()
|
||||||
|
|
||||||
@@ -1287,6 +1313,9 @@ class Crew(FlowTrackable, BaseModel):
|
|||||||
def _finish_execution(self, final_string_output: str) -> None:
|
def _finish_execution(self, final_string_output: str) -> None:
|
||||||
if self.max_rpm:
|
if self.max_rpm:
|
||||||
self._rpm_controller.stop_rpm_counter()
|
self._rpm_controller.stop_rpm_counter()
|
||||||
|
|
||||||
|
if self._llm_response_cache_handler:
|
||||||
|
self._llm_response_cache_handler.stop()
|
||||||
|
|
||||||
def calculate_usage_metrics(self) -> UsageMetrics:
|
def calculate_usage_metrics(self) -> UsageMetrics:
|
||||||
"""Calculates and returns the usage metrics."""
|
"""Calculates and returns the usage metrics."""
|
||||||
|
|||||||
@@ -296,6 +296,7 @@ class LLM(BaseLLM):
|
|||||||
self.additional_params = kwargs
|
self.additional_params = kwargs
|
||||||
self.is_anthropic = self._is_anthropic_model(model)
|
self.is_anthropic = self._is_anthropic_model(model)
|
||||||
self.stream = stream
|
self.stream = stream
|
||||||
|
self._response_cache_handler = None
|
||||||
|
|
||||||
litellm.drop_params = True
|
litellm.drop_params = True
|
||||||
|
|
||||||
@@ -869,25 +870,43 @@ class LLM(BaseLLM):
|
|||||||
for message in messages:
|
for message in messages:
|
||||||
if message.get("role") == "system":
|
if message.get("role") == "system":
|
||||||
message["role"] = "assistant"
|
message["role"] = "assistant"
|
||||||
|
|
||||||
|
if self._response_cache_handler and self._response_cache_handler.is_replaying():
|
||||||
|
cached_response = self._response_cache_handler.get_cached_response(
|
||||||
|
self.model, messages
|
||||||
|
)
|
||||||
|
if cached_response:
|
||||||
|
# Emit completion event for the cached response
|
||||||
|
self._handle_emit_call_events(cached_response, LLMCallType.LLM_CALL)
|
||||||
|
return cached_response
|
||||||
|
|
||||||
# --- 5) Set up callbacks if provided
|
# --- 6) Set up callbacks if provided
|
||||||
with suppress_warnings():
|
with suppress_warnings():
|
||||||
if callbacks and len(callbacks) > 0:
|
if callbacks and len(callbacks) > 0:
|
||||||
self.set_callbacks(callbacks)
|
self.set_callbacks(callbacks)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# --- 6) Prepare parameters for the completion call
|
# --- 7) Prepare parameters for the completion call
|
||||||
params = self._prepare_completion_params(messages, tools)
|
params = self._prepare_completion_params(messages, tools)
|
||||||
|
|
||||||
# --- 7) Make the completion call and handle response
|
# --- 8) Make the completion call and handle response
|
||||||
if self.stream:
|
if self.stream:
|
||||||
return self._handle_streaming_response(
|
response = self._handle_streaming_response(
|
||||||
params, callbacks, available_functions
|
params, callbacks, available_functions
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self._handle_non_streaming_response(
|
response = self._handle_non_streaming_response(
|
||||||
params, callbacks, available_functions
|
params, callbacks, available_functions
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (self._response_cache_handler and
|
||||||
|
self._response_cache_handler.is_recording() and
|
||||||
|
isinstance(response, str)):
|
||||||
|
self._response_cache_handler.cache_response(
|
||||||
|
self.model, messages, response
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
except LLMContextLengthExceededException:
|
except LLMContextLengthExceededException:
|
||||||
# Re-raise LLMContextLengthExceededException as it should be handled
|
# Re-raise LLMContextLengthExceededException as it should be handled
|
||||||
@@ -1107,3 +1126,18 @@ class LLM(BaseLLM):
|
|||||||
|
|
||||||
litellm.success_callback = success_callbacks
|
litellm.success_callback = success_callbacks
|
||||||
litellm.failure_callback = failure_callbacks
|
litellm.failure_callback = failure_callbacks
|
||||||
|
|
||||||
|
def set_response_cache_handler(self, handler):
|
||||||
|
"""
|
||||||
|
Sets the response cache handler for record/replay functionality.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
handler: An instance of LLMResponseCacheHandler.
|
||||||
|
"""
|
||||||
|
self._response_cache_handler = handler
|
||||||
|
|
||||||
|
def clear_response_cache_handler(self):
|
||||||
|
"""
|
||||||
|
Clears the response cache handler.
|
||||||
|
"""
|
||||||
|
self._response_cache_handler = None
|
||||||
|
|||||||
131
src/crewai/memory/storage/llm_response_cache_storage.py
Normal file
131
src/crewai/memory/storage/llm_response_cache_storage.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
import json
|
||||||
|
import sqlite3
|
||||||
|
import hashlib
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from crewai.utilities import Printer
|
||||||
|
from crewai.utilities.crew_json_encoder import CrewJSONEncoder
|
||||||
|
from crewai.utilities.paths import db_storage_path
|
||||||
|
|
||||||
|
class LLMResponseCacheStorage:
|
||||||
|
"""
|
||||||
|
SQLite storage for caching LLM responses.
|
||||||
|
Used for offline record/replay functionality.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, db_path: str = f"{db_storage_path()}/llm_response_cache.db"
|
||||||
|
) -> None:
|
||||||
|
self.db_path = db_path
|
||||||
|
self._printer: Printer = Printer()
|
||||||
|
self._initialize_db()
|
||||||
|
|
||||||
|
def _initialize_db(self):
|
||||||
|
"""
|
||||||
|
Initializes the SQLite database and creates the llm_response_cache table
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS llm_response_cache (
|
||||||
|
request_hash TEXT PRIMARY KEY,
|
||||||
|
model TEXT,
|
||||||
|
messages TEXT,
|
||||||
|
response TEXT,
|
||||||
|
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._printer.print(
|
||||||
|
content=f"LLM RESPONSE CACHE ERROR: An error occurred during database initialization: {e}",
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _compute_request_hash(self, model: str, messages: List[Dict[str, str]]) -> str:
|
||||||
|
"""
|
||||||
|
Computes a hash for the request based on the model and messages.
|
||||||
|
This hash is used as the key for caching.
|
||||||
|
|
||||||
|
Sensitive information like API keys should not be included in the hash.
|
||||||
|
"""
|
||||||
|
message_str = json.dumps(messages, sort_keys=True)
|
||||||
|
request_hash = hashlib.sha256(f"{model}:{message_str}".encode()).hexdigest()
|
||||||
|
return request_hash
|
||||||
|
|
||||||
|
def add(self, model: str, messages: List[Dict[str, str]], response: str) -> None:
|
||||||
|
"""
|
||||||
|
Adds a response to the cache.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
request_hash = self._compute_request_hash(model, messages)
|
||||||
|
messages_json = json.dumps(messages, cls=CrewJSONEncoder)
|
||||||
|
|
||||||
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
INSERT OR REPLACE INTO llm_response_cache
|
||||||
|
(request_hash, model, messages, response)
|
||||||
|
VALUES (?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
request_hash,
|
||||||
|
model,
|
||||||
|
messages_json,
|
||||||
|
response,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._printer.print(
|
||||||
|
content=f"LLM RESPONSE CACHE ERROR: Failed to add response: {e}",
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
|
||||||
|
def get(self, model: str, messages: List[Dict[str, str]]) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Retrieves a response from the cache based on the model and messages.
|
||||||
|
Returns None if not found.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
request_hash = self._compute_request_hash(model, messages)
|
||||||
|
|
||||||
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
SELECT response
|
||||||
|
FROM llm_response_cache
|
||||||
|
WHERE request_hash = ?
|
||||||
|
""",
|
||||||
|
(request_hash,),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = cursor.fetchone()
|
||||||
|
return result[0] if result else None
|
||||||
|
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._printer.print(
|
||||||
|
content=f"LLM RESPONSE CACHE ERROR: Failed to retrieve response: {e}",
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def delete_all(self) -> None:
|
||||||
|
"""
|
||||||
|
Deletes all records from the llm_response_cache table.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("DELETE FROM llm_response_cache")
|
||||||
|
conn.commit()
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._printer.print(
|
||||||
|
content=f"LLM RESPONSE CACHE ERROR: Failed to clear cache: {e}",
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
69
src/crewai/utilities/llm_response_cache_handler.py
Normal file
69
src/crewai/utilities/llm_response_cache_handler.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from crewai.memory.storage.llm_response_cache_storage import LLMResponseCacheStorage
|
||||||
|
|
||||||
|
class LLMResponseCacheHandler:
|
||||||
|
"""
|
||||||
|
Handler for the LLM response cache storage.
|
||||||
|
Used for record/replay functionality.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.storage = LLMResponseCacheStorage()
|
||||||
|
self._recording = False
|
||||||
|
self._replaying = False
|
||||||
|
|
||||||
|
def start_recording(self) -> None:
|
||||||
|
"""
|
||||||
|
Starts recording LLM responses.
|
||||||
|
"""
|
||||||
|
self._recording = True
|
||||||
|
self._replaying = False
|
||||||
|
|
||||||
|
def start_replaying(self) -> None:
|
||||||
|
"""
|
||||||
|
Starts replaying LLM responses from the cache.
|
||||||
|
"""
|
||||||
|
self._recording = False
|
||||||
|
self._replaying = True
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
"""
|
||||||
|
Stops recording or replaying.
|
||||||
|
"""
|
||||||
|
self._recording = False
|
||||||
|
self._replaying = False
|
||||||
|
|
||||||
|
def is_recording(self) -> bool:
|
||||||
|
"""
|
||||||
|
Returns whether recording is active.
|
||||||
|
"""
|
||||||
|
return self._recording
|
||||||
|
|
||||||
|
def is_replaying(self) -> bool:
|
||||||
|
"""
|
||||||
|
Returns whether replaying is active.
|
||||||
|
"""
|
||||||
|
return self._replaying
|
||||||
|
|
||||||
|
def cache_response(self, model: str, messages: List[Dict[str, str]], response: str) -> None:
|
||||||
|
"""
|
||||||
|
Caches an LLM response if recording is active.
|
||||||
|
"""
|
||||||
|
if self._recording:
|
||||||
|
self.storage.add(model, messages, response)
|
||||||
|
|
||||||
|
def get_cached_response(self, model: str, messages: List[Dict[str, str]]) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Retrieves a cached LLM response if replaying is active.
|
||||||
|
Returns None if not found or if replaying is not active.
|
||||||
|
"""
|
||||||
|
if self._replaying:
|
||||||
|
return self.storage.get(model, messages)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def clear_cache(self) -> None:
|
||||||
|
"""
|
||||||
|
Clears the LLM response cache.
|
||||||
|
"""
|
||||||
|
self.storage.delete_all()
|
||||||
78
tests/llm_response_cache_test.py
Normal file
78
tests/llm_response_cache_test.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from crewai.llm import LLM
|
||||||
|
from crewai.utilities.llm_response_cache_handler import LLMResponseCacheHandler
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def handler():
|
||||||
|
handler = LLMResponseCacheHandler()
|
||||||
|
handler.storage.add = MagicMock()
|
||||||
|
handler.storage.get = MagicMock()
|
||||||
|
return handler
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
|
def test_llm_recording(handler):
|
||||||
|
handler.start_recording()
|
||||||
|
|
||||||
|
llm = LLM(model="gpt-4o-mini")
|
||||||
|
llm.set_response_cache_handler(handler)
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "Hello, world!"}]
|
||||||
|
|
||||||
|
with patch('litellm.completion') as mock_completion:
|
||||||
|
mock_completion.return_value = {
|
||||||
|
"choices": [{"message": {"content": "Hello, human!"}}]
|
||||||
|
}
|
||||||
|
|
||||||
|
response = llm.call(messages)
|
||||||
|
|
||||||
|
assert response == "Hello, human!"
|
||||||
|
|
||||||
|
handler.storage.add.assert_called_once_with(
|
||||||
|
"gpt-4o-mini", messages, "Hello, human!"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
|
def test_llm_replaying(handler):
|
||||||
|
handler.start_replaying()
|
||||||
|
handler.storage.get.return_value = "Cached response"
|
||||||
|
|
||||||
|
llm = LLM(model="gpt-4o-mini")
|
||||||
|
llm.set_response_cache_handler(handler)
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "Hello, world!"}]
|
||||||
|
|
||||||
|
with patch('litellm.completion') as mock_completion:
|
||||||
|
response = llm.call(messages)
|
||||||
|
|
||||||
|
assert response == "Cached response"
|
||||||
|
|
||||||
|
mock_completion.assert_not_called()
|
||||||
|
|
||||||
|
handler.storage.get.assert_called_once_with("gpt-4o-mini", messages)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
|
def test_llm_replay_fallback(handler):
|
||||||
|
handler.start_replaying()
|
||||||
|
handler.storage.get.return_value = None
|
||||||
|
|
||||||
|
llm = LLM(model="gpt-4o-mini")
|
||||||
|
llm.set_response_cache_handler(handler)
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "Hello, world!"}]
|
||||||
|
|
||||||
|
with patch('litellm.completion') as mock_completion:
|
||||||
|
mock_completion.return_value = {
|
||||||
|
"choices": [{"message": {"content": "Hello, human!"}}]
|
||||||
|
}
|
||||||
|
|
||||||
|
response = llm.call(messages)
|
||||||
|
|
||||||
|
assert response == "Hello, human!"
|
||||||
|
|
||||||
|
mock_completion.assert_called_once()
|
||||||
90
tests/record_replay_test.py
Normal file
90
tests/record_replay_test.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from crewai.agent import Agent
|
||||||
|
from crewai.crew import Crew
|
||||||
|
from crewai.process import Process
|
||||||
|
from crewai.task import Task
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
|
def test_crew_recording_mode():
|
||||||
|
agent = Agent(
|
||||||
|
role="Test Agent",
|
||||||
|
goal="Test the recording functionality",
|
||||||
|
backstory="A test agent for recording LLM responses",
|
||||||
|
)
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Return a simple response",
|
||||||
|
expected_output="A simple response",
|
||||||
|
agent=agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
crew = Crew(
|
||||||
|
agents=[agent],
|
||||||
|
tasks=[task],
|
||||||
|
process=Process.sequential,
|
||||||
|
record_mode=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_handler = MagicMock()
|
||||||
|
crew._llm_response_cache_handler = mock_handler
|
||||||
|
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
agent.llm = mock_llm
|
||||||
|
|
||||||
|
with patch('crewai.utilities.llm_response_cache_handler.LLMResponseCacheHandler', return_value=mock_handler):
|
||||||
|
crew.kickoff()
|
||||||
|
|
||||||
|
mock_handler.start_recording.assert_called_once()
|
||||||
|
|
||||||
|
mock_llm.set_response_cache_handler.assert_called_once_with(mock_handler)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
|
def test_crew_replay_mode():
|
||||||
|
agent = Agent(
|
||||||
|
role="Test Agent",
|
||||||
|
goal="Test the replay functionality",
|
||||||
|
backstory="A test agent for replaying LLM responses",
|
||||||
|
)
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Return a simple response",
|
||||||
|
expected_output="A simple response",
|
||||||
|
agent=agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
crew = Crew(
|
||||||
|
agents=[agent],
|
||||||
|
tasks=[task],
|
||||||
|
process=Process.sequential,
|
||||||
|
replay_mode=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_handler = MagicMock()
|
||||||
|
crew._llm_response_cache_handler = mock_handler
|
||||||
|
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
agent.llm = mock_llm
|
||||||
|
|
||||||
|
with patch('crewai.utilities.llm_response_cache_handler.LLMResponseCacheHandler', return_value=mock_handler):
|
||||||
|
crew.kickoff()
|
||||||
|
|
||||||
|
mock_handler.start_replaying.assert_called_once()
|
||||||
|
|
||||||
|
mock_llm.set_response_cache_handler.assert_called_once_with(mock_handler)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
|
def test_record_replay_flags_conflict():
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
crew = Crew(
|
||||||
|
agents=[],
|
||||||
|
tasks=[],
|
||||||
|
process=Process.sequential,
|
||||||
|
record_mode=True,
|
||||||
|
replay_mode=True,
|
||||||
|
)
|
||||||
|
crew.kickoff()
|
||||||
Reference in New Issue
Block a user