cleaner refactor

This commit is contained in:
Lorenze Jay
2024-12-03 12:18:07 -08:00
parent 7e93285df1
commit 8b1aef9e2d
6 changed files with 24 additions and 46 deletions

View File

@@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Union, List, Dict, Any from typing import Union, List, Dict, Any
from pydantic import Field, PrivateAttr from pydantic import Field
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.utilities.logger import Logger from crewai.utilities.logger import Logger
@@ -19,11 +19,11 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
) )
content: Dict[Path, str] = Field(init=False, default_factory=dict) content: Dict[Path, str] = Field(init=False, default_factory=dict)
storage: KnowledgeStorage = Field(default_factory=KnowledgeStorage) storage: KnowledgeStorage = Field(default_factory=KnowledgeStorage)
_safe_file_paths: List[Path] = PrivateAttr(default_factory=list) safe_file_paths: List[Path] = Field(default_factory=list)
def model_post_init(self, _): def model_post_init(self, _):
"""Post-initialization method to load content.""" """Post-initialization method to load content."""
self._safe_file_paths = self._process_file_paths() self.safe_file_paths = self._process_file_paths()
self.validate_paths() self.validate_paths()
self.content = self.load_content() self.content = self.load_content()
@@ -34,21 +34,7 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
def validate_paths(self): def validate_paths(self):
"""Validate the paths.""" """Validate the paths."""
if isinstance(self.file_path, str): for path in self.safe_file_paths:
self.file_path = self.convert_to_path(self.file_path)
elif isinstance(self.file_path, list):
processed_paths = []
for path in self.file_path:
processed_paths.append(self.convert_to_path(path))
self.file_path = processed_paths
paths = [self.file_path] if isinstance(self.file_path, Path) else self.file_path
if not isinstance(paths, list):
raise ValueError("file_path must be a Path or a list of Paths")
paths = [Path(path) if isinstance(path, str) else path for path in paths]
for path in paths:
if not path.exists(): if not path.exists():
self._logger.log( self._logger.log(
"error", "error",

View File

@@ -10,18 +10,15 @@ class CSVKnowledgeSource(BaseFileKnowledgeSource):
def load_content(self) -> Dict[Path, str]: def load_content(self) -> Dict[Path, str]:
"""Load and preprocess CSV file content.""" """Load and preprocess CSV file content."""
content_dict = {}
file_path = ( for file_path in self.safe_file_paths:
self.file_path[0] if isinstance(self.file_path, list) else self.file_path with open(file_path, "r", encoding="utf-8") as csvfile:
) reader = csv.reader(csvfile)
file_path = Path(file_path) if isinstance(file_path, str) else file_path content = ""
for row in reader:
with open(file_path, "r", encoding="utf-8") as csvfile: content += " ".join(row) + "\n"
reader = csv.reader(csvfile) content_dict[file_path] = content
content = "" return content_dict
for row in reader:
content += " ".join(row) + "\n"
return {file_path: content}
def add(self) -> None: def add(self) -> None:
""" """

View File

@@ -10,14 +10,13 @@ class ExcelKnowledgeSource(BaseFileKnowledgeSource):
"""Load and preprocess Excel file content.""" """Load and preprocess Excel file content."""
pd = self._import_dependencies() pd = self._import_dependencies()
if isinstance(self.file_path, list): content_dict = {}
file_path = self.convert_to_path(self.file_path[0]) for file_path in self.safe_file_paths:
else: file_path = self.convert_to_path(file_path)
file_path = self.convert_to_path(self.file_path) df = pd.read_excel(file_path)
content = df.to_csv(index=False)
df = pd.read_excel(file_path) content_dict[file_path] = content
content = df.to_csv(index=False) return content_dict
return {file_path: content}
def _import_dependencies(self): def _import_dependencies(self):
"""Dynamically import dependencies.""" """Dynamically import dependencies."""

View File

@@ -10,10 +10,8 @@ class JSONKnowledgeSource(BaseFileKnowledgeSource):
def load_content(self) -> Dict[Path, str]: def load_content(self) -> Dict[Path, str]:
"""Load and preprocess JSON file content.""" """Load and preprocess JSON file content."""
paths = [self.file_path] if isinstance(self.file_path, Path) else self.file_path
content: Dict[Path, str] = {} content: Dict[Path, str] = {}
for path in paths: for path in self.safe_file_paths:
path = self.convert_to_path(path) path = self.convert_to_path(path)
with open(path, "r", encoding="utf-8") as json_file: with open(path, "r", encoding="utf-8") as json_file:
data = json.load(json_file) data = json.load(json_file)

View File

@@ -11,10 +11,9 @@ class PDFKnowledgeSource(BaseFileKnowledgeSource):
"""Load and preprocess PDF file content.""" """Load and preprocess PDF file content."""
pdfplumber = self._import_pdfplumber() pdfplumber = self._import_pdfplumber()
paths = [self.file_path] if isinstance(self.file_path, Path) else self.file_path
content = {} content = {}
for path in paths: for path in self.safe_file_paths:
text = "" text = ""
path = self.convert_to_path(path) path = self.convert_to_path(path)
with pdfplumber.open(path) as pdf: with pdfplumber.open(path) as pdf:

View File

@@ -9,10 +9,9 @@ class TextFileKnowledgeSource(BaseFileKnowledgeSource):
def load_content(self) -> Dict[Path, str]: def load_content(self) -> Dict[Path, str]:
"""Load and preprocess text file content.""" """Load and preprocess text file content."""
paths = [self.file_path] if isinstance(self.file_path, Path) else self.file_path
content = {} content = {}
for path in paths: for path in self.safe_file_paths:
path = Path(path) path = self.convert_to_path(path)
with open(path, "r", encoding="utf-8") as f: with open(path, "r", encoding="utf-8") as f:
content[path] = f.read() content[path] = f.read()
return content return content