fix: add backward compatibility for legacy RAG tool config format

Fixes #4028 - WebsiteSearchTool always requires OpenAI API key even when
Ollama or other providers are specified.

The issue was that the documentation showed the old config format with
'llm' and 'embedder' keys, but the actual RagToolConfig type expects
'embedding_model' and 'vectordb' keys. When the old format was passed,
the embedder config was not recognized, causing the tool to fall back
to the default OpenAI embedding function which requires OPENAI_API_KEY.

Changes:
- Add _normalize_legacy_config method to RagTool that maps legacy
  'embedder' key to 'embedding_model'
- Emit deprecation warnings for legacy config keys
- Ignore 'llm' key with warning (not used in RAG tools)
- Add tests for backward compatibility
- Update documentation to show new config format with examples

Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
Devin AI
2025-12-04 11:27:17 +00:00
parent 633e279b51
commit 36d9ca099e
3 changed files with 283 additions and 18 deletions

View File

@@ -50,29 +50,72 @@ tool = WebsiteSearchTool(website='https://example.com')
## Customization Options
By default, the tool uses OpenAI for both embeddings and summarization. To customize the model, you can use a config dictionary as follows:
By default, the tool uses OpenAI for embeddings. To customize the embedding model, you can use a config dictionary as follows:
```python Code
tool = WebsiteSearchTool(
config=dict(
llm=dict(
provider="ollama", # or google, openai, anthropic, llama2, ...
embedding_model=dict(
provider="ollama", # or openai, google-generativeai, azure, etc.
config=dict(
model="llama2",
# temperature=0.5,
# top_p=1,
# stream=true,
),
),
embedder=dict(
provider="google-generativeai", # or openai, ollama, ...
config=dict(
model_name="gemini-embedding-001",
task_type="RETRIEVAL_DOCUMENT",
# title="Embeddings",
model_name="nomic-embed-text",
url="http://localhost:11434/api/embeddings",
),
),
)
)
```
```
### Available Embedding Providers
The following embedding providers are supported:
- `openai` - OpenAI embeddings (default)
- `ollama` - Ollama local embeddings
- `google-generativeai` - Google Generative AI embeddings
- `azure` - Azure OpenAI embeddings
- `huggingface` - HuggingFace embeddings
- `cohere` - Cohere embeddings
- `voyageai` - Voyage AI embeddings
- And more...
### Example with Google Generative AI
```python Code
tool = WebsiteSearchTool(
config=dict(
embedding_model=dict(
provider="google-generativeai",
config=dict(
model_name="models/embedding-001",
task_type="RETRIEVAL_DOCUMENT",
),
),
)
)
```
### Example with Azure OpenAI
```python Code
tool = WebsiteSearchTool(
config=dict(
embedding_model=dict(
provider="azure",
config=dict(
model="text-embedding-3-small",
api_key="your-api-key",
api_base="https://your-resource.openai.azure.com/",
api_version="2024-02-01",
deployment_id="your-deployment-id",
),
),
)
)
```
<Note>
The `llm` and `embedder` config keys from older documentation are deprecated.
Please use `embedding_model` instead. The `llm` key is not used by RAG tools -
the LLM for generation is controlled by the agent's LLM configuration.
</Note>

View File

@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from typing import Any, Literal, cast
import warnings
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
from crewai.rag.embeddings.factory import build_embedder
@@ -131,10 +132,17 @@ class RagTool(BaseTool):
@field_validator("config", mode="before")
@classmethod
def _validate_config(cls, value: Any) -> Any:
"""Validate config with improved error messages for embedding providers."""
"""Validate config with improved error messages for embedding providers.
Also provides backward compatibility for the legacy config format that used
'llm' and 'embedder' keys instead of 'embedding_model' and 'vectordb'.
"""
if not isinstance(value, dict):
return value
# Handle backward compatibility for legacy config format
value = cls._normalize_legacy_config(value)
embedding_model = value.get("embedding_model")
if embedding_model:
try:
@@ -144,6 +152,45 @@ class RagTool(BaseTool):
return value
@classmethod
def _normalize_legacy_config(cls, config: dict[str, Any]) -> dict[str, Any]:
"""Normalize legacy config format to the current format.
The legacy format used 'llm' and 'embedder' keys, while the current format
uses 'embedding_model' and 'vectordb' keys.
Args:
config: The configuration dictionary to normalize.
Returns:
A normalized configuration dictionary using the current format.
"""
normalized = dict(config)
# Handle legacy 'embedder' key -> 'embedding_model'
if "embedder" in normalized and "embedding_model" not in normalized:
warnings.warn(
"The 'embedder' config key is deprecated. "
"Please use 'embedding_model' instead. "
"Example: config={'embedding_model': {'provider': 'ollama', 'config': {...}}}",
DeprecationWarning,
stacklevel=4,
)
normalized["embedding_model"] = normalized.pop("embedder")
# Handle legacy 'llm' key - this is not used in RAG tools
if "llm" in normalized:
warnings.warn(
"The 'llm' config key is not used by RAG tools and will be ignored. "
"The LLM for generation is controlled by the agent's LLM configuration, "
"not the tool configuration.",
DeprecationWarning,
stacklevel=4,
)
normalized.pop("llm")
return normalized
@model_validator(mode="after")
def _ensure_adapter(self) -> Self:
if isinstance(self.adapter, RagTool._AdapterPlaceholder):

View File

@@ -2,6 +2,9 @@ from pathlib import Path
from tempfile import TemporaryDirectory
from typing import cast
from unittest.mock import MagicMock, Mock, patch
import warnings
import pytest
from crewai_tools.adapters.crewai_rag_adapter import CrewAIRagAdapter
from crewai_tools.tools.rag.rag_tool import RagTool
@@ -299,3 +302,175 @@ def test_rag_tool_config_with_qdrant_and_azure_embeddings(
assert tool.adapter is not None
assert isinstance(tool.adapter, CrewAIRagAdapter)
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
def test_rag_tool_with_legacy_embedder_config(
mock_create_client: Mock,
) -> None:
"""Test that RagTool accepts legacy 'embedder' config key with deprecation warning.
This test verifies the fix for issue #4028 where WebsiteSearchTool and other
RAG tools always required OpenAI API key even when using Ollama or other providers.
The legacy config format used 'embedder' key instead of 'embedding_model'.
"""
mock_embedding_func = MagicMock()
mock_embedding_func.return_value = [[0.1] * 1536]
mock_client = MagicMock()
mock_client.get_or_create_collection = MagicMock(return_value=None)
mock_create_client.return_value = mock_client
with patch(
"crewai_tools.tools.rag.rag_tool.build_embedder",
return_value=mock_embedding_func,
):
class MyTool(RagTool):
pass
# Legacy config format using 'embedder' key (as shown in old docs)
legacy_config = {
"embedder": {
"provider": "ollama",
"config": {
"model_name": "nomic-embed-text",
"url": "http://localhost:11434/api/embeddings",
},
},
}
# Should emit deprecation warning but still work
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
tool = MyTool(config=legacy_config)
# Check that deprecation warning was issued
deprecation_warnings = [
warning
for warning in w
if issubclass(warning.category, DeprecationWarning)
]
assert len(deprecation_warnings) >= 1
assert "embedder" in str(deprecation_warnings[0].message)
assert "embedding_model" in str(deprecation_warnings[0].message)
assert tool.adapter is not None
assert isinstance(tool.adapter, CrewAIRagAdapter)
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
def test_rag_tool_with_legacy_llm_config_ignored(
mock_create_client: Mock,
) -> None:
"""Test that RagTool ignores legacy 'llm' config key with deprecation warning.
The 'llm' key was shown in old documentation but is not used by RAG tools.
The LLM for generation is controlled by the agent's LLM configuration.
"""
mock_embedding_func = MagicMock()
mock_embedding_func.return_value = [[0.1] * 1536]
mock_client = MagicMock()
mock_client.get_or_create_collection = MagicMock(return_value=None)
mock_create_client.return_value = mock_client
with patch(
"crewai_tools.tools.rag.rag_tool.build_embedder",
return_value=mock_embedding_func,
):
class MyTool(RagTool):
pass
# Legacy config format with both 'llm' and 'embedder' keys
legacy_config = {
"llm": {
"provider": "ollama",
"config": {
"model": "llama2",
},
},
"embedder": {
"provider": "ollama",
"config": {
"model_name": "nomic-embed-text",
"url": "http://localhost:11434/api/embeddings",
},
},
}
# Should emit deprecation warnings for both keys
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
tool = MyTool(config=legacy_config)
# Check that deprecation warnings were issued for both keys
deprecation_warnings = [
warning
for warning in w
if issubclass(warning.category, DeprecationWarning)
]
assert len(deprecation_warnings) >= 2
warning_messages = [str(warning.message) for warning in deprecation_warnings]
assert any("llm" in msg for msg in warning_messages)
assert any("embedder" in msg for msg in warning_messages)
assert tool.adapter is not None
assert isinstance(tool.adapter, CrewAIRagAdapter)
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
def test_rag_tool_legacy_config_does_not_override_new_config(
mock_create_client: Mock,
) -> None:
"""Test that legacy 'embedder' key does not override 'embedding_model' if both present."""
mock_embedding_func = MagicMock()
mock_embedding_func.return_value = [[0.1] * 1536]
mock_client = MagicMock()
mock_client.get_or_create_collection = MagicMock(return_value=None)
mock_create_client.return_value = mock_client
with patch(
"crewai_tools.tools.rag.rag_tool.build_embedder",
return_value=mock_embedding_func,
) as mock_build:
class MyTool(RagTool):
pass
# Config with both old and new keys - new key should take precedence
config = {
"embedder": {
"provider": "ollama",
"config": {"model_name": "old-model"},
},
"embedding_model": {
"provider": "openai",
"config": {"model": "text-embedding-3-small", "api_key": "test-key"},
},
}
# No deprecation warning for 'embedder' since 'embedding_model' is present
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
tool = MyTool(config=config)
# Should NOT warn about 'embedder' since 'embedding_model' takes precedence
embedder_warnings = [
warning
for warning in w
if issubclass(warning.category, DeprecationWarning)
and "embedder" in str(warning.message)
]
assert len(embedder_warnings) == 0
assert tool.adapter is not None
# Verify that the new 'embedding_model' config was used, not the legacy 'embedder'
call_args = mock_build.call_args
assert call_args is not None
spec = call_args[0][0]
assert spec["provider"] == "openai"