Compare commits

..

8 Commits

Author SHA1 Message Date
Devin AI
79547fba25 Remove lock usage entirely to fix pickling issues
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-05-06 00:20:18 +00:00
Devin AI
171f8b63fd Replace RLock with threading.Lock to fix pickling issues
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-05-06 00:19:24 +00:00
Devin AI
72df165b07 Fix RLock pickling issue in BaseFileKnowledgeSource
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-05-06 00:15:45 +00:00
Devin AI
63eccf5e30 Improve error handling and documentation in Knowledge._check_and_reload_sources
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-05-06 00:14:25 +00:00
Devin AI
a98a44afb2 Fix test file path handling for CSVKnowledgeSource
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-05-06 00:12:12 +00:00
Devin AI
6e0f1fe38d Address code review comments: improve error handling, add thread safety, enhance documentation
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-05-06 00:11:28 +00:00
Devin AI
c2bf2b3210 Fix import sorting in manual test script
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-05-06 00:06:18 +00:00
Devin AI
14579a7861 Fix #2762: Make CSV knowledge sources detect and load file updates
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-05-06 00:04:39 +00:00
12 changed files with 341 additions and 818 deletions

92
manual_test_csv_update.py Normal file
View File

@@ -0,0 +1,92 @@
"""
Manual test script to verify CSV knowledge source update functionality.
This script creates a CSV file, loads it as a knowledge source, updates the file,
and verifies that the updated content is detected and loaded.
"""
import os
import sys
import tempfile
import time
from pathlib import Path
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.csv_knowledge_source import CSVKnowledgeSource
def test_csv_knowledge_source_updates():
"""Test that CSVKnowledgeSource properly detects and loads updates to CSV files."""
with tempfile.TemporaryDirectory() as tmpdir:
csv_path = Path(tmpdir) / "test_updates.csv"
initial_csv_content = [
["name", "age", "city"],
["John", "30", "New York"],
["Alice", "25", "San Francisco"],
["Bob", "28", "Chicago"],
]
with open(csv_path, "w") as f:
for row in initial_csv_content:
f.write(",".join(row) + "\n")
print(f"Created CSV file at {csv_path}")
csv_source = CSVKnowledgeSource(file_paths=[csv_path])
if not hasattr(csv_source, 'files_have_changed'):
print("❌ TEST FAILED: files_have_changed method not found in CSVKnowledgeSource")
return False
if not hasattr(csv_source, '_file_mtimes'):
print("❌ TEST FAILED: _file_mtimes attribute not found in CSVKnowledgeSource")
return False
knowledge = Knowledge(sources=[csv_source], collection_name="test_updates")
if not hasattr(knowledge, '_check_and_reload_sources'):
print("❌ TEST FAILED: _check_and_reload_sources method not found in Knowledge")
return False
print("✅ All required methods and attributes exist")
updated_csv_content = [
["name", "age", "city"],
["John", "30", "Boston"], # Changed city
["Alice", "25", "San Francisco"],
["Bob", "28", "Chicago"],
["Eve", "22", "Miami"], # Added new person
]
print("\nWaiting for 1 second before updating file...")
time.sleep(1)
with open(csv_path, "w") as f:
for row in updated_csv_content:
f.write(",".join(row) + "\n")
print(f"Updated CSV file at {csv_path}")
if not csv_source.files_have_changed():
print("❌ TEST FAILED: files_have_changed did not detect file modification")
return False
print("✅ files_have_changed correctly detected file modification")
csv_source._record_file_mtimes()
csv_source.content = csv_source.load_content()
content_str = str(csv_source.content)
if "Boston" in content_str and "Eve" in content_str and "Miami" in content_str:
print("✅ Content was correctly updated with new data")
else:
print("❌ TEST FAILED: Content was not updated with new data")
return False
print("\n✅ TEST PASSED: CSV knowledge source correctly detects and loads file updates")
return True
if __name__ == "__main__":
success = test_csv_knowledge_source_updates()
sys.exit(0 if success else 1)

View File

@@ -201,22 +201,9 @@ def install(context):
@crewai.command()
@click.option(
"--record",
is_flag=True,
help="Record LLM responses for later replay",
)
@click.option(
"--replay",
is_flag=True,
help="Replay from recorded LLM responses without making network calls",
)
def run(record: bool = False, replay: bool = False):
def run():
"""Run the Crew."""
if record and replay:
raise click.UsageError("Cannot use --record and --replay simultaneously")
click.echo("Running the Crew")
run_crew(record=record, replay=replay)
run_crew()
@crewai.command()

View File

