mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-21 22:08:21 +00:00
Compare commits
5 Commits
devin/1763
...
joaomdmour
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
02a6ffda8f | ||
|
|
65f75cd374 | ||
|
|
6ca6babd37 | ||
|
|
8d7fbc0c79 | ||
|
|
50d6553134 |
@@ -326,7 +326,7 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"tab": "AOP",
|
||||
"tab": "AMP",
|
||||
"icon": "briefcase",
|
||||
"groups": [
|
||||
{
|
||||
@@ -753,7 +753,7 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"tab": "AOP",
|
||||
"tab": "AMP",
|
||||
"icon": "briefcase",
|
||||
"groups": [
|
||||
{
|
||||
|
||||
@@ -388,8 +388,8 @@ crew = Crew(
|
||||
agents=[sales_agent, tech_agent, support_agent],
|
||||
tasks=[...],
|
||||
embedder={ # Fallback embedder for agents without their own
|
||||
"provider": "google-generativeai",
|
||||
"config": {"model_name": "gemini-embedding-001"}
|
||||
"provider": "google",
|
||||
"config": {"model": "text-embedding-004"}
|
||||
}
|
||||
)
|
||||
|
||||
@@ -629,9 +629,9 @@ agent = Agent(
|
||||
backstory="Expert researcher",
|
||||
knowledge_sources=[knowledge_source],
|
||||
embedder={
|
||||
"provider": "google-generativeai",
|
||||
"provider": "google",
|
||||
"config": {
|
||||
"model_name": "gemini-embedding-001",
|
||||
"model": "models/text-embedding-004",
|
||||
"api_key": "your-google-key"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -341,7 +341,7 @@ crew = Crew(
|
||||
embedder={
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model_name": "text-embedding-3-small" # or "text-embedding-3-large"
|
||||
"model": "text-embedding-3-small" # or "text-embedding-3-large"
|
||||
}
|
||||
}
|
||||
)
|
||||
@@ -353,7 +353,7 @@ crew = Crew(
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"api_key": "your-openai-api-key", # Optional: override env var
|
||||
"model_name": "text-embedding-3-large",
|
||||
"model": "text-embedding-3-large",
|
||||
"dimensions": 1536, # Optional: reduce dimensions for smaller storage
|
||||
"organization_id": "your-org-id" # Optional: for organization accounts
|
||||
}
|
||||
@@ -375,7 +375,7 @@ crew = Crew(
|
||||
"api_base": "https://your-resource.openai.azure.com/",
|
||||
"api_type": "azure",
|
||||
"api_version": "2023-05-15",
|
||||
"model_name": "text-embedding-3-small",
|
||||
"model": "text-embedding-3-small",
|
||||
"deployment_id": "your-deployment-name" # Azure deployment name
|
||||
}
|
||||
}
|
||||
@@ -390,10 +390,10 @@ Use Google's text embedding models for integration with Google Cloud services.
|
||||
crew = Crew(
|
||||
memory=True,
|
||||
embedder={
|
||||
"provider": "google-generativeai",
|
||||
"provider": "google",
|
||||
"config": {
|
||||
"api_key": "your-google-api-key",
|
||||
"model_name": "gemini-embedding-001" # or "text-embedding-005", "text-multilingual-embedding-002"
|
||||
"model": "text-embedding-004" # or "text-embedding-preview-0409"
|
||||
}
|
||||
}
|
||||
)
|
||||
@@ -461,7 +461,7 @@ crew = Crew(
|
||||
"provider": "cohere",
|
||||
"config": {
|
||||
"api_key": "your-cohere-api-key",
|
||||
"model_name": "embed-english-v3.0" # or "embed-multilingual-v3.0"
|
||||
"model": "embed-english-v3.0" # or "embed-multilingual-v3.0"
|
||||
}
|
||||
}
|
||||
)
|
||||
@@ -478,7 +478,7 @@ crew = Crew(
|
||||
"provider": "voyageai",
|
||||
"config": {
|
||||
"api_key": "your-voyage-api-key",
|
||||
"model": "voyage-3", # or "voyage-3-lite", "voyage-code-3"
|
||||
"model": "voyage-large-2", # or "voyage-code-2" for code
|
||||
"input_type": "document" # or "query"
|
||||
}
|
||||
}
|
||||
@@ -912,10 +912,10 @@ crew = Crew(
|
||||
crew = Crew(
|
||||
memory=True,
|
||||
embedder={
|
||||
"provider": "google-generativeai",
|
||||
"provider": "google",
|
||||
"config": {
|
||||
"api_key": "your-api-key",
|
||||
"model_name": "gemini-embedding-001"
|
||||
"model": "text-embedding-004"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
@@ -7,7 +7,7 @@ mode: "wide"
|
||||
|
||||
## Introduction
|
||||
|
||||
CrewAI AOP(Agent Operations Platform) provides a platform for deploying, monitoring, and scaling your crews and agents in a production environment.
|
||||
CrewAI AOP(Agent Management Platform) provides a platform for deploying, monitoring, and scaling your crews and agents in a production environment.
|
||||
|
||||
<Frame>
|
||||
<img src="/images/enterprise/crewai-enterprise-dashboard.png" alt="CrewAI AOP Dashboard" />
|
||||
|
||||
@@ -77,7 +77,7 @@ The `RagTool` accepts the following parameters:
|
||||
|
||||
- **summarize**: Optional. Whether to summarize the retrieved content. Default is `False`.
|
||||
- **adapter**: Optional. A custom adapter for the knowledge base. If not provided, a CrewAIRagAdapter will be used.
|
||||
- **config**: Optional. Configuration for the underlying CrewAI RAG system. Accepts a `RagToolConfig` TypedDict with optional `embedding_model` (ProviderSpec) and `vectordb` (VectorDbConfig) keys. All configuration values provided programmatically take precedence over environment variables.
|
||||
- **config**: Optional. Configuration for the underlying CrewAI RAG system.
|
||||
|
||||
## Adding Content
|
||||
|
||||
@@ -127,528 +127,26 @@ You can customize the behavior of the `RagTool` by providing a configuration dic
|
||||
|
||||
```python Code
|
||||
from crewai_tools import RagTool
|
||||
from crewai_tools.tools.rag import RagToolConfig, VectorDbConfig, ProviderSpec
|
||||
|
||||
# Create a RAG tool with custom configuration
|
||||
|
||||
vectordb: VectorDbConfig = {
|
||||
"provider": "qdrant",
|
||||
"config": {
|
||||
"collection_name": "my-collection"
|
||||
config = {
|
||||
"vectordb": {
|
||||
"provider": "qdrant",
|
||||
"config": {
|
||||
"collection_name": "my-collection"
|
||||
}
|
||||
},
|
||||
"embedding_model": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": "text-embedding-3-small"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
embedding_model: ProviderSpec = {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model_name": "text-embedding-3-small"
|
||||
}
|
||||
}
|
||||
|
||||
config: RagToolConfig = {
|
||||
"vectordb": vectordb,
|
||||
"embedding_model": embedding_model
|
||||
}
|
||||
|
||||
rag_tool = RagTool(config=config, summarize=True)
|
||||
```
|
||||
|
||||
## Embedding Model Configuration
|
||||
|
||||
The `embedding_model` parameter accepts a `crewai.rag.embeddings.types.ProviderSpec` dictionary with the structure:
|
||||
|
||||
```python
|
||||
{
|
||||
"provider": "provider-name", # Required
|
||||
"config": { # Optional
|
||||
# Provider-specific configuration
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Supported Providers
|
||||
|
||||
<AccordionGroup>
|
||||
<Accordion title="OpenAI">
|
||||
```python main.py
|
||||
from crewai.rag.embeddings.providers.openai.types import OpenAIProviderSpec
|
||||
|
||||
embedding_model: OpenAIProviderSpec = {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"api_key": "your-api-key",
|
||||
"model_name": "text-embedding-ada-002",
|
||||
"dimensions": 1536,
|
||||
"organization_id": "your-org-id",
|
||||
"api_base": "https://api.openai.com/v1",
|
||||
"api_version": "v1",
|
||||
"default_headers": {"Custom-Header": "value"}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Config Options:**
|
||||
- `api_key` (str): OpenAI API key
|
||||
- `model_name` (str): Model to use. Default: `text-embedding-ada-002`. Options: `text-embedding-3-small`, `text-embedding-3-large`, `text-embedding-ada-002`
|
||||
- `dimensions` (int): Number of dimensions for the embedding
|
||||
- `organization_id` (str): OpenAI organization ID
|
||||
- `api_base` (str): Custom API base URL
|
||||
- `api_version` (str): API version
|
||||
- `default_headers` (dict): Custom headers for API requests
|
||||
|
||||
**Environment Variables:**
|
||||
- `OPENAI_API_KEY` or `EMBEDDINGS_OPENAI_API_KEY`: `api_key`
|
||||
- `OPENAI_ORGANIZATION_ID` or `EMBEDDINGS_OPENAI_ORGANIZATION_ID`: `organization_id`
|
||||
- `OPENAI_MODEL_NAME` or `EMBEDDINGS_OPENAI_MODEL_NAME`: `model_name`
|
||||
- `OPENAI_API_BASE` or `EMBEDDINGS_OPENAI_API_BASE`: `api_base`
|
||||
- `OPENAI_API_VERSION` or `EMBEDDINGS_OPENAI_API_VERSION`: `api_version`
|
||||
- `OPENAI_DIMENSIONS` or `EMBEDDINGS_OPENAI_DIMENSIONS`: `dimensions`
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Cohere">
|
||||
```python main.py
|
||||
from crewai.rag.embeddings.providers.cohere.types import CohereProviderSpec
|
||||
|
||||
embedding_model: CohereProviderSpec = {
|
||||
"provider": "cohere",
|
||||
"config": {
|
||||
"api_key": "your-api-key",
|
||||
"model_name": "embed-english-v3.0"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Config Options:**
|
||||
- `api_key` (str): Cohere API key
|
||||
- `model_name` (str): Model to use. Default: `large`. Options: `embed-english-v3.0`, `embed-multilingual-v3.0`, `large`, `small`
|
||||
|
||||
**Environment Variables:**
|
||||
- `COHERE_API_KEY` or `EMBEDDINGS_COHERE_API_KEY`: `api_key`
|
||||
- `EMBEDDINGS_COHERE_MODEL_NAME`: `model_name`
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="VoyageAI">
|
||||
```python main.py
|
||||
from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderSpec
|
||||
|
||||
embedding_model: VoyageAIProviderSpec = {
|
||||
"provider": "voyageai",
|
||||
"config": {
|
||||
"api_key": "your-api-key",
|
||||
"model": "voyage-3",
|
||||
"input_type": "document",
|
||||
"truncation": True,
|
||||
"output_dtype": "float32",
|
||||
"output_dimension": 1024,
|
||||
"max_retries": 3,
|
||||
"timeout": 60.0
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Config Options:**
|
||||
- `api_key` (str): VoyageAI API key
|
||||
- `model` (str): Model to use. Default: `voyage-2`. Options: `voyage-3`, `voyage-3-lite`, `voyage-code-3`, `voyage-large-2`
|
||||
- `input_type` (str): Type of input. Options: `document` (for storage), `query` (for search)
|
||||
- `truncation` (bool): Whether to truncate inputs that exceed max length. Default: `True`
|
||||
- `output_dtype` (str): Output data type
|
||||
- `output_dimension` (int): Dimension of output embeddings
|
||||
- `max_retries` (int): Maximum number of retry attempts. Default: `0`
|
||||
- `timeout` (float): Request timeout in seconds
|
||||
|
||||
**Environment Variables:**
|
||||
- `VOYAGEAI_API_KEY` or `EMBEDDINGS_VOYAGEAI_API_KEY`: `api_key`
|
||||
- `VOYAGEAI_MODEL` or `EMBEDDINGS_VOYAGEAI_MODEL`: `model`
|
||||
- `VOYAGEAI_INPUT_TYPE` or `EMBEDDINGS_VOYAGEAI_INPUT_TYPE`: `input_type`
|
||||
- `VOYAGEAI_TRUNCATION` or `EMBEDDINGS_VOYAGEAI_TRUNCATION`: `truncation`
|
||||
- `VOYAGEAI_OUTPUT_DTYPE` or `EMBEDDINGS_VOYAGEAI_OUTPUT_DTYPE`: `output_dtype`
|
||||
- `VOYAGEAI_OUTPUT_DIMENSION` or `EMBEDDINGS_VOYAGEAI_OUTPUT_DIMENSION`: `output_dimension`
|
||||
- `VOYAGEAI_MAX_RETRIES` or `EMBEDDINGS_VOYAGEAI_MAX_RETRIES`: `max_retries`
|
||||
- `VOYAGEAI_TIMEOUT` or `EMBEDDINGS_VOYAGEAI_TIMEOUT`: `timeout`
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Ollama">
|
||||
```python main.py
|
||||
from crewai.rag.embeddings.providers.ollama.types import OllamaProviderSpec
|
||||
|
||||
embedding_model: OllamaProviderSpec = {
|
||||
"provider": "ollama",
|
||||
"config": {
|
||||
"model_name": "llama2",
|
||||
"url": "http://localhost:11434/api/embeddings"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Config Options:**
|
||||
- `model_name` (str): Ollama model name (e.g., `llama2`, `mistral`, `nomic-embed-text`)
|
||||
- `url` (str): Ollama API endpoint URL. Default: `http://localhost:11434/api/embeddings`
|
||||
|
||||
**Environment Variables:**
|
||||
- `OLLAMA_MODEL` or `EMBEDDINGS_OLLAMA_MODEL`: `model_name`
|
||||
- `OLLAMA_URL` or `EMBEDDINGS_OLLAMA_URL`: `url`
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Amazon Bedrock">
|
||||
```python main.py
|
||||
from crewai.rag.embeddings.providers.aws.types import BedrockProviderSpec
|
||||
|
||||
embedding_model: BedrockProviderSpec = {
|
||||
"provider": "amazon-bedrock",
|
||||
"config": {
|
||||
"model_name": "amazon.titan-embed-text-v2:0",
|
||||
"session": boto3_session
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Config Options:**
|
||||
- `model_name` (str): Bedrock model ID. Default: `amazon.titan-embed-text-v1`. Options: `amazon.titan-embed-text-v1`, `amazon.titan-embed-text-v2:0`, `cohere.embed-english-v3`, `cohere.embed-multilingual-v3`
|
||||
- `session` (Any): Boto3 session object for AWS authentication
|
||||
|
||||
**Environment Variables:**
|
||||
- `AWS_ACCESS_KEY_ID`: AWS access key
|
||||
- `AWS_SECRET_ACCESS_KEY`: AWS secret key
|
||||
- `AWS_REGION`: AWS region (e.g., `us-east-1`)
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Azure OpenAI">
|
||||
```python main.py
|
||||
from crewai.rag.embeddings.providers.microsoft.types import AzureProviderSpec
|
||||
|
||||
embedding_model: AzureProviderSpec = {
|
||||
"provider": "azure",
|
||||
"config": {
|
||||
"deployment_id": "your-deployment-id",
|
||||
"api_key": "your-api-key",
|
||||
"api_base": "https://your-resource.openai.azure.com",
|
||||
"api_version": "2024-02-01",
|
||||
"model_name": "text-embedding-ada-002",
|
||||
"api_type": "azure"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Config Options:**
|
||||
- `deployment_id` (str): **Required** - Azure OpenAI deployment ID
|
||||
- `api_key` (str): Azure OpenAI API key
|
||||
- `api_base` (str): Azure OpenAI resource endpoint
|
||||
- `api_version` (str): API version. Example: `2024-02-01`
|
||||
- `model_name` (str): Model name. Default: `text-embedding-ada-002`
|
||||
- `api_type` (str): API type. Default: `azure`
|
||||
- `dimensions` (int): Output dimensions
|
||||
- `default_headers` (dict): Custom headers
|
||||
|
||||
**Environment Variables:**
|
||||
- `AZURE_OPENAI_API_KEY` or `EMBEDDINGS_AZURE_API_KEY`: `api_key`
|
||||
- `AZURE_OPENAI_ENDPOINT` or `EMBEDDINGS_AZURE_API_BASE`: `api_base`
|
||||
- `EMBEDDINGS_AZURE_DEPLOYMENT_ID`: `deployment_id`
|
||||
- `EMBEDDINGS_AZURE_API_VERSION`: `api_version`
|
||||
- `EMBEDDINGS_AZURE_MODEL_NAME`: `model_name`
|
||||
- `EMBEDDINGS_AZURE_API_TYPE`: `api_type`
|
||||
- `EMBEDDINGS_AZURE_DIMENSIONS`: `dimensions`
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Google Generative AI">
|
||||
```python main.py
|
||||
from crewai.rag.embeddings.providers.google.types import GenerativeAiProviderSpec
|
||||
|
||||
embedding_model: GenerativeAiProviderSpec = {
|
||||
"provider": "google-generativeai",
|
||||
"config": {
|
||||
"api_key": "your-api-key",
|
||||
"model_name": "gemini-embedding-001",
|
||||
"task_type": "RETRIEVAL_DOCUMENT"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Config Options:**
|
||||
- `api_key` (str): Google AI API key
|
||||
- `model_name` (str): Model name. Default: `gemini-embedding-001`. Options: `gemini-embedding-001`, `text-embedding-005`, `text-multilingual-embedding-002`
|
||||
- `task_type` (str): Task type for embeddings. Default: `RETRIEVAL_DOCUMENT`. Options: `RETRIEVAL_DOCUMENT`, `RETRIEVAL_QUERY`
|
||||
|
||||
**Environment Variables:**
|
||||
- `GOOGLE_API_KEY`, `GEMINI_API_KEY`, or `EMBEDDINGS_GOOGLE_API_KEY`: `api_key`
|
||||
- `EMBEDDINGS_GOOGLE_GENERATIVE_AI_MODEL_NAME`: `model_name`
|
||||
- `EMBEDDINGS_GOOGLE_GENERATIVE_AI_TASK_TYPE`: `task_type`
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Google Vertex AI">
|
||||
```python main.py
|
||||
from crewai.rag.embeddings.providers.google.types import VertexAIProviderSpec
|
||||
|
||||
embedding_model: VertexAIProviderSpec = {
|
||||
"provider": "google-vertex",
|
||||
"config": {
|
||||
"model_name": "text-embedding-004",
|
||||
"project_id": "your-project-id",
|
||||
"region": "us-central1",
|
||||
"api_key": "your-api-key"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Config Options:**
|
||||
- `model_name` (str): Model name. Default: `textembedding-gecko`. Options: `text-embedding-004`, `textembedding-gecko`, `textembedding-gecko-multilingual`
|
||||
- `project_id` (str): Google Cloud project ID. Default: `cloud-large-language-models`
|
||||
- `region` (str): Google Cloud region. Default: `us-central1`
|
||||
- `api_key` (str): API key for authentication
|
||||
|
||||
**Environment Variables:**
|
||||
- `GOOGLE_APPLICATION_CREDENTIALS`: Path to service account JSON file
|
||||
- `GOOGLE_CLOUD_PROJECT` or `EMBEDDINGS_GOOGLE_VERTEX_PROJECT_ID`: `project_id`
|
||||
- `EMBEDDINGS_GOOGLE_VERTEX_MODEL_NAME`: `model_name`
|
||||
- `EMBEDDINGS_GOOGLE_VERTEX_REGION`: `region`
|
||||
- `EMBEDDINGS_GOOGLE_VERTEX_API_KEY`: `api_key`
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Jina AI">
|
||||
```python main.py
|
||||
from crewai.rag.embeddings.providers.jina.types import JinaProviderSpec
|
||||
|
||||
embedding_model: JinaProviderSpec = {
|
||||
"provider": "jina",
|
||||
"config": {
|
||||
"api_key": "your-api-key",
|
||||
"model_name": "jina-embeddings-v3"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Config Options:**
|
||||
- `api_key` (str): Jina AI API key
|
||||
- `model_name` (str): Model name. Default: `jina-embeddings-v2-base-en`. Options: `jina-embeddings-v3`, `jina-embeddings-v2-base-en`, `jina-embeddings-v2-small-en`
|
||||
|
||||
**Environment Variables:**
|
||||
- `JINA_API_KEY` or `EMBEDDINGS_JINA_API_KEY`: `api_key`
|
||||
- `EMBEDDINGS_JINA_MODEL_NAME`: `model_name`
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="HuggingFace">
|
||||
```python main.py
|
||||
from crewai.rag.embeddings.providers.huggingface.types import HuggingFaceProviderSpec
|
||||
|
||||
embedding_model: HuggingFaceProviderSpec = {
|
||||
"provider": "huggingface",
|
||||
"config": {
|
||||
"url": "https://api-inference.huggingface.co/models/sentence-transformers/all-MiniLM-L6-v2"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Config Options:**
|
||||
- `url` (str): Full URL to HuggingFace inference API endpoint
|
||||
|
||||
**Environment Variables:**
|
||||
- `HUGGINGFACE_URL` or `EMBEDDINGS_HUGGINGFACE_URL`: `url`
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Instructor">
|
||||
```python main.py
|
||||
from crewai.rag.embeddings.providers.instructor.types import InstructorProviderSpec
|
||||
|
||||
embedding_model: InstructorProviderSpec = {
|
||||
"provider": "instructor",
|
||||
"config": {
|
||||
"model_name": "hkunlp/instructor-xl",
|
||||
"device": "cuda",
|
||||
"instruction": "Represent the document"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Config Options:**
|
||||
- `model_name` (str): HuggingFace model ID. Default: `hkunlp/instructor-base`. Options: `hkunlp/instructor-xl`, `hkunlp/instructor-large`, `hkunlp/instructor-base`
|
||||
- `device` (str): Device to run on. Default: `cpu`. Options: `cpu`, `cuda`, `mps`
|
||||
- `instruction` (str): Instruction prefix for embeddings
|
||||
|
||||
**Environment Variables:**
|
||||
- `EMBEDDINGS_INSTRUCTOR_MODEL_NAME`: `model_name`
|
||||
- `EMBEDDINGS_INSTRUCTOR_DEVICE`: `device`
|
||||
- `EMBEDDINGS_INSTRUCTOR_INSTRUCTION`: `instruction`
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Sentence Transformer">
|
||||
```python main.py
|
||||
from crewai.rag.embeddings.providers.sentence_transformer.types import SentenceTransformerProviderSpec
|
||||
|
||||
embedding_model: SentenceTransformerProviderSpec = {
|
||||
"provider": "sentence-transformer",
|
||||
"config": {
|
||||
"model_name": "all-mpnet-base-v2",
|
||||
"device": "cuda",
|
||||
"normalize_embeddings": True
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Config Options:**
|
||||
- `model_name` (str): Sentence Transformers model name. Default: `all-MiniLM-L6-v2`. Options: `all-mpnet-base-v2`, `all-MiniLM-L6-v2`, `paraphrase-multilingual-MiniLM-L12-v2`
|
||||
- `device` (str): Device to run on. Default: `cpu`. Options: `cpu`, `cuda`, `mps`
|
||||
- `normalize_embeddings` (bool): Whether to normalize embeddings. Default: `False`
|
||||
|
||||
**Environment Variables:**
|
||||
- `EMBEDDINGS_SENTENCE_TRANSFORMER_MODEL_NAME`: `model_name`
|
||||
- `EMBEDDINGS_SENTENCE_TRANSFORMER_DEVICE`: `device`
|
||||
- `EMBEDDINGS_SENTENCE_TRANSFORMER_NORMALIZE_EMBEDDINGS`: `normalize_embeddings`
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="ONNX">
|
||||
```python main.py
|
||||
from crewai.rag.embeddings.providers.onnx.types import ONNXProviderSpec
|
||||
|
||||
embedding_model: ONNXProviderSpec = {
|
||||
"provider": "onnx",
|
||||
"config": {
|
||||
"preferred_providers": ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Config Options:**
|
||||
- `preferred_providers` (list[str]): List of ONNX execution providers in order of preference
|
||||
|
||||
**Environment Variables:**
|
||||
- `EMBEDDINGS_ONNX_PREFERRED_PROVIDERS`: `preferred_providers` (comma-separated list)
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="OpenCLIP">
|
||||
```python main.py
|
||||
from crewai.rag.embeddings.providers.openclip.types import OpenCLIPProviderSpec
|
||||
|
||||
embedding_model: OpenCLIPProviderSpec = {
|
||||
"provider": "openclip",
|
||||
"config": {
|
||||
"model_name": "ViT-B-32",
|
||||
"checkpoint": "laion2b_s34b_b79k",
|
||||
"device": "cuda"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Config Options:**
|
||||
- `model_name` (str): OpenCLIP model architecture. Default: `ViT-B-32`. Options: `ViT-B-32`, `ViT-B-16`, `ViT-L-14`
|
||||
- `checkpoint` (str): Pretrained checkpoint name. Default: `laion2b_s34b_b79k`. Options: `laion2b_s34b_b79k`, `laion400m_e32`, `openai`
|
||||
- `device` (str): Device to run on. Default: `cpu`. Options: `cpu`, `cuda`
|
||||
|
||||
**Environment Variables:**
|
||||
- `EMBEDDINGS_OPENCLIP_MODEL_NAME`: `model_name`
|
||||
- `EMBEDDINGS_OPENCLIP_CHECKPOINT`: `checkpoint`
|
||||
- `EMBEDDINGS_OPENCLIP_DEVICE`: `device`
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Text2Vec">
|
||||
```python main.py
|
||||
from crewai.rag.embeddings.providers.text2vec.types import Text2VecProviderSpec
|
||||
|
||||
embedding_model: Text2VecProviderSpec = {
|
||||
"provider": "text2vec",
|
||||
"config": {
|
||||
"model_name": "shibing624/text2vec-base-multilingual"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Config Options:**
|
||||
- `model_name` (str): Text2Vec model name from HuggingFace. Default: `shibing624/text2vec-base-chinese`. Options: `shibing624/text2vec-base-multilingual`, `shibing624/text2vec-base-chinese`
|
||||
|
||||
**Environment Variables:**
|
||||
- `EMBEDDINGS_TEXT2VEC_MODEL_NAME`: `model_name`
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Roboflow">
|
||||
```python main.py
|
||||
from crewai.rag.embeddings.providers.roboflow.types import RoboflowProviderSpec
|
||||
|
||||
embedding_model: RoboflowProviderSpec = {
|
||||
"provider": "roboflow",
|
||||
"config": {
|
||||
"api_key": "your-api-key",
|
||||
"api_url": "https://infer.roboflow.com"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Config Options:**
|
||||
- `api_key` (str): Roboflow API key. Default: `""` (empty string)
|
||||
- `api_url` (str): Roboflow inference API URL. Default: `https://infer.roboflow.com`
|
||||
|
||||
**Environment Variables:**
|
||||
- `ROBOFLOW_API_KEY` or `EMBEDDINGS_ROBOFLOW_API_KEY`: `api_key`
|
||||
- `ROBOFLOW_API_URL` or `EMBEDDINGS_ROBOFLOW_API_URL`: `api_url`
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="WatsonX (IBM)">
|
||||
```python main.py
|
||||
from crewai.rag.embeddings.providers.ibm.types import WatsonXProviderSpec
|
||||
|
||||
embedding_model: WatsonXProviderSpec = {
|
||||
"provider": "watsonx",
|
||||
"config": {
|
||||
"model_id": "ibm/slate-125m-english-rtrvr",
|
||||
"url": "https://us-south.ml.cloud.ibm.com",
|
||||
"api_key": "your-api-key",
|
||||
"project_id": "your-project-id",
|
||||
"batch_size": 100,
|
||||
"concurrency_limit": 10,
|
||||
"persistent_connection": True
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Config Options:**
|
||||
- `model_id` (str): WatsonX model identifier
|
||||
- `url` (str): WatsonX API endpoint
|
||||
- `api_key` (str): IBM Cloud API key
|
||||
- `project_id` (str): WatsonX project ID
|
||||
- `space_id` (str): WatsonX space ID (alternative to project_id)
|
||||
- `batch_size` (int): Batch size for embeddings. Default: `100`
|
||||
- `concurrency_limit` (int): Maximum concurrent requests. Default: `10`
|
||||
- `persistent_connection` (bool): Use persistent connections. Default: `True`
|
||||
- Plus 20+ additional authentication and configuration options
|
||||
|
||||
**Environment Variables:**
|
||||
- `WATSONX_API_KEY` or `EMBEDDINGS_WATSONX_API_KEY`: `api_key`
|
||||
- `WATSONX_URL` or `EMBEDDINGS_WATSONX_URL`: `url`
|
||||
- `WATSONX_PROJECT_ID` or `EMBEDDINGS_WATSONX_PROJECT_ID`: `project_id`
|
||||
- `EMBEDDINGS_WATSONX_MODEL_ID`: `model_id`
|
||||
- `EMBEDDINGS_WATSONX_SPACE_ID`: `space_id`
|
||||
- `EMBEDDINGS_WATSONX_BATCH_SIZE`: `batch_size`
|
||||
- `EMBEDDINGS_WATSONX_CONCURRENCY_LIMIT`: `concurrency_limit`
|
||||
- `EMBEDDINGS_WATSONX_PERSISTENT_CONNECTION`: `persistent_connection`
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Custom">
|
||||
```python main.py
|
||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
||||
from crewai.rag.embeddings.providers.custom.types import CustomProviderSpec
|
||||
|
||||
class MyEmbeddingFunction(EmbeddingFunction):
|
||||
def __call__(self, input):
|
||||
# Your custom embedding logic
|
||||
return embeddings
|
||||
|
||||
embedding_model: CustomProviderSpec = {
|
||||
"provider": "custom",
|
||||
"config": {
|
||||
"embedding_callable": MyEmbeddingFunction
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Config Options:**
|
||||
- `embedding_callable` (type[EmbeddingFunction]): Custom embedding function class
|
||||
|
||||
**Note:** Custom embedding functions must implement the `EmbeddingFunction` protocol defined in `crewai.rag.core.base_embeddings_callable`. The `__call__` method should accept input data and return embeddings as a list of numpy arrays (or compatible format that will be normalized). The returned embeddings are automatically normalized and validated.
|
||||
</Accordion>
|
||||
</AccordionGroup>
|
||||
|
||||
### Notes
|
||||
- All config fields are optional unless marked as **Required**
|
||||
- API keys can typically be provided via environment variables instead of config
|
||||
- Default values are shown where applicable
|
||||
|
||||
|
||||
## Conclusion
|
||||
The `RagTool` provides a powerful way to create and query knowledge bases from various data sources. By leveraging Retrieval-Augmented Generation, it enables agents to access and retrieve relevant information efficiently, enhancing their ability to provide accurate and contextually appropriate responses.
|
||||
|
||||
@@ -58,10 +58,10 @@ tool = MySQLSearchTool(
|
||||
),
|
||||
),
|
||||
embedder=dict(
|
||||
provider="google-generativeai",
|
||||
provider="google",
|
||||
config=dict(
|
||||
model_name="gemini-embedding-001",
|
||||
task_type="RETRIEVAL_DOCUMENT",
|
||||
model="models/embedding-001",
|
||||
task_type="retrieval_document",
|
||||
# title="Embeddings",
|
||||
),
|
||||
),
|
||||
|
||||
@@ -71,10 +71,10 @@ tool = PGSearchTool(
|
||||
),
|
||||
),
|
||||
embedder=dict(
|
||||
provider="google-generativeai", # or openai, ollama, ...
|
||||
provider="google", # or openai, ollama, ...
|
||||
config=dict(
|
||||
model_name="gemini-embedding-001",
|
||||
task_type="RETRIEVAL_DOCUMENT",
|
||||
model="models/embedding-001",
|
||||
task_type="retrieval_document",
|
||||
# title="Embeddings",
|
||||
),
|
||||
),
|
||||
|
||||
@@ -64,10 +64,10 @@ tool = JSONSearchTool(
|
||||
},
|
||||
},
|
||||
"embedding_model": {
|
||||
"provider": "google-generativeai", # or openai, ollama, ...
|
||||
"provider": "google", # or openai, ollama, ...
|
||||
"config": {
|
||||
"model_name": "gemini-embedding-001",
|
||||
"task_type": "RETRIEVAL_DOCUMENT",
|
||||
"model": "models/embedding-001",
|
||||
"task_type": "retrieval_document",
|
||||
# Further customization options can be added here.
|
||||
},
|
||||
},
|
||||
|
||||
@@ -63,15 +63,15 @@ tool = PDFSearchTool(
|
||||
"config": {
|
||||
# Model identifier for the chosen provider. "model" will be auto-mapped to "model_name" internally.
|
||||
"model": "text-embedding-3-small",
|
||||
# Optional: API key. If omitted, the tool will use provider-specific env vars
|
||||
# (e.g., OPENAI_API_KEY or EMBEDDINGS_OPENAI_API_KEY for OpenAI).
|
||||
# Optional: API key. If omitted, the tool will use provider-specific env vars when available
|
||||
# (e.g., OPENAI_API_KEY for provider="openai").
|
||||
# "api_key": "sk-...",
|
||||
|
||||
# Provider-specific examples:
|
||||
# --- Google Generative AI ---
|
||||
# (Set provider="google-generativeai" above)
|
||||
# "model_name": "gemini-embedding-001",
|
||||
# "task_type": "RETRIEVAL_DOCUMENT",
|
||||
# "model": "models/embedding-001",
|
||||
# "task_type": "retrieval_document",
|
||||
# "title": "Embeddings",
|
||||
|
||||
# --- Cohere ---
|
||||
|
||||
@@ -66,9 +66,9 @@ tool = TXTSearchTool(
|
||||
"provider": "openai", # or google-generativeai, cohere, ollama, ...
|
||||
"config": {
|
||||
"model": "text-embedding-3-small",
|
||||
# "api_key": "sk-...", # optional if env var is set (e.g., OPENAI_API_KEY or EMBEDDINGS_OPENAI_API_KEY)
|
||||
# "api_key": "sk-...", # optional if env var is set
|
||||
# Provider examples:
|
||||
# Google → model_name: "gemini-embedding-001", task_type: "RETRIEVAL_DOCUMENT"
|
||||
# Google → model: "models/embedding-001", task_type: "retrieval_document"
|
||||
# Cohere → model: "embed-english-v3.0"
|
||||
# Ollama → model: "nomic-embed-text"
|
||||
},
|
||||
|
||||
@@ -73,10 +73,10 @@ tool = CodeDocsSearchTool(
|
||||
),
|
||||
),
|
||||
embedder=dict(
|
||||
provider="google-generativeai", # or openai, ollama, ...
|
||||
provider="google", # or openai, ollama, ...
|
||||
config=dict(
|
||||
model_name="gemini-embedding-001",
|
||||
task_type="RETRIEVAL_DOCUMENT",
|
||||
model="models/embedding-001",
|
||||
task_type="retrieval_document",
|
||||
# title="Embeddings",
|
||||
),
|
||||
),
|
||||
|
||||
@@ -75,10 +75,10 @@ tool = GithubSearchTool(
|
||||
),
|
||||
),
|
||||
embedder=dict(
|
||||
provider="google-generativeai", # or openai, ollama, ...
|
||||
provider="google", # or openai, ollama, ...
|
||||
config=dict(
|
||||
model_name="gemini-embedding-001",
|
||||
task_type="RETRIEVAL_DOCUMENT",
|
||||
model="models/embedding-001",
|
||||
task_type="retrieval_document",
|
||||
# title="Embeddings",
|
||||
),
|
||||
),
|
||||
|
||||
@@ -66,10 +66,10 @@ tool = WebsiteSearchTool(
|
||||
),
|
||||
),
|
||||
embedder=dict(
|
||||
provider="google-generativeai", # or openai, ollama, ...
|
||||
provider="google", # or openai, ollama, ...
|
||||
config=dict(
|
||||
model_name="gemini-embedding-001",
|
||||
task_type="RETRIEVAL_DOCUMENT",
|
||||
model="models/embedding-001",
|
||||
task_type="retrieval_document",
|
||||
# title="Embeddings",
|
||||
),
|
||||
),
|
||||
|
||||
@@ -106,10 +106,10 @@ youtube_channel_tool = YoutubeChannelSearchTool(
|
||||
),
|
||||
),
|
||||
embedder=dict(
|
||||
provider="google-generativeai", # or openai, ollama, ...
|
||||
provider="google", # or openai, ollama, ...
|
||||
config=dict(
|
||||
model_name="gemini-embedding-001",
|
||||
task_type="RETRIEVAL_DOCUMENT",
|
||||
model="models/embedding-001",
|
||||
task_type="retrieval_document",
|
||||
# title="Embeddings",
|
||||
),
|
||||
),
|
||||
|
||||
@@ -108,10 +108,10 @@ youtube_search_tool = YoutubeVideoSearchTool(
|
||||
),
|
||||
),
|
||||
embedder=dict(
|
||||
provider="google-generativeai", # or openai, ollama, ...
|
||||
provider="google", # or openai, ollama, ...
|
||||
config=dict(
|
||||
model_name="gemini-embedding-001",
|
||||
task_type="RETRIEVAL_DOCUMENT",
|
||||
model="models/embedding-001",
|
||||
task_type="retrieval_document",
|
||||
# title="Embeddings",
|
||||
),
|
||||
),
|
||||
|
||||
@@ -7,7 +7,7 @@ mode: "wide"
|
||||
|
||||
## 소개
|
||||
|
||||
CrewAI AOP(Agent Operation Platform)는 프로덕션 환경에서 crew와 agent를 배포, 모니터링, 확장할 수 있는 플랫폼을 제공합니다.
|
||||
CrewAI AOP(Agent Management Platform)는 프로덕션 환경에서 crew와 agent를 배포, 모니터링, 확장할 수 있는 플랫폼을 제공합니다.
|
||||
|
||||
<Frame>
|
||||
<img src="/images/enterprise/crewai-enterprise-dashboard.png" alt="CrewAI AOP Dashboard" />
|
||||
|
||||
@@ -7,7 +7,7 @@ mode: "wide"
|
||||
|
||||
## Introdução
|
||||
|
||||
CrewAI AOP(Agent Operation Platform) fornece uma plataforma para implementar, monitorar e escalar seus crews e agentes em um ambiente de produção.
|
||||
CrewAI AOP(Agent Management Platform) fornece uma plataforma para implementar, monitorar e escalar seus crews e agentes em um ambiente de produção.
|
||||
|
||||
<Frame>
|
||||
<img src="/images/enterprise/crewai-enterprise-dashboard.png" alt="CrewAI AOP Dashboard" />
|
||||
|
||||
@@ -12,13 +12,13 @@ dependencies = [
|
||||
"pytube>=15.0.0",
|
||||
"requests>=2.32.5",
|
||||
"docker>=7.1.0",
|
||||
"crewai==1.6.1",
|
||||
"crewai==1.5.0",
|
||||
"lancedb>=0.5.4",
|
||||
"tiktoken>=0.8.0",
|
||||
"beautifulsoup4>=4.13.4",
|
||||
"pypdf>=5.9.0",
|
||||
"python-docx>=1.2.0",
|
||||
"youtube-transcript-api>=1.2.2",
|
||||
"pymupdf>=1.26.6",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -291,4 +291,4 @@ __all__ = [
|
||||
"ZapierActionTools",
|
||||
]
|
||||
|
||||
__version__ = "1.6.1"
|
||||
__version__ = "1.5.0"
|
||||
|
||||
@@ -1,46 +1,39 @@
|
||||
"""Adapter for CrewAI's native RAG system."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from pathlib import Path
|
||||
from typing import Any, TypeAlias, TypedDict
|
||||
import uuid
|
||||
|
||||
from crewai.rag.config.types import RagConfigType
|
||||
from crewai.rag.config.utils import get_rag_client
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.factory import create_client
|
||||
from crewai.rag.qdrant.config import QdrantConfig
|
||||
from crewai.rag.types import BaseRecord, SearchResult
|
||||
from pydantic import PrivateAttr
|
||||
from pydantic.dataclasses import is_pydantic_dataclass
|
||||
from typing_extensions import TypeIs, Unpack
|
||||
from qdrant_client.models import VectorParams
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from crewai_tools.rag.misc import sanitize_metadata_for_chromadb
|
||||
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
from crewai_tools.tools.rag.types import AddDocumentParams, ContentItem
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.rag.qdrant.config import QdrantConfig
|
||||
ContentItem: TypeAlias = str | Path | dict[str, Any]
|
||||
|
||||
|
||||
def _is_qdrant_config(config: Any) -> TypeIs[QdrantConfig]:
|
||||
"""Check if config is a QdrantConfig using safe duck typing.
|
||||
class AddDocumentParams(TypedDict, total=False):
|
||||
"""Parameters for adding documents to the RAG system."""
|
||||
|
||||
Args:
|
||||
config: RAG configuration to check.
|
||||
|
||||
Returns:
|
||||
True if config is a QdrantConfig instance.
|
||||
"""
|
||||
if not is_pydantic_dataclass(config):
|
||||
return False
|
||||
|
||||
try:
|
||||
return cast(bool, config.provider == "qdrant") # type: ignore[attr-defined]
|
||||
except (AttributeError, ImportError):
|
||||
return False
|
||||
data_type: DataType
|
||||
metadata: dict[str, Any]
|
||||
website: str
|
||||
url: str
|
||||
file_path: str | Path
|
||||
github_url: str
|
||||
youtube_url: str
|
||||
directory_path: str | Path
|
||||
|
||||
|
||||
class CrewAIRagAdapter(Adapter):
|
||||
@@ -63,9 +56,8 @@ class CrewAIRagAdapter(Adapter):
|
||||
else:
|
||||
self._client = get_rag_client()
|
||||
collection_params: dict[str, Any] = {"collection_name": self.collection_name}
|
||||
|
||||
if self.config is not None and _is_qdrant_config(self.config):
|
||||
if self.config.vectors_config is not None:
|
||||
if isinstance(self.config, QdrantConfig) and self.config.vectors_config:
|
||||
if isinstance(self.config.vectors_config, VectorParams):
|
||||
collection_params["vectors_config"] = self.config.vectors_config
|
||||
self._client.get_or_create_collection(**collection_params)
|
||||
|
||||
@@ -115,26 +107,13 @@ class CrewAIRagAdapter(Adapter):
|
||||
def add(self, *args: ContentItem, **kwargs: Unpack[AddDocumentParams]) -> None:
|
||||
"""Add content to the knowledge base.
|
||||
|
||||
This method handles various input types and converts them to documents
|
||||
for the vector database. It supports the data_type parameter for
|
||||
compatibility with existing tools.
|
||||
|
||||
Args:
|
||||
*args: Content items to add (strings, paths, or document dicts)
|
||||
**kwargs: Additional parameters including:
|
||||
- data_type: DataType enum or string (e.g., "file", "pdf_file", "text")
|
||||
- path: Path to file or directory (alternative to positional arg)
|
||||
- file_path: Alias for path
|
||||
- metadata: Additional metadata to attach to documents
|
||||
- url: URL to fetch content from
|
||||
- website: Website URL to scrape
|
||||
- github_url: GitHub repository URL
|
||||
- youtube_url: YouTube video URL
|
||||
- directory_path: Path to directory
|
||||
|
||||
Examples:
|
||||
rag_tool.add("path/to/document.pdf", data_type=DataType.PDF_FILE)
|
||||
|
||||
rag_tool.add(path="path/to/document.pdf", data_type="file")
|
||||
rag_tool.add(file_path="path/to/document.pdf", data_type="pdf_file")
|
||||
|
||||
rag_tool.add("path/to/document.pdf") # auto-detects PDF
|
||||
**kwargs: Additional parameters including data_type, metadata, etc.
|
||||
"""
|
||||
import os
|
||||
|
||||
@@ -143,54 +122,10 @@ class CrewAIRagAdapter(Adapter):
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
documents: list[BaseRecord] = []
|
||||
raw_data_type = kwargs.get("data_type")
|
||||
data_type: DataType | None = kwargs.get("data_type")
|
||||
base_metadata: dict[str, Any] = kwargs.get("metadata", {})
|
||||
|
||||
data_type: DataType | None = None
|
||||
if raw_data_type is not None:
|
||||
if isinstance(raw_data_type, DataType):
|
||||
if raw_data_type != DataType.FILE:
|
||||
data_type = raw_data_type
|
||||
elif isinstance(raw_data_type, str):
|
||||
if raw_data_type != "file":
|
||||
try:
|
||||
data_type = DataType(raw_data_type)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Invalid data_type: '{raw_data_type}'. "
|
||||
f"Valid values are: 'file' (auto-detect), or one of: "
|
||||
f"{', '.join(dt.value for dt in DataType)}"
|
||||
) from None
|
||||
|
||||
content_items: list[ContentItem] = list(args)
|
||||
|
||||
path_value = kwargs.get("path") or kwargs.get("file_path")
|
||||
if path_value is not None:
|
||||
content_items.append(path_value)
|
||||
|
||||
if url := kwargs.get("url"):
|
||||
content_items.append(url)
|
||||
if website := kwargs.get("website"):
|
||||
content_items.append(website)
|
||||
if github_url := kwargs.get("github_url"):
|
||||
content_items.append(github_url)
|
||||
if youtube_url := kwargs.get("youtube_url"):
|
||||
content_items.append(youtube_url)
|
||||
if directory_path := kwargs.get("directory_path"):
|
||||
content_items.append(directory_path)
|
||||
|
||||
file_extensions = {
|
||||
".pdf",
|
||||
".txt",
|
||||
".csv",
|
||||
".json",
|
||||
".xml",
|
||||
".docx",
|
||||
".mdx",
|
||||
".md",
|
||||
}
|
||||
|
||||
for arg in content_items:
|
||||
for arg in args:
|
||||
source_ref: str
|
||||
if isinstance(arg, dict):
|
||||
source_ref = str(arg.get("source", arg.get("content", "")))
|
||||
@@ -198,14 +133,6 @@ class CrewAIRagAdapter(Adapter):
|
||||
source_ref = str(arg)
|
||||
|
||||
if not data_type:
|
||||
ext = os.path.splitext(source_ref)[1].lower()
|
||||
is_url = source_ref.startswith(("http://", "https://", "file://"))
|
||||
if (
|
||||
ext in file_extensions
|
||||
and not is_url
|
||||
and not os.path.isfile(source_ref)
|
||||
):
|
||||
raise FileNotFoundError(f"File does not exist: {source_ref}")
|
||||
data_type = DataTypes.from_content(source_ref)
|
||||
|
||||
if data_type == DataType.DIRECTORY:
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from enum import Enum
|
||||
from importlib import import_module
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader
|
||||
@@ -10,7 +8,6 @@ from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
||||
|
||||
|
||||
class DataType(str, Enum):
|
||||
FILE = "file"
|
||||
PDF_FILE = "pdf_file"
|
||||
TEXT_FILE = "text_file"
|
||||
CSV = "csv"
|
||||
@@ -18,14 +15,22 @@ class DataType(str, Enum):
|
||||
XML = "xml"
|
||||
DOCX = "docx"
|
||||
MDX = "mdx"
|
||||
|
||||
# Database types
|
||||
MYSQL = "mysql"
|
||||
POSTGRES = "postgres"
|
||||
|
||||
# Repository types
|
||||
GITHUB = "github"
|
||||
DIRECTORY = "directory"
|
||||
|
||||
# Web types
|
||||
WEBSITE = "website"
|
||||
DOCS_SITE = "docs_site"
|
||||
YOUTUBE_VIDEO = "youtube_video"
|
||||
YOUTUBE_CHANNEL = "youtube_channel"
|
||||
|
||||
# Raw types
|
||||
TEXT = "text"
|
||||
|
||||
def get_chunker(self) -> BaseChunker:
|
||||
@@ -58,11 +63,13 @@ class DataType(str, Enum):
|
||||
|
||||
try:
|
||||
module = import_module(module_path)
|
||||
return cast(BaseChunker, getattr(module, class_name)())
|
||||
return getattr(module, class_name)()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error loading chunker for {self}: {e}") from e
|
||||
|
||||
def get_loader(self) -> BaseLoader:
|
||||
from importlib import import_module
|
||||
|
||||
loaders = {
|
||||
DataType.PDF_FILE: ("pdf_loader", "PDFLoader"),
|
||||
DataType.TEXT_FILE: ("text_loader", "TextFileLoader"),
|
||||
@@ -91,7 +98,7 @@ class DataType(str, Enum):
|
||||
module_path = f"crewai_tools.rag.loaders.{module_name}"
|
||||
try:
|
||||
module = import_module(module_path)
|
||||
return cast(BaseLoader, getattr(module, class_name)())
|
||||
return getattr(module, class_name)()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error loading loader for {self}: {e}") from e
|
||||
|
||||
|
||||
@@ -2,112 +2,70 @@
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
from urllib.parse import urlparse
|
||||
import urllib.request
|
||||
from typing import Any
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader, LoaderResult
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
|
||||
|
||||
class PDFLoader(BaseLoader):
|
||||
"""Loader for PDF files and URLs."""
|
||||
"""Loader for PDF files."""
|
||||
|
||||
@staticmethod
|
||||
def _is_url(path: str) -> bool:
|
||||
"""Check if the path is a URL."""
|
||||
try:
|
||||
parsed = urlparse(path)
|
||||
return parsed.scheme in ("http", "https")
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _download_pdf(url: str) -> bytes:
|
||||
"""Download PDF content from a URL.
|
||||
def load(self, source: SourceContent, **kwargs) -> LoaderResult: # type: ignore[override]
|
||||
"""Load and extract text from a PDF file.
|
||||
|
||||
Args:
|
||||
url: The URL to download from.
|
||||
source: The source content containing the PDF file path
|
||||
|
||||
Returns:
|
||||
The PDF content as bytes.
|
||||
LoaderResult with extracted text content
|
||||
|
||||
Raises:
|
||||
ValueError: If the download fails.
|
||||
"""
|
||||
|
||||
try:
|
||||
with urllib.request.urlopen(url, timeout=30) as response: # noqa: S310
|
||||
return cast(bytes, response.read())
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to download PDF from {url}: {e!s}") from e
|
||||
|
||||
def load(self, source: SourceContent, **kwargs: Any) -> LoaderResult: # type: ignore[override]
|
||||
"""Load and extract text from a PDF file or URL.
|
||||
|
||||
Args:
|
||||
source: The source content containing the PDF file path or URL.
|
||||
|
||||
Returns:
|
||||
LoaderResult with extracted text content.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the PDF file doesn't exist.
|
||||
ImportError: If required PDF libraries aren't installed.
|
||||
ValueError: If the PDF cannot be read or downloaded.
|
||||
FileNotFoundError: If the PDF file doesn't exist
|
||||
ImportError: If required PDF libraries aren't installed
|
||||
"""
|
||||
try:
|
||||
import pymupdf # type: ignore[import-untyped]
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"PDF support requires pymupdf. Install with: uv add pymupdf"
|
||||
) from e
|
||||
import pypdf
|
||||
except ImportError:
|
||||
try:
|
||||
import PyPDF2 as pypdf # type: ignore[import-not-found,no-redef] # noqa: N813
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"PDF support requires pypdf or PyPDF2. Install with: uv add pypdf"
|
||||
) from e
|
||||
|
||||
file_path = source.source
|
||||
is_url = self._is_url(file_path)
|
||||
|
||||
if is_url:
|
||||
source_name = Path(urlparse(file_path).path).name or "downloaded.pdf"
|
||||
else:
|
||||
source_name = Path(file_path).name
|
||||
if not os.path.isfile(file_path):
|
||||
raise FileNotFoundError(f"PDF file not found: {file_path}")
|
||||
|
||||
text_content: list[str] = []
|
||||
text_content = []
|
||||
metadata: dict[str, Any] = {
|
||||
"source": file_path,
|
||||
"file_name": source_name,
|
||||
"source": str(file_path),
|
||||
"file_name": Path(file_path).name,
|
||||
"file_type": "pdf",
|
||||
}
|
||||
|
||||
try:
|
||||
if is_url:
|
||||
pdf_bytes = self._download_pdf(file_path)
|
||||
doc = pymupdf.open(stream=pdf_bytes, filetype="pdf")
|
||||
else:
|
||||
if not os.path.isfile(file_path):
|
||||
raise FileNotFoundError(f"PDF file not found: {file_path}")
|
||||
doc = pymupdf.open(file_path)
|
||||
with open(file_path, "rb") as file:
|
||||
pdf_reader = pypdf.PdfReader(file)
|
||||
metadata["num_pages"] = len(pdf_reader.pages)
|
||||
|
||||
metadata["num_pages"] = len(doc)
|
||||
|
||||
for page_num, page in enumerate(doc, 1):
|
||||
page_text = page.get_text()
|
||||
if page_text.strip():
|
||||
text_content.append(f"Page {page_num}:\n{page_text}")
|
||||
|
||||
doc.close()
|
||||
except FileNotFoundError:
|
||||
raise
|
||||
for page_num, page in enumerate(pdf_reader.pages, 1):
|
||||
page_text = page.extract_text()
|
||||
if page_text.strip():
|
||||
text_content.append(f"Page {page_num}:\n{page_text}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error reading PDF from {file_path}: {e!s}") from e
|
||||
raise ValueError(f"Error reading PDF file {file_path}: {e!s}") from e
|
||||
|
||||
if not text_content:
|
||||
content = f"[PDF file with no extractable text: {source_name}]"
|
||||
content = f"[PDF file with no extractable text: {Path(file_path).name}]"
|
||||
else:
|
||||
content = "\n\n".join(text_content)
|
||||
|
||||
return LoaderResult(
|
||||
content=content,
|
||||
source=file_path,
|
||||
source=str(file_path),
|
||||
metadata=metadata,
|
||||
doc_id=self.generate_doc_id(source_ref=file_path, content=content),
|
||||
doc_id=self.generate_doc_id(source_ref=str(file_path), content=content),
|
||||
)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing_extensions import Self
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||
@@ -25,17 +24,14 @@ class PDFSearchTool(RagTool):
|
||||
"A tool that can be used to semantic search a query from a PDF's content."
|
||||
)
|
||||
args_schema: type[BaseModel] = PDFSearchToolSchema
|
||||
pdf: str | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _configure_for_pdf(self) -> Self:
|
||||
"""Configure tool for specific PDF if provided."""
|
||||
if self.pdf is not None:
|
||||
self.add(self.pdf)
|
||||
self.description = f"A tool that can be used to semantic search a query the {self.pdf} PDF's content."
|
||||
def __init__(self, pdf: str | None = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if pdf is not None:
|
||||
self.add(pdf)
|
||||
self.description = f"A tool that can be used to semantic search a query the {pdf} PDF's content."
|
||||
self.args_schema = FixedPDFSearchToolSchema
|
||||
self._generate_description()
|
||||
return self
|
||||
|
||||
def add(self, pdf: str) -> None:
|
||||
super().add(pdf, data_type=DataType.PDF_FILE)
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
from crewai.rag.embeddings.types import ProviderSpec
|
||||
|
||||
from crewai_tools.tools.rag.types import RagToolConfig, VectorDbConfig
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ProviderSpec",
|
||||
"RagToolConfig",
|
||||
"VectorDbConfig",
|
||||
]
|
||||
|
||||
@@ -1,84 +1,13 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Literal, cast
|
||||
import os
|
||||
from typing import Any, cast
|
||||
|
||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
||||
from crewai.rag.embeddings.factory import build_embedder
|
||||
from crewai.rag.embeddings.types import ProviderSpec
|
||||
from crewai.rag.embeddings.factory import get_embedding_function
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
TypeAdapter,
|
||||
ValidationError,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import Self, Unpack
|
||||
|
||||
from crewai_tools.tools.rag.types import (
|
||||
AddDocumentParams,
|
||||
ContentItem,
|
||||
RagToolConfig,
|
||||
VectorDbConfig,
|
||||
)
|
||||
|
||||
|
||||
def _validate_embedding_config(
|
||||
value: dict[str, Any] | ProviderSpec,
|
||||
) -> dict[str, Any] | ProviderSpec:
|
||||
"""Validate embedding config and provide clearer error messages for union validation.
|
||||
|
||||
This pre-validator catches Pydantic ValidationErrors from the ProviderSpec union
|
||||
and provides a cleaner, more focused error message that only shows the relevant
|
||||
provider's validation errors instead of all 18 union members.
|
||||
|
||||
Args:
|
||||
value: The embedding configuration dictionary or validated ProviderSpec.
|
||||
|
||||
Returns:
|
||||
A validated ProviderSpec instance, or the original value if already validated
|
||||
or missing required fields.
|
||||
|
||||
Raises:
|
||||
ValueError: If the configuration is invalid for the specified provider.
|
||||
"""
|
||||
if not isinstance(value, dict):
|
||||
return value
|
||||
|
||||
provider = value.get("provider")
|
||||
if not provider:
|
||||
return value
|
||||
|
||||
try:
|
||||
type_adapter: TypeAdapter[ProviderSpec] = TypeAdapter(ProviderSpec)
|
||||
return type_adapter.validate_python(value)
|
||||
except ValidationError as e:
|
||||
provider_key = f"{provider.lower()}providerspec"
|
||||
provider_errors = [
|
||||
err for err in e.errors() if provider_key in str(err.get("loc", "")).lower()
|
||||
]
|
||||
|
||||
if provider_errors:
|
||||
error_msgs = []
|
||||
for err in provider_errors:
|
||||
loc_parts = err["loc"]
|
||||
if str(loc_parts[0]).lower() == provider_key:
|
||||
loc_parts = loc_parts[1:]
|
||||
loc = ".".join(str(x) for x in loc_parts)
|
||||
error_msgs.append(f" - {loc}: {err['msg']}")
|
||||
|
||||
raise ValueError(
|
||||
f"Invalid configuration for embedding provider '{provider}':\n"
|
||||
+ "\n".join(error_msgs)
|
||||
) from e
|
||||
|
||||
raise
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
|
||||
class Adapter(BaseModel, ABC):
|
||||
"""Abstract base class for RAG adapters."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@abstractmethod
|
||||
@@ -93,8 +22,8 @@ class Adapter(BaseModel, ABC):
|
||||
@abstractmethod
|
||||
def add(
|
||||
self,
|
||||
*args: ContentItem,
|
||||
**kwargs: Unpack[AddDocumentParams],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Add content to the knowledge base."""
|
||||
|
||||
@@ -109,11 +38,7 @@ class RagTool(BaseTool):
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def add(
|
||||
self,
|
||||
*args: ContentItem,
|
||||
**kwargs: Unpack[AddDocumentParams],
|
||||
) -> None:
|
||||
def add(self, *args: Any, **kwargs: Any) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
name: str = "Knowledge base"
|
||||
@@ -121,131 +46,145 @@ class RagTool(BaseTool):
|
||||
summarize: bool = False
|
||||
similarity_threshold: float = 0.6
|
||||
limit: int = 5
|
||||
collection_name: str = "rag_tool_collection"
|
||||
adapter: Adapter = Field(default_factory=_AdapterPlaceholder)
|
||||
config: RagToolConfig = Field(
|
||||
default_factory=RagToolConfig,
|
||||
description="Configuration format accepted by RagTool.",
|
||||
)
|
||||
|
||||
@field_validator("config", mode="before")
|
||||
@classmethod
|
||||
def _validate_config(cls, value: Any) -> Any:
|
||||
"""Validate config with improved error messages for embedding providers."""
|
||||
if not isinstance(value, dict):
|
||||
return value
|
||||
|
||||
embedding_model = value.get("embedding_model")
|
||||
if embedding_model:
|
||||
try:
|
||||
value["embedding_model"] = _validate_embedding_config(embedding_model)
|
||||
except ValueError:
|
||||
raise
|
||||
|
||||
return value
|
||||
config: Any | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _ensure_adapter(self) -> Self:
|
||||
def _set_default_adapter(self):
|
||||
if isinstance(self.adapter, RagTool._AdapterPlaceholder):
|
||||
from crewai_tools.adapters.crewai_rag_adapter import CrewAIRagAdapter
|
||||
|
||||
provider_cfg = self._parse_config(self.config)
|
||||
parsed_config = self._parse_config(self.config)
|
||||
|
||||
self.adapter = CrewAIRagAdapter(
|
||||
collection_name=self.collection_name,
|
||||
collection_name="rag_tool_collection",
|
||||
summarize=self.summarize,
|
||||
similarity_threshold=self.similarity_threshold,
|
||||
limit=self.limit,
|
||||
config=provider_cfg,
|
||||
config=parsed_config,
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def _parse_config(self, config: RagToolConfig) -> Any:
|
||||
"""Normalize the RagToolConfig into a provider-specific config object.
|
||||
def _parse_config(self, config: Any) -> Any:
|
||||
"""Parse complex config format to extract provider-specific config.
|
||||
|
||||
Defaults to 'chromadb' with no extra provider config if none is supplied.
|
||||
Raises:
|
||||
ValueError: If the config format is invalid or uses unsupported providers.
|
||||
"""
|
||||
if not config:
|
||||
return self._create_provider_config("chromadb", {}, None)
|
||||
if config is None:
|
||||
return None
|
||||
|
||||
vectordb_cfg = cast(VectorDbConfig, config.get("vectordb", {}))
|
||||
provider: Literal["chromadb", "qdrant"] = vectordb_cfg.get(
|
||||
"provider", "chromadb"
|
||||
)
|
||||
provider_config: dict[str, Any] = vectordb_cfg.get("config", {})
|
||||
if isinstance(config, dict) and "provider" in config:
|
||||
return config
|
||||
|
||||
supported = ("chromadb", "qdrant")
|
||||
if provider not in supported:
|
||||
raise ValueError(
|
||||
f"Unsupported vector database provider: '{provider}'. "
|
||||
f"CrewAI RAG currently supports: {', '.join(supported)}."
|
||||
)
|
||||
if isinstance(config, dict):
|
||||
if "vectordb" in config:
|
||||
vectordb_config = config["vectordb"]
|
||||
if isinstance(vectordb_config, dict) and "provider" in vectordb_config:
|
||||
provider = vectordb_config["provider"]
|
||||
provider_config = vectordb_config.get("config", {})
|
||||
|
||||
embedding_spec: ProviderSpec | None = config.get("embedding_model")
|
||||
if embedding_spec:
|
||||
embedding_spec = cast(
|
||||
ProviderSpec, _validate_embedding_config(embedding_spec)
|
||||
)
|
||||
supported_providers = ["chromadb", "qdrant"]
|
||||
if provider not in supported_providers:
|
||||
raise ValueError(
|
||||
f"Unsupported vector database provider: '{provider}'. "
|
||||
f"CrewAI RAG currently supports: {', '.join(supported_providers)}."
|
||||
)
|
||||
|
||||
embedding_function = build_embedder(embedding_spec) if embedding_spec else None
|
||||
return self._create_provider_config(
|
||||
provider, provider_config, embedding_function
|
||||
)
|
||||
embedding_config = config.get("embedding_model")
|
||||
embedding_function = None
|
||||
if embedding_config and isinstance(embedding_config, dict):
|
||||
embedding_function = self._create_embedding_function(
|
||||
embedding_config, provider
|
||||
)
|
||||
|
||||
return self._create_provider_config(
|
||||
provider, provider_config, embedding_function
|
||||
)
|
||||
return None
|
||||
embedding_config = config.get("embedding_model")
|
||||
embedding_function = None
|
||||
if embedding_config and isinstance(embedding_config, dict):
|
||||
embedding_function = self._create_embedding_function(
|
||||
embedding_config, "chromadb"
|
||||
)
|
||||
|
||||
return self._create_provider_config("chromadb", {}, embedding_function)
|
||||
return config
|
||||
|
||||
@staticmethod
|
||||
def _create_embedding_function(embedding_config: dict, provider: str) -> Any:
|
||||
"""Create embedding function for the specified vector database provider."""
|
||||
embedding_provider = embedding_config.get("provider")
|
||||
embedding_model_config = embedding_config.get("config", {}).copy()
|
||||
|
||||
if "model" in embedding_model_config:
|
||||
embedding_model_config["model_name"] = embedding_model_config.pop("model")
|
||||
|
||||
factory_config = {"provider": embedding_provider, **embedding_model_config}
|
||||
|
||||
if embedding_provider == "openai" and "api_key" not in factory_config:
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if api_key:
|
||||
factory_config["api_key"] = api_key
|
||||
|
||||
if provider == "chromadb":
|
||||
return get_embedding_function(factory_config) # type: ignore[call-overload]
|
||||
|
||||
if provider == "qdrant":
|
||||
chromadb_func = get_embedding_function(factory_config) # type: ignore[call-overload]
|
||||
|
||||
def qdrant_embed_fn(text: str) -> list[float]:
|
||||
"""Embed text using ChromaDB function and convert to list of floats for Qdrant.
|
||||
|
||||
Args:
|
||||
text: The input text to embed.
|
||||
|
||||
Returns:
|
||||
A list of floats representing the embedding.
|
||||
"""
|
||||
embeddings = chromadb_func([text])
|
||||
return embeddings[0] if embeddings and len(embeddings) > 0 else []
|
||||
|
||||
return cast(Any, qdrant_embed_fn)
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _create_provider_config(
|
||||
provider: Literal["chromadb", "qdrant"],
|
||||
provider_config: dict[str, Any],
|
||||
embedding_function: EmbeddingFunction[Any] | None,
|
||||
provider: str, provider_config: dict, embedding_function: Any
|
||||
) -> Any:
|
||||
"""Instantiate provider config with optional embedding_function injected."""
|
||||
"""Create proper provider config object."""
|
||||
if provider == "chromadb":
|
||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
|
||||
kwargs = dict(provider_config)
|
||||
if embedding_function is not None:
|
||||
kwargs["embedding_function"] = embedding_function
|
||||
return ChromaDBConfig(**kwargs)
|
||||
config_kwargs = {}
|
||||
if embedding_function:
|
||||
config_kwargs["embedding_function"] = embedding_function
|
||||
|
||||
config_kwargs.update(provider_config)
|
||||
|
||||
return ChromaDBConfig(**config_kwargs)
|
||||
|
||||
if provider == "qdrant":
|
||||
from crewai.rag.qdrant.config import QdrantConfig
|
||||
|
||||
kwargs = dict(provider_config)
|
||||
if embedding_function is not None:
|
||||
kwargs["embedding_function"] = embedding_function
|
||||
return QdrantConfig(**kwargs)
|
||||
config_kwargs = {}
|
||||
if embedding_function:
|
||||
config_kwargs["embedding_function"] = embedding_function
|
||||
|
||||
raise ValueError(f"Unhandled provider: {provider}")
|
||||
config_kwargs.update(provider_config)
|
||||
|
||||
return QdrantConfig(**config_kwargs)
|
||||
|
||||
return None
|
||||
|
||||
def add(
|
||||
self,
|
||||
*args: ContentItem,
|
||||
**kwargs: Unpack[AddDocumentParams],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Add content to the knowledge base.
|
||||
|
||||
|
||||
Args:
|
||||
*args: Content items to add (strings, paths, or document dicts)
|
||||
data_type: DataType enum or string (e.g., "file", "pdf_file", "text")
|
||||
path: Path to file or directory, alias to positional arg
|
||||
file_path: Alias for path
|
||||
metadata: Additional metadata to attach to documents
|
||||
url: URL to fetch content from
|
||||
website: Website URL to scrape
|
||||
github_url: GitHub repository URL
|
||||
youtube_url: YouTube video URL
|
||||
directory_path: Path to directory
|
||||
|
||||
Examples:
|
||||
rag_tool.add("path/to/document.pdf", data_type=DataType.PDF_FILE)
|
||||
|
||||
# Keyword argument (documented API)
|
||||
rag_tool.add(path="path/to/document.pdf", data_type="file")
|
||||
rag_tool.add(file_path="path/to/document.pdf", data_type="pdf_file")
|
||||
|
||||
# Auto-detect type from extension
|
||||
rag_tool.add("path/to/document.pdf") # auto-detects PDF
|
||||
"""
|
||||
self.adapter.add(*args, **kwargs)
|
||||
|
||||
def _run(
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
"""Type definitions for RAG tool configuration."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, TypeAlias
|
||||
|
||||
from crewai.rag.embeddings.types import ProviderSpec
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
|
||||
|
||||
DataTypeStr: TypeAlias = Literal[
|
||||
"file",
|
||||
"pdf_file",
|
||||
"text_file",
|
||||
"csv",
|
||||
"json",
|
||||
"xml",
|
||||
"docx",
|
||||
"mdx",
|
||||
"mysql",
|
||||
"postgres",
|
||||
"github",
|
||||
"directory",
|
||||
"website",
|
||||
"docs_site",
|
||||
"youtube_video",
|
||||
"youtube_channel",
|
||||
"text",
|
||||
]
|
||||
|
||||
ContentItem: TypeAlias = str | Path | dict[str, Any]
|
||||
|
||||
|
||||
class AddDocumentParams(TypedDict, total=False):
|
||||
"""Parameters for adding documents to the RAG system."""
|
||||
|
||||
data_type: DataType | DataTypeStr
|
||||
metadata: dict[str, Any]
|
||||
path: str | Path
|
||||
file_path: str | Path
|
||||
website: str
|
||||
url: str
|
||||
github_url: str
|
||||
youtube_url: str
|
||||
directory_path: str | Path
|
||||
|
||||
|
||||
class VectorDbConfig(TypedDict):
|
||||
"""Configuration for vector database provider.
|
||||
|
||||
Attributes:
|
||||
provider: RAG provider literal.
|
||||
config: RAG configuration options.
|
||||
"""
|
||||
|
||||
provider: Literal["chromadb", "qdrant"]
|
||||
config: dict[str, Any]
|
||||
|
||||
|
||||
class RagToolConfig(TypedDict, total=False):
|
||||
"""Configuration accepted by RAG tools.
|
||||
|
||||
Supports embedding model and vector database configuration.
|
||||
|
||||
Attributes:
|
||||
embedding_model: Embedding model configuration accepted by RAG tools.
|
||||
vectordb: Vector database configuration accepted by RAG tools.
|
||||
"""
|
||||
|
||||
embedding_model: ProviderSpec
|
||||
vectordb: VectorDbConfig
|
||||
@@ -1,5 +1,4 @@
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing_extensions import Self
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||
|
||||
@@ -25,17 +24,14 @@ class TXTSearchTool(RagTool):
|
||||
"A tool that can be used to semantic search a query from a txt's content."
|
||||
)
|
||||
args_schema: type[BaseModel] = TXTSearchToolSchema
|
||||
txt: str | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _configure_for_txt(self) -> Self:
|
||||
"""Configure tool for specific TXT file if provided."""
|
||||
if self.txt is not None:
|
||||
self.add(self.txt)
|
||||
self.description = f"A tool that can be used to semantic search a query the {self.txt} txt's content."
|
||||
def __init__(self, txt: str | None = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if txt is not None:
|
||||
self.add(txt)
|
||||
self.description = f"A tool that can be used to semantic search a query the {txt} txt's content."
|
||||
self.args_schema = FixedTXTSearchToolSchema
|
||||
self._generate_description()
|
||||
return self
|
||||
|
||||
def _run( # type: ignore[override]
|
||||
self,
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Tests for RAG tool with mocked embeddings and vector database."""
|
||||
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import cast
|
||||
@@ -115,15 +117,15 @@ def test_rag_tool_with_file(
|
||||
assert "Python is a programming language" in result
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.rag.rag_tool.build_embedder")
|
||||
@patch("crewai_tools.tools.rag.rag_tool.RagTool._create_embedding_function")
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
|
||||
def test_rag_tool_with_custom_embeddings(
|
||||
mock_create_client: Mock, mock_build_embedder: Mock
|
||||
mock_create_client: Mock, mock_create_embedding: Mock
|
||||
) -> None:
|
||||
"""Test RagTool with custom embeddings configuration to ensure no API calls."""
|
||||
mock_embedding_func = MagicMock()
|
||||
mock_embedding_func.return_value = [[0.2] * 1536]
|
||||
mock_build_embedder.return_value = mock_embedding_func
|
||||
mock_create_embedding.return_value = mock_embedding_func
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
@@ -151,7 +153,7 @@ def test_rag_tool_with_custom_embeddings(
|
||||
assert "Relevant Content:" in result
|
||||
assert "Test content" in result
|
||||
|
||||
mock_build_embedder.assert_called()
|
||||
mock_create_embedding.assert_called()
|
||||
|
||||
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.get_rag_client")
|
||||
@@ -174,128 +176,3 @@ def test_rag_tool_no_results(
|
||||
result = tool._run(query="Non-existent content")
|
||||
assert "Relevant Content:" in result
|
||||
assert "No relevant content found" in result
|
||||
|
||||
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
|
||||
def test_rag_tool_with_azure_config_without_env_vars(
|
||||
mock_create_client: Mock,
|
||||
) -> None:
|
||||
"""Test that RagTool accepts Azure config without requiring env vars.
|
||||
|
||||
This test verifies the fix for the issue where RAG tools were ignoring
|
||||
the embedding configuration passed via the config parameter and instead
|
||||
requiring environment variables like EMBEDDINGS_OPENAI_API_KEY.
|
||||
"""
|
||||
mock_embedding_func = MagicMock()
|
||||
mock_embedding_func.return_value = [[0.1] * 1536]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_client.add_documents = MagicMock(return_value=None)
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
# Patch the embedding function builder to avoid actual API calls
|
||||
with patch(
|
||||
"crewai_tools.tools.rag.rag_tool.build_embedder",
|
||||
return_value=mock_embedding_func,
|
||||
):
|
||||
|
||||
class MyTool(RagTool):
|
||||
pass
|
||||
|
||||
# Configuration with explicit Azure credentials - should work without env vars
|
||||
config = {
|
||||
"embedding_model": {
|
||||
"provider": "azure",
|
||||
"config": {
|
||||
"model": "text-embedding-3-small",
|
||||
"api_key": "test-api-key",
|
||||
"api_base": "https://test.openai.azure.com/",
|
||||
"api_version": "2024-02-01",
|
||||
"api_type": "azure",
|
||||
"deployment_id": "test-deployment",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
# This should not raise a validation error about missing env vars
|
||||
tool = MyTool(config=config)
|
||||
|
||||
assert tool.adapter is not None
|
||||
assert isinstance(tool.adapter, CrewAIRagAdapter)
|
||||
|
||||
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
|
||||
def test_rag_tool_with_openai_config_without_env_vars(
|
||||
mock_create_client: Mock,
|
||||
) -> None:
|
||||
"""Test that RagTool accepts OpenAI config without requiring env vars."""
|
||||
mock_embedding_func = MagicMock()
|
||||
mock_embedding_func.return_value = [[0.1] * 1536]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
with patch(
|
||||
"crewai_tools.tools.rag.rag_tool.build_embedder",
|
||||
return_value=mock_embedding_func,
|
||||
):
|
||||
|
||||
class MyTool(RagTool):
|
||||
pass
|
||||
|
||||
config = {
|
||||
"embedding_model": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": "text-embedding-3-small",
|
||||
"api_key": "sk-test123",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
tool = MyTool(config=config)
|
||||
|
||||
assert tool.adapter is not None
|
||||
assert isinstance(tool.adapter, CrewAIRagAdapter)
|
||||
|
||||
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
|
||||
def test_rag_tool_config_with_qdrant_and_azure_embeddings(
|
||||
mock_create_client: Mock,
|
||||
) -> None:
|
||||
"""Test RagTool with Qdrant vector DB and Azure embeddings config."""
|
||||
mock_embedding_func = MagicMock()
|
||||
mock_embedding_func.return_value = [[0.1] * 1536]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
with patch(
|
||||
"crewai_tools.tools.rag.rag_tool.build_embedder",
|
||||
return_value=mock_embedding_func,
|
||||
):
|
||||
|
||||
class MyTool(RagTool):
|
||||
pass
|
||||
|
||||
config = {
|
||||
"vectordb": {"provider": "qdrant", "config": {}},
|
||||
"embedding_model": {
|
||||
"provider": "azure",
|
||||
"config": {
|
||||
"model": "text-embedding-3-large",
|
||||
"api_key": "test-key",
|
||||
"api_base": "https://test.openai.azure.com/",
|
||||
"api_version": "2024-02-01",
|
||||
"deployment_id": "test-deployment",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tool = MyTool(config=config)
|
||||
|
||||
assert tool.adapter is not None
|
||||
assert isinstance(tool.adapter, CrewAIRagAdapter)
|
||||
|
||||
@@ -1,471 +0,0 @@
|
||||
"""Tests for RagTool.add() method with various data_type values."""
|
||||
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_rag_client() -> MagicMock:
|
||||
"""Create a mock RAG client for testing."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_client.add_documents = MagicMock(return_value=None)
|
||||
mock_client.search = MagicMock(return_value=[])
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rag_tool(mock_rag_client: MagicMock) -> RagTool:
|
||||
"""Create a RagTool instance with mocked client."""
|
||||
with (
|
||||
patch(
|
||||
"crewai_tools.adapters.crewai_rag_adapter.get_rag_client",
|
||||
return_value=mock_rag_client,
|
||||
),
|
||||
patch(
|
||||
"crewai_tools.adapters.crewai_rag_adapter.create_client",
|
||||
return_value=mock_rag_client,
|
||||
),
|
||||
):
|
||||
return RagTool()
|
||||
|
||||
|
||||
class TestDataTypeFileAlias:
|
||||
"""Tests for data_type='file' alias."""
|
||||
|
||||
def test_file_alias_with_existing_file(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test that data_type='file' works with existing files."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("Test content for file alias.")
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="file")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_file_alias_with_nonexistent_file_raises_error(
|
||||
self, rag_tool: RagTool
|
||||
) -> None:
|
||||
"""Test that data_type='file' raises FileNotFoundError for missing files."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent/path/to/file.pdf", data_type="file")
|
||||
|
||||
def test_file_alias_with_path_keyword(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test that path keyword argument works with data_type='file'."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "document.txt"
|
||||
test_file.write_text("Content via path keyword.")
|
||||
|
||||
rag_tool.add(data_type="file", path=str(test_file))
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_file_alias_with_file_path_keyword(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test that file_path keyword argument works with data_type='file'."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "document.txt"
|
||||
test_file.write_text("Content via file_path keyword.")
|
||||
|
||||
rag_tool.add(data_type="file", file_path=str(test_file))
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
|
||||
class TestDataTypeStringValues:
|
||||
"""Tests for data_type as string values matching DataType enum."""
|
||||
|
||||
def test_pdf_file_string(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type='pdf_file' with existing PDF file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
# Create a minimal valid PDF file
|
||||
test_file = Path(tmpdir) / "test.pdf"
|
||||
test_file.write_bytes(
|
||||
b"%PDF-1.4\n1 0 obj\n<<\n/Type /Catalog\n>>\nendobj\ntrailer\n"
|
||||
b"<<\n/Root 1 0 R\n>>\n%%EOF"
|
||||
)
|
||||
|
||||
# Mock the PDF loader to avoid actual PDF parsing
|
||||
with patch(
|
||||
"crewai_tools.adapters.crewai_rag_adapter.DataType.get_loader"
|
||||
) as mock_loader:
|
||||
mock_loader_instance = MagicMock()
|
||||
mock_loader_instance.load.return_value = MagicMock(
|
||||
content="PDF content", metadata={}, doc_id="test-id"
|
||||
)
|
||||
mock_loader.return_value = mock_loader_instance
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="pdf_file")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_text_file_string(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type='text_file' with existing text file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("Plain text content.")
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="text_file")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_csv_string(self, rag_tool: RagTool, mock_rag_client: MagicMock) -> None:
|
||||
"""Test data_type='csv' with existing CSV file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.csv"
|
||||
test_file.write_text("name,value\nfoo,1\nbar,2")
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="csv")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_json_string(self, rag_tool: RagTool, mock_rag_client: MagicMock) -> None:
|
||||
"""Test data_type='json' with existing JSON file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.json"
|
||||
test_file.write_text('{"key": "value", "items": [1, 2, 3]}')
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="json")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_xml_string(self, rag_tool: RagTool, mock_rag_client: MagicMock) -> None:
|
||||
"""Test data_type='xml' with existing XML file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.xml"
|
||||
test_file.write_text('<?xml version="1.0"?><root><item>value</item></root>')
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="xml")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_mdx_string(self, rag_tool: RagTool, mock_rag_client: MagicMock) -> None:
|
||||
"""Test data_type='mdx' with existing MDX file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.mdx"
|
||||
test_file.write_text("# Heading\n\nSome markdown content.")
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="mdx")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_text_string(self, rag_tool: RagTool, mock_rag_client: MagicMock) -> None:
|
||||
"""Test data_type='text' with raw text content."""
|
||||
rag_tool.add("This is raw text content.", data_type="text")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_directory_string(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type='directory' with existing directory."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
# Create some files in the directory
|
||||
(Path(tmpdir) / "file1.txt").write_text("Content 1")
|
||||
(Path(tmpdir) / "file2.txt").write_text("Content 2")
|
||||
|
||||
rag_tool.add(path=tmpdir, data_type="directory")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
|
||||
class TestDataTypeEnumValues:
|
||||
"""Tests for data_type as DataType enum values."""
|
||||
|
||||
def test_datatype_file_enum_with_existing_file(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type=DataType.FILE with existing file (auto-detect)."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("File enum auto-detect content.")
|
||||
|
||||
rag_tool.add(str(test_file), data_type=DataType.FILE)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_datatype_file_enum_with_nonexistent_file_raises_error(
|
||||
self, rag_tool: RagTool
|
||||
) -> None:
|
||||
"""Test data_type=DataType.FILE raises FileNotFoundError for missing files."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add("nonexistent/file.pdf", data_type=DataType.FILE)
|
||||
|
||||
def test_datatype_pdf_file_enum(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type=DataType.PDF_FILE with existing file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.pdf"
|
||||
test_file.write_bytes(
|
||||
b"%PDF-1.4\n1 0 obj\n<<\n/Type /Catalog\n>>\nendobj\ntrailer\n"
|
||||
b"<<\n/Root 1 0 R\n>>\n%%EOF"
|
||||
)
|
||||
|
||||
with patch(
|
||||
"crewai_tools.adapters.crewai_rag_adapter.DataType.get_loader"
|
||||
) as mock_loader:
|
||||
mock_loader_instance = MagicMock()
|
||||
mock_loader_instance.load.return_value = MagicMock(
|
||||
content="PDF content", metadata={}, doc_id="test-id"
|
||||
)
|
||||
mock_loader.return_value = mock_loader_instance
|
||||
|
||||
rag_tool.add(str(test_file), data_type=DataType.PDF_FILE)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_datatype_text_file_enum(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type=DataType.TEXT_FILE with existing file."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("Text file content.")
|
||||
|
||||
rag_tool.add(str(test_file), data_type=DataType.TEXT_FILE)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_datatype_text_enum(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type=DataType.TEXT with raw text."""
|
||||
rag_tool.add("Raw text using enum.", data_type=DataType.TEXT)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_datatype_directory_enum(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test data_type=DataType.DIRECTORY with existing directory."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
(Path(tmpdir) / "file.txt").write_text("Directory file content.")
|
||||
|
||||
rag_tool.add(tmpdir, data_type=DataType.DIRECTORY)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
|
||||
class TestInvalidDataType:
|
||||
"""Tests for invalid data_type values."""
|
||||
|
||||
def test_invalid_string_data_type_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that invalid string data_type raises ValueError."""
|
||||
with pytest.raises(ValueError, match="Invalid data_type"):
|
||||
rag_tool.add("some content", data_type="invalid_type")
|
||||
|
||||
def test_invalid_data_type_error_message_contains_valid_values(
|
||||
self, rag_tool: RagTool
|
||||
) -> None:
|
||||
"""Test that error message lists valid data_type values."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
rag_tool.add("some content", data_type="not_a_type")
|
||||
|
||||
error_message = str(exc_info.value)
|
||||
assert "file" in error_message
|
||||
assert "pdf_file" in error_message
|
||||
assert "text_file" in error_message
|
||||
|
||||
|
||||
class TestFileExistenceValidation:
|
||||
"""Tests for file existence validation."""
|
||||
|
||||
def test_pdf_file_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent PDF file raises FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent.pdf", data_type="pdf_file")
|
||||
|
||||
def test_text_file_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent text file raises FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent.txt", data_type="text_file")
|
||||
|
||||
def test_csv_file_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent CSV file raises FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent.csv", data_type="csv")
|
||||
|
||||
def test_json_file_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent JSON file raises FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent.json", data_type="json")
|
||||
|
||||
def test_xml_file_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent XML file raises FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent.xml", data_type="xml")
|
||||
|
||||
def test_docx_file_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent DOCX file raises FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent.docx", data_type="docx")
|
||||
|
||||
def test_mdx_file_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent MDX file raises FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add(path="nonexistent.mdx", data_type="mdx")
|
||||
|
||||
def test_directory_not_found_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that non-existent directory raises ValueError."""
|
||||
with pytest.raises(ValueError, match="Directory does not exist"):
|
||||
rag_tool.add(path="nonexistent/directory", data_type="directory")
|
||||
|
||||
|
||||
class TestKeywordArgumentVariants:
|
||||
"""Tests for different keyword argument combinations."""
|
||||
|
||||
def test_positional_argument_with_data_type(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test positional argument with data_type."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("Positional arg content.")
|
||||
|
||||
rag_tool.add(str(test_file), data_type="text_file")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_path_keyword_with_data_type(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test path keyword argument with data_type."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("Path keyword content.")
|
||||
|
||||
rag_tool.add(path=str(test_file), data_type="text_file")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_file_path_keyword_with_data_type(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test file_path keyword argument with data_type."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("File path keyword content.")
|
||||
|
||||
rag_tool.add(file_path=str(test_file), data_type="text_file")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_directory_path_keyword(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test directory_path keyword argument."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
(Path(tmpdir) / "file.txt").write_text("Directory content.")
|
||||
|
||||
rag_tool.add(directory_path=tmpdir)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
|
||||
class TestAutoDetection:
|
||||
"""Tests for auto-detection of data type from content."""
|
||||
|
||||
def test_auto_detect_nonexistent_file_raises_error(self, rag_tool: RagTool) -> None:
|
||||
"""Test that auto-detection raises FileNotFoundError for missing files."""
|
||||
with pytest.raises(FileNotFoundError, match="File does not exist"):
|
||||
rag_tool.add("path/to/document.pdf")
|
||||
|
||||
def test_auto_detect_txt_file(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test auto-detection of .txt file type."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "auto.txt"
|
||||
test_file.write_text("Auto-detected text file.")
|
||||
|
||||
# No data_type specified - should auto-detect
|
||||
rag_tool.add(str(test_file))
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_auto_detect_csv_file(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test auto-detection of .csv file type."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "auto.csv"
|
||||
test_file.write_text("col1,col2\nval1,val2")
|
||||
|
||||
rag_tool.add(str(test_file))
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_auto_detect_json_file(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test auto-detection of .json file type."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "auto.json"
|
||||
test_file.write_text('{"auto": "detected"}')
|
||||
|
||||
rag_tool.add(str(test_file))
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_auto_detect_directory(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test auto-detection of directory type."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
(Path(tmpdir) / "file.txt").write_text("Auto-detected directory.")
|
||||
|
||||
rag_tool.add(tmpdir)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
def test_auto_detect_raw_text(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test auto-detection of raw text (non-file content)."""
|
||||
rag_tool.add("Just some raw text content")
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
|
||||
|
||||
class TestMetadataHandling:
|
||||
"""Tests for metadata handling with data_type."""
|
||||
|
||||
def test_metadata_passed_to_documents(
|
||||
self, rag_tool: RagTool, mock_rag_client: MagicMock
|
||||
) -> None:
|
||||
"""Test that metadata is properly passed to documents."""
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
test_file = Path(tmpdir) / "test.txt"
|
||||
test_file.write_text("Content with metadata.")
|
||||
|
||||
rag_tool.add(
|
||||
path=str(test_file),
|
||||
data_type="text_file",
|
||||
metadata={"custom_key": "custom_value"},
|
||||
)
|
||||
|
||||
assert mock_rag_client.add_documents.called
|
||||
call_args = mock_rag_client.add_documents.call_args
|
||||
documents = call_args.kwargs.get("documents", call_args.args[0] if call_args.args else [])
|
||||
|
||||
# Check that at least one document has the custom metadata
|
||||
assert any(
|
||||
doc.get("metadata", {}).get("custom_key") == "custom_value"
|
||||
for doc in documents
|
||||
)
|
||||
@@ -1,66 +0,0 @@
|
||||
"""Tests for improved RAG tool validation error messages."""
|
||||
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||
|
||||
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
|
||||
def test_azure_missing_deployment_id_gives_clear_error(mock_create_client: Mock) -> None:
|
||||
"""Test that missing deployment_id for Azure gives a clear, focused error message."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
class MyTool(RagTool):
|
||||
pass
|
||||
|
||||
config = {
|
||||
"embedding_model": {
|
||||
"provider": "azure",
|
||||
"config": {
|
||||
"api_base": "http://localhost:4000/v1",
|
||||
"api_key": "test-key",
|
||||
"api_version": "2024-02-01",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
MyTool(config=config)
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "azure" in error_msg.lower()
|
||||
assert "deployment_id" in error_msg.lower()
|
||||
assert "bedrock" not in error_msg.lower()
|
||||
assert "cohere" not in error_msg.lower()
|
||||
assert "huggingface" not in error_msg.lower()
|
||||
|
||||
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
|
||||
def test_valid_azure_config_works(mock_create_client: Mock) -> None:
|
||||
"""Test that valid Azure config works without errors."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
class MyTool(RagTool):
|
||||
pass
|
||||
|
||||
config = {
|
||||
"embedding_model": {
|
||||
"provider": "azure",
|
||||
"config": {
|
||||
"api_base": "http://localhost:4000/v1",
|
||||
"api_key": "test-key",
|
||||
"api_version": "2024-02-01",
|
||||
"deployment_id": "text-embedding-3-small",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
tool = MyTool(config=config)
|
||||
assert tool is not None
|
||||
@@ -1,116 +0,0 @@
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
from crewai_tools.adapters.crewai_rag_adapter import CrewAIRagAdapter
|
||||
from crewai_tools.tools.pdf_search_tool.pdf_search_tool import PDFSearchTool
|
||||
|
||||
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
|
||||
def test_pdf_search_tool_with_azure_config_without_env_vars(
|
||||
mock_create_client: Mock,
|
||||
) -> None:
|
||||
"""Test PDFSearchTool accepts Azure config without requiring env vars.
|
||||
|
||||
This verifies the fix for the reported issue where PDFSearchTool would
|
||||
throw a validation error:
|
||||
pydantic_core._pydantic_core.ValidationError: 1 validation error for PDFSearchTool
|
||||
EMBEDDINGS_OPENAI_API_KEY
|
||||
Field required [type=missing, input_value={}, input_type=dict]
|
||||
"""
|
||||
mock_embedding_func = MagicMock()
|
||||
mock_embedding_func.return_value = [[0.1] * 1536]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
# Patch the embedding function builder to avoid actual API calls
|
||||
with patch(
|
||||
"crewai_tools.tools.rag.rag_tool.build_embedder",
|
||||
return_value=mock_embedding_func,
|
||||
):
|
||||
# This is the exact config format from the bug report
|
||||
config = {
|
||||
"embedding_model": {
|
||||
"provider": "azure",
|
||||
"config": {
|
||||
"model": "text-embedding-3-small",
|
||||
"api_key": "test-litellm-api-key",
|
||||
"api_base": "https://test.litellm.proxy/",
|
||||
"api_version": "2024-02-01",
|
||||
"api_type": "azure",
|
||||
"deployment_id": "test-deployment",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
# This should not raise a validation error about missing env vars
|
||||
tool = PDFSearchTool(config=config)
|
||||
|
||||
assert tool.adapter is not None
|
||||
assert isinstance(tool.adapter, CrewAIRagAdapter)
|
||||
assert tool.name == "Search a PDF's content"
|
||||
|
||||
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
|
||||
def test_pdf_search_tool_with_openai_config_without_env_vars(
|
||||
mock_create_client: Mock,
|
||||
) -> None:
|
||||
"""Test PDFSearchTool accepts OpenAI config without requiring env vars."""
|
||||
mock_embedding_func = MagicMock()
|
||||
mock_embedding_func.return_value = [[0.1] * 1536]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
with patch(
|
||||
"crewai_tools.tools.rag.rag_tool.build_embedder",
|
||||
return_value=mock_embedding_func,
|
||||
):
|
||||
config = {
|
||||
"embedding_model": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": "text-embedding-3-small",
|
||||
"api_key": "sk-test123",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
tool = PDFSearchTool(config=config)
|
||||
|
||||
assert tool.adapter is not None
|
||||
assert isinstance(tool.adapter, CrewAIRagAdapter)
|
||||
|
||||
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
|
||||
def test_pdf_search_tool_with_vectordb_and_embedding_config(
|
||||
mock_create_client: Mock,
|
||||
) -> None:
|
||||
"""Test PDFSearchTool with both vector DB and embedding config."""
|
||||
mock_embedding_func = MagicMock()
|
||||
mock_embedding_func.return_value = [[0.1] * 1536]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
with patch(
|
||||
"crewai_tools.tools.rag.rag_tool.build_embedder",
|
||||
return_value=mock_embedding_func,
|
||||
):
|
||||
config = {
|
||||
"vectordb": {"provider": "chromadb", "config": {}},
|
||||
"embedding_model": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": "text-embedding-3-large",
|
||||
"api_key": "sk-test-key",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tool = PDFSearchTool(config=config)
|
||||
|
||||
assert tool.adapter is not None
|
||||
assert isinstance(tool.adapter, CrewAIRagAdapter)
|
||||
@@ -1,104 +0,0 @@
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
from crewai_tools.adapters.crewai_rag_adapter import CrewAIRagAdapter
|
||||
from crewai_tools.tools.txt_search_tool.txt_search_tool import TXTSearchTool
|
||||
|
||||
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
|
||||
def test_txt_search_tool_with_azure_config_without_env_vars(
|
||||
mock_create_client: Mock,
|
||||
) -> None:
|
||||
"""Test TXTSearchTool accepts Azure config without requiring env vars."""
|
||||
mock_embedding_func = MagicMock()
|
||||
mock_embedding_func.return_value = [[0.1] * 1536]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
with patch(
|
||||
"crewai_tools.tools.rag.rag_tool.build_embedder",
|
||||
return_value=mock_embedding_func,
|
||||
):
|
||||
config = {
|
||||
"embedding_model": {
|
||||
"provider": "azure",
|
||||
"config": {
|
||||
"model": "text-embedding-3-small",
|
||||
"api_key": "test-api-key",
|
||||
"api_base": "https://test.openai.azure.com/",
|
||||
"api_version": "2024-02-01",
|
||||
"api_type": "azure",
|
||||
"deployment_id": "test-deployment",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
# This should not raise a validation error about missing env vars
|
||||
tool = TXTSearchTool(config=config)
|
||||
|
||||
assert tool.adapter is not None
|
||||
assert isinstance(tool.adapter, CrewAIRagAdapter)
|
||||
assert tool.name == "Search a txt's content"
|
||||
|
||||
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
|
||||
def test_txt_search_tool_with_openai_config_without_env_vars(
|
||||
mock_create_client: Mock,
|
||||
) -> None:
|
||||
"""Test TXTSearchTool accepts OpenAI config without requiring env vars."""
|
||||
mock_embedding_func = MagicMock()
|
||||
mock_embedding_func.return_value = [[0.1] * 1536]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
with patch(
|
||||
"crewai_tools.tools.rag.rag_tool.build_embedder",
|
||||
return_value=mock_embedding_func,
|
||||
):
|
||||
config = {
|
||||
"embedding_model": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": "text-embedding-3-small",
|
||||
"api_key": "sk-test123",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
tool = TXTSearchTool(config=config)
|
||||
|
||||
assert tool.adapter is not None
|
||||
assert isinstance(tool.adapter, CrewAIRagAdapter)
|
||||
|
||||
|
||||
@patch("crewai_tools.adapters.crewai_rag_adapter.create_client")
|
||||
def test_txt_search_tool_with_cohere_config(mock_create_client: Mock) -> None:
|
||||
"""Test TXTSearchTool with Cohere embedding provider."""
|
||||
mock_embedding_func = MagicMock()
|
||||
mock_embedding_func.return_value = [[0.1] * 1024]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_create_client.return_value = mock_client
|
||||
|
||||
with patch(
|
||||
"crewai_tools.tools.rag.rag_tool.build_embedder",
|
||||
return_value=mock_embedding_func,
|
||||
):
|
||||
config = {
|
||||
"embedding_model": {
|
||||
"provider": "cohere",
|
||||
"config": {
|
||||
"model": "embed-english-v3.0",
|
||||
"api_key": "test-cohere-key",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
tool = TXTSearchTool(config=config)
|
||||
|
||||
assert tool.adapter is not None
|
||||
assert isinstance(tool.adapter, CrewAIRagAdapter)
|
||||
@@ -48,7 +48,7 @@ Repository = "https://github.com/crewAIInc/crewAI"
|
||||
|
||||
[project.optional-dependencies]
|
||||
tools = [
|
||||
"crewai-tools==1.6.1",
|
||||
"crewai-tools==1.5.0",
|
||||
]
|
||||
embeddings = [
|
||||
"tiktoken~=0.8.0"
|
||||
|
||||
@@ -40,7 +40,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
|
||||
|
||||
_suppress_pydantic_deprecation_warnings()
|
||||
|
||||
__version__ = "1.6.1"
|
||||
__version__ = "1.5.0"
|
||||
_telemetry_submitted = False
|
||||
|
||||
|
||||
|
||||
@@ -73,7 +73,6 @@ CLI_SETTINGS_KEYS = [
|
||||
"oauth2_audience",
|
||||
"oauth2_client_id",
|
||||
"oauth2_domain",
|
||||
"oauth2_extra",
|
||||
]
|
||||
|
||||
# Default values for CLI settings
|
||||
@@ -83,7 +82,6 @@ DEFAULT_CLI_SETTINGS = {
|
||||
"oauth2_audience": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
|
||||
"oauth2_client_id": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
|
||||
"oauth2_domain": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
|
||||
"oauth2_extra": {},
|
||||
}
|
||||
|
||||
# Readonly settings - cannot be set by the user
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.6.1"
|
||||
"crewai[tools]==1.5.0"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.6.1"
|
||||
"crewai[tools]==1.5.0"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -950,34 +950,15 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
def _handle_crew_planning(self) -> None:
|
||||
"""Handles the Crew planning."""
|
||||
import re
|
||||
|
||||
self._logger.log("info", "Planning the crew execution")
|
||||
result = CrewPlanner(
|
||||
tasks=self.tasks, planning_agent_llm=self.planning_llm
|
||||
)._handle_crew_planning()
|
||||
|
||||
plan_map: dict[int, str] = {}
|
||||
for step_plan in result.list_of_plans_per_task:
|
||||
match = re.search(r"Task Number (\d+)", step_plan.task, re.IGNORECASE)
|
||||
if match:
|
||||
task_number = int(match.group(1))
|
||||
plan_map[task_number] = step_plan.plan
|
||||
else:
|
||||
self._logger.log(
|
||||
"warning",
|
||||
f"Could not extract task number from plan task field: {step_plan.task}",
|
||||
)
|
||||
|
||||
for idx, task in enumerate(self.tasks):
|
||||
task_number = idx + 1 # Task numbers are 1-indexed
|
||||
if task_number in plan_map:
|
||||
task.description += plan_map[task_number]
|
||||
else:
|
||||
self._logger.log(
|
||||
"warning",
|
||||
f"No plan found for task {task_number}. Task description: {task.description}",
|
||||
)
|
||||
for task, step_plan in zip(
|
||||
self.tasks, result.list_of_plans_per_task, strict=False
|
||||
):
|
||||
task.description += step_plan.plan
|
||||
|
||||
def _store_execution_log(
|
||||
self,
|
||||
|
||||
@@ -64,7 +64,6 @@ class FlowFinishedEvent(FlowEvent):
|
||||
flow_name: str
|
||||
result: Any | None = None
|
||||
type: str = "flow_finished"
|
||||
state: dict[str, Any] | BaseModel
|
||||
|
||||
|
||||
class FlowPlotEvent(FlowEvent):
|
||||
|
||||
@@ -1008,7 +1008,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
type="flow_finished",
|
||||
flow_name=self.name or self.__class__.__name__,
|
||||
result=final_output,
|
||||
state=self._copy_and_serialize_state(),
|
||||
),
|
||||
)
|
||||
if future:
|
||||
@@ -1110,7 +1109,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
dumped_params = {f"_{i}": arg for i, arg in enumerate(args)} | (
|
||||
kwargs or {}
|
||||
)
|
||||
|
||||
future = crewai_event_bus.emit(
|
||||
self,
|
||||
MethodExecutionStartedEvent(
|
||||
@@ -1118,7 +1116,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
method_name=method_name,
|
||||
flow_name=self.name or self.__class__.__name__,
|
||||
params=dumped_params,
|
||||
state=self._copy_and_serialize_state(),
|
||||
state=self._copy_state(),
|
||||
),
|
||||
)
|
||||
if future:
|
||||
@@ -1136,14 +1134,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
)
|
||||
|
||||
self._completed_methods.add(method_name)
|
||||
|
||||
future = crewai_event_bus.emit(
|
||||
self,
|
||||
MethodExecutionFinishedEvent(
|
||||
type="method_execution_finished",
|
||||
method_name=method_name,
|
||||
flow_name=self.name or self.__class__.__name__,
|
||||
state=self._copy_and_serialize_state(),
|
||||
state=self._copy_state(),
|
||||
result=result,
|
||||
),
|
||||
)
|
||||
@@ -1165,16 +1162,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
self._event_futures.append(future)
|
||||
raise e
|
||||
|
||||
def _copy_and_serialize_state(self) -> dict[str, Any]:
|
||||
state_copy = self._copy_state()
|
||||
if isinstance(state_copy, BaseModel):
|
||||
try:
|
||||
return state_copy.model_dump(mode="json")
|
||||
except Exception:
|
||||
return state_copy.model_dump()
|
||||
else:
|
||||
return state_copy
|
||||
|
||||
async def _execute_listeners(
|
||||
self, trigger_method: FlowMethodName, result: Any
|
||||
) -> None:
|
||||
|
||||
@@ -17,7 +17,6 @@ from __future__ import annotations
|
||||
|
||||
import ast
|
||||
from collections import defaultdict, deque
|
||||
from enum import Enum
|
||||
import inspect
|
||||
import textwrap
|
||||
from typing import TYPE_CHECKING, Any
|
||||
@@ -41,123 +40,11 @@ if TYPE_CHECKING:
|
||||
_printer = Printer()
|
||||
|
||||
|
||||
def _extract_string_literals_from_type_annotation(
|
||||
node: ast.expr,
|
||||
function_globals: dict[str, Any] | None = None,
|
||||
) -> list[str]:
|
||||
"""Extract string literals from a type annotation AST node.
|
||||
|
||||
Handles:
|
||||
- Literal["a", "b", "c"]
|
||||
- "a" | "b" | "c" (union of string literals)
|
||||
- Just "a" (single string constant annotation)
|
||||
- Enum types with string values (e.g., class MyEnum(str, Enum))
|
||||
|
||||
Args:
|
||||
node: The AST node representing a type annotation.
|
||||
function_globals: The globals dict from the function, used to resolve Enum types.
|
||||
|
||||
Returns:
|
||||
List of string literals found in the annotation.
|
||||
"""
|
||||
|
||||
strings: list[str] = []
|
||||
|
||||
if isinstance(node, ast.Constant) and isinstance(node.value, str):
|
||||
strings.append(node.value)
|
||||
|
||||
elif isinstance(node, ast.Name) and function_globals:
|
||||
enum_class = function_globals.get(node.id)
|
||||
if (
|
||||
enum_class is not None
|
||||
and isinstance(enum_class, type)
|
||||
and issubclass(enum_class, Enum)
|
||||
):
|
||||
strings.extend(
|
||||
member.value for member in enum_class if isinstance(member.value, str)
|
||||
)
|
||||
|
||||
elif isinstance(node, ast.Attribute) and function_globals:
|
||||
try:
|
||||
if isinstance(node.value, ast.Name):
|
||||
module = function_globals.get(node.value.id)
|
||||
if module is not None:
|
||||
enum_class = getattr(module, node.attr, None)
|
||||
if (
|
||||
enum_class is not None
|
||||
and isinstance(enum_class, type)
|
||||
and issubclass(enum_class, Enum)
|
||||
):
|
||||
strings.extend(
|
||||
member.value
|
||||
for member in enum_class
|
||||
if isinstance(member.value, str)
|
||||
)
|
||||
except (AttributeError, TypeError):
|
||||
pass
|
||||
|
||||
elif isinstance(node, ast.Subscript):
|
||||
is_literal = False
|
||||
if isinstance(node.value, ast.Name) and node.value.id == "Literal":
|
||||
is_literal = True
|
||||
elif isinstance(node.value, ast.Attribute) and node.value.attr == "Literal":
|
||||
is_literal = True
|
||||
|
||||
if is_literal:
|
||||
if isinstance(node.slice, ast.Tuple):
|
||||
strings.extend(
|
||||
elt.value
|
||||
for elt in node.slice.elts
|
||||
if isinstance(elt, ast.Constant) and isinstance(elt.value, str)
|
||||
)
|
||||
elif isinstance(node.slice, ast.Constant) and isinstance(
|
||||
node.slice.value, str
|
||||
):
|
||||
strings.append(node.slice.value)
|
||||
|
||||
elif isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
|
||||
strings.extend(
|
||||
_extract_string_literals_from_type_annotation(node.left, function_globals)
|
||||
)
|
||||
strings.extend(
|
||||
_extract_string_literals_from_type_annotation(node.right, function_globals)
|
||||
)
|
||||
|
||||
return strings
|
||||
|
||||
|
||||
def _unwrap_function(function: Any) -> Any:
|
||||
"""Unwrap a function to get the original function with correct globals.
|
||||
|
||||
Flow methods are wrapped by decorators like @router, @listen, etc.
|
||||
This function unwraps them to get the original function which has
|
||||
the correct __globals__ for resolving type annotations like Enums.
|
||||
|
||||
Args:
|
||||
function: The potentially wrapped function.
|
||||
|
||||
Returns:
|
||||
The unwrapped original function.
|
||||
"""
|
||||
if hasattr(function, "__func__"):
|
||||
function = function.__func__
|
||||
|
||||
if hasattr(function, "__wrapped__"):
|
||||
wrapped = function.__wrapped__
|
||||
if hasattr(wrapped, "unwrap"):
|
||||
return wrapped.unwrap()
|
||||
return wrapped
|
||||
|
||||
return function
|
||||
|
||||
|
||||
def get_possible_return_constants(function: Any) -> list[str] | None:
|
||||
"""Extract possible string return values from a function using AST parsing.
|
||||
|
||||
This function analyzes the source code of a router method to identify
|
||||
all possible string values it might return. It handles:
|
||||
- Return type annotations: -> Literal["a", "b"] or -> "a" | "b" | "c"
|
||||
- Enum type annotations: -> MyEnum (extracts string values from members)
|
||||
- Direct string literals: return "value"
|
||||
- Variable assignments: x = "value"; return x
|
||||
- Dictionary lookups: d = {"k": "v"}; return d[key]
|
||||
@@ -170,8 +57,6 @@ def get_possible_return_constants(function: Any) -> list[str] | None:
|
||||
Returns:
|
||||
List of possible string return values, or None if analysis fails.
|
||||
"""
|
||||
unwrapped = _unwrap_function(function)
|
||||
|
||||
try:
|
||||
source = inspect.getsource(function)
|
||||
except OSError:
|
||||
@@ -212,17 +97,6 @@ def get_possible_return_constants(function: Any) -> list[str] | None:
|
||||
return None
|
||||
|
||||
return_values: set[str] = set()
|
||||
|
||||
function_globals = getattr(unwrapped, "__globals__", None)
|
||||
|
||||
for node in ast.walk(code_ast):
|
||||
if isinstance(node, ast.FunctionDef):
|
||||
if node.returns:
|
||||
annotation_values = _extract_string_literals_from_type_annotation(
|
||||
node.returns, function_globals
|
||||
)
|
||||
return_values.update(annotation_values)
|
||||
break # Only process the first function definition
|
||||
dict_definitions: dict[str, list[str]] = {}
|
||||
variable_values: dict[str, list[str]] = {}
|
||||
state_attribute_values: dict[str, list[str]] = {}
|
||||
|
||||
@@ -3,13 +3,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
import inspect
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.flow.constants import AND_CONDITION, OR_CONDITION
|
||||
from crewai.flow.flow_wrappers import FlowCondition
|
||||
from crewai.flow.types import FlowMethodName
|
||||
from crewai.flow.types import FlowMethodName, FlowRouteName
|
||||
from crewai.flow.utils import (
|
||||
is_flow_condition_dict,
|
||||
is_simple_flow_condition,
|
||||
@@ -18,9 +18,6 @@ from crewai.flow.visualization.schema import extract_method_signature
|
||||
from crewai.flow.visualization.types import FlowStructure, NodeMetadata, StructureEdge
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.flow.flow import Flow
|
||||
|
||||
@@ -349,44 +346,35 @@ def build_flow_structure(flow: Flow[Any]) -> FlowStructure:
|
||||
if trigger_method in nodes
|
||||
)
|
||||
|
||||
all_string_triggers: set[str] = set()
|
||||
for condition_data in flow._listeners.values():
|
||||
if is_simple_flow_condition(condition_data):
|
||||
_, methods = condition_data
|
||||
for m in methods:
|
||||
if str(m) not in nodes: # It's a string trigger, not a method name
|
||||
all_string_triggers.add(str(m))
|
||||
elif is_flow_condition_dict(condition_data):
|
||||
for trigger in _extract_direct_or_triggers(condition_data):
|
||||
if trigger not in nodes:
|
||||
all_string_triggers.add(trigger)
|
||||
|
||||
all_router_outputs: set[str] = set()
|
||||
for router_method_name in router_methods:
|
||||
if router_method_name not in flow._router_paths:
|
||||
flow._router_paths[FlowMethodName(router_method_name)] = []
|
||||
|
||||
current_paths = flow._router_paths.get(FlowMethodName(router_method_name), [])
|
||||
if current_paths and router_method_name in nodes:
|
||||
nodes[router_method_name]["router_paths"] = [str(p) for p in current_paths]
|
||||
all_router_outputs.update(str(p) for p in current_paths)
|
||||
|
||||
if not current_paths:
|
||||
logger.warning(
|
||||
f"Could not determine return paths for router '{router_method_name}'. "
|
||||
f"Add a return type annotation like "
|
||||
f"'-> Literal[\"path1\", \"path2\"]' or '-> YourEnum' "
|
||||
f"to enable proper flow visualization."
|
||||
)
|
||||
|
||||
orphaned_triggers = all_string_triggers - all_router_outputs
|
||||
if orphaned_triggers:
|
||||
logger.error(
|
||||
f"Found listeners waiting for triggers {orphaned_triggers} "
|
||||
f"but no router outputs these values explicitly. "
|
||||
f"If your router returns a non-static value, check that your router has proper return type annotations."
|
||||
inferred_paths: Iterable[FlowMethodName | FlowRouteName] = set(
|
||||
flow._router_paths.get(FlowMethodName(router_method_name), [])
|
||||
)
|
||||
|
||||
for condition_data in flow._listeners.values():
|
||||
trigger_strings: list[str] = []
|
||||
|
||||
if is_simple_flow_condition(condition_data):
|
||||
_, methods = condition_data
|
||||
trigger_strings = [str(m) for m in methods]
|
||||
elif is_flow_condition_dict(condition_data):
|
||||
trigger_strings = _extract_direct_or_triggers(condition_data)
|
||||
|
||||
for trigger_str in trigger_strings:
|
||||
if trigger_str not in nodes:
|
||||
# This is likely a router path output
|
||||
inferred_paths.add(trigger_str) # type: ignore[attr-defined]
|
||||
|
||||
if inferred_paths:
|
||||
flow._router_paths[FlowMethodName(router_method_name)] = list(
|
||||
inferred_paths # type: ignore[arg-type]
|
||||
)
|
||||
if router_method_name in nodes:
|
||||
nodes[router_method_name]["router_paths"] = list(inferred_paths)
|
||||
|
||||
for router_method_name in router_methods:
|
||||
if router_method_name not in flow._router_paths:
|
||||
continue
|
||||
@@ -395,9 +383,6 @@ def build_flow_structure(flow: Flow[Any]) -> FlowStructure:
|
||||
|
||||
for path in router_paths:
|
||||
for listener_name, condition_data in flow._listeners.items():
|
||||
if listener_name == router_method_name:
|
||||
continue
|
||||
|
||||
trigger_strings_from_cond: list[str] = []
|
||||
|
||||
if is_simple_flow_condition(condition_data):
|
||||
|
||||
@@ -406,100 +406,46 @@ class LLM(BaseLLM):
|
||||
instance.is_litellm = True
|
||||
return instance
|
||||
|
||||
@classmethod
|
||||
def _matches_provider_pattern(cls, model: str, provider: str) -> bool:
|
||||
"""Check if a model name matches provider-specific patterns.
|
||||
|
||||
This allows supporting models that aren't in the hardcoded constants list,
|
||||
including "latest" versions and new models that follow provider naming conventions.
|
||||
|
||||
Args:
|
||||
model: The model name to check
|
||||
provider: The provider to check against (canonical name)
|
||||
|
||||
Returns:
|
||||
True if the model matches the provider's naming pattern, False otherwise
|
||||
"""
|
||||
model_lower = model.lower()
|
||||
|
||||
if provider == "openai":
|
||||
return any(
|
||||
model_lower.startswith(prefix)
|
||||
for prefix in ["gpt-", "o1", "o3", "o4", "whisper-"]
|
||||
)
|
||||
|
||||
if provider == "anthropic" or provider == "claude":
|
||||
return any(
|
||||
model_lower.startswith(prefix) for prefix in ["claude-", "anthropic."]
|
||||
)
|
||||
|
||||
if provider == "gemini" or provider == "google":
|
||||
return any(
|
||||
model_lower.startswith(prefix)
|
||||
for prefix in ["gemini-", "gemma-", "learnlm-"]
|
||||
)
|
||||
|
||||
if provider == "bedrock":
|
||||
return "." in model_lower
|
||||
|
||||
if provider == "azure":
|
||||
return any(
|
||||
model_lower.startswith(prefix)
|
||||
for prefix in ["gpt-", "gpt-35-", "o1", "o3", "o4", "azure-"]
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _validate_model_in_constants(cls, model: str, provider: str) -> bool:
|
||||
"""Validate if a model name exists in the provider's constants or matches provider patterns.
|
||||
|
||||
This method first checks the hardcoded constants list for known models.
|
||||
If not found, it falls back to pattern matching to support new models,
|
||||
"latest" versions, and models that follow provider naming conventions.
|
||||
"""Validate if a model name exists in the provider's constants.
|
||||
|
||||
Args:
|
||||
model: The model name to validate
|
||||
provider: The provider to check against (canonical name)
|
||||
|
||||
Returns:
|
||||
True if the model exists in constants or matches provider patterns, False otherwise
|
||||
True if the model exists in the provider's constants, False otherwise
|
||||
"""
|
||||
if provider == "openai" and model in OPENAI_MODELS:
|
||||
return True
|
||||
if provider == "openai":
|
||||
return model in OPENAI_MODELS
|
||||
|
||||
if (
|
||||
provider == "anthropic" or provider == "claude"
|
||||
) and model in ANTHROPIC_MODELS:
|
||||
return True
|
||||
if provider == "anthropic" or provider == "claude":
|
||||
return model in ANTHROPIC_MODELS
|
||||
|
||||
if (provider == "gemini" or provider == "google") and model in GEMINI_MODELS:
|
||||
return True
|
||||
if provider == "gemini":
|
||||
return model in GEMINI_MODELS
|
||||
|
||||
if provider == "bedrock" and model in BEDROCK_MODELS:
|
||||
return True
|
||||
if provider == "bedrock":
|
||||
return model in BEDROCK_MODELS
|
||||
|
||||
if provider == "azure":
|
||||
# azure does not provide a list of available models, determine a better way to handle this
|
||||
return True
|
||||
|
||||
# Fallback to pattern matching for models not in constants
|
||||
return cls._matches_provider_pattern(model, provider)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _infer_provider_from_model(cls, model: str) -> str:
|
||||
"""Infer the provider from the model name.
|
||||
|
||||
This method first checks the hardcoded constants list for known models.
|
||||
If not found, it uses pattern matching to infer the provider from model name patterns.
|
||||
This allows supporting new models and "latest" versions without hardcoding.
|
||||
|
||||
Args:
|
||||
model: The model name without provider prefix
|
||||
|
||||
Returns:
|
||||
The inferred provider name, defaults to "openai"
|
||||
"""
|
||||
|
||||
if model in OPENAI_MODELS:
|
||||
return "openai"
|
||||
|
||||
@@ -1753,14 +1699,12 @@ class LLM(BaseLLM):
|
||||
max_tokens=self.max_tokens,
|
||||
presence_penalty=self.presence_penalty,
|
||||
frequency_penalty=self.frequency_penalty,
|
||||
logit_bias=(
|
||||
copy.deepcopy(self.logit_bias, memo) if self.logit_bias else None
|
||||
),
|
||||
response_format=(
|
||||
copy.deepcopy(self.response_format, memo)
|
||||
if self.response_format
|
||||
else None
|
||||
),
|
||||
logit_bias=copy.deepcopy(self.logit_bias, memo)
|
||||
if self.logit_bias
|
||||
else None,
|
||||
response_format=copy.deepcopy(self.response_format, memo)
|
||||
if self.response_format
|
||||
else None,
|
||||
seed=self.seed,
|
||||
logprobs=self.logprobs,
|
||||
top_logprobs=self.top_logprobs,
|
||||
|
||||
@@ -182,8 +182,6 @@ OPENAI_MODELS: list[OpenAIModels] = [
|
||||
|
||||
|
||||
AnthropicModels: TypeAlias = Literal[
|
||||
"claude-opus-4-5-20251101",
|
||||
"claude-opus-4-5",
|
||||
"claude-3-7-sonnet-latest",
|
||||
"claude-3-7-sonnet-20250219",
|
||||
"claude-3-5-haiku-latest",
|
||||
@@ -210,8 +208,6 @@ AnthropicModels: TypeAlias = Literal[
|
||||
"claude-3-haiku-20240307",
|
||||
]
|
||||
ANTHROPIC_MODELS: list[AnthropicModels] = [
|
||||
"claude-opus-4-5-20251101",
|
||||
"claude-opus-4-5",
|
||||
"claude-3-7-sonnet-latest",
|
||||
"claude-3-7-sonnet-20250219",
|
||||
"claude-3-5-haiku-latest",
|
||||
@@ -456,7 +452,6 @@ BedrockModels: TypeAlias = Literal[
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0:28k",
|
||||
"anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
"anthropic.claude-instant-v1:2:100k",
|
||||
"anthropic.claude-opus-4-5-20251101-v1:0",
|
||||
"anthropic.claude-opus-4-1-20250805-v1:0",
|
||||
"anthropic.claude-opus-4-20250514-v1:0",
|
||||
"anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
@@ -529,7 +524,6 @@ BEDROCK_MODELS: list[BedrockModels] = [
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0:28k",
|
||||
"anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
"anthropic.claude-instant-v1:2:100k",
|
||||
"anthropic.claude-opus-4-5-20251101-v1:0",
|
||||
"anthropic.claude-opus-4-1-20250805-v1:0",
|
||||
"anthropic.claude-opus-4-20250514-v1:0",
|
||||
"anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
|
||||
@@ -8,7 +8,6 @@ from typing import TYPE_CHECKING, Any
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
)
|
||||
@@ -27,7 +26,6 @@ try:
|
||||
from azure.ai.inference.models import (
|
||||
ChatCompletions,
|
||||
ChatCompletionsToolCall,
|
||||
JsonSchemaFormat,
|
||||
StreamingChatCompletionsUpdate,
|
||||
)
|
||||
from azure.core.credentials import (
|
||||
@@ -280,16 +278,13 @@ class AzureCompletion(BaseLLM):
|
||||
}
|
||||
|
||||
if response_model and self.is_openai_model:
|
||||
model_description = generate_model_description(response_model)
|
||||
json_schema_info = model_description["json_schema"]
|
||||
json_schema_name = json_schema_info["name"]
|
||||
|
||||
params["response_format"] = JsonSchemaFormat(
|
||||
name=json_schema_name,
|
||||
schema=json_schema_info["schema"],
|
||||
description=f"Schema for {json_schema_name}",
|
||||
strict=json_schema_info["strict"],
|
||||
)
|
||||
params["response_format"] = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": response_model.__name__,
|
||||
"schema": response_model.model_json_schema(),
|
||||
},
|
||||
}
|
||||
|
||||
# Only include model parameter for non-Azure OpenAI endpoints
|
||||
# Azure OpenAI endpoints have the deployment name in the URL
|
||||
@@ -315,14 +310,6 @@ class AzureCompletion(BaseLLM):
|
||||
params["tools"] = self._convert_tools_for_interference(tools)
|
||||
params["tool_choice"] = "auto"
|
||||
|
||||
additional_params = self.additional_params
|
||||
additional_drop_params = additional_params.get("additional_drop_params")
|
||||
drop_params = additional_params.get("drop_params")
|
||||
|
||||
if drop_params and isinstance(additional_drop_params, list):
|
||||
for drop_param in additional_drop_params:
|
||||
params.pop(drop_param, None)
|
||||
|
||||
return params
|
||||
|
||||
def _convert_tools_for_interference(
|
||||
|
||||
@@ -17,7 +17,6 @@ from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.llms.hooks.transport import HTTPTransport
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
)
|
||||
@@ -246,16 +245,6 @@ class OpenAICompletion(BaseLLM):
|
||||
if self.is_o1_model and self.reasoning_effort:
|
||||
params["reasoning_effort"] = self.reasoning_effort
|
||||
|
||||
if self.response_format is not None:
|
||||
if isinstance(self.response_format, type) and issubclass(
|
||||
self.response_format, BaseModel
|
||||
):
|
||||
params["response_format"] = generate_model_description(
|
||||
self.response_format
|
||||
)
|
||||
elif isinstance(self.response_format, dict):
|
||||
params["response_format"] = self.response_format
|
||||
|
||||
if tools:
|
||||
params["tools"] = self._convert_tools_for_interference(tools)
|
||||
params["tool_choice"] = "auto"
|
||||
@@ -314,11 +303,8 @@ class OpenAICompletion(BaseLLM):
|
||||
"""Handle non-streaming chat completion."""
|
||||
try:
|
||||
if response_model:
|
||||
parse_params = {
|
||||
k: v for k, v in params.items() if k != "response_format"
|
||||
}
|
||||
parsed_response = self.client.beta.chat.completions.parse(
|
||||
**parse_params,
|
||||
**params,
|
||||
response_format=response_model,
|
||||
)
|
||||
math_reasoning = parsed_response.choices[0].message
|
||||
|
||||
@@ -66,6 +66,7 @@ class SSETransport(BaseTransport):
|
||||
self._transport_context = sse_client(
|
||||
self.url,
|
||||
headers=self.headers if self.headers else None,
|
||||
terminate_on_close=True,
|
||||
)
|
||||
|
||||
read, write = await self._transport_context.__aenter__()
|
||||
|
||||
@@ -16,7 +16,6 @@ from crewai.utilities.paths import db_storage_path
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.crew import Crew
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.types import ProviderSpec
|
||||
@@ -33,16 +32,16 @@ class RAGStorage(BaseRAGStorage):
|
||||
self,
|
||||
type: str,
|
||||
allow_reset: bool = True,
|
||||
embedder_config: ProviderSpec | BaseEmbeddingsProvider[Any] | None = None,
|
||||
crew: Crew | None = None,
|
||||
embedder_config: ProviderSpec | BaseEmbeddingsProvider | None = None,
|
||||
crew: Any = None,
|
||||
path: str | None = None,
|
||||
) -> None:
|
||||
super().__init__(type, allow_reset, embedder_config, crew)
|
||||
crew_agents = crew.agents if crew else []
|
||||
sanitized_roles = [self._sanitize_role(agent.role) for agent in crew_agents]
|
||||
agents_str = "_".join(sanitized_roles)
|
||||
self.agents = agents_str
|
||||
self.storage_file_name = self._build_storage_file_name(type, agents_str)
|
||||
agents = crew.agents if crew else []
|
||||
agents = [self._sanitize_role(agent.role) for agent in agents]
|
||||
agents = "_".join(agents)
|
||||
self.agents = agents
|
||||
self.storage_file_name = self._build_storage_file_name(type, agents)
|
||||
|
||||
self.type = type
|
||||
self._client: BaseClient | None = None
|
||||
@@ -97,10 +96,6 @@ class RAGStorage(BaseRAGStorage):
|
||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||
)
|
||||
)
|
||||
|
||||
if self.path:
|
||||
config.settings.persist_directory = self.path
|
||||
|
||||
self._client = create_client(config)
|
||||
|
||||
def _get_client(self) -> BaseClient:
|
||||
|
||||
@@ -2,10 +2,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar, overload
|
||||
|
||||
from crewai.project.utils import memoize
|
||||
@@ -158,23 +156,6 @@ def cache_handler(meth: Callable[P, R]) -> CacheHandlerMethod[P, R]:
|
||||
return CacheHandlerMethod(memoize(meth))
|
||||
|
||||
|
||||
def _call_method(method: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
|
||||
"""Call a method, awaiting it if async and running in an event loop."""
|
||||
result = method(*args, **kwargs)
|
||||
if inspect.iscoroutine(result):
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
if loop and loop.is_running():
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||
return pool.submit(asyncio.run, result).result()
|
||||
return asyncio.run(result)
|
||||
return result
|
||||
|
||||
|
||||
@overload
|
||||
def crew(
|
||||
meth: Callable[Concatenate[SelfT, P], Crew],
|
||||
@@ -217,7 +198,7 @@ def crew(
|
||||
|
||||
# Instantiate tasks in order
|
||||
for _, task_method in tasks:
|
||||
task_instance = _call_method(task_method, self)
|
||||
task_instance = task_method(self)
|
||||
instantiated_tasks.append(task_instance)
|
||||
agent_instance = getattr(task_instance, "agent", None)
|
||||
if agent_instance and agent_instance.role not in agent_roles:
|
||||
@@ -226,7 +207,7 @@ def crew(
|
||||
|
||||
# Instantiate agents not included by tasks
|
||||
for _, agent_method in agents:
|
||||
agent_instance = _call_method(agent_method, self)
|
||||
agent_instance = agent_method(self)
|
||||
if agent_instance.role not in agent_roles:
|
||||
instantiated_agents.append(agent_instance)
|
||||
agent_roles.add(agent_instance.role)
|
||||
@@ -234,7 +215,7 @@ def crew(
|
||||
self.agents = instantiated_agents
|
||||
self.tasks = instantiated_tasks
|
||||
|
||||
crew_instance: Crew = _call_method(meth, self, *args, **kwargs)
|
||||
crew_instance = meth(self, *args, **kwargs)
|
||||
|
||||
def callback_wrapper(
|
||||
hook: Callable[Concatenate[CrewInstance, P2], R2], instance: CrewInstance
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
"""Utility functions for the crewai project module."""
|
||||
|
||||
from collections.abc import Callable, Coroutine
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
import inspect
|
||||
from typing import Any, ParamSpec, TypeVar, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -38,8 +37,8 @@ def _make_hashable(arg: Any) -> Any:
|
||||
def memoize(meth: Callable[P, R]) -> Callable[P, R]:
|
||||
"""Memoize a method by caching its results based on arguments.
|
||||
|
||||
Handles both sync and async methods. Pydantic BaseModel instances are
|
||||
converted to JSON strings before hashing for cache lookup.
|
||||
Handles Pydantic BaseModel instances by converting them to JSON strings
|
||||
before hashing for cache lookup.
|
||||
|
||||
Args:
|
||||
meth: The method to memoize.
|
||||
@@ -47,16 +46,18 @@ def memoize(meth: Callable[P, R]) -> Callable[P, R]:
|
||||
Returns:
|
||||
A memoized version of the method that caches results.
|
||||
"""
|
||||
if inspect.iscoroutinefunction(meth):
|
||||
return cast(Callable[P, R], _memoize_async(meth))
|
||||
return _memoize_sync(meth)
|
||||
|
||||
|
||||
def _memoize_sync(meth: Callable[P, R]) -> Callable[P, R]:
|
||||
"""Memoize a synchronous method."""
|
||||
|
||||
@wraps(meth)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""Wrapper that converts arguments to hashable form before caching.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments to the memoized method.
|
||||
**kwargs: Keyword arguments to the memoized method.
|
||||
|
||||
Returns:
|
||||
The result of the memoized method call.
|
||||
"""
|
||||
hashable_args = tuple(_make_hashable(arg) for arg in args)
|
||||
hashable_kwargs = tuple(
|
||||
sorted((k, _make_hashable(v)) for k, v in kwargs.items())
|
||||
@@ -72,27 +73,3 @@ def _memoize_sync(meth: Callable[P, R]) -> Callable[P, R]:
|
||||
return result
|
||||
|
||||
return cast(Callable[P, R], wrapper)
|
||||
|
||||
|
||||
def _memoize_async(
|
||||
meth: Callable[P, Coroutine[Any, Any, R]],
|
||||
) -> Callable[P, Coroutine[Any, Any, R]]:
|
||||
"""Memoize an async method."""
|
||||
|
||||
@wraps(meth)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
hashable_args = tuple(_make_hashable(arg) for arg in args)
|
||||
hashable_kwargs = tuple(
|
||||
sorted((k, _make_hashable(v)) for k, v in kwargs.items())
|
||||
)
|
||||
cache_key = str((hashable_args, hashable_kwargs))
|
||||
|
||||
cached_result: R | None = cache.read(tool=meth.__name__, input=cache_key)
|
||||
if cached_result is not None:
|
||||
return cached_result
|
||||
|
||||
result = await meth(*args, **kwargs)
|
||||
cache.add(tool=meth.__name__, input=cache_key, output=result)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -2,10 +2,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
import inspect
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -134,22 +132,6 @@ class CrewClass(Protocol):
|
||||
crew: Callable[..., Crew]
|
||||
|
||||
|
||||
def _resolve_result(result: Any) -> Any:
|
||||
"""Resolve a potentially async result to its value."""
|
||||
if inspect.iscoroutine(result):
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
if loop and loop.is_running():
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||
return pool.submit(asyncio.run, result).result()
|
||||
return asyncio.run(result)
|
||||
return result
|
||||
|
||||
|
||||
class DecoratedMethod(Generic[P, R]):
|
||||
"""Base wrapper for methods with decorator metadata.
|
||||
|
||||
@@ -180,12 +162,7 @@ class DecoratedMethod(Generic[P, R]):
|
||||
"""
|
||||
if obj is None:
|
||||
return self
|
||||
inner = partial(self._meth, obj)
|
||||
|
||||
def _bound(*args: Any, **kwargs: Any) -> R:
|
||||
result: R = _resolve_result(inner(*args, **kwargs)) # type: ignore[call-arg]
|
||||
return result
|
||||
|
||||
bound = partial(self._meth, obj)
|
||||
for attr in (
|
||||
"is_agent",
|
||||
"is_llm",
|
||||
@@ -197,8 +174,8 @@ class DecoratedMethod(Generic[P, R]):
|
||||
"is_crew",
|
||||
):
|
||||
if hasattr(self, attr):
|
||||
setattr(_bound, attr, getattr(self, attr))
|
||||
return _bound
|
||||
setattr(bound, attr, getattr(self, attr))
|
||||
return bound
|
||||
|
||||
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""Call the wrapped method.
|
||||
@@ -259,7 +236,6 @@ class BoundTaskMethod(Generic[TaskResultT]):
|
||||
The task result with name ensured.
|
||||
"""
|
||||
result = self._task_method.unwrap()(self._obj, *args, **kwargs)
|
||||
result = _resolve_result(result)
|
||||
return self._task_method.ensure_task_name(result)
|
||||
|
||||
|
||||
@@ -316,9 +292,7 @@ class TaskMethod(Generic[P, TaskResultT]):
|
||||
Returns:
|
||||
The task instance with name set if not already provided.
|
||||
"""
|
||||
result = self._meth(*args, **kwargs)
|
||||
result = _resolve_result(result)
|
||||
return self.ensure_task_name(result)
|
||||
return self.ensure_task_name(self._meth(*args, **kwargs))
|
||||
|
||||
def unwrap(self) -> Callable[P, TaskResultT]:
|
||||
"""Get the original unwrapped method.
|
||||
|
||||
@@ -91,7 +91,6 @@ PROVIDER_PATHS = {
|
||||
"cohere": "crewai.rag.embeddings.providers.cohere.cohere_provider.CohereProvider",
|
||||
"custom": "crewai.rag.embeddings.providers.custom.custom_provider.CustomProvider",
|
||||
"google-generativeai": "crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider",
|
||||
"google": "crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider",
|
||||
"google-vertex": "crewai.rag.embeddings.providers.google.vertex.VertexAIProvider",
|
||||
"huggingface": "crewai.rag.embeddings.providers.huggingface.huggingface_provider.HuggingFaceProvider",
|
||||
"instructor": "crewai.rag.embeddings.providers.instructor.instructor_provider.InstructorProvider",
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any
|
||||
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
|
||||
AmazonBedrockEmbeddingFunction,
|
||||
)
|
||||
from pydantic import AliasChoices, Field
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -21,7 +21,7 @@ def create_aws_session() -> Any:
|
||||
ValueError: If AWS session creation fails
|
||||
"""
|
||||
try:
|
||||
import boto3
|
||||
import boto3 # type: ignore[import]
|
||||
|
||||
return boto3.Session()
|
||||
except ImportError as e:
|
||||
@@ -46,12 +46,7 @@ class BedrockProvider(BaseEmbeddingsProvider[AmazonBedrockEmbeddingFunction]):
|
||||
model_name: str = Field(
|
||||
default="amazon.titan-embed-text-v1",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_BEDROCK_MODEL_NAME",
|
||||
"BEDROCK_MODEL_NAME",
|
||||
"AWS_BEDROCK_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
validation_alias="EMBEDDINGS_BEDROCK_MODEL_NAME",
|
||||
)
|
||||
session: Any = Field(
|
||||
default_factory=create_aws_session, description="AWS session object"
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.cohere_embedding_function import (
|
||||
CohereEmbeddingFunction,
|
||||
)
|
||||
from pydantic import AliasChoices, Field
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -15,14 +15,10 @@ class CohereProvider(BaseEmbeddingsProvider[CohereEmbeddingFunction]):
|
||||
default=CohereEmbeddingFunction, description="Cohere embedding function class"
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Cohere API key",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_COHERE_API_KEY", "COHERE_API_KEY"),
|
||||
description="Cohere API key", validation_alias="EMBEDDINGS_COHERE_API_KEY"
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="large",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_COHERE_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
validation_alias="EMBEDDINGS_COHERE_MODEL_NAME",
|
||||
)
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
"""Google Generative AI embeddings provider."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleGenerativeAiEmbeddingFunction,
|
||||
)
|
||||
from pydantic import AliasChoices, Field
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -17,27 +15,16 @@ class GenerativeAiProvider(BaseEmbeddingsProvider[GoogleGenerativeAiEmbeddingFun
|
||||
default=GoogleGenerativeAiEmbeddingFunction,
|
||||
description="Google Generative AI embedding function class",
|
||||
)
|
||||
model_name: Literal[
|
||||
"gemini-embedding-001", "text-embedding-005", "text-multilingual-embedding-002"
|
||||
] = Field(
|
||||
default="gemini-embedding-001",
|
||||
model_name: str = Field(
|
||||
default="models/embedding-001",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_GENERATIVE_AI_MODEL_NAME", "model"
|
||||
),
|
||||
validation_alias="EMBEDDINGS_GOOGLE_GENERATIVE_AI_MODEL_NAME",
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Google API key",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_API_KEY", "GOOGLE_API_KEY", "GEMINI_API_KEY"
|
||||
),
|
||||
description="Google API key", validation_alias="EMBEDDINGS_GOOGLE_API_KEY"
|
||||
)
|
||||
task_type: str = Field(
|
||||
default="RETRIEVAL_DOCUMENT",
|
||||
description="Task type for embeddings",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_GENERATIVE_AI_TASK_TYPE",
|
||||
"GOOGLE_GENERATIVE_AI_TASK_TYPE",
|
||||
"GEMINI_TASK_TYPE",
|
||||
),
|
||||
validation_alias="EMBEDDINGS_GOOGLE_GENERATIVE_AI_TASK_TYPE",
|
||||
)
|
||||
|
||||
@@ -6,23 +6,10 @@ from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class GenerativeAiProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Google Generative AI provider.
|
||||
|
||||
Attributes:
|
||||
api_key: Google API key for authentication.
|
||||
model_name: Embedding model name.
|
||||
task_type: Task type for embeddings. Default is "RETRIEVAL_DOCUMENT".
|
||||
"""
|
||||
"""Configuration for Google Generative AI provider."""
|
||||
|
||||
api_key: str
|
||||
model_name: Annotated[
|
||||
Literal[
|
||||
"gemini-embedding-001",
|
||||
"text-embedding-005",
|
||||
"text-multilingual-embedding-002",
|
||||
],
|
||||
"gemini-embedding-001",
|
||||
]
|
||||
model_name: Annotated[str, "models/embedding-001"]
|
||||
task_type: Annotated[str, "RETRIEVAL_DOCUMENT"]
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleVertexEmbeddingFunction,
|
||||
)
|
||||
from pydantic import AliasChoices, Field
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -18,29 +18,18 @@ class VertexAIProvider(BaseEmbeddingsProvider[GoogleVertexEmbeddingFunction]):
|
||||
model_name: str = Field(
|
||||
default="textembedding-gecko",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_VERTEX_MODEL_NAME",
|
||||
"GOOGLE_VERTEX_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
validation_alias="EMBEDDINGS_GOOGLE_VERTEX_MODEL_NAME",
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Google API key",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_CLOUD_API_KEY", "GOOGLE_CLOUD_API_KEY"
|
||||
),
|
||||
description="Google API key", validation_alias="EMBEDDINGS_GOOGLE_CLOUD_API_KEY"
|
||||
)
|
||||
project_id: str = Field(
|
||||
default="cloud-large-language-models",
|
||||
description="GCP project ID",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_CLOUD_PROJECT", "GOOGLE_CLOUD_PROJECT"
|
||||
),
|
||||
validation_alias="EMBEDDINGS_GOOGLE_CLOUD_PROJECT",
|
||||
)
|
||||
region: str = Field(
|
||||
default="us-central1",
|
||||
description="GCP region",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_CLOUD_REGION", "GOOGLE_CLOUD_REGION"
|
||||
),
|
||||
validation_alias="EMBEDDINGS_GOOGLE_CLOUD_REGION",
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||
HuggingFaceEmbeddingServer,
|
||||
)
|
||||
from pydantic import AliasChoices, Field
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -16,6 +16,5 @@ class HuggingFaceProvider(BaseEmbeddingsProvider[HuggingFaceEmbeddingServer]):
|
||||
description="HuggingFace embedding function class",
|
||||
)
|
||||
url: str = Field(
|
||||
description="HuggingFace API URL",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_HUGGINGFACE_URL", "HUGGINGFACE_URL"),
|
||||
description="HuggingFace API URL", validation_alias="EMBEDDINGS_HUGGINGFACE_URL"
|
||||
)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import AliasChoices, Field, model_validator
|
||||
from pydantic import Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
@@ -21,10 +21,7 @@ class WatsonXProvider(BaseEmbeddingsProvider[WatsonXEmbeddingFunction]):
|
||||
default=WatsonXEmbeddingFunction, description="WatsonX embedding function class"
|
||||
)
|
||||
model_id: str = Field(
|
||||
description="WatsonX model ID",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_MODEL_ID", "WATSONX_MODEL_ID"
|
||||
),
|
||||
description="WatsonX model ID", validation_alias="EMBEDDINGS_WATSONX_MODEL_ID"
|
||||
)
|
||||
params: dict[str, str | dict[str, str]] | None = Field(
|
||||
default=None, description="Additional parameters"
|
||||
@@ -33,143 +30,109 @@ class WatsonXProvider(BaseEmbeddingsProvider[WatsonXEmbeddingFunction]):
|
||||
project_id: str | None = Field(
|
||||
default=None,
|
||||
description="WatsonX project ID",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_PROJECT_ID", "WATSONX_PROJECT_ID"
|
||||
),
|
||||
validation_alias="EMBEDDINGS_WATSONX_PROJECT_ID",
|
||||
)
|
||||
space_id: str | None = Field(
|
||||
default=None,
|
||||
description="WatsonX space ID",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_SPACE_ID", "WATSONX_SPACE_ID"
|
||||
),
|
||||
validation_alias="EMBEDDINGS_WATSONX_SPACE_ID",
|
||||
)
|
||||
api_client: Any | None = Field(default=None, description="WatsonX API client")
|
||||
verify: bool | str | None = Field(
|
||||
default=None,
|
||||
description="SSL verification",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_WATSONX_VERIFY", "WATSONX_VERIFY"),
|
||||
validation_alias="EMBEDDINGS_WATSONX_VERIFY",
|
||||
)
|
||||
persistent_connection: bool = Field(
|
||||
default=True,
|
||||
description="Use persistent connection",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_PERSISTENT_CONNECTION", "WATSONX_PERSISTENT_CONNECTION"
|
||||
),
|
||||
validation_alias="EMBEDDINGS_WATSONX_PERSISTENT_CONNECTION",
|
||||
)
|
||||
batch_size: int = Field(
|
||||
default=100,
|
||||
description="Batch size for processing",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_BATCH_SIZE", "WATSONX_BATCH_SIZE"
|
||||
),
|
||||
validation_alias="EMBEDDINGS_WATSONX_BATCH_SIZE",
|
||||
)
|
||||
concurrency_limit: int = Field(
|
||||
default=10,
|
||||
description="Concurrency limit",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_CONCURRENCY_LIMIT", "WATSONX_CONCURRENCY_LIMIT"
|
||||
),
|
||||
validation_alias="EMBEDDINGS_WATSONX_CONCURRENCY_LIMIT",
|
||||
)
|
||||
max_retries: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum retries",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_MAX_RETRIES", "WATSONX_MAX_RETRIES"
|
||||
),
|
||||
validation_alias="EMBEDDINGS_WATSONX_MAX_RETRIES",
|
||||
)
|
||||
delay_time: float | None = Field(
|
||||
default=None,
|
||||
description="Delay time between retries",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_DELAY_TIME", "WATSONX_DELAY_TIME"
|
||||
),
|
||||
validation_alias="EMBEDDINGS_WATSONX_DELAY_TIME",
|
||||
)
|
||||
retry_status_codes: list[int] | None = Field(
|
||||
default=None, description="HTTP status codes to retry on"
|
||||
)
|
||||
url: str = Field(
|
||||
description="WatsonX API URL",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_WATSONX_URL", "WATSONX_URL"),
|
||||
description="WatsonX API URL", validation_alias="EMBEDDINGS_WATSONX_URL"
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="WatsonX API key",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_WATSONX_API_KEY", "WATSONX_API_KEY"),
|
||||
description="WatsonX API key", validation_alias="EMBEDDINGS_WATSONX_API_KEY"
|
||||
)
|
||||
name: str | None = Field(
|
||||
default=None,
|
||||
description="Service name",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_WATSONX_NAME", "WATSONX_NAME"),
|
||||
validation_alias="EMBEDDINGS_WATSONX_NAME",
|
||||
)
|
||||
iam_serviceid_crn: str | None = Field(
|
||||
default=None,
|
||||
description="IAM service ID CRN",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_IAM_SERVICEID_CRN", "WATSONX_IAM_SERVICEID_CRN"
|
||||
),
|
||||
validation_alias="EMBEDDINGS_WATSONX_IAM_SERVICEID_CRN",
|
||||
)
|
||||
trusted_profile_id: str | None = Field(
|
||||
default=None,
|
||||
description="Trusted profile ID",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_TRUSTED_PROFILE_ID", "WATSONX_TRUSTED_PROFILE_ID"
|
||||
),
|
||||
validation_alias="EMBEDDINGS_WATSONX_TRUSTED_PROFILE_ID",
|
||||
)
|
||||
token: str | None = Field(
|
||||
default=None,
|
||||
description="Bearer token",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_WATSONX_TOKEN", "WATSONX_TOKEN"),
|
||||
validation_alias="EMBEDDINGS_WATSONX_TOKEN",
|
||||
)
|
||||
projects_token: str | None = Field(
|
||||
default=None,
|
||||
description="Projects token",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_PROJECTS_TOKEN", "WATSONX_PROJECTS_TOKEN"
|
||||
),
|
||||
validation_alias="EMBEDDINGS_WATSONX_PROJECTS_TOKEN",
|
||||
)
|
||||
username: str | None = Field(
|
||||
default=None,
|
||||
description="Username",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_USERNAME", "WATSONX_USERNAME"
|
||||
),
|
||||
validation_alias="EMBEDDINGS_WATSONX_USERNAME",
|
||||
)
|
||||
password: str | None = Field(
|
||||
default=None,
|
||||
description="Password",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_PASSWORD", "WATSONX_PASSWORD"
|
||||
),
|
||||
validation_alias="EMBEDDINGS_WATSONX_PASSWORD",
|
||||
)
|
||||
instance_id: str | None = Field(
|
||||
default=None,
|
||||
description="Service instance ID",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_INSTANCE_ID", "WATSONX_INSTANCE_ID"
|
||||
),
|
||||
validation_alias="EMBEDDINGS_WATSONX_INSTANCE_ID",
|
||||
)
|
||||
version: str | None = Field(
|
||||
default=None,
|
||||
description="API version",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_WATSONX_VERSION", "WATSONX_VERSION"),
|
||||
validation_alias="EMBEDDINGS_WATSONX_VERSION",
|
||||
)
|
||||
bedrock_url: str | None = Field(
|
||||
default=None,
|
||||
description="Bedrock URL",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_BEDROCK_URL", "WATSONX_BEDROCK_URL"
|
||||
),
|
||||
validation_alias="EMBEDDINGS_WATSONX_BEDROCK_URL",
|
||||
)
|
||||
platform_url: str | None = Field(
|
||||
default=None,
|
||||
description="Platform URL",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_WATSONX_PLATFORM_URL", "WATSONX_PLATFORM_URL"
|
||||
),
|
||||
)
|
||||
proxies: dict[str, Any] | None = Field(
|
||||
default=None, description="Proxy configuration"
|
||||
validation_alias="EMBEDDINGS_WATSONX_PLATFORM_URL",
|
||||
)
|
||||
proxies: dict | None = Field(default=None, description="Proxy configuration")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_space_or_project(self) -> Self:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.instructor_embedding_function import (
|
||||
InstructorEmbeddingFunction,
|
||||
)
|
||||
from pydantic import AliasChoices, Field
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -18,23 +18,15 @@ class InstructorProvider(BaseEmbeddingsProvider[InstructorEmbeddingFunction]):
|
||||
model_name: str = Field(
|
||||
default="hkunlp/instructor-base",
|
||||
description="Model name to use",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_INSTRUCTOR_MODEL_NAME",
|
||||
"INSTRUCTOR_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
validation_alias="EMBEDDINGS_INSTRUCTOR_MODEL_NAME",
|
||||
)
|
||||
device: str = Field(
|
||||
default="cpu",
|
||||
description="Device to run model on (cpu or cuda)",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_INSTRUCTOR_DEVICE", "INSTRUCTOR_DEVICE"
|
||||
),
|
||||
validation_alias="EMBEDDINGS_INSTRUCTOR_DEVICE",
|
||||
)
|
||||
instruction: str | None = Field(
|
||||
default=None,
|
||||
description="Instruction for embeddings",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_INSTRUCTOR_INSTRUCTION", "INSTRUCTOR_INSTRUCTION"
|
||||
),
|
||||
validation_alias="EMBEDDINGS_INSTRUCTOR_INSTRUCTION",
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.jina_embedding_function import (
|
||||
JinaEmbeddingFunction,
|
||||
)
|
||||
from pydantic import AliasChoices, Field
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -15,15 +15,10 @@ class JinaProvider(BaseEmbeddingsProvider[JinaEmbeddingFunction]):
|
||||
default=JinaEmbeddingFunction, description="Jina embedding function class"
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Jina API key",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_JINA_API_KEY", "JINA_API_KEY"),
|
||||
description="Jina API key", validation_alias="EMBEDDINGS_JINA_API_KEY"
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="jina-embeddings-v2-base-en",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_JINA_MODEL_NAME",
|
||||
"JINA_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
validation_alias="EMBEDDINGS_JINA_MODEL_NAME",
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
from pydantic import AliasChoices, Field
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -18,39 +18,27 @@ class AzureProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
|
||||
description="Azure OpenAI embedding function class",
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Azure API key",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_OPENAI_API_KEY", "OPENAI_API_KEY"),
|
||||
description="Azure API key", validation_alias="EMBEDDINGS_OPENAI_API_KEY"
|
||||
)
|
||||
api_base: str | None = Field(
|
||||
default=None,
|
||||
description="Azure endpoint URL",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_OPENAI_API_BASE", "OPENAI_API_BASE"),
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_BASE",
|
||||
)
|
||||
api_type: str = Field(
|
||||
default="azure",
|
||||
description="API type for Azure",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_API_TYPE", "OPENAI_API_TYPE", "AZURE_OPENAI_API_TYPE"
|
||||
),
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_TYPE",
|
||||
)
|
||||
api_version: str | None = Field(
|
||||
default="2024-02-01",
|
||||
default=None,
|
||||
description="Azure API version",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_API_VERSION",
|
||||
"OPENAI_API_VERSION",
|
||||
"AZURE_OPENAI_API_VERSION",
|
||||
),
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_VERSION",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="text-embedding-ada-002",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_MODEL_NAME",
|
||||
"OPENAI_MODEL_NAME",
|
||||
"AZURE_OPENAI_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
validation_alias="EMBEDDINGS_OPENAI_MODEL_NAME",
|
||||
)
|
||||
default_headers: dict[str, Any] | None = Field(
|
||||
default=None, description="Default headers for API requests"
|
||||
@@ -58,26 +46,15 @@ class AzureProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
|
||||
dimensions: int | None = Field(
|
||||
default=None,
|
||||
description="Embedding dimensions",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_DIMENSIONS",
|
||||
"OPENAI_DIMENSIONS",
|
||||
"AZURE_OPENAI_DIMENSIONS",
|
||||
),
|
||||
validation_alias="EMBEDDINGS_OPENAI_DIMENSIONS",
|
||||
)
|
||||
deployment_id: str = Field(
|
||||
deployment_id: str | None = Field(
|
||||
default=None,
|
||||
description="Azure deployment ID",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_DEPLOYMENT_ID",
|
||||
"AZURE_OPENAI_DEPLOYMENT",
|
||||
"AZURE_DEPLOYMENT_ID",
|
||||
),
|
||||
validation_alias="EMBEDDINGS_OPENAI_DEPLOYMENT_ID",
|
||||
)
|
||||
organization_id: str | None = Field(
|
||||
default=None,
|
||||
description="Organization ID",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_ORGANIZATION_ID",
|
||||
"OPENAI_ORGANIZATION_ID",
|
||||
"AZURE_OPENAI_ORGANIZATION_ID",
|
||||
),
|
||||
validation_alias="EMBEDDINGS_OPENAI_ORGANIZATION_ID",
|
||||
)
|
||||
|
||||
@@ -15,7 +15,7 @@ class AzureProviderConfig(TypedDict, total=False):
|
||||
model_name: Annotated[str, "text-embedding-ada-002"]
|
||||
default_headers: dict[str, Any]
|
||||
dimensions: int
|
||||
deployment_id: Required[str]
|
||||
deployment_id: str
|
||||
organization_id: str
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
||||
OllamaEmbeddingFunction,
|
||||
)
|
||||
from pydantic import AliasChoices, Field
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -17,14 +17,9 @@ class OllamaProvider(BaseEmbeddingsProvider[OllamaEmbeddingFunction]):
|
||||
url: str = Field(
|
||||
default="http://localhost:11434/api/embeddings",
|
||||
description="Ollama API endpoint URL",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_OLLAMA_URL", "OLLAMA_URL"),
|
||||
validation_alias="EMBEDDINGS_OLLAMA_URL",
|
||||
)
|
||||
model_name: str = Field(
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OLLAMA_MODEL_NAME",
|
||||
"OLLAMA_MODEL_NAME",
|
||||
"OLLAMA_MODEL",
|
||||
"model",
|
||||
),
|
||||
validation_alias="EMBEDDINGS_OLLAMA_MODEL_NAME",
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""ONNX embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ONNXMiniLM_L6_V2
|
||||
from pydantic import AliasChoices, Field
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -15,7 +15,5 @@ class ONNXProvider(BaseEmbeddingsProvider[ONNXMiniLM_L6_V2]):
|
||||
preferred_providers: list[str] | None = Field(
|
||||
default=None,
|
||||
description="Preferred ONNX execution providers",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_ONNX_PREFERRED_PROVIDERS", "ONNX_PREFERRED_PROVIDERS"
|
||||
),
|
||||
validation_alias="EMBEDDINGS_ONNX_PREFERRED_PROVIDERS",
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
from pydantic import AliasChoices, Field
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -20,33 +20,27 @@ class OpenAIProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="OpenAI API key",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_OPENAI_API_KEY", "OPENAI_API_KEY"),
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_KEY",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="text-embedding-ada-002",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_MODEL_NAME",
|
||||
"OPENAI_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
validation_alias="EMBEDDINGS_OPENAI_MODEL_NAME",
|
||||
)
|
||||
api_base: str | None = Field(
|
||||
default=None,
|
||||
description="Base URL for API requests",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_OPENAI_API_BASE", "OPENAI_API_BASE"),
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_BASE",
|
||||
)
|
||||
api_type: str | None = Field(
|
||||
default=None,
|
||||
description="API type (e.g., 'azure')",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_OPENAI_API_TYPE", "OPENAI_API_TYPE"),
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_TYPE",
|
||||
)
|
||||
api_version: str | None = Field(
|
||||
default=None,
|
||||
description="API version",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_API_VERSION", "OPENAI_API_VERSION"
|
||||
),
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_VERSION",
|
||||
)
|
||||
default_headers: dict[str, Any] | None = Field(
|
||||
default=None, description="Default headers for API requests"
|
||||
@@ -54,21 +48,15 @@ class OpenAIProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
|
||||
dimensions: int | None = Field(
|
||||
default=None,
|
||||
description="Embedding dimensions",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_DIMENSIONS", "OPENAI_DIMENSIONS"
|
||||
),
|
||||
validation_alias="EMBEDDINGS_OPENAI_DIMENSIONS",
|
||||
)
|
||||
deployment_id: str | None = Field(
|
||||
default=None,
|
||||
description="Azure deployment ID",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_DEPLOYMENT_ID", "OPENAI_DEPLOYMENT_ID"
|
||||
),
|
||||
validation_alias="EMBEDDINGS_OPENAI_DEPLOYMENT_ID",
|
||||
)
|
||||
organization_id: str | None = Field(
|
||||
default=None,
|
||||
description="OpenAI organization ID",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENAI_ORGANIZATION_ID", "OPENAI_ORGANIZATION_ID"
|
||||
),
|
||||
validation_alias="EMBEDDINGS_OPENAI_ORGANIZATION_ID",
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.open_clip_embedding_function import (
|
||||
OpenCLIPEmbeddingFunction,
|
||||
)
|
||||
from pydantic import AliasChoices, Field
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -18,21 +18,15 @@ class OpenCLIPProvider(BaseEmbeddingsProvider[OpenCLIPEmbeddingFunction]):
|
||||
model_name: str = Field(
|
||||
default="ViT-B-32",
|
||||
description="Model name to use",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENCLIP_MODEL_NAME",
|
||||
"OPENCLIP_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
validation_alias="EMBEDDINGS_OPENCLIP_MODEL_NAME",
|
||||
)
|
||||
checkpoint: str = Field(
|
||||
default="laion2b_s34b_b79k",
|
||||
description="Model checkpoint",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_OPENCLIP_CHECKPOINT", "OPENCLIP_CHECKPOINT"
|
||||
),
|
||||
validation_alias="EMBEDDINGS_OPENCLIP_CHECKPOINT",
|
||||
)
|
||||
device: str | None = Field(
|
||||
default="cpu",
|
||||
description="Device to run model on",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_OPENCLIP_DEVICE", "OPENCLIP_DEVICE"),
|
||||
validation_alias="EMBEDDINGS_OPENCLIP_DEVICE",
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.roboflow_embedding_function import (
|
||||
RoboflowEmbeddingFunction,
|
||||
)
|
||||
from pydantic import AliasChoices, Field
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -18,14 +18,10 @@ class RoboflowProvider(BaseEmbeddingsProvider[RoboflowEmbeddingFunction]):
|
||||
api_key: str = Field(
|
||||
default="",
|
||||
description="Roboflow API key",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_ROBOFLOW_API_KEY", "ROBOFLOW_API_KEY"
|
||||
),
|
||||
validation_alias="EMBEDDINGS_ROBOFLOW_API_KEY",
|
||||
)
|
||||
api_url: str = Field(
|
||||
default="https://infer.roboflow.com",
|
||||
description="Roboflow API URL",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_ROBOFLOW_API_URL", "ROBOFLOW_API_URL"
|
||||
),
|
||||
validation_alias="EMBEDDINGS_ROBOFLOW_API_URL",
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import (
|
||||
SentenceTransformerEmbeddingFunction,
|
||||
)
|
||||
from pydantic import AliasChoices, Field
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -20,24 +20,15 @@ class SentenceTransformerProvider(
|
||||
model_name: str = Field(
|
||||
default="all-MiniLM-L6-v2",
|
||||
description="Model name to use",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_SENTENCE_TRANSFORMER_MODEL_NAME",
|
||||
"SENTENCE_TRANSFORMER_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
validation_alias="EMBEDDINGS_SENTENCE_TRANSFORMER_MODEL_NAME",
|
||||
)
|
||||
device: str = Field(
|
||||
default="cpu",
|
||||
description="Device to run model on (cpu or cuda)",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_SENTENCE_TRANSFORMER_DEVICE", "SENTENCE_TRANSFORMER_DEVICE"
|
||||
),
|
||||
validation_alias="EMBEDDINGS_SENTENCE_TRANSFORMER_DEVICE",
|
||||
)
|
||||
normalize_embeddings: bool = Field(
|
||||
default=False,
|
||||
description="Whether to normalize embeddings",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_SENTENCE_TRANSFORMER_NORMALIZE_EMBEDDINGS",
|
||||
"SENTENCE_TRANSFORMER_NORMALIZE_EMBEDDINGS",
|
||||
),
|
||||
validation_alias="EMBEDDINGS_SENTENCE_TRANSFORMER_NORMALIZE_EMBEDDINGS",
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from chromadb.utils.embedding_functions.text2vec_embedding_function import (
|
||||
Text2VecEmbeddingFunction,
|
||||
)
|
||||
from pydantic import AliasChoices, Field
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
@@ -18,9 +18,5 @@ class Text2VecProvider(BaseEmbeddingsProvider[Text2VecEmbeddingFunction]):
|
||||
model_name: str = Field(
|
||||
default="shibing624/text2vec-base-chinese",
|
||||
description="Model name to use",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_TEXT2VEC_MODEL_NAME",
|
||||
"TEXT2VEC_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
validation_alias="EMBEDDINGS_TEXT2VEC_MODEL_NAME",
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Voyage AI embeddings provider."""
|
||||
|
||||
from pydantic import AliasChoices, Field
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.providers.voyageai.embedding_callable import (
|
||||
@@ -18,53 +18,38 @@ class VoyageAIProvider(BaseEmbeddingsProvider[VoyageAIEmbeddingFunction]):
|
||||
model: str = Field(
|
||||
default="voyage-2",
|
||||
description="Model to use for embeddings",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_VOYAGEAI_MODEL", "VOYAGEAI_MODEL"),
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_MODEL",
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Voyage AI API key",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_VOYAGEAI_API_KEY", "VOYAGEAI_API_KEY"
|
||||
),
|
||||
description="Voyage AI API key", validation_alias="EMBEDDINGS_VOYAGEAI_API_KEY"
|
||||
)
|
||||
input_type: str | None = Field(
|
||||
default=None,
|
||||
description="Input type for embeddings",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_VOYAGEAI_INPUT_TYPE", "VOYAGEAI_INPUT_TYPE"
|
||||
),
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_INPUT_TYPE",
|
||||
)
|
||||
truncation: bool = Field(
|
||||
default=True,
|
||||
description="Whether to truncate inputs",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_VOYAGEAI_TRUNCATION", "VOYAGEAI_TRUNCATION"
|
||||
),
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_TRUNCATION",
|
||||
)
|
||||
output_dtype: str | None = Field(
|
||||
default=None,
|
||||
description="Output data type",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_VOYAGEAI_OUTPUT_DTYPE", "VOYAGEAI_OUTPUT_DTYPE"
|
||||
),
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_OUTPUT_DTYPE",
|
||||
)
|
||||
output_dimension: int | None = Field(
|
||||
default=None,
|
||||
description="Output dimension",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_VOYAGEAI_OUTPUT_DIMENSION", "VOYAGEAI_OUTPUT_DIMENSION"
|
||||
),
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_OUTPUT_DIMENSION",
|
||||
)
|
||||
max_retries: int = Field(
|
||||
default=0,
|
||||
description="Maximum retries for API calls",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_VOYAGEAI_MAX_RETRIES", "VOYAGEAI_MAX_RETRIES"
|
||||
),
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_MAX_RETRIES",
|
||||
)
|
||||
timeout: float | None = Field(
|
||||
default=None,
|
||||
description="Timeout for API calls",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_VOYAGEAI_TIMEOUT", "VOYAGEAI_TIMEOUT"
|
||||
),
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_TIMEOUT",
|
||||
)
|
||||
|
||||
@@ -29,7 +29,7 @@ from crewai.rag.embeddings.providers.text2vec.types import Text2VecProviderSpec
|
||||
from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderSpec
|
||||
|
||||
|
||||
ProviderSpec: TypeAlias = (
|
||||
ProviderSpec = (
|
||||
AzureProviderSpec
|
||||
| BedrockProviderSpec
|
||||
| CohereProviderSpec
|
||||
|
||||
@@ -1,23 +1,16 @@
|
||||
"""Qdrant configuration model."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import field
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
from typing import Literal, cast
|
||||
|
||||
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||
from qdrant_client.models import VectorParams
|
||||
|
||||
from crewai.rag.config.base import BaseRagConfig
|
||||
from crewai.rag.qdrant.constants import DEFAULT_EMBEDDING_MODEL, DEFAULT_STORAGE_PATH
|
||||
from crewai.rag.qdrant.types import QdrantClientParams, QdrantEmbeddingFunctionWrapper
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qdrant_client.models import VectorParams
|
||||
else:
|
||||
VectorParams = Any
|
||||
|
||||
|
||||
def _default_options() -> QdrantClientParams:
|
||||
"""Create default Qdrant client options.
|
||||
|
||||
@@ -33,7 +26,7 @@ def _default_embedding_function() -> QdrantEmbeddingFunctionWrapper:
|
||||
Returns:
|
||||
Default embedding function using fastembed with all-MiniLM-L6-v2.
|
||||
"""
|
||||
from fastembed import TextEmbedding
|
||||
from fastembed import TextEmbedding # type: ignore[import-not-found]
|
||||
|
||||
model = TextEmbedding(model_name=DEFAULT_EMBEDDING_MODEL)
|
||||
|
||||
|
||||
@@ -1,112 +0,0 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"messages":[{"role":"user","content":"Say hello in one word"}],"model":"gpt-4o"}'
|
||||
headers:
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- gzip, deflate, zstd
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '81'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- api.openai.com
|
||||
user-agent:
|
||||
- OpenAI/Python 1.109.1
|
||||
x-stainless-arch:
|
||||
- arm64
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- MacOS
|
||||
x-stainless-package-version:
|
||||
- 1.109.1
|
||||
x-stainless-read-timeout:
|
||||
- '600'
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.12.10
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: !!binary |
|
||||
H4sIAAAAAAAAAwAAAP//jJJNT9wwEIbv+RXunDdVNtoPuteiCoEQJ7SiFYqMPcm6OB7LnvAhtP8d
|
||||
OWE3oYDUiw9+5h2/73heMiHAaNgIUDvJqvU2/1nf1K44vVzePG6vr5cPV2V5/nt7eqG36lcLs6Sg
|
||||
u7+o+KD6rqj1FtmQG7AKKBlT1/l6tSjKYr1a96AljTbJGs/5gvKyKBd5cZIXqzfhjozCCBvxJxNC
|
||||
iJf+TBadxifYiGJ2uGkxRtkgbI5FQkAgm25AxmgiS8cwG6Eix+h612doLX2bwoB1F2Xy5jprJ0A6
|
||||
RyxTtt7W7RvZH41Yanygu/iPFGrjTNxVAWUklx6NTB56us+EuO0Dd+8ygA/Ueq6Y7rF/bl4O7WCc
|
||||
8AgPjImlnWgWs0+aVRpZGhsn8wIl1Q71qByHKzttaAKySeSPXj7rPcQ2rvmf9iNQCj2jrnxAbdT7
|
||||
vGNZwLR+X5UdR9wbhojhwSis2GBI36Cxlp0dNgPic2Rsq9q4BoMPZliP2lfqxwkWSyXna8j22SsA
|
||||
AAD//wMAmJrFFCcDAAA=
|
||||
headers:
|
||||
CF-RAY:
|
||||
- 9a3c18dff8580f53-EWR
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Encoding:
|
||||
- gzip
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Mon, 24 Nov 2025 21:46:08 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Set-Cookie:
|
||||
- FILTERED
|
||||
Strict-Transport-Security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- nosniff
|
||||
access-control-expose-headers:
|
||||
- X-Request-ID
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
openai-organization:
|
||||
- FILTERED
|
||||
openai-processing-ms:
|
||||
- '1096'
|
||||
openai-project:
|
||||
- FILTERED
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
x-envoy-upstream-service-time:
|
||||
- '1138'
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-project-requests:
|
||||
- '10000'
|
||||
x-ratelimit-limit-requests:
|
||||
- '10000'
|
||||
x-ratelimit-limit-tokens:
|
||||
- '30000000'
|
||||
x-ratelimit-remaining-project-requests:
|
||||
- '9999'
|
||||
x-ratelimit-remaining-requests:
|
||||
- '9999'
|
||||
x-ratelimit-remaining-tokens:
|
||||
- '29999992'
|
||||
x-ratelimit-reset-project-requests:
|
||||
- 6ms
|
||||
x-ratelimit-reset-requests:
|
||||
- 6ms
|
||||
x-ratelimit-reset-tokens:
|
||||
- 0s
|
||||
x-request-id:
|
||||
- req_670507131d6c455caf0e8cbc30a1a792
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
@@ -1,113 +0,0 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"messages":[{"role":"user","content":"Return a JSON object with a ''status''
|
||||
field set to ''success''"}],"model":"gpt-4o","response_format":{"type":"json_object"}}'
|
||||
headers:
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- gzip, deflate, zstd
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '160'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- api.openai.com
|
||||
user-agent:
|
||||
- OpenAI/Python 1.109.1
|
||||
x-stainless-arch:
|
||||
- arm64
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- MacOS
|
||||
x-stainless-package-version:
|
||||
- 1.109.1
|
||||
x-stainless-read-timeout:
|
||||
- '600'
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.12.10
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: !!binary |
|
||||
H4sIAAAAAAAAA4xSwW6cMBC98xXWnJeKkC274Zr01vbUVopKhLxmACfGdj1D1Wi1/14ZNgtpU6kX
|
||||
hObNe7z3mGMiBOgGSgGql6wGb9Lb9r4191/46eaD+/j509f9Lvs2PP+47u/usIdNZLjDIyp+Yb1T
|
||||
bvAGWTs7wyqgZIyqV7tim+XZrng/AYNr0ERa5zndujTP8m2a7dOsOBN7pxUSlOJ7IoQQx+kZLdoG
|
||||
f0Epss3LZEAi2SGUlyUhIDgTJyCJNLG0DJsFVM4y2sn1sbJCVEAseaQKyvg+KoVEFVT2tGYFbEeS
|
||||
0bQdjVkB0lrHMoae/D6ckdPFoXGdD+5Af1Ch1VZTXweU5Gx0Q+w8TOgpEeJhamJ8FQ58cIPnmt0T
|
||||
Tp/L81kOluoX8OaMsWNplvH11eYNsbpBltrQqkhQUvXYLMyldTk22q2AZBX5by9vac+xte3+R34B
|
||||
lELP2NQ+YKPV67zLWsB4l/9au1Q8GQbC8FMrrFljiL+hwVaOZj4ZoGdiHOpW2w6DD3q+m9bXO8TD
|
||||
tmizYg/JKfkNAAD//wMA0CE0wkADAAA=
|
||||
headers:
|
||||
CF-RAY:
|
||||
- 9a3c18d7de3c80dc-EWR
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Encoding:
|
||||
- gzip
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Mon, 24 Nov 2025 21:46:06 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Set-Cookie:
|
||||
- FILTERED
|
||||
Strict-Transport-Security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- nosniff
|
||||
access-control-expose-headers:
|
||||
- X-Request-ID
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
openai-organization:
|
||||
- FILTERED
|
||||
openai-processing-ms:
|
||||
- '424'
|
||||
openai-project:
|
||||
- FILTERED
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
x-envoy-upstream-service-time:
|
||||
- '443'
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-project-requests:
|
||||
- '10000'
|
||||
x-ratelimit-limit-requests:
|
||||
- '10000'
|
||||
x-ratelimit-limit-tokens:
|
||||
- '30000000'
|
||||
x-ratelimit-remaining-project-requests:
|
||||
- '9999'
|
||||
x-ratelimit-remaining-requests:
|
||||
- '9999'
|
||||
x-ratelimit-remaining-tokens:
|
||||
- '29999983'
|
||||
x-ratelimit-reset-project-requests:
|
||||
- 6ms
|
||||
x-ratelimit-reset-requests:
|
||||
- 6ms
|
||||
x-ratelimit-reset-tokens:
|
||||
- 0s
|
||||
x-request-id:
|
||||
- req_71bc4c9f29f843d6b3788b119850dfde
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
@@ -1,116 +0,0 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"messages":[{"role":"user","content":"What is the capital of France? Be
|
||||
concise."}],"model":"gpt-4o","response_format":{"type":"json_schema","json_schema":{"name":"AnswerResponse","strict":true,"schema":{"description":"Response
|
||||
model with structured fields.","properties":{"answer":{"description":"The answer
|
||||
to the question","title":"Answer","type":"string"},"confidence":{"description":"Confidence
|
||||
score between 0 and 1","title":"Confidence","type":"number"}},"required":["answer","confidence"],"title":"AnswerResponse","type":"object","additionalProperties":false}}}}'
|
||||
headers:
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- gzip, deflate, zstd
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '571'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- api.openai.com
|
||||
user-agent:
|
||||
- OpenAI/Python 1.109.1
|
||||
x-stainless-arch:
|
||||
- arm64
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- MacOS
|
||||
x-stainless-package-version:
|
||||
- 1.109.1
|
||||
x-stainless-read-timeout:
|
||||
- '600'
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.12.10
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: !!binary |
|
||||
H4sIAAAAAAAAAwAAAP//jFLLbtswELzrK4g9SwFtyA/pmKCH9pRbUVSBwJArmTVFElyqbWr43wvK
|
||||
jqW0KdALDzs7w5ndPWWMgVZQM5AHEeXgTfHQfemO1f0H8cK73fDp82b/iz6uH4+WnC0hTwz3/A1l
|
||||
fGXdSTd4g1E7e4FlQBExqa5225Kv+W5bTsDgFJpE630sSles+bos+L7g2yvx4LREgpp9zRhj7DS9
|
||||
yaJV+BNqxvPXyoBEokeob02MQXAmVUAQaYrCRshnUDob0U6uTw0ISz8wNFA38CiCpgbyJrV0WqGV
|
||||
2EDN76rqvBQI2I0kkn87GrMAhLUuipR/sv50Rc43s8b1Prhn+oMKnbaaDm1AQc4mYxSdhwk9Z4w9
|
||||
TUMZ3+QEH9zgYxvdEafvqvIiB/MWZnC1uoLRRWEWdb7J35FrFUahDS2mClLIA6qZOq9AjEq7BZAt
|
||||
Qv/t5j3tS3Bt+/+RnwEp0UdUrQ+otHybeG4LmI70X223IU+GgTB81xLbqDGkRSjsxGgu9wP0QhGH
|
||||
ttO2x+CDvhxR51tZ7ZFvpFjtIDtnvwEAAP//AwAvoKedTQMAAA==
|
||||
headers:
|
||||
CF-RAY:
|
||||
- 9a3c18cf7fe04253-EWR
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Encoding:
|
||||
- gzip
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Mon, 24 Nov 2025 21:46:05 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Set-Cookie:
|
||||
- FILTERED
|
||||
Strict-Transport-Security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- nosniff
|
||||
access-control-expose-headers:
|
||||
- X-Request-ID
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
openai-organization:
|
||||
- FILTERED
|
||||
openai-processing-ms:
|
||||
- '448'
|
||||
openai-project:
|
||||
- FILTERED
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
x-envoy-upstream-service-time:
|
||||
- '465'
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-project-requests:
|
||||
- '10000'
|
||||
x-ratelimit-limit-requests:
|
||||
- '10000'
|
||||
x-ratelimit-limit-tokens:
|
||||
- '30000000'
|
||||
x-ratelimit-remaining-project-requests:
|
||||
- '9999'
|
||||
x-ratelimit-remaining-requests:
|
||||
- '9999'
|
||||
x-ratelimit-remaining-tokens:
|
||||
- '29999987'
|
||||
x-ratelimit-reset-project-requests:
|
||||
- 6ms
|
||||
x-ratelimit-reset-requests:
|
||||
- 6ms
|
||||
x-ratelimit-reset-tokens:
|
||||
- 0s
|
||||
x-request-id:
|
||||
- req_765510cb1e614ed6a83e665bf7c5a07b
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
@@ -72,8 +72,7 @@ class TestSettings(unittest.TestCase):
|
||||
@patch("crewai.cli.config.TokenManager")
|
||||
def test_reset_settings(self, mock_token_manager):
|
||||
user_settings = {key: f"value_for_{key}" for key in USER_SETTINGS_KEYS}
|
||||
cli_settings = {key: f"value_for_{key}" for key in CLI_SETTINGS_KEYS if key != "oauth2_extra"}
|
||||
cli_settings["oauth2_extra"] = {"scope": "xxx", "other": "yyy"}
|
||||
cli_settings = {key: f"value_for_{key}" for key in CLI_SETTINGS_KEYS}
|
||||
|
||||
settings = Settings(
|
||||
config_path=self.config_path, **user_settings, **cli_settings
|
||||
|
||||
@@ -381,7 +381,6 @@ def test_azure_raises_error_when_endpoint_missing():
|
||||
with pytest.raises(ValueError, match="Azure endpoint is required"):
|
||||
AzureCompletion(model="gpt-4", api_key="test-key")
|
||||
|
||||
|
||||
def test_azure_raises_error_when_api_key_missing():
|
||||
"""Test that AzureCompletion raises ValueError when API key is missing"""
|
||||
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||
@@ -390,8 +389,6 @@ def test_azure_raises_error_when_api_key_missing():
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
with pytest.raises(ValueError, match="Azure API key is required"):
|
||||
AzureCompletion(model="gpt-4", endpoint="https://test.openai.azure.com")
|
||||
|
||||
|
||||
def test_azure_endpoint_configuration():
|
||||
"""
|
||||
Test that Azure endpoint configuration works with multiple environment variable names
|
||||
@@ -1089,27 +1086,3 @@ def test_azure_mistral_and_other_models():
|
||||
)
|
||||
assert "model" in params
|
||||
assert params["model"] == model_name
|
||||
|
||||
|
||||
def test_azure_completion_params_preparation_with_drop_params():
|
||||
"""
|
||||
Test that completion parameters are properly prepared with drop paramaeters attribute respected
|
||||
"""
|
||||
with patch.dict(os.environ, {
|
||||
"AZURE_API_KEY": "test-key",
|
||||
"AZURE_ENDPOINT": "https://models.inference.ai.azure.com"
|
||||
}):
|
||||
llm = LLM(
|
||||
model="azure/o4-mini",
|
||||
drop_params=True,
|
||||
additional_drop_params=["stop"],
|
||||
max_tokens=1000
|
||||
)
|
||||
|
||||
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||
assert isinstance(llm, AzureCompletion)
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
params = llm._prepare_completion_params(messages)
|
||||
|
||||
assert params.get('stop') == None
|
||||
@@ -528,50 +528,3 @@ def test_openai_streaming_with_response_model():
|
||||
|
||||
assert "input" not in call_kwargs
|
||||
assert "text_format" not in call_kwargs
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_openai_response_format_with_pydantic_model():
|
||||
"""
|
||||
Test that response_format with a Pydantic BaseModel returns structured output.
|
||||
"""
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class AnswerResponse(BaseModel):
|
||||
"""Response model with structured fields."""
|
||||
|
||||
answer: str = Field(description="The answer to the question")
|
||||
confidence: float = Field(description="Confidence score between 0 and 1")
|
||||
|
||||
llm = LLM(model="gpt-4o", response_format=AnswerResponse)
|
||||
result = llm.call("What is the capital of France? Be concise.")
|
||||
|
||||
assert isinstance(result, AnswerResponse)
|
||||
assert result.answer is not None
|
||||
assert 0 <= result.confidence <= 1
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_openai_response_format_with_dict():
|
||||
"""
|
||||
Test that response_format with a dict returns JSON output.
|
||||
"""
|
||||
import json
|
||||
|
||||
llm = LLM(model="gpt-4o", response_format={"type": "json_object"})
|
||||
result = llm.call("Return a JSON object with a 'status' field set to 'success'")
|
||||
|
||||
parsed = json.loads(result)
|
||||
assert "status" in parsed
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_openai_response_format_none():
|
||||
"""
|
||||
Test that when response_format is None, the API returns plain text.
|
||||
"""
|
||||
llm = LLM(model="gpt-4o", response_format=None)
|
||||
result = llm.call("Say hello in one word")
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert len(result) > 0
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
"""Tests for SSE transport."""
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.mcp.transports.sse import SSETransport
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sse_transport_connect_does_not_pass_invalid_args():
|
||||
"""Test that SSETransport.connect() doesn't pass invalid args to sse_client.
|
||||
|
||||
The sse_client function does not accept terminate_on_close parameter.
|
||||
"""
|
||||
transport = SSETransport(
|
||||
url="http://localhost:9999/sse",
|
||||
headers={"Authorization": "Bearer test"},
|
||||
)
|
||||
|
||||
with pytest.raises(ConnectionError) as exc_info:
|
||||
await transport.connect()
|
||||
|
||||
assert "unexpected keyword argument" not in str(exc_info.value)
|
||||
@@ -1,364 +0,0 @@
|
||||
"""Tests for backward compatibility of embedding provider configurations."""
|
||||
|
||||
from crewai.rag.embeddings.factory import build_embedder, PROVIDER_PATHS
|
||||
from crewai.rag.embeddings.providers.openai.openai_provider import OpenAIProvider
|
||||
from crewai.rag.embeddings.providers.cohere.cohere_provider import CohereProvider
|
||||
from crewai.rag.embeddings.providers.google.generative_ai import GenerativeAiProvider
|
||||
from crewai.rag.embeddings.providers.google.vertex import VertexAIProvider
|
||||
from crewai.rag.embeddings.providers.microsoft.azure import AzureProvider
|
||||
from crewai.rag.embeddings.providers.jina.jina_provider import JinaProvider
|
||||
from crewai.rag.embeddings.providers.ollama.ollama_provider import OllamaProvider
|
||||
from crewai.rag.embeddings.providers.aws.bedrock import BedrockProvider
|
||||
from crewai.rag.embeddings.providers.text2vec.text2vec_provider import Text2VecProvider
|
||||
from crewai.rag.embeddings.providers.sentence_transformer.sentence_transformer_provider import (
|
||||
SentenceTransformerProvider,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.instructor.instructor_provider import InstructorProvider
|
||||
from crewai.rag.embeddings.providers.openclip.openclip_provider import OpenCLIPProvider
|
||||
|
||||
|
||||
class TestGoogleProviderAlias:
|
||||
"""Test that 'google' provider name alias works for backward compatibility."""
|
||||
|
||||
def test_google_alias_in_provider_paths(self):
|
||||
"""Verify 'google' is registered as an alias for google-generativeai."""
|
||||
assert "google" in PROVIDER_PATHS
|
||||
assert "google-generativeai" in PROVIDER_PATHS
|
||||
assert PROVIDER_PATHS["google"] == PROVIDER_PATHS["google-generativeai"]
|
||||
|
||||
|
||||
class TestModelKeyBackwardCompatibility:
|
||||
"""Test that 'model' config key works as alias for 'model_name'."""
|
||||
|
||||
def test_openai_provider_accepts_model_key(self):
|
||||
"""Test OpenAI provider accepts 'model' as alias for 'model_name'."""
|
||||
provider = OpenAIProvider(
|
||||
api_key="test-key",
|
||||
model="text-embedding-3-small",
|
||||
)
|
||||
assert provider.model_name == "text-embedding-3-small"
|
||||
|
||||
def test_openai_provider_model_name_takes_precedence(self):
|
||||
"""Test that model_name takes precedence when both are provided."""
|
||||
provider = OpenAIProvider(
|
||||
api_key="test-key",
|
||||
model_name="text-embedding-3-large",
|
||||
)
|
||||
assert provider.model_name == "text-embedding-3-large"
|
||||
|
||||
def test_cohere_provider_accepts_model_key(self):
|
||||
"""Test Cohere provider accepts 'model' as alias for 'model_name'."""
|
||||
provider = CohereProvider(
|
||||
api_key="test-key",
|
||||
model="embed-english-v3.0",
|
||||
)
|
||||
assert provider.model_name == "embed-english-v3.0"
|
||||
|
||||
def test_google_generativeai_provider_accepts_model_key(self):
|
||||
"""Test Google Generative AI provider accepts 'model' as alias."""
|
||||
provider = GenerativeAiProvider(
|
||||
api_key="test-key",
|
||||
model="gemini-embedding-001",
|
||||
)
|
||||
assert provider.model_name == "gemini-embedding-001"
|
||||
|
||||
def test_google_vertex_provider_accepts_model_key(self):
|
||||
"""Test Google Vertex AI provider accepts 'model' as alias."""
|
||||
provider = VertexAIProvider(
|
||||
api_key="test-key",
|
||||
model="text-embedding-004",
|
||||
)
|
||||
assert provider.model_name == "text-embedding-004"
|
||||
|
||||
def test_azure_provider_accepts_model_key(self):
|
||||
"""Test Azure provider accepts 'model' as alias for 'model_name'."""
|
||||
provider = AzureProvider(
|
||||
api_key="test-key",
|
||||
deployment_id="test-deployment",
|
||||
model="text-embedding-ada-002",
|
||||
)
|
||||
assert provider.model_name == "text-embedding-ada-002"
|
||||
|
||||
def test_jina_provider_accepts_model_key(self):
|
||||
"""Test Jina provider accepts 'model' as alias for 'model_name'."""
|
||||
provider = JinaProvider(
|
||||
api_key="test-key",
|
||||
model="jina-embeddings-v3",
|
||||
)
|
||||
assert provider.model_name == "jina-embeddings-v3"
|
||||
|
||||
def test_ollama_provider_accepts_model_key(self):
|
||||
"""Test Ollama provider accepts 'model' as alias for 'model_name'."""
|
||||
provider = OllamaProvider(
|
||||
model="nomic-embed-text",
|
||||
)
|
||||
assert provider.model_name == "nomic-embed-text"
|
||||
|
||||
def test_text2vec_provider_accepts_model_key(self):
|
||||
"""Test Text2Vec provider accepts 'model' as alias for 'model_name'."""
|
||||
provider = Text2VecProvider(
|
||||
model="shibing624/text2vec-base-multilingual",
|
||||
)
|
||||
assert provider.model_name == "shibing624/text2vec-base-multilingual"
|
||||
|
||||
def test_sentence_transformer_provider_accepts_model_key(self):
|
||||
"""Test SentenceTransformer provider accepts 'model' as alias."""
|
||||
provider = SentenceTransformerProvider(
|
||||
model="all-mpnet-base-v2",
|
||||
)
|
||||
assert provider.model_name == "all-mpnet-base-v2"
|
||||
|
||||
def test_instructor_provider_accepts_model_key(self):
|
||||
"""Test Instructor provider accepts 'model' as alias for 'model_name'."""
|
||||
provider = InstructorProvider(
|
||||
model="hkunlp/instructor-xl",
|
||||
)
|
||||
assert provider.model_name == "hkunlp/instructor-xl"
|
||||
|
||||
def test_openclip_provider_accepts_model_key(self):
|
||||
"""Test OpenCLIP provider accepts 'model' as alias for 'model_name'."""
|
||||
provider = OpenCLIPProvider(
|
||||
model="ViT-B-16",
|
||||
)
|
||||
assert provider.model_name == "ViT-B-16"
|
||||
|
||||
|
||||
class TestTaskTypeConfiguration:
|
||||
"""Test that task_type configuration works correctly."""
|
||||
|
||||
def test_google_provider_accepts_lowercase_task_type(self):
|
||||
"""Test Google provider accepts lowercase task_type."""
|
||||
provider = GenerativeAiProvider(
|
||||
api_key="test-key",
|
||||
task_type="retrieval_document",
|
||||
)
|
||||
assert provider.task_type == "retrieval_document"
|
||||
|
||||
def test_google_provider_accepts_uppercase_task_type(self):
|
||||
"""Test Google provider accepts uppercase task_type."""
|
||||
provider = GenerativeAiProvider(
|
||||
api_key="test-key",
|
||||
task_type="RETRIEVAL_QUERY",
|
||||
)
|
||||
assert provider.task_type == "RETRIEVAL_QUERY"
|
||||
|
||||
def test_google_provider_default_task_type(self):
|
||||
"""Test Google provider has correct default task_type."""
|
||||
provider = GenerativeAiProvider(
|
||||
api_key="test-key",
|
||||
)
|
||||
assert provider.task_type == "RETRIEVAL_DOCUMENT"
|
||||
|
||||
|
||||
class TestFactoryBackwardCompatibility:
|
||||
"""Test factory function with backward compatible configurations."""
|
||||
|
||||
def test_factory_with_google_alias(self):
|
||||
"""Test factory resolves 'google' to google-generativeai provider."""
|
||||
config = {
|
||||
"provider": "google",
|
||||
"config": {
|
||||
"api_key": "test-key",
|
||||
"model": "gemini-embedding-001",
|
||||
},
|
||||
}
|
||||
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
with patch("crewai.rag.embeddings.factory.import_and_validate_definition") as mock_import:
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
|
||||
build_embedder(config)
|
||||
|
||||
mock_import.assert_called_once_with(
|
||||
"crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider"
|
||||
)
|
||||
|
||||
def test_factory_with_model_key_openai(self):
|
||||
"""Test factory passes 'model' config to OpenAI provider."""
|
||||
config = {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"api_key": "test-key",
|
||||
"model": "text-embedding-3-small",
|
||||
},
|
||||
}
|
||||
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
with patch("crewai.rag.embeddings.factory.import_and_validate_definition") as mock_import:
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
|
||||
build_embedder(config)
|
||||
|
||||
call_kwargs = mock_provider_class.call_args.kwargs
|
||||
assert call_kwargs["model"] == "text-embedding-3-small"
|
||||
|
||||
|
||||
class TestDocumentationCodeSnippets:
|
||||
"""Test code snippets from documentation work correctly."""
|
||||
|
||||
def test_memory_openai_config(self):
|
||||
"""Test OpenAI config from memory.mdx documentation."""
|
||||
provider = OpenAIProvider(
|
||||
model_name="text-embedding-3-small",
|
||||
)
|
||||
assert provider.model_name == "text-embedding-3-small"
|
||||
|
||||
def test_memory_openai_config_with_options(self):
|
||||
"""Test OpenAI config with all options from memory.mdx."""
|
||||
provider = OpenAIProvider(
|
||||
api_key="your-openai-api-key",
|
||||
model_name="text-embedding-3-large",
|
||||
dimensions=1536,
|
||||
organization_id="your-org-id",
|
||||
)
|
||||
assert provider.model_name == "text-embedding-3-large"
|
||||
assert provider.dimensions == 1536
|
||||
|
||||
def test_memory_azure_config(self):
|
||||
"""Test Azure config from memory.mdx documentation."""
|
||||
provider = AzureProvider(
|
||||
api_key="your-azure-key",
|
||||
api_base="https://your-resource.openai.azure.com/",
|
||||
api_type="azure",
|
||||
api_version="2023-05-15",
|
||||
model_name="text-embedding-3-small",
|
||||
deployment_id="your-deployment-name",
|
||||
)
|
||||
assert provider.model_name == "text-embedding-3-small"
|
||||
assert provider.api_type == "azure"
|
||||
|
||||
def test_memory_google_generativeai_config(self):
|
||||
"""Test Google Generative AI config from memory.mdx documentation."""
|
||||
provider = GenerativeAiProvider(
|
||||
api_key="your-google-api-key",
|
||||
model_name="gemini-embedding-001",
|
||||
)
|
||||
assert provider.model_name == "gemini-embedding-001"
|
||||
|
||||
def test_memory_cohere_config(self):
|
||||
"""Test Cohere config from memory.mdx documentation."""
|
||||
provider = CohereProvider(
|
||||
api_key="your-cohere-api-key",
|
||||
model_name="embed-english-v3.0",
|
||||
)
|
||||
assert provider.model_name == "embed-english-v3.0"
|
||||
|
||||
def test_knowledge_agent_embedder_config(self):
|
||||
"""Test agent embedder config from knowledge.mdx documentation."""
|
||||
provider = GenerativeAiProvider(
|
||||
model_name="gemini-embedding-001",
|
||||
api_key="your-google-key",
|
||||
)
|
||||
assert provider.model_name == "gemini-embedding-001"
|
||||
|
||||
def test_ragtool_openai_config(self):
|
||||
"""Test RagTool OpenAI config from ragtool.mdx documentation."""
|
||||
provider = OpenAIProvider(
|
||||
model_name="text-embedding-3-small",
|
||||
)
|
||||
assert provider.model_name == "text-embedding-3-small"
|
||||
|
||||
def test_ragtool_cohere_config(self):
|
||||
"""Test RagTool Cohere config from ragtool.mdx documentation."""
|
||||
provider = CohereProvider(
|
||||
api_key="your-api-key",
|
||||
model_name="embed-english-v3.0",
|
||||
)
|
||||
assert provider.model_name == "embed-english-v3.0"
|
||||
|
||||
def test_ragtool_ollama_config(self):
|
||||
"""Test RagTool Ollama config from ragtool.mdx documentation."""
|
||||
provider = OllamaProvider(
|
||||
model_name="llama2",
|
||||
url="http://localhost:11434/api/embeddings",
|
||||
)
|
||||
assert provider.model_name == "llama2"
|
||||
|
||||
def test_ragtool_azure_config(self):
|
||||
"""Test RagTool Azure config from ragtool.mdx documentation."""
|
||||
provider = AzureProvider(
|
||||
deployment_id="your-deployment-id",
|
||||
api_key="your-api-key",
|
||||
api_base="https://your-resource.openai.azure.com",
|
||||
api_version="2024-02-01",
|
||||
model_name="text-embedding-ada-002",
|
||||
api_type="azure",
|
||||
)
|
||||
assert provider.model_name == "text-embedding-ada-002"
|
||||
assert provider.deployment_id == "your-deployment-id"
|
||||
|
||||
def test_ragtool_google_generativeai_config(self):
|
||||
"""Test RagTool Google Generative AI config from ragtool.mdx."""
|
||||
provider = GenerativeAiProvider(
|
||||
api_key="your-api-key",
|
||||
model_name="gemini-embedding-001",
|
||||
task_type="RETRIEVAL_DOCUMENT",
|
||||
)
|
||||
assert provider.model_name == "gemini-embedding-001"
|
||||
assert provider.task_type == "RETRIEVAL_DOCUMENT"
|
||||
|
||||
def test_ragtool_jina_config(self):
|
||||
"""Test RagTool Jina config from ragtool.mdx documentation."""
|
||||
provider = JinaProvider(
|
||||
api_key="your-api-key",
|
||||
model_name="jina-embeddings-v3",
|
||||
)
|
||||
assert provider.model_name == "jina-embeddings-v3"
|
||||
|
||||
def test_ragtool_sentence_transformer_config(self):
|
||||
"""Test RagTool SentenceTransformer config from ragtool.mdx."""
|
||||
provider = SentenceTransformerProvider(
|
||||
model_name="all-mpnet-base-v2",
|
||||
device="cuda",
|
||||
normalize_embeddings=True,
|
||||
)
|
||||
assert provider.model_name == "all-mpnet-base-v2"
|
||||
assert provider.device == "cuda"
|
||||
assert provider.normalize_embeddings is True
|
||||
|
||||
|
||||
class TestLegacyConfigurationFormats:
|
||||
"""Test legacy configuration formats that should still work."""
|
||||
|
||||
def test_legacy_google_with_model_key(self):
|
||||
"""Test legacy Google config using 'model' instead of 'model_name'."""
|
||||
provider = GenerativeAiProvider(
|
||||
api_key="test-key",
|
||||
model="text-embedding-005",
|
||||
task_type="retrieval_document",
|
||||
)
|
||||
assert provider.model_name == "text-embedding-005"
|
||||
assert provider.task_type == "retrieval_document"
|
||||
|
||||
def test_legacy_openai_with_model_key(self):
|
||||
"""Test legacy OpenAI config using 'model' instead of 'model_name'."""
|
||||
provider = OpenAIProvider(
|
||||
api_key="test-key",
|
||||
model="text-embedding-ada-002",
|
||||
)
|
||||
assert provider.model_name == "text-embedding-ada-002"
|
||||
|
||||
def test_legacy_cohere_with_model_key(self):
|
||||
"""Test legacy Cohere config using 'model' instead of 'model_name'."""
|
||||
provider = CohereProvider(
|
||||
api_key="test-key",
|
||||
model="embed-multilingual-v3.0",
|
||||
)
|
||||
assert provider.model_name == "embed-multilingual-v3.0"
|
||||
|
||||
def test_legacy_azure_with_model_key(self):
|
||||
"""Test legacy Azure config using 'model' instead of 'model_name'."""
|
||||
provider = AzureProvider(
|
||||
api_key="test-key",
|
||||
deployment_id="test-deployment",
|
||||
model="text-embedding-3-large",
|
||||
)
|
||||
assert provider.model_name == "text-embedding-3-large"
|
||||
@@ -1,82 +0,0 @@
|
||||
"""Tests for RAGStorage custom path functionality."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from crewai.memory.storage.rag_storage import RAGStorage
|
||||
|
||||
|
||||
@patch("crewai.memory.storage.rag_storage.create_client")
|
||||
@patch("crewai.memory.storage.rag_storage.build_embedder")
|
||||
def test_rag_storage_custom_path(
|
||||
mock_build_embedder: MagicMock,
|
||||
mock_create_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test RAGStorage uses custom path when provided."""
|
||||
mock_build_embedder.return_value = MagicMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_create_client.return_value = MagicMock()
|
||||
|
||||
custom_path = "/custom/memory/path"
|
||||
embedder_config = {"provider": "openai", "config": {"model": "text-embedding-3-small"}}
|
||||
|
||||
RAGStorage(
|
||||
type="short_term",
|
||||
crew=None,
|
||||
path=custom_path,
|
||||
embedder_config=embedder_config,
|
||||
)
|
||||
|
||||
mock_create_client.assert_called_once()
|
||||
config_arg = mock_create_client.call_args[0][0]
|
||||
assert config_arg.settings.persist_directory == custom_path
|
||||
|
||||
|
||||
@patch("crewai.memory.storage.rag_storage.create_client")
|
||||
@patch("crewai.memory.storage.rag_storage.build_embedder")
|
||||
def test_rag_storage_default_path_when_none(
|
||||
mock_build_embedder: MagicMock,
|
||||
mock_create_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test RAGStorage uses default path when no custom path is provided."""
|
||||
mock_build_embedder.return_value = MagicMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_create_client.return_value = MagicMock()
|
||||
|
||||
embedder_config = {"provider": "openai", "config": {"model": "text-embedding-3-small"}}
|
||||
|
||||
storage = RAGStorage(
|
||||
type="short_term",
|
||||
crew=None,
|
||||
path=None,
|
||||
embedder_config=embedder_config,
|
||||
)
|
||||
|
||||
mock_create_client.assert_called_once()
|
||||
assert storage.path is None
|
||||
|
||||
|
||||
@patch("crewai.memory.storage.rag_storage.create_client")
|
||||
@patch("crewai.memory.storage.rag_storage.build_embedder")
|
||||
def test_rag_storage_custom_path_with_batch_size(
|
||||
mock_build_embedder: MagicMock,
|
||||
mock_create_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test RAGStorage uses custom path with batch_size in config."""
|
||||
mock_build_embedder.return_value = MagicMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_create_client.return_value = MagicMock()
|
||||
|
||||
custom_path = "/custom/batch/path"
|
||||
embedder_config = {
|
||||
"provider": "openai",
|
||||
"config": {"model": "text-embedding-3-small", "batch_size": 100},
|
||||
}
|
||||
|
||||
RAGStorage(
|
||||
type="long_term",
|
||||
crew=None,
|
||||
path=custom_path,
|
||||
embedder_config=embedder_config,
|
||||
)
|
||||
|
||||
mock_create_client.assert_called_once()
|
||||
config_arg = mock_create_client.call_args[0][0]
|
||||
assert config_arg.settings.persist_directory == custom_path
|
||||
assert config_arg.batch_size == 100
|
||||
@@ -4772,93 +4772,3 @@ def test_ensure_exchanged_messages_are_propagated_to_external_memory():
|
||||
assert "Researcher" in messages[0]["content"]
|
||||
assert messages[1]["role"] == "user"
|
||||
assert "Research a topic to teach a kid aged 6 about math" in messages[1]["content"]
|
||||
|
||||
|
||||
def test_crew_planning_with_mismatched_task_order():
|
||||
"""Test that crew planning correctly matches plans to tasks even when LLM returns them out of order.
|
||||
|
||||
This test reproduces the bug reported in issue #3953 where the task planner
|
||||
returns plans in the wrong order (e.g., starting with Task 21 instead of Task 1),
|
||||
causing plans to be attached to the wrong tasks.
|
||||
"""
|
||||
from crewai.utilities.planning_handler import PlanPerTask, PlannerTaskPydanticOutput
|
||||
|
||||
# Create 5 tasks with distinct descriptions
|
||||
tasks = []
|
||||
agents = []
|
||||
for i in range(1, 6):
|
||||
agent = Agent(
|
||||
role=f"Agent {i}",
|
||||
goal=f"Goal {i}",
|
||||
backstory=f"Backstory {i}",
|
||||
)
|
||||
agents.append(agent)
|
||||
task = Task(
|
||||
description=f"Task {i} description",
|
||||
expected_output=f"Output {i}",
|
||||
agent=agent,
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
crew = Crew(
|
||||
agents=agents,
|
||||
tasks=tasks,
|
||||
planning=True,
|
||||
planning_llm="gpt-4o-mini",
|
||||
)
|
||||
|
||||
# Mock the LLM response to return plans in the WRONG order
|
||||
# Simulating the bug where Task 5 plan comes first, then Task 3, etc.
|
||||
wrong_order_plans = [
|
||||
PlanPerTask(
|
||||
task="Task Number 5 - Task 5 description",
|
||||
plan="\n\nPlan for task 5"
|
||||
),
|
||||
PlanPerTask(
|
||||
task="Task Number 3 - Task 3 description",
|
||||
plan="\n\nPlan for task 3"
|
||||
),
|
||||
PlanPerTask(
|
||||
task="Task Number 1 - Task 1 description",
|
||||
plan="\n\nPlan for task 1"
|
||||
),
|
||||
PlanPerTask(
|
||||
task="Task Number 4 - Task 4 description",
|
||||
plan="\n\nPlan for task 4"
|
||||
),
|
||||
PlanPerTask(
|
||||
task="Task Number 2 - Task 2 description",
|
||||
plan="\n\nPlan for task 2"
|
||||
),
|
||||
]
|
||||
|
||||
with patch.object(Task, "execute_sync") as mock_execute:
|
||||
mock_execute.return_value = TaskOutput(
|
||||
description="Planning task",
|
||||
agent="planner",
|
||||
pydantic=PlannerTaskPydanticOutput(
|
||||
list_of_plans_per_task=wrong_order_plans
|
||||
),
|
||||
)
|
||||
|
||||
# Call the planning method
|
||||
crew._handle_crew_planning()
|
||||
|
||||
# Verify that each task has the CORRECT plan appended to its description
|
||||
# Task 1 should have "Plan for task 1", not "Plan for task 5"
|
||||
assert "Plan for task 1" in crew.tasks[0].description, \
|
||||
f"Task 1 should have 'Plan for task 1' but got: {crew.tasks[0].description}"
|
||||
assert "Plan for task 2" in crew.tasks[1].description, \
|
||||
f"Task 2 should have 'Plan for task 2' but got: {crew.tasks[1].description}"
|
||||
assert "Plan for task 3" in crew.tasks[2].description, \
|
||||
f"Task 3 should have 'Plan for task 3' but got: {crew.tasks[2].description}"
|
||||
assert "Plan for task 4" in crew.tasks[3].description, \
|
||||
f"Task 4 should have 'Plan for task 4' but got: {crew.tasks[3].description}"
|
||||
assert "Plan for task 5" in crew.tasks[4].description, \
|
||||
f"Task 5 should have 'Plan for task 5' but got: {crew.tasks[4].description}"
|
||||
|
||||
# Also verify that wrong plans are NOT in the wrong tasks
|
||||
assert "Plan for task 5" not in crew.tasks[0].description, \
|
||||
"Task 1 should not have Plan for task 5"
|
||||
assert "Plan for task 3" not in crew.tasks[1].description, \
|
||||
"Task 2 should not have Plan for task 3"
|
||||
|
||||
@@ -723,11 +723,11 @@ def test_structured_flow_event_emission():
|
||||
assert isinstance(received_events[3], MethodExecutionStartedEvent)
|
||||
assert received_events[3].method_name == "send_welcome_message"
|
||||
assert received_events[3].params == {}
|
||||
assert received_events[3].state["sent"] is False
|
||||
assert received_events[3].state.sent is False
|
||||
|
||||
assert isinstance(received_events[4], MethodExecutionFinishedEvent)
|
||||
assert received_events[4].method_name == "send_welcome_message"
|
||||
assert received_events[4].state["sent"] is True
|
||||
assert received_events[4].state.sent is True
|
||||
assert received_events[4].result == "Welcome, Anakin!"
|
||||
|
||||
assert isinstance(received_events[5], FlowFinishedEvent)
|
||||
|
||||
@@ -415,256 +415,4 @@ def test_router_paths_not_in_and_conditions():
|
||||
|
||||
assert "step_1" in targets
|
||||
assert "step_3_or" in targets
|
||||
assert "step_2_and" not in targets
|
||||
|
||||
|
||||
def test_chained_routers_no_self_loops():
|
||||
"""Test that chained routers don't create self-referencing edges.
|
||||
|
||||
This tests the bug where routers with string triggers (like 'auth', 'exp')
|
||||
would incorrectly create edges to themselves when another router outputs
|
||||
those strings.
|
||||
"""
|
||||
|
||||
class ChainedRouterFlow(Flow):
|
||||
"""Flow with multiple chained routers using string outputs."""
|
||||
|
||||
@start()
|
||||
def entrance(self):
|
||||
return "started"
|
||||
|
||||
@router(entrance)
|
||||
def session_in_cache(self):
|
||||
return "exp"
|
||||
|
||||
@router("exp")
|
||||
def check_exp(self):
|
||||
return "auth"
|
||||
|
||||
@router("auth")
|
||||
def call_ai_auth(self):
|
||||
return "action"
|
||||
|
||||
@listen("action")
|
||||
def forward_to_action(self):
|
||||
return "done"
|
||||
|
||||
@listen("authenticate")
|
||||
def forward_to_authenticate(self):
|
||||
return "need_auth"
|
||||
|
||||
flow = ChainedRouterFlow()
|
||||
structure = build_flow_structure(flow)
|
||||
|
||||
# Check that no self-loops exist
|
||||
for edge in structure["edges"]:
|
||||
assert edge["source"] != edge["target"], (
|
||||
f"Self-loop detected: {edge['source']} -> {edge['target']}"
|
||||
)
|
||||
|
||||
# Verify correct connections
|
||||
router_edges = [edge for edge in structure["edges"] if edge["is_router_path"]]
|
||||
|
||||
# session_in_cache -> check_exp (via 'exp')
|
||||
exp_edges = [
|
||||
edge
|
||||
for edge in router_edges
|
||||
if edge["router_path_label"] == "exp" and edge["source"] == "session_in_cache"
|
||||
]
|
||||
assert len(exp_edges) == 1
|
||||
assert exp_edges[0]["target"] == "check_exp"
|
||||
|
||||
# check_exp -> call_ai_auth (via 'auth')
|
||||
auth_edges = [
|
||||
edge
|
||||
for edge in router_edges
|
||||
if edge["router_path_label"] == "auth" and edge["source"] == "check_exp"
|
||||
]
|
||||
assert len(auth_edges) == 1
|
||||
assert auth_edges[0]["target"] == "call_ai_auth"
|
||||
|
||||
# call_ai_auth -> forward_to_action (via 'action')
|
||||
action_edges = [
|
||||
edge
|
||||
for edge in router_edges
|
||||
if edge["router_path_label"] == "action" and edge["source"] == "call_ai_auth"
|
||||
]
|
||||
assert len(action_edges) == 1
|
||||
assert action_edges[0]["target"] == "forward_to_action"
|
||||
|
||||
|
||||
def test_routers_with_shared_output_strings():
|
||||
"""Test that routers with shared output strings don't create incorrect edges.
|
||||
|
||||
This tests a scenario where multiple routers can output the same string,
|
||||
ensuring the visualization only creates edges for the router that actually
|
||||
outputs the string, not all routers.
|
||||
"""
|
||||
|
||||
class SharedOutputRouterFlow(Flow):
|
||||
"""Flow where multiple routers can output 'auth'."""
|
||||
|
||||
@start()
|
||||
def start(self):
|
||||
return "started"
|
||||
|
||||
@router(start)
|
||||
def router_a(self):
|
||||
# This router can output 'auth' or 'skip'
|
||||
return "auth"
|
||||
|
||||
@router("auth")
|
||||
def router_b(self):
|
||||
# This router listens to 'auth' but outputs 'done'
|
||||
return "done"
|
||||
|
||||
@listen("done")
|
||||
def finalize(self):
|
||||
return "complete"
|
||||
|
||||
@listen("skip")
|
||||
def handle_skip(self):
|
||||
return "skipped"
|
||||
|
||||
flow = SharedOutputRouterFlow()
|
||||
structure = build_flow_structure(flow)
|
||||
|
||||
# Check no self-loops
|
||||
for edge in structure["edges"]:
|
||||
assert edge["source"] != edge["target"], (
|
||||
f"Self-loop detected: {edge['source']} -> {edge['target']}"
|
||||
)
|
||||
|
||||
# router_a should connect to router_b via 'auth'
|
||||
router_edges = [edge for edge in structure["edges"] if edge["is_router_path"]]
|
||||
auth_from_a = [
|
||||
edge
|
||||
for edge in router_edges
|
||||
if edge["source"] == "router_a" and edge["router_path_label"] == "auth"
|
||||
]
|
||||
assert len(auth_from_a) == 1
|
||||
assert auth_from_a[0]["target"] == "router_b"
|
||||
|
||||
# router_b should connect to finalize via 'done'
|
||||
done_from_b = [
|
||||
edge
|
||||
for edge in router_edges
|
||||
if edge["source"] == "router_b" and edge["router_path_label"] == "done"
|
||||
]
|
||||
assert len(done_from_b) == 1
|
||||
assert done_from_b[0]["target"] == "finalize"
|
||||
|
||||
|
||||
def test_warning_for_router_without_paths(caplog):
|
||||
"""Test that a warning is logged when a router has no determinable paths."""
|
||||
import logging
|
||||
|
||||
class RouterWithoutPathsFlow(Flow):
|
||||
"""Flow with a router that returns a dynamic value."""
|
||||
|
||||
@start()
|
||||
def begin(self):
|
||||
return "started"
|
||||
|
||||
@router(begin)
|
||||
def dynamic_router(self):
|
||||
# Returns a variable that can't be statically analyzed
|
||||
import random
|
||||
return random.choice(["path_a", "path_b"])
|
||||
|
||||
@listen("path_a")
|
||||
def handle_a(self):
|
||||
return "a"
|
||||
|
||||
@listen("path_b")
|
||||
def handle_b(self):
|
||||
return "b"
|
||||
|
||||
flow = RouterWithoutPathsFlow()
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
build_flow_structure(flow)
|
||||
|
||||
# Check that warning was logged for the router
|
||||
assert any(
|
||||
"Could not determine return paths for router 'dynamic_router'" in record.message
|
||||
for record in caplog.records
|
||||
)
|
||||
|
||||
# Check that error was logged for orphaned triggers
|
||||
assert any(
|
||||
"Found listeners waiting for triggers" in record.message
|
||||
for record in caplog.records
|
||||
)
|
||||
|
||||
|
||||
def test_warning_for_orphaned_listeners(caplog):
|
||||
"""Test that an error is logged when listeners wait for triggers no router outputs."""
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
class OrphanedListenerFlow(Flow):
|
||||
"""Flow where a listener waits for a trigger that no router outputs."""
|
||||
|
||||
@start()
|
||||
def begin(self):
|
||||
return "started"
|
||||
|
||||
@router(begin)
|
||||
def my_router(self) -> Literal["option_a", "option_b"]:
|
||||
return "option_a"
|
||||
|
||||
@listen("option_a")
|
||||
def handle_a(self):
|
||||
return "a"
|
||||
|
||||
@listen("option_c") # This trigger is never output by any router
|
||||
def handle_orphan(self):
|
||||
return "orphan"
|
||||
|
||||
flow = OrphanedListenerFlow()
|
||||
|
||||
with caplog.at_level(logging.ERROR):
|
||||
build_flow_structure(flow)
|
||||
|
||||
# Check that error was logged for orphaned trigger
|
||||
assert any(
|
||||
"Found listeners waiting for triggers" in record.message
|
||||
and "option_c" in record.message
|
||||
for record in caplog.records
|
||||
)
|
||||
|
||||
|
||||
def test_no_warning_for_properly_typed_router(caplog):
|
||||
"""Test that no warning is logged when router has proper type annotations."""
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
class ProperlyTypedRouterFlow(Flow):
|
||||
"""Flow with properly typed router."""
|
||||
|
||||
@start()
|
||||
def begin(self):
|
||||
return "started"
|
||||
|
||||
@router(begin)
|
||||
def typed_router(self) -> Literal["path_a", "path_b"]:
|
||||
return "path_a"
|
||||
|
||||
@listen("path_a")
|
||||
def handle_a(self):
|
||||
return "a"
|
||||
|
||||
@listen("path_b")
|
||||
def handle_b(self):
|
||||
return "b"
|
||||
|
||||
flow = ProperlyTypedRouterFlow()
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
build_flow_structure(flow)
|
||||
|
||||
# No warnings should be logged
|
||||
warning_messages = [r.message for r in caplog.records if r.levelno >= logging.WARNING]
|
||||
assert not any("Could not determine return paths" in msg for msg in warning_messages)
|
||||
assert not any("Found listeners waiting for triggers" in msg for msg in warning_messages)
|
||||
assert "step_2_and" not in targets
|
||||
@@ -243,11 +243,7 @@ def test_validate_call_params_not_supported():
|
||||
|
||||
# Patch supports_response_schema to simulate an unsupported model.
|
||||
with patch("crewai.llm.supports_response_schema", return_value=False):
|
||||
llm = LLM(
|
||||
model="gemini/gemini-1.5-pro",
|
||||
response_format=DummyResponse,
|
||||
is_litellm=True,
|
||||
)
|
||||
llm = LLM(model="gemini/gemini-1.5-pro", response_format=DummyResponse, is_litellm=True)
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
llm._validate_call_params()
|
||||
assert "does not support response_format" in str(excinfo.value)
|
||||
@@ -706,16 +702,13 @@ def test_ollama_does_not_modify_when_last_is_user(ollama_llm):
|
||||
|
||||
assert formatted == original_messages
|
||||
|
||||
|
||||
def test_native_provider_raises_error_when_supported_but_fails():
|
||||
"""Test that when a native provider is in SUPPORTED_NATIVE_PROVIDERS but fails to instantiate, we raise the error."""
|
||||
with patch("crewai.llm.SUPPORTED_NATIVE_PROVIDERS", ["openai"]):
|
||||
with patch("crewai.llm.LLM._get_native_provider") as mock_get_native:
|
||||
# Mock that provider exists but throws an error when instantiated
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.side_effect = ValueError(
|
||||
"Native provider initialization failed"
|
||||
)
|
||||
mock_provider.side_effect = ValueError("Native provider initialization failed")
|
||||
mock_get_native.return_value = mock_provider
|
||||
|
||||
with pytest.raises(ImportError) as excinfo:
|
||||
@@ -758,16 +751,16 @@ def test_prefixed_models_with_valid_constants_use_native_sdk():
|
||||
|
||||
|
||||
def test_prefixed_models_with_invalid_constants_use_litellm():
|
||||
"""Test that models with native provider prefixes use LiteLLM when model is NOT in constants and does NOT match patterns."""
|
||||
"""Test that models with native provider prefixes use LiteLLM when model is NOT in constants."""
|
||||
# Test openai/ prefix with non-OpenAI model (not in OPENAI_MODELS) → LiteLLM
|
||||
llm = LLM(model="openai/gemini-2.5-flash", is_litellm=False)
|
||||
assert llm.is_litellm is True
|
||||
assert llm.model == "openai/gemini-2.5-flash"
|
||||
|
||||
# Test openai/ prefix with model that doesn't match patterns (e.g. no gpt- prefix) → LiteLLM
|
||||
llm2 = LLM(model="openai/custom-finetune-model", is_litellm=False)
|
||||
# Test openai/ prefix with unknown future model → LiteLLM
|
||||
llm2 = LLM(model="openai/gpt-future-6", is_litellm=False)
|
||||
assert llm2.is_litellm is True
|
||||
assert llm2.model == "openai/custom-finetune-model"
|
||||
assert llm2.model == "openai/gpt-future-6"
|
||||
|
||||
# Test anthropic/ prefix with non-Anthropic model → LiteLLM
|
||||
llm3 = LLM(model="anthropic/gpt-4o", is_litellm=False)
|
||||
@@ -775,21 +768,6 @@ def test_prefixed_models_with_invalid_constants_use_litellm():
|
||||
assert llm3.model == "anthropic/gpt-4o"
|
||||
|
||||
|
||||
def test_prefixed_models_with_valid_patterns_use_native_sdk():
|
||||
"""Test that models matching provider patterns use native SDK even if not in constants."""
|
||||
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
|
||||
llm = LLM(model="openai/gpt-future-6", is_litellm=False)
|
||||
assert llm.is_litellm is False
|
||||
assert llm.provider == "openai"
|
||||
assert llm.model == "gpt-future-6"
|
||||
|
||||
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}):
|
||||
llm2 = LLM(model="anthropic/claude-future-5", is_litellm=False)
|
||||
assert llm2.is_litellm is False
|
||||
assert llm2.provider == "anthropic"
|
||||
assert llm2.model == "claude-future-5"
|
||||
|
||||
|
||||
def test_prefixed_models_with_non_native_providers_use_litellm():
|
||||
"""Test that models with non-native provider prefixes always use LiteLLM."""
|
||||
# Test groq/ prefix (not a native provider) → LiteLLM
|
||||
@@ -843,36 +821,19 @@ def test_validate_model_in_constants():
|
||||
"""Test the _validate_model_in_constants method."""
|
||||
# OpenAI models
|
||||
assert LLM._validate_model_in_constants("gpt-4o", "openai") is True
|
||||
assert LLM._validate_model_in_constants("gpt-future-6", "openai") is True
|
||||
assert LLM._validate_model_in_constants("o1-latest", "openai") is True
|
||||
assert LLM._validate_model_in_constants("unknown-model", "openai") is False
|
||||
assert LLM._validate_model_in_constants("gpt-future-6", "openai") is False
|
||||
|
||||
# Anthropic models
|
||||
assert LLM._validate_model_in_constants("claude-opus-4-0", "claude") is True
|
||||
assert LLM._validate_model_in_constants("claude-future-5", "claude") is True
|
||||
assert (
|
||||
LLM._validate_model_in_constants("claude-3-5-sonnet-latest", "claude") is True
|
||||
)
|
||||
assert LLM._validate_model_in_constants("unknown-model", "claude") is False
|
||||
assert LLM._validate_model_in_constants("claude-future-5", "claude") is False
|
||||
|
||||
# Gemini models
|
||||
assert LLM._validate_model_in_constants("gemini-2.5-pro", "gemini") is True
|
||||
assert LLM._validate_model_in_constants("gemini-future", "gemini") is True
|
||||
assert LLM._validate_model_in_constants("gemma-3-latest", "gemini") is True
|
||||
assert LLM._validate_model_in_constants("unknown-model", "gemini") is False
|
||||
assert LLM._validate_model_in_constants("gemini-future", "gemini") is False
|
||||
|
||||
# Azure models
|
||||
assert LLM._validate_model_in_constants("gpt-4o", "azure") is True
|
||||
assert LLM._validate_model_in_constants("gpt-35-turbo", "azure") is True
|
||||
|
||||
# Bedrock models
|
||||
assert (
|
||||
LLM._validate_model_in_constants(
|
||||
"anthropic.claude-opus-4-1-20250805-v1:0", "bedrock"
|
||||
)
|
||||
is True
|
||||
)
|
||||
assert (
|
||||
LLM._validate_model_in_constants("anthropic.claude-future-v1:0", "bedrock")
|
||||
is True
|
||||
)
|
||||
assert LLM._validate_model_in_constants("anthropic.claude-opus-4-1-20250805-v1:0", "bedrock") is True
|
||||
|
||||
@@ -272,99 +272,6 @@ def another_simple_tool():
|
||||
return "Hi!"
|
||||
|
||||
|
||||
class TestAsyncDecoratorSupport:
|
||||
"""Tests for async method support in @agent, @task decorators."""
|
||||
|
||||
def test_async_agent_memoization(self):
|
||||
"""Async agent methods should be properly memoized."""
|
||||
|
||||
class AsyncAgentCrew:
|
||||
call_count = 0
|
||||
|
||||
@agent
|
||||
async def async_agent(self):
|
||||
AsyncAgentCrew.call_count += 1
|
||||
return Agent(
|
||||
role="Async Agent", goal="Async Goal", backstory="Async Backstory"
|
||||
)
|
||||
|
||||
crew = AsyncAgentCrew()
|
||||
first_call = crew.async_agent()
|
||||
second_call = crew.async_agent()
|
||||
|
||||
assert first_call is second_call, "Async agent memoization failed"
|
||||
assert AsyncAgentCrew.call_count == 1, "Async agent called more than once"
|
||||
|
||||
def test_async_task_memoization(self):
|
||||
"""Async task methods should be properly memoized."""
|
||||
|
||||
class AsyncTaskCrew:
|
||||
call_count = 0
|
||||
|
||||
@task
|
||||
async def async_task(self):
|
||||
AsyncTaskCrew.call_count += 1
|
||||
return Task(
|
||||
description="Async Description", expected_output="Async Output"
|
||||
)
|
||||
|
||||
crew = AsyncTaskCrew()
|
||||
first_call = crew.async_task()
|
||||
second_call = crew.async_task()
|
||||
|
||||
assert first_call is second_call, "Async task memoization failed"
|
||||
assert AsyncTaskCrew.call_count == 1, "Async task called more than once"
|
||||
|
||||
def test_async_task_name_inference(self):
|
||||
"""Async task should have name inferred from method name."""
|
||||
|
||||
class AsyncTaskNameCrew:
|
||||
@task
|
||||
async def my_async_task(self):
|
||||
return Task(
|
||||
description="Async Description", expected_output="Async Output"
|
||||
)
|
||||
|
||||
crew = AsyncTaskNameCrew()
|
||||
task_instance = crew.my_async_task()
|
||||
|
||||
assert task_instance.name == "my_async_task", (
|
||||
"Async task name not inferred correctly"
|
||||
)
|
||||
|
||||
def test_async_agent_returns_agent_not_coroutine(self):
|
||||
"""Async agent decorator should return Agent, not coroutine."""
|
||||
|
||||
class AsyncAgentTypeCrew:
|
||||
@agent
|
||||
async def typed_async_agent(self):
|
||||
return Agent(
|
||||
role="Typed Agent", goal="Typed Goal", backstory="Typed Backstory"
|
||||
)
|
||||
|
||||
crew = AsyncAgentTypeCrew()
|
||||
result = crew.typed_async_agent()
|
||||
|
||||
assert isinstance(result, Agent), (
|
||||
f"Expected Agent, got {type(result).__name__}"
|
||||
)
|
||||
|
||||
def test_async_task_returns_task_not_coroutine(self):
|
||||
"""Async task decorator should return Task, not coroutine."""
|
||||
|
||||
class AsyncTaskTypeCrew:
|
||||
@task
|
||||
async def typed_async_task(self):
|
||||
return Task(
|
||||
description="Typed Description", expected_output="Typed Output"
|
||||
)
|
||||
|
||||
crew = AsyncTaskTypeCrew()
|
||||
result = crew.typed_async_task()
|
||||
|
||||
assert isinstance(result, Task), f"Expected Task, got {type(result).__name__}"
|
||||
|
||||
|
||||
def test_internal_crew_with_mcp():
|
||||
from crewai_tools.adapters.tool_collection import ToolCollection
|
||||
|
||||
|
||||
@@ -26,7 +26,6 @@ from crewai.events.types.flow_events import (
|
||||
FlowFinishedEvent,
|
||||
FlowStartedEvent,
|
||||
MethodExecutionFailedEvent,
|
||||
MethodExecutionFinishedEvent,
|
||||
MethodExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.llm_events import (
|
||||
@@ -48,7 +47,7 @@ from crewai.flow.flow import Flow, listen, start
|
||||
from crewai.llm import LLM
|
||||
from crewai.task import Task
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
import pytest
|
||||
|
||||
from ..utils import wait_for_event_handlers
|
||||
@@ -704,156 +703,6 @@ def test_flow_emits_method_execution_failed_event():
|
||||
assert received_events[0].error == error
|
||||
|
||||
|
||||
def test_flow_method_execution_started_includes_unstructured_state():
|
||||
"""Test that MethodExecutionStartedEvent includes unstructured (dict) state."""
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(MethodExecutionStartedEvent)
|
||||
def handle_method_started(source, event):
|
||||
received_events.append(event)
|
||||
if event.method_name == "process":
|
||||
event_received.set()
|
||||
|
||||
class TestFlow(Flow[dict]):
|
||||
@start()
|
||||
def begin(self):
|
||||
self.state["counter"] = 1
|
||||
self.state["message"] = "test"
|
||||
return "started"
|
||||
|
||||
@listen("begin")
|
||||
def process(self):
|
||||
self.state["counter"] = 2
|
||||
return "processed"
|
||||
|
||||
flow = TestFlow()
|
||||
flow.kickoff()
|
||||
|
||||
assert event_received.wait(timeout=5), (
|
||||
"Timeout waiting for method execution started event"
|
||||
)
|
||||
|
||||
# Find the events for each method
|
||||
begin_event = next(e for e in received_events if e.method_name == "begin")
|
||||
process_event = next(e for e in received_events if e.method_name == "process")
|
||||
|
||||
# Verify state is included and is a dict
|
||||
assert begin_event.state is not None
|
||||
assert isinstance(begin_event.state, dict)
|
||||
assert "id" in begin_event.state # Auto-generated ID
|
||||
|
||||
# Verify state from begin method is captured in process event
|
||||
assert process_event.state is not None
|
||||
assert isinstance(process_event.state, dict)
|
||||
assert process_event.state["counter"] == 1
|
||||
assert process_event.state["message"] == "test"
|
||||
|
||||
|
||||
def test_flow_method_execution_started_includes_structured_state():
|
||||
"""Test that MethodExecutionStartedEvent includes structured (BaseModel) state and serializes it properly."""
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
|
||||
class FlowState(BaseModel):
|
||||
counter: int = 0
|
||||
message: str = ""
|
||||
items: list[str] = []
|
||||
|
||||
@crewai_event_bus.on(MethodExecutionStartedEvent)
|
||||
def handle_method_started(source, event):
|
||||
received_events.append(event)
|
||||
if event.method_name == "process":
|
||||
event_received.set()
|
||||
|
||||
class TestFlow(Flow[FlowState]):
|
||||
@start()
|
||||
def begin(self):
|
||||
self.state.counter = 1
|
||||
self.state.message = "initial"
|
||||
self.state.items = ["a", "b"]
|
||||
return "started"
|
||||
|
||||
@listen("begin")
|
||||
def process(self):
|
||||
self.state.counter += 1
|
||||
return "processed"
|
||||
|
||||
flow = TestFlow()
|
||||
flow.kickoff()
|
||||
|
||||
assert event_received.wait(timeout=5), (
|
||||
"Timeout waiting for method execution started event"
|
||||
)
|
||||
|
||||
begin_event = next(e for e in received_events if e.method_name == "begin")
|
||||
process_event = next(e for e in received_events if e.method_name == "process")
|
||||
|
||||
assert begin_event.state is not None
|
||||
assert isinstance(begin_event.state, dict)
|
||||
assert begin_event.state["counter"] == 0 # Initial state
|
||||
assert begin_event.state["message"] == ""
|
||||
assert begin_event.state["items"] == []
|
||||
|
||||
assert process_event.state is not None
|
||||
assert isinstance(process_event.state, dict)
|
||||
assert process_event.state["counter"] == 1
|
||||
assert process_event.state["message"] == "initial"
|
||||
assert process_event.state["items"] == ["a", "b"]
|
||||
|
||||
|
||||
def test_flow_method_execution_finished_includes_serialized_state():
|
||||
"""Test that MethodExecutionFinishedEvent includes properly serialized state."""
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
|
||||
class FlowState(BaseModel):
|
||||
result: str = ""
|
||||
completed: bool = False
|
||||
|
||||
@crewai_event_bus.on(MethodExecutionFinishedEvent)
|
||||
def handle_method_finished(source, event):
|
||||
received_events.append(event)
|
||||
if event.method_name == "process":
|
||||
event_received.set()
|
||||
|
||||
class TestFlow(Flow[FlowState]):
|
||||
@start()
|
||||
def begin(self):
|
||||
self.state.result = "begin done"
|
||||
return "started"
|
||||
|
||||
@listen("begin")
|
||||
def process(self):
|
||||
self.state.result = "process done"
|
||||
self.state.completed = True
|
||||
return "final_result"
|
||||
|
||||
flow = TestFlow()
|
||||
final_output = flow.kickoff()
|
||||
|
||||
assert event_received.wait(timeout=5), (
|
||||
"Timeout waiting for method execution finished event"
|
||||
)
|
||||
|
||||
begin_finished = next(e for e in received_events if e.method_name == "begin")
|
||||
process_finished = next(e for e in received_events if e.method_name == "process")
|
||||
|
||||
assert begin_finished.state is not None
|
||||
assert isinstance(begin_finished.state, dict)
|
||||
assert begin_finished.state["result"] == "begin done"
|
||||
assert begin_finished.state["completed"] is False
|
||||
assert begin_finished.result == "started"
|
||||
|
||||
# Verify process finished event has final state and result
|
||||
assert process_finished.state is not None
|
||||
assert isinstance(process_finished.state, dict)
|
||||
assert process_finished.state["result"] == "process done"
|
||||
assert process_finished.state["completed"] is True
|
||||
assert process_finished.result == "final_result"
|
||||
assert final_output == "final_result"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_llm_emits_call_started_event():
|
||||
received_events = []
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
"""CrewAI development tools."""
|
||||
|
||||
__version__ = "1.6.1"
|
||||
__version__ = "1.5.0"
|
||||
|
||||
26
uv.lock
generated
26
uv.lock
generated
@@ -1225,7 +1225,7 @@ dependencies = [
|
||||
{ name = "crewai" },
|
||||
{ name = "docker" },
|
||||
{ name = "lancedb" },
|
||||
{ name = "pymupdf" },
|
||||
{ name = "pypdf" },
|
||||
{ name = "python-docx" },
|
||||
{ name = "pytube" },
|
||||
{ name = "requests" },
|
||||
@@ -1382,8 +1382,8 @@ requires-dist = [
|
||||
{ name = "psycopg2-binary", marker = "extra == 'postgresql'", specifier = ">=2.9.10" },
|
||||
{ name = "pygithub", marker = "extra == 'github'", specifier = "==1.59.1" },
|
||||
{ name = "pymongo", marker = "extra == 'mongodb'", specifier = ">=4.13" },
|
||||
{ name = "pymupdf", specifier = ">=1.26.6" },
|
||||
{ name = "pymysql", marker = "extra == 'mysql'", specifier = ">=1.1.1" },
|
||||
{ name = "pypdf", specifier = ">=5.9.0" },
|
||||
{ name = "python-docx", specifier = ">=1.2.0" },
|
||||
{ name = "python-docx", marker = "extra == 'rag'", specifier = ">=1.1.0" },
|
||||
{ name = "pytube", specifier = ">=15.0.0" },
|
||||
@@ -2224,8 +2224,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/7f/91/ae2eb6b7979e2f9b035a9f612cf70f1bf54aad4e1d125129bef1eae96f19/greenlet-3.2.4-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c2ca18a03a8cfb5b25bc1cbe20f3d9a4c80d8c3b13ba3df49ac3961af0b1018d", size = 584358, upload-time = "2025-08-07T13:18:23.708Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f7/85/433de0c9c0252b22b16d413c9407e6cb3b41df7389afc366ca204dbc1393/greenlet-3.2.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:9fe0a28a7b952a21e2c062cd5756d34354117796c6d9215a87f55e38d15402c5", size = 1113550, upload-time = "2025-08-07T13:42:37.467Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a1/8d/88f3ebd2bc96bf7747093696f4335a0a8a4c5acfcf1b757717c0d2474ba3/greenlet-3.2.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8854167e06950ca75b898b104b63cc646573aa5fef1353d4508ecdd1ee76254f", size = 1137126, upload-time = "2025-08-07T13:18:20.239Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f1/29/74242b7d72385e29bcc5563fba67dad94943d7cd03552bac320d597f29b2/greenlet-3.2.4-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:f47617f698838ba98f4ff4189aef02e7343952df3a615f847bb575c3feb177a7", size = 1544904, upload-time = "2025-11-04T12:42:04.763Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c8/e2/1572b8eeab0f77df5f6729d6ab6b141e4a84ee8eb9bc8c1e7918f94eda6d/greenlet-3.2.4-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:af41be48a4f60429d5cad9d22175217805098a9ef7c40bfef44f7669fb9d74d8", size = 1611228, upload-time = "2025-11-04T12:42:08.423Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d6/6f/b60b0291d9623c496638c582297ead61f43c4b72eef5e9c926ef4565ec13/greenlet-3.2.4-cp310-cp310-win_amd64.whl", hash = "sha256:73f49b5368b5359d04e18d15828eecc1806033db5233397748f4ca813ff1056c", size = 298654, upload-time = "2025-08-07T13:50:00.469Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a4/de/f28ced0a67749cac23fecb02b694f6473f47686dff6afaa211d186e2ef9c/greenlet-3.2.4-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:96378df1de302bc38e99c3a9aa311967b7dc80ced1dcc6f171e99842987882a2", size = 272305, upload-time = "2025-08-07T13:15:41.288Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/09/16/2c3792cba130000bf2a31c5272999113f4764fd9d874fb257ff588ac779a/greenlet-3.2.4-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:1ee8fae0519a337f2329cb78bd7a8e128ec0f881073d43f023c7b8d4831d5246", size = 632472, upload-time = "2025-08-07T13:42:55.044Z" },
|
||||
@@ -2235,8 +2233,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/1f/8e/abdd3f14d735b2929290a018ecf133c901be4874b858dd1c604b9319f064/greenlet-3.2.4-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2523e5246274f54fdadbce8494458a2ebdcdbc7b802318466ac5606d3cded1f8", size = 587684, upload-time = "2025-08-07T13:18:25.164Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5d/65/deb2a69c3e5996439b0176f6651e0052542bb6c8f8ec2e3fba97c9768805/greenlet-3.2.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:1987de92fec508535687fb807a5cea1560f6196285a4cde35c100b8cd632cc52", size = 1116647, upload-time = "2025-08-07T13:42:38.655Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3f/cc/b07000438a29ac5cfb2194bfc128151d52f333cee74dd7dfe3fb733fc16c/greenlet-3.2.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:55e9c5affaa6775e2c6b67659f3a71684de4c549b3dd9afca3bc773533d284fa", size = 1142073, upload-time = "2025-08-07T13:18:21.737Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/67/24/28a5b2fa42d12b3d7e5614145f0bd89714c34c08be6aabe39c14dd52db34/greenlet-3.2.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c9c6de1940a7d828635fbd254d69db79e54619f165ee7ce32fda763a9cb6a58c", size = 1548385, upload-time = "2025-11-04T12:42:11.067Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6a/05/03f2f0bdd0b0ff9a4f7b99333d57b53a7709c27723ec8123056b084e69cd/greenlet-3.2.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:03c5136e7be905045160b1b9fdca93dd6727b180feeafda6818e6496434ed8c5", size = 1613329, upload-time = "2025-11-04T12:42:12.928Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d8/0f/30aef242fcab550b0b3520b8e3561156857c94288f0332a79928c31a52cf/greenlet-3.2.4-cp311-cp311-win_amd64.whl", hash = "sha256:9c40adce87eaa9ddb593ccb0fa6a07caf34015a29bf8d344811665b573138db9", size = 299100, upload-time = "2025-08-07T13:44:12.287Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/44/69/9b804adb5fd0671f367781560eb5eb586c4d495277c93bde4307b9e28068/greenlet-3.2.4-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:3b67ca49f54cede0186854a008109d6ee71f66bd57bb36abd6d0a0267b540cdd", size = 274079, upload-time = "2025-08-07T13:15:45.033Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/46/e9/d2a80c99f19a153eff70bc451ab78615583b8dac0754cfb942223d2c1a0d/greenlet-3.2.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ddf9164e7a5b08e9d22511526865780a576f19ddd00d62f8a665949327fde8bb", size = 640997, upload-time = "2025-08-07T13:42:56.234Z" },
|
||||
@@ -2246,8 +2242,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/19/0d/6660d55f7373b2ff8152401a83e02084956da23ae58cddbfb0b330978fe9/greenlet-3.2.4-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3b3812d8d0c9579967815af437d96623f45c0f2ae5f04e366de62a12d83a8fb0", size = 607586, upload-time = "2025-08-07T13:18:28.544Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8e/1a/c953fdedd22d81ee4629afbb38d2f9d71e37d23caace44775a3a969147d4/greenlet-3.2.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:abbf57b5a870d30c4675928c37278493044d7c14378350b3aa5d484fa65575f0", size = 1123281, upload-time = "2025-08-07T13:42:39.858Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3f/c7/12381b18e21aef2c6bd3a636da1088b888b97b7a0362fac2e4de92405f97/greenlet-3.2.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:20fb936b4652b6e307b8f347665e2c615540d4b42b3b4c8a321d8286da7e520f", size = 1151142, upload-time = "2025-08-07T13:18:22.981Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/27/45/80935968b53cfd3f33cf99ea5f08227f2646e044568c9b1555b58ffd61c2/greenlet-3.2.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:ee7a6ec486883397d70eec05059353b8e83eca9168b9f3f9a361971e77e0bcd0", size = 1564846, upload-time = "2025-11-04T12:42:15.191Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/69/02/b7c30e5e04752cb4db6202a3858b149c0710e5453b71a3b2aec5d78a1aab/greenlet-3.2.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:326d234cbf337c9c3def0676412eb7040a35a768efc92504b947b3e9cfc7543d", size = 1633814, upload-time = "2025-11-04T12:42:17.175Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e9/08/b0814846b79399e585f974bbeebf5580fbe59e258ea7be64d9dfb253c84f/greenlet-3.2.4-cp312-cp312-win_amd64.whl", hash = "sha256:a7d4e128405eea3814a12cc2605e0e6aedb4035bf32697f72deca74de4105e02", size = 299899, upload-time = "2025-08-07T13:38:53.448Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/49/e8/58c7f85958bda41dafea50497cbd59738c5c43dbbea5ee83d651234398f4/greenlet-3.2.4-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:1a921e542453fe531144e91e1feedf12e07351b1cf6c9e8a3325ea600a715a31", size = 272814, upload-time = "2025-08-07T13:15:50.011Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/62/dd/b9f59862e9e257a16e4e610480cfffd29e3fae018a68c2332090b53aac3d/greenlet-3.2.4-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cd3c8e693bff0fff6ba55f140bf390fa92c994083f838fece0f63be121334945", size = 641073, upload-time = "2025-08-07T13:42:57.23Z" },
|
||||
@@ -2257,8 +2251,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ee/43/3cecdc0349359e1a527cbf2e3e28e5f8f06d3343aaf82ca13437a9aa290f/greenlet-3.2.4-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:23768528f2911bcd7e475210822ffb5254ed10d71f4028387e5a99b4c6699671", size = 610497, upload-time = "2025-08-07T13:18:31.636Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b8/19/06b6cf5d604e2c382a6f31cafafd6f33d5dea706f4db7bdab184bad2b21d/greenlet-3.2.4-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:00fadb3fedccc447f517ee0d3fd8fe49eae949e1cd0f6a611818f4f6fb7dc83b", size = 1121662, upload-time = "2025-08-07T13:42:41.117Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a2/15/0d5e4e1a66fab130d98168fe984c509249c833c1a3c16806b90f253ce7b9/greenlet-3.2.4-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:d25c5091190f2dc0eaa3f950252122edbbadbb682aa7b1ef2f8af0f8c0afefae", size = 1149210, upload-time = "2025-08-07T13:18:24.072Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1c/53/f9c440463b3057485b8594d7a638bed53ba531165ef0ca0e6c364b5cc807/greenlet-3.2.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:6e343822feb58ac4d0a1211bd9399de2b3a04963ddeec21530fc426cc121f19b", size = 1564759, upload-time = "2025-11-04T12:42:19.395Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/47/e4/3bb4240abdd0a8d23f4f88adec746a3099f0d86bfedb623f063b2e3b4df0/greenlet-3.2.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ca7f6f1f2649b89ce02f6f229d7c19f680a6238af656f61e0115b24857917929", size = 1634288, upload-time = "2025-11-04T12:42:21.174Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0b/55/2321e43595e6801e105fcfdee02b34c0f996eb71e6ddffca6b10b7e1d771/greenlet-3.2.4-cp313-cp313-win_amd64.whl", hash = "sha256:554b03b6e73aaabec3745364d6239e9e012d64c68ccd0b8430c64ccc14939a8b", size = 299685, upload-time = "2025-08-07T13:24:38.824Z" },
|
||||
]
|
||||
|
||||
@@ -5978,20 +5970,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/48/7c/42f0b6997324023e94939f8f32b9a8dd928499f4b5d7b4412905368686b5/pymongo-4.15.3-cp313-cp313-win_arm64.whl", hash = "sha256:fb384623ece34db78d445dd578a52d28b74e8319f4d9535fbaff79d0eae82b3d", size = 944300, upload-time = "2025-10-07T21:56:58.969Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pymupdf"
|
||||
version = "1.26.6"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ec/d7/a6f0e03a117fa2ad79c4b898203bb212b17804f92558a6a339298faca7bb/pymupdf-1.26.6.tar.gz", hash = "sha256:a2b4531cd4ab36d6f1f794bb6d3c33b49bda22f36d58bb1f3e81cbc10183bd2b", size = 84322494, upload-time = "2025-11-05T15:20:46.786Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/9e/5c/dec354eee5fe4966c715f33818ed4193e0e6c986cf8484de35b6c167fb8e/pymupdf-1.26.6-cp310-abi3-macosx_10_9_x86_64.whl", hash = "sha256:e46f320a136ad55e5219e8f0f4061bdf3e4c12b126d2740d5a49f73fae7ea176", size = 23178988, upload-time = "2025-11-05T14:31:19.834Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ec/a0/11adb742d18142bd623556cd3b5d64649816decc5eafd30efc9498657e76/pymupdf-1.26.6-cp310-abi3-macosx_11_0_arm64.whl", hash = "sha256:6844cd2396553c0fa06de4869d5d5ecb1260e6fc3b9d85abe8fa35f14dd9d688", size = 22469764, upload-time = "2025-11-05T14:32:34.654Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e4/c8/377cf20e31f58d4c243bfcf2d3cb7466d5b97003b10b9f1161f11eb4a994/pymupdf-1.26.6-cp310-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:617ba69e02c44f0da1c0e039ea4a26cf630849fd570e169c71daeb8ac52a81d6", size = 23502227, upload-time = "2025-11-06T11:03:56.934Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4f/bf/6e02e3d84b32c137c71a0a3dcdba8f2f6e9950619a3bc272245c7c06a051/pymupdf-1.26.6-cp310-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:7777d0b7124c2ebc94849536b6a1fb85d158df3b9d873935e63036559391534c", size = 24115381, upload-time = "2025-11-05T14:33:54.338Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ab/9d/30f7fcb3776bfedde66c06297960debe4883b1667294a1ee9426c942e94d/pymupdf-1.26.6-cp310-abi3-win32.whl", hash = "sha256:8f3ef05befc90ca6bb0f12983200a7048d5bff3e1c1edef1bb3de60b32cb5274", size = 17203613, upload-time = "2025-11-05T17:19:47.494Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f9/e8/989f4eaa369c7166dc24f0eaa3023f13788c40ff1b96701f7047421554a8/pymupdf-1.26.6-cp310-abi3-win_amd64.whl", hash = "sha256:ce02ca96ed0d1acfd00331a4d41a34c98584d034155b06fd4ec0f051718de7ba", size = 18405680, upload-time = "2025-11-05T14:34:48.672Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pymysql"
|
||||
version = "1.1.2"
|
||||
|
||||
Reference in New Issue
Block a user