mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-22 22:58:13 +00:00
Compare commits
2 Commits
gl/feat/na
...
devin/1768
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6cfc105d54 | ||
|
|
703e0f6191 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -26,4 +26,3 @@ plan.md
|
||||
conceptual_plan.md
|
||||
build_image
|
||||
chromadb-*.lock
|
||||
.claude
|
||||
|
||||
@@ -19,7 +19,7 @@ repos:
|
||||
language: system
|
||||
pass_filenames: true
|
||||
types: [python]
|
||||
exclude: ^(lib/crewai/src/crewai/cli/templates/|lib/crewai/tests/|lib/crewai-tools/tests/|lib/crewai-files/tests/)
|
||||
exclude: ^(lib/crewai/src/crewai/cli/templates/|lib/crewai/tests/|lib/crewai-tools/tests/)
|
||||
- repo: https://github.com/astral-sh/uv-pre-commit
|
||||
rev: 0.9.3
|
||||
hooks:
|
||||
|
||||
@@ -160,10 +160,7 @@ def vcr_cassette_dir(request: Any) -> str:
|
||||
test_file = Path(request.fspath)
|
||||
|
||||
for parent in test_file.parents:
|
||||
if (
|
||||
parent.name in ("crewai", "crewai-tools", "crewai-files")
|
||||
and parent.parent.name == "lib"
|
||||
):
|
||||
if parent.name in ("crewai", "crewai-tools") and parent.parent.name == "lib":
|
||||
package_root = parent
|
||||
break
|
||||
else:
|
||||
|
||||
@@ -291,7 +291,6 @@
|
||||
"en/observability/arize-phoenix",
|
||||
"en/observability/braintrust",
|
||||
"en/observability/datadog",
|
||||
"en/observability/galileo",
|
||||
"en/observability/langdb",
|
||||
"en/observability/langfuse",
|
||||
"en/observability/langtrace",
|
||||
@@ -429,8 +428,7 @@
|
||||
"group": "How-To Guides",
|
||||
"pages": [
|
||||
"en/enterprise/guides/build-crew",
|
||||
"en/enterprise/guides/prepare-for-deployment",
|
||||
"en/enterprise/guides/deploy-to-amp",
|
||||
"en/enterprise/guides/deploy-crew",
|
||||
"en/enterprise/guides/kickoff-crew",
|
||||
"en/enterprise/guides/update-crew",
|
||||
"en/enterprise/guides/enable-crew-studio",
|
||||
@@ -744,7 +742,6 @@
|
||||
"pt-BR/observability/arize-phoenix",
|
||||
"pt-BR/observability/braintrust",
|
||||
"pt-BR/observability/datadog",
|
||||
"pt-BR/observability/galileo",
|
||||
"pt-BR/observability/langdb",
|
||||
"pt-BR/observability/langfuse",
|
||||
"pt-BR/observability/langtrace",
|
||||
@@ -865,8 +862,7 @@
|
||||
"group": "Guias",
|
||||
"pages": [
|
||||
"pt-BR/enterprise/guides/build-crew",
|
||||
"pt-BR/enterprise/guides/prepare-for-deployment",
|
||||
"pt-BR/enterprise/guides/deploy-to-amp",
|
||||
"pt-BR/enterprise/guides/deploy-crew",
|
||||
"pt-BR/enterprise/guides/kickoff-crew",
|
||||
"pt-BR/enterprise/guides/update-crew",
|
||||
"pt-BR/enterprise/guides/enable-crew-studio",
|
||||
@@ -1207,7 +1203,6 @@
|
||||
"ko/observability/arize-phoenix",
|
||||
"ko/observability/braintrust",
|
||||
"ko/observability/datadog",
|
||||
"ko/observability/galileo",
|
||||
"ko/observability/langdb",
|
||||
"ko/observability/langfuse",
|
||||
"ko/observability/langtrace",
|
||||
@@ -1328,8 +1323,7 @@
|
||||
"group": "How-To Guides",
|
||||
"pages": [
|
||||
"ko/enterprise/guides/build-crew",
|
||||
"ko/enterprise/guides/prepare-for-deployment",
|
||||
"ko/enterprise/guides/deploy-to-amp",
|
||||
"ko/enterprise/guides/deploy-crew",
|
||||
"ko/enterprise/guides/kickoff-crew",
|
||||
"ko/enterprise/guides/update-crew",
|
||||
"ko/enterprise/guides/enable-crew-studio",
|
||||
@@ -1517,18 +1511,6 @@
|
||||
"source": "/enterprise/:path*",
|
||||
"destination": "/en/enterprise/:path*"
|
||||
},
|
||||
{
|
||||
"source": "/en/enterprise/guides/deploy-crew",
|
||||
"destination": "/en/enterprise/guides/deploy-to-amp"
|
||||
},
|
||||
{
|
||||
"source": "/ko/enterprise/guides/deploy-crew",
|
||||
"destination": "/ko/enterprise/guides/deploy-to-amp"
|
||||
},
|
||||
{
|
||||
"source": "/pt-BR/enterprise/guides/deploy-crew",
|
||||
"destination": "/pt-BR/enterprise/guides/deploy-to-amp"
|
||||
},
|
||||
{
|
||||
"source": "/api-reference/:path*",
|
||||
"destination": "/en/api-reference/:path*"
|
||||
|
||||
@@ -375,13 +375,10 @@ In this section, you'll find detailed examples that help you select, configure,
|
||||
GOOGLE_API_KEY=<your-api-key>
|
||||
GEMINI_API_KEY=<your-api-key>
|
||||
|
||||
# For Vertex AI Express mode (API key authentication)
|
||||
GOOGLE_GENAI_USE_VERTEXAI=true
|
||||
GOOGLE_API_KEY=<your-api-key>
|
||||
|
||||
# For Vertex AI with service account
|
||||
# Optional - for Vertex AI
|
||||
GOOGLE_CLOUD_PROJECT=<your-project-id>
|
||||
GOOGLE_CLOUD_LOCATION=<location> # Defaults to us-central1
|
||||
GOOGLE_GENAI_USE_VERTEXAI=true # Set to use Vertex AI
|
||||
```
|
||||
|
||||
**Basic Usage:**
|
||||
@@ -415,35 +412,7 @@ In this section, you'll find detailed examples that help you select, configure,
|
||||
)
|
||||
```
|
||||
|
||||
**Vertex AI Express Mode (API Key Authentication):**
|
||||
|
||||
Vertex AI Express mode allows you to use Vertex AI with simple API key authentication instead of service account credentials. This is the quickest way to get started with Vertex AI.
|
||||
|
||||
To enable Express mode, set both environment variables in your `.env` file:
|
||||
```toml .env
|
||||
GOOGLE_GENAI_USE_VERTEXAI=true
|
||||
GOOGLE_API_KEY=<your-api-key>
|
||||
```
|
||||
|
||||
Then use the LLM as usual:
|
||||
```python Code
|
||||
from crewai import LLM
|
||||
|
||||
llm = LLM(
|
||||
model="gemini/gemini-2.0-flash",
|
||||
temperature=0.7
|
||||
)
|
||||
```
|
||||
|
||||
<Info>
|
||||
To get an Express mode API key:
|
||||
- New Google Cloud users: Get an [express mode API key](https://cloud.google.com/vertex-ai/generative-ai/docs/start/quickstart?usertype=apikey)
|
||||
- Existing Google Cloud users: Get a [Google Cloud API key bound to a service account](https://cloud.google.com/docs/authentication/api-keys)
|
||||
|
||||
For more details, see the [Vertex AI Express mode documentation](https://docs.cloud.google.com/vertex-ai/generative-ai/docs/start/quickstart?usertype=apikey).
|
||||
</Info>
|
||||
|
||||
**Vertex AI Configuration (Service Account):**
|
||||
**Vertex AI Configuration:**
|
||||
```python Code
|
||||
from crewai import LLM
|
||||
|
||||
@@ -455,10 +424,10 @@ In this section, you'll find detailed examples that help you select, configure,
|
||||
```
|
||||
|
||||
**Supported Environment Variables:**
|
||||
- `GOOGLE_API_KEY` or `GEMINI_API_KEY`: Your Google API key (required for Gemini API and Vertex AI Express mode)
|
||||
- `GOOGLE_GENAI_USE_VERTEXAI`: Set to `true` to use Vertex AI (required for Express mode)
|
||||
- `GOOGLE_CLOUD_PROJECT`: Google Cloud project ID (for Vertex AI with service account)
|
||||
- `GOOGLE_API_KEY` or `GEMINI_API_KEY`: Your Google API key (required for Gemini API)
|
||||
- `GOOGLE_CLOUD_PROJECT`: Google Cloud project ID (for Vertex AI)
|
||||
- `GOOGLE_CLOUD_LOCATION`: GCP location (defaults to `us-central1`)
|
||||
- `GOOGLE_GENAI_USE_VERTEXAI`: Set to `true` to use Vertex AI
|
||||
|
||||
**Features:**
|
||||
- Native function calling support for Gemini 1.5+ and 2.x models
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
---
|
||||
title: "Deploy to AMP"
|
||||
description: "Deploy your Crew or Flow to CrewAI AMP"
|
||||
title: "Deploy Crew"
|
||||
description: "Deploying a Crew on CrewAI AMP"
|
||||
icon: "rocket"
|
||||
mode: "wide"
|
||||
---
|
||||
|
||||
<Note>
|
||||
After creating a Crew or Flow locally (or through Crew Studio), the next step is
|
||||
After creating a crew locally or through Crew Studio, the next step is
|
||||
deploying it to the CrewAI AMP platform. This guide covers multiple deployment
|
||||
methods to help you choose the best approach for your workflow.
|
||||
</Note>
|
||||
@@ -14,26 +14,19 @@ mode: "wide"
|
||||
## Prerequisites
|
||||
|
||||
<CardGroup cols={2}>
|
||||
<Card title="Project Ready for Deployment" icon="check-circle">
|
||||
You should have a working Crew or Flow that runs successfully locally.
|
||||
Follow our [preparation guide](/en/enterprise/guides/prepare-for-deployment) to verify your project structure.
|
||||
<Card title="Crew Ready for Deployment" icon="users">
|
||||
You should have a working crew either built locally or created through Crew
|
||||
Studio
|
||||
</Card>
|
||||
<Card title="GitHub Repository" icon="github">
|
||||
Your code should be in a GitHub repository (for GitHub integration
|
||||
Your crew code should be in a GitHub repository (for GitHub integration
|
||||
method)
|
||||
</Card>
|
||||
</CardGroup>
|
||||
|
||||
<Info>
|
||||
**Crews vs Flows**: Both project types can be deployed as "automations" on CrewAI AMP.
|
||||
The deployment process is the same, but they have different project structures.
|
||||
See [Prepare for Deployment](/en/enterprise/guides/prepare-for-deployment) for details.
|
||||
</Info>
|
||||
|
||||
## Option 1: Deploy Using CrewAI CLI
|
||||
|
||||
The CLI provides the fastest way to deploy locally developed Crews or Flows to the AMP platform.
|
||||
The CLI automatically detects your project type from `pyproject.toml` and builds accordingly.
|
||||
The CLI provides the fastest way to deploy locally developed crews to the Enterprise platform.
|
||||
|
||||
<Steps>
|
||||
<Step title="Install CrewAI CLI">
|
||||
@@ -135,7 +128,7 @@ crewai deploy remove <deployment_id>
|
||||
|
||||
## Option 2: Deploy Directly via Web Interface
|
||||
|
||||
You can also deploy your Crews or Flows directly through the CrewAI AMP web interface by connecting your GitHub account. This approach doesn't require using the CLI on your local machine. The platform automatically detects your project type and handles the build appropriately.
|
||||
You can also deploy your crews directly through the CrewAI AMP web interface by connecting your GitHub account. This approach doesn't require using the CLI on your local machine.
|
||||
|
||||
<Steps>
|
||||
|
||||
@@ -289,7 +282,68 @@ For automated deployments in CI/CD pipelines, you can use the CrewAI API to trig
|
||||
|
||||
</Steps>
|
||||
|
||||
## Interact with Your Deployed Automation
|
||||
## ⚠️ Environment Variable Security Requirements
|
||||
|
||||
<Warning>
|
||||
**Important**: CrewAI AMP has security restrictions on environment variable
|
||||
names that can cause deployment failures if not followed.
|
||||
</Warning>
|
||||
|
||||
### Blocked Environment Variable Patterns
|
||||
|
||||
For security reasons, the following environment variable naming patterns are **automatically filtered** and will cause deployment issues:
|
||||
|
||||
**Blocked Patterns:**
|
||||
|
||||
- Variables ending with `_TOKEN` (e.g., `MY_API_TOKEN`)
|
||||
- Variables ending with `_PASSWORD` (e.g., `DB_PASSWORD`)
|
||||
- Variables ending with `_SECRET` (e.g., `API_SECRET`)
|
||||
- Variables ending with `_KEY` in certain contexts
|
||||
|
||||
**Specific Blocked Variables:**
|
||||
|
||||
- `GITHUB_USER`, `GITHUB_TOKEN`
|
||||
- `AWS_REGION`, `AWS_DEFAULT_REGION`
|
||||
- Various internal CrewAI system variables
|
||||
|
||||
### Allowed Exceptions
|
||||
|
||||
Some variables are explicitly allowed despite matching blocked patterns:
|
||||
|
||||
- `AZURE_AD_TOKEN`
|
||||
- `AZURE_OPENAI_AD_TOKEN`
|
||||
- `ENTERPRISE_ACTION_TOKEN`
|
||||
- `CREWAI_ENTEPRISE_TOOLS_TOKEN`
|
||||
|
||||
### How to Fix Naming Issues
|
||||
|
||||
If your deployment fails due to environment variable restrictions:
|
||||
|
||||
```bash
|
||||
# ❌ These will cause deployment failures
|
||||
OPENAI_TOKEN=sk-...
|
||||
DATABASE_PASSWORD=mypassword
|
||||
API_SECRET=secret123
|
||||
|
||||
# ✅ Use these naming patterns instead
|
||||
OPENAI_API_KEY=sk-...
|
||||
DATABASE_CREDENTIALS=mypassword
|
||||
API_CONFIG=secret123
|
||||
```
|
||||
|
||||
### Best Practices
|
||||
|
||||
1. **Use standard naming conventions**: `PROVIDER_API_KEY` instead of `PROVIDER_TOKEN`
|
||||
2. **Test locally first**: Ensure your crew works with the renamed variables
|
||||
3. **Update your code**: Change any references to the old variable names
|
||||
4. **Document changes**: Keep track of renamed variables for your team
|
||||
|
||||
<Tip>
|
||||
If you encounter deployment failures with cryptic environment variable errors,
|
||||
check your variable names against these patterns first.
|
||||
</Tip>
|
||||
|
||||
### Interact with Your Deployed Crew
|
||||
|
||||
Once deployment is complete, you can access your crew through:
|
||||
|
||||
@@ -333,108 +387,7 @@ The Enterprise platform also offers:
|
||||
- **Custom Tools Repository**: Create, share, and install tools
|
||||
- **Crew Studio**: Build crews through a chat interface without writing code
|
||||
|
||||
## Troubleshooting Deployment Failures
|
||||
|
||||
If your deployment fails, check these common issues:
|
||||
|
||||
### Build Failures
|
||||
|
||||
#### Missing uv.lock File
|
||||
|
||||
**Symptom**: Build fails early with dependency resolution errors
|
||||
|
||||
**Solution**: Generate and commit the lock file:
|
||||
|
||||
```bash
|
||||
uv lock
|
||||
git add uv.lock
|
||||
git commit -m "Add uv.lock for deployment"
|
||||
git push
|
||||
```
|
||||
|
||||
<Warning>
|
||||
The `uv.lock` file is required for all deployments. Without it, the platform
|
||||
cannot reliably install your dependencies.
|
||||
</Warning>
|
||||
|
||||
#### Wrong Project Structure
|
||||
|
||||
**Symptom**: "Could not find entry point" or "Module not found" errors
|
||||
|
||||
**Solution**: Verify your project matches the expected structure:
|
||||
|
||||
- **Both Crews and Flows**: Must have entry point at `src/project_name/main.py`
|
||||
- **Crews**: Use a `run()` function as entry point
|
||||
- **Flows**: Use a `kickoff()` function as entry point
|
||||
|
||||
See [Prepare for Deployment](/en/enterprise/guides/prepare-for-deployment) for detailed structure diagrams.
|
||||
|
||||
#### Missing CrewBase Decorator
|
||||
|
||||
**Symptom**: "Crew not found", "Config not found", or agent/task configuration errors
|
||||
|
||||
**Solution**: Ensure **all** crew classes use the `@CrewBase` decorator:
|
||||
|
||||
```python
|
||||
from crewai.project import CrewBase, agent, crew, task
|
||||
|
||||
@CrewBase # This decorator is REQUIRED
|
||||
class YourCrew():
|
||||
"""Your crew description"""
|
||||
|
||||
@agent
|
||||
def my_agent(self) -> Agent:
|
||||
return Agent(
|
||||
config=self.agents_config['my_agent'], # type: ignore[index]
|
||||
verbose=True
|
||||
)
|
||||
|
||||
# ... rest of crew definition
|
||||
```
|
||||
|
||||
<Info>
|
||||
This applies to standalone Crews AND crews embedded inside Flow projects.
|
||||
Every crew class needs the decorator.
|
||||
</Info>
|
||||
|
||||
#### Incorrect pyproject.toml Type
|
||||
|
||||
**Symptom**: Build succeeds but runtime fails, or unexpected behavior
|
||||
|
||||
**Solution**: Verify the `[tool.crewai]` section matches your project type:
|
||||
|
||||
```toml
|
||||
# For Crew projects:
|
||||
[tool.crewai]
|
||||
type = "crew"
|
||||
|
||||
# For Flow projects:
|
||||
[tool.crewai]
|
||||
type = "flow"
|
||||
```
|
||||
|
||||
### Runtime Failures
|
||||
|
||||
#### LLM Connection Failures
|
||||
|
||||
**Symptom**: API key errors, "model not found", or authentication failures
|
||||
|
||||
**Solution**:
|
||||
1. Verify your LLM provider's API key is correctly set in environment variables
|
||||
2. Ensure the environment variable names match what your code expects
|
||||
3. Test locally with the exact same environment variables before deploying
|
||||
|
||||
#### Crew Execution Errors
|
||||
|
||||
**Symptom**: Crew starts but fails during execution
|
||||
|
||||
**Solution**:
|
||||
1. Check the execution logs in the AMP dashboard (Traces tab)
|
||||
2. Verify all tools have required API keys configured
|
||||
3. Ensure agent configurations in `agents.yaml` are valid
|
||||
4. Check task configurations in `tasks.yaml` for syntax errors
|
||||
|
||||
<Card title="Need Help?" icon="headset" href="mailto:support@crewai.com">
|
||||
Contact our support team for assistance with deployment issues or questions
|
||||
about the AMP platform.
|
||||
about the Enterprise platform.
|
||||
</Card>
|
||||
@@ -1,305 +0,0 @@
|
||||
---
|
||||
title: "Prepare for Deployment"
|
||||
description: "Ensure your Crew or Flow is ready for deployment to CrewAI AMP"
|
||||
icon: "clipboard-check"
|
||||
mode: "wide"
|
||||
---
|
||||
|
||||
<Note>
|
||||
Before deploying to CrewAI AMP, it's crucial to verify your project is correctly structured.
|
||||
Both Crews and Flows can be deployed as "automations," but they have different project structures
|
||||
and requirements that must be met for successful deployment.
|
||||
</Note>
|
||||
|
||||
## Understanding Automations
|
||||
|
||||
In CrewAI AMP, **automations** is the umbrella term for deployable Agentic AI projects. An automation can be either:
|
||||
|
||||
- **A Crew**: A standalone team of AI agents working together on tasks
|
||||
- **A Flow**: An orchestrated workflow that can combine multiple crews, direct LLM calls, and procedural logic
|
||||
|
||||
Understanding which type you're deploying is essential because they have different project structures and entry points.
|
||||
|
||||
## Crews vs Flows: Key Differences
|
||||
|
||||
<CardGroup cols={2}>
|
||||
<Card title="Crew Projects" icon="users">
|
||||
Standalone AI agent teams with `crew.py` defining agents and tasks. Best for focused, collaborative tasks.
|
||||
</Card>
|
||||
<Card title="Flow Projects" icon="diagram-project">
|
||||
Orchestrated workflows with embedded crews in a `crews/` folder. Best for complex, multi-stage processes.
|
||||
</Card>
|
||||
</CardGroup>
|
||||
|
||||
| Aspect | Crew | Flow |
|
||||
|--------|------|------|
|
||||
| **Project structure** | `src/project_name/` with `crew.py` | `src/project_name/` with `crews/` folder |
|
||||
| **Main logic location** | `src/project_name/crew.py` | `src/project_name/main.py` (Flow class) |
|
||||
| **Entry point function** | `run()` in `main.py` | `kickoff()` in `main.py` |
|
||||
| **pyproject.toml type** | `type = "crew"` | `type = "flow"` |
|
||||
| **CLI create command** | `crewai create crew name` | `crewai create flow name` |
|
||||
| **Config location** | `src/project_name/config/` | `src/project_name/crews/crew_name/config/` |
|
||||
| **Can contain other crews** | No | Yes (in `crews/` folder) |
|
||||
|
||||
## Project Structure Reference
|
||||
|
||||
### Crew Project Structure
|
||||
|
||||
When you run `crewai create crew my_crew`, you get this structure:
|
||||
|
||||
```
|
||||
my_crew/
|
||||
├── .gitignore
|
||||
├── pyproject.toml # Must have type = "crew"
|
||||
├── README.md
|
||||
├── .env
|
||||
├── uv.lock # REQUIRED for deployment
|
||||
└── src/
|
||||
└── my_crew/
|
||||
├── __init__.py
|
||||
├── main.py # Entry point with run() function
|
||||
├── crew.py # Crew class with @CrewBase decorator
|
||||
├── tools/
|
||||
│ ├── custom_tool.py
|
||||
│ └── __init__.py
|
||||
└── config/
|
||||
├── agents.yaml # Agent definitions
|
||||
└── tasks.yaml # Task definitions
|
||||
```
|
||||
|
||||
<Warning>
|
||||
The nested `src/project_name/` structure is critical for Crews.
|
||||
Placing files at the wrong level will cause deployment failures.
|
||||
</Warning>
|
||||
|
||||
### Flow Project Structure
|
||||
|
||||
When you run `crewai create flow my_flow`, you get this structure:
|
||||
|
||||
```
|
||||
my_flow/
|
||||
├── .gitignore
|
||||
├── pyproject.toml # Must have type = "flow"
|
||||
├── README.md
|
||||
├── .env
|
||||
├── uv.lock # REQUIRED for deployment
|
||||
└── src/
|
||||
└── my_flow/
|
||||
├── __init__.py
|
||||
├── main.py # Entry point with kickoff() function + Flow class
|
||||
├── crews/ # Embedded crews folder
|
||||
│ └── poem_crew/
|
||||
│ ├── __init__.py
|
||||
│ ├── poem_crew.py # Crew with @CrewBase decorator
|
||||
│ └── config/
|
||||
│ ├── agents.yaml
|
||||
│ └── tasks.yaml
|
||||
└── tools/
|
||||
├── __init__.py
|
||||
└── custom_tool.py
|
||||
```
|
||||
|
||||
<Info>
|
||||
Both Crews and Flows use the `src/project_name/` structure.
|
||||
The key difference is that Flows have a `crews/` folder for embedded crews,
|
||||
while Crews have `crew.py` directly in the project folder.
|
||||
</Info>
|
||||
|
||||
## Pre-Deployment Checklist
|
||||
|
||||
Use this checklist to verify your project is ready for deployment.
|
||||
|
||||
### 1. Verify pyproject.toml Configuration
|
||||
|
||||
Your `pyproject.toml` must include the correct `[tool.crewai]` section:
|
||||
|
||||
<Tabs>
|
||||
<Tab title="For Crews">
|
||||
```toml
|
||||
[tool.crewai]
|
||||
type = "crew"
|
||||
```
|
||||
</Tab>
|
||||
<Tab title="For Flows">
|
||||
```toml
|
||||
[tool.crewai]
|
||||
type = "flow"
|
||||
```
|
||||
</Tab>
|
||||
</Tabs>
|
||||
|
||||
<Warning>
|
||||
If the `type` doesn't match your project structure, the build will fail or
|
||||
the automation won't run correctly.
|
||||
</Warning>
|
||||
|
||||
### 2. Ensure uv.lock File Exists
|
||||
|
||||
CrewAI uses `uv` for dependency management. The `uv.lock` file ensures reproducible builds and is **required** for deployment.
|
||||
|
||||
```bash
|
||||
# Generate or update the lock file
|
||||
uv lock
|
||||
|
||||
# Verify it exists
|
||||
ls -la uv.lock
|
||||
```
|
||||
|
||||
If the file doesn't exist, run `uv lock` and commit it to your repository:
|
||||
|
||||
```bash
|
||||
uv lock
|
||||
git add uv.lock
|
||||
git commit -m "Add uv.lock for deployment"
|
||||
git push
|
||||
```
|
||||
|
||||
### 3. Validate CrewBase Decorator Usage
|
||||
|
||||
**Every crew class must use the `@CrewBase` decorator.** This applies to:
|
||||
|
||||
- Standalone crew projects
|
||||
- Crews embedded inside Flow projects
|
||||
|
||||
```python
|
||||
from crewai import Agent, Crew, Process, Task
|
||||
from crewai.project import CrewBase, agent, crew, task
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from typing import List
|
||||
|
||||
@CrewBase # This decorator is REQUIRED
|
||||
class MyCrew():
|
||||
"""My crew description"""
|
||||
|
||||
agents: List[BaseAgent]
|
||||
tasks: List[Task]
|
||||
|
||||
@agent
|
||||
def my_agent(self) -> Agent:
|
||||
return Agent(
|
||||
config=self.agents_config['my_agent'], # type: ignore[index]
|
||||
verbose=True
|
||||
)
|
||||
|
||||
@task
|
||||
def my_task(self) -> Task:
|
||||
return Task(
|
||||
config=self.tasks_config['my_task'] # type: ignore[index]
|
||||
)
|
||||
|
||||
@crew
|
||||
def crew(self) -> Crew:
|
||||
return Crew(
|
||||
agents=self.agents,
|
||||
tasks=self.tasks,
|
||||
process=Process.sequential,
|
||||
verbose=True,
|
||||
)
|
||||
```
|
||||
|
||||
<Warning>
|
||||
If you forget the `@CrewBase` decorator, your deployment will fail with
|
||||
errors about missing agents or tasks configurations.
|
||||
</Warning>
|
||||
|
||||
### 4. Check Project Entry Points
|
||||
|
||||
Both Crews and Flows have their entry point in `src/project_name/main.py`:
|
||||
|
||||
<Tabs>
|
||||
<Tab title="For Crews">
|
||||
The entry point uses a `run()` function:
|
||||
|
||||
```python
|
||||
# src/my_crew/main.py
|
||||
from my_crew.crew import MyCrew
|
||||
|
||||
def run():
|
||||
"""Run the crew."""
|
||||
inputs = {'topic': 'AI in Healthcare'}
|
||||
result = MyCrew().crew().kickoff(inputs=inputs)
|
||||
return result
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
```
|
||||
</Tab>
|
||||
<Tab title="For Flows">
|
||||
The entry point uses a `kickoff()` function with a Flow class:
|
||||
|
||||
```python
|
||||
# src/my_flow/main.py
|
||||
from crewai.flow import Flow, listen, start
|
||||
from my_flow.crews.poem_crew.poem_crew import PoemCrew
|
||||
|
||||
class MyFlow(Flow):
|
||||
@start()
|
||||
def begin(self):
|
||||
# Flow logic here
|
||||
result = PoemCrew().crew().kickoff(inputs={...})
|
||||
return result
|
||||
|
||||
def kickoff():
|
||||
"""Run the flow."""
|
||||
MyFlow().kickoff()
|
||||
|
||||
if __name__ == "__main__":
|
||||
kickoff()
|
||||
```
|
||||
</Tab>
|
||||
</Tabs>
|
||||
|
||||
### 5. Prepare Environment Variables
|
||||
|
||||
Before deployment, ensure you have:
|
||||
|
||||
1. **LLM API keys** ready (OpenAI, Anthropic, Google, etc.)
|
||||
2. **Tool API keys** if using external tools (Serper, etc.)
|
||||
|
||||
<Tip>
|
||||
Test your project locally with the same environment variables before deploying
|
||||
to catch configuration issues early.
|
||||
</Tip>
|
||||
|
||||
## Quick Validation Commands
|
||||
|
||||
Run these commands from your project root to quickly verify your setup:
|
||||
|
||||
```bash
|
||||
# 1. Check project type in pyproject.toml
|
||||
grep -A2 "\[tool.crewai\]" pyproject.toml
|
||||
|
||||
# 2. Verify uv.lock exists
|
||||
ls -la uv.lock || echo "ERROR: uv.lock missing! Run 'uv lock'"
|
||||
|
||||
# 3. Verify src/ structure exists
|
||||
ls -la src/*/main.py 2>/dev/null || echo "No main.py found in src/"
|
||||
|
||||
# 4. For Crews - verify crew.py exists
|
||||
ls -la src/*/crew.py 2>/dev/null || echo "No crew.py (expected for Crews)"
|
||||
|
||||
# 5. For Flows - verify crews/ folder exists
|
||||
ls -la src/*/crews/ 2>/dev/null || echo "No crews/ folder (expected for Flows)"
|
||||
|
||||
# 6. Check for CrewBase usage
|
||||
grep -r "@CrewBase" . --include="*.py"
|
||||
```
|
||||
|
||||
## Common Setup Mistakes
|
||||
|
||||
| Mistake | Symptom | Fix |
|
||||
|---------|---------|-----|
|
||||
| Missing `uv.lock` | Build fails during dependency resolution | Run `uv lock` and commit |
|
||||
| Wrong `type` in pyproject.toml | Build succeeds but runtime fails | Change to correct type |
|
||||
| Missing `@CrewBase` decorator | "Config not found" errors | Add decorator to all crew classes |
|
||||
| Files at root instead of `src/` | Entry point not found | Move to `src/project_name/` |
|
||||
| Missing `run()` or `kickoff()` | Cannot start automation | Add correct entry function |
|
||||
|
||||
## Next Steps
|
||||
|
||||
Once your project passes all checklist items, you're ready to deploy:
|
||||
|
||||
<Card title="Deploy to AMP" icon="rocket" href="/en/enterprise/guides/deploy-to-amp">
|
||||
Follow the deployment guide to deploy your Crew or Flow to CrewAI AMP using
|
||||
the CLI, web interface, or CI/CD integration.
|
||||
</Card>
|
||||
@@ -1,48 +1,43 @@
|
||||
---
|
||||
title: Agent-to-Agent (A2A) Protocol
|
||||
description: Agents delegate tasks to remote A2A agents and/or operate as A2A-compliant server agents.
|
||||
description: Enable CrewAI agents to delegate tasks to remote A2A-compliant agents for specialized handling
|
||||
icon: network-wired
|
||||
mode: "wide"
|
||||
---
|
||||
|
||||
## A2A Agent Delegation
|
||||
|
||||
CrewAI treats [A2A protocol](https://a2a-protocol.org/latest/) as a first-class delegation primitive, enabling agents to delegate tasks, request information, and collaborate with remote agents, as well as act as A2A-compliant server agents.
|
||||
In client mode, agents autonomously choose between local execution and remote delegation based on task requirements.
|
||||
CrewAI supports the Agent-to-Agent (A2A) protocol, allowing agents to delegate tasks to remote specialized agents. The agent's LLM automatically decides whether to handle a task directly or delegate to an A2A agent based on the task requirements.
|
||||
|
||||
<Note>
|
||||
A2A delegation requires the `a2a-sdk` package. Install with: `uv add 'crewai[a2a]'` or `pip install 'crewai[a2a]'`
|
||||
</Note>
|
||||
|
||||
## How It Works
|
||||
|
||||
When an agent is configured with A2A capabilities:
|
||||
|
||||
1. The Agent analyzes each task
|
||||
1. The LLM analyzes each task
|
||||
2. It decides to either:
|
||||
- Handle the task directly using its own capabilities
|
||||
- Delegate to a remote A2A agent for specialized handling
|
||||
3. If delegating, the agent communicates with the remote A2A agent through the protocol
|
||||
4. Results are returned to the CrewAI workflow
|
||||
|
||||
<Note>
|
||||
A2A delegation requires the `a2a-sdk` package. Install with: `uv add 'crewai[a2a]'` or `pip install 'crewai[a2a]'`
|
||||
</Note>
|
||||
|
||||
## Basic Configuration
|
||||
|
||||
<Warning>
|
||||
`crewai.a2a.config.A2AConfig` is deprecated and will be removed in v2.0.0. Use `A2AClientConfig` for connecting to remote agents and/or `A2AServerConfig` for exposing agents as servers.
|
||||
</Warning>
|
||||
|
||||
Configure an agent for A2A delegation by setting the `a2a` parameter:
|
||||
|
||||
```python Code
|
||||
from crewai import Agent, Crew, Task
|
||||
from crewai.a2a import A2AClientConfig
|
||||
from crewai.a2a import A2AConfig
|
||||
|
||||
agent = Agent(
|
||||
role="Research Coordinator",
|
||||
goal="Coordinate research tasks efficiently",
|
||||
backstory="Expert at delegating to specialized research agents",
|
||||
llm="gpt-4o",
|
||||
a2a=A2AClientConfig(
|
||||
a2a=A2AConfig(
|
||||
endpoint="https://example.com/.well-known/agent-card.json",
|
||||
timeout=120,
|
||||
max_turns=10
|
||||
@@ -59,9 +54,9 @@ crew = Crew(agents=[agent], tasks=[task], verbose=True)
|
||||
result = crew.kickoff()
|
||||
```
|
||||
|
||||
## Client Configuration Options
|
||||
## Configuration Options
|
||||
|
||||
The `A2AClientConfig` class accepts the following parameters:
|
||||
The `A2AConfig` class accepts the following parameters:
|
||||
|
||||
<ParamField path="endpoint" type="str" required>
|
||||
The A2A agent endpoint URL (typically points to `.well-known/agent-card.json`)
|
||||
@@ -96,34 +91,14 @@ The `A2AClientConfig` class accepts the following parameters:
|
||||
Update mechanism for receiving task status. Options: `StreamingConfig`, `PollingConfig`, or `PushNotificationConfig`.
|
||||
</ParamField>
|
||||
|
||||
<ParamField path="transport_protocol" type="Literal['JSONRPC', 'GRPC', 'HTTP+JSON']" default="JSONRPC">
|
||||
Transport protocol for A2A communication. Options: `JSONRPC` (default), `GRPC`, or `HTTP+JSON`.
|
||||
</ParamField>
|
||||
|
||||
<ParamField path="accepted_output_modes" type="list[str]" default='["application/json"]'>
|
||||
Media types the client can accept in responses.
|
||||
</ParamField>
|
||||
|
||||
<ParamField path="supported_transports" type="list[str]" default='["JSONRPC"]'>
|
||||
Ordered list of transport protocols the client supports.
|
||||
</ParamField>
|
||||
|
||||
<ParamField path="use_client_preference" type="bool" default="False">
|
||||
Whether to prioritize client transport preferences over server.
|
||||
</ParamField>
|
||||
|
||||
<ParamField path="extensions" type="list[str]" default="[]">
|
||||
Extension URIs the client supports.
|
||||
</ParamField>
|
||||
|
||||
## Authentication
|
||||
|
||||
For A2A agents that require authentication, use one of the provided auth schemes:
|
||||
|
||||
<Tabs>
|
||||
<Tab title="Bearer Token">
|
||||
```python bearer_token_auth.py lines
|
||||
from crewai.a2a import A2AClientConfig
|
||||
```python Code
|
||||
from crewai.a2a import A2AConfig
|
||||
from crewai.a2a.auth import BearerTokenAuth
|
||||
|
||||
agent = Agent(
|
||||
@@ -131,18 +106,18 @@ agent = Agent(
|
||||
goal="Coordinate tasks with secured agents",
|
||||
backstory="Manages secure agent communications",
|
||||
llm="gpt-4o",
|
||||
a2a=A2AClientConfig(
|
||||
a2a=A2AConfig(
|
||||
endpoint="https://secure-agent.example.com/.well-known/agent-card.json",
|
||||
auth=BearerTokenAuth(token="your-bearer-token"),
|
||||
timeout=120
|
||||
)
|
||||
)
|
||||
```
|
||||
```
|
||||
</Tab>
|
||||
|
||||
<Tab title="API Key">
|
||||
```python api_key_auth.py lines
|
||||
from crewai.a2a import A2AClientConfig
|
||||
```python Code
|
||||
from crewai.a2a import A2AConfig
|
||||
from crewai.a2a.auth import APIKeyAuth
|
||||
|
||||
agent = Agent(
|
||||
@@ -150,7 +125,7 @@ agent = Agent(
|
||||
goal="Coordinate with API-based agents",
|
||||
backstory="Manages API-authenticated communications",
|
||||
llm="gpt-4o",
|
||||
a2a=A2AClientConfig(
|
||||
a2a=A2AConfig(
|
||||
endpoint="https://api-agent.example.com/.well-known/agent-card.json",
|
||||
auth=APIKeyAuth(
|
||||
api_key="your-api-key",
|
||||
@@ -160,12 +135,12 @@ agent = Agent(
|
||||
timeout=120
|
||||
)
|
||||
)
|
||||
```
|
||||
```
|
||||
</Tab>
|
||||
|
||||
<Tab title="OAuth2">
|
||||
```python oauth2_auth.py lines
|
||||
from crewai.a2a import A2AClientConfig
|
||||
```python Code
|
||||
from crewai.a2a import A2AConfig
|
||||
from crewai.a2a.auth import OAuth2ClientCredentials
|
||||
|
||||
agent = Agent(
|
||||
@@ -173,7 +148,7 @@ agent = Agent(
|
||||
goal="Coordinate with OAuth-secured agents",
|
||||
backstory="Manages OAuth-authenticated communications",
|
||||
llm="gpt-4o",
|
||||
a2a=A2AClientConfig(
|
||||
a2a=A2AConfig(
|
||||
endpoint="https://oauth-agent.example.com/.well-known/agent-card.json",
|
||||
auth=OAuth2ClientCredentials(
|
||||
token_url="https://auth.example.com/oauth/token",
|
||||
@@ -184,12 +159,12 @@ agent = Agent(
|
||||
timeout=120
|
||||
)
|
||||
)
|
||||
```
|
||||
```
|
||||
</Tab>
|
||||
|
||||
<Tab title="HTTP Basic">
|
||||
```python http_basic_auth.py lines
|
||||
from crewai.a2a import A2AClientConfig
|
||||
```python Code
|
||||
from crewai.a2a import A2AConfig
|
||||
from crewai.a2a.auth import HTTPBasicAuth
|
||||
|
||||
agent = Agent(
|
||||
@@ -197,7 +172,7 @@ agent = Agent(
|
||||
goal="Coordinate with basic auth agents",
|
||||
backstory="Manages basic authentication communications",
|
||||
llm="gpt-4o",
|
||||
a2a=A2AClientConfig(
|
||||
a2a=A2AConfig(
|
||||
endpoint="https://basic-agent.example.com/.well-known/agent-card.json",
|
||||
auth=HTTPBasicAuth(
|
||||
username="your-username",
|
||||
@@ -206,7 +181,7 @@ agent = Agent(
|
||||
timeout=120
|
||||
)
|
||||
)
|
||||
```
|
||||
```
|
||||
</Tab>
|
||||
</Tabs>
|
||||
|
||||
@@ -215,7 +190,7 @@ agent = Agent(
|
||||
Configure multiple A2A agents for delegation by passing a list:
|
||||
|
||||
```python Code
|
||||
from crewai.a2a import A2AClientConfig
|
||||
from crewai.a2a import A2AConfig
|
||||
from crewai.a2a.auth import BearerTokenAuth
|
||||
|
||||
agent = Agent(
|
||||
@@ -224,11 +199,11 @@ agent = Agent(
|
||||
backstory="Expert at delegating to the right specialist",
|
||||
llm="gpt-4o",
|
||||
a2a=[
|
||||
A2AClientConfig(
|
||||
A2AConfig(
|
||||
endpoint="https://research.example.com/.well-known/agent-card.json",
|
||||
timeout=120
|
||||
),
|
||||
A2AClientConfig(
|
||||
A2AConfig(
|
||||
endpoint="https://data.example.com/.well-known/agent-card.json",
|
||||
auth=BearerTokenAuth(token="data-token"),
|
||||
timeout=90
|
||||
@@ -244,7 +219,7 @@ The LLM will automatically choose which A2A agent to delegate to based on the ta
|
||||
Control how agent connection failures are handled using the `fail_fast` parameter:
|
||||
|
||||
```python Code
|
||||
from crewai.a2a import A2AClientConfig
|
||||
from crewai.a2a import A2AConfig
|
||||
|
||||
# Fail immediately on connection errors (default)
|
||||
agent = Agent(
|
||||
@@ -252,7 +227,7 @@ agent = Agent(
|
||||
goal="Coordinate research tasks",
|
||||
backstory="Expert at delegation",
|
||||
llm="gpt-4o",
|
||||
a2a=A2AClientConfig(
|
||||
a2a=A2AConfig(
|
||||
endpoint="https://research.example.com/.well-known/agent-card.json",
|
||||
fail_fast=True
|
||||
)
|
||||
@@ -265,11 +240,11 @@ agent = Agent(
|
||||
backstory="Expert at working with available resources",
|
||||
llm="gpt-4o",
|
||||
a2a=[
|
||||
A2AClientConfig(
|
||||
A2AConfig(
|
||||
endpoint="https://primary.example.com/.well-known/agent-card.json",
|
||||
fail_fast=False
|
||||
),
|
||||
A2AClientConfig(
|
||||
A2AConfig(
|
||||
endpoint="https://backup.example.com/.well-known/agent-card.json",
|
||||
fail_fast=False
|
||||
)
|
||||
@@ -288,8 +263,8 @@ Control how your agent receives task status updates from remote A2A agents:
|
||||
|
||||
<Tabs>
|
||||
<Tab title="Streaming (Default)">
|
||||
```python streaming_config.py lines
|
||||
from crewai.a2a import A2AClientConfig
|
||||
```python Code
|
||||
from crewai.a2a import A2AConfig
|
||||
from crewai.a2a.updates import StreamingConfig
|
||||
|
||||
agent = Agent(
|
||||
@@ -297,17 +272,17 @@ agent = Agent(
|
||||
goal="Coordinate research tasks",
|
||||
backstory="Expert at delegation",
|
||||
llm="gpt-4o",
|
||||
a2a=A2AClientConfig(
|
||||
a2a=A2AConfig(
|
||||
endpoint="https://research.example.com/.well-known/agent-card.json",
|
||||
updates=StreamingConfig()
|
||||
)
|
||||
)
|
||||
```
|
||||
```
|
||||
</Tab>
|
||||
|
||||
<Tab title="Polling">
|
||||
```python polling_config.py lines
|
||||
from crewai.a2a import A2AClientConfig
|
||||
```python Code
|
||||
from crewai.a2a import A2AConfig
|
||||
from crewai.a2a.updates import PollingConfig
|
||||
|
||||
agent = Agent(
|
||||
@@ -315,7 +290,7 @@ agent = Agent(
|
||||
goal="Coordinate research tasks",
|
||||
backstory="Expert at delegation",
|
||||
llm="gpt-4o",
|
||||
a2a=A2AClientConfig(
|
||||
a2a=A2AConfig(
|
||||
endpoint="https://research.example.com/.well-known/agent-card.json",
|
||||
updates=PollingConfig(
|
||||
interval=2.0,
|
||||
@@ -324,12 +299,12 @@ agent = Agent(
|
||||
)
|
||||
)
|
||||
)
|
||||
```
|
||||
```
|
||||
</Tab>
|
||||
|
||||
<Tab title="Push Notifications">
|
||||
```python push_notifications_config.py lines
|
||||
from crewai.a2a import A2AClientConfig
|
||||
```python Code
|
||||
from crewai.a2a import A2AConfig
|
||||
from crewai.a2a.updates import PushNotificationConfig
|
||||
|
||||
agent = Agent(
|
||||
@@ -337,137 +312,19 @@ agent = Agent(
|
||||
goal="Coordinate research tasks",
|
||||
backstory="Expert at delegation",
|
||||
llm="gpt-4o",
|
||||
a2a=A2AClientConfig(
|
||||
a2a=A2AConfig(
|
||||
endpoint="https://research.example.com/.well-known/agent-card.json",
|
||||
updates=PushNotificationConfig(
|
||||
url="{base_url}/a2a/callback",
|
||||
url={base_url}/a2a/callback",
|
||||
token="your-validation-token",
|
||||
timeout=300.0
|
||||
)
|
||||
)
|
||||
)
|
||||
```
|
||||
```
|
||||
</Tab>
|
||||
</Tabs>
|
||||
|
||||
## Exposing Agents as A2A Servers
|
||||
|
||||
You can expose your CrewAI agents as A2A-compliant servers, allowing other A2A clients to delegate tasks to them.
|
||||
|
||||
### Server Configuration
|
||||
|
||||
Add an `A2AServerConfig` to your agent to enable server capabilities:
|
||||
|
||||
```python a2a_server_agent.py lines
|
||||
from crewai import Agent
|
||||
from crewai.a2a import A2AServerConfig
|
||||
|
||||
agent = Agent(
|
||||
role="Data Analyst",
|
||||
goal="Analyze datasets and provide insights",
|
||||
backstory="Expert data scientist with statistical analysis skills",
|
||||
llm="gpt-4o",
|
||||
a2a=A2AServerConfig(url="https://your-server.com")
|
||||
)
|
||||
```
|
||||
|
||||
### Server Configuration Options
|
||||
|
||||
<ParamField path="name" type="str" default="None">
|
||||
Human-readable name for the agent. Defaults to the agent's role if not provided.
|
||||
</ParamField>
|
||||
|
||||
<ParamField path="description" type="str" default="None">
|
||||
Human-readable description. Defaults to the agent's goal and backstory if not provided.
|
||||
</ParamField>
|
||||
|
||||
<ParamField path="version" type="str" default="1.0.0">
|
||||
Version string for the agent card.
|
||||
</ParamField>
|
||||
|
||||
<ParamField path="skills" type="list[AgentSkill]" default="[]">
|
||||
List of agent skills. Auto-generated from agent tools if not provided.
|
||||
</ParamField>
|
||||
|
||||
<ParamField path="capabilities" type="AgentCapabilities" default="AgentCapabilities(streaming=True, push_notifications=False)">
|
||||
Declaration of optional capabilities supported by the agent.
|
||||
</ParamField>
|
||||
|
||||
<ParamField path="default_input_modes" type="list[str]" default='["text/plain", "application/json"]'>
|
||||
Supported input MIME types.
|
||||
</ParamField>
|
||||
|
||||
<ParamField path="default_output_modes" type="list[str]" default='["text/plain", "application/json"]'>
|
||||
Supported output MIME types.
|
||||
</ParamField>
|
||||
|
||||
<ParamField path="url" type="str" default="None">
|
||||
Preferred endpoint URL. If set, overrides the URL passed to `to_agent_card()`.
|
||||
</ParamField>
|
||||
|
||||
<ParamField path="preferred_transport" type="Literal['JSONRPC', 'GRPC', 'HTTP+JSON']" default="JSONRPC">
|
||||
Transport protocol for the preferred endpoint.
|
||||
</ParamField>
|
||||
|
||||
<ParamField path="protocol_version" type="str" default="0.3">
|
||||
A2A protocol version this agent supports.
|
||||
</ParamField>
|
||||
|
||||
<ParamField path="provider" type="AgentProvider" default="None">
|
||||
Information about the agent's service provider.
|
||||
</ParamField>
|
||||
|
||||
<ParamField path="documentation_url" type="str" default="None">
|
||||
URL to the agent's documentation.
|
||||
</ParamField>
|
||||
|
||||
<ParamField path="icon_url" type="str" default="None">
|
||||
URL to an icon for the agent.
|
||||
</ParamField>
|
||||
|
||||
<ParamField path="additional_interfaces" type="list[AgentInterface]" default="[]">
|
||||
Additional supported interfaces (transport and URL combinations).
|
||||
</ParamField>
|
||||
|
||||
<ParamField path="security" type="list[dict[str, list[str]]]" default="[]">
|
||||
Security requirement objects for all agent interactions.
|
||||
</ParamField>
|
||||
|
||||
<ParamField path="security_schemes" type="dict[str, SecurityScheme]" default="{}">
|
||||
Security schemes available to authorize requests.
|
||||
</ParamField>
|
||||
|
||||
<ParamField path="supports_authenticated_extended_card" type="bool" default="False">
|
||||
Whether agent provides extended card to authenticated users.
|
||||
</ParamField>
|
||||
|
||||
<ParamField path="signatures" type="list[AgentCardSignature]" default="[]">
|
||||
JSON Web Signatures for the AgentCard.
|
||||
</ParamField>
|
||||
|
||||
### Combined Client and Server
|
||||
|
||||
An agent can act as both client and server by providing both configurations:
|
||||
|
||||
```python Code
|
||||
from crewai import Agent
|
||||
from crewai.a2a import A2AClientConfig, A2AServerConfig
|
||||
|
||||
agent = Agent(
|
||||
role="Research Coordinator",
|
||||
goal="Coordinate research and serve analysis requests",
|
||||
backstory="Expert at delegation and analysis",
|
||||
llm="gpt-4o",
|
||||
a2a=[
|
||||
A2AClientConfig(
|
||||
endpoint="https://specialist.example.com/.well-known/agent-card.json",
|
||||
timeout=120
|
||||
),
|
||||
A2AServerConfig(url="https://your-server.com")
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
<CardGroup cols={2}>
|
||||
|
||||
@@ -1,115 +0,0 @@
|
||||
---
|
||||
title: Galileo
|
||||
description: Galileo integration for CrewAI tracing and evaluation
|
||||
icon: telescope
|
||||
mode: "wide"
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
This guide demonstrates how to integrate **Galileo** with **CrewAI**
|
||||
for comprehensive tracing and Evaluation Engineering.
|
||||
By the end of this guide, you will be able to trace your CrewAI agents,
|
||||
monitor their performance, and evaluate their behaviour with
|
||||
Galileo's powerful observability platform.
|
||||
|
||||
> **What is Galileo?** [Galileo](https://galileo.ai) is AI evaluation and observability
|
||||
platform that delivers end-to-end tracing, evaluation,
|
||||
and monitoring for AI applications. It enables teams to capture ground truth,
|
||||
create robust guardrails, and run systematic experiments with
|
||||
built-in experiment tracking and performance analytics—ensuring reliability,
|
||||
transparency, and continuous improvement across the AI lifecycle.
|
||||
|
||||
## Getting started
|
||||
|
||||
This tutorial follows the [CrewAI quickstart](/en/quickstart) and shows how to add
|
||||
Galileo's [CrewAIEventListener](https://v2docs.galileo.ai/sdk-api/python/reference/handlers/crewai/handler),
|
||||
an event handler.
|
||||
For more information, see Galileo’s
|
||||
[Add Galileo to a CrewAI Application](https://v2docs.galileo.ai/how-to-guides/third-party-integrations/add-galileo-to-crewai/add-galileo-to-crewai)
|
||||
how-to guide.
|
||||
|
||||
> **Note** This tutorial assumes you have completed the [CrewAI quickstart](/en/quickstart).
|
||||
If you want a completed comprehensive example, see the Galileo
|
||||
[CrewAI sdk-example repo](https://github.com/rungalileo/sdk-examples/tree/main/python/agent/crew-ai).
|
||||
|
||||
### Step 1: Install dependencies
|
||||
|
||||
Install the required dependencies for your app.
|
||||
Create a virtual environment using your preferred method,
|
||||
then install dependencies inside that environment using your
|
||||
preferred tool:
|
||||
|
||||
```bash
|
||||
uv add galileo
|
||||
```
|
||||
|
||||
### Step 2: Add to the .env file from the [CrewAI quickstart](/en/quickstart)
|
||||
|
||||
```bash
|
||||
# Your Galileo API key
|
||||
GALILEO_API_KEY="your-galileo-api-key"
|
||||
|
||||
# Your Galileo project name
|
||||
GALILEO_PROJECT="your-galileo-project-name"
|
||||
|
||||
# The name of the Log stream you want to use for logging
|
||||
GALILEO_LOG_STREAM="your-galileo-log-stream "
|
||||
```
|
||||
|
||||
### Step 3: Add the Galileo event listener
|
||||
|
||||
To enable logging with Galileo, you need to create an instance of the `CrewAIEventListener`.
|
||||
Import the Galileo CrewAI handler package by
|
||||
adding the following code at the top of your main.py file:
|
||||
|
||||
```python
|
||||
from galileo.handlers.crewai.handler import CrewAIEventListener
|
||||
```
|
||||
|
||||
At the start of your run function, create the event listener:
|
||||
|
||||
```python
|
||||
def run():
|
||||
# Create the event listener
|
||||
CrewAIEventListener()
|
||||
# The rest of your existing code goes here
|
||||
```
|
||||
|
||||
When you create the listener instance, it is automatically
|
||||
registered with CrewAI.
|
||||
|
||||
### Step 4: Run your crew
|
||||
|
||||
Run your crew with the CrewAI CLI:
|
||||
|
||||
```bash
|
||||
crewai run
|
||||
```
|
||||
|
||||
### Step 5: View the traces in Galileo
|
||||
|
||||
Once your crew has finished, the traces will be flushed and appear in Galileo.
|
||||
|
||||

|
||||
|
||||
## Understanding the Galileo Integration
|
||||
|
||||
Galileo integrates with CrewAI by registering an event listener
|
||||
that captures Crew execution events (e.g., agent actions, tool calls, model responses)
|
||||
and forwards them to Galileo for observability and evaluation.
|
||||
|
||||
### Understanding the event listener
|
||||
|
||||
Creating a `CrewAIEventListener()` instance is all that’s
|
||||
required to enable Galileo for a CrewAI run. When instantiated, the listener:
|
||||
|
||||
- Automatically registers itself with CrewAI
|
||||
- Reads Galileo configuration from environment variables
|
||||
- Logs all run data to the Galileo project and log stream specified by
|
||||
`GALILEO_PROJECT` and `GALILEO_LOG_STREAM`
|
||||
|
||||
No additional configuration or code changes are required.
|
||||
All data from this run is logged to the Galileo project and
|
||||
log stream specified by your environment configuration
|
||||
(for example, GALILEO_PROJECT and GALILEO_LOG_STREAM).
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 239 KiB |
@@ -107,7 +107,7 @@ CrewAI 코드 내에는 사용할 모델을 지정할 수 있는 여러 위치
|
||||
|
||||
## 공급자 구성 예시
|
||||
|
||||
CrewAI는 고유한 기능, 인증 방법, 모델 역량을 제공하는 다양한 LLM 공급자를 지원합니다.
|
||||
CrewAI는 고유한 기능, 인증 방법, 모델 역량을 제공하는 다양한 LLM 공급자를 지원합니다.
|
||||
이 섹션에서는 프로젝트의 요구에 가장 적합한 LLM을 선택, 구성, 최적화하는 데 도움이 되는 자세한 예시를 제공합니다.
|
||||
|
||||
<AccordionGroup>
|
||||
@@ -153,8 +153,8 @@ CrewAI는 고유한 기능, 인증 방법, 모델 역량을 제공하는 다양
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Meta-Llama">
|
||||
Meta의 Llama API는 Meta의 대형 언어 모델 패밀리 접근을 제공합니다.
|
||||
API는 [Meta Llama API](https://llama.developer.meta.com?utm_source=partner-crewai&utm_medium=website)에서 사용할 수 있습니다.
|
||||
Meta의 Llama API는 Meta의 대형 언어 모델 패밀리 접근을 제공합니다.
|
||||
API는 [Meta Llama API](https://llama.developer.meta.com?utm_source=partner-crewai&utm_medium=website)에서 사용할 수 있습니다.
|
||||
`.env` 파일에 다음 환경 변수를 설정하십시오:
|
||||
|
||||
```toml Code
|
||||
@@ -207,20 +207,11 @@ CrewAI는 고유한 기능, 인증 방법, 모델 역량을 제공하는 다양
|
||||
`.env` 파일에 API 키를 설정하십시오. 키가 필요하거나 기존 키를 찾으려면 [AI Studio](https://aistudio.google.com/apikey)를 확인하세요.
|
||||
|
||||
```toml .env
|
||||
# Gemini API 사용 시 (다음 중 하나)
|
||||
GOOGLE_API_KEY=<your-api-key>
|
||||
# https://ai.google.dev/gemini-api/docs/api-key
|
||||
GEMINI_API_KEY=<your-api-key>
|
||||
|
||||
# Vertex AI Express 모드 사용 시 (API 키 인증)
|
||||
GOOGLE_GENAI_USE_VERTEXAI=true
|
||||
GOOGLE_API_KEY=<your-api-key>
|
||||
|
||||
# Vertex AI 서비스 계정 사용 시
|
||||
GOOGLE_CLOUD_PROJECT=<your-project-id>
|
||||
GOOGLE_CLOUD_LOCATION=<location> # 기본값: us-central1
|
||||
```
|
||||
|
||||
**기본 사용법:**
|
||||
CrewAI 프로젝트에서의 예시 사용법:
|
||||
```python Code
|
||||
from crewai import LLM
|
||||
|
||||
@@ -230,34 +221,6 @@ CrewAI는 고유한 기능, 인증 방법, 모델 역량을 제공하는 다양
|
||||
)
|
||||
```
|
||||
|
||||
**Vertex AI Express 모드 (API 키 인증):**
|
||||
|
||||
Vertex AI Express 모드를 사용하면 서비스 계정 자격 증명 대신 간단한 API 키 인증으로 Vertex AI를 사용할 수 있습니다. Vertex AI를 시작하는 가장 빠른 방법입니다.
|
||||
|
||||
Express 모드를 활성화하려면 `.env` 파일에 두 환경 변수를 모두 설정하세요:
|
||||
```toml .env
|
||||
GOOGLE_GENAI_USE_VERTEXAI=true
|
||||
GOOGLE_API_KEY=<your-api-key>
|
||||
```
|
||||
|
||||
그런 다음 평소처럼 LLM을 사용하세요:
|
||||
```python Code
|
||||
from crewai import LLM
|
||||
|
||||
llm = LLM(
|
||||
model="gemini/gemini-2.0-flash",
|
||||
temperature=0.7
|
||||
)
|
||||
```
|
||||
|
||||
<Info>
|
||||
Express 모드 API 키를 받으려면:
|
||||
- 신규 Google Cloud 사용자: [Express 모드 API 키](https://cloud.google.com/vertex-ai/generative-ai/docs/start/quickstart?usertype=apikey) 받기
|
||||
- 기존 Google Cloud 사용자: [서비스 계정에 바인딩된 Google Cloud API 키](https://cloud.google.com/docs/authentication/api-keys) 받기
|
||||
|
||||
자세한 내용은 [Vertex AI Express 모드 문서](https://docs.cloud.google.com/vertex-ai/generative-ai/docs/start/quickstart?usertype=apikey)를 참조하세요.
|
||||
</Info>
|
||||
|
||||
### Gemini 모델
|
||||
|
||||
Google은 다양한 용도에 최적화된 강력한 모델을 제공합니다.
|
||||
@@ -513,7 +476,7 @@ CrewAI는 고유한 기능, 인증 방법, 모델 역량을 제공하는 다양
|
||||
|
||||
<Accordion title="Local NVIDIA NIM Deployed using WSL2">
|
||||
|
||||
NVIDIA NIM을 이용하면 Windows 기기에서 WSL2(Windows Subsystem for Linux)를 통해 강력한 LLM을 로컬로 실행할 수 있습니다.
|
||||
NVIDIA NIM을 이용하면 Windows 기기에서 WSL2(Windows Subsystem for Linux)를 통해 강력한 LLM을 로컬로 실행할 수 있습니다.
|
||||
이 방식은 Nvidia GPU를 활용하여 프라이빗하고, 안전하며, 비용 효율적인 AI 추론을 클라우드 서비스에 의존하지 않고 구현할 수 있습니다.
|
||||
데이터 프라이버시, 오프라인 기능이 필요한 개발, 테스트, 또는 프로덕션 환경에 최적입니다.
|
||||
|
||||
@@ -991,4 +954,4 @@ LLM 설정을 최대한 활용하는 방법을 알아보세요:
|
||||
llm = LLM(model="openai/gpt-4o") # 128K tokens
|
||||
```
|
||||
</Tab>
|
||||
</Tabs>
|
||||
</Tabs>
|
||||
@@ -128,7 +128,7 @@ Flow를 배포할 때 다음을 고려하세요:
|
||||
### CrewAI Enterprise
|
||||
Flow를 배포하는 가장 쉬운 방법은 CrewAI Enterprise를 사용하는 것입니다. 인프라, 인증 및 모니터링을 대신 처리합니다.
|
||||
|
||||
시작하려면 [배포 가이드](/ko/enterprise/guides/deploy-to-amp)를 확인하세요.
|
||||
시작하려면 [배포 가이드](/ko/enterprise/guides/deploy-crew)를 확인하세요.
|
||||
|
||||
```bash
|
||||
crewai deploy create
|
||||
|
||||
@@ -91,7 +91,7 @@ Git 없이 빠르게 배포 — 프로젝트 ZIP 패키지를 업로드하세요
|
||||
## 관련 문서
|
||||
|
||||
<CardGroup cols={3}>
|
||||
<Card title="크루 배포" href="/ko/enterprise/guides/deploy-to-amp" icon="rocket">
|
||||
<Card title="크루 배포" href="/ko/enterprise/guides/deploy-crew" icon="rocket">
|
||||
GitHub 또는 ZIP 파일로 크루 배포
|
||||
</Card>
|
||||
<Card title="자동화 트리거" href="/ko/enterprise/guides/automation-triggers" icon="trigger">
|
||||
|
||||
@@ -79,7 +79,7 @@ Crew Studio는 자연어와 시각적 워크플로 에디터로 처음부터 자
|
||||
<Card title="크루 빌드" href="/ko/enterprise/guides/build-crew" icon="paintbrush">
|
||||
크루를 빌드하세요.
|
||||
</Card>
|
||||
<Card title="크루 배포" href="/ko/enterprise/guides/deploy-to-amp" icon="rocket">
|
||||
<Card title="크루 배포" href="/ko/enterprise/guides/deploy-crew" icon="rocket">
|
||||
GitHub 또는 ZIP 파일로 크루 배포.
|
||||
</Card>
|
||||
<Card title="React 컴포넌트 내보내기" href="/ko/enterprise/guides/react-component-export" icon="download">
|
||||
|
||||
305
docs/ko/enterprise/guides/deploy-crew.mdx
Normal file
305
docs/ko/enterprise/guides/deploy-crew.mdx
Normal file
@@ -0,0 +1,305 @@
|
||||
---
|
||||
title: "Crew 배포"
|
||||
description: "CrewAI 엔터프라이즈에서 Crew 배포하기"
|
||||
icon: "rocket"
|
||||
mode: "wide"
|
||||
---
|
||||
|
||||
<Note>
|
||||
로컬에서 또는 Crew Studio를 통해 crew를 생성한 후, 다음 단계는 이를 CrewAI AMP
|
||||
플랫폼에 배포하는 것입니다. 본 가이드에서는 다양한 배포 방법을 다루며,
|
||||
여러분의 워크플로우에 가장 적합한 방식을 선택할 수 있도록 안내합니다.
|
||||
</Note>
|
||||
|
||||
## 사전 준비 사항
|
||||
|
||||
<CardGroup cols={2}>
|
||||
<Card title="배포 준비가 된 Crew" icon="users">
|
||||
작동 중인 crew가 로컬에서 빌드되었거나 Crew Studio를 통해 생성되어 있어야
|
||||
합니다.
|
||||
</Card>
|
||||
<Card title="GitHub 저장소" icon="github">
|
||||
crew 코드가 GitHub 저장소에 있어야 합니다(GitHub 연동 방식의 경우).
|
||||
</Card>
|
||||
</CardGroup>
|
||||
|
||||
## 옵션 1: CrewAI CLI를 사용한 배포
|
||||
|
||||
CLI는 로컬에서 개발된 crew를 Enterprise 플랫폼에 가장 빠르게 배포할 수 있는 방법을 제공합니다.
|
||||
|
||||
<Steps>
|
||||
<Step title="CrewAI CLI 설치">
|
||||
아직 설치하지 않았다면 CrewAI CLI를 설치하세요:
|
||||
|
||||
```bash
|
||||
pip install crewai[tools]
|
||||
```
|
||||
|
||||
<Tip>
|
||||
CLI는 기본 CrewAI 패키지에 포함되어 있지만, `[tools]` 추가 옵션을 사용하면 모든 배포 종속성을 함께 설치할 수 있습니다.
|
||||
</Tip>
|
||||
|
||||
</Step>
|
||||
|
||||
<Step title="Enterprise 플랫폼에 인증">
|
||||
먼저, CrewAI AMP 플랫폼에 CLI를 인증해야 합니다:
|
||||
|
||||
```bash
|
||||
# 이미 CrewAI AMP 계정이 있거나 새로 생성하고 싶을 때:
|
||||
crewai login
|
||||
```
|
||||
|
||||
위 명령어를 실행하면 CLI가 다음을 진행합니다:
|
||||
1. URL과 고유 기기 코드를 표시합니다
|
||||
2. 브라우저를 열어 인증 페이지로 이동합니다
|
||||
3. 기기 확인을 요청합니다
|
||||
4. 인증 과정을 완료합니다
|
||||
|
||||
인증이 성공적으로 완료되면 터미널에 확인 메시지가 표시됩니다!
|
||||
|
||||
</Step>
|
||||
|
||||
<Step title="배포 생성">
|
||||
|
||||
프로젝트 디렉터리에서 다음 명령어를 실행하세요:
|
||||
|
||||
```bash
|
||||
crewai deploy create
|
||||
```
|
||||
|
||||
이 명령어는 다음을 수행합니다:
|
||||
1. GitHub 저장소 정보를 감지합니다
|
||||
2. 로컬 `.env` 파일의 환경 변수를 식별합니다
|
||||
3. 이러한 변수를 Enterprise 플랫폼으로 안전하게 전송합니다
|
||||
4. 고유 식별자가 부여된 새 배포를 만듭니다
|
||||
|
||||
성공적으로 생성되면 다음과 같은 메시지가 표시됩니다:
|
||||
```shell
|
||||
Deployment created successfully!
|
||||
Name: your_project_name
|
||||
Deployment ID: 01234567-89ab-cdef-0123-456789abcdef
|
||||
Current Status: Deploy Enqueued
|
||||
```
|
||||
|
||||
</Step>
|
||||
|
||||
<Step title="배포 진행 상황 모니터링">
|
||||
|
||||
다음 명령어로 배포 상태를 추적할 수 있습니다:
|
||||
|
||||
```bash
|
||||
crewai deploy status
|
||||
```
|
||||
|
||||
빌드 과정의 상세 로그가 필요하다면:
|
||||
|
||||
```bash
|
||||
crewai deploy logs
|
||||
```
|
||||
|
||||
<Tip>
|
||||
첫 배포는 컨테이너 이미지를 빌드하므로 일반적으로 10~15분 정도 소요됩니다. 이후 배포는 훨씬 빠릅니다.
|
||||
</Tip>
|
||||
|
||||
</Step>
|
||||
</Steps>
|
||||
|
||||
## 추가 CLI 명령어
|
||||
|
||||
CrewAI CLI는 배포를 관리하기 위한 여러 명령어를 제공합니다:
|
||||
|
||||
```bash
|
||||
# 모든 배포 목록 확인
|
||||
crewai deploy list
|
||||
|
||||
# 배포 상태 확인
|
||||
crewai deploy status
|
||||
|
||||
# 배포 로그 보기
|
||||
crewai deploy logs
|
||||
|
||||
# 코드 변경 후 업데이트 푸시
|
||||
crewai deploy push
|
||||
|
||||
# 배포 삭제
|
||||
crewai deploy remove <deployment_id>
|
||||
```
|
||||
|
||||
## 옵션 2: 웹 인터페이스를 통한 직접 배포
|
||||
|
||||
GitHub 계정을 연결하여 CrewAI AMP 웹 인터페이스를 통해 crews를 직접 배포할 수도 있습니다. 이 방법은 로컬 머신에서 CLI를 사용할 필요가 없습니다.
|
||||
|
||||
<Steps>
|
||||
|
||||
<Step title="GitHub로 푸시하기">
|
||||
|
||||
crew를 GitHub 저장소에 푸시해야 합니다. 아직 crew를 만들지 않았다면, [이 튜토리얼](/ko/quickstart)을 따라할 수 있습니다.
|
||||
|
||||
</Step>
|
||||
|
||||
<Step title="GitHub를 CrewAI AOP에 연결하기">
|
||||
|
||||
1. [CrewAI AMP](https://app.crewai.com)에 로그인합니다.
|
||||
2. "Connect GitHub" 버튼을 클릭합니다.
|
||||
|
||||
<Frame>
|
||||

|
||||
</Frame>
|
||||
|
||||
</Step>
|
||||
|
||||
<Step title="저장소 선택하기">
|
||||
|
||||
GitHub 계정을 연결한 후 배포할 저장소를 선택할 수 있습니다:
|
||||
|
||||
<Frame>
|
||||

|
||||
</Frame>
|
||||
|
||||
</Step>
|
||||
|
||||
<Step title="환경 변수 설정하기">
|
||||
|
||||
배포 전에, LLM 제공업체 또는 기타 서비스에 연결할 환경 변수를 설정해야 합니다:
|
||||
|
||||
1. 변수를 개별적으로 또는 일괄적으로 추가할 수 있습니다.
|
||||
2. 환경 변수는 `KEY=VALUE` 형식(한 줄에 하나씩)으로 입력합니다.
|
||||
|
||||
<Frame>
|
||||

|
||||
</Frame>
|
||||
|
||||
</Step>
|
||||
|
||||
<Step title="Crew 배포하기">
|
||||
|
||||
1. "Deploy" 버튼을 클릭하여 배포 프로세스를 시작합니다.
|
||||
2. 진행 바를 통해 진행 상황을 모니터링할 수 있습니다.
|
||||
3. 첫 번째 배포에는 일반적으로 약 10-15분 정도 소요되며, 이후 배포는 더 빠릅니다.
|
||||
|
||||
<Frame>
|
||||

|
||||
</Frame>
|
||||
|
||||
배포가 완료되면 다음을 확인할 수 있습니다:
|
||||
- crew의 고유 URL
|
||||
- crew API를 보호할 Bearer 토큰
|
||||
- 배포를 삭제해야 하는 경우 "Delete" 버튼
|
||||
|
||||
</Step>
|
||||
|
||||
</Steps>
|
||||
|
||||
## ⚠️ 환경 변수 보안 요구사항
|
||||
|
||||
<Warning>
|
||||
**중요**: CrewAI AOP는 환경 변수 이름에 대한 보안 제한이 있으며, 이를 따르지
|
||||
않을 경우 배포가 실패할 수 있습니다.
|
||||
</Warning>
|
||||
|
||||
### 차단된 환경 변수 패턴
|
||||
|
||||
보안상의 이유로, 다음과 같은 환경 변수 명명 패턴은 **자동으로 필터링**되며 배포에 문제가 발생할 수 있습니다:
|
||||
|
||||
**차단된 패턴:**
|
||||
|
||||
- `_TOKEN`으로 끝나는 변수 (예: `MY_API_TOKEN`)
|
||||
- `_PASSWORD`로 끝나는 변수 (예: `DB_PASSWORD`)
|
||||
- `_SECRET`로 끝나는 변수 (예: `API_SECRET`)
|
||||
- 특정 상황에서 `_KEY`로 끝나는 변수
|
||||
|
||||
**특정 차단 변수:**
|
||||
|
||||
- `GITHUB_USER`, `GITHUB_TOKEN`
|
||||
- `AWS_REGION`, `AWS_DEFAULT_REGION`
|
||||
- 다양한 내부 CrewAI 시스템 변수
|
||||
|
||||
### 허용된 예외
|
||||
|
||||
일부 변수는 차단된 패턴과 일치하더라도 명시적으로 허용됩니다:
|
||||
|
||||
- `AZURE_AD_TOKEN`
|
||||
- `AZURE_OPENAI_AD_TOKEN`
|
||||
- `ENTERPRISE_ACTION_TOKEN`
|
||||
- `CREWAI_ENTEPRISE_TOOLS_TOKEN`
|
||||
|
||||
### 네이밍 문제 해결 방법
|
||||
|
||||
환경 변수 제한으로 인해 배포가 실패하는 경우:
|
||||
|
||||
```bash
|
||||
# ❌ 이러한 이름은 배포 실패를 초래합니다
|
||||
OPENAI_TOKEN=sk-...
|
||||
DATABASE_PASSWORD=mypassword
|
||||
API_SECRET=secret123
|
||||
|
||||
# ✅ 대신 다음과 같은 네이밍 패턴을 사용하세요
|
||||
OPENAI_API_KEY=sk-...
|
||||
DATABASE_CREDENTIALS=mypassword
|
||||
API_CONFIG=secret123
|
||||
```
|
||||
|
||||
### 모범 사례
|
||||
|
||||
1. **표준 명명 규칙 사용**: `PROVIDER_TOKEN` 대신 `PROVIDER_API_KEY` 사용
|
||||
2. **먼저 로컬에서 테스트**: crew가 이름이 변경된 변수로 제대로 동작하는지 확인
|
||||
3. **코드 업데이트**: 이전 변수 이름을 참조하는 부분을 모두 변경
|
||||
4. **변경 내용 문서화**: 팀을 위해 이름이 변경된 변수를 기록
|
||||
|
||||
<Tip>
|
||||
배포 실패 시, 환경 변수 에러 메시지가 난해하다면 먼저 변수 이름이 이 패턴을
|
||||
따르는지 확인하세요.
|
||||
</Tip>
|
||||
|
||||
### 배포된 Crew와 상호작용하기
|
||||
|
||||
배포가 완료되면 다음을 통해 crew에 접근할 수 있습니다:
|
||||
|
||||
1. **REST API**: 플랫폼에서 아래의 주요 경로가 포함된 고유한 HTTPS 엔드포인트를 생성합니다:
|
||||
|
||||
- `/inputs`: 필요한 입력 파라미터 목록
|
||||
- `/kickoff`: 제공된 입력값으로 실행 시작
|
||||
- `/status/{kickoff_id}`: 실행 상태 확인
|
||||
|
||||
2. **웹 인터페이스**: [app.crewai.com](https://app.crewai.com)에 방문하여 다음을 확인할 수 있습니다:
|
||||
- **Status 탭**: 배포 정보, API 엔드포인트 세부 정보 및 인증 토큰 확인
|
||||
- **Run 탭**: crew 구조의 시각적 표현
|
||||
- **Executions 탭**: 모든 실행 내역
|
||||
- **Metrics 탭**: 성능 분석
|
||||
- **Traces 탭**: 상세 실행 인사이트
|
||||
|
||||
### 실행 트리거하기
|
||||
|
||||
Enterprise 대시보드에서 다음 작업을 수행할 수 있습니다:
|
||||
|
||||
1. crew 이름을 클릭하여 상세 정보를 엽니다
|
||||
2. 관리 인터페이스에서 "Trigger Crew"를 선택합니다
|
||||
3. 나타나는 모달에 필요한 입력값을 입력합니다
|
||||
4. 파이프라인을 따라 실행의 진행 상황을 모니터링합니다
|
||||
|
||||
### 모니터링 및 분석
|
||||
|
||||
Enterprise 플랫폼은 포괄적인 가시성 기능을 제공합니다:
|
||||
|
||||
- **실행 관리**: 활성 및 완료된 실행 추적
|
||||
- **트레이스**: 각 실행의 상세 분해
|
||||
- **메트릭**: 토큰 사용량, 실행 시간, 비용
|
||||
- **타임라인 보기**: 작업 시퀀스의 시각적 표현
|
||||
|
||||
### 고급 기능
|
||||
|
||||
Enterprise 플랫폼은 또한 다음을 제공합니다:
|
||||
|
||||
- **환경 변수 관리**: API 키를 안전하게 저장 및 관리
|
||||
- **LLM 연결**: 다양한 LLM 공급자와의 통합 구성
|
||||
- **Custom Tools Repository**: 도구 생성, 공유 및 설치
|
||||
- **Crew Studio**: 코드를 작성하지 않고 채팅 인터페이스를 통해 crew 빌드
|
||||
|
||||
<Card
|
||||
title="도움이 필요하신가요?"
|
||||
icon="headset"
|
||||
href="mailto:support@crewai.com"
|
||||
>
|
||||
Enterprise 플랫폼의 배포 문제 또는 문의 사항이 있으시면 지원팀에 연락해
|
||||
주십시오.
|
||||
</Card>
|
||||
@@ -1,438 +0,0 @@
|
||||
---
|
||||
title: "AMP에 배포하기"
|
||||
description: "Crew 또는 Flow를 CrewAI AMP에 배포하기"
|
||||
icon: "rocket"
|
||||
mode: "wide"
|
||||
---
|
||||
|
||||
<Note>
|
||||
로컬에서 또는 Crew Studio를 통해 Crew나 Flow를 생성한 후, 다음 단계는 이를 CrewAI AMP
|
||||
플랫폼에 배포하는 것입니다. 본 가이드에서는 다양한 배포 방법을 다루며,
|
||||
여러분의 워크플로우에 가장 적합한 방식을 선택할 수 있도록 안내합니다.
|
||||
</Note>
|
||||
|
||||
## 사전 준비 사항
|
||||
|
||||
<CardGroup cols={2}>
|
||||
<Card title="배포 준비가 완료된 프로젝트" icon="check-circle">
|
||||
로컬에서 성공적으로 실행되는 Crew 또는 Flow가 있어야 합니다.
|
||||
[배포 준비 가이드](/ko/enterprise/guides/prepare-for-deployment)를 따라 프로젝트 구조를 확인하세요.
|
||||
</Card>
|
||||
<Card title="GitHub 저장소" icon="github">
|
||||
코드가 GitHub 저장소에 있어야 합니다(GitHub 연동 방식의 경우).
|
||||
</Card>
|
||||
</CardGroup>
|
||||
|
||||
<Info>
|
||||
**Crews vs Flows**: 두 프로젝트 유형 모두 CrewAI AMP에서 "자동화"로 배포할 수 있습니다.
|
||||
배포 과정은 동일하지만, 프로젝트 구조가 다릅니다.
|
||||
자세한 내용은 [배포 준비하기](/ko/enterprise/guides/prepare-for-deployment)를 참조하세요.
|
||||
</Info>
|
||||
|
||||
## 옵션 1: CrewAI CLI를 사용한 배포
|
||||
|
||||
CLI는 로컬에서 개발된 Crew 또는 Flow를 AMP 플랫폼에 가장 빠르게 배포할 수 있는 방법을 제공합니다.
|
||||
CLI는 `pyproject.toml`에서 프로젝트 유형을 자동으로 감지하고 그에 맞게 빌드합니다.
|
||||
|
||||
<Steps>
|
||||
<Step title="CrewAI CLI 설치">
|
||||
아직 설치하지 않았다면 CrewAI CLI를 설치하세요:
|
||||
|
||||
```bash
|
||||
pip install crewai[tools]
|
||||
```
|
||||
|
||||
<Tip>
|
||||
CLI는 기본 CrewAI 패키지에 포함되어 있지만, `[tools]` 추가 옵션을 사용하면 모든 배포 종속성을 함께 설치할 수 있습니다.
|
||||
</Tip>
|
||||
|
||||
</Step>
|
||||
|
||||
<Step title="Enterprise 플랫폼에 인증">
|
||||
먼저, CrewAI AMP 플랫폼에 CLI를 인증해야 합니다:
|
||||
|
||||
```bash
|
||||
# 이미 CrewAI AMP 계정이 있거나 새로 생성하고 싶을 때:
|
||||
crewai login
|
||||
```
|
||||
|
||||
위 명령어를 실행하면 CLI가 다음을 진행합니다:
|
||||
1. URL과 고유 기기 코드를 표시합니다
|
||||
2. 브라우저를 열어 인증 페이지로 이동합니다
|
||||
3. 기기 확인을 요청합니다
|
||||
4. 인증 과정을 완료합니다
|
||||
|
||||
인증이 성공적으로 완료되면 터미널에 확인 메시지가 표시됩니다!
|
||||
|
||||
</Step>
|
||||
|
||||
<Step title="배포 생성">
|
||||
|
||||
프로젝트 디렉터리에서 다음 명령어를 실행하세요:
|
||||
|
||||
```bash
|
||||
crewai deploy create
|
||||
```
|
||||
|
||||
이 명령어는 다음을 수행합니다:
|
||||
1. GitHub 저장소 정보를 감지합니다
|
||||
2. 로컬 `.env` 파일의 환경 변수를 식별합니다
|
||||
3. 이러한 변수를 Enterprise 플랫폼으로 안전하게 전송합니다
|
||||
4. 고유 식별자가 부여된 새 배포를 만듭니다
|
||||
|
||||
성공적으로 생성되면 다음과 같은 메시지가 표시됩니다:
|
||||
```shell
|
||||
Deployment created successfully!
|
||||
Name: your_project_name
|
||||
Deployment ID: 01234567-89ab-cdef-0123-456789abcdef
|
||||
Current Status: Deploy Enqueued
|
||||
```
|
||||
|
||||
</Step>
|
||||
|
||||
<Step title="배포 진행 상황 모니터링">
|
||||
|
||||
다음 명령어로 배포 상태를 추적할 수 있습니다:
|
||||
|
||||
```bash
|
||||
crewai deploy status
|
||||
```
|
||||
|
||||
빌드 과정의 상세 로그가 필요하다면:
|
||||
|
||||
```bash
|
||||
crewai deploy logs
|
||||
```
|
||||
|
||||
<Tip>
|
||||
첫 배포는 컨테이너 이미지를 빌드하므로 일반적으로 10~15분 정도 소요됩니다. 이후 배포는 훨씬 빠릅니다.
|
||||
</Tip>
|
||||
|
||||
</Step>
|
||||
</Steps>
|
||||
|
||||
## 추가 CLI 명령어
|
||||
|
||||
CrewAI CLI는 배포를 관리하기 위한 여러 명령어를 제공합니다:
|
||||
|
||||
```bash
|
||||
# 모든 배포 목록 확인
|
||||
crewai deploy list
|
||||
|
||||
# 배포 상태 확인
|
||||
crewai deploy status
|
||||
|
||||
# 배포 로그 보기
|
||||
crewai deploy logs
|
||||
|
||||
# 코드 변경 후 업데이트 푸시
|
||||
crewai deploy push
|
||||
|
||||
# 배포 삭제
|
||||
crewai deploy remove <deployment_id>
|
||||
```
|
||||
|
||||
## 옵션 2: 웹 인터페이스를 통한 직접 배포
|
||||
|
||||
GitHub 계정을 연결하여 CrewAI AMP 웹 인터페이스를 통해 Crew 또는 Flow를 직접 배포할 수도 있습니다. 이 방법은 로컬 머신에서 CLI를 사용할 필요가 없습니다. 플랫폼은 자동으로 프로젝트 유형을 감지하고 적절하게 빌드를 처리합니다.
|
||||
|
||||
<Steps>
|
||||
|
||||
<Step title="GitHub로 푸시하기">
|
||||
|
||||
Crew를 GitHub 저장소에 푸시해야 합니다. 아직 Crew를 만들지 않았다면, [이 튜토리얼](/ko/quickstart)을 따라할 수 있습니다.
|
||||
|
||||
</Step>
|
||||
|
||||
<Step title="GitHub를 CrewAI AMP에 연결하기">
|
||||
|
||||
1. [CrewAI AMP](https://app.crewai.com)에 로그인합니다.
|
||||
2. "Connect GitHub" 버튼을 클릭합니다.
|
||||
|
||||
<Frame>
|
||||

|
||||
</Frame>
|
||||
|
||||
</Step>
|
||||
|
||||
<Step title="저장소 선택하기">
|
||||
|
||||
GitHub 계정을 연결한 후 배포할 저장소를 선택할 수 있습니다:
|
||||
|
||||
<Frame>
|
||||

|
||||
</Frame>
|
||||
|
||||
</Step>
|
||||
|
||||
<Step title="환경 변수 설정하기">
|
||||
|
||||
배포 전에, LLM 제공업체 또는 기타 서비스에 연결할 환경 변수를 설정해야 합니다:
|
||||
|
||||
1. 변수를 개별적으로 또는 일괄적으로 추가할 수 있습니다.
|
||||
2. 환경 변수는 `KEY=VALUE` 형식(한 줄에 하나씩)으로 입력합니다.
|
||||
|
||||
<Frame>
|
||||

|
||||
</Frame>
|
||||
|
||||
</Step>
|
||||
|
||||
<Step title="Crew 배포하기">
|
||||
|
||||
1. "Deploy" 버튼을 클릭하여 배포 프로세스를 시작합니다.
|
||||
2. 진행 바를 통해 진행 상황을 모니터링할 수 있습니다.
|
||||
3. 첫 번째 배포에는 일반적으로 약 10-15분 정도 소요되며, 이후 배포는 더 빠릅니다.
|
||||
|
||||
<Frame>
|
||||

|
||||
</Frame>
|
||||
|
||||
배포가 완료되면 다음을 확인할 수 있습니다:
|
||||
- Crew의 고유 URL
|
||||
- Crew API를 보호할 Bearer 토큰
|
||||
- 배포를 삭제해야 하는 경우 "Delete" 버튼
|
||||
|
||||
</Step>
|
||||
|
||||
</Steps>
|
||||
|
||||
## 옵션 3: API를 통한 재배포 (CI/CD 통합)
|
||||
|
||||
CI/CD 파이프라인에서 자동화된 배포를 위해 CrewAI API를 사용하여 기존 crew의 재배포를 트리거할 수 있습니다. 이 방법은 GitHub Actions, Jenkins 또는 기타 자동화 워크플로우에 특히 유용합니다.
|
||||
|
||||
<Steps>
|
||||
<Step title="개인 액세스 토큰 발급">
|
||||
|
||||
CrewAI AMP 계정 설정에서 API 토큰을 생성합니다:
|
||||
|
||||
1. [app.crewai.com](https://app.crewai.com)으로 이동합니다
|
||||
2. **Settings** → **Account** → **Personal Access Token**을 클릭합니다
|
||||
3. 새 토큰을 생성하고 안전하게 복사합니다
|
||||
4. 이 토큰을 CI/CD 시스템의 시크릿으로 저장합니다
|
||||
|
||||
</Step>
|
||||
|
||||
<Step title="Automation UUID 찾기">
|
||||
|
||||
배포된 crew의 고유 식별자를 찾습니다:
|
||||
|
||||
1. CrewAI AMP 대시보드에서 **Automations**로 이동합니다
|
||||
2. 기존 automation/crew를 선택합니다
|
||||
3. **Additional Details**를 클릭합니다
|
||||
4. **UUID**를 복사합니다 - 이것이 특정 crew 배포를 식별합니다
|
||||
|
||||
</Step>
|
||||
|
||||
<Step title="API를 통한 재배포 트리거">
|
||||
|
||||
Deploy API 엔드포인트를 사용하여 재배포를 트리거합니다:
|
||||
|
||||
```bash
|
||||
curl -i -X POST \
|
||||
-H "Authorization: Bearer YOUR_PERSONAL_ACCESS_TOKEN" \
|
||||
https://app.crewai.com/crewai_plus/api/v1/crews/YOUR-AUTOMATION-UUID/deploy
|
||||
|
||||
# HTTP/2 200
|
||||
# content-type: application/json
|
||||
#
|
||||
# {
|
||||
# "uuid": "your-automation-uuid",
|
||||
# "status": "Deploy Enqueued",
|
||||
# "public_url": "https://your-crew-deployment.crewai.com",
|
||||
# "token": "your-bearer-token"
|
||||
# }
|
||||
```
|
||||
|
||||
<Info>
|
||||
Git에 연결되어 처음 생성된 automation의 경우, API가 재배포 전에 자동으로 저장소에서 최신 변경 사항을 가져옵니다.
|
||||
</Info>
|
||||
|
||||
</Step>
|
||||
|
||||
<Step title="GitHub Actions 통합 예시">
|
||||
|
||||
더 복잡한 배포 트리거가 있는 GitHub Actions 워크플로우 예시입니다:
|
||||
|
||||
```yaml
|
||||
name: Deploy CrewAI Automation
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
types: [ labeled ]
|
||||
release:
|
||||
types: [ published ]
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
if: |
|
||||
(github.event_name == 'push' && github.ref == 'refs/heads/main') ||
|
||||
(github.event_name == 'pull_request' && contains(github.event.pull_request.labels.*.name, 'deploy')) ||
|
||||
(github.event_name == 'release')
|
||||
steps:
|
||||
- name: Trigger CrewAI Redeployment
|
||||
run: |
|
||||
curl -X POST \
|
||||
-H "Authorization: Bearer ${{ secrets.CREWAI_PAT }}" \
|
||||
https://app.crewai.com/crewai_plus/api/v1/crews/${{ secrets.CREWAI_AUTOMATION_UUID }}/deploy
|
||||
```
|
||||
|
||||
<Tip>
|
||||
`CREWAI_PAT`와 `CREWAI_AUTOMATION_UUID`를 저장소 시크릿으로 추가하세요. PR 배포의 경우 "deploy" 라벨을 추가하여 워크플로우를 트리거합니다.
|
||||
</Tip>
|
||||
|
||||
</Step>
|
||||
|
||||
</Steps>
|
||||
|
||||
## 배포된 Automation과 상호작용하기
|
||||
|
||||
배포가 완료되면 다음을 통해 crew에 접근할 수 있습니다:
|
||||
|
||||
1. **REST API**: 플랫폼에서 아래의 주요 경로가 포함된 고유한 HTTPS 엔드포인트를 생성합니다:
|
||||
|
||||
- `/inputs`: 필요한 입력 파라미터 목록
|
||||
- `/kickoff`: 제공된 입력값으로 실행 시작
|
||||
- `/status/{kickoff_id}`: 실행 상태 확인
|
||||
|
||||
2. **웹 인터페이스**: [app.crewai.com](https://app.crewai.com)에 방문하여 다음을 확인할 수 있습니다:
|
||||
- **Status 탭**: 배포 정보, API 엔드포인트 세부 정보 및 인증 토큰 확인
|
||||
- **Run 탭**: Crew 구조의 시각적 표현
|
||||
- **Executions 탭**: 모든 실행 내역
|
||||
- **Metrics 탭**: 성능 분석
|
||||
- **Traces 탭**: 상세 실행 인사이트
|
||||
|
||||
### 실행 트리거하기
|
||||
|
||||
Enterprise 대시보드에서 다음 작업을 수행할 수 있습니다:
|
||||
|
||||
1. Crew 이름을 클릭하여 상세 정보를 엽니다
|
||||
2. 관리 인터페이스에서 "Trigger Crew"를 선택합니다
|
||||
3. 나타나는 모달에 필요한 입력값을 입력합니다
|
||||
4. 파이프라인을 따라 실행의 진행 상황을 모니터링합니다
|
||||
|
||||
### 모니터링 및 분석
|
||||
|
||||
Enterprise 플랫폼은 포괄적인 가시성 기능을 제공합니다:
|
||||
|
||||
- **실행 관리**: 활성 및 완료된 실행 추적
|
||||
- **트레이스**: 각 실행의 상세 분해
|
||||
- **메트릭**: 토큰 사용량, 실행 시간, 비용
|
||||
- **타임라인 보기**: 작업 시퀀스의 시각적 표현
|
||||
|
||||
### 고급 기능
|
||||
|
||||
Enterprise 플랫폼은 또한 다음을 제공합니다:
|
||||
|
||||
- **환경 변수 관리**: API 키를 안전하게 저장 및 관리
|
||||
- **LLM 연결**: 다양한 LLM 공급자와의 통합 구성
|
||||
- **Custom Tools Repository**: 도구 생성, 공유 및 설치
|
||||
- **Crew Studio**: 코드를 작성하지 않고 채팅 인터페이스를 통해 crew 빌드
|
||||
|
||||
## 배포 실패 문제 해결
|
||||
|
||||
배포가 실패하면 다음과 같은 일반적인 문제를 확인하세요:
|
||||
|
||||
### 빌드 실패
|
||||
|
||||
#### uv.lock 파일 누락
|
||||
|
||||
**증상**: 의존성 해결 오류와 함께 빌드 초기에 실패
|
||||
|
||||
**해결책**: lock 파일을 생성하고 커밋합니다:
|
||||
|
||||
```bash
|
||||
uv lock
|
||||
git add uv.lock
|
||||
git commit -m "Add uv.lock for deployment"
|
||||
git push
|
||||
```
|
||||
|
||||
<Warning>
|
||||
`uv.lock` 파일은 모든 배포에 필수입니다. 이 파일이 없으면 플랫폼에서
|
||||
의존성을 안정적으로 설치할 수 없습니다.
|
||||
</Warning>
|
||||
|
||||
#### 잘못된 프로젝트 구조
|
||||
|
||||
**증상**: "Could not find entry point" 또는 "Module not found" 오류
|
||||
|
||||
**해결책**: 프로젝트가 예상 구조와 일치하는지 확인합니다:
|
||||
|
||||
- **Crews와 Flows 모두**: 진입점이 `src/project_name/main.py`에 있어야 합니다
|
||||
- **Crews**: 진입점으로 `run()` 함수 사용
|
||||
- **Flows**: 진입점으로 `kickoff()` 함수 사용
|
||||
|
||||
자세한 구조 다이어그램은 [배포 준비하기](/ko/enterprise/guides/prepare-for-deployment)를 참조하세요.
|
||||
|
||||
#### CrewBase 데코레이터 누락
|
||||
|
||||
**증상**: "Crew not found", "Config not found" 또는 agent/task 구성 오류
|
||||
|
||||
**해결책**: **모든** crew 클래스가 `@CrewBase` 데코레이터를 사용하는지 확인합니다:
|
||||
|
||||
```python
|
||||
from crewai.project import CrewBase, agent, crew, task
|
||||
|
||||
@CrewBase # 이 데코레이터는 필수입니다
|
||||
class YourCrew():
|
||||
"""Crew 설명"""
|
||||
|
||||
@agent
|
||||
def my_agent(self) -> Agent:
|
||||
return Agent(
|
||||
config=self.agents_config['my_agent'], # type: ignore[index]
|
||||
verbose=True
|
||||
)
|
||||
|
||||
# ... 나머지 crew 정의
|
||||
```
|
||||
|
||||
<Info>
|
||||
이것은 독립 실행형 Crews와 Flow 프로젝트 내에 포함된 crews 모두에 적용됩니다.
|
||||
모든 crew 클래스에 데코레이터가 필요합니다.
|
||||
</Info>
|
||||
|
||||
#### 잘못된 pyproject.toml 타입
|
||||
|
||||
**증상**: 빌드는 성공하지만 런타임에서 실패하거나 예상치 못한 동작
|
||||
|
||||
**해결책**: `[tool.crewai]` 섹션이 프로젝트 유형과 일치하는지 확인합니다:
|
||||
|
||||
```toml
|
||||
# Crew 프로젝트의 경우:
|
||||
[tool.crewai]
|
||||
type = "crew"
|
||||
|
||||
# Flow 프로젝트의 경우:
|
||||
[tool.crewai]
|
||||
type = "flow"
|
||||
```
|
||||
|
||||
### 런타임 실패
|
||||
|
||||
#### LLM 연결 실패
|
||||
|
||||
**증상**: API 키 오류, "model not found" 또는 인증 실패
|
||||
|
||||
**해결책**:
|
||||
1. LLM 제공업체의 API 키가 환경 변수에 올바르게 설정되어 있는지 확인합니다
|
||||
2. 환경 변수 이름이 코드에서 예상하는 것과 일치하는지 확인합니다
|
||||
3. 배포 전에 동일한 환경 변수로 로컬에서 테스트합니다
|
||||
|
||||
#### Crew 실행 오류
|
||||
|
||||
**증상**: Crew가 시작되지만 실행 중에 실패
|
||||
|
||||
**해결책**:
|
||||
1. AMP 대시보드에서 실행 로그를 확인합니다 (Traces 탭)
|
||||
2. 모든 도구에 필요한 API 키가 구성되어 있는지 확인합니다
|
||||
3. `agents.yaml`의 agent 구성이 유효한지 확인합니다
|
||||
4. `tasks.yaml`의 task 구성에 구문 오류가 없는지 확인합니다
|
||||
|
||||
<Card title="도움이 필요하신가요?" icon="headset" href="mailto:support@crewai.com">
|
||||
배포 문제 또는 AMP 플랫폼에 대한 문의 사항이 있으시면 지원팀에 연락해 주세요.
|
||||
</Card>
|
||||
@@ -1,305 +0,0 @@
|
||||
---
|
||||
title: "배포 준비하기"
|
||||
description: "Crew 또는 Flow가 CrewAI AMP에 배포될 준비가 되었는지 확인하기"
|
||||
icon: "clipboard-check"
|
||||
mode: "wide"
|
||||
---
|
||||
|
||||
<Note>
|
||||
CrewAI AMP에 배포하기 전에, 프로젝트가 올바르게 구성되어 있는지 확인하는 것이 중요합니다.
|
||||
Crews와 Flows 모두 "자동화"로 배포할 수 있지만, 성공적인 배포를 위해 충족해야 하는
|
||||
서로 다른 프로젝트 구조와 요구 사항이 있습니다.
|
||||
</Note>
|
||||
|
||||
## 자동화 이해하기
|
||||
|
||||
CrewAI AMP에서 **자동화(automations)**는 배포 가능한 Agentic AI 프로젝트의 총칭입니다. 자동화는 다음 중 하나일 수 있습니다:
|
||||
|
||||
- **Crew**: 작업을 함께 수행하는 AI 에이전트들의 독립 실행형 팀
|
||||
- **Flow**: 여러 crew, 직접 LLM 호출 및 절차적 로직을 결합할 수 있는 오케스트레이션된 워크플로우
|
||||
|
||||
배포하는 유형을 이해하는 것은 프로젝트 구조와 진입점이 다르기 때문에 필수적입니다.
|
||||
|
||||
## Crews vs Flows: 주요 차이점
|
||||
|
||||
<CardGroup cols={2}>
|
||||
<Card title="Crew 프로젝트" icon="users">
|
||||
에이전트와 작업을 정의하는 `crew.py`가 있는 독립 실행형 AI 에이전트 팀. 집중적이고 협업적인 작업에 적합합니다.
|
||||
</Card>
|
||||
<Card title="Flow 프로젝트" icon="diagram-project">
|
||||
`crews/` 폴더에 포함된 crew가 있는 오케스트레이션된 워크플로우. 복잡한 다단계 프로세스에 적합합니다.
|
||||
</Card>
|
||||
</CardGroup>
|
||||
|
||||
| 측면 | Crew | Flow |
|
||||
|------|------|------|
|
||||
| **프로젝트 구조** | `crew.py`가 있는 `src/project_name/` | `crews/` 폴더가 있는 `src/project_name/` |
|
||||
| **메인 로직 위치** | `src/project_name/crew.py` | `src/project_name/main.py` (Flow 클래스) |
|
||||
| **진입점 함수** | `main.py`의 `run()` | `main.py`의 `kickoff()` |
|
||||
| **pyproject.toml 타입** | `type = "crew"` | `type = "flow"` |
|
||||
| **CLI 생성 명령어** | `crewai create crew name` | `crewai create flow name` |
|
||||
| **설정 위치** | `src/project_name/config/` | `src/project_name/crews/crew_name/config/` |
|
||||
| **다른 crew 포함 가능** | 아니오 | 예 (`crews/` 폴더 내) |
|
||||
|
||||
## 프로젝트 구조 참조
|
||||
|
||||
### Crew 프로젝트 구조
|
||||
|
||||
`crewai create crew my_crew`를 실행하면 다음 구조를 얻습니다:
|
||||
|
||||
```
|
||||
my_crew/
|
||||
├── .gitignore
|
||||
├── pyproject.toml # type = "crew"여야 함
|
||||
├── README.md
|
||||
├── .env
|
||||
├── uv.lock # 배포에 필수
|
||||
└── src/
|
||||
└── my_crew/
|
||||
├── __init__.py
|
||||
├── main.py # run() 함수가 있는 진입점
|
||||
├── crew.py # @CrewBase 데코레이터가 있는 Crew 클래스
|
||||
├── tools/
|
||||
│ ├── custom_tool.py
|
||||
│ └── __init__.py
|
||||
└── config/
|
||||
├── agents.yaml # 에이전트 정의
|
||||
└── tasks.yaml # 작업 정의
|
||||
```
|
||||
|
||||
<Warning>
|
||||
중첩된 `src/project_name/` 구조는 Crews에 매우 중요합니다.
|
||||
잘못된 레벨에 파일을 배치하면 배포 실패의 원인이 됩니다.
|
||||
</Warning>
|
||||
|
||||
### Flow 프로젝트 구조
|
||||
|
||||
`crewai create flow my_flow`를 실행하면 다음 구조를 얻습니다:
|
||||
|
||||
```
|
||||
my_flow/
|
||||
├── .gitignore
|
||||
├── pyproject.toml # type = "flow"여야 함
|
||||
├── README.md
|
||||
├── .env
|
||||
├── uv.lock # 배포에 필수
|
||||
└── src/
|
||||
└── my_flow/
|
||||
├── __init__.py
|
||||
├── main.py # kickoff() 함수 + Flow 클래스가 있는 진입점
|
||||
├── crews/ # 포함된 crews 폴더
|
||||
│ └── poem_crew/
|
||||
│ ├── __init__.py
|
||||
│ ├── poem_crew.py # @CrewBase 데코레이터가 있는 Crew
|
||||
│ └── config/
|
||||
│ ├── agents.yaml
|
||||
│ └── tasks.yaml
|
||||
└── tools/
|
||||
├── __init__.py
|
||||
└── custom_tool.py
|
||||
```
|
||||
|
||||
<Info>
|
||||
Crews와 Flows 모두 `src/project_name/` 구조를 사용합니다.
|
||||
핵심 차이점은 Flows는 포함된 crews를 위한 `crews/` 폴더가 있고,
|
||||
Crews는 프로젝트 폴더에 직접 `crew.py`가 있다는 것입니다.
|
||||
</Info>
|
||||
|
||||
## 배포 전 체크리스트
|
||||
|
||||
이 체크리스트를 사용하여 프로젝트가 배포 준비가 되었는지 확인하세요.
|
||||
|
||||
### 1. pyproject.toml 설정 확인
|
||||
|
||||
`pyproject.toml`에 올바른 `[tool.crewai]` 섹션이 포함되어야 합니다:
|
||||
|
||||
<Tabs>
|
||||
<Tab title="Crews의 경우">
|
||||
```toml
|
||||
[tool.crewai]
|
||||
type = "crew"
|
||||
```
|
||||
</Tab>
|
||||
<Tab title="Flows의 경우">
|
||||
```toml
|
||||
[tool.crewai]
|
||||
type = "flow"
|
||||
```
|
||||
</Tab>
|
||||
</Tabs>
|
||||
|
||||
<Warning>
|
||||
`type`이 프로젝트 구조와 일치하지 않으면 빌드가 실패하거나
|
||||
자동화가 올바르게 실행되지 않습니다.
|
||||
</Warning>
|
||||
|
||||
### 2. uv.lock 파일 존재 확인
|
||||
|
||||
CrewAI는 의존성 관리를 위해 `uv`를 사용합니다. `uv.lock` 파일은 재현 가능한 빌드를 보장하며 배포에 **필수**입니다.
|
||||
|
||||
```bash
|
||||
# lock 파일 생성 또는 업데이트
|
||||
uv lock
|
||||
|
||||
# 존재 여부 확인
|
||||
ls -la uv.lock
|
||||
```
|
||||
|
||||
파일이 존재하지 않으면 `uv lock`을 실행하고 저장소에 커밋하세요:
|
||||
|
||||
```bash
|
||||
uv lock
|
||||
git add uv.lock
|
||||
git commit -m "Add uv.lock for deployment"
|
||||
git push
|
||||
```
|
||||
|
||||
### 3. CrewBase 데코레이터 사용 확인
|
||||
|
||||
**모든 crew 클래스는 `@CrewBase` 데코레이터를 사용해야 합니다.** 이것은 다음에 적용됩니다:
|
||||
|
||||
- 독립 실행형 crew 프로젝트
|
||||
- Flow 프로젝트 내에 포함된 crews
|
||||
|
||||
```python
|
||||
from crewai import Agent, Crew, Process, Task
|
||||
from crewai.project import CrewBase, agent, crew, task
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from typing import List
|
||||
|
||||
@CrewBase # 이 데코레이터는 필수입니다
|
||||
class MyCrew():
|
||||
"""내 crew 설명"""
|
||||
|
||||
agents: List[BaseAgent]
|
||||
tasks: List[Task]
|
||||
|
||||
@agent
|
||||
def my_agent(self) -> Agent:
|
||||
return Agent(
|
||||
config=self.agents_config['my_agent'], # type: ignore[index]
|
||||
verbose=True
|
||||
)
|
||||
|
||||
@task
|
||||
def my_task(self) -> Task:
|
||||
return Task(
|
||||
config=self.tasks_config['my_task'] # type: ignore[index]
|
||||
)
|
||||
|
||||
@crew
|
||||
def crew(self) -> Crew:
|
||||
return Crew(
|
||||
agents=self.agents,
|
||||
tasks=self.tasks,
|
||||
process=Process.sequential,
|
||||
verbose=True,
|
||||
)
|
||||
```
|
||||
|
||||
<Warning>
|
||||
`@CrewBase` 데코레이터를 잊으면 에이전트나 작업 구성이 누락되었다는
|
||||
오류와 함께 배포가 실패합니다.
|
||||
</Warning>
|
||||
|
||||
### 4. 프로젝트 진입점 확인
|
||||
|
||||
Crews와 Flows 모두 `src/project_name/main.py`에 진입점이 있습니다:
|
||||
|
||||
<Tabs>
|
||||
<Tab title="Crews의 경우">
|
||||
진입점은 `run()` 함수를 사용합니다:
|
||||
|
||||
```python
|
||||
# src/my_crew/main.py
|
||||
from my_crew.crew import MyCrew
|
||||
|
||||
def run():
|
||||
"""crew를 실행합니다."""
|
||||
inputs = {'topic': 'AI in Healthcare'}
|
||||
result = MyCrew().crew().kickoff(inputs=inputs)
|
||||
return result
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
```
|
||||
</Tab>
|
||||
<Tab title="Flows의 경우">
|
||||
진입점은 Flow 클래스와 함께 `kickoff()` 함수를 사용합니다:
|
||||
|
||||
```python
|
||||
# src/my_flow/main.py
|
||||
from crewai.flow import Flow, listen, start
|
||||
from my_flow.crews.poem_crew.poem_crew import PoemCrew
|
||||
|
||||
class MyFlow(Flow):
|
||||
@start()
|
||||
def begin(self):
|
||||
# Flow 로직
|
||||
result = PoemCrew().crew().kickoff(inputs={...})
|
||||
return result
|
||||
|
||||
def kickoff():
|
||||
"""flow를 실행합니다."""
|
||||
MyFlow().kickoff()
|
||||
|
||||
if __name__ == "__main__":
|
||||
kickoff()
|
||||
```
|
||||
</Tab>
|
||||
</Tabs>
|
||||
|
||||
### 5. 환경 변수 준비
|
||||
|
||||
배포 전에 다음을 준비해야 합니다:
|
||||
|
||||
1. **LLM API 키** (OpenAI, Anthropic, Google 등)
|
||||
2. **도구 API 키** - 외부 도구를 사용하는 경우 (Serper 등)
|
||||
|
||||
<Tip>
|
||||
구성 문제를 조기에 발견하기 위해 배포 전에 동일한 환경 변수로
|
||||
로컬에서 프로젝트를 테스트하세요.
|
||||
</Tip>
|
||||
|
||||
## 빠른 검증 명령어
|
||||
|
||||
프로젝트 루트에서 다음 명령어를 실행하여 설정을 빠르게 확인하세요:
|
||||
|
||||
```bash
|
||||
# 1. pyproject.toml에서 프로젝트 타입 확인
|
||||
grep -A2 "\[tool.crewai\]" pyproject.toml
|
||||
|
||||
# 2. uv.lock 존재 확인
|
||||
ls -la uv.lock || echo "오류: uv.lock이 없습니다! 'uv lock'을 실행하세요"
|
||||
|
||||
# 3. src/ 구조 존재 확인
|
||||
ls -la src/*/main.py 2>/dev/null || echo "src/에서 main.py를 찾을 수 없습니다"
|
||||
|
||||
# 4. Crews의 경우 - crew.py 존재 확인
|
||||
ls -la src/*/crew.py 2>/dev/null || echo "crew.py가 없습니다 (Crews에서 예상됨)"
|
||||
|
||||
# 5. Flows의 경우 - crews/ 폴더 존재 확인
|
||||
ls -la src/*/crews/ 2>/dev/null || echo "crews/ 폴더가 없습니다 (Flows에서 예상됨)"
|
||||
|
||||
# 6. CrewBase 사용 확인
|
||||
grep -r "@CrewBase" . --include="*.py"
|
||||
```
|
||||
|
||||
## 일반적인 설정 실수
|
||||
|
||||
| 실수 | 증상 | 해결 방법 |
|
||||
|------|------|----------|
|
||||
| `uv.lock` 누락 | 의존성 해결 중 빌드 실패 | `uv lock` 실행 후 커밋 |
|
||||
| pyproject.toml의 잘못된 `type` | 빌드 성공하지만 런타임 실패 | 올바른 타입으로 변경 |
|
||||
| `@CrewBase` 데코레이터 누락 | "Config not found" 오류 | 모든 crew 클래스에 데코레이터 추가 |
|
||||
| `src/` 대신 루트에 파일 배치 | 진입점을 찾을 수 없음 | `src/project_name/`으로 이동 |
|
||||
| `run()` 또는 `kickoff()` 누락 | 자동화를 시작할 수 없음 | 올바른 진입 함수 추가 |
|
||||
|
||||
## 다음 단계
|
||||
|
||||
프로젝트가 모든 체크리스트 항목을 통과하면 배포할 준비가 된 것입니다:
|
||||
|
||||
<Card title="AMP에 배포하기" icon="rocket" href="/ko/enterprise/guides/deploy-to-amp">
|
||||
CLI, 웹 인터페이스 또는 CI/CD 통합을 사용하여 Crew 또는 Flow를 CrewAI AMP에
|
||||
배포하려면 배포 가이드를 따르세요.
|
||||
</Card>
|
||||
@@ -79,7 +79,7 @@ CrewAI AOP는 오픈 소스 프레임워크의 강력함에 프로덕션 배포,
|
||||
<Card
|
||||
title="Crew 배포"
|
||||
icon="rocket"
|
||||
href="/ko/enterprise/guides/deploy-to-amp"
|
||||
href="/ko/enterprise/guides/deploy-crew"
|
||||
>
|
||||
Crew 배포
|
||||
</Card>
|
||||
@@ -96,4 +96,4 @@ CrewAI AOP는 오픈 소스 프레임워크의 강력함에 프로덕션 배포,
|
||||
</Step>
|
||||
</Steps>
|
||||
|
||||
자세한 안내를 원하시면 [배포 가이드](/ko/enterprise/guides/deploy-to-amp)를 확인하거나 아래 버튼을 클릭해 시작하세요.
|
||||
자세한 안내를 원하시면 [배포 가이드](/ko/enterprise/guides/deploy-crew)를 확인하거나 아래 버튼을 클릭해 시작하세요.
|
||||
|
||||
@@ -1,115 +0,0 @@
|
||||
---
|
||||
title: Galileo 갈릴레오
|
||||
description: CrewAI 추적 및 평가를 위한 Galileo 통합
|
||||
icon: telescope
|
||||
mode: "wide"
|
||||
---
|
||||
|
||||
## 개요
|
||||
|
||||
이 가이드는 **Galileo**를 **CrewAI**와 통합하는 방법을 보여줍니다.
|
||||
포괄적인 추적 및 평가 엔지니어링을 위한 것입니다.
|
||||
이 가이드가 끝나면 CrewAI 에이전트를 추적할 수 있게 됩니다.
|
||||
성과를 모니터링하고 행동을 평가합니다.
|
||||
Galileo의 강력한 관측 플랫폼.
|
||||
|
||||
> **갈릴레오(Galileo)란 무엇인가요?**[Galileo](https://galileo.ai/)는 AI 평가 및 관찰 가능성입니다.
|
||||
엔드투엔드 추적, 평가,
|
||||
AI 애플리케이션 모니터링. 이를 통해 팀은 실제 사실을 포착할 수 있습니다.
|
||||
견고한 가드레일을 만들고 체계적인 실험을 실행하세요.
|
||||
내장된 실험 추적 및 성능 분석으로 신뢰성 보장
|
||||
AI 수명주기 전반에 걸쳐 투명성과 지속적인 개선을 제공합니다.
|
||||
|
||||
## 시작하기
|
||||
|
||||
이 튜토리얼은 [CrewAI 빠른 시작](/ko/quickstart.mdx)을 따르며 추가하는 방법을 보여줍니다.
|
||||
갈릴레오의 [CrewAIEventListener](https://v2docs.galileo.ai/sdk-api/python/reference/handlers/crewai/handler),
|
||||
이벤트 핸들러.
|
||||
자세한 내용은 갈릴레오 문서를 참고하세요.
|
||||
[CrewAI 애플리케이션에 Galileo 추가](https://v2docs.galileo.ai/how-to-guides/third-party-integrations/add-galileo-to-crewai/add-galileo-to-crewai)
|
||||
방법 안내.
|
||||
|
||||
> **참고**이 튜토리얼에서는 [CrewAI 빠른 시작](/ko/quickstart.mdx)을 완료했다고 가정합니다.
|
||||
완전한 포괄적인 예제를 원한다면 Galileo
|
||||
[CrewAI SDK 예제 저장소](https://github.com/rungalileo/sdk-examples/tree/main/python/agent/crew-ai).
|
||||
|
||||
### 1단계: 종속성 설치
|
||||
|
||||
앱에 필요한 종속성을 설치합니다.
|
||||
원하는 방법으로 가상 환경을 생성하고,
|
||||
그런 다음 다음을 사용하여 해당 환경 내에 종속성을 설치하십시오.
|
||||
선호하는 도구:
|
||||
|
||||
```bash
|
||||
uv add galileo
|
||||
```
|
||||
|
||||
### 2단계: [CrewAI 빠른 시작](/ko/quickstart.mdx)에서 .env 파일에 추가
|
||||
|
||||
```bash
|
||||
# Your Galileo API key
|
||||
GALILEO_API_KEY="your-galileo-api-key"
|
||||
|
||||
# Your Galileo project name
|
||||
GALILEO_PROJECT="your-galileo-project-name"
|
||||
|
||||
# The name of the Log stream you want to use for logging
|
||||
GALILEO_LOG_STREAM="your-galileo-log-stream "
|
||||
```
|
||||
|
||||
### 3단계: Galileo 이벤트 리스너 추가
|
||||
|
||||
Galileo로 로깅을 활성화하려면 `CrewAIEventListener`의 인스턴스를 생성해야 합니다.
|
||||
다음을 통해 Galileo CrewAI 핸들러 패키지를 가져옵니다.
|
||||
main.py 파일 상단에 다음 코드를 추가하세요.
|
||||
|
||||
```python
|
||||
from galileo.handlers.crewai.handler import CrewAIEventListener
|
||||
```
|
||||
|
||||
실행 함수 시작 시 이벤트 리스너를 생성합니다.
|
||||
|
||||
```python
|
||||
def run():
|
||||
# Create the event listener
|
||||
CrewAIEventListener()
|
||||
# The rest of your existing code goes here
|
||||
```
|
||||
|
||||
리스너 인스턴스를 생성하면 자동으로
|
||||
CrewAI에 등록되었습니다.
|
||||
|
||||
### 4단계: Crew Agent 실행
|
||||
|
||||
CrewAI CLI를 사용하여 Crew Agent를 실행하세요.
|
||||
|
||||
```bash
|
||||
crewai run
|
||||
```
|
||||
|
||||
### 5단계: Galileo에서 추적 보기
|
||||
|
||||
승무원 에이전트가 완료되면 흔적이 플러시되어 Galileo에 나타납니다.
|
||||
|
||||

|
||||
|
||||
## 갈릴레오 통합 이해
|
||||
|
||||
Galileo는 이벤트 리스너를 등록하여 CrewAI와 통합됩니다.
|
||||
승무원 실행 이벤트(예: 에이전트 작업, 도구 호출, 모델 응답)를 캡처합니다.
|
||||
관찰 가능성과 평가를 위해 이를 갈릴레오에 전달합니다.
|
||||
|
||||
### 이벤트 리스너 이해
|
||||
|
||||
`CrewAIEventListener()` 인스턴스를 생성하는 것이 전부입니다.
|
||||
CrewAI 실행을 위해 Galileo를 활성화하는 데 필요합니다. 인스턴스화되면 리스너는 다음을 수행합니다.
|
||||
|
||||
-CrewAI에 자동으로 등록됩니다.
|
||||
-환경 변수에서 Galileo 구성을 읽습니다.
|
||||
-모든 실행 데이터를 Galileo 프로젝트 및 다음에서 지정한 로그 스트림에 기록합니다.
|
||||
`GALILEO_PROJECT` 및 `GALILEO_LOG_STREAM`
|
||||
|
||||
추가 구성이나 코드 변경이 필요하지 않습니다.
|
||||
이 실행의 모든 데이터는 Galileo 프로젝트에 기록되며
|
||||
환경 구성에 따라 지정된 로그 스트림
|
||||
(예: GALILEO_PROJECT 및 GALILEO_LOG_STREAM)
|
||||
@@ -79,7 +79,7 @@ Existem diferentes locais no código do CrewAI onde você pode especificar o mod
|
||||
|
||||
# Configuração avançada com parâmetros detalhados
|
||||
llm = LLM(
|
||||
model="openai/gpt-4",
|
||||
model="openai/gpt-4",
|
||||
temperature=0.8,
|
||||
max_tokens=150,
|
||||
top_p=0.9,
|
||||
@@ -207,20 +207,11 @@ Nesta seção, você encontrará exemplos detalhados que ajudam a selecionar, co
|
||||
Defina sua chave de API no seu arquivo `.env`. Se precisar de uma chave, ou encontrar uma existente, verifique o [AI Studio](https://aistudio.google.com/apikey).
|
||||
|
||||
```toml .env
|
||||
# Para API Gemini (uma das seguintes)
|
||||
GOOGLE_API_KEY=<your-api-key>
|
||||
# https://ai.google.dev/gemini-api/docs/api-key
|
||||
GEMINI_API_KEY=<your-api-key>
|
||||
|
||||
# Para Vertex AI Express mode (autenticação por chave de API)
|
||||
GOOGLE_GENAI_USE_VERTEXAI=true
|
||||
GOOGLE_API_KEY=<your-api-key>
|
||||
|
||||
# Para Vertex AI com conta de serviço
|
||||
GOOGLE_CLOUD_PROJECT=<your-project-id>
|
||||
GOOGLE_CLOUD_LOCATION=<location> # Padrão: us-central1
|
||||
```
|
||||
|
||||
**Uso Básico:**
|
||||
Exemplo de uso em seu projeto CrewAI:
|
||||
```python Code
|
||||
from crewai import LLM
|
||||
|
||||
@@ -230,34 +221,6 @@ Nesta seção, você encontrará exemplos detalhados que ajudam a selecionar, co
|
||||
)
|
||||
```
|
||||
|
||||
**Vertex AI Express Mode (Autenticação por Chave de API):**
|
||||
|
||||
O Vertex AI Express mode permite usar o Vertex AI com autenticação simples por chave de API, em vez de credenciais de conta de serviço. Esta é a maneira mais rápida de começar com o Vertex AI.
|
||||
|
||||
Para habilitar o Express mode, defina ambas as variáveis de ambiente no seu arquivo `.env`:
|
||||
```toml .env
|
||||
GOOGLE_GENAI_USE_VERTEXAI=true
|
||||
GOOGLE_API_KEY=<your-api-key>
|
||||
```
|
||||
|
||||
Em seguida, use o LLM normalmente:
|
||||
```python Code
|
||||
from crewai import LLM
|
||||
|
||||
llm = LLM(
|
||||
model="gemini/gemini-2.0-flash",
|
||||
temperature=0.7
|
||||
)
|
||||
```
|
||||
|
||||
<Info>
|
||||
Para obter uma chave de API do Express mode:
|
||||
- Novos usuários do Google Cloud: Obtenha uma [chave de API do Express mode](https://cloud.google.com/vertex-ai/generative-ai/docs/start/quickstart?usertype=apikey)
|
||||
- Usuários existentes do Google Cloud: Obtenha uma [chave de API do Google Cloud vinculada a uma conta de serviço](https://cloud.google.com/docs/authentication/api-keys)
|
||||
|
||||
Para mais detalhes, consulte a [documentação do Vertex AI Express mode](https://docs.cloud.google.com/vertex-ai/generative-ai/docs/start/quickstart?usertype=apikey).
|
||||
</Info>
|
||||
|
||||
### Modelos Gemini
|
||||
|
||||
O Google oferece uma variedade de modelos poderosos otimizados para diferentes casos de uso.
|
||||
@@ -860,7 +823,7 @@ Saiba como obter o máximo da configuração do seu LLM:
|
||||
Lembre-se de monitorar regularmente o uso de tokens e ajustar suas configurações para otimizar custos e desempenho.
|
||||
</Info>
|
||||
</Accordion>
|
||||
|
||||
|
||||
<Accordion title="Descartar Parâmetros Adicionais">
|
||||
O CrewAI usa Litellm internamente para chamadas LLM, permitindo descartar parâmetros adicionais desnecessários para seu caso de uso. Isso pode simplificar seu código e reduzir a complexidade da configuração do LLM.
|
||||
Por exemplo, se não precisar enviar o parâmetro <code>stop</code>, basta omiti-lo na chamada do LLM:
|
||||
@@ -919,4 +882,4 @@ Saiba como obter o máximo da configuração do seu LLM:
|
||||
llm = LLM(model="openai/gpt-4o") # 128K tokens
|
||||
```
|
||||
</Tab>
|
||||
</Tabs>
|
||||
</Tabs>
|
||||
@@ -128,7 +128,7 @@ Ao implantar seu Flow, considere o seguinte:
|
||||
### CrewAI Enterprise
|
||||
A maneira mais fácil de implantar seu Flow é usando o CrewAI Enterprise. Ele lida com a infraestrutura, autenticação e monitoramento para você.
|
||||
|
||||
Confira o [Guia de Implantação](/pt-BR/enterprise/guides/deploy-to-amp) para começar.
|
||||
Confira o [Guia de Implantação](/pt-BR/enterprise/guides/deploy-crew) para começar.
|
||||
|
||||
```bash
|
||||
crewai deploy create
|
||||
|
||||
@@ -91,7 +91,7 @@ Após implantar, você pode ver os detalhes da automação e usar o menu **Optio
|
||||
## Relacionados
|
||||
|
||||
<CardGroup cols={3}>
|
||||
<Card title="Implantar um Crew" href="/pt-BR/enterprise/guides/deploy-to-amp" icon="rocket">
|
||||
<Card title="Implantar um Crew" href="/pt-BR/enterprise/guides/deploy-crew" icon="rocket">
|
||||
Implante um Crew via GitHub ou arquivo ZIP.
|
||||
</Card>
|
||||
<Card title="Gatilhos de Automação" href="/pt-BR/enterprise/guides/automation-triggers" icon="trigger">
|
||||
|
||||
@@ -79,7 +79,7 @@ Após publicar, você pode visualizar os detalhes da automação e usar o menu *
|
||||
<Card title="Criar um Crew" href="/pt-BR/enterprise/guides/build-crew" icon="paintbrush">
|
||||
Crie um Crew.
|
||||
</Card>
|
||||
<Card title="Implantar um Crew" href="/pt-BR/enterprise/guides/deploy-to-amp" icon="rocket">
|
||||
<Card title="Implantar um Crew" href="/pt-BR/enterprise/guides/deploy-crew" icon="rocket">
|
||||
Implante um Crew via GitHub ou ZIP.
|
||||
</Card>
|
||||
<Card title="Exportar um Componente React" href="/pt-BR/enterprise/guides/react-component-export" icon="download">
|
||||
|
||||
304
docs/pt-BR/enterprise/guides/deploy-crew.mdx
Normal file
304
docs/pt-BR/enterprise/guides/deploy-crew.mdx
Normal file
@@ -0,0 +1,304 @@
|
||||
---
|
||||
title: "Deploy Crew"
|
||||
description: "Implantando um Crew na CrewAI AMP"
|
||||
icon: "rocket"
|
||||
mode: "wide"
|
||||
---
|
||||
|
||||
<Note>
|
||||
Depois de criar um crew localmente ou pelo Crew Studio, o próximo passo é
|
||||
implantá-lo na plataforma CrewAI AMP. Este guia cobre múltiplos métodos de
|
||||
implantação para ajudá-lo a escolher a melhor abordagem para o seu fluxo de
|
||||
trabalho.
|
||||
</Note>
|
||||
|
||||
## Pré-requisitos
|
||||
|
||||
<CardGroup cols={2}>
|
||||
<Card title="Crew Pronto para Implantação" icon="users">
|
||||
Você deve ter um crew funcional, criado localmente ou pelo Crew Studio
|
||||
</Card>
|
||||
<Card title="Repositório GitHub" icon="github">
|
||||
O código do seu crew deve estar em um repositório do GitHub (para o método
|
||||
de integração com GitHub)
|
||||
</Card>
|
||||
</CardGroup>
|
||||
|
||||
## Opção 1: Implantar Usando o CrewAI CLI
|
||||
|
||||
A CLI fornece a maneira mais rápida de implantar crews desenvolvidos localmente na plataforma Enterprise.
|
||||
|
||||
<Steps>
|
||||
<Step title="Instale o CrewAI CLI">
|
||||
Se ainda não tiver, instale o CrewAI CLI:
|
||||
|
||||
```bash
|
||||
pip install crewai[tools]
|
||||
```
|
||||
|
||||
<Tip>
|
||||
A CLI vem com o pacote principal CrewAI, mas o extra `[tools]` garante todas as dependências de implantação.
|
||||
</Tip>
|
||||
|
||||
</Step>
|
||||
|
||||
<Step title="Autentique-se na Plataforma Enterprise">
|
||||
Primeiro, você precisa autenticar sua CLI com a plataforma CrewAI AMP:
|
||||
|
||||
```bash
|
||||
# Se já possui uma conta CrewAI AMP, ou deseja criar uma:
|
||||
crewai login
|
||||
```
|
||||
|
||||
Ao executar qualquer um dos comandos, a CLI irá:
|
||||
1. Exibir uma URL e um código de dispositivo único
|
||||
2. Abrir seu navegador para a página de autenticação
|
||||
3. Solicitar a confirmação do dispositivo
|
||||
4. Completar o processo de autenticação
|
||||
|
||||
Após a autenticação bem-sucedida, você verá uma mensagem de confirmação no terminal!
|
||||
|
||||
</Step>
|
||||
|
||||
<Step title="Criar uma Implantação">
|
||||
|
||||
No diretório do seu projeto, execute:
|
||||
|
||||
```bash
|
||||
crewai deploy create
|
||||
```
|
||||
|
||||
Este comando irá:
|
||||
1. Detectar informações do seu repositório GitHub
|
||||
2. Identificar variáveis de ambiente no seu arquivo `.env` local
|
||||
3. Transferir essas variáveis com segurança para a plataforma Enterprise
|
||||
4. Criar uma nova implantação com um identificador único
|
||||
|
||||
Com a criação bem-sucedida, você verá uma mensagem como:
|
||||
```shell
|
||||
Deployment created successfully!
|
||||
Name: your_project_name
|
||||
Deployment ID: 01234567-89ab-cdef-0123-456789abcdef
|
||||
Current Status: Deploy Enqueued
|
||||
```
|
||||
|
||||
</Step>
|
||||
|
||||
<Step title="Acompanhe o Progresso da Implantação">
|
||||
|
||||
Acompanhe o status da implantação com:
|
||||
|
||||
```bash
|
||||
crewai deploy status
|
||||
```
|
||||
|
||||
Para ver logs detalhados do processo de build:
|
||||
|
||||
```bash
|
||||
crewai deploy logs
|
||||
```
|
||||
|
||||
<Tip>
|
||||
A primeira implantação normalmente leva de 10 a 15 minutos, pois as imagens dos containers são construídas. As próximas implantações são bem mais rápidas.
|
||||
</Tip>
|
||||
|
||||
</Step>
|
||||
</Steps>
|
||||
|
||||
## Comandos Adicionais da CLI
|
||||
|
||||
O CrewAI CLI oferece vários comandos para gerenciar suas implantações:
|
||||
|
||||
```bash
|
||||
# Liste todas as suas implantações
|
||||
crewai deploy list
|
||||
|
||||
# Consulte o status de uma implantação
|
||||
crewai deploy status
|
||||
|
||||
# Veja os logs da implantação
|
||||
crewai deploy logs
|
||||
|
||||
# Envie atualizações após alterações no código
|
||||
crewai deploy push
|
||||
|
||||
# Remova uma implantação
|
||||
crewai deploy remove <deployment_id>
|
||||
```
|
||||
|
||||
## Opção 2: Implantar Diretamente pela Interface Web
|
||||
|
||||
Você também pode implantar seus crews diretamente pela interface web da CrewAI AMP conectando sua conta do GitHub. Esta abordagem não requer utilizar a CLI na sua máquina local.
|
||||
|
||||
<Steps>
|
||||
|
||||
<Step title="Enviar no GitHub">
|
||||
|
||||
Você precisa subir seu crew para um repositório do GitHub. Caso ainda não tenha criado um crew, você pode [seguir este tutorial](/pt-BR/quickstart).
|
||||
|
||||
</Step>
|
||||
|
||||
<Step title="Conectando o GitHub ao CrewAI AMP">
|
||||
|
||||
1. Faça login em [CrewAI AMP](https://app.crewai.com)
|
||||
2. Clique no botão "Connect GitHub"
|
||||
|
||||
<Frame>
|
||||

|
||||
</Frame>
|
||||
|
||||
</Step>
|
||||
|
||||
<Step title="Selecionar o Repositório">
|
||||
|
||||
Após conectar sua conta GitHub, você poderá selecionar qual repositório deseja implantar:
|
||||
|
||||
<Frame>
|
||||

|
||||
</Frame>
|
||||
|
||||
</Step>
|
||||
|
||||
<Step title="Definir as Variáveis de Ambiente">
|
||||
|
||||
Antes de implantar, você precisará configurar as variáveis de ambiente para conectar ao seu provedor de LLM ou outros serviços:
|
||||
|
||||
1. Você pode adicionar variáveis individualmente ou em lote
|
||||
2. Digite suas variáveis no formato `KEY=VALUE` (uma por linha)
|
||||
|
||||
<Frame>
|
||||

|
||||
</Frame>
|
||||
|
||||
</Step>
|
||||
|
||||
<Step title="Implante Seu Crew">
|
||||
|
||||
1. Clique no botão "Deploy" para iniciar o processo de implantação
|
||||
2. Você pode monitorar o progresso pela barra de progresso
|
||||
3. A primeira implantação geralmente demora de 10 a 15 minutos; as próximas serão mais rápidas
|
||||
|
||||
<Frame>
|
||||

|
||||
</Frame>
|
||||
|
||||
Após a conclusão, você verá:
|
||||
- A URL exclusiva do seu crew
|
||||
- Um Bearer token para proteger sua API crew
|
||||
- Um botão "Delete" caso precise remover a implantação
|
||||
|
||||
</Step>
|
||||
|
||||
</Steps>
|
||||
|
||||
## ⚠️ Requisitos de Segurança para Variáveis de Ambiente
|
||||
|
||||
<Warning>
|
||||
**Importante**: A CrewAI AMP possui restrições de segurança sobre os nomes de
|
||||
variáveis de ambiente que podem causar falha na implantação caso não sejam
|
||||
seguidas.
|
||||
</Warning>
|
||||
|
||||
### Padrões de Variáveis de Ambiente Bloqueados
|
||||
|
||||
Por motivos de segurança, os seguintes padrões de nome de variável de ambiente são **automaticamente filtrados** e causarão problemas de implantação:
|
||||
|
||||
**Padrões Bloqueados:**
|
||||
|
||||
- Variáveis terminando em `_TOKEN` (ex: `MY_API_TOKEN`)
|
||||
- Variáveis terminando em `_PASSWORD` (ex: `DB_PASSWORD`)
|
||||
- Variáveis terminando em `_SECRET` (ex: `API_SECRET`)
|
||||
- Variáveis terminando em `_KEY` em certos contextos
|
||||
|
||||
**Variáveis Bloqueadas Específicas:**
|
||||
|
||||
- `GITHUB_USER`, `GITHUB_TOKEN`
|
||||
- `AWS_REGION`, `AWS_DEFAULT_REGION`
|
||||
- Diversas variáveis internas do sistema CrewAI
|
||||
|
||||
### Exceções Permitidas
|
||||
|
||||
Algumas variáveis são explicitamente permitidas mesmo coincidindo com os padrões bloqueados:
|
||||
|
||||
- `AZURE_AD_TOKEN`
|
||||
- `AZURE_OPENAI_AD_TOKEN`
|
||||
- `ENTERPRISE_ACTION_TOKEN`
|
||||
- `CREWAI_ENTEPRISE_TOOLS_TOKEN`
|
||||
|
||||
### Como Corrigir Problemas de Nomeação
|
||||
|
||||
Se sua implantação falhar devido a restrições de variáveis de ambiente:
|
||||
|
||||
```bash
|
||||
# ❌ Estas irão causar falhas na implantação
|
||||
OPENAI_TOKEN=sk-...
|
||||
DATABASE_PASSWORD=mysenha
|
||||
API_SECRET=segredo123
|
||||
|
||||
# ✅ Utilize estes padrões de nomeação
|
||||
OPENAI_API_KEY=sk-...
|
||||
DATABASE_CREDENTIALS=mysenha
|
||||
API_CONFIG=segredo123
|
||||
```
|
||||
|
||||
### Melhores Práticas
|
||||
|
||||
1. **Use convenções padrão de nomenclatura**: `PROVIDER_API_KEY` em vez de `PROVIDER_TOKEN`
|
||||
2. **Teste localmente primeiro**: Certifique-se de que seu crew funciona com as variáveis renomeadas
|
||||
3. **Atualize seu código**: Altere todas as referências aos nomes antigos das variáveis
|
||||
4. **Documente as mudanças**: Mantenha registro das variáveis renomeadas para seu time
|
||||
|
||||
<Tip>
|
||||
Se você se deparar com falhas de implantação com erros enigmáticos de
|
||||
variáveis de ambiente, confira primeiro os nomes das variáveis em relação a
|
||||
esses padrões.
|
||||
</Tip>
|
||||
|
||||
### Interaja com Seu Crew Implantado
|
||||
|
||||
Após a implantação, você pode acessar seu crew por meio de:
|
||||
|
||||
1. **REST API**: A plataforma gera um endpoint HTTPS exclusivo com estas rotas principais:
|
||||
|
||||
- `/inputs`: Lista os parâmetros de entrada requeridos
|
||||
- `/kickoff`: Inicia uma execução com os inputs fornecidos
|
||||
- `/status/{kickoff_id}`: Consulta o status da execução
|
||||
|
||||
2. **Interface Web**: Acesse [app.crewai.com](https://app.crewai.com) para visualizar:
|
||||
- **Aba Status**: Informações da implantação, detalhes do endpoint da API e token de autenticação
|
||||
- **Aba Run**: Visualização da estrutura do seu crew
|
||||
- **Aba Executions**: Histórico de todas as execuções
|
||||
- **Aba Metrics**: Análises de desempenho
|
||||
- **Aba Traces**: Insights detalhados das execuções
|
||||
|
||||
### Dispare uma Execução
|
||||
|
||||
No dashboard Enterprise, você pode:
|
||||
|
||||
1. Clicar no nome do seu crew para abrir seus detalhes
|
||||
2. Selecionar "Trigger Crew" na interface de gerenciamento
|
||||
3. Inserir os inputs necessários no modal exibido
|
||||
4. Monitorar o progresso à medida que a execução avança pelo pipeline
|
||||
|
||||
### Monitoramento e Análises
|
||||
|
||||
A plataforma Enterprise oferece recursos abrangentes de observabilidade:
|
||||
|
||||
- **Gestão das Execuções**: Acompanhe execuções ativas e concluídas
|
||||
- **Traces**: Quebra detalhada de cada execução
|
||||
- **Métricas**: Uso de tokens, tempos de execução e custos
|
||||
- **Visualização em Linha do Tempo**: Representação visual das sequências de tarefas
|
||||
|
||||
### Funcionalidades Avançadas
|
||||
|
||||
A plataforma Enterprise também oferece:
|
||||
|
||||
- **Gerenciamento de Variáveis de Ambiente**: Armazene e gerencie com segurança as chaves de API
|
||||
- **Conexões com LLM**: Configure integrações com diversos provedores de LLM
|
||||
- **Repositório Custom Tools**: Crie, compartilhe e instale ferramentas
|
||||
- **Crew Studio**: Monte crews via interface de chat sem escrever código
|
||||
|
||||
<Card title="Precisa de Ajuda?" icon="headset" href="mailto:support@crewai.com">
|
||||
Entre em contato com nossa equipe de suporte para ajuda com questões de
|
||||
implantação ou dúvidas sobre a plataforma Enterprise.
|
||||
</Card>
|
||||
@@ -1,439 +0,0 @@
|
||||
---
|
||||
title: "Deploy para AMP"
|
||||
description: "Implante seu Crew ou Flow no CrewAI AMP"
|
||||
icon: "rocket"
|
||||
mode: "wide"
|
||||
---
|
||||
|
||||
<Note>
|
||||
Depois de criar um Crew ou Flow localmente (ou pelo Crew Studio), o próximo passo é
|
||||
implantá-lo na plataforma CrewAI AMP. Este guia cobre múltiplos métodos de
|
||||
implantação para ajudá-lo a escolher a melhor abordagem para o seu fluxo de trabalho.
|
||||
</Note>
|
||||
|
||||
## Pré-requisitos
|
||||
|
||||
<CardGroup cols={2}>
|
||||
<Card title="Projeto Pronto para Implantação" icon="check-circle">
|
||||
Você deve ter um Crew ou Flow funcionando localmente com sucesso.
|
||||
Siga nosso [guia de preparação](/pt-BR/enterprise/guides/prepare-for-deployment) para verificar a estrutura do seu projeto.
|
||||
</Card>
|
||||
<Card title="Repositório GitHub" icon="github">
|
||||
Seu código deve estar em um repositório do GitHub (para o método de integração com GitHub).
|
||||
</Card>
|
||||
</CardGroup>
|
||||
|
||||
<Info>
|
||||
**Crews vs Flows**: Ambos os tipos de projeto podem ser implantados como "automações" no CrewAI AMP.
|
||||
O processo de implantação é o mesmo, mas eles têm estruturas de projeto diferentes.
|
||||
Veja [Preparar para Implantação](/pt-BR/enterprise/guides/prepare-for-deployment) para detalhes.
|
||||
</Info>
|
||||
|
||||
## Opção 1: Implantar Usando o CrewAI CLI
|
||||
|
||||
A CLI fornece a maneira mais rápida de implantar Crews ou Flows desenvolvidos localmente na plataforma AMP.
|
||||
A CLI detecta automaticamente o tipo do seu projeto a partir do `pyproject.toml` e faz o build adequadamente.
|
||||
|
||||
<Steps>
|
||||
<Step title="Instale o CrewAI CLI">
|
||||
Se ainda não tiver, instale o CrewAI CLI:
|
||||
|
||||
```bash
|
||||
pip install crewai[tools]
|
||||
```
|
||||
|
||||
<Tip>
|
||||
A CLI vem com o pacote principal CrewAI, mas o extra `[tools]` garante todas as dependências de implantação.
|
||||
</Tip>
|
||||
|
||||
</Step>
|
||||
|
||||
<Step title="Autentique-se na Plataforma Enterprise">
|
||||
Primeiro, você precisa autenticar sua CLI com a plataforma CrewAI AMP:
|
||||
|
||||
```bash
|
||||
# Se já possui uma conta CrewAI AMP, ou deseja criar uma:
|
||||
crewai login
|
||||
```
|
||||
|
||||
Ao executar qualquer um dos comandos, a CLI irá:
|
||||
1. Exibir uma URL e um código de dispositivo único
|
||||
2. Abrir seu navegador para a página de autenticação
|
||||
3. Solicitar a confirmação do dispositivo
|
||||
4. Completar o processo de autenticação
|
||||
|
||||
Após a autenticação bem-sucedida, você verá uma mensagem de confirmação no terminal!
|
||||
|
||||
</Step>
|
||||
|
||||
<Step title="Criar uma Implantação">
|
||||
|
||||
No diretório do seu projeto, execute:
|
||||
|
||||
```bash
|
||||
crewai deploy create
|
||||
```
|
||||
|
||||
Este comando irá:
|
||||
1. Detectar informações do seu repositório GitHub
|
||||
2. Identificar variáveis de ambiente no seu arquivo `.env` local
|
||||
3. Transferir essas variáveis com segurança para a plataforma Enterprise
|
||||
4. Criar uma nova implantação com um identificador único
|
||||
|
||||
Com a criação bem-sucedida, você verá uma mensagem como:
|
||||
```shell
|
||||
Deployment created successfully!
|
||||
Name: your_project_name
|
||||
Deployment ID: 01234567-89ab-cdef-0123-456789abcdef
|
||||
Current Status: Deploy Enqueued
|
||||
```
|
||||
|
||||
</Step>
|
||||
|
||||
<Step title="Acompanhe o Progresso da Implantação">
|
||||
|
||||
Acompanhe o status da implantação com:
|
||||
|
||||
```bash
|
||||
crewai deploy status
|
||||
```
|
||||
|
||||
Para ver logs detalhados do processo de build:
|
||||
|
||||
```bash
|
||||
crewai deploy logs
|
||||
```
|
||||
|
||||
<Tip>
|
||||
A primeira implantação normalmente leva de 10 a 15 minutos, pois as imagens dos containers são construídas. As próximas implantações são bem mais rápidas.
|
||||
</Tip>
|
||||
|
||||
</Step>
|
||||
</Steps>
|
||||
|
||||
## Comandos Adicionais da CLI
|
||||
|
||||
O CrewAI CLI oferece vários comandos para gerenciar suas implantações:
|
||||
|
||||
```bash
|
||||
# Liste todas as suas implantações
|
||||
crewai deploy list
|
||||
|
||||
# Consulte o status de uma implantação
|
||||
crewai deploy status
|
||||
|
||||
# Veja os logs da implantação
|
||||
crewai deploy logs
|
||||
|
||||
# Envie atualizações após alterações no código
|
||||
crewai deploy push
|
||||
|
||||
# Remova uma implantação
|
||||
crewai deploy remove <deployment_id>
|
||||
```
|
||||
|
||||
## Opção 2: Implantar Diretamente pela Interface Web
|
||||
|
||||
Você também pode implantar seus Crews ou Flows diretamente pela interface web do CrewAI AMP conectando sua conta do GitHub. Esta abordagem não requer utilizar a CLI na sua máquina local. A plataforma detecta automaticamente o tipo do seu projeto e trata o build adequadamente.
|
||||
|
||||
<Steps>
|
||||
|
||||
<Step title="Enviar para o GitHub">
|
||||
|
||||
Você precisa enviar seu crew para um repositório do GitHub. Caso ainda não tenha criado um crew, você pode [seguir este tutorial](/pt-BR/quickstart).
|
||||
|
||||
</Step>
|
||||
|
||||
<Step title="Conectando o GitHub ao CrewAI AMP">
|
||||
|
||||
1. Faça login em [CrewAI AMP](https://app.crewai.com)
|
||||
2. Clique no botão "Connect GitHub"
|
||||
|
||||
<Frame>
|
||||

|
||||
</Frame>
|
||||
|
||||
</Step>
|
||||
|
||||
<Step title="Selecionar o Repositório">
|
||||
|
||||
Após conectar sua conta GitHub, você poderá selecionar qual repositório deseja implantar:
|
||||
|
||||
<Frame>
|
||||

|
||||
</Frame>
|
||||
|
||||
</Step>
|
||||
|
||||
<Step title="Definir as Variáveis de Ambiente">
|
||||
|
||||
Antes de implantar, você precisará configurar as variáveis de ambiente para conectar ao seu provedor de LLM ou outros serviços:
|
||||
|
||||
1. Você pode adicionar variáveis individualmente ou em lote
|
||||
2. Digite suas variáveis no formato `KEY=VALUE` (uma por linha)
|
||||
|
||||
<Frame>
|
||||

|
||||
</Frame>
|
||||
|
||||
</Step>
|
||||
|
||||
<Step title="Implante Seu Crew">
|
||||
|
||||
1. Clique no botão "Deploy" para iniciar o processo de implantação
|
||||
2. Você pode monitorar o progresso pela barra de progresso
|
||||
3. A primeira implantação geralmente demora de 10 a 15 minutos; as próximas serão mais rápidas
|
||||
|
||||
<Frame>
|
||||

|
||||
</Frame>
|
||||
|
||||
Após a conclusão, você verá:
|
||||
- A URL exclusiva do seu crew
|
||||
- Um Bearer token para proteger sua API crew
|
||||
- Um botão "Delete" caso precise remover a implantação
|
||||
|
||||
</Step>
|
||||
|
||||
</Steps>
|
||||
|
||||
## Opção 3: Reimplantar Usando API (Integração CI/CD)
|
||||
|
||||
Para implantações automatizadas em pipelines CI/CD, você pode usar a API do CrewAI para acionar reimplantações de crews existentes. Isso é particularmente útil para GitHub Actions, Jenkins ou outros workflows de automação.
|
||||
|
||||
<Steps>
|
||||
<Step title="Obtenha Seu Token de Acesso Pessoal">
|
||||
|
||||
Navegue até as configurações da sua conta CrewAI AMP para gerar um token de API:
|
||||
|
||||
1. Acesse [app.crewai.com](https://app.crewai.com)
|
||||
2. Clique em **Settings** → **Account** → **Personal Access Token**
|
||||
3. Gere um novo token e copie-o com segurança
|
||||
4. Armazene este token como um secret no seu sistema CI/CD
|
||||
|
||||
</Step>
|
||||
|
||||
<Step title="Encontre o UUID da Sua Automação">
|
||||
|
||||
Localize o identificador único do seu crew implantado:
|
||||
|
||||
1. Acesse **Automations** no seu dashboard CrewAI AMP
|
||||
2. Selecione sua automação/crew existente
|
||||
3. Clique em **Additional Details**
|
||||
4. Copie o **UUID** - este identifica sua implantação específica do crew
|
||||
|
||||
</Step>
|
||||
|
||||
<Step title="Acione a Reimplantação via API">
|
||||
|
||||
Use o endpoint da API de Deploy para acionar uma reimplantação:
|
||||
|
||||
```bash
|
||||
curl -i -X POST \
|
||||
-H "Authorization: Bearer YOUR_PERSONAL_ACCESS_TOKEN" \
|
||||
https://app.crewai.com/crewai_plus/api/v1/crews/YOUR-AUTOMATION-UUID/deploy
|
||||
|
||||
# HTTP/2 200
|
||||
# content-type: application/json
|
||||
#
|
||||
# {
|
||||
# "uuid": "your-automation-uuid",
|
||||
# "status": "Deploy Enqueued",
|
||||
# "public_url": "https://your-crew-deployment.crewai.com",
|
||||
# "token": "your-bearer-token"
|
||||
# }
|
||||
```
|
||||
|
||||
<Info>
|
||||
Se sua automação foi criada originalmente conectada ao Git, a API automaticamente puxará as últimas alterações do seu repositório antes de reimplantar.
|
||||
</Info>
|
||||
|
||||
</Step>
|
||||
|
||||
<Step title="Exemplo de Integração com GitHub Actions">
|
||||
|
||||
Aqui está um workflow do GitHub Actions com gatilhos de implantação mais complexos:
|
||||
|
||||
```yaml
|
||||
name: Deploy CrewAI Automation
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
types: [ labeled ]
|
||||
release:
|
||||
types: [ published ]
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
if: |
|
||||
(github.event_name == 'push' && github.ref == 'refs/heads/main') ||
|
||||
(github.event_name == 'pull_request' && contains(github.event.pull_request.labels.*.name, 'deploy')) ||
|
||||
(github.event_name == 'release')
|
||||
steps:
|
||||
- name: Trigger CrewAI Redeployment
|
||||
run: |
|
||||
curl -X POST \
|
||||
-H "Authorization: Bearer ${{ secrets.CREWAI_PAT }}" \
|
||||
https://app.crewai.com/crewai_plus/api/v1/crews/${{ secrets.CREWAI_AUTOMATION_UUID }}/deploy
|
||||
```
|
||||
|
||||
<Tip>
|
||||
Adicione `CREWAI_PAT` e `CREWAI_AUTOMATION_UUID` como secrets do repositório. Para implantações de PR, adicione um label "deploy" para acionar o workflow.
|
||||
</Tip>
|
||||
|
||||
</Step>
|
||||
|
||||
</Steps>
|
||||
|
||||
## Interaja com Sua Automação Implantada
|
||||
|
||||
Após a implantação, você pode acessar seu crew através de:
|
||||
|
||||
1. **REST API**: A plataforma gera um endpoint HTTPS exclusivo com estas rotas principais:
|
||||
|
||||
- `/inputs`: Lista os parâmetros de entrada requeridos
|
||||
- `/kickoff`: Inicia uma execução com os inputs fornecidos
|
||||
- `/status/{kickoff_id}`: Consulta o status da execução
|
||||
|
||||
2. **Interface Web**: Acesse [app.crewai.com](https://app.crewai.com) para visualizar:
|
||||
- **Aba Status**: Informações da implantação, detalhes do endpoint da API e token de autenticação
|
||||
- **Aba Run**: Visualização da estrutura do seu crew
|
||||
- **Aba Executions**: Histórico de todas as execuções
|
||||
- **Aba Metrics**: Análises de desempenho
|
||||
- **Aba Traces**: Insights detalhados das execuções
|
||||
|
||||
### Dispare uma Execução
|
||||
|
||||
No dashboard Enterprise, você pode:
|
||||
|
||||
1. Clicar no nome do seu crew para abrir seus detalhes
|
||||
2. Selecionar "Trigger Crew" na interface de gerenciamento
|
||||
3. Inserir os inputs necessários no modal exibido
|
||||
4. Monitorar o progresso à medida que a execução avança pelo pipeline
|
||||
|
||||
### Monitoramento e Análises
|
||||
|
||||
A plataforma Enterprise oferece recursos abrangentes de observabilidade:
|
||||
|
||||
- **Gestão das Execuções**: Acompanhe execuções ativas e concluídas
|
||||
- **Traces**: Quebra detalhada de cada execução
|
||||
- **Métricas**: Uso de tokens, tempos de execução e custos
|
||||
- **Visualização em Linha do Tempo**: Representação visual das sequências de tarefas
|
||||
|
||||
### Funcionalidades Avançadas
|
||||
|
||||
A plataforma Enterprise também oferece:
|
||||
|
||||
- **Gerenciamento de Variáveis de Ambiente**: Armazene e gerencie com segurança as chaves de API
|
||||
- **Conexões com LLM**: Configure integrações com diversos provedores de LLM
|
||||
- **Repositório Custom Tools**: Crie, compartilhe e instale ferramentas
|
||||
- **Crew Studio**: Monte crews via interface de chat sem escrever código
|
||||
|
||||
## Solução de Problemas em Falhas de Implantação
|
||||
|
||||
Se sua implantação falhar, verifique estes problemas comuns:
|
||||
|
||||
### Falhas de Build
|
||||
|
||||
#### Arquivo uv.lock Ausente
|
||||
|
||||
**Sintoma**: Build falha no início com erros de resolução de dependências
|
||||
|
||||
**Solução**: Gere e faça commit do arquivo lock:
|
||||
|
||||
```bash
|
||||
uv lock
|
||||
git add uv.lock
|
||||
git commit -m "Add uv.lock for deployment"
|
||||
git push
|
||||
```
|
||||
|
||||
<Warning>
|
||||
O arquivo `uv.lock` é obrigatório para todas as implantações. Sem ele, a plataforma
|
||||
não consegue instalar suas dependências de forma confiável.
|
||||
</Warning>
|
||||
|
||||
#### Estrutura de Projeto Incorreta
|
||||
|
||||
**Sintoma**: Erros "Could not find entry point" ou "Module not found"
|
||||
|
||||
**Solução**: Verifique se seu projeto corresponde à estrutura esperada:
|
||||
|
||||
- **Tanto Crews quanto Flows**: Devem ter ponto de entrada em `src/project_name/main.py`
|
||||
- **Crews**: Usam uma função `run()` como ponto de entrada
|
||||
- **Flows**: Usam uma função `kickoff()` como ponto de entrada
|
||||
|
||||
Veja [Preparar para Implantação](/pt-BR/enterprise/guides/prepare-for-deployment) para diagramas de estrutura detalhados.
|
||||
|
||||
#### Decorador CrewBase Ausente
|
||||
|
||||
**Sintoma**: Erros "Crew not found", "Config not found" ou erros de configuração de agent/task
|
||||
|
||||
**Solução**: Certifique-se de que **todas** as classes crew usam o decorador `@CrewBase`:
|
||||
|
||||
```python
|
||||
from crewai.project import CrewBase, agent, crew, task
|
||||
|
||||
@CrewBase # Este decorador é OBRIGATÓRIO
|
||||
class YourCrew():
|
||||
"""Descrição do seu crew"""
|
||||
|
||||
@agent
|
||||
def my_agent(self) -> Agent:
|
||||
return Agent(
|
||||
config=self.agents_config['my_agent'], # type: ignore[index]
|
||||
verbose=True
|
||||
)
|
||||
|
||||
# ... resto da definição do crew
|
||||
```
|
||||
|
||||
<Info>
|
||||
Isso se aplica a Crews independentes E crews embutidos dentro de projetos Flow.
|
||||
Toda classe crew precisa do decorador.
|
||||
</Info>
|
||||
|
||||
#### Tipo Incorreto no pyproject.toml
|
||||
|
||||
**Sintoma**: Build tem sucesso mas falha em runtime, ou comportamento inesperado
|
||||
|
||||
**Solução**: Verifique se a seção `[tool.crewai]` corresponde ao tipo do seu projeto:
|
||||
|
||||
```toml
|
||||
# Para projetos Crew:
|
||||
[tool.crewai]
|
||||
type = "crew"
|
||||
|
||||
# Para projetos Flow:
|
||||
[tool.crewai]
|
||||
type = "flow"
|
||||
```
|
||||
|
||||
### Falhas de Runtime
|
||||
|
||||
#### Falhas de Conexão com LLM
|
||||
|
||||
**Sintoma**: Erros de chave API, "model not found" ou falhas de autenticação
|
||||
|
||||
**Solução**:
|
||||
1. Verifique se a chave API do seu provedor LLM está corretamente definida nas variáveis de ambiente
|
||||
2. Certifique-se de que os nomes das variáveis de ambiente correspondem ao que seu código espera
|
||||
3. Teste localmente com exatamente as mesmas variáveis de ambiente antes de implantar
|
||||
|
||||
#### Erros de Execução do Crew
|
||||
|
||||
**Sintoma**: Crew inicia mas falha durante a execução
|
||||
|
||||
**Solução**:
|
||||
1. Verifique os logs de execução no dashboard AMP (aba Traces)
|
||||
2. Verifique se todas as ferramentas têm as chaves API necessárias configuradas
|
||||
3. Certifique-se de que as configurações de agents em `agents.yaml` são válidas
|
||||
4. Verifique se há erros de sintaxe nas configurações de tasks em `tasks.yaml`
|
||||
|
||||
<Card title="Precisa de Ajuda?" icon="headset" href="mailto:support@crewai.com">
|
||||
Entre em contato com nossa equipe de suporte para ajuda com questões de
|
||||
implantação ou dúvidas sobre a plataforma AMP.
|
||||
</Card>
|
||||
@@ -1,305 +0,0 @@
|
||||
---
|
||||
title: "Preparar para Implantação"
|
||||
description: "Certifique-se de que seu Crew ou Flow está pronto para implantação no CrewAI AMP"
|
||||
icon: "clipboard-check"
|
||||
mode: "wide"
|
||||
---
|
||||
|
||||
<Note>
|
||||
Antes de implantar no CrewAI AMP, é crucial verificar se seu projeto está estruturado corretamente.
|
||||
Tanto Crews quanto Flows podem ser implantados como "automações", mas eles têm estruturas de projeto
|
||||
e requisitos diferentes que devem ser atendidos para uma implantação bem-sucedida.
|
||||
</Note>
|
||||
|
||||
## Entendendo Automações
|
||||
|
||||
No CrewAI AMP, **automações** é o termo geral para projetos de IA Agêntica implantáveis. Uma automação pode ser:
|
||||
|
||||
- **Um Crew**: Uma equipe independente de agentes de IA trabalhando juntos em tarefas
|
||||
- **Um Flow**: Um workflow orquestrado que pode combinar múltiplos crews, chamadas diretas de LLM e lógica procedural
|
||||
|
||||
Entender qual tipo você está implantando é essencial porque eles têm estruturas de projeto e pontos de entrada diferentes.
|
||||
|
||||
## Crews vs Flows: Principais Diferenças
|
||||
|
||||
<CardGroup cols={2}>
|
||||
<Card title="Projetos Crew" icon="users">
|
||||
Equipes de agentes de IA independentes com `crew.py` definindo agentes e tarefas. Ideal para tarefas focadas e colaborativas.
|
||||
</Card>
|
||||
<Card title="Projetos Flow" icon="diagram-project">
|
||||
Workflows orquestrados com crews embutidos em uma pasta `crews/`. Ideal para processos complexos de múltiplas etapas.
|
||||
</Card>
|
||||
</CardGroup>
|
||||
|
||||
| Aspecto | Crew | Flow |
|
||||
|---------|------|------|
|
||||
| **Estrutura do projeto** | `src/project_name/` com `crew.py` | `src/project_name/` com pasta `crews/` |
|
||||
| **Localização da lógica principal** | `src/project_name/crew.py` | `src/project_name/main.py` (classe Flow) |
|
||||
| **Função de ponto de entrada** | `run()` em `main.py` | `kickoff()` em `main.py` |
|
||||
| **Tipo no pyproject.toml** | `type = "crew"` | `type = "flow"` |
|
||||
| **Comando CLI de criação** | `crewai create crew name` | `crewai create flow name` |
|
||||
| **Localização da configuração** | `src/project_name/config/` | `src/project_name/crews/crew_name/config/` |
|
||||
| **Pode conter outros crews** | Não | Sim (na pasta `crews/`) |
|
||||
|
||||
## Referência de Estrutura de Projeto
|
||||
|
||||
### Estrutura de Projeto Crew
|
||||
|
||||
Quando você executa `crewai create crew my_crew`, você obtém esta estrutura:
|
||||
|
||||
```
|
||||
my_crew/
|
||||
├── .gitignore
|
||||
├── pyproject.toml # Deve ter type = "crew"
|
||||
├── README.md
|
||||
├── .env
|
||||
├── uv.lock # OBRIGATÓRIO para implantação
|
||||
└── src/
|
||||
└── my_crew/
|
||||
├── __init__.py
|
||||
├── main.py # Ponto de entrada com função run()
|
||||
├── crew.py # Classe Crew com decorador @CrewBase
|
||||
├── tools/
|
||||
│ ├── custom_tool.py
|
||||
│ └── __init__.py
|
||||
└── config/
|
||||
├── agents.yaml # Definições de agentes
|
||||
└── tasks.yaml # Definições de tarefas
|
||||
```
|
||||
|
||||
<Warning>
|
||||
A estrutura aninhada `src/project_name/` é crítica para Crews.
|
||||
Colocar arquivos no nível errado causará falhas na implantação.
|
||||
</Warning>
|
||||
|
||||
### Estrutura de Projeto Flow
|
||||
|
||||
Quando você executa `crewai create flow my_flow`, você obtém esta estrutura:
|
||||
|
||||
```
|
||||
my_flow/
|
||||
├── .gitignore
|
||||
├── pyproject.toml # Deve ter type = "flow"
|
||||
├── README.md
|
||||
├── .env
|
||||
├── uv.lock # OBRIGATÓRIO para implantação
|
||||
└── src/
|
||||
└── my_flow/
|
||||
├── __init__.py
|
||||
├── main.py # Ponto de entrada com função kickoff() + classe Flow
|
||||
├── crews/ # Pasta de crews embutidos
|
||||
│ └── poem_crew/
|
||||
│ ├── __init__.py
|
||||
│ ├── poem_crew.py # Crew com decorador @CrewBase
|
||||
│ └── config/
|
||||
│ ├── agents.yaml
|
||||
│ └── tasks.yaml
|
||||
└── tools/
|
||||
├── __init__.py
|
||||
└── custom_tool.py
|
||||
```
|
||||
|
||||
<Info>
|
||||
Tanto Crews quanto Flows usam a estrutura `src/project_name/`.
|
||||
A diferença chave é que Flows têm uma pasta `crews/` para crews embutidos,
|
||||
enquanto Crews têm `crew.py` diretamente na pasta do projeto.
|
||||
</Info>
|
||||
|
||||
## Checklist Pré-Implantação
|
||||
|
||||
Use este checklist para verificar se seu projeto está pronto para implantação.
|
||||
|
||||
### 1. Verificar Configuração do pyproject.toml
|
||||
|
||||
Seu `pyproject.toml` deve incluir a seção `[tool.crewai]` correta:
|
||||
|
||||
<Tabs>
|
||||
<Tab title="Para Crews">
|
||||
```toml
|
||||
[tool.crewai]
|
||||
type = "crew"
|
||||
```
|
||||
</Tab>
|
||||
<Tab title="Para Flows">
|
||||
```toml
|
||||
[tool.crewai]
|
||||
type = "flow"
|
||||
```
|
||||
</Tab>
|
||||
</Tabs>
|
||||
|
||||
<Warning>
|
||||
Se o `type` não corresponder à estrutura do seu projeto, o build falhará ou
|
||||
a automação não funcionará corretamente.
|
||||
</Warning>
|
||||
|
||||
### 2. Garantir que o Arquivo uv.lock Existe
|
||||
|
||||
CrewAI usa `uv` para gerenciamento de dependências. O arquivo `uv.lock` garante builds reproduzíveis e é **obrigatório** para implantação.
|
||||
|
||||
```bash
|
||||
# Gerar ou atualizar o arquivo lock
|
||||
uv lock
|
||||
|
||||
# Verificar se existe
|
||||
ls -la uv.lock
|
||||
```
|
||||
|
||||
Se o arquivo não existir, execute `uv lock` e faça commit no seu repositório:
|
||||
|
||||
```bash
|
||||
uv lock
|
||||
git add uv.lock
|
||||
git commit -m "Add uv.lock for deployment"
|
||||
git push
|
||||
```
|
||||
|
||||
### 3. Validar Uso do Decorador CrewBase
|
||||
|
||||
**Toda classe crew deve usar o decorador `@CrewBase`.** Isso se aplica a:
|
||||
|
||||
- Projetos crew independentes
|
||||
- Crews embutidos dentro de projetos Flow
|
||||
|
||||
```python
|
||||
from crewai import Agent, Crew, Process, Task
|
||||
from crewai.project import CrewBase, agent, crew, task
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from typing import List
|
||||
|
||||
@CrewBase # Este decorador é OBRIGATÓRIO
|
||||
class MyCrew():
|
||||
"""Descrição do meu crew"""
|
||||
|
||||
agents: List[BaseAgent]
|
||||
tasks: List[Task]
|
||||
|
||||
@agent
|
||||
def my_agent(self) -> Agent:
|
||||
return Agent(
|
||||
config=self.agents_config['my_agent'], # type: ignore[index]
|
||||
verbose=True
|
||||
)
|
||||
|
||||
@task
|
||||
def my_task(self) -> Task:
|
||||
return Task(
|
||||
config=self.tasks_config['my_task'] # type: ignore[index]
|
||||
)
|
||||
|
||||
@crew
|
||||
def crew(self) -> Crew:
|
||||
return Crew(
|
||||
agents=self.agents,
|
||||
tasks=self.tasks,
|
||||
process=Process.sequential,
|
||||
verbose=True,
|
||||
)
|
||||
```
|
||||
|
||||
<Warning>
|
||||
Se você esquecer o decorador `@CrewBase`, sua implantação falhará com
|
||||
erros sobre configurações de agents ou tasks ausentes.
|
||||
</Warning>
|
||||
|
||||
### 4. Verificar Pontos de Entrada do Projeto
|
||||
|
||||
Tanto Crews quanto Flows têm seu ponto de entrada em `src/project_name/main.py`:
|
||||
|
||||
<Tabs>
|
||||
<Tab title="Para Crews">
|
||||
O ponto de entrada usa uma função `run()`:
|
||||
|
||||
```python
|
||||
# src/my_crew/main.py
|
||||
from my_crew.crew import MyCrew
|
||||
|
||||
def run():
|
||||
"""Executa o crew."""
|
||||
inputs = {'topic': 'AI in Healthcare'}
|
||||
result = MyCrew().crew().kickoff(inputs=inputs)
|
||||
return result
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
```
|
||||
</Tab>
|
||||
<Tab title="Para Flows">
|
||||
O ponto de entrada usa uma função `kickoff()` com uma classe Flow:
|
||||
|
||||
```python
|
||||
# src/my_flow/main.py
|
||||
from crewai.flow import Flow, listen, start
|
||||
from my_flow.crews.poem_crew.poem_crew import PoemCrew
|
||||
|
||||
class MyFlow(Flow):
|
||||
@start()
|
||||
def begin(self):
|
||||
# Lógica do Flow aqui
|
||||
result = PoemCrew().crew().kickoff(inputs={...})
|
||||
return result
|
||||
|
||||
def kickoff():
|
||||
"""Executa o flow."""
|
||||
MyFlow().kickoff()
|
||||
|
||||
if __name__ == "__main__":
|
||||
kickoff()
|
||||
```
|
||||
</Tab>
|
||||
</Tabs>
|
||||
|
||||
### 5. Preparar Variáveis de Ambiente
|
||||
|
||||
Antes da implantação, certifique-se de ter:
|
||||
|
||||
1. **Chaves de API de LLM** prontas (OpenAI, Anthropic, Google, etc.)
|
||||
2. **Chaves de API de ferramentas** se estiver usando ferramentas externas (Serper, etc.)
|
||||
|
||||
<Tip>
|
||||
Teste seu projeto localmente com as mesmas variáveis de ambiente antes de implantar
|
||||
para detectar problemas de configuração antecipadamente.
|
||||
</Tip>
|
||||
|
||||
## Comandos de Validação Rápida
|
||||
|
||||
Execute estes comandos a partir da raiz do seu projeto para verificar rapidamente sua configuração:
|
||||
|
||||
```bash
|
||||
# 1. Verificar tipo do projeto no pyproject.toml
|
||||
grep -A2 "\[tool.crewai\]" pyproject.toml
|
||||
|
||||
# 2. Verificar se uv.lock existe
|
||||
ls -la uv.lock || echo "ERRO: uv.lock ausente! Execute 'uv lock'"
|
||||
|
||||
# 3. Verificar se estrutura src/ existe
|
||||
ls -la src/*/main.py 2>/dev/null || echo "Nenhum main.py encontrado em src/"
|
||||
|
||||
# 4. Para Crews - verificar se crew.py existe
|
||||
ls -la src/*/crew.py 2>/dev/null || echo "Nenhum crew.py (esperado para Crews)"
|
||||
|
||||
# 5. Para Flows - verificar se pasta crews/ existe
|
||||
ls -la src/*/crews/ 2>/dev/null || echo "Nenhuma pasta crews/ (esperado para Flows)"
|
||||
|
||||
# 6. Verificar uso do CrewBase
|
||||
grep -r "@CrewBase" . --include="*.py"
|
||||
```
|
||||
|
||||
## Erros Comuns de Configuração
|
||||
|
||||
| Erro | Sintoma | Correção |
|
||||
|------|---------|----------|
|
||||
| `uv.lock` ausente | Build falha durante resolução de dependências | Execute `uv lock` e faça commit |
|
||||
| `type` errado no pyproject.toml | Build bem-sucedido mas falha em runtime | Altere para o tipo correto |
|
||||
| Decorador `@CrewBase` ausente | Erros "Config not found" | Adicione decorador a todas as classes crew |
|
||||
| Arquivos na raiz ao invés de `src/` | Ponto de entrada não encontrado | Mova para `src/project_name/` |
|
||||
| `run()` ou `kickoff()` ausente | Não é possível iniciar automação | Adicione a função de entrada correta |
|
||||
|
||||
## Próximos Passos
|
||||
|
||||
Uma vez que seu projeto passar por todos os itens do checklist, você está pronto para implantar:
|
||||
|
||||
<Card title="Deploy para AMP" icon="rocket" href="/pt-BR/enterprise/guides/deploy-to-amp">
|
||||
Siga o guia de implantação para implantar seu Crew ou Flow no CrewAI AMP usando
|
||||
a CLI, interface web ou integração CI/CD.
|
||||
</Card>
|
||||
@@ -82,7 +82,7 @@ CrewAI AMP expande o poder do framework open-source com funcionalidades projetad
|
||||
<Card
|
||||
title="Implantar Crew"
|
||||
icon="rocket"
|
||||
href="/pt-BR/enterprise/guides/deploy-to-amp"
|
||||
href="/pt-BR/enterprise/guides/deploy-crew"
|
||||
>
|
||||
Implantar Crew
|
||||
</Card>
|
||||
@@ -92,11 +92,11 @@ CrewAI AMP expande o poder do framework open-source com funcionalidades projetad
|
||||
<Card
|
||||
title="Acesso via API"
|
||||
icon="code"
|
||||
href="/pt-BR/enterprise/guides/kickoff-crew"
|
||||
href="/pt-BR/enterprise/guides/deploy-crew"
|
||||
>
|
||||
Usar a API do Crew
|
||||
</Card>
|
||||
</Step>
|
||||
</Steps>
|
||||
|
||||
Para instruções detalhadas, consulte nosso [guia de implantação](/pt-BR/enterprise/guides/deploy-to-amp) ou clique no botão abaixo para começar.
|
||||
Para instruções detalhadas, consulte nosso [guia de implantação](/pt-BR/enterprise/guides/deploy-crew) ou clique no botão abaixo para começar.
|
||||
|
||||
@@ -1,115 +0,0 @@
|
||||
---
|
||||
title: Galileo Galileu
|
||||
description: Integração Galileo para rastreamento e avaliação CrewAI
|
||||
icon: telescope
|
||||
mode: "wide"
|
||||
---
|
||||
|
||||
## Visão geral
|
||||
|
||||
Este guia demonstra como integrar o **Galileo**com o **CrewAI**
|
||||
para rastreamento abrangente e engenharia de avaliação.
|
||||
Ao final deste guia, você será capaz de rastrear seus agentes CrewAI,
|
||||
monitorar seu desempenho e avaliar seu comportamento com
|
||||
A poderosa plataforma de observabilidade do Galileo.
|
||||
|
||||
> **O que é Galileo?**[Galileo](https://galileo.ai/) é avaliação e observabilidade de IA
|
||||
plataforma que oferece rastreamento, avaliação e
|
||||
e monitoramento de aplicações de IA. Ele permite que as equipes capturem a verdade,
|
||||
criar grades de proteção robustas e realizar experimentos sistemáticos com
|
||||
rastreamento de experimentos integrado e análise de desempenho -garantindo confiabilidade,
|
||||
transparência e melhoria contínua em todo o ciclo de vida da IA.
|
||||
|
||||
## Primeiros passos
|
||||
|
||||
Este tutorial segue o [CrewAI Quickstart](pt-BR/quickstart) e mostra como adicionar
|
||||
[CrewAIEventListener] do Galileo(https://v2docs.galileo.ai/sdk-api/python/reference/handlers/crewai/handler),
|
||||
um manipulador de eventos.
|
||||
Para mais informações, consulte Galileu
|
||||
[Adicionar Galileo a um aplicativo CrewAI](https://v2docs.galileo.ai/how-to-guides/third-party-integrations/add-galileo-to-crewai/add-galileo-to-crewai)
|
||||
guia prático.
|
||||
|
||||
> **Observação**Este tutorial pressupõe que você concluiu o [CrewAI Quickstart](pt-BR/quickstart).
|
||||
Se você quiser um exemplo completo e abrangente, consulte o Galileo
|
||||
[Repositório de exemplo SDK da CrewAI](https://github.com/rungalileo/sdk-examples/tree/main/python/agent/crew-ai).
|
||||
|
||||
### Etapa 1: instalar dependências
|
||||
|
||||
Instale as dependências necessárias para seu aplicativo.
|
||||
Crie um ambiente virtual usando seu método preferido,
|
||||
em seguida, instale dependências dentro desse ambiente usando seu
|
||||
ferramenta preferida:
|
||||
|
||||
```bash
|
||||
uv add galileo
|
||||
```
|
||||
|
||||
### Etapa 2: adicione ao arquivo .env do [CrewAI Quickstart](/pt-BR/quickstart)
|
||||
|
||||
```bash
|
||||
# Your Galileo API key
|
||||
GALILEO_API_KEY="your-galileo-api-key"
|
||||
|
||||
# Your Galileo project name
|
||||
GALILEO_PROJECT="your-galileo-project-name"
|
||||
|
||||
# The name of the Log stream you want to use for logging
|
||||
GALILEO_LOG_STREAM="your-galileo-log-stream "
|
||||
```
|
||||
|
||||
### Etapa 3: adicionar o ouvinte de eventos Galileo
|
||||
|
||||
Para habilitar o registro com Galileo, você precisa criar uma instância do `CrewAIEventListener`.
|
||||
Importe o pacote manipulador Galileo CrewAI por
|
||||
adicionando o seguinte código no topo do seu arquivo main.py:
|
||||
|
||||
```python
|
||||
from galileo.handlers.crewai.handler import CrewAIEventListener
|
||||
```
|
||||
|
||||
No início da sua função run, crie o ouvinte de evento:
|
||||
|
||||
```python
|
||||
def run():
|
||||
# Create the event listener
|
||||
CrewAIEventListener()
|
||||
# The rest of your existing code goes here
|
||||
```
|
||||
|
||||
Quando você cria a instância do listener, ela é automaticamente
|
||||
registrado na CrewAI.
|
||||
|
||||
### Etapa 4: administre sua Crew
|
||||
|
||||
Administre sua Crew com o CrewAI CLI:
|
||||
|
||||
```bash
|
||||
crewai run
|
||||
```
|
||||
|
||||
### Passo 5: Visualize os traços no Galileo
|
||||
|
||||
Assim que sua tripulação terminar, os rastros serão eliminados e aparecerão no Galileo.
|
||||
|
||||

|
||||
|
||||
## Compreendendo a integração do Galileo
|
||||
|
||||
Galileo se integra ao CrewAI registrando um ouvinte de evento
|
||||
que captura eventos de execução da tripulação (por exemplo, ações do agente, chamadas de ferramentas, respostas do modelo)
|
||||
e os encaminha ao Galileo para observabilidade e avaliação.
|
||||
|
||||
### Compreendendo o ouvinte de eventos
|
||||
|
||||
Criar uma instância `CrewAIEventListener()` é tudo o que você precisa
|
||||
necessário para habilitar o Galileo para uma execução do CrewAI. Quando instanciado, o ouvinte:
|
||||
|
||||
-Registra-se automaticamente no CrewAI
|
||||
-Lê a configuração do Galileo a partir de variáveis de ambiente
|
||||
-Registra todos os dados de execução no projeto Galileo e fluxo de log especificado por
|
||||
`GALILEO_PROJECT` e `GALILEO_LOG_STREAM`
|
||||
|
||||
Nenhuma configuração adicional ou alterações de código são necessárias.
|
||||
Todos os dados desta execução são registados no projecto Galileo e
|
||||
fluxo de log especificado pela configuração do seu ambiente
|
||||
(por exemplo, GALILEO_PROJECT e GALILEO_LOG_STREAM).
|
||||
@@ -1,25 +0,0 @@
|
||||
[project]
|
||||
name = "crewai-files"
|
||||
dynamic = ["version"]
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "Greyson LaLonde", email = "greyson@crewai.com" }
|
||||
]
|
||||
requires-python = ">=3.10, <3.14"
|
||||
dependencies = [
|
||||
"Pillow~=10.4.0",
|
||||
"pypdf~=4.0.0",
|
||||
"python-magic>=0.4.27",
|
||||
"aiocache~=0.12.3",
|
||||
"aiofiles~=24.1.0",
|
||||
"tinytag~=1.10.0",
|
||||
"av~=13.0.0",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.version]
|
||||
path = "src/crewai_files/__init__.py"
|
||||
@@ -1,153 +0,0 @@
|
||||
"""File handling utilities for crewAI tasks."""
|
||||
|
||||
from crewai_files.cache.cleanup import (
|
||||
cleanup_expired_files,
|
||||
cleanup_provider_files,
|
||||
cleanup_uploaded_files,
|
||||
)
|
||||
from crewai_files.cache.upload_cache import (
|
||||
CachedUpload,
|
||||
UploadCache,
|
||||
get_upload_cache,
|
||||
reset_upload_cache,
|
||||
)
|
||||
from crewai_files.core.resolved import (
|
||||
FileReference,
|
||||
InlineBase64,
|
||||
InlineBytes,
|
||||
ResolvedFile,
|
||||
ResolvedFileType,
|
||||
UrlReference,
|
||||
)
|
||||
from crewai_files.core.sources import (
|
||||
FileBytes,
|
||||
FilePath,
|
||||
FileSource,
|
||||
FileSourceInput,
|
||||
FileStream,
|
||||
FileUrl,
|
||||
RawFileInput,
|
||||
)
|
||||
from crewai_files.core.types import (
|
||||
AudioExtension,
|
||||
AudioFile,
|
||||
AudioMimeType,
|
||||
BaseFile,
|
||||
File,
|
||||
FileInput,
|
||||
FileMode,
|
||||
ImageExtension,
|
||||
ImageFile,
|
||||
ImageMimeType,
|
||||
PDFContentType,
|
||||
PDFExtension,
|
||||
PDFFile,
|
||||
TextContentType,
|
||||
TextExtension,
|
||||
TextFile,
|
||||
VideoExtension,
|
||||
VideoFile,
|
||||
VideoMimeType,
|
||||
)
|
||||
from crewai_files.formatting import (
|
||||
aformat_multimodal_content,
|
||||
format_multimodal_content,
|
||||
)
|
||||
from crewai_files.processing import (
|
||||
ANTHROPIC_CONSTRAINTS,
|
||||
BEDROCK_CONSTRAINTS,
|
||||
GEMINI_CONSTRAINTS,
|
||||
OPENAI_CONSTRAINTS,
|
||||
AudioConstraints,
|
||||
FileHandling,
|
||||
FileProcessingError,
|
||||
FileProcessor,
|
||||
FileTooLargeError,
|
||||
FileValidationError,
|
||||
ImageConstraints,
|
||||
PDFConstraints,
|
||||
ProcessingDependencyError,
|
||||
ProviderConstraints,
|
||||
UnsupportedFileTypeError,
|
||||
VideoConstraints,
|
||||
get_constraints_for_provider,
|
||||
)
|
||||
from crewai_files.resolution.resolver import (
|
||||
FileResolver,
|
||||
FileResolverConfig,
|
||||
create_resolver,
|
||||
)
|
||||
from crewai_files.resolution.utils import normalize_input_files, wrap_file_source
|
||||
from crewai_files.uploaders import FileUploader, UploadResult, get_uploader
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ANTHROPIC_CONSTRAINTS",
|
||||
"BEDROCK_CONSTRAINTS",
|
||||
"GEMINI_CONSTRAINTS",
|
||||
"OPENAI_CONSTRAINTS",
|
||||
"AudioConstraints",
|
||||
"AudioExtension",
|
||||
"AudioFile",
|
||||
"AudioMimeType",
|
||||
"BaseFile",
|
||||
"CachedUpload",
|
||||
"File",
|
||||
"FileBytes",
|
||||
"FileHandling",
|
||||
"FileInput",
|
||||
"FileMode",
|
||||
"FilePath",
|
||||
"FileProcessingError",
|
||||
"FileProcessor",
|
||||
"FileReference",
|
||||
"FileResolver",
|
||||
"FileResolverConfig",
|
||||
"FileSource",
|
||||
"FileSourceInput",
|
||||
"FileStream",
|
||||
"FileTooLargeError",
|
||||
"FileUploader",
|
||||
"FileUrl",
|
||||
"FileValidationError",
|
||||
"ImageConstraints",
|
||||
"ImageExtension",
|
||||
"ImageFile",
|
||||
"ImageMimeType",
|
||||
"InlineBase64",
|
||||
"InlineBytes",
|
||||
"PDFConstraints",
|
||||
"PDFContentType",
|
||||
"PDFExtension",
|
||||
"PDFFile",
|
||||
"ProcessingDependencyError",
|
||||
"ProviderConstraints",
|
||||
"RawFileInput",
|
||||
"ResolvedFile",
|
||||
"ResolvedFileType",
|
||||
"TextContentType",
|
||||
"TextExtension",
|
||||
"TextFile",
|
||||
"UnsupportedFileTypeError",
|
||||
"UploadCache",
|
||||
"UploadResult",
|
||||
"UrlReference",
|
||||
"VideoConstraints",
|
||||
"VideoExtension",
|
||||
"VideoFile",
|
||||
"VideoMimeType",
|
||||
"aformat_multimodal_content",
|
||||
"cleanup_expired_files",
|
||||
"cleanup_provider_files",
|
||||
"cleanup_uploaded_files",
|
||||
"create_resolver",
|
||||
"format_multimodal_content",
|
||||
"get_constraints_for_provider",
|
||||
"get_upload_cache",
|
||||
"get_uploader",
|
||||
"normalize_input_files",
|
||||
"reset_upload_cache",
|
||||
"wrap_file_source",
|
||||
]
|
||||
|
||||
__version__ = "1.8.1"
|
||||
@@ -1,14 +0,0 @@
|
||||
"""Upload caching and cleanup."""
|
||||
|
||||
from crewai_files.cache.cleanup import cleanup_uploaded_files
|
||||
from crewai_files.cache.metrics import FileOperationMetrics, measure_operation
|
||||
from crewai_files.cache.upload_cache import UploadCache, get_upload_cache
|
||||
|
||||
|
||||
__all__ = [
|
||||
"FileOperationMetrics",
|
||||
"UploadCache",
|
||||
"cleanup_uploaded_files",
|
||||
"get_upload_cache",
|
||||
"measure_operation",
|
||||
]
|
||||
374
lib/crewai-files/src/crewai_files/cache/cleanup.py
vendored
374
lib/crewai-files/src/crewai_files/cache/cleanup.py
vendored
@@ -1,374 +0,0 @@
|
||||
"""Cleanup utilities for uploaded files."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from crewai_files.cache.upload_cache import CachedUpload, UploadCache
|
||||
from crewai_files.uploaders import get_uploader
|
||||
from crewai_files.uploaders.factory import ProviderType
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai_files.uploaders.base import FileUploader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _safe_delete(
|
||||
uploader: FileUploader,
|
||||
file_id: str,
|
||||
provider: str,
|
||||
) -> bool:
|
||||
"""Safely delete a file, logging any errors.
|
||||
|
||||
Args:
|
||||
uploader: The file uploader to use.
|
||||
file_id: The file ID to delete.
|
||||
provider: Provider name for logging.
|
||||
|
||||
Returns:
|
||||
True if deleted successfully, False otherwise.
|
||||
"""
|
||||
try:
|
||||
if uploader.delete(file_id):
|
||||
logger.debug(f"Deleted {file_id} from {provider}")
|
||||
return True
|
||||
logger.warning(f"Failed to delete {file_id} from {provider}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(f"Error deleting {file_id} from {provider}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def cleanup_uploaded_files(
|
||||
cache: UploadCache,
|
||||
*,
|
||||
delete_from_provider: bool = True,
|
||||
providers: list[ProviderType] | None = None,
|
||||
) -> int:
|
||||
"""Clean up uploaded files from the cache and optionally from providers.
|
||||
|
||||
Args:
|
||||
cache: The upload cache to clean up.
|
||||
delete_from_provider: If True, delete files from the provider as well.
|
||||
providers: Optional list of providers to clean up. If None, cleans all.
|
||||
|
||||
Returns:
|
||||
Number of files cleaned up.
|
||||
"""
|
||||
cleaned = 0
|
||||
|
||||
provider_uploads: dict[ProviderType, list[CachedUpload]] = {}
|
||||
|
||||
for provider in _get_providers_from_cache(cache):
|
||||
if providers is not None and provider not in providers:
|
||||
continue
|
||||
provider_uploads[provider] = cache.get_all_for_provider(provider)
|
||||
|
||||
if delete_from_provider:
|
||||
for provider, uploads in provider_uploads.items():
|
||||
uploader = get_uploader(provider)
|
||||
if uploader is None:
|
||||
logger.warning(
|
||||
f"No uploader available for {provider}, skipping cleanup"
|
||||
)
|
||||
continue
|
||||
|
||||
for upload in uploads:
|
||||
if _safe_delete(uploader, upload.file_id, provider):
|
||||
cleaned += 1
|
||||
|
||||
cache.clear()
|
||||
|
||||
logger.info(f"Cleaned up {cleaned} uploaded files")
|
||||
return cleaned
|
||||
|
||||
|
||||
def cleanup_expired_files(
|
||||
cache: UploadCache,
|
||||
*,
|
||||
delete_from_provider: bool = False,
|
||||
) -> int:
|
||||
"""Clean up expired files from the cache.
|
||||
|
||||
Args:
|
||||
cache: The upload cache to clean up.
|
||||
delete_from_provider: If True, attempt to delete from provider as well.
|
||||
Note: Expired files may already be deleted by the provider.
|
||||
|
||||
Returns:
|
||||
Number of expired entries removed from cache.
|
||||
"""
|
||||
expired_entries: list[CachedUpload] = []
|
||||
|
||||
if delete_from_provider:
|
||||
for provider in _get_providers_from_cache(cache):
|
||||
expired_entries.extend(
|
||||
upload
|
||||
for upload in cache.get_all_for_provider(provider)
|
||||
if upload.is_expired()
|
||||
)
|
||||
|
||||
removed = cache.clear_expired()
|
||||
|
||||
if delete_from_provider:
|
||||
for upload in expired_entries:
|
||||
uploader = get_uploader(upload.provider)
|
||||
if uploader is not None:
|
||||
try:
|
||||
uploader.delete(upload.file_id)
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not delete expired file {upload.file_id}: {e}")
|
||||
|
||||
return removed
|
||||
|
||||
|
||||
def cleanup_provider_files(
|
||||
provider: ProviderType,
|
||||
*,
|
||||
cache: UploadCache | None = None,
|
||||
delete_all_from_provider: bool = False,
|
||||
) -> int:
|
||||
"""Clean up all files for a specific provider.
|
||||
|
||||
Args:
|
||||
provider: Provider name to clean up.
|
||||
cache: Optional upload cache to clear entries from.
|
||||
delete_all_from_provider: If True, delete all files from the provider,
|
||||
not just cached ones.
|
||||
|
||||
Returns:
|
||||
Number of files deleted.
|
||||
"""
|
||||
deleted = 0
|
||||
uploader = get_uploader(provider)
|
||||
|
||||
if uploader is None:
|
||||
logger.warning(f"No uploader available for {provider}")
|
||||
return 0
|
||||
|
||||
if delete_all_from_provider:
|
||||
try:
|
||||
files = uploader.list_files()
|
||||
for file_info in files:
|
||||
file_id = file_info.get("id") or file_info.get("name")
|
||||
if file_id and uploader.delete(file_id):
|
||||
deleted += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Error listing/deleting files from {provider}: {e}")
|
||||
elif cache is not None:
|
||||
uploads = cache.get_all_for_provider(provider)
|
||||
for upload in uploads:
|
||||
if _safe_delete(uploader, upload.file_id, provider):
|
||||
deleted += 1
|
||||
cache.remove_by_file_id(upload.file_id, provider)
|
||||
|
||||
logger.info(f"Deleted {deleted} files from {provider}")
|
||||
return deleted
|
||||
|
||||
|
||||
def _get_providers_from_cache(cache: UploadCache) -> set[ProviderType]:
|
||||
"""Get unique provider names from cache entries.
|
||||
|
||||
Args:
|
||||
cache: The upload cache.
|
||||
|
||||
Returns:
|
||||
Set of provider names.
|
||||
"""
|
||||
return cache.get_providers()
|
||||
|
||||
|
||||
async def _asafe_delete(
|
||||
uploader: FileUploader,
|
||||
file_id: str,
|
||||
provider: str,
|
||||
) -> bool:
|
||||
"""Async safely delete a file, logging any errors.
|
||||
|
||||
Args:
|
||||
uploader: The file uploader to use.
|
||||
file_id: The file ID to delete.
|
||||
provider: Provider name for logging.
|
||||
|
||||
Returns:
|
||||
True if deleted successfully, False otherwise.
|
||||
"""
|
||||
try:
|
||||
if await uploader.adelete(file_id):
|
||||
logger.debug(f"Deleted {file_id} from {provider}")
|
||||
return True
|
||||
logger.warning(f"Failed to delete {file_id} from {provider}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(f"Error deleting {file_id} from {provider}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def acleanup_uploaded_files(
|
||||
cache: UploadCache,
|
||||
*,
|
||||
delete_from_provider: bool = True,
|
||||
providers: list[ProviderType] | None = None,
|
||||
max_concurrency: int = 10,
|
||||
) -> int:
|
||||
"""Async clean up uploaded files from the cache and optionally from providers.
|
||||
|
||||
Args:
|
||||
cache: The upload cache to clean up.
|
||||
delete_from_provider: If True, delete files from the provider as well.
|
||||
providers: Optional list of providers to clean up. If None, cleans all.
|
||||
max_concurrency: Maximum number of concurrent delete operations.
|
||||
|
||||
Returns:
|
||||
Number of files cleaned up.
|
||||
"""
|
||||
cleaned = 0
|
||||
|
||||
provider_uploads: dict[ProviderType, list[CachedUpload]] = {}
|
||||
|
||||
for provider in _get_providers_from_cache(cache):
|
||||
if providers is not None and provider not in providers:
|
||||
continue
|
||||
provider_uploads[provider] = await cache.aget_all_for_provider(provider)
|
||||
|
||||
if delete_from_provider:
|
||||
semaphore = asyncio.Semaphore(max_concurrency)
|
||||
|
||||
async def delete_one(file_uploader: FileUploader, cached: CachedUpload) -> bool:
|
||||
"""Delete a single file with semaphore limiting."""
|
||||
async with semaphore:
|
||||
return await _asafe_delete(
|
||||
file_uploader, cached.file_id, cached.provider
|
||||
)
|
||||
|
||||
tasks: list[asyncio.Task[bool]] = []
|
||||
for provider, uploads in provider_uploads.items():
|
||||
uploader = get_uploader(provider)
|
||||
if uploader is None:
|
||||
logger.warning(
|
||||
f"No uploader available for {provider}, skipping cleanup"
|
||||
)
|
||||
continue
|
||||
|
||||
tasks.extend(
|
||||
asyncio.create_task(delete_one(uploader, cached)) for cached in uploads
|
||||
)
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
cleaned = sum(1 for r in results if r is True)
|
||||
|
||||
await cache.aclear()
|
||||
|
||||
logger.info(f"Cleaned up {cleaned} uploaded files")
|
||||
return cleaned
|
||||
|
||||
|
||||
async def acleanup_expired_files(
|
||||
cache: UploadCache,
|
||||
*,
|
||||
delete_from_provider: bool = False,
|
||||
max_concurrency: int = 10,
|
||||
) -> int:
|
||||
"""Async clean up expired files from the cache.
|
||||
|
||||
Args:
|
||||
cache: The upload cache to clean up.
|
||||
delete_from_provider: If True, attempt to delete from provider as well.
|
||||
max_concurrency: Maximum number of concurrent delete operations.
|
||||
|
||||
Returns:
|
||||
Number of expired entries removed from cache.
|
||||
"""
|
||||
expired_entries: list[CachedUpload] = []
|
||||
|
||||
if delete_from_provider:
|
||||
for provider in _get_providers_from_cache(cache):
|
||||
uploads = await cache.aget_all_for_provider(provider)
|
||||
expired_entries.extend(upload for upload in uploads if upload.is_expired())
|
||||
|
||||
removed = await cache.aclear_expired()
|
||||
|
||||
if delete_from_provider and expired_entries:
|
||||
semaphore = asyncio.Semaphore(max_concurrency)
|
||||
|
||||
async def delete_expired(cached: CachedUpload) -> None:
|
||||
"""Delete an expired file with semaphore limiting."""
|
||||
async with semaphore:
|
||||
file_uploader = get_uploader(cached.provider)
|
||||
if file_uploader is not None:
|
||||
try:
|
||||
await file_uploader.adelete(cached.file_id)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f"Could not delete expired file {cached.file_id}: {e}"
|
||||
)
|
||||
|
||||
await asyncio.gather(
|
||||
*[delete_expired(cached) for cached in expired_entries],
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
return removed
|
||||
|
||||
|
||||
async def acleanup_provider_files(
|
||||
provider: ProviderType,
|
||||
*,
|
||||
cache: UploadCache | None = None,
|
||||
delete_all_from_provider: bool = False,
|
||||
max_concurrency: int = 10,
|
||||
) -> int:
|
||||
"""Async clean up all files for a specific provider.
|
||||
|
||||
Args:
|
||||
provider: Provider name to clean up.
|
||||
cache: Optional upload cache to clear entries from.
|
||||
delete_all_from_provider: If True, delete all files from the provider.
|
||||
max_concurrency: Maximum number of concurrent delete operations.
|
||||
|
||||
Returns:
|
||||
Number of files deleted.
|
||||
"""
|
||||
deleted = 0
|
||||
uploader = get_uploader(provider)
|
||||
|
||||
if uploader is None:
|
||||
logger.warning(f"No uploader available for {provider}")
|
||||
return 0
|
||||
|
||||
semaphore = asyncio.Semaphore(max_concurrency)
|
||||
|
||||
async def delete_single(target_file_id: str) -> bool:
|
||||
"""Delete a single file with semaphore limiting."""
|
||||
async with semaphore:
|
||||
return await uploader.adelete(target_file_id)
|
||||
|
||||
if delete_all_from_provider:
|
||||
try:
|
||||
files = uploader.list_files()
|
||||
tasks = []
|
||||
for file_info in files:
|
||||
fid = file_info.get("id") or file_info.get("name")
|
||||
if fid:
|
||||
tasks.append(delete_single(fid))
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
deleted = sum(1 for r in results if r is True)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error listing/deleting files from {provider}: {e}")
|
||||
elif cache is not None:
|
||||
uploads = await cache.aget_all_for_provider(provider)
|
||||
tasks = []
|
||||
for upload in uploads:
|
||||
tasks.append(delete_single(upload.file_id))
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
for upload, result in zip(uploads, results, strict=False):
|
||||
if result is True:
|
||||
deleted += 1
|
||||
await cache.aremove_by_file_id(upload.file_id, provider)
|
||||
|
||||
logger.info(f"Deleted {deleted} files from {provider}")
|
||||
return deleted
|
||||
184
lib/crewai-files/src/crewai_files/cache/metrics.py
vendored
184
lib/crewai-files/src/crewai_files/cache/metrics.py
vendored
@@ -1,184 +0,0 @@
|
||||
"""Performance metrics and structured logging for file operations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileOperationMetrics:
|
||||
"""Metrics for a file operation.
|
||||
|
||||
Attributes:
|
||||
operation: Name of the operation (e.g., "upload", "resolve", "process").
|
||||
filename: Name of the file being operated on.
|
||||
provider: Provider name if applicable.
|
||||
duration_ms: Duration of the operation in milliseconds.
|
||||
size_bytes: Size of the file in bytes.
|
||||
success: Whether the operation succeeded.
|
||||
error: Error message if operation failed.
|
||||
timestamp: When the operation occurred.
|
||||
metadata: Additional operation-specific metadata.
|
||||
"""
|
||||
|
||||
operation: str
|
||||
filename: str | None = None
|
||||
provider: str | None = None
|
||||
duration_ms: float = 0.0
|
||||
size_bytes: int | None = None
|
||||
success: bool = True
|
||||
error: str | None = None
|
||||
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert metrics to dictionary for logging.
|
||||
|
||||
Returns:
|
||||
Dictionary representation of metrics.
|
||||
"""
|
||||
result: dict[str, Any] = {
|
||||
"operation": self.operation,
|
||||
"duration_ms": round(self.duration_ms, 2),
|
||||
"success": self.success,
|
||||
"timestamp": self.timestamp.isoformat(),
|
||||
}
|
||||
|
||||
if self.filename:
|
||||
result["filename"] = self.filename
|
||||
if self.provider:
|
||||
result["provider"] = self.provider
|
||||
if self.size_bytes is not None:
|
||||
result["size_bytes"] = self.size_bytes
|
||||
if self.error:
|
||||
result["error"] = self.error
|
||||
if self.metadata:
|
||||
result.update(self.metadata)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@contextmanager
|
||||
def measure_operation(
|
||||
operation: str,
|
||||
*,
|
||||
filename: str | None = None,
|
||||
provider: str | None = None,
|
||||
size_bytes: int | None = None,
|
||||
log_level: int = logging.DEBUG,
|
||||
**extra_metadata: Any,
|
||||
) -> Generator[FileOperationMetrics, None, None]:
|
||||
"""Context manager to measure and log operation performance.
|
||||
|
||||
Args:
|
||||
operation: Name of the operation.
|
||||
filename: Optional filename being operated on.
|
||||
provider: Optional provider name.
|
||||
size_bytes: Optional file size in bytes.
|
||||
log_level: Log level for the result message.
|
||||
**extra_metadata: Additional metadata to include.
|
||||
|
||||
Yields:
|
||||
FileOperationMetrics object that will be populated with results.
|
||||
|
||||
Example:
|
||||
with measure_operation("upload", filename="test.pdf", provider="openai") as metrics:
|
||||
result = upload_file(file)
|
||||
metrics.metadata["file_id"] = result.file_id
|
||||
"""
|
||||
metrics = FileOperationMetrics(
|
||||
operation=operation,
|
||||
filename=filename,
|
||||
provider=provider,
|
||||
size_bytes=size_bytes,
|
||||
metadata=dict(extra_metadata),
|
||||
)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
try:
|
||||
yield metrics
|
||||
metrics.success = True
|
||||
except Exception as e:
|
||||
metrics.success = False
|
||||
metrics.error = str(e)
|
||||
raise
|
||||
finally:
|
||||
metrics.duration_ms = (time.perf_counter() - start_time) * 1000
|
||||
|
||||
log_message = f"{operation}"
|
||||
if filename:
|
||||
log_message += f" [{filename}]"
|
||||
if provider:
|
||||
log_message += f" ({provider})"
|
||||
|
||||
if metrics.success:
|
||||
log_message += f" completed in {metrics.duration_ms:.2f}ms"
|
||||
else:
|
||||
log_message += f" failed after {metrics.duration_ms:.2f}ms: {metrics.error}"
|
||||
|
||||
logger.log(log_level, log_message, extra=metrics.to_dict())
|
||||
|
||||
|
||||
def log_file_operation(
|
||||
operation: str,
|
||||
*,
|
||||
filename: str | None = None,
|
||||
provider: str | None = None,
|
||||
size_bytes: int | None = None,
|
||||
duration_ms: float | None = None,
|
||||
success: bool = True,
|
||||
error: str | None = None,
|
||||
level: int = logging.INFO,
|
||||
**extra: Any,
|
||||
) -> None:
|
||||
"""Log a file operation with structured data.
|
||||
|
||||
Args:
|
||||
operation: Name of the operation.
|
||||
filename: Optional filename being operated on.
|
||||
provider: Optional provider name.
|
||||
size_bytes: Optional file size in bytes.
|
||||
duration_ms: Optional duration in milliseconds.
|
||||
success: Whether the operation succeeded.
|
||||
error: Optional error message.
|
||||
level: Log level to use.
|
||||
**extra: Additional metadata to include.
|
||||
"""
|
||||
metrics = FileOperationMetrics(
|
||||
operation=operation,
|
||||
filename=filename,
|
||||
provider=provider,
|
||||
size_bytes=size_bytes,
|
||||
duration_ms=duration_ms or 0.0,
|
||||
success=success,
|
||||
error=error,
|
||||
metadata=dict(extra),
|
||||
)
|
||||
|
||||
message = f"{operation}"
|
||||
if filename:
|
||||
message += f" [{filename}]"
|
||||
if provider:
|
||||
message += f" ({provider})"
|
||||
|
||||
if success:
|
||||
if duration_ms:
|
||||
message += f" completed in {duration_ms:.2f}ms"
|
||||
else:
|
||||
message += " completed"
|
||||
else:
|
||||
message += " failed"
|
||||
if error:
|
||||
message += f": {error}"
|
||||
|
||||
logger.log(level, message, extra=metrics.to_dict())
|
||||
@@ -1,553 +0,0 @@
|
||||
"""Cache for tracking uploaded files using aiocache."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import atexit
|
||||
import builtins
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from aiocache import Cache # type: ignore[import-untyped]
|
||||
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
|
||||
|
||||
from crewai_files.core.constants import DEFAULT_MAX_CACHE_ENTRIES, DEFAULT_TTL_SECONDS
|
||||
from crewai_files.uploaders.factory import ProviderType
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai_files.core.types import FileInput
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CachedUpload:
|
||||
"""Represents a cached file upload.
|
||||
|
||||
Attributes:
|
||||
file_id: Provider-specific file identifier.
|
||||
provider: Name of the provider.
|
||||
file_uri: Optional URI for accessing the file.
|
||||
content_type: MIME type of the uploaded file.
|
||||
uploaded_at: When the file was uploaded.
|
||||
expires_at: When the upload expires (if applicable).
|
||||
"""
|
||||
|
||||
file_id: str
|
||||
provider: ProviderType
|
||||
file_uri: str | None
|
||||
content_type: str
|
||||
uploaded_at: datetime
|
||||
expires_at: datetime | None = None
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if this cached upload has expired."""
|
||||
if self.expires_at is None:
|
||||
return False
|
||||
return datetime.now(timezone.utc) >= self.expires_at
|
||||
|
||||
|
||||
def _make_key(file_hash: str, provider: str) -> str:
|
||||
"""Create a cache key from file hash and provider."""
|
||||
return f"upload:{provider}:{file_hash}"
|
||||
|
||||
|
||||
def _compute_file_hash_streaming(chunks: Iterator[bytes]) -> str:
|
||||
"""Compute SHA-256 hash from streaming chunks.
|
||||
|
||||
Args:
|
||||
chunks: Iterator of byte chunks.
|
||||
|
||||
Returns:
|
||||
Hexadecimal hash string.
|
||||
"""
|
||||
hasher = hashlib.sha256()
|
||||
for chunk in chunks:
|
||||
hasher.update(chunk)
|
||||
return hasher.hexdigest()
|
||||
|
||||
|
||||
def _compute_file_hash(file: FileInput) -> str:
|
||||
"""Compute SHA-256 hash of file content.
|
||||
|
||||
Uses streaming for FilePath sources to avoid loading large files into memory.
|
||||
"""
|
||||
from crewai_files.core.sources import FilePath
|
||||
|
||||
source = file._file_source
|
||||
if isinstance(source, FilePath):
|
||||
return _compute_file_hash_streaming(source.read_chunks(chunk_size=1024 * 1024))
|
||||
content = file.read()
|
||||
return hashlib.sha256(content).hexdigest()
|
||||
|
||||
|
||||
class UploadCache:
|
||||
"""Async cache for tracking uploaded files using aiocache.
|
||||
|
||||
Supports in-memory caching by default, with optional Redis backend
|
||||
for distributed setups.
|
||||
|
||||
Attributes:
|
||||
ttl: Default time-to-live in seconds for cached entries.
|
||||
namespace: Cache namespace for isolation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ttl: int = DEFAULT_TTL_SECONDS,
|
||||
namespace: str = "crewai_uploads",
|
||||
cache_type: str = "memory",
|
||||
max_entries: int | None = DEFAULT_MAX_CACHE_ENTRIES,
|
||||
**cache_kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the upload cache.
|
||||
|
||||
Args:
|
||||
ttl: Default TTL in seconds.
|
||||
namespace: Cache namespace.
|
||||
cache_type: Backend type ("memory" or "redis").
|
||||
max_entries: Maximum cache entries (None for unlimited).
|
||||
**cache_kwargs: Additional args for cache backend.
|
||||
"""
|
||||
self.ttl = ttl
|
||||
self.namespace = namespace
|
||||
self.max_entries = max_entries
|
||||
self._provider_keys: dict[ProviderType, set[str]] = {}
|
||||
self._key_access_order: list[str] = []
|
||||
|
||||
if cache_type == "redis":
|
||||
self._cache = Cache(
|
||||
Cache.REDIS,
|
||||
serializer=PickleSerializer(),
|
||||
namespace=namespace,
|
||||
**cache_kwargs,
|
||||
)
|
||||
else:
|
||||
self._cache = Cache(
|
||||
serializer=PickleSerializer(),
|
||||
namespace=namespace,
|
||||
)
|
||||
|
||||
def _track_key(self, provider: ProviderType, key: str) -> None:
|
||||
"""Track a key for a provider (for cleanup) and access order."""
|
||||
if provider not in self._provider_keys:
|
||||
self._provider_keys[provider] = set()
|
||||
self._provider_keys[provider].add(key)
|
||||
if key in self._key_access_order:
|
||||
self._key_access_order.remove(key)
|
||||
self._key_access_order.append(key)
|
||||
|
||||
def _untrack_key(self, provider: ProviderType, key: str) -> None:
|
||||
"""Remove key tracking for a provider."""
|
||||
if provider in self._provider_keys:
|
||||
self._provider_keys[provider].discard(key)
|
||||
if key in self._key_access_order:
|
||||
self._key_access_order.remove(key)
|
||||
|
||||
async def _evict_if_needed(self) -> int:
|
||||
"""Evict oldest entries if limit exceeded.
|
||||
|
||||
Returns:
|
||||
Number of entries evicted.
|
||||
"""
|
||||
if self.max_entries is None:
|
||||
return 0
|
||||
|
||||
current_count = len(self)
|
||||
if current_count < self.max_entries:
|
||||
return 0
|
||||
|
||||
to_evict = max(1, self.max_entries // 10)
|
||||
return await self._evict_oldest(to_evict)
|
||||
|
||||
async def _evict_oldest(self, count: int) -> int:
|
||||
"""Evict the oldest entries from the cache.
|
||||
|
||||
Args:
|
||||
count: Number of entries to evict.
|
||||
|
||||
Returns:
|
||||
Number of entries actually evicted.
|
||||
"""
|
||||
evicted = 0
|
||||
keys_to_evict = self._key_access_order[:count]
|
||||
|
||||
for key in keys_to_evict:
|
||||
await self._cache.delete(key)
|
||||
self._key_access_order.remove(key)
|
||||
for provider_keys in self._provider_keys.values():
|
||||
provider_keys.discard(key)
|
||||
evicted += 1
|
||||
|
||||
if evicted > 0:
|
||||
logger.debug(f"Evicted {evicted} oldest cache entries")
|
||||
|
||||
return evicted
|
||||
|
||||
async def aget(
|
||||
self, file: FileInput, provider: ProviderType
|
||||
) -> CachedUpload | None:
|
||||
"""Get a cached upload for a file.
|
||||
|
||||
Args:
|
||||
file: The file to look up.
|
||||
provider: The provider name.
|
||||
|
||||
Returns:
|
||||
Cached upload if found and not expired, None otherwise.
|
||||
"""
|
||||
file_hash = _compute_file_hash(file)
|
||||
return await self.aget_by_hash(file_hash, provider)
|
||||
|
||||
async def aget_by_hash(
|
||||
self, file_hash: str, provider: ProviderType
|
||||
) -> CachedUpload | None:
|
||||
"""Get a cached upload by file hash.
|
||||
|
||||
Args:
|
||||
file_hash: Hash of the file content.
|
||||
provider: The provider name.
|
||||
|
||||
Returns:
|
||||
Cached upload if found and not expired, None otherwise.
|
||||
"""
|
||||
key = _make_key(file_hash, provider)
|
||||
result = await self._cache.get(key)
|
||||
|
||||
if result is None:
|
||||
return None
|
||||
if isinstance(result, CachedUpload):
|
||||
if result.is_expired():
|
||||
await self._cache.delete(key)
|
||||
self._untrack_key(provider, key)
|
||||
return None
|
||||
return result
|
||||
return None
|
||||
|
||||
async def aset(
|
||||
self,
|
||||
file: FileInput,
|
||||
provider: ProviderType,
|
||||
file_id: str,
|
||||
file_uri: str | None = None,
|
||||
expires_at: datetime | None = None,
|
||||
) -> CachedUpload:
|
||||
"""Cache an uploaded file.
|
||||
|
||||
Args:
|
||||
file: The file that was uploaded.
|
||||
provider: The provider name.
|
||||
file_id: Provider-specific file identifier.
|
||||
file_uri: Optional URI for accessing the file.
|
||||
expires_at: When the upload expires.
|
||||
|
||||
Returns:
|
||||
The created cache entry.
|
||||
"""
|
||||
file_hash = _compute_file_hash(file)
|
||||
return await self.aset_by_hash(
|
||||
file_hash=file_hash,
|
||||
content_type=file.content_type,
|
||||
provider=provider,
|
||||
file_id=file_id,
|
||||
file_uri=file_uri,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
async def aset_by_hash(
|
||||
self,
|
||||
file_hash: str,
|
||||
content_type: str,
|
||||
provider: ProviderType,
|
||||
file_id: str,
|
||||
file_uri: str | None = None,
|
||||
expires_at: datetime | None = None,
|
||||
) -> CachedUpload:
|
||||
"""Cache an uploaded file by hash.
|
||||
|
||||
Args:
|
||||
file_hash: Hash of the file content.
|
||||
content_type: MIME type of the file.
|
||||
provider: The provider name.
|
||||
file_id: Provider-specific file identifier.
|
||||
file_uri: Optional URI for accessing the file.
|
||||
expires_at: When the upload expires.
|
||||
|
||||
Returns:
|
||||
The created cache entry.
|
||||
"""
|
||||
await self._evict_if_needed()
|
||||
|
||||
key = _make_key(file_hash, provider)
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
cached = CachedUpload(
|
||||
file_id=file_id,
|
||||
provider=provider,
|
||||
file_uri=file_uri,
|
||||
content_type=content_type,
|
||||
uploaded_at=now,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
ttl = self.ttl
|
||||
if expires_at is not None:
|
||||
ttl = max(0, int((expires_at - now).total_seconds()))
|
||||
|
||||
await self._cache.set(key, cached, ttl=ttl)
|
||||
self._track_key(provider, key)
|
||||
logger.debug(f"Cached upload: {file_id} for provider {provider}")
|
||||
return cached
|
||||
|
||||
async def aremove(self, file: FileInput, provider: ProviderType) -> bool:
|
||||
"""Remove a cached upload.
|
||||
|
||||
Args:
|
||||
file: The file to remove.
|
||||
provider: The provider name.
|
||||
|
||||
Returns:
|
||||
True if entry was removed, False if not found.
|
||||
"""
|
||||
file_hash = _compute_file_hash(file)
|
||||
key = _make_key(file_hash, provider)
|
||||
|
||||
result = await self._cache.delete(key)
|
||||
removed = bool(result > 0 if isinstance(result, int) else result)
|
||||
if removed:
|
||||
self._untrack_key(provider, key)
|
||||
return removed
|
||||
|
||||
async def aremove_by_file_id(self, file_id: str, provider: ProviderType) -> bool:
|
||||
"""Remove a cached upload by file ID.
|
||||
|
||||
Args:
|
||||
file_id: The file ID to remove.
|
||||
provider: The provider name.
|
||||
|
||||
Returns:
|
||||
True if entry was removed, False if not found.
|
||||
"""
|
||||
if provider not in self._provider_keys:
|
||||
return False
|
||||
|
||||
for key in list(self._provider_keys[provider]):
|
||||
cached = await self._cache.get(key)
|
||||
if isinstance(cached, CachedUpload) and cached.file_id == file_id:
|
||||
await self._cache.delete(key)
|
||||
self._untrack_key(provider, key)
|
||||
return True
|
||||
return False
|
||||
|
||||
async def aclear_expired(self) -> int:
|
||||
"""Remove all expired entries from the cache.
|
||||
|
||||
Returns:
|
||||
Number of entries removed.
|
||||
"""
|
||||
removed = 0
|
||||
|
||||
for provider, keys in list(self._provider_keys.items()):
|
||||
for key in list(keys):
|
||||
cached = await self._cache.get(key)
|
||||
if cached is None or (
|
||||
isinstance(cached, CachedUpload) and cached.is_expired()
|
||||
):
|
||||
await self._cache.delete(key)
|
||||
self._untrack_key(provider, key)
|
||||
removed += 1
|
||||
|
||||
if removed > 0:
|
||||
logger.debug(f"Cleared {removed} expired cache entries")
|
||||
return removed
|
||||
|
||||
async def aclear(self) -> int:
|
||||
"""Clear all entries from the cache.
|
||||
|
||||
Returns:
|
||||
Number of entries cleared.
|
||||
"""
|
||||
count = sum(len(keys) for keys in self._provider_keys.values())
|
||||
await self._cache.clear(namespace=self.namespace)
|
||||
self._provider_keys.clear()
|
||||
|
||||
if count > 0:
|
||||
logger.debug(f"Cleared {count} cache entries")
|
||||
return count
|
||||
|
||||
async def aget_all_for_provider(self, provider: ProviderType) -> list[CachedUpload]:
|
||||
"""Get all cached uploads for a provider.
|
||||
|
||||
Args:
|
||||
provider: The provider name.
|
||||
|
||||
Returns:
|
||||
List of cached uploads for the provider.
|
||||
"""
|
||||
if provider not in self._provider_keys:
|
||||
return []
|
||||
|
||||
results: list[CachedUpload] = []
|
||||
for key in list(self._provider_keys[provider]):
|
||||
cached = await self._cache.get(key)
|
||||
if isinstance(cached, CachedUpload) and not cached.is_expired():
|
||||
results.append(cached)
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def _run_sync(coro: Any) -> Any:
|
||||
"""Run an async coroutine from sync context without blocking event loop."""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
|
||||
if loop is not None and loop.is_running():
|
||||
future = asyncio.run_coroutine_threadsafe(coro, loop)
|
||||
return future.result(timeout=30)
|
||||
return asyncio.run(coro)
|
||||
|
||||
def get(self, file: FileInput, provider: ProviderType) -> CachedUpload | None:
|
||||
"""Sync wrapper for aget."""
|
||||
result: CachedUpload | None = self._run_sync(self.aget(file, provider))
|
||||
return result
|
||||
|
||||
def get_by_hash(
|
||||
self, file_hash: str, provider: ProviderType
|
||||
) -> CachedUpload | None:
|
||||
"""Sync wrapper for aget_by_hash."""
|
||||
result: CachedUpload | None = self._run_sync(
|
||||
self.aget_by_hash(file_hash, provider)
|
||||
)
|
||||
return result
|
||||
|
||||
def set(
|
||||
self,
|
||||
file: FileInput,
|
||||
provider: ProviderType,
|
||||
file_id: str,
|
||||
file_uri: str | None = None,
|
||||
expires_at: datetime | None = None,
|
||||
) -> CachedUpload:
|
||||
"""Sync wrapper for aset."""
|
||||
result: CachedUpload = self._run_sync(
|
||||
self.aset(file, provider, file_id, file_uri, expires_at)
|
||||
)
|
||||
return result
|
||||
|
||||
def set_by_hash(
|
||||
self,
|
||||
file_hash: str,
|
||||
content_type: str,
|
||||
provider: ProviderType,
|
||||
file_id: str,
|
||||
file_uri: str | None = None,
|
||||
expires_at: datetime | None = None,
|
||||
) -> CachedUpload:
|
||||
"""Sync wrapper for aset_by_hash."""
|
||||
result: CachedUpload = self._run_sync(
|
||||
self.aset_by_hash(
|
||||
file_hash, content_type, provider, file_id, file_uri, expires_at
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
def remove(self, file: FileInput, provider: ProviderType) -> bool:
|
||||
"""Sync wrapper for aremove."""
|
||||
result: bool = self._run_sync(self.aremove(file, provider))
|
||||
return result
|
||||
|
||||
def remove_by_file_id(self, file_id: str, provider: ProviderType) -> bool:
|
||||
"""Sync wrapper for aremove_by_file_id."""
|
||||
result: bool = self._run_sync(self.aremove_by_file_id(file_id, provider))
|
||||
return result
|
||||
|
||||
def clear_expired(self) -> int:
|
||||
"""Sync wrapper for aclear_expired."""
|
||||
result: int = self._run_sync(self.aclear_expired())
|
||||
return result
|
||||
|
||||
def clear(self) -> int:
|
||||
"""Sync wrapper for aclear."""
|
||||
result: int = self._run_sync(self.aclear())
|
||||
return result
|
||||
|
||||
def get_all_for_provider(self, provider: ProviderType) -> list[CachedUpload]:
|
||||
"""Sync wrapper for aget_all_for_provider."""
|
||||
result: list[CachedUpload] = self._run_sync(
|
||||
self.aget_all_for_provider(provider)
|
||||
)
|
||||
return result
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of cached entries."""
|
||||
return sum(len(keys) for keys in self._provider_keys.values())
|
||||
|
||||
def get_providers(self) -> builtins.set[ProviderType]:
|
||||
"""Get all provider names that have cached entries.
|
||||
|
||||
Returns:
|
||||
Set of provider names.
|
||||
"""
|
||||
return builtins.set(self._provider_keys.keys())
|
||||
|
||||
|
||||
_default_cache: UploadCache | None = None
|
||||
|
||||
|
||||
def get_upload_cache(
|
||||
ttl: int = DEFAULT_TTL_SECONDS,
|
||||
namespace: str = "crewai_uploads",
|
||||
cache_type: str = "memory",
|
||||
**cache_kwargs: Any,
|
||||
) -> UploadCache:
|
||||
"""Get or create the default upload cache.
|
||||
|
||||
Args:
|
||||
ttl: Default TTL in seconds.
|
||||
namespace: Cache namespace.
|
||||
cache_type: Backend type ("memory" or "redis").
|
||||
**cache_kwargs: Additional args for cache backend.
|
||||
|
||||
Returns:
|
||||
The upload cache instance.
|
||||
"""
|
||||
global _default_cache
|
||||
if _default_cache is None:
|
||||
_default_cache = UploadCache(
|
||||
ttl=ttl,
|
||||
namespace=namespace,
|
||||
cache_type=cache_type,
|
||||
**cache_kwargs,
|
||||
)
|
||||
return _default_cache
|
||||
|
||||
|
||||
def reset_upload_cache() -> None:
|
||||
"""Reset the default upload cache (useful for testing)."""
|
||||
global _default_cache
|
||||
if _default_cache is not None:
|
||||
_default_cache.clear()
|
||||
_default_cache = None
|
||||
|
||||
|
||||
def _cleanup_on_exit() -> None:
|
||||
"""Clean up uploaded files on process exit."""
|
||||
global _default_cache
|
||||
if _default_cache is None or len(_default_cache) == 0:
|
||||
return
|
||||
|
||||
from crewai_files.cache.cleanup import cleanup_uploaded_files
|
||||
|
||||
try:
|
||||
cleanup_uploaded_files(_default_cache)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error during exit cleanup: {e}")
|
||||
|
||||
|
||||
atexit.register(_cleanup_on_exit)
|
||||
@@ -1,92 +0,0 @@
|
||||
"""Core file types and sources."""
|
||||
|
||||
from crewai_files.core.constants import (
|
||||
BACKOFF_BASE_DELAY,
|
||||
BACKOFF_JITTER_FACTOR,
|
||||
BACKOFF_MAX_DELAY,
|
||||
DEFAULT_MAX_CACHE_ENTRIES,
|
||||
DEFAULT_MAX_FILE_SIZE_BYTES,
|
||||
DEFAULT_TTL_SECONDS,
|
||||
DEFAULT_UPLOAD_CHUNK_SIZE,
|
||||
FILES_API_MAX_SIZE,
|
||||
GEMINI_FILE_TTL,
|
||||
MAGIC_BUFFER_SIZE,
|
||||
MAX_CONCURRENCY,
|
||||
MULTIPART_CHUNKSIZE,
|
||||
MULTIPART_THRESHOLD,
|
||||
UPLOAD_MAX_RETRIES,
|
||||
UPLOAD_RETRY_DELAY_BASE,
|
||||
)
|
||||
from crewai_files.core.resolved import (
|
||||
FileReference,
|
||||
InlineBase64,
|
||||
InlineBytes,
|
||||
ResolvedFile,
|
||||
UrlReference,
|
||||
)
|
||||
from crewai_files.core.sources import (
|
||||
AsyncFileStream,
|
||||
FileBytes,
|
||||
FilePath,
|
||||
FileSource,
|
||||
FileStream,
|
||||
FileUrl,
|
||||
)
|
||||
from crewai_files.core.types import (
|
||||
AudioFile,
|
||||
AudioMimeType,
|
||||
BaseFile,
|
||||
CoercedFileSource,
|
||||
File,
|
||||
FileInput,
|
||||
FileMode,
|
||||
ImageFile,
|
||||
ImageMimeType,
|
||||
PDFFile,
|
||||
TextFile,
|
||||
VideoFile,
|
||||
VideoMimeType,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BACKOFF_BASE_DELAY",
|
||||
"BACKOFF_JITTER_FACTOR",
|
||||
"BACKOFF_MAX_DELAY",
|
||||
"DEFAULT_MAX_CACHE_ENTRIES",
|
||||
"DEFAULT_MAX_FILE_SIZE_BYTES",
|
||||
"DEFAULT_TTL_SECONDS",
|
||||
"DEFAULT_UPLOAD_CHUNK_SIZE",
|
||||
"FILES_API_MAX_SIZE",
|
||||
"GEMINI_FILE_TTL",
|
||||
"MAGIC_BUFFER_SIZE",
|
||||
"MAX_CONCURRENCY",
|
||||
"MULTIPART_CHUNKSIZE",
|
||||
"MULTIPART_THRESHOLD",
|
||||
"UPLOAD_MAX_RETRIES",
|
||||
"UPLOAD_RETRY_DELAY_BASE",
|
||||
"AsyncFileStream",
|
||||
"AudioFile",
|
||||
"AudioMimeType",
|
||||
"BaseFile",
|
||||
"CoercedFileSource",
|
||||
"File",
|
||||
"FileBytes",
|
||||
"FileInput",
|
||||
"FileMode",
|
||||
"FilePath",
|
||||
"FileReference",
|
||||
"FileSource",
|
||||
"FileStream",
|
||||
"FileUrl",
|
||||
"ImageFile",
|
||||
"ImageMimeType",
|
||||
"InlineBase64",
|
||||
"InlineBytes",
|
||||
"PDFFile",
|
||||
"ResolvedFile",
|
||||
"TextFile",
|
||||
"UrlReference",
|
||||
"VideoFile",
|
||||
"VideoMimeType",
|
||||
]
|
||||
@@ -1,26 +0,0 @@
|
||||
"""Constants for file handling utilities."""
|
||||
|
||||
from datetime import timedelta
|
||||
from typing import Final, Literal
|
||||
|
||||
|
||||
DEFAULT_MAX_FILE_SIZE_BYTES: Final[Literal[524_288_000]] = 524_288_000
|
||||
MAGIC_BUFFER_SIZE: Final[Literal[2048]] = 2048
|
||||
|
||||
UPLOAD_MAX_RETRIES: Final[Literal[3]] = 3
|
||||
UPLOAD_RETRY_DELAY_BASE: Final[Literal[2]] = 2
|
||||
|
||||
DEFAULT_TTL_SECONDS: Final[Literal[86_400]] = 86_400
|
||||
DEFAULT_MAX_CACHE_ENTRIES: Final[Literal[1000]] = 1000
|
||||
|
||||
GEMINI_FILE_TTL: Final[timedelta] = timedelta(hours=48)
|
||||
BACKOFF_BASE_DELAY: Final[float] = 1.0
|
||||
BACKOFF_MAX_DELAY: Final[float] = 30.0
|
||||
BACKOFF_JITTER_FACTOR: Final[float] = 0.1
|
||||
|
||||
FILES_API_MAX_SIZE: Final[Literal[536_870_912]] = 536_870_912
|
||||
DEFAULT_UPLOAD_CHUNK_SIZE: Final[Literal[67_108_864]] = 67_108_864
|
||||
|
||||
MULTIPART_THRESHOLD: Final[Literal[8_388_608]] = 8_388_608
|
||||
MULTIPART_CHUNKSIZE: Final[Literal[8_388_608]] = 8_388_608
|
||||
MAX_CONCURRENCY: Final[Literal[10]] = 10
|
||||
@@ -1,84 +0,0 @@
|
||||
"""Resolved file types representing different delivery methods for file content."""
|
||||
|
||||
from abc import ABC
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ResolvedFile(ABC):
|
||||
"""Base class for resolved file representations.
|
||||
|
||||
A ResolvedFile represents the final form of a file ready for delivery
|
||||
to an LLM provider, whether inline or via reference.
|
||||
|
||||
Attributes:
|
||||
content_type: MIME type of the file content.
|
||||
"""
|
||||
|
||||
content_type: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class InlineBase64(ResolvedFile):
|
||||
"""File content encoded as base64 string.
|
||||
|
||||
Used by most providers for inline file content in messages.
|
||||
|
||||
Attributes:
|
||||
content_type: MIME type of the file content.
|
||||
data: Base64-encoded file content.
|
||||
"""
|
||||
|
||||
data: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class InlineBytes(ResolvedFile):
|
||||
"""File content as raw bytes.
|
||||
|
||||
Used by providers like Bedrock that accept raw bytes instead of base64.
|
||||
|
||||
Attributes:
|
||||
content_type: MIME type of the file content.
|
||||
data: Raw file bytes.
|
||||
"""
|
||||
|
||||
data: bytes
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FileReference(ResolvedFile):
|
||||
"""Reference to an uploaded file.
|
||||
|
||||
Used when files are uploaded via provider File APIs.
|
||||
|
||||
Attributes:
|
||||
content_type: MIME type of the file content.
|
||||
file_id: Provider-specific file identifier.
|
||||
provider: Name of the provider the file was uploaded to.
|
||||
expires_at: When the uploaded file expires (if applicable).
|
||||
file_uri: Optional URI for accessing the file (used by Gemini).
|
||||
"""
|
||||
|
||||
file_id: str
|
||||
provider: str
|
||||
expires_at: datetime | None = None
|
||||
file_uri: str | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UrlReference(ResolvedFile):
|
||||
"""Reference to a file accessible via URL.
|
||||
|
||||
Used by providers that support fetching files from URLs.
|
||||
|
||||
Attributes:
|
||||
content_type: MIME type of the file content.
|
||||
url: URL where the file can be accessed.
|
||||
"""
|
||||
|
||||
url: str
|
||||
|
||||
|
||||
ResolvedFileType = InlineBase64 | InlineBytes | FileReference | UrlReference
|
||||
@@ -1,513 +0,0 @@
|
||||
"""Base file class for handling file inputs in tasks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any, BinaryIO, Protocol, cast, runtime_checkable
|
||||
|
||||
import aiofiles
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
BeforeValidator,
|
||||
Field,
|
||||
GetCoreSchemaHandler,
|
||||
PrivateAttr,
|
||||
model_validator,
|
||||
)
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from crewai_files.core.constants import DEFAULT_MAX_FILE_SIZE_BYTES, MAGIC_BUFFER_SIZE
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class AsyncReadable(Protocol):
|
||||
"""Protocol for async readable streams."""
|
||||
|
||||
async def read(self, size: int = -1) -> bytes:
|
||||
"""Read up to size bytes from the stream."""
|
||||
...
|
||||
|
||||
|
||||
class _AsyncReadableValidator:
|
||||
"""Pydantic validator for AsyncReadable types."""
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, _source_type: Any, _handler: GetCoreSchemaHandler
|
||||
) -> CoreSchema:
|
||||
return core_schema.no_info_plain_validator_function(
|
||||
cls._validate,
|
||||
serialization=core_schema.plain_serializer_function_ser_schema(
|
||||
lambda x: None, info_arg=False
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _validate(value: Any) -> AsyncReadable:
|
||||
if isinstance(value, AsyncReadable):
|
||||
return value
|
||||
raise ValueError("Expected an async readable object with async read() method")
|
||||
|
||||
|
||||
ValidatedAsyncReadable = Annotated[AsyncReadable, _AsyncReadableValidator()]
|
||||
|
||||
|
||||
def _fallback_content_type(filename: str | None) -> str:
|
||||
"""Get content type from filename extension or return default."""
|
||||
if filename:
|
||||
mime_type, _ = mimetypes.guess_type(filename)
|
||||
if mime_type:
|
||||
return mime_type
|
||||
return "application/octet-stream"
|
||||
|
||||
|
||||
def detect_content_type(data: bytes, filename: str | None = None) -> str:
|
||||
"""Detect MIME type from file content.
|
||||
|
||||
Uses python-magic if available for accurate content-based detection,
|
||||
falls back to mimetypes module using filename extension.
|
||||
|
||||
Args:
|
||||
data: Raw bytes to analyze (only first 2048 bytes are used).
|
||||
filename: Optional filename for extension-based fallback.
|
||||
|
||||
Returns:
|
||||
The detected MIME type.
|
||||
"""
|
||||
try:
|
||||
import magic
|
||||
|
||||
result: str = magic.from_buffer(data[:MAGIC_BUFFER_SIZE], mime=True)
|
||||
return result
|
||||
except ImportError:
|
||||
return _fallback_content_type(filename)
|
||||
|
||||
|
||||
def detect_content_type_from_path(path: Path, filename: str | None = None) -> str:
|
||||
"""Detect MIME type from file path.
|
||||
|
||||
Uses python-magic's from_file() for accurate detection without reading
|
||||
the entire file into memory.
|
||||
|
||||
Args:
|
||||
path: Path to the file.
|
||||
filename: Optional filename for extension-based fallback.
|
||||
|
||||
Returns:
|
||||
The detected MIME type.
|
||||
"""
|
||||
try:
|
||||
import magic
|
||||
|
||||
result: str = magic.from_file(str(path), mime=True)
|
||||
return result
|
||||
except ImportError:
|
||||
return _fallback_content_type(filename or path.name)
|
||||
|
||||
|
||||
class _BinaryIOValidator:
|
||||
"""Pydantic validator for BinaryIO types."""
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, _source_type: Any, _handler: GetCoreSchemaHandler
|
||||
) -> CoreSchema:
|
||||
return core_schema.no_info_plain_validator_function(
|
||||
cls._validate,
|
||||
serialization=core_schema.plain_serializer_function_ser_schema(
|
||||
lambda x: None, info_arg=False
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _validate(value: Any) -> BinaryIO:
|
||||
if hasattr(value, "read") and hasattr(value, "seek"):
|
||||
return cast(BinaryIO, value)
|
||||
raise ValueError("Expected a binary file-like object with read() and seek()")
|
||||
|
||||
|
||||
ValidatedBinaryIO = Annotated[BinaryIO, _BinaryIOValidator()]
|
||||
|
||||
|
||||
class FilePath(BaseModel):
|
||||
"""File loaded from a filesystem path."""
|
||||
|
||||
path: Path = Field(description="Path to the file on the filesystem.")
|
||||
max_size_bytes: int = Field(
|
||||
default=DEFAULT_MAX_FILE_SIZE_BYTES,
|
||||
exclude=True,
|
||||
description="Maximum file size in bytes.",
|
||||
)
|
||||
_content: bytes | None = PrivateAttr(default=None)
|
||||
_content_type: str = PrivateAttr()
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_file_exists(self) -> FilePath:
|
||||
"""Validate that the file exists, is secure, and within size limits."""
|
||||
from crewai_files.processing.exceptions import FileTooLargeError
|
||||
|
||||
path_str = str(self.path)
|
||||
if ".." in path_str:
|
||||
raise ValueError(f"Path traversal not allowed: {self.path}")
|
||||
|
||||
if self.path.is_symlink():
|
||||
resolved = self.path.resolve()
|
||||
cwd = Path.cwd().resolve()
|
||||
if not str(resolved).startswith(str(cwd)):
|
||||
raise ValueError(f"Symlink escapes allowed directory: {self.path}")
|
||||
|
||||
if not self.path.exists():
|
||||
raise ValueError(f"File not found: {self.path}")
|
||||
if not self.path.is_file():
|
||||
raise ValueError(f"Path is not a file: {self.path}")
|
||||
|
||||
actual_size = self.path.stat().st_size
|
||||
if actual_size > self.max_size_bytes:
|
||||
raise FileTooLargeError(
|
||||
f"File exceeds max size ({actual_size} > {self.max_size_bytes})",
|
||||
file_name=str(self.path),
|
||||
actual_size=actual_size,
|
||||
max_size=self.max_size_bytes,
|
||||
)
|
||||
|
||||
self._content_type = detect_content_type_from_path(self.path, self.path.name)
|
||||
return self
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
"""Get the filename from the path."""
|
||||
return self.path.name
|
||||
|
||||
@property
|
||||
def content_type(self) -> str:
|
||||
"""Get the content type."""
|
||||
return self._content_type
|
||||
|
||||
def read(self) -> bytes:
|
||||
"""Read the file content from disk."""
|
||||
if self._content is None:
|
||||
self._content = self.path.read_bytes()
|
||||
return self._content
|
||||
|
||||
async def aread(self) -> bytes:
|
||||
"""Async read the file content from disk."""
|
||||
if self._content is None:
|
||||
async with aiofiles.open(self.path, "rb") as f:
|
||||
self._content = await f.read()
|
||||
return self._content
|
||||
|
||||
def read_chunks(self, chunk_size: int = 65536) -> Iterator[bytes]:
|
||||
"""Stream file content in chunks without loading entirely into memory.
|
||||
|
||||
Args:
|
||||
chunk_size: Size of each chunk in bytes.
|
||||
|
||||
Yields:
|
||||
Chunks of file content.
|
||||
"""
|
||||
with open(self.path, "rb") as f:
|
||||
while chunk := f.read(chunk_size):
|
||||
yield chunk
|
||||
|
||||
async def aread_chunks(self, chunk_size: int = 65536) -> AsyncIterator[bytes]:
|
||||
"""Async streaming for non-blocking I/O.
|
||||
|
||||
Args:
|
||||
chunk_size: Size of each chunk in bytes.
|
||||
|
||||
Yields:
|
||||
Chunks of file content.
|
||||
"""
|
||||
async with aiofiles.open(self.path, "rb") as f:
|
||||
while chunk := await f.read(chunk_size):
|
||||
yield chunk
|
||||
|
||||
|
||||
class FileBytes(BaseModel):
|
||||
"""File created from raw bytes content."""
|
||||
|
||||
data: bytes = Field(description="Raw bytes content of the file.")
|
||||
filename: str | None = Field(default=None, description="Optional filename.")
|
||||
_content_type: str = PrivateAttr()
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _detect_content_type(self) -> FileBytes:
|
||||
"""Detect and cache content type from data."""
|
||||
self._content_type = detect_content_type(self.data, self.filename)
|
||||
return self
|
||||
|
||||
@property
|
||||
def content_type(self) -> str:
|
||||
"""Get the content type."""
|
||||
return self._content_type
|
||||
|
||||
def read(self) -> bytes:
|
||||
"""Return the bytes content."""
|
||||
return self.data
|
||||
|
||||
async def aread(self) -> bytes:
|
||||
"""Async return the bytes content (immediate, already in memory)."""
|
||||
return self.data
|
||||
|
||||
def read_chunks(self, chunk_size: int = 65536) -> Iterator[bytes]:
|
||||
"""Stream bytes content in chunks.
|
||||
|
||||
Args:
|
||||
chunk_size: Size of each chunk in bytes.
|
||||
|
||||
Yields:
|
||||
Chunks of bytes content.
|
||||
"""
|
||||
for i in range(0, len(self.data), chunk_size):
|
||||
yield self.data[i : i + chunk_size]
|
||||
|
||||
async def aread_chunks(self, chunk_size: int = 65536) -> AsyncIterator[bytes]:
|
||||
"""Async streaming (immediate yield since already in memory).
|
||||
|
||||
Args:
|
||||
chunk_size: Size of each chunk in bytes.
|
||||
|
||||
Yields:
|
||||
Chunks of bytes content.
|
||||
"""
|
||||
for chunk in self.read_chunks(chunk_size):
|
||||
yield chunk
|
||||
|
||||
|
||||
class FileStream(BaseModel):
|
||||
"""File loaded from a file-like stream."""
|
||||
|
||||
stream: ValidatedBinaryIO = Field(description="Binary file stream.")
|
||||
filename: str | None = Field(default=None, description="Optional filename.")
|
||||
_content: bytes | None = PrivateAttr(default=None)
|
||||
_content_type: str = PrivateAttr()
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _initialize(self) -> FileStream:
|
||||
"""Extract filename and detect content type."""
|
||||
if self.filename is None:
|
||||
name = getattr(self.stream, "name", None)
|
||||
if name is not None:
|
||||
self.filename = Path(name).name
|
||||
|
||||
position = self.stream.tell()
|
||||
self.stream.seek(0)
|
||||
header = self.stream.read(MAGIC_BUFFER_SIZE)
|
||||
self.stream.seek(position)
|
||||
self._content_type = detect_content_type(header, self.filename)
|
||||
return self
|
||||
|
||||
@property
|
||||
def content_type(self) -> str:
|
||||
"""Get the content type."""
|
||||
return self._content_type
|
||||
|
||||
def read(self) -> bytes:
|
||||
"""Read the stream content. Content is cached after first read."""
|
||||
if self._content is None:
|
||||
position = self.stream.tell()
|
||||
self.stream.seek(0)
|
||||
self._content = self.stream.read()
|
||||
self.stream.seek(position)
|
||||
return self._content
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the underlying stream."""
|
||||
self.stream.close()
|
||||
|
||||
def __enter__(self) -> FileStream:
|
||||
"""Enter context manager."""
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: Any,
|
||||
) -> None:
|
||||
"""Exit context manager and close stream."""
|
||||
self.close()
|
||||
|
||||
def read_chunks(self, chunk_size: int = 65536) -> Iterator[bytes]:
|
||||
"""Stream from underlying stream in chunks.
|
||||
|
||||
Args:
|
||||
chunk_size: Size of each chunk in bytes.
|
||||
|
||||
Yields:
|
||||
Chunks of stream content.
|
||||
"""
|
||||
position = self.stream.tell()
|
||||
self.stream.seek(0)
|
||||
try:
|
||||
while chunk := self.stream.read(chunk_size):
|
||||
yield chunk
|
||||
finally:
|
||||
self.stream.seek(position)
|
||||
|
||||
|
||||
class AsyncFileStream(BaseModel):
|
||||
"""File loaded from an async stream.
|
||||
|
||||
Use for async file handles like aiofiles objects or aiohttp response bodies.
|
||||
This is an async-only type - use aread() instead of read().
|
||||
|
||||
Attributes:
|
||||
stream: Async file-like object with async read() method.
|
||||
filename: Optional filename for the stream.
|
||||
"""
|
||||
|
||||
stream: ValidatedAsyncReadable = Field(
|
||||
description="Async file stream with async read() method."
|
||||
)
|
||||
filename: str | None = Field(default=None, description="Optional filename.")
|
||||
_content: bytes | None = PrivateAttr(default=None)
|
||||
_content_type: str | None = PrivateAttr(default=None)
|
||||
|
||||
@property
|
||||
def content_type(self) -> str:
|
||||
"""Get the content type from stream content (cached). Requires aread() first."""
|
||||
if self._content is None:
|
||||
raise RuntimeError("Call aread() first to load content")
|
||||
if self._content_type is None:
|
||||
self._content_type = detect_content_type(self._content, self.filename)
|
||||
return self._content_type
|
||||
|
||||
async def aread(self) -> bytes:
|
||||
"""Async read the stream content. Content is cached after first read."""
|
||||
if self._content is None:
|
||||
self._content = await self.stream.read()
|
||||
return self._content
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""Async close the underlying stream."""
|
||||
if hasattr(self.stream, "close"):
|
||||
result = self.stream.close()
|
||||
if hasattr(result, "__await__"):
|
||||
await result
|
||||
|
||||
async def __aenter__(self) -> AsyncFileStream:
|
||||
"""Async enter context manager."""
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: Any,
|
||||
) -> None:
|
||||
"""Async exit context manager and close stream."""
|
||||
await self.aclose()
|
||||
|
||||
async def aread_chunks(self, chunk_size: int = 65536) -> AsyncIterator[bytes]:
|
||||
"""Async stream content in chunks.
|
||||
|
||||
Args:
|
||||
chunk_size: Size of each chunk in bytes.
|
||||
|
||||
Yields:
|
||||
Chunks of stream content.
|
||||
"""
|
||||
while chunk := await self.stream.read(chunk_size):
|
||||
yield chunk
|
||||
|
||||
|
||||
class FileUrl(BaseModel):
|
||||
"""File referenced by URL.
|
||||
|
||||
For providers that support URL references, the URL is passed directly.
|
||||
For providers that don't, content is fetched on demand.
|
||||
|
||||
Attributes:
|
||||
url: URL where the file can be accessed.
|
||||
filename: Optional filename (extracted from URL if not provided).
|
||||
"""
|
||||
|
||||
url: str = Field(description="URL where the file can be accessed.")
|
||||
filename: str | None = Field(default=None, description="Optional filename.")
|
||||
_content_type: str | None = PrivateAttr(default=None)
|
||||
_content: bytes | None = PrivateAttr(default=None)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_url(self) -> FileUrl:
|
||||
"""Validate URL format."""
|
||||
if not self.url.startswith(("http://", "https://")):
|
||||
raise ValueError(f"Invalid URL scheme: {self.url}")
|
||||
return self
|
||||
|
||||
@property
|
||||
def content_type(self) -> str:
|
||||
"""Get the content type, guessing from URL extension if not set."""
|
||||
if self._content_type is None:
|
||||
self._content_type = self._guess_content_type()
|
||||
return self._content_type
|
||||
|
||||
def _guess_content_type(self) -> str:
|
||||
"""Guess content type from URL extension."""
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(self.url)
|
||||
path = parsed.path
|
||||
guessed, _ = mimetypes.guess_type(path)
|
||||
return guessed or "application/octet-stream"
|
||||
|
||||
def read(self) -> bytes:
|
||||
"""Fetch content from URL (for providers that don't support URL references)."""
|
||||
if self._content is None:
|
||||
import httpx
|
||||
|
||||
response = httpx.get(self.url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
self._content = response.content
|
||||
if "content-type" in response.headers:
|
||||
self._content_type = response.headers["content-type"].split(";")[0]
|
||||
return self._content
|
||||
|
||||
async def aread(self) -> bytes:
|
||||
"""Async fetch content from URL."""
|
||||
if self._content is None:
|
||||
import httpx
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(self.url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
self._content = response.content
|
||||
if "content-type" in response.headers:
|
||||
self._content_type = response.headers["content-type"].split(";")[0]
|
||||
return self._content
|
||||
|
||||
|
||||
FileSource = FilePath | FileBytes | FileStream | AsyncFileStream | FileUrl
|
||||
|
||||
|
||||
def is_file_source(v: object) -> TypeIs[FileSource]:
|
||||
"""Type guard to narrow input to FileSource."""
|
||||
return isinstance(v, (FilePath, FileBytes, FileStream, FileUrl))
|
||||
|
||||
|
||||
def _normalize_source(value: Any) -> FileSource:
|
||||
"""Convert raw input to appropriate source type."""
|
||||
if isinstance(value, (FilePath, FileBytes, FileStream, AsyncFileStream, FileUrl)):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
if value.startswith(("http://", "https://")):
|
||||
return FileUrl(url=value)
|
||||
return FilePath(path=Path(value))
|
||||
if isinstance(value, Path):
|
||||
return FilePath(path=value)
|
||||
if isinstance(value, bytes):
|
||||
return FileBytes(data=value)
|
||||
if isinstance(value, AsyncReadable):
|
||||
return AsyncFileStream(stream=value)
|
||||
if hasattr(value, "read") and hasattr(value, "seek"):
|
||||
return FileStream(stream=value)
|
||||
raise ValueError(f"Cannot convert {type(value).__name__} to file source")
|
||||
|
||||
|
||||
RawFileInput = str | Path | bytes
|
||||
FileSourceInput = Annotated[
|
||||
RawFileInput | FileSource, BeforeValidator(_normalize_source)
|
||||
]
|
||||
@@ -1,281 +0,0 @@
|
||||
"""Content-type specific file classes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from io import IOBase
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any, BinaryIO, Literal, Self
|
||||
|
||||
from pydantic import BaseModel, Field, GetCoreSchemaHandler
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
|
||||
from crewai_files.core.sources import (
|
||||
AsyncFileStream,
|
||||
FileBytes,
|
||||
FilePath,
|
||||
FileSource,
|
||||
FileStream,
|
||||
FileUrl,
|
||||
is_file_source,
|
||||
)
|
||||
|
||||
|
||||
FileSourceInput = str | Path | bytes | IOBase | FileSource
|
||||
|
||||
|
||||
class _FileSourceCoercer:
|
||||
"""Pydantic-compatible type that coerces various inputs to FileSource."""
|
||||
|
||||
@classmethod
|
||||
def _coerce(cls, v: Any) -> FileSource:
|
||||
"""Convert raw input to appropriate FileSource type."""
|
||||
if isinstance(v, (FilePath, FileBytes, FileStream, FileUrl)):
|
||||
return v
|
||||
if isinstance(v, str):
|
||||
if v.startswith(("http://", "https://")):
|
||||
return FileUrl(url=v)
|
||||
return FilePath(path=Path(v))
|
||||
if isinstance(v, Path):
|
||||
return FilePath(path=v)
|
||||
if isinstance(v, bytes):
|
||||
return FileBytes(data=v)
|
||||
if isinstance(v, (IOBase, BinaryIO)):
|
||||
return FileStream(stream=v)
|
||||
raise ValueError(f"Cannot convert {type(v).__name__} to file source")
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls,
|
||||
_source_type: Any,
|
||||
_handler: GetCoreSchemaHandler,
|
||||
) -> CoreSchema:
|
||||
"""Generate Pydantic core schema for FileSource coercion."""
|
||||
return core_schema.no_info_plain_validator_function(
|
||||
cls._coerce,
|
||||
serialization=core_schema.plain_serializer_function_ser_schema(
|
||||
lambda v: v,
|
||||
info_arg=False,
|
||||
return_schema=core_schema.any_schema(),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
CoercedFileSource = Annotated[FileSourceInput, _FileSourceCoercer]
|
||||
|
||||
FileMode = Literal["strict", "auto", "warn", "chunk"]
|
||||
|
||||
|
||||
ImageExtension = Literal[
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".gif",
|
||||
".webp",
|
||||
".bmp",
|
||||
".tiff",
|
||||
".tif",
|
||||
".svg",
|
||||
".heic",
|
||||
".heif",
|
||||
]
|
||||
ImageMimeType = Literal[
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/gif",
|
||||
"image/webp",
|
||||
"image/bmp",
|
||||
"image/tiff",
|
||||
"image/svg+xml",
|
||||
"image/heic",
|
||||
"image/heif",
|
||||
]
|
||||
|
||||
PDFExtension = Literal[".pdf"]
|
||||
PDFContentType = Literal["application/pdf"]
|
||||
|
||||
TextExtension = Literal[
|
||||
".txt",
|
||||
".md",
|
||||
".rst",
|
||||
".csv",
|
||||
".json",
|
||||
".xml",
|
||||
".yaml",
|
||||
".yml",
|
||||
".html",
|
||||
".htm",
|
||||
".log",
|
||||
".ini",
|
||||
".cfg",
|
||||
".conf",
|
||||
]
|
||||
TextContentType = Literal[
|
||||
"text/plain",
|
||||
"text/markdown",
|
||||
"text/csv",
|
||||
"application/json",
|
||||
"application/xml",
|
||||
"text/xml",
|
||||
"application/x-yaml",
|
||||
"text/yaml",
|
||||
"text/html",
|
||||
]
|
||||
|
||||
AudioExtension = Literal[
|
||||
".mp3", ".wav", ".ogg", ".flac", ".aac", ".m4a", ".wma", ".aiff", ".opus"
|
||||
]
|
||||
AudioMimeType = Literal[
|
||||
"audio/mp3",
|
||||
"audio/mpeg",
|
||||
"audio/wav",
|
||||
"audio/x-wav",
|
||||
"audio/ogg",
|
||||
"audio/flac",
|
||||
"audio/aac",
|
||||
"audio/m4a",
|
||||
"audio/mp4",
|
||||
"audio/x-ms-wma",
|
||||
"audio/aiff",
|
||||
"audio/opus",
|
||||
]
|
||||
|
||||
VideoExtension = Literal[
|
||||
".mp4", ".avi", ".mkv", ".mov", ".webm", ".flv", ".wmv", ".m4v", ".mpeg", ".mpg"
|
||||
]
|
||||
VideoMimeType = Literal[
|
||||
"video/mp4",
|
||||
"video/mpeg",
|
||||
"video/webm",
|
||||
"video/quicktime",
|
||||
"video/x-msvideo",
|
||||
"video/x-matroska",
|
||||
"video/x-flv",
|
||||
"video/x-ms-wmv",
|
||||
]
|
||||
|
||||
|
||||
class BaseFile(ABC, BaseModel):
|
||||
"""Abstract base class for typed file wrappers.
|
||||
|
||||
Provides common functionality for all file types including:
|
||||
- File source management
|
||||
- Content reading
|
||||
- Dict unpacking support (`**` syntax)
|
||||
- Per-file mode mode
|
||||
|
||||
Can be unpacked with ** syntax: `{**ImageFile(source="./chart.png")}`
|
||||
which unpacks to: `{"chart": <ImageFile instance>}` using filename stem as key.
|
||||
|
||||
Attributes:
|
||||
source: The underlying file source (path, bytes, or stream).
|
||||
mode: How to handle this file if it exceeds provider limits.
|
||||
"""
|
||||
|
||||
source: CoercedFileSource = Field(description="The underlying file source.")
|
||||
mode: FileMode = Field(
|
||||
default="auto",
|
||||
description="How to handle if file exceeds limits: strict, auto, warn, chunk.",
|
||||
)
|
||||
|
||||
@property
|
||||
def _file_source(self) -> FileSource:
|
||||
"""Get source with narrowed type (always FileSource after validation)."""
|
||||
if is_file_source(self.source):
|
||||
return self.source
|
||||
raise TypeError("source must be a FileSource after validation")
|
||||
|
||||
@property
|
||||
def filename(self) -> str | None:
|
||||
"""Get the filename from the source."""
|
||||
return self._file_source.filename
|
||||
|
||||
@property
|
||||
def content_type(self) -> str:
|
||||
"""Get the content type from the source."""
|
||||
return self._file_source.content_type
|
||||
|
||||
def read(self) -> bytes:
|
||||
"""Read the file content as bytes."""
|
||||
return self._file_source.read() # type: ignore[union-attr]
|
||||
|
||||
async def aread(self) -> bytes:
|
||||
"""Async read the file content as bytes.
|
||||
|
||||
Raises:
|
||||
TypeError: If the underlying source doesn't support async read.
|
||||
"""
|
||||
source = self._file_source
|
||||
if isinstance(source, (FilePath, FileBytes, AsyncFileStream, FileUrl)):
|
||||
return await source.aread()
|
||||
raise TypeError(f"{type(source).__name__} does not support async read")
|
||||
|
||||
def read_text(self, encoding: str = "utf-8") -> str:
|
||||
"""Read the file content as string."""
|
||||
return self.read().decode(encoding)
|
||||
|
||||
@property
|
||||
def _unpack_key(self) -> str:
|
||||
"""Get the key to use when unpacking (filename stem)."""
|
||||
filename = self._file_source.filename
|
||||
if filename:
|
||||
return Path(filename).stem
|
||||
return "file"
|
||||
|
||||
def keys(self) -> list[str]:
|
||||
"""Return keys for dict unpacking."""
|
||||
return [self._unpack_key]
|
||||
|
||||
def __getitem__(self, key: str) -> Self:
|
||||
"""Return self for dict unpacking."""
|
||||
if key == self._unpack_key:
|
||||
return self
|
||||
raise KeyError(key)
|
||||
|
||||
|
||||
class ImageFile(BaseFile):
|
||||
"""File representing an image.
|
||||
|
||||
Supports common image formats: PNG, JPEG, GIF, WebP, BMP, TIFF, SVG.
|
||||
"""
|
||||
|
||||
|
||||
class PDFFile(BaseFile):
|
||||
"""File representing a PDF document."""
|
||||
|
||||
|
||||
class TextFile(BaseFile):
|
||||
"""File representing a text document.
|
||||
|
||||
Supports common text formats: TXT, MD, RST, CSV, JSON, XML, YAML, HTML.
|
||||
"""
|
||||
|
||||
|
||||
class AudioFile(BaseFile):
|
||||
"""File representing an audio file.
|
||||
|
||||
Supports common audio formats: MP3, WAV, OGG, FLAC, AAC, M4A, WMA.
|
||||
"""
|
||||
|
||||
|
||||
class VideoFile(BaseFile):
|
||||
"""File representing a video file.
|
||||
|
||||
Supports common video formats: MP4, AVI, MKV, MOV, WebM, FLV, WMV.
|
||||
"""
|
||||
|
||||
|
||||
class File(BaseFile):
|
||||
"""Generic file that auto-detects the appropriate type.
|
||||
|
||||
Use this when you don't want to specify the exact file type.
|
||||
The content type is automatically detected from the file contents.
|
||||
|
||||
Example:
|
||||
>>> pdf_file = File(source="./document.pdf")
|
||||
>>> image_file = File(source="./image.png")
|
||||
>>> bytes_file = File(source=b"file content")
|
||||
"""
|
||||
|
||||
|
||||
FileInput = AudioFile | File | ImageFile | PDFFile | TextFile | VideoFile
|
||||
@@ -1,12 +0,0 @@
|
||||
"""High-level formatting API for multimodal content."""
|
||||
|
||||
from crewai_files.formatting.api import (
|
||||
aformat_multimodal_content,
|
||||
format_multimodal_content,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"aformat_multimodal_content",
|
||||
"format_multimodal_content",
|
||||
]
|
||||
@@ -1,91 +0,0 @@
|
||||
"""Anthropic content block formatter."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from typing import Any
|
||||
|
||||
from crewai_files.core.resolved import (
|
||||
FileReference,
|
||||
InlineBase64,
|
||||
ResolvedFile,
|
||||
UrlReference,
|
||||
)
|
||||
from crewai_files.core.types import FileInput
|
||||
|
||||
|
||||
class AnthropicFormatter:
|
||||
"""Formats resolved files into Anthropic content blocks."""
|
||||
|
||||
def format_block(
|
||||
self,
|
||||
file: FileInput,
|
||||
resolved: ResolvedFile,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Format a resolved file into an Anthropic content block.
|
||||
|
||||
Args:
|
||||
file: Original file input with metadata.
|
||||
resolved: Resolved file.
|
||||
|
||||
Returns:
|
||||
Content block dict or None if not supported.
|
||||
"""
|
||||
content_type = file.content_type
|
||||
block_type = self._get_block_type(content_type)
|
||||
if block_type is None:
|
||||
return None
|
||||
|
||||
if isinstance(resolved, FileReference):
|
||||
return {
|
||||
"type": block_type,
|
||||
"source": {
|
||||
"type": "file",
|
||||
"file_id": resolved.file_id,
|
||||
},
|
||||
}
|
||||
|
||||
if isinstance(resolved, UrlReference):
|
||||
return {
|
||||
"type": block_type,
|
||||
"source": {
|
||||
"type": "url",
|
||||
"url": resolved.url,
|
||||
},
|
||||
}
|
||||
|
||||
if isinstance(resolved, InlineBase64):
|
||||
return {
|
||||
"type": block_type,
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": resolved.content_type,
|
||||
"data": resolved.data,
|
||||
},
|
||||
}
|
||||
|
||||
data = base64.b64encode(file.read()).decode("ascii")
|
||||
return {
|
||||
"type": block_type,
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": content_type,
|
||||
"data": data,
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _get_block_type(content_type: str) -> str | None:
|
||||
"""Get Anthropic block type for content type.
|
||||
|
||||
Args:
|
||||
content_type: MIME type.
|
||||
|
||||
Returns:
|
||||
Block type string or None if not supported.
|
||||
"""
|
||||
if content_type.startswith("image/"):
|
||||
return "image"
|
||||
if content_type == "application/pdf":
|
||||
return "document"
|
||||
return None
|
||||
@@ -1,277 +0,0 @@
|
||||
"""High-level API for formatting multimodal content."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from crewai_files.cache.upload_cache import get_upload_cache
|
||||
from crewai_files.core.types import FileInput
|
||||
from crewai_files.formatting.anthropic import AnthropicFormatter
|
||||
from crewai_files.formatting.bedrock import BedrockFormatter
|
||||
from crewai_files.formatting.gemini import GeminiFormatter
|
||||
from crewai_files.formatting.openai import OpenAIFormatter
|
||||
from crewai_files.processing.constraints import get_constraints_for_provider
|
||||
from crewai_files.processing.processor import FileProcessor
|
||||
from crewai_files.resolution.resolver import FileResolver, FileResolverConfig
|
||||
from crewai_files.uploaders.factory import ProviderType
|
||||
|
||||
|
||||
def _normalize_provider(provider: str | None) -> ProviderType:
|
||||
"""Normalize provider string to ProviderType.
|
||||
|
||||
Args:
|
||||
provider: Raw provider string.
|
||||
|
||||
Returns:
|
||||
Normalized provider type.
|
||||
|
||||
Raises:
|
||||
ValueError: If provider is None or empty.
|
||||
"""
|
||||
if not provider:
|
||||
raise ValueError("provider is required")
|
||||
|
||||
provider_lower = provider.lower()
|
||||
|
||||
if "gemini" in provider_lower:
|
||||
return "gemini"
|
||||
if "google" in provider_lower:
|
||||
return "google"
|
||||
if "anthropic" in provider_lower:
|
||||
return "anthropic"
|
||||
if "claude" in provider_lower:
|
||||
return "claude"
|
||||
if "bedrock" in provider_lower:
|
||||
return "bedrock"
|
||||
if "aws" in provider_lower:
|
||||
return "aws"
|
||||
if "azure" in provider_lower:
|
||||
return "azure"
|
||||
if "gpt" in provider_lower:
|
||||
return "gpt"
|
||||
|
||||
return "openai"
|
||||
|
||||
|
||||
def format_multimodal_content(
|
||||
files: dict[str, FileInput],
|
||||
provider: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Format files as provider-specific multimodal content blocks.
|
||||
|
||||
This is the main high-level API for converting files to content blocks
|
||||
suitable for sending to LLM providers. It handles:
|
||||
- File processing according to provider constraints
|
||||
- Resolution (upload vs inline) based on provider capabilities
|
||||
- Formatting into provider-specific content block structures
|
||||
|
||||
Args:
|
||||
files: Dictionary mapping file names to FileInput objects.
|
||||
provider: Provider name (e.g., "openai", "anthropic", "bedrock", "gemini").
|
||||
|
||||
Returns:
|
||||
List of content blocks in the provider's expected format.
|
||||
|
||||
Example:
|
||||
>>> from crewai_files import format_multimodal_content, ImageFile
|
||||
>>> files = {"photo": ImageFile(source="image.jpg")}
|
||||
>>> blocks = format_multimodal_content(files, "openai")
|
||||
"""
|
||||
if not files:
|
||||
return []
|
||||
|
||||
provider_type = _normalize_provider(provider)
|
||||
|
||||
processor = FileProcessor(constraints=provider_type)
|
||||
processed_files = processor.process_files(files)
|
||||
|
||||
if not processed_files:
|
||||
return []
|
||||
|
||||
constraints = get_constraints_for_provider(provider_type)
|
||||
supported_types = _get_supported_types(constraints)
|
||||
supported_files = _filter_supported_files(processed_files, supported_types)
|
||||
|
||||
if not supported_files:
|
||||
return []
|
||||
|
||||
config = _get_resolver_config(provider_type)
|
||||
upload_cache = get_upload_cache()
|
||||
resolver = FileResolver(config=config, upload_cache=upload_cache)
|
||||
|
||||
formatter = _get_formatter(provider_type)
|
||||
content_blocks: list[dict[str, Any]] = []
|
||||
|
||||
for name, file_input in supported_files.items():
|
||||
resolved = resolver.resolve(file_input, provider_type)
|
||||
block = _format_block(formatter, file_input, resolved, name)
|
||||
if block is not None:
|
||||
content_blocks.append(block)
|
||||
|
||||
return content_blocks
|
||||
|
||||
|
||||
async def aformat_multimodal_content(
|
||||
files: dict[str, FileInput],
|
||||
provider: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Async format files as provider-specific multimodal content blocks.
|
||||
|
||||
Async version of format_multimodal_content with parallel file resolution.
|
||||
|
||||
Args:
|
||||
files: Dictionary mapping file names to FileInput objects.
|
||||
provider: Provider name (e.g., "openai", "anthropic", "bedrock", "gemini").
|
||||
|
||||
Returns:
|
||||
List of content blocks in the provider's expected format.
|
||||
"""
|
||||
if not files:
|
||||
return []
|
||||
|
||||
provider_type = _normalize_provider(provider)
|
||||
|
||||
processor = FileProcessor(constraints=provider_type)
|
||||
processed_files = await processor.aprocess_files(files)
|
||||
|
||||
if not processed_files:
|
||||
return []
|
||||
|
||||
constraints = get_constraints_for_provider(provider_type)
|
||||
supported_types = _get_supported_types(constraints)
|
||||
supported_files = _filter_supported_files(processed_files, supported_types)
|
||||
|
||||
if not supported_files:
|
||||
return []
|
||||
|
||||
config = _get_resolver_config(provider_type)
|
||||
upload_cache = get_upload_cache()
|
||||
resolver = FileResolver(config=config, upload_cache=upload_cache)
|
||||
|
||||
resolved_files = await resolver.aresolve_files(supported_files, provider_type)
|
||||
|
||||
formatter = _get_formatter(provider_type)
|
||||
content_blocks: list[dict[str, Any]] = []
|
||||
|
||||
for name, resolved in resolved_files.items():
|
||||
file_input = supported_files[name]
|
||||
block = _format_block(formatter, file_input, resolved, name)
|
||||
if block is not None:
|
||||
content_blocks.append(block)
|
||||
|
||||
return content_blocks
|
||||
|
||||
|
||||
def _get_supported_types(
|
||||
constraints: Any | None,
|
||||
) -> list[str]:
|
||||
"""Get list of supported MIME type prefixes from constraints.
|
||||
|
||||
Args:
|
||||
constraints: Provider constraints.
|
||||
|
||||
Returns:
|
||||
List of MIME type prefixes (e.g., ["image/", "application/pdf"]).
|
||||
"""
|
||||
if constraints is None:
|
||||
return []
|
||||
|
||||
supported: list[str] = []
|
||||
if constraints.image is not None:
|
||||
supported.append("image/")
|
||||
if constraints.pdf is not None:
|
||||
supported.append("application/pdf")
|
||||
if constraints.audio is not None:
|
||||
supported.append("audio/")
|
||||
if constraints.video is not None:
|
||||
supported.append("video/")
|
||||
return supported
|
||||
|
||||
|
||||
def _filter_supported_files(
|
||||
files: dict[str, FileInput],
|
||||
supported_types: list[str],
|
||||
) -> dict[str, FileInput]:
|
||||
"""Filter files to those with supported content types.
|
||||
|
||||
Args:
|
||||
files: All files.
|
||||
supported_types: MIME type prefixes to allow.
|
||||
|
||||
Returns:
|
||||
Filtered dictionary of supported files.
|
||||
"""
|
||||
return {
|
||||
name: f
|
||||
for name, f in files.items()
|
||||
if any(f.content_type.startswith(t) for t in supported_types)
|
||||
}
|
||||
|
||||
|
||||
def _get_resolver_config(provider_lower: str) -> FileResolverConfig:
|
||||
"""Get resolver config for provider.
|
||||
|
||||
Args:
|
||||
provider_lower: Lowercase provider name.
|
||||
|
||||
Returns:
|
||||
Configured FileResolverConfig.
|
||||
"""
|
||||
if "bedrock" in provider_lower:
|
||||
s3_bucket = os.environ.get("CREWAI_BEDROCK_S3_BUCKET")
|
||||
prefer_upload = bool(s3_bucket)
|
||||
return FileResolverConfig(
|
||||
prefer_upload=prefer_upload, use_bytes_for_bedrock=True
|
||||
)
|
||||
|
||||
return FileResolverConfig(prefer_upload=False)
|
||||
|
||||
|
||||
def _get_formatter(
|
||||
provider_lower: str,
|
||||
) -> OpenAIFormatter | AnthropicFormatter | BedrockFormatter | GeminiFormatter:
|
||||
"""Get formatter for provider.
|
||||
|
||||
Args:
|
||||
provider_lower: Lowercase provider name.
|
||||
|
||||
Returns:
|
||||
Provider-specific formatter instance.
|
||||
"""
|
||||
if "anthropic" in provider_lower or "claude" in provider_lower:
|
||||
return AnthropicFormatter()
|
||||
|
||||
if "bedrock" in provider_lower or "aws" in provider_lower:
|
||||
s3_bucket_owner = os.environ.get("CREWAI_BEDROCK_S3_BUCKET_OWNER")
|
||||
return BedrockFormatter(s3_bucket_owner=s3_bucket_owner)
|
||||
|
||||
if "gemini" in provider_lower or "google" in provider_lower:
|
||||
return GeminiFormatter()
|
||||
|
||||
return OpenAIFormatter()
|
||||
|
||||
|
||||
def _format_block(
|
||||
formatter: OpenAIFormatter
|
||||
| AnthropicFormatter
|
||||
| BedrockFormatter
|
||||
| GeminiFormatter,
|
||||
file_input: FileInput,
|
||||
resolved: Any,
|
||||
name: str,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Format a single file block using the appropriate formatter.
|
||||
|
||||
Args:
|
||||
formatter: Provider formatter.
|
||||
file_input: Original file input.
|
||||
resolved: Resolved file.
|
||||
name: File name.
|
||||
|
||||
Returns:
|
||||
Content block dict or None.
|
||||
"""
|
||||
if isinstance(formatter, BedrockFormatter):
|
||||
return formatter.format_block(file_input, resolved, name=name)
|
||||
return formatter.format_block(file_input, resolved)
|
||||
@@ -1,28 +0,0 @@
|
||||
"""Base formatter protocol for provider-specific content blocks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Protocol
|
||||
|
||||
from crewai_files.core.resolved import ResolvedFile
|
||||
from crewai_files.core.types import FileInput
|
||||
|
||||
|
||||
class ContentFormatter(Protocol):
|
||||
"""Protocol for formatting resolved files into provider content blocks."""
|
||||
|
||||
def format_block(
|
||||
self,
|
||||
file: FileInput,
|
||||
resolved: ResolvedFile,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Format a resolved file into a provider-specific content block.
|
||||
|
||||
Args:
|
||||
file: Original file input with metadata.
|
||||
resolved: Resolved file (FileReference, InlineBase64, etc.).
|
||||
|
||||
Returns:
|
||||
Content block dict or None if file type not supported.
|
||||
"""
|
||||
...
|
||||
@@ -1,188 +0,0 @@
|
||||
"""Bedrock content block formatter."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from crewai_files.core.resolved import (
|
||||
FileReference,
|
||||
InlineBytes,
|
||||
ResolvedFile,
|
||||
)
|
||||
from crewai_files.core.types import FileInput
|
||||
|
||||
|
||||
_DOCUMENT_FORMATS: dict[str, str] = {
|
||||
"application/pdf": "pdf",
|
||||
"text/csv": "csv",
|
||||
"text/plain": "txt",
|
||||
"text/markdown": "md",
|
||||
"text/html": "html",
|
||||
"application/msword": "doc",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document": "docx",
|
||||
"application/vnd.ms-excel": "xls",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": "xlsx",
|
||||
}
|
||||
|
||||
_VIDEO_FORMATS: dict[str, str] = {
|
||||
"video/mp4": "mp4",
|
||||
"video/quicktime": "mov",
|
||||
"video/x-matroska": "mkv",
|
||||
"video/webm": "webm",
|
||||
"video/x-flv": "flv",
|
||||
"video/mpeg": "mpeg",
|
||||
"video/3gpp": "three_gp",
|
||||
}
|
||||
|
||||
|
||||
class BedrockFormatter:
|
||||
"""Formats resolved files into Bedrock Converse API content blocks."""
|
||||
|
||||
def __init__(self, s3_bucket_owner: str | None = None) -> None:
|
||||
"""Initialize formatter.
|
||||
|
||||
Args:
|
||||
s3_bucket_owner: Optional S3 bucket owner for file references.
|
||||
"""
|
||||
self.s3_bucket_owner = s3_bucket_owner
|
||||
|
||||
def format_block(
|
||||
self,
|
||||
file: FileInput,
|
||||
resolved: ResolvedFile,
|
||||
name: str | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Format a resolved file into a Bedrock content block.
|
||||
|
||||
Args:
|
||||
file: Original file input with metadata.
|
||||
resolved: Resolved file.
|
||||
name: File name (required for document blocks).
|
||||
|
||||
Returns:
|
||||
Content block dict or None if not supported.
|
||||
"""
|
||||
content_type = file.content_type
|
||||
|
||||
if isinstance(resolved, FileReference) and resolved.file_uri:
|
||||
return self._format_s3_block(content_type, resolved.file_uri, name)
|
||||
|
||||
if isinstance(resolved, InlineBytes):
|
||||
file_bytes = resolved.data
|
||||
else:
|
||||
file_bytes = file.read()
|
||||
|
||||
return self._format_bytes_block(content_type, file_bytes, name)
|
||||
|
||||
def _format_s3_block(
|
||||
self,
|
||||
content_type: str,
|
||||
file_uri: str,
|
||||
name: str | None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Format block with S3 location source.
|
||||
|
||||
Args:
|
||||
content_type: MIME type.
|
||||
file_uri: S3 URI.
|
||||
name: File name for documents.
|
||||
|
||||
Returns:
|
||||
Content block dict or None.
|
||||
"""
|
||||
s3_location: dict[str, Any] = {"uri": file_uri}
|
||||
if self.s3_bucket_owner:
|
||||
s3_location["bucketOwner"] = self.s3_bucket_owner
|
||||
|
||||
if content_type.startswith("image/"):
|
||||
return {
|
||||
"image": {
|
||||
"format": self._get_image_format(content_type),
|
||||
"source": {"s3Location": s3_location},
|
||||
}
|
||||
}
|
||||
|
||||
if content_type.startswith("video/"):
|
||||
video_format = _VIDEO_FORMATS.get(content_type)
|
||||
if video_format:
|
||||
return {
|
||||
"video": {
|
||||
"format": video_format,
|
||||
"source": {"s3Location": s3_location},
|
||||
}
|
||||
}
|
||||
return None
|
||||
|
||||
doc_format = _DOCUMENT_FORMATS.get(content_type)
|
||||
if doc_format:
|
||||
return {
|
||||
"document": {
|
||||
"name": name or "document",
|
||||
"format": doc_format,
|
||||
"source": {"s3Location": s3_location},
|
||||
}
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
def _format_bytes_block(
|
||||
self,
|
||||
content_type: str,
|
||||
file_bytes: bytes,
|
||||
name: str | None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Format block with inline bytes source.
|
||||
|
||||
Args:
|
||||
content_type: MIME type.
|
||||
file_bytes: Raw file bytes.
|
||||
name: File name for documents.
|
||||
|
||||
Returns:
|
||||
Content block dict or None.
|
||||
"""
|
||||
if content_type.startswith("image/"):
|
||||
return {
|
||||
"image": {
|
||||
"format": self._get_image_format(content_type),
|
||||
"source": {"bytes": file_bytes},
|
||||
}
|
||||
}
|
||||
|
||||
if content_type.startswith("video/"):
|
||||
video_format = _VIDEO_FORMATS.get(content_type)
|
||||
if video_format:
|
||||
return {
|
||||
"video": {
|
||||
"format": video_format,
|
||||
"source": {"bytes": file_bytes},
|
||||
}
|
||||
}
|
||||
return None
|
||||
|
||||
doc_format = _DOCUMENT_FORMATS.get(content_type)
|
||||
if doc_format:
|
||||
return {
|
||||
"document": {
|
||||
"name": name or "document",
|
||||
"format": doc_format,
|
||||
"source": {"bytes": file_bytes},
|
||||
}
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_image_format(content_type: str) -> str:
|
||||
"""Get Bedrock image format from content type.
|
||||
|
||||
Args:
|
||||
content_type: MIME type.
|
||||
|
||||
Returns:
|
||||
Format string for Bedrock.
|
||||
"""
|
||||
media_type = content_type.split("/")[-1]
|
||||
if media_type == "jpg":
|
||||
return "jpeg"
|
||||
return media_type
|
||||
@@ -1,66 +0,0 @@
|
||||
"""Gemini content block formatter."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from typing import Any
|
||||
|
||||
from crewai_files.core.resolved import (
|
||||
FileReference,
|
||||
InlineBase64,
|
||||
ResolvedFile,
|
||||
UrlReference,
|
||||
)
|
||||
from crewai_files.core.types import FileInput
|
||||
|
||||
|
||||
class GeminiFormatter:
|
||||
"""Formats resolved files into Gemini content blocks."""
|
||||
|
||||
def format_block(
|
||||
self,
|
||||
file: FileInput,
|
||||
resolved: ResolvedFile,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Format a resolved file into a Gemini content block.
|
||||
|
||||
Args:
|
||||
file: Original file input with metadata.
|
||||
resolved: Resolved file.
|
||||
|
||||
Returns:
|
||||
Content block dict or None if not supported.
|
||||
"""
|
||||
content_type = file.content_type
|
||||
|
||||
if isinstance(resolved, FileReference) and resolved.file_uri:
|
||||
return {
|
||||
"fileData": {
|
||||
"mimeType": resolved.content_type,
|
||||
"fileUri": resolved.file_uri,
|
||||
}
|
||||
}
|
||||
|
||||
if isinstance(resolved, UrlReference):
|
||||
return {
|
||||
"fileData": {
|
||||
"mimeType": content_type,
|
||||
"fileUri": resolved.url,
|
||||
}
|
||||
}
|
||||
|
||||
if isinstance(resolved, InlineBase64):
|
||||
return {
|
||||
"inlineData": {
|
||||
"mimeType": resolved.content_type,
|
||||
"data": resolved.data,
|
||||
}
|
||||
}
|
||||
|
||||
data = base64.b64encode(file.read()).decode("ascii")
|
||||
return {
|
||||
"inlineData": {
|
||||
"mimeType": content_type,
|
||||
"data": data,
|
||||
}
|
||||
}
|
||||
@@ -1,60 +0,0 @@
|
||||
"""OpenAI content block formatter."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from typing import Any
|
||||
|
||||
from crewai_files.core.resolved import (
|
||||
FileReference,
|
||||
InlineBase64,
|
||||
ResolvedFile,
|
||||
UrlReference,
|
||||
)
|
||||
from crewai_files.core.types import FileInput
|
||||
|
||||
|
||||
class OpenAIFormatter:
|
||||
"""Formats resolved files into OpenAI content blocks."""
|
||||
|
||||
def format_block(
|
||||
self,
|
||||
file: FileInput,
|
||||
resolved: ResolvedFile,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Format a resolved file into an OpenAI content block.
|
||||
|
||||
Args:
|
||||
file: Original file input with metadata.
|
||||
resolved: Resolved file.
|
||||
|
||||
Returns:
|
||||
Content block dict or None if not supported.
|
||||
"""
|
||||
content_type = file.content_type
|
||||
|
||||
if isinstance(resolved, FileReference):
|
||||
return {
|
||||
"type": "file",
|
||||
"file": {"file_id": resolved.file_id},
|
||||
}
|
||||
|
||||
if isinstance(resolved, UrlReference):
|
||||
return {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": resolved.url},
|
||||
}
|
||||
|
||||
if isinstance(resolved, InlineBase64):
|
||||
return {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:{resolved.content_type};base64,{resolved.data}"
|
||||
},
|
||||
}
|
||||
|
||||
data = base64.b64encode(file.read()).decode("ascii")
|
||||
return {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:{content_type};base64,{data}"},
|
||||
}
|
||||
@@ -1,62 +0,0 @@
|
||||
"""File processing module for multimodal content handling.
|
||||
|
||||
This module provides validation, transformation, and processing utilities
|
||||
for files used in multimodal LLM interactions.
|
||||
"""
|
||||
|
||||
from crewai_files.processing.constraints import (
|
||||
ANTHROPIC_CONSTRAINTS,
|
||||
BEDROCK_CONSTRAINTS,
|
||||
GEMINI_CONSTRAINTS,
|
||||
OPENAI_CONSTRAINTS,
|
||||
AudioConstraints,
|
||||
ImageConstraints,
|
||||
PDFConstraints,
|
||||
ProviderConstraints,
|
||||
VideoConstraints,
|
||||
get_constraints_for_provider,
|
||||
)
|
||||
from crewai_files.processing.enums import FileHandling
|
||||
from crewai_files.processing.exceptions import (
|
||||
FileProcessingError,
|
||||
FileTooLargeError,
|
||||
FileValidationError,
|
||||
ProcessingDependencyError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from crewai_files.processing.processor import FileProcessor
|
||||
from crewai_files.processing.validators import (
|
||||
validate_audio,
|
||||
validate_file,
|
||||
validate_image,
|
||||
validate_pdf,
|
||||
validate_text,
|
||||
validate_video,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ANTHROPIC_CONSTRAINTS",
|
||||
"BEDROCK_CONSTRAINTS",
|
||||
"GEMINI_CONSTRAINTS",
|
||||
"OPENAI_CONSTRAINTS",
|
||||
"AudioConstraints",
|
||||
"FileHandling",
|
||||
"FileProcessingError",
|
||||
"FileProcessor",
|
||||
"FileTooLargeError",
|
||||
"FileValidationError",
|
||||
"ImageConstraints",
|
||||
"PDFConstraints",
|
||||
"ProcessingDependencyError",
|
||||
"ProviderConstraints",
|
||||
"UnsupportedFileTypeError",
|
||||
"VideoConstraints",
|
||||
"get_constraints_for_provider",
|
||||
"validate_audio",
|
||||
"validate_file",
|
||||
"validate_image",
|
||||
"validate_pdf",
|
||||
"validate_text",
|
||||
"validate_video",
|
||||
]
|
||||
@@ -1,285 +0,0 @@
|
||||
"""Provider-specific file constraints for multimodal content."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import Literal
|
||||
|
||||
from crewai_files.core.types import (
|
||||
AudioMimeType,
|
||||
ImageMimeType,
|
||||
VideoMimeType,
|
||||
)
|
||||
|
||||
|
||||
ProviderName = Literal[
|
||||
"anthropic",
|
||||
"openai",
|
||||
"gemini",
|
||||
"bedrock",
|
||||
"azure",
|
||||
]
|
||||
|
||||
DEFAULT_IMAGE_FORMATS: tuple[ImageMimeType, ...] = (
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/gif",
|
||||
"image/webp",
|
||||
)
|
||||
|
||||
GEMINI_IMAGE_FORMATS: tuple[ImageMimeType, ...] = (
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/gif",
|
||||
"image/webp",
|
||||
"image/heic",
|
||||
"image/heif",
|
||||
)
|
||||
|
||||
DEFAULT_AUDIO_FORMATS: tuple[AudioMimeType, ...] = (
|
||||
"audio/mp3",
|
||||
"audio/mpeg",
|
||||
"audio/wav",
|
||||
"audio/ogg",
|
||||
"audio/flac",
|
||||
"audio/aac",
|
||||
"audio/m4a",
|
||||
)
|
||||
|
||||
GEMINI_AUDIO_FORMATS: tuple[AudioMimeType, ...] = (
|
||||
"audio/mp3",
|
||||
"audio/mpeg",
|
||||
"audio/wav",
|
||||
"audio/ogg",
|
||||
"audio/flac",
|
||||
"audio/aac",
|
||||
"audio/m4a",
|
||||
"audio/opus",
|
||||
)
|
||||
|
||||
DEFAULT_VIDEO_FORMATS: tuple[VideoMimeType, ...] = (
|
||||
"video/mp4",
|
||||
"video/mpeg",
|
||||
"video/webm",
|
||||
"video/quicktime",
|
||||
)
|
||||
|
||||
GEMINI_VIDEO_FORMATS: tuple[VideoMimeType, ...] = (
|
||||
"video/mp4",
|
||||
"video/mpeg",
|
||||
"video/webm",
|
||||
"video/quicktime",
|
||||
"video/x-msvideo",
|
||||
"video/x-flv",
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ImageConstraints:
|
||||
"""Constraints for image files.
|
||||
|
||||
Attributes:
|
||||
max_size_bytes: Maximum file size in bytes.
|
||||
max_width: Maximum image width in pixels.
|
||||
max_height: Maximum image height in pixels.
|
||||
max_images_per_request: Maximum number of images per request.
|
||||
supported_formats: Supported image MIME types.
|
||||
"""
|
||||
|
||||
max_size_bytes: int
|
||||
max_width: int | None = None
|
||||
max_height: int | None = None
|
||||
max_images_per_request: int | None = None
|
||||
supported_formats: tuple[ImageMimeType, ...] = DEFAULT_IMAGE_FORMATS
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PDFConstraints:
|
||||
"""Constraints for PDF files.
|
||||
|
||||
Attributes:
|
||||
max_size_bytes: Maximum file size in bytes.
|
||||
max_pages: Maximum number of pages.
|
||||
"""
|
||||
|
||||
max_size_bytes: int
|
||||
max_pages: int | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AudioConstraints:
|
||||
"""Constraints for audio files.
|
||||
|
||||
Attributes:
|
||||
max_size_bytes: Maximum file size in bytes.
|
||||
max_duration_seconds: Maximum audio duration in seconds.
|
||||
supported_formats: Supported audio MIME types.
|
||||
"""
|
||||
|
||||
max_size_bytes: int
|
||||
max_duration_seconds: int | None = None
|
||||
supported_formats: tuple[AudioMimeType, ...] = DEFAULT_AUDIO_FORMATS
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class VideoConstraints:
|
||||
"""Constraints for video files.
|
||||
|
||||
Attributes:
|
||||
max_size_bytes: Maximum file size in bytes.
|
||||
max_duration_seconds: Maximum video duration in seconds.
|
||||
supported_formats: Supported video MIME types.
|
||||
"""
|
||||
|
||||
max_size_bytes: int
|
||||
max_duration_seconds: int | None = None
|
||||
supported_formats: tuple[VideoMimeType, ...] = DEFAULT_VIDEO_FORMATS
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProviderConstraints:
|
||||
"""Complete set of constraints for a provider.
|
||||
|
||||
Attributes:
|
||||
name: Provider name identifier.
|
||||
image: Image file constraints.
|
||||
pdf: PDF file constraints.
|
||||
audio: Audio file constraints.
|
||||
video: Video file constraints.
|
||||
general_max_size_bytes: Maximum size for any file type.
|
||||
supports_file_upload: Whether the provider supports file upload APIs.
|
||||
file_upload_threshold_bytes: Size threshold above which to use file upload.
|
||||
supports_url_references: Whether the provider supports URL-based file references.
|
||||
"""
|
||||
|
||||
name: ProviderName
|
||||
image: ImageConstraints | None = None
|
||||
pdf: PDFConstraints | None = None
|
||||
audio: AudioConstraints | None = None
|
||||
video: VideoConstraints | None = None
|
||||
general_max_size_bytes: int | None = None
|
||||
supports_file_upload: bool = False
|
||||
file_upload_threshold_bytes: int | None = None
|
||||
supports_url_references: bool = False
|
||||
|
||||
|
||||
ANTHROPIC_CONSTRAINTS = ProviderConstraints(
|
||||
name="anthropic",
|
||||
image=ImageConstraints(
|
||||
max_size_bytes=5_242_880, # 5 MB per image
|
||||
max_width=8000,
|
||||
max_height=8000,
|
||||
max_images_per_request=100,
|
||||
),
|
||||
pdf=PDFConstraints(
|
||||
max_size_bytes=33_554_432, # 32 MB request size limit
|
||||
max_pages=100,
|
||||
),
|
||||
supports_file_upload=True,
|
||||
file_upload_threshold_bytes=5_242_880,
|
||||
supports_url_references=True,
|
||||
)
|
||||
|
||||
OPENAI_CONSTRAINTS = ProviderConstraints(
|
||||
name="openai",
|
||||
image=ImageConstraints(
|
||||
max_size_bytes=20_971_520,
|
||||
max_images_per_request=10,
|
||||
),
|
||||
audio=AudioConstraints(
|
||||
max_size_bytes=26_214_400, # 25 MB - whisper limit
|
||||
max_duration_seconds=1500, # 25 minutes, arbitrary-ish, this is from the transcriptions limit
|
||||
),
|
||||
supports_file_upload=True,
|
||||
file_upload_threshold_bytes=5_242_880,
|
||||
supports_url_references=True,
|
||||
)
|
||||
|
||||
GEMINI_CONSTRAINTS = ProviderConstraints(
|
||||
name="gemini",
|
||||
image=ImageConstraints(
|
||||
max_size_bytes=104_857_600,
|
||||
supported_formats=GEMINI_IMAGE_FORMATS,
|
||||
),
|
||||
pdf=PDFConstraints(
|
||||
max_size_bytes=52_428_800,
|
||||
),
|
||||
audio=AudioConstraints(
|
||||
max_size_bytes=104_857_600,
|
||||
max_duration_seconds=34200, # 9.5 hours
|
||||
supported_formats=GEMINI_AUDIO_FORMATS,
|
||||
),
|
||||
video=VideoConstraints(
|
||||
max_size_bytes=2_147_483_648,
|
||||
max_duration_seconds=3600, # 1 hour at default resolution
|
||||
supported_formats=GEMINI_VIDEO_FORMATS,
|
||||
),
|
||||
supports_file_upload=True,
|
||||
file_upload_threshold_bytes=20_971_520,
|
||||
supports_url_references=True,
|
||||
)
|
||||
|
||||
BEDROCK_CONSTRAINTS = ProviderConstraints(
|
||||
name="bedrock",
|
||||
image=ImageConstraints(
|
||||
max_size_bytes=4_608_000,
|
||||
max_width=8000,
|
||||
max_height=8000,
|
||||
),
|
||||
pdf=PDFConstraints(
|
||||
max_size_bytes=3_840_000,
|
||||
max_pages=100,
|
||||
),
|
||||
)
|
||||
|
||||
AZURE_CONSTRAINTS = ProviderConstraints(
|
||||
name="azure",
|
||||
image=ImageConstraints(
|
||||
max_size_bytes=20_971_520,
|
||||
max_images_per_request=10,
|
||||
),
|
||||
audio=AudioConstraints(
|
||||
max_size_bytes=26_214_400, # 25 MB - same as openai
|
||||
max_duration_seconds=1500, # 25 minutes - same as openai
|
||||
),
|
||||
supports_url_references=True,
|
||||
)
|
||||
|
||||
|
||||
_PROVIDER_CONSTRAINTS_MAP: dict[str, ProviderConstraints] = {
|
||||
"anthropic": ANTHROPIC_CONSTRAINTS,
|
||||
"openai": OPENAI_CONSTRAINTS,
|
||||
"gemini": GEMINI_CONSTRAINTS,
|
||||
"bedrock": BEDROCK_CONSTRAINTS,
|
||||
"azure": AZURE_CONSTRAINTS,
|
||||
"claude": ANTHROPIC_CONSTRAINTS,
|
||||
"gpt": OPENAI_CONSTRAINTS,
|
||||
"google": GEMINI_CONSTRAINTS,
|
||||
"aws": BEDROCK_CONSTRAINTS,
|
||||
}
|
||||
|
||||
|
||||
@lru_cache(maxsize=32)
|
||||
def get_constraints_for_provider(
|
||||
provider: str | ProviderConstraints,
|
||||
) -> ProviderConstraints | None:
|
||||
"""Get constraints for a provider by name or return if already ProviderConstraints.
|
||||
|
||||
Args:
|
||||
provider: Provider name string or ProviderConstraints instance.
|
||||
|
||||
Returns:
|
||||
ProviderConstraints for the provider, or None if not found.
|
||||
"""
|
||||
if isinstance(provider, ProviderConstraints):
|
||||
return provider
|
||||
|
||||
provider_lower = provider.lower()
|
||||
|
||||
if provider_lower in _PROVIDER_CONSTRAINTS_MAP:
|
||||
return _PROVIDER_CONSTRAINTS_MAP[provider_lower]
|
||||
|
||||
for key, constraints in _PROVIDER_CONSTRAINTS_MAP.items():
|
||||
if key in provider_lower:
|
||||
return constraints
|
||||
|
||||
return None
|
||||
@@ -1,19 +0,0 @@
|
||||
"""Enums for file processing configuration."""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class FileHandling(Enum):
|
||||
"""Defines how files exceeding provider limits should be handled.
|
||||
|
||||
Attributes:
|
||||
STRICT: Fail with an error if file exceeds limits.
|
||||
AUTO: Automatically resize, compress, or optimize to fit limits.
|
||||
WARN: Log a warning but attempt to process anyway.
|
||||
CHUNK: Split large files into smaller pieces.
|
||||
"""
|
||||
|
||||
STRICT = "strict"
|
||||
AUTO = "auto"
|
||||
WARN = "warn"
|
||||
CHUNK = "chunk"
|
||||
@@ -1,145 +0,0 @@
|
||||
"""Exceptions for file processing operations."""
|
||||
|
||||
|
||||
class FileProcessingError(Exception):
|
||||
"""Base exception for file processing errors."""
|
||||
|
||||
def __init__(self, message: str, file_name: str | None = None) -> None:
|
||||
"""Initialize the exception.
|
||||
|
||||
Args:
|
||||
message: Error message describing the issue.
|
||||
file_name: Optional name of the file that caused the error.
|
||||
"""
|
||||
self.file_name = file_name
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class FileValidationError(FileProcessingError):
|
||||
"""Raised when file validation fails."""
|
||||
|
||||
|
||||
class FileTooLargeError(FileValidationError):
|
||||
"""Raised when a file exceeds the maximum allowed size."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
file_name: str | None = None,
|
||||
actual_size: int | None = None,
|
||||
max_size: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize the exception.
|
||||
|
||||
Args:
|
||||
message: Error message describing the issue.
|
||||
file_name: Optional name of the file that caused the error.
|
||||
actual_size: The actual size of the file in bytes.
|
||||
max_size: The maximum allowed size in bytes.
|
||||
"""
|
||||
self.actual_size = actual_size
|
||||
self.max_size = max_size
|
||||
super().__init__(message, file_name)
|
||||
|
||||
|
||||
class UnsupportedFileTypeError(FileValidationError):
|
||||
"""Raised when a file type is not supported by the provider."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
file_name: str | None = None,
|
||||
content_type: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize the exception.
|
||||
|
||||
Args:
|
||||
message: Error message describing the issue.
|
||||
file_name: Optional name of the file that caused the error.
|
||||
content_type: The content type that is not supported.
|
||||
"""
|
||||
self.content_type = content_type
|
||||
super().__init__(message, file_name)
|
||||
|
||||
|
||||
class ProcessingDependencyError(FileProcessingError):
|
||||
"""Raised when a required processing dependency is not installed."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
dependency: str,
|
||||
install_command: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize the exception.
|
||||
|
||||
Args:
|
||||
message: Error message describing the issue.
|
||||
dependency: Name of the missing dependency.
|
||||
install_command: Optional command to install the dependency.
|
||||
"""
|
||||
self.dependency = dependency
|
||||
self.install_command = install_command
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class TransientFileError(FileProcessingError):
|
||||
"""Transient error that may succeed on retry (network, timeout)."""
|
||||
|
||||
|
||||
class PermanentFileError(FileProcessingError):
|
||||
"""Permanent error that will not succeed on retry (auth, format)."""
|
||||
|
||||
|
||||
class UploadError(FileProcessingError):
|
||||
"""Base exception for upload errors."""
|
||||
|
||||
|
||||
class TransientUploadError(UploadError, TransientFileError):
|
||||
"""Upload failed but may succeed on retry (network issues, rate limits)."""
|
||||
|
||||
|
||||
class PermanentUploadError(UploadError, PermanentFileError):
|
||||
"""Upload failed permanently (auth failure, invalid file, unsupported type)."""
|
||||
|
||||
|
||||
def classify_upload_error(e: Exception, filename: str | None = None) -> Exception:
|
||||
"""Classify an exception as transient or permanent upload error.
|
||||
|
||||
Analyzes the exception type name and status code to determine if
|
||||
the error is likely transient (retryable) or permanent.
|
||||
|
||||
Args:
|
||||
e: The exception to classify.
|
||||
filename: Optional filename for error context.
|
||||
|
||||
Returns:
|
||||
A TransientUploadError or PermanentUploadError wrapping the original.
|
||||
"""
|
||||
error_type = type(e).__name__
|
||||
|
||||
if "RateLimit" in error_type or "APIConnection" in error_type:
|
||||
return TransientUploadError(f"Transient upload error: {e}", file_name=filename)
|
||||
if "Authentication" in error_type or "Permission" in error_type:
|
||||
return PermanentUploadError(
|
||||
f"Authentication/permission error: {e}", file_name=filename
|
||||
)
|
||||
if "BadRequest" in error_type or "InvalidRequest" in error_type:
|
||||
return PermanentUploadError(f"Invalid request: {e}", file_name=filename)
|
||||
|
||||
status_code = getattr(e, "status_code", None)
|
||||
if status_code is not None:
|
||||
if status_code >= 500 or status_code == 429:
|
||||
return TransientUploadError(
|
||||
f"Server error ({status_code}): {e}", file_name=filename
|
||||
)
|
||||
if status_code in (401, 403):
|
||||
return PermanentUploadError(
|
||||
f"Auth error ({status_code}): {e}", file_name=filename
|
||||
)
|
||||
if status_code == 400:
|
||||
return PermanentUploadError(
|
||||
f"Bad request ({status_code}): {e}", file_name=filename
|
||||
)
|
||||
|
||||
return TransientUploadError(f"Upload failed: {e}", file_name=filename)
|
||||
@@ -1,346 +0,0 @@
|
||||
"""FileProcessor for validating and transforming files based on provider constraints."""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Sequence
|
||||
import logging
|
||||
|
||||
from crewai_files.core.types import (
|
||||
AudioFile,
|
||||
File,
|
||||
FileInput,
|
||||
ImageFile,
|
||||
PDFFile,
|
||||
TextFile,
|
||||
VideoFile,
|
||||
)
|
||||
from crewai_files.processing.constraints import (
|
||||
ProviderConstraints,
|
||||
get_constraints_for_provider,
|
||||
)
|
||||
from crewai_files.processing.enums import FileHandling
|
||||
from crewai_files.processing.exceptions import (
|
||||
FileProcessingError,
|
||||
FileTooLargeError,
|
||||
FileValidationError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from crewai_files.processing.transformers import (
|
||||
chunk_pdf,
|
||||
chunk_text,
|
||||
get_image_dimensions,
|
||||
get_pdf_page_count,
|
||||
optimize_image,
|
||||
resize_image,
|
||||
)
|
||||
from crewai_files.processing.validators import validate_file
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileProcessor:
|
||||
"""Processes files according to provider constraints and per-file mode mode.
|
||||
|
||||
Validates files against provider-specific limits and optionally transforms
|
||||
them (resize, compress, chunk) to meet those limits. Each file specifies
|
||||
its own mode mode via `file.mode`.
|
||||
|
||||
Attributes:
|
||||
constraints: Provider constraints for validation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
constraints: ProviderConstraints | str | None = None,
|
||||
) -> None:
|
||||
"""Initialize the FileProcessor.
|
||||
|
||||
Args:
|
||||
constraints: Provider constraints or provider name string.
|
||||
If None, validation is skipped.
|
||||
"""
|
||||
if isinstance(constraints, str):
|
||||
resolved = get_constraints_for_provider(constraints)
|
||||
if resolved is None:
|
||||
logger.warning(
|
||||
f"Unknown provider '{constraints}' - validation disabled"
|
||||
)
|
||||
self.constraints = resolved
|
||||
else:
|
||||
self.constraints = constraints
|
||||
|
||||
def validate(self, file: FileInput) -> Sequence[str]:
|
||||
"""Validate a file against provider constraints.
|
||||
|
||||
Args:
|
||||
file: The file to validate.
|
||||
|
||||
Returns:
|
||||
List of validation error messages (empty if valid).
|
||||
|
||||
Raises:
|
||||
FileValidationError: If file.mode is STRICT and validation fails.
|
||||
"""
|
||||
if self.constraints is None:
|
||||
return []
|
||||
|
||||
mode = self._get_mode(file)
|
||||
raise_on_error = mode == FileHandling.STRICT
|
||||
return validate_file(file, self.constraints, raise_on_error=raise_on_error)
|
||||
|
||||
@staticmethod
|
||||
def _get_mode(file: FileInput) -> FileHandling:
|
||||
"""Get the mode mode for a file.
|
||||
|
||||
Args:
|
||||
file: The file to get mode for.
|
||||
|
||||
Returns:
|
||||
The file's mode mode, defaulting to AUTO.
|
||||
"""
|
||||
mode = getattr(file, "mode", None)
|
||||
if mode is None:
|
||||
return FileHandling.AUTO
|
||||
if isinstance(mode, str):
|
||||
return FileHandling(mode)
|
||||
if isinstance(mode, FileHandling):
|
||||
return mode
|
||||
return FileHandling.AUTO
|
||||
|
||||
def process(self, file: FileInput) -> FileInput | Sequence[FileInput]:
|
||||
"""Process a single file according to constraints and its mode mode.
|
||||
|
||||
Args:
|
||||
file: The file to process.
|
||||
|
||||
Returns:
|
||||
The processed file (possibly transformed) or a sequence of files
|
||||
if the file was chunked.
|
||||
|
||||
Raises:
|
||||
FileProcessingError: If file.mode is STRICT and processing fails.
|
||||
"""
|
||||
if self.constraints is None:
|
||||
return file
|
||||
|
||||
mode = self._get_mode(file)
|
||||
|
||||
try:
|
||||
errors = self.validate(file)
|
||||
|
||||
if not errors:
|
||||
return file
|
||||
|
||||
if mode == FileHandling.STRICT:
|
||||
raise FileValidationError("; ".join(errors), file_name=file.filename)
|
||||
|
||||
if mode == FileHandling.WARN:
|
||||
for error in errors:
|
||||
logger.warning(error)
|
||||
return file
|
||||
|
||||
if mode == FileHandling.AUTO:
|
||||
return self._auto_process(file)
|
||||
|
||||
if mode == FileHandling.CHUNK:
|
||||
return self._chunk_process(file)
|
||||
|
||||
return file
|
||||
|
||||
except (FileValidationError, FileTooLargeError, UnsupportedFileTypeError):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing file '{file.filename}': {e}")
|
||||
if mode == FileHandling.STRICT:
|
||||
raise FileProcessingError(str(e), file_name=file.filename) from e
|
||||
return file
|
||||
|
||||
def process_files(
|
||||
self,
|
||||
files: dict[str, FileInput],
|
||||
) -> dict[str, FileInput]:
|
||||
"""Process multiple files according to constraints.
|
||||
|
||||
Args:
|
||||
files: Dictionary mapping names to file inputs.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping names to processed files. If a file is chunked,
|
||||
multiple entries are created with indexed names.
|
||||
"""
|
||||
result: dict[str, FileInput] = {}
|
||||
|
||||
for name, file in files.items():
|
||||
processed = self.process(file)
|
||||
|
||||
if isinstance(processed, Sequence) and not isinstance(
|
||||
processed, (str, bytes)
|
||||
):
|
||||
for i, chunk in enumerate(processed):
|
||||
chunk_name = f"{name}_chunk_{i}"
|
||||
result[chunk_name] = chunk
|
||||
else:
|
||||
result[name] = processed
|
||||
|
||||
return result
|
||||
|
||||
async def aprocess_files(
|
||||
self,
|
||||
files: dict[str, FileInput],
|
||||
max_concurrency: int = 10,
|
||||
) -> dict[str, FileInput]:
|
||||
"""Async process multiple files in parallel.
|
||||
|
||||
Args:
|
||||
files: Dictionary mapping names to file inputs.
|
||||
max_concurrency: Maximum number of concurrent processing tasks.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping names to processed files. If a file is chunked,
|
||||
multiple entries are created with indexed names.
|
||||
"""
|
||||
semaphore = asyncio.Semaphore(max_concurrency)
|
||||
|
||||
async def process_single(
|
||||
key: str, input_file: FileInput
|
||||
) -> tuple[str, FileInput | Sequence[FileInput]]:
|
||||
"""Process a single file with semaphore limiting."""
|
||||
async with semaphore:
|
||||
loop = asyncio.get_running_loop()
|
||||
result = await loop.run_in_executor(None, self.process, input_file)
|
||||
return key, result
|
||||
|
||||
tasks = [process_single(n, f) for n, f in files.items()]
|
||||
gather_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
output: dict[str, FileInput] = {}
|
||||
for item in gather_results:
|
||||
if isinstance(item, BaseException):
|
||||
logger.error(f"Processing failed: {item}")
|
||||
continue
|
||||
entry_name, processed = item
|
||||
if isinstance(processed, Sequence) and not isinstance(
|
||||
processed, (str, bytes)
|
||||
):
|
||||
for i, chunk in enumerate(processed):
|
||||
output[f"{entry_name}_chunk_{i}"] = chunk
|
||||
elif isinstance(
|
||||
processed, (AudioFile, File, ImageFile, PDFFile, TextFile, VideoFile)
|
||||
):
|
||||
output[entry_name] = processed
|
||||
|
||||
return output
|
||||
|
||||
def _auto_process(self, file: FileInput) -> FileInput:
|
||||
"""Automatically resize/compress file to meet constraints.
|
||||
|
||||
Args:
|
||||
file: The file to process.
|
||||
|
||||
Returns:
|
||||
The processed file.
|
||||
"""
|
||||
if self.constraints is None:
|
||||
return file
|
||||
|
||||
if isinstance(file, ImageFile) and self.constraints.image is not None:
|
||||
return self._auto_process_image(file)
|
||||
|
||||
if isinstance(file, PDFFile) and self.constraints.pdf is not None:
|
||||
logger.warning(
|
||||
f"Cannot auto-compress PDF '{file.filename}'. "
|
||||
"Consider using CHUNK mode for large PDFs."
|
||||
)
|
||||
return file
|
||||
|
||||
if isinstance(file, (AudioFile, VideoFile)):
|
||||
logger.warning(
|
||||
f"Auto-processing not supported for {type(file).__name__}. "
|
||||
"File will be used as-is."
|
||||
)
|
||||
return file
|
||||
|
||||
return file
|
||||
|
||||
def _auto_process_image(self, file: ImageFile) -> ImageFile:
|
||||
"""Auto-process an image file.
|
||||
|
||||
Args:
|
||||
file: The image file to process.
|
||||
|
||||
Returns:
|
||||
The processed image file.
|
||||
"""
|
||||
if self.constraints is None or self.constraints.image is None:
|
||||
return file
|
||||
|
||||
image_constraints = self.constraints.image
|
||||
processed = file
|
||||
content = file.read()
|
||||
current_size = len(content)
|
||||
|
||||
if image_constraints.max_width or image_constraints.max_height:
|
||||
dimensions = get_image_dimensions(file)
|
||||
if dimensions:
|
||||
width, height = dimensions
|
||||
max_w = image_constraints.max_width or width
|
||||
max_h = image_constraints.max_height or height
|
||||
|
||||
if width > max_w or height > max_h:
|
||||
try:
|
||||
processed = resize_image(file, max_w, max_h)
|
||||
content = processed.read()
|
||||
current_size = len(content)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to resize image: {e}")
|
||||
|
||||
if current_size > image_constraints.max_size_bytes:
|
||||
try:
|
||||
processed = optimize_image(processed, image_constraints.max_size_bytes)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to optimize image: {e}")
|
||||
|
||||
return processed
|
||||
|
||||
def _chunk_process(self, file: FileInput) -> FileInput | Sequence[FileInput]:
|
||||
"""Split file into chunks to meet constraints.
|
||||
|
||||
Args:
|
||||
file: The file to chunk.
|
||||
|
||||
Returns:
|
||||
Original file if chunking not needed, or sequence of chunked files.
|
||||
"""
|
||||
if self.constraints is None:
|
||||
return file
|
||||
|
||||
if isinstance(file, PDFFile) and self.constraints.pdf is not None:
|
||||
max_pages = self.constraints.pdf.max_pages
|
||||
if max_pages is not None:
|
||||
page_count = get_pdf_page_count(file)
|
||||
if page_count is not None and page_count > max_pages:
|
||||
try:
|
||||
return list(chunk_pdf(file, max_pages))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to chunk PDF: {e}")
|
||||
return file
|
||||
|
||||
if isinstance(file, TextFile):
|
||||
# Use general max size as character limit approximation
|
||||
max_size = self.constraints.general_max_size_bytes
|
||||
if max_size is not None:
|
||||
content = file.read()
|
||||
if len(content) > max_size:
|
||||
try:
|
||||
return list(chunk_text(file, max_size))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to chunk text file: {e}")
|
||||
return file
|
||||
|
||||
if isinstance(file, (ImageFile, AudioFile, VideoFile)):
|
||||
logger.warning(
|
||||
f"Chunking not supported for {type(file).__name__}. "
|
||||
"Consider using AUTO mode for images."
|
||||
)
|
||||
|
||||
return file
|
||||
@@ -1,336 +0,0 @@
|
||||
"""File transformation functions for resizing, optimizing, and chunking."""
|
||||
|
||||
from collections.abc import Iterator
|
||||
import io
|
||||
import logging
|
||||
|
||||
from crewai_files.core.sources import FileBytes
|
||||
from crewai_files.core.types import ImageFile, PDFFile, TextFile
|
||||
from crewai_files.processing.exceptions import ProcessingDependencyError
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def resize_image(
|
||||
file: ImageFile,
|
||||
max_width: int,
|
||||
max_height: int,
|
||||
*,
|
||||
preserve_aspect_ratio: bool = True,
|
||||
) -> ImageFile:
|
||||
"""Resize an image to fit within the specified dimensions.
|
||||
|
||||
Args:
|
||||
file: The image file to resize.
|
||||
max_width: Maximum width in pixels.
|
||||
max_height: Maximum height in pixels.
|
||||
preserve_aspect_ratio: If True, maintain aspect ratio while fitting within bounds.
|
||||
|
||||
Returns:
|
||||
A new ImageFile with the resized image data.
|
||||
|
||||
Raises:
|
||||
ProcessingDependencyError: If Pillow is not installed.
|
||||
"""
|
||||
try:
|
||||
from PIL import Image
|
||||
except ImportError as e:
|
||||
raise ProcessingDependencyError(
|
||||
"Pillow is required for image resizing",
|
||||
dependency="Pillow",
|
||||
install_command="pip install Pillow",
|
||||
) from e
|
||||
|
||||
content = file.read()
|
||||
|
||||
with Image.open(io.BytesIO(content)) as img:
|
||||
original_width, original_height = img.size
|
||||
|
||||
if original_width <= max_width and original_height <= max_height:
|
||||
return file
|
||||
|
||||
if preserve_aspect_ratio:
|
||||
width_ratio = max_width / original_width
|
||||
height_ratio = max_height / original_height
|
||||
scale_factor = min(width_ratio, height_ratio)
|
||||
|
||||
new_width = int(original_width * scale_factor)
|
||||
new_height = int(original_height * scale_factor)
|
||||
else:
|
||||
new_width = min(original_width, max_width)
|
||||
new_height = min(original_height, max_height)
|
||||
|
||||
resized_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||||
|
||||
output_format = img.format or "PNG"
|
||||
if output_format.upper() == "JPEG":
|
||||
if resized_img.mode in ("RGBA", "LA", "P"):
|
||||
resized_img = resized_img.convert("RGB")
|
||||
|
||||
output_buffer = io.BytesIO()
|
||||
resized_img.save(output_buffer, format=output_format)
|
||||
output_bytes = output_buffer.getvalue()
|
||||
|
||||
logger.info(
|
||||
f"Resized image '{file.filename}' from {original_width}x{original_height} "
|
||||
f"to {new_width}x{new_height}"
|
||||
)
|
||||
|
||||
return ImageFile(source=FileBytes(data=output_bytes, filename=file.filename))
|
||||
|
||||
|
||||
def optimize_image(
|
||||
file: ImageFile,
|
||||
target_size_bytes: int,
|
||||
*,
|
||||
min_quality: int = 20,
|
||||
initial_quality: int = 85,
|
||||
) -> ImageFile:
|
||||
"""Optimize an image to fit within a target file size.
|
||||
|
||||
Uses iterative quality reduction to achieve target size.
|
||||
|
||||
Args:
|
||||
file: The image file to optimize.
|
||||
target_size_bytes: Target maximum file size in bytes.
|
||||
min_quality: Minimum quality to use (prevents excessive degradation).
|
||||
initial_quality: Starting quality for optimization.
|
||||
|
||||
Returns:
|
||||
A new ImageFile with the optimized image data.
|
||||
|
||||
Raises:
|
||||
ProcessingDependencyError: If Pillow is not installed.
|
||||
"""
|
||||
try:
|
||||
from PIL import Image
|
||||
except ImportError as e:
|
||||
raise ProcessingDependencyError(
|
||||
"Pillow is required for image optimization",
|
||||
dependency="Pillow",
|
||||
install_command="pip install Pillow",
|
||||
) from e
|
||||
|
||||
content = file.read()
|
||||
current_size = len(content)
|
||||
|
||||
if current_size <= target_size_bytes:
|
||||
return file
|
||||
|
||||
with Image.open(io.BytesIO(content)) as img:
|
||||
if img.mode in ("RGBA", "LA", "P"):
|
||||
img = img.convert("RGB")
|
||||
output_format = "JPEG"
|
||||
else:
|
||||
output_format = img.format or "JPEG"
|
||||
if output_format.upper() not in ("JPEG", "JPG"):
|
||||
output_format = "JPEG"
|
||||
|
||||
quality = initial_quality
|
||||
output_bytes = content
|
||||
|
||||
while len(output_bytes) > target_size_bytes and quality >= min_quality:
|
||||
output_buffer = io.BytesIO()
|
||||
img.save(
|
||||
output_buffer, format=output_format, quality=quality, optimize=True
|
||||
)
|
||||
output_bytes = output_buffer.getvalue()
|
||||
|
||||
if len(output_bytes) > target_size_bytes:
|
||||
quality -= 5
|
||||
|
||||
logger.info(
|
||||
f"Optimized image '{file.filename}' from {current_size} bytes to "
|
||||
f"{len(output_bytes)} bytes (quality={quality})"
|
||||
)
|
||||
|
||||
filename = file.filename
|
||||
if (
|
||||
filename
|
||||
and output_format.upper() == "JPEG"
|
||||
and not filename.lower().endswith((".jpg", ".jpeg"))
|
||||
):
|
||||
filename = filename.rsplit(".", 1)[0] + ".jpg"
|
||||
|
||||
return ImageFile(source=FileBytes(data=output_bytes, filename=filename))
|
||||
|
||||
|
||||
def chunk_pdf(
|
||||
file: PDFFile,
|
||||
max_pages: int,
|
||||
*,
|
||||
overlap_pages: int = 0,
|
||||
) -> Iterator[PDFFile]:
|
||||
"""Split a PDF into chunks of maximum page count.
|
||||
|
||||
Yields chunks one at a time to minimize memory usage.
|
||||
|
||||
Args:
|
||||
file: The PDF file to chunk.
|
||||
max_pages: Maximum pages per chunk.
|
||||
overlap_pages: Number of overlapping pages between chunks (for context).
|
||||
|
||||
Yields:
|
||||
PDFFile objects, one per chunk.
|
||||
|
||||
Raises:
|
||||
ProcessingDependencyError: If pypdf is not installed.
|
||||
"""
|
||||
try:
|
||||
from pypdf import PdfReader, PdfWriter
|
||||
except ImportError as e:
|
||||
raise ProcessingDependencyError(
|
||||
"pypdf is required for PDF chunking",
|
||||
dependency="pypdf",
|
||||
install_command="pip install pypdf",
|
||||
) from e
|
||||
|
||||
content = file.read()
|
||||
reader = PdfReader(io.BytesIO(content))
|
||||
total_pages = len(reader.pages)
|
||||
|
||||
if total_pages <= max_pages:
|
||||
yield file
|
||||
return
|
||||
|
||||
filename = file.filename or "document.pdf"
|
||||
base_filename = filename.rsplit(".", 1)[0]
|
||||
step = max_pages - overlap_pages
|
||||
|
||||
chunk_num = 0
|
||||
start_page = 0
|
||||
|
||||
while start_page < total_pages:
|
||||
end_page = min(start_page + max_pages, total_pages)
|
||||
|
||||
writer = PdfWriter()
|
||||
for page_num in range(start_page, end_page):
|
||||
writer.add_page(reader.pages[page_num])
|
||||
|
||||
output_buffer = io.BytesIO()
|
||||
writer.write(output_buffer)
|
||||
output_bytes = output_buffer.getvalue()
|
||||
|
||||
chunk_filename = f"{base_filename}_chunk_{chunk_num}.pdf"
|
||||
|
||||
logger.info(
|
||||
f"Created PDF chunk '{chunk_filename}' with pages {start_page + 1}-{end_page}"
|
||||
)
|
||||
|
||||
yield PDFFile(source=FileBytes(data=output_bytes, filename=chunk_filename))
|
||||
|
||||
start_page += step
|
||||
chunk_num += 1
|
||||
|
||||
|
||||
def chunk_text(
|
||||
file: TextFile,
|
||||
max_chars: int,
|
||||
*,
|
||||
overlap_chars: int = 200,
|
||||
split_on_newlines: bool = True,
|
||||
) -> Iterator[TextFile]:
|
||||
"""Split a text file into chunks of maximum character count.
|
||||
|
||||
Yields chunks one at a time to minimize memory usage.
|
||||
|
||||
Args:
|
||||
file: The text file to chunk.
|
||||
max_chars: Maximum characters per chunk.
|
||||
overlap_chars: Number of overlapping characters between chunks.
|
||||
split_on_newlines: If True, prefer splitting at newline boundaries.
|
||||
|
||||
Yields:
|
||||
TextFile objects, one per chunk.
|
||||
"""
|
||||
content = file.read()
|
||||
text = content.decode(errors="replace")
|
||||
total_chars = len(text)
|
||||
|
||||
if total_chars <= max_chars:
|
||||
yield file
|
||||
return
|
||||
|
||||
filename = file.filename or "text.txt"
|
||||
base_filename = filename.rsplit(".", 1)[0]
|
||||
extension = filename.rsplit(".", 1)[-1] if "." in filename else "txt"
|
||||
|
||||
chunk_num = 0
|
||||
start_pos = 0
|
||||
|
||||
while start_pos < total_chars:
|
||||
end_pos = min(start_pos + max_chars, total_chars)
|
||||
|
||||
if end_pos < total_chars and split_on_newlines:
|
||||
last_newline = text.rfind("\n", start_pos, end_pos)
|
||||
if last_newline > start_pos + max_chars // 2:
|
||||
end_pos = last_newline + 1
|
||||
|
||||
chunk_content = text[start_pos:end_pos]
|
||||
chunk_bytes = chunk_content.encode()
|
||||
|
||||
chunk_filename = f"{base_filename}_chunk_{chunk_num}.{extension}"
|
||||
|
||||
logger.info(
|
||||
f"Created text chunk '{chunk_filename}' with {len(chunk_content)} characters"
|
||||
)
|
||||
|
||||
yield TextFile(source=FileBytes(data=chunk_bytes, filename=chunk_filename))
|
||||
|
||||
if end_pos < total_chars:
|
||||
start_pos = max(start_pos + 1, end_pos - overlap_chars)
|
||||
else:
|
||||
start_pos = total_chars
|
||||
chunk_num += 1
|
||||
|
||||
|
||||
def get_image_dimensions(file: ImageFile) -> tuple[int, int] | None:
|
||||
"""Get the dimensions of an image file.
|
||||
|
||||
Args:
|
||||
file: The image file to measure.
|
||||
|
||||
Returns:
|
||||
Tuple of (width, height) in pixels, or None if dimensions cannot be determined.
|
||||
"""
|
||||
try:
|
||||
from PIL import Image
|
||||
except ImportError:
|
||||
logger.warning("Pillow not installed - cannot get image dimensions")
|
||||
return None
|
||||
|
||||
content = file.read()
|
||||
|
||||
try:
|
||||
with Image.open(io.BytesIO(content)) as img:
|
||||
width, height = img.size
|
||||
return width, height
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get image dimensions: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_pdf_page_count(file: PDFFile) -> int | None:
|
||||
"""Get the page count of a PDF file.
|
||||
|
||||
Args:
|
||||
file: The PDF file to measure.
|
||||
|
||||
Returns:
|
||||
Number of pages, or None if page count cannot be determined.
|
||||
"""
|
||||
try:
|
||||
from pypdf import PdfReader
|
||||
except ImportError:
|
||||
logger.warning("pypdf not installed - cannot get PDF page count")
|
||||
return None
|
||||
|
||||
content = file.read()
|
||||
|
||||
try:
|
||||
reader = PdfReader(io.BytesIO(content))
|
||||
return len(reader.pages)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get PDF page count: {e}")
|
||||
return None
|
||||
@@ -1,564 +0,0 @@
|
||||
"""File validation functions for checking against provider constraints."""
|
||||
|
||||
from collections.abc import Sequence
|
||||
import io
|
||||
import logging
|
||||
|
||||
from crewai_files.core.types import (
|
||||
AudioFile,
|
||||
FileInput,
|
||||
ImageFile,
|
||||
PDFFile,
|
||||
TextFile,
|
||||
VideoFile,
|
||||
)
|
||||
from crewai_files.processing.constraints import (
|
||||
AudioConstraints,
|
||||
ImageConstraints,
|
||||
PDFConstraints,
|
||||
ProviderConstraints,
|
||||
VideoConstraints,
|
||||
)
|
||||
from crewai_files.processing.exceptions import (
|
||||
FileTooLargeError,
|
||||
FileValidationError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_image_dimensions(content: bytes) -> tuple[int, int] | None:
|
||||
"""Get image dimensions using Pillow if available.
|
||||
|
||||
Args:
|
||||
content: Raw image bytes.
|
||||
|
||||
Returns:
|
||||
Tuple of (width, height) or None if Pillow unavailable.
|
||||
"""
|
||||
try:
|
||||
from PIL import Image
|
||||
|
||||
with Image.open(io.BytesIO(content)) as img:
|
||||
width, height = img.size
|
||||
return int(width), int(height)
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"Pillow not installed - cannot validate image dimensions. "
|
||||
"Install with: pip install Pillow"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _get_pdf_page_count(content: bytes) -> int | None:
|
||||
"""Get PDF page count using pypdf if available.
|
||||
|
||||
Args:
|
||||
content: Raw PDF bytes.
|
||||
|
||||
Returns:
|
||||
Page count or None if pypdf unavailable.
|
||||
"""
|
||||
try:
|
||||
from pypdf import PdfReader
|
||||
|
||||
reader = PdfReader(io.BytesIO(content))
|
||||
return len(reader.pages)
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"pypdf not installed - cannot validate PDF page count. "
|
||||
"Install with: pip install pypdf"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _get_audio_duration(content: bytes, filename: str | None = None) -> float | None:
|
||||
"""Get audio duration in seconds using tinytag if available.
|
||||
|
||||
Args:
|
||||
content: Raw audio bytes.
|
||||
filename: Optional filename for format detection hint.
|
||||
|
||||
Returns:
|
||||
Duration in seconds or None if tinytag unavailable.
|
||||
"""
|
||||
try:
|
||||
from tinytag import TinyTag # type: ignore[import-untyped]
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"tinytag not installed - cannot validate audio duration. "
|
||||
"Install with: pip install tinytag"
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
tag = TinyTag.get(file_obj=io.BytesIO(content), filename=filename)
|
||||
duration: float | None = tag.duration
|
||||
return duration
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not determine audio duration: {e}")
|
||||
return None
|
||||
|
||||
|
||||
_VIDEO_FORMAT_MAP: dict[str, str] = {
|
||||
"video/mp4": "mp4",
|
||||
"video/webm": "webm",
|
||||
"video/x-matroska": "matroska",
|
||||
"video/quicktime": "mov",
|
||||
"video/x-msvideo": "avi",
|
||||
"video/x-flv": "flv",
|
||||
}
|
||||
|
||||
|
||||
def _get_video_duration(
|
||||
content: bytes, content_type: str | None = None
|
||||
) -> float | None:
|
||||
"""Get video duration in seconds using av if available.
|
||||
|
||||
Args:
|
||||
content: Raw video bytes.
|
||||
content_type: Optional MIME type for format detection hint.
|
||||
|
||||
Returns:
|
||||
Duration in seconds or None if av unavailable.
|
||||
"""
|
||||
try:
|
||||
import av
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"av (PyAV) not installed - cannot validate video duration. "
|
||||
"Install with: pip install av"
|
||||
)
|
||||
return None
|
||||
|
||||
format_hint = _VIDEO_FORMAT_MAP.get(content_type) if content_type else None
|
||||
|
||||
try:
|
||||
with av.open(io.BytesIO(content), format=format_hint) as container: # type: ignore[attr-defined]
|
||||
duration: int | None = container.duration # type: ignore[union-attr]
|
||||
if duration is None:
|
||||
return None
|
||||
return float(duration) / 1_000_000
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not determine video duration: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _format_size(size_bytes: int) -> str:
|
||||
"""Format byte size to human-readable string."""
|
||||
if size_bytes >= 1024 * 1024 * 1024:
|
||||
return f"{size_bytes / (1024 * 1024 * 1024):.1f}GB"
|
||||
if size_bytes >= 1024 * 1024:
|
||||
return f"{size_bytes / (1024 * 1024):.1f}MB"
|
||||
if size_bytes >= 1024:
|
||||
return f"{size_bytes / 1024:.1f}KB"
|
||||
return f"{size_bytes}B"
|
||||
|
||||
|
||||
def _validate_size(
|
||||
file_type: str,
|
||||
filename: str | None,
|
||||
file_size: int,
|
||||
max_size: int,
|
||||
errors: list[str],
|
||||
raise_on_error: bool,
|
||||
) -> None:
|
||||
"""Validate file size against maximum.
|
||||
|
||||
Args:
|
||||
file_type: Type label for error messages (e.g., "Image", "PDF").
|
||||
filename: Name of the file being validated.
|
||||
file_size: Actual file size in bytes.
|
||||
max_size: Maximum allowed size in bytes.
|
||||
errors: List to append error messages to.
|
||||
raise_on_error: If True, raise FileTooLargeError on failure.
|
||||
"""
|
||||
if file_size > max_size:
|
||||
msg = (
|
||||
f"{file_type} '{filename}' size ({_format_size(file_size)}) exceeds "
|
||||
f"maximum ({_format_size(max_size)})"
|
||||
)
|
||||
errors.append(msg)
|
||||
if raise_on_error:
|
||||
raise FileTooLargeError(
|
||||
msg,
|
||||
file_name=filename,
|
||||
actual_size=file_size,
|
||||
max_size=max_size,
|
||||
)
|
||||
|
||||
|
||||
def _validate_format(
|
||||
file_type: str,
|
||||
filename: str | None,
|
||||
content_type: str,
|
||||
supported_formats: tuple[str, ...],
|
||||
errors: list[str],
|
||||
raise_on_error: bool,
|
||||
) -> None:
|
||||
"""Validate content type against supported formats.
|
||||
|
||||
Args:
|
||||
file_type: Type label for error messages (e.g., "Image", "Audio").
|
||||
filename: Name of the file being validated.
|
||||
content_type: MIME type of the file.
|
||||
supported_formats: Tuple of supported MIME types.
|
||||
errors: List to append error messages to.
|
||||
raise_on_error: If True, raise UnsupportedFileTypeError on failure.
|
||||
"""
|
||||
if content_type not in supported_formats:
|
||||
msg = (
|
||||
f"{file_type} format '{content_type}' is not supported. "
|
||||
f"Supported: {', '.join(supported_formats)}"
|
||||
)
|
||||
errors.append(msg)
|
||||
if raise_on_error:
|
||||
raise UnsupportedFileTypeError(
|
||||
msg, file_name=filename, content_type=content_type
|
||||
)
|
||||
|
||||
|
||||
def validate_image(
|
||||
file: ImageFile,
|
||||
constraints: ImageConstraints,
|
||||
*,
|
||||
raise_on_error: bool = True,
|
||||
) -> Sequence[str]:
|
||||
"""Validate an image file against constraints.
|
||||
|
||||
Args:
|
||||
file: The image file to validate.
|
||||
constraints: Image constraints to validate against.
|
||||
raise_on_error: If True, raise exceptions on validation failure.
|
||||
|
||||
Returns:
|
||||
List of validation error messages (empty if valid).
|
||||
|
||||
Raises:
|
||||
FileTooLargeError: If the file exceeds size limits.
|
||||
FileValidationError: If the file exceeds dimension limits.
|
||||
UnsupportedFileTypeError: If the format is not supported.
|
||||
"""
|
||||
errors: list[str] = []
|
||||
content = file.read()
|
||||
file_size = len(content)
|
||||
filename = file.filename
|
||||
|
||||
_validate_size(
|
||||
"Image", filename, file_size, constraints.max_size_bytes, errors, raise_on_error
|
||||
)
|
||||
_validate_format(
|
||||
"Image",
|
||||
filename,
|
||||
file.content_type,
|
||||
constraints.supported_formats,
|
||||
errors,
|
||||
raise_on_error,
|
||||
)
|
||||
|
||||
if constraints.max_width is not None or constraints.max_height is not None:
|
||||
dimensions = _get_image_dimensions(content)
|
||||
if dimensions is not None:
|
||||
width, height = dimensions
|
||||
|
||||
if constraints.max_width and width > constraints.max_width:
|
||||
msg = (
|
||||
f"Image '{filename}' width ({width}px) exceeds "
|
||||
f"maximum ({constraints.max_width}px)"
|
||||
)
|
||||
errors.append(msg)
|
||||
if raise_on_error:
|
||||
raise FileValidationError(msg, file_name=filename)
|
||||
|
||||
if constraints.max_height and height > constraints.max_height:
|
||||
msg = (
|
||||
f"Image '{filename}' height ({height}px) exceeds "
|
||||
f"maximum ({constraints.max_height}px)"
|
||||
)
|
||||
errors.append(msg)
|
||||
if raise_on_error:
|
||||
raise FileValidationError(msg, file_name=filename)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def validate_pdf(
|
||||
file: PDFFile,
|
||||
constraints: PDFConstraints,
|
||||
*,
|
||||
raise_on_error: bool = True,
|
||||
) -> Sequence[str]:
|
||||
"""Validate a PDF file against constraints.
|
||||
|
||||
Args:
|
||||
file: The PDF file to validate.
|
||||
constraints: PDF constraints to validate against.
|
||||
raise_on_error: If True, raise exceptions on validation failure.
|
||||
|
||||
Returns:
|
||||
List of validation error messages (empty if valid).
|
||||
|
||||
Raises:
|
||||
FileTooLargeError: If the file exceeds size limits.
|
||||
FileValidationError: If the file exceeds page limits.
|
||||
"""
|
||||
errors: list[str] = []
|
||||
content = file.read()
|
||||
file_size = len(content)
|
||||
filename = file.filename
|
||||
|
||||
_validate_size(
|
||||
"PDF", filename, file_size, constraints.max_size_bytes, errors, raise_on_error
|
||||
)
|
||||
|
||||
if constraints.max_pages is not None:
|
||||
page_count = _get_pdf_page_count(content)
|
||||
if page_count is not None and page_count > constraints.max_pages:
|
||||
msg = (
|
||||
f"PDF '{filename}' page count ({page_count}) exceeds "
|
||||
f"maximum ({constraints.max_pages})"
|
||||
)
|
||||
errors.append(msg)
|
||||
if raise_on_error:
|
||||
raise FileValidationError(msg, file_name=filename)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def validate_audio(
|
||||
file: AudioFile,
|
||||
constraints: AudioConstraints,
|
||||
*,
|
||||
raise_on_error: bool = True,
|
||||
) -> Sequence[str]:
|
||||
"""Validate an audio file against constraints.
|
||||
|
||||
Args:
|
||||
file: The audio file to validate.
|
||||
constraints: Audio constraints to validate against.
|
||||
raise_on_error: If True, raise exceptions on validation failure.
|
||||
|
||||
Returns:
|
||||
List of validation error messages (empty if valid).
|
||||
|
||||
Raises:
|
||||
FileTooLargeError: If the file exceeds size limits.
|
||||
FileValidationError: If the file exceeds duration limits.
|
||||
UnsupportedFileTypeError: If the format is not supported.
|
||||
"""
|
||||
errors: list[str] = []
|
||||
content = file.read()
|
||||
file_size = len(content)
|
||||
filename = file.filename
|
||||
|
||||
_validate_size(
|
||||
"Audio",
|
||||
filename,
|
||||
file_size,
|
||||
constraints.max_size_bytes,
|
||||
errors,
|
||||
raise_on_error,
|
||||
)
|
||||
_validate_format(
|
||||
"Audio",
|
||||
filename,
|
||||
file.content_type,
|
||||
constraints.supported_formats,
|
||||
errors,
|
||||
raise_on_error,
|
||||
)
|
||||
|
||||
if constraints.max_duration_seconds is not None:
|
||||
duration = _get_audio_duration(content, filename)
|
||||
if duration is not None and duration > constraints.max_duration_seconds:
|
||||
msg = (
|
||||
f"Audio '{filename}' duration ({duration:.1f}s) exceeds "
|
||||
f"maximum ({constraints.max_duration_seconds}s)"
|
||||
)
|
||||
errors.append(msg)
|
||||
if raise_on_error:
|
||||
raise FileValidationError(msg, file_name=filename)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def validate_video(
|
||||
file: VideoFile,
|
||||
constraints: VideoConstraints,
|
||||
*,
|
||||
raise_on_error: bool = True,
|
||||
) -> Sequence[str]:
|
||||
"""Validate a video file against constraints.
|
||||
|
||||
Args:
|
||||
file: The video file to validate.
|
||||
constraints: Video constraints to validate against.
|
||||
raise_on_error: If True, raise exceptions on validation failure.
|
||||
|
||||
Returns:
|
||||
List of validation error messages (empty if valid).
|
||||
|
||||
Raises:
|
||||
FileTooLargeError: If the file exceeds size limits.
|
||||
FileValidationError: If the file exceeds duration limits.
|
||||
UnsupportedFileTypeError: If the format is not supported.
|
||||
"""
|
||||
errors: list[str] = []
|
||||
content = file.read()
|
||||
file_size = len(content)
|
||||
filename = file.filename
|
||||
|
||||
_validate_size(
|
||||
"Video",
|
||||
filename,
|
||||
file_size,
|
||||
constraints.max_size_bytes,
|
||||
errors,
|
||||
raise_on_error,
|
||||
)
|
||||
_validate_format(
|
||||
"Video",
|
||||
filename,
|
||||
file.content_type,
|
||||
constraints.supported_formats,
|
||||
errors,
|
||||
raise_on_error,
|
||||
)
|
||||
|
||||
if constraints.max_duration_seconds is not None:
|
||||
duration = _get_video_duration(content)
|
||||
if duration is not None and duration > constraints.max_duration_seconds:
|
||||
msg = (
|
||||
f"Video '{filename}' duration ({duration:.1f}s) exceeds "
|
||||
f"maximum ({constraints.max_duration_seconds}s)"
|
||||
)
|
||||
errors.append(msg)
|
||||
if raise_on_error:
|
||||
raise FileValidationError(msg, file_name=filename)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def validate_text(
|
||||
file: TextFile,
|
||||
constraints: ProviderConstraints,
|
||||
*,
|
||||
raise_on_error: bool = True,
|
||||
) -> Sequence[str]:
|
||||
"""Validate a text file against general constraints.
|
||||
|
||||
Args:
|
||||
file: The text file to validate.
|
||||
constraints: Provider constraints to validate against.
|
||||
raise_on_error: If True, raise exceptions on validation failure.
|
||||
|
||||
Returns:
|
||||
List of validation error messages (empty if valid).
|
||||
|
||||
Raises:
|
||||
FileTooLargeError: If the file exceeds size limits.
|
||||
"""
|
||||
errors: list[str] = []
|
||||
|
||||
if constraints.general_max_size_bytes is None:
|
||||
return errors
|
||||
|
||||
file_size = len(file.read())
|
||||
_validate_size(
|
||||
"Text file",
|
||||
file.filename,
|
||||
file_size,
|
||||
constraints.general_max_size_bytes,
|
||||
errors,
|
||||
raise_on_error,
|
||||
)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def _check_unsupported_type(
|
||||
file: FileInput,
|
||||
provider_name: str,
|
||||
type_name: str,
|
||||
raise_on_error: bool,
|
||||
) -> Sequence[str]:
|
||||
"""Check if file type is unsupported and handle error.
|
||||
|
||||
Args:
|
||||
file: The file being validated.
|
||||
provider_name: Name of the provider.
|
||||
type_name: Name of the file type (e.g., "images", "PDFs").
|
||||
raise_on_error: If True, raise exception instead of returning errors.
|
||||
|
||||
Returns:
|
||||
List with error message (only returns when raise_on_error is False).
|
||||
|
||||
Raises:
|
||||
UnsupportedFileTypeError: If raise_on_error is True.
|
||||
"""
|
||||
msg = f"Provider '{provider_name}' does not support {type_name}"
|
||||
if raise_on_error:
|
||||
raise UnsupportedFileTypeError(
|
||||
msg, file_name=file.filename, content_type=file.content_type
|
||||
)
|
||||
return [msg]
|
||||
|
||||
|
||||
def validate_file(
|
||||
file: FileInput,
|
||||
constraints: ProviderConstraints,
|
||||
*,
|
||||
raise_on_error: bool = True,
|
||||
) -> Sequence[str]:
|
||||
"""Validate a file against provider constraints.
|
||||
|
||||
Dispatches to the appropriate validator based on file type.
|
||||
|
||||
Args:
|
||||
file: The file to validate.
|
||||
constraints: Provider constraints to validate against.
|
||||
raise_on_error: If True, raise exceptions on validation failure.
|
||||
|
||||
Returns:
|
||||
List of validation error messages (empty if valid).
|
||||
|
||||
Raises:
|
||||
FileTooLargeError: If the file exceeds size limits.
|
||||
FileValidationError: If the file fails other validation checks.
|
||||
UnsupportedFileTypeError: If the file type is not supported.
|
||||
"""
|
||||
if isinstance(file, ImageFile):
|
||||
if constraints.image is None:
|
||||
return _check_unsupported_type(
|
||||
file, constraints.name, "images", raise_on_error
|
||||
)
|
||||
return validate_image(file, constraints.image, raise_on_error=raise_on_error)
|
||||
|
||||
if isinstance(file, PDFFile):
|
||||
if constraints.pdf is None:
|
||||
return _check_unsupported_type(
|
||||
file, constraints.name, "PDFs", raise_on_error
|
||||
)
|
||||
return validate_pdf(file, constraints.pdf, raise_on_error=raise_on_error)
|
||||
|
||||
if isinstance(file, AudioFile):
|
||||
if constraints.audio is None:
|
||||
return _check_unsupported_type(
|
||||
file, constraints.name, "audio", raise_on_error
|
||||
)
|
||||
return validate_audio(file, constraints.audio, raise_on_error=raise_on_error)
|
||||
|
||||
if isinstance(file, VideoFile):
|
||||
if constraints.video is None:
|
||||
return _check_unsupported_type(
|
||||
file, constraints.name, "video", raise_on_error
|
||||
)
|
||||
return validate_video(file, constraints.video, raise_on_error=raise_on_error)
|
||||
|
||||
if isinstance(file, TextFile):
|
||||
return validate_text(file, constraints, raise_on_error=raise_on_error)
|
||||
|
||||
return []
|
||||
@@ -1,16 +0,0 @@
|
||||
"""File resolution logic."""
|
||||
|
||||
from crewai_files.resolution.resolver import FileResolver
|
||||
from crewai_files.resolution.utils import (
|
||||
is_file_source,
|
||||
normalize_input_files,
|
||||
wrap_file_source,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"FileResolver",
|
||||
"is_file_source",
|
||||
"normalize_input_files",
|
||||
"wrap_file_source",
|
||||
]
|
||||
@@ -1,670 +0,0 @@
|
||||
"""FileResolver for deciding file delivery method and managing uploads."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
from dataclasses import dataclass, field
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
from crewai_files.cache.metrics import measure_operation
|
||||
from crewai_files.cache.upload_cache import CachedUpload, UploadCache
|
||||
from crewai_files.core.constants import UPLOAD_MAX_RETRIES, UPLOAD_RETRY_DELAY_BASE
|
||||
from crewai_files.core.resolved import (
|
||||
FileReference,
|
||||
InlineBase64,
|
||||
InlineBytes,
|
||||
ResolvedFile,
|
||||
UrlReference,
|
||||
)
|
||||
from crewai_files.core.sources import FileUrl
|
||||
from crewai_files.core.types import FileInput
|
||||
from crewai_files.processing.constraints import (
|
||||
AudioConstraints,
|
||||
ImageConstraints,
|
||||
PDFConstraints,
|
||||
ProviderConstraints,
|
||||
VideoConstraints,
|
||||
get_constraints_for_provider,
|
||||
)
|
||||
from crewai_files.uploaders import UploadResult, get_uploader
|
||||
from crewai_files.uploaders.base import FileUploader
|
||||
from crewai_files.uploaders.factory import ProviderType
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileContext:
|
||||
"""Cached file metadata to avoid redundant reads.
|
||||
|
||||
Attributes:
|
||||
content: Raw file bytes.
|
||||
size: Size of the file in bytes.
|
||||
content_hash: SHA-256 hash of the file content.
|
||||
content_type: MIME type of the file.
|
||||
"""
|
||||
|
||||
content: bytes
|
||||
size: int
|
||||
content_hash: str
|
||||
content_type: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileResolverConfig:
|
||||
"""Configuration for FileResolver.
|
||||
|
||||
Attributes:
|
||||
prefer_upload: If True, prefer uploading over inline for supported providers.
|
||||
upload_threshold_bytes: Size threshold above which to use upload.
|
||||
If None, uses provider-specific threshold.
|
||||
use_bytes_for_bedrock: If True, use raw bytes instead of base64 for Bedrock.
|
||||
"""
|
||||
|
||||
prefer_upload: bool = False
|
||||
upload_threshold_bytes: int | None = None
|
||||
use_bytes_for_bedrock: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileResolver:
|
||||
"""Resolves files to their delivery format based on provider capabilities.
|
||||
|
||||
Decides whether to use inline base64, raw bytes, or file upload based on:
|
||||
- Provider constraints and capabilities
|
||||
- File size
|
||||
- Configuration preferences
|
||||
|
||||
Caches uploaded files to avoid redundant uploads.
|
||||
|
||||
Attributes:
|
||||
config: Resolver configuration.
|
||||
upload_cache: Cache for tracking uploaded files.
|
||||
"""
|
||||
|
||||
config: FileResolverConfig = field(default_factory=FileResolverConfig)
|
||||
upload_cache: UploadCache | None = None
|
||||
_uploaders: dict[str, FileUploader] = field(default_factory=dict)
|
||||
|
||||
@staticmethod
|
||||
def _build_file_context(file: FileInput) -> FileContext:
|
||||
"""Build context by reading file once.
|
||||
|
||||
Args:
|
||||
file: The file to build context for.
|
||||
|
||||
Returns:
|
||||
FileContext with cached metadata.
|
||||
"""
|
||||
content = file.read()
|
||||
return FileContext(
|
||||
content=content,
|
||||
size=len(content),
|
||||
content_hash=hashlib.sha256(content).hexdigest(),
|
||||
content_type=file.content_type,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_url_source(file: FileInput) -> bool:
|
||||
"""Check if file source is a URL.
|
||||
|
||||
Args:
|
||||
file: The file to check.
|
||||
|
||||
Returns:
|
||||
True if the file source is a FileUrl, False otherwise.
|
||||
"""
|
||||
return isinstance(file._file_source, FileUrl)
|
||||
|
||||
@staticmethod
|
||||
def _supports_url(constraints: ProviderConstraints | None) -> bool:
|
||||
"""Check if provider supports URL references.
|
||||
|
||||
Args:
|
||||
constraints: Provider constraints.
|
||||
|
||||
Returns:
|
||||
True if the provider supports URL references, False otherwise.
|
||||
"""
|
||||
return constraints is not None and constraints.supports_url_references
|
||||
|
||||
@staticmethod
|
||||
def _resolve_as_url(file: FileInput) -> UrlReference:
|
||||
"""Resolve a URL source as UrlReference.
|
||||
|
||||
Args:
|
||||
file: The file with URL source.
|
||||
|
||||
Returns:
|
||||
UrlReference with the URL and content type.
|
||||
"""
|
||||
source = file._file_source
|
||||
if not isinstance(source, FileUrl):
|
||||
raise TypeError(f"Expected FileUrl source, got {type(source).__name__}")
|
||||
return UrlReference(
|
||||
content_type=file.content_type,
|
||||
url=source.url,
|
||||
)
|
||||
|
||||
def resolve(self, file: FileInput, provider: ProviderType) -> ResolvedFile:
|
||||
"""Resolve a file to its delivery format for a provider.
|
||||
|
||||
Args:
|
||||
file: The file to resolve.
|
||||
provider: Provider name (e.g., "gemini", "anthropic", "openai").
|
||||
|
||||
Returns:
|
||||
ResolvedFile representing the appropriate delivery format.
|
||||
"""
|
||||
constraints = get_constraints_for_provider(provider)
|
||||
|
||||
if self._is_url_source(file) and self._supports_url(constraints):
|
||||
return self._resolve_as_url(file)
|
||||
|
||||
context = self._build_file_context(file)
|
||||
|
||||
should_upload = self._should_upload(file, provider, constraints, context.size)
|
||||
|
||||
if should_upload:
|
||||
resolved = self._resolve_via_upload(file, provider, context)
|
||||
if resolved is not None:
|
||||
return resolved
|
||||
|
||||
return self._resolve_inline(file, provider, context)
|
||||
|
||||
def resolve_files(
|
||||
self,
|
||||
files: dict[str, FileInput],
|
||||
provider: ProviderType,
|
||||
) -> dict[str, ResolvedFile]:
|
||||
"""Resolve multiple files for a provider.
|
||||
|
||||
Args:
|
||||
files: Dictionary mapping names to file inputs.
|
||||
provider: Provider name.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping names to resolved files.
|
||||
"""
|
||||
return {name: self.resolve(file, provider) for name, file in files.items()}
|
||||
|
||||
@staticmethod
|
||||
def _get_type_constraint(
|
||||
content_type: str,
|
||||
constraints: ProviderConstraints,
|
||||
) -> ImageConstraints | PDFConstraints | AudioConstraints | VideoConstraints | None:
|
||||
"""Get type-specific constraint based on content type.
|
||||
|
||||
Args:
|
||||
content_type: MIME type of the file.
|
||||
constraints: Provider constraints.
|
||||
|
||||
Returns:
|
||||
Type-specific constraint or None if not found.
|
||||
"""
|
||||
if content_type.startswith("image/"):
|
||||
return constraints.image
|
||||
if content_type == "application/pdf":
|
||||
return constraints.pdf
|
||||
if content_type.startswith("audio/"):
|
||||
return constraints.audio
|
||||
if content_type.startswith("video/"):
|
||||
return constraints.video
|
||||
return None
|
||||
|
||||
def _should_upload(
|
||||
self,
|
||||
file: FileInput,
|
||||
provider: str,
|
||||
constraints: ProviderConstraints | None,
|
||||
file_size: int,
|
||||
) -> bool:
|
||||
"""Determine if a file should be uploaded rather than inlined.
|
||||
|
||||
Uses type-specific constraints to make smarter decisions:
|
||||
- Checks if file exceeds type-specific inline size limits
|
||||
- Falls back to general threshold if no type-specific constraint
|
||||
|
||||
Args:
|
||||
file: The file to check.
|
||||
provider: Provider name.
|
||||
constraints: Provider constraints.
|
||||
file_size: Size of the file in bytes.
|
||||
|
||||
Returns:
|
||||
True if the file should be uploaded, False otherwise.
|
||||
"""
|
||||
if constraints is None or not constraints.supports_file_upload:
|
||||
return False
|
||||
|
||||
if self.config.prefer_upload:
|
||||
return True
|
||||
|
||||
content_type = file.content_type
|
||||
type_constraint = self._get_type_constraint(content_type, constraints)
|
||||
|
||||
if type_constraint is not None:
|
||||
# Check if file exceeds type-specific inline limit
|
||||
if file_size > type_constraint.max_size_bytes:
|
||||
logger.debug(
|
||||
f"File {file.filename} ({file_size}B) exceeds {content_type} "
|
||||
f"inline limit ({type_constraint.max_size_bytes}B) for {provider}"
|
||||
)
|
||||
return True
|
||||
|
||||
# Fall back to general threshold
|
||||
threshold = self.config.upload_threshold_bytes
|
||||
if threshold is None:
|
||||
threshold = constraints.file_upload_threshold_bytes
|
||||
|
||||
if threshold is not None and file_size > threshold:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _resolve_via_upload(
|
||||
self,
|
||||
file: FileInput,
|
||||
provider: ProviderType,
|
||||
context: FileContext,
|
||||
) -> ResolvedFile | None:
|
||||
"""Resolve a file by uploading it.
|
||||
|
||||
Args:
|
||||
file: The file to upload.
|
||||
provider: Provider name.
|
||||
context: Pre-computed file context.
|
||||
|
||||
Returns:
|
||||
FileReference if upload succeeds, None otherwise.
|
||||
"""
|
||||
if self.upload_cache is not None:
|
||||
cached = self.upload_cache.get_by_hash(context.content_hash, provider)
|
||||
if cached is not None:
|
||||
logger.debug(
|
||||
f"Using cached upload for {file.filename}: {cached.file_id}"
|
||||
)
|
||||
return FileReference(
|
||||
content_type=cached.content_type,
|
||||
file_id=cached.file_id,
|
||||
provider=cached.provider,
|
||||
expires_at=cached.expires_at,
|
||||
file_uri=cached.file_uri,
|
||||
)
|
||||
|
||||
uploader = self._get_uploader(provider)
|
||||
if uploader is None:
|
||||
logger.debug(f"No uploader available for {provider}")
|
||||
return None
|
||||
|
||||
result = self._upload_with_retry(uploader, file, provider, context.size)
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
if self.upload_cache is not None:
|
||||
self.upload_cache.set_by_hash(
|
||||
file_hash=context.content_hash,
|
||||
content_type=context.content_type,
|
||||
provider=provider,
|
||||
file_id=result.file_id,
|
||||
file_uri=result.file_uri,
|
||||
expires_at=result.expires_at,
|
||||
)
|
||||
|
||||
return FileReference(
|
||||
content_type=result.content_type,
|
||||
file_id=result.file_id,
|
||||
provider=result.provider,
|
||||
expires_at=result.expires_at,
|
||||
file_uri=result.file_uri,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _upload_with_retry(
|
||||
uploader: FileUploader,
|
||||
file: FileInput,
|
||||
provider: str,
|
||||
file_size: int,
|
||||
) -> UploadResult | None:
|
||||
"""Upload with exponential backoff retry.
|
||||
|
||||
Args:
|
||||
uploader: The uploader to use.
|
||||
file: The file to upload.
|
||||
provider: Provider name for logging.
|
||||
file_size: Size of the file in bytes.
|
||||
|
||||
Returns:
|
||||
UploadResult if successful, None otherwise.
|
||||
"""
|
||||
import time
|
||||
|
||||
from crewai_files.processing.exceptions import (
|
||||
PermanentUploadError,
|
||||
TransientUploadError,
|
||||
)
|
||||
|
||||
last_error: Exception | None = None
|
||||
|
||||
for attempt in range(UPLOAD_MAX_RETRIES):
|
||||
with measure_operation(
|
||||
"upload",
|
||||
filename=file.filename,
|
||||
provider=provider,
|
||||
size_bytes=file_size,
|
||||
attempt=attempt + 1,
|
||||
) as metrics:
|
||||
try:
|
||||
result = uploader.upload(file)
|
||||
metrics.metadata["file_id"] = result.file_id
|
||||
return result
|
||||
except PermanentUploadError as e:
|
||||
metrics.metadata["error_type"] = "permanent"
|
||||
logger.warning(
|
||||
f"Non-retryable upload error for {file.filename}: {e}"
|
||||
)
|
||||
return None
|
||||
except TransientUploadError as e:
|
||||
metrics.metadata["error_type"] = "transient"
|
||||
last_error = e
|
||||
except Exception as e:
|
||||
metrics.metadata["error_type"] = "unknown"
|
||||
last_error = e
|
||||
|
||||
if attempt < UPLOAD_MAX_RETRIES - 1:
|
||||
delay = UPLOAD_RETRY_DELAY_BASE**attempt
|
||||
logger.debug(
|
||||
f"Retrying upload for {file.filename} in {delay}s (attempt {attempt + 1})"
|
||||
)
|
||||
time.sleep(delay)
|
||||
|
||||
logger.warning(
|
||||
f"Upload failed for {file.filename} to {provider} after {UPLOAD_MAX_RETRIES} attempts: {last_error}"
|
||||
)
|
||||
return None
|
||||
|
||||
def _resolve_inline(
|
||||
self,
|
||||
file: FileInput,
|
||||
provider: str,
|
||||
context: FileContext,
|
||||
) -> ResolvedFile:
|
||||
"""Resolve a file as inline content.
|
||||
|
||||
Args:
|
||||
file: The file to resolve (used for logging).
|
||||
provider: Provider name.
|
||||
context: Pre-computed file context.
|
||||
|
||||
Returns:
|
||||
InlineBase64 or InlineBytes depending on provider.
|
||||
"""
|
||||
logger.debug(f"Resolving {file.filename} as inline for {provider}")
|
||||
if self.config.use_bytes_for_bedrock and "bedrock" in provider:
|
||||
return InlineBytes(
|
||||
content_type=context.content_type,
|
||||
data=context.content,
|
||||
)
|
||||
|
||||
encoded = base64.b64encode(context.content).decode("ascii")
|
||||
return InlineBase64(
|
||||
content_type=context.content_type,
|
||||
data=encoded,
|
||||
)
|
||||
|
||||
async def aresolve(self, file: FileInput, provider: ProviderType) -> ResolvedFile:
|
||||
"""Async resolve a file to its delivery format for a provider.
|
||||
|
||||
Args:
|
||||
file: The file to resolve.
|
||||
provider: Provider name (e.g., "gemini", "anthropic", "openai").
|
||||
|
||||
Returns:
|
||||
ResolvedFile representing the appropriate delivery format.
|
||||
"""
|
||||
constraints = get_constraints_for_provider(provider)
|
||||
|
||||
if self._is_url_source(file) and self._supports_url(constraints):
|
||||
return self._resolve_as_url(file)
|
||||
|
||||
context = self._build_file_context(file)
|
||||
|
||||
should_upload = self._should_upload(file, provider, constraints, context.size)
|
||||
|
||||
if should_upload:
|
||||
resolved = await self._aresolve_via_upload(file, provider, context)
|
||||
if resolved is not None:
|
||||
return resolved
|
||||
|
||||
return self._resolve_inline(file, provider, context)
|
||||
|
||||
async def aresolve_files(
|
||||
self,
|
||||
files: dict[str, FileInput],
|
||||
provider: ProviderType,
|
||||
max_concurrency: int = 10,
|
||||
) -> dict[str, ResolvedFile]:
|
||||
"""Async resolve multiple files in parallel.
|
||||
|
||||
Args:
|
||||
files: Dictionary mapping names to file inputs.
|
||||
provider: Provider name.
|
||||
max_concurrency: Maximum number of concurrent resolutions.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping names to resolved files.
|
||||
"""
|
||||
semaphore = asyncio.Semaphore(max_concurrency)
|
||||
|
||||
async def resolve_single(
|
||||
entry_key: str, input_file: FileInput
|
||||
) -> tuple[str, ResolvedFile]:
|
||||
"""Resolve a single file with semaphore limiting."""
|
||||
async with semaphore:
|
||||
entry_resolved = await self.aresolve(input_file, provider)
|
||||
return entry_key, entry_resolved
|
||||
|
||||
tasks = [resolve_single(n, f) for n, f in files.items()]
|
||||
gather_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
output: dict[str, ResolvedFile] = {}
|
||||
for item in gather_results:
|
||||
if isinstance(item, BaseException):
|
||||
logger.error(f"Resolution failed: {item}")
|
||||
continue
|
||||
key, resolved = item
|
||||
output[key] = resolved
|
||||
|
||||
return output
|
||||
|
||||
async def _aresolve_via_upload(
|
||||
self,
|
||||
file: FileInput,
|
||||
provider: ProviderType,
|
||||
context: FileContext,
|
||||
) -> ResolvedFile | None:
|
||||
"""Async resolve a file by uploading it.
|
||||
|
||||
Args:
|
||||
file: The file to upload.
|
||||
provider: Provider name.
|
||||
context: Pre-computed file context.
|
||||
|
||||
Returns:
|
||||
FileReference if upload succeeds, None otherwise.
|
||||
"""
|
||||
if self.upload_cache is not None:
|
||||
cached = await self.upload_cache.aget_by_hash(
|
||||
context.content_hash, provider
|
||||
)
|
||||
if cached is not None:
|
||||
logger.debug(
|
||||
f"Using cached upload for {file.filename}: {cached.file_id}"
|
||||
)
|
||||
return FileReference(
|
||||
content_type=cached.content_type,
|
||||
file_id=cached.file_id,
|
||||
provider=cached.provider,
|
||||
expires_at=cached.expires_at,
|
||||
file_uri=cached.file_uri,
|
||||
)
|
||||
|
||||
uploader = self._get_uploader(provider)
|
||||
if uploader is None:
|
||||
logger.debug(f"No uploader available for {provider}")
|
||||
return None
|
||||
|
||||
result = await self._aupload_with_retry(uploader, file, provider, context.size)
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
if self.upload_cache is not None:
|
||||
await self.upload_cache.aset_by_hash(
|
||||
file_hash=context.content_hash,
|
||||
content_type=context.content_type,
|
||||
provider=provider,
|
||||
file_id=result.file_id,
|
||||
file_uri=result.file_uri,
|
||||
expires_at=result.expires_at,
|
||||
)
|
||||
|
||||
return FileReference(
|
||||
content_type=result.content_type,
|
||||
file_id=result.file_id,
|
||||
provider=result.provider,
|
||||
expires_at=result.expires_at,
|
||||
file_uri=result.file_uri,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _aupload_with_retry(
|
||||
uploader: FileUploader,
|
||||
file: FileInput,
|
||||
provider: str,
|
||||
file_size: int,
|
||||
) -> UploadResult | None:
|
||||
"""Async upload with exponential backoff retry.
|
||||
|
||||
Args:
|
||||
uploader: The uploader to use.
|
||||
file: The file to upload.
|
||||
provider: Provider name for logging.
|
||||
file_size: Size of the file in bytes.
|
||||
|
||||
Returns:
|
||||
UploadResult if successful, None otherwise.
|
||||
"""
|
||||
from crewai_files.processing.exceptions import (
|
||||
PermanentUploadError,
|
||||
TransientUploadError,
|
||||
)
|
||||
|
||||
last_error: Exception | None = None
|
||||
|
||||
for attempt in range(UPLOAD_MAX_RETRIES):
|
||||
with measure_operation(
|
||||
"upload",
|
||||
filename=file.filename,
|
||||
provider=provider,
|
||||
size_bytes=file_size,
|
||||
attempt=attempt + 1,
|
||||
) as metrics:
|
||||
try:
|
||||
result = await uploader.aupload(file)
|
||||
metrics.metadata["file_id"] = result.file_id
|
||||
return result
|
||||
except PermanentUploadError as e:
|
||||
metrics.metadata["error_type"] = "permanent"
|
||||
logger.warning(
|
||||
f"Non-retryable upload error for {file.filename}: {e}"
|
||||
)
|
||||
return None
|
||||
except TransientUploadError as e:
|
||||
metrics.metadata["error_type"] = "transient"
|
||||
last_error = e
|
||||
except Exception as e:
|
||||
metrics.metadata["error_type"] = "unknown"
|
||||
last_error = e
|
||||
|
||||
if attempt < UPLOAD_MAX_RETRIES - 1:
|
||||
delay = UPLOAD_RETRY_DELAY_BASE**attempt
|
||||
logger.debug(
|
||||
f"Retrying upload for {file.filename} in {delay}s (attempt {attempt + 1})"
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
logger.warning(
|
||||
f"Upload failed for {file.filename} to {provider} after {UPLOAD_MAX_RETRIES} attempts: {last_error}"
|
||||
)
|
||||
return None
|
||||
|
||||
def _get_uploader(self, provider: ProviderType) -> FileUploader | None:
|
||||
"""Get or create an uploader for a provider.
|
||||
|
||||
Args:
|
||||
provider: Provider name.
|
||||
|
||||
Returns:
|
||||
FileUploader instance or None if not available.
|
||||
"""
|
||||
if provider not in self._uploaders:
|
||||
uploader = get_uploader(provider)
|
||||
if uploader is not None:
|
||||
self._uploaders[provider] = uploader
|
||||
else:
|
||||
return None
|
||||
|
||||
return self._uploaders.get(provider)
|
||||
|
||||
def get_cached_uploads(self, provider: ProviderType) -> list[CachedUpload]:
|
||||
"""Get all cached uploads for a provider.
|
||||
|
||||
Args:
|
||||
provider: Provider name.
|
||||
|
||||
Returns:
|
||||
List of cached uploads.
|
||||
"""
|
||||
if self.upload_cache is None:
|
||||
return []
|
||||
return self.upload_cache.get_all_for_provider(provider)
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear the upload cache."""
|
||||
if self.upload_cache is not None:
|
||||
self.upload_cache.clear()
|
||||
|
||||
|
||||
def create_resolver(
|
||||
provider: str | None = None,
|
||||
prefer_upload: bool = False,
|
||||
upload_threshold_bytes: int | None = None,
|
||||
enable_cache: bool = True,
|
||||
) -> FileResolver:
|
||||
"""Create a configured FileResolver.
|
||||
|
||||
Args:
|
||||
provider: Optional provider name to load default threshold from constraints.
|
||||
prefer_upload: Whether to prefer upload over inline.
|
||||
upload_threshold_bytes: Size threshold for using upload. If None and
|
||||
provider is specified, uses provider's default threshold.
|
||||
enable_cache: Whether to enable upload caching.
|
||||
|
||||
Returns:
|
||||
Configured FileResolver instance.
|
||||
"""
|
||||
threshold = upload_threshold_bytes
|
||||
if threshold is None and provider is not None:
|
||||
constraints = get_constraints_for_provider(provider)
|
||||
if constraints is not None:
|
||||
threshold = constraints.file_upload_threshold_bytes
|
||||
|
||||
config = FileResolverConfig(
|
||||
prefer_upload=prefer_upload,
|
||||
upload_threshold_bytes=threshold,
|
||||
)
|
||||
|
||||
cache = UploadCache() if enable_cache else None
|
||||
|
||||
return FileResolver(config=config, upload_cache=cache)
|
||||
@@ -1,91 +0,0 @@
|
||||
"""Utility functions for file handling."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from crewai_files.core.sources import is_file_source
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai_files.core.sources import FileSource, FileSourceInput
|
||||
from crewai_files.core.types import FileInput
|
||||
|
||||
|
||||
__all__ = ["is_file_source", "normalize_input_files", "wrap_file_source"]
|
||||
|
||||
|
||||
def wrap_file_source(source: FileSource) -> FileInput:
|
||||
"""Wrap a FileSource in the appropriate typed FileInput wrapper.
|
||||
|
||||
Args:
|
||||
source: The file source to wrap.
|
||||
|
||||
Returns:
|
||||
Typed FileInput wrapper based on content type.
|
||||
"""
|
||||
from crewai_files.core.types import (
|
||||
AudioFile,
|
||||
ImageFile,
|
||||
PDFFile,
|
||||
TextFile,
|
||||
VideoFile,
|
||||
)
|
||||
|
||||
content_type = source.content_type
|
||||
|
||||
if content_type.startswith("image/"):
|
||||
return ImageFile(source=source)
|
||||
if content_type.startswith("audio/"):
|
||||
return AudioFile(source=source)
|
||||
if content_type.startswith("video/"):
|
||||
return VideoFile(source=source)
|
||||
if content_type == "application/pdf":
|
||||
return PDFFile(source=source)
|
||||
return TextFile(source=source)
|
||||
|
||||
|
||||
def normalize_input_files(
|
||||
input_files: list[FileSourceInput | FileInput],
|
||||
) -> dict[str, FileInput]:
|
||||
"""Convert a list of file sources to a named dictionary of FileInputs.
|
||||
|
||||
Args:
|
||||
input_files: List of file source inputs or File objects.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping names to FileInput wrappers.
|
||||
"""
|
||||
from crewai_files.core.sources import FileBytes, FilePath, FileStream, FileUrl
|
||||
from crewai_files.core.types import BaseFile
|
||||
|
||||
result: dict[str, FileInput] = {}
|
||||
|
||||
for i, item in enumerate(input_files):
|
||||
if isinstance(item, BaseFile):
|
||||
name = item.filename or f"file_{i}"
|
||||
if "." in name:
|
||||
name = name.rsplit(".", 1)[0]
|
||||
result[name] = item
|
||||
continue
|
||||
|
||||
file_source: FilePath | FileBytes | FileStream | FileUrl
|
||||
if isinstance(item, (FilePath, FileBytes, FileStream, FileUrl)):
|
||||
file_source = item
|
||||
elif isinstance(item, Path):
|
||||
file_source = FilePath(path=item)
|
||||
elif isinstance(item, str):
|
||||
if item.startswith(("http://", "https://")):
|
||||
file_source = FileUrl(url=item)
|
||||
else:
|
||||
file_source = FilePath(path=Path(item))
|
||||
elif isinstance(item, (bytes, memoryview)):
|
||||
file_source = FileBytes(data=bytes(item))
|
||||
else:
|
||||
continue
|
||||
|
||||
name = file_source.filename or f"file_{i}"
|
||||
result[name] = wrap_file_source(file_source)
|
||||
|
||||
return result
|
||||
@@ -1,11 +0,0 @@
|
||||
"""File uploader implementations for provider File APIs."""
|
||||
|
||||
from crewai_files.uploaders.base import FileUploader, UploadResult
|
||||
from crewai_files.uploaders.factory import get_uploader
|
||||
|
||||
|
||||
__all__ = [
|
||||
"FileUploader",
|
||||
"UploadResult",
|
||||
"get_uploader",
|
||||
]
|
||||
@@ -1,241 +0,0 @@
|
||||
"""Anthropic Files API uploader implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from crewai_files.core.types import FileInput
|
||||
from crewai_files.processing.exceptions import classify_upload_error
|
||||
from crewai_files.uploaders.base import FileUploader, UploadResult
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnthropicFileUploader(FileUploader):
|
||||
"""Uploader for Anthropic Files API.
|
||||
|
||||
Uses the anthropic SDK to upload files. Files are stored persistently
|
||||
until explicitly deleted.
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: str | None = None) -> None:
|
||||
"""Initialize the Anthropic uploader.
|
||||
|
||||
Args:
|
||||
api_key: Optional Anthropic API key. If not provided, uses
|
||||
ANTHROPIC_API_KEY environment variable.
|
||||
"""
|
||||
self._api_key = api_key or os.environ.get("ANTHROPIC_API_KEY")
|
||||
self._client: Any = None
|
||||
self._async_client: Any = None
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
"""Return the provider name."""
|
||||
return "anthropic"
|
||||
|
||||
def _get_client(self) -> Any:
|
||||
"""Get or create the Anthropic client."""
|
||||
if self._client is None:
|
||||
try:
|
||||
import anthropic
|
||||
|
||||
self._client = anthropic.Anthropic(api_key=self._api_key)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"anthropic is required for Anthropic file uploads. "
|
||||
"Install with: pip install anthropic"
|
||||
) from e
|
||||
return self._client
|
||||
|
||||
def _get_async_client(self) -> Any:
|
||||
"""Get or create the async Anthropic client."""
|
||||
if self._async_client is None:
|
||||
try:
|
||||
import anthropic
|
||||
|
||||
self._async_client = anthropic.AsyncAnthropic(api_key=self._api_key)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"anthropic is required for Anthropic file uploads. "
|
||||
"Install with: pip install anthropic"
|
||||
) from e
|
||||
return self._async_client
|
||||
|
||||
def upload(self, file: FileInput, purpose: str | None = None) -> UploadResult:
|
||||
"""Upload a file to Anthropic.
|
||||
|
||||
Args:
|
||||
file: The file to upload.
|
||||
purpose: Optional purpose for the file (default: "user_upload").
|
||||
|
||||
Returns:
|
||||
UploadResult with the file ID and metadata.
|
||||
|
||||
Raises:
|
||||
TransientUploadError: For retryable errors (network, rate limits).
|
||||
PermanentUploadError: For non-retryable errors (auth, validation).
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
|
||||
content = file.read()
|
||||
file_purpose = purpose or "user_upload"
|
||||
|
||||
file_data = io.BytesIO(content)
|
||||
|
||||
logger.info(
|
||||
f"Uploading file '{file.filename}' to Anthropic ({len(content)} bytes)"
|
||||
)
|
||||
|
||||
uploaded_file = client.files.create(
|
||||
file=(file.filename, file_data, file.content_type),
|
||||
purpose=file_purpose,
|
||||
)
|
||||
|
||||
logger.info(f"Uploaded to Anthropic: {uploaded_file.id}")
|
||||
|
||||
return UploadResult(
|
||||
file_id=uploaded_file.id,
|
||||
file_uri=None,
|
||||
content_type=file.content_type,
|
||||
expires_at=None,
|
||||
provider=self.provider_name,
|
||||
)
|
||||
except ImportError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise classify_upload_error(e, file.filename) from e
|
||||
|
||||
def delete(self, file_id: str) -> bool:
|
||||
"""Delete an uploaded file from Anthropic.
|
||||
|
||||
Args:
|
||||
file_id: The file ID to delete.
|
||||
|
||||
Returns:
|
||||
True if deletion was successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
client.files.delete(file_id=file_id)
|
||||
logger.info(f"Deleted Anthropic file: {file_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete Anthropic file {file_id}: {e}")
|
||||
return False
|
||||
|
||||
def get_file_info(self, file_id: str) -> dict[str, Any] | None:
|
||||
"""Get information about an uploaded file.
|
||||
|
||||
Args:
|
||||
file_id: The file ID.
|
||||
|
||||
Returns:
|
||||
Dictionary with file information, or None if not found.
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
file_info = client.files.retrieve(file_id=file_id)
|
||||
return {
|
||||
"id": file_info.id,
|
||||
"filename": file_info.filename,
|
||||
"purpose": file_info.purpose,
|
||||
"size_bytes": file_info.size_bytes,
|
||||
"created_at": file_info.created_at,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to get Anthropic file info for {file_id}: {e}")
|
||||
return None
|
||||
|
||||
def list_files(self) -> list[dict[str, Any]]:
|
||||
"""List all uploaded files.
|
||||
|
||||
Returns:
|
||||
List of dictionaries with file information.
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
files = client.files.list()
|
||||
return [
|
||||
{
|
||||
"id": f.id,
|
||||
"filename": f.filename,
|
||||
"purpose": f.purpose,
|
||||
"size_bytes": f.size_bytes,
|
||||
"created_at": f.created_at,
|
||||
}
|
||||
for f in files.data
|
||||
]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to list Anthropic files: {e}")
|
||||
return []
|
||||
|
||||
async def aupload(
|
||||
self, file: FileInput, purpose: str | None = None
|
||||
) -> UploadResult:
|
||||
"""Async upload a file to Anthropic using native async client.
|
||||
|
||||
Args:
|
||||
file: The file to upload.
|
||||
purpose: Optional purpose for the file (default: "user_upload").
|
||||
|
||||
Returns:
|
||||
UploadResult with the file ID and metadata.
|
||||
|
||||
Raises:
|
||||
TransientUploadError: For retryable errors (network, rate limits).
|
||||
PermanentUploadError: For non-retryable errors (auth, validation).
|
||||
"""
|
||||
try:
|
||||
client = self._get_async_client()
|
||||
|
||||
content = await file.aread()
|
||||
file_purpose = purpose or "user_upload"
|
||||
|
||||
file_data = io.BytesIO(content)
|
||||
|
||||
logger.info(
|
||||
f"Uploading file '{file.filename}' to Anthropic ({len(content)} bytes)"
|
||||
)
|
||||
|
||||
uploaded_file = await client.files.create(
|
||||
file=(file.filename, file_data, file.content_type),
|
||||
purpose=file_purpose,
|
||||
)
|
||||
|
||||
logger.info(f"Uploaded to Anthropic: {uploaded_file.id}")
|
||||
|
||||
return UploadResult(
|
||||
file_id=uploaded_file.id,
|
||||
file_uri=None,
|
||||
content_type=file.content_type,
|
||||
expires_at=None,
|
||||
provider=self.provider_name,
|
||||
)
|
||||
except ImportError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise classify_upload_error(e, file.filename) from e
|
||||
|
||||
async def adelete(self, file_id: str) -> bool:
|
||||
"""Async delete an uploaded file from Anthropic.
|
||||
|
||||
Args:
|
||||
file_id: The file ID to delete.
|
||||
|
||||
Returns:
|
||||
True if deletion was successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
client = self._get_async_client()
|
||||
await client.files.delete(file_id=file_id)
|
||||
logger.info(f"Deleted Anthropic file: {file_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete Anthropic file {file_id}: {e}")
|
||||
return False
|
||||
@@ -1,118 +0,0 @@
|
||||
"""Base class for file uploaders."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from crewai_files.core.types import FileInput
|
||||
|
||||
|
||||
@dataclass
|
||||
class UploadResult:
|
||||
"""Result of a file upload operation.
|
||||
|
||||
Attributes:
|
||||
file_id: Provider-specific file identifier.
|
||||
file_uri: Optional URI for accessing the file.
|
||||
content_type: MIME type of the uploaded file.
|
||||
expires_at: When the upload expires (if applicable).
|
||||
provider: Name of the provider.
|
||||
"""
|
||||
|
||||
file_id: str
|
||||
provider: str
|
||||
content_type: str
|
||||
file_uri: str | None = None
|
||||
expires_at: datetime | None = None
|
||||
|
||||
|
||||
class FileUploader(ABC):
|
||||
"""Abstract base class for provider file uploaders.
|
||||
|
||||
Implementations handle uploading files to provider-specific File APIs.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def provider_name(self) -> str:
|
||||
"""Return the provider name."""
|
||||
|
||||
@abstractmethod
|
||||
def upload(self, file: FileInput, purpose: str | None = None) -> UploadResult:
|
||||
"""Upload a file to the provider.
|
||||
|
||||
Args:
|
||||
file: The file to upload.
|
||||
purpose: Optional purpose/description for the upload.
|
||||
|
||||
Returns:
|
||||
UploadResult with the file identifier and metadata.
|
||||
|
||||
Raises:
|
||||
Exception: If upload fails.
|
||||
"""
|
||||
|
||||
async def aupload(
|
||||
self, file: FileInput, purpose: str | None = None
|
||||
) -> UploadResult:
|
||||
"""Async upload a file to the provider.
|
||||
|
||||
Default implementation runs sync upload in executor.
|
||||
Override in subclasses for native async support.
|
||||
|
||||
Args:
|
||||
file: The file to upload.
|
||||
purpose: Optional purpose/description for the upload.
|
||||
|
||||
Returns:
|
||||
UploadResult with the file identifier and metadata.
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(None, self.upload, file, purpose)
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, file_id: str) -> bool:
|
||||
"""Delete an uploaded file.
|
||||
|
||||
Args:
|
||||
file_id: The file identifier to delete.
|
||||
|
||||
Returns:
|
||||
True if deletion was successful, False otherwise.
|
||||
"""
|
||||
|
||||
async def adelete(self, file_id: str) -> bool:
|
||||
"""Async delete an uploaded file.
|
||||
|
||||
Default implementation runs sync delete in executor.
|
||||
Override in subclasses for native async support.
|
||||
|
||||
Args:
|
||||
file_id: The file identifier to delete.
|
||||
|
||||
Returns:
|
||||
True if deletion was successful, False otherwise.
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(None, self.delete, file_id)
|
||||
|
||||
def get_file_info(self, file_id: str) -> dict[str, Any] | None:
|
||||
"""Get information about an uploaded file.
|
||||
|
||||
Args:
|
||||
file_id: The file identifier.
|
||||
|
||||
Returns:
|
||||
Dictionary with file information, or None if not found.
|
||||
"""
|
||||
return None
|
||||
|
||||
def list_files(self) -> list[dict[str, Any]]:
|
||||
"""List all uploaded files.
|
||||
|
||||
Returns:
|
||||
List of dictionaries with file information.
|
||||
"""
|
||||
return []
|
||||
@@ -1,473 +0,0 @@
|
||||
"""AWS Bedrock S3 file uploader implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from crewai_files.core.constants import (
|
||||
MAX_CONCURRENCY,
|
||||
MULTIPART_CHUNKSIZE,
|
||||
MULTIPART_THRESHOLD,
|
||||
)
|
||||
from crewai_files.core.sources import FileBytes, FilePath
|
||||
from crewai_files.core.types import FileInput
|
||||
from crewai_files.processing.exceptions import (
|
||||
PermanentUploadError,
|
||||
TransientUploadError,
|
||||
)
|
||||
from crewai_files.uploaders.base import FileUploader, UploadResult
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _classify_s3_error(e: Exception, filename: str | None) -> Exception:
|
||||
"""Classify an S3 exception as transient or permanent upload error.
|
||||
|
||||
Args:
|
||||
e: The exception to classify.
|
||||
filename: The filename for error context.
|
||||
|
||||
Returns:
|
||||
A TransientUploadError or PermanentUploadError wrapping the original.
|
||||
"""
|
||||
error_type = type(e).__name__
|
||||
error_code = getattr(e, "response", {}).get("Error", {}).get("Code", "")
|
||||
|
||||
if error_code in ("SlowDown", "ServiceUnavailable", "InternalError"):
|
||||
return TransientUploadError(f"Transient S3 error: {e}", file_name=filename)
|
||||
if error_code in ("AccessDenied", "InvalidAccessKeyId", "SignatureDoesNotMatch"):
|
||||
return PermanentUploadError(f"S3 authentication error: {e}", file_name=filename)
|
||||
if error_code in ("NoSuchBucket", "InvalidBucketName"):
|
||||
return PermanentUploadError(f"S3 bucket error: {e}", file_name=filename)
|
||||
if "Throttl" in error_type or "Throttl" in str(e):
|
||||
return TransientUploadError(f"S3 throttling: {e}", file_name=filename)
|
||||
return TransientUploadError(f"S3 upload failed: {e}", file_name=filename)
|
||||
|
||||
|
||||
def _get_file_path(file: FileInput) -> Path | None:
|
||||
"""Get the filesystem path if file source is FilePath.
|
||||
|
||||
Args:
|
||||
file: The file input to check.
|
||||
|
||||
Returns:
|
||||
Path if source is FilePath, None otherwise.
|
||||
"""
|
||||
source = file._file_source
|
||||
if isinstance(source, FilePath):
|
||||
return source.path
|
||||
return None
|
||||
|
||||
|
||||
def _get_file_size(file: FileInput) -> int | None:
|
||||
"""Get file size without reading content if possible.
|
||||
|
||||
Args:
|
||||
file: The file input.
|
||||
|
||||
Returns:
|
||||
Size in bytes if determinable without reading, None otherwise.
|
||||
"""
|
||||
source = file._file_source
|
||||
if isinstance(source, FilePath):
|
||||
return source.path.stat().st_size
|
||||
if isinstance(source, FileBytes):
|
||||
return len(source.data)
|
||||
return None
|
||||
|
||||
|
||||
def _compute_hash_streaming(file_path: Path) -> str:
|
||||
"""Compute SHA-256 hash by streaming file content.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file.
|
||||
|
||||
Returns:
|
||||
First 16 characters of hex digest.
|
||||
"""
|
||||
hasher = hashlib.sha256()
|
||||
with open(file_path, "rb") as f:
|
||||
while chunk := f.read(1024 * 1024):
|
||||
hasher.update(chunk)
|
||||
return hasher.hexdigest()[:16]
|
||||
|
||||
|
||||
class BedrockFileUploader(FileUploader):
|
||||
"""Uploader for AWS Bedrock via S3.
|
||||
|
||||
Uploads files to S3 and returns S3 URIs that can be used with Bedrock's
|
||||
Converse API s3Location source format.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bucket_name: str | None = None,
|
||||
bucket_owner: str | None = None,
|
||||
prefix: str = "crewai-files",
|
||||
region: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize the Bedrock S3 uploader.
|
||||
|
||||
Args:
|
||||
bucket_name: S3 bucket name. If not provided, uses
|
||||
CREWAI_BEDROCK_S3_BUCKET environment variable.
|
||||
bucket_owner: Optional bucket owner account ID for cross-account access.
|
||||
Uses CREWAI_BEDROCK_S3_BUCKET_OWNER environment variable if not provided.
|
||||
prefix: S3 key prefix for uploaded files (default: "crewai-files").
|
||||
region: AWS region. Uses AWS_REGION or AWS_DEFAULT_REGION if not provided.
|
||||
"""
|
||||
self._bucket_name = bucket_name or os.environ.get("CREWAI_BEDROCK_S3_BUCKET")
|
||||
self._bucket_owner = bucket_owner or os.environ.get(
|
||||
"CREWAI_BEDROCK_S3_BUCKET_OWNER"
|
||||
)
|
||||
self._prefix = prefix
|
||||
self._region = region or os.environ.get(
|
||||
"AWS_REGION", os.environ.get("AWS_DEFAULT_REGION")
|
||||
)
|
||||
self._client: Any = None
|
||||
self._async_client: Any = None
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
"""Return the provider name."""
|
||||
return "bedrock"
|
||||
|
||||
@property
|
||||
def bucket_name(self) -> str:
|
||||
"""Return the configured bucket name."""
|
||||
if not self._bucket_name:
|
||||
raise ValueError(
|
||||
"S3 bucket name not configured. Set CREWAI_BEDROCK_S3_BUCKET "
|
||||
"environment variable or pass bucket_name parameter."
|
||||
)
|
||||
return self._bucket_name
|
||||
|
||||
@property
|
||||
def bucket_owner(self) -> str | None:
|
||||
"""Return the configured bucket owner."""
|
||||
return self._bucket_owner
|
||||
|
||||
def _get_client(self) -> Any:
|
||||
"""Get or create the S3 client."""
|
||||
if self._client is None:
|
||||
try:
|
||||
import boto3
|
||||
|
||||
self._client = boto3.client("s3", region_name=self._region)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"boto3 is required for Bedrock S3 file uploads. "
|
||||
"Install with: pip install boto3"
|
||||
) from e
|
||||
return self._client
|
||||
|
||||
def _get_async_client(self) -> Any:
|
||||
"""Get or create the async S3 client."""
|
||||
if self._async_client is None:
|
||||
try:
|
||||
import aioboto3 # type: ignore[import-not-found]
|
||||
|
||||
self._session = aioboto3.Session()
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"aioboto3 is required for async Bedrock S3 file uploads. "
|
||||
"Install with: pip install aioboto3"
|
||||
) from e
|
||||
return self._session
|
||||
|
||||
def _generate_s3_key(self, file: FileInput, content: bytes | None = None) -> str:
|
||||
"""Generate a unique S3 key for the file.
|
||||
|
||||
For FilePath sources with no content provided, computes hash via streaming.
|
||||
|
||||
Args:
|
||||
file: The file being uploaded.
|
||||
content: The file content bytes (optional for FilePath sources).
|
||||
|
||||
Returns:
|
||||
S3 key string.
|
||||
"""
|
||||
if content is not None:
|
||||
content_hash = hashlib.sha256(content).hexdigest()[:16]
|
||||
else:
|
||||
file_path = _get_file_path(file)
|
||||
if file_path is not None:
|
||||
content_hash = _compute_hash_streaming(file_path)
|
||||
else:
|
||||
content_hash = hashlib.sha256(file.read()).hexdigest()[:16]
|
||||
|
||||
filename = file.filename or "file"
|
||||
safe_filename = "".join(
|
||||
c if c.isalnum() or c in ".-_" else "_" for c in filename
|
||||
)
|
||||
return f"{self._prefix}/{content_hash}_{safe_filename}"
|
||||
|
||||
def _build_s3_uri(self, key: str) -> str:
|
||||
"""Build an S3 URI from a key.
|
||||
|
||||
Args:
|
||||
key: The S3 object key.
|
||||
|
||||
Returns:
|
||||
S3 URI string.
|
||||
"""
|
||||
return f"s3://{self.bucket_name}/{key}"
|
||||
|
||||
@staticmethod
|
||||
def _get_transfer_config() -> Any:
|
||||
"""Get boto3 TransferConfig for multipart uploads."""
|
||||
from boto3.s3.transfer import TransferConfig
|
||||
|
||||
return TransferConfig(
|
||||
multipart_threshold=MULTIPART_THRESHOLD,
|
||||
multipart_chunksize=MULTIPART_CHUNKSIZE,
|
||||
max_concurrency=MAX_CONCURRENCY,
|
||||
)
|
||||
|
||||
def upload(self, file: FileInput, purpose: str | None = None) -> UploadResult:
|
||||
"""Upload a file to S3 for use with Bedrock.
|
||||
|
||||
Uses streaming upload with automatic multipart for large files.
|
||||
For FilePath sources, streams directly from disk without loading into memory.
|
||||
|
||||
Args:
|
||||
file: The file to upload.
|
||||
purpose: Optional purpose (unused, kept for interface consistency).
|
||||
|
||||
Returns:
|
||||
UploadResult with the S3 URI and metadata.
|
||||
|
||||
Raises:
|
||||
TransientUploadError: For retryable errors (network, throttling).
|
||||
PermanentUploadError: For non-retryable errors (auth, validation).
|
||||
"""
|
||||
import io
|
||||
|
||||
try:
|
||||
client = self._get_client()
|
||||
transfer_config = self._get_transfer_config()
|
||||
file_path = _get_file_path(file)
|
||||
|
||||
if file_path is not None:
|
||||
file_size = file_path.stat().st_size
|
||||
s3_key = self._generate_s3_key(file)
|
||||
|
||||
logger.info(
|
||||
f"Uploading file '{file.filename}' to S3 bucket "
|
||||
f"'{self.bucket_name}' ({file_size} bytes, streaming)"
|
||||
)
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
client.upload_fileobj(
|
||||
f,
|
||||
self.bucket_name,
|
||||
s3_key,
|
||||
ExtraArgs={"ContentType": file.content_type},
|
||||
Config=transfer_config,
|
||||
)
|
||||
else:
|
||||
content = file.read()
|
||||
s3_key = self._generate_s3_key(file, content)
|
||||
|
||||
logger.info(
|
||||
f"Uploading file '{file.filename}' to S3 bucket "
|
||||
f"'{self.bucket_name}' ({len(content)} bytes)"
|
||||
)
|
||||
|
||||
client.upload_fileobj(
|
||||
io.BytesIO(content),
|
||||
self.bucket_name,
|
||||
s3_key,
|
||||
ExtraArgs={"ContentType": file.content_type},
|
||||
Config=transfer_config,
|
||||
)
|
||||
|
||||
s3_uri = self._build_s3_uri(s3_key)
|
||||
logger.info(f"Uploaded to S3: {s3_uri}")
|
||||
|
||||
return UploadResult(
|
||||
file_id=s3_key,
|
||||
file_uri=s3_uri,
|
||||
content_type=file.content_type,
|
||||
expires_at=None,
|
||||
provider=self.provider_name,
|
||||
)
|
||||
except ImportError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise _classify_s3_error(e, file.filename) from e
|
||||
|
||||
def delete(self, file_id: str) -> bool:
|
||||
"""Delete an uploaded file from S3.
|
||||
|
||||
Args:
|
||||
file_id: The S3 key to delete.
|
||||
|
||||
Returns:
|
||||
True if deletion was successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
client.delete_object(Bucket=self.bucket_name, Key=file_id)
|
||||
logger.info(f"Deleted S3 object: s3://{self.bucket_name}/{file_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to delete S3 object s3://{self.bucket_name}/{file_id}: {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
def get_file_info(self, file_id: str) -> dict[str, Any] | None:
|
||||
"""Get information about an uploaded file.
|
||||
|
||||
Args:
|
||||
file_id: The S3 key.
|
||||
|
||||
Returns:
|
||||
Dictionary with file information, or None if not found.
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
response = client.head_object(Bucket=self.bucket_name, Key=file_id)
|
||||
return {
|
||||
"id": file_id,
|
||||
"uri": self._build_s3_uri(file_id),
|
||||
"content_type": response.get("ContentType"),
|
||||
"size": response.get("ContentLength"),
|
||||
"last_modified": response.get("LastModified"),
|
||||
"etag": response.get("ETag"),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to get S3 object info for {file_id}: {e}")
|
||||
return None
|
||||
|
||||
def list_files(self) -> list[dict[str, Any]]:
|
||||
"""List all uploaded files in the configured prefix.
|
||||
|
||||
Returns:
|
||||
List of dictionaries with file information.
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
response = client.list_objects_v2(
|
||||
Bucket=self.bucket_name,
|
||||
Prefix=self._prefix,
|
||||
)
|
||||
return [
|
||||
{
|
||||
"id": obj["Key"],
|
||||
"uri": self._build_s3_uri(obj["Key"]),
|
||||
"size": obj.get("Size"),
|
||||
"last_modified": obj.get("LastModified"),
|
||||
"etag": obj.get("ETag"),
|
||||
}
|
||||
for obj in response.get("Contents", [])
|
||||
]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to list S3 objects: {e}")
|
||||
return []
|
||||
|
||||
async def aupload(
|
||||
self, file: FileInput, purpose: str | None = None
|
||||
) -> UploadResult:
|
||||
"""Async upload a file to S3 for use with Bedrock.
|
||||
|
||||
Uses streaming upload with automatic multipart for large files.
|
||||
For FilePath sources, streams directly from disk without loading into memory.
|
||||
|
||||
Args:
|
||||
file: The file to upload.
|
||||
purpose: Optional purpose (unused, kept for interface consistency).
|
||||
|
||||
Returns:
|
||||
UploadResult with the S3 URI and metadata.
|
||||
|
||||
Raises:
|
||||
TransientUploadError: For retryable errors (network, throttling).
|
||||
PermanentUploadError: For non-retryable errors (auth, validation).
|
||||
"""
|
||||
import io
|
||||
|
||||
import aiofiles
|
||||
|
||||
try:
|
||||
session = self._get_async_client()
|
||||
transfer_config = self._get_transfer_config()
|
||||
file_path = _get_file_path(file)
|
||||
|
||||
if file_path is not None:
|
||||
file_size = file_path.stat().st_size
|
||||
s3_key = self._generate_s3_key(file)
|
||||
|
||||
logger.info(
|
||||
f"Uploading file '{file.filename}' to S3 bucket "
|
||||
f"'{self.bucket_name}' ({file_size} bytes, streaming)"
|
||||
)
|
||||
|
||||
async with session.client("s3", region_name=self._region) as client:
|
||||
async with aiofiles.open(file_path, "rb") as f:
|
||||
await client.upload_fileobj(
|
||||
f,
|
||||
self.bucket_name,
|
||||
s3_key,
|
||||
ExtraArgs={"ContentType": file.content_type},
|
||||
Config=transfer_config,
|
||||
)
|
||||
else:
|
||||
content = await file.aread()
|
||||
s3_key = self._generate_s3_key(file, content)
|
||||
|
||||
logger.info(
|
||||
f"Uploading file '{file.filename}' to S3 bucket "
|
||||
f"'{self.bucket_name}' ({len(content)} bytes)"
|
||||
)
|
||||
|
||||
async with session.client("s3", region_name=self._region) as client:
|
||||
await client.upload_fileobj(
|
||||
io.BytesIO(content),
|
||||
self.bucket_name,
|
||||
s3_key,
|
||||
ExtraArgs={"ContentType": file.content_type},
|
||||
Config=transfer_config,
|
||||
)
|
||||
|
||||
s3_uri = self._build_s3_uri(s3_key)
|
||||
logger.info(f"Uploaded to S3: {s3_uri}")
|
||||
|
||||
return UploadResult(
|
||||
file_id=s3_key,
|
||||
file_uri=s3_uri,
|
||||
content_type=file.content_type,
|
||||
expires_at=None,
|
||||
provider=self.provider_name,
|
||||
)
|
||||
except ImportError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise _classify_s3_error(e, file.filename) from e
|
||||
|
||||
async def adelete(self, file_id: str) -> bool:
|
||||
"""Async delete an uploaded file from S3.
|
||||
|
||||
Args:
|
||||
file_id: The S3 key to delete.
|
||||
|
||||
Returns:
|
||||
True if deletion was successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
session = self._get_async_client()
|
||||
async with session.client("s3", region_name=self._region) as client:
|
||||
await client.delete_object(Bucket=self.bucket_name, Key=file_id)
|
||||
logger.info(f"Deleted S3 object: s3://{self.bucket_name}/{file_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to delete S3 object s3://{self.bucket_name}/{file_id}: {e}"
|
||||
)
|
||||
return False
|
||||
@@ -1,196 +0,0 @@
|
||||
"""Factory for creating file uploaders."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Literal, TypeAlias, TypedDict, overload
|
||||
|
||||
from typing_extensions import NotRequired, Unpack
|
||||
|
||||
from crewai_files.uploaders.anthropic import AnthropicFileUploader
|
||||
from crewai_files.uploaders.bedrock import BedrockFileUploader
|
||||
from crewai_files.uploaders.gemini import GeminiFileUploader
|
||||
from crewai_files.uploaders.openai import OpenAIFileUploader
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
FileUploaderType: TypeAlias = (
|
||||
GeminiFileUploader
|
||||
| AnthropicFileUploader
|
||||
| BedrockFileUploader
|
||||
| OpenAIFileUploader
|
||||
)
|
||||
|
||||
GeminiProviderType = Literal["gemini", "google"]
|
||||
AnthropicProviderType = Literal["anthropic", "claude"]
|
||||
OpenAIProviderType = Literal["openai", "gpt", "azure"]
|
||||
BedrockProviderType = Literal["bedrock", "aws"]
|
||||
|
||||
ProviderType: TypeAlias = (
|
||||
GeminiProviderType
|
||||
| AnthropicProviderType
|
||||
| OpenAIProviderType
|
||||
| BedrockProviderType
|
||||
)
|
||||
|
||||
|
||||
class _BaseOpts(TypedDict):
|
||||
"""Kwargs for uploader factory."""
|
||||
|
||||
api_key: NotRequired[str | None]
|
||||
|
||||
|
||||
class OpenAIOpts(_BaseOpts):
|
||||
"""Kwargs for openai uploader factory."""
|
||||
|
||||
chunk_size: NotRequired[int]
|
||||
|
||||
|
||||
class GeminiOpts(_BaseOpts):
|
||||
"""Kwargs for gemini uploader factory."""
|
||||
|
||||
|
||||
class AnthropicOpts(_BaseOpts):
|
||||
"""Kwargs for anthropic uploader factory."""
|
||||
|
||||
|
||||
class BedrockOpts(TypedDict):
|
||||
"""Kwargs for bedrock uploader factory."""
|
||||
|
||||
bucket_name: NotRequired[str | None]
|
||||
bucket_owner: NotRequired[str | None]
|
||||
prefix: NotRequired[str]
|
||||
region: NotRequired[str | None]
|
||||
|
||||
|
||||
class AllOptions(TypedDict):
|
||||
"""Kwargs for uploader factory."""
|
||||
|
||||
api_key: NotRequired[str | None]
|
||||
chunk_size: NotRequired[int]
|
||||
bucket_name: NotRequired[str | None]
|
||||
bucket_owner: NotRequired[str | None]
|
||||
prefix: NotRequired[str]
|
||||
region: NotRequired[str | None]
|
||||
|
||||
|
||||
@overload
|
||||
def get_uploader(
|
||||
provider: GeminiProviderType,
|
||||
**kwargs: Unpack[GeminiOpts],
|
||||
) -> GeminiFileUploader:
|
||||
"""Get Gemini file uploader."""
|
||||
|
||||
|
||||
@overload
|
||||
def get_uploader(
|
||||
provider: AnthropicProviderType,
|
||||
**kwargs: Unpack[AnthropicOpts],
|
||||
) -> AnthropicFileUploader:
|
||||
"""Get Anthropic file uploader."""
|
||||
|
||||
|
||||
@overload
|
||||
def get_uploader(
|
||||
provider: OpenAIProviderType,
|
||||
**kwargs: Unpack[OpenAIOpts],
|
||||
) -> OpenAIFileUploader:
|
||||
"""Get OpenAI file uploader."""
|
||||
|
||||
|
||||
@overload
|
||||
def get_uploader(
|
||||
provider: BedrockProviderType,
|
||||
**kwargs: Unpack[BedrockOpts],
|
||||
) -> BedrockFileUploader:
|
||||
"""Get Bedrock file uploader."""
|
||||
|
||||
|
||||
@overload
|
||||
def get_uploader(
|
||||
provider: ProviderType, **kwargs: Unpack[AllOptions]
|
||||
) -> FileUploaderType:
|
||||
"""Get any file uploader."""
|
||||
|
||||
|
||||
def get_uploader(
|
||||
provider: ProviderType, **kwargs: Unpack[AllOptions]
|
||||
) -> FileUploaderType:
|
||||
"""Get a file uploader for a specific provider.
|
||||
|
||||
Args:
|
||||
provider: Provider name (e.g., "gemini", "anthropic").
|
||||
**kwargs: Additional arguments passed to the uploader constructor.
|
||||
|
||||
Returns:
|
||||
FileUploader instance for the provider, or None if not supported.
|
||||
"""
|
||||
provider_lower = provider.lower()
|
||||
|
||||
if "gemini" in provider_lower or "google" in provider_lower:
|
||||
try:
|
||||
from crewai_files.uploaders.gemini import GeminiFileUploader
|
||||
|
||||
return GeminiFileUploader(api_key=kwargs.get("api_key"))
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"google-genai not installed. Install with: pip install google-genai"
|
||||
)
|
||||
raise
|
||||
|
||||
if "anthropic" in provider_lower or "claude" in provider_lower:
|
||||
try:
|
||||
from crewai_files.uploaders.anthropic import AnthropicFileUploader
|
||||
|
||||
return AnthropicFileUploader(api_key=kwargs.get("api_key"))
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"anthropic not installed. Install with: pip install anthropic"
|
||||
)
|
||||
raise
|
||||
|
||||
if (
|
||||
"openai" in provider_lower
|
||||
or "gpt" in provider_lower
|
||||
or "azure" in provider_lower
|
||||
):
|
||||
try:
|
||||
from crewai_files.uploaders.openai import OpenAIFileUploader
|
||||
|
||||
return OpenAIFileUploader(
|
||||
api_key=kwargs.get("api_key"),
|
||||
chunk_size=kwargs.get("chunk_size", 67_108_864),
|
||||
)
|
||||
except ImportError:
|
||||
logger.warning("openai not installed. Install with: pip install openai")
|
||||
raise
|
||||
|
||||
if "bedrock" in provider_lower or "aws" in provider_lower:
|
||||
import os
|
||||
|
||||
if (
|
||||
not os.environ.get("CREWAI_BEDROCK_S3_BUCKET")
|
||||
and "bucket_name" not in kwargs
|
||||
):
|
||||
logger.debug(
|
||||
"Bedrock S3 uploader not configured. "
|
||||
"Set CREWAI_BEDROCK_S3_BUCKET environment variable to enable."
|
||||
)
|
||||
raise
|
||||
try:
|
||||
from crewai_files.uploaders.bedrock import BedrockFileUploader
|
||||
|
||||
return BedrockFileUploader(
|
||||
bucket_name=kwargs.get("bucket_name"),
|
||||
bucket_owner=kwargs.get("bucket_owner"),
|
||||
prefix=kwargs.get("prefix", "crewai-files"),
|
||||
region=kwargs.get("region"),
|
||||
)
|
||||
except ImportError:
|
||||
logger.warning("boto3 not installed. Install with: pip install boto3")
|
||||
raise
|
||||
|
||||
logger.debug(f"No file uploader available for provider: {provider}")
|
||||
raise
|
||||
@@ -1,443 +0,0 @@
|
||||
"""Gemini File API uploader implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import random
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from crewai_files.core.constants import (
|
||||
BACKOFF_BASE_DELAY,
|
||||
BACKOFF_JITTER_FACTOR,
|
||||
BACKOFF_MAX_DELAY,
|
||||
GEMINI_FILE_TTL,
|
||||
)
|
||||
from crewai_files.core.sources import FilePath
|
||||
from crewai_files.core.types import FileInput
|
||||
from crewai_files.processing.exceptions import (
|
||||
PermanentUploadError,
|
||||
TransientUploadError,
|
||||
classify_upload_error,
|
||||
)
|
||||
from crewai_files.uploaders.base import FileUploader, UploadResult
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _compute_backoff_delay(attempt: int) -> float:
|
||||
"""Compute exponential backoff delay with jitter.
|
||||
|
||||
Args:
|
||||
attempt: The current attempt number (0-indexed).
|
||||
|
||||
Returns:
|
||||
Delay in seconds with jitter applied.
|
||||
"""
|
||||
delay: float = min(BACKOFF_BASE_DELAY * (2**attempt), BACKOFF_MAX_DELAY)
|
||||
jitter: float = random.uniform(0, delay * BACKOFF_JITTER_FACTOR) # noqa: S311
|
||||
return float(delay + jitter)
|
||||
|
||||
|
||||
def _classify_gemini_error(e: Exception, filename: str | None) -> Exception:
|
||||
"""Classify a Gemini exception as transient or permanent upload error.
|
||||
|
||||
Checks Gemini-specific error message patterns first, then falls back
|
||||
to generic status code classification.
|
||||
|
||||
Args:
|
||||
e: The exception to classify.
|
||||
filename: The filename for error context.
|
||||
|
||||
Returns:
|
||||
A TransientUploadError or PermanentUploadError wrapping the original.
|
||||
"""
|
||||
error_msg = str(e).lower()
|
||||
|
||||
if "quota" in error_msg or "rate" in error_msg or "limit" in error_msg:
|
||||
return TransientUploadError(f"Rate limit error: {e}", file_name=filename)
|
||||
if "auth" in error_msg or "permission" in error_msg or "denied" in error_msg:
|
||||
return PermanentUploadError(
|
||||
f"Authentication/permission error: {e}", file_name=filename
|
||||
)
|
||||
if "invalid" in error_msg or "unsupported" in error_msg:
|
||||
return PermanentUploadError(f"Invalid request: {e}", file_name=filename)
|
||||
|
||||
return classify_upload_error(e, filename)
|
||||
|
||||
|
||||
def _get_file_path(file: FileInput) -> Path | None:
|
||||
"""Get the filesystem path if file source is FilePath.
|
||||
|
||||
Args:
|
||||
file: The file input to check.
|
||||
|
||||
Returns:
|
||||
Path if source is FilePath, None otherwise.
|
||||
"""
|
||||
source = file._file_source
|
||||
if isinstance(source, FilePath):
|
||||
return source.path
|
||||
return None
|
||||
|
||||
|
||||
class GeminiFileUploader(FileUploader):
|
||||
"""Uploader for Google Gemini File API.
|
||||
|
||||
Uses the google-genai SDK to upload files. Files are stored for 48 hours.
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: str | None = None) -> None:
|
||||
"""Initialize the Gemini uploader.
|
||||
|
||||
Args:
|
||||
api_key: Optional Google API key. If not provided, uses
|
||||
GOOGLE_API_KEY environment variable.
|
||||
"""
|
||||
self._api_key = api_key or os.environ.get("GOOGLE_API_KEY")
|
||||
self._client: Any = None
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
"""Return the provider name."""
|
||||
return "gemini"
|
||||
|
||||
def _get_client(self) -> Any:
|
||||
"""Get or create the Gemini client."""
|
||||
if self._client is None:
|
||||
try:
|
||||
from google import genai
|
||||
|
||||
self._client = genai.Client(api_key=self._api_key)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"google-genai is required for Gemini file uploads. "
|
||||
"Install with: pip install google-genai"
|
||||
) from e
|
||||
return self._client
|
||||
|
||||
def upload(self, file: FileInput, purpose: str | None = None) -> UploadResult:
|
||||
"""Upload a file to Gemini.
|
||||
|
||||
For FilePath sources, passes the path directly to the SDK which handles
|
||||
streaming internally via resumable uploads, avoiding memory overhead.
|
||||
|
||||
Args:
|
||||
file: The file to upload.
|
||||
purpose: Optional purpose/description (used as display name).
|
||||
|
||||
Returns:
|
||||
UploadResult with the file URI and metadata.
|
||||
|
||||
Raises:
|
||||
TransientUploadError: For retryable errors (network, rate limits).
|
||||
PermanentUploadError: For non-retryable errors (auth, validation).
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
display_name = purpose or file.filename
|
||||
|
||||
file_path = _get_file_path(file)
|
||||
if file_path is not None:
|
||||
file_size = file_path.stat().st_size
|
||||
logger.info(
|
||||
f"Uploading file '{file.filename}' to Gemini via path "
|
||||
f"({file_size} bytes, streaming)"
|
||||
)
|
||||
uploaded_file = client.files.upload(
|
||||
file=file_path,
|
||||
config={
|
||||
"display_name": display_name,
|
||||
"mime_type": file.content_type,
|
||||
},
|
||||
)
|
||||
else:
|
||||
content = file.read()
|
||||
file_data = io.BytesIO(content)
|
||||
file_data.name = file.filename
|
||||
|
||||
logger.info(
|
||||
f"Uploading file '{file.filename}' to Gemini ({len(content)} bytes)"
|
||||
)
|
||||
|
||||
uploaded_file = client.files.upload(
|
||||
file=file_data,
|
||||
config={
|
||||
"display_name": display_name,
|
||||
"mime_type": file.content_type,
|
||||
},
|
||||
)
|
||||
|
||||
if file.content_type.startswith("video/"):
|
||||
if not self.wait_for_processing(uploaded_file.name):
|
||||
raise PermanentUploadError(
|
||||
f"Video processing failed for {file.filename}",
|
||||
file_name=file.filename,
|
||||
)
|
||||
|
||||
expires_at = datetime.now(timezone.utc) + GEMINI_FILE_TTL
|
||||
|
||||
logger.info(
|
||||
f"Uploaded to Gemini: {uploaded_file.name} (URI: {uploaded_file.uri})"
|
||||
)
|
||||
|
||||
return UploadResult(
|
||||
file_id=uploaded_file.name,
|
||||
file_uri=uploaded_file.uri,
|
||||
content_type=file.content_type,
|
||||
expires_at=expires_at,
|
||||
provider=self.provider_name,
|
||||
)
|
||||
except ImportError:
|
||||
raise
|
||||
except (TransientUploadError, PermanentUploadError):
|
||||
raise
|
||||
except Exception as e:
|
||||
raise _classify_gemini_error(e, file.filename) from e
|
||||
|
||||
async def aupload(
|
||||
self, file: FileInput, purpose: str | None = None
|
||||
) -> UploadResult:
|
||||
"""Async upload a file to Gemini using native async client.
|
||||
|
||||
For FilePath sources, passes the path directly to the SDK which handles
|
||||
streaming internally via resumable uploads, avoiding memory overhead.
|
||||
|
||||
Args:
|
||||
file: The file to upload.
|
||||
purpose: Optional purpose/description (used as display name).
|
||||
|
||||
Returns:
|
||||
UploadResult with the file URI and metadata.
|
||||
|
||||
Raises:
|
||||
TransientUploadError: For retryable errors (network, rate limits).
|
||||
PermanentUploadError: For non-retryable errors (auth, validation).
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
display_name = purpose or file.filename
|
||||
|
||||
file_path = _get_file_path(file)
|
||||
if file_path is not None:
|
||||
file_size = file_path.stat().st_size
|
||||
logger.info(
|
||||
f"Uploading file '{file.filename}' to Gemini via path "
|
||||
f"({file_size} bytes, streaming)"
|
||||
)
|
||||
uploaded_file = await client.aio.files.upload(
|
||||
file=file_path,
|
||||
config={
|
||||
"display_name": display_name,
|
||||
"mime_type": file.content_type,
|
||||
},
|
||||
)
|
||||
else:
|
||||
content = await file.aread()
|
||||
file_data = io.BytesIO(content)
|
||||
file_data.name = file.filename
|
||||
|
||||
logger.info(
|
||||
f"Uploading file '{file.filename}' to Gemini ({len(content)} bytes)"
|
||||
)
|
||||
|
||||
uploaded_file = await client.aio.files.upload(
|
||||
file=file_data,
|
||||
config={
|
||||
"display_name": display_name,
|
||||
"mime_type": file.content_type,
|
||||
},
|
||||
)
|
||||
|
||||
if file.content_type.startswith("video/"):
|
||||
if not await self.await_for_processing(uploaded_file.name):
|
||||
raise PermanentUploadError(
|
||||
f"Video processing failed for {file.filename}",
|
||||
file_name=file.filename,
|
||||
)
|
||||
|
||||
expires_at = datetime.now(timezone.utc) + GEMINI_FILE_TTL
|
||||
|
||||
logger.info(
|
||||
f"Uploaded to Gemini: {uploaded_file.name} (URI: {uploaded_file.uri})"
|
||||
)
|
||||
|
||||
return UploadResult(
|
||||
file_id=uploaded_file.name,
|
||||
file_uri=uploaded_file.uri,
|
||||
content_type=file.content_type,
|
||||
expires_at=expires_at,
|
||||
provider=self.provider_name,
|
||||
)
|
||||
except ImportError:
|
||||
raise
|
||||
except (TransientUploadError, PermanentUploadError):
|
||||
raise
|
||||
except Exception as e:
|
||||
raise _classify_gemini_error(e, file.filename) from e
|
||||
|
||||
def delete(self, file_id: str) -> bool:
|
||||
"""Delete an uploaded file from Gemini.
|
||||
|
||||
Args:
|
||||
file_id: The file name/ID to delete.
|
||||
|
||||
Returns:
|
||||
True if deletion was successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
client.files.delete(name=file_id)
|
||||
logger.info(f"Deleted Gemini file: {file_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete Gemini file {file_id}: {e}")
|
||||
return False
|
||||
|
||||
async def adelete(self, file_id: str) -> bool:
|
||||
"""Async delete an uploaded file from Gemini.
|
||||
|
||||
Args:
|
||||
file_id: The file name/ID to delete.
|
||||
|
||||
Returns:
|
||||
True if deletion was successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
await client.aio.files.delete(name=file_id)
|
||||
logger.info(f"Deleted Gemini file: {file_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete Gemini file {file_id}: {e}")
|
||||
return False
|
||||
|
||||
def get_file_info(self, file_id: str) -> dict[str, Any] | None:
|
||||
"""Get information about an uploaded file.
|
||||
|
||||
Args:
|
||||
file_id: The file name/ID.
|
||||
|
||||
Returns:
|
||||
Dictionary with file information, or None if not found.
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
file_info = client.files.get(name=file_id)
|
||||
return {
|
||||
"name": file_info.name,
|
||||
"uri": file_info.uri,
|
||||
"display_name": file_info.display_name,
|
||||
"mime_type": file_info.mime_type,
|
||||
"size_bytes": file_info.size_bytes,
|
||||
"state": str(file_info.state),
|
||||
"create_time": file_info.create_time,
|
||||
"expiration_time": file_info.expiration_time,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to get Gemini file info for {file_id}: {e}")
|
||||
return None
|
||||
|
||||
def list_files(self) -> list[dict[str, Any]]:
|
||||
"""List all uploaded files.
|
||||
|
||||
Returns:
|
||||
List of dictionaries with file information.
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
files = client.files.list()
|
||||
return [
|
||||
{
|
||||
"name": f.name,
|
||||
"uri": f.uri,
|
||||
"display_name": f.display_name,
|
||||
"mime_type": f.mime_type,
|
||||
"size_bytes": f.size_bytes,
|
||||
"state": str(f.state),
|
||||
}
|
||||
for f in files
|
||||
]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to list Gemini files: {e}")
|
||||
return []
|
||||
|
||||
def wait_for_processing(self, file_id: str, timeout_seconds: int = 300) -> bool:
|
||||
"""Wait for a file to finish processing with exponential backoff.
|
||||
|
||||
Some files (especially videos) need time to process after upload.
|
||||
|
||||
Args:
|
||||
file_id: The file name/ID.
|
||||
timeout_seconds: Maximum time to wait.
|
||||
|
||||
Returns:
|
||||
True if processing completed, False if timed out or failed.
|
||||
"""
|
||||
try:
|
||||
from google.genai.types import FileState
|
||||
except ImportError:
|
||||
return True
|
||||
|
||||
client = self._get_client()
|
||||
start_time = time.time()
|
||||
attempt = 0
|
||||
|
||||
while time.time() - start_time < timeout_seconds:
|
||||
file_info = client.files.get(name=file_id)
|
||||
|
||||
if file_info.state == FileState.ACTIVE:
|
||||
return True
|
||||
if file_info.state == FileState.FAILED:
|
||||
logger.error(f"Gemini file processing failed: {file_id}")
|
||||
return False
|
||||
|
||||
time.sleep(_compute_backoff_delay(attempt))
|
||||
attempt += 1
|
||||
|
||||
logger.warning(f"Timed out waiting for Gemini file processing: {file_id}")
|
||||
return False
|
||||
|
||||
async def await_for_processing(
|
||||
self, file_id: str, timeout_seconds: int = 300
|
||||
) -> bool:
|
||||
"""Async wait for a file to finish processing with exponential backoff.
|
||||
|
||||
Some files (especially videos) need time to process after upload.
|
||||
|
||||
Args:
|
||||
file_id: The file name/ID.
|
||||
timeout_seconds: Maximum time to wait.
|
||||
|
||||
Returns:
|
||||
True if processing completed, False if timed out or failed.
|
||||
"""
|
||||
try:
|
||||
from google.genai.types import FileState
|
||||
except ImportError:
|
||||
return True
|
||||
|
||||
client = self._get_client()
|
||||
start_time = time.time()
|
||||
attempt = 0
|
||||
|
||||
while time.time() - start_time < timeout_seconds:
|
||||
file_info = await client.aio.files.get(name=file_id)
|
||||
|
||||
if file_info.state == FileState.ACTIVE:
|
||||
return True
|
||||
if file_info.state == FileState.FAILED:
|
||||
logger.error(f"Gemini file processing failed: {file_id}")
|
||||
return False
|
||||
|
||||
await asyncio.sleep(_compute_backoff_delay(attempt))
|
||||
attempt += 1
|
||||
|
||||
logger.warning(f"Timed out waiting for Gemini file processing: {file_id}")
|
||||
return False
|
||||
@@ -1,669 +0,0 @@
|
||||
"""OpenAI Files API uploader implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from crewai_files.core.constants import DEFAULT_UPLOAD_CHUNK_SIZE, FILES_API_MAX_SIZE
|
||||
from crewai_files.core.sources import FileBytes, FilePath, FileStream
|
||||
from crewai_files.core.types import FileInput
|
||||
from crewai_files.processing.exceptions import (
|
||||
PermanentUploadError,
|
||||
TransientUploadError,
|
||||
classify_upload_error,
|
||||
)
|
||||
from crewai_files.uploaders.base import FileUploader, UploadResult
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_file_size(file: FileInput) -> int | None:
|
||||
"""Get file size without reading content if possible.
|
||||
|
||||
Args:
|
||||
file: The file to get size for.
|
||||
|
||||
Returns:
|
||||
File size in bytes, or None if size cannot be determined without reading.
|
||||
"""
|
||||
source = file._file_source
|
||||
if isinstance(source, FilePath):
|
||||
return source.path.stat().st_size
|
||||
if isinstance(source, FileBytes):
|
||||
return len(source.data)
|
||||
return None
|
||||
|
||||
|
||||
def _iter_file_chunks(file: FileInput, chunk_size: int) -> Iterator[bytes]:
|
||||
"""Iterate over file content in chunks.
|
||||
|
||||
Args:
|
||||
file: The file to read.
|
||||
chunk_size: Size of each chunk in bytes.
|
||||
|
||||
Yields:
|
||||
Chunks of file content.
|
||||
"""
|
||||
source = file._file_source
|
||||
if isinstance(source, (FilePath, FileBytes, FileStream)):
|
||||
yield from source.read_chunks(chunk_size)
|
||||
else:
|
||||
content = file.read()
|
||||
for i in range(0, len(content), chunk_size):
|
||||
yield content[i : i + chunk_size]
|
||||
|
||||
|
||||
async def _aiter_file_chunks(
|
||||
file: FileInput, chunk_size: int, content: bytes | None = None
|
||||
) -> AsyncIterator[bytes]:
|
||||
"""Async iterate over file content in chunks.
|
||||
|
||||
Args:
|
||||
file: The file to read.
|
||||
chunk_size: Size of each chunk in bytes.
|
||||
content: Optional pre-loaded content to chunk.
|
||||
|
||||
Yields:
|
||||
Chunks of file content.
|
||||
"""
|
||||
if content is not None:
|
||||
for i in range(0, len(content), chunk_size):
|
||||
yield content[i : i + chunk_size]
|
||||
return
|
||||
|
||||
source = file._file_source
|
||||
if isinstance(source, FilePath):
|
||||
async for chunk in source.aread_chunks(chunk_size):
|
||||
yield chunk
|
||||
elif isinstance(source, (FileBytes, FileStream)):
|
||||
for chunk in source.read_chunks(chunk_size):
|
||||
yield chunk
|
||||
else:
|
||||
data = await file.aread()
|
||||
for i in range(0, len(data), chunk_size):
|
||||
yield data[i : i + chunk_size]
|
||||
|
||||
|
||||
class OpenAIFileUploader(FileUploader):
|
||||
"""Uploader for OpenAI Files and Uploads APIs.
|
||||
|
||||
Uses the Files API for files up to 512MB (single request).
|
||||
Uses the Uploads API for files larger than 512MB (multipart chunked).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
chunk_size: int = DEFAULT_UPLOAD_CHUNK_SIZE,
|
||||
) -> None:
|
||||
"""Initialize the OpenAI uploader.
|
||||
|
||||
Args:
|
||||
api_key: Optional OpenAI API key. If not provided, uses
|
||||
OPENAI_API_KEY environment variable.
|
||||
chunk_size: Chunk size in bytes for multipart uploads (default 64MB).
|
||||
"""
|
||||
self._api_key = api_key or os.environ.get("OPENAI_API_KEY")
|
||||
self._chunk_size = chunk_size
|
||||
self._client: Any = None
|
||||
self._async_client: Any = None
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
"""Return the provider name."""
|
||||
return "openai"
|
||||
|
||||
def _build_upload_result(self, file_id: str, content_type: str) -> UploadResult:
|
||||
"""Build an UploadResult for a completed upload.
|
||||
|
||||
Args:
|
||||
file_id: The uploaded file ID.
|
||||
content_type: The file's content type.
|
||||
|
||||
Returns:
|
||||
UploadResult with the file metadata.
|
||||
"""
|
||||
return UploadResult(
|
||||
file_id=file_id,
|
||||
file_uri=None,
|
||||
content_type=content_type,
|
||||
expires_at=None,
|
||||
provider=self.provider_name,
|
||||
)
|
||||
|
||||
def _get_client(self) -> Any:
|
||||
"""Get or create the OpenAI client."""
|
||||
if self._client is None:
|
||||
try:
|
||||
from openai import OpenAI
|
||||
|
||||
self._client = OpenAI(api_key=self._api_key)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"openai is required for OpenAI file uploads. "
|
||||
"Install with: pip install openai"
|
||||
) from e
|
||||
return self._client
|
||||
|
||||
def _get_async_client(self) -> Any:
|
||||
"""Get or create the async OpenAI client."""
|
||||
if self._async_client is None:
|
||||
try:
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
self._async_client = AsyncOpenAI(api_key=self._api_key)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"openai is required for OpenAI file uploads. "
|
||||
"Install with: pip install openai"
|
||||
) from e
|
||||
return self._async_client
|
||||
|
||||
def upload(self, file: FileInput, purpose: str | None = None) -> UploadResult:
|
||||
"""Upload a file to OpenAI.
|
||||
|
||||
Uses Files API for files <= 512MB, Uploads API for larger files.
|
||||
For large files, streams chunks to avoid loading entire file in memory.
|
||||
|
||||
Args:
|
||||
file: The file to upload.
|
||||
purpose: Optional purpose for the file (default: "user_data").
|
||||
|
||||
Returns:
|
||||
UploadResult with the file ID and metadata.
|
||||
|
||||
Raises:
|
||||
TransientUploadError: For retryable errors (network, rate limits).
|
||||
PermanentUploadError: For non-retryable errors (auth, validation).
|
||||
"""
|
||||
try:
|
||||
file_size = _get_file_size(file)
|
||||
|
||||
if file_size is not None and file_size > FILES_API_MAX_SIZE:
|
||||
return self._upload_multipart_streaming(file, file_size, purpose)
|
||||
|
||||
content = file.read()
|
||||
if len(content) > FILES_API_MAX_SIZE:
|
||||
return self._upload_multipart(file, content, purpose)
|
||||
return self._upload_simple(file, content, purpose)
|
||||
except ImportError:
|
||||
raise
|
||||
except (TransientUploadError, PermanentUploadError):
|
||||
raise
|
||||
except Exception as e:
|
||||
raise classify_upload_error(e, file.filename) from e
|
||||
|
||||
def _upload_simple(
|
||||
self,
|
||||
file: FileInput,
|
||||
content: bytes,
|
||||
purpose: str | None,
|
||||
) -> UploadResult:
|
||||
"""Upload using the Files API (single request, up to 512MB).
|
||||
|
||||
Args:
|
||||
file: The file to upload.
|
||||
content: File content bytes.
|
||||
purpose: Optional purpose for the file.
|
||||
|
||||
Returns:
|
||||
UploadResult with the file ID and metadata.
|
||||
"""
|
||||
client = self._get_client()
|
||||
file_purpose = purpose or "user_data"
|
||||
|
||||
file_data = io.BytesIO(content)
|
||||
file_data.name = file.filename or "file"
|
||||
|
||||
logger.info(
|
||||
f"Uploading file '{file.filename}' to OpenAI Files API ({len(content)} bytes)"
|
||||
)
|
||||
|
||||
uploaded_file = client.files.create(
|
||||
file=file_data,
|
||||
purpose=file_purpose,
|
||||
)
|
||||
|
||||
logger.info(f"Uploaded to OpenAI: {uploaded_file.id}")
|
||||
|
||||
return self._build_upload_result(uploaded_file.id, file.content_type)
|
||||
|
||||
def _upload_multipart(
|
||||
self,
|
||||
file: FileInput,
|
||||
content: bytes,
|
||||
purpose: str | None,
|
||||
) -> UploadResult:
|
||||
"""Upload using the Uploads API with content already in memory.
|
||||
|
||||
Args:
|
||||
file: The file to upload.
|
||||
content: File content bytes (already loaded).
|
||||
purpose: Optional purpose for the file.
|
||||
|
||||
Returns:
|
||||
UploadResult with the file ID and metadata.
|
||||
"""
|
||||
client = self._get_client()
|
||||
file_purpose = purpose or "user_data"
|
||||
filename = file.filename or "file"
|
||||
file_size = len(content)
|
||||
|
||||
logger.info(
|
||||
f"Uploading file '{filename}' to OpenAI Uploads API "
|
||||
f"({file_size} bytes, {self._chunk_size} byte chunks)"
|
||||
)
|
||||
|
||||
upload = client.uploads.create(
|
||||
bytes=file_size,
|
||||
filename=filename,
|
||||
mime_type=file.content_type,
|
||||
purpose=file_purpose,
|
||||
)
|
||||
|
||||
part_ids: list[str] = []
|
||||
offset = 0
|
||||
part_num = 1
|
||||
|
||||
try:
|
||||
while offset < file_size:
|
||||
chunk = content[offset : offset + self._chunk_size]
|
||||
chunk_io = io.BytesIO(chunk)
|
||||
|
||||
logger.debug(
|
||||
f"Uploading part {part_num} ({len(chunk)} bytes, offset {offset})"
|
||||
)
|
||||
|
||||
part = client.uploads.parts.create(
|
||||
upload_id=upload.id,
|
||||
data=chunk_io,
|
||||
)
|
||||
part_ids.append(part.id)
|
||||
|
||||
offset += self._chunk_size
|
||||
part_num += 1
|
||||
|
||||
completed = client.uploads.complete(
|
||||
upload_id=upload.id,
|
||||
part_ids=part_ids,
|
||||
)
|
||||
|
||||
file_id = completed.file.id if completed.file else upload.id
|
||||
logger.info(f"Completed multipart upload to OpenAI: {file_id}")
|
||||
|
||||
return self._build_upload_result(file_id, file.content_type)
|
||||
except Exception:
|
||||
logger.warning(f"Multipart upload failed, cancelling upload {upload.id}")
|
||||
try:
|
||||
client.uploads.cancel(upload_id=upload.id)
|
||||
except Exception as cancel_err:
|
||||
logger.debug(f"Failed to cancel upload: {cancel_err}")
|
||||
raise
|
||||
|
||||
def _upload_multipart_streaming(
|
||||
self,
|
||||
file: FileInput,
|
||||
file_size: int,
|
||||
purpose: str | None,
|
||||
) -> UploadResult:
|
||||
"""Upload using the Uploads API with streaming chunks.
|
||||
|
||||
Streams chunks directly from the file source without loading
|
||||
the entire file into memory. Used for large files.
|
||||
|
||||
Args:
|
||||
file: The file to upload.
|
||||
file_size: Total file size in bytes.
|
||||
purpose: Optional purpose for the file.
|
||||
|
||||
Returns:
|
||||
UploadResult with the file ID and metadata.
|
||||
"""
|
||||
client = self._get_client()
|
||||
file_purpose = purpose or "user_data"
|
||||
filename = file.filename or "file"
|
||||
|
||||
logger.info(
|
||||
f"Uploading file '{filename}' to OpenAI Uploads API (streaming) "
|
||||
f"({file_size} bytes, {self._chunk_size} byte chunks)"
|
||||
)
|
||||
|
||||
upload = client.uploads.create(
|
||||
bytes=file_size,
|
||||
filename=filename,
|
||||
mime_type=file.content_type,
|
||||
purpose=file_purpose,
|
||||
)
|
||||
|
||||
part_ids: list[str] = []
|
||||
part_num = 1
|
||||
|
||||
try:
|
||||
for chunk in _iter_file_chunks(file, self._chunk_size):
|
||||
chunk_io = io.BytesIO(chunk)
|
||||
|
||||
logger.debug(f"Uploading part {part_num} ({len(chunk)} bytes)")
|
||||
|
||||
part = client.uploads.parts.create(
|
||||
upload_id=upload.id,
|
||||
data=chunk_io,
|
||||
)
|
||||
part_ids.append(part.id)
|
||||
part_num += 1
|
||||
|
||||
completed = client.uploads.complete(
|
||||
upload_id=upload.id,
|
||||
part_ids=part_ids,
|
||||
)
|
||||
|
||||
file_id = completed.file.id if completed.file else upload.id
|
||||
logger.info(f"Completed streaming multipart upload to OpenAI: {file_id}")
|
||||
|
||||
return self._build_upload_result(file_id, file.content_type)
|
||||
except Exception:
|
||||
logger.warning(f"Multipart upload failed, cancelling upload {upload.id}")
|
||||
try:
|
||||
client.uploads.cancel(upload_id=upload.id)
|
||||
except Exception as cancel_err:
|
||||
logger.debug(f"Failed to cancel upload: {cancel_err}")
|
||||
raise
|
||||
|
||||
def delete(self, file_id: str) -> bool:
|
||||
"""Delete an uploaded file from OpenAI.
|
||||
|
||||
Args:
|
||||
file_id: The file ID to delete.
|
||||
|
||||
Returns:
|
||||
True if deletion was successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
client.files.delete(file_id)
|
||||
logger.info(f"Deleted OpenAI file: {file_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete OpenAI file {file_id}: {e}")
|
||||
return False
|
||||
|
||||
def get_file_info(self, file_id: str) -> dict[str, Any] | None:
|
||||
"""Get information about an uploaded file.
|
||||
|
||||
Args:
|
||||
file_id: The file ID.
|
||||
|
||||
Returns:
|
||||
Dictionary with file information, or None if not found.
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
file_info = client.files.retrieve(file_id)
|
||||
return {
|
||||
"id": file_info.id,
|
||||
"filename": file_info.filename,
|
||||
"purpose": file_info.purpose,
|
||||
"bytes": file_info.bytes,
|
||||
"created_at": file_info.created_at,
|
||||
"status": file_info.status,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to get OpenAI file info for {file_id}: {e}")
|
||||
return None
|
||||
|
||||
def list_files(self) -> list[dict[str, Any]]:
|
||||
"""List all uploaded files.
|
||||
|
||||
Returns:
|
||||
List of dictionaries with file information.
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
files = client.files.list()
|
||||
return [
|
||||
{
|
||||
"id": f.id,
|
||||
"filename": f.filename,
|
||||
"purpose": f.purpose,
|
||||
"bytes": f.bytes,
|
||||
"created_at": f.created_at,
|
||||
"status": f.status,
|
||||
}
|
||||
for f in files.data
|
||||
]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to list OpenAI files: {e}")
|
||||
return []
|
||||
|
||||
async def aupload(
|
||||
self, file: FileInput, purpose: str | None = None
|
||||
) -> UploadResult:
|
||||
"""Async upload a file to OpenAI using native async client.
|
||||
|
||||
Uses Files API for files <= 512MB, Uploads API for larger files.
|
||||
For large files, streams chunks to avoid loading entire file in memory.
|
||||
|
||||
Args:
|
||||
file: The file to upload.
|
||||
purpose: Optional purpose for the file (default: "user_data").
|
||||
|
||||
Returns:
|
||||
UploadResult with the file ID and metadata.
|
||||
|
||||
Raises:
|
||||
TransientUploadError: For retryable errors (network, rate limits).
|
||||
PermanentUploadError: For non-retryable errors (auth, validation).
|
||||
"""
|
||||
try:
|
||||
file_size = _get_file_size(file)
|
||||
|
||||
if file_size is not None and file_size > FILES_API_MAX_SIZE:
|
||||
return await self._aupload_multipart_streaming(file, file_size, purpose)
|
||||
|
||||
content = await file.aread()
|
||||
if len(content) > FILES_API_MAX_SIZE:
|
||||
return await self._aupload_multipart(file, content, purpose)
|
||||
return await self._aupload_simple(file, content, purpose)
|
||||
except ImportError:
|
||||
raise
|
||||
except (TransientUploadError, PermanentUploadError):
|
||||
raise
|
||||
except Exception as e:
|
||||
raise classify_upload_error(e, file.filename) from e
|
||||
|
||||
async def _aupload_simple(
|
||||
self,
|
||||
file: FileInput,
|
||||
content: bytes,
|
||||
purpose: str | None,
|
||||
) -> UploadResult:
|
||||
"""Async upload using the Files API (single request, up to 512MB).
|
||||
|
||||
Args:
|
||||
file: The file to upload.
|
||||
content: File content bytes.
|
||||
purpose: Optional purpose for the file.
|
||||
|
||||
Returns:
|
||||
UploadResult with the file ID and metadata.
|
||||
"""
|
||||
client = self._get_async_client()
|
||||
file_purpose = purpose or "user_data"
|
||||
|
||||
file_data = io.BytesIO(content)
|
||||
file_data.name = file.filename or "file"
|
||||
|
||||
logger.info(
|
||||
f"Uploading file '{file.filename}' to OpenAI Files API ({len(content)} bytes)"
|
||||
)
|
||||
|
||||
uploaded_file = await client.files.create(
|
||||
file=file_data,
|
||||
purpose=file_purpose,
|
||||
)
|
||||
|
||||
logger.info(f"Uploaded to OpenAI: {uploaded_file.id}")
|
||||
|
||||
return self._build_upload_result(uploaded_file.id, file.content_type)
|
||||
|
||||
async def _aupload_multipart(
|
||||
self,
|
||||
file: FileInput,
|
||||
content: bytes,
|
||||
purpose: str | None,
|
||||
) -> UploadResult:
|
||||
"""Async upload using the Uploads API (multipart chunked, up to 8GB).
|
||||
|
||||
Args:
|
||||
file: The file to upload.
|
||||
content: File content bytes.
|
||||
purpose: Optional purpose for the file.
|
||||
|
||||
Returns:
|
||||
UploadResult with the file ID and metadata.
|
||||
"""
|
||||
client = self._get_async_client()
|
||||
file_purpose = purpose or "user_data"
|
||||
filename = file.filename or "file"
|
||||
file_size = len(content)
|
||||
|
||||
logger.info(
|
||||
f"Uploading file '{filename}' to OpenAI Uploads API "
|
||||
f"({file_size} bytes, {self._chunk_size} byte chunks)"
|
||||
)
|
||||
|
||||
upload = await client.uploads.create(
|
||||
bytes=file_size,
|
||||
filename=filename,
|
||||
mime_type=file.content_type,
|
||||
purpose=file_purpose,
|
||||
)
|
||||
|
||||
part_ids: list[str] = []
|
||||
offset = 0
|
||||
part_num = 1
|
||||
|
||||
try:
|
||||
while offset < file_size:
|
||||
chunk = content[offset : offset + self._chunk_size]
|
||||
chunk_io = io.BytesIO(chunk)
|
||||
|
||||
logger.debug(
|
||||
f"Uploading part {part_num} ({len(chunk)} bytes, offset {offset})"
|
||||
)
|
||||
|
||||
part = await client.uploads.parts.create(
|
||||
upload_id=upload.id,
|
||||
data=chunk_io,
|
||||
)
|
||||
part_ids.append(part.id)
|
||||
|
||||
offset += self._chunk_size
|
||||
part_num += 1
|
||||
|
||||
completed = await client.uploads.complete(
|
||||
upload_id=upload.id,
|
||||
part_ids=part_ids,
|
||||
)
|
||||
|
||||
file_id = completed.file.id if completed.file else upload.id
|
||||
logger.info(f"Completed multipart upload to OpenAI: {file_id}")
|
||||
|
||||
return self._build_upload_result(file_id, file.content_type)
|
||||
except Exception:
|
||||
logger.warning(f"Multipart upload failed, cancelling upload {upload.id}")
|
||||
try:
|
||||
await client.uploads.cancel(upload_id=upload.id)
|
||||
except Exception as cancel_err:
|
||||
logger.debug(f"Failed to cancel upload: {cancel_err}")
|
||||
raise
|
||||
|
||||
async def _aupload_multipart_streaming(
|
||||
self,
|
||||
file: FileInput,
|
||||
file_size: int,
|
||||
purpose: str | None,
|
||||
) -> UploadResult:
|
||||
"""Async upload using the Uploads API with streaming chunks.
|
||||
|
||||
Streams chunks directly from the file source without loading
|
||||
the entire file into memory. Used for large files.
|
||||
|
||||
Args:
|
||||
file: The file to upload.
|
||||
file_size: Total file size in bytes.
|
||||
purpose: Optional purpose for the file.
|
||||
|
||||
Returns:
|
||||
UploadResult with the file ID and metadata.
|
||||
"""
|
||||
client = self._get_async_client()
|
||||
file_purpose = purpose or "user_data"
|
||||
filename = file.filename or "file"
|
||||
|
||||
logger.info(
|
||||
f"Uploading file '{filename}' to OpenAI Uploads API (streaming) "
|
||||
f"({file_size} bytes, {self._chunk_size} byte chunks)"
|
||||
)
|
||||
|
||||
upload = await client.uploads.create(
|
||||
bytes=file_size,
|
||||
filename=filename,
|
||||
mime_type=file.content_type,
|
||||
purpose=file_purpose,
|
||||
)
|
||||
|
||||
part_ids: list[str] = []
|
||||
part_num = 1
|
||||
|
||||
try:
|
||||
async for chunk in _aiter_file_chunks(file, self._chunk_size):
|
||||
chunk_io = io.BytesIO(chunk)
|
||||
|
||||
logger.debug(f"Uploading part {part_num} ({len(chunk)} bytes)")
|
||||
|
||||
part = await client.uploads.parts.create(
|
||||
upload_id=upload.id,
|
||||
data=chunk_io,
|
||||
)
|
||||
part_ids.append(part.id)
|
||||
part_num += 1
|
||||
|
||||
completed = await client.uploads.complete(
|
||||
upload_id=upload.id,
|
||||
part_ids=part_ids,
|
||||
)
|
||||
|
||||
file_id = completed.file.id if completed.file else upload.id
|
||||
logger.info(f"Completed streaming multipart upload to OpenAI: {file_id}")
|
||||
|
||||
return self._build_upload_result(file_id, file.content_type)
|
||||
except Exception:
|
||||
logger.warning(f"Multipart upload failed, cancelling upload {upload.id}")
|
||||
try:
|
||||
await client.uploads.cancel(upload_id=upload.id)
|
||||
except Exception as cancel_err:
|
||||
logger.debug(f"Failed to cancel upload: {cancel_err}")
|
||||
raise
|
||||
|
||||
async def adelete(self, file_id: str) -> bool:
|
||||
"""Async delete an uploaded file from OpenAI.
|
||||
|
||||
Args:
|
||||
file_id: The file ID to delete.
|
||||
|
||||
Returns:
|
||||
True if deletion was successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
client = self._get_async_client()
|
||||
await client.files.delete(file_id)
|
||||
logger.info(f"Deleted OpenAI file: {file_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete OpenAI file {file_id}: {e}")
|
||||
return False
|
||||
@@ -1,225 +0,0 @@
|
||||
"""Tests for provider constraints."""
|
||||
|
||||
from crewai_files.processing.constraints import (
|
||||
ANTHROPIC_CONSTRAINTS,
|
||||
BEDROCK_CONSTRAINTS,
|
||||
GEMINI_CONSTRAINTS,
|
||||
OPENAI_CONSTRAINTS,
|
||||
AudioConstraints,
|
||||
ImageConstraints,
|
||||
PDFConstraints,
|
||||
ProviderConstraints,
|
||||
VideoConstraints,
|
||||
get_constraints_for_provider,
|
||||
)
|
||||
import pytest
|
||||
|
||||
|
||||
class TestImageConstraints:
|
||||
"""Tests for ImageConstraints dataclass."""
|
||||
|
||||
def test_image_constraints_creation(self):
|
||||
"""Test creating image constraints with all fields."""
|
||||
constraints = ImageConstraints(
|
||||
max_size_bytes=5 * 1024 * 1024,
|
||||
max_width=8000,
|
||||
max_height=8000,
|
||||
max_images_per_request=10,
|
||||
)
|
||||
|
||||
assert constraints.max_size_bytes == 5 * 1024 * 1024
|
||||
assert constraints.max_width == 8000
|
||||
assert constraints.max_height == 8000
|
||||
assert constraints.max_images_per_request == 10
|
||||
|
||||
def test_image_constraints_defaults(self):
|
||||
"""Test image constraints with default values."""
|
||||
constraints = ImageConstraints(max_size_bytes=1000)
|
||||
|
||||
assert constraints.max_size_bytes == 1000
|
||||
assert constraints.max_width is None
|
||||
assert constraints.max_height is None
|
||||
assert constraints.max_images_per_request is None
|
||||
assert "image/png" in constraints.supported_formats
|
||||
|
||||
def test_image_constraints_frozen(self):
|
||||
"""Test that image constraints are immutable."""
|
||||
constraints = ImageConstraints(max_size_bytes=1000)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
constraints.max_size_bytes = 2000
|
||||
|
||||
|
||||
class TestPDFConstraints:
|
||||
"""Tests for PDFConstraints dataclass."""
|
||||
|
||||
def test_pdf_constraints_creation(self):
|
||||
"""Test creating PDF constraints."""
|
||||
constraints = PDFConstraints(
|
||||
max_size_bytes=30 * 1024 * 1024,
|
||||
max_pages=100,
|
||||
)
|
||||
|
||||
assert constraints.max_size_bytes == 30 * 1024 * 1024
|
||||
assert constraints.max_pages == 100
|
||||
|
||||
def test_pdf_constraints_defaults(self):
|
||||
"""Test PDF constraints with default values."""
|
||||
constraints = PDFConstraints(max_size_bytes=1000)
|
||||
|
||||
assert constraints.max_size_bytes == 1000
|
||||
assert constraints.max_pages is None
|
||||
|
||||
|
||||
class TestAudioConstraints:
|
||||
"""Tests for AudioConstraints dataclass."""
|
||||
|
||||
def test_audio_constraints_creation(self):
|
||||
"""Test creating audio constraints."""
|
||||
constraints = AudioConstraints(
|
||||
max_size_bytes=100 * 1024 * 1024,
|
||||
max_duration_seconds=3600,
|
||||
)
|
||||
|
||||
assert constraints.max_size_bytes == 100 * 1024 * 1024
|
||||
assert constraints.max_duration_seconds == 3600
|
||||
assert "audio/mp3" in constraints.supported_formats
|
||||
|
||||
|
||||
class TestVideoConstraints:
|
||||
"""Tests for VideoConstraints dataclass."""
|
||||
|
||||
def test_video_constraints_creation(self):
|
||||
"""Test creating video constraints."""
|
||||
constraints = VideoConstraints(
|
||||
max_size_bytes=2 * 1024 * 1024 * 1024,
|
||||
max_duration_seconds=7200,
|
||||
)
|
||||
|
||||
assert constraints.max_size_bytes == 2 * 1024 * 1024 * 1024
|
||||
assert constraints.max_duration_seconds == 7200
|
||||
assert "video/mp4" in constraints.supported_formats
|
||||
|
||||
|
||||
class TestProviderConstraints:
|
||||
"""Tests for ProviderConstraints dataclass."""
|
||||
|
||||
def test_provider_constraints_creation(self):
|
||||
"""Test creating full provider constraints."""
|
||||
constraints = ProviderConstraints(
|
||||
name="test-provider",
|
||||
image=ImageConstraints(max_size_bytes=5 * 1024 * 1024),
|
||||
pdf=PDFConstraints(max_size_bytes=30 * 1024 * 1024),
|
||||
supports_file_upload=True,
|
||||
file_upload_threshold_bytes=10 * 1024 * 1024,
|
||||
)
|
||||
|
||||
assert constraints.name == "test-provider"
|
||||
assert constraints.image is not None
|
||||
assert constraints.pdf is not None
|
||||
assert constraints.supports_file_upload is True
|
||||
|
||||
def test_provider_constraints_defaults(self):
|
||||
"""Test provider constraints with default values."""
|
||||
constraints = ProviderConstraints(name="test")
|
||||
|
||||
assert constraints.name == "test"
|
||||
assert constraints.image is None
|
||||
assert constraints.pdf is None
|
||||
assert constraints.audio is None
|
||||
assert constraints.video is None
|
||||
assert constraints.supports_file_upload is False
|
||||
|
||||
|
||||
class TestPredefinedConstraints:
|
||||
"""Tests for predefined provider constraints."""
|
||||
|
||||
def test_anthropic_constraints(self):
|
||||
"""Test Anthropic constraints are properly defined."""
|
||||
assert ANTHROPIC_CONSTRAINTS.name == "anthropic"
|
||||
assert ANTHROPIC_CONSTRAINTS.image is not None
|
||||
assert ANTHROPIC_CONSTRAINTS.image.max_size_bytes == 5 * 1024 * 1024
|
||||
assert ANTHROPIC_CONSTRAINTS.image.max_width == 8000
|
||||
assert ANTHROPIC_CONSTRAINTS.pdf is not None
|
||||
assert ANTHROPIC_CONSTRAINTS.pdf.max_pages == 100
|
||||
assert ANTHROPIC_CONSTRAINTS.supports_file_upload is True
|
||||
|
||||
def test_openai_constraints(self):
|
||||
"""Test OpenAI constraints are properly defined."""
|
||||
assert OPENAI_CONSTRAINTS.name == "openai"
|
||||
assert OPENAI_CONSTRAINTS.image is not None
|
||||
assert OPENAI_CONSTRAINTS.image.max_size_bytes == 20 * 1024 * 1024
|
||||
assert OPENAI_CONSTRAINTS.pdf is None # OpenAI doesn't support PDFs
|
||||
|
||||
def test_gemini_constraints(self):
|
||||
"""Test Gemini constraints are properly defined."""
|
||||
assert GEMINI_CONSTRAINTS.name == "gemini"
|
||||
assert GEMINI_CONSTRAINTS.image is not None
|
||||
assert GEMINI_CONSTRAINTS.pdf is not None
|
||||
assert GEMINI_CONSTRAINTS.audio is not None
|
||||
assert GEMINI_CONSTRAINTS.video is not None
|
||||
assert GEMINI_CONSTRAINTS.supports_file_upload is True
|
||||
|
||||
def test_bedrock_constraints(self):
|
||||
"""Test Bedrock constraints are properly defined."""
|
||||
assert BEDROCK_CONSTRAINTS.name == "bedrock"
|
||||
assert BEDROCK_CONSTRAINTS.image is not None
|
||||
assert BEDROCK_CONSTRAINTS.image.max_size_bytes == 4_608_000
|
||||
assert BEDROCK_CONSTRAINTS.pdf is not None
|
||||
assert BEDROCK_CONSTRAINTS.supports_file_upload is False
|
||||
|
||||
|
||||
class TestGetConstraintsForProvider:
|
||||
"""Tests for get_constraints_for_provider function."""
|
||||
|
||||
def test_get_by_exact_name(self):
|
||||
"""Test getting constraints by exact provider name."""
|
||||
result = get_constraints_for_provider("anthropic")
|
||||
assert result == ANTHROPIC_CONSTRAINTS
|
||||
|
||||
result = get_constraints_for_provider("openai")
|
||||
assert result == OPENAI_CONSTRAINTS
|
||||
|
||||
result = get_constraints_for_provider("gemini")
|
||||
assert result == GEMINI_CONSTRAINTS
|
||||
|
||||
def test_get_by_alias(self):
|
||||
"""Test getting constraints by alias name."""
|
||||
result = get_constraints_for_provider("claude")
|
||||
assert result == ANTHROPIC_CONSTRAINTS
|
||||
|
||||
result = get_constraints_for_provider("gpt")
|
||||
assert result == OPENAI_CONSTRAINTS
|
||||
|
||||
result = get_constraints_for_provider("google")
|
||||
assert result == GEMINI_CONSTRAINTS
|
||||
|
||||
def test_get_case_insensitive(self):
|
||||
"""Test case-insensitive lookup."""
|
||||
result = get_constraints_for_provider("ANTHROPIC")
|
||||
assert result == ANTHROPIC_CONSTRAINTS
|
||||
|
||||
result = get_constraints_for_provider("OpenAI")
|
||||
assert result == OPENAI_CONSTRAINTS
|
||||
|
||||
def test_get_with_provider_constraints_object(self):
|
||||
"""Test passing ProviderConstraints object returns it unchanged."""
|
||||
custom = ProviderConstraints(name="custom")
|
||||
result = get_constraints_for_provider(custom)
|
||||
assert result is custom
|
||||
|
||||
def test_get_unknown_provider(self):
|
||||
"""Test unknown provider returns None."""
|
||||
result = get_constraints_for_provider("unknown-provider")
|
||||
assert result is None
|
||||
|
||||
def test_get_by_partial_match(self):
|
||||
"""Test partial match in provider string."""
|
||||
result = get_constraints_for_provider("claude-3-sonnet")
|
||||
assert result == ANTHROPIC_CONSTRAINTS
|
||||
|
||||
result = get_constraints_for_provider("gpt-4o")
|
||||
assert result == OPENAI_CONSTRAINTS
|
||||
|
||||
result = get_constraints_for_provider("gemini-pro")
|
||||
assert result == GEMINI_CONSTRAINTS
|
||||
@@ -1,303 +0,0 @@
|
||||
"""Tests for FileProcessor class."""
|
||||
|
||||
from crewai_files import FileBytes, ImageFile
|
||||
from crewai_files.processing.constraints import (
|
||||
ANTHROPIC_CONSTRAINTS,
|
||||
ImageConstraints,
|
||||
ProviderConstraints,
|
||||
)
|
||||
from crewai_files.processing.enums import FileHandling
|
||||
from crewai_files.processing.exceptions import (
|
||||
FileTooLargeError,
|
||||
)
|
||||
from crewai_files.processing.processor import FileProcessor
|
||||
import pytest
|
||||
|
||||
|
||||
# Minimal valid PNG: 8x8 pixel RGB image (valid for PIL)
|
||||
MINIMAL_PNG = bytes(
|
||||
[
|
||||
0x89,
|
||||
0x50,
|
||||
0x4E,
|
||||
0x47,
|
||||
0x0D,
|
||||
0x0A,
|
||||
0x1A,
|
||||
0x0A,
|
||||
0x00,
|
||||
0x00,
|
||||
0x00,
|
||||
0x0D,
|
||||
0x49,
|
||||
0x48,
|
||||
0x44,
|
||||
0x52,
|
||||
0x00,
|
||||
0x00,
|
||||
0x00,
|
||||
0x08,
|
||||
0x00,
|
||||
0x00,
|
||||
0x00,
|
||||
0x08,
|
||||
0x08,
|
||||
0x02,
|
||||
0x00,
|
||||
0x00,
|
||||
0x00,
|
||||
0x4B,
|
||||
0x6D,
|
||||
0x29,
|
||||
0xDC,
|
||||
0x00,
|
||||
0x00,
|
||||
0x00,
|
||||
0x12,
|
||||
0x49,
|
||||
0x44,
|
||||
0x41,
|
||||
0x54,
|
||||
0x78,
|
||||
0x9C,
|
||||
0x63,
|
||||
0xFC,
|
||||
0xCF,
|
||||
0x80,
|
||||
0x1D,
|
||||
0x30,
|
||||
0xE1,
|
||||
0x10,
|
||||
0x1F,
|
||||
0xA4,
|
||||
0x12,
|
||||
0x00,
|
||||
0xCD,
|
||||
0x41,
|
||||
0x01,
|
||||
0x0F,
|
||||
0xE8,
|
||||
0x41,
|
||||
0xE2,
|
||||
0x6F,
|
||||
0x00,
|
||||
0x00,
|
||||
0x00,
|
||||
0x00,
|
||||
0x49,
|
||||
0x45,
|
||||
0x4E,
|
||||
0x44,
|
||||
0xAE,
|
||||
0x42,
|
||||
0x60,
|
||||
0x82,
|
||||
]
|
||||
)
|
||||
|
||||
# Minimal valid PDF
|
||||
MINIMAL_PDF = (
|
||||
b"%PDF-1.4\n1 0 obj<</Type/Catalog/Pages 2 0 R>>endobj "
|
||||
b"2 0 obj<</Type/Pages/Kids[3 0 R]/Count 1>>endobj "
|
||||
b"3 0 obj<</Type/Page/MediaBox[0 0 612 792]/Parent 2 0 R>>endobj "
|
||||
b"xref\n0 4\n0000000000 65535 f \n0000000009 00000 n \n"
|
||||
b"0000000052 00000 n \n0000000101 00000 n \n"
|
||||
b"trailer<</Size 4/Root 1 0 R>>\nstartxref\n178\n%%EOF"
|
||||
)
|
||||
|
||||
|
||||
class TestFileProcessorInit:
|
||||
"""Tests for FileProcessor initialization."""
|
||||
|
||||
def test_init_with_constraints(self):
|
||||
"""Test initialization with ProviderConstraints."""
|
||||
processor = FileProcessor(constraints=ANTHROPIC_CONSTRAINTS)
|
||||
|
||||
assert processor.constraints == ANTHROPIC_CONSTRAINTS
|
||||
|
||||
def test_init_with_provider_string(self):
|
||||
"""Test initialization with provider name string."""
|
||||
processor = FileProcessor(constraints="anthropic")
|
||||
|
||||
assert processor.constraints == ANTHROPIC_CONSTRAINTS
|
||||
|
||||
def test_init_with_unknown_provider(self):
|
||||
"""Test initialization with unknown provider sets constraints to None."""
|
||||
processor = FileProcessor(constraints="unknown")
|
||||
|
||||
assert processor.constraints is None
|
||||
|
||||
def test_init_with_none_constraints(self):
|
||||
"""Test initialization with None constraints."""
|
||||
processor = FileProcessor(constraints=None)
|
||||
|
||||
assert processor.constraints is None
|
||||
|
||||
|
||||
class TestFileProcessorValidate:
|
||||
"""Tests for FileProcessor.validate method."""
|
||||
|
||||
def test_validate_valid_file(self):
|
||||
"""Test validating a valid file returns no errors."""
|
||||
processor = FileProcessor(constraints=ANTHROPIC_CONSTRAINTS)
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
errors = processor.validate(file)
|
||||
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_validate_without_constraints(self):
|
||||
"""Test validating without constraints returns empty list."""
|
||||
processor = FileProcessor(constraints=None)
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
errors = processor.validate(file)
|
||||
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_validate_strict_raises_on_error(self):
|
||||
"""Test STRICT mode raises on validation error."""
|
||||
constraints = ProviderConstraints(
|
||||
name="test",
|
||||
image=ImageConstraints(max_size_bytes=10),
|
||||
)
|
||||
processor = FileProcessor(constraints=constraints)
|
||||
# Set mode to strict on the file
|
||||
file = ImageFile(
|
||||
source=FileBytes(data=MINIMAL_PNG, filename="test.png"), mode="strict"
|
||||
)
|
||||
|
||||
with pytest.raises(FileTooLargeError):
|
||||
processor.validate(file)
|
||||
|
||||
|
||||
class TestFileProcessorProcess:
|
||||
"""Tests for FileProcessor.process method."""
|
||||
|
||||
def test_process_valid_file(self):
|
||||
"""Test processing a valid file returns it unchanged."""
|
||||
processor = FileProcessor(constraints=ANTHROPIC_CONSTRAINTS)
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
result = processor.process(file)
|
||||
|
||||
assert result == file
|
||||
|
||||
def test_process_without_constraints(self):
|
||||
"""Test processing without constraints returns file unchanged."""
|
||||
processor = FileProcessor(constraints=None)
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
result = processor.process(file)
|
||||
|
||||
assert result == file
|
||||
|
||||
def test_process_strict_raises_on_error(self):
|
||||
"""Test STRICT mode raises on processing error."""
|
||||
constraints = ProviderConstraints(
|
||||
name="test",
|
||||
image=ImageConstraints(max_size_bytes=10),
|
||||
)
|
||||
processor = FileProcessor(constraints=constraints)
|
||||
# Set mode to strict on the file
|
||||
file = ImageFile(
|
||||
source=FileBytes(data=MINIMAL_PNG, filename="test.png"), mode="strict"
|
||||
)
|
||||
|
||||
with pytest.raises(FileTooLargeError):
|
||||
processor.process(file)
|
||||
|
||||
def test_process_warn_returns_file(self):
|
||||
"""Test WARN mode returns file with warning."""
|
||||
constraints = ProviderConstraints(
|
||||
name="test",
|
||||
image=ImageConstraints(max_size_bytes=10),
|
||||
)
|
||||
processor = FileProcessor(constraints=constraints)
|
||||
# Set mode to warn on the file
|
||||
file = ImageFile(
|
||||
source=FileBytes(data=MINIMAL_PNG, filename="test.png"), mode="warn"
|
||||
)
|
||||
|
||||
result = processor.process(file)
|
||||
|
||||
assert result == file
|
||||
|
||||
|
||||
class TestFileProcessorProcessFiles:
|
||||
"""Tests for FileProcessor.process_files method."""
|
||||
|
||||
def test_process_files_multiple(self):
|
||||
"""Test processing multiple files."""
|
||||
processor = FileProcessor(constraints=ANTHROPIC_CONSTRAINTS)
|
||||
files = {
|
||||
"image1": ImageFile(
|
||||
source=FileBytes(data=MINIMAL_PNG, filename="test1.png")
|
||||
),
|
||||
"image2": ImageFile(
|
||||
source=FileBytes(data=MINIMAL_PNG, filename="test2.png")
|
||||
),
|
||||
}
|
||||
|
||||
result = processor.process_files(files)
|
||||
|
||||
assert len(result) == 2
|
||||
assert "image1" in result
|
||||
assert "image2" in result
|
||||
|
||||
def test_process_files_empty(self):
|
||||
"""Test processing empty files dict."""
|
||||
processor = FileProcessor(constraints=ANTHROPIC_CONSTRAINTS)
|
||||
|
||||
result = processor.process_files({})
|
||||
|
||||
assert result == {}
|
||||
|
||||
|
||||
class TestFileHandlingEnum:
|
||||
"""Tests for FileHandling enum."""
|
||||
|
||||
def test_enum_values(self):
|
||||
"""Test all enum values are accessible."""
|
||||
assert FileHandling.STRICT.value == "strict"
|
||||
assert FileHandling.AUTO.value == "auto"
|
||||
assert FileHandling.WARN.value == "warn"
|
||||
assert FileHandling.CHUNK.value == "chunk"
|
||||
|
||||
|
||||
class TestFileProcessorPerFileMode:
|
||||
"""Tests for per-file mode handling."""
|
||||
|
||||
def test_file_default_mode_is_auto(self):
|
||||
"""Test that files default to auto mode."""
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
assert file.mode == "auto"
|
||||
|
||||
def test_file_custom_mode(self):
|
||||
"""Test setting custom mode on file."""
|
||||
file = ImageFile(
|
||||
source=FileBytes(data=MINIMAL_PNG, filename="test.png"), mode="strict"
|
||||
)
|
||||
assert file.mode == "strict"
|
||||
|
||||
def test_processor_respects_file_mode(self):
|
||||
"""Test processor uses each file's mode setting."""
|
||||
constraints = ProviderConstraints(
|
||||
name="test",
|
||||
image=ImageConstraints(max_size_bytes=10),
|
||||
)
|
||||
processor = FileProcessor(constraints=constraints)
|
||||
|
||||
# File with strict mode should raise
|
||||
strict_file = ImageFile(
|
||||
source=FileBytes(data=MINIMAL_PNG, filename="test.png"), mode="strict"
|
||||
)
|
||||
with pytest.raises(FileTooLargeError):
|
||||
processor.process(strict_file)
|
||||
|
||||
# File with warn mode should not raise
|
||||
warn_file = ImageFile(
|
||||
source=FileBytes(data=MINIMAL_PNG, filename="test.png"), mode="warn"
|
||||
)
|
||||
result = processor.process(warn_file)
|
||||
assert result == warn_file
|
||||
@@ -1,362 +0,0 @@
|
||||
"""Unit tests for file transformers."""
|
||||
|
||||
import io
|
||||
from unittest.mock import patch
|
||||
|
||||
from crewai_files import ImageFile, PDFFile, TextFile
|
||||
from crewai_files.core.sources import FileBytes
|
||||
from crewai_files.processing.exceptions import ProcessingDependencyError
|
||||
from crewai_files.processing.transformers import (
|
||||
chunk_pdf,
|
||||
chunk_text,
|
||||
get_image_dimensions,
|
||||
get_pdf_page_count,
|
||||
optimize_image,
|
||||
resize_image,
|
||||
)
|
||||
import pytest
|
||||
|
||||
|
||||
def create_test_png(width: int = 100, height: int = 100) -> bytes:
|
||||
"""Create a minimal valid PNG for testing."""
|
||||
from PIL import Image
|
||||
|
||||
img = Image.new("RGB", (width, height), color="red")
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format="PNG")
|
||||
return buffer.getvalue()
|
||||
|
||||
|
||||
def create_test_pdf(num_pages: int = 1) -> bytes:
|
||||
"""Create a minimal valid PDF for testing."""
|
||||
from pypdf import PdfWriter
|
||||
|
||||
writer = PdfWriter()
|
||||
for _ in range(num_pages):
|
||||
writer.add_blank_page(width=612, height=792)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
writer.write(buffer)
|
||||
return buffer.getvalue()
|
||||
|
||||
|
||||
class TestResizeImage:
|
||||
"""Tests for resize_image function."""
|
||||
|
||||
def test_resize_larger_image(self) -> None:
|
||||
"""Test resizing an image larger than max dimensions."""
|
||||
png_bytes = create_test_png(200, 150)
|
||||
img = ImageFile(source=FileBytes(data=png_bytes, filename="test.png"))
|
||||
|
||||
result = resize_image(img, max_width=100, max_height=100)
|
||||
|
||||
dims = get_image_dimensions(result)
|
||||
assert dims is not None
|
||||
width, height = dims
|
||||
assert width <= 100
|
||||
assert height <= 100
|
||||
|
||||
def test_no_resize_if_within_bounds(self) -> None:
|
||||
"""Test that small images are returned unchanged."""
|
||||
png_bytes = create_test_png(50, 50)
|
||||
img = ImageFile(source=FileBytes(data=png_bytes, filename="small.png"))
|
||||
|
||||
result = resize_image(img, max_width=100, max_height=100)
|
||||
|
||||
assert result is img
|
||||
|
||||
def test_preserve_aspect_ratio(self) -> None:
|
||||
"""Test that aspect ratio is preserved during resize."""
|
||||
png_bytes = create_test_png(200, 100)
|
||||
img = ImageFile(source=FileBytes(data=png_bytes, filename="wide.png"))
|
||||
|
||||
result = resize_image(img, max_width=100, max_height=100)
|
||||
|
||||
dims = get_image_dimensions(result)
|
||||
assert dims is not None
|
||||
width, height = dims
|
||||
assert width == 100
|
||||
assert height == 50
|
||||
|
||||
def test_resize_without_aspect_ratio(self) -> None:
|
||||
"""Test resizing without preserving aspect ratio."""
|
||||
png_bytes = create_test_png(200, 100)
|
||||
img = ImageFile(source=FileBytes(data=png_bytes, filename="wide.png"))
|
||||
|
||||
result = resize_image(
|
||||
img, max_width=50, max_height=50, preserve_aspect_ratio=False
|
||||
)
|
||||
|
||||
dims = get_image_dimensions(result)
|
||||
assert dims is not None
|
||||
width, height = dims
|
||||
assert width == 50
|
||||
assert height == 50
|
||||
|
||||
def test_resize_returns_image_file(self) -> None:
|
||||
"""Test that resize returns an ImageFile instance."""
|
||||
png_bytes = create_test_png(200, 200)
|
||||
img = ImageFile(source=FileBytes(data=png_bytes, filename="test.png"))
|
||||
|
||||
result = resize_image(img, max_width=100, max_height=100)
|
||||
|
||||
assert isinstance(result, ImageFile)
|
||||
|
||||
def test_raises_without_pillow(self) -> None:
|
||||
"""Test that ProcessingDependencyError is raised without Pillow."""
|
||||
img = ImageFile(source=FileBytes(data=b"fake", filename="test.png"))
|
||||
|
||||
with patch.dict("sys.modules", {"PIL": None, "PIL.Image": None}):
|
||||
with pytest.raises(ProcessingDependencyError) as exc_info:
|
||||
# Force reimport to trigger ImportError
|
||||
import importlib
|
||||
|
||||
import crewai_files.processing.transformers as t
|
||||
|
||||
importlib.reload(t)
|
||||
t.resize_image(img, 100, 100)
|
||||
|
||||
assert "Pillow" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestOptimizeImage:
|
||||
"""Tests for optimize_image function."""
|
||||
|
||||
def test_optimize_reduces_size(self) -> None:
|
||||
"""Test that optimization reduces file size."""
|
||||
png_bytes = create_test_png(500, 500)
|
||||
original_size = len(png_bytes)
|
||||
img = ImageFile(source=FileBytes(data=png_bytes, filename="large.png"))
|
||||
|
||||
result = optimize_image(img, target_size_bytes=original_size // 2)
|
||||
|
||||
result_size = len(result.read())
|
||||
assert result_size < original_size
|
||||
|
||||
def test_no_optimize_if_under_target(self) -> None:
|
||||
"""Test that small images are returned unchanged."""
|
||||
png_bytes = create_test_png(50, 50)
|
||||
img = ImageFile(source=FileBytes(data=png_bytes, filename="small.png"))
|
||||
|
||||
result = optimize_image(img, target_size_bytes=1024 * 1024)
|
||||
|
||||
assert result is img
|
||||
|
||||
def test_optimize_returns_image_file(self) -> None:
|
||||
"""Test that optimize returns an ImageFile instance."""
|
||||
png_bytes = create_test_png(200, 200)
|
||||
img = ImageFile(source=FileBytes(data=png_bytes, filename="test.png"))
|
||||
|
||||
result = optimize_image(img, target_size_bytes=100)
|
||||
|
||||
assert isinstance(result, ImageFile)
|
||||
|
||||
def test_optimize_respects_min_quality(self) -> None:
|
||||
"""Test that optimization stops at minimum quality."""
|
||||
png_bytes = create_test_png(100, 100)
|
||||
img = ImageFile(source=FileBytes(data=png_bytes, filename="test.png"))
|
||||
|
||||
# Request impossibly small size - should stop at min quality
|
||||
result = optimize_image(img, target_size_bytes=10, min_quality=50)
|
||||
|
||||
assert isinstance(result, ImageFile)
|
||||
assert len(result.read()) > 10
|
||||
|
||||
|
||||
class TestChunkPdf:
|
||||
"""Tests for chunk_pdf function."""
|
||||
|
||||
def test_chunk_splits_large_pdf(self) -> None:
|
||||
"""Test that large PDFs are split into chunks."""
|
||||
pdf_bytes = create_test_pdf(num_pages=10)
|
||||
pdf = PDFFile(source=FileBytes(data=pdf_bytes, filename="large.pdf"))
|
||||
|
||||
result = list(chunk_pdf(pdf, max_pages=3))
|
||||
|
||||
assert len(result) == 4
|
||||
assert all(isinstance(chunk, PDFFile) for chunk in result)
|
||||
|
||||
def test_no_chunk_if_within_limit(self) -> None:
|
||||
"""Test that small PDFs are returned unchanged."""
|
||||
pdf_bytes = create_test_pdf(num_pages=3)
|
||||
pdf = PDFFile(source=FileBytes(data=pdf_bytes, filename="small.pdf"))
|
||||
|
||||
result = list(chunk_pdf(pdf, max_pages=5))
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0] is pdf
|
||||
|
||||
def test_chunk_filenames(self) -> None:
|
||||
"""Test that chunked files have indexed filenames."""
|
||||
pdf_bytes = create_test_pdf(num_pages=6)
|
||||
pdf = PDFFile(source=FileBytes(data=pdf_bytes, filename="document.pdf"))
|
||||
|
||||
result = list(chunk_pdf(pdf, max_pages=2))
|
||||
|
||||
assert result[0].filename == "document_chunk_0.pdf"
|
||||
assert result[1].filename == "document_chunk_1.pdf"
|
||||
assert result[2].filename == "document_chunk_2.pdf"
|
||||
|
||||
def test_chunk_with_overlap(self) -> None:
|
||||
"""Test chunking with overlapping pages."""
|
||||
pdf_bytes = create_test_pdf(num_pages=10)
|
||||
pdf = PDFFile(source=FileBytes(data=pdf_bytes, filename="doc.pdf"))
|
||||
|
||||
result = list(chunk_pdf(pdf, max_pages=4, overlap_pages=1))
|
||||
|
||||
# With overlap, we get more chunks
|
||||
assert len(result) >= 3
|
||||
|
||||
def test_chunk_page_counts(self) -> None:
|
||||
"""Test that each chunk has correct page count."""
|
||||
pdf_bytes = create_test_pdf(num_pages=7)
|
||||
pdf = PDFFile(source=FileBytes(data=pdf_bytes, filename="doc.pdf"))
|
||||
|
||||
result = list(chunk_pdf(pdf, max_pages=3))
|
||||
|
||||
page_counts = [get_pdf_page_count(chunk) for chunk in result]
|
||||
assert page_counts == [3, 3, 1]
|
||||
|
||||
|
||||
class TestChunkText:
|
||||
"""Tests for chunk_text function."""
|
||||
|
||||
def test_chunk_splits_large_text(self) -> None:
|
||||
"""Test that large text files are split into chunks."""
|
||||
content = "Hello world. " * 100
|
||||
text = TextFile(source=content.encode(), filename="large.txt")
|
||||
|
||||
result = list(chunk_text(text, max_chars=200, overlap_chars=0))
|
||||
|
||||
assert len(result) > 1
|
||||
assert all(isinstance(chunk, TextFile) for chunk in result)
|
||||
|
||||
def test_no_chunk_if_within_limit(self) -> None:
|
||||
"""Test that small text files are returned unchanged."""
|
||||
content = "Short text"
|
||||
text = TextFile(source=content.encode(), filename="small.txt")
|
||||
|
||||
result = list(chunk_text(text, max_chars=1000, overlap_chars=0))
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0] is text
|
||||
|
||||
def test_chunk_filenames(self) -> None:
|
||||
"""Test that chunked files have indexed filenames."""
|
||||
content = "A" * 500
|
||||
text = TextFile(source=FileBytes(data=content.encode(), filename="data.txt"))
|
||||
|
||||
result = list(chunk_text(text, max_chars=200, overlap_chars=0))
|
||||
|
||||
assert result[0].filename == "data_chunk_0.txt"
|
||||
assert result[1].filename == "data_chunk_1.txt"
|
||||
assert len(result) == 3
|
||||
|
||||
def test_chunk_preserves_extension(self) -> None:
|
||||
"""Test that file extension is preserved in chunks."""
|
||||
content = "A" * 500
|
||||
text = TextFile(source=FileBytes(data=content.encode(), filename="script.py"))
|
||||
|
||||
result = list(chunk_text(text, max_chars=200, overlap_chars=0))
|
||||
|
||||
assert all(chunk.filename.endswith(".py") for chunk in result)
|
||||
|
||||
def test_chunk_prefers_newline_boundaries(self) -> None:
|
||||
"""Test that chunking prefers to split at newlines."""
|
||||
content = "Line one\nLine two\nLine three\nLine four\nLine five"
|
||||
text = TextFile(source=content.encode(), filename="lines.txt")
|
||||
|
||||
result = list(
|
||||
chunk_text(text, max_chars=25, overlap_chars=0, split_on_newlines=True)
|
||||
)
|
||||
|
||||
# Should split at newline boundaries
|
||||
for chunk in result:
|
||||
chunk_text_content = chunk.read().decode()
|
||||
# Chunks should end at newlines (except possibly the last)
|
||||
if chunk != result[-1]:
|
||||
assert (
|
||||
chunk_text_content.endswith("\n") or len(chunk_text_content) <= 25
|
||||
)
|
||||
|
||||
def test_chunk_with_overlap(self) -> None:
|
||||
"""Test chunking with overlapping characters."""
|
||||
content = "ABCDEFGHIJ" * 10
|
||||
text = TextFile(source=content.encode(), filename="data.txt")
|
||||
|
||||
result = list(chunk_text(text, max_chars=30, overlap_chars=5))
|
||||
|
||||
# With overlap, chunks should share some content
|
||||
assert len(result) >= 3
|
||||
|
||||
def test_chunk_overlap_larger_than_max_chars(self) -> None:
|
||||
"""Test that overlap > max_chars doesn't cause infinite loop."""
|
||||
content = "A" * 100
|
||||
text = TextFile(source=content.encode(), filename="data.txt")
|
||||
|
||||
# overlap_chars > max_chars should still work (just with max overlap)
|
||||
result = list(chunk_text(text, max_chars=20, overlap_chars=50))
|
||||
|
||||
assert len(result) > 1
|
||||
# Should still complete without hanging
|
||||
|
||||
|
||||
class TestGetImageDimensions:
|
||||
"""Tests for get_image_dimensions function."""
|
||||
|
||||
def test_get_dimensions(self) -> None:
|
||||
"""Test getting image dimensions."""
|
||||
png_bytes = create_test_png(150, 100)
|
||||
img = ImageFile(source=FileBytes(data=png_bytes, filename="test.png"))
|
||||
|
||||
dims = get_image_dimensions(img)
|
||||
|
||||
assert dims == (150, 100)
|
||||
|
||||
def test_returns_none_for_invalid_image(self) -> None:
|
||||
"""Test that None is returned for invalid image data."""
|
||||
img = ImageFile(source=FileBytes(data=b"not an image", filename="bad.png"))
|
||||
|
||||
dims = get_image_dimensions(img)
|
||||
|
||||
assert dims is None
|
||||
|
||||
def test_returns_none_without_pillow(self) -> None:
|
||||
"""Test that None is returned when Pillow is not installed."""
|
||||
png_bytes = create_test_png(100, 100)
|
||||
ImageFile(source=FileBytes(data=png_bytes, filename="test.png"))
|
||||
|
||||
with patch.dict("sys.modules", {"PIL": None}):
|
||||
# Can't easily test this without unloading module
|
||||
# Just verify the function handles the case gracefully
|
||||
pass
|
||||
|
||||
|
||||
class TestGetPdfPageCount:
|
||||
"""Tests for get_pdf_page_count function."""
|
||||
|
||||
def test_get_page_count(self) -> None:
|
||||
"""Test getting PDF page count."""
|
||||
pdf_bytes = create_test_pdf(num_pages=5)
|
||||
pdf = PDFFile(source=FileBytes(data=pdf_bytes, filename="test.pdf"))
|
||||
|
||||
count = get_pdf_page_count(pdf)
|
||||
|
||||
assert count == 5
|
||||
|
||||
def test_single_page(self) -> None:
|
||||
"""Test page count for single page PDF."""
|
||||
pdf_bytes = create_test_pdf(num_pages=1)
|
||||
pdf = PDFFile(source=FileBytes(data=pdf_bytes, filename="single.pdf"))
|
||||
|
||||
count = get_pdf_page_count(pdf)
|
||||
|
||||
assert count == 1
|
||||
|
||||
def test_returns_none_for_invalid_pdf(self) -> None:
|
||||
"""Test that None is returned for invalid PDF data."""
|
||||
pdf = PDFFile(source=FileBytes(data=b"not a pdf", filename="bad.pdf"))
|
||||
|
||||
count = get_pdf_page_count(pdf)
|
||||
|
||||
assert count is None
|
||||
@@ -1,644 +0,0 @@
|
||||
"""Tests for file validators."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from crewai_files import AudioFile, FileBytes, ImageFile, PDFFile, TextFile, VideoFile
|
||||
from crewai_files.processing.constraints import (
|
||||
ANTHROPIC_CONSTRAINTS,
|
||||
AudioConstraints,
|
||||
ImageConstraints,
|
||||
PDFConstraints,
|
||||
ProviderConstraints,
|
||||
VideoConstraints,
|
||||
)
|
||||
from crewai_files.processing.exceptions import (
|
||||
FileTooLargeError,
|
||||
FileValidationError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from crewai_files.processing.validators import (
|
||||
_get_audio_duration,
|
||||
_get_video_duration,
|
||||
validate_audio,
|
||||
validate_file,
|
||||
validate_image,
|
||||
validate_pdf,
|
||||
validate_text,
|
||||
validate_video,
|
||||
)
|
||||
import pytest
|
||||
|
||||
|
||||
# Minimal valid PNG: 8x8 pixel RGB image (valid for PIL)
|
||||
MINIMAL_PNG = bytes(
|
||||
[
|
||||
0x89,
|
||||
0x50,
|
||||
0x4E,
|
||||
0x47,
|
||||
0x0D,
|
||||
0x0A,
|
||||
0x1A,
|
||||
0x0A,
|
||||
0x00,
|
||||
0x00,
|
||||
0x00,
|
||||
0x0D,
|
||||
0x49,
|
||||
0x48,
|
||||
0x44,
|
||||
0x52,
|
||||
0x00,
|
||||
0x00,
|
||||
0x00,
|
||||
0x08,
|
||||
0x00,
|
||||
0x00,
|
||||
0x00,
|
||||
0x08,
|
||||
0x08,
|
||||
0x02,
|
||||
0x00,
|
||||
0x00,
|
||||
0x00,
|
||||
0x4B,
|
||||
0x6D,
|
||||
0x29,
|
||||
0xDC,
|
||||
0x00,
|
||||
0x00,
|
||||
0x00,
|
||||
0x12,
|
||||
0x49,
|
||||
0x44,
|
||||
0x41,
|
||||
0x54,
|
||||
0x78,
|
||||
0x9C,
|
||||
0x63,
|
||||
0xFC,
|
||||
0xCF,
|
||||
0x80,
|
||||
0x1D,
|
||||
0x30,
|
||||
0xE1,
|
||||
0x10,
|
||||
0x1F,
|
||||
0xA4,
|
||||
0x12,
|
||||
0x00,
|
||||
0xCD,
|
||||
0x41,
|
||||
0x01,
|
||||
0x0F,
|
||||
0xE8,
|
||||
0x41,
|
||||
0xE2,
|
||||
0x6F,
|
||||
0x00,
|
||||
0x00,
|
||||
0x00,
|
||||
0x00,
|
||||
0x49,
|
||||
0x45,
|
||||
0x4E,
|
||||
0x44,
|
||||
0xAE,
|
||||
0x42,
|
||||
0x60,
|
||||
0x82,
|
||||
]
|
||||
)
|
||||
|
||||
# Minimal valid PDF
|
||||
MINIMAL_PDF = (
|
||||
b"%PDF-1.4\n1 0 obj<</Type/Catalog/Pages 2 0 R>>endobj "
|
||||
b"2 0 obj<</Type/Pages/Kids[3 0 R]/Count 1>>endobj "
|
||||
b"3 0 obj<</Type/Page/MediaBox[0 0 612 792]/Parent 2 0 R>>endobj "
|
||||
b"xref\n0 4\n0000000000 65535 f \n0000000009 00000 n \n"
|
||||
b"0000000052 00000 n \n0000000101 00000 n \n"
|
||||
b"trailer<</Size 4/Root 1 0 R>>\nstartxref\n178\n%%EOF"
|
||||
)
|
||||
|
||||
|
||||
class TestValidateImage:
|
||||
"""Tests for validate_image function."""
|
||||
|
||||
def test_validate_valid_image(self):
|
||||
"""Test validating a valid image within constraints."""
|
||||
constraints = ImageConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
supported_formats=("image/png",),
|
||||
)
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
errors = validate_image(file, constraints, raise_on_error=False)
|
||||
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_validate_image_too_large(self):
|
||||
"""Test validating an image that exceeds size limit."""
|
||||
constraints = ImageConstraints(
|
||||
max_size_bytes=10, # Very small limit
|
||||
supported_formats=("image/png",),
|
||||
)
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
with pytest.raises(FileTooLargeError) as exc_info:
|
||||
validate_image(file, constraints)
|
||||
|
||||
assert "exceeds" in str(exc_info.value)
|
||||
assert exc_info.value.file_name == "test.png"
|
||||
|
||||
def test_validate_image_unsupported_format(self):
|
||||
"""Test validating an image with unsupported format."""
|
||||
constraints = ImageConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
supported_formats=("image/jpeg",), # Only JPEG
|
||||
)
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
with pytest.raises(UnsupportedFileTypeError) as exc_info:
|
||||
validate_image(file, constraints)
|
||||
|
||||
assert "not supported" in str(exc_info.value)
|
||||
|
||||
def test_validate_image_no_raise(self):
|
||||
"""Test validating with raise_on_error=False returns errors list."""
|
||||
constraints = ImageConstraints(
|
||||
max_size_bytes=10,
|
||||
supported_formats=("image/jpeg",),
|
||||
)
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
errors = validate_image(file, constraints, raise_on_error=False)
|
||||
|
||||
assert len(errors) == 2 # Size error and format error
|
||||
|
||||
|
||||
class TestValidatePDF:
|
||||
"""Tests for validate_pdf function."""
|
||||
|
||||
def test_validate_valid_pdf(self):
|
||||
"""Test validating a valid PDF within constraints."""
|
||||
constraints = PDFConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
)
|
||||
file = PDFFile(source=FileBytes(data=MINIMAL_PDF, filename="test.pdf"))
|
||||
|
||||
errors = validate_pdf(file, constraints, raise_on_error=False)
|
||||
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_validate_pdf_too_large(self):
|
||||
"""Test validating a PDF that exceeds size limit."""
|
||||
constraints = PDFConstraints(
|
||||
max_size_bytes=10, # Very small limit
|
||||
)
|
||||
file = PDFFile(source=FileBytes(data=MINIMAL_PDF, filename="test.pdf"))
|
||||
|
||||
with pytest.raises(FileTooLargeError) as exc_info:
|
||||
validate_pdf(file, constraints)
|
||||
|
||||
assert "exceeds" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestValidateText:
|
||||
"""Tests for validate_text function."""
|
||||
|
||||
def test_validate_valid_text(self):
|
||||
"""Test validating a valid text file."""
|
||||
constraints = ProviderConstraints(
|
||||
name="test",
|
||||
general_max_size_bytes=10 * 1024 * 1024,
|
||||
)
|
||||
file = TextFile(source=FileBytes(data=b"Hello, World!", filename="test.txt"))
|
||||
|
||||
errors = validate_text(file, constraints, raise_on_error=False)
|
||||
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_validate_text_too_large(self):
|
||||
"""Test validating text that exceeds size limit."""
|
||||
constraints = ProviderConstraints(
|
||||
name="test",
|
||||
general_max_size_bytes=5,
|
||||
)
|
||||
file = TextFile(source=FileBytes(data=b"Hello, World!", filename="test.txt"))
|
||||
|
||||
with pytest.raises(FileTooLargeError):
|
||||
validate_text(file, constraints)
|
||||
|
||||
def test_validate_text_no_limit(self):
|
||||
"""Test validating text with no size limit."""
|
||||
constraints = ProviderConstraints(name="test")
|
||||
file = TextFile(source=FileBytes(data=b"Hello, World!", filename="test.txt"))
|
||||
|
||||
errors = validate_text(file, constraints, raise_on_error=False)
|
||||
|
||||
assert len(errors) == 0
|
||||
|
||||
|
||||
class TestValidateFile:
|
||||
"""Tests for validate_file function."""
|
||||
|
||||
def test_validate_file_dispatches_to_image(self):
|
||||
"""Test validate_file dispatches to image validator."""
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
errors = validate_file(file, ANTHROPIC_CONSTRAINTS, raise_on_error=False)
|
||||
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_validate_file_dispatches_to_pdf(self):
|
||||
"""Test validate_file dispatches to PDF validator."""
|
||||
file = PDFFile(source=FileBytes(data=MINIMAL_PDF, filename="test.pdf"))
|
||||
|
||||
errors = validate_file(file, ANTHROPIC_CONSTRAINTS, raise_on_error=False)
|
||||
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_validate_file_unsupported_type(self):
|
||||
"""Test validating a file type not supported by provider."""
|
||||
constraints = ProviderConstraints(
|
||||
name="test",
|
||||
image=None, # No image support
|
||||
)
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
with pytest.raises(UnsupportedFileTypeError) as exc_info:
|
||||
validate_file(file, constraints)
|
||||
|
||||
assert "does not support images" in str(exc_info.value)
|
||||
|
||||
def test_validate_file_pdf_not_supported(self):
|
||||
"""Test validating PDF when provider doesn't support it."""
|
||||
constraints = ProviderConstraints(
|
||||
name="test",
|
||||
pdf=None, # No PDF support
|
||||
)
|
||||
file = PDFFile(source=FileBytes(data=MINIMAL_PDF, filename="test.pdf"))
|
||||
|
||||
with pytest.raises(UnsupportedFileTypeError) as exc_info:
|
||||
validate_file(file, constraints)
|
||||
|
||||
assert "does not support PDFs" in str(exc_info.value)
|
||||
|
||||
|
||||
# Minimal audio bytes for testing (not a valid audio file, used for mocked tests)
|
||||
MINIMAL_AUDIO = b"\x00" * 100
|
||||
|
||||
# Minimal video bytes for testing (not a valid video file, used for mocked tests)
|
||||
MINIMAL_VIDEO = b"\x00" * 100
|
||||
|
||||
# Fallback content type when python-magic cannot detect
|
||||
FALLBACK_CONTENT_TYPE = "application/octet-stream"
|
||||
|
||||
|
||||
class TestValidateAudio:
|
||||
"""Tests for validate_audio function and audio duration validation."""
|
||||
|
||||
def test_validate_valid_audio(self):
|
||||
"""Test validating a valid audio file within constraints."""
|
||||
constraints = AudioConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
supported_formats=("audio/mp3", "audio/mpeg", FALLBACK_CONTENT_TYPE),
|
||||
)
|
||||
file = AudioFile(source=FileBytes(data=MINIMAL_AUDIO, filename="test.mp3"))
|
||||
|
||||
errors = validate_audio(file, constraints, raise_on_error=False)
|
||||
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_validate_audio_too_large(self):
|
||||
"""Test validating an audio file that exceeds size limit."""
|
||||
constraints = AudioConstraints(
|
||||
max_size_bytes=10, # Very small limit
|
||||
supported_formats=("audio/mp3", "audio/mpeg", FALLBACK_CONTENT_TYPE),
|
||||
)
|
||||
file = AudioFile(source=FileBytes(data=MINIMAL_AUDIO, filename="test.mp3"))
|
||||
|
||||
with pytest.raises(FileTooLargeError) as exc_info:
|
||||
validate_audio(file, constraints)
|
||||
|
||||
assert "exceeds" in str(exc_info.value)
|
||||
assert exc_info.value.file_name == "test.mp3"
|
||||
|
||||
def test_validate_audio_unsupported_format(self):
|
||||
"""Test validating an audio file with unsupported format."""
|
||||
constraints = AudioConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
supported_formats=("audio/wav",), # Only WAV
|
||||
)
|
||||
file = AudioFile(source=FileBytes(data=MINIMAL_AUDIO, filename="test.mp3"))
|
||||
|
||||
with pytest.raises(UnsupportedFileTypeError) as exc_info:
|
||||
validate_audio(file, constraints)
|
||||
|
||||
assert "not supported" in str(exc_info.value)
|
||||
|
||||
@patch("crewai_files.processing.validators._get_audio_duration")
|
||||
def test_validate_audio_duration_passes(self, mock_get_duration):
|
||||
"""Test validating audio when duration is under limit."""
|
||||
mock_get_duration.return_value = 30.0
|
||||
constraints = AudioConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
max_duration_seconds=60,
|
||||
supported_formats=("audio/mp3", "audio/mpeg", FALLBACK_CONTENT_TYPE),
|
||||
)
|
||||
file = AudioFile(source=FileBytes(data=MINIMAL_AUDIO, filename="test.mp3"))
|
||||
|
||||
errors = validate_audio(file, constraints, raise_on_error=False)
|
||||
|
||||
assert len(errors) == 0
|
||||
mock_get_duration.assert_called_once()
|
||||
|
||||
@patch("crewai_files.processing.validators._get_audio_duration")
|
||||
def test_validate_audio_duration_fails(self, mock_get_duration):
|
||||
"""Test validating audio when duration exceeds limit."""
|
||||
mock_get_duration.return_value = 120.5
|
||||
constraints = AudioConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
max_duration_seconds=60,
|
||||
supported_formats=("audio/mp3", "audio/mpeg", FALLBACK_CONTENT_TYPE),
|
||||
)
|
||||
file = AudioFile(source=FileBytes(data=MINIMAL_AUDIO, filename="test.mp3"))
|
||||
|
||||
with pytest.raises(FileValidationError) as exc_info:
|
||||
validate_audio(file, constraints)
|
||||
|
||||
assert "duration" in str(exc_info.value).lower()
|
||||
assert "120.5s" in str(exc_info.value)
|
||||
assert "60s" in str(exc_info.value)
|
||||
|
||||
@patch("crewai_files.processing.validators._get_audio_duration")
|
||||
def test_validate_audio_duration_no_raise(self, mock_get_duration):
|
||||
"""Test audio duration validation with raise_on_error=False."""
|
||||
mock_get_duration.return_value = 120.5
|
||||
constraints = AudioConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
max_duration_seconds=60,
|
||||
supported_formats=("audio/mp3", "audio/mpeg", FALLBACK_CONTENT_TYPE),
|
||||
)
|
||||
file = AudioFile(source=FileBytes(data=MINIMAL_AUDIO, filename="test.mp3"))
|
||||
|
||||
errors = validate_audio(file, constraints, raise_on_error=False)
|
||||
|
||||
assert len(errors) == 1
|
||||
assert "duration" in errors[0].lower()
|
||||
|
||||
@patch("crewai_files.processing.validators._get_audio_duration")
|
||||
def test_validate_audio_duration_none_skips(self, mock_get_duration):
|
||||
"""Test that duration validation is skipped when max_duration_seconds is None."""
|
||||
constraints = AudioConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
max_duration_seconds=None,
|
||||
supported_formats=("audio/mp3", "audio/mpeg", FALLBACK_CONTENT_TYPE),
|
||||
)
|
||||
file = AudioFile(source=FileBytes(data=MINIMAL_AUDIO, filename="test.mp3"))
|
||||
|
||||
errors = validate_audio(file, constraints, raise_on_error=False)
|
||||
|
||||
assert len(errors) == 0
|
||||
mock_get_duration.assert_not_called()
|
||||
|
||||
@patch("crewai_files.processing.validators._get_audio_duration")
|
||||
def test_validate_audio_duration_detection_returns_none(self, mock_get_duration):
|
||||
"""Test that validation passes when duration detection returns None."""
|
||||
mock_get_duration.return_value = None
|
||||
constraints = AudioConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
max_duration_seconds=60,
|
||||
supported_formats=("audio/mp3", "audio/mpeg", FALLBACK_CONTENT_TYPE),
|
||||
)
|
||||
file = AudioFile(source=FileBytes(data=MINIMAL_AUDIO, filename="test.mp3"))
|
||||
|
||||
errors = validate_audio(file, constraints, raise_on_error=False)
|
||||
|
||||
assert len(errors) == 0
|
||||
|
||||
|
||||
class TestValidateVideo:
|
||||
"""Tests for validate_video function and video duration validation."""
|
||||
|
||||
def test_validate_valid_video(self):
|
||||
"""Test validating a valid video file within constraints."""
|
||||
constraints = VideoConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
supported_formats=("video/mp4", FALLBACK_CONTENT_TYPE),
|
||||
)
|
||||
file = VideoFile(source=FileBytes(data=MINIMAL_VIDEO, filename="test.mp4"))
|
||||
|
||||
errors = validate_video(file, constraints, raise_on_error=False)
|
||||
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_validate_video_too_large(self):
|
||||
"""Test validating a video file that exceeds size limit."""
|
||||
constraints = VideoConstraints(
|
||||
max_size_bytes=10, # Very small limit
|
||||
supported_formats=("video/mp4", FALLBACK_CONTENT_TYPE),
|
||||
)
|
||||
file = VideoFile(source=FileBytes(data=MINIMAL_VIDEO, filename="test.mp4"))
|
||||
|
||||
with pytest.raises(FileTooLargeError) as exc_info:
|
||||
validate_video(file, constraints)
|
||||
|
||||
assert "exceeds" in str(exc_info.value)
|
||||
assert exc_info.value.file_name == "test.mp4"
|
||||
|
||||
def test_validate_video_unsupported_format(self):
|
||||
"""Test validating a video file with unsupported format."""
|
||||
constraints = VideoConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
supported_formats=("video/webm",), # Only WebM
|
||||
)
|
||||
file = VideoFile(source=FileBytes(data=MINIMAL_VIDEO, filename="test.mp4"))
|
||||
|
||||
with pytest.raises(UnsupportedFileTypeError) as exc_info:
|
||||
validate_video(file, constraints)
|
||||
|
||||
assert "not supported" in str(exc_info.value)
|
||||
|
||||
@patch("crewai_files.processing.validators._get_video_duration")
|
||||
def test_validate_video_duration_passes(self, mock_get_duration):
|
||||
"""Test validating video when duration is under limit."""
|
||||
mock_get_duration.return_value = 30.0
|
||||
constraints = VideoConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
max_duration_seconds=60,
|
||||
supported_formats=("video/mp4", FALLBACK_CONTENT_TYPE),
|
||||
)
|
||||
file = VideoFile(source=FileBytes(data=MINIMAL_VIDEO, filename="test.mp4"))
|
||||
|
||||
errors = validate_video(file, constraints, raise_on_error=False)
|
||||
|
||||
assert len(errors) == 0
|
||||
mock_get_duration.assert_called_once()
|
||||
|
||||
@patch("crewai_files.processing.validators._get_video_duration")
|
||||
def test_validate_video_duration_fails(self, mock_get_duration):
|
||||
"""Test validating video when duration exceeds limit."""
|
||||
mock_get_duration.return_value = 180.0
|
||||
constraints = VideoConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
max_duration_seconds=60,
|
||||
supported_formats=("video/mp4", FALLBACK_CONTENT_TYPE),
|
||||
)
|
||||
file = VideoFile(source=FileBytes(data=MINIMAL_VIDEO, filename="test.mp4"))
|
||||
|
||||
with pytest.raises(FileValidationError) as exc_info:
|
||||
validate_video(file, constraints)
|
||||
|
||||
assert "duration" in str(exc_info.value).lower()
|
||||
assert "180.0s" in str(exc_info.value)
|
||||
assert "60s" in str(exc_info.value)
|
||||
|
||||
@patch("crewai_files.processing.validators._get_video_duration")
|
||||
def test_validate_video_duration_no_raise(self, mock_get_duration):
|
||||
"""Test video duration validation with raise_on_error=False."""
|
||||
mock_get_duration.return_value = 180.0
|
||||
constraints = VideoConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
max_duration_seconds=60,
|
||||
supported_formats=("video/mp4", FALLBACK_CONTENT_TYPE),
|
||||
)
|
||||
file = VideoFile(source=FileBytes(data=MINIMAL_VIDEO, filename="test.mp4"))
|
||||
|
||||
errors = validate_video(file, constraints, raise_on_error=False)
|
||||
|
||||
assert len(errors) == 1
|
||||
assert "duration" in errors[0].lower()
|
||||
|
||||
@patch("crewai_files.processing.validators._get_video_duration")
|
||||
def test_validate_video_duration_none_skips(self, mock_get_duration):
|
||||
"""Test that duration validation is skipped when max_duration_seconds is None."""
|
||||
constraints = VideoConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
max_duration_seconds=None,
|
||||
supported_formats=("video/mp4", FALLBACK_CONTENT_TYPE),
|
||||
)
|
||||
file = VideoFile(source=FileBytes(data=MINIMAL_VIDEO, filename="test.mp4"))
|
||||
|
||||
errors = validate_video(file, constraints, raise_on_error=False)
|
||||
|
||||
assert len(errors) == 0
|
||||
mock_get_duration.assert_not_called()
|
||||
|
||||
@patch("crewai_files.processing.validators._get_video_duration")
|
||||
def test_validate_video_duration_detection_returns_none(self, mock_get_duration):
|
||||
"""Test that validation passes when duration detection returns None."""
|
||||
mock_get_duration.return_value = None
|
||||
constraints = VideoConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
max_duration_seconds=60,
|
||||
supported_formats=("video/mp4", FALLBACK_CONTENT_TYPE),
|
||||
)
|
||||
file = VideoFile(source=FileBytes(data=MINIMAL_VIDEO, filename="test.mp4"))
|
||||
|
||||
errors = validate_video(file, constraints, raise_on_error=False)
|
||||
|
||||
assert len(errors) == 0
|
||||
|
||||
|
||||
class TestGetAudioDuration:
|
||||
"""Tests for _get_audio_duration helper function."""
|
||||
|
||||
def test_get_audio_duration_corrupt_file(self):
|
||||
"""Test handling of corrupt audio data."""
|
||||
corrupt_data = b"not valid audio data at all"
|
||||
result = _get_audio_duration(corrupt_data)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGetVideoDuration:
|
||||
"""Tests for _get_video_duration helper function."""
|
||||
|
||||
def test_get_video_duration_corrupt_file(self):
|
||||
"""Test handling of corrupt video data."""
|
||||
corrupt_data = b"not valid video data at all"
|
||||
result = _get_video_duration(corrupt_data)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestRealVideoFile:
|
||||
"""Tests using real video fixture file."""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_video_path(self):
|
||||
"""Path to sample video fixture."""
|
||||
from pathlib import Path
|
||||
|
||||
path = Path(__file__).parent.parent.parent / "fixtures" / "sample_video.mp4"
|
||||
if not path.exists():
|
||||
pytest.skip("sample_video.mp4 fixture not found")
|
||||
return path
|
||||
|
||||
@pytest.fixture
|
||||
def sample_video_content(self, sample_video_path):
|
||||
"""Read sample video content."""
|
||||
return sample_video_path.read_bytes()
|
||||
|
||||
def test_get_video_duration_real_file(self, sample_video_content):
|
||||
"""Test duration detection with real video file."""
|
||||
try:
|
||||
import av # noqa: F401
|
||||
except ImportError:
|
||||
pytest.skip("PyAV not installed")
|
||||
|
||||
duration = _get_video_duration(sample_video_content, "video/mp4")
|
||||
|
||||
assert duration is not None
|
||||
assert 4.5 <= duration <= 5.5 # ~5 seconds with tolerance
|
||||
|
||||
def test_get_video_duration_real_file_no_format_hint(self, sample_video_content):
|
||||
"""Test duration detection without format hint."""
|
||||
try:
|
||||
import av # noqa: F401
|
||||
except ImportError:
|
||||
pytest.skip("PyAV not installed")
|
||||
|
||||
duration = _get_video_duration(sample_video_content)
|
||||
|
||||
assert duration is not None
|
||||
assert 4.5 <= duration <= 5.5
|
||||
|
||||
def test_validate_video_real_file_passes(self, sample_video_path):
|
||||
"""Test validating real video file within constraints."""
|
||||
try:
|
||||
import av # noqa: F401
|
||||
except ImportError:
|
||||
pytest.skip("PyAV not installed")
|
||||
|
||||
constraints = VideoConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
max_duration_seconds=60,
|
||||
supported_formats=("video/mp4",),
|
||||
)
|
||||
file = VideoFile(source=str(sample_video_path))
|
||||
|
||||
errors = validate_video(file, constraints, raise_on_error=False)
|
||||
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_validate_video_real_file_duration_exceeded(self, sample_video_path):
|
||||
"""Test validating real video file that exceeds duration limit."""
|
||||
try:
|
||||
import av # noqa: F401
|
||||
except ImportError:
|
||||
pytest.skip("PyAV not installed")
|
||||
|
||||
constraints = VideoConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
max_duration_seconds=2, # Video is ~5 seconds
|
||||
supported_formats=("video/mp4",),
|
||||
)
|
||||
file = VideoFile(source=str(sample_video_path))
|
||||
|
||||
with pytest.raises(FileValidationError) as exc_info:
|
||||
validate_video(file, constraints)
|
||||
|
||||
assert "duration" in str(exc_info.value).lower()
|
||||
assert "2s" in str(exc_info.value)
|
||||
@@ -1,311 +0,0 @@
|
||||
"""Tests for FileUrl source type and URL resolution."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from crewai_files import FileBytes, FileUrl, ImageFile
|
||||
from crewai_files.core.resolved import InlineBase64, UrlReference
|
||||
from crewai_files.core.sources import FilePath, _normalize_source
|
||||
from crewai_files.resolution.resolver import FileResolver
|
||||
import pytest
|
||||
|
||||
|
||||
class TestFileUrl:
|
||||
"""Tests for FileUrl source type."""
|
||||
|
||||
def test_create_file_url(self):
|
||||
"""Test creating FileUrl with valid URL."""
|
||||
url = FileUrl(url="https://example.com/image.png")
|
||||
|
||||
assert url.url == "https://example.com/image.png"
|
||||
assert url.filename is None
|
||||
|
||||
def test_create_file_url_with_filename(self):
|
||||
"""Test creating FileUrl with custom filename."""
|
||||
url = FileUrl(url="https://example.com/image.png", filename="custom.png")
|
||||
|
||||
assert url.url == "https://example.com/image.png"
|
||||
assert url.filename == "custom.png"
|
||||
|
||||
def test_invalid_url_scheme_raises(self):
|
||||
"""Test that non-http(s) URLs raise ValueError."""
|
||||
with pytest.raises(ValueError, match="Invalid URL scheme"):
|
||||
FileUrl(url="ftp://example.com/file.txt")
|
||||
|
||||
def test_invalid_url_scheme_file_raises(self):
|
||||
"""Test that file:// URLs raise ValueError."""
|
||||
with pytest.raises(ValueError, match="Invalid URL scheme"):
|
||||
FileUrl(url="file:///path/to/file.txt")
|
||||
|
||||
def test_http_url_valid(self):
|
||||
"""Test that HTTP URLs are valid."""
|
||||
url = FileUrl(url="http://example.com/image.jpg")
|
||||
|
||||
assert url.url == "http://example.com/image.jpg"
|
||||
|
||||
def test_https_url_valid(self):
|
||||
"""Test that HTTPS URLs are valid."""
|
||||
url = FileUrl(url="https://example.com/image.jpg")
|
||||
|
||||
assert url.url == "https://example.com/image.jpg"
|
||||
|
||||
def test_content_type_guessing_png(self):
|
||||
"""Test content type guessing for PNG files."""
|
||||
url = FileUrl(url="https://example.com/image.png")
|
||||
|
||||
assert url.content_type == "image/png"
|
||||
|
||||
def test_content_type_guessing_jpeg(self):
|
||||
"""Test content type guessing for JPEG files."""
|
||||
url = FileUrl(url="https://example.com/photo.jpg")
|
||||
|
||||
assert url.content_type == "image/jpeg"
|
||||
|
||||
def test_content_type_guessing_pdf(self):
|
||||
"""Test content type guessing for PDF files."""
|
||||
url = FileUrl(url="https://example.com/document.pdf")
|
||||
|
||||
assert url.content_type == "application/pdf"
|
||||
|
||||
def test_content_type_guessing_with_query_params(self):
|
||||
"""Test content type guessing with URL query parameters."""
|
||||
url = FileUrl(url="https://example.com/image.png?v=123&token=abc")
|
||||
|
||||
assert url.content_type == "image/png"
|
||||
|
||||
def test_content_type_fallback_unknown(self):
|
||||
"""Test content type falls back to octet-stream for unknown extensions."""
|
||||
url = FileUrl(url="https://example.com/file.unknownext123")
|
||||
|
||||
assert url.content_type == "application/octet-stream"
|
||||
|
||||
def test_content_type_no_extension(self):
|
||||
"""Test content type for URL without extension."""
|
||||
url = FileUrl(url="https://example.com/file")
|
||||
|
||||
assert url.content_type == "application/octet-stream"
|
||||
|
||||
def test_read_fetches_content(self):
|
||||
"""Test that read() fetches content from URL."""
|
||||
url = FileUrl(url="https://example.com/image.png")
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = b"fake image content"
|
||||
mock_response.headers = {"content-type": "image/png"}
|
||||
|
||||
with patch("httpx.get", return_value=mock_response) as mock_get:
|
||||
content = url.read()
|
||||
|
||||
mock_get.assert_called_once_with(
|
||||
"https://example.com/image.png", follow_redirects=True
|
||||
)
|
||||
assert content == b"fake image content"
|
||||
|
||||
def test_read_caches_content(self):
|
||||
"""Test that read() caches content."""
|
||||
url = FileUrl(url="https://example.com/image.png")
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = b"fake content"
|
||||
mock_response.headers = {}
|
||||
|
||||
with patch("httpx.get", return_value=mock_response) as mock_get:
|
||||
content1 = url.read()
|
||||
content2 = url.read()
|
||||
|
||||
mock_get.assert_called_once()
|
||||
assert content1 == content2
|
||||
|
||||
def test_read_updates_content_type_from_response(self):
|
||||
"""Test that read() updates content type from response headers."""
|
||||
url = FileUrl(url="https://example.com/file")
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = b"fake content"
|
||||
mock_response.headers = {"content-type": "image/webp; charset=utf-8"}
|
||||
|
||||
with patch("httpx.get", return_value=mock_response):
|
||||
url.read()
|
||||
|
||||
assert url.content_type == "image/webp"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aread_fetches_content(self):
|
||||
"""Test that aread() fetches content from URL asynchronously."""
|
||||
url = FileUrl(url="https://example.com/image.png")
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = b"async fake content"
|
||||
mock_response.headers = {"content-type": "image/png"}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
content = await url.aread()
|
||||
|
||||
assert content == b"async fake content"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aread_caches_content(self):
|
||||
"""Test that aread() caches content."""
|
||||
url = FileUrl(url="https://example.com/image.png")
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = b"cached content"
|
||||
mock_response.headers = {}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
content1 = await url.aread()
|
||||
content2 = await url.aread()
|
||||
|
||||
mock_client.get.assert_called_once()
|
||||
assert content1 == content2
|
||||
|
||||
|
||||
class TestNormalizeSource:
|
||||
"""Tests for _normalize_source with URL detection."""
|
||||
|
||||
def test_normalize_url_string(self):
|
||||
"""Test that URL strings are converted to FileUrl."""
|
||||
result = _normalize_source("https://example.com/image.png")
|
||||
|
||||
assert isinstance(result, FileUrl)
|
||||
assert result.url == "https://example.com/image.png"
|
||||
|
||||
def test_normalize_http_url_string(self):
|
||||
"""Test that HTTP URL strings are converted to FileUrl."""
|
||||
result = _normalize_source("http://example.com/file.pdf")
|
||||
|
||||
assert isinstance(result, FileUrl)
|
||||
assert result.url == "http://example.com/file.pdf"
|
||||
|
||||
def test_normalize_file_path_string(self, tmp_path):
|
||||
"""Test that file path strings are converted to FilePath."""
|
||||
test_file = tmp_path / "test.png"
|
||||
test_file.write_bytes(b"test content")
|
||||
|
||||
result = _normalize_source(str(test_file))
|
||||
|
||||
assert isinstance(result, FilePath)
|
||||
|
||||
def test_normalize_relative_path_is_not_url(self):
|
||||
"""Test that relative path strings are not treated as URLs."""
|
||||
result = _normalize_source("https://example.com/file.png")
|
||||
|
||||
assert isinstance(result, FileUrl)
|
||||
assert not isinstance(result, FilePath)
|
||||
|
||||
def test_normalize_file_url_passthrough(self):
|
||||
"""Test that FileUrl instances pass through unchanged."""
|
||||
original = FileUrl(url="https://example.com/image.png")
|
||||
result = _normalize_source(original)
|
||||
|
||||
assert result is original
|
||||
|
||||
|
||||
class TestResolverUrlHandling:
|
||||
"""Tests for FileResolver URL handling."""
|
||||
|
||||
def test_resolve_url_source_for_supported_provider(self):
|
||||
"""Test URL source resolves to UrlReference for supported providers."""
|
||||
resolver = FileResolver()
|
||||
file = ImageFile(source=FileUrl(url="https://example.com/image.png"))
|
||||
|
||||
resolved = resolver.resolve(file, "anthropic")
|
||||
|
||||
assert isinstance(resolved, UrlReference)
|
||||
assert resolved.url == "https://example.com/image.png"
|
||||
assert resolved.content_type == "image/png"
|
||||
|
||||
def test_resolve_url_source_openai(self):
|
||||
"""Test URL source resolves to UrlReference for OpenAI."""
|
||||
resolver = FileResolver()
|
||||
file = ImageFile(source=FileUrl(url="https://example.com/photo.jpg"))
|
||||
|
||||
resolved = resolver.resolve(file, "openai")
|
||||
|
||||
assert isinstance(resolved, UrlReference)
|
||||
assert resolved.url == "https://example.com/photo.jpg"
|
||||
|
||||
def test_resolve_url_source_gemini(self):
|
||||
"""Test URL source resolves to UrlReference for Gemini."""
|
||||
resolver = FileResolver()
|
||||
file = ImageFile(source=FileUrl(url="https://example.com/image.webp"))
|
||||
|
||||
resolved = resolver.resolve(file, "gemini")
|
||||
|
||||
assert isinstance(resolved, UrlReference)
|
||||
assert resolved.url == "https://example.com/image.webp"
|
||||
|
||||
def test_resolve_url_source_azure(self):
|
||||
"""Test URL source resolves to UrlReference for Azure."""
|
||||
resolver = FileResolver()
|
||||
file = ImageFile(source=FileUrl(url="https://example.com/image.gif"))
|
||||
|
||||
resolved = resolver.resolve(file, "azure")
|
||||
|
||||
assert isinstance(resolved, UrlReference)
|
||||
assert resolved.url == "https://example.com/image.gif"
|
||||
|
||||
def test_resolve_url_source_bedrock_fetches_content(self):
|
||||
"""Test URL source fetches content for Bedrock (unsupported URLs)."""
|
||||
resolver = FileResolver()
|
||||
file_url = FileUrl(url="https://example.com/image.png")
|
||||
file = ImageFile(source=file_url)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = b"\x89PNG\r\n\x1a\n" + b"\x00" * 50
|
||||
mock_response.headers = {"content-type": "image/png"}
|
||||
|
||||
with patch("httpx.get", return_value=mock_response):
|
||||
resolved = resolver.resolve(file, "bedrock")
|
||||
|
||||
assert not isinstance(resolved, UrlReference)
|
||||
|
||||
def test_resolve_bytes_source_still_works(self):
|
||||
"""Test that bytes source still resolves normally."""
|
||||
resolver = FileResolver()
|
||||
minimal_png = (
|
||||
b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x08\x00\x00\x00\x08"
|
||||
b"\x01\x00\x00\x00\x00\xf9Y\xab\xcd\x00\x00\x00\nIDATx\x9cc`\x00\x00"
|
||||
b"\x00\x02\x00\x01\xe2!\xbc3\x00\x00\x00\x00IEND\xaeB`\x82"
|
||||
)
|
||||
file = ImageFile(source=FileBytes(data=minimal_png, filename="test.png"))
|
||||
|
||||
resolved = resolver.resolve(file, "anthropic")
|
||||
|
||||
assert isinstance(resolved, InlineBase64)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aresolve_url_source(self):
|
||||
"""Test async URL resolution for supported provider."""
|
||||
resolver = FileResolver()
|
||||
file = ImageFile(source=FileUrl(url="https://example.com/image.png"))
|
||||
|
||||
resolved = await resolver.aresolve(file, "anthropic")
|
||||
|
||||
assert isinstance(resolved, UrlReference)
|
||||
assert resolved.url == "https://example.com/image.png"
|
||||
|
||||
|
||||
class TestImageFileWithUrl:
|
||||
"""Tests for creating ImageFile with URL source."""
|
||||
|
||||
def test_image_file_from_url_string(self):
|
||||
"""Test creating ImageFile from URL string."""
|
||||
file = ImageFile(source="https://example.com/image.png")
|
||||
|
||||
assert isinstance(file.source, FileUrl)
|
||||
assert file.source.url == "https://example.com/image.png"
|
||||
|
||||
def test_image_file_from_file_url(self):
|
||||
"""Test creating ImageFile from FileUrl instance."""
|
||||
url = FileUrl(url="https://example.com/photo.jpg")
|
||||
file = ImageFile(source=url)
|
||||
|
||||
assert file.source is url
|
||||
assert file.content_type == "image/jpeg"
|
||||
@@ -1,134 +0,0 @@
|
||||
"""Tests for resolved file types."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from crewai_files.core.resolved import (
|
||||
FileReference,
|
||||
InlineBase64,
|
||||
InlineBytes,
|
||||
ResolvedFile,
|
||||
UrlReference,
|
||||
)
|
||||
import pytest
|
||||
|
||||
|
||||
class TestInlineBase64:
|
||||
"""Tests for InlineBase64 resolved type."""
|
||||
|
||||
def test_create_inline_base64(self):
|
||||
"""Test creating InlineBase64 instance."""
|
||||
resolved = InlineBase64(
|
||||
content_type="image/png",
|
||||
data="iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==",
|
||||
)
|
||||
|
||||
assert resolved.content_type == "image/png"
|
||||
assert len(resolved.data) > 0
|
||||
|
||||
def test_inline_base64_is_resolved_file(self):
|
||||
"""Test InlineBase64 is a ResolvedFile."""
|
||||
resolved = InlineBase64(content_type="image/png", data="abc123")
|
||||
|
||||
assert isinstance(resolved, ResolvedFile)
|
||||
|
||||
def test_inline_base64_frozen(self):
|
||||
"""Test InlineBase64 is immutable."""
|
||||
resolved = InlineBase64(content_type="image/png", data="abc123")
|
||||
|
||||
with pytest.raises(Exception):
|
||||
resolved.data = "xyz789"
|
||||
|
||||
|
||||
class TestInlineBytes:
|
||||
"""Tests for InlineBytes resolved type."""
|
||||
|
||||
def test_create_inline_bytes(self):
|
||||
"""Test creating InlineBytes instance."""
|
||||
data = b"\x89PNG\r\n\x1a\n"
|
||||
resolved = InlineBytes(
|
||||
content_type="image/png",
|
||||
data=data,
|
||||
)
|
||||
|
||||
assert resolved.content_type == "image/png"
|
||||
assert resolved.data == data
|
||||
|
||||
def test_inline_bytes_is_resolved_file(self):
|
||||
"""Test InlineBytes is a ResolvedFile."""
|
||||
resolved = InlineBytes(content_type="image/png", data=b"test")
|
||||
|
||||
assert isinstance(resolved, ResolvedFile)
|
||||
|
||||
|
||||
class TestFileReference:
|
||||
"""Tests for FileReference resolved type."""
|
||||
|
||||
def test_create_file_reference(self):
|
||||
"""Test creating FileReference instance."""
|
||||
resolved = FileReference(
|
||||
content_type="image/png",
|
||||
file_id="file-abc123",
|
||||
provider="gemini",
|
||||
)
|
||||
|
||||
assert resolved.content_type == "image/png"
|
||||
assert resolved.file_id == "file-abc123"
|
||||
assert resolved.provider == "gemini"
|
||||
assert resolved.expires_at is None
|
||||
assert resolved.file_uri is None
|
||||
|
||||
def test_file_reference_with_expiry(self):
|
||||
"""Test FileReference with expiry time."""
|
||||
expiry = datetime.now(timezone.utc)
|
||||
resolved = FileReference(
|
||||
content_type="application/pdf",
|
||||
file_id="file-xyz789",
|
||||
provider="gemini",
|
||||
expires_at=expiry,
|
||||
)
|
||||
|
||||
assert resolved.expires_at == expiry
|
||||
|
||||
def test_file_reference_with_uri(self):
|
||||
"""Test FileReference with URI."""
|
||||
resolved = FileReference(
|
||||
content_type="video/mp4",
|
||||
file_id="file-video123",
|
||||
provider="gemini",
|
||||
file_uri="https://generativelanguage.googleapis.com/v1/files/file-video123",
|
||||
)
|
||||
|
||||
assert resolved.file_uri is not None
|
||||
|
||||
def test_file_reference_is_resolved_file(self):
|
||||
"""Test FileReference is a ResolvedFile."""
|
||||
resolved = FileReference(
|
||||
content_type="image/png",
|
||||
file_id="file-123",
|
||||
provider="anthropic",
|
||||
)
|
||||
|
||||
assert isinstance(resolved, ResolvedFile)
|
||||
|
||||
|
||||
class TestUrlReference:
|
||||
"""Tests for UrlReference resolved type."""
|
||||
|
||||
def test_create_url_reference(self):
|
||||
"""Test creating UrlReference instance."""
|
||||
resolved = UrlReference(
|
||||
content_type="image/png",
|
||||
url="https://storage.googleapis.com/bucket/image.png",
|
||||
)
|
||||
|
||||
assert resolved.content_type == "image/png"
|
||||
assert resolved.url == "https://storage.googleapis.com/bucket/image.png"
|
||||
|
||||
def test_url_reference_is_resolved_file(self):
|
||||
"""Test UrlReference is a ResolvedFile."""
|
||||
resolved = UrlReference(
|
||||
content_type="image/jpeg",
|
||||
url="https://example.com/photo.jpg",
|
||||
)
|
||||
|
||||
assert isinstance(resolved, ResolvedFile)
|
||||
@@ -1,176 +0,0 @@
|
||||
"""Tests for FileResolver."""
|
||||
|
||||
from crewai_files import FileBytes, ImageFile
|
||||
from crewai_files.cache.upload_cache import UploadCache
|
||||
from crewai_files.core.resolved import InlineBase64, InlineBytes
|
||||
from crewai_files.resolution.resolver import (
|
||||
FileResolver,
|
||||
FileResolverConfig,
|
||||
create_resolver,
|
||||
)
|
||||
|
||||
|
||||
# Minimal valid PNG
|
||||
MINIMAL_PNG = (
|
||||
b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x08\x00\x00\x00\x08"
|
||||
b"\x01\x00\x00\x00\x00\xf9Y\xab\xcd\x00\x00\x00\nIDATx\x9cc`\x00\x00"
|
||||
b"\x00\x02\x00\x01\xe2!\xbc3\x00\x00\x00\x00IEND\xaeB`\x82"
|
||||
)
|
||||
|
||||
|
||||
class TestFileResolverConfig:
|
||||
"""Tests for FileResolverConfig."""
|
||||
|
||||
def test_default_config(self):
|
||||
"""Test default configuration values."""
|
||||
config = FileResolverConfig()
|
||||
|
||||
assert config.prefer_upload is False
|
||||
assert config.upload_threshold_bytes is None
|
||||
assert config.use_bytes_for_bedrock is True
|
||||
|
||||
def test_custom_config(self):
|
||||
"""Test custom configuration values."""
|
||||
config = FileResolverConfig(
|
||||
prefer_upload=True,
|
||||
upload_threshold_bytes=1024 * 1024,
|
||||
use_bytes_for_bedrock=False,
|
||||
)
|
||||
|
||||
assert config.prefer_upload is True
|
||||
assert config.upload_threshold_bytes == 1024 * 1024
|
||||
assert config.use_bytes_for_bedrock is False
|
||||
|
||||
|
||||
class TestFileResolver:
|
||||
"""Tests for FileResolver class."""
|
||||
|
||||
def test_resolve_inline_base64(self):
|
||||
"""Test resolving file as inline base64."""
|
||||
resolver = FileResolver()
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
resolved = resolver.resolve(file, "openai")
|
||||
|
||||
assert isinstance(resolved, InlineBase64)
|
||||
assert resolved.content_type == "image/png"
|
||||
assert len(resolved.data) > 0
|
||||
|
||||
def test_resolve_inline_bytes_for_bedrock(self):
|
||||
"""Test resolving file as inline bytes for Bedrock."""
|
||||
config = FileResolverConfig(use_bytes_for_bedrock=True)
|
||||
resolver = FileResolver(config=config)
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
resolved = resolver.resolve(file, "bedrock")
|
||||
|
||||
assert isinstance(resolved, InlineBytes)
|
||||
assert resolved.content_type == "image/png"
|
||||
assert resolved.data == MINIMAL_PNG
|
||||
|
||||
def test_resolve_files_multiple(self):
|
||||
"""Test resolving multiple files."""
|
||||
resolver = FileResolver()
|
||||
files = {
|
||||
"image1": ImageFile(
|
||||
source=FileBytes(data=MINIMAL_PNG, filename="test1.png")
|
||||
),
|
||||
"image2": ImageFile(
|
||||
source=FileBytes(data=MINIMAL_PNG, filename="test2.png")
|
||||
),
|
||||
}
|
||||
|
||||
resolved = resolver.resolve_files(files, "openai")
|
||||
|
||||
assert len(resolved) == 2
|
||||
assert "image1" in resolved
|
||||
assert "image2" in resolved
|
||||
assert all(isinstance(r, InlineBase64) for r in resolved.values())
|
||||
|
||||
def test_resolve_with_cache(self):
|
||||
"""Test resolver uses cache."""
|
||||
cache = UploadCache()
|
||||
resolver = FileResolver(upload_cache=cache)
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
# First resolution
|
||||
resolved1 = resolver.resolve(file, "openai")
|
||||
# Second resolution (should use same base64 encoding)
|
||||
resolved2 = resolver.resolve(file, "openai")
|
||||
|
||||
assert isinstance(resolved1, InlineBase64)
|
||||
assert isinstance(resolved2, InlineBase64)
|
||||
# Data should be identical
|
||||
assert resolved1.data == resolved2.data
|
||||
|
||||
def test_clear_cache(self):
|
||||
"""Test clearing resolver cache."""
|
||||
cache = UploadCache()
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
# Add something to cache manually
|
||||
cache.set(file=file, provider="gemini", file_id="test")
|
||||
|
||||
resolver = FileResolver(upload_cache=cache)
|
||||
resolver.clear_cache()
|
||||
|
||||
assert len(cache) == 0
|
||||
|
||||
def test_get_cached_uploads(self):
|
||||
"""Test getting cached uploads from resolver."""
|
||||
cache = UploadCache()
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
cache.set(file=file, provider="gemini", file_id="test-1")
|
||||
cache.set(file=file, provider="anthropic", file_id="test-2")
|
||||
|
||||
resolver = FileResolver(upload_cache=cache)
|
||||
|
||||
gemini_uploads = resolver.get_cached_uploads("gemini")
|
||||
anthropic_uploads = resolver.get_cached_uploads("anthropic")
|
||||
|
||||
assert len(gemini_uploads) == 1
|
||||
assert len(anthropic_uploads) == 1
|
||||
|
||||
def test_get_cached_uploads_empty(self):
|
||||
"""Test getting cached uploads when no cache."""
|
||||
resolver = FileResolver() # No cache
|
||||
|
||||
uploads = resolver.get_cached_uploads("gemini")
|
||||
|
||||
assert uploads == []
|
||||
|
||||
|
||||
class TestCreateResolver:
|
||||
"""Tests for create_resolver factory function."""
|
||||
|
||||
def test_create_default_resolver(self):
|
||||
"""Test creating resolver with default settings."""
|
||||
resolver = create_resolver()
|
||||
|
||||
assert resolver.config.prefer_upload is False
|
||||
assert resolver.upload_cache is not None
|
||||
|
||||
def test_create_resolver_with_options(self):
|
||||
"""Test creating resolver with custom options."""
|
||||
resolver = create_resolver(
|
||||
prefer_upload=True,
|
||||
upload_threshold_bytes=5 * 1024 * 1024,
|
||||
enable_cache=False,
|
||||
)
|
||||
|
||||
assert resolver.config.prefer_upload is True
|
||||
assert resolver.config.upload_threshold_bytes == 5 * 1024 * 1024
|
||||
assert resolver.upload_cache is None
|
||||
|
||||
def test_create_resolver_cache_enabled(self):
|
||||
"""Test resolver has cache when enabled."""
|
||||
resolver = create_resolver(enable_cache=True)
|
||||
|
||||
assert resolver.upload_cache is not None
|
||||
|
||||
def test_create_resolver_cache_disabled(self):
|
||||
"""Test resolver has no cache when disabled."""
|
||||
resolver = create_resolver(enable_cache=False)
|
||||
|
||||
assert resolver.upload_cache is None
|
||||
@@ -1,210 +0,0 @@
|
||||
"""Tests for upload cache."""
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from crewai_files import FileBytes, ImageFile
|
||||
from crewai_files.cache.upload_cache import CachedUpload, UploadCache
|
||||
|
||||
|
||||
# Minimal valid PNG
|
||||
MINIMAL_PNG = (
|
||||
b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x08\x00\x00\x00\x08"
|
||||
b"\x01\x00\x00\x00\x00\xf9Y\xab\xcd\x00\x00\x00\nIDATx\x9cc`\x00\x00"
|
||||
b"\x00\x02\x00\x01\xe2!\xbc3\x00\x00\x00\x00IEND\xaeB`\x82"
|
||||
)
|
||||
|
||||
|
||||
class TestCachedUpload:
|
||||
"""Tests for CachedUpload dataclass."""
|
||||
|
||||
def test_cached_upload_creation(self):
|
||||
"""Test creating a cached upload."""
|
||||
now = datetime.now(timezone.utc)
|
||||
cached = CachedUpload(
|
||||
file_id="file-123",
|
||||
provider="gemini",
|
||||
file_uri="files/file-123",
|
||||
content_type="image/png",
|
||||
uploaded_at=now,
|
||||
expires_at=now + timedelta(hours=48),
|
||||
)
|
||||
|
||||
assert cached.file_id == "file-123"
|
||||
assert cached.provider == "gemini"
|
||||
assert cached.file_uri == "files/file-123"
|
||||
assert cached.content_type == "image/png"
|
||||
|
||||
def test_is_expired_false(self):
|
||||
"""Test is_expired returns False for non-expired upload."""
|
||||
future = datetime.now(timezone.utc) + timedelta(hours=24)
|
||||
cached = CachedUpload(
|
||||
file_id="file-123",
|
||||
provider="gemini",
|
||||
file_uri=None,
|
||||
content_type="image/png",
|
||||
uploaded_at=datetime.now(timezone.utc),
|
||||
expires_at=future,
|
||||
)
|
||||
|
||||
assert cached.is_expired() is False
|
||||
|
||||
def test_is_expired_true(self):
|
||||
"""Test is_expired returns True for expired upload."""
|
||||
past = datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
cached = CachedUpload(
|
||||
file_id="file-123",
|
||||
provider="gemini",
|
||||
file_uri=None,
|
||||
content_type="image/png",
|
||||
uploaded_at=datetime.now(timezone.utc) - timedelta(hours=2),
|
||||
expires_at=past,
|
||||
)
|
||||
|
||||
assert cached.is_expired() is True
|
||||
|
||||
def test_is_expired_no_expiry(self):
|
||||
"""Test is_expired returns False when no expiry set."""
|
||||
cached = CachedUpload(
|
||||
file_id="file-123",
|
||||
provider="anthropic",
|
||||
file_uri=None,
|
||||
content_type="image/png",
|
||||
uploaded_at=datetime.now(timezone.utc),
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
assert cached.is_expired() is False
|
||||
|
||||
|
||||
class TestUploadCache:
|
||||
"""Tests for UploadCache class."""
|
||||
|
||||
def test_cache_creation(self):
|
||||
"""Test creating an empty cache."""
|
||||
cache = UploadCache()
|
||||
|
||||
assert len(cache) == 0
|
||||
|
||||
def test_set_and_get(self):
|
||||
"""Test setting and getting cached uploads."""
|
||||
cache = UploadCache()
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
cache.set(
|
||||
file=file,
|
||||
provider="gemini",
|
||||
file_id="file-123",
|
||||
file_uri="files/file-123",
|
||||
)
|
||||
|
||||
result = cache.get(file, "gemini")
|
||||
|
||||
assert result is not None
|
||||
assert result.file_id == "file-123"
|
||||
assert result.provider == "gemini"
|
||||
|
||||
def test_get_missing(self):
|
||||
"""Test getting non-existent entry returns None."""
|
||||
cache = UploadCache()
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
result = cache.get(file, "gemini")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_get_different_provider(self):
|
||||
"""Test getting with different provider returns None."""
|
||||
cache = UploadCache()
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
cache.set(file=file, provider="gemini", file_id="file-123")
|
||||
|
||||
result = cache.get(file, "anthropic") # Different provider
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_remove(self):
|
||||
"""Test removing cached entry."""
|
||||
cache = UploadCache()
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
cache.set(file=file, provider="gemini", file_id="file-123")
|
||||
removed = cache.remove(file, "gemini")
|
||||
|
||||
assert removed is True
|
||||
assert cache.get(file, "gemini") is None
|
||||
|
||||
def test_remove_missing(self):
|
||||
"""Test removing non-existent entry returns False."""
|
||||
cache = UploadCache()
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
removed = cache.remove(file, "gemini")
|
||||
|
||||
assert removed is False
|
||||
|
||||
def test_remove_by_file_id(self):
|
||||
"""Test removing by file ID."""
|
||||
cache = UploadCache()
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
cache.set(file=file, provider="gemini", file_id="file-123")
|
||||
removed = cache.remove_by_file_id("file-123", "gemini")
|
||||
|
||||
assert removed is True
|
||||
assert len(cache) == 0
|
||||
|
||||
def test_clear_expired(self):
|
||||
"""Test clearing expired entries."""
|
||||
cache = UploadCache()
|
||||
file1 = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test1.png"))
|
||||
file2 = ImageFile(
|
||||
source=FileBytes(data=MINIMAL_PNG + b"x", filename="test2.png")
|
||||
)
|
||||
|
||||
# Add one expired and one valid entry
|
||||
past = datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
future = datetime.now(timezone.utc) + timedelta(hours=24)
|
||||
|
||||
cache.set(file=file1, provider="gemini", file_id="expired", expires_at=past)
|
||||
cache.set(file=file2, provider="gemini", file_id="valid", expires_at=future)
|
||||
|
||||
removed = cache.clear_expired()
|
||||
|
||||
assert removed == 1
|
||||
assert len(cache) == 1
|
||||
assert cache.get(file2, "gemini") is not None
|
||||
|
||||
def test_clear(self):
|
||||
"""Test clearing all entries."""
|
||||
cache = UploadCache()
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
cache.set(file=file, provider="gemini", file_id="file-123")
|
||||
cache.set(file=file, provider="anthropic", file_id="file-456")
|
||||
|
||||
cleared = cache.clear()
|
||||
|
||||
assert cleared == 2
|
||||
assert len(cache) == 0
|
||||
|
||||
def test_get_all_for_provider(self):
|
||||
"""Test getting all cached uploads for a provider."""
|
||||
cache = UploadCache()
|
||||
file1 = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test1.png"))
|
||||
file2 = ImageFile(
|
||||
source=FileBytes(data=MINIMAL_PNG + b"x", filename="test2.png")
|
||||
)
|
||||
file3 = ImageFile(
|
||||
source=FileBytes(data=MINIMAL_PNG + b"xx", filename="test3.png")
|
||||
)
|
||||
|
||||
cache.set(file=file1, provider="gemini", file_id="file-1")
|
||||
cache.set(file=file2, provider="gemini", file_id="file-2")
|
||||
cache.set(file=file3, provider="anthropic", file_id="file-3")
|
||||
|
||||
gemini_uploads = cache.get_all_for_provider("gemini")
|
||||
anthropic_uploads = cache.get_all_for_provider("anthropic")
|
||||
|
||||
assert len(gemini_uploads) == 2
|
||||
assert len(anthropic_uploads) == 1
|
||||
@@ -12,7 +12,7 @@ dependencies = [
|
||||
"pytube~=15.0.0",
|
||||
"requests~=2.32.5",
|
||||
"docker~=7.1.0",
|
||||
"crewai==1.8.1",
|
||||
"crewai==1.8.0",
|
||||
"lancedb~=0.5.4",
|
||||
"tiktoken~=0.8.0",
|
||||
"beautifulsoup4~=4.13.4",
|
||||
|
||||
@@ -291,4 +291,4 @@ __all__ = [
|
||||
"ZapierActionTools",
|
||||
]
|
||||
|
||||
__version__ = "1.8.1"
|
||||
__version__ = "1.8.0"
|
||||
|
||||
@@ -49,7 +49,7 @@ Repository = "https://github.com/crewAIInc/crewAI"
|
||||
|
||||
[project.optional-dependencies]
|
||||
tools = [
|
||||
"crewai-tools==1.8.1",
|
||||
"crewai-tools==1.8.0",
|
||||
]
|
||||
embeddings = [
|
||||
"tiktoken~=0.8.0"
|
||||
@@ -98,9 +98,6 @@ a2a = [
|
||||
"httpx-sse~=0.4.0",
|
||||
"aiocache[redis,memcached]~=0.12.3",
|
||||
]
|
||||
file-processing = [
|
||||
"crewai-files",
|
||||
]
|
||||
|
||||
|
||||
[project.scripts]
|
||||
@@ -127,7 +124,6 @@ torchvision = [
|
||||
{ index = "pytorch-nightly", marker = "python_version >= '3.13'" },
|
||||
{ index = "pytorch", marker = "python_version < '3.13'" },
|
||||
]
|
||||
crewai-files = { workspace = true }
|
||||
|
||||
|
||||
[build-system]
|
||||
|
||||
@@ -3,15 +3,6 @@ from typing import Any
|
||||
import urllib.request
|
||||
import warnings
|
||||
|
||||
from crewai_files import (
|
||||
AudioFile,
|
||||
File,
|
||||
ImageFile,
|
||||
PDFFile,
|
||||
TextFile,
|
||||
VideoFile,
|
||||
)
|
||||
|
||||
from crewai.agent.core import Agent
|
||||
from crewai.crew import Crew
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
@@ -49,7 +40,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
|
||||
|
||||
_suppress_pydantic_deprecation_warnings()
|
||||
|
||||
__version__ = "1.8.1"
|
||||
__version__ = "1.8.0"
|
||||
_telemetry_submitted = False
|
||||
|
||||
|
||||
@@ -83,20 +74,14 @@ _track_install_async()
|
||||
__all__ = [
|
||||
"LLM",
|
||||
"Agent",
|
||||
"AudioFile",
|
||||
"BaseLLM",
|
||||
"Crew",
|
||||
"CrewOutput",
|
||||
"File",
|
||||
"Flow",
|
||||
"ImageFile",
|
||||
"Knowledge",
|
||||
"LLMGuardrail",
|
||||
"PDFFile",
|
||||
"Process",
|
||||
"Task",
|
||||
"TaskOutput",
|
||||
"TextFile",
|
||||
"VideoFile",
|
||||
"__version__",
|
||||
]
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
"""Agent-to-Agent (A2A) protocol communication module for CrewAI."""
|
||||
|
||||
from crewai.a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
|
||||
from crewai.a2a.config import A2AConfig
|
||||
|
||||
|
||||
__all__ = [
|
||||
"A2AClientConfig",
|
||||
"A2AConfig",
|
||||
"A2AServerConfig",
|
||||
]
|
||||
|
||||
@@ -5,57 +5,45 @@ This module is separate from experimental.a2a to avoid circular imports.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from importlib.metadata import version
|
||||
from typing import Any, ClassVar, Literal
|
||||
from typing import Annotated, Any, ClassVar
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from typing_extensions import deprecated
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
BeforeValidator,
|
||||
ConfigDict,
|
||||
Field,
|
||||
HttpUrl,
|
||||
TypeAdapter,
|
||||
)
|
||||
|
||||
from crewai.a2a.auth.schemas import AuthScheme
|
||||
from crewai.a2a.types import TransportType, Url
|
||||
|
||||
|
||||
try:
|
||||
from a2a.types import (
|
||||
AgentCapabilities,
|
||||
AgentCardSignature,
|
||||
AgentInterface,
|
||||
AgentProvider,
|
||||
AgentSkill,
|
||||
SecurityScheme,
|
||||
)
|
||||
|
||||
from crewai.a2a.updates import UpdateConfig
|
||||
except ImportError:
|
||||
UpdateConfig = Any
|
||||
AgentCapabilities = Any
|
||||
AgentCardSignature = Any
|
||||
AgentInterface = Any
|
||||
AgentProvider = Any
|
||||
SecurityScheme = Any
|
||||
AgentSkill = Any
|
||||
UpdateConfig = Any # type: ignore[misc,assignment]
|
||||
|
||||
|
||||
http_url_adapter = TypeAdapter(HttpUrl)
|
||||
|
||||
Url = Annotated[
|
||||
str,
|
||||
BeforeValidator(
|
||||
lambda value: str(http_url_adapter.validate_python(value, strict=True))
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _get_default_update_config() -> UpdateConfig:
|
||||
from crewai.a2a.updates import StreamingConfig
|
||||
|
||||
return StreamingConfig()
|
||||
|
||||
|
||||
@deprecated(
|
||||
"""
|
||||
`crewai.a2a.config.A2AConfig` is deprecated and will be removed in v2.0.0,
|
||||
use `crewai.a2a.config.A2AClientConfig` or `crewai.a2a.config.A2AServerConfig` instead.
|
||||
""",
|
||||
category=FutureWarning,
|
||||
)
|
||||
class A2AConfig(BaseModel):
|
||||
"""Configuration for A2A protocol integration.
|
||||
|
||||
Deprecated:
|
||||
Use A2AClientConfig instead. This class will be removed in a future version.
|
||||
|
||||
Attributes:
|
||||
endpoint: A2A agent endpoint URL.
|
||||
auth: Authentication scheme.
|
||||
@@ -65,7 +53,6 @@ class A2AConfig(BaseModel):
|
||||
fail_fast: If True, raise error when agent unreachable; if False, skip and continue.
|
||||
trust_remote_completion_status: If True, return A2A agent's result directly when completed.
|
||||
updates: Update mechanism config.
|
||||
transport_protocol: A2A transport protocol (grpc, jsonrpc, http+json).
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
@@ -95,180 +82,3 @@ class A2AConfig(BaseModel):
|
||||
default_factory=_get_default_update_config,
|
||||
description="Update mechanism config",
|
||||
)
|
||||
transport_protocol: Literal["JSONRPC", "GRPC", "HTTP+JSON"] = Field(
|
||||
default="JSONRPC",
|
||||
description="Specified mode of A2A transport protocol",
|
||||
)
|
||||
|
||||
|
||||
class A2AClientConfig(BaseModel):
|
||||
"""Configuration for connecting to remote A2A agents.
|
||||
|
||||
Attributes:
|
||||
endpoint: A2A agent endpoint URL.
|
||||
auth: Authentication scheme.
|
||||
timeout: Request timeout in seconds.
|
||||
max_turns: Maximum conversation turns with A2A agent.
|
||||
response_model: Optional Pydantic model for structured A2A agent responses.
|
||||
fail_fast: If True, raise error when agent unreachable; if False, skip and continue.
|
||||
trust_remote_completion_status: If True, return A2A agent's result directly when completed.
|
||||
updates: Update mechanism config.
|
||||
accepted_output_modes: Media types the client can accept in responses.
|
||||
supported_transports: Ordered list of transport protocols the client supports.
|
||||
use_client_preference: Whether to prioritize client transport preferences over server.
|
||||
extensions: Extension URIs the client supports.
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
endpoint: Url = Field(description="A2A agent endpoint URL")
|
||||
auth: AuthScheme | None = Field(
|
||||
default=None,
|
||||
description="Authentication scheme",
|
||||
)
|
||||
timeout: int = Field(default=120, description="Request timeout in seconds")
|
||||
max_turns: int = Field(
|
||||
default=10, description="Maximum conversation turns with A2A agent"
|
||||
)
|
||||
response_model: type[BaseModel] | None = Field(
|
||||
default=None,
|
||||
description="Optional Pydantic model for structured A2A agent responses",
|
||||
)
|
||||
fail_fast: bool = Field(
|
||||
default=True,
|
||||
description="If True, raise error when agent unreachable; if False, skip",
|
||||
)
|
||||
trust_remote_completion_status: bool = Field(
|
||||
default=False,
|
||||
description="If True, return A2A result directly when completed",
|
||||
)
|
||||
updates: UpdateConfig = Field(
|
||||
default_factory=_get_default_update_config,
|
||||
description="Update mechanism config",
|
||||
)
|
||||
accepted_output_modes: list[str] = Field(
|
||||
default_factory=lambda: ["application/json"],
|
||||
description="Media types the client can accept in responses",
|
||||
)
|
||||
supported_transports: list[str] = Field(
|
||||
default_factory=lambda: ["JSONRPC"],
|
||||
description="Ordered list of transport protocols the client supports",
|
||||
)
|
||||
use_client_preference: bool = Field(
|
||||
default=False,
|
||||
description="Whether to prioritize client transport preferences over server",
|
||||
)
|
||||
extensions: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Extension URIs the client supports",
|
||||
)
|
||||
transport_protocol: Literal["JSONRPC", "GRPC", "HTTP+JSON"] = Field(
|
||||
default="JSONRPC",
|
||||
description="Specified mode of A2A transport protocol",
|
||||
)
|
||||
|
||||
|
||||
class A2AServerConfig(BaseModel):
|
||||
"""Configuration for exposing a Crew or Agent as an A2A server.
|
||||
|
||||
All fields correspond to A2A AgentCard fields. Fields like name, description,
|
||||
and skills can be auto-derived from the Crew/Agent if not provided.
|
||||
|
||||
Attributes:
|
||||
name: Human-readable name for the agent.
|
||||
description: Human-readable description of the agent.
|
||||
version: Version string for the agent card.
|
||||
skills: List of agent skills/capabilities.
|
||||
default_input_modes: Default supported input MIME types.
|
||||
default_output_modes: Default supported output MIME types.
|
||||
capabilities: Declaration of optional capabilities.
|
||||
preferred_transport: Transport protocol for the preferred endpoint.
|
||||
protocol_version: A2A protocol version this agent supports.
|
||||
provider: Information about the agent's service provider.
|
||||
documentation_url: URL to the agent's documentation.
|
||||
icon_url: URL to an icon for the agent.
|
||||
additional_interfaces: Additional supported interfaces.
|
||||
security: Security requirement objects for all interactions.
|
||||
security_schemes: Security schemes available to authorize requests.
|
||||
supports_authenticated_extended_card: Whether agent provides extended card to authenticated users.
|
||||
url: Preferred endpoint URL for the agent.
|
||||
signatures: JSON Web Signatures for the AgentCard.
|
||||
"""
|
||||
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
|
||||
|
||||
name: str | None = Field(
|
||||
default=None,
|
||||
description="Human-readable name for the agent. Auto-derived from Crew/Agent if not provided.",
|
||||
)
|
||||
description: str | None = Field(
|
||||
default=None,
|
||||
description="Human-readable description of the agent. Auto-derived from Crew/Agent if not provided.",
|
||||
)
|
||||
version: str = Field(
|
||||
default="1.0.0",
|
||||
description="Version string for the agent card",
|
||||
)
|
||||
skills: list[AgentSkill] = Field(
|
||||
default_factory=list,
|
||||
description="List of agent skills. Auto-derived from tasks/tools if not provided.",
|
||||
)
|
||||
default_input_modes: list[str] = Field(
|
||||
default_factory=lambda: ["text/plain", "application/json"],
|
||||
description="Default supported input MIME types",
|
||||
)
|
||||
default_output_modes: list[str] = Field(
|
||||
default_factory=lambda: ["text/plain", "application/json"],
|
||||
description="Default supported output MIME types",
|
||||
)
|
||||
capabilities: AgentCapabilities = Field(
|
||||
default_factory=lambda: AgentCapabilities(
|
||||
streaming=True,
|
||||
push_notifications=False,
|
||||
),
|
||||
description="Declaration of optional capabilities supported by the agent",
|
||||
)
|
||||
preferred_transport: TransportType = Field(
|
||||
default="JSONRPC",
|
||||
description="Transport protocol for the preferred endpoint",
|
||||
)
|
||||
protocol_version: str = Field(
|
||||
default_factory=lambda: version("a2a-sdk"),
|
||||
description="A2A protocol version this agent supports",
|
||||
)
|
||||
provider: AgentProvider | None = Field(
|
||||
default=None,
|
||||
description="Information about the agent's service provider",
|
||||
)
|
||||
documentation_url: Url | None = Field(
|
||||
default=None,
|
||||
description="URL to the agent's documentation",
|
||||
)
|
||||
icon_url: Url | None = Field(
|
||||
default=None,
|
||||
description="URL to an icon for the agent",
|
||||
)
|
||||
additional_interfaces: list[AgentInterface] = Field(
|
||||
default_factory=list,
|
||||
description="Additional supported interfaces (transport and URL combinations)",
|
||||
)
|
||||
security: list[dict[str, list[str]]] = Field(
|
||||
default_factory=list,
|
||||
description="Security requirement objects for all agent interactions",
|
||||
)
|
||||
security_schemes: dict[str, SecurityScheme] = Field(
|
||||
default_factory=dict,
|
||||
description="Security schemes available to authorize requests",
|
||||
)
|
||||
supports_authenticated_extended_card: bool = Field(
|
||||
default=False,
|
||||
description="Whether agent provides extended card to authenticated users",
|
||||
)
|
||||
url: Url | None = Field(
|
||||
default=None,
|
||||
description="Preferred endpoint URL for the agent. Set at runtime if not provided.",
|
||||
)
|
||||
signatures: list[AgentCardSignature] = Field(
|
||||
default_factory=list,
|
||||
description="JSON Web Signatures for the AgentCard",
|
||||
)
|
||||
|
||||
@@ -3,10 +3,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import TYPE_CHECKING, Any, TypedDict
|
||||
from typing import TYPE_CHECKING, TypedDict
|
||||
import uuid
|
||||
|
||||
from a2a.client.errors import A2AClientHTTPError
|
||||
from a2a.types import (
|
||||
AgentCard,
|
||||
Message,
|
||||
@@ -21,10 +20,7 @@ from a2a.types import (
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.a2a_events import (
|
||||
A2AConnectionErrorEvent,
|
||||
A2AResponseReceivedEvent,
|
||||
)
|
||||
from crewai.events.types.a2a_events import A2AResponseReceivedEvent
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -59,8 +55,7 @@ class TaskStateResult(TypedDict):
|
||||
history: list[Message]
|
||||
result: NotRequired[str]
|
||||
error: NotRequired[str]
|
||||
agent_card: NotRequired[dict[str, Any]]
|
||||
a2a_agent_name: NotRequired[str | None]
|
||||
agent_card: NotRequired[AgentCard]
|
||||
|
||||
|
||||
def extract_task_result_parts(a2a_task: A2ATask) -> list[str]:
|
||||
@@ -136,69 +131,50 @@ def process_task_state(
|
||||
is_multiturn: bool,
|
||||
agent_role: str | None,
|
||||
result_parts: list[str] | None = None,
|
||||
endpoint: str | None = None,
|
||||
a2a_agent_name: str | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
is_final: bool = True,
|
||||
) -> TaskStateResult | None:
|
||||
"""Process A2A task state and return result dictionary.
|
||||
|
||||
Shared logic for both polling and streaming handlers.
|
||||
|
||||
Args:
|
||||
a2a_task: The A2A task to process.
|
||||
new_messages: List to collect messages (modified in place).
|
||||
agent_card: The agent card.
|
||||
turn_number: Current turn number.
|
||||
is_multiturn: Whether multi-turn conversation.
|
||||
agent_role: Agent role for logging.
|
||||
a2a_task: The A2A task to process
|
||||
new_messages: List to collect messages (modified in place)
|
||||
agent_card: The agent card
|
||||
turn_number: Current turn number
|
||||
is_multiturn: Whether multi-turn conversation
|
||||
agent_role: Agent role for logging
|
||||
result_parts: Accumulated result parts (streaming passes accumulated,
|
||||
polling passes None to extract from task).
|
||||
endpoint: A2A agent endpoint URL.
|
||||
a2a_agent_name: Name of the A2A agent from agent card.
|
||||
from_task: Optional CrewAI Task for event metadata.
|
||||
from_agent: Optional CrewAI Agent for event metadata.
|
||||
is_final: Whether this is the final response in the stream.
|
||||
polling passes None to extract from task)
|
||||
|
||||
Returns:
|
||||
Result dictionary if terminal/actionable state, None otherwise.
|
||||
Result dictionary if terminal/actionable state, None otherwise
|
||||
"""
|
||||
should_extract = result_parts is None
|
||||
if result_parts is None:
|
||||
result_parts = []
|
||||
|
||||
if a2a_task.status.state == TaskState.completed:
|
||||
if not result_parts:
|
||||
if should_extract:
|
||||
extracted_parts = extract_task_result_parts(a2a_task)
|
||||
result_parts.extend(extracted_parts)
|
||||
if a2a_task.history:
|
||||
new_messages.extend(a2a_task.history)
|
||||
|
||||
response_text = " ".join(result_parts) if result_parts else ""
|
||||
message_id = None
|
||||
if a2a_task.status and a2a_task.status.message:
|
||||
message_id = a2a_task.status.message.message_id
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AResponseReceivedEvent(
|
||||
response=response_text,
|
||||
turn_number=turn_number,
|
||||
context_id=a2a_task.context_id,
|
||||
message_id=message_id,
|
||||
is_multiturn=is_multiturn,
|
||||
status="completed",
|
||||
final=is_final,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
return TaskStateResult(
|
||||
status=TaskState.completed,
|
||||
agent_card=agent_card.model_dump(exclude_none=True),
|
||||
agent_card=agent_card,
|
||||
result=response_text,
|
||||
history=new_messages,
|
||||
)
|
||||
@@ -218,24 +194,14 @@ def process_task_state(
|
||||
)
|
||||
new_messages.append(agent_message)
|
||||
|
||||
input_message_id = None
|
||||
if a2a_task.status and a2a_task.status.message:
|
||||
input_message_id = a2a_task.status.message.message_id
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AResponseReceivedEvent(
|
||||
response=response_text,
|
||||
turn_number=turn_number,
|
||||
context_id=a2a_task.context_id,
|
||||
message_id=input_message_id,
|
||||
is_multiturn=is_multiturn,
|
||||
status="input_required",
|
||||
final=is_final,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -243,7 +209,7 @@ def process_task_state(
|
||||
status=TaskState.input_required,
|
||||
error=response_text,
|
||||
history=new_messages,
|
||||
agent_card=agent_card.model_dump(exclude_none=True),
|
||||
agent_card=agent_card,
|
||||
)
|
||||
|
||||
if a2a_task.status.state in {TaskState.failed, TaskState.rejected}:
|
||||
@@ -282,11 +248,6 @@ async def send_message_and_get_task_id(
|
||||
turn_number: int,
|
||||
is_multiturn: bool,
|
||||
agent_role: str | None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
endpoint: str | None = None,
|
||||
a2a_agent_name: str | None = None,
|
||||
context_id: str | None = None,
|
||||
) -> str | TaskStateResult:
|
||||
"""Send message and process initial response.
|
||||
|
||||
@@ -301,11 +262,6 @@ async def send_message_and_get_task_id(
|
||||
turn_number: Current turn number
|
||||
is_multiturn: Whether multi-turn conversation
|
||||
agent_role: Agent role for logging
|
||||
from_task: Optional CrewAI Task object for event metadata.
|
||||
from_agent: Optional CrewAI Agent object for event metadata.
|
||||
endpoint: Optional A2A endpoint URL.
|
||||
a2a_agent_name: Optional A2A agent name.
|
||||
context_id: Optional A2A context ID for correlation.
|
||||
|
||||
Returns:
|
||||
Task ID string if agent needs polling/waiting, or TaskStateResult if done.
|
||||
@@ -324,16 +280,9 @@ async def send_message_and_get_task_id(
|
||||
A2AResponseReceivedEvent(
|
||||
response=response_text,
|
||||
turn_number=turn_number,
|
||||
context_id=event.context_id,
|
||||
message_id=event.message_id,
|
||||
is_multiturn=is_multiturn,
|
||||
status="completed",
|
||||
final=True,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -341,7 +290,7 @@ async def send_message_and_get_task_id(
|
||||
status=TaskState.completed,
|
||||
result=response_text,
|
||||
history=new_messages,
|
||||
agent_card=agent_card.model_dump(exclude_none=True),
|
||||
agent_card=agent_card,
|
||||
)
|
||||
|
||||
if isinstance(event, tuple):
|
||||
@@ -355,10 +304,6 @@ async def send_message_and_get_task_id(
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
if result:
|
||||
return result
|
||||
@@ -371,99 +316,6 @@ async def send_message_and_get_task_id(
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
except A2AClientHTTPError as e:
|
||||
error_msg = f"HTTP Error {e.status_code}: {e!s}"
|
||||
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
context_id=context_id,
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint or "",
|
||||
error=str(e),
|
||||
error_type="http_error",
|
||||
status_code=e.status_code,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
operation="send_message",
|
||||
context_id=context_id,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=turn_number,
|
||||
context_id=context_id,
|
||||
is_multiturn=is_multiturn,
|
||||
status="failed",
|
||||
final=True,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error during send_message: {e!s}"
|
||||
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
context_id=context_id,
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint or "",
|
||||
error=str(e),
|
||||
error_type="unexpected_error",
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
operation="send_message",
|
||||
context_id=context_id,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=turn_number,
|
||||
context_id=context_id,
|
||||
is_multiturn=is_multiturn,
|
||||
status="failed",
|
||||
final=True,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
finally:
|
||||
aclose = getattr(event_stream, "aclose", None)
|
||||
if aclose:
|
||||
|
||||
@@ -1,17 +1,7 @@
|
||||
"""Type definitions for A2A protocol message parts."""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Any, Literal, Protocol, TypedDict, runtime_checkable
|
||||
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Literal,
|
||||
Protocol,
|
||||
TypedDict,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
from pydantic import BeforeValidator, HttpUrl, TypeAdapter
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from crewai.a2a.updates import (
|
||||
@@ -25,18 +15,6 @@ from crewai.a2a.updates import (
|
||||
)
|
||||
|
||||
|
||||
TransportType = Literal["JSONRPC", "GRPC", "HTTP+JSON"]
|
||||
|
||||
http_url_adapter: TypeAdapter[HttpUrl] = TypeAdapter(HttpUrl)
|
||||
|
||||
Url = Annotated[
|
||||
str,
|
||||
BeforeValidator(
|
||||
lambda value: str(http_url_adapter.validate_python(value, strict=True))
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class AgentResponseProtocol(Protocol):
|
||||
"""Protocol for the dynamically created AgentResponse model."""
|
||||
|
||||
@@ -22,13 +22,6 @@ class BaseHandlerKwargs(TypedDict, total=False):
|
||||
turn_number: int
|
||||
is_multiturn: bool
|
||||
agent_role: str | None
|
||||
context_id: str | None
|
||||
task_id: str | None
|
||||
endpoint: str | None
|
||||
agent_branch: Any
|
||||
a2a_agent_name: str | None
|
||||
from_task: Any
|
||||
from_agent: Any
|
||||
|
||||
|
||||
class PollingHandlerKwargs(BaseHandlerKwargs, total=False):
|
||||
@@ -36,6 +29,8 @@ class PollingHandlerKwargs(BaseHandlerKwargs, total=False):
|
||||
|
||||
polling_interval: float
|
||||
polling_timeout: float
|
||||
endpoint: str
|
||||
agent_branch: Any
|
||||
history_length: int
|
||||
max_polls: int | None
|
||||
|
||||
@@ -43,6 +38,9 @@ class PollingHandlerKwargs(BaseHandlerKwargs, total=False):
|
||||
class StreamingHandlerKwargs(BaseHandlerKwargs, total=False):
|
||||
"""Kwargs for streaming handler."""
|
||||
|
||||
context_id: str | None
|
||||
task_id: str | None
|
||||
|
||||
|
||||
class PushNotificationHandlerKwargs(BaseHandlerKwargs, total=False):
|
||||
"""Kwargs for push notification handler."""
|
||||
@@ -51,6 +49,7 @@ class PushNotificationHandlerKwargs(BaseHandlerKwargs, total=False):
|
||||
result_store: PushNotificationResultStore
|
||||
polling_timeout: float
|
||||
polling_interval: float
|
||||
agent_branch: Any
|
||||
|
||||
|
||||
class PushNotificationResultStore(Protocol):
|
||||
|
||||
@@ -31,7 +31,6 @@ from crewai.a2a.task_helpers import (
|
||||
from crewai.a2a.updates.base import PollingHandlerKwargs
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.a2a_events import (
|
||||
A2AConnectionErrorEvent,
|
||||
A2APollingStartedEvent,
|
||||
A2APollingStatusEvent,
|
||||
A2AResponseReceivedEvent,
|
||||
@@ -50,33 +49,23 @@ async def _poll_task_until_complete(
|
||||
agent_branch: Any | None = None,
|
||||
history_length: int = 100,
|
||||
max_polls: int | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
context_id: str | None = None,
|
||||
endpoint: str | None = None,
|
||||
a2a_agent_name: str | None = None,
|
||||
) -> A2ATask:
|
||||
"""Poll task status until terminal state reached.
|
||||
|
||||
Args:
|
||||
client: A2A client instance.
|
||||
task_id: Task ID to poll.
|
||||
polling_interval: Seconds between poll attempts.
|
||||
polling_timeout: Max seconds before timeout.
|
||||
agent_branch: Agent tree branch for logging.
|
||||
history_length: Number of messages to retrieve per poll.
|
||||
max_polls: Max number of poll attempts (None = unlimited).
|
||||
from_task: Optional CrewAI Task object for event metadata.
|
||||
from_agent: Optional CrewAI Agent object for event metadata.
|
||||
context_id: A2A context ID for correlation.
|
||||
endpoint: A2A agent endpoint URL.
|
||||
a2a_agent_name: Name of the A2A agent from agent card.
|
||||
client: A2A client instance
|
||||
task_id: Task ID to poll
|
||||
polling_interval: Seconds between poll attempts
|
||||
polling_timeout: Max seconds before timeout
|
||||
agent_branch: Agent tree branch for logging
|
||||
history_length: Number of messages to retrieve per poll
|
||||
max_polls: Max number of poll attempts (None = unlimited)
|
||||
|
||||
Returns:
|
||||
Final task object in terminal state.
|
||||
Final task object in terminal state
|
||||
|
||||
Raises:
|
||||
A2APollingTimeoutError: If polling exceeds timeout or max_polls.
|
||||
A2APollingTimeoutError: If polling exceeds timeout or max_polls
|
||||
"""
|
||||
start_time = time.monotonic()
|
||||
poll_count = 0
|
||||
@@ -88,19 +77,13 @@ async def _poll_task_until_complete(
|
||||
)
|
||||
|
||||
elapsed = time.monotonic() - start_time
|
||||
effective_context_id = task.context_id or context_id
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2APollingStatusEvent(
|
||||
task_id=task_id,
|
||||
context_id=effective_context_id,
|
||||
state=str(task.status.state.value) if task.status.state else "unknown",
|
||||
elapsed_seconds=elapsed,
|
||||
poll_count=poll_count,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -154,9 +137,6 @@ class PollingHandler:
|
||||
max_polls = kwargs.get("max_polls")
|
||||
context_id = kwargs.get("context_id")
|
||||
task_id = kwargs.get("task_id")
|
||||
a2a_agent_name = kwargs.get("a2a_agent_name")
|
||||
from_task = kwargs.get("from_task")
|
||||
from_agent = kwargs.get("from_agent")
|
||||
|
||||
try:
|
||||
result_or_task_id = await send_message_and_get_task_id(
|
||||
@@ -166,11 +146,6 @@ class PollingHandler:
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
agent_role=agent_role,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
context_id=context_id,
|
||||
)
|
||||
|
||||
if not isinstance(result_or_task_id, str):
|
||||
@@ -182,12 +157,8 @@ class PollingHandler:
|
||||
agent_branch,
|
||||
A2APollingStartedEvent(
|
||||
task_id=task_id,
|
||||
context_id=context_id,
|
||||
polling_interval=polling_interval,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -199,11 +170,6 @@ class PollingHandler:
|
||||
agent_branch=agent_branch,
|
||||
history_length=history_length,
|
||||
max_polls=max_polls,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
context_id=context_id,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
)
|
||||
|
||||
result = process_task_state(
|
||||
@@ -213,10 +179,6 @@ class PollingHandler:
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
if result:
|
||||
return result
|
||||
@@ -244,15 +206,9 @@ class PollingHandler:
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=turn_number,
|
||||
context_id=context_id,
|
||||
is_multiturn=is_multiturn,
|
||||
status="failed",
|
||||
final=True,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
@@ -273,83 +229,14 @@ class PollingHandler:
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint,
|
||||
error=str(e),
|
||||
error_type="http_error",
|
||||
status_code=e.status_code,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
operation="polling",
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=turn_number,
|
||||
context_id=context_id,
|
||||
is_multiturn=is_multiturn,
|
||||
status="failed",
|
||||
final=True,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error during polling: {e!s}"
|
||||
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint or "",
|
||||
error=str(e),
|
||||
error_type="unexpected_error",
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
operation="polling",
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=turn_number,
|
||||
context_id=context_id,
|
||||
is_multiturn=is_multiturn,
|
||||
status="failed",
|
||||
final=True,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
|
||||
@@ -29,7 +29,6 @@ from crewai.a2a.updates.base import (
|
||||
)
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.a2a_events import (
|
||||
A2AConnectionErrorEvent,
|
||||
A2APushNotificationRegisteredEvent,
|
||||
A2APushNotificationTimeoutEvent,
|
||||
A2AResponseReceivedEvent,
|
||||
@@ -49,11 +48,6 @@ async def _wait_for_push_result(
|
||||
timeout: float,
|
||||
poll_interval: float,
|
||||
agent_branch: Any | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
context_id: str | None = None,
|
||||
endpoint: str | None = None,
|
||||
a2a_agent_name: str | None = None,
|
||||
) -> A2ATask | None:
|
||||
"""Wait for push notification result.
|
||||
|
||||
@@ -63,11 +57,6 @@ async def _wait_for_push_result(
|
||||
timeout: Max seconds to wait.
|
||||
poll_interval: Seconds between polling attempts.
|
||||
agent_branch: Agent tree branch for logging.
|
||||
from_task: Optional CrewAI Task object for event metadata.
|
||||
from_agent: Optional CrewAI Agent object for event metadata.
|
||||
context_id: A2A context ID for correlation.
|
||||
endpoint: A2A agent endpoint URL.
|
||||
a2a_agent_name: Name of the A2A agent.
|
||||
|
||||
Returns:
|
||||
Final task object, or None if timeout.
|
||||
@@ -83,12 +72,7 @@ async def _wait_for_push_result(
|
||||
agent_branch,
|
||||
A2APushNotificationTimeoutEvent(
|
||||
task_id=task_id,
|
||||
context_id=context_id,
|
||||
timeout_seconds=timeout,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -131,56 +115,18 @@ class PushNotificationHandler:
|
||||
agent_role = kwargs.get("agent_role")
|
||||
context_id = kwargs.get("context_id")
|
||||
task_id = kwargs.get("task_id")
|
||||
endpoint = kwargs.get("endpoint")
|
||||
a2a_agent_name = kwargs.get("a2a_agent_name")
|
||||
from_task = kwargs.get("from_task")
|
||||
from_agent = kwargs.get("from_agent")
|
||||
|
||||
if config is None:
|
||||
error_msg = (
|
||||
"PushNotificationConfig is required for push notification handler"
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint or "",
|
||||
error=error_msg,
|
||||
error_type="configuration_error",
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
operation="push_notification",
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=error_msg,
|
||||
error="PushNotificationConfig is required for push notification handler",
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
if result_store is None:
|
||||
error_msg = (
|
||||
"PushNotificationResultStore is required for push notification handler"
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint or "",
|
||||
error=error_msg,
|
||||
error_type="configuration_error",
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
operation="push_notification",
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=error_msg,
|
||||
error="PushNotificationResultStore is required for push notification handler",
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
@@ -192,11 +138,6 @@ class PushNotificationHandler:
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
agent_role=agent_role,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
context_id=context_id,
|
||||
)
|
||||
|
||||
if not isinstance(result_or_task_id, str):
|
||||
@@ -208,12 +149,7 @@ class PushNotificationHandler:
|
||||
agent_branch,
|
||||
A2APushNotificationRegisteredEvent(
|
||||
task_id=task_id,
|
||||
context_id=context_id,
|
||||
callback_url=str(config.url),
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -229,11 +165,6 @@ class PushNotificationHandler:
|
||||
timeout=polling_timeout,
|
||||
poll_interval=polling_interval,
|
||||
agent_branch=agent_branch,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
context_id=context_id,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
)
|
||||
|
||||
if final_task is None:
|
||||
@@ -250,10 +181,6 @@ class PushNotificationHandler:
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
if result:
|
||||
return result
|
||||
@@ -276,83 +203,14 @@ class PushNotificationHandler:
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint or "",
|
||||
error=str(e),
|
||||
error_type="http_error",
|
||||
status_code=e.status_code,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
operation="push_notification",
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=turn_number,
|
||||
context_id=context_id,
|
||||
is_multiturn=is_multiturn,
|
||||
status="failed",
|
||||
final=True,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error during push notification: {e!s}"
|
||||
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint or "",
|
||||
error=str(e),
|
||||
error_type="unexpected_error",
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
operation="push_notification",
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=turn_number,
|
||||
context_id=context_id,
|
||||
is_multiturn=is_multiturn,
|
||||
status="failed",
|
||||
final=True,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
|
||||
@@ -26,13 +26,7 @@ from crewai.a2a.task_helpers import (
|
||||
)
|
||||
from crewai.a2a.updates.base import StreamingHandlerKwargs
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.a2a_events import (
|
||||
A2AArtifactReceivedEvent,
|
||||
A2AConnectionErrorEvent,
|
||||
A2AResponseReceivedEvent,
|
||||
A2AStreamingChunkEvent,
|
||||
A2AStreamingStartedEvent,
|
||||
)
|
||||
from crewai.events.types.a2a_events import A2AResponseReceivedEvent
|
||||
|
||||
|
||||
class StreamingHandler:
|
||||
@@ -63,57 +57,19 @@ class StreamingHandler:
|
||||
turn_number = kwargs.get("turn_number", 0)
|
||||
is_multiturn = kwargs.get("is_multiturn", False)
|
||||
agent_role = kwargs.get("agent_role")
|
||||
endpoint = kwargs.get("endpoint")
|
||||
a2a_agent_name = kwargs.get("a2a_agent_name")
|
||||
from_task = kwargs.get("from_task")
|
||||
from_agent = kwargs.get("from_agent")
|
||||
agent_branch = kwargs.get("agent_branch")
|
||||
|
||||
result_parts: list[str] = []
|
||||
final_result: TaskStateResult | None = None
|
||||
event_stream = client.send_message(message)
|
||||
chunk_index = 0
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AStreamingStartedEvent(
|
||||
task_id=task_id,
|
||||
context_id=context_id,
|
||||
endpoint=endpoint or "",
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
agent_role=agent_role,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
async for event in event_stream:
|
||||
if isinstance(event, Message):
|
||||
new_messages.append(event)
|
||||
message_context_id = event.context_id or context_id
|
||||
for part in event.parts:
|
||||
if part.root.kind == "text":
|
||||
text = part.root.text
|
||||
result_parts.append(text)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AStreamingChunkEvent(
|
||||
task_id=event.task_id or task_id,
|
||||
context_id=message_context_id,
|
||||
chunk=text,
|
||||
chunk_index=chunk_index,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
chunk_index += 1
|
||||
|
||||
elif isinstance(event, tuple):
|
||||
a2a_task, update = event
|
||||
@@ -125,51 +81,10 @@ class StreamingHandler:
|
||||
for part in artifact.parts
|
||||
if part.root.kind == "text"
|
||||
)
|
||||
artifact_size = None
|
||||
if artifact.parts:
|
||||
artifact_size = sum(
|
||||
len(p.root.text.encode("utf-8"))
|
||||
if p.root.kind == "text"
|
||||
else len(getattr(p.root, "data", b""))
|
||||
for p in artifact.parts
|
||||
)
|
||||
effective_context_id = a2a_task.context_id or context_id
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AArtifactReceivedEvent(
|
||||
task_id=a2a_task.id,
|
||||
artifact_id=artifact.artifact_id,
|
||||
artifact_name=artifact.name,
|
||||
artifact_description=artifact.description,
|
||||
mime_type=artifact.parts[0].root.kind
|
||||
if artifact.parts
|
||||
else None,
|
||||
size_bytes=artifact_size,
|
||||
append=update.append or False,
|
||||
last_chunk=update.last_chunk or False,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
context_id=effective_context_id,
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
is_final_update = False
|
||||
if isinstance(update, TaskStatusUpdateEvent):
|
||||
is_final_update = update.final
|
||||
if (
|
||||
update.status
|
||||
and update.status.message
|
||||
and update.status.message.parts
|
||||
):
|
||||
result_parts.extend(
|
||||
part.root.text
|
||||
for part in update.status.message.parts
|
||||
if part.root.kind == "text" and part.root.text
|
||||
)
|
||||
|
||||
if (
|
||||
not is_final_update
|
||||
@@ -186,11 +101,6 @@ class StreamingHandler:
|
||||
is_multiturn=is_multiturn,
|
||||
agent_role=agent_role,
|
||||
result_parts=result_parts,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
is_final=is_final_update,
|
||||
)
|
||||
if final_result:
|
||||
break
|
||||
@@ -208,82 +118,13 @@ class StreamingHandler:
|
||||
new_messages.append(error_message)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint or "",
|
||||
error=str(e),
|
||||
error_type="http_error",
|
||||
status_code=e.status_code,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
operation="streaming",
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
None,
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=turn_number,
|
||||
context_id=context_id,
|
||||
is_multiturn=is_multiturn,
|
||||
status="failed",
|
||||
final=True,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
status=TaskState.failed,
|
||||
error=error_msg,
|
||||
history=new_messages,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error during streaming: {e!s}"
|
||||
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint or "",
|
||||
error=str(e),
|
||||
error_type="unexpected_error",
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
operation="streaming",
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=turn_number,
|
||||
context_id=context_id,
|
||||
is_multiturn=is_multiturn,
|
||||
status="failed",
|
||||
final=True,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
return TaskStateResult(
|
||||
@@ -295,23 +136,7 @@ class StreamingHandler:
|
||||
finally:
|
||||
aclose = getattr(event_stream, "aclose", None)
|
||||
if aclose:
|
||||
try:
|
||||
await aclose()
|
||||
except Exception as close_error:
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint or "",
|
||||
error=str(close_error),
|
||||
error_type="stream_close_error",
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
operation="stream_close",
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
await aclose()
|
||||
|
||||
if final_result:
|
||||
return final_result
|
||||
@@ -320,5 +145,5 @@ class StreamingHandler:
|
||||
status=TaskState.completed,
|
||||
result=" ".join(result_parts) if result_parts else "",
|
||||
history=new_messages,
|
||||
agent_card=agent_card.model_dump(exclude_none=True),
|
||||
agent_card=agent_card,
|
||||
)
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
"""A2A delegation utilities for executing tasks on remote agents."""
|
||||
"""Utility functions for A2A (Agent-to-Agent) protocol delegation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator, MutableMapping
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
from functools import lru_cache
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any
|
||||
import uuid
|
||||
|
||||
from a2a.client import Client, ClientConfig, ClientFactory
|
||||
from a2a.client import A2AClientHTTPError, Client, ClientConfig, ClientFactory
|
||||
from a2a.types import (
|
||||
AgentCard,
|
||||
Message,
|
||||
@@ -16,16 +18,21 @@ from a2a.types import (
|
||||
PushNotificationConfig as A2APushNotificationConfig,
|
||||
Role,
|
||||
TextPart,
|
||||
TransportProtocol,
|
||||
)
|
||||
from aiocache import cached # type: ignore[import-untyped]
|
||||
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
|
||||
from crewai.a2a.auth.schemas import APIKeyAuth, HTTPDigestAuth
|
||||
from crewai.a2a.auth.utils import (
|
||||
_auth_store,
|
||||
configure_auth_client,
|
||||
retry_on_401,
|
||||
validate_auth_against_agent_card,
|
||||
)
|
||||
from crewai.a2a.config import A2AConfig
|
||||
from crewai.a2a.task_helpers import TaskStateResult
|
||||
from crewai.a2a.types import (
|
||||
HANDLER_REGISTRY,
|
||||
@@ -39,7 +46,6 @@ from crewai.a2a.updates import (
|
||||
StreamingHandler,
|
||||
UpdateConfig,
|
||||
)
|
||||
from crewai.a2a.utils.agent_card import _afetch_agent_card_cached
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.a2a_events import (
|
||||
A2AConversationStartedEvent,
|
||||
@@ -47,6 +53,7 @@ from crewai.events.types.a2a_events import (
|
||||
A2ADelegationStartedEvent,
|
||||
A2AMessageSentEvent,
|
||||
)
|
||||
from crewai.types.utils import create_literals_from_strings
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -69,9 +76,189 @@ def get_handler(config: UpdateConfig | None) -> HandlerType:
|
||||
return HANDLER_REGISTRY.get(type(config), StreamingHandler)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def _fetch_agent_card_cached(
|
||||
endpoint: str,
|
||||
auth_hash: int,
|
||||
timeout: int,
|
||||
_ttl_hash: int,
|
||||
) -> AgentCard:
|
||||
"""Cached sync version of fetch_agent_card."""
|
||||
auth = _auth_store.get(auth_hash)
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
return loop.run_until_complete(
|
||||
_afetch_agent_card_impl(endpoint=endpoint, auth=auth, timeout=timeout)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
def fetch_agent_card(
|
||||
endpoint: str,
|
||||
auth: AuthScheme | None = None,
|
||||
timeout: int = 30,
|
||||
use_cache: bool = True,
|
||||
cache_ttl: int = 300,
|
||||
) -> AgentCard:
|
||||
"""Fetch AgentCard from an A2A endpoint with optional caching.
|
||||
|
||||
Args:
|
||||
endpoint: A2A agent endpoint URL (AgentCard URL)
|
||||
auth: Optional AuthScheme for authentication
|
||||
timeout: Request timeout in seconds
|
||||
use_cache: Whether to use caching (default True)
|
||||
cache_ttl: Cache TTL in seconds (default 300 = 5 minutes)
|
||||
|
||||
Returns:
|
||||
AgentCard object with agent capabilities and skills
|
||||
|
||||
Raises:
|
||||
httpx.HTTPStatusError: If the request fails
|
||||
A2AClientHTTPError: If authentication fails
|
||||
"""
|
||||
if use_cache:
|
||||
if auth:
|
||||
auth_data = auth.model_dump_json(
|
||||
exclude={
|
||||
"_access_token",
|
||||
"_token_expires_at",
|
||||
"_refresh_token",
|
||||
"_authorization_callback",
|
||||
}
|
||||
)
|
||||
auth_hash = hash((type(auth).__name__, auth_data))
|
||||
else:
|
||||
auth_hash = 0
|
||||
_auth_store[auth_hash] = auth
|
||||
ttl_hash = int(time.time() // cache_ttl)
|
||||
return _fetch_agent_card_cached(endpoint, auth_hash, timeout, ttl_hash)
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
return loop.run_until_complete(
|
||||
afetch_agent_card(endpoint=endpoint, auth=auth, timeout=timeout)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def afetch_agent_card(
|
||||
endpoint: str,
|
||||
auth: AuthScheme | None = None,
|
||||
timeout: int = 30,
|
||||
use_cache: bool = True,
|
||||
) -> AgentCard:
|
||||
"""Fetch AgentCard from an A2A endpoint asynchronously.
|
||||
|
||||
Native async implementation. Use this when running in an async context.
|
||||
|
||||
Args:
|
||||
endpoint: A2A agent endpoint URL (AgentCard URL).
|
||||
auth: Optional AuthScheme for authentication.
|
||||
timeout: Request timeout in seconds.
|
||||
use_cache: Whether to use caching (default True).
|
||||
|
||||
Returns:
|
||||
AgentCard object with agent capabilities and skills.
|
||||
|
||||
Raises:
|
||||
httpx.HTTPStatusError: If the request fails.
|
||||
A2AClientHTTPError: If authentication fails.
|
||||
"""
|
||||
if use_cache:
|
||||
if auth:
|
||||
auth_data = auth.model_dump_json(
|
||||
exclude={
|
||||
"_access_token",
|
||||
"_token_expires_at",
|
||||
"_refresh_token",
|
||||
"_authorization_callback",
|
||||
}
|
||||
)
|
||||
auth_hash = hash((type(auth).__name__, auth_data))
|
||||
else:
|
||||
auth_hash = 0
|
||||
_auth_store[auth_hash] = auth
|
||||
agent_card: AgentCard = await _afetch_agent_card_cached(
|
||||
endpoint, auth_hash, timeout
|
||||
)
|
||||
return agent_card
|
||||
|
||||
return await _afetch_agent_card_impl(endpoint=endpoint, auth=auth, timeout=timeout)
|
||||
|
||||
|
||||
@cached(ttl=300, serializer=PickleSerializer()) # type: ignore[untyped-decorator]
|
||||
async def _afetch_agent_card_cached(
|
||||
endpoint: str,
|
||||
auth_hash: int,
|
||||
timeout: int,
|
||||
) -> AgentCard:
|
||||
"""Cached async implementation of AgentCard fetching."""
|
||||
auth = _auth_store.get(auth_hash)
|
||||
return await _afetch_agent_card_impl(endpoint=endpoint, auth=auth, timeout=timeout)
|
||||
|
||||
|
||||
async def _afetch_agent_card_impl(
|
||||
endpoint: str,
|
||||
auth: AuthScheme | None,
|
||||
timeout: int,
|
||||
) -> AgentCard:
|
||||
"""Internal async implementation of AgentCard fetching."""
|
||||
if "/.well-known/agent-card.json" in endpoint:
|
||||
base_url = endpoint.replace("/.well-known/agent-card.json", "")
|
||||
agent_card_path = "/.well-known/agent-card.json"
|
||||
else:
|
||||
url_parts = endpoint.split("/", 3)
|
||||
base_url = f"{url_parts[0]}//{url_parts[2]}"
|
||||
agent_card_path = f"/{url_parts[3]}" if len(url_parts) > 3 else "/"
|
||||
|
||||
headers: MutableMapping[str, str] = {}
|
||||
if auth:
|
||||
async with httpx.AsyncClient(timeout=timeout) as temp_auth_client:
|
||||
if isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
|
||||
configure_auth_client(auth, temp_auth_client)
|
||||
headers = await auth.apply_auth(temp_auth_client, {})
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout, headers=headers) as temp_client:
|
||||
if auth and isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
|
||||
configure_auth_client(auth, temp_client)
|
||||
|
||||
agent_card_url = f"{base_url}{agent_card_path}"
|
||||
|
||||
async def _fetch_agent_card_request() -> httpx.Response:
|
||||
return await temp_client.get(agent_card_url)
|
||||
|
||||
try:
|
||||
response = await retry_on_401(
|
||||
request_func=_fetch_agent_card_request,
|
||||
auth_scheme=auth,
|
||||
client=temp_client,
|
||||
headers=temp_client.headers,
|
||||
max_retries=2,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return AgentCard.model_validate(response.json())
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 401:
|
||||
error_details = ["Authentication failed"]
|
||||
www_auth = e.response.headers.get("WWW-Authenticate")
|
||||
if www_auth:
|
||||
error_details.append(f"WWW-Authenticate: {www_auth}")
|
||||
if not auth:
|
||||
error_details.append("No auth scheme provided")
|
||||
msg = " | ".join(error_details)
|
||||
raise A2AClientHTTPError(401, msg) from e
|
||||
raise
|
||||
|
||||
|
||||
def execute_a2a_delegation(
|
||||
endpoint: str,
|
||||
transport_protocol: Literal["JSONRPC", "GRPC", "HTTP+JSON"],
|
||||
auth: AuthScheme | None,
|
||||
timeout: int,
|
||||
task_description: str,
|
||||
@@ -88,9 +275,6 @@ def execute_a2a_delegation(
|
||||
response_model: type[BaseModel] | None = None,
|
||||
turn_number: int | None = None,
|
||||
updates: UpdateConfig | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
skill_id: str | None = None,
|
||||
) -> TaskStateResult:
|
||||
"""Execute a task delegation to a remote A2A agent synchronously.
|
||||
|
||||
@@ -98,23 +282,6 @@ def execute_a2a_delegation(
|
||||
use aexecute_a2a_delegation directly.
|
||||
|
||||
Args:
|
||||
endpoint: A2A agent endpoint URL (AgentCard URL)
|
||||
transport_protocol: Optional A2A transport protocol (grpc, jsonrpc, http+json)
|
||||
auth: Optional AuthScheme for authentication (Bearer, OAuth2, API Key, HTTP Basic/Digest)
|
||||
timeout: Request timeout in seconds
|
||||
task_description: The task to delegate
|
||||
context: Optional context information
|
||||
context_id: Context ID for correlating messages/tasks
|
||||
task_id: Specific task identifier
|
||||
reference_task_ids: List of related task IDs
|
||||
metadata: Additional metadata (external_id, request_id, etc.)
|
||||
extensions: Protocol extensions for custom fields
|
||||
conversation_history: Previous Message objects from conversation
|
||||
agent_id: Agent identifier for logging
|
||||
agent_role: Role of the CrewAI agent delegating the task
|
||||
agent_branch: Optional agent tree branch for logging
|
||||
response_model: Optional Pydantic model for structured outputs
|
||||
turn_number: Optional turn number for multi-turn conversations
|
||||
endpoint: A2A agent endpoint URL.
|
||||
auth: Optional AuthScheme for authentication.
|
||||
timeout: Request timeout in seconds.
|
||||
@@ -132,9 +299,6 @@ def execute_a2a_delegation(
|
||||
response_model: Optional Pydantic model for structured outputs.
|
||||
turn_number: Optional turn number for multi-turn conversations.
|
||||
updates: Update mechanism config from A2AConfig.updates.
|
||||
from_task: Optional CrewAI Task object for event metadata.
|
||||
from_agent: Optional CrewAI Agent object for event metadata.
|
||||
skill_id: Optional skill ID to target a specific agent capability.
|
||||
|
||||
Returns:
|
||||
TaskStateResult with status, result/error, history, and agent_card.
|
||||
@@ -159,24 +323,16 @@ def execute_a2a_delegation(
|
||||
agent_role=agent_role,
|
||||
agent_branch=agent_branch,
|
||||
response_model=response_model,
|
||||
transport_protocol=transport_protocol,
|
||||
turn_number=turn_number,
|
||||
updates=updates,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
skill_id=skill_id,
|
||||
)
|
||||
)
|
||||
finally:
|
||||
try:
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
finally:
|
||||
loop.close()
|
||||
loop.close()
|
||||
|
||||
|
||||
async def aexecute_a2a_delegation(
|
||||
endpoint: str,
|
||||
transport_protocol: Literal["JSONRPC", "GRPC", "HTTP+JSON"],
|
||||
auth: AuthScheme | None,
|
||||
timeout: int,
|
||||
task_description: str,
|
||||
@@ -193,9 +349,6 @@ async def aexecute_a2a_delegation(
|
||||
response_model: type[BaseModel] | None = None,
|
||||
turn_number: int | None = None,
|
||||
updates: UpdateConfig | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
skill_id: str | None = None,
|
||||
) -> TaskStateResult:
|
||||
"""Execute a task delegation to a remote A2A agent asynchronously.
|
||||
|
||||
@@ -203,23 +356,6 @@ async def aexecute_a2a_delegation(
|
||||
in an async context (e.g., with Crew.akickoff() or agent.aexecute_task()).
|
||||
|
||||
Args:
|
||||
endpoint: A2A agent endpoint URL
|
||||
transport_protocol: Optional A2A transport protocol (grpc, jsonrpc, http+json)
|
||||
auth: Optional AuthScheme for authentication
|
||||
timeout: Request timeout in seconds
|
||||
task_description: Task to delegate
|
||||
context: Optional context
|
||||
context_id: Context ID for correlation
|
||||
task_id: Specific task identifier
|
||||
reference_task_ids: Related task IDs
|
||||
metadata: Additional metadata
|
||||
extensions: Protocol extensions
|
||||
conversation_history: Previous Message objects
|
||||
turn_number: Current turn number
|
||||
agent_branch: Agent tree branch for logging
|
||||
agent_id: Agent identifier for logging
|
||||
agent_role: Agent role for logging
|
||||
response_model: Optional Pydantic model for structured outputs
|
||||
endpoint: A2A agent endpoint URL.
|
||||
auth: Optional AuthScheme for authentication.
|
||||
timeout: Request timeout in seconds.
|
||||
@@ -237,9 +373,6 @@ async def aexecute_a2a_delegation(
|
||||
response_model: Optional Pydantic model for structured outputs.
|
||||
turn_number: Optional turn number for multi-turn conversations.
|
||||
updates: Update mechanism config from A2AConfig.updates.
|
||||
from_task: Optional CrewAI Task object for event metadata.
|
||||
from_agent: Optional CrewAI Agent object for event metadata.
|
||||
skill_id: Optional skill ID to target a specific agent capability.
|
||||
|
||||
Returns:
|
||||
TaskStateResult with status, result/error, history, and agent_card.
|
||||
@@ -251,6 +384,17 @@ async def aexecute_a2a_delegation(
|
||||
if turn_number is None:
|
||||
turn_number = len([m for m in conversation_history if m.role == Role.user]) + 1
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2ADelegationStartedEvent(
|
||||
endpoint=endpoint,
|
||||
task_description=task_description,
|
||||
agent_id=agent_id,
|
||||
is_multiturn=is_multiturn,
|
||||
turn_number=turn_number,
|
||||
),
|
||||
)
|
||||
|
||||
result = await _aexecute_a2a_delegation_impl(
|
||||
endpoint=endpoint,
|
||||
auth=auth,
|
||||
@@ -270,29 +414,15 @@ async def aexecute_a2a_delegation(
|
||||
agent_role=agent_role,
|
||||
response_model=response_model,
|
||||
updates=updates,
|
||||
transport_protocol=transport_protocol,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
skill_id=skill_id,
|
||||
)
|
||||
|
||||
agent_card_data: dict[str, Any] = result.get("agent_card") or {}
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2ADelegationCompletedEvent(
|
||||
status=result["status"],
|
||||
result=result.get("result"),
|
||||
error=result.get("error"),
|
||||
context_id=context_id,
|
||||
is_multiturn=is_multiturn,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=result.get("a2a_agent_name"),
|
||||
agent_card=agent_card_data,
|
||||
provider=agent_card_data.get("provider"),
|
||||
metadata=metadata,
|
||||
extensions=list(extensions.keys()) if extensions else None,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -301,7 +431,6 @@ async def aexecute_a2a_delegation(
|
||||
|
||||
async def _aexecute_a2a_delegation_impl(
|
||||
endpoint: str,
|
||||
transport_protocol: Literal["JSONRPC", "GRPC", "HTTP+JSON"],
|
||||
auth: AuthScheme | None,
|
||||
timeout: int,
|
||||
task_description: str,
|
||||
@@ -319,9 +448,6 @@ async def _aexecute_a2a_delegation_impl(
|
||||
agent_role: str | None,
|
||||
response_model: type[BaseModel] | None,
|
||||
updates: UpdateConfig | None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
skill_id: str | None = None,
|
||||
) -> TaskStateResult:
|
||||
"""Internal async implementation of A2A delegation."""
|
||||
if auth:
|
||||
@@ -354,28 +480,6 @@ async def _aexecute_a2a_delegation_impl(
|
||||
if agent_card.name:
|
||||
a2a_agent_name = agent_card.name
|
||||
|
||||
agent_card_dict = agent_card.model_dump(exclude_none=True)
|
||||
crewai_event_bus.emit(
|
||||
agent_branch,
|
||||
A2ADelegationStartedEvent(
|
||||
endpoint=endpoint,
|
||||
task_description=task_description,
|
||||
agent_id=agent_id or endpoint,
|
||||
context_id=context_id,
|
||||
is_multiturn=is_multiturn,
|
||||
turn_number=turn_number,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
agent_card=agent_card_dict,
|
||||
protocol_version=agent_card.protocol_version,
|
||||
provider=agent_card_dict.get("provider"),
|
||||
skill_id=skill_id,
|
||||
metadata=metadata,
|
||||
extensions=list(extensions.keys()) if extensions else None,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
if turn_number == 1:
|
||||
agent_id_for_event = agent_id or endpoint
|
||||
crewai_event_bus.emit(
|
||||
@@ -383,17 +487,7 @@ async def _aexecute_a2a_delegation_impl(
|
||||
A2AConversationStartedEvent(
|
||||
agent_id=agent_id_for_event,
|
||||
endpoint=endpoint,
|
||||
context_id=context_id,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
agent_card=agent_card_dict,
|
||||
protocol_version=agent_card.protocol_version,
|
||||
provider=agent_card_dict.get("provider"),
|
||||
skill_id=skill_id,
|
||||
reference_task_ids=reference_task_ids,
|
||||
metadata=metadata,
|
||||
extensions=list(extensions.keys()) if extensions else None,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -419,10 +513,6 @@ async def _aexecute_a2a_delegation_impl(
|
||||
}
|
||||
)
|
||||
|
||||
message_metadata = metadata.copy() if metadata else {}
|
||||
if skill_id:
|
||||
message_metadata["skill_id"] = skill_id
|
||||
|
||||
message = Message(
|
||||
role=Role.user,
|
||||
message_id=str(uuid.uuid4()),
|
||||
@@ -430,27 +520,19 @@ async def _aexecute_a2a_delegation_impl(
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
reference_task_ids=reference_task_ids,
|
||||
metadata=message_metadata if message_metadata else None,
|
||||
metadata=metadata,
|
||||
extensions=extensions,
|
||||
)
|
||||
|
||||
transport_protocol = TransportProtocol("JSONRPC")
|
||||
new_messages: list[Message] = [*conversation_history, message]
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AMessageSentEvent(
|
||||
message=message_text,
|
||||
turn_number=turn_number,
|
||||
context_id=context_id,
|
||||
message_id=message.message_id,
|
||||
is_multiturn=is_multiturn,
|
||||
agent_role=agent_role,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
skill_id=skill_id,
|
||||
metadata=message_metadata if message_metadata else None,
|
||||
extensions=list(extensions.keys()) if extensions else None,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -465,9 +547,6 @@ async def _aexecute_a2a_delegation_impl(
|
||||
"task_id": task_id,
|
||||
"endpoint": endpoint,
|
||||
"agent_branch": agent_branch,
|
||||
"a2a_agent_name": a2a_agent_name,
|
||||
"from_task": from_task,
|
||||
"from_agent": from_agent,
|
||||
}
|
||||
|
||||
if isinstance(updates, PollingConfig):
|
||||
@@ -505,22 +584,19 @@ async def _aexecute_a2a_delegation_impl(
|
||||
use_polling=use_polling,
|
||||
push_notification_config=push_config_for_client,
|
||||
) as client:
|
||||
result = await handler.execute(
|
||||
return await handler.execute(
|
||||
client=client,
|
||||
message=message,
|
||||
new_messages=new_messages,
|
||||
agent_card=agent_card,
|
||||
**handler_kwargs,
|
||||
)
|
||||
result["a2a_agent_name"] = a2a_agent_name
|
||||
result["agent_card"] = agent_card.model_dump(exclude_none=True)
|
||||
return result
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _create_a2a_client(
|
||||
agent_card: AgentCard,
|
||||
transport_protocol: Literal["JSONRPC", "GRPC", "HTTP+JSON"],
|
||||
transport_protocol: TransportProtocol,
|
||||
timeout: int,
|
||||
headers: MutableMapping[str, str],
|
||||
streaming: bool,
|
||||
@@ -531,18 +607,19 @@ async def _create_a2a_client(
|
||||
"""Create and configure an A2A client.
|
||||
|
||||
Args:
|
||||
agent_card: The A2A agent card.
|
||||
transport_protocol: Transport protocol to use.
|
||||
timeout: Request timeout in seconds.
|
||||
headers: HTTP headers (already with auth applied).
|
||||
streaming: Enable streaming responses.
|
||||
auth: Optional AuthScheme for client configuration.
|
||||
use_polling: Enable polling mode.
|
||||
push_notification_config: Optional push notification config.
|
||||
agent_card: The A2A agent card
|
||||
transport_protocol: Transport protocol to use
|
||||
timeout: Request timeout in seconds
|
||||
headers: HTTP headers (already with auth applied)
|
||||
streaming: Enable streaming responses
|
||||
auth: Optional AuthScheme for client configuration
|
||||
use_polling: Enable polling mode
|
||||
push_notification_config: Optional push notification config to include in requests
|
||||
|
||||
Yields:
|
||||
Configured A2A client instance.
|
||||
Configured A2A client instance
|
||||
"""
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
timeout=timeout,
|
||||
headers=headers,
|
||||
@@ -563,7 +640,7 @@ async def _create_a2a_client(
|
||||
|
||||
config = ClientConfig(
|
||||
httpx_client=httpx_client,
|
||||
supported_transports=[transport_protocol],
|
||||
supported_transports=[str(transport_protocol.value)],
|
||||
streaming=streaming and not use_polling,
|
||||
polling=use_polling,
|
||||
accepted_output_modes=["application/json"],
|
||||
@@ -573,3 +650,78 @@ async def _create_a2a_client(
|
||||
factory = ClientFactory(config)
|
||||
client = factory.create(agent_card)
|
||||
yield client
|
||||
|
||||
|
||||
def create_agent_response_model(agent_ids: tuple[str, ...]) -> type[BaseModel]:
|
||||
"""Create a dynamic AgentResponse model with Literal types for agent IDs.
|
||||
|
||||
Args:
|
||||
agent_ids: List of available A2A agent IDs
|
||||
|
||||
Returns:
|
||||
Dynamically created Pydantic model with Literal-constrained a2a_ids field
|
||||
"""
|
||||
|
||||
DynamicLiteral = create_literals_from_strings(agent_ids) # noqa: N806
|
||||
|
||||
return create_model(
|
||||
"AgentResponse",
|
||||
a2a_ids=(
|
||||
tuple[DynamicLiteral, ...], # type: ignore[valid-type]
|
||||
Field(
|
||||
default_factory=tuple,
|
||||
max_length=len(agent_ids),
|
||||
description="A2A agent IDs to delegate to.",
|
||||
),
|
||||
),
|
||||
message=(
|
||||
str,
|
||||
Field(
|
||||
description="The message content. If is_a2a=true, this is sent to the A2A agent. If is_a2a=false, this is your final answer ending the conversation."
|
||||
),
|
||||
),
|
||||
is_a2a=(
|
||||
bool,
|
||||
Field(
|
||||
description="Set to false when the remote agent has answered your question - extract their answer and return it as your final message. Set to true ONLY if you need to ask a NEW, DIFFERENT question. NEVER repeat the same request - if the conversation history shows the agent already answered, set is_a2a=false immediately."
|
||||
),
|
||||
),
|
||||
__base__=BaseModel,
|
||||
)
|
||||
|
||||
|
||||
def extract_a2a_agent_ids_from_config(
|
||||
a2a_config: list[A2AConfig] | A2AConfig | None,
|
||||
) -> tuple[list[A2AConfig], tuple[str, ...]]:
|
||||
"""Extract A2A agent IDs from A2A configuration.
|
||||
|
||||
Args:
|
||||
a2a_config: A2A configuration
|
||||
|
||||
Returns:
|
||||
List of A2A agent IDs
|
||||
"""
|
||||
if a2a_config is None:
|
||||
return [], ()
|
||||
|
||||
if isinstance(a2a_config, A2AConfig):
|
||||
a2a_agents = [a2a_config]
|
||||
else:
|
||||
a2a_agents = a2a_config
|
||||
return a2a_agents, tuple(config.endpoint for config in a2a_agents)
|
||||
|
||||
|
||||
def get_a2a_agents_and_response_model(
|
||||
a2a_config: list[A2AConfig] | A2AConfig | None,
|
||||
) -> tuple[list[A2AConfig], type[BaseModel]]:
|
||||
"""Get A2A agent IDs and response model.
|
||||
|
||||
Args:
|
||||
a2a_config: A2A configuration
|
||||
|
||||
Returns:
|
||||
Tuple of A2A agent IDs and response model
|
||||
"""
|
||||
a2a_agents, agent_ids = extract_a2a_agent_ids_from_config(a2a_config=a2a_config)
|
||||
|
||||
return a2a_agents, create_agent_response_model(agent_ids)
|
||||
@@ -1 +0,0 @@
|
||||
"""A2A utility modules for client operations."""
|
||||
@@ -1,513 +0,0 @@
|
||||
"""AgentCard utilities for A2A client and server operations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import MutableMapping
|
||||
from functools import lru_cache
|
||||
import time
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from a2a.client.errors import A2AClientHTTPError
|
||||
from a2a.types import AgentCapabilities, AgentCard, AgentSkill
|
||||
from aiocache import cached # type: ignore[import-untyped]
|
||||
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
|
||||
import httpx
|
||||
|
||||
from crewai.a2a.auth.schemas import APIKeyAuth, HTTPDigestAuth
|
||||
from crewai.a2a.auth.utils import (
|
||||
_auth_store,
|
||||
configure_auth_client,
|
||||
retry_on_401,
|
||||
)
|
||||
from crewai.a2a.config import A2AServerConfig
|
||||
from crewai.crew import Crew
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.a2a_events import (
|
||||
A2AAgentCardFetchedEvent,
|
||||
A2AAuthenticationFailedEvent,
|
||||
A2AConnectionErrorEvent,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.a2a.auth.schemas import AuthScheme
|
||||
from crewai.agent import Agent
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
def _get_server_config(agent: Agent) -> A2AServerConfig | None:
|
||||
"""Get A2AServerConfig from an agent's a2a configuration.
|
||||
|
||||
Args:
|
||||
agent: The Agent instance to check.
|
||||
|
||||
Returns:
|
||||
A2AServerConfig if present, None otherwise.
|
||||
"""
|
||||
if agent.a2a is None:
|
||||
return None
|
||||
if isinstance(agent.a2a, A2AServerConfig):
|
||||
return agent.a2a
|
||||
if isinstance(agent.a2a, list):
|
||||
for config in agent.a2a:
|
||||
if isinstance(config, A2AServerConfig):
|
||||
return config
|
||||
return None
|
||||
|
||||
|
||||
def fetch_agent_card(
|
||||
endpoint: str,
|
||||
auth: AuthScheme | None = None,
|
||||
timeout: int = 30,
|
||||
use_cache: bool = True,
|
||||
cache_ttl: int = 300,
|
||||
) -> AgentCard:
|
||||
"""Fetch AgentCard from an A2A endpoint with optional caching.
|
||||
|
||||
Args:
|
||||
endpoint: A2A agent endpoint URL (AgentCard URL).
|
||||
auth: Optional AuthScheme for authentication.
|
||||
timeout: Request timeout in seconds.
|
||||
use_cache: Whether to use caching (default True).
|
||||
cache_ttl: Cache TTL in seconds (default 300 = 5 minutes).
|
||||
|
||||
Returns:
|
||||
AgentCard object with agent capabilities and skills.
|
||||
|
||||
Raises:
|
||||
httpx.HTTPStatusError: If the request fails.
|
||||
A2AClientHTTPError: If authentication fails.
|
||||
"""
|
||||
if use_cache:
|
||||
if auth:
|
||||
auth_data = auth.model_dump_json(
|
||||
exclude={
|
||||
"_access_token",
|
||||
"_token_expires_at",
|
||||
"_refresh_token",
|
||||
"_authorization_callback",
|
||||
}
|
||||
)
|
||||
auth_hash = hash((type(auth).__name__, auth_data))
|
||||
else:
|
||||
auth_hash = 0
|
||||
_auth_store[auth_hash] = auth
|
||||
ttl_hash = int(time.time() // cache_ttl)
|
||||
return _fetch_agent_card_cached(endpoint, auth_hash, timeout, ttl_hash)
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
return loop.run_until_complete(
|
||||
afetch_agent_card(endpoint=endpoint, auth=auth, timeout=timeout)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
async def afetch_agent_card(
|
||||
endpoint: str,
|
||||
auth: AuthScheme | None = None,
|
||||
timeout: int = 30,
|
||||
use_cache: bool = True,
|
||||
) -> AgentCard:
|
||||
"""Fetch AgentCard from an A2A endpoint asynchronously.
|
||||
|
||||
Native async implementation. Use this when running in an async context.
|
||||
|
||||
Args:
|
||||
endpoint: A2A agent endpoint URL (AgentCard URL).
|
||||
auth: Optional AuthScheme for authentication.
|
||||
timeout: Request timeout in seconds.
|
||||
use_cache: Whether to use caching (default True).
|
||||
|
||||
Returns:
|
||||
AgentCard object with agent capabilities and skills.
|
||||
|
||||
Raises:
|
||||
httpx.HTTPStatusError: If the request fails.
|
||||
A2AClientHTTPError: If authentication fails.
|
||||
"""
|
||||
if use_cache:
|
||||
if auth:
|
||||
auth_data = auth.model_dump_json(
|
||||
exclude={
|
||||
"_access_token",
|
||||
"_token_expires_at",
|
||||
"_refresh_token",
|
||||
"_authorization_callback",
|
||||
}
|
||||
)
|
||||
auth_hash = hash((type(auth).__name__, auth_data))
|
||||
else:
|
||||
auth_hash = 0
|
||||
_auth_store[auth_hash] = auth
|
||||
agent_card: AgentCard = await _afetch_agent_card_cached(
|
||||
endpoint, auth_hash, timeout
|
||||
)
|
||||
return agent_card
|
||||
|
||||
return await _afetch_agent_card_impl(endpoint=endpoint, auth=auth, timeout=timeout)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def _fetch_agent_card_cached(
|
||||
endpoint: str,
|
||||
auth_hash: int,
|
||||
timeout: int,
|
||||
_ttl_hash: int,
|
||||
) -> AgentCard:
|
||||
"""Cached sync version of fetch_agent_card."""
|
||||
auth = _auth_store.get(auth_hash)
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
return loop.run_until_complete(
|
||||
_afetch_agent_card_impl(endpoint=endpoint, auth=auth, timeout=timeout)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
@cached(ttl=300, serializer=PickleSerializer()) # type: ignore[untyped-decorator]
|
||||
async def _afetch_agent_card_cached(
|
||||
endpoint: str,
|
||||
auth_hash: int,
|
||||
timeout: int,
|
||||
) -> AgentCard:
|
||||
"""Cached async implementation of AgentCard fetching."""
|
||||
auth = _auth_store.get(auth_hash)
|
||||
return await _afetch_agent_card_impl(endpoint=endpoint, auth=auth, timeout=timeout)
|
||||
|
||||
|
||||
async def _afetch_agent_card_impl(
|
||||
endpoint: str,
|
||||
auth: AuthScheme | None,
|
||||
timeout: int,
|
||||
) -> AgentCard:
|
||||
"""Internal async implementation of AgentCard fetching."""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if "/.well-known/agent-card.json" in endpoint:
|
||||
base_url = endpoint.replace("/.well-known/agent-card.json", "")
|
||||
agent_card_path = "/.well-known/agent-card.json"
|
||||
else:
|
||||
url_parts = endpoint.split("/", 3)
|
||||
base_url = f"{url_parts[0]}//{url_parts[2]}"
|
||||
agent_card_path = f"/{url_parts[3]}" if len(url_parts) > 3 else "/"
|
||||
|
||||
headers: MutableMapping[str, str] = {}
|
||||
if auth:
|
||||
async with httpx.AsyncClient(timeout=timeout) as temp_auth_client:
|
||||
if isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
|
||||
configure_auth_client(auth, temp_auth_client)
|
||||
headers = await auth.apply_auth(temp_auth_client, {})
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout, headers=headers) as temp_client:
|
||||
if auth and isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
|
||||
configure_auth_client(auth, temp_client)
|
||||
|
||||
agent_card_url = f"{base_url}{agent_card_path}"
|
||||
|
||||
async def _fetch_agent_card_request() -> httpx.Response:
|
||||
return await temp_client.get(agent_card_url)
|
||||
|
||||
try:
|
||||
response = await retry_on_401(
|
||||
request_func=_fetch_agent_card_request,
|
||||
auth_scheme=auth,
|
||||
client=temp_client,
|
||||
headers=temp_client.headers,
|
||||
max_retries=2,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
agent_card = AgentCard.model_validate(response.json())
|
||||
fetch_time_ms = (time.perf_counter() - start_time) * 1000
|
||||
agent_card_dict = agent_card.model_dump(exclude_none=True)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AAgentCardFetchedEvent(
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=agent_card.name,
|
||||
agent_card=agent_card_dict,
|
||||
protocol_version=agent_card.protocol_version,
|
||||
provider=agent_card_dict.get("provider"),
|
||||
cached=False,
|
||||
fetch_time_ms=fetch_time_ms,
|
||||
),
|
||||
)
|
||||
|
||||
return agent_card
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
elapsed_ms = (time.perf_counter() - start_time) * 1000
|
||||
response_body = e.response.text[:1000] if e.response.text else None
|
||||
|
||||
if e.response.status_code == 401:
|
||||
error_details = ["Authentication failed"]
|
||||
www_auth = e.response.headers.get("WWW-Authenticate")
|
||||
if www_auth:
|
||||
error_details.append(f"WWW-Authenticate: {www_auth}")
|
||||
if not auth:
|
||||
error_details.append("No auth scheme provided")
|
||||
msg = " | ".join(error_details)
|
||||
|
||||
auth_type = type(auth).__name__ if auth else None
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AAuthenticationFailedEvent(
|
||||
endpoint=endpoint,
|
||||
auth_type=auth_type,
|
||||
error=msg,
|
||||
status_code=401,
|
||||
metadata={
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"response_body": response_body,
|
||||
"www_authenticate": www_auth,
|
||||
"request_url": str(e.request.url),
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
raise A2AClientHTTPError(401, msg) from e
|
||||
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint,
|
||||
error=str(e),
|
||||
error_type="http_error",
|
||||
status_code=e.response.status_code,
|
||||
operation="fetch_agent_card",
|
||||
metadata={
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"response_body": response_body,
|
||||
"request_url": str(e.request.url),
|
||||
},
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
except httpx.TimeoutException as e:
|
||||
elapsed_ms = (time.perf_counter() - start_time) * 1000
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint,
|
||||
error=str(e),
|
||||
error_type="timeout",
|
||||
operation="fetch_agent_card",
|
||||
metadata={
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"timeout_config": timeout,
|
||||
"request_url": str(e.request.url) if e.request else None,
|
||||
},
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
except httpx.ConnectError as e:
|
||||
elapsed_ms = (time.perf_counter() - start_time) * 1000
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint,
|
||||
error=str(e),
|
||||
error_type="connection_error",
|
||||
operation="fetch_agent_card",
|
||||
metadata={
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"request_url": str(e.request.url) if e.request else None,
|
||||
},
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
except httpx.RequestError as e:
|
||||
elapsed_ms = (time.perf_counter() - start_time) * 1000
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AConnectionErrorEvent(
|
||||
endpoint=endpoint,
|
||||
error=str(e),
|
||||
error_type="request_error",
|
||||
operation="fetch_agent_card",
|
||||
metadata={
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"request_url": str(e.request.url) if e.request else None,
|
||||
},
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def _task_to_skill(task: Task) -> AgentSkill:
|
||||
"""Convert a CrewAI Task to an A2A AgentSkill.
|
||||
|
||||
Args:
|
||||
task: The CrewAI Task to convert.
|
||||
|
||||
Returns:
|
||||
AgentSkill representing the task's capability.
|
||||
"""
|
||||
task_name = task.name or task.description[:50]
|
||||
task_id = task_name.lower().replace(" ", "_")
|
||||
|
||||
tags: list[str] = []
|
||||
if task.agent:
|
||||
tags.append(task.agent.role.lower().replace(" ", "-"))
|
||||
|
||||
return AgentSkill(
|
||||
id=task_id,
|
||||
name=task_name,
|
||||
description=task.description,
|
||||
tags=tags,
|
||||
examples=[task.expected_output] if task.expected_output else None,
|
||||
)
|
||||
|
||||
|
||||
def _tool_to_skill(tool_name: str, tool_description: str) -> AgentSkill:
|
||||
"""Convert an Agent's tool to an A2A AgentSkill.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool.
|
||||
tool_description: Description of what the tool does.
|
||||
|
||||
Returns:
|
||||
AgentSkill representing the tool's capability.
|
||||
"""
|
||||
tool_id = tool_name.lower().replace(" ", "_")
|
||||
|
||||
return AgentSkill(
|
||||
id=tool_id,
|
||||
name=tool_name,
|
||||
description=tool_description,
|
||||
tags=[tool_name.lower().replace(" ", "-")],
|
||||
)
|
||||
|
||||
|
||||
def _crew_to_agent_card(crew: Crew, url: str) -> AgentCard:
|
||||
"""Generate an A2A AgentCard from a Crew instance.
|
||||
|
||||
Args:
|
||||
crew: The Crew instance to generate a card for.
|
||||
url: The base URL where this crew will be exposed.
|
||||
|
||||
Returns:
|
||||
AgentCard describing the crew's capabilities.
|
||||
"""
|
||||
crew_name = getattr(crew, "name", None) or crew.__class__.__name__
|
||||
|
||||
description_parts: list[str] = []
|
||||
crew_description = getattr(crew, "description", None)
|
||||
if crew_description:
|
||||
description_parts.append(crew_description)
|
||||
else:
|
||||
agent_roles = [agent.role for agent in crew.agents]
|
||||
description_parts.append(
|
||||
f"A crew of {len(crew.agents)} agents: {', '.join(agent_roles)}"
|
||||
)
|
||||
|
||||
skills = [_task_to_skill(task) for task in crew.tasks]
|
||||
|
||||
return AgentCard(
|
||||
name=crew_name,
|
||||
description=" ".join(description_parts),
|
||||
url=url,
|
||||
version="1.0.0",
|
||||
capabilities=AgentCapabilities(
|
||||
streaming=True,
|
||||
push_notifications=True,
|
||||
),
|
||||
default_input_modes=["text/plain", "application/json"],
|
||||
default_output_modes=["text/plain", "application/json"],
|
||||
skills=skills,
|
||||
)
|
||||
|
||||
|
||||
def _agent_to_agent_card(agent: Agent, url: str) -> AgentCard:
|
||||
"""Generate an A2A AgentCard from an Agent instance.
|
||||
|
||||
Uses A2AServerConfig values when available, falling back to agent properties.
|
||||
|
||||
Args:
|
||||
agent: The Agent instance to generate a card for.
|
||||
url: The base URL where this agent will be exposed.
|
||||
|
||||
Returns:
|
||||
AgentCard describing the agent's capabilities.
|
||||
"""
|
||||
server_config = _get_server_config(agent) or A2AServerConfig()
|
||||
|
||||
name = server_config.name or agent.role
|
||||
|
||||
description_parts = [agent.goal]
|
||||
if agent.backstory:
|
||||
description_parts.append(agent.backstory)
|
||||
description = server_config.description or " ".join(description_parts)
|
||||
|
||||
skills: list[AgentSkill] = (
|
||||
server_config.skills.copy() if server_config.skills else []
|
||||
)
|
||||
|
||||
if not skills:
|
||||
if agent.tools:
|
||||
for tool in agent.tools:
|
||||
tool_name = getattr(tool, "name", None) or tool.__class__.__name__
|
||||
tool_desc = getattr(tool, "description", None) or f"Tool: {tool_name}"
|
||||
skills.append(_tool_to_skill(tool_name, tool_desc))
|
||||
|
||||
if not skills:
|
||||
skills.append(
|
||||
AgentSkill(
|
||||
id=agent.role.lower().replace(" ", "_"),
|
||||
name=agent.role,
|
||||
description=agent.goal,
|
||||
tags=[agent.role.lower().replace(" ", "-")],
|
||||
)
|
||||
)
|
||||
|
||||
return AgentCard(
|
||||
name=name,
|
||||
description=description,
|
||||
url=server_config.url or url,
|
||||
version=server_config.version,
|
||||
capabilities=server_config.capabilities,
|
||||
default_input_modes=server_config.default_input_modes,
|
||||
default_output_modes=server_config.default_output_modes,
|
||||
skills=skills,
|
||||
protocol_version=server_config.protocol_version,
|
||||
provider=server_config.provider,
|
||||
documentation_url=server_config.documentation_url,
|
||||
icon_url=server_config.icon_url,
|
||||
additional_interfaces=server_config.additional_interfaces,
|
||||
security=server_config.security,
|
||||
security_schemes=server_config.security_schemes,
|
||||
supports_authenticated_extended_card=server_config.supports_authenticated_extended_card,
|
||||
signatures=server_config.signatures,
|
||||
)
|
||||
|
||||
|
||||
def inject_a2a_server_methods(agent: Agent) -> None:
|
||||
"""Inject A2A server methods onto an Agent instance.
|
||||
|
||||
Adds a `to_agent_card(url: str) -> AgentCard` method to the agent
|
||||
that generates an A2A-compliant AgentCard.
|
||||
|
||||
Only injects if the agent has an A2AServerConfig.
|
||||
|
||||
Args:
|
||||
agent: The Agent instance to inject methods onto.
|
||||
"""
|
||||
if _get_server_config(agent) is None:
|
||||
return
|
||||
|
||||
def _to_agent_card(self: Agent, url: str) -> AgentCard:
|
||||
return _agent_to_agent_card(self, url)
|
||||
|
||||
object.__setattr__(agent, "to_agent_card", MethodType(_to_agent_card, agent))
|
||||
@@ -1,101 +0,0 @@
|
||||
"""Response model utilities for A2A agent interactions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TypeAlias
|
||||
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
|
||||
from crewai.a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
|
||||
from crewai.types.utils import create_literals_from_strings
|
||||
|
||||
|
||||
A2AConfigTypes: TypeAlias = A2AConfig | A2AServerConfig | A2AClientConfig
|
||||
A2AClientConfigTypes: TypeAlias = A2AConfig | A2AClientConfig
|
||||
|
||||
|
||||
def create_agent_response_model(agent_ids: tuple[str, ...]) -> type[BaseModel] | None:
|
||||
"""Create a dynamic AgentResponse model with Literal types for agent IDs.
|
||||
|
||||
Args:
|
||||
agent_ids: List of available A2A agent IDs.
|
||||
|
||||
Returns:
|
||||
Dynamically created Pydantic model with Literal-constrained a2a_ids field,
|
||||
or None if agent_ids is empty.
|
||||
"""
|
||||
if not agent_ids:
|
||||
return None
|
||||
|
||||
DynamicLiteral = create_literals_from_strings(agent_ids) # noqa: N806
|
||||
|
||||
return create_model(
|
||||
"AgentResponse",
|
||||
a2a_ids=(
|
||||
tuple[DynamicLiteral, ...], # type: ignore[valid-type]
|
||||
Field(
|
||||
default_factory=tuple,
|
||||
max_length=len(agent_ids),
|
||||
description="A2A agent IDs to delegate to.",
|
||||
),
|
||||
),
|
||||
message=(
|
||||
str,
|
||||
Field(
|
||||
description="The message content. If is_a2a=true, this is sent to the A2A agent. If is_a2a=false, this is your final answer ending the conversation."
|
||||
),
|
||||
),
|
||||
is_a2a=(
|
||||
bool,
|
||||
Field(
|
||||
description="Set to false when the remote agent has answered your question - extract their answer and return it as your final message. Set to true ONLY if you need to ask a NEW, DIFFERENT question. NEVER repeat the same request - if the conversation history shows the agent already answered, set is_a2a=false immediately."
|
||||
),
|
||||
),
|
||||
__base__=BaseModel,
|
||||
)
|
||||
|
||||
|
||||
def extract_a2a_agent_ids_from_config(
|
||||
a2a_config: list[A2AConfigTypes] | A2AConfigTypes | None,
|
||||
) -> tuple[list[A2AClientConfigTypes], tuple[str, ...]]:
|
||||
"""Extract A2A agent IDs from A2A configuration.
|
||||
|
||||
Filters out A2AServerConfig since it doesn't have an endpoint for delegation.
|
||||
|
||||
Args:
|
||||
a2a_config: A2A configuration (any type).
|
||||
|
||||
Returns:
|
||||
Tuple of client A2A configs list and agent endpoint IDs.
|
||||
"""
|
||||
if a2a_config is None:
|
||||
return [], ()
|
||||
|
||||
configs: list[A2AConfigTypes]
|
||||
if isinstance(a2a_config, (A2AConfig, A2AClientConfig, A2AServerConfig)):
|
||||
configs = [a2a_config]
|
||||
else:
|
||||
configs = a2a_config
|
||||
|
||||
# Filter to only client configs (those with endpoint)
|
||||
client_configs: list[A2AClientConfigTypes] = [
|
||||
config for config in configs if isinstance(config, (A2AConfig, A2AClientConfig))
|
||||
]
|
||||
|
||||
return client_configs, tuple(config.endpoint for config in client_configs)
|
||||
|
||||
|
||||
def get_a2a_agents_and_response_model(
|
||||
a2a_config: list[A2AConfigTypes] | A2AConfigTypes | None,
|
||||
) -> tuple[list[A2AClientConfigTypes], type[BaseModel] | None]:
|
||||
"""Get A2A agent configs and response model.
|
||||
|
||||
Args:
|
||||
a2a_config: A2A configuration (any type).
|
||||
|
||||
Returns:
|
||||
Tuple of client A2A configs and response model.
|
||||
"""
|
||||
a2a_agents, agent_ids = extract_a2a_agent_ids_from_config(a2a_config=a2a_config)
|
||||
|
||||
return a2a_agents, create_agent_response_model(agent_ids)
|
||||
@@ -1,399 +0,0 @@
|
||||
"""A2A task utilities for server-side task management."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
from collections.abc import Callable, Coroutine
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from a2a.server.agent_execution import RequestContext
|
||||
from a2a.server.events import EventQueue
|
||||
from a2a.types import (
|
||||
InternalError,
|
||||
InvalidParamsError,
|
||||
Message,
|
||||
Task as A2ATask,
|
||||
TaskState,
|
||||
TaskStatus,
|
||||
TaskStatusUpdateEvent,
|
||||
)
|
||||
from a2a.utils import new_agent_text_message, new_text_artifact
|
||||
from a2a.utils.errors import ServerError
|
||||
from aiocache import SimpleMemoryCache, caches # type: ignore[import-untyped]
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.a2a_events import (
|
||||
A2AServerTaskCanceledEvent,
|
||||
A2AServerTaskCompletedEvent,
|
||||
A2AServerTaskFailedEvent,
|
||||
A2AServerTaskStartedEvent,
|
||||
)
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent import Agent
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def _parse_redis_url(url: str) -> dict[str, Any]:
|
||||
"""Parse a Redis URL into aiocache configuration.
|
||||
|
||||
Args:
|
||||
url: Redis connection URL (e.g., redis://localhost:6379/0).
|
||||
|
||||
Returns:
|
||||
Configuration dict for aiocache.RedisCache.
|
||||
"""
|
||||
|
||||
parsed = urlparse(url)
|
||||
config: dict[str, Any] = {
|
||||
"cache": "aiocache.RedisCache",
|
||||
"endpoint": parsed.hostname or "localhost",
|
||||
"port": parsed.port or 6379,
|
||||
}
|
||||
if parsed.path and parsed.path != "/":
|
||||
try:
|
||||
config["db"] = int(parsed.path.lstrip("/"))
|
||||
except ValueError:
|
||||
pass
|
||||
if parsed.password:
|
||||
config["password"] = parsed.password
|
||||
return config
|
||||
|
||||
|
||||
_redis_url = os.environ.get("REDIS_URL")
|
||||
|
||||
caches.set_config(
|
||||
{
|
||||
"default": _parse_redis_url(_redis_url)
|
||||
if _redis_url
|
||||
else {
|
||||
"cache": "aiocache.SimpleMemoryCache",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def cancellable(
|
||||
fn: Callable[P, Coroutine[Any, Any, T]],
|
||||
) -> Callable[P, Coroutine[Any, Any, T]]:
|
||||
"""Decorator that enables cancellation for A2A task execution.
|
||||
|
||||
Runs a cancellation watcher concurrently with the wrapped function.
|
||||
When a cancel event is published, the execution is cancelled.
|
||||
|
||||
Args:
|
||||
fn: The async function to wrap.
|
||||
|
||||
Returns:
|
||||
Wrapped function with cancellation support.
|
||||
"""
|
||||
|
||||
@wraps(fn)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
"""Wrap function with cancellation monitoring."""
|
||||
context: RequestContext | None = None
|
||||
for arg in args:
|
||||
if isinstance(arg, RequestContext):
|
||||
context = arg
|
||||
break
|
||||
if context is None:
|
||||
context = cast(RequestContext | None, kwargs.get("context"))
|
||||
|
||||
if context is None:
|
||||
return await fn(*args, **kwargs)
|
||||
|
||||
task_id = context.task_id
|
||||
cache = caches.get("default")
|
||||
|
||||
async def poll_for_cancel() -> bool:
|
||||
"""Poll cache for cancellation flag."""
|
||||
while True:
|
||||
if await cache.get(f"cancel:{task_id}"):
|
||||
return True
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def watch_for_cancel() -> bool:
|
||||
"""Watch for cancellation events via pub/sub or polling."""
|
||||
if isinstance(cache, SimpleMemoryCache):
|
||||
return await poll_for_cancel()
|
||||
|
||||
try:
|
||||
client = cache.client
|
||||
pubsub = client.pubsub()
|
||||
await pubsub.subscribe(f"cancel:{task_id}")
|
||||
async for message in pubsub.listen():
|
||||
if message["type"] == "message":
|
||||
return True
|
||||
except (OSError, ConnectionError) as e:
|
||||
logger.warning("Cancel watcher error for task_id=%s: %s", task_id, e)
|
||||
return await poll_for_cancel()
|
||||
return False
|
||||
|
||||
execute_task = asyncio.create_task(fn(*args, **kwargs))
|
||||
cancel_watch = asyncio.create_task(watch_for_cancel())
|
||||
|
||||
try:
|
||||
done, _ = await asyncio.wait(
|
||||
[execute_task, cancel_watch],
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
|
||||
if cancel_watch in done:
|
||||
execute_task.cancel()
|
||||
try:
|
||||
await execute_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
raise asyncio.CancelledError(f"Task {task_id} was cancelled")
|
||||
cancel_watch.cancel()
|
||||
return execute_task.result()
|
||||
finally:
|
||||
await cache.delete(f"cancel:{task_id}")
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@cancellable
|
||||
async def execute(
|
||||
agent: Agent,
|
||||
context: RequestContext,
|
||||
event_queue: EventQueue,
|
||||
) -> None:
|
||||
"""Execute an A2A task using a CrewAI agent.
|
||||
|
||||
Args:
|
||||
agent: The CrewAI agent to execute the task.
|
||||
context: The A2A request context containing the user's message.
|
||||
event_queue: The event queue for sending responses back.
|
||||
|
||||
TODOs:
|
||||
* need to impl both of structured output and file inputs, depends on `file_inputs` for
|
||||
`crewai.task.Task`, pass the below two to Task. both utils in `a2a.utils.parts`
|
||||
* structured outputs ingestion, `structured_inputs = get_data_parts(parts=context.message.parts)`
|
||||
* file inputs ingestion, `file_inputs = get_file_parts(parts=context.message.parts)`
|
||||
"""
|
||||
|
||||
user_message = context.get_user_input()
|
||||
task_id = context.task_id
|
||||
context_id = context.context_id
|
||||
if task_id is None or context_id is None:
|
||||
msg = "task_id and context_id are required"
|
||||
crewai_event_bus.emit(
|
||||
agent,
|
||||
A2AServerTaskFailedEvent(
|
||||
task_id="",
|
||||
context_id="",
|
||||
error=msg,
|
||||
from_agent=agent,
|
||||
),
|
||||
)
|
||||
raise ServerError(InvalidParamsError(message=msg)) from None
|
||||
|
||||
task = Task(
|
||||
description=user_message,
|
||||
expected_output="Response to the user's request",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent,
|
||||
A2AServerTaskStartedEvent(
|
||||
task_id=task_id,
|
||||
context_id=context_id,
|
||||
from_task=task,
|
||||
from_agent=agent,
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
result = await agent.aexecute_task(task=task, tools=agent.tools)
|
||||
result_str = str(result)
|
||||
history: list[Message] = [context.message] if context.message else []
|
||||
history.append(new_agent_text_message(result_str, context_id, task_id))
|
||||
await event_queue.enqueue_event(
|
||||
A2ATask(
|
||||
id=task_id,
|
||||
context_id=context_id,
|
||||
status=TaskStatus(state=TaskState.input_required),
|
||||
artifacts=[new_text_artifact(result_str, f"result_{task_id}")],
|
||||
history=history,
|
||||
)
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
agent,
|
||||
A2AServerTaskCompletedEvent(
|
||||
task_id=task_id,
|
||||
context_id=context_id,
|
||||
result=str(result),
|
||||
from_task=task,
|
||||
from_agent=agent,
|
||||
),
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
crewai_event_bus.emit(
|
||||
agent,
|
||||
A2AServerTaskCanceledEvent(
|
||||
task_id=task_id,
|
||||
context_id=context_id,
|
||||
from_task=task,
|
||||
from_agent=agent,
|
||||
),
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
agent,
|
||||
A2AServerTaskFailedEvent(
|
||||
task_id=task_id,
|
||||
context_id=context_id,
|
||||
error=str(e),
|
||||
from_task=task,
|
||||
from_agent=agent,
|
||||
),
|
||||
)
|
||||
raise ServerError(
|
||||
error=InternalError(message=f"Task execution failed: {e}")
|
||||
) from e
|
||||
|
||||
|
||||
async def cancel(
|
||||
context: RequestContext,
|
||||
event_queue: EventQueue,
|
||||
) -> A2ATask | None:
|
||||
"""Cancel an A2A task.
|
||||
|
||||
Publishes a cancel event that the cancellable decorator listens for.
|
||||
|
||||
Args:
|
||||
context: The A2A request context containing task information.
|
||||
event_queue: The event queue for sending the cancellation status.
|
||||
|
||||
Returns:
|
||||
The canceled task with updated status.
|
||||
"""
|
||||
task_id = context.task_id
|
||||
context_id = context.context_id
|
||||
if task_id is None or context_id is None:
|
||||
raise ServerError(InvalidParamsError(message="task_id and context_id required"))
|
||||
|
||||
if context.current_task and context.current_task.status.state in (
|
||||
TaskState.completed,
|
||||
TaskState.failed,
|
||||
TaskState.canceled,
|
||||
):
|
||||
return context.current_task
|
||||
|
||||
cache = caches.get("default")
|
||||
|
||||
await cache.set(f"cancel:{task_id}", True, ttl=3600)
|
||||
if not isinstance(cache, SimpleMemoryCache):
|
||||
await cache.client.publish(f"cancel:{task_id}", "cancel")
|
||||
|
||||
await event_queue.enqueue_event(
|
||||
TaskStatusUpdateEvent(
|
||||
task_id=task_id,
|
||||
context_id=context_id,
|
||||
status=TaskStatus(state=TaskState.canceled),
|
||||
final=True,
|
||||
)
|
||||
)
|
||||
|
||||
if context.current_task:
|
||||
context.current_task.status = TaskStatus(state=TaskState.canceled)
|
||||
return context.current_task
|
||||
return None
|
||||
|
||||
|
||||
def list_tasks(
|
||||
tasks: list[A2ATask],
|
||||
context_id: str | None = None,
|
||||
status: TaskState | None = None,
|
||||
status_timestamp_after: datetime | None = None,
|
||||
page_size: int = 50,
|
||||
page_token: str | None = None,
|
||||
history_length: int | None = None,
|
||||
include_artifacts: bool = False,
|
||||
) -> tuple[list[A2ATask], str | None, int]:
|
||||
"""Filter and paginate A2A tasks.
|
||||
|
||||
Provides filtering by context, status, and timestamp, along with
|
||||
cursor-based pagination. This is a pure utility function that operates
|
||||
on an in-memory list of tasks - storage retrieval is handled separately.
|
||||
|
||||
Args:
|
||||
tasks: All tasks to filter.
|
||||
context_id: Filter by context ID to get tasks in a conversation.
|
||||
status: Filter by task state (e.g., completed, working).
|
||||
status_timestamp_after: Filter to tasks updated after this time.
|
||||
page_size: Maximum tasks per page (default 50).
|
||||
page_token: Base64-encoded cursor from previous response.
|
||||
history_length: Limit history messages per task (None = full history).
|
||||
include_artifacts: Whether to include task artifacts (default False).
|
||||
|
||||
Returns:
|
||||
Tuple of (filtered_tasks, next_page_token, total_count).
|
||||
- filtered_tasks: Tasks matching filters, paginated and trimmed.
|
||||
- next_page_token: Token for next page, or None if no more pages.
|
||||
- total_count: Total number of tasks matching filters (before pagination).
|
||||
"""
|
||||
filtered: list[A2ATask] = []
|
||||
for task in tasks:
|
||||
if context_id and task.context_id != context_id:
|
||||
continue
|
||||
if status and task.status.state != status:
|
||||
continue
|
||||
if status_timestamp_after and task.status.timestamp:
|
||||
ts = datetime.fromisoformat(task.status.timestamp.replace("Z", "+00:00"))
|
||||
if ts <= status_timestamp_after:
|
||||
continue
|
||||
filtered.append(task)
|
||||
|
||||
def get_timestamp(t: A2ATask) -> datetime:
|
||||
"""Extract timestamp from task status for sorting."""
|
||||
if t.status.timestamp is None:
|
||||
return datetime.min
|
||||
return datetime.fromisoformat(t.status.timestamp.replace("Z", "+00:00"))
|
||||
|
||||
filtered.sort(key=get_timestamp, reverse=True)
|
||||
total = len(filtered)
|
||||
|
||||
start = 0
|
||||
if page_token:
|
||||
try:
|
||||
cursor_id = base64.b64decode(page_token).decode()
|
||||
for idx, task in enumerate(filtered):
|
||||
if task.id == cursor_id:
|
||||
start = idx + 1
|
||||
break
|
||||
except (ValueError, UnicodeDecodeError):
|
||||
pass
|
||||
|
||||
page = filtered[start : start + page_size]
|
||||
|
||||
result: list[A2ATask] = []
|
||||
for task in page:
|
||||
task = task.model_copy(deep=True)
|
||||
if history_length is not None and task.history:
|
||||
task.history = task.history[-history_length:]
|
||||
if not include_artifacts:
|
||||
task.artifacts = None
|
||||
result.append(task)
|
||||
|
||||
next_token: str | None = None
|
||||
if result and len(result) == page_size:
|
||||
next_token = base64.b64encode(result[-1].id.encode()).decode()
|
||||
|
||||
return result, next_token, total
|
||||
@@ -6,17 +6,16 @@ Wraps agent classes with A2A delegation capabilities.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable, Coroutine, Mapping
|
||||
from collections.abc import Callable, Coroutine
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from functools import wraps
|
||||
import json
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from a2a.types import Role, TaskState
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from crewai.a2a.config import A2AClientConfig, A2AConfig
|
||||
from crewai.a2a.config import A2AConfig
|
||||
from crewai.a2a.extensions.base import ExtensionRegistry
|
||||
from crewai.a2a.task_helpers import TaskStateResult
|
||||
from crewai.a2a.templates import (
|
||||
@@ -27,16 +26,13 @@ from crewai.a2a.templates import (
|
||||
UNAVAILABLE_AGENTS_NOTICE_TEMPLATE,
|
||||
)
|
||||
from crewai.a2a.types import AgentResponseProtocol
|
||||
from crewai.a2a.utils.agent_card import (
|
||||
afetch_agent_card,
|
||||
fetch_agent_card,
|
||||
inject_a2a_server_methods,
|
||||
)
|
||||
from crewai.a2a.utils.delegation import (
|
||||
from crewai.a2a.utils import (
|
||||
aexecute_a2a_delegation,
|
||||
afetch_agent_card,
|
||||
execute_a2a_delegation,
|
||||
fetch_agent_card,
|
||||
get_a2a_agents_and_response_model,
|
||||
)
|
||||
from crewai.a2a.utils.response_model import get_a2a_agents_and_response_model
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.a2a_events import (
|
||||
A2AConversationCompletedEvent,
|
||||
@@ -126,12 +122,10 @@ def wrap_agent_with_a2a_instance(
|
||||
agent, "aexecute_task", MethodType(aexecute_task_with_a2a, agent)
|
||||
)
|
||||
|
||||
inject_a2a_server_methods(agent)
|
||||
|
||||
|
||||
def _fetch_card_from_config(
|
||||
config: A2AConfig | A2AClientConfig,
|
||||
) -> tuple[A2AConfig | A2AClientConfig, AgentCard | Exception]:
|
||||
config: A2AConfig,
|
||||
) -> tuple[A2AConfig, AgentCard | Exception]:
|
||||
"""Fetch agent card from A2A config.
|
||||
|
||||
Args:
|
||||
@@ -152,7 +146,7 @@ def _fetch_card_from_config(
|
||||
|
||||
|
||||
def _fetch_agent_cards_concurrently(
|
||||
a2a_agents: list[A2AConfig | A2AClientConfig],
|
||||
a2a_agents: list[A2AConfig],
|
||||
) -> tuple[dict[str, AgentCard], dict[str, str]]:
|
||||
"""Fetch agent cards concurrently for multiple A2A agents.
|
||||
|
||||
@@ -187,10 +181,10 @@ def _fetch_agent_cards_concurrently(
|
||||
|
||||
def _execute_task_with_a2a(
|
||||
self: Agent,
|
||||
a2a_agents: list[A2AConfig | A2AClientConfig],
|
||||
a2a_agents: list[A2AConfig],
|
||||
original_fn: Callable[..., str],
|
||||
task: Task,
|
||||
agent_response_model: type[BaseModel] | None,
|
||||
agent_response_model: type[BaseModel],
|
||||
context: str | None,
|
||||
tools: list[BaseTool] | None,
|
||||
extension_registry: ExtensionRegistry,
|
||||
@@ -276,9 +270,9 @@ def _execute_task_with_a2a(
|
||||
|
||||
|
||||
def _augment_prompt_with_a2a(
|
||||
a2a_agents: list[A2AConfig | A2AClientConfig],
|
||||
a2a_agents: list[A2AConfig],
|
||||
task_description: str,
|
||||
agent_cards: Mapping[str, AgentCard | dict[str, Any]],
|
||||
agent_cards: dict[str, AgentCard],
|
||||
conversation_history: list[Message] | None = None,
|
||||
turn_num: int = 0,
|
||||
max_turns: int | None = None,
|
||||
@@ -310,15 +304,7 @@ def _augment_prompt_with_a2a(
|
||||
for config in a2a_agents:
|
||||
if config.endpoint in agent_cards:
|
||||
card = agent_cards[config.endpoint]
|
||||
if isinstance(card, dict):
|
||||
filtered = {
|
||||
k: v
|
||||
for k, v in card.items()
|
||||
if k in {"description", "url", "skills"} and v is not None
|
||||
}
|
||||
agents_text += f"\n{json.dumps(filtered, indent=2)}\n"
|
||||
else:
|
||||
agents_text += f"\n{card.model_dump_json(indent=2, exclude_none=True, include={'description', 'url', 'skills'})}\n"
|
||||
agents_text += f"\n{card.model_dump_json(indent=2, exclude_none=True, include={'description', 'url', 'skills'})}\n"
|
||||
|
||||
failed_agents = failed_agents or {}
|
||||
if failed_agents:
|
||||
@@ -386,7 +372,7 @@ IMPORTANT: You have the ability to delegate this task to remote A2A agents.
|
||||
|
||||
|
||||
def _parse_agent_response(
|
||||
raw_result: str | dict[str, Any], agent_response_model: type[BaseModel] | None
|
||||
raw_result: str | dict[str, Any], agent_response_model: type[BaseModel]
|
||||
) -> BaseModel | str | dict[str, Any]:
|
||||
"""Parse LLM output as AgentResponse or return raw agent response."""
|
||||
if agent_response_model:
|
||||
@@ -403,11 +389,6 @@ def _parse_agent_response(
|
||||
def _handle_max_turns_exceeded(
|
||||
conversation_history: list[Message],
|
||||
max_turns: int,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
endpoint: str | None = None,
|
||||
a2a_agent_name: str | None = None,
|
||||
agent_card: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""Handle the case when max turns is exceeded.
|
||||
|
||||
@@ -435,11 +416,6 @@ def _handle_max_turns_exceeded(
|
||||
final_result=final_message,
|
||||
error=None,
|
||||
total_turns=max_turns,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
agent_card=agent_card,
|
||||
),
|
||||
)
|
||||
return final_message
|
||||
@@ -451,11 +427,6 @@ def _handle_max_turns_exceeded(
|
||||
final_result=None,
|
||||
error=f"Conversation exceeded maximum turns ({max_turns})",
|
||||
total_turns=max_turns,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
agent_card=agent_card,
|
||||
),
|
||||
)
|
||||
raise Exception(f"A2A conversation exceeded maximum turns ({max_turns})")
|
||||
@@ -466,12 +437,7 @@ def _process_response_result(
|
||||
disable_structured_output: bool,
|
||||
turn_num: int,
|
||||
agent_role: str,
|
||||
agent_response_model: type[BaseModel] | None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
endpoint: str | None = None,
|
||||
a2a_agent_name: str | None = None,
|
||||
agent_card: dict[str, Any] | None = None,
|
||||
agent_response_model: type[BaseModel],
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""Process LLM response and determine next action.
|
||||
|
||||
@@ -490,10 +456,6 @@ def _process_response_result(
|
||||
turn_number=final_turn_number,
|
||||
is_multiturn=True,
|
||||
agent_role=agent_role,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
@@ -503,11 +465,6 @@ def _process_response_result(
|
||||
final_result=result_text,
|
||||
error=None,
|
||||
total_turns=final_turn_number,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
agent_card=agent_card,
|
||||
),
|
||||
)
|
||||
return result_text, None
|
||||
@@ -528,10 +485,6 @@ def _process_response_result(
|
||||
turn_number=final_turn_number,
|
||||
is_multiturn=True,
|
||||
agent_role=agent_role,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
@@ -541,11 +494,6 @@ def _process_response_result(
|
||||
final_result=str(llm_response.message),
|
||||
error=None,
|
||||
total_turns=final_turn_number,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
agent_card=agent_card,
|
||||
),
|
||||
)
|
||||
return str(llm_response.message), None
|
||||
@@ -557,15 +505,13 @@ def _process_response_result(
|
||||
def _prepare_agent_cards_dict(
|
||||
a2a_result: TaskStateResult,
|
||||
agent_id: str,
|
||||
agent_cards: Mapping[str, AgentCard | dict[str, Any]] | None,
|
||||
) -> dict[str, AgentCard | dict[str, Any]]:
|
||||
agent_cards: dict[str, AgentCard] | None,
|
||||
) -> dict[str, AgentCard]:
|
||||
"""Prepare agent cards dictionary from result and existing cards.
|
||||
|
||||
Shared logic for both sync and async response handlers.
|
||||
"""
|
||||
agent_cards_dict: dict[str, AgentCard | dict[str, Any]] = (
|
||||
dict(agent_cards) if agent_cards else {}
|
||||
)
|
||||
agent_cards_dict = agent_cards or {}
|
||||
if "agent_card" in a2a_result and agent_id not in agent_cards_dict:
|
||||
agent_cards_dict[agent_id] = a2a_result["agent_card"]
|
||||
return agent_cards_dict
|
||||
@@ -577,11 +523,11 @@ def _prepare_delegation_context(
|
||||
task: Task,
|
||||
original_task_description: str | None,
|
||||
) -> tuple[
|
||||
list[A2AConfig | A2AClientConfig],
|
||||
type[BaseModel] | None,
|
||||
list[A2AConfig],
|
||||
type[BaseModel],
|
||||
str,
|
||||
str,
|
||||
A2AConfig | A2AClientConfig,
|
||||
A2AConfig,
|
||||
str | None,
|
||||
str | None,
|
||||
dict[str, Any] | None,
|
||||
@@ -645,13 +591,8 @@ def _handle_task_completion(
|
||||
task: Task,
|
||||
task_id_config: str | None,
|
||||
reference_task_ids: list[str],
|
||||
agent_config: A2AConfig | A2AClientConfig,
|
||||
agent_config: A2AConfig,
|
||||
turn_num: int,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
endpoint: str | None = None,
|
||||
a2a_agent_name: str | None = None,
|
||||
agent_card: dict[str, Any] | None = None,
|
||||
) -> tuple[str | None, str | None, list[str]]:
|
||||
"""Handle task completion state including reference task updates.
|
||||
|
||||
@@ -678,11 +619,6 @@ def _handle_task_completion(
|
||||
final_result=result_text,
|
||||
error=None,
|
||||
total_turns=final_turn_number,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
agent_card=agent_card,
|
||||
),
|
||||
)
|
||||
return str(result_text), task_id_config, reference_task_ids
|
||||
@@ -695,7 +631,7 @@ def _handle_agent_response_and_continue(
|
||||
a2a_result: TaskStateResult,
|
||||
agent_id: str,
|
||||
agent_cards: dict[str, AgentCard] | None,
|
||||
a2a_agents: list[A2AConfig | A2AClientConfig],
|
||||
a2a_agents: list[A2AConfig],
|
||||
original_task_description: str,
|
||||
conversation_history: list[Message],
|
||||
turn_num: int,
|
||||
@@ -704,11 +640,8 @@ def _handle_agent_response_and_continue(
|
||||
original_fn: Callable[..., str],
|
||||
context: str | None,
|
||||
tools: list[BaseTool] | None,
|
||||
agent_response_model: type[BaseModel] | None,
|
||||
agent_response_model: type[BaseModel],
|
||||
remote_task_completed: bool = False,
|
||||
endpoint: str | None = None,
|
||||
a2a_agent_name: str | None = None,
|
||||
agent_card: dict[str, Any] | None = None,
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""Handle A2A result and get CrewAI agent's response.
|
||||
|
||||
@@ -760,11 +693,6 @@ def _handle_agent_response_and_continue(
|
||||
turn_num=turn_num,
|
||||
agent_role=self.role,
|
||||
agent_response_model=agent_response_model,
|
||||
from_task=task,
|
||||
from_agent=self,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
agent_card=agent_card,
|
||||
)
|
||||
|
||||
|
||||
@@ -817,12 +745,6 @@ def _delegate_to_a2a(
|
||||
|
||||
conversation_history: list[Message] = []
|
||||
|
||||
current_agent_card = agent_cards.get(agent_id) if agent_cards else None
|
||||
current_agent_card_dict = (
|
||||
current_agent_card.model_dump() if current_agent_card else None
|
||||
)
|
||||
current_a2a_agent_name = current_agent_card.name if current_agent_card else None
|
||||
|
||||
try:
|
||||
for turn_num in range(max_turns):
|
||||
console_formatter = getattr(crewai_event_bus, "_console", None)
|
||||
@@ -849,9 +771,6 @@ def _delegate_to_a2a(
|
||||
response_model=agent_config.response_model,
|
||||
turn_number=turn_num + 1,
|
||||
updates=agent_config.updates,
|
||||
transport_protocol=agent_config.transport_protocol,
|
||||
from_task=task,
|
||||
from_agent=self,
|
||||
)
|
||||
|
||||
conversation_history = a2a_result.get("history", [])
|
||||
@@ -872,11 +791,6 @@ def _delegate_to_a2a(
|
||||
reference_task_ids,
|
||||
agent_config,
|
||||
turn_num,
|
||||
from_task=task,
|
||||
from_agent=self,
|
||||
endpoint=agent_config.endpoint,
|
||||
a2a_agent_name=current_a2a_agent_name,
|
||||
agent_card=current_agent_card_dict,
|
||||
)
|
||||
)
|
||||
if trusted_result is not None:
|
||||
@@ -898,9 +812,6 @@ def _delegate_to_a2a(
|
||||
tools=tools,
|
||||
agent_response_model=agent_response_model,
|
||||
remote_task_completed=(a2a_result["status"] == TaskState.completed),
|
||||
endpoint=agent_config.endpoint,
|
||||
a2a_agent_name=current_a2a_agent_name,
|
||||
agent_card=current_agent_card_dict,
|
||||
)
|
||||
|
||||
if final_result is not None:
|
||||
@@ -929,9 +840,6 @@ def _delegate_to_a2a(
|
||||
tools=tools,
|
||||
agent_response_model=agent_response_model,
|
||||
remote_task_completed=False,
|
||||
endpoint=agent_config.endpoint,
|
||||
a2a_agent_name=current_a2a_agent_name,
|
||||
agent_card=current_agent_card_dict,
|
||||
)
|
||||
|
||||
if final_result is not None:
|
||||
@@ -948,32 +856,19 @@ def _delegate_to_a2a(
|
||||
final_result=None,
|
||||
error=error_msg,
|
||||
total_turns=turn_num + 1,
|
||||
from_task=task,
|
||||
from_agent=self,
|
||||
endpoint=agent_config.endpoint,
|
||||
a2a_agent_name=current_a2a_agent_name,
|
||||
agent_card=current_agent_card_dict,
|
||||
),
|
||||
)
|
||||
return f"A2A delegation failed: {error_msg}"
|
||||
|
||||
return _handle_max_turns_exceeded(
|
||||
conversation_history,
|
||||
max_turns,
|
||||
from_task=task,
|
||||
from_agent=self,
|
||||
endpoint=agent_config.endpoint,
|
||||
a2a_agent_name=current_a2a_agent_name,
|
||||
agent_card=current_agent_card_dict,
|
||||
)
|
||||
return _handle_max_turns_exceeded(conversation_history, max_turns)
|
||||
|
||||
finally:
|
||||
task.description = original_task_description
|
||||
|
||||
|
||||
async def _afetch_card_from_config(
|
||||
config: A2AConfig | A2AClientConfig,
|
||||
) -> tuple[A2AConfig | A2AClientConfig, AgentCard | Exception]:
|
||||
config: A2AConfig,
|
||||
) -> tuple[A2AConfig, AgentCard | Exception]:
|
||||
"""Fetch agent card from A2A config asynchronously."""
|
||||
try:
|
||||
card = await afetch_agent_card(
|
||||
@@ -987,7 +882,7 @@ async def _afetch_card_from_config(
|
||||
|
||||
|
||||
async def _afetch_agent_cards_concurrently(
|
||||
a2a_agents: list[A2AConfig | A2AClientConfig],
|
||||
a2a_agents: list[A2AConfig],
|
||||
) -> tuple[dict[str, AgentCard], dict[str, str]]:
|
||||
"""Fetch agent cards concurrently for multiple A2A agents using asyncio."""
|
||||
agent_cards: dict[str, AgentCard] = {}
|
||||
@@ -1012,10 +907,10 @@ async def _afetch_agent_cards_concurrently(
|
||||
|
||||
async def _aexecute_task_with_a2a(
|
||||
self: Agent,
|
||||
a2a_agents: list[A2AConfig | A2AClientConfig],
|
||||
a2a_agents: list[A2AConfig],
|
||||
original_fn: Callable[..., Coroutine[Any, Any, str]],
|
||||
task: Task,
|
||||
agent_response_model: type[BaseModel] | None,
|
||||
agent_response_model: type[BaseModel],
|
||||
context: str | None,
|
||||
tools: list[BaseTool] | None,
|
||||
extension_registry: ExtensionRegistry,
|
||||
@@ -1091,7 +986,7 @@ async def _ahandle_agent_response_and_continue(
|
||||
a2a_result: TaskStateResult,
|
||||
agent_id: str,
|
||||
agent_cards: dict[str, AgentCard] | None,
|
||||
a2a_agents: list[A2AConfig | A2AClientConfig],
|
||||
a2a_agents: list[A2AConfig],
|
||||
original_task_description: str,
|
||||
conversation_history: list[Message],
|
||||
turn_num: int,
|
||||
@@ -1100,11 +995,8 @@ async def _ahandle_agent_response_and_continue(
|
||||
original_fn: Callable[..., Coroutine[Any, Any, str]],
|
||||
context: str | None,
|
||||
tools: list[BaseTool] | None,
|
||||
agent_response_model: type[BaseModel] | None,
|
||||
agent_response_model: type[BaseModel],
|
||||
remote_task_completed: bool = False,
|
||||
endpoint: str | None = None,
|
||||
a2a_agent_name: str | None = None,
|
||||
agent_card: dict[str, Any] | None = None,
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""Async version of _handle_agent_response_and_continue."""
|
||||
agent_cards_dict = _prepare_agent_cards_dict(a2a_result, agent_id, agent_cards)
|
||||
@@ -1134,11 +1026,6 @@ async def _ahandle_agent_response_and_continue(
|
||||
turn_num=turn_num,
|
||||
agent_role=self.role,
|
||||
agent_response_model=agent_response_model,
|
||||
from_task=task,
|
||||
from_agent=self,
|
||||
endpoint=endpoint,
|
||||
a2a_agent_name=a2a_agent_name,
|
||||
agent_card=agent_card,
|
||||
)
|
||||
|
||||
|
||||
@@ -1173,12 +1060,6 @@ async def _adelegate_to_a2a(
|
||||
|
||||
conversation_history: list[Message] = []
|
||||
|
||||
current_agent_card = agent_cards.get(agent_id) if agent_cards else None
|
||||
current_agent_card_dict = (
|
||||
current_agent_card.model_dump() if current_agent_card else None
|
||||
)
|
||||
current_a2a_agent_name = current_agent_card.name if current_agent_card else None
|
||||
|
||||
try:
|
||||
for turn_num in range(max_turns):
|
||||
console_formatter = getattr(crewai_event_bus, "_console", None)
|
||||
@@ -1204,10 +1085,7 @@ async def _adelegate_to_a2a(
|
||||
agent_branch=agent_branch,
|
||||
response_model=agent_config.response_model,
|
||||
turn_number=turn_num + 1,
|
||||
transport_protocol=agent_config.transport_protocol,
|
||||
updates=agent_config.updates,
|
||||
from_task=task,
|
||||
from_agent=self,
|
||||
)
|
||||
|
||||
conversation_history = a2a_result.get("history", [])
|
||||
@@ -1228,11 +1106,6 @@ async def _adelegate_to_a2a(
|
||||
reference_task_ids,
|
||||
agent_config,
|
||||
turn_num,
|
||||
from_task=task,
|
||||
from_agent=self,
|
||||
endpoint=agent_config.endpoint,
|
||||
a2a_agent_name=current_a2a_agent_name,
|
||||
agent_card=current_agent_card_dict,
|
||||
)
|
||||
)
|
||||
if trusted_result is not None:
|
||||
@@ -1254,9 +1127,6 @@ async def _adelegate_to_a2a(
|
||||
tools=tools,
|
||||
agent_response_model=agent_response_model,
|
||||
remote_task_completed=(a2a_result["status"] == TaskState.completed),
|
||||
endpoint=agent_config.endpoint,
|
||||
a2a_agent_name=current_a2a_agent_name,
|
||||
agent_card=current_agent_card_dict,
|
||||
)
|
||||
|
||||
if final_result is not None:
|
||||
@@ -1284,9 +1154,6 @@ async def _adelegate_to_a2a(
|
||||
context=context,
|
||||
tools=tools,
|
||||
agent_response_model=agent_response_model,
|
||||
endpoint=agent_config.endpoint,
|
||||
a2a_agent_name=current_a2a_agent_name,
|
||||
agent_card=current_agent_card_dict,
|
||||
)
|
||||
|
||||
if final_result is not None:
|
||||
@@ -1303,24 +1170,11 @@ async def _adelegate_to_a2a(
|
||||
final_result=None,
|
||||
error=error_msg,
|
||||
total_turns=turn_num + 1,
|
||||
from_task=task,
|
||||
from_agent=self,
|
||||
endpoint=agent_config.endpoint,
|
||||
a2a_agent_name=current_a2a_agent_name,
|
||||
agent_card=current_agent_card_dict,
|
||||
),
|
||||
)
|
||||
return f"A2A delegation failed: {error_msg}"
|
||||
|
||||
return _handle_max_turns_exceeded(
|
||||
conversation_history,
|
||||
max_turns,
|
||||
from_task=task,
|
||||
from_agent=self,
|
||||
endpoint=agent_config.endpoint,
|
||||
a2a_agent_name=current_a2a_agent_name,
|
||||
agent_card=current_agent_card_dict,
|
||||
)
|
||||
return _handle_max_turns_exceeded(conversation_history, max_turns)
|
||||
|
||||
finally:
|
||||
task.description = original_task_description
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable, Coroutine, Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
@@ -17,6 +17,7 @@ from urllib.parse import urlparse
|
||||
from pydantic import BaseModel, Field, InstanceOf, PrivateAttr, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.a2a.config import A2AConfig
|
||||
from crewai.agent.utils import (
|
||||
ahandle_knowledge_retrieval,
|
||||
apply_training_data,
|
||||
@@ -34,11 +35,6 @@ from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.agent_events import (
|
||||
LiteAgentExecutionCompletedEvent,
|
||||
LiteAgentExecutionErrorEvent,
|
||||
LiteAgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.knowledge_events import (
|
||||
KnowledgeQueryCompletedEvent,
|
||||
KnowledgeQueryFailedEvent,
|
||||
@@ -48,10 +44,10 @@ from crewai.events.types.memory_events import (
|
||||
MemoryRetrievalCompletedEvent,
|
||||
MemoryRetrievalStartedEvent,
|
||||
)
|
||||
from crewai.experimental.agent_executor import AgentExecutor
|
||||
from crewai.experimental.crew_agent_executor_flow import CrewAgentExecutorFlow
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.lite_agent_output import LiteAgentOutput
|
||||
from crewai.lite_agent import LiteAgent
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.mcp import (
|
||||
MCPClient,
|
||||
@@ -69,37 +65,26 @@ from crewai.security.fingerprint import Fingerprint
|
||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
from crewai.utilities.agent_utils import (
|
||||
get_tool_names,
|
||||
is_inside_event_loop,
|
||||
load_agent_from_repository,
|
||||
parse_tools,
|
||||
render_text_description_and_args,
|
||||
)
|
||||
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
|
||||
from crewai.utilities.converter import Converter, ConverterError
|
||||
from crewai.utilities.guardrail import process_guardrail
|
||||
from crewai.utilities.converter import Converter
|
||||
from crewai.utilities.guardrail_types import GuardrailType
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
from crewai.utilities.prompts import Prompts, StandardPromptResult, SystemPromptResult
|
||||
from crewai.utilities.pydantic_schema_utils import generate_model_description
|
||||
from crewai.utilities.prompts import Prompts
|
||||
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
|
||||
|
||||
try:
|
||||
from crewai.a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
|
||||
except ImportError:
|
||||
A2AClientConfig = Any # type: ignore[assignment,misc]
|
||||
A2AConfig = Any # type: ignore[assignment,misc]
|
||||
A2AServerConfig = Any # type: ignore[assignment,misc]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai_tools import CodeInterpreterTool
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import PlatformAppOrAction
|
||||
from crewai.lite_agent_output import LiteAgentOutput
|
||||
from crewai.task import Task
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
@@ -121,7 +106,7 @@ class Agent(BaseAgent):
|
||||
The agent can also have memory, can operate in verbose mode, and can delegate tasks to other agents.
|
||||
|
||||
Attributes:
|
||||
agent_executor: An instance of the CrewAgentExecutor or AgentExecutor class.
|
||||
agent_executor: An instance of the CrewAgentExecutor or CrewAgentExecutorFlow class.
|
||||
role: The role of the agent.
|
||||
goal: The objective of the agent.
|
||||
backstory: The backstory of the agent.
|
||||
@@ -184,8 +169,7 @@ class Agent(BaseAgent):
|
||||
)
|
||||
multimodal: bool = Field(
|
||||
default=False,
|
||||
deprecated=True,
|
||||
description="[DEPRECATED, will be removed in v2.0 - pass files natively.] Whether the agent is multimodal.",
|
||||
description="Whether the agent is multimodal.",
|
||||
)
|
||||
inject_date: bool = Field(
|
||||
default=False,
|
||||
@@ -234,22 +218,13 @@ class Agent(BaseAgent):
|
||||
guardrail_max_retries: int = Field(
|
||||
default=3, description="Maximum number of retries when guardrail fails"
|
||||
)
|
||||
a2a: (
|
||||
list[A2AConfig | A2AServerConfig | A2AClientConfig]
|
||||
| A2AConfig
|
||||
| A2AServerConfig
|
||||
| A2AClientConfig
|
||||
| None
|
||||
) = Field(
|
||||
a2a: list[A2AConfig] | A2AConfig | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
A2A (Agent-to-Agent) configuration for delegating tasks to remote agents.
|
||||
Can be a single A2AConfig/A2AClientConfig/A2AServerConfig, or a list of any number of A2AConfig/A2AClientConfig with a single A2AServerConfig.
|
||||
""",
|
||||
description="A2A (Agent-to-Agent) configuration for delegating tasks to remote agents. Can be a single A2AConfig or a dict mapping agent IDs to configs.",
|
||||
)
|
||||
executor_class: type[CrewAgentExecutor] | type[AgentExecutor] = Field(
|
||||
executor_class: type[CrewAgentExecutor] | type[CrewAgentExecutorFlow] = Field(
|
||||
default=CrewAgentExecutor,
|
||||
description="Class to use for the agent executor. Defaults to CrewAgentExecutor, can optionally use AgentExecutor.",
|
||||
description="Class to use for the agent executor. Defaults to CrewAgentExecutor, can optionally use CrewAgentExecutorFlow.",
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@@ -758,7 +733,7 @@ class Agent(BaseAgent):
|
||||
if self.agent_executor is not None:
|
||||
self._update_executor_parameters(
|
||||
task=task,
|
||||
tools=parsed_tools, # type: ignore[arg-type]
|
||||
tools=parsed_tools,
|
||||
raw_tools=raw_tools,
|
||||
prompt=prompt,
|
||||
stop_words=stop_words,
|
||||
@@ -767,7 +742,7 @@ class Agent(BaseAgent):
|
||||
else:
|
||||
self.agent_executor = self.executor_class(
|
||||
llm=cast(BaseLLM, self.llm),
|
||||
task=task, # type: ignore[arg-type]
|
||||
task=task,
|
||||
i18n=self.i18n,
|
||||
agent=self,
|
||||
crew=self.crew,
|
||||
@@ -790,11 +765,11 @@ class Agent(BaseAgent):
|
||||
def _update_executor_parameters(
|
||||
self,
|
||||
task: Task | None,
|
||||
tools: list[BaseTool],
|
||||
tools: list,
|
||||
raw_tools: list[BaseTool],
|
||||
prompt: SystemPromptResult | StandardPromptResult,
|
||||
prompt: dict,
|
||||
stop_words: list[str],
|
||||
rpm_limit_fn: Callable | None, # type: ignore[type-arg]
|
||||
rpm_limit_fn: Callable | None,
|
||||
) -> None:
|
||||
"""Update executor parameters without recreating instance.
|
||||
|
||||
@@ -1592,25 +1567,26 @@ class Agent(BaseAgent):
|
||||
)
|
||||
return None
|
||||
|
||||
def _prepare_kickoff(
|
||||
def kickoff(
|
||||
self,
|
||||
messages: str | list[LLMMessage],
|
||||
response_format: type[Any] | None = None,
|
||||
) -> tuple[AgentExecutor, dict[str, str], dict[str, Any], list[CrewStructuredTool]]:
|
||||
"""Prepare common setup for kickoff execution.
|
||||
) -> LiteAgentOutput:
|
||||
"""
|
||||
Execute the agent with the given messages using a LiteAgent instance.
|
||||
|
||||
This method handles all the common preparation logic shared between
|
||||
kickoff() and kickoff_async(), including tool processing, prompt building,
|
||||
executor creation, and input formatting.
|
||||
This method is useful when you want to use the Agent configuration but
|
||||
with the simpler and more direct execution flow of LiteAgent.
|
||||
|
||||
Args:
|
||||
messages: Either a string query or a list of message dictionaries.
|
||||
If a string is provided, it will be converted to a user message.
|
||||
If a list is provided, each dict should have 'role' and 'content' keys.
|
||||
response_format: Optional Pydantic model for structured output.
|
||||
|
||||
Returns:
|
||||
Tuple of (executor, inputs, agent_info, parsed_tools) ready for execution.
|
||||
LiteAgentOutput: The result of the agent execution.
|
||||
"""
|
||||
# Process platform apps and MCP tools
|
||||
if self.apps:
|
||||
platform_tools = self.get_platform_tools(self.apps)
|
||||
if platform_tools and self.tools is not None:
|
||||
@@ -1620,359 +1596,25 @@ class Agent(BaseAgent):
|
||||
if mcps and self.tools is not None:
|
||||
self.tools.extend(mcps)
|
||||
|
||||
# Prepare tools
|
||||
raw_tools: list[BaseTool] = self.tools or []
|
||||
parsed_tools = parse_tools(raw_tools)
|
||||
|
||||
# Build agent_info for backward-compatible event emission
|
||||
agent_info = {
|
||||
"id": self.id,
|
||||
"role": self.role,
|
||||
"goal": self.goal,
|
||||
"backstory": self.backstory,
|
||||
"tools": raw_tools,
|
||||
"verbose": self.verbose,
|
||||
}
|
||||
|
||||
# Build prompt for standalone execution
|
||||
prompt = Prompts(
|
||||
agent=self,
|
||||
has_tools=len(raw_tools) > 0,
|
||||
i18n=self.i18n,
|
||||
use_system_prompt=self.use_system_prompt,
|
||||
system_template=self.system_template,
|
||||
prompt_template=self.prompt_template,
|
||||
response_template=self.response_template,
|
||||
).task_execution()
|
||||
|
||||
# Prepare stop words
|
||||
stop_words = [self.i18n.slice("observation")]
|
||||
if self.response_template:
|
||||
stop_words.append(
|
||||
self.response_template.split("{{ .Response }}")[1].strip()
|
||||
)
|
||||
|
||||
# Get RPM limit function
|
||||
rpm_limit_fn = (
|
||||
self._rpm_controller.check_or_wait if self._rpm_controller else None
|
||||
)
|
||||
|
||||
# Create the executor for standalone mode (no crew, no task)
|
||||
executor = AgentExecutor(
|
||||
task=None,
|
||||
crew=None,
|
||||
llm=cast(BaseLLM, self.llm),
|
||||
agent=self,
|
||||
prompt=prompt,
|
||||
max_iter=self.max_iter,
|
||||
tools=parsed_tools,
|
||||
tools_names=get_tool_names(parsed_tools),
|
||||
stop_words=stop_words,
|
||||
tools_description=render_text_description_and_args(parsed_tools),
|
||||
tools_handler=self.tools_handler,
|
||||
original_tools=raw_tools,
|
||||
step_callback=self.step_callback,
|
||||
function_calling_llm=self.function_calling_llm,
|
||||
lite_agent = LiteAgent(
|
||||
id=self.id,
|
||||
role=self.role,
|
||||
goal=self.goal,
|
||||
backstory=self.backstory,
|
||||
llm=self.llm,
|
||||
tools=self.tools or [],
|
||||
max_iterations=self.max_iter,
|
||||
max_execution_time=self.max_execution_time,
|
||||
respect_context_window=self.respect_context_window,
|
||||
request_within_rpm_limit=rpm_limit_fn,
|
||||
callbacks=[TokenCalcHandler(self._token_process)],
|
||||
response_model=response_format,
|
||||
verbose=self.verbose,
|
||||
response_format=response_format,
|
||||
i18n=self.i18n,
|
||||
original_agent=self,
|
||||
guardrail=self.guardrail,
|
||||
guardrail_max_retries=self.guardrail_max_retries,
|
||||
)
|
||||
|
||||
# Format messages
|
||||
if isinstance(messages, str):
|
||||
formatted_messages = messages
|
||||
else:
|
||||
formatted_messages = "\n".join(
|
||||
str(msg.get("content", "")) for msg in messages if msg.get("content")
|
||||
)
|
||||
|
||||
# Build the input dict for the executor
|
||||
inputs = {
|
||||
"input": formatted_messages,
|
||||
"tool_names": get_tool_names(parsed_tools),
|
||||
"tools": render_text_description_and_args(parsed_tools),
|
||||
}
|
||||
|
||||
return executor, inputs, agent_info, parsed_tools
|
||||
|
||||
def kickoff(
|
||||
self,
|
||||
messages: str | list[LLMMessage],
|
||||
response_format: type[Any] | None = None,
|
||||
) -> LiteAgentOutput | Coroutine[Any, Any, LiteAgentOutput]:
|
||||
"""
|
||||
Execute the agent with the given messages using the AgentExecutor.
|
||||
|
||||
This method provides standalone agent execution without requiring a Crew.
|
||||
It supports tools, response formatting, and guardrails.
|
||||
|
||||
When called from within a Flow (sync or async method), this automatically
|
||||
detects the event loop and returns a coroutine that the Flow framework
|
||||
awaits. Users don't need to handle async explicitly.
|
||||
|
||||
Args:
|
||||
messages: Either a string query or a list of message dictionaries.
|
||||
If a string is provided, it will be converted to a user message.
|
||||
If a list is provided, each dict should have 'role' and 'content' keys.
|
||||
response_format: Optional Pydantic model for structured output.
|
||||
|
||||
Returns:
|
||||
LiteAgentOutput: The result of the agent execution.
|
||||
When inside a Flow, returns a coroutine that resolves to LiteAgentOutput.
|
||||
|
||||
Note:
|
||||
For explicit async usage outside of Flow, use kickoff_async() directly.
|
||||
"""
|
||||
# Magic auto-async: if inside event loop (e.g., inside a Flow),
|
||||
# return coroutine for Flow to await
|
||||
if is_inside_event_loop():
|
||||
return self.kickoff_async(messages, response_format)
|
||||
|
||||
executor, inputs, agent_info, parsed_tools = self._prepare_kickoff(
|
||||
messages, response_format
|
||||
)
|
||||
|
||||
try:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LiteAgentExecutionStartedEvent(
|
||||
agent_info=agent_info,
|
||||
tools=parsed_tools,
|
||||
messages=messages,
|
||||
),
|
||||
)
|
||||
|
||||
output = self._execute_and_build_output(executor, inputs, response_format)
|
||||
|
||||
if self.guardrail is not None:
|
||||
output = self._process_kickoff_guardrail(
|
||||
output=output,
|
||||
executor=executor,
|
||||
inputs=inputs,
|
||||
response_format=response_format,
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LiteAgentExecutionCompletedEvent(
|
||||
agent_info=agent_info,
|
||||
output=output.raw,
|
||||
),
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LiteAgentExecutionErrorEvent(
|
||||
agent_info=agent_info,
|
||||
error=str(e),
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
def _execute_and_build_output(
|
||||
self,
|
||||
executor: AgentExecutor,
|
||||
inputs: dict[str, str],
|
||||
response_format: type[Any] | None = None,
|
||||
) -> LiteAgentOutput:
|
||||
"""Execute the agent and build the output object.
|
||||
|
||||
Args:
|
||||
executor: The executor instance.
|
||||
inputs: Input dictionary for execution.
|
||||
response_format: Optional response format.
|
||||
|
||||
Returns:
|
||||
LiteAgentOutput with raw output, formatted result, and metrics.
|
||||
"""
|
||||
import json
|
||||
|
||||
# Execute the agent (this is called from sync path, so invoke returns dict)
|
||||
result = cast(dict[str, Any], executor.invoke(inputs))
|
||||
raw_output = result.get("output", "")
|
||||
|
||||
# Handle response format conversion
|
||||
formatted_result: BaseModel | None = None
|
||||
if response_format:
|
||||
try:
|
||||
model_schema = generate_model_description(response_format)
|
||||
schema = json.dumps(model_schema, indent=2)
|
||||
instructions = self.i18n.slice("formatted_task_instructions").format(
|
||||
output_format=schema
|
||||
)
|
||||
|
||||
converter = Converter(
|
||||
llm=self.llm,
|
||||
text=raw_output,
|
||||
model=response_format,
|
||||
instructions=instructions,
|
||||
)
|
||||
|
||||
conversion_result = converter.to_pydantic()
|
||||
if isinstance(conversion_result, BaseModel):
|
||||
formatted_result = conversion_result
|
||||
except ConverterError:
|
||||
pass # Keep raw output if conversion fails
|
||||
|
||||
# Get token usage metrics
|
||||
if isinstance(self.llm, BaseLLM):
|
||||
usage_metrics = self.llm.get_token_usage_summary()
|
||||
else:
|
||||
usage_metrics = self._token_process.get_summary()
|
||||
|
||||
return LiteAgentOutput(
|
||||
raw=raw_output,
|
||||
pydantic=formatted_result,
|
||||
agent_role=self.role,
|
||||
usage_metrics=usage_metrics.model_dump() if usage_metrics else None,
|
||||
messages=executor.messages,
|
||||
)
|
||||
|
||||
async def _execute_and_build_output_async(
|
||||
self,
|
||||
executor: AgentExecutor,
|
||||
inputs: dict[str, str],
|
||||
response_format: type[Any] | None = None,
|
||||
) -> LiteAgentOutput:
|
||||
"""Execute the agent asynchronously and build the output object.
|
||||
|
||||
This is the async version of _execute_and_build_output that uses
|
||||
invoke_async() for native async execution within event loops.
|
||||
|
||||
Args:
|
||||
executor: The executor instance.
|
||||
inputs: Input dictionary for execution.
|
||||
response_format: Optional response format.
|
||||
|
||||
Returns:
|
||||
LiteAgentOutput with raw output, formatted result, and metrics.
|
||||
"""
|
||||
import json
|
||||
|
||||
# Execute the agent asynchronously
|
||||
result = await executor.invoke_async(inputs)
|
||||
raw_output = result.get("output", "")
|
||||
|
||||
# Handle response format conversion
|
||||
formatted_result: BaseModel | None = None
|
||||
if response_format:
|
||||
try:
|
||||
model_schema = generate_model_description(response_format)
|
||||
schema = json.dumps(model_schema, indent=2)
|
||||
instructions = self.i18n.slice("formatted_task_instructions").format(
|
||||
output_format=schema
|
||||
)
|
||||
|
||||
converter = Converter(
|
||||
llm=self.llm,
|
||||
text=raw_output,
|
||||
model=response_format,
|
||||
instructions=instructions,
|
||||
)
|
||||
|
||||
conversion_result = converter.to_pydantic()
|
||||
if isinstance(conversion_result, BaseModel):
|
||||
formatted_result = conversion_result
|
||||
except ConverterError:
|
||||
pass # Keep raw output if conversion fails
|
||||
|
||||
# Get token usage metrics
|
||||
if isinstance(self.llm, BaseLLM):
|
||||
usage_metrics = self.llm.get_token_usage_summary()
|
||||
else:
|
||||
usage_metrics = self._token_process.get_summary()
|
||||
|
||||
return LiteAgentOutput(
|
||||
raw=raw_output,
|
||||
pydantic=formatted_result,
|
||||
agent_role=self.role,
|
||||
usage_metrics=usage_metrics.model_dump() if usage_metrics else None,
|
||||
messages=executor.messages,
|
||||
)
|
||||
|
||||
def _process_kickoff_guardrail(
|
||||
self,
|
||||
output: LiteAgentOutput,
|
||||
executor: AgentExecutor,
|
||||
inputs: dict[str, str],
|
||||
response_format: type[Any] | None = None,
|
||||
retry_count: int = 0,
|
||||
) -> LiteAgentOutput:
|
||||
"""Process guardrail for kickoff execution with retry logic.
|
||||
|
||||
Args:
|
||||
output: Current agent output.
|
||||
executor: The executor instance.
|
||||
inputs: Input dictionary for re-execution.
|
||||
response_format: Optional response format.
|
||||
retry_count: Current retry count.
|
||||
|
||||
Returns:
|
||||
Validated/updated output.
|
||||
"""
|
||||
from crewai.utilities.guardrail_types import GuardrailCallable
|
||||
|
||||
# Ensure guardrail is callable
|
||||
guardrail_callable: GuardrailCallable
|
||||
if isinstance(self.guardrail, str):
|
||||
from crewai.tasks.llm_guardrail import LLMGuardrail
|
||||
|
||||
guardrail_callable = cast(
|
||||
GuardrailCallable,
|
||||
LLMGuardrail(description=self.guardrail, llm=cast(BaseLLM, self.llm)),
|
||||
)
|
||||
elif callable(self.guardrail):
|
||||
guardrail_callable = self.guardrail
|
||||
else:
|
||||
# Should not happen if called from kickoff with guardrail check
|
||||
return output
|
||||
|
||||
guardrail_result = process_guardrail(
|
||||
output=output,
|
||||
guardrail=guardrail_callable,
|
||||
retry_count=retry_count,
|
||||
event_source=self,
|
||||
from_agent=self,
|
||||
)
|
||||
|
||||
if not guardrail_result.success:
|
||||
if retry_count >= self.guardrail_max_retries:
|
||||
raise ValueError(
|
||||
f"Agent's guardrail failed validation after {self.guardrail_max_retries} retries. "
|
||||
f"Last error: {guardrail_result.error}"
|
||||
)
|
||||
|
||||
# Add feedback and re-execute
|
||||
executor._append_message_to_state(
|
||||
guardrail_result.error or "Guardrail validation failed",
|
||||
role="user",
|
||||
)
|
||||
|
||||
# Re-execute and build new output
|
||||
output = self._execute_and_build_output(executor, inputs, response_format)
|
||||
|
||||
# Recursively retry guardrail
|
||||
return self._process_kickoff_guardrail(
|
||||
output=output,
|
||||
executor=executor,
|
||||
inputs=inputs,
|
||||
response_format=response_format,
|
||||
retry_count=retry_count + 1,
|
||||
)
|
||||
|
||||
# Apply guardrail result if available
|
||||
if guardrail_result.result is not None:
|
||||
if isinstance(guardrail_result.result, str):
|
||||
output.raw = guardrail_result.result
|
||||
elif isinstance(guardrail_result.result, BaseModel):
|
||||
output.pydantic = guardrail_result.result
|
||||
|
||||
return output
|
||||
return lite_agent.kickoff(messages)
|
||||
|
||||
async def kickoff_async(
|
||||
self,
|
||||
@@ -1980,11 +1622,9 @@ class Agent(BaseAgent):
|
||||
response_format: type[Any] | None = None,
|
||||
) -> LiteAgentOutput:
|
||||
"""
|
||||
Execute the agent asynchronously with the given messages.
|
||||
Execute the agent asynchronously with the given messages using a LiteAgent instance.
|
||||
|
||||
This is the async version of the kickoff method that uses native async
|
||||
execution. It is designed for use within async contexts, such as when
|
||||
called from within an async Flow method.
|
||||
This is the async version of the kickoff method.
|
||||
|
||||
Args:
|
||||
messages: Either a string query or a list of message dictionaries.
|
||||
@@ -1995,48 +1635,21 @@ class Agent(BaseAgent):
|
||||
Returns:
|
||||
LiteAgentOutput: The result of the agent execution.
|
||||
"""
|
||||
executor, inputs, agent_info, parsed_tools = self._prepare_kickoff(
|
||||
messages, response_format
|
||||
lite_agent = LiteAgent(
|
||||
role=self.role,
|
||||
goal=self.goal,
|
||||
backstory=self.backstory,
|
||||
llm=self.llm,
|
||||
tools=self.tools or [],
|
||||
max_iterations=self.max_iter,
|
||||
max_execution_time=self.max_execution_time,
|
||||
respect_context_window=self.respect_context_window,
|
||||
verbose=self.verbose,
|
||||
response_format=response_format,
|
||||
i18n=self.i18n,
|
||||
original_agent=self,
|
||||
guardrail=self.guardrail,
|
||||
guardrail_max_retries=self.guardrail_max_retries,
|
||||
)
|
||||
|
||||
try:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LiteAgentExecutionStartedEvent(
|
||||
agent_info=agent_info,
|
||||
tools=parsed_tools,
|
||||
messages=messages,
|
||||
),
|
||||
)
|
||||
|
||||
output = await self._execute_and_build_output_async(
|
||||
executor, inputs, response_format
|
||||
)
|
||||
|
||||
if self.guardrail is not None:
|
||||
output = self._process_kickoff_guardrail(
|
||||
output=output,
|
||||
executor=executor,
|
||||
inputs=inputs,
|
||||
response_format=response_format,
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LiteAgentExecutionCompletedEvent(
|
||||
agent_info=agent_info,
|
||||
output=output.raw,
|
||||
),
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LiteAgentExecutionErrorEvent(
|
||||
agent_info=agent_info,
|
||||
error=str(e),
|
||||
),
|
||||
)
|
||||
raise
|
||||
return await lite_agent.kickoff_async(messages)
|
||||
|
||||
@@ -21,9 +21,9 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class CrewAgentExecutorMixin:
|
||||
crew: Crew | None
|
||||
crew: Crew
|
||||
agent: Agent
|
||||
task: Task | None
|
||||
task: Task
|
||||
iterations: int
|
||||
max_iter: int
|
||||
messages: list[LLMMessage]
|
||||
|
||||
@@ -10,7 +10,6 @@ from collections.abc import Callable
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from crewai_files import aformat_multimodal_content, format_multimodal_content
|
||||
from pydantic import BaseModel, GetCoreSchemaHandler, ValidationError
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
|
||||
@@ -44,7 +43,6 @@ from crewai.utilities.agent_utils import (
|
||||
process_llm_response,
|
||||
)
|
||||
from crewai.utilities.constants import TRAINING_DATA_FILE
|
||||
from crewai.utilities.file_store import get_all_files
|
||||
from crewai.utilities.i18n import I18N, get_i18n
|
||||
from crewai.utilities.printer import Printer
|
||||
from crewai.utilities.tool_utils import (
|
||||
@@ -190,8 +188,6 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
user_prompt = self._format_prompt(self.prompt.get("prompt", ""), inputs)
|
||||
self.messages.append(format_message_for_llm(user_prompt))
|
||||
|
||||
self._inject_multimodal_files()
|
||||
|
||||
self._show_start_logs()
|
||||
|
||||
self.ask_for_human_input = bool(inputs.get("ask_for_human_input", False))
|
||||
@@ -216,74 +212,6 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
self._create_external_memory(formatted_answer)
|
||||
return {"output": formatted_answer.output}
|
||||
|
||||
def _inject_multimodal_files(self) -> None:
|
||||
"""Inject files as multimodal content into messages.
|
||||
|
||||
For crews with input files and LLMs that support multimodal,
|
||||
uses crewai_files to process, resolve, and format files into
|
||||
provider-specific content blocks.
|
||||
"""
|
||||
if not self.crew or not self.task:
|
||||
return
|
||||
|
||||
if not self.llm.supports_multimodal():
|
||||
return
|
||||
|
||||
files = get_all_files(self.crew.id, self.task.id)
|
||||
if not files:
|
||||
return
|
||||
|
||||
provider = getattr(self.llm, "provider", None) or getattr(self.llm, "model", "")
|
||||
content_blocks = format_multimodal_content(files, provider)
|
||||
|
||||
if not content_blocks:
|
||||
return
|
||||
|
||||
for i in range(len(self.messages) - 1, -1, -1):
|
||||
msg = self.messages[i]
|
||||
if msg.get("role") == "user":
|
||||
existing_content = msg.get("content", "")
|
||||
if isinstance(existing_content, str):
|
||||
msg["content"] = [
|
||||
self.llm.format_text_content(existing_content),
|
||||
*content_blocks,
|
||||
]
|
||||
break
|
||||
|
||||
async def _ainject_multimodal_files(self) -> None:
|
||||
"""Async inject files as multimodal content into messages.
|
||||
|
||||
For crews with input files and LLMs that support multimodal,
|
||||
uses crewai_files to process, resolve, and format files into
|
||||
provider-specific content blocks with parallel file resolution.
|
||||
"""
|
||||
if not self.crew or not self.task:
|
||||
return
|
||||
|
||||
if not self.llm.supports_multimodal():
|
||||
return
|
||||
|
||||
files = get_all_files(self.crew.id, self.task.id)
|
||||
if not files:
|
||||
return
|
||||
|
||||
provider = getattr(self.llm, "provider", None) or getattr(self.llm, "model", "")
|
||||
content_blocks = await aformat_multimodal_content(files, provider)
|
||||
|
||||
if not content_blocks:
|
||||
return
|
||||
|
||||
for i in range(len(self.messages) - 1, -1, -1):
|
||||
msg = self.messages[i]
|
||||
if msg.get("role") == "user":
|
||||
existing_content = msg.get("content", "")
|
||||
if isinstance(existing_content, str):
|
||||
msg["content"] = [
|
||||
self.llm.format_text_content(existing_content),
|
||||
*content_blocks,
|
||||
]
|
||||
break
|
||||
|
||||
def _invoke_loop(self) -> AgentFinish:
|
||||
"""Execute agent loop until completion.
|
||||
|
||||
@@ -427,8 +355,6 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
user_prompt = self._format_prompt(self.prompt.get("prompt", ""), inputs)
|
||||
self.messages.append(format_message_for_llm(user_prompt))
|
||||
|
||||
await self._ainject_multimodal_files()
|
||||
|
||||
self._show_start_logs()
|
||||
|
||||
self.ask_for_human_input = bool(inputs.get("ask_for_human_input", False))
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
from crewai.cli.authentication.providers.base_provider import BaseProvider
|
||||
|
||||
|
||||
class KeycloakProvider(BaseProvider):
|
||||
def get_authorize_url(self) -> str:
|
||||
return f"{self._oauth2_base_url()}/realms/{self.settings.extra.get('realm')}/protocol/openid-connect/auth/device"
|
||||
|
||||
def get_token_url(self) -> str:
|
||||
return f"{self._oauth2_base_url()}/realms/{self.settings.extra.get('realm')}/protocol/openid-connect/token"
|
||||
|
||||
def get_jwks_url(self) -> str:
|
||||
return f"{self._oauth2_base_url()}/realms/{self.settings.extra.get('realm')}/protocol/openid-connect/certs"
|
||||
|
||||
def get_issuer(self) -> str:
|
||||
return f"{self._oauth2_base_url()}/realms/{self.settings.extra.get('realm')}"
|
||||
|
||||
def get_audience(self) -> str:
|
||||
return self.settings.audience or "no-audience-provided"
|
||||
|
||||
def get_client_id(self) -> str:
|
||||
if self.settings.client_id is None:
|
||||
raise ValueError(
|
||||
"Client ID is required. Please set it in the configuration."
|
||||
)
|
||||
return self.settings.client_id
|
||||
|
||||
def get_required_fields(self) -> list[str]:
|
||||
return ["realm"]
|
||||
|
||||
def _oauth2_base_url(self) -> str:
|
||||
domain = self.settings.domain.removeprefix("https://").removeprefix("http://")
|
||||
return f"https://{domain}"
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.8.1"
|
||||
"crewai[tools]==1.8.0"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.8.1"
|
||||
"crewai[tools]==1.8.0"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -80,7 +80,6 @@ from crewai.task import Task
|
||||
from crewai.tasks.conditional_task import ConditionalTask
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
from crewai.tools.agent_tools.read_file_tool import ReadFileTool
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.types.streaming import CrewStreamingOutput
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
@@ -89,7 +88,6 @@ from crewai.utilities.crew.models import CrewContext
|
||||
from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator
|
||||
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
|
||||
from crewai.utilities.file_handler import FileHandler
|
||||
from crewai.utilities.file_store import clear_files, get_all_files
|
||||
from crewai.utilities.formatter import (
|
||||
aggregate_raw_outputs_from_task_outputs,
|
||||
aggregate_raw_outputs_from_tasks,
|
||||
@@ -108,7 +106,6 @@ from crewai.utilities.streaming import (
|
||||
)
|
||||
from crewai.utilities.task_output_storage_handler import TaskOutputStorageHandler
|
||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
from crewai.utilities.types import KickoffInputs
|
||||
|
||||
|
||||
warnings.filterwarnings("ignore", category=SyntaxWarning, module="pysbd")
|
||||
@@ -678,7 +675,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
def kickoff(
|
||||
self,
|
||||
inputs: KickoffInputs | dict[str, Any] | None = None,
|
||||
inputs: dict[str, Any] | None = None,
|
||||
) -> CrewOutput | CrewStreamingOutput:
|
||||
if self.stream:
|
||||
enable_agent_streaming(self.agents)
|
||||
@@ -735,7 +732,6 @@ class Crew(FlowTrackable, BaseModel):
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
clear_files(self.id)
|
||||
detach(token)
|
||||
|
||||
def kickoff_for_each(
|
||||
@@ -766,7 +762,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return results
|
||||
|
||||
async def kickoff_async(
|
||||
self, inputs: KickoffInputs | dict[str, Any] | None = None
|
||||
self, inputs: dict[str, Any] | None = None
|
||||
) -> CrewOutput | CrewStreamingOutput:
|
||||
"""Asynchronous kickoff method to start the crew execution.
|
||||
|
||||
@@ -821,7 +817,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return await run_for_each_async(self, inputs, kickoff_fn)
|
||||
|
||||
async def akickoff(
|
||||
self, inputs: KickoffInputs | dict[str, Any] | None = None
|
||||
self, inputs: dict[str, Any] | None = None
|
||||
) -> CrewOutput | CrewStreamingOutput:
|
||||
"""Native async kickoff method using async task execution throughout.
|
||||
|
||||
@@ -884,7 +880,6 @@ class Crew(FlowTrackable, BaseModel):
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
clear_files(self.id)
|
||||
detach(token)
|
||||
|
||||
async def akickoff_for_each(
|
||||
@@ -1220,8 +1215,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
and hasattr(agent, "multimodal")
|
||||
and getattr(agent, "multimodal", False)
|
||||
):
|
||||
if not (agent.llm and agent.llm.supports_multimodal()):
|
||||
tools = self._add_multimodal_tools(agent, tools)
|
||||
tools = self._add_multimodal_tools(agent, tools)
|
||||
|
||||
if agent and (hasattr(agent, "apps") and getattr(agent, "apps", None)):
|
||||
tools = self._add_platform_tools(task, tools)
|
||||
@@ -1229,24 +1223,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
if agent and (hasattr(agent, "mcps") and getattr(agent, "mcps", None)):
|
||||
tools = self._add_mcp_tools(task, tools)
|
||||
|
||||
files = get_all_files(self.id, task.id)
|
||||
if files:
|
||||
supported_types: list[str] = []
|
||||
if agent and agent.llm and agent.llm.supports_multimodal():
|
||||
supported_types = agent.llm.supported_multimodal_content_types()
|
||||
|
||||
def is_auto_injected(content_type: str) -> bool:
|
||||
return any(content_type.startswith(t) for t in supported_types)
|
||||
|
||||
# Only add read_file tool if there are files that need it
|
||||
files_needing_tool = {
|
||||
name: f
|
||||
for name, f in files.items()
|
||||
if not is_auto_injected(f.content_type)
|
||||
}
|
||||
if files_needing_tool:
|
||||
tools = self._add_file_tools(tools, files_needing_tool)
|
||||
|
||||
# Return a list[BaseTool] compatible with Task.execute_sync and execute_async
|
||||
return tools
|
||||
|
||||
def _get_agent_to_use(self, task: Task) -> BaseAgent | None:
|
||||
@@ -1326,22 +1303,6 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return self._merge_tools(tools, cast(list[BaseTool], code_tools))
|
||||
return tools
|
||||
|
||||
def _add_file_tools(
|
||||
self, tools: list[BaseTool], files: dict[str, Any]
|
||||
) -> list[BaseTool]:
|
||||
"""Add file reading tool when input files are available.
|
||||
|
||||
Args:
|
||||
tools: Current list of tools.
|
||||
files: Dictionary of input files.
|
||||
|
||||
Returns:
|
||||
Updated list with file tool added.
|
||||
"""
|
||||
read_file_tool = ReadFileTool()
|
||||
read_file_tool.set_files(files)
|
||||
return self._merge_tools(tools, [read_file_tool])
|
||||
|
||||
def _add_delegation_tools(
|
||||
self, task: Task, tools: list[BaseTool]
|
||||
) -> list[BaseTool]:
|
||||
|
||||
@@ -6,25 +6,15 @@ import asyncio
|
||||
from collections.abc import Callable, Coroutine, Iterable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai_files import (
|
||||
AudioFile,
|
||||
ImageFile,
|
||||
PDFFile,
|
||||
TextFile,
|
||||
VideoFile,
|
||||
)
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
from crewai.types.streaming import CrewStreamingOutput, FlowStreamingOutput
|
||||
from crewai.utilities.file_store import store_files
|
||||
from crewai.utilities.streaming import (
|
||||
StreamingState,
|
||||
TaskInfo,
|
||||
create_streaming_state,
|
||||
)
|
||||
from crewai.utilities.types import KickoffInputs
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -186,36 +176,7 @@ def check_conditional_skip(
|
||||
return None
|
||||
|
||||
|
||||
def _extract_files_from_inputs(inputs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Extract file objects from inputs dict.
|
||||
|
||||
Scans inputs for FileInput objects (ImageFile, TextFile, etc.) and
|
||||
extracts them into a separate dict.
|
||||
|
||||
Args:
|
||||
inputs: The inputs dictionary to scan.
|
||||
|
||||
Returns:
|
||||
Dictionary of extracted file objects.
|
||||
"""
|
||||
file_types = (AudioFile, ImageFile, PDFFile, TextFile, VideoFile)
|
||||
files: dict[str, Any] = {}
|
||||
keys_to_remove: list[str] = []
|
||||
|
||||
for key, value in inputs.items():
|
||||
if isinstance(value, file_types):
|
||||
files[key] = value
|
||||
keys_to_remove.append(key)
|
||||
|
||||
for key in keys_to_remove:
|
||||
del inputs[key]
|
||||
|
||||
return files
|
||||
|
||||
|
||||
def prepare_kickoff(
|
||||
crew: Crew, inputs: KickoffInputs | dict[str, Any] | None
|
||||
) -> dict[str, Any] | None:
|
||||
def prepare_kickoff(crew: Crew, inputs: dict[str, Any] | None) -> dict[str, Any] | None:
|
||||
"""Prepare crew for kickoff execution.
|
||||
|
||||
Handles before callbacks, event emission, task handler reset, input
|
||||
@@ -231,17 +192,14 @@ def prepare_kickoff(
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.crew_events import CrewKickoffStartedEvent
|
||||
|
||||
# Normalize inputs to dict[str, Any] for internal processing
|
||||
normalized: dict[str, Any] | None = dict(inputs) if inputs is not None else None
|
||||
|
||||
for before_callback in crew.before_kickoff_callbacks:
|
||||
if normalized is None:
|
||||
normalized = {}
|
||||
normalized = before_callback(normalized)
|
||||
if inputs is None:
|
||||
inputs = {}
|
||||
inputs = before_callback(inputs)
|
||||
|
||||
future = crewai_event_bus.emit(
|
||||
crew,
|
||||
CrewKickoffStartedEvent(crew_name=crew.name, inputs=normalized),
|
||||
CrewKickoffStartedEvent(crew_name=crew.name, inputs=inputs),
|
||||
)
|
||||
if future is not None:
|
||||
try:
|
||||
@@ -252,20 +210,9 @@ def prepare_kickoff(
|
||||
crew._task_output_handler.reset()
|
||||
crew._logging_color = "bold_purple"
|
||||
|
||||
if normalized is not None:
|
||||
# Extract files from dedicated "files" key
|
||||
files = normalized.pop("files", None) or {}
|
||||
|
||||
# Extract file objects unpacked directly into inputs
|
||||
unpacked_files = _extract_files_from_inputs(normalized)
|
||||
|
||||
# Merge files (unpacked files take precedence over explicit files dict)
|
||||
all_files = {**files, **unpacked_files}
|
||||
if all_files:
|
||||
store_files(crew.id, all_files)
|
||||
|
||||
crew._inputs = normalized
|
||||
crew._interpolate_inputs(normalized)
|
||||
if inputs is not None:
|
||||
crew._inputs = inputs
|
||||
crew._interpolate_inputs(inputs)
|
||||
crew._set_tasks_callbacks()
|
||||
crew._set_allow_crewai_trigger_context_for_first_task()
|
||||
|
||||
@@ -280,7 +227,7 @@ def prepare_kickoff(
|
||||
if crew.planning:
|
||||
crew._handle_crew_planning()
|
||||
|
||||
return normalized
|
||||
return inputs
|
||||
|
||||
|
||||
class StreamingContext:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user