fix: rag tool embeddings config

* fix: ensure config is not flattened, add tests

* chore: refactor inits to model_validator

* chore: refactor rag tool config parsing

* chore: add initial docs

* chore: add additional validation aliases for provider env vars

* chore: add solid docs

* chore: move imports to top

* fix: revert circular import

* fix: lazy import qdrant-client

* fix: allow collection name config

* chore: narrow model names for google

* chore: update additional docs

* chore: add backward compat on model name aliases

* chore: add tests for config changes
This commit is contained in:
Greyson LaLonde
2025-11-24 16:51:28 -05:00
committed by GitHub
parent 9c84475691
commit a928cde6ee
46 changed files with 1850 additions and 291 deletions

View File

@@ -388,8 +388,8 @@ crew = Crew(
agents=[sales_agent, tech_agent, support_agent],
tasks=[...],
embedder={ # Fallback embedder for agents without their own
"provider": "google",
"config": {"model": "text-embedding-004"}
"provider": "google-generativeai",
"config": {"model_name": "gemini-embedding-001"}
}
)
@@ -629,9 +629,9 @@ agent = Agent(
backstory="Expert researcher",
knowledge_sources=[knowledge_source],
embedder={
"provider": "google",
"provider": "google-generativeai",
"config": {
"model": "models/text-embedding-004",
"model_name": "gemini-embedding-001",
"api_key": "your-google-key"
}
}

View File

@@ -341,7 +341,7 @@ crew = Crew(
embedder={
"provider": "openai",
"config": {
"model": "text-embedding-3-small" # or "text-embedding-3-large"
"model_name": "text-embedding-3-small" # or "text-embedding-3-large"
}
}
)
@@ -353,7 +353,7 @@ crew = Crew(
"provider": "openai",
"config": {
"api_key": "your-openai-api-key", # Optional: override env var
"model": "text-embedding-3-large",
"model_name": "text-embedding-3-large",
"dimensions": 1536, # Optional: reduce dimensions for smaller storage
"organization_id": "your-org-id" # Optional: for organization accounts
}
@@ -375,7 +375,7 @@ crew = Crew(
"api_base": "https://your-resource.openai.azure.com/",
"api_type": "azure",
"api_version": "2023-05-15",
"model": "text-embedding-3-small",
"model_name": "text-embedding-3-small",
"deployment_id": "your-deployment-name" # Azure deployment name
}
}
@@ -390,10 +390,10 @@ Use Google's text embedding models for integration with Google Cloud services.
crew = Crew(
memory=True,
embedder={
"provider": "google",
"provider": "google-generativeai",
"config": {
"api_key": "your-google-api-key",
"model": "text-embedding-004" # or "text-embedding-preview-0409"
"model_name": "gemini-embedding-001" # or "text-embedding-005", "text-multilingual-embedding-002"
}
}
)
@@ -461,7 +461,7 @@ crew = Crew(
"provider": "cohere",
"config": {
"api_key": "your-cohere-api-key",
"model": "embed-english-v3.0" # or "embed-multilingual-v3.0"
"model_name": "embed-english-v3.0" # or "embed-multilingual-v3.0"
}
}
)
@@ -478,7 +478,7 @@ crew = Crew(
"provider": "voyageai",
"config": {
"api_key": "your-voyage-api-key",
"model": "voyage-large-2", # or "voyage-code-2" for code
"model": "voyage-3", # or "voyage-3-lite", "voyage-code-3"
"input_type": "document" # or "query"
}
}
@@ -912,10 +912,10 @@ crew = Crew(
crew = Crew(
memory=True,
embedder={
"provider": "google",
"provider": "google-generativeai",
"config": {
"api_key": "your-api-key",
"model": "text-embedding-004"
"model_name": "gemini-embedding-001"
}
}
)

View File

@@ -77,7 +77,7 @@ The `RagTool` accepts the following parameters:
- **summarize**: Optional. Whether to summarize the retrieved content. Default is `False`.
- **adapter**: Optional. A custom adapter for the knowledge base. If not provided, a CrewAIRagAdapter will be used.
- **config**: Optional. Configuration for the underlying CrewAI RAG system.
- **config**: Optional. Configuration for the underlying CrewAI RAG system. Accepts a `RagToolConfig` TypedDict with optional `embedding_model` (ProviderSpec) and `vectordb` (VectorDbConfig) keys. All configuration values provided programmatically take precedence over environment variables.
## Adding Content
@@ -127,26 +127,528 @@ You can customize the behavior of the `RagTool` by providing a configuration dic
```python Code
from crewai_tools import RagTool
from crewai_tools.tools.rag import RagToolConfig, VectorDbConfig, ProviderSpec
# Create a RAG tool with custom configuration
config = {
"vectordb": {
"provider": "qdrant",
"config": {
"collection_name": "my-collection"
}
},
"embedding_model": {
"provider": "openai",
"config": {
"model": "text-embedding-3-small"
}
vectordb: VectorDbConfig = {
"provider": "qdrant",
"config": {
"collection_name": "my-collection"
}
}
embedding_model: ProviderSpec = {
"provider": "openai",
"config": {
"model_name": "text-embedding-3-small"
}
}
config: RagToolConfig = {
"vectordb": vectordb,
"embedding_model": embedding_model
}
rag_tool = RagTool(config=config, summarize=True)
```
## Embedding Model Configuration
The `embedding_model` parameter accepts a `crewai.rag.embeddings.types.ProviderSpec` dictionary with the structure:
```python
{
"provider": "provider-name", # Required
"config": { # Optional
# Provider-specific configuration
}
}
```
### Supported Providers
<AccordionGroup>
<Accordion title="OpenAI">
```python main.py
from crewai.rag.embeddings.providers.openai.types import OpenAIProviderSpec
embedding_model: OpenAIProviderSpec = {
"provider": "openai",
"config": {
"api_key": "your-api-key",
"model_name": "text-embedding-ada-002",
"dimensions": 1536,
"organization_id": "your-org-id",
"api_base": "https://api.openai.com/v1",
"api_version": "v1",
"default_headers": {"Custom-Header": "value"}
}
}
```
**Config Options:**
- `api_key` (str): OpenAI API key
- `model_name` (str): Model to use. Default: `text-embedding-ada-002`. Options: `text-embedding-3-small`, `text-embedding-3-large`, `text-embedding-ada-002`
- `dimensions` (int): Number of dimensions for the embedding
- `organization_id` (str): OpenAI organization ID
- `api_base` (str): Custom API base URL
- `api_version` (str): API version
- `default_headers` (dict): Custom headers for API requests
**Environment Variables:**
- `OPENAI_API_KEY` or `EMBEDDINGS_OPENAI_API_KEY`: `api_key`
- `OPENAI_ORGANIZATION_ID` or `EMBEDDINGS_OPENAI_ORGANIZATION_ID`: `organization_id`
- `OPENAI_MODEL_NAME` or `EMBEDDINGS_OPENAI_MODEL_NAME`: `model_name`
- `OPENAI_API_BASE` or `EMBEDDINGS_OPENAI_API_BASE`: `api_base`
- `OPENAI_API_VERSION` or `EMBEDDINGS_OPENAI_API_VERSION`: `api_version`
- `OPENAI_DIMENSIONS` or `EMBEDDINGS_OPENAI_DIMENSIONS`: `dimensions`
</Accordion>
<Accordion title="Cohere">
```python main.py
from crewai.rag.embeddings.providers.cohere.types import CohereProviderSpec
embedding_model: CohereProviderSpec = {
"provider": "cohere",
"config": {
"api_key": "your-api-key",
"model_name": "embed-english-v3.0"
}
}
```
**Config Options:**
- `api_key` (str): Cohere API key
- `model_name` (str): Model to use. Default: `large`. Options: `embed-english-v3.0`, `embed-multilingual-v3.0`, `large`, `small`
**Environment Variables:**
- `COHERE_API_KEY` or `EMBEDDINGS_COHERE_API_KEY`: `api_key`
- `EMBEDDINGS_COHERE_MODEL_NAME`: `model_name`
</Accordion>
<Accordion title="VoyageAI">
```python main.py
from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderSpec
embedding_model: VoyageAIProviderSpec = {
"provider": "voyageai",
"config": {
"api_key": "your-api-key",
"model": "voyage-3",
"input_type": "document",
"truncation": True,
"output_dtype": "float32",
"output_dimension": 1024,
"max_retries": 3,
"timeout": 60.0
}
}
```
**Config Options:**
- `api_key` (str): VoyageAI API key
- `model` (str): Model to use. Default: `voyage-2`. Options: `voyage-3`, `voyage-3-lite`, `voyage-code-3`, `voyage-large-2`
- `input_type` (str): Type of input. Options: `document` (for storage), `query` (for search)
- `truncation` (bool): Whether to truncate inputs that exceed max length. Default: `True`
- `output_dtype` (str): Output data type
- `output_dimension` (int): Dimension of output embeddings
- `max_retries` (int): Maximum number of retry attempts. Default: `0`
- `timeout` (float): Request timeout in seconds
**Environment Variables:**
- `VOYAGEAI_API_KEY` or `EMBEDDINGS_VOYAGEAI_API_KEY`: `api_key`
- `VOYAGEAI_MODEL` or `EMBEDDINGS_VOYAGEAI_MODEL`: `model`
- `VOYAGEAI_INPUT_TYPE` or `EMBEDDINGS_VOYAGEAI_INPUT_TYPE`: `input_type`
- `VOYAGEAI_TRUNCATION` or `EMBEDDINGS_VOYAGEAI_TRUNCATION`: `truncation`
- `VOYAGEAI_OUTPUT_DTYPE` or `EMBEDDINGS_VOYAGEAI_OUTPUT_DTYPE`: `output_dtype`
- `VOYAGEAI_OUTPUT_DIMENSION` or `EMBEDDINGS_VOYAGEAI_OUTPUT_DIMENSION`: `output_dimension`
- `VOYAGEAI_MAX_RETRIES` or `EMBEDDINGS_VOYAGEAI_MAX_RETRIES`: `max_retries`
- `VOYAGEAI_TIMEOUT` or `EMBEDDINGS_VOYAGEAI_TIMEOUT`: `timeout`
</Accordion>
<Accordion title="Ollama">
```python main.py
from crewai.rag.embeddings.providers.ollama.types import OllamaProviderSpec
embedding_model: OllamaProviderSpec = {
"provider": "ollama",
"config": {
"model_name": "llama2",
"url": "http://localhost:11434/api/embeddings"
}
}
```
**Config Options:**
- `model_name` (str): Ollama model name (e.g., `llama2`, `mistral`, `nomic-embed-text`)
- `url` (str): Ollama API endpoint URL. Default: `http://localhost:11434/api/embeddings`
**Environment Variables:**
- `OLLAMA_MODEL` or `EMBEDDINGS_OLLAMA_MODEL`: `model_name`
- `OLLAMA_URL` or `EMBEDDINGS_OLLAMA_URL`: `url`
</Accordion>
<Accordion title="Amazon Bedrock">
```python main.py
from crewai.rag.embeddings.providers.aws.types import BedrockProviderSpec
embedding_model: BedrockProviderSpec = {
"provider": "amazon-bedrock",
"config": {
"model_name": "amazon.titan-embed-text-v2:0",
"session": boto3_session
}
}
```
**Config Options:**
- `model_name` (str): Bedrock model ID. Default: `amazon.titan-embed-text-v1`. Options: `amazon.titan-embed-text-v1`, `amazon.titan-embed-text-v2:0`, `cohere.embed-english-v3`, `cohere.embed-multilingual-v3`
- `session` (Any): Boto3 session object for AWS authentication
**Environment Variables:**
- `AWS_ACCESS_KEY_ID`: AWS access key
- `AWS_SECRET_ACCESS_KEY`: AWS secret key
- `AWS_REGION`: AWS region (e.g., `us-east-1`)
</Accordion>
<Accordion title="Azure OpenAI">
```python main.py
from crewai.rag.embeddings.providers.microsoft.types import AzureProviderSpec
embedding_model: AzureProviderSpec = {
"provider": "azure",
"config": {
"deployment_id": "your-deployment-id",
"api_key": "your-api-key",
"api_base": "https://your-resource.openai.azure.com",
"api_version": "2024-02-01",
"model_name": "text-embedding-ada-002",
"api_type": "azure"
}
}
```
**Config Options:**
- `deployment_id` (str): **Required** - Azure OpenAI deployment ID
- `api_key` (str): Azure OpenAI API key
- `api_base` (str): Azure OpenAI resource endpoint
- `api_version` (str): API version. Example: `2024-02-01`
- `model_name` (str): Model name. Default: `text-embedding-ada-002`
- `api_type` (str): API type. Default: `azure`
- `dimensions` (int): Output dimensions
- `default_headers` (dict): Custom headers
**Environment Variables:**
- `AZURE_OPENAI_API_KEY` or `EMBEDDINGS_AZURE_API_KEY`: `api_key`
- `AZURE_OPENAI_ENDPOINT` or `EMBEDDINGS_AZURE_API_BASE`: `api_base`
- `EMBEDDINGS_AZURE_DEPLOYMENT_ID`: `deployment_id`
- `EMBEDDINGS_AZURE_API_VERSION`: `api_version`
- `EMBEDDINGS_AZURE_MODEL_NAME`: `model_name`
- `EMBEDDINGS_AZURE_API_TYPE`: `api_type`
- `EMBEDDINGS_AZURE_DIMENSIONS`: `dimensions`
</Accordion>
<Accordion title="Google Generative AI">
```python main.py
from crewai.rag.embeddings.providers.google.types import GenerativeAiProviderSpec
embedding_model: GenerativeAiProviderSpec = {
"provider": "google-generativeai",
"config": {
"api_key": "your-api-key",
"model_name": "gemini-embedding-001",
"task_type": "RETRIEVAL_DOCUMENT"
}
}
```
**Config Options:**
- `api_key` (str): Google AI API key
- `model_name` (str): Model name. Default: `gemini-embedding-001`. Options: `gemini-embedding-001`, `text-embedding-005`, `text-multilingual-embedding-002`
- `task_type` (str): Task type for embeddings. Default: `RETRIEVAL_DOCUMENT`. Options: `RETRIEVAL_DOCUMENT`, `RETRIEVAL_QUERY`
**Environment Variables:**
- `GOOGLE_API_KEY`, `GEMINI_API_KEY`, or `EMBEDDINGS_GOOGLE_API_KEY`: `api_key`
- `EMBEDDINGS_GOOGLE_GENERATIVE_AI_MODEL_NAME`: `model_name`
- `EMBEDDINGS_GOOGLE_GENERATIVE_AI_TASK_TYPE`: `task_type`
</Accordion>
<Accordion title="Google Vertex AI">
```python main.py
from crewai.rag.embeddings.providers.google.types import VertexAIProviderSpec
embedding_model: VertexAIProviderSpec = {
"provider": "google-vertex",
"config": {
"model_name": "text-embedding-004",
"project_id": "your-project-id",
"region": "us-central1",
"api_key": "your-api-key"
}
}
```
**Config Options:**
- `model_name` (str): Model name. Default: `textembedding-gecko`. Options: `text-embedding-004`, `textembedding-gecko`, `textembedding-gecko-multilingual`
- `project_id` (str): Google Cloud project ID. Default: `cloud-large-language-models`
- `region` (str): Google Cloud region. Default: `us-central1`
- `api_key` (str): API key for authentication
**Environment Variables:**
- `GOOGLE_APPLICATION_CREDENTIALS`: Path to service account JSON file
- `GOOGLE_CLOUD_PROJECT` or `EMBEDDINGS_GOOGLE_VERTEX_PROJECT_ID`: `project_id`
- `EMBEDDINGS_GOOGLE_VERTEX_MODEL_NAME`: `model_name`
- `EMBEDDINGS_GOOGLE_VERTEX_REGION`: `region`
- `EMBEDDINGS_GOOGLE_VERTEX_API_KEY`: `api_key`
</Accordion>
<Accordion title="Jina AI">
```python main.py
from crewai.rag.embeddings.providers.jina.types import JinaProviderSpec
embedding_model: JinaProviderSpec = {
"provider": "jina",
"config": {
"api_key": "your-api-key",
"model_name": "jina-embeddings-v3"
}
}
```
**Config Options:**
- `api_key` (str): Jina AI API key
- `model_name` (str): Model name. Default: `jina-embeddings-v2-base-en`. Options: `jina-embeddings-v3`, `jina-embeddings-v2-base-en`, `jina-embeddings-v2-small-en`
**Environment Variables:**
- `JINA_API_KEY` or `EMBEDDINGS_JINA_API_KEY`: `api_key`
- `EMBEDDINGS_JINA_MODEL_NAME`: `model_name`
</Accordion>
<Accordion title="HuggingFace">
```python main.py
from crewai.rag.embeddings.providers.huggingface.types import HuggingFaceProviderSpec
embedding_model: HuggingFaceProviderSpec = {
"provider": "huggingface",
"config": {
"url": "https://api-inference.huggingface.co/models/sentence-transformers/all-MiniLM-L6-v2"
}
}
```
**Config Options:**
- `url` (str): Full URL to HuggingFace inference API endpoint
**Environment Variables:**
- `HUGGINGFACE_URL` or `EMBEDDINGS_HUGGINGFACE_URL`: `url`
</Accordion>
<Accordion title="Instructor">
```python main.py
from crewai.rag.embeddings.providers.instructor.types import InstructorProviderSpec
embedding_model: InstructorProviderSpec = {
"provider": "instructor",
"config": {
"model_name": "hkunlp/instructor-xl",
"device": "cuda",
"instruction": "Represent the document"
}
}
```
**Config Options:**
- `model_name` (str): HuggingFace model ID. Default: `hkunlp/instructor-base`. Options: `hkunlp/instructor-xl`, `hkunlp/instructor-large`, `hkunlp/instructor-base`
- `device` (str): Device to run on. Default: `cpu`. Options: `cpu`, `cuda`, `mps`
- `instruction` (str): Instruction prefix for embeddings
**Environment Variables:**
- `EMBEDDINGS_INSTRUCTOR_MODEL_NAME`: `model_name`
- `EMBEDDINGS_INSTRUCTOR_DEVICE`: `device`
- `EMBEDDINGS_INSTRUCTOR_INSTRUCTION`: `instruction`
</Accordion>
<Accordion title="Sentence Transformer">
```python main.py
from crewai.rag.embeddings.providers.sentence_transformer.types import SentenceTransformerProviderSpec
embedding_model: SentenceTransformerProviderSpec = {
"provider": "sentence-transformer",
"config": {
"model_name": "all-mpnet-base-v2",
"device": "cuda",
"normalize_embeddings": True
}
}
```
**Config Options:**
- `model_name` (str): Sentence Transformers model name. Default: `all-MiniLM-L6-v2`. Options: `all-mpnet-base-v2`, `all-MiniLM-L6-v2`, `paraphrase-multilingual-MiniLM-L12-v2`
- `device` (str): Device to run on. Default: `cpu`. Options: `cpu`, `cuda`, `mps`
- `normalize_embeddings` (bool): Whether to normalize embeddings. Default: `False`
**Environment Variables:**
- `EMBEDDINGS_SENTENCE_TRANSFORMER_MODEL_NAME`: `model_name`
- `EMBEDDINGS_SENTENCE_TRANSFORMER_DEVICE`: `device`
- `EMBEDDINGS_SENTENCE_TRANSFORMER_NORMALIZE_EMBEDDINGS`: `normalize_embeddings`
</Accordion>
<Accordion title="ONNX">
```python main.py
from crewai.rag.embeddings.providers.onnx.types import ONNXProviderSpec
embedding_model: ONNXProviderSpec = {
"provider": "onnx",
"config": {
"preferred_providers": ["CUDAExecutionProvider", "CPUExecutionProvider"]
}
}
```
**Config Options:**
- `preferred_providers` (list[str]): List of ONNX execution providers in order of preference
**Environment Variables:**
- `EMBEDDINGS_ONNX_PREFERRED_PROVIDERS`: `preferred_providers` (comma-separated list)
</Accordion>
<Accordion title="OpenCLIP">
```python main.py
from crewai.rag.embeddings.providers.openclip.types import OpenCLIPProviderSpec
embedding_model: OpenCLIPProviderSpec = {
"provider": "openclip",
"config": {
"model_name": "ViT-B-32",
"checkpoint": "laion2b_s34b_b79k",
"device": "cuda"
}
}
```
**Config Options:**
- `model_name` (str): OpenCLIP model architecture. Default: `ViT-B-32`. Options: `ViT-B-32`, `ViT-B-16`, `ViT-L-14`
- `checkpoint` (str): Pretrained checkpoint name. Default: `laion2b_s34b_b79k`. Options: `laion2b_s34b_b79k`, `laion400m_e32`, `openai`
- `device` (str): Device to run on. Default: `cpu`. Options: `cpu`, `cuda`
**Environment Variables:**
- `EMBEDDINGS_OPENCLIP_MODEL_NAME`: `model_name`
- `EMBEDDINGS_OPENCLIP_CHECKPOINT`: `checkpoint`
- `EMBEDDINGS_OPENCLIP_DEVICE`: `device`
</Accordion>
<Accordion title="Text2Vec">
```python main.py
from crewai.rag.embeddings.providers.text2vec.types import Text2VecProviderSpec
embedding_model: Text2VecProviderSpec = {
"provider": "text2vec",
"config": {
"model_name": "shibing624/text2vec-base-multilingual"
}
}
```
**Config Options:**
- `model_name` (str): Text2Vec model name from HuggingFace. Default: `shibing624/text2vec-base-chinese`. Options: `shibing624/text2vec-base-multilingual`, `shibing624/text2vec-base-chinese`
**Environment Variables:**
- `EMBEDDINGS_TEXT2VEC_MODEL_NAME`: `model_name`
</Accordion>
<Accordion title="Roboflow">
```python main.py
from crewai.rag.embeddings.providers.roboflow.types import RoboflowProviderSpec
embedding_model: RoboflowProviderSpec = {
"provider": "roboflow",
"config": {
"api_key": "your-api-key",
"api_url": "https://infer.roboflow.com"
}
}
```
**Config Options:**
- `api_key` (str): Roboflow API key. Default: `""` (empty string)
- `api_url` (str): Roboflow inference API URL. Default: `https://infer.roboflow.com`
**Environment Variables:**
- `ROBOFLOW_API_KEY` or `EMBEDDINGS_ROBOFLOW_API_KEY`: `api_key`
- `ROBOFLOW_API_URL` or `EMBEDDINGS_ROBOFLOW_API_URL`: `api_url`
</Accordion>
<Accordion title="WatsonX (IBM)">
```python main.py
from crewai.rag.embeddings.providers.ibm.types import WatsonXProviderSpec
embedding_model: WatsonXProviderSpec = {
"provider": "watsonx",
"config": {
"model_id": "ibm/slate-125m-english-rtrvr",
"url": "https://us-south.ml.cloud.ibm.com",
"api_key": "your-api-key",
"project_id": "your-project-id",
"batch_size": 100,
"concurrency_limit": 10,
"persistent_connection": True
}
}
```
**Config Options:**
- `model_id` (str): WatsonX model identifier
- `url` (str): WatsonX API endpoint
- `api_key` (str): IBM Cloud API key
- `project_id` (str): WatsonX project ID
- `space_id` (str): WatsonX space ID (alternative to project_id)
- `batch_size` (int): Batch size for embeddings. Default: `100`
- `concurrency_limit` (int): Maximum concurrent requests. Default: `10`
- `persistent_connection` (bool): Use persistent connections. Default: `True`
- Plus 20+ additional authentication and configuration options
**Environment Variables:**
- `WATSONX_API_KEY` or `EMBEDDINGS_WATSONX_API_KEY`: `api_key`
- `WATSONX_URL` or `EMBEDDINGS_WATSONX_URL`: `url`
- `WATSONX_PROJECT_ID` or `EMBEDDINGS_WATSONX_PROJECT_ID`: `project_id`
- `EMBEDDINGS_WATSONX_MODEL_ID`: `model_id`
- `EMBEDDINGS_WATSONX_SPACE_ID`: `space_id`
- `EMBEDDINGS_WATSONX_BATCH_SIZE`: `batch_size`
- `EMBEDDINGS_WATSONX_CONCURRENCY_LIMIT`: `concurrency_limit`
- `EMBEDDINGS_WATSONX_PERSISTENT_CONNECTION`: `persistent_connection`
</Accordion>
<Accordion title="Custom">
```python main.py
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
from crewai.rag.embeddings.providers.custom.types import CustomProviderSpec
class MyEmbeddingFunction(EmbeddingFunction):
def __call__(self, input):
# Your custom embedding logic
return embeddings
embedding_model: CustomProviderSpec = {
"provider": "custom",
"config": {
"embedding_callable": MyEmbeddingFunction
}
}
```
**Config Options:**
- `embedding_callable` (type[EmbeddingFunction]): Custom embedding function class
**Note:** Custom embedding functions must implement the `EmbeddingFunction` protocol defined in `crewai.rag.core.base_embeddings_callable`. The `__call__` method should accept input data and return embeddings as a list of numpy arrays (or compatible format that will be normalized). The returned embeddings are automatically normalized and validated.
</Accordion>
</AccordionGroup>
### Notes
- All config fields are optional unless marked as **Required**
- API keys can typically be provided via environment variables instead of config
- Default values are shown where applicable
## Conclusion
The `RagTool` provides a powerful way to create and query knowledge bases from various data sources. By leveraging Retrieval-Augmented Generation, it enables agents to access and retrieve relevant information efficiently, enhancing their ability to provide accurate and contextually appropriate responses.

View File

@@ -58,10 +58,10 @@ tool = MySQLSearchTool(
),
),
embedder=dict(
provider="google",
provider="google-generativeai",
config=dict(
model="models/embedding-001",
task_type="retrieval_document",
model_name="gemini-embedding-001",
task_type="RETRIEVAL_DOCUMENT",
# title="Embeddings",
),
),

