fix: rag tool embeddings config

* fix: ensure config is not flattened, add tests

* chore: refactor inits to model_validator

* chore: refactor rag tool config parsing

* chore: add initial docs

* chore: add additional validation aliases for provider env vars

* chore: add solid docs

* chore: move imports to top

* fix: revert circular import

* fix: lazy import qdrant-client

* fix: allow collection name config

* chore: narrow model names for google

* chore: update additional docs

* chore: add backward compat on model name aliases

* chore: add tests for config changes
This commit is contained in:
Greyson LaLonde
2025-11-24 16:51:28 -05:00
committed by GitHub
parent 9c84475691
commit a928cde6ee
46 changed files with 1850 additions and 291 deletions

View File

@@ -1,28 +1,51 @@
"""Adapter for CrewAI's native RAG system."""
from __future__ import annotations
import hashlib
from pathlib import Path
from typing import Any, TypeAlias, TypedDict
from typing import TYPE_CHECKING, Any, TypeAlias, TypedDict, cast
import uuid
from crewai.rag.config.types import RagConfigType
from crewai.rag.config.utils import get_rag_client
from crewai.rag.core.base_client import BaseClient
from crewai.rag.factory import create_client
from crewai.rag.qdrant.config import QdrantConfig
from crewai.rag.types import BaseRecord, SearchResult
from pydantic import PrivateAttr
from qdrant_client.models import VectorParams
from typing_extensions import Unpack
from pydantic.dataclasses import is_pydantic_dataclass
from typing_extensions import TypeIs, Unpack
from crewai_tools.rag.data_types import DataType
from crewai_tools.rag.misc import sanitize_metadata_for_chromadb
from crewai_tools.tools.rag.rag_tool import Adapter
if TYPE_CHECKING:
from crewai.rag.qdrant.config import QdrantConfig
ContentItem: TypeAlias = str | Path | dict[str, Any]
def _is_qdrant_config(config: Any) -> TypeIs[QdrantConfig]:
"""Check if config is a QdrantConfig using safe duck typing.
Args:
config: RAG configuration to check.
Returns:
True if config is a QdrantConfig instance.
"""
if not is_pydantic_dataclass(config):
return False
try:
return cast(bool, config.provider == "qdrant") # type: ignore[attr-defined]
except (AttributeError, ImportError):
return False
class AddDocumentParams(TypedDict, total=False):
"""Parameters for adding documents to the RAG system."""
@@ -56,8 +79,9 @@ class CrewAIRagAdapter(Adapter):
else:
self._client = get_rag_client()
collection_params: dict[str, Any] = {"collection_name": self.collection_name}
if isinstance(self.config, QdrantConfig) and self.config.vectors_config:
if isinstance(self.config.vectors_config, VectorParams):
if self.config is not None and _is_qdrant_config(self.config):
if self.config.vectors_config is not None:
collection_params["vectors_config"] = self.config.vectors_config
self._client.get_or_create_collection(**collection_params)

View File

@@ -1,4 +1,5 @@
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
from typing_extensions import Self
from crewai_tools.rag.data_types import DataType
from crewai_tools.tools.rag.rag_tool import RagTool
@@ -24,14 +25,17 @@ class PDFSearchTool(RagTool):
"A tool that can be used to semantic search a query from a PDF's content."
)
args_schema: type[BaseModel] = PDFSearchToolSchema
pdf: str | None = None
def __init__(self, pdf: str | None = None, **kwargs):
super().__init__(**kwargs)
if pdf is not None:
self.add(pdf)
self.description = f"A tool that can be used to semantic search a query the {pdf} PDF's content."
@model_validator(mode="after")
def _configure_for_pdf(self) -> Self:
"""Configure tool for specific PDF if provided."""
if self.pdf is not None:
self.add(self.pdf)
self.description = f"A tool that can be used to semantic search a query the {self.pdf} PDF's content."
self.args_schema = FixedPDFSearchToolSchema
self._generate_description()
return self
def add(self, pdf: str) -> None:
super().add(pdf, data_type=DataType.PDF_FILE)

