cleaner refactor

This commit is contained in:
Lorenze Jay
2024-12-03 12:18:07 -08:00
parent 7e93285df1
commit 8b1aef9e2d
6 changed files with 24 additions and 46 deletions

View File

@@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
from pathlib import Path
from typing import Union, List, Dict, Any
from pydantic import Field, PrivateAttr
from pydantic import Field
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.utilities.logger import Logger
@@ -19,11 +19,11 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
)
content: Dict[Path, str] = Field(init=False, default_factory=dict)
storage: KnowledgeStorage = Field(default_factory=KnowledgeStorage)
_safe_file_paths: List[Path] = PrivateAttr(default_factory=list)
safe_file_paths: List[Path] = Field(default_factory=list)
def model_post_init(self, _):
"""Post-initialization method to load content."""
self._safe_file_paths = self._process_file_paths()
self.safe_file_paths = self._process_file_paths()
self.validate_paths()
self.content = self.load_content()
@@ -34,21 +34,7 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
def validate_paths(self):
"""Validate the paths."""
if isinstance(self.file_path, str):
self.file_path = self.convert_to_path(self.file_path)
elif isinstance(self.file_path, list):
processed_paths = []
for path in self.file_path:
processed_paths.append(self.convert_to_path(path))
self.file_path = processed_paths
paths = [self.file_path] if isinstance(self.file_path, Path) else self.file_path
if not isinstance(paths, list):
raise ValueError("file_path must be a Path or a list of Paths")
paths = [Path(path) if isinstance(path, str) else path for path in paths]
for path in paths:
for path in self.safe_file_paths:
if not path.exists():
self._logger.log(
"error",

View File

@@ -10,18 +10,15 @@ class CSVKnowledgeSource(BaseFileKnowledgeSource):
def load_content(self) -> Dict[Path, str]:
"""Load and preprocess CSV file content."""
file_path = (
self.file_path[0] if isinstance(self.file_path, list) else self.file_path
)
file_path = Path(file_path) if isinstance(file_path, str) else file_path
with open(file_path, "r", encoding="utf-8") as csvfile:
reader = csv.reader(csvfile)
content = ""
for row in reader:
content += " ".join(row) + "\n"
return {file_path: content}
content_dict = {}
for file_path in self.safe_file_paths:
with open(file_path, "r", encoding="utf-8") as csvfile:
reader = csv.reader(csvfile)
content = ""
for row in reader:
content += " ".join(row) + "\n"
content_dict[file_path] = content
return content_dict
def add(self) -> None:
"""

View File

@@ -10,14 +10,13 @@ class ExcelKnowledgeSource(BaseFileKnowledgeSource):
"""Load and preprocess Excel file content."""
pd = self._import_dependencies()
if isinstance(self.file_path, list):
file_path = self.convert_to_path(self.file_path[0])
else:
file_path = self.convert_to_path(self.file_path)
df = pd.read_excel(file_path)
content = df.to_csv(index=False)
return {file_path: content}
content_dict = {}
for file_path in self.safe_file_paths:
file_path = self.convert_to_path(file_path)
df = pd.read_excel(file_path)
content = df.to_csv(index=False)
content_dict[file_path] = content
return content_dict
def _import_dependencies(self):
"""Dynamically import dependencies."""

View File

@@ -10,10 +10,8 @@ class JSONKnowledgeSource(BaseFileKnowledgeSource):
def load_content(self) -> Dict[Path, str]:
"""Load and preprocess JSON file content."""
paths = [self.file_path] if isinstance(self.file_path, Path) else self.file_path
content: Dict[Path, str] = {}
for path in paths:
for path in self.safe_file_paths:
path = self.convert_to_path(path)
with open(path, "r", encoding="utf-8") as json_file:
data = json.load(json_file)

View File

@@ -11,10 +11,9 @@ class PDFKnowledgeSource(BaseFileKnowledgeSource):
"""Load and preprocess PDF file content."""
pdfplumber = self._import_pdfplumber()
paths = [self.file_path] if isinstance(self.file_path, Path) else self.file_path
content = {}
for path in paths:
for path in self.safe_file_paths:
text = ""
path = self.convert_to_path(path)
with pdfplumber.open(path) as pdf:

View File

@@ -9,10 +9,9 @@ class TextFileKnowledgeSource(BaseFileKnowledgeSource):
def load_content(self) -> Dict[Path, str]:
"""Load and preprocess text file content."""
paths = [self.file_path] if isinstance(self.file_path, Path) else self.file_path
content = {}
for path in paths:
path = Path(path)
for path in self.safe_file_paths:
path = self.convert_to_path(path)
with open(path, "r", encoding="utf-8") as f:
content[path] = f.read()
return content