View File

@@ -71,10 +71,10 @@ tool = PGSearchTool(
),
),
embedder=dict(
provider="google", # or openai, ollama, ...
provider="google-generativeai", # or openai, ollama, ...
config=dict(
model="models/embedding-001",
task_type="retrieval_document",
model_name="gemini-embedding-001",
task_type="RETRIEVAL_DOCUMENT",
# title="Embeddings",
),
),

View File

@@ -64,10 +64,10 @@ tool = JSONSearchTool(
},
},
"embedding_model": {
"provider": "google", # or openai, ollama, ...
"provider": "google-generativeai", # or openai, ollama, ...
"config": {
"model": "models/embedding-001",
"task_type": "retrieval_document",
"model_name": "gemini-embedding-001",
"task_type": "RETRIEVAL_DOCUMENT",
# Further customization options can be added here.
},
},

View File

@@ -63,15 +63,15 @@ tool = PDFSearchTool(
"config": {
# Model identifier for the chosen provider. "model" will be auto-mapped to "model_name" internally.
"model": "text-embedding-3-small",
# Optional: API key. If omitted, the tool will use provider-specific env vars when available
# (e.g., OPENAI_API_KEY for provider="openai").
# Optional: API key. If omitted, the tool will use provider-specific env vars
# (e.g., OPENAI_API_KEY or EMBEDDINGS_OPENAI_API_KEY for OpenAI).
# "api_key": "sk-...",
# Provider-specific examples:
# --- Google Generative AI ---
# (Set provider="google-generativeai" above)
# "model": "models/embedding-001",
# "task_type": "retrieval_document",
# "model_name": "gemini-embedding-001",
# "task_type": "RETRIEVAL_DOCUMENT",
# "title": "Embeddings",
# --- Cohere ---

View File

@@ -66,9 +66,9 @@ tool = TXTSearchTool(
"provider": "openai", # or google-generativeai, cohere, ollama, ...
"config": {
"model": "text-embedding-3-small",
# "api_key": "sk-...", # optional if env var is set
# "api_key": "sk-...", # optional if env var is set (e.g., OPENAI_API_KEY or EMBEDDINGS_OPENAI_API_KEY)
# Provider examples:
# Google → model: "models/embedding-001", task_type: "retrieval_document"
# Google → model_name: "gemini-embedding-001", task_type: "RETRIEVAL_DOCUMENT"
# Cohere → model: "embed-english-v3.0"
# Ollama → model: "nomic-embed-text"
},

View File

@@ -73,10 +73,10 @@ tool = CodeDocsSearchTool(
),
),
embedder=dict(
provider="google", # or openai, ollama, ...
provider="google-generativeai", # or openai, ollama, ...
config=dict(
model="models/embedding-001",
task_type="retrieval_document",
model_name="gemini-embedding-001",
task_type="RETRIEVAL_DOCUMENT",
# title="Embeddings",
),
),

View File

@@ -75,10 +75,10 @@ tool = GithubSearchTool(
),
),
embedder=dict(
provider="google", # or openai, ollama, ...
provider="google-generativeai", # or openai, ollama, ...
config=dict(
model="models/embedding-001",
task_type="retrieval_document",
model_name="gemini-embedding-001",
task_type="RETRIEVAL_DOCUMENT",
# title="Embeddings",
),
),

View File

@@ -66,10 +66,10 @@ tool = WebsiteSearchTool(
),
),
embedder=dict(
provider="google", # or openai, ollama, ...
provider="google-generativeai", # or openai, ollama, ...
config=dict(
model="models/embedding-001",
task_type="retrieval_document",
model_name="gemini-embedding-001",
task_type="RETRIEVAL_DOCUMENT",
# title="Embeddings",
),
),

