refactor: convert BaseKnowledgeStorage to BaseModel

This commit is contained in:
Greyson LaLonde
2026-03-31 07:51:54 +08:00
parent ced1d9da30
commit 5385c8370b
2 changed files with 18 additions and 11 deletions

View File

@@ -3,12 +3,15 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from pydantic import BaseModel, ConfigDict
if TYPE_CHECKING: if TYPE_CHECKING:
from crewai.rag.types import SearchResult from crewai.rag.types import SearchResult
class BaseKnowledgeStorage(ABC): class BaseKnowledgeStorage(BaseModel, ABC):
model_config = ConfigDict(arbitrary_types_allowed=True)
"""Abstract base class for knowledge storage implementations.""" """Abstract base class for knowledge storage implementations."""
@abstractmethod @abstractmethod

View File

@@ -3,6 +3,9 @@ import traceback
from typing import Any, cast from typing import Any, cast
import warnings import warnings
from pydantic import Field, PrivateAttr, model_validator
from typing_extensions import Self
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
from crewai.rag.chromadb.config import ChromaDBConfig from crewai.rag.chromadb.config import ChromaDBConfig
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
@@ -22,31 +25,32 @@ class KnowledgeStorage(BaseKnowledgeStorage):
search efficiency. search efficiency.
""" """
def __init__( collection_name: str | None = None
self, embedder: (
embedder: ProviderSpec ProviderSpec
| BaseEmbeddingsProvider[Any] | BaseEmbeddingsProvider[Any]
| type[BaseEmbeddingsProvider[Any]] | type[BaseEmbeddingsProvider[Any]]
| None = None, | None
collection_name: str | None = None, ) = Field(default=None, exclude=True)
) -> None: _client: BaseClient | None = PrivateAttr(default=None)
self.collection_name = collection_name
self._client: BaseClient | None = None
@model_validator(mode="after")
def _init_client(self) -> Self:
warnings.filterwarnings( warnings.filterwarnings(
"ignore", "ignore",
message=r".*'model_fields'.*is deprecated.*", message=r".*'model_fields'.*is deprecated.*",
module=r"^chromadb(\.|$)", module=r"^chromadb(\.|$)",
) )
if embedder: if self.embedder:
embedding_function = build_embedder(embedder) # type: ignore[arg-type] embedding_function = build_embedder(self.embedder) # type: ignore[arg-type]
config = ChromaDBConfig( config = ChromaDBConfig(
embedding_function=cast( embedding_function=cast(
ChromaEmbeddingFunctionWrapper, embedding_function ChromaEmbeddingFunctionWrapper, embedding_function
) )
) )
self._client = create_client(config) self._client = create_client(config)
return self
def _get_client(self) -> BaseClient: def _get_client(self) -> BaseClient:
"""Get the appropriate client - instance-specific or global.""" """Get the appropriate client - instance-specific or global."""