fix: rag tool embeddings config

* fix: ensure config is not flattened, add tests

* chore: refactor inits to model_validator

* chore: refactor rag tool config parsing

* chore: add initial docs

* chore: add additional validation aliases for provider env vars

* chore: add solid docs

* chore: move imports to top

* fix: revert circular import

* fix: lazy import qdrant-client

* fix: allow collection name config

* chore: narrow model names for google

* chore: update additional docs

* chore: add backward compat on model name aliases

* chore: add tests for config changes
This commit is contained in:
Greyson LaLonde
2025-11-24 16:51:28 -05:00
committed by GitHub
parent 9c84475691
commit a928cde6ee
46 changed files with 1850 additions and 291 deletions

View File

@@ -1,5 +1,3 @@
"""Tests for RAG tool with mocked embeddings and vector database."""
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import cast
@@ -117,15 +115,15 @@ def test_rag_tool_with_file(
assert "Python is a programming language" in result
@patch("crewai_tools.tools.rag.rag_tool.RagTool._create_embedding_function")
@patch("crewai_tools.tools.rag.rag_tool.build_embedder")
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
def test_rag_tool_with_custom_embeddings(
mock_create_client: Mock, mock_create_embedding: Mock
mock_create_client: Mock, mock_build_embedder: Mock
) -> None:
"""Test RagTool with custom embeddings configuration to ensure no API calls."""
mock_embedding_func = MagicMock()
mock_embedding_func.return_value = [[0.2] * 1536]
mock_create_embedding.return_value = mock_embedding_func
mock_build_embedder.return_value = mock_embedding_func
mock_client = MagicMock()
mock_client.get_or_create_collection = MagicMock(return_value=None)
@@ -153,7 +151,7 @@ def test_rag_tool_with_custom_embeddings(
assert "Relevant Content:" in result
assert "Test content" in result
mock_create_embedding.assert_called()
mock_build_embedder.assert_called()
@patch("crewai_tools.adapters.crewai_rag_adapter.get_rag_client")
@@ -176,3 +174,128 @@ def test_rag_tool_no_results(
result = tool._run(query="Non-existent content")
assert "Relevant Content:" in result
assert "No relevant content found" in result
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
def test_rag_tool_with_azure_config_without_env_vars(
mock_create_client: Mock,
) -> None:
"""Test that RagTool accepts Azure config without requiring env vars.
This test verifies the fix for the issue where RAG tools were ignoring
the embedding configuration passed via the config parameter and instead
requiring environment variables like EMBEDDINGS_OPENAI_API_KEY.
"""
mock_embedding_func = MagicMock()
mock_embedding_func.return_value = [[0.1] * 1536]
mock_client = MagicMock()
mock_client.get_or_create_collection = MagicMock(return_value=None)
mock_client.add_documents = MagicMock(return_value=None)
mock_create_client.return_value = mock_client
# Patch the embedding function builder to avoid actual API calls
with patch(
"crewai_tools.tools.rag.rag_tool.build_embedder",
return_value=mock_embedding_func,
):
class MyTool(RagTool):
pass
# Configuration with explicit Azure credentials - should work without env vars
config = {
"embedding_model": {
"provider": "azure",
"config": {
"model": "text-embedding-3-small",
"api_key": "test-api-key",
"api_base": "https://test.openai.azure.com/",
"api_version": "2024-02-01",
"api_type": "azure",
"deployment_id": "test-deployment",
},
}
}
# This should not raise a validation error about missing env vars
tool = MyTool(config=config)
assert tool.adapter is not None
assert isinstance(tool.adapter, CrewAIRagAdapter)
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
def test_rag_tool_with_openai_config_without_env_vars(
mock_create_client: Mock,
) -> None:
"""Test that RagTool accepts OpenAI config without requiring env vars."""
mock_embedding_func = MagicMock()
mock_embedding_func.return_value = [[0.1] * 1536]
mock_client = MagicMock()
mock_client.get_or_create_collection = MagicMock(return_value=None)
mock_create_client.return_value = mock_client
with patch(
"crewai_tools.tools.rag.rag_tool.build_embedder",
return_value=mock_embedding_func,
):
class MyTool(RagTool):
pass
config = {
"embedding_model": {
"provider": "openai",
"config": {
"model": "text-embedding-3-small",
"api_key": "sk-test123",
},
}
}
tool = MyTool(config=config)
assert tool.adapter is not None
assert isinstance(tool.adapter, CrewAIRagAdapter)
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
def test_rag_tool_config_with_qdrant_and_azure_embeddings(
mock_create_client: Mock,
) -> None:
"""Test RagTool with Qdrant vector DB and Azure embeddings config."""
mock_embedding_func = MagicMock()
mock_embedding_func.return_value = [[0.1] * 1536]
mock_client = MagicMock()
mock_client.get_or_create_collection = MagicMock(return_value=None)
mock_create_client.return_value = mock_client
with patch(
"crewai_tools.tools.rag.rag_tool.build_embedder",
return_value=mock_embedding_func,
):
class MyTool(RagTool):
pass
config = {
"vectordb": {"provider": "qdrant", "config": {}},
"embedding_model": {
"provider": "azure",
"config": {
"model": "text-embedding-3-large",
"api_key": "test-key",
"api_base": "https://test.openai.azure.com/",
"api_version": "2024-02-01",
"deployment_id": "test-deployment",
},
},
}
tool = MyTool(config=config)
assert tool.adapter is not None
assert isinstance(tool.adapter, CrewAIRagAdapter)

View File

@@ -0,0 +1,66 @@
"""Tests for improved RAG tool validation error messages."""
from unittest.mock import MagicMock, Mock, patch
import pytest
from pydantic import ValidationError
from crewai_tools.tools.rag.rag_tool import RagTool
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
def test_azure_missing_deployment_id_gives_clear_error(mock_create_client: Mock) -> None:
"""Test that missing deployment_id for Azure gives a clear, focused error message."""
mock_client = MagicMock()
mock_client.get_or_create_collection = MagicMock(return_value=None)
mock_create_client.return_value = mock_client
class MyTool(RagTool):
pass
config = {
"embedding_model": {
"provider": "azure",
"config": {
"api_base": "http://localhost:4000/v1",
"api_key": "test-key",
"api_version": "2024-02-01",
},
}
}
with pytest.raises(ValueError) as exc_info:
MyTool(config=config)
error_msg = str(exc_info.value)
assert "azure" in error_msg.lower()
assert "deployment_id" in error_msg.lower()
assert "bedrock" not in error_msg.lower()
assert "cohere" not in error_msg.lower()
assert "huggingface" not in error_msg.lower()
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
def test_valid_azure_config_works(mock_create_client: Mock) -> None:
"""Test that valid Azure config works without errors."""
mock_client = MagicMock()
mock_client.get_or_create_collection = MagicMock(return_value=None)
mock_create_client.return_value = mock_client
class MyTool(RagTool):
pass
config = {
"embedding_model": {
"provider": "azure",
"config": {
"api_base": "http://localhost:4000/v1",
"api_key": "test-key",
"api_version": "2024-02-01",
"deployment_id": "text-embedding-3-small",
},
}
}
tool = MyTool(config=config)
assert tool is not None