mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-30 23:02:50 +00:00
refactor: convert BaseKnowledgeStorage to BaseModel
This commit is contained in:
@@ -3,12 +3,15 @@ from __future__ import annotations
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from crewai.rag.types import SearchResult
|
from crewai.rag.types import SearchResult
|
||||||
|
|
||||||
|
|
||||||
class BaseKnowledgeStorage(ABC):
|
class BaseKnowledgeStorage(BaseModel, ABC):
|
||||||
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
"""Abstract base class for knowledge storage implementations."""
|
"""Abstract base class for knowledge storage implementations."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|||||||
@@ -3,6 +3,9 @@ import traceback
|
|||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
from pydantic import Field, PrivateAttr, model_validator
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
|
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
|
||||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||||
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
|
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
|
||||||
@@ -22,31 +25,32 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
search efficiency.
|
search efficiency.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
collection_name: str | None = None
|
||||||
self,
|
embedder: (
|
||||||
embedder: ProviderSpec
|
ProviderSpec
|
||||||
| BaseEmbeddingsProvider[Any]
|
| BaseEmbeddingsProvider[Any]
|
||||||
| type[BaseEmbeddingsProvider[Any]]
|
| type[BaseEmbeddingsProvider[Any]]
|
||||||
| None = None,
|
| None
|
||||||
collection_name: str | None = None,
|
) = Field(default=None, exclude=True)
|
||||||
) -> None:
|
_client: BaseClient | None = PrivateAttr(default=None)
|
||||||
self.collection_name = collection_name
|
|
||||||
self._client: BaseClient | None = None
|
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def _init_client(self) -> Self:
|
||||||
warnings.filterwarnings(
|
warnings.filterwarnings(
|
||||||
"ignore",
|
"ignore",
|
||||||
message=r".*'model_fields'.*is deprecated.*",
|
message=r".*'model_fields'.*is deprecated.*",
|
||||||
module=r"^chromadb(\.|$)",
|
module=r"^chromadb(\.|$)",
|
||||||
)
|
)
|
||||||
|
|
||||||
if embedder:
|
if self.embedder:
|
||||||
embedding_function = build_embedder(embedder) # type: ignore[arg-type]
|
embedding_function = build_embedder(self.embedder) # type: ignore[arg-type]
|
||||||
config = ChromaDBConfig(
|
config = ChromaDBConfig(
|
||||||
embedding_function=cast(
|
embedding_function=cast(
|
||||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self._client = create_client(config)
|
self._client = create_client(config)
|
||||||
|
return self
|
||||||
|
|
||||||
def _get_client(self) -> BaseClient:
|
def _get_client(self) -> BaseClient:
|
||||||
"""Get the appropriate client - instance-specific or global."""
|
"""Get the appropriate client - instance-specific or global."""
|
||||||
|
|||||||
Reference in New Issue
Block a user