logged but file_path is backwards compatible

This commit is contained in:
Lorenze Jay
2024-12-16 16:30:47 -08:00
parent f1c9caa8ec
commit 054bc266b9

View File

@@ -19,11 +19,20 @@ class DoclingSource(BaseFileKnowledgeSource):
file_paths: List[str] = Field(default_factory=list) file_paths: List[str] = Field(default_factory=list)
document_converter: DocumentConverter = Field(default_factory=DocumentConverter) document_converter: DocumentConverter = Field(default_factory=DocumentConverter)
safe_file_paths: List[Union[Path, str]] = Field(default_factory=list)
content: List[DoclingDocument] | None = Field(default=None)
chunks: List[str] = Field(default_factory=list) chunks: List[str] = Field(default_factory=list)
# We are accepting string urls and validating them if they are valid urls
# Overiding content to be a list of DoclingDocuments
safe_file_paths: List[Union[Path, str]] = Field(default_factory=list) # type: ignore[assignment]
content: List[DoclingDocument] | None = Field(default=None) # type: ignore[assignment]
def model_post_init(self, _) -> None: def model_post_init(self, _) -> None:
if self.file_path:
self._logger.log(
"warning",
"The 'file_path' attribute is deprecated and will be removed in a future version. Please use 'file_paths' instead.",
color="yellow",
)
self.file_paths = self.file_path # type: ignore[assignment]
self.safe_file_paths = self._process_file_paths() self.safe_file_paths = self._process_file_paths()
self.document_converter = DocumentConverter( self.document_converter = DocumentConverter(
allowed_formats=[ allowed_formats=[
@@ -39,7 +48,7 @@ class DoclingSource(BaseFileKnowledgeSource):
) )
self.content = self.load_content() self.content = self.load_content()
def load_content(self) -> List[DoclingDocument] | None: def load_content(self) -> List[DoclingDocument] | None: # type: ignore[assignment]
try: try:
return self.convert_source_to_docling_documents() return self.convert_source_to_docling_documents()
except Exception as e: except Exception as e:
@@ -58,28 +67,43 @@ class DoclingSource(BaseFileKnowledgeSource):
conv_results_iter = self.document_converter.convert_all(self.safe_file_paths) conv_results_iter = self.document_converter.convert_all(self.safe_file_paths)
return [result.document for result in conv_results_iter] return [result.document for result in conv_results_iter]
def _chunk_text(self, doc: DoclingDocument) -> Iterator[str]: def _chunk_text(self, doc: DoclingDocument) -> Iterator[str]: # type: ignore[assignment]
chunker = HierarchicalChunker() chunker = HierarchicalChunker()
for chunk in chunker.chunk(doc): for chunk in chunker.chunk(doc):
yield chunk.text yield chunk.text
def _process_file_paths(self) -> list[Path | str]: def _process_file_paths(self) -> list[Path | str]: # type: ignore[assignment]
processed_paths = [] processed_paths = []
for path in self.file_paths: for path in self.file_paths:
if path.startswith("http"): if isinstance(path, str):
if path.startswith(("http://", "https://")): if path.startswith(("http://", "https://")):
try: try:
result = urlparse(path) if self._validate_url(path):
if all([result.scheme, result.netloc]): # Basic URL validation
processed_paths.append(path) processed_paths.append(path)
else: else:
raise ValueError(f"Invalid URL format: {path}") raise ValueError(f"Invalid URL format: {path}")
except Exception as e: except Exception as e:
raise ValueError(f"Invalid URL: {path}. Error: {str(e)}") raise ValueError(f"Invalid URL: {path}. Error: {str(e)}")
else:
local_path = Path(KNOWLEDGE_DIRECTORY + "/" + path)
if local_path.exists():
processed_paths.append(local_path)
else: else:
raise FileNotFoundError(f"File not found: {local_path}") local_path = Path(KNOWLEDGE_DIRECTORY + "/" + path)
if local_path.exists():
processed_paths.append(local_path)
else:
raise FileNotFoundError(f"File not found: {local_path}")
else:
# this is an instance of Path
processed_paths.append(path)
return processed_paths return processed_paths
def _validate_url(self, url: str) -> bool:
try:
result = urlparse(url)
return all(
[
result.scheme in ("http", "https"),
result.netloc,
len(result.netloc.split(".")) >= 2, # Ensure domain has TLD
]
)
except Exception:
return False