mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-04 21:58:29 +00:00
Compare commits
8 Commits
devin/1746
...
devin/1746
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
79547fba25 | ||
|
|
171f8b63fd | ||
|
|
72df165b07 | ||
|
|
63eccf5e30 | ||
|
|
a98a44afb2 | ||
|
|
6e0f1fe38d | ||
|
|
c2bf2b3210 | ||
|
|
14579a7861 |
92
manual_test_csv_update.py
Normal file
92
manual_test_csv_update.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""
|
||||
Manual test script to verify CSV knowledge source update functionality.
|
||||
This script creates a CSV file, loads it as a knowledge source, updates the file,
|
||||
and verifies that the updated content is detected and loaded.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.csv_knowledge_source import CSVKnowledgeSource
|
||||
|
||||
|
||||
def test_csv_knowledge_source_updates():
|
||||
"""Test that CSVKnowledgeSource properly detects and loads updates to CSV files."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
csv_path = Path(tmpdir) / "test_updates.csv"
|
||||
|
||||
initial_csv_content = [
|
||||
["name", "age", "city"],
|
||||
["John", "30", "New York"],
|
||||
["Alice", "25", "San Francisco"],
|
||||
["Bob", "28", "Chicago"],
|
||||
]
|
||||
|
||||
with open(csv_path, "w") as f:
|
||||
for row in initial_csv_content:
|
||||
f.write(",".join(row) + "\n")
|
||||
|
||||
print(f"Created CSV file at {csv_path}")
|
||||
|
||||
csv_source = CSVKnowledgeSource(file_paths=[csv_path])
|
||||
|
||||
if not hasattr(csv_source, 'files_have_changed'):
|
||||
print("❌ TEST FAILED: files_have_changed method not found in CSVKnowledgeSource")
|
||||
return False
|
||||
|
||||
if not hasattr(csv_source, '_file_mtimes'):
|
||||
print("❌ TEST FAILED: _file_mtimes attribute not found in CSVKnowledgeSource")
|
||||
return False
|
||||
|
||||
knowledge = Knowledge(sources=[csv_source], collection_name="test_updates")
|
||||
|
||||
if not hasattr(knowledge, '_check_and_reload_sources'):
|
||||
print("❌ TEST FAILED: _check_and_reload_sources method not found in Knowledge")
|
||||
return False
|
||||
|
||||
print("✅ All required methods and attributes exist")
|
||||
|
||||
updated_csv_content = [
|
||||
["name", "age", "city"],
|
||||
["John", "30", "Boston"], # Changed city
|
||||
["Alice", "25", "San Francisco"],
|
||||
["Bob", "28", "Chicago"],
|
||||
["Eve", "22", "Miami"], # Added new person
|
||||
]
|
||||
|
||||
print("\nWaiting for 1 second before updating file...")
|
||||
time.sleep(1)
|
||||
|
||||
with open(csv_path, "w") as f:
|
||||
for row in updated_csv_content:
|
||||
f.write(",".join(row) + "\n")
|
||||
|
||||
print(f"Updated CSV file at {csv_path}")
|
||||
|
||||
if not csv_source.files_have_changed():
|
||||
print("❌ TEST FAILED: files_have_changed did not detect file modification")
|
||||
return False
|
||||
|
||||
print("✅ files_have_changed correctly detected file modification")
|
||||
|
||||
csv_source._record_file_mtimes()
|
||||
csv_source.content = csv_source.load_content()
|
||||
|
||||
content_str = str(csv_source.content)
|
||||
if "Boston" in content_str and "Eve" in content_str and "Miami" in content_str:
|
||||
print("✅ Content was correctly updated with new data")
|
||||
else:
|
||||
print("❌ TEST FAILED: Content was not updated with new data")
|
||||
return False
|
||||
|
||||
print("\n✅ TEST PASSED: CSV knowledge source correctly detects and loads file updates")
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = test_csv_knowledge_source_updates()
|
||||
sys.exit(0 if success else 1)
|
||||
@@ -201,22 +201,9 @@ def install(context):
|
||||
|
||||
|
||||
@crewai.command()
|
||||
@click.option(
|
||||
"--record",
|
||||
is_flag=True,
|
||||
help="Record LLM responses for later replay",
|
||||
)
|
||||
@click.option(
|
||||
"--replay",
|
||||
is_flag=True,
|
||||
help="Replay from recorded LLM responses without making network calls",
|
||||
)
|
||||
def run(record: bool = False, replay: bool = False):
|
||||
def run():
|
||||
"""Run the Crew."""
|
||||
if record and replay:
|
||||
raise click.UsageError("Cannot use --record and --replay simultaneously")
|
||||
click.echo("Running the Crew")
|
||||
run_crew(record=record, replay=replay)
|
||||
run_crew()
|
||||
|
||||
|
||||
@crewai.command()
|
||||
|
||||
@@ -14,17 +14,13 @@ class CrewType(Enum):
|
||||
FLOW = "flow"
|
||||
|
||||
|
||||
def run_crew(record: bool = False, replay: bool = False) -> None:
|
||||
def run_crew() -> None:
|
||||
"""
|
||||
Run the crew or flow by running a command in the UV environment.
|
||||
|
||||
Starting from version 0.103.0, this command can be used to run both
|
||||
standard crews and flows. For flows, it detects the type from pyproject.toml
|
||||
and automatically runs the appropriate command.
|
||||
|
||||
Args:
|
||||
record (bool, optional): Whether to record LLM responses. Defaults to False.
|
||||
replay (bool, optional): Whether to replay from recorded LLM responses. Defaults to False.
|
||||
"""
|
||||
crewai_version = get_crewai_version()
|
||||
min_required_version = "0.71.0"
|
||||
@@ -48,24 +44,17 @@ def run_crew(record: bool = False, replay: bool = False) -> None:
|
||||
click.echo(f"Running the {'Flow' if is_flow else 'Crew'}")
|
||||
|
||||
# Execute the appropriate command
|
||||
execute_command(crew_type, record, replay)
|
||||
execute_command(crew_type)
|
||||
|
||||
|
||||
def execute_command(crew_type: CrewType, record: bool = False, replay: bool = False) -> None:
|
||||
def execute_command(crew_type: CrewType) -> None:
|
||||
"""
|
||||
Execute the appropriate command based on crew type.
|
||||
|
||||
Args:
|
||||
crew_type: The type of crew to run
|
||||
record: Whether to record LLM responses
|
||||
replay: Whether to replay from recorded LLM responses
|
||||
"""
|
||||
command = ["uv", "run", "kickoff" if crew_type == CrewType.FLOW else "run_crew"]
|
||||
|
||||
if record:
|
||||
command.append("--record")
|
||||
if replay:
|
||||
command.append("--replay")
|
||||
|
||||
try:
|
||||
subprocess.run(command, capture_output=False, text=True, check=True)
|
||||
|
||||
@@ -244,15 +244,6 @@ class Crew(FlowTrackable, BaseModel):
|
||||
default_factory=SecurityConfig,
|
||||
description="Security configuration for the crew, including fingerprinting.",
|
||||
)
|
||||
record_mode: bool = Field(
|
||||
default=False,
|
||||
description="Whether to record LLM responses for later replay.",
|
||||
)
|
||||
replay_mode: bool = Field(
|
||||
default=False,
|
||||
description="Whether to replay from recorded LLM responses without making network calls.",
|
||||
)
|
||||
_llm_response_cache_handler: Optional[Any] = PrivateAttr(default=None)
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
@@ -642,19 +633,6 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self._task_output_handler.reset()
|
||||
self._logging_color = "bold_purple"
|
||||
|
||||
if self.record_mode and self.replay_mode:
|
||||
raise ValueError("Cannot use both record_mode and replay_mode at the same time")
|
||||
|
||||
if self.record_mode or self.replay_mode:
|
||||
from crewai.utilities.llm_response_cache_handler import (
|
||||
LLMResponseCacheHandler,
|
||||
)
|
||||
self._llm_response_cache_handler = LLMResponseCacheHandler()
|
||||
if self.record_mode:
|
||||
self._llm_response_cache_handler.start_recording()
|
||||
elif self.replay_mode:
|
||||
self._llm_response_cache_handler.start_replaying()
|
||||
|
||||
if inputs is not None:
|
||||
self._inputs = inputs
|
||||
self._interpolate_inputs(inputs)
|
||||
@@ -673,12 +651,6 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
if not agent.step_callback: # type: ignore # "BaseAgent" has no attribute "step_callback"
|
||||
agent.step_callback = self.step_callback # type: ignore # "BaseAgent" has no attribute "step_callback"
|
||||
|
||||
if self._llm_response_cache_handler:
|
||||
if hasattr(agent, "llm") and agent.llm:
|
||||
agent.llm.set_response_cache_handler(self._llm_response_cache_handler)
|
||||
if hasattr(agent, "function_calling_llm") and agent.function_calling_llm:
|
||||
agent.function_calling_llm.set_response_cache_handler(self._llm_response_cache_handler)
|
||||
|
||||
agent.create_agent_executor()
|
||||
|
||||
@@ -1315,9 +1287,6 @@ class Crew(FlowTrackable, BaseModel):
|
||||
def _finish_execution(self, final_string_output: str) -> None:
|
||||
if self.max_rpm:
|
||||
self._rpm_controller.stop_rpm_counter()
|
||||
|
||||
if self._llm_response_cache_handler:
|
||||
self._llm_response_cache_handler.stop()
|
||||
|
||||
def calculate_usage_metrics(self) -> UsageMetrics:
|
||||
"""Calculates and returns the usage metrics."""
|
||||
|
||||
@@ -5,6 +5,7 @@ from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
|
||||
from crewai.utilities.logger import Logger
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false" # removes logging from fastembed
|
||||
|
||||
@@ -12,10 +13,19 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false" # removes logging from fastembed
|
||||
class Knowledge(BaseModel):
|
||||
"""
|
||||
Knowledge is a collection of sources and setup for the vector store to save and query relevant context.
|
||||
|
||||
This class manages knowledge sources and provides methods to query them for relevant information.
|
||||
It automatically detects and reloads file-based knowledge sources when their underlying files change.
|
||||
|
||||
Args:
|
||||
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
|
||||
The knowledge sources to use for querying.
|
||||
storage: Optional[KnowledgeStorage] = Field(default=None)
|
||||
The storage backend for knowledge embeddings.
|
||||
embedder: Optional[Dict[str, Any]] = None
|
||||
Configuration for the embedding model.
|
||||
collection_name: Optional[str] = None
|
||||
Name of the collection to use for storage.
|
||||
"""
|
||||
|
||||
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
|
||||
@@ -23,6 +33,7 @@ class Knowledge(BaseModel):
|
||||
storage: Optional[KnowledgeStorage] = Field(default=None)
|
||||
embedder: Optional[Dict[str, Any]] = None
|
||||
collection_name: Optional[str] = None
|
||||
_logger: Logger = Logger(verbose=True)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -54,6 +65,8 @@ class Knowledge(BaseModel):
|
||||
"""
|
||||
if self.storage is None:
|
||||
raise ValueError("Storage is not initialized.")
|
||||
|
||||
self._check_and_reload_sources()
|
||||
|
||||
results = self.storage.search(
|
||||
query,
|
||||
@@ -61,6 +74,65 @@ class Knowledge(BaseModel):
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
return results
|
||||
|
||||
def _check_and_reload_sources(self):
|
||||
"""
|
||||
Check if any file-based knowledge sources have changed and reload them if necessary.
|
||||
|
||||
This method detects modifications to source files by comparing their modification timestamps
|
||||
with previously recorded values. When changes are detected, the source is reloaded and
|
||||
the storage is updated with the new content.
|
||||
|
||||
The method handles various file-related exceptions with specific error messages:
|
||||
- FileNotFoundError: When a source file no longer exists
|
||||
- PermissionError: When there are permission issues accessing a file
|
||||
- IOError: When there are I/O errors reading a file
|
||||
- ValueError: When there are issues with file content format
|
||||
- Other unexpected exceptions are also caught and logged
|
||||
|
||||
Each exception is logged with appropriate context to aid in troubleshooting.
|
||||
"""
|
||||
for source in self.sources:
|
||||
source_name = source.__class__.__name__
|
||||
try:
|
||||
if hasattr(source, 'files_have_changed') and source.files_have_changed():
|
||||
self._logger.log("info", f"Reloading modified source: {source_name}")
|
||||
|
||||
# Update file modification timestamps
|
||||
try:
|
||||
source._record_file_mtimes()
|
||||
except (PermissionError, IOError) as e:
|
||||
self._logger.log("warning", f"Could not record file timestamps for {source_name}: {str(e)}")
|
||||
|
||||
try:
|
||||
source.content = source.load_content()
|
||||
except FileNotFoundError as e:
|
||||
self._logger.log("error", f"File not found when loading content for {source_name}: {str(e)}")
|
||||
continue
|
||||
except PermissionError as e:
|
||||
self._logger.log("error", f"Permission error when loading content for {source_name}: {str(e)}")
|
||||
continue
|
||||
except IOError as e:
|
||||
self._logger.log("error", f"IO error when loading content for {source_name}: {str(e)}")
|
||||
continue
|
||||
except ValueError as e:
|
||||
self._logger.log("error", f"Invalid content format in {source_name}: {str(e)}")
|
||||
continue
|
||||
|
||||
try:
|
||||
source.add()
|
||||
self._logger.log("info", f"Successfully reloaded and updated {source_name}")
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Failed to update storage for {source_name}: {str(e)}")
|
||||
|
||||
except FileNotFoundError as e:
|
||||
self._logger.log("error", f"File not found when checking for updates in {source_name}: {str(e)}")
|
||||
except PermissionError as e:
|
||||
self._logger.log("error", f"Permission error when checking for updates in {source_name}: {str(e)}")
|
||||
except IOError as e:
|
||||
self._logger.log("error", f"IO error when checking for updates in {source_name}: {str(e)}")
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Unexpected error when checking for updates in {source_name}: {str(e)}")
|
||||
|
||||
def add_sources(self):
|
||||
try:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
@@ -11,9 +12,24 @@ from crewai.utilities.logger import Logger
|
||||
|
||||
|
||||
class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
"""Base class for knowledge sources that load content from files."""
|
||||
"""
|
||||
Base class for knowledge sources that load content from files.
|
||||
|
||||
This class provides common functionality for file-based knowledge sources,
|
||||
including file path validation, content loading, and change detection.
|
||||
It automatically tracks file modification times to detect when files have
|
||||
been updated and need to be reloaded.
|
||||
|
||||
Attributes:
|
||||
file_path: Deprecated. Use file_paths instead.
|
||||
file_paths: Path(s) to the file(s) containing knowledge data.
|
||||
content: Dictionary mapping file paths to their loaded content.
|
||||
storage: Storage backend for the knowledge data.
|
||||
safe_file_paths: Validated list of Path objects.
|
||||
"""
|
||||
|
||||
_logger: Logger = Logger(verbose=True)
|
||||
|
||||
file_path: Optional[Union[Path, List[Path], str, List[str]]] = Field(
|
||||
default=None,
|
||||
description="[Deprecated] The path to the file. Use file_paths instead.",
|
||||
@@ -43,7 +59,34 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
"""Post-initialization method to load content."""
|
||||
self.safe_file_paths = self._process_file_paths()
|
||||
self.validate_content()
|
||||
self._record_file_mtimes()
|
||||
self.content = self.load_content()
|
||||
|
||||
def _record_file_mtimes(self):
|
||||
"""
|
||||
Record modification times of all files.
|
||||
|
||||
This method stores the current modification timestamps of all files
|
||||
in the _file_mtimes dictionary. These timestamps are later used to
|
||||
detect when files have been modified and need to be reloaded.
|
||||
|
||||
Thread-safe: Uses a lock to prevent concurrent modifications.
|
||||
"""
|
||||
with self._lock:
|
||||
self._file_mtimes = {}
|
||||
for path in self.safe_file_paths:
|
||||
try:
|
||||
if path.exists() and path.is_file():
|
||||
if os.access(path, os.R_OK):
|
||||
self._file_mtimes[path] = path.stat().st_mtime
|
||||
else:
|
||||
self._logger.log("warning", f"File {path} is not readable.")
|
||||
except PermissionError as e:
|
||||
self._logger.log("error", f"Permission error when recording file timestamp for {path}: {str(e)}")
|
||||
except IOError as e:
|
||||
self._logger.log("error", f"IO error when recording file timestamp for {path}: {str(e)}")
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Unexpected error when recording file timestamp for {path}: {str(e)}")
|
||||
|
||||
@abstractmethod
|
||||
def load_content(self) -> Dict[Path, str]:
|
||||
@@ -107,3 +150,41 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
)
|
||||
|
||||
return [self.convert_to_path(path) for path in path_list]
|
||||
|
||||
def files_have_changed(self) -> bool:
|
||||
"""
|
||||
Check if any of the files have been modified since they were last loaded.
|
||||
|
||||
This method compares the current modification timestamps of files with the
|
||||
previously recorded timestamps to detect changes. When a file has been modified,
|
||||
it logs the change and returns True to trigger a reload.
|
||||
|
||||
Returns:
|
||||
bool: True if any file has been modified, False otherwise.
|
||||
"""
|
||||
for path in self.safe_file_paths:
|
||||
try:
|
||||
if not path.exists():
|
||||
self._logger.log("warning", f"File {path} no longer exists.")
|
||||
continue
|
||||
|
||||
if not path.is_file():
|
||||
self._logger.log("warning", f"Path {path} is not a file.")
|
||||
continue
|
||||
|
||||
if not os.access(path, os.R_OK):
|
||||
self._logger.log("warning", f"File {path} is not readable.")
|
||||
continue
|
||||
|
||||
current_mtime = path.stat().st_mtime
|
||||
if path not in self._file_mtimes or current_mtime > self._file_mtimes[path]:
|
||||
self._logger.log("info", f"File {path} has been modified. Reloading data.")
|
||||
return True
|
||||
except PermissionError as e:
|
||||
self._logger.log("error", f"Permission error when checking file {path}: {str(e)}")
|
||||
except IOError as e:
|
||||
self._logger.log("error", f"IO error when checking file {path}: {str(e)}")
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Unexpected error when checking file {path}: {str(e)}")
|
||||
|
||||
return False
|
||||
|
||||
@@ -296,7 +296,6 @@ class LLM(BaseLLM):
|
||||
self.additional_params = kwargs
|
||||
self.is_anthropic = self._is_anthropic_model(model)
|
||||
self.stream = stream
|
||||
self._response_cache_handler = None
|
||||
|
||||
litellm.drop_params = True
|
||||
|
||||
@@ -870,43 +869,25 @@ class LLM(BaseLLM):
|
||||
for message in messages:
|
||||
if message.get("role") == "system":
|
||||
message["role"] = "assistant"
|
||||
|
||||
if self._response_cache_handler and self._response_cache_handler.is_replaying():
|
||||
cached_response = self._response_cache_handler.get_cached_response(
|
||||
self.model, messages
|
||||
)
|
||||
if cached_response:
|
||||
# Emit completion event for the cached response
|
||||
self._handle_emit_call_events(cached_response, LLMCallType.LLM_CALL)
|
||||
return cached_response
|
||||
|
||||
# --- 6) Set up callbacks if provided
|
||||
# --- 5) Set up callbacks if provided
|
||||
with suppress_warnings():
|
||||
if callbacks and len(callbacks) > 0:
|
||||
self.set_callbacks(callbacks)
|
||||
|
||||
try:
|
||||
# --- 7) Prepare parameters for the completion call
|
||||
# --- 6) Prepare parameters for the completion call
|
||||
params = self._prepare_completion_params(messages, tools)
|
||||
|
||||
# --- 8) Make the completion call and handle response
|
||||
# --- 7) Make the completion call and handle response
|
||||
if self.stream:
|
||||
response = self._handle_streaming_response(
|
||||
return self._handle_streaming_response(
|
||||
params, callbacks, available_functions
|
||||
)
|
||||
else:
|
||||
response = self._handle_non_streaming_response(
|
||||
return self._handle_non_streaming_response(
|
||||
params, callbacks, available_functions
|
||||
)
|
||||
|
||||
if (self._response_cache_handler and
|
||||
self._response_cache_handler.is_recording() and
|
||||
isinstance(response, str)):
|
||||
self._response_cache_handler.cache_response(
|
||||
self.model, messages, response
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except LLMContextLengthExceededException:
|
||||
# Re-raise LLMContextLengthExceededException as it should be handled
|
||||
@@ -1126,18 +1107,3 @@ class LLM(BaseLLM):
|
||||
|
||||
litellm.success_callback = success_callbacks
|
||||
litellm.failure_callback = failure_callbacks
|
||||
|
||||
def set_response_cache_handler(self, handler):
|
||||
"""
|
||||
Sets the response cache handler for record/replay functionality.
|
||||
|
||||
Args:
|
||||
handler: An instance of LLMResponseCacheHandler.
|
||||
"""
|
||||
self._response_cache_handler = handler
|
||||
|
||||
def clear_response_cache_handler(self):
|
||||
"""
|
||||
Clears the response cache handler.
|
||||
"""
|
||||
self._response_cache_handler = None
|
||||
|
||||
@@ -1,314 +0,0 @@
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from crewai.utilities import Printer
|
||||
from crewai.utilities.crew_json_encoder import CrewJSONEncoder
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMResponseCacheStorage:
|
||||
"""
|
||||
SQLite storage for caching LLM responses.
|
||||
Used for offline record/replay functionality.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, db_path: str = f"{db_storage_path()}/llm_response_cache.db"
|
||||
) -> None:
|
||||
self.db_path = db_path
|
||||
self._printer: Printer = Printer()
|
||||
self._connection_pool: Dict[int, sqlite3.Connection] = {}
|
||||
self._initialize_db()
|
||||
|
||||
def _get_connection(self) -> sqlite3.Connection:
|
||||
"""
|
||||
Gets a connection from the connection pool or creates a new one.
|
||||
Uses thread-local storage to ensure thread safety.
|
||||
"""
|
||||
thread_id = threading.get_ident()
|
||||
if thread_id not in self._connection_pool:
|
||||
try:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.execute("PRAGMA foreign_keys = ON")
|
||||
conn.execute("PRAGMA journal_mode = WAL")
|
||||
self._connection_pool[thread_id] = conn
|
||||
except sqlite3.Error as e:
|
||||
error_msg = f"Failed to create SQLite connection: {e}"
|
||||
self._printer.print(
|
||||
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
|
||||
color="red",
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
return self._connection_pool[thread_id]
|
||||
|
||||
def _close_connections(self) -> None:
|
||||
"""
|
||||
Closes all connections in the connection pool.
|
||||
"""
|
||||
for thread_id, conn in list(self._connection_pool.items()):
|
||||
try:
|
||||
conn.close()
|
||||
del self._connection_pool[thread_id]
|
||||
except sqlite3.Error as e:
|
||||
error_msg = f"Failed to close SQLite connection: {e}"
|
||||
self._printer.print(
|
||||
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
|
||||
color="red",
|
||||
)
|
||||
logger.error(error_msg)
|
||||
|
||||
def _initialize_db(self) -> None:
|
||||
"""
|
||||
Initializes the SQLite database and creates the llm_response_cache table
|
||||
"""
|
||||
try:
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS llm_response_cache (
|
||||
request_hash TEXT PRIMARY KEY,
|
||||
model TEXT,
|
||||
messages TEXT,
|
||||
response TEXT,
|
||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
error_msg = f"Failed to initialize database: {e}"
|
||||
self._printer.print(
|
||||
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
|
||||
color="red",
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
|
||||
def _compute_request_hash(self, model: str, messages: List[Dict[str, str]]) -> str:
|
||||
"""
|
||||
Computes a hash for the request based on the model and messages.
|
||||
This hash is used as the key for caching.
|
||||
|
||||
Sensitive information like API keys should not be included in the hash.
|
||||
"""
|
||||
try:
|
||||
message_str = json.dumps(messages, sort_keys=True)
|
||||
request_hash = hashlib.sha256(f"{model}:{message_str}".encode()).hexdigest()
|
||||
return request_hash
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to compute request hash: {e}"
|
||||
self._printer.print(
|
||||
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
|
||||
color="red",
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
|
||||
def add(self, model: str, messages: List[Dict[str, str]], response: str) -> None:
|
||||
"""
|
||||
Adds a response to the cache.
|
||||
"""
|
||||
try:
|
||||
request_hash = self._compute_request_hash(model, messages)
|
||||
messages_json = json.dumps(messages, cls=CrewJSONEncoder)
|
||||
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO llm_response_cache
|
||||
(request_hash, model, messages, response)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
request_hash,
|
||||
model,
|
||||
messages_json,
|
||||
response,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
error_msg = f"Failed to add response to cache: {e}"
|
||||
self._printer.print(
|
||||
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
|
||||
color="red",
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error when adding response: {e}"
|
||||
self._printer.print(
|
||||
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
|
||||
color="red",
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
|
||||
def get(self, model: str, messages: List[Dict[str, str]]) -> Optional[str]:
|
||||
"""
|
||||
Retrieves a response from the cache based on the model and messages.
|
||||
Returns None if not found.
|
||||
"""
|
||||
try:
|
||||
request_hash = self._compute_request_hash(model, messages)
|
||||
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT response
|
||||
FROM llm_response_cache
|
||||
WHERE request_hash = ?
|
||||
""",
|
||||
(request_hash,),
|
||||
)
|
||||
|
||||
result = cursor.fetchone()
|
||||
return result[0] if result else None
|
||||
|
||||
except sqlite3.Error as e:
|
||||
error_msg = f"Failed to retrieve response from cache: {e}"
|
||||
self._printer.print(
|
||||
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
|
||||
color="red",
|
||||
)
|
||||
logger.error(error_msg)
|
||||
return None
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error when retrieving response: {e}"
|
||||
self._printer.print(
|
||||
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
|
||||
color="red",
|
||||
)
|
||||
logger.error(error_msg)
|
||||
return None
|
||||
|
||||
def delete_all(self) -> None:
|
||||
"""
|
||||
Deletes all records from the llm_response_cache table.
|
||||
"""
|
||||
try:
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("DELETE FROM llm_response_cache")
|
||||
conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
error_msg = f"Failed to clear cache: {e}"
|
||||
self._printer.print(
|
||||
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
|
||||
color="red",
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
|
||||
def cleanup_expired_cache(self, max_age_days: int = 7) -> None:
|
||||
"""
|
||||
Removes cache entries older than the specified number of days.
|
||||
|
||||
This method helps maintain the cache size and ensures that only recent
|
||||
responses are kept, which is important for keeping the cache relevant
|
||||
and preventing it from growing too large over time.
|
||||
|
||||
Args:
|
||||
max_age_days: Maximum age of cache entries in days. Defaults to 7.
|
||||
If set to 0, all entries will be deleted.
|
||||
Must be a non-negative integer.
|
||||
|
||||
Raises:
|
||||
ValueError: If max_age_days is not a non-negative integer.
|
||||
"""
|
||||
if not isinstance(max_age_days, int) or max_age_days < 0:
|
||||
error_msg = "max_age_days must be a non-negative integer"
|
||||
self._printer.print(
|
||||
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
|
||||
color="red",
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
try:
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
if max_age_days <= 0:
|
||||
cursor.execute("DELETE FROM llm_response_cache")
|
||||
deleted_count = cursor.rowcount
|
||||
logger.info("Deleting all cache entries (max_age_days <= 0)")
|
||||
else:
|
||||
cursor.execute(
|
||||
"""
|
||||
DELETE FROM llm_response_cache
|
||||
WHERE timestamp < datetime('now', ? || ' days')
|
||||
""",
|
||||
(f"-{max_age_days}",)
|
||||
)
|
||||
deleted_count = cursor.rowcount
|
||||
|
||||
conn.commit()
|
||||
|
||||
if deleted_count > 0:
|
||||
self._printer.print(
|
||||
content=f"LLM RESPONSE CACHE: Removed {deleted_count} expired cache entries",
|
||||
color="green",
|
||||
)
|
||||
logger.info(f"Removed {deleted_count} expired cache entries")
|
||||
|
||||
except sqlite3.Error as e:
|
||||
error_msg = f"Failed to cleanup expired cache: {e}"
|
||||
self._printer.print(
|
||||
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
|
||||
color="red",
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
|
||||
def get_cache_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns statistics about the cache.
|
||||
|
||||
Returns:
|
||||
A dictionary containing cache statistics.
|
||||
"""
|
||||
try:
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("SELECT COUNT(*) FROM llm_response_cache")
|
||||
total_count = cursor.fetchone()[0]
|
||||
|
||||
cursor.execute("SELECT model, COUNT(*) FROM llm_response_cache GROUP BY model")
|
||||
model_counts = {model: count for model, count in cursor.fetchall()}
|
||||
|
||||
cursor.execute("SELECT MIN(timestamp), MAX(timestamp) FROM llm_response_cache")
|
||||
oldest, newest = cursor.fetchone()
|
||||
|
||||
return {
|
||||
"total_entries": total_count,
|
||||
"entries_by_model": model_counts,
|
||||
"oldest_entry": oldest,
|
||||
"newest_entry": newest,
|
||||
}
|
||||
|
||||
except sqlite3.Error as e:
|
||||
error_msg = f"Failed to get cache stats: {e}"
|
||||
self._printer.print(
|
||||
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
|
||||
color="red",
|
||||
)
|
||||
logger.error(error_msg)
|
||||
return {"error": str(e)}
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""
|
||||
Closes all connections when the object is garbage collected.
|
||||
"""
|
||||
self._close_connections()
|
||||
@@ -1,156 +0,0 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from crewai.memory.storage.llm_response_cache_storage import LLMResponseCacheStorage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMResponseCacheHandler:
|
||||
"""
|
||||
Handler for the LLM response cache storage.
|
||||
Used for record/replay functionality.
|
||||
"""
|
||||
|
||||
def __init__(self, max_cache_age_days: int = 7) -> None:
|
||||
"""
|
||||
Initializes the LLM response cache handler.
|
||||
|
||||
Args:
|
||||
max_cache_age_days: Maximum age of cache entries in days. Defaults to 7.
|
||||
"""
|
||||
self.storage = LLMResponseCacheStorage()
|
||||
self._recording = False
|
||||
self._replaying = False
|
||||
self.max_cache_age_days = max_cache_age_days
|
||||
|
||||
try:
|
||||
self.storage.cleanup_expired_cache(self.max_cache_age_days)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cleanup expired cache on initialization: {e}")
|
||||
|
||||
def start_recording(self) -> None:
|
||||
"""
|
||||
Starts recording LLM responses.
|
||||
"""
|
||||
self._recording = True
|
||||
self._replaying = False
|
||||
logger.info("Started recording LLM responses")
|
||||
|
||||
def start_replaying(self) -> None:
|
||||
"""
|
||||
Starts replaying LLM responses from the cache.
|
||||
"""
|
||||
self._recording = False
|
||||
self._replaying = True
|
||||
logger.info("Started replaying LLM responses from cache")
|
||||
|
||||
try:
|
||||
stats = self.storage.get_cache_stats()
|
||||
logger.info(f"Cache statistics: {stats}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get cache statistics: {e}")
|
||||
|
||||
def stop(self) -> None:
|
||||
"""
|
||||
Stops recording or replaying.
|
||||
"""
|
||||
was_recording = self._recording
|
||||
was_replaying = self._replaying
|
||||
|
||||
self._recording = False
|
||||
self._replaying = False
|
||||
|
||||
if was_recording:
|
||||
logger.info("Stopped recording LLM responses")
|
||||
if was_replaying:
|
||||
logger.info("Stopped replaying LLM responses")
|
||||
|
||||
def is_recording(self) -> bool:
|
||||
"""
|
||||
Returns whether recording is active.
|
||||
"""
|
||||
return self._recording
|
||||
|
||||
def is_replaying(self) -> bool:
|
||||
"""
|
||||
Returns whether replaying is active.
|
||||
"""
|
||||
return self._replaying
|
||||
|
||||
def cache_response(self, model: str, messages: List[Dict[str, str]], response: str) -> None:
|
||||
"""
|
||||
Caches an LLM response if recording is active.
|
||||
|
||||
Args:
|
||||
model: The model used for the LLM call.
|
||||
messages: The messages sent to the LLM.
|
||||
response: The response from the LLM.
|
||||
"""
|
||||
if not self._recording:
|
||||
return
|
||||
|
||||
try:
|
||||
self.storage.add(model, messages, response)
|
||||
logger.debug(f"Cached response for model {model}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cache response: {e}")
|
||||
|
||||
def get_cached_response(self, model: str, messages: List[Dict[str, str]]) -> Optional[str]:
|
||||
"""
|
||||
Retrieves a cached LLM response if replaying is active.
|
||||
Returns None if not found or if replaying is not active.
|
||||
|
||||
Args:
|
||||
model: The model used for the LLM call.
|
||||
messages: The messages sent to the LLM.
|
||||
|
||||
Returns:
|
||||
The cached response, or None if not found or if replaying is not active.
|
||||
"""
|
||||
if not self._replaying:
|
||||
return None
|
||||
|
||||
try:
|
||||
response = self.storage.get(model, messages)
|
||||
if response:
|
||||
logger.debug(f"Retrieved cached response for model {model}")
|
||||
else:
|
||||
logger.debug(f"No cached response found for model {model}")
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to retrieve cached response: {e}")
|
||||
return None
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""
|
||||
Clears the LLM response cache.
|
||||
"""
|
||||
try:
|
||||
self.storage.delete_all()
|
||||
logger.info("Cleared LLM response cache")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to clear cache: {e}")
|
||||
|
||||
def cleanup_expired_cache(self) -> None:
|
||||
"""
|
||||
Removes cache entries older than the maximum age.
|
||||
"""
|
||||
try:
|
||||
self.storage.cleanup_expired_cache(self.max_cache_age_days)
|
||||
logger.info(f"Cleaned up expired cache entries (older than {self.max_cache_age_days} days)")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup expired cache: {e}")
|
||||
|
||||
def get_cache_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns statistics about the cache.
|
||||
|
||||
Returns:
|
||||
A dictionary containing cache statistics.
|
||||
"""
|
||||
try:
|
||||
return self.storage.get_cache_stats()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get cache stats: {e}")
|
||||
return {"error": str(e)}
|
||||
85
tests/knowledge/test_csv_knowledge_source_updates.py
Normal file
85
tests/knowledge/test_csv_knowledge_source_updates.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.csv_knowledge_source import CSVKnowledgeSource
|
||||
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
|
||||
|
||||
|
||||
@patch('crewai.knowledge.storage.knowledge_storage.KnowledgeStorage.search')
|
||||
@patch('crewai.knowledge.source.csv_knowledge_source.CSVKnowledgeSource.add')
|
||||
def test_csv_knowledge_source_updates(mock_add, mock_search, tmpdir):
|
||||
"""Test that CSVKnowledgeSource properly detects and loads updates to CSV files."""
|
||||
mock_search.side_effect = [
|
||||
[{"context": "name,age,city\nJohn,30,New York\nAlice,25,San Francisco\nBob,28,Chicago"}],
|
||||
[{"context": "name,age,city\nJohn,30,Boston\nAlice,25,San Francisco\nBob,28,Chicago\nEve,22,Miami"}],
|
||||
[{"context": "name,age,city\nJohn,30,Boston\nAlice,25,San Francisco\nBob,28,Chicago\nEve,22,Miami"}]
|
||||
]
|
||||
|
||||
csv_path = str(tmpdir / "test_updates.csv")
|
||||
|
||||
initial_csv_content = [
|
||||
["name", "age", "city"],
|
||||
["John", "30", "New York"],
|
||||
["Alice", "25", "San Francisco"],
|
||||
["Bob", "28", "Chicago"],
|
||||
]
|
||||
|
||||
with open(csv_path, "w") as f:
|
||||
for row in initial_csv_content:
|
||||
f.write(",".join(row) + "\n")
|
||||
|
||||
csv_source = CSVKnowledgeSource(file_paths=[csv_path])
|
||||
|
||||
original_files_have_changed = csv_source.files_have_changed
|
||||
files_changed_called = [False]
|
||||
|
||||
def spy_files_have_changed():
|
||||
files_changed_called[0] = True
|
||||
return original_files_have_changed()
|
||||
|
||||
csv_source.files_have_changed = spy_files_have_changed
|
||||
|
||||
knowledge = Knowledge(sources=[csv_source], collection_name="test_updates")
|
||||
|
||||
assert hasattr(knowledge, '_check_and_reload_sources'), "Knowledge class is missing _check_and_reload_sources method"
|
||||
|
||||
initial_results = knowledge.query(["John"])
|
||||
assert any("John" in result["context"] for result in initial_results)
|
||||
assert any("New York" in result["context"] for result in initial_results)
|
||||
|
||||
mock_add.reset_mock()
|
||||
files_changed_called[0] = False
|
||||
|
||||
updated_csv_content = [
|
||||
["name", "age", "city"],
|
||||
["John", "30", "Boston"], # Changed city
|
||||
["Alice", "25", "San Francisco"],
|
||||
["Bob", "28", "Chicago"],
|
||||
["Eve", "22", "Miami"], # Added new person
|
||||
]
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
csv_path_str = str(csv_path)
|
||||
with open(csv_path_str, "w") as f:
|
||||
for row in updated_csv_content:
|
||||
f.write(",".join(row) + "\n")
|
||||
|
||||
updated_results = knowledge.query(["John"])
|
||||
|
||||
assert files_changed_called[0], "files_have_changed method was not called during query"
|
||||
|
||||
assert mock_add.called, "add method was not called to reload the data"
|
||||
|
||||
assert any("John" in result["context"] for result in updated_results)
|
||||
assert any("Boston" in result["context"] for result in updated_results)
|
||||
assert not any("New York" in result["context"] for result in updated_results)
|
||||
|
||||
new_results = knowledge.query(["Eve"])
|
||||
assert any("Eve" in result["context"] for result in new_results)
|
||||
assert any("Miami" in result["context"] for result in new_results)
|
||||
@@ -1,155 +0,0 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.llm import LLM
|
||||
from crewai.memory.storage.llm_response_cache_storage import LLMResponseCacheStorage
|
||||
from crewai.utilities.llm_response_cache_handler import LLMResponseCacheHandler
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def handler():
|
||||
handler = LLMResponseCacheHandler()
|
||||
handler.storage.add = MagicMock()
|
||||
handler.storage.get = MagicMock()
|
||||
return handler
|
||||
|
||||
|
||||
def create_mock_response(content):
|
||||
"""Create a properly structured mock response object for litellm.completion"""
|
||||
message = SimpleNamespace(content=content)
|
||||
choice = SimpleNamespace(message=message)
|
||||
response = SimpleNamespace(choices=[choice])
|
||||
return response
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_llm_recording(handler):
|
||||
handler.start_recording()
|
||||
|
||||
llm = LLM(model="gpt-4o-mini")
|
||||
llm.set_response_cache_handler(handler)
|
||||
|
||||
messages = [{"role": "user", "content": "Hello, world!"}]
|
||||
|
||||
with patch('litellm.completion') as mock_completion:
|
||||
mock_completion.return_value = create_mock_response("Hello, human!")
|
||||
|
||||
response = llm.call(messages)
|
||||
|
||||
assert response == "Hello, human!"
|
||||
|
||||
handler.storage.add.assert_called_once_with(
|
||||
"gpt-4o-mini", messages, "Hello, human!"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_llm_replaying(handler):
|
||||
handler.start_replaying()
|
||||
handler.storage.get.return_value = "Cached response"
|
||||
|
||||
llm = LLM(model="gpt-4o-mini")
|
||||
llm.set_response_cache_handler(handler)
|
||||
|
||||
messages = [{"role": "user", "content": "Hello, world!"}]
|
||||
|
||||
with patch('litellm.completion') as mock_completion:
|
||||
response = llm.call(messages)
|
||||
|
||||
assert response == "Cached response"
|
||||
|
||||
mock_completion.assert_not_called()
|
||||
|
||||
handler.storage.get.assert_called_once_with("gpt-4o-mini", messages)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_llm_replay_fallback(handler):
|
||||
handler.start_replaying()
|
||||
handler.storage.get.return_value = None
|
||||
|
||||
llm = LLM(model="gpt-4o-mini")
|
||||
llm.set_response_cache_handler(handler)
|
||||
|
||||
messages = [{"role": "user", "content": "Hello, world!"}]
|
||||
|
||||
with patch('litellm.completion') as mock_completion:
|
||||
mock_completion.return_value = create_mock_response("Hello, human!")
|
||||
|
||||
response = llm.call(messages)
|
||||
|
||||
assert response == "Hello, human!"
|
||||
|
||||
mock_completion.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_cache_error_handling():
|
||||
"""Test that errors during cache operations are handled gracefully."""
|
||||
handler = LLMResponseCacheHandler()
|
||||
|
||||
handler.storage.add = MagicMock(side_effect=sqlite3.Error("Mock DB error"))
|
||||
handler.storage.get = MagicMock(side_effect=sqlite3.Error("Mock DB error"))
|
||||
|
||||
handler.start_recording()
|
||||
|
||||
handler.cache_response("model", [{"role": "user", "content": "test"}], "response")
|
||||
|
||||
handler.start_replaying()
|
||||
|
||||
assert handler.get_cached_response("model", [{"role": "user", "content": "test"}]) is None
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_cache_expiration():
|
||||
"""Test that cache expiration works correctly."""
|
||||
import sqlite3
|
||||
|
||||
conn = sqlite3.connect(":memory:")
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS llm_response_cache (
|
||||
request_hash TEXT PRIMARY KEY,
|
||||
model TEXT,
|
||||
messages TEXT,
|
||||
response TEXT,
|
||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
storage = LLMResponseCacheStorage(":memory:")
|
||||
|
||||
original_get_connection = storage._get_connection
|
||||
storage._get_connection = lambda: conn
|
||||
|
||||
try:
|
||||
model = "test-model"
|
||||
messages = [{"role": "user", "content": "test"}]
|
||||
response = "test response"
|
||||
storage.add(model, messages, response)
|
||||
|
||||
assert storage.get(model, messages) == response
|
||||
|
||||
storage.cleanup_expired_cache(max_age_days=0)
|
||||
|
||||
assert storage.get(model, messages) is None
|
||||
finally:
|
||||
storage._get_connection = original_get_connection
|
||||
conn.close()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_concurrent_cache_access():
|
||||
"""Test that concurrent cache access works correctly."""
|
||||
pytest.skip("SQLite in-memory databases are not shared between threads")
|
||||
|
||||
|
||||
# storage = LLMResponseCacheStorage(temp_db.name)
|
||||
@@ -1,93 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.crew import Crew
|
||||
from crewai.process import Process
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_recording_mode():
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test the recording functionality",
|
||||
backstory="A test agent for recording LLM responses",
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Return a simple response",
|
||||
expected_output="A simple response",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[agent],
|
||||
tasks=[task],
|
||||
process=Process.sequential,
|
||||
record_mode=True,
|
||||
)
|
||||
|
||||
mock_handler = MagicMock()
|
||||
crew._llm_response_cache_handler = mock_handler
|
||||
|
||||
mock_llm = MagicMock()
|
||||
agent.llm = mock_llm
|
||||
|
||||
with patch('crewai.agent.Agent.execute_task', return_value="Test response"):
|
||||
with patch('crewai.utilities.llm_response_cache_handler.LLMResponseCacheHandler', return_value=mock_handler):
|
||||
crew.kickoff()
|
||||
|
||||
mock_handler.start_recording.assert_called_once()
|
||||
|
||||
mock_llm.set_response_cache_handler.assert_called_once_with(mock_handler)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_replay_mode():
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test the replay functionality",
|
||||
backstory="A test agent for replaying LLM responses",
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Return a simple response",
|
||||
expected_output="A simple response",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[agent],
|
||||
tasks=[task],
|
||||
process=Process.sequential,
|
||||
replay_mode=True,
|
||||
)
|
||||
|
||||
mock_handler = MagicMock()
|
||||
crew._llm_response_cache_handler = mock_handler
|
||||
|
||||
mock_llm = MagicMock()
|
||||
agent.llm = mock_llm
|
||||
|
||||
with patch('crewai.agent.Agent.execute_task', return_value="Test response"):
|
||||
with patch('crewai.utilities.llm_response_cache_handler.LLMResponseCacheHandler', return_value=mock_handler):
|
||||
crew.kickoff()
|
||||
|
||||
mock_handler.start_replaying.assert_called_once()
|
||||
|
||||
mock_llm.set_response_cache_handler.assert_called_once_with(mock_handler)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_record_replay_flags_conflict():
|
||||
with pytest.raises(ValueError):
|
||||
crew = Crew(
|
||||
agents=[],
|
||||
tasks=[],
|
||||
process=Process.sequential,
|
||||
record_mode=True,
|
||||
replay_mode=True,
|
||||
)
|
||||
crew.kickoff()
|
||||
Reference in New Issue
Block a user