View File

@@ -0,0 +1,10 @@
from crewai.rag.embeddings.types import ProviderSpec
from crewai_tools.tools.rag.types import RagToolConfig, VectorDbConfig
__all__ = [
"ProviderSpec",
"RagToolConfig",
"VectorDbConfig",
]

View File

@@ -1,10 +1,74 @@
from abc import ABC, abstractmethod
import os
from typing import Any, cast
from typing import Any, Literal, cast
from crewai.rag.embeddings.factory import get_embedding_function
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
from crewai.rag.embeddings.factory import build_embedder
from crewai.rag.embeddings.types import ProviderSpec
from crewai.tools import BaseTool
from pydantic import BaseModel, ConfigDict, Field, model_validator
from pydantic import (
BaseModel,
ConfigDict,
Field,
TypeAdapter,
ValidationError,
field_validator,
model_validator,
)
from typing_extensions import Self
from crewai_tools.tools.rag.types import RagToolConfig, VectorDbConfig
def _validate_embedding_config(
value: dict[str, Any] | ProviderSpec,
) -> dict[str, Any] | ProviderSpec:
"""Validate embedding config and provide clearer error messages for union validation.
This pre-validator catches Pydantic ValidationErrors from the ProviderSpec union
and provides a cleaner, more focused error message that only shows the relevant
provider's validation errors instead of all 18 union members.
Args:
value: The embedding configuration dictionary or validated ProviderSpec.
Returns:
A validated ProviderSpec instance, or the original value if already validated
or missing required fields.
Raises:
ValueError: If the configuration is invalid for the specified provider.
"""
if not isinstance(value, dict):
return value
provider = value.get("provider")
if not provider:
return value
try:
type_adapter: TypeAdapter[ProviderSpec] = TypeAdapter(ProviderSpec)
return type_adapter.validate_python(value)
except ValidationError as e:
provider_key = f"{provider.lower()}providerspec"
provider_errors = [
err for err in e.errors() if provider_key in str(err.get("loc", "")).lower()
]
if provider_errors:
error_msgs = []
for err in provider_errors:
loc_parts = err["loc"]
if str(loc_parts[0]).lower() == provider_key:
loc_parts = loc_parts[1:]
loc = ".".join(str(x) for x in loc_parts)
error_msgs.append(f" - {loc}: {err['msg']}")
raise ValueError(
f"Invalid configuration for embedding provider '{provider}':\n"
+ "\n".join(error_msgs)
) from e
raise
class Adapter(BaseModel, ABC):
@@ -46,139 +110,100 @@ class RagTool(BaseTool):
summarize: bool = False
similarity_threshold: float = 0.6
limit: int = 5
collection_name: str = "rag_tool_collection"
adapter: Adapter = Field(default_factory=_AdapterPlaceholder)
config: Any | None = None
config: RagToolConfig = Field(
default_factory=RagToolConfig,
description="Configuration format accepted by RagTool.",
)
@field_validator("config", mode="before")
@classmethod
def _validate_config(cls, value: Any) -> Any:
"""Validate config with improved error messages for embedding providers."""
if not isinstance(value, dict):
return value
embedding_model = value.get("embedding_model")
if embedding_model:
try:
value["embedding_model"] = _validate_embedding_config(embedding_model)
except ValueError:
raise
return value
@model_validator(mode="after")
def _set_default_adapter(self):
def _ensure_adapter(self) -> Self:
if isinstance(self.adapter, RagTool._AdapterPlaceholder):
from crewai_tools.adapters.crewai_rag_adapter import CrewAIRagAdapter
parsed_config = self._parse_config(self.config)
provider_cfg = self._parse_config(self.config)
self.adapter = CrewAIRagAdapter(
collection_name="rag_tool_collection",
collection_name=self.collection_name,
summarize=self.summarize,
similarity_threshold=self.similarity_threshold,
limit=self.limit,
config=parsed_config,
config=provider_cfg,
)
return self
def _parse_config(self, config: Any) -> Any:
"""Parse complex config format to extract provider-specific config.
def _parse_config(self, config: RagToolConfig) -> Any:
"""Normalize the RagToolConfig into a provider-specific config object.
Raises:
ValueError: If the config format is invalid or uses unsupported providers.
Defaults to 'chromadb' with no extra provider config if none is supplied.
"""
if config is None:
return None
if not config:
return self._create_provider_config("chromadb", {}, None)
if isinstance(config, dict) and "provider" in config:
return config
vectordb_cfg = cast(VectorDbConfig, config.get("vectordb", {}))
provider: Literal["chromadb", "qdrant"] = vectordb_cfg.get(
"provider", "chromadb"
)
provider_config: dict[str, Any] = vectordb_cfg.get("config", {})
if isinstance(config, dict):
if "vectordb" in config:
vectordb_config = config["vectordb"]
if isinstance(vectordb_config, dict) and "provider" in vectordb_config:
provider = vectordb_config["provider"]
provider_config = vectordb_config.get("config", {})
supported = ("chromadb", "qdrant")
if provider not in supported:
raise ValueError(
f"Unsupported vector database provider: '{provider}'. "
f"CrewAI RAG currently supports: {', '.join(supported)}."
)
supported_providers = ["chromadb", "qdrant"]
if provider not in supported_providers:
raise ValueError(
f"Unsupported vector database provider: '{provider}'. "
f"CrewAI RAG currently supports: {', '.join(supported_providers)}."
)
embedding_spec: ProviderSpec | None = config.get("embedding_model")
if embedding_spec:
embedding_spec = cast(
ProviderSpec, _validate_embedding_config(embedding_spec)
)
embedding_config = config.get("embedding_model")
embedding_function = None
if embedding_config and isinstance(embedding_config, dict):
embedding_function = self._create_embedding_function(
embedding_config, provider
)
return self._create_provider_config(
provider, provider_config, embedding_function
)
return None
embedding_config = config.get("embedding_model")
embedding_function = None
if embedding_config and isinstance(embedding_config, dict):
embedding_function = self._create_embedding_function(
embedding_config, "chromadb"
)
return self._create_provider_config("chromadb", {}, embedding_function)
return config
@staticmethod
def _create_embedding_function(embedding_config: dict, provider: str) -> Any:
"""Create embedding function for the specified vector database provider."""
embedding_provider = embedding_config.get("provider")
embedding_model_config = embedding_config.get("config", {}).copy()
if "model" in embedding_model_config:
embedding_model_config["model_name"] = embedding_model_config.pop("model")
factory_config = {"provider": embedding_provider, **embedding_model_config}
if embedding_provider == "openai" and "api_key" not in factory_config:
api_key = os.getenv("OPENAI_API_KEY")
if api_key:
factory_config["api_key"] = api_key
if provider == "chromadb":
return get_embedding_function(factory_config) # type: ignore[call-overload]
if provider == "qdrant":
chromadb_func = get_embedding_function(factory_config) # type: ignore[call-overload]
def qdrant_embed_fn(text: str) -> list[float]:
"""Embed text using ChromaDB function and convert to list of floats for Qdrant.
Args:
text: The input text to embed.
Returns:
A list of floats representing the embedding.
"""
embeddings = chromadb_func([text])
return embeddings[0] if embeddings and len(embeddings) > 0 else []
return cast(Any, qdrant_embed_fn)
return None
embedding_function = build_embedder(embedding_spec) if embedding_spec else None
return self._create_provider_config(
provider, provider_config, embedding_function
)
@staticmethod
def _create_provider_config(
provider: str, provider_config: dict, embedding_function: Any
provider: Literal["chromadb", "qdrant"],
provider_config: dict[str, Any],
embedding_function: EmbeddingFunction[Any] | None,
) -> Any:
"""Create proper provider config object."""
"""Instantiate provider config with optional embedding_function injected."""
if provider == "chromadb":
from crewai.rag.chromadb.config import ChromaDBConfig
config_kwargs = {}
if embedding_function:
config_kwargs["embedding_function"] = embedding_function
config_kwargs.update(provider_config)
return ChromaDBConfig(**config_kwargs)
kwargs = dict(provider_config)
if embedding_function is not None:
kwargs["embedding_function"] = embedding_function
return ChromaDBConfig(**kwargs)
if provider == "qdrant":
from crewai.rag.qdrant.config import QdrantConfig
config_kwargs = {}
if embedding_function:
config_kwargs["embedding_function"] = embedding_function
kwargs = dict(provider_config)
if embedding_function is not None:
kwargs["embedding_function"] = embedding_function
return QdrantConfig(**kwargs)
config_kwargs.update(provider_config)
return QdrantConfig(**config_kwargs)
return None
raise ValueError(f"Unhandled provider: {provider}")
def add(
self,

View File

@@ -0,0 +1,32 @@
"""Type definitions for RAG tool configuration."""
from typing import Any, Literal
from crewai.rag.embeddings.types import ProviderSpec
from typing_extensions import TypedDict
class VectorDbConfig(TypedDict):
"""Configuration for vector database provider.
Attributes:
provider: RAG provider literal.
config: RAG configuration options.
"""
provider: Literal["chromadb", "qdrant"]
config: dict[str, Any]
class RagToolConfig(TypedDict, total=False):
"""Configuration accepted by RAG tools.
Supports embedding model and vector database configuration.
Attributes:
embedding_model: Embedding model configuration accepted by RAG tools.
vectordb: Vector database configuration accepted by RAG tools.
"""
embedding_model: ProviderSpec
vectordb: VectorDbConfig

View File

@@ -1,4 +1,5 @@
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
from typing_extensions import Self
from crewai_tools.tools.rag.rag_tool import RagTool
@@ -24,14 +25,17 @@ class TXTSearchTool(RagTool):
"A tool that can be used to semantic search a query from a txt's content."
)
args_schema: type[BaseModel] = TXTSearchToolSchema
txt: str | None = None
def __init__(self, txt: str | None = None, **kwargs):
super().__init__(**kwargs)
if txt is not None:
self.add(txt)
self.description = f"A tool that can be used to semantic search a query the {txt} txt's content."
@model_validator(mode="after")
def _configure_for_txt(self) -> Self:
"""Configure tool for specific TXT file if provided."""
if self.txt is not None:
self.add(self.txt)
self.description = f"A tool that can be used to semantic search a query the {self.txt} txt's content."
self.args_schema = FixedTXTSearchToolSchema
self._generate_description()
return self
def _run( # type: ignore[override]
self,

View File

@@ -1,5 +1,3 @@
"""Tests for RAG tool with mocked embeddings and vector database."""
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import cast
@@ -117,15 +115,15 @@ def test_rag_tool_with_file(
assert "Python is a programming language" in result
@patch("crewai_tools.tools.rag.rag_tool.RagTool._create_embedding_function")
@patch("crewai_tools.tools.rag.rag_tool.build_embedder")
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
def test_rag_tool_with_custom_embeddings(
mock_create_client: Mock, mock_create_embedding: Mock
mock_create_client: Mock, mock_build_embedder: Mock
) -> None:
"""Test RagTool with custom embeddings configuration to ensure no API calls."""
mock_embedding_func = MagicMock()
mock_embedding_func.return_value = [[0.2] * 1536]
mock_create_embedding.return_value = mock_embedding_func
mock_build_embedder.return_value = mock_embedding_func
mock_client = MagicMock()
mock_client.get_or_create_collection = MagicMock(return_value=None)
@@ -153,7 +151,7 @@ def test_rag_tool_with_custom_embeddings(
assert "Relevant Content:" in result
assert "Test content" in result
mock_create_embedding.assert_called()
mock_build_embedder.assert_called()
@patch("crewai_tools.adapters.crewai_rag_adapter.get_rag_client")
@@ -176,3 +174,128 @@ def test_rag_tool_no_results(
result = tool._run(query="Non-existent content")
assert "Relevant Content:" in result
assert "No relevant content found" in result
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
def test_rag_tool_with_azure_config_without_env_vars(
mock_create_client: Mock,
) -> None:
"""Test that RagTool accepts Azure config without requiring env vars.
This test verifies the fix for the issue where RAG tools were ignoring
the embedding configuration passed via the config parameter and instead
requiring environment variables like EMBEDDINGS_OPENAI_API_KEY.
"""
mock_embedding_func = MagicMock()
mock_embedding_func.return_value = [[0.1] * 1536]
mock_client = MagicMock()
mock_client.get_or_create_collection = MagicMock(return_value=None)
mock_client.add_documents = MagicMock(return_value=None)
mock_create_client.return_value = mock_client
# Patch the embedding function builder to avoid actual API calls
with patch(
"crewai_tools.tools.rag.rag_tool.build_embedder",
return_value=mock_embedding_func,
):
class MyTool(RagTool):
pass
# Configuration with explicit Azure credentials - should work without env vars
config = {
"embedding_model": {
"provider": "azure",
"config": {
"model": "text-embedding-3-small",
"api_key": "test-api-key",
"api_base": "https://test.openai.azure.com/",
"api_version": "2024-02-01",
"api_type": "azure",
"deployment_id": "test-deployment",
},
}
}
# This should not raise a validation error about missing env vars
tool = MyTool(config=config)
assert tool.adapter is not None
assert isinstance(tool.adapter, CrewAIRagAdapter)
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
def test_rag_tool_with_openai_config_without_env_vars(
mock_create_client: Mock,
) -> None:
"""Test that RagTool accepts OpenAI config without requiring env vars."""
mock_embedding_func = MagicMock()
mock_embedding_func.return_value = [[0.1] * 1536]
mock_client = MagicMock()
mock_client.get_or_create_collection = MagicMock(return_value=None)
mock_create_client.return_value = mock_client
with patch(
"crewai_tools.tools.rag.rag_tool.build_embedder",
return_value=mock_embedding_func,
):
class MyTool(RagTool):
pass
config = {
"embedding_model": {
"provider": "openai",
"config": {
"model": "text-embedding-3-small",
"api_key": "sk-test123",
},
}
}
tool = MyTool(config=config)
assert tool.adapter is not None
assert isinstance(tool.adapter, CrewAIRagAdapter)
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
def test_rag_tool_config_with_qdrant_and_azure_embeddings(
mock_create_client: Mock,
) -> None:
"""Test RagTool with Qdrant vector DB and Azure embeddings config."""
mock_embedding_func = MagicMock()
mock_embedding_func.return_value = [[0.1] * 1536]
mock_client = MagicMock()
mock_client.get_or_create_collection = MagicMock(return_value=None)
mock_create_client.return_value = mock_client
with patch(
"crewai_tools.tools.rag.rag_tool.build_embedder",
return_value=mock_embedding_func,
):
class MyTool(RagTool):
pass
config = {
"vectordb": {"provider": "qdrant", "config": {}},
"embedding_model": {
"provider": "azure",
"config": {
"model": "text-embedding-3-large",
"api_key": "test-key",
"api_base": "https://test.openai.azure.com/",
"api_version": "2024-02-01",
"deployment_id": "test-deployment",
},
},
}
tool = MyTool(config=config)
assert tool.adapter is not None
assert isinstance(tool.adapter, CrewAIRagAdapter)

View File

@@ -0,0 +1,66 @@
"""Tests for improved RAG tool validation error messages."""
from unittest.mock import MagicMock, Mock, patch
import pytest
from pydantic import ValidationError
from crewai_tools.tools.rag.rag_tool import RagTool
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
def test_azure_missing_deployment_id_gives_clear_error(mock_create_client: Mock) -> None:
"""Test that missing deployment_id for Azure gives a clear, focused error message."""
mock_client = MagicMock()
mock_client.get_or_create_collection = MagicMock(return_value=None)
mock_create_client.return_value = mock_client
class MyTool(RagTool):
pass
config = {
"embedding_model": {
"provider": "azure",
"config": {
"api_base": "http://localhost:4000/v1",
"api_key": "test-key",
"api_version": "2024-02-01",
},
}
}
with pytest.raises(ValueError) as exc_info:
MyTool(config=config)
error_msg = str(exc_info.value)
assert "azure" in error_msg.lower()
assert "deployment_id" in error_msg.lower()
assert "bedrock" not in error_msg.lower()
assert "cohere" not in error_msg.lower()
assert "huggingface" not in error_msg.lower()
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
def test_valid_azure_config_works(mock_create_client: Mock) -> None:
"""Test that valid Azure config works without errors."""
mock_client = MagicMock()
mock_client.get_or_create_collection = MagicMock(return_value=None)
mock_create_client.return_value = mock_client
class MyTool(RagTool):
pass
config = {
"embedding_model": {
"provider": "azure",
"config": {
"api_base": "http://localhost:4000/v1",
"api_key": "test-key",
"api_version": "2024-02-01",
"deployment_id": "text-embedding-3-small",
},
}
}
tool = MyTool(config=config)
assert tool is not None

View File

@@ -0,0 +1,116 @@
from unittest.mock import MagicMock, Mock, patch
from crewai_tools.adapters.crewai_rag_adapter import CrewAIRagAdapter
from crewai_tools.tools.pdf_search_tool.pdf_search_tool import PDFSearchTool
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
def test_pdf_search_tool_with_azure_config_without_env_vars(
mock_create_client: Mock,
) -> None:
"""Test PDFSearchTool accepts Azure config without requiring env vars.
This verifies the fix for the reported issue where PDFSearchTool would
throw a validation error:
pydantic_core._pydantic_core.ValidationError: 1 validation error for PDFSearchTool
EMBEDDINGS_OPENAI_API_KEY
Field required [type=missing, input_value={}, input_type=dict]
"""
mock_embedding_func = MagicMock()
mock_embedding_func.return_value = [[0.1] * 1536]
mock_client = MagicMock()
mock_client.get_or_create_collection = MagicMock(return_value=None)
mock_create_client.return_value = mock_client
# Patch the embedding function builder to avoid actual API calls
with patch(
"crewai_tools.tools.rag.rag_tool.build_embedder",
return_value=mock_embedding_func,
):
# This is the exact config format from the bug report
config = {
"embedding_model": {
"provider": "azure",
"config": {
"model": "text-embedding-3-small",
"api_key": "test-litellm-api-key",
"api_base": "https://test.litellm.proxy/",
"api_version": "2024-02-01",
"api_type": "azure",
"deployment_id": "test-deployment",
},
}
}
# This should not raise a validation error about missing env vars
tool = PDFSearchTool(config=config)
assert tool.adapter is not None
assert isinstance(tool.adapter, CrewAIRagAdapter)
assert tool.name == "Search a PDF's content"
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
def test_pdf_search_tool_with_openai_config_without_env_vars(
mock_create_client: Mock,
) -> None:
"""Test PDFSearchTool accepts OpenAI config without requiring env vars."""
mock_embedding_func = MagicMock()
mock_embedding_func.return_value = [[0.1] * 1536]
mock_client = MagicMock()
mock_client.get_or_create_collection = MagicMock(return_value=None)
mock_create_client.return_value = mock_client
with patch(
"crewai_tools.tools.rag.rag_tool.build_embedder",
return_value=mock_embedding_func,
):
config = {
"embedding_model": {
"provider": "openai",
"config": {
"model": "text-embedding-3-small",
"api_key": "sk-test123",
},
}
}
tool = PDFSearchTool(config=config)
assert tool.adapter is not None
assert isinstance(tool.adapter, CrewAIRagAdapter)
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
def test_pdf_search_tool_with_vectordb_and_embedding_config(
mock_create_client: Mock,
) -> None:
"""Test PDFSearchTool with both vector DB and embedding config."""
mock_embedding_func = MagicMock()
mock_embedding_func.return_value = [[0.1] * 1536]
mock_client = MagicMock()
mock_client.get_or_create_collection = MagicMock(return_value=None)
mock_create_client.return_value = mock_client
with patch(
"crewai_tools.tools.rag.rag_tool.build_embedder",
return_value=mock_embedding_func,
):
config = {
"vectordb": {"provider": "chromadb", "config": {}},
"embedding_model": {
"provider": "openai",
"config": {
"model": "text-embedding-3-large",
"api_key": "sk-test-key",
},
},
}
tool = PDFSearchTool(config=config)
assert tool.adapter is not None
assert isinstance(tool.adapter, CrewAIRagAdapter)

View File

@@ -0,0 +1,104 @@
from unittest.mock import MagicMock, Mock, patch
from crewai_tools.adapters.crewai_rag_adapter import CrewAIRagAdapter
from crewai_tools.tools.txt_search_tool.txt_search_tool import TXTSearchTool
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
def test_txt_search_tool_with_azure_config_without_env_vars(
mock_create_client: Mock,
) -> None:
"""Test TXTSearchTool accepts Azure config without requiring env vars."""
mock_embedding_func = MagicMock()
mock_embedding_func.return_value = [[0.1] * 1536]
mock_client = MagicMock()
mock_client.get_or_create_collection = MagicMock(return_value=None)
mock_create_client.return_value = mock_client
with patch(
"crewai_tools.tools.rag.rag_tool.build_embedder",
return_value=mock_embedding_func,
):
config = {
"embedding_model": {
"provider": "azure",
"config": {
"model": "text-embedding-3-small",
"api_key": "test-api-key",
"api_base": "https://test.openai.azure.com/",
"api_version": "2024-02-01",
"api_type": "azure",
"deployment_id": "test-deployment",
},
}
}
# This should not raise a validation error about missing env vars
tool = TXTSearchTool(config=config)
assert tool.adapter is not None
assert isinstance(tool.adapter, CrewAIRagAdapter)
assert tool.name == "Search a txt's content"
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
def test_txt_search_tool_with_openai_config_without_env_vars(
mock_create_client: Mock,
) -> None:
"""Test TXTSearchTool accepts OpenAI config without requiring env vars."""
mock_embedding_func = MagicMock()
mock_embedding_func.return_value = [[0.1] * 1536]
mock_client = MagicMock()
mock_client.get_or_create_collection = MagicMock(return_value=None)
mock_create_client.return_value = mock_client
with patch(
"crewai_tools.tools.rag.rag_tool.build_embedder",
return_value=mock_embedding_func,
):
config = {
"embedding_model": {
"provider": "openai",
"config": {
"model": "text-embedding-3-small",
"api_key": "sk-test123",
},
}
}
tool = TXTSearchTool(config=config)
assert tool.adapter is not None
assert isinstance(tool.adapter, CrewAIRagAdapter)
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
def test_txt_search_tool_with_cohere_config(mock_create_client: Mock) -> None:
"""Test TXTSearchTool with Cohere embedding provider."""
mock_embedding_func = MagicMock()
mock_embedding_func.return_value = [[0.1] * 1024]
mock_client = MagicMock()
mock_client.get_or_create_collection = MagicMock(return_value=None)
mock_create_client.return_value = mock_client
with patch(
"crewai_tools.tools.rag.rag_tool.build_embedder",
return_value=mock_embedding_func,
):
config = {
"embedding_model": {
"provider": "cohere",
"config": {
"model": "embed-english-v3.0",
"api_key": "test-cohere-key",
},
}
}
tool = TXTSearchTool(config=config)
assert tool.adapter is not None
assert isinstance(tool.adapter, CrewAIRagAdapter)