From 436a458072e2df9f7e44b075d2e8ef19fa3dc0e8 Mon Sep 17 00:00:00 2001 From: Lorenze Jay Date: Tue, 17 Dec 2024 10:18:48 -0800 Subject: [PATCH] fix types --- .../source/base_file_knowledge_source.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/src/crewai/knowledge/source/base_file_knowledge_source.py b/src/crewai/knowledge/source/base_file_knowledge_source.py index 90f787e0a..b863197da 100644 --- a/src/crewai/knowledge/source/base_file_knowledge_source.py +++ b/src/crewai/knowledge/source/base_file_knowledge_source.py @@ -1,8 +1,8 @@ from abc import ABC, abstractmethod from pathlib import Path -from typing import Dict, List, Union +from typing import Dict, List, Optional, Union -from pydantic import Field +from pydantic import Field, field_validator from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage @@ -14,17 +14,24 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC): """Base class for knowledge sources that load content from files.""" _logger: Logger = Logger(verbose=True) - file_path: Union[Path, List[Path], str, List[str]] = Field( + file_path: Optional[Union[Path, List[Path], str, List[str]]] = Field( default=None, description="[Deprecated] The path to the file. Use file_paths instead.", ) - file_paths: Union[Path, List[Path], str, List[str]] = Field( - ..., description="The path to the file" + file_paths: Optional[Union[Path, List[Path], str, List[str]]] = Field( + default_factory=list, description="The path to the file" ) content: Dict[Path, str] = Field(init=False, default_factory=dict) storage: KnowledgeStorage = Field(default_factory=KnowledgeStorage) safe_file_paths: List[Path] = Field(default_factory=list) + @field_validator("file_path", "file_paths", mode="before") + def validate_file_path(cls, v, values): + """Validate that at least one of file_path or file_paths is provided.""" + if v is None and ("file_path" not in values or values.get("file_path") is None): + raise ValueError("Either file_path or file_paths must be provided") + return v + def model_post_init(self, _): """Post-initialization method to load content.""" self.safe_file_paths = self._process_file_paths() @@ -77,6 +84,8 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC): else self.file_path ) else: + if self.file_paths is None: + raise ValueError("Your source must be provided with a file_paths: []") paths = ( [self.file_paths] if isinstance(self.file_paths, (str, Path))