mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-15 11:58:31 +00:00
fix: rag tool embeddings config
* fix: ensure config is not flattened, add tests * chore: refactor inits to model_validator * chore: refactor rag tool config parsing * chore: add initial docs * chore: add additional validation aliases for provider env vars * chore: add solid docs * chore: move imports to top * fix: revert circular import * fix: lazy import qdrant-client * fix: allow collection name config * chore: narrow model names for google * chore: update additional docs * chore: add backward compat on model name aliases * chore: add tests for config changes
This commit is contained in:
@@ -388,8 +388,8 @@ crew = Crew(
|
||||
agents=[sales_agent, tech_agent, support_agent],
|
||||
tasks=[...],
|
||||
embedder={ # Fallback embedder for agents without their own
|
||||
"provider": "google",
|
||||
"config": {"model": "text-embedding-004"}
|
||||
"provider": "google-generativeai",
|
||||
"config": {"model_name": "gemini-embedding-001"}
|
||||
}
|
||||
)
|
||||
|
||||
@@ -629,9 +629,9 @@ agent = Agent(
|
||||
backstory="Expert researcher",
|
||||
knowledge_sources=[knowledge_source],
|
||||
embedder={
|
||||
"provider": "google",
|
||||
"provider": "google-generativeai",
|
||||
"config": {
|
||||
"model": "models/text-embedding-004",
|
||||
"model_name": "gemini-embedding-001",
|
||||
"api_key": "your-google-key"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -341,7 +341,7 @@ crew = Crew(
|
||||
embedder={
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": "text-embedding-3-small" # or "text-embedding-3-large"
|
||||
"model_name": "text-embedding-3-small" # or "text-embedding-3-large"
|
||||
}
|
||||
}
|
||||
)
|
||||
@@ -353,7 +353,7 @@ crew = Crew(
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"api_key": "your-openai-api-key", # Optional: override env var
|
||||
"model": "text-embedding-3-large",
|
||||
"model_name": "text-embedding-3-large",
|
||||
"dimensions": 1536, # Optional: reduce dimensions for smaller storage
|
||||
"organization_id": "your-org-id" # Optional: for organization accounts
|
||||
}
|
||||
@@ -375,7 +375,7 @@ crew = Crew(
|
||||
"api_base": "https://your-resource.openai.azure.com/",
|
||||
"api_type": "azure",
|
||||
"api_version": "2023-05-15",
|
||||
"model": "text-embedding-3-small",
|
||||
"model_name": "text-embedding-3-small",
|
||||
"deployment_id": "your-deployment-name" # Azure deployment name
|
||||
}
|
||||
}
|
||||
@@ -390,10 +390,10 @@ Use Google's text embedding models for integration with Google Cloud services.
|
||||
crew = Crew(
|
||||
memory=True,
|
||||
embedder={
|
||||
"provider": "google",
|
||||
"provider": "google-generativeai",
|
||||
"config": {
|
||||
"api_key": "your-google-api-key",
|
||||
"model": "text-embedding-004" # or "text-embedding-preview-0409"
|
||||
"model_name": "gemini-embedding-001" # or "text-embedding-005", "text-multilingual-embedding-002"
|
||||
}
|
||||
}
|
||||
)
|
||||
@@ -461,7 +461,7 @@ crew = Crew(
|
||||
"provider": "cohere",
|
||||
"config": {
|
||||
"api_key": "your-cohere-api-key",
|
||||
"model": "embed-english-v3.0" # or "embed-multilingual-v3.0"
|
||||
"model_name": "embed-english-v3.0" # or "embed-multilingual-v3.0"
|
||||
}
|
||||
}
|
||||
)
|
||||
@@ -478,7 +478,7 @@ crew = Crew(
|
||||
"provider": "voyageai",
|
||||
"config": {
|
||||
"api_key": "your-voyage-api-key",
|
||||
"model": "voyage-large-2", # or "voyage-code-2" for code
|
||||
"model": "voyage-3", # or "voyage-3-lite", "voyage-code-3"
|
||||
"input_type": "document" # or "query"
|
||||
}
|
||||
}
|
||||
@@ -912,10 +912,10 @@ crew = Crew(
|
||||
crew = Crew(
|
||||
memory=True,
|
||||
embedder={
|
||||
"provider": "google",
|
||||
"provider": "google-generativeai",
|
||||
"config": {
|
||||
"api_key": "your-api-key",
|
||||
"model": "text-embedding-004"
|
||||
"model_name": "gemini-embedding-001"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
@@ -77,7 +77,7 @@ The `RagTool` accepts the following parameters:
|
||||
|
||||
- **summarize**: Optional. Whether to summarize the retrieved content. Default is `False`.
|
||||
- **adapter**: Optional. A custom adapter for the knowledge base. If not provided, a CrewAIRagAdapter will be used.
|
||||
- **config**: Optional. Configuration for the underlying CrewAI RAG system.
|
||||
- **config**: Optional. Configuration for the underlying CrewAI RAG system. Accepts a `RagToolConfig` TypedDict with optional `embedding_model` (ProviderSpec) and `vectordb` (VectorDbConfig) keys. All configuration values provided programmatically take precedence over environment variables.
|
||||
|
||||
## Adding Content
|
||||
|
||||
@@ -127,26 +127,528 @@ You can customize the behavior of the `RagTool` by providing a configuration dic
|
||||
|
||||
```python Code
|
||||
from crewai_tools import RagTool
|
||||
from crewai_tools.tools.rag import RagToolConfig, VectorDbConfig, ProviderSpec
|
||||
|
||||
# Create a RAG tool with custom configuration
|
||||
config = {
|
||||
"vectordb": {
|
||||
"provider": "qdrant",
|
||||
"config": {
|
||||
"collection_name": "my-collection"
|
||||
}
|
||||
},
|
||||
"embedding_model": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": "text-embedding-3-small"
|
||||
}
|
||||
|
||||
vectordb: VectorDbConfig = {
|
||||
"provider": "qdrant",
|
||||
"config": {
|
||||
"collection_name": "my-collection"
|
||||
}
|
||||
}
|
||||
|
||||
embedding_model: ProviderSpec = {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model_name": "text-embedding-3-small"
|
||||
}
|
||||
}
|
||||
|
||||
config: RagToolConfig = {
|
||||
"vectordb": vectordb,
|
||||
"embedding_model": embedding_model
|
||||
}
|
||||
|
||||
rag_tool = RagTool(config=config, summarize=True)
|
||||
```
|
||||
|
||||
## Embedding Model Configuration
|
||||
|
||||
The `embedding_model` parameter accepts a `crewai.rag.embeddings.types.ProviderSpec` dictionary with the structure:
|
||||
|
||||
```python
|
||||
{
|
||||
"provider": "provider-name", # Required
|
||||
"config": { # Optional
|
||||
# Provider-specific configuration
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Supported Providers
|
||||
|
||||
<AccordionGroup>
|
||||
<Accordion title="OpenAI">
|
||||
```python main.py
|
||||
from crewai.rag.embeddings.providers.openai.types import OpenAIProviderSpec
|
||||
|
||||
embedding_model: OpenAIProviderSpec = {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"api_key": "your-api-key",
|
||||
"model_name": "text-embedding-ada-002",
|
||||
"dimensions": 1536,
|
||||
"organization_id": "your-org-id",
|
||||
"api_base": "https://api.openai.com/v1",
|
||||
"api_version": "v1",
|
||||
"default_headers": {"Custom-Header": "value"}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Config Options:**
|
||||
- `api_key` (str): OpenAI API key
|
||||
- `model_name` (str): Model to use. Default: `text-embedding-ada-002`. Options: `text-embedding-3-small`, `text-embedding-3-large`, `text-embedding-ada-002`
|
||||
- `dimensions` (int): Number of dimensions for the embedding
|
||||
- `organization_id` (str): OpenAI organization ID
|
||||
- `api_base` (str): Custom API base URL
|
||||
- `api_version` (str): API version
|
||||
- `default_headers` (dict): Custom headers for API requests
|
||||
|
||||
**Environment Variables:**
|
||||
- `OPENAI_API_KEY` or `EMBEDDINGS_OPENAI_API_KEY`: `api_key`
|
||||
- `OPENAI_ORGANIZATION_ID` or `EMBEDDINGS_OPENAI_ORGANIZATION_ID`: `organization_id`
|
||||
- `OPENAI_MODEL_NAME` or `EMBEDDINGS_OPENAI_MODEL_NAME`: `model_name`
|
||||
- `OPENAI_API_BASE` or `EMBEDDINGS_OPENAI_API_BASE`: `api_base`
|
||||
- `OPENAI_API_VERSION` or `EMBEDDINGS_OPENAI_API_VERSION`: `api_version`
|
||||
- `OPENAI_DIMENSIONS` or `EMBEDDINGS_OPENAI_DIMENSIONS`: `dimensions`
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Cohere">
|
||||
```python main.py
|
||||
from crewai.rag.embeddings.providers.cohere.types import CohereProviderSpec
|
||||
|
||||
embedding_model: CohereProviderSpec = {
|
||||
"provider": "cohere",
|
||||
"config": {
|
||||
"api_key": "your-api-key",
|
||||
"model_name": "embed-english-v3.0"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Config Options:**
|
||||
- `api_key` (str): Cohere API key
|
||||
- `model_name` (str): Model to use. Default: `large`. Options: `embed-english-v3.0`, `embed-multilingual-v3.0`, `large`, `small`
|
||||
|
||||
**Environment Variables:**
|
||||
- `COHERE_API_KEY` or `EMBEDDINGS_COHERE_API_KEY`: `api_key`
|
||||
- `EMBEDDINGS_COHERE_MODEL_NAME`: `model_name`
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="VoyageAI">
|
||||
```python main.py
|
||||
from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderSpec
|
||||
|
||||
embedding_model: VoyageAIProviderSpec = {
|
||||
"provider": "voyageai",
|
||||
"config": {
|
||||
"api_key": "your-api-key",
|
||||
"model": "voyage-3",
|
||||
"input_type": "document",
|
||||
"truncation": True,
|
||||
"output_dtype": "float32",
|
||||
"output_dimension": 1024,
|
||||
"max_retries": 3,
|
||||
"timeout": 60.0
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Config Options:**
|
||||
- `api_key` (str): VoyageAI API key
|
||||
- `model` (str): Model to use. Default: `voyage-2`. Options: `voyage-3`, `voyage-3-lite`, `voyage-code-3`, `voyage-large-2`
|
||||
- `input_type` (str): Type of input. Options: `document` (for storage), `query` (for search)
|
||||
- `truncation` (bool): Whether to truncate inputs that exceed max length. Default: `True`
|
||||
- `output_dtype` (str): Output data type
|
||||
- `output_dimension` (int): Dimension of output embeddings
|
||||
- `max_retries` (int): Maximum number of retry attempts. Default: `0`
|
||||
- `timeout` (float): Request timeout in seconds
|
||||
|
||||
**Environment Variables:**
|
||||
- `VOYAGEAI_API_KEY` or `EMBEDDINGS_VOYAGEAI_API_KEY`: `api_key`
|
||||
- `VOYAGEAI_MODEL` or `EMBEDDINGS_VOYAGEAI_MODEL`: `model`
|
||||
- `VOYAGEAI_INPUT_TYPE` or `EMBEDDINGS_VOYAGEAI_INPUT_TYPE`: `input_type`
|
||||
- `VOYAGEAI_TRUNCATION` or `EMBEDDINGS_VOYAGEAI_TRUNCATION`: `truncation`
|
||||
- `VOYAGEAI_OUTPUT_DTYPE` or `EMBEDDINGS_VOYAGEAI_OUTPUT_DTYPE`: `output_dtype`
|
||||
- `VOYAGEAI_OUTPUT_DIMENSION` or `EMBEDDINGS_VOYAGEAI_OUTPUT_DIMENSION`: `output_dimension`
|
||||
- `VOYAGEAI_MAX_RETRIES` or `EMBEDDINGS_VOYAGEAI_MAX_RETRIES`: `max_retries`
|
||||
- `VOYAGEAI_TIMEOUT` or `EMBEDDINGS_VOYAGEAI_TIMEOUT`: `timeout`
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Ollama">
|
||||
```python main.py
|
||||
from crewai.rag.embeddings.providers.ollama.types import OllamaProviderSpec
|
||||
|
||||
embedding_model: OllamaProviderSpec = {
|
||||
"provider": "ollama",
|
||||
"config": {
|
||||
"model_name": "llama2",
|
||||
"url": "http://localhost:11434/api/embeddings"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Config Options:**
|
||||
- `model_name` (str): Ollama model name (e.g., `llama2`, `mistral`, `nomic-embed-text`)
|
||||
- `url` (str): Ollama API endpoint URL. Default: `http://localhost:11434/api/embeddings`
|
||||
|
||||
**Environment Variables:**
|
||||
- `OLLAMA_MODEL` or `EMBEDDINGS_OLLAMA_MODEL`: `model_name`
|
||||
- `OLLAMA_URL` or `EMBEDDINGS_OLLAMA_URL`: `url`
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Amazon Bedrock">
|
||||
```python main.py
|
||||
from crewai.rag.embeddings.providers.aws.types import BedrockProviderSpec
|
||||
|
||||
embedding_model: BedrockProviderSpec = {
|
||||
"provider": "amazon-bedrock",
|
||||
"config": {
|
||||
"model_name": "amazon.titan-embed-text-v2:0",
|
||||
"session": boto3_session
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Config Options:**
|
||||
- `model_name` (str): Bedrock model ID. Default: `amazon.titan-embed-text-v1`. Options: `amazon.titan-embed-text-v1`, `amazon.titan-embed-text-v2:0`, `cohere.embed-english-v3`, `cohere.embed-multilingual-v3`
|
||||
- `session` (Any): Boto3 session object for AWS authentication
|
||||
|
||||
**Environment Variables:**
|
||||
- `AWS_ACCESS_KEY_ID`: AWS access key
|
||||
- `AWS_SECRET_ACCESS_KEY`: AWS secret key
|
||||
- `AWS_REGION`: AWS region (e.g., `us-east-1`)
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Azure OpenAI">
|
||||
```python main.py
|
||||
from crewai.rag.embeddings.providers.microsoft.types import AzureProviderSpec
|
||||
|
||||
embedding_model: AzureProviderSpec = {
|
||||
"provider": "azure",
|
||||
"config": {
|
||||
"deployment_id": "your-deployment-id",
|
||||
"api_key": "your-api-key",
|
||||
"api_base": "https://your-resource.openai.azure.com",
|
||||
"api_version": "2024-02-01",
|
||||
"model_name": "text-embedding-ada-002",
|
||||
"api_type": "azure"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Config Options:**
|
||||
- `deployment_id` (str): **Required** - Azure OpenAI deployment ID
|
||||
- `api_key` (str): Azure OpenAI API key
|
||||
- `api_base` (str): Azure OpenAI resource endpoint
|
||||
- `api_version` (str): API version. Example: `2024-02-01`
|
||||
- `model_name` (str): Model name. Default: `text-embedding-ada-002`
|
||||
- `api_type` (str): API type. Default: `azure`
|
||||
- `dimensions` (int): Output dimensions
|
||||
- `default_headers` (dict): Custom headers
|
||||
|
||||
**Environment Variables:**
|
||||
- `AZURE_OPENAI_API_KEY` or `EMBEDDINGS_AZURE_API_KEY`: `api_key`
|
||||
- `AZURE_OPENAI_ENDPOINT` or `EMBEDDINGS_AZURE_API_BASE`: `api_base`
|
||||
- `EMBEDDINGS_AZURE_DEPLOYMENT_ID`: `deployment_id`
|
||||
- `EMBEDDINGS_AZURE_API_VERSION`: `api_version`
|
||||
- `EMBEDDINGS_AZURE_MODEL_NAME`: `model_name`
|
||||
- `EMBEDDINGS_AZURE_API_TYPE`: `api_type`
|
||||
- `EMBEDDINGS_AZURE_DIMENSIONS`: `dimensions`
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Google Generative AI">
|
||||
```python main.py
|
||||
from crewai.rag.embeddings.providers.google.types import GenerativeAiProviderSpec
|
||||
|
||||
embedding_model: GenerativeAiProviderSpec = {
|
||||
"provider": "google-generativeai",
|
||||
"config": {
|
||||
"api_key": "your-api-key",
|
||||
"model_name": "gemini-embedding-001",
|
||||
"task_type": "RETRIEVAL_DOCUMENT"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Config Options:**
|
||||
- `api_key` (str): Google AI API key
|
||||
- `model_name` (str): Model name. Default: `gemini-embedding-001`. Options: `gemini-embedding-001`, `text-embedding-005`, `text-multilingual-embedding-002`
|
||||
- `task_type` (str): Task type for embeddings. Default: `RETRIEVAL_DOCUMENT`. Options: `RETRIEVAL_DOCUMENT`, `RETRIEVAL_QUERY`
|
||||
|
||||
**Environment Variables:**
|
||||
- `GOOGLE_API_KEY`, `GEMINI_API_KEY`, or `EMBEDDINGS_GOOGLE_API_KEY`: `api_key`
|
||||
- `EMBEDDINGS_GOOGLE_GENERATIVE_AI_MODEL_NAME`: `model_name`
|
||||
- `EMBEDDINGS_GOOGLE_GENERATIVE_AI_TASK_TYPE`: `task_type`
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Google Vertex AI">
|
||||
```python main.py
|
||||
from crewai.rag.embeddings.providers.google.types import VertexAIProviderSpec
|
||||
|
||||
embedding_model: VertexAIProviderSpec = {
|
||||
"provider": "google-vertex",
|
||||
"config": {
|
||||
"model_name": "text-embedding-004",
|
||||
"project_id": "your-project-id",
|
||||
"region": "us-central1",
|
||||
"api_key": "your-api-key"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Config Options:**
|
||||
- `model_name` (str): Model name. Default: `textembedding-gecko`. Options: `text-embedding-004`, `textembedding-gecko`, `textembedding-gecko-multilingual`
|
||||
- `project_id` (str): Google Cloud project ID. Default: `cloud-large-language-models`
|
||||
- `region` (str): Google Cloud region. Default: `us-central1`
|
||||
- `api_key` (str): API key for authentication
|
||||
|
||||
**Environment Variables:**
|
||||
- `GOOGLE_APPLICATION_CREDENTIALS`: Path to service account JSON file
|
||||
- `GOOGLE_CLOUD_PROJECT` or `EMBEDDINGS_GOOGLE_VERTEX_PROJECT_ID`: `project_id`
|
||||
- `EMBEDDINGS_GOOGLE_VERTEX_MODEL_NAME`: `model_name`
|
||||
- `EMBEDDINGS_GOOGLE_VERTEX_REGION`: `region`
|
||||
- `EMBEDDINGS_GOOGLE_VERTEX_API_KEY`: `api_key`
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Jina AI">
|
||||
```python main.py
|
||||
from crewai.rag.embeddings.providers.jina.types import JinaProviderSpec
|
||||
|
||||
embedding_model: JinaProviderSpec = {
|
||||
"provider": "jina",
|
||||
"config": {
|
||||
"api_key": "your-api-key",
|
||||
"model_name": "jina-embeddings-v3"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Config Options:**
|
||||
- `api_key` (str): Jina AI API key
|
||||
- `model_name` (str): Model name. Default: `jina-embeddings-v2-base-en`. Options: `jina-embeddings-v3`, `jina-embeddings-v2-base-en`, `jina-embeddings-v2-small-en`
|
||||
|
||||
**Environment Variables:**
|
||||
- `JINA_API_KEY` or `EMBEDDINGS_JINA_API_KEY`: `api_key`
|
||||
- `EMBEDDINGS_JINA_MODEL_NAME`: `model_name`
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="HuggingFace">
|
||||
```python main.py
|
||||
from crewai.rag.embeddings.providers.huggingface.types import HuggingFaceProviderSpec
|
||||
|
||||
embedding_model: HuggingFaceProviderSpec = {
|
||||
"provider": "huggingface",
|
||||
"config": {
|
||||
"url": "https://api-inference.huggingface.co/models/sentence-transformers/all-MiniLM-L6-v2"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Config Options:**
|
||||
- `url` (str): Full URL to HuggingFace inference API endpoint
|
||||
|
||||
**Environment Variables:**
|
||||
- `HUGGINGFACE_URL` or `EMBEDDINGS_HUGGINGFACE_URL`: `url`
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Instructor">
|
||||
```python main.py
|
||||
from crewai.rag.embeddings.providers.instructor.types import InstructorProviderSpec
|
||||
|
||||
embedding_model: InstructorProviderSpec = {
|
||||
"provider": "instructor",
|
||||
"config": {
|
||||
"model_name": "hkunlp/instructor-xl",
|
||||
"device": "cuda",
|
||||
"instruction": "Represent the document"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Config Options:**
|
||||
- `model_name` (str): HuggingFace model ID. Default: `hkunlp/instructor-base`. Options: `hkunlp/instructor-xl`, `hkunlp/instructor-large`, `hkunlp/instructor-base`
|
||||
- `device` (str): Device to run on. Default: `cpu`. Options: `cpu`, `cuda`, `mps`
|
||||
- `instruction` (str): Instruction prefix for embeddings
|
||||
|
||||
**Environment Variables:**
|
||||
- `EMBEDDINGS_INSTRUCTOR_MODEL_NAME`: `model_name`
|
||||
- `EMBEDDINGS_INSTRUCTOR_DEVICE`: `device`
|
||||
- `EMBEDDINGS_INSTRUCTOR_INSTRUCTION`: `instruction`
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Sentence Transformer">
|
||||
```python main.py
|
||||
from crewai.rag.embeddings.providers.sentence_transformer.types import SentenceTransformerProviderSpec
|
||||
|
||||
embedding_model: SentenceTransformerProviderSpec = {
|
||||
"provider": "sentence-transformer",
|
||||
"config": {
|
||||
"model_name": "all-mpnet-base-v2",
|
||||
"device": "cuda",
|
||||
"normalize_embeddings": True
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Config Options:**
|
||||
- `model_name` (str): Sentence Transformers model name. Default: `all-MiniLM-L6-v2`. Options: `all-mpnet-base-v2`, `all-MiniLM-L6-v2`, `paraphrase-multilingual-MiniLM-L12-v2`
|
||||
- `device` (str): Device to run on. Default: `cpu`. Options: `cpu`, `cuda`, `mps`
|
||||
- `normalize_embeddings` (bool): Whether to normalize embeddings. Default: `False`
|
||||
|
||||
**Environment Variables:**
|
||||
- `EMBEDDINGS_SENTENCE_TRANSFORMER_MODEL_NAME`: `model_name`
|
||||
- `EMBEDDINGS_SENTENCE_TRANSFORMER_DEVICE`: `device`
|
||||
- `EMBEDDINGS_SENTENCE_TRANSFORMER_NORMALIZE_EMBEDDINGS`: `normalize_embeddings`
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="ONNX">
|
||||
```python main.py
|
||||
from crewai.rag.embeddings.providers.onnx.types import ONNXProviderSpec
|
||||
|
||||
embedding_model: ONNXProviderSpec = {
|
||||
"provider": "onnx",
|
||||
"config": {
|
||||
"preferred_providers": ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Config Options:**
|
||||
- `preferred_providers` (list[str]): List of ONNX execution providers in order of preference
|
||||
|
||||
**Environment Variables:**
|
||||
- `EMBEDDINGS_ONNX_PREFERRED_PROVIDERS`: `preferred_providers` (comma-separated list)
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="OpenCLIP">
|
||||
```python main.py
|
||||
from crewai.rag.embeddings.providers.openclip.types import OpenCLIPProviderSpec
|
||||
|
||||
embedding_model: OpenCLIPProviderSpec = {
|
||||
"provider": "openclip",
|
||||
"config": {
|
||||
"model_name": "ViT-B-32",
|
||||
"checkpoint": "laion2b_s34b_b79k",
|
||||
"device": "cuda"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Config Options:**
|
||||
- `model_name` (str): OpenCLIP model architecture. Default: `ViT-B-32`. Options: `ViT-B-32`, `ViT-B-16`, `ViT-L-14`
|
||||
- `checkpoint` (str): Pretrained checkpoint name. Default: `laion2b_s34b_b79k`. Options: `laion2b_s34b_b79k`, `laion400m_e32`, `openai`
|
||||
- `device` (str): Device to run on. Default: `cpu`. Options: `cpu`, `cuda`
|
||||
|
||||
**Environment Variables:**
|
||||
- `EMBEDDINGS_OPENCLIP_MODEL_NAME`: `model_name`
|
||||
- `EMBEDDINGS_OPENCLIP_CHECKPOINT`: `checkpoint`
|
||||
- `EMBEDDINGS_OPENCLIP_DEVICE`: `device`
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Text2Vec">
|
||||
```python main.py
|
||||
from crewai.rag.embeddings.providers.text2vec.types import Text2VecProviderSpec
|
||||
|
||||
embedding_model: Text2VecProviderSpec = {
|
||||
"provider": "text2vec",
|
||||
"config": {
|
||||
"model_name": "shibing624/text2vec-base-multilingual"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Config Options:**
|
||||
- `model_name` (str): Text2Vec model name from HuggingFace. Default: `shibing624/text2vec-base-chinese`. Options: `shibing624/text2vec-base-multilingual`, `shibing624/text2vec-base-chinese`
|
||||
|
||||
**Environment Variables:**
|
||||
- `EMBEDDINGS_TEXT2VEC_MODEL_NAME`: `model_name`
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Roboflow">
|
||||
```python main.py
|
||||
from crewai.rag.embeddings.providers.roboflow.types import RoboflowProviderSpec
|
||||
|
||||
embedding_model: RoboflowProviderSpec = {
|
||||
"provider": "roboflow",
|
||||
"config": {
|
||||
"api_key": "your-api-key",
|
||||
"api_url": "https://infer.roboflow.com"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Config Options:**
|
||||
- `api_key` (str): Roboflow API key. Default: `""` (empty string)
|
||||
- `api_url` (str): Roboflow inference API URL. Default: `https://infer.roboflow.com`
|
||||
|
||||
**Environment Variables:**
|
||||
- `ROBOFLOW_API_KEY` or `EMBEDDINGS_ROBOFLOW_API_KEY`: `api_key`
|
||||
- `ROBOFLOW_API_URL` or `EMBEDDINGS_ROBOFLOW_API_URL`: `api_url`
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="WatsonX (IBM)">
|
||||
```python main.py
|
||||
from crewai.rag.embeddings.providers.ibm.types import WatsonXProviderSpec
|
||||
|
||||
embedding_model: WatsonXProviderSpec = {
|
||||
"provider": "watsonx",
|
||||
"config": {
|
||||
"model_id": "ibm/slate-125m-english-rtrvr",
|
||||
"url": "https://us-south.ml.cloud.ibm.com",
|
||||
"api_key": "your-api-key",
|
||||
"project_id": "your-project-id",
|
||||
"batch_size": 100,
|
||||
"concurrency_limit": 10,
|
||||
"persistent_connection": True
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Config Options:**
|
||||
- `model_id` (str): WatsonX model identifier
|
||||
- `url` (str): WatsonX API endpoint
|
||||
- `api_key` (str): IBM Cloud API key
|
||||
- `project_id` (str): WatsonX project ID
|
||||
- `space_id` (str): WatsonX space ID (alternative to project_id)
|
||||
- `batch_size` (int): Batch size for embeddings. Default: `100`
|
||||
- `concurrency_limit` (int): Maximum concurrent requests. Default: `10`
|
||||
- `persistent_connection` (bool): Use persistent connections. Default: `True`
|
||||
- Plus 20+ additional authentication and configuration options
|
||||
|
||||
**Environment Variables:**
|
||||
- `WATSONX_API_KEY` or `EMBEDDINGS_WATSONX_API_KEY`: `api_key`
|
||||
- `WATSONX_URL` or `EMBEDDINGS_WATSONX_URL`: `url`
|
||||
- `WATSONX_PROJECT_ID` or `EMBEDDINGS_WATSONX_PROJECT_ID`: `project_id`
|
||||
- `EMBEDDINGS_WATSONX_MODEL_ID`: `model_id`
|
||||
- `EMBEDDINGS_WATSONX_SPACE_ID`: `space_id`
|
||||
- `EMBEDDINGS_WATSONX_BATCH_SIZE`: `batch_size`
|
||||
- `EMBEDDINGS_WATSONX_CONCURRENCY_LIMIT`: `concurrency_limit`
|
||||
- `EMBEDDINGS_WATSONX_PERSISTENT_CONNECTION`: `persistent_connection`
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Custom">
|
||||
```python main.py
|
||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
||||
from crewai.rag.embeddings.providers.custom.types import CustomProviderSpec
|
||||
|
||||
class MyEmbeddingFunction(EmbeddingFunction):
|
||||
def __call__(self, input):
|
||||
# Your custom embedding logic
|
||||
return embeddings
|
||||
|
||||
embedding_model: CustomProviderSpec = {
|
||||
"provider": "custom",
|
||||
"config": {
|
||||
"embedding_callable": MyEmbeddingFunction
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Config Options:**
|
||||
- `embedding_callable` (type[EmbeddingFunction]): Custom embedding function class
|
||||
|
||||
**Note:** Custom embedding functions must implement the `EmbeddingFunction` protocol defined in `crewai.rag.core.base_embeddings_callable`. The `__call__` method should accept input data and return embeddings as a list of numpy arrays (or compatible format that will be normalized). The returned embeddings are automatically normalized and validated.
|
||||
</Accordion>
|
||||
</AccordionGroup>
|
||||
|
||||
### Notes
|
||||
- All config fields are optional unless marked as **Required**
|
||||
- API keys can typically be provided via environment variables instead of config
|
||||
- Default values are shown where applicable
|
||||
|
||||
|
||||
## Conclusion
|
||||
The `RagTool` provides a powerful way to create and query knowledge bases from various data sources. By leveraging Retrieval-Augmented Generation, it enables agents to access and retrieve relevant information efficiently, enhancing their ability to provide accurate and contextually appropriate responses.
|
||||
|
||||
@@ -58,10 +58,10 @@ tool = MySQLSearchTool(
|
||||
),
|
||||
),
|
||||
embedder=dict(
|
||||
provider="google",
|
||||
provider="google-generativeai",
|
||||
config=dict(
|
||||
model="models/embedding-001",
|
||||
task_type="retrieval_document",
|
||||
model_name="gemini-embedding-001",
|
||||
task_type="RETRIEVAL_DOCUMENT",
|
||||
# title="Embeddings",
|
||||
),
|
||||
),
|
||||
|
||||
@@ -71,10 +71,10 @@ tool = PGSearchTool(
|
||||
),
|
||||
),
|
||||
embedder=dict(
|
||||
provider="google", # or openai, ollama, ...
|
||||
provider="google-generativeai", # or openai, ollama, ...
|
||||
config=dict(
|
||||
model="models/embedding-001",
|
||||
task_type="retrieval_document",
|
||||
model_name="gemini-embedding-001",
|
||||
task_type="RETRIEVAL_DOCUMENT",
|
||||
# title="Embeddings",
|
||||
),
|
||||
),
|
||||
|
||||
@@ -64,10 +64,10 @@ tool = JSONSearchTool(
|
||||
},
|
||||
},
|
||||
"embedding_model": {
|
||||
"provider": "google", # or openai, ollama, ...
|
||||
"provider": "google-generativeai", # or openai, ollama, ...
|
||||
"config": {
|
||||
"model": "models/embedding-001",
|
||||
"task_type": "retrieval_document",
|
||||
"model_name": "gemini-embedding-001",
|
||||
"task_type": "RETRIEVAL_DOCUMENT",
|
||||
# Further customization options can be added here.
|
||||
},
|
||||
},
|
||||
|
||||
@@ -63,15 +63,15 @@ tool = PDFSearchTool(
|
||||
"config": {
|
||||
# Model identifier for the chosen provider. "model" will be auto-mapped to "model_name" internally.
|
||||
"model": "text-embedding-3-small",
|
||||
# Optional: API key. If omitted, the tool will use provider-specific env vars when available
|
||||
# (e.g., OPENAI_API_KEY for provider="openai").
|
||||
# Optional: API key. If omitted, the tool will use provider-specific env vars
|
||||
# (e.g., OPENAI_API_KEY or EMBEDDINGS_OPENAI_API_KEY for OpenAI).
|
||||
# "api_key": "sk-...",
|
||||
|
||||
# Provider-specific examples:
|
||||
# --- Google Generative AI ---
|
||||
# (Set provider="google-generativeai" above)
|
||||
# "model": "models/embedding-001",
|
||||
# "task_type": "retrieval_document",
|
||||
# "model_name": "gemini-embedding-001",
|
||||
# "task_type": "RETRIEVAL_DOCUMENT",
|
||||
# "title": "Embeddings",
|
||||
|
||||
# --- Cohere ---
|
||||
|
||||
@@ -66,9 +66,9 @@ tool = TXTSearchTool(
|
||||
"provider": "openai", # or google-generativeai, cohere, ollama, ...
|
||||
"config": {
|
||||
"model": "text-embedding-3-small",
|
||||
# "api_key": "sk-...", # optional if env var is set
|
||||
# "api_key": "sk-...", # optional if env var is set (e.g., OPENAI_API_KEY or EMBEDDINGS_OPENAI_API_KEY)
|
||||
# Provider examples:
|
||||
# Google → model: "models/embedding-001", task_type: "retrieval_document"
|
||||
# Google → model_name: "gemini-embedding-001", task_type: "RETRIEVAL_DOCUMENT"
|
||||
# Cohere → model: "embed-english-v3.0"
|
||||
# Ollama → model: "nomic-embed-text"
|
||||
},
|
||||
|
||||
@@ -73,10 +73,10 @@ tool = CodeDocsSearchTool(
|
||||
),
|
||||
),
|
||||
embedder=dict(
|
||||
provider="google", # or openai, ollama, ...
|
||||
provider="google-generativeai", # or openai, ollama, ...
|
||||
config=dict(
|
||||
model="models/embedding-001",
|
||||
task_type="retrieval_document",
|
||||
model_name="gemini-embedding-001",
|
||||
task_type="RETRIEVAL_DOCUMENT",
|
||||
# title="Embeddings",
|
||||
),
|
||||
),
|
||||
|
||||
@@ -75,10 +75,10 @@ tool = GithubSearchTool(
|
||||
),
|
||||
),
|
||||
embedder=dict(
|
||||
provider="google", # or openai, ollama, ...
|
||||
provider="google-generativeai", # or openai, ollama, ...
|
||||
config=dict(
|
||||
model="models/embedding-001",
|
||||
task_type="retrieval_document",
|
||||
model_name="gemini-embedding-001",
|
||||
task_type="RETRIEVAL_DOCUMENT",
|
||||
# title="Embeddings",
|
||||
),
|
||||
),
|
||||
|
||||
@@ -66,10 +66,10 @@ tool = WebsiteSearchTool(
|
||||
),
|
||||
),
|
||||
embedder=dict(
|
||||
provider="google", # or openai, ollama, ...
|
||||
provider="google-generativeai", # or openai, ollama, ...
|
||||
config=dict(
|
||||
model="models/embedding-001",
|
||||
task_type="retrieval_document",
|
||||
model_name="gemini-embedding-001",
|
||||
task_type="RETRIEVAL_DOCUMENT",
|
||||
# title="Embeddings",
|
||||
),
|
||||
),
|
||||
|
||||
@@ -106,10 +106,10 @@ youtube_channel_tool = YoutubeChannelSearchTool(
|
||||
),
|
||||
),
|
||||
embedder=dict(
|
||||
provider="google", # or openai, ollama, ...
|
||||
provider="google-generativeai", # or openai, ollama, ...
|
||||
config=dict(
|
||||
model="models/embedding-001",
|
||||
task_type="retrieval_document",
|
||||
model_name="gemini-embedding-001",
|
||||
task_type="RETRIEVAL_DOCUMENT",
|
||||
# title="Embeddings",
|
||||
),
|
||||
),
|
||||
|
||||
@@ -108,10 +108,10 @@ youtube_search_tool = YoutubeVideoSearchTool(
|
||||
),
|
||||
),
|
||||
embedder=dict(
|
||||
provider="google", # or openai, ollama, ...
|
||||
provider="google-generativeai", # or openai, ollama, ...
|
||||
config=dict(
|
||||
model="models/embedding-001",
|
||||
task_type="retrieval_document",
|
||||
model_name="gemini-embedding-001",
|
||||
task_type="RETRIEVAL_DOCUMENT",
|
||||
# title="Embeddings",
|
||||
),
|
||||
),
|
||||
|
||||
@@ -1,28 +1,51 @@
|
||||
"""Adapter for CrewAI's native RAG system."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from typing import Any, TypeAlias, TypedDict
|
||||
from typing import TYPE_CHECKING, Any, TypeAlias, TypedDict, cast
|
||||
import uuid
|
||||
|
||||
from crewai.rag.config.types import RagConfigType
|
||||
from crewai.rag.config.utils import get_rag_client
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.factory import create_client
|
||||
from crewai.rag.qdrant.config import QdrantConfig
|
||||
from crewai.rag.types import BaseRecord, SearchResult
|
||||
from pydantic import PrivateAttr
|
||||
from qdrant_client.models import VectorParams
|
||||
from typing_extensions import Unpack
|
||||
from pydantic.dataclasses import is_pydantic_dataclass
|
||||
from typing_extensions import TypeIs, Unpack
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from crewai_tools.rag.misc import sanitize_metadata_for_chromadb
|
||||
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.rag.qdrant.config import QdrantConfig
|
||||
|
||||
|
||||
ContentItem: TypeAlias = str | Path | dict[str, Any]
|
||||
|
||||
|
||||
def _is_qdrant_config(config: Any) -> TypeIs[QdrantConfig]:
|
||||
"""Check if config is a QdrantConfig using safe duck typing.
|
||||
|
||||
Args:
|
||||
config: RAG configuration to check.
|
||||
|
||||
Returns:
|
||||
True if config is a QdrantConfig instance.
|
||||
"""
|
||||
if not is_pydantic_dataclass(config):
|
||||
return False
|
||||
|
||||
try:
|
||||
return cast(bool, config.provider == "qdrant") # type: ignore[attr-defined]
|
||||
except (AttributeError, ImportError):
|
||||
return False
|
||||
|
||||
|
||||
class AddDocumentParams(TypedDict, total=False):
|
||||
"""Parameters for adding documents to the RAG system."""
|
||||
|
||||
@@ -56,8 +79,9 @@ class CrewAIRagAdapter(Adapter):
|
||||
else:
|
||||
self._client = get_rag_client()
|
||||
collection_params: dict[str, Any] = {"collection_name": self.collection_name}
|
||||
if isinstance(self.config, QdrantConfig) and self.config.vectors_config:
|
||||
if isinstance(self.config.vectors_config, VectorParams):
|
||||
|
||||
if self.config is not None and _is_qdrant_config(self.config):
|
||||
if self.config.vectors_config is not None:
|
||||
collection_params["vectors_config"] = self.config.vectors_config
|
||||
self._client.get_or_create_collection(**collection_params)
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||
@@ -24,14 +25,17 @@ class PDFSearchTool(RagTool):
|
||||
"A tool that can be used to semantic search a query from a PDF's content."
|
||||
)
|
||||
args_schema: type[BaseModel] = PDFSearchToolSchema
|
||||
pdf: str | None = None
|
||||
|
||||
def __init__(self, pdf: str | None = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if pdf is not None:
|
||||
self.add(pdf)
|
||||
self.description = f"A tool that can be used to semantic search a query the {pdf} PDF's content."
|
||||
@model_validator(mode="after")
|
||||
def _configure_for_pdf(self) -> Self:
|
||||
"""Configure tool for specific PDF if provided."""
|
||||
if self.pdf is not None:
|
||||
self.add(self.pdf)
|
||||
self.description = f"A tool that can be used to semantic search a query the {self.pdf} PDF's content."
|
||||
self.args_schema = FixedPDFSearchToolSchema
|
||||
self._generate_description()
|
||||
return self
|
||||
|
||||
def add(self, pdf: str) -> None:
|
||||
super().add(pdf, data_type=DataType.PDF_FILE)
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
from crewai.rag.embeddings.types import ProviderSpec
|
||||
|
||||
from crewai_tools.tools.rag.types import RagToolConfig, VectorDbConfig
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ProviderSpec",
|
||||
"RagToolConfig",
|
||||
"VectorDbConfig",
|
||||
]
|
||||
|
||||
@@ -1,10 +1,74 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import os
|
||||
from typing import Any, cast
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from crewai.rag.embeddings.factory import get_embedding_function
|
||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
||||
from crewai.rag.embeddings.factory import build_embedder
|
||||
from crewai.rag.embeddings.types import ProviderSpec
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
TypeAdapter,
|
||||
ValidationError,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai_tools.tools.rag.types import RagToolConfig, VectorDbConfig
|
||||
|
||||
|
||||
def _validate_embedding_config(
|
||||
value: dict[str, Any] | ProviderSpec,
|
||||
) -> dict[str, Any] | ProviderSpec:
|
||||
"""Validate embedding config and provide clearer error messages for union validation.
|
||||
|
||||
This pre-validator catches Pydantic ValidationErrors from the ProviderSpec union
|
||||
and provides a cleaner, more focused error message that only shows the relevant
|
||||
provider's validation errors instead of all 18 union members.
|
||||
|
||||
Args:
|
||||
value: The embedding configuration dictionary or validated ProviderSpec.
|
||||
|
||||
Returns:
|
||||
A validated ProviderSpec instance, or the original value if already validated
|
||||
or missing required fields.
|
||||
|
||||
Raises:
|
||||
ValueError: If the configuration is invalid for the specified provider.
|
||||
"""
|
||||
if not isinstance(value, dict):
|
||||
return value
|
||||
|
||||
provider = value.get("provider")
|
||||
if not provider:
|
||||
return value
|
||||
|
||||
try:
|
||||
type_adapter: TypeAdapter[ProviderSpec] = TypeAdapter(ProviderSpec)
|
||||
return type_adapter.validate_python(value)
|
||||
except ValidationError as e:
|
||||
provider_key = f"{provider.lower()}providerspec"
|
||||
provider_errors = [
|
||||
err for err in e.errors() if provider_key in str(err.get("loc", "")).lower()
|
||||
]
|
||||
|
||||
if provider_errors:
|
||||
error_msgs = []
|
||||
for err in provider_errors:
|
||||
loc_parts = err["loc"]
|
||||
if str(loc_parts[0]).lower() == provider_key:
|
||||
loc_parts = loc_parts[1:]
|
||||
loc = ".".join(str(x) for x in loc_parts)
|
||||
error_msgs.append(f" - {loc}: {err['msg']}")
|
||||
|
||||
raise ValueError(
|
||||
f"Invalid configuration for embedding provider '{provider}':\n"
|
||||
+ "\n".join(error_msgs)
|
||||
) from e
|
||||
|
||||
raise
|
||||
|
||||
|
||||
class Adapter(BaseModel, ABC):
|
||||
@@ -46,139 +110,100 @@ class RagTool(BaseTool):
|
||||
summarize: bool = False
|
||||
similarity_threshold: float = 0.6
|
||||
limit: int = 5
|
||||
collection_name: str = "rag_tool_collection"
|
||||
adapter: Adapter = Field(default_factory=_AdapterPlaceholder)
|
||||
config: Any | None = None
|
||||
config: RagToolConfig = Field(
|
||||
default_factory=RagToolConfig,
|
||||
description="Configuration format accepted by RagTool.",
|
||||
)
|
||||
|
||||
@field_validator("config", mode="before")
|
||||
@classmethod
|
||||
def _validate_config(cls, value: Any) -> Any:
|
||||
"""Validate config with improved error messages for embedding providers."""
|
||||
if not isinstance(value, dict):
|
||||
return value
|
||||
|
||||
embedding_model = value.get("embedding_model")
|
||||
if embedding_model:
|
||||
try:
|
||||
value["embedding_model"] = _validate_embedding_config(embedding_model)
|
||||
except ValueError:
|
||||
raise
|
||||
|
||||
return value
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _set_default_adapter(self):
|
||||
def _ensure_adapter(self) -> Self:
|
||||
if isinstance(self.adapter, RagTool._AdapterPlaceholder):
|
||||
from crewai_tools.adapters.crewai_rag_adapter import CrewAIRagAdapter
|
||||
|
||||
parsed_config = self._parse_config(self.config)
|
||||
|
||||
provider_cfg = self._parse_config(self.config)
|
||||
self.adapter = CrewAIRagAdapter(
|
||||
collection_name="rag_tool_collection",
|
||||
collection_name=self.collection_name,
|
||||
summarize=self.summarize,
|
||||
similarity_threshold=self.similarity_threshold,
|
||||
limit=self.limit,
|
||||
config=parsed_config,
|
||||
config=provider_cfg,
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def _parse_config(self, config: Any) -> Any:
|
||||
"""Parse complex config format to extract provider-specific config.
|
||||
def _parse_config(self, config: RagToolConfig) -> Any:
|
||||
"""Normalize the RagToolConfig into a provider-specific config object.
|
||||
|
||||
Raises:
|
||||
ValueError: If the config format is invalid or uses unsupported providers.
|
||||
Defaults to 'chromadb' with no extra provider config if none is supplied.
|
||||
"""
|
||||
if config is None:
|
||||
return None
|
||||
if not config:
|
||||
return self._create_provider_config("chromadb", {}, None)
|
||||
|
||||
if isinstance(config, dict) and "provider" in config:
|
||||
return config
|
||||
vectordb_cfg = cast(VectorDbConfig, config.get("vectordb", {}))
|
||||
provider: Literal["chromadb", "qdrant"] = vectordb_cfg.get(
|
||||
"provider", "chromadb"
|
||||
)
|
||||
provider_config: dict[str, Any] = vectordb_cfg.get("config", {})
|
||||
|
||||
if isinstance(config, dict):
|
||||
if "vectordb" in config:
|
||||
vectordb_config = config["vectordb"]
|
||||
if isinstance(vectordb_config, dict) and "provider" in vectordb_config:
|
||||
provider = vectordb_config["provider"]
|
||||
provider_config = vectordb_config.get("config", {})
|
||||
supported = ("chromadb", "qdrant")
|
||||
if provider not in supported:
|
||||
raise ValueError(
|
||||
f"Unsupported vector database provider: '{provider}'. "
|
||||
f"CrewAI RAG currently supports: {', '.join(supported)}."
|
||||
)
|
||||
|
||||
supported_providers = ["chromadb", "qdrant"]
|
||||
if provider not in supported_providers:
|
||||
raise ValueError(
|
||||
f"Unsupported vector database provider: '{provider}'. "
|
||||
f"CrewAI RAG currently supports: {', '.join(supported_providers)}."
|
||||
)
|
||||
embedding_spec: ProviderSpec | None = config.get("embedding_model")
|
||||
if embedding_spec:
|
||||
embedding_spec = cast(
|
||||
ProviderSpec, _validate_embedding_config(embedding_spec)
|
||||
)
|
||||
|
||||
embedding_config = config.get("embedding_model")
|
||||
embedding_function = None
|
||||
if embedding_config and isinstance(embedding_config, dict):
|
||||
embedding_function = self._create_embedding_function(
|
||||
embedding_config, provider
|
||||
)
|
||||
|
||||
return self._create_provider_config(
|
||||
provider, provider_config, embedding_function
|
||||
)
|
||||
return None
|
||||
embedding_config = config.get("embedding_model")
|
||||
embedding_function = None
|
||||
if embedding_config and isinstance(embedding_config, dict):
|
||||
embedding_function = self._create_embedding_function(
|
||||
embedding_config, "chromadb"
|
||||
)
|
||||
|
||||
return self._create_provider_config("chromadb", {}, embedding_function)
|
||||
return config
|
||||
|
||||
@staticmethod
|
||||
def _create_embedding_function(embedding_config: dict, provider: str) -> Any:
|
||||
"""Create embedding function for the specified vector database provider."""
|
||||
embedding_provider = embedding_config.get("provider")
|
||||
embedding_model_config = embedding_config.get("config", {}).copy()
|
||||
|
||||
if "model" in embedding_model_config:
|
||||
embedding_model_config["model_name"] = embedding_model_config.pop("model")
|
||||
|
||||
factory_config = {"provider": embedding_provider, **embedding_model_config}
|
||||
|
||||
if embedding_provider == "openai" and "api_key" not in factory_config:
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if api_key:
|
||||
factory_config["api_key"] = api_key
|
||||
|
||||
if provider == "chromadb":
|
||||
return get_embedding_function(factory_config) # type: ignore[call-overload]
|
||||
|
||||
if provider == "qdrant":
|
||||
chromadb_func = get_embedding_function(factory_config) # type: ignore[call-overload]
|
||||
|
||||
def qdrant_embed_fn(text: str) -> list[float]:
|
||||
"""Embed text using ChromaDB function and convert to list of floats for Qdrant.
|
||||
|
||||
Args:
|
||||
text: The input text to embed.
|
||||
|
||||
Returns:
|
||||
A list of floats representing the embedding.
|
||||
"""
|
||||
embeddings = chromadb_func([text])
|
||||
return embeddings[0] if embeddings and len(embeddings) > 0 else []
|
||||
|
||||
return cast(Any, qdrant_embed_fn)
|
||||
|
||||
return None
|
||||
embedding_function = build_embedder(embedding_spec) if embedding_spec else None
|
||||
return self._create_provider_config(
|
||||
provider, provider_config, embedding_function
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_provider_config(
|
||||
provider: str, provider_config: dict, embedding_function: Any
|
||||
provider: Literal["chromadb", "qdrant"],
|
||||
provider_config: dict[str, Any],
|
||||
embedding_function: EmbeddingFunction[Any] | None,
|
||||
) -> Any:
|
||||
"""Create proper provider config object."""
|
||||
"""Instantiate provider config with optional embedding_function injected."""
|
||||
if provider == "chromadb":
|
||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
|
||||
config_kwargs = {}
|
||||
if embedding_function:
|
||||
config_kwargs["embedding_function"] = embedding_function
|
||||
|
||||
config_kwargs.update(provider_config)
|
||||
|
||||
return ChromaDBConfig(**config_kwargs)
|
||||
kwargs = dict(provider_config)
|
||||
if embedding_function is not None:
|
||||
kwargs["embedding_function"] = embedding_function
|
||||
return ChromaDBConfig(**kwargs)
|
||||
|
||||
if provider == "qdrant":
|
||||
from crewai.rag.qdrant.config import QdrantConfig
|
||||
|
||||
config_kwargs = {}
|
||||
if embedding_function:
|
||||
config_kwargs["embedding_function"] = embedding_function
|
||||
kwargs = dict(provider_config)
|
||||
if embedding_function is not None:
|
||||
kwargs["embedding_function"] = embedding_function
|
||||
return QdrantConfig(**kwargs)
|
||||
|
||||
config_kwargs.update(provider_config)
|
||||
|
||||
return QdrantConfig(**config_kwargs)
|
||||
|
||||
return None
|
||||
raise ValueError(f"Unhandled provider: {provider}")
|
||||
|
||||
def add(
|
||||
self,
|
||||
|
||||
32
lib/crewai-tools/src/crewai_tools/tools/rag/types.py
Normal file
32
lib/crewai-tools/src/crewai_tools/tools/rag/types.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Type definitions for RAG tool configuration."""
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from crewai.rag.embeddings.types import ProviderSpec
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
class VectorDbConfig(TypedDict):
|
||||
"""Configuration for vector database provider.
|
||||
|
||||
Attributes:
|
||||
provider: RAG provider literal.
|
||||
config: RAG configuration options.
|
||||
"""
|
||||
|
||||
provider: Literal["chromadb", "qdrant"]
|
||||
config: dict[str, Any]
|
||||
|
||||
|
||||
class RagToolConfig(TypedDict, total=False):
|
||||
"""Configuration accepted by RAG tools.
|
||||
|
||||
Supports embedding model and vector database configuration.
|
||||
|
||||
Attributes:
|
||||
embedding_model: Embedding model configuration accepted by RAG tools.
|
||||
vectordb: Vector database configuration accepted by RAG tools.
|
||||
"""
|
||||
|
||||
embedding_model: ProviderSpec
|
||||
vectordb: VectorDbConfig
|
||||
@@ -1,4 +1,5 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||
|
||||
@@ -24,14 +25,17 @@ class TXTSearchTool(RagTool):
|
||||
"A tool that can be used to semantic search a query from a txt's content."
|
||||
)
|
||||
args_schema: type[BaseModel] = TXTSearchToolSchema
|
||||
txt: str | None = None
|
||||
|
||||
def __init__(self, txt: str | None = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if txt is not None:
|
||||
self.add(txt)
|
||||
self.description = f"A tool that can be used to semantic search a query the {txt} txt's content."
|
||||
@model_validator(mode="after")
|
||||
def _configure_for_txt(self) -> Self:
|
||||
"""Configure tool for specific TXT file if provided."""
|
||||
if self.txt is not None:
|
||||
self.add(self.txt)
|
||||
self.description = f"A tool that can be used to semantic search a query the {self.txt} txt's content."
|
||||
self.args_schema = FixedTXTSearchToolSchema
|
||||
self._generate_description()
|
||||
return self
|
||||
|
||||
def _run( # type: ignore[override]
|
||||
self,
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
"""Tests for RAG tool with mocked embeddings and vector database."""
|
||||
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import cast
|
||||
@@ -117,15 +115,15 @@ def test_rag_tool_with_file(
|
||||
assert "Python is a programming language" in result
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.rag.rag_tool.RagTool._create_embedding_function")
|
||||
@patch("crewai_tools.tools.rag.rag_tool.build_embedder")
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
|
||||
def test_rag_tool_with_custom_embeddings(
|
||||
mock_create_client: Mock, mock_create_embedding: Mock
|
||||
mock_create_client: Mock, mock_build_embedder: Mock
|
||||
) -> None:
|
||||
"""Test RagTool with custom embeddings configuration to ensure no API calls."""
|
||||
mock_embedding_func = MagicMock()
|
||||
mock_embedding_func.return_value = [[0.2] * 1536]
|
||||
mock_create_embedding.return_value = mock_embedding_func
|
||||
mock_build_embedder.return_value = mock_embedding_func
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
@@ -153,7 +151,7 @@ def test_rag_tool_with_custom_embeddings(
|
||||
assert "Relevant Content:" in result
|
||||
assert "Test content" in result
|
||||
|
||||
mock_create_embedding.assert_called()
|
||||
mock_build_embedder.assert_called()
|
||||
|
||||
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.get_rag_client")
|
||||
@@ -176,3 +174,128 @@ def test_rag_tool_no_results(
|
||||
result = tool._run(query="Non-existent content")
|
||||
assert "Relevant Content:" in result
|
||||
assert "No relevant content found" in result
|
||||
|
||||
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
|
||||
def test_rag_tool_with_azure_config_without_env_vars(
|
||||
mock_create_client: Mock,
|
||||
) -> None:
|
||||
"""Test that RagTool accepts Azure config without requiring env vars.
|
||||
|
||||
This test verifies the fix for the issue where RAG tools were ignoring
|
||||
the embedding configuration passed via the config parameter and instead
|
||||
requiring environment variables like EMBEDDINGS_OPENAI_API_KEY.
|
||||
"""
|
||||
mock_embedding_func = MagicMock()
|
||||
mock_embedding_func.return_value = [[0.1] * 1536]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_client.add_documents = MagicMock(return_value=None)
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
# Patch the embedding function builder to avoid actual API calls
|
||||
with patch(
|
||||
"crewai_tools.tools.rag.rag_tool.build_embedder",
|
||||
return_value=mock_embedding_func,
|
||||
):
|
||||
|
||||
class MyTool(RagTool):
|
||||
pass
|
||||
|
||||
# Configuration with explicit Azure credentials - should work without env vars
|
||||
config = {
|
||||
"embedding_model": {
|
||||
"provider": "azure",
|
||||
"config": {
|
||||
"model": "text-embedding-3-small",
|
||||
"api_key": "test-api-key",
|
||||
"api_base": "https://test.openai.azure.com/",
|
||||
"api_version": "2024-02-01",
|
||||
"api_type": "azure",
|
||||
"deployment_id": "test-deployment",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
# This should not raise a validation error about missing env vars
|
||||
tool = MyTool(config=config)
|
||||
|
||||
assert tool.adapter is not None
|
||||
assert isinstance(tool.adapter, CrewAIRagAdapter)
|
||||
|
||||
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
|
||||
def test_rag_tool_with_openai_config_without_env_vars(
|
||||
mock_create_client: Mock,
|
||||
) -> None:
|
||||
"""Test that RagTool accepts OpenAI config without requiring env vars."""
|
||||
mock_embedding_func = MagicMock()
|
||||
mock_embedding_func.return_value = [[0.1] * 1536]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
with patch(
|
||||
"crewai_tools.tools.rag.rag_tool.build_embedder",
|
||||
return_value=mock_embedding_func,
|
||||
):
|
||||
|
||||
class MyTool(RagTool):
|
||||
pass
|
||||
|
||||
config = {
|
||||
"embedding_model": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": "text-embedding-3-small",
|
||||
"api_key": "sk-test123",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
tool = MyTool(config=config)
|
||||
|
||||
assert tool.adapter is not None
|
||||
assert isinstance(tool.adapter, CrewAIRagAdapter)
|
||||
|
||||
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
|
||||
def test_rag_tool_config_with_qdrant_and_azure_embeddings(
|
||||
mock_create_client: Mock,
|
||||
) -> None:
|
||||
"""Test RagTool with Qdrant vector DB and Azure embeddings config."""
|
||||
mock_embedding_func = MagicMock()
|
||||
mock_embedding_func.return_value = [[0.1] * 1536]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
with patch(
|
||||
"crewai_tools.tools.rag.rag_tool.build_embedder",
|
||||
return_value=mock_embedding_func,
|
||||
):
|
||||
|
||||
class MyTool(RagTool):
|
||||
pass
|
||||
|
||||
config = {
|
||||
"vectordb": {"provider": "qdrant", "config": {}},
|
||||
"embedding_model": {
|
||||
"provider": "azure",
|
||||
"config": {
|
||||
"model": "text-embedding-3-large",
|
||||
"api_key": "test-key",
|
||||
"api_base": "https://test.openai.azure.com/",
|
||||
"api_version": "2024-02-01",
|
||||
"deployment_id": "test-deployment",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tool = MyTool(config=config)
|
||||
|
||||
assert tool.adapter is not None
|
||||
assert isinstance(tool.adapter, CrewAIRagAdapter)
|
||||
|
||||
66
lib/crewai-tools/tests/tools/rag/test_rag_tool_validation.py
Normal file
66
lib/crewai-tools/tests/tools/rag/test_rag_tool_validation.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""Tests for improved RAG tool validation error messages."""
|
||||
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||
|
||||
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
|
||||
def test_azure_missing_deployment_id_gives_clear_error(mock_create_client: Mock) -> None:
|
||||
"""Test that missing deployment_id for Azure gives a clear, focused error message."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
class MyTool(RagTool):
|
||||
pass
|
||||
|
||||
config = {
|
||||
"embedding_model": {
|
||||
"provider": "azure",
|
||||
"config": {
|
||||
"api_base": "http://localhost:4000/v1",
|
||||
"api_key": "test-key",
|
||||
"api_version": "2024-02-01",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
MyTool(config=config)
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "azure" in error_msg.lower()
|
||||
assert "deployment_id" in error_msg.lower()
|
||||
assert "bedrock" not in error_msg.lower()
|
||||
assert "cohere" not in error_msg.lower()
|
||||
assert "huggingface" not in error_msg.lower()
|
||||
|
||||
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
|
||||
def test_valid_azure_config_works(mock_create_client: Mock) -> None:
|
||||
"""Test that valid Azure config works without errors."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
class MyTool(RagTool):
|
||||
pass
|
||||
|
||||
config = {
|
||||
"embedding_model": {
|
||||
"provider": "azure",
|
||||
"config": {
|
||||
"api_base": "http://localhost:4000/v1",
|
||||
"api_key": "test-key",
|
||||
"api_version": "2024-02-01",
|
||||
"deployment_id": "text-embedding-3-small",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
tool = MyTool(config=config)
|
||||
assert tool is not None
|
||||
116
lib/crewai-tools/tests/tools/test_pdf_search_tool_config.py
Normal file
116
lib/crewai-tools/tests/tools/test_pdf_search_tool_config.py
Normal file
@@ -0,0 +1,116 @@
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
from crewai_tools.adapters.crewai_rag_adapter import CrewAIRagAdapter
|
||||
from crewai_tools.tools.pdf_search_tool.pdf_search_tool import PDFSearchTool
|
||||
|
||||
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
|
||||
def test_pdf_search_tool_with_azure_config_without_env_vars(
|
||||
mock_create_client: Mock,
|
||||
) -> None:
|
||||
"""Test PDFSearchTool accepts Azure config without requiring env vars.
|
||||
|
||||
This verifies the fix for the reported issue where PDFSearchTool would
|
||||
throw a validation error:
|
||||
pydantic_core._pydantic_core.ValidationError: 1 validation error for PDFSearchTool
|
||||
EMBEDDINGS_OPENAI_API_KEY
|
||||
Field required [type=missing, input_value={}, input_type=dict]
|
||||
"""
|
||||
mock_embedding_func = MagicMock()
|
||||
mock_embedding_func.return_value = [[0.1] * 1536]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
# Patch the embedding function builder to avoid actual API calls
|
||||
with patch(
|
||||
"crewai_tools.tools.rag.rag_tool.build_embedder",
|
||||
return_value=mock_embedding_func,
|
||||
):
|
||||
# This is the exact config format from the bug report
|
||||
config = {
|
||||
"embedding_model": {
|
||||
"provider": "azure",
|
||||
"config": {
|
||||
"model": "text-embedding-3-small",
|
||||
"api_key": "test-litellm-api-key",
|
||||
"api_base": "https://test.litellm.proxy/",
|
||||
"api_version": "2024-02-01",
|
||||
"api_type": "azure",
|
||||
"deployment_id": "test-deployment",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
# This should not raise a validation error about missing env vars
|
||||
tool = PDFSearchTool(config=config)
|
||||
|
||||
assert tool.adapter is not None
|
||||
assert isinstance(tool.adapter, CrewAIRagAdapter)
|
||||
assert tool.name == "Search a PDF's content"
|
||||
|
||||
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
|
||||
def test_pdf_search_tool_with_openai_config_without_env_vars(
|
||||
mock_create_client: Mock,
|
||||
) -> None:
|
||||
"""Test PDFSearchTool accepts OpenAI config without requiring env vars."""
|
||||
mock_embedding_func = MagicMock()
|
||||
mock_embedding_func.return_value = [[0.1] * 1536]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
with patch(
|
||||
"crewai_tools.tools.rag.rag_tool.build_embedder",
|
||||
return_value=mock_embedding_func,
|
||||
):
|
||||
config = {
|
||||
"embedding_model": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": "text-embedding-3-small",
|
||||
"api_key": "sk-test123",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
tool = PDFSearchTool(config=config)
|
||||
|
||||
assert tool.adapter is not None
|
||||
assert isinstance(tool.adapter, CrewAIRagAdapter)
|
||||
|
||||
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
|
||||
def test_pdf_search_tool_with_vectordb_and_embedding_config(
|
||||
mock_create_client: Mock,
|
||||
) -> None:
|
||||
"""Test PDFSearchTool with both vector DB and embedding config."""
|
||||
mock_embedding_func = MagicMock()
|
||||
mock_embedding_func.return_value = [[0.1] * 1536]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
with patch(
|
||||
"crewai_tools.tools.rag.rag_tool.build_embedder",
|
||||
return_value=mock_embedding_func,
|
||||
):
|
||||
config = {
|
||||
"vectordb": {"provider": "chromadb", "config": {}},
|
||||
"embedding_model": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": "text-embedding-3-large",
|
||||
"api_key": "sk-test-key",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tool = PDFSearchTool(config=config)
|
||||
|
||||
assert tool.adapter is not None
|
||||
assert isinstance(tool.adapter, CrewAIRagAdapter)
|
||||
104
lib/crewai-tools/tests/tools/test_txt_search_tool_config.py
Normal file
104
lib/crewai-tools/tests/tools/test_txt_search_tool_config.py
Normal file
@@ -0,0 +1,104 @@
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
from crewai_tools.adapters.crewai_rag_adapter import CrewAIRagAdapter
|
||||
from crewai_tools.tools.txt_search_tool.txt_search_tool import TXTSearchTool
|
||||
|
||||
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
|
||||
def test_txt_search_tool_with_azure_config_without_env_vars(
|
||||
mock_create_client: Mock,
|
||||
) -> None:
|
||||
"""Test TXTSearchTool accepts Azure config without requiring env vars."""
|
||||
mock_embedding_func = MagicMock()
|
||||
mock_embedding_func.return_value = [[0.1] * 1536]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
with patch(
|
||||
"crewai_tools.tools.rag.rag_tool.build_embedder",
|
||||
return_value=mock_embedding_func,
|
||||
):
|
||||
config = {
|
||||
"embedding_model": {
|
||||
"provider": "azure",
|
||||
"config": {
|
||||
"model": "text-embedding-3-small",
|
||||
"api_key": "test-api-key",
|
||||
"api_base": "https://test.openai.azure.com/",
|
||||
"api_version": "2024-02-01",
|
||||
"api_type": "azure",
|
||||
"deployment_id": "test-deployment",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
# This should not raise a validation error about missing env vars
|
||||
tool = TXTSearchTool(config=config)
|
||||
|
||||
assert tool.adapter is not None
|
||||
assert isinstance(tool.adapter, CrewAIRagAdapter)
|
||||
assert tool.name == "Search a txt's content"
|
||||
|
||||
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
|
||||
def test_txt_search_tool_with_openai_config_without_env_vars(
|
||||
mock_create_client: Mock,
|
||||
) -> None:
|
||||
"""Test TXTSearchTool accepts OpenAI config without requiring env vars."""
|
||||
mock_embedding_func = MagicMock()
|
||||
mock_embedding_func.return_value = [[0.1] * 1536]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
with patch(
|
||||
"crewai_tools.tools.rag.rag_tool.build_embedder",
|
||||
return_value=mock_embedding_func,
|
||||
):
|
||||
config = {
|
||||
"embedding_model": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": "text-embedding-3-small",
|
||||
"api_key": "sk-test123",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
tool = TXTSearchTool(config=config)
|
||||
|
||||
assert tool.adapter is not None
|
||||
assert isinstance(tool.adapter, CrewAIRagAdapter)
|
||||
|
||||
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
|
||||
def test_txt_search_tool_with_cohere_config(mock_create_client: Mock) -> None:
|
||||
"""Test TXTSearchTool with Cohere embedding provider."""
|
||||
mock_embedding_func = MagicMock()
|
||||
mock_embedding_func.return_value = [[0.1] * 1024]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
with patch(
|
||||
"crewai_tools.tools.rag.rag_tool.build_embedder",
|
||||
return_value=mock_embedding_func,
|
||||
):
|
||||
config = {
|
||||
"embedding_model": {
|
||||
"provider": "cohere",
|
||||
"config": {
|
||||
"model": "embed-english-v3.0",
|
||||
"api_key": "test-cohere-key",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
tool = TXTSearchTool(config=config)
|
||||
|
||||
assert tool.adapter is not None
|
||||
assert isinstance(tool.adapter, CrewAIRagAdapter)
|
||||
@@ -91,6 +91,7 @@ PROVIDER_PATHS = {
|
||||
"cohere": "crewai.rag.embeddings.providers.cohere.cohere_provider.CohereProvider",
|
||||
"custom": "crewai.rag.embeddings.providers.custom.custom_provider.CustomProvider",
|
||||
"google-generativeai": "crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider",
|
||||
"google": "crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider",
|
||||
"google-vertex": "crewai.rag.embeddings.providers.google.vertex.VertexAIProvider",
|
||||
"huggingface": "crewai.rag.embeddings.providers.huggingface.huggingface_provider.HuggingFaceProvider",
|
||||
"instructor": "crewai.rag.embeddings.providers.instructor.instructor_provider.InstructorProvider",
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any
|
||||
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
|
||||
AmazonBedrockEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -21,7 +21,7 @@ def create_aws_session() -> Any:
|
||||
ValueError: If AWS session creation fails
|
||||
"""
|
||||
try:
|
||||
import boto3 # type: ignore[import]
|
||||
import boto3
|
||||
|
||||
return boto3.Session()
|
||||
except ImportError as e:
|
||||
@@ -46,7 +46,12 @@ class BedrockProvider(BaseEmbeddingsProvider[AmazonBedrockEmbeddingFunction]):
|
||||
model_name: str = Field(
|
||||
default="amazon.titan-embed-text-v1",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_BEDROCK_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_BEDROCK_MODEL_NAME",
|
||||
"BEDROCK_MODEL_NAME",
|
||||
"AWS_BEDROCK_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
session: Any = Field(
|
||||
default_factory=create_aws_session, description="AWS session object"
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.cohere_embedding_function import (
|
||||
CohereEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -15,10 +15,14 @@ class CohereProvider(BaseEmbeddingsProvider[CohereEmbeddingFunction]):
|
||||
default=CohereEmbeddingFunction, description="Cohere embedding function class"
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Cohere API key", validation_alias="EMBEDDINGS_COHERE_API_KEY"
|
||||
description="Cohere API key",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_COHERE_API_KEY", "COHERE_API_KEY"),
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="large",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_COHERE_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_COHERE_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
"""Google Generative AI embeddings provider."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleGenerativeAiEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -15,16 +17,27 @@ class GenerativeAiProvider(BaseEmbeddingsProvider[GoogleGenerativeAiEmbeddingFun
|
||||
default=GoogleGenerativeAiEmbeddingFunction,
|
||||
description="Google Generative AI embedding function class",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="models/embedding-001",
|
||||
model_name: Literal[
|
||||
"gemini-embedding-001", "text-embedding-005", "text-multilingual-embedding-002"
|
||||
] = Field(
|
||||
default="gemini-embedding-001",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_GOOGLE_GENERATIVE_AI_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_GENERATIVE_AI_MODEL_NAME", "model"
|
||||
),
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Google API key", validation_alias="EMBEDDINGS_GOOGLE_API_KEY"
|
||||
description="Google API key",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_API_KEY", "GOOGLE_API_KEY", "GEMINI_API_KEY"
|
||||
),
|
||||
)
|
||||
task_type: str = Field(
|
||||
default="RETRIEVAL_DOCUMENT",
|
||||
description="Task type for embeddings",
|
||||
validation_alias="EMBEDDINGS_GOOGLE_GENERATIVE_AI_TASK_TYPE",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_GENERATIVE_AI_TASK_TYPE",
|
||||
"GOOGLE_GENERATIVE_AI_TASK_TYPE",
|
||||
"GEMINI_TASK_TYPE",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -6,10 +6,23 @@ from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class GenerativeAiProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Google Generative AI provider."""
|
||||
"""Configuration for Google Generative AI provider.
|
||||
|
||||
Attributes:
|
||||
api_key: Google API key for authentication.
|
||||
model_name: Embedding model name.
|
||||
task_type: Task type for embeddings. Default is "RETRIEVAL_DOCUMENT".
|
||||
"""
|
||||
|
||||
api_key: str
|
||||
model_name: Annotated[str, "models/embedding-001"]
|
||||
model_name: Annotated[
|
||||
Literal[
|
||||
"gemini-embedding-001",
|
||||
"text-embedding-005",
|
||||
"text-multilingual-embedding-002",
|
||||
],
|
||||
"gemini-embedding-001",
|
||||
]
|
||||
task_type: Annotated[str, "RETRIEVAL_DOCUMENT"]
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleVertexEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -18,18 +18,29 @@ class VertexAIProvider(BaseEmbeddingsProvider[GoogleVertexEmbeddingFunction]):
|
||||
model_name: str = Field(
|
||||
default="textembedding-gecko",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_GOOGLE_VERTEX_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_VERTEX_MODEL_NAME",
|
||||
"GOOGLE_VERTEX_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Google API key", validation_alias="EMBEDDINGS_GOOGLE_CLOUD_API_KEY"
|
||||
description="Google API key",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_CLOUD_API_KEY", "GOOGLE_CLOUD_API_KEY"
|
||||
),
|
||||
)
|
||||
project_id: str = Field(
|
||||
default="cloud-large-language-models",
|
||||
description="GCP project ID",
|
||||
validation_alias="EMBEDDINGS_GOOGLE_CLOUD_PROJECT",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_CLOUD_PROJECT", "GOOGLE_CLOUD_PROJECT"
|
||||
),
|
||||
)
|
||||
region: str = Field(
|
||||
default="us-central1",
|
||||
description="GCP region",
|
||||
validation_alias="EMBEDDINGS_GOOGLE_CLOUD_REGION",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_CLOUD_REGION", "GOOGLE_CLOUD_REGION"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||
HuggingFaceEmbeddingServer,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -16,5 +16,6 @@ class HuggingFaceProvider(BaseEmbeddingsProvider[HuggingFaceEmbeddingServer]):
|
||||
description="HuggingFace embedding function class",
|
||||
)
|
||||
url: str = Field(
|
||||
description="HuggingFace API URL", validation_alias="EMBEDDINGS_HUGGINGFACE_URL"
|
||||
description="HuggingFace API URL",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_HUGGINGFACE_URL", "HUGGINGFACE_URL"),
|
||||
)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, model_validator
|
||||
from pydantic import AliasChoices, Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
@@ -21,7 +21,10 @@ class WatsonXProvider(BaseEmbeddingsProvider[WatsonXEmbeddingFunction]):
|
||||
default=WatsonXEmbeddingFunction, description="WatsonX embedding function class"
|
||||
)
|
||||
model_id: str = Field(
|
||||
description="WatsonX model ID", validation_alias="EMBEDDINGS_WATSONX_MODEL_ID"
|
||||
description="WatsonX model ID",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_MODEL_ID", "WATSONX_MODEL_ID"
|
||||
),
|
||||
)
|
||||
params: dict[str, str | dict[str, str]] | None = Field(
|
||||
default=None, description="Additional parameters"
|
||||
@@ -30,109 +33,143 @@ class WatsonXProvider(BaseEmbeddingsProvider[WatsonXEmbeddingFunction]):
|
||||
project_id: str | None = Field(
|
||||
default=None,
|
||||
description="WatsonX project ID",
|
||||
validation_alias="EMBEDDINGS_WATSONX_PROJECT_ID",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_PROJECT_ID", "WATSONX_PROJECT_ID"
|
||||
),
|
||||
)
|
||||
space_id: str | None = Field(
|
||||
default=None,
|
||||
description="WatsonX space ID",
|
||||
validation_alias="EMBEDDINGS_WATSONX_SPACE_ID",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_SPACE_ID", "WATSONX_SPACE_ID"
|
||||
),
|
||||
)
|
||||
api_client: Any | None = Field(default=None, description="WatsonX API client")
|
||||
verify: bool | str | None = Field(
|
||||
default=None,
|
||||
description="SSL verification",
|
||||
validation_alias="EMBEDDINGS_WATSONX_VERIFY",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_WATSONX_VERIFY", "WATSONX_VERIFY"),
|
||||
)
|
||||
persistent_connection: bool = Field(
|
||||
default=True,
|
||||
description="Use persistent connection",
|
||||
validation_alias="EMBEDDINGS_WATSONX_PERSISTENT_CONNECTION",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_PERSISTENT_CONNECTION", "WATSONX_PERSISTENT_CONNECTION"
|
||||
),
|
||||
)
|
||||
batch_size: int = Field(
|
||||
default=100,
|
||||
description="Batch size for processing",
|
||||
validation_alias="EMBEDDINGS_WATSONX_BATCH_SIZE",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_BATCH_SIZE", "WATSONX_BATCH_SIZE"
|
||||
),
|
||||
)
|
||||
concurrency_limit: int = Field(
|
||||
default=10,
|
||||
description="Concurrency limit",
|
||||
validation_alias="EMBEDDINGS_WATSONX_CONCURRENCY_LIMIT",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_CONCURRENCY_LIMIT", "WATSONX_CONCURRENCY_LIMIT"
|
||||
),
|
||||
)
|
||||
max_retries: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum retries",
|
||||
validation_alias="EMBEDDINGS_WATSONX_MAX_RETRIES",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_MAX_RETRIES", "WATSONX_MAX_RETRIES"
|
||||
),
|
||||
)
|
||||
delay_time: float | None = Field(
|
||||
default=None,
|
||||
description="Delay time between retries",
|
||||
validation_alias="EMBEDDINGS_WATSONX_DELAY_TIME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_DELAY_TIME", "WATSONX_DELAY_TIME"
|
||||
),
|
||||
)
|
||||
retry_status_codes: list[int] | None = Field(
|
||||
default=None, description="HTTP status codes to retry on"
|
||||
)
|
||||
url: str = Field(
|
||||
description="WatsonX API URL", validation_alias="EMBEDDINGS_WATSONX_URL"
|
||||
description="WatsonX API URL",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_WATSONX_URL", "WATSONX_URL"),
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="WatsonX API key", validation_alias="EMBEDDINGS_WATSONX_API_KEY"
|
||||
description="WatsonX API key",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_WATSONX_API_KEY", "WATSONX_API_KEY"),
|
||||
)
|
||||
name: str | None = Field(
|
||||
default=None,
|
||||
description="Service name",
|
||||
validation_alias="EMBEDDINGS_WATSONX_NAME",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_WATSONX_NAME", "WATSONX_NAME"),
|
||||
)
|
||||
iam_serviceid_crn: str | None = Field(
|
||||
default=None,
|
||||
description="IAM service ID CRN",
|
||||
validation_alias="EMBEDDINGS_WATSONX_IAM_SERVICEID_CRN",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_IAM_SERVICEID_CRN", "WATSONX_IAM_SERVICEID_CRN"
|
||||
),
|
||||
)
|
||||
trusted_profile_id: str | None = Field(
|
||||
default=None,
|
||||
description="Trusted profile ID",
|
||||
validation_alias="EMBEDDINGS_WATSONX_TRUSTED_PROFILE_ID",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_TRUSTED_PROFILE_ID", "WATSONX_TRUSTED_PROFILE_ID"
|
||||
),
|
||||
)
|
||||
token: str | None = Field(
|
||||
default=None,
|
||||
description="Bearer token",
|
||||
validation_alias="EMBEDDINGS_WATSONX_TOKEN",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_WATSONX_TOKEN", "WATSONX_TOKEN"),
|
||||
)
|
||||
projects_token: str | None = Field(
|
||||
default=None,
|
||||
description="Projects token",
|
||||
validation_alias="EMBEDDINGS_WATSONX_PROJECTS_TOKEN",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_PROJECTS_TOKEN", "WATSONX_PROJECTS_TOKEN"
|
||||
),
|
||||
)
|
||||
username: str | None = Field(
|
||||
default=None,
|
||||
description="Username",
|
||||
validation_alias="EMBEDDINGS_WATSONX_USERNAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_USERNAME", "WATSONX_USERNAME"
|
||||
),
|
||||
)
|
||||
password: str | None = Field(
|
||||
default=None,
|
||||
description="Password",
|
||||
validation_alias="EMBEDDINGS_WATSONX_PASSWORD",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_PASSWORD", "WATSONX_PASSWORD"
|
||||
),
|
||||
)
|
||||
instance_id: str | None = Field(
|
||||
default=None,
|
||||
description="Service instance ID",
|
||||
validation_alias="EMBEDDINGS_WATSONX_INSTANCE_ID",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_INSTANCE_ID", "WATSONX_INSTANCE_ID"
|
||||
),
|
||||
)
|
||||
version: str | None = Field(
|
||||
default=None,
|
||||
description="API version",
|
||||
validation_alias="EMBEDDINGS_WATSONX_VERSION",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_WATSONX_VERSION", "WATSONX_VERSION"),
|
||||
)
|
||||
bedrock_url: str | None = Field(
|
||||
default=None,
|
||||
description="Bedrock URL",
|
||||
validation_alias="EMBEDDINGS_WATSONX_BEDROCK_URL",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_BEDROCK_URL", "WATSONX_BEDROCK_URL"
|
||||
),
|
||||
)
|
||||
platform_url: str | None = Field(
|
||||
default=None,
|
||||
description="Platform URL",
|
||||
validation_alias="EMBEDDINGS_WATSONX_PLATFORM_URL",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_PLATFORM_URL", "WATSONX_PLATFORM_URL"
|
||||
),
|
||||
)
|
||||
proxies: dict[str, Any] | None = Field(
|
||||
default=None, description="Proxy configuration"
|
||||
)
|
||||
proxies: dict | None = Field(default=None, description="Proxy configuration")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_space_or_project(self) -> Self:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.instructor_embedding_function import (
|
||||
InstructorEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -18,15 +18,23 @@ class InstructorProvider(BaseEmbeddingsProvider[InstructorEmbeddingFunction]):
|
||||
model_name: str = Field(
|
||||
default="hkunlp/instructor-base",
|
||||
description="Model name to use",
|
||||
validation_alias="EMBEDDINGS_INSTRUCTOR_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_INSTRUCTOR_MODEL_NAME",
|
||||
"INSTRUCTOR_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
device: str = Field(
|
||||
default="cpu",
|
||||
description="Device to run model on (cpu or cuda)",
|
||||
validation_alias="EMBEDDINGS_INSTRUCTOR_DEVICE",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_INSTRUCTOR_DEVICE", "INSTRUCTOR_DEVICE"
|
||||
),
|
||||
)
|
||||
instruction: str | None = Field(
|
||||
default=None,
|
||||
description="Instruction for embeddings",
|
||||
validation_alias="EMBEDDINGS_INSTRUCTOR_INSTRUCTION",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_INSTRUCTOR_INSTRUCTION", "INSTRUCTOR_INSTRUCTION"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.jina_embedding_function import (
|
||||
JinaEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -15,10 +15,15 @@ class JinaProvider(BaseEmbeddingsProvider[JinaEmbeddingFunction]):
|
||||
default=JinaEmbeddingFunction, description="Jina embedding function class"
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Jina API key", validation_alias="EMBEDDINGS_JINA_API_KEY"
|
||||
description="Jina API key",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_JINA_API_KEY", "JINA_API_KEY"),
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="jina-embeddings-v2-base-en",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_JINA_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_JINA_MODEL_NAME",
|
||||
"JINA_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -18,27 +18,39 @@ class AzureProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
|
||||
description="Azure OpenAI embedding function class",
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Azure API key", validation_alias="EMBEDDINGS_OPENAI_API_KEY"
|
||||
description="Azure API key",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_OPENAI_API_KEY", "OPENAI_API_KEY"),
|
||||
)
|
||||
api_base: str | None = Field(
|
||||
default=None,
|
||||
description="Azure endpoint URL",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_BASE",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_OPENAI_API_BASE", "OPENAI_API_BASE"),
|
||||
)
|
||||
api_type: str = Field(
|
||||
default="azure",
|
||||
description="API type for Azure",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_TYPE",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_API_TYPE", "OPENAI_API_TYPE", "AZURE_OPENAI_API_TYPE"
|
||||
),
|
||||
)
|
||||
api_version: str | None = Field(
|
||||
default=None,
|
||||
default="2024-02-01",
|
||||
description="Azure API version",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_VERSION",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_API_VERSION",
|
||||
"OPENAI_API_VERSION",
|
||||
"AZURE_OPENAI_API_VERSION",
|
||||
),
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="text-embedding-ada-002",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_OPENAI_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_MODEL_NAME",
|
||||
"OPENAI_MODEL_NAME",
|
||||
"AZURE_OPENAI_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
default_headers: dict[str, Any] | None = Field(
|
||||
default=None, description="Default headers for API requests"
|
||||
@@ -46,15 +58,26 @@ class AzureProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
|
||||
dimensions: int | None = Field(
|
||||
default=None,
|
||||
description="Embedding dimensions",
|
||||
validation_alias="EMBEDDINGS_OPENAI_DIMENSIONS",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_DIMENSIONS",
|
||||
"OPENAI_DIMENSIONS",
|
||||
"AZURE_OPENAI_DIMENSIONS",
|
||||
),
|
||||
)
|
||||
deployment_id: str | None = Field(
|
||||
default=None,
|
||||
deployment_id: str = Field(
|
||||
description="Azure deployment ID",
|
||||
validation_alias="EMBEDDINGS_OPENAI_DEPLOYMENT_ID",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_DEPLOYMENT_ID",
|
||||
"AZURE_OPENAI_DEPLOYMENT",
|
||||
"AZURE_DEPLOYMENT_ID",
|
||||
),
|
||||
)
|
||||
organization_id: str | None = Field(
|
||||
default=None,
|
||||
description="Organization ID",
|
||||
validation_alias="EMBEDDINGS_OPENAI_ORGANIZATION_ID",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_ORGANIZATION_ID",
|
||||
"OPENAI_ORGANIZATION_ID",
|
||||
"AZURE_OPENAI_ORGANIZATION_ID",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -15,7 +15,7 @@ class AzureProviderConfig(TypedDict, total=False):
|
||||
model_name: Annotated[str, "text-embedding-ada-002"]
|
||||
default_headers: dict[str, Any]
|
||||
dimensions: int
|
||||
deployment_id: str
|
||||
deployment_id: Required[str]
|
||||
organization_id: str
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
||||
OllamaEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -17,9 +17,14 @@ class OllamaProvider(BaseEmbeddingsProvider[OllamaEmbeddingFunction]):
|
||||
url: str = Field(
|
||||
default="http://localhost:11434/api/embeddings",
|
||||
description="Ollama API endpoint URL",
|
||||
validation_alias="EMBEDDINGS_OLLAMA_URL",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_OLLAMA_URL", "OLLAMA_URL"),
|
||||
)
|
||||
model_name: str = Field(
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_OLLAMA_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OLLAMA_MODEL_NAME",
|
||||
"OLLAMA_MODEL_NAME",
|
||||
"OLLAMA_MODEL",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""ONNX embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ONNXMiniLM_L6_V2
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -15,5 +15,7 @@ class ONNXProvider(BaseEmbeddingsProvider[ONNXMiniLM_L6_V2]):
|
||||
preferred_providers: list[str] | None = Field(
|
||||
default=None,
|
||||
description="Preferred ONNX execution providers",
|
||||
validation_alias="EMBEDDINGS_ONNX_PREFERRED_PROVIDERS",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_ONNX_PREFERRED_PROVIDERS", "ONNX_PREFERRED_PROVIDERS"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -20,27 +20,33 @@ class OpenAIProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="OpenAI API key",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_KEY",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_OPENAI_API_KEY", "OPENAI_API_KEY"),
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="text-embedding-ada-002",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_OPENAI_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_MODEL_NAME",
|
||||
"OPENAI_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
api_base: str | None = Field(
|
||||
default=None,
|
||||
description="Base URL for API requests",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_BASE",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_OPENAI_API_BASE", "OPENAI_API_BASE"),
|
||||
)
|
||||
api_type: str | None = Field(
|
||||
default=None,
|
||||
description="API type (e.g., 'azure')",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_TYPE",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_OPENAI_API_TYPE", "OPENAI_API_TYPE"),
|
||||
)
|
||||
api_version: str | None = Field(
|
||||
default=None,
|
||||
description="API version",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_VERSION",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_API_VERSION", "OPENAI_API_VERSION"
|
||||
),
|
||||
)
|
||||
default_headers: dict[str, Any] | None = Field(
|
||||
default=None, description="Default headers for API requests"
|
||||
@@ -48,15 +54,21 @@ class OpenAIProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
|
||||
dimensions: int | None = Field(
|
||||
default=None,
|
||||
description="Embedding dimensions",
|
||||
validation_alias="EMBEDDINGS_OPENAI_DIMENSIONS",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_DIMENSIONS", "OPENAI_DIMENSIONS"
|
||||
),
|
||||
)
|
||||
deployment_id: str | None = Field(
|
||||
default=None,
|
||||
description="Azure deployment ID",
|
||||
validation_alias="EMBEDDINGS_OPENAI_DEPLOYMENT_ID",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_DEPLOYMENT_ID", "OPENAI_DEPLOYMENT_ID"
|
||||
),
|
||||
)
|
||||
organization_id: str | None = Field(
|
||||
default=None,
|
||||
description="OpenAI organization ID",
|
||||
validation_alias="EMBEDDINGS_OPENAI_ORGANIZATION_ID",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_ORGANIZATION_ID", "OPENAI_ORGANIZATION_ID"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.open_clip_embedding_function import (
|
||||
OpenCLIPEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -18,15 +18,21 @@ class OpenCLIPProvider(BaseEmbeddingsProvider[OpenCLIPEmbeddingFunction]):
|
||||
model_name: str = Field(
|
||||
default="ViT-B-32",
|
||||
description="Model name to use",
|
||||
validation_alias="EMBEDDINGS_OPENCLIP_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENCLIP_MODEL_NAME",
|
||||
"OPENCLIP_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
checkpoint: str = Field(
|
||||
default="laion2b_s34b_b79k",
|
||||
description="Model checkpoint",
|
||||
validation_alias="EMBEDDINGS_OPENCLIP_CHECKPOINT",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENCLIP_CHECKPOINT", "OPENCLIP_CHECKPOINT"
|
||||
),
|
||||
)
|
||||
device: str | None = Field(
|
||||
default="cpu",
|
||||
description="Device to run model on",
|
||||
validation_alias="EMBEDDINGS_OPENCLIP_DEVICE",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_OPENCLIP_DEVICE", "OPENCLIP_DEVICE"),
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.roboflow_embedding_function import (
|
||||
RoboflowEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -18,10 +18,14 @@ class RoboflowProvider(BaseEmbeddingsProvider[RoboflowEmbeddingFunction]):
|
||||
api_key: str = Field(
|
||||
default="",
|
||||
description="Roboflow API key",
|
||||
validation_alias="EMBEDDINGS_ROBOFLOW_API_KEY",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_ROBOFLOW_API_KEY", "ROBOFLOW_API_KEY"
|
||||
),
|
||||
)
|
||||
api_url: str = Field(
|
||||
default="https://infer.roboflow.com",
|
||||
description="Roboflow API URL",
|
||||
validation_alias="EMBEDDINGS_ROBOFLOW_API_URL",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_ROBOFLOW_API_URL", "ROBOFLOW_API_URL"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import (
|
||||
SentenceTransformerEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -20,15 +20,24 @@ class SentenceTransformerProvider(
|
||||
model_name: str = Field(
|
||||
default="all-MiniLM-L6-v2",
|
||||
description="Model name to use",
|
||||
validation_alias="EMBEDDINGS_SENTENCE_TRANSFORMER_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_SENTENCE_TRANSFORMER_MODEL_NAME",
|
||||
"SENTENCE_TRANSFORMER_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
device: str = Field(
|
||||
default="cpu",
|
||||
description="Device to run model on (cpu or cuda)",
|
||||
validation_alias="EMBEDDINGS_SENTENCE_TRANSFORMER_DEVICE",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_SENTENCE_TRANSFORMER_DEVICE", "SENTENCE_TRANSFORMER_DEVICE"
|
||||
),
|
||||
)
|
||||
normalize_embeddings: bool = Field(
|
||||
default=False,
|
||||
description="Whether to normalize embeddings",
|
||||
validation_alias="EMBEDDINGS_SENTENCE_TRANSFORMER_NORMALIZE_EMBEDDINGS",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_SENTENCE_TRANSFORMER_NORMALIZE_EMBEDDINGS",
|
||||
"SENTENCE_TRANSFORMER_NORMALIZE_EMBEDDINGS",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.text2vec_embedding_function import (
|
||||
Text2VecEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -18,5 +18,9 @@ class Text2VecProvider(BaseEmbeddingsProvider[Text2VecEmbeddingFunction]):
|
||||
model_name: str = Field(
|
||||
default="shibing624/text2vec-base-chinese",
|
||||
description="Model name to use",
|
||||
validation_alias="EMBEDDINGS_TEXT2VEC_MODEL_NAME",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_TEXT2VEC_MODEL_NAME",
|
||||
"TEXT2VEC_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Voyage AI embeddings provider."""
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.providers.voyageai.embedding_callable import (
|
||||
@@ -18,38 +18,53 @@ class VoyageAIProvider(BaseEmbeddingsProvider[VoyageAIEmbeddingFunction]):
|
||||
model: str = Field(
|
||||
default="voyage-2",
|
||||
description="Model to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_MODEL",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_VOYAGEAI_MODEL", "VOYAGEAI_MODEL"),
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Voyage AI API key", validation_alias="EMBEDDINGS_VOYAGEAI_API_KEY"
|
||||
description="Voyage AI API key",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_VOYAGEAI_API_KEY", "VOYAGEAI_API_KEY"
|
||||
),
|
||||
)
|
||||
input_type: str | None = Field(
|
||||
default=None,
|
||||
description="Input type for embeddings",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_INPUT_TYPE",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_VOYAGEAI_INPUT_TYPE", "VOYAGEAI_INPUT_TYPE"
|
||||
),
|
||||
)
|
||||
truncation: bool = Field(
|
||||
default=True,
|
||||
description="Whether to truncate inputs",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_TRUNCATION",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_VOYAGEAI_TRUNCATION", "VOYAGEAI_TRUNCATION"
|
||||
),
|
||||
)
|
||||
output_dtype: str | None = Field(
|
||||
default=None,
|
||||
description="Output data type",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_OUTPUT_DTYPE",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_VOYAGEAI_OUTPUT_DTYPE", "VOYAGEAI_OUTPUT_DTYPE"
|
||||
),
|
||||
)
|
||||
output_dimension: int | None = Field(
|
||||
default=None,
|
||||
description="Output dimension",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_OUTPUT_DIMENSION",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_VOYAGEAI_OUTPUT_DIMENSION", "VOYAGEAI_OUTPUT_DIMENSION"
|
||||
),
|
||||
)
|
||||
max_retries: int = Field(
|
||||
default=0,
|
||||
description="Maximum retries for API calls",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_MAX_RETRIES",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_VOYAGEAI_MAX_RETRIES", "VOYAGEAI_MAX_RETRIES"
|
||||
),
|
||||
)
|
||||
timeout: float | None = Field(
|
||||
default=None,
|
||||
description="Timeout for API calls",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_TIMEOUT",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_VOYAGEAI_TIMEOUT", "VOYAGEAI_TIMEOUT"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -29,7 +29,7 @@ from crewai.rag.embeddings.providers.text2vec.types import Text2VecProviderSpec
|
||||
from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderSpec
|
||||
|
||||
|
||||
ProviderSpec = (
|
||||
ProviderSpec: TypeAlias = (
|
||||
AzureProviderSpec
|
||||
| BedrockProviderSpec
|
||||
| CohereProviderSpec
|
||||
|
||||
@@ -1,16 +1,23 @@
|
||||
"""Qdrant configuration model."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import field
|
||||
from typing import Literal, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||
from qdrant_client.models import VectorParams
|
||||
|
||||
from crewai.rag.config.base import BaseRagConfig
|
||||
from crewai.rag.qdrant.constants import DEFAULT_EMBEDDING_MODEL, DEFAULT_STORAGE_PATH
|
||||
from crewai.rag.qdrant.types import QdrantClientParams, QdrantEmbeddingFunctionWrapper
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qdrant_client.models import VectorParams
|
||||
else:
|
||||
VectorParams = Any
|
||||
|
||||
|
||||
def _default_options() -> QdrantClientParams:
|
||||
"""Create default Qdrant client options.
|
||||
|
||||
@@ -26,7 +33,7 @@ def _default_embedding_function() -> QdrantEmbeddingFunctionWrapper:
|
||||
Returns:
|
||||
Default embedding function using fastembed with all-MiniLM-L6-v2.
|
||||
"""
|
||||
from fastembed import TextEmbedding # type: ignore[import-not-found]
|
||||
from fastembed import TextEmbedding
|
||||
|
||||
model = TextEmbedding(model_name=DEFAULT_EMBEDDING_MODEL)
|
||||
|
||||
|
||||
364
lib/crewai/tests/rag/embeddings/test_backward_compatibility.py
Normal file
364
lib/crewai/tests/rag/embeddings/test_backward_compatibility.py
Normal file
@@ -0,0 +1,364 @@
|
||||
"""Tests for backward compatibility of embedding provider configurations."""
|
||||
|
||||
from crewai.rag.embeddings.factory import build_embedder, PROVIDER_PATHS
|
||||
from crewai.rag.embeddings.providers.openai.openai_provider import OpenAIProvider
|
||||
from crewai.rag.embeddings.providers.cohere.cohere_provider import CohereProvider
|
||||
from crewai.rag.embeddings.providers.google.generative_ai import GenerativeAiProvider
|
||||
from crewai.rag.embeddings.providers.google.vertex import VertexAIProvider
|
||||
from crewai.rag.embeddings.providers.microsoft.azure import AzureProvider
|
||||
from crewai.rag.embeddings.providers.jina.jina_provider import JinaProvider
|
||||
from crewai.rag.embeddings.providers.ollama.ollama_provider import OllamaProvider
|
||||
from crewai.rag.embeddings.providers.aws.bedrock import BedrockProvider
|
||||
from crewai.rag.embeddings.providers.text2vec.text2vec_provider import Text2VecProvider
|
||||
from crewai.rag.embeddings.providers.sentence_transformer.sentence_transformer_provider import (
|
||||
SentenceTransformerProvider,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.instructor.instructor_provider import InstructorProvider
|
||||
from crewai.rag.embeddings.providers.openclip.openclip_provider import OpenCLIPProvider
|
||||
|
||||
|
||||
class TestGoogleProviderAlias:
|
||||
"""Test that 'google' provider name alias works for backward compatibility."""
|
||||
|
||||
def test_google_alias_in_provider_paths(self):
|
||||
"""Verify 'google' is registered as an alias for google-generativeai."""
|
||||
assert "google" in PROVIDER_PATHS
|
||||
assert "google-generativeai" in PROVIDER_PATHS
|
||||
assert PROVIDER_PATHS["google"] == PROVIDER_PATHS["google-generativeai"]
|
||||
|
||||
|
||||
class TestModelKeyBackwardCompatibility:
|
||||
"""Test that 'model' config key works as alias for 'model_name'."""
|
||||
|
||||
def test_openai_provider_accepts_model_key(self):
|
||||
"""Test OpenAI provider accepts 'model' as alias for 'model_name'."""
|
||||
provider = OpenAIProvider(
|
||||
api_key="test-key",
|
||||
model="text-embedding-3-small",
|
||||
)
|
||||
assert provider.model_name == "text-embedding-3-small"
|
||||
|
||||
def test_openai_provider_model_name_takes_precedence(self):
|
||||
"""Test that model_name takes precedence when both are provided."""
|
||||
provider = OpenAIProvider(
|
||||
api_key="test-key",
|
||||
model_name="text-embedding-3-large",
|
||||
)
|
||||
assert provider.model_name == "text-embedding-3-large"
|
||||
|
||||
def test_cohere_provider_accepts_model_key(self):
|
||||
"""Test Cohere provider accepts 'model' as alias for 'model_name'."""
|
||||
provider = CohereProvider(
|
||||
api_key="test-key",
|
||||
model="embed-english-v3.0",
|
||||
)
|
||||
assert provider.model_name == "embed-english-v3.0"
|
||||
|
||||
def test_google_generativeai_provider_accepts_model_key(self):
|
||||
"""Test Google Generative AI provider accepts 'model' as alias."""
|
||||
provider = GenerativeAiProvider(
|
||||
api_key="test-key",
|
||||
model="gemini-embedding-001",
|
||||
)
|
||||
assert provider.model_name == "gemini-embedding-001"
|
||||
|
||||
def test_google_vertex_provider_accepts_model_key(self):
|
||||
"""Test Google Vertex AI provider accepts 'model' as alias."""
|
||||
provider = VertexAIProvider(
|
||||
api_key="test-key",
|
||||
model="text-embedding-004",
|
||||
)
|
||||
assert provider.model_name == "text-embedding-004"
|
||||
|
||||
def test_azure_provider_accepts_model_key(self):
|
||||
"""Test Azure provider accepts 'model' as alias for 'model_name'."""
|
||||
provider = AzureProvider(
|
||||
api_key="test-key",
|
||||
deployment_id="test-deployment",
|
||||
model="text-embedding-ada-002",
|
||||
)
|
||||
assert provider.model_name == "text-embedding-ada-002"
|
||||
|
||||
def test_jina_provider_accepts_model_key(self):
|
||||
"""Test Jina provider accepts 'model' as alias for 'model_name'."""
|
||||
provider = JinaProvider(
|
||||
api_key="test-key",
|
||||
model="jina-embeddings-v3",
|
||||
)
|
||||
assert provider.model_name == "jina-embeddings-v3"
|
||||
|
||||
def test_ollama_provider_accepts_model_key(self):
|
||||
"""Test Ollama provider accepts 'model' as alias for 'model_name'."""
|
||||
provider = OllamaProvider(
|
||||
model="nomic-embed-text",
|
||||
)
|
||||
assert provider.model_name == "nomic-embed-text"
|
||||
|
||||
def test_text2vec_provider_accepts_model_key(self):
|
||||
"""Test Text2Vec provider accepts 'model' as alias for 'model_name'."""
|
||||
provider = Text2VecProvider(
|
||||
model="shibing624/text2vec-base-multilingual",
|
||||
)
|
||||
assert provider.model_name == "shibing624/text2vec-base-multilingual"
|
||||
|
||||
def test_sentence_transformer_provider_accepts_model_key(self):
|
||||
"""Test SentenceTransformer provider accepts 'model' as alias."""
|
||||
provider = SentenceTransformerProvider(
|
||||
model="all-mpnet-base-v2",
|
||||
)
|
||||
assert provider.model_name == "all-mpnet-base-v2"
|
||||
|
||||
def test_instructor_provider_accepts_model_key(self):
|
||||
"""Test Instructor provider accepts 'model' as alias for 'model_name'."""
|
||||
provider = InstructorProvider(
|
||||
model="hkunlp/instructor-xl",
|
||||
)
|
||||
assert provider.model_name == "hkunlp/instructor-xl"
|
||||
|
||||
def test_openclip_provider_accepts_model_key(self):
|
||||
"""Test OpenCLIP provider accepts 'model' as alias for 'model_name'."""
|
||||
provider = OpenCLIPProvider(
|
||||
model="ViT-B-16",
|
||||
)
|
||||
assert provider.model_name == "ViT-B-16"
|
||||
|
||||
|
||||
class TestTaskTypeConfiguration:
|
||||
"""Test that task_type configuration works correctly."""
|
||||
|
||||
def test_google_provider_accepts_lowercase_task_type(self):
|
||||
"""Test Google provider accepts lowercase task_type."""
|
||||
provider = GenerativeAiProvider(
|
||||
api_key="test-key",
|
||||
task_type="retrieval_document",
|
||||
)
|
||||
assert provider.task_type == "retrieval_document"
|
||||
|
||||
def test_google_provider_accepts_uppercase_task_type(self):
|
||||
"""Test Google provider accepts uppercase task_type."""
|
||||
provider = GenerativeAiProvider(
|
||||
api_key="test-key",
|
||||
task_type="RETRIEVAL_QUERY",
|
||||
)
|
||||
assert provider.task_type == "RETRIEVAL_QUERY"
|
||||
|
||||
def test_google_provider_default_task_type(self):
|
||||
"""Test Google provider has correct default task_type."""
|
||||
provider = GenerativeAiProvider(
|
||||
api_key="test-key",
|
||||
)
|
||||
assert provider.task_type == "RETRIEVAL_DOCUMENT"
|
||||
|
||||
|
||||
class TestFactoryBackwardCompatibility:
|
||||
"""Test factory function with backward compatible configurations."""
|
||||
|
||||
def test_factory_with_google_alias(self):
|
||||
"""Test factory resolves 'google' to google-generativeai provider."""
|
||||
config = {
|
||||
"provider": "google",
|
||||
"config": {
|
||||
"api_key": "test-key",
|
||||
"model": "gemini-embedding-001",
|
||||
},
|
||||
}
|
||||
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
with patch("crewai.rag.embeddings.factory.import_and_validate_definition") as mock_import:
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
|
||||
build_embedder(config)
|
||||
|
||||
mock_import.assert_called_once_with(
|
||||
"crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider"
|
||||
)
|
||||
|
||||
def test_factory_with_model_key_openai(self):
|
||||
"""Test factory passes 'model' config to OpenAI provider."""
|
||||
config = {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"api_key": "test-key",
|
||||
"model": "text-embedding-3-small",
|
||||
},
|
||||
}
|
||||
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
with patch("crewai.rag.embeddings.factory.import_and_validate_definition") as mock_import:
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
|
||||
build_embedder(config)
|
||||
|
||||
call_kwargs = mock_provider_class.call_args.kwargs
|
||||
assert call_kwargs["model"] == "text-embedding-3-small"
|
||||
|
||||
|
||||
class TestDocumentationCodeSnippets:
|
||||
"""Test code snippets from documentation work correctly."""
|
||||
|
||||
def test_memory_openai_config(self):
|
||||
"""Test OpenAI config from memory.mdx documentation."""
|
||||
provider = OpenAIProvider(
|
||||
model_name="text-embedding-3-small",
|
||||
)
|
||||
assert provider.model_name == "text-embedding-3-small"
|
||||
|
||||
def test_memory_openai_config_with_options(self):
|
||||
"""Test OpenAI config with all options from memory.mdx."""
|
||||
provider = OpenAIProvider(
|
||||
api_key="your-openai-api-key",
|
||||
model_name="text-embedding-3-large",
|
||||
dimensions=1536,
|
||||
organization_id="your-org-id",
|
||||
)
|
||||
assert provider.model_name == "text-embedding-3-large"
|
||||
assert provider.dimensions == 1536
|
||||
|
||||
def test_memory_azure_config(self):
|
||||
"""Test Azure config from memory.mdx documentation."""
|
||||
provider = AzureProvider(
|
||||
api_key="your-azure-key",
|
||||
api_base="https://your-resource.openai.azure.com/",
|
||||
api_type="azure",
|
||||
api_version="2023-05-15",
|
||||
model_name="text-embedding-3-small",
|
||||
deployment_id="your-deployment-name",
|
||||
)
|
||||
assert provider.model_name == "text-embedding-3-small"
|
||||
assert provider.api_type == "azure"
|
||||
|
||||
def test_memory_google_generativeai_config(self):
|
||||
"""Test Google Generative AI config from memory.mdx documentation."""
|
||||
provider = GenerativeAiProvider(
|
||||
api_key="your-google-api-key",
|
||||
model_name="gemini-embedding-001",
|
||||
)
|
||||
assert provider.model_name == "gemini-embedding-001"
|
||||
|
||||
def test_memory_cohere_config(self):
|
||||
"""Test Cohere config from memory.mdx documentation."""
|
||||
provider = CohereProvider(
|
||||
api_key="your-cohere-api-key",
|
||||
model_name="embed-english-v3.0",
|
||||
)
|
||||
assert provider.model_name == "embed-english-v3.0"
|
||||
|
||||
def test_knowledge_agent_embedder_config(self):
|
||||
"""Test agent embedder config from knowledge.mdx documentation."""
|
||||
provider = GenerativeAiProvider(
|
||||
model_name="gemini-embedding-001",
|
||||
api_key="your-google-key",
|
||||
)
|
||||
assert provider.model_name == "gemini-embedding-001"
|
||||
|
||||
def test_ragtool_openai_config(self):
|
||||
"""Test RagTool OpenAI config from ragtool.mdx documentation."""
|
||||
provider = OpenAIProvider(
|
||||
model_name="text-embedding-3-small",
|
||||
)
|
||||
assert provider.model_name == "text-embedding-3-small"
|
||||
|
||||
def test_ragtool_cohere_config(self):
|
||||
"""Test RagTool Cohere config from ragtool.mdx documentation."""
|
||||
provider = CohereProvider(
|
||||
api_key="your-api-key",
|
||||
model_name="embed-english-v3.0",
|
||||
)
|
||||
assert provider.model_name == "embed-english-v3.0"
|
||||
|
||||
def test_ragtool_ollama_config(self):
|
||||
"""Test RagTool Ollama config from ragtool.mdx documentation."""
|
||||
provider = OllamaProvider(
|
||||
model_name="llama2",
|
||||
url="http://localhost:11434/api/embeddings",
|
||||
)
|
||||
assert provider.model_name == "llama2"
|
||||
|
||||
def test_ragtool_azure_config(self):
|
||||
"""Test RagTool Azure config from ragtool.mdx documentation."""
|
||||
provider = AzureProvider(
|
||||
deployment_id="your-deployment-id",
|
||||
api_key="your-api-key",
|
||||
api_base="https://your-resource.openai.azure.com",
|
||||
api_version="2024-02-01",
|
||||
model_name="text-embedding-ada-002",
|
||||
api_type="azure",
|
||||
)
|
||||
assert provider.model_name == "text-embedding-ada-002"
|
||||
assert provider.deployment_id == "your-deployment-id"
|
||||
|
||||
def test_ragtool_google_generativeai_config(self):
|
||||
"""Test RagTool Google Generative AI config from ragtool.mdx."""
|
||||
provider = GenerativeAiProvider(
|
||||
api_key="your-api-key",
|
||||
model_name="gemini-embedding-001",
|
||||
task_type="RETRIEVAL_DOCUMENT",
|
||||
)
|
||||
assert provider.model_name == "gemini-embedding-001"
|
||||
assert provider.task_type == "RETRIEVAL_DOCUMENT"
|
||||
|
||||
def test_ragtool_jina_config(self):
|
||||
"""Test RagTool Jina config from ragtool.mdx documentation."""
|
||||
provider = JinaProvider(
|
||||
api_key="your-api-key",
|
||||
model_name="jina-embeddings-v3",
|
||||
)
|
||||
assert provider.model_name == "jina-embeddings-v3"
|
||||
|
||||
def test_ragtool_sentence_transformer_config(self):
|
||||
"""Test RagTool SentenceTransformer config from ragtool.mdx."""
|
||||
provider = SentenceTransformerProvider(
|
||||
model_name="all-mpnet-base-v2",
|
||||
device="cuda",
|
||||
normalize_embeddings=True,
|
||||
)
|
||||
assert provider.model_name == "all-mpnet-base-v2"
|
||||
assert provider.device == "cuda"
|
||||
assert provider.normalize_embeddings is True
|
||||
|
||||
|
||||
class TestLegacyConfigurationFormats:
|
||||
"""Test legacy configuration formats that should still work."""
|
||||
|
||||
def test_legacy_google_with_model_key(self):
|
||||
"""Test legacy Google config using 'model' instead of 'model_name'."""
|
||||
provider = GenerativeAiProvider(
|
||||
api_key="test-key",
|
||||
model="text-embedding-005",
|
||||
task_type="retrieval_document",
|
||||
)
|
||||
assert provider.model_name == "text-embedding-005"
|
||||
assert provider.task_type == "retrieval_document"
|
||||
|
||||
def test_legacy_openai_with_model_key(self):
|
||||
"""Test legacy OpenAI config using 'model' instead of 'model_name'."""
|
||||
provider = OpenAIProvider(
|
||||
api_key="test-key",
|
||||
model="text-embedding-ada-002",
|
||||
)
|
||||
assert provider.model_name == "text-embedding-ada-002"
|
||||
|
||||
def test_legacy_cohere_with_model_key(self):
|
||||
"""Test legacy Cohere config using 'model' instead of 'model_name'."""
|
||||
provider = CohereProvider(
|
||||
api_key="test-key",
|
||||
model="embed-multilingual-v3.0",
|
||||
)
|
||||
assert provider.model_name == "embed-multilingual-v3.0"
|
||||
|
||||
def test_legacy_azure_with_model_key(self):
|
||||
"""Test legacy Azure config using 'model' instead of 'model_name'."""
|
||||
provider = AzureProvider(
|
||||
api_key="test-key",
|
||||
deployment_id="test-deployment",
|
||||
model="text-embedding-3-large",
|
||||
)
|
||||
assert provider.model_name == "text-embedding-3-large"
|
||||
Reference in New Issue
Block a user