View File

@@ -106,10 +106,10 @@ youtube_channel_tool = YoutubeChannelSearchTool(
),
),
embedder=dict(
provider="google", # or openai, ollama, ...
provider="google-generativeai", # or openai, ollama, ...
config=dict(
model="models/embedding-001",
task_type="retrieval_document",
model_name="gemini-embedding-001",
task_type="RETRIEVAL_DOCUMENT",
# title="Embeddings",
),
),

View File

@@ -108,10 +108,10 @@ youtube_search_tool = YoutubeVideoSearchTool(
),
),
embedder=dict(
provider="google", # or openai, ollama, ...
provider="google-generativeai", # or openai, ollama, ...
config=dict(
model="models/embedding-001",
task_type="retrieval_document",
model_name="gemini-embedding-001",
task_type="RETRIEVAL_DOCUMENT",
# title="Embeddings",
),
),

View File

@@ -1,28 +1,51 @@
"""Adapter for CrewAI's native RAG system."""
from __future__ import annotations
import hashlib
from pathlib import Path
from typing import Any, TypeAlias, TypedDict
from typing import TYPE_CHECKING, Any, TypeAlias, TypedDict, cast
import uuid
from crewai.rag.config.types import RagConfigType
from crewai.rag.config.utils import get_rag_client
from crewai.rag.core.base_client import BaseClient
from crewai.rag.factory import create_client
from crewai.rag.qdrant.config import QdrantConfig
from crewai.rag.types import BaseRecord, SearchResult
from pydantic import PrivateAttr
from qdrant_client.models import VectorParams
from typing_extensions import Unpack
from pydantic.dataclasses import is_pydantic_dataclass
from typing_extensions import TypeIs, Unpack
from crewai_tools.rag.data_types import DataType
from crewai_tools.rag.misc import sanitize_metadata_for_chromadb
from crewai_tools.tools.rag.rag_tool import Adapter
if TYPE_CHECKING:
from crewai.rag.qdrant.config import QdrantConfig
ContentItem: TypeAlias = str | Path | dict[str, Any]
def _is_qdrant_config(config: Any) -> TypeIs[QdrantConfig]:
"""Check if config is a QdrantConfig using safe duck typing.
Args:
config: RAG configuration to check.
Returns:
True if config is a QdrantConfig instance.
"""
if not is_pydantic_dataclass(config):
return False
try:
return cast(bool, config.provider == "qdrant") # type: ignore[attr-defined]
except (AttributeError, ImportError):
return False
class AddDocumentParams(TypedDict, total=False):
"""Parameters for adding documents to the RAG system."""
@@ -56,8 +79,9 @@ class CrewAIRagAdapter(Adapter):
else:
self._client = get_rag_client()
collection_params: dict[str, Any] = {"collection_name": self.collection_name}
if isinstance(self.config, QdrantConfig) and self.config.vectors_config:
if isinstance(self.config.vectors_config, VectorParams):
if self.config is not None and _is_qdrant_config(self.config):
if self.config.vectors_config is not None:
collection_params["vectors_config"] = self.config.vectors_config
self._client.get_or_create_collection(**collection_params)

View File

@@ -1,4 +1,5 @@
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
from typing_extensions import Self
from crewai_tools.rag.data_types import DataType
from crewai_tools.tools.rag.rag_tool import RagTool
@@ -24,14 +25,17 @@ class PDFSearchTool(RagTool):
"A tool that can be used to semantic search a query from a PDF's content."
)
args_schema: type[BaseModel] = PDFSearchToolSchema
pdf: str | None = None
def __init__(self, pdf: str | None = None, **kwargs):
super().__init__(**kwargs)
if pdf is not None:
self.add(pdf)
self.description = f"A tool that can be used to semantic search a query the {pdf} PDF's content."
@model_validator(mode="after")
def _configure_for_pdf(self) -> Self:
"""Configure tool for specific PDF if provided."""
if self.pdf is not None:
self.add(self.pdf)
self.description = f"A tool that can be used to semantic search a query the {self.pdf} PDF's content."
self.args_schema = FixedPDFSearchToolSchema
self._generate_description()
return self
def add(self, pdf: str) -> None:
super().add(pdf, data_type=DataType.PDF_FILE)

View File

@@ -0,0 +1,10 @@
from crewai.rag.embeddings.types import ProviderSpec
from crewai_tools.tools.rag.types import RagToolConfig, VectorDbConfig
__all__ = [
"ProviderSpec",
"RagToolConfig",
"VectorDbConfig",
]

View File

