diff --git a/src/crewai/agent.py b/src/crewai/agent.py index d10b768d4..87a31add2 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -20,6 +20,7 @@ from crewai.tools.agent_tools.agent_tools import AgentTools from crewai.utilities import Converter, Prompts from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE from crewai.utilities.converter import generate_model_description +from crewai.utilities.embedding_configurator import EmbeddingConfig from crewai.utilities.events.agent_events import ( AgentExecutionCompletedEvent, AgentExecutionErrorEvent, @@ -108,7 +109,7 @@ class Agent(BaseAgent): default="safe", description="Mode for code execution: 'safe' (using Docker) or 'unsafe' (direct execution).", ) - embedder: Optional[Dict[str, Any]] = Field( + embedder: Optional[EmbeddingConfig] = Field( default=None, description="Embedder configuration for the agent.", ) @@ -134,7 +135,7 @@ class Agent(BaseAgent): self.cache_handler = CacheHandler() self.set_cache_handler(self.cache_handler) - def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None): + def set_knowledge(self, crew_embedder: Optional[EmbeddingConfig] = None): try: if self.embedder is None and crew_embedder: self.embedder = crew_embedder diff --git a/src/crewai/agents/agent_builder/base_agent.py b/src/crewai/agents/agent_builder/base_agent.py index 47515d087..68df90777 100644 --- a/src/crewai/agents/agent_builder/base_agent.py +++ b/src/crewai/agents/agent_builder/base_agent.py @@ -25,6 +25,7 @@ from crewai.tools.base_tool import BaseTool, Tool from crewai.utilities import I18N, Logger, RPMController from crewai.utilities.config import process_config from crewai.utilities.converter import Converter +from crewai.utilities.embedding_configurator import EmbeddingConfig T = TypeVar("T", bound="BaseAgent") @@ -362,5 +363,5 @@ class BaseAgent(ABC, BaseModel): self._rpm_controller = rpm_controller self.create_agent_executor() - def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None): + def set_knowledge(self, crew_embedder: Optional[EmbeddingConfig] = None): pass