diff --git a/lib/crewai/src/crewai/knowledge/storage/base_knowledge_storage.py b/lib/crewai/src/crewai/knowledge/storage/base_knowledge_storage.py index e8a2054f7..ea8aff734 100644 --- a/lib/crewai/src/crewai/knowledge/storage/base_knowledge_storage.py +++ b/lib/crewai/src/crewai/knowledge/storage/base_knowledge_storage.py @@ -3,12 +3,15 @@ from __future__ import annotations from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any +from pydantic import BaseModel, ConfigDict + if TYPE_CHECKING: from crewai.rag.types import SearchResult -class BaseKnowledgeStorage(ABC): +class BaseKnowledgeStorage(BaseModel, ABC): + model_config = ConfigDict(arbitrary_types_allowed=True) """Abstract base class for knowledge storage implementations.""" @abstractmethod diff --git a/lib/crewai/src/crewai/knowledge/storage/knowledge_storage.py b/lib/crewai/src/crewai/knowledge/storage/knowledge_storage.py index cfcbca25a..3c9615946 100644 --- a/lib/crewai/src/crewai/knowledge/storage/knowledge_storage.py +++ b/lib/crewai/src/crewai/knowledge/storage/knowledge_storage.py @@ -3,6 +3,9 @@ import traceback from typing import Any, cast import warnings +from pydantic import Field, PrivateAttr, model_validator +from typing_extensions import Self + from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage from crewai.rag.chromadb.config import ChromaDBConfig from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper @@ -22,31 +25,32 @@ class KnowledgeStorage(BaseKnowledgeStorage): search efficiency. """ - def __init__( - self, - embedder: ProviderSpec + collection_name: str | None = None + embedder: ( + ProviderSpec | BaseEmbeddingsProvider[Any] | type[BaseEmbeddingsProvider[Any]] - | None = None, - collection_name: str | None = None, - ) -> None: - self.collection_name = collection_name - self._client: BaseClient | None = None + | None + ) = Field(default=None, exclude=True) + _client: BaseClient | None = PrivateAttr(default=None) + @model_validator(mode="after") + def _init_client(self) -> Self: warnings.filterwarnings( "ignore", message=r".*'model_fields'.*is deprecated.*", module=r"^chromadb(\.|$)", ) - if embedder: - embedding_function = build_embedder(embedder) # type: ignore[arg-type] + if self.embedder: + embedding_function = build_embedder(self.embedder) # type: ignore[arg-type] config = ChromaDBConfig( embedding_function=cast( ChromaEmbeddingFunctionWrapper, embedding_function ) ) self._client = create_client(config) + return self def _get_client(self) -> BaseClient: """Get the appropriate client - instance-specific or global."""