@@ -1,10 +1,74 @@
from abc import ABC, abstractmethod
import os
from typing import Any, cast
from typing import Any, Literal, cast
from crewai.rag.embeddings.factory import get_embedding_function
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
from crewai.rag.embeddings.factory import build_embedder
from crewai.rag.embeddings.types import ProviderSpec
from crewai.tools import BaseTool
from pydantic import BaseModel, ConfigDict, Field, model_validator
from pydantic import (
BaseModel,
ConfigDict,
Field,
TypeAdapter,
ValidationError,
field_validator,
model_validator,
)
from typing_extensions import Self
from crewai_tools.tools.rag.types import RagToolConfig, VectorDbConfig
def _validate_embedding_config(
value: dict[str, Any] | ProviderSpec,
) -> dict[str, Any] | ProviderSpec:
"""Validate embedding config and provide clearer error messages for union validation.
This pre-validator catches Pydantic ValidationErrors from the ProviderSpec union
and provides a cleaner, more focused error message that only shows the relevant
provider's validation errors instead of all 18 union members.
Args:
value: The embedding configuration dictionary or validated ProviderSpec.
Returns:
A validated ProviderSpec instance, or the original value if already validated
or missing required fields.
Raises:
ValueError: If the configuration is invalid for the specified provider.
"""
if not isinstance(value, dict):
return value
provider = value.get("provider")
if not provider:
return value
try:
type_adapter: TypeAdapter[ProviderSpec] = TypeAdapter(ProviderSpec)
return type_adapter.validate_python(value)
except ValidationError as e:
provider_key = f"{provider.lower()}providerspec"
provider_errors = [
err for err in e.errors() if provider_key in str(err.get("loc", "")).lower()
]
if provider_errors:
error_msgs = []
for err in provider_errors:
loc_parts = err["loc"]
if str(loc_parts[0]).lower() == provider_key:
loc_parts = loc_parts[1:]
loc = ".".join(str(x) for x in loc_parts)
error_msgs.append(f" - {loc}: {err['msg']}")
raise ValueError(
f"Invalid configuration for embedding provider '{provider}':\n"
+ "\n".join(error_msgs)
) from e
raise
class Adapter(BaseModel, ABC):
@@ -46,139 +110,100 @@ class RagTool(BaseTool):
summarize: bool = False
similarity_threshold: float = 0.6
limit: int = 5
collection_name: str = "rag_tool_collection"
adapter: Adapter = Field(default_factory=_AdapterPlaceholder)
config: Any | None = None
config: RagToolConfig = Field(
default_factory=RagToolConfig,
description="Configuration format accepted by RagTool.",
)
@field_validator("config", mode="before")
@classmethod
def _validate_config(cls, value: Any) -> Any:
"""Validate config with improved error messages for embedding providers."""
if not isinstance(value, dict):
return value
embedding_model = value.get("embedding_model")
if embedding_model:
try:
value["embedding_model"] = _validate_embedding_config(embedding_model)
except ValueError:
raise
return value
@model_validator(mode="after")
def _set_default_adapter(self):
def _ensure_adapter(self) -> Self:
if isinstance(self.adapter, RagTool._AdapterPlaceholder):
from crewai_tools.adapters.crewai_rag_adapter import CrewAIRagAdapter
parsed_config = self._parse_config(self.config)
provider_cfg = self._parse_config(self.config)
self.adapter = CrewAIRagAdapter(
collection_name="rag_tool_collection",
collection_name=self.collection_name,
summarize=self.summarize,
similarity_threshold=self.similarity_threshold,
limit=self.limit,
config=parsed_config,
config=provider_cfg,
)
return self
def _parse_config(self, config: Any) -> Any:
"""Parse complex config format to extract provider-specific config.
def _parse_config(self, config: RagToolConfig) -> Any:
"""Normalize the RagToolConfig into a provider-specific config object.
Raises:
ValueError: If the config format is invalid or uses unsupported providers.
Defaults to 'chromadb' with no extra provider config if none is supplied.
"""
if config is None:
return None
if not config:
return self._create_provider_config("chromadb", {}, None)
if isinstance(config, dict) and "provider" in config:
return config
vectordb_cfg = cast(VectorDbConfig, config.get("vectordb", {}))
provider: Literal["chromadb", "qdrant"] = vectordb_cfg.get(
"provider", "chromadb"
)
provider_config: dict[str, Any] = vectordb_cfg.get("config", {})
if isinstance(config, dict):
if "vectordb" in config:
vectordb_config = config["vectordb"]
if isinstance(vectordb_config, dict) and "provider" in vectordb_config:
provider = vectordb_config["provider"]
provider_config = vectordb_config.get("config", {})
supported = ("chromadb", "qdrant")
if provider not in supported:
raise ValueError(
f"Unsupported vector database provider: '{provider}'. "
f"CrewAI RAG currently supports: {', '.join(supported)}."
)
supported_providers = ["chromadb", "qdrant"]
if provider not in supported_providers:
raise ValueError(
f"Unsupported vector database provider: '{provider}'. "
f"CrewAI RAG currently supports: {', '.join(supported_providers)}."
)
embedding_spec: ProviderSpec | None = config.get("embedding_model")
if embedding_spec:
embedding_spec = cast(
ProviderSpec, _validate_embedding_config(embedding_spec)
)
embedding_config = config.get("embedding_model")
embedding_function = None
if embedding_config and isinstance(embedding_config, dict):
embedding_function = self._create_embedding_function(
embedding_config, provider
)
return self._create_provider_config(
provider, provider_config, embedding_function
)
return None
embedding_config = config.get("embedding_model")
embedding_function = None
if embedding_config and isinstance(embedding_config, dict):
embedding_function = self._create_embedding_function(
embedding_config, "chromadb"
)
return self._create_provider_config("chromadb", {}, embedding_function)
return config
@staticmethod
def _create_embedding_function(embedding_config: dict, provider: str) -> Any:
"""Create embedding function for the specified vector database provider."""
embedding_provider = embedding_config.get("provider")
embedding_model_config = embedding_config.get("config", {}).copy()
if "model" in embedding_model_config:
embedding_model_config["model_name"] = embedding_model_config.pop("model")
factory_config = {"provider": embedding_provider, **embedding_model_config}
if embedding_provider == "openai" and "api_key" not in factory_config:
api_key = os.getenv("OPENAI_API_KEY")
if api_key:
factory_config["api_key"] = api_key
if provider == "chromadb":
return get_embedding_function(factory_config) # type: ignore[call-overload]
if provider == "qdrant":
chromadb_func = get_embedding_function(factory_config) # type: ignore[call-overload]
def qdrant_embed_fn(text: str) -> list[float]:
"""Embed text using ChromaDB function and convert to list of floats for Qdrant.
Args:
text: The input text to embed.
Returns:
A list of floats representing the embedding.
"""
embeddings = chromadb_func([text])
return embeddings[0] if embeddings and len(embeddings) > 0 else []
return cast(Any, qdrant_embed_fn)
return None
embedding_function = build_embedder(embedding_spec) if embedding_spec else None
return self._create_provider_config(
provider, provider_config, embedding_function
)
@staticmethod
def _create_provider_config(
provider: str, provider_config: dict, embedding_function: Any
provider: Literal["chromadb", "qdrant"],
provider_config: dict[str, Any],
embedding_function: EmbeddingFunction[Any] | None,
) -> Any:
"""Create proper provider config object."""
"""Instantiate provider config with optional embedding_function injected."""
if provider == "chromadb":
from crewai.rag.chromadb.config import ChromaDBConfig
config_kwargs = {}
if embedding_function:
config_kwargs["embedding_function"] = embedding_function
config_kwargs.update(provider_config)
return ChromaDBConfig(**config_kwargs)
kwargs = dict(provider_config)
if embedding_function is not None:
kwargs["embedding_function"] = embedding_function
return ChromaDBConfig(**kwargs)
if provider == "qdrant":
from crewai.rag.qdrant.config import QdrantConfig
config_kwargs = {}
if embedding_function:
config_kwargs["embedding_function"] = embedding_function
kwargs = dict(provider_config)
if embedding_function is not None:
kwargs["embedding_function"] = embedding_function
return QdrantConfig(**kwargs)
config_kwargs.update(provider_config)
return QdrantConfig(**config_kwargs)
return None
raise ValueError(f"Unhandled provider: {provider}")
def add(
self,

View File

@@ -0,0 +1,32 @@
"""Type definitions for RAG tool configuration."""
from typing import Any, Literal
from crewai.rag.embeddings.types import ProviderSpec
from typing_extensions import TypedDict
class VectorDbConfig(TypedDict):
"""Configuration for vector database provider.
Attributes:
provider: RAG provider literal.
config: RAG configuration options.
"""
provider: Literal["chromadb", "qdrant"]
config: dict[str, Any]
class RagToolConfig(TypedDict, total=False):
"""Configuration accepted by RAG tools.
Supports embedding model and vector database configuration.
Attributes:
embedding_model: Embedding model configuration accepted by RAG tools.
vectordb: Vector database configuration accepted by RAG tools.
"""
embedding_model: ProviderSpec
vectordb: VectorDbConfig

View File

@@ -1,4 +1,5 @@
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
from typing_extensions import Self
from crewai_tools.tools.rag.rag_tool import RagTool
@@ -24,14 +25,17 @@ class TXTSearchTool(RagTool):
"A tool that can be used to semantic search a query from a txt's content."
)
args_schema: type[BaseModel] = TXTSearchToolSchema
txt: str | None = None
def __init__(self, txt: str | None = None, **kwargs):
super().__init__(**kwargs)
if txt is not None:
self.add(txt)
self.description = f"A tool that can be used to semantic search a query the {txt} txt's content."
@model_validator(mode="after")
def _configure_for_txt(self) -> Self:
"""Configure tool for specific TXT file if provided."""
if self.txt is not None:
self.add(self.txt)
self.description = f"A tool that can be used to semantic search a query the {self.txt} txt's content."
self.args_schema = FixedTXTSearchToolSchema
self._generate_description()
return self
def _run( # type: ignore[override]
self,

View File

@@ -1,5 +1,3 @@
"""Tests for RAG tool with mocked embeddings and vector database."""
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import cast
@@ -117,15 +115,15 @@ def test_rag_tool_with_file(
assert "Python is a programming language" in result
@patch("crewai_tools.tools.rag.rag_tool.RagTool._create_embedding_function")
@patch("crewai_tools.tools.rag.rag_tool.build_embedder")
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
def test_rag_tool_with_custom_embeddings(
mock_create_client: Mock, mock_create_embedding: Mock
mock_create_client: Mock, mock_build_embedder: Mock
) -> None:
"""Test RagTool with custom embeddings configuration to ensure no API calls."""
mock_embedding_func = MagicMock()
mock_embedding_func.return_value = [[0.2] * 1536]
mock_create_embedding.return_value = mock_embedding_func
mock_build_embedder.return_value = mock_embedding_func
mock_client = MagicMock()
mock_client.get_or_create_collection = MagicMock(return_value=None)
@@ -153,7 +151,7 @@ def test_rag_tool_with_custom_embeddings(
assert "Relevant Content:" in result
assert "Test content" in result
mock_create_embedding.assert_called()
mock_build_embedder.assert_called()
@patch("crewai_tools.adapters.crewai_rag_adapter.get_rag_client")
@@ -176,3 +174,128 @@ def test_rag_tool_no_results(
result = tool._run(query="Non-existent content")
assert "Relevant Content:" in result
assert "No relevant content found" in result
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
def test_rag_tool_with_azure_config_without_env_vars(
mock_create_client: Mock,
) -> None:
"""Test that RagTool accepts Azure config without requiring env vars.
This test verifies the fix for the issue where RAG tools were ignoring
the embedding configuration passed via the config parameter and instead
requiring environment variables like EMBEDDINGS_OPENAI_API_KEY.
"""
mock_embedding_func = MagicMock()
mock_embedding_func.return_value = [[0.1] * 1536]
mock_client = MagicMock()
mock_client.get_or_create_collection = MagicMock(return_value=None)
mock_client.add_documents = MagicMock(return_value=None)
mock_create_client.return_value = mock_client
# Patch the embedding function builder to avoid actual API calls
with patch(
"crewai_tools.tools.rag.rag_tool.build_embedder",
return_value=mock_embedding_func,
):
class MyTool(RagTool):
pass
# Configuration with explicit Azure credentials - should work without env vars
config = {
"embedding_model": {
"provider": "azure",
"config": {
"model": "text-embedding-3-small",
"api_key": "test-api-key",
"api_base": "https://test.openai.azure.com/",
"api_version": "2024-02-01",
"api_type": "azure",
"deployment_id": "test-deployment",
},
}
}
# This should not raise a validation error about missing env vars
tool = MyTool(config=config)
assert tool.adapter is not None
assert isinstance(tool.adapter, CrewAIRagAdapter)
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
def test_rag_tool_with_openai_config_without_env_vars(
mock_create_client: Mock,
) -> None:
"""Test that RagTool accepts OpenAI config without requiring env vars."""
mock_embedding_func = MagicMock()
mock_embedding_func.return_value = [[0.1] * 1536]
mock_client = MagicMock()
mock_client.get_or_create_collection = MagicMock(return_value=None)
mock_create_client.return_value = mock_client
with patch(
"crewai_tools.tools.rag.rag_tool.build_embedder",
return_value=mock_embedding_func,
):
class MyTool(RagTool):
pass
config = {
"embedding_model": {
"provider": "openai",
"config": {
"model": "text-embedding-3-small",
"api_key": "sk-test123",
},
}
}
tool = MyTool(config=config)
assert tool.adapter is not None
assert isinstance(tool.adapter, CrewAIRagAdapter)
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
def test_rag_tool_config_with_qdrant_and_azure_embeddings(
mock_create_client: Mock,
) -> None:
"""Test RagTool with Qdrant vector DB and Azure embeddings config."""
mock_embedding_func = MagicMock()
mock_embedding_func.return_value = [[0.1] * 1536]
mock_client = MagicMock()
mock_client.get_or_create_collection = MagicMock(return_value=None)
mock_create_client.return_value = mock_client
with patch(
"crewai_tools.tools.rag.rag_tool.build_embedder",
return_value=mock_embedding_func,
):
class MyTool(RagTool):
pass
config = {
"vectordb": {"provider": "qdrant", "config": {}},
"embedding_model": {
"provider": "azure",
"config": {
"model": "text-embedding-3-large",
"api_key": "test-key",
"api_base": "https://test.openai.azure.com/",
"api_version": "2024-02-01",
"deployment_id": "test-deployment",
},
},
}
tool = MyTool(config=config)
assert tool.adapter is not None
assert isinstance(tool.adapter, CrewAIRagAdapter)

View File

@@ -0,0 +1,66 @@
"""Tests for improved RAG tool validation error messages."""
from unittest.mock import MagicMock, Mock, patch
import pytest
from pydantic import ValidationError
from crewai_tools.tools.rag.rag_tool import RagTool
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
def test_azure_missing_deployment_id_gives_clear_error(mock_create_client: Mock) -> None:
"""Test that missing deployment_id for Azure gives a clear, focused error message."""
mock_client = MagicMock()
mock_client.get_or_create_collection = MagicMock(return_value=None)
mock_create_client.return_value = mock_client
class MyTool(RagTool):
pass
config = {
"embedding_model": {
"provider": "azure",
"config": {
"api_base": "http://localhost:4000/v1",
"api_key": "test-key",
"api_version": "2024-02-01",
},
}
}
with pytest.raises(ValueError) as exc_info:
MyTool(config=config)
error_msg = str(exc_info.value)
assert "azure" in error_msg.lower()
assert "deployment_id" in error_msg.lower()
assert "bedrock" not in error_msg.lower()
assert "cohere" not in error_msg.lower()
assert "huggingface" not in error_msg.lower()
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
def test_valid_azure_config_works(mock_create_client: Mock) -> None:
"""Test that valid Azure config works without errors."""
mock_client = MagicMock()
mock_client.get_or_create_collection = MagicMock(return_value=None)
mock_create_client.return_value = mock_client
class MyTool(RagTool):
pass
config = {
"embedding_model": {
"provider": "azure",
"config": {
"api_base": "http://localhost:4000/v1",
"api_key": "test-key",
"api_version": "2024-02-01",
"deployment_id": "text-embedding-3-small",
},
}
}
tool = MyTool(config=config)
assert tool is not None

View File

@@ -0,0 +1,116 @@
from unittest.mock import MagicMock, Mock, patch
from crewai_tools.adapters.crewai_rag_adapter import CrewAIRagAdapter
from crewai_tools.tools.pdf_search_tool.pdf_search_tool import PDFSearchTool
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
def test_pdf_search_tool_with_azure_config_without_env_vars(
mock_create_client: Mock,
) -> None:
"""Test PDFSearchTool accepts Azure config without requiring env vars.
This verifies the fix for the reported issue where PDFSearchTool would
throw a validation error:
pydantic_core._pydantic_core.ValidationError: 1 validation error for PDFSearchTool
EMBEDDINGS_OPENAI_API_KEY
Field required [type=missing, input_value={}, input_type=dict]
"""
mock_embedding_func = MagicMock()
mock_embedding_func.return_value = [[0.1] * 1536]
mock_client = MagicMock()
mock_client.get_or_create_collection = MagicMock(return_value=None)
mock_create_client.return_value = mock_client
# Patch the embedding function builder to avoid actual API calls
with patch(
"crewai_tools.tools.rag.rag_tool.build_embedder",
return_value=mock_embedding_func,
):
# This is the exact config format from the bug report
config = {
"embedding_model": {
"provider": "azure",
"config": {
"model": "text-embedding-3-small",
"api_key": "test-litellm-api-key",
"api_base": "https://test.litellm.proxy/",
"api_version": "2024-02-01",
"api_type": "azure",
"deployment_id": "test-deployment",
},
}
}
# This should not raise a validation error about missing env vars
tool = PDFSearchTool(config=config)
assert tool.adapter is not None
assert isinstance(tool.adapter, CrewAIRagAdapter)
assert tool.name == "Search a PDF's content"
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
def test_pdf_search_tool_with_openai_config_without_env_vars(
mock_create_client: Mock,
) -> None:
"""Test PDFSearchTool accepts OpenAI config without requiring env vars."""
mock_embedding_func = MagicMock()
mock_embedding_func.return_value = [[0.1] * 1536]
mock_client = MagicMock()
mock_client.get_or_create_collection = MagicMock(return_value=None)
mock_create_client.return_value = mock_client
with patch(
"crewai_tools.tools.rag.rag_tool.build_embedder",
return_value=mock_embedding_func,
):
config = {
"embedding_model": {
"provider": "openai",
"config": {
"model": "text-embedding-3-small",
"api_key": "sk-test123",
},
}
}
tool = PDFSearchTool(config=config)
assert tool.adapter is not None
assert isinstance(tool.adapter, CrewAIRagAdapter)
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
def test_pdf_search_tool_with_vectordb_and_embedding_config(
mock_create_client: Mock,
) -> None:
"""Test PDFSearchTool with both vector DB and embedding config."""
mock_embedding_func = MagicMock()
mock_embedding_func.return_value = [[0.1] * 1536]
mock_client = MagicMock()
mock_client.get_or_create_collection = MagicMock(return_value=None)
mock_create_client.return_value = mock_client
with patch(
"crewai_tools.tools.rag.rag_tool.build_embedder",
return_value=mock_embedding_func,
):
config = {
"vectordb": {"provider": "chromadb", "config": {}},
"embedding_model": {
"provider": "openai",
"config": {
"model": "text-embedding-3-large",
"api_key": "sk-test-key",
},
},
}
tool = PDFSearchTool(config=config)
assert tool.adapter is not None
assert isinstance(tool.adapter, CrewAIRagAdapter)

View File

@@ -0,0 +1,104 @@
from unittest.mock import MagicMock, Mock, patch
from crewai_tools.adapters.crewai_rag_adapter import CrewAIRagAdapter
from crewai_tools.tools.txt_search_tool.txt_search_tool import TXTSearchTool
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
def test_txt_search_tool_with_azure_config_without_env_vars(
mock_create_client: Mock,
) -> None:
"""Test TXTSearchTool accepts Azure config without requiring env vars."""
mock_embedding_func = MagicMock()
mock_embedding_func.return_value = [[0.1] * 1536]
mock_client = MagicMock()
mock_client.get_or_create_collection = MagicMock(return_value=None)
mock_create_client.return_value = mock_client
with patch(
"crewai_tools.tools.rag.rag_tool.build_embedder",
return_value=mock_embedding_func,
):
config = {
"embedding_model": {
"provider": "azure",
"config": {
"model": "text-embedding-3-small",
"api_key": "test-api-key",
"api_base": "https://test.openai.azure.com/",
"api_version": "2024-02-01",
"api_type": "azure",
"deployment_id": "test-deployment",
},
}
}
# This should not raise a validation error about missing env vars
tool = TXTSearchTool(config=config)
assert tool.adapter is not None
assert isinstance(tool.adapter, CrewAIRagAdapter)
assert tool.name == "Search a txt's content"
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
def test_txt_search_tool_with_openai_config_without_env_vars(
mock_create_client: Mock,
) -> None:
"""Test TXTSearchTool accepts OpenAI config without requiring env vars."""
mock_embedding_func = MagicMock()
mock_embedding_func.return_value = [[0.1] * 1536]
mock_client = MagicMock()
mock_client.get_or_create_collection = MagicMock(return_value=None)
mock_create_client.return_value = mock_client
with patch(
"crewai_tools.tools.rag.rag_tool.build_embedder",
return_value=mock_embedding_func,
):
config = {
"embedding_model": {
"provider": "openai",
"config": {
"model": "text-embedding-3-small",
"api_key": "sk-test123",
},
}
}
tool = TXTSearchTool(config=config)
assert tool.adapter is not None
assert isinstance(tool.adapter, CrewAIRagAdapter)
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
def test_txt_search_tool_with_cohere_config(mock_create_client: Mock) -> None:
"""Test TXTSearchTool with Cohere embedding provider."""
mock_embedding_func = MagicMock()
mock_embedding_func.return_value = [[0.1] * 1024]
mock_client = MagicMock()
mock_client.get_or_create_collection = MagicMock(return_value=None)
mock_create_client.return_value = mock_client
with patch(
"crewai_tools.tools.rag.rag_tool.build_embedder",
return_value=mock_embedding_func,
):
config = {
"embedding_model": {
"provider": "cohere",
"config": {
"model": "embed-english-v3.0",
"api_key": "test-cohere-key",
},
}
}
tool = TXTSearchTool(config=config)
assert tool.adapter is not None
assert isinstance(tool.adapter, CrewAIRagAdapter)

View File

@@ -91,6 +91,7 @@ PROVIDER_PATHS = {
"cohere": "crewai.rag.embeddings.providers.cohere.cohere_provider.CohereProvider",
"custom": "crewai.rag.embeddings.providers.custom.custom_provider.CustomProvider",
"google-generativeai": "crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider",
"google": "crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider",
"google-vertex": "crewai.rag.embeddings.providers.google.vertex.VertexAIProvider",
"huggingface": "crewai.rag.embeddings.providers.huggingface.huggingface_provider.HuggingFaceProvider",
"instructor": "crewai.rag.embeddings.providers.instructor.instructor_provider.InstructorProvider",

View File

@@ -5,7 +5,7 @@ from typing import Any
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
AmazonBedrockEmbeddingFunction,
)
from pydantic import Field
from pydantic import AliasChoices, Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
@@ -21,7 +21,7 @@ def create_aws_session() -> Any:
ValueError: If AWS session creation fails
"""
try:
import boto3 # type: ignore[import]
import boto3
return boto3.Session()
except ImportError as e:
@@ -46,7 +46,12 @@ class BedrockProvider(BaseEmbeddingsProvider[AmazonBedrockEmbeddingFunction]):
model_name: str = Field(
default="amazon.titan-embed-text-v1",
description="Model name to use for embeddings",
validation_alias="EMBEDDINGS_BEDROCK_MODEL_NAME",
validation_alias=AliasChoices(
"EMBEDDINGS_BEDROCK_MODEL_NAME",
"BEDROCK_MODEL_NAME",
"AWS_BEDROCK_MODEL_NAME",
"model",
),
)
session: Any = Field(
default_factory=create_aws_session, description="AWS session object"

View File

@@ -3,7 +3,7 @@
from chromadb.utils.embedding_functions.cohere_embedding_function import (
CohereEmbeddingFunction,
)
from pydantic import Field
from pydantic import AliasChoices, Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
@@ -15,10 +15,14 @@ class CohereProvider(BaseEmbeddingsProvider[CohereEmbeddingFunction]):
default=CohereEmbeddingFunction, description="Cohere embedding function class"
)
api_key: str = Field(
description="Cohere API key", validation_alias="EMBEDDINGS_COHERE_API_KEY"
description="Cohere API key",
validation_alias=AliasChoices("EMBEDDINGS_COHERE_API_KEY", "COHERE_API_KEY"),
)
model_name: str = Field(
default="large",
description="Model name to use for embeddings",
validation_alias="EMBEDDINGS_COHERE_MODEL_NAME",
validation_alias=AliasChoices(
"EMBEDDINGS_COHERE_MODEL_NAME",
"model",
),
)