@@ -14,17 +14,13 @@ class CrewType(Enum):
FLOW = "flow"
def run_crew(record: bool = False, replay: bool = False) -> None:
def run_crew() -> None:
"""
Run the crew or flow by running a command in the UV environment.
Starting from version 0.103.0, this command can be used to run both
standard crews and flows. For flows, it detects the type from pyproject.toml
and automatically runs the appropriate command.
Args:
record (bool, optional): Whether to record LLM responses. Defaults to False.
replay (bool, optional): Whether to replay from recorded LLM responses. Defaults to False.
"""
crewai_version = get_crewai_version()
min_required_version = "0.71.0"
@@ -48,24 +44,17 @@ def run_crew(record: bool = False, replay: bool = False) -> None:
click.echo(f"Running the {'Flow' if is_flow else 'Crew'}")
# Execute the appropriate command
execute_command(crew_type, record, replay)
execute_command(crew_type)
def execute_command(crew_type: CrewType, record: bool = False, replay: bool = False) -> None:
def execute_command(crew_type: CrewType) -> None:
"""
Execute the appropriate command based on crew type.
Args:
crew_type: The type of crew to run
record: Whether to record LLM responses
replay: Whether to replay from recorded LLM responses
"""
command = ["uv", "run", "kickoff" if crew_type == CrewType.FLOW else "run_crew"]
if record:
command.append("--record")
if replay:
command.append("--replay")
try:
subprocess.run(command, capture_output=False, text=True, check=True)

View File

@@ -244,15 +244,6 @@ class Crew(FlowTrackable, BaseModel):
default_factory=SecurityConfig,
description="Security configuration for the crew, including fingerprinting.",
)
record_mode: bool = Field(
default=False,
description="Whether to record LLM responses for later replay.",
)
replay_mode: bool = Field(
default=False,
description="Whether to replay from recorded LLM responses without making network calls.",
)
_llm_response_cache_handler: Optional[Any] = PrivateAttr(default=None)
@field_validator("id", mode="before")
@classmethod
@@ -642,19 +633,6 @@ class Crew(FlowTrackable, BaseModel):
self._task_output_handler.reset()
self._logging_color = "bold_purple"
if self.record_mode and self.replay_mode:
raise ValueError("Cannot use both record_mode and replay_mode at the same time")
if self.record_mode or self.replay_mode:
from crewai.utilities.llm_response_cache_handler import (
LLMResponseCacheHandler,
)
self._llm_response_cache_handler = LLMResponseCacheHandler()
if self.record_mode:
self._llm_response_cache_handler.start_recording()
elif self.replay_mode:
self._llm_response_cache_handler.start_replaying()
if inputs is not None:
self._inputs = inputs
self._interpolate_inputs(inputs)
@@ -673,12 +651,6 @@ class Crew(FlowTrackable, BaseModel):
if not agent.step_callback: # type: ignore # "BaseAgent" has no attribute "step_callback"
agent.step_callback = self.step_callback # type: ignore # "BaseAgent" has no attribute "step_callback"
if self._llm_response_cache_handler:
if hasattr(agent, "llm") and agent.llm:
agent.llm.set_response_cache_handler(self._llm_response_cache_handler)
if hasattr(agent, "function_calling_llm") and agent.function_calling_llm:
agent.function_calling_llm.set_response_cache_handler(self._llm_response_cache_handler)
agent.create_agent_executor()
@@ -1315,9 +1287,6 @@ class Crew(FlowTrackable, BaseModel):
def _finish_execution(self, final_string_output: str) -> None:
if self.max_rpm:
self._rpm_controller.stop_rpm_counter()
if self._llm_response_cache_handler:
self._llm_response_cache_handler.stop()
def calculate_usage_metrics(self) -> UsageMetrics:
"""Calculates and returns the usage metrics."""

View File

