mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
Address code review comments: improve error handling, add thread safety, enhance documentation
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -5,6 +5,7 @@ from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
|
||||
from crewai.utilities.logger import Logger
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false" # removes logging from fastembed
|
||||
|
||||
@@ -12,10 +13,19 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false" # removes logging from fastembed
|
||||
class Knowledge(BaseModel):
|
||||
"""
|
||||
Knowledge is a collection of sources and setup for the vector store to save and query relevant context.
|
||||
|
||||
This class manages knowledge sources and provides methods to query them for relevant information.
|
||||
It automatically detects and reloads file-based knowledge sources when their underlying files change.
|
||||
|
||||
Args:
|
||||
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
|
||||
The knowledge sources to use for querying.
|
||||
storage: Optional[KnowledgeStorage] = Field(default=None)
|
||||
The storage backend for knowledge embeddings.
|
||||
embedder: Optional[Dict[str, Any]] = None
|
||||
Configuration for the embedding model.
|
||||
collection_name: Optional[str] = None
|
||||
Name of the collection to use for storage.
|
||||
"""
|
||||
|
||||
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
|
||||
@@ -23,6 +33,7 @@ class Knowledge(BaseModel):
|
||||
storage: Optional[KnowledgeStorage] = Field(default=None)
|
||||
embedder: Optional[Dict[str, Any]] = None
|
||||
collection_name: Optional[str] = None
|
||||
_logger: Logger = Logger(verbose=True)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -65,12 +76,30 @@ class Knowledge(BaseModel):
|
||||
return results
|
||||
|
||||
def _check_and_reload_sources(self):
|
||||
"""Check if any sources have changed and reload them if necessary."""
|
||||
"""
|
||||
Check if any file-based knowledge sources have changed and reload them if necessary.
|
||||
|
||||
This method detects modifications to source files by comparing their modification timestamps
|
||||
with previously recorded values. When changes are detected, the source is reloaded and
|
||||
the storage is updated with the new content.
|
||||
|
||||
Handles specific exceptions for file operations to provide better error reporting.
|
||||
"""
|
||||
for source in self.sources:
|
||||
if hasattr(source, 'files_have_changed') and source.files_have_changed():
|
||||
source._record_file_mtimes() # Update timestamps
|
||||
source.content = source.load_content()
|
||||
source.add() # Reload and update storage
|
||||
try:
|
||||
if hasattr(source, 'files_have_changed') and source.files_have_changed():
|
||||
self._logger.log("info", f"Reloading modified source: {source.__class__.__name__}")
|
||||
source._record_file_mtimes() # Update timestamps
|
||||
source.content = source.load_content()
|
||||
source.add() # Reload and update storage
|
||||
except FileNotFoundError as e:
|
||||
self._logger.log("error", f"File not found when checking for updates: {str(e)}")
|
||||
except PermissionError as e:
|
||||
self._logger.log("error", f"Permission error when checking for updates: {str(e)}")
|
||||
except IOError as e:
|
||||
self._logger.log("error", f"IO error when checking for updates: {str(e)}")
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Unexpected error when checking for updates: {str(e)}")
|
||||
|
||||
def add_sources(self):
|
||||
try:
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from threading import RLock
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
@@ -11,9 +13,25 @@ from crewai.utilities.logger import Logger
|
||||
|
||||
|
||||
class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
"""Base class for knowledge sources that load content from files."""
|
||||
"""
|
||||
Base class for knowledge sources that load content from files.
|
||||
|
||||
This class provides common functionality for file-based knowledge sources,
|
||||
including file path validation, content loading, and change detection.
|
||||
It automatically tracks file modification times to detect when files have
|
||||
been updated and need to be reloaded.
|
||||
|
||||
Attributes:
|
||||
file_path: Deprecated. Use file_paths instead.
|
||||
file_paths: Path(s) to the file(s) containing knowledge data.
|
||||
content: Dictionary mapping file paths to their loaded content.
|
||||
storage: Storage backend for the knowledge data.
|
||||
safe_file_paths: Validated list of Path objects.
|
||||
"""
|
||||
|
||||
_logger: Logger = Logger(verbose=True)
|
||||
_lock: RLock = RLock() # Thread-safe lock for file operations
|
||||
|
||||
file_path: Optional[Union[Path, List[Path], str, List[str]]] = Field(
|
||||
default=None,
|
||||
description="[Deprecated] The path to the file. Use file_paths instead.",
|
||||
@@ -47,11 +65,30 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
self.content = self.load_content()
|
||||
|
||||
def _record_file_mtimes(self):
|
||||
"""Record modification times of all files."""
|
||||
self._file_mtimes = {}
|
||||
for path in self.safe_file_paths:
|
||||
if path.exists() and path.is_file():
|
||||
self._file_mtimes[path] = path.stat().st_mtime
|
||||
"""
|
||||
Record modification times of all files.
|
||||
|
||||
This method stores the current modification timestamps of all files
|
||||
in the _file_mtimes dictionary. These timestamps are later used to
|
||||
detect when files have been modified and need to be reloaded.
|
||||
|
||||
Thread-safe: Uses a lock to prevent concurrent modifications.
|
||||
"""
|
||||
with self._lock:
|
||||
self._file_mtimes = {}
|
||||
for path in self.safe_file_paths:
|
||||
try:
|
||||
if path.exists() and path.is_file():
|
||||
if os.access(path, os.R_OK):
|
||||
self._file_mtimes[path] = path.stat().st_mtime
|
||||
else:
|
||||
self._logger.log("warning", f"File {path} is not readable.")
|
||||
except PermissionError as e:
|
||||
self._logger.log("error", f"Permission error when recording file timestamp for {path}: {str(e)}")
|
||||
except IOError as e:
|
||||
self._logger.log("error", f"IO error when recording file timestamp for {path}: {str(e)}")
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Unexpected error when recording file timestamp for {path}: {str(e)}")
|
||||
|
||||
@abstractmethod
|
||||
def load_content(self) -> Dict[Path, str]:
|
||||
@@ -117,12 +154,42 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
return [self.convert_to_path(path) for path in path_list]
|
||||
|
||||
def files_have_changed(self) -> bool:
|
||||
"""Check if any of the files have been modified since they were last loaded."""
|
||||
for path in self.safe_file_paths:
|
||||
if not path.exists() or not path.is_file():
|
||||
continue
|
||||
current_mtime = path.stat().st_mtime
|
||||
if path not in self._file_mtimes or current_mtime > self._file_mtimes[path]:
|
||||
self._logger.log("info", f"File {path} has been modified. Reloading data.")
|
||||
return True
|
||||
return False
|
||||
"""
|
||||
Check if any of the files have been modified since they were last loaded.
|
||||
|
||||
This method compares the current modification timestamps of files with the
|
||||
previously recorded timestamps to detect changes. When a file has been modified,
|
||||
it logs the change and returns True to trigger a reload.
|
||||
|
||||
Thread-safe: Uses a lock to prevent concurrent modifications.
|
||||
|
||||
Returns:
|
||||
bool: True if any file has been modified, False otherwise.
|
||||
"""
|
||||
with self._lock:
|
||||
for path in self.safe_file_paths:
|
||||
try:
|
||||
if not path.exists():
|
||||
self._logger.log("warning", f"File {path} no longer exists.")
|
||||
continue
|
||||
|
||||
if not path.is_file():
|
||||
self._logger.log("warning", f"Path {path} is not a file.")
|
||||
continue
|
||||
|
||||
if not os.access(path, os.R_OK):
|
||||
self._logger.log("warning", f"File {path} is not readable.")
|
||||
continue
|
||||
|
||||
current_mtime = path.stat().st_mtime
|
||||
if path not in self._file_mtimes or current_mtime > self._file_mtimes[path]:
|
||||
self._logger.log("info", f"File {path} has been modified. Reloading data.")
|
||||
return True
|
||||
except PermissionError as e:
|
||||
self._logger.log("error", f"Permission error when checking file {path}: {str(e)}")
|
||||
except IOError as e:
|
||||
self._logger.log("error", f"IO error when checking file {path}: {str(e)}")
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Unexpected error when checking file {path}: {str(e)}")
|
||||
|
||||
return False
|
||||
|
||||
@@ -20,7 +20,7 @@ def test_csv_knowledge_source_updates(mock_add, mock_search, tmpdir):
|
||||
[{"context": "name,age,city\nJohn,30,Boston\nAlice,25,San Francisco\nBob,28,Chicago\nEve,22,Miami"}]
|
||||
]
|
||||
|
||||
csv_path = tmpdir / "test_updates.csv"
|
||||
csv_path = str(tmpdir / "test_updates.csv")
|
||||
|
||||
initial_csv_content = [
|
||||
["name", "age", "city"],
|
||||
@@ -33,7 +33,7 @@ def test_csv_knowledge_source_updates(mock_add, mock_search, tmpdir):
|
||||
for row in initial_csv_content:
|
||||
f.write(",".join(row) + "\n")
|
||||
|
||||
csv_source = CSVKnowledgeSource(file_paths=[csv_path])
|
||||
csv_source = CSVKnowledgeSource(file_paths=csv_path)
|
||||
|
||||
original_files_have_changed = csv_source.files_have_changed
|
||||
files_changed_called = [False]
|
||||
|
||||
Reference in New Issue
Block a user