View File

@@ -1,9 +1,11 @@
"""Google Generative AI embeddings provider."""
from typing import Literal
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleGenerativeAiEmbeddingFunction,
)
from pydantic import Field
from pydantic import AliasChoices, Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
@@ -15,16 +17,27 @@ class GenerativeAiProvider(BaseEmbeddingsProvider[GoogleGenerativeAiEmbeddingFun
default=GoogleGenerativeAiEmbeddingFunction,
description="Google Generative AI embedding function class",
)
model_name: str = Field(
default="models/embedding-001",
model_name: Literal[
"gemini-embedding-001", "text-embedding-005", "text-multilingual-embedding-002"
] = Field(
default="gemini-embedding-001",
description="Model name to use for embeddings",
validation_alias="EMBEDDINGS_GOOGLE_GENERATIVE_AI_MODEL_NAME",
validation_alias=AliasChoices(
"EMBEDDINGS_GOOGLE_GENERATIVE_AI_MODEL_NAME", "model"
),
)
api_key: str = Field(
description="Google API key", validation_alias="EMBEDDINGS_GOOGLE_API_KEY"
description="Google API key",
validation_alias=AliasChoices(
"EMBEDDINGS_GOOGLE_API_KEY", "GOOGLE_API_KEY", "GEMINI_API_KEY"
),
)
task_type: str = Field(
default="RETRIEVAL_DOCUMENT",
description="Task type for embeddings",
validation_alias="EMBEDDINGS_GOOGLE_GENERATIVE_AI_TASK_TYPE",
validation_alias=AliasChoices(
"EMBEDDINGS_GOOGLE_GENERATIVE_AI_TASK_TYPE",
"GOOGLE_GENERATIVE_AI_TASK_TYPE",
"GEMINI_TASK_TYPE",
),
)

View File

@@ -6,10 +6,23 @@ from typing_extensions import Required, TypedDict
class GenerativeAiProviderConfig(TypedDict, total=False):
"""Configuration for Google Generative AI provider."""
"""Configuration for Google Generative AI provider.
Attributes:
api_key: Google API key for authentication.
model_name: Embedding model name.
task_type: Task type for embeddings. Default is "RETRIEVAL_DOCUMENT".
"""
api_key: str
model_name: Annotated[str, "models/embedding-001"]
model_name: Annotated[
Literal[
"gemini-embedding-001",
"text-embedding-005",
"text-multilingual-embedding-002",
],
"gemini-embedding-001",
]
task_type: Annotated[str, "RETRIEVAL_DOCUMENT"]

View File

@@ -3,7 +3,7 @@
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleVertexEmbeddingFunction,
)
from pydantic import Field
from pydantic import AliasChoices, Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
@@ -18,18 +18,29 @@ class VertexAIProvider(BaseEmbeddingsProvider[GoogleVertexEmbeddingFunction]):
model_name: str = Field(
default="textembedding-gecko",
description="Model name to use for embeddings",
validation_alias="EMBEDDINGS_GOOGLE_VERTEX_MODEL_NAME",
validation_alias=AliasChoices(
"EMBEDDINGS_GOOGLE_VERTEX_MODEL_NAME",
"GOOGLE_VERTEX_MODEL_NAME",
"model",
),
)
api_key: str = Field(
description="Google API key", validation_alias="EMBEDDINGS_GOOGLE_CLOUD_API_KEY"
description="Google API key",
validation_alias=AliasChoices(
"EMBEDDINGS_GOOGLE_CLOUD_API_KEY", "GOOGLE_CLOUD_API_KEY"
),
)
project_id: str = Field(
default="cloud-large-language-models",
description="GCP project ID",
validation_alias="EMBEDDINGS_GOOGLE_CLOUD_PROJECT",
validation_alias=AliasChoices(
"EMBEDDINGS_GOOGLE_CLOUD_PROJECT", "GOOGLE_CLOUD_PROJECT"
),
)
region: str = Field(
default="us-central1",
description="GCP region",
validation_alias="EMBEDDINGS_GOOGLE_CLOUD_REGION",
validation_alias=AliasChoices(
"EMBEDDINGS_GOOGLE_CLOUD_REGION", "GOOGLE_CLOUD_REGION"
),
)

View File

@@ -3,7 +3,7 @@
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
HuggingFaceEmbeddingServer,
)
from pydantic import Field
from pydantic import AliasChoices, Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
@@ -16,5 +16,6 @@ class HuggingFaceProvider(BaseEmbeddingsProvider[HuggingFaceEmbeddingServer]):
description="HuggingFace embedding function class",
)
url: str = Field(
description="HuggingFace API URL", validation_alias="EMBEDDINGS_HUGGINGFACE_URL"
description="HuggingFace API URL",
validation_alias=AliasChoices("EMBEDDINGS_HUGGINGFACE_URL", "HUGGINGFACE_URL"),
)

View File

