mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 20:38:29 +00:00
Compare commits
2 Commits
main
...
devin/1761
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e8c85e3bbb | ||
|
|
1d575d96e3 |
@@ -305,6 +305,121 @@ class Crew(FlowTrackable, BaseModel):
|
||||
# TODO: Improve typing
|
||||
return json.loads(v) if isinstance(v, Json) else v # type: ignore
|
||||
|
||||
@field_validator("embedder", mode="before")
|
||||
@classmethod
|
||||
def validate_embedder_config(cls, v: Any) -> Any:
|
||||
"""Validates embedder configuration and provides clear error messages.
|
||||
|
||||
Args:
|
||||
v: The embedder configuration to be validated.
|
||||
|
||||
Returns:
|
||||
The embedder config if it is valid.
|
||||
|
||||
Raises:
|
||||
PydanticCustomError: If the embedder configuration is invalid,
|
||||
with a clear, helpful error message.
|
||||
"""
|
||||
if v is None:
|
||||
return v
|
||||
|
||||
if not isinstance(v, dict):
|
||||
return v
|
||||
|
||||
provider = v.get("provider")
|
||||
if not provider:
|
||||
return v
|
||||
|
||||
valid_providers = [
|
||||
"azure",
|
||||
"amazon-bedrock",
|
||||
"cohere",
|
||||
"custom",
|
||||
"google-generativeai",
|
||||
"google-vertex",
|
||||
"huggingface",
|
||||
"instructor",
|
||||
"jina",
|
||||
"ollama",
|
||||
"onnx",
|
||||
"openai",
|
||||
"openclip",
|
||||
"roboflow",
|
||||
"sentence-transformer",
|
||||
"text2vec",
|
||||
"voyageai",
|
||||
"watsonx",
|
||||
]
|
||||
|
||||
if provider not in valid_providers:
|
||||
raise PydanticCustomError(
|
||||
"invalid_embedder_provider",
|
||||
(
|
||||
f"Invalid embedder provider: '{provider}'. "
|
||||
f"Valid providers are: {', '.join(valid_providers)}. "
|
||||
"Please check the documentation for the correct provider name."
|
||||
),
|
||||
{},
|
||||
)
|
||||
|
||||
providers_requiring_config = [
|
||||
"google-generativeai",
|
||||
"google-vertex",
|
||||
"openclip",
|
||||
"roboflow",
|
||||
"sentence-transformer",
|
||||
"text2vec",
|
||||
"voyageai",
|
||||
]
|
||||
|
||||
if provider in providers_requiring_config and "config" not in v:
|
||||
example_config = {}
|
||||
if provider == "google-generativeai":
|
||||
example_config = {
|
||||
"provider": "google-generativeai",
|
||||
"config": {
|
||||
"api_key": "your_api_key",
|
||||
"model_name": "models/embedding-001",
|
||||
}
|
||||
}
|
||||
elif provider == "google-vertex":
|
||||
example_config = {
|
||||
"provider": "google-vertex",
|
||||
"config": {
|
||||
"api_key": "your_api_key",
|
||||
"model_name": "textembedding-gecko",
|
||||
"project_id": "your_project_id",
|
||||
"region": "us-central1",
|
||||
}
|
||||
}
|
||||
elif provider == "openai":
|
||||
example_config = {
|
||||
"provider": "openai",
|
||||
"api_key": "your_api_key",
|
||||
"model": "text-embedding-3-small",
|
||||
}
|
||||
else:
|
||||
example_config = {
|
||||
"provider": provider,
|
||||
"config": {
|
||||
"api_key": "your_api_key",
|
||||
"model_name": "your_model_name",
|
||||
}
|
||||
}
|
||||
|
||||
raise PydanticCustomError(
|
||||
"invalid_embedder_config_structure",
|
||||
(
|
||||
f"Invalid embedder configuration for provider '{provider}'. "
|
||||
f"The configuration is missing the required 'config' field. "
|
||||
f"Expected structure:\n{json.dumps(example_config, indent=2)}\n"
|
||||
f"But received:\n{json.dumps(v, indent=2)}"
|
||||
),
|
||||
{},
|
||||
)
|
||||
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_private_attrs(self) -> Crew:
|
||||
"""set private attributes."""
|
||||
|
||||
257
lib/crewai/tests/test_embedder_validation.py
Normal file
257
lib/crewai/tests/test_embedder_validation.py
Normal file
@@ -0,0 +1,257 @@
|
||||
"""Tests for embedder configuration validation (Issue #3755)."""
|
||||
|
||||
import pytest
|
||||
from pydantic_core import ValidationError
|
||||
|
||||
from crewai import Agent, Crew, Task
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_agent():
|
||||
"""Create a simple agent for testing."""
|
||||
return Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_task(simple_agent):
|
||||
"""Create a simple task for testing."""
|
||||
return Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
agent=simple_agent
|
||||
)
|
||||
|
||||
|
||||
def test_invalid_embedder_provider_name(simple_agent, simple_task):
|
||||
"""Test that an invalid provider name gives a clear error message."""
|
||||
invalid_config = {
|
||||
"provider": "invalid-provider",
|
||||
"config": {
|
||||
"api_key": "test_key",
|
||||
"model_name": "test_model"
|
||||
}
|
||||
}
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
Crew(
|
||||
agents=[simple_agent],
|
||||
tasks=[simple_task],
|
||||
embedder=invalid_config
|
||||
)
|
||||
|
||||
error_message = str(exc_info.value)
|
||||
assert "Invalid embedder provider: 'invalid-provider'" in error_message
|
||||
assert "Valid providers are:" in error_message
|
||||
assert "google-generativeai" in error_message
|
||||
assert "openai" in error_message
|
||||
|
||||
|
||||
def test_google_generativeai_missing_config_field(simple_agent, simple_task):
|
||||
"""Test that missing config field for google-generativeai gives a clear error."""
|
||||
invalid_config = {
|
||||
"provider": "google-generativeai",
|
||||
"model_name": "models/text-embedding-004",
|
||||
"api_key": "test_key"
|
||||
}
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
Crew(
|
||||
agents=[simple_agent],
|
||||
tasks=[simple_task],
|
||||
embedder=invalid_config
|
||||
)
|
||||
|
||||
error_message = str(exc_info.value)
|
||||
assert "Invalid embedder configuration for provider 'google-generativeai'" in error_message
|
||||
assert "missing the required 'config' field" in error_message
|
||||
assert "Expected structure:" in error_message
|
||||
assert '"provider": "google-generativeai"' in error_message
|
||||
assert '"config"' in error_message
|
||||
|
||||
|
||||
def test_google_vertex_missing_config_field(simple_agent, simple_task):
|
||||
"""Test that missing config field for google-vertex gives a clear error."""
|
||||
invalid_config = {
|
||||
"provider": "google-vertex",
|
||||
"model_name": "textembedding-gecko",
|
||||
"api_key": "test_key"
|
||||
}
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
Crew(
|
||||
agents=[simple_agent],
|
||||
tasks=[simple_task],
|
||||
embedder=invalid_config
|
||||
)
|
||||
|
||||
error_message = str(exc_info.value)
|
||||
assert "Invalid embedder configuration for provider 'google-vertex'" in error_message
|
||||
assert "missing the required 'config' field" in error_message
|
||||
assert "Expected structure:" in error_message
|
||||
|
||||
|
||||
def test_valid_google_generativeai_config(simple_agent, simple_task):
|
||||
"""Test that a valid google-generativeai config is accepted."""
|
||||
valid_config = {
|
||||
"provider": "google-generativeai",
|
||||
"config": {
|
||||
"api_key": "test_key",
|
||||
"model_name": "models/embedding-001"
|
||||
}
|
||||
}
|
||||
|
||||
crew = Crew(
|
||||
agents=[simple_agent],
|
||||
tasks=[simple_task],
|
||||
embedder=valid_config
|
||||
)
|
||||
|
||||
assert crew.embedder == valid_config
|
||||
|
||||
|
||||
def test_valid_openai_config(simple_agent, simple_task):
|
||||
"""Test that a valid openai config is accepted."""
|
||||
valid_config = {
|
||||
"provider": "openai",
|
||||
"api_key": "test_key",
|
||||
"model": "text-embedding-3-small"
|
||||
}
|
||||
|
||||
crew = Crew(
|
||||
agents=[simple_agent],
|
||||
tasks=[simple_task],
|
||||
embedder=valid_config
|
||||
)
|
||||
|
||||
assert crew.embedder is not None
|
||||
assert crew.embedder["provider"] == "openai"
|
||||
|
||||
|
||||
def test_valid_ollama_config(simple_agent, simple_task):
|
||||
"""Test that a valid ollama config is accepted."""
|
||||
valid_config = {
|
||||
"provider": "ollama",
|
||||
"model": "nomic-embed-text"
|
||||
}
|
||||
|
||||
crew = Crew(
|
||||
agents=[simple_agent],
|
||||
tasks=[simple_task],
|
||||
embedder=valid_config
|
||||
)
|
||||
|
||||
assert crew.embedder is not None
|
||||
assert crew.embedder["provider"] == "ollama"
|
||||
|
||||
|
||||
def test_none_embedder_config(simple_agent, simple_task):
|
||||
"""Test that None embedder config is accepted."""
|
||||
crew = Crew(
|
||||
agents=[simple_agent],
|
||||
tasks=[simple_task],
|
||||
embedder=None
|
||||
)
|
||||
|
||||
assert crew.embedder is None
|
||||
|
||||
|
||||
def test_embedder_config_without_provider_field(simple_agent, simple_task):
|
||||
"""Test that config without provider field is handled by Pydantic."""
|
||||
invalid_config = {
|
||||
"api_key": "test_key",
|
||||
"model_name": "test_model"
|
||||
}
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
Crew(
|
||||
agents=[simple_agent],
|
||||
tasks=[simple_task],
|
||||
embedder=invalid_config
|
||||
)
|
||||
|
||||
|
||||
def test_sentence_transformer_missing_config_field(simple_agent, simple_task):
|
||||
"""Test that missing config field for sentence-transformer gives a clear error."""
|
||||
invalid_config = {
|
||||
"provider": "sentence-transformer",
|
||||
"model_name": "all-MiniLM-L6-v2"
|
||||
}
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
Crew(
|
||||
agents=[simple_agent],
|
||||
tasks=[simple_task],
|
||||
embedder=invalid_config
|
||||
)
|
||||
|
||||
error_message = str(exc_info.value)
|
||||
assert "Invalid embedder configuration for provider 'sentence-transformer'" in error_message
|
||||
assert "missing the required 'config' field" in error_message
|
||||
|
||||
|
||||
def test_voyageai_missing_config_field(simple_agent, simple_task):
|
||||
"""Test that missing config field for voyageai gives a clear error."""
|
||||
invalid_config = {
|
||||
"provider": "voyageai",
|
||||
"api_key": "test_key"
|
||||
}
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
Crew(
|
||||
agents=[simple_agent],
|
||||
tasks=[simple_task],
|
||||
embedder=invalid_config
|
||||
)
|
||||
|
||||
error_message = str(exc_info.value)
|
||||
assert "Invalid embedder configuration for provider 'voyageai'" in error_message
|
||||
assert "missing the required 'config' field" in error_message
|
||||
|
||||
|
||||
def test_error_shows_received_config(simple_agent, simple_task):
|
||||
"""Test that error message shows the received configuration."""
|
||||
invalid_config = {
|
||||
"provider": "google-generativeai",
|
||||
"model_name": "models/text-embedding-004",
|
||||
"api_key": "test_key"
|
||||
}
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
Crew(
|
||||
agents=[simple_agent],
|
||||
tasks=[simple_task],
|
||||
embedder=invalid_config
|
||||
)
|
||||
|
||||
error_message = str(exc_info.value)
|
||||
assert "But received:" in error_message
|
||||
assert '"model_name": "models/text-embedding-004"' in error_message
|
||||
|
||||
|
||||
def test_single_validation_error_not_multiple(simple_agent, simple_task):
|
||||
"""Test that we get a single clear error, not multiple confusing errors."""
|
||||
invalid_config = {
|
||||
"provider": "google-generativeai",
|
||||
"model_name": "models/text-embedding-004",
|
||||
"api_key": "test_key"
|
||||
}
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
Crew(
|
||||
agents=[simple_agent],
|
||||
tasks=[simple_task],
|
||||
embedder=invalid_config
|
||||
)
|
||||
|
||||
error_message = str(exc_info.value)
|
||||
error_count = error_message.count("validation error")
|
||||
|
||||
assert error_count == 1, f"Expected 1 validation error, got {error_count}"
|
||||
|
||||
assert "AzureProviderSpec" not in error_message
|
||||
assert "BedrockProviderSpec" not in error_message
|
||||
assert "CohereProviderSpec" not in error_message
|
||||
Reference in New Issue
Block a user