fix types

This commit is contained in:
Lorenze Jay
2024-12-17 10:18:48 -08:00
parent 7885c5f906
commit 436a458072

View File

@@ -1,8 +1,8 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, List, Union
from typing import Dict, List, Optional, Union
from pydantic import Field
from pydantic import Field, field_validator
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
@@ -14,17 +14,24 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
"""Base class for knowledge sources that load content from files."""
_logger: Logger = Logger(verbose=True)
file_path: Union[Path, List[Path], str, List[str]] = Field(
file_path: Optional[Union[Path, List[Path], str, List[str]]] = Field(
default=None,
description="[Deprecated] The path to the file. Use file_paths instead.",
)
file_paths: Union[Path, List[Path], str, List[str]] = Field(
..., description="The path to the file"
file_paths: Optional[Union[Path, List[Path], str, List[str]]] = Field(
default_factory=list, description="The path to the file"
)
content: Dict[Path, str] = Field(init=False, default_factory=dict)
storage: KnowledgeStorage = Field(default_factory=KnowledgeStorage)
safe_file_paths: List[Path] = Field(default_factory=list)
@field_validator("file_path", "file_paths", mode="before")
def validate_file_path(cls, v, values):
"""Validate that at least one of file_path or file_paths is provided."""
if v is None and ("file_path" not in values or values.get("file_path") is None):
raise ValueError("Either file_path or file_paths must be provided")
return v
def model_post_init(self, _):
"""Post-initialization method to load content."""
self.safe_file_paths = self._process_file_paths()
@@ -77,6 +84,8 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
else self.file_path
)
else:
if self.file_paths is None:
raise ValueError("Your source must be provided with a file_paths: []")
paths = (
[self.file_paths]
if isinstance(self.file_paths, (str, Path))