From 8f99caf61be4c7fd938e473694ec0fa3805bd7dd Mon Sep 17 00:00:00 2001 From: Nick Fujita Date: Thu, 20 Feb 2025 17:58:46 +0900 Subject: [PATCH] 'type cleanup' --- src/crewai/crew.py | 3 ++- src/crewai/knowledge/knowledge.py | 5 +++-- src/crewai/knowledge/storage/knowledge_storage.py | 5 +++-- src/crewai/utilities/embedding_configurator.py | 4 ++-- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/crewai/crew.py b/src/crewai/crew.py index 31678ae88..64218f1a8 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -42,6 +42,7 @@ from crewai.traces.unified_trace_controller import init_crew_main_trace from crewai.types.usage_metrics import UsageMetrics from crewai.utilities import I18N, FileHandler, Logger, RPMController from crewai.utilities.constants import TRAINING_DATA_FILE +from crewai.utilities.embedding_configurator import EmbeddingConfig from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator from crewai.utilities.evaluators.task_evaluator import TaskEvaluator from crewai.utilities.events.crew_events import ( @@ -144,7 +145,7 @@ class Crew(BaseModel): default=None, description="An instance of the UserMemory to be used by the Crew to store/fetch memories of a specific user.", ) - embedder: Optional[dict] = Field( + embedder: Optional[EmbeddingConfig] = Field( default=None, description="Configuration for the embedder to be used for the crew.", ) diff --git a/src/crewai/knowledge/knowledge.py b/src/crewai/knowledge/knowledge.py index da1db90a8..f4057fe07 100644 --- a/src/crewai/knowledge/knowledge.py +++ b/src/crewai/knowledge/knowledge.py @@ -5,6 +5,7 @@ from pydantic import BaseModel, ConfigDict, Field from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage +from crewai.utilities.embedding_configurator import EmbeddingConfig os.environ["TOKENIZERS_PARALLELISM"] = "false" # removes logging from fastembed @@ -21,14 +22,14 @@ class Knowledge(BaseModel): sources: List[BaseKnowledgeSource] = Field(default_factory=list) model_config = ConfigDict(arbitrary_types_allowed=True) storage: Optional[KnowledgeStorage] = Field(default=None) - embedder: Optional[Dict[str, Any]] = None + embedder: Optional[EmbeddingConfig] = None collection_name: Optional[str] = None def __init__( self, collection_name: str, sources: List[BaseKnowledgeSource], - embedder: Optional[Dict[str, Any]] = None, + embedder: Optional[EmbeddingConfig] = None, storage: Optional[KnowledgeStorage] = None, **data, ): diff --git a/src/crewai/knowledge/storage/knowledge_storage.py b/src/crewai/knowledge/storage/knowledge_storage.py index 72240e2b6..c902e8b6e 100644 --- a/src/crewai/knowledge/storage/knowledge_storage.py +++ b/src/crewai/knowledge/storage/knowledge_storage.py @@ -15,6 +15,7 @@ from chromadb.config import Settings from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage from crewai.utilities import EmbeddingConfigurator from crewai.utilities.constants import KNOWLEDGE_DIRECTORY +from crewai.utilities.embedding_configurator import EmbeddingConfig from crewai.utilities.logger import Logger from crewai.utilities.paths import db_storage_path @@ -48,7 +49,7 @@ class KnowledgeStorage(BaseKnowledgeStorage): def __init__( self, - embedder: Optional[Dict[str, Any]] = None, + embedder: Optional[EmbeddingConfig] = None, collection_name: Optional[str] = None, ): self.collection_name = collection_name @@ -187,7 +188,7 @@ class KnowledgeStorage(BaseKnowledgeStorage): api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small" ) - def _set_embedder_config(self, embedder: Optional[Dict[str, Any]] = None) -> None: + def _set_embedder_config(self, embedder: Optional[EmbeddingConfig] = None) -> None: """Set the embedding configuration for the knowledge storage. Args: diff --git a/src/crewai/utilities/embedding_configurator.py b/src/crewai/utilities/embedding_configurator.py index 2b71ba09f..78704e954 100644 --- a/src/crewai/utilities/embedding_configurator.py +++ b/src/crewai/utilities/embedding_configurator.py @@ -1,5 +1,5 @@ import os -from typing import Any, Dict, Optional, cast +from typing import Any, Callable, cast from chromadb import Documents, EmbeddingFunction, Embeddings from chromadb.api.types import validate_embedding_function @@ -14,7 +14,7 @@ class EmbeddingProviderConfig(BaseModel): task_type: str | None = None session: str | None = None api_url: str | None = None - embedder: str | callable | None = None + embedder: str | Callable | None = None api_key: str | None = None api_base: str | None = None api_type: str | None = None