@@ -5,6 +5,7 @@ from pydantic import BaseModel, ConfigDict, Field
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
from crewai.utilities.logger import Logger
os.environ["TOKENIZERS_PARALLELISM"] = "false" # removes logging from fastembed
@@ -12,10 +13,19 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false" # removes logging from fastembed
class Knowledge(BaseModel):
"""
Knowledge is a collection of sources and setup for the vector store to save and query relevant context.
This class manages knowledge sources and provides methods to query them for relevant information.
It automatically detects and reloads file-based knowledge sources when their underlying files change.
Args:
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
The knowledge sources to use for querying.
storage: Optional[KnowledgeStorage] = Field(default=None)
The storage backend for knowledge embeddings.
embedder: Optional[Dict[str, Any]] = None
Configuration for the embedding model.
collection_name: Optional[str] = None
Name of the collection to use for storage.
"""
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
@@ -23,6 +33,7 @@ class Knowledge(BaseModel):
storage: Optional[KnowledgeStorage] = Field(default=None)
embedder: Optional[Dict[str, Any]] = None
collection_name: Optional[str] = None
_logger: Logger = Logger(verbose=True)
def __init__(
self,
@@ -54,6 +65,8 @@ class Knowledge(BaseModel):
"""
if self.storage is None:
raise ValueError("Storage is not initialized.")
self._check_and_reload_sources()
results = self.storage.search(
query,
@@ -61,6 +74,65 @@ class Knowledge(BaseModel):
score_threshold=score_threshold,
)
return results
def _check_and_reload_sources(self):
"""
Check if any file-based knowledge sources have changed and reload them if necessary.
This method detects modifications to source files by comparing their modification timestamps
with previously recorded values. When changes are detected, the source is reloaded and
the storage is updated with the new content.
The method handles various file-related exceptions with specific error messages:
- FileNotFoundError: When a source file no longer exists
- PermissionError: When there are permission issues accessing a file
- IOError: When there are I/O errors reading a file
- ValueError: When there are issues with file content format
- Other unexpected exceptions are also caught and logged
Each exception is logged with appropriate context to aid in troubleshooting.
"""
for source in self.sources:
source_name = source.__class__.__name__
try:
if hasattr(source, 'files_have_changed') and source.files_have_changed():
self._logger.log("info", f"Reloading modified source: {source_name}")
# Update file modification timestamps
try:
source._record_file_mtimes()
except (PermissionError, IOError) as e:
self._logger.log("warning", f"Could not record file timestamps for {source_name}: {str(e)}")
try:
source.content = source.load_content()
except FileNotFoundError as e:
self._logger.log("error", f"File not found when loading content for {source_name}: {str(e)}")
continue
except PermissionError as e:
self._logger.log("error", f"Permission error when loading content for {source_name}: {str(e)}")
continue
except IOError as e:
self._logger.log("error", f"IO error when loading content for {source_name}: {str(e)}")
continue
except ValueError as e:
self._logger.log("error", f"Invalid content format in {source_name}: {str(e)}")
continue
try:
source.add()
self._logger.log("info", f"Successfully reloaded and updated {source_name}")
except Exception as e:
self._logger.log("error", f"Failed to update storage for {source_name}: {str(e)}")
except FileNotFoundError as e:
self._logger.log("error", f"File not found when checking for updates in {source_name}: {str(e)}")
except PermissionError as e:
self._logger.log("error", f"Permission error when checking for updates in {source_name}: {str(e)}")
except IOError as e:
self._logger.log("error", f"IO error when checking for updates in {source_name}: {str(e)}")
except Exception as e:
self._logger.log("error", f"Unexpected error when checking for updates in {source_name}: {str(e)}")
def add_sources(self):
try:

View File

@@ -1,3 +1,4 @@
import os
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, List, Optional, Union
@@ -11,9 +12,24 @@ from crewai.utilities.logger import Logger
class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
"""Base class for knowledge sources that load content from files."""
"""
Base class for knowledge sources that load content from files.
This class provides common functionality for file-based knowledge sources,
including file path validation, content loading, and change detection.
It automatically tracks file modification times to detect when files have
been updated and need to be reloaded.
Attributes:
file_path: Deprecated. Use file_paths instead.
file_paths: Path(s) to the file(s) containing knowledge data.
content: Dictionary mapping file paths to their loaded content.
storage: Storage backend for the knowledge data.
safe_file_paths: Validated list of Path objects.
"""
_logger: Logger = Logger(verbose=True)
file_path: Optional[Union[Path, List[Path], str, List[str]]] = Field(
default=None,
description="[Deprecated] The path to the file. Use file_paths instead.",
@@ -43,7 +59,34 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
"""Post-initialization method to load content."""
self.safe_file_paths = self._process_file_paths()
self.validate_content()
self._record_file_mtimes()
self.content = self.load_content()
def _record_file_mtimes(self):
"""
Record modification times of all files.
This method stores the current modification timestamps of all files
in the _file_mtimes dictionary. These timestamps are later used to
detect when files have been modified and need to be reloaded.
Thread-safe: Uses a lock to prevent concurrent modifications.
"""
with self._lock:
self._file_mtimes = {}
for path in self.safe_file_paths:
try:
if path.exists() and path.is_file():
if os.access(path, os.R_OK):
self._file_mtimes[path] = path.stat().st_mtime
else:
self._logger.log("warning", f"File {path} is not readable.")
except PermissionError as e:
self._logger.log("error", f"Permission error when recording file timestamp for {path}: {str(e)}")
except IOError as e:
self._logger.log("error", f"IO error when recording file timestamp for {path}: {str(e)}")
except Exception as e:
self._logger.log("error", f"Unexpected error when recording file timestamp for {path}: {str(e)}")
@abstractmethod
def load_content(self) -> Dict[Path, str]:
@@ -107,3 +150,41 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
)
return [self.convert_to_path(path) for path in path_list]
def files_have_changed(self) -> bool:
"""
Check if any of the files have been modified since they were last loaded.
This method compares the current modification timestamps of files with the
previously recorded timestamps to detect changes. When a file has been modified,
it logs the change and returns True to trigger a reload.
Returns:
bool: True if any file has been modified, False otherwise.
"""
for path in self.safe_file_paths:
try:
if not path.exists():
self._logger.log("warning", f"File {path} no longer exists.")
continue
if not path.is_file():
self._logger.log("warning", f"Path {path} is not a file.")
continue
if not os.access(path, os.R_OK):
self._logger.log("warning", f"File {path} is not readable.")
continue
current_mtime = path.stat().st_mtime
if path not in self._file_mtimes or current_mtime > self._file_mtimes[path]:
self._logger.log("info", f"File {path} has been modified. Reloading data.")
return True
except PermissionError as e:
self._logger.log("error", f"Permission error when checking file {path}: {str(e)}")
except IOError as e:
self._logger.log("error", f"IO error when checking file {path}: {str(e)}")
except Exception as e:
self._logger.log("error", f"Unexpected error when checking file {path}: {str(e)}")
return False

View File

@@ -296,7 +296,6 @@ class LLM(BaseLLM):
self.additional_params = kwargs
self.is_anthropic = self._is_anthropic_model(model)
self.stream = stream
self._response_cache_handler = None
litellm.drop_params = True
@@ -870,43 +869,25 @@ class LLM(BaseLLM):
for message in messages:
if message.get("role") == "system":
message["role"] = "assistant"
if self._response_cache_handler and self._response_cache_handler.is_replaying():
cached_response = self._response_cache_handler.get_cached_response(
self.model, messages
)
if cached_response:
# Emit completion event for the cached response
self._handle_emit_call_events(cached_response, LLMCallType.LLM_CALL)
return cached_response
# --- 6) Set up callbacks if provided
# --- 5) Set up callbacks if provided
with suppress_warnings():
if callbacks and len(callbacks) > 0:
self.set_callbacks(callbacks)
try:
# --- 7) Prepare parameters for the completion call
# --- 6) Prepare parameters for the completion call
params = self._prepare_completion_params(messages, tools)
# --- 8) Make the completion call and handle response
# --- 7) Make the completion call and handle response
if self.stream:
response = self._handle_streaming_response(
return self._handle_streaming_response(
params, callbacks, available_functions
)
else:
response = self._handle_non_streaming_response(
return self._handle_non_streaming_response(
params, callbacks, available_functions
)
if (self._response_cache_handler and
self._response_cache_handler.is_recording() and
isinstance(response, str)):
self._response_cache_handler.cache_response(
self.model, messages, response
)
return response
except LLMContextLengthExceededException:
# Re-raise LLMContextLengthExceededException as it should be handled
@@ -1126,18 +1107,3 @@ class LLM(BaseLLM):
litellm.success_callback = success_callbacks
litellm.failure_callback = failure_callbacks
def set_response_cache_handler(self, handler):
"""
Sets the response cache handler for record/replay functionality.
Args:
handler: An instance of LLMResponseCacheHandler.
"""
self._response_cache_handler = handler
def clear_response_cache_handler(self):
"""
Clears the response cache handler.
"""
self._response_cache_handler = None

View File

@@ -1,314 +0,0 @@
import hashlib
import json
import logging
import sqlite3
import threading
from typing import Any, Dict, List, Optional
from crewai.utilities import Printer
from crewai.utilities.crew_json_encoder import CrewJSONEncoder
from crewai.utilities.paths import db_storage_path
logger = logging.getLogger(__name__)
class LLMResponseCacheStorage:
"""
SQLite storage for caching LLM responses.
Used for offline record/replay functionality.
"""
def __init__(
self, db_path: str = f"{db_storage_path()}/llm_response_cache.db"
) -> None:
self.db_path = db_path
self._printer: Printer = Printer()
self._connection_pool: Dict[int, sqlite3.Connection] = {}
self._initialize_db()
def _get_connection(self) -> sqlite3.Connection:
"""
Gets a connection from the connection pool or creates a new one.
Uses thread-local storage to ensure thread safety.
"""
thread_id = threading.get_ident()
if thread_id not in self._connection_pool:
try:
conn = sqlite3.connect(self.db_path)
conn.execute("PRAGMA foreign_keys = ON")
conn.execute("PRAGMA journal_mode = WAL")
self._connection_pool[thread_id] = conn
except sqlite3.Error as e:
error_msg = f"Failed to create SQLite connection: {e}"
self._printer.print(
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
color="red",
)
logger.error(error_msg)
raise
return self._connection_pool[thread_id]
def _close_connections(self) -> None:
"""
Closes all connections in the connection pool.
"""
for thread_id, conn in list(self._connection_pool.items()):
try:
conn.close()
del self._connection_pool[thread_id]
except sqlite3.Error as e:
error_msg = f"Failed to close SQLite connection: {e}"
self._printer.print(
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
color="red",
)
logger.error(error_msg)
def _initialize_db(self) -> None:
"""
Initializes the SQLite database and creates the llm_response_cache table
"""
try:
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS llm_response_cache (
request_hash TEXT PRIMARY KEY,
model TEXT,
messages TEXT,
response TEXT,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
)
"""
)
conn.commit()
except sqlite3.Error as e:
error_msg = f"Failed to initialize database: {e}"
self._printer.print(
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
color="red",
)
logger.error(error_msg)
raise
def _compute_request_hash(self, model: str, messages: List[Dict[str, str]]) -> str:
"""
Computes a hash for the request based on the model and messages.
This hash is used as the key for caching.
Sensitive information like API keys should not be included in the hash.
"""
try:
message_str = json.dumps(messages, sort_keys=True)
request_hash = hashlib.sha256(f"{model}:{message_str}".encode()).hexdigest()
return request_hash
except Exception as e:
error_msg = f"Failed to compute request hash: {e}"
self._printer.print(
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
color="red",
)
logger.error(error_msg)
raise
def add(self, model: str, messages: List[Dict[str, str]], response: str) -> None:
"""
Adds a response to the cache.
"""
try:
request_hash = self._compute_request_hash(model, messages)
messages_json = json.dumps(messages, cls=CrewJSONEncoder)
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute(
"""
INSERT OR REPLACE INTO llm_response_cache
(request_hash, model, messages, response)
VALUES (?, ?, ?, ?)
""",
(
request_hash,
model,
messages_json,
response,
),
)
conn.commit()
except sqlite3.Error as e:
error_msg = f"Failed to add response to cache: {e}"
self._printer.print(
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
color="red",
)
logger.error(error_msg)
raise
except Exception as e:
error_msg = f"Unexpected error when adding response: {e}"
self._printer.print(
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
color="red",
)
logger.error(error_msg)
raise
def get(self, model: str, messages: List[Dict[str, str]]) -> Optional[str]:
"""
Retrieves a response from the cache based on the model and messages.
Returns None if not found.
"""
try:
request_hash = self._compute_request_hash(model, messages)
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute(
"""
SELECT response
FROM llm_response_cache
WHERE request_hash = ?
""",
(request_hash,),
)
result = cursor.fetchone()
return result[0] if result else None
except sqlite3.Error as e:
error_msg = f"Failed to retrieve response from cache: {e}"
self._printer.print(
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
color="red",
)
logger.error(error_msg)
return None
except Exception as e:
error_msg = f"Unexpected error when retrieving response: {e}"
self._printer.print(
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
color="red",
)
logger.error(error_msg)
return None
def delete_all(self) -> None:
"""
Deletes all records from the llm_response_cache table.
"""
try:
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute("DELETE FROM llm_response_cache")
conn.commit()
except sqlite3.Error as e:
error_msg = f"Failed to clear cache: {e}"
self._printer.print(
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
color="red",
)
logger.error(error_msg)
raise
def cleanup_expired_cache(self, max_age_days: int = 7) -> None:
"""
Removes cache entries older than the specified number of days.
This method helps maintain the cache size and ensures that only recent
responses are kept, which is important for keeping the cache relevant
and preventing it from growing too large over time.
Args:
max_age_days: Maximum age of cache entries in days. Defaults to 7.
If set to 0, all entries will be deleted.
Must be a non-negative integer.
Raises:
ValueError: If max_age_days is not a non-negative integer.
"""
if not isinstance(max_age_days, int) or max_age_days < 0:
error_msg = "max_age_days must be a non-negative integer"
self._printer.print(
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
color="red",
)
logger.error(error_msg)
raise ValueError(error_msg)
try:
conn = self._get_connection()
cursor = conn.cursor()
if max_age_days <= 0:
cursor.execute("DELETE FROM llm_response_cache")
deleted_count = cursor.rowcount
logger.info("Deleting all cache entries (max_age_days <= 0)")
else:
cursor.execute(
"""
DELETE FROM llm_response_cache
WHERE timestamp < datetime('now', ? || ' days')
""",
(f"-{max_age_days}",)
)
deleted_count = cursor.rowcount
conn.commit()
if deleted_count > 0:
self._printer.print(
content=f"LLM RESPONSE CACHE: Removed {deleted_count} expired cache entries",
color="green",
)
logger.info(f"Removed {deleted_count} expired cache entries")
except sqlite3.Error as e:
error_msg = f"Failed to cleanup expired cache: {e}"
self._printer.print(
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
color="red",
)
logger.error(error_msg)
raise
def get_cache_stats(self) -> Dict[str, Any]:
"""
Returns statistics about the cache.
Returns:
A dictionary containing cache statistics.
"""
try:
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) FROM llm_response_cache")
total_count = cursor.fetchone()[0]
cursor.execute("SELECT model, COUNT(*) FROM llm_response_cache GROUP BY model")
model_counts = {model: count for model, count in cursor.fetchall()}
cursor.execute("SELECT MIN(timestamp), MAX(timestamp) FROM llm_response_cache")
oldest, newest = cursor.fetchone()
return {
"total_entries": total_count,
"entries_by_model": model_counts,
"oldest_entry": oldest,
"newest_entry": newest,
}
except sqlite3.Error as e:
error_msg = f"Failed to get cache stats: {e}"
self._printer.print(
content=f"LLM RESPONSE CACHE ERROR: {error_msg}",
color="red",
)
logger.error(error_msg)
return {"error": str(e)}
def __del__(self) -> None:
"""
Closes all connections when the object is garbage collected.
"""
self._close_connections()

