mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-03 05:08:29 +00:00
Compare commits
1 Commits
devin/1765
...
devin/1764
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
36d9ca099e |
@@ -50,29 +50,72 @@ tool = WebsiteSearchTool(website='https://example.com')
|
||||
|
||||
## Customization Options
|
||||
|
||||
By default, the tool uses OpenAI for both embeddings and summarization. To customize the model, you can use a config dictionary as follows:
|
||||
|
||||
By default, the tool uses OpenAI for embeddings. To customize the embedding model, you can use a config dictionary as follows:
|
||||
|
||||
```python Code
|
||||
tool = WebsiteSearchTool(
|
||||
config=dict(
|
||||
llm=dict(
|
||||
provider="ollama", # or google, openai, anthropic, llama2, ...
|
||||
embedding_model=dict(
|
||||
provider="ollama", # or openai, google-generativeai, azure, etc.
|
||||
config=dict(
|
||||
model="llama2",
|
||||
# temperature=0.5,
|
||||
# top_p=1,
|
||||
# stream=true,
|
||||
),
|
||||
),
|
||||
embedder=dict(
|
||||
provider="google-generativeai", # or openai, ollama, ...
|
||||
config=dict(
|
||||
model_name="gemini-embedding-001",
|
||||
task_type="RETRIEVAL_DOCUMENT",
|
||||
# title="Embeddings",
|
||||
model_name="nomic-embed-text",
|
||||
url="http://localhost:11434/api/embeddings",
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
```
|
||||
```
|
||||
|
||||
### Available Embedding Providers
|
||||
|
||||
The following embedding providers are supported:
|
||||
|
||||
- `openai` - OpenAI embeddings (default)
|
||||
- `ollama` - Ollama local embeddings
|
||||
- `google-generativeai` - Google Generative AI embeddings
|
||||
- `azure` - Azure OpenAI embeddings
|
||||
- `huggingface` - HuggingFace embeddings
|
||||
- `cohere` - Cohere embeddings
|
||||
- `voyageai` - Voyage AI embeddings
|
||||
- And more...
|
||||
|
||||
### Example with Google Generative AI
|
||||
|
||||
```python Code
|
||||
tool = WebsiteSearchTool(
|
||||
config=dict(
|
||||
embedding_model=dict(
|
||||
provider="google-generativeai",
|
||||
config=dict(
|
||||
model_name="models/embedding-001",
|
||||
task_type="RETRIEVAL_DOCUMENT",
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
### Example with Azure OpenAI
|
||||
|
||||
```python Code
|
||||
tool = WebsiteSearchTool(
|
||||
config=dict(
|
||||
embedding_model=dict(
|
||||
provider="azure",
|
||||
config=dict(
|
||||
model="text-embedding-3-small",
|
||||
api_key="your-api-key",
|
||||
api_base="https://your-resource.openai.azure.com/",
|
||||
api_version="2024-02-01",
|
||||
deployment_id="your-deployment-id",
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
<Note>
|
||||
The `llm` and `embedder` config keys from older documentation are deprecated.
|
||||
Please use `embedding_model` instead. The `llm` key is not used by RAG tools -
|
||||
the LLM for generation is controlled by the agent's LLM configuration.
|
||||
</Note>
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Literal, cast
|
||||
import warnings
|
||||
|
||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
||||
from crewai.rag.embeddings.factory import build_embedder
|
||||
@@ -131,10 +132,17 @@ class RagTool(BaseTool):
|
||||
@field_validator("config", mode="before")
|
||||
@classmethod
|
||||
def _validate_config(cls, value: Any) -> Any:
|
||||
"""Validate config with improved error messages for embedding providers."""
|
||||
"""Validate config with improved error messages for embedding providers.
|
||||
|
||||
Also provides backward compatibility for the legacy config format that used
|
||||
'llm' and 'embedder' keys instead of 'embedding_model' and 'vectordb'.
|
||||
"""
|
||||
if not isinstance(value, dict):
|
||||
return value
|
||||
|
||||
# Handle backward compatibility for legacy config format
|
||||
value = cls._normalize_legacy_config(value)
|
||||
|
||||
embedding_model = value.get("embedding_model")
|
||||
if embedding_model:
|
||||
try:
|
||||
@@ -144,6 +152,45 @@ class RagTool(BaseTool):
|
||||
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def _normalize_legacy_config(cls, config: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Normalize legacy config format to the current format.
|
||||
|
||||
The legacy format used 'llm' and 'embedder' keys, while the current format
|
||||
uses 'embedding_model' and 'vectordb' keys.
|
||||
|
||||
Args:
|
||||
config: The configuration dictionary to normalize.
|
||||
|
||||
Returns:
|
||||
A normalized configuration dictionary using the current format.
|
||||
"""
|
||||
normalized = dict(config)
|
||||
|
||||
# Handle legacy 'embedder' key -> 'embedding_model'
|
||||
if "embedder" in normalized and "embedding_model" not in normalized:
|
||||
warnings.warn(
|
||||
"The 'embedder' config key is deprecated. "
|
||||
"Please use 'embedding_model' instead. "
|
||||
"Example: config={'embedding_model': {'provider': 'ollama', 'config': {...}}}",
|
||||
DeprecationWarning,
|
||||
stacklevel=4,
|
||||
)
|
||||
normalized["embedding_model"] = normalized.pop("embedder")
|
||||
|
||||
# Handle legacy 'llm' key - this is not used in RAG tools
|
||||
if "llm" in normalized:
|
||||
warnings.warn(
|
||||
"The 'llm' config key is not used by RAG tools and will be ignored. "
|
||||
"The LLM for generation is controlled by the agent's LLM configuration, "
|
||||
"not the tool configuration.",
|
||||
DeprecationWarning,
|
||||
stacklevel=4,
|
||||
)
|
||||
normalized.pop("llm")
|
||||
|
||||
return normalized
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _ensure_adapter(self) -> Self:
|
||||
if isinstance(self.adapter, RagTool._AdapterPlaceholder):
|
||||
|
||||
@@ -2,6 +2,9 @@ from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai_tools.adapters.crewai_rag_adapter import CrewAIRagAdapter
|
||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||
@@ -299,3 +302,175 @@ def test_rag_tool_config_with_qdrant_and_azure_embeddings(
|
||||
|
||||
assert tool.adapter is not None
|
||||
assert isinstance(tool.adapter, CrewAIRagAdapter)
|
||||
|
||||
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
|
||||
def test_rag_tool_with_legacy_embedder_config(
|
||||
mock_create_client: Mock,
|
||||
) -> None:
|
||||
"""Test that RagTool accepts legacy 'embedder' config key with deprecation warning.
|
||||
|
||||
This test verifies the fix for issue #4028 where WebsiteSearchTool and other
|
||||
RAG tools always required OpenAI API key even when using Ollama or other providers.
|
||||
The legacy config format used 'embedder' key instead of 'embedding_model'.
|
||||
"""
|
||||
mock_embedding_func = MagicMock()
|
||||
mock_embedding_func.return_value = [[0.1] * 1536]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
with patch(
|
||||
"crewai_tools.tools.rag.rag_tool.build_embedder",
|
||||
return_value=mock_embedding_func,
|
||||
):
|
||||
|
||||
class MyTool(RagTool):
|
||||
pass
|
||||
|
||||
# Legacy config format using 'embedder' key (as shown in old docs)
|
||||
legacy_config = {
|
||||
"embedder": {
|
||||
"provider": "ollama",
|
||||
"config": {
|
||||
"model_name": "nomic-embed-text",
|
||||
"url": "http://localhost:11434/api/embeddings",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# Should emit deprecation warning but still work
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
tool = MyTool(config=legacy_config)
|
||||
|
||||
# Check that deprecation warning was issued
|
||||
deprecation_warnings = [
|
||||
warning
|
||||
for warning in w
|
||||
if issubclass(warning.category, DeprecationWarning)
|
||||
]
|
||||
assert len(deprecation_warnings) >= 1
|
||||
assert "embedder" in str(deprecation_warnings[0].message)
|
||||
assert "embedding_model" in str(deprecation_warnings[0].message)
|
||||
|
||||
assert tool.adapter is not None
|
||||
assert isinstance(tool.adapter, CrewAIRagAdapter)
|
||||
|
||||
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
|
||||
def test_rag_tool_with_legacy_llm_config_ignored(
|
||||
mock_create_client: Mock,
|
||||
) -> None:
|
||||
"""Test that RagTool ignores legacy 'llm' config key with deprecation warning.
|
||||
|
||||
The 'llm' key was shown in old documentation but is not used by RAG tools.
|
||||
The LLM for generation is controlled by the agent's LLM configuration.
|
||||
"""
|
||||
mock_embedding_func = MagicMock()
|
||||
mock_embedding_func.return_value = [[0.1] * 1536]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
with patch(
|
||||
"crewai_tools.tools.rag.rag_tool.build_embedder",
|
||||
return_value=mock_embedding_func,
|
||||
):
|
||||
|
||||
class MyTool(RagTool):
|
||||
pass
|
||||
|
||||
# Legacy config format with both 'llm' and 'embedder' keys
|
||||
legacy_config = {
|
||||
"llm": {
|
||||
"provider": "ollama",
|
||||
"config": {
|
||||
"model": "llama2",
|
||||
},
|
||||
},
|
||||
"embedder": {
|
||||
"provider": "ollama",
|
||||
"config": {
|
||||
"model_name": "nomic-embed-text",
|
||||
"url": "http://localhost:11434/api/embeddings",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# Should emit deprecation warnings for both keys
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
tool = MyTool(config=legacy_config)
|
||||
|
||||
# Check that deprecation warnings were issued for both keys
|
||||
deprecation_warnings = [
|
||||
warning
|
||||
for warning in w
|
||||
if issubclass(warning.category, DeprecationWarning)
|
||||
]
|
||||
assert len(deprecation_warnings) >= 2
|
||||
|
||||
warning_messages = [str(warning.message) for warning in deprecation_warnings]
|
||||
assert any("llm" in msg for msg in warning_messages)
|
||||
assert any("embedder" in msg for msg in warning_messages)
|
||||
|
||||
assert tool.adapter is not None
|
||||
assert isinstance(tool.adapter, CrewAIRagAdapter)
|
||||
|
||||
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
|
||||
def test_rag_tool_legacy_config_does_not_override_new_config(
|
||||
mock_create_client: Mock,
|
||||
) -> None:
|
||||
"""Test that legacy 'embedder' key does not override 'embedding_model' if both present."""
|
||||
mock_embedding_func = MagicMock()
|
||||
mock_embedding_func.return_value = [[0.1] * 1536]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
with patch(
|
||||
"crewai_tools.tools.rag.rag_tool.build_embedder",
|
||||
return_value=mock_embedding_func,
|
||||
) as mock_build:
|
||||
|
||||
class MyTool(RagTool):
|
||||
pass
|
||||
|
||||
# Config with both old and new keys - new key should take precedence
|
||||
config = {
|
||||
"embedder": {
|
||||
"provider": "ollama",
|
||||
"config": {"model_name": "old-model"},
|
||||
},
|
||||
"embedding_model": {
|
||||
"provider": "openai",
|
||||
"config": {"model": "text-embedding-3-small", "api_key": "test-key"},
|
||||
},
|
||||
}
|
||||
|
||||
# No deprecation warning for 'embedder' since 'embedding_model' is present
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
tool = MyTool(config=config)
|
||||
|
||||
# Should NOT warn about 'embedder' since 'embedding_model' takes precedence
|
||||
embedder_warnings = [
|
||||
warning
|
||||
for warning in w
|
||||
if issubclass(warning.category, DeprecationWarning)
|
||||
and "embedder" in str(warning.message)
|
||||
]
|
||||
assert len(embedder_warnings) == 0
|
||||
|
||||
assert tool.adapter is not None
|
||||
|
||||
# Verify that the new 'embedding_model' config was used, not the legacy 'embedder'
|
||||
call_args = mock_build.call_args
|
||||
assert call_args is not None
|
||||
spec = call_args[0][0]
|
||||
assert spec["provider"] == "openai"
|
||||
|
||||
Reference in New Issue
Block a user