mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-11 05:22:41 +00:00
Compare commits
45 Commits
1.0.0a4
...
lorenze/na
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
505e3ceea5 | ||
|
|
881b5befad | ||
|
|
97ecd327a8 | ||
|
|
84a3cbad7c | ||
|
|
81cd269318 | ||
|
|
f3da80a1f1 | ||
|
|
037e2b4631 | ||
|
|
fbd72ded44 | ||
|
|
83bc40eefe | ||
|
|
57052b94d3 | ||
|
|
18943babff | ||
|
|
fa5a901d93 | ||
|
|
7e8d33104a | ||
|
|
8a9835a59f | ||
|
|
7a59769c7e | ||
|
|
ba30374ac4 | ||
|
|
6150a358a3 | ||
|
|
61e3ec2e6f | ||
|
|
c5455142c3 | ||
|
|
44bbccdb75 | ||
|
|
0073b4206f | ||
|
|
dcd57ccc9f | ||
|
|
06a45b29db | ||
|
|
7351e4b0ef | ||
|
|
38e7a37485 | ||
|
|
21ba6d5b54 | ||
|
|
d9b68ddd85 | ||
|
|
2d5ad7a187 | ||
|
|
97c2cbd110 | ||
|
|
7045ed389a | ||
|
|
3fc1381e76 | ||
|
|
a9aff87db3 | ||
|
|
c9ff264e8e | ||
|
|
0816b810c7 | ||
|
|
96d142e353 | ||
|
|
53b239c6df | ||
|
|
f0fb349ddf | ||
|
|
cec4e4c2e9 | ||
|
|
bf2e2a42da | ||
|
|
814c962196 | ||
|
|
2ebb2e845f | ||
|
|
6c5ac13242 | ||
|
|
7b550ebfe8 | ||
|
|
52b2f07c9f | ||
|
|
da331ce422 |
@@ -825,6 +825,12 @@
|
||||
"group": "Triggers",
|
||||
"pages": [
|
||||
"pt-BR/enterprise/guides/automation-triggers",
|
||||
"pt-BR/enterprise/guides/gmail-trigger",
|
||||
"pt-BR/enterprise/guides/google-calendar-trigger",
|
||||
"pt-BR/enterprise/guides/google-drive-trigger",
|
||||
"pt-BR/enterprise/guides/outlook-trigger",
|
||||
"pt-BR/enterprise/guides/onedrive-trigger",
|
||||
"pt-BR/enterprise/guides/microsoft-teams-trigger",
|
||||
"pt-BR/enterprise/guides/slack-trigger",
|
||||
"pt-BR/enterprise/guides/hubspot-trigger",
|
||||
"pt-BR/enterprise/guides/salesforce-trigger",
|
||||
@@ -1250,6 +1256,12 @@
|
||||
"group": "트리거",
|
||||
"pages": [
|
||||
"ko/enterprise/guides/automation-triggers",
|
||||
"ko/enterprise/guides/gmail-trigger",
|
||||
"ko/enterprise/guides/google-calendar-trigger",
|
||||
"ko/enterprise/guides/google-drive-trigger",
|
||||
"ko/enterprise/guides/outlook-trigger",
|
||||
"ko/enterprise/guides/onedrive-trigger",
|
||||
"ko/enterprise/guides/microsoft-teams-trigger",
|
||||
"ko/enterprise/guides/slack-trigger",
|
||||
"ko/enterprise/guides/hubspot-trigger",
|
||||
"ko/enterprise/guides/salesforce-trigger",
|
||||
|
||||
@@ -57,6 +57,22 @@ Tools & Integrations is the central hub for connecting third‑party apps and ma
|
||||
uv add crewai-tools
|
||||
```
|
||||
|
||||
### Environment Variable Setup
|
||||
|
||||
<Note>
|
||||
To use integrations with `Agent(apps=[])`, you must set the `CREWAI_PLATFORM_INTEGRATION_TOKEN` environment variable with your Enterprise Token.
|
||||
</Note>
|
||||
|
||||
```bash
|
||||
export CREWAI_PLATFORM_INTEGRATION_TOKEN="your_enterprise_token"
|
||||
```
|
||||
|
||||
Or add it to your `.env` file:
|
||||
|
||||
```
|
||||
CREWAI_PLATFORM_INTEGRATION_TOKEN=your_enterprise_token
|
||||
```
|
||||
|
||||
### Usage Example
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -117,27 +117,50 @@ Before wiring a trigger into production, make sure you:
|
||||
- Decide whether to pass trigger context automatically using `allow_crewai_trigger_context`
|
||||
- Set up monitoring—webhook logs, CrewAI execution history, and optional external alerting
|
||||
|
||||
### Payload & Crew Examples Repository
|
||||
### Testing Triggers Locally with CLI
|
||||
|
||||
We maintain a comprehensive repository with end-to-end trigger examples to help you build and test your automations:
|
||||
The CrewAI CLI provides powerful commands to help you develop and test trigger-driven automations without deploying to production.
|
||||
|
||||
This repository contains:
|
||||
#### List Available Triggers
|
||||
|
||||
- **Realistic payload samples** for every supported trigger integration
|
||||
- **Ready-to-run crew implementations** that parse each payload and turn it into a business workflow
|
||||
- **Multiple scenarios per integration** (e.g., new events, updates, deletions) so you can match the shape of your data
|
||||
View all available triggers for your connected integrations:
|
||||
|
||||
| Integration | When it fires | Payload Samples | Crew Examples |
|
||||
| :-- | :-- | :-- | :-- |
|
||||
| Gmail | New messages, thread updates | [New alerts, thread updates](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/gmail) | [`new-email-crew.py`, `gmail-alert-crew.py`](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/gmail) |
|
||||
| Google Calendar | Event created / updated / started / ended / cancelled | [Event lifecycle payloads](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/google_calendar) | [`calendar-event-crew.py`, `calendar-meeting-crew.py`, `calendar-working-location-crew.py`](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/google_calendar) |
|
||||
| Google Drive | File created / updated / deleted | [File lifecycle payloads](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/google_drive) | [`drive-file-crew.py`, `drive-file-deletion-crew.py`](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/google_drive) |
|
||||
| Outlook | New email, calendar event removed | [Outlook payloads](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/outlook) | [`outlook-message-crew.py`, `outlook-event-removal-crew.py`](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/outlook) |
|
||||
| OneDrive | File operations (create, update, share, delete) | [OneDrive payloads](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/onedrive) | [`onedrive-file-crew.py`](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/onedrive) |
|
||||
| HubSpot | Record created / updated (contacts, companies, deals) | [HubSpot payloads](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/hubspot) | [`hubspot-company-crew.py`, `hubspot-contact-crew.py`, `hubspot-record-crew.py`](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/hubspot) |
|
||||
| Microsoft Teams | Chat thread created | [Teams chat payload](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/microsoft-teams) | [`teams-chat-created-crew.py`](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/microsoft-teams) |
|
||||
```bash
|
||||
crewai triggers list
|
||||
```
|
||||
|
||||
This command displays all triggers available based on your connected integrations, showing:
|
||||
- Integration name and connection status
|
||||
- Available trigger types
|
||||
- Trigger names and descriptions
|
||||
|
||||
#### Simulate Trigger Execution
|
||||
|
||||
Test your crew with realistic trigger payloads before deployment:
|
||||
|
||||
```bash
|
||||
crewai triggers run <trigger_name>
|
||||
```
|
||||
|
||||
For example:
|
||||
|
||||
```bash
|
||||
crewai triggers run microsoft_onedrive/file_changed
|
||||
```
|
||||
|
||||
This command:
|
||||
- Executes your crew locally
|
||||
- Passes a complete, realistic trigger payload
|
||||
- Simulates exactly how your crew will be called in production
|
||||
|
||||
<Warning>
|
||||
**Important Development Notes:**
|
||||
- Use `crewai triggers run <trigger>` to simulate trigger execution during development
|
||||
- Using `crewai run` will NOT simulate trigger calls and won't pass the trigger payload
|
||||
- After deployment, your crew will be executed with the actual trigger payload
|
||||
- If your crew expects parameters that aren't in the trigger payload, execution may fail
|
||||
</Warning>
|
||||
|
||||
Use these samples to understand payload shape, copy the matching crew, and then replace the test payload with your live trigger data.
|
||||
|
||||
### Triggers with Crew
|
||||
|
||||
@@ -241,15 +264,20 @@ def delegate_to_crew(self, crewai_trigger_payload: dict = None):
|
||||
## Troubleshooting
|
||||
|
||||
**Trigger not firing:**
|
||||
- Verify the trigger is enabled
|
||||
- Check integration connection status
|
||||
- Verify the trigger is enabled in your deployment's Triggers tab
|
||||
- Check integration connection status under Tools & Integrations
|
||||
- Ensure all required environment variables are properly configured
|
||||
|
||||
**Execution failures:**
|
||||
- Check the execution logs for error details
|
||||
- If you are developing, make sure the inputs include the `crewai_trigger_payload` parameter with the correct payload
|
||||
- Use `crewai triggers run <trigger_name>` to test locally and see the exact payload structure
|
||||
- Verify your crew can handle the `crewai_trigger_payload` parameter
|
||||
- Ensure your crew doesn't expect parameters that aren't included in the trigger payload
|
||||
|
||||
**Development issues:**
|
||||
- Always test with `crewai triggers run <trigger>` before deploying to see the complete payload
|
||||
- Remember that `crewai run` does NOT simulate trigger calls—use `crewai triggers run` instead
|
||||
- Use `crewai triggers list` to verify which triggers are available for your connected integrations
|
||||
- After deployment, your crew will receive the actual trigger payload, so test thoroughly locally first
|
||||
|
||||
Automation triggers transform your CrewAI deployments into responsive, event-driven systems that can seamlessly integrate with your existing business processes and tools.
|
||||
|
||||
<Card title="CrewAI AMP Trigger Examples" href="https://github.com/crewAIInc/crewai-enterprise-trigger-examples" icon="github">
|
||||
Check them out on GitHub!
|
||||
</Card>
|
||||
|
||||
@@ -51,16 +51,25 @@ class GmailProcessingCrew:
|
||||
)
|
||||
```
|
||||
|
||||
The Gmail payload will be available via the standard context mechanisms. See the payload samples repository for structure and fields.
|
||||
The Gmail payload will be available via the standard context mechanisms.
|
||||
|
||||
### Sample payloads & crews
|
||||
### Testing Locally
|
||||
|
||||
The [CrewAI AMP Trigger Examples repository](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/gmail) includes:
|
||||
Test your Gmail trigger integration locally using the CrewAI CLI:
|
||||
|
||||
- `new-email-payload-1.json` / `new-email-payload-2.json` — production-style new message alerts with matching crews in `new-email-crew.py`
|
||||
- `thread-updated-sample-1.json` — follow-up messages on an existing thread, processed by `gmail-alert-crew.py`
|
||||
```bash
|
||||
# View all available triggers
|
||||
crewai triggers list
|
||||
|
||||
Use these samples to validate your parsing logic locally before wiring the trigger to your live Gmail accounts.
|
||||
# Simulate a Gmail trigger with realistic payload
|
||||
crewai triggers run gmail/new_email
|
||||
```
|
||||
|
||||
The `crewai triggers run` command will execute your crew with a complete Gmail payload, allowing you to test your parsing logic before deployment.
|
||||
|
||||
<Warning>
|
||||
Use `crewai triggers run gmail/new_email` (not `crewai run`) to simulate trigger execution during development. After deployment, your crew will automatically receive the trigger payload.
|
||||
</Warning>
|
||||
|
||||
## Monitoring Executions
|
||||
|
||||
@@ -70,16 +79,10 @@ Track history and performance of triggered runs:
|
||||
<img src="/images/enterprise/list-executions.png" alt="List of executions triggered by automation" />
|
||||
</Frame>
|
||||
|
||||
## Payload Reference
|
||||
|
||||
See the sample payloads and field descriptions:
|
||||
|
||||
<Card title="Gmail samples in Trigger Examples Repo" href="https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/gmail" icon="envelopes-bulk">
|
||||
Gmail samples in Trigger Examples Repo
|
||||
</Card>
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
- Ensure Gmail is connected in Tools & Integrations
|
||||
- Verify the Gmail Trigger is enabled on the Triggers tab
|
||||
- Test locally with `crewai triggers run gmail/new_email` to see the exact payload structure
|
||||
- Check the execution logs and confirm the payload is passed as `crewai_trigger_payload`
|
||||
- Remember: use `crewai triggers run` (not `crewai run`) to simulate trigger execution
|
||||
|
||||
@@ -39,16 +39,23 @@ print(result.raw)
|
||||
|
||||
Use `crewai_trigger_payload` exactly as it is delivered by the trigger so the crew can extract the proper fields.
|
||||
|
||||
## Sample payloads & crews
|
||||
## Testing Locally
|
||||
|
||||
The [Google Calendar examples](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/google_calendar) show how to handle multiple event types:
|
||||
Test your Google Calendar trigger integration locally using the CrewAI CLI:
|
||||
|
||||
- `new-event.json` → standard event creation handled by `calendar-event-crew.py`
|
||||
- `event-updated.json` / `event-started.json` / `event-ended.json` → in-flight updates processed by `calendar-meeting-crew.py`
|
||||
- `event-canceled.json` → cancellation workflow that alerts attendees via `calendar-meeting-crew.py`
|
||||
- Working location events use `calendar-working-location-crew.py` to extract on-site schedules
|
||||
```bash
|
||||
# View all available triggers
|
||||
crewai triggers list
|
||||
|
||||
Each crew transforms raw event metadata (attendees, rooms, working locations) into the summaries your teams need.
|
||||
# Simulate a Google Calendar trigger with realistic payload
|
||||
crewai triggers run google_calendar/event_changed
|
||||
```
|
||||
|
||||
The `crewai triggers run` command will execute your crew with a complete Calendar payload, allowing you to test your parsing logic before deployment.
|
||||
|
||||
<Warning>
|
||||
Use `crewai triggers run google_calendar/event_changed` (not `crewai run`) to simulate trigger execution during development. After deployment, your crew will automatically receive the trigger payload.
|
||||
</Warning>
|
||||
|
||||
## Monitoring Executions
|
||||
|
||||
@@ -61,5 +68,7 @@ The **Executions** list in the deployment dashboard tracks every triggered run a
|
||||
## Troubleshooting
|
||||
|
||||
- Ensure the correct Google account is connected and the trigger is enabled
|
||||
- Test locally with `crewai triggers run google_calendar/event_changed` to see the exact payload structure
|
||||
- Confirm your workflow handles all-day events (payloads use `start.date` and `end.date` instead of timestamps)
|
||||
- Check execution logs if reminders or attendee arrays are missing—calendar permissions can limit fields in the payload
|
||||
- Remember: use `crewai triggers run` (not `crewai run`) to simulate trigger execution
|
||||
|
||||
@@ -36,15 +36,23 @@ crew.kickoff({
|
||||
})
|
||||
```
|
||||
|
||||
## Sample payloads & crews
|
||||
## Testing Locally
|
||||
|
||||
Explore the [Google Drive examples](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/google_drive) to cover different operations:
|
||||
Test your Google Drive trigger integration locally using the CrewAI CLI:
|
||||
|
||||
- `new-file.json` → new uploads processed by `drive-file-crew.py`
|
||||
- `updated-file.json` → file edits and metadata changes handled by `drive-file-crew.py`
|
||||
- `deleted-file.json` → deletion events routed through `drive-file-deletion-crew.py`
|
||||
```bash
|
||||
# View all available triggers
|
||||
crewai triggers list
|
||||
|
||||
Each crew highlights the file name, operation type, owner, permissions, and security considerations so downstream systems can respond appropriately.
|
||||
# Simulate a Google Drive trigger with realistic payload
|
||||
crewai triggers run google_drive/file_changed
|
||||
```
|
||||
|
||||
The `crewai triggers run` command will execute your crew with a complete Drive payload, allowing you to test your parsing logic before deployment.
|
||||
|
||||
<Warning>
|
||||
Use `crewai triggers run google_drive/file_changed` (not `crewai run`) to simulate trigger execution during development. After deployment, your crew will automatically receive the trigger payload.
|
||||
</Warning>
|
||||
|
||||
## Monitoring Executions
|
||||
|
||||
@@ -57,5 +65,7 @@ Track history and performance of triggered runs with the **Executions** list in
|
||||
## Troubleshooting
|
||||
|
||||
- Verify Google Drive is connected and the trigger toggle is enabled
|
||||
- Test locally with `crewai triggers run google_drive/file_changed` to see the exact payload structure
|
||||
- If a payload is missing permission data, ensure the connected account has access to the file or folder
|
||||
- The trigger sends file IDs only; use the Drive API if you need to fetch binary content during the crew run
|
||||
- Remember: use `crewai triggers run` (not `crewai run`) to simulate trigger execution
|
||||
|
||||
@@ -49,16 +49,4 @@ This guide provides a step-by-step process to set up HubSpot triggers for CrewAI
|
||||
</Step>
|
||||
</Steps>
|
||||
|
||||
## Additional Resources
|
||||
|
||||
### Sample payloads & crews
|
||||
|
||||
You can jump-start development with the [HubSpot examples in the trigger repository](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/hubspot):
|
||||
|
||||
- `record-created-contact.json`, `record-updated-contact.json` → contact lifecycle events handled by `hubspot-contact-crew.py`
|
||||
- `record-created-company.json`, `record-updated-company.json` → company enrichment flows in `hubspot-company-crew.py`
|
||||
- `record-created-deals.json`, `record-updated-deals.json` → deal pipeline automation in `hubspot-record-crew.py`
|
||||
|
||||
Each crew demonstrates how to parse HubSpot record fields, enrich context, and return structured insights.
|
||||
|
||||
For more detailed information on available actions and customization options, refer to the [HubSpot Workflows Documentation](https://knowledge.hubspot.com/workflows/create-workflows).
|
||||
|
||||
@@ -37,16 +37,28 @@ print(result.raw)
|
||||
|
||||
The crew parses thread metadata (subject, created time, roster) and generates an action plan for the receiving team.
|
||||
|
||||
## Sample payloads & crews
|
||||
## Testing Locally
|
||||
|
||||
The [Microsoft Teams examples](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/microsoft-teams) include:
|
||||
Test your Microsoft Teams trigger integration locally using the CrewAI CLI:
|
||||
|
||||
- `chat-created.json` → chat creation payload processed by `teams-chat-created-crew.py`
|
||||
```bash
|
||||
# View all available triggers
|
||||
crewai triggers list
|
||||
|
||||
The crew demonstrates how to extract participants, initial messages, tenant information, and compliance metadata from the Microsoft Graph webhook payload.
|
||||
# Simulate a Microsoft Teams trigger with realistic payload
|
||||
crewai triggers run microsoft_teams/teams_message_created
|
||||
```
|
||||
|
||||
The `crewai triggers run` command will execute your crew with a complete Teams payload, allowing you to test your parsing logic before deployment.
|
||||
|
||||
<Warning>
|
||||
Use `crewai triggers run microsoft_teams/teams_message_created` (not `crewai run`) to simulate trigger execution during development. After deployment, your crew will automatically receive the trigger payload.
|
||||
</Warning>
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
- Ensure the Teams connection is active; it must be refreshed if the tenant revokes permissions
|
||||
- Test locally with `crewai triggers run microsoft_teams/teams_message_created` to see the exact payload structure
|
||||
- Confirm the webhook subscription in Microsoft 365 is still valid if payloads stop arriving
|
||||
- Review execution logs for payload shape mismatches—Graph notifications may omit fields when a chat is private or restricted
|
||||
- Remember: use `crewai triggers run` (not `crewai run`) to simulate trigger execution
|
||||
|
||||
@@ -36,18 +36,28 @@ crew.kickoff({
|
||||
|
||||
The crew inspects file metadata, user activity, and permission changes to produce a compliance-friendly summary.
|
||||
|
||||
## Sample payloads & crews
|
||||
## Testing Locally
|
||||
|
||||
The [OneDrive examples](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/onedrive) showcase how to:
|
||||
Test your OneDrive trigger integration locally using the CrewAI CLI:
|
||||
|
||||
- Parse file metadata, size, and folder paths
|
||||
- Track who created and last modified the file
|
||||
- Highlight permission and external sharing changes
|
||||
```bash
|
||||
# View all available triggers
|
||||
crewai triggers list
|
||||
|
||||
`onedrive-file-crew.py` bundles the analysis and summarization tasks so you can add remediation steps as needed.
|
||||
# Simulate a OneDrive trigger with realistic payload
|
||||
crewai triggers run microsoft_onedrive/file_changed
|
||||
```
|
||||
|
||||
The `crewai triggers run` command will execute your crew with a complete OneDrive payload, allowing you to test your parsing logic before deployment.
|
||||
|
||||
<Warning>
|
||||
Use `crewai triggers run microsoft_onedrive/file_changed` (not `crewai run`) to simulate trigger execution during development. After deployment, your crew will automatically receive the trigger payload.
|
||||
</Warning>
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
- Ensure the connected account has permission to read the file metadata included in the webhook
|
||||
- Test locally with `crewai triggers run microsoft_onedrive/file_changed` to see the exact payload structure
|
||||
- If the trigger fires but the payload is missing `permissions`, confirm the site-level sharing settings allow Graph to return this field
|
||||
- For large tenants, filter notifications upstream so the crew only runs on relevant directories
|
||||
- Remember: use `crewai triggers run` (not `crewai run`) to simulate trigger execution
|
||||
|
||||
@@ -36,17 +36,28 @@ crew.kickoff({
|
||||
|
||||
The crew extracts sender details, subject, body preview, and attachments before generating a structured response.
|
||||
|
||||
## Sample payloads & crews
|
||||
## Testing Locally
|
||||
|
||||
Review the [Outlook examples](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/outlook) for two common scenarios:
|
||||
Test your Outlook trigger integration locally using the CrewAI CLI:
|
||||
|
||||
- `new-message.json` → new mail notifications parsed by `outlook-message-crew.py`
|
||||
- `event-removed.json` → calendar cleanup handled by `outlook-event-removal-crew.py`
|
||||
```bash
|
||||
# View all available triggers
|
||||
crewai triggers list
|
||||
|
||||
Each crew demonstrates how to handle Microsoft Graph payloads, normalize headers, and keep humans in-the-loop with concise summaries.
|
||||
# Simulate an Outlook trigger with realistic payload
|
||||
crewai triggers run microsoft_outlook/email_received
|
||||
```
|
||||
|
||||
The `crewai triggers run` command will execute your crew with a complete Outlook payload, allowing you to test your parsing logic before deployment.
|
||||
|
||||
<Warning>
|
||||
Use `crewai triggers run microsoft_outlook/email_received` (not `crewai run`) to simulate trigger execution during development. After deployment, your crew will automatically receive the trigger payload.
|
||||
</Warning>
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
- Verify the Outlook connector is still authorized; the subscription must be renewed periodically
|
||||
- Test locally with `crewai triggers run microsoft_outlook/email_received` to see the exact payload structure
|
||||
- If attachments are missing, confirm the webhook subscription includes the `includeResourceData` flag
|
||||
- Review execution logs when events fail to match—cancellation payloads lack attendee lists by design and the crew should account for that
|
||||
- Remember: use `crewai triggers run` (not `crewai run`) to simulate trigger execution
|
||||
|
||||
@@ -57,6 +57,22 @@ mode: "wide"
|
||||
uv add crewai-tools
|
||||
```
|
||||
|
||||
### 환경 변수 설정
|
||||
|
||||
<Note>
|
||||
`Agent(apps=[])`와 함께 통합을 사용하려면 Enterprise Token으로 `CREWAI_PLATFORM_INTEGRATION_TOKEN` 환경 변수를 설정해야 합니다.
|
||||
</Note>
|
||||
|
||||
```bash
|
||||
export CREWAI_PLATFORM_INTEGRATION_TOKEN="your_enterprise_token"
|
||||
```
|
||||
|
||||
또는 `.env` 파일에 추가하세요:
|
||||
|
||||
```
|
||||
CREWAI_PLATFORM_INTEGRATION_TOKEN=your_enterprise_token
|
||||
```
|
||||
|
||||
### 사용 예시
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -110,19 +110,49 @@ CrewAI AMP 트리거는 팀이 이미 사용하고 있는 도구의 실시간
|
||||
- `allow_crewai_trigger_context` 옵션으로 컨텍스트 자동 주입 여부를 결정했나요?
|
||||
- 웹훅 로그, CrewAI 실행 기록, 외부 알림 등 모니터링을 준비했나요?
|
||||
|
||||
### Payload & Crew 예제 저장소
|
||||
### CLI로 로컬에서 트리거 테스트
|
||||
|
||||
| 통합 | 동작 시점 | Payload 예제 | Crew 예제 |
|
||||
| :-- | :-- | :-- | :-- |
|
||||
| Gmail | 신규 메일, 스레드 업데이트 | [Gmail payload](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/gmail) | [`new-email-crew.py`, `gmail-alert-crew.py`](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/gmail) |
|
||||
| Google Calendar | 이벤트 생성/수정/시작/종료/취소 | [Calendar payload](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/google_calendar) | [`calendar-event-crew.py`, `calendar-meeting-crew.py`, `calendar-working-location-crew.py`](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/google_calendar) |
|
||||
| Google Drive | 파일 생성/수정/삭제 | [Drive payload](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/google_drive) | [`drive-file-crew.py`, `drive-file-deletion-crew.py`](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/google_drive) |
|
||||
| Outlook | 새 이메일, 이벤트 제거 | [Outlook payload](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/outlook) | [`outlook-message-crew.py`, `outlook-event-removal-crew.py`](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/outlook) |
|
||||
| OneDrive | 파일 작업(생성, 수정, 공유, 삭제) | [OneDrive payload](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/onedrive) | [`onedrive-file-crew.py`](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/onedrive) |
|
||||
| HubSpot | 레코드 생성/업데이트(연락처, 회사, 딜) | [HubSpot payload](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/hubspot) | [`hubspot-company-crew.py`, `hubspot-contact-crew.py`, `hubspot-record-crew.py`](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/hubspot) |
|
||||
| Microsoft Teams | 채팅 생성 | [Teams payload](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/microsoft-teams) | [`teams-chat-created-crew.py`](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/microsoft-teams) |
|
||||
CrewAI CLI는 프로덕션에 배포하기 전에 트리거 기반 자동화를 개발하고 테스트할 수 있는 강력한 명령을 제공합니다.
|
||||
|
||||
예제 payload를 참고해 파싱 로직을 검증하고, 제공되는 crew를 복사해 실제 데이터로 교체하세요.
|
||||
#### 사용 가능한 트리거 목록 보기
|
||||
|
||||
연결된 통합에 사용 가능한 모든 트리거를 확인하세요:
|
||||
|
||||
```bash
|
||||
crewai triggers list
|
||||
```
|
||||
|
||||
이 명령은 연결된 통합을 기반으로 사용 가능한 모든 트리거를 표시합니다:
|
||||
- 통합 이름 및 연결 상태
|
||||
- 사용 가능한 트리거 유형
|
||||
- 트리거 이름 및 설명
|
||||
|
||||
#### 트리거 실행 시뮬레이션
|
||||
|
||||
배포 전에 실제 트리거 payload로 크루를 테스트하세요:
|
||||
|
||||
```bash
|
||||
crewai triggers run <트리거_이름>
|
||||
```
|
||||
|
||||
예시:
|
||||
|
||||
```bash
|
||||
crewai triggers run microsoft_onedrive/file_changed
|
||||
```
|
||||
|
||||
이 명령은:
|
||||
- 로컬에서 크루를 실행합니다
|
||||
- 완전하고 실제적인 트리거 payload를 전달합니다
|
||||
- 프로덕션에서 크루가 호출되는 방식을 정확히 시뮬레이션합니다
|
||||
|
||||
<Warning>
|
||||
**중요한 개발 노트:**
|
||||
- 개발 중 트리거 실행을 시뮬레이션하려면 `crewai triggers run <trigger>`를 사용하세요
|
||||
- `crewai run`을 사용하면 트리거 호출을 시뮬레이션하지 않으며 트리거 payload를 전달하지 않습니다
|
||||
- 배포 후에는 실제 트리거 payload로 크루가 실행됩니다
|
||||
- 크루가 트리거 payload에 없는 매개변수를 기대하면 실행이 실패할 수 있습니다
|
||||
</Warning>
|
||||
|
||||
### 트리거와 Crew 연동
|
||||
|
||||
@@ -191,17 +221,20 @@ def delegate_to_crew(self, crewai_trigger_payload: dict = None):
|
||||
## 문제 해결
|
||||
|
||||
**트리거가 실행되지 않나요?**
|
||||
- 트리거가 활성 상태인지 확인하세요.
|
||||
- 통합 연결 상태를 확인하세요.
|
||||
- 배포의 Triggers 탭에서 트리거가 활성화되어 있는지 확인하세요
|
||||
- Tools & Integrations에서 통합 연결 상태를 확인하세요
|
||||
- 필요한 모든 환경 변수가 올바르게 구성되어 있는지 확인하세요
|
||||
|
||||
**실행 중 오류가 발생하나요?**
|
||||
- 실행 로그에서 오류 메시지를 확인하세요.
|
||||
- 개발 중이라면 `crewai_trigger_payload`가 올바른 데이터로 전달되고 있는지 확인하세요.
|
||||
- 실행 로그에서 오류 세부 정보를 확인하세요
|
||||
- `crewai triggers run <트리거_이름>`을 사용하여 로컬에서 테스트하고 정확한 payload 구조를 확인하세요
|
||||
- 크루가 `crewai_trigger_payload` 매개변수를 처리할 수 있는지 확인하세요
|
||||
- 크루가 트리거 payload에 포함되지 않은 매개변수를 기대하지 않는지 확인하세요
|
||||
|
||||
**개발 문제:**
|
||||
- 배포하기 전에 항상 `crewai triggers run <trigger>`로 테스트하여 전체 payload를 확인하세요
|
||||
- `crewai run`은 트리거 호출을 시뮬레이션하지 않으므로 `crewai triggers run`을 대신 사용하세요
|
||||
- `crewai triggers list`를 사용하여 연결된 통합에 사용 가능한 트리거를 확인하세요
|
||||
- 배포 후 크루는 실제 트리거 payload를 받으므로 먼저 로컬에서 철저히 테스트하세요
|
||||
|
||||
트리거를 활용하면 CrewAI 자동화를 이벤트 기반 시스템으로 전환하여 기존 비즈니스 프로세스와 도구에 자연스럽게 녹여낼 수 있습니다.
|
||||
|
||||
<Callout icon="github" title="예제 저장소">
|
||||
<a href="https://github.com/crewAIInc/crewai-enterprise-trigger-examples">
|
||||
CrewAI AMP Trigger Examples
|
||||
</a>
|
||||
</Callout>
|
||||
|
||||
@@ -51,16 +51,25 @@ class GmailProcessingCrew:
|
||||
)
|
||||
```
|
||||
|
||||
The Gmail payload will be available via the standard context mechanisms. See the payload samples repository for structure and fields.
|
||||
The Gmail payload will be available via the standard context mechanisms.
|
||||
|
||||
### Sample payloads & crews
|
||||
### 로컬에서 테스트
|
||||
|
||||
The [CrewAI AMP Trigger Examples repository](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/gmail) includes:
|
||||
CrewAI CLI를 사용하여 Gmail 트리거 통합을 로컬에서 테스트하세요:
|
||||
|
||||
- `new-email-payload-1.json` / `new-email-payload-2.json` — production-style new message alerts with matching crews in `new-email-crew.py`
|
||||
- `thread-updated-sample-1.json` — follow-up messages on an existing thread, processed by `gmail-alert-crew.py`
|
||||
```bash
|
||||
# 사용 가능한 모든 트리거 보기
|
||||
crewai triggers list
|
||||
|
||||
Use these samples to validate your parsing logic locally before wiring the trigger to your live Gmail accounts.
|
||||
# 실제 payload로 Gmail 트리거 시뮬레이션
|
||||
crewai triggers run gmail/new_email
|
||||
```
|
||||
|
||||
`crewai triggers run` 명령은 완전한 Gmail payload로 크루를 실행하여 배포 전에 파싱 로직을 테스트할 수 있게 해줍니다.
|
||||
|
||||
<Warning>
|
||||
개발 중에는 `crewai triggers run gmail/new_email`을 사용하세요 (`crewai run`이 아님). 배포 후에는 크루가 자동으로 트리거 payload를 받습니다.
|
||||
</Warning>
|
||||
|
||||
## Monitoring Executions
|
||||
|
||||
@@ -70,16 +79,10 @@ Track history and performance of triggered runs:
|
||||
<img src="/images/enterprise/list-executions.png" alt="List of executions triggered by automation" />
|
||||
</Frame>
|
||||
|
||||
## Payload Reference
|
||||
|
||||
See the sample payloads and field descriptions:
|
||||
|
||||
<Card title="Gmail samples in Trigger Examples Repo" href="https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/gmail" icon="envelopes-bulk">
|
||||
Gmail samples in Trigger Examples Repo
|
||||
</Card>
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
- Ensure Gmail is connected in Tools & Integrations
|
||||
- Verify the Gmail Trigger is enabled on the Triggers tab
|
||||
- `crewai triggers run gmail/new_email`로 로컬 테스트하여 정확한 payload 구조를 확인하세요
|
||||
- Check the execution logs and confirm the payload is passed as `crewai_trigger_payload`
|
||||
- 주의: 트리거 실행을 시뮬레이션하려면 `crewai triggers run`을 사용하세요 (`crewai run`이 아님)
|
||||
|
||||
@@ -39,16 +39,23 @@ print(result.raw)
|
||||
|
||||
Use `crewai_trigger_payload` exactly as it is delivered by the trigger so the crew can extract the proper fields.
|
||||
|
||||
## Sample payloads & crews
|
||||
## 로컬에서 테스트
|
||||
|
||||
The [Google Calendar examples](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/google_calendar) show how to handle multiple event types:
|
||||
CrewAI CLI를 사용하여 Google Calendar 트리거 통합을 로컬에서 테스트하세요:
|
||||
|
||||
- `new-event.json` → standard event creation handled by `calendar-event-crew.py`
|
||||
- `event-updated.json` / `event-started.json` / `event-ended.json` → in-flight updates processed by `calendar-meeting-crew.py`
|
||||
- `event-canceled.json` → cancellation workflow that alerts attendees via `calendar-meeting-crew.py`
|
||||
- Working location events use `calendar-working-location-crew.py` to extract on-site schedules
|
||||
```bash
|
||||
# 사용 가능한 모든 트리거 보기
|
||||
crewai triggers list
|
||||
|
||||
Each crew transforms raw event metadata (attendees, rooms, working locations) into the summaries your teams need.
|
||||
# 실제 payload로 Google Calendar 트리거 시뮬레이션
|
||||
crewai triggers run google_calendar/event_changed
|
||||
```
|
||||
|
||||
`crewai triggers run` 명령은 완전한 Calendar payload로 크루를 실행하여 배포 전에 파싱 로직을 테스트할 수 있게 해줍니다.
|
||||
|
||||
<Warning>
|
||||
개발 중에는 `crewai triggers run google_calendar/event_changed`를 사용하세요 (`crewai run`이 아님). 배포 후에는 크루가 자동으로 트리거 payload를 받습니다.
|
||||
</Warning>
|
||||
|
||||
## Monitoring Executions
|
||||
|
||||
@@ -61,5 +68,7 @@ The **Executions** list in the deployment dashboard tracks every triggered run a
|
||||
## Troubleshooting
|
||||
|
||||
- Ensure the correct Google account is connected and the trigger is enabled
|
||||
- `crewai triggers run google_calendar/event_changed`로 로컬 테스트하여 정확한 payload 구조를 확인하세요
|
||||
- Confirm your workflow handles all-day events (payloads use `start.date` and `end.date` instead of timestamps)
|
||||
- Check execution logs if reminders or attendee arrays are missing—calendar permissions can limit fields in the payload
|
||||
- 주의: 트리거 실행을 시뮬레이션하려면 `crewai triggers run`을 사용하세요 (`crewai run`이 아님)
|
||||
|
||||
@@ -36,15 +36,23 @@ crew.kickoff({
|
||||
})
|
||||
```
|
||||
|
||||
## Sample payloads & crews
|
||||
## 로컬에서 테스트
|
||||
|
||||
Explore the [Google Drive examples](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/google_drive) to cover different operations:
|
||||
CrewAI CLI를 사용하여 Google Drive 트리거 통합을 로컬에서 테스트하세요:
|
||||
|
||||
- `new-file.json` → new uploads processed by `drive-file-crew.py`
|
||||
- `updated-file.json` → file edits and metadata changes handled by `drive-file-crew.py`
|
||||
- `deleted-file.json` → deletion events routed through `drive-file-deletion-crew.py`
|
||||
```bash
|
||||
# 사용 가능한 모든 트리거 보기
|
||||
crewai triggers list
|
||||
|
||||
Each crew highlights the file name, operation type, owner, permissions, and security considerations so downstream systems can respond appropriately.
|
||||
# 실제 payload로 Google Drive 트리거 시뮬레이션
|
||||
crewai triggers run google_drive/file_changed
|
||||
```
|
||||
|
||||
`crewai triggers run` 명령은 완전한 Drive payload로 크루를 실행하여 배포 전에 파싱 로직을 테스트할 수 있게 해줍니다.
|
||||
|
||||
<Warning>
|
||||
개발 중에는 `crewai triggers run google_drive/file_changed`를 사용하세요 (`crewai run`이 아님). 배포 후에는 크루가 자동으로 트리거 payload를 받습니다.
|
||||
</Warning>
|
||||
|
||||
## Monitoring Executions
|
||||
|
||||
@@ -57,5 +65,7 @@ Track history and performance of triggered runs with the **Executions** list in
|
||||
## Troubleshooting
|
||||
|
||||
- Verify Google Drive is connected and the trigger toggle is enabled
|
||||
- `crewai triggers run google_drive/file_changed`로 로컬 테스트하여 정확한 payload 구조를 확인하세요
|
||||
- If a payload is missing permission data, ensure the connected account has access to the file or folder
|
||||
- The trigger sends file IDs only; use the Drive API if you need to fetch binary content during the crew run
|
||||
- 주의: 트리거 실행을 시뮬레이션하려면 `crewai triggers run`을 사용하세요 (`crewai run`이 아님)
|
||||
|
||||
@@ -49,6 +49,4 @@ mode: "wide"
|
||||
</Step>
|
||||
</Steps>
|
||||
|
||||
## 추가 자료
|
||||
|
||||
사용 가능한 작업과 사용자 지정 옵션에 대한 자세한 정보는 [HubSpot 워크플로우 문서](https://knowledge.hubspot.com/workflows/create-workflows)를 참고하세요.
|
||||
|
||||
@@ -37,16 +37,28 @@ print(result.raw)
|
||||
|
||||
The crew parses thread metadata (subject, created time, roster) and generates an action plan for the receiving team.
|
||||
|
||||
## Sample payloads & crews
|
||||
## 로컬에서 테스트
|
||||
|
||||
The [Microsoft Teams examples](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/microsoft-teams) include:
|
||||
CrewAI CLI를 사용하여 Microsoft Teams 트리거 통합을 로컬에서 테스트하세요:
|
||||
|
||||
- `chat-created.json` → chat creation payload processed by `teams-chat-created-crew.py`
|
||||
```bash
|
||||
# 사용 가능한 모든 트리거 보기
|
||||
crewai triggers list
|
||||
|
||||
The crew demonstrates how to extract participants, initial messages, tenant information, and compliance metadata from the Microsoft Graph webhook payload.
|
||||
# 실제 payload로 Microsoft Teams 트리거 시뮬레이션
|
||||
crewai triggers run microsoft_teams/teams_message_created
|
||||
```
|
||||
|
||||
`crewai triggers run` 명령은 완전한 Teams payload로 크루를 실행하여 배포 전에 파싱 로직을 테스트할 수 있게 해줍니다.
|
||||
|
||||
<Warning>
|
||||
개발 중에는 `crewai triggers run microsoft_teams/teams_message_created`를 사용하세요 (`crewai run`이 아님). 배포 후에는 크루가 자동으로 트리거 payload를 받습니다.
|
||||
</Warning>
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
- Ensure the Teams connection is active; it must be refreshed if the tenant revokes permissions
|
||||
- `crewai triggers run microsoft_teams/teams_message_created`로 로컬 테스트하여 정확한 payload 구조를 확인하세요
|
||||
- Confirm the webhook subscription in Microsoft 365 is still valid if payloads stop arriving
|
||||
- Review execution logs for payload shape mismatches—Graph notifications may omit fields when a chat is private or restricted
|
||||
- 주의: 트리거 실행을 시뮬레이션하려면 `crewai triggers run`을 사용하세요 (`crewai run`이 아님)
|
||||
|
||||
@@ -36,18 +36,28 @@ crew.kickoff({
|
||||
|
||||
The crew inspects file metadata, user activity, and permission changes to produce a compliance-friendly summary.
|
||||
|
||||
## Sample payloads & crews
|
||||
## 로컬에서 테스트
|
||||
|
||||
The [OneDrive examples](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/onedrive) showcase how to:
|
||||
CrewAI CLI를 사용하여 OneDrive 트리거 통합을 로컬에서 테스트하세요:
|
||||
|
||||
- Parse file metadata, size, and folder paths
|
||||
- Track who created and last modified the file
|
||||
- Highlight permission and external sharing changes
|
||||
```bash
|
||||
# 사용 가능한 모든 트리거 보기
|
||||
crewai triggers list
|
||||
|
||||
`onedrive-file-crew.py` bundles the analysis and summarization tasks so you can add remediation steps as needed.
|
||||
# 실제 payload로 OneDrive 트리거 시뮬레이션
|
||||
crewai triggers run microsoft_onedrive/file_changed
|
||||
```
|
||||
|
||||
`crewai triggers run` 명령은 완전한 OneDrive payload로 크루를 실행하여 배포 전에 파싱 로직을 테스트할 수 있게 해줍니다.
|
||||
|
||||
<Warning>
|
||||
개발 중에는 `crewai triggers run microsoft_onedrive/file_changed`를 사용하세요 (`crewai run`이 아님). 배포 후에는 크루가 자동으로 트리거 payload를 받습니다.
|
||||
</Warning>
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
- Ensure the connected account has permission to read the file metadata included in the webhook
|
||||
- `crewai triggers run microsoft_onedrive/file_changed`로 로컬 테스트하여 정확한 payload 구조를 확인하세요
|
||||
- If the trigger fires but the payload is missing `permissions`, confirm the site-level sharing settings allow Graph to return this field
|
||||
- For large tenants, filter notifications upstream so the crew only runs on relevant directories
|
||||
- 주의: 트리거 실행을 시뮬레이션하려면 `crewai triggers run`을 사용하세요 (`crewai run`이 아님)
|
||||
|
||||
@@ -36,17 +36,28 @@ crew.kickoff({
|
||||
|
||||
The crew extracts sender details, subject, body preview, and attachments before generating a structured response.
|
||||
|
||||
## Sample payloads & crews
|
||||
## 로컬에서 테스트
|
||||
|
||||
Review the [Outlook examples](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/outlook) for two common scenarios:
|
||||
CrewAI CLI를 사용하여 Outlook 트리거 통합을 로컬에서 테스트하세요:
|
||||
|
||||
- `new-message.json` → new mail notifications parsed by `outlook-message-crew.py`
|
||||
- `event-removed.json` → calendar cleanup handled by `outlook-event-removal-crew.py`
|
||||
```bash
|
||||
# 사용 가능한 모든 트리거 보기
|
||||
crewai triggers list
|
||||
|
||||
Each crew demonstrates how to handle Microsoft Graph payloads, normalize headers, and keep humans in-the-loop with concise summaries.
|
||||
# 실제 payload로 Outlook 트리거 시뮬레이션
|
||||
crewai triggers run microsoft_outlook/email_received
|
||||
```
|
||||
|
||||
`crewai triggers run` 명령은 완전한 Outlook payload로 크루를 실행하여 배포 전에 파싱 로직을 테스트할 수 있게 해줍니다.
|
||||
|
||||
<Warning>
|
||||
개발 중에는 `crewai triggers run microsoft_outlook/email_received`를 사용하세요 (`crewai run`이 아님). 배포 후에는 크루가 자동으로 트리거 payload를 받습니다.
|
||||
</Warning>
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
- Verify the Outlook connector is still authorized; the subscription must be renewed periodically
|
||||
- `crewai triggers run microsoft_outlook/email_received`로 로컬 테스트하여 정확한 payload 구조를 확인하세요
|
||||
- If attachments are missing, confirm the webhook subscription includes the `includeResourceData` flag
|
||||
- Review execution logs when events fail to match—cancellation payloads lack attendee lists by design and the crew should account for that
|
||||
- 주의: 트리거 실행을 시뮬레이션하려면 `crewai triggers run`을 사용하세요 (`crewai run`이 아님)
|
||||
|
||||
@@ -57,6 +57,22 @@ Ferramentas & Integrações é o hub central para conectar aplicações de terce
|
||||
uv add crewai-tools
|
||||
```
|
||||
|
||||
### Configuração de variável de ambiente
|
||||
|
||||
<Note>
|
||||
Para usar integrações com `Agent(apps=[])`, você deve definir a variável de ambiente `CREWAI_PLATFORM_INTEGRATION_TOKEN` com seu Enterprise Token.
|
||||
</Note>
|
||||
|
||||
```bash
|
||||
export CREWAI_PLATFORM_INTEGRATION_TOKEN="seu_enterprise_token"
|
||||
```
|
||||
|
||||
Ou adicione ao seu arquivo `.env`:
|
||||
|
||||
```
|
||||
CREWAI_PLATFORM_INTEGRATION_TOKEN=seu_enterprise_token
|
||||
```
|
||||
|
||||
### Exemplo de uso
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -116,19 +116,49 @@ Antes de ativar em produção, confirme que você:
|
||||
- Decidiu se usará `allow_crewai_trigger_context` para injetar contexto automaticamente
|
||||
- Configurou monitoramento (webhooks, históricos da CrewAI, alertas externos)
|
||||
|
||||
### Repositório de Payloads e Crews de Exemplo
|
||||
### Testando Triggers Localmente com CLI
|
||||
|
||||
| Integração | Quando dispara | Amostras de payload | Crews de exemplo |
|
||||
| :-- | :-- | :-- | :-- |
|
||||
| Gmail | Novas mensagens, atualização de threads | [Payloads de alertas e threads](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/gmail) | [`new-email-crew.py`, `gmail-alert-crew.py`](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/gmail) |
|
||||
| Google Calendar | Evento criado/atualizado/iniciado/encerrado/cancelado | [Payloads de eventos](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/google_calendar) | [`calendar-event-crew.py`, `calendar-meeting-crew.py`, `calendar-working-location-crew.py`](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/google_calendar) |
|
||||
| Google Drive | Arquivo criado/atualizado/excluído | [Payloads de arquivos](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/google_drive) | [`drive-file-crew.py`, `drive-file-deletion-crew.py`](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/google_drive) |
|
||||
| Outlook | Novo e‑mail, evento removido | [Payloads do Outlook](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/outlook) | [`outlook-message-crew.py`, `outlook-event-removal-crew.py`](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/outlook) |
|
||||
| OneDrive | Operações de arquivo (criar, atualizar, compartilhar, excluir) | [Payloads do OneDrive](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/onedrive) | [`onedrive-file-crew.py`](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/onedrive) |
|
||||
| HubSpot | Registros criados/atualizados (contatos, empresas, negócios) | [Payloads do HubSpot](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/hubspot) | [`hubspot-company-crew.py`, `hubspot-contact-crew.py`, `hubspot-record-crew.py`](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/hubspot) |
|
||||
| Microsoft Teams | Chat criado | [Payload do Teams](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/microsoft-teams) | [`teams-chat-created-crew.py`](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/microsoft-teams) |
|
||||
A CLI da CrewAI fornece comandos poderosos para ajudá-lo a desenvolver e testar automações orientadas por triggers sem fazer deploy para produção.
|
||||
|
||||
Use essas amostras para ajustar o parsing, copiar a crew correspondente e substituir o payload de teste pelo dado real.
|
||||
#### Listar Triggers Disponíveis
|
||||
|
||||
Visualize todos os triggers disponíveis para suas integrações conectadas:
|
||||
|
||||
```bash
|
||||
crewai triggers list
|
||||
```
|
||||
|
||||
Este comando exibe todos os triggers disponíveis baseados nas suas integrações conectadas, mostrando:
|
||||
- Nome da integração e status de conexão
|
||||
- Tipos de triggers disponíveis
|
||||
- Nomes e descrições dos triggers
|
||||
|
||||
#### Simular Execução de Trigger
|
||||
|
||||
Teste sua crew com payloads realistas de triggers antes do deployment:
|
||||
|
||||
```bash
|
||||
crewai triggers run <nome_do_trigger>
|
||||
```
|
||||
|
||||
Por exemplo:
|
||||
|
||||
```bash
|
||||
crewai triggers run microsoft_onedrive/file_changed
|
||||
```
|
||||
|
||||
Este comando:
|
||||
- Executa sua crew localmente
|
||||
- Passa um payload de trigger completo e realista
|
||||
- Simula exatamente como sua crew será chamada em produção
|
||||
|
||||
<Warning>
|
||||
**Notas Importantes de Desenvolvimento:**
|
||||
- Use `crewai triggers run <trigger>` para simular execução de trigger durante o desenvolvimento
|
||||
- Usar `crewai run` NÃO simulará chamadas de trigger e não passará o payload do trigger
|
||||
- Após o deployment, sua crew será executada com o payload real do trigger
|
||||
- Se sua crew espera parâmetros que não estão no payload do trigger, a execução pode falhar
|
||||
</Warning>
|
||||
|
||||
### Triggers com Crews
|
||||
|
||||
@@ -203,17 +233,20 @@ def delegar_para_crew(self, crewai_trigger_payload: dict = None):
|
||||
## Solução de Problemas
|
||||
|
||||
**Trigger não dispara:**
|
||||
- Verifique se está habilitado
|
||||
- Confira o status da conexão
|
||||
- Verifique se o trigger está habilitado na aba Triggers do seu deployment
|
||||
- Confira o status da conexão em Tools & Integrations
|
||||
- Garanta que todas as variáveis de ambiente necessárias estão configuradas
|
||||
|
||||
**Falhas de execução:**
|
||||
- Consulte os logs para entender o erro
|
||||
- Durante o desenvolvimento, garanta que `crewai_trigger_payload` está presente com o payload correto
|
||||
- Consulte os logs de execução para detalhes do erro
|
||||
- Use `crewai triggers run <nome_do_trigger>` para testar localmente e ver a estrutura exata do payload
|
||||
- Verifique se sua crew pode processar o parâmetro `crewai_trigger_payload`
|
||||
- Garanta que sua crew não espera parâmetros que não estão incluídos no payload do trigger
|
||||
|
||||
**Problemas de desenvolvimento:**
|
||||
- Sempre teste com `crewai triggers run <trigger>` antes de fazer deploy para ver o payload completo
|
||||
- Lembre-se que `crewai run` NÃO simula chamadas de trigger—use `crewai triggers run` em vez disso
|
||||
- Use `crewai triggers list` para verificar quais triggers estão disponíveis para suas integrações conectadas
|
||||
- Após o deployment, sua crew receberá o payload real do trigger, então teste minuciosamente localmente primeiro
|
||||
|
||||
Os triggers transformam suas implantações CrewAI em sistemas orientados por eventos, integrando-se perfeitamente aos processos e ferramentas já usados pelo seu time.
|
||||
|
||||
<Callout icon="github" title="Exemplos na prática">
|
||||
<a href="https://github.com/crewAIInc/crewai-enterprise-trigger-examples">
|
||||
Repositório CrewAI AMP Trigger Examples
|
||||
</a>
|
||||
</Callout>
|
||||
|
||||
@@ -51,16 +51,25 @@ class GmailProcessingCrew:
|
||||
)
|
||||
```
|
||||
|
||||
The Gmail payload will be available via the standard context mechanisms. See the payload samples repository for structure and fields.
|
||||
The Gmail payload will be available via the standard context mechanisms.
|
||||
|
||||
### Sample payloads & crews
|
||||
### Testando Localmente
|
||||
|
||||
The [CrewAI AMP Trigger Examples repository](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/gmail) includes:
|
||||
Teste sua integração de trigger do Gmail localmente usando a CLI da CrewAI:
|
||||
|
||||
- `new-email-payload-1.json` / `new-email-payload-2.json` — production-style new message alerts with matching crews in `new-email-crew.py`
|
||||
- `thread-updated-sample-1.json` — follow-up messages on an existing thread, processed by `gmail-alert-crew.py`
|
||||
```bash
|
||||
# Visualize todos os triggers disponíveis
|
||||
crewai triggers list
|
||||
|
||||
Use these samples to validate your parsing logic locally before wiring the trigger to your live Gmail accounts.
|
||||
# Simule um trigger do Gmail com payload realista
|
||||
crewai triggers run gmail/new_email
|
||||
```
|
||||
|
||||
O comando `crewai triggers run` executará sua crew com um payload completo do Gmail, permitindo que você teste sua lógica de parsing antes do deployment.
|
||||
|
||||
<Warning>
|
||||
Use `crewai triggers run gmail/new_email` (não `crewai run`) para simular execução de trigger durante o desenvolvimento. Após o deployment, sua crew receberá automaticamente o payload do trigger.
|
||||
</Warning>
|
||||
|
||||
## Monitoring Executions
|
||||
|
||||
@@ -70,16 +79,10 @@ Track history and performance of triggered runs:
|
||||
<img src="/images/enterprise/list-executions.png" alt="List of executions triggered by automation" />
|
||||
</Frame>
|
||||
|
||||
## Payload Reference
|
||||
|
||||
See the sample payloads and field descriptions:
|
||||
|
||||
<Card title="Gmail samples in Trigger Examples Repo" href="https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/gmail" icon="envelopes-bulk">
|
||||
Gmail samples in Trigger Examples Repo
|
||||
</Card>
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
- Ensure Gmail is connected in Tools & Integrations
|
||||
- Verify the Gmail Trigger is enabled on the Triggers tab
|
||||
- Teste localmente com `crewai triggers run gmail/new_email` para ver a estrutura exata do payload
|
||||
- Check the execution logs and confirm the payload is passed as `crewai_trigger_payload`
|
||||
- Lembre-se: use `crewai triggers run` (não `crewai run`) para simular execução de trigger
|
||||
|
||||
@@ -39,16 +39,23 @@ print(result.raw)
|
||||
|
||||
Use `crewai_trigger_payload` exactly as it is delivered by the trigger so the crew can extract the proper fields.
|
||||
|
||||
## Sample payloads & crews
|
||||
## Testando Localmente
|
||||
|
||||
The [Google Calendar examples](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/google_calendar) show how to handle multiple event types:
|
||||
Teste sua integração de trigger do Google Calendar localmente usando a CLI da CrewAI:
|
||||
|
||||
- `new-event.json` → standard event creation handled by `calendar-event-crew.py`
|
||||
- `event-updated.json` / `event-started.json` / `event-ended.json` → in-flight updates processed by `calendar-meeting-crew.py`
|
||||
- `event-canceled.json` → cancellation workflow that alerts attendees via `calendar-meeting-crew.py`
|
||||
- Working location events use `calendar-working-location-crew.py` to extract on-site schedules
|
||||
```bash
|
||||
# Visualize todos os triggers disponíveis
|
||||
crewai triggers list
|
||||
|
||||
Each crew transforms raw event metadata (attendees, rooms, working locations) into the summaries your teams need.
|
||||
# Simule um trigger do Google Calendar com payload realista
|
||||
crewai triggers run google_calendar/event_changed
|
||||
```
|
||||
|
||||
O comando `crewai triggers run` executará sua crew com um payload completo do Calendar, permitindo que você teste sua lógica de parsing antes do deployment.
|
||||
|
||||
<Warning>
|
||||
Use `crewai triggers run google_calendar/event_changed` (não `crewai run`) para simular execução de trigger durante o desenvolvimento. Após o deployment, sua crew receberá automaticamente o payload do trigger.
|
||||
</Warning>
|
||||
|
||||
## Monitoring Executions
|
||||
|
||||
@@ -61,5 +68,7 @@ The **Executions** list in the deployment dashboard tracks every triggered run a
|
||||
## Troubleshooting
|
||||
|
||||
- Ensure the correct Google account is connected and the trigger is enabled
|
||||
- Teste localmente com `crewai triggers run google_calendar/event_changed` para ver a estrutura exata do payload
|
||||
- Confirm your workflow handles all-day events (payloads use `start.date` and `end.date` instead of timestamps)
|
||||
- Check execution logs if reminders or attendee arrays are missing—calendar permissions can limit fields in the payload
|
||||
- Lembre-se: use `crewai triggers run` (não `crewai run`) para simular execução de trigger
|
||||
|
||||
@@ -36,15 +36,23 @@ crew.kickoff({
|
||||
})
|
||||
```
|
||||
|
||||
## Sample payloads & crews
|
||||
## Testando Localmente
|
||||
|
||||
Explore the [Google Drive examples](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/google_drive) to cover different operations:
|
||||
Teste sua integração de trigger do Google Drive localmente usando a CLI da CrewAI:
|
||||
|
||||
- `new-file.json` → new uploads processed by `drive-file-crew.py`
|
||||
- `updated-file.json` → file edits and metadata changes handled by `drive-file-crew.py`
|
||||
- `deleted-file.json` → deletion events routed through `drive-file-deletion-crew.py`
|
||||
```bash
|
||||
# Visualize todos os triggers disponíveis
|
||||
crewai triggers list
|
||||
|
||||
Each crew highlights the file name, operation type, owner, permissions, and security considerations so downstream systems can respond appropriately.
|
||||
# Simule um trigger do Google Drive com payload realista
|
||||
crewai triggers run google_drive/file_changed
|
||||
```
|
||||
|
||||
O comando `crewai triggers run` executará sua crew com um payload completo do Drive, permitindo que você teste sua lógica de parsing antes do deployment.
|
||||
|
||||
<Warning>
|
||||
Use `crewai triggers run google_drive/file_changed` (não `crewai run`) para simular execução de trigger durante o desenvolvimento. Após o deployment, sua crew receberá automaticamente o payload do trigger.
|
||||
</Warning>
|
||||
|
||||
## Monitoring Executions
|
||||
|
||||
@@ -57,5 +65,7 @@ Track history and performance of triggered runs with the **Executions** list in
|
||||
## Troubleshooting
|
||||
|
||||
- Verify Google Drive is connected and the trigger toggle is enabled
|
||||
- Teste localmente com `crewai triggers run google_drive/file_changed` para ver a estrutura exata do payload
|
||||
- If a payload is missing permission data, ensure the connected account has access to the file or folder
|
||||
- The trigger sends file IDs only; use the Drive API if you need to fetch binary content during the crew run
|
||||
- Lembre-se: use `crewai triggers run` (não `crewai run`) para simular execução de trigger
|
||||
|
||||
@@ -37,16 +37,28 @@ print(result.raw)
|
||||
|
||||
The crew parses thread metadata (subject, created time, roster) and generates an action plan for the receiving team.
|
||||
|
||||
## Sample payloads & crews
|
||||
## Testando Localmente
|
||||
|
||||
The [Microsoft Teams examples](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/microsoft-teams) include:
|
||||
Teste sua integração de trigger do Microsoft Teams localmente usando a CLI da CrewAI:
|
||||
|
||||
- `chat-created.json` → chat creation payload processed by `teams-chat-created-crew.py`
|
||||
```bash
|
||||
# Visualize todos os triggers disponíveis
|
||||
crewai triggers list
|
||||
|
||||
The crew demonstrates how to extract participants, initial messages, tenant information, and compliance metadata from the Microsoft Graph webhook payload.
|
||||
# Simule um trigger do Microsoft Teams com payload realista
|
||||
crewai triggers run microsoft_teams/teams_message_created
|
||||
```
|
||||
|
||||
O comando `crewai triggers run` executará sua crew com um payload completo do Teams, permitindo que você teste sua lógica de parsing antes do deployment.
|
||||
|
||||
<Warning>
|
||||
Use `crewai triggers run microsoft_teams/teams_message_created` (não `crewai run`) para simular execução de trigger durante o desenvolvimento. Após o deployment, sua crew receberá automaticamente o payload do trigger.
|
||||
</Warning>
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
- Ensure the Teams connection is active; it must be refreshed if the tenant revokes permissions
|
||||
- Teste localmente com `crewai triggers run microsoft_teams/teams_message_created` para ver a estrutura exata do payload
|
||||
- Confirm the webhook subscription in Microsoft 365 is still valid if payloads stop arriving
|
||||
- Review execution logs for payload shape mismatches—Graph notifications may omit fields when a chat is private or restricted
|
||||
- Lembre-se: use `crewai triggers run` (não `crewai run`) para simular execução de trigger
|
||||
|
||||
@@ -36,18 +36,28 @@ crew.kickoff({
|
||||
|
||||
The crew inspects file metadata, user activity, and permission changes to produce a compliance-friendly summary.
|
||||
|
||||
## Sample payloads & crews
|
||||
## Testando Localmente
|
||||
|
||||
The [OneDrive examples](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/onedrive) showcase how to:
|
||||
Teste sua integração de trigger do OneDrive localmente usando a CLI da CrewAI:
|
||||
|
||||
- Parse file metadata, size, and folder paths
|
||||
- Track who created and last modified the file
|
||||
- Highlight permission and external sharing changes
|
||||
```bash
|
||||
# Visualize todos os triggers disponíveis
|
||||
crewai triggers list
|
||||
|
||||
`onedrive-file-crew.py` bundles the analysis and summarization tasks so you can add remediation steps as needed.
|
||||
# Simule um trigger do OneDrive com payload realista
|
||||
crewai triggers run microsoft_onedrive/file_changed
|
||||
```
|
||||
|
||||
O comando `crewai triggers run` executará sua crew com um payload completo do OneDrive, permitindo que você teste sua lógica de parsing antes do deployment.
|
||||
|
||||
<Warning>
|
||||
Use `crewai triggers run microsoft_onedrive/file_changed` (não `crewai run`) para simular execução de trigger durante o desenvolvimento. Após o deployment, sua crew receberá automaticamente o payload do trigger.
|
||||
</Warning>
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
- Ensure the connected account has permission to read the file metadata included in the webhook
|
||||
- Teste localmente com `crewai triggers run microsoft_onedrive/file_changed` para ver a estrutura exata do payload
|
||||
- If the trigger fires but the payload is missing `permissions`, confirm the site-level sharing settings allow Graph to return this field
|
||||
- For large tenants, filter notifications upstream so the crew only runs on relevant directories
|
||||
- Lembre-se: use `crewai triggers run` (não `crewai run`) para simular execução de trigger
|
||||
|
||||
@@ -36,17 +36,28 @@ crew.kickoff({
|
||||
|
||||
The crew extracts sender details, subject, body preview, and attachments before generating a structured response.
|
||||
|
||||
## Sample payloads & crews
|
||||
## Testando Localmente
|
||||
|
||||
Review the [Outlook examples](https://github.com/crewAIInc/crewai-enterprise-trigger-examples/tree/main/outlook) for two common scenarios:
|
||||
Teste sua integração de trigger do Outlook localmente usando a CLI da CrewAI:
|
||||
|
||||
- `new-message.json` → new mail notifications parsed by `outlook-message-crew.py`
|
||||
- `event-removed.json` → calendar cleanup handled by `outlook-event-removal-crew.py`
|
||||
```bash
|
||||
# Visualize todos os triggers disponíveis
|
||||
crewai triggers list
|
||||
|
||||
Each crew demonstrates how to handle Microsoft Graph payloads, normalize headers, and keep humans in-the-loop with concise summaries.
|
||||
# Simule um trigger do Outlook com payload realista
|
||||
crewai triggers run microsoft_outlook/email_received
|
||||
```
|
||||
|
||||
O comando `crewai triggers run` executará sua crew com um payload completo do Outlook, permitindo que você teste sua lógica de parsing antes do deployment.
|
||||
|
||||
<Warning>
|
||||
Use `crewai triggers run microsoft_outlook/email_received` (não `crewai run`) para simular execução de trigger durante o desenvolvimento. Após o deployment, sua crew receberá automaticamente o payload do trigger.
|
||||
</Warning>
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
- Verify the Outlook connector is still authorized; the subscription must be renewed periodically
|
||||
- Teste localmente com `crewai triggers run microsoft_outlook/email_received` para ver a estrutura exata do payload
|
||||
- If attachments are missing, confirm the webhook subscription includes the `includeResourceData` flag
|
||||
- Review execution logs when events fail to match—cancellation payloads lack attendee lists by design and the crew should account for that
|
||||
- Lembre-se: use `crewai triggers run` (não `crewai run`) para simular execução de trigger
|
||||
|
||||
@@ -12,10 +12,9 @@ dependencies = [
|
||||
"pytube>=15.0.0",
|
||||
"requests>=2.32.5",
|
||||
"docker>=7.1.0",
|
||||
"crewai==1.0.0a4",
|
||||
"crewai==1.0.0b2",
|
||||
"lancedb>=0.5.4",
|
||||
"tiktoken>=0.8.0",
|
||||
"stagehand>=0.4.1",
|
||||
"beautifulsoup4>=4.13.4",
|
||||
"pypdf>=5.9.0",
|
||||
"python-docx>=1.2.0",
|
||||
|
||||
@@ -291,4 +291,4 @@ __all__ = [
|
||||
"ZapierActionTools",
|
||||
]
|
||||
|
||||
__version__ = "1.0.0a4"
|
||||
__version__ = "1.0.0b2"
|
||||
|
||||
@@ -3,13 +3,16 @@
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from typing import Any, TypeAlias, TypedDict
|
||||
import uuid
|
||||
|
||||
from crewai.rag.config.types import RagConfigType
|
||||
from crewai.rag.config.utils import get_rag_client
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.factory import create_client
|
||||
from crewai.rag.qdrant.config import QdrantConfig
|
||||
from crewai.rag.types import BaseRecord, SearchResult
|
||||
from pydantic import PrivateAttr
|
||||
from qdrant_client.models import VectorParams
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
@@ -52,7 +55,11 @@ class CrewAIRagAdapter(Adapter):
|
||||
self._client = create_client(self.config)
|
||||
else:
|
||||
self._client = get_rag_client()
|
||||
self._client.get_or_create_collection(collection_name=self.collection_name)
|
||||
collection_params: dict[str, Any] = {"collection_name": self.collection_name}
|
||||
if isinstance(self.config, QdrantConfig) and self.config.vectors_config:
|
||||
if isinstance(self.config.vectors_config, VectorParams):
|
||||
collection_params["vectors_config"] = self.config.vectors_config
|
||||
self._client.get_or_create_collection(**collection_params)
|
||||
|
||||
def query(
|
||||
self,
|
||||
@@ -76,6 +83,8 @@ class CrewAIRagAdapter(Adapter):
|
||||
if similarity_threshold is not None
|
||||
else self.similarity_threshold
|
||||
)
|
||||
if self._client is None:
|
||||
raise ValueError("Client is not initialized")
|
||||
|
||||
results: list[SearchResult] = self._client.search(
|
||||
collection_name=self.collection_name,
|
||||
@@ -201,9 +210,10 @@ class CrewAIRagAdapter(Adapter):
|
||||
if isinstance(arg, dict):
|
||||
file_metadata.update(arg.get("metadata", {}))
|
||||
|
||||
chunk_id = hashlib.sha256(
|
||||
chunk_hash = hashlib.sha256(
|
||||
f"{file_result.doc_id}_{chunk_idx}_{file_chunk}".encode()
|
||||
).hexdigest()
|
||||
chunk_id = str(uuid.UUID(chunk_hash[:32]))
|
||||
|
||||
documents.append(
|
||||
{
|
||||
@@ -251,9 +261,10 @@ class CrewAIRagAdapter(Adapter):
|
||||
if isinstance(arg, dict):
|
||||
chunk_metadata.update(arg.get("metadata", {}))
|
||||
|
||||
chunk_id = hashlib.sha256(
|
||||
chunk_hash = hashlib.sha256(
|
||||
f"{loader_result.doc_id}_{i}_{chunk}".encode()
|
||||
).hexdigest()
|
||||
chunk_id = str(uuid.UUID(chunk_hash[:32]))
|
||||
|
||||
documents.append(
|
||||
{
|
||||
@@ -264,6 +275,8 @@ class CrewAIRagAdapter(Adapter):
|
||||
)
|
||||
|
||||
if documents:
|
||||
if self._client is None:
|
||||
raise ValueError("Client is not initialized")
|
||||
self._client.add_documents(
|
||||
collection_name=self.collection_name, documents=documents
|
||||
)
|
||||
|
||||
@@ -4,12 +4,12 @@ from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import chromadb
|
||||
import litellm
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
|
||||
from crewai_tools.rag.base_loader import BaseLoader
|
||||
from crewai_tools.rag.chunkers.base_chunker import BaseChunker
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from crewai_tools.rag.embedding_service import EmbeddingService
|
||||
from crewai_tools.rag.misc import compute_sha256
|
||||
from crewai_tools.rag.source_content import SourceContent
|
||||
from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
@@ -18,31 +18,6 @@ from crewai_tools.tools.rag.rag_tool import Adapter
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmbeddingService:
|
||||
def __init__(self, model: str = "text-embedding-3-small", **kwargs):
|
||||
self.model = model
|
||||
self.kwargs = kwargs
|
||||
|
||||
def embed_text(self, text: str) -> list[float]:
|
||||
try:
|
||||
response = litellm.embedding(model=self.model, input=[text], **self.kwargs)
|
||||
return response.data[0]["embedding"]
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating embedding: {e}")
|
||||
raise
|
||||
|
||||
def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
try:
|
||||
response = litellm.embedding(model=self.model, input=texts, **self.kwargs)
|
||||
return [data["embedding"] for data in response.data]
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating batch embeddings: {e}")
|
||||
raise
|
||||
|
||||
|
||||
class Document(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
content: str
|
||||
@@ -54,6 +29,7 @@ class Document(BaseModel):
|
||||
class RAG(Adapter):
|
||||
collection_name: str = "crewai_knowledge_base"
|
||||
persist_directory: str | None = None
|
||||
embedding_provider: str = "openai"
|
||||
embedding_model: str = "text-embedding-3-large"
|
||||
summarize: bool = False
|
||||
top_k: int = 5
|
||||
@@ -79,7 +55,9 @@ class RAG(Adapter):
|
||||
)
|
||||
|
||||
self._embedding_service = EmbeddingService(
|
||||
model=self.embedding_model, **self.embedding_config
|
||||
provider=self.embedding_provider,
|
||||
model=self.embedding_model,
|
||||
**self.embedding_config,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize ChromaDB: {e}")
|
||||
@@ -181,7 +159,7 @@ class RAG(Adapter):
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add documents to ChromaDB: {e}")
|
||||
|
||||
def query(self, question: str, where: dict[str, Any] | None = None) -> str:
|
||||
def query(self, question: str, where: dict[str, Any] | None = None) -> str: # type: ignore
|
||||
try:
|
||||
question_embedding = self._embedding_service.embed_text(question)
|
||||
|
||||
|
||||
508
lib/crewai-tools/src/crewai_tools/rag/embedding_service.py
Normal file
508
lib/crewai-tools/src/crewai_tools/rag/embedding_service.py
Normal file
@@ -0,0 +1,508 @@
|
||||
"""
|
||||
Enhanced embedding service that leverages CrewAI's existing embedding providers.
|
||||
This replaces the litellm-based EmbeddingService with a more flexible architecture.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmbeddingConfig(BaseModel):
|
||||
"""Configuration for embedding providers."""
|
||||
|
||||
provider: str = Field(description="Embedding provider name")
|
||||
model: str = Field(description="Model name to use")
|
||||
api_key: str | None = Field(default=None, description="API key for the provider")
|
||||
timeout: float | None = Field(
|
||||
default=30.0, description="Request timeout in seconds"
|
||||
)
|
||||
max_retries: int = Field(default=3, description="Maximum number of retries")
|
||||
batch_size: int = Field(
|
||||
default=100, description="Batch size for processing multiple texts"
|
||||
)
|
||||
extra_config: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Additional provider-specific configuration"
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingService:
|
||||
"""
|
||||
Enhanced embedding service that uses CrewAI's existing embedding providers.
|
||||
|
||||
Supports multiple providers:
|
||||
- openai: OpenAI embeddings (text-embedding-3-small, text-embedding-3-large, etc.)
|
||||
- voyageai: Voyage AI embeddings (voyage-2, voyage-large-2, etc.)
|
||||
- cohere: Cohere embeddings (embed-english-v3.0, embed-multilingual-v3.0, etc.)
|
||||
- google-generativeai: Google Gemini embeddings (models/embedding-001, etc.)
|
||||
- google-vertex: Google Vertex embeddings (models/embedding-001, etc.)
|
||||
- huggingface: Hugging Face embeddings (sentence-transformers/all-MiniLM-L6-v2, etc.)
|
||||
- jina: Jina embeddings (jina-embeddings-v2-base-en, etc.)
|
||||
- ollama: Ollama embeddings (nomic-embed-text, etc.)
|
||||
- openai: OpenAI embeddings (text-embedding-3-small, text-embedding-3-large, etc.)
|
||||
- roboflow: Roboflow embeddings (roboflow-embeddings-v2-base-en, etc.)
|
||||
- voyageai: Voyage AI embeddings (voyage-2, voyage-large-2, etc.)
|
||||
- watsonx: Watson X embeddings (ibm/slate-125m-english-rtrvr, etc.)
|
||||
- custom: Custom embeddings (embedding_callable, etc.)
|
||||
- sentence-transformer: Sentence Transformers embeddings (all-MiniLM-L6-v2, etc.)
|
||||
- text2vec: Text2Vec embeddings (text2vec-base-en, etc.)
|
||||
- openclip: OpenClip embeddings (openclip-large-v2, etc.)
|
||||
- instructor: Instructor embeddings (hkunlp/instructor-large, etc.)
|
||||
- onnx: ONNX embeddings (onnx-large-v2, etc.)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider: str = "openai",
|
||||
model: str = "text-embedding-3-small",
|
||||
api_key: str | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""
|
||||
Initialize the embedding service.
|
||||
|
||||
Args:
|
||||
provider: The embedding provider to use
|
||||
model: The model name
|
||||
api_key: API key (if not provided, will look for environment variables)
|
||||
**kwargs: Additional configuration options
|
||||
"""
|
||||
self.config = EmbeddingConfig(
|
||||
provider=provider,
|
||||
model=model,
|
||||
api_key=api_key or self._get_default_api_key(provider),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self._embedding_function = None
|
||||
self._initialize_embedding_function()
|
||||
|
||||
def _get_default_api_key(self, provider: str) -> str | None:
|
||||
"""Get default API key from environment variables."""
|
||||
env_key_map = {
|
||||
"azure": "AZURE_OPENAI_API_KEY",
|
||||
"amazon-bedrock": "AWS_ACCESS_KEY_ID", # or AWS_PROFILE
|
||||
"cohere": "COHERE_API_KEY",
|
||||
"google-generativeai": "GOOGLE_API_KEY",
|
||||
"google-vertex": "GOOGLE_APPLICATION_CREDENTIALS",
|
||||
"huggingface": "HUGGINGFACE_API_KEY",
|
||||
"jina": "JINA_API_KEY",
|
||||
"ollama": None, # Ollama typically runs locally without API key
|
||||
"openai": "OPENAI_API_KEY",
|
||||
"roboflow": "ROBOFLOW_API_KEY",
|
||||
"voyageai": "VOYAGE_API_KEY",
|
||||
"watsonx": "WATSONX_API_KEY",
|
||||
}
|
||||
|
||||
env_key = env_key_map.get(provider)
|
||||
if env_key:
|
||||
return os.getenv(env_key)
|
||||
return None
|
||||
|
||||
def _initialize_embedding_function(self):
|
||||
"""Initialize the embedding function using CrewAI's factory."""
|
||||
try:
|
||||
from crewai.rag.embeddings.factory import build_embedder
|
||||
|
||||
# Build the configuration for CrewAI's factory
|
||||
config = self._build_provider_config()
|
||||
|
||||
# Create the embedding function
|
||||
self._embedding_function = build_embedder(config)
|
||||
|
||||
logger.info(
|
||||
f"Initialized {self.config.provider} embedding service with model "
|
||||
f"{self.config.model}"
|
||||
)
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
f"CrewAI embedding providers not available. "
|
||||
f"Make sure crewai is installed: {e}"
|
||||
) from e
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize embedding function: {e}")
|
||||
raise RuntimeError(
|
||||
f"Failed to initialize {self.config.provider} embedding service: {e}"
|
||||
) from e
|
||||
|
||||
def _build_provider_config(self) -> dict[str, Any]:
|
||||
"""Build configuration dictionary for CrewAI's embedding factory."""
|
||||
base_config = {"provider": self.config.provider, "config": {}}
|
||||
|
||||
# Provider-specific configuration mapping
|
||||
if self.config.provider == "openai":
|
||||
base_config["config"] = {
|
||||
"api_key": self.config.api_key,
|
||||
"model_name": self.config.model,
|
||||
**self.config.extra_config,
|
||||
}
|
||||
elif self.config.provider == "azure":
|
||||
base_config["config"] = {
|
||||
"api_key": self.config.api_key,
|
||||
"model_name": self.config.model,
|
||||
**self.config.extra_config,
|
||||
}
|
||||
elif self.config.provider == "voyageai":
|
||||
base_config["config"] = {
|
||||
"api_key": self.config.api_key,
|
||||
"model": self.config.model,
|
||||
"max_retries": self.config.max_retries,
|
||||
"timeout": self.config.timeout,
|
||||
**self.config.extra_config,
|
||||
}
|
||||
elif self.config.provider == "cohere":
|
||||
base_config["config"] = {
|
||||
"api_key": self.config.api_key,
|
||||
"model_name": self.config.model,
|
||||
**self.config.extra_config,
|
||||
}
|
||||
elif self.config.provider in ["google-generativeai", "google-vertex"]:
|
||||
base_config["config"] = {
|
||||
"api_key": self.config.api_key,
|
||||
"model_name": self.config.model,
|
||||
**self.config.extra_config,
|
||||
}
|
||||
elif self.config.provider == "amazon-bedrock":
|
||||
base_config["config"] = {
|
||||
"aws_access_key_id": self.config.api_key,
|
||||
"model_name": self.config.model,
|
||||
**self.config.extra_config,
|
||||
}
|
||||
elif self.config.provider == "huggingface":
|
||||
base_config["config"] = {
|
||||
"api_key": self.config.api_key,
|
||||
"model_name": self.config.model,
|
||||
**self.config.extra_config,
|
||||
}
|
||||
elif self.config.provider == "jina":
|
||||
base_config["config"] = {
|
||||
"api_key": self.config.api_key,
|
||||
"model_name": self.config.model,
|
||||
**self.config.extra_config,
|
||||
}
|
||||
elif self.config.provider == "ollama":
|
||||
base_config["config"] = {
|
||||
"model": self.config.model,
|
||||
**self.config.extra_config,
|
||||
}
|
||||
elif self.config.provider == "sentence-transformer":
|
||||
base_config["config"] = {
|
||||
"model_name": self.config.model,
|
||||
**self.config.extra_config,
|
||||
}
|
||||
elif self.config.provider == "instructor":
|
||||
base_config["config"] = {
|
||||
"model_name": self.config.model,
|
||||
**self.config.extra_config,
|
||||
}
|
||||
elif self.config.provider == "onnx":
|
||||
base_config["config"] = {
|
||||
**self.config.extra_config,
|
||||
}
|
||||
elif self.config.provider == "roboflow":
|
||||
base_config["config"] = {
|
||||
"api_key": self.config.api_key,
|
||||
**self.config.extra_config,
|
||||
}
|
||||
elif self.config.provider == "openclip":
|
||||
base_config["config"] = {
|
||||
"model_name": self.config.model,
|
||||
**self.config.extra_config,
|
||||
}
|
||||
elif self.config.provider == "text2vec":
|
||||
base_config["config"] = {
|
||||
"model_name": self.config.model,
|
||||
**self.config.extra_config,
|
||||
}
|
||||
elif self.config.provider == "watsonx":
|
||||
base_config["config"] = {
|
||||
"api_key": self.config.api_key,
|
||||
"model_name": self.config.model,
|
||||
**self.config.extra_config,
|
||||
}
|
||||
elif self.config.provider == "custom":
|
||||
# Custom provider requires embedding_callable in extra_config
|
||||
base_config["config"] = {
|
||||
**self.config.extra_config,
|
||||
}
|
||||
else:
|
||||
# Generic configuration for any unlisted providers
|
||||
base_config["config"] = {
|
||||
"api_key": self.config.api_key,
|
||||
"model": self.config.model,
|
||||
**self.config.extra_config,
|
||||
}
|
||||
|
||||
return base_config
|
||||
|
||||
def embed_text(self, text: str) -> list[float]:
|
||||
"""
|
||||
Generate embedding for a single text.
|
||||
|
||||
Args:
|
||||
text: Text to embed
|
||||
|
||||
Returns:
|
||||
List of floats representing the embedding
|
||||
|
||||
Raises:
|
||||
RuntimeError: If embedding generation fails
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
logger.warning("Empty text provided for embedding")
|
||||
return []
|
||||
|
||||
try:
|
||||
# Use ChromaDB's embedding function interface
|
||||
embeddings = self._embedding_function([text]) # type: ignore
|
||||
return embeddings[0] if embeddings else []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating embedding for text: {e}")
|
||||
raise RuntimeError(f"Failed to generate embedding: {e}") from e
|
||||
|
||||
def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
||||
"""
|
||||
Generate embeddings for multiple texts.
|
||||
|
||||
Args:
|
||||
texts: List of texts to embed
|
||||
|
||||
Returns:
|
||||
List of embedding vectors
|
||||
|
||||
Raises:
|
||||
RuntimeError: If embedding generation fails
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
# Filter out empty texts
|
||||
valid_texts = [text for text in texts if text and text.strip()]
|
||||
if not valid_texts:
|
||||
logger.warning("No valid texts provided for batch embedding")
|
||||
return []
|
||||
|
||||
try:
|
||||
# Process in batches to avoid API limits
|
||||
all_embeddings = []
|
||||
|
||||
for i in range(0, len(valid_texts), self.config.batch_size):
|
||||
batch = valid_texts[i : i + self.config.batch_size]
|
||||
batch_embeddings = self._embedding_function(batch) # type: ignore
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
|
||||
return all_embeddings
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating batch embeddings: {e}")
|
||||
raise RuntimeError(f"Failed to generate batch embeddings: {e}") from e
|
||||
|
||||
def get_embedding_dimension(self) -> int | None:
|
||||
"""
|
||||
Get the dimension of embeddings produced by this service.
|
||||
|
||||
Returns:
|
||||
Embedding dimension or None if unknown
|
||||
"""
|
||||
# Try to get dimension by generating a test embedding
|
||||
try:
|
||||
test_embedding = self.embed_text("test")
|
||||
return len(test_embedding) if test_embedding else None
|
||||
except Exception:
|
||||
logger.warning("Could not determine embedding dimension")
|
||||
return None
|
||||
|
||||
def validate_connection(self) -> bool:
|
||||
"""
|
||||
Validate that the embedding service is working correctly.
|
||||
|
||||
Returns:
|
||||
True if the service is working, False otherwise
|
||||
"""
|
||||
try:
|
||||
test_embedding = self.embed_text("test connection")
|
||||
return len(test_embedding) > 0
|
||||
except Exception as e:
|
||||
logger.error(f"Connection validation failed: {e}")
|
||||
return False
|
||||
|
||||
def get_service_info(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get information about the current embedding service.
|
||||
|
||||
Returns:
|
||||
Dictionary with service information
|
||||
"""
|
||||
return {
|
||||
"provider": self.config.provider,
|
||||
"model": self.config.model,
|
||||
"embedding_dimension": self.get_embedding_dimension(),
|
||||
"batch_size": self.config.batch_size,
|
||||
"is_connected": self.validate_connection(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def list_supported_providers(cls) -> list[str]:
|
||||
"""
|
||||
List all supported embedding providers.
|
||||
|
||||
Returns:
|
||||
List of supported provider names
|
||||
"""
|
||||
return [
|
||||
"azure",
|
||||
"amazon-bedrock",
|
||||
"cohere",
|
||||
"custom",
|
||||
"google-generativeai",
|
||||
"google-vertex",
|
||||
"huggingface",
|
||||
"instructor",
|
||||
"jina",
|
||||
"ollama",
|
||||
"onnx",
|
||||
"openai",
|
||||
"openclip",
|
||||
"roboflow",
|
||||
"sentence-transformer",
|
||||
"text2vec",
|
||||
"voyageai",
|
||||
"watsonx",
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def create_openai_service(
|
||||
cls,
|
||||
model: str = "text-embedding-3-small",
|
||||
api_key: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> "EmbeddingService":
|
||||
"""Create an OpenAI embedding service."""
|
||||
return cls(provider="openai", model=model, api_key=api_key, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def create_voyage_service(
|
||||
cls, model: str = "voyage-2", api_key: str | None = None, **kwargs: Any
|
||||
) -> "EmbeddingService":
|
||||
"""Create a Voyage AI embedding service."""
|
||||
return cls(provider="voyageai", model=model, api_key=api_key, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def create_cohere_service(
|
||||
cls,
|
||||
model: str = "embed-english-v3.0",
|
||||
api_key: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> "EmbeddingService":
|
||||
"""Create a Cohere embedding service."""
|
||||
return cls(provider="cohere", model=model, api_key=api_key, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def create_gemini_service(
|
||||
cls,
|
||||
model: str = "models/embedding-001",
|
||||
api_key: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> "EmbeddingService":
|
||||
"""Create a Google Gemini embedding service."""
|
||||
return cls(
|
||||
provider="google-generativeai", model=model, api_key=api_key, **kwargs
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_azure_service(
|
||||
cls,
|
||||
model: str = "text-embedding-ada-002",
|
||||
api_key: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> "EmbeddingService":
|
||||
"""Create an Azure OpenAI embedding service."""
|
||||
return cls(provider="azure", model=model, api_key=api_key, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def create_bedrock_service(
|
||||
cls,
|
||||
model: str = "amazon.titan-embed-text-v1",
|
||||
api_key: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> "EmbeddingService":
|
||||
"""Create an Amazon Bedrock embedding service."""
|
||||
return cls(provider="amazon-bedrock", model=model, api_key=api_key, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def create_huggingface_service(
|
||||
cls,
|
||||
model: str = "sentence-transformers/all-MiniLM-L6-v2",
|
||||
api_key: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> "EmbeddingService":
|
||||
"""Create a Hugging Face embedding service."""
|
||||
return cls(provider="huggingface", model=model, api_key=api_key, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def create_sentence_transformer_service(
|
||||
cls,
|
||||
model: str = "all-MiniLM-L6-v2",
|
||||
**kwargs: Any,
|
||||
) -> "EmbeddingService":
|
||||
"""Create a Sentence Transformers embedding service (local)."""
|
||||
return cls(provider="sentence-transformer", model=model, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def create_ollama_service(
|
||||
cls,
|
||||
model: str = "nomic-embed-text",
|
||||
**kwargs: Any,
|
||||
) -> "EmbeddingService":
|
||||
"""Create an Ollama embedding service (local)."""
|
||||
return cls(provider="ollama", model=model, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def create_jina_service(
|
||||
cls,
|
||||
model: str = "jina-embeddings-v2-base-en",
|
||||
api_key: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> "EmbeddingService":
|
||||
"""Create a Jina AI embedding service."""
|
||||
return cls(provider="jina", model=model, api_key=api_key, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def create_instructor_service(
|
||||
cls,
|
||||
model: str = "hkunlp/instructor-large",
|
||||
**kwargs: Any,
|
||||
) -> "EmbeddingService":
|
||||
"""Create an Instructor embedding service."""
|
||||
return cls(provider="instructor", model=model, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def create_watsonx_service(
|
||||
cls,
|
||||
model: str = "ibm/slate-125m-english-rtrvr",
|
||||
api_key: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> "EmbeddingService":
|
||||
"""Create a Watson X embedding service."""
|
||||
return cls(provider="watsonx", model=model, api_key=api_key, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def create_custom_service(
|
||||
cls,
|
||||
embedding_callable: Any,
|
||||
**kwargs: Any,
|
||||
) -> "EmbeddingService":
|
||||
"""Create a custom embedding service with your own embedding function."""
|
||||
return cls(
|
||||
provider="custom",
|
||||
model="custom",
|
||||
extra_config={"embedding_callable": embedding_callable},
|
||||
**kwargs,
|
||||
)
|
||||
342
lib/crewai-tools/tests/rag/test_embedding_service.py
Normal file
342
lib/crewai-tools/tests/rag/test_embedding_service.py
Normal file
@@ -0,0 +1,342 @@
|
||||
"""
|
||||
Tests for the enhanced embedding service.
|
||||
"""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from crewai_tools.rag.embedding_service import EmbeddingService, EmbeddingConfig
|
||||
|
||||
|
||||
class TestEmbeddingConfig:
|
||||
"""Test the EmbeddingConfig model."""
|
||||
|
||||
def test_default_config(self):
|
||||
"""Test default configuration values."""
|
||||
config = EmbeddingConfig(provider="openai", model="text-embedding-3-small")
|
||||
|
||||
assert config.provider == "openai"
|
||||
assert config.model == "text-embedding-3-small"
|
||||
assert config.api_key is None
|
||||
assert config.timeout == 30.0
|
||||
assert config.max_retries == 3
|
||||
assert config.batch_size == 100
|
||||
assert config.extra_config == {}
|
||||
|
||||
def test_custom_config(self):
|
||||
"""Test custom configuration values."""
|
||||
config = EmbeddingConfig(
|
||||
provider="voyageai",
|
||||
model="voyage-2",
|
||||
api_key="test-key",
|
||||
timeout=60.0,
|
||||
max_retries=5,
|
||||
batch_size=50,
|
||||
extra_config={"input_type": "document"}
|
||||
)
|
||||
|
||||
assert config.provider == "voyageai"
|
||||
assert config.model == "voyage-2"
|
||||
assert config.api_key == "test-key"
|
||||
assert config.timeout == 60.0
|
||||
assert config.max_retries == 5
|
||||
assert config.batch_size == 50
|
||||
assert config.extra_config == {"input_type": "document"}
|
||||
|
||||
|
||||
class TestEmbeddingService:
|
||||
"""Test the EmbeddingService class."""
|
||||
|
||||
def test_list_supported_providers(self):
|
||||
"""Test listing supported providers."""
|
||||
providers = EmbeddingService.list_supported_providers()
|
||||
expected_providers = [
|
||||
"openai", "azure", "voyageai", "cohere", "google-generativeai",
|
||||
"amazon-bedrock", "huggingface", "jina", "ollama", "sentence-transformer",
|
||||
"instructor", "watsonx", "custom"
|
||||
]
|
||||
|
||||
assert isinstance(providers, list)
|
||||
assert len(providers) >= 15 # Should have at least 15 providers
|
||||
assert all(provider in providers for provider in expected_providers)
|
||||
|
||||
def test_get_default_api_key(self):
|
||||
"""Test getting default API keys from environment."""
|
||||
service = EmbeddingService.__new__(EmbeddingService) # Create without __init__
|
||||
|
||||
# Test with environment variable set
|
||||
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-openai-key"}):
|
||||
api_key = service._get_default_api_key("openai")
|
||||
assert api_key == "test-openai-key"
|
||||
|
||||
# Test with no environment variable
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
api_key = service._get_default_api_key("openai")
|
||||
assert api_key is None
|
||||
|
||||
# Test unknown provider
|
||||
api_key = service._get_default_api_key("unknown-provider")
|
||||
assert api_key is None
|
||||
|
||||
@patch('crewai.rag.embeddings.factory.build_embedder')
|
||||
def test_initialization_success(self, mock_build_embedder):
|
||||
"""Test successful initialization."""
|
||||
# Mock the embedding function
|
||||
mock_embedding_function = Mock()
|
||||
mock_build_embedder.return_value = mock_embedding_function
|
||||
|
||||
service = EmbeddingService(
|
||||
provider="openai",
|
||||
model="text-embedding-3-small",
|
||||
api_key="test-key"
|
||||
)
|
||||
|
||||
assert service.config.provider == "openai"
|
||||
assert service.config.model == "text-embedding-3-small"
|
||||
assert service.config.api_key == "test-key"
|
||||
assert service._embedding_function == mock_embedding_function
|
||||
|
||||
# Verify build_embedder was called with correct config
|
||||
mock_build_embedder.assert_called_once()
|
||||
call_args = mock_build_embedder.call_args[0][0]
|
||||
assert call_args["provider"] == "openai"
|
||||
assert call_args["config"]["api_key"] == "test-key"
|
||||
assert call_args["config"]["model_name"] == "text-embedding-3-small"
|
||||
|
||||
@patch('crewai.rag.embeddings.factory.build_embedder')
|
||||
def test_initialization_import_error(self, mock_build_embedder):
|
||||
"""Test initialization with import error."""
|
||||
mock_build_embedder.side_effect = ImportError("CrewAI not installed")
|
||||
|
||||
with pytest.raises(ImportError, match="CrewAI embedding providers not available"):
|
||||
EmbeddingService(provider="openai", model="test-model", api_key="test-key")
|
||||
|
||||
@patch('crewai.rag.embeddings.factory.build_embedder')
|
||||
def test_embed_text_success(self, mock_build_embedder):
|
||||
"""Test successful text embedding."""
|
||||
# Mock the embedding function
|
||||
mock_embedding_function = Mock()
|
||||
mock_embedding_function.return_value = [[0.1, 0.2, 0.3]]
|
||||
mock_build_embedder.return_value = mock_embedding_function
|
||||
|
||||
service = EmbeddingService(provider="openai", model="test-model", api_key="test-key")
|
||||
|
||||
result = service.embed_text("test text")
|
||||
|
||||
assert result == [0.1, 0.2, 0.3]
|
||||
mock_embedding_function.assert_called_once_with(["test text"])
|
||||
|
||||
@patch('crewai.rag.embeddings.factory.build_embedder')
|
||||
def test_embed_text_empty_input(self, mock_build_embedder):
|
||||
"""Test embedding empty text."""
|
||||
mock_embedding_function = Mock()
|
||||
mock_build_embedder.return_value = mock_embedding_function
|
||||
|
||||
service = EmbeddingService(provider="openai", model="test-model", api_key="test-key")
|
||||
|
||||
result = service.embed_text("")
|
||||
assert result == []
|
||||
|
||||
result = service.embed_text(" ")
|
||||
assert result == []
|
||||
|
||||
# Embedding function should not be called for empty text
|
||||
mock_embedding_function.assert_not_called()
|
||||
|
||||
@patch('crewai.rag.embeddings.factory.build_embedder')
|
||||
def test_embed_batch_success(self, mock_build_embedder):
|
||||
"""Test successful batch embedding."""
|
||||
# Mock the embedding function
|
||||
mock_embedding_function = Mock()
|
||||
mock_embedding_function.return_value = [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]
|
||||
mock_build_embedder.return_value = mock_embedding_function
|
||||
|
||||
service = EmbeddingService(provider="openai", model="test-model", api_key="test-key")
|
||||
|
||||
texts = ["text1", "text2", "text3"]
|
||||
result = service.embed_batch(texts)
|
||||
|
||||
assert result == [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]
|
||||
mock_embedding_function.assert_called_once_with(texts)
|
||||
|
||||
@patch('crewai.rag.embeddings.factory.build_embedder')
|
||||
def test_embed_batch_empty_input(self, mock_build_embedder):
|
||||
"""Test batch embedding with empty input."""
|
||||
mock_embedding_function = Mock()
|
||||
mock_build_embedder.return_value = mock_embedding_function
|
||||
|
||||
service = EmbeddingService(provider="openai", model="test-model", api_key="test-key")
|
||||
|
||||
# Empty list
|
||||
result = service.embed_batch([])
|
||||
assert result == []
|
||||
|
||||
# List with empty strings
|
||||
result = service.embed_batch(["", " ", ""])
|
||||
assert result == []
|
||||
|
||||
# Embedding function should not be called for empty input
|
||||
mock_embedding_function.assert_not_called()
|
||||
|
||||
@patch('crewai.rag.embeddings.factory.build_embedder')
|
||||
def test_validate_connection(self, mock_build_embedder):
|
||||
"""Test connection validation."""
|
||||
# Mock successful embedding
|
||||
mock_embedding_function = Mock()
|
||||
mock_embedding_function.return_value = [[0.1, 0.2, 0.3]]
|
||||
mock_build_embedder.return_value = mock_embedding_function
|
||||
|
||||
service = EmbeddingService(provider="openai", model="test-model", api_key="test-key")
|
||||
|
||||
assert service.validate_connection() is True
|
||||
|
||||
# Mock failed embedding
|
||||
mock_embedding_function.side_effect = Exception("Connection failed")
|
||||
assert service.validate_connection() is False
|
||||
|
||||
@patch('crewai.rag.embeddings.factory.build_embedder')
|
||||
def test_get_service_info(self, mock_build_embedder):
|
||||
"""Test getting service information."""
|
||||
# Mock the embedding function
|
||||
mock_embedding_function = Mock()
|
||||
mock_embedding_function.return_value = [[0.1, 0.2, 0.3]]
|
||||
mock_build_embedder.return_value = mock_embedding_function
|
||||
|
||||
service = EmbeddingService(provider="openai", model="test-model", api_key="test-key")
|
||||
|
||||
info = service.get_service_info()
|
||||
|
||||
assert info["provider"] == "openai"
|
||||
assert info["model"] == "test-model"
|
||||
assert info["embedding_dimension"] == 3
|
||||
assert info["batch_size"] == 100
|
||||
assert info["is_connected"] is True
|
||||
|
||||
def test_create_openai_service(self):
|
||||
"""Test OpenAI service creation."""
|
||||
with patch('crewai.rag.embeddings.factory.build_embedder'):
|
||||
service = EmbeddingService.create_openai_service(
|
||||
model="text-embedding-3-large",
|
||||
api_key="test-key"
|
||||
)
|
||||
|
||||
assert service.config.provider == "openai"
|
||||
assert service.config.model == "text-embedding-3-large"
|
||||
assert service.config.api_key == "test-key"
|
||||
|
||||
def test_create_voyage_service(self):
|
||||
"""Test Voyage AI service creation."""
|
||||
with patch('crewai.rag.embeddings.factory.build_embedder'):
|
||||
service = EmbeddingService.create_voyage_service(
|
||||
model="voyage-large-2",
|
||||
api_key="test-key"
|
||||
)
|
||||
|
||||
assert service.config.provider == "voyageai"
|
||||
assert service.config.model == "voyage-large-2"
|
||||
assert service.config.api_key == "test-key"
|
||||
|
||||
def test_create_cohere_service(self):
|
||||
"""Test Cohere service creation."""
|
||||
with patch('crewai.rag.embeddings.factory.build_embedder'):
|
||||
service = EmbeddingService.create_cohere_service(
|
||||
model="embed-multilingual-v3.0",
|
||||
api_key="test-key"
|
||||
)
|
||||
|
||||
assert service.config.provider == "cohere"
|
||||
assert service.config.model == "embed-multilingual-v3.0"
|
||||
assert service.config.api_key == "test-key"
|
||||
|
||||
def test_create_gemini_service(self):
|
||||
"""Test Gemini service creation."""
|
||||
with patch('crewai.rag.embeddings.factory.build_embedder'):
|
||||
service = EmbeddingService.create_gemini_service(
|
||||
model="models/embedding-001",
|
||||
api_key="test-key"
|
||||
)
|
||||
|
||||
assert service.config.provider == "google-generativeai"
|
||||
assert service.config.model == "models/embedding-001"
|
||||
assert service.config.api_key == "test-key"
|
||||
|
||||
|
||||
class TestProviderConfigurations:
|
||||
"""Test provider-specific configurations."""
|
||||
|
||||
@patch('crewai.rag.embeddings.factory.build_embedder')
|
||||
def test_openai_config(self, mock_build_embedder):
|
||||
"""Test OpenAI configuration mapping."""
|
||||
mock_build_embedder.return_value = Mock()
|
||||
|
||||
service = EmbeddingService(
|
||||
provider="openai",
|
||||
model="text-embedding-3-small",
|
||||
api_key="test-key",
|
||||
extra_config={"dimensions": 1024}
|
||||
)
|
||||
|
||||
# Check the configuration passed to build_embedder
|
||||
call_args = mock_build_embedder.call_args[0][0]
|
||||
assert call_args["provider"] == "openai"
|
||||
assert call_args["config"]["api_key"] == "test-key"
|
||||
assert call_args["config"]["model_name"] == "text-embedding-3-small"
|
||||
assert call_args["config"]["dimensions"] == 1024
|
||||
|
||||
@patch('crewai.rag.embeddings.factory.build_embedder')
|
||||
def test_voyageai_config(self, mock_build_embedder):
|
||||
"""Test Voyage AI configuration mapping."""
|
||||
mock_build_embedder.return_value = Mock()
|
||||
|
||||
service = EmbeddingService(
|
||||
provider="voyageai",
|
||||
model="voyage-2",
|
||||
api_key="test-key",
|
||||
timeout=60.0,
|
||||
max_retries=5,
|
||||
extra_config={"input_type": "document"}
|
||||
)
|
||||
|
||||
# Check the configuration passed to build_embedder
|
||||
call_args = mock_build_embedder.call_args[0][0]
|
||||
assert call_args["provider"] == "voyageai"
|
||||
assert call_args["config"]["api_key"] == "test-key"
|
||||
assert call_args["config"]["model"] == "voyage-2"
|
||||
assert call_args["config"]["timeout"] == 60.0
|
||||
assert call_args["config"]["max_retries"] == 5
|
||||
assert call_args["config"]["input_type"] == "document"
|
||||
|
||||
@patch('crewai.rag.embeddings.factory.build_embedder')
|
||||
def test_cohere_config(self, mock_build_embedder):
|
||||
"""Test Cohere configuration mapping."""
|
||||
mock_build_embedder.return_value = Mock()
|
||||
|
||||
service = EmbeddingService(
|
||||
provider="cohere",
|
||||
model="embed-english-v3.0",
|
||||
api_key="test-key"
|
||||
)
|
||||
|
||||
# Check the configuration passed to build_embedder
|
||||
call_args = mock_build_embedder.call_args[0][0]
|
||||
assert call_args["provider"] == "cohere"
|
||||
assert call_args["config"]["api_key"] == "test-key"
|
||||
assert call_args["config"]["model_name"] == "embed-english-v3.0"
|
||||
|
||||
@patch('crewai.rag.embeddings.factory.build_embedder')
|
||||
def test_gemini_config(self, mock_build_embedder):
|
||||
"""Test Gemini configuration mapping."""
|
||||
mock_build_embedder.return_value = Mock()
|
||||
|
||||
service = EmbeddingService(
|
||||
provider="google-generativeai",
|
||||
model="models/embedding-001",
|
||||
api_key="test-key"
|
||||
)
|
||||
|
||||
# Check the configuration passed to build_embedder
|
||||
call_args = mock_build_embedder.call_args[0][0]
|
||||
assert call_args["provider"] == "google-generativeai"
|
||||
assert call_args["config"]["api_key"] == "test-key"
|
||||
assert call_args["config"]["model_name"] == "models/embedding-001"
|
||||
@@ -35,7 +35,6 @@ dependencies = [
|
||||
"uv>=0.4.25",
|
||||
"tomli-w>=1.1.0",
|
||||
"tomli>=2.0.2",
|
||||
"blinker>=1.9.0",
|
||||
"json5>=0.10.0",
|
||||
"portalocker==2.7.0",
|
||||
"pydantic-settings>=2.10.1",
|
||||
@@ -49,7 +48,7 @@ Repository = "https://github.com/crewAIInc/crewAI"
|
||||
|
||||
[project.optional-dependencies]
|
||||
tools = [
|
||||
"crewai-tools==1.0.0a4",
|
||||
"crewai-tools==1.0.0b2",
|
||||
]
|
||||
embeddings = [
|
||||
"tiktoken~=0.8.0"
|
||||
@@ -85,6 +84,9 @@ voyageai = [
|
||||
litellm = [
|
||||
"litellm>=1.74.9",
|
||||
]
|
||||
boto3 = [
|
||||
"boto3>=1.40.45",
|
||||
]
|
||||
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -40,7 +40,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
|
||||
|
||||
_suppress_pydantic_deprecation_warnings()
|
||||
|
||||
__version__ = "1.0.0a4"
|
||||
__version__ = "1.0.0b2"
|
||||
_telemetry_submitted = False
|
||||
|
||||
|
||||
|
||||
@@ -30,6 +30,7 @@ def validate_jwt_token(
|
||||
algorithms=["RS256"],
|
||||
audience=audience,
|
||||
issuer=issuer,
|
||||
leeway=10.0,
|
||||
options={
|
||||
"verify_signature": True,
|
||||
"verify_exp": True,
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import os
|
||||
import subprocess
|
||||
from enum import Enum
|
||||
|
||||
import click
|
||||
from packaging import version
|
||||
|
||||
from crewai.cli.utils import read_toml
|
||||
from crewai.cli.utils import build_env_with_tool_repository_credentials, read_toml
|
||||
from crewai.cli.version import get_crewai_version
|
||||
|
||||
|
||||
@@ -55,8 +56,22 @@ def execute_command(crew_type: CrewType) -> None:
|
||||
"""
|
||||
command = ["uv", "run", "kickoff" if crew_type == CrewType.FLOW else "run_crew"]
|
||||
|
||||
env = os.environ.copy()
|
||||
try:
|
||||
subprocess.run(command, capture_output=False, text=True, check=True) # noqa: S603
|
||||
pyproject_data = read_toml()
|
||||
sources = pyproject_data.get("tool", {}).get("uv", {}).get("sources", {})
|
||||
|
||||
for source_config in sources.values():
|
||||
if isinstance(source_config, dict):
|
||||
index = source_config.get("index")
|
||||
if index:
|
||||
index_env = build_env_with_tool_repository_credentials(index)
|
||||
env.update(index_env)
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
try:
|
||||
subprocess.run(command, capture_output=False, text=True, check=True, env=env) # noqa: S603
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
handle_error(e, crew_type)
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]>=0.203.0,<1.0.0"
|
||||
"crewai[tools]>=0.203.1,<1.0.0"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]>=0.203.0,<1.0.0",
|
||||
"crewai[tools]>=0.203.1,<1.0.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "Power up your crews with {{folder_name}}"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]>=0.203.0"
|
||||
"crewai[tools]>=0.203.1"
|
||||
]
|
||||
|
||||
[tool.crewai]
|
||||
|
||||
@@ -5,10 +5,13 @@ This module provides the event infrastructure that allows users to:
|
||||
- Track memory operations and performance
|
||||
- Build custom logging and analytics
|
||||
- Extend CrewAI with custom event handlers
|
||||
- Declare handler dependencies for ordered execution
|
||||
"""
|
||||
|
||||
from crewai.events.base_event_listener import BaseEventListener
|
||||
from crewai.events.depends import Depends
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.handler_graph import CircularDependencyError
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentEvaluationCompletedEvent,
|
||||
AgentEvaluationFailedEvent,
|
||||
@@ -109,6 +112,7 @@ __all__ = [
|
||||
"AgentReasoningFailedEvent",
|
||||
"AgentReasoningStartedEvent",
|
||||
"BaseEventListener",
|
||||
"CircularDependencyError",
|
||||
"CrewKickoffCompletedEvent",
|
||||
"CrewKickoffFailedEvent",
|
||||
"CrewKickoffStartedEvent",
|
||||
@@ -119,6 +123,7 @@ __all__ = [
|
||||
"CrewTrainCompletedEvent",
|
||||
"CrewTrainFailedEvent",
|
||||
"CrewTrainStartedEvent",
|
||||
"Depends",
|
||||
"FlowCreatedEvent",
|
||||
"FlowEvent",
|
||||
"FlowFinishedEvent",
|
||||
|
||||
@@ -9,6 +9,7 @@ class BaseEventListener(ABC):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.setup_listeners(crewai_event_bus)
|
||||
crewai_event_bus.validate_dependencies()
|
||||
|
||||
@abstractmethod
|
||||
def setup_listeners(self, crewai_event_bus: CrewAIEventsBus):
|
||||
|
||||
105
lib/crewai/src/crewai/events/depends.py
Normal file
105
lib/crewai/src/crewai/events/depends.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""Dependency injection system for event handlers.
|
||||
|
||||
This module provides a FastAPI-style dependency system that allows event handlers
|
||||
to declare dependencies on other handlers, ensuring proper execution order while
|
||||
maintaining parallelism where possible.
|
||||
"""
|
||||
|
||||
from collections.abc import Coroutine
|
||||
from typing import Any, Generic, Protocol, TypeVar
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
|
||||
|
||||
EventT_co = TypeVar("EventT_co", bound=BaseEvent, contravariant=True)
|
||||
|
||||
|
||||
class EventHandler(Protocol[EventT_co]):
|
||||
"""Protocol for event handler functions.
|
||||
|
||||
Generic protocol that accepts any subclass of BaseEvent.
|
||||
Handlers can be either synchronous (returning None) or asynchronous
|
||||
(returning a coroutine).
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self, source: Any, event: EventT_co, /
|
||||
) -> None | Coroutine[Any, Any, None]:
|
||||
"""Event handler signature.
|
||||
|
||||
Args:
|
||||
source: The object that emitted the event
|
||||
event: The event instance (any BaseEvent subclass)
|
||||
|
||||
Returns:
|
||||
None for sync handlers, Coroutine for async handlers
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
T = TypeVar("T", bound=EventHandler[Any])
|
||||
|
||||
|
||||
class Depends(Generic[T]):
|
||||
"""Declares a dependency on another event handler.
|
||||
|
||||
Similar to FastAPI's Depends, this allows handlers to specify that they
|
||||
depend on other handlers completing first. Handlers with dependencies will
|
||||
execute after their dependencies, while independent handlers can run in parallel.
|
||||
|
||||
Args:
|
||||
handler: The handler function that this handler depends on
|
||||
|
||||
Example:
|
||||
>>> from crewai.events import Depends, crewai_event_bus
|
||||
>>> from crewai.events import LLMCallStartedEvent
|
||||
>>> @crewai_event_bus.on(LLMCallStartedEvent)
|
||||
>>> def setup_context(source, event):
|
||||
... return {"initialized": True}
|
||||
>>>
|
||||
>>> @crewai_event_bus.on(LLMCallStartedEvent, depends_on=Depends(setup_context))
|
||||
>>> def process(source, event):
|
||||
... # Runs after setup_context completes
|
||||
... pass
|
||||
"""
|
||||
|
||||
def __init__(self, handler: T) -> None:
|
||||
"""Initialize a dependency on a handler.
|
||||
|
||||
Args:
|
||||
handler: The handler function this depends on
|
||||
"""
|
||||
self.handler = handler
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return a string representation of the dependency.
|
||||
|
||||
Returns:
|
||||
A string showing the dependent handler name
|
||||
"""
|
||||
handler_name = getattr(self.handler, "__name__", repr(self.handler))
|
||||
return f"Depends({handler_name})"
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Check equality based on the handler reference.
|
||||
|
||||
Args:
|
||||
other: Another Depends instance to compare
|
||||
|
||||
Returns:
|
||||
True if both depend on the same handler, False otherwise
|
||||
"""
|
||||
if not isinstance(other, Depends):
|
||||
return False
|
||||
return self.handler is other.handler
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Return hash based on handler identity.
|
||||
|
||||
Since equality is based on identity (is), we hash the handler
|
||||
object directly rather than its id for consistency.
|
||||
|
||||
Returns:
|
||||
Hash of the handler object
|
||||
"""
|
||||
return id(self.handler)
|
||||
@@ -1,125 +1,507 @@
|
||||
from __future__ import annotations
|
||||
"""Event bus for managing and dispatching events in CrewAI.
|
||||
|
||||
import threading
|
||||
from collections.abc import Callable
|
||||
This module provides a singleton event bus that allows registration and handling
|
||||
of events throughout the CrewAI system, supporting both synchronous and asynchronous
|
||||
event handlers with optional dependency management.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import atexit
|
||||
from collections.abc import Callable, Generator
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, TypeVar, cast
|
||||
import threading
|
||||
from typing import Any, Final, ParamSpec, TypeVar
|
||||
|
||||
from blinker import Signal
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
from crewai.events.event_types import EventTypes
|
||||
from crewai.events.depends import Depends
|
||||
from crewai.events.handler_graph import build_execution_plan
|
||||
from crewai.events.types.event_bus_types import (
|
||||
AsyncHandler,
|
||||
AsyncHandlerSet,
|
||||
ExecutionPlan,
|
||||
Handler,
|
||||
SyncHandler,
|
||||
SyncHandlerSet,
|
||||
)
|
||||
from crewai.events.types.llm_events import LLMStreamChunkEvent
|
||||
from crewai.events.utils.console_formatter import ConsoleFormatter
|
||||
from crewai.events.utils.handlers import is_async_handler, is_call_handler_safe
|
||||
from crewai.events.utils.rw_lock import RWLock
|
||||
|
||||
EventT = TypeVar("EventT", bound=BaseEvent)
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class CrewAIEventsBus:
|
||||
"""
|
||||
A singleton event bus that uses blinker signals for event handling.
|
||||
Allows both internal (Flow/Crew) and external event handling.
|
||||
"""Singleton event bus for handling events in CrewAI.
|
||||
|
||||
This class manages event registration and emission for both synchronous
|
||||
and asynchronous event handlers, automatically scheduling async handlers
|
||||
in a dedicated background event loop.
|
||||
|
||||
Synchronous handlers execute in a thread pool executor to ensure completion
|
||||
before program exit. Asynchronous handlers execute in a dedicated event loop
|
||||
running in a daemon thread, with graceful shutdown waiting for completion.
|
||||
|
||||
Attributes:
|
||||
_instance: Singleton instance of the event bus
|
||||
_instance_lock: Reentrant lock for singleton initialization (class-level)
|
||||
_rwlock: Read-write lock for handler registration and access (instance-level)
|
||||
_sync_handlers: Mapping of event types to registered synchronous handlers
|
||||
_async_handlers: Mapping of event types to registered asynchronous handlers
|
||||
_sync_executor: Thread pool executor for running synchronous handlers
|
||||
_loop: Dedicated asyncio event loop for async handler execution
|
||||
_loop_thread: Background daemon thread running the event loop
|
||||
_console: Console formatter for error output
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
_instance: Self | None = None
|
||||
_instance_lock: threading.RLock = threading.RLock()
|
||||
_rwlock: RWLock
|
||||
_sync_handlers: dict[type[BaseEvent], SyncHandlerSet]
|
||||
_async_handlers: dict[type[BaseEvent], AsyncHandlerSet]
|
||||
_handler_dependencies: dict[type[BaseEvent], dict[Handler, list[Depends]]]
|
||||
_execution_plan_cache: dict[type[BaseEvent], ExecutionPlan]
|
||||
_console: ConsoleFormatter
|
||||
_shutting_down: bool
|
||||
|
||||
def __new__(cls):
|
||||
def __new__(cls) -> Self:
|
||||
"""Create or return the singleton instance.
|
||||
|
||||
Returns:
|
||||
The singleton CrewAIEventsBus instance
|
||||
"""
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None: # prevent race condition
|
||||
with cls._instance_lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialize()
|
||||
return cls._instance
|
||||
|
||||
def _initialize(self) -> None:
|
||||
"""Initialize the event bus internal state"""
|
||||
self._signal = Signal("crewai_event_bus")
|
||||
self._handlers: dict[type[BaseEvent], list[Callable]] = {}
|
||||
"""Initialize the event bus internal state.
|
||||
|
||||
Creates handler dictionaries and starts a dedicated background
|
||||
event loop for async handler execution.
|
||||
"""
|
||||
self._shutting_down = False
|
||||
self._rwlock = RWLock()
|
||||
self._sync_handlers: dict[type[BaseEvent], SyncHandlerSet] = {}
|
||||
self._async_handlers: dict[type[BaseEvent], AsyncHandlerSet] = {}
|
||||
self._handler_dependencies: dict[type[BaseEvent], dict[Handler, list[Depends]]] = {}
|
||||
self._execution_plan_cache: dict[type[BaseEvent], ExecutionPlan] = {}
|
||||
self._sync_executor = ThreadPoolExecutor(
|
||||
max_workers=10,
|
||||
thread_name_prefix="CrewAISyncHandler",
|
||||
)
|
||||
self._console = ConsoleFormatter()
|
||||
|
||||
self._loop = asyncio.new_event_loop()
|
||||
self._loop_thread = threading.Thread(
|
||||
target=self._run_loop,
|
||||
name="CrewAIEventsLoop",
|
||||
daemon=True,
|
||||
)
|
||||
self._loop_thread.start()
|
||||
|
||||
def _run_loop(self) -> None:
|
||||
"""Run the background async event loop."""
|
||||
asyncio.set_event_loop(self._loop)
|
||||
self._loop.run_forever()
|
||||
|
||||
def _register_handler(
|
||||
self,
|
||||
event_type: type[BaseEvent],
|
||||
handler: Callable[..., Any],
|
||||
dependencies: list[Depends] | None = None,
|
||||
) -> None:
|
||||
"""Register a handler for the given event type.
|
||||
|
||||
Args:
|
||||
event_type: The event class to listen for
|
||||
handler: The handler function to register
|
||||
dependencies: Optional list of dependencies
|
||||
"""
|
||||
with self._rwlock.w_locked():
|
||||
if is_async_handler(handler):
|
||||
existing_async = self._async_handlers.get(event_type, frozenset())
|
||||
self._async_handlers[event_type] = existing_async | {handler}
|
||||
else:
|
||||
existing_sync = self._sync_handlers.get(event_type, frozenset())
|
||||
self._sync_handlers[event_type] = existing_sync | {handler}
|
||||
|
||||
if dependencies:
|
||||
if event_type not in self._handler_dependencies:
|
||||
self._handler_dependencies[event_type] = {}
|
||||
self._handler_dependencies[event_type][handler] = dependencies
|
||||
|
||||
self._execution_plan_cache.pop(event_type, None)
|
||||
|
||||
def on(
|
||||
self, event_type: type[EventT]
|
||||
) -> Callable[[Callable[[Any, EventT], None]], Callable[[Any, EventT], None]]:
|
||||
"""
|
||||
Decorator to register an event handler for a specific event type.
|
||||
self,
|
||||
event_type: type[BaseEvent],
|
||||
depends_on: Depends | list[Depends] | None = None,
|
||||
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
"""Decorator to register an event handler for a specific event type.
|
||||
|
||||
Usage:
|
||||
@crewai_event_bus.on(AgentExecutionCompletedEvent)
|
||||
def on_agent_execution_completed(
|
||||
source: Any, event: AgentExecutionCompletedEvent
|
||||
):
|
||||
print(f"👍 Agent '{event.agent}' completed task")
|
||||
print(f" Output: {event.output}")
|
||||
Args:
|
||||
event_type: The event class to listen for
|
||||
depends_on: Optional dependency or list of dependencies. Handlers with
|
||||
dependencies will execute after their dependencies complete.
|
||||
|
||||
Returns:
|
||||
Decorator function that registers the handler
|
||||
|
||||
Example:
|
||||
>>> from crewai.events import crewai_event_bus, Depends
|
||||
>>> from crewai.events.types.llm_events import LLMCallStartedEvent
|
||||
>>>
|
||||
>>> @crewai_event_bus.on(LLMCallStartedEvent)
|
||||
>>> def setup_context(source, event):
|
||||
... print("Setting up context")
|
||||
>>>
|
||||
>>> @crewai_event_bus.on(LLMCallStartedEvent, depends_on=Depends(setup_context))
|
||||
>>> def process(source, event):
|
||||
... print("Processing (runs after setup_context)")
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
handler: Callable[[Any, EventT], None],
|
||||
) -> Callable[[Any, EventT], None]:
|
||||
if event_type not in self._handlers:
|
||||
self._handlers[event_type] = []
|
||||
self._handlers[event_type].append(
|
||||
cast(Callable[[Any, EventT], None], handler)
|
||||
)
|
||||
def decorator(handler: Callable[P, R]) -> Callable[P, R]:
|
||||
"""Register the handler and return it unchanged.
|
||||
|
||||
Args:
|
||||
handler: Event handler function to register
|
||||
|
||||
Returns:
|
||||
The same handler function unchanged
|
||||
"""
|
||||
deps = None
|
||||
if depends_on is not None:
|
||||
deps = [depends_on] if isinstance(depends_on, Depends) else depends_on
|
||||
|
||||
self._register_handler(event_type, handler, dependencies=deps)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
@staticmethod
|
||||
def _call_handler(
|
||||
handler: Callable, source: Any, event: BaseEvent, event_type: type
|
||||
def _call_handlers(
|
||||
self,
|
||||
source: Any,
|
||||
event: BaseEvent,
|
||||
handlers: SyncHandlerSet,
|
||||
) -> None:
|
||||
"""Call a single handler with error handling."""
|
||||
try:
|
||||
handler(source, event)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"[EventBus Error] Handler '{handler.__name__}' failed for event '{event_type.__name__}': {e}"
|
||||
"""Call provided synchronous handlers.
|
||||
|
||||
Args:
|
||||
source: The emitting object
|
||||
event: The event instance
|
||||
handlers: Frozenset of sync handlers to call
|
||||
"""
|
||||
errors: list[tuple[SyncHandler, Exception]] = [
|
||||
(handler, error)
|
||||
for handler in handlers
|
||||
if (error := is_call_handler_safe(handler, source, event)) is not None
|
||||
]
|
||||
|
||||
if errors:
|
||||
for handler, error in errors:
|
||||
self._console.print(
|
||||
f"[CrewAIEventsBus] Sync handler error in {handler.__name__}: {error}"
|
||||
)
|
||||
|
||||
async def _acall_handlers(
|
||||
self,
|
||||
source: Any,
|
||||
event: BaseEvent,
|
||||
handlers: AsyncHandlerSet,
|
||||
) -> None:
|
||||
"""Asynchronously call provided async handlers.
|
||||
|
||||
Args:
|
||||
source: The object that emitted the event
|
||||
event: The event instance
|
||||
handlers: Frozenset of async handlers to call
|
||||
"""
|
||||
coros = [handler(source, event) for handler in handlers]
|
||||
results = await asyncio.gather(*coros, return_exceptions=True)
|
||||
for handler, result in zip(handlers, results, strict=False):
|
||||
if isinstance(result, Exception):
|
||||
self._console.print(
|
||||
f"[CrewAIEventsBus] Async handler error in {getattr(handler, '__name__', handler)}: {result}"
|
||||
)
|
||||
|
||||
async def _emit_with_dependencies(self, source: Any, event: BaseEvent) -> None:
|
||||
"""Emit an event with dependency-aware handler execution.
|
||||
|
||||
Handlers are grouped into execution levels based on their dependencies.
|
||||
Within each level, async handlers run concurrently while sync handlers
|
||||
run sequentially (or in thread pool). Each level completes before the
|
||||
next level starts.
|
||||
|
||||
Uses a cached execution plan for performance. The plan is built once
|
||||
per event type and cached until handlers are modified.
|
||||
|
||||
Args:
|
||||
source: The emitting object
|
||||
event: The event instance to emit
|
||||
"""
|
||||
event_type = type(event)
|
||||
|
||||
with self._rwlock.r_locked():
|
||||
if self._shutting_down:
|
||||
return
|
||||
cached_plan = self._execution_plan_cache.get(event_type)
|
||||
if cached_plan is not None:
|
||||
sync_handlers = self._sync_handlers.get(event_type, frozenset())
|
||||
async_handlers = self._async_handlers.get(event_type, frozenset())
|
||||
|
||||
if cached_plan is None:
|
||||
with self._rwlock.w_locked():
|
||||
if self._shutting_down:
|
||||
return
|
||||
cached_plan = self._execution_plan_cache.get(event_type)
|
||||
if cached_plan is None:
|
||||
sync_handlers = self._sync_handlers.get(event_type, frozenset())
|
||||
async_handlers = self._async_handlers.get(event_type, frozenset())
|
||||
dependencies = dict(self._handler_dependencies.get(event_type, {}))
|
||||
all_handlers = list(sync_handlers | async_handlers)
|
||||
|
||||
if not all_handlers:
|
||||
return
|
||||
|
||||
cached_plan = build_execution_plan(all_handlers, dependencies)
|
||||
self._execution_plan_cache[event_type] = cached_plan
|
||||
else:
|
||||
sync_handlers = self._sync_handlers.get(event_type, frozenset())
|
||||
async_handlers = self._async_handlers.get(event_type, frozenset())
|
||||
|
||||
for level in cached_plan:
|
||||
level_sync = frozenset(h for h in level if h in sync_handlers)
|
||||
level_async = frozenset(h for h in level if h in async_handlers)
|
||||
|
||||
if level_sync:
|
||||
if event_type is LLMStreamChunkEvent:
|
||||
self._call_handlers(source, event, level_sync)
|
||||
else:
|
||||
future = self._sync_executor.submit(
|
||||
self._call_handlers, source, event, level_sync
|
||||
)
|
||||
await asyncio.get_running_loop().run_in_executor(
|
||||
None, future.result
|
||||
)
|
||||
|
||||
if level_async:
|
||||
await self._acall_handlers(source, event, level_async)
|
||||
|
||||
def emit(self, source: Any, event: BaseEvent) -> Future[None] | None:
|
||||
"""Emit an event to all registered handlers.
|
||||
|
||||
If handlers have dependencies (registered with depends_on), they execute
|
||||
in dependency order. Otherwise, handlers execute as before (sync in thread
|
||||
pool, async fire-and-forget).
|
||||
|
||||
Stream chunk events always execute synchronously to preserve ordering.
|
||||
|
||||
Args:
|
||||
source: The emitting object
|
||||
event: The event instance to emit
|
||||
|
||||
Returns:
|
||||
Future that completes when handlers finish. Returns:
|
||||
- Future for sync-only handlers (ThreadPoolExecutor future)
|
||||
- Future for async handlers or mixed handlers (asyncio future)
|
||||
- Future for dependency-managed handlers (asyncio future)
|
||||
- None if no handlers or sync stream chunk events
|
||||
|
||||
Example:
|
||||
>>> future = crewai_event_bus.emit(source, event)
|
||||
>>> if future:
|
||||
... await asyncio.wrap_future(future) # In async test
|
||||
... # or future.result(timeout=5.0) in sync code
|
||||
"""
|
||||
event_type = type(event)
|
||||
|
||||
with self._rwlock.r_locked():
|
||||
if self._shutting_down:
|
||||
self._console.print(
|
||||
"[CrewAIEventsBus] Warning: Attempted to emit event during shutdown. Ignoring."
|
||||
)
|
||||
return None
|
||||
has_dependencies = event_type in self._handler_dependencies
|
||||
sync_handlers = self._sync_handlers.get(event_type, frozenset())
|
||||
async_handlers = self._async_handlers.get(event_type, frozenset())
|
||||
|
||||
if has_dependencies:
|
||||
return asyncio.run_coroutine_threadsafe(
|
||||
self._emit_with_dependencies(source, event),
|
||||
self._loop,
|
||||
)
|
||||
|
||||
def emit(self, source: Any, event: BaseEvent) -> None:
|
||||
"""
|
||||
Emit an event to all registered handlers
|
||||
if sync_handlers:
|
||||
if event_type is LLMStreamChunkEvent:
|
||||
self._call_handlers(source, event, sync_handlers)
|
||||
else:
|
||||
sync_future = self._sync_executor.submit(
|
||||
self._call_handlers, source, event, sync_handlers
|
||||
)
|
||||
if not async_handlers:
|
||||
return sync_future
|
||||
|
||||
if async_handlers:
|
||||
return asyncio.run_coroutine_threadsafe(
|
||||
self._acall_handlers(source, event, async_handlers),
|
||||
self._loop,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
async def aemit(self, source: Any, event: BaseEvent) -> None:
|
||||
"""Asynchronously emit an event to registered async handlers.
|
||||
|
||||
Only processes async handlers. Use in async contexts.
|
||||
|
||||
Args:
|
||||
source: The object emitting the event
|
||||
event: The event instance to emit
|
||||
"""
|
||||
for event_type, handlers in self._handlers.items():
|
||||
if isinstance(event, event_type):
|
||||
for handler in handlers:
|
||||
self._call_handler(handler, source, event, event_type)
|
||||
event_type = type(event)
|
||||
|
||||
self._signal.send(source, event=event)
|
||||
with self._rwlock.r_locked():
|
||||
if self._shutting_down:
|
||||
self._console.print(
|
||||
"[CrewAIEventsBus] Warning: Attempted to emit event during shutdown. Ignoring."
|
||||
)
|
||||
return
|
||||
async_handlers = self._async_handlers.get(event_type, frozenset())
|
||||
|
||||
if async_handlers:
|
||||
await self._acall_handlers(source, event, async_handlers)
|
||||
|
||||
def register_handler(
|
||||
self, event_type: type[EventTypes], handler: Callable[[Any, EventTypes], None]
|
||||
self,
|
||||
event_type: type[BaseEvent],
|
||||
handler: SyncHandler | AsyncHandler,
|
||||
) -> None:
|
||||
"""Register an event handler for a specific event type"""
|
||||
if event_type not in self._handlers:
|
||||
self._handlers[event_type] = []
|
||||
self._handlers[event_type].append(
|
||||
cast(Callable[[Any, EventTypes], None], handler)
|
||||
)
|
||||
"""Register an event handler for a specific event type.
|
||||
|
||||
Args:
|
||||
event_type: The event class to listen for
|
||||
handler: The handler function to register
|
||||
"""
|
||||
self._register_handler(event_type, handler)
|
||||
|
||||
def validate_dependencies(self) -> None:
|
||||
"""Validate all registered handler dependencies.
|
||||
|
||||
Attempts to build execution plans for all event types with dependencies.
|
||||
This detects circular dependencies and cross-event-type dependencies
|
||||
before events are emitted.
|
||||
|
||||
Raises:
|
||||
CircularDependencyError: If circular dependencies or unresolved
|
||||
dependencies (e.g., cross-event-type) are detected
|
||||
"""
|
||||
with self._rwlock.r_locked():
|
||||
for event_type in self._handler_dependencies:
|
||||
sync_handlers = self._sync_handlers.get(event_type, frozenset())
|
||||
async_handlers = self._async_handlers.get(event_type, frozenset())
|
||||
dependencies = dict(self._handler_dependencies.get(event_type, {}))
|
||||
all_handlers = list(sync_handlers | async_handlers)
|
||||
|
||||
if all_handlers and dependencies:
|
||||
build_execution_plan(all_handlers, dependencies)
|
||||
|
||||
@contextmanager
|
||||
def scoped_handlers(self):
|
||||
"""
|
||||
Context manager for temporary event handling scope.
|
||||
Useful for testing or temporary event handling.
|
||||
def scoped_handlers(self) -> Generator[None, Any, None]:
|
||||
"""Context manager for temporary event handling scope.
|
||||
|
||||
Usage:
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
@crewai_event_bus.on(CrewKickoffStarted)
|
||||
def temp_handler(source, event):
|
||||
print("Temporary handler")
|
||||
# Do stuff...
|
||||
# Handlers are cleared after the context
|
||||
Useful for testing or temporary event handling. All handlers registered
|
||||
within this context are cleared when the context exits.
|
||||
|
||||
Example:
|
||||
>>> from crewai.events.event_bus import crewai_event_bus
|
||||
>>> from crewai.events.event_types import CrewKickoffStartedEvent
|
||||
>>> with crewai_event_bus.scoped_handlers():
|
||||
...
|
||||
... @crewai_event_bus.on(CrewKickoffStartedEvent)
|
||||
... def temp_handler(source, event):
|
||||
... print("Temporary handler")
|
||||
...
|
||||
... # Do stuff...
|
||||
... # Handlers are cleared after the context
|
||||
"""
|
||||
previous_handlers = self._handlers.copy()
|
||||
self._handlers.clear()
|
||||
with self._rwlock.w_locked():
|
||||
prev_sync = self._sync_handlers
|
||||
prev_async = self._async_handlers
|
||||
prev_deps = self._handler_dependencies
|
||||
prev_cache = self._execution_plan_cache
|
||||
self._sync_handlers = {}
|
||||
self._async_handlers = {}
|
||||
self._handler_dependencies = {}
|
||||
self._execution_plan_cache = {}
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._handlers = previous_handlers
|
||||
with self._rwlock.w_locked():
|
||||
self._sync_handlers = prev_sync
|
||||
self._async_handlers = prev_async
|
||||
self._handler_dependencies = prev_deps
|
||||
self._execution_plan_cache = prev_cache
|
||||
|
||||
def shutdown(self, wait: bool = True) -> None:
|
||||
"""Gracefully shutdown the event loop and wait for all tasks to finish.
|
||||
|
||||
Args:
|
||||
wait: If True, wait for all pending tasks to complete before stopping.
|
||||
If False, cancel all pending tasks immediately.
|
||||
"""
|
||||
with self._rwlock.w_locked():
|
||||
self._shutting_down = True
|
||||
loop = getattr(self, "_loop", None)
|
||||
|
||||
if loop is None or loop.is_closed():
|
||||
return
|
||||
|
||||
if wait:
|
||||
|
||||
async def _wait_for_all_tasks() -> None:
|
||||
tasks = {
|
||||
t
|
||||
for t in asyncio.all_tasks(loop)
|
||||
if t is not asyncio.current_task()
|
||||
}
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
future = asyncio.run_coroutine_threadsafe(_wait_for_all_tasks(), loop)
|
||||
try:
|
||||
future.result()
|
||||
except Exception as e:
|
||||
self._console.print(f"[CrewAIEventsBus] Error waiting for tasks: {e}")
|
||||
else:
|
||||
|
||||
def _cancel_tasks() -> None:
|
||||
for task in asyncio.all_tasks(loop):
|
||||
if task is not asyncio.current_task():
|
||||
task.cancel()
|
||||
|
||||
loop.call_soon_threadsafe(_cancel_tasks)
|
||||
|
||||
loop.call_soon_threadsafe(loop.stop)
|
||||
self._loop_thread.join()
|
||||
loop.close()
|
||||
self._sync_executor.shutdown(wait=wait)
|
||||
|
||||
with self._rwlock.w_locked():
|
||||
self._sync_handlers.clear()
|
||||
self._async_handlers.clear()
|
||||
self._execution_plan_cache.clear()
|
||||
|
||||
|
||||
# Global instance
|
||||
crewai_event_bus = CrewAIEventsBus()
|
||||
crewai_event_bus: Final[CrewAIEventsBus] = CrewAIEventsBus()
|
||||
|
||||
atexit.register(crewai_event_bus.shutdown)
|
||||
|
||||
@@ -386,7 +386,7 @@ class EventListener(BaseEventListener):
|
||||
|
||||
# Read from the in-memory stream
|
||||
content = self.text_stream.read()
|
||||
_printer.print(content, end="", flush=True)
|
||||
_printer.print(content)
|
||||
self.next_chunk = self.text_stream.tell()
|
||||
|
||||
# ----------- LLM GUARDRAIL EVENTS -----------
|
||||
|
||||
130
lib/crewai/src/crewai/events/handler_graph.py
Normal file
130
lib/crewai/src/crewai/events/handler_graph.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""Dependency graph resolution for event handlers.
|
||||
|
||||
This module resolves handler dependencies into execution levels, ensuring
|
||||
handlers execute in correct order while maximizing parallelism.
|
||||
"""
|
||||
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import Sequence
|
||||
|
||||
from crewai.events.depends import Depends
|
||||
from crewai.events.types.event_bus_types import ExecutionPlan, Handler
|
||||
|
||||
|
||||
class CircularDependencyError(Exception):
|
||||
"""Exception raised when circular dependencies are detected in event handlers.
|
||||
|
||||
Attributes:
|
||||
handlers: The handlers involved in the circular dependency
|
||||
"""
|
||||
|
||||
def __init__(self, handlers: list[Handler]) -> None:
|
||||
"""Initialize the circular dependency error.
|
||||
|
||||
Args:
|
||||
handlers: The handlers involved in the circular dependency
|
||||
"""
|
||||
handler_names = ", ".join(
|
||||
getattr(h, "__name__", repr(h)) for h in handlers[:5]
|
||||
)
|
||||
message = f"Circular dependency detected in event handlers: {handler_names}"
|
||||
super().__init__(message)
|
||||
self.handlers = handlers
|
||||
|
||||
|
||||
class HandlerGraph:
|
||||
"""Resolves handler dependencies into parallel execution levels.
|
||||
|
||||
Handlers are organized into levels where:
|
||||
- Level 0: Handlers with no dependencies (can run first)
|
||||
- Level N: Handlers that depend on handlers in levels 0...N-1
|
||||
|
||||
Handlers within the same level can execute in parallel.
|
||||
|
||||
Attributes:
|
||||
levels: List of handler sets, where each level can execute in parallel
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
handlers: dict[Handler, list[Depends]],
|
||||
) -> None:
|
||||
"""Initialize the dependency graph.
|
||||
|
||||
Args:
|
||||
handlers: Mapping of handler -> list of `crewai.events.depends.Depends` objects
|
||||
"""
|
||||
self.handlers = handlers
|
||||
self.levels: ExecutionPlan = []
|
||||
self._resolve()
|
||||
|
||||
def _resolve(self) -> None:
|
||||
"""Resolve dependencies into execution levels using topological sort."""
|
||||
dependents: dict[Handler, set[Handler]] = defaultdict(set)
|
||||
in_degree: dict[Handler, int] = {}
|
||||
|
||||
for handler in self.handlers:
|
||||
in_degree[handler] = 0
|
||||
|
||||
for handler, deps in self.handlers.items():
|
||||
in_degree[handler] = len(deps)
|
||||
for dep in deps:
|
||||
dependents[dep.handler].add(handler)
|
||||
|
||||
queue: deque[Handler] = deque(
|
||||
[h for h, deg in in_degree.items() if deg == 0]
|
||||
)
|
||||
|
||||
while queue:
|
||||
current_level: set[Handler] = set()
|
||||
|
||||
for _ in range(len(queue)):
|
||||
handler = queue.popleft()
|
||||
current_level.add(handler)
|
||||
|
||||
for dependent in dependents[handler]:
|
||||
in_degree[dependent] -= 1
|
||||
if in_degree[dependent] == 0:
|
||||
queue.append(dependent)
|
||||
|
||||
if current_level:
|
||||
self.levels.append(current_level)
|
||||
|
||||
remaining = [h for h, deg in in_degree.items() if deg > 0]
|
||||
if remaining:
|
||||
raise CircularDependencyError(remaining)
|
||||
|
||||
def get_execution_plan(self) -> ExecutionPlan:
|
||||
"""Get the ordered execution plan.
|
||||
|
||||
Returns:
|
||||
List of handler sets, where each set represents handlers that can
|
||||
execute in parallel. Sets are ordered such that dependencies are
|
||||
satisfied.
|
||||
"""
|
||||
return self.levels
|
||||
|
||||
|
||||
def build_execution_plan(
|
||||
handlers: Sequence[Handler],
|
||||
dependencies: dict[Handler, list[Depends]],
|
||||
) -> ExecutionPlan:
|
||||
"""Build an execution plan from handlers and their dependencies.
|
||||
|
||||
Args:
|
||||
handlers: All handlers for an event type
|
||||
dependencies: Mapping of handler -> list of dependencies
|
||||
|
||||
Returns:
|
||||
Execution plan as list of levels, where each level is a set of
|
||||
handlers that can execute in parallel
|
||||
|
||||
Raises:
|
||||
CircularDependencyError: If circular dependencies are detected
|
||||
"""
|
||||
handler_dict: dict[Handler, list[Depends]] = {
|
||||
h: dependencies.get(h, []) for h in handlers
|
||||
}
|
||||
|
||||
graph = HandlerGraph(handler_dict)
|
||||
return graph.get_execution_plan()
|
||||
@@ -1,8 +1,9 @@
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from logging import getLogger
|
||||
from threading import Condition, Lock
|
||||
from typing import Any
|
||||
import uuid
|
||||
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
@@ -14,6 +15,7 @@ from crewai.events.listeners.tracing.types import TraceEvent
|
||||
from crewai.events.listeners.tracing.utils import should_auto_collect_first_time_traces
|
||||
from crewai.utilities.constants import CREWAI_BASE_URL
|
||||
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
@@ -41,6 +43,11 @@ class TraceBatchManager:
|
||||
"""Single responsibility: Manage batches and event buffering"""
|
||||
|
||||
def __init__(self):
|
||||
self._init_lock = Lock()
|
||||
self._pending_events_lock = Lock()
|
||||
self._pending_events_cv = Condition(self._pending_events_lock)
|
||||
self._pending_events_count = 0
|
||||
|
||||
self.is_current_batch_ephemeral: bool = False
|
||||
self.trace_batch_id: str | None = None
|
||||
self.current_batch: TraceBatch | None = None
|
||||
@@ -64,24 +71,28 @@ class TraceBatchManager:
|
||||
execution_metadata: dict[str, Any],
|
||||
use_ephemeral: bool = False,
|
||||
) -> TraceBatch:
|
||||
"""Initialize a new trace batch"""
|
||||
self.current_batch = TraceBatch(
|
||||
user_context=user_context, execution_metadata=execution_metadata
|
||||
)
|
||||
self.event_buffer.clear()
|
||||
self.is_current_batch_ephemeral = use_ephemeral
|
||||
"""Initialize a new trace batch (thread-safe)"""
|
||||
with self._init_lock:
|
||||
if self.current_batch is not None:
|
||||
logger.debug("Batch already initialized, skipping duplicate initialization")
|
||||
return self.current_batch
|
||||
|
||||
self.record_start_time("execution")
|
||||
|
||||
if should_auto_collect_first_time_traces():
|
||||
self.trace_batch_id = self.current_batch.batch_id
|
||||
else:
|
||||
self._initialize_backend_batch(
|
||||
user_context, execution_metadata, use_ephemeral
|
||||
self.current_batch = TraceBatch(
|
||||
user_context=user_context, execution_metadata=execution_metadata
|
||||
)
|
||||
self.backend_initialized = True
|
||||
self.is_current_batch_ephemeral = use_ephemeral
|
||||
|
||||
return self.current_batch
|
||||
self.record_start_time("execution")
|
||||
|
||||
if should_auto_collect_first_time_traces():
|
||||
self.trace_batch_id = self.current_batch.batch_id
|
||||
else:
|
||||
self._initialize_backend_batch(
|
||||
user_context, execution_metadata, use_ephemeral
|
||||
)
|
||||
self.backend_initialized = True
|
||||
|
||||
return self.current_batch
|
||||
|
||||
def _initialize_backend_batch(
|
||||
self,
|
||||
@@ -148,6 +159,38 @@ class TraceBatchManager:
|
||||
f"Error initializing trace batch: {e}. Continuing without tracing."
|
||||
)
|
||||
|
||||
def begin_event_processing(self):
|
||||
"""Mark that an event handler started processing (for synchronization)"""
|
||||
with self._pending_events_lock:
|
||||
self._pending_events_count += 1
|
||||
|
||||
def end_event_processing(self):
|
||||
"""Mark that an event handler finished processing (for synchronization)"""
|
||||
with self._pending_events_cv:
|
||||
self._pending_events_count -= 1
|
||||
if self._pending_events_count == 0:
|
||||
self._pending_events_cv.notify_all()
|
||||
|
||||
def wait_for_pending_events(self, timeout: float = 2.0) -> bool:
|
||||
"""Wait for all pending event handlers to finish processing
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait in seconds (default: 2.0)
|
||||
|
||||
Returns:
|
||||
True if all handlers completed, False if timeout occurred
|
||||
"""
|
||||
with self._pending_events_cv:
|
||||
if self._pending_events_count > 0:
|
||||
logger.debug(f"Waiting for {self._pending_events_count} pending event handlers...")
|
||||
self._pending_events_cv.wait(timeout)
|
||||
if self._pending_events_count > 0:
|
||||
logger.error(
|
||||
f"Timeout waiting for event handlers. {self._pending_events_count} still pending. Events may be incomplete!"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
def add_event(self, trace_event: TraceEvent):
|
||||
"""Add event to buffer"""
|
||||
self.event_buffer.append(trace_event)
|
||||
@@ -180,8 +223,8 @@ class TraceBatchManager:
|
||||
self.event_buffer.clear()
|
||||
return 200
|
||||
|
||||
logger.warning(
|
||||
f"Failed to send events: {response.status_code}. Events will be lost."
|
||||
logger.error(
|
||||
f"Failed to send events: {response.status_code}. Response: {response.text}. Events will be lost."
|
||||
)
|
||||
return 500
|
||||
|
||||
@@ -196,15 +239,33 @@ class TraceBatchManager:
|
||||
if not self.current_batch:
|
||||
return None
|
||||
|
||||
self.current_batch.events = self.event_buffer.copy()
|
||||
if self.event_buffer:
|
||||
all_handlers_completed = self.wait_for_pending_events(timeout=2.0)
|
||||
|
||||
if not all_handlers_completed:
|
||||
logger.error("Event handler timeout - marking batch as failed due to incomplete events")
|
||||
self.plus_api.mark_trace_batch_as_failed(
|
||||
self.trace_batch_id, "Timeout waiting for event handlers - events incomplete"
|
||||
)
|
||||
return None
|
||||
|
||||
sorted_events = sorted(
|
||||
self.event_buffer,
|
||||
key=lambda e: e.timestamp if hasattr(e, 'timestamp') and e.timestamp else ''
|
||||
)
|
||||
|
||||
self.current_batch.events = sorted_events
|
||||
events_sent_count = len(sorted_events)
|
||||
if sorted_events:
|
||||
original_buffer = self.event_buffer
|
||||
self.event_buffer = sorted_events
|
||||
events_sent_to_backend_status = self._send_events_to_backend()
|
||||
self.event_buffer = original_buffer
|
||||
if events_sent_to_backend_status == 500:
|
||||
self.plus_api.mark_trace_batch_as_failed(
|
||||
self.trace_batch_id, "Error sending events to backend"
|
||||
)
|
||||
return None
|
||||
self._finalize_backend_batch()
|
||||
self._finalize_backend_batch(events_sent_count)
|
||||
|
||||
finalized_batch = self.current_batch
|
||||
|
||||
@@ -220,18 +281,20 @@ class TraceBatchManager:
|
||||
|
||||
return finalized_batch
|
||||
|
||||
def _finalize_backend_batch(self):
|
||||
"""Send batch finalization to backend"""
|
||||
def _finalize_backend_batch(self, events_count: int = 0):
|
||||
"""Send batch finalization to backend
|
||||
|
||||
Args:
|
||||
events_count: Number of events that were successfully sent
|
||||
"""
|
||||
if not self.plus_api or not self.trace_batch_id:
|
||||
return
|
||||
|
||||
try:
|
||||
total_events = len(self.current_batch.events) if self.current_batch else 0
|
||||
|
||||
payload = {
|
||||
"status": "completed",
|
||||
"duration_ms": self.calculate_duration("execution"),
|
||||
"final_event_count": total_events,
|
||||
"final_event_count": events_count,
|
||||
}
|
||||
|
||||
response = (
|
||||
|
||||
@@ -170,14 +170,6 @@ class TraceCollectionListener(BaseEventListener):
|
||||
def on_flow_finished(source, event):
|
||||
self._handle_trace_event("flow_finished", source, event)
|
||||
|
||||
if self.batch_manager.batch_owner_type == "flow":
|
||||
if self.first_time_handler.is_first_time:
|
||||
self.first_time_handler.mark_events_collected()
|
||||
self.first_time_handler.handle_execution_completion()
|
||||
else:
|
||||
# Normal flow finalization
|
||||
self.batch_manager.finalize_batch()
|
||||
|
||||
@event_bus.on(FlowPlotEvent)
|
||||
def on_flow_plot(source, event):
|
||||
self._handle_action_event("flow_plot", source, event)
|
||||
@@ -383,10 +375,12 @@ class TraceCollectionListener(BaseEventListener):
|
||||
|
||||
def _handle_trace_event(self, event_type: str, source: Any, event: Any):
|
||||
"""Generic handler for context end events"""
|
||||
|
||||
trace_event = self._create_trace_event(event_type, source, event)
|
||||
|
||||
self.batch_manager.add_event(trace_event)
|
||||
self.batch_manager.begin_event_processing()
|
||||
try:
|
||||
trace_event = self._create_trace_event(event_type, source, event)
|
||||
self.batch_manager.add_event(trace_event)
|
||||
finally:
|
||||
self.batch_manager.end_event_processing()
|
||||
|
||||
def _handle_action_event(self, event_type: str, source: Any, event: Any):
|
||||
"""Generic handler for action events (LLM calls, tool usage)"""
|
||||
@@ -399,18 +393,29 @@ class TraceCollectionListener(BaseEventListener):
|
||||
}
|
||||
self.batch_manager.initialize_batch(user_context, execution_metadata)
|
||||
|
||||
trace_event = self._create_trace_event(event_type, source, event)
|
||||
self.batch_manager.add_event(trace_event)
|
||||
self.batch_manager.begin_event_processing()
|
||||
try:
|
||||
trace_event = self._create_trace_event(event_type, source, event)
|
||||
self.batch_manager.add_event(trace_event)
|
||||
finally:
|
||||
self.batch_manager.end_event_processing()
|
||||
|
||||
def _create_trace_event(
|
||||
self, event_type: str, source: Any, event: Any
|
||||
) -> TraceEvent:
|
||||
"""Create a trace event"""
|
||||
trace_event = TraceEvent(
|
||||
type=event_type,
|
||||
)
|
||||
if hasattr(event, 'timestamp') and event.timestamp:
|
||||
trace_event = TraceEvent(
|
||||
type=event_type,
|
||||
timestamp=event.timestamp.isoformat(),
|
||||
)
|
||||
else:
|
||||
trace_event = TraceEvent(
|
||||
type=event_type,
|
||||
)
|
||||
|
||||
trace_event.event_data = self._build_event_data(event_type, event, source)
|
||||
|
||||
return trace_event
|
||||
|
||||
def _build_event_data(
|
||||
|
||||
@@ -358,7 +358,8 @@ def prompt_user_for_trace_viewing(timeout_seconds: int = 20) -> bool:
|
||||
try:
|
||||
response = input().strip().lower()
|
||||
result[0] = response in ["y", "yes"]
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
except (EOFError, KeyboardInterrupt, OSError, LookupError):
|
||||
# Handle all input-related errors silently
|
||||
result[0] = False
|
||||
|
||||
input_thread = threading.Thread(target=get_input, daemon=True)
|
||||
@@ -371,6 +372,7 @@ def prompt_user_for_trace_viewing(timeout_seconds: int = 20) -> bool:
|
||||
return result[0]
|
||||
|
||||
except Exception:
|
||||
# Suppress any warnings or errors and assume "no"
|
||||
return False
|
||||
|
||||
|
||||
|
||||
14
lib/crewai/src/crewai/events/types/event_bus_types.py
Normal file
14
lib/crewai/src/crewai/events/types/event_bus_types.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""Type definitions for event handlers."""
|
||||
|
||||
from collections.abc import Callable, Coroutine
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
|
||||
SyncHandler: TypeAlias = Callable[[Any, BaseEvent], None]
|
||||
AsyncHandler: TypeAlias = Callable[[Any, BaseEvent], Coroutine[Any, Any, None]]
|
||||
SyncHandlerSet: TypeAlias = frozenset[SyncHandler]
|
||||
AsyncHandlerSet: TypeAlias = frozenset[AsyncHandler]
|
||||
|
||||
Handler: TypeAlias = Callable[[Any, BaseEvent], Any]
|
||||
ExecutionPlan: TypeAlias = list[set[Handler]]
|
||||
59
lib/crewai/src/crewai/events/utils/handlers.py
Normal file
59
lib/crewai/src/crewai/events/utils/handlers.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""Handler utility functions for event processing."""
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
from typing import Any
|
||||
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
from crewai.events.types.event_bus_types import AsyncHandler, SyncHandler
|
||||
|
||||
|
||||
def is_async_handler(
|
||||
handler: Any,
|
||||
) -> TypeIs[AsyncHandler]:
|
||||
"""Type guard to check if handler is an async handler.
|
||||
|
||||
Args:
|
||||
handler: The handler to check
|
||||
|
||||
Returns:
|
||||
True if handler is an async coroutine function
|
||||
"""
|
||||
try:
|
||||
if inspect.iscoroutinefunction(handler) or (
|
||||
callable(handler) and inspect.iscoroutinefunction(handler.__call__)
|
||||
):
|
||||
return True
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
if isinstance(handler, functools.partial) and inspect.iscoroutinefunction(
|
||||
handler.func
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def is_call_handler_safe(
|
||||
handler: SyncHandler,
|
||||
source: Any,
|
||||
event: BaseEvent,
|
||||
) -> Exception | None:
|
||||
"""Safely call a single handler and return any exception.
|
||||
|
||||
Args:
|
||||
handler: The handler function to call
|
||||
source: The object that emitted the event
|
||||
event: The event instance
|
||||
|
||||
Returns:
|
||||
Exception if handler raised one, None otherwise
|
||||
"""
|
||||
try:
|
||||
handler(source, event)
|
||||
return None
|
||||
except Exception as e:
|
||||
return e
|
||||
81
lib/crewai/src/crewai/events/utils/rw_lock.py
Normal file
81
lib/crewai/src/crewai/events/utils/rw_lock.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""Read-write lock for thread-safe concurrent access.
|
||||
|
||||
This module provides a reader-writer lock implementation that allows multiple
|
||||
concurrent readers or a single exclusive writer.
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from threading import Condition
|
||||
|
||||
|
||||
class RWLock:
|
||||
"""Read-write lock for managing concurrent read and exclusive write access.
|
||||
|
||||
Allows multiple threads to acquire read locks simultaneously, but ensures
|
||||
exclusive access for write operations. Writers are prioritized when waiting.
|
||||
|
||||
Attributes:
|
||||
_cond: Condition variable for coordinating lock access
|
||||
_readers: Count of active readers
|
||||
_writer: Whether a writer currently holds the lock
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the read-write lock."""
|
||||
self._cond = Condition()
|
||||
self._readers = 0
|
||||
self._writer = False
|
||||
|
||||
def r_acquire(self) -> None:
|
||||
"""Acquire a read lock, blocking if a writer holds the lock."""
|
||||
with self._cond:
|
||||
while self._writer:
|
||||
self._cond.wait()
|
||||
self._readers += 1
|
||||
|
||||
def r_release(self) -> None:
|
||||
"""Release a read lock and notify waiting writers if last reader."""
|
||||
with self._cond:
|
||||
self._readers -= 1
|
||||
if self._readers == 0:
|
||||
self._cond.notify_all()
|
||||
|
||||
@contextmanager
|
||||
def r_locked(self) -> Generator[None, None, None]:
|
||||
"""Context manager for acquiring a read lock.
|
||||
|
||||
Yields:
|
||||
None
|
||||
"""
|
||||
try:
|
||||
self.r_acquire()
|
||||
yield
|
||||
finally:
|
||||
self.r_release()
|
||||
|
||||
def w_acquire(self) -> None:
|
||||
"""Acquire a write lock, blocking if any readers or writers are active."""
|
||||
with self._cond:
|
||||
while self._writer or self._readers > 0:
|
||||
self._cond.wait()
|
||||
self._writer = True
|
||||
|
||||
def w_release(self) -> None:
|
||||
"""Release a write lock and notify all waiting threads."""
|
||||
with self._cond:
|
||||
self._writer = False
|
||||
self._cond.notify_all()
|
||||
|
||||
@contextmanager
|
||||
def w_locked(self) -> Generator[None, None, None]:
|
||||
"""Context manager for acquiring a write lock.
|
||||
|
||||
Yields:
|
||||
None
|
||||
"""
|
||||
try:
|
||||
self.w_acquire()
|
||||
yield
|
||||
finally:
|
||||
self.w_release()
|
||||
@@ -52,19 +52,14 @@ class AgentEvaluator:
|
||||
self.console_formatter = ConsoleFormatter()
|
||||
self.display_formatter = EvaluationDisplayFormatter()
|
||||
|
||||
self._thread_local: threading.local = threading.local()
|
||||
self._execution_state = ExecutionState()
|
||||
self._state_lock = threading.Lock()
|
||||
|
||||
for agent in self.agents:
|
||||
self._execution_state.agent_evaluators[str(agent.id)] = self.evaluators
|
||||
|
||||
self._subscribe_to_events()
|
||||
|
||||
@property
|
||||
def _execution_state(self) -> ExecutionState:
|
||||
if not hasattr(self._thread_local, "execution_state"):
|
||||
self._thread_local.execution_state = ExecutionState()
|
||||
return self._thread_local.execution_state
|
||||
|
||||
def _subscribe_to_events(self) -> None:
|
||||
from typing import cast
|
||||
|
||||
@@ -112,21 +107,22 @@ class AgentEvaluator:
|
||||
state=state,
|
||||
)
|
||||
|
||||
current_iteration = self._execution_state.iteration
|
||||
if current_iteration not in self._execution_state.iterations_results:
|
||||
self._execution_state.iterations_results[current_iteration] = {}
|
||||
with self._state_lock:
|
||||
current_iteration = self._execution_state.iteration
|
||||
if current_iteration not in self._execution_state.iterations_results:
|
||||
self._execution_state.iterations_results[current_iteration] = {}
|
||||
|
||||
if (
|
||||
agent.role
|
||||
not in self._execution_state.iterations_results[current_iteration]
|
||||
):
|
||||
self._execution_state.iterations_results[current_iteration][
|
||||
agent.role
|
||||
] = []
|
||||
|
||||
if (
|
||||
agent.role
|
||||
not in self._execution_state.iterations_results[current_iteration]
|
||||
):
|
||||
self._execution_state.iterations_results[current_iteration][
|
||||
agent.role
|
||||
] = []
|
||||
|
||||
self._execution_state.iterations_results[current_iteration][
|
||||
agent.role
|
||||
].append(result)
|
||||
].append(result)
|
||||
|
||||
def _handle_lite_agent_completed(
|
||||
self, source: object, event: LiteAgentExecutionCompletedEvent
|
||||
@@ -164,22 +160,23 @@ class AgentEvaluator:
|
||||
state=state,
|
||||
)
|
||||
|
||||
current_iteration = self._execution_state.iteration
|
||||
if current_iteration not in self._execution_state.iterations_results:
|
||||
self._execution_state.iterations_results[current_iteration] = {}
|
||||
with self._state_lock:
|
||||
current_iteration = self._execution_state.iteration
|
||||
if current_iteration not in self._execution_state.iterations_results:
|
||||
self._execution_state.iterations_results[current_iteration] = {}
|
||||
|
||||
agent_role = target_agent.role
|
||||
if (
|
||||
agent_role
|
||||
not in self._execution_state.iterations_results[current_iteration]
|
||||
):
|
||||
self._execution_state.iterations_results[current_iteration][
|
||||
agent_role
|
||||
] = []
|
||||
|
||||
agent_role = target_agent.role
|
||||
if (
|
||||
agent_role
|
||||
not in self._execution_state.iterations_results[current_iteration]
|
||||
):
|
||||
self._execution_state.iterations_results[current_iteration][
|
||||
agent_role
|
||||
] = []
|
||||
|
||||
self._execution_state.iterations_results[current_iteration][
|
||||
agent_role
|
||||
].append(result)
|
||||
].append(result)
|
||||
|
||||
def set_iteration(self, iteration: int) -> None:
|
||||
self._execution_state.iteration = iteration
|
||||
|
||||
@@ -3,6 +3,7 @@ import copy
|
||||
import inspect
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import Future
|
||||
from typing import Any, ClassVar, Generic, TypeVar, cast
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -463,6 +464,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
self._completed_methods: set[str] = set() # Track completed methods for reload
|
||||
self._persistence: FlowPersistence | None = persistence
|
||||
self._is_execution_resuming: bool = False
|
||||
self._event_futures: list[Future[None]] = []
|
||||
|
||||
# Initialize state with initial values
|
||||
self._state = self._create_initial_state()
|
||||
@@ -855,7 +857,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
self._initialize_state(filtered_inputs)
|
||||
|
||||
# Emit FlowStartedEvent and log the start of the flow.
|
||||
crewai_event_bus.emit(
|
||||
future = crewai_event_bus.emit(
|
||||
self,
|
||||
FlowStartedEvent(
|
||||
type="flow_started",
|
||||
@@ -863,6 +865,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
inputs=inputs,
|
||||
),
|
||||
)
|
||||
if future:
|
||||
self._event_futures.append(future)
|
||||
self._log_flow_event(
|
||||
f"Flow started with ID: {self.flow_id}", color="bold_magenta"
|
||||
)
|
||||
@@ -881,7 +885,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
|
||||
final_output = self._method_outputs[-1] if self._method_outputs else None
|
||||
|
||||
crewai_event_bus.emit(
|
||||
future = crewai_event_bus.emit(
|
||||
self,
|
||||
FlowFinishedEvent(
|
||||
type="flow_finished",
|
||||
@@ -889,6 +893,25 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
result=final_output,
|
||||
),
|
||||
)
|
||||
if future:
|
||||
self._event_futures.append(future)
|
||||
|
||||
if self._event_futures:
|
||||
await asyncio.gather(*[asyncio.wrap_future(f) for f in self._event_futures])
|
||||
self._event_futures.clear()
|
||||
|
||||
if (
|
||||
is_tracing_enabled()
|
||||
or self.tracing
|
||||
or should_auto_collect_first_time_traces()
|
||||
):
|
||||
trace_listener = TraceCollectionListener()
|
||||
if trace_listener.batch_manager.batch_owner_type == "flow":
|
||||
if trace_listener.first_time_handler.is_first_time:
|
||||
trace_listener.first_time_handler.mark_events_collected()
|
||||
trace_listener.first_time_handler.handle_execution_completion()
|
||||
else:
|
||||
trace_listener.batch_manager.finalize_batch()
|
||||
|
||||
return final_output
|
||||
finally:
|
||||
@@ -971,7 +994,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
dumped_params = {f"_{i}": arg for i, arg in enumerate(args)} | (
|
||||
kwargs or {}
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
future = crewai_event_bus.emit(
|
||||
self,
|
||||
MethodExecutionStartedEvent(
|
||||
type="method_execution_started",
|
||||
@@ -981,6 +1004,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
state=self._copy_state(),
|
||||
),
|
||||
)
|
||||
if future:
|
||||
self._event_futures.append(future)
|
||||
|
||||
result = (
|
||||
await method(*args, **kwargs)
|
||||
@@ -994,7 +1019,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
)
|
||||
|
||||
self._completed_methods.add(method_name)
|
||||
crewai_event_bus.emit(
|
||||
future = crewai_event_bus.emit(
|
||||
self,
|
||||
MethodExecutionFinishedEvent(
|
||||
type="method_execution_finished",
|
||||
@@ -1004,10 +1029,12 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
result=result,
|
||||
),
|
||||
)
|
||||
if future:
|
||||
self._event_futures.append(future)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
future = crewai_event_bus.emit(
|
||||
self,
|
||||
MethodExecutionFailedEvent(
|
||||
type="method_execution_failed",
|
||||
@@ -1016,6 +1043,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
error=e,
|
||||
),
|
||||
)
|
||||
if future:
|
||||
self._event_futures.append(future)
|
||||
raise e
|
||||
|
||||
async def _execute_listeners(self, trigger_method: str, result: Any) -> None:
|
||||
|
||||
@@ -367,6 +367,14 @@ class LLM(BaseLLM):
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
elif provider == "bedrock":
|
||||
try:
|
||||
from crewai.llms.providers.bedrock.completion import BedrockCompletion
|
||||
|
||||
return BedrockCompletion
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
@@ -40,6 +39,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
top_p: float | None = None,
|
||||
stop_sequences: list[str] | None = None,
|
||||
stream: bool = False,
|
||||
client_params: dict[str, Any] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize Anthropic chat completion client.
|
||||
@@ -55,19 +55,20 @@ class AnthropicCompletion(BaseLLM):
|
||||
top_p: Nucleus sampling parameter
|
||||
stop_sequences: Stop sequences (Anthropic uses stop_sequences, not stop)
|
||||
stream: Enable streaming responses
|
||||
client_params: Additional parameters for the Anthropic client
|
||||
**kwargs: Additional parameters
|
||||
"""
|
||||
super().__init__(
|
||||
model=model, temperature=temperature, stop=stop_sequences or [], **kwargs
|
||||
)
|
||||
|
||||
# Initialize Anthropic client
|
||||
self.client = Anthropic(
|
||||
api_key=api_key or os.getenv("ANTHROPIC_API_KEY"),
|
||||
base_url=base_url,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
# Client params
|
||||
self.client_params = client_params
|
||||
self.base_url = base_url
|
||||
self.timeout = timeout
|
||||
self.max_retries = max_retries
|
||||
|
||||
self.client = Anthropic(**self._get_client_params())
|
||||
|
||||
# Store completion parameters
|
||||
self.max_tokens = max_tokens
|
||||
@@ -79,6 +80,26 @@ class AnthropicCompletion(BaseLLM):
|
||||
self.is_claude_3 = "claude-3" in model.lower()
|
||||
self.supports_tools = self.is_claude_3 # Claude 3+ supports tool use
|
||||
|
||||
def _get_client_params(self) -> dict[str, Any]:
|
||||
"""Get client parameters."""
|
||||
|
||||
if self.api_key is None:
|
||||
self.api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
if self.api_key is None:
|
||||
raise ValueError("ANTHROPIC_API_KEY is required")
|
||||
|
||||
client_params = {
|
||||
"api_key": self.api_key,
|
||||
"base_url": self.base_url,
|
||||
"timeout": self.timeout,
|
||||
"max_retries": self.max_retries,
|
||||
}
|
||||
|
||||
if self.client_params:
|
||||
client_params.update(self.client_params)
|
||||
|
||||
return client_params
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: str | list[dict[str, str]],
|
||||
@@ -183,12 +204,20 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
def _convert_tools_for_interference(self, tools: list[dict]) -> list[dict]:
|
||||
"""Convert CrewAI tool format to Anthropic tool use format."""
|
||||
from crewai.llms.providers.utils.common import safe_tool_conversion
|
||||
|
||||
anthropic_tools = []
|
||||
|
||||
for tool in tools:
|
||||
name, description, parameters = safe_tool_conversion(tool, "Anthropic")
|
||||
if "input_schema" in tool and "name" in tool and "description" in tool:
|
||||
anthropic_tools.append(tool)
|
||||
continue
|
||||
|
||||
try:
|
||||
from crewai.llms.providers.utils.common import safe_tool_conversion
|
||||
|
||||
name, description, parameters = safe_tool_conversion(tool, "Anthropic")
|
||||
except (ImportError, KeyError, ValueError) as e:
|
||||
logging.error(f"Error converting tool to Anthropic format: {e}")
|
||||
raise e
|
||||
|
||||
anthropic_tool = {
|
||||
"name": name,
|
||||
@@ -196,7 +225,13 @@ class AnthropicCompletion(BaseLLM):
|
||||
}
|
||||
|
||||
if parameters and isinstance(parameters, dict):
|
||||
anthropic_tool["input_schema"] = parameters # type: ignore
|
||||
anthropic_tool["input_schema"] = parameters
|
||||
else:
|
||||
anthropic_tool["input_schema"] = {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
anthropic_tools.append(anthropic_tool)
|
||||
|
||||
@@ -229,13 +264,11 @@ class AnthropicCompletion(BaseLLM):
|
||||
content = message.get("content", "")
|
||||
|
||||
if role == "system":
|
||||
# Extract system message - Anthropic handles it separately
|
||||
if system_message:
|
||||
system_message += f"\n\n{content}"
|
||||
else:
|
||||
system_message = content
|
||||
else:
|
||||
# Add user/assistant messages - ensure both role and content are str, not None
|
||||
role_str = role if role is not None else "user"
|
||||
content_str = content if content is not None else ""
|
||||
formatted_messages.append({"role": role_str, "content": content_str})
|
||||
@@ -270,22 +303,22 @@ class AnthropicCompletion(BaseLLM):
|
||||
usage = self._extract_anthropic_token_usage(response)
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
# Check if Claude wants to use tools
|
||||
if response.content and available_functions:
|
||||
for content_block in response.content:
|
||||
if isinstance(content_block, ToolUseBlock):
|
||||
function_name = content_block.name
|
||||
function_args = content_block.input
|
||||
tool_uses = [
|
||||
block for block in response.content if isinstance(block, ToolUseBlock)
|
||||
]
|
||||
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args, # type: ignore
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
if tool_uses:
|
||||
# Handle tool use conversation flow
|
||||
return self._handle_tool_use_conversation(
|
||||
response,
|
||||
tool_uses,
|
||||
params,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
)
|
||||
|
||||
# Extract text content
|
||||
content = ""
|
||||
@@ -318,12 +351,14 @@ class AnthropicCompletion(BaseLLM):
|
||||
) -> str:
|
||||
"""Handle streaming message completion."""
|
||||
full_response = ""
|
||||
tool_uses = {}
|
||||
|
||||
# Remove 'stream' parameter as messages.stream() doesn't accept it
|
||||
# (the SDK sets it internally)
|
||||
stream_params = {k: v for k, v in params.items() if k != "stream"}
|
||||
|
||||
# Make streaming API call
|
||||
with self.client.messages.stream(**params) as stream:
|
||||
with self.client.messages.stream(**stream_params) as stream:
|
||||
for event in stream:
|
||||
# Handle content delta events
|
||||
if hasattr(event, "delta") and hasattr(event.delta, "text"):
|
||||
text_delta = event.delta.text
|
||||
full_response += text_delta
|
||||
@@ -333,44 +368,29 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
# Handle tool use events
|
||||
elif hasattr(event, "delta") and hasattr(event.delta, "partial_json"):
|
||||
# Tool use streaming - accumulate JSON
|
||||
tool_id = getattr(event, "index", "default")
|
||||
if tool_id not in tool_uses:
|
||||
tool_uses[tool_id] = {
|
||||
"name": "",
|
||||
"input": "",
|
||||
}
|
||||
final_message: Message = stream.get_final_message()
|
||||
|
||||
if hasattr(event.delta, "name"):
|
||||
tool_uses[tool_id]["name"] = event.delta.name
|
||||
if hasattr(event.delta, "partial_json"):
|
||||
tool_uses[tool_id]["input"] += event.delta.partial_json
|
||||
usage = self._extract_anthropic_token_usage(final_message)
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
# Handle completed tool uses
|
||||
if tool_uses and available_functions:
|
||||
for tool_data in tool_uses.values():
|
||||
function_name = tool_data["name"]
|
||||
if final_message.content and available_functions:
|
||||
tool_uses = [
|
||||
block
|
||||
for block in final_message.content
|
||||
if isinstance(block, ToolUseBlock)
|
||||
]
|
||||
|
||||
try:
|
||||
function_args = json.loads(tool_data["input"])
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"Failed to parse streamed tool arguments: {e}")
|
||||
continue
|
||||
|
||||
# Execute tool
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
if tool_uses:
|
||||
# Handle tool use conversation flow
|
||||
return self._handle_tool_use_conversation(
|
||||
final_message,
|
||||
tool_uses,
|
||||
params,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Apply stop words to full response
|
||||
full_response = self._apply_stop_words(full_response)
|
||||
|
||||
@@ -385,6 +405,113 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
return full_response
|
||||
|
||||
def _handle_tool_use_conversation(
|
||||
self,
|
||||
initial_response: Message,
|
||||
tool_uses: list[ToolUseBlock],
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any],
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str:
|
||||
"""Handle the complete tool use conversation flow.
|
||||
|
||||
This implements the proper Anthropic tool use pattern:
|
||||
1. Claude requests tool use
|
||||
2. We execute the tools
|
||||
3. We send tool results back to Claude
|
||||
4. Claude processes results and generates final response
|
||||
"""
|
||||
# Execute all requested tools and collect results
|
||||
tool_results = []
|
||||
|
||||
for tool_use in tool_uses:
|
||||
function_name = tool_use.name
|
||||
function_args = tool_use.input
|
||||
|
||||
# Execute the tool
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args, # type: ignore
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
# Create tool result in Anthropic format
|
||||
tool_result = {
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_use.id,
|
||||
"content": str(result)
|
||||
if result is not None
|
||||
else "Tool execution completed",
|
||||
}
|
||||
tool_results.append(tool_result)
|
||||
|
||||
# Prepare follow-up conversation with tool results
|
||||
follow_up_params = params.copy()
|
||||
|
||||
# Add Claude's tool use response to conversation
|
||||
assistant_message = {"role": "assistant", "content": initial_response.content}
|
||||
|
||||
# Add user message with tool results
|
||||
user_message = {"role": "user", "content": tool_results}
|
||||
|
||||
# Update messages for follow-up call
|
||||
follow_up_params["messages"] = params["messages"] + [
|
||||
assistant_message,
|
||||
user_message,
|
||||
]
|
||||
|
||||
try:
|
||||
# Send tool results back to Claude for final response
|
||||
final_response: Message = self.client.messages.create(**follow_up_params)
|
||||
|
||||
# Track token usage for follow-up call
|
||||
follow_up_usage = self._extract_anthropic_token_usage(final_response)
|
||||
self._track_token_usage_internal(follow_up_usage)
|
||||
|
||||
# Extract final text content
|
||||
final_content = ""
|
||||
if final_response.content:
|
||||
for content_block in final_response.content:
|
||||
if hasattr(content_block, "text"):
|
||||
final_content += content_block.text
|
||||
|
||||
final_content = self._apply_stop_words(final_content)
|
||||
|
||||
# Emit completion event for the final response
|
||||
self._emit_call_completed_event(
|
||||
response=final_content,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=follow_up_params["messages"],
|
||||
)
|
||||
|
||||
# Log combined token usage
|
||||
total_usage = {
|
||||
"input_tokens": follow_up_usage.get("input_tokens", 0),
|
||||
"output_tokens": follow_up_usage.get("output_tokens", 0),
|
||||
"total_tokens": follow_up_usage.get("total_tokens", 0),
|
||||
}
|
||||
|
||||
if total_usage.get("total_tokens", 0) > 0:
|
||||
logging.info(f"Anthropic API tool conversation usage: {total_usage}")
|
||||
|
||||
return final_content
|
||||
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
logging.error(f"Context window exceeded in tool follow-up: {e}")
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
|
||||
logging.error(f"Tool follow-up conversation failed: {e}")
|
||||
# Fallback: return the first tool result if follow-up fails
|
||||
if tool_results:
|
||||
return tool_results[0]["content"]
|
||||
raise e
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the model supports function calling."""
|
||||
return self.supports_tools
|
||||
|
||||
553
lib/crewai/src/crewai/llms/providers/bedrock/completion.py
Normal file
553
lib/crewai/src/crewai/llms/providers/bedrock/completion.py
Normal file
@@ -0,0 +1,553 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
from boto3.session import Session
|
||||
from botocore.config import Config
|
||||
from botocore.exceptions import BotoCoreError, ClientError
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"AWS Bedrock native provider not available, to install: `uv add boto3`"
|
||||
) from None
|
||||
|
||||
|
||||
class BedrockCompletion(BaseLLM):
|
||||
"""AWS Bedrock native completion implementation using the Converse API.
|
||||
|
||||
This class provides direct integration with AWS Bedrock using the modern
|
||||
Converse API, which provides a unified interface across all Bedrock models.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
aws_access_key_id: str | None = None,
|
||||
aws_secret_access_key: str | None = None,
|
||||
aws_session_token: str | None = None,
|
||||
region_name: str = "us-east-1",
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
top_p: float | None = None,
|
||||
top_k: int | None = None,
|
||||
stop_sequences: Sequence[str] | None = None,
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize AWS Bedrock completion client."""
|
||||
# Extract provider from kwargs to avoid duplicate argument
|
||||
kwargs.pop("provider", None)
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
stop=stop_sequences or [],
|
||||
provider="bedrock",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Initialize Bedrock client with proper configuration
|
||||
session = Session(
|
||||
aws_access_key_id=aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID"),
|
||||
aws_secret_access_key=aws_secret_access_key
|
||||
or os.getenv("AWS_SECRET_ACCESS_KEY"),
|
||||
aws_session_token=aws_session_token or os.getenv("AWS_SESSION_TOKEN"),
|
||||
region_name=region_name,
|
||||
)
|
||||
|
||||
# Configure client with timeouts and retries following AWS best practices
|
||||
config = Config(
|
||||
connect_timeout=60,
|
||||
read_timeout=300,
|
||||
retries={
|
||||
"max_attempts": 3,
|
||||
"mode": "adaptive",
|
||||
},
|
||||
tcp_keepalive=True,
|
||||
)
|
||||
|
||||
self.client = session.client("bedrock-runtime", config=config)
|
||||
self.region_name = region_name
|
||||
|
||||
# Store completion parameters
|
||||
self.max_tokens = max_tokens
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
self.stream = stream
|
||||
self.stop_sequences = stop_sequences or []
|
||||
|
||||
# Model-specific settings
|
||||
self.is_claude_model = "claude" in model.lower()
|
||||
self.supports_tools = True # Converse API supports tools for most models
|
||||
self.supports_streaming = True
|
||||
|
||||
# Handle inference profiles for newer models
|
||||
self.model_id = model
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: str | list[dict[str, str]],
|
||||
tools: Sequence[Mapping[str, Any]] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str | Any:
|
||||
"""Call AWS Bedrock Converse API."""
|
||||
try:
|
||||
# Emit call started event
|
||||
self._emit_call_started_event(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
# Format messages for Converse API
|
||||
formatted_messages, system_message = self._format_messages_for_converse(
|
||||
messages
|
||||
)
|
||||
|
||||
# Prepare tool configuration
|
||||
tool_config = None
|
||||
if tools:
|
||||
tool_config = {"tools": self._format_tools_for_converse(tools)}
|
||||
|
||||
# Prepare request body
|
||||
body = {
|
||||
"inferenceConfig": self._get_inference_config(),
|
||||
}
|
||||
|
||||
# Add system message if present
|
||||
if system_message:
|
||||
body["system"] = [{"text": system_message}]
|
||||
|
||||
# Add tool config if present
|
||||
if tool_config:
|
||||
body["toolConfig"] = tool_config
|
||||
|
||||
if self.stream:
|
||||
return self._handle_streaming_converse(
|
||||
formatted_messages, body, available_functions, from_task, from_agent
|
||||
)
|
||||
|
||||
return self._handle_converse(
|
||||
formatted_messages, body, available_functions, from_task, from_agent
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
logging.error(f"Context window exceeded: {e}")
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
|
||||
error_msg = f"AWS Bedrock API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
|
||||
def _handle_converse(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
body: dict[str, Any],
|
||||
available_functions: Mapping[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str:
|
||||
"""Handle non-streaming converse API call following AWS best practices."""
|
||||
try:
|
||||
# Validate messages format before API call
|
||||
if not messages:
|
||||
raise ValueError("Messages cannot be empty")
|
||||
|
||||
# Ensure we have valid message structure
|
||||
for i, msg in enumerate(messages):
|
||||
if (
|
||||
not isinstance(msg, dict)
|
||||
or "role" not in msg
|
||||
or "content" not in msg
|
||||
):
|
||||
raise ValueError(f"Invalid message format at index {i}")
|
||||
|
||||
# Call Bedrock Converse API with proper error handling
|
||||
response = self.client.converse(
|
||||
modelId=self.model_id, messages=messages, **body
|
||||
)
|
||||
|
||||
# Track token usage according to AWS response format
|
||||
if "usage" in response:
|
||||
self._track_token_usage_internal(response["usage"])
|
||||
|
||||
# Extract content following AWS response structure
|
||||
output = response.get("output", {})
|
||||
message = output.get("message", {})
|
||||
content = message.get("content", [])
|
||||
|
||||
if not content:
|
||||
logging.warning("No content in Bedrock response")
|
||||
return (
|
||||
"I apologize, but I received an empty response. Please try again."
|
||||
)
|
||||
|
||||
# Extract text content from response
|
||||
text_content = ""
|
||||
for content_block in content:
|
||||
# Handle different content block types as per AWS documentation
|
||||
if "text" in content_block:
|
||||
text_content += content_block["text"]
|
||||
elif content_block.get("type") == "toolUse" and available_functions:
|
||||
# Handle tool use according to AWS format
|
||||
tool_use = content_block["toolUse"]
|
||||
function_name = tool_use.get("name")
|
||||
function_args = tool_use.get("input", {})
|
||||
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Apply stop sequences if configured
|
||||
text_content = self._apply_stop_words(text_content)
|
||||
|
||||
# Validate final response
|
||||
if not text_content or text_content.strip() == "":
|
||||
logging.warning("Extracted empty text content from Bedrock response")
|
||||
text_content = "I apologize, but I couldn't generate a proper response. Please try again."
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=text_content,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
return text_content
|
||||
|
||||
except ClientError as e:
|
||||
# Handle all AWS ClientError exceptions as per documentation
|
||||
error_code = e.response.get("Error", {}).get("Code", "Unknown")
|
||||
error_msg = e.response.get("Error", {}).get("Message", str(e))
|
||||
|
||||
# Log the specific error for debugging
|
||||
logging.error(f"AWS Bedrock ClientError ({error_code}): {error_msg}")
|
||||
|
||||
# Handle specific error codes as documented
|
||||
if error_code == "ValidationException":
|
||||
# This is the error we're seeing with Cohere
|
||||
if "last turn" in error_msg and "user message" in error_msg:
|
||||
raise ValueError(
|
||||
f"Conversation format error: {error_msg}. Check message alternation."
|
||||
) from e
|
||||
raise ValueError(f"Request validation failed: {error_msg}") from e
|
||||
if error_code == "AccessDeniedException":
|
||||
raise PermissionError(
|
||||
f"Access denied to model {self.model_id}: {error_msg}"
|
||||
) from e
|
||||
if error_code == "ResourceNotFoundException":
|
||||
raise ValueError(f"Model {self.model_id} not found: {error_msg}") from e
|
||||
if error_code == "ThrottlingException":
|
||||
raise RuntimeError(
|
||||
f"API throttled, please retry later: {error_msg}"
|
||||
) from e
|
||||
if error_code == "ModelTimeoutException":
|
||||
raise TimeoutError(f"Model request timed out: {error_msg}") from e
|
||||
if error_code == "ServiceQuotaExceededException":
|
||||
raise RuntimeError(f"Service quota exceeded: {error_msg}") from e
|
||||
if error_code == "ModelNotReadyException":
|
||||
raise RuntimeError(
|
||||
f"Model {self.model_id} not ready: {error_msg}"
|
||||
) from e
|
||||
if error_code == "ModelErrorException":
|
||||
raise RuntimeError(f"Model error: {error_msg}") from e
|
||||
if error_code == "InternalServerException":
|
||||
raise RuntimeError(f"Internal server error: {error_msg}") from e
|
||||
if error_code == "ServiceUnavailableException":
|
||||
raise RuntimeError(f"Service unavailable: {error_msg}") from e
|
||||
|
||||
raise RuntimeError(f"Bedrock API error ({error_code}): {error_msg}") from e
|
||||
|
||||
except BotoCoreError as e:
|
||||
error_msg = f"Bedrock connection error: {e}"
|
||||
logging.error(error_msg)
|
||||
raise ConnectionError(error_msg) from e
|
||||
except Exception as e:
|
||||
# Catch any other unexpected errors
|
||||
error_msg = f"Unexpected error in Bedrock converse call: {e}"
|
||||
logging.error(error_msg)
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
def _handle_streaming_converse(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
body: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str:
|
||||
"""Handle streaming converse API call."""
|
||||
full_response = ""
|
||||
|
||||
try:
|
||||
response = self.client.converse_stream(
|
||||
modelId=self.model_id, messages=messages, **body
|
||||
)
|
||||
|
||||
stream = response.get("stream")
|
||||
if stream:
|
||||
for event in stream:
|
||||
if "contentBlockDelta" in event:
|
||||
delta = event["contentBlockDelta"]["delta"]
|
||||
if "text" in delta:
|
||||
text_chunk = delta["text"]
|
||||
logging.debug(f"Streaming text chunk: {text_chunk[:50]}...")
|
||||
full_response += text_chunk
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=text_chunk,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
elif "messageStop" in event:
|
||||
# Handle end of message
|
||||
break
|
||||
|
||||
except ClientError as e:
|
||||
error_msg = self._handle_client_error(e)
|
||||
raise RuntimeError(error_msg) from e
|
||||
except BotoCoreError as e:
|
||||
error_msg = f"Bedrock streaming connection error: {e}"
|
||||
logging.error(error_msg)
|
||||
raise ConnectionError(error_msg) from e
|
||||
|
||||
# Apply stop words to full response
|
||||
full_response = self._apply_stop_words(full_response)
|
||||
|
||||
# Ensure we don't return empty content
|
||||
if not full_response or full_response.strip() == "":
|
||||
logging.warning("Bedrock streaming returned empty content, using fallback")
|
||||
full_response = (
|
||||
"I apologize, but I couldn't generate a response. Please try again."
|
||||
)
|
||||
|
||||
# Emit completion event
|
||||
self._emit_call_completed_event(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
return full_response
|
||||
|
||||
def _format_messages_for_converse(
|
||||
self, messages: str | list[dict[str, str]]
|
||||
) -> tuple[list[dict[str, Any]], str | None]:
|
||||
"""Format messages for Converse API following AWS documentation."""
|
||||
# Use base class formatting first
|
||||
formatted_messages = self._format_messages(messages)
|
||||
|
||||
converse_messages = []
|
||||
system_message = None
|
||||
|
||||
for message in formatted_messages:
|
||||
role = message.get("role")
|
||||
content = message.get("content", "")
|
||||
|
||||
if role == "system":
|
||||
# Extract system message - Converse API handles it separately
|
||||
if system_message:
|
||||
system_message += f"\n\n{content}"
|
||||
else:
|
||||
system_message = content
|
||||
else:
|
||||
# Convert to Converse API format with proper content structure
|
||||
converse_messages.append({"role": role, "content": [{"text": content}]})
|
||||
|
||||
# CRITICAL: Handle model-specific conversation requirements
|
||||
# Cohere and some other models require conversation to end with user message
|
||||
if converse_messages:
|
||||
last_message = converse_messages[-1]
|
||||
if last_message["role"] == "assistant":
|
||||
# For Cohere models, add a continuation user message
|
||||
if "cohere" in self.model.lower():
|
||||
converse_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"text": "Please continue and provide your final answer."
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
# For other models that might have similar requirements
|
||||
elif any(
|
||||
model_family in self.model.lower()
|
||||
for model_family in ["command", "coral"]
|
||||
):
|
||||
converse_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"text": "Continue your response."}],
|
||||
}
|
||||
)
|
||||
|
||||
# Ensure first message is from user (required by Converse API)
|
||||
if not converse_messages:
|
||||
converse_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"text": "Hello, please help me with my request."}],
|
||||
}
|
||||
)
|
||||
elif converse_messages[0]["role"] != "user":
|
||||
converse_messages.insert(
|
||||
0,
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"text": "Hello, please help me with my request."}],
|
||||
},
|
||||
)
|
||||
|
||||
return converse_messages, system_message
|
||||
|
||||
def _format_tools_for_converse(self, tools: list[dict]) -> list[dict]:
|
||||
"""Convert CrewAI tools to Converse API format following AWS specification."""
|
||||
from crewai.llms.providers.utils.common import safe_tool_conversion
|
||||
|
||||
converse_tools = []
|
||||
|
||||
for tool in tools:
|
||||
try:
|
||||
name, description, parameters = safe_tool_conversion(tool, "Bedrock")
|
||||
|
||||
converse_tool = {
|
||||
"toolSpec": {
|
||||
"name": name,
|
||||
"description": description,
|
||||
}
|
||||
}
|
||||
|
||||
if parameters and isinstance(parameters, dict):
|
||||
converse_tool["toolSpec"]["inputSchema"] = {"json": parameters}
|
||||
|
||||
converse_tools.append(converse_tool)
|
||||
|
||||
except Exception as e: # noqa: PERF203
|
||||
logging.warning(
|
||||
f"Failed to convert tool {tool.get('name', 'unknown')}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
return converse_tools
|
||||
|
||||
def _get_inference_config(self) -> dict[str, Any]:
|
||||
"""Get inference configuration following AWS Converse API specification."""
|
||||
config = {}
|
||||
|
||||
if self.max_tokens:
|
||||
config["maxTokens"] = self.max_tokens
|
||||
|
||||
if self.temperature is not None:
|
||||
config["temperature"] = float(self.temperature)
|
||||
if self.top_p is not None:
|
||||
config["topP"] = float(self.top_p)
|
||||
if self.stop_sequences:
|
||||
config["stopSequences"] = self.stop_sequences
|
||||
|
||||
if self.is_claude_model and self.top_k is not None:
|
||||
# top_k is supported by Claude models
|
||||
config["topK"] = int(self.top_k)
|
||||
|
||||
return config
|
||||
|
||||
def _handle_client_error(self, e: ClientError) -> str:
|
||||
"""Handle AWS ClientError with specific error codes and return error message."""
|
||||
error_code = e.response.get("Error", {}).get("Code", "Unknown")
|
||||
error_msg = e.response.get("Error", {}).get("Message", str(e))
|
||||
|
||||
error_mapping = {
|
||||
"AccessDeniedException": f"Access denied to model {self.model_id}: {error_msg}",
|
||||
"ResourceNotFoundException": f"Model {self.model_id} not found: {error_msg}",
|
||||
"ThrottlingException": f"API throttled, please retry later: {error_msg}",
|
||||
"ValidationException": f"Invalid request: {error_msg}",
|
||||
"ModelTimeoutException": f"Model request timed out: {error_msg}",
|
||||
"ServiceQuotaExceededException": f"Service quota exceeded: {error_msg}",
|
||||
"ModelNotReadyException": f"Model {self.model_id} not ready: {error_msg}",
|
||||
"ModelErrorException": f"Model error: {error_msg}",
|
||||
}
|
||||
|
||||
full_error_msg = error_mapping.get(
|
||||
error_code, f"Bedrock API error: {error_msg}"
|
||||
)
|
||||
logging.error(f"Bedrock client error ({error_code}): {full_error_msg}")
|
||||
|
||||
return full_error_msg
|
||||
|
||||
def _track_token_usage_internal(self, usage: dict[str, Any]) -> None:
|
||||
"""Track token usage from Bedrock response."""
|
||||
input_tokens = usage.get("inputTokens", 0)
|
||||
output_tokens = usage.get("outputTokens", 0)
|
||||
total_tokens = usage.get("totalTokens", input_tokens + output_tokens)
|
||||
|
||||
self._token_usage["prompt_tokens"] += input_tokens
|
||||
self._token_usage["completion_tokens"] += output_tokens
|
||||
self._token_usage["total_tokens"] += total_tokens
|
||||
self._token_usage["successful_requests"] += 1
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the model supports function calling."""
|
||||
return self.supports_tools
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
"""Check if the model supports stop words."""
|
||||
return True
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
"""Get the context window size for the model."""
|
||||
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO
|
||||
|
||||
# Context window sizes for common Bedrock models
|
||||
context_windows = {
|
||||
"anthropic.claude-3-5-sonnet": 200000,
|
||||
"anthropic.claude-3-5-haiku": 200000,
|
||||
"anthropic.claude-3-opus": 200000,
|
||||
"anthropic.claude-3-sonnet": 200000,
|
||||
"anthropic.claude-3-haiku": 200000,
|
||||
"anthropic.claude-3-7-sonnet": 200000,
|
||||
"anthropic.claude-v2": 100000,
|
||||
"amazon.titan-text-express": 8000,
|
||||
"ai21.j2-ultra": 8192,
|
||||
"cohere.command-text": 4096,
|
||||
"meta.llama2-13b-chat": 4096,
|
||||
"meta.llama2-70b-chat": 4096,
|
||||
"meta.llama3-70b-instruct": 128000,
|
||||
"deepseek.r1": 32768,
|
||||
}
|
||||
|
||||
# Find the best match for the model name
|
||||
for model_prefix, size in context_windows.items():
|
||||
if self.model.startswith(model_prefix):
|
||||
return int(size * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
|
||||
# Default context window size
|
||||
return int(8192 * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
@@ -11,9 +11,9 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
|
||||
|
||||
try:
|
||||
from google import genai # type: ignore
|
||||
from google.genai import types # type: ignore
|
||||
from google.genai.errors import APIError # type: ignore
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
from google.genai.errors import APIError
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Google Gen AI native provider not available, to install: `uv add google-genai`"
|
||||
@@ -40,6 +40,7 @@ class GeminiCompletion(BaseLLM):
|
||||
stop_sequences: list[str] | None = None,
|
||||
stream: bool = False,
|
||||
safety_settings: dict[str, Any] | None = None,
|
||||
client_params: dict[str, Any] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize Google Gemini chat completion client.
|
||||
@@ -56,35 +57,27 @@ class GeminiCompletion(BaseLLM):
|
||||
stop_sequences: Stop sequences
|
||||
stream: Enable streaming responses
|
||||
safety_settings: Safety filter settings
|
||||
client_params: Additional parameters to pass to the Google Gen AI Client constructor.
|
||||
Supports parameters like http_options, credentials, debug_config, etc.
|
||||
**kwargs: Additional parameters
|
||||
"""
|
||||
super().__init__(
|
||||
model=model, temperature=temperature, stop=stop_sequences or [], **kwargs
|
||||
)
|
||||
|
||||
# Get API configuration
|
||||
# Store client params for later use
|
||||
self.client_params = client_params or {}
|
||||
|
||||
# Get API configuration with environment variable fallbacks
|
||||
self.api_key = (
|
||||
api_key or os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY")
|
||||
)
|
||||
self.project = project or os.getenv("GOOGLE_CLOUD_PROJECT")
|
||||
self.location = location or os.getenv("GOOGLE_CLOUD_LOCATION") or "us-central1"
|
||||
|
||||
# Initialize client based on available configuration
|
||||
if self.project:
|
||||
# Use Vertex AI
|
||||
self.client = genai.Client(
|
||||
vertexai=True,
|
||||
project=self.project,
|
||||
location=self.location,
|
||||
)
|
||||
elif self.api_key:
|
||||
# Use Gemini Developer API
|
||||
self.client = genai.Client(api_key=self.api_key)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either GOOGLE_API_KEY/GEMINI_API_KEY (for Gemini API) or "
|
||||
"GOOGLE_CLOUD_PROJECT (for Vertex AI) must be set"
|
||||
)
|
||||
use_vertexai = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true"
|
||||
|
||||
self.client = self._initialize_client(use_vertexai)
|
||||
|
||||
# Store completion parameters
|
||||
self.top_p = top_p
|
||||
@@ -99,6 +92,78 @@ class GeminiCompletion(BaseLLM):
|
||||
self.is_gemini_1_5 = "gemini-1.5" in model.lower()
|
||||
self.supports_tools = self.is_gemini_1_5 or self.is_gemini_2
|
||||
|
||||
def _initialize_client(self, use_vertexai: bool = False) -> genai.Client:
|
||||
"""Initialize the Google Gen AI client with proper parameter handling.
|
||||
|
||||
Args:
|
||||
use_vertexai: Whether to use Vertex AI (from environment variable)
|
||||
|
||||
Returns:
|
||||
Initialized Google Gen AI Client
|
||||
"""
|
||||
client_params = {}
|
||||
|
||||
if self.client_params:
|
||||
client_params.update(self.client_params)
|
||||
|
||||
if use_vertexai or self.project:
|
||||
client_params.update(
|
||||
{
|
||||
"vertexai": True,
|
||||
"project": self.project,
|
||||
"location": self.location,
|
||||
}
|
||||
)
|
||||
|
||||
client_params.pop("api_key", None)
|
||||
|
||||
elif self.api_key:
|
||||
client_params["api_key"] = self.api_key
|
||||
|
||||
client_params.pop("vertexai", None)
|
||||
client_params.pop("project", None)
|
||||
client_params.pop("location", None)
|
||||
|
||||
else:
|
||||
try:
|
||||
return genai.Client(**client_params)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
"Either GOOGLE_API_KEY/GEMINI_API_KEY (for Gemini API) or "
|
||||
"GOOGLE_CLOUD_PROJECT (for Vertex AI) must be set"
|
||||
) from e
|
||||
|
||||
return genai.Client(**client_params)
|
||||
|
||||
def _get_client_params(self) -> dict[str, Any]:
|
||||
"""Get client parameters for compatibility with base class.
|
||||
|
||||
Note: This method is kept for compatibility but the Google Gen AI SDK
|
||||
uses a different initialization pattern via the Client constructor.
|
||||
"""
|
||||
params = {}
|
||||
|
||||
if (
|
||||
hasattr(self, "client")
|
||||
and hasattr(self.client, "vertexai")
|
||||
and self.client.vertexai
|
||||
):
|
||||
# Vertex AI configuration
|
||||
params.update(
|
||||
{
|
||||
"vertexai": True,
|
||||
"project": self.project,
|
||||
"location": self.location,
|
||||
}
|
||||
)
|
||||
elif self.api_key:
|
||||
params["api_key"] = self.api_key
|
||||
|
||||
if self.client_params:
|
||||
params.update(self.client_params)
|
||||
|
||||
return params
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: str | list[dict[str, str]],
|
||||
@@ -427,7 +492,7 @@ class GeminiCompletion(BaseLLM):
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
"""Check if the model supports stop words."""
|
||||
return self._supports_stop_words_implementation()
|
||||
return True
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
"""Get the context window size for the model."""
|
||||
|
||||
@@ -10,7 +10,7 @@ from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
)
|
||||
from openai import OpenAI
|
||||
from openai import APIConnectionError, NotFoundError, OpenAI
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion import Choice
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||
@@ -33,6 +33,9 @@ class OpenAICompletion(BaseLLM):
|
||||
project: str | None = None,
|
||||
timeout: float | None = None,
|
||||
max_retries: int = 2,
|
||||
default_headers: dict[str, str] | None = None,
|
||||
default_query: dict[str, Any] | None = None,
|
||||
client_params: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
@@ -44,8 +47,8 @@ class OpenAICompletion(BaseLLM):
|
||||
response_format: dict[str, Any] | type[BaseModel] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
reasoning_effort: str | None = None, # For o1 models
|
||||
provider: str | None = None, # Add provider parameter
|
||||
reasoning_effort: str | None = None,
|
||||
provider: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize OpenAI chat completion client."""
|
||||
@@ -53,6 +56,16 @@ class OpenAICompletion(BaseLLM):
|
||||
if provider is None:
|
||||
provider = kwargs.pop("provider", "openai")
|
||||
|
||||
# Client configuration attributes
|
||||
self.organization = organization
|
||||
self.project = project
|
||||
self.max_retries = max_retries
|
||||
self.default_headers = default_headers
|
||||
self.default_query = default_query
|
||||
self.client_params = client_params
|
||||
self.timeout = timeout
|
||||
self.base_url = base_url
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
@@ -63,15 +76,10 @@ class OpenAICompletion(BaseLLM):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.client = OpenAI(
|
||||
api_key=api_key or os.getenv("OPENAI_API_KEY"),
|
||||
base_url=base_url,
|
||||
organization=organization,
|
||||
project=project,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
client_config = self._get_client_params()
|
||||
self.client = OpenAI(**client_config)
|
||||
|
||||
# Completion parameters
|
||||
self.top_p = top_p
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.presence_penalty = presence_penalty
|
||||
@@ -83,10 +91,35 @@ class OpenAICompletion(BaseLLM):
|
||||
self.logprobs = logprobs
|
||||
self.top_logprobs = top_logprobs
|
||||
self.reasoning_effort = reasoning_effort
|
||||
self.timeout = timeout
|
||||
self.is_o1_model = "o1" in model.lower()
|
||||
self.is_gpt4_model = "gpt-4" in model.lower()
|
||||
|
||||
def _get_client_params(self) -> dict[str, Any]:
|
||||
"""Get OpenAI client parameters."""
|
||||
|
||||
if self.api_key is None:
|
||||
self.api_key = os.getenv("OPENAI_API_KEY")
|
||||
if self.api_key is None:
|
||||
raise ValueError("OPENAI_API_KEY is required")
|
||||
|
||||
base_params = {
|
||||
"api_key": self.api_key,
|
||||
"organization": self.organization,
|
||||
"project": self.project,
|
||||
"base_url": self.base_url,
|
||||
"timeout": self.timeout,
|
||||
"max_retries": self.max_retries,
|
||||
"default_headers": self.default_headers,
|
||||
"default_query": self.default_query,
|
||||
}
|
||||
|
||||
client_params = {k: v for k, v in base_params.items() if v is not None}
|
||||
|
||||
if self.client_params:
|
||||
client_params.update(self.client_params)
|
||||
|
||||
return client_params
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: str | list[dict[str, str]],
|
||||
@@ -207,7 +240,6 @@ class OpenAICompletion(BaseLLM):
|
||||
"api_key",
|
||||
"base_url",
|
||||
"timeout",
|
||||
"max_retries",
|
||||
}
|
||||
|
||||
return {k: v for k, v in params.items() if k not in crewai_specific_params}
|
||||
@@ -306,10 +338,31 @@ class OpenAICompletion(BaseLLM):
|
||||
|
||||
if usage.get("total_tokens", 0) > 0:
|
||||
logging.info(f"OpenAI API usage: {usage}")
|
||||
except NotFoundError as e:
|
||||
error_msg = f"Model {self.model} not found: {e}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise ValueError(error_msg) from e
|
||||
except APIConnectionError as e:
|
||||
error_msg = f"Failed to connect to OpenAI API: {e}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise ConnectionError(error_msg) from e
|
||||
except Exception as e:
|
||||
# Handle context length exceeded and other errors
|
||||
if is_context_length_exceeded(e):
|
||||
logging.error(f"Context window exceeded: {e}")
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
|
||||
error_msg = f"OpenAI API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise e from e
|
||||
|
||||
return content
|
||||
|
||||
@@ -4,6 +4,7 @@ from dataclasses import field
|
||||
from typing import Literal, cast
|
||||
|
||||
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||
from qdrant_client.models import VectorParams
|
||||
|
||||
from crewai.rag.config.base import BaseRagConfig
|
||||
from crewai.rag.qdrant.constants import DEFAULT_EMBEDDING_MODEL, DEFAULT_STORAGE_PATH
|
||||
@@ -53,3 +54,4 @@ class QdrantConfig(BaseRagConfig):
|
||||
embedding_function: QdrantEmbeddingFunctionWrapper = field(
|
||||
default_factory=_default_embedding_function
|
||||
)
|
||||
vectors_config: VectorParams | None = field(default=None)
|
||||
|
||||
@@ -4,8 +4,8 @@ import asyncio
|
||||
from typing import TypeGuard
|
||||
from uuid import uuid4
|
||||
|
||||
from qdrant_client import AsyncQdrantClient # type: ignore[import-not-found]
|
||||
from qdrant_client import (
|
||||
AsyncQdrantClient, # type: ignore[import-not-found]
|
||||
QdrantClient as SyncQdrantClient, # type: ignore[import-not-found]
|
||||
)
|
||||
from qdrant_client.models import ( # type: ignore[import-not-found]
|
||||
|
||||
@@ -7,7 +7,7 @@ import uuid
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import Future
|
||||
from copy import copy
|
||||
from copy import copy as shallow_copy
|
||||
from hashlib import md5
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
@@ -674,7 +674,9 @@ Follow these guidelines:
|
||||
copied_data = {k: v for k, v in copied_data.items() if v is not None}
|
||||
|
||||
cloned_context = (
|
||||
[task_mapping[context_task.key] for context_task in self.context]
|
||||
self.context
|
||||
if self.context is NOT_SPECIFIED
|
||||
else [task_mapping[context_task.key] for context_task in self.context]
|
||||
if isinstance(self.context, list)
|
||||
else None
|
||||
)
|
||||
@@ -683,7 +685,7 @@ Follow these guidelines:
|
||||
return next((agent for agent in agents if agent.role == role), None)
|
||||
|
||||
cloned_agent = get_agent_by_role(self.agent.role) if self.agent else None
|
||||
cloned_tools = copy(self.tools) if self.tools else []
|
||||
cloned_tools = shallow_copy(self.tools) if self.tools else []
|
||||
|
||||
return self.__class__(
|
||||
**copied_data,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Test Agent creation and execution basic functionality."""
|
||||
|
||||
import os
|
||||
import threading
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@@ -185,14 +186,17 @@ def test_agent_execution_with_tools():
|
||||
expected_output="The result of the multiplication.",
|
||||
)
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(ToolUsageFinishedEvent)
|
||||
def handle_tool_end(source, event):
|
||||
received_events.append(event)
|
||||
event_received.set()
|
||||
|
||||
output = agent.execute_task(task)
|
||||
assert output == "The result of the multiplication is 12."
|
||||
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for tool usage event"
|
||||
assert len(received_events) == 1
|
||||
assert isinstance(received_events[0], ToolUsageFinishedEvent)
|
||||
assert received_events[0].tool_name == "multiplier"
|
||||
@@ -284,10 +288,12 @@ def test_cache_hitting():
|
||||
'multiplier-{"first_number": 12, "second_number": 3}': 36,
|
||||
}
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(ToolUsageFinishedEvent)
|
||||
def handle_tool_end(source, event):
|
||||
received_events.append(event)
|
||||
event_received.set()
|
||||
|
||||
with (
|
||||
patch.object(CacheHandler, "read") as read,
|
||||
@@ -303,6 +309,7 @@ def test_cache_hitting():
|
||||
read.assert_called_with(
|
||||
tool="multiplier", input='{"first_number": 2, "second_number": 6}'
|
||||
)
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for tool usage event"
|
||||
assert len(received_events) == 1
|
||||
assert isinstance(received_events[0], ToolUsageFinishedEvent)
|
||||
assert received_events[0].from_cache
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# mypy: ignore-errors
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from typing import cast
|
||||
from unittest.mock import Mock, patch
|
||||
@@ -156,14 +157,17 @@ def test_lite_agent_with_tools():
|
||||
)
|
||||
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(ToolUsageStartedEvent)
|
||||
def event_handler(source, event):
|
||||
received_events.append(event)
|
||||
event_received.set()
|
||||
|
||||
agent.kickoff("What are the effects of climate change on coral reefs?")
|
||||
|
||||
# Verify tool usage events were emitted
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for tool usage events"
|
||||
assert len(received_events) > 0, "Tool usage events should be emitted"
|
||||
event = received_events[0]
|
||||
assert isinstance(event, ToolUsageStartedEvent)
|
||||
@@ -316,15 +320,18 @@ def test_sets_parent_flow_when_inside_flow():
|
||||
return agent.kickoff("Test query")
|
||||
|
||||
flow = MyFlow()
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
event_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(LiteAgentExecutionStartedEvent)
|
||||
def capture_agent(source, event):
|
||||
nonlocal captured_agent
|
||||
captured_agent = source
|
||||
@crewai_event_bus.on(LiteAgentExecutionStartedEvent)
|
||||
def capture_agent(source, event):
|
||||
nonlocal captured_agent
|
||||
captured_agent = source
|
||||
event_received.set()
|
||||
|
||||
flow.kickoff()
|
||||
assert captured_agent.parent_flow is flow
|
||||
flow.kickoff()
|
||||
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for agent execution event"
|
||||
assert captured_agent.parent_flow is flow
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -342,30 +349,43 @@ def test_guardrail_is_called_using_string():
|
||||
guardrail="""Only include Brazilian players, both women and men""",
|
||||
)
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
all_events_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(LLMGuardrailStartedEvent)
|
||||
def capture_guardrail_started(source, event):
|
||||
assert isinstance(source, LiteAgent)
|
||||
assert source.original_agent == agent
|
||||
guardrail_events["started"].append(event)
|
||||
@crewai_event_bus.on(LLMGuardrailStartedEvent)
|
||||
def capture_guardrail_started(source, event):
|
||||
assert isinstance(source, LiteAgent)
|
||||
assert source.original_agent == agent
|
||||
guardrail_events["started"].append(event)
|
||||
if (
|
||||
len(guardrail_events["started"]) == 2
|
||||
and len(guardrail_events["completed"]) == 2
|
||||
):
|
||||
all_events_received.set()
|
||||
|
||||
@crewai_event_bus.on(LLMGuardrailCompletedEvent)
|
||||
def capture_guardrail_completed(source, event):
|
||||
assert isinstance(source, LiteAgent)
|
||||
assert source.original_agent == agent
|
||||
guardrail_events["completed"].append(event)
|
||||
@crewai_event_bus.on(LLMGuardrailCompletedEvent)
|
||||
def capture_guardrail_completed(source, event):
|
||||
assert isinstance(source, LiteAgent)
|
||||
assert source.original_agent == agent
|
||||
guardrail_events["completed"].append(event)
|
||||
if (
|
||||
len(guardrail_events["started"]) == 2
|
||||
and len(guardrail_events["completed"]) == 2
|
||||
):
|
||||
all_events_received.set()
|
||||
|
||||
result = agent.kickoff(messages="Top 10 best players in the world?")
|
||||
result = agent.kickoff(messages="Top 10 best players in the world?")
|
||||
|
||||
assert len(guardrail_events["started"]) == 2
|
||||
assert len(guardrail_events["completed"]) == 2
|
||||
assert not guardrail_events["completed"][0].success
|
||||
assert guardrail_events["completed"][1].success
|
||||
assert (
|
||||
"Here are the top 10 best soccer players in the world, focusing exclusively on Brazilian players"
|
||||
in result.raw
|
||||
)
|
||||
assert all_events_received.wait(timeout=10), (
|
||||
"Timeout waiting for all guardrail events"
|
||||
)
|
||||
assert len(guardrail_events["started"]) == 2
|
||||
assert len(guardrail_events["completed"]) == 2
|
||||
assert not guardrail_events["completed"][0].success
|
||||
assert guardrail_events["completed"][1].success
|
||||
assert (
|
||||
"Here are the top 10 best soccer players in the world, focusing exclusively on Brazilian players"
|
||||
in result.raw
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -376,29 +396,42 @@ def test_guardrail_is_called_using_callable():
|
||||
LLMGuardrailStartedEvent,
|
||||
)
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
all_events_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(LLMGuardrailStartedEvent)
|
||||
def capture_guardrail_started(source, event):
|
||||
guardrail_events["started"].append(event)
|
||||
@crewai_event_bus.on(LLMGuardrailStartedEvent)
|
||||
def capture_guardrail_started(source, event):
|
||||
guardrail_events["started"].append(event)
|
||||
if (
|
||||
len(guardrail_events["started"]) == 1
|
||||
and len(guardrail_events["completed"]) == 1
|
||||
):
|
||||
all_events_received.set()
|
||||
|
||||
@crewai_event_bus.on(LLMGuardrailCompletedEvent)
|
||||
def capture_guardrail_completed(source, event):
|
||||
guardrail_events["completed"].append(event)
|
||||
@crewai_event_bus.on(LLMGuardrailCompletedEvent)
|
||||
def capture_guardrail_completed(source, event):
|
||||
guardrail_events["completed"].append(event)
|
||||
if (
|
||||
len(guardrail_events["started"]) == 1
|
||||
and len(guardrail_events["completed"]) == 1
|
||||
):
|
||||
all_events_received.set()
|
||||
|
||||
agent = Agent(
|
||||
role="Sports Analyst",
|
||||
goal="Gather information about the best soccer players",
|
||||
backstory="""You are an expert at gathering and organizing information. You carefully collect details and present them in a structured way.""",
|
||||
guardrail=lambda output: (True, "Pelé - Santos, 1958"),
|
||||
)
|
||||
agent = Agent(
|
||||
role="Sports Analyst",
|
||||
goal="Gather information about the best soccer players",
|
||||
backstory="""You are an expert at gathering and organizing information. You carefully collect details and present them in a structured way.""",
|
||||
guardrail=lambda output: (True, "Pelé - Santos, 1958"),
|
||||
)
|
||||
|
||||
result = agent.kickoff(messages="Top 1 best players in the world?")
|
||||
result = agent.kickoff(messages="Top 1 best players in the world?")
|
||||
|
||||
assert len(guardrail_events["started"]) == 1
|
||||
assert len(guardrail_events["completed"]) == 1
|
||||
assert guardrail_events["completed"][0].success
|
||||
assert "Pelé - Santos, 1958" in result.raw
|
||||
assert all_events_received.wait(timeout=10), (
|
||||
"Timeout waiting for all guardrail events"
|
||||
)
|
||||
assert len(guardrail_events["started"]) == 1
|
||||
assert len(guardrail_events["completed"]) == 1
|
||||
assert guardrail_events["completed"][0].success
|
||||
assert "Pelé - Santos, 1958" in result.raw
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -409,37 +442,50 @@ def test_guardrail_reached_attempt_limit():
|
||||
LLMGuardrailStartedEvent,
|
||||
)
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
all_events_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(LLMGuardrailStartedEvent)
|
||||
def capture_guardrail_started(source, event):
|
||||
guardrail_events["started"].append(event)
|
||||
|
||||
@crewai_event_bus.on(LLMGuardrailCompletedEvent)
|
||||
def capture_guardrail_completed(source, event):
|
||||
guardrail_events["completed"].append(event)
|
||||
|
||||
agent = Agent(
|
||||
role="Sports Analyst",
|
||||
goal="Gather information about the best soccer players",
|
||||
backstory="""You are an expert at gathering and organizing information. You carefully collect details and present them in a structured way.""",
|
||||
guardrail=lambda output: (
|
||||
False,
|
||||
"You are not allowed to include Brazilian players",
|
||||
),
|
||||
guardrail_max_retries=2,
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
Exception, match="Agent's guardrail failed validation after 2 retries"
|
||||
@crewai_event_bus.on(LLMGuardrailStartedEvent)
|
||||
def capture_guardrail_started(source, event):
|
||||
guardrail_events["started"].append(event)
|
||||
if (
|
||||
len(guardrail_events["started"]) == 3
|
||||
and len(guardrail_events["completed"]) == 3
|
||||
):
|
||||
agent.kickoff(messages="Top 10 best players in the world?")
|
||||
all_events_received.set()
|
||||
|
||||
assert len(guardrail_events["started"]) == 3 # 2 retries + 1 initial call
|
||||
assert len(guardrail_events["completed"]) == 3 # 2 retries + 1 initial call
|
||||
assert not guardrail_events["completed"][0].success
|
||||
assert not guardrail_events["completed"][1].success
|
||||
assert not guardrail_events["completed"][2].success
|
||||
@crewai_event_bus.on(LLMGuardrailCompletedEvent)
|
||||
def capture_guardrail_completed(source, event):
|
||||
guardrail_events["completed"].append(event)
|
||||
if (
|
||||
len(guardrail_events["started"]) == 3
|
||||
and len(guardrail_events["completed"]) == 3
|
||||
):
|
||||
all_events_received.set()
|
||||
|
||||
agent = Agent(
|
||||
role="Sports Analyst",
|
||||
goal="Gather information about the best soccer players",
|
||||
backstory="""You are an expert at gathering and organizing information. You carefully collect details and present them in a structured way.""",
|
||||
guardrail=lambda output: (
|
||||
False,
|
||||
"You are not allowed to include Brazilian players",
|
||||
),
|
||||
guardrail_max_retries=2,
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
Exception, match="Agent's guardrail failed validation after 2 retries"
|
||||
):
|
||||
agent.kickoff(messages="Top 10 best players in the world?")
|
||||
|
||||
assert all_events_received.wait(timeout=10), (
|
||||
"Timeout waiting for all guardrail events"
|
||||
)
|
||||
assert len(guardrail_events["started"]) == 3 # 2 retries + 1 initial call
|
||||
assert len(guardrail_events["completed"]) == 3 # 2 retries + 1 initial call
|
||||
assert not guardrail_events["completed"][0].success
|
||||
assert not guardrail_events["completed"][1].success
|
||||
assert not guardrail_events["completed"][2].success
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
|
||||
227
lib/crewai/tests/cassettes/test_openai_completion_call.yaml
Normal file
227
lib/crewai/tests/cassettes/test_openai_completion_call.yaml
Normal file
@@ -0,0 +1,227 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"messages": [{"role": "user", "content": "Hello, how are you?"}], "model":
|
||||
"gpt-4o", "stream": false}'
|
||||
headers:
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- gzip, deflate
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '102'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- api.openai.com
|
||||
user-agent:
|
||||
- OpenAI/Python 1.109.1
|
||||
x-stainless-arch:
|
||||
- arm64
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- MacOS
|
||||
x-stainless-package-version:
|
||||
- 1.109.1
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.13.3
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: !!binary |
|
||||
H4sIAAAAAAAAAwAAAP//jFJNj9MwEL3nVwy+7CVdpd1+XxBCXbUSB7ggBFpFrj1JvDgeY08K1ar/
|
||||
HSXpNl1YJC4+zJs3fu/NPCUAwmixBqEqyar2dvT+04dNEzbOvXOrzz+228PHGS5pu/k6u7//ItKW
|
||||
QftHVPzMulVUe4tsyPWwCigZ26njxTy7W00n2bIDatJoW1rpeTSl0SSbTEfZcpTNz8SKjMIo1vAt
|
||||
AQB46t5WotP4S6whS58rNcYoSxTrSxOACGTbipAxmsjSsUgHUJFjdJ3qLVpLb2B3U8NjExkk+EBl
|
||||
kHUKkWAHmtwNQyUPCAWiNa6MKewb7hgVBgTpNASU+ghMUKH1cKTmFrb0E5R0sINeQlsFJi2Pb6+l
|
||||
BCyaKNskXGPtFSCdI5Ztkl0ID2fkdLFtqfSB9vEPqiiMM7HKA8pIrrUYmbzo0FMC8NDF27xITPhA
|
||||
teec6Tt2343v+nFi2OcATlZnkImlHerTSfrKtFwjS2Pj1XqEkqpCPTCHXcpGG7oCkivPf4t5bXbv
|
||||
27jyf8YPgFLoGXXuA2qjXhoe2gK21/6vtkvGnWARMRyMwpwNhnYPGgvZ2P4QRTxGxjovjCsx+GD6
|
||||
ayx8rvbFeLGczeYLkZyS3wAAAP//AwCZQodJlgMAAA==
|
||||
headers:
|
||||
CF-RAY:
|
||||
- 98e23dd86b0c4705-SJC
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Encoding:
|
||||
- gzip
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Mon, 13 Oct 2025 22:23:30 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Set-Cookie:
|
||||
- __cf_bm=wwEqnpcIZyBbBZ_COqrhykwhzQkjmXMsXhNFYjtokPs-1760394210-1.0.1.1-8gJdrt5_Ak6dIqzZox1X9WYI1a7OgSgwaiJdWzz3egks.yw87Cm9__k5K.j4aXQFrUQt7b3OBkTuyrhIysP_CtKEqT5ap_Gc6vH4XqNYXVw;
|
||||
path=/; expires=Mon, 13-Oct-25 22:53:30 GMT; domain=.api.openai.com; HttpOnly;
|
||||
Secure; SameSite=None
|
||||
- _cfuvid=MTZb.IlikCEE87xU.hPEMy_FZxe7wdzqB_xM1BQOjQs-1760394210023-0.0.1.1-604800000;
|
||||
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
|
||||
Strict-Transport-Security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- nosniff
|
||||
access-control-expose-headers:
|
||||
- X-Request-ID
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
openai-organization:
|
||||
- crewai-iuxna1
|
||||
openai-processing-ms:
|
||||
- '1252'
|
||||
openai-project:
|
||||
- proj_xitITlrFeen7zjNSzML82h9x
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
x-envoy-upstream-service-time:
|
||||
- '1451'
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-project-requests:
|
||||
- '10000'
|
||||
x-ratelimit-limit-requests:
|
||||
- '10000'
|
||||
x-ratelimit-limit-tokens:
|
||||
- '30000000'
|
||||
x-ratelimit-remaining-project-requests:
|
||||
- '9999'
|
||||
x-ratelimit-remaining-requests:
|
||||
- '9999'
|
||||
x-ratelimit-remaining-tokens:
|
||||
- '29999993'
|
||||
x-ratelimit-reset-project-requests:
|
||||
- 6ms
|
||||
x-ratelimit-reset-requests:
|
||||
- 6ms
|
||||
x-ratelimit-reset-tokens:
|
||||
- 0s
|
||||
x-request-id:
|
||||
- req_bfe85ec6f9514d3093d79765a87c6c7b
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
- request:
|
||||
body: '{"messages": [{"role": "user", "content": "Hello, how are you?"}], "model":
|
||||
"gpt-4o", "stream": false}'
|
||||
headers:
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- gzip, deflate
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '102'
|
||||
content-type:
|
||||
- application/json
|
||||
cookie:
|
||||
- __cf_bm=wwEqnpcIZyBbBZ_COqrhykwhzQkjmXMsXhNFYjtokPs-1760394210-1.0.1.1-8gJdrt5_Ak6dIqzZox1X9WYI1a7OgSgwaiJdWzz3egks.yw87Cm9__k5K.j4aXQFrUQt7b3OBkTuyrhIysP_CtKEqT5ap_Gc6vH4XqNYXVw;
|
||||
_cfuvid=MTZb.IlikCEE87xU.hPEMy_FZxe7wdzqB_xM1BQOjQs-1760394210023-0.0.1.1-604800000
|
||||
host:
|
||||
- api.openai.com
|
||||
user-agent:
|
||||
- OpenAI/Python 1.109.1
|
||||
x-stainless-arch:
|
||||
- arm64
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- MacOS
|
||||
x-stainless-package-version:
|
||||
- 1.109.1
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.13.3
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: !!binary |
|
||||
H4sIAAAAAAAAAwAAAP//jFJNa9tAEL3rV0z3kosc5I/Iji8lFIJNPyBQSqEEsd4dSZusdpbdUVoT
|
||||
/N+LJMdy2hR62cO8ebPvvZnnBEAYLdYgVC1ZNd5OPtx9+qym7d3+a/4N69I9OpVtbubfP97efrkR
|
||||
aceg3QMqfmFdKmq8RTbkBlgFlIzd1Okyz+bXi3yV90BDGm1HqzxPFjSZZbPFJFtNsvxIrMkojGIN
|
||||
PxIAgOf+7SQ6jb/EGrL0pdJgjLJCsT41AYhAtqsIGaOJLB2LdAQVOUbXq96gtfQOthcNPLSRQYIP
|
||||
VAXZpBAJtqDJXTDU8gmhRLTGVTGFXcs9o8aAIJ2GgFLvgQlqtB721F7Chn6Ckg62MEjoqsCk5f79
|
||||
uZSAZRtll4RrrT0DpHPEskuyD+H+iBxOti1VPtAu/kEVpXEm1kVAGcl1FiOTFz16SADu+3jbV4kJ
|
||||
H6jxXDA9Yv/ddD6ME+M+R3B2fQSZWNqxvpilb0wrNLI0Np6tRyipatQjc9ylbLWhMyA58/y3mLdm
|
||||
D76Nq/5n/AgohZ5RFz6gNuq14bEtYHft/2o7ZdwLFhHDk1FYsMHQ7UFjKVs7HKKI+8jYFKVxFQYf
|
||||
zHCNpS/UrpwuV1dX+VIkh+Q3AAAA//8DAISwErWWAwAA
|
||||
headers:
|
||||
CF-RAY:
|
||||
- 98e249852df117c4-SJC
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Encoding:
|
||||
- gzip
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Mon, 13 Oct 2025 22:31:27 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Strict-Transport-Security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- nosniff
|
||||
access-control-expose-headers:
|
||||
- X-Request-ID
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
openai-organization:
|
||||
- crewai-iuxna1
|
||||
openai-processing-ms:
|
||||
- '512'
|
||||
openai-project:
|
||||
- proj_xitITlrFeen7zjNSzML82h9x
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
x-envoy-upstream-service-time:
|
||||
- '670'
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-project-requests:
|
||||
- '10000'
|
||||
x-ratelimit-limit-requests:
|
||||
- '10000'
|
||||
x-ratelimit-limit-tokens:
|
||||
- '30000000'
|
||||
x-ratelimit-remaining-project-requests:
|
||||
- '9999'
|
||||
x-ratelimit-remaining-requests:
|
||||
- '9999'
|
||||
x-ratelimit-remaining-tokens:
|
||||
- '29999993'
|
||||
x-ratelimit-reset-project-requests:
|
||||
- 6ms
|
||||
x-ratelimit-reset-requests:
|
||||
- 6ms
|
||||
x-ratelimit-reset-tokens:
|
||||
- 0s
|
||||
x-request-id:
|
||||
- req_6d219ed625a24c38895b896c9e13dcef
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
@@ -0,0 +1,129 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"messages": [{"role": "system", "content": "You are Research Assistant.
|
||||
You are a helpful research assistant.\nYour personal goal is: Find information
|
||||
about the population of Tokyo\nTo give my best complete final answer to the
|
||||
task respond using the exact following format:\n\nThought: I now can give a
|
||||
great answer\nFinal Answer: Your final answer must be the great and the most
|
||||
complete as possible, it must be outcome described.\n\nI MUST use these formats,
|
||||
my job depends on it!"}, {"role": "user", "content": "\nCurrent Task: Find information
|
||||
about the population of Tokyo\n\nThis is the expected criteria for your final
|
||||
answer: The population of Tokyo is 10 million\nyou MUST return the actual complete
|
||||
content as the final answer, not a summary.\n\nBegin! This is VERY important
|
||||
to you, use the tools available and give your best Final Answer, your job depends
|
||||
on it!\n\nThought:"}], "model": "gpt-4o", "stream": false}'
|
||||
headers:
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- gzip, deflate
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '927'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- api.openai.com
|
||||
user-agent:
|
||||
- OpenAI/Python 1.109.1
|
||||
x-stainless-arch:
|
||||
- arm64
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- MacOS
|
||||
x-stainless-package-version:
|
||||
- 1.109.1
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.13.3
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: !!binary |
|
||||
H4sIAAAAAAAAAwAAAP//jFTbahsxEH33Vwx6Xgdf0sT2Wwi09AKlkFJoG8xYmt2dRqtRJa0dN+Tf
|
||||
i2Qndi6Fvixoz5yjc0Yj3Q0AFBu1AKVbTLrzdnj55RNtPm7N+s+cv9t6Mp/dXpL99vXz5fbig6oy
|
||||
Q1a/SKcH1omWzltKLG4H60CYKKuOz89G0/mb2WRSgE4M2UxrfBqeynAympwOR7Ph6GxPbIU1RbWA
|
||||
HwMAgLvyzRadoVu1gFH18KejGLEhtXgsAlBBbP6jMEaOCV1S1QHU4hK54vqqlb5p0wLeg5MNaHTQ
|
||||
8JoAocnWAV3cUAD46d6yQwsXZb2Aq5bAi+8t5rAgNVzJzVYqYKdtb9g1sJLUAqcIHaUgXiwndICB
|
||||
ENAZSC3BZArRk2a0sMFgIqQWE3R4Q9D7UqHJpYAWNKdtBRwhcuO4Zo0u2S1YDA2FTHMwHkHH1rK4
|
||||
E7iI2VIW6CQmCJR1wGDCamc0S4mjJ1Ulj/Qxb8YUgV3BNhKsqWDDqS3rd+VMw17nIudpcZ0T47OW
|
||||
yJoCTM8fbIEn8ZaqHDCXc3pl75fNOrZxUhr/om2H9m9a1m3mFQ797nmNNmeXGnDfxRbLAWvpVuzI
|
||||
HJsu/adbTWQizJ8ZPzmeoUB1HzGPsOutPQLQOUlFrUzv9R65f5xXK40PsorPqKpmx7FdBsIoLs9m
|
||||
TOJVQe8HANflXvRPRl35IJ1PyyQ3VLYbn093eupwE4/Q8dkeTZLQHoDJbF69Irg0lJBtPLpaSqNu
|
||||
yRyoh3uIvWE5AgZHsV/aeU17F51d8z/yB0Br8onM0gcyrJ9GPpQFyi/Vv8oe21wMq0hhzZqWiSnk
|
||||
ozBUY293j4iK25ioW9bsGgo+8O4lqf2SVlM91avZqVGD+8FfAAAA//8DAFlnuIlSBQAA
|
||||
headers:
|
||||
CF-RAY:
|
||||
- 98e26542adbbce40-SJC
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Encoding:
|
||||
- gzip
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Mon, 13 Oct 2025 22:50:26 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Set-Cookie:
|
||||
- __cf_bm=ZOY3aTF4ZQGyq1Ai5bME5tI2L4FUKjdaM76hKUktVgg-1760395826-1.0.1.1-6MNmhofBsqJxHCGxkDDtTbJUi9JDiJwdeBOsfQEvrMTovTmf8eAYxjskKbAxY0ZicvPhqx2bOD64cOAPUfREUiFdzz1oh3uKuy4_AL9Vma0;
|
||||
path=/; expires=Mon, 13-Oct-25 23:20:26 GMT; domain=.api.openai.com; HttpOnly;
|
||||
Secure; SameSite=None
|
||||
- _cfuvid=ETABAP9icJoaIxhFazEUuSnHhwqlBentj3YJUS501.w-1760395826352-0.0.1.1-604800000;
|
||||
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
|
||||
Strict-Transport-Security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- nosniff
|
||||
access-control-expose-headers:
|
||||
- X-Request-ID
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
openai-organization:
|
||||
- crewai-iuxna1
|
||||
openai-processing-ms:
|
||||
- '3572'
|
||||
openai-project:
|
||||
- proj_xitITlrFeen7zjNSzML82h9x
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
x-envoy-upstream-service-time:
|
||||
- '3756'
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-project-requests:
|
||||
- '10000'
|
||||
x-ratelimit-limit-requests:
|
||||
- '10000'
|
||||
x-ratelimit-limit-tokens:
|
||||
- '30000000'
|
||||
x-ratelimit-remaining-project-requests:
|
||||
- '9999'
|
||||
x-ratelimit-remaining-requests:
|
||||
- '9999'
|
||||
x-ratelimit-remaining-tokens:
|
||||
- '29999798'
|
||||
x-ratelimit-reset-project-requests:
|
||||
- 6ms
|
||||
x-ratelimit-reset-requests:
|
||||
- 6ms
|
||||
x-ratelimit-reset-tokens:
|
||||
- 0s
|
||||
x-request-id:
|
||||
- req_3676b4edd10244929526ceb64a623a88
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
@@ -0,0 +1,133 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"messages": [{"role": "system", "content": "You are Research Assistant.
|
||||
You are a helpful research assistant.\nYour personal goal is: Find information
|
||||
about the population of Tokyo\nTo give my best complete final answer to the
|
||||
task respond using the exact following format:\n\nThought: I now can give a
|
||||
great answer\nFinal Answer: Your final answer must be the great and the most
|
||||
complete as possible, it must be outcome described.\n\nI MUST use these formats,
|
||||
my job depends on it!"}, {"role": "user", "content": "\nCurrent Task: Find information
|
||||
about the population of Tokyo\n\nThis is the expected criteria for your final
|
||||
answer: The population of Tokyo is 10 million\nyou MUST return the actual complete
|
||||
content as the final answer, not a summary.\n\nBegin! This is VERY important
|
||||
to you, use the tools available and give your best Final Answer, your job depends
|
||||
on it!\n\nThought:"}], "model": "gpt-4o-mini", "stream": false}'
|
||||
headers:
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- gzip, deflate
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '932'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- api.openai.com
|
||||
user-agent:
|
||||
- OpenAI/Python 1.109.1
|
||||
x-stainless-arch:
|
||||
- arm64
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- MacOS
|
||||
x-stainless-package-version:
|
||||
- 1.109.1
|
||||
x-stainless-read-timeout:
|
||||
- '600'
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.13.3
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: !!binary |
|
||||
H4sIAAAAAAAAA4xUTY8bNwy9+1cQcx4vbK+93vXNDdompyJF0Bb5gEFrODPMSqJASXa8wf73QmN7
|
||||
7W1ToJcBxMdHvscR9X0EUHFTraAyPSbjgh2/ef9+8eanj7M/lr8vzduv+eHP/c9/rZ8+2vXTnKu6
|
||||
MGT7lUw6s26MuGApsfgjbJQwUak6Xd5N5tPZ/exuAJw0ZAutC2k8l7Fjz+PZZDYfT5bj6f2J3Qsb
|
||||
itUKPo0AAL4P36LTN/StWsGkPkccxYgdVauXJIBKxZZIhTFyTOhTVV9AIz6RH6S/Ay97MOih4x0B
|
||||
QldkA/q4JwX47H9hjxbWw3kF6wjSgjuAxZgghwYTAXv4zSTZksJsMrutIfUEQUK2WMZRGB/k8SDA
|
||||
ETAElW/sMJE9wHQOjq0tSYEk2KFWYRvySdECKmEN+54tDfFfh6Hqqd76jJoe2BubG4oQs6pk37Dv
|
||||
ICi1ZFJWijX0GAEhJuw60gF9JVF2pHC7PAuqweFjyeI0dHYS05EhOYKjpBLEckI/iDwL34va5gY+
|
||||
FA+cDlfeUyTblhEoedl7aqAVLWHY8VbRJzDZFqk1NLwjjQRkxIs71IC+gcid55ZNyeysbNEC+9Zm
|
||||
8oaODa/8tNwV0+DwAK3NJuXyo5pMkAR2qFxMtGiSaJmY6QEjOO50oNeQdYuen06n0r4hJ51i6NlA
|
||||
UvJNrGGbE3TkSdHaQ10mpeSQfQTxVKyXiVjUjsplKSUBu86Ko2OfeDJiDzfX11OpzRHLivhs7RWA
|
||||
3ks6MstifDkhzy+rYKULKtv4D2rVsufYb5Qwii/XPiYJ1YA+jwC+DCuXX21RFVRcSJskjzS0my5v
|
||||
j/Wqy6ZfobPFCU2S0F6A2cN9/YOCm4YSso1XW1sZND01F+plxTE3LFfA6Mr2v+X8qPbROvvu/5S/
|
||||
AMZQSNRsglLD5rXlS5pSeQn/K+1lzIPgKpLu2NAmMWn5FQ21mO3xfariISZym5Z9RxqUj49UGzaL
|
||||
uwm2d7RYPFSj59HfAAAA//8DAB8kWOqyBQAA
|
||||
headers:
|
||||
CF-RAY:
|
||||
- 98e404605874fad2-SJC
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Encoding:
|
||||
- gzip
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Tue, 14 Oct 2025 03:33:48 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Set-Cookie:
|
||||
- __cf_bm=o5Vy5q.qstP73vjTrIb7GX6EjMltWq26Vk1ctm8rrcQ-1760412828-1.0.1.1-6PmDQhWH5.60C02WBN9ENJiBEZ0hYXY1YJ6TKxTAflRETSCaMVA2j1.xE2KPFpUrsSsmbkopxQ1p2NYmLzuRy08dingIYyz5HZGz8ghl.nM;
|
||||
path=/; expires=Tue, 14-Oct-25 04:03:48 GMT; domain=.api.openai.com; HttpOnly;
|
||||
Secure; SameSite=None
|
||||
- _cfuvid=TkrzMwZH3VZy7i4ED_kVxlx4MUrHeXnluoFfmeqTT2w-1760412828927-0.0.1.1-604800000;
|
||||
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
|
||||
Strict-Transport-Security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- nosniff
|
||||
access-control-expose-headers:
|
||||
- X-Request-ID
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
openai-organization:
|
||||
- crewai-iuxna1
|
||||
openai-processing-ms:
|
||||
- '2644'
|
||||
openai-project:
|
||||
- proj_xitITlrFeen7zjNSzML82h9x
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
x-envoy-upstream-service-time:
|
||||
- '2793'
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-project-tokens:
|
||||
- '150000000'
|
||||
x-ratelimit-limit-requests:
|
||||
- '30000'
|
||||
x-ratelimit-limit-tokens:
|
||||
- '150000000'
|
||||
x-ratelimit-remaining-project-tokens:
|
||||
- '149999797'
|
||||
x-ratelimit-remaining-requests:
|
||||
- '29999'
|
||||
x-ratelimit-remaining-tokens:
|
||||
- '149999797'
|
||||
x-ratelimit-reset-project-tokens:
|
||||
- 0s
|
||||
x-ratelimit-reset-requests:
|
||||
- 2ms
|
||||
x-ratelimit-reset-tokens:
|
||||
- 0s
|
||||
x-request-id:
|
||||
- req_5c4fad6d3e4743d1a43ab65bd333b477
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
@@ -2,6 +2,7 @@ import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import jwt
|
||||
|
||||
from crewai.cli.authentication.utils import validate_jwt_token
|
||||
|
||||
|
||||
@@ -16,19 +17,22 @@ class TestUtils(unittest.TestCase):
|
||||
key="mock_signing_key"
|
||||
)
|
||||
|
||||
jwt_token = "aaaaa.bbbbbb.cccccc" # noqa: S105
|
||||
|
||||
decoded_token = validate_jwt_token(
|
||||
jwt_token="aaaaa.bbbbbb.cccccc",
|
||||
jwt_token=jwt_token,
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
)
|
||||
|
||||
mock_jwt.decode.assert_called_with(
|
||||
"aaaaa.bbbbbb.cccccc",
|
||||
jwt_token,
|
||||
"mock_signing_key",
|
||||
algorithms=["RS256"],
|
||||
audience="app_id_xxxx",
|
||||
issuer="https://mock_issuer",
|
||||
leeway=10.0,
|
||||
options={
|
||||
"verify_signature": True,
|
||||
"verify_exp": True,
|
||||
@@ -42,9 +46,9 @@ class TestUtils(unittest.TestCase):
|
||||
|
||||
def test_validate_jwt_token_expired(self, mock_jwt, mock_pyjwkclient):
|
||||
mock_jwt.decode.side_effect = jwt.ExpiredSignatureError
|
||||
with self.assertRaises(Exception):
|
||||
with self.assertRaises(Exception): # noqa: B017
|
||||
validate_jwt_token(
|
||||
jwt_token="aaaaa.bbbbbb.cccccc",
|
||||
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
@@ -52,9 +56,9 @@ class TestUtils(unittest.TestCase):
|
||||
|
||||
def test_validate_jwt_token_invalid_audience(self, mock_jwt, mock_pyjwkclient):
|
||||
mock_jwt.decode.side_effect = jwt.InvalidAudienceError
|
||||
with self.assertRaises(Exception):
|
||||
with self.assertRaises(Exception): # noqa: B017
|
||||
validate_jwt_token(
|
||||
jwt_token="aaaaa.bbbbbb.cccccc",
|
||||
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
@@ -62,9 +66,9 @@ class TestUtils(unittest.TestCase):
|
||||
|
||||
def test_validate_jwt_token_invalid_issuer(self, mock_jwt, mock_pyjwkclient):
|
||||
mock_jwt.decode.side_effect = jwt.InvalidIssuerError
|
||||
with self.assertRaises(Exception):
|
||||
with self.assertRaises(Exception): # noqa: B017
|
||||
validate_jwt_token(
|
||||
jwt_token="aaaaa.bbbbbb.cccccc",
|
||||
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
@@ -74,9 +78,9 @@ class TestUtils(unittest.TestCase):
|
||||
self, mock_jwt, mock_pyjwkclient
|
||||
):
|
||||
mock_jwt.decode.side_effect = jwt.MissingRequiredClaimError
|
||||
with self.assertRaises(Exception):
|
||||
with self.assertRaises(Exception): # noqa: B017
|
||||
validate_jwt_token(
|
||||
jwt_token="aaaaa.bbbbbb.cccccc",
|
||||
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
@@ -84,9 +88,9 @@ class TestUtils(unittest.TestCase):
|
||||
|
||||
def test_validate_jwt_token_jwks_error(self, mock_jwt, mock_pyjwkclient):
|
||||
mock_jwt.decode.side_effect = jwt.exceptions.PyJWKClientError
|
||||
with self.assertRaises(Exception):
|
||||
with self.assertRaises(Exception): # noqa: B017
|
||||
validate_jwt_token(
|
||||
jwt_token="aaaaa.bbbbbb.cccccc",
|
||||
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
@@ -94,9 +98,9 @@ class TestUtils(unittest.TestCase):
|
||||
|
||||
def test_validate_jwt_token_invalid_token(self, mock_jwt, mock_pyjwkclient):
|
||||
mock_jwt.decode.side_effect = jwt.InvalidTokenError
|
||||
with self.assertRaises(Exception):
|
||||
with self.assertRaises(Exception): # noqa: B017
|
||||
validate_jwt_token(
|
||||
jwt_token="aaaaa.bbbbbb.cccccc",
|
||||
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
|
||||
@@ -33,7 +33,7 @@ def setup_test_environment():
|
||||
except (OSError, IOError) as e:
|
||||
raise RuntimeError(
|
||||
f"Test storage directory {storage_dir} is not writable: {e}"
|
||||
)
|
||||
) from e
|
||||
|
||||
os.environ["CREWAI_STORAGE_DIR"] = str(storage_dir)
|
||||
os.environ["CREWAI_TESTING"] = "true"
|
||||
@@ -159,6 +159,29 @@ def mock_opentelemetry_components():
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_event_bus_handlers():
|
||||
"""Clear event bus handlers after each test for isolation.
|
||||
|
||||
Handlers registered during the test are allowed to run, then cleaned up
|
||||
after the test completes.
|
||||
"""
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.experimental.evaluation.evaluation_listener import (
|
||||
EvaluationTraceCallback,
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
crewai_event_bus.shutdown(wait=True)
|
||||
crewai_event_bus._initialize()
|
||||
|
||||
callback = EvaluationTraceCallback()
|
||||
callback.traces.clear()
|
||||
callback.current_agent_id = None
|
||||
callback.current_task_id = None
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def vcr_config(request) -> dict:
|
||||
import os
|
||||
|
||||
286
lib/crewai/tests/events/test_depends.py
Normal file
286
lib/crewai/tests/events/test_depends.py
Normal file
@@ -0,0 +1,286 @@
|
||||
"""Tests for FastAPI-style dependency injection in event handlers."""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.events import Depends, crewai_event_bus
|
||||
from crewai.events.base_events import BaseEvent
|
||||
|
||||
|
||||
class DependsTestEvent(BaseEvent):
|
||||
"""Test event for dependency tests."""
|
||||
|
||||
value: int = 0
|
||||
type: str = "test_event"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_dependency():
|
||||
"""Test that handler with dependency runs after its dependency."""
|
||||
execution_order = []
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(DependsTestEvent)
|
||||
def setup(source, event: DependsTestEvent):
|
||||
execution_order.append("setup")
|
||||
|
||||
@crewai_event_bus.on(DependsTestEvent, Depends(setup))
|
||||
def process(source, event: DependsTestEvent):
|
||||
execution_order.append("process")
|
||||
|
||||
event = DependsTestEvent(value=1)
|
||||
future = crewai_event_bus.emit("test_source", event)
|
||||
|
||||
if future:
|
||||
await asyncio.wrap_future(future)
|
||||
|
||||
assert execution_order == ["setup", "process"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_dependencies():
|
||||
"""Test handler with multiple dependencies."""
|
||||
execution_order = []
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(DependsTestEvent)
|
||||
def setup_a(source, event: DependsTestEvent):
|
||||
execution_order.append("setup_a")
|
||||
|
||||
@crewai_event_bus.on(DependsTestEvent)
|
||||
def setup_b(source, event: DependsTestEvent):
|
||||
execution_order.append("setup_b")
|
||||
|
||||
@crewai_event_bus.on(
|
||||
DependsTestEvent, depends_on=[Depends(setup_a), Depends(setup_b)]
|
||||
)
|
||||
def process(source, event: DependsTestEvent):
|
||||
execution_order.append("process")
|
||||
|
||||
event = DependsTestEvent(value=1)
|
||||
future = crewai_event_bus.emit("test_source", event)
|
||||
|
||||
if future:
|
||||
await asyncio.wrap_future(future)
|
||||
|
||||
# setup_a and setup_b can run in any order (same level)
|
||||
assert "process" in execution_order
|
||||
assert execution_order.index("process") > execution_order.index("setup_a")
|
||||
assert execution_order.index("process") > execution_order.index("setup_b")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chain_of_dependencies():
|
||||
"""Test chain of dependencies (A -> B -> C)."""
|
||||
execution_order = []
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(DependsTestEvent)
|
||||
def handler_a(source, event: DependsTestEvent):
|
||||
execution_order.append("handler_a")
|
||||
|
||||
@crewai_event_bus.on(DependsTestEvent, depends_on=Depends(handler_a))
|
||||
def handler_b(source, event: DependsTestEvent):
|
||||
execution_order.append("handler_b")
|
||||
|
||||
@crewai_event_bus.on(DependsTestEvent, depends_on=Depends(handler_b))
|
||||
def handler_c(source, event: DependsTestEvent):
|
||||
execution_order.append("handler_c")
|
||||
|
||||
event = DependsTestEvent(value=1)
|
||||
future = crewai_event_bus.emit("test_source", event)
|
||||
|
||||
if future:
|
||||
await asyncio.wrap_future(future)
|
||||
|
||||
assert execution_order == ["handler_a", "handler_b", "handler_c"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_handler_with_dependency():
|
||||
"""Test async handler with dependency on sync handler."""
|
||||
execution_order = []
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(DependsTestEvent)
|
||||
def sync_setup(source, event: DependsTestEvent):
|
||||
execution_order.append("sync_setup")
|
||||
|
||||
@crewai_event_bus.on(DependsTestEvent, depends_on=Depends(sync_setup))
|
||||
async def async_process(source, event: DependsTestEvent):
|
||||
await asyncio.sleep(0.01)
|
||||
execution_order.append("async_process")
|
||||
|
||||
event = DependsTestEvent(value=1)
|
||||
future = crewai_event_bus.emit("test_source", event)
|
||||
|
||||
if future:
|
||||
await asyncio.wrap_future(future)
|
||||
|
||||
assert execution_order == ["sync_setup", "async_process"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_handlers_with_dependencies():
|
||||
"""Test mix of sync and async handlers with dependencies."""
|
||||
execution_order = []
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(DependsTestEvent)
|
||||
def setup(source, event: DependsTestEvent):
|
||||
execution_order.append("setup")
|
||||
|
||||
@crewai_event_bus.on(DependsTestEvent, depends_on=Depends(setup))
|
||||
def sync_process(source, event: DependsTestEvent):
|
||||
execution_order.append("sync_process")
|
||||
|
||||
@crewai_event_bus.on(DependsTestEvent, depends_on=Depends(setup))
|
||||
async def async_process(source, event: DependsTestEvent):
|
||||
await asyncio.sleep(0.01)
|
||||
execution_order.append("async_process")
|
||||
|
||||
@crewai_event_bus.on(
|
||||
DependsTestEvent, depends_on=[Depends(sync_process), Depends(async_process)]
|
||||
)
|
||||
def finalize(source, event: DependsTestEvent):
|
||||
execution_order.append("finalize")
|
||||
|
||||
event = DependsTestEvent(value=1)
|
||||
future = crewai_event_bus.emit("test_source", event)
|
||||
|
||||
if future:
|
||||
await asyncio.wrap_future(future)
|
||||
|
||||
# Verify execution order
|
||||
assert execution_order[0] == "setup"
|
||||
assert "finalize" in execution_order
|
||||
assert execution_order.index("finalize") > execution_order.index("sync_process")
|
||||
assert execution_order.index("finalize") > execution_order.index("async_process")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_independent_handlers_run_concurrently():
|
||||
"""Test that handlers without dependencies can run concurrently."""
|
||||
execution_order = []
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(DependsTestEvent)
|
||||
async def handler_a(source, event: DependsTestEvent):
|
||||
await asyncio.sleep(0.01)
|
||||
execution_order.append("handler_a")
|
||||
|
||||
@crewai_event_bus.on(DependsTestEvent)
|
||||
async def handler_b(source, event: DependsTestEvent):
|
||||
await asyncio.sleep(0.01)
|
||||
execution_order.append("handler_b")
|
||||
|
||||
event = DependsTestEvent(value=1)
|
||||
future = crewai_event_bus.emit("test_source", event)
|
||||
|
||||
if future:
|
||||
await asyncio.wrap_future(future)
|
||||
|
||||
# Both handlers should have executed
|
||||
assert len(execution_order) == 2
|
||||
assert "handler_a" in execution_order
|
||||
assert "handler_b" in execution_order
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_circular_dependency_detection():
|
||||
"""Test that circular dependencies are detected and raise an error."""
|
||||
from crewai.events.handler_graph import CircularDependencyError, build_execution_plan
|
||||
|
||||
# Create circular dependency: handler_a -> handler_b -> handler_c -> handler_a
|
||||
def handler_a(source, event: DependsTestEvent):
|
||||
pass
|
||||
|
||||
def handler_b(source, event: DependsTestEvent):
|
||||
pass
|
||||
|
||||
def handler_c(source, event: DependsTestEvent):
|
||||
pass
|
||||
|
||||
# Build a dependency graph with a cycle
|
||||
handlers = [handler_a, handler_b, handler_c]
|
||||
dependencies = {
|
||||
handler_a: [Depends(handler_b)],
|
||||
handler_b: [Depends(handler_c)],
|
||||
handler_c: [Depends(handler_a)], # Creates the cycle
|
||||
}
|
||||
|
||||
# Should raise CircularDependencyError about circular dependency
|
||||
with pytest.raises(CircularDependencyError, match="Circular dependency"):
|
||||
build_execution_plan(handlers, dependencies)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_without_dependency_runs_normally():
|
||||
"""Test that handlers without dependencies still work as before."""
|
||||
execution_order = []
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(DependsTestEvent)
|
||||
def simple_handler(source, event: DependsTestEvent):
|
||||
execution_order.append("simple_handler")
|
||||
|
||||
event = DependsTestEvent(value=1)
|
||||
future = crewai_event_bus.emit("test_source", event)
|
||||
|
||||
if future:
|
||||
await asyncio.wrap_future(future)
|
||||
|
||||
assert execution_order == ["simple_handler"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_depends_equality():
|
||||
"""Test Depends equality and hashing."""
|
||||
|
||||
def handler_a(source, event):
|
||||
pass
|
||||
|
||||
def handler_b(source, event):
|
||||
pass
|
||||
|
||||
dep_a1 = Depends(handler_a)
|
||||
dep_a2 = Depends(handler_a)
|
||||
dep_b = Depends(handler_b)
|
||||
|
||||
# Same handler should be equal
|
||||
assert dep_a1 == dep_a2
|
||||
assert hash(dep_a1) == hash(dep_a2)
|
||||
|
||||
# Different handlers should not be equal
|
||||
assert dep_a1 != dep_b
|
||||
assert hash(dep_a1) != hash(dep_b)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aemit_ignores_dependencies():
|
||||
"""Test that aemit only processes async handlers (no dependency support yet)."""
|
||||
execution_order = []
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(DependsTestEvent)
|
||||
def sync_handler(source, event: DependsTestEvent):
|
||||
execution_order.append("sync_handler")
|
||||
|
||||
@crewai_event_bus.on(DependsTestEvent)
|
||||
async def async_handler(source, event: DependsTestEvent):
|
||||
execution_order.append("async_handler")
|
||||
|
||||
event = DependsTestEvent(value=1)
|
||||
await crewai_event_bus.aemit("test_source", event)
|
||||
|
||||
# Only async handler should execute
|
||||
assert execution_order == ["async_handler"]
|
||||
@@ -1,3 +1,5 @@
|
||||
import threading
|
||||
|
||||
import pytest
|
||||
from crewai.agent import Agent
|
||||
from crewai.crew import Crew
|
||||
@@ -19,7 +21,10 @@ from crewai.experimental.evaluation import (
|
||||
create_default_evaluator,
|
||||
)
|
||||
from crewai.experimental.evaluation.agent_evaluator import AgentEvaluator
|
||||
from crewai.experimental.evaluation.base_evaluator import AgentEvaluationResult
|
||||
from crewai.experimental.evaluation.base_evaluator import (
|
||||
AgentEvaluationResult,
|
||||
BaseEvaluator,
|
||||
)
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
@@ -51,12 +56,25 @@ class TestAgentEvaluator:
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_evaluate_current_iteration(self, mock_crew):
|
||||
from crewai.events.types.task_events import TaskCompletedEvent
|
||||
|
||||
agent_evaluator = AgentEvaluator(
|
||||
agents=mock_crew.agents, evaluators=[GoalAlignmentEvaluator()]
|
||||
)
|
||||
|
||||
task_completed_event = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(TaskCompletedEvent)
|
||||
async def on_task_completed(source, event):
|
||||
# TaskCompletedEvent fires AFTER evaluation results are stored
|
||||
task_completed_event.set()
|
||||
|
||||
mock_crew.kickoff()
|
||||
|
||||
assert task_completed_event.wait(timeout=5), (
|
||||
"Timeout waiting for task completion"
|
||||
)
|
||||
|
||||
results = agent_evaluator.get_evaluation_results()
|
||||
|
||||
assert isinstance(results, dict)
|
||||
@@ -98,73 +116,15 @@ class TestAgentEvaluator:
|
||||
]
|
||||
|
||||
assert len(agent_evaluator.evaluators) == len(expected_types)
|
||||
for evaluator, expected_type in zip(agent_evaluator.evaluators, expected_types):
|
||||
for evaluator, expected_type in zip(
|
||||
agent_evaluator.evaluators, expected_types, strict=False
|
||||
):
|
||||
assert isinstance(evaluator, expected_type)
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_eval_lite_agent(self):
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Complete test tasks successfully",
|
||||
backstory="An agent created for testing purposes",
|
||||
)
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
events = {}
|
||||
|
||||
@crewai_event_bus.on(AgentEvaluationStartedEvent)
|
||||
def capture_started(source, event):
|
||||
events["started"] = event
|
||||
|
||||
@crewai_event_bus.on(AgentEvaluationCompletedEvent)
|
||||
def capture_completed(source, event):
|
||||
events["completed"] = event
|
||||
|
||||
@crewai_event_bus.on(AgentEvaluationFailedEvent)
|
||||
def capture_failed(source, event):
|
||||
events["failed"] = event
|
||||
|
||||
agent_evaluator = AgentEvaluator(
|
||||
agents=[agent], evaluators=[GoalAlignmentEvaluator()]
|
||||
)
|
||||
|
||||
agent.kickoff(messages="Complete this task successfully")
|
||||
|
||||
assert events.keys() == {"started", "completed"}
|
||||
assert events["started"].agent_id == str(agent.id)
|
||||
assert events["started"].agent_role == agent.role
|
||||
assert events["started"].task_id is None
|
||||
assert events["started"].iteration == 1
|
||||
|
||||
assert events["completed"].agent_id == str(agent.id)
|
||||
assert events["completed"].agent_role == agent.role
|
||||
assert events["completed"].task_id is None
|
||||
assert events["completed"].iteration == 1
|
||||
assert events["completed"].metric_category == MetricCategory.GOAL_ALIGNMENT
|
||||
assert isinstance(events["completed"].score, EvaluationScore)
|
||||
assert events["completed"].score.score == 2.0
|
||||
|
||||
results = agent_evaluator.get_evaluation_results()
|
||||
|
||||
assert isinstance(results, dict)
|
||||
|
||||
(result,) = results[agent.role]
|
||||
assert isinstance(result, AgentEvaluationResult)
|
||||
|
||||
assert result.agent_id == str(agent.id)
|
||||
assert result.task_id == "lite_task"
|
||||
|
||||
(goal_alignment,) = result.metrics.values()
|
||||
assert goal_alignment.score == 2.0
|
||||
|
||||
expected_feedback = "The agent did not demonstrate a clear understanding of the task goal, which is to complete test tasks successfully"
|
||||
assert expected_feedback in goal_alignment.feedback
|
||||
|
||||
assert goal_alignment.raw_response is not None
|
||||
assert '"score": 2' in goal_alignment.raw_response
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_eval_specific_agents_from_crew(self, mock_crew):
|
||||
from crewai.events.types.task_events import TaskCompletedEvent
|
||||
|
||||
agent = Agent(
|
||||
role="Test Agent Eval",
|
||||
goal="Complete test tasks successfully",
|
||||
@@ -178,111 +138,132 @@ class TestAgentEvaluator:
|
||||
mock_crew.agents.append(agent)
|
||||
mock_crew.tasks.append(task)
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
events = {}
|
||||
events = {}
|
||||
started_event = threading.Event()
|
||||
completed_event = threading.Event()
|
||||
task_completed_event = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(AgentEvaluationStartedEvent)
|
||||
def capture_started(source, event):
|
||||
agent_evaluator = AgentEvaluator(
|
||||
agents=[agent], evaluators=[GoalAlignmentEvaluator()]
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(AgentEvaluationStartedEvent)
|
||||
async def capture_started(source, event):
|
||||
if event.agent_id == str(agent.id):
|
||||
events["started"] = event
|
||||
started_event.set()
|
||||
|
||||
@crewai_event_bus.on(AgentEvaluationCompletedEvent)
|
||||
def capture_completed(source, event):
|
||||
@crewai_event_bus.on(AgentEvaluationCompletedEvent)
|
||||
async def capture_completed(source, event):
|
||||
if event.agent_id == str(agent.id):
|
||||
events["completed"] = event
|
||||
completed_event.set()
|
||||
|
||||
@crewai_event_bus.on(AgentEvaluationFailedEvent)
|
||||
def capture_failed(source, event):
|
||||
events["failed"] = event
|
||||
@crewai_event_bus.on(AgentEvaluationFailedEvent)
|
||||
def capture_failed(source, event):
|
||||
events["failed"] = event
|
||||
|
||||
agent_evaluator = AgentEvaluator(
|
||||
agents=[agent], evaluators=[GoalAlignmentEvaluator()]
|
||||
)
|
||||
mock_crew.kickoff()
|
||||
@crewai_event_bus.on(TaskCompletedEvent)
|
||||
async def on_task_completed(source, event):
|
||||
# TaskCompletedEvent fires AFTER evaluation results are stored
|
||||
if event.task and event.task.id == task.id:
|
||||
task_completed_event.set()
|
||||
|
||||
assert events.keys() == {"started", "completed"}
|
||||
assert events["started"].agent_id == str(agent.id)
|
||||
assert events["started"].agent_role == agent.role
|
||||
assert events["started"].task_id == str(task.id)
|
||||
assert events["started"].iteration == 1
|
||||
mock_crew.kickoff()
|
||||
|
||||
assert events["completed"].agent_id == str(agent.id)
|
||||
assert events["completed"].agent_role == agent.role
|
||||
assert events["completed"].task_id == str(task.id)
|
||||
assert events["completed"].iteration == 1
|
||||
assert events["completed"].metric_category == MetricCategory.GOAL_ALIGNMENT
|
||||
assert isinstance(events["completed"].score, EvaluationScore)
|
||||
assert events["completed"].score.score == 5.0
|
||||
assert started_event.wait(timeout=5), "Timeout waiting for started event"
|
||||
assert completed_event.wait(timeout=5), "Timeout waiting for completed event"
|
||||
assert task_completed_event.wait(timeout=5), (
|
||||
"Timeout waiting for task completion"
|
||||
)
|
||||
|
||||
results = agent_evaluator.get_evaluation_results()
|
||||
assert events.keys() == {"started", "completed"}
|
||||
assert events["started"].agent_id == str(agent.id)
|
||||
assert events["started"].agent_role == agent.role
|
||||
assert events["started"].task_id == str(task.id)
|
||||
assert events["started"].iteration == 1
|
||||
|
||||
assert isinstance(results, dict)
|
||||
assert len(results.keys()) == 1
|
||||
(result,) = results[agent.role]
|
||||
assert isinstance(result, AgentEvaluationResult)
|
||||
assert events["completed"].agent_id == str(agent.id)
|
||||
assert events["completed"].agent_role == agent.role
|
||||
assert events["completed"].task_id == str(task.id)
|
||||
assert events["completed"].iteration == 1
|
||||
assert events["completed"].metric_category == MetricCategory.GOAL_ALIGNMENT
|
||||
assert isinstance(events["completed"].score, EvaluationScore)
|
||||
assert events["completed"].score.score == 5.0
|
||||
|
||||
assert result.agent_id == str(agent.id)
|
||||
assert result.task_id == str(task.id)
|
||||
results = agent_evaluator.get_evaluation_results()
|
||||
|
||||
(goal_alignment,) = result.metrics.values()
|
||||
assert goal_alignment.score == 5.0
|
||||
assert isinstance(results, dict)
|
||||
assert len(results.keys()) == 1
|
||||
(result,) = results[agent.role]
|
||||
assert isinstance(result, AgentEvaluationResult)
|
||||
|
||||
expected_feedback = "The agent provided a thorough guide on how to conduct a test task but failed to produce specific expected output"
|
||||
assert expected_feedback in goal_alignment.feedback
|
||||
assert result.agent_id == str(agent.id)
|
||||
assert result.task_id == str(task.id)
|
||||
|
||||
assert goal_alignment.raw_response is not None
|
||||
assert '"score": 5' in goal_alignment.raw_response
|
||||
(goal_alignment,) = result.metrics.values()
|
||||
assert goal_alignment.score == 5.0
|
||||
|
||||
expected_feedback = "The agent provided a thorough guide on how to conduct a test task but failed to produce specific expected output"
|
||||
assert expected_feedback in goal_alignment.feedback
|
||||
|
||||
assert goal_alignment.raw_response is not None
|
||||
assert '"score": 5' in goal_alignment.raw_response
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_failed_evaluation(self, mock_crew):
|
||||
(agent,) = mock_crew.agents
|
||||
(task,) = mock_crew.tasks
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
events = {}
|
||||
events = {}
|
||||
started_event = threading.Event()
|
||||
failed_event = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(AgentEvaluationStartedEvent)
|
||||
def capture_started(source, event):
|
||||
events["started"] = event
|
||||
@crewai_event_bus.on(AgentEvaluationStartedEvent)
|
||||
def capture_started(source, event):
|
||||
events["started"] = event
|
||||
started_event.set()
|
||||
|
||||
@crewai_event_bus.on(AgentEvaluationCompletedEvent)
|
||||
def capture_completed(source, event):
|
||||
events["completed"] = event
|
||||
@crewai_event_bus.on(AgentEvaluationCompletedEvent)
|
||||
def capture_completed(source, event):
|
||||
events["completed"] = event
|
||||
|
||||
@crewai_event_bus.on(AgentEvaluationFailedEvent)
|
||||
def capture_failed(source, event):
|
||||
events["failed"] = event
|
||||
@crewai_event_bus.on(AgentEvaluationFailedEvent)
|
||||
def capture_failed(source, event):
|
||||
events["failed"] = event
|
||||
failed_event.set()
|
||||
|
||||
# Create a mock evaluator that will raise an exception
|
||||
from crewai.experimental.evaluation import MetricCategory
|
||||
from crewai.experimental.evaluation.base_evaluator import BaseEvaluator
|
||||
class FailingEvaluator(BaseEvaluator):
|
||||
metric_category = MetricCategory.GOAL_ALIGNMENT
|
||||
|
||||
class FailingEvaluator(BaseEvaluator):
|
||||
metric_category = MetricCategory.GOAL_ALIGNMENT
|
||||
def evaluate(self, agent, task, execution_trace, final_output):
|
||||
raise ValueError("Forced evaluation failure")
|
||||
|
||||
def evaluate(self, agent, task, execution_trace, final_output):
|
||||
raise ValueError("Forced evaluation failure")
|
||||
agent_evaluator = AgentEvaluator(
|
||||
agents=[agent], evaluators=[FailingEvaluator()]
|
||||
)
|
||||
mock_crew.kickoff()
|
||||
|
||||
agent_evaluator = AgentEvaluator(
|
||||
agents=[agent], evaluators=[FailingEvaluator()]
|
||||
)
|
||||
mock_crew.kickoff()
|
||||
assert started_event.wait(timeout=5), "Timeout waiting for started event"
|
||||
assert failed_event.wait(timeout=5), "Timeout waiting for failed event"
|
||||
|
||||
assert events.keys() == {"started", "failed"}
|
||||
assert events["started"].agent_id == str(agent.id)
|
||||
assert events["started"].agent_role == agent.role
|
||||
assert events["started"].task_id == str(task.id)
|
||||
assert events["started"].iteration == 1
|
||||
assert events.keys() == {"started", "failed"}
|
||||
assert events["started"].agent_id == str(agent.id)
|
||||
assert events["started"].agent_role == agent.role
|
||||
assert events["started"].task_id == str(task.id)
|
||||
assert events["started"].iteration == 1
|
||||
|
||||
assert events["failed"].agent_id == str(agent.id)
|
||||
assert events["failed"].agent_role == agent.role
|
||||
assert events["failed"].task_id == str(task.id)
|
||||
assert events["failed"].iteration == 1
|
||||
assert events["failed"].error == "Forced evaluation failure"
|
||||
assert events["failed"].agent_id == str(agent.id)
|
||||
assert events["failed"].agent_role == agent.role
|
||||
assert events["failed"].task_id == str(task.id)
|
||||
assert events["failed"].iteration == 1
|
||||
assert events["failed"].error == "Forced evaluation failure"
|
||||
|
||||
results = agent_evaluator.get_evaluation_results()
|
||||
(result,) = results[agent.role]
|
||||
assert isinstance(result, AgentEvaluationResult)
|
||||
results = agent_evaluator.get_evaluation_results()
|
||||
(result,) = results[agent.role]
|
||||
assert isinstance(result, AgentEvaluationResult)
|
||||
|
||||
assert result.agent_id == str(agent.id)
|
||||
assert result.task_id == str(task.id)
|
||||
assert result.agent_id == str(agent.id)
|
||||
assert result.task_id == str(task.id)
|
||||
|
||||
assert result.metrics == {}
|
||||
assert result.metrics == {}
|
||||
|
||||
665
lib/crewai/tests/llms/anthropic/test_anthropic.py
Normal file
665
lib/crewai/tests/llms/anthropic/test_anthropic.py
Normal file
@@ -0,0 +1,665 @@
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
from unittest.mock import patch, MagicMock
|
||||
import pytest
|
||||
|
||||
from crewai.llm import LLM
|
||||
from crewai.crew import Crew
|
||||
from crewai.agent import Agent
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_anthropic_api_key():
|
||||
"""Automatically mock ANTHROPIC_API_KEY for all tests in this module."""
|
||||
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}):
|
||||
yield
|
||||
|
||||
|
||||
def test_anthropic_completion_is_used_when_anthropic_provider():
|
||||
"""
|
||||
Test that AnthropicCompletion from completion.py is used when LLM uses provider 'anthropic'
|
||||
"""
|
||||
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
|
||||
assert llm.__class__.__name__ == "AnthropicCompletion"
|
||||
assert llm.provider == "anthropic"
|
||||
assert llm.model == "claude-3-5-sonnet-20241022"
|
||||
|
||||
|
||||
def test_anthropic_completion_is_used_when_claude_provider():
|
||||
"""
|
||||
Test that AnthropicCompletion is used when provider is 'claude'
|
||||
"""
|
||||
llm = LLM(model="claude/claude-3-5-sonnet-20241022")
|
||||
|
||||
from crewai.llms.providers.anthropic.completion import AnthropicCompletion
|
||||
assert isinstance(llm, AnthropicCompletion)
|
||||
assert llm.provider == "claude"
|
||||
assert llm.model == "claude-3-5-sonnet-20241022"
|
||||
|
||||
|
||||
|
||||
|
||||
def test_anthropic_tool_use_conversation_flow():
|
||||
"""
|
||||
Test that the Anthropic completion properly handles tool use conversation flow
|
||||
"""
|
||||
from unittest.mock import Mock, patch
|
||||
from crewai.llms.providers.anthropic.completion import AnthropicCompletion
|
||||
from anthropic.types.tool_use_block import ToolUseBlock
|
||||
|
||||
# Create AnthropicCompletion instance
|
||||
completion = AnthropicCompletion(model="claude-3-5-sonnet-20241022")
|
||||
|
||||
# Mock tool function
|
||||
def mock_weather_tool(location: str) -> str:
|
||||
return f"The weather in {location} is sunny and 75°F"
|
||||
|
||||
available_functions = {"get_weather": mock_weather_tool}
|
||||
|
||||
# Mock the Anthropic client responses
|
||||
with patch.object(completion.client.messages, 'create') as mock_create:
|
||||
# Mock initial response with tool use - need to properly mock ToolUseBlock
|
||||
mock_tool_use = Mock(spec=ToolUseBlock)
|
||||
mock_tool_use.id = "tool_123"
|
||||
mock_tool_use.name = "get_weather"
|
||||
mock_tool_use.input = {"location": "San Francisco"}
|
||||
|
||||
mock_initial_response = Mock()
|
||||
mock_initial_response.content = [mock_tool_use]
|
||||
mock_initial_response.usage = Mock()
|
||||
mock_initial_response.usage.input_tokens = 100
|
||||
mock_initial_response.usage.output_tokens = 50
|
||||
|
||||
# Mock final response after tool result - properly mock text content
|
||||
mock_text_block = Mock()
|
||||
# Set the text attribute as a string, not another Mock
|
||||
mock_text_block.configure_mock(text="Based on the weather data, it's a beautiful day in San Francisco with sunny skies and 75°F temperature.")
|
||||
|
||||
mock_final_response = Mock()
|
||||
mock_final_response.content = [mock_text_block]
|
||||
mock_final_response.usage = Mock()
|
||||
mock_final_response.usage.input_tokens = 150
|
||||
mock_final_response.usage.output_tokens = 75
|
||||
|
||||
# Configure mock to return different responses on successive calls
|
||||
mock_create.side_effect = [mock_initial_response, mock_final_response]
|
||||
|
||||
# Test the call
|
||||
messages = [{"role": "user", "content": "What's the weather like in San Francisco?"}]
|
||||
result = completion.call(
|
||||
messages=messages,
|
||||
available_functions=available_functions
|
||||
)
|
||||
|
||||
# Verify the result contains the final response
|
||||
assert "beautiful day in San Francisco" in result
|
||||
assert "sunny skies" in result
|
||||
assert "75°F" in result
|
||||
|
||||
# Verify that two API calls were made (initial + follow-up)
|
||||
assert mock_create.call_count == 2
|
||||
|
||||
# Verify the second call includes tool results
|
||||
second_call_args = mock_create.call_args_list[1][1] # kwargs of second call
|
||||
messages_in_second_call = second_call_args["messages"]
|
||||
|
||||
# Should have original user message + assistant tool use + user tool result
|
||||
assert len(messages_in_second_call) == 3
|
||||
assert messages_in_second_call[0]["role"] == "user"
|
||||
assert messages_in_second_call[1]["role"] == "assistant"
|
||||
assert messages_in_second_call[2]["role"] == "user"
|
||||
|
||||
# Verify tool result format
|
||||
tool_result = messages_in_second_call[2]["content"][0]
|
||||
assert tool_result["type"] == "tool_result"
|
||||
assert tool_result["tool_use_id"] == "tool_123"
|
||||
assert "sunny and 75°F" in tool_result["content"]
|
||||
|
||||
|
||||
def test_anthropic_completion_module_is_imported():
|
||||
"""
|
||||
Test that the completion module is properly imported when using Anthropic provider
|
||||
"""
|
||||
module_name = "crewai.llms.providers.anthropic.completion"
|
||||
|
||||
# Remove module from cache if it exists
|
||||
if module_name in sys.modules:
|
||||
del sys.modules[module_name]
|
||||
|
||||
# Create LLM instance - this should trigger the import
|
||||
LLM(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
|
||||
# Verify the module was imported
|
||||
assert module_name in sys.modules
|
||||
completion_mod = sys.modules[module_name]
|
||||
assert isinstance(completion_mod, types.ModuleType)
|
||||
|
||||
# Verify the class exists in the module
|
||||
assert hasattr(completion_mod, 'AnthropicCompletion')
|
||||
|
||||
|
||||
def test_fallback_to_litellm_when_native_anthropic_fails():
|
||||
"""
|
||||
Test that LLM falls back to LiteLLM when native Anthropic completion fails
|
||||
"""
|
||||
# Mock the _get_native_provider to return a failing class
|
||||
with patch('crewai.llm.LLM._get_native_provider') as mock_get_provider:
|
||||
|
||||
class FailingCompletion:
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise Exception("Native Anthropic SDK failed")
|
||||
|
||||
mock_get_provider.return_value = FailingCompletion
|
||||
|
||||
# This should fall back to LiteLLM
|
||||
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
|
||||
# Check that it's using LiteLLM
|
||||
assert hasattr(llm, 'is_litellm')
|
||||
assert llm.is_litellm == True
|
||||
|
||||
|
||||
def test_anthropic_completion_initialization_parameters():
|
||||
"""
|
||||
Test that AnthropicCompletion is initialized with correct parameters
|
||||
"""
|
||||
llm = LLM(
|
||||
model="anthropic/claude-3-5-sonnet-20241022",
|
||||
temperature=0.7,
|
||||
max_tokens=2000,
|
||||
top_p=0.9,
|
||||
api_key="test-key"
|
||||
)
|
||||
|
||||
from crewai.llms.providers.anthropic.completion import AnthropicCompletion
|
||||
assert isinstance(llm, AnthropicCompletion)
|
||||
assert llm.model == "claude-3-5-sonnet-20241022"
|
||||
assert llm.temperature == 0.7
|
||||
assert llm.max_tokens == 2000
|
||||
assert llm.top_p == 0.9
|
||||
|
||||
|
||||
def test_anthropic_specific_parameters():
|
||||
"""
|
||||
Test Anthropic-specific parameters like stop_sequences and streaming
|
||||
"""
|
||||
llm = LLM(
|
||||
model="anthropic/claude-3-5-sonnet-20241022",
|
||||
stop_sequences=["Human:", "Assistant:"],
|
||||
stream=True,
|
||||
max_retries=5,
|
||||
timeout=60
|
||||
)
|
||||
|
||||
from crewai.llms.providers.anthropic.completion import AnthropicCompletion
|
||||
assert isinstance(llm, AnthropicCompletion)
|
||||
assert llm.stop_sequences == ["Human:", "Assistant:"]
|
||||
assert llm.stream == True
|
||||
assert llm.client.max_retries == 5
|
||||
assert llm.client.timeout == 60
|
||||
|
||||
|
||||
def test_anthropic_completion_call():
|
||||
"""
|
||||
Test that AnthropicCompletion call method works
|
||||
"""
|
||||
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
|
||||
# Mock the call method on the instance
|
||||
with patch.object(llm, 'call', return_value="Hello! I'm Claude, ready to help.") as mock_call:
|
||||
result = llm.call("Hello, how are you?")
|
||||
|
||||
assert result == "Hello! I'm Claude, ready to help."
|
||||
mock_call.assert_called_once_with("Hello, how are you?")
|
||||
|
||||
|
||||
def test_anthropic_completion_called_during_crew_execution():
|
||||
"""
|
||||
Test that AnthropicCompletion.call is actually invoked when running a crew
|
||||
"""
|
||||
# Create the LLM instance first
|
||||
anthropic_llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
|
||||
# Mock the call method on the specific instance
|
||||
with patch.object(anthropic_llm, 'call', return_value="Tokyo has 14 million people.") as mock_call:
|
||||
|
||||
# Create agent with explicit LLM configuration
|
||||
agent = Agent(
|
||||
role="Research Assistant",
|
||||
goal="Find population info",
|
||||
backstory="You research populations.",
|
||||
llm=anthropic_llm,
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Find Tokyo population",
|
||||
expected_output="Population number",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
result = crew.kickoff()
|
||||
|
||||
# Verify mock was called
|
||||
assert mock_call.called
|
||||
assert "14 million" in str(result)
|
||||
|
||||
|
||||
def test_anthropic_completion_call_arguments():
|
||||
"""
|
||||
Test that AnthropicCompletion.call is invoked with correct arguments
|
||||
"""
|
||||
# Create LLM instance first
|
||||
anthropic_llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
|
||||
# Mock the instance method
|
||||
with patch.object(anthropic_llm, 'call') as mock_call:
|
||||
mock_call.return_value = "Task completed successfully."
|
||||
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Complete a simple task",
|
||||
backstory="You are a test agent.",
|
||||
llm=anthropic_llm # Use same instance
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Say hello world",
|
||||
expected_output="Hello world",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
crew.kickoff()
|
||||
|
||||
# Verify call was made
|
||||
assert mock_call.called
|
||||
|
||||
# Check the arguments passed to the call method
|
||||
call_args = mock_call.call_args
|
||||
assert call_args is not None
|
||||
|
||||
# The first argument should be the messages
|
||||
messages = call_args[0][0] # First positional argument
|
||||
assert isinstance(messages, (str, list))
|
||||
|
||||
# Verify that the task description appears in the messages
|
||||
if isinstance(messages, str):
|
||||
assert "hello world" in messages.lower()
|
||||
elif isinstance(messages, list):
|
||||
message_content = str(messages).lower()
|
||||
assert "hello world" in message_content
|
||||
|
||||
|
||||
def test_multiple_anthropic_calls_in_crew():
|
||||
"""
|
||||
Test that AnthropicCompletion.call is invoked multiple times for multiple tasks
|
||||
"""
|
||||
# Create LLM instance first
|
||||
anthropic_llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
|
||||
# Mock the instance method
|
||||
with patch.object(anthropic_llm, 'call') as mock_call:
|
||||
mock_call.return_value = "Task completed."
|
||||
|
||||
agent = Agent(
|
||||
role="Multi-task Agent",
|
||||
goal="Complete multiple tasks",
|
||||
backstory="You can handle multiple tasks.",
|
||||
llm=anthropic_llm # Use same instance
|
||||
)
|
||||
|
||||
task1 = Task(
|
||||
description="First task",
|
||||
expected_output="First result",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
task2 = Task(
|
||||
description="Second task",
|
||||
expected_output="Second result",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[agent],
|
||||
tasks=[task1, task2]
|
||||
)
|
||||
crew.kickoff()
|
||||
|
||||
# Verify multiple calls were made
|
||||
assert mock_call.call_count >= 2 # At least one call per task
|
||||
|
||||
# Verify each call had proper arguments
|
||||
for call in mock_call.call_args_list:
|
||||
assert len(call[0]) > 0 # Has positional arguments
|
||||
messages = call[0][0]
|
||||
assert messages is not None
|
||||
|
||||
|
||||
def test_anthropic_completion_with_tools():
|
||||
"""
|
||||
Test that AnthropicCompletion.call is invoked with tools when agent has tools
|
||||
"""
|
||||
from crewai.tools import tool
|
||||
|
||||
@tool
|
||||
def sample_tool(query: str) -> str:
|
||||
"""A sample tool for testing"""
|
||||
return f"Tool result for: {query}"
|
||||
|
||||
# Create LLM instance first
|
||||
anthropic_llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
|
||||
# Mock the instance method
|
||||
with patch.object(anthropic_llm, 'call') as mock_call:
|
||||
mock_call.return_value = "Task completed with tools."
|
||||
|
||||
agent = Agent(
|
||||
role="Tool User",
|
||||
goal="Use tools to complete tasks",
|
||||
backstory="You can use tools.",
|
||||
llm=anthropic_llm, # Use same instance
|
||||
tools=[sample_tool]
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Use the sample tool",
|
||||
expected_output="Tool usage result",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
crew.kickoff()
|
||||
|
||||
assert mock_call.called
|
||||
|
||||
call_args = mock_call.call_args
|
||||
call_kwargs = call_args[1] if len(call_args) > 1 else {}
|
||||
|
||||
if 'tools' in call_kwargs:
|
||||
assert call_kwargs['tools'] is not None
|
||||
assert len(call_kwargs['tools']) > 0
|
||||
|
||||
|
||||
def test_anthropic_raises_error_when_model_not_supported():
|
||||
"""Test that AnthropicCompletion raises ValueError when model not supported"""
|
||||
|
||||
# Mock the Anthropic client to raise an error
|
||||
with patch('crewai.llms.providers.anthropic.completion.Anthropic') as mock_anthropic_class:
|
||||
mock_client = MagicMock()
|
||||
mock_anthropic_class.return_value = mock_client
|
||||
|
||||
# Mock the error that Anthropic would raise for unsupported models
|
||||
from anthropic import NotFoundError
|
||||
mock_client.messages.create.side_effect = NotFoundError(
|
||||
message="The model `model-doesnt-exist` does not exist",
|
||||
response=MagicMock(),
|
||||
body={}
|
||||
)
|
||||
|
||||
llm = LLM(model="anthropic/model-doesnt-exist")
|
||||
|
||||
with pytest.raises(Exception): # Should raise some error for unsupported model
|
||||
llm.call("Hello")
|
||||
|
||||
|
||||
def test_anthropic_client_params_setup():
|
||||
"""
|
||||
Test that client_params are properly merged with default client parameters
|
||||
"""
|
||||
# Use only valid Anthropic client parameters
|
||||
custom_client_params = {
|
||||
"default_headers": {"X-Custom-Header": "test-value"},
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}):
|
||||
llm = LLM(
|
||||
model="anthropic/claude-3-5-sonnet-20241022",
|
||||
api_key="test-key",
|
||||
base_url="https://custom-api.com",
|
||||
timeout=45,
|
||||
max_retries=5,
|
||||
client_params=custom_client_params
|
||||
)
|
||||
|
||||
from crewai.llms.providers.anthropic.completion import AnthropicCompletion
|
||||
assert isinstance(llm, AnthropicCompletion)
|
||||
|
||||
assert llm.client_params == custom_client_params
|
||||
|
||||
merged_params = llm._get_client_params()
|
||||
|
||||
assert merged_params["api_key"] == "test-key"
|
||||
assert merged_params["base_url"] == "https://custom-api.com"
|
||||
assert merged_params["timeout"] == 45
|
||||
assert merged_params["max_retries"] == 5
|
||||
|
||||
assert merged_params["default_headers"] == {"X-Custom-Header": "test-value"}
|
||||
|
||||
|
||||
def test_anthropic_client_params_override_defaults():
|
||||
"""
|
||||
Test that client_params can override default client parameters
|
||||
"""
|
||||
override_client_params = {
|
||||
"timeout": 120, # Override the timeout parameter
|
||||
"max_retries": 10, # Override the max_retries parameter
|
||||
"default_headers": {"X-Override": "true"} # Valid custom parameter
|
||||
}
|
||||
|
||||
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}):
|
||||
llm = LLM(
|
||||
model="anthropic/claude-3-5-sonnet-20241022",
|
||||
api_key="test-key",
|
||||
timeout=30,
|
||||
max_retries=3,
|
||||
client_params=override_client_params
|
||||
)
|
||||
|
||||
# Verify this is actually AnthropicCompletion, not LiteLLM fallback
|
||||
from crewai.llms.providers.anthropic.completion import AnthropicCompletion
|
||||
assert isinstance(llm, AnthropicCompletion)
|
||||
|
||||
merged_params = llm._get_client_params()
|
||||
|
||||
# client_params should override the individual parameters
|
||||
assert merged_params["timeout"] == 120
|
||||
assert merged_params["max_retries"] == 10
|
||||
assert merged_params["default_headers"] == {"X-Override": "true"}
|
||||
|
||||
|
||||
def test_anthropic_client_params_none():
|
||||
"""
|
||||
Test that client_params=None works correctly (no additional parameters)
|
||||
"""
|
||||
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}):
|
||||
llm = LLM(
|
||||
model="anthropic/claude-3-5-sonnet-20241022",
|
||||
api_key="test-key",
|
||||
base_url="https://api.anthropic.com",
|
||||
timeout=60,
|
||||
max_retries=2,
|
||||
client_params=None
|
||||
)
|
||||
|
||||
from crewai.llms.providers.anthropic.completion import AnthropicCompletion
|
||||
assert isinstance(llm, AnthropicCompletion)
|
||||
|
||||
assert llm.client_params is None
|
||||
|
||||
merged_params = llm._get_client_params()
|
||||
|
||||
expected_keys = {"api_key", "base_url", "timeout", "max_retries"}
|
||||
assert set(merged_params.keys()) == expected_keys
|
||||
|
||||
# Fixed assertions - all should be inside the with block and use correct values
|
||||
assert merged_params["api_key"] == "test-key" # Not "test-anthropic-key"
|
||||
assert merged_params["base_url"] == "https://api.anthropic.com"
|
||||
assert merged_params["timeout"] == 60
|
||||
assert merged_params["max_retries"] == 2
|
||||
|
||||
|
||||
def test_anthropic_client_params_empty_dict():
|
||||
"""
|
||||
Test that client_params={} works correctly (empty additional parameters)
|
||||
"""
|
||||
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}):
|
||||
llm = LLM(
|
||||
model="anthropic/claude-3-5-sonnet-20241022",
|
||||
api_key="test-key",
|
||||
client_params={}
|
||||
)
|
||||
|
||||
from crewai.llms.providers.anthropic.completion import AnthropicCompletion
|
||||
assert isinstance(llm, AnthropicCompletion)
|
||||
|
||||
assert llm.client_params == {}
|
||||
|
||||
merged_params = llm._get_client_params()
|
||||
|
||||
assert "api_key" in merged_params
|
||||
assert merged_params["api_key"] == "test-key"
|
||||
|
||||
|
||||
def test_anthropic_model_detection():
|
||||
"""
|
||||
Test that various Anthropic model formats are properly detected
|
||||
"""
|
||||
# Test Anthropic model naming patterns that actually work with provider detection
|
||||
anthropic_test_cases = [
|
||||
"anthropic/claude-3-5-sonnet-20241022",
|
||||
"claude/claude-3-5-sonnet-20241022"
|
||||
]
|
||||
|
||||
for model_name in anthropic_test_cases:
|
||||
llm = LLM(model=model_name)
|
||||
from crewai.llms.providers.anthropic.completion import AnthropicCompletion
|
||||
assert isinstance(llm, AnthropicCompletion), f"Failed for model: {model_name}"
|
||||
|
||||
|
||||
def test_anthropic_supports_stop_words():
|
||||
"""
|
||||
Test that Anthropic models support stop sequences
|
||||
"""
|
||||
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
assert llm.supports_stop_words() == True
|
||||
|
||||
|
||||
def test_anthropic_context_window_size():
|
||||
"""
|
||||
Test that Anthropic models return correct context window sizes
|
||||
"""
|
||||
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
context_size = llm.get_context_window_size()
|
||||
|
||||
# Should return a reasonable context window size (Claude 3.5 has 200k tokens)
|
||||
assert context_size > 100000 # Should be substantial
|
||||
assert context_size <= 200000 # But not exceed the actual limit
|
||||
|
||||
|
||||
def test_anthropic_message_formatting():
|
||||
"""
|
||||
Test that messages are properly formatted for Anthropic API
|
||||
"""
|
||||
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
|
||||
# Test message formatting
|
||||
test_messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "user", "content": "How are you?"}
|
||||
]
|
||||
|
||||
formatted_messages, system_message = llm._format_messages_for_anthropic(test_messages)
|
||||
|
||||
# System message should be extracted
|
||||
assert system_message == "You are a helpful assistant."
|
||||
|
||||
# Remaining messages should start with user
|
||||
assert formatted_messages[0]["role"] == "user"
|
||||
assert len(formatted_messages) >= 3 # Should have user, assistant, user messages
|
||||
|
||||
|
||||
def test_anthropic_streaming_parameter():
|
||||
"""
|
||||
Test that streaming parameter is properly handled
|
||||
"""
|
||||
# Test non-streaming
|
||||
llm_no_stream = LLM(model="anthropic/claude-3-5-sonnet-20241022", stream=False)
|
||||
assert llm_no_stream.stream == False
|
||||
|
||||
# Test streaming
|
||||
llm_stream = LLM(model="anthropic/claude-3-5-sonnet-20241022", stream=True)
|
||||
assert llm_stream.stream == True
|
||||
|
||||
|
||||
def test_anthropic_tool_conversion():
|
||||
"""
|
||||
Test that tools are properly converted to Anthropic format
|
||||
"""
|
||||
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
|
||||
# Mock tool in CrewAI format
|
||||
crewai_tools = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"description": "A test tool",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "Search query"}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
}]
|
||||
|
||||
# Test tool conversion
|
||||
anthropic_tools = llm._convert_tools_for_interference(crewai_tools)
|
||||
|
||||
assert len(anthropic_tools) == 1
|
||||
assert anthropic_tools[0]["name"] == "test_tool"
|
||||
assert anthropic_tools[0]["description"] == "A test tool"
|
||||
assert "input_schema" in anthropic_tools[0]
|
||||
|
||||
|
||||
def test_anthropic_environment_variable_api_key():
|
||||
"""
|
||||
Test that Anthropic API key is properly loaded from environment
|
||||
"""
|
||||
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-anthropic-key"}):
|
||||
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
|
||||
assert llm.client is not None
|
||||
assert hasattr(llm.client, 'messages')
|
||||
|
||||
|
||||
def test_anthropic_token_usage_tracking():
|
||||
"""
|
||||
Test that token usage is properly tracked for Anthropic responses
|
||||
"""
|
||||
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
|
||||
# Mock the Anthropic response with usage information
|
||||
with patch.object(llm.client.messages, 'create') as mock_create:
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock(text="test response")]
|
||||
mock_response.usage = MagicMock(input_tokens=50, output_tokens=25)
|
||||
mock_create.return_value = mock_response
|
||||
|
||||
result = llm.call("Hello")
|
||||
|
||||
# Verify the response
|
||||
assert result == "test response"
|
||||
|
||||
# Verify token usage was extracted
|
||||
usage = llm._extract_anthropic_token_usage(mock_response)
|
||||
assert usage["input_tokens"] == 50
|
||||
assert usage["output_tokens"] == 25
|
||||
assert usage["total_tokens"] == 75
|
||||
644
lib/crewai/tests/llms/google/test_google.py
Normal file
644
lib/crewai/tests/llms/google/test_google.py
Normal file
@@ -0,0 +1,644 @@
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
from unittest.mock import patch, MagicMock
|
||||
import pytest
|
||||
|
||||
from crewai.llm import LLM
|
||||
from crewai.crew import Crew
|
||||
from crewai.agent import Agent
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_anthropic_api_key():
|
||||
"""Automatically mock ANTHROPIC_API_KEY for all tests in this module."""
|
||||
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}):
|
||||
yield
|
||||
|
||||
|
||||
def test_gemini_completion_is_used_when_google_provider():
|
||||
"""
|
||||
Test that GeminiCompletion from completion.py is used when LLM uses provider 'google'
|
||||
"""
|
||||
llm = LLM(model="google/gemini-2.0-flash-001")
|
||||
|
||||
assert llm.__class__.__name__ == "GeminiCompletion"
|
||||
assert llm.provider == "google"
|
||||
assert llm.model == "gemini-2.0-flash-001"
|
||||
|
||||
|
||||
def test_gemini_completion_is_used_when_gemini_provider():
|
||||
"""
|
||||
Test that GeminiCompletion is used when provider is 'gemini'
|
||||
"""
|
||||
llm = LLM(model="gemini/gemini-2.0-flash-001")
|
||||
|
||||
from crewai.llms.providers.gemini.completion import GeminiCompletion
|
||||
assert isinstance(llm, GeminiCompletion)
|
||||
assert llm.provider == "gemini"
|
||||
assert llm.model == "gemini-2.0-flash-001"
|
||||
|
||||
|
||||
|
||||
|
||||
def test_gemini_tool_use_conversation_flow():
|
||||
"""
|
||||
Test that the Gemini completion properly handles tool use conversation flow
|
||||
"""
|
||||
from unittest.mock import Mock, patch
|
||||
from crewai.llms.providers.gemini.completion import GeminiCompletion
|
||||
|
||||
# Create GeminiCompletion instance
|
||||
completion = GeminiCompletion(model="gemini-2.0-flash-001")
|
||||
|
||||
# Mock tool function
|
||||
def mock_weather_tool(location: str) -> str:
|
||||
return f"The weather in {location} is sunny and 75°F"
|
||||
|
||||
available_functions = {"get_weather": mock_weather_tool}
|
||||
|
||||
# Mock the Google Gemini client responses
|
||||
with patch.object(completion.client.models, 'generate_content') as mock_generate:
|
||||
# Mock function call in response
|
||||
mock_function_call = Mock()
|
||||
mock_function_call.name = "get_weather"
|
||||
mock_function_call.args = {"location": "San Francisco"}
|
||||
|
||||
mock_part = Mock()
|
||||
mock_part.function_call = mock_function_call
|
||||
|
||||
mock_content = Mock()
|
||||
mock_content.parts = [mock_part]
|
||||
|
||||
mock_candidate = Mock()
|
||||
mock_candidate.content = mock_content
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.candidates = [mock_candidate]
|
||||
mock_response.text = "Based on the weather data, it's a beautiful day in San Francisco with sunny skies and 75°F temperature."
|
||||
mock_response.usage_metadata = Mock()
|
||||
mock_response.usage_metadata.prompt_token_count = 100
|
||||
mock_response.usage_metadata.candidates_token_count = 50
|
||||
mock_response.usage_metadata.total_token_count = 150
|
||||
|
||||
mock_generate.return_value = mock_response
|
||||
|
||||
# Test the call
|
||||
messages = [{"role": "user", "content": "What's the weather like in San Francisco?"}]
|
||||
result = completion.call(
|
||||
messages=messages,
|
||||
available_functions=available_functions
|
||||
)
|
||||
|
||||
# Verify the tool was executed and returned the result
|
||||
assert result == "The weather in San Francisco is sunny and 75°F"
|
||||
|
||||
# Verify that the API was called
|
||||
assert mock_generate.called
|
||||
|
||||
|
||||
def test_gemini_completion_module_is_imported():
|
||||
"""
|
||||
Test that the completion module is properly imported when using Google provider
|
||||
"""
|
||||
module_name = "crewai.llms.providers.gemini.completion"
|
||||
|
||||
# Remove module from cache if it exists
|
||||
if module_name in sys.modules:
|
||||
del sys.modules[module_name]
|
||||
|
||||
# Create LLM instance - this should trigger the import
|
||||
LLM(model="google/gemini-2.0-flash-001")
|
||||
|
||||
# Verify the module was imported
|
||||
assert module_name in sys.modules
|
||||
completion_mod = sys.modules[module_name]
|
||||
assert isinstance(completion_mod, types.ModuleType)
|
||||
|
||||
# Verify the class exists in the module
|
||||
assert hasattr(completion_mod, 'GeminiCompletion')
|
||||
|
||||
|
||||
def test_fallback_to_litellm_when_native_gemini_fails():
|
||||
"""
|
||||
Test that LLM falls back to LiteLLM when native Gemini completion fails
|
||||
"""
|
||||
# Mock the _get_native_provider to return a failing class
|
||||
with patch('crewai.llm.LLM._get_native_provider') as mock_get_provider:
|
||||
|
||||
class FailingCompletion:
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise Exception("Native Google Gen AI SDK failed")
|
||||
|
||||
mock_get_provider.return_value = FailingCompletion
|
||||
|
||||
# This should fall back to LiteLLM
|
||||
llm = LLM(model="google/gemini-2.0-flash-001")
|
||||
|
||||
# Check that it's using LiteLLM
|
||||
assert hasattr(llm, 'is_litellm')
|
||||
assert llm.is_litellm == True
|
||||
|
||||
|
||||
def test_gemini_completion_initialization_parameters():
|
||||
"""
|
||||
Test that GeminiCompletion is initialized with correct parameters
|
||||
"""
|
||||
llm = LLM(
|
||||
model="google/gemini-2.0-flash-001",
|
||||
temperature=0.7,
|
||||
max_output_tokens=2000,
|
||||
top_p=0.9,
|
||||
top_k=40,
|
||||
api_key="test-key"
|
||||
)
|
||||
|
||||
from crewai.llms.providers.gemini.completion import GeminiCompletion
|
||||
assert isinstance(llm, GeminiCompletion)
|
||||
assert llm.model == "gemini-2.0-flash-001"
|
||||
assert llm.temperature == 0.7
|
||||
assert llm.max_output_tokens == 2000
|
||||
assert llm.top_p == 0.9
|
||||
assert llm.top_k == 40
|
||||
|
||||
|
||||
def test_gemini_specific_parameters():
|
||||
"""
|
||||
Test Gemini-specific parameters like stop_sequences, streaming, and safety settings
|
||||
"""
|
||||
safety_settings = {
|
||||
"HARM_CATEGORY_HARASSMENT": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
"HARM_CATEGORY_HATE_SPEECH": "BLOCK_MEDIUM_AND_ABOVE"
|
||||
}
|
||||
|
||||
llm = LLM(
|
||||
model="google/gemini-2.0-flash-001",
|
||||
stop_sequences=["Human:", "Assistant:"],
|
||||
stream=True,
|
||||
safety_settings=safety_settings,
|
||||
project="test-project",
|
||||
location="us-central1"
|
||||
)
|
||||
|
||||
from crewai.llms.providers.gemini.completion import GeminiCompletion
|
||||
assert isinstance(llm, GeminiCompletion)
|
||||
assert llm.stop_sequences == ["Human:", "Assistant:"]
|
||||
assert llm.stream == True
|
||||
assert llm.safety_settings == safety_settings
|
||||
assert llm.project == "test-project"
|
||||
assert llm.location == "us-central1"
|
||||
|
||||
|
||||
def test_gemini_completion_call():
|
||||
"""
|
||||
Test that GeminiCompletion call method works
|
||||
"""
|
||||
llm = LLM(model="google/gemini-2.0-flash-001")
|
||||
|
||||
# Mock the call method on the instance
|
||||
with patch.object(llm, 'call', return_value="Hello! I'm Gemini, ready to help.") as mock_call:
|
||||
result = llm.call("Hello, how are you?")
|
||||
|
||||
assert result == "Hello! I'm Gemini, ready to help."
|
||||
mock_call.assert_called_once_with("Hello, how are you?")
|
||||
|
||||
|
||||
def test_gemini_completion_called_during_crew_execution():
|
||||
"""
|
||||
Test that GeminiCompletion.call is actually invoked when running a crew
|
||||
"""
|
||||
# Create the LLM instance first
|
||||
gemini_llm = LLM(model="google/gemini-2.0-flash-001")
|
||||
|
||||
# Mock the call method on the specific instance
|
||||
with patch.object(gemini_llm, 'call', return_value="Tokyo has 14 million people.") as mock_call:
|
||||
|
||||
# Create agent with explicit LLM configuration
|
||||
agent = Agent(
|
||||
role="Research Assistant",
|
||||
goal="Find population info",
|
||||
backstory="You research populations.",
|
||||
llm=gemini_llm,
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Find Tokyo population",
|
||||
expected_output="Population number",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
result = crew.kickoff()
|
||||
|
||||
# Verify mock was called
|
||||
assert mock_call.called
|
||||
assert "14 million" in str(result)
|
||||
|
||||
|
||||
def test_gemini_completion_call_arguments():
|
||||
"""
|
||||
Test that GeminiCompletion.call is invoked with correct arguments
|
||||
"""
|
||||
# Create LLM instance first
|
||||
gemini_llm = LLM(model="google/gemini-2.0-flash-001")
|
||||
|
||||
# Mock the instance method
|
||||
with patch.object(gemini_llm, 'call') as mock_call:
|
||||
mock_call.return_value = "Task completed successfully."
|
||||
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Complete a simple task",
|
||||
backstory="You are a test agent.",
|
||||
llm=gemini_llm # Use same instance
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Say hello world",
|
||||
expected_output="Hello world",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
crew.kickoff()
|
||||
|
||||
# Verify call was made
|
||||
assert mock_call.called
|
||||
|
||||
# Check the arguments passed to the call method
|
||||
call_args = mock_call.call_args
|
||||
assert call_args is not None
|
||||
|
||||
# The first argument should be the messages
|
||||
messages = call_args[0][0] # First positional argument
|
||||
assert isinstance(messages, (str, list))
|
||||
|
||||
# Verify that the task description appears in the messages
|
||||
if isinstance(messages, str):
|
||||
assert "hello world" in messages.lower()
|
||||
elif isinstance(messages, list):
|
||||
message_content = str(messages).lower()
|
||||
assert "hello world" in message_content
|
||||
|
||||
|
||||
def test_multiple_gemini_calls_in_crew():
|
||||
"""
|
||||
Test that GeminiCompletion.call is invoked multiple times for multiple tasks
|
||||
"""
|
||||
# Create LLM instance first
|
||||
gemini_llm = LLM(model="google/gemini-2.0-flash-001")
|
||||
|
||||
# Mock the instance method
|
||||
with patch.object(gemini_llm, 'call') as mock_call:
|
||||
mock_call.return_value = "Task completed."
|
||||
|
||||
agent = Agent(
|
||||
role="Multi-task Agent",
|
||||
goal="Complete multiple tasks",
|
||||
backstory="You can handle multiple tasks.",
|
||||
llm=gemini_llm # Use same instance
|
||||
)
|
||||
|
||||
task1 = Task(
|
||||
description="First task",
|
||||
expected_output="First result",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
task2 = Task(
|
||||
description="Second task",
|
||||
expected_output="Second result",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[agent],
|
||||
tasks=[task1, task2]
|
||||
)
|
||||
crew.kickoff()
|
||||
|
||||
# Verify multiple calls were made
|
||||
assert mock_call.call_count >= 2 # At least one call per task
|
||||
|
||||
# Verify each call had proper arguments
|
||||
for call in mock_call.call_args_list:
|
||||
assert len(call[0]) > 0 # Has positional arguments
|
||||
messages = call[0][0]
|
||||
assert messages is not None
|
||||
|
||||
|
||||
def test_gemini_completion_with_tools():
|
||||
"""
|
||||
Test that GeminiCompletion.call is invoked with tools when agent has tools
|
||||
"""
|
||||
from crewai.tools import tool
|
||||
|
||||
@tool
|
||||
def sample_tool(query: str) -> str:
|
||||
"""A sample tool for testing"""
|
||||
return f"Tool result for: {query}"
|
||||
|
||||
# Create LLM instance first
|
||||
gemini_llm = LLM(model="google/gemini-2.0-flash-001")
|
||||
|
||||
# Mock the instance method
|
||||
with patch.object(gemini_llm, 'call') as mock_call:
|
||||
mock_call.return_value = "Task completed with tools."
|
||||
|
||||
agent = Agent(
|
||||
role="Tool User",
|
||||
goal="Use tools to complete tasks",
|
||||
backstory="You can use tools.",
|
||||
llm=gemini_llm, # Use same instance
|
||||
tools=[sample_tool]
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Use the sample tool",
|
||||
expected_output="Tool usage result",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
crew.kickoff()
|
||||
|
||||
assert mock_call.called
|
||||
|
||||
call_args = mock_call.call_args
|
||||
call_kwargs = call_args[1] if len(call_args) > 1 else {}
|
||||
|
||||
if 'tools' in call_kwargs:
|
||||
assert call_kwargs['tools'] is not None
|
||||
assert len(call_kwargs['tools']) > 0
|
||||
|
||||
|
||||
def test_gemini_raises_error_when_model_not_supported():
|
||||
"""Test that GeminiCompletion raises ValueError when model not supported"""
|
||||
|
||||
# Mock the Google client to raise an error
|
||||
with patch('crewai.llms.providers.gemini.completion.genai') as mock_genai:
|
||||
mock_client = MagicMock()
|
||||
mock_genai.Client.return_value = mock_client
|
||||
|
||||
# Mock the error that Google would raise for unsupported models
|
||||
from google.genai.errors import ClientError # type: ignore
|
||||
mock_client.models.generate_content.side_effect = ClientError(
|
||||
code=404,
|
||||
response_json={
|
||||
'error': {
|
||||
'code': 404,
|
||||
'message': 'models/model-doesnt-exist is not found for API version v1beta, or is not supported for generateContent.',
|
||||
'status': 'NOT_FOUND'
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
llm = LLM(model="google/model-doesnt-exist")
|
||||
|
||||
with pytest.raises(Exception): # Should raise some error for unsupported model
|
||||
llm.call("Hello")
|
||||
|
||||
|
||||
def test_gemini_vertex_ai_setup():
|
||||
"""
|
||||
Test that Vertex AI configuration is properly handled
|
||||
"""
|
||||
with patch.dict(os.environ, {
|
||||
"GOOGLE_CLOUD_PROJECT": "test-project",
|
||||
"GOOGLE_CLOUD_LOCATION": "us-west1"
|
||||
}):
|
||||
llm = LLM(
|
||||
model="google/gemini-2.0-flash-001",
|
||||
project="test-project",
|
||||
location="us-west1"
|
||||
)
|
||||
|
||||
from crewai.llms.providers.gemini.completion import GeminiCompletion
|
||||
assert isinstance(llm, GeminiCompletion)
|
||||
|
||||
assert llm.project == "test-project"
|
||||
assert llm.location == "us-west1"
|
||||
|
||||
|
||||
def test_gemini_api_key_configuration():
|
||||
"""
|
||||
Test that API key configuration works for both GOOGLE_API_KEY and GEMINI_API_KEY
|
||||
"""
|
||||
# Test with GOOGLE_API_KEY
|
||||
with patch.dict(os.environ, {"GOOGLE_API_KEY": "test-google-key"}):
|
||||
llm = LLM(model="google/gemini-2.0-flash-001")
|
||||
|
||||
from crewai.llms.providers.gemini.completion import GeminiCompletion
|
||||
assert isinstance(llm, GeminiCompletion)
|
||||
assert llm.api_key == "test-google-key"
|
||||
|
||||
# Test with GEMINI_API_KEY
|
||||
with patch.dict(os.environ, {"GEMINI_API_KEY": "test-gemini-key"}, clear=True):
|
||||
llm = LLM(model="google/gemini-2.0-flash-001")
|
||||
|
||||
assert isinstance(llm, GeminiCompletion)
|
||||
assert llm.api_key == "test-gemini-key"
|
||||
|
||||
|
||||
def test_gemini_model_capabilities():
|
||||
"""
|
||||
Test that model capabilities are correctly identified
|
||||
"""
|
||||
# Test Gemini 2.0 model
|
||||
llm_2_0 = LLM(model="google/gemini-2.0-flash-001")
|
||||
from crewai.llms.providers.gemini.completion import GeminiCompletion
|
||||
assert isinstance(llm_2_0, GeminiCompletion)
|
||||
assert llm_2_0.is_gemini_2 == True
|
||||
assert llm_2_0.supports_tools == True
|
||||
|
||||
# Test Gemini 1.5 model
|
||||
llm_1_5 = LLM(model="google/gemini-1.5-pro")
|
||||
assert isinstance(llm_1_5, GeminiCompletion)
|
||||
assert llm_1_5.is_gemini_1_5 == True
|
||||
assert llm_1_5.supports_tools == True
|
||||
|
||||
|
||||
def test_gemini_generation_config():
|
||||
"""
|
||||
Test that generation config is properly prepared
|
||||
"""
|
||||
llm = LLM(
|
||||
model="google/gemini-2.0-flash-001",
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
top_k=40,
|
||||
max_output_tokens=1000
|
||||
)
|
||||
|
||||
from crewai.llms.providers.gemini.completion import GeminiCompletion
|
||||
assert isinstance(llm, GeminiCompletion)
|
||||
|
||||
# Test config preparation
|
||||
config = llm._prepare_generation_config()
|
||||
|
||||
# Verify config has the expected parameters
|
||||
assert hasattr(config, 'temperature') or 'temperature' in str(config)
|
||||
assert hasattr(config, 'top_p') or 'top_p' in str(config)
|
||||
assert hasattr(config, 'top_k') or 'top_k' in str(config)
|
||||
assert hasattr(config, 'max_output_tokens') or 'max_output_tokens' in str(config)
|
||||
|
||||
|
||||
def test_gemini_model_detection():
|
||||
"""
|
||||
Test that various Gemini model formats are properly detected
|
||||
"""
|
||||
# Test Gemini model naming patterns that actually work with provider detection
|
||||
gemini_test_cases = [
|
||||
"google/gemini-2.0-flash-001",
|
||||
"gemini/gemini-2.0-flash-001",
|
||||
"google/gemini-1.5-pro",
|
||||
"gemini/gemini-1.5-flash"
|
||||
]
|
||||
|
||||
for model_name in gemini_test_cases:
|
||||
llm = LLM(model=model_name)
|
||||
from crewai.llms.providers.gemini.completion import GeminiCompletion
|
||||
assert isinstance(llm, GeminiCompletion), f"Failed for model: {model_name}"
|
||||
|
||||
|
||||
def test_gemini_supports_stop_words():
|
||||
"""
|
||||
Test that Gemini models support stop sequences
|
||||
"""
|
||||
llm = LLM(model="google/gemini-2.0-flash-001")
|
||||
assert llm.supports_stop_words() == True
|
||||
|
||||
|
||||
def test_gemini_context_window_size():
|
||||
"""
|
||||
Test that Gemini models return correct context window sizes
|
||||
"""
|
||||
# Test Gemini 2.0 Flash
|
||||
llm_2_0 = LLM(model="google/gemini-2.0-flash-001")
|
||||
context_size_2_0 = llm_2_0.get_context_window_size()
|
||||
assert context_size_2_0 > 500000 # Should be substantial (1M tokens)
|
||||
|
||||
# Test Gemini 1.5 Pro
|
||||
llm_1_5 = LLM(model="google/gemini-1.5-pro")
|
||||
context_size_1_5 = llm_1_5.get_context_window_size()
|
||||
assert context_size_1_5 > 1000000 # Should be very large (2M tokens)
|
||||
|
||||
|
||||
def test_gemini_message_formatting():
|
||||
"""
|
||||
Test that messages are properly formatted for Gemini API
|
||||
"""
|
||||
llm = LLM(model="google/gemini-2.0-flash-001")
|
||||
|
||||
# Test message formatting
|
||||
test_messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "user", "content": "How are you?"}
|
||||
]
|
||||
|
||||
formatted_contents, system_instruction = llm._format_messages_for_gemini(test_messages)
|
||||
|
||||
# System message should be extracted
|
||||
assert system_instruction == "You are a helpful assistant."
|
||||
|
||||
# Remaining messages should be Content objects
|
||||
assert len(formatted_contents) >= 3 # Should have user, model, user messages
|
||||
|
||||
# First content should be user role
|
||||
assert formatted_contents[0].role == "user"
|
||||
# Second should be model (converted from assistant)
|
||||
assert formatted_contents[1].role == "model"
|
||||
|
||||
|
||||
def test_gemini_streaming_parameter():
|
||||
"""
|
||||
Test that streaming parameter is properly handled
|
||||
"""
|
||||
# Test non-streaming
|
||||
llm_no_stream = LLM(model="google/gemini-2.0-flash-001", stream=False)
|
||||
assert llm_no_stream.stream == False
|
||||
|
||||
# Test streaming
|
||||
llm_stream = LLM(model="google/gemini-2.0-flash-001", stream=True)
|
||||
assert llm_stream.stream == True
|
||||
|
||||
|
||||
def test_gemini_tool_conversion():
|
||||
"""
|
||||
Test that tools are properly converted to Gemini format
|
||||
"""
|
||||
llm = LLM(model="google/gemini-2.0-flash-001")
|
||||
|
||||
# Mock tool in CrewAI format
|
||||
crewai_tools = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"description": "A test tool",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "Search query"}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
}]
|
||||
|
||||
# Test tool conversion
|
||||
gemini_tools = llm._convert_tools_for_interference(crewai_tools)
|
||||
|
||||
assert len(gemini_tools) == 1
|
||||
# Gemini tools are Tool objects with function_declarations
|
||||
assert hasattr(gemini_tools[0], 'function_declarations')
|
||||
assert len(gemini_tools[0].function_declarations) == 1
|
||||
|
||||
func_decl = gemini_tools[0].function_declarations[0]
|
||||
assert func_decl.name == "test_tool"
|
||||
assert func_decl.description == "A test tool"
|
||||
|
||||
|
||||
def test_gemini_environment_variable_api_key():
|
||||
"""
|
||||
Test that Google API key is properly loaded from environment
|
||||
"""
|
||||
with patch.dict(os.environ, {"GOOGLE_API_KEY": "test-google-key"}):
|
||||
llm = LLM(model="google/gemini-2.0-flash-001")
|
||||
|
||||
assert llm.client is not None
|
||||
assert hasattr(llm.client, 'models')
|
||||
assert llm.api_key == "test-google-key"
|
||||
|
||||
|
||||
def test_gemini_token_usage_tracking():
|
||||
"""
|
||||
Test that token usage is properly tracked for Gemini responses
|
||||
"""
|
||||
llm = LLM(model="google/gemini-2.0-flash-001")
|
||||
|
||||
# Mock the Gemini response with usage information
|
||||
with patch.object(llm.client.models, 'generate_content') as mock_generate:
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "test response"
|
||||
mock_response.candidates = []
|
||||
mock_response.usage_metadata = MagicMock(
|
||||
prompt_token_count=50,
|
||||
candidates_token_count=25,
|
||||
total_token_count=75
|
||||
)
|
||||
mock_generate.return_value = mock_response
|
||||
|
||||
result = llm.call("Hello")
|
||||
|
||||
# Verify the response
|
||||
assert result == "test response"
|
||||
|
||||
# Verify token usage was extracted
|
||||
usage = llm._extract_token_usage(mock_response)
|
||||
assert usage["prompt_token_count"] == 50
|
||||
assert usage["candidates_token_count"] == 25
|
||||
assert usage["total_token_count"] == 75
|
||||
assert usage["total_tokens"] == 75
|
||||
409
lib/crewai/tests/llms/openai/test_openai.py
Normal file
409
lib/crewai/tests/llms/openai/test_openai.py
Normal file
@@ -0,0 +1,409 @@
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
from unittest.mock import patch, MagicMock
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.providers.openai.completion import OpenAICompletion
|
||||
from crewai.crew import Crew
|
||||
from crewai.agent import Agent
|
||||
from crewai.task import Task
|
||||
from crewai.cli.constants import DEFAULT_LLM_MODEL
|
||||
|
||||
def test_openai_completion_is_used_when_openai_provider():
|
||||
"""
|
||||
Test that OpenAICompletion from completion.py is used when LLM uses provider 'openai'
|
||||
"""
|
||||
llm = LLM(model="openai/gpt-4o")
|
||||
|
||||
assert llm.__class__.__name__ == "OpenAICompletion"
|
||||
assert llm.provider == "openai"
|
||||
assert llm.model == "gpt-4o"
|
||||
|
||||
|
||||
def test_openai_completion_is_used_when_no_provider_prefix():
|
||||
"""
|
||||
Test that OpenAICompletion is used when no provider prefix is given (defaults to openai)
|
||||
"""
|
||||
llm = LLM(model="gpt-4o")
|
||||
|
||||
from crewai.llms.providers.openai.completion import OpenAICompletion
|
||||
assert isinstance(llm, OpenAICompletion)
|
||||
assert llm.provider == "openai"
|
||||
assert llm.model == "gpt-4o"
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_openai_is_default_provider_without_explicit_llm_set_on_agent():
|
||||
"""
|
||||
Test that OpenAI is the default provider when no explicit LLM is set on the agent
|
||||
"""
|
||||
agent = Agent(
|
||||
role="Research Assistant",
|
||||
goal="Find information about the population of Tokyo",
|
||||
backstory="You are a helpful research assistant.",
|
||||
)
|
||||
task = Task(
|
||||
description="Find information about the population of Tokyo",
|
||||
expected_output="The population of Tokyo is 10 million",
|
||||
agent=agent,
|
||||
)
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
crew.kickoff()
|
||||
assert crew.agents[0].llm.__class__.__name__ == "OpenAICompletion"
|
||||
assert crew.agents[0].llm.model == DEFAULT_LLM_MODEL
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def test_openai_completion_module_is_imported():
|
||||
"""
|
||||
Test that the completion module is properly imported when using OpenAI provider
|
||||
"""
|
||||
module_name = "crewai.llms.providers.openai.completion"
|
||||
|
||||
# Remove module from cache if it exists
|
||||
if module_name in sys.modules:
|
||||
del sys.modules[module_name]
|
||||
|
||||
# Create LLM instance - this should trigger the import
|
||||
LLM(model="openai/gpt-4o")
|
||||
|
||||
# Verify the module was imported
|
||||
assert module_name in sys.modules
|
||||
completion_mod = sys.modules[module_name]
|
||||
assert isinstance(completion_mod, types.ModuleType)
|
||||
|
||||
# Verify the class exists in the module
|
||||
assert hasattr(completion_mod, 'OpenAICompletion')
|
||||
|
||||
|
||||
def test_fallback_to_litellm_when_native_fails():
|
||||
"""
|
||||
Test that LLM falls back to LiteLLM when native OpenAI completion fails
|
||||
"""
|
||||
# Mock the _get_native_provider to return a failing class
|
||||
with patch('crewai.llm.LLM._get_native_provider') as mock_get_provider:
|
||||
|
||||
class FailingCompletion:
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise Exception("Native SDK failed")
|
||||
|
||||
mock_get_provider.return_value = FailingCompletion
|
||||
|
||||
# This should fall back to LiteLLM
|
||||
llm = LLM(model="openai/gpt-4o")
|
||||
|
||||
# Check that it's using LiteLLM
|
||||
assert hasattr(llm, 'is_litellm')
|
||||
assert llm.is_litellm == True
|
||||
|
||||
|
||||
def test_openai_completion_initialization_parameters():
|
||||
"""
|
||||
Test that OpenAICompletion is initialized with correct parameters
|
||||
"""
|
||||
llm = LLM(
|
||||
model="openai/gpt-4o",
|
||||
temperature=0.7,
|
||||
max_tokens=1000,
|
||||
api_key="test-key"
|
||||
)
|
||||
|
||||
from crewai.llms.providers.openai.completion import OpenAICompletion
|
||||
assert isinstance(llm, OpenAICompletion)
|
||||
assert llm.model == "gpt-4o"
|
||||
assert llm.temperature == 0.7
|
||||
assert llm.max_tokens == 1000
|
||||
|
||||
def test_openai_completion_call():
|
||||
"""
|
||||
Test that OpenAICompletion call method works
|
||||
"""
|
||||
llm = LLM(model="openai/gpt-4o")
|
||||
|
||||
# Mock the call method on the instance
|
||||
with patch.object(llm, 'call', return_value="Hello! I'm ready to help.") as mock_call:
|
||||
result = llm.call("Hello, how are you?")
|
||||
|
||||
assert result == "Hello! I'm ready to help."
|
||||
mock_call.assert_called_once_with("Hello, how are you?")
|
||||
|
||||
|
||||
def test_openai_completion_called_during_crew_execution():
|
||||
"""
|
||||
Test that OpenAICompletion.call is actually invoked when running a crew
|
||||
"""
|
||||
# Create the LLM instance first
|
||||
openai_llm = LLM(model="openai/gpt-4o")
|
||||
|
||||
# Mock the call method on the specific instance
|
||||
with patch.object(openai_llm, 'call', return_value="Tokyo has 14 million people.") as mock_call:
|
||||
|
||||
# Create agent with explicit LLM configuration
|
||||
agent = Agent(
|
||||
role="Research Assistant",
|
||||
goal="Find population info",
|
||||
backstory="You research populations.",
|
||||
llm=openai_llm,
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Find Tokyo population",
|
||||
expected_output="Population number",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
result = crew.kickoff()
|
||||
|
||||
# Verify mock was called
|
||||
assert mock_call.called
|
||||
assert "14 million" in str(result)
|
||||
|
||||
|
||||
def test_openai_completion_call_arguments():
|
||||
"""
|
||||
Test that OpenAICompletion.call is invoked with correct arguments
|
||||
"""
|
||||
# Create LLM instance first (like working tests)
|
||||
openai_llm = LLM(model="openai/gpt-4o")
|
||||
|
||||
# Mock the instance method (like working tests)
|
||||
with patch.object(openai_llm, 'call') as mock_call:
|
||||
mock_call.return_value = "Task completed successfully."
|
||||
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Complete a simple task",
|
||||
backstory="You are a test agent.",
|
||||
llm=openai_llm # Use same instance
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Say hello world",
|
||||
expected_output="Hello world",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
crew.kickoff()
|
||||
|
||||
# Verify call was made
|
||||
assert mock_call.called
|
||||
|
||||
# Check the arguments passed to the call method
|
||||
call_args = mock_call.call_args
|
||||
assert call_args is not None
|
||||
|
||||
# The first argument should be the messages
|
||||
messages = call_args[0][0] # First positional argument
|
||||
assert isinstance(messages, (str, list))
|
||||
|
||||
# Verify that the task description appears in the messages
|
||||
if isinstance(messages, str):
|
||||
assert "hello world" in messages.lower()
|
||||
elif isinstance(messages, list):
|
||||
message_content = str(messages).lower()
|
||||
assert "hello world" in message_content
|
||||
|
||||
|
||||
def test_multiple_openai_calls_in_crew():
|
||||
"""
|
||||
Test that OpenAICompletion.call is invoked multiple times for multiple tasks
|
||||
"""
|
||||
# Create LLM instance first
|
||||
openai_llm = LLM(model="openai/gpt-4o")
|
||||
|
||||
# Mock the instance method
|
||||
with patch.object(openai_llm, 'call') as mock_call:
|
||||
mock_call.return_value = "Task completed."
|
||||
|
||||
agent = Agent(
|
||||
role="Multi-task Agent",
|
||||
goal="Complete multiple tasks",
|
||||
backstory="You can handle multiple tasks.",
|
||||
llm=openai_llm # Use same instance
|
||||
)
|
||||
|
||||
task1 = Task(
|
||||
description="First task",
|
||||
expected_output="First result",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
task2 = Task(
|
||||
description="Second task",
|
||||
expected_output="Second result",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[agent],
|
||||
tasks=[task1, task2]
|
||||
)
|
||||
crew.kickoff()
|
||||
|
||||
# Verify multiple calls were made
|
||||
assert mock_call.call_count >= 2 # At least one call per task
|
||||
|
||||
# Verify each call had proper arguments
|
||||
for call in mock_call.call_args_list:
|
||||
assert len(call[0]) > 0 # Has positional arguments
|
||||
messages = call[0][0]
|
||||
assert messages is not None
|
||||
|
||||
|
||||
def test_openai_completion_with_tools():
|
||||
"""
|
||||
Test that OpenAICompletion.call is invoked with tools when agent has tools
|
||||
"""
|
||||
from crewai.tools import tool
|
||||
|
||||
@tool
|
||||
def sample_tool(query: str) -> str:
|
||||
"""A sample tool for testing"""
|
||||
return f"Tool result for: {query}"
|
||||
|
||||
# Create LLM instance first
|
||||
openai_llm = LLM(model="openai/gpt-4o")
|
||||
|
||||
# Mock the instance method (not the class method)
|
||||
with patch.object(openai_llm, 'call') as mock_call:
|
||||
mock_call.return_value = "Task completed with tools."
|
||||
|
||||
agent = Agent(
|
||||
role="Tool User",
|
||||
goal="Use tools to complete tasks",
|
||||
backstory="You can use tools.",
|
||||
llm=openai_llm, # Use same instance
|
||||
tools=[sample_tool]
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Use the sample tool",
|
||||
expected_output="Tool usage result",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
crew.kickoff()
|
||||
|
||||
assert mock_call.called
|
||||
|
||||
call_args = mock_call.call_args
|
||||
call_kwargs = call_args[1] if len(call_args) > 1 else {}
|
||||
|
||||
if 'tools' in call_kwargs:
|
||||
assert call_kwargs['tools'] is not None
|
||||
assert len(call_kwargs['tools']) > 0
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_openai_completion_call_returns_usage_metrics():
|
||||
"""
|
||||
Test that OpenAICompletion.call returns usage metrics
|
||||
"""
|
||||
agent = Agent(
|
||||
role="Research Assistant",
|
||||
goal="Find information about the population of Tokyo",
|
||||
backstory="You are a helpful research assistant.",
|
||||
llm=LLM(model="openai/gpt-4o"),
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Find information about the population of Tokyo",
|
||||
expected_output="The population of Tokyo is 10 million",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
result = crew.kickoff()
|
||||
assert result.token_usage is not None
|
||||
assert result.token_usage.total_tokens == 289
|
||||
assert result.token_usage.prompt_tokens == 173
|
||||
assert result.token_usage.completion_tokens == 116
|
||||
assert result.token_usage.successful_requests == 1
|
||||
assert result.token_usage.cached_prompt_tokens == 0
|
||||
|
||||
|
||||
def test_openai_raises_error_when_model_not_supported():
|
||||
"""Test that OpenAICompletion raises ValueError when model not supported"""
|
||||
|
||||
with patch('crewai.llms.providers.openai.completion.OpenAI') as mock_openai_class:
|
||||
mock_client = MagicMock()
|
||||
mock_openai_class.return_value = mock_client
|
||||
|
||||
mock_client.chat.completions.create.side_effect = openai.NotFoundError(
|
||||
message="The model `model-doesnt-exist` does not exist",
|
||||
response=MagicMock(),
|
||||
body={}
|
||||
)
|
||||
|
||||
llm = LLM(model="openai/model-doesnt-exist")
|
||||
|
||||
with pytest.raises(ValueError, match="Model.*not found"):
|
||||
llm.call("Hello")
|
||||
|
||||
def test_openai_client_setup_with_extra_arguments():
|
||||
"""
|
||||
Test that OpenAICompletion is initialized with correct parameters
|
||||
"""
|
||||
llm = LLM(
|
||||
model="openai/gpt-4o",
|
||||
temperature=0.7,
|
||||
max_tokens=1000,
|
||||
top_p=0.5,
|
||||
max_retries=3,
|
||||
timeout=30
|
||||
)
|
||||
|
||||
# Check that model parameters are stored on the LLM instance
|
||||
assert llm.temperature == 0.7
|
||||
assert llm.max_tokens == 1000
|
||||
assert llm.top_p == 0.5
|
||||
|
||||
# Check that client parameters are properly configured
|
||||
assert llm.client.max_retries == 3
|
||||
assert llm.client.timeout == 30
|
||||
|
||||
# Test that parameters are properly used in API calls
|
||||
with patch.object(llm.client.chat.completions, 'create') as mock_create:
|
||||
mock_create.return_value = MagicMock(
|
||||
choices=[MagicMock(message=MagicMock(content="test response", tool_calls=None))],
|
||||
usage=MagicMock(prompt_tokens=10, completion_tokens=20, total_tokens=30)
|
||||
)
|
||||
|
||||
llm.call("Hello")
|
||||
|
||||
# Verify the API was called with the right parameters
|
||||
call_args = mock_create.call_args[1] # keyword arguments
|
||||
assert call_args['temperature'] == 0.7
|
||||
assert call_args['max_tokens'] == 1000
|
||||
assert call_args['top_p'] == 0.5
|
||||
assert call_args['model'] == 'gpt-4o'
|
||||
|
||||
def test_extra_arguments_are_passed_to_openai_completion():
|
||||
"""
|
||||
Test that extra arguments are passed to OpenAICompletion
|
||||
"""
|
||||
llm = LLM(model="openai/gpt-4o", temperature=0.7, max_tokens=1000, top_p=0.5, max_retries=3)
|
||||
|
||||
with patch.object(llm.client.chat.completions, 'create') as mock_create:
|
||||
mock_create.return_value = MagicMock(
|
||||
choices=[MagicMock(message=MagicMock(content="test response", tool_calls=None))],
|
||||
usage=MagicMock(prompt_tokens=10, completion_tokens=20, total_tokens=30)
|
||||
)
|
||||
|
||||
llm.call("Hello, how are you?")
|
||||
|
||||
assert mock_create.called
|
||||
call_kwargs = mock_create.call_args[1]
|
||||
|
||||
assert call_kwargs['temperature'] == 0.7
|
||||
assert call_kwargs['max_tokens'] == 1000
|
||||
assert call_kwargs['top_p'] == 0.5
|
||||
assert call_kwargs['model'] == 'gpt-4o'
|
||||
@@ -1,23 +1,36 @@
|
||||
from unittest.mock import MagicMock, patch, ANY
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemorySaveStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemoryQueryCompletedEvent,
|
||||
)
|
||||
from unittest.mock import ANY, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from mem0.memory.main import Memory
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.crew import Crew, Process
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from crewai.memory.external.external_memory import ExternalMemory
|
||||
from crewai.memory.external.external_memory_item import ExternalMemoryItem
|
||||
from crewai.memory.storage.interface import Storage
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_event_handlers():
|
||||
"""Cleanup event handlers after each test"""
|
||||
yield
|
||||
with crewai_event_bus._rwlock.w_locked():
|
||||
crewai_event_bus._sync_handlers = {}
|
||||
crewai_event_bus._async_handlers = {}
|
||||
crewai_event_bus._handler_dependencies = {}
|
||||
crewai_event_bus._execution_plan_cache = {}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mem0_memory():
|
||||
mock_memory = MagicMock(spec=Memory)
|
||||
@@ -238,24 +251,26 @@ def test_external_memory_search_events(
|
||||
custom_storage, external_memory_with_mocked_config
|
||||
):
|
||||
events = defaultdict(list)
|
||||
event_received = threading.Event()
|
||||
|
||||
external_memory_with_mocked_config.storage = custom_storage
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
||||
def on_search_started(source, event):
|
||||
events["MemoryQueryStartedEvent"].append(event)
|
||||
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
||||
def on_search_started(source, event):
|
||||
events["MemoryQueryStartedEvent"].append(event)
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||
def on_search_completed(source, event):
|
||||
events["MemoryQueryCompletedEvent"].append(event)
|
||||
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||
def on_search_completed(source, event):
|
||||
events["MemoryQueryCompletedEvent"].append(event)
|
||||
event_received.set()
|
||||
|
||||
external_memory_with_mocked_config.search(
|
||||
query="test value",
|
||||
limit=3,
|
||||
score_threshold=0.35,
|
||||
)
|
||||
external_memory_with_mocked_config.search(
|
||||
query="test value",
|
||||
limit=3,
|
||||
score_threshold=0.35,
|
||||
)
|
||||
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for search events"
|
||||
assert len(events["MemoryQueryStartedEvent"]) == 1
|
||||
assert len(events["MemoryQueryCompletedEvent"]) == 1
|
||||
|
||||
@@ -300,24 +315,25 @@ def test_external_memory_save_events(
|
||||
custom_storage, external_memory_with_mocked_config
|
||||
):
|
||||
events = defaultdict(list)
|
||||
event_received = threading.Event()
|
||||
|
||||
external_memory_with_mocked_config.storage = custom_storage
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||
def on_save_started(source, event):
|
||||
events["MemorySaveStartedEvent"].append(event)
|
||||
|
||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||
def on_save_started(source, event):
|
||||
events["MemorySaveStartedEvent"].append(event)
|
||||
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||
def on_save_completed(source, event):
|
||||
events["MemorySaveCompletedEvent"].append(event)
|
||||
event_received.set()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||
def on_save_completed(source, event):
|
||||
events["MemorySaveCompletedEvent"].append(event)
|
||||
|
||||
external_memory_with_mocked_config.save(
|
||||
value="saving value",
|
||||
metadata={"task": "test_task"},
|
||||
)
|
||||
external_memory_with_mocked_config.save(
|
||||
value="saving value",
|
||||
metadata={"task": "test_task"},
|
||||
)
|
||||
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for save events"
|
||||
assert len(events["MemorySaveStartedEvent"]) == 1
|
||||
assert len(events["MemorySaveCompletedEvent"]) == 1
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from unittest.mock import ANY
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
@@ -21,27 +23,37 @@ def long_term_memory():
|
||||
|
||||
def test_long_term_memory_save_events(long_term_memory):
|
||||
events = defaultdict(list)
|
||||
all_events_received = threading.Event()
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||
def on_save_started(source, event):
|
||||
events["MemorySaveStartedEvent"].append(event)
|
||||
if (
|
||||
len(events["MemorySaveStartedEvent"]) == 1
|
||||
and len(events["MemorySaveCompletedEvent"]) == 1
|
||||
):
|
||||
all_events_received.set()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||
def on_save_started(source, event):
|
||||
events["MemorySaveStartedEvent"].append(event)
|
||||
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||
def on_save_completed(source, event):
|
||||
events["MemorySaveCompletedEvent"].append(event)
|
||||
if (
|
||||
len(events["MemorySaveStartedEvent"]) == 1
|
||||
and len(events["MemorySaveCompletedEvent"]) == 1
|
||||
):
|
||||
all_events_received.set()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||
def on_save_completed(source, event):
|
||||
events["MemorySaveCompletedEvent"].append(event)
|
||||
|
||||
memory = LongTermMemoryItem(
|
||||
agent="test_agent",
|
||||
task="test_task",
|
||||
expected_output="test_output",
|
||||
datetime="test_datetime",
|
||||
quality=0.5,
|
||||
metadata={"task": "test_task", "quality": 0.5},
|
||||
)
|
||||
long_term_memory.save(memory)
|
||||
memory = LongTermMemoryItem(
|
||||
agent="test_agent",
|
||||
task="test_task",
|
||||
expected_output="test_output",
|
||||
datetime="test_datetime",
|
||||
quality=0.5,
|
||||
metadata={"task": "test_task", "quality": 0.5},
|
||||
)
|
||||
long_term_memory.save(memory)
|
||||
|
||||
assert all_events_received.wait(timeout=5), "Timeout waiting for save events"
|
||||
assert len(events["MemorySaveStartedEvent"]) == 1
|
||||
assert len(events["MemorySaveCompletedEvent"]) == 1
|
||||
assert len(events["MemorySaveFailedEvent"]) == 0
|
||||
@@ -86,21 +98,31 @@ def test_long_term_memory_save_events(long_term_memory):
|
||||
|
||||
def test_long_term_memory_search_events(long_term_memory):
|
||||
events = defaultdict(list)
|
||||
all_events_received = threading.Event()
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
||||
def on_search_started(source, event):
|
||||
events["MemoryQueryStartedEvent"].append(event)
|
||||
if (
|
||||
len(events["MemoryQueryStartedEvent"]) == 1
|
||||
and len(events["MemoryQueryCompletedEvent"]) == 1
|
||||
):
|
||||
all_events_received.set()
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
||||
def on_search_started(source, event):
|
||||
events["MemoryQueryStartedEvent"].append(event)
|
||||
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||
def on_search_completed(source, event):
|
||||
events["MemoryQueryCompletedEvent"].append(event)
|
||||
if (
|
||||
len(events["MemoryQueryStartedEvent"]) == 1
|
||||
and len(events["MemoryQueryCompletedEvent"]) == 1
|
||||
):
|
||||
all_events_received.set()
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||
def on_search_completed(source, event):
|
||||
events["MemoryQueryCompletedEvent"].append(event)
|
||||
test_query = "test query"
|
||||
|
||||
test_query = "test query"
|
||||
|
||||
long_term_memory.search(test_query, latest_n=5)
|
||||
long_term_memory.search(test_query, latest_n=5)
|
||||
|
||||
assert all_events_received.wait(timeout=5), "Timeout waiting for search events"
|
||||
assert len(events["MemoryQueryStartedEvent"]) == 1
|
||||
assert len(events["MemoryQueryCompletedEvent"]) == 1
|
||||
assert len(events["MemoryQueryFailedEvent"]) == 0
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from unittest.mock import ANY, patch
|
||||
|
||||
@@ -37,24 +38,33 @@ def short_term_memory():
|
||||
|
||||
def test_short_term_memory_search_events(short_term_memory):
|
||||
events = defaultdict(list)
|
||||
search_started = threading.Event()
|
||||
search_completed = threading.Event()
|
||||
|
||||
with patch.object(short_term_memory.storage, "search", return_value=[]):
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
||||
def on_search_started(source, event):
|
||||
events["MemoryQueryStartedEvent"].append(event)
|
||||
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
||||
def on_search_started(source, event):
|
||||
events["MemoryQueryStartedEvent"].append(event)
|
||||
search_started.set()
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||
def on_search_completed(source, event):
|
||||
events["MemoryQueryCompletedEvent"].append(event)
|
||||
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||
def on_search_completed(source, event):
|
||||
events["MemoryQueryCompletedEvent"].append(event)
|
||||
search_completed.set()
|
||||
|
||||
# Call the save method
|
||||
short_term_memory.search(
|
||||
query="test value",
|
||||
limit=3,
|
||||
score_threshold=0.35,
|
||||
)
|
||||
short_term_memory.search(
|
||||
query="test value",
|
||||
limit=3,
|
||||
score_threshold=0.35,
|
||||
)
|
||||
|
||||
assert search_started.wait(timeout=2), (
|
||||
"Timeout waiting for search started event"
|
||||
)
|
||||
assert search_completed.wait(timeout=2), (
|
||||
"Timeout waiting for search completed event"
|
||||
)
|
||||
|
||||
assert len(events["MemoryQueryStartedEvent"]) == 1
|
||||
assert len(events["MemoryQueryCompletedEvent"]) == 1
|
||||
@@ -98,20 +108,26 @@ def test_short_term_memory_search_events(short_term_memory):
|
||||
|
||||
def test_short_term_memory_save_events(short_term_memory):
|
||||
events = defaultdict(list)
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
save_started = threading.Event()
|
||||
save_completed = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||
def on_save_started(source, event):
|
||||
events["MemorySaveStartedEvent"].append(event)
|
||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||
def on_save_started(source, event):
|
||||
events["MemorySaveStartedEvent"].append(event)
|
||||
save_started.set()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||
def on_save_completed(source, event):
|
||||
events["MemorySaveCompletedEvent"].append(event)
|
||||
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||
def on_save_completed(source, event):
|
||||
events["MemorySaveCompletedEvent"].append(event)
|
||||
save_completed.set()
|
||||
|
||||
short_term_memory.save(
|
||||
value="test value",
|
||||
metadata={"task": "test_task"},
|
||||
)
|
||||
short_term_memory.save(
|
||||
value="test value",
|
||||
metadata={"task": "test_task"},
|
||||
)
|
||||
|
||||
assert save_started.wait(timeout=2), "Timeout waiting for save started event"
|
||||
assert save_completed.wait(timeout=2), "Timeout waiting for save completed event"
|
||||
|
||||
assert len(events["MemorySaveStartedEvent"]) == 1
|
||||
assert len(events["MemorySaveCompletedEvent"]) == 1
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
"""Test Agent creation and execution basic functionality."""
|
||||
|
||||
import json
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import Future
|
||||
from hashlib import md5
|
||||
import json
|
||||
import re
|
||||
from unittest import mock
|
||||
from unittest.mock import ANY, MagicMock, patch
|
||||
@@ -2476,62 +2477,63 @@ def test_using_contextual_memory():
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_memory_events_are_emitted():
|
||||
events = defaultdict(list)
|
||||
event_received = threading.Event()
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||
def handle_memory_save_started(source, event):
|
||||
events["MemorySaveStartedEvent"].append(event)
|
||||
|
||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||
def handle_memory_save_started(source, event):
|
||||
events["MemorySaveStartedEvent"].append(event)
|
||||
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||
def handle_memory_save_completed(source, event):
|
||||
events["MemorySaveCompletedEvent"].append(event)
|
||||
|
||||
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||
def handle_memory_save_completed(source, event):
|
||||
events["MemorySaveCompletedEvent"].append(event)
|
||||
@crewai_event_bus.on(MemorySaveFailedEvent)
|
||||
def handle_memory_save_failed(source, event):
|
||||
events["MemorySaveFailedEvent"].append(event)
|
||||
|
||||
@crewai_event_bus.on(MemorySaveFailedEvent)
|
||||
def handle_memory_save_failed(source, event):
|
||||
events["MemorySaveFailedEvent"].append(event)
|
||||
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
||||
def handle_memory_query_started(source, event):
|
||||
events["MemoryQueryStartedEvent"].append(event)
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
||||
def handle_memory_query_started(source, event):
|
||||
events["MemoryQueryStartedEvent"].append(event)
|
||||
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||
def handle_memory_query_completed(source, event):
|
||||
events["MemoryQueryCompletedEvent"].append(event)
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||
def handle_memory_query_completed(source, event):
|
||||
events["MemoryQueryCompletedEvent"].append(event)
|
||||
@crewai_event_bus.on(MemoryQueryFailedEvent)
|
||||
def handle_memory_query_failed(source, event):
|
||||
events["MemoryQueryFailedEvent"].append(event)
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryFailedEvent)
|
||||
def handle_memory_query_failed(source, event):
|
||||
events["MemoryQueryFailedEvent"].append(event)
|
||||
@crewai_event_bus.on(MemoryRetrievalStartedEvent)
|
||||
def handle_memory_retrieval_started(source, event):
|
||||
events["MemoryRetrievalStartedEvent"].append(event)
|
||||
|
||||
@crewai_event_bus.on(MemoryRetrievalStartedEvent)
|
||||
def handle_memory_retrieval_started(source, event):
|
||||
events["MemoryRetrievalStartedEvent"].append(event)
|
||||
@crewai_event_bus.on(MemoryRetrievalCompletedEvent)
|
||||
def handle_memory_retrieval_completed(source, event):
|
||||
events["MemoryRetrievalCompletedEvent"].append(event)
|
||||
event_received.set()
|
||||
|
||||
@crewai_event_bus.on(MemoryRetrievalCompletedEvent)
|
||||
def handle_memory_retrieval_completed(source, event):
|
||||
events["MemoryRetrievalCompletedEvent"].append(event)
|
||||
math_researcher = Agent(
|
||||
role="Researcher",
|
||||
goal="You research about math.",
|
||||
backstory="You're an expert in research and you love to learn new things.",
|
||||
allow_delegation=False,
|
||||
)
|
||||
|
||||
math_researcher = Agent(
|
||||
role="Researcher",
|
||||
goal="You research about math.",
|
||||
backstory="You're an expert in research and you love to learn new things.",
|
||||
allow_delegation=False,
|
||||
)
|
||||
task1 = Task(
|
||||
description="Research a topic to teach a kid aged 6 about math.",
|
||||
expected_output="A topic, explanation, angle, and examples.",
|
||||
agent=math_researcher,
|
||||
)
|
||||
|
||||
task1 = Task(
|
||||
description="Research a topic to teach a kid aged 6 about math.",
|
||||
expected_output="A topic, explanation, angle, and examples.",
|
||||
agent=math_researcher,
|
||||
)
|
||||
crew = Crew(
|
||||
agents=[math_researcher],
|
||||
tasks=[task1],
|
||||
memory=True,
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[math_researcher],
|
||||
tasks=[task1],
|
||||
memory=True,
|
||||
)
|
||||
|
||||
crew.kickoff()
|
||||
crew.kickoff()
|
||||
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for memory events"
|
||||
assert len(events["MemorySaveStartedEvent"]) == 3
|
||||
assert len(events["MemorySaveCompletedEvent"]) == 3
|
||||
assert len(events["MemorySaveFailedEvent"]) == 0
|
||||
@@ -2907,19 +2909,29 @@ def test_crew_train_success(
|
||||
copy_mock.return_value = crew
|
||||
|
||||
received_events = []
|
||||
lock = threading.Lock()
|
||||
all_events_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(CrewTrainStartedEvent)
|
||||
def on_crew_train_started(source, event: CrewTrainStartedEvent):
|
||||
received_events.append(event)
|
||||
with lock:
|
||||
received_events.append(event)
|
||||
if len(received_events) == 2:
|
||||
all_events_received.set()
|
||||
|
||||
@crewai_event_bus.on(CrewTrainCompletedEvent)
|
||||
def on_crew_train_completed(source, event: CrewTrainCompletedEvent):
|
||||
received_events.append(event)
|
||||
with lock:
|
||||
received_events.append(event)
|
||||
if len(received_events) == 2:
|
||||
all_events_received.set()
|
||||
|
||||
crew.train(
|
||||
n_iterations=2, inputs={"topic": "AI"}, filename="trained_agents_data.pkl"
|
||||
)
|
||||
|
||||
assert all_events_received.wait(timeout=5), "Timeout waiting for all train events"
|
||||
|
||||
# Ensure kickoff is called on the copied crew
|
||||
kickoff_mock.assert_has_calls(
|
||||
[mock.call(inputs={"topic": "AI"}), mock.call(inputs={"topic": "AI"})]
|
||||
@@ -3726,17 +3738,27 @@ def test_crew_testing_function(kickoff_mock, copy_mock, crew_evaluator, research
|
||||
llm_instance = LLM("gpt-4o-mini")
|
||||
|
||||
received_events = []
|
||||
lock = threading.Lock()
|
||||
all_events_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(CrewTestStartedEvent)
|
||||
def on_crew_test_started(source, event: CrewTestStartedEvent):
|
||||
received_events.append(event)
|
||||
with lock:
|
||||
received_events.append(event)
|
||||
if len(received_events) == 2:
|
||||
all_events_received.set()
|
||||
|
||||
@crewai_event_bus.on(CrewTestCompletedEvent)
|
||||
def on_crew_test_completed(source, event: CrewTestCompletedEvent):
|
||||
received_events.append(event)
|
||||
with lock:
|
||||
received_events.append(event)
|
||||
if len(received_events) == 2:
|
||||
all_events_received.set()
|
||||
|
||||
crew.test(n_iterations, llm_instance, inputs={"topic": "AI"})
|
||||
|
||||
assert all_events_received.wait(timeout=5), "Timeout waiting for all test events"
|
||||
|
||||
# Ensure kickoff is called on the copied crew
|
||||
kickoff_mock.assert_has_calls(
|
||||
[mock.call(inputs={"topic": "AI"}), mock.call(inputs={"topic": "AI"})]
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
"""Test Flow creation and execution basic functionality."""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.flow_events import (
|
||||
FlowFinishedEvent,
|
||||
@@ -13,7 +16,6 @@ from crewai.events.types.flow_events import (
|
||||
MethodExecutionStartedEvent,
|
||||
)
|
||||
from crewai.flow.flow import Flow, and_, listen, or_, router, start
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def test_simple_sequential_flow():
|
||||
@@ -439,20 +441,42 @@ def test_unstructured_flow_event_emission():
|
||||
|
||||
flow = PoemFlow()
|
||||
received_events = []
|
||||
lock = threading.Lock()
|
||||
all_events_received = threading.Event()
|
||||
expected_event_count = (
|
||||
7 # 1 FlowStarted + 5 MethodExecutionStarted + 1 FlowFinished
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(FlowStartedEvent)
|
||||
def handle_flow_start(source, event):
|
||||
received_events.append(event)
|
||||
with lock:
|
||||
received_events.append(event)
|
||||
if len(received_events) == expected_event_count:
|
||||
all_events_received.set()
|
||||
|
||||
@crewai_event_bus.on(MethodExecutionStartedEvent)
|
||||
def handle_method_start(source, event):
|
||||
received_events.append(event)
|
||||
with lock:
|
||||
received_events.append(event)
|
||||
if len(received_events) == expected_event_count:
|
||||
all_events_received.set()
|
||||
|
||||
@crewai_event_bus.on(FlowFinishedEvent)
|
||||
def handle_flow_end(source, event):
|
||||
received_events.append(event)
|
||||
with lock:
|
||||
received_events.append(event)
|
||||
if len(received_events) == expected_event_count:
|
||||
all_events_received.set()
|
||||
|
||||
flow.kickoff(inputs={"separator": ", "})
|
||||
|
||||
assert all_events_received.wait(timeout=5), "Timeout waiting for all flow events"
|
||||
|
||||
# Sort events by timestamp to ensure deterministic order
|
||||
# (async handlers may append out of order)
|
||||
with lock:
|
||||
received_events.sort(key=lambda e: e.timestamp)
|
||||
|
||||
assert isinstance(received_events[0], FlowStartedEvent)
|
||||
assert received_events[0].flow_name == "PoemFlow"
|
||||
assert received_events[0].inputs == {"separator": ", "}
|
||||
@@ -642,28 +666,48 @@ def test_structured_flow_event_emission():
|
||||
return f"Welcome, {self.state.name}!"
|
||||
|
||||
flow = OnboardingFlow()
|
||||
flow.kickoff(inputs={"name": "Anakin"})
|
||||
|
||||
received_events = []
|
||||
lock = threading.Lock()
|
||||
all_events_received = threading.Event()
|
||||
expected_event_count = 6 # 1 FlowStarted + 2 MethodExecutionStarted + 2 MethodExecutionFinished + 1 FlowFinished
|
||||
|
||||
@crewai_event_bus.on(FlowStartedEvent)
|
||||
def handle_flow_start(source, event):
|
||||
received_events.append(event)
|
||||
with lock:
|
||||
received_events.append(event)
|
||||
if len(received_events) == expected_event_count:
|
||||
all_events_received.set()
|
||||
|
||||
@crewai_event_bus.on(MethodExecutionStartedEvent)
|
||||
def handle_method_start(source, event):
|
||||
received_events.append(event)
|
||||
with lock:
|
||||
received_events.append(event)
|
||||
if len(received_events) == expected_event_count:
|
||||
all_events_received.set()
|
||||
|
||||
@crewai_event_bus.on(MethodExecutionFinishedEvent)
|
||||
def handle_method_end(source, event):
|
||||
received_events.append(event)
|
||||
with lock:
|
||||
received_events.append(event)
|
||||
if len(received_events) == expected_event_count:
|
||||
all_events_received.set()
|
||||
|
||||
@crewai_event_bus.on(FlowFinishedEvent)
|
||||
def handle_flow_end(source, event):
|
||||
received_events.append(event)
|
||||
with lock:
|
||||
received_events.append(event)
|
||||
if len(received_events) == expected_event_count:
|
||||
all_events_received.set()
|
||||
|
||||
flow.kickoff(inputs={"name": "Anakin"})
|
||||
|
||||
assert all_events_received.wait(timeout=5), "Timeout waiting for all flow events"
|
||||
|
||||
# Sort events by timestamp to ensure deterministic order
|
||||
with lock:
|
||||
received_events.sort(key=lambda e: e.timestamp)
|
||||
|
||||
assert isinstance(received_events[0], FlowStartedEvent)
|
||||
assert received_events[0].flow_name == "OnboardingFlow"
|
||||
assert received_events[0].inputs == {"name": "Anakin"}
|
||||
@@ -711,25 +755,46 @@ def test_stateless_flow_event_emission():
|
||||
|
||||
flow = StatelessFlow()
|
||||
received_events = []
|
||||
lock = threading.Lock()
|
||||
all_events_received = threading.Event()
|
||||
expected_event_count = 6 # 1 FlowStarted + 2 MethodExecutionStarted + 2 MethodExecutionFinished + 1 FlowFinished
|
||||
|
||||
@crewai_event_bus.on(FlowStartedEvent)
|
||||
def handle_flow_start(source, event):
|
||||
received_events.append(event)
|
||||
with lock:
|
||||
received_events.append(event)
|
||||
if len(received_events) == expected_event_count:
|
||||
all_events_received.set()
|
||||
|
||||
@crewai_event_bus.on(MethodExecutionStartedEvent)
|
||||
def handle_method_start(source, event):
|
||||
received_events.append(event)
|
||||
with lock:
|
||||
received_events.append(event)
|
||||
if len(received_events) == expected_event_count:
|
||||
all_events_received.set()
|
||||
|
||||
@crewai_event_bus.on(MethodExecutionFinishedEvent)
|
||||
def handle_method_end(source, event):
|
||||
received_events.append(event)
|
||||
with lock:
|
||||
received_events.append(event)
|
||||
if len(received_events) == expected_event_count:
|
||||
all_events_received.set()
|
||||
|
||||
@crewai_event_bus.on(FlowFinishedEvent)
|
||||
def handle_flow_end(source, event):
|
||||
received_events.append(event)
|
||||
with lock:
|
||||
received_events.append(event)
|
||||
if len(received_events) == expected_event_count:
|
||||
all_events_received.set()
|
||||
|
||||
flow.kickoff()
|
||||
|
||||
assert all_events_received.wait(timeout=5), "Timeout waiting for all flow events"
|
||||
|
||||
# Sort events by timestamp to ensure deterministic order
|
||||
with lock:
|
||||
received_events.sort(key=lambda e: e.timestamp)
|
||||
|
||||
assert isinstance(received_events[0], FlowStartedEvent)
|
||||
assert received_events[0].flow_name == "StatelessFlow"
|
||||
assert received_events[0].inputs is None
|
||||
@@ -769,13 +834,16 @@ def test_flow_plotting():
|
||||
flow = StatelessFlow()
|
||||
flow.kickoff()
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(FlowPlotEvent)
|
||||
def handle_flow_plot(source, event):
|
||||
received_events.append(event)
|
||||
event_received.set()
|
||||
|
||||
flow.plot("test_flow")
|
||||
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for plot event"
|
||||
assert len(received_events) == 1
|
||||
assert isinstance(received_events[0], FlowPlotEvent)
|
||||
assert received_events[0].flow_name == "StatelessFlow"
|
||||
|
||||
@@ -1218,7 +1218,7 @@ def test_create_directory_false():
|
||||
assert not resolved_dir.exists()
|
||||
|
||||
with pytest.raises(
|
||||
RuntimeError, match="Directory .* does not exist and create_directory is False"
|
||||
RuntimeError, match=r"Directory .* does not exist and create_directory is False"
|
||||
):
|
||||
task._save_file("test content")
|
||||
|
||||
@@ -1635,3 +1635,48 @@ def test_task_interpolation_with_hyphens():
|
||||
assert "say hello world" in task.prompt()
|
||||
|
||||
assert result.raw == "Hello, World!"
|
||||
|
||||
|
||||
def test_task_copy_with_none_context():
|
||||
original_task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
context=None
|
||||
)
|
||||
|
||||
new_task = original_task.copy(agents=[], task_mapping={})
|
||||
assert original_task.context is None
|
||||
assert new_task.context is None
|
||||
|
||||
|
||||
def test_task_copy_with_not_specified_context():
|
||||
from crewai.utilities.constants import NOT_SPECIFIED
|
||||
original_task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
)
|
||||
|
||||
new_task = original_task.copy(agents=[], task_mapping={})
|
||||
assert original_task.context is NOT_SPECIFIED
|
||||
assert new_task.context is NOT_SPECIFIED
|
||||
|
||||
|
||||
def test_task_copy_with_list_context():
|
||||
"""Test that copying a task with list context works correctly."""
|
||||
task1 = Task(
|
||||
description="Task 1",
|
||||
expected_output="Output 1"
|
||||
)
|
||||
task2 = Task(
|
||||
description="Task 2",
|
||||
expected_output="Output 2",
|
||||
context=[task1]
|
||||
)
|
||||
|
||||
task_mapping = {task1.key: task1}
|
||||
|
||||
copied_task2 = task2.copy(agents=[], task_mapping=task_mapping)
|
||||
|
||||
assert isinstance(copied_task2.context, list)
|
||||
assert len(copied_task2.context) == 1
|
||||
assert copied_task2.context[0] is task1
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import threading
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
@@ -175,78 +176,92 @@ def test_task_guardrail_process_output(task_output):
|
||||
def test_guardrail_emits_events(sample_agent):
|
||||
started_guardrail = []
|
||||
completed_guardrail = []
|
||||
all_events_received = threading.Event()
|
||||
expected_started = 3 # 2 from first task, 1 from second
|
||||
expected_completed = 3 # 2 from first task, 1 from second
|
||||
|
||||
task = Task(
|
||||
task1 = Task(
|
||||
description="Gather information about available books on the First World War",
|
||||
agent=sample_agent,
|
||||
expected_output="A list of available books on the First World War",
|
||||
guardrail="Ensure the authors are from Italy",
|
||||
)
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(LLMGuardrailStartedEvent)
|
||||
def handle_guardrail_started(source, event):
|
||||
assert source == task
|
||||
started_guardrail.append(
|
||||
{"guardrail": event.guardrail, "retry_count": event.retry_count}
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(LLMGuardrailCompletedEvent)
|
||||
def handle_guardrail_completed(source, event):
|
||||
assert source == task
|
||||
completed_guardrail.append(
|
||||
{
|
||||
"success": event.success,
|
||||
"result": event.result,
|
||||
"error": event.error,
|
||||
"retry_count": event.retry_count,
|
||||
}
|
||||
)
|
||||
|
||||
result = task.execute_sync(agent=sample_agent)
|
||||
|
||||
def custom_guardrail(result: TaskOutput):
|
||||
return (True, "good result from callable function")
|
||||
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Output",
|
||||
guardrail=custom_guardrail,
|
||||
@crewai_event_bus.on(LLMGuardrailStartedEvent)
|
||||
def handle_guardrail_started(source, event):
|
||||
started_guardrail.append(
|
||||
{"guardrail": event.guardrail, "retry_count": event.retry_count}
|
||||
)
|
||||
if (
|
||||
len(started_guardrail) >= expected_started
|
||||
and len(completed_guardrail) >= expected_completed
|
||||
):
|
||||
all_events_received.set()
|
||||
|
||||
task.execute_sync(agent=sample_agent)
|
||||
@crewai_event_bus.on(LLMGuardrailCompletedEvent)
|
||||
def handle_guardrail_completed(source, event):
|
||||
completed_guardrail.append(
|
||||
{
|
||||
"success": event.success,
|
||||
"result": event.result,
|
||||
"error": event.error,
|
||||
"retry_count": event.retry_count,
|
||||
}
|
||||
)
|
||||
if (
|
||||
len(started_guardrail) >= expected_started
|
||||
and len(completed_guardrail) >= expected_completed
|
||||
):
|
||||
all_events_received.set()
|
||||
|
||||
expected_started_events = [
|
||||
{"guardrail": "Ensure the authors are from Italy", "retry_count": 0},
|
||||
{"guardrail": "Ensure the authors are from Italy", "retry_count": 1},
|
||||
{
|
||||
"guardrail": """def custom_guardrail(result: TaskOutput):
|
||||
return (True, "good result from callable function")""",
|
||||
"retry_count": 0,
|
||||
},
|
||||
]
|
||||
result = task1.execute_sync(agent=sample_agent)
|
||||
|
||||
expected_completed_events = [
|
||||
{
|
||||
"success": False,
|
||||
"result": None,
|
||||
"error": "The task result does not comply with the guardrail because none of "
|
||||
"the listed authors are from Italy. All authors mentioned are from "
|
||||
"different countries, including Germany, the UK, the USA, and others, "
|
||||
"which violates the requirement that authors must be Italian.",
|
||||
"retry_count": 0,
|
||||
},
|
||||
{"success": True, "result": result.raw, "error": None, "retry_count": 1},
|
||||
{
|
||||
"success": True,
|
||||
"result": "good result from callable function",
|
||||
"error": None,
|
||||
"retry_count": 0,
|
||||
},
|
||||
]
|
||||
assert started_guardrail == expected_started_events
|
||||
assert completed_guardrail == expected_completed_events
|
||||
def custom_guardrail(result: TaskOutput):
|
||||
return (True, "good result from callable function")
|
||||
|
||||
task2 = Task(
|
||||
description="Test task",
|
||||
expected_output="Output",
|
||||
guardrail=custom_guardrail,
|
||||
)
|
||||
|
||||
task2.execute_sync(agent=sample_agent)
|
||||
|
||||
# Wait for all events to be received
|
||||
assert all_events_received.wait(timeout=10), (
|
||||
"Timeout waiting for all guardrail events"
|
||||
)
|
||||
|
||||
expected_started_events = [
|
||||
{"guardrail": "Ensure the authors are from Italy", "retry_count": 0},
|
||||
{"guardrail": "Ensure the authors are from Italy", "retry_count": 1},
|
||||
{
|
||||
"guardrail": """def custom_guardrail(result: TaskOutput):
|
||||
return (True, "good result from callable function")""",
|
||||
"retry_count": 0,
|
||||
},
|
||||
]
|
||||
|
||||
expected_completed_events = [
|
||||
{
|
||||
"success": False,
|
||||
"result": None,
|
||||
"error": "The task result does not comply with the guardrail because none of "
|
||||
"the listed authors are from Italy. All authors mentioned are from "
|
||||
"different countries, including Germany, the UK, the USA, and others, "
|
||||
"which violates the requirement that authors must be Italian.",
|
||||
"retry_count": 0,
|
||||
},
|
||||
{"success": True, "result": result.raw, "error": None, "retry_count": 1},
|
||||
{
|
||||
"success": True,
|
||||
"result": "good result from callable function",
|
||||
"error": None,
|
||||
"retry_count": 0,
|
||||
},
|
||||
]
|
||||
assert started_guardrail == expected_started_events
|
||||
assert completed_guardrail == expected_completed_events
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import datetime
|
||||
import json
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@@ -32,7 +33,7 @@ class RandomNumberTool(BaseTool):
|
||||
args_schema: type[BaseModel] = RandomNumberToolInput
|
||||
|
||||
def _run(self, min_value: int, max_value: int) -> int:
|
||||
return random.randint(min_value, max_value)
|
||||
return random.randint(min_value, max_value) # noqa: S311
|
||||
|
||||
|
||||
# Example agent and task
|
||||
@@ -470,13 +471,21 @@ def test_tool_selection_error_event_direct():
|
||||
)
|
||||
|
||||
received_events = []
|
||||
first_event_received = threading.Event()
|
||||
second_event_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(ToolSelectionErrorEvent)
|
||||
def event_handler(source, event):
|
||||
received_events.append(event)
|
||||
if event.tool_name == "Non Existent Tool":
|
||||
first_event_received.set()
|
||||
elif event.tool_name == "":
|
||||
second_event_received.set()
|
||||
|
||||
with pytest.raises(Exception):
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
tool_usage._select_tool("Non Existent Tool")
|
||||
|
||||
assert first_event_received.wait(timeout=5), "Timeout waiting for first event"
|
||||
assert len(received_events) == 1
|
||||
event = received_events[0]
|
||||
assert isinstance(event, ToolSelectionErrorEvent)
|
||||
@@ -488,12 +497,12 @@ def test_tool_selection_error_event_direct():
|
||||
assert "A test tool" in event.tool_class
|
||||
assert "don't exist" in event.error
|
||||
|
||||
received_events.clear()
|
||||
with pytest.raises(Exception):
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
tool_usage._select_tool("")
|
||||
|
||||
assert len(received_events) == 1
|
||||
event = received_events[0]
|
||||
assert second_event_received.wait(timeout=5), "Timeout waiting for second event"
|
||||
assert len(received_events) == 2
|
||||
event = received_events[1]
|
||||
assert isinstance(event, ToolSelectionErrorEvent)
|
||||
assert event.agent_key == "test_key"
|
||||
assert event.agent_role == "test_role"
|
||||
@@ -562,7 +571,7 @@ def test_tool_validate_input_error_event():
|
||||
|
||||
# Test invalid input
|
||||
invalid_input = "invalid json {[}"
|
||||
with pytest.raises(Exception):
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
tool_usage._validate_tool_input(invalid_input)
|
||||
|
||||
# Verify event was emitted
|
||||
@@ -616,12 +625,13 @@ def test_tool_usage_finished_event_with_result():
|
||||
action=MagicMock(),
|
||||
)
|
||||
|
||||
# Track received events
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(ToolUsageFinishedEvent)
|
||||
def event_handler(source, event):
|
||||
received_events.append(event)
|
||||
event_received.set()
|
||||
|
||||
# Call on_tool_use_finished with test data
|
||||
started_at = time.time()
|
||||
@@ -634,7 +644,7 @@ def test_tool_usage_finished_event_with_result():
|
||||
result=result,
|
||||
)
|
||||
|
||||
# Verify event was emitted
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for event"
|
||||
assert len(received_events) == 1, "Expected one event to be emitted"
|
||||
event = received_events[0]
|
||||
assert isinstance(event, ToolUsageFinishedEvent)
|
||||
@@ -695,12 +705,13 @@ def test_tool_usage_finished_event_with_cached_result():
|
||||
action=MagicMock(),
|
||||
)
|
||||
|
||||
# Track received events
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(ToolUsageFinishedEvent)
|
||||
def event_handler(source, event):
|
||||
received_events.append(event)
|
||||
event_received.set()
|
||||
|
||||
# Call on_tool_use_finished with test data and from_cache=True
|
||||
started_at = time.time()
|
||||
@@ -713,7 +724,7 @@ def test_tool_usage_finished_event_with_cached_result():
|
||||
result=result,
|
||||
)
|
||||
|
||||
# Verify event was emitted
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for event"
|
||||
assert len(received_events) == 1, "Expected one event to be emitted"
|
||||
event = received_events[0]
|
||||
assert isinstance(event, ToolUsageFinishedEvent)
|
||||
|
||||
@@ -14,6 +14,7 @@ from crewai.events.listeners.tracing.trace_listener import (
|
||||
)
|
||||
from crewai.events.listeners.tracing.types import TraceEvent
|
||||
from crewai.flow.flow import Flow, start
|
||||
from tests.utils import wait_for_event_handlers
|
||||
|
||||
|
||||
class TestTraceListenerSetup:
|
||||
@@ -39,38 +40,44 @@ class TestTraceListenerSetup:
|
||||
):
|
||||
yield
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_event_bus(self):
|
||||
"""Clear event bus listeners before and after each test"""
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
|
||||
# Store original handlers
|
||||
original_handlers = crewai_event_bus._handlers.copy()
|
||||
|
||||
# Clear for test
|
||||
crewai_event_bus._handlers.clear()
|
||||
|
||||
yield
|
||||
|
||||
# Restore original state
|
||||
crewai_event_bus._handlers.clear()
|
||||
crewai_event_bus._handlers.update(original_handlers)
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_tracing_singletons(self):
|
||||
"""Reset tracing singleton instances between tests"""
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.event_listener import EventListener
|
||||
|
||||
# Clear event bus handlers BEFORE creating any new singletons
|
||||
with crewai_event_bus._rwlock.w_locked():
|
||||
crewai_event_bus._sync_handlers = {}
|
||||
crewai_event_bus._async_handlers = {}
|
||||
crewai_event_bus._handler_dependencies = {}
|
||||
crewai_event_bus._execution_plan_cache = {}
|
||||
|
||||
# Reset TraceCollectionListener singleton
|
||||
if hasattr(TraceCollectionListener, "_instance"):
|
||||
TraceCollectionListener._instance = None
|
||||
TraceCollectionListener._initialized = False
|
||||
|
||||
# Reset EventListener singleton
|
||||
if hasattr(EventListener, "_instance"):
|
||||
EventListener._instance = None
|
||||
|
||||
yield
|
||||
|
||||
# Clean up after test
|
||||
with crewai_event_bus._rwlock.w_locked():
|
||||
crewai_event_bus._sync_handlers = {}
|
||||
crewai_event_bus._async_handlers = {}
|
||||
crewai_event_bus._handler_dependencies = {}
|
||||
crewai_event_bus._execution_plan_cache = {}
|
||||
|
||||
if hasattr(TraceCollectionListener, "_instance"):
|
||||
TraceCollectionListener._instance = None
|
||||
TraceCollectionListener._initialized = False
|
||||
|
||||
if hasattr(EventListener, "_instance"):
|
||||
EventListener._instance = None
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_plus_api_calls(self):
|
||||
"""Mock all PlusAPI HTTP calls to avoid network requests"""
|
||||
@@ -167,15 +174,26 @@ class TestTraceListenerSetup:
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
|
||||
trace_listener = None
|
||||
for handler_list in crewai_event_bus._handlers.values():
|
||||
for handler in handler_list:
|
||||
if hasattr(handler, "__self__") and isinstance(
|
||||
handler.__self__, TraceCollectionListener
|
||||
):
|
||||
trace_listener = handler.__self__
|
||||
with crewai_event_bus._rwlock.r_locked():
|
||||
for handler_set in crewai_event_bus._sync_handlers.values():
|
||||
for handler in handler_set:
|
||||
if hasattr(handler, "__self__") and isinstance(
|
||||
handler.__self__, TraceCollectionListener
|
||||
):
|
||||
trace_listener = handler.__self__
|
||||
break
|
||||
if trace_listener:
|
||||
break
|
||||
if trace_listener:
|
||||
break
|
||||
if not trace_listener:
|
||||
for handler_set in crewai_event_bus._async_handlers.values():
|
||||
for handler in handler_set:
|
||||
if hasattr(handler, "__self__") and isinstance(
|
||||
handler.__self__, TraceCollectionListener
|
||||
):
|
||||
trace_listener = handler.__self__
|
||||
break
|
||||
if trace_listener:
|
||||
break
|
||||
|
||||
if not trace_listener:
|
||||
pytest.skip(
|
||||
@@ -221,6 +239,7 @@ class TestTraceListenerSetup:
|
||||
wraps=trace_listener.batch_manager.add_event,
|
||||
) as add_event_mock:
|
||||
crew.kickoff()
|
||||
wait_for_event_handlers()
|
||||
|
||||
assert add_event_mock.call_count >= 2
|
||||
|
||||
@@ -267,24 +286,22 @@ class TestTraceListenerSetup:
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
|
||||
trace_handlers = []
|
||||
for handlers in crewai_event_bus._handlers.values():
|
||||
for handler in handlers:
|
||||
if hasattr(handler, "__self__") and isinstance(
|
||||
handler.__self__, TraceCollectionListener
|
||||
):
|
||||
trace_handlers.append(handler)
|
||||
elif hasattr(handler, "__name__") and any(
|
||||
trace_name in handler.__name__
|
||||
for trace_name in [
|
||||
"on_crew_started",
|
||||
"on_crew_completed",
|
||||
"on_flow_started",
|
||||
]
|
||||
):
|
||||
trace_handlers.append(handler)
|
||||
with crewai_event_bus._rwlock.r_locked():
|
||||
for handlers in crewai_event_bus._sync_handlers.values():
|
||||
for handler in handlers:
|
||||
if hasattr(handler, "__self__") and isinstance(
|
||||
handler.__self__, TraceCollectionListener
|
||||
):
|
||||
trace_handlers.append(handler)
|
||||
for handlers in crewai_event_bus._async_handlers.values():
|
||||
for handler in handlers:
|
||||
if hasattr(handler, "__self__") and isinstance(
|
||||
handler.__self__, TraceCollectionListener
|
||||
):
|
||||
trace_handlers.append(handler)
|
||||
|
||||
assert len(trace_handlers) == 0, (
|
||||
f"Found {len(trace_handlers)} trace handlers when tracing should be disabled"
|
||||
f"Found {len(trace_handlers)} TraceCollectionListener handlers when tracing should be disabled"
|
||||
)
|
||||
|
||||
def test_trace_listener_setup_correctly_for_crew(self):
|
||||
@@ -385,6 +402,7 @@ class TestTraceListenerSetup:
|
||||
):
|
||||
crew = Crew(agents=[agent], tasks=[task], tracing=True)
|
||||
crew.kickoff()
|
||||
wait_for_event_handlers()
|
||||
|
||||
mock_plus_api_class.assert_called_with(api_key="mock_token_12345")
|
||||
|
||||
@@ -396,15 +414,33 @@ class TestTraceListenerSetup:
|
||||
def teardown_method(self):
|
||||
"""Cleanup after each test method"""
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.event_listener import EventListener
|
||||
|
||||
crewai_event_bus._handlers.clear()
|
||||
with crewai_event_bus._rwlock.w_locked():
|
||||
crewai_event_bus._sync_handlers = {}
|
||||
crewai_event_bus._async_handlers = {}
|
||||
crewai_event_bus._handler_dependencies = {}
|
||||
crewai_event_bus._execution_plan_cache = {}
|
||||
|
||||
# Reset EventListener singleton
|
||||
if hasattr(EventListener, "_instance"):
|
||||
EventListener._instance = None
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
"""Final cleanup after all tests in this class"""
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.event_listener import EventListener
|
||||
|
||||
crewai_event_bus._handlers.clear()
|
||||
with crewai_event_bus._rwlock.w_locked():
|
||||
crewai_event_bus._sync_handlers = {}
|
||||
crewai_event_bus._async_handlers = {}
|
||||
crewai_event_bus._handler_dependencies = {}
|
||||
crewai_event_bus._execution_plan_cache = {}
|
||||
|
||||
# Reset EventListener singleton
|
||||
if hasattr(EventListener, "_instance"):
|
||||
EventListener._instance = None
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_first_time_user_trace_collection_with_timeout(self, mock_plus_api_calls):
|
||||
@@ -466,6 +502,7 @@ class TestTraceListenerSetup:
|
||||
) as mock_add_event,
|
||||
):
|
||||
result = crew.kickoff()
|
||||
wait_for_event_handlers()
|
||||
assert result is not None
|
||||
|
||||
assert mock_handle_completion.call_count >= 1
|
||||
@@ -543,6 +580,7 @@ class TestTraceListenerSetup:
|
||||
)
|
||||
|
||||
crew.kickoff()
|
||||
wait_for_event_handlers()
|
||||
|
||||
assert mock_handle_completion.call_count >= 1, (
|
||||
"handle_execution_completion should be called"
|
||||
@@ -561,7 +599,6 @@ class TestTraceListenerSetup:
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_first_time_user_trace_consolidation_logic(self, mock_plus_api_calls):
|
||||
"""Test the consolidation logic for first-time users vs regular tracing"""
|
||||
|
||||
with (
|
||||
patch.dict(os.environ, {"CREWAI_TRACING_ENABLED": "false"}),
|
||||
patch(
|
||||
@@ -579,7 +616,9 @@ class TestTraceListenerSetup:
|
||||
):
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
|
||||
crewai_event_bus._handlers.clear()
|
||||
with crewai_event_bus._rwlock.w_locked():
|
||||
crewai_event_bus._sync_handlers = {}
|
||||
crewai_event_bus._async_handlers = {}
|
||||
|
||||
trace_listener = TraceCollectionListener()
|
||||
trace_listener.setup_listeners(crewai_event_bus)
|
||||
@@ -600,6 +639,9 @@ class TestTraceListenerSetup:
|
||||
with patch.object(TraceBatchManager, "initialize_batch") as mock_initialize:
|
||||
result = crew.kickoff()
|
||||
|
||||
assert trace_listener.batch_manager.wait_for_pending_events(timeout=5.0), (
|
||||
"Timeout waiting for trace event handlers to complete"
|
||||
)
|
||||
assert mock_initialize.call_count >= 1
|
||||
assert mock_initialize.call_args_list[0][1]["use_ephemeral"] is True
|
||||
assert result is not None
|
||||
@@ -700,6 +742,7 @@ class TestTraceListenerSetup:
|
||||
) as mock_mark_failed,
|
||||
):
|
||||
crew.kickoff()
|
||||
wait_for_event_handlers()
|
||||
|
||||
mock_mark_failed.assert_called_once()
|
||||
call_args = mock_mark_failed.call_args_list[0]
|
||||
|
||||
206
lib/crewai/tests/utilities/events/test_async_event_bus.py
Normal file
206
lib/crewai/tests/utilities/events/test_async_event_bus.py
Normal file
@@ -0,0 +1,206 @@
|
||||
"""Tests for async event handling in CrewAI event bus.
|
||||
|
||||
This module tests async handler registration, execution, and the aemit method.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
|
||||
|
||||
class AsyncTestEvent(BaseEvent):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_handler_execution():
|
||||
received_events = []
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(AsyncTestEvent)
|
||||
async def async_handler(source: object, event: BaseEvent) -> None:
|
||||
await asyncio.sleep(0.01)
|
||||
received_events.append(event)
|
||||
|
||||
event = AsyncTestEvent(type="async_test")
|
||||
crewai_event_bus.emit("test_source", event)
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert len(received_events) == 1
|
||||
assert received_events[0] == event
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aemit_with_async_handlers():
|
||||
received_events = []
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(AsyncTestEvent)
|
||||
async def async_handler(source: object, event: BaseEvent) -> None:
|
||||
await asyncio.sleep(0.01)
|
||||
received_events.append(event)
|
||||
|
||||
event = AsyncTestEvent(type="async_test")
|
||||
await crewai_event_bus.aemit("test_source", event)
|
||||
|
||||
assert len(received_events) == 1
|
||||
assert received_events[0] == event
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_async_handlers():
|
||||
received_events_1 = []
|
||||
received_events_2 = []
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(AsyncTestEvent)
|
||||
async def handler_1(source: object, event: BaseEvent) -> None:
|
||||
await asyncio.sleep(0.01)
|
||||
received_events_1.append(event)
|
||||
|
||||
@crewai_event_bus.on(AsyncTestEvent)
|
||||
async def handler_2(source: object, event: BaseEvent) -> None:
|
||||
await asyncio.sleep(0.02)
|
||||
received_events_2.append(event)
|
||||
|
||||
event = AsyncTestEvent(type="async_test")
|
||||
await crewai_event_bus.aemit("test_source", event)
|
||||
|
||||
assert len(received_events_1) == 1
|
||||
assert len(received_events_2) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_sync_and_async_handlers():
|
||||
sync_events = []
|
||||
async_events = []
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(AsyncTestEvent)
|
||||
def sync_handler(source: object, event: BaseEvent) -> None:
|
||||
sync_events.append(event)
|
||||
|
||||
@crewai_event_bus.on(AsyncTestEvent)
|
||||
async def async_handler(source: object, event: BaseEvent) -> None:
|
||||
await asyncio.sleep(0.01)
|
||||
async_events.append(event)
|
||||
|
||||
event = AsyncTestEvent(type="mixed_test")
|
||||
crewai_event_bus.emit("test_source", event)
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert len(sync_events) == 1
|
||||
assert len(async_events) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_handler_error_handling():
|
||||
successful_handler_called = []
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(AsyncTestEvent)
|
||||
async def failing_handler(source: object, event: BaseEvent) -> None:
|
||||
raise ValueError("Async handler error")
|
||||
|
||||
@crewai_event_bus.on(AsyncTestEvent)
|
||||
async def successful_handler(source: object, event: BaseEvent) -> None:
|
||||
await asyncio.sleep(0.01)
|
||||
successful_handler_called.append(True)
|
||||
|
||||
event = AsyncTestEvent(type="error_test")
|
||||
await crewai_event_bus.aemit("test_source", event)
|
||||
|
||||
assert len(successful_handler_called) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aemit_with_no_handlers():
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
event = AsyncTestEvent(type="no_handlers")
|
||||
await crewai_event_bus.aemit("test_source", event)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_handler_registration_via_register_handler():
|
||||
received_events = []
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
async def custom_async_handler(source: object, event: BaseEvent) -> None:
|
||||
await asyncio.sleep(0.01)
|
||||
received_events.append(event)
|
||||
|
||||
crewai_event_bus.register_handler(AsyncTestEvent, custom_async_handler)
|
||||
|
||||
event = AsyncTestEvent(type="register_test")
|
||||
await crewai_event_bus.aemit("test_source", event)
|
||||
|
||||
assert len(received_events) == 1
|
||||
assert received_events[0] == event
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emit_async_handlers_fire_and_forget():
|
||||
received_events = []
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(AsyncTestEvent)
|
||||
async def slow_async_handler(source: object, event: BaseEvent) -> None:
|
||||
await asyncio.sleep(0.05)
|
||||
received_events.append(event)
|
||||
|
||||
event = AsyncTestEvent(type="fire_forget_test")
|
||||
crewai_event_bus.emit("test_source", event)
|
||||
|
||||
assert len(received_events) == 0
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert len(received_events) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scoped_handlers_with_async():
|
||||
received_before = []
|
||||
received_during = []
|
||||
received_after = []
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(AsyncTestEvent)
|
||||
async def before_handler(source: object, event: BaseEvent) -> None:
|
||||
received_before.append(event)
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(AsyncTestEvent)
|
||||
async def scoped_handler(source: object, event: BaseEvent) -> None:
|
||||
received_during.append(event)
|
||||
|
||||
event1 = AsyncTestEvent(type="during_scope")
|
||||
await crewai_event_bus.aemit("test_source", event1)
|
||||
|
||||
assert len(received_before) == 0
|
||||
assert len(received_during) == 1
|
||||
|
||||
@crewai_event_bus.on(AsyncTestEvent)
|
||||
async def after_handler(source: object, event: BaseEvent) -> None:
|
||||
received_after.append(event)
|
||||
|
||||
event2 = AsyncTestEvent(type="after_scope")
|
||||
await crewai_event_bus.aemit("test_source", event2)
|
||||
|
||||
assert len(received_before) == 1
|
||||
assert len(received_during) == 1
|
||||
assert len(received_after) == 1
|
||||
@@ -1,3 +1,4 @@
|
||||
import threading
|
||||
from unittest.mock import Mock
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
@@ -21,27 +22,42 @@ def test_specific_event_handler():
|
||||
mock_handler.assert_called_once_with("source_object", event)
|
||||
|
||||
|
||||
def test_wildcard_event_handler():
|
||||
mock_handler = Mock()
|
||||
def test_multiple_handlers_same_event():
|
||||
"""Test that multiple handlers can be registered for the same event type."""
|
||||
mock_handler1 = Mock()
|
||||
mock_handler2 = Mock()
|
||||
|
||||
@crewai_event_bus.on(BaseEvent)
|
||||
def handler(source, event):
|
||||
mock_handler(source, event)
|
||||
@crewai_event_bus.on(TestEvent)
|
||||
def handler1(source, event):
|
||||
mock_handler1(source, event)
|
||||
|
||||
@crewai_event_bus.on(TestEvent)
|
||||
def handler2(source, event):
|
||||
mock_handler2(source, event)
|
||||
|
||||
event = TestEvent(type="test_event")
|
||||
crewai_event_bus.emit("source_object", event)
|
||||
|
||||
mock_handler.assert_called_once_with("source_object", event)
|
||||
mock_handler1.assert_called_once_with("source_object", event)
|
||||
mock_handler2.assert_called_once_with("source_object", event)
|
||||
|
||||
|
||||
def test_event_bus_error_handling(capfd):
|
||||
@crewai_event_bus.on(BaseEvent)
|
||||
def test_event_bus_error_handling():
|
||||
"""Test that handler exceptions are caught and don't break the event bus."""
|
||||
called = threading.Event()
|
||||
error_caught = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(TestEvent)
|
||||
def broken_handler(source, event):
|
||||
called.set()
|
||||
raise ValueError("Simulated handler failure")
|
||||
|
||||
@crewai_event_bus.on(TestEvent)
|
||||
def working_handler(source, event):
|
||||
error_caught.set()
|
||||
|
||||
event = TestEvent(type="test_event")
|
||||
crewai_event_bus.emit("source_object", event)
|
||||
|
||||
out, err = capfd.readouterr()
|
||||
assert "Simulated handler failure" in out
|
||||
assert "Handler 'broken_handler' failed" in out
|
||||
assert called.wait(timeout=2), "Broken handler was never called"
|
||||
assert error_caught.wait(timeout=2), "Working handler was never called after error"
|
||||
|
||||
264
lib/crewai/tests/utilities/events/test_rw_lock.py
Normal file
264
lib/crewai/tests/utilities/events/test_rw_lock.py
Normal file
@@ -0,0 +1,264 @@
|
||||
"""Tests for read-write lock implementation.
|
||||
|
||||
This module tests the RWLock class for correct concurrent read and write behavior.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
|
||||
from crewai.events.utils.rw_lock import RWLock
|
||||
|
||||
|
||||
def test_multiple_readers_concurrent():
|
||||
lock = RWLock()
|
||||
active_readers = [0]
|
||||
max_concurrent_readers = [0]
|
||||
lock_for_counters = threading.Lock()
|
||||
|
||||
def reader(reader_id: int) -> None:
|
||||
with lock.r_locked():
|
||||
with lock_for_counters:
|
||||
active_readers[0] += 1
|
||||
max_concurrent_readers[0] = max(
|
||||
max_concurrent_readers[0], active_readers[0]
|
||||
)
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
with lock_for_counters:
|
||||
active_readers[0] -= 1
|
||||
|
||||
threads = [threading.Thread(target=reader, args=(i,)) for i in range(5)]
|
||||
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
assert max_concurrent_readers[0] == 5
|
||||
|
||||
|
||||
def test_writer_blocks_readers():
|
||||
lock = RWLock()
|
||||
writer_holding_lock = [False]
|
||||
reader_accessed_during_write = [False]
|
||||
|
||||
def writer() -> None:
|
||||
with lock.w_locked():
|
||||
writer_holding_lock[0] = True
|
||||
time.sleep(0.2)
|
||||
writer_holding_lock[0] = False
|
||||
|
||||
def reader() -> None:
|
||||
time.sleep(0.05)
|
||||
with lock.r_locked():
|
||||
if writer_holding_lock[0]:
|
||||
reader_accessed_during_write[0] = True
|
||||
|
||||
writer_thread = threading.Thread(target=writer)
|
||||
reader_thread = threading.Thread(target=reader)
|
||||
|
||||
writer_thread.start()
|
||||
reader_thread.start()
|
||||
|
||||
writer_thread.join()
|
||||
reader_thread.join()
|
||||
|
||||
assert not reader_accessed_during_write[0]
|
||||
|
||||
|
||||
def test_writer_blocks_other_writers():
|
||||
lock = RWLock()
|
||||
execution_order: list[int] = []
|
||||
lock_for_order = threading.Lock()
|
||||
|
||||
def writer(writer_id: int) -> None:
|
||||
with lock.w_locked():
|
||||
with lock_for_order:
|
||||
execution_order.append(writer_id)
|
||||
time.sleep(0.1)
|
||||
|
||||
threads = [threading.Thread(target=writer, args=(i,)) for i in range(3)]
|
||||
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
assert len(execution_order) == 3
|
||||
assert len(set(execution_order)) == 3
|
||||
|
||||
|
||||
def test_readers_block_writers():
|
||||
lock = RWLock()
|
||||
reader_count = [0]
|
||||
writer_accessed_during_read = [False]
|
||||
lock_for_counters = threading.Lock()
|
||||
|
||||
def reader() -> None:
|
||||
with lock.r_locked():
|
||||
with lock_for_counters:
|
||||
reader_count[0] += 1
|
||||
time.sleep(0.2)
|
||||
with lock_for_counters:
|
||||
reader_count[0] -= 1
|
||||
|
||||
def writer() -> None:
|
||||
time.sleep(0.05)
|
||||
with lock.w_locked():
|
||||
with lock_for_counters:
|
||||
if reader_count[0] > 0:
|
||||
writer_accessed_during_read[0] = True
|
||||
|
||||
reader_thread = threading.Thread(target=reader)
|
||||
writer_thread = threading.Thread(target=writer)
|
||||
|
||||
reader_thread.start()
|
||||
writer_thread.start()
|
||||
|
||||
reader_thread.join()
|
||||
writer_thread.join()
|
||||
|
||||
assert not writer_accessed_during_read[0]
|
||||
|
||||
|
||||
def test_alternating_readers_and_writers():
|
||||
lock = RWLock()
|
||||
operations: list[str] = []
|
||||
lock_for_operations = threading.Lock()
|
||||
|
||||
def reader(reader_id: int) -> None:
|
||||
with lock.r_locked():
|
||||
with lock_for_operations:
|
||||
operations.append(f"r{reader_id}_start")
|
||||
time.sleep(0.05)
|
||||
with lock_for_operations:
|
||||
operations.append(f"r{reader_id}_end")
|
||||
|
||||
def writer(writer_id: int) -> None:
|
||||
with lock.w_locked():
|
||||
with lock_for_operations:
|
||||
operations.append(f"w{writer_id}_start")
|
||||
time.sleep(0.05)
|
||||
with lock_for_operations:
|
||||
operations.append(f"w{writer_id}_end")
|
||||
|
||||
threads = [
|
||||
threading.Thread(target=reader, args=(0,)),
|
||||
threading.Thread(target=writer, args=(0,)),
|
||||
threading.Thread(target=reader, args=(1,)),
|
||||
threading.Thread(target=writer, args=(1,)),
|
||||
threading.Thread(target=reader, args=(2,)),
|
||||
]
|
||||
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
assert len(operations) == 10
|
||||
|
||||
start_ops = [op for op in operations if "_start" in op]
|
||||
end_ops = [op for op in operations if "_end" in op]
|
||||
assert len(start_ops) == 5
|
||||
assert len(end_ops) == 5
|
||||
|
||||
|
||||
def test_context_manager_releases_on_exception():
|
||||
lock = RWLock()
|
||||
exception_raised = False
|
||||
|
||||
try:
|
||||
with lock.r_locked():
|
||||
raise ValueError("Test exception")
|
||||
except ValueError:
|
||||
exception_raised = True
|
||||
|
||||
assert exception_raised
|
||||
|
||||
acquired = False
|
||||
with lock.w_locked():
|
||||
acquired = True
|
||||
|
||||
assert acquired
|
||||
|
||||
|
||||
def test_write_lock_releases_on_exception():
|
||||
lock = RWLock()
|
||||
exception_raised = False
|
||||
|
||||
try:
|
||||
with lock.w_locked():
|
||||
raise ValueError("Test exception")
|
||||
except ValueError:
|
||||
exception_raised = True
|
||||
|
||||
assert exception_raised
|
||||
|
||||
acquired = False
|
||||
with lock.r_locked():
|
||||
acquired = True
|
||||
|
||||
assert acquired
|
||||
|
||||
|
||||
def test_stress_many_readers_few_writers():
|
||||
lock = RWLock()
|
||||
read_count = [0]
|
||||
write_count = [0]
|
||||
lock_for_counters = threading.Lock()
|
||||
|
||||
def reader() -> None:
|
||||
for _ in range(10):
|
||||
with lock.r_locked():
|
||||
with lock_for_counters:
|
||||
read_count[0] += 1
|
||||
time.sleep(0.001)
|
||||
|
||||
def writer() -> None:
|
||||
for _ in range(5):
|
||||
with lock.w_locked():
|
||||
with lock_for_counters:
|
||||
write_count[0] += 1
|
||||
time.sleep(0.01)
|
||||
|
||||
reader_threads = [threading.Thread(target=reader) for _ in range(10)]
|
||||
writer_threads = [threading.Thread(target=writer) for _ in range(2)]
|
||||
|
||||
all_threads = reader_threads + writer_threads
|
||||
|
||||
for thread in all_threads:
|
||||
thread.start()
|
||||
|
||||
for thread in all_threads:
|
||||
thread.join()
|
||||
|
||||
assert read_count[0] == 100
|
||||
assert write_count[0] == 10
|
||||
|
||||
|
||||
def test_nested_read_locks_same_thread():
|
||||
lock = RWLock()
|
||||
nested_acquired = False
|
||||
|
||||
with lock.r_locked():
|
||||
with lock.r_locked():
|
||||
nested_acquired = True
|
||||
|
||||
assert nested_acquired
|
||||
|
||||
|
||||
def test_manual_acquire_release():
|
||||
lock = RWLock()
|
||||
|
||||
lock.r_acquire()
|
||||
lock.r_release()
|
||||
|
||||
lock.w_acquire()
|
||||
lock.w_release()
|
||||
|
||||
with lock.r_locked():
|
||||
pass
|
||||
247
lib/crewai/tests/utilities/events/test_shutdown.py
Normal file
247
lib/crewai/tests/utilities/events/test_shutdown.py
Normal file
@@ -0,0 +1,247 @@
|
||||
"""Tests for event bus shutdown and cleanup behavior.
|
||||
|
||||
This module tests graceful shutdown, task completion, and cleanup operations.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
from crewai.events.event_bus import CrewAIEventsBus
|
||||
|
||||
|
||||
class ShutdownTestEvent(BaseEvent):
|
||||
pass
|
||||
|
||||
|
||||
def test_shutdown_prevents_new_events():
|
||||
bus = CrewAIEventsBus()
|
||||
received_events = []
|
||||
|
||||
with bus.scoped_handlers():
|
||||
|
||||
@bus.on(ShutdownTestEvent)
|
||||
def handler(source: object, event: BaseEvent) -> None:
|
||||
received_events.append(event)
|
||||
|
||||
bus._shutting_down = True
|
||||
|
||||
event = ShutdownTestEvent(type="after_shutdown")
|
||||
bus.emit("test_source", event)
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
assert len(received_events) == 0
|
||||
|
||||
bus._shutting_down = False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aemit_during_shutdown():
|
||||
bus = CrewAIEventsBus()
|
||||
received_events = []
|
||||
|
||||
with bus.scoped_handlers():
|
||||
|
||||
@bus.on(ShutdownTestEvent)
|
||||
async def handler(source: object, event: BaseEvent) -> None:
|
||||
received_events.append(event)
|
||||
|
||||
bus._shutting_down = True
|
||||
|
||||
event = ShutdownTestEvent(type="aemit_during_shutdown")
|
||||
await bus.aemit("test_source", event)
|
||||
|
||||
assert len(received_events) == 0
|
||||
|
||||
bus._shutting_down = False
|
||||
|
||||
|
||||
def test_shutdown_flag_prevents_emit():
|
||||
bus = CrewAIEventsBus()
|
||||
emitted_count = [0]
|
||||
|
||||
with bus.scoped_handlers():
|
||||
|
||||
@bus.on(ShutdownTestEvent)
|
||||
def handler(source: object, event: BaseEvent) -> None:
|
||||
emitted_count[0] += 1
|
||||
|
||||
event1 = ShutdownTestEvent(type="before_shutdown")
|
||||
bus.emit("test_source", event1)
|
||||
|
||||
time.sleep(0.1)
|
||||
assert emitted_count[0] == 1
|
||||
|
||||
bus._shutting_down = True
|
||||
|
||||
event2 = ShutdownTestEvent(type="during_shutdown")
|
||||
bus.emit("test_source", event2)
|
||||
|
||||
time.sleep(0.1)
|
||||
assert emitted_count[0] == 1
|
||||
|
||||
bus._shutting_down = False
|
||||
|
||||
|
||||
def test_concurrent_access_during_shutdown_flag():
|
||||
bus = CrewAIEventsBus()
|
||||
received_events = []
|
||||
lock = threading.Lock()
|
||||
|
||||
with bus.scoped_handlers():
|
||||
|
||||
@bus.on(ShutdownTestEvent)
|
||||
def handler(source: object, event: BaseEvent) -> None:
|
||||
with lock:
|
||||
received_events.append(event)
|
||||
|
||||
def emit_events() -> None:
|
||||
for i in range(10):
|
||||
event = ShutdownTestEvent(type=f"event_{i}")
|
||||
bus.emit("source", event)
|
||||
time.sleep(0.01)
|
||||
|
||||
def set_shutdown_flag() -> None:
|
||||
time.sleep(0.05)
|
||||
bus._shutting_down = True
|
||||
|
||||
emit_thread = threading.Thread(target=emit_events)
|
||||
shutdown_thread = threading.Thread(target=set_shutdown_flag)
|
||||
|
||||
emit_thread.start()
|
||||
shutdown_thread.start()
|
||||
|
||||
emit_thread.join()
|
||||
shutdown_thread.join()
|
||||
|
||||
time.sleep(0.2)
|
||||
|
||||
assert len(received_events) < 10
|
||||
assert len(received_events) > 0
|
||||
|
||||
bus._shutting_down = False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_handlers_complete_before_shutdown_flag():
|
||||
bus = CrewAIEventsBus()
|
||||
completed_handlers = []
|
||||
|
||||
with bus.scoped_handlers():
|
||||
|
||||
@bus.on(ShutdownTestEvent)
|
||||
async def async_handler(source: object, event: BaseEvent) -> None:
|
||||
await asyncio.sleep(0.05)
|
||||
if not bus._shutting_down:
|
||||
completed_handlers.append(event)
|
||||
|
||||
for i in range(5):
|
||||
event = ShutdownTestEvent(type=f"event_{i}")
|
||||
bus.emit("source", event)
|
||||
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
assert len(completed_handlers) == 5
|
||||
|
||||
|
||||
def test_scoped_handlers_cleanup():
|
||||
bus = CrewAIEventsBus()
|
||||
received_before = []
|
||||
received_during = []
|
||||
received_after = []
|
||||
|
||||
with bus.scoped_handlers():
|
||||
|
||||
@bus.on(ShutdownTestEvent)
|
||||
def before_handler(source: object, event: BaseEvent) -> None:
|
||||
received_before.append(event)
|
||||
|
||||
with bus.scoped_handlers():
|
||||
|
||||
@bus.on(ShutdownTestEvent)
|
||||
def during_handler(source: object, event: BaseEvent) -> None:
|
||||
received_during.append(event)
|
||||
|
||||
event1 = ShutdownTestEvent(type="during")
|
||||
bus.emit("source", event1)
|
||||
time.sleep(0.1)
|
||||
|
||||
assert len(received_before) == 0
|
||||
assert len(received_during) == 1
|
||||
|
||||
event2 = ShutdownTestEvent(type="after_inner_scope")
|
||||
bus.emit("source", event2)
|
||||
time.sleep(0.1)
|
||||
|
||||
assert len(received_before) == 1
|
||||
assert len(received_during) == 1
|
||||
|
||||
event3 = ShutdownTestEvent(type="after_outer_scope")
|
||||
bus.emit("source", event3)
|
||||
time.sleep(0.1)
|
||||
|
||||
assert len(received_before) == 1
|
||||
assert len(received_during) == 1
|
||||
assert len(received_after) == 0
|
||||
|
||||
|
||||
def test_handler_registration_thread_safety():
|
||||
bus = CrewAIEventsBus()
|
||||
handlers_registered = [0]
|
||||
lock = threading.Lock()
|
||||
|
||||
with bus.scoped_handlers():
|
||||
|
||||
def register_handlers() -> None:
|
||||
for _ in range(20):
|
||||
|
||||
@bus.on(ShutdownTestEvent)
|
||||
def handler(source: object, event: BaseEvent) -> None:
|
||||
pass
|
||||
|
||||
with lock:
|
||||
handlers_registered[0] += 1
|
||||
|
||||
time.sleep(0.001)
|
||||
|
||||
threads = [threading.Thread(target=register_handlers) for _ in range(3)]
|
||||
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
assert handlers_registered[0] == 60
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_sync_async_handler_execution():
|
||||
bus = CrewAIEventsBus()
|
||||
sync_executed = []
|
||||
async_executed = []
|
||||
|
||||
with bus.scoped_handlers():
|
||||
|
||||
@bus.on(ShutdownTestEvent)
|
||||
def sync_handler(source: object, event: BaseEvent) -> None:
|
||||
time.sleep(0.01)
|
||||
sync_executed.append(event)
|
||||
|
||||
@bus.on(ShutdownTestEvent)
|
||||
async def async_handler(source: object, event: BaseEvent) -> None:
|
||||
await asyncio.sleep(0.01)
|
||||
async_executed.append(event)
|
||||
|
||||
for i in range(5):
|
||||
event = ShutdownTestEvent(type=f"event_{i}")
|
||||
bus.emit("source", event)
|
||||
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
assert len(sync_executed) == 5
|
||||
assert len(async_executed) == 5
|
||||
189
lib/crewai/tests/utilities/events/test_thread_safety.py
Normal file
189
lib/crewai/tests/utilities/events/test_thread_safety.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""Tests for thread safety in CrewAI event bus.
|
||||
|
||||
This module tests concurrent event emission and handler registration.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
|
||||
|
||||
class ThreadSafetyTestEvent(BaseEvent):
|
||||
pass
|
||||
|
||||
|
||||
def test_concurrent_emit_from_multiple_threads():
|
||||
received_events: list[BaseEvent] = []
|
||||
lock = threading.Lock()
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(ThreadSafetyTestEvent)
|
||||
def handler(source: object, event: BaseEvent) -> None:
|
||||
with lock:
|
||||
received_events.append(event)
|
||||
|
||||
threads: list[threading.Thread] = []
|
||||
num_threads = 10
|
||||
events_per_thread = 10
|
||||
|
||||
def emit_events(thread_id: int) -> None:
|
||||
for i in range(events_per_thread):
|
||||
event = ThreadSafetyTestEvent(type=f"thread_{thread_id}_event_{i}")
|
||||
crewai_event_bus.emit(f"source_{thread_id}", event)
|
||||
|
||||
for i in range(num_threads):
|
||||
thread = threading.Thread(target=emit_events, args=(i,))
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
time.sleep(0.5)
|
||||
|
||||
assert len(received_events) == num_threads * events_per_thread
|
||||
|
||||
|
||||
def test_concurrent_handler_registration():
|
||||
handlers_executed: list[int] = []
|
||||
lock = threading.Lock()
|
||||
|
||||
def create_handler(handler_id: int) -> Callable[[object, BaseEvent], None]:
|
||||
def handler(source: object, event: BaseEvent) -> None:
|
||||
with lock:
|
||||
handlers_executed.append(handler_id)
|
||||
|
||||
return handler
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
threads: list[threading.Thread] = []
|
||||
num_handlers = 20
|
||||
|
||||
def register_handler(handler_id: int) -> None:
|
||||
crewai_event_bus.register_handler(
|
||||
ThreadSafetyTestEvent, create_handler(handler_id)
|
||||
)
|
||||
|
||||
for i in range(num_handlers):
|
||||
thread = threading.Thread(target=register_handler, args=(i,))
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
event = ThreadSafetyTestEvent(type="registration_test")
|
||||
crewai_event_bus.emit("test_source", event)
|
||||
|
||||
time.sleep(0.5)
|
||||
|
||||
assert len(handlers_executed) == num_handlers
|
||||
assert set(handlers_executed) == set(range(num_handlers))
|
||||
|
||||
|
||||
def test_concurrent_emit_and_registration():
|
||||
received_events: list[BaseEvent] = []
|
||||
lock = threading.Lock()
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
def emit_continuously() -> None:
|
||||
for i in range(50):
|
||||
event = ThreadSafetyTestEvent(type=f"emit_event_{i}")
|
||||
crewai_event_bus.emit("emitter", event)
|
||||
time.sleep(0.001)
|
||||
|
||||
def register_continuously() -> None:
|
||||
for _ in range(10):
|
||||
|
||||
@crewai_event_bus.on(ThreadSafetyTestEvent)
|
||||
def handler(source: object, event: BaseEvent) -> None:
|
||||
with lock:
|
||||
received_events.append(event)
|
||||
|
||||
time.sleep(0.005)
|
||||
|
||||
emit_thread = threading.Thread(target=emit_continuously)
|
||||
register_thread = threading.Thread(target=register_continuously)
|
||||
|
||||
emit_thread.start()
|
||||
register_thread.start()
|
||||
|
||||
emit_thread.join()
|
||||
register_thread.join()
|
||||
|
||||
time.sleep(0.5)
|
||||
|
||||
assert len(received_events) > 0
|
||||
|
||||
|
||||
def test_stress_test_rapid_emit():
|
||||
received_count = [0]
|
||||
lock = threading.Lock()
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(ThreadSafetyTestEvent)
|
||||
def counter_handler(source: object, event: BaseEvent) -> None:
|
||||
with lock:
|
||||
received_count[0] += 1
|
||||
|
||||
num_events = 1000
|
||||
|
||||
for i in range(num_events):
|
||||
event = ThreadSafetyTestEvent(type=f"rapid_event_{i}")
|
||||
crewai_event_bus.emit("rapid_source", event)
|
||||
|
||||
time.sleep(1.0)
|
||||
|
||||
assert received_count[0] == num_events
|
||||
|
||||
|
||||
def test_multiple_event_types_concurrent():
|
||||
class EventTypeA(BaseEvent):
|
||||
pass
|
||||
|
||||
class EventTypeB(BaseEvent):
|
||||
pass
|
||||
|
||||
received_a: list[BaseEvent] = []
|
||||
received_b: list[BaseEvent] = []
|
||||
lock = threading.Lock()
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(EventTypeA)
|
||||
def handler_a(source: object, event: BaseEvent) -> None:
|
||||
with lock:
|
||||
received_a.append(event)
|
||||
|
||||
@crewai_event_bus.on(EventTypeB)
|
||||
def handler_b(source: object, event: BaseEvent) -> None:
|
||||
with lock:
|
||||
received_b.append(event)
|
||||
|
||||
def emit_type_a() -> None:
|
||||
for i in range(50):
|
||||
crewai_event_bus.emit("source_a", EventTypeA(type=f"type_a_{i}"))
|
||||
|
||||
def emit_type_b() -> None:
|
||||
for i in range(50):
|
||||
crewai_event_bus.emit("source_b", EventTypeB(type=f"type_b_{i}"))
|
||||
|
||||
thread_a = threading.Thread(target=emit_type_a)
|
||||
thread_b = threading.Thread(target=emit_type_b)
|
||||
|
||||
thread_a.start()
|
||||
thread_b.start()
|
||||
|
||||
thread_a.join()
|
||||
thread_b.join()
|
||||
|
||||
time.sleep(0.5)
|
||||
|
||||
assert len(received_a) == 50
|
||||
assert len(received_b) == 50
|
||||
@@ -1,3 +1,4 @@
|
||||
import threading
|
||||
from datetime import datetime
|
||||
import os
|
||||
from unittest.mock import Mock, patch
|
||||
@@ -49,6 +50,8 @@ from crewai.tools.base_tool import BaseTool
|
||||
from pydantic import Field
|
||||
import pytest
|
||||
|
||||
from ..utils import wait_for_event_handlers
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def vcr_config(request) -> dict:
|
||||
@@ -118,6 +121,7 @@ def test_crew_emits_start_kickoff_event(
|
||||
# Now when Crew creates EventListener, it will use our mocked telemetry
|
||||
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
|
||||
crew.kickoff()
|
||||
wait_for_event_handlers()
|
||||
|
||||
mock_telemetry.crew_execution_span.assert_called_once_with(crew, None)
|
||||
mock_telemetry.end_crew.assert_called_once_with(crew, "hi")
|
||||
@@ -131,15 +135,20 @@ def test_crew_emits_start_kickoff_event(
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_emits_end_kickoff_event(base_agent, base_task):
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(CrewKickoffCompletedEvent)
|
||||
def handle_crew_end(source, event):
|
||||
received_events.append(event)
|
||||
event_received.set()
|
||||
|
||||
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
|
||||
|
||||
crew.kickoff()
|
||||
|
||||
assert event_received.wait(timeout=5), (
|
||||
"Timeout waiting for crew kickoff completed event"
|
||||
)
|
||||
assert len(received_events) == 1
|
||||
assert received_events[0].crew_name == "TestCrew"
|
||||
assert isinstance(received_events[0].timestamp, datetime)
|
||||
@@ -165,6 +174,7 @@ def test_crew_emits_test_kickoff_type_event(base_agent, base_task):
|
||||
eval_llm = LLM(model="gpt-4o-mini")
|
||||
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
|
||||
crew.test(n_iterations=1, eval_llm=eval_llm)
|
||||
wait_for_event_handlers()
|
||||
|
||||
assert len(received_events) == 3
|
||||
assert received_events[0].crew_name == "TestCrew"
|
||||
@@ -181,40 +191,44 @@ def test_crew_emits_test_kickoff_type_event(base_agent, base_task):
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_emits_kickoff_failed_event(base_agent, base_task):
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
@crewai_event_bus.on(CrewKickoffFailedEvent)
|
||||
def handle_crew_failed(source, event):
|
||||
received_events.append(event)
|
||||
event_received.set()
|
||||
|
||||
@crewai_event_bus.on(CrewKickoffFailedEvent)
|
||||
def handle_crew_failed(source, event):
|
||||
received_events.append(event)
|
||||
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
|
||||
|
||||
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
|
||||
with patch.object(Crew, "_execute_tasks") as mock_execute:
|
||||
error_message = "Simulated crew kickoff failure"
|
||||
mock_execute.side_effect = Exception(error_message)
|
||||
|
||||
with patch.object(Crew, "_execute_tasks") as mock_execute:
|
||||
error_message = "Simulated crew kickoff failure"
|
||||
mock_execute.side_effect = Exception(error_message)
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
crew.kickoff()
|
||||
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
crew.kickoff()
|
||||
|
||||
assert len(received_events) == 1
|
||||
assert received_events[0].error == error_message
|
||||
assert isinstance(received_events[0].timestamp, datetime)
|
||||
assert received_events[0].type == "crew_kickoff_failed"
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for failed event"
|
||||
assert len(received_events) == 1
|
||||
assert received_events[0].error == error_message
|
||||
assert isinstance(received_events[0].timestamp, datetime)
|
||||
assert received_events[0].type == "crew_kickoff_failed"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_emits_start_task_event(base_agent, base_task):
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(TaskStartedEvent)
|
||||
def handle_task_start(source, event):
|
||||
received_events.append(event)
|
||||
event_received.set()
|
||||
|
||||
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
|
||||
|
||||
crew.kickoff()
|
||||
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for task started event"
|
||||
assert len(received_events) == 1
|
||||
assert isinstance(received_events[0].timestamp, datetime)
|
||||
assert received_events[0].type == "task_started"
|
||||
@@ -225,10 +239,12 @@ def test_crew_emits_end_task_event(
|
||||
base_agent, base_task, reset_event_listener_singleton
|
||||
):
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(TaskCompletedEvent)
|
||||
def handle_task_end(source, event):
|
||||
received_events.append(event)
|
||||
event_received.set()
|
||||
|
||||
mock_span = Mock()
|
||||
|
||||
@@ -246,6 +262,7 @@ def test_crew_emits_end_task_event(
|
||||
mock_telemetry.task_started.assert_called_once_with(crew=crew, task=base_task)
|
||||
mock_telemetry.task_ended.assert_called_once_with(mock_span, base_task, crew)
|
||||
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for task completed event"
|
||||
assert len(received_events) == 1
|
||||
assert isinstance(received_events[0].timestamp, datetime)
|
||||
assert received_events[0].type == "task_completed"
|
||||
@@ -255,11 +272,13 @@ def test_crew_emits_end_task_event(
|
||||
def test_task_emits_failed_event_on_execution_error(base_agent, base_task):
|
||||
received_events = []
|
||||
received_sources = []
|
||||
event_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(TaskFailedEvent)
|
||||
def handle_task_failed(source, event):
|
||||
received_events.append(event)
|
||||
received_sources.append(source)
|
||||
event_received.set()
|
||||
|
||||
with patch.object(
|
||||
Task,
|
||||
@@ -281,6 +300,9 @@ def test_task_emits_failed_event_on_execution_error(base_agent, base_task):
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
agent.execute_task(task=task)
|
||||
|
||||
assert event_received.wait(timeout=5), (
|
||||
"Timeout waiting for task failed event"
|
||||
)
|
||||
assert len(received_events) == 1
|
||||
assert received_sources[0] == task
|
||||
assert received_events[0].error == error_message
|
||||
@@ -291,17 +313,27 @@ def test_task_emits_failed_event_on_execution_error(base_agent, base_task):
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_agent_emits_execution_started_and_completed_events(base_agent, base_task):
|
||||
received_events = []
|
||||
lock = threading.Lock()
|
||||
all_events_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(AgentExecutionStartedEvent)
|
||||
def handle_agent_start(source, event):
|
||||
received_events.append(event)
|
||||
with lock:
|
||||
received_events.append(event)
|
||||
|
||||
@crewai_event_bus.on(AgentExecutionCompletedEvent)
|
||||
def handle_agent_completed(source, event):
|
||||
received_events.append(event)
|
||||
with lock:
|
||||
received_events.append(event)
|
||||
if len(received_events) >= 2:
|
||||
all_events_received.set()
|
||||
|
||||
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
|
||||
crew.kickoff()
|
||||
|
||||
assert all_events_received.wait(timeout=5), (
|
||||
"Timeout waiting for agent execution events"
|
||||
)
|
||||
assert len(received_events) == 2
|
||||
assert received_events[0].agent == base_agent
|
||||
assert received_events[0].task == base_task
|
||||
@@ -320,10 +352,12 @@ def test_agent_emits_execution_started_and_completed_events(base_agent, base_tas
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_agent_emits_execution_error_event(base_agent, base_task):
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(AgentExecutionErrorEvent)
|
||||
def handle_agent_start(source, event):
|
||||
received_events.append(event)
|
||||
event_received.set()
|
||||
|
||||
error_message = "Error happening while sending prompt to model."
|
||||
base_agent.max_retry_limit = 0
|
||||
@@ -337,6 +371,9 @@ def test_agent_emits_execution_error_event(base_agent, base_task):
|
||||
task=base_task,
|
||||
)
|
||||
|
||||
assert event_received.wait(timeout=5), (
|
||||
"Timeout waiting for agent execution error event"
|
||||
)
|
||||
assert len(received_events) == 1
|
||||
assert received_events[0].agent == base_agent
|
||||
assert received_events[0].task == base_task
|
||||
@@ -358,10 +395,12 @@ class SayHiTool(BaseTool):
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_tools_emits_finished_events():
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(ToolUsageFinishedEvent)
|
||||
def handle_tool_end(source, event):
|
||||
received_events.append(event)
|
||||
event_received.set()
|
||||
|
||||
agent = Agent(
|
||||
role="base_agent",
|
||||
@@ -377,6 +416,10 @@ def test_tools_emits_finished_events():
|
||||
)
|
||||
crew = Crew(agents=[agent], tasks=[task], name="TestCrew")
|
||||
crew.kickoff()
|
||||
|
||||
assert event_received.wait(timeout=5), (
|
||||
"Timeout waiting for tool usage finished event"
|
||||
)
|
||||
assert len(received_events) == 1
|
||||
assert received_events[0].agent_key == agent.key
|
||||
assert received_events[0].agent_role == agent.role
|
||||
@@ -389,10 +432,15 @@ def test_tools_emits_finished_events():
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_tools_emits_error_events():
|
||||
received_events = []
|
||||
lock = threading.Lock()
|
||||
all_events_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(ToolUsageErrorEvent)
|
||||
def handle_tool_end(source, event):
|
||||
received_events.append(event)
|
||||
with lock:
|
||||
received_events.append(event)
|
||||
if len(received_events) >= 48:
|
||||
all_events_received.set()
|
||||
|
||||
class ErrorTool(BaseTool):
|
||||
name: str = Field(
|
||||
@@ -423,6 +471,9 @@ def test_tools_emits_error_events():
|
||||
crew = Crew(agents=[agent], tasks=[task], name="TestCrew")
|
||||
crew.kickoff()
|
||||
|
||||
assert all_events_received.wait(timeout=5), (
|
||||
"Timeout waiting for tool usage error events"
|
||||
)
|
||||
assert len(received_events) == 48
|
||||
assert received_events[0].agent_key == agent.key
|
||||
assert received_events[0].agent_role == agent.role
|
||||
@@ -435,11 +486,13 @@ def test_tools_emits_error_events():
|
||||
|
||||
def test_flow_emits_start_event(reset_event_listener_singleton):
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
mock_span = Mock()
|
||||
|
||||
@crewai_event_bus.on(FlowStartedEvent)
|
||||
def handle_flow_start(source, event):
|
||||
received_events.append(event)
|
||||
event_received.set()
|
||||
|
||||
class TestFlow(Flow[dict]):
|
||||
@start()
|
||||
@@ -458,6 +511,7 @@ def test_flow_emits_start_event(reset_event_listener_singleton):
|
||||
flow = TestFlow()
|
||||
flow.kickoff()
|
||||
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for flow started event"
|
||||
mock_telemetry.flow_execution_span.assert_called_once_with("TestFlow", ["begin"])
|
||||
assert len(received_events) == 1
|
||||
assert received_events[0].flow_name == "TestFlow"
|
||||
@@ -466,6 +520,7 @@ def test_flow_emits_start_event(reset_event_listener_singleton):
|
||||
|
||||
def test_flow_name_emitted_to_event_bus():
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
|
||||
class MyFlowClass(Flow):
|
||||
name = "PRODUCTION_FLOW"
|
||||
@@ -477,118 +532,133 @@ def test_flow_name_emitted_to_event_bus():
|
||||
@crewai_event_bus.on(FlowStartedEvent)
|
||||
def handle_flow_start(source, event):
|
||||
received_events.append(event)
|
||||
event_received.set()
|
||||
|
||||
flow = MyFlowClass()
|
||||
flow.kickoff()
|
||||
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for flow started event"
|
||||
assert len(received_events) == 1
|
||||
assert received_events[0].flow_name == "PRODUCTION_FLOW"
|
||||
|
||||
|
||||
def test_flow_emits_finish_event():
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
@crewai_event_bus.on(FlowFinishedEvent)
|
||||
def handle_flow_finish(source, event):
|
||||
received_events.append(event)
|
||||
event_received.set()
|
||||
|
||||
@crewai_event_bus.on(FlowFinishedEvent)
|
||||
def handle_flow_finish(source, event):
|
||||
received_events.append(event)
|
||||
class TestFlow(Flow[dict]):
|
||||
@start()
|
||||
def begin(self):
|
||||
return "completed"
|
||||
|
||||
class TestFlow(Flow[dict]):
|
||||
@start()
|
||||
def begin(self):
|
||||
return "completed"
|
||||
flow = TestFlow()
|
||||
result = flow.kickoff()
|
||||
|
||||
flow = TestFlow()
|
||||
result = flow.kickoff()
|
||||
|
||||
assert len(received_events) == 1
|
||||
assert received_events[0].flow_name == "TestFlow"
|
||||
assert received_events[0].type == "flow_finished"
|
||||
assert received_events[0].result == "completed"
|
||||
assert result == "completed"
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for finish event"
|
||||
assert len(received_events) == 1
|
||||
assert received_events[0].flow_name == "TestFlow"
|
||||
assert received_events[0].type == "flow_finished"
|
||||
assert received_events[0].result == "completed"
|
||||
assert result == "completed"
|
||||
|
||||
|
||||
def test_flow_emits_method_execution_started_event():
|
||||
received_events = []
|
||||
lock = threading.Lock()
|
||||
second_event_received = threading.Event()
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(MethodExecutionStartedEvent)
|
||||
def handle_method_start(source, event):
|
||||
@crewai_event_bus.on(MethodExecutionStartedEvent)
|
||||
async def handle_method_start(source, event):
|
||||
with lock:
|
||||
received_events.append(event)
|
||||
if event.method_name == "second_method":
|
||||
second_event_received.set()
|
||||
|
||||
class TestFlow(Flow[dict]):
|
||||
@start()
|
||||
def begin(self):
|
||||
return "started"
|
||||
class TestFlow(Flow[dict]):
|
||||
@start()
|
||||
def begin(self):
|
||||
return "started"
|
||||
|
||||
@listen("begin")
|
||||
def second_method(self):
|
||||
return "executed"
|
||||
@listen("begin")
|
||||
def second_method(self):
|
||||
return "executed"
|
||||
|
||||
flow = TestFlow()
|
||||
flow.kickoff()
|
||||
flow = TestFlow()
|
||||
flow.kickoff()
|
||||
|
||||
assert len(received_events) == 2
|
||||
assert second_event_received.wait(timeout=5), (
|
||||
"Timeout waiting for second_method event"
|
||||
)
|
||||
assert len(received_events) == 2
|
||||
|
||||
assert received_events[0].method_name == "begin"
|
||||
assert received_events[0].flow_name == "TestFlow"
|
||||
assert received_events[0].type == "method_execution_started"
|
||||
# Events may arrive in any order due to async handlers, so check both are present
|
||||
method_names = {event.method_name for event in received_events}
|
||||
assert method_names == {"begin", "second_method"}
|
||||
|
||||
assert received_events[1].method_name == "second_method"
|
||||
assert received_events[1].flow_name == "TestFlow"
|
||||
assert received_events[1].type == "method_execution_started"
|
||||
for event in received_events:
|
||||
assert event.flow_name == "TestFlow"
|
||||
assert event.type == "method_execution_started"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_register_handler_adds_new_handler(base_agent, base_task):
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
|
||||
def custom_handler(source, event):
|
||||
received_events.append(event)
|
||||
event_received.set()
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
crewai_event_bus.register_handler(CrewKickoffStartedEvent, custom_handler)
|
||||
crewai_event_bus.register_handler(CrewKickoffStartedEvent, custom_handler)
|
||||
|
||||
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
|
||||
crew.kickoff()
|
||||
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
|
||||
crew.kickoff()
|
||||
|
||||
assert len(received_events) == 1
|
||||
assert isinstance(received_events[0].timestamp, datetime)
|
||||
assert received_events[0].type == "crew_kickoff_started"
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for handler event"
|
||||
assert len(received_events) == 1
|
||||
assert isinstance(received_events[0].timestamp, datetime)
|
||||
assert received_events[0].type == "crew_kickoff_started"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_multiple_handlers_for_same_event(base_agent, base_task):
|
||||
received_events_1 = []
|
||||
received_events_2 = []
|
||||
event_received = threading.Event()
|
||||
|
||||
def handler_1(source, event):
|
||||
received_events_1.append(event)
|
||||
|
||||
def handler_2(source, event):
|
||||
received_events_2.append(event)
|
||||
event_received.set()
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
crewai_event_bus.register_handler(CrewKickoffStartedEvent, handler_1)
|
||||
crewai_event_bus.register_handler(CrewKickoffStartedEvent, handler_2)
|
||||
crewai_event_bus.register_handler(CrewKickoffStartedEvent, handler_1)
|
||||
crewai_event_bus.register_handler(CrewKickoffStartedEvent, handler_2)
|
||||
|
||||
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
|
||||
crew.kickoff()
|
||||
crew = Crew(agents=[base_agent], tasks=[base_task], name="TestCrew")
|
||||
crew.kickoff()
|
||||
|
||||
assert len(received_events_1) == 1
|
||||
assert len(received_events_2) == 1
|
||||
assert received_events_1[0].type == "crew_kickoff_started"
|
||||
assert received_events_2[0].type == "crew_kickoff_started"
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for handler events"
|
||||
assert len(received_events_1) == 1
|
||||
assert len(received_events_2) == 1
|
||||
assert received_events_1[0].type == "crew_kickoff_started"
|
||||
assert received_events_2[0].type == "crew_kickoff_started"
|
||||
|
||||
|
||||
def test_flow_emits_created_event():
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(FlowCreatedEvent)
|
||||
def handle_flow_created(source, event):
|
||||
received_events.append(event)
|
||||
event_received.set()
|
||||
|
||||
class TestFlow(Flow[dict]):
|
||||
@start()
|
||||
@@ -598,6 +668,7 @@ def test_flow_emits_created_event():
|
||||
flow = TestFlow()
|
||||
flow.kickoff()
|
||||
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for flow created event"
|
||||
assert len(received_events) == 1
|
||||
assert received_events[0].flow_name == "TestFlow"
|
||||
assert received_events[0].type == "flow_created"
|
||||
@@ -605,11 +676,13 @@ def test_flow_emits_created_event():
|
||||
|
||||
def test_flow_emits_method_execution_failed_event():
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
error = Exception("Simulated method failure")
|
||||
|
||||
@crewai_event_bus.on(MethodExecutionFailedEvent)
|
||||
def handle_method_failed(source, event):
|
||||
received_events.append(event)
|
||||
event_received.set()
|
||||
|
||||
class TestFlow(Flow[dict]):
|
||||
@start()
|
||||
@@ -620,6 +693,9 @@ def test_flow_emits_method_execution_failed_event():
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
flow.kickoff()
|
||||
|
||||
assert event_received.wait(timeout=5), (
|
||||
"Timeout waiting for method execution failed event"
|
||||
)
|
||||
assert len(received_events) == 1
|
||||
assert received_events[0].method_name == "begin"
|
||||
assert received_events[0].flow_name == "TestFlow"
|
||||
@@ -641,6 +717,7 @@ def test_llm_emits_call_started_event():
|
||||
|
||||
llm = LLM(model="gpt-4o-mini")
|
||||
llm.call("Hello, how are you?")
|
||||
wait_for_event_handlers()
|
||||
|
||||
assert len(received_events) == 2
|
||||
assert received_events[0].type == "llm_call_started"
|
||||
@@ -656,10 +733,12 @@ def test_llm_emits_call_started_event():
|
||||
@pytest.mark.isolated
|
||||
def test_llm_emits_call_failed_event():
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(LLMCallFailedEvent)
|
||||
def handle_llm_call_failed(source, event):
|
||||
received_events.append(event)
|
||||
event_received.set()
|
||||
|
||||
error_message = "OpenAI API call failed: Simulated API failure"
|
||||
|
||||
@@ -673,6 +752,7 @@ def test_llm_emits_call_failed_event():
|
||||
llm.call("Hello, how are you?")
|
||||
|
||||
assert str(exc_info.value) == "Simulated API failure"
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for failed event"
|
||||
assert len(received_events) == 1
|
||||
assert received_events[0].type == "llm_call_failed"
|
||||
assert received_events[0].error == error_message
|
||||
@@ -686,24 +766,28 @@ def test_llm_emits_call_failed_event():
|
||||
def test_llm_emits_stream_chunk_events():
|
||||
"""Test that LLM emits stream chunk events when streaming is enabled."""
|
||||
received_chunks = []
|
||||
event_received = threading.Event()
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
@crewai_event_bus.on(LLMStreamChunkEvent)
|
||||
def handle_stream_chunk(source, event):
|
||||
received_chunks.append(event.chunk)
|
||||
if len(received_chunks) >= 1:
|
||||
event_received.set()
|
||||
|
||||
@crewai_event_bus.on(LLMStreamChunkEvent)
|
||||
def handle_stream_chunk(source, event):
|
||||
received_chunks.append(event.chunk)
|
||||
# Create an LLM with streaming enabled
|
||||
llm = LLM(model="gpt-4o", stream=True)
|
||||
|
||||
# Create an LLM with streaming enabled
|
||||
llm = LLM(model="gpt-4o", stream=True)
|
||||
# Call the LLM with a simple message
|
||||
response = llm.call("Tell me a short joke")
|
||||
|
||||
# Call the LLM with a simple message
|
||||
response = llm.call("Tell me a short joke")
|
||||
# Wait for at least one chunk
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for stream chunks"
|
||||
|
||||
# Verify that we received chunks
|
||||
assert len(received_chunks) > 0
|
||||
# Verify that we received chunks
|
||||
assert len(received_chunks) > 0
|
||||
|
||||
# Verify that concatenating all chunks equals the final response
|
||||
assert "".join(received_chunks) == response
|
||||
# Verify that concatenating all chunks equals the final response
|
||||
assert "".join(received_chunks) == response
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -711,23 +795,21 @@ def test_llm_no_stream_chunks_when_streaming_disabled():
|
||||
"""Test that LLM doesn't emit stream chunk events when streaming is disabled."""
|
||||
received_chunks = []
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
@crewai_event_bus.on(LLMStreamChunkEvent)
|
||||
def handle_stream_chunk(source, event):
|
||||
received_chunks.append(event.chunk)
|
||||
|
||||
@crewai_event_bus.on(LLMStreamChunkEvent)
|
||||
def handle_stream_chunk(source, event):
|
||||
received_chunks.append(event.chunk)
|
||||
# Create an LLM with streaming disabled
|
||||
llm = LLM(model="gpt-4o", stream=False)
|
||||
|
||||
# Create an LLM with streaming disabled
|
||||
llm = LLM(model="gpt-4o", stream=False)
|
||||
# Call the LLM with a simple message
|
||||
response = llm.call("Tell me a short joke")
|
||||
|
||||
# Call the LLM with a simple message
|
||||
response = llm.call("Tell me a short joke")
|
||||
# Verify that we didn't receive any chunks
|
||||
assert len(received_chunks) == 0
|
||||
|
||||
# Verify that we didn't receive any chunks
|
||||
assert len(received_chunks) == 0
|
||||
|
||||
# Verify we got a response
|
||||
assert response and isinstance(response, str)
|
||||
# Verify we got a response
|
||||
assert response and isinstance(response, str)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -735,98 +817,105 @@ def test_streaming_fallback_to_non_streaming():
|
||||
"""Test that streaming falls back to non-streaming when there's an error."""
|
||||
received_chunks = []
|
||||
fallback_called = False
|
||||
event_received = threading.Event()
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
@crewai_event_bus.on(LLMStreamChunkEvent)
|
||||
def handle_stream_chunk(source, event):
|
||||
received_chunks.append(event.chunk)
|
||||
if len(received_chunks) >= 2:
|
||||
event_received.set()
|
||||
|
||||
@crewai_event_bus.on(LLMStreamChunkEvent)
|
||||
def handle_stream_chunk(source, event):
|
||||
received_chunks.append(event.chunk)
|
||||
# Create an LLM with streaming enabled
|
||||
llm = LLM(model="gpt-4o", stream=True)
|
||||
|
||||
# Create an LLM with streaming enabled
|
||||
llm = LLM(model="gpt-4o", stream=True)
|
||||
# Store original methods
|
||||
original_call = llm.call
|
||||
|
||||
# Store original methods
|
||||
original_call = llm.call
|
||||
# Create a mock call method that handles the streaming error
|
||||
def mock_call(messages, tools=None, callbacks=None, available_functions=None):
|
||||
nonlocal fallback_called
|
||||
# Emit a couple of chunks to simulate partial streaming
|
||||
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk="Test chunk 1"))
|
||||
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk="Test chunk 2"))
|
||||
|
||||
# Create a mock call method that handles the streaming error
|
||||
def mock_call(messages, tools=None, callbacks=None, available_functions=None):
|
||||
nonlocal fallback_called
|
||||
# Emit a couple of chunks to simulate partial streaming
|
||||
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk="Test chunk 1"))
|
||||
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk="Test chunk 2"))
|
||||
# Mark that fallback would be called
|
||||
fallback_called = True
|
||||
|
||||
# Mark that fallback would be called
|
||||
fallback_called = True
|
||||
# Return a response as if fallback succeeded
|
||||
return "Fallback response after streaming error"
|
||||
|
||||
# Return a response as if fallback succeeded
|
||||
return "Fallback response after streaming error"
|
||||
# Replace the call method with our mock
|
||||
llm.call = mock_call
|
||||
|
||||
# Replace the call method with our mock
|
||||
llm.call = mock_call
|
||||
try:
|
||||
# Call the LLM
|
||||
response = llm.call("Tell me a short joke")
|
||||
wait_for_event_handlers()
|
||||
|
||||
try:
|
||||
# Call the LLM
|
||||
response = llm.call("Tell me a short joke")
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for stream chunks"
|
||||
|
||||
# Verify that we received some chunks
|
||||
assert len(received_chunks) == 2
|
||||
assert received_chunks[0] == "Test chunk 1"
|
||||
assert received_chunks[1] == "Test chunk 2"
|
||||
# Verify that we received some chunks
|
||||
assert len(received_chunks) == 2
|
||||
assert received_chunks[0] == "Test chunk 1"
|
||||
assert received_chunks[1] == "Test chunk 2"
|
||||
|
||||
# Verify fallback was triggered
|
||||
assert fallback_called
|
||||
# Verify fallback was triggered
|
||||
assert fallback_called
|
||||
|
||||
# Verify we got the fallback response
|
||||
assert response == "Fallback response after streaming error"
|
||||
# Verify we got the fallback response
|
||||
assert response == "Fallback response after streaming error"
|
||||
|
||||
finally:
|
||||
# Restore the original method
|
||||
llm.call = original_call
|
||||
finally:
|
||||
# Restore the original method
|
||||
llm.call = original_call
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_streaming_empty_response_handling():
|
||||
"""Test that streaming handles empty responses correctly."""
|
||||
received_chunks = []
|
||||
event_received = threading.Event()
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
@crewai_event_bus.on(LLMStreamChunkEvent)
|
||||
def handle_stream_chunk(source, event):
|
||||
received_chunks.append(event.chunk)
|
||||
if len(received_chunks) >= 3:
|
||||
event_received.set()
|
||||
|
||||
@crewai_event_bus.on(LLMStreamChunkEvent)
|
||||
def handle_stream_chunk(source, event):
|
||||
received_chunks.append(event.chunk)
|
||||
# Create an LLM with streaming enabled
|
||||
llm = LLM(model="gpt-3.5-turbo", stream=True)
|
||||
|
||||
# Create an LLM with streaming enabled
|
||||
llm = LLM(model="gpt-3.5-turbo", stream=True)
|
||||
# Store original methods
|
||||
original_call = llm.call
|
||||
|
||||
# Store original methods
|
||||
original_call = llm.call
|
||||
# Create a mock call method that simulates empty chunks
|
||||
def mock_call(messages, tools=None, callbacks=None, available_functions=None):
|
||||
# Emit a few empty chunks
|
||||
for _ in range(3):
|
||||
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk=""))
|
||||
|
||||
# Create a mock call method that simulates empty chunks
|
||||
def mock_call(messages, tools=None, callbacks=None, available_functions=None):
|
||||
# Emit a few empty chunks
|
||||
for _ in range(3):
|
||||
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk=""))
|
||||
# Return the default message for empty responses
|
||||
return "I apologize, but I couldn't generate a proper response. Please try again or rephrase your request."
|
||||
|
||||
# Return the default message for empty responses
|
||||
return "I apologize, but I couldn't generate a proper response. Please try again or rephrase your request."
|
||||
# Replace the call method with our mock
|
||||
llm.call = mock_call
|
||||
|
||||
# Replace the call method with our mock
|
||||
llm.call = mock_call
|
||||
try:
|
||||
# Call the LLM - this should handle empty response
|
||||
response = llm.call("Tell me a short joke")
|
||||
|
||||
try:
|
||||
# Call the LLM - this should handle empty response
|
||||
response = llm.call("Tell me a short joke")
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for empty chunks"
|
||||
|
||||
# Verify that we received empty chunks
|
||||
assert len(received_chunks) == 3
|
||||
assert all(chunk == "" for chunk in received_chunks)
|
||||
# Verify that we received empty chunks
|
||||
assert len(received_chunks) == 3
|
||||
assert all(chunk == "" for chunk in received_chunks)
|
||||
|
||||
# Verify the response is the default message for empty responses
|
||||
assert "I apologize" in response and "couldn't generate" in response
|
||||
# Verify the response is the default message for empty responses
|
||||
assert "I apologize" in response and "couldn't generate" in response
|
||||
|
||||
finally:
|
||||
# Restore the original method
|
||||
llm.call = original_call
|
||||
finally:
|
||||
# Restore the original method
|
||||
llm.call = original_call
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -835,41 +924,49 @@ def test_stream_llm_emits_event_with_task_and_agent_info():
|
||||
failed_event = []
|
||||
started_event = []
|
||||
stream_event = []
|
||||
event_received = threading.Event()
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
@crewai_event_bus.on(LLMCallFailedEvent)
|
||||
def handle_llm_failed(source, event):
|
||||
failed_event.append(event)
|
||||
|
||||
@crewai_event_bus.on(LLMCallFailedEvent)
|
||||
def handle_llm_failed(source, event):
|
||||
failed_event.append(event)
|
||||
@crewai_event_bus.on(LLMCallStartedEvent)
|
||||
def handle_llm_started(source, event):
|
||||
started_event.append(event)
|
||||
|
||||
@crewai_event_bus.on(LLMCallStartedEvent)
|
||||
def handle_llm_started(source, event):
|
||||
started_event.append(event)
|
||||
@crewai_event_bus.on(LLMCallCompletedEvent)
|
||||
def handle_llm_completed(source, event):
|
||||
completed_event.append(event)
|
||||
if len(started_event) >= 1 and len(stream_event) >= 12:
|
||||
event_received.set()
|
||||
|
||||
@crewai_event_bus.on(LLMCallCompletedEvent)
|
||||
def handle_llm_completed(source, event):
|
||||
completed_event.append(event)
|
||||
@crewai_event_bus.on(LLMStreamChunkEvent)
|
||||
def handle_llm_stream_chunk(source, event):
|
||||
stream_event.append(event)
|
||||
if (
|
||||
len(completed_event) >= 1
|
||||
and len(started_event) >= 1
|
||||
and len(stream_event) >= 12
|
||||
):
|
||||
event_received.set()
|
||||
|
||||
@crewai_event_bus.on(LLMStreamChunkEvent)
|
||||
def handle_llm_stream_chunk(source, event):
|
||||
stream_event.append(event)
|
||||
agent = Agent(
|
||||
role="TestAgent",
|
||||
llm=LLM(model="gpt-4o-mini", stream=True),
|
||||
goal="Just say hi",
|
||||
backstory="You are a helpful assistant that just says hi",
|
||||
)
|
||||
task = Task(
|
||||
description="Just say hi",
|
||||
expected_output="hi",
|
||||
llm=LLM(model="gpt-4o-mini", stream=True),
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
role="TestAgent",
|
||||
llm=LLM(model="gpt-4o-mini", stream=True),
|
||||
goal="Just say hi",
|
||||
backstory="You are a helpful assistant that just says hi",
|
||||
)
|
||||
task = Task(
|
||||
description="Just say hi",
|
||||
expected_output="hi",
|
||||
llm=LLM(model="gpt-4o-mini", stream=True),
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
crew.kickoff()
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
crew.kickoff()
|
||||
|
||||
assert event_received.wait(timeout=10), "Timeout waiting for LLM events"
|
||||
assert len(completed_event) == 1
|
||||
assert len(failed_event) == 0
|
||||
assert len(started_event) == 1
|
||||
@@ -899,28 +996,30 @@ def test_llm_emits_event_with_task_and_agent_info(base_agent, base_task):
|
||||
failed_event = []
|
||||
started_event = []
|
||||
stream_event = []
|
||||
event_received = threading.Event()
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
@crewai_event_bus.on(LLMCallFailedEvent)
|
||||
def handle_llm_failed(source, event):
|
||||
failed_event.append(event)
|
||||
|
||||
@crewai_event_bus.on(LLMCallFailedEvent)
|
||||
def handle_llm_failed(source, event):
|
||||
failed_event.append(event)
|
||||
@crewai_event_bus.on(LLMCallStartedEvent)
|
||||
def handle_llm_started(source, event):
|
||||
started_event.append(event)
|
||||
|
||||
@crewai_event_bus.on(LLMCallStartedEvent)
|
||||
def handle_llm_started(source, event):
|
||||
started_event.append(event)
|
||||
@crewai_event_bus.on(LLMCallCompletedEvent)
|
||||
def handle_llm_completed(source, event):
|
||||
completed_event.append(event)
|
||||
if len(started_event) >= 1:
|
||||
event_received.set()
|
||||
|
||||
@crewai_event_bus.on(LLMCallCompletedEvent)
|
||||
def handle_llm_completed(source, event):
|
||||
completed_event.append(event)
|
||||
@crewai_event_bus.on(LLMStreamChunkEvent)
|
||||
def handle_llm_stream_chunk(source, event):
|
||||
stream_event.append(event)
|
||||
|
||||
@crewai_event_bus.on(LLMStreamChunkEvent)
|
||||
def handle_llm_stream_chunk(source, event):
|
||||
stream_event.append(event)
|
||||
|
||||
crew = Crew(agents=[base_agent], tasks=[base_task])
|
||||
crew.kickoff()
|
||||
crew = Crew(agents=[base_agent], tasks=[base_task])
|
||||
crew.kickoff()
|
||||
|
||||
assert event_received.wait(timeout=10), "Timeout waiting for LLM events"
|
||||
assert len(completed_event) == 1
|
||||
assert len(failed_event) == 0
|
||||
assert len(started_event) == 1
|
||||
@@ -950,32 +1049,41 @@ def test_llm_emits_event_with_lite_agent():
|
||||
failed_event = []
|
||||
started_event = []
|
||||
stream_event = []
|
||||
all_events_received = threading.Event()
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
@crewai_event_bus.on(LLMCallFailedEvent)
|
||||
def handle_llm_failed(source, event):
|
||||
failed_event.append(event)
|
||||
|
||||
@crewai_event_bus.on(LLMCallFailedEvent)
|
||||
def handle_llm_failed(source, event):
|
||||
failed_event.append(event)
|
||||
@crewai_event_bus.on(LLMCallStartedEvent)
|
||||
def handle_llm_started(source, event):
|
||||
started_event.append(event)
|
||||
|
||||
@crewai_event_bus.on(LLMCallStartedEvent)
|
||||
def handle_llm_started(source, event):
|
||||
started_event.append(event)
|
||||
@crewai_event_bus.on(LLMCallCompletedEvent)
|
||||
def handle_llm_completed(source, event):
|
||||
completed_event.append(event)
|
||||
if len(started_event) >= 1 and len(stream_event) >= 15:
|
||||
all_events_received.set()
|
||||
|
||||
@crewai_event_bus.on(LLMCallCompletedEvent)
|
||||
def handle_llm_completed(source, event):
|
||||
completed_event.append(event)
|
||||
@crewai_event_bus.on(LLMStreamChunkEvent)
|
||||
def handle_llm_stream_chunk(source, event):
|
||||
stream_event.append(event)
|
||||
if (
|
||||
len(completed_event) >= 1
|
||||
and len(started_event) >= 1
|
||||
and len(stream_event) >= 15
|
||||
):
|
||||
all_events_received.set()
|
||||
|
||||
@crewai_event_bus.on(LLMStreamChunkEvent)
|
||||
def handle_llm_stream_chunk(source, event):
|
||||
stream_event.append(event)
|
||||
agent = Agent(
|
||||
role="Speaker",
|
||||
llm=LLM(model="gpt-4o-mini", stream=True),
|
||||
goal="Just say hi",
|
||||
backstory="You are a helpful assistant that just says hi",
|
||||
)
|
||||
agent.kickoff(messages=[{"role": "user", "content": "say hi!"}])
|
||||
|
||||
agent = Agent(
|
||||
role="Speaker",
|
||||
llm=LLM(model="gpt-4o-mini", stream=True),
|
||||
goal="Just say hi",
|
||||
backstory="You are a helpful assistant that just says hi",
|
||||
)
|
||||
agent.kickoff(messages=[{"role": "user", "content": "say hi!"}])
|
||||
assert all_events_received.wait(timeout=10), "Timeout waiting for all events"
|
||||
|
||||
assert len(completed_event) == 1
|
||||
assert len(failed_event) == 0
|
||||
|
||||
39
lib/crewai/tests/utils.py
Normal file
39
lib/crewai/tests/utils.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""Test utilities for CrewAI tests."""
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
|
||||
def wait_for_event_handlers(timeout: float = 5.0) -> None:
|
||||
"""Wait for all pending event handlers to complete.
|
||||
|
||||
This helper ensures all sync and async handlers finish processing before
|
||||
proceeding. Useful in tests to make assertions deterministic.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait in seconds.
|
||||
"""
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
|
||||
loop = getattr(crewai_event_bus, "_loop", None)
|
||||
|
||||
if loop and not loop.is_closed():
|
||||
|
||||
async def _wait_for_async_tasks() -> None:
|
||||
tasks = {
|
||||
t for t in asyncio.all_tasks(loop) if t is not asyncio.current_task()
|
||||
}
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
future = asyncio.run_coroutine_threadsafe(_wait_for_async_tasks(), loop)
|
||||
try:
|
||||
future.result(timeout=timeout)
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
crewai_event_bus._sync_executor.shutdown(wait=True)
|
||||
crewai_event_bus._sync_executor = ThreadPoolExecutor(
|
||||
max_workers=10,
|
||||
thread_name_prefix="CrewAISyncHandler",
|
||||
)
|
||||
@@ -1,3 +1,3 @@
|
||||
"""CrewAI development tools."""
|
||||
|
||||
__version__ = "1.0.0a4"
|
||||
__version__ = "1.0.0b2"
|
||||
|
||||
19
uv.lock
generated
19
uv.lock
generated
@@ -465,15 +465,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/01/f3/a9d961cfba236dc85f27f2f2c6eab88e12698754aaa02459ba7dfafc5062/bedrock_agentcore-0.1.7-py3-none-any.whl", hash = "sha256:441dde64fea596e9571e47ae37ee3b033e58d8d255018f13bdcde8ae8bef2075", size = 77216, upload-time = "2025-10-01T16:18:38.153Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "blinker"
|
||||
version = "1.9.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/21/28/9b3f50ce0e048515135495f198351908d99540d69bfdc8c1d15b73dc55ce/blinker-1.9.0.tar.gz", hash = "sha256:b4ce2265a7abece45e7cc896e98dbebe6cead56bcf805a3d23136d145f5445bf", size = 22460, upload-time = "2024-11-08T17:25:47.436Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/10/cb/f2ad4230dc2eb1a74edf38f1a38b9b52277f75bef262d8908e60d957e13c/blinker-1.9.0-py3-none-any.whl", hash = "sha256:ba0efaa9080b619ff2f3459d1d500c57bddea4a6b424b60a91141db6fd2f08bc", size = 8458, upload-time = "2024-11-08T17:25:46.184Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "boto3"
|
||||
version = "1.40.45"
|
||||
@@ -987,7 +978,6 @@ name = "crewai"
|
||||
source = { editable = "lib/crewai" }
|
||||
dependencies = [
|
||||
{ name = "appdirs" },
|
||||
{ name = "blinker" },
|
||||
{ name = "chromadb" },
|
||||
{ name = "click" },
|
||||
{ name = "instructor" },
|
||||
@@ -1020,6 +1010,9 @@ aisuite = [
|
||||
aws = [
|
||||
{ name = "boto3" },
|
||||
]
|
||||
boto3 = [
|
||||
{ name = "boto3" },
|
||||
]
|
||||
docling = [
|
||||
{ name = "docling" },
|
||||
]
|
||||
@@ -1058,8 +1051,8 @@ watson = [
|
||||
requires-dist = [
|
||||
{ name = "aisuite", marker = "extra == 'aisuite'", specifier = ">=0.1.10" },
|
||||
{ name = "appdirs", specifier = ">=1.4.4" },
|
||||
{ name = "blinker", specifier = ">=1.9.0" },
|
||||
{ name = "boto3", marker = "extra == 'aws'", specifier = ">=1.40.38" },
|
||||
{ name = "boto3", marker = "extra == 'boto3'", specifier = ">=1.40.45" },
|
||||
{ name = "chromadb", specifier = "~=1.1.0" },
|
||||
{ name = "click", specifier = ">=8.1.7" },
|
||||
{ name = "crewai-tools", marker = "extra == 'tools'", editable = "lib/crewai-tools" },
|
||||
@@ -1095,7 +1088,7 @@ requires-dist = [
|
||||
{ name = "uv", specifier = ">=0.4.25" },
|
||||
{ name = "voyageai", marker = "extra == 'voyageai'", specifier = ">=0.3.5" },
|
||||
]
|
||||
provides-extras = ["aisuite", "aws", "docling", "embeddings", "litellm", "mem0", "openpyxl", "pandas", "pdfplumber", "qdrant", "tools", "voyageai", "watson"]
|
||||
provides-extras = ["aisuite", "aws", "boto3", "docling", "embeddings", "litellm", "mem0", "openpyxl", "pandas", "pdfplumber", "qdrant", "tools", "voyageai", "watson"]
|
||||
|
||||
[[package]]
|
||||
name = "crewai-devtools"
|
||||
@@ -1131,7 +1124,6 @@ dependencies = [
|
||||
{ name = "python-docx" },
|
||||
{ name = "pytube" },
|
||||
{ name = "requests" },
|
||||
{ name = "stagehand" },
|
||||
{ name = "tiktoken" },
|
||||
{ name = "youtube-transcript-api" },
|
||||
]
|
||||
@@ -1302,7 +1294,6 @@ requires-dist = [
|
||||
{ name = "spider-client", marker = "extra == 'spider-client'", specifier = ">=0.1.25" },
|
||||
{ name = "sqlalchemy", marker = "extra == 'singlestore'", specifier = ">=2.0.40" },
|
||||
{ name = "sqlalchemy", marker = "extra == 'sqlalchemy'", specifier = ">=2.0.35" },
|
||||
{ name = "stagehand", specifier = ">=0.4.1" },
|
||||
{ name = "stagehand", marker = "extra == 'stagehand'", specifier = ">=0.4.1" },
|
||||
{ name = "tavily-python", marker = "extra == 'tavily-python'", specifier = ">=0.5.4" },
|
||||
{ name = "tiktoken", specifier = ">=0.8.0" },
|
||||
|
||||
Reference in New Issue
Block a user