Compare commits

...

28 Commits

Author SHA1 Message Date
Greyson LaLonde
8b3acb58a4 chore: add crewai-a2a package README 2026-02-27 09:23:55 -05:00
Greyson LaLonde
8a3c2d5ca6 refactor: extract crewai.a2a to crewai-a2a workspace package 2026-02-27 09:12:57 -05:00
João Moura
1bdb9496a3 refactor: update step callback methods to support asynchronous invocation (#4633)
* refactor: update step callback methods to support asynchronous invocation

- Replaced synchronous step callback invocations with asynchronous counterparts in the CrewAgentExecutor class.
- Introduced a new async method _ainvoke_step_callback to handle step callbacks in an async context, improving responsiveness and performance in asynchronous workflows.

* chore: bump version to 1.10.1b1 across multiple files

- Updated version strings from 1.10.1b to 1.10.1b1 in various project files including pyproject.toml and __init__.py files.
- Adjusted dependency specifications to reflect the new version in relevant templates and modules.
2026-02-27 07:35:03 -03:00
Joao Moura
979aa26c3d bump new alpha version 2026-02-27 01:43:33 -08:00
João Moura
514c082882 refactor: implement lazy loading for heavy dependencies in Memory module (#4632)
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Check Documentation Broken Links / Check broken links (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
- Introduced lazy imports for the Memory and EncodingFlow classes to optimize import time and reduce initial load, particularly beneficial for deployment scenarios like Celery pre-fork.
- Updated the Memory class to include new configuration options for aggregation queries, enhancing its functionality.
- Adjusted the __getattr__ method in both the crewai and memory modules to support lazy loading of specified attributes.
2026-02-27 03:20:02 -03:00
Greyson LaLonde
c9e8068578 docs: update changelog and version for v1.10.0 2026-02-26 19:14:25 -05:00
Greyson LaLonde
df2778f08b fix: make branch for release notes
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Build uv cache / build-cache (3.10) (push) Has been cancelled
Build uv cache / build-cache (3.11) (push) Has been cancelled
Build uv cache / build-cache (3.12) (push) Has been cancelled
Build uv cache / build-cache (3.13) (push) Has been cancelled
2026-02-26 18:49:13 -05:00
Greyson LaLonde
d8fea2518d feat: bump versions to 1.10.0
* feat: bump versions to 1.10.0

* chore: update tool specifications

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Lorenze Jay <63378463+lorenzejay@users.noreply.github.com>
2026-02-26 18:31:14 -05:00
Lucas Gomide
d259150d8d Enhance MCP tool resolution and related events (#4580)
* feat: enhance MCP tool resolution

* feat: emit event when MCP configuration fails

* feat: emit event when MCP tool execution has failed

* style: resolve linter issues

* refactor: use clear and natural mcp tool name resolution

* test: fix broken tests

* fix: resolve MCP connection leaks, slug validation, duplicate connections, and httpx exception handling

---------

Co-authored-by: Greyson LaLonde <greyson.r.lalonde@gmail.com>
Co-authored-by: Greyson LaLonde <greyson@crewai.com>
2026-02-26 13:59:30 -08:00
Greyson LaLonde
c4a328c9d5 fix: validate tool kwargs even when empty to prevent cryptic TypeError (#4611) 2026-02-26 16:18:03 -05:00
Greyson LaLonde
373abbb6b7 fix: add dict overload to build_embedder and type default embedder 2026-02-26 16:04:28 -05:00
João Moura
86d3ee022d feat: update lancedb version and add lance-namespace packages
* chore(deps): update lancedb version and add lance-namespace packages

- Updated lancedb dependency version from 0.4.0 to 0.29.2 in multiple files.
- Added new packages: lance-namespace and lance-namespace-urllib3-client with version 0.5.2, including their dependencies and installation details.
- Enhanced MemoryTUI to display a limit on entries and improved the LanceDBStorage class with automatic background compaction and index creation for better performance.

* linter

* refactor: update memory recall limit and formatting in Agent class

- Reduced the memory recall limit from 10 to 5 in multiple locations within the Agent class.
- Updated the memory formatting to use a new `format` method in the MemoryMatch class for improved readability and metadata inclusion.

* refactor: enhance memory handling with read-only support

- Updated memory-related classes and methods to support read-only functionality, allowing for silent no-ops when attempting to remember data in read-only mode.
- Modified the LiteAgent and CrewAgentExecutorMixin classes to check for read-only status before saving memories.
- Adjusted MemorySlice and Memory classes to reflect changes in behavior when read-only is enabled.
- Updated tests to verify that memory operations behave correctly under read-only conditions.

* test: set mock memory to read-write in unit tests

- Updated unit tests in test_unified_memory.py to set mock_memory._read_only to False, ensuring that memory operations can be tested in a writable state.

* fix test

* fix: preserve falsy metadata values and fix remember() return type

---------

Co-authored-by: lorenzejay <lorenzejaytech@gmail.com>
Co-authored-by: Greyson LaLonde <greyson@crewai.com>
2026-02-26 15:05:10 -05:00
Lucas Gomide
09e3b81ca3 fix: preserve null types in tool parameter schemas for LLM (#4579)
* fix: preserve null types in tool parameter schemas for LLM

Tool parameter schemas were stripping null from optional fields via
generate_model_description, forcing the LLM to provide non-null values
for fields.
Adds strip_null_types parameter to generate_model_description and passes False when generating tool
schemas, so optional fields keep anyOf: [{type: T}, {type: null}]

* Update lib/crewai/src/crewai/utilities/pydantic_schema_utils.py

Co-authored-by: Gabe Milani <gabriel@crewai.com>

---------

Co-authored-by: Greyson LaLonde <greyson.r.lalonde@gmail.com>
Co-authored-by: Gabe Milani <gabriel@crewai.com>
Co-authored-by: Lorenze Jay <63378463+lorenzejay@users.noreply.github.com>
2026-02-26 11:51:34 -05:00
Heitor Carvalho
b6d8ce5c55 docs: add litellm dependency note for non-native LLM providers (#4600)
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Check Documentation Broken Links / Check broken links (push) Has been cancelled
2026-02-26 10:57:37 -03:00
Greyson LaLonde
b371f97a2f fix: map output_pydantic/output_json to native structured output
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Check Documentation Broken Links / Check broken links (push) Has been cancelled
Build uv cache / build-cache (3.10) (push) Has been cancelled
Build uv cache / build-cache (3.11) (push) Has been cancelled
Build uv cache / build-cache (3.12) (push) Has been cancelled
Build uv cache / build-cache (3.13) (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
* fix: map output_pydantic/output_json to native structured output

* test: add crew+tools+structured output integration test for Gemini

* fix: re-record stale cassette for test_crew_testing_function

* fix: re-record remaining stale cassettes for native structured output

* fix: enable native structured output for lite agent and fix mypy errors
2026-02-25 17:13:34 -05:00
dependabot[bot]
017189db78 chore(deps): bump nltk in the security-updates group across 1 directory (#4598)
Bumps the security-updates group with 1 update in the / directory: [nltk](https://github.com/nltk/nltk).


Updates `nltk` from 3.9.2 to 3.9.3
- [Changelog](https://github.com/nltk/nltk/blob/develop/ChangeLog)
- [Commits](https://github.com/nltk/nltk/compare/3.9.2...3.9.3)

---
updated-dependencies:
- dependency-name: nltk
  dependency-version: 3.9.3
  dependency-type: indirect
  dependency-group: security-updates
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-02-25 15:37:21 -06:00
dependabot[bot]
02d911494f chore(deps): bump cryptography (#4506)
Bumps the security-updates group with 1 update in the / directory: [cryptography](https://github.com/pyca/cryptography).


Updates `cryptography` from 46.0.4 to 46.0.5
- [Changelog](https://github.com/pyca/cryptography/blob/main/CHANGELOG.rst)
- [Commits](https://github.com/pyca/cryptography/compare/46.0.4...46.0.5)

---
updated-dependencies:
- dependency-name: cryptography
  dependency-version: 46.0.5
  dependency-type: indirect
  dependency-group: security-updates
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-02-25 15:04:07 -06:00
João Moura
8102d0a6ca feat: enhance JSON argument parsing and validation in CrewAgentExecutor and BaseTool
* feat: enhance JSON argument parsing and validation in CrewAgentExecutor and BaseTool

- Added error handling for malformed JSON tool arguments in CrewAgentExecutor, providing descriptive error messages.
- Implemented schema validation for tool arguments in BaseTool, ensuring that invalid arguments raise appropriate exceptions.
- Introduced tests to verify correct behavior for both valid and invalid JSON inputs, enhancing robustness of tool execution.

* refactor: improve argument validation in BaseTool

- Introduced a new private method  to handle argument validation for tools, enhancing code clarity and reusability.
- Updated the  method to utilize the new validation method, ensuring consistent error handling for invalid arguments.
- Enhanced exception handling to specifically catch , providing clearer error messages for tool argument validation failures.

* feat: introduce parse_tool_call_args for improved argument parsing

- Added a new utility function, parse_tool_call_args, to handle parsing of tool call arguments from JSON strings or dictionaries, enhancing error handling for malformed JSON inputs.
- Updated CrewAgentExecutor and AgentExecutor to utilize the new parsing function, streamlining argument validation and improving clarity in error reporting.
- Introduced unit tests for parse_tool_call_args to ensure robust functionality and correct handling of various input scenarios.

* feat: add keyword argument validation in BaseTool and Tool classes

- Introduced a new method `_validate_kwargs` in BaseTool to validate keyword arguments against the defined schema, ensuring proper argument handling.
- Updated the `run` and `arun` methods in both BaseTool and Tool classes to utilize the new validation method, improving error handling and robustness.
- Added comprehensive tests for asynchronous execution in `TestBaseToolArunValidation` to verify correct behavior for valid and invalid keyword arguments.

* Potential fix for pull request finding 'Syntax error'

Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com>

---------

Co-authored-by: lorenzejay <lorenzejaytech@gmail.com>
Co-authored-by: Lorenze Jay <63378463+lorenzejay@users.noreply.github.com>
Co-authored-by: Greyson LaLonde <greyson.r.lalonde@gmail.com>
Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com>
2026-02-25 13:13:31 -05:00
Greyson LaLonde
ee374d01de chore: add versioning logic for devtools 2026-02-25 12:13:00 -05:00
Greyson LaLonde
9914e51199 feat: add versioned docs
starting with 1.10.0
2026-02-25 11:05:31 -05:00
nicoferdi96
2dbb83ae31 Private package registry (#4583)
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Check Documentation Broken Links / Check broken links (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
adding reference and explaination for package registry

Co-authored-by: Lorenze Jay <63378463+lorenzejay@users.noreply.github.com>
2026-02-24 19:37:17 +01:00
Mike Plachta
7377e1aa26 fix: bedrock region was always set to "us-east-1" not respecting the env var. (#4582)
* fix: bedrock region was always set to "us-east-1" not respecting the env
var.

code had AWS_REGION_NAME referenced, but not used, unified to
AWS_DEFAULT_REGION as per documentation

* DRY code improvement and fix caught by tests.

* Supporting litellm configuration
2026-02-24 09:59:01 -08:00
Greyson LaLonde
51754899a2 feat: migrate CLI http client from requests to httpx
Some checks failed
Build uv cache / build-cache (3.10) (push) Has been cancelled
Build uv cache / build-cache (3.11) (push) Has been cancelled
Build uv cache / build-cache (3.12) (push) Has been cancelled
Build uv cache / build-cache (3.13) (push) Has been cancelled
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
2026-02-20 18:21:05 -05:00
Greyson LaLonde
71b4f8402a fix: ensure callbacks are ran/awaited if promise
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Build uv cache / build-cache (3.12) (push) Has been cancelled
Build uv cache / build-cache (3.13) (push) Has been cancelled
Build uv cache / build-cache (3.10) (push) Has been cancelled
Build uv cache / build-cache (3.11) (push) Has been cancelled
2026-02-20 13:15:50 -05:00
Greyson LaLonde
4a4c99d8a2 fix: capture method name in exception context
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
2026-02-19 17:51:18 -05:00
Greyson LaLonde
28a6b855a2 fix: preserve enum type in router result; improve types 2026-02-19 17:30:47 -05:00
Lorenze Jay
d09656664d supporting parallel tool use (#4513)
* supporting parallel tool use

* ensure we respect max_usage_count

* ensure result_as_answer, hooks, and cache parodity

* improve crew agent executor

* address test comments
2026-02-19 14:07:28 -08:00
Lucas Gomide
49aa29bb41 docs: correct broken human_feedback examples with working self-loop patterns (#4520)
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Check Documentation Broken Links / Check broken links (push) Has been cancelled
2026-02-19 09:02:01 -08:00
231 changed files with 26115 additions and 16452 deletions

View File

@@ -21,7 +21,6 @@ OPENROUTER_API_KEY=fake-openrouter-key
AWS_ACCESS_KEY_ID=fake-aws-access-key
AWS_SECRET_ACCESS_KEY=fake-aws-secret-key
AWS_DEFAULT_REGION=us-east-1
AWS_REGION_NAME=us-east-1
# -----------------------------------------------------------------------------
# Azure OpenAI Configuration

View File

@@ -1,8 +1,6 @@
name: Publish to PyPI
on:
repository_dispatch:
types: [deployment-tests-passed]
workflow_dispatch:
inputs:
release_tag:
@@ -20,11 +18,8 @@ jobs:
- name: Determine release tag
id: release
run: |
# Priority: workflow_dispatch input > repository_dispatch payload > default branch
if [ -n "${{ inputs.release_tag }}" ]; then
echo "tag=${{ inputs.release_tag }}" >> $GITHUB_OUTPUT
elif [ -n "${{ github.event.client_payload.release_tag }}" ]; then
echo "tag=${{ github.event.client_payload.release_tag }}" >> $GITHUB_OUTPUT
else
echo "tag=" >> $GITHUB_OUTPUT
fi

View File

@@ -1,18 +0,0 @@
name: Trigger Deployment Tests
on:
release:
types: [published]
jobs:
trigger:
name: Trigger deployment tests
runs-on: ubuntu-latest
steps:
- name: Trigger deployment tests
uses: peter-evans/repository-dispatch@v3
with:
token: ${{ secrets.CREWAI_DEPLOYMENTS_PAT }}
repository: ${{ secrets.CREWAI_DEPLOYMENTS_REPOSITORY }}
event-type: crewai-release
client-payload: '{"release_tag": "${{ github.event.release.tag_name }}", "release_name": "${{ github.event.release.name }}"}'

View File

@@ -19,7 +19,7 @@ repos:
language: system
pass_filenames: true
types: [python]
exclude: ^(lib/crewai/src/crewai/cli/templates/|lib/crewai/tests/|lib/crewai-tools/tests/|lib/crewai-files/tests/)
exclude: ^(lib/crewai/src/crewai/cli/templates/|lib/crewai/tests/|lib/crewai-tools/tests/|lib/crewai-files/tests/|lib/crewai-a2a/tests/)
- repo: https://github.com/astral-sh/uv-pre-commit
rev: 0.9.3
hooks:

View File

@@ -12,6 +12,7 @@ from dotenv import load_dotenv
import pytest
from vcr.request import Request # type: ignore[import-untyped]
try:
import vcr.stubs.httpx_stubs as httpx_stubs # type: ignore[import-untyped]
except ModuleNotFoundError:
@@ -225,7 +226,7 @@ def vcr_cassette_dir(request: Any) -> str:
for parent in test_file.parents:
if (
parent.name in ("crewai", "crewai-tools", "crewai-files")
parent.name in ("crewai", "crewai-tools", "crewai-files", "crewai-a2a")
and parent.parent.name == "lib"
):
package_root = parent

File diff suppressed because it is too large Load Diff

View File

@@ -4,6 +4,56 @@ description: "Product updates, improvements, and bug fixes for CrewAI"
icon: "clock"
mode: "wide"
---
<Update label="Feb 26, 2026">
## v1.10.0
[View release on GitHub](https://github.com/crewAIInc/crewAI/releases/tag/1.10.0)
## What's Changed
### Features
- Enhance MCP tool resolution and related events
- Update lancedb version and add lance-namespace packages
- Enhance JSON argument parsing and validation in CrewAgentExecutor and BaseTool
- Migrate CLI HTTP client from requests to httpx
- Add versioned documentation
- Add yanked detection for version notes
- Implement user input handling in Flows
- Enhance HITL self-loop functionality in human feedback integration tests
- Add started_event_id and set in eventbus
- Auto update tools.specs
### Bug Fixes
- Validate tool kwargs even when empty to prevent cryptic TypeError
- Preserve null types in tool parameter schemas for LLM
- Map output_pydantic/output_json to native structured output
- Ensure callbacks are ran/awaited if promise
- Capture method name in exception context
- Preserve enum type in router result; improve types
- Fix cyclic flows silently breaking when persistence ID is passed in inputs
- Correct CLI flag format from --skip-provider to --skip_provider
- Ensure OpenAI tool call stream is finalized
- Resolve complex schema $ref pointers in MCP tools
- Enforce additionalProperties=false in schemas
- Reject reserved script names for crew folders
- Resolve race condition in guardrail event emission test
### Documentation
- Add litellm dependency note for non-native LLM providers
- Clarify NL2SQL security model and hardening guidance
- Add 96 missing actions across 9 integrations
### Refactoring
- Refactor crew to provider
- Extract HITL to provider pattern
- Improve hook typing and registration
## Contributors
@dependabot[bot], @github-actions[bot], @github-code-quality[bot], @greysonlalonde, @heitorado, @hobostay, @joaomdmoura, @johnvan7, @jonathansampson, @lorenzejay, @lucasgomide, @mattatcha, @mplachta, @nicoferdi96, @theCyberTech, @thiagomoretto, @vinibrsl
</Update>
<Update label="Jan 26, 2026">
## v1.9.0

View File

@@ -106,6 +106,15 @@ There are different places in CrewAI code where you can specify the model to use
</Tab>
</Tabs>
<Info>
CrewAI provides native SDK integrations for OpenAI, Anthropic, Google (Gemini API), Azure, and AWS Bedrock — no extra install needed beyond the provider-specific extras (e.g. `uv add "crewai[openai]"`).
All other providers are powered by **LiteLLM**. If you plan to use any of them, add it as a dependency to your project:
```bash
uv add 'crewai[litellm]'
```
</Info>
## Provider Configuration Examples
CrewAI supports a multitude of LLM providers, each offering unique features, authentication methods, and model capabilities.
@@ -275,6 +284,11 @@ In this section, you'll find detailed examples that help you select, configure,
| `meta_llama/Llama-4-Maverick-17B-128E-Instruct-FP8` | 128k | 4028 | Text, Image | Text |
| `meta_llama/Llama-3.3-70B-Instruct` | 128k | 4028 | Text | Text |
| `meta_llama/Llama-3.3-8B-Instruct` | 128k | 4028 | Text | Text |
**Note:** This provider uses LiteLLM. Add it as a dependency to your project:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
<Accordion title="Anthropic">
@@ -470,7 +484,7 @@ In this section, you'll find detailed examples that help you select, configure,
To get an Express mode API key:
- New Google Cloud users: Get an [express mode API key](https://cloud.google.com/vertex-ai/generative-ai/docs/start/quickstart?usertype=apikey)
- Existing Google Cloud users: Get a [Google Cloud API key bound to a service account](https://cloud.google.com/docs/authentication/api-keys)
For more details, see the [Vertex AI Express mode documentation](https://docs.cloud.google.com/vertex-ai/generative-ai/docs/start/quickstart?usertype=apikey).
</Info>
@@ -571,6 +585,11 @@ In this section, you'll find detailed examples that help you select, configure,
| gemini-1.5-flash | 1M tokens | Balanced multimodal model, good for most tasks |
| gemini-1.5-flash-8B | 1M tokens | Fastest, most cost-efficient, good for high-frequency tasks |
| gemini-1.5-pro | 2M tokens | Best performing, wide variety of reasoning tasks including logical reasoning, coding, and creative collaboration |
**Note:** This provider uses LiteLLM. Add it as a dependency to your project:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
<Accordion title="Azure">
@@ -652,6 +671,7 @@ In this section, you'll find detailed examples that help you select, configure,
# Optional
AWS_SESSION_TOKEN=<your-session-token> # For temporary credentials
AWS_DEFAULT_REGION=<your-region> # Defaults to us-east-1
AWS_REGION_NAME=<your-region> # Alternative configuration for backwards compatibility with LiteLLM. Defaults to us-east-1
```
**Basic Usage:**
@@ -695,6 +715,7 @@ In this section, you'll find detailed examples that help you select, configure,
- `AWS_SECRET_ACCESS_KEY`: AWS secret key (required)
- `AWS_SESSION_TOKEN`: AWS session token for temporary credentials (optional)
- `AWS_DEFAULT_REGION`: AWS region (defaults to `us-east-1`)
- `AWS_REGION_NAME`: AWS region (defaults to `us-east-1`). Alternative configuration for backwards compatibility with LiteLLM
**Features:**
- Native tool calling support via Converse API
@@ -764,6 +785,11 @@ In this section, you'll find detailed examples that help you select, configure,
model="sagemaker/<my-endpoint>"
)
```
**Note:** This provider uses LiteLLM. Add it as a dependency to your project:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
<Accordion title="Mistral">
@@ -779,6 +805,11 @@ In this section, you'll find detailed examples that help you select, configure,
temperature=0.7
)
```
**Note:** This provider uses LiteLLM. Add it as a dependency to your project:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
<Accordion title="Nvidia NIM">
@@ -865,6 +896,11 @@ In this section, you'll find detailed examples that help you select, configure,
| rakuten/rakutenai-7b-instruct | 1,024 tokens | Advanced state-of-the-art LLM with language understanding, superior reasoning, and text generation. |
| rakuten/rakutenai-7b-chat | 1,024 tokens | Advanced state-of-the-art LLM with language understanding, superior reasoning, and text generation. |
| baichuan-inc/baichuan2-13b-chat | 4,096 tokens | Support Chinese and English chat, coding, math, instruction following, solving quizzes |
**Note:** This provider uses LiteLLM. Add it as a dependency to your project:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
<Accordion title="Local NVIDIA NIM Deployed using WSL2">
@@ -905,6 +941,11 @@ In this section, you'll find detailed examples that help you select, configure,
# ...
```
**Note:** This provider uses LiteLLM. Add it as a dependency to your project:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
<Accordion title="Groq">
@@ -926,6 +967,11 @@ In this section, you'll find detailed examples that help you select, configure,
| Llama 3.1 70B/8B | 131,072 tokens | High-performance, large context tasks |
| Llama 3.2 Series | 8,192 tokens | General-purpose tasks |
| Mixtral 8x7B | 32,768 tokens | Balanced performance and context |
**Note:** This provider uses LiteLLM. Add it as a dependency to your project:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
<Accordion title="IBM watsonx.ai">
@@ -948,6 +994,11 @@ In this section, you'll find detailed examples that help you select, configure,
base_url="https://api.watsonx.ai/v1"
)
```
**Note:** This provider uses LiteLLM. Add it as a dependency to your project:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
<Accordion title="Ollama (Local LLMs)">
@@ -961,6 +1012,11 @@ In this section, you'll find detailed examples that help you select, configure,
base_url="http://localhost:11434"
)
```
**Note:** This provider uses LiteLLM. Add it as a dependency to your project:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
<Accordion title="Fireworks AI">
@@ -976,6 +1032,11 @@ In this section, you'll find detailed examples that help you select, configure,
temperature=0.7
)
```
**Note:** This provider uses LiteLLM. Add it as a dependency to your project:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
<Accordion title="Perplexity AI">
@@ -991,6 +1052,11 @@ In this section, you'll find detailed examples that help you select, configure,
base_url="https://api.perplexity.ai/"
)
```
**Note:** This provider uses LiteLLM. Add it as a dependency to your project:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
<Accordion title="Hugging Face">
@@ -1005,6 +1071,11 @@ In this section, you'll find detailed examples that help you select, configure,
model="huggingface/meta-llama/Meta-Llama-3.1-8B-Instruct"
)
```
**Note:** This provider uses LiteLLM. Add it as a dependency to your project:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
<Accordion title="SambaNova">
@@ -1028,6 +1099,11 @@ In this section, you'll find detailed examples that help you select, configure,
| Llama 3.2 Series | 8,192 tokens | General-purpose, multimodal tasks |
| Llama 3.3 70B | Up to 131,072 tokens | High-performance and output quality |
| Qwen2 familly | 8,192 tokens | High-performance and output quality |
**Note:** This provider uses LiteLLM. Add it as a dependency to your project:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
<Accordion title="Cerebras">
@@ -1053,6 +1129,11 @@ In this section, you'll find detailed examples that help you select, configure,
- Good balance of speed and quality
- Support for long context windows
</Info>
**Note:** This provider uses LiteLLM. Add it as a dependency to your project:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
<Accordion title="Open Router">
@@ -1075,6 +1156,11 @@ In this section, you'll find detailed examples that help you select, configure,
- openrouter/deepseek/deepseek-r1
- openrouter/deepseek/deepseek-chat
</Info>
**Note:** This provider uses LiteLLM. Add it as a dependency to your project:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
<Accordion title="Nebius AI Studio">
@@ -1097,6 +1183,11 @@ In this section, you'll find detailed examples that help you select, configure,
- Competitive pricing
- Good balance of speed and quality
</Info>
**Note:** This provider uses LiteLLM. Add it as a dependency to your project:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
</AccordionGroup>

View File

@@ -38,22 +38,21 @@ CrewAI Enterprise provides a comprehensive Human-in-the-Loop (HITL) management s
Configure human review checkpoints within your Flows using the `@human_feedback` decorator. When execution reaches a review point, the system pauses, notifies the assignee via email, and waits for a response.
```python
from crewai.flow.flow import Flow, start, listen
from crewai.flow.flow import Flow, start, listen, or_
from crewai.flow.human_feedback import human_feedback, HumanFeedbackResult
class ContentApprovalFlow(Flow):
@start()
def generate_content(self):
# AI generates content
return "Generated marketing copy for Q1 campaign..."
@listen(generate_content)
@human_feedback(
message="Please review this content for brand compliance:",
emit=["approved", "rejected", "needs_revision"],
)
def review_content(self, content):
return content
@listen(or_("generate_content", "needs_revision"))
def review_content(self):
return "Marketing copy for review..."
@listen("approved")
def publish_content(self, result: HumanFeedbackResult):
@@ -62,10 +61,6 @@ class ContentApprovalFlow(Flow):
@listen("rejected")
def archive_content(self, result: HumanFeedbackResult):
print(f"Content rejected. Reason: {result.feedback}")
@listen("needs_revision")
def revise_content(self, result: HumanFeedbackResult):
print(f"Revision requested: {result.feedback}")
```
For complete implementation details, see the [Human Feedback in Flows](/en/learn/human-feedback-in-flows) guide.

View File

@@ -177,6 +177,11 @@ You need to push your crew to a GitHub repository. If you haven't created a crew
![Set Environment Variables](/images/enterprise/set-env-variables.png)
</Frame>
<Info>
Using private Python packages? You'll need to add your registry credentials here too.
See [Private Package Registries](/en/enterprise/guides/private-package-registry) for the required variables.
</Info>
</Step>
<Step title="Deploy Your Crew">

View File

@@ -256,6 +256,12 @@ Before deployment, ensure you have:
1. **LLM API keys** ready (OpenAI, Anthropic, Google, etc.)
2. **Tool API keys** if using external tools (Serper, etc.)
<Info>
If your project depends on packages from a **private PyPI registry**, you'll also need to configure
registry authentication credentials as environment variables. See the
[Private Package Registries](/en/enterprise/guides/private-package-registry) guide for details.
</Info>
<Tip>
Test your project locally with the same environment variables before deploying
to catch configuration issues early.

View File

@@ -0,0 +1,263 @@
---
title: "Private Package Registries"
description: "Install private Python packages from authenticated PyPI registries in CrewAI AMP"
icon: "lock"
mode: "wide"
---
<Note>
This guide covers how to configure your CrewAI project to install Python packages
from private PyPI registries (Azure DevOps Artifacts, GitHub Packages, GitLab, AWS CodeArtifact, etc.)
when deploying to CrewAI AMP.
</Note>
## When You Need This
If your project depends on internal or proprietary Python packages hosted on a private registry
rather than the public PyPI, you'll need to:
1. Tell UV **where** to find the package (an index URL)
2. Tell UV **which** packages come from that index (a source mapping)
3. Provide **credentials** so UV can authenticate during install
CrewAI AMP uses [UV](https://docs.astral.sh/uv/) for dependency resolution and installation.
UV supports authenticated private registries through `pyproject.toml` configuration combined
with environment variables for credentials.
## Step 1: Configure pyproject.toml
Three pieces work together in your `pyproject.toml`:
### 1a. Declare the dependency
Add the private package to your `[project.dependencies]` like any other dependency:
```toml
[project]
dependencies = [
"crewai[tools]>=0.100.1,<1.0.0",
"my-private-package>=1.2.0",
]
```
### 1b. Define the index
Register your private registry as a named index under `[[tool.uv.index]]`:
```toml
[[tool.uv.index]]
name = "my-private-registry"
url = "https://pkgs.dev.azure.com/my-org/_packaging/my-feed/pypi/simple/"
explicit = true
```
<Info>
The `name` field is important — UV uses it to construct the environment variable names
for authentication (see [Step 2](#step-2-set-authentication-credentials) below).
Setting `explicit = true` means UV won't search this index for every package — only the
ones you explicitly map to it in `[tool.uv.sources]`. This avoids unnecessary queries
against your private registry and protects against dependency confusion attacks.
</Info>
### 1c. Map the package to the index
Tell UV which packages should be resolved from your private index using `[tool.uv.sources]`:
```toml
[tool.uv.sources]
my-private-package = { index = "my-private-registry" }
```
### Complete example
```toml
[project]
name = "my-crew-project"
version = "0.1.0"
requires-python = ">=3.10,<=3.13"
dependencies = [
"crewai[tools]>=0.100.1,<1.0.0",
"my-private-package>=1.2.0",
]
[tool.crewai]
type = "crew"
[[tool.uv.index]]
name = "my-private-registry"
url = "https://pkgs.dev.azure.com/my-org/_packaging/my-feed/pypi/simple/"
explicit = true
[tool.uv.sources]
my-private-package = { index = "my-private-registry" }
```
After updating `pyproject.toml`, regenerate your lock file:
```bash
uv lock
```
<Warning>
Always commit the updated `uv.lock` along with your `pyproject.toml` changes.
The lock file is required for deployment — see [Prepare for Deployment](/en/enterprise/guides/prepare-for-deployment).
</Warning>
## Step 2: Set Authentication Credentials
UV authenticates against private indexes using environment variables that follow a naming convention
based on the index name you defined in `pyproject.toml`:
```
UV_INDEX_{UPPER_NAME}_USERNAME
UV_INDEX_{UPPER_NAME}_PASSWORD
```
Where `{UPPER_NAME}` is your index name converted to **uppercase** with **hyphens replaced by underscores**.
For example, an index named `my-private-registry` uses:
| Variable | Value |
|----------|-------|
| `UV_INDEX_MY_PRIVATE_REGISTRY_USERNAME` | Your registry username or token name |
| `UV_INDEX_MY_PRIVATE_REGISTRY_PASSWORD` | Your registry password or token/PAT |
<Warning>
These environment variables **must** be added via the CrewAI AMP **Environment Variables** settings —
either globally or at the deployment level. They cannot be set in `.env` files or hardcoded in your project.
See [Setting Environment Variables in AMP](#setting-environment-variables-in-amp) below.
</Warning>
## Registry Provider Reference
The table below shows the index URL format and credential values for common registry providers.
Replace placeholder values with your actual organization and feed details.
| Provider | Index URL | Username | Password |
|----------|-----------|----------|----------|
| **Azure DevOps Artifacts** | `https://pkgs.dev.azure.com/{org}/_packaging/{feed}/pypi/simple/` | Any non-empty string (e.g. `token`) | Personal Access Token (PAT) with Packaging Read scope |
| **GitHub Packages** | `https://pypi.pkg.github.com/{owner}/simple/` | GitHub username | Personal Access Token (classic) with `read:packages` scope |
| **GitLab Package Registry** | `https://gitlab.com/api/v4/projects/{project_id}/packages/pypi/simple/` | `__token__` | Project or Personal Access Token with `read_api` scope |
| **AWS CodeArtifact** | Use the URL from `aws codeartifact get-repository-endpoint` | `aws` | Token from `aws codeartifact get-authorization-token` |
| **Google Artifact Registry** | `https://{region}-python.pkg.dev/{project}/{repo}/simple/` | `_json_key_base64` | Base64-encoded service account key |
| **JFrog Artifactory** | `https://{instance}.jfrog.io/artifactory/api/pypi/{repo}/simple/` | Username or email | API key or identity token |
| **Self-hosted (devpi, Nexus, etc.)** | Your registry's simple API URL | Registry username | Registry password |
<Tip>
For **AWS CodeArtifact**, the authorization token expires periodically.
You'll need to refresh the `UV_INDEX_*_PASSWORD` value when it expires.
Consider automating this in your CI/CD pipeline.
</Tip>
## Setting Environment Variables in AMP
Private registry credentials must be configured as environment variables in CrewAI AMP.
You have two options:
<Tabs>
<Tab title="Web Interface">
1. Log in to [CrewAI AMP](https://app.crewai.com)
2. Navigate to your automation
3. Open the **Environment Variables** tab
4. Add each variable (`UV_INDEX_*_USERNAME` and `UV_INDEX_*_PASSWORD`) with its value
See the [Deploy to AMP — Set Environment Variables](/en/enterprise/guides/deploy-to-amp#set-environment-variables) step for details.
</Tab>
<Tab title="CLI Deployment">
Add the variables to your local `.env` file before running `crewai deploy create`.
The CLI will securely transfer them to the platform:
```bash
# .env
OPENAI_API_KEY=sk-...
UV_INDEX_MY_PRIVATE_REGISTRY_USERNAME=token
UV_INDEX_MY_PRIVATE_REGISTRY_PASSWORD=your-pat-here
```
```bash
crewai deploy create
```
</Tab>
</Tabs>
<Warning>
**Never** commit credentials to your repository. Use AMP environment variables for all secrets.
The `.env` file should be listed in `.gitignore`.
</Warning>
To update credentials on an existing deployment, see [Update Your Crew — Environment Variables](/en/enterprise/guides/update-crew).
## How It All Fits Together
When CrewAI AMP builds your automation, the resolution flow works like this:
<Steps>
<Step title="Build starts">
AMP pulls your repository and reads `pyproject.toml` and `uv.lock`.
</Step>
<Step title="UV resolves dependencies">
UV reads `[tool.uv.sources]` to determine which index each package should come from.
</Step>
<Step title="UV authenticates">
For each private index, UV looks up `UV_INDEX_{NAME}_USERNAME` and `UV_INDEX_{NAME}_PASSWORD`
from the environment variables you configured in AMP.
</Step>
<Step title="Packages install">
UV downloads and installs all packages — both public (from PyPI) and private (from your registry).
</Step>
<Step title="Automation runs">
Your crew or flow starts with all dependencies available.
</Step>
</Steps>
## Troubleshooting
### Authentication Errors During Build
**Symptom**: Build fails with `401 Unauthorized` or `403 Forbidden` when resolving a private package.
**Check**:
- The `UV_INDEX_*` environment variable names match your index name exactly (uppercased, hyphens → underscores)
- Credentials are set in AMP environment variables, not just in a local `.env`
- Your token/PAT has the required read permissions for the package feed
- The token hasn't expired (especially relevant for AWS CodeArtifact)
### Package Not Found
**Symptom**: `No matching distribution found for my-private-package`.
**Check**:
- The index URL in `pyproject.toml` ends with `/simple/`
- The `[tool.uv.sources]` entry maps the correct package name to the correct index name
- The package is actually published to your private registry
- Run `uv lock` locally with the same credentials to verify resolution works
### Lock File Conflicts
**Symptom**: `uv lock` fails or produces unexpected results after adding a private index.
**Solution**: Set the credentials locally and regenerate:
```bash
export UV_INDEX_MY_PRIVATE_REGISTRY_USERNAME=token
export UV_INDEX_MY_PRIVATE_REGISTRY_PASSWORD=your-pat
uv lock
```
Then commit the updated `uv.lock`.
## Related Guides
<CardGroup cols={3}>
<Card title="Prepare for Deployment" icon="clipboard-check" href="/en/enterprise/guides/prepare-for-deployment">
Verify project structure and dependencies before deploying.
</Card>
<Card title="Deploy to AMP" icon="rocket" href="/en/enterprise/guides/deploy-to-amp">
Deploy your crew or flow and configure environment variables.
</Card>
<Card title="Update Your Crew" icon="arrows-rotate" href="/en/enterprise/guides/update-crew">
Update environment variables and push changes to a running deployment.
</Card>
</CardGroup>

View File

@@ -98,33 +98,43 @@ def handle_feedback(self, result):
When you specify `emit`, the decorator becomes a router. The human's free-form feedback is interpreted by an LLM and collapsed into one of the specified outcomes:
```python Code
@start()
@human_feedback(
message="Do you approve this content for publication?",
emit=["approved", "rejected", "needs_revision"],
llm="gpt-4o-mini",
default_outcome="needs_revision",
)
def review_content(self):
return "Draft blog post content here..."
from crewai.flow.flow import Flow, start, listen, or_
from crewai.flow.human_feedback import human_feedback
@listen("approved")
def publish(self, result):
print(f"Publishing! User said: {result.feedback}")
class ReviewFlow(Flow):
@start()
def generate_content(self):
return "Draft blog post content here..."
@listen("rejected")
def discard(self, result):
print(f"Discarding. Reason: {result.feedback}")
@human_feedback(
message="Do you approve this content for publication?",
emit=["approved", "rejected", "needs_revision"],
llm="gpt-4o-mini",
default_outcome="needs_revision",
)
@listen(or_("generate_content", "needs_revision"))
def review_content(self):
return "Draft blog post content here..."
@listen("needs_revision")
def revise(self, result):
print(f"Revising based on: {result.feedback}")
@listen("approved")
def publish(self, result):
print(f"Publishing! User said: {result.feedback}")
@listen("rejected")
def discard(self, result):
print(f"Discarding. Reason: {result.feedback}")
```
When the human says something like "needs more detail", the LLM collapses that to `"needs_revision"`, which triggers `review_content` again via `or_()` — creating a revision loop. The loop continues until the outcome is `"approved"` or `"rejected"`.
<Tip>
The LLM uses structured outputs (function calling) when available to guarantee the response is one of your specified outcomes. This makes routing reliable and predictable.
</Tip>
<Warning>
A `@start()` method only runs once at the beginning of the flow. If you need a revision loop, separate the start method from the review method and use `@listen(or_("trigger", "revision_outcome"))` on the review method to enable the self-loop.
</Warning>
## HumanFeedbackResult
The `HumanFeedbackResult` dataclass contains all information about a human feedback interaction:
@@ -188,127 +198,183 @@ Each `HumanFeedbackResult` is appended to `human_feedback_history`, so multiple
## Complete Example: Content Approval Workflow
Here's a full example implementing a content review and approval workflow:
Here's a full example implementing a content review and approval workflow with a revision loop:
<CodeGroup>
```python Code
from crewai.flow.flow import Flow, start, listen
from crewai.flow.flow import Flow, start, listen, or_
from crewai.flow.human_feedback import human_feedback, HumanFeedbackResult
from pydantic import BaseModel
class ContentState(BaseModel):
topic: str = ""
draft: str = ""
final_content: str = ""
revision_count: int = 0
status: str = "pending"
class ContentApprovalFlow(Flow[ContentState]):
"""A flow that generates content and gets human approval."""
"""A flow that generates content and loops until the human approves."""
@start()
def get_topic(self):
self.state.topic = input("What topic should I write about? ")
return self.state.topic
@listen(get_topic)
def generate_draft(self, topic):
# In real use, this would call an LLM
self.state.draft = f"# {topic}\n\nThis is a draft about {topic}..."
def generate_draft(self):
self.state.draft = "# AI Safety\n\nThis is a draft about AI Safety..."
return self.state.draft
@listen(generate_draft)
@human_feedback(
message="Please review this draft. Reply 'approved', 'rejected', or provide revision feedback:",
message="Please review this draft. Approve, reject, or describe what needs changing:",
emit=["approved", "rejected", "needs_revision"],
llm="gpt-4o-mini",
default_outcome="needs_revision",
)
def review_draft(self, draft):
return draft
@listen(or_("generate_draft", "needs_revision"))
def review_draft(self):
self.state.revision_count += 1
return f"{self.state.draft} (v{self.state.revision_count})"
@listen("approved")
def publish_content(self, result: HumanFeedbackResult):
self.state.final_content = result.output
print("\n✅ Content approved and published!")
print(f"Reviewer comment: {result.feedback}")
self.state.status = "published"
print(f"Content approved and published! Reviewer said: {result.feedback}")
return "published"
@listen("rejected")
def handle_rejection(self, result: HumanFeedbackResult):
print("\n❌ Content rejected")
print(f"Reason: {result.feedback}")
self.state.status = "rejected"
print(f"Content rejected. Reason: {result.feedback}")
return "rejected"
@listen("needs_revision")
def revise_content(self, result: HumanFeedbackResult):
self.state.revision_count += 1
print(f"\n📝 Revision #{self.state.revision_count} requested")
print(f"Feedback: {result.feedback}")
# In a real flow, you might loop back to generate_draft
# For this example, we just acknowledge
return "revision_requested"
# Run the flow
flow = ContentApprovalFlow()
result = flow.kickoff()
print(f"\nFlow completed. Revisions requested: {flow.state.revision_count}")
print(f"\nFlow completed. Status: {flow.state.status}, Reviews: {flow.state.revision_count}")
```
```text Output
What topic should I write about? AI Safety
==================================================
OUTPUT FOR REVIEW:
==================================================
# AI Safety
This is a draft about AI Safety... (v1)
==================================================
Please review this draft. Approve, reject, or describe what needs changing:
(Press Enter to skip, or type your feedback)
Your feedback: Needs more detail on alignment research
==================================================
OUTPUT FOR REVIEW:
==================================================
# AI Safety
This is a draft about AI Safety...
This is a draft about AI Safety... (v2)
==================================================
Please review this draft. Reply 'approved', 'rejected', or provide revision feedback:
Please review this draft. Approve, reject, or describe what needs changing:
(Press Enter to skip, or type your feedback)
Your feedback: Looks good, approved!
Content approved and published!
Reviewer comment: Looks good, approved!
Content approved and published! Reviewer said: Looks good, approved!
Flow completed. Revisions requested: 0
Flow completed. Status: published, Reviews: 2
```
</CodeGroup>
The key pattern is `@listen(or_("generate_draft", "needs_revision"))` — the review method listens to both the initial trigger and its own revision outcome, creating a self-loop that repeats until the human approves or rejects.
## Combining with Other Decorators
The `@human_feedback` decorator works with other flow decorators. Place it as the innermost decorator (closest to the function):
The `@human_feedback` decorator works with `@start()`, `@listen()`, and `or_()`. Both decorator orderings work — the framework propagates attributes in both directions — but the recommended patterns are:
```python Code
# Correct: @human_feedback is innermost (closest to the function)
# One-shot review at the start of a flow (no self-loop)
@start()
@human_feedback(message="Review this:")
@human_feedback(message="Review this:", emit=["approved", "rejected"], llm="gpt-4o-mini")
def my_start_method(self):
return "content"
# Linear review on a listener (no self-loop)
@listen(other_method)
@human_feedback(message="Review this too:")
@human_feedback(message="Review this too:", emit=["good", "bad"], llm="gpt-4o-mini")
def my_listener(self, data):
return f"processed: {data}"
# Self-loop: review that can loop back for revisions
@human_feedback(message="Approve or revise?", emit=["approved", "revise"], llm="gpt-4o-mini")
@listen(or_("upstream_method", "revise"))
def review_with_loop(self):
return "content for review"
```
<Tip>
Place `@human_feedback` as the innermost decorator (last/closest to the function) so it wraps the method directly and can capture the return value before passing to the flow system.
</Tip>
### Self-loop pattern
To create a revision loop, the review method must listen to **both** an upstream trigger and its own revision outcome using `or_()`:
```python Code
@start()
def generate(self):
return "initial draft"
@human_feedback(
message="Approve or request changes?",
emit=["revise", "approved"],
llm="gpt-4o-mini",
default_outcome="approved",
)
@listen(or_("generate", "revise"))
def review(self):
return "content"
@listen("approved")
def publish(self):
return "published"
```
When the outcome is `"revise"`, the flow routes back to `review` (because it listens to `"revise"` via `or_()`). When the outcome is `"approved"`, the flow continues to `publish`. This works because the flow engine exempts routers from the "fire once" rule, allowing them to re-execute on each loop iteration.
### Chained routers
A listener triggered by one router's outcome can itself be a router:
```python Code
@start()
def generate(self):
return "draft content"
@human_feedback(message="First review:", emit=["approved", "rejected"], llm="gpt-4o-mini")
@listen("generate")
def first_review(self):
return "draft content"
@human_feedback(message="Final review:", emit=["publish", "hold"], llm="gpt-4o-mini")
@listen("approved")
def final_review(self, prev):
return "final content"
@listen("publish")
def on_publish(self, prev):
return "published"
@listen("hold")
def on_hold(self, prev):
return "held for later"
```
### Limitations
- **`@start()` methods run once**: A `@start()` method cannot self-loop. If you need a revision cycle, use a separate `@start()` method as the entry point and put the `@human_feedback` on a `@listen()` method.
- **No `@start()` + `@listen()` on the same method**: This is a Flow framework constraint. A method is either a start point or a listener, not both.
## Best Practices
### 1. Write Clear Request Messages
The `request` parameter is what the human sees. Make it actionable:
The `message` parameter is what the human sees. Make it actionable:
```python Code
# ✅ Good - clear and actionable
@@ -516,9 +582,9 @@ class ContentPipeline(Flow):
@start()
@human_feedback(
message="Approve this content for publication?",
emit=["approved", "rejected", "needs_revision"],
emit=["approved", "rejected"],
llm="gpt-4o-mini",
default_outcome="needs_revision",
default_outcome="rejected",
provider=SlackNotificationProvider("#content-reviews"),
)
def generate_content(self):
@@ -534,11 +600,6 @@ class ContentPipeline(Flow):
print(f"Archived. Reason: {result.feedback}")
return {"status": "archived"}
@listen("needs_revision")
def queue_revision(self, result):
print(f"Queued for revision: {result.feedback}")
return {"status": "revision_needed"}
# Starting the flow (will pause and wait for Slack response)
def start_content_pipeline():
@@ -594,22 +655,22 @@ Over time, the human sees progressively better pre-reviewed output because each
```python Code
class ArticleReviewFlow(Flow):
@start()
def generate_article(self):
return self.crew.kickoff(inputs={"topic": "AI Safety"}).raw
@human_feedback(
message="Review this article draft:",
emit=["approved", "needs_revision"],
llm="gpt-4o-mini",
learn=True, # enable HITL learning
)
def generate_article(self):
return self.crew.kickoff(inputs={"topic": "AI Safety"}).raw
@listen(or_("generate_article", "needs_revision"))
def review_article(self):
return self.last_human_feedback.output if self.last_human_feedback else "article draft"
@listen("approved")
def publish(self):
print(f"Publishing: {self.last_human_feedback.output}")
@listen("needs_revision")
def revise(self):
print("Revising based on feedback...")
```
**First run**: The human sees the raw output and says "Always include citations for factual claims." The lesson is distilled and stored in memory.

View File

@@ -7,7 +7,7 @@ mode: "wide"
## Connect CrewAI to LLMs
CrewAI uses LiteLLM to connect to a wide variety of Language Models (LLMs). This integration provides extensive versatility, allowing you to use models from numerous providers with a simple, unified interface.
CrewAI connects to LLMs through native SDK integrations for the most popular providers (OpenAI, Anthropic, Google Gemini, Azure, and AWS Bedrock), and uses LiteLLM as a flexible fallback for all other providers.
<Note>
By default, CrewAI uses the `gpt-4o-mini` model. This is determined by the `OPENAI_MODEL_NAME` environment variable, which defaults to "gpt-4o-mini" if not set.
@@ -41,6 +41,14 @@ LiteLLM supports a wide range of providers, including but not limited to:
For a complete and up-to-date list of supported providers, please refer to the [LiteLLM Providers documentation](https://docs.litellm.ai/docs/providers).
<Info>
To use any provider not covered by a native integration, add LiteLLM as a dependency to your project:
```bash
uv add 'crewai[litellm]'
```
Native providers (OpenAI, Anthropic, Google Gemini, Azure, AWS Bedrock) use their own SDK extras — see the [Provider Configuration Examples](/en/concepts/llms#provider-configuration-examples).
</Info>
## Changing the LLM
To use a different LLM with your CrewAI agents, you have several options:

View File

@@ -35,7 +35,7 @@ Visit [app.crewai.com](https://app.crewai.com) and create your free account. Thi
If you haven't already, install CrewAI with the CLI tools:
```bash
uv add crewai[tools]
uv add 'crewai[tools]'
```
Then authenticate your CLI with your CrewAI AMP account:

View File

@@ -4,6 +4,56 @@ description: "CrewAI의 제품 업데이트, 개선 사항 및 버그 수정"
icon: "clock"
mode: "wide"
---
<Update label="2026년 2월 26일">
## v1.10.0
[GitHub 릴리스 보기](https://github.com/crewAIInc/crewAI/releases/tag/1.10.0)
## 변경 사항
### 기능
- MCP 도구 해상도 및 관련 이벤트 개선
- lancedb 버전 업데이트 및 lance-namespace 패키지 추가
- CrewAgentExecutor 및 BaseTool에서 JSON 인수 파싱 및 검증 개선
- CLI HTTP 클라이언트를 requests에서 httpx로 마이그레이션
- 버전화된 문서 추가
- 버전 노트에 대한 yanked 감지 추가
- Flows에서 사용자 입력 처리 구현
- 인간 피드백 통합 테스트에서 HITL 자기 루프 기능 개선
- eventbus에 started_event_id 추가 및 설정
- tools.specs 자동 업데이트
### 버그 수정
- 빈 경우에도 도구 kwargs를 검증하여 모호한 TypeError 방지
- LLM을 위한 도구 매개변수 스키마에서 null 타입 유지
- output_pydantic/output_json을 네이티브 구조화된 출력으로 매핑
- 약속이 있는 경우 콜백이 실행/대기되도록 보장
- 예외 컨텍스트에서 메서드 이름 캡처
- 라우터 결과에서 enum 타입 유지; 타입 개선
- 입력으로 지속성 ID가 전달될 때 조용히 깨지는 순환 흐름 수정
- CLI 플래그 형식을 --skip-provider에서 --skip_provider로 수정
- OpenAI 도구 호출 스트림이 완료되도록 보장
- MCP 도구에서 복잡한 스키마 $ref 포인터 해결
- 스키마에서 additionalProperties=false 강제 적용
- 크루 폴더에 대해 예약된 스크립트 이름 거부
- 가드레일 이벤트 방출 테스트에서 경쟁 조건 해결
### 문서
- 비네이티브 LLM 공급자를 위한 litellm 종속성 노트 추가
- NL2SQL 보안 모델 및 강화 지침 명확화
- 9개 통합에서 96개의 누락된 작업 추가
### 리팩토링
- crew를 provider로 리팩토링
- HITL을 provider 패턴으로 추출
- 훅 타이핑 및 등록 개선
## 기여자
@dependabot[bot], @github-actions[bot], @github-code-quality[bot], @greysonlalonde, @heitorado, @hobostay, @joaomdmoura, @johnvan7, @jonathansampson, @lorenzejay, @lucasgomide, @mattatcha, @mplachta, @nicoferdi96, @theCyberTech, @thiagomoretto, @vinibrsl
</Update>
<Update label="2026년 1월 26일">
## v1.9.0

View File

@@ -105,6 +105,15 @@ CrewAI 코드 내에는 사용할 모델을 지정할 수 있는 여러 위치
</Tab>
</Tabs>
<Info>
CrewAI는 OpenAI, Anthropic, Google (Gemini API), Azure, AWS Bedrock에 대해 네이티브 SDK 통합을 제공합니다 — 제공자별 extras(예: `uv add "crewai[openai]"`) 외에 추가 설치가 필요하지 않습니다.
그 외 모든 제공자는 **LiteLLM**을 통해 지원됩니다. 이를 사용하려면 프로젝트에 의존성으로 추가하세요:
```bash
uv add 'crewai[litellm]'
```
</Info>
## 공급자 구성 예시
CrewAI는 고유한 기능, 인증 방법, 모델 역량을 제공하는 다양한 LLM 공급자를 지원합니다.
@@ -214,6 +223,11 @@ CrewAI는 고유한 기능, 인증 방법, 모델 역량을 제공하는 다양
| `meta_llama/Llama-4-Maverick-17B-128E-Instruct-FP8` | 128k | 4028 | 텍스트, 이미지 | 텍스트 |
| `meta_llama/Llama-3.3-70B-Instruct` | 128k | 4028 | 텍스트 | 텍스트 |
| `meta_llama/Llama-3.3-8B-Instruct` | 128k | 4028 | 텍스트 | 텍스트 |
**참고:** 이 제공자는 LiteLLM을 사용합니다. 프로젝트에 의존성으로 추가하세요:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
<Accordion title="Anthropic">
@@ -354,6 +368,11 @@ CrewAI는 고유한 기능, 인증 방법, 모델 역량을 제공하는 다양
| gemini-1.5-flash | 1M 토큰 | 밸런스 잡힌 멀티모달 모델, 대부분의 작업에 적합 |
| gemini-1.5-flash-8B | 1M 토큰 | 가장 빠르고, 비용 효율적, 고빈도 작업에 적합 |
| gemini-1.5-pro | 2M 토큰 | 최고의 성능, 논리적 추론, 코딩, 창의적 협업 등 다양한 추론 작업에 적합 |
**참고:** 이 제공자는 LiteLLM을 사용합니다. 프로젝트에 의존성으로 추가하세요:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
<Accordion title="Azure">
@@ -439,6 +458,11 @@ CrewAI는 고유한 기능, 인증 방법, 모델 역량을 제공하는 다양
model="sagemaker/<my-endpoint>"
)
```
**참고:** 이 제공자는 LiteLLM을 사용합니다. 프로젝트에 의존성으로 추가하세요:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
<Accordion title="Mistral">
@@ -454,6 +478,11 @@ CrewAI는 고유한 기능, 인증 방법, 모델 역량을 제공하는 다양
temperature=0.7
)
```
**참고:** 이 제공자는 LiteLLM을 사용합니다. 프로젝트에 의존성으로 추가하세요:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
<Accordion title="Nvidia NIM">
@@ -540,6 +569,11 @@ CrewAI는 고유한 기능, 인증 방법, 모델 역량을 제공하는 다양
| rakuten/rakutenai-7b-instruct | 1,024 토큰 | 언어 이해, 추론, 텍스트 생성이 탁월한 최첨단 LLM |
| rakuten/rakutenai-7b-chat | 1,024 토큰 | 언어 이해, 추론, 텍스트 생성이 탁월한 최첨단 LLM |
| baichuan-inc/baichuan2-13b-chat | 4,096 토큰 | 중국어 및 영어 대화, 코딩, 수학, 지시 따르기, 퀴즈 풀이 지원 |
**참고:** 이 제공자는 LiteLLM을 사용합니다. 프로젝트에 의존성으로 추가하세요:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
<Accordion title="Local NVIDIA NIM Deployed using WSL2">
@@ -580,6 +614,11 @@ CrewAI는 고유한 기능, 인증 방법, 모델 역량을 제공하는 다양
# ...
```
**참고:** 이 제공자는 LiteLLM을 사용합니다. 프로젝트에 의존성으로 추가하세요:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
<Accordion title="Groq">
@@ -601,6 +640,11 @@ CrewAI는 고유한 기능, 인증 방법, 모델 역량을 제공하는 다양
| Llama 3.1 70B/8B| 131,072 토큰 | 고성능, 대용량 문맥 작업 |
| Llama 3.2 Series| 8,192 토큰 | 범용 작업 |
| Mixtral 8x7B | 32,768 토큰 | 성능과 문맥의 균형 |
**참고:** 이 제공자는 LiteLLM을 사용합니다. 프로젝트에 의존성으로 추가하세요:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
<Accordion title="IBM watsonx.ai">
@@ -623,6 +667,11 @@ CrewAI는 고유한 기능, 인증 방법, 모델 역량을 제공하는 다양
base_url="https://api.watsonx.ai/v1"
)
```
**참고:** 이 제공자는 LiteLLM을 사용합니다. 프로젝트에 의존성으로 추가하세요:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
<Accordion title="Ollama (Local LLMs)">
@@ -636,6 +685,11 @@ CrewAI는 고유한 기능, 인증 방법, 모델 역량을 제공하는 다양
base_url="http://localhost:11434"
)
```
**참고:** 이 제공자는 LiteLLM을 사용합니다. 프로젝트에 의존성으로 추가하세요:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
<Accordion title="Fireworks AI">
@@ -651,6 +705,11 @@ CrewAI는 고유한 기능, 인증 방법, 모델 역량을 제공하는 다양
temperature=0.7
)
```
**참고:** 이 제공자는 LiteLLM을 사용합니다. 프로젝트에 의존성으로 추가하세요:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
<Accordion title="Perplexity AI">
@@ -666,6 +725,11 @@ CrewAI는 고유한 기능, 인증 방법, 모델 역량을 제공하는 다양
base_url="https://api.perplexity.ai/"
)
```
**참고:** 이 제공자는 LiteLLM을 사용합니다. 프로젝트에 의존성으로 추가하세요:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
<Accordion title="Hugging Face">
@@ -680,6 +744,11 @@ CrewAI는 고유한 기능, 인증 방법, 모델 역량을 제공하는 다양
model="huggingface/meta-llama/Meta-Llama-3.1-8B-Instruct"
)
```
**참고:** 이 제공자는 LiteLLM을 사용합니다. 프로젝트에 의존성으로 추가하세요:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
<Accordion title="SambaNova">
@@ -703,6 +772,11 @@ CrewAI는 고유한 기능, 인증 방법, 모델 역량을 제공하는 다양
| Llama 3.2 Series| 8,192 토큰 | 범용, 멀티모달 작업 |
| Llama 3.3 70B | 최대 131,072 토큰 | 고성능, 높은 출력 품질 |
| Qwen2 familly | 8,192 토큰 | 고성능, 높은 출력 품질 |
**참고:** 이 제공자는 LiteLLM을 사용합니다. 프로젝트에 의존성으로 추가하세요:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
<Accordion title="Cerebras">
@@ -728,6 +802,11 @@ CrewAI는 고유한 기능, 인증 방법, 모델 역량을 제공하는 다양
- 속도와 품질의 우수한 밸런스
- 긴 컨텍스트 윈도우 지원
</Info>
**참고:** 이 제공자는 LiteLLM을 사용합니다. 프로젝트에 의존성으로 추가하세요:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
<Accordion title="Open Router">
@@ -750,6 +829,11 @@ CrewAI는 고유한 기능, 인증 방법, 모델 역량을 제공하는 다양
- openrouter/deepseek/deepseek-r1
- openrouter/deepseek/deepseek-chat
</Info>
**참고:** 이 제공자는 LiteLLM을 사용합니다. 프로젝트에 의존성으로 추가하세요:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
<Accordion title="Nebius AI Studio">
@@ -772,6 +856,11 @@ CrewAI는 고유한 기능, 인증 방법, 모델 역량을 제공하는 다양
- 경쟁력 있는 가격
- 속도와 품질의 우수한 밸런스
</Info>
**참고:** 이 제공자는 LiteLLM을 사용합니다. 프로젝트에 의존성으로 추가하세요:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
</AccordionGroup>

View File

@@ -38,22 +38,21 @@ CrewAI Enterprise는 AI 워크플로우를 협업적인 인간-AI 프로세스
`@human_feedback` 데코레이터를 사용하여 Flow 내에 인간 검토 체크포인트를 구성합니다. 실행이 검토 포인트에 도달하면 시스템이 일시 중지되고, 담당자에게 이메일로 알리며, 응답을 기다립니다.
```python
from crewai.flow.flow import Flow, start, listen
from crewai.flow.flow import Flow, start, listen, or_
from crewai.flow.human_feedback import human_feedback, HumanFeedbackResult
class ContentApprovalFlow(Flow):
@start()
def generate_content(self):
# AI가 콘텐츠 생성
return "Q1 캠페인용 마케팅 카피 생성..."
@listen(generate_content)
@human_feedback(
message="브랜드 준수를 위해 이 콘텐츠를 검토해 주세요:",
emit=["approved", "rejected", "needs_revision"],
)
def review_content(self, content):
return content
@listen(or_("generate_content", "needs_revision"))
def review_content(self):
return "검토용 마케팅 카피..."
@listen("approved")
def publish_content(self, result: HumanFeedbackResult):
@@ -62,10 +61,6 @@ class ContentApprovalFlow(Flow):
@listen("rejected")
def archive_content(self, result: HumanFeedbackResult):
print(f"콘텐츠 거부됨. 사유: {result.feedback}")
@listen("needs_revision")
def revise_content(self, result: HumanFeedbackResult):
print(f"수정 요청: {result.feedback}")
```
완전한 구현 세부 사항은 [Flow에서 인간 피드백](/ko/learn/human-feedback-in-flows) 가이드를 참조하세요.

View File

@@ -176,6 +176,11 @@ Crew를 GitHub 저장소에 푸시해야 합니다. 아직 Crew를 만들지 않
![Set Environment Variables](/images/enterprise/set-env-variables.png)
</Frame>
<Info>
프라이빗 Python 패키지를 사용하시나요? 여기에 레지스트리 자격 증명도 추가해야 합니다.
필요한 변수는 [프라이빗 패키지 레지스트리](/ko/enterprise/guides/private-package-registry)를 참조하세요.
</Info>
</Step>
<Step title="Crew 배포하기">

View File

@@ -256,6 +256,12 @@ Crews와 Flows 모두 `src/project_name/main.py`에 진입점이 있습니다:
1. **LLM API 키** (OpenAI, Anthropic, Google 등)
2. **도구 API 키** - 외부 도구를 사용하는 경우 (Serper 등)
<Info>
프로젝트가 **프라이빗 PyPI 레지스트리**의 패키지에 의존하는 경우, 레지스트리 인증 자격 증명도
환경 변수로 구성해야 합니다. 자세한 내용은
[프라이빗 패키지 레지스트리](/ko/enterprise/guides/private-package-registry) 가이드를 참조하세요.
</Info>
<Tip>
구성 문제를 조기에 발견하기 위해 배포 전에 동일한 환경 변수로
로컬에서 프로젝트를 테스트하세요.

View File

@@ -0,0 +1,261 @@
---
title: "프라이빗 패키지 레지스트리"
description: "CrewAI AMP에서 인증된 PyPI 레지스트리의 프라이빗 Python 패키지 설치하기"
icon: "lock"
mode: "wide"
---
<Note>
이 가이드는 CrewAI AMP에 배포할 때 프라이빗 PyPI 레지스트리(Azure DevOps Artifacts, GitHub Packages,
GitLab, AWS CodeArtifact 등)에서 Python 패키지를 설치하도록 CrewAI 프로젝트를 구성하는 방법을 다룹니다.
</Note>
## 이 가이드가 필요한 경우
프로젝트가 공개 PyPI가 아닌 프라이빗 레지스트리에 호스팅된 내부 또는 독점 Python 패키지에
의존하는 경우, 다음을 수행해야 합니다:
1. UV에 패키지를 **어디서** 찾을지 알려줍니다 (index URL)
2. UV에 **어떤** 패키지가 해당 index에서 오는지 알려줍니다 (source 매핑)
3. UV가 설치 중에 인증할 수 있도록 **자격 증명**을 제공합니다
CrewAI AMP는 의존성 해결 및 설치에 [UV](https://docs.astral.sh/uv/)를 사용합니다.
UV는 `pyproject.toml` 구성과 자격 증명용 환경 변수를 결합하여 인증된 프라이빗 레지스트리를 지원합니다.
## 1단계: pyproject.toml 구성
`pyproject.toml`에서 세 가지 요소가 함께 작동합니다:
### 1a. 의존성 선언
프라이빗 패키지를 다른 의존성과 마찬가지로 `[project.dependencies]`에 추가합니다:
```toml
[project]
dependencies = [
"crewai[tools]>=0.100.1,<1.0.0",
"my-private-package>=1.2.0",
]
```
### 1b. index 정의
프라이빗 레지스트리를 `[[tool.uv.index]]` 아래에 명명된 index로 등록합니다:
```toml
[[tool.uv.index]]
name = "my-private-registry"
url = "https://pkgs.dev.azure.com/my-org/_packaging/my-feed/pypi/simple/"
explicit = true
```
<Info>
`name` 필드는 중요합니다 — UV는 이를 사용하여 인증을 위한 환경 변수 이름을
구성합니다 (아래 [2단계](#2단계-인증-자격-증명-설정)를 참조하세요).
`explicit = true`를 설정하면 UV가 모든 패키지에 대해 이 index를 검색하지 않습니다 —
`[tool.uv.sources]`에서 명시적으로 매핑한 패키지만 검색합니다. 이렇게 하면 프라이빗
레지스트리에 대한 불필요한 쿼리를 방지하고 의존성 혼동 공격을 차단할 수 있습니다.
</Info>
### 1c. 패키지를 index에 매핑
`[tool.uv.sources]`를 사용하여 프라이빗 index에서 해결해야 할 패키지를 UV에 알려줍니다:
```toml
[tool.uv.sources]
my-private-package = { index = "my-private-registry" }
```
### 전체 예시
```toml
[project]
name = "my-crew-project"
version = "0.1.0"
requires-python = ">=3.10,<=3.13"
dependencies = [
"crewai[tools]>=0.100.1,<1.0.0",
"my-private-package>=1.2.0",
]
[tool.crewai]
type = "crew"
[[tool.uv.index]]
name = "my-private-registry"
url = "https://pkgs.dev.azure.com/my-org/_packaging/my-feed/pypi/simple/"
explicit = true
[tool.uv.sources]
my-private-package = { index = "my-private-registry" }
```
`pyproject.toml`을 업데이트한 후 lock 파일을 다시 생성합니다:
```bash
uv lock
```
<Warning>
업데이트된 `uv.lock`을 항상 `pyproject.toml` 변경 사항과 함께 커밋하세요.
lock 파일은 배포에 필수입니다 — [배포 준비하기](/ko/enterprise/guides/prepare-for-deployment)를 참조하세요.
</Warning>
## 2단계: 인증 자격 증명 설정
UV는 `pyproject.toml`에서 정의한 index 이름을 기반으로 한 명명 규칙을 따르는
환경 변수를 사용하여 프라이빗 index에 인증합니다:
```
UV_INDEX_{UPPER_NAME}_USERNAME
UV_INDEX_{UPPER_NAME}_PASSWORD
```
여기서 `{UPPER_NAME}`은 index 이름을 **대문자**로 변환하고 **하이픈을 언더스코어로 대체**한 것입니다.
예를 들어, `my-private-registry`라는 이름의 index는 다음을 사용합니다:
| 변수 | 값 |
|------|-----|
| `UV_INDEX_MY_PRIVATE_REGISTRY_USERNAME` | 레지스트리 사용자 이름 또는 토큰 이름 |
| `UV_INDEX_MY_PRIVATE_REGISTRY_PASSWORD` | 레지스트리 비밀번호 또는 토큰/PAT |
<Warning>
이 환경 변수는 CrewAI AMP **환경 변수** 설정을 통해 **반드시** 추가해야 합니다 —
전역적으로 또는 배포 수준에서. `.env` 파일에 설정하거나 프로젝트에 하드코딩할 수 없습니다.
아래 [AMP에서 환경 변수 설정](#amp에서-환경-변수-설정)을 참조하세요.
</Warning>
## 레지스트리 제공업체 참조
아래 표는 일반적인 레지스트리 제공업체의 index URL 형식과 자격 증명 값을 보여줍니다.
자리 표시자 값을 실제 조직 및 피드 세부 정보로 대체하세요.
| 제공업체 | Index URL | 사용자 이름 | 비밀번호 |
|---------|-----------|-----------|---------|
| **Azure DevOps Artifacts** | `https://pkgs.dev.azure.com/{org}/_packaging/{feed}/pypi/simple/` | 비어 있지 않은 임의의 문자열 (예: `token`) | Packaging Read 범위의 Personal Access Token (PAT) |
| **GitHub Packages** | `https://pypi.pkg.github.com/{owner}/simple/` | GitHub 사용자 이름 | `read:packages` 범위의 Personal Access Token (classic) |
| **GitLab Package Registry** | `https://gitlab.com/api/v4/projects/{project_id}/packages/pypi/simple/` | `__token__` | `read_api` 범위의 Project 또는 Personal Access Token |
| **AWS CodeArtifact** | `aws codeartifact get-repository-endpoint`의 URL 사용 | `aws` | `aws codeartifact get-authorization-token`의 토큰 |
| **Google Artifact Registry** | `https://{region}-python.pkg.dev/{project}/{repo}/simple/` | `_json_key_base64` | Base64로 인코딩된 서비스 계정 키 |
| **JFrog Artifactory** | `https://{instance}.jfrog.io/artifactory/api/pypi/{repo}/simple/` | 사용자 이름 또는 이메일 | API 키 또는 ID 토큰 |
| **자체 호스팅 (devpi, Nexus 등)** | 레지스트리의 simple API URL | 레지스트리 사용자 이름 | 레지스트리 비밀번호 |
<Tip>
**AWS CodeArtifact**의 경우 인증 토큰이 주기적으로 만료됩니다.
만료되면 `UV_INDEX_*_PASSWORD` 값을 갱신해야 합니다.
CI/CD 파이프라인에서 이를 자동화하는 것을 고려하세요.
</Tip>
## AMP에서 환경 변수 설정
프라이빗 레지스트리 자격 증명은 CrewAI AMP에서 환경 변수로 구성해야 합니다.
두 가지 옵션이 있습니다:
<Tabs>
<Tab title="웹 인터페이스">
1. [CrewAI AMP](https://app.crewai.com)에 로그인합니다
2. 자동화로 이동합니다
3. **Environment Variables** 탭을 엽니다
4. 각 변수 (`UV_INDEX_*_USERNAME` 및 `UV_INDEX_*_PASSWORD`)에 값을 추가합니다
자세한 내용은 [AMP에 배포하기 — 환경 변수 설정하기](/ko/enterprise/guides/deploy-to-amp#환경-변수-설정하기) 단계를 참조하세요.
</Tab>
<Tab title="CLI 배포">
`crewai deploy create`를 실행하기 전에 로컬 `.env` 파일에 변수를 추가합니다.
CLI가 이를 안전하게 플랫폼으로 전송합니다:
```bash
# .env
OPENAI_API_KEY=sk-...
UV_INDEX_MY_PRIVATE_REGISTRY_USERNAME=token
UV_INDEX_MY_PRIVATE_REGISTRY_PASSWORD=your-pat-here
```
```bash
crewai deploy create
```
</Tab>
</Tabs>
<Warning>
자격 증명을 저장소에 **절대** 커밋하지 마세요. 모든 비밀 정보에는 AMP 환경 변수를 사용하세요.
`.env` 파일은 `.gitignore`에 포함되어야 합니다.
</Warning>
기존 배포의 자격 증명을 업데이트하려면 [Crew 업데이트하기 — 환경 변수](/ko/enterprise/guides/update-crew)를 참조하세요.
## 전체 동작 흐름
CrewAI AMP가 자동화를 빌드할 때, 해결 흐름은 다음과 같이 작동합니다:
<Steps>
<Step title="빌드 시작">
AMP가 저장소를 가져오고 `pyproject.toml`과 `uv.lock`을 읽습니다.
</Step>
<Step title="UV가 의존성 해결">
UV가 `[tool.uv.sources]`를 읽어 각 패키지가 어떤 index에서 와야 하는지 결정합니다.
</Step>
<Step title="UV가 인증">
각 프라이빗 index에 대해 UV가 AMP에서 구성한 환경 변수에서
`UV_INDEX_{NAME}_USERNAME`과 `UV_INDEX_{NAME}_PASSWORD`를 조회합니다.
</Step>
<Step title="패키지 설치">
UV가 공개(PyPI) 및 프라이빗(레지스트리) 패키지를 모두 다운로드하고 설치합니다.
</Step>
<Step title="자동화 실행">
모든 의존성이 사용 가능한 상태에서 crew 또는 flow가 시작됩니다.
</Step>
</Steps>
## 문제 해결
### 빌드 중 인증 오류
**증상**: 프라이빗 패키지를 해결할 때 `401 Unauthorized` 또는 `403 Forbidden`으로 빌드가 실패합니다.
**확인사항**:
- `UV_INDEX_*` 환경 변수 이름이 index 이름과 정확히 일치하는지 확인합니다 (대문자, 하이픈 -> 언더스코어)
- 자격 증명이 로컬 `.env`뿐만 아니라 AMP 환경 변수에 설정되어 있는지 확인합니다
- 토큰/PAT에 패키지 피드에 필요한 읽기 권한이 있는지 확인합니다
- 토큰이 만료되지 않았는지 확인합니다 (특히 AWS CodeArtifact의 경우)
### 패키지를 찾을 수 없음
**증상**: `No matching distribution found for my-private-package`.
**확인사항**:
- `pyproject.toml`의 index URL이 `/simple/`로 끝나는지 확인합니다
- `[tool.uv.sources]` 항목이 올바른 패키지 이름을 올바른 index 이름에 매핑하는지 확인합니다
- 패키지가 실제로 프라이빗 레지스트리에 게시되어 있는지 확인합니다
- 동일한 자격 증명으로 로컬에서 `uv lock`을 실행하여 해결이 작동하는지 확인합니다
### Lock 파일 충돌
**증상**: 프라이빗 index를 추가한 후 `uv lock`이 실패하거나 예상치 못한 결과를 생성합니다.
**해결책**: 로컬에서 자격 증명을 설정하고 다시 생성합니다:
```bash
export UV_INDEX_MY_PRIVATE_REGISTRY_USERNAME=token
export UV_INDEX_MY_PRIVATE_REGISTRY_PASSWORD=your-pat
uv lock
```
그런 다음 업데이트된 `uv.lock`을 커밋합니다.
## 관련 가이드
<CardGroup cols={3}>
<Card title="배포 준비하기" icon="clipboard-check" href="/ko/enterprise/guides/prepare-for-deployment">
배포 전에 프로젝트 구조와 의존성을 확인합니다.
</Card>
<Card title="AMP에 배포하기" icon="rocket" href="/ko/enterprise/guides/deploy-to-amp">
crew 또는 flow를 배포하고 환경 변수를 구성합니다.
</Card>
<Card title="Crew 업데이트하기" icon="arrows-rotate" href="/ko/enterprise/guides/update-crew">
환경 변수를 업데이트하고 실행 중인 배포에 변경 사항을 푸시합니다.
</Card>
</CardGroup>

View File

@@ -98,33 +98,43 @@ def handle_feedback(self, result):
`emit`을 지정하면, 데코레이터는 라우터가 됩니다. 인간의 자유 형식 피드백이 LLM에 의해 해석되어 지정된 outcome 중 하나로 매핑됩니다:
```python Code
@start()
@human_feedback(
message="이 콘텐츠의 출판을 승인하시겠습니까?",
emit=["approved", "rejected", "needs_revision"],
llm="gpt-4o-mini",
default_outcome="needs_revision",
)
def review_content(self):
return "블로그 게시물 초안 내용..."
from crewai.flow.flow import Flow, start, listen, or_
from crewai.flow.human_feedback import human_feedback
@listen("approved")
def publish(self, result):
print(f"출판 중! 사용자 의견: {result.feedback}")
class ReviewFlow(Flow):
@start()
def generate_content(self):
return "블로그 게시물 초안 내용..."
@listen("rejected")
def discard(self, result):
print(f"폐기됨. 이유: {result.feedback}")
@human_feedback(
message="이 콘텐츠의 출판을 승인하시겠습니까?",
emit=["approved", "rejected", "needs_revision"],
llm="gpt-4o-mini",
default_outcome="needs_revision",
)
@listen(or_("generate_content", "needs_revision"))
def review_content(self):
return "블로그 게시물 초안 내용..."
@listen("needs_revision")
def revise(self, result):
print(f"다음을 기반으로 수정 중: {result.feedback}")
@listen("approved")
def publish(self, result):
print(f"출판 중! 사용자 의견: {result.feedback}")
@listen("rejected")
def discard(self, result):
print(f"폐기됨. 이유: {result.feedback}")
```
사용자가 "더 자세한 내용이 필요합니다"와 같이 말하면, LLM이 이를 `"needs_revision"`으로 매핑하고, `or_()`를 통해 `review_content`가 다시 트리거됩니다 — 수정 루프가 생성됩니다. outcome이 `"approved"` 또는 `"rejected"`가 될 때까지 루프가 계속됩니다.
<Tip>
LLM은 가능한 경우 구조화된 출력(function calling)을 사용하여 응답이 지정된 outcome 중 하나임을 보장합니다. 이로 인해 라우팅이 신뢰할 수 있고 예측 가능해집니다.
</Tip>
<Warning>
`@start()` 메서드는 flow 시작 시 한 번만 실행됩니다. 수정 루프가 필요한 경우, start 메서드를 review 메서드와 분리하고 review 메서드에 `@listen(or_("trigger", "revision_outcome"))`를 사용하여 self-loop을 활성화하세요.
</Warning>
## HumanFeedbackResult
`HumanFeedbackResult` 데이터클래스는 인간 피드백 상호작용에 대한 모든 정보를 포함합니다:
@@ -193,116 +203,162 @@ def summarize(self):
<CodeGroup>
```python Code
from crewai.flow.flow import Flow, start, listen
from crewai.flow.flow import Flow, start, listen, or_
from crewai.flow.human_feedback import human_feedback, HumanFeedbackResult
from pydantic import BaseModel
class ContentState(BaseModel):
topic: str = ""
draft: str = ""
final_content: str = ""
revision_count: int = 0
status: str = "pending"
class ContentApprovalFlow(Flow[ContentState]):
"""콘텐츠를 생성하고 인간의 승인을 받는 Flow입니다."""
"""콘텐츠를 생성하고 승인될 때까지 반복하는 Flow."""
@start()
def get_topic(self):
self.state.topic = input("어떤 주제에 대해 글을 쓸까요? ")
return self.state.topic
@listen(get_topic)
def generate_draft(self, topic):
# 실제 사용에서는 LLM을 호출합니다
self.state.draft = f"# {topic}\n\n{topic}에 대한 초안입니다..."
def generate_draft(self):
self.state.draft = "# AI 안전\n\nAI 안전에 대한 초안..."
return self.state.draft
@listen(generate_draft)
@human_feedback(
message="이 초안을 검토해 주세요. 'approved', 'rejected'로 답하거나 수정 피드백을 제공해 주세요:",
message="이 초안을 검토해 주세요. 승인, 거부 또는 변경이 필요한 사항을 설명해 주세요:",
emit=["approved", "rejected", "needs_revision"],
llm="gpt-4o-mini",
default_outcome="needs_revision",
)
def review_draft(self, draft):
return draft
@listen(or_("generate_draft", "needs_revision"))
def review_draft(self):
self.state.revision_count += 1
return f"{self.state.draft} (v{self.state.revision_count})"
@listen("approved")
def publish_content(self, result: HumanFeedbackResult):
self.state.final_content = result.output
print("\n✅ 콘텐츠 승인되어 출판되었습니다!")
print(f"검토자 코멘트: {result.feedback}")
self.state.status = "published"
print(f"콘텐츠 승인 및 게시! 리뷰어 의견: {result.feedback}")
return "published"
@listen("rejected")
def handle_rejection(self, result: HumanFeedbackResult):
print("\n❌ 콘텐츠가 거부되었습니다")
print(f"이유: {result.feedback}")
self.state.status = "rejected"
print(f"콘텐츠 거부됨. 이유: {result.feedback}")
return "rejected"
@listen("needs_revision")
def revise_content(self, result: HumanFeedbackResult):
self.state.revision_count += 1
print(f"\n📝 수정 #{self.state.revision_count} 요청됨")
print(f"피드백: {result.feedback}")
# 실제 Flow에서는 generate_draft로 돌아갈 수 있습니다
# 이 예제에서는 단순히 확인합니다
return "revision_requested"
# Flow 실행
flow = ContentApprovalFlow()
result = flow.kickoff()
print(f"\nFlow 완료. 요청된 수정: {flow.state.revision_count}")
print(f"\nFlow 완료. 상태: {flow.state.status}, 검토 횟수: {flow.state.revision_count}")
```
```text Output
어떤 주제에 대해 글을 쓸까요? AI 안전
==================================================
OUTPUT FOR REVIEW:
==================================================
# AI 안전
AI 안전에 대한 초안... (v1)
==================================================
이 초안을 검토해 주세요. 승인, 거부 또는 변경이 필요한 사항을 설명해 주세요:
(Press Enter to skip, or type your feedback)
Your feedback: 더 자세한 내용이 필요합니다
==================================================
OUTPUT FOR REVIEW:
==================================================
# AI 안전
AI 안전에 대한 초안입니다...
AI 안전에 대한 초안... (v2)
==================================================
이 초안을 검토해 주세요. 'approved', 'rejected'로 답하거나 수정 피드백을 제공해 주세요:
이 초안을 검토해 주세요. 승인, 거부 또는 변경이 필요한 사항을 설명해 주세요:
(Press Enter to skip, or type your feedback)
Your feedback: 좋아 보입니다, 승인!
콘텐츠 승인되어 출판되었습니다!
검토자 코멘트: 좋아 보입니다, 승인!
콘텐츠 승인 및 게시! 리뷰어 의견: 좋아 보입니다, 승인!
Flow 완료. 요청된 수정: 0
Flow 완료. 상태: published, 검토 횟수: 2
```
</CodeGroup>
## 다른 데코레이터와 결합하기
`@human_feedback` 데코레이터는 다른 Flow 데코레이터와 함께 작동합니다. 가장 안쪽 데코레이터(함수에 가장 가까운)로 배치하세요:
`@human_feedback` 데코레이터는 `@start()`, `@listen()`, `or_()`와 함께 작동합니다. 데코레이터 순서는 두 가지 모두 동작합니다—프레임워크가 양방향으로 속성을 전파합니다—하지만 권장 패턴은 다음과 같습니다:
```python Code
# 올바름: @human_feedback이 가장 안쪽(함수에 가장 가까움)
# Flow 시작 시 일회성 검토 (self-loop 없음)
@start()
@human_feedback(message="이것을 검토해 주세요:")
@human_feedback(message="이것을 검토해 주세요:", emit=["approved", "rejected"], llm="gpt-4o-mini")
def my_start_method(self):
return "content"
# 리스너에서 선형 검토 (self-loop 없음)
@listen(other_method)
@human_feedback(message="이것도 검토해 주세요:")
@human_feedback(message="이것도 검토해 주세요:", emit=["good", "bad"], llm="gpt-4o-mini")
def my_listener(self, data):
return f"processed: {data}"
# Self-loop: 수정을 위해 반복할 수 있는 검토
@human_feedback(message="승인 또는 수정 요청?", emit=["approved", "revise"], llm="gpt-4o-mini")
@listen(or_("upstream_method", "revise"))
def review_with_loop(self):
return "content for review"
```
<Tip>
`@human_feedback`를 가장 안쪽 데코레이터(마지막/함수에 가장 가까움)로 배치하여 메서드를 직접 래핑하고 Flow 시스템에 전달하기 전에 반환 값을 캡처할 수 있도록 하세요.
</Tip>
### Self-loop 패턴
수정 루프를 만들려면 `or_()`를 사용하여 검토 메서드가 **상위 트리거**와 **자체 수정 outcome**을 모두 리스닝해야 합니다:
```python Code
@start()
def generate(self):
return "initial draft"
@human_feedback(
message="승인하시겠습니까, 아니면 변경을 요청하시겠습니까?",
emit=["revise", "approved"],
llm="gpt-4o-mini",
default_outcome="approved",
)
@listen(or_("generate", "revise"))
def review(self):
return "content"
@listen("approved")
def publish(self):
return "published"
```
outcome이 `"revise"`이면 flow가 `review`로 다시 라우팅됩니다 (`or_()`를 통해 `"revise"`를 리스닝하기 때문). outcome이 `"approved"`이면 flow가 `publish`로 계속됩니다. flow 엔진이 라우터를 "한 번만 실행" 규칙에서 제외하여 각 루프 반복마다 재실행할 수 있기 때문에 이 패턴이 동작합니다.
### 체인된 라우터
한 라우터의 outcome으로 트리거된 리스너가 그 자체로 라우터가 될 수 있습니다:
```python Code
@start()
@human_feedback(message="첫 번째 검토:", emit=["approved", "rejected"], llm="gpt-4o-mini")
def draft(self):
return "draft content"
@listen("approved")
@human_feedback(message="최종 검토:", emit=["publish", "revise"], llm="gpt-4o-mini")
def final_review(self, prev):
return "final content"
@listen("publish")
def on_publish(self, prev):
return "published"
```
### 제한 사항
- **`@start()` 메서드는 한 번만 실행**: `@start()` 메서드는 self-loop할 수 없습니다. 수정 주기가 필요하면 별도의 `@start()` 메서드를 진입점으로 사용하고 `@listen()` 메서드에 `@human_feedback`를 배치하세요.
- **동일 메서드에 `@start()` + `@listen()` 불가**: 이는 Flow 프레임워크 제약입니다. 메서드는 시작점이거나 리스너여야 하며, 둘 다일 수 없습니다.
## 모범 사례
@@ -516,9 +572,9 @@ class ContentPipeline(Flow):
@start()
@human_feedback(
message="이 콘텐츠의 출판을 승인하시겠습니까?",
emit=["approved", "rejected", "needs_revision"],
emit=["approved", "rejected"],
llm="gpt-4o-mini",
default_outcome="needs_revision",
default_outcome="rejected",
provider=SlackNotificationProvider("#content-reviews"),
)
def generate_content(self):
@@ -534,11 +590,6 @@ class ContentPipeline(Flow):
print(f"보관됨. 이유: {result.feedback}")
return {"status": "archived"}
@listen("needs_revision")
def queue_revision(self, result):
print(f"수정 대기열에 추가됨: {result.feedback}")
return {"status": "revision_needed"}
# Flow 시작 (Slack 응답을 기다리며 일시 중지)
def start_content_pipeline():
@@ -594,22 +645,22 @@ async def on_slack_feedback_async(flow_id: str, slack_message: str):
```python Code
class ArticleReviewFlow(Flow):
@start()
@human_feedback(
message="Review this article draft:",
emit=["approved", "needs_revision"],
llm="gpt-4o-mini",
learn=True, # HITL 학습 활성화
)
def generate_article(self):
return self.crew.kickoff(inputs={"topic": "AI Safety"}).raw
@human_feedback(
message="이 글 초안을 검토해 주세요:",
emit=["approved", "needs_revision"],
llm="gpt-4o-mini",
learn=True,
)
@listen(or_("generate_article", "needs_revision"))
def review_article(self):
return self.last_human_feedback.output if self.last_human_feedback else "article draft"
@listen("approved")
def publish(self):
print(f"Publishing: {self.last_human_feedback.output}")
@listen("needs_revision")
def revise(self):
print("Revising based on feedback...")
```
**첫 번째 실행**: 인간이 원시 출력을 보고 "사실에 대한 주장에는 항상 인용을 포함하세요."라고 말합니다. 교훈이 추출되어 메모리에 저장됩니다.

View File

@@ -7,7 +7,7 @@ mode: "wide"
## CrewAI를 LLM에 연결하기
CrewAI는 LiteLLM을 사용하여 다양한 언어 모델(LLM)에 연결합니다. 이 통합은 높은 다양성을 제공하여, 여러 공급자의 모델을 간단하고 통합된 인터페이스로 사용할 수 있게 해줍니다.
CrewAI는 가장 인기 있는 제공자(OpenAI, Anthropic, Google Gemini, Azure, AWS Bedrock)에 대해 네이티브 SDK 통합을 통해 LLM에 연결하며, 그 외 모든 제공자에 대해서는 LiteLLM을 유연한 폴백으로 사용합니다.
<Note>
기본적으로 CrewAI는 `gpt-4o-mini` 모델을 사용합니다. 이는 `OPENAI_MODEL_NAME` 환경 변수에 의해 결정되며, 설정되지 않은 경우 기본값은 "gpt-4o-mini"입니다.
@@ -41,6 +41,14 @@ LiteLLM은 다음을 포함하되 이에 국한되지 않는 다양한 프로바
지원되는 프로바이더의 전체 및 최신 목록은 [LiteLLM 프로바이더 문서](https://docs.litellm.ai/docs/providers)를 참조하세요.
<Info>
네이티브 통합에서 지원하지 않는 제공자를 사용하려면 LiteLLM을 프로젝트에 의존성으로 추가하세요:
```bash
uv add 'crewai[litellm]'
```
네이티브 제공자(OpenAI, Anthropic, Google Gemini, Azure, AWS Bedrock)는 자체 SDK extras를 사용합니다 — [공급자 구성 예시](/ko/concepts/llms#공급자-구성-예시)를 참조하세요.
</Info>
## LLM 변경하기
CrewAI agent에서 다른 LLM을 사용하려면 여러 가지 방법이 있습니다:

View File

@@ -35,7 +35,7 @@ crewai login
아직 설치하지 않았다면 CLI 도구와 함께 CrewAI를 설치하세요:
```bash
uv add crewai[tools]
uv add 'crewai[tools]'
```
그런 다음 CrewAI AMP 계정으로 CLI를 인증하세요:

View File

@@ -4,6 +4,56 @@ description: "Atualizações de produto, melhorias e correções do CrewAI"
icon: "clock"
mode: "wide"
---
<Update label="26 fev 2026">
## v1.10.0
[Ver release no GitHub](https://github.com/crewAIInc/crewAI/releases/tag/1.10.0)
## O que Mudou
### Recursos
- Aprimorar a resolução da ferramenta MCP e eventos relacionados
- Atualizar a versão do lancedb e adicionar pacotes lance-namespace
- Aprimorar a análise e validação de argumentos JSON no CrewAgentExecutor e BaseTool
- Migrar o cliente HTTP da CLI de requests para httpx
- Adicionar documentação versionada
- Adicionar detecção de versões removidas para notas de versão
- Implementar tratamento de entrada do usuário em Flows
- Aprimorar a funcionalidade de auto-loop HITL nos testes de integração de feedback humano
- Adicionar started_event_id e definir no eventbus
- Atualizar automaticamente tools.specs
### Correções de Bugs
- Validar kwargs da ferramenta mesmo quando vazios para evitar TypeError crípticos
- Preservar tipos nulos nos esquemas de parâmetros da ferramenta para LLM
- Mapear output_pydantic/output_json para saída estruturada nativa
- Garantir que callbacks sejam executados/aguardados se forem promessas
- Capturar o nome do método no contexto da exceção
- Preservar tipo enum no resultado do roteador; melhorar tipos
- Corrigir fluxos cíclicos que quebram silenciosamente quando o ID de persistência é passado nas entradas
- Corrigir o formato da flag da CLI de --skip-provider para --skip_provider
- Garantir que o fluxo de chamada da ferramenta OpenAI seja finalizado
- Resolver ponteiros $ref de esquema complexos nas ferramentas MCP
- Impor additionalProperties=false nos esquemas
- Rejeitar nomes de scripts reservados para pastas de equipe
- Resolver condição de corrida no teste de emissão de eventos de guardrail
### Documentação
- Adicionar nota de dependência litellm para provedores de LLM não nativos
- Esclarecer o modelo de segurança NL2SQL e orientações de fortalecimento
- Adicionar 96 ações ausentes em 9 integrações
### Refatoração
- Refatorar crew para provider
- Extrair HITL para padrão de provider
- Melhorar tipagem e registro de hooks
## Contribuidores
@dependabot[bot], @github-actions[bot], @github-code-quality[bot], @greysonlalonde, @heitorado, @hobostay, @joaomdmoura, @johnvan7, @jonathansampson, @lorenzejay, @lucasgomide, @mattatcha, @mplachta, @nicoferdi96, @theCyberTech, @thiagomoretto, @vinibrsl
</Update>
<Update label="26 jan 2026">
## v1.9.0

View File

@@ -105,6 +105,15 @@ Existem diferentes locais no código do CrewAI onde você pode especificar o mod
</Tab>
</Tabs>
<Info>
O CrewAI oferece integrações nativas via SDK para OpenAI, Anthropic, Google (Gemini API), Azure e AWS Bedrock — sem necessidade de instalação extra além dos extras específicos do provedor (ex.: `uv add "crewai[openai]"`).
Todos os outros provedores são alimentados pelo **LiteLLM**. Se você planeja usar algum deles, adicione-o como dependência ao seu projeto:
```bash
uv add 'crewai[litellm]'
```
</Info>
## Exemplos de Configuração de Provedores
O CrewAI suporta uma grande variedade de provedores de LLM, cada um com recursos, métodos de autenticação e capacidades de modelo únicos.
@@ -214,6 +223,11 @@ Nesta seção, você encontrará exemplos detalhados que ajudam a selecionar, co
| `meta_llama/Llama-4-Maverick-17B-128E-Instruct-FP8` | 128k | 4028 | Texto, Imagem | Texto |
| `meta_llama/Llama-3.3-70B-Instruct` | 128k | 4028 | Texto | Texto |
| `meta_llama/Llama-3.3-8B-Instruct` | 128k | 4028 | Texto | Texto |
**Nota:** Este provedor usa o LiteLLM. Adicione-o como dependência ao seu projeto:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
<Accordion title="Anthropic">
@@ -354,6 +368,11 @@ Nesta seção, você encontrará exemplos detalhados que ajudam a selecionar, co
| gemini-1.5-flash | 1M tokens | Modelo multimodal equilibrado, bom para maioria das tarefas |
| gemini-1.5-flash-8B | 1M tokens | Mais rápido, mais eficiente em custo, adequado para tarefas de alta frequência |
| gemini-1.5-pro | 2M tokens | Melhor desempenho para uma ampla variedade de tarefas de raciocínio, incluindo lógica, codificação e colaboração criativa |
**Nota:** Este provedor usa o LiteLLM. Adicione-o como dependência ao seu projeto:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
<Accordion title="Azure">
@@ -438,6 +457,11 @@ Nesta seção, você encontrará exemplos detalhados que ajudam a selecionar, co
model="sagemaker/<my-endpoint>"
)
```
**Nota:** Este provedor usa o LiteLLM. Adicione-o como dependência ao seu projeto:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
<Accordion title="Mistral">
@@ -453,6 +477,11 @@ Nesta seção, você encontrará exemplos detalhados que ajudam a selecionar, co
temperature=0.7
)
```
**Nota:** Este provedor usa o LiteLLM. Adicione-o como dependência ao seu projeto:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
<Accordion title="Nvidia NIM">
@@ -539,6 +568,11 @@ Nesta seção, você encontrará exemplos detalhados que ajudam a selecionar, co
| rakuten/rakutenai-7b-instruct | 1.024 tokens | LLM topo de linha, compreensão, raciocínio e geração textual.|
| rakuten/rakutenai-7b-chat | 1.024 tokens | LLM topo de linha, compreensão, raciocínio e geração textual.|
| baichuan-inc/baichuan2-13b-chat | 4.096 tokens | Suporte a chat em chinês/inglês, programação, matemática, seguir instruções, resolver quizzes.|
**Nota:** Este provedor usa o LiteLLM. Adicione-o como dependência ao seu projeto:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
<Accordion title="Local NVIDIA NIM Deployed using WSL2">
@@ -579,6 +613,11 @@ Nesta seção, você encontrará exemplos detalhados que ajudam a selecionar, co
# ...
```
**Nota:** Este provedor usa o LiteLLM. Adicione-o como dependência ao seu projeto:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
<Accordion title="Groq">
@@ -600,6 +639,11 @@ Nesta seção, você encontrará exemplos detalhados que ajudam a selecionar, co
| Llama 3.1 70B/8B | 131.072 tokens | Alta performance e tarefas de contexto grande|
| Llama 3.2 Série | 8.192 tokens | Tarefas gerais |
| Mixtral 8x7B | 32.768 tokens | Equilíbrio entre performance e contexto |
**Nota:** Este provedor usa o LiteLLM. Adicione-o como dependência ao seu projeto:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
<Accordion title="IBM watsonx.ai">
@@ -622,6 +666,11 @@ Nesta seção, você encontrará exemplos detalhados que ajudam a selecionar, co
base_url="https://api.watsonx.ai/v1"
)
```
**Nota:** Este provedor usa o LiteLLM. Adicione-o como dependência ao seu projeto:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
<Accordion title="Ollama (LLMs Locais)">
@@ -635,6 +684,11 @@ Nesta seção, você encontrará exemplos detalhados que ajudam a selecionar, co
base_url="http://localhost:11434"
)
```
**Nota:** Este provedor usa o LiteLLM. Adicione-o como dependência ao seu projeto:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
<Accordion title="Fireworks AI">
@@ -650,6 +704,11 @@ Nesta seção, você encontrará exemplos detalhados que ajudam a selecionar, co
temperature=0.7
)
```
**Nota:** Este provedor usa o LiteLLM. Adicione-o como dependência ao seu projeto:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
<Accordion title="Perplexity AI">
@@ -665,6 +724,11 @@ Nesta seção, você encontrará exemplos detalhados que ajudam a selecionar, co
base_url="https://api.perplexity.ai/"
)
```
**Nota:** Este provedor usa o LiteLLM. Adicione-o como dependência ao seu projeto:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
<Accordion title="Hugging Face">
@@ -679,6 +743,11 @@ Nesta seção, você encontrará exemplos detalhados que ajudam a selecionar, co
model="huggingface/meta-llama/Meta-Llama-3.1-8B-Instruct"
)
```
**Nota:** Este provedor usa o LiteLLM. Adicione-o como dependência ao seu projeto:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
<Accordion title="SambaNova">
@@ -702,6 +771,11 @@ Nesta seção, você encontrará exemplos detalhados que ajudam a selecionar, co
| Llama 3.2 Série | 8.192 tokens | Tarefas gerais e multimodais |
| Llama 3.3 70B | Até 131.072 tokens | Desempenho e qualidade de saída elevada |
| Família Qwen2 | 8.192 tokens | Desempenho e qualidade de saída elevada |
**Nota:** Este provedor usa o LiteLLM. Adicione-o como dependência ao seu projeto:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
<Accordion title="Cerebras">
@@ -727,6 +801,11 @@ Nesta seção, você encontrará exemplos detalhados que ajudam a selecionar, co
- Equilíbrio entre velocidade e qualidade
- Suporte a longas janelas de contexto
</Info>
**Nota:** Este provedor usa o LiteLLM. Adicione-o como dependência ao seu projeto:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
<Accordion title="Open Router">
@@ -749,6 +828,11 @@ Nesta seção, você encontrará exemplos detalhados que ajudam a selecionar, co
- openrouter/deepseek/deepseek-r1
- openrouter/deepseek/deepseek-chat
</Info>
**Nota:** Este provedor usa o LiteLLM. Adicione-o como dependência ao seu projeto:
```bash
uv add 'crewai[litellm]'
```
</Accordion>
</AccordionGroup>

View File

@@ -38,22 +38,21 @@ O CrewAI Enterprise oferece um sistema abrangente de gerenciamento Human-in-the-
Configure checkpoints de revisão humana em seus Flows usando o decorador `@human_feedback`. Quando a execução atinge um ponto de revisão, o sistema pausa, notifica o responsável via email e aguarda uma resposta.
```python
from crewai.flow.flow import Flow, start, listen
from crewai.flow.flow import Flow, start, listen, or_
from crewai.flow.human_feedback import human_feedback, HumanFeedbackResult
class ContentApprovalFlow(Flow):
@start()
def generate_content(self):
# IA gera conteúdo
return "Texto de marketing gerado para campanha Q1..."
@listen(generate_content)
@human_feedback(
message="Por favor, revise este conteúdo para conformidade com a marca:",
emit=["approved", "rejected", "needs_revision"],
)
def review_content(self, content):
return content
@listen(or_("generate_content", "needs_revision"))
def review_content(self):
return "Texto de marketing para revisão..."
@listen("approved")
def publish_content(self, result: HumanFeedbackResult):
@@ -62,10 +61,6 @@ class ContentApprovalFlow(Flow):
@listen("rejected")
def archive_content(self, result: HumanFeedbackResult):
print(f"Conteúdo rejeitado. Motivo: {result.feedback}")
@listen("needs_revision")
def revise_content(self, result: HumanFeedbackResult):
print(f"Revisão solicitada: {result.feedback}")
```
Para detalhes completos de implementação, consulte o guia [Feedback Humano em Flows](/pt-BR/learn/human-feedback-in-flows).

View File

@@ -176,6 +176,11 @@ Você precisa enviar seu crew para um repositório do GitHub. Caso ainda não te
![Definir Variáveis de Ambiente](/images/enterprise/set-env-variables.png)
</Frame>
<Info>
Usando pacotes Python privados? Você também precisará adicionar suas credenciais de registro aqui.
Consulte [Registros de Pacotes Privados](/pt-BR/enterprise/guides/private-package-registry) para as variáveis necessárias.
</Info>
</Step>
<Step title="Implante Seu Crew">

View File

@@ -256,6 +256,12 @@ Antes da implantação, certifique-se de ter:
1. **Chaves de API de LLM** prontas (OpenAI, Anthropic, Google, etc.)
2. **Chaves de API de ferramentas** se estiver usando ferramentas externas (Serper, etc.)
<Info>
Se seu projeto depende de pacotes de um **registro PyPI privado**, você também precisará configurar
credenciais de autenticação do registro como variáveis de ambiente. Consulte o guia
[Registros de Pacotes Privados](/pt-BR/enterprise/guides/private-package-registry) para mais detalhes.
</Info>
<Tip>
Teste seu projeto localmente com as mesmas variáveis de ambiente antes de implantar
para detectar problemas de configuração antecipadamente.

View File

@@ -0,0 +1,263 @@
---
title: "Registros de Pacotes Privados"
description: "Instale pacotes Python privados de registros PyPI autenticados no CrewAI AMP"
icon: "lock"
mode: "wide"
---
<Note>
Este guia aborda como configurar seu projeto CrewAI para instalar pacotes Python
de registros PyPI privados (Azure DevOps Artifacts, GitHub Packages, GitLab, AWS CodeArtifact, etc.)
ao implantar no CrewAI AMP.
</Note>
## Quando Você Precisa Disso
Se seu projeto depende de pacotes Python internos ou proprietários hospedados em um registro privado
em vez do PyPI público, você precisará:
1. Informar ao UV **onde** encontrar o pacote (uma URL de index)
2. Informar ao UV **quais** pacotes vêm desse index (um mapeamento de source)
3. Fornecer **credenciais** para que o UV possa autenticar durante a instalação
O CrewAI AMP usa [UV](https://docs.astral.sh/uv/) para resolução e instalação de dependências.
O UV suporta registros privados autenticados por meio da configuração do `pyproject.toml` combinada
com variáveis de ambiente para credenciais.
## Passo 1: Configurar o pyproject.toml
Três elementos trabalham juntos no seu `pyproject.toml`:
### 1a. Declarar a dependência
Adicione o pacote privado ao seu `[project.dependencies]` como qualquer outra dependência:
```toml
[project]
dependencies = [
"crewai[tools]>=0.100.1,<1.0.0",
"my-private-package>=1.2.0",
]
```
### 1b. Definir o index
Registre seu registro privado como um index nomeado em `[[tool.uv.index]]`:
```toml
[[tool.uv.index]]
name = "my-private-registry"
url = "https://pkgs.dev.azure.com/my-org/_packaging/my-feed/pypi/simple/"
explicit = true
```
<Info>
O campo `name` é importante — o UV o utiliza para construir os nomes das variáveis de ambiente
para autenticação (veja o [Passo 2](#passo-2-configurar-credenciais-de-autenticação) abaixo).
Definir `explicit = true` significa que o UV não consultará esse index para todos os pacotes — apenas
os que você mapear explicitamente em `[tool.uv.sources]`. Isso evita consultas desnecessárias
ao seu registro privado e protege contra ataques de confusão de dependências.
</Info>
### 1c. Mapear o pacote para o index
Informe ao UV quais pacotes devem ser resolvidos a partir do seu index privado usando `[tool.uv.sources]`:
```toml
[tool.uv.sources]
my-private-package = { index = "my-private-registry" }
```
### Exemplo completo
```toml
[project]
name = "my-crew-project"
version = "0.1.0"
requires-python = ">=3.10,<=3.13"
dependencies = [
"crewai[tools]>=0.100.1,<1.0.0",
"my-private-package>=1.2.0",
]
[tool.crewai]
type = "crew"
[[tool.uv.index]]
name = "my-private-registry"
url = "https://pkgs.dev.azure.com/my-org/_packaging/my-feed/pypi/simple/"
explicit = true
[tool.uv.sources]
my-private-package = { index = "my-private-registry" }
```
Após atualizar o `pyproject.toml`, regenere seu arquivo lock:
```bash
uv lock
```
<Warning>
Sempre faça commit do `uv.lock` atualizado junto com as alterações no `pyproject.toml`.
O arquivo lock é obrigatório para implantação — veja [Preparar para Implantação](/pt-BR/enterprise/guides/prepare-for-deployment).
</Warning>
## Passo 2: Configurar Credenciais de Autenticação
O UV autentica em indexes privados usando variáveis de ambiente que seguem uma convenção de nomenclatura
baseada no nome do index que você definiu no `pyproject.toml`:
```
UV_INDEX_{UPPER_NAME}_USERNAME
UV_INDEX_{UPPER_NAME}_PASSWORD
```
Onde `{UPPER_NAME}` é o nome do seu index convertido para **maiúsculas** com **hifens substituídos por underscores**.
Por exemplo, um index chamado `my-private-registry` usa:
| Variável | Valor |
|----------|-------|
| `UV_INDEX_MY_PRIVATE_REGISTRY_USERNAME` | Seu nome de usuário ou nome do token do registro |
| `UV_INDEX_MY_PRIVATE_REGISTRY_PASSWORD` | Sua senha ou token/PAT do registro |
<Warning>
Essas variáveis de ambiente **devem** ser adicionadas pelas configurações de **Variáveis de Ambiente** do CrewAI AMP —
globalmente ou no nível da implantação. Elas não podem ser definidas em arquivos `.env` ou codificadas no seu projeto.
Veja [Configurar Variáveis de Ambiente no AMP](#configurar-variáveis-de-ambiente-no-amp) abaixo.
</Warning>
## Referência de Provedores de Registro
A tabela abaixo mostra o formato da URL de index e os valores de credenciais para provedores de registro comuns.
Substitua os valores de exemplo pelos detalhes reais da sua organização e feed.
| Provedor | URL do Index | Usuário | Senha |
|----------|-------------|---------|-------|
| **Azure DevOps Artifacts** | `https://pkgs.dev.azure.com/{org}/_packaging/{feed}/pypi/simple/` | Qualquer string não vazia (ex: `token`) | Personal Access Token (PAT) com escopo Packaging Read |
| **GitHub Packages** | `https://pypi.pkg.github.com/{owner}/simple/` | Nome de usuário do GitHub | Personal Access Token (classic) com escopo `read:packages` |
| **GitLab Package Registry** | `https://gitlab.com/api/v4/projects/{project_id}/packages/pypi/simple/` | `__token__` | Project ou Personal Access Token com escopo `read_api` |
| **AWS CodeArtifact** | Use a URL de `aws codeartifact get-repository-endpoint` | `aws` | Token de `aws codeartifact get-authorization-token` |
| **Google Artifact Registry** | `https://{region}-python.pkg.dev/{project}/{repo}/simple/` | `_json_key_base64` | Chave de conta de serviço codificada em Base64 |
| **JFrog Artifactory** | `https://{instance}.jfrog.io/artifactory/api/pypi/{repo}/simple/` | Nome de usuário ou email | Chave API ou token de identidade |
| **Auto-hospedado (devpi, Nexus, etc.)** | URL da API simple do seu registro | Nome de usuário do registro | Senha do registro |
<Tip>
Para **AWS CodeArtifact**, o token de autorização expira periodicamente.
Você precisará atualizar o valor de `UV_INDEX_*_PASSWORD` quando ele expirar.
Considere automatizar isso no seu pipeline de CI/CD.
</Tip>
## Configurar Variáveis de Ambiente no AMP
As credenciais do registro privado devem ser configuradas como variáveis de ambiente no CrewAI AMP.
Você tem duas opções:
<Tabs>
<Tab title="Interface Web">
1. Faça login no [CrewAI AMP](https://app.crewai.com)
2. Navegue até sua automação
3. Abra a aba **Environment Variables**
4. Adicione cada variável (`UV_INDEX_*_USERNAME` e `UV_INDEX_*_PASSWORD`) com seu valor
Veja o passo [Deploy para AMP — Definir Variáveis de Ambiente](/pt-BR/enterprise/guides/deploy-to-amp#definir-as-variáveis-de-ambiente) para detalhes.
</Tab>
<Tab title="Implantação via CLI">
Adicione as variáveis ao seu arquivo `.env` local antes de executar `crewai deploy create`.
A CLI as transferirá com segurança para a plataforma:
```bash
# .env
OPENAI_API_KEY=sk-...
UV_INDEX_MY_PRIVATE_REGISTRY_USERNAME=token
UV_INDEX_MY_PRIVATE_REGISTRY_PASSWORD=your-pat-here
```
```bash
crewai deploy create
```
</Tab>
</Tabs>
<Warning>
**Nunca** faça commit de credenciais no seu repositório. Use variáveis de ambiente do AMP para todos os segredos.
O arquivo `.env` deve estar listado no `.gitignore`.
</Warning>
Para atualizar credenciais em uma implantação existente, veja [Atualizar Seu Crew — Variáveis de Ambiente](/pt-BR/enterprise/guides/update-crew).
## Como Tudo se Conecta
Quando o CrewAI AMP faz o build da sua automação, o fluxo de resolução funciona assim:
<Steps>
<Step title="Build inicia">
O AMP busca seu repositório e lê o `pyproject.toml` e o `uv.lock`.
</Step>
<Step title="UV resolve dependências">
O UV lê `[tool.uv.sources]` para determinar de qual index cada pacote deve vir.
</Step>
<Step title="UV autentica">
Para cada index privado, o UV busca `UV_INDEX_{NAME}_USERNAME` e `UV_INDEX_{NAME}_PASSWORD`
nas variáveis de ambiente que você configurou no AMP.
</Step>
<Step title="Pacotes são instalados">
O UV baixa e instala todos os pacotes — tanto públicos (do PyPI) quanto privados (do seu registro).
</Step>
<Step title="Automação executa">
Seu crew ou flow inicia com todas as dependências disponíveis.
</Step>
</Steps>
## Solução de Problemas
### Erros de Autenticação Durante o Build
**Sintoma**: Build falha com `401 Unauthorized` ou `403 Forbidden` ao resolver um pacote privado.
**Verifique**:
- Os nomes das variáveis de ambiente `UV_INDEX_*` correspondem exatamente ao nome do seu index (maiúsculas, hifens -> underscores)
- As credenciais estão definidas nas variáveis de ambiente do AMP, não apenas em um `.env` local
- Seu token/PAT tem as permissões de leitura necessárias para o feed de pacotes
- O token não expirou (especialmente relevante para AWS CodeArtifact)
### Pacote Não Encontrado
**Sintoma**: `No matching distribution found for my-private-package`.
**Verifique**:
- A URL do index no `pyproject.toml` termina com `/simple/`
- A entrada `[tool.uv.sources]` mapeia o nome correto do pacote para o nome correto do index
- O pacote está realmente publicado no seu registro privado
- Execute `uv lock` localmente com as mesmas credenciais para verificar se a resolução funciona
### Conflitos no Arquivo Lock
**Sintoma**: `uv lock` falha ou produz resultados inesperados após adicionar um index privado.
**Solução**: Defina as credenciais localmente e regenere:
```bash
export UV_INDEX_MY_PRIVATE_REGISTRY_USERNAME=token
export UV_INDEX_MY_PRIVATE_REGISTRY_PASSWORD=your-pat
uv lock
```
Em seguida, faça commit do `uv.lock` atualizado.
## Guias Relacionados
<CardGroup cols={3}>
<Card title="Preparar para Implantação" icon="clipboard-check" href="/pt-BR/enterprise/guides/prepare-for-deployment">
Verifique a estrutura do projeto e as dependências antes de implantar.
</Card>
<Card title="Deploy para AMP" icon="rocket" href="/pt-BR/enterprise/guides/deploy-to-amp">
Implante seu crew ou flow e configure variáveis de ambiente.
</Card>
<Card title="Atualizar Seu Crew" icon="arrows-rotate" href="/pt-BR/enterprise/guides/update-crew">
Atualize variáveis de ambiente e envie alterações para uma implantação em execução.
</Card>
</CardGroup>

View File

@@ -98,33 +98,43 @@ def handle_feedback(self, result):
Quando você especifica `emit`, o decorador se torna um roteador. O feedback livre do humano é interpretado por um LLM e mapeado para um dos outcomes especificados:
```python Code
@start()
@human_feedback(
message="Você aprova este conteúdo para publicação?",
emit=["approved", "rejected", "needs_revision"],
llm="gpt-4o-mini",
default_outcome="needs_revision",
)
def review_content(self):
return "Rascunho do post do blog aqui..."
from crewai.flow.flow import Flow, start, listen, or_
from crewai.flow.human_feedback import human_feedback
@listen("approved")
def publish(self, result):
print(f"Publicando! Usuário disse: {result.feedback}")
class ReviewFlow(Flow):
@start()
def generate_content(self):
return "Rascunho do post do blog aqui..."
@listen("rejected")
def discard(self, result):
print(f"Descartando. Motivo: {result.feedback}")
@human_feedback(
message="Você aprova este conteúdo para publicação?",
emit=["approved", "rejected", "needs_revision"],
llm="gpt-4o-mini",
default_outcome="needs_revision",
)
@listen(or_("generate_content", "needs_revision"))
def review_content(self):
return "Rascunho do post do blog aqui..."
@listen("needs_revision")
def revise(self, result):
print(f"Revisando baseado em: {result.feedback}")
@listen("approved")
def publish(self, result):
print(f"Publicando! Usuário disse: {result.feedback}")
@listen("rejected")
def discard(self, result):
print(f"Descartando. Motivo: {result.feedback}")
```
Quando o humano diz algo como "precisa de mais detalhes", o LLM mapeia para `"needs_revision"`, que dispara `review_content` novamente via `or_()` — criando um loop de revisão. O loop continua até que o outcome seja `"approved"` ou `"rejected"`.
<Tip>
O LLM usa saídas estruturadas (function calling) quando disponível para garantir que a resposta seja um dos seus outcomes especificados. Isso torna o roteamento confiável e previsível.
</Tip>
<Warning>
Um método `@start()` só executa uma vez no início do flow. Se você precisa de um loop de revisão, separe o método start do método de revisão e use `@listen(or_("trigger", "revision_outcome"))` no método de revisão para habilitar o self-loop.
</Warning>
## HumanFeedbackResult
O dataclass `HumanFeedbackResult` contém todas as informações sobre uma interação de feedback humano:
@@ -193,116 +203,162 @@ Aqui está um exemplo completo implementando um fluxo de revisão e aprovação
<CodeGroup>
```python Code
from crewai.flow.flow import Flow, start, listen
from crewai.flow.flow import Flow, start, listen, or_
from crewai.flow.human_feedback import human_feedback, HumanFeedbackResult
from pydantic import BaseModel
class ContentState(BaseModel):
topic: str = ""
draft: str = ""
final_content: str = ""
revision_count: int = 0
status: str = "pending"
class ContentApprovalFlow(Flow[ContentState]):
"""Um flow que gera conteúdo e obtém aprovação humana."""
"""Um flow que gera conteúdo e faz loop até o humano aprovar."""
@start()
def get_topic(self):
self.state.topic = input("Sobre qual tópico devo escrever? ")
return self.state.topic
@listen(get_topic)
def generate_draft(self, topic):
# Em uso real, isso chamaria um LLM
self.state.draft = f"# {topic}\n\nEste é um rascunho sobre {topic}..."
def generate_draft(self):
self.state.draft = "# IA Segura\n\nEste é um rascunho sobre IA Segura..."
return self.state.draft
@listen(generate_draft)
@human_feedback(
message="Por favor, revise este rascunho. Responda 'approved', 'rejected', ou forneça feedback de revisão:",
message="Por favor, revise este rascunho. Aprove, rejeite ou descreva o que precisa mudar:",
emit=["approved", "rejected", "needs_revision"],
llm="gpt-4o-mini",
default_outcome="needs_revision",
)
def review_draft(self, draft):
return draft
@listen(or_("generate_draft", "needs_revision"))
def review_draft(self):
self.state.revision_count += 1
return f"{self.state.draft} (v{self.state.revision_count})"
@listen("approved")
def publish_content(self, result: HumanFeedbackResult):
self.state.final_content = result.output
print("\n✅ Conteúdo aprovado e publicado!")
print(f"Comentário do revisor: {result.feedback}")
self.state.status = "published"
print(f"Conteúdo aprovado e publicado! Revisor disse: {result.feedback}")
return "published"
@listen("rejected")
def handle_rejection(self, result: HumanFeedbackResult):
print("\n❌ Conteúdo rejeitado")
print(f"Motivo: {result.feedback}")
self.state.status = "rejected"
print(f"Conteúdo rejeitado. Motivo: {result.feedback}")
return "rejected"
@listen("needs_revision")
def revise_content(self, result: HumanFeedbackResult):
self.state.revision_count += 1
print(f"\n📝 Revisão #{self.state.revision_count} solicitada")
print(f"Feedback: {result.feedback}")
# Em um flow real, você pode voltar para generate_draft
# Para este exemplo, apenas reconhecemos
return "revision_requested"
# Executar o flow
flow = ContentApprovalFlow()
result = flow.kickoff()
print(f"\nFlow concluído. Revisões solicitadas: {flow.state.revision_count}")
print(f"\nFlow finalizado. Status: {flow.state.status}, Revisões: {flow.state.revision_count}")
```
```text Output
Sobre qual tópico devo escrever? Segurança em IA
==================================================
OUTPUT FOR REVIEW:
==================================================
# IA Segura
Este é um rascunho sobre IA Segura... (v1)
==================================================
Por favor, revise este rascunho. Aprove, rejeite ou descreva o que precisa mudar:
(Press Enter to skip, or type your feedback)
Your feedback: Preciso de mais detalhes sobre segurança em IA.
==================================================
OUTPUT FOR REVIEW:
==================================================
# Segurança em IA
# IA Segura
Este é um rascunho sobre Segurança em IA...
Este é um rascunho sobre IA Segura... (v2)
==================================================
Por favor, revise este rascunho. Responda 'approved', 'rejected', ou forneça feedback de revisão:
Por favor, revise este rascunho. Aprove, rejeite ou descreva o que precisa mudar:
(Press Enter to skip, or type your feedback)
Your feedback: Parece bom, aprovado!
Conteúdo aprovado e publicado!
Comentário do revisor: Parece bom, aprovado!
Conteúdo aprovado e publicado! Revisor disse: Parece bom, aprovado!
Flow concluído. Revisões solicitadas: 0
Flow finalizado. Status: published, Revisões: 2
```
</CodeGroup>
## Combinando com Outros Decoradores
O decorador `@human_feedback` funciona com outros decoradores de flow. Coloque-o como o decorador mais interno (mais próximo da função):
O decorador `@human_feedback` funciona com `@start()`, `@listen()` e `or_()`. Ambas as ordens de decoradores funcionam — o framework propaga atributos em ambas as direções — mas os padrões recomendados são:
```python Code
# Correto: @human_feedback é o mais interno (mais próximo da função)
# Revisão única no início do flow (sem self-loop)
@start()
@human_feedback(message="Revise isto:")
@human_feedback(message="Revise isto:", emit=["approved", "rejected"], llm="gpt-4o-mini")
def my_start_method(self):
return "content"
# Revisão linear em um listener (sem self-loop)
@listen(other_method)
@human_feedback(message="Revise isto também:")
@human_feedback(message="Revise isto também:", emit=["good", "bad"], llm="gpt-4o-mini")
def my_listener(self, data):
return f"processed: {data}"
# Self-loop: revisão que pode voltar para revisões
@human_feedback(message="Aprovar ou revisar?", emit=["approved", "revise"], llm="gpt-4o-mini")
@listen(or_("upstream_method", "revise"))
def review_with_loop(self):
return "content for review"
```
<Tip>
Coloque `@human_feedback` como o decorador mais interno (último/mais próximo da função) para que ele envolva o método diretamente e possa capturar o valor de retorno antes de passar para o sistema de flow.
</Tip>
### Padrão de self-loop
Para criar um loop de revisão, o método de revisão deve escutar **ambos** um gatilho upstream e seu próprio outcome de revisão usando `or_()`:
```python Code
@start()
def generate(self):
return "initial draft"
@human_feedback(
message="Aprovar ou solicitar alterações?",
emit=["revise", "approved"],
llm="gpt-4o-mini",
default_outcome="approved",
)
@listen(or_("generate", "revise"))
def review(self):
return "content"
@listen("approved")
def publish(self):
return "published"
```
Quando o outcome é `"revise"`, o flow roteia de volta para `review` (porque ele escuta `"revise"` via `or_()`). Quando o outcome é `"approved"`, o flow continua para `publish`. Isso funciona porque o engine de flow isenta roteadores da regra "fire once", permitindo que eles re-executem em cada iteração do loop.
### Roteadores encadeados
Um listener disparado pelo outcome de um roteador pode ser ele mesmo um roteador:
```python Code
@start()
@human_feedback(message="Primeira revisão:", emit=["approved", "rejected"], llm="gpt-4o-mini")
def draft(self):
return "draft content"
@listen("approved")
@human_feedback(message="Revisão final:", emit=["publish", "revise"], llm="gpt-4o-mini")
def final_review(self, prev):
return "final content"
@listen("publish")
def on_publish(self, prev):
return "published"
```
### Limitações
- **Métodos `@start()` executam uma vez**: Um método `@start()` não pode fazer self-loop. Se você precisa de um ciclo de revisão, use um método `@start()` separado como ponto de entrada e coloque o `@human_feedback` em um método `@listen()`.
- **Sem `@start()` + `@listen()` no mesmo método**: Esta é uma restrição do framework de Flow. Um método é ou um ponto de início ou um listener, não ambos.
## Melhores Práticas
@@ -516,9 +572,9 @@ class ContentPipeline(Flow):
@start()
@human_feedback(
message="Aprova este conteúdo para publicação?",
emit=["approved", "rejected", "needs_revision"],
emit=["approved", "rejected"],
llm="gpt-4o-mini",
default_outcome="needs_revision",
default_outcome="rejected",
provider=SlackNotificationProvider("#content-reviews"),
)
def generate_content(self):
@@ -534,11 +590,6 @@ class ContentPipeline(Flow):
print(f"Arquivado. Motivo: {result.feedback}")
return {"status": "archived"}
@listen("needs_revision")
def queue_revision(self, result):
print(f"Na fila para revisão: {result.feedback}")
return {"status": "revision_needed"}
# Iniciando o flow (vai pausar e aguardar resposta do Slack)
def start_content_pipeline():
@@ -594,22 +645,22 @@ Com o tempo, o humano vê saídas pré-revisadas progressivamente melhores porqu
```python Code
class ArticleReviewFlow(Flow):
@start()
def generate_article(self):
return self.crew.kickoff(inputs={"topic": "AI Safety"}).raw
@human_feedback(
message="Review this article draft:",
message="Revise este rascunho do artigo:",
emit=["approved", "needs_revision"],
llm="gpt-4o-mini",
learn=True, # enable HITL learning
)
def generate_article(self):
return self.crew.kickoff(inputs={"topic": "AI Safety"}).raw
@listen(or_("generate_article", "needs_revision"))
def review_article(self):
return self.last_human_feedback.output if self.last_human_feedback else "article draft"
@listen("approved")
def publish(self):
print(f"Publishing: {self.last_human_feedback.output}")
@listen("needs_revision")
def revise(self):
print("Revising based on feedback...")
```
**Primeira execução**: O humano vê a saída bruta e diz "Sempre inclua citações para afirmações factuais." A lição é destilada e armazenada na memória.

View File

@@ -7,7 +7,7 @@ mode: "wide"
## Conecte o CrewAI a LLMs
O CrewAI utiliza o LiteLLM para conectar-se a uma grande variedade de Modelos de Linguagem (LLMs). Essa integração proporciona grande versatilidade, permitindo que você utilize modelos de inúmeros provedores por meio de uma interface simples e unificada.
O CrewAI conecta-se a LLMs por meio de integrações nativas via SDK para os provedores mais populares (OpenAI, Anthropic, Google Gemini, Azure e AWS Bedrock), e usa o LiteLLM como alternativa flexível para todos os demais provedores.
<Note>
Por padrão, o CrewAI usa o modelo `gpt-4o-mini`. Isso é determinado pela variável de ambiente `OPENAI_MODEL_NAME`, que tem como padrão "gpt-4o-mini" se não for definida.
@@ -40,6 +40,14 @@ O LiteLLM oferece suporte a uma ampla gama de provedores, incluindo, mas não se
Para uma lista completa e sempre atualizada dos provedores suportados, consulte a [documentação de Provedores do LiteLLM](https://docs.litellm.ai/docs/providers).
<Info>
Para usar qualquer provedor não coberto por uma integração nativa, adicione o LiteLLM como dependência ao seu projeto:
```bash
uv add 'crewai[litellm]'
```
Provedores nativos (OpenAI, Anthropic, Google Gemini, Azure, AWS Bedrock) usam seus próprios extras de SDK — consulte os [Exemplos de Configuração de Provedores](/pt-BR/concepts/llms#exemplos-de-configuração-de-provedores).
</Info>
## Alterando a LLM
Para utilizar uma LLM diferente com seus agentes CrewAI, você tem várias opções:

189
lib/crewai-a2a/README.md Normal file
View File

@@ -0,0 +1,189 @@
# crewai-a2a
Agent-to-Agent (A2A) protocol support for CrewAI. Enables agents to discover, authenticate, and communicate with remote A2A-compatible agents.
## Quick Links
[Homepage](https://www.crewai.com/) | [Documentation](https://docs.crewai.com/) | [Community](https://community.crewai.com/)
## Installation
```bash
uv pip install crewai[a2a]
# or
uv add 'crewai[a2a]'
```
## Usage
### Connecting to a Remote A2A Agent
```python
from crewai import Agent
from crewai_a2a import A2AClientConfig
agent = Agent(
role="Coordinator",
goal="Delegate research tasks",
a2a=[
A2AClientConfig(endpoint="https://research-agent.example.com"),
],
)
```
### Exposing an Agent as an A2A Server
```python
from crewai import Agent
from crewai_a2a import A2AServerConfig
agent = Agent(
role="Researcher",
goal="Answer research questions",
a2a_server=A2AServerConfig(
name="Research Agent",
description="Answers research questions using web search",
),
)
```
## Authentication
### Client Schemes
```python
from crewai_a2a.auth import (
BearerTokenAuth,
HTTPBasicAuth,
APIKeyAuth,
OAuth2ClientCredentials,
)
from crewai_a2a.config import A2AClientConfig
# Bearer token
A2AClientConfig(
endpoint="https://agent.example.com",
auth=BearerTokenAuth(token="my-token"),
)
# API key
A2AClientConfig(
endpoint="https://agent.example.com",
auth=APIKeyAuth(api_key="key", location="header", name="X-API-Key"),
)
# OAuth2 client credentials
A2AClientConfig(
endpoint="https://agent.example.com",
auth=OAuth2ClientCredentials(
token_url="https://auth.example.com/token",
client_id="id",
client_secret="secret",
),
)
```
### Server Schemes
```python
from crewai_a2a.auth import SimpleTokenAuth, OIDCAuth
from crewai_a2a.config import A2AServerConfig
# Simple token validation
A2AServerConfig(auth=SimpleTokenAuth(token="expected-token"))
# OpenID Connect
A2AServerConfig(
auth=OIDCAuth(
issuer="https://auth.example.com",
audience="my-agent",
),
)
```
## Update Mechanisms
Control how the client receives task updates from remote agents.
```python
from crewai_a2a.updates import PollingConfig, StreamingConfig, PushNotificationConfig
from crewai_a2a.config import A2AClientConfig
# Polling
A2AClientConfig(
endpoint="https://agent.example.com",
updates=PollingConfig(interval=2.0, timeout=60),
)
# Server-Sent Events streaming
A2AClientConfig(
endpoint="https://agent.example.com",
updates=StreamingConfig(),
)
# Webhook push notifications
A2AClientConfig(
endpoint="https://agent.example.com",
updates=PushNotificationConfig(
url="https://my-server.example.com/webhook",
timeout=300,
),
)
```
## Extensions
### Client Extensions
Client extensions inject tools, augment prompts, and process responses.
```python
from crewai_a2a.extensions import A2AExtension
class MyExtension(A2AExtension):
def inject_tools(self, agent):
...
def augment_prompt(self, base_prompt, conversation_state):
return f"{base_prompt}\n\nAdditional context from extension."
```
### Server Extensions
Server extensions add protocol-level capabilities to your A2A server.
```python
from crewai_a2a.extensions import ServerExtension
class MyServerExtension(ServerExtension):
uri = "urn:my-org:my-extension"
description = "Custom protocol extension"
async def on_request(self, context):
...
async def on_response(self, context, result):
...
```
## Transport
Three transport protocols are supported: JSON-RPC (default), gRPC, and HTTP+JSON.
```python
from crewai_a2a.config import ClientTransportConfig, GRPCClientConfig
from crewai_a2a.config import A2AClientConfig
A2AClientConfig(
endpoint="https://agent.example.com",
transport=ClientTransportConfig(
preferred="GRPC",
grpc=GRPCClientConfig(
max_send_message_length=4 * 1024 * 1024,
),
),
)
```

View File

@@ -0,0 +1,25 @@
[project]
name = "crewai-a2a"
dynamic = ["version"]
description = "A2A (Agent-to-Agent) protocol support for CrewAI"
readme = "README.md"
authors = [{ name = "Greyson LaLonde", email = "greyson@crewai.com" }]
requires-python = ">=3.10, <3.14"
dependencies = [
"crewai==1.10.1b1",
"a2a-sdk~=0.3.10",
"httpx-auth~=0.23.1",
"httpx-sse~=0.4.0",
"aiocache[redis,memcached]~=0.12.3",
]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.version]
path = "src/crewai_a2a/__init__.py"
[tool.uv.sources]
crewai = { workspace = true }
crewai-files = { workspace = true }

View File

@@ -0,0 +1,12 @@
"""Agent-to-Agent (A2A) protocol communication module for CrewAI."""
__version__ = "1.10.1b1"
from crewai_a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
__all__ = [
"A2AClientConfig",
"A2AConfig",
"A2AServerConfig",
]

View File

@@ -0,0 +1,36 @@
"""A2A authentication schemas."""
from crewai_a2a.auth.client_schemes import (
APIKeyAuth,
AuthScheme,
BearerTokenAuth,
ClientAuthScheme,
HTTPBasicAuth,
HTTPDigestAuth,
OAuth2AuthorizationCode,
OAuth2ClientCredentials,
TLSConfig,
)
from crewai_a2a.auth.server_schemes import (
AuthenticatedUser,
OIDCAuth,
ServerAuthScheme,
SimpleTokenAuth,
)
__all__ = [
"APIKeyAuth",
"AuthScheme",
"AuthenticatedUser",
"BearerTokenAuth",
"ClientAuthScheme",
"HTTPBasicAuth",
"HTTPDigestAuth",
"OAuth2AuthorizationCode",
"OAuth2ClientCredentials",
"OIDCAuth",
"ServerAuthScheme",
"SimpleTokenAuth",
"TLSConfig",
]

View File

@@ -0,0 +1,550 @@
"""Authentication schemes for A2A protocol clients.
Supported authentication methods:
- Bearer tokens
- OAuth2 (Client Credentials, Authorization Code)
- API Keys (header, query, cookie)
- HTTP Basic authentication
- HTTP Digest authentication
- mTLS (mutual TLS) client certificate authentication
"""
from __future__ import annotations
from abc import ABC, abstractmethod
import asyncio
import base64
from collections.abc import Awaitable, Callable, MutableMapping
from pathlib import Path
import ssl
import time
from typing import TYPE_CHECKING, ClassVar, Literal
import urllib.parse
import httpx
from httpx import DigestAuth
from pydantic import BaseModel, ConfigDict, Field, FilePath, PrivateAttr
from typing_extensions import deprecated
if TYPE_CHECKING:
import grpc # type: ignore[import-untyped]
class TLSConfig(BaseModel):
"""TLS/mTLS configuration for secure client connections.
Supports mutual TLS (mTLS) where the client presents a certificate to the server,
and standard TLS with custom CA verification.
Attributes:
client_cert_path: Path to client certificate file (PEM format) for mTLS.
client_key_path: Path to client private key file (PEM format) for mTLS.
ca_cert_path: Path to CA certificate bundle for server verification.
verify: Whether to verify server certificates. Set False only for development.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
client_cert_path: FilePath | None = Field(
default=None,
description="Path to client certificate file (PEM format) for mTLS",
)
client_key_path: FilePath | None = Field(
default=None,
description="Path to client private key file (PEM format) for mTLS",
)
ca_cert_path: FilePath | None = Field(
default=None,
description="Path to CA certificate bundle for server verification",
)
verify: bool = Field(
default=True,
description="Whether to verify server certificates. Set False only for development.",
)
def get_httpx_ssl_context(self) -> ssl.SSLContext | bool | str:
"""Build SSL context for httpx client.
Returns:
SSL context if certificates configured, True for default verification,
False if verification disabled, or path to CA bundle.
"""
if not self.verify:
return False
if self.client_cert_path and self.client_key_path:
context = ssl.create_default_context()
if self.ca_cert_path:
context.load_verify_locations(cafile=str(self.ca_cert_path))
context.load_cert_chain(
certfile=str(self.client_cert_path),
keyfile=str(self.client_key_path),
)
return context
if self.ca_cert_path:
return str(self.ca_cert_path)
return True
def get_grpc_credentials(self) -> grpc.ChannelCredentials | None: # type: ignore[no-any-unimported]
"""Build gRPC channel credentials for secure connections.
Returns:
gRPC SSL credentials if certificates configured, None otherwise.
"""
try:
import grpc
except ImportError:
return None
if not self.verify and not self.client_cert_path:
return None
root_certs: bytes | None = None
private_key: bytes | None = None
certificate_chain: bytes | None = None
if self.ca_cert_path:
root_certs = Path(self.ca_cert_path).read_bytes()
if self.client_cert_path and self.client_key_path:
private_key = Path(self.client_key_path).read_bytes()
certificate_chain = Path(self.client_cert_path).read_bytes()
return grpc.ssl_channel_credentials(
root_certificates=root_certs,
private_key=private_key,
certificate_chain=certificate_chain,
)
class ClientAuthScheme(ABC, BaseModel):
"""Base class for client-side authentication schemes.
Client auth schemes apply credentials to outgoing requests.
Attributes:
tls: Optional TLS/mTLS configuration for secure connections.
"""
tls: TLSConfig | None = Field(
default=None,
description="TLS/mTLS configuration for secure connections",
)
@abstractmethod
async def apply_auth(
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
) -> MutableMapping[str, str]:
"""Apply authentication to request headers.
Args:
client: HTTP client for making auth requests.
headers: Current request headers.
Returns:
Updated headers with authentication applied.
"""
...
@deprecated("Use ClientAuthScheme instead", category=FutureWarning)
class AuthScheme(ClientAuthScheme):
"""Deprecated: Use ClientAuthScheme instead."""
class BearerTokenAuth(ClientAuthScheme):
"""Bearer token authentication (Authorization: Bearer <token>).
Attributes:
token: Bearer token for authentication.
"""
token: str = Field(description="Bearer token")
async def apply_auth(
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
) -> MutableMapping[str, str]:
"""Apply Bearer token to Authorization header.
Args:
client: HTTP client for making auth requests.
headers: Current request headers.
Returns:
Updated headers with Bearer token in Authorization header.
"""
headers["Authorization"] = f"Bearer {self.token}"
return headers
class HTTPBasicAuth(ClientAuthScheme):
"""HTTP Basic authentication.
Attributes:
username: Username for Basic authentication.
password: Password for Basic authentication.
"""
username: str = Field(description="Username")
password: str = Field(description="Password")
async def apply_auth(
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
) -> MutableMapping[str, str]:
"""Apply HTTP Basic authentication.
Args:
client: HTTP client for making auth requests.
headers: Current request headers.
Returns:
Updated headers with Basic auth in Authorization header.
"""
credentials = f"{self.username}:{self.password}"
encoded = base64.b64encode(credentials.encode()).decode()
headers["Authorization"] = f"Basic {encoded}"
return headers
class HTTPDigestAuth(ClientAuthScheme):
"""HTTP Digest authentication.
Note: Uses httpx-auth library for digest implementation.
Attributes:
username: Username for Digest authentication.
password: Password for Digest authentication.
"""
username: str = Field(description="Username")
password: str = Field(description="Password")
_configured_client_id: int | None = PrivateAttr(default=None)
async def apply_auth(
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
) -> MutableMapping[str, str]:
"""Digest auth is handled by httpx auth flow, not headers.
Args:
client: HTTP client for making auth requests.
headers: Current request headers.
Returns:
Unchanged headers (Digest auth handled by httpx auth flow).
"""
return headers
def configure_client(self, client: httpx.AsyncClient) -> None:
"""Configure client with Digest auth.
Idempotent: Only configures the client once. Subsequent calls on the same
client instance are no-ops to prevent overwriting auth configuration.
Args:
client: HTTP client to configure with Digest authentication.
"""
client_id = id(client)
if self._configured_client_id == client_id:
return
client.auth = DigestAuth(self.username, self.password)
self._configured_client_id = client_id
class APIKeyAuth(ClientAuthScheme):
"""API Key authentication (header, query, or cookie).
Attributes:
api_key: API key value for authentication.
location: Where to send the API key (header, query, or cookie).
name: Parameter name for the API key (default: X-API-Key).
"""
api_key: str = Field(description="API key value")
location: Literal["header", "query", "cookie"] = Field(
default="header", description="Where to send the API key"
)
name: str = Field(default="X-API-Key", description="Parameter name for the API key")
_configured_client_ids: set[int] = PrivateAttr(default_factory=set)
async def apply_auth(
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
) -> MutableMapping[str, str]:
"""Apply API key authentication.
Args:
client: HTTP client for making auth requests.
headers: Current request headers.
Returns:
Updated headers with API key (for header/cookie locations).
"""
if self.location == "header":
headers[self.name] = self.api_key
elif self.location == "cookie":
headers["Cookie"] = f"{self.name}={self.api_key}"
return headers
def configure_client(self, client: httpx.AsyncClient) -> None:
"""Configure client for query param API keys.
Idempotent: Only adds the request hook once per client instance.
Subsequent calls on the same client are no-ops to prevent hook accumulation.
Args:
client: HTTP client to configure with query param API key hook.
"""
if self.location == "query":
client_id = id(client)
if client_id in self._configured_client_ids:
return
async def _add_api_key_param(request: httpx.Request) -> None:
url = httpx.URL(request.url)
request.url = url.copy_add_param(self.name, self.api_key)
client.event_hooks["request"].append(_add_api_key_param)
self._configured_client_ids.add(client_id)
class OAuth2ClientCredentials(ClientAuthScheme):
"""OAuth2 Client Credentials flow authentication.
Thread-safe implementation with asyncio.Lock to prevent concurrent token fetches
when multiple requests share the same auth instance.
Attributes:
token_url: OAuth2 token endpoint URL.
client_id: OAuth2 client identifier.
client_secret: OAuth2 client secret.
scopes: List of required OAuth2 scopes.
"""
token_url: str = Field(description="OAuth2 token endpoint")
client_id: str = Field(description="OAuth2 client ID")
client_secret: str = Field(description="OAuth2 client secret")
scopes: list[str] = Field(
default_factory=list, description="Required OAuth2 scopes"
)
_access_token: str | None = PrivateAttr(default=None)
_token_expires_at: float | None = PrivateAttr(default=None)
_lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock)
async def apply_auth(
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
) -> MutableMapping[str, str]:
"""Apply OAuth2 access token to Authorization header.
Uses asyncio.Lock to ensure only one coroutine fetches tokens at a time,
preventing race conditions when multiple concurrent requests use the same
auth instance.
Args:
client: HTTP client for making token requests.
headers: Current request headers.
Returns:
Updated headers with OAuth2 access token in Authorization header.
"""
if (
self._access_token is None
or self._token_expires_at is None
or time.time() >= self._token_expires_at
):
async with self._lock:
if (
self._access_token is None
or self._token_expires_at is None
or time.time() >= self._token_expires_at
):
await self._fetch_token(client)
if self._access_token:
headers["Authorization"] = f"Bearer {self._access_token}"
return headers
async def _fetch_token(self, client: httpx.AsyncClient) -> None:
"""Fetch OAuth2 access token using client credentials flow.
Args:
client: HTTP client for making token request.
Raises:
httpx.HTTPStatusError: If token request fails.
"""
data = {
"grant_type": "client_credentials",
"client_id": self.client_id,
"client_secret": self.client_secret,
}
if self.scopes:
data["scope"] = " ".join(self.scopes)
response = await client.post(self.token_url, data=data)
response.raise_for_status()
token_data = response.json()
self._access_token = token_data["access_token"]
expires_in = token_data.get("expires_in", 3600)
self._token_expires_at = time.time() + expires_in - 60
class OAuth2AuthorizationCode(ClientAuthScheme):
"""OAuth2 Authorization Code flow authentication.
Thread-safe implementation with asyncio.Lock to prevent concurrent token operations.
Note: Requires interactive authorization.
Attributes:
authorization_url: OAuth2 authorization endpoint URL.
token_url: OAuth2 token endpoint URL.
client_id: OAuth2 client identifier.
client_secret: OAuth2 client secret.
redirect_uri: OAuth2 redirect URI for callback.
scopes: List of required OAuth2 scopes.
"""
authorization_url: str = Field(description="OAuth2 authorization endpoint")
token_url: str = Field(description="OAuth2 token endpoint")
client_id: str = Field(description="OAuth2 client ID")
client_secret: str = Field(description="OAuth2 client secret")
redirect_uri: str = Field(description="OAuth2 redirect URI")
scopes: list[str] = Field(
default_factory=list, description="Required OAuth2 scopes"
)
_access_token: str | None = PrivateAttr(default=None)
_refresh_token: str | None = PrivateAttr(default=None)
_token_expires_at: float | None = PrivateAttr(default=None)
_authorization_callback: Callable[[str], Awaitable[str]] | None = PrivateAttr(
default=None
)
_lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock)
def set_authorization_callback(
self, callback: Callable[[str], Awaitable[str]] | None
) -> None:
"""Set callback to handle authorization URL.
Args:
callback: Async function that receives authorization URL and returns auth code.
"""
self._authorization_callback = callback
async def apply_auth(
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
) -> MutableMapping[str, str]:
"""Apply OAuth2 access token to Authorization header.
Uses asyncio.Lock to ensure only one coroutine handles token operations
(initial fetch or refresh) at a time.
Args:
client: HTTP client for making token requests.
headers: Current request headers.
Returns:
Updated headers with OAuth2 access token in Authorization header.
Raises:
ValueError: If authorization callback is not set.
"""
if self._access_token is None:
if self._authorization_callback is None:
msg = "Authorization callback not set. Use set_authorization_callback()"
raise ValueError(msg)
async with self._lock:
if self._access_token is None:
await self._fetch_initial_token(client)
elif self._token_expires_at and time.time() >= self._token_expires_at:
async with self._lock:
if self._token_expires_at and time.time() >= self._token_expires_at:
await self._refresh_access_token(client)
if self._access_token:
headers["Authorization"] = f"Bearer {self._access_token}"
return headers
async def _fetch_initial_token(self, client: httpx.AsyncClient) -> None:
"""Fetch initial access token using authorization code flow.
Args:
client: HTTP client for making token request.
Raises:
ValueError: If authorization callback is not set.
httpx.HTTPStatusError: If token request fails.
"""
params = {
"response_type": "code",
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
"scope": " ".join(self.scopes),
}
auth_url = f"{self.authorization_url}?{urllib.parse.urlencode(params)}"
if self._authorization_callback is None:
msg = "Authorization callback not set"
raise ValueError(msg)
auth_code = await self._authorization_callback(auth_url)
data = {
"grant_type": "authorization_code",
"code": auth_code,
"client_id": self.client_id,
"client_secret": self.client_secret,
"redirect_uri": self.redirect_uri,
}
response = await client.post(self.token_url, data=data)
response.raise_for_status()
token_data = response.json()
self._access_token = token_data["access_token"]
self._refresh_token = token_data.get("refresh_token")
expires_in = token_data.get("expires_in", 3600)
self._token_expires_at = time.time() + expires_in - 60
async def _refresh_access_token(self, client: httpx.AsyncClient) -> None:
"""Refresh the access token using refresh token.
Args:
client: HTTP client for making token request.
Raises:
httpx.HTTPStatusError: If token refresh request fails.
"""
if not self._refresh_token:
await self._fetch_initial_token(client)
return
data = {
"grant_type": "refresh_token",
"refresh_token": self._refresh_token,
"client_id": self.client_id,
"client_secret": self.client_secret,
}
response = await client.post(self.token_url, data=data)
response.raise_for_status()
token_data = response.json()
self._access_token = token_data["access_token"]
if "refresh_token" in token_data:
self._refresh_token = token_data["refresh_token"]
expires_in = token_data.get("expires_in", 3600)
self._token_expires_at = time.time() + expires_in - 60

View File

@@ -0,0 +1,71 @@
"""Deprecated: Authentication schemes for A2A protocol agents.
This module is deprecated. Import from crewai_a2a.auth instead:
- crewai_a2a.auth.ClientAuthScheme (replaces AuthScheme)
- crewai_a2a.auth.BearerTokenAuth
- crewai_a2a.auth.HTTPBasicAuth
- crewai_a2a.auth.HTTPDigestAuth
- crewai_a2a.auth.APIKeyAuth
- crewai_a2a.auth.OAuth2ClientCredentials
- crewai_a2a.auth.OAuth2AuthorizationCode
"""
from __future__ import annotations
from typing_extensions import deprecated
from crewai_a2a.auth.client_schemes import (
APIKeyAuth as _APIKeyAuth,
BearerTokenAuth as _BearerTokenAuth,
ClientAuthScheme as _ClientAuthScheme,
HTTPBasicAuth as _HTTPBasicAuth,
HTTPDigestAuth as _HTTPDigestAuth,
OAuth2AuthorizationCode as _OAuth2AuthorizationCode,
OAuth2ClientCredentials as _OAuth2ClientCredentials,
)
@deprecated("Use ClientAuthScheme from crewai_a2a.auth instead", category=FutureWarning)
class AuthScheme(_ClientAuthScheme):
"""Deprecated: Use ClientAuthScheme from crewai_a2a.auth instead."""
@deprecated("Import from crewai_a2a.auth instead", category=FutureWarning)
class BearerTokenAuth(_BearerTokenAuth):
"""Deprecated: Import from crewai_a2a.auth instead."""
@deprecated("Import from crewai_a2a.auth instead", category=FutureWarning)
class HTTPBasicAuth(_HTTPBasicAuth):
"""Deprecated: Import from crewai_a2a.auth instead."""
@deprecated("Import from crewai_a2a.auth instead", category=FutureWarning)
class HTTPDigestAuth(_HTTPDigestAuth):
"""Deprecated: Import from crewai_a2a.auth instead."""
@deprecated("Import from crewai_a2a.auth instead", category=FutureWarning)
class APIKeyAuth(_APIKeyAuth):
"""Deprecated: Import from crewai_a2a.auth instead."""
@deprecated("Import from crewai_a2a.auth instead", category=FutureWarning)
class OAuth2ClientCredentials(_OAuth2ClientCredentials):
"""Deprecated: Import from crewai_a2a.auth instead."""
@deprecated("Import from crewai_a2a.auth instead", category=FutureWarning)
class OAuth2AuthorizationCode(_OAuth2AuthorizationCode):
"""Deprecated: Import from crewai_a2a.auth instead."""
__all__ = [
"APIKeyAuth",
"AuthScheme",
"BearerTokenAuth",
"HTTPBasicAuth",
"HTTPDigestAuth",
"OAuth2AuthorizationCode",
"OAuth2ClientCredentials",
]

View File

@@ -0,0 +1,742 @@
"""Server-side authentication schemes for A2A protocol.
These schemes validate incoming requests to A2A server endpoints.
Supported authentication methods:
- Simple token validation with static bearer tokens
- OpenID Connect with JWT validation using JWKS
- OAuth2 with JWT validation or token introspection
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
import logging
import os
from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal
import jwt
from jwt import PyJWKClient
from pydantic import (
BaseModel,
BeforeValidator,
ConfigDict,
Field,
HttpUrl,
PrivateAttr,
SecretStr,
model_validator,
)
from typing_extensions import Self
if TYPE_CHECKING:
from a2a.types import OAuth2SecurityScheme
logger = logging.getLogger(__name__)
try:
from fastapi import ( # type: ignore[import-not-found]
HTTPException,
status as http_status,
)
HTTP_401_UNAUTHORIZED = http_status.HTTP_401_UNAUTHORIZED
HTTP_500_INTERNAL_SERVER_ERROR = http_status.HTTP_500_INTERNAL_SERVER_ERROR
HTTP_503_SERVICE_UNAVAILABLE = http_status.HTTP_503_SERVICE_UNAVAILABLE
except ImportError:
class HTTPException(Exception): # type: ignore[no-redef] # noqa: N818
"""Fallback HTTPException when FastAPI is not installed."""
def __init__(
self,
status_code: int,
detail: str | None = None,
headers: dict[str, str] | None = None,
) -> None:
self.status_code = status_code
self.detail = detail
self.headers = headers
super().__init__(detail)
HTTP_401_UNAUTHORIZED = 401
HTTP_500_INTERNAL_SERVER_ERROR = 500
HTTP_503_SERVICE_UNAVAILABLE = 503
def _coerce_secret_str(v: str | SecretStr | None) -> SecretStr | None:
"""Coerce string to SecretStr."""
if v is None or isinstance(v, SecretStr):
return v
return SecretStr(v)
CoercedSecretStr = Annotated[SecretStr, BeforeValidator(_coerce_secret_str)]
JWTAlgorithm = Literal[
"RS256",
"RS384",
"RS512",
"ES256",
"ES384",
"ES512",
"PS256",
"PS384",
"PS512",
]
@dataclass
class AuthenticatedUser:
"""Result of successful authentication.
Attributes:
token: The original token that was validated.
scheme: Name of the authentication scheme used.
claims: JWT claims from OIDC or OAuth2 authentication.
"""
token: str
scheme: str
claims: dict[str, Any] | None = None
class ServerAuthScheme(ABC, BaseModel):
"""Base class for server-side authentication schemes.
Each scheme validates incoming requests and returns an AuthenticatedUser
on success, or raises HTTPException on failure.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
@abstractmethod
async def authenticate(self, token: str) -> AuthenticatedUser:
"""Authenticate the provided token.
Args:
token: The bearer token to authenticate.
Returns:
AuthenticatedUser on successful authentication.
Raises:
HTTPException: If authentication fails.
"""
...
class SimpleTokenAuth(ServerAuthScheme):
"""Simple bearer token authentication.
Validates tokens against a configured static token or AUTH_TOKEN env var.
Attributes:
token: Expected token value. Falls back to AUTH_TOKEN env var if not set.
"""
token: CoercedSecretStr | None = Field(
default=None,
description="Expected token. Falls back to AUTH_TOKEN env var.",
)
def _get_expected_token(self) -> str | None:
"""Get the expected token value."""
if self.token:
return self.token.get_secret_value()
return os.environ.get("AUTH_TOKEN")
async def authenticate(self, token: str) -> AuthenticatedUser:
"""Authenticate using simple token comparison.
Args:
token: The bearer token to authenticate.
Returns:
AuthenticatedUser on successful authentication.
Raises:
HTTPException: If authentication fails.
"""
expected = self._get_expected_token()
if expected is None:
logger.warning(
"Simple token authentication failed",
extra={"reason": "no_token_configured"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Authentication not configured",
)
if token != expected:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid or missing authentication credentials",
)
return AuthenticatedUser(
token=token,
scheme="simple_token",
)
class OIDCAuth(ServerAuthScheme):
"""OpenID Connect authentication.
Validates JWTs using JWKS with caching support via PyJWT.
Attributes:
issuer: The OpenID Connect issuer URL.
audience: The expected audience claim.
jwks_url: Optional explicit JWKS URL. Derived from issuer if not set.
algorithms: List of allowed signing algorithms.
required_claims: List of claims that must be present in the token.
jwks_cache_ttl: TTL for JWKS cache in seconds.
clock_skew_seconds: Allowed clock skew for token validation.
"""
issuer: HttpUrl = Field(
description="OpenID Connect issuer URL (e.g., https://auth.example.com)"
)
audience: str = Field(description="Expected audience claim (e.g., api://my-agent)")
jwks_url: HttpUrl | None = Field(
default=None,
description="Explicit JWKS URL. Derived from issuer if not set.",
)
algorithms: list[str] = Field(
default_factory=lambda: ["RS256"],
description="List of allowed signing algorithms (RS256, ES256, etc.)",
)
required_claims: list[str] = Field(
default_factory=lambda: ["exp", "iat", "iss", "aud", "sub"],
description="List of claims that must be present in the token",
)
jwks_cache_ttl: int = Field(
default=3600,
description="TTL for JWKS cache in seconds",
ge=60,
)
clock_skew_seconds: float = Field(
default=30.0,
description="Allowed clock skew for token validation",
ge=0.0,
)
_jwk_client: PyJWKClient | None = PrivateAttr(default=None)
@model_validator(mode="after")
def _init_jwk_client(self) -> Self:
"""Initialize the JWK client after model creation."""
jwks_url = (
str(self.jwks_url)
if self.jwks_url
else f"{str(self.issuer).rstrip('/')}/.well-known/jwks.json"
)
self._jwk_client = PyJWKClient(jwks_url, lifespan=self.jwks_cache_ttl)
return self
async def authenticate(self, token: str) -> AuthenticatedUser:
"""Authenticate using OIDC JWT validation.
Args:
token: The JWT to authenticate.
Returns:
AuthenticatedUser on successful authentication.
Raises:
HTTPException: If authentication fails.
"""
if self._jwk_client is None:
raise HTTPException(
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
detail="OIDC not initialized",
)
try:
signing_key = self._jwk_client.get_signing_key_from_jwt(token)
claims = jwt.decode(
token,
signing_key.key,
algorithms=self.algorithms,
audience=self.audience,
issuer=str(self.issuer).rstrip("/"),
leeway=self.clock_skew_seconds,
options={
"require": self.required_claims,
},
)
return AuthenticatedUser(
token=token,
scheme="oidc",
claims=claims,
)
except jwt.ExpiredSignatureError:
logger.debug(
"OIDC authentication failed",
extra={"reason": "token_expired", "scheme": "oidc"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Token has expired",
) from None
except jwt.InvalidAudienceError:
logger.debug(
"OIDC authentication failed",
extra={"reason": "invalid_audience", "scheme": "oidc"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid token audience",
) from None
except jwt.InvalidIssuerError:
logger.debug(
"OIDC authentication failed",
extra={"reason": "invalid_issuer", "scheme": "oidc"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid token issuer",
) from None
except jwt.MissingRequiredClaimError as e:
logger.debug(
"OIDC authentication failed",
extra={"reason": "missing_claim", "claim": e.claim, "scheme": "oidc"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail=f"Missing required claim: {e.claim}",
) from None
except jwt.PyJWKClientError as e:
logger.error(
"OIDC authentication failed",
extra={
"reason": "jwks_client_error",
"error": str(e),
"scheme": "oidc",
},
)
raise HTTPException(
status_code=HTTP_503_SERVICE_UNAVAILABLE,
detail="Unable to fetch signing keys",
) from None
except jwt.InvalidTokenError as e:
logger.debug(
"OIDC authentication failed",
extra={"reason": "invalid_token", "error": str(e), "scheme": "oidc"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid or missing authentication credentials",
) from None
class OAuth2ServerAuth(ServerAuthScheme):
"""OAuth2 authentication for A2A server.
Declares OAuth2 security scheme in AgentCard and validates tokens using
either JWKS for JWT tokens or token introspection for opaque tokens.
This is distinct from OIDCAuth in that it declares an explicit OAuth2SecurityScheme
with flows, rather than an OpenIdConnectSecurityScheme with discovery URL.
Attributes:
token_url: OAuth2 token endpoint URL for client_credentials flow.
authorization_url: OAuth2 authorization endpoint for authorization_code flow.
refresh_url: Optional refresh token endpoint URL.
scopes: Available OAuth2 scopes with descriptions.
jwks_url: JWKS URL for JWT validation. Required if not using introspection.
introspection_url: Token introspection endpoint (RFC 7662). Alternative to JWKS.
introspection_client_id: Client ID for introspection endpoint authentication.
introspection_client_secret: Client secret for introspection endpoint.
audience: Expected audience claim for JWT validation.
issuer: Expected issuer claim for JWT validation.
algorithms: Allowed JWT signing algorithms.
required_claims: Claims that must be present in the token.
jwks_cache_ttl: TTL for JWKS cache in seconds.
clock_skew_seconds: Allowed clock skew for token validation.
"""
token_url: HttpUrl = Field(
description="OAuth2 token endpoint URL",
)
authorization_url: HttpUrl | None = Field(
default=None,
description="OAuth2 authorization endpoint URL for authorization_code flow",
)
refresh_url: HttpUrl | None = Field(
default=None,
description="OAuth2 refresh token endpoint URL",
)
scopes: dict[str, str] = Field(
default_factory=dict,
description="Available OAuth2 scopes with descriptions",
)
jwks_url: HttpUrl | None = Field(
default=None,
description="JWKS URL for JWT validation. Required if not using introspection.",
)
introspection_url: HttpUrl | None = Field(
default=None,
description="Token introspection endpoint (RFC 7662). Alternative to JWKS.",
)
introspection_client_id: str | None = Field(
default=None,
description="Client ID for introspection endpoint authentication",
)
introspection_client_secret: CoercedSecretStr | None = Field(
default=None,
description="Client secret for introspection endpoint authentication",
)
audience: str | None = Field(
default=None,
description="Expected audience claim for JWT validation",
)
issuer: str | None = Field(
default=None,
description="Expected issuer claim for JWT validation",
)
algorithms: list[str] = Field(
default_factory=lambda: ["RS256"],
description="Allowed JWT signing algorithms",
)
required_claims: list[str] = Field(
default_factory=lambda: ["exp", "iat"],
description="Claims that must be present in the token",
)
jwks_cache_ttl: int = Field(
default=3600,
description="TTL for JWKS cache in seconds",
ge=60,
)
clock_skew_seconds: float = Field(
default=30.0,
description="Allowed clock skew for token validation",
ge=0.0,
)
_jwk_client: PyJWKClient | None = PrivateAttr(default=None)
@model_validator(mode="after")
def _validate_and_init(self) -> Self:
"""Validate configuration and initialize JWKS client if needed."""
if not self.jwks_url and not self.introspection_url:
raise ValueError(
"Either jwks_url or introspection_url must be provided for token validation"
)
if self.introspection_url:
if not self.introspection_client_id or not self.introspection_client_secret:
raise ValueError(
"introspection_client_id and introspection_client_secret are required "
"when using token introspection"
)
if self.jwks_url:
self._jwk_client = PyJWKClient(
str(self.jwks_url), lifespan=self.jwks_cache_ttl
)
return self
async def authenticate(self, token: str) -> AuthenticatedUser:
"""Authenticate using OAuth2 token validation.
Uses JWKS validation if jwks_url is configured, otherwise falls back
to token introspection.
Args:
token: The OAuth2 access token to authenticate.
Returns:
AuthenticatedUser on successful authentication.
Raises:
HTTPException: If authentication fails.
"""
if self._jwk_client:
return await self._authenticate_jwt(token)
return await self._authenticate_introspection(token)
async def _authenticate_jwt(self, token: str) -> AuthenticatedUser:
"""Authenticate using JWKS JWT validation."""
if self._jwk_client is None:
raise HTTPException(
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
detail="OAuth2 JWKS not initialized",
)
try:
signing_key = self._jwk_client.get_signing_key_from_jwt(token)
decode_options: dict[str, Any] = {
"require": self.required_claims,
}
claims = jwt.decode(
token,
signing_key.key,
algorithms=self.algorithms,
audience=self.audience,
issuer=self.issuer,
leeway=self.clock_skew_seconds,
options=decode_options, # type: ignore[arg-type]
)
return AuthenticatedUser(
token=token,
scheme="oauth2",
claims=claims,
)
except jwt.ExpiredSignatureError:
logger.debug(
"OAuth2 authentication failed",
extra={"reason": "token_expired", "scheme": "oauth2"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Token has expired",
) from None
except jwt.InvalidAudienceError:
logger.debug(
"OAuth2 authentication failed",
extra={"reason": "invalid_audience", "scheme": "oauth2"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid token audience",
) from None
except jwt.InvalidIssuerError:
logger.debug(
"OAuth2 authentication failed",
extra={"reason": "invalid_issuer", "scheme": "oauth2"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid token issuer",
) from None
except jwt.MissingRequiredClaimError as e:
logger.debug(
"OAuth2 authentication failed",
extra={"reason": "missing_claim", "claim": e.claim, "scheme": "oauth2"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail=f"Missing required claim: {e.claim}",
) from None
except jwt.PyJWKClientError as e:
logger.error(
"OAuth2 authentication failed",
extra={
"reason": "jwks_client_error",
"error": str(e),
"scheme": "oauth2",
},
)
raise HTTPException(
status_code=HTTP_503_SERVICE_UNAVAILABLE,
detail="Unable to fetch signing keys",
) from None
except jwt.InvalidTokenError as e:
logger.debug(
"OAuth2 authentication failed",
extra={"reason": "invalid_token", "error": str(e), "scheme": "oauth2"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid or missing authentication credentials",
) from None
async def _authenticate_introspection(self, token: str) -> AuthenticatedUser:
"""Authenticate using OAuth2 token introspection (RFC 7662)."""
import httpx
if not self.introspection_url:
raise HTTPException(
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
detail="OAuth2 introspection not configured",
)
try:
async with httpx.AsyncClient() as client:
response = await client.post(
str(self.introspection_url),
data={"token": token},
auth=(
self.introspection_client_id or "",
self.introspection_client_secret.get_secret_value()
if self.introspection_client_secret
else "",
),
)
response.raise_for_status()
introspection_result = response.json()
except httpx.HTTPStatusError as e:
logger.error(
"OAuth2 introspection failed",
extra={"reason": "http_error", "status_code": e.response.status_code},
)
raise HTTPException(
status_code=HTTP_503_SERVICE_UNAVAILABLE,
detail="Token introspection service unavailable",
) from None
except Exception as e:
logger.error(
"OAuth2 introspection failed",
extra={"reason": "unexpected_error", "error": str(e)},
)
raise HTTPException(
status_code=HTTP_503_SERVICE_UNAVAILABLE,
detail="Token introspection failed",
) from None
if not introspection_result.get("active", False):
logger.debug(
"OAuth2 authentication failed",
extra={"reason": "token_not_active", "scheme": "oauth2"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Token is not active",
)
return AuthenticatedUser(
token=token,
scheme="oauth2",
claims=introspection_result,
)
def to_security_scheme(self) -> OAuth2SecurityScheme:
"""Generate OAuth2SecurityScheme for AgentCard declaration.
Creates an OAuth2SecurityScheme with appropriate flows based on
the configured URLs. Includes client_credentials flow if token_url
is set, and authorization_code flow if authorization_url is set.
Returns:
OAuth2SecurityScheme suitable for use in AgentCard security_schemes.
"""
from a2a.types import (
AuthorizationCodeOAuthFlow,
ClientCredentialsOAuthFlow,
OAuth2SecurityScheme,
OAuthFlows,
)
client_credentials = None
authorization_code = None
if self.token_url:
client_credentials = ClientCredentialsOAuthFlow(
token_url=str(self.token_url),
refresh_url=str(self.refresh_url) if self.refresh_url else None,
scopes=self.scopes,
)
if self.authorization_url:
authorization_code = AuthorizationCodeOAuthFlow(
authorization_url=str(self.authorization_url),
token_url=str(self.token_url),
refresh_url=str(self.refresh_url) if self.refresh_url else None,
scopes=self.scopes,
)
return OAuth2SecurityScheme(
flows=OAuthFlows(
client_credentials=client_credentials,
authorization_code=authorization_code,
),
description="OAuth2 authentication",
)
class APIKeyServerAuth(ServerAuthScheme):
"""API Key authentication for A2A server.
Validates requests using an API key in a header, query parameter, or cookie.
Attributes:
name: The name of the API key parameter (default: X-API-Key).
location: Where to look for the API key (header, query, or cookie).
api_key: The expected API key value.
"""
name: str = Field(
default="X-API-Key",
description="Name of the API key parameter",
)
location: Literal["header", "query", "cookie"] = Field(
default="header",
description="Where to look for the API key",
)
api_key: CoercedSecretStr = Field(
description="Expected API key value",
)
async def authenticate(self, token: str) -> AuthenticatedUser:
"""Authenticate using API key comparison.
Args:
token: The API key to authenticate.
Returns:
AuthenticatedUser on successful authentication.
Raises:
HTTPException: If authentication fails.
"""
if token != self.api_key.get_secret_value():
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
)
return AuthenticatedUser(
token=token,
scheme="api_key",
)
class MTLSServerAuth(ServerAuthScheme):
"""Mutual TLS authentication marker for AgentCard declaration.
This scheme is primarily for AgentCard security_schemes declaration.
Actual mTLS verification happens at the TLS/transport layer, not
at the application layer via token validation.
When configured, this signals to clients that the server requires
client certificates for authentication.
"""
description: str = Field(
default="Mutual TLS certificate authentication",
description="Description for the security scheme",
)
async def authenticate(self, token: str) -> AuthenticatedUser:
"""Return authenticated user for mTLS.
mTLS verification happens at the transport layer before this is called.
If we reach this point, the TLS handshake with client cert succeeded.
Args:
token: Certificate subject or identifier (from TLS layer).
Returns:
AuthenticatedUser indicating mTLS authentication.
"""
return AuthenticatedUser(
token=token or "mtls-verified",
scheme="mtls",
)

View File

@@ -0,0 +1,273 @@
"""Authentication utilities for A2A protocol agent communication.
Provides validation and retry logic for various authentication schemes including
OAuth2, API keys, and HTTP authentication methods.
"""
import asyncio
from collections.abc import Awaitable, Callable, MutableMapping
import hashlib
import re
import threading
from typing import Final, Literal, cast
from a2a.client.errors import A2AClientHTTPError
from a2a.types import (
APIKeySecurityScheme,
AgentCard,
HTTPAuthSecurityScheme,
OAuth2SecurityScheme,
)
from httpx import AsyncClient, Response
from crewai_a2a.auth.client_schemes import (
APIKeyAuth,
BearerTokenAuth,
ClientAuthScheme,
HTTPBasicAuth,
HTTPDigestAuth,
OAuth2AuthorizationCode,
OAuth2ClientCredentials,
)
class _AuthStore:
"""Store for authentication schemes with safe concurrent access."""
def __init__(self) -> None:
self._store: dict[str, ClientAuthScheme | None] = {}
self._lock = threading.RLock()
@staticmethod
def compute_key(auth_type: str, auth_data: str) -> str:
"""Compute a collision-resistant key using SHA-256."""
content = f"{auth_type}:{auth_data}"
return hashlib.sha256(content.encode()).hexdigest()
def set(self, key: str, auth: ClientAuthScheme | None) -> None:
"""Store an auth scheme."""
with self._lock:
self._store[key] = auth
def get(self, key: str) -> ClientAuthScheme | None:
"""Retrieve an auth scheme by key."""
with self._lock:
return self._store.get(key)
def __setitem__(self, key: str, value: ClientAuthScheme | None) -> None:
with self._lock:
self._store[key] = value
def __getitem__(self, key: str) -> ClientAuthScheme | None:
with self._lock:
return self._store[key]
_auth_store = _AuthStore()
_SCHEME_PATTERN: Final[re.Pattern[str]] = re.compile(r"(\w+)\s+(.+?)(?=,\s*\w+\s+|$)")
_PARAM_PATTERN: Final[re.Pattern[str]] = re.compile(r'(\w+)=(?:"([^"]*)"|([^\s,]+))')
_SCHEME_AUTH_MAPPING: Final[dict[type, tuple[type[ClientAuthScheme], ...]]] = {
OAuth2SecurityScheme: (
OAuth2ClientCredentials,
OAuth2AuthorizationCode,
BearerTokenAuth,
),
APIKeySecurityScheme: (APIKeyAuth,),
}
_HTTPSchemeType = Literal["basic", "digest", "bearer"]
_HTTP_SCHEME_MAPPING: Final[dict[_HTTPSchemeType, type[ClientAuthScheme]]] = {
"basic": HTTPBasicAuth,
"digest": HTTPDigestAuth,
"bearer": BearerTokenAuth,
}
def _raise_auth_mismatch(
expected_classes: type[ClientAuthScheme] | tuple[type[ClientAuthScheme], ...],
provided_auth: ClientAuthScheme,
) -> None:
"""Raise authentication mismatch error.
Args:
expected_classes: Expected authentication class or tuple of classes.
provided_auth: Actually provided authentication instance.
Raises:
A2AClientHTTPError: Always raises with 401 status code.
"""
if isinstance(expected_classes, tuple):
if len(expected_classes) == 1:
required = expected_classes[0].__name__
else:
names = [cls.__name__ for cls in expected_classes]
required = f"one of ({', '.join(names)})"
else:
required = expected_classes.__name__
msg = (
f"AgentCard requires {required} authentication, "
f"but {type(provided_auth).__name__} was provided"
)
raise A2AClientHTTPError(401, msg)
def parse_www_authenticate(header_value: str) -> dict[str, dict[str, str]]:
"""Parse WWW-Authenticate header into auth challenges.
Args:
header_value: The WWW-Authenticate header value.
Returns:
Dictionary mapping auth scheme to its parameters.
Example: {"Bearer": {"realm": "api", "scope": "read write"}}
"""
if not header_value:
return {}
challenges: dict[str, dict[str, str]] = {}
for match in _SCHEME_PATTERN.finditer(header_value):
scheme = match.group(1)
params_str = match.group(2)
params: dict[str, str] = {}
for param_match in _PARAM_PATTERN.finditer(params_str):
key = param_match.group(1)
value = param_match.group(2) or param_match.group(3)
params[key] = value
challenges[scheme] = params
return challenges
def validate_auth_against_agent_card(
agent_card: AgentCard, auth: ClientAuthScheme | None
) -> None:
"""Validate that provided auth matches AgentCard security requirements.
Args:
agent_card: The A2A AgentCard containing security requirements.
auth: User-provided authentication scheme (or None).
Raises:
A2AClientHTTPError: If auth doesn't match AgentCard requirements (status_code=401).
"""
if not agent_card.security or not agent_card.security_schemes:
return
if not auth:
msg = "AgentCard requires authentication but no auth scheme provided"
raise A2AClientHTTPError(401, msg)
first_security_req = agent_card.security[0] if agent_card.security else {}
for scheme_name in first_security_req.keys():
security_scheme_wrapper = agent_card.security_schemes.get(scheme_name)
if not security_scheme_wrapper:
continue
scheme = security_scheme_wrapper.root
if allowed_classes := _SCHEME_AUTH_MAPPING.get(type(scheme)):
if not isinstance(auth, allowed_classes):
_raise_auth_mismatch(allowed_classes, auth)
return
if isinstance(scheme, HTTPAuthSecurityScheme):
scheme_key = cast(_HTTPSchemeType, scheme.scheme.lower())
if required_class := _HTTP_SCHEME_MAPPING.get(scheme_key):
if not isinstance(auth, required_class):
_raise_auth_mismatch(required_class, auth)
return
msg = "Could not validate auth against AgentCard security requirements"
raise A2AClientHTTPError(401, msg)
async def retry_on_401(
request_func: Callable[[], Awaitable[Response]],
auth_scheme: ClientAuthScheme | None,
client: AsyncClient,
headers: MutableMapping[str, str],
max_retries: int = 3,
) -> Response:
"""Retry a request on 401 authentication error.
Handles 401 errors by:
1. Parsing WWW-Authenticate header
2. Re-acquiring credentials
3. Retrying the request
Args:
request_func: Async function that makes the HTTP request.
auth_scheme: Authentication scheme to refresh credentials with.
client: HTTP client for making requests.
headers: Request headers to update with new auth.
max_retries: Maximum number of retry attempts (default: 3).
Returns:
HTTP response from the request.
Raises:
httpx.HTTPStatusError: If retries are exhausted or auth scheme is None.
"""
last_response: Response | None = None
last_challenges: dict[str, dict[str, str]] = {}
for attempt in range(max_retries):
response = await request_func()
if response.status_code != 401:
return response
last_response = response
if auth_scheme is None:
response.raise_for_status()
return response
www_authenticate = response.headers.get("WWW-Authenticate", "")
challenges = parse_www_authenticate(www_authenticate)
last_challenges = challenges
if attempt >= max_retries - 1:
break
backoff_time = 2**attempt
await asyncio.sleep(backoff_time)
await auth_scheme.apply_auth(client, headers)
if last_response:
last_response.raise_for_status()
return last_response
msg = "retry_on_401 failed without making any requests"
if last_challenges:
challenge_info = ", ".join(
f"{scheme} (realm={params.get('realm', 'N/A')})"
for scheme, params in last_challenges.items()
)
msg = f"{msg}. Server challenges: {challenge_info}"
raise RuntimeError(msg)
def configure_auth_client(
auth: HTTPDigestAuth | APIKeyAuth, client: AsyncClient
) -> None:
"""Configure HTTP client with auth-specific settings.
Only HTTPDigestAuth and APIKeyAuth need client configuration.
Args:
auth: Authentication scheme that requires client configuration.
client: HTTP client to configure.
"""
auth.configure_client(client)

View File

@@ -0,0 +1,690 @@
"""A2A configuration types.
This module is separate from experimental.a2a to avoid circular imports.
"""
from __future__ import annotations
from pathlib import Path
from typing import Any, ClassVar, Literal, cast
import warnings
from pydantic import (
BaseModel,
ConfigDict,
Field,
FilePath,
PrivateAttr,
SecretStr,
model_validator,
)
from typing_extensions import Self, deprecated
from crewai_a2a.auth.client_schemes import ClientAuthScheme
from crewai_a2a.auth.server_schemes import ServerAuthScheme
from crewai_a2a.extensions.base import ValidatedA2AExtension
from crewai_a2a.types import ProtocolVersion, TransportType, Url
try:
from a2a.types import (
AgentCapabilities,
AgentCardSignature,
AgentInterface,
AgentProvider,
AgentSkill,
SecurityScheme,
)
from crewai_a2a.extensions.server import ServerExtension
from crewai_a2a.updates import UpdateConfig
except ImportError:
UpdateConfig: Any = Any # type: ignore[no-redef]
AgentCapabilities: Any = Any # type: ignore[no-redef]
AgentCardSignature: Any = Any # type: ignore[no-redef]
AgentInterface: Any = Any # type: ignore[no-redef]
AgentProvider: Any = Any # type: ignore[no-redef]
SecurityScheme: Any = Any # type: ignore[no-redef]
AgentSkill: Any = Any # type: ignore[no-redef]
ServerExtension: Any = Any # type: ignore[no-redef]
def _get_default_update_config() -> UpdateConfig:
from crewai_a2a.updates import StreamingConfig
return StreamingConfig()
SigningAlgorithm = Literal[
"RS256",
"RS384",
"RS512",
"ES256",
"ES384",
"ES512",
"PS256",
"PS384",
"PS512",
]
class AgentCardSigningConfig(BaseModel):
"""Configuration for AgentCard JWS signing.
Provides the private key and algorithm settings for signing AgentCards.
Either private_key_path or private_key_pem must be provided, but not both.
Attributes:
private_key_path: Path to a PEM-encoded private key file.
private_key_pem: PEM-encoded private key as a secret string.
key_id: Optional key identifier for the JWS header (kid claim).
algorithm: Signing algorithm (RS256, ES256, PS256, etc.).
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
private_key_path: FilePath | None = Field(
default=None,
description="Path to PEM-encoded private key file",
)
private_key_pem: SecretStr | None = Field(
default=None,
description="PEM-encoded private key",
)
key_id: str | None = Field(
default=None,
description="Key identifier for JWS header (kid claim)",
)
algorithm: SigningAlgorithm = Field(
default="RS256",
description="Signing algorithm (RS256, ES256, PS256, etc.)",
)
@model_validator(mode="after")
def _validate_key_source(self) -> Self:
"""Ensure exactly one key source is provided."""
has_path = self.private_key_path is not None
has_pem = self.private_key_pem is not None
if not has_path and not has_pem:
raise ValueError(
"Either private_key_path or private_key_pem must be provided"
)
if has_path and has_pem:
raise ValueError(
"Only one of private_key_path or private_key_pem should be provided"
)
return self
def get_private_key(self) -> str:
"""Get the private key content.
Returns:
The PEM-encoded private key as a string.
"""
if self.private_key_pem:
return self.private_key_pem.get_secret_value()
if self.private_key_path:
return Path(self.private_key_path).read_text()
raise ValueError("No private key configured")
class GRPCServerConfig(BaseModel):
"""gRPC server transport configuration.
Presence of this config in ServerTransportConfig.grpc enables gRPC transport.
Attributes:
host: Hostname to advertise in agent cards (default: localhost).
Use docker service name (e.g., 'web') for docker-compose setups.
port: Port for the gRPC server.
tls_cert_path: Path to TLS certificate file for gRPC.
tls_key_path: Path to TLS private key file for gRPC.
max_workers: Maximum number of workers for the gRPC thread pool.
reflection_enabled: Whether to enable gRPC reflection for debugging.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
host: str = Field(
default="localhost",
description="Hostname to advertise in agent cards for gRPC connections",
)
port: int = Field(
default=50051,
description="Port for the gRPC server",
)
tls_cert_path: str | None = Field(
default=None,
description="Path to TLS certificate file for gRPC",
)
tls_key_path: str | None = Field(
default=None,
description="Path to TLS private key file for gRPC",
)
max_workers: int = Field(
default=10,
description="Maximum number of workers for the gRPC thread pool",
)
reflection_enabled: bool = Field(
default=False,
description="Whether to enable gRPC reflection for debugging",
)
class GRPCClientConfig(BaseModel):
"""gRPC client transport configuration.
Attributes:
max_send_message_length: Maximum size for outgoing messages in bytes.
max_receive_message_length: Maximum size for incoming messages in bytes.
keepalive_time_ms: Time between keepalive pings in milliseconds.
keepalive_timeout_ms: Timeout for keepalive ping response in milliseconds.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
max_send_message_length: int | None = Field(
default=None,
description="Maximum size for outgoing messages in bytes",
)
max_receive_message_length: int | None = Field(
default=None,
description="Maximum size for incoming messages in bytes",
)
keepalive_time_ms: int | None = Field(
default=None,
description="Time between keepalive pings in milliseconds",
)
keepalive_timeout_ms: int | None = Field(
default=None,
description="Timeout for keepalive ping response in milliseconds",
)
class JSONRPCServerConfig(BaseModel):
"""JSON-RPC server transport configuration.
Presence of this config in ServerTransportConfig.jsonrpc enables JSON-RPC transport.
Attributes:
rpc_path: URL path for the JSON-RPC endpoint.
agent_card_path: URL path for the agent card endpoint.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
rpc_path: str = Field(
default="/a2a",
description="URL path for the JSON-RPC endpoint",
)
agent_card_path: str = Field(
default="/.well-known/agent-card.json",
description="URL path for the agent card endpoint",
)
class JSONRPCClientConfig(BaseModel):
"""JSON-RPC client transport configuration.
Attributes:
max_request_size: Maximum request body size in bytes.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
max_request_size: int | None = Field(
default=None,
description="Maximum request body size in bytes",
)
class HTTPJSONConfig(BaseModel):
"""HTTP+JSON transport configuration.
Presence of this config in ServerTransportConfig.http_json enables HTTP+JSON transport.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
class ServerPushNotificationConfig(BaseModel):
"""Configuration for outgoing webhook push notifications.
Controls how the server signs and delivers push notifications to clients.
Attributes:
signature_secret: Shared secret for HMAC-SHA256 signing of outgoing webhooks.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
signature_secret: SecretStr | None = Field(
default=None,
description="Shared secret for HMAC-SHA256 signing of outgoing push notifications",
)
class ServerTransportConfig(BaseModel):
"""Transport configuration for A2A server.
Groups all transport-related settings including preferred transport
and protocol-specific configurations.
Attributes:
preferred: Transport protocol for the preferred endpoint.
jsonrpc: JSON-RPC server transport configuration.
grpc: gRPC server transport configuration.
http_json: HTTP+JSON transport configuration.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
preferred: TransportType = Field(
default="JSONRPC",
description="Transport protocol for the preferred endpoint",
)
jsonrpc: JSONRPCServerConfig = Field(
default_factory=JSONRPCServerConfig,
description="JSON-RPC server transport configuration",
)
grpc: GRPCServerConfig | None = Field(
default=None,
description="gRPC server transport configuration",
)
http_json: HTTPJSONConfig | None = Field(
default=None,
description="HTTP+JSON transport configuration",
)
def _migrate_client_transport_fields(
transport: ClientTransportConfig,
transport_protocol: TransportType | None,
supported_transports: list[TransportType] | None,
) -> None:
"""Migrate deprecated transport fields to new config."""
if transport_protocol is not None:
warnings.warn(
"transport_protocol is deprecated, use transport=ClientTransportConfig(preferred=...) instead",
FutureWarning,
stacklevel=5,
)
object.__setattr__(transport, "preferred", transport_protocol)
if supported_transports is not None:
warnings.warn(
"supported_transports is deprecated, use transport=ClientTransportConfig(supported=...) instead",
FutureWarning,
stacklevel=5,
)
object.__setattr__(transport, "supported", supported_transports)
class ClientTransportConfig(BaseModel):
"""Transport configuration for A2A client.
Groups all client transport-related settings including preferred transport,
supported transports for negotiation, and protocol-specific configurations.
Transport negotiation logic:
1. If `preferred` is set and server supports it → use client's preferred
2. Otherwise, if server's preferred is in client's `supported` → use server's preferred
3. Otherwise, find first match from client's `supported` in server's interfaces
Attributes:
preferred: Client's preferred transport. If set, client preference takes priority.
supported: Transports the client can use, in order of preference.
jsonrpc: JSON-RPC client transport configuration.
grpc: gRPC client transport configuration.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
preferred: TransportType | None = Field(
default=None,
description="Client's preferred transport. If set, takes priority over server preference.",
)
supported: list[TransportType] = Field(
default_factory=lambda: cast(list[TransportType], ["JSONRPC"]),
description="Transports the client can use, in order of preference",
)
jsonrpc: JSONRPCClientConfig = Field(
default_factory=JSONRPCClientConfig,
description="JSON-RPC client transport configuration",
)
grpc: GRPCClientConfig = Field(
default_factory=GRPCClientConfig,
description="gRPC client transport configuration",
)
@deprecated(
"""
`crewai_a2a.config.A2AConfig` is deprecated and will be removed in v2.0.0,
use `crewai_a2a.config.A2AClientConfig` or `crewai_a2a.config.A2AServerConfig` instead.
""",
category=FutureWarning,
)
class A2AConfig(BaseModel):
"""Configuration for A2A protocol integration.
Deprecated:
Use A2AClientConfig instead. This class will be removed in a future version.
Attributes:
endpoint: A2A agent endpoint URL.
auth: Authentication scheme.
timeout: Request timeout in seconds.
max_turns: Maximum conversation turns with A2A agent.
response_model: Optional Pydantic model for structured A2A agent responses.
fail_fast: If True, raise error when agent unreachable; if False, skip and continue.
trust_remote_completion_status: If True, return A2A agent's result directly when completed.
updates: Update mechanism config.
client_extensions: Client-side processing hooks for tool injection and prompt augmentation.
transport: Transport configuration (preferred, supported transports, gRPC settings).
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
endpoint: Url = Field(description="A2A agent endpoint URL")
auth: ClientAuthScheme | None = Field(
default=None,
description="Authentication scheme",
)
timeout: int = Field(default=120, description="Request timeout in seconds")
max_turns: int = Field(
default=10, description="Maximum conversation turns with A2A agent"
)
response_model: type[BaseModel] | None = Field(
default=None,
description="Optional Pydantic model for structured A2A agent responses",
)
fail_fast: bool = Field(
default=True,
description="If True, raise error when agent unreachable; if False, skip",
)
trust_remote_completion_status: bool = Field(
default=False,
description="If True, return A2A result directly when completed",
)
updates: UpdateConfig = Field(
default_factory=_get_default_update_config,
description="Update mechanism config",
)
client_extensions: list[ValidatedA2AExtension] = Field(
default_factory=list,
description="Client-side processing hooks for tool injection and prompt augmentation",
)
transport: ClientTransportConfig = Field(
default_factory=ClientTransportConfig,
description="Transport configuration (preferred, supported transports, gRPC settings)",
)
transport_protocol: TransportType | None = Field(
default=None,
description="Deprecated: Use transport.preferred instead",
exclude=True,
)
supported_transports: list[TransportType] | None = Field(
default=None,
description="Deprecated: Use transport.supported instead",
exclude=True,
)
use_client_preference: bool | None = Field(
default=None,
description="Deprecated: Set transport.preferred to enable client preference",
exclude=True,
)
_parallel_delegation: bool = PrivateAttr(default=False)
@model_validator(mode="after")
def _migrate_deprecated_transport_fields(self) -> Self:
"""Migrate deprecated transport fields to new config."""
_migrate_client_transport_fields(
self.transport, self.transport_protocol, self.supported_transports
)
if self.use_client_preference is not None:
warnings.warn(
"use_client_preference is deprecated, set transport.preferred to enable client preference",
FutureWarning,
stacklevel=4,
)
if self.use_client_preference and self.transport.supported:
object.__setattr__(
self.transport, "preferred", self.transport.supported[0]
)
return self
class A2AClientConfig(BaseModel):
"""Configuration for connecting to remote A2A agents.
Attributes:
endpoint: A2A agent endpoint URL.
auth: Authentication scheme.
timeout: Request timeout in seconds.
max_turns: Maximum conversation turns with A2A agent.
response_model: Optional Pydantic model for structured A2A agent responses.
fail_fast: If True, raise error when agent unreachable; if False, skip and continue.
trust_remote_completion_status: If True, return A2A agent's result directly when completed.
updates: Update mechanism config.
accepted_output_modes: Media types the client can accept in responses.
extensions: Extension URIs the client supports (A2A protocol extensions).
client_extensions: Client-side processing hooks for tool injection and prompt augmentation.
transport: Transport configuration (preferred, supported transports, gRPC settings).
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
endpoint: Url = Field(description="A2A agent endpoint URL")
auth: ClientAuthScheme | None = Field(
default=None,
description="Authentication scheme",
)
timeout: int = Field(default=120, description="Request timeout in seconds")
max_turns: int = Field(
default=10, description="Maximum conversation turns with A2A agent"
)
response_model: type[BaseModel] | None = Field(
default=None,
description="Optional Pydantic model for structured A2A agent responses",
)
fail_fast: bool = Field(
default=True,
description="If True, raise error when agent unreachable; if False, skip",
)
trust_remote_completion_status: bool = Field(
default=False,
description="If True, return A2A result directly when completed",
)
updates: UpdateConfig = Field(
default_factory=_get_default_update_config,
description="Update mechanism config",
)
accepted_output_modes: list[str] = Field(
default_factory=lambda: ["application/json"],
description="Media types the client can accept in responses",
)
extensions: list[str] = Field(
default_factory=list,
description="Extension URIs the client supports",
)
client_extensions: list[ValidatedA2AExtension] = Field(
default_factory=list,
description="Client-side processing hooks for tool injection and prompt augmentation",
)
transport: ClientTransportConfig = Field(
default_factory=ClientTransportConfig,
description="Transport configuration (preferred, supported transports, gRPC settings)",
)
transport_protocol: TransportType | None = Field(
default=None,
description="Deprecated: Use transport.preferred instead",
exclude=True,
)
supported_transports: list[TransportType] | None = Field(
default=None,
description="Deprecated: Use transport.supported instead",
exclude=True,
)
_parallel_delegation: bool = PrivateAttr(default=False)
@model_validator(mode="after")
def _migrate_deprecated_transport_fields(self) -> Self:
"""Migrate deprecated transport fields to new config."""
_migrate_client_transport_fields(
self.transport, self.transport_protocol, self.supported_transports
)
return self
class A2AServerConfig(BaseModel):
"""Configuration for exposing a Crew or Agent as an A2A server.
All fields correspond to A2A AgentCard fields. Fields like name, description,
and skills can be auto-derived from the Crew/Agent if not provided.
Attributes:
name: Human-readable name for the agent.
description: Human-readable description of the agent.
version: Version string for the agent card.
skills: List of agent skills/capabilities.
default_input_modes: Default supported input MIME types.
default_output_modes: Default supported output MIME types.
capabilities: Declaration of optional capabilities.
protocol_version: A2A protocol version this agent supports.
provider: Information about the agent's service provider.
documentation_url: URL to the agent's documentation.
icon_url: URL to an icon for the agent.
additional_interfaces: Additional supported interfaces.
security: Security requirement objects for all interactions.
security_schemes: Security schemes available to authorize requests.
supports_authenticated_extended_card: Whether agent provides extended card to authenticated users.
url: Preferred endpoint URL for the agent.
signing_config: Configuration for signing the AgentCard with JWS.
signatures: Deprecated. Pre-computed JWS signatures. Use signing_config instead.
server_extensions: Server-side A2A protocol extensions with on_request/on_response hooks.
push_notifications: Configuration for outgoing push notifications.
transport: Transport configuration (preferred transport, gRPC, REST settings).
auth: Authentication scheme for A2A endpoints.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
name: str | None = Field(
default=None,
description="Human-readable name for the agent. Auto-derived from Crew/Agent if not provided.",
)
description: str | None = Field(
default=None,
description="Human-readable description of the agent. Auto-derived from Crew/Agent if not provided.",
)
version: str = Field(
default="1.0.0",
description="Version string for the agent card",
)
skills: list[AgentSkill] = Field(
default_factory=list,
description="List of agent skills. Auto-derived from tasks/tools if not provided.",
)
default_input_modes: list[str] = Field(
default_factory=lambda: ["text/plain", "application/json"],
description="Default supported input MIME types",
)
default_output_modes: list[str] = Field(
default_factory=lambda: ["text/plain", "application/json"],
description="Default supported output MIME types",
)
capabilities: AgentCapabilities = Field(
default_factory=lambda: AgentCapabilities(
streaming=True,
push_notifications=False,
),
description="Declaration of optional capabilities supported by the agent",
)
protocol_version: ProtocolVersion = Field(
default="0.3.0",
description="A2A protocol version this agent supports",
)
provider: AgentProvider | None = Field(
default=None,
description="Information about the agent's service provider",
)
documentation_url: Url | None = Field(
default=None,
description="URL to the agent's documentation",
)
icon_url: Url | None = Field(
default=None,
description="URL to an icon for the agent",
)
additional_interfaces: list[AgentInterface] = Field(
default_factory=list,
description="Additional supported interfaces.",
)
security: list[dict[str, list[str]]] = Field(
default_factory=list,
description="Security requirement objects for all agent interactions",
)
security_schemes: dict[str, SecurityScheme] = Field(
default_factory=dict,
description="Security schemes available to authorize requests",
)
supports_authenticated_extended_card: bool = Field(
default=False,
description="Whether agent provides extended card to authenticated users",
)
url: Url | None = Field(
default=None,
description="Preferred endpoint URL for the agent. Set at runtime if not provided.",
)
signing_config: AgentCardSigningConfig | None = Field(
default=None,
description="Configuration for signing the AgentCard with JWS",
)
signatures: list[AgentCardSignature] | None = Field(
default=None,
description="Deprecated: Use signing_config instead. Pre-computed JWS signatures for the AgentCard.",
exclude=True,
deprecated=True,
)
server_extensions: list[ServerExtension] = Field(
default_factory=list,
description="Server-side A2A protocol extensions that modify agent behavior",
)
push_notifications: ServerPushNotificationConfig | None = Field(
default=None,
description="Configuration for outgoing push notifications",
)
transport: ServerTransportConfig = Field(
default_factory=ServerTransportConfig,
description="Transport configuration (preferred transport, gRPC, REST settings)",
)
preferred_transport: TransportType | None = Field(
default=None,
description="Deprecated: Use transport.preferred instead",
exclude=True,
deprecated=True,
)
auth: ServerAuthScheme | None = Field(
default=None,
description="Authentication scheme for A2A endpoints. Defaults to SimpleTokenAuth using AUTH_TOKEN env var.",
)
@model_validator(mode="after")
def _migrate_deprecated_fields(self) -> Self:
"""Migrate deprecated fields to new config."""
if self.preferred_transport is not None:
warnings.warn(
"preferred_transport is deprecated, use transport=ServerTransportConfig(preferred=...) instead",
FutureWarning,
stacklevel=4,
)
object.__setattr__(self.transport, "preferred", self.preferred_transport)
if self.signatures is not None:
warnings.warn(
"signatures is deprecated, use signing_config=AgentCardSigningConfig(...) instead. "
"The signatures field will be removed in v2.0.0.",
FutureWarning,
stacklevel=4,
)
return self

View File

@@ -0,0 +1,491 @@
"""A2A error codes and error response utilities.
This module provides a centralized mapping of all A2A protocol error codes
as defined in the A2A specification, plus custom CrewAI extensions.
Error codes follow JSON-RPC 2.0 conventions:
- -32700 to -32600: Standard JSON-RPC errors
- -32099 to -32000: Server errors (A2A-specific)
- -32768 to -32100: Reserved for implementation-defined errors
"""
from __future__ import annotations
from dataclasses import dataclass, field
from enum import IntEnum
from typing import Any
from a2a.client.errors import A2AClientTimeoutError
class A2APollingTimeoutError(A2AClientTimeoutError):
"""Raised when polling exceeds the configured timeout."""
class A2AErrorCode(IntEnum):
"""A2A protocol error codes.
Codes follow JSON-RPC 2.0 specification with A2A-specific extensions.
"""
# JSON-RPC 2.0 Standard Errors (-32700 to -32600)
JSON_PARSE_ERROR = -32700
"""Invalid JSON was received by the server."""
INVALID_REQUEST = -32600
"""The JSON sent is not a valid Request object."""
METHOD_NOT_FOUND = -32601
"""The method does not exist / is not available."""
INVALID_PARAMS = -32602
"""Invalid method parameter(s)."""
INTERNAL_ERROR = -32603
"""Internal JSON-RPC error."""
# A2A-Specific Errors (-32099 to -32000)
TASK_NOT_FOUND = -32001
"""The specified task was not found."""
TASK_NOT_CANCELABLE = -32002
"""The task cannot be canceled (already completed/failed)."""
PUSH_NOTIFICATION_NOT_SUPPORTED = -32003
"""Push notifications are not supported by this agent."""
UNSUPPORTED_OPERATION = -32004
"""The requested operation is not supported."""
CONTENT_TYPE_NOT_SUPPORTED = -32005
"""Incompatible content types between client and server."""
INVALID_AGENT_RESPONSE = -32006
"""The agent produced an invalid response."""
# CrewAI Custom Extensions (-32768 to -32100)
UNSUPPORTED_VERSION = -32009
"""The requested A2A protocol version is not supported."""
UNSUPPORTED_EXTENSION = -32010
"""Client does not support required protocol extensions."""
AUTHENTICATION_REQUIRED = -32011
"""Authentication is required for this operation."""
AUTHORIZATION_FAILED = -32012
"""Authorization check failed (insufficient permissions)."""
RATE_LIMIT_EXCEEDED = -32013
"""Rate limit exceeded for this client/operation."""
TASK_TIMEOUT = -32014
"""Task execution timed out."""
TRANSPORT_NEGOTIATION_FAILED = -32015
"""Failed to negotiate a compatible transport protocol."""
CONTEXT_NOT_FOUND = -32016
"""The specified context was not found."""
SKILL_NOT_FOUND = -32017
"""The specified skill was not found."""
ARTIFACT_NOT_FOUND = -32018
"""The specified artifact was not found."""
# Error code to default message mapping
ERROR_MESSAGES: dict[int, str] = {
A2AErrorCode.JSON_PARSE_ERROR: "Parse error",
A2AErrorCode.INVALID_REQUEST: "Invalid Request",
A2AErrorCode.METHOD_NOT_FOUND: "Method not found",
A2AErrorCode.INVALID_PARAMS: "Invalid params",
A2AErrorCode.INTERNAL_ERROR: "Internal error",
A2AErrorCode.TASK_NOT_FOUND: "Task not found",
A2AErrorCode.TASK_NOT_CANCELABLE: "Task not cancelable",
A2AErrorCode.PUSH_NOTIFICATION_NOT_SUPPORTED: "Push Notification is not supported",
A2AErrorCode.UNSUPPORTED_OPERATION: "This operation is not supported",
A2AErrorCode.CONTENT_TYPE_NOT_SUPPORTED: "Incompatible content types",
A2AErrorCode.INVALID_AGENT_RESPONSE: "Invalid agent response",
A2AErrorCode.UNSUPPORTED_VERSION: "Unsupported A2A version",
A2AErrorCode.UNSUPPORTED_EXTENSION: "Client does not support required extensions",
A2AErrorCode.AUTHENTICATION_REQUIRED: "Authentication required",
A2AErrorCode.AUTHORIZATION_FAILED: "Authorization failed",
A2AErrorCode.RATE_LIMIT_EXCEEDED: "Rate limit exceeded",
A2AErrorCode.TASK_TIMEOUT: "Task execution timed out",
A2AErrorCode.TRANSPORT_NEGOTIATION_FAILED: "Transport negotiation failed",
A2AErrorCode.CONTEXT_NOT_FOUND: "Context not found",
A2AErrorCode.SKILL_NOT_FOUND: "Skill not found",
A2AErrorCode.ARTIFACT_NOT_FOUND: "Artifact not found",
}
@dataclass
class A2AError(Exception):
"""Base exception for A2A protocol errors.
Attributes:
code: The A2A/JSON-RPC error code.
message: Human-readable error message.
data: Optional additional error data.
"""
code: int
message: str | None = None
data: Any = None
def __post_init__(self) -> None:
if self.message is None:
self.message = ERROR_MESSAGES.get(self.code, "Unknown error")
super().__init__(self.message)
def to_dict(self) -> dict[str, Any]:
"""Convert to JSON-RPC error object format."""
error: dict[str, Any] = {
"code": self.code,
"message": self.message,
}
if self.data is not None:
error["data"] = self.data
return error
def to_response(self, request_id: str | int | None = None) -> dict[str, Any]:
"""Convert to full JSON-RPC error response."""
return {
"jsonrpc": "2.0",
"error": self.to_dict(),
"id": request_id,
}
@dataclass
class JSONParseError(A2AError):
"""Invalid JSON was received."""
code: int = field(default=A2AErrorCode.JSON_PARSE_ERROR, init=False)
@dataclass
class InvalidRequestError(A2AError):
"""The JSON sent is not a valid Request object."""
code: int = field(default=A2AErrorCode.INVALID_REQUEST, init=False)
@dataclass
class MethodNotFoundError(A2AError):
"""The method does not exist / is not available."""
code: int = field(default=A2AErrorCode.METHOD_NOT_FOUND, init=False)
method: str | None = None
def __post_init__(self) -> None:
if self.message is None and self.method:
self.message = f"Method not found: {self.method}"
super().__post_init__()
@dataclass
class InvalidParamsError(A2AError):
"""Invalid method parameter(s)."""
code: int = field(default=A2AErrorCode.INVALID_PARAMS, init=False)
param: str | None = None
reason: str | None = None
def __post_init__(self) -> None:
if self.message is None:
if self.param and self.reason:
self.message = f"Invalid parameter '{self.param}': {self.reason}"
elif self.param:
self.message = f"Invalid parameter: {self.param}"
super().__post_init__()
@dataclass
class InternalError(A2AError):
"""Internal JSON-RPC error."""
code: int = field(default=A2AErrorCode.INTERNAL_ERROR, init=False)
@dataclass
class TaskNotFoundError(A2AError):
"""The specified task was not found."""
code: int = field(default=A2AErrorCode.TASK_NOT_FOUND, init=False)
task_id: str | None = None
def __post_init__(self) -> None:
if self.message is None and self.task_id:
self.message = f"Task not found: {self.task_id}"
super().__post_init__()
@dataclass
class TaskNotCancelableError(A2AError):
"""The task cannot be canceled."""
code: int = field(default=A2AErrorCode.TASK_NOT_CANCELABLE, init=False)
task_id: str | None = None
reason: str | None = None
def __post_init__(self) -> None:
if self.message is None:
if self.task_id and self.reason:
self.message = f"Task {self.task_id} cannot be canceled: {self.reason}"
elif self.task_id:
self.message = f"Task {self.task_id} cannot be canceled"
super().__post_init__()
@dataclass
class PushNotificationNotSupportedError(A2AError):
"""Push notifications are not supported."""
code: int = field(default=A2AErrorCode.PUSH_NOTIFICATION_NOT_SUPPORTED, init=False)
@dataclass
class UnsupportedOperationError(A2AError):
"""The requested operation is not supported."""
code: int = field(default=A2AErrorCode.UNSUPPORTED_OPERATION, init=False)
operation: str | None = None
def __post_init__(self) -> None:
if self.message is None and self.operation:
self.message = f"Operation not supported: {self.operation}"
super().__post_init__()
@dataclass
class ContentTypeNotSupportedError(A2AError):
"""Incompatible content types."""
code: int = field(default=A2AErrorCode.CONTENT_TYPE_NOT_SUPPORTED, init=False)
requested_types: list[str] | None = None
supported_types: list[str] | None = None
def __post_init__(self) -> None:
if self.message is None and self.requested_types and self.supported_types:
self.message = (
f"Content type not supported. Requested: {self.requested_types}, "
f"Supported: {self.supported_types}"
)
super().__post_init__()
@dataclass
class InvalidAgentResponseError(A2AError):
"""The agent produced an invalid response."""
code: int = field(default=A2AErrorCode.INVALID_AGENT_RESPONSE, init=False)
@dataclass
class UnsupportedVersionError(A2AError):
"""The requested A2A version is not supported."""
code: int = field(default=A2AErrorCode.UNSUPPORTED_VERSION, init=False)
requested_version: str | None = None
supported_versions: list[str] | None = None
def __post_init__(self) -> None:
if self.message is None and self.requested_version:
msg = f"Unsupported A2A version: {self.requested_version}"
if self.supported_versions:
msg += f". Supported versions: {', '.join(self.supported_versions)}"
self.message = msg
super().__post_init__()
@dataclass
class UnsupportedExtensionError(A2AError):
"""Client does not support required extensions."""
code: int = field(default=A2AErrorCode.UNSUPPORTED_EXTENSION, init=False)
required_extensions: list[str] | None = None
def __post_init__(self) -> None:
if self.message is None and self.required_extensions:
self.message = f"Client does not support required extensions: {', '.join(self.required_extensions)}"
super().__post_init__()
@dataclass
class AuthenticationRequiredError(A2AError):
"""Authentication is required."""
code: int = field(default=A2AErrorCode.AUTHENTICATION_REQUIRED, init=False)
@dataclass
class AuthorizationFailedError(A2AError):
"""Authorization check failed."""
code: int = field(default=A2AErrorCode.AUTHORIZATION_FAILED, init=False)
required_scope: str | None = None
def __post_init__(self) -> None:
if self.message is None and self.required_scope:
self.message = (
f"Authorization failed. Required scope: {self.required_scope}"
)
super().__post_init__()
@dataclass
class RateLimitExceededError(A2AError):
"""Rate limit exceeded."""
code: int = field(default=A2AErrorCode.RATE_LIMIT_EXCEEDED, init=False)
retry_after: int | None = None
def __post_init__(self) -> None:
if self.message is None and self.retry_after:
self.message = (
f"Rate limit exceeded. Retry after {self.retry_after} seconds"
)
if self.retry_after:
self.data = {"retry_after": self.retry_after}
super().__post_init__()
@dataclass
class TaskTimeoutError(A2AError):
"""Task execution timed out."""
code: int = field(default=A2AErrorCode.TASK_TIMEOUT, init=False)
task_id: str | None = None
timeout_seconds: float | None = None
def __post_init__(self) -> None:
if self.message is None:
if self.task_id and self.timeout_seconds:
self.message = (
f"Task {self.task_id} timed out after {self.timeout_seconds}s"
)
elif self.task_id:
self.message = f"Task {self.task_id} timed out"
super().__post_init__()
@dataclass
class TransportNegotiationFailedError(A2AError):
"""Failed to negotiate a compatible transport protocol."""
code: int = field(default=A2AErrorCode.TRANSPORT_NEGOTIATION_FAILED, init=False)
client_transports: list[str] | None = None
server_transports: list[str] | None = None
def __post_init__(self) -> None:
if self.message is None and self.client_transports and self.server_transports:
self.message = (
f"Transport negotiation failed. Client: {self.client_transports}, "
f"Server: {self.server_transports}"
)
super().__post_init__()
@dataclass
class ContextNotFoundError(A2AError):
"""The specified context was not found."""
code: int = field(default=A2AErrorCode.CONTEXT_NOT_FOUND, init=False)
context_id: str | None = None
def __post_init__(self) -> None:
if self.message is None and self.context_id:
self.message = f"Context not found: {self.context_id}"
super().__post_init__()
@dataclass
class SkillNotFoundError(A2AError):
"""The specified skill was not found."""
code: int = field(default=A2AErrorCode.SKILL_NOT_FOUND, init=False)
skill_id: str | None = None
def __post_init__(self) -> None:
if self.message is None and self.skill_id:
self.message = f"Skill not found: {self.skill_id}"
super().__post_init__()
@dataclass
class ArtifactNotFoundError(A2AError):
"""The specified artifact was not found."""
code: int = field(default=A2AErrorCode.ARTIFACT_NOT_FOUND, init=False)
artifact_id: str | None = None
def __post_init__(self) -> None:
if self.message is None and self.artifact_id:
self.message = f"Artifact not found: {self.artifact_id}"
super().__post_init__()
def create_error_response(
code: int | A2AErrorCode,
message: str | None = None,
data: Any = None,
request_id: str | int | None = None,
) -> dict[str, Any]:
"""Create a JSON-RPC error response.
Args:
code: Error code (A2AErrorCode or int).
message: Optional error message (uses default if not provided).
data: Optional additional error data.
request_id: Request ID for correlation.
Returns:
Dict in JSON-RPC error response format.
"""
error = A2AError(code=int(code), message=message, data=data)
return error.to_response(request_id)
def is_retryable_error(code: int) -> bool:
"""Check if an error is potentially retryable.
Args:
code: Error code to check.
Returns:
True if the error might be resolved by retrying.
"""
retryable_codes = {
A2AErrorCode.INTERNAL_ERROR,
A2AErrorCode.RATE_LIMIT_EXCEEDED,
A2AErrorCode.TASK_TIMEOUT,
}
return code in retryable_codes
def is_client_error(code: int) -> bool:
"""Check if an error is a client-side error.
Args:
code: Error code to check.
Returns:
True if the error is due to client request issues.
"""
client_error_codes = {
A2AErrorCode.JSON_PARSE_ERROR,
A2AErrorCode.INVALID_REQUEST,
A2AErrorCode.METHOD_NOT_FOUND,
A2AErrorCode.INVALID_PARAMS,
A2AErrorCode.TASK_NOT_FOUND,
A2AErrorCode.CONTENT_TYPE_NOT_SUPPORTED,
A2AErrorCode.UNSUPPORTED_VERSION,
A2AErrorCode.UNSUPPORTED_EXTENSION,
A2AErrorCode.CONTEXT_NOT_FOUND,
A2AErrorCode.SKILL_NOT_FOUND,
A2AErrorCode.ARTIFACT_NOT_FOUND,
}
return code in client_error_codes

View File

@@ -0,0 +1,37 @@
"""A2A Protocol Extensions for CrewAI.
This module contains extensions to the A2A (Agent-to-Agent) protocol.
**Client-side extensions** (A2AExtension) allow customizing how the A2A wrapper
processes requests and responses during delegation to remote agents. These provide
hooks for tool injection, prompt augmentation, and response processing.
**Server-side extensions** (ServerExtension) allow agents to offer additional
functionality beyond the core A2A specification. Clients activate extensions
via the X-A2A-Extensions header.
See: https://a2a-protocol.org/latest/topics/extensions/
"""
from crewai_a2a.extensions.base import (
A2AExtension,
ConversationState,
ExtensionRegistry,
ValidatedA2AExtension,
)
from crewai_a2a.extensions.server import (
ExtensionContext,
ServerExtension,
ServerExtensionRegistry,
)
__all__ = [
"A2AExtension",
"ConversationState",
"ExtensionContext",
"ExtensionRegistry",
"ServerExtension",
"ServerExtensionRegistry",
"ValidatedA2AExtension",
]

View File

@@ -0,0 +1,237 @@
"""Base extension interface for CrewAI A2A wrapper processing hooks.
This module defines the protocol for extending CrewAI's A2A wrapper functionality
with custom logic for tool injection, prompt augmentation, and response processing.
Note: These are CrewAI-specific processing hooks, NOT A2A protocol extensions.
A2A protocol extensions are capability declarations using AgentExtension objects
in AgentCard.capabilities.extensions, activated via the A2A-Extensions HTTP header.
See: https://a2a-protocol.org/latest/topics/extensions/
"""
from __future__ import annotations
from collections.abc import Sequence
from typing import TYPE_CHECKING, Annotated, Any, Protocol, runtime_checkable
from pydantic import BeforeValidator
if TYPE_CHECKING:
from a2a.types import Message
from crewai.agent.core import Agent
def _validate_a2a_extension(v: Any) -> Any:
"""Validate that value implements A2AExtension protocol."""
if not isinstance(v, A2AExtension):
raise ValueError(
f"Value must implement A2AExtension protocol. "
f"Got {type(v).__name__} which is missing required methods."
)
return v
ValidatedA2AExtension = Annotated[Any, BeforeValidator(_validate_a2a_extension)]
@runtime_checkable
class ConversationState(Protocol):
"""Protocol for extension-specific conversation state.
Extensions can define their own state classes that implement this protocol
to track conversation-specific data extracted from message history.
"""
def is_ready(self) -> bool:
"""Check if the state indicates readiness for some action.
Returns:
True if the state is ready, False otherwise.
"""
...
@runtime_checkable
class A2AExtension(Protocol):
"""Protocol for A2A wrapper extensions.
Extensions can implement this protocol to inject custom logic into
the A2A conversation flow at various integration points.
Example:
class MyExtension:
def inject_tools(self, agent: Agent) -> None:
# Add custom tools to the agent
pass
def extract_state_from_history(
self, conversation_history: Sequence[Message]
) -> ConversationState | None:
# Extract state from conversation
return None
def augment_prompt(
self, base_prompt: str, conversation_state: ConversationState | None
) -> str:
# Add custom instructions
return base_prompt
def process_response(
self, agent_response: Any, conversation_state: ConversationState | None
) -> Any:
# Modify response if needed
return agent_response
"""
def inject_tools(self, agent: Agent) -> None:
"""Inject extension-specific tools into the agent.
Called when an agent is wrapped with A2A capabilities. Extensions
can add tools that enable extension-specific functionality.
Args:
agent: The agent instance to inject tools into.
"""
...
def extract_state_from_history(
self, conversation_history: Sequence[Message]
) -> ConversationState | None:
"""Extract extension-specific state from conversation history.
Called during prompt augmentation to allow extensions to analyze
the conversation history and extract relevant state information.
Args:
conversation_history: The sequence of A2A messages exchanged.
Returns:
Extension-specific conversation state, or None if no relevant state.
"""
...
def augment_prompt(
self,
base_prompt: str,
conversation_state: ConversationState | None,
) -> str:
"""Augment the task prompt with extension-specific instructions.
Called during prompt augmentation to allow extensions to add
custom instructions based on conversation state.
Args:
base_prompt: The base prompt to augment.
conversation_state: Extension-specific state from extract_state_from_history.
Returns:
The augmented prompt with extension-specific instructions.
"""
...
def process_response(
self,
agent_response: Any,
conversation_state: ConversationState | None,
) -> Any:
"""Process and potentially modify the agent response.
Called after parsing the agent's response, allowing extensions to
enhance or modify the response based on conversation state.
Args:
agent_response: The parsed agent response.
conversation_state: Extension-specific state from extract_state_from_history.
Returns:
The processed agent response (may be modified or original).
"""
...
class ExtensionRegistry:
"""Registry for managing A2A extensions.
Maintains a collection of extensions and provides methods to invoke
their hooks at various integration points.
"""
def __init__(self) -> None:
"""Initialize the extension registry."""
self._extensions: list[A2AExtension] = []
def register(self, extension: A2AExtension) -> None:
"""Register an extension.
Args:
extension: The extension to register.
"""
self._extensions.append(extension)
def inject_all_tools(self, agent: Agent) -> None:
"""Inject tools from all registered extensions.
Args:
agent: The agent instance to inject tools into.
"""
for extension in self._extensions:
extension.inject_tools(agent)
def extract_all_states(
self, conversation_history: Sequence[Message]
) -> dict[type[A2AExtension], ConversationState]:
"""Extract conversation states from all registered extensions.
Args:
conversation_history: The sequence of A2A messages exchanged.
Returns:
Mapping of extension types to their conversation states.
"""
states: dict[type[A2AExtension], ConversationState] = {}
for extension in self._extensions:
state = extension.extract_state_from_history(conversation_history)
if state is not None:
states[type(extension)] = state
return states
def augment_prompt_with_all(
self,
base_prompt: str,
extension_states: dict[type[A2AExtension], ConversationState],
) -> str:
"""Augment prompt with instructions from all registered extensions.
Args:
base_prompt: The base prompt to augment.
extension_states: Mapping of extension types to conversation states.
Returns:
The fully augmented prompt.
"""
augmented = base_prompt
for extension in self._extensions:
state = extension_states.get(type(extension))
augmented = extension.augment_prompt(augmented, state)
return augmented
def process_response_with_all(
self,
agent_response: Any,
extension_states: dict[type[A2AExtension], ConversationState],
) -> Any:
"""Process response through all registered extensions.
Args:
agent_response: The parsed agent response.
extension_states: Mapping of extension types to conversation states.
Returns:
The processed agent response.
"""
processed = agent_response
for extension in self._extensions:
state = extension_states.get(type(extension))
processed = extension.process_response(processed, state)
return processed

View File

@@ -0,0 +1,170 @@
"""A2A Protocol extension utilities.
This module provides utilities for working with A2A protocol extensions as
defined in the A2A specification. Extensions are capability declarations in
AgentCard.capabilities.extensions using AgentExtension objects, activated
via the X-A2A-Extensions HTTP header.
See: https://a2a-protocol.org/latest/topics/extensions/
"""
from __future__ import annotations
from typing import Any
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
from a2a.extensions.common import (
HTTP_EXTENSION_HEADER,
)
from a2a.types import AgentCard, AgentExtension
from crewai_a2a.config import A2AClientConfig, A2AConfig
from crewai_a2a.extensions.base import ExtensionRegistry
def get_extensions_from_config(
a2a_config: list[A2AConfig | A2AClientConfig] | A2AConfig | A2AClientConfig,
) -> list[str]:
"""Extract extension URIs from A2A configuration.
Args:
a2a_config: A2A configuration (single or list).
Returns:
Deduplicated list of extension URIs from all configs.
"""
configs = a2a_config if isinstance(a2a_config, list) else [a2a_config]
seen: set[str] = set()
result: list[str] = []
for config in configs:
if not isinstance(config, A2AClientConfig):
continue
for uri in config.extensions:
if uri not in seen:
seen.add(uri)
result.append(uri)
return result
class ExtensionsMiddleware(ClientCallInterceptor):
"""Middleware to add X-A2A-Extensions header to requests.
This middleware adds the extensions header to all outgoing requests,
declaring which A2A protocol extensions the client supports.
"""
def __init__(self, extensions: list[str]) -> None:
"""Initialize with extension URIs.
Args:
extensions: List of extension URIs the client supports.
"""
self._extensions = extensions
async def intercept(
self,
method_name: str,
request_payload: dict[str, Any],
http_kwargs: dict[str, Any],
agent_card: AgentCard | None,
context: ClientCallContext | None,
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Add extensions header to the request.
Args:
method_name: The A2A method being called.
request_payload: The JSON-RPC request payload.
http_kwargs: HTTP request kwargs (headers, etc).
agent_card: The target agent's card.
context: Optional call context.
Returns:
Tuple of (request_payload, modified_http_kwargs).
"""
if self._extensions:
headers = http_kwargs.setdefault("headers", {})
headers[HTTP_EXTENSION_HEADER] = ",".join(self._extensions)
return request_payload, http_kwargs
def validate_required_extensions(
agent_card: AgentCard,
client_extensions: list[str] | None,
) -> list[AgentExtension]:
"""Validate that client supports all required extensions from agent.
Args:
agent_card: The agent's card with declared extensions.
client_extensions: Extension URIs the client supports.
Returns:
List of unsupported required extensions.
Raises:
None - returns list of unsupported extensions for caller to handle.
"""
unsupported: list[AgentExtension] = []
client_set = set(client_extensions or [])
if not agent_card.capabilities or not agent_card.capabilities.extensions:
return unsupported
unsupported.extend(
ext
for ext in agent_card.capabilities.extensions
if ext.required and ext.uri not in client_set
)
return unsupported
def create_extension_registry_from_config(
a2a_config: list[A2AConfig | A2AClientConfig] | A2AConfig | A2AClientConfig,
) -> ExtensionRegistry:
"""Create an extension registry from A2A client configuration.
Extracts client_extensions from each A2AClientConfig and registers them
with the ExtensionRegistry. These extensions provide CrewAI-specific
processing hooks (tool injection, prompt augmentation, response processing).
Note: A2A protocol extensions (URI strings sent via X-A2A-Extensions header)
are handled separately via get_extensions_from_config() and ExtensionsMiddleware.
Args:
a2a_config: A2A configuration (single or list).
Returns:
Extension registry with all client_extensions registered.
Example:
class LoggingExtension:
def inject_tools(self, agent): pass
def extract_state_from_history(self, history): return None
def augment_prompt(self, prompt, state): return prompt
def process_response(self, response, state):
print(f"Response: {response}")
return response
config = A2AClientConfig(
endpoint="https://agent.example.com",
client_extensions=[LoggingExtension()],
)
registry = create_extension_registry_from_config(config)
"""
registry = ExtensionRegistry()
configs = a2a_config if isinstance(a2a_config, list) else [a2a_config]
seen: set[int] = set()
for config in configs:
if isinstance(config, (A2AConfig, A2AClientConfig)):
client_exts = getattr(config, "client_extensions", [])
for extension in client_exts:
ext_id = id(extension)
if ext_id not in seen:
seen.add(ext_id)
registry.register(extension)
return registry

View File

@@ -0,0 +1,305 @@
"""A2A protocol server extensions for CrewAI agents.
This module provides the base class and context for implementing A2A protocol
extensions on the server side. Extensions allow agents to offer additional
functionality beyond the core A2A specification.
See: https://a2a-protocol.org/latest/topics/extensions/
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
import logging
from typing import TYPE_CHECKING, Annotated, Any
from a2a.types import AgentExtension
from pydantic_core import CoreSchema, core_schema
if TYPE_CHECKING:
from a2a.server.context import ServerCallContext
from pydantic import GetCoreSchemaHandler
logger = logging.getLogger(__name__)
@dataclass
class ExtensionContext:
"""Context passed to extension hooks during request processing.
Provides access to request metadata, client extensions, and shared state
that extensions can read from and write to.
Attributes:
metadata: Request metadata dict, includes extension-namespaced keys.
client_extensions: Set of extension URIs the client declared support for.
state: Mutable dict for extensions to share data during request lifecycle.
server_context: The underlying A2A server call context.
"""
metadata: dict[str, Any]
client_extensions: set[str]
state: dict[str, Any] = field(default_factory=dict)
server_context: ServerCallContext | None = None
def get_extension_metadata(self, uri: str, key: str) -> Any | None:
"""Get extension-specific metadata value.
Extension metadata uses namespaced keys in the format:
"{extension_uri}/{key}"
Args:
uri: The extension URI.
key: The metadata key within the extension namespace.
Returns:
The metadata value, or None if not present.
"""
full_key = f"{uri}/{key}"
return self.metadata.get(full_key)
def set_extension_metadata(self, uri: str, key: str, value: Any) -> None:
"""Set extension-specific metadata value.
Args:
uri: The extension URI.
key: The metadata key within the extension namespace.
value: The value to set.
"""
full_key = f"{uri}/{key}"
self.metadata[full_key] = value
class ServerExtension(ABC):
"""Base class for A2A protocol server extensions.
Subclass this to create custom extensions that modify agent behavior
when clients activate them. Extensions are identified by URI and can
be marked as required.
Example:
class SamplingExtension(ServerExtension):
uri = "urn:crewai:ext:sampling/v1"
required = True
def __init__(self, max_tokens: int = 4096):
self.max_tokens = max_tokens
@property
def params(self) -> dict[str, Any]:
return {"max_tokens": self.max_tokens}
async def on_request(self, context: ExtensionContext) -> None:
limit = context.get_extension_metadata(self.uri, "limit")
if limit:
context.state["token_limit"] = int(limit)
async def on_response(self, context: ExtensionContext, result: Any) -> Any:
return result
"""
uri: Annotated[str, "Extension URI identifier. Must be unique."]
required: Annotated[bool, "Whether clients must support this extension."] = False
description: Annotated[
str | None, "Human-readable description of the extension."
] = None
@classmethod
def __get_pydantic_core_schema__(
cls,
_source_type: Any,
_handler: GetCoreSchemaHandler,
) -> CoreSchema:
"""Tell Pydantic how to validate ServerExtension instances."""
return core_schema.is_instance_schema(cls)
@property
def params(self) -> dict[str, Any] | None:
"""Extension parameters to advertise in AgentCard.
Override this property to expose configuration that clients can read.
Returns:
Dict of parameter names to values, or None.
"""
return None
def agent_extension(self) -> AgentExtension:
"""Generate the AgentExtension object for the AgentCard.
Returns:
AgentExtension with this extension's URI, required flag, and params.
"""
return AgentExtension(
uri=self.uri,
required=self.required if self.required else None,
description=self.description,
params=self.params,
)
def is_active(self, context: ExtensionContext) -> bool:
"""Check if this extension is active for the current request.
An extension is active if the client declared support for it.
Args:
context: The extension context for the current request.
Returns:
True if the client supports this extension.
"""
return self.uri in context.client_extensions
@abstractmethod
async def on_request(self, context: ExtensionContext) -> None:
"""Called before agent execution if extension is active.
Use this hook to:
- Read extension-specific metadata from the request
- Set up state for the execution
- Modify execution parameters via context.state
Args:
context: The extension context with request metadata and state.
"""
...
@abstractmethod
async def on_response(self, context: ExtensionContext, result: Any) -> Any:
"""Called after agent execution if extension is active.
Use this hook to:
- Modify or enhance the result
- Add extension-specific metadata to the response
- Clean up any resources
Args:
context: The extension context with request metadata and state.
result: The agent execution result.
Returns:
The result, potentially modified.
"""
...
class ServerExtensionRegistry:
"""Registry for managing server-side A2A protocol extensions.
Collects extensions and provides methods to generate AgentCapabilities
and invoke extension hooks during request processing.
"""
def __init__(self, extensions: list[ServerExtension] | None = None) -> None:
"""Initialize the registry with optional extensions.
Args:
extensions: Initial list of extensions to register.
"""
self._extensions: list[ServerExtension] = list(extensions) if extensions else []
self._by_uri: dict[str, ServerExtension] = {
ext.uri: ext for ext in self._extensions
}
def register(self, extension: ServerExtension) -> None:
"""Register an extension.
Args:
extension: The extension to register.
Raises:
ValueError: If an extension with the same URI is already registered.
"""
if extension.uri in self._by_uri:
raise ValueError(f"Extension already registered: {extension.uri}")
self._extensions.append(extension)
self._by_uri[extension.uri] = extension
def get_agent_extensions(self) -> list[AgentExtension]:
"""Get AgentExtension objects for all registered extensions.
Returns:
List of AgentExtension objects for the AgentCard.
"""
return [ext.agent_extension() for ext in self._extensions]
def get_extension(self, uri: str) -> ServerExtension | None:
"""Get an extension by URI.
Args:
uri: The extension URI.
Returns:
The extension, or None if not found.
"""
return self._by_uri.get(uri)
@staticmethod
def create_context(
metadata: dict[str, Any],
client_extensions: set[str],
server_context: ServerCallContext | None = None,
) -> ExtensionContext:
"""Create an ExtensionContext for a request.
Args:
metadata: Request metadata dict.
client_extensions: Set of extension URIs from client.
server_context: Optional server call context.
Returns:
ExtensionContext for use in hooks.
"""
return ExtensionContext(
metadata=metadata,
client_extensions=client_extensions,
server_context=server_context,
)
async def invoke_on_request(self, context: ExtensionContext) -> None:
"""Invoke on_request hooks for all active extensions.
Tracks activated extensions and isolates errors from individual hooks.
Args:
context: The extension context for the request.
"""
for extension in self._extensions:
if extension.is_active(context):
try:
await extension.on_request(context)
if context.server_context is not None:
context.server_context.activated_extensions.add(extension.uri)
except Exception:
logger.exception(
"Extension on_request hook failed",
extra={"extension": extension.uri},
)
async def invoke_on_response(self, context: ExtensionContext, result: Any) -> Any:
"""Invoke on_response hooks for all active extensions.
Isolates errors from individual hooks to prevent one failing extension
from breaking the entire response.
Args:
context: The extension context for the request.
result: The agent execution result.
Returns:
The result after all extensions have processed it.
"""
processed = result
for extension in self._extensions:
if extension.is_active(context):
try:
processed = await extension.on_response(context, processed)
except Exception:
logger.exception(
"Extension on_response hook failed",
extra={"extension": extension.uri},
)
return processed

View File

View File

@@ -0,0 +1,479 @@
"""Helper functions for processing A2A task results."""
from __future__ import annotations
from collections.abc import AsyncIterator
from typing import TYPE_CHECKING, Any, TypedDict
import uuid
from a2a.client.errors import A2AClientHTTPError
from a2a.types import (
AgentCard,
Message,
Part,
Role,
Task,
TaskArtifactUpdateEvent,
TaskState,
TaskStatusUpdateEvent,
TextPart,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
A2AConnectionErrorEvent,
A2AResponseReceivedEvent,
)
from typing_extensions import NotRequired
if TYPE_CHECKING:
from a2a.types import Task as A2ATask
SendMessageEvent = (
tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None] | Message
)
TERMINAL_STATES: frozenset[TaskState] = frozenset(
{
TaskState.completed,
TaskState.failed,
TaskState.rejected,
TaskState.canceled,
}
)
ACTIONABLE_STATES: frozenset[TaskState] = frozenset(
{
TaskState.input_required,
TaskState.auth_required,
}
)
PENDING_STATES: frozenset[TaskState] = frozenset(
{
TaskState.submitted,
TaskState.working,
}
)
class TaskStateResult(TypedDict):
"""Result dictionary from processing A2A task state."""
status: TaskState
history: list[Message]
result: NotRequired[str]
error: NotRequired[str]
agent_card: NotRequired[dict[str, Any]]
a2a_agent_name: NotRequired[str | None]
def extract_task_result_parts(a2a_task: A2ATask) -> list[str]:
"""Extract result parts from A2A task status message, history, and artifacts.
Args:
a2a_task: A2A Task object with status, history, and artifacts
Returns:
List of result text parts
"""
result_parts: list[str] = []
if a2a_task.status and a2a_task.status.message:
msg = a2a_task.status.message
result_parts.extend(
part.root.text for part in msg.parts if part.root.kind == "text"
)
if not result_parts and a2a_task.history:
for history_msg in reversed(a2a_task.history):
if history_msg.role == Role.agent:
result_parts.extend(
part.root.text
for part in history_msg.parts
if part.root.kind == "text"
)
break
if a2a_task.artifacts:
result_parts.extend(
part.root.text
for artifact in a2a_task.artifacts
for part in artifact.parts
if part.root.kind == "text"
)
return result_parts
def extract_error_message(a2a_task: A2ATask, default: str) -> str:
"""Extract error message from A2A task.
Args:
a2a_task: A2A Task object
default: Default message if no error found
Returns:
Error message string
"""
if a2a_task.status and a2a_task.status.message:
msg = a2a_task.status.message
if msg:
for part in msg.parts:
if part.root.kind == "text":
return str(part.root.text)
return str(msg)
if a2a_task.history:
for history_msg in reversed(a2a_task.history):
for part in history_msg.parts:
if part.root.kind == "text":
return str(part.root.text)
return default
def process_task_state(
a2a_task: A2ATask,
new_messages: list[Message],
agent_card: AgentCard,
turn_number: int,
is_multiturn: bool,
agent_role: str | None,
result_parts: list[str] | None = None,
endpoint: str | None = None,
a2a_agent_name: str | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
is_final: bool = True,
) -> TaskStateResult | None:
"""Process A2A task state and return result dictionary.
Shared logic for both polling and streaming handlers.
Args:
a2a_task: The A2A task to process.
new_messages: List to collect messages (modified in place).
agent_card: The agent card.
turn_number: Current turn number.
is_multiturn: Whether multi-turn conversation.
agent_role: Agent role for logging.
result_parts: Accumulated result parts (streaming passes accumulated,
polling passes None to extract from task).
endpoint: A2A agent endpoint URL.
a2a_agent_name: Name of the A2A agent from agent card.
from_task: Optional CrewAI Task for event metadata.
from_agent: Optional CrewAI Agent for event metadata.
is_final: Whether this is the final response in the stream.
Returns:
Result dictionary if terminal/actionable state, None otherwise.
"""
if result_parts is None:
result_parts = []
if a2a_task.status.state == TaskState.completed:
if not result_parts:
extracted_parts = extract_task_result_parts(a2a_task)
result_parts.extend(extracted_parts)
if a2a_task.history:
new_messages.extend(a2a_task.history)
response_text = " ".join(result_parts) if result_parts else ""
message_id = None
if a2a_task.status and a2a_task.status.message:
message_id = a2a_task.status.message.message_id
crewai_event_bus.emit(
None,
A2AResponseReceivedEvent(
response=response_text,
turn_number=turn_number,
context_id=a2a_task.context_id,
message_id=message_id,
is_multiturn=is_multiturn,
status="completed",
final=is_final,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
return TaskStateResult(
status=TaskState.completed,
agent_card=agent_card.model_dump(exclude_none=True),
result=response_text,
history=new_messages,
)
if a2a_task.status.state == TaskState.input_required:
if a2a_task.history:
new_messages.extend(a2a_task.history)
response_text = extract_error_message(a2a_task, "Additional input required")
if response_text and not a2a_task.history:
agent_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=response_text))],
context_id=a2a_task.context_id,
task_id=a2a_task.id,
)
new_messages.append(agent_message)
input_message_id = None
if a2a_task.status and a2a_task.status.message:
input_message_id = a2a_task.status.message.message_id
crewai_event_bus.emit(
None,
A2AResponseReceivedEvent(
response=response_text,
turn_number=turn_number,
context_id=a2a_task.context_id,
message_id=input_message_id,
is_multiturn=is_multiturn,
status="input_required",
final=is_final,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
return TaskStateResult(
status=TaskState.input_required,
error=response_text,
history=new_messages,
agent_card=agent_card.model_dump(exclude_none=True),
)
if a2a_task.status.state in {TaskState.failed, TaskState.rejected}:
error_msg = extract_error_message(a2a_task, "Task failed without error message")
if a2a_task.history:
new_messages.extend(a2a_task.history)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
if a2a_task.status.state == TaskState.auth_required:
error_msg = extract_error_message(a2a_task, "Authentication required")
return TaskStateResult(
status=TaskState.auth_required,
error=error_msg,
history=new_messages,
)
if a2a_task.status.state == TaskState.canceled:
error_msg = extract_error_message(a2a_task, "Task was canceled")
return TaskStateResult(
status=TaskState.canceled,
error=error_msg,
history=new_messages,
)
if a2a_task.status.state in PENDING_STATES:
return None
return None
async def send_message_and_get_task_id(
event_stream: AsyncIterator[SendMessageEvent],
new_messages: list[Message],
agent_card: AgentCard,
turn_number: int,
is_multiturn: bool,
agent_role: str | None,
from_task: Any | None = None,
from_agent: Any | None = None,
endpoint: str | None = None,
a2a_agent_name: str | None = None,
context_id: str | None = None,
) -> str | TaskStateResult:
"""Send message and process initial response.
Handles the common pattern of sending a message and either:
- Getting an immediate Message response (task completed synchronously)
- Getting a Task that needs polling/waiting for completion
Args:
event_stream: Async iterator from client.send_message()
new_messages: List to collect messages (modified in place)
agent_card: The agent card
turn_number: Current turn number
is_multiturn: Whether multi-turn conversation
agent_role: Agent role for logging
from_task: Optional CrewAI Task object for event metadata.
from_agent: Optional CrewAI Agent object for event metadata.
endpoint: Optional A2A endpoint URL.
a2a_agent_name: Optional A2A agent name.
context_id: Optional A2A context ID for correlation.
Returns:
Task ID string if agent needs polling/waiting, or TaskStateResult if done.
"""
try:
async for event in event_stream:
if isinstance(event, Message):
new_messages.append(event)
result_parts = [
part.root.text for part in event.parts if part.root.kind == "text"
]
response_text = " ".join(result_parts) if result_parts else ""
crewai_event_bus.emit(
None,
A2AResponseReceivedEvent(
response=response_text,
turn_number=turn_number,
context_id=event.context_id,
message_id=event.message_id,
is_multiturn=is_multiturn,
status="completed",
final=True,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
return TaskStateResult(
status=TaskState.completed,
result=response_text,
history=new_messages,
agent_card=agent_card.model_dump(exclude_none=True),
)
if isinstance(event, tuple):
a2a_task, _ = event
if a2a_task.status.state in TERMINAL_STATES | ACTIONABLE_STATES:
result = process_task_state(
a2a_task=a2a_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
)
if result:
return result
return a2a_task.id
return TaskStateResult(
status=TaskState.failed,
error="No task ID received from initial message",
history=new_messages,
)
except A2AClientHTTPError as e:
error_msg = f"HTTP Error {e.status_code}: {e!s}"
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=context_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
None,
A2AConnectionErrorEvent(
endpoint=endpoint or "",
error=str(e),
error_type="http_error",
status_code=e.status_code,
a2a_agent_name=a2a_agent_name,
operation="send_message",
context_id=context_id,
from_task=from_task,
from_agent=from_agent,
),
)
crewai_event_bus.emit(
None,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=turn_number,
context_id=context_id,
is_multiturn=is_multiturn,
status="failed",
final=True,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
except Exception as e:
error_msg = f"Unexpected error during send_message: {e!s}"
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=context_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
None,
A2AConnectionErrorEvent(
endpoint=endpoint or "",
error=str(e),
error_type="unexpected_error",
a2a_agent_name=a2a_agent_name,
operation="send_message",
context_id=context_id,
from_task=from_task,
from_agent=from_agent,
),
)
crewai_event_bus.emit(
None,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=turn_number,
context_id=context_id,
is_multiturn=is_multiturn,
status="failed",
final=True,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
finally:
aclose = getattr(event_stream, "aclose", None)
if aclose:
await aclose()

View File

@@ -0,0 +1,55 @@
"""String templates for A2A (Agent-to-Agent) protocol messaging and status."""
from string import Template
from typing import Final
AVAILABLE_AGENTS_TEMPLATE: Final[Template] = Template(
"\n<AVAILABLE_A2A_AGENTS>\n $available_a2a_agents\n</AVAILABLE_A2A_AGENTS>\n"
)
PREVIOUS_A2A_CONVERSATION_TEMPLATE: Final[Template] = Template(
"\n<PREVIOUS_A2A_CONVERSATION>\n"
" $previous_a2a_conversation"
"\n</PREVIOUS_A2A_CONVERSATION>\n"
)
CONVERSATION_TURN_INFO_TEMPLATE: Final[Template] = Template(
"\n<CONVERSATION_PROGRESS>\n"
' turn="$turn_count"\n'
' max_turns="$max_turns"\n'
" $warning"
"\n</CONVERSATION_PROGRESS>\n"
)
UNAVAILABLE_AGENTS_NOTICE_TEMPLATE: Final[Template] = Template(
"\n<A2A_AGENTS_STATUS>\n"
" NOTE: A2A agents were configured but are currently unavailable.\n"
" You cannot delegate to remote agents for this task.\n\n"
" Unavailable Agents:\n"
" $unavailable_agents"
"\n</A2A_AGENTS_STATUS>\n"
)
REMOTE_AGENT_COMPLETED_NOTICE: Final[str] = """
<REMOTE_AGENT_STATUS>
STATUS: COMPLETED
The remote agent has finished processing your request. Their response is in the conversation history above.
You MUST now:
1. Extract the answer from the conversation history
2. Set is_a2a=false
3. Return the answer as your final message
DO NOT send another request - the task is already done.
</REMOTE_AGENT_STATUS>
"""
REMOTE_AGENT_RESPONSE_NOTICE: Final[str] = """
<REMOTE_AGENT_STATUS>
STATUS: RESPONSE_RECEIVED
The remote agent has responded. Their response is in the conversation history above.
You MUST now:
1. Set is_a2a=false (the remote task is complete and cannot receive more messages)
2. Provide YOUR OWN response to the original task based on the information received
IMPORTANT: Your response should be addressed to the USER who gave you the original task.
Report what the remote agent told you in THIRD PERSON (e.g., "The remote agent said..." or "I learned that...").
Do NOT address the remote agent directly or use "you" to refer to them.
</REMOTE_AGENT_STATUS>
"""

View File

@@ -0,0 +1,104 @@
"""Type definitions for A2A protocol message parts."""
from __future__ import annotations
from typing import (
Annotated,
Any,
Literal,
Protocol,
TypedDict,
runtime_checkable,
)
from pydantic import BeforeValidator, HttpUrl, TypeAdapter
from typing_extensions import NotRequired
try:
from crewai_a2a.updates import (
PollingConfig,
PollingHandler,
PushNotificationConfig,
PushNotificationHandler,
StreamingConfig,
StreamingHandler,
UpdateConfig,
)
except ImportError:
PollingConfig = Any # type: ignore[misc,assignment]
PollingHandler = Any # type: ignore[misc,assignment]
PushNotificationConfig = Any # type: ignore[misc,assignment]
PushNotificationHandler = Any # type: ignore[misc,assignment]
StreamingConfig = Any # type: ignore[misc,assignment]
StreamingHandler = Any # type: ignore[misc,assignment]
UpdateConfig = Any # type: ignore[misc,assignment]
TransportType = Literal["JSONRPC", "GRPC", "HTTP+JSON"]
ProtocolVersion = Literal[
"0.2.0",
"0.2.1",
"0.2.2",
"0.2.3",
"0.2.4",
"0.2.5",
"0.2.6",
"0.3.0",
"0.4.0",
]
http_url_adapter: TypeAdapter[HttpUrl] = TypeAdapter(HttpUrl)
Url = Annotated[
str,
BeforeValidator(
lambda value: str(http_url_adapter.validate_python(value, strict=True))
),
]
@runtime_checkable
class AgentResponseProtocol(Protocol):
"""Protocol for the dynamically created AgentResponse model."""
a2a_ids: tuple[str, ...]
message: str
is_a2a: bool
class PartsMetadataDict(TypedDict, total=False):
"""Metadata for A2A message parts.
Attributes:
mimeType: MIME type for the part content.
schema: JSON schema for the part content.
"""
mimeType: Literal["application/json"]
schema: dict[str, Any]
class PartsDict(TypedDict):
"""A2A message part containing text and optional metadata.
Attributes:
text: The text content of the message part.
metadata: Optional metadata describing the part content.
"""
text: str
metadata: NotRequired[PartsMetadataDict]
PollingHandlerType = type[PollingHandler]
StreamingHandlerType = type[StreamingHandler]
PushNotificationHandlerType = type[PushNotificationHandler]
HandlerType = PollingHandlerType | StreamingHandlerType | PushNotificationHandlerType
HANDLER_REGISTRY: dict[type[UpdateConfig], HandlerType] = {
PollingConfig: PollingHandler,
StreamingConfig: StreamingHandler,
PushNotificationConfig: PushNotificationHandler,
}

View File

@@ -0,0 +1,35 @@
"""A2A update mechanism configuration types."""
from crewai_a2a.updates.base import (
BaseHandlerKwargs,
PollingHandlerKwargs,
PushNotificationHandlerKwargs,
PushNotificationResultStore,
StreamingHandlerKwargs,
UpdateHandler,
)
from crewai_a2a.updates.polling.config import PollingConfig
from crewai_a2a.updates.polling.handler import PollingHandler
from crewai_a2a.updates.push_notifications.config import PushNotificationConfig
from crewai_a2a.updates.push_notifications.handler import PushNotificationHandler
from crewai_a2a.updates.streaming.config import StreamingConfig
from crewai_a2a.updates.streaming.handler import StreamingHandler
UpdateConfig = PollingConfig | StreamingConfig | PushNotificationConfig
__all__ = [
"BaseHandlerKwargs",
"PollingConfig",
"PollingHandler",
"PollingHandlerKwargs",
"PushNotificationConfig",
"PushNotificationHandler",
"PushNotificationHandlerKwargs",
"PushNotificationResultStore",
"StreamingConfig",
"StreamingHandler",
"StreamingHandlerKwargs",
"UpdateConfig",
"UpdateHandler",
]

View File

@@ -0,0 +1,176 @@
"""Base types for A2A update mechanism handlers."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, NamedTuple, Protocol, TypedDict
from pydantic import GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema
class CommonParams(NamedTuple):
"""Common parameters shared across all update handlers.
Groups the frequently-passed parameters to reduce duplication.
"""
turn_number: int
is_multiturn: bool
agent_role: str | None
endpoint: str
a2a_agent_name: str | None
context_id: str | None
from_task: Any
from_agent: Any
if TYPE_CHECKING:
from a2a.client import Client
from a2a.types import AgentCard, Message, Task
from crewai_a2a.task_helpers import TaskStateResult
from crewai_a2a.updates.push_notifications.config import PushNotificationConfig
class BaseHandlerKwargs(TypedDict, total=False):
"""Base kwargs shared by all handlers."""
turn_number: int
is_multiturn: bool
agent_role: str | None
context_id: str | None
task_id: str | None
endpoint: str | None
agent_branch: Any
a2a_agent_name: str | None
from_task: Any
from_agent: Any
class PollingHandlerKwargs(BaseHandlerKwargs, total=False):
"""Kwargs for polling handler."""
polling_interval: float
polling_timeout: float
history_length: int
max_polls: int | None
class StreamingHandlerKwargs(BaseHandlerKwargs, total=False):
"""Kwargs for streaming handler."""
class PushNotificationHandlerKwargs(BaseHandlerKwargs, total=False):
"""Kwargs for push notification handler."""
config: PushNotificationConfig
result_store: PushNotificationResultStore
polling_timeout: float
polling_interval: float
class PushNotificationResultStore(Protocol):
"""Protocol for storing and retrieving push notification results.
This protocol defines the interface for a result store that the
PushNotificationHandler uses to wait for task completion.
"""
@classmethod
def __get_pydantic_core_schema__(
cls,
_source_type: Any,
_handler: GetCoreSchemaHandler,
) -> CoreSchema:
return core_schema.any_schema()
async def wait_for_result(
self,
task_id: str,
timeout: float,
poll_interval: float = 1.0,
) -> Task | None:
"""Wait for a task result to be available.
Args:
task_id: The task ID to wait for.
timeout: Max seconds to wait before returning None.
poll_interval: Seconds between polling attempts.
Returns:
The completed Task object, or None if timeout.
"""
...
async def get_result(self, task_id: str) -> Task | None:
"""Get a task result if available.
Args:
task_id: The task ID to retrieve.
Returns:
The Task object if available, None otherwise.
"""
...
async def store_result(self, task: Task) -> None:
"""Store a task result.
Args:
task: The Task object to store.
"""
...
class UpdateHandler(Protocol):
"""Protocol for A2A update mechanism handlers."""
@staticmethod
async def execute(
client: Client,
message: Message,
new_messages: list[Message],
agent_card: AgentCard,
**kwargs: Any,
) -> TaskStateResult:
"""Execute the update mechanism and return result.
Args:
client: A2A client instance.
message: Message to send.
new_messages: List to collect messages (modified in place).
agent_card: The agent card.
**kwargs: Additional handler-specific parameters.
Returns:
Result dictionary with status, result/error, and history.
"""
...
def extract_common_params(kwargs: BaseHandlerKwargs) -> CommonParams:
"""Extract common parameters from handler kwargs.
Args:
kwargs: Handler kwargs dict.
Returns:
CommonParams with extracted values.
Raises:
ValueError: If endpoint is not provided.
"""
endpoint = kwargs.get("endpoint")
if endpoint is None:
raise ValueError("endpoint is required for update handlers")
return CommonParams(
turn_number=kwargs.get("turn_number", 0),
is_multiturn=kwargs.get("is_multiturn", False),
agent_role=kwargs.get("agent_role"),
endpoint=endpoint,
a2a_agent_name=kwargs.get("a2a_agent_name"),
context_id=kwargs.get("context_id"),
from_task=kwargs.get("from_task"),
from_agent=kwargs.get("from_agent"),
)

View File

@@ -0,0 +1 @@
"""Polling update mechanism module."""

View File

@@ -0,0 +1,25 @@
"""Polling update mechanism configuration."""
from __future__ import annotations
from pydantic import BaseModel, Field
class PollingConfig(BaseModel):
"""Configuration for polling-based task updates.
Attributes:
interval: Seconds between poll attempts.
timeout: Max seconds to poll before raising timeout error.
max_polls: Max number of poll attempts.
history_length: Number of messages to retrieve per poll.
"""
interval: float = Field(
default=2.0, gt=0, description="Seconds between poll attempts"
)
timeout: float | None = Field(default=None, gt=0, description="Max seconds to poll")
max_polls: int | None = Field(default=None, gt=0, description="Max poll attempts")
history_length: int = Field(
default=100, gt=0, description="Messages to retrieve per poll"
)

View File

@@ -0,0 +1,359 @@
"""Polling update mechanism handler."""
from __future__ import annotations
import asyncio
import time
from typing import TYPE_CHECKING, Any
import uuid
from a2a.client import Client
from a2a.client.errors import A2AClientHTTPError
from a2a.types import (
AgentCard,
Message,
Part,
Role,
TaskQueryParams,
TaskState,
TextPart,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
A2AConnectionErrorEvent,
A2APollingStartedEvent,
A2APollingStatusEvent,
A2AResponseReceivedEvent,
)
from typing_extensions import Unpack
from crewai_a2a.errors import A2APollingTimeoutError
from crewai_a2a.task_helpers import (
ACTIONABLE_STATES,
TERMINAL_STATES,
TaskStateResult,
process_task_state,
send_message_and_get_task_id,
)
from crewai_a2a.updates.base import PollingHandlerKwargs
if TYPE_CHECKING:
from a2a.types import Task as A2ATask
async def _poll_task_until_complete(
client: Client,
task_id: str,
polling_interval: float,
polling_timeout: float,
agent_branch: Any | None = None,
history_length: int = 100,
max_polls: int | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
context_id: str | None = None,
endpoint: str | None = None,
a2a_agent_name: str | None = None,
) -> A2ATask:
"""Poll task status until terminal state reached.
Args:
client: A2A client instance.
task_id: Task ID to poll.
polling_interval: Seconds between poll attempts.
polling_timeout: Max seconds before timeout.
agent_branch: Agent tree branch for logging.
history_length: Number of messages to retrieve per poll.
max_polls: Max number of poll attempts (None = unlimited).
from_task: Optional CrewAI Task object for event metadata.
from_agent: Optional CrewAI Agent object for event metadata.
context_id: A2A context ID for correlation.
endpoint: A2A agent endpoint URL.
a2a_agent_name: Name of the A2A agent from agent card.
Returns:
Final task object in terminal state.
Raises:
A2APollingTimeoutError: If polling exceeds timeout or max_polls.
"""
start_time = time.monotonic()
poll_count = 0
while True:
poll_count += 1
task = await client.get_task(
TaskQueryParams(id=task_id, history_length=history_length)
)
elapsed = time.monotonic() - start_time
effective_context_id = task.context_id or context_id
crewai_event_bus.emit(
agent_branch,
A2APollingStatusEvent(
task_id=task_id,
context_id=effective_context_id,
state=str(task.status.state.value),
elapsed_seconds=elapsed,
poll_count=poll_count,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
if task.status.state in TERMINAL_STATES | ACTIONABLE_STATES:
return task
if elapsed > polling_timeout:
raise A2APollingTimeoutError(
f"Polling timeout after {polling_timeout}s ({poll_count} polls)"
)
if max_polls and poll_count >= max_polls:
raise A2APollingTimeoutError(
f"Max polls ({max_polls}) exceeded after {elapsed:.1f}s"
)
await asyncio.sleep(polling_interval)
class PollingHandler:
"""Polling-based update handler."""
@staticmethod
async def execute(
client: Client,
message: Message,
new_messages: list[Message],
agent_card: AgentCard,
**kwargs: Unpack[PollingHandlerKwargs],
) -> TaskStateResult:
"""Execute A2A delegation using polling for updates.
Args:
client: A2A client instance.
message: Message to send.
new_messages: List to collect messages.
agent_card: The agent card.
**kwargs: Polling-specific parameters.
Returns:
Dictionary with status, result/error, and history.
"""
polling_interval = kwargs.get("polling_interval", 2.0)
polling_timeout = kwargs.get("polling_timeout", 300.0)
endpoint = kwargs.get("endpoint", "")
agent_branch = kwargs.get("agent_branch")
turn_number = kwargs.get("turn_number", 0)
is_multiturn = kwargs.get("is_multiturn", False)
agent_role = kwargs.get("agent_role")
history_length = kwargs.get("history_length", 100)
max_polls = kwargs.get("max_polls")
context_id = kwargs.get("context_id")
task_id = kwargs.get("task_id")
a2a_agent_name = kwargs.get("a2a_agent_name")
from_task = kwargs.get("from_task")
from_agent = kwargs.get("from_agent")
try:
result_or_task_id = await send_message_and_get_task_id(
event_stream=client.send_message(message),
new_messages=new_messages,
agent_card=agent_card,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
from_task=from_task,
from_agent=from_agent,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
context_id=context_id,
)
if not isinstance(result_or_task_id, str):
return result_or_task_id
task_id = result_or_task_id
crewai_event_bus.emit(
agent_branch,
A2APollingStartedEvent(
task_id=task_id,
context_id=context_id,
polling_interval=polling_interval,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
final_task = await _poll_task_until_complete(
client=client,
task_id=task_id,
polling_interval=polling_interval,
polling_timeout=polling_timeout,
agent_branch=agent_branch,
history_length=history_length,
max_polls=max_polls,
from_task=from_task,
from_agent=from_agent,
context_id=context_id,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
)
result = process_task_state(
a2a_task=final_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
)
if result:
return result
return TaskStateResult(
status=TaskState.failed,
error=f"Unexpected task state: {final_task.status.state}",
history=new_messages,
)
except A2APollingTimeoutError as e:
error_msg = str(e)
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=context_id,
task_id=task_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
agent_branch,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=turn_number,
context_id=context_id,
is_multiturn=is_multiturn,
status="failed",
final=True,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
except A2AClientHTTPError as e:
error_msg = f"HTTP Error {e.status_code}: {e!s}"
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=context_id,
task_id=task_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=endpoint,
error=str(e),
error_type="http_error",
status_code=e.status_code,
a2a_agent_name=a2a_agent_name,
operation="polling",
context_id=context_id,
task_id=task_id,
from_task=from_task,
from_agent=from_agent,
),
)
crewai_event_bus.emit(
agent_branch,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=turn_number,
context_id=context_id,
is_multiturn=is_multiturn,
status="failed",
final=True,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
except Exception as e:
error_msg = f"Unexpected error during polling: {e!s}"
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=context_id,
task_id=task_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=endpoint,
error=str(e),
error_type="unexpected_error",
a2a_agent_name=a2a_agent_name,
operation="polling",
context_id=context_id,
task_id=task_id,
from_task=from_task,
from_agent=from_agent,
),
)
crewai_event_bus.emit(
agent_branch,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=turn_number,
context_id=context_id,
is_multiturn=is_multiturn,
status="failed",
final=True,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)

View File

@@ -0,0 +1 @@
"""Push notification update mechanism module."""

View File

@@ -0,0 +1,65 @@
"""Push notification update mechanism configuration."""
from __future__ import annotations
from typing import Annotated
from a2a.types import PushNotificationAuthenticationInfo
from pydantic import AnyHttpUrl, BaseModel, BeforeValidator, Field
from crewai_a2a.updates.base import PushNotificationResultStore
from crewai_a2a.updates.push_notifications.signature import WebhookSignatureConfig
def _coerce_signature(
value: str | WebhookSignatureConfig | None,
) -> WebhookSignatureConfig | None:
"""Convert string secret to WebhookSignatureConfig."""
if value is None:
return None
if isinstance(value, str):
return WebhookSignatureConfig.hmac_sha256(secret=value)
return value
SignatureInput = Annotated[
WebhookSignatureConfig | None,
BeforeValidator(_coerce_signature),
]
class PushNotificationConfig(BaseModel):
"""Configuration for webhook-based task updates.
Attributes:
url: Callback URL where agent sends push notifications.
id: Unique identifier for this config.
token: Token to validate incoming notifications.
authentication: Auth info for agent to use when calling webhook.
timeout: Max seconds to wait for task completion.
interval: Seconds between result polling attempts.
result_store: Store for receiving push notification results.
signature: HMAC signature config. Pass a string (secret) for defaults,
or WebhookSignatureConfig for custom settings.
"""
url: AnyHttpUrl = Field(description="Callback URL for push notifications")
id: str | None = Field(default=None, description="Unique config identifier")
token: str | None = Field(default=None, description="Validation token")
authentication: PushNotificationAuthenticationInfo | None = Field(
default=None, description="Auth info for agent to use when calling webhook"
)
timeout: float | None = Field(
default=300.0, gt=0, description="Max seconds to wait for task completion"
)
interval: float = Field(
default=2.0, gt=0, description="Seconds between result polling attempts"
)
result_store: PushNotificationResultStore | None = Field(
default=None, description="Result store for push notification handling"
)
signature: SignatureInput = Field(
default=None,
description="HMAC signature config. Pass a string (secret) for simple usage, "
"or WebhookSignatureConfig for custom headers/tolerance.",
)

View File

@@ -0,0 +1,354 @@
"""Push notification (webhook) update mechanism handler."""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any
import uuid
from a2a.client import Client
from a2a.client.errors import A2AClientHTTPError
from a2a.types import (
AgentCard,
Message,
Part,
Role,
TaskState,
TextPart,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
A2AConnectionErrorEvent,
A2APushNotificationRegisteredEvent,
A2APushNotificationTimeoutEvent,
A2AResponseReceivedEvent,
)
from typing_extensions import Unpack
from crewai_a2a.task_helpers import (
TaskStateResult,
process_task_state,
send_message_and_get_task_id,
)
from crewai_a2a.updates.base import (
CommonParams,
PushNotificationHandlerKwargs,
PushNotificationResultStore,
extract_common_params,
)
if TYPE_CHECKING:
from a2a.types import Task as A2ATask
logger = logging.getLogger(__name__)
def _handle_push_error(
error: Exception,
error_msg: str,
error_type: str,
new_messages: list[Message],
agent_branch: Any | None,
params: CommonParams,
task_id: str | None,
status_code: int | None = None,
) -> TaskStateResult:
"""Handle push notification errors with consistent event emission.
Args:
error: The exception that occurred.
error_msg: Formatted error message for the result.
error_type: Type of error for the event.
new_messages: List to append error message to.
agent_branch: Agent tree branch for events.
params: Common handler parameters.
task_id: A2A task ID.
status_code: HTTP status code if applicable.
Returns:
TaskStateResult with failed status.
"""
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=params.context_id,
task_id=task_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=params.endpoint,
error=str(error),
error_type=error_type,
status_code=status_code,
a2a_agent_name=params.a2a_agent_name,
operation="push_notification",
context_id=params.context_id,
task_id=task_id,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
crewai_event_bus.emit(
agent_branch,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=params.turn_number,
context_id=params.context_id,
is_multiturn=params.is_multiturn,
status="failed",
final=True,
agent_role=params.agent_role,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
async def _wait_for_push_result(
task_id: str,
result_store: PushNotificationResultStore,
timeout: float,
poll_interval: float,
agent_branch: Any | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
context_id: str | None = None,
endpoint: str | None = None,
a2a_agent_name: str | None = None,
) -> A2ATask | None:
"""Wait for push notification result.
Args:
task_id: Task ID to wait for.
result_store: Store to retrieve results from.
timeout: Max seconds to wait.
poll_interval: Seconds between polling attempts.
agent_branch: Agent tree branch for logging.
from_task: Optional CrewAI Task object for event metadata.
from_agent: Optional CrewAI Agent object for event metadata.
context_id: A2A context ID for correlation.
endpoint: A2A agent endpoint URL.
a2a_agent_name: Name of the A2A agent.
Returns:
Final task object, or None if timeout.
"""
task = await result_store.wait_for_result(
task_id=task_id,
timeout=timeout,
poll_interval=poll_interval,
)
if task is None:
crewai_event_bus.emit(
agent_branch,
A2APushNotificationTimeoutEvent(
task_id=task_id,
context_id=context_id,
timeout_seconds=timeout,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
return task
class PushNotificationHandler:
"""Push notification (webhook) based update handler."""
@staticmethod
async def execute(
client: Client,
message: Message,
new_messages: list[Message],
agent_card: AgentCard,
**kwargs: Unpack[PushNotificationHandlerKwargs],
) -> TaskStateResult:
"""Execute A2A delegation using push notifications for updates.
Args:
client: A2A client instance.
message: Message to send.
new_messages: List to collect messages.
agent_card: The agent card.
**kwargs: Push notification-specific parameters.
Returns:
Dictionary with status, result/error, and history.
Raises:
ValueError: If result_store or config not provided.
"""
config = kwargs.get("config")
result_store = kwargs.get("result_store")
polling_timeout = kwargs.get("polling_timeout", 300.0)
polling_interval = kwargs.get("polling_interval", 2.0)
agent_branch = kwargs.get("agent_branch")
task_id = kwargs.get("task_id")
params = extract_common_params(kwargs)
if config is None:
error_msg = (
"PushNotificationConfig is required for push notification handler"
)
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=params.endpoint,
error=error_msg,
error_type="configuration_error",
a2a_agent_name=params.a2a_agent_name,
operation="push_notification",
context_id=params.context_id,
task_id=task_id,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
if result_store is None:
error_msg = (
"PushNotificationResultStore is required for push notification handler"
)
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=params.endpoint,
error=error_msg,
error_type="configuration_error",
a2a_agent_name=params.a2a_agent_name,
operation="push_notification",
context_id=params.context_id,
task_id=task_id,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
try:
result_or_task_id = await send_message_and_get_task_id(
event_stream=client.send_message(message),
new_messages=new_messages,
agent_card=agent_card,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
agent_role=params.agent_role,
from_task=params.from_task,
from_agent=params.from_agent,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
context_id=params.context_id,
)
if not isinstance(result_or_task_id, str):
return result_or_task_id
task_id = result_or_task_id
crewai_event_bus.emit(
agent_branch,
A2APushNotificationRegisteredEvent(
task_id=task_id,
context_id=params.context_id,
callback_url=str(config.url),
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
logger.debug(
"Push notification callback for task %s configured at %s (via initial request)",
task_id,
config.url,
)
final_task = await _wait_for_push_result(
task_id=task_id,
result_store=result_store,
timeout=polling_timeout,
poll_interval=polling_interval,
agent_branch=agent_branch,
from_task=params.from_task,
from_agent=params.from_agent,
context_id=params.context_id,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
)
if final_task is None:
return TaskStateResult(
status=TaskState.failed,
error=f"Push notification timeout after {polling_timeout}s",
history=new_messages,
)
result = process_task_state(
a2a_task=final_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
agent_role=params.agent_role,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
)
if result:
return result
return TaskStateResult(
status=TaskState.failed,
error=f"Unexpected task state: {final_task.status.state}",
history=new_messages,
)
except A2AClientHTTPError as e:
return _handle_push_error(
error=e,
error_msg=f"HTTP Error {e.status_code}: {e!s}",
error_type="http_error",
new_messages=new_messages,
agent_branch=agent_branch,
params=params,
task_id=task_id,
status_code=e.status_code,
)
except Exception as e:
return _handle_push_error(
error=e,
error_msg=f"Unexpected error during push notification: {e!s}",
error_type="unexpected_error",
new_messages=new_messages,
agent_branch=agent_branch,
params=params,
task_id=task_id,
)

View File

@@ -0,0 +1,87 @@
"""Webhook signature configuration for push notifications."""
from __future__ import annotations
from enum import Enum
import secrets
from pydantic import BaseModel, Field, SecretStr
class WebhookSignatureMode(str, Enum):
"""Signature mode for webhook push notifications."""
NONE = "none"
HMAC_SHA256 = "hmac_sha256"
class WebhookSignatureConfig(BaseModel):
"""Configuration for webhook signature verification.
Provides cryptographic integrity verification and replay attack protection
for A2A push notifications.
Attributes:
mode: Signature mode (none or hmac_sha256).
secret: Shared secret for HMAC computation (required for hmac_sha256 mode).
timestamp_tolerance_seconds: Max allowed age of timestamps for replay protection.
header_name: HTTP header name for the signature.
timestamp_header_name: HTTP header name for the timestamp.
"""
mode: WebhookSignatureMode = Field(
default=WebhookSignatureMode.NONE,
description="Signature verification mode",
)
secret: SecretStr | None = Field(
default=None,
description="Shared secret for HMAC computation",
)
timestamp_tolerance_seconds: int = Field(
default=300,
ge=0,
description="Max allowed timestamp age in seconds (5 min default)",
)
header_name: str = Field(
default="X-A2A-Signature",
description="HTTP header name for the signature",
)
timestamp_header_name: str = Field(
default="X-A2A-Signature-Timestamp",
description="HTTP header name for the timestamp",
)
@classmethod
def generate_secret(cls, length: int = 32) -> str:
"""Generate a cryptographically secure random secret.
Args:
length: Number of random bytes to generate (default 32).
Returns:
URL-safe base64-encoded secret string.
"""
return secrets.token_urlsafe(length)
@classmethod
def hmac_sha256(
cls,
secret: str | SecretStr,
timestamp_tolerance_seconds: int = 300,
) -> WebhookSignatureConfig:
"""Create an HMAC-SHA256 signature configuration.
Args:
secret: Shared secret for HMAC computation.
timestamp_tolerance_seconds: Max allowed timestamp age in seconds.
Returns:
Configured WebhookSignatureConfig for HMAC-SHA256.
"""
if isinstance(secret, str):
secret = SecretStr(secret)
return cls(
mode=WebhookSignatureMode.HMAC_SHA256,
secret=secret,
timestamp_tolerance_seconds=timestamp_tolerance_seconds,
)

View File

@@ -0,0 +1 @@
"""Streaming update mechanism module."""

View File

@@ -0,0 +1,9 @@
"""Streaming update mechanism configuration."""
from __future__ import annotations
from pydantic import BaseModel
class StreamingConfig(BaseModel):
"""Configuration for SSE-based task updates."""

View File

@@ -0,0 +1,646 @@
"""Streaming (SSE) update mechanism handler."""
from __future__ import annotations
import asyncio
import logging
from typing import Final
import uuid
from a2a.client import Client
from a2a.client.errors import A2AClientHTTPError
from a2a.types import (
AgentCard,
Message,
Part,
Role,
Task,
TaskArtifactUpdateEvent,
TaskIdParams,
TaskQueryParams,
TaskState,
TaskStatusUpdateEvent,
TextPart,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
A2AArtifactReceivedEvent,
A2AConnectionErrorEvent,
A2AResponseReceivedEvent,
A2AStreamingChunkEvent,
A2AStreamingStartedEvent,
)
from typing_extensions import Unpack
from crewai_a2a.task_helpers import (
ACTIONABLE_STATES,
TERMINAL_STATES,
TaskStateResult,
process_task_state,
)
from crewai_a2a.updates.base import StreamingHandlerKwargs, extract_common_params
from crewai_a2a.updates.streaming.params import (
process_status_update,
)
logger = logging.getLogger(__name__)
MAX_RESUBSCRIBE_ATTEMPTS: Final[int] = 3
RESUBSCRIBE_BACKOFF_BASE: Final[float] = 1.0
class StreamingHandler:
"""SSE streaming-based update handler."""
@staticmethod
async def _try_recover_from_interruption( # type: ignore[misc]
client: Client,
task_id: str,
new_messages: list[Message],
agent_card: AgentCard,
result_parts: list[str],
**kwargs: Unpack[StreamingHandlerKwargs],
) -> TaskStateResult | None:
"""Attempt to recover from a stream interruption by checking task state.
If the task completed while we were disconnected, returns the result.
If the task is still running, attempts to resubscribe and continue.
Args:
client: A2A client instance.
task_id: The task ID to recover.
new_messages: List of collected messages.
agent_card: The agent card.
result_parts: Accumulated result text parts.
**kwargs: Handler parameters.
Returns:
TaskStateResult if recovery succeeded (task finished or resubscribe worked).
None if recovery not possible (caller should handle failure).
Note:
When None is returned, recovery failed and the original exception should
be handled by the caller. All recovery attempts are logged.
"""
params = extract_common_params(kwargs) # type: ignore[arg-type]
try:
a2a_task: Task = await client.get_task(TaskQueryParams(id=task_id))
if a2a_task.status.state in TERMINAL_STATES:
logger.info(
"Task completed during stream interruption",
extra={"task_id": task_id, "state": str(a2a_task.status.state)},
)
return process_task_state(
a2a_task=a2a_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
agent_role=params.agent_role,
result_parts=result_parts,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
)
if a2a_task.status.state in ACTIONABLE_STATES:
logger.info(
"Task in actionable state during stream interruption",
extra={"task_id": task_id, "state": str(a2a_task.status.state)},
)
return process_task_state(
a2a_task=a2a_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
agent_role=params.agent_role,
result_parts=result_parts,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
is_final=False,
)
logger.info(
"Task still running, attempting resubscribe",
extra={"task_id": task_id, "state": str(a2a_task.status.state)},
)
for attempt in range(MAX_RESUBSCRIBE_ATTEMPTS):
try:
backoff = RESUBSCRIBE_BACKOFF_BASE * (2**attempt)
if attempt > 0:
await asyncio.sleep(backoff)
event_stream = client.resubscribe(TaskIdParams(id=task_id))
async for event in event_stream:
if isinstance(event, tuple):
resubscribed_task, update = event
is_final_update = (
process_status_update(update, result_parts)
if isinstance(update, TaskStatusUpdateEvent)
else False
)
if isinstance(update, TaskArtifactUpdateEvent):
artifact = update.artifact
result_parts.extend(
part.root.text
for part in artifact.parts
if part.root.kind == "text"
)
if (
is_final_update
or resubscribed_task.status.state
in TERMINAL_STATES | ACTIONABLE_STATES
):
return process_task_state(
a2a_task=resubscribed_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
agent_role=params.agent_role,
result_parts=result_parts,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
is_final=is_final_update,
)
elif isinstance(event, Message):
new_messages.append(event)
result_parts.extend(
part.root.text
for part in event.parts
if part.root.kind == "text"
)
final_task = await client.get_task(TaskQueryParams(id=task_id))
return process_task_state(
a2a_task=final_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
agent_role=params.agent_role,
result_parts=result_parts,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
)
except Exception as resubscribe_error: # noqa: PERF203
logger.warning(
"Resubscribe attempt failed",
extra={
"task_id": task_id,
"attempt": attempt + 1,
"max_attempts": MAX_RESUBSCRIBE_ATTEMPTS,
"error": str(resubscribe_error),
},
)
if attempt == MAX_RESUBSCRIBE_ATTEMPTS - 1:
return None
except Exception as e:
logger.warning(
"Failed to recover from stream interruption due to unexpected error",
extra={
"task_id": task_id,
"error": str(e),
"error_type": type(e).__name__,
},
exc_info=True,
)
return None
logger.warning(
"Recovery exhausted all resubscribe attempts without success",
extra={"task_id": task_id, "max_attempts": MAX_RESUBSCRIBE_ATTEMPTS},
)
return None
@staticmethod
async def execute(
client: Client,
message: Message,
new_messages: list[Message],
agent_card: AgentCard,
**kwargs: Unpack[StreamingHandlerKwargs],
) -> TaskStateResult:
"""Execute A2A delegation using SSE streaming for updates.
Args:
client: A2A client instance.
message: Message to send.
new_messages: List to collect messages.
agent_card: The agent card.
**kwargs: Streaming-specific parameters.
Returns:
Dictionary with status, result/error, and history.
"""
task_id = kwargs.get("task_id")
agent_branch = kwargs.get("agent_branch")
params = extract_common_params(kwargs)
result_parts: list[str] = []
final_result: TaskStateResult | None = None
event_stream = client.send_message(message)
chunk_index = 0
current_task_id: str | None = task_id
crewai_event_bus.emit(
agent_branch,
A2AStreamingStartedEvent(
task_id=task_id,
context_id=params.context_id,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
agent_role=params.agent_role,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
try:
async for event in event_stream:
if isinstance(event, tuple):
a2a_task, _ = event
current_task_id = a2a_task.id
if isinstance(event, Message):
new_messages.append(event)
message_context_id = event.context_id or params.context_id
for part in event.parts:
if part.root.kind == "text":
text = part.root.text
result_parts.append(text)
crewai_event_bus.emit(
agent_branch,
A2AStreamingChunkEvent(
task_id=event.task_id or task_id,
context_id=message_context_id,
chunk=text,
chunk_index=chunk_index,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
chunk_index += 1
elif isinstance(event, tuple):
a2a_task, update = event
if isinstance(update, TaskArtifactUpdateEvent):
artifact = update.artifact
result_parts.extend(
part.root.text
for part in artifact.parts
if part.root.kind == "text"
)
artifact_size = None
if artifact.parts:
artifact_size = sum(
len(p.root.text.encode())
if p.root.kind == "text"
else len(getattr(p.root, "data", b""))
for p in artifact.parts
)
effective_context_id = a2a_task.context_id or params.context_id
crewai_event_bus.emit(
agent_branch,
A2AArtifactReceivedEvent(
task_id=a2a_task.id,
artifact_id=artifact.artifact_id,
artifact_name=artifact.name,
artifact_description=artifact.description,
mime_type=artifact.parts[0].root.kind
if artifact.parts
else None,
size_bytes=artifact_size,
append=update.append or False,
last_chunk=update.last_chunk or False,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
context_id=effective_context_id,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
is_final_update = (
process_status_update(update, result_parts)
if isinstance(update, TaskStatusUpdateEvent)
else False
)
if (
not is_final_update
and a2a_task.status.state
not in TERMINAL_STATES | ACTIONABLE_STATES
):
continue
final_result = process_task_state(
a2a_task=a2a_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=params.turn_number,
is_multiturn=params.is_multiturn,
agent_role=params.agent_role,
result_parts=result_parts,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
is_final=is_final_update,
)
if final_result:
break
except A2AClientHTTPError as e:
if current_task_id:
logger.info(
"Stream interrupted with HTTP error, attempting recovery",
extra={
"task_id": current_task_id,
"error": str(e),
"status_code": e.status_code,
},
)
recovery_kwargs = {k: v for k, v in kwargs.items() if k != "task_id"}
recovered_result = (
await StreamingHandler._try_recover_from_interruption(
client=client,
task_id=current_task_id,
new_messages=new_messages,
agent_card=agent_card,
result_parts=result_parts,
**recovery_kwargs,
)
)
if recovered_result:
logger.info(
"Successfully recovered task after HTTP error",
extra={
"task_id": current_task_id,
"status": str(recovered_result.get("status")),
},
)
return recovered_result
logger.warning(
"Failed to recover from HTTP error, returning failure",
extra={
"task_id": current_task_id,
"status_code": e.status_code,
"original_error": str(e),
},
)
error_msg = f"HTTP Error {e.status_code}: {e!s}"
error_type = "http_error"
status_code = e.status_code
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=params.context_id,
task_id=task_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=params.endpoint,
error=str(e),
error_type=error_type,
status_code=status_code,
a2a_agent_name=params.a2a_agent_name,
operation="streaming",
context_id=params.context_id,
task_id=task_id,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
crewai_event_bus.emit(
agent_branch,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=params.turn_number,
context_id=params.context_id,
is_multiturn=params.is_multiturn,
status="failed",
final=True,
agent_role=params.agent_role,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
except (asyncio.TimeoutError, asyncio.CancelledError, ConnectionError) as e:
error_type = type(e).__name__.lower()
if current_task_id:
logger.info(
f"Stream interrupted with {error_type}, attempting recovery",
extra={"task_id": current_task_id, "error": str(e)},
)
recovery_kwargs = {k: v for k, v in kwargs.items() if k != "task_id"}
recovered_result = (
await StreamingHandler._try_recover_from_interruption(
client=client,
task_id=current_task_id,
new_messages=new_messages,
agent_card=agent_card,
result_parts=result_parts,
**recovery_kwargs,
)
)
if recovered_result:
logger.info(
f"Successfully recovered task after {error_type}",
extra={
"task_id": current_task_id,
"status": str(recovered_result.get("status")),
},
)
return recovered_result
logger.warning(
f"Failed to recover from {error_type}, returning failure",
extra={
"task_id": current_task_id,
"error_type": error_type,
"original_error": str(e),
},
)
error_msg = f"Connection error during streaming: {e!s}"
status_code = None
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=params.context_id,
task_id=task_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=params.endpoint,
error=str(e),
error_type=error_type,
status_code=status_code,
a2a_agent_name=params.a2a_agent_name,
operation="streaming",
context_id=params.context_id,
task_id=task_id,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
crewai_event_bus.emit(
agent_branch,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=params.turn_number,
context_id=params.context_id,
is_multiturn=params.is_multiturn,
status="failed",
final=True,
agent_role=params.agent_role,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
except Exception as e:
logger.exception(
"Unexpected error during streaming",
extra={
"task_id": current_task_id,
"error_type": type(e).__name__,
"endpoint": params.endpoint,
},
)
error_msg = f"Unexpected error during streaming: {type(e).__name__}: {e!s}"
error_type = "unexpected_error"
status_code = None
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=params.context_id,
task_id=task_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=params.endpoint,
error=str(e),
error_type=error_type,
status_code=status_code,
a2a_agent_name=params.a2a_agent_name,
operation="streaming",
context_id=params.context_id,
task_id=task_id,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
crewai_event_bus.emit(
agent_branch,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=params.turn_number,
context_id=params.context_id,
is_multiturn=params.is_multiturn,
status="failed",
final=True,
agent_role=params.agent_role,
endpoint=params.endpoint,
a2a_agent_name=params.a2a_agent_name,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
finally:
aclose = getattr(event_stream, "aclose", None)
if aclose:
try:
await aclose()
except Exception as close_error:
crewai_event_bus.emit(
agent_branch,
A2AConnectionErrorEvent(
endpoint=params.endpoint,
error=str(close_error),
error_type="stream_close_error",
a2a_agent_name=params.a2a_agent_name,
operation="stream_close",
context_id=params.context_id,
task_id=task_id,
from_task=params.from_task,
from_agent=params.from_agent,
),
)
if final_result:
return final_result
return TaskStateResult(
status=TaskState.completed,
result=" ".join(result_parts) if result_parts else "",
history=new_messages,
agent_card=agent_card.model_dump(exclude_none=True),
)

View File

@@ -0,0 +1,28 @@
"""Common parameter extraction for streaming handlers."""
from __future__ import annotations
from a2a.types import TaskStatusUpdateEvent
def process_status_update(
update: TaskStatusUpdateEvent,
result_parts: list[str],
) -> bool:
"""Process a status update event and extract text parts.
Args:
update: The status update event.
result_parts: List to append text parts to (modified in place).
Returns:
True if this is a final update, False otherwise.
"""
is_final = update.final
if update.status and update.status.message and update.status.message.parts:
result_parts.extend(
part.root.text
for part in update.status.message.parts
if part.root.kind == "text" and part.root.text
)
return is_final

View File

@@ -0,0 +1 @@
"""A2A utility modules for client operations."""

View File

@@ -0,0 +1,587 @@
"""AgentCard utilities for A2A client and server operations."""
from __future__ import annotations
import asyncio
from collections.abc import MutableMapping
from functools import lru_cache
import ssl
import time
from types import MethodType
from typing import TYPE_CHECKING
from a2a.client.errors import A2AClientHTTPError
from a2a.types import AgentCapabilities, AgentCard, AgentSkill
from aiocache import cached # type: ignore[import-untyped]
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
from crewai.crew import Crew
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
A2AAgentCardFetchedEvent,
A2AAuthenticationFailedEvent,
A2AConnectionErrorEvent,
)
import httpx
from crewai_a2a.auth.client_schemes import APIKeyAuth, HTTPDigestAuth
from crewai_a2a.auth.utils import (
_auth_store,
configure_auth_client,
retry_on_401,
)
from crewai_a2a.config import A2AServerConfig
if TYPE_CHECKING:
from crewai.agent import Agent
from crewai.task import Task
from crewai_a2a.auth.client_schemes import ClientAuthScheme
def _get_tls_verify(auth: ClientAuthScheme | None) -> ssl.SSLContext | bool | str:
"""Get TLS verify parameter from auth scheme.
Args:
auth: Optional authentication scheme with TLS config.
Returns:
SSL context, CA cert path, True for default verification,
or False if verification disabled.
"""
if auth and auth.tls:
return auth.tls.get_httpx_ssl_context()
return True
async def _prepare_auth_headers(
auth: ClientAuthScheme | None,
timeout: int,
) -> tuple[MutableMapping[str, str], ssl.SSLContext | bool | str]:
"""Prepare authentication headers and TLS verification settings.
Args:
auth: Optional authentication scheme.
timeout: Request timeout in seconds.
Returns:
Tuple of (headers dict, TLS verify setting).
"""
headers: MutableMapping[str, str] = {}
verify = _get_tls_verify(auth)
if auth:
async with httpx.AsyncClient(
timeout=timeout, verify=verify
) as temp_auth_client:
if isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
configure_auth_client(auth, temp_auth_client)
headers = await auth.apply_auth(temp_auth_client, {})
return headers, verify
def _get_server_config(agent: Agent) -> A2AServerConfig | None:
"""Get A2AServerConfig from an agent's a2a configuration.
Args:
agent: The Agent instance to check.
Returns:
A2AServerConfig if present, None otherwise.
"""
if agent.a2a is None:
return None
if isinstance(agent.a2a, A2AServerConfig):
return agent.a2a
if isinstance(agent.a2a, list):
for config in agent.a2a:
if isinstance(config, A2AServerConfig):
return config
return None
def fetch_agent_card(
endpoint: str,
auth: ClientAuthScheme | None = None,
timeout: int = 30,
use_cache: bool = True,
cache_ttl: int = 300,
) -> AgentCard:
"""Fetch AgentCard from an A2A endpoint with optional caching.
Args:
endpoint: A2A agent endpoint URL (AgentCard URL).
auth: Optional ClientAuthScheme for authentication.
timeout: Request timeout in seconds.
use_cache: Whether to use caching (default True).
cache_ttl: Cache TTL in seconds (default 300 = 5 minutes).
Returns:
AgentCard object with agent capabilities and skills.
Raises:
httpx.HTTPStatusError: If the request fails.
A2AClientHTTPError: If authentication fails.
"""
if use_cache:
if auth:
auth_data = auth.model_dump_json(
exclude={
"_access_token",
"_token_expires_at",
"_refresh_token",
"_authorization_callback",
}
)
auth_hash = _auth_store.compute_key(type(auth).__name__, auth_data)
else:
auth_hash = _auth_store.compute_key("none", "")
_auth_store.set(auth_hash, auth)
ttl_hash = int(time.time() // cache_ttl)
return _fetch_agent_card_cached(endpoint, auth_hash, timeout, ttl_hash)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(
afetch_agent_card(endpoint=endpoint, auth=auth, timeout=timeout)
)
finally:
loop.close()
async def afetch_agent_card(
endpoint: str,
auth: ClientAuthScheme | None = None,
timeout: int = 30,
use_cache: bool = True,
) -> AgentCard:
"""Fetch AgentCard from an A2A endpoint asynchronously.
Native async implementation. Use this when running in an async context.
Args:
endpoint: A2A agent endpoint URL (AgentCard URL).
auth: Optional ClientAuthScheme for authentication.
timeout: Request timeout in seconds.
use_cache: Whether to use caching (default True).
Returns:
AgentCard object with agent capabilities and skills.
Raises:
httpx.HTTPStatusError: If the request fails.
A2AClientHTTPError: If authentication fails.
"""
if use_cache:
if auth:
auth_data = auth.model_dump_json(
exclude={
"_access_token",
"_token_expires_at",
"_refresh_token",
"_authorization_callback",
}
)
auth_hash = _auth_store.compute_key(type(auth).__name__, auth_data)
else:
auth_hash = _auth_store.compute_key("none", "")
_auth_store.set(auth_hash, auth)
agent_card: AgentCard = await _afetch_agent_card_cached(
endpoint, auth_hash, timeout
)
return agent_card
return await _afetch_agent_card_impl(endpoint=endpoint, auth=auth, timeout=timeout)
@lru_cache()
def _fetch_agent_card_cached(
endpoint: str,
auth_hash: str,
timeout: int,
_ttl_hash: int,
) -> AgentCard:
"""Cached sync version of fetch_agent_card."""
auth = _auth_store.get(auth_hash)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(
_afetch_agent_card_impl(endpoint=endpoint, auth=auth, timeout=timeout)
)
finally:
loop.close()
@cached(ttl=300, serializer=PickleSerializer()) # type: ignore[untyped-decorator]
async def _afetch_agent_card_cached(
endpoint: str,
auth_hash: str,
timeout: int,
) -> AgentCard:
"""Cached async implementation of AgentCard fetching."""
auth = _auth_store.get(auth_hash)
return await _afetch_agent_card_impl(endpoint=endpoint, auth=auth, timeout=timeout)
async def _afetch_agent_card_impl(
endpoint: str,
auth: ClientAuthScheme | None,
timeout: int,
) -> AgentCard:
"""Internal async implementation of AgentCard fetching."""
start_time = time.perf_counter()
if "/.well-known/agent-card.json" in endpoint:
base_url = endpoint.replace("/.well-known/agent-card.json", "")
agent_card_path = "/.well-known/agent-card.json"
else:
url_parts = endpoint.split("/", 3)
base_url = f"{url_parts[0]}//{url_parts[2]}"
agent_card_path = (
f"/{url_parts[3]}"
if len(url_parts) > 3 and url_parts[3]
else "/.well-known/agent-card.json"
)
headers, verify = await _prepare_auth_headers(auth, timeout)
async with httpx.AsyncClient(
timeout=timeout, headers=headers, verify=verify
) as temp_client:
if auth and isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
configure_auth_client(auth, temp_client)
agent_card_url = f"{base_url}{agent_card_path}"
async def _fetch_agent_card_request() -> httpx.Response:
return await temp_client.get(agent_card_url)
try:
response = await retry_on_401(
request_func=_fetch_agent_card_request,
auth_scheme=auth,
client=temp_client,
headers=temp_client.headers,
max_retries=2,
)
response.raise_for_status()
agent_card = AgentCard.model_validate(response.json())
fetch_time_ms = (time.perf_counter() - start_time) * 1000
agent_card_dict = agent_card.model_dump(exclude_none=True)
crewai_event_bus.emit(
None,
A2AAgentCardFetchedEvent(
endpoint=endpoint,
a2a_agent_name=agent_card.name,
agent_card=agent_card_dict,
protocol_version=agent_card.protocol_version,
provider=agent_card_dict.get("provider"),
cached=False,
fetch_time_ms=fetch_time_ms,
),
)
return agent_card
except httpx.HTTPStatusError as e:
elapsed_ms = (time.perf_counter() - start_time) * 1000
response_body = e.response.text[:1000] if e.response.text else None
if e.response.status_code == 401:
error_details = ["Authentication failed"]
www_auth = e.response.headers.get("WWW-Authenticate")
if www_auth:
error_details.append(f"WWW-Authenticate: {www_auth}")
if not auth:
error_details.append("No auth scheme provided")
msg = " | ".join(error_details)
auth_type = type(auth).__name__ if auth else None
crewai_event_bus.emit(
None,
A2AAuthenticationFailedEvent(
endpoint=endpoint,
auth_type=auth_type,
error=msg,
status_code=401,
metadata={
"elapsed_ms": elapsed_ms,
"response_body": response_body,
"www_authenticate": www_auth,
"request_url": str(e.request.url),
},
),
)
raise A2AClientHTTPError(401, msg) from e
crewai_event_bus.emit(
None,
A2AConnectionErrorEvent(
endpoint=endpoint,
error=str(e),
error_type="http_error",
status_code=e.response.status_code,
operation="fetch_agent_card",
metadata={
"elapsed_ms": elapsed_ms,
"response_body": response_body,
"request_url": str(e.request.url),
},
),
)
raise
except httpx.TimeoutException as e:
elapsed_ms = (time.perf_counter() - start_time) * 1000
crewai_event_bus.emit(
None,
A2AConnectionErrorEvent(
endpoint=endpoint,
error=str(e),
error_type="timeout",
operation="fetch_agent_card",
metadata={
"elapsed_ms": elapsed_ms,
"timeout_config": timeout,
"request_url": str(e.request.url) if e.request else None,
},
),
)
raise
except httpx.ConnectError as e:
elapsed_ms = (time.perf_counter() - start_time) * 1000
crewai_event_bus.emit(
None,
A2AConnectionErrorEvent(
endpoint=endpoint,
error=str(e),
error_type="connection_error",
operation="fetch_agent_card",
metadata={
"elapsed_ms": elapsed_ms,
"request_url": str(e.request.url) if e.request else None,
},
),
)
raise
except httpx.RequestError as e:
elapsed_ms = (time.perf_counter() - start_time) * 1000
crewai_event_bus.emit(
None,
A2AConnectionErrorEvent(
endpoint=endpoint,
error=str(e),
error_type="request_error",
operation="fetch_agent_card",
metadata={
"elapsed_ms": elapsed_ms,
"request_url": str(e.request.url) if e.request else None,
},
),
)
raise
def _task_to_skill(task: Task) -> AgentSkill:
"""Convert a CrewAI Task to an A2A AgentSkill.
Args:
task: The CrewAI Task to convert.
Returns:
AgentSkill representing the task's capability.
"""
task_name = task.name or task.description[:50]
task_id = task_name.lower().replace(" ", "_")
tags: list[str] = []
if task.agent:
tags.append(task.agent.role.lower().replace(" ", "-"))
return AgentSkill(
id=task_id,
name=task_name,
description=task.description,
tags=tags,
examples=[task.expected_output] if task.expected_output else None,
)
def _tool_to_skill(tool_name: str, tool_description: str) -> AgentSkill:
"""Convert an Agent's tool to an A2A AgentSkill.
Args:
tool_name: Name of the tool.
tool_description: Description of what the tool does.
Returns:
AgentSkill representing the tool's capability.
"""
tool_id = tool_name.lower().replace(" ", "_")
return AgentSkill(
id=tool_id,
name=tool_name,
description=tool_description,
tags=[tool_name.lower().replace(" ", "-")],
)
def _crew_to_agent_card(crew: Crew, url: str) -> AgentCard:
"""Generate an A2A AgentCard from a Crew instance.
Args:
crew: The Crew instance to generate a card for.
url: The base URL where this crew will be exposed.
Returns:
AgentCard describing the crew's capabilities.
"""
crew_name = getattr(crew, "name", None) or crew.__class__.__name__
description_parts: list[str] = []
crew_description = getattr(crew, "description", None)
if crew_description:
description_parts.append(crew_description)
else:
agent_roles = [agent.role for agent in crew.agents]
description_parts.append(
f"A crew of {len(crew.agents)} agents: {', '.join(agent_roles)}"
)
skills = [_task_to_skill(task) for task in crew.tasks]
return AgentCard(
name=crew_name,
description=" ".join(description_parts),
url=url,
version="1.0.0",
capabilities=AgentCapabilities(
streaming=True,
push_notifications=True,
),
default_input_modes=["text/plain", "application/json"],
default_output_modes=["text/plain", "application/json"],
skills=skills,
)
def _agent_to_agent_card(agent: Agent, url: str) -> AgentCard:
"""Generate an A2A AgentCard from an Agent instance.
Uses A2AServerConfig values when available, falling back to agent properties.
If signing_config is provided, the card will be signed with JWS.
Args:
agent: The Agent instance to generate a card for.
url: The base URL where this agent will be exposed.
Returns:
AgentCard describing the agent's capabilities.
"""
from crewai_a2a.utils.agent_card_signing import sign_agent_card
server_config = _get_server_config(agent) or A2AServerConfig()
name = server_config.name or agent.role
description_parts = [agent.goal]
if agent.backstory:
description_parts.append(agent.backstory)
description = server_config.description or " ".join(description_parts)
skills: list[AgentSkill] = (
server_config.skills.copy() if server_config.skills else []
)
if not skills:
if agent.tools:
for tool in agent.tools:
tool_name = getattr(tool, "name", None) or tool.__class__.__name__
tool_desc = getattr(tool, "description", None) or f"Tool: {tool_name}"
skills.append(_tool_to_skill(tool_name, tool_desc))
if not skills:
skills.append(
AgentSkill(
id=agent.role.lower().replace(" ", "_"),
name=agent.role,
description=agent.goal,
tags=[agent.role.lower().replace(" ", "-")],
)
)
capabilities = server_config.capabilities
if server_config.server_extensions:
from crewai_a2a.extensions.server import ServerExtensionRegistry
registry = ServerExtensionRegistry(server_config.server_extensions)
ext_list = registry.get_agent_extensions()
existing_exts = list(capabilities.extensions) if capabilities.extensions else []
existing_uris = {e.uri for e in existing_exts}
for ext in ext_list:
if ext.uri not in existing_uris:
existing_exts.append(ext)
capabilities = capabilities.model_copy(update={"extensions": existing_exts})
card = AgentCard(
name=name,
description=description,
url=server_config.url or url,
version=server_config.version,
capabilities=capabilities,
default_input_modes=server_config.default_input_modes,
default_output_modes=server_config.default_output_modes,
skills=skills,
preferred_transport=server_config.transport.preferred,
protocol_version=server_config.protocol_version,
provider=server_config.provider,
documentation_url=server_config.documentation_url,
icon_url=server_config.icon_url,
additional_interfaces=server_config.additional_interfaces,
security=server_config.security,
security_schemes=server_config.security_schemes,
supports_authenticated_extended_card=server_config.supports_authenticated_extended_card,
)
if server_config.signing_config:
signature = sign_agent_card(
card,
private_key=server_config.signing_config.get_private_key(),
key_id=server_config.signing_config.key_id,
algorithm=server_config.signing_config.algorithm,
)
card = card.model_copy(update={"signatures": [signature]})
elif server_config.signatures:
card = card.model_copy(update={"signatures": server_config.signatures})
return card
def inject_a2a_server_methods(agent: Agent) -> None:
"""Inject A2A server methods onto an Agent instance.
Adds a `to_agent_card(url: str) -> AgentCard` method to the agent
that generates an A2A-compliant AgentCard.
Only injects if the agent has an A2AServerConfig.
Args:
agent: The Agent instance to inject methods onto.
"""
if _get_server_config(agent) is None:
return
def _to_agent_card(self: Agent, url: str) -> AgentCard:
return _agent_to_agent_card(self, url)
object.__setattr__(agent, "to_agent_card", MethodType(_to_agent_card, agent))

View File

@@ -0,0 +1,236 @@
"""AgentCard JWS signing utilities.
This module provides functions for signing and verifying AgentCards using
JSON Web Signatures (JWS) as per RFC 7515. Signed agent cards allow clients
to verify the authenticity and integrity of agent card information.
Example:
>>> from crewai_a2a.utils.agent_card_signing import sign_agent_card
>>> signature = sign_agent_card(agent_card, private_key_pem, key_id="key-1")
>>> card_with_sig = card.model_copy(update={"signatures": [signature]})
"""
from __future__ import annotations
import base64
import json
import logging
from typing import Any, Literal
from a2a.types import AgentCard, AgentCardSignature
import jwt
from pydantic import SecretStr
logger = logging.getLogger(__name__)
SigningAlgorithm = Literal[
"RS256", "RS384", "RS512", "ES256", "ES384", "ES512", "PS256", "PS384", "PS512"
]
def _normalize_private_key(private_key: str | bytes | SecretStr) -> bytes:
"""Normalize private key to bytes format.
Args:
private_key: PEM-encoded private key as string, bytes, or SecretStr.
Returns:
Private key as bytes.
"""
if isinstance(private_key, SecretStr):
private_key = private_key.get_secret_value()
if isinstance(private_key, str):
private_key = private_key.encode()
return private_key
def _serialize_agent_card(agent_card: AgentCard) -> str:
"""Serialize AgentCard to canonical JSON for signing.
Excludes the signatures field to avoid circular reference during signing.
Uses sorted keys and compact separators for deterministic output.
Args:
agent_card: The AgentCard to serialize.
Returns:
Canonical JSON string representation.
"""
card_dict = agent_card.model_dump(exclude={"signatures"}, exclude_none=True)
return json.dumps(card_dict, sort_keys=True, separators=(",", ":"))
def _base64url_encode(data: bytes | str) -> str:
"""Encode data to URL-safe base64 without padding.
Args:
data: Data to encode.
Returns:
URL-safe base64 encoded string without padding.
"""
if isinstance(data, str):
data = data.encode()
return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii")
def sign_agent_card(
agent_card: AgentCard,
private_key: str | bytes | SecretStr,
key_id: str | None = None,
algorithm: SigningAlgorithm = "RS256",
) -> AgentCardSignature:
"""Sign an AgentCard using JWS (RFC 7515).
Creates a detached JWS signature for the AgentCard. The signature covers
all fields except the signatures field itself.
Args:
agent_card: The AgentCard to sign.
private_key: PEM-encoded private key (RSA, EC, or RSA-PSS).
key_id: Optional key identifier for the JWS header (kid claim).
algorithm: Signing algorithm (RS256, ES256, PS256, etc.).
Returns:
AgentCardSignature with protected header and signature.
Raises:
jwt.exceptions.InvalidKeyError: If the private key is invalid.
ValueError: If the algorithm is not supported for the key type.
Example:
>>> signature = sign_agent_card(
... agent_card,
... private_key_pem="-----BEGIN PRIVATE KEY-----...",
... key_id="my-key-id",
... )
"""
key_bytes = _normalize_private_key(private_key)
payload = _serialize_agent_card(agent_card)
protected_header: dict[str, Any] = {"typ": "JWS"}
if key_id:
protected_header["kid"] = key_id
jws_token = jwt.api_jws.encode(
payload.encode(),
key_bytes,
algorithm=algorithm,
headers=protected_header,
)
parts = jws_token.split(".")
protected_b64 = parts[0]
signature_b64 = parts[2]
header: dict[str, Any] | None = None
if key_id:
header = {"kid": key_id}
return AgentCardSignature(
protected=protected_b64,
signature=signature_b64,
header=header,
)
def verify_agent_card_signature(
agent_card: AgentCard,
signature: AgentCardSignature,
public_key: str | bytes,
algorithms: list[str] | None = None,
) -> bool:
"""Verify an AgentCard JWS signature.
Validates that the signature was created with the corresponding private key
and that the AgentCard content has not been modified.
Args:
agent_card: The AgentCard to verify.
signature: The AgentCardSignature to validate.
public_key: PEM-encoded public key (RSA, EC, or RSA-PSS).
algorithms: List of allowed algorithms. Defaults to common asymmetric algorithms.
Returns:
True if signature is valid, False otherwise.
Example:
>>> is_valid = verify_agent_card_signature(
... agent_card, signature, public_key_pem="-----BEGIN PUBLIC KEY-----..."
... )
"""
if algorithms is None:
algorithms = [
"RS256",
"RS384",
"RS512",
"ES256",
"ES384",
"ES512",
"PS256",
"PS384",
"PS512",
]
if isinstance(public_key, str):
public_key = public_key.encode()
payload = _serialize_agent_card(agent_card)
payload_b64 = _base64url_encode(payload)
jws_token = f"{signature.protected}.{payload_b64}.{signature.signature}"
try:
jwt.api_jws.decode(
jws_token,
public_key,
algorithms=algorithms,
)
return True
except jwt.InvalidSignatureError:
logger.debug(
"AgentCard signature verification failed",
extra={"reason": "invalid_signature"},
)
return False
except jwt.DecodeError as e:
logger.debug(
"AgentCard signature verification failed",
extra={"reason": "decode_error", "error": str(e)},
)
return False
except jwt.InvalidAlgorithmError as e:
logger.debug(
"AgentCard signature verification failed",
extra={"reason": "algorithm_error", "error": str(e)},
)
return False
def get_key_id_from_signature(signature: AgentCardSignature) -> str | None:
"""Extract the key ID (kid) from an AgentCardSignature.
Checks both the unprotected header and the protected header for the kid claim.
Args:
signature: The AgentCardSignature to extract from.
Returns:
The key ID if present, None otherwise.
"""
if signature.header and "kid" in signature.header:
kid: str = signature.header["kid"]
return kid
try:
protected = signature.protected
padding_needed = 4 - (len(protected) % 4)
if padding_needed != 4:
protected += "=" * padding_needed
protected_json = base64.urlsafe_b64decode(protected).decode()
protected_header: dict[str, Any] = json.loads(protected_json)
return protected_header.get("kid")
except (ValueError, json.JSONDecodeError):
return None

View File

@@ -0,0 +1,338 @@
"""Content type negotiation for A2A protocol.
This module handles negotiation of input/output MIME types between A2A clients
and servers based on AgentCard capabilities.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Annotated, Final, Literal, cast
from a2a.types import Part
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import A2AContentTypeNegotiatedEvent
if TYPE_CHECKING:
from a2a.types import AgentCard, AgentSkill
TEXT_PLAIN: Literal["text/plain"] = "text/plain"
APPLICATION_JSON: Literal["application/json"] = "application/json"
IMAGE_PNG: Literal["image/png"] = "image/png"
IMAGE_JPEG: Literal["image/jpeg"] = "image/jpeg"
IMAGE_WILDCARD: Literal["image/*"] = "image/*"
APPLICATION_PDF: Literal["application/pdf"] = "application/pdf"
APPLICATION_OCTET_STREAM: Literal["application/octet-stream"] = (
"application/octet-stream"
)
DEFAULT_CLIENT_INPUT_MODES: Final[list[Literal["text/plain", "application/json"]]] = [
TEXT_PLAIN,
APPLICATION_JSON,
]
DEFAULT_CLIENT_OUTPUT_MODES: Final[list[Literal["text/plain", "application/json"]]] = [
TEXT_PLAIN,
APPLICATION_JSON,
]
@dataclass
class NegotiatedContentTypes:
"""Result of content type negotiation."""
input_modes: Annotated[list[str], "Negotiated input MIME types the client can send"]
output_modes: Annotated[
list[str], "Negotiated output MIME types the server will produce"
]
effective_input_modes: Annotated[list[str], "Server's effective input modes"]
effective_output_modes: Annotated[list[str], "Server's effective output modes"]
skill_name: Annotated[
str | None, "Skill name if negotiation was skill-specific"
] = None
class ContentTypeNegotiationError(Exception):
"""Raised when no compatible content types can be negotiated."""
def __init__(
self,
client_input_modes: list[str],
client_output_modes: list[str],
server_input_modes: list[str],
server_output_modes: list[str],
direction: str = "both",
message: str | None = None,
) -> None:
self.client_input_modes = client_input_modes
self.client_output_modes = client_output_modes
self.server_input_modes = server_input_modes
self.server_output_modes = server_output_modes
self.direction = direction
if message is None:
if direction == "input":
message = (
f"No compatible input content types. "
f"Client supports: {client_input_modes}, "
f"Server accepts: {server_input_modes}"
)
elif direction == "output":
message = (
f"No compatible output content types. "
f"Client accepts: {client_output_modes}, "
f"Server produces: {server_output_modes}"
)
else:
message = (
f"No compatible content types. "
f"Input - Client: {client_input_modes}, Server: {server_input_modes}. "
f"Output - Client: {client_output_modes}, Server: {server_output_modes}"
)
super().__init__(message)
def _normalize_mime_type(mime_type: str) -> str:
"""Normalize MIME type for comparison (lowercase, strip whitespace)."""
return mime_type.lower().strip()
def _mime_types_compatible(client_type: str, server_type: str) -> bool:
"""Check if two MIME types are compatible.
Handles wildcards like image/* matching image/png.
"""
client_normalized = _normalize_mime_type(client_type)
server_normalized = _normalize_mime_type(server_type)
if client_normalized == server_normalized:
return True
if "*" in client_normalized or "*" in server_normalized:
client_parts = client_normalized.split("/")
server_parts = server_normalized.split("/")
if len(client_parts) == 2 and len(server_parts) == 2:
type_match = (
client_parts[0] == server_parts[0]
or client_parts[0] == "*"
or server_parts[0] == "*"
)
subtype_match = (
client_parts[1] == server_parts[1]
or client_parts[1] == "*"
or server_parts[1] == "*"
)
return type_match and subtype_match
return False
def _find_compatible_modes(
client_modes: list[str], server_modes: list[str]
) -> list[str]:
"""Find compatible MIME types between client and server.
Returns modes in client preference order.
"""
compatible = []
for client_mode in client_modes:
for server_mode in server_modes:
if _mime_types_compatible(client_mode, server_mode):
if "*" in client_mode and "*" not in server_mode:
if server_mode not in compatible:
compatible.append(server_mode)
else:
if client_mode not in compatible:
compatible.append(client_mode)
break
return compatible
def _get_effective_modes(
agent_card: AgentCard,
skill_name: str | None = None,
) -> tuple[list[str], list[str], AgentSkill | None]:
"""Get effective input/output modes from agent card.
If skill_name is provided and the skill has custom modes, those are used.
Otherwise, falls back to agent card defaults.
"""
skill: AgentSkill | None = None
if skill_name and agent_card.skills:
for s in agent_card.skills:
if s.name == skill_name or s.id == skill_name:
skill = s
break
if skill:
input_modes = (
skill.input_modes if skill.input_modes else agent_card.default_input_modes
)
output_modes = (
skill.output_modes
if skill.output_modes
else agent_card.default_output_modes
)
else:
input_modes = agent_card.default_input_modes
output_modes = agent_card.default_output_modes
return input_modes, output_modes, skill
def negotiate_content_types(
agent_card: AgentCard,
client_input_modes: list[str] | None = None,
client_output_modes: list[str] | None = None,
skill_name: str | None = None,
emit_event: bool = True,
endpoint: str | None = None,
a2a_agent_name: str | None = None,
strict: bool = False,
) -> NegotiatedContentTypes:
"""Negotiate content types between client and server.
Args:
agent_card: The remote agent's card with capability info.
client_input_modes: MIME types the client can send. Defaults to text/plain and application/json.
client_output_modes: MIME types the client can accept. Defaults to text/plain and application/json.
skill_name: Optional skill to use for mode lookup.
emit_event: Whether to emit a content type negotiation event.
endpoint: Agent endpoint (for event metadata).
a2a_agent_name: Agent name (for event metadata).
strict: If True, raises error when no compatible types found.
If False, returns empty lists for incompatible directions.
Returns:
NegotiatedContentTypes with compatible input and output modes.
Raises:
ContentTypeNegotiationError: If strict=True and no compatible types found.
"""
if client_input_modes is None:
client_input_modes = cast(list[str], DEFAULT_CLIENT_INPUT_MODES.copy())
if client_output_modes is None:
client_output_modes = cast(list[str], DEFAULT_CLIENT_OUTPUT_MODES.copy())
server_input_modes, server_output_modes, skill = _get_effective_modes(
agent_card, skill_name
)
compatible_input = _find_compatible_modes(client_input_modes, server_input_modes)
compatible_output = _find_compatible_modes(client_output_modes, server_output_modes)
if strict:
if not compatible_input and not compatible_output:
raise ContentTypeNegotiationError(
client_input_modes=client_input_modes,
client_output_modes=client_output_modes,
server_input_modes=server_input_modes,
server_output_modes=server_output_modes,
)
if not compatible_input:
raise ContentTypeNegotiationError(
client_input_modes=client_input_modes,
client_output_modes=client_output_modes,
server_input_modes=server_input_modes,
server_output_modes=server_output_modes,
direction="input",
)
if not compatible_output:
raise ContentTypeNegotiationError(
client_input_modes=client_input_modes,
client_output_modes=client_output_modes,
server_input_modes=server_input_modes,
server_output_modes=server_output_modes,
direction="output",
)
result = NegotiatedContentTypes(
input_modes=compatible_input,
output_modes=compatible_output,
effective_input_modes=server_input_modes,
effective_output_modes=server_output_modes,
skill_name=skill.name if skill else None,
)
if emit_event:
crewai_event_bus.emit(
None,
A2AContentTypeNegotiatedEvent(
endpoint=endpoint or agent_card.url,
a2a_agent_name=a2a_agent_name or agent_card.name,
skill_name=skill_name,
client_input_modes=client_input_modes,
client_output_modes=client_output_modes,
server_input_modes=server_input_modes,
server_output_modes=server_output_modes,
negotiated_input_modes=compatible_input,
negotiated_output_modes=compatible_output,
negotiation_success=bool(compatible_input and compatible_output),
),
)
return result
def validate_content_type(
content_type: str,
allowed_modes: list[str],
) -> bool:
"""Validate that a content type is allowed by a list of modes.
Args:
content_type: The MIME type to validate.
allowed_modes: List of allowed MIME types (may include wildcards).
Returns:
True if content_type is compatible with any allowed mode.
"""
for mode in allowed_modes:
if _mime_types_compatible(content_type, mode):
return True
return False
def get_part_content_type(part: Part) -> str:
"""Extract MIME type from an A2A Part.
Args:
part: A Part object containing TextPart, DataPart, or FilePart.
Returns:
The MIME type string for this part.
"""
root = part.root
if root.kind == "text":
return TEXT_PLAIN
if root.kind == "data":
return APPLICATION_JSON
if root.kind == "file":
return root.file.mime_type or APPLICATION_OCTET_STREAM
return APPLICATION_OCTET_STREAM
def validate_message_parts(
parts: list[Part],
allowed_modes: list[str],
) -> list[str]:
"""Validate that all message parts have allowed content types.
Args:
parts: List of Parts from the incoming message.
allowed_modes: List of allowed MIME types (from default_input_modes).
Returns:
List of invalid content types found (empty if all valid).
"""
invalid_types: list[str] = []
for part in parts:
content_type = get_part_content_type(part)
if not validate_content_type(content_type, allowed_modes):
if content_type not in invalid_types:
invalid_types.append(content_type)
return invalid_types

View File

@@ -0,0 +1,980 @@
"""A2A delegation utilities for executing tasks on remote agents."""
from __future__ import annotations
import asyncio
import base64
from collections.abc import AsyncIterator, Callable, MutableMapping
from contextlib import asynccontextmanager
import logging
from typing import TYPE_CHECKING, Any, Final, Literal
import uuid
from a2a.client import Client, ClientConfig, ClientFactory
from a2a.types import (
AgentCard,
FilePart,
FileWithBytes,
Message,
Part,
PushNotificationConfig as A2APushNotificationConfig,
Role,
TextPart,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
A2AConversationStartedEvent,
A2ADelegationCompletedEvent,
A2ADelegationStartedEvent,
A2AMessageSentEvent,
)
import httpx
from pydantic import BaseModel
from crewai_a2a.auth.client_schemes import APIKeyAuth, HTTPDigestAuth
from crewai_a2a.auth.utils import (
_auth_store,
configure_auth_client,
validate_auth_against_agent_card,
)
from crewai_a2a.config import ClientTransportConfig, GRPCClientConfig
from crewai_a2a.extensions.registry import (
ExtensionsMiddleware,
validate_required_extensions,
)
from crewai_a2a.task_helpers import TaskStateResult
from crewai_a2a.types import (
HANDLER_REGISTRY,
HandlerType,
PartsDict,
PartsMetadataDict,
TransportType,
)
from crewai_a2a.updates import (
PollingConfig,
PushNotificationConfig,
StreamingHandler,
UpdateConfig,
)
from crewai_a2a.utils.agent_card import (
_afetch_agent_card_cached,
_get_tls_verify,
_prepare_auth_headers,
)
from crewai_a2a.utils.content_type import (
DEFAULT_CLIENT_OUTPUT_MODES,
negotiate_content_types,
)
from crewai_a2a.utils.transport import (
NegotiatedTransport,
TransportNegotiationError,
negotiate_transport,
)
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from a2a.types import Message
from crewai_a2a.auth.client_schemes import ClientAuthScheme
_DEFAULT_TRANSPORT: Final[TransportType] = "JSONRPC"
def _create_file_parts(input_files: dict[str, Any] | None) -> list[Part]:
"""Convert FileInput dictionary to FilePart objects.
Args:
input_files: Dictionary mapping names to FileInput objects.
Returns:
List of Part objects containing FilePart data.
"""
if not input_files:
return []
try:
import crewai_files # noqa: F401
except ImportError:
logger.debug("crewai_files not installed, skipping file parts")
return []
parts: list[Part] = []
for name, file_input in input_files.items():
content_bytes = file_input.read()
content_base64 = base64.b64encode(content_bytes).decode()
file_with_bytes = FileWithBytes(
bytes=content_base64,
mimeType=file_input.content_type,
name=file_input.filename or name,
)
parts.append(Part(root=FilePart(file=file_with_bytes)))
return parts
def get_handler(config: UpdateConfig | None) -> HandlerType:
"""Get the handler class for a given update config.
Args:
config: Update mechanism configuration.
Returns:
Handler class for the config type, defaults to StreamingHandler.
"""
if config is None:
return StreamingHandler
return HANDLER_REGISTRY.get(type(config), StreamingHandler)
def execute_a2a_delegation(
endpoint: str,
auth: ClientAuthScheme | None,
timeout: int,
task_description: str,
context: str | None = None,
context_id: str | None = None,
task_id: str | None = None,
reference_task_ids: list[str] | None = None,
metadata: dict[str, Any] | None = None,
extensions: dict[str, Any] | None = None,
conversation_history: list[Message] | None = None,
agent_id: str | None = None,
agent_role: Role | None = None,
agent_branch: Any | None = None,
response_model: type[BaseModel] | None = None,
turn_number: int | None = None,
updates: UpdateConfig | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
skill_id: str | None = None,
client_extensions: list[str] | None = None,
transport: ClientTransportConfig | None = None,
accepted_output_modes: list[str] | None = None,
input_files: dict[str, Any] | None = None,
) -> TaskStateResult:
"""Execute a task delegation to a remote A2A agent synchronously.
WARNING: This function blocks the entire thread by creating and running a new
event loop. Prefer using 'await aexecute_a2a_delegation()' in async contexts
for better performance and resource efficiency.
This is a synchronous wrapper around aexecute_a2a_delegation that creates a
new event loop to run the async implementation. It is provided for compatibility
with synchronous code paths only.
Args:
endpoint: A2A agent endpoint URL (AgentCard URL).
auth: Optional ClientAuthScheme for authentication.
timeout: Request timeout in seconds.
task_description: The task to delegate.
context: Optional context information.
context_id: Context ID for correlating messages/tasks.
task_id: Specific task identifier.
reference_task_ids: List of related task IDs.
metadata: Additional metadata.
extensions: Protocol extensions for custom fields.
conversation_history: Previous Message objects from conversation.
agent_id: Agent identifier for logging.
agent_role: Role of the CrewAI agent delegating the task.
agent_branch: Optional agent tree branch for logging.
response_model: Optional Pydantic model for structured outputs.
turn_number: Optional turn number for multi-turn conversations.
updates: Update mechanism config from A2AConfig.updates.
from_task: Optional CrewAI Task object for event metadata.
from_agent: Optional CrewAI Agent object for event metadata.
skill_id: Optional skill ID to target a specific agent capability.
client_extensions: A2A protocol extension URIs the client supports.
transport: Transport configuration (preferred, supported transports, gRPC settings).
accepted_output_modes: MIME types the client can accept in responses.
input_files: Optional dictionary of files to send to remote agent.
Returns:
TaskStateResult with status, result/error, history, and agent_card.
Raises:
RuntimeError: If called from an async context with a running event loop.
"""
try:
asyncio.get_running_loop()
raise RuntimeError(
"execute_a2a_delegation() cannot be called from an async context. "
"Use 'await aexecute_a2a_delegation()' instead."
)
except RuntimeError as e:
if "no running event loop" not in str(e).lower():
raise
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(
aexecute_a2a_delegation(
endpoint=endpoint,
auth=auth,
timeout=timeout,
task_description=task_description,
context=context,
context_id=context_id,
task_id=task_id,
reference_task_ids=reference_task_ids,
metadata=metadata,
extensions=extensions,
conversation_history=conversation_history,
agent_id=agent_id,
agent_role=agent_role,
agent_branch=agent_branch,
response_model=response_model,
turn_number=turn_number,
updates=updates,
from_task=from_task,
from_agent=from_agent,
skill_id=skill_id,
client_extensions=client_extensions,
transport=transport,
accepted_output_modes=accepted_output_modes,
input_files=input_files,
)
)
finally:
try:
loop.run_until_complete(loop.shutdown_asyncgens())
finally:
loop.close()
async def aexecute_a2a_delegation(
endpoint: str,
auth: ClientAuthScheme | None,
timeout: int,
task_description: str,
context: str | None = None,
context_id: str | None = None,
task_id: str | None = None,
reference_task_ids: list[str] | None = None,
metadata: dict[str, Any] | None = None,
extensions: dict[str, Any] | None = None,
conversation_history: list[Message] | None = None,
agent_id: str | None = None,
agent_role: Role | None = None,
agent_branch: Any | None = None,
response_model: type[BaseModel] | None = None,
turn_number: int | None = None,
updates: UpdateConfig | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
skill_id: str | None = None,
client_extensions: list[str] | None = None,
transport: ClientTransportConfig | None = None,
accepted_output_modes: list[str] | None = None,
input_files: dict[str, Any] | None = None,
) -> TaskStateResult:
"""Execute a task delegation to a remote A2A agent asynchronously.
Native async implementation with multi-turn support. Use this when running
in an async context (e.g., with Crew.akickoff() or agent.aexecute_task()).
Args:
endpoint: A2A agent endpoint URL.
auth: Optional ClientAuthScheme for authentication.
timeout: Request timeout in seconds.
task_description: The task to delegate.
context: Optional context information.
context_id: Context ID for correlating messages/tasks.
task_id: Specific task identifier.
reference_task_ids: List of related task IDs.
metadata: Additional metadata.
extensions: Protocol extensions for custom fields.
conversation_history: Previous Message objects from conversation.
agent_id: Agent identifier for logging.
agent_role: Role of the CrewAI agent delegating the task.
agent_branch: Optional agent tree branch for logging.
response_model: Optional Pydantic model for structured outputs.
turn_number: Optional turn number for multi-turn conversations.
updates: Update mechanism config from A2AConfig.updates.
from_task: Optional CrewAI Task object for event metadata.
from_agent: Optional CrewAI Agent object for event metadata.
skill_id: Optional skill ID to target a specific agent capability.
client_extensions: A2A protocol extension URIs the client supports.
transport: Transport configuration (preferred, supported transports, gRPC settings).
accepted_output_modes: MIME types the client can accept in responses.
input_files: Optional dictionary of files to send to remote agent.
Returns:
TaskStateResult with status, result/error, history, and agent_card.
"""
if conversation_history is None:
conversation_history = []
is_multiturn = len(conversation_history) > 0
if turn_number is None:
turn_number = len([m for m in conversation_history if m.role == Role.user]) + 1
try:
result = await _aexecute_a2a_delegation_impl(
endpoint=endpoint,
auth=auth,
timeout=timeout,
task_description=task_description,
context=context,
context_id=context_id,
task_id=task_id,
reference_task_ids=reference_task_ids,
metadata=metadata,
extensions=extensions,
conversation_history=conversation_history,
is_multiturn=is_multiturn,
turn_number=turn_number,
agent_branch=agent_branch,
agent_id=agent_id,
agent_role=agent_role,
response_model=response_model,
updates=updates,
from_task=from_task,
from_agent=from_agent,
skill_id=skill_id,
client_extensions=client_extensions,
transport=transport,
accepted_output_modes=accepted_output_modes,
input_files=input_files,
)
except Exception as e:
crewai_event_bus.emit(
agent_branch,
A2ADelegationCompletedEvent(
status="failed",
result=None,
error=str(e),
context_id=context_id,
is_multiturn=is_multiturn,
endpoint=endpoint,
metadata=metadata,
extensions=list(extensions.keys()) if extensions else None,
from_task=from_task,
from_agent=from_agent,
),
)
raise
agent_card_data = result.get("agent_card")
crewai_event_bus.emit(
agent_branch,
A2ADelegationCompletedEvent(
status=result["status"],
result=result.get("result"),
error=result.get("error"),
context_id=context_id,
is_multiturn=is_multiturn,
endpoint=endpoint,
a2a_agent_name=result.get("a2a_agent_name"),
agent_card=agent_card_data,
provider=agent_card_data.get("provider") if agent_card_data else None,
metadata=metadata,
extensions=list(extensions.keys()) if extensions else None,
from_task=from_task,
from_agent=from_agent,
),
)
return result
async def _aexecute_a2a_delegation_impl(
endpoint: str,
auth: ClientAuthScheme | None,
timeout: int,
task_description: str,
context: str | None,
context_id: str | None,
task_id: str | None,
reference_task_ids: list[str] | None,
metadata: dict[str, Any] | None,
extensions: dict[str, Any] | None,
conversation_history: list[Message],
is_multiturn: bool,
turn_number: int,
agent_branch: Any | None,
agent_id: str | None,
agent_role: str | None,
response_model: type[BaseModel] | None,
updates: UpdateConfig | None,
from_task: Any | None = None,
from_agent: Any | None = None,
skill_id: str | None = None,
client_extensions: list[str] | None = None,
transport: ClientTransportConfig | None = None,
accepted_output_modes: list[str] | None = None,
input_files: dict[str, Any] | None = None,
) -> TaskStateResult:
"""Internal async implementation of A2A delegation."""
if transport is None:
transport = ClientTransportConfig()
if auth:
auth_data = auth.model_dump_json(
exclude={
"_access_token",
"_token_expires_at",
"_refresh_token",
"_authorization_callback",
}
)
auth_hash = _auth_store.compute_key(type(auth).__name__, auth_data)
else:
auth_hash = _auth_store.compute_key("none", endpoint)
_auth_store.set(auth_hash, auth)
agent_card = await _afetch_agent_card_cached(
endpoint=endpoint, auth_hash=auth_hash, timeout=timeout
)
validate_auth_against_agent_card(agent_card, auth)
unsupported_exts = validate_required_extensions(agent_card, client_extensions)
if unsupported_exts:
ext_uris = [ext.uri for ext in unsupported_exts]
raise ValueError(
f"Agent requires extensions not supported by client: {ext_uris}"
)
negotiated: NegotiatedTransport | None = None
effective_transport: TransportType = transport.preferred or _DEFAULT_TRANSPORT
effective_url = endpoint
client_transports: list[str] = (
list(transport.supported) if transport.supported else [_DEFAULT_TRANSPORT]
)
try:
negotiated = negotiate_transport(
agent_card=agent_card,
client_supported_transports=client_transports,
client_preferred_transport=transport.preferred,
endpoint=endpoint,
a2a_agent_name=agent_card.name,
)
effective_transport = negotiated.transport # type: ignore[assignment]
effective_url = negotiated.url
except TransportNegotiationError as e:
logger.warning(
"Transport negotiation failed, using fallback",
extra={
"error": str(e),
"fallback_transport": effective_transport,
"fallback_url": effective_url,
"endpoint": endpoint,
"client_transports": client_transports,
"server_transports": [
iface.transport for iface in agent_card.additional_interfaces or []
]
+ [agent_card.preferred_transport or "JSONRPC"],
},
)
effective_output_modes = accepted_output_modes or DEFAULT_CLIENT_OUTPUT_MODES.copy()
content_negotiated = negotiate_content_types(
agent_card=agent_card,
client_output_modes=accepted_output_modes,
skill_name=skill_id,
endpoint=endpoint,
a2a_agent_name=agent_card.name,
)
if content_negotiated.output_modes:
effective_output_modes = content_negotiated.output_modes
headers, _ = await _prepare_auth_headers(auth, timeout)
a2a_agent_name = None
if agent_card.name:
a2a_agent_name = agent_card.name
agent_card_dict = agent_card.model_dump(exclude_none=True)
crewai_event_bus.emit(
agent_branch,
A2ADelegationStartedEvent(
endpoint=endpoint,
task_description=task_description,
agent_id=agent_id or endpoint,
context_id=context_id,
is_multiturn=is_multiturn,
turn_number=turn_number,
a2a_agent_name=a2a_agent_name,
agent_card=agent_card_dict,
protocol_version=agent_card.protocol_version,
provider=agent_card_dict.get("provider"),
skill_id=skill_id,
metadata=metadata,
extensions=list(extensions.keys()) if extensions else None,
from_task=from_task,
from_agent=from_agent,
),
)
if turn_number == 1:
agent_id_for_event = agent_id or endpoint
crewai_event_bus.emit(
agent_branch,
A2AConversationStartedEvent(
agent_id=agent_id_for_event,
endpoint=endpoint,
context_id=context_id,
a2a_agent_name=a2a_agent_name,
agent_card=agent_card_dict,
protocol_version=agent_card.protocol_version,
provider=agent_card_dict.get("provider"),
skill_id=skill_id,
reference_task_ids=reference_task_ids,
metadata=metadata,
extensions=list(extensions.keys()) if extensions else None,
from_task=from_task,
from_agent=from_agent,
),
)
message_parts = []
if context:
message_parts.append(f"Context:\n{context}\n\n")
message_parts.append(f"{task_description}")
message_text = "".join(message_parts)
if is_multiturn and conversation_history and not task_id:
if first_task_id := conversation_history[0].task_id:
task_id = first_task_id
parts: PartsDict = {"text": message_text}
if response_model:
parts.update(
{
"metadata": PartsMetadataDict(
mimeType="application/json",
schema=response_model.model_json_schema(),
)
}
)
message_metadata = metadata.copy() if metadata else {}
if skill_id:
message_metadata["skill_id"] = skill_id
parts_list: list[Part] = [Part(root=TextPart(**parts))]
parts_list.extend(_create_file_parts(input_files))
message = Message(
role=Role.user,
message_id=str(uuid.uuid4()),
parts=parts_list,
context_id=context_id,
task_id=task_id,
reference_task_ids=reference_task_ids,
metadata=message_metadata if message_metadata else None,
extensions=extensions,
)
new_messages: list[Message] = [*conversation_history, message]
crewai_event_bus.emit(
None,
A2AMessageSentEvent(
message=message_text,
turn_number=turn_number,
context_id=context_id,
message_id=message.message_id,
is_multiturn=is_multiturn,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
skill_id=skill_id,
metadata=message_metadata if message_metadata else None,
extensions=list(extensions.keys()) if extensions else None,
from_task=from_task,
from_agent=from_agent,
),
)
handler = get_handler(updates)
use_polling = isinstance(updates, PollingConfig)
handler_kwargs: dict[str, Any] = {
"turn_number": turn_number,
"is_multiturn": is_multiturn,
"agent_role": agent_role,
"context_id": context_id,
"task_id": task_id,
"endpoint": endpoint,
"agent_branch": agent_branch,
"a2a_agent_name": a2a_agent_name,
"from_task": from_task,
"from_agent": from_agent,
}
if isinstance(updates, PollingConfig):
handler_kwargs.update(
{
"polling_interval": updates.interval,
"polling_timeout": updates.timeout or float(timeout),
"history_length": updates.history_length,
"max_polls": updates.max_polls,
}
)
elif isinstance(updates, PushNotificationConfig):
handler_kwargs.update(
{
"config": updates,
"result_store": updates.result_store,
"polling_timeout": updates.timeout or float(timeout),
"polling_interval": updates.interval,
}
)
push_config_for_client = (
updates if isinstance(updates, PushNotificationConfig) else None
)
use_streaming = not use_polling and push_config_for_client is None
client_agent_card = agent_card
if effective_url != agent_card.url:
client_agent_card = agent_card.model_copy(update={"url": effective_url})
async with _create_a2a_client(
agent_card=client_agent_card,
transport_protocol=effective_transport,
timeout=timeout,
headers=headers,
streaming=use_streaming,
auth=auth,
use_polling=use_polling,
push_notification_config=push_config_for_client,
client_extensions=client_extensions,
accepted_output_modes=effective_output_modes, # type: ignore[arg-type]
grpc_config=transport.grpc,
) as client:
result = await handler.execute(
client=client,
message=message,
new_messages=new_messages,
agent_card=agent_card,
**handler_kwargs,
)
result["a2a_agent_name"] = a2a_agent_name
result["agent_card"] = agent_card.model_dump(exclude_none=True)
return result
def _normalize_grpc_metadata(
metadata: tuple[tuple[str, str], ...] | None,
) -> tuple[tuple[str, str], ...] | None:
"""Lowercase all gRPC metadata keys.
gRPC requires lowercase metadata keys, but some libraries (like the A2A SDK)
use mixed-case headers like 'X-A2A-Extensions'. This normalizes them.
"""
if metadata is None:
return None
return tuple((key.lower(), value) for key, value in metadata)
def _create_grpc_interceptors(
auth_metadata: list[tuple[str, str]] | None = None,
) -> list[Any]:
"""Create gRPC interceptors for metadata normalization and auth injection.
Args:
auth_metadata: Optional auth metadata to inject into all calls.
Used for insecure channels that need auth (non-localhost without TLS).
Returns a list of interceptors that lowercase metadata keys for gRPC
compatibility. Must be called after grpc is imported.
"""
import grpc.aio # type: ignore[import-untyped]
def _merge_metadata(
existing: tuple[tuple[str, str], ...] | None,
auth: list[tuple[str, str]] | None,
) -> tuple[tuple[str, str], ...] | None:
"""Merge existing metadata with auth metadata and normalize keys."""
merged: list[tuple[str, str]] = []
if existing:
merged.extend(existing)
if auth:
merged.extend(auth)
if not merged:
return None
return tuple((key.lower(), value) for key, value in merged)
def _inject_metadata(client_call_details: Any) -> Any:
"""Inject merged metadata into call details."""
return client_call_details._replace(
metadata=_merge_metadata(client_call_details.metadata, auth_metadata)
)
class MetadataUnaryUnary(grpc.aio.UnaryUnaryClientInterceptor): # type: ignore[misc,no-any-unimported]
"""Interceptor for unary-unary calls that injects auth metadata."""
async def intercept_unary_unary( # type: ignore[no-untyped-def]
self, continuation, client_call_details, request
):
"""Intercept unary-unary call and inject metadata."""
return await continuation(_inject_metadata(client_call_details), request)
class MetadataUnaryStream(grpc.aio.UnaryStreamClientInterceptor): # type: ignore[misc,no-any-unimported]
"""Interceptor for unary-stream calls that injects auth metadata."""
async def intercept_unary_stream( # type: ignore[no-untyped-def]
self, continuation, client_call_details, request
):
"""Intercept unary-stream call and inject metadata."""
return await continuation(_inject_metadata(client_call_details), request)
class MetadataStreamUnary(grpc.aio.StreamUnaryClientInterceptor): # type: ignore[misc,no-any-unimported]
"""Interceptor for stream-unary calls that injects auth metadata."""
async def intercept_stream_unary( # type: ignore[no-untyped-def]
self, continuation, client_call_details, request_iterator
):
"""Intercept stream-unary call and inject metadata."""
return await continuation(
_inject_metadata(client_call_details), request_iterator
)
class MetadataStreamStream(grpc.aio.StreamStreamClientInterceptor): # type: ignore[misc,no-any-unimported]
"""Interceptor for stream-stream calls that injects auth metadata."""
async def intercept_stream_stream( # type: ignore[no-untyped-def]
self, continuation, client_call_details, request_iterator
):
"""Intercept stream-stream call and inject metadata."""
return await continuation(
_inject_metadata(client_call_details), request_iterator
)
return [
MetadataUnaryUnary(),
MetadataUnaryStream(),
MetadataStreamUnary(),
MetadataStreamStream(),
]
def _create_grpc_channel_factory(
grpc_config: GRPCClientConfig,
auth: ClientAuthScheme | None = None,
) -> Callable[[str], Any]:
"""Create a gRPC channel factory with the given configuration.
Args:
grpc_config: gRPC client configuration with channel options.
auth: Optional ClientAuthScheme for TLS and auth configuration.
Returns:
A callable that creates gRPC channels from URLs.
"""
try:
import grpc
except ImportError as e:
raise ImportError(
"gRPC transport requires grpcio. Install with: pip install a2a-sdk[grpc]"
) from e
auth_metadata: list[tuple[str, str]] = []
if auth is not None:
from crewai_a2a.auth.client_schemes import (
APIKeyAuth,
BearerTokenAuth,
HTTPBasicAuth,
HTTPDigestAuth,
OAuth2AuthorizationCode,
OAuth2ClientCredentials,
)
if isinstance(auth, HTTPDigestAuth):
raise ValueError(
"HTTPDigestAuth is not supported with gRPC transport. "
"Digest authentication requires HTTP challenge-response flow. "
"Use BearerTokenAuth, HTTPBasicAuth, APIKeyAuth (header), or OAuth2 instead."
)
if isinstance(auth, APIKeyAuth) and auth.location in ("query", "cookie"):
raise ValueError(
f"APIKeyAuth with location='{auth.location}' is not supported with gRPC transport. "
"gRPC only supports header-based authentication. "
"Use APIKeyAuth with location='header' instead."
)
if isinstance(auth, BearerTokenAuth):
auth_metadata.append(("authorization", f"Bearer {auth.token}"))
elif isinstance(auth, HTTPBasicAuth):
import base64
basic_credentials = f"{auth.username}:{auth.password}"
encoded = base64.b64encode(basic_credentials.encode()).decode()
auth_metadata.append(("authorization", f"Basic {encoded}"))
elif isinstance(auth, APIKeyAuth) and auth.location == "header":
header_name = auth.name.lower()
auth_metadata.append((header_name, auth.api_key))
elif isinstance(auth, (OAuth2ClientCredentials, OAuth2AuthorizationCode)):
if auth._access_token:
auth_metadata.append(("authorization", f"Bearer {auth._access_token}"))
def factory(url: str) -> Any:
"""Create a gRPC channel for the given URL."""
target = url
use_tls = False
if url.startswith("grpcs://"):
target = url[8:]
use_tls = True
elif url.startswith("grpc://"):
target = url[7:]
elif url.startswith("https://"):
target = url[8:]
use_tls = True
elif url.startswith("http://"):
target = url[7:]
options: list[tuple[str, Any]] = []
if grpc_config.max_send_message_length is not None:
options.append(
("grpc.max_send_message_length", grpc_config.max_send_message_length)
)
if grpc_config.max_receive_message_length is not None:
options.append(
(
"grpc.max_receive_message_length",
grpc_config.max_receive_message_length,
)
)
if grpc_config.keepalive_time_ms is not None:
options.append(("grpc.keepalive_time_ms", grpc_config.keepalive_time_ms))
if grpc_config.keepalive_timeout_ms is not None:
options.append(
("grpc.keepalive_timeout_ms", grpc_config.keepalive_timeout_ms)
)
channel_credentials = None
if auth and hasattr(auth, "tls") and auth.tls:
channel_credentials = auth.tls.get_grpc_credentials()
elif use_tls:
channel_credentials = grpc.ssl_channel_credentials()
if channel_credentials and auth_metadata:
class AuthMetadataPlugin(grpc.AuthMetadataPlugin): # type: ignore[misc,no-any-unimported]
"""gRPC auth metadata plugin that adds auth headers as metadata."""
def __init__(self, metadata: list[tuple[str, str]]) -> None:
self._metadata = tuple(metadata)
def __call__( # type: ignore[no-any-unimported]
self,
context: grpc.AuthMetadataContext,
callback: grpc.AuthMetadataPluginCallback,
) -> None:
callback(self._metadata, None)
call_creds = grpc.metadata_call_credentials(
AuthMetadataPlugin(auth_metadata)
)
credentials = grpc.composite_channel_credentials(
channel_credentials, call_creds
)
interceptors = _create_grpc_interceptors()
return grpc.aio.secure_channel(
target, credentials, options=options or None, interceptors=interceptors
)
if channel_credentials:
interceptors = _create_grpc_interceptors()
return grpc.aio.secure_channel(
target,
channel_credentials,
options=options or None,
interceptors=interceptors,
)
interceptors = _create_grpc_interceptors(
auth_metadata=auth_metadata if auth_metadata else None
)
return grpc.aio.insecure_channel(
target, options=options or None, interceptors=interceptors
)
return factory
@asynccontextmanager
async def _create_a2a_client(
agent_card: AgentCard,
transport_protocol: Literal["JSONRPC", "GRPC", "HTTP+JSON"],
timeout: int,
headers: MutableMapping[str, str],
streaming: bool,
auth: ClientAuthScheme | None = None,
use_polling: bool = False,
push_notification_config: PushNotificationConfig | None = None,
client_extensions: list[str] | None = None,
accepted_output_modes: list[str] | None = None,
grpc_config: GRPCClientConfig | None = None,
) -> AsyncIterator[Client]:
"""Create and configure an A2A client.
Args:
agent_card: The A2A agent card.
transport_protocol: Transport protocol to use.
timeout: Request timeout in seconds.
headers: HTTP headers (already with auth applied).
streaming: Enable streaming responses.
auth: Optional ClientAuthScheme for client configuration.
use_polling: Enable polling mode.
push_notification_config: Optional push notification config.
client_extensions: A2A protocol extension URIs to declare support for.
accepted_output_modes: MIME types the client can accept in responses.
grpc_config: Optional gRPC client configuration.
Yields:
Configured A2A client instance.
"""
verify = _get_tls_verify(auth)
async with httpx.AsyncClient(
timeout=timeout,
headers=headers,
verify=verify,
) as httpx_client:
if auth and isinstance(auth, (HTTPDigestAuth, APIKeyAuth)):
configure_auth_client(auth, httpx_client)
push_configs: list[A2APushNotificationConfig] = []
if push_notification_config is not None:
push_configs.append(
A2APushNotificationConfig(
url=str(push_notification_config.url),
id=push_notification_config.id,
token=push_notification_config.token,
authentication=push_notification_config.authentication,
)
)
grpc_channel_factory = None
if transport_protocol == "GRPC":
grpc_channel_factory = _create_grpc_channel_factory(
grpc_config or GRPCClientConfig(),
auth=auth,
)
config = ClientConfig(
httpx_client=httpx_client,
supported_transports=[transport_protocol],
streaming=streaming and not use_polling,
polling=use_polling,
accepted_output_modes=accepted_output_modes or DEFAULT_CLIENT_OUTPUT_MODES, # type: ignore[arg-type]
push_notification_configs=push_configs,
grpc_channel_factory=grpc_channel_factory,
)
factory = ClientFactory(config)
client = factory.create(agent_card)
if client_extensions:
await client.add_request_middleware(ExtensionsMiddleware(client_extensions))
yield client

View File

@@ -0,0 +1,131 @@
"""Structured JSON logging utilities for A2A module."""
from __future__ import annotations
from contextvars import ContextVar
from datetime import datetime, timezone
import json
import logging
from typing import Any
_log_context: ContextVar[dict[str, Any] | None] = ContextVar(
"log_context", default=None
)
class JSONFormatter(logging.Formatter):
"""JSON formatter for structured logging.
Outputs logs as JSON with consistent fields for log aggregators.
"""
def format(self, record: logging.LogRecord) -> str:
"""Format log record as JSON string."""
log_data: dict[str, Any] = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"level": record.levelname,
"logger": record.name,
"message": record.getMessage(),
}
if record.exc_info:
log_data["exception"] = self.formatException(record.exc_info)
context = _log_context.get()
if context is not None:
log_data.update(context)
if hasattr(record, "task_id"):
log_data["task_id"] = record.task_id
if hasattr(record, "context_id"):
log_data["context_id"] = record.context_id
if hasattr(record, "agent"):
log_data["agent"] = record.agent
if hasattr(record, "endpoint"):
log_data["endpoint"] = record.endpoint
if hasattr(record, "extension"):
log_data["extension"] = record.extension
if hasattr(record, "error"):
log_data["error"] = record.error
for key, value in record.__dict__.items():
if key.startswith("_") or key in (
"name",
"msg",
"args",
"created",
"filename",
"funcName",
"levelname",
"levelno",
"lineno",
"module",
"msecs",
"pathname",
"process",
"processName",
"relativeCreated",
"stack_info",
"exc_info",
"exc_text",
"thread",
"threadName",
"taskName",
"message",
):
continue
if key not in log_data:
log_data[key] = value
return json.dumps(log_data, default=str)
class LogContext:
"""Context manager for adding fields to all logs within a scope.
Example:
with LogContext(task_id="abc", context_id="xyz"):
logger.info("Processing task") # Includes task_id and context_id
"""
def __init__(self, **fields: Any) -> None:
self._fields = fields
self._token: Any = None
def __enter__(self) -> LogContext:
current = _log_context.get() or {}
new_context = {**current, **self._fields}
self._token = _log_context.set(new_context)
return self
def __exit__(self, *args: Any) -> None:
_log_context.reset(self._token)
def configure_json_logging(logger_name: str = "crewai.a2a") -> None:
"""Configure JSON logging for the A2A module.
Args:
logger_name: Logger name to configure.
"""
logger = logging.getLogger(logger_name)
for handler in logger.handlers[:]:
logger.removeHandler(handler)
handler = logging.StreamHandler()
handler.setFormatter(JSONFormatter())
logger.addHandler(handler)
def get_logger(name: str) -> logging.Logger:
"""Get a logger configured for structured JSON output.
Args:
name: Logger name.
Returns:
Configured logger instance.
"""
return logging.getLogger(name)

View File

@@ -0,0 +1,101 @@
"""Response model utilities for A2A agent interactions."""
from __future__ import annotations
from typing import TypeAlias
from crewai.types.utils import create_literals_from_strings
from pydantic import BaseModel, Field, create_model
from crewai_a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
A2AConfigTypes: TypeAlias = A2AConfig | A2AServerConfig | A2AClientConfig
A2AClientConfigTypes: TypeAlias = A2AConfig | A2AClientConfig
def create_agent_response_model(agent_ids: tuple[str, ...]) -> type[BaseModel] | None:
"""Create a dynamic AgentResponse model with Literal types for agent IDs.
Args:
agent_ids: List of available A2A agent IDs.
Returns:
Dynamically created Pydantic model with Literal-constrained a2a_ids field,
or None if agent_ids is empty.
"""
if not agent_ids:
return None
DynamicLiteral = create_literals_from_strings(agent_ids) # noqa: N806
return create_model(
"AgentResponse",
a2a_ids=(
tuple[DynamicLiteral, ...], # type: ignore[valid-type]
Field(
default_factory=tuple,
max_length=len(agent_ids),
description="A2A agent IDs to delegate to.",
),
),
message=(
str,
Field(
description="The message content. If is_a2a=true, this is sent to the A2A agent. If is_a2a=false, this is your final answer ending the conversation."
),
),
is_a2a=(
bool,
Field(
description="Set to false when the remote agent has answered your question - extract their answer and return it as your final message. Set to true ONLY if you need to ask a NEW, DIFFERENT question. NEVER repeat the same request - if the conversation history shows the agent already answered, set is_a2a=false immediately."
),
),
__base__=BaseModel,
)
def extract_a2a_agent_ids_from_config(
a2a_config: list[A2AConfigTypes] | A2AConfigTypes | None,
) -> tuple[list[A2AClientConfigTypes], tuple[str, ...]]:
"""Extract A2A agent IDs from A2A configuration.
Filters out A2AServerConfig since it doesn't have an endpoint for delegation.
Args:
a2a_config: A2A configuration (any type).
Returns:
Tuple of client A2A configs list and agent endpoint IDs.
"""
if a2a_config is None:
return [], ()
configs: list[A2AConfigTypes]
if isinstance(a2a_config, (A2AConfig, A2AClientConfig, A2AServerConfig)):
configs = [a2a_config]
else:
configs = a2a_config
# Filter to only client configs (those with endpoint)
client_configs: list[A2AClientConfigTypes] = [
config for config in configs if isinstance(config, (A2AConfig, A2AClientConfig))
]
return client_configs, tuple(config.endpoint for config in client_configs)
def get_a2a_agents_and_response_model(
a2a_config: list[A2AConfigTypes] | A2AConfigTypes | None,
) -> tuple[list[A2AClientConfigTypes], type[BaseModel] | None]:
"""Get A2A agent configs and response model.
Args:
a2a_config: A2A configuration (any type).
Returns:
Tuple of client A2A configs and response model.
"""
a2a_agents, agent_ids = extract_a2a_agent_ids_from_config(a2a_config=a2a_config)
return a2a_agents, create_agent_response_model(agent_ids)

View File

@@ -0,0 +1,585 @@
"""A2A task utilities for server-side task management."""
from __future__ import annotations
import asyncio
import base64
from collections.abc import Callable, Coroutine
from datetime import datetime
from functools import wraps
import json
import logging
import os
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, TypedDict, cast
from urllib.parse import urlparse
from a2a.server.agent_execution import RequestContext
from a2a.server.events import EventQueue
from a2a.types import (
Artifact,
FileWithBytes,
FileWithUri,
InternalError,
InvalidParamsError,
Message,
Part,
Task as A2ATask,
TaskState,
TaskStatus,
TaskStatusUpdateEvent,
)
from a2a.utils import (
get_data_parts,
get_file_parts,
new_agent_text_message,
new_data_artifact,
new_text_artifact,
)
from a2a.utils.errors import ServerError
from aiocache import SimpleMemoryCache, caches # type: ignore[import-untyped]
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
A2AServerTaskCanceledEvent,
A2AServerTaskCompletedEvent,
A2AServerTaskFailedEvent,
A2AServerTaskStartedEvent,
)
from crewai.task import Task
from crewai.utilities.pydantic_schema_utils import create_model_from_schema
from pydantic import BaseModel
from crewai_a2a.utils.agent_card import _get_server_config
from crewai_a2a.utils.content_type import validate_message_parts
if TYPE_CHECKING:
from crewai.agent import Agent
from crewai_a2a.extensions.server import ExtensionContext, ServerExtensionRegistry
logger = logging.getLogger(__name__)
P = ParamSpec("P")
T = TypeVar("T")
class RedisCacheConfig(TypedDict, total=False):
"""Configuration for aiocache Redis backend."""
cache: str
endpoint: str
port: int
db: int
password: str
def _parse_redis_url(url: str) -> RedisCacheConfig:
"""Parse a Redis URL into aiocache configuration.
Args:
url: Redis connection URL (e.g., redis://localhost:6379/0).
Returns:
Configuration dict for aiocache.RedisCache.
"""
parsed = urlparse(url)
config: RedisCacheConfig = {
"cache": "aiocache.RedisCache",
"endpoint": parsed.hostname or "localhost",
"port": parsed.port or 6379,
}
if parsed.path and parsed.path != "/":
try:
config["db"] = int(parsed.path.lstrip("/"))
except ValueError:
pass
if parsed.password:
config["password"] = parsed.password
return config
_redis_url = os.environ.get("REDIS_URL")
caches.set_config(
{
"default": _parse_redis_url(_redis_url)
if _redis_url
else {
"cache": "aiocache.SimpleMemoryCache",
}
}
)
def cancellable(
fn: Callable[P, Coroutine[Any, Any, T]],
) -> Callable[P, Coroutine[Any, Any, T]]:
"""Decorator that enables cancellation for A2A task execution.
Runs a cancellation watcher concurrently with the wrapped function.
When a cancel event is published, the execution is cancelled.
Args:
fn: The async function to wrap.
Returns:
Wrapped function with cancellation support.
"""
@wraps(fn)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
"""Wrap function with cancellation monitoring."""
context: RequestContext | None = None
for arg in args:
if isinstance(arg, RequestContext):
context = arg
break
if context is None:
context = cast(RequestContext | None, kwargs.get("context"))
if context is None:
return await fn(*args, **kwargs)
task_id = context.task_id
cache = caches.get("default")
async def poll_for_cancel() -> bool:
"""Poll cache for cancellation flag."""
while True:
if await cache.get(f"cancel:{task_id}"):
return True
await asyncio.sleep(0.1)
async def watch_for_cancel() -> bool:
"""Watch for cancellation events via pub/sub or polling."""
if isinstance(cache, SimpleMemoryCache):
return await poll_for_cancel()
try:
client = cache.client
pubsub = client.pubsub()
await pubsub.subscribe(f"cancel:{task_id}")
async for message in pubsub.listen():
if message["type"] == "message":
return True
except (OSError, ConnectionError) as e:
logger.warning(
"Cancel watcher Redis error, falling back to polling",
extra={"task_id": task_id, "error": str(e)},
)
return await poll_for_cancel()
return False
execute_task = asyncio.create_task(fn(*args, **kwargs))
cancel_watch = asyncio.create_task(watch_for_cancel())
try:
done, _ = await asyncio.wait(
[execute_task, cancel_watch],
return_when=asyncio.FIRST_COMPLETED,
)
if cancel_watch in done:
execute_task.cancel()
try:
await execute_task
except asyncio.CancelledError:
pass
raise asyncio.CancelledError(f"Task {task_id} was cancelled")
cancel_watch.cancel()
return execute_task.result()
finally:
await cache.delete(f"cancel:{task_id}")
return wrapper
def _convert_a2a_files_to_file_inputs(
a2a_files: list[FileWithBytes | FileWithUri],
) -> dict[str, Any]:
"""Convert a2a file types to crewai FileInput dict.
Args:
a2a_files: List of FileWithBytes or FileWithUri from a2a SDK.
Returns:
Dictionary mapping file names to FileInput objects.
"""
try:
from crewai_files import File, FileBytes
except ImportError:
logger.debug("crewai_files not installed, returning empty file dict")
return {}
file_dict: dict[str, Any] = {}
for idx, a2a_file in enumerate(a2a_files):
if isinstance(a2a_file, FileWithBytes):
file_bytes = base64.b64decode(a2a_file.bytes)
name = a2a_file.name or f"file_{idx}"
file_source = FileBytes(data=file_bytes, filename=a2a_file.name)
file_dict[name] = File(source=file_source)
elif isinstance(a2a_file, FileWithUri):
name = a2a_file.name or f"file_{idx}"
file_dict[name] = File(source=a2a_file.uri)
return file_dict
def _extract_response_schema(parts: list[Part]) -> dict[str, Any] | None:
"""Extract response schema from message parts metadata.
The client may include a JSON schema in TextPart metadata to specify
the expected response format (see delegation.py line 463).
Args:
parts: List of message parts.
Returns:
JSON schema dict if found, None otherwise.
"""
for part in parts:
if part.root.kind == "text" and part.root.metadata:
schema = part.root.metadata.get("schema")
if schema and isinstance(schema, dict):
return schema # type: ignore[no-any-return]
return None
def _create_result_artifact(
result: Any,
task_id: str,
) -> Artifact:
"""Create artifact from task result, using DataPart for structured data.
Args:
result: The task execution result.
task_id: The task ID for naming the artifact.
Returns:
Artifact with appropriate part type (DataPart for dict/Pydantic, TextPart for strings).
"""
artifact_name = f"result_{task_id}"
if isinstance(result, dict):
return new_data_artifact(artifact_name, result)
if isinstance(result, BaseModel):
return new_data_artifact(artifact_name, result.model_dump())
return new_text_artifact(artifact_name, str(result))
def _build_task_description(
user_message: str,
structured_inputs: list[dict[str, Any]],
) -> str:
"""Build task description including structured data if present.
Args:
user_message: The original user message text.
structured_inputs: List of structured data from DataParts.
Returns:
Task description with structured data appended if present.
"""
if not structured_inputs:
return user_message
structured_json = json.dumps(structured_inputs, indent=2)
return f"{user_message}\n\nStructured Data:\n{structured_json}"
async def execute(
agent: Agent,
context: RequestContext,
event_queue: EventQueue,
) -> None:
"""Execute an A2A task using a CrewAI agent.
Args:
agent: The CrewAI agent to execute the task.
context: The A2A request context containing the user's message.
event_queue: The event queue for sending responses back.
"""
await _execute_impl(agent, context, event_queue, None, None)
@cancellable
async def _execute_impl(
agent: Agent,
context: RequestContext,
event_queue: EventQueue,
extension_registry: ServerExtensionRegistry | None,
extension_context: ExtensionContext | None,
) -> None:
"""Internal implementation for task execution with optional extensions."""
server_config = _get_server_config(agent)
if context.message and context.message.parts and server_config:
allowed_modes = server_config.default_input_modes
invalid_types = validate_message_parts(context.message.parts, allowed_modes)
if invalid_types:
raise ServerError(
InvalidParamsError(
message=f"Unsupported content type(s): {', '.join(invalid_types)}. "
f"Supported: {', '.join(allowed_modes)}"
)
)
if extension_registry and extension_context:
await extension_registry.invoke_on_request(extension_context)
user_message = context.get_user_input()
response_model: type[BaseModel] | None = None
structured_inputs: list[dict[str, Any]] = []
a2a_files: list[FileWithBytes | FileWithUri] = []
if context.message and context.message.parts:
schema = _extract_response_schema(context.message.parts)
if schema:
try:
response_model = create_model_from_schema(schema)
except Exception as e:
logger.debug(
"Failed to create response model from schema",
extra={"error": str(e), "schema_title": schema.get("title")},
)
structured_inputs = get_data_parts(context.message.parts)
a2a_files = get_file_parts(context.message.parts)
task_id = context.task_id
context_id = context.context_id
if task_id is None or context_id is None:
msg = "task_id and context_id are required"
crewai_event_bus.emit(
agent,
A2AServerTaskFailedEvent(
task_id="",
context_id="",
error=msg,
from_agent=agent,
),
)
raise ServerError(InvalidParamsError(message=msg)) from None
task = Task(
description=_build_task_description(user_message, structured_inputs),
expected_output="Response to the user's request",
agent=agent,
response_model=response_model,
input_files=_convert_a2a_files_to_file_inputs(a2a_files),
)
crewai_event_bus.emit(
agent,
A2AServerTaskStartedEvent(
task_id=task_id,
context_id=context_id,
from_task=task,
from_agent=agent,
),
)
try:
result = await agent.aexecute_task(task=task, tools=agent.tools)
if extension_registry and extension_context:
result = await extension_registry.invoke_on_response(
extension_context, result
)
result_str = str(result)
history: list[Message] = [context.message] if context.message else []
history.append(new_agent_text_message(result_str, context_id, task_id))
await event_queue.enqueue_event(
A2ATask(
id=task_id,
context_id=context_id,
status=TaskStatus(state=TaskState.completed),
artifacts=[_create_result_artifact(result, task_id)],
history=history,
)
)
crewai_event_bus.emit(
agent,
A2AServerTaskCompletedEvent(
task_id=task_id,
context_id=context_id,
result=str(result),
from_task=task,
from_agent=agent,
),
)
except asyncio.CancelledError:
crewai_event_bus.emit(
agent,
A2AServerTaskCanceledEvent(
task_id=task_id,
context_id=context_id,
from_task=task,
from_agent=agent,
),
)
raise
except Exception as e:
crewai_event_bus.emit(
agent,
A2AServerTaskFailedEvent(
task_id=task_id,
context_id=context_id,
error=str(e),
from_task=task,
from_agent=agent,
),
)
raise ServerError(
error=InternalError(message=f"Task execution failed: {e}")
) from e
async def execute_with_extensions(
agent: Agent,
context: RequestContext,
event_queue: EventQueue,
extension_registry: ServerExtensionRegistry,
extension_context: ExtensionContext,
) -> None:
"""Execute an A2A task with extension hooks.
Args:
agent: The CrewAI agent to execute the task.
context: The A2A request context containing the user's message.
event_queue: The event queue for sending responses back.
extension_registry: Registry of server extensions.
extension_context: Context for extension hooks.
"""
await _execute_impl(
agent, context, event_queue, extension_registry, extension_context
)
async def cancel(
context: RequestContext,
event_queue: EventQueue,
) -> A2ATask | None:
"""Cancel an A2A task.
Publishes a cancel event that the cancellable decorator listens for.
Args:
context: The A2A request context containing task information.
event_queue: The event queue for sending the cancellation status.
Returns:
The canceled task with updated status.
"""
task_id = context.task_id
context_id = context.context_id
if task_id is None or context_id is None:
raise ServerError(InvalidParamsError(message="task_id and context_id required"))
if context.current_task and context.current_task.status.state in (
TaskState.completed,
TaskState.failed,
TaskState.canceled,
):
return context.current_task
cache = caches.get("default")
await cache.set(f"cancel:{task_id}", True, ttl=3600)
if not isinstance(cache, SimpleMemoryCache):
await cache.client.publish(f"cancel:{task_id}", "cancel")
await event_queue.enqueue_event(
TaskStatusUpdateEvent(
task_id=task_id,
context_id=context_id,
status=TaskStatus(state=TaskState.canceled),
final=True,
)
)
if context.current_task:
context.current_task.status = TaskStatus(state=TaskState.canceled)
return context.current_task
return None
def list_tasks(
tasks: list[A2ATask],
context_id: str | None = None,
status: TaskState | None = None,
status_timestamp_after: datetime | None = None,
page_size: int = 50,
page_token: str | None = None,
history_length: int | None = None,
include_artifacts: bool = False,
) -> tuple[list[A2ATask], str | None, int]:
"""Filter and paginate A2A tasks.
Provides filtering by context, status, and timestamp, along with
cursor-based pagination. This is a pure utility function that operates
on an in-memory list of tasks - storage retrieval is handled separately.
Args:
tasks: All tasks to filter.
context_id: Filter by context ID to get tasks in a conversation.
status: Filter by task state (e.g., completed, working).
status_timestamp_after: Filter to tasks updated after this time.
page_size: Maximum tasks per page (default 50).
page_token: Base64-encoded cursor from previous response.
history_length: Limit history messages per task (None = full history).
include_artifacts: Whether to include task artifacts (default False).
Returns:
Tuple of (filtered_tasks, next_page_token, total_count).
- filtered_tasks: Tasks matching filters, paginated and trimmed.
- next_page_token: Token for next page, or None if no more pages.
- total_count: Total number of tasks matching filters (before pagination).
"""
filtered: list[A2ATask] = []
for task in tasks:
if context_id and task.context_id != context_id:
continue
if status and task.status.state != status:
continue
if status_timestamp_after and task.status.timestamp:
ts = datetime.fromisoformat(task.status.timestamp.replace("Z", "+00:00"))
if ts <= status_timestamp_after:
continue
filtered.append(task)
def get_timestamp(t: A2ATask) -> datetime:
"""Extract timestamp from task status for sorting."""
if t.status.timestamp is None:
return datetime.min
return datetime.fromisoformat(t.status.timestamp.replace("Z", "+00:00"))
filtered.sort(key=get_timestamp, reverse=True)
total = len(filtered)
start = 0
if page_token:
try:
cursor_id = base64.b64decode(page_token).decode()
for idx, task in enumerate(filtered):
if task.id == cursor_id:
start = idx + 1
break
except (ValueError, UnicodeDecodeError):
pass
page = filtered[start : start + page_size]
result: list[A2ATask] = []
for task in page:
task = task.model_copy(deep=True)
if history_length is not None and task.history:
task.history = task.history[-history_length:]
if not include_artifacts:
task.artifacts = None
result.append(task)
next_token: str | None = None
if result and len(result) == page_size:
next_token = base64.b64encode(result[-1].id.encode()).decode()
return result, next_token, total

View File

@@ -0,0 +1,214 @@
"""Transport negotiation utilities for A2A protocol.
This module provides functionality for negotiating the transport protocol
between an A2A client and server based on their respective capabilities
and preferences.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Final, Literal
from a2a.types import AgentCard, AgentInterface
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import A2ATransportNegotiatedEvent
TransportProtocol = Literal["JSONRPC", "GRPC", "HTTP+JSON"]
NegotiationSource = Literal["client_preferred", "server_preferred", "fallback"]
JSONRPC_TRANSPORT: Literal["JSONRPC"] = "JSONRPC"
GRPC_TRANSPORT: Literal["GRPC"] = "GRPC"
HTTP_JSON_TRANSPORT: Literal["HTTP+JSON"] = "HTTP+JSON"
DEFAULT_TRANSPORT_PREFERENCE: Final[list[TransportProtocol]] = [
JSONRPC_TRANSPORT,
GRPC_TRANSPORT,
HTTP_JSON_TRANSPORT,
]
@dataclass
class NegotiatedTransport:
"""Result of transport negotiation.
Attributes:
transport: The negotiated transport protocol.
url: The URL to use for this transport.
source: How the transport was selected ('preferred', 'additional', 'fallback').
"""
transport: str
url: str
source: NegotiationSource
class TransportNegotiationError(Exception):
"""Raised when no compatible transport can be negotiated."""
def __init__(
self,
client_transports: list[str],
server_transports: list[str],
message: str | None = None,
) -> None:
"""Initialize the error with negotiation details.
Args:
client_transports: Transports supported by the client.
server_transports: Transports supported by the server.
message: Optional custom error message.
"""
self.client_transports = client_transports
self.server_transports = server_transports
if message is None:
message = (
f"No compatible transport found. "
f"Client supports: {client_transports}. "
f"Server supports: {server_transports}."
)
super().__init__(message)
def _get_server_interfaces(agent_card: AgentCard) -> list[AgentInterface]:
"""Extract all available interfaces from an AgentCard.
Creates a unified list of interfaces including the primary URL and
any additional interfaces declared by the agent.
Args:
agent_card: The agent's card containing transport information.
Returns:
List of AgentInterface objects representing all available endpoints.
"""
interfaces: list[AgentInterface] = []
primary_transport = agent_card.preferred_transport or JSONRPC_TRANSPORT
interfaces.append(
AgentInterface(
transport=primary_transport,
url=agent_card.url,
)
)
if agent_card.additional_interfaces:
for interface in agent_card.additional_interfaces:
is_duplicate = any(
i.url == interface.url and i.transport == interface.transport
for i in interfaces
)
if not is_duplicate:
interfaces.append(interface)
return interfaces
def negotiate_transport(
agent_card: AgentCard,
client_supported_transports: list[str] | None = None,
client_preferred_transport: str | None = None,
emit_event: bool = True,
endpoint: str | None = None,
a2a_agent_name: str | None = None,
) -> NegotiatedTransport:
"""Negotiate the transport protocol between client and server.
Compares the client's supported transports with the server's available
interfaces to find a compatible transport and URL.
Negotiation logic:
1. If client_preferred_transport is set and server supports it → use it
2. Otherwise, if server's preferred is in client's supported → use server's
3. Otherwise, find first match from client's supported in server's interfaces
Args:
agent_card: The server's AgentCard with transport information.
client_supported_transports: Transports the client can use.
Defaults to ["JSONRPC"] if not specified.
client_preferred_transport: Client's preferred transport. If set and
server supports it, takes priority over server preference.
emit_event: Whether to emit a transport negotiation event.
endpoint: Original endpoint URL for event metadata.
a2a_agent_name: Agent name for event metadata.
Returns:
NegotiatedTransport with the selected transport, URL, and source.
Raises:
TransportNegotiationError: If no compatible transport is found.
"""
if client_supported_transports is None:
client_supported_transports = [JSONRPC_TRANSPORT]
client_transports = [t.upper() for t in client_supported_transports]
client_preferred = (
client_preferred_transport.upper() if client_preferred_transport else None
)
server_interfaces = _get_server_interfaces(agent_card)
server_transports = [i.transport.upper() for i in server_interfaces]
transport_to_interface: dict[str, AgentInterface] = {}
for interface in server_interfaces:
transport_upper = interface.transport.upper()
if transport_upper not in transport_to_interface:
transport_to_interface[transport_upper] = interface
result: NegotiatedTransport | None = None
if client_preferred and client_preferred in transport_to_interface:
interface = transport_to_interface[client_preferred]
result = NegotiatedTransport(
transport=interface.transport,
url=interface.url,
source="client_preferred",
)
else:
server_preferred = (agent_card.preferred_transport or JSONRPC_TRANSPORT).upper()
if (
server_preferred in client_transports
and server_preferred in transport_to_interface
):
interface = transport_to_interface[server_preferred]
result = NegotiatedTransport(
transport=interface.transport,
url=interface.url,
source="server_preferred",
)
else:
for transport in client_transports:
if transport in transport_to_interface:
interface = transport_to_interface[transport]
result = NegotiatedTransport(
transport=interface.transport,
url=interface.url,
source="fallback",
)
break
if result is None:
raise TransportNegotiationError(
client_transports=client_transports,
server_transports=server_transports,
)
if emit_event:
crewai_event_bus.emit(
None,
A2ATransportNegotiatedEvent(
endpoint=endpoint or agent_card.url,
a2a_agent_name=a2a_agent_name or agent_card.name,
negotiated_transport=result.transport,
negotiated_url=result.url,
source=result.source,
client_supported_transports=client_transports,
server_supported_transports=server_transports,
server_preferred_transport=agent_card.preferred_transport
or JSONRPC_TRANSPORT,
client_preferred_transport=client_preferred,
),
)
return result

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,44 @@
interactions:
- request:
body: ''
headers:
User-Agent:
- X-USER-AGENT-XXX
accept:
- '*/*'
accept-encoding:
- ACCEPT-ENCODING-XXX
connection:
- keep-alive
host:
- localhost:9999
method: GET
uri: http://localhost:9999/.well-known/agent-card.json
response:
body:
string: '{"capabilities":{"streaming":true},"defaultInputModes":["text"],"defaultOutputModes":["text"],"description":"An
AI assistant powered by OpenAI GPT with calculator and time tools. Ask questions,
perform calculations, or get the current time in any timezone.","name":"GPT
Assistant","preferredTransport":"JSONRPC","protocolVersion":"0.3.0","skills":[{"description":"Have
a general conversation with the AI assistant. Ask questions, get explanations,
or just chat.","examples":["Hello, how are you?","Explain quantum computing
in simple terms","What can you help me with?"],"id":"conversation","name":"General
Conversation","tags":["chat","conversation","general"]},{"description":"Perform
mathematical calculations including arithmetic, exponents, and more.","examples":["What
is 25 * 17?","Calculate 2^10","What''s (100 + 50) / 3?"],"id":"calculator","name":"Calculator","tags":["math","calculator","arithmetic"]},{"description":"Get
the current date and time in any timezone.","examples":["What time is it?","What''s
the current time in Tokyo?","What''s today''s date in New York?"],"id":"time","name":"Current
Time","tags":["time","date","timezone"]}],"url":"http://localhost:9999/","version":"1.0.0"}'
headers:
content-length:
- '1198'
content-type:
- application/json
date:
- Tue, 06 Jan 2026 14:17:00 GMT
server:
- uvicorn
status:
code: 200
message: OK
version: 1

View File

@@ -0,0 +1,126 @@
interactions:
- request:
body: ''
headers:
User-Agent:
- X-USER-AGENT-XXX
accept:
- '*/*'
accept-encoding:
- ACCEPT-ENCODING-XXX
connection:
- keep-alive
host:
- localhost:9999
method: GET
uri: http://localhost:9999/.well-known/agent-card.json
response:
body:
string: '{"capabilities":{"streaming":true},"defaultInputModes":["text"],"defaultOutputModes":["text"],"description":"An
AI assistant powered by OpenAI GPT with calculator and time tools. Ask questions,
perform calculations, or get the current time in any timezone.","name":"GPT
Assistant","preferredTransport":"JSONRPC","protocolVersion":"0.3.0","skills":[{"description":"Have
a general conversation with the AI assistant. Ask questions, get explanations,
or just chat.","examples":["Hello, how are you?","Explain quantum computing
in simple terms","What can you help me with?"],"id":"conversation","name":"General
Conversation","tags":["chat","conversation","general"]},{"description":"Perform
mathematical calculations including arithmetic, exponents, and more.","examples":["What
is 25 * 17?","Calculate 2^10","What''s (100 + 50) / 3?"],"id":"calculator","name":"Calculator","tags":["math","calculator","arithmetic"]},{"description":"Get
the current date and time in any timezone.","examples":["What time is it?","What''s
the current time in Tokyo?","What''s today''s date in New York?"],"id":"time","name":"Current
Time","tags":["time","date","timezone"]}],"url":"http://localhost:9999/","version":"1.0.0"}'
headers:
content-length:
- '1198'
content-type:
- application/json
date:
- Tue, 06 Jan 2026 14:16:58 GMT
server:
- uvicorn
status:
code: 200
message: OK
- request:
body: '{"id":"e5ac2160-ae9b-4bf9-aad7-14bf0d53d6d9","jsonrpc":"2.0","method":"message/stream","params":{"configuration":{"acceptedOutputModes":[],"blocking":true},"message":{"kind":"message","messageId":"e1e63c75-3ea0-49fb-b512-5128a2476416","parts":[{"kind":"text","text":"What
is 2 + 2?"}],"role":"user"}}}'
headers:
User-Agent:
- X-USER-AGENT-XXX
accept:
- '*/*, text/event-stream'
accept-encoding:
- ACCEPT-ENCODING-XXX
cache-control:
- no-store
connection:
- keep-alive
content-length:
- '301'
content-type:
- application/json
host:
- localhost:9999
method: POST
uri: http://localhost:9999/
response:
body:
string: "data: {\"id\":\"e5ac2160-ae9b-4bf9-aad7-14bf0d53d6d9\",\"jsonrpc\":\"2.0\",\"result\":{\"contextId\":\"b9e14c1b-734d-4d1e-864a-e6dda5231d71\",\"final\":false,\"kind\":\"status-update\",\"status\":{\"state\":\"submitted\"},\"taskId\":\"0dd4d3af-f35d-409d-9462-01218e5641f9\"}}\r\n\r\ndata:
{\"id\":\"e5ac2160-ae9b-4bf9-aad7-14bf0d53d6d9\",\"jsonrpc\":\"2.0\",\"result\":{\"contextId\":\"b9e14c1b-734d-4d1e-864a-e6dda5231d71\",\"final\":false,\"kind\":\"status-update\",\"status\":{\"state\":\"working\"},\"taskId\":\"0dd4d3af-f35d-409d-9462-01218e5641f9\"}}\r\n\r\ndata:
{\"id\":\"e5ac2160-ae9b-4bf9-aad7-14bf0d53d6d9\",\"jsonrpc\":\"2.0\",\"result\":{\"contextId\":\"b9e14c1b-734d-4d1e-864a-e6dda5231d71\",\"final\":true,\"kind\":\"status-update\",\"status\":{\"message\":{\"kind\":\"message\",\"messageId\":\"54bb7ff3-f2c0-4eb3-b427-bf1c8cf90832\",\"parts\":[{\"kind\":\"text\",\"text\":\"\\n[Tool:
calculator] 2 + 2 = 4\\n2 + 2 equals 4.\"}],\"role\":\"agent\"},\"state\":\"completed\"},\"taskId\":\"0dd4d3af-f35d-409d-9462-01218e5641f9\"}}\r\n\r\n"
headers:
Transfer-Encoding:
- chunked
cache-control:
- no-store
connection:
- keep-alive
content-type:
- text/event-stream; charset=utf-8
date:
- Tue, 06 Jan 2026 14:16:58 GMT
server:
- uvicorn
x-accel-buffering:
- 'no'
status:
code: 200
message: OK
- request:
body: '{"id":"cb1e4af3-d2d0-4848-96b8-7082ee6171d1","jsonrpc":"2.0","method":"tasks/get","params":{"historyLength":100,"id":"0dd4d3af-f35d-409d-9462-01218e5641f9"}}'
headers:
User-Agent:
- X-USER-AGENT-XXX
accept:
- '*/*'
accept-encoding:
- ACCEPT-ENCODING-XXX
connection:
- keep-alive
content-length:
- '157'
content-type:
- application/json
host:
- localhost:9999
method: POST
uri: http://localhost:9999/
response:
body:
string: '{"id":"cb1e4af3-d2d0-4848-96b8-7082ee6171d1","jsonrpc":"2.0","result":{"contextId":"b9e14c1b-734d-4d1e-864a-e6dda5231d71","history":[{"contextId":"b9e14c1b-734d-4d1e-864a-e6dda5231d71","kind":"message","messageId":"e1e63c75-3ea0-49fb-b512-5128a2476416","parts":[{"kind":"text","text":"What
is 2 + 2?"}],"role":"user","taskId":"0dd4d3af-f35d-409d-9462-01218e5641f9"}],"id":"0dd4d3af-f35d-409d-9462-01218e5641f9","kind":"task","status":{"message":{"kind":"message","messageId":"54bb7ff3-f2c0-4eb3-b427-bf1c8cf90832","parts":[{"kind":"text","text":"\n[Tool:
calculator] 2 + 2 = 4\n2 + 2 equals 4."}],"role":"agent"},"state":"completed"}}}'
headers:
content-length:
- '635'
content-type:
- application/json
date:
- Tue, 06 Jan 2026 14:17:00 GMT
server:
- uvicorn
status:
code: 200
message: OK
version: 1

View File

@@ -0,0 +1,90 @@
interactions:
- request:
body: ''
headers:
User-Agent:
- X-USER-AGENT-XXX
accept:
- '*/*'
accept-encoding:
- ACCEPT-ENCODING-XXX
connection:
- keep-alive
host:
- localhost:9999
method: GET
uri: http://localhost:9999/.well-known/agent-card.json
response:
body:
string: '{"capabilities":{"streaming":true},"defaultInputModes":["text"],"defaultOutputModes":["text"],"description":"An
AI assistant powered by OpenAI GPT with calculator and time tools. Ask questions,
perform calculations, or get the current time in any timezone.","name":"GPT
Assistant","preferredTransport":"JSONRPC","protocolVersion":"0.3.0","skills":[{"description":"Have
a general conversation with the AI assistant. Ask questions, get explanations,
or just chat.","examples":["Hello, how are you?","Explain quantum computing
in simple terms","What can you help me with?"],"id":"conversation","name":"General
Conversation","tags":["chat","conversation","general"]},{"description":"Perform
mathematical calculations including arithmetic, exponents, and more.","examples":["What
is 25 * 17?","Calculate 2^10","What''s (100 + 50) / 3?"],"id":"calculator","name":"Calculator","tags":["math","calculator","arithmetic"]},{"description":"Get
the current date and time in any timezone.","examples":["What time is it?","What''s
the current time in Tokyo?","What''s today''s date in New York?"],"id":"time","name":"Current
Time","tags":["time","date","timezone"]}],"url":"http://localhost:9999/","version":"1.0.0"}'
headers:
content-length:
- '1198'
content-type:
- application/json
date:
- Tue, 06 Jan 2026 14:17:02 GMT
server:
- uvicorn
status:
code: 200
message: OK
- request:
body: '{"id":"8cf25b61-8884-4246-adce-fccb32e176ab","jsonrpc":"2.0","method":"message/stream","params":{"configuration":{"acceptedOutputModes":[],"blocking":true},"message":{"kind":"message","messageId":"c145297f-7331-4835-adcc-66b51de92a2b","parts":[{"kind":"text","text":"What
is 2 + 2?"}],"role":"user"}}}'
headers:
User-Agent:
- X-USER-AGENT-XXX
accept:
- '*/*, text/event-stream'
accept-encoding:
- ACCEPT-ENCODING-XXX
cache-control:
- no-store
connection:
- keep-alive
content-length:
- '301'
content-type:
- application/json
host:
- localhost:9999
method: POST
uri: http://localhost:9999/
response:
body:
string: "data: {\"id\":\"8cf25b61-8884-4246-adce-fccb32e176ab\",\"jsonrpc\":\"2.0\",\"result\":{\"contextId\":\"30601267-ab3b-48ef-afc8-916c37a18651\",\"final\":false,\"kind\":\"status-update\",\"status\":{\"state\":\"submitted\"},\"taskId\":\"3083d3da-4739-4f4f-a4e8-7c048ea819c1\"}}\r\n\r\ndata:
{\"id\":\"8cf25b61-8884-4246-adce-fccb32e176ab\",\"jsonrpc\":\"2.0\",\"result\":{\"contextId\":\"30601267-ab3b-48ef-afc8-916c37a18651\",\"final\":false,\"kind\":\"status-update\",\"status\":{\"state\":\"working\"},\"taskId\":\"3083d3da-4739-4f4f-a4e8-7c048ea819c1\"}}\r\n\r\ndata:
{\"id\":\"8cf25b61-8884-4246-adce-fccb32e176ab\",\"jsonrpc\":\"2.0\",\"result\":{\"contextId\":\"30601267-ab3b-48ef-afc8-916c37a18651\",\"final\":true,\"kind\":\"status-update\",\"status\":{\"message\":{\"kind\":\"message\",\"messageId\":\"25f81e3c-b7e8-48b5-a98a-4066f3637a13\",\"parts\":[{\"kind\":\"text\",\"text\":\"\\n[Tool:
calculator] 2 + 2 = 4\\n2 + 2 equals 4.\"}],\"role\":\"agent\"},\"state\":\"completed\"},\"taskId\":\"3083d3da-4739-4f4f-a4e8-7c048ea819c1\"}}\r\n\r\n"
headers:
Transfer-Encoding:
- chunked
cache-control:
- no-store
connection:
- keep-alive
content-type:
- text/event-stream; charset=utf-8
date:
- Tue, 06 Jan 2026 14:17:02 GMT
server:
- uvicorn
x-accel-buffering:
- 'no'
status:
code: 200
message: OK
version: 1

View File

@@ -0,0 +1,90 @@
interactions:
- request:
body: ''
headers:
User-Agent:
- X-USER-AGENT-XXX
accept:
- '*/*'
accept-encoding:
- ACCEPT-ENCODING-XXX
connection:
- keep-alive
host:
- localhost:9999
method: GET
uri: http://localhost:9999/.well-known/agent-card.json
response:
body:
string: '{"capabilities":{"streaming":true},"defaultInputModes":["text"],"defaultOutputModes":["text"],"description":"An
AI assistant powered by OpenAI GPT with calculator and time tools. Ask questions,
perform calculations, or get the current time in any timezone.","name":"GPT
Assistant","preferredTransport":"JSONRPC","protocolVersion":"0.3.0","skills":[{"description":"Have
a general conversation with the AI assistant. Ask questions, get explanations,
or just chat.","examples":["Hello, how are you?","Explain quantum computing
in simple terms","What can you help me with?"],"id":"conversation","name":"General
Conversation","tags":["chat","conversation","general"]},{"description":"Perform
mathematical calculations including arithmetic, exponents, and more.","examples":["What
is 25 * 17?","Calculate 2^10","What''s (100 + 50) / 3?"],"id":"calculator","name":"Calculator","tags":["math","calculator","arithmetic"]},{"description":"Get
the current date and time in any timezone.","examples":["What time is it?","What''s
the current time in Tokyo?","What''s today''s date in New York?"],"id":"time","name":"Current
Time","tags":["time","date","timezone"]}],"url":"http://localhost:9999/","version":"1.0.0"}'
headers:
content-length:
- '1198'
content-type:
- application/json
date:
- Tue, 06 Jan 2026 14:17:00 GMT
server:
- uvicorn
status:
code: 200
message: OK
- request:
body: '{"id":"3a17c6bf-8db6-45a6-8535-34c45c0c4936","jsonrpc":"2.0","method":"message/stream","params":{"configuration":{"acceptedOutputModes":[],"blocking":true},"message":{"kind":"message","messageId":"712558a3-6d92-4591-be8a-9dd8566dde82","parts":[{"kind":"text","text":"What
is 2 + 2?"}],"role":"user"}}}'
headers:
User-Agent:
- X-USER-AGENT-XXX
accept:
- '*/*, text/event-stream'
accept-encoding:
- ACCEPT-ENCODING-XXX
cache-control:
- no-store
connection:
- keep-alive
content-length:
- '301'
content-type:
- application/json
host:
- localhost:9999
method: POST
uri: http://localhost:9999/
response:
body:
string: "data: {\"id\":\"3a17c6bf-8db6-45a6-8535-34c45c0c4936\",\"jsonrpc\":\"2.0\",\"result\":{\"contextId\":\"ca2fbbc9-761e-45d9-a929-0c68b1f8acbf\",\"final\":false,\"kind\":\"status-update\",\"status\":{\"state\":\"submitted\"},\"taskId\":\"c6e88db0-36e9-4269-8b9a-ecb6dfdcf6a1\"}}\r\n\r\ndata:
{\"id\":\"3a17c6bf-8db6-45a6-8535-34c45c0c4936\",\"jsonrpc\":\"2.0\",\"result\":{\"contextId\":\"ca2fbbc9-761e-45d9-a929-0c68b1f8acbf\",\"final\":false,\"kind\":\"status-update\",\"status\":{\"state\":\"working\"},\"taskId\":\"c6e88db0-36e9-4269-8b9a-ecb6dfdcf6a1\"}}\r\n\r\ndata:
{\"id\":\"3a17c6bf-8db6-45a6-8535-34c45c0c4936\",\"jsonrpc\":\"2.0\",\"result\":{\"contextId\":\"ca2fbbc9-761e-45d9-a929-0c68b1f8acbf\",\"final\":true,\"kind\":\"status-update\",\"status\":{\"message\":{\"kind\":\"message\",\"messageId\":\"916324aa-fd25-4849-bceb-c4644e2fcbb0\",\"parts\":[{\"kind\":\"text\",\"text\":\"\\n[Tool:
calculator] 2 + 2 = 4\\n2 + 2 equals 4.\"}],\"role\":\"agent\"},\"state\":\"completed\"},\"taskId\":\"c6e88db0-36e9-4269-8b9a-ecb6dfdcf6a1\"}}\r\n\r\n"
headers:
Transfer-Encoding:
- chunked
cache-control:
- no-store
connection:
- keep-alive
content-type:
- text/event-stream; charset=utf-8
date:
- Tue, 06 Jan 2026 14:17:00 GMT
server:
- uvicorn
x-accel-buffering:
- 'no'
status:
code: 200
message: OK
version: 1

View File

@@ -0,0 +1,21 @@
"""Pytest configuration for crewai-a2a tests.
Ensures Agent model is properly rebuilt with A2A types,
which can fail silently during circular import resolution.
"""
from crewai_a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
def pytest_configure() -> None:
"""Rebuild Agent/LiteAgent models after crewai_a2a is fully loaded."""
from crewai.agent.core import Agent
from crewai.lite_agent import LiteAgent
ns = {
"A2AConfig": A2AConfig,
"A2AClientConfig": A2AClientConfig,
"A2AServerConfig": A2AServerConfig,
}
Agent.model_rebuild(_types_namespace=ns)
LiteAgent.model_rebuild(_types_namespace=ns)

View File

@@ -3,15 +3,13 @@ from __future__ import annotations
import os
import uuid
from a2a.client import ClientFactory
from a2a.types import AgentCard, Message, Part, Role, Task, TaskState, TextPart
from crewai_a2a.updates.polling.handler import PollingHandler
from crewai_a2a.updates.streaming.handler import StreamingHandler
import pytest
import pytest_asyncio
from a2a.client import ClientFactory
from a2a.types import AgentCard, Message, Part, Role, TaskState, TextPart
from crewai.a2a.updates.polling.handler import PollingHandler
from crewai.a2a.updates.streaming.handler import StreamingHandler
A2A_TEST_ENDPOINT = os.getenv("A2A_TEST_ENDPOINT", "http://localhost:9999")
@@ -162,7 +160,7 @@ class TestA2APushNotificationHandler:
)
@pytest.fixture
def mock_task(self) -> "Task":
def mock_task(self) -> Task:
"""Create a minimal valid task for testing."""
from a2a.types import Task, TaskStatus
@@ -182,11 +180,12 @@ class TestA2APushNotificationHandler:
from unittest.mock import AsyncMock, MagicMock
from a2a.types import Task, TaskStatus
from crewai_a2a.updates.push_notifications.config import PushNotificationConfig
from crewai_a2a.updates.push_notifications.handler import (
PushNotificationHandler,
)
from pydantic import AnyHttpUrl
from crewai.a2a.updates.push_notifications.config import PushNotificationConfig
from crewai.a2a.updates.push_notifications.handler import PushNotificationHandler
completed_task = Task(
id="task-123",
context_id="ctx-123",
@@ -246,11 +245,12 @@ class TestA2APushNotificationHandler:
from unittest.mock import AsyncMock, MagicMock
from a2a.types import Task, TaskStatus
from crewai_a2a.updates.push_notifications.config import PushNotificationConfig
from crewai_a2a.updates.push_notifications.handler import (
PushNotificationHandler,
)
from pydantic import AnyHttpUrl
from crewai.a2a.updates.push_notifications.config import PushNotificationConfig
from crewai.a2a.updates.push_notifications.handler import PushNotificationHandler
mock_store = MagicMock()
mock_store.wait_for_result = AsyncMock(return_value=None)
@@ -303,7 +303,9 @@ class TestA2APushNotificationHandler:
"""Test that push handler fails gracefully without config."""
from unittest.mock import MagicMock
from crewai.a2a.updates.push_notifications.handler import PushNotificationHandler
from crewai_a2a.updates.push_notifications.handler import (
PushNotificationHandler,
)
mock_client = MagicMock()

View File

@@ -3,10 +3,9 @@
from __future__ import annotations
from a2a.types import AgentCard, AgentSkill
from crewai import Agent
from crewai.a2a.config import A2AClientConfig, A2AServerConfig
from crewai.a2a.utils.agent_card import inject_a2a_server_methods
from crewai_a2a.config import A2AClientConfig, A2AServerConfig
from crewai_a2a.utils.agent_card import inject_a2a_server_methods
class TestInjectA2AServerMethods:

View File

@@ -6,13 +6,12 @@ import asyncio
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import pytest_asyncio
from a2a.server.agent_execution import RequestContext
from a2a.server.events import EventQueue
from a2a.types import Message, Task as A2ATask, TaskState, TaskStatus
from crewai.a2a.utils.task import cancel, cancellable, execute
from crewai_a2a.utils.task import cancel, cancellable, execute
import pytest
import pytest_asyncio
@pytest.fixture
@@ -85,8 +84,11 @@ class TestCancellableDecorator:
assert call_count == 1
@pytest.mark.asyncio
async def test_executes_function_with_context(self, mock_context: MagicMock) -> None:
async def test_executes_function_with_context(
self, mock_context: MagicMock
) -> None:
"""Function executes normally with RequestContext when not cancelled."""
@cancellable
async def my_func(context: RequestContext) -> str:
await asyncio.sleep(0.01)
@@ -134,6 +136,7 @@ class TestCancellableDecorator:
@pytest.mark.asyncio
async def test_extracts_context_from_kwargs(self, mock_context: MagicMock) -> None:
"""Context can be passed as keyword argument."""
@cancellable
async def my_func(value: int, context: RequestContext | None = None) -> int:
return value + 1
@@ -156,8 +159,8 @@ class TestExecute:
) -> None:
"""Execute completes successfully and enqueues completed task."""
with (
patch("crewai.a2a.utils.task.Task", return_value=mock_task),
patch("crewai.a2a.utils.task.crewai_event_bus") as mock_bus,
patch("crewai_a2a.utils.task.Task", return_value=mock_task),
patch("crewai_a2a.utils.task.crewai_event_bus") as mock_bus,
):
await execute(mock_agent, mock_context, mock_event_queue)
@@ -175,8 +178,8 @@ class TestExecute:
) -> None:
"""Execute emits A2AServerTaskStartedEvent."""
with (
patch("crewai.a2a.utils.task.Task", return_value=mock_task),
patch("crewai.a2a.utils.task.crewai_event_bus") as mock_bus,
patch("crewai_a2a.utils.task.Task", return_value=mock_task),
patch("crewai_a2a.utils.task.crewai_event_bus") as mock_bus,
):
await execute(mock_agent, mock_context, mock_event_queue)
@@ -197,8 +200,8 @@ class TestExecute:
) -> None:
"""Execute emits A2AServerTaskCompletedEvent on success."""
with (
patch("crewai.a2a.utils.task.Task", return_value=mock_task),
patch("crewai.a2a.utils.task.crewai_event_bus") as mock_bus,
patch("crewai_a2a.utils.task.Task", return_value=mock_task),
patch("crewai_a2a.utils.task.crewai_event_bus") as mock_bus,
):
await execute(mock_agent, mock_context, mock_event_queue)
@@ -221,8 +224,8 @@ class TestExecute:
mock_agent.aexecute_task = AsyncMock(side_effect=ValueError("Test error"))
with (
patch("crewai.a2a.utils.task.Task", return_value=mock_task),
patch("crewai.a2a.utils.task.crewai_event_bus") as mock_bus,
patch("crewai_a2a.utils.task.Task", return_value=mock_task),
patch("crewai_a2a.utils.task.crewai_event_bus") as mock_bus,
):
with pytest.raises(Exception):
await execute(mock_agent, mock_context, mock_event_queue)
@@ -245,8 +248,8 @@ class TestExecute:
mock_agent.aexecute_task = AsyncMock(side_effect=asyncio.CancelledError())
with (
patch("crewai.a2a.utils.task.Task", return_value=mock_task),
patch("crewai.a2a.utils.task.crewai_event_bus") as mock_bus,
patch("crewai_a2a.utils.task.Task", return_value=mock_task),
patch("crewai_a2a.utils.task.crewai_event_bus") as mock_bus,
):
with pytest.raises(asyncio.CancelledError):
await execute(mock_agent, mock_context, mock_event_queue)
@@ -354,6 +357,7 @@ class TestExecuteAndCancelIntegration:
mock_task: MagicMock,
) -> None:
"""Calling cancel stops a running execute."""
async def slow_task(**kwargs: Any) -> str:
await asyncio.sleep(2.0)
return "should not complete"
@@ -361,8 +365,8 @@ class TestExecuteAndCancelIntegration:
mock_agent.aexecute_task = slow_task
with (
patch("crewai.a2a.utils.task.Task", return_value=mock_task),
patch("crewai.a2a.utils.task.crewai_event_bus"),
patch("crewai_a2a.utils.task.Task", return_value=mock_task),
patch("crewai_a2a.utils.task.crewai_event_bus"),
):
execute_task = asyncio.create_task(
execute(mock_agent, mock_context, mock_event_queue)
@@ -372,4 +376,4 @@ class TestExecuteAndCancelIntegration:
await cancel(mock_context, mock_event_queue)
with pytest.raises(asyncio.CancelledError):
await execute_task
await execute_task

View File

@@ -152,4 +152,4 @@ __all__ = [
"wrap_file_source",
]
__version__ = "1.9.3"
__version__ = "1.10.1b1"

View File

@@ -8,12 +8,10 @@ authors = [
]
requires-python = ">=3.10, <3.14"
dependencies = [
"lancedb~=0.5.4",
"pytube~=15.0.0",
"requests~=2.32.5",
"docker~=7.1.0",
"crewai==1.9.3",
"lancedb~=0.5.4",
"crewai==1.10.1b1",
"tiktoken~=0.8.0",
"beautifulsoup4~=4.13.4",
"python-docx~=1.2.0",

View File

@@ -291,4 +291,4 @@ __all__ = [
"ZapierActionTools",
]
__version__ = "1.9.3"
__version__ = "1.10.1b1"

View File

@@ -20117,18 +20117,6 @@
"humanized_name": "Web Automation Tool",
"init_params_schema": {
"$defs": {
"AvailableModel": {
"enum": [
"gpt-4o",
"gpt-4o-mini",
"claude-3-5-sonnet-latest",
"claude-3-7-sonnet-latest",
"computer-use-preview",
"gemini-2.0-flash"
],
"title": "AvailableModel",
"type": "string"
},
"EnvVar": {
"properties": {
"default": {
@@ -20206,17 +20194,6 @@
"default": null,
"title": "Model Api Key"
},
"model_name": {
"anyOf": [
{
"$ref": "#/$defs/AvailableModel"
},
{
"type": "null"
}
],
"default": "claude-3-7-sonnet-latest"
},
"project_id": {
"anyOf": [
{

View File

@@ -38,10 +38,11 @@ dependencies = [
"json5~=0.10.0",
"portalocker~=2.7.0",
"pydantic-settings~=2.10.1",
"httpx~=0.28.1",
"mcp~=1.26.0",
"uv~=0.9.13",
"aiosqlite~=0.21.0",
"lancedb>=0.4.0",
"lancedb>=0.29.2",
]
[project.urls]
@@ -52,7 +53,7 @@ Repository = "https://github.com/crewAIInc/crewAI"
[project.optional-dependencies]
tools = [
"crewai-tools==1.9.3",
"crewai-tools==1.10.1b1",
]
embeddings = [
"tiktoken~=0.8.0"
@@ -95,12 +96,7 @@ azure-ai-inference = [
anthropic = [
"anthropic~=0.73.0",
]
a2a = [
"a2a-sdk~=0.3.10",
"httpx-auth~=0.23.1",
"httpx-sse~=0.4.0",
"aiocache[redis,memcached]~=0.12.3",
]
a2a = ["crewai-a2a==1.10.1b1"]
file-processing = [
"crewai-files",
]
@@ -131,6 +127,7 @@ torchvision = [
{ index = "pytorch", marker = "python_version < '3.13'" },
]
crewai-files = { workspace = true }
crewai-a2a = { workspace = true }
[build-system]

View File

@@ -10,7 +10,6 @@ from crewai.flow.flow import Flow
from crewai.knowledge.knowledge import Knowledge
from crewai.llm import LLM
from crewai.llms.base_llm import BaseLLM
from crewai.memory.unified_memory import Memory
from crewai.process import Process
from crewai.task import Task
from crewai.tasks.llm_guardrail import LLMGuardrail
@@ -41,7 +40,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
_suppress_pydantic_deprecation_warnings()
__version__ = "1.9.3"
__version__ = "1.10.1b1"
_telemetry_submitted = False
@@ -72,6 +71,25 @@ def _track_install_async() -> None:
_track_install_async()
_LAZY_IMPORTS: dict[str, tuple[str, str]] = {
"Memory": ("crewai.memory.unified_memory", "Memory"),
}
def __getattr__(name: str) -> Any:
"""Lazily import heavy modules (e.g. Memory → lancedb) on first access."""
if name in _LAZY_IMPORTS:
module_path, attr = _LAZY_IMPORTS[name]
import importlib
mod = importlib.import_module(module_path)
val = getattr(mod, attr)
globals()[name] = val
return val
raise AttributeError(f"module 'crewai' has no attribute {name!r}")
__all__ = [
"LLM",
"Agent",

View File

@@ -1,10 +1,13 @@
"""Agent-to-Agent (A2A) protocol communication module for CrewAI."""
"""Backward-compatibility shim — use ``crewai_a2a`` instead."""
from crewai.a2a.config import A2AClientConfig, A2AConfig, A2AServerConfig
import warnings
__all__ = [
"A2AClientConfig",
"A2AConfig",
"A2AServerConfig",
]
warnings.warn(
"'crewai.a2a' has been moved to 'crewai_a2a'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
from crewai_a2a import * # noqa: E402, F403

View File

@@ -1,36 +1,13 @@
"""A2A authentication schemas."""
"""Backward-compatibility shim — use ``crewai_a2a.auth`` instead."""
from crewai.a2a.auth.client_schemes import (
APIKeyAuth,
AuthScheme,
BearerTokenAuth,
ClientAuthScheme,
HTTPBasicAuth,
HTTPDigestAuth,
OAuth2AuthorizationCode,
OAuth2ClientCredentials,
TLSConfig,
)
from crewai.a2a.auth.server_schemes import (
AuthenticatedUser,
OIDCAuth,
ServerAuthScheme,
SimpleTokenAuth,
import warnings
warnings.warn(
"'crewai.a2a.auth' has been moved to 'crewai_a2a.auth'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
__all__ = [
"APIKeyAuth",
"AuthScheme",
"AuthenticatedUser",
"BearerTokenAuth",
"ClientAuthScheme",
"HTTPBasicAuth",
"HTTPDigestAuth",
"OAuth2AuthorizationCode",
"OAuth2ClientCredentials",
"OIDCAuth",
"ServerAuthScheme",
"SimpleTokenAuth",
"TLSConfig",
]
from crewai_a2a.auth import * # noqa: E402, F403

View File

@@ -1,550 +1,13 @@
"""Authentication schemes for A2A protocol clients.
"""Backward-compatibility shim — use ``crewai_a2a.auth.client_schemes`` instead."""
Supported authentication methods:
- Bearer tokens
- OAuth2 (Client Credentials, Authorization Code)
- API Keys (header, query, cookie)
- HTTP Basic authentication
- HTTP Digest authentication
- mTLS (mutual TLS) client certificate authentication
"""
import warnings
from __future__ import annotations
from abc import ABC, abstractmethod
import asyncio
import base64
from collections.abc import Awaitable, Callable, MutableMapping
from pathlib import Path
import ssl
import time
from typing import TYPE_CHECKING, ClassVar, Literal
import urllib.parse
warnings.warn(
"'crewai.a2a.auth.client_schemes' has been moved to 'crewai_a2a.auth.client_schemes'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
import httpx
from httpx import DigestAuth
from pydantic import BaseModel, ConfigDict, Field, FilePath, PrivateAttr
from typing_extensions import deprecated
if TYPE_CHECKING:
import grpc # type: ignore[import-untyped]
class TLSConfig(BaseModel):
"""TLS/mTLS configuration for secure client connections.
Supports mutual TLS (mTLS) where the client presents a certificate to the server,
and standard TLS with custom CA verification.
Attributes:
client_cert_path: Path to client certificate file (PEM format) for mTLS.
client_key_path: Path to client private key file (PEM format) for mTLS.
ca_cert_path: Path to CA certificate bundle for server verification.
verify: Whether to verify server certificates. Set False only for development.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
client_cert_path: FilePath | None = Field(
default=None,
description="Path to client certificate file (PEM format) for mTLS",
)
client_key_path: FilePath | None = Field(
default=None,
description="Path to client private key file (PEM format) for mTLS",
)
ca_cert_path: FilePath | None = Field(
default=None,
description="Path to CA certificate bundle for server verification",
)
verify: bool = Field(
default=True,
description="Whether to verify server certificates. Set False only for development.",
)
def get_httpx_ssl_context(self) -> ssl.SSLContext | bool | str:
"""Build SSL context for httpx client.
Returns:
SSL context if certificates configured, True for default verification,
False if verification disabled, or path to CA bundle.
"""
if not self.verify:
return False
if self.client_cert_path and self.client_key_path:
context = ssl.create_default_context()
if self.ca_cert_path:
context.load_verify_locations(cafile=str(self.ca_cert_path))
context.load_cert_chain(
certfile=str(self.client_cert_path),
keyfile=str(self.client_key_path),
)
return context
if self.ca_cert_path:
return str(self.ca_cert_path)
return True
def get_grpc_credentials(self) -> grpc.ChannelCredentials | None: # type: ignore[no-any-unimported]
"""Build gRPC channel credentials for secure connections.
Returns:
gRPC SSL credentials if certificates configured, None otherwise.
"""
try:
import grpc
except ImportError:
return None
if not self.verify and not self.client_cert_path:
return None
root_certs: bytes | None = None
private_key: bytes | None = None
certificate_chain: bytes | None = None
if self.ca_cert_path:
root_certs = Path(self.ca_cert_path).read_bytes()
if self.client_cert_path and self.client_key_path:
private_key = Path(self.client_key_path).read_bytes()
certificate_chain = Path(self.client_cert_path).read_bytes()
return grpc.ssl_channel_credentials(
root_certificates=root_certs,
private_key=private_key,
certificate_chain=certificate_chain,
)
class ClientAuthScheme(ABC, BaseModel):
"""Base class for client-side authentication schemes.
Client auth schemes apply credentials to outgoing requests.
Attributes:
tls: Optional TLS/mTLS configuration for secure connections.
"""
tls: TLSConfig | None = Field(
default=None,
description="TLS/mTLS configuration for secure connections",
)
@abstractmethod
async def apply_auth(
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
) -> MutableMapping[str, str]:
"""Apply authentication to request headers.
Args:
client: HTTP client for making auth requests.
headers: Current request headers.
Returns:
Updated headers with authentication applied.
"""
...
@deprecated("Use ClientAuthScheme instead", category=FutureWarning)
class AuthScheme(ClientAuthScheme):
"""Deprecated: Use ClientAuthScheme instead."""
class BearerTokenAuth(ClientAuthScheme):
"""Bearer token authentication (Authorization: Bearer <token>).
Attributes:
token: Bearer token for authentication.
"""
token: str = Field(description="Bearer token")
async def apply_auth(
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
) -> MutableMapping[str, str]:
"""Apply Bearer token to Authorization header.
Args:
client: HTTP client for making auth requests.
headers: Current request headers.
Returns:
Updated headers with Bearer token in Authorization header.
"""
headers["Authorization"] = f"Bearer {self.token}"
return headers
class HTTPBasicAuth(ClientAuthScheme):
"""HTTP Basic authentication.
Attributes:
username: Username for Basic authentication.
password: Password for Basic authentication.
"""
username: str = Field(description="Username")
password: str = Field(description="Password")
async def apply_auth(
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
) -> MutableMapping[str, str]:
"""Apply HTTP Basic authentication.
Args:
client: HTTP client for making auth requests.
headers: Current request headers.
Returns:
Updated headers with Basic auth in Authorization header.
"""
credentials = f"{self.username}:{self.password}"
encoded = base64.b64encode(credentials.encode()).decode()
headers["Authorization"] = f"Basic {encoded}"
return headers
class HTTPDigestAuth(ClientAuthScheme):
"""HTTP Digest authentication.
Note: Uses httpx-auth library for digest implementation.
Attributes:
username: Username for Digest authentication.
password: Password for Digest authentication.
"""
username: str = Field(description="Username")
password: str = Field(description="Password")
_configured_client_id: int | None = PrivateAttr(default=None)
async def apply_auth(
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
) -> MutableMapping[str, str]:
"""Digest auth is handled by httpx auth flow, not headers.
Args:
client: HTTP client for making auth requests.
headers: Current request headers.
Returns:
Unchanged headers (Digest auth handled by httpx auth flow).
"""
return headers
def configure_client(self, client: httpx.AsyncClient) -> None:
"""Configure client with Digest auth.
Idempotent: Only configures the client once. Subsequent calls on the same
client instance are no-ops to prevent overwriting auth configuration.
Args:
client: HTTP client to configure with Digest authentication.
"""
client_id = id(client)
if self._configured_client_id == client_id:
return
client.auth = DigestAuth(self.username, self.password)
self._configured_client_id = client_id
class APIKeyAuth(ClientAuthScheme):
"""API Key authentication (header, query, or cookie).
Attributes:
api_key: API key value for authentication.
location: Where to send the API key (header, query, or cookie).
name: Parameter name for the API key (default: X-API-Key).
"""
api_key: str = Field(description="API key value")
location: Literal["header", "query", "cookie"] = Field(
default="header", description="Where to send the API key"
)
name: str = Field(default="X-API-Key", description="Parameter name for the API key")
_configured_client_ids: set[int] = PrivateAttr(default_factory=set)
async def apply_auth(
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
) -> MutableMapping[str, str]:
"""Apply API key authentication.
Args:
client: HTTP client for making auth requests.
headers: Current request headers.
Returns:
Updated headers with API key (for header/cookie locations).
"""
if self.location == "header":
headers[self.name] = self.api_key
elif self.location == "cookie":
headers["Cookie"] = f"{self.name}={self.api_key}"
return headers
def configure_client(self, client: httpx.AsyncClient) -> None:
"""Configure client for query param API keys.
Idempotent: Only adds the request hook once per client instance.
Subsequent calls on the same client are no-ops to prevent hook accumulation.
Args:
client: HTTP client to configure with query param API key hook.
"""
if self.location == "query":
client_id = id(client)
if client_id in self._configured_client_ids:
return
async def _add_api_key_param(request: httpx.Request) -> None:
url = httpx.URL(request.url)
request.url = url.copy_add_param(self.name, self.api_key)
client.event_hooks["request"].append(_add_api_key_param)
self._configured_client_ids.add(client_id)
class OAuth2ClientCredentials(ClientAuthScheme):
"""OAuth2 Client Credentials flow authentication.
Thread-safe implementation with asyncio.Lock to prevent concurrent token fetches
when multiple requests share the same auth instance.
Attributes:
token_url: OAuth2 token endpoint URL.
client_id: OAuth2 client identifier.
client_secret: OAuth2 client secret.
scopes: List of required OAuth2 scopes.
"""
token_url: str = Field(description="OAuth2 token endpoint")
client_id: str = Field(description="OAuth2 client ID")
client_secret: str = Field(description="OAuth2 client secret")
scopes: list[str] = Field(
default_factory=list, description="Required OAuth2 scopes"
)
_access_token: str | None = PrivateAttr(default=None)
_token_expires_at: float | None = PrivateAttr(default=None)
_lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock)
async def apply_auth(
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
) -> MutableMapping[str, str]:
"""Apply OAuth2 access token to Authorization header.
Uses asyncio.Lock to ensure only one coroutine fetches tokens at a time,
preventing race conditions when multiple concurrent requests use the same
auth instance.
Args:
client: HTTP client for making token requests.
headers: Current request headers.
Returns:
Updated headers with OAuth2 access token in Authorization header.
"""
if (
self._access_token is None
or self._token_expires_at is None
or time.time() >= self._token_expires_at
):
async with self._lock:
if (
self._access_token is None
or self._token_expires_at is None
or time.time() >= self._token_expires_at
):
await self._fetch_token(client)
if self._access_token:
headers["Authorization"] = f"Bearer {self._access_token}"
return headers
async def _fetch_token(self, client: httpx.AsyncClient) -> None:
"""Fetch OAuth2 access token using client credentials flow.
Args:
client: HTTP client for making token request.
Raises:
httpx.HTTPStatusError: If token request fails.
"""
data = {
"grant_type": "client_credentials",
"client_id": self.client_id,
"client_secret": self.client_secret,
}
if self.scopes:
data["scope"] = " ".join(self.scopes)
response = await client.post(self.token_url, data=data)
response.raise_for_status()
token_data = response.json()
self._access_token = token_data["access_token"]
expires_in = token_data.get("expires_in", 3600)
self._token_expires_at = time.time() + expires_in - 60
class OAuth2AuthorizationCode(ClientAuthScheme):
"""OAuth2 Authorization Code flow authentication.
Thread-safe implementation with asyncio.Lock to prevent concurrent token operations.
Note: Requires interactive authorization.
Attributes:
authorization_url: OAuth2 authorization endpoint URL.
token_url: OAuth2 token endpoint URL.
client_id: OAuth2 client identifier.
client_secret: OAuth2 client secret.
redirect_uri: OAuth2 redirect URI for callback.
scopes: List of required OAuth2 scopes.
"""
authorization_url: str = Field(description="OAuth2 authorization endpoint")
token_url: str = Field(description="OAuth2 token endpoint")
client_id: str = Field(description="OAuth2 client ID")
client_secret: str = Field(description="OAuth2 client secret")
redirect_uri: str = Field(description="OAuth2 redirect URI")
scopes: list[str] = Field(
default_factory=list, description="Required OAuth2 scopes"
)
_access_token: str | None = PrivateAttr(default=None)
_refresh_token: str | None = PrivateAttr(default=None)
_token_expires_at: float | None = PrivateAttr(default=None)
_authorization_callback: Callable[[str], Awaitable[str]] | None = PrivateAttr(
default=None
)
_lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock)
def set_authorization_callback(
self, callback: Callable[[str], Awaitable[str]] | None
) -> None:
"""Set callback to handle authorization URL.
Args:
callback: Async function that receives authorization URL and returns auth code.
"""
self._authorization_callback = callback
async def apply_auth(
self, client: httpx.AsyncClient, headers: MutableMapping[str, str]
) -> MutableMapping[str, str]:
"""Apply OAuth2 access token to Authorization header.
Uses asyncio.Lock to ensure only one coroutine handles token operations
(initial fetch or refresh) at a time.
Args:
client: HTTP client for making token requests.
headers: Current request headers.
Returns:
Updated headers with OAuth2 access token in Authorization header.
Raises:
ValueError: If authorization callback is not set.
"""
if self._access_token is None:
if self._authorization_callback is None:
msg = "Authorization callback not set. Use set_authorization_callback()"
raise ValueError(msg)
async with self._lock:
if self._access_token is None:
await self._fetch_initial_token(client)
elif self._token_expires_at and time.time() >= self._token_expires_at:
async with self._lock:
if self._token_expires_at and time.time() >= self._token_expires_at:
await self._refresh_access_token(client)
if self._access_token:
headers["Authorization"] = f"Bearer {self._access_token}"
return headers
async def _fetch_initial_token(self, client: httpx.AsyncClient) -> None:
"""Fetch initial access token using authorization code flow.
Args:
client: HTTP client for making token request.
Raises:
ValueError: If authorization callback is not set.
httpx.HTTPStatusError: If token request fails.
"""
params = {
"response_type": "code",
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
"scope": " ".join(self.scopes),
}
auth_url = f"{self.authorization_url}?{urllib.parse.urlencode(params)}"
if self._authorization_callback is None:
msg = "Authorization callback not set"
raise ValueError(msg)
auth_code = await self._authorization_callback(auth_url)
data = {
"grant_type": "authorization_code",
"code": auth_code,
"client_id": self.client_id,
"client_secret": self.client_secret,
"redirect_uri": self.redirect_uri,
}
response = await client.post(self.token_url, data=data)
response.raise_for_status()
token_data = response.json()
self._access_token = token_data["access_token"]
self._refresh_token = token_data.get("refresh_token")
expires_in = token_data.get("expires_in", 3600)
self._token_expires_at = time.time() + expires_in - 60
async def _refresh_access_token(self, client: httpx.AsyncClient) -> None:
"""Refresh the access token using refresh token.
Args:
client: HTTP client for making token request.
Raises:
httpx.HTTPStatusError: If token refresh request fails.
"""
if not self._refresh_token:
await self._fetch_initial_token(client)
return
data = {
"grant_type": "refresh_token",
"refresh_token": self._refresh_token,
"client_id": self.client_id,
"client_secret": self.client_secret,
}
response = await client.post(self.token_url, data=data)
response.raise_for_status()
token_data = response.json()
self._access_token = token_data["access_token"]
if "refresh_token" in token_data:
self._refresh_token = token_data["refresh_token"]
expires_in = token_data.get("expires_in", 3600)
self._token_expires_at = time.time() + expires_in - 60
from crewai_a2a.auth.client_schemes import * # noqa: E402, F403

View File

@@ -1,71 +1,13 @@
"""Deprecated: Authentication schemes for A2A protocol agents.
"""Backward-compatibility shim — use ``crewai_a2a.auth.schemas`` instead."""
This module is deprecated. Import from crewai.a2a.auth instead:
- crewai.a2a.auth.ClientAuthScheme (replaces AuthScheme)
- crewai.a2a.auth.BearerTokenAuth
- crewai.a2a.auth.HTTPBasicAuth
- crewai.a2a.auth.HTTPDigestAuth
- crewai.a2a.auth.APIKeyAuth
- crewai.a2a.auth.OAuth2ClientCredentials
- crewai.a2a.auth.OAuth2AuthorizationCode
"""
import warnings
from __future__ import annotations
from typing_extensions import deprecated
from crewai.a2a.auth.client_schemes import (
APIKeyAuth as _APIKeyAuth,
BearerTokenAuth as _BearerTokenAuth,
ClientAuthScheme as _ClientAuthScheme,
HTTPBasicAuth as _HTTPBasicAuth,
HTTPDigestAuth as _HTTPDigestAuth,
OAuth2AuthorizationCode as _OAuth2AuthorizationCode,
OAuth2ClientCredentials as _OAuth2ClientCredentials,
warnings.warn(
"'crewai.a2a.auth.schemas' has been moved to 'crewai_a2a.auth.schemas'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
@deprecated("Use ClientAuthScheme from crewai.a2a.auth instead", category=FutureWarning)
class AuthScheme(_ClientAuthScheme):
"""Deprecated: Use ClientAuthScheme from crewai.a2a.auth instead."""
@deprecated("Import from crewai.a2a.auth instead", category=FutureWarning)
class BearerTokenAuth(_BearerTokenAuth):
"""Deprecated: Import from crewai.a2a.auth instead."""
@deprecated("Import from crewai.a2a.auth instead", category=FutureWarning)
class HTTPBasicAuth(_HTTPBasicAuth):
"""Deprecated: Import from crewai.a2a.auth instead."""
@deprecated("Import from crewai.a2a.auth instead", category=FutureWarning)
class HTTPDigestAuth(_HTTPDigestAuth):
"""Deprecated: Import from crewai.a2a.auth instead."""
@deprecated("Import from crewai.a2a.auth instead", category=FutureWarning)
class APIKeyAuth(_APIKeyAuth):
"""Deprecated: Import from crewai.a2a.auth instead."""
@deprecated("Import from crewai.a2a.auth instead", category=FutureWarning)
class OAuth2ClientCredentials(_OAuth2ClientCredentials):
"""Deprecated: Import from crewai.a2a.auth instead."""
@deprecated("Import from crewai.a2a.auth instead", category=FutureWarning)
class OAuth2AuthorizationCode(_OAuth2AuthorizationCode):
"""Deprecated: Import from crewai.a2a.auth instead."""
__all__ = [
"APIKeyAuth",
"AuthScheme",
"BearerTokenAuth",
"HTTPBasicAuth",
"HTTPDigestAuth",
"OAuth2AuthorizationCode",
"OAuth2ClientCredentials",
]
from crewai_a2a.auth.schemas import * # noqa: E402, F403

View File

@@ -1,739 +1,13 @@
"""Server-side authentication schemes for A2A protocol.
"""Backward-compatibility shim — use ``crewai_a2a.auth.server_schemes`` instead."""
These schemes validate incoming requests to A2A server endpoints.
import warnings
Supported authentication methods:
- Simple token validation with static bearer tokens
- OpenID Connect with JWT validation using JWKS
- OAuth2 with JWT validation or token introspection
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
import logging
import os
from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal
import jwt
from jwt import PyJWKClient
from pydantic import (
BaseModel,
BeforeValidator,
ConfigDict,
Field,
HttpUrl,
PrivateAttr,
SecretStr,
model_validator,
warnings.warn(
"'crewai.a2a.auth.server_schemes' has been moved to 'crewai_a2a.auth.server_schemes'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
from typing_extensions import Self
if TYPE_CHECKING:
from a2a.types import OAuth2SecurityScheme
logger = logging.getLogger(__name__)
try:
from fastapi import HTTPException, status as http_status
HTTP_401_UNAUTHORIZED = http_status.HTTP_401_UNAUTHORIZED
HTTP_500_INTERNAL_SERVER_ERROR = http_status.HTTP_500_INTERNAL_SERVER_ERROR
HTTP_503_SERVICE_UNAVAILABLE = http_status.HTTP_503_SERVICE_UNAVAILABLE
except ImportError:
class HTTPException(Exception): # type: ignore[no-redef] # noqa: N818
"""Fallback HTTPException when FastAPI is not installed."""
def __init__(
self,
status_code: int,
detail: str | None = None,
headers: dict[str, str] | None = None,
) -> None:
self.status_code = status_code
self.detail = detail
self.headers = headers
super().__init__(detail)
HTTP_401_UNAUTHORIZED = 401
HTTP_500_INTERNAL_SERVER_ERROR = 500
HTTP_503_SERVICE_UNAVAILABLE = 503
def _coerce_secret_str(v: str | SecretStr | None) -> SecretStr | None:
"""Coerce string to SecretStr."""
if v is None or isinstance(v, SecretStr):
return v
return SecretStr(v)
CoercedSecretStr = Annotated[SecretStr, BeforeValidator(_coerce_secret_str)]
JWTAlgorithm = Literal[
"RS256",
"RS384",
"RS512",
"ES256",
"ES384",
"ES512",
"PS256",
"PS384",
"PS512",
]
@dataclass
class AuthenticatedUser:
"""Result of successful authentication.
Attributes:
token: The original token that was validated.
scheme: Name of the authentication scheme used.
claims: JWT claims from OIDC or OAuth2 authentication.
"""
token: str
scheme: str
claims: dict[str, Any] | None = None
class ServerAuthScheme(ABC, BaseModel):
"""Base class for server-side authentication schemes.
Each scheme validates incoming requests and returns an AuthenticatedUser
on success, or raises HTTPException on failure.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
@abstractmethod
async def authenticate(self, token: str) -> AuthenticatedUser:
"""Authenticate the provided token.
Args:
token: The bearer token to authenticate.
Returns:
AuthenticatedUser on successful authentication.
Raises:
HTTPException: If authentication fails.
"""
...
class SimpleTokenAuth(ServerAuthScheme):
"""Simple bearer token authentication.
Validates tokens against a configured static token or AUTH_TOKEN env var.
Attributes:
token: Expected token value. Falls back to AUTH_TOKEN env var if not set.
"""
token: CoercedSecretStr | None = Field(
default=None,
description="Expected token. Falls back to AUTH_TOKEN env var.",
)
def _get_expected_token(self) -> str | None:
"""Get the expected token value."""
if self.token:
return self.token.get_secret_value()
return os.environ.get("AUTH_TOKEN")
async def authenticate(self, token: str) -> AuthenticatedUser:
"""Authenticate using simple token comparison.
Args:
token: The bearer token to authenticate.
Returns:
AuthenticatedUser on successful authentication.
Raises:
HTTPException: If authentication fails.
"""
expected = self._get_expected_token()
if expected is None:
logger.warning(
"Simple token authentication failed",
extra={"reason": "no_token_configured"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Authentication not configured",
)
if token != expected:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid or missing authentication credentials",
)
return AuthenticatedUser(
token=token,
scheme="simple_token",
)
class OIDCAuth(ServerAuthScheme):
"""OpenID Connect authentication.
Validates JWTs using JWKS with caching support via PyJWT.
Attributes:
issuer: The OpenID Connect issuer URL.
audience: The expected audience claim.
jwks_url: Optional explicit JWKS URL. Derived from issuer if not set.
algorithms: List of allowed signing algorithms.
required_claims: List of claims that must be present in the token.
jwks_cache_ttl: TTL for JWKS cache in seconds.
clock_skew_seconds: Allowed clock skew for token validation.
"""
issuer: HttpUrl = Field(
description="OpenID Connect issuer URL (e.g., https://auth.example.com)"
)
audience: str = Field(description="Expected audience claim (e.g., api://my-agent)")
jwks_url: HttpUrl | None = Field(
default=None,
description="Explicit JWKS URL. Derived from issuer if not set.",
)
algorithms: list[str] = Field(
default_factory=lambda: ["RS256"],
description="List of allowed signing algorithms (RS256, ES256, etc.)",
)
required_claims: list[str] = Field(
default_factory=lambda: ["exp", "iat", "iss", "aud", "sub"],
description="List of claims that must be present in the token",
)
jwks_cache_ttl: int = Field(
default=3600,
description="TTL for JWKS cache in seconds",
ge=60,
)
clock_skew_seconds: float = Field(
default=30.0,
description="Allowed clock skew for token validation",
ge=0.0,
)
_jwk_client: PyJWKClient | None = PrivateAttr(default=None)
@model_validator(mode="after")
def _init_jwk_client(self) -> Self:
"""Initialize the JWK client after model creation."""
jwks_url = (
str(self.jwks_url)
if self.jwks_url
else f"{str(self.issuer).rstrip('/')}/.well-known/jwks.json"
)
self._jwk_client = PyJWKClient(jwks_url, lifespan=self.jwks_cache_ttl)
return self
async def authenticate(self, token: str) -> AuthenticatedUser:
"""Authenticate using OIDC JWT validation.
Args:
token: The JWT to authenticate.
Returns:
AuthenticatedUser on successful authentication.
Raises:
HTTPException: If authentication fails.
"""
if self._jwk_client is None:
raise HTTPException(
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
detail="OIDC not initialized",
)
try:
signing_key = self._jwk_client.get_signing_key_from_jwt(token)
claims = jwt.decode(
token,
signing_key.key,
algorithms=self.algorithms,
audience=self.audience,
issuer=str(self.issuer).rstrip("/"),
leeway=self.clock_skew_seconds,
options={
"require": self.required_claims,
},
)
return AuthenticatedUser(
token=token,
scheme="oidc",
claims=claims,
)
except jwt.ExpiredSignatureError:
logger.debug(
"OIDC authentication failed",
extra={"reason": "token_expired", "scheme": "oidc"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Token has expired",
) from None
except jwt.InvalidAudienceError:
logger.debug(
"OIDC authentication failed",
extra={"reason": "invalid_audience", "scheme": "oidc"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid token audience",
) from None
except jwt.InvalidIssuerError:
logger.debug(
"OIDC authentication failed",
extra={"reason": "invalid_issuer", "scheme": "oidc"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid token issuer",
) from None
except jwt.MissingRequiredClaimError as e:
logger.debug(
"OIDC authentication failed",
extra={"reason": "missing_claim", "claim": e.claim, "scheme": "oidc"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail=f"Missing required claim: {e.claim}",
) from None
except jwt.PyJWKClientError as e:
logger.error(
"OIDC authentication failed",
extra={
"reason": "jwks_client_error",
"error": str(e),
"scheme": "oidc",
},
)
raise HTTPException(
status_code=HTTP_503_SERVICE_UNAVAILABLE,
detail="Unable to fetch signing keys",
) from None
except jwt.InvalidTokenError as e:
logger.debug(
"OIDC authentication failed",
extra={"reason": "invalid_token", "error": str(e), "scheme": "oidc"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid or missing authentication credentials",
) from None
class OAuth2ServerAuth(ServerAuthScheme):
"""OAuth2 authentication for A2A server.
Declares OAuth2 security scheme in AgentCard and validates tokens using
either JWKS for JWT tokens or token introspection for opaque tokens.
This is distinct from OIDCAuth in that it declares an explicit OAuth2SecurityScheme
with flows, rather than an OpenIdConnectSecurityScheme with discovery URL.
Attributes:
token_url: OAuth2 token endpoint URL for client_credentials flow.
authorization_url: OAuth2 authorization endpoint for authorization_code flow.
refresh_url: Optional refresh token endpoint URL.
scopes: Available OAuth2 scopes with descriptions.
jwks_url: JWKS URL for JWT validation. Required if not using introspection.
introspection_url: Token introspection endpoint (RFC 7662). Alternative to JWKS.
introspection_client_id: Client ID for introspection endpoint authentication.
introspection_client_secret: Client secret for introspection endpoint.
audience: Expected audience claim for JWT validation.
issuer: Expected issuer claim for JWT validation.
algorithms: Allowed JWT signing algorithms.
required_claims: Claims that must be present in the token.
jwks_cache_ttl: TTL for JWKS cache in seconds.
clock_skew_seconds: Allowed clock skew for token validation.
"""
token_url: HttpUrl = Field(
description="OAuth2 token endpoint URL",
)
authorization_url: HttpUrl | None = Field(
default=None,
description="OAuth2 authorization endpoint URL for authorization_code flow",
)
refresh_url: HttpUrl | None = Field(
default=None,
description="OAuth2 refresh token endpoint URL",
)
scopes: dict[str, str] = Field(
default_factory=dict,
description="Available OAuth2 scopes with descriptions",
)
jwks_url: HttpUrl | None = Field(
default=None,
description="JWKS URL for JWT validation. Required if not using introspection.",
)
introspection_url: HttpUrl | None = Field(
default=None,
description="Token introspection endpoint (RFC 7662). Alternative to JWKS.",
)
introspection_client_id: str | None = Field(
default=None,
description="Client ID for introspection endpoint authentication",
)
introspection_client_secret: CoercedSecretStr | None = Field(
default=None,
description="Client secret for introspection endpoint authentication",
)
audience: str | None = Field(
default=None,
description="Expected audience claim for JWT validation",
)
issuer: str | None = Field(
default=None,
description="Expected issuer claim for JWT validation",
)
algorithms: list[str] = Field(
default_factory=lambda: ["RS256"],
description="Allowed JWT signing algorithms",
)
required_claims: list[str] = Field(
default_factory=lambda: ["exp", "iat"],
description="Claims that must be present in the token",
)
jwks_cache_ttl: int = Field(
default=3600,
description="TTL for JWKS cache in seconds",
ge=60,
)
clock_skew_seconds: float = Field(
default=30.0,
description="Allowed clock skew for token validation",
ge=0.0,
)
_jwk_client: PyJWKClient | None = PrivateAttr(default=None)
@model_validator(mode="after")
def _validate_and_init(self) -> Self:
"""Validate configuration and initialize JWKS client if needed."""
if not self.jwks_url and not self.introspection_url:
raise ValueError(
"Either jwks_url or introspection_url must be provided for token validation"
)
if self.introspection_url:
if not self.introspection_client_id or not self.introspection_client_secret:
raise ValueError(
"introspection_client_id and introspection_client_secret are required "
"when using token introspection"
)
if self.jwks_url:
self._jwk_client = PyJWKClient(
str(self.jwks_url), lifespan=self.jwks_cache_ttl
)
return self
async def authenticate(self, token: str) -> AuthenticatedUser:
"""Authenticate using OAuth2 token validation.
Uses JWKS validation if jwks_url is configured, otherwise falls back
to token introspection.
Args:
token: The OAuth2 access token to authenticate.
Returns:
AuthenticatedUser on successful authentication.
Raises:
HTTPException: If authentication fails.
"""
if self._jwk_client:
return await self._authenticate_jwt(token)
return await self._authenticate_introspection(token)
async def _authenticate_jwt(self, token: str) -> AuthenticatedUser:
"""Authenticate using JWKS JWT validation."""
if self._jwk_client is None:
raise HTTPException(
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
detail="OAuth2 JWKS not initialized",
)
try:
signing_key = self._jwk_client.get_signing_key_from_jwt(token)
decode_options: dict[str, Any] = {
"require": self.required_claims,
}
claims = jwt.decode(
token,
signing_key.key,
algorithms=self.algorithms,
audience=self.audience,
issuer=self.issuer,
leeway=self.clock_skew_seconds,
options=decode_options,
)
return AuthenticatedUser(
token=token,
scheme="oauth2",
claims=claims,
)
except jwt.ExpiredSignatureError:
logger.debug(
"OAuth2 authentication failed",
extra={"reason": "token_expired", "scheme": "oauth2"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Token has expired",
) from None
except jwt.InvalidAudienceError:
logger.debug(
"OAuth2 authentication failed",
extra={"reason": "invalid_audience", "scheme": "oauth2"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid token audience",
) from None
except jwt.InvalidIssuerError:
logger.debug(
"OAuth2 authentication failed",
extra={"reason": "invalid_issuer", "scheme": "oauth2"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid token issuer",
) from None
except jwt.MissingRequiredClaimError as e:
logger.debug(
"OAuth2 authentication failed",
extra={"reason": "missing_claim", "claim": e.claim, "scheme": "oauth2"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail=f"Missing required claim: {e.claim}",
) from None
except jwt.PyJWKClientError as e:
logger.error(
"OAuth2 authentication failed",
extra={
"reason": "jwks_client_error",
"error": str(e),
"scheme": "oauth2",
},
)
raise HTTPException(
status_code=HTTP_503_SERVICE_UNAVAILABLE,
detail="Unable to fetch signing keys",
) from None
except jwt.InvalidTokenError as e:
logger.debug(
"OAuth2 authentication failed",
extra={"reason": "invalid_token", "error": str(e), "scheme": "oauth2"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid or missing authentication credentials",
) from None
async def _authenticate_introspection(self, token: str) -> AuthenticatedUser:
"""Authenticate using OAuth2 token introspection (RFC 7662)."""
import httpx
if not self.introspection_url:
raise HTTPException(
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
detail="OAuth2 introspection not configured",
)
try:
async with httpx.AsyncClient() as client:
response = await client.post(
str(self.introspection_url),
data={"token": token},
auth=(
self.introspection_client_id or "",
self.introspection_client_secret.get_secret_value()
if self.introspection_client_secret
else "",
),
)
response.raise_for_status()
introspection_result = response.json()
except httpx.HTTPStatusError as e:
logger.error(
"OAuth2 introspection failed",
extra={"reason": "http_error", "status_code": e.response.status_code},
)
raise HTTPException(
status_code=HTTP_503_SERVICE_UNAVAILABLE,
detail="Token introspection service unavailable",
) from None
except Exception as e:
logger.error(
"OAuth2 introspection failed",
extra={"reason": "unexpected_error", "error": str(e)},
)
raise HTTPException(
status_code=HTTP_503_SERVICE_UNAVAILABLE,
detail="Token introspection failed",
) from None
if not introspection_result.get("active", False):
logger.debug(
"OAuth2 authentication failed",
extra={"reason": "token_not_active", "scheme": "oauth2"},
)
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Token is not active",
)
return AuthenticatedUser(
token=token,
scheme="oauth2",
claims=introspection_result,
)
def to_security_scheme(self) -> OAuth2SecurityScheme:
"""Generate OAuth2SecurityScheme for AgentCard declaration.
Creates an OAuth2SecurityScheme with appropriate flows based on
the configured URLs. Includes client_credentials flow if token_url
is set, and authorization_code flow if authorization_url is set.
Returns:
OAuth2SecurityScheme suitable for use in AgentCard security_schemes.
"""
from a2a.types import (
AuthorizationCodeOAuthFlow,
ClientCredentialsOAuthFlow,
OAuth2SecurityScheme,
OAuthFlows,
)
client_credentials = None
authorization_code = None
if self.token_url:
client_credentials = ClientCredentialsOAuthFlow(
token_url=str(self.token_url),
refresh_url=str(self.refresh_url) if self.refresh_url else None,
scopes=self.scopes,
)
if self.authorization_url:
authorization_code = AuthorizationCodeOAuthFlow(
authorization_url=str(self.authorization_url),
token_url=str(self.token_url),
refresh_url=str(self.refresh_url) if self.refresh_url else None,
scopes=self.scopes,
)
return OAuth2SecurityScheme(
flows=OAuthFlows(
client_credentials=client_credentials,
authorization_code=authorization_code,
),
description="OAuth2 authentication",
)
class APIKeyServerAuth(ServerAuthScheme):
"""API Key authentication for A2A server.
Validates requests using an API key in a header, query parameter, or cookie.
Attributes:
name: The name of the API key parameter (default: X-API-Key).
location: Where to look for the API key (header, query, or cookie).
api_key: The expected API key value.
"""
name: str = Field(
default="X-API-Key",
description="Name of the API key parameter",
)
location: Literal["header", "query", "cookie"] = Field(
default="header",
description="Where to look for the API key",
)
api_key: CoercedSecretStr = Field(
description="Expected API key value",
)
async def authenticate(self, token: str) -> AuthenticatedUser:
"""Authenticate using API key comparison.
Args:
token: The API key to authenticate.
Returns:
AuthenticatedUser on successful authentication.
Raises:
HTTPException: If authentication fails.
"""
if token != self.api_key.get_secret_value():
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
)
return AuthenticatedUser(
token=token,
scheme="api_key",
)
class MTLSServerAuth(ServerAuthScheme):
"""Mutual TLS authentication marker for AgentCard declaration.
This scheme is primarily for AgentCard security_schemes declaration.
Actual mTLS verification happens at the TLS/transport layer, not
at the application layer via token validation.
When configured, this signals to clients that the server requires
client certificates for authentication.
"""
description: str = Field(
default="Mutual TLS certificate authentication",
description="Description for the security scheme",
)
async def authenticate(self, token: str) -> AuthenticatedUser:
"""Return authenticated user for mTLS.
mTLS verification happens at the transport layer before this is called.
If we reach this point, the TLS handshake with client cert succeeded.
Args:
token: Certificate subject or identifier (from TLS layer).
Returns:
AuthenticatedUser indicating mTLS authentication.
"""
return AuthenticatedUser(
token=token or "mtls-verified",
scheme="mtls",
)
from crewai_a2a.auth.server_schemes import * # noqa: E402, F403

View File

@@ -1,273 +1,13 @@
"""Authentication utilities for A2A protocol agent communication.
"""Backward-compatibility shim — use ``crewai_a2a.auth.utils`` instead."""
Provides validation and retry logic for various authentication schemes including
OAuth2, API keys, and HTTP authentication methods.
"""
import warnings
import asyncio
from collections.abc import Awaitable, Callable, MutableMapping
import hashlib
import re
import threading
from typing import Final, Literal, cast
from a2a.client.errors import A2AClientHTTPError
from a2a.types import (
APIKeySecurityScheme,
AgentCard,
HTTPAuthSecurityScheme,
OAuth2SecurityScheme,
)
from httpx import AsyncClient, Response
from crewai.a2a.auth.client_schemes import (
APIKeyAuth,
BearerTokenAuth,
ClientAuthScheme,
HTTPBasicAuth,
HTTPDigestAuth,
OAuth2AuthorizationCode,
OAuth2ClientCredentials,
warnings.warn(
"'crewai.a2a.auth.utils' has been moved to 'crewai_a2a.auth.utils'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
class _AuthStore:
"""Store for authentication schemes with safe concurrent access."""
def __init__(self) -> None:
self._store: dict[str, ClientAuthScheme | None] = {}
self._lock = threading.RLock()
@staticmethod
def compute_key(auth_type: str, auth_data: str) -> str:
"""Compute a collision-resistant key using SHA-256."""
content = f"{auth_type}:{auth_data}"
return hashlib.sha256(content.encode()).hexdigest()
def set(self, key: str, auth: ClientAuthScheme | None) -> None:
"""Store an auth scheme."""
with self._lock:
self._store[key] = auth
def get(self, key: str) -> ClientAuthScheme | None:
"""Retrieve an auth scheme by key."""
with self._lock:
return self._store.get(key)
def __setitem__(self, key: str, value: ClientAuthScheme | None) -> None:
with self._lock:
self._store[key] = value
def __getitem__(self, key: str) -> ClientAuthScheme | None:
with self._lock:
return self._store[key]
_auth_store = _AuthStore()
_SCHEME_PATTERN: Final[re.Pattern[str]] = re.compile(r"(\w+)\s+(.+?)(?=,\s*\w+\s+|$)")
_PARAM_PATTERN: Final[re.Pattern[str]] = re.compile(r'(\w+)=(?:"([^"]*)"|([^\s,]+))')
_SCHEME_AUTH_MAPPING: Final[dict[type, tuple[type[ClientAuthScheme], ...]]] = {
OAuth2SecurityScheme: (
OAuth2ClientCredentials,
OAuth2AuthorizationCode,
BearerTokenAuth,
),
APIKeySecurityScheme: (APIKeyAuth,),
}
_HTTPSchemeType = Literal["basic", "digest", "bearer"]
_HTTP_SCHEME_MAPPING: Final[dict[_HTTPSchemeType, type[ClientAuthScheme]]] = {
"basic": HTTPBasicAuth,
"digest": HTTPDigestAuth,
"bearer": BearerTokenAuth,
}
def _raise_auth_mismatch(
expected_classes: type[ClientAuthScheme] | tuple[type[ClientAuthScheme], ...],
provided_auth: ClientAuthScheme,
) -> None:
"""Raise authentication mismatch error.
Args:
expected_classes: Expected authentication class or tuple of classes.
provided_auth: Actually provided authentication instance.
Raises:
A2AClientHTTPError: Always raises with 401 status code.
"""
if isinstance(expected_classes, tuple):
if len(expected_classes) == 1:
required = expected_classes[0].__name__
else:
names = [cls.__name__ for cls in expected_classes]
required = f"one of ({', '.join(names)})"
else:
required = expected_classes.__name__
msg = (
f"AgentCard requires {required} authentication, "
f"but {type(provided_auth).__name__} was provided"
)
raise A2AClientHTTPError(401, msg)
def parse_www_authenticate(header_value: str) -> dict[str, dict[str, str]]:
"""Parse WWW-Authenticate header into auth challenges.
Args:
header_value: The WWW-Authenticate header value.
Returns:
Dictionary mapping auth scheme to its parameters.
Example: {"Bearer": {"realm": "api", "scope": "read write"}}
"""
if not header_value:
return {}
challenges: dict[str, dict[str, str]] = {}
for match in _SCHEME_PATTERN.finditer(header_value):
scheme = match.group(1)
params_str = match.group(2)
params: dict[str, str] = {}
for param_match in _PARAM_PATTERN.finditer(params_str):
key = param_match.group(1)
value = param_match.group(2) or param_match.group(3)
params[key] = value
challenges[scheme] = params
return challenges
def validate_auth_against_agent_card(
agent_card: AgentCard, auth: ClientAuthScheme | None
) -> None:
"""Validate that provided auth matches AgentCard security requirements.
Args:
agent_card: The A2A AgentCard containing security requirements.
auth: User-provided authentication scheme (or None).
Raises:
A2AClientHTTPError: If auth doesn't match AgentCard requirements (status_code=401).
"""
if not agent_card.security or not agent_card.security_schemes:
return
if not auth:
msg = "AgentCard requires authentication but no auth scheme provided"
raise A2AClientHTTPError(401, msg)
first_security_req = agent_card.security[0] if agent_card.security else {}
for scheme_name in first_security_req.keys():
security_scheme_wrapper = agent_card.security_schemes.get(scheme_name)
if not security_scheme_wrapper:
continue
scheme = security_scheme_wrapper.root
if allowed_classes := _SCHEME_AUTH_MAPPING.get(type(scheme)):
if not isinstance(auth, allowed_classes):
_raise_auth_mismatch(allowed_classes, auth)
return
if isinstance(scheme, HTTPAuthSecurityScheme):
scheme_key = cast(_HTTPSchemeType, scheme.scheme.lower())
if required_class := _HTTP_SCHEME_MAPPING.get(scheme_key):
if not isinstance(auth, required_class):
_raise_auth_mismatch(required_class, auth)
return
msg = "Could not validate auth against AgentCard security requirements"
raise A2AClientHTTPError(401, msg)
async def retry_on_401(
request_func: Callable[[], Awaitable[Response]],
auth_scheme: ClientAuthScheme | None,
client: AsyncClient,
headers: MutableMapping[str, str],
max_retries: int = 3,
) -> Response:
"""Retry a request on 401 authentication error.
Handles 401 errors by:
1. Parsing WWW-Authenticate header
2. Re-acquiring credentials
3. Retrying the request
Args:
request_func: Async function that makes the HTTP request.
auth_scheme: Authentication scheme to refresh credentials with.
client: HTTP client for making requests.
headers: Request headers to update with new auth.
max_retries: Maximum number of retry attempts (default: 3).
Returns:
HTTP response from the request.
Raises:
httpx.HTTPStatusError: If retries are exhausted or auth scheme is None.
"""
last_response: Response | None = None
last_challenges: dict[str, dict[str, str]] = {}
for attempt in range(max_retries):
response = await request_func()
if response.status_code != 401:
return response
last_response = response
if auth_scheme is None:
response.raise_for_status()
return response
www_authenticate = response.headers.get("WWW-Authenticate", "")
challenges = parse_www_authenticate(www_authenticate)
last_challenges = challenges
if attempt >= max_retries - 1:
break
backoff_time = 2**attempt
await asyncio.sleep(backoff_time)
await auth_scheme.apply_auth(client, headers)
if last_response:
last_response.raise_for_status()
return last_response
msg = "retry_on_401 failed without making any requests"
if last_challenges:
challenge_info = ", ".join(
f"{scheme} (realm={params.get('realm', 'N/A')})"
for scheme, params in last_challenges.items()
)
msg = f"{msg}. Server challenges: {challenge_info}"
raise RuntimeError(msg)
def configure_auth_client(
auth: HTTPDigestAuth | APIKeyAuth, client: AsyncClient
) -> None:
"""Configure HTTP client with auth-specific settings.
Only HTTPDigestAuth and APIKeyAuth need client configuration.
Args:
auth: Authentication scheme that requires client configuration.
client: HTTP client to configure.
"""
auth.configure_client(client)
from crewai_a2a.auth.utils import * # noqa: E402, F403

View File

@@ -1,690 +1,13 @@
"""A2A configuration types.
"""Backward-compatibility shim — use ``crewai_a2a.config`` instead."""
This module is separate from experimental.a2a to avoid circular imports.
"""
from __future__ import annotations
from pathlib import Path
from typing import Any, ClassVar, Literal, cast
import warnings
from pydantic import (
BaseModel,
ConfigDict,
Field,
FilePath,
PrivateAttr,
SecretStr,
model_validator,
warnings.warn(
"'crewai.a2a.config' has been moved to 'crewai_a2a.config'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
from typing_extensions import Self, deprecated
from crewai.a2a.auth.client_schemes import ClientAuthScheme
from crewai.a2a.auth.server_schemes import ServerAuthScheme
from crewai.a2a.extensions.base import ValidatedA2AExtension
from crewai.a2a.types import ProtocolVersion, TransportType, Url
try:
from a2a.types import (
AgentCapabilities,
AgentCardSignature,
AgentInterface,
AgentProvider,
AgentSkill,
SecurityScheme,
)
from crewai.a2a.extensions.server import ServerExtension
from crewai.a2a.updates import UpdateConfig
except ImportError:
UpdateConfig: Any = Any # type: ignore[no-redef]
AgentCapabilities: Any = Any # type: ignore[no-redef]
AgentCardSignature: Any = Any # type: ignore[no-redef]
AgentInterface: Any = Any # type: ignore[no-redef]
AgentProvider: Any = Any # type: ignore[no-redef]
SecurityScheme: Any = Any # type: ignore[no-redef]
AgentSkill: Any = Any # type: ignore[no-redef]
ServerExtension: Any = Any # type: ignore[no-redef]
def _get_default_update_config() -> UpdateConfig:
from crewai.a2a.updates import StreamingConfig
return StreamingConfig()
SigningAlgorithm = Literal[
"RS256",
"RS384",
"RS512",
"ES256",
"ES384",
"ES512",
"PS256",
"PS384",
"PS512",
]
class AgentCardSigningConfig(BaseModel):
"""Configuration for AgentCard JWS signing.
Provides the private key and algorithm settings for signing AgentCards.
Either private_key_path or private_key_pem must be provided, but not both.
Attributes:
private_key_path: Path to a PEM-encoded private key file.
private_key_pem: PEM-encoded private key as a secret string.
key_id: Optional key identifier for the JWS header (kid claim).
algorithm: Signing algorithm (RS256, ES256, PS256, etc.).
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
private_key_path: FilePath | None = Field(
default=None,
description="Path to PEM-encoded private key file",
)
private_key_pem: SecretStr | None = Field(
default=None,
description="PEM-encoded private key",
)
key_id: str | None = Field(
default=None,
description="Key identifier for JWS header (kid claim)",
)
algorithm: SigningAlgorithm = Field(
default="RS256",
description="Signing algorithm (RS256, ES256, PS256, etc.)",
)
@model_validator(mode="after")
def _validate_key_source(self) -> Self:
"""Ensure exactly one key source is provided."""
has_path = self.private_key_path is not None
has_pem = self.private_key_pem is not None
if not has_path and not has_pem:
raise ValueError(
"Either private_key_path or private_key_pem must be provided"
)
if has_path and has_pem:
raise ValueError(
"Only one of private_key_path or private_key_pem should be provided"
)
return self
def get_private_key(self) -> str:
"""Get the private key content.
Returns:
The PEM-encoded private key as a string.
"""
if self.private_key_pem:
return self.private_key_pem.get_secret_value()
if self.private_key_path:
return Path(self.private_key_path).read_text()
raise ValueError("No private key configured")
class GRPCServerConfig(BaseModel):
"""gRPC server transport configuration.
Presence of this config in ServerTransportConfig.grpc enables gRPC transport.
Attributes:
host: Hostname to advertise in agent cards (default: localhost).
Use docker service name (e.g., 'web') for docker-compose setups.
port: Port for the gRPC server.
tls_cert_path: Path to TLS certificate file for gRPC.
tls_key_path: Path to TLS private key file for gRPC.
max_workers: Maximum number of workers for the gRPC thread pool.
reflection_enabled: Whether to enable gRPC reflection for debugging.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
host: str = Field(
default="localhost",
description="Hostname to advertise in agent cards for gRPC connections",
)
port: int = Field(
default=50051,
description="Port for the gRPC server",
)
tls_cert_path: str | None = Field(
default=None,
description="Path to TLS certificate file for gRPC",
)
tls_key_path: str | None = Field(
default=None,
description="Path to TLS private key file for gRPC",
)
max_workers: int = Field(
default=10,
description="Maximum number of workers for the gRPC thread pool",
)
reflection_enabled: bool = Field(
default=False,
description="Whether to enable gRPC reflection for debugging",
)
class GRPCClientConfig(BaseModel):
"""gRPC client transport configuration.
Attributes:
max_send_message_length: Maximum size for outgoing messages in bytes.
max_receive_message_length: Maximum size for incoming messages in bytes.
keepalive_time_ms: Time between keepalive pings in milliseconds.
keepalive_timeout_ms: Timeout for keepalive ping response in milliseconds.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
max_send_message_length: int | None = Field(
default=None,
description="Maximum size for outgoing messages in bytes",
)
max_receive_message_length: int | None = Field(
default=None,
description="Maximum size for incoming messages in bytes",
)
keepalive_time_ms: int | None = Field(
default=None,
description="Time between keepalive pings in milliseconds",
)
keepalive_timeout_ms: int | None = Field(
default=None,
description="Timeout for keepalive ping response in milliseconds",
)
class JSONRPCServerConfig(BaseModel):
"""JSON-RPC server transport configuration.
Presence of this config in ServerTransportConfig.jsonrpc enables JSON-RPC transport.
Attributes:
rpc_path: URL path for the JSON-RPC endpoint.
agent_card_path: URL path for the agent card endpoint.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
rpc_path: str = Field(
default="/a2a",
description="URL path for the JSON-RPC endpoint",
)
agent_card_path: str = Field(
default="/.well-known/agent-card.json",
description="URL path for the agent card endpoint",
)
class JSONRPCClientConfig(BaseModel):
"""JSON-RPC client transport configuration.
Attributes:
max_request_size: Maximum request body size in bytes.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
max_request_size: int | None = Field(
default=None,
description="Maximum request body size in bytes",
)
class HTTPJSONConfig(BaseModel):
"""HTTP+JSON transport configuration.
Presence of this config in ServerTransportConfig.http_json enables HTTP+JSON transport.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
class ServerPushNotificationConfig(BaseModel):
"""Configuration for outgoing webhook push notifications.
Controls how the server signs and delivers push notifications to clients.
Attributes:
signature_secret: Shared secret for HMAC-SHA256 signing of outgoing webhooks.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
signature_secret: SecretStr | None = Field(
default=None,
description="Shared secret for HMAC-SHA256 signing of outgoing push notifications",
)
class ServerTransportConfig(BaseModel):
"""Transport configuration for A2A server.
Groups all transport-related settings including preferred transport
and protocol-specific configurations.
Attributes:
preferred: Transport protocol for the preferred endpoint.
jsonrpc: JSON-RPC server transport configuration.
grpc: gRPC server transport configuration.
http_json: HTTP+JSON transport configuration.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
preferred: TransportType = Field(
default="JSONRPC",
description="Transport protocol for the preferred endpoint",
)
jsonrpc: JSONRPCServerConfig = Field(
default_factory=JSONRPCServerConfig,
description="JSON-RPC server transport configuration",
)
grpc: GRPCServerConfig | None = Field(
default=None,
description="gRPC server transport configuration",
)
http_json: HTTPJSONConfig | None = Field(
default=None,
description="HTTP+JSON transport configuration",
)
def _migrate_client_transport_fields(
transport: ClientTransportConfig,
transport_protocol: TransportType | None,
supported_transports: list[TransportType] | None,
) -> None:
"""Migrate deprecated transport fields to new config."""
if transport_protocol is not None:
warnings.warn(
"transport_protocol is deprecated, use transport=ClientTransportConfig(preferred=...) instead",
FutureWarning,
stacklevel=5,
)
object.__setattr__(transport, "preferred", transport_protocol)
if supported_transports is not None:
warnings.warn(
"supported_transports is deprecated, use transport=ClientTransportConfig(supported=...) instead",
FutureWarning,
stacklevel=5,
)
object.__setattr__(transport, "supported", supported_transports)
class ClientTransportConfig(BaseModel):
"""Transport configuration for A2A client.
Groups all client transport-related settings including preferred transport,
supported transports for negotiation, and protocol-specific configurations.
Transport negotiation logic:
1. If `preferred` is set and server supports it → use client's preferred
2. Otherwise, if server's preferred is in client's `supported` → use server's preferred
3. Otherwise, find first match from client's `supported` in server's interfaces
Attributes:
preferred: Client's preferred transport. If set, client preference takes priority.
supported: Transports the client can use, in order of preference.
jsonrpc: JSON-RPC client transport configuration.
grpc: gRPC client transport configuration.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
preferred: TransportType | None = Field(
default=None,
description="Client's preferred transport. If set, takes priority over server preference.",
)
supported: list[TransportType] = Field(
default_factory=lambda: cast(list[TransportType], ["JSONRPC"]),
description="Transports the client can use, in order of preference",
)
jsonrpc: JSONRPCClientConfig = Field(
default_factory=JSONRPCClientConfig,
description="JSON-RPC client transport configuration",
)
grpc: GRPCClientConfig = Field(
default_factory=GRPCClientConfig,
description="gRPC client transport configuration",
)
@deprecated(
"""
`crewai.a2a.config.A2AConfig` is deprecated and will be removed in v2.0.0,
use `crewai.a2a.config.A2AClientConfig` or `crewai.a2a.config.A2AServerConfig` instead.
""",
category=FutureWarning,
)
class A2AConfig(BaseModel):
"""Configuration for A2A protocol integration.
Deprecated:
Use A2AClientConfig instead. This class will be removed in a future version.
Attributes:
endpoint: A2A agent endpoint URL.
auth: Authentication scheme.
timeout: Request timeout in seconds.
max_turns: Maximum conversation turns with A2A agent.
response_model: Optional Pydantic model for structured A2A agent responses.
fail_fast: If True, raise error when agent unreachable; if False, skip and continue.
trust_remote_completion_status: If True, return A2A agent's result directly when completed.
updates: Update mechanism config.
client_extensions: Client-side processing hooks for tool injection and prompt augmentation.
transport: Transport configuration (preferred, supported transports, gRPC settings).
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
endpoint: Url = Field(description="A2A agent endpoint URL")
auth: ClientAuthScheme | None = Field(
default=None,
description="Authentication scheme",
)
timeout: int = Field(default=120, description="Request timeout in seconds")
max_turns: int = Field(
default=10, description="Maximum conversation turns with A2A agent"
)
response_model: type[BaseModel] | None = Field(
default=None,
description="Optional Pydantic model for structured A2A agent responses",
)
fail_fast: bool = Field(
default=True,
description="If True, raise error when agent unreachable; if False, skip",
)
trust_remote_completion_status: bool = Field(
default=False,
description="If True, return A2A result directly when completed",
)
updates: UpdateConfig = Field(
default_factory=_get_default_update_config,
description="Update mechanism config",
)
client_extensions: list[ValidatedA2AExtension] = Field(
default_factory=list,
description="Client-side processing hooks for tool injection and prompt augmentation",
)
transport: ClientTransportConfig = Field(
default_factory=ClientTransportConfig,
description="Transport configuration (preferred, supported transports, gRPC settings)",
)
transport_protocol: TransportType | None = Field(
default=None,
description="Deprecated: Use transport.preferred instead",
exclude=True,
)
supported_transports: list[TransportType] | None = Field(
default=None,
description="Deprecated: Use transport.supported instead",
exclude=True,
)
use_client_preference: bool | None = Field(
default=None,
description="Deprecated: Set transport.preferred to enable client preference",
exclude=True,
)
_parallel_delegation: bool = PrivateAttr(default=False)
@model_validator(mode="after")
def _migrate_deprecated_transport_fields(self) -> Self:
"""Migrate deprecated transport fields to new config."""
_migrate_client_transport_fields(
self.transport, self.transport_protocol, self.supported_transports
)
if self.use_client_preference is not None:
warnings.warn(
"use_client_preference is deprecated, set transport.preferred to enable client preference",
FutureWarning,
stacklevel=4,
)
if self.use_client_preference and self.transport.supported:
object.__setattr__(
self.transport, "preferred", self.transport.supported[0]
)
return self
class A2AClientConfig(BaseModel):
"""Configuration for connecting to remote A2A agents.
Attributes:
endpoint: A2A agent endpoint URL.
auth: Authentication scheme.
timeout: Request timeout in seconds.
max_turns: Maximum conversation turns with A2A agent.
response_model: Optional Pydantic model for structured A2A agent responses.
fail_fast: If True, raise error when agent unreachable; if False, skip and continue.
trust_remote_completion_status: If True, return A2A agent's result directly when completed.
updates: Update mechanism config.
accepted_output_modes: Media types the client can accept in responses.
extensions: Extension URIs the client supports (A2A protocol extensions).
client_extensions: Client-side processing hooks for tool injection and prompt augmentation.
transport: Transport configuration (preferred, supported transports, gRPC settings).
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
endpoint: Url = Field(description="A2A agent endpoint URL")
auth: ClientAuthScheme | None = Field(
default=None,
description="Authentication scheme",
)
timeout: int = Field(default=120, description="Request timeout in seconds")
max_turns: int = Field(
default=10, description="Maximum conversation turns with A2A agent"
)
response_model: type[BaseModel] | None = Field(
default=None,
description="Optional Pydantic model for structured A2A agent responses",
)
fail_fast: bool = Field(
default=True,
description="If True, raise error when agent unreachable; if False, skip",
)
trust_remote_completion_status: bool = Field(
default=False,
description="If True, return A2A result directly when completed",
)
updates: UpdateConfig = Field(
default_factory=_get_default_update_config,
description="Update mechanism config",
)
accepted_output_modes: list[str] = Field(
default_factory=lambda: ["application/json"],
description="Media types the client can accept in responses",
)
extensions: list[str] = Field(
default_factory=list,
description="Extension URIs the client supports",
)
client_extensions: list[ValidatedA2AExtension] = Field(
default_factory=list,
description="Client-side processing hooks for tool injection and prompt augmentation",
)
transport: ClientTransportConfig = Field(
default_factory=ClientTransportConfig,
description="Transport configuration (preferred, supported transports, gRPC settings)",
)
transport_protocol: TransportType | None = Field(
default=None,
description="Deprecated: Use transport.preferred instead",
exclude=True,
)
supported_transports: list[TransportType] | None = Field(
default=None,
description="Deprecated: Use transport.supported instead",
exclude=True,
)
_parallel_delegation: bool = PrivateAttr(default=False)
@model_validator(mode="after")
def _migrate_deprecated_transport_fields(self) -> Self:
"""Migrate deprecated transport fields to new config."""
_migrate_client_transport_fields(
self.transport, self.transport_protocol, self.supported_transports
)
return self
class A2AServerConfig(BaseModel):
"""Configuration for exposing a Crew or Agent as an A2A server.
All fields correspond to A2A AgentCard fields. Fields like name, description,
and skills can be auto-derived from the Crew/Agent if not provided.
Attributes:
name: Human-readable name for the agent.
description: Human-readable description of the agent.
version: Version string for the agent card.
skills: List of agent skills/capabilities.
default_input_modes: Default supported input MIME types.
default_output_modes: Default supported output MIME types.
capabilities: Declaration of optional capabilities.
protocol_version: A2A protocol version this agent supports.
provider: Information about the agent's service provider.
documentation_url: URL to the agent's documentation.
icon_url: URL to an icon for the agent.
additional_interfaces: Additional supported interfaces.
security: Security requirement objects for all interactions.
security_schemes: Security schemes available to authorize requests.
supports_authenticated_extended_card: Whether agent provides extended card to authenticated users.
url: Preferred endpoint URL for the agent.
signing_config: Configuration for signing the AgentCard with JWS.
signatures: Deprecated. Pre-computed JWS signatures. Use signing_config instead.
server_extensions: Server-side A2A protocol extensions with on_request/on_response hooks.
push_notifications: Configuration for outgoing push notifications.
transport: Transport configuration (preferred transport, gRPC, REST settings).
auth: Authentication scheme for A2A endpoints.
"""
model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
name: str | None = Field(
default=None,
description="Human-readable name for the agent. Auto-derived from Crew/Agent if not provided.",
)
description: str | None = Field(
default=None,
description="Human-readable description of the agent. Auto-derived from Crew/Agent if not provided.",
)
version: str = Field(
default="1.0.0",
description="Version string for the agent card",
)
skills: list[AgentSkill] = Field(
default_factory=list,
description="List of agent skills. Auto-derived from tasks/tools if not provided.",
)
default_input_modes: list[str] = Field(
default_factory=lambda: ["text/plain", "application/json"],
description="Default supported input MIME types",
)
default_output_modes: list[str] = Field(
default_factory=lambda: ["text/plain", "application/json"],
description="Default supported output MIME types",
)
capabilities: AgentCapabilities = Field(
default_factory=lambda: AgentCapabilities(
streaming=True,
push_notifications=False,
),
description="Declaration of optional capabilities supported by the agent",
)
protocol_version: ProtocolVersion = Field(
default="0.3.0",
description="A2A protocol version this agent supports",
)
provider: AgentProvider | None = Field(
default=None,
description="Information about the agent's service provider",
)
documentation_url: Url | None = Field(
default=None,
description="URL to the agent's documentation",
)
icon_url: Url | None = Field(
default=None,
description="URL to an icon for the agent",
)
additional_interfaces: list[AgentInterface] = Field(
default_factory=list,
description="Additional supported interfaces.",
)
security: list[dict[str, list[str]]] = Field(
default_factory=list,
description="Security requirement objects for all agent interactions",
)
security_schemes: dict[str, SecurityScheme] = Field(
default_factory=dict,
description="Security schemes available to authorize requests",
)
supports_authenticated_extended_card: bool = Field(
default=False,
description="Whether agent provides extended card to authenticated users",
)
url: Url | None = Field(
default=None,
description="Preferred endpoint URL for the agent. Set at runtime if not provided.",
)
signing_config: AgentCardSigningConfig | None = Field(
default=None,
description="Configuration for signing the AgentCard with JWS",
)
signatures: list[AgentCardSignature] | None = Field(
default=None,
description="Deprecated: Use signing_config instead. Pre-computed JWS signatures for the AgentCard.",
exclude=True,
deprecated=True,
)
server_extensions: list[ServerExtension] = Field(
default_factory=list,
description="Server-side A2A protocol extensions that modify agent behavior",
)
push_notifications: ServerPushNotificationConfig | None = Field(
default=None,
description="Configuration for outgoing push notifications",
)
transport: ServerTransportConfig = Field(
default_factory=ServerTransportConfig,
description="Transport configuration (preferred transport, gRPC, REST settings)",
)
preferred_transport: TransportType | None = Field(
default=None,
description="Deprecated: Use transport.preferred instead",
exclude=True,
deprecated=True,
)
auth: ServerAuthScheme | None = Field(
default=None,
description="Authentication scheme for A2A endpoints. Defaults to SimpleTokenAuth using AUTH_TOKEN env var.",
)
@model_validator(mode="after")
def _migrate_deprecated_fields(self) -> Self:
"""Migrate deprecated fields to new config."""
if self.preferred_transport is not None:
warnings.warn(
"preferred_transport is deprecated, use transport=ServerTransportConfig(preferred=...) instead",
FutureWarning,
stacklevel=4,
)
object.__setattr__(self.transport, "preferred", self.preferred_transport)
if self.signatures is not None:
warnings.warn(
"signatures is deprecated, use signing_config=AgentCardSigningConfig(...) instead. "
"The signatures field will be removed in v2.0.0.",
FutureWarning,
stacklevel=4,
)
return self
from crewai_a2a.config import * # noqa: E402, F403

View File

@@ -1,491 +1,13 @@
"""A2A error codes and error response utilities.
"""Backward-compatibility shim — use ``crewai_a2a.errors`` instead."""
This module provides a centralized mapping of all A2A protocol error codes
as defined in the A2A specification, plus custom CrewAI extensions.
import warnings
Error codes follow JSON-RPC 2.0 conventions:
- -32700 to -32600: Standard JSON-RPC errors
- -32099 to -32000: Server errors (A2A-specific)
- -32768 to -32100: Reserved for implementation-defined errors
"""
from __future__ import annotations
warnings.warn(
"'crewai.a2a.errors' has been moved to 'crewai_a2a.errors'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
from dataclasses import dataclass, field
from enum import IntEnum
from typing import Any
from a2a.client.errors import A2AClientTimeoutError
class A2APollingTimeoutError(A2AClientTimeoutError):
"""Raised when polling exceeds the configured timeout."""
class A2AErrorCode(IntEnum):
"""A2A protocol error codes.
Codes follow JSON-RPC 2.0 specification with A2A-specific extensions.
"""
# JSON-RPC 2.0 Standard Errors (-32700 to -32600)
JSON_PARSE_ERROR = -32700
"""Invalid JSON was received by the server."""
INVALID_REQUEST = -32600
"""The JSON sent is not a valid Request object."""
METHOD_NOT_FOUND = -32601
"""The method does not exist / is not available."""
INVALID_PARAMS = -32602
"""Invalid method parameter(s)."""
INTERNAL_ERROR = -32603
"""Internal JSON-RPC error."""
# A2A-Specific Errors (-32099 to -32000)
TASK_NOT_FOUND = -32001
"""The specified task was not found."""
TASK_NOT_CANCELABLE = -32002
"""The task cannot be canceled (already completed/failed)."""
PUSH_NOTIFICATION_NOT_SUPPORTED = -32003
"""Push notifications are not supported by this agent."""
UNSUPPORTED_OPERATION = -32004
"""The requested operation is not supported."""
CONTENT_TYPE_NOT_SUPPORTED = -32005
"""Incompatible content types between client and server."""
INVALID_AGENT_RESPONSE = -32006
"""The agent produced an invalid response."""
# CrewAI Custom Extensions (-32768 to -32100)
UNSUPPORTED_VERSION = -32009
"""The requested A2A protocol version is not supported."""
UNSUPPORTED_EXTENSION = -32010
"""Client does not support required protocol extensions."""
AUTHENTICATION_REQUIRED = -32011
"""Authentication is required for this operation."""
AUTHORIZATION_FAILED = -32012
"""Authorization check failed (insufficient permissions)."""
RATE_LIMIT_EXCEEDED = -32013
"""Rate limit exceeded for this client/operation."""
TASK_TIMEOUT = -32014
"""Task execution timed out."""
TRANSPORT_NEGOTIATION_FAILED = -32015
"""Failed to negotiate a compatible transport protocol."""
CONTEXT_NOT_FOUND = -32016
"""The specified context was not found."""
SKILL_NOT_FOUND = -32017
"""The specified skill was not found."""
ARTIFACT_NOT_FOUND = -32018
"""The specified artifact was not found."""
# Error code to default message mapping
ERROR_MESSAGES: dict[int, str] = {
A2AErrorCode.JSON_PARSE_ERROR: "Parse error",
A2AErrorCode.INVALID_REQUEST: "Invalid Request",
A2AErrorCode.METHOD_NOT_FOUND: "Method not found",
A2AErrorCode.INVALID_PARAMS: "Invalid params",
A2AErrorCode.INTERNAL_ERROR: "Internal error",
A2AErrorCode.TASK_NOT_FOUND: "Task not found",
A2AErrorCode.TASK_NOT_CANCELABLE: "Task not cancelable",
A2AErrorCode.PUSH_NOTIFICATION_NOT_SUPPORTED: "Push Notification is not supported",
A2AErrorCode.UNSUPPORTED_OPERATION: "This operation is not supported",
A2AErrorCode.CONTENT_TYPE_NOT_SUPPORTED: "Incompatible content types",
A2AErrorCode.INVALID_AGENT_RESPONSE: "Invalid agent response",
A2AErrorCode.UNSUPPORTED_VERSION: "Unsupported A2A version",
A2AErrorCode.UNSUPPORTED_EXTENSION: "Client does not support required extensions",
A2AErrorCode.AUTHENTICATION_REQUIRED: "Authentication required",
A2AErrorCode.AUTHORIZATION_FAILED: "Authorization failed",
A2AErrorCode.RATE_LIMIT_EXCEEDED: "Rate limit exceeded",
A2AErrorCode.TASK_TIMEOUT: "Task execution timed out",
A2AErrorCode.TRANSPORT_NEGOTIATION_FAILED: "Transport negotiation failed",
A2AErrorCode.CONTEXT_NOT_FOUND: "Context not found",
A2AErrorCode.SKILL_NOT_FOUND: "Skill not found",
A2AErrorCode.ARTIFACT_NOT_FOUND: "Artifact not found",
}
@dataclass
class A2AError(Exception):
"""Base exception for A2A protocol errors.
Attributes:
code: The A2A/JSON-RPC error code.
message: Human-readable error message.
data: Optional additional error data.
"""
code: int
message: str | None = None
data: Any = None
def __post_init__(self) -> None:
if self.message is None:
self.message = ERROR_MESSAGES.get(self.code, "Unknown error")
super().__init__(self.message)
def to_dict(self) -> dict[str, Any]:
"""Convert to JSON-RPC error object format."""
error: dict[str, Any] = {
"code": self.code,
"message": self.message,
}
if self.data is not None:
error["data"] = self.data
return error
def to_response(self, request_id: str | int | None = None) -> dict[str, Any]:
"""Convert to full JSON-RPC error response."""
return {
"jsonrpc": "2.0",
"error": self.to_dict(),
"id": request_id,
}
@dataclass
class JSONParseError(A2AError):
"""Invalid JSON was received."""
code: int = field(default=A2AErrorCode.JSON_PARSE_ERROR, init=False)
@dataclass
class InvalidRequestError(A2AError):
"""The JSON sent is not a valid Request object."""
code: int = field(default=A2AErrorCode.INVALID_REQUEST, init=False)
@dataclass
class MethodNotFoundError(A2AError):
"""The method does not exist / is not available."""
code: int = field(default=A2AErrorCode.METHOD_NOT_FOUND, init=False)
method: str | None = None
def __post_init__(self) -> None:
if self.message is None and self.method:
self.message = f"Method not found: {self.method}"
super().__post_init__()
@dataclass
class InvalidParamsError(A2AError):
"""Invalid method parameter(s)."""
code: int = field(default=A2AErrorCode.INVALID_PARAMS, init=False)
param: str | None = None
reason: str | None = None
def __post_init__(self) -> None:
if self.message is None:
if self.param and self.reason:
self.message = f"Invalid parameter '{self.param}': {self.reason}"
elif self.param:
self.message = f"Invalid parameter: {self.param}"
super().__post_init__()
@dataclass
class InternalError(A2AError):
"""Internal JSON-RPC error."""
code: int = field(default=A2AErrorCode.INTERNAL_ERROR, init=False)
@dataclass
class TaskNotFoundError(A2AError):
"""The specified task was not found."""
code: int = field(default=A2AErrorCode.TASK_NOT_FOUND, init=False)
task_id: str | None = None
def __post_init__(self) -> None:
if self.message is None and self.task_id:
self.message = f"Task not found: {self.task_id}"
super().__post_init__()
@dataclass
class TaskNotCancelableError(A2AError):
"""The task cannot be canceled."""
code: int = field(default=A2AErrorCode.TASK_NOT_CANCELABLE, init=False)
task_id: str | None = None
reason: str | None = None
def __post_init__(self) -> None:
if self.message is None:
if self.task_id and self.reason:
self.message = f"Task {self.task_id} cannot be canceled: {self.reason}"
elif self.task_id:
self.message = f"Task {self.task_id} cannot be canceled"
super().__post_init__()
@dataclass
class PushNotificationNotSupportedError(A2AError):
"""Push notifications are not supported."""
code: int = field(default=A2AErrorCode.PUSH_NOTIFICATION_NOT_SUPPORTED, init=False)
@dataclass
class UnsupportedOperationError(A2AError):
"""The requested operation is not supported."""
code: int = field(default=A2AErrorCode.UNSUPPORTED_OPERATION, init=False)
operation: str | None = None
def __post_init__(self) -> None:
if self.message is None and self.operation:
self.message = f"Operation not supported: {self.operation}"
super().__post_init__()
@dataclass
class ContentTypeNotSupportedError(A2AError):
"""Incompatible content types."""
code: int = field(default=A2AErrorCode.CONTENT_TYPE_NOT_SUPPORTED, init=False)
requested_types: list[str] | None = None
supported_types: list[str] | None = None
def __post_init__(self) -> None:
if self.message is None and self.requested_types and self.supported_types:
self.message = (
f"Content type not supported. Requested: {self.requested_types}, "
f"Supported: {self.supported_types}"
)
super().__post_init__()
@dataclass
class InvalidAgentResponseError(A2AError):
"""The agent produced an invalid response."""
code: int = field(default=A2AErrorCode.INVALID_AGENT_RESPONSE, init=False)
@dataclass
class UnsupportedVersionError(A2AError):
"""The requested A2A version is not supported."""
code: int = field(default=A2AErrorCode.UNSUPPORTED_VERSION, init=False)
requested_version: str | None = None
supported_versions: list[str] | None = None
def __post_init__(self) -> None:
if self.message is None and self.requested_version:
msg = f"Unsupported A2A version: {self.requested_version}"
if self.supported_versions:
msg += f". Supported versions: {', '.join(self.supported_versions)}"
self.message = msg
super().__post_init__()
@dataclass
class UnsupportedExtensionError(A2AError):
"""Client does not support required extensions."""
code: int = field(default=A2AErrorCode.UNSUPPORTED_EXTENSION, init=False)
required_extensions: list[str] | None = None
def __post_init__(self) -> None:
if self.message is None and self.required_extensions:
self.message = f"Client does not support required extensions: {', '.join(self.required_extensions)}"
super().__post_init__()
@dataclass
class AuthenticationRequiredError(A2AError):
"""Authentication is required."""
code: int = field(default=A2AErrorCode.AUTHENTICATION_REQUIRED, init=False)
@dataclass
class AuthorizationFailedError(A2AError):
"""Authorization check failed."""
code: int = field(default=A2AErrorCode.AUTHORIZATION_FAILED, init=False)
required_scope: str | None = None
def __post_init__(self) -> None:
if self.message is None and self.required_scope:
self.message = (
f"Authorization failed. Required scope: {self.required_scope}"
)
super().__post_init__()
@dataclass
class RateLimitExceededError(A2AError):
"""Rate limit exceeded."""
code: int = field(default=A2AErrorCode.RATE_LIMIT_EXCEEDED, init=False)
retry_after: int | None = None
def __post_init__(self) -> None:
if self.message is None and self.retry_after:
self.message = (
f"Rate limit exceeded. Retry after {self.retry_after} seconds"
)
if self.retry_after:
self.data = {"retry_after": self.retry_after}
super().__post_init__()
@dataclass
class TaskTimeoutError(A2AError):
"""Task execution timed out."""
code: int = field(default=A2AErrorCode.TASK_TIMEOUT, init=False)
task_id: str | None = None
timeout_seconds: float | None = None
def __post_init__(self) -> None:
if self.message is None:
if self.task_id and self.timeout_seconds:
self.message = (
f"Task {self.task_id} timed out after {self.timeout_seconds}s"
)
elif self.task_id:
self.message = f"Task {self.task_id} timed out"
super().__post_init__()
@dataclass
class TransportNegotiationFailedError(A2AError):
"""Failed to negotiate a compatible transport protocol."""
code: int = field(default=A2AErrorCode.TRANSPORT_NEGOTIATION_FAILED, init=False)
client_transports: list[str] | None = None
server_transports: list[str] | None = None
def __post_init__(self) -> None:
if self.message is None and self.client_transports and self.server_transports:
self.message = (
f"Transport negotiation failed. Client: {self.client_transports}, "
f"Server: {self.server_transports}"
)
super().__post_init__()
@dataclass
class ContextNotFoundError(A2AError):
"""The specified context was not found."""
code: int = field(default=A2AErrorCode.CONTEXT_NOT_FOUND, init=False)
context_id: str | None = None
def __post_init__(self) -> None:
if self.message is None and self.context_id:
self.message = f"Context not found: {self.context_id}"
super().__post_init__()
@dataclass
class SkillNotFoundError(A2AError):
"""The specified skill was not found."""
code: int = field(default=A2AErrorCode.SKILL_NOT_FOUND, init=False)
skill_id: str | None = None
def __post_init__(self) -> None:
if self.message is None and self.skill_id:
self.message = f"Skill not found: {self.skill_id}"
super().__post_init__()
@dataclass
class ArtifactNotFoundError(A2AError):
"""The specified artifact was not found."""
code: int = field(default=A2AErrorCode.ARTIFACT_NOT_FOUND, init=False)
artifact_id: str | None = None
def __post_init__(self) -> None:
if self.message is None and self.artifact_id:
self.message = f"Artifact not found: {self.artifact_id}"
super().__post_init__()
def create_error_response(
code: int | A2AErrorCode,
message: str | None = None,
data: Any = None,
request_id: str | int | None = None,
) -> dict[str, Any]:
"""Create a JSON-RPC error response.
Args:
code: Error code (A2AErrorCode or int).
message: Optional error message (uses default if not provided).
data: Optional additional error data.
request_id: Request ID for correlation.
Returns:
Dict in JSON-RPC error response format.
"""
error = A2AError(code=int(code), message=message, data=data)
return error.to_response(request_id)
def is_retryable_error(code: int) -> bool:
"""Check if an error is potentially retryable.
Args:
code: Error code to check.
Returns:
True if the error might be resolved by retrying.
"""
retryable_codes = {
A2AErrorCode.INTERNAL_ERROR,
A2AErrorCode.RATE_LIMIT_EXCEEDED,
A2AErrorCode.TASK_TIMEOUT,
}
return code in retryable_codes
def is_client_error(code: int) -> bool:
"""Check if an error is a client-side error.
Args:
code: Error code to check.
Returns:
True if the error is due to client request issues.
"""
client_error_codes = {
A2AErrorCode.JSON_PARSE_ERROR,
A2AErrorCode.INVALID_REQUEST,
A2AErrorCode.METHOD_NOT_FOUND,
A2AErrorCode.INVALID_PARAMS,
A2AErrorCode.TASK_NOT_FOUND,
A2AErrorCode.CONTENT_TYPE_NOT_SUPPORTED,
A2AErrorCode.UNSUPPORTED_VERSION,
A2AErrorCode.UNSUPPORTED_EXTENSION,
A2AErrorCode.CONTEXT_NOT_FOUND,
A2AErrorCode.SKILL_NOT_FOUND,
A2AErrorCode.ARTIFACT_NOT_FOUND,
}
return code in client_error_codes
from crewai_a2a.errors import * # noqa: E402, F403

View File

@@ -1,37 +1,13 @@
"""A2A Protocol Extensions for CrewAI.
"""Backward-compatibility shim — use ``crewai_a2a.extensions`` instead."""
This module contains extensions to the A2A (Agent-to-Agent) protocol.
import warnings
**Client-side extensions** (A2AExtension) allow customizing how the A2A wrapper
processes requests and responses during delegation to remote agents. These provide
hooks for tool injection, prompt augmentation, and response processing.
**Server-side extensions** (ServerExtension) allow agents to offer additional
functionality beyond the core A2A specification. Clients activate extensions
via the X-A2A-Extensions header.
See: https://a2a-protocol.org/latest/topics/extensions/
"""
from crewai.a2a.extensions.base import (
A2AExtension,
ConversationState,
ExtensionRegistry,
ValidatedA2AExtension,
)
from crewai.a2a.extensions.server import (
ExtensionContext,
ServerExtension,
ServerExtensionRegistry,
warnings.warn(
"'crewai.a2a.extensions' has been moved to 'crewai_a2a.extensions'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
__all__ = [
"A2AExtension",
"ConversationState",
"ExtensionContext",
"ExtensionRegistry",
"ServerExtension",
"ServerExtensionRegistry",
"ValidatedA2AExtension",
]
from crewai_a2a.extensions import * # noqa: E402, F403

View File

@@ -1,238 +1,13 @@
"""Base extension interface for CrewAI A2A wrapper processing hooks.
"""Backward-compatibility shim — use ``crewai_a2a.extensions.base`` instead."""
This module defines the protocol for extending CrewAI's A2A wrapper functionality
with custom logic for tool injection, prompt augmentation, and response processing.
Note: These are CrewAI-specific processing hooks, NOT A2A protocol extensions.
A2A protocol extensions are capability declarations using AgentExtension objects
in AgentCard.capabilities.extensions, activated via the A2A-Extensions HTTP header.
See: https://a2a-protocol.org/latest/topics/extensions/
"""
from __future__ import annotations
from collections.abc import Sequence
from typing import TYPE_CHECKING, Annotated, Any, Protocol, runtime_checkable
from pydantic import BeforeValidator
import warnings
if TYPE_CHECKING:
from a2a.types import Message
warnings.warn(
"'crewai.a2a.extensions.base' has been moved to 'crewai_a2a.extensions.base'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
from crewai.agent.core import Agent
def _validate_a2a_extension(v: Any) -> Any:
"""Validate that value implements A2AExtension protocol."""
if not isinstance(v, A2AExtension):
raise ValueError(
f"Value must implement A2AExtension protocol. "
f"Got {type(v).__name__} which is missing required methods."
)
return v
ValidatedA2AExtension = Annotated[Any, BeforeValidator(_validate_a2a_extension)]
@runtime_checkable
class ConversationState(Protocol):
"""Protocol for extension-specific conversation state.
Extensions can define their own state classes that implement this protocol
to track conversation-specific data extracted from message history.
"""
def is_ready(self) -> bool:
"""Check if the state indicates readiness for some action.
Returns:
True if the state is ready, False otherwise.
"""
...
@runtime_checkable
class A2AExtension(Protocol):
"""Protocol for A2A wrapper extensions.
Extensions can implement this protocol to inject custom logic into
the A2A conversation flow at various integration points.
Example:
class MyExtension:
def inject_tools(self, agent: Agent) -> None:
# Add custom tools to the agent
pass
def extract_state_from_history(
self, conversation_history: Sequence[Message]
) -> ConversationState | None:
# Extract state from conversation
return None
def augment_prompt(
self, base_prompt: str, conversation_state: ConversationState | None
) -> str:
# Add custom instructions
return base_prompt
def process_response(
self, agent_response: Any, conversation_state: ConversationState | None
) -> Any:
# Modify response if needed
return agent_response
"""
def inject_tools(self, agent: Agent) -> None:
"""Inject extension-specific tools into the agent.
Called when an agent is wrapped with A2A capabilities. Extensions
can add tools that enable extension-specific functionality.
Args:
agent: The agent instance to inject tools into.
"""
...
def extract_state_from_history(
self, conversation_history: Sequence[Message]
) -> ConversationState | None:
"""Extract extension-specific state from conversation history.
Called during prompt augmentation to allow extensions to analyze
the conversation history and extract relevant state information.
Args:
conversation_history: The sequence of A2A messages exchanged.
Returns:
Extension-specific conversation state, or None if no relevant state.
"""
...
def augment_prompt(
self,
base_prompt: str,
conversation_state: ConversationState | None,
) -> str:
"""Augment the task prompt with extension-specific instructions.
Called during prompt augmentation to allow extensions to add
custom instructions based on conversation state.
Args:
base_prompt: The base prompt to augment.
conversation_state: Extension-specific state from extract_state_from_history.
Returns:
The augmented prompt with extension-specific instructions.
"""
...
def process_response(
self,
agent_response: Any,
conversation_state: ConversationState | None,
) -> Any:
"""Process and potentially modify the agent response.
Called after parsing the agent's response, allowing extensions to
enhance or modify the response based on conversation state.
Args:
agent_response: The parsed agent response.
conversation_state: Extension-specific state from extract_state_from_history.
Returns:
The processed agent response (may be modified or original).
"""
...
class ExtensionRegistry:
"""Registry for managing A2A extensions.
Maintains a collection of extensions and provides methods to invoke
their hooks at various integration points.
"""
def __init__(self) -> None:
"""Initialize the extension registry."""
self._extensions: list[A2AExtension] = []
def register(self, extension: A2AExtension) -> None:
"""Register an extension.
Args:
extension: The extension to register.
"""
self._extensions.append(extension)
def inject_all_tools(self, agent: Agent) -> None:
"""Inject tools from all registered extensions.
Args:
agent: The agent instance to inject tools into.
"""
for extension in self._extensions:
extension.inject_tools(agent)
def extract_all_states(
self, conversation_history: Sequence[Message]
) -> dict[type[A2AExtension], ConversationState]:
"""Extract conversation states from all registered extensions.
Args:
conversation_history: The sequence of A2A messages exchanged.
Returns:
Mapping of extension types to their conversation states.
"""
states: dict[type[A2AExtension], ConversationState] = {}
for extension in self._extensions:
state = extension.extract_state_from_history(conversation_history)
if state is not None:
states[type(extension)] = state
return states
def augment_prompt_with_all(
self,
base_prompt: str,
extension_states: dict[type[A2AExtension], ConversationState],
) -> str:
"""Augment prompt with instructions from all registered extensions.
Args:
base_prompt: The base prompt to augment.
extension_states: Mapping of extension types to conversation states.
Returns:
The fully augmented prompt.
"""
augmented = base_prompt
for extension in self._extensions:
state = extension_states.get(type(extension))
augmented = extension.augment_prompt(augmented, state)
return augmented
def process_response_with_all(
self,
agent_response: Any,
extension_states: dict[type[A2AExtension], ConversationState],
) -> Any:
"""Process response through all registered extensions.
Args:
agent_response: The parsed agent response.
extension_states: Mapping of extension types to conversation states.
Returns:
The processed agent response.
"""
processed = agent_response
for extension in self._extensions:
state = extension_states.get(type(extension))
processed = extension.process_response(processed, state)
return processed
from crewai_a2a.extensions.base import * # noqa: E402, F403

View File

@@ -1,170 +1,13 @@
"""A2A Protocol extension utilities.
"""Backward-compatibility shim — use ``crewai_a2a.extensions.registry`` instead."""
This module provides utilities for working with A2A protocol extensions as
defined in the A2A specification. Extensions are capability declarations in
AgentCard.capabilities.extensions using AgentExtension objects, activated
via the X-A2A-Extensions HTTP header.
import warnings
See: https://a2a-protocol.org/latest/topics/extensions/
"""
from __future__ import annotations
from typing import Any
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
from a2a.extensions.common import (
HTTP_EXTENSION_HEADER,
warnings.warn(
"'crewai.a2a.extensions.registry' has been moved to 'crewai_a2a.extensions.registry'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
from a2a.types import AgentCard, AgentExtension
from crewai.a2a.config import A2AClientConfig, A2AConfig
from crewai.a2a.extensions.base import ExtensionRegistry
def get_extensions_from_config(
a2a_config: list[A2AConfig | A2AClientConfig] | A2AConfig | A2AClientConfig,
) -> list[str]:
"""Extract extension URIs from A2A configuration.
Args:
a2a_config: A2A configuration (single or list).
Returns:
Deduplicated list of extension URIs from all configs.
"""
configs = a2a_config if isinstance(a2a_config, list) else [a2a_config]
seen: set[str] = set()
result: list[str] = []
for config in configs:
if not isinstance(config, A2AClientConfig):
continue
for uri in config.extensions:
if uri not in seen:
seen.add(uri)
result.append(uri)
return result
class ExtensionsMiddleware(ClientCallInterceptor):
"""Middleware to add X-A2A-Extensions header to requests.
This middleware adds the extensions header to all outgoing requests,
declaring which A2A protocol extensions the client supports.
"""
def __init__(self, extensions: list[str]) -> None:
"""Initialize with extension URIs.
Args:
extensions: List of extension URIs the client supports.
"""
self._extensions = extensions
async def intercept(
self,
method_name: str,
request_payload: dict[str, Any],
http_kwargs: dict[str, Any],
agent_card: AgentCard | None,
context: ClientCallContext | None,
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Add extensions header to the request.
Args:
method_name: The A2A method being called.
request_payload: The JSON-RPC request payload.
http_kwargs: HTTP request kwargs (headers, etc).
agent_card: The target agent's card.
context: Optional call context.
Returns:
Tuple of (request_payload, modified_http_kwargs).
"""
if self._extensions:
headers = http_kwargs.setdefault("headers", {})
headers[HTTP_EXTENSION_HEADER] = ",".join(self._extensions)
return request_payload, http_kwargs
def validate_required_extensions(
agent_card: AgentCard,
client_extensions: list[str] | None,
) -> list[AgentExtension]:
"""Validate that client supports all required extensions from agent.
Args:
agent_card: The agent's card with declared extensions.
client_extensions: Extension URIs the client supports.
Returns:
List of unsupported required extensions.
Raises:
None - returns list of unsupported extensions for caller to handle.
"""
unsupported: list[AgentExtension] = []
client_set = set(client_extensions or [])
if not agent_card.capabilities or not agent_card.capabilities.extensions:
return unsupported
unsupported.extend(
ext
for ext in agent_card.capabilities.extensions
if ext.required and ext.uri not in client_set
)
return unsupported
def create_extension_registry_from_config(
a2a_config: list[A2AConfig | A2AClientConfig] | A2AConfig | A2AClientConfig,
) -> ExtensionRegistry:
"""Create an extension registry from A2A client configuration.
Extracts client_extensions from each A2AClientConfig and registers them
with the ExtensionRegistry. These extensions provide CrewAI-specific
processing hooks (tool injection, prompt augmentation, response processing).
Note: A2A protocol extensions (URI strings sent via X-A2A-Extensions header)
are handled separately via get_extensions_from_config() and ExtensionsMiddleware.
Args:
a2a_config: A2A configuration (single or list).
Returns:
Extension registry with all client_extensions registered.
Example:
class LoggingExtension:
def inject_tools(self, agent): pass
def extract_state_from_history(self, history): return None
def augment_prompt(self, prompt, state): return prompt
def process_response(self, response, state):
print(f"Response: {response}")
return response
config = A2AClientConfig(
endpoint="https://agent.example.com",
client_extensions=[LoggingExtension()],
)
registry = create_extension_registry_from_config(config)
"""
registry = ExtensionRegistry()
configs = a2a_config if isinstance(a2a_config, list) else [a2a_config]
seen: set[int] = set()
for config in configs:
if isinstance(config, (A2AConfig, A2AClientConfig)):
client_exts = getattr(config, "client_extensions", [])
for extension in client_exts:
ext_id = id(extension)
if ext_id not in seen:
seen.add(ext_id)
registry.register(extension)
return registry
from crewai_a2a.extensions.registry import * # noqa: E402, F403

View File

@@ -1,305 +1,13 @@
"""A2A protocol server extensions for CrewAI agents.
"""Backward-compatibility shim — use ``crewai_a2a.extensions.server`` instead."""
This module provides the base class and context for implementing A2A protocol
extensions on the server side. Extensions allow agents to offer additional
functionality beyond the core A2A specification.
import warnings
See: https://a2a-protocol.org/latest/topics/extensions/
"""
from __future__ import annotations
warnings.warn(
"'crewai.a2a.extensions.server' has been moved to 'crewai_a2a.extensions.server'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
import logging
from typing import TYPE_CHECKING, Annotated, Any
from a2a.types import AgentExtension
from pydantic_core import CoreSchema, core_schema
if TYPE_CHECKING:
from a2a.server.context import ServerCallContext
from pydantic import GetCoreSchemaHandler
logger = logging.getLogger(__name__)
@dataclass
class ExtensionContext:
"""Context passed to extension hooks during request processing.
Provides access to request metadata, client extensions, and shared state
that extensions can read from and write to.
Attributes:
metadata: Request metadata dict, includes extension-namespaced keys.
client_extensions: Set of extension URIs the client declared support for.
state: Mutable dict for extensions to share data during request lifecycle.
server_context: The underlying A2A server call context.
"""
metadata: dict[str, Any]
client_extensions: set[str]
state: dict[str, Any] = field(default_factory=dict)
server_context: ServerCallContext | None = None
def get_extension_metadata(self, uri: str, key: str) -> Any | None:
"""Get extension-specific metadata value.
Extension metadata uses namespaced keys in the format:
"{extension_uri}/{key}"
Args:
uri: The extension URI.
key: The metadata key within the extension namespace.
Returns:
The metadata value, or None if not present.
"""
full_key = f"{uri}/{key}"
return self.metadata.get(full_key)
def set_extension_metadata(self, uri: str, key: str, value: Any) -> None:
"""Set extension-specific metadata value.
Args:
uri: The extension URI.
key: The metadata key within the extension namespace.
value: The value to set.
"""
full_key = f"{uri}/{key}"
self.metadata[full_key] = value
class ServerExtension(ABC):
"""Base class for A2A protocol server extensions.
Subclass this to create custom extensions that modify agent behavior
when clients activate them. Extensions are identified by URI and can
be marked as required.
Example:
class SamplingExtension(ServerExtension):
uri = "urn:crewai:ext:sampling/v1"
required = True
def __init__(self, max_tokens: int = 4096):
self.max_tokens = max_tokens
@property
def params(self) -> dict[str, Any]:
return {"max_tokens": self.max_tokens}
async def on_request(self, context: ExtensionContext) -> None:
limit = context.get_extension_metadata(self.uri, "limit")
if limit:
context.state["token_limit"] = int(limit)
async def on_response(self, context: ExtensionContext, result: Any) -> Any:
return result
"""
uri: Annotated[str, "Extension URI identifier. Must be unique."]
required: Annotated[bool, "Whether clients must support this extension."] = False
description: Annotated[
str | None, "Human-readable description of the extension."
] = None
@classmethod
def __get_pydantic_core_schema__(
cls,
_source_type: Any,
_handler: GetCoreSchemaHandler,
) -> CoreSchema:
"""Tell Pydantic how to validate ServerExtension instances."""
return core_schema.is_instance_schema(cls)
@property
def params(self) -> dict[str, Any] | None:
"""Extension parameters to advertise in AgentCard.
Override this property to expose configuration that clients can read.
Returns:
Dict of parameter names to values, or None.
"""
return None
def agent_extension(self) -> AgentExtension:
"""Generate the AgentExtension object for the AgentCard.
Returns:
AgentExtension with this extension's URI, required flag, and params.
"""
return AgentExtension(
uri=self.uri,
required=self.required if self.required else None,
description=self.description,
params=self.params,
)
def is_active(self, context: ExtensionContext) -> bool:
"""Check if this extension is active for the current request.
An extension is active if the client declared support for it.
Args:
context: The extension context for the current request.
Returns:
True if the client supports this extension.
"""
return self.uri in context.client_extensions
@abstractmethod
async def on_request(self, context: ExtensionContext) -> None:
"""Called before agent execution if extension is active.
Use this hook to:
- Read extension-specific metadata from the request
- Set up state for the execution
- Modify execution parameters via context.state
Args:
context: The extension context with request metadata and state.
"""
...
@abstractmethod
async def on_response(self, context: ExtensionContext, result: Any) -> Any:
"""Called after agent execution if extension is active.
Use this hook to:
- Modify or enhance the result
- Add extension-specific metadata to the response
- Clean up any resources
Args:
context: The extension context with request metadata and state.
result: The agent execution result.
Returns:
The result, potentially modified.
"""
...
class ServerExtensionRegistry:
"""Registry for managing server-side A2A protocol extensions.
Collects extensions and provides methods to generate AgentCapabilities
and invoke extension hooks during request processing.
"""
def __init__(self, extensions: list[ServerExtension] | None = None) -> None:
"""Initialize the registry with optional extensions.
Args:
extensions: Initial list of extensions to register.
"""
self._extensions: list[ServerExtension] = list(extensions) if extensions else []
self._by_uri: dict[str, ServerExtension] = {
ext.uri: ext for ext in self._extensions
}
def register(self, extension: ServerExtension) -> None:
"""Register an extension.
Args:
extension: The extension to register.
Raises:
ValueError: If an extension with the same URI is already registered.
"""
if extension.uri in self._by_uri:
raise ValueError(f"Extension already registered: {extension.uri}")
self._extensions.append(extension)
self._by_uri[extension.uri] = extension
def get_agent_extensions(self) -> list[AgentExtension]:
"""Get AgentExtension objects for all registered extensions.
Returns:
List of AgentExtension objects for the AgentCard.
"""
return [ext.agent_extension() for ext in self._extensions]
def get_extension(self, uri: str) -> ServerExtension | None:
"""Get an extension by URI.
Args:
uri: The extension URI.
Returns:
The extension, or None if not found.
"""
return self._by_uri.get(uri)
@staticmethod
def create_context(
metadata: dict[str, Any],
client_extensions: set[str],
server_context: ServerCallContext | None = None,
) -> ExtensionContext:
"""Create an ExtensionContext for a request.
Args:
metadata: Request metadata dict.
client_extensions: Set of extension URIs from client.
server_context: Optional server call context.
Returns:
ExtensionContext for use in hooks.
"""
return ExtensionContext(
metadata=metadata,
client_extensions=client_extensions,
server_context=server_context,
)
async def invoke_on_request(self, context: ExtensionContext) -> None:
"""Invoke on_request hooks for all active extensions.
Tracks activated extensions and isolates errors from individual hooks.
Args:
context: The extension context for the request.
"""
for extension in self._extensions:
if extension.is_active(context):
try:
await extension.on_request(context)
if context.server_context is not None:
context.server_context.activated_extensions.add(extension.uri)
except Exception:
logger.exception(
"Extension on_request hook failed",
extra={"extension": extension.uri},
)
async def invoke_on_response(self, context: ExtensionContext, result: Any) -> Any:
"""Invoke on_response hooks for all active extensions.
Isolates errors from individual hooks to prevent one failing extension
from breaking the entire response.
Args:
context: The extension context for the request.
result: The agent execution result.
Returns:
The result after all extensions have processed it.
"""
processed = result
for extension in self._extensions:
if extension.is_active(context):
try:
processed = await extension.on_response(context, processed)
except Exception:
logger.exception(
"Extension on_response hook failed",
extra={"extension": extension.uri},
)
return processed
from crewai_a2a.extensions.server import * # noqa: E402, F403

View File

@@ -1,480 +1,13 @@
"""Helper functions for processing A2A task results."""
"""Backward-compatibility shim — use ``crewai_a2a.task_helpers`` instead."""
from __future__ import annotations
import warnings
from collections.abc import AsyncIterator
from typing import TYPE_CHECKING, Any, TypedDict
import uuid
from a2a.client.errors import A2AClientHTTPError
from a2a.types import (
AgentCard,
Message,
Part,
Role,
Task,
TaskArtifactUpdateEvent,
TaskState,
TaskStatusUpdateEvent,
TextPart,
)
from typing_extensions import NotRequired
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.a2a_events import (
A2AConnectionErrorEvent,
A2AResponseReceivedEvent,
warnings.warn(
"'crewai.a2a.task_helpers' has been moved to 'crewai_a2a.task_helpers'. "
"Please update your imports. The old path will be removed in v2.0.0.",
FutureWarning,
stacklevel=2,
)
if TYPE_CHECKING:
from a2a.types import Task as A2ATask
SendMessageEvent = (
tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None] | Message
)
TERMINAL_STATES: frozenset[TaskState] = frozenset(
{
TaskState.completed,
TaskState.failed,
TaskState.rejected,
TaskState.canceled,
}
)
ACTIONABLE_STATES: frozenset[TaskState] = frozenset(
{
TaskState.input_required,
TaskState.auth_required,
}
)
PENDING_STATES: frozenset[TaskState] = frozenset(
{
TaskState.submitted,
TaskState.working,
}
)
class TaskStateResult(TypedDict):
"""Result dictionary from processing A2A task state."""
status: TaskState
history: list[Message]
result: NotRequired[str]
error: NotRequired[str]
agent_card: NotRequired[dict[str, Any]]
a2a_agent_name: NotRequired[str | None]
def extract_task_result_parts(a2a_task: A2ATask) -> list[str]:
"""Extract result parts from A2A task status message, history, and artifacts.
Args:
a2a_task: A2A Task object with status, history, and artifacts
Returns:
List of result text parts
"""
result_parts: list[str] = []
if a2a_task.status and a2a_task.status.message:
msg = a2a_task.status.message
result_parts.extend(
part.root.text for part in msg.parts if part.root.kind == "text"
)
if not result_parts and a2a_task.history:
for history_msg in reversed(a2a_task.history):
if history_msg.role == Role.agent:
result_parts.extend(
part.root.text
for part in history_msg.parts
if part.root.kind == "text"
)
break
if a2a_task.artifacts:
result_parts.extend(
part.root.text
for artifact in a2a_task.artifacts
for part in artifact.parts
if part.root.kind == "text"
)
return result_parts
def extract_error_message(a2a_task: A2ATask, default: str) -> str:
"""Extract error message from A2A task.
Args:
a2a_task: A2A Task object
default: Default message if no error found
Returns:
Error message string
"""
if a2a_task.status and a2a_task.status.message:
msg = a2a_task.status.message
if msg:
for part in msg.parts:
if part.root.kind == "text":
return str(part.root.text)
return str(msg)
if a2a_task.history:
for history_msg in reversed(a2a_task.history):
for part in history_msg.parts:
if part.root.kind == "text":
return str(part.root.text)
return default
def process_task_state(
a2a_task: A2ATask,
new_messages: list[Message],
agent_card: AgentCard,
turn_number: int,
is_multiturn: bool,
agent_role: str | None,
result_parts: list[str] | None = None,
endpoint: str | None = None,
a2a_agent_name: str | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
is_final: bool = True,
) -> TaskStateResult | None:
"""Process A2A task state and return result dictionary.
Shared logic for both polling and streaming handlers.
Args:
a2a_task: The A2A task to process.
new_messages: List to collect messages (modified in place).
agent_card: The agent card.
turn_number: Current turn number.
is_multiturn: Whether multi-turn conversation.
agent_role: Agent role for logging.
result_parts: Accumulated result parts (streaming passes accumulated,
polling passes None to extract from task).
endpoint: A2A agent endpoint URL.
a2a_agent_name: Name of the A2A agent from agent card.
from_task: Optional CrewAI Task for event metadata.
from_agent: Optional CrewAI Agent for event metadata.
is_final: Whether this is the final response in the stream.
Returns:
Result dictionary if terminal/actionable state, None otherwise.
"""
if result_parts is None:
result_parts = []
if a2a_task.status.state == TaskState.completed:
if not result_parts:
extracted_parts = extract_task_result_parts(a2a_task)
result_parts.extend(extracted_parts)
if a2a_task.history:
new_messages.extend(a2a_task.history)
response_text = " ".join(result_parts) if result_parts else ""
message_id = None
if a2a_task.status and a2a_task.status.message:
message_id = a2a_task.status.message.message_id
crewai_event_bus.emit(
None,
A2AResponseReceivedEvent(
response=response_text,
turn_number=turn_number,
context_id=a2a_task.context_id,
message_id=message_id,
is_multiturn=is_multiturn,
status="completed",
final=is_final,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
return TaskStateResult(
status=TaskState.completed,
agent_card=agent_card.model_dump(exclude_none=True),
result=response_text,
history=new_messages,
)
if a2a_task.status.state == TaskState.input_required:
if a2a_task.history:
new_messages.extend(a2a_task.history)
response_text = extract_error_message(a2a_task, "Additional input required")
if response_text and not a2a_task.history:
agent_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=response_text))],
context_id=a2a_task.context_id,
task_id=a2a_task.id,
)
new_messages.append(agent_message)
input_message_id = None
if a2a_task.status and a2a_task.status.message:
input_message_id = a2a_task.status.message.message_id
crewai_event_bus.emit(
None,
A2AResponseReceivedEvent(
response=response_text,
turn_number=turn_number,
context_id=a2a_task.context_id,
message_id=input_message_id,
is_multiturn=is_multiturn,
status="input_required",
final=is_final,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
return TaskStateResult(
status=TaskState.input_required,
error=response_text,
history=new_messages,
agent_card=agent_card.model_dump(exclude_none=True),
)
if a2a_task.status.state in {TaskState.failed, TaskState.rejected}:
error_msg = extract_error_message(a2a_task, "Task failed without error message")
if a2a_task.history:
new_messages.extend(a2a_task.history)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
if a2a_task.status.state == TaskState.auth_required:
error_msg = extract_error_message(a2a_task, "Authentication required")
return TaskStateResult(
status=TaskState.auth_required,
error=error_msg,
history=new_messages,
)
if a2a_task.status.state == TaskState.canceled:
error_msg = extract_error_message(a2a_task, "Task was canceled")
return TaskStateResult(
status=TaskState.canceled,
error=error_msg,
history=new_messages,
)
if a2a_task.status.state in PENDING_STATES:
return None
return None
async def send_message_and_get_task_id(
event_stream: AsyncIterator[SendMessageEvent],
new_messages: list[Message],
agent_card: AgentCard,
turn_number: int,
is_multiturn: bool,
agent_role: str | None,
from_task: Any | None = None,
from_agent: Any | None = None,
endpoint: str | None = None,
a2a_agent_name: str | None = None,
context_id: str | None = None,
) -> str | TaskStateResult:
"""Send message and process initial response.
Handles the common pattern of sending a message and either:
- Getting an immediate Message response (task completed synchronously)
- Getting a Task that needs polling/waiting for completion
Args:
event_stream: Async iterator from client.send_message()
new_messages: List to collect messages (modified in place)
agent_card: The agent card
turn_number: Current turn number
is_multiturn: Whether multi-turn conversation
agent_role: Agent role for logging
from_task: Optional CrewAI Task object for event metadata.
from_agent: Optional CrewAI Agent object for event metadata.
endpoint: Optional A2A endpoint URL.
a2a_agent_name: Optional A2A agent name.
context_id: Optional A2A context ID for correlation.
Returns:
Task ID string if agent needs polling/waiting, or TaskStateResult if done.
"""
try:
async for event in event_stream:
if isinstance(event, Message):
new_messages.append(event)
result_parts = [
part.root.text for part in event.parts if part.root.kind == "text"
]
response_text = " ".join(result_parts) if result_parts else ""
crewai_event_bus.emit(
None,
A2AResponseReceivedEvent(
response=response_text,
turn_number=turn_number,
context_id=event.context_id,
message_id=event.message_id,
is_multiturn=is_multiturn,
status="completed",
final=True,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
return TaskStateResult(
status=TaskState.completed,
result=response_text,
history=new_messages,
agent_card=agent_card.model_dump(exclude_none=True),
)
if isinstance(event, tuple):
a2a_task, _ = event
if a2a_task.status.state in TERMINAL_STATES | ACTIONABLE_STATES:
result = process_task_state(
a2a_task=a2a_task,
new_messages=new_messages,
agent_card=agent_card,
turn_number=turn_number,
is_multiturn=is_multiturn,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
)
if result:
return result
return a2a_task.id
return TaskStateResult(
status=TaskState.failed,
error="No task ID received from initial message",
history=new_messages,
)
except A2AClientHTTPError as e:
error_msg = f"HTTP Error {e.status_code}: {e!s}"
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=context_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
None,
A2AConnectionErrorEvent(
endpoint=endpoint or "",
error=str(e),
error_type="http_error",
status_code=e.status_code,
a2a_agent_name=a2a_agent_name,
operation="send_message",
context_id=context_id,
from_task=from_task,
from_agent=from_agent,
),
)
crewai_event_bus.emit(
None,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=turn_number,
context_id=context_id,
is_multiturn=is_multiturn,
status="failed",
final=True,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
except Exception as e:
error_msg = f"Unexpected error during send_message: {e!s}"
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=context_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
None,
A2AConnectionErrorEvent(
endpoint=endpoint or "",
error=str(e),
error_type="unexpected_error",
a2a_agent_name=a2a_agent_name,
operation="send_message",
context_id=context_id,
from_task=from_task,
from_agent=from_agent,
),
)
crewai_event_bus.emit(
None,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=turn_number,
context_id=context_id,
is_multiturn=is_multiturn,
status="failed",
final=True,
agent_role=agent_role,
endpoint=endpoint,
a2a_agent_name=a2a_agent_name,
from_task=from_task,
from_agent=from_agent,
),
)
return TaskStateResult(
status=TaskState.failed,
error=error_msg,
history=new_messages,
)
finally:
aclose = getattr(event_stream, "aclose", None)
if aclose:
await aclose()
from crewai_a2a.task_helpers import * # noqa: E402, F403

Some files were not shown because too many files have changed in this diff Show More