Address code review comments: improve error handling, add thread safety, enhance documentation

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-05-06 00:11:28 +00:00
parent c2bf2b3210
commit 6e0f1fe38d
3 changed files with 118 additions and 22 deletions

View File

@@ -5,6 +5,7 @@ from pydantic import BaseModel, ConfigDict, Field
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
from crewai.utilities.logger import Logger
os.environ["TOKENIZERS_PARALLELISM"] = "false" # removes logging from fastembed
@@ -12,10 +13,19 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false" # removes logging from fastembed
class Knowledge(BaseModel):
"""
Knowledge is a collection of sources and setup for the vector store to save and query relevant context.
This class manages knowledge sources and provides methods to query them for relevant information.
It automatically detects and reloads file-based knowledge sources when their underlying files change.
Args:
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
The knowledge sources to use for querying.
storage: Optional[KnowledgeStorage] = Field(default=None)
The storage backend for knowledge embeddings.
embedder: Optional[Dict[str, Any]] = None
Configuration for the embedding model.
collection_name: Optional[str] = None
Name of the collection to use for storage.
"""
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
@@ -23,6 +33,7 @@ class Knowledge(BaseModel):
storage: Optional[KnowledgeStorage] = Field(default=None)
embedder: Optional[Dict[str, Any]] = None
collection_name: Optional[str] = None
_logger: Logger = Logger(verbose=True)
def __init__(
self,
@@ -65,12 +76,30 @@ class Knowledge(BaseModel):
return results
def _check_and_reload_sources(self):
"""Check if any sources have changed and reload them if necessary."""
"""
Check if any file-based knowledge sources have changed and reload them if necessary.
This method detects modifications to source files by comparing their modification timestamps
with previously recorded values. When changes are detected, the source is reloaded and
the storage is updated with the new content.
Handles specific exceptions for file operations to provide better error reporting.
"""
for source in self.sources:
if hasattr(source, 'files_have_changed') and source.files_have_changed():
source._record_file_mtimes() # Update timestamps
source.content = source.load_content()
source.add() # Reload and update storage
try:
if hasattr(source, 'files_have_changed') and source.files_have_changed():
self._logger.log("info", f"Reloading modified source: {source.__class__.__name__}")
source._record_file_mtimes() # Update timestamps
source.content = source.load_content()
source.add() # Reload and update storage
except FileNotFoundError as e:
self._logger.log("error", f"File not found when checking for updates: {str(e)}")
except PermissionError as e:
self._logger.log("error", f"Permission error when checking for updates: {str(e)}")
except IOError as e:
self._logger.log("error", f"IO error when checking for updates: {str(e)}")
except Exception as e:
self._logger.log("error", f"Unexpected error when checking for updates: {str(e)}")
def add_sources(self):
try:

View File

@@ -1,5 +1,7 @@
import os
from abc import ABC, abstractmethod
from pathlib import Path
from threading import RLock
from typing import Dict, List, Optional, Union
from pydantic import Field, field_validator
@@ -11,9 +13,25 @@ from crewai.utilities.logger import Logger
class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
"""Base class for knowledge sources that load content from files."""
"""
Base class for knowledge sources that load content from files.
This class provides common functionality for file-based knowledge sources,
including file path validation, content loading, and change detection.
It automatically tracks file modification times to detect when files have
been updated and need to be reloaded.
Attributes:
file_path: Deprecated. Use file_paths instead.
file_paths: Path(s) to the file(s) containing knowledge data.
content: Dictionary mapping file paths to their loaded content.
storage: Storage backend for the knowledge data.
safe_file_paths: Validated list of Path objects.
"""
_logger: Logger = Logger(verbose=True)
_lock: RLock = RLock() # Thread-safe lock for file operations
file_path: Optional[Union[Path, List[Path], str, List[str]]] = Field(
default=None,
description="[Deprecated] The path to the file. Use file_paths instead.",
@@ -47,11 +65,30 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
self.content = self.load_content()
def _record_file_mtimes(self):
"""Record modification times of all files."""
self._file_mtimes = {}
for path in self.safe_file_paths:
if path.exists() and path.is_file():
self._file_mtimes[path] = path.stat().st_mtime
"""
Record modification times of all files.
This method stores the current modification timestamps of all files
in the _file_mtimes dictionary. These timestamps are later used to
detect when files have been modified and need to be reloaded.
Thread-safe: Uses a lock to prevent concurrent modifications.
"""
with self._lock:
self._file_mtimes = {}
for path in self.safe_file_paths:
try:
if path.exists() and path.is_file():
if os.access(path, os.R_OK):
self._file_mtimes[path] = path.stat().st_mtime
else:
self._logger.log("warning", f"File {path} is not readable.")
except PermissionError as e:
self._logger.log("error", f"Permission error when recording file timestamp for {path}: {str(e)}")
except IOError as e:
self._logger.log("error", f"IO error when recording file timestamp for {path}: {str(e)}")
except Exception as e:
self._logger.log("error", f"Unexpected error when recording file timestamp for {path}: {str(e)}")
@abstractmethod
def load_content(self) -> Dict[Path, str]:
@@ -117,12 +154,42 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
return [self.convert_to_path(path) for path in path_list]
def files_have_changed(self) -> bool:
"""Check if any of the files have been modified since they were last loaded."""
for path in self.safe_file_paths:
if not path.exists() or not path.is_file():
continue
current_mtime = path.stat().st_mtime
if path not in self._file_mtimes or current_mtime > self._file_mtimes[path]:
self._logger.log("info", f"File {path} has been modified. Reloading data.")
return True
return False
"""
Check if any of the files have been modified since they were last loaded.
This method compares the current modification timestamps of files with the
previously recorded timestamps to detect changes. When a file has been modified,
it logs the change and returns True to trigger a reload.
Thread-safe: Uses a lock to prevent concurrent modifications.
Returns:
bool: True if any file has been modified, False otherwise.
"""
with self._lock:
for path in self.safe_file_paths:
try:
if not path.exists():
self._logger.log("warning", f"File {path} no longer exists.")
continue
if not path.is_file():
self._logger.log("warning", f"Path {path} is not a file.")
continue
if not os.access(path, os.R_OK):
self._logger.log("warning", f"File {path} is not readable.")
continue
current_mtime = path.stat().st_mtime
if path not in self._file_mtimes or current_mtime > self._file_mtimes[path]:
self._logger.log("info", f"File {path} has been modified. Reloading data.")
return True
except PermissionError as e:
self._logger.log("error", f"Permission error when checking file {path}: {str(e)}")
except IOError as e:
self._logger.log("error", f"IO error when checking file {path}: {str(e)}")
except Exception as e:
self._logger.log("error", f"Unexpected error when checking file {path}: {str(e)}")
return False

View File

@@ -20,7 +20,7 @@ def test_csv_knowledge_source_updates(mock_add, mock_search, tmpdir):
[{"context": "name,age,city\nJohn,30,Boston\nAlice,25,San Francisco\nBob,28,Chicago\nEve,22,Miami"}]
]
csv_path = tmpdir / "test_updates.csv"
csv_path = str(tmpdir / "test_updates.csv")
initial_csv_content = [
["name", "age", "city"],
@@ -33,7 +33,7 @@ def test_csv_knowledge_source_updates(mock_add, mock_search, tmpdir):
for row in initial_csv_content:
f.write(",".join(row) + "\n")
csv_source = CSVKnowledgeSource(file_paths=[csv_path])
csv_source = CSVKnowledgeSource(file_paths=csv_path)
original_files_have_changed = csv_source.files_have_changed
files_changed_called = [False]