diff --git a/lib/crewai/src/crewai/crew.py b/lib/crewai/src/crewai/crew.py index a8e88ce55..61b87d487 100644 --- a/lib/crewai/src/crewai/crew.py +++ b/lib/crewai/src/crewai/crew.py @@ -305,6 +305,121 @@ class Crew(FlowTrackable, BaseModel): # TODO: Improve typing return json.loads(v) if isinstance(v, Json) else v # type: ignore + @field_validator("embedder", mode="before") + @classmethod + def validate_embedder_config(cls, v: Any) -> Any: + """Validates embedder configuration and provides clear error messages. + + Args: + v: The embedder configuration to be validated. + + Returns: + The embedder config if it is valid. + + Raises: + PydanticCustomError: If the embedder configuration is invalid, + with a clear, helpful error message. + """ + if v is None: + return v + + if not isinstance(v, dict): + return v + + provider = v.get("provider") + if not provider: + return v + + valid_providers = [ + "azure", + "amazon-bedrock", + "cohere", + "custom", + "google-generativeai", + "google-vertex", + "huggingface", + "instructor", + "jina", + "ollama", + "onnx", + "openai", + "openclip", + "roboflow", + "sentence-transformer", + "text2vec", + "voyageai", + "watsonx", + ] + + if provider not in valid_providers: + raise PydanticCustomError( + "invalid_embedder_provider", + ( + f"Invalid embedder provider: '{provider}'. " + f"Valid providers are: {', '.join(valid_providers)}. " + "Please check the documentation for the correct provider name." + ), + {}, + ) + + providers_requiring_config = [ + "google-generativeai", + "google-vertex", + "openclip", + "roboflow", + "sentence-transformer", + "text2vec", + "voyageai", + ] + + if provider in providers_requiring_config and "config" not in v: + example_config = {} + if provider == "google-generativeai": + example_config = { + "provider": "google-generativeai", + "config": { + "api_key": "your_api_key", + "model_name": "models/embedding-001", + } + } + elif provider == "google-vertex": + example_config = { + "provider": "google-vertex", + "config": { + "api_key": "your_api_key", + "model_name": "textembedding-gecko", + "project_id": "your_project_id", + "region": "us-central1", + } + } + elif provider == "openai": + example_config = { + "provider": "openai", + "api_key": "your_api_key", + "model": "text-embedding-3-small", + } + else: + example_config = { + "provider": provider, + "config": { + "api_key": "your_api_key", + "model_name": "your_model_name", + } + } + + raise PydanticCustomError( + "invalid_embedder_config_structure", + ( + f"Invalid embedder configuration for provider '{provider}'. " + f"The configuration is missing the required 'config' field. " + f"Expected structure:\n{json.dumps(example_config, indent=2)}\n" + f"But received:\n{json.dumps(v, indent=2)}" + ), + {}, + ) + + return v + @model_validator(mode="after") def set_private_attrs(self) -> Crew: """set private attributes.""" diff --git a/lib/crewai/tests/test_embedder_validation.py b/lib/crewai/tests/test_embedder_validation.py new file mode 100644 index 000000000..92d594b06 --- /dev/null +++ b/lib/crewai/tests/test_embedder_validation.py @@ -0,0 +1,257 @@ +"""Tests for embedder configuration validation (Issue #3755).""" + +import pytest +from pydantic_core import ValidationError + +from crewai import Agent, Crew, Task + + +@pytest.fixture +def simple_agent(): + """Create a simple agent for testing.""" + return Agent( + role="Test Agent", + goal="Test goal", + backstory="Test backstory" + ) + + +@pytest.fixture +def simple_task(simple_agent): + """Create a simple task for testing.""" + return Task( + description="Test task", + expected_output="Test output", + agent=simple_agent + ) + + +def test_invalid_embedder_provider_name(simple_agent, simple_task): + """Test that an invalid provider name gives a clear error message.""" + invalid_config = { + "provider": "invalid-provider", + "config": { + "api_key": "test_key", + "model_name": "test_model" + } + } + + with pytest.raises(ValidationError) as exc_info: + Crew( + agents=[simple_agent], + tasks=[simple_task], + embedder=invalid_config + ) + + error_message = str(exc_info.value) + assert "Invalid embedder provider: 'invalid-provider'" in error_message + assert "Valid providers are:" in error_message + assert "google-generativeai" in error_message + assert "openai" in error_message + + +def test_google_generativeai_missing_config_field(simple_agent, simple_task): + """Test that missing config field for google-generativeai gives a clear error.""" + invalid_config = { + "provider": "google-generativeai", + "model_name": "models/text-embedding-004", + "api_key": "test_key" + } + + with pytest.raises(ValidationError) as exc_info: + Crew( + agents=[simple_agent], + tasks=[simple_task], + embedder=invalid_config + ) + + error_message = str(exc_info.value) + assert "Invalid embedder configuration for provider 'google-generativeai'" in error_message + assert "missing the required 'config' field" in error_message + assert "Expected structure:" in error_message + assert '"provider": "google-generativeai"' in error_message + assert '"config"' in error_message + + +def test_google_vertex_missing_config_field(simple_agent, simple_task): + """Test that missing config field for google-vertex gives a clear error.""" + invalid_config = { + "provider": "google-vertex", + "model_name": "textembedding-gecko", + "api_key": "test_key" + } + + with pytest.raises(ValidationError) as exc_info: + Crew( + agents=[simple_agent], + tasks=[simple_task], + embedder=invalid_config + ) + + error_message = str(exc_info.value) + assert "Invalid embedder configuration for provider 'google-vertex'" in error_message + assert "missing the required 'config' field" in error_message + assert "Expected structure:" in error_message + + +def test_valid_google_generativeai_config(simple_agent, simple_task): + """Test that a valid google-generativeai config is accepted.""" + valid_config = { + "provider": "google-generativeai", + "config": { + "api_key": "test_key", + "model_name": "models/embedding-001" + } + } + + crew = Crew( + agents=[simple_agent], + tasks=[simple_task], + embedder=valid_config + ) + + assert crew.embedder == valid_config + + +def test_valid_openai_config(simple_agent, simple_task): + """Test that a valid openai config is accepted.""" + valid_config = { + "provider": "openai", + "api_key": "test_key", + "model": "text-embedding-3-small" + } + + crew = Crew( + agents=[simple_agent], + tasks=[simple_task], + embedder=valid_config + ) + + assert crew.embedder is not None + assert crew.embedder["provider"] == "openai" + + +def test_valid_ollama_config(simple_agent, simple_task): + """Test that a valid ollama config is accepted.""" + valid_config = { + "provider": "ollama", + "model": "nomic-embed-text" + } + + crew = Crew( + agents=[simple_agent], + tasks=[simple_task], + embedder=valid_config + ) + + assert crew.embedder is not None + assert crew.embedder["provider"] == "ollama" + + +def test_none_embedder_config(simple_agent, simple_task): + """Test that None embedder config is accepted.""" + crew = Crew( + agents=[simple_agent], + tasks=[simple_task], + embedder=None + ) + + assert crew.embedder is None + + +def test_embedder_config_without_provider_field(simple_agent, simple_task): + """Test that config without provider field is handled by Pydantic.""" + invalid_config = { + "api_key": "test_key", + "model_name": "test_model" + } + + with pytest.raises(ValidationError): + Crew( + agents=[simple_agent], + tasks=[simple_task], + embedder=invalid_config + ) + + +def test_sentence_transformer_missing_config_field(simple_agent, simple_task): + """Test that missing config field for sentence-transformer gives a clear error.""" + invalid_config = { + "provider": "sentence-transformer", + "model_name": "all-MiniLM-L6-v2" + } + + with pytest.raises(ValidationError) as exc_info: + Crew( + agents=[simple_agent], + tasks=[simple_task], + embedder=invalid_config + ) + + error_message = str(exc_info.value) + assert "Invalid embedder configuration for provider 'sentence-transformer'" in error_message + assert "missing the required 'config' field" in error_message + + +def test_voyageai_missing_config_field(simple_agent, simple_task): + """Test that missing config field for voyageai gives a clear error.""" + invalid_config = { + "provider": "voyageai", + "api_key": "test_key" + } + + with pytest.raises(ValidationError) as exc_info: + Crew( + agents=[simple_agent], + tasks=[simple_task], + embedder=invalid_config + ) + + error_message = str(exc_info.value) + assert "Invalid embedder configuration for provider 'voyageai'" in error_message + assert "missing the required 'config' field" in error_message + + +def test_error_shows_received_config(simple_agent, simple_task): + """Test that error message shows the received configuration.""" + invalid_config = { + "provider": "google-generativeai", + "model_name": "models/text-embedding-004", + "api_key": "test_key" + } + + with pytest.raises(ValidationError) as exc_info: + Crew( + agents=[simple_agent], + tasks=[simple_task], + embedder=invalid_config + ) + + error_message = str(exc_info.value) + assert "But received:" in error_message + assert '"model_name": "models/text-embedding-004"' in error_message + + +def test_single_validation_error_not_multiple(simple_agent, simple_task): + """Test that we get a single clear error, not multiple confusing errors.""" + invalid_config = { + "provider": "google-generativeai", + "model_name": "models/text-embedding-004", + "api_key": "test_key" + } + + with pytest.raises(ValidationError) as exc_info: + Crew( + agents=[simple_agent], + tasks=[simple_task], + embedder=invalid_config + ) + + error_message = str(exc_info.value) + error_count = error_message.count("validation error") + + assert error_count == 1, f"Expected 1 validation error, got {error_count}" + + assert "AzureProviderSpec" not in error_message + assert "BedrockProviderSpec" not in error_message + assert "CohereProviderSpec" not in error_message