@@ -2,7 +2,7 @@
from typing import Any
from pydantic import Field, model_validator
from pydantic import AliasChoices, Field, model_validator
from typing_extensions import Self
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
@@ -21,7 +21,10 @@ class WatsonXProvider(BaseEmbeddingsProvider[WatsonXEmbeddingFunction]):
default=WatsonXEmbeddingFunction, description="WatsonX embedding function class"
)
model_id: str = Field(
description="WatsonX model ID", validation_alias="EMBEDDINGS_WATSONX_MODEL_ID"
description="WatsonX model ID",
validation_alias=AliasChoices(
"EMBEDDINGS_WATSONX_MODEL_ID", "WATSONX_MODEL_ID"
),
)
params: dict[str, str | dict[str, str]] | None = Field(
default=None, description="Additional parameters"
@@ -30,109 +33,143 @@ class WatsonXProvider(BaseEmbeddingsProvider[WatsonXEmbeddingFunction]):
project_id: str | None = Field(
default=None,
description="WatsonX project ID",
validation_alias="EMBEDDINGS_WATSONX_PROJECT_ID",
validation_alias=AliasChoices(
"EMBEDDINGS_WATSONX_PROJECT_ID", "WATSONX_PROJECT_ID"
),
)
space_id: str | None = Field(
default=None,
description="WatsonX space ID",
validation_alias="EMBEDDINGS_WATSONX_SPACE_ID",
validation_alias=AliasChoices(
"EMBEDDINGS_WATSONX_SPACE_ID", "WATSONX_SPACE_ID"
),
)
api_client: Any | None = Field(default=None, description="WatsonX API client")
verify: bool | str | None = Field(
default=None,
description="SSL verification",
validation_alias="EMBEDDINGS_WATSONX_VERIFY",
validation_alias=AliasChoices("EMBEDDINGS_WATSONX_VERIFY", "WATSONX_VERIFY"),
)
persistent_connection: bool = Field(
default=True,
description="Use persistent connection",
validation_alias="EMBEDDINGS_WATSONX_PERSISTENT_CONNECTION",
validation_alias=AliasChoices(
"EMBEDDINGS_WATSONX_PERSISTENT_CONNECTION", "WATSONX_PERSISTENT_CONNECTION"
),
)
batch_size: int = Field(
default=100,
description="Batch size for processing",
validation_alias="EMBEDDINGS_WATSONX_BATCH_SIZE",
validation_alias=AliasChoices(
"EMBEDDINGS_WATSONX_BATCH_SIZE", "WATSONX_BATCH_SIZE"
),
)
concurrency_limit: int = Field(
default=10,
description="Concurrency limit",
validation_alias="EMBEDDINGS_WATSONX_CONCURRENCY_LIMIT",
validation_alias=AliasChoices(
"EMBEDDINGS_WATSONX_CONCURRENCY_LIMIT", "WATSONX_CONCURRENCY_LIMIT"
),
)
max_retries: int | None = Field(
default=None,
description="Maximum retries",
validation_alias="EMBEDDINGS_WATSONX_MAX_RETRIES",
validation_alias=AliasChoices(
"EMBEDDINGS_WATSONX_MAX_RETRIES", "WATSONX_MAX_RETRIES"
),
)
delay_time: float | None = Field(
default=None,
description="Delay time between retries",
validation_alias="EMBEDDINGS_WATSONX_DELAY_TIME",
validation_alias=AliasChoices(
"EMBEDDINGS_WATSONX_DELAY_TIME", "WATSONX_DELAY_TIME"
),
)
retry_status_codes: list[int] | None = Field(
default=None, description="HTTP status codes to retry on"
)
url: str = Field(
description="WatsonX API URL", validation_alias="EMBEDDINGS_WATSONX_URL"
description="WatsonX API URL",
validation_alias=AliasChoices("EMBEDDINGS_WATSONX_URL", "WATSONX_URL"),
)
api_key: str = Field(
description="WatsonX API key", validation_alias="EMBEDDINGS_WATSONX_API_KEY"
description="WatsonX API key",
validation_alias=AliasChoices("EMBEDDINGS_WATSONX_API_KEY", "WATSONX_API_KEY"),
)
name: str | None = Field(
default=None,
description="Service name",
validation_alias="EMBEDDINGS_WATSONX_NAME",
validation_alias=AliasChoices("EMBEDDINGS_WATSONX_NAME", "WATSONX_NAME"),
)
iam_serviceid_crn: str | None = Field(
default=None,
description="IAM service ID CRN",
validation_alias="EMBEDDINGS_WATSONX_IAM_SERVICEID_CRN",
validation_alias=AliasChoices(
"EMBEDDINGS_WATSONX_IAM_SERVICEID_CRN", "WATSONX_IAM_SERVICEID_CRN"
),
)
trusted_profile_id: str | None = Field(
default=None,
description="Trusted profile ID",
validation_alias="EMBEDDINGS_WATSONX_TRUSTED_PROFILE_ID",
validation_alias=AliasChoices(
"EMBEDDINGS_WATSONX_TRUSTED_PROFILE_ID", "WATSONX_TRUSTED_PROFILE_ID"
),
)
token: str | None = Field(
default=None,
description="Bearer token",
validation_alias="EMBEDDINGS_WATSONX_TOKEN",
validation_alias=AliasChoices("EMBEDDINGS_WATSONX_TOKEN", "WATSONX_TOKEN"),
)
projects_token: str | None = Field(
default=None,
description="Projects token",
validation_alias="EMBEDDINGS_WATSONX_PROJECTS_TOKEN",
validation_alias=AliasChoices(
"EMBEDDINGS_WATSONX_PROJECTS_TOKEN", "WATSONX_PROJECTS_TOKEN"
),
)
username: str | None = Field(
default=None,
description="Username",
validation_alias="EMBEDDINGS_WATSONX_USERNAME",
validation_alias=AliasChoices(
"EMBEDDINGS_WATSONX_USERNAME", "WATSONX_USERNAME"
),
)
password: str | None = Field(
default=None,
description="Password",
validation_alias="EMBEDDINGS_WATSONX_PASSWORD",
validation_alias=AliasChoices(
"EMBEDDINGS_WATSONX_PASSWORD", "WATSONX_PASSWORD"
),
)
instance_id: str | None = Field(
default=None,
description="Service instance ID",
validation_alias="EMBEDDINGS_WATSONX_INSTANCE_ID",
validation_alias=AliasChoices(
"EMBEDDINGS_WATSONX_INSTANCE_ID", "WATSONX_INSTANCE_ID"
),
)
version: str | None = Field(
default=None,
description="API version",
validation_alias="EMBEDDINGS_WATSONX_VERSION",
validation_alias=AliasChoices("EMBEDDINGS_WATSONX_VERSION", "WATSONX_VERSION"),
)
bedrock_url: str | None = Field(
default=None,
description="Bedrock URL",
validation_alias="EMBEDDINGS_WATSONX_BEDROCK_URL",
validation_alias=AliasChoices(
"EMBEDDINGS_WATSONX_BEDROCK_URL", "WATSONX_BEDROCK_URL"
),
)
platform_url: str | None = Field(
default=None,
description="Platform URL",
validation_alias="EMBEDDINGS_WATSONX_PLATFORM_URL",
validation_alias=AliasChoices(
"EMBEDDINGS_WATSONX_PLATFORM_URL", "WATSONX_PLATFORM_URL"
),
)
proxies: dict[str, Any] | None = Field(
default=None, description="Proxy configuration"
)
proxies: dict | None = Field(default=None, description="Proxy configuration")
@model_validator(mode="after")
def validate_space_or_project(self) -> Self:

View File

@@ -3,7 +3,7 @@
from chromadb.utils.embedding_functions.instructor_embedding_function import (
InstructorEmbeddingFunction,
)
from pydantic import Field
from pydantic import AliasChoices, Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
@@ -18,15 +18,23 @@ class InstructorProvider(BaseEmbeddingsProvider[InstructorEmbeddingFunction]):
model_name: str = Field(
default="hkunlp/instructor-base",
description="Model name to use",
validation_alias="EMBEDDINGS_INSTRUCTOR_MODEL_NAME",
validation_alias=AliasChoices(
"EMBEDDINGS_INSTRUCTOR_MODEL_NAME",
"INSTRUCTOR_MODEL_NAME",
"model",
),
)
device: str = Field(
default="cpu",
description="Device to run model on (cpu or cuda)",
validation_alias="EMBEDDINGS_INSTRUCTOR_DEVICE",
validation_alias=AliasChoices(
"EMBEDDINGS_INSTRUCTOR_DEVICE", "INSTRUCTOR_DEVICE"
),
)
instruction: str | None = Field(
default=None,
description="Instruction for embeddings",
validation_alias="EMBEDDINGS_INSTRUCTOR_INSTRUCTION",
validation_alias=AliasChoices(
"EMBEDDINGS_INSTRUCTOR_INSTRUCTION", "INSTRUCTOR_INSTRUCTION"
),
)

View File

@@ -3,7 +3,7 @@
from chromadb.utils.embedding_functions.jina_embedding_function import (
JinaEmbeddingFunction,
)
from pydantic import Field
from pydantic import AliasChoices, Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
@@ -15,10 +15,15 @@ class JinaProvider(BaseEmbeddingsProvider[JinaEmbeddingFunction]):
default=JinaEmbeddingFunction, description="Jina embedding function class"
)
api_key: str = Field(
description="Jina API key", validation_alias="EMBEDDINGS_JINA_API_KEY"
description="Jina API key",
validation_alias=AliasChoices("EMBEDDINGS_JINA_API_KEY", "JINA_API_KEY"),
)
model_name: str = Field(
default="jina-embeddings-v2-base-en",
description="Model name to use for embeddings",
validation_alias="EMBEDDINGS_JINA_MODEL_NAME",
validation_alias=AliasChoices(
"EMBEDDINGS_JINA_MODEL_NAME",
"JINA_MODEL_NAME",
"model",
),
)

View File

