mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 00:58:30 +00:00
fix types
This commit is contained in:
@@ -1,8 +1,8 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field, field_validator
|
||||||
|
|
||||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||||
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
|
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
|
||||||
@@ -14,17 +14,24 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
|||||||
"""Base class for knowledge sources that load content from files."""
|
"""Base class for knowledge sources that load content from files."""
|
||||||
|
|
||||||
_logger: Logger = Logger(verbose=True)
|
_logger: Logger = Logger(verbose=True)
|
||||||
file_path: Union[Path, List[Path], str, List[str]] = Field(
|
file_path: Optional[Union[Path, List[Path], str, List[str]]] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="[Deprecated] The path to the file. Use file_paths instead.",
|
description="[Deprecated] The path to the file. Use file_paths instead.",
|
||||||
)
|
)
|
||||||
file_paths: Union[Path, List[Path], str, List[str]] = Field(
|
file_paths: Optional[Union[Path, List[Path], str, List[str]]] = Field(
|
||||||
..., description="The path to the file"
|
default_factory=list, description="The path to the file"
|
||||||
)
|
)
|
||||||
content: Dict[Path, str] = Field(init=False, default_factory=dict)
|
content: Dict[Path, str] = Field(init=False, default_factory=dict)
|
||||||
storage: KnowledgeStorage = Field(default_factory=KnowledgeStorage)
|
storage: KnowledgeStorage = Field(default_factory=KnowledgeStorage)
|
||||||
safe_file_paths: List[Path] = Field(default_factory=list)
|
safe_file_paths: List[Path] = Field(default_factory=list)
|
||||||
|
|
||||||
|
@field_validator("file_path", "file_paths", mode="before")
|
||||||
|
def validate_file_path(cls, v, values):
|
||||||
|
"""Validate that at least one of file_path or file_paths is provided."""
|
||||||
|
if v is None and ("file_path" not in values or values.get("file_path") is None):
|
||||||
|
raise ValueError("Either file_path or file_paths must be provided")
|
||||||
|
return v
|
||||||
|
|
||||||
def model_post_init(self, _):
|
def model_post_init(self, _):
|
||||||
"""Post-initialization method to load content."""
|
"""Post-initialization method to load content."""
|
||||||
self.safe_file_paths = self._process_file_paths()
|
self.safe_file_paths = self._process_file_paths()
|
||||||
@@ -77,6 +84,8 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
|||||||
else self.file_path
|
else self.file_path
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
if self.file_paths is None:
|
||||||
|
raise ValueError("Your source must be provided with a file_paths: []")
|
||||||
paths = (
|
paths = (
|
||||||
[self.file_paths]
|
[self.file_paths]
|
||||||
if isinstance(self.file_paths, (str, Path))
|
if isinstance(self.file_paths, (str, Path))
|
||||||
|
|||||||
Reference in New Issue
Block a user