View File

@@ -1,156 +0,0 @@
import logging
from typing import Any, Dict, List, Optional
from crewai.memory.storage.llm_response_cache_storage import LLMResponseCacheStorage
logger = logging.getLogger(__name__)
class LLMResponseCacheHandler:
"""
Handler for the LLM response cache storage.
Used for record/replay functionality.
"""
def __init__(self, max_cache_age_days: int = 7) -> None:
"""
Initializes the LLM response cache handler.
Args:
max_cache_age_days: Maximum age of cache entries in days. Defaults to 7.
"""
self.storage = LLMResponseCacheStorage()
self._recording = False
self._replaying = False
self.max_cache_age_days = max_cache_age_days
try:
self.storage.cleanup_expired_cache(self.max_cache_age_days)
except Exception as e:
logger.warning(f"Failed to cleanup expired cache on initialization: {e}")
def start_recording(self) -> None:
"""
Starts recording LLM responses.
"""
self._recording = True
self._replaying = False
logger.info("Started recording LLM responses")
def start_replaying(self) -> None:
"""
Starts replaying LLM responses from the cache.
"""
self._recording = False
self._replaying = True
logger.info("Started replaying LLM responses from cache")
try:
stats = self.storage.get_cache_stats()
logger.info(f"Cache statistics: {stats}")
except Exception as e:
logger.warning(f"Failed to get cache statistics: {e}")
def stop(self) -> None:
"""
Stops recording or replaying.
"""
was_recording = self._recording
was_replaying = self._replaying
self._recording = False
self._replaying = False
if was_recording:
logger.info("Stopped recording LLM responses")
if was_replaying:
logger.info("Stopped replaying LLM responses")
def is_recording(self) -> bool:
"""
Returns whether recording is active.
"""
return self._recording
def is_replaying(self) -> bool:
"""
Returns whether replaying is active.
"""
return self._replaying
def cache_response(self, model: str, messages: List[Dict[str, str]], response: str) -> None:
"""
Caches an LLM response if recording is active.
Args:
model: The model used for the LLM call.
messages: The messages sent to the LLM.
response: The response from the LLM.
"""
if not self._recording:
return
try:
self.storage.add(model, messages, response)
logger.debug(f"Cached response for model {model}")
except Exception as e:
logger.error(f"Failed to cache response: {e}")
def get_cached_response(self, model: str, messages: List[Dict[str, str]]) -> Optional[str]:
"""
Retrieves a cached LLM response if replaying is active.
Returns None if not found or if replaying is not active.
Args:
model: The model used for the LLM call.
messages: The messages sent to the LLM.
Returns:
The cached response, or None if not found or if replaying is not active.
"""
if not self._replaying:
return None
try:
response = self.storage.get(model, messages)
if response:
logger.debug(f"Retrieved cached response for model {model}")
else:
logger.debug(f"No cached response found for model {model}")
return response
except Exception as e:
logger.error(f"Failed to retrieve cached response: {e}")
return None
def clear_cache(self) -> None:
"""
Clears the LLM response cache.
"""
try:
self.storage.delete_all()
logger.info("Cleared LLM response cache")
except Exception as e:
logger.error(f"Failed to clear cache: {e}")
def cleanup_expired_cache(self) -> None:
"""
Removes cache entries older than the maximum age.
"""
try:
self.storage.cleanup_expired_cache(self.max_cache_age_days)
logger.info(f"Cleaned up expired cache entries (older than {self.max_cache_age_days} days)")
except Exception as e:
logger.error(f"Failed to cleanup expired cache: {e}")
def get_cache_stats(self) -> Dict[str, Any]:
"""
Returns statistics about the cache.
Returns:
A dictionary containing cache statistics.
"""
try:
return self.storage.get_cache_stats()
except Exception as e:
logger.error(f"Failed to get cache stats: {e}")
return {"error": str(e)}