@@ -5,7 +5,7 @@ from typing import Any
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
from pydantic import Field
from pydantic import AliasChoices, Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
@@ -18,27 +18,39 @@ class AzureProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
description="Azure OpenAI embedding function class",
)
api_key: str = Field(
description="Azure API key", validation_alias="EMBEDDINGS_OPENAI_API_KEY"
description="Azure API key",
validation_alias=AliasChoices("EMBEDDINGS_OPENAI_API_KEY", "OPENAI_API_KEY"),
)
api_base: str | None = Field(
default=None,
description="Azure endpoint URL",
validation_alias="EMBEDDINGS_OPENAI_API_BASE",
validation_alias=AliasChoices("EMBEDDINGS_OPENAI_API_BASE", "OPENAI_API_BASE"),
)
api_type: str = Field(
default="azure",
description="API type for Azure",
validation_alias="EMBEDDINGS_OPENAI_API_TYPE",
validation_alias=AliasChoices(
"EMBEDDINGS_OPENAI_API_TYPE", "OPENAI_API_TYPE", "AZURE_OPENAI_API_TYPE"
),
)
api_version: str | None = Field(
default=None,
default="2024-02-01",
description="Azure API version",
validation_alias="EMBEDDINGS_OPENAI_API_VERSION",
validation_alias=AliasChoices(
"EMBEDDINGS_OPENAI_API_VERSION",
"OPENAI_API_VERSION",
"AZURE_OPENAI_API_VERSION",
),
)
model_name: str = Field(
default="text-embedding-ada-002",
description="Model name to use for embeddings",
validation_alias="EMBEDDINGS_OPENAI_MODEL_NAME",
validation_alias=AliasChoices(
"EMBEDDINGS_OPENAI_MODEL_NAME",
"OPENAI_MODEL_NAME",
"AZURE_OPENAI_MODEL_NAME",
"model",
),
)
default_headers: dict[str, Any] | None = Field(
default=None, description="Default headers for API requests"
@@ -46,15 +58,26 @@ class AzureProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
dimensions: int | None = Field(
default=None,
description="Embedding dimensions",
validation_alias="EMBEDDINGS_OPENAI_DIMENSIONS",
validation_alias=AliasChoices(
"EMBEDDINGS_OPENAI_DIMENSIONS",
"OPENAI_DIMENSIONS",
"AZURE_OPENAI_DIMENSIONS",
),
)
deployment_id: str | None = Field(
default=None,
deployment_id: str = Field(
description="Azure deployment ID",
validation_alias="EMBEDDINGS_OPENAI_DEPLOYMENT_ID",
validation_alias=AliasChoices(
"EMBEDDINGS_OPENAI_DEPLOYMENT_ID",
"AZURE_OPENAI_DEPLOYMENT",
"AZURE_DEPLOYMENT_ID",
),
)
organization_id: str | None = Field(
default=None,
description="Organization ID",
validation_alias="EMBEDDINGS_OPENAI_ORGANIZATION_ID",
validation_alias=AliasChoices(
"EMBEDDINGS_OPENAI_ORGANIZATION_ID",
"OPENAI_ORGANIZATION_ID",
"AZURE_OPENAI_ORGANIZATION_ID",
),
)

View File

@@ -15,7 +15,7 @@ class AzureProviderConfig(TypedDict, total=False):
model_name: Annotated[str, "text-embedding-ada-002"]
default_headers: dict[str, Any]
dimensions: int
deployment_id: str
deployment_id: Required[str]
organization_id: str

View File

@@ -3,7 +3,7 @@
from chromadb.utils.embedding_functions.ollama_embedding_function import (
OllamaEmbeddingFunction,
)
from pydantic import Field
from pydantic import AliasChoices, Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
@@ -17,9 +17,14 @@ class OllamaProvider(BaseEmbeddingsProvider[OllamaEmbeddingFunction]):
url: str = Field(
default="http://localhost:11434/api/embeddings",
description="Ollama API endpoint URL",
validation_alias="EMBEDDINGS_OLLAMA_URL",
validation_alias=AliasChoices("EMBEDDINGS_OLLAMA_URL", "OLLAMA_URL"),
)
model_name: str = Field(
description="Model name to use for embeddings",
validation_alias="EMBEDDINGS_OLLAMA_MODEL_NAME",
validation_alias=AliasChoices(
"EMBEDDINGS_OLLAMA_MODEL_NAME",
"OLLAMA_MODEL_NAME",
"OLLAMA_MODEL",
"model",
),
)

View File

@@ -1,7 +1,7 @@
"""ONNX embeddings provider."""
from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ONNXMiniLM_L6_V2
from pydantic import Field
from pydantic import AliasChoices, Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
@@ -15,5 +15,7 @@ class ONNXProvider(BaseEmbeddingsProvider[ONNXMiniLM_L6_V2]):
preferred_providers: list[str] | None = Field(
default=None,
description="Preferred ONNX execution providers",
validation_alias="EMBEDDINGS_ONNX_PREFERRED_PROVIDERS",
validation_alias=AliasChoices(
"EMBEDDINGS_ONNX_PREFERRED_PROVIDERS", "ONNX_PREFERRED_PROVIDERS"
),
)

View File

@@ -5,7 +5,7 @@ from typing import Any
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
from pydantic import Field
from pydantic import AliasChoices, Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
@@ -20,27 +20,33 @@ class OpenAIProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
api_key: str | None = Field(
default=None,
description="OpenAI API key",
validation_alias="EMBEDDINGS_OPENAI_API_KEY",
validation_alias=AliasChoices("EMBEDDINGS_OPENAI_API_KEY", "OPENAI_API_KEY"),
)
model_name: str = Field(
default="text-embedding-ada-002",
description="Model name to use for embeddings",
validation_alias="EMBEDDINGS_OPENAI_MODEL_NAME",
validation_alias=AliasChoices(
"EMBEDDINGS_OPENAI_MODEL_NAME",
"OPENAI_MODEL_NAME",
"model",
),
)
api_base: str | None = Field(
default=None,
description="Base URL for API requests",
validation_alias="EMBEDDINGS_OPENAI_API_BASE",
validation_alias=AliasChoices("EMBEDDINGS_OPENAI_API_BASE", "OPENAI_API_BASE"),
)
api_type: str | None = Field(
default=None,
description="API type (e.g., 'azure')",
validation_alias="EMBEDDINGS_OPENAI_API_TYPE",
validation_alias=AliasChoices("EMBEDDINGS_OPENAI_API_TYPE", "OPENAI_API_TYPE"),
)
api_version: str | None = Field(
default=None,
description="API version",
validation_alias="EMBEDDINGS_OPENAI_API_VERSION",
validation_alias=AliasChoices(
"EMBEDDINGS_OPENAI_API_VERSION", "OPENAI_API_VERSION"
),
)
default_headers: dict[str, Any] | None = Field(
default=None, description="Default headers for API requests"
@@ -48,15 +54,21 @@ class OpenAIProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
dimensions: int | None = Field(
default=None,
description="Embedding dimensions",
validation_alias="EMBEDDINGS_OPENAI_DIMENSIONS",
validation_alias=AliasChoices(
"EMBEDDINGS_OPENAI_DIMENSIONS", "OPENAI_DIMENSIONS"
),
)
deployment_id: str | None = Field(
default=None,
description="Azure deployment ID",
validation_alias="EMBEDDINGS_OPENAI_DEPLOYMENT_ID",
validation_alias=AliasChoices(
"EMBEDDINGS_OPENAI_DEPLOYMENT_ID", "OPENAI_DEPLOYMENT_ID"
),
)
organization_id: str | None = Field(
default=None,
description="OpenAI organization ID",
validation_alias="EMBEDDINGS_OPENAI_ORGANIZATION_ID",
validation_alias=AliasChoices(
"EMBEDDINGS_OPENAI_ORGANIZATION_ID", "OPENAI_ORGANIZATION_ID"
),
)

View File

@@ -3,7 +3,7 @@
from chromadb.utils.embedding_functions.open_clip_embedding_function import (
OpenCLIPEmbeddingFunction,
)
from pydantic import Field
from pydantic import AliasChoices, Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
@@ -18,15 +18,21 @@ class OpenCLIPProvider(BaseEmbeddingsProvider[OpenCLIPEmbeddingFunction]):
model_name: str = Field(
default="ViT-B-32",
description="Model name to use",
validation_alias="EMBEDDINGS_OPENCLIP_MODEL_NAME",
validation_alias=AliasChoices(
"EMBEDDINGS_OPENCLIP_MODEL_NAME",
"OPENCLIP_MODEL_NAME",
"model",
),
)
checkpoint: str = Field(
default="laion2b_s34b_b79k",
description="Model checkpoint",
validation_alias="EMBEDDINGS_OPENCLIP_CHECKPOINT",
validation_alias=AliasChoices(
"EMBEDDINGS_OPENCLIP_CHECKPOINT", "OPENCLIP_CHECKPOINT"
),
)
device: str | None = Field(
default="cpu",
description="Device to run model on",
validation_alias="EMBEDDINGS_OPENCLIP_DEVICE",
validation_alias=AliasChoices("EMBEDDINGS_OPENCLIP_DEVICE", "OPENCLIP_DEVICE"),
)

View File

@@ -3,7 +3,7 @@
from chromadb.utils.embedding_functions.roboflow_embedding_function import (
RoboflowEmbeddingFunction,
)
from pydantic import Field
from pydantic import AliasChoices, Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
@@ -18,10 +18,14 @@ class RoboflowProvider(BaseEmbeddingsProvider[RoboflowEmbeddingFunction]):
api_key: str = Field(
default="",
description="Roboflow API key",
validation_alias="EMBEDDINGS_ROBOFLOW_API_KEY",
validation_alias=AliasChoices(
"EMBEDDINGS_ROBOFLOW_API_KEY", "ROBOFLOW_API_KEY"
),
)
api_url: str = Field(
default="https://infer.roboflow.com",
description="Roboflow API URL",
validation_alias="EMBEDDINGS_ROBOFLOW_API_URL",
validation_alias=AliasChoices(
"EMBEDDINGS_ROBOFLOW_API_URL", "ROBOFLOW_API_URL"
),
)

View File

@@ -3,7 +3,7 @@
from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import (
SentenceTransformerEmbeddingFunction,
)
from pydantic import Field
from pydantic import AliasChoices, Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
@@ -20,15 +20,24 @@ class SentenceTransformerProvider(
model_name: str = Field(
default="all-MiniLM-L6-v2",
description="Model name to use",
validation_alias="EMBEDDINGS_SENTENCE_TRANSFORMER_MODEL_NAME",
validation_alias=AliasChoices(
"EMBEDDINGS_SENTENCE_TRANSFORMER_MODEL_NAME",
"SENTENCE_TRANSFORMER_MODEL_NAME",
"model",
),
)
device: str = Field(
default="cpu",
description="Device to run model on (cpu or cuda)",
validation_alias="EMBEDDINGS_SENTENCE_TRANSFORMER_DEVICE",
validation_alias=AliasChoices(
"EMBEDDINGS_SENTENCE_TRANSFORMER_DEVICE", "SENTENCE_TRANSFORMER_DEVICE"
),
)
normalize_embeddings: bool = Field(
default=False,
description="Whether to normalize embeddings",
validation_alias="EMBEDDINGS_SENTENCE_TRANSFORMER_NORMALIZE_EMBEDDINGS",
validation_alias=AliasChoices(
"EMBEDDINGS_SENTENCE_TRANSFORMER_NORMALIZE_EMBEDDINGS",
"SENTENCE_TRANSFORMER_NORMALIZE_EMBEDDINGS",
),
)

View File

@@ -3,7 +3,7 @@
from chromadb.utils.embedding_functions.text2vec_embedding_function import (
Text2VecEmbeddingFunction,
)
from pydantic import Field
from pydantic import AliasChoices, Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
@@ -18,5 +18,9 @@ class Text2VecProvider(BaseEmbeddingsProvider[Text2VecEmbeddingFunction]):
model_name: str = Field(
default="shibing624/text2vec-base-chinese",
description="Model name to use",
validation_alias="EMBEDDINGS_TEXT2VEC_MODEL_NAME",
validation_alias=AliasChoices(
"EMBEDDINGS_TEXT2VEC_MODEL_NAME",
"TEXT2VEC_MODEL_NAME",
"model",
),
)

View File

@@ -1,6 +1,6 @@
"""Voyage AI embeddings provider."""
from pydantic import Field
from pydantic import AliasChoices, Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
from crewai.rag.embeddings.providers.voyageai.embedding_callable import (
@@ -18,38 +18,53 @@ class VoyageAIProvider(BaseEmbeddingsProvider[VoyageAIEmbeddingFunction]):
model: str = Field(
default="voyage-2",
description="Model to use for embeddings",
validation_alias="EMBEDDINGS_VOYAGEAI_MODEL",
validation_alias=AliasChoices("EMBEDDINGS_VOYAGEAI_MODEL", "VOYAGEAI_MODEL"),
)
api_key: str = Field(
description="Voyage AI API key", validation_alias="EMBEDDINGS_VOYAGEAI_API_KEY"
description="Voyage AI API key",
validation_alias=AliasChoices(
"EMBEDDINGS_VOYAGEAI_API_KEY", "VOYAGEAI_API_KEY"
),
)
input_type: str | None = Field(
default=None,
description="Input type for embeddings",
validation_alias="EMBEDDINGS_VOYAGEAI_INPUT_TYPE",
validation_alias=AliasChoices(
"EMBEDDINGS_VOYAGEAI_INPUT_TYPE", "VOYAGEAI_INPUT_TYPE"
),
)
truncation: bool = Field(
default=True,
description="Whether to truncate inputs",
validation_alias="EMBEDDINGS_VOYAGEAI_TRUNCATION",
validation_alias=AliasChoices(
"EMBEDDINGS_VOYAGEAI_TRUNCATION", "VOYAGEAI_TRUNCATION"
),
)
output_dtype: str | None = Field(
default=None,
description="Output data type",
validation_alias="EMBEDDINGS_VOYAGEAI_OUTPUT_DTYPE",
validation_alias=AliasChoices(
"EMBEDDINGS_VOYAGEAI_OUTPUT_DTYPE", "VOYAGEAI_OUTPUT_DTYPE"
),
)
output_dimension: int | None = Field(
default=None,
description="Output dimension",
validation_alias="EMBEDDINGS_VOYAGEAI_OUTPUT_DIMENSION",
validation_alias=AliasChoices(
"EMBEDDINGS_VOYAGEAI_OUTPUT_DIMENSION", "VOYAGEAI_OUTPUT_DIMENSION"
),
)
max_retries: int = Field(
default=0,
description="Maximum retries for API calls",
validation_alias="EMBEDDINGS_VOYAGEAI_MAX_RETRIES",
validation_alias=AliasChoices(
"EMBEDDINGS_VOYAGEAI_MAX_RETRIES", "VOYAGEAI_MAX_RETRIES"
),
)
timeout: float | None = Field(
default=None,
description="Timeout for API calls",
validation_alias="EMBEDDINGS_VOYAGEAI_TIMEOUT",
validation_alias=AliasChoices(
"EMBEDDINGS_VOYAGEAI_TIMEOUT", "VOYAGEAI_TIMEOUT"
),
)

View File

@@ -29,7 +29,7 @@ from crewai.rag.embeddings.providers.text2vec.types import Text2VecProviderSpec
from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderSpec
ProviderSpec = (
ProviderSpec: TypeAlias = (
AzureProviderSpec
| BedrockProviderSpec
| CohereProviderSpec

View File

@@ -1,16 +1,23 @@
"""Qdrant configuration model."""
from __future__ import annotations
from dataclasses import field
from typing import Literal, cast
from typing import TYPE_CHECKING, Any, Literal, cast
from pydantic.dataclasses import dataclass as pyd_dataclass
from qdrant_client.models import VectorParams
from crewai.rag.config.base import BaseRagConfig
from crewai.rag.qdrant.constants import DEFAULT_EMBEDDING_MODEL, DEFAULT_STORAGE_PATH
from crewai.rag.qdrant.types import QdrantClientParams, QdrantEmbeddingFunctionWrapper
if TYPE_CHECKING:
from qdrant_client.models import VectorParams
else:
VectorParams = Any
def _default_options() -> QdrantClientParams:
"""Create default Qdrant client options.
@@ -26,7 +33,7 @@ def _default_embedding_function() -> QdrantEmbeddingFunctionWrapper:
Returns:
Default embedding function using fastembed with all-MiniLM-L6-v2.
"""
from fastembed import TextEmbedding # type: ignore[import-not-found]
from fastembed import TextEmbedding
model = TextEmbedding(model_name=DEFAULT_EMBEDDING_MODEL)

View File

@@ -0,0 +1,364 @@
"""Tests for backward compatibility of embedding provider configurations."""
from crewai.rag.embeddings.factory import build_embedder, PROVIDER_PATHS
from crewai.rag.embeddings.providers.openai.openai_provider import OpenAIProvider
from crewai.rag.embeddings.providers.cohere.cohere_provider import CohereProvider
from crewai.rag.embeddings.providers.google.generative_ai import GenerativeAiProvider
from crewai.rag.embeddings.providers.google.vertex import VertexAIProvider
from crewai.rag.embeddings.providers.microsoft.azure import AzureProvider
from crewai.rag.embeddings.providers.jina.jina_provider import JinaProvider
from crewai.rag.embeddings.providers.ollama.ollama_provider import OllamaProvider
from crewai.rag.embeddings.providers.aws.bedrock import BedrockProvider
from crewai.rag.embeddings.providers.text2vec.text2vec_provider import Text2VecProvider
from crewai.rag.embeddings.providers.sentence_transformer.sentence_transformer_provider import (
SentenceTransformerProvider,
)
from crewai.rag.embeddings.providers.instructor.instructor_provider import InstructorProvider
from crewai.rag.embeddings.providers.openclip.openclip_provider import OpenCLIPProvider
class TestGoogleProviderAlias:
"""Test that 'google' provider name alias works for backward compatibility."""
def test_google_alias_in_provider_paths(self):
"""Verify 'google' is registered as an alias for google-generativeai."""
assert "google" in PROVIDER_PATHS
assert "google-generativeai" in PROVIDER_PATHS
assert PROVIDER_PATHS["google"] == PROVIDER_PATHS["google-generativeai"]
class TestModelKeyBackwardCompatibility:
"""Test that 'model' config key works as alias for 'model_name'."""
def test_openai_provider_accepts_model_key(self):
"""Test OpenAI provider accepts 'model' as alias for 'model_name'."""
provider = OpenAIProvider(
api_key="test-key",
model="text-embedding-3-small",
)
assert provider.model_name == "text-embedding-3-small"
def test_openai_provider_model_name_takes_precedence(self):
"""Test that model_name takes precedence when both are provided."""
provider = OpenAIProvider(
api_key="test-key",
model_name="text-embedding-3-large",
)
assert provider.model_name == "text-embedding-3-large"
def test_cohere_provider_accepts_model_key(self):
"""Test Cohere provider accepts 'model' as alias for 'model_name'."""
provider = CohereProvider(
api_key="test-key",
model="embed-english-v3.0",
)
assert provider.model_name == "embed-english-v3.0"
def test_google_generativeai_provider_accepts_model_key(self):
"""Test Google Generative AI provider accepts 'model' as alias."""
provider = GenerativeAiProvider(
api_key="test-key",
model="gemini-embedding-001",
)
assert provider.model_name == "gemini-embedding-001"
def test_google_vertex_provider_accepts_model_key(self):
"""Test Google Vertex AI provider accepts 'model' as alias."""
provider = VertexAIProvider(
api_key="test-key",
model="text-embedding-004",
)
assert provider.model_name == "text-embedding-004"
def test_azure_provider_accepts_model_key(self):
"""Test Azure provider accepts 'model' as alias for 'model_name'."""
provider = AzureProvider(
api_key="test-key",
deployment_id="test-deployment",
model="text-embedding-ada-002",
)
assert provider.model_name == "text-embedding-ada-002"
def test_jina_provider_accepts_model_key(self):
"""Test Jina provider accepts 'model' as alias for 'model_name'."""
provider = JinaProvider(
api_key="test-key",
model="jina-embeddings-v3",
)
assert provider.model_name == "jina-embeddings-v3"
def test_ollama_provider_accepts_model_key(self):
"""Test Ollama provider accepts 'model' as alias for 'model_name'."""
provider = OllamaProvider(
model="nomic-embed-text",
)
assert provider.model_name == "nomic-embed-text"
def test_text2vec_provider_accepts_model_key(self):
"""Test Text2Vec provider accepts 'model' as alias for 'model_name'."""
provider = Text2VecProvider(
model="shibing624/text2vec-base-multilingual",
)
assert provider.model_name == "shibing624/text2vec-base-multilingual"
def test_sentence_transformer_provider_accepts_model_key(self):
"""Test SentenceTransformer provider accepts 'model' as alias."""
provider = SentenceTransformerProvider(
model="all-mpnet-base-v2",
)
assert provider.model_name == "all-mpnet-base-v2"
def test_instructor_provider_accepts_model_key(self):
"""Test Instructor provider accepts 'model' as alias for 'model_name'."""
provider = InstructorProvider(
model="hkunlp/instructor-xl",
)
assert provider.model_name == "hkunlp/instructor-xl"
def test_openclip_provider_accepts_model_key(self):
"""Test OpenCLIP provider accepts 'model' as alias for 'model_name'."""
provider = OpenCLIPProvider(
model="ViT-B-16",
)
assert provider.model_name == "ViT-B-16"
class TestTaskTypeConfiguration:
"""Test that task_type configuration works correctly."""
def test_google_provider_accepts_lowercase_task_type(self):
"""Test Google provider accepts lowercase task_type."""
provider = GenerativeAiProvider(
api_key="test-key",
task_type="retrieval_document",
)
assert provider.task_type == "retrieval_document"
def test_google_provider_accepts_uppercase_task_type(self):
"""Test Google provider accepts uppercase task_type."""
provider = GenerativeAiProvider(
api_key="test-key",
task_type="RETRIEVAL_QUERY",
)
assert provider.task_type == "RETRIEVAL_QUERY"
def test_google_provider_default_task_type(self):
"""Test Google provider has correct default task_type."""
provider = GenerativeAiProvider(
api_key="test-key",
)
assert provider.task_type == "RETRIEVAL_DOCUMENT"
class TestFactoryBackwardCompatibility:
"""Test factory function with backward compatible configurations."""
def test_factory_with_google_alias(self):
"""Test factory resolves 'google' to google-generativeai provider."""
config = {
"provider": "google",
"config": {
"api_key": "test-key",
"model": "gemini-embedding-001",
},
}
from unittest.mock import patch, MagicMock
with patch("crewai.rag.embeddings.factory.import_and_validate_definition") as mock_import:
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
build_embedder(config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider"
)
def test_factory_with_model_key_openai(self):
"""Test factory passes 'model' config to OpenAI provider."""
config = {
"provider": "openai",
"config": {
"api_key": "test-key",
"model": "text-embedding-3-small",
},
}
from unittest.mock import patch, MagicMock
with patch("crewai.rag.embeddings.factory.import_and_validate_definition") as mock_import:
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
build_embedder(config)
call_kwargs = mock_provider_class.call_args.kwargs
assert call_kwargs["model"] == "text-embedding-3-small"
class TestDocumentationCodeSnippets:
"""Test code snippets from documentation work correctly."""
def test_memory_openai_config(self):
"""Test OpenAI config from memory.mdx documentation."""
provider = OpenAIProvider(
model_name="text-embedding-3-small",
)
assert provider.model_name == "text-embedding-3-small"
def test_memory_openai_config_with_options(self):
"""Test OpenAI config with all options from memory.mdx."""
provider = OpenAIProvider(
api_key="your-openai-api-key",
model_name="text-embedding-3-large",
dimensions=1536,
organization_id="your-org-id",
)
assert provider.model_name == "text-embedding-3-large"
assert provider.dimensions == 1536
def test_memory_azure_config(self):
"""Test Azure config from memory.mdx documentation."""
provider = AzureProvider(
api_key="your-azure-key",
api_base="https://your-resource.openai.azure.com/",
api_type="azure",
api_version="2023-05-15",
model_name="text-embedding-3-small",
deployment_id="your-deployment-name",
)
assert provider.model_name == "text-embedding-3-small"
assert provider.api_type == "azure"
def test_memory_google_generativeai_config(self):
"""Test Google Generative AI config from memory.mdx documentation."""
provider = GenerativeAiProvider(
api_key="your-google-api-key",
model_name="gemini-embedding-001",
)
assert provider.model_name == "gemini-embedding-001"
def test_memory_cohere_config(self):
"""Test Cohere config from memory.mdx documentation."""
provider = CohereProvider(
api_key="your-cohere-api-key",
model_name="embed-english-v3.0",
)
assert provider.model_name == "embed-english-v3.0"
def test_knowledge_agent_embedder_config(self):
"""Test agent embedder config from knowledge.mdx documentation."""
provider = GenerativeAiProvider(
model_name="gemini-embedding-001",
api_key="your-google-key",
)
assert provider.model_name == "gemini-embedding-001"
def test_ragtool_openai_config(self):
"""Test RagTool OpenAI config from ragtool.mdx documentation."""
provider = OpenAIProvider(
model_name="text-embedding-3-small",
)
assert provider.model_name == "text-embedding-3-small"
def test_ragtool_cohere_config(self):
"""Test RagTool Cohere config from ragtool.mdx documentation."""
provider = CohereProvider(
api_key="your-api-key",
model_name="embed-english-v3.0",
)
assert provider.model_name == "embed-english-v3.0"
def test_ragtool_ollama_config(self):
"""Test RagTool Ollama config from ragtool.mdx documentation."""
provider = OllamaProvider(
model_name="llama2",
url="http://localhost:11434/api/embeddings",
)
assert provider.model_name == "llama2"
def test_ragtool_azure_config(self):
"""Test RagTool Azure config from ragtool.mdx documentation."""
provider = AzureProvider(
deployment_id="your-deployment-id",
api_key="your-api-key",
api_base="https://your-resource.openai.azure.com",
api_version="2024-02-01",
model_name="text-embedding-ada-002",
api_type="azure",
)
assert provider.model_name == "text-embedding-ada-002"
assert provider.deployment_id == "your-deployment-id"
def test_ragtool_google_generativeai_config(self):
"""Test RagTool Google Generative AI config from ragtool.mdx."""
provider = GenerativeAiProvider(
api_key="your-api-key",
model_name="gemini-embedding-001",
task_type="RETRIEVAL_DOCUMENT",
)
assert provider.model_name == "gemini-embedding-001"
assert provider.task_type == "RETRIEVAL_DOCUMENT"
def test_ragtool_jina_config(self):
"""Test RagTool Jina config from ragtool.mdx documentation."""
provider = JinaProvider(
api_key="your-api-key",
model_name="jina-embeddings-v3",
)
assert provider.model_name == "jina-embeddings-v3"
def test_ragtool_sentence_transformer_config(self):
"""Test RagTool SentenceTransformer config from ragtool.mdx."""
provider = SentenceTransformerProvider(
model_name="all-mpnet-base-v2",
device="cuda",
normalize_embeddings=True,
)
assert provider.model_name == "all-mpnet-base-v2"
assert provider.device == "cuda"
assert provider.normalize_embeddings is True
class TestLegacyConfigurationFormats:
"""Test legacy configuration formats that should still work."""
def test_legacy_google_with_model_key(self):
"""Test legacy Google config using 'model' instead of 'model_name'."""
provider = GenerativeAiProvider(
api_key="test-key",
model="text-embedding-005",
task_type="retrieval_document",
)
assert provider.model_name == "text-embedding-005"
assert provider.task_type == "retrieval_document"
def test_legacy_openai_with_model_key(self):
"""Test legacy OpenAI config using 'model' instead of 'model_name'."""
provider = OpenAIProvider(
api_key="test-key",
model="text-embedding-ada-002",
)
assert provider.model_name == "text-embedding-ada-002"
def test_legacy_cohere_with_model_key(self):
"""Test legacy Cohere config using 'model' instead of 'model_name'."""
provider = CohereProvider(
api_key="test-key",
model="embed-multilingual-v3.0",
)
assert provider.model_name == "embed-multilingual-v3.0"
def test_legacy_azure_with_model_key(self):
"""Test legacy Azure config using 'model' instead of 'model_name'."""
provider = AzureProvider(
api_key="test-key",
deployment_id="test-deployment",
model="text-embedding-3-large",
)
assert provider.model_name == "text-embedding-3-large"