diff --git a/src/crewai/agent.py b/src/crewai/agent.py index 1b450ea6f..e90d0c601 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -3,9 +3,10 @@ import shutil import subprocess from typing import Any, Dict, List, Literal, Optional, Sequence, Union -from pydantic import Field, InstanceOf, PrivateAttr, model_validator +from pydantic import Field, InstanceOf, PrivateAttr, field_validator, model_validator from crewai.agents import CacheHandler +from crewai.utilities import EmbeddingConfigurator from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.agents.crew_agent_executor import CrewAgentExecutor from crewai.knowledge.knowledge import Knowledge @@ -120,12 +121,22 @@ class Agent(BaseAgent): description="Embedder configuration for the agent. Must include 'provider' and relevant configuration parameters.", ) - @validator("embedder_config") - def validate_embedder_config(cls, v): + @field_validator("embedder_config") + @classmethod + def validate_embedder_config(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: """Validate embedder configuration. Args: v: The embedder configuration to validate. + Must include 'provider' and 'config' keys. + Example: + { + 'provider': 'openai', + 'config': { + 'api_key': 'your-key', + 'model': 'text-embedding-3-small' + } + } Returns: The validated embedder configuration. @@ -134,10 +145,17 @@ class Agent(BaseAgent): ValueError: If the embedder configuration is invalid. """ if v is not None: - if not isinstance(v, dict) or "provider" not in v: + if not isinstance(v, dict): + raise ValueError("embedder_config must be a dictionary") + if "provider" not in v: raise ValueError("embedder_config must contain 'provider' key") if "config" not in v: raise ValueError("embedder_config must contain 'config' key") + if v["provider"] not in EmbeddingConfigurator().embedding_functions: + raise ValueError( + f"Unsupported embedding provider: {v['provider']}, " + f"supported providers: {list(EmbeddingConfigurator().embedding_functions.keys())}" + ) return v @model_validator(mode="after")