View File

@@ -0,0 +1,85 @@
import os
import time
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.csv_knowledge_source import CSVKnowledgeSource
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
@patch('crewai.knowledge.storage.knowledge_storage.KnowledgeStorage.search')
@patch('crewai.knowledge.source.csv_knowledge_source.CSVKnowledgeSource.add')
def test_csv_knowledge_source_updates(mock_add, mock_search, tmpdir):
"""Test that CSVKnowledgeSource properly detects and loads updates to CSV files."""
mock_search.side_effect = [
[{"context": "name,age,city\nJohn,30,New York\nAlice,25,San Francisco\nBob,28,Chicago"}],
[{"context": "name,age,city\nJohn,30,Boston\nAlice,25,San Francisco\nBob,28,Chicago\nEve,22,Miami"}],
[{"context": "name,age,city\nJohn,30,Boston\nAlice,25,San Francisco\nBob,28,Chicago\nEve,22,Miami"}]
]
csv_path = str(tmpdir / "test_updates.csv")
initial_csv_content = [
["name", "age", "city"],
["John", "30", "New York"],
["Alice", "25", "San Francisco"],
["Bob", "28", "Chicago"],
]
with open(csv_path, "w") as f:
for row in initial_csv_content:
f.write(",".join(row) + "\n")
csv_source = CSVKnowledgeSource(file_paths=[csv_path])
original_files_have_changed = csv_source.files_have_changed
files_changed_called = [False]
def spy_files_have_changed():
files_changed_called[0] = True
return original_files_have_changed()
csv_source.files_have_changed = spy_files_have_changed
knowledge = Knowledge(sources=[csv_source], collection_name="test_updates")
assert hasattr(knowledge, '_check_and_reload_sources'), "Knowledge class is missing _check_and_reload_sources method"
initial_results = knowledge.query(["John"])
assert any("John" in result["context"] for result in initial_results)
assert any("New York" in result["context"] for result in initial_results)
mock_add.reset_mock()
files_changed_called[0] = False
updated_csv_content = [
["name", "age", "city"],
["John", "30", "Boston"], # Changed city
["Alice", "25", "San Francisco"],
["Bob", "28", "Chicago"],
["Eve", "22", "Miami"], # Added new person
]
time.sleep(1)
csv_path_str = str(csv_path)
with open(csv_path_str, "w") as f:
for row in updated_csv_content:
f.write(",".join(row) + "\n")
updated_results = knowledge.query(["John"])
assert files_changed_called[0], "files_have_changed method was not called during query"
assert mock_add.called, "add method was not called to reload the data"
assert any("John" in result["context"] for result in updated_results)
assert any("Boston" in result["context"] for result in updated_results)
assert not any("New York" in result["context"] for result in updated_results)
new_results = knowledge.query(["Eve"])
assert any("Eve" in result["context"] for result in new_results)
assert any("Miami" in result["context"] for result in new_results)

