mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-28 01:28:14 +00:00
Compare commits
1 Commits
llm-event-
...
devin/1764
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
36d9ca099e |
@@ -50,29 +50,72 @@ tool = WebsiteSearchTool(website='https://example.com')
|
|||||||
|
|
||||||
## Customization Options
|
## Customization Options
|
||||||
|
|
||||||
By default, the tool uses OpenAI for both embeddings and summarization. To customize the model, you can use a config dictionary as follows:
|
By default, the tool uses OpenAI for embeddings. To customize the embedding model, you can use a config dictionary as follows:
|
||||||
|
|
||||||
|
|
||||||
```python Code
|
```python Code
|
||||||
tool = WebsiteSearchTool(
|
tool = WebsiteSearchTool(
|
||||||
config=dict(
|
config=dict(
|
||||||
llm=dict(
|
embedding_model=dict(
|
||||||
provider="ollama", # or google, openai, anthropic, llama2, ...
|
provider="ollama", # or openai, google-generativeai, azure, etc.
|
||||||
config=dict(
|
config=dict(
|
||||||
model="llama2",
|
model_name="nomic-embed-text",
|
||||||
# temperature=0.5,
|
url="http://localhost:11434/api/embeddings",
|
||||||
# top_p=1,
|
|
||||||
# stream=true,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
embedder=dict(
|
|
||||||
provider="google-generativeai", # or openai, ollama, ...
|
|
||||||
config=dict(
|
|
||||||
model_name="gemini-embedding-001",
|
|
||||||
task_type="RETRIEVAL_DOCUMENT",
|
|
||||||
# title="Embeddings",
|
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Available Embedding Providers
|
||||||
|
|
||||||
|
The following embedding providers are supported:
|
||||||
|
|
||||||
|
- `openai` - OpenAI embeddings (default)
|
||||||
|
- `ollama` - Ollama local embeddings
|
||||||
|
- `google-generativeai` - Google Generative AI embeddings
|
||||||
|
- `azure` - Azure OpenAI embeddings
|
||||||
|
- `huggingface` - HuggingFace embeddings
|
||||||
|
- `cohere` - Cohere embeddings
|
||||||
|
- `voyageai` - Voyage AI embeddings
|
||||||
|
- And more...
|
||||||
|
|
||||||
|
### Example with Google Generative AI
|
||||||
|
|
||||||
|
```python Code
|
||||||
|
tool = WebsiteSearchTool(
|
||||||
|
config=dict(
|
||||||
|
embedding_model=dict(
|
||||||
|
provider="google-generativeai",
|
||||||
|
config=dict(
|
||||||
|
model_name="models/embedding-001",
|
||||||
|
task_type="RETRIEVAL_DOCUMENT",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Example with Azure OpenAI
|
||||||
|
|
||||||
|
```python Code
|
||||||
|
tool = WebsiteSearchTool(
|
||||||
|
config=dict(
|
||||||
|
embedding_model=dict(
|
||||||
|
provider="azure",
|
||||||
|
config=dict(
|
||||||
|
model="text-embedding-3-small",
|
||||||
|
api_key="your-api-key",
|
||||||
|
api_base="https://your-resource.openai.azure.com/",
|
||||||
|
api_version="2024-02-01",
|
||||||
|
deployment_id="your-deployment-id",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
<Note>
|
||||||
|
The `llm` and `embedder` config keys from older documentation are deprecated.
|
||||||
|
Please use `embedding_model` instead. The `llm` key is not used by RAG tools -
|
||||||
|
the LLM for generation is controlled by the agent's LLM configuration.
|
||||||
|
</Note>
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Literal, cast
|
from typing import Any, Literal, cast
|
||||||
|
import warnings
|
||||||
|
|
||||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
||||||
from crewai.rag.embeddings.factory import build_embedder
|
from crewai.rag.embeddings.factory import build_embedder
|
||||||
@@ -131,10 +132,17 @@ class RagTool(BaseTool):
|
|||||||
@field_validator("config", mode="before")
|
@field_validator("config", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def _validate_config(cls, value: Any) -> Any:
|
def _validate_config(cls, value: Any) -> Any:
|
||||||
"""Validate config with improved error messages for embedding providers."""
|
"""Validate config with improved error messages for embedding providers.
|
||||||
|
|
||||||
|
Also provides backward compatibility for the legacy config format that used
|
||||||
|
'llm' and 'embedder' keys instead of 'embedding_model' and 'vectordb'.
|
||||||
|
"""
|
||||||
if not isinstance(value, dict):
|
if not isinstance(value, dict):
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
# Handle backward compatibility for legacy config format
|
||||||
|
value = cls._normalize_legacy_config(value)
|
||||||
|
|
||||||
embedding_model = value.get("embedding_model")
|
embedding_model = value.get("embedding_model")
|
||||||
if embedding_model:
|
if embedding_model:
|
||||||
try:
|
try:
|
||||||
@@ -144,6 +152,45 @@ class RagTool(BaseTool):
|
|||||||
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _normalize_legacy_config(cls, config: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Normalize legacy config format to the current format.
|
||||||
|
|
||||||
|
The legacy format used 'llm' and 'embedder' keys, while the current format
|
||||||
|
uses 'embedding_model' and 'vectordb' keys.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: The configuration dictionary to normalize.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A normalized configuration dictionary using the current format.
|
||||||
|
"""
|
||||||
|
normalized = dict(config)
|
||||||
|
|
||||||
|
# Handle legacy 'embedder' key -> 'embedding_model'
|
||||||
|
if "embedder" in normalized and "embedding_model" not in normalized:
|
||||||
|
warnings.warn(
|
||||||
|
"The 'embedder' config key is deprecated. "
|
||||||
|
"Please use 'embedding_model' instead. "
|
||||||
|
"Example: config={'embedding_model': {'provider': 'ollama', 'config': {...}}}",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=4,
|
||||||
|
)
|
||||||
|
normalized["embedding_model"] = normalized.pop("embedder")
|
||||||
|
|
||||||
|
# Handle legacy 'llm' key - this is not used in RAG tools
|
||||||
|
if "llm" in normalized:
|
||||||
|
warnings.warn(
|
||||||
|
"The 'llm' config key is not used by RAG tools and will be ignored. "
|
||||||
|
"The LLM for generation is controlled by the agent's LLM configuration, "
|
||||||
|
"not the tool configuration.",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=4,
|
||||||
|
)
|
||||||
|
normalized.pop("llm")
|
||||||
|
|
||||||
|
return normalized
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def _ensure_adapter(self) -> Self:
|
def _ensure_adapter(self) -> Self:
|
||||||
if isinstance(self.adapter, RagTool._AdapterPlaceholder):
|
if isinstance(self.adapter, RagTool._AdapterPlaceholder):
|
||||||
|
|||||||
@@ -2,6 +2,9 @@ from pathlib import Path
|
|||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
from typing import cast
|
from typing import cast
|
||||||
from unittest.mock import MagicMock, Mock, patch
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from crewai_tools.adapters.crewai_rag_adapter import CrewAIRagAdapter
|
from crewai_tools.adapters.crewai_rag_adapter import CrewAIRagAdapter
|
||||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||||
@@ -299,3 +302,175 @@ def test_rag_tool_config_with_qdrant_and_azure_embeddings(
|
|||||||
|
|
||||||
assert tool.adapter is not None
|
assert tool.adapter is not None
|
||||||
assert isinstance(tool.adapter, CrewAIRagAdapter)
|
assert isinstance(tool.adapter, CrewAIRagAdapter)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
|
||||||
|
def test_rag_tool_with_legacy_embedder_config(
|
||||||
|
mock_create_client: Mock,
|
||||||
|
) -> None:
|
||||||
|
"""Test that RagTool accepts legacy 'embedder' config key with deprecation warning.
|
||||||
|
|
||||||
|
This test verifies the fix for issue #4028 where WebsiteSearchTool and other
|
||||||
|
RAG tools always required OpenAI API key even when using Ollama or other providers.
|
||||||
|
The legacy config format used 'embedder' key instead of 'embedding_model'.
|
||||||
|
"""
|
||||||
|
mock_embedding_func = MagicMock()
|
||||||
|
mock_embedding_func.return_value = [[0.1] * 1536]
|
||||||
|
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||||
|
mock_create_client.return_value = mock_client
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"crewai_tools.tools.rag.rag_tool.build_embedder",
|
||||||
|
return_value=mock_embedding_func,
|
||||||
|
):
|
||||||
|
|
||||||
|
class MyTool(RagTool):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Legacy config format using 'embedder' key (as shown in old docs)
|
||||||
|
legacy_config = {
|
||||||
|
"embedder": {
|
||||||
|
"provider": "ollama",
|
||||||
|
"config": {
|
||||||
|
"model_name": "nomic-embed-text",
|
||||||
|
"url": "http://localhost:11434/api/embeddings",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Should emit deprecation warning but still work
|
||||||
|
with warnings.catch_warnings(record=True) as w:
|
||||||
|
warnings.simplefilter("always")
|
||||||
|
tool = MyTool(config=legacy_config)
|
||||||
|
|
||||||
|
# Check that deprecation warning was issued
|
||||||
|
deprecation_warnings = [
|
||||||
|
warning
|
||||||
|
for warning in w
|
||||||
|
if issubclass(warning.category, DeprecationWarning)
|
||||||
|
]
|
||||||
|
assert len(deprecation_warnings) >= 1
|
||||||
|
assert "embedder" in str(deprecation_warnings[0].message)
|
||||||
|
assert "embedding_model" in str(deprecation_warnings[0].message)
|
||||||
|
|
||||||
|
assert tool.adapter is not None
|
||||||
|
assert isinstance(tool.adapter, CrewAIRagAdapter)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
|
||||||
|
def test_rag_tool_with_legacy_llm_config_ignored(
|
||||||
|
mock_create_client: Mock,
|
||||||
|
) -> None:
|
||||||
|
"""Test that RagTool ignores legacy 'llm' config key with deprecation warning.
|
||||||
|
|
||||||
|
The 'llm' key was shown in old documentation but is not used by RAG tools.
|
||||||
|
The LLM for generation is controlled by the agent's LLM configuration.
|
||||||
|
"""
|
||||||
|
mock_embedding_func = MagicMock()
|
||||||
|
mock_embedding_func.return_value = [[0.1] * 1536]
|
||||||
|
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||||
|
mock_create_client.return_value = mock_client
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"crewai_tools.tools.rag.rag_tool.build_embedder",
|
||||||
|
return_value=mock_embedding_func,
|
||||||
|
):
|
||||||
|
|
||||||
|
class MyTool(RagTool):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Legacy config format with both 'llm' and 'embedder' keys
|
||||||
|
legacy_config = {
|
||||||
|
"llm": {
|
||||||
|
"provider": "ollama",
|
||||||
|
"config": {
|
||||||
|
"model": "llama2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"embedder": {
|
||||||
|
"provider": "ollama",
|
||||||
|
"config": {
|
||||||
|
"model_name": "nomic-embed-text",
|
||||||
|
"url": "http://localhost:11434/api/embeddings",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Should emit deprecation warnings for both keys
|
||||||
|
with warnings.catch_warnings(record=True) as w:
|
||||||
|
warnings.simplefilter("always")
|
||||||
|
tool = MyTool(config=legacy_config)
|
||||||
|
|
||||||
|
# Check that deprecation warnings were issued for both keys
|
||||||
|
deprecation_warnings = [
|
||||||
|
warning
|
||||||
|
for warning in w
|
||||||
|
if issubclass(warning.category, DeprecationWarning)
|
||||||
|
]
|
||||||
|
assert len(deprecation_warnings) >= 2
|
||||||
|
|
||||||
|
warning_messages = [str(warning.message) for warning in deprecation_warnings]
|
||||||
|
assert any("llm" in msg for msg in warning_messages)
|
||||||
|
assert any("embedder" in msg for msg in warning_messages)
|
||||||
|
|
||||||
|
assert tool.adapter is not None
|
||||||
|
assert isinstance(tool.adapter, CrewAIRagAdapter)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
|
||||||
|
def test_rag_tool_legacy_config_does_not_override_new_config(
|
||||||
|
mock_create_client: Mock,
|
||||||
|
) -> None:
|
||||||
|
"""Test that legacy 'embedder' key does not override 'embedding_model' if both present."""
|
||||||
|
mock_embedding_func = MagicMock()
|
||||||
|
mock_embedding_func.return_value = [[0.1] * 1536]
|
||||||
|
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||||
|
mock_create_client.return_value = mock_client
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"crewai_tools.tools.rag.rag_tool.build_embedder",
|
||||||
|
return_value=mock_embedding_func,
|
||||||
|
) as mock_build:
|
||||||
|
|
||||||
|
class MyTool(RagTool):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Config with both old and new keys - new key should take precedence
|
||||||
|
config = {
|
||||||
|
"embedder": {
|
||||||
|
"provider": "ollama",
|
||||||
|
"config": {"model_name": "old-model"},
|
||||||
|
},
|
||||||
|
"embedding_model": {
|
||||||
|
"provider": "openai",
|
||||||
|
"config": {"model": "text-embedding-3-small", "api_key": "test-key"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# No deprecation warning for 'embedder' since 'embedding_model' is present
|
||||||
|
with warnings.catch_warnings(record=True) as w:
|
||||||
|
warnings.simplefilter("always")
|
||||||
|
tool = MyTool(config=config)
|
||||||
|
|
||||||
|
# Should NOT warn about 'embedder' since 'embedding_model' takes precedence
|
||||||
|
embedder_warnings = [
|
||||||
|
warning
|
||||||
|
for warning in w
|
||||||
|
if issubclass(warning.category, DeprecationWarning)
|
||||||
|
and "embedder" in str(warning.message)
|
||||||
|
]
|
||||||
|
assert len(embedder_warnings) == 0
|
||||||
|
|
||||||
|
assert tool.adapter is not None
|
||||||
|
|
||||||
|
# Verify that the new 'embedding_model' config was used, not the legacy 'embedder'
|
||||||
|
call_args = mock_build.call_args
|
||||||
|
assert call_args is not None
|
||||||
|
spec = call_args[0][0]
|
||||||
|
assert spec["provider"] == "openai"
|
||||||
|
|||||||
Reference in New Issue
Block a user