View File

@@ -1,155 +0,0 @@
import sqlite3
import threading
from concurrent.futures import ThreadPoolExecutor
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from crewai.llm import LLM
from crewai.memory.storage.llm_response_cache_storage import LLMResponseCacheStorage
from crewai.utilities.llm_response_cache_handler import LLMResponseCacheHandler
@pytest.fixture
def handler():
handler = LLMResponseCacheHandler()
handler.storage.add = MagicMock()
handler.storage.get = MagicMock()
return handler
def create_mock_response(content):
"""Create a properly structured mock response object for litellm.completion"""
message = SimpleNamespace(content=content)
choice = SimpleNamespace(message=message)
response = SimpleNamespace(choices=[choice])
return response
@pytest.mark.vcr(filter_headers=["authorization"])
def test_llm_recording(handler):
handler.start_recording()
llm = LLM(model="gpt-4o-mini")
llm.set_response_cache_handler(handler)
messages = [{"role": "user", "content": "Hello, world!"}]
with patch('litellm.completion') as mock_completion:
mock_completion.return_value = create_mock_response("Hello, human!")
response = llm.call(messages)
assert response == "Hello, human!"
handler.storage.add.assert_called_once_with(
"gpt-4o-mini", messages, "Hello, human!"
)
@pytest.mark.vcr(filter_headers=["authorization"])
def test_llm_replaying(handler):
handler.start_replaying()
handler.storage.get.return_value = "Cached response"
llm = LLM(model="gpt-4o-mini")
llm.set_response_cache_handler(handler)
messages = [{"role": "user", "content": "Hello, world!"}]
with patch('litellm.completion') as mock_completion:
response = llm.call(messages)
assert response == "Cached response"
mock_completion.assert_not_called()
handler.storage.get.assert_called_once_with("gpt-4o-mini", messages)
@pytest.mark.vcr(filter_headers=["authorization"])
def test_llm_replay_fallback(handler):
handler.start_replaying()
handler.storage.get.return_value = None
llm = LLM(model="gpt-4o-mini")
llm.set_response_cache_handler(handler)
messages = [{"role": "user", "content": "Hello, world!"}]
with patch('litellm.completion') as mock_completion:
mock_completion.return_value = create_mock_response("Hello, human!")
response = llm.call(messages)
assert response == "Hello, human!"
mock_completion.assert_called_once()
@pytest.mark.vcr(filter_headers=["authorization"])
def test_cache_error_handling():
"""Test that errors during cache operations are handled gracefully."""
handler = LLMResponseCacheHandler()
handler.storage.add = MagicMock(side_effect=sqlite3.Error("Mock DB error"))
handler.storage.get = MagicMock(side_effect=sqlite3.Error("Mock DB error"))
handler.start_recording()
handler.cache_response("model", [{"role": "user", "content": "test"}], "response")
handler.start_replaying()
assert handler.get_cached_response("model", [{"role": "user", "content": "test"}]) is None
@pytest.mark.vcr(filter_headers=["authorization"])
def test_cache_expiration():
"""Test that cache expiration works correctly."""
import sqlite3
conn = sqlite3.connect(":memory:")
cursor = conn.cursor()
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS llm_response_cache (
request_hash TEXT PRIMARY KEY,
model TEXT,
messages TEXT,
response TEXT,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
)
"""
)
conn.commit()
storage = LLMResponseCacheStorage(":memory:")
original_get_connection = storage._get_connection
storage._get_connection = lambda: conn
try:
model = "test-model"
messages = [{"role": "user", "content": "test"}]
response = "test response"
storage.add(model, messages, response)
assert storage.get(model, messages) == response
storage.cleanup_expired_cache(max_age_days=0)
assert storage.get(model, messages) is None
finally:
storage._get_connection = original_get_connection
conn.close()
@pytest.mark.vcr(filter_headers=["authorization"])
def test_concurrent_cache_access():
"""Test that concurrent cache access works correctly."""
pytest.skip("SQLite in-memory databases are not shared between threads")
# storage = LLMResponseCacheStorage(temp_db.name)

View File

@@ -1,93 +0,0 @@
from unittest.mock import MagicMock, patch
import pytest
from crewai.agent import Agent
from crewai.crew import Crew
from crewai.process import Process
from crewai.task import Task
@pytest.mark.vcr(filter_headers=["authorization"])
def test_crew_recording_mode():
agent = Agent(
role="Test Agent",
goal="Test the recording functionality",
backstory="A test agent for recording LLM responses",
)
task = Task(
description="Return a simple response",
expected_output="A simple response",
agent=agent,
)
crew = Crew(
agents=[agent],
tasks=[task],
process=Process.sequential,
record_mode=True,
)
mock_handler = MagicMock()
crew._llm_response_cache_handler = mock_handler
mock_llm = MagicMock()
agent.llm = mock_llm
with patch('crewai.agent.Agent.execute_task', return_value="Test response"):
with patch('crewai.utilities.llm_response_cache_handler.LLMResponseCacheHandler', return_value=mock_handler):
crew.kickoff()
mock_handler.start_recording.assert_called_once()
mock_llm.set_response_cache_handler.assert_called_once_with(mock_handler)
@pytest.mark.vcr(filter_headers=["authorization"])
def test_crew_replay_mode():
agent = Agent(
role="Test Agent",
goal="Test the replay functionality",
backstory="A test agent for replaying LLM responses",
)
task = Task(
description="Return a simple response",
expected_output="A simple response",
agent=agent,
)
crew = Crew(
agents=[agent],
tasks=[task],
process=Process.sequential,
replay_mode=True,
)
mock_handler = MagicMock()
crew._llm_response_cache_handler = mock_handler
mock_llm = MagicMock()
agent.llm = mock_llm
with patch('crewai.agent.Agent.execute_task', return_value="Test response"):
with patch('crewai.utilities.llm_response_cache_handler.LLMResponseCacheHandler', return_value=mock_handler):
crew.kickoff()
mock_handler.start_replaying.assert_called_once()
mock_llm.set_response_cache_handler.assert_called_once_with(mock_handler)
@pytest.mark.vcr(filter_headers=["authorization"])
def test_record_replay_flags_conflict():
with pytest.raises(ValueError):
crew = Crew(
agents=[],
tasks=[],
process=Process.sequential,
record_mode=True,
replay_mode=True,
)
crew.kickoff()