mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-04 22:49:23 +00:00
Compare commits
19 Commits
0.193.1
...
lg-trigger
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ae8e52b484 | ||
|
|
e070c1400c | ||
|
|
6537e3737d | ||
|
|
346faf229f | ||
|
|
a0b757a12c | ||
|
|
1dbe8aab52 | ||
|
|
4ac65eb0a6 | ||
|
|
3e97393f58 | ||
|
|
34bed359a6 | ||
|
|
feeed505bb | ||
|
|
cb0efd05b4 | ||
|
|
db5f565dea | ||
|
|
58413b663a | ||
|
|
37636f0dd7 | ||
|
|
0e370593f1 | ||
|
|
aa8dc9d77f | ||
|
|
9c1096dbdc | ||
|
|
47044450c0 | ||
|
|
0ee438c39d |
135
TRIGGER_IMPLEMENTATION_SUMMARY.md
Normal file
135
TRIGGER_IMPLEMENTATION_SUMMARY.md
Normal file
@@ -0,0 +1,135 @@
|
||||
# CrewAI CLI Trigger Feature Implementation
|
||||
|
||||
## Overview
|
||||
Successfully implemented the trigger functionality for CrewAI CLI as requested, adding two main commands:
|
||||
- `crewai trigger list` - Lists all triggers grouped by provider
|
||||
- `crewai trigger <app/trigger_name>` - Runs a crew with the specified trigger payload
|
||||
|
||||
## Implementation Details
|
||||
|
||||
### 1. Extended PlusAPI Client (`src/crewai/cli/plus_api.py`)
|
||||
- Added `TRIGGERS_RESOURCE = "/v1/triggers"` endpoint constant
|
||||
- Implemented `list_triggers()` method for GET `/v1/triggers`
|
||||
- Implemented `get_trigger_sample_payload(trigger_identification)` method for POST `/v1/triggers/sample_payload`
|
||||
|
||||
### 2. Created TriggerCommand Class (`src/crewai/cli/trigger_command.py`)
|
||||
- Inherits from `BaseCommand` and `PlusAPIMixin` for proper authentication
|
||||
- Implements `list_triggers()` method with:
|
||||
- Rich table display grouped by provider
|
||||
- Comprehensive error handling for network issues, authentication, etc.
|
||||
- User-friendly messages and styling
|
||||
- Implements `run_trigger(trigger_identification)` method with:
|
||||
- Trigger identification format validation (`app/trigger_name`)
|
||||
- Sample payload retrieval from API
|
||||
- Dynamic crew/flow execution with trigger payload injection
|
||||
- Temporary script generation and cleanup
|
||||
- Robust error handling and validation
|
||||
|
||||
### 3. Integrated CLI Commands (`src/crewai/cli/cli.py`)
|
||||
- Added import for `TriggerCommand`
|
||||
- Implemented `@crewai.command()` decorator for `trigger` command
|
||||
- Supports both `crewai trigger list` and `crewai trigger <app/trigger_name>` syntax
|
||||
- Proper argument parsing and command routing
|
||||
|
||||
### 4. Key Features
|
||||
|
||||
#### Trigger Listing
|
||||
- Fetches triggers from `/v1/triggers` endpoint
|
||||
- Displays triggers in a formatted table grouped by provider
|
||||
- Shows trigger ID and description for each trigger
|
||||
- Provides usage instructions
|
||||
|
||||
#### Trigger Execution
|
||||
- Validates trigger identification format
|
||||
- Fetches sample payload from `/v1/triggers/sample_payload` endpoint
|
||||
- Detects project type (crew vs flow) from `pyproject.toml`
|
||||
- Generates appropriate execution script with trigger payload injection
|
||||
- Executes crew/flow with `uv run python` command
|
||||
- Adds trigger payload to inputs as `crewai_trigger_payload`
|
||||
- Handles cleanup of temporary files
|
||||
|
||||
#### Error Handling
|
||||
- Network connectivity issues
|
||||
- Authentication failures (401)
|
||||
- Authorization issues (403)
|
||||
- Trigger not found (404)
|
||||
- Invalid project structure
|
||||
- Subprocess execution errors
|
||||
- Comprehensive user feedback with actionable suggestions
|
||||
|
||||
### 5. Usage Examples
|
||||
|
||||
```bash
|
||||
# List all available triggers
|
||||
crewai trigger list
|
||||
|
||||
# Run a specific trigger
|
||||
crewai trigger github/pull_request_opened
|
||||
crewai trigger slack/message_received
|
||||
crewai trigger webhook/user_signup
|
||||
```
|
||||
|
||||
### 6. API Integration Points
|
||||
|
||||
#### CrewAI Client → Rails App
|
||||
- GET `/v1/triggers` - Returns triggers grouped by provider
|
||||
- POST `/v1/triggers/sample_payload` with `{"trigger_identification": "app/trigger_name"}`
|
||||
|
||||
#### Expected Response Format
|
||||
```json
|
||||
{
|
||||
"github": {
|
||||
"github/pull_request_opened": {
|
||||
"description": "Triggered when a pull request is opened"
|
||||
},
|
||||
"github/issue_created": {
|
||||
"description": "Triggered when an issue is created"
|
||||
}
|
||||
},
|
||||
"slack": {
|
||||
"slack/message_received": {
|
||||
"description": "Triggered when a message is received"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 7. Crew/Flow Integration
|
||||
The trigger payload is automatically injected into the crew/flow inputs as `crewai_trigger_payload`, allowing crews to access trigger data:
|
||||
|
||||
```python
|
||||
# In crew/flow code
|
||||
def my_crew():
|
||||
crew = Crew(...)
|
||||
result = crew.kickoff(inputs=inputs) # inputs will contain 'crewai_trigger_payload'
|
||||
return result
|
||||
```
|
||||
|
||||
### 8. Dependencies
|
||||
- `click` - CLI framework
|
||||
- `rich` - Enhanced terminal output
|
||||
- `requests` - HTTP client
|
||||
- Existing CrewAI CLI infrastructure (authentication, configuration, etc.)
|
||||
|
||||
## Testing
|
||||
- All imports work correctly
|
||||
- CLI command structure is properly implemented
|
||||
- Error handling is comprehensive
|
||||
- Code follows CrewAI patterns and conventions
|
||||
|
||||
## Next Steps for Backend Implementation
|
||||
|
||||
### Rails App Requirements
|
||||
1. Add `GET /v1/triggers` endpoint
|
||||
2. Add `POST /v1/triggers/sample_payload` endpoint
|
||||
3. Implement integration service method `summarize_triggers`
|
||||
4. Each provider service must implement:
|
||||
- `list_triggers()` method
|
||||
- `get_sample_payload(trigger_identification)` method
|
||||
|
||||
### CrewAI OAuth Requirements
|
||||
1. Implement endpoint that returns sample payload for trigger identification
|
||||
2. Ensure trigger data format matches expected structure
|
||||
|
||||
The CLI implementation is complete and ready for integration with the backend services.
|
||||
|
||||
@@ -5,6 +5,82 @@ icon: "clock"
|
||||
mode: "wide"
|
||||
---
|
||||
|
||||
<Update label="Sep 20, 2025">
|
||||
## v0.193.2
|
||||
|
||||
[View release on GitHub](https://github.com/crewAIInc/crewAI/releases/tag/0.193.2)
|
||||
|
||||
## What's Changed
|
||||
|
||||
- Updated pyproject templates to use the right version
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="Sep 20, 2025">
|
||||
## v0.193.1
|
||||
|
||||
[View release on GitHub](https://github.com/crewAIInc/crewAI/releases/tag/0.193.1)
|
||||
|
||||
## What's Changed
|
||||
|
||||
- Series of minor fixes and linter improvements
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="Sep 19, 2025">
|
||||
## v0.193.0
|
||||
|
||||
[View release on GitHub](https://github.com/crewAIInc/crewAI/releases/tag/0.193.0)
|
||||
|
||||
## Core Improvements & Fixes
|
||||
|
||||
- Fixed handling of the `model` parameter during OpenAI adapter initialization
|
||||
- Resolved test duration cache issues in CI workflows
|
||||
- Fixed flaky test related to repeated tool usage by agents
|
||||
- Added missing event exports to `__init__.py` for consistent module behavior
|
||||
- Dropped message storage from metadata in Mem0 to reduce bloat
|
||||
- Fixed L2 distance metric support for backward compatibility in vector search
|
||||
|
||||
## New Features & Enhancements
|
||||
|
||||
- Introduced thread-safe platform context management
|
||||
- Added test duration caching for optimized `pytest-split` runs
|
||||
- Added ephemeral trace improvements for better trace control
|
||||
- Made search parameters for RAG, knowledge, and memory fully configurable
|
||||
- Enabled ChromaDB to use OpenAI API for embedding functions
|
||||
- Added deeper observability tools for user-level insights
|
||||
- Unified RAG storage system with instance-specific client support
|
||||
|
||||
## Documentation & Guides
|
||||
|
||||
- Updated `RagTool` references to reflect CrewAI native RAG implementation
|
||||
- Improved internal docs for `langgraph` and `openai` agent adapters with type annotations and docstrings
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="Sep 11, 2025">
|
||||
## v0.186.1
|
||||
|
||||
[View release on GitHub](https://github.com/crewAIInc/crewAI/releases/tag/0.186.1)
|
||||
|
||||
## What's Changed
|
||||
|
||||
- Fixed version not being found and silently failing reversion
|
||||
- Bumped CrewAI version to 0.186.1 and updated dependencies in the CLI
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="Sep 10, 2025">
|
||||
## v0.186.0
|
||||
|
||||
[View release on GitHub](https://github.com/crewAIInc/crewAI/releases/tag/0.186.0)
|
||||
|
||||
## What's Changed
|
||||
|
||||
- Refer to the GitHub release notes for detailed changes
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="Sep 04, 2025">
|
||||
## v0.177.0
|
||||
|
||||
|
||||
@@ -404,6 +404,10 @@ crewai config reset
|
||||
After resetting configuration, re-run `crewai login` to authenticate again.
|
||||
</Tip>
|
||||
|
||||
<Tip>
|
||||
CrewAI CLI handles authentication to the Tool Repository automatically when adding packages to your project. Just append `crewai` before any `uv` command to use it. E.g. `crewai uv add requests`. For more information, see [Tool Repository](https://docs.crewai.com/enterprise/features/tool-repository) docs.
|
||||
</Tip>
|
||||
|
||||
<Note>
|
||||
Configuration settings are stored in `~/.config/crewai/settings.json`. Some settings like organization name and UUID are read-only and managed through authentication and organization commands. Tool repository related settings are hidden and cannot be set directly by users.
|
||||
</Note>
|
||||
|
||||
@@ -52,6 +52,36 @@ researcher = Agent(
|
||||
)
|
||||
```
|
||||
|
||||
## Adding other packages after installing a tool
|
||||
|
||||
After installing a tool from the CrewAI Enterprise Tool Repository, you need to use the `crewai uv` command to add other packages to your project.
|
||||
Using pure `uv` commands will fail due to authentication to tool repository being handled by the CLI. By using the `crewai uv` command, you can add other packages to your project without having to worry about authentication.
|
||||
Any `uv` command can be used with the `crewai uv` command, making it a powerful tool for managing your project's dependencies without the hassle of managing authentication through environment variables or other methods.
|
||||
|
||||
Say that you have installed a custom tool from the CrewAI Enterprise Tool Repository called "my-tool":
|
||||
|
||||
```bash
|
||||
crewai tool install my-tool
|
||||
```
|
||||
|
||||
And now you want to add another package to your project, you can use the following command:
|
||||
|
||||
```bash
|
||||
crewai uv add requests
|
||||
```
|
||||
|
||||
Other commands like `uv sync` or `uv remove` can also be used with the `crewai uv` command:
|
||||
|
||||
```bash
|
||||
crewai uv sync
|
||||
```
|
||||
|
||||
```bash
|
||||
crewai uv remove requests
|
||||
```
|
||||
|
||||
This will add the package to your project and update `pyproject.toml` accordingly.
|
||||
|
||||
## Creating and Publishing Tools
|
||||
|
||||
To create a new tool project:
|
||||
|
||||
@@ -27,7 +27,7 @@ Follow the steps below to get Crewing! 🚣♂️
|
||||
<Step title="Navigate to your new crew project">
|
||||
<CodeGroup>
|
||||
```shell Terminal
|
||||
cd latest-ai-development
|
||||
cd latest_ai_development
|
||||
```
|
||||
</CodeGroup>
|
||||
</Step>
|
||||
|
||||
@@ -5,6 +5,82 @@ icon: "clock"
|
||||
mode: "wide"
|
||||
---
|
||||
|
||||
<Update label="2025년 9월 20일">
|
||||
## v0.193.2
|
||||
|
||||
[GitHub 릴리스 보기](https://github.com/crewAIInc/crewAI/releases/tag/0.193.2)
|
||||
|
||||
## 변경 사항
|
||||
|
||||
- 올바른 버전을 사용하도록 pyproject 템플릿 업데이트
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="2025년 9월 20일">
|
||||
## v0.193.1
|
||||
|
||||
[GitHub 릴리스 보기](https://github.com/crewAIInc/crewAI/releases/tag/0.193.1)
|
||||
|
||||
## 변경 사항
|
||||
|
||||
- 일련의 사소한 수정 및 린터 개선
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="2025년 9월 19일">
|
||||
## v0.193.0
|
||||
|
||||
[GitHub 릴리스 보기](https://github.com/crewAIInc/crewAI/releases/tag/0.193.0)
|
||||
|
||||
## 핵심 개선 사항 및 수정 사항
|
||||
|
||||
- OpenAI 어댑터 초기화 중 `model` 매개변수 처리 수정
|
||||
- CI 워크플로에서 테스트 소요 시간 캐시 문제 해결
|
||||
- 에이전트의 반복 도구 사용과 관련된 불안정한 테스트 수정
|
||||
- 일관된 모듈 동작을 위해 누락된 이벤트 내보내기를 `__init__.py`에 추가
|
||||
- 메타데이터 부하를 줄이기 위해 Mem0에서 메시지 저장 제거
|
||||
- 벡터 검색의 하위 호환성을 위해 L2 거리 메트릭 지원 수정
|
||||
|
||||
## 새로운 기능 및 향상 사항
|
||||
|
||||
- 스레드 안전한 플랫폼 컨텍스트 관리 도입
|
||||
- `pytest-split` 실행 최적화를 위한 테스트 소요 시간 캐싱 추가
|
||||
- 더 나은 추적 제어를 위한 일시적(trace) 개선
|
||||
- RAG, 지식, 메모리 검색 매개변수를 완전 구성 가능하게 변경
|
||||
- ChromaDB가 임베딩 함수에 OpenAI API를 사용할 수 있도록 지원
|
||||
- 사용자 수준 인사이트를 위한 심화된 관찰 가능성 도구 추가
|
||||
- 인스턴스별 클라이언트를 지원하는 통합 RAG 스토리지 시스템
|
||||
|
||||
## 문서 및 가이드
|
||||
|
||||
- CrewAI 네이티브 RAG 구현을 반영하도록 `RagTool` 참조 업데이트
|
||||
- 타입 주석과 도크스트링을 포함해 `langgraph` 및 `openai` 에이전트 어댑터 내부 문서 개선
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="2025년 9월 11일">
|
||||
## v0.186.1
|
||||
|
||||
[GitHub 릴리스 보기](https://github.com/crewAIInc/crewAI/releases/tag/0.186.1)
|
||||
|
||||
## 변경 사항
|
||||
|
||||
- 버전을 찾지 못해 조용히 되돌리는(reversion) 문제 수정
|
||||
- CLI에서 CrewAI 버전을 0.186.1로 올리고 의존성 업데이트
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="2025년 9월 10일">
|
||||
## v0.186.0
|
||||
|
||||
[GitHub 릴리스 보기](https://github.com/crewAIInc/crewAI/releases/tag/0.186.0)
|
||||
|
||||
## 변경 사항
|
||||
|
||||
- 자세한 변경 사항은 GitHub 릴리스 노트를 참조하세요
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="2025년 9월 4일">
|
||||
## v0.177.0
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ mode: "wide"
|
||||
<Step title="새로운 crew 프로젝트로 이동하기">
|
||||
<CodeGroup>
|
||||
```shell Terminal
|
||||
cd latest-ai-development
|
||||
cd latest_ai_development
|
||||
```
|
||||
</CodeGroup>
|
||||
</Step>
|
||||
|
||||
@@ -5,6 +5,82 @@ icon: "clock"
|
||||
mode: "wide"
|
||||
---
|
||||
|
||||
<Update label="20 set 2025">
|
||||
## v0.193.2
|
||||
|
||||
[Ver release no GitHub](https://github.com/crewAIInc/crewAI/releases/tag/0.193.2)
|
||||
|
||||
## O que Mudou
|
||||
|
||||
- Atualizados templates do pyproject para usar a versão correta
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="20 set 2025">
|
||||
## v0.193.1
|
||||
|
||||
[Ver release no GitHub](https://github.com/crewAIInc/crewAI/releases/tag/0.193.1)
|
||||
|
||||
## O que Mudou
|
||||
|
||||
- Série de pequenas correções e melhorias de linter
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="19 set 2025">
|
||||
## v0.193.0
|
||||
|
||||
[Ver release no GitHub](https://github.com/crewAIInc/crewAI/releases/tag/0.193.0)
|
||||
|
||||
## Melhorias e Correções Principais
|
||||
|
||||
- Corrigido manuseio do parâmetro `model` durante a inicialização do adaptador OpenAI
|
||||
- Resolvidos problemas de cache da duração de testes nos fluxos de CI
|
||||
- Corrigido teste instável relacionado ao uso repetido de ferramentas pelos agentes
|
||||
- Adicionadas exportações de eventos ausentes no `__init__.py` para comportamento consistente do módulo
|
||||
- Removido armazenamento de mensagem dos metadados no Mem0 para reduzir inchaço
|
||||
- Corrigido suporte à métrica de distância L2 para compatibilidade retroativa na busca vetorial
|
||||
|
||||
## Novos Recursos e Melhorias
|
||||
|
||||
- Introduzida gestão de contexto de plataforma com segurança de threads
|
||||
- Adicionado cache da duração de testes para execuções otimizadas do `pytest-split`
|
||||
- Melhorias de traces efêmeros para melhor controle de rastreamento
|
||||
- Parâmetros de busca para RAG, conhecimento e memória totalmente configuráveis
|
||||
- Habilitado ChromaDB para usar a OpenAI API para funções de embedding
|
||||
- Adicionadas ferramentas de observabilidade mais profundas para insights ao nível do usuário
|
||||
- Sistema de armazenamento RAG unificado com suporte a cliente específico por instância
|
||||
|
||||
## Documentação e Guias
|
||||
|
||||
- Atualizadas referências do `RagTool` para refletir a implementação nativa de RAG do CrewAI
|
||||
- Melhorada documentação interna para adaptadores de agente `langgraph` e `openai` com anotações de tipo e docstrings
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="11 set 2025">
|
||||
## v0.186.1
|
||||
|
||||
[Ver release no GitHub](https://github.com/crewAIInc/crewAI/releases/tag/0.186.1)
|
||||
|
||||
## O que Mudou
|
||||
|
||||
- Corrigida falha silenciosa de reversão quando a versão não era encontrada
|
||||
- Versão do CrewAI atualizada para 0.186.1 e dependências do CLI atualizadas
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="10 set 2025">
|
||||
## v0.186.0
|
||||
|
||||
[Ver release no GitHub](https://github.com/crewAIInc/crewAI/releases/tag/0.186.0)
|
||||
|
||||
## O que Mudou
|
||||
|
||||
- Consulte as notas de lançamento no GitHub para detalhes completos
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="04 set 2025">
|
||||
## v0.177.0
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ Siga os passos abaixo para começar a tripular! 🚣♂️
|
||||
<Step title="Navegue até o novo projeto da sua tripulação">
|
||||
<CodeGroup>
|
||||
```shell Terminal
|
||||
cd latest-ai-development
|
||||
cd latest_ai_development
|
||||
```
|
||||
</CodeGroup>
|
||||
</Step>
|
||||
|
||||
@@ -9,7 +9,7 @@ authors = [
|
||||
]
|
||||
dependencies = [
|
||||
# Core Dependencies
|
||||
"pydantic>=2.4.2",
|
||||
"pydantic>=2.11.9",
|
||||
"openai>=1.13.3",
|
||||
"litellm==1.74.9",
|
||||
"instructor>=1.3.3",
|
||||
@@ -27,7 +27,7 @@ dependencies = [
|
||||
"openpyxl>=3.1.5",
|
||||
"pyvis>=0.3.2",
|
||||
# Authentication and Security
|
||||
"python-dotenv>=1.0.0",
|
||||
"python-dotenv>=1.1.1",
|
||||
"pyjwt>=2.9.0",
|
||||
# Configuration and Utils
|
||||
"click>=8.1.7",
|
||||
@@ -40,6 +40,7 @@ dependencies = [
|
||||
"blinker>=1.9.0",
|
||||
"json5>=0.10.0",
|
||||
"portalocker==2.7.0",
|
||||
"pydantic-settings>=2.10.1",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
@@ -72,23 +73,20 @@ qdrant = [
|
||||
"qdrant-client[fastembed]>=1.14.3",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
dev-dependencies = [
|
||||
"ruff>=0.12.11",
|
||||
"mypy>=1.17.1",
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"ruff>=0.13.1",
|
||||
"mypy>=1.18.2",
|
||||
"pre-commit>=4.3.0",
|
||||
"bandit>=1.8.6",
|
||||
"pillow>=10.2.0",
|
||||
"cairosvg>=2.7.1",
|
||||
"pytest>=8.0.0",
|
||||
"python-dotenv>=1.0.0",
|
||||
"pytest-asyncio>=0.23.7",
|
||||
"pytest-subprocess>=1.5.2",
|
||||
"pytest-recording>=0.13.2",
|
||||
"pytest-randomly>=3.16.0",
|
||||
"pytest-timeout>=2.3.1",
|
||||
"pytest-xdist>=3.6.1",
|
||||
"pytest-split>=0.9.0",
|
||||
"pytest>=8.4.2",
|
||||
"pytest-asyncio>=1.2.0",
|
||||
"pytest-subprocess>=1.5.3",
|
||||
"pytest-recording>=0.13.4",
|
||||
"pytest-randomly>=4.0.1",
|
||||
"pytest-timeout>=2.4.0",
|
||||
"pytest-xdist>=3.8.0",
|
||||
"pytest-split>=0.10.0",
|
||||
"types-requests==2.32.*",
|
||||
"types-pyyaml==6.0.*",
|
||||
"types-regex==2024.11.6.*",
|
||||
|
||||
@@ -40,7 +40,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
|
||||
|
||||
_suppress_pydantic_deprecation_warnings()
|
||||
|
||||
__version__ = "0.193.1"
|
||||
__version__ = "0.193.2"
|
||||
_telemetry_submitted = False
|
||||
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from crewai.agents.constants import (
|
||||
MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE,
|
||||
UNABLE_TO_REPAIR_JSON_RESULTS,
|
||||
)
|
||||
from crewai.utilities import I18N
|
||||
from crewai.utilities.i18n import I18N
|
||||
|
||||
_I18N = I18N()
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import os
|
||||
import subprocess
|
||||
from importlib.metadata import version as get_version
|
||||
|
||||
import click
|
||||
@@ -8,6 +10,7 @@ from crewai.cli.create_crew import create_crew
|
||||
from crewai.cli.create_flow import create_flow
|
||||
from crewai.cli.crew_chat import run_chat
|
||||
from crewai.cli.settings.main import SettingsCommand
|
||||
from crewai.cli.utils import build_env_with_tool_repository_credentials, read_toml
|
||||
from crewai.memory.storage.kickoff_task_outputs_storage import (
|
||||
KickoffTaskOutputsSQLiteStorage,
|
||||
)
|
||||
@@ -25,6 +28,7 @@ from .reset_memories_command import reset_memories_command
|
||||
from .run_crew import run_crew
|
||||
from .tools.main import ToolCommand
|
||||
from .train_crew import train_crew
|
||||
from .trigger_command import TriggerCommand
|
||||
from .update_crew import update_crew
|
||||
|
||||
|
||||
@@ -34,6 +38,46 @@ def crewai():
|
||||
"""Top-level command group for crewai."""
|
||||
|
||||
|
||||
@crewai.command(
|
||||
name="uv",
|
||||
context_settings=dict(
|
||||
ignore_unknown_options=True,
|
||||
),
|
||||
)
|
||||
@click.argument("uv_args", nargs=-1, type=click.UNPROCESSED)
|
||||
def uv(uv_args):
|
||||
"""A wrapper around uv commands that adds custom tool authentication through env vars."""
|
||||
env = os.environ.copy()
|
||||
try:
|
||||
pyproject_data = read_toml()
|
||||
sources = pyproject_data.get("tool", {}).get("uv", {}).get("sources", {})
|
||||
|
||||
for source_config in sources.values():
|
||||
if isinstance(source_config, dict):
|
||||
index = source_config.get("index")
|
||||
if index:
|
||||
index_env = build_env_with_tool_repository_credentials(index)
|
||||
env.update(index_env)
|
||||
except (FileNotFoundError, KeyError) as e:
|
||||
raise SystemExit(
|
||||
"Error. A valid pyproject.toml file is required. Check that a valid pyproject.toml file exists in the current directory."
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise SystemExit(f"Error: {e}") from e
|
||||
|
||||
try:
|
||||
subprocess.run( # noqa: S603
|
||||
["uv", *uv_args], # noqa: S607
|
||||
capture_output=False,
|
||||
env=env,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
except subprocess.CalledProcessError as e:
|
||||
click.secho(f"uv command failed with exit code {e.returncode}", fg="red")
|
||||
raise SystemExit(e.returncode) from e
|
||||
|
||||
|
||||
@crewai.command()
|
||||
@click.argument("type", type=click.Choice(["crew", "flow"]))
|
||||
@click.argument("name")
|
||||
@@ -239,11 +283,6 @@ def deploy():
|
||||
"""Deploy the Crew CLI group."""
|
||||
|
||||
|
||||
@crewai.group()
|
||||
def tool():
|
||||
"""Tool Repository related commands."""
|
||||
|
||||
|
||||
@deploy.command(name="create")
|
||||
@click.option("-y", "--yes", is_flag=True, help="Skip the confirmation prompt")
|
||||
def deploy_create(yes: bool):
|
||||
@@ -291,6 +330,11 @@ def deploy_remove(uuid: str | None):
|
||||
deploy_cmd.remove_crew(uuid=uuid)
|
||||
|
||||
|
||||
@crewai.group()
|
||||
def tool():
|
||||
"""Tool Repository related commands."""
|
||||
|
||||
|
||||
@tool.command(name="create")
|
||||
@click.argument("handle")
|
||||
def tool_create(handle: str):
|
||||
@@ -430,5 +474,18 @@ def config_reset():
|
||||
config_command.reset_all_settings()
|
||||
|
||||
|
||||
@crewai.command()
|
||||
@click.argument("action_or_trigger")
|
||||
def trigger(action_or_trigger: str):
|
||||
"""Trigger management. Use 'list' to list triggers or provide trigger identification to run."""
|
||||
trigger_cmd = TriggerCommand()
|
||||
|
||||
if action_or_trigger == "list":
|
||||
trigger_cmd.list_triggers()
|
||||
else:
|
||||
# Assume it's a trigger identification
|
||||
trigger_cmd.run_trigger(action_or_trigger)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
crewai()
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import json
|
||||
import tempfile
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -12,8 +14,48 @@ from crewai.cli.constants import (
|
||||
)
|
||||
from crewai.cli.shared.token_manager import TokenManager
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
DEFAULT_CONFIG_PATH = Path.home() / ".config" / "crewai" / "settings.json"
|
||||
|
||||
|
||||
def get_writable_config_path() -> Path | None:
|
||||
"""
|
||||
Find a writable location for the config file with fallback options.
|
||||
|
||||
Tries in order:
|
||||
1. Default: ~/.config/crewai/settings.json
|
||||
2. Temp directory: /tmp/crewai_settings.json (or OS equivalent)
|
||||
3. Current directory: ./crewai_settings.json
|
||||
4. In-memory only (returns None)
|
||||
|
||||
Returns:
|
||||
Path object for writable config location, or None if no writable location found
|
||||
"""
|
||||
fallback_paths = [
|
||||
DEFAULT_CONFIG_PATH, # Default location
|
||||
Path(tempfile.gettempdir()) / "crewai_settings.json", # Temporary directory
|
||||
Path.cwd() / "crewai_settings.json", # Current working directory
|
||||
]
|
||||
|
||||
for config_path in fallback_paths:
|
||||
try:
|
||||
config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
test_file = config_path.parent / ".crewai_write_test"
|
||||
try:
|
||||
test_file.write_text("test")
|
||||
test_file.unlink() # Clean up test file
|
||||
logger.info(f"Using config path: {config_path}")
|
||||
return config_path
|
||||
except Exception: # noqa: S112
|
||||
continue
|
||||
|
||||
except Exception: # noqa: S112
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# Settings that are related to the user's account
|
||||
USER_SETTINGS_KEYS = [
|
||||
"tool_repository_username",
|
||||
@@ -93,16 +135,32 @@ class Settings(BaseModel):
|
||||
default=DEFAULT_CLI_SETTINGS["oauth2_domain"],
|
||||
)
|
||||
|
||||
def __init__(self, config_path: Path = DEFAULT_CONFIG_PATH, **data):
|
||||
"""Load Settings from config path"""
|
||||
config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
def __init__(self, config_path: Path | None = None, **data):
|
||||
"""Load Settings from config path with fallback support"""
|
||||
if config_path is None:
|
||||
config_path = get_writable_config_path()
|
||||
|
||||
# If config_path is None, we're in memory-only mode
|
||||
if config_path is None:
|
||||
merged_data = {**data}
|
||||
# Dummy path for memory-only mode
|
||||
super().__init__(config_path=Path("/dev/null"), **merged_data)
|
||||
return
|
||||
|
||||
try:
|
||||
config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
except Exception:
|
||||
merged_data = {**data}
|
||||
# Dummy path for memory-only mode
|
||||
super().__init__(config_path=Path("/dev/null"), **merged_data)
|
||||
return
|
||||
|
||||
file_data = {}
|
||||
if config_path.is_file():
|
||||
try:
|
||||
with config_path.open("r") as f:
|
||||
file_data = json.load(f)
|
||||
except json.JSONDecodeError:
|
||||
except Exception:
|
||||
file_data = {}
|
||||
|
||||
merged_data = {**file_data, **data}
|
||||
@@ -122,15 +180,22 @@ class Settings(BaseModel):
|
||||
|
||||
def dump(self) -> None:
|
||||
"""Save current settings to settings.json"""
|
||||
if self.config_path.is_file():
|
||||
with self.config_path.open("r") as f:
|
||||
existing_data = json.load(f)
|
||||
else:
|
||||
existing_data = {}
|
||||
if str(self.config_path) == "/dev/null":
|
||||
return
|
||||
|
||||
updated_data = {**existing_data, **self.model_dump(exclude_unset=True)}
|
||||
with self.config_path.open("w") as f:
|
||||
json.dump(updated_data, f, indent=4)
|
||||
try:
|
||||
if self.config_path.is_file():
|
||||
with self.config_path.open("r") as f:
|
||||
existing_data = json.load(f)
|
||||
else:
|
||||
existing_data = {}
|
||||
|
||||
updated_data = {**existing_data, **self.model_dump(exclude_unset=True)}
|
||||
with self.config_path.open("w") as f:
|
||||
json.dump(updated_data, f, indent=4)
|
||||
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
def _reset_user_settings(self) -> None:
|
||||
"""Reset all user settings to default values"""
|
||||
|
||||
@@ -18,6 +18,7 @@ class PlusAPI:
|
||||
AGENTS_RESOURCE = "/crewai_plus/api/v1/agents"
|
||||
TRACING_RESOURCE = "/crewai_plus/api/v1/tracing"
|
||||
EPHEMERAL_TRACING_RESOURCE = "/crewai_plus/api/v1/tracing/ephemeral"
|
||||
TRIGGERS_RESOURCE = "/v1/triggers"
|
||||
|
||||
def __init__(self, api_key: str) -> None:
|
||||
self.api_key = api_key
|
||||
@@ -166,3 +167,25 @@ class PlusAPI:
|
||||
json=payload,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
def mark_trace_batch_as_failed(
|
||||
self, trace_batch_id: str, error_message: str
|
||||
) -> requests.Response:
|
||||
return self._make_request(
|
||||
"PATCH",
|
||||
f"{self.TRACING_RESOURCE}/batches/{trace_batch_id}",
|
||||
json={"status": "failed", "failure_reason": error_message},
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
def list_triggers(self) -> requests.Response:
|
||||
"""List all triggers from the current user."""
|
||||
return self._make_request("GET", self.TRIGGERS_RESOURCE)
|
||||
|
||||
def get_trigger_sample_payload(self, trigger_identification: str) -> requests.Response:
|
||||
"""Get sample payload for a trigger identification."""
|
||||
return self._make_request(
|
||||
"POST",
|
||||
f"{self.TRIGGERS_RESOURCE}/sample_payload",
|
||||
json={"trigger_identification": trigger_identification}
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]>=0.193.0,<1.0.0"
|
||||
"crewai[tools]>=0.193.2,<1.0.0"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]>=0.193.0,<1.0.0",
|
||||
"crewai[tools]>=0.193.2,<1.0.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "Power up your crews with {{folder_name}}"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]>=0.193.0"
|
||||
"crewai[tools]>=0.193.2"
|
||||
]
|
||||
|
||||
[tool.crewai]
|
||||
|
||||
@@ -12,6 +12,7 @@ from crewai.cli import git
|
||||
from crewai.cli.command import BaseCommand, PlusAPIMixin
|
||||
from crewai.cli.config import Settings
|
||||
from crewai.cli.utils import (
|
||||
build_env_with_tool_repository_credentials,
|
||||
extract_available_exports,
|
||||
get_project_description,
|
||||
get_project_name,
|
||||
@@ -42,8 +43,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
if project_root.exists():
|
||||
click.secho(f"Folder {folder_name} already exists.", fg="red")
|
||||
raise SystemExit
|
||||
else:
|
||||
os.makedirs(project_root)
|
||||
os.makedirs(project_root)
|
||||
|
||||
click.secho(f"Creating custom tool {folder_name}...", fg="green", bold=True)
|
||||
|
||||
@@ -56,7 +56,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
os.chdir(project_root)
|
||||
try:
|
||||
self.login()
|
||||
subprocess.run(["git", "init"], check=True)
|
||||
subprocess.run(["git", "init"], check=True) # noqa: S607
|
||||
console.print(
|
||||
f"[green]Created custom tool [bold]{folder_name}[/bold]. Run [bold]cd {project_root}[/bold] to start working.[/green]"
|
||||
)
|
||||
@@ -76,10 +76,10 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
raise SystemExit()
|
||||
|
||||
project_name = get_project_name(require=True)
|
||||
assert isinstance(project_name, str)
|
||||
assert isinstance(project_name, str) # noqa: S101
|
||||
|
||||
project_version = get_project_version(require=True)
|
||||
assert isinstance(project_version, str)
|
||||
assert isinstance(project_version, str) # noqa: S101
|
||||
|
||||
project_description = get_project_description(require=False)
|
||||
encoded_tarball = None
|
||||
@@ -94,8 +94,8 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
self._print_current_organization()
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_build_dir:
|
||||
subprocess.run(
|
||||
["uv", "build", "--sdist", "--out-dir", temp_build_dir],
|
||||
subprocess.run( # noqa: S603
|
||||
["uv", "build", "--sdist", "--out-dir", temp_build_dir], # noqa: S607
|
||||
check=True,
|
||||
capture_output=False,
|
||||
)
|
||||
@@ -146,7 +146,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
style="bold red",
|
||||
)
|
||||
raise SystemExit
|
||||
elif get_response.status_code != 200:
|
||||
if get_response.status_code != 200:
|
||||
console.print(
|
||||
"Failed to get tool details. Please try again later.", style="bold red"
|
||||
)
|
||||
@@ -196,10 +196,10 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
else:
|
||||
add_package_command.extend(["--index", index, tool_handle])
|
||||
|
||||
add_package_result = subprocess.run(
|
||||
add_package_result = subprocess.run( # noqa: S603
|
||||
add_package_command,
|
||||
capture_output=False,
|
||||
env=self._build_env_with_credentials(repository_handle),
|
||||
env=build_env_with_tool_repository_credentials(repository_handle),
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
@@ -221,20 +221,6 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
)
|
||||
raise SystemExit
|
||||
|
||||
def _build_env_with_credentials(self, repository_handle: str):
|
||||
repository_handle = repository_handle.upper().replace("-", "_")
|
||||
settings = Settings()
|
||||
|
||||
env = os.environ.copy()
|
||||
env[f"UV_INDEX_{repository_handle}_USERNAME"] = str(
|
||||
settings.tool_repository_username or ""
|
||||
)
|
||||
env[f"UV_INDEX_{repository_handle}_PASSWORD"] = str(
|
||||
settings.tool_repository_password or ""
|
||||
)
|
||||
|
||||
return env
|
||||
|
||||
def _print_current_organization(self) -> None:
|
||||
settings = Settings()
|
||||
if settings.org_uuid:
|
||||
|
||||
315
src/crewai/cli/trigger_command.py
Normal file
315
src/crewai/cli/trigger_command.py
Normal file
@@ -0,0 +1,315 @@
|
||||
import sys
|
||||
import os
|
||||
import subprocess
|
||||
from typing import Dict, Any
|
||||
|
||||
import click
|
||||
import requests
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
|
||||
from crewai.cli.command import BaseCommand, PlusAPIMixin
|
||||
from crewai.telemetry.telemetry import Telemetry
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
class TriggerCommand(BaseCommand, PlusAPIMixin):
|
||||
"""Command handler for trigger-related operations."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the trigger command with telemetry and API client."""
|
||||
self._telemetry = Telemetry()
|
||||
super().__init__()
|
||||
PlusAPIMixin.__init__(self, self._telemetry)
|
||||
|
||||
def list_triggers(self) -> None:
|
||||
"""List all triggers grouped by provider name."""
|
||||
try:
|
||||
console.print("Fetching triggers from CrewAI API...", style="blue")
|
||||
|
||||
# Fetch triggers from API
|
||||
response = self.plus_api_client.list_triggers()
|
||||
self._validate_response(response)
|
||||
|
||||
triggers_data = response.json()
|
||||
|
||||
if not triggers_data:
|
||||
console.print(
|
||||
"No triggers found for the current user.", style="yellow"
|
||||
)
|
||||
return
|
||||
|
||||
# Display triggers grouped by provider
|
||||
self._display_triggers(triggers_data)
|
||||
|
||||
except requests.exceptions.ConnectionError:
|
||||
console.print(
|
||||
"Failed to connect to CrewAI API. Please check your internet connection.",
|
||||
style="bold red"
|
||||
)
|
||||
raise SystemExit(1)
|
||||
except requests.exceptions.Timeout:
|
||||
console.print(
|
||||
"Request to CrewAI API timed out. Please try again later.",
|
||||
style="bold red"
|
||||
)
|
||||
raise SystemExit(1)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response.status_code == 401:
|
||||
console.print(
|
||||
"Authentication failed. Please run 'crewai login' to authenticate.",
|
||||
style="bold red"
|
||||
)
|
||||
elif e.response.status_code == 403:
|
||||
console.print(
|
||||
"Access denied. You may not have permission to access triggers.",
|
||||
style="bold red"
|
||||
)
|
||||
else:
|
||||
console.print(f"HTTP error occurred: {e}", style="bold red")
|
||||
raise SystemExit(1)
|
||||
except Exception as e:
|
||||
console.print(f"Unexpected error listing triggers: {e}", style="bold red")
|
||||
console.print("Please check your configuration and try again.", style="yellow")
|
||||
raise SystemExit(1)
|
||||
|
||||
def run_trigger(self, trigger_identification: str) -> None:
|
||||
"""Run a crew with the specified trigger payload."""
|
||||
try:
|
||||
# Validate trigger identification format
|
||||
if not trigger_identification or "/" not in trigger_identification:
|
||||
console.print(
|
||||
"Invalid trigger identification format. Expected format: 'app/trigger_name'",
|
||||
style="bold red"
|
||||
)
|
||||
console.print(
|
||||
"Use 'crewai trigger list' to see available triggers.", style="yellow"
|
||||
)
|
||||
raise SystemExit(1)
|
||||
|
||||
# Get sample payload for the trigger
|
||||
console.print(f"Getting sample payload for trigger: {trigger_identification}", style="blue")
|
||||
response = self.plus_api_client.get_trigger_sample_payload(trigger_identification)
|
||||
self._validate_response(response)
|
||||
|
||||
trigger_payload = response.json()
|
||||
|
||||
if not trigger_payload:
|
||||
console.print(
|
||||
f"No sample payload found for trigger: {trigger_identification}",
|
||||
style="yellow"
|
||||
)
|
||||
console.print(
|
||||
"Use 'crewai trigger list' to see available triggers.", style="yellow"
|
||||
)
|
||||
return
|
||||
|
||||
console.print("Sample payload retrieved successfully", style="green")
|
||||
|
||||
# Import and run the crew with the trigger payload
|
||||
self._run_crew_with_payload(trigger_payload)
|
||||
|
||||
except requests.exceptions.ConnectionError:
|
||||
console.print(
|
||||
"Failed to connect to CrewAI API. Please check your internet connection.",
|
||||
style="bold red"
|
||||
)
|
||||
raise SystemExit(1)
|
||||
except requests.exceptions.Timeout:
|
||||
console.print(
|
||||
"Request to CrewAI API timed out. Please try again later.",
|
||||
style="bold red"
|
||||
)
|
||||
raise SystemExit(1)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response.status_code == 401:
|
||||
console.print(
|
||||
"Authentication failed. Please run 'crewai login' to authenticate.",
|
||||
style="bold red"
|
||||
)
|
||||
elif e.response.status_code == 404:
|
||||
console.print(
|
||||
f"Trigger '{trigger_identification}' not found.",
|
||||
style="bold red"
|
||||
)
|
||||
console.print(
|
||||
"Use 'crewai trigger list' to see available triggers.", style="yellow"
|
||||
)
|
||||
elif e.response.status_code == 403:
|
||||
console.print(
|
||||
"Access denied. You may not have permission to access this trigger.",
|
||||
style="bold red"
|
||||
)
|
||||
else:
|
||||
console.print(f"HTTP error occurred: {e}", style="bold red")
|
||||
raise SystemExit(1)
|
||||
except FileNotFoundError as e:
|
||||
console.print(
|
||||
f"Project file not found: {e}", style="bold red"
|
||||
)
|
||||
console.print(
|
||||
"Make sure you're in a valid CrewAI project directory.", style="yellow"
|
||||
)
|
||||
raise SystemExit(1)
|
||||
except subprocess.CalledProcessError as e:
|
||||
console.print(f"Error running crew: {e}", style="bold red")
|
||||
if e.output:
|
||||
console.print(f"Output: {e.output}", style="red")
|
||||
raise SystemExit(1)
|
||||
except Exception as e:
|
||||
console.print(f"Unexpected error running trigger: {e}", style="bold red")
|
||||
console.print("Please check your configuration and try again.", style="yellow")
|
||||
raise SystemExit(1)
|
||||
|
||||
def _display_triggers(self, triggers_data: Dict[str, Any]) -> None:
|
||||
"""Display triggers in a formatted table grouped by provider."""
|
||||
table = Table(title="Available Triggers")
|
||||
table.add_column("Provider", style="cyan", no_wrap=True)
|
||||
table.add_column("Trigger ID", style="magenta")
|
||||
table.add_column("Description", style="green")
|
||||
|
||||
# Group triggers by provider
|
||||
for provider_name, triggers in triggers_data.items():
|
||||
if isinstance(triggers, dict):
|
||||
# Add provider header
|
||||
first_trigger = True
|
||||
|
||||
for trigger_id, trigger_info in triggers.items():
|
||||
description = trigger_info.get("description", "No description available")
|
||||
|
||||
# Display provider name only for the first trigger of each provider
|
||||
provider_display = provider_name if first_trigger else ""
|
||||
first_trigger = False
|
||||
|
||||
table.add_row(
|
||||
provider_display,
|
||||
trigger_id,
|
||||
description
|
||||
)
|
||||
|
||||
# Add separator between providers (except for the last one)
|
||||
if provider_name != list(triggers_data.keys())[-1]:
|
||||
table.add_row("", "", "")
|
||||
|
||||
console.print(table)
|
||||
console.print("\nTo run a trigger, use: [bold green]crewai trigger <trigger_id>[/bold green]")
|
||||
|
||||
def _run_crew_with_payload(self, trigger_payload: Dict[str, Any]) -> None:
|
||||
"""Run the crew with the trigger payload."""
|
||||
script_path = None
|
||||
try:
|
||||
from crewai.cli.utils import read_toml
|
||||
|
||||
# Validate project structure
|
||||
if not os.path.exists("pyproject.toml"):
|
||||
raise FileNotFoundError("pyproject.toml not found. Make sure you're in a CrewAI project directory.")
|
||||
|
||||
if not os.path.exists("src"):
|
||||
raise FileNotFoundError("src directory not found. Make sure you're in a CrewAI project directory.")
|
||||
|
||||
if not os.path.exists("src/main.py"):
|
||||
raise FileNotFoundError("src/main.py not found. Make sure you have a valid CrewAI project.")
|
||||
|
||||
# Read project configuration
|
||||
pyproject_data = read_toml()
|
||||
is_flow = pyproject_data.get("tool", {}).get("crewai", {}).get("type") == "flow"
|
||||
|
||||
console.print(f"Project type detected: {'Flow' if is_flow else 'Crew'}")
|
||||
console.print("Preparing execution environment...")
|
||||
|
||||
# Create a temporary script to run the crew with trigger payload
|
||||
script_content = self._generate_crew_script(trigger_payload, is_flow)
|
||||
|
||||
# Write script to temporary file
|
||||
script_path = "temp_trigger_run.py"
|
||||
with open(script_path, "w") as f:
|
||||
f.write(script_content)
|
||||
|
||||
console.print(f"Running {'flow' if is_flow else 'crew'} with trigger payload...", style="blue")
|
||||
|
||||
# Execute the script
|
||||
command = ["uv", "run", "python", script_path]
|
||||
result = subprocess.run(command, check=True, capture_output=True, text=True)
|
||||
|
||||
# Display success message
|
||||
console.print("✓ Execution completed successfully!", style="bold green")
|
||||
if result.stdout:
|
||||
console.print("Output:", style="blue")
|
||||
console.print(result.stdout)
|
||||
|
||||
except FileNotFoundError as e:
|
||||
raise # Re-raise to be caught by the outer try-catch
|
||||
except subprocess.CalledProcessError as e:
|
||||
error_msg = f"Crew execution failed with exit code {e.returncode}"
|
||||
if e.stderr:
|
||||
error_msg += f"\nError output: {e.stderr}"
|
||||
if e.stdout:
|
||||
error_msg += f"\nStandard output: {e.stdout}"
|
||||
raise subprocess.CalledProcessError(e.returncode, e.cmd, error_msg)
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to execute crew: {str(e)}")
|
||||
finally:
|
||||
# Clean up temporary script
|
||||
if script_path and os.path.exists(script_path):
|
||||
try:
|
||||
os.remove(script_path)
|
||||
except OSError:
|
||||
console.print(f"Warning: Could not remove temporary file {script_path}", style="yellow")
|
||||
|
||||
def _generate_crew_script(self, trigger_payload: Dict[str, Any], is_flow: bool) -> str:
|
||||
"""Generate a Python script to run the crew with trigger payload."""
|
||||
if is_flow:
|
||||
return f"""
|
||||
import sys
|
||||
sys.path.append('src')
|
||||
|
||||
from main import *
|
||||
|
||||
def main():
|
||||
try:
|
||||
# Initialize and run the flow with trigger payload
|
||||
flow = main()
|
||||
|
||||
# Add trigger payload to inputs
|
||||
inputs = {{"crewai_trigger_payload": {trigger_payload}}}
|
||||
|
||||
result = flow.kickoff(inputs=inputs)
|
||||
print("Flow execution completed successfully")
|
||||
print(f"Result: {{result}}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error running flow: {{e}}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
"""
|
||||
else:
|
||||
return f"""
|
||||
import sys
|
||||
sys.path.append('src')
|
||||
|
||||
def main():
|
||||
try:
|
||||
# Import the crew
|
||||
from main import main as crew_main
|
||||
|
||||
# Get the crew instance
|
||||
crew = crew_main()
|
||||
|
||||
# Add trigger payload to inputs
|
||||
inputs = {{"crewai_trigger_payload": {trigger_payload}}}
|
||||
|
||||
result = crew.kickoff(inputs=inputs)
|
||||
print("Crew execution completed successfully")
|
||||
print(f"Result: {{result}}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error running crew: {{e}}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
"""
|
||||
@@ -11,6 +11,7 @@ import click
|
||||
import tomli
|
||||
from rich.console import Console
|
||||
|
||||
from crewai.cli.config import Settings
|
||||
from crewai.cli.constants import ENV_VARS
|
||||
from crewai.crew import Crew
|
||||
from crewai.flow import Flow
|
||||
@@ -417,6 +418,21 @@ def extract_available_exports(dir_path: str = "src"):
|
||||
raise SystemExit(1) from e
|
||||
|
||||
|
||||
def build_env_with_tool_repository_credentials(repository_handle: str):
|
||||
repository_handle = repository_handle.upper().replace("-", "_")
|
||||
settings = Settings()
|
||||
|
||||
env = os.environ.copy()
|
||||
env[f"UV_INDEX_{repository_handle}_USERNAME"] = str(
|
||||
settings.tool_repository_username or ""
|
||||
)
|
||||
env[f"UV_INDEX_{repository_handle}_PASSWORD"] = str(
|
||||
settings.tool_repository_password or ""
|
||||
)
|
||||
|
||||
return env
|
||||
|
||||
|
||||
def _load_tools_from_init(init_file: Path) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Load and validate tools from a given __init__.py file.
|
||||
|
||||
@@ -200,6 +200,9 @@ class TraceBatchManager:
|
||||
if self.event_buffer:
|
||||
events_sent_to_backend_status = self._send_events_to_backend()
|
||||
if events_sent_to_backend_status == 500:
|
||||
self.plus_api.mark_trace_batch_as_failed(
|
||||
self.trace_batch_id, "Error sending events to backend"
|
||||
)
|
||||
return None
|
||||
self._finalize_backend_batch()
|
||||
|
||||
@@ -273,10 +276,13 @@ class TraceBatchManager:
|
||||
logger.error(
|
||||
f"❌ Failed to finalize trace batch: {response.status_code} - {response.text}"
|
||||
)
|
||||
self.plus_api.mark_trace_batch_as_failed(
|
||||
self.trace_batch_id, response.text
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error finalizing trace batch: {e}")
|
||||
# TODO: send error to app marking as failed
|
||||
self.plus_api.mark_trace_batch_as_failed(self.trace_batch_id, str(e))
|
||||
|
||||
def _cleanup_batch_data(self):
|
||||
"""Clean up batch data after successful finalization to free memory"""
|
||||
|
||||
@@ -1,40 +1,39 @@
|
||||
from crewai.experimental.evaluation import (
|
||||
AgentEvaluationResult,
|
||||
AgentEvaluator,
|
||||
BaseEvaluator,
|
||||
EvaluationScore,
|
||||
MetricCategory,
|
||||
AgentEvaluationResult,
|
||||
SemanticQualityEvaluator,
|
||||
GoalAlignmentEvaluator,
|
||||
ReasoningEfficiencyEvaluator,
|
||||
ToolSelectionEvaluator,
|
||||
ParameterExtractionEvaluator,
|
||||
ToolInvocationEvaluator,
|
||||
EvaluationTraceCallback,
|
||||
create_evaluation_callbacks,
|
||||
AgentEvaluator,
|
||||
create_default_evaluator,
|
||||
ExperimentRunner,
|
||||
ExperimentResults,
|
||||
ExperimentResult,
|
||||
ExperimentResults,
|
||||
ExperimentRunner,
|
||||
GoalAlignmentEvaluator,
|
||||
MetricCategory,
|
||||
ParameterExtractionEvaluator,
|
||||
ReasoningEfficiencyEvaluator,
|
||||
SemanticQualityEvaluator,
|
||||
ToolInvocationEvaluator,
|
||||
ToolSelectionEvaluator,
|
||||
create_default_evaluator,
|
||||
create_evaluation_callbacks,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AgentEvaluationResult",
|
||||
"AgentEvaluator",
|
||||
"BaseEvaluator",
|
||||
"EvaluationScore",
|
||||
"MetricCategory",
|
||||
"AgentEvaluationResult",
|
||||
"SemanticQualityEvaluator",
|
||||
"GoalAlignmentEvaluator",
|
||||
"ReasoningEfficiencyEvaluator",
|
||||
"ToolSelectionEvaluator",
|
||||
"ParameterExtractionEvaluator",
|
||||
"ToolInvocationEvaluator",
|
||||
"EvaluationTraceCallback",
|
||||
"create_evaluation_callbacks",
|
||||
"AgentEvaluator",
|
||||
"create_default_evaluator",
|
||||
"ExperimentRunner",
|
||||
"ExperimentResult",
|
||||
"ExperimentResults",
|
||||
"ExperimentResult"
|
||||
]
|
||||
"ExperimentRunner",
|
||||
"GoalAlignmentEvaluator",
|
||||
"MetricCategory",
|
||||
"ParameterExtractionEvaluator",
|
||||
"ReasoningEfficiencyEvaluator",
|
||||
"SemanticQualityEvaluator",
|
||||
"ToolInvocationEvaluator",
|
||||
"ToolSelectionEvaluator",
|
||||
"create_default_evaluator",
|
||||
"create_evaluation_callbacks",
|
||||
]
|
||||
|
||||
@@ -1,51 +1,47 @@
|
||||
from crewai.experimental.evaluation.agent_evaluator import (
|
||||
AgentEvaluator,
|
||||
create_default_evaluator,
|
||||
)
|
||||
from crewai.experimental.evaluation.base_evaluator import (
|
||||
AgentEvaluationResult,
|
||||
BaseEvaluator,
|
||||
EvaluationScore,
|
||||
MetricCategory,
|
||||
AgentEvaluationResult
|
||||
)
|
||||
|
||||
from crewai.experimental.evaluation.metrics import (
|
||||
SemanticQualityEvaluator,
|
||||
GoalAlignmentEvaluator,
|
||||
ReasoningEfficiencyEvaluator,
|
||||
ToolSelectionEvaluator,
|
||||
ParameterExtractionEvaluator,
|
||||
ToolInvocationEvaluator
|
||||
)
|
||||
|
||||
from crewai.experimental.evaluation.evaluation_listener import (
|
||||
EvaluationTraceCallback,
|
||||
create_evaluation_callbacks
|
||||
create_evaluation_callbacks,
|
||||
)
|
||||
|
||||
from crewai.experimental.evaluation.agent_evaluator import (
|
||||
AgentEvaluator,
|
||||
create_default_evaluator
|
||||
)
|
||||
|
||||
from crewai.experimental.evaluation.experiment import (
|
||||
ExperimentRunner,
|
||||
ExperimentResult,
|
||||
ExperimentResults,
|
||||
ExperimentResult
|
||||
ExperimentRunner,
|
||||
)
|
||||
from crewai.experimental.evaluation.metrics import (
|
||||
GoalAlignmentEvaluator,
|
||||
ParameterExtractionEvaluator,
|
||||
ReasoningEfficiencyEvaluator,
|
||||
SemanticQualityEvaluator,
|
||||
ToolInvocationEvaluator,
|
||||
ToolSelectionEvaluator,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AgentEvaluationResult",
|
||||
"AgentEvaluator",
|
||||
"BaseEvaluator",
|
||||
"EvaluationScore",
|
||||
"MetricCategory",
|
||||
"AgentEvaluationResult",
|
||||
"SemanticQualityEvaluator",
|
||||
"GoalAlignmentEvaluator",
|
||||
"ReasoningEfficiencyEvaluator",
|
||||
"ToolSelectionEvaluator",
|
||||
"ParameterExtractionEvaluator",
|
||||
"ToolInvocationEvaluator",
|
||||
"EvaluationTraceCallback",
|
||||
"create_evaluation_callbacks",
|
||||
"AgentEvaluator",
|
||||
"create_default_evaluator",
|
||||
"ExperimentRunner",
|
||||
"ExperimentResult",
|
||||
"ExperimentResults",
|
||||
"ExperimentResult"
|
||||
"ExperimentRunner",
|
||||
"GoalAlignmentEvaluator",
|
||||
"MetricCategory",
|
||||
"ParameterExtractionEvaluator",
|
||||
"ReasoningEfficiencyEvaluator",
|
||||
"SemanticQualityEvaluator",
|
||||
"ToolInvocationEvaluator",
|
||||
"ToolSelectionEvaluator",
|
||||
"create_default_evaluator",
|
||||
"create_evaluation_callbacks",
|
||||
]
|
||||
|
||||
@@ -1,34 +1,36 @@
|
||||
import threading
|
||||
from typing import Any, Optional
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from crewai.experimental.evaluation.base_evaluator import (
|
||||
AgentEvaluationResult,
|
||||
AggregationStrategy,
|
||||
)
|
||||
from crewai.agent import Agent
|
||||
from crewai.task import Task
|
||||
from crewai.experimental.evaluation.evaluation_display import EvaluationDisplayFormatter
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentEvaluationStartedEvent,
|
||||
AgentEvaluationCompletedEvent,
|
||||
AgentEvaluationFailedEvent,
|
||||
AgentEvaluationStartedEvent,
|
||||
LiteAgentExecutionCompletedEvent,
|
||||
)
|
||||
from crewai.experimental.evaluation import BaseEvaluator, create_evaluation_callbacks
|
||||
from collections.abc import Sequence
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.utils.console_formatter import ConsoleFormatter
|
||||
from crewai.events.types.task_events import TaskCompletedEvent
|
||||
from crewai.events.types.agent_events import LiteAgentExecutionCompletedEvent
|
||||
from crewai.events.utils.console_formatter import ConsoleFormatter
|
||||
from crewai.experimental.evaluation.base_evaluator import (
|
||||
AgentAggregatedEvaluationResult,
|
||||
AgentEvaluationResult,
|
||||
AggregationStrategy,
|
||||
BaseEvaluator,
|
||||
EvaluationScore,
|
||||
MetricCategory,
|
||||
)
|
||||
from crewai.experimental.evaluation.evaluation_display import EvaluationDisplayFormatter
|
||||
from crewai.experimental.evaluation.evaluation_listener import (
|
||||
create_evaluation_callbacks,
|
||||
)
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
class ExecutionState:
|
||||
current_agent_id: Optional[str] = None
|
||||
current_task_id: Optional[str] = None
|
||||
current_agent_id: str | None = None
|
||||
current_task_id: str | None = None
|
||||
|
||||
def __init__(self):
|
||||
self.traces = {}
|
||||
@@ -40,10 +42,10 @@ class ExecutionState:
|
||||
class AgentEvaluator:
|
||||
def __init__(
|
||||
self,
|
||||
agents: list[Agent],
|
||||
agents: list[Agent] | list[BaseAgent],
|
||||
evaluators: Sequence[BaseEvaluator] | None = None,
|
||||
):
|
||||
self.agents: list[Agent] = agents
|
||||
self.agents: list[Agent] | list[BaseAgent] = agents
|
||||
self.evaluators: Sequence[BaseEvaluator] | None = evaluators
|
||||
|
||||
self.callback = create_evaluation_callbacks()
|
||||
@@ -75,7 +77,8 @@ class AgentEvaluator:
|
||||
)
|
||||
|
||||
def _handle_task_completed(self, source: Any, event: TaskCompletedEvent) -> None:
|
||||
assert event.task is not None
|
||||
if event.task is None:
|
||||
raise ValueError("TaskCompletedEvent must have a task")
|
||||
agent = event.task.agent
|
||||
if (
|
||||
agent
|
||||
@@ -92,9 +95,8 @@ class AgentEvaluator:
|
||||
state.current_agent_id = str(agent.id)
|
||||
state.current_task_id = str(event.task.id)
|
||||
|
||||
assert (
|
||||
state.current_agent_id is not None and state.current_task_id is not None
|
||||
)
|
||||
if state.current_agent_id is None or state.current_task_id is None:
|
||||
raise ValueError("Agent ID and Task ID must not be None")
|
||||
trace = self.callback.get_trace(
|
||||
state.current_agent_id, state.current_task_id
|
||||
)
|
||||
@@ -146,9 +148,8 @@ class AgentEvaluator:
|
||||
if not target_agent:
|
||||
return
|
||||
|
||||
assert (
|
||||
state.current_agent_id is not None and state.current_task_id is not None
|
||||
)
|
||||
if state.current_agent_id is None or state.current_task_id is None:
|
||||
raise ValueError("Agent ID and Task ID must not be None")
|
||||
trace = self.callback.get_trace(
|
||||
state.current_agent_id, state.current_task_id
|
||||
)
|
||||
@@ -244,7 +245,7 @@ class AgentEvaluator:
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
agent: Agent,
|
||||
agent: Agent | BaseAgent,
|
||||
execution_trace: dict[str, Any],
|
||||
final_output: Any,
|
||||
state: ExecutionState,
|
||||
@@ -255,7 +256,8 @@ class AgentEvaluator:
|
||||
task_id=state.current_task_id or (str(task.id) if task else "unknown_task"),
|
||||
)
|
||||
|
||||
assert self.evaluators is not None
|
||||
if self.evaluators is None:
|
||||
raise ValueError("Evaluators must be initialized")
|
||||
task_id = str(task.id) if task else None
|
||||
for evaluator in self.evaluators:
|
||||
try:
|
||||
@@ -276,7 +278,7 @@ class AgentEvaluator:
|
||||
metric_category=evaluator.metric_category,
|
||||
score=score,
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception as e: # noqa: PERF203
|
||||
self.emit_evaluation_failed_event(
|
||||
agent_role=agent.role,
|
||||
agent_id=str(agent.id),
|
||||
@@ -284,7 +286,7 @@ class AgentEvaluator:
|
||||
error=str(e),
|
||||
)
|
||||
self.console_formatter.print(
|
||||
f"Error in {evaluator.metric_category.value} evaluator: {str(e)}"
|
||||
f"Error in {evaluator.metric_category.value} evaluator: {e!s}"
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -337,14 +339,14 @@ class AgentEvaluator:
|
||||
)
|
||||
|
||||
|
||||
def create_default_evaluator(agents: list[Agent], llm: None = None):
|
||||
def create_default_evaluator(agents: list[Agent] | list[BaseAgent], llm: None = None):
|
||||
from crewai.experimental.evaluation import (
|
||||
GoalAlignmentEvaluator,
|
||||
SemanticQualityEvaluator,
|
||||
ToolSelectionEvaluator,
|
||||
ParameterExtractionEvaluator,
|
||||
ToolInvocationEvaluator,
|
||||
ReasoningEfficiencyEvaluator,
|
||||
SemanticQualityEvaluator,
|
||||
ToolInvocationEvaluator,
|
||||
ToolSelectionEvaluator,
|
||||
)
|
||||
|
||||
evaluators = [
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
import abc
|
||||
import enum
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.task import Task
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.llm import BaseLLM
|
||||
from crewai.task import Task
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
|
||||
|
||||
class MetricCategory(enum.Enum):
|
||||
GOAL_ALIGNMENT = "goal_alignment"
|
||||
SEMANTIC_QUALITY = "semantic_quality"
|
||||
@@ -19,7 +21,7 @@ class MetricCategory(enum.Enum):
|
||||
TOOL_INVOCATION = "tool_invocation"
|
||||
|
||||
def title(self):
|
||||
return self.value.replace('_', ' ').title()
|
||||
return self.value.replace("_", " ").title()
|
||||
|
||||
|
||||
class EvaluationScore(BaseModel):
|
||||
@@ -27,15 +29,13 @@ class EvaluationScore(BaseModel):
|
||||
default=5.0,
|
||||
description="Numeric score from 0-10 where 0 is worst and 10 is best, None if not applicable",
|
||||
ge=0.0,
|
||||
le=10.0
|
||||
le=10.0,
|
||||
)
|
||||
feedback: str = Field(
|
||||
default="",
|
||||
description="Detailed feedback explaining the evaluation score"
|
||||
default="", description="Detailed feedback explaining the evaluation score"
|
||||
)
|
||||
raw_response: str | None = Field(
|
||||
default=None,
|
||||
description="Raw response from the evaluator (e.g., LLM)"
|
||||
default=None, description="Raw response from the evaluator (e.g., LLM)"
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
@@ -56,8 +56,8 @@ class BaseEvaluator(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def evaluate(
|
||||
self,
|
||||
agent: Agent,
|
||||
execution_trace: Dict[str, Any],
|
||||
agent: Agent | BaseAgent,
|
||||
execution_trace: dict[str, Any],
|
||||
final_output: Any,
|
||||
task: Task | None = None,
|
||||
) -> EvaluationScore:
|
||||
@@ -67,9 +67,8 @@ class BaseEvaluator(abc.ABC):
|
||||
class AgentEvaluationResult(BaseModel):
|
||||
agent_id: str = Field(description="ID of the evaluated agent")
|
||||
task_id: str = Field(description="ID of the task that was executed")
|
||||
metrics: Dict[MetricCategory, EvaluationScore] = Field(
|
||||
default_factory=dict,
|
||||
description="Evaluation scores for each metric category"
|
||||
metrics: dict[MetricCategory, EvaluationScore] = Field(
|
||||
default_factory=dict, description="Evaluation scores for each metric category"
|
||||
)
|
||||
|
||||
|
||||
@@ -81,33 +80,23 @@ class AggregationStrategy(Enum):
|
||||
|
||||
|
||||
class AgentAggregatedEvaluationResult(BaseModel):
|
||||
agent_id: str = Field(
|
||||
default="",
|
||||
description="ID of the agent"
|
||||
)
|
||||
agent_role: str = Field(
|
||||
default="",
|
||||
description="Role of the agent"
|
||||
)
|
||||
agent_id: str = Field(default="", description="ID of the agent")
|
||||
agent_role: str = Field(default="", description="Role of the agent")
|
||||
task_count: int = Field(
|
||||
default=0,
|
||||
description="Number of tasks included in this aggregation"
|
||||
default=0, description="Number of tasks included in this aggregation"
|
||||
)
|
||||
aggregation_strategy: AggregationStrategy = Field(
|
||||
default=AggregationStrategy.SIMPLE_AVERAGE,
|
||||
description="Strategy used for aggregation"
|
||||
description="Strategy used for aggregation",
|
||||
)
|
||||
metrics: Dict[MetricCategory, EvaluationScore] = Field(
|
||||
default_factory=dict,
|
||||
description="Aggregated metrics across all tasks"
|
||||
metrics: dict[MetricCategory, EvaluationScore] = Field(
|
||||
default_factory=dict, description="Aggregated metrics across all tasks"
|
||||
)
|
||||
task_results: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="IDs of tasks included in this aggregation"
|
||||
task_results: list[str] = Field(
|
||||
default_factory=list, description="IDs of tasks included in this aggregation"
|
||||
)
|
||||
overall_score: Optional[float] = Field(
|
||||
default=None,
|
||||
description="Overall score for this agent"
|
||||
overall_score: float | None = Field(
|
||||
default=None, description="Overall score for this agent"
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
@@ -119,7 +108,7 @@ class AgentAggregatedEvaluationResult(BaseModel):
|
||||
result += f"\n\n- {category.value.upper()}: {score.score}/10\n"
|
||||
|
||||
if score.feedback:
|
||||
detailed_feedback = "\n ".join(score.feedback.split('\n'))
|
||||
detailed_feedback = "\n ".join(score.feedback.split("\n"))
|
||||
result += f" {detailed_feedback}\n"
|
||||
|
||||
return result
|
||||
return result
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
from collections import defaultdict
|
||||
from typing import Dict, Any, List
|
||||
from rich.table import Table
|
||||
from rich.box import HEAVY_EDGE, ROUNDED
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from rich.box import HEAVY_EDGE, ROUNDED
|
||||
from rich.table import Table
|
||||
|
||||
from crewai.events.utils.console_formatter import ConsoleFormatter
|
||||
from crewai.experimental.evaluation.base_evaluator import (
|
||||
AgentAggregatedEvaluationResult,
|
||||
AggregationStrategy,
|
||||
AgentEvaluationResult,
|
||||
AggregationStrategy,
|
||||
EvaluationScore,
|
||||
MetricCategory,
|
||||
)
|
||||
from crewai.experimental.evaluation import EvaluationScore
|
||||
from crewai.events.utils.console_formatter import ConsoleFormatter
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
|
||||
|
||||
@@ -19,7 +21,7 @@ class EvaluationDisplayFormatter:
|
||||
self.console_formatter = ConsoleFormatter()
|
||||
|
||||
def display_evaluation_with_feedback(
|
||||
self, iterations_results: Dict[int, Dict[str, List[Any]]]
|
||||
self, iterations_results: dict[int, dict[str, list[Any]]]
|
||||
):
|
||||
if not iterations_results:
|
||||
self.console_formatter.print(
|
||||
@@ -99,7 +101,7 @@ class EvaluationDisplayFormatter:
|
||||
|
||||
def display_summary_results(
|
||||
self,
|
||||
iterations_results: Dict[int, Dict[str, List[AgentAggregatedEvaluationResult]]],
|
||||
iterations_results: dict[int, dict[str, list[AgentEvaluationResult]]],
|
||||
):
|
||||
if not iterations_results:
|
||||
self.console_formatter.print(
|
||||
@@ -280,7 +282,7 @@ class EvaluationDisplayFormatter:
|
||||
feedback_summary = feedbacks[0]
|
||||
|
||||
aggregated_metrics[category] = EvaluationScore(
|
||||
score=avg_score, feedback=feedback_summary
|
||||
score=avg_score, feedback=feedback_summary or ""
|
||||
)
|
||||
|
||||
overall_score = None
|
||||
@@ -304,25 +306,25 @@ class EvaluationDisplayFormatter:
|
||||
self,
|
||||
agent_role: str,
|
||||
metric: str,
|
||||
feedbacks: List[str],
|
||||
scores: List[float | None],
|
||||
feedbacks: list[str],
|
||||
scores: list[float | None],
|
||||
strategy: AggregationStrategy,
|
||||
) -> str:
|
||||
if len(feedbacks) <= 2 and all(len(fb) < 200 for fb in feedbacks):
|
||||
return "\n\n".join(
|
||||
[f"Feedback {i+1}: {fb}" for i, fb in enumerate(feedbacks)]
|
||||
[f"Feedback {i + 1}: {fb}" for i, fb in enumerate(feedbacks)]
|
||||
)
|
||||
|
||||
try:
|
||||
llm = create_llm()
|
||||
|
||||
formatted_feedbacks = []
|
||||
for i, (feedback, score) in enumerate(zip(feedbacks, scores)):
|
||||
for i, (feedback, score) in enumerate(zip(feedbacks, scores, strict=False)):
|
||||
if len(feedback) > 500:
|
||||
feedback = feedback[:500] + "..."
|
||||
score_text = f"{score:.1f}" if score is not None else "N/A"
|
||||
formatted_feedbacks.append(
|
||||
f"Feedback #{i+1} (Score: {score_text}):\n{feedback}"
|
||||
f"Feedback #{i + 1} (Score: {score_text}):\n{feedback}"
|
||||
)
|
||||
|
||||
all_feedbacks = "\n\n" + "\n\n---\n\n".join(formatted_feedbacks)
|
||||
@@ -365,10 +367,9 @@ class EvaluationDisplayFormatter:
|
||||
""",
|
||||
},
|
||||
]
|
||||
assert llm is not None
|
||||
response = llm.call(prompt)
|
||||
|
||||
return response
|
||||
if llm is None:
|
||||
raise ValueError("LLM must be initialized")
|
||||
return llm.call(prompt)
|
||||
|
||||
except Exception:
|
||||
return "Synthesized from multiple tasks: " + "\n\n".join(
|
||||
|
||||
@@ -1,26 +1,25 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.task import Task
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.events.base_event_listener import BaseEventListener
|
||||
from crewai.events.event_bus import CrewAIEventsBus
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionStartedEvent,
|
||||
AgentExecutionCompletedEvent,
|
||||
LiteAgentExecutionStartedEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
LiteAgentExecutionCompletedEvent,
|
||||
LiteAgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.llm_events import LLMCallCompletedEvent, LLMCallStartedEvent
|
||||
from crewai.events.types.tool_usage_events import (
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageErrorEvent,
|
||||
ToolExecutionErrorEvent,
|
||||
ToolSelectionErrorEvent,
|
||||
ToolUsageErrorEvent,
|
||||
ToolUsageFinishedEvent,
|
||||
ToolValidateInputErrorEvent,
|
||||
)
|
||||
from crewai.events.types.llm_events import LLMCallStartedEvent, LLMCallCompletedEvent
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
class EvaluationTraceCallback(BaseEventListener):
|
||||
@@ -136,7 +135,7 @@ class EvaluationTraceCallback(BaseEventListener):
|
||||
def _init_trace(self, trace_key: str, **kwargs: Any):
|
||||
self.traces[trace_key] = kwargs
|
||||
|
||||
def on_agent_start(self, agent: Agent, task: Task):
|
||||
def on_agent_start(self, agent: BaseAgent, task: Task):
|
||||
self.current_agent_id = agent.id
|
||||
self.current_task_id = task.id
|
||||
|
||||
@@ -151,7 +150,7 @@ class EvaluationTraceCallback(BaseEventListener):
|
||||
final_output=None,
|
||||
)
|
||||
|
||||
def on_agent_finish(self, agent: Agent, task: Task, output: Any):
|
||||
def on_agent_finish(self, agent: BaseAgent, task: Task, output: Any):
|
||||
trace_key = f"{agent.id}_{task.id}"
|
||||
if trace_key in self.traces:
|
||||
self.traces[trace_key]["final_output"] = output
|
||||
@@ -253,7 +252,7 @@ class EvaluationTraceCallback(BaseEventListener):
|
||||
if hasattr(self, "current_llm_call"):
|
||||
self.current_llm_call = {}
|
||||
|
||||
def get_trace(self, agent_id: str, task_id: str) -> Optional[Dict[str, Any]]:
|
||||
def get_trace(self, agent_id: str, task_id: str) -> dict[str, Any] | None:
|
||||
trace_key = f"{agent_id}_{task_id}"
|
||||
return self.traces.get(trace_key)
|
||||
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from crewai.experimental.evaluation.experiment.result import (
|
||||
ExperimentResult,
|
||||
ExperimentResults,
|
||||
)
|
||||
from crewai.experimental.evaluation.experiment.runner import ExperimentRunner
|
||||
from crewai.experimental.evaluation.experiment.result import ExperimentResults, ExperimentResult
|
||||
|
||||
__all__ = [
|
||||
"ExperimentRunner",
|
||||
"ExperimentResults",
|
||||
"ExperimentResult"
|
||||
]
|
||||
__all__ = ["ExperimentResult", "ExperimentResults", "ExperimentRunner"]
|
||||
|
||||
@@ -2,45 +2,60 @@ import json
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ExperimentResult(BaseModel):
|
||||
identifier: str
|
||||
inputs: dict[str, Any]
|
||||
score: int | dict[str, int | float]
|
||||
expected_score: int | dict[str, int | float]
|
||||
score: float | dict[str, float]
|
||||
expected_score: float | dict[str, float]
|
||||
passed: bool
|
||||
agent_evaluations: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ExperimentResults:
|
||||
def __init__(self, results: list[ExperimentResult], metadata: dict[str, Any] | None = None):
|
||||
def __init__(
|
||||
self, results: list[ExperimentResult], metadata: dict[str, Any] | None = None
|
||||
):
|
||||
self.results = results
|
||||
self.metadata = metadata or {}
|
||||
self.timestamp = datetime.now(timezone.utc)
|
||||
|
||||
from crewai.experimental.evaluation.experiment.result_display import ExperimentResultsDisplay
|
||||
from crewai.experimental.evaluation.experiment.result_display import (
|
||||
ExperimentResultsDisplay,
|
||||
)
|
||||
|
||||
self.display = ExperimentResultsDisplay()
|
||||
|
||||
def to_json(self, filepath: str | None = None) -> dict[str, Any]:
|
||||
data = {
|
||||
"timestamp": self.timestamp.isoformat(),
|
||||
"metadata": self.metadata,
|
||||
"results": [r.model_dump(exclude={"agent_evaluations"}) for r in self.results]
|
||||
"results": [
|
||||
r.model_dump(exclude={"agent_evaluations"}) for r in self.results
|
||||
],
|
||||
}
|
||||
|
||||
if filepath:
|
||||
with open(filepath, 'w') as f:
|
||||
with open(filepath, "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
self.display.console.print(f"[green]Results saved to {filepath}[/green]")
|
||||
|
||||
return data
|
||||
|
||||
def compare_with_baseline(self, baseline_filepath: str, save_current: bool = True, print_summary: bool = False) -> dict[str, Any]:
|
||||
def compare_with_baseline(
|
||||
self,
|
||||
baseline_filepath: str,
|
||||
save_current: bool = True,
|
||||
print_summary: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
baseline_runs = []
|
||||
|
||||
if os.path.exists(baseline_filepath) and os.path.getsize(baseline_filepath) > 0:
|
||||
try:
|
||||
with open(baseline_filepath, 'r') as f:
|
||||
with open(baseline_filepath, "r") as f:
|
||||
baseline_data = json.load(f)
|
||||
|
||||
if isinstance(baseline_data, dict) and "timestamp" in baseline_data:
|
||||
@@ -48,14 +63,18 @@ class ExperimentResults:
|
||||
elif isinstance(baseline_data, list):
|
||||
baseline_runs = baseline_data
|
||||
except (json.JSONDecodeError, FileNotFoundError) as e:
|
||||
self.display.console.print(f"[yellow]Warning: Could not load baseline file: {str(e)}[/yellow]")
|
||||
self.display.console.print(
|
||||
f"[yellow]Warning: Could not load baseline file: {e!s}[/yellow]"
|
||||
)
|
||||
|
||||
if not baseline_runs:
|
||||
if save_current:
|
||||
current_data = self.to_json()
|
||||
with open(baseline_filepath, 'w') as f:
|
||||
with open(baseline_filepath, "w") as f:
|
||||
json.dump([current_data], f, indent=2)
|
||||
self.display.console.print(f"[green]Saved current results as new baseline to {baseline_filepath}[/green]")
|
||||
self.display.console.print(
|
||||
f"[green]Saved current results as new baseline to {baseline_filepath}[/green]"
|
||||
)
|
||||
return {"is_baseline": True, "changes": {}}
|
||||
|
||||
baseline_runs.sort(key=lambda x: x.get("timestamp", ""), reverse=True)
|
||||
@@ -69,9 +88,11 @@ class ExperimentResults:
|
||||
if save_current:
|
||||
current_data = self.to_json()
|
||||
baseline_runs.append(current_data)
|
||||
with open(baseline_filepath, 'w') as f:
|
||||
with open(baseline_filepath, "w") as f:
|
||||
json.dump(baseline_runs, f, indent=2)
|
||||
self.display.console.print(f"[green]Added current results to baseline file {baseline_filepath}[/green]")
|
||||
self.display.console.print(
|
||||
f"[green]Added current results to baseline file {baseline_filepath}[/green]"
|
||||
)
|
||||
|
||||
return comparison
|
||||
|
||||
@@ -118,5 +139,5 @@ class ExperimentResults:
|
||||
"new_tests": new_tests,
|
||||
"missing_tests": missing_tests,
|
||||
"total_compared": len(improved) + len(regressed) + len(unchanged),
|
||||
"baseline_timestamp": baseline_run.get("timestamp", "unknown")
|
||||
"baseline_timestamp": baseline_run.get("timestamp", "unknown"),
|
||||
}
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
from typing import Dict, Any
|
||||
from typing import Any
|
||||
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
|
||||
from crewai.experimental.evaluation.experiment.result import ExperimentResults
|
||||
|
||||
|
||||
class ExperimentResultsDisplay:
|
||||
def __init__(self):
|
||||
self.console = Console()
|
||||
@@ -19,13 +22,19 @@ class ExperimentResultsDisplay:
|
||||
table.add_row("Total Test Cases", str(total))
|
||||
table.add_row("Passed", str(passed))
|
||||
table.add_row("Failed", str(total - passed))
|
||||
table.add_row("Success Rate", f"{(passed / total * 100):.1f}%" if total > 0 else "N/A")
|
||||
table.add_row(
|
||||
"Success Rate", f"{(passed / total * 100):.1f}%" if total > 0 else "N/A"
|
||||
)
|
||||
|
||||
self.console.print(table)
|
||||
|
||||
def comparison_summary(self, comparison: Dict[str, Any], baseline_timestamp: str):
|
||||
self.console.print(Panel(f"[bold]Comparison with baseline run from {baseline_timestamp}[/bold]",
|
||||
expand=False))
|
||||
def comparison_summary(self, comparison: dict[str, Any], baseline_timestamp: str):
|
||||
self.console.print(
|
||||
Panel(
|
||||
f"[bold]Comparison with baseline run from {baseline_timestamp}[/bold]",
|
||||
expand=False,
|
||||
)
|
||||
)
|
||||
|
||||
table = Table(title="Results Comparison")
|
||||
table.add_column("Metric", style="cyan")
|
||||
@@ -34,7 +43,9 @@ class ExperimentResultsDisplay:
|
||||
|
||||
improved = comparison.get("improved", [])
|
||||
if improved:
|
||||
details = ", ".join([f"{test_identifier}" for test_identifier in improved[:3]])
|
||||
details = ", ".join(
|
||||
[f"{test_identifier}" for test_identifier in improved[:3]]
|
||||
)
|
||||
if len(improved) > 3:
|
||||
details += f" and {len(improved) - 3} more"
|
||||
table.add_row("✅ Improved", str(len(improved)), details)
|
||||
@@ -43,7 +54,9 @@ class ExperimentResultsDisplay:
|
||||
|
||||
regressed = comparison.get("regressed", [])
|
||||
if regressed:
|
||||
details = ", ".join([f"{test_identifier}" for test_identifier in regressed[:3]])
|
||||
details = ", ".join(
|
||||
[f"{test_identifier}" for test_identifier in regressed[:3]]
|
||||
)
|
||||
if len(regressed) > 3:
|
||||
details += f" and {len(regressed) - 3} more"
|
||||
table.add_row("❌ Regressed", str(len(regressed)), details, style="red")
|
||||
@@ -58,13 +71,13 @@ class ExperimentResultsDisplay:
|
||||
details = ", ".join(new_tests[:3])
|
||||
if len(new_tests) > 3:
|
||||
details += f" and {len(new_tests) - 3} more"
|
||||
table.add_row("➕ New Tests", str(len(new_tests)), details)
|
||||
table.add_row("+ New Tests", str(len(new_tests)), details)
|
||||
|
||||
missing_tests = comparison.get("missing_tests", [])
|
||||
if missing_tests:
|
||||
details = ", ".join(missing_tests[:3])
|
||||
if len(missing_tests) > 3:
|
||||
details += f" and {len(missing_tests) - 3} more"
|
||||
table.add_row("➖ Missing Tests", str(len(missing_tests)), details)
|
||||
table.add_row("- Missing Tests", str(len(missing_tests)), details)
|
||||
|
||||
self.console.print(table)
|
||||
|
||||
@@ -2,11 +2,20 @@ from collections import defaultdict
|
||||
from hashlib import md5
|
||||
from typing import Any
|
||||
|
||||
from crewai import Crew, Agent
|
||||
from crewai import Agent, Crew
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.experimental.evaluation import AgentEvaluator, create_default_evaluator
|
||||
from crewai.experimental.evaluation.experiment.result_display import ExperimentResultsDisplay
|
||||
from crewai.experimental.evaluation.experiment.result import ExperimentResults, ExperimentResult
|
||||
from crewai.experimental.evaluation.evaluation_display import AgentAggregatedEvaluationResult
|
||||
from crewai.experimental.evaluation.evaluation_display import (
|
||||
AgentAggregatedEvaluationResult,
|
||||
)
|
||||
from crewai.experimental.evaluation.experiment.result import (
|
||||
ExperimentResult,
|
||||
ExperimentResults,
|
||||
)
|
||||
from crewai.experimental.evaluation.experiment.result_display import (
|
||||
ExperimentResultsDisplay,
|
||||
)
|
||||
|
||||
|
||||
class ExperimentRunner:
|
||||
def __init__(self, dataset: list[dict[str, Any]]):
|
||||
@@ -14,11 +23,17 @@ class ExperimentRunner:
|
||||
self.evaluator: AgentEvaluator | None = None
|
||||
self.display = ExperimentResultsDisplay()
|
||||
|
||||
def run(self, crew: Crew | None = None, agents: list[Agent] | None = None, print_summary: bool = False) -> ExperimentResults:
|
||||
def run(
|
||||
self,
|
||||
crew: Crew | None = None,
|
||||
agents: list[Agent] | list[BaseAgent] | None = None,
|
||||
print_summary: bool = False,
|
||||
) -> ExperimentResults:
|
||||
if crew and not agents:
|
||||
agents = crew.agents
|
||||
|
||||
assert agents is not None
|
||||
if agents is None:
|
||||
raise ValueError("Agents must be provided either directly or via a crew")
|
||||
self.evaluator = create_default_evaluator(agents=agents)
|
||||
|
||||
results = []
|
||||
@@ -35,21 +50,37 @@ class ExperimentRunner:
|
||||
|
||||
return experiment_results
|
||||
|
||||
def _run_test_case(self, test_case: dict[str, Any], agents: list[Agent], crew: Crew | None = None) -> ExperimentResult:
|
||||
def _run_test_case(
|
||||
self,
|
||||
test_case: dict[str, Any],
|
||||
agents: list[Agent] | list[BaseAgent],
|
||||
crew: Crew | None = None,
|
||||
) -> ExperimentResult:
|
||||
inputs = test_case["inputs"]
|
||||
expected_score = test_case["expected_score"]
|
||||
identifier = test_case.get("identifier") or md5(str(test_case).encode(), usedforsecurity=False).hexdigest()
|
||||
identifier = (
|
||||
test_case.get("identifier")
|
||||
or md5(str(test_case).encode(), usedforsecurity=False).hexdigest()
|
||||
)
|
||||
|
||||
try:
|
||||
self.display.console.print(f"[dim]Running crew with input: {str(inputs)[:50]}...[/dim]")
|
||||
self.display.console.print(
|
||||
f"[dim]Running crew with input: {str(inputs)[:50]}...[/dim]"
|
||||
)
|
||||
self.display.console.print("\n")
|
||||
if crew:
|
||||
crew.kickoff(inputs=inputs)
|
||||
else:
|
||||
for agent in agents:
|
||||
agent.kickoff(**inputs)
|
||||
if isinstance(agent, Agent):
|
||||
agent.kickoff(**inputs)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Agent {agent} is not an instance of Agent and cannot be kicked off directly"
|
||||
)
|
||||
|
||||
assert self.evaluator is not None
|
||||
if self.evaluator is None:
|
||||
raise ValueError("Evaluator must be initialized")
|
||||
agent_evaluations = self.evaluator.get_agent_evaluation()
|
||||
|
||||
actual_score = self._extract_scores(agent_evaluations)
|
||||
@@ -61,35 +92,38 @@ class ExperimentRunner:
|
||||
score=actual_score,
|
||||
expected_score=expected_score,
|
||||
passed=passed,
|
||||
agent_evaluations=agent_evaluations
|
||||
agent_evaluations=agent_evaluations,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.display.console.print(f"[red]Error running test case: {str(e)}[/red]")
|
||||
self.display.console.print(f"[red]Error running test case: {e!s}[/red]")
|
||||
return ExperimentResult(
|
||||
identifier=identifier,
|
||||
inputs=inputs,
|
||||
score=0,
|
||||
score=0.0,
|
||||
expected_score=expected_score,
|
||||
passed=False
|
||||
passed=False,
|
||||
)
|
||||
|
||||
def _extract_scores(self, agent_evaluations: dict[str, AgentAggregatedEvaluationResult]) -> float | dict[str, float]:
|
||||
def _extract_scores(
|
||||
self, agent_evaluations: dict[str, AgentAggregatedEvaluationResult]
|
||||
) -> float | dict[str, float]:
|
||||
all_scores: dict[str, list[float]] = defaultdict(list)
|
||||
for evaluation in agent_evaluations.values():
|
||||
for metric_name, score in evaluation.metrics.items():
|
||||
if score.score is not None:
|
||||
all_scores[metric_name.value].append(score.score)
|
||||
|
||||
avg_scores = {m: sum(s)/len(s) for m, s in all_scores.items()}
|
||||
avg_scores = {m: sum(s) / len(s) for m, s in all_scores.items()}
|
||||
|
||||
if len(avg_scores) == 1:
|
||||
return list(avg_scores.values())[0]
|
||||
return next(iter(avg_scores.values()))
|
||||
|
||||
return avg_scores
|
||||
|
||||
def _assert_scores(self, expected: float | dict[str, float],
|
||||
actual: float | dict[str, float]) -> bool:
|
||||
def _assert_scores(
|
||||
self, expected: float | dict[str, float], actual: float | dict[str, float]
|
||||
) -> bool:
|
||||
"""
|
||||
Compare expected and actual scores, and return whether the test case passed.
|
||||
|
||||
@@ -122,4 +156,4 @@ class ExperimentRunner:
|
||||
# All matching keys must have actual >= expected
|
||||
return all(actual[key] >= expected[key] for key in matching_keys)
|
||||
|
||||
return False
|
||||
return False
|
||||
|
||||
@@ -13,11 +13,11 @@ def extract_json_from_llm_response(text: str) -> dict[str, Any]:
|
||||
|
||||
json_patterns = [
|
||||
# Standard markdown code blocks with json
|
||||
r'```json\s*([\s\S]*?)\s*```',
|
||||
r"```json\s*([\s\S]*?)\s*```",
|
||||
# Code blocks without language specifier
|
||||
r'```\s*([\s\S]*?)\s*```',
|
||||
r"```\s*([\s\S]*?)\s*```",
|
||||
# Inline code with JSON
|
||||
r'`([{\\[].*[}\]])`',
|
||||
r"`([{\\[].*[}\]])`",
|
||||
]
|
||||
|
||||
for pattern in json_patterns:
|
||||
@@ -25,6 +25,6 @@ def extract_json_from_llm_response(text: str) -> dict[str, Any]:
|
||||
for match in matches:
|
||||
try:
|
||||
return json.loads(match.strip())
|
||||
except json.JSONDecodeError:
|
||||
except json.JSONDecodeError: # noqa: PERF203
|
||||
continue
|
||||
raise ValueError("No valid JSON found in the response")
|
||||
|
||||
@@ -1,26 +1,21 @@
|
||||
from crewai.experimental.evaluation.metrics.goal_metrics import GoalAlignmentEvaluator
|
||||
from crewai.experimental.evaluation.metrics.reasoning_metrics import (
|
||||
ReasoningEfficiencyEvaluator
|
||||
ReasoningEfficiencyEvaluator,
|
||||
)
|
||||
|
||||
from crewai.experimental.evaluation.metrics.tools_metrics import (
|
||||
ToolSelectionEvaluator,
|
||||
ParameterExtractionEvaluator,
|
||||
ToolInvocationEvaluator
|
||||
)
|
||||
|
||||
from crewai.experimental.evaluation.metrics.goal_metrics import (
|
||||
GoalAlignmentEvaluator
|
||||
)
|
||||
|
||||
from crewai.experimental.evaluation.metrics.semantic_quality_metrics import (
|
||||
SemanticQualityEvaluator
|
||||
SemanticQualityEvaluator,
|
||||
)
|
||||
from crewai.experimental.evaluation.metrics.tools_metrics import (
|
||||
ParameterExtractionEvaluator,
|
||||
ToolInvocationEvaluator,
|
||||
ToolSelectionEvaluator,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ReasoningEfficiencyEvaluator",
|
||||
"ToolSelectionEvaluator",
|
||||
"ParameterExtractionEvaluator",
|
||||
"ToolInvocationEvaluator",
|
||||
"GoalAlignmentEvaluator",
|
||||
"SemanticQualityEvaluator"
|
||||
]
|
||||
"ParameterExtractionEvaluator",
|
||||
"ReasoningEfficiencyEvaluator",
|
||||
"SemanticQualityEvaluator",
|
||||
"ToolInvocationEvaluator",
|
||||
"ToolSelectionEvaluator",
|
||||
]
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.experimental.evaluation.base_evaluator import (
|
||||
BaseEvaluator,
|
||||
EvaluationScore,
|
||||
MetricCategory,
|
||||
)
|
||||
from crewai.experimental.evaluation.json_parser import extract_json_from_llm_response
|
||||
from crewai.task import Task
|
||||
|
||||
from crewai.experimental.evaluation.base_evaluator import BaseEvaluator, EvaluationScore, MetricCategory
|
||||
from crewai.experimental.evaluation.json_parser import extract_json_from_llm_response
|
||||
|
||||
class GoalAlignmentEvaluator(BaseEvaluator):
|
||||
@property
|
||||
@@ -13,8 +18,8 @@ class GoalAlignmentEvaluator(BaseEvaluator):
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
agent: Agent,
|
||||
execution_trace: Dict[str, Any],
|
||||
agent: Agent | BaseAgent,
|
||||
execution_trace: dict[str, Any],
|
||||
final_output: Any,
|
||||
task: Task | None = None,
|
||||
) -> EvaluationScore:
|
||||
@@ -23,7 +28,9 @@ class GoalAlignmentEvaluator(BaseEvaluator):
|
||||
task_context = f"Task description: {task.description}\nExpected output: {task.expected_output}\n"
|
||||
|
||||
prompt = [
|
||||
{"role": "system", "content": """You are an expert evaluator assessing how well an AI agent's output aligns with its assigned task goal.
|
||||
{
|
||||
"role": "system",
|
||||
"content": """You are an expert evaluator assessing how well an AI agent's output aligns with its assigned task goal.
|
||||
|
||||
Score the agent's goal alignment on a scale from 0-10 where:
|
||||
- 0: Complete misalignment, agent did not understand or attempt the task goal
|
||||
@@ -37,8 +44,11 @@ Consider:
|
||||
4. Did the agent provide all requested information or deliverables?
|
||||
|
||||
Return your evaluation as JSON with fields 'score' (number) and 'feedback' (string).
|
||||
"""},
|
||||
{"role": "user", "content": f"""
|
||||
""",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""
|
||||
Agent role: {agent.role}
|
||||
Agent goal: {agent.goal}
|
||||
{task_context}
|
||||
@@ -47,23 +57,26 @@ Agent's final output:
|
||||
{final_output}
|
||||
|
||||
Evaluate how well the agent's output aligns with the assigned task goal.
|
||||
"""}
|
||||
""",
|
||||
},
|
||||
]
|
||||
assert self.llm is not None
|
||||
if self.llm is None:
|
||||
raise ValueError("LLM must be initialized")
|
||||
response = self.llm.call(prompt)
|
||||
|
||||
try:
|
||||
evaluation_data: dict[str, Any] = extract_json_from_llm_response(response)
|
||||
assert evaluation_data is not None
|
||||
if evaluation_data is None:
|
||||
raise ValueError("Failed to extract evaluation data from LLM response")
|
||||
|
||||
return EvaluationScore(
|
||||
score=evaluation_data.get("score", 0),
|
||||
feedback=evaluation_data.get("feedback", response),
|
||||
raw_response=response
|
||||
raw_response=response,
|
||||
)
|
||||
except Exception:
|
||||
return EvaluationScore(
|
||||
score=None,
|
||||
feedback=f"Failed to parse evaluation. Raw response: {response}",
|
||||
raw_response=response
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
@@ -8,18 +8,24 @@ This module provides evaluator implementations for:
|
||||
|
||||
import logging
|
||||
import re
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Tuple
|
||||
import numpy as np
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.task import Task
|
||||
|
||||
from crewai.experimental.evaluation.base_evaluator import BaseEvaluator, EvaluationScore, MetricCategory
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.experimental.evaluation.base_evaluator import (
|
||||
BaseEvaluator,
|
||||
EvaluationScore,
|
||||
MetricCategory,
|
||||
)
|
||||
from crewai.experimental.evaluation.json_parser import extract_json_from_llm_response
|
||||
from crewai.task import Task
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
|
||||
|
||||
class ReasoningPatternType(Enum):
|
||||
EFFICIENT = "efficient" # Good reasoning flow
|
||||
LOOP = "loop" # Agent is stuck in a loop
|
||||
@@ -35,8 +41,8 @@ class ReasoningEfficiencyEvaluator(BaseEvaluator):
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
agent: Agent,
|
||||
execution_trace: Dict[str, Any],
|
||||
agent: Agent | BaseAgent,
|
||||
execution_trace: dict[str, Any],
|
||||
final_output: TaskOutput | str,
|
||||
task: Task | None = None,
|
||||
) -> EvaluationScore:
|
||||
@@ -49,7 +55,7 @@ class ReasoningEfficiencyEvaluator(BaseEvaluator):
|
||||
if not llm_calls or len(llm_calls) < 2:
|
||||
return EvaluationScore(
|
||||
score=None,
|
||||
feedback="Insufficient LLM calls to evaluate reasoning efficiency."
|
||||
feedback="Insufficient LLM calls to evaluate reasoning efficiency.",
|
||||
)
|
||||
|
||||
total_calls = len(llm_calls)
|
||||
@@ -58,12 +64,16 @@ class ReasoningEfficiencyEvaluator(BaseEvaluator):
|
||||
time_intervals = []
|
||||
has_reliable_timing = True
|
||||
for i in range(1, len(llm_calls)):
|
||||
start_time = llm_calls[i-1].get("end_time")
|
||||
start_time = llm_calls[i - 1].get("end_time")
|
||||
end_time = llm_calls[i].get("start_time")
|
||||
if start_time and end_time and start_time != end_time:
|
||||
try:
|
||||
interval = end_time - start_time
|
||||
time_intervals.append(interval.total_seconds() if hasattr(interval, 'total_seconds') else 0)
|
||||
time_intervals.append(
|
||||
interval.total_seconds()
|
||||
if hasattr(interval, "total_seconds")
|
||||
else 0
|
||||
)
|
||||
except Exception:
|
||||
has_reliable_timing = False
|
||||
else:
|
||||
@@ -83,14 +93,22 @@ class ReasoningEfficiencyEvaluator(BaseEvaluator):
|
||||
if has_reliable_timing and time_intervals:
|
||||
efficiency_metrics["avg_time_between_calls"] = np.mean(time_intervals)
|
||||
|
||||
loop_info = f"Detected {len(loop_details)} potential reasoning loops." if loop_detected else "No significant reasoning loops detected."
|
||||
loop_info = (
|
||||
f"Detected {len(loop_details)} potential reasoning loops."
|
||||
if loop_detected
|
||||
else "No significant reasoning loops detected."
|
||||
)
|
||||
|
||||
call_samples = self._get_call_samples(llm_calls)
|
||||
|
||||
final_output = final_output.raw if isinstance(final_output, TaskOutput) else final_output
|
||||
final_output = (
|
||||
final_output.raw if isinstance(final_output, TaskOutput) else final_output
|
||||
)
|
||||
|
||||
prompt = [
|
||||
{"role": "system", "content": """You are an expert evaluator assessing the reasoning efficiency of an AI agent's thought process.
|
||||
{
|
||||
"role": "system",
|
||||
"content": """You are an expert evaluator assessing the reasoning efficiency of an AI agent's thought process.
|
||||
|
||||
Evaluate the agent's reasoning efficiency across these five key subcategories:
|
||||
|
||||
@@ -120,8 +138,11 @@ Return your evaluation as JSON with the following structure:
|
||||
"feedback": string (general feedback about overall reasoning efficiency),
|
||||
"optimization_suggestions": string (concrete suggestions for improving reasoning efficiency),
|
||||
"detected_patterns": string (describe any inefficient reasoning patterns you observe)
|
||||
}"""},
|
||||
{"role": "user", "content": f"""
|
||||
}""",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""
|
||||
Agent role: {agent.role}
|
||||
{task_context}
|
||||
|
||||
@@ -140,10 +161,12 @@ Agent's final output:
|
||||
|
||||
Evaluate the reasoning efficiency of this agent based on these interaction patterns.
|
||||
Identify any inefficient reasoning patterns and provide specific suggestions for optimization.
|
||||
"""}
|
||||
""",
|
||||
},
|
||||
]
|
||||
|
||||
assert self.llm is not None
|
||||
if self.llm is None:
|
||||
raise ValueError("LLM must be initialized")
|
||||
response = self.llm.call(prompt)
|
||||
|
||||
try:
|
||||
@@ -156,34 +179,46 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
||||
conciseness = scores.get("conciseness", 5.0)
|
||||
loop_avoidance = scores.get("loop_avoidance", 5.0)
|
||||
|
||||
overall_score = evaluation_data.get("overall_score", evaluation_data.get("score", 5.0))
|
||||
overall_score = evaluation_data.get(
|
||||
"overall_score", evaluation_data.get("score", 5.0)
|
||||
)
|
||||
feedback = evaluation_data.get("feedback", "No detailed feedback provided.")
|
||||
optimization_suggestions = evaluation_data.get("optimization_suggestions", "No specific suggestions provided.")
|
||||
optimization_suggestions = evaluation_data.get(
|
||||
"optimization_suggestions", "No specific suggestions provided."
|
||||
)
|
||||
|
||||
detailed_feedback = "Reasoning Efficiency Evaluation:\n"
|
||||
detailed_feedback += f"• Focus: {focus}/10 - Staying on topic without tangents\n"
|
||||
detailed_feedback += f"• Progression: {progression}/10 - Building on previous thinking\n"
|
||||
detailed_feedback += (
|
||||
f"• Focus: {focus}/10 - Staying on topic without tangents\n"
|
||||
)
|
||||
detailed_feedback += (
|
||||
f"• Progression: {progression}/10 - Building on previous thinking\n"
|
||||
)
|
||||
detailed_feedback += f"• Decision Quality: {decision_quality}/10 - Making appropriate decisions\n"
|
||||
detailed_feedback += f"• Conciseness: {conciseness}/10 - Communicating efficiently\n"
|
||||
detailed_feedback += (
|
||||
f"• Conciseness: {conciseness}/10 - Communicating efficiently\n"
|
||||
)
|
||||
detailed_feedback += f"• Loop Avoidance: {loop_avoidance}/10 - Avoiding repetitive patterns\n\n"
|
||||
|
||||
detailed_feedback += f"Feedback:\n{feedback}\n\n"
|
||||
detailed_feedback += f"Optimization Suggestions:\n{optimization_suggestions}"
|
||||
detailed_feedback += (
|
||||
f"Optimization Suggestions:\n{optimization_suggestions}"
|
||||
)
|
||||
|
||||
return EvaluationScore(
|
||||
score=float(overall_score),
|
||||
feedback=detailed_feedback,
|
||||
raw_response=response
|
||||
raw_response=response,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to parse reasoning efficiency evaluation: {e}")
|
||||
return EvaluationScore(
|
||||
score=None,
|
||||
feedback=f"Failed to parse reasoning efficiency evaluation. Raw response: {response[:200]}...",
|
||||
raw_response=response
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
def _detect_loops(self, llm_calls: List[Dict]) -> Tuple[bool, List[Dict]]:
|
||||
def _detect_loops(self, llm_calls: list[dict]) -> tuple[bool, list[dict]]:
|
||||
loop_details = []
|
||||
|
||||
messages = []
|
||||
@@ -193,9 +228,11 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
||||
messages.append(content)
|
||||
elif isinstance(content, list) and len(content) > 0:
|
||||
# Handle message list format
|
||||
for msg in content:
|
||||
if isinstance(msg, dict) and "content" in msg:
|
||||
messages.append(msg["content"])
|
||||
messages.extend(
|
||||
msg["content"]
|
||||
for msg in content
|
||||
if isinstance(msg, dict) and "content" in msg
|
||||
)
|
||||
|
||||
# Simple n-gram based similarity detection
|
||||
# For a more robust implementation, consider using embedding-based similarity
|
||||
@@ -205,18 +242,20 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
||||
# A more sophisticated approach would use semantic similarity
|
||||
similarity = self._calculate_text_similarity(messages[i], messages[j])
|
||||
if similarity > 0.7: # Arbitrary threshold
|
||||
loop_details.append({
|
||||
"first_occurrence": i,
|
||||
"second_occurrence": j,
|
||||
"similarity": similarity,
|
||||
"snippet": messages[i][:100] + "..."
|
||||
})
|
||||
loop_details.append(
|
||||
{
|
||||
"first_occurrence": i,
|
||||
"second_occurrence": j,
|
||||
"similarity": similarity,
|
||||
"snippet": messages[i][:100] + "...",
|
||||
}
|
||||
)
|
||||
|
||||
return len(loop_details) > 0, loop_details
|
||||
|
||||
def _calculate_text_similarity(self, text1: str, text2: str) -> float:
|
||||
text1 = re.sub(r'\s+', ' ', text1.lower()).strip()
|
||||
text2 = re.sub(r'\s+', ' ', text2.lower()).strip()
|
||||
text1 = re.sub(r"\s+", " ", text1.lower()).strip()
|
||||
text2 = re.sub(r"\s+", " ", text2.lower()).strip()
|
||||
|
||||
# Simple Jaccard similarity on word sets
|
||||
words1 = set(text1.split())
|
||||
@@ -227,7 +266,7 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
||||
|
||||
return intersection / union if union > 0 else 0.0
|
||||
|
||||
def _analyze_reasoning_patterns(self, llm_calls: List[Dict]) -> Dict[str, Any]:
|
||||
def _analyze_reasoning_patterns(self, llm_calls: list[dict]) -> dict[str, Any]:
|
||||
call_lengths = []
|
||||
response_times = []
|
||||
|
||||
@@ -248,8 +287,8 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
||||
if start_time and end_time:
|
||||
try:
|
||||
response_times.append(end_time - start_time)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logging.debug(f"Failed to calculate response time: {e}")
|
||||
|
||||
avg_length = np.mean(call_lengths) if call_lengths else 0
|
||||
std_length = np.std(call_lengths) if call_lengths else 0
|
||||
@@ -267,7 +306,9 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
||||
details = "Agent is consistently verbose across interactions."
|
||||
elif len(llm_calls) > 10 and length_trend > 0.5:
|
||||
primary_pattern = ReasoningPatternType.INDECISIVE
|
||||
details = "Agent shows signs of indecisiveness with increasing message lengths."
|
||||
details = (
|
||||
"Agent shows signs of indecisiveness with increasing message lengths."
|
||||
)
|
||||
elif std_length / avg_length > 0.8:
|
||||
primary_pattern = ReasoningPatternType.SCATTERED
|
||||
details = "Agent shows inconsistent reasoning flow with highly variable responses."
|
||||
@@ -279,8 +320,8 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
||||
"avg_length": avg_length,
|
||||
"std_length": std_length,
|
||||
"length_trend": length_trend,
|
||||
"loop_score": loop_score
|
||||
}
|
||||
"loop_score": loop_score,
|
||||
},
|
||||
}
|
||||
|
||||
def _calculate_trend(self, values: Sequence[float | int]) -> float:
|
||||
@@ -303,7 +344,9 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
||||
except Exception:
|
||||
return 0.0
|
||||
|
||||
def _calculate_loop_likelihood(self, call_lengths: Sequence[float], response_times: Sequence[float]) -> float:
|
||||
def _calculate_loop_likelihood(
|
||||
self, call_lengths: Sequence[float], response_times: Sequence[float]
|
||||
) -> float:
|
||||
if not call_lengths or len(call_lengths) < 3:
|
||||
return 0.0
|
||||
|
||||
@@ -312,7 +355,11 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
||||
if len(call_lengths) >= 4:
|
||||
repeated_lengths = 0
|
||||
for i in range(len(call_lengths) - 2):
|
||||
ratio = call_lengths[i] / call_lengths[i + 2] if call_lengths[i + 2] > 0 else 0
|
||||
ratio = (
|
||||
call_lengths[i] / call_lengths[i + 2]
|
||||
if call_lengths[i + 2] > 0
|
||||
else 0
|
||||
)
|
||||
if 0.85 <= ratio <= 1.15:
|
||||
repeated_lengths += 1
|
||||
|
||||
@@ -324,21 +371,27 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
||||
std_time = np.std(response_times)
|
||||
mean_time = np.mean(response_times)
|
||||
if mean_time > 0:
|
||||
time_consistency = 1.0 - (std_time / mean_time)
|
||||
indicators.append(max(0, time_consistency - 0.3) * 1.5)
|
||||
except Exception:
|
||||
pass
|
||||
time_consistency = 1.0 - (float(std_time) / float(mean_time))
|
||||
indicators.append(max(0.0, float(time_consistency - 0.3)) * 1.5)
|
||||
except Exception as e:
|
||||
logging.debug(f"Time consistency calculation failed: {e}")
|
||||
|
||||
return np.mean(indicators) if indicators else 0.0
|
||||
return float(np.mean(indicators)) if indicators else 0.0
|
||||
|
||||
def _get_call_samples(self, llm_calls: List[Dict]) -> str:
|
||||
def _get_call_samples(self, llm_calls: list[dict]) -> str:
|
||||
samples = []
|
||||
|
||||
if len(llm_calls) <= 6:
|
||||
sample_indices = list(range(len(llm_calls)))
|
||||
else:
|
||||
sample_indices = [0, 1, len(llm_calls) // 2 - 1, len(llm_calls) // 2,
|
||||
len(llm_calls) - 2, len(llm_calls) - 1]
|
||||
sample_indices = [
|
||||
0,
|
||||
1,
|
||||
len(llm_calls) // 2 - 1,
|
||||
len(llm_calls) // 2,
|
||||
len(llm_calls) - 2,
|
||||
len(llm_calls) - 1,
|
||||
]
|
||||
|
||||
for idx in sample_indices:
|
||||
call = llm_calls[idx]
|
||||
@@ -347,10 +400,11 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
||||
if isinstance(content, str):
|
||||
sample = content
|
||||
elif isinstance(content, list) and len(content) > 0:
|
||||
sample_parts = []
|
||||
for msg in content:
|
||||
if isinstance(msg, dict) and "content" in msg:
|
||||
sample_parts.append(msg["content"])
|
||||
sample_parts = [
|
||||
msg["content"]
|
||||
for msg in content
|
||||
if isinstance(msg, dict) and "content" in msg
|
||||
]
|
||||
sample = "\n".join(sample_parts)
|
||||
else:
|
||||
sample = str(content)
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.experimental.evaluation.base_evaluator import (
|
||||
BaseEvaluator,
|
||||
EvaluationScore,
|
||||
MetricCategory,
|
||||
)
|
||||
from crewai.experimental.evaluation.json_parser import extract_json_from_llm_response
|
||||
from crewai.task import Task
|
||||
|
||||
from crewai.experimental.evaluation.base_evaluator import BaseEvaluator, EvaluationScore, MetricCategory
|
||||
from crewai.experimental.evaluation.json_parser import extract_json_from_llm_response
|
||||
|
||||
class SemanticQualityEvaluator(BaseEvaluator):
|
||||
@property
|
||||
@@ -13,8 +18,8 @@ class SemanticQualityEvaluator(BaseEvaluator):
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
agent: Agent,
|
||||
execution_trace: Dict[str, Any],
|
||||
agent: Agent | BaseAgent,
|
||||
execution_trace: dict[str, Any],
|
||||
final_output: Any,
|
||||
task: Task | None = None,
|
||||
) -> EvaluationScore:
|
||||
@@ -22,7 +27,9 @@ class SemanticQualityEvaluator(BaseEvaluator):
|
||||
if task is not None:
|
||||
task_context = f"Task description: {task.description}"
|
||||
prompt = [
|
||||
{"role": "system", "content": """You are an expert evaluator assessing the semantic quality of an AI agent's output.
|
||||
{
|
||||
"role": "system",
|
||||
"content": """You are an expert evaluator assessing the semantic quality of an AI agent's output.
|
||||
|
||||
Score the semantic quality on a scale from 0-10 where:
|
||||
- 0: Completely incoherent, confusing, or logically flawed output
|
||||
@@ -37,8 +44,11 @@ Consider:
|
||||
5. Is the output free from contradictions and logical fallacies?
|
||||
|
||||
Return your evaluation as JSON with fields 'score' (number) and 'feedback' (string).
|
||||
"""},
|
||||
{"role": "user", "content": f"""
|
||||
""",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""
|
||||
Agent role: {agent.role}
|
||||
{task_context}
|
||||
|
||||
@@ -46,23 +56,28 @@ Agent's final output:
|
||||
{final_output}
|
||||
|
||||
Evaluate the semantic quality and reasoning of this output.
|
||||
"""}
|
||||
""",
|
||||
},
|
||||
]
|
||||
|
||||
assert self.llm is not None
|
||||
if self.llm is None:
|
||||
raise ValueError("LLM must be initialized")
|
||||
response = self.llm.call(prompt)
|
||||
|
||||
try:
|
||||
evaluation_data: dict[str, Any] = extract_json_from_llm_response(response)
|
||||
assert evaluation_data is not None
|
||||
if evaluation_data is None:
|
||||
raise ValueError("Failed to extract evaluation data from LLM response")
|
||||
return EvaluationScore(
|
||||
score=float(evaluation_data["score"]) if evaluation_data.get("score") is not None else None,
|
||||
score=float(evaluation_data["score"])
|
||||
if evaluation_data.get("score") is not None
|
||||
else None,
|
||||
feedback=evaluation_data.get("feedback", response),
|
||||
raw_response=response
|
||||
raw_response=response,
|
||||
)
|
||||
except Exception:
|
||||
return EvaluationScore(
|
||||
score=None,
|
||||
feedback=f"Failed to parse evaluation. Raw response: {response}",
|
||||
raw_response=response
|
||||
)
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
@@ -1,22 +1,26 @@
|
||||
import json
|
||||
from typing import Dict, Any
|
||||
from typing import Any
|
||||
|
||||
from crewai.experimental.evaluation.base_evaluator import BaseEvaluator, EvaluationScore, MetricCategory
|
||||
from crewai.experimental.evaluation.json_parser import extract_json_from_llm_response
|
||||
from crewai.agent import Agent
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.experimental.evaluation.base_evaluator import (
|
||||
BaseEvaluator,
|
||||
EvaluationScore,
|
||||
MetricCategory,
|
||||
)
|
||||
from crewai.experimental.evaluation.json_parser import extract_json_from_llm_response
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
class ToolSelectionEvaluator(BaseEvaluator):
|
||||
|
||||
@property
|
||||
def metric_category(self) -> MetricCategory:
|
||||
return MetricCategory.TOOL_SELECTION
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
agent: Agent,
|
||||
execution_trace: Dict[str, Any],
|
||||
agent: Agent | BaseAgent,
|
||||
execution_trace: dict[str, Any],
|
||||
final_output: str,
|
||||
task: Task | None = None,
|
||||
) -> EvaluationScore:
|
||||
@@ -26,19 +30,18 @@ class ToolSelectionEvaluator(BaseEvaluator):
|
||||
|
||||
tool_uses = execution_trace.get("tool_uses", [])
|
||||
tool_count = len(tool_uses)
|
||||
unique_tool_types = set([tool.get("tool", "Unknown tool") for tool in tool_uses])
|
||||
unique_tool_types = set(
|
||||
[tool.get("tool", "Unknown tool") for tool in tool_uses]
|
||||
)
|
||||
|
||||
if tool_count == 0:
|
||||
if not agent.tools:
|
||||
return EvaluationScore(
|
||||
score=None,
|
||||
feedback="Agent had no tools available to use."
|
||||
)
|
||||
else:
|
||||
return EvaluationScore(
|
||||
score=None,
|
||||
feedback="Agent had tools available but didn't use any."
|
||||
score=None, feedback="Agent had no tools available to use."
|
||||
)
|
||||
return EvaluationScore(
|
||||
score=None, feedback="Agent had tools available but didn't use any."
|
||||
)
|
||||
|
||||
available_tools_info = ""
|
||||
if agent.tools:
|
||||
@@ -52,7 +55,9 @@ class ToolSelectionEvaluator(BaseEvaluator):
|
||||
tool_types_summary += f"- {tool_type}\n"
|
||||
|
||||
prompt = [
|
||||
{"role": "system", "content": """You are an expert evaluator assessing if an AI agent selected the most appropriate tools for a given task.
|
||||
{
|
||||
"role": "system",
|
||||
"content": """You are an expert evaluator assessing if an AI agent selected the most appropriate tools for a given task.
|
||||
|
||||
You must evaluate based on these 2 criteria:
|
||||
1. Relevance (0-10): Were the tools chosen directly aligned with the task's goals?
|
||||
@@ -73,8 +78,11 @@ Return your evaluation as JSON with these fields:
|
||||
- overall_score: number (average of all scores, 0-10)
|
||||
- feedback: string (focused ONLY on tool selection decisions from available tools)
|
||||
- improvement_suggestions: string (ONLY suggest better selection from the AVAILABLE tools list, NOT new tools)
|
||||
"""},
|
||||
{"role": "user", "content": f"""
|
||||
""",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""
|
||||
Agent role: {agent.role}
|
||||
{task_context}
|
||||
|
||||
@@ -89,14 +97,17 @@ IMPORTANT:
|
||||
- ONLY evaluate selection from tools listed as available
|
||||
- DO NOT suggest new tools that aren't in the available tools list
|
||||
- DO NOT evaluate tool usage or results
|
||||
"""}
|
||||
""",
|
||||
},
|
||||
]
|
||||
assert self.llm is not None
|
||||
if self.llm is None:
|
||||
raise ValueError("LLM must be initialized")
|
||||
response = self.llm.call(prompt)
|
||||
|
||||
try:
|
||||
evaluation_data = extract_json_from_llm_response(response)
|
||||
assert evaluation_data is not None
|
||||
if evaluation_data is None:
|
||||
raise ValueError("Failed to extract evaluation data from LLM response")
|
||||
|
||||
scores = evaluation_data.get("scores", {})
|
||||
relevance = scores.get("relevance", 5.0)
|
||||
@@ -105,22 +116,24 @@ IMPORTANT:
|
||||
|
||||
feedback = "Tool Selection Evaluation:\n"
|
||||
feedback += f"• Relevance: {relevance}/10 - Selection of appropriate tool types for the task\n"
|
||||
feedback += f"• Coverage: {coverage}/10 - Selection of all necessary tool types\n"
|
||||
feedback += (
|
||||
f"• Coverage: {coverage}/10 - Selection of all necessary tool types\n"
|
||||
)
|
||||
if "improvement_suggestions" in evaluation_data:
|
||||
feedback += f"Improvement Suggestions:\n{evaluation_data['improvement_suggestions']}"
|
||||
else:
|
||||
feedback += evaluation_data.get("feedback", "No detailed feedback available.")
|
||||
feedback += evaluation_data.get(
|
||||
"feedback", "No detailed feedback available."
|
||||
)
|
||||
|
||||
return EvaluationScore(
|
||||
score=overall_score,
|
||||
feedback=feedback,
|
||||
raw_response=response
|
||||
score=overall_score, feedback=feedback, raw_response=response
|
||||
)
|
||||
except Exception as e:
|
||||
return EvaluationScore(
|
||||
score=None,
|
||||
feedback=f"Error evaluating tool selection: {e}",
|
||||
raw_response=response
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
|
||||
@@ -131,8 +144,8 @@ class ParameterExtractionEvaluator(BaseEvaluator):
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
agent: Agent,
|
||||
execution_trace: Dict[str, Any],
|
||||
agent: Agent | BaseAgent,
|
||||
execution_trace: dict[str, Any],
|
||||
final_output: str,
|
||||
task: Task | None = None,
|
||||
) -> EvaluationScore:
|
||||
@@ -145,19 +158,23 @@ class ParameterExtractionEvaluator(BaseEvaluator):
|
||||
if tool_count == 0:
|
||||
return EvaluationScore(
|
||||
score=None,
|
||||
feedback="No tool usage detected. Cannot evaluate parameter extraction."
|
||||
feedback="No tool usage detected. Cannot evaluate parameter extraction.",
|
||||
)
|
||||
|
||||
validation_errors = []
|
||||
for tool_use in tool_uses:
|
||||
if not tool_use.get("success", True) and tool_use.get("error_type") == "validation_error":
|
||||
validation_errors.append({
|
||||
"tool": tool_use.get("tool", "Unknown tool"),
|
||||
"error": tool_use.get("result"),
|
||||
"args": tool_use.get("args", {})
|
||||
})
|
||||
validation_errors = [
|
||||
{
|
||||
"tool": tool_use.get("tool", "Unknown tool"),
|
||||
"error": tool_use.get("result"),
|
||||
"args": tool_use.get("args", {}),
|
||||
}
|
||||
for tool_use in tool_uses
|
||||
if not tool_use.get("success", True)
|
||||
and tool_use.get("error_type") == "validation_error"
|
||||
]
|
||||
|
||||
validation_error_rate = len(validation_errors) / tool_count if tool_count > 0 else 0
|
||||
validation_error_rate = (
|
||||
len(validation_errors) / tool_count if tool_count > 0 else 0
|
||||
)
|
||||
|
||||
param_samples = []
|
||||
for i, tool_use in enumerate(tool_uses[:5]):
|
||||
@@ -168,7 +185,7 @@ class ParameterExtractionEvaluator(BaseEvaluator):
|
||||
|
||||
is_validation_error = error_type == "validation_error"
|
||||
|
||||
sample = f"Tool use #{i+1} - {tool_name}:\n"
|
||||
sample = f"Tool use #{i + 1} - {tool_name}:\n"
|
||||
sample += f"- Parameters: {json.dumps(tool_args, indent=2)}\n"
|
||||
sample += f"- Success: {'No' if not success else 'Yes'}"
|
||||
|
||||
@@ -187,13 +204,17 @@ class ParameterExtractionEvaluator(BaseEvaluator):
|
||||
tool_name = err.get("tool", "Unknown tool")
|
||||
error_msg = err.get("error", "Unknown error")
|
||||
args = err.get("args", {})
|
||||
validation_errors_info += f"\nValidation Error #{i+1}:\n- Tool: {tool_name}\n- Args: {json.dumps(args, indent=2)}\n- Error: {error_msg}"
|
||||
validation_errors_info += f"\nValidation Error #{i + 1}:\n- Tool: {tool_name}\n- Args: {json.dumps(args, indent=2)}\n- Error: {error_msg}"
|
||||
|
||||
if len(validation_errors) > 3:
|
||||
validation_errors_info += f"\n...and {len(validation_errors) - 3} more validation errors."
|
||||
validation_errors_info += (
|
||||
f"\n...and {len(validation_errors) - 3} more validation errors."
|
||||
)
|
||||
param_samples_text = "\n\n".join(param_samples)
|
||||
prompt = [
|
||||
{"role": "system", "content": """You are an expert evaluator assessing how well an AI agent extracts and formats PARAMETER VALUES for tool calls.
|
||||
{
|
||||
"role": "system",
|
||||
"content": """You are an expert evaluator assessing how well an AI agent extracts and formats PARAMETER VALUES for tool calls.
|
||||
|
||||
Your job is to evaluate ONLY whether the agent used the correct parameter VALUES, not whether the right tools were selected or how the tools were invoked.
|
||||
|
||||
@@ -216,8 +237,11 @@ Return your evaluation as JSON with these fields:
|
||||
- overall_score: number (average of all scores, 0-10)
|
||||
- feedback: string (focused ONLY on parameter value extraction quality)
|
||||
- improvement_suggestions: string (concrete suggestions for better parameter VALUE extraction)
|
||||
"""},
|
||||
{"role": "user", "content": f"""
|
||||
""",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""
|
||||
Agent role: {agent.role}
|
||||
{task_context}
|
||||
|
||||
@@ -226,15 +250,18 @@ Parameter extraction examples:
|
||||
{validation_errors_info}
|
||||
|
||||
Evaluate the quality of the agent's parameter extraction for this task.
|
||||
"""}
|
||||
""",
|
||||
},
|
||||
]
|
||||
|
||||
assert self.llm is not None
|
||||
if self.llm is None:
|
||||
raise ValueError("LLM must be initialized")
|
||||
response = self.llm.call(prompt)
|
||||
|
||||
try:
|
||||
evaluation_data = extract_json_from_llm_response(response)
|
||||
assert evaluation_data is not None
|
||||
if evaluation_data is None:
|
||||
raise ValueError("Failed to extract evaluation data from LLM response")
|
||||
|
||||
scores = evaluation_data.get("scores", {})
|
||||
accuracy = scores.get("accuracy", 5.0)
|
||||
@@ -251,18 +278,18 @@ Evaluate the quality of the agent's parameter extraction for this task.
|
||||
if "improvement_suggestions" in evaluation_data:
|
||||
feedback += f"Improvement Suggestions:\n{evaluation_data['improvement_suggestions']}"
|
||||
else:
|
||||
feedback += evaluation_data.get("feedback", "No detailed feedback available.")
|
||||
feedback += evaluation_data.get(
|
||||
"feedback", "No detailed feedback available."
|
||||
)
|
||||
|
||||
return EvaluationScore(
|
||||
score=overall_score,
|
||||
feedback=feedback,
|
||||
raw_response=response
|
||||
score=overall_score, feedback=feedback, raw_response=response
|
||||
)
|
||||
except Exception as e:
|
||||
return EvaluationScore(
|
||||
score=None,
|
||||
feedback=f"Error evaluating parameter extraction: {e}",
|
||||
raw_response=response
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
|
||||
@@ -273,8 +300,8 @@ class ToolInvocationEvaluator(BaseEvaluator):
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
agent: Agent,
|
||||
execution_trace: Dict[str, Any],
|
||||
agent: Agent | BaseAgent,
|
||||
execution_trace: dict[str, Any],
|
||||
final_output: str,
|
||||
task: Task | None = None,
|
||||
) -> EvaluationScore:
|
||||
@@ -288,7 +315,7 @@ class ToolInvocationEvaluator(BaseEvaluator):
|
||||
if tool_count == 0:
|
||||
return EvaluationScore(
|
||||
score=None,
|
||||
feedback="No tool usage detected. Cannot evaluate tool invocation."
|
||||
feedback="No tool usage detected. Cannot evaluate tool invocation.",
|
||||
)
|
||||
|
||||
for tool_use in tool_uses:
|
||||
@@ -296,7 +323,7 @@ class ToolInvocationEvaluator(BaseEvaluator):
|
||||
error_info = {
|
||||
"tool": tool_use.get("tool", "Unknown tool"),
|
||||
"error": tool_use.get("result"),
|
||||
"error_type": tool_use.get("error_type", "unknown_error")
|
||||
"error_type": tool_use.get("error_type", "unknown_error"),
|
||||
}
|
||||
tool_errors.append(error_info)
|
||||
|
||||
@@ -315,9 +342,11 @@ class ToolInvocationEvaluator(BaseEvaluator):
|
||||
tool_args = tool_use.get("args", {})
|
||||
success = tool_use.get("success", True) and not tool_use.get("error", False)
|
||||
error_type = tool_use.get("error_type", "") if not success else ""
|
||||
error_msg = tool_use.get("result", "No error") if not success else "No error"
|
||||
error_msg = (
|
||||
tool_use.get("result", "No error") if not success else "No error"
|
||||
)
|
||||
|
||||
sample = f"Tool invocation #{i+1}:\n"
|
||||
sample = f"Tool invocation #{i + 1}:\n"
|
||||
sample += f"- Tool: {tool_name}\n"
|
||||
sample += f"- Parameters: {json.dumps(tool_args, indent=2)}\n"
|
||||
sample += f"- Success: {'No' if not success else 'Yes'}\n"
|
||||
@@ -330,11 +359,13 @@ class ToolInvocationEvaluator(BaseEvaluator):
|
||||
if error_types:
|
||||
error_type_summary = "Error type breakdown:\n"
|
||||
for error_type, count in error_types.items():
|
||||
error_type_summary += f"- {error_type}: {count} occurrences ({(count/tool_count):.1%})\n"
|
||||
error_type_summary += f"- {error_type}: {count} occurrences ({(count / tool_count):.1%})\n"
|
||||
|
||||
invocation_samples_text = "\n\n".join(invocation_samples)
|
||||
prompt = [
|
||||
{"role": "system", "content": """You are an expert evaluator assessing how correctly an AI agent's tool invocations are STRUCTURED.
|
||||
{
|
||||
"role": "system",
|
||||
"content": """You are an expert evaluator assessing how correctly an AI agent's tool invocations are STRUCTURED.
|
||||
|
||||
Your job is to evaluate ONLY the structural and syntactical aspects of how the agent called tools, NOT which tools were selected or what parameter values were used.
|
||||
|
||||
@@ -359,8 +390,11 @@ Return your evaluation as JSON with these fields:
|
||||
- overall_score: number (average of all scores, 0-10)
|
||||
- feedback: string (focused ONLY on structural aspects of tool invocation)
|
||||
- improvement_suggestions: string (concrete suggestions for better structuring of tool calls)
|
||||
"""},
|
||||
{"role": "user", "content": f"""
|
||||
""",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""
|
||||
Agent role: {agent.role}
|
||||
{task_context}
|
||||
|
||||
@@ -371,15 +405,18 @@ Tool error rate: {error_rate:.2%} ({len(tool_errors)} errors out of {tool_count}
|
||||
{error_type_summary}
|
||||
|
||||
Evaluate the quality of the agent's tool invocation structure during this task.
|
||||
"""}
|
||||
""",
|
||||
},
|
||||
]
|
||||
|
||||
assert self.llm is not None
|
||||
if self.llm is None:
|
||||
raise ValueError("LLM must be initialized")
|
||||
response = self.llm.call(prompt)
|
||||
|
||||
try:
|
||||
evaluation_data = extract_json_from_llm_response(response)
|
||||
assert evaluation_data is not None
|
||||
if evaluation_data is None:
|
||||
raise ValueError("Failed to extract evaluation data from LLM response")
|
||||
scores = evaluation_data.get("scores", {})
|
||||
structure = scores.get("structure", 5.0)
|
||||
error_handling = scores.get("error_handling", 5.0)
|
||||
@@ -388,23 +425,25 @@ Evaluate the quality of the agent's tool invocation structure during this task.
|
||||
overall_score = float(evaluation_data.get("overall_score", 5.0))
|
||||
|
||||
feedback = "Tool Invocation Evaluation:\n"
|
||||
feedback += f"• Structure: {structure}/10 - Following proper syntax and format\n"
|
||||
feedback += (
|
||||
f"• Structure: {structure}/10 - Following proper syntax and format\n"
|
||||
)
|
||||
feedback += f"• Error Handling: {error_handling}/10 - Appropriately handling tool errors\n"
|
||||
feedback += f"• Invocation Patterns: {invocation_patterns}/10 - Proper sequencing and management of calls\n\n"
|
||||
|
||||
if "improvement_suggestions" in evaluation_data:
|
||||
feedback += f"Improvement Suggestions:\n{evaluation_data['improvement_suggestions']}"
|
||||
else:
|
||||
feedback += evaluation_data.get("feedback", "No detailed feedback available.")
|
||||
feedback += evaluation_data.get(
|
||||
"feedback", "No detailed feedback available."
|
||||
)
|
||||
|
||||
return EvaluationScore(
|
||||
score=overall_score,
|
||||
feedback=feedback,
|
||||
raw_response=response
|
||||
score=overall_score, feedback=feedback, raw_response=response
|
||||
)
|
||||
except Exception as e:
|
||||
return EvaluationScore(
|
||||
score=None,
|
||||
feedback=f"Error evaluating tool invocation: {e}",
|
||||
raw_response=response
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
@@ -1,12 +1,21 @@
|
||||
import inspect
|
||||
import warnings
|
||||
|
||||
from typing_extensions import Any
|
||||
import warnings
|
||||
from crewai.experimental.evaluation.experiment import ExperimentResults, ExperimentRunner
|
||||
from crewai import Crew, Agent
|
||||
|
||||
def assert_experiment_successfully(experiment_results: ExperimentResults, baseline_filepath: str | None = None) -> None:
|
||||
failed_tests = [result for result in experiment_results.results if not result.passed]
|
||||
from crewai import Agent, Crew
|
||||
from crewai.experimental.evaluation.experiment import (
|
||||
ExperimentResults,
|
||||
ExperimentRunner,
|
||||
)
|
||||
|
||||
|
||||
def assert_experiment_successfully(
|
||||
experiment_results: ExperimentResults, baseline_filepath: str | None = None
|
||||
) -> None:
|
||||
failed_tests = [
|
||||
result for result in experiment_results.results if not result.passed
|
||||
]
|
||||
|
||||
if failed_tests:
|
||||
detailed_failures: list[str] = []
|
||||
@@ -14,39 +23,54 @@ def assert_experiment_successfully(experiment_results: ExperimentResults, baseli
|
||||
for result in failed_tests:
|
||||
expected = result.expected_score
|
||||
actual = result.score
|
||||
detailed_failures.append(f"- {result.identifier}: expected {expected}, got {actual}")
|
||||
detailed_failures.append(
|
||||
f"- {result.identifier}: expected {expected}, got {actual}"
|
||||
)
|
||||
|
||||
failure_details = "\n".join(detailed_failures)
|
||||
raise AssertionError(f"The following test cases failed:\n{failure_details}")
|
||||
|
||||
baseline_filepath = baseline_filepath or _get_baseline_filepath_fallback()
|
||||
comparison = experiment_results.compare_with_baseline(baseline_filepath=baseline_filepath)
|
||||
comparison = experiment_results.compare_with_baseline(
|
||||
baseline_filepath=baseline_filepath
|
||||
)
|
||||
assert_experiment_no_regression(comparison)
|
||||
|
||||
|
||||
def assert_experiment_no_regression(comparison_result: dict[str, list[str]]) -> None:
|
||||
regressed = comparison_result.get("regressed", [])
|
||||
if regressed:
|
||||
raise AssertionError(f"Regression detected! The following tests that previously passed now fail: {regressed}")
|
||||
raise AssertionError(
|
||||
f"Regression detected! The following tests that previously passed now fail: {regressed}"
|
||||
)
|
||||
|
||||
missing_tests = comparison_result.get("missing_tests", [])
|
||||
if missing_tests:
|
||||
warnings.warn(
|
||||
f"Warning: {len(missing_tests)} tests from the baseline are missing in the current run: {missing_tests}",
|
||||
UserWarning
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
def run_experiment(dataset: list[dict[str, Any]], crew: Crew | None = None, agents: list[Agent] | None = None, verbose: bool = False) -> ExperimentResults:
|
||||
|
||||
def run_experiment(
|
||||
dataset: list[dict[str, Any]],
|
||||
crew: Crew | None = None,
|
||||
agents: list[Agent] | None = None,
|
||||
verbose: bool = False,
|
||||
) -> ExperimentResults:
|
||||
runner = ExperimentRunner(dataset=dataset)
|
||||
|
||||
return runner.run(agents=agents, crew=crew, print_summary=verbose)
|
||||
|
||||
|
||||
def _get_baseline_filepath_fallback() -> str:
|
||||
test_func_name = "experiment_fallback"
|
||||
|
||||
try:
|
||||
current_frame = inspect.currentframe()
|
||||
if current_frame is not None:
|
||||
test_func_name = current_frame.f_back.f_back.f_code.co_name # type: ignore[union-attr]
|
||||
test_func_name = current_frame.f_back.f_back.f_code.co_name # type: ignore[union-attr]
|
||||
except Exception:
|
||||
...
|
||||
return f"{test_func_name}_results.json"
|
||||
return f"{test_func_name}_results.json"
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from crewai.flow.flow import Flow, start, listen, or_, and_, router
|
||||
from crewai.flow.flow import Flow, and_, listen, or_, router, start
|
||||
from crewai.flow.persistence import persist
|
||||
|
||||
__all__ = ["Flow", "start", "listen", "or_", "and_", "router", "persist"]
|
||||
|
||||
__all__ = ["Flow", "and_", "listen", "or_", "persist", "router", "start"]
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import inspect
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, InstanceOf, model_validator
|
||||
|
||||
@@ -14,7 +13,7 @@ class FlowTrackable(BaseModel):
|
||||
inspecting the call stack.
|
||||
"""
|
||||
|
||||
parent_flow: Optional[InstanceOf[Flow]] = Field(
|
||||
parent_flow: InstanceOf[Flow] | None = Field(
|
||||
default=None,
|
||||
description="The parent flow of the instance, if it was created inside a flow.",
|
||||
)
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
# flow_visualizer.py
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from pyvis.network import Network
|
||||
from pyvis.network import Network # type: ignore[import-untyped]
|
||||
|
||||
from crewai.flow.config import COLORS, NODE_STYLES
|
||||
from crewai.flow.html_template_handler import HTMLTemplateHandler
|
||||
from crewai.flow.legend_generator import generate_legend_items_html, get_legend_items
|
||||
from crewai.flow.path_utils import safe_path_join, validate_path_exists
|
||||
from crewai.flow.path_utils import safe_path_join
|
||||
from crewai.flow.utils import calculate_node_levels
|
||||
from crewai.flow.visualization_utils import (
|
||||
add_edges,
|
||||
@@ -34,13 +33,13 @@ class FlowPlot:
|
||||
ValueError
|
||||
If flow object is invalid or missing required attributes.
|
||||
"""
|
||||
if not hasattr(flow, '_methods'):
|
||||
if not hasattr(flow, "_methods"):
|
||||
raise ValueError("Invalid flow object: missing '_methods' attribute")
|
||||
if not hasattr(flow, '_listeners'):
|
||||
if not hasattr(flow, "_listeners"):
|
||||
raise ValueError("Invalid flow object: missing '_listeners' attribute")
|
||||
if not hasattr(flow, '_start_methods'):
|
||||
if not hasattr(flow, "_start_methods"):
|
||||
raise ValueError("Invalid flow object: missing '_start_methods' attribute")
|
||||
|
||||
|
||||
self.flow = flow
|
||||
self.colors = COLORS
|
||||
self.node_styles = NODE_STYLES
|
||||
@@ -65,7 +64,7 @@ class FlowPlot:
|
||||
"""
|
||||
if not filename or not isinstance(filename, str):
|
||||
raise ValueError("Filename must be a non-empty string")
|
||||
|
||||
|
||||
try:
|
||||
# Initialize network
|
||||
net = Network(
|
||||
@@ -96,32 +95,34 @@ class FlowPlot:
|
||||
try:
|
||||
node_levels = calculate_node_levels(self.flow)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to calculate node levels: {str(e)}")
|
||||
raise ValueError(f"Failed to calculate node levels: {e!s}") from e
|
||||
|
||||
# Compute positions
|
||||
try:
|
||||
node_positions = compute_positions(self.flow, node_levels)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to compute node positions: {str(e)}")
|
||||
raise ValueError(f"Failed to compute node positions: {e!s}") from e
|
||||
|
||||
# Add nodes to the network
|
||||
try:
|
||||
add_nodes_to_network(net, self.flow, node_positions, self.node_styles)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to add nodes to network: {str(e)}")
|
||||
raise RuntimeError(f"Failed to add nodes to network: {e!s}") from e
|
||||
|
||||
# Add edges to the network
|
||||
try:
|
||||
add_edges(net, self.flow, node_positions, self.colors)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to add edges to network: {str(e)}")
|
||||
raise RuntimeError(f"Failed to add edges to network: {e!s}") from e
|
||||
|
||||
# Generate HTML
|
||||
try:
|
||||
network_html = net.generate_html()
|
||||
final_html_content = self._generate_final_html(network_html)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to generate network visualization: {str(e)}")
|
||||
raise RuntimeError(
|
||||
f"Failed to generate network visualization: {e!s}"
|
||||
) from e
|
||||
|
||||
# Save the final HTML content to the file
|
||||
try:
|
||||
@@ -129,12 +130,16 @@ class FlowPlot:
|
||||
f.write(final_html_content)
|
||||
print(f"Plot saved as {filename}.html")
|
||||
except IOError as e:
|
||||
raise IOError(f"Failed to save flow visualization to {filename}.html: {str(e)}")
|
||||
raise IOError(
|
||||
f"Failed to save flow visualization to {filename}.html: {e!s}"
|
||||
) from e
|
||||
|
||||
except (ValueError, RuntimeError, IOError) as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Unexpected error during flow visualization: {str(e)}")
|
||||
raise RuntimeError(
|
||||
f"Unexpected error during flow visualization: {e!s}"
|
||||
) from e
|
||||
finally:
|
||||
self._cleanup_pyvis_lib()
|
||||
|
||||
@@ -165,7 +170,9 @@ class FlowPlot:
|
||||
try:
|
||||
# Extract just the body content from the generated HTML
|
||||
current_dir = os.path.dirname(__file__)
|
||||
template_path = safe_path_join("assets", "crewai_flow_visual_template.html", root=current_dir)
|
||||
template_path = safe_path_join(
|
||||
"assets", "crewai_flow_visual_template.html", root=current_dir
|
||||
)
|
||||
logo_path = safe_path_join("assets", "crewai_logo.svg", root=current_dir)
|
||||
|
||||
if not os.path.exists(template_path):
|
||||
@@ -179,12 +186,9 @@ class FlowPlot:
|
||||
# Generate the legend items HTML
|
||||
legend_items = get_legend_items(self.colors)
|
||||
legend_items_html = generate_legend_items_html(legend_items)
|
||||
final_html_content = html_handler.generate_final_html(
|
||||
network_body, legend_items_html
|
||||
)
|
||||
return final_html_content
|
||||
return html_handler.generate_final_html(network_body, legend_items_html)
|
||||
except Exception as e:
|
||||
raise IOError(f"Failed to generate visualization HTML: {str(e)}")
|
||||
raise IOError(f"Failed to generate visualization HTML: {e!s}") from e
|
||||
|
||||
def _cleanup_pyvis_lib(self):
|
||||
"""
|
||||
@@ -197,6 +201,7 @@ class FlowPlot:
|
||||
lib_folder = safe_path_join("lib", root=os.getcwd())
|
||||
if os.path.exists(lib_folder) and os.path.isdir(lib_folder):
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(lib_folder)
|
||||
except ValueError as e:
|
||||
print(f"Error validating lib folder path: {e}")
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import base64
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from crewai.flow.path_utils import safe_path_join, validate_path_exists
|
||||
from crewai.flow.path_utils import validate_path_exists
|
||||
|
||||
|
||||
class HTMLTemplateHandler:
|
||||
@@ -28,7 +27,7 @@ class HTMLTemplateHandler:
|
||||
self.template_path = validate_path_exists(template_path, "file")
|
||||
self.logo_path = validate_path_exists(logo_path, "file")
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid template or logo path: {e}")
|
||||
raise ValueError(f"Invalid template or logo path: {e}") from e
|
||||
|
||||
def read_template(self):
|
||||
"""Read and return the HTML template file contents."""
|
||||
@@ -53,23 +52,23 @@ class HTMLTemplateHandler:
|
||||
if "border" in item:
|
||||
legend_items_html += f"""
|
||||
<div class="legend-item">
|
||||
<div class="legend-color-box" style="background-color: {item['color']}; border: 2px dashed {item['border']};"></div>
|
||||
<div>{item['label']}</div>
|
||||
<div class="legend-color-box" style="background-color: {item["color"]}; border: 2px dashed {item["border"]};"></div>
|
||||
<div>{item["label"]}</div>
|
||||
</div>
|
||||
"""
|
||||
elif item.get("dashed") is not None:
|
||||
style = "dashed" if item["dashed"] else "solid"
|
||||
legend_items_html += f"""
|
||||
<div class="legend-item">
|
||||
<div class="legend-{style}" style="border-bottom: 2px {style} {item['color']};"></div>
|
||||
<div>{item['label']}</div>
|
||||
<div class="legend-{style}" style="border-bottom: 2px {style} {item["color"]};"></div>
|
||||
<div>{item["label"]}</div>
|
||||
</div>
|
||||
"""
|
||||
else:
|
||||
legend_items_html += f"""
|
||||
<div class="legend-item">
|
||||
<div class="legend-color-box" style="background-color: {item['color']};"></div>
|
||||
<div>{item['label']}</div>
|
||||
<div class="legend-color-box" style="background-color: {item["color"]};"></div>
|
||||
<div>{item["label"]}</div>
|
||||
</div>
|
||||
"""
|
||||
return legend_items_html
|
||||
@@ -79,15 +78,9 @@ class HTMLTemplateHandler:
|
||||
html_template = self.read_template()
|
||||
logo_svg_base64 = self.encode_logo()
|
||||
|
||||
final_html_content = html_template.replace("{{ title }}", title)
|
||||
final_html_content = final_html_content.replace(
|
||||
"{{ network_content }}", network_body
|
||||
return (
|
||||
html_template.replace("{{ title }}", title)
|
||||
.replace("{{ network_content }}", network_body)
|
||||
.replace("{{ logo_svg_base64 }}", logo_svg_base64)
|
||||
.replace("<!-- LEGEND_ITEMS_PLACEHOLDER -->", legend_items_html)
|
||||
)
|
||||
final_html_content = final_html_content.replace(
|
||||
"{{ logo_svg_base64 }}", logo_svg_base64
|
||||
)
|
||||
final_html_content = final_html_content.replace(
|
||||
"<!-- LEGEND_ITEMS_PLACEHOLDER -->", legend_items_html
|
||||
)
|
||||
|
||||
return final_html_content
|
||||
|
||||
@@ -5,12 +5,10 @@ This module provides utilities for secure path handling to prevent directory
|
||||
traversal attacks and ensure paths remain within allowed boundaries.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
|
||||
def safe_path_join(*parts: str, root: Union[str, Path, None] = None) -> str:
|
||||
def safe_path_join(*parts: str, root: str | Path | None = None) -> str:
|
||||
"""
|
||||
Safely join path components and ensure the result is within allowed boundaries.
|
||||
|
||||
@@ -43,25 +41,25 @@ def safe_path_join(*parts: str, root: Union[str, Path, None] = None) -> str:
|
||||
|
||||
# Establish root directory
|
||||
root_path = Path(root).resolve() if root else Path.cwd()
|
||||
|
||||
|
||||
# Join and resolve the full path
|
||||
full_path = Path(root_path, *clean_parts).resolve()
|
||||
|
||||
|
||||
# Check if the resolved path is within root
|
||||
if not str(full_path).startswith(str(root_path)):
|
||||
raise ValueError(
|
||||
f"Invalid path: Potential directory traversal. Path must be within {root_path}"
|
||||
)
|
||||
|
||||
|
||||
return str(full_path)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, ValueError):
|
||||
raise
|
||||
raise ValueError(f"Invalid path components: {str(e)}")
|
||||
raise ValueError(f"Invalid path components: {e!s}") from e
|
||||
|
||||
|
||||
def validate_path_exists(path: Union[str, Path], file_type: str = "file") -> str:
|
||||
def validate_path_exists(path: str | Path, file_type: str = "file") -> str:
|
||||
"""
|
||||
Validate that a path exists and is of the expected type.
|
||||
|
||||
@@ -84,24 +82,24 @@ def validate_path_exists(path: Union[str, Path], file_type: str = "file") -> str
|
||||
"""
|
||||
try:
|
||||
path_obj = Path(path).resolve()
|
||||
|
||||
|
||||
if not path_obj.exists():
|
||||
raise ValueError(f"Path does not exist: {path}")
|
||||
|
||||
|
||||
if file_type == "file" and not path_obj.is_file():
|
||||
raise ValueError(f"Path is not a file: {path}")
|
||||
elif file_type == "directory" and not path_obj.is_dir():
|
||||
if file_type == "directory" and not path_obj.is_dir():
|
||||
raise ValueError(f"Path is not a directory: {path}")
|
||||
|
||||
|
||||
return str(path_obj)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, ValueError):
|
||||
raise
|
||||
raise ValueError(f"Invalid path: {str(e)}")
|
||||
raise ValueError(f"Invalid path: {e!s}") from e
|
||||
|
||||
|
||||
def list_files(directory: Union[str, Path], pattern: str = "*") -> List[str]:
|
||||
def list_files(directory: str | Path, pattern: str = "*") -> list[str]:
|
||||
"""
|
||||
Safely list files in a directory matching a pattern.
|
||||
|
||||
@@ -126,10 +124,10 @@ def list_files(directory: Union[str, Path], pattern: str = "*") -> List[str]:
|
||||
dir_path = Path(directory).resolve()
|
||||
if not dir_path.is_dir():
|
||||
raise ValueError(f"Not a directory: {directory}")
|
||||
|
||||
|
||||
return [str(p) for p in dir_path.glob(pattern) if p.is_file()]
|
||||
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, ValueError):
|
||||
raise
|
||||
raise ValueError(f"Error listing files: {str(e)}")
|
||||
raise ValueError(f"Error listing files: {e!s}") from e
|
||||
|
||||
@@ -4,7 +4,7 @@ CrewAI Flow Persistence.
|
||||
This module provides interfaces and implementations for persisting flow states.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, TypeVar, Union
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -12,7 +12,7 @@ from crewai.flow.persistence.base import FlowPersistence
|
||||
from crewai.flow.persistence.decorators import persist
|
||||
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
|
||||
|
||||
__all__ = ["FlowPersistence", "persist", "SQLiteFlowPersistence"]
|
||||
__all__ = ["FlowPersistence", "SQLiteFlowPersistence", "persist"]
|
||||
|
||||
StateType = TypeVar('StateType', bound=Union[Dict[str, Any], BaseModel])
|
||||
DictStateType = Dict[str, Any]
|
||||
StateType = TypeVar("StateType", bound=dict[str, Any] | BaseModel)
|
||||
DictStateType = dict[str, Any]
|
||||
|
||||
@@ -1,53 +1,47 @@
|
||||
"""Base class for flow state persistence."""
|
||||
|
||||
import abc
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class FlowPersistence(abc.ABC):
|
||||
"""Abstract base class for flow state persistence.
|
||||
|
||||
|
||||
This class defines the interface that all persistence implementations must follow.
|
||||
It supports both structured (Pydantic BaseModel) and unstructured (dict) states.
|
||||
"""
|
||||
|
||||
|
||||
@abc.abstractmethod
|
||||
def init_db(self) -> None:
|
||||
"""Initialize the persistence backend.
|
||||
|
||||
|
||||
This method should handle any necessary setup, such as:
|
||||
- Creating tables
|
||||
- Establishing connections
|
||||
- Setting up indexes
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abc.abstractmethod
|
||||
def save_state(
|
||||
self,
|
||||
flow_uuid: str,
|
||||
method_name: str,
|
||||
state_data: Union[Dict[str, Any], BaseModel]
|
||||
self, flow_uuid: str, method_name: str, state_data: dict[str, Any] | BaseModel
|
||||
) -> None:
|
||||
"""Persist the flow state after method completion.
|
||||
|
||||
|
||||
Args:
|
||||
flow_uuid: Unique identifier for the flow instance
|
||||
method_name: Name of the method that just completed
|
||||
state_data: Current state data (either dict or Pydantic model)
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]:
|
||||
def load_state(self, flow_uuid: str) -> dict[str, Any] | None:
|
||||
"""Load the most recent state for a given flow UUID.
|
||||
|
||||
|
||||
Args:
|
||||
flow_uuid: Unique identifier for the flow instance
|
||||
|
||||
|
||||
Returns:
|
||||
The most recent state as a dictionary, or None if no state exists
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -24,13 +24,10 @@ Example:
|
||||
import asyncio
|
||||
import functools
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
@@ -48,7 +45,7 @@ LOG_MESSAGES = {
|
||||
"save_state": "Saving flow state to memory for ID: {}",
|
||||
"save_error": "Failed to persist state for method {}: {}",
|
||||
"state_missing": "Flow instance has no state",
|
||||
"id_missing": "Flow state must have an 'id' field for persistence"
|
||||
"id_missing": "Flow state must have an 'id' field for persistence",
|
||||
}
|
||||
|
||||
|
||||
@@ -58,7 +55,13 @@ class PersistenceDecorator:
|
||||
_printer = Printer() # Class-level printer instance
|
||||
|
||||
@classmethod
|
||||
def persist_state(cls, flow_instance: Any, method_name: str, persistence_instance: FlowPersistence, verbose: bool = False) -> None:
|
||||
def persist_state(
|
||||
cls,
|
||||
flow_instance: Any,
|
||||
method_name: str,
|
||||
persistence_instance: FlowPersistence,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
"""Persist flow state with proper error handling and logging.
|
||||
|
||||
This method handles the persistence of flow state data, including proper
|
||||
@@ -76,22 +79,24 @@ class PersistenceDecorator:
|
||||
AttributeError: If flow instance lacks required state attributes
|
||||
"""
|
||||
try:
|
||||
state = getattr(flow_instance, 'state', None)
|
||||
state = getattr(flow_instance, "state", None)
|
||||
if state is None:
|
||||
raise ValueError("Flow instance has no state")
|
||||
|
||||
flow_uuid: Optional[str] = None
|
||||
flow_uuid: str | None = None
|
||||
if isinstance(state, dict):
|
||||
flow_uuid = state.get('id')
|
||||
flow_uuid = state.get("id")
|
||||
elif isinstance(state, BaseModel):
|
||||
flow_uuid = getattr(state, 'id', None)
|
||||
flow_uuid = getattr(state, "id", None)
|
||||
|
||||
if not flow_uuid:
|
||||
raise ValueError("Flow state must have an 'id' field for persistence")
|
||||
|
||||
# Log state saving only if verbose is True
|
||||
if verbose:
|
||||
cls._printer.print(LOG_MESSAGES["save_state"].format(flow_uuid), color="cyan")
|
||||
cls._printer.print(
|
||||
LOG_MESSAGES["save_state"].format(flow_uuid), color="cyan"
|
||||
)
|
||||
logger.info(LOG_MESSAGES["save_state"].format(flow_uuid))
|
||||
|
||||
try:
|
||||
@@ -104,12 +109,12 @@ class PersistenceDecorator:
|
||||
error_msg = LOG_MESSAGES["save_error"].format(method_name, str(e))
|
||||
cls._printer.print(error_msg, color="red")
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(f"State persistence failed: {str(e)}") from e
|
||||
except AttributeError:
|
||||
raise RuntimeError(f"State persistence failed: {e!s}") from e
|
||||
except AttributeError as e:
|
||||
error_msg = LOG_MESSAGES["state_missing"]
|
||||
cls._printer.print(error_msg, color="red")
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
raise ValueError(error_msg) from e
|
||||
except (TypeError, ValueError) as e:
|
||||
error_msg = LOG_MESSAGES["id_missing"]
|
||||
cls._printer.print(error_msg, color="red")
|
||||
@@ -117,7 +122,7 @@ class PersistenceDecorator:
|
||||
raise ValueError(error_msg) from e
|
||||
|
||||
|
||||
def persist(persistence: Optional[FlowPersistence] = None, verbose: bool = False):
|
||||
def persist(persistence: FlowPersistence | None = None, verbose: bool = False):
|
||||
"""Decorator to persist flow state.
|
||||
|
||||
This decorator can be applied at either the class level or method level.
|
||||
@@ -144,111 +149,151 @@ def persist(persistence: Optional[FlowPersistence] = None, verbose: bool = False
|
||||
def begin(self):
|
||||
pass
|
||||
"""
|
||||
def decorator(target: Union[Type, Callable[..., T]]) -> Union[Type, Callable[..., T]]:
|
||||
|
||||
def decorator(target: type | Callable[..., T]) -> type | Callable[..., T]:
|
||||
"""Decorator that handles both class and method decoration."""
|
||||
actual_persistence = persistence or SQLiteFlowPersistence()
|
||||
|
||||
if isinstance(target, type):
|
||||
# Class decoration
|
||||
original_init = getattr(target, "__init__")
|
||||
original_init = target.__init__ # type: ignore[misc]
|
||||
|
||||
@functools.wraps(original_init)
|
||||
def new_init(self: Any, *args: Any, **kwargs: Any) -> None:
|
||||
if 'persistence' not in kwargs:
|
||||
kwargs['persistence'] = actual_persistence
|
||||
if "persistence" not in kwargs:
|
||||
kwargs["persistence"] = actual_persistence
|
||||
original_init(self, *args, **kwargs)
|
||||
|
||||
setattr(target, "__init__", new_init)
|
||||
target.__init__ = new_init # type: ignore[misc]
|
||||
|
||||
# Store original methods to preserve their decorators
|
||||
original_methods = {}
|
||||
|
||||
for name, method in target.__dict__.items():
|
||||
if callable(method) and (
|
||||
hasattr(method, "__is_start_method__") or
|
||||
hasattr(method, "__trigger_methods__") or
|
||||
hasattr(method, "__condition_type__") or
|
||||
hasattr(method, "__is_flow_method__") or
|
||||
hasattr(method, "__is_router__")
|
||||
):
|
||||
original_methods[name] = method
|
||||
original_methods = {
|
||||
name: method
|
||||
for name, method in target.__dict__.items()
|
||||
if callable(method)
|
||||
and (
|
||||
hasattr(method, "__is_start_method__")
|
||||
or hasattr(method, "__trigger_methods__")
|
||||
or hasattr(method, "__condition_type__")
|
||||
or hasattr(method, "__is_flow_method__")
|
||||
or hasattr(method, "__is_router__")
|
||||
)
|
||||
}
|
||||
|
||||
# Create wrapped versions of the methods that include persistence
|
||||
for name, method in original_methods.items():
|
||||
if asyncio.iscoroutinefunction(method):
|
||||
# Create a closure to capture the current name and method
|
||||
def create_async_wrapper(method_name: str, original_method: Callable):
|
||||
def create_async_wrapper(
|
||||
method_name: str, original_method: Callable
|
||||
):
|
||||
@functools.wraps(original_method)
|
||||
async def method_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
async def method_wrapper(
|
||||
self: Any, *args: Any, **kwargs: Any
|
||||
) -> Any:
|
||||
result = await original_method(self, *args, **kwargs)
|
||||
PersistenceDecorator.persist_state(self, method_name, actual_persistence, verbose)
|
||||
PersistenceDecorator.persist_state(
|
||||
self, method_name, actual_persistence, verbose
|
||||
)
|
||||
return result
|
||||
|
||||
return method_wrapper
|
||||
|
||||
wrapped = create_async_wrapper(name, method)
|
||||
|
||||
# Preserve all original decorators and attributes
|
||||
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
|
||||
for attr in [
|
||||
"__is_start_method__",
|
||||
"__trigger_methods__",
|
||||
"__condition_type__",
|
||||
"__is_router__",
|
||||
]:
|
||||
if hasattr(method, attr):
|
||||
setattr(wrapped, attr, getattr(method, attr))
|
||||
setattr(wrapped, "__is_flow_method__", True)
|
||||
wrapped.__is_flow_method__ = True # type: ignore[attr-defined]
|
||||
|
||||
# Update the class with the wrapped method
|
||||
setattr(target, name, wrapped)
|
||||
else:
|
||||
# Create a closure to capture the current name and method
|
||||
def create_sync_wrapper(method_name: str, original_method: Callable):
|
||||
def create_sync_wrapper(
|
||||
method_name: str, original_method: Callable
|
||||
):
|
||||
@functools.wraps(original_method)
|
||||
def method_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
result = original_method(self, *args, **kwargs)
|
||||
PersistenceDecorator.persist_state(self, method_name, actual_persistence, verbose)
|
||||
PersistenceDecorator.persist_state(
|
||||
self, method_name, actual_persistence, verbose
|
||||
)
|
||||
return result
|
||||
|
||||
return method_wrapper
|
||||
|
||||
wrapped = create_sync_wrapper(name, method)
|
||||
|
||||
# Preserve all original decorators and attributes
|
||||
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
|
||||
for attr in [
|
||||
"__is_start_method__",
|
||||
"__trigger_methods__",
|
||||
"__condition_type__",
|
||||
"__is_router__",
|
||||
]:
|
||||
if hasattr(method, attr):
|
||||
setattr(wrapped, attr, getattr(method, attr))
|
||||
setattr(wrapped, "__is_flow_method__", True)
|
||||
wrapped.__is_flow_method__ = True # type: ignore[attr-defined]
|
||||
|
||||
# Update the class with the wrapped method
|
||||
setattr(target, name, wrapped)
|
||||
|
||||
return target
|
||||
else:
|
||||
# Method decoration
|
||||
method = target
|
||||
setattr(method, "__is_flow_method__", True)
|
||||
# Method decoration
|
||||
method = target
|
||||
method.__is_flow_method__ = True # type: ignore[attr-defined]
|
||||
|
||||
if asyncio.iscoroutinefunction(method):
|
||||
@functools.wraps(method)
|
||||
async def method_async_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
|
||||
method_coro = method(flow_instance, *args, **kwargs)
|
||||
if asyncio.iscoroutine(method_coro):
|
||||
result = await method_coro
|
||||
else:
|
||||
result = method_coro
|
||||
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence, verbose)
|
||||
return result
|
||||
if asyncio.iscoroutinefunction(method):
|
||||
|
||||
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
|
||||
if hasattr(method, attr):
|
||||
setattr(method_async_wrapper, attr, getattr(method, attr))
|
||||
setattr(method_async_wrapper, "__is_flow_method__", True)
|
||||
return cast(Callable[..., T], method_async_wrapper)
|
||||
else:
|
||||
@functools.wraps(method)
|
||||
def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
|
||||
result = method(flow_instance, *args, **kwargs)
|
||||
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence, verbose)
|
||||
return result
|
||||
@functools.wraps(method)
|
||||
async def method_async_wrapper(
|
||||
flow_instance: Any, *args: Any, **kwargs: Any
|
||||
) -> T:
|
||||
method_coro = method(flow_instance, *args, **kwargs)
|
||||
if asyncio.iscoroutine(method_coro):
|
||||
result = await method_coro
|
||||
else:
|
||||
result = method_coro
|
||||
PersistenceDecorator.persist_state(
|
||||
flow_instance, method.__name__, actual_persistence, verbose
|
||||
)
|
||||
return result
|
||||
|
||||
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
|
||||
if hasattr(method, attr):
|
||||
setattr(method_sync_wrapper, attr, getattr(method, attr))
|
||||
setattr(method_sync_wrapper, "__is_flow_method__", True)
|
||||
return cast(Callable[..., T], method_sync_wrapper)
|
||||
for attr in [
|
||||
"__is_start_method__",
|
||||
"__trigger_methods__",
|
||||
"__condition_type__",
|
||||
"__is_router__",
|
||||
]:
|
||||
if hasattr(method, attr):
|
||||
setattr(method_async_wrapper, attr, getattr(method, attr))
|
||||
method_async_wrapper.__is_flow_method__ = True # type: ignore[attr-defined]
|
||||
return cast(Callable[..., T], method_async_wrapper)
|
||||
|
||||
@functools.wraps(method)
|
||||
def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
|
||||
result = method(flow_instance, *args, **kwargs)
|
||||
PersistenceDecorator.persist_state(
|
||||
flow_instance, method.__name__, actual_persistence, verbose
|
||||
)
|
||||
return result
|
||||
|
||||
for attr in [
|
||||
"__is_start_method__",
|
||||
"__trigger_methods__",
|
||||
"__condition_type__",
|
||||
"__is_router__",
|
||||
]:
|
||||
if hasattr(method, attr):
|
||||
setattr(method_sync_wrapper, attr, getattr(method, attr))
|
||||
method_sync_wrapper.__is_flow_method__ = True # type: ignore[attr-defined]
|
||||
return cast(Callable[..., T], method_sync_wrapper)
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -6,7 +6,7 @@ import json
|
||||
import sqlite3
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -23,7 +23,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
|
||||
db_path: str
|
||||
|
||||
def __init__(self, db_path: Optional[str] = None):
|
||||
def __init__(self, db_path: str | None = None):
|
||||
"""Initialize SQLite persistence.
|
||||
|
||||
Args:
|
||||
@@ -70,7 +70,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
self,
|
||||
flow_uuid: str,
|
||||
method_name: str,
|
||||
state_data: Union[Dict[str, Any], BaseModel],
|
||||
state_data: dict[str, Any] | BaseModel,
|
||||
) -> None:
|
||||
"""Save the current flow state to SQLite.
|
||||
|
||||
@@ -107,7 +107,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
),
|
||||
)
|
||||
|
||||
def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]:
|
||||
def load_state(self, flow_uuid: str) -> dict[str, Any] | None:
|
||||
"""Load the most recent state for a given flow UUID.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -5,6 +5,7 @@ the Flow system.
|
||||
"""
|
||||
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from typing_extensions import NotRequired, Required
|
||||
|
||||
|
||||
|
||||
@@ -17,10 +17,10 @@ import ast
|
||||
import inspect
|
||||
import textwrap
|
||||
from collections import defaultdict, deque
|
||||
from typing import Any, Deque, Dict, List, Optional, Set, Union
|
||||
from typing import Any
|
||||
|
||||
|
||||
def get_possible_return_constants(function: Any) -> Optional[List[str]]:
|
||||
def get_possible_return_constants(function: Any) -> list[str] | None:
|
||||
try:
|
||||
source = inspect.getsource(function)
|
||||
except OSError:
|
||||
@@ -58,12 +58,12 @@ def get_possible_return_constants(function: Any) -> Optional[List[str]]:
|
||||
target = node.targets[0]
|
||||
if isinstance(target, ast.Name):
|
||||
var_name = target.id
|
||||
dict_values = []
|
||||
# Extract string values from the dictionary
|
||||
for val in node.value.values:
|
||||
if isinstance(val, ast.Constant) and isinstance(val.value, str):
|
||||
dict_values.append(val.value)
|
||||
# If non-string, skip or just ignore
|
||||
dict_values = [
|
||||
val.value
|
||||
for val in node.value.values
|
||||
if isinstance(val, ast.Constant) and isinstance(val.value, str)
|
||||
]
|
||||
if dict_values:
|
||||
dict_definitions[var_name] = dict_values
|
||||
self.generic_visit(node)
|
||||
@@ -94,7 +94,7 @@ def get_possible_return_constants(function: Any) -> Optional[List[str]]:
|
||||
return list(return_values) if return_values else None
|
||||
|
||||
|
||||
def calculate_node_levels(flow: Any) -> Dict[str, int]:
|
||||
def calculate_node_levels(flow: Any) -> dict[str, int]:
|
||||
"""
|
||||
Calculate the hierarchical level of each node in the flow.
|
||||
|
||||
@@ -118,10 +118,10 @@ def calculate_node_levels(flow: Any) -> Dict[str, int]:
|
||||
- Handles both OR and AND conditions for listeners
|
||||
- Processes router paths separately
|
||||
"""
|
||||
levels: Dict[str, int] = {}
|
||||
queue: Deque[str] = deque()
|
||||
visited: Set[str] = set()
|
||||
pending_and_listeners: Dict[str, Set[str]] = {}
|
||||
levels: dict[str, int] = {}
|
||||
queue: deque[str] = deque()
|
||||
visited: set[str] = set()
|
||||
pending_and_listeners: dict[str, set[str]] = {}
|
||||
|
||||
# Make all start methods at level 0
|
||||
for method_name, method in flow._methods.items():
|
||||
@@ -172,7 +172,7 @@ def calculate_node_levels(flow: Any) -> Dict[str, int]:
|
||||
return levels
|
||||
|
||||
|
||||
def count_outgoing_edges(flow: Any) -> Dict[str, int]:
|
||||
def count_outgoing_edges(flow: Any) -> dict[str, int]:
|
||||
"""
|
||||
Count the number of outgoing edges for each method in the flow.
|
||||
|
||||
@@ -197,7 +197,7 @@ def count_outgoing_edges(flow: Any) -> Dict[str, int]:
|
||||
return counts
|
||||
|
||||
|
||||
def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]:
|
||||
def build_ancestor_dict(flow: Any) -> dict[str, set[str]]:
|
||||
"""
|
||||
Build a dictionary mapping each node to its ancestor nodes.
|
||||
|
||||
@@ -211,8 +211,8 @@ def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]:
|
||||
Dict[str, Set[str]]
|
||||
Dictionary mapping each node to a set of its ancestor nodes.
|
||||
"""
|
||||
ancestors: Dict[str, Set[str]] = {node: set() for node in flow._methods}
|
||||
visited: Set[str] = set()
|
||||
ancestors: dict[str, set[str]] = {node: set() for node in flow._methods}
|
||||
visited: set[str] = set()
|
||||
for node in flow._methods:
|
||||
if node not in visited:
|
||||
dfs_ancestors(node, ancestors, visited, flow)
|
||||
@@ -220,7 +220,7 @@ def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]:
|
||||
|
||||
|
||||
def dfs_ancestors(
|
||||
node: str, ancestors: Dict[str, Set[str]], visited: Set[str], flow: Any
|
||||
node: str, ancestors: dict[str, set[str]], visited: set[str], flow: Any
|
||||
) -> None:
|
||||
"""
|
||||
Perform depth-first search to build ancestor relationships.
|
||||
@@ -265,7 +265,7 @@ def dfs_ancestors(
|
||||
|
||||
|
||||
def is_ancestor(
|
||||
node: str, ancestor_candidate: str, ancestors: Dict[str, Set[str]]
|
||||
node: str, ancestor_candidate: str, ancestors: dict[str, set[str]]
|
||||
) -> bool:
|
||||
"""
|
||||
Check if one node is an ancestor of another.
|
||||
@@ -287,7 +287,7 @@ def is_ancestor(
|
||||
return ancestor_candidate in ancestors.get(node, set())
|
||||
|
||||
|
||||
def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]:
|
||||
def build_parent_children_dict(flow: Any) -> dict[str, list[str]]:
|
||||
"""
|
||||
Build a dictionary mapping parent nodes to their children.
|
||||
|
||||
@@ -307,7 +307,7 @@ def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]:
|
||||
- Maps router methods to their paths and listeners
|
||||
- Children lists are sorted for consistent ordering
|
||||
"""
|
||||
parent_children: Dict[str, List[str]] = {}
|
||||
parent_children: dict[str, list[str]] = {}
|
||||
|
||||
# Map listeners to their trigger methods
|
||||
for listener_name, (_, trigger_methods) in flow._listeners.items():
|
||||
@@ -332,7 +332,7 @@ def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]:
|
||||
|
||||
|
||||
def get_child_index(
|
||||
parent: str, child: str, parent_children: Dict[str, List[str]]
|
||||
parent: str, child: str, parent_children: dict[str, list[str]]
|
||||
) -> int:
|
||||
"""
|
||||
Get the index of a child node in its parent's sorted children list.
|
||||
@@ -364,7 +364,7 @@ def process_router_paths(flow, current, current_level, levels, queue):
|
||||
paths = flow._router_paths.get(current, [])
|
||||
for path in paths:
|
||||
for listener_name, (
|
||||
condition_type,
|
||||
_condition_type,
|
||||
trigger_methods,
|
||||
) in flow._listeners.items():
|
||||
if path in trigger_methods:
|
||||
|
||||
@@ -17,7 +17,7 @@ Example
|
||||
|
||||
import ast
|
||||
import inspect
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
from typing import Any
|
||||
|
||||
from .utils import (
|
||||
build_ancestor_dict,
|
||||
@@ -56,6 +56,7 @@ def method_calls_crew(method: Any) -> bool:
|
||||
|
||||
class CrewCallVisitor(ast.NodeVisitor):
|
||||
"""AST visitor to detect .crew() method calls."""
|
||||
|
||||
def __init__(self):
|
||||
self.found = False
|
||||
|
||||
@@ -73,8 +74,8 @@ def method_calls_crew(method: Any) -> bool:
|
||||
def add_nodes_to_network(
|
||||
net: Any,
|
||||
flow: Any,
|
||||
node_positions: Dict[str, Tuple[float, float]],
|
||||
node_styles: Dict[str, Dict[str, Any]]
|
||||
node_positions: dict[str, tuple[float, float]],
|
||||
node_styles: dict[str, dict[str, Any]],
|
||||
) -> None:
|
||||
"""
|
||||
Add nodes to the network visualization with appropriate styling.
|
||||
@@ -98,6 +99,7 @@ def add_nodes_to_network(
|
||||
- Crew methods
|
||||
- Regular methods
|
||||
"""
|
||||
|
||||
def human_friendly_label(method_name):
|
||||
return method_name.replace("_", " ").title()
|
||||
|
||||
@@ -138,10 +140,10 @@ def add_nodes_to_network(
|
||||
|
||||
def compute_positions(
|
||||
flow: Any,
|
||||
node_levels: Dict[str, int],
|
||||
node_levels: dict[str, int],
|
||||
y_spacing: float = 150,
|
||||
x_spacing: float = 300
|
||||
) -> Dict[str, Tuple[float, float]]:
|
||||
x_spacing: float = 300,
|
||||
) -> dict[str, tuple[float, float]]:
|
||||
"""
|
||||
Compute the (x, y) positions for each node in the flow graph.
|
||||
|
||||
@@ -161,8 +163,8 @@ def compute_positions(
|
||||
Dict[str, Tuple[float, float]]
|
||||
Dictionary mapping node names to their (x, y) coordinates.
|
||||
"""
|
||||
level_nodes: Dict[int, List[str]] = {}
|
||||
node_positions: Dict[str, Tuple[float, float]] = {}
|
||||
level_nodes: dict[int, list[str]] = {}
|
||||
node_positions: dict[str, tuple[float, float]] = {}
|
||||
|
||||
for method_name, level in node_levels.items():
|
||||
level_nodes.setdefault(level, []).append(method_name)
|
||||
@@ -180,10 +182,10 @@ def compute_positions(
|
||||
def add_edges(
|
||||
net: Any,
|
||||
flow: Any,
|
||||
node_positions: Dict[str, Tuple[float, float]],
|
||||
colors: Dict[str, str]
|
||||
node_positions: dict[str, tuple[float, float]],
|
||||
colors: dict[str, str],
|
||||
) -> None:
|
||||
edge_smooth: Dict[str, Union[str, float]] = {"type": "continuous"} # Default value
|
||||
edge_smooth: dict[str, str | float] = {"type": "continuous"} # Default value
|
||||
"""
|
||||
Add edges to the network visualization with appropriate styling.
|
||||
|
||||
@@ -269,7 +271,7 @@ def add_edges(
|
||||
for router_method_name, paths in flow._router_paths.items():
|
||||
for path in paths:
|
||||
for listener_name, (
|
||||
condition_type,
|
||||
_condition_type,
|
||||
trigger_methods,
|
||||
) in flow._listeners.items():
|
||||
if path in trigger_methods:
|
||||
|
||||
@@ -8,7 +8,7 @@ from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
|
||||
from crewai.rag.config.utils import get_rag_client
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.embeddings.factory import get_embedding_function
|
||||
from crewai.rag.embeddings.factory import EmbedderConfig, get_embedding_function
|
||||
from crewai.rag.factory import create_client
|
||||
from crewai.rag.types import BaseRecord, SearchResult
|
||||
from crewai.utilities.logger import Logger
|
||||
@@ -27,6 +27,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
) -> None:
|
||||
self.collection_name = collection_name
|
||||
self._client: BaseClient | None = None
|
||||
self._embedder_config = embedder # Store embedder config
|
||||
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
@@ -35,12 +36,29 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
)
|
||||
|
||||
if embedder:
|
||||
embedding_function = get_embedding_function(embedder)
|
||||
config = ChromaDBConfig(
|
||||
embedding_function=cast(
|
||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||
# Cast to EmbedderConfig for type checking
|
||||
embedder_typed = cast(EmbedderConfig, embedder)
|
||||
embedding_function = get_embedding_function(embedder_typed)
|
||||
batch_size = None
|
||||
if isinstance(embedder, dict) and "config" in embedder:
|
||||
nested_config = embedder["config"]
|
||||
if isinstance(nested_config, dict):
|
||||
batch_size = nested_config.get("batch_size")
|
||||
|
||||
# Create config with batch_size if provided
|
||||
if batch_size is not None:
|
||||
config = ChromaDBConfig(
|
||||
embedding_function=cast(
|
||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||
),
|
||||
batch_size=batch_size,
|
||||
)
|
||||
else:
|
||||
config = ChromaDBConfig(
|
||||
embedding_function=cast(
|
||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||
)
|
||||
)
|
||||
)
|
||||
self._client = create_client(config)
|
||||
|
||||
def _get_client(self) -> BaseClient:
|
||||
@@ -105,9 +123,23 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
|
||||
rag_documents: list[BaseRecord] = [{"content": doc} for doc in documents]
|
||||
|
||||
client.add_documents(
|
||||
collection_name=collection_name, documents=rag_documents
|
||||
)
|
||||
batch_size = None
|
||||
if self._embedder_config and isinstance(self._embedder_config, dict):
|
||||
if "config" in self._embedder_config:
|
||||
nested_config = self._embedder_config["config"]
|
||||
if isinstance(nested_config, dict):
|
||||
batch_size = nested_config.get("batch_size")
|
||||
|
||||
if batch_size is not None:
|
||||
client.add_documents(
|
||||
collection_name=collection_name,
|
||||
documents=rag_documents,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
else:
|
||||
client.add_documents(
|
||||
collection_name=collection_name, documents=rag_documents
|
||||
)
|
||||
except Exception as e:
|
||||
if "dimension mismatch" in str(e).lower():
|
||||
Logger(verbose=True).log(
|
||||
|
||||
@@ -367,6 +367,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
output=output,
|
||||
guardrail=self._guardrail,
|
||||
retry_count=self._guardrail_retry_count,
|
||||
event_source=self,
|
||||
)
|
||||
|
||||
if not guardrail_result.success:
|
||||
|
||||
@@ -1,28 +1,26 @@
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
Any,
|
||||
DefaultDict,
|
||||
Dict,
|
||||
List,
|
||||
Final,
|
||||
Literal,
|
||||
Optional,
|
||||
Type,
|
||||
TextIO,
|
||||
TypedDict,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from datetime import datetime
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from litellm.types.utils import ChatCompletionDeltaToolCall
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.llm_events import (
|
||||
LLMCallCompletedEvent,
|
||||
LLMCallFailedEvent,
|
||||
@@ -31,15 +29,19 @@ from crewai.events.types.llm_events import (
|
||||
LLMStreamChunkEvent,
|
||||
)
|
||||
from crewai.events.types.tool_usage_events import (
|
||||
ToolUsageStartedEvent,
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageErrorEvent,
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageStartedEvent,
|
||||
)
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
)
|
||||
from crewai.utilities.logger_utils import suppress_warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", UserWarning)
|
||||
with suppress_warnings():
|
||||
import litellm
|
||||
from litellm import Choices
|
||||
from litellm import Choices, CustomLogger
|
||||
from litellm.exceptions import ContextWindowExceededError
|
||||
from litellm.litellm_core_utils.get_supported_openai_params import (
|
||||
get_supported_openai_params,
|
||||
@@ -47,16 +49,6 @@ with warnings.catch_warnings():
|
||||
from litellm.types.utils import ModelResponse
|
||||
from litellm.utils import supports_response_schema
|
||||
|
||||
|
||||
import io
|
||||
from typing import TextIO
|
||||
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededException,
|
||||
)
|
||||
|
||||
load_dotenv()
|
||||
|
||||
litellm.suppress_debug_info = True
|
||||
@@ -126,7 +118,11 @@ if not isinstance(sys.stderr, FilteredStream):
|
||||
sys.stderr = FilteredStream(sys.stderr)
|
||||
|
||||
|
||||
LLM_CONTEXT_WINDOW_SIZES = {
|
||||
MIN_CONTEXT: Final[int] = 1024
|
||||
MAX_CONTEXT: Final[int] = 2097152 # Current max from gemini-1.5-pro
|
||||
ANTHROPIC_PREFIXES: Final[tuple[str, str, str]] = ("anthropic/", "claude-", "claude/")
|
||||
|
||||
LLM_CONTEXT_WINDOW_SIZES: Final[dict[str, int]] = {
|
||||
# openai
|
||||
"gpt-4": 8192,
|
||||
"gpt-4o": 128000,
|
||||
@@ -252,30 +248,19 @@ LLM_CONTEXT_WINDOW_SIZES = {
|
||||
"mistral/mistral-large-2402": 32768,
|
||||
}
|
||||
|
||||
DEFAULT_CONTEXT_WINDOW_SIZE = 8192
|
||||
CONTEXT_WINDOW_USAGE_RATIO = 0.85
|
||||
|
||||
|
||||
@contextmanager
|
||||
def suppress_warnings():
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore")
|
||||
warnings.filterwarnings(
|
||||
"ignore", message="open_text is deprecated*", category=DeprecationWarning
|
||||
)
|
||||
|
||||
yield
|
||||
DEFAULT_CONTEXT_WINDOW_SIZE: Final[int] = 8192
|
||||
CONTEXT_WINDOW_USAGE_RATIO: Final[float] = 0.85
|
||||
|
||||
|
||||
class Delta(TypedDict):
|
||||
content: Optional[str]
|
||||
role: Optional[str]
|
||||
content: str | None
|
||||
role: str | None
|
||||
|
||||
|
||||
class StreamingChoices(TypedDict):
|
||||
delta: Delta
|
||||
index: int
|
||||
finish_reason: Optional[str]
|
||||
finish_reason: str | None
|
||||
|
||||
|
||||
class FunctionArgs(BaseModel):
|
||||
@@ -288,31 +273,31 @@ class AccumulatedToolArgs(BaseModel):
|
||||
|
||||
|
||||
class LLM(BaseLLM):
|
||||
completion_cost: Optional[float] = None
|
||||
completion_cost: float | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
timeout: Optional[Union[float, int]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
n: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
max_completion_tokens: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
logit_bias: Optional[Dict[int, float]] = None,
|
||||
response_format: Optional[Type[BaseModel]] = None,
|
||||
seed: Optional[int] = None,
|
||||
logprobs: Optional[int] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
base_url: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
callbacks: List[Any] | None = None,
|
||||
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
|
||||
timeout: float | int | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
n: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[int, float] | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
seed: int | None = None,
|
||||
logprobs: int | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
base_url: str | None = None,
|
||||
api_base: str | None = None,
|
||||
api_version: str | None = None,
|
||||
api_key: str | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
reasoning_effort: Literal["none", "low", "medium", "high"] | None = None,
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -345,7 +330,7 @@ class LLM(BaseLLM):
|
||||
|
||||
# Normalize self.stop to always be a List[str]
|
||||
if stop is None:
|
||||
self.stop: List[str] = []
|
||||
self.stop: list[str] = []
|
||||
elif isinstance(stop, str):
|
||||
self.stop = [stop]
|
||||
else:
|
||||
@@ -354,7 +339,8 @@ class LLM(BaseLLM):
|
||||
self.set_callbacks(callbacks or [])
|
||||
self.set_env_callbacks()
|
||||
|
||||
def _is_anthropic_model(self, model: str) -> bool:
|
||||
@staticmethod
|
||||
def _is_anthropic_model(model: str) -> bool:
|
||||
"""Determine if the model is from Anthropic provider.
|
||||
|
||||
Args:
|
||||
@@ -363,21 +349,18 @@ class LLM(BaseLLM):
|
||||
Returns:
|
||||
bool: True if the model is from Anthropic, False otherwise.
|
||||
"""
|
||||
ANTHROPIC_PREFIXES = ("anthropic/", "claude-", "claude/")
|
||||
return any(prefix in model.lower() for prefix in ANTHROPIC_PREFIXES)
|
||||
|
||||
def _prepare_completion_params(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
messages: str | list[dict[str, str]],
|
||||
tools: list[dict] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Prepare parameters for the completion call.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the LLM
|
||||
tools: Optional list of tool schemas
|
||||
callbacks: Optional list of callback functions
|
||||
available_functions: Optional dict of available functions
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Parameters for the completion call
|
||||
@@ -419,11 +402,11 @@ class LLM(BaseLLM):
|
||||
|
||||
def _handle_streaming_response(
|
||||
self,
|
||||
params: Dict[str, Any],
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
params: dict[str, Any],
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str:
|
||||
"""Handle a streaming response from the LLM.
|
||||
|
||||
@@ -445,9 +428,8 @@ class LLM(BaseLLM):
|
||||
last_chunk = None
|
||||
chunk_count = 0
|
||||
usage_info = None
|
||||
tool_calls = None
|
||||
|
||||
accumulated_tool_args: DefaultDict[int, AccumulatedToolArgs] = defaultdict(
|
||||
accumulated_tool_args: defaultdict[int, AccumulatedToolArgs] = defaultdict(
|
||||
AccumulatedToolArgs
|
||||
)
|
||||
|
||||
@@ -472,16 +454,16 @@ class LLM(BaseLLM):
|
||||
choices = chunk["choices"]
|
||||
elif hasattr(chunk, "choices"):
|
||||
# Check if choices is not a type but an actual attribute with value
|
||||
if not isinstance(getattr(chunk, "choices"), type):
|
||||
choices = getattr(chunk, "choices")
|
||||
if not isinstance(chunk.choices, type):
|
||||
choices = chunk.choices
|
||||
|
||||
# Try to extract usage information if available
|
||||
if isinstance(chunk, dict) and "usage" in chunk:
|
||||
usage_info = chunk["usage"]
|
||||
elif hasattr(chunk, "usage"):
|
||||
# Check if usage is not a type but an actual attribute with value
|
||||
if not isinstance(getattr(chunk, "usage"), type):
|
||||
usage_info = getattr(chunk, "usage")
|
||||
if not isinstance(chunk.usage, type):
|
||||
usage_info = chunk.usage
|
||||
|
||||
if choices and len(choices) > 0:
|
||||
choice = choices[0]
|
||||
@@ -491,7 +473,7 @@ class LLM(BaseLLM):
|
||||
if isinstance(choice, dict) and "delta" in choice:
|
||||
delta = choice["delta"]
|
||||
elif hasattr(choice, "delta"):
|
||||
delta = getattr(choice, "delta")
|
||||
delta = choice.delta
|
||||
|
||||
# Extract content from delta
|
||||
if delta:
|
||||
@@ -501,7 +483,7 @@ class LLM(BaseLLM):
|
||||
chunk_content = delta["content"]
|
||||
# Handle object format
|
||||
elif hasattr(delta, "content"):
|
||||
chunk_content = getattr(delta, "content")
|
||||
chunk_content = delta.content
|
||||
|
||||
# Handle case where content might be None or empty
|
||||
if chunk_content is None and isinstance(delta, dict):
|
||||
@@ -533,7 +515,9 @@ class LLM(BaseLLM):
|
||||
full_response += chunk_content
|
||||
|
||||
# Emit the chunk event
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
if not hasattr(crewai_event_bus, "emit"):
|
||||
raise Exception("crewai_event_bus must have an `emit` method")
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMStreamChunkEvent(
|
||||
@@ -572,8 +556,8 @@ class LLM(BaseLLM):
|
||||
if isinstance(last_chunk, dict) and "choices" in last_chunk:
|
||||
choices = last_chunk["choices"]
|
||||
elif hasattr(last_chunk, "choices"):
|
||||
if not isinstance(getattr(last_chunk, "choices"), type):
|
||||
choices = getattr(last_chunk, "choices")
|
||||
if not isinstance(last_chunk.choices, type):
|
||||
choices = last_chunk.choices
|
||||
|
||||
if choices and len(choices) > 0:
|
||||
choice = choices[0]
|
||||
@@ -583,14 +567,14 @@ class LLM(BaseLLM):
|
||||
if isinstance(choice, dict) and "message" in choice:
|
||||
message = choice["message"]
|
||||
elif hasattr(choice, "message"):
|
||||
message = getattr(choice, "message")
|
||||
message = choice.message
|
||||
|
||||
if message:
|
||||
content = None
|
||||
if isinstance(message, dict) and "content" in message:
|
||||
content = message["content"]
|
||||
elif hasattr(message, "content"):
|
||||
content = getattr(message, "content")
|
||||
content = message.content
|
||||
|
||||
if content:
|
||||
full_response = content
|
||||
@@ -617,8 +601,8 @@ class LLM(BaseLLM):
|
||||
if isinstance(last_chunk, dict) and "choices" in last_chunk:
|
||||
choices = last_chunk["choices"]
|
||||
elif hasattr(last_chunk, "choices"):
|
||||
if not isinstance(getattr(last_chunk, "choices"), type):
|
||||
choices = getattr(last_chunk, "choices")
|
||||
if not isinstance(last_chunk.choices, type):
|
||||
choices = last_chunk.choices
|
||||
|
||||
if choices and len(choices) > 0:
|
||||
choice = choices[0]
|
||||
@@ -627,13 +611,13 @@ class LLM(BaseLLM):
|
||||
if isinstance(choice, dict) and "message" in choice:
|
||||
message = choice["message"]
|
||||
elif hasattr(choice, "message"):
|
||||
message = getattr(choice, "message")
|
||||
message = choice.message
|
||||
|
||||
if message:
|
||||
if isinstance(message, dict) and "tool_calls" in message:
|
||||
tool_calls = message["tool_calls"]
|
||||
elif hasattr(message, "tool_calls"):
|
||||
tool_calls = getattr(message, "tool_calls")
|
||||
tool_calls = message.tool_calls
|
||||
except Exception as e:
|
||||
logging.debug(f"Error checking for tool calls: {e}")
|
||||
# --- 8) If no tool calls or no available functions, return the text response directly
|
||||
@@ -673,11 +657,11 @@ class LLM(BaseLLM):
|
||||
# Catch context window errors from litellm and convert them to our own exception type.
|
||||
# This exception is handled by CrewAgentExecutor._invoke_loop() which can then
|
||||
# decide whether to summarize the content or abort based on the respect_context_window flag.
|
||||
raise LLMContextLengthExceededException(str(e))
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
except Exception as e:
|
||||
logging.error(f"Error in streaming response: {str(e)}")
|
||||
logging.error(f"Error in streaming response: {e!s}")
|
||||
if full_response.strip():
|
||||
logging.warning(f"Returning partial response despite error: {str(e)}")
|
||||
logging.warning(f"Returning partial response despite error: {e!s}")
|
||||
self._handle_emit_call_events(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
@@ -688,22 +672,25 @@ class LLM(BaseLLM):
|
||||
return full_response
|
||||
|
||||
# Emit failed event and re-raise the exception
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
if not hasattr(crewai_event_bus, "emit"):
|
||||
raise AttributeError(
|
||||
"crewai_event_bus must have an 'emit' method"
|
||||
) from e
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(
|
||||
error=str(e), from_task=from_task, from_agent=from_agent
|
||||
),
|
||||
)
|
||||
raise Exception(f"Failed to get streaming response: {str(e)}")
|
||||
raise Exception(f"Failed to get streaming response: {e!s}") from e
|
||||
|
||||
def _handle_streaming_tool_calls(
|
||||
self,
|
||||
tool_calls: List[ChatCompletionDeltaToolCall],
|
||||
accumulated_tool_args: DefaultDict[int, AccumulatedToolArgs],
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
tool_calls: list[ChatCompletionDeltaToolCall],
|
||||
accumulated_tool_args: defaultdict[int, AccumulatedToolArgs],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> None | str:
|
||||
for tool_call in tool_calls:
|
||||
current_tool_accumulator = accumulated_tool_args[tool_call.index]
|
||||
@@ -715,7 +702,8 @@ class LLM(BaseLLM):
|
||||
current_tool_accumulator.function.arguments += (
|
||||
tool_call.function.arguments
|
||||
)
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
if not hasattr(crewai_event_bus, "emit"):
|
||||
raise AttributeError("crewai_event_bus must have an 'emit' method")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMStreamChunkEvent(
|
||||
@@ -742,11 +730,11 @@ class LLM(BaseLLM):
|
||||
continue
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _handle_streaming_callbacks(
|
||||
self,
|
||||
callbacks: Optional[List[Any]],
|
||||
usage_info: Optional[Dict[str, Any]],
|
||||
last_chunk: Optional[Any],
|
||||
callbacks: list[Any] | None,
|
||||
usage_info: dict[str, Any] | None,
|
||||
last_chunk: Any | None,
|
||||
) -> None:
|
||||
"""Handle callbacks with usage info for streaming responses.
|
||||
|
||||
@@ -769,10 +757,8 @@ class LLM(BaseLLM):
|
||||
):
|
||||
usage_info = last_chunk["usage"]
|
||||
elif hasattr(last_chunk, "usage"):
|
||||
if not isinstance(
|
||||
getattr(last_chunk, "usage"), type
|
||||
):
|
||||
usage_info = getattr(last_chunk, "usage")
|
||||
if not isinstance(last_chunk.usage, type):
|
||||
usage_info = last_chunk.usage
|
||||
except Exception as e:
|
||||
logging.debug(f"Error extracting usage info: {e}")
|
||||
|
||||
@@ -786,11 +772,11 @@ class LLM(BaseLLM):
|
||||
|
||||
def _handle_non_streaming_response(
|
||||
self,
|
||||
params: Dict[str, Any],
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
params: dict[str, Any],
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle a non-streaming response from the LLM.
|
||||
|
||||
@@ -815,7 +801,7 @@ class LLM(BaseLLM):
|
||||
except ContextWindowExceededError as e:
|
||||
# Convert litellm's context window error to our own exception type
|
||||
# for consistent handling in the rest of the codebase
|
||||
raise LLMContextLengthExceededException(str(e))
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
# --- 2) Extract response message and content
|
||||
response_message = cast(Choices, cast(ModelResponse, response).choices)[
|
||||
0
|
||||
@@ -847,7 +833,7 @@ class LLM(BaseLLM):
|
||||
)
|
||||
return text_response
|
||||
# --- 6) If there is no text response, no available functions, but there are tool calls, return the tool calls
|
||||
elif tool_calls and not available_functions and not text_response:
|
||||
if tool_calls and not available_functions and not text_response:
|
||||
return tool_calls
|
||||
|
||||
# --- 7) Handle tool calls if present
|
||||
@@ -868,19 +854,21 @@ class LLM(BaseLLM):
|
||||
|
||||
def _handle_tool_call(
|
||||
self,
|
||||
tool_calls: List[Any],
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
) -> Optional[str]:
|
||||
tool_calls: list[Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str | None:
|
||||
"""Handle a tool call from the LLM.
|
||||
|
||||
Args:
|
||||
tool_calls: List of tool calls from the LLM
|
||||
available_functions: Dict of available functions
|
||||
from_task: Optional Task that invoked the LLM
|
||||
from_agent: Optional Agent that invoked the LLM
|
||||
|
||||
Returns:
|
||||
Optional[str]: The result of the tool call, or None if no tool call was made
|
||||
The result of the tool call, or None if no tool call was made
|
||||
"""
|
||||
# --- 1) Validate tool calls and available functions
|
||||
if not tool_calls or not available_functions:
|
||||
@@ -899,7 +887,8 @@ class LLM(BaseLLM):
|
||||
fn = available_functions[function_name]
|
||||
|
||||
# --- 3.2) Execute function
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
if not hasattr(crewai_event_bus, "emit"):
|
||||
raise AttributeError("crewai_event_bus must have an 'emit' method")
|
||||
started_at = datetime.now()
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -939,17 +928,20 @@ class LLM(BaseLLM):
|
||||
function_name, lambda: None
|
||||
) # Ensure fn is always a callable
|
||||
logging.error(f"Error executing function '{function_name}': {e}")
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
if not hasattr(crewai_event_bus, "emit"):
|
||||
raise AttributeError(
|
||||
"crewai_event_bus must have an 'emit' method"
|
||||
) from e
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(error=f"Tool execution error: {str(e)}"),
|
||||
event=LLMCallFailedEvent(error=f"Tool execution error: {e!s}"),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolUsageErrorEvent(
|
||||
tool_name=function_name,
|
||||
tool_args=function_args,
|
||||
error=f"Tool execution error: {str(e)}",
|
||||
error=f"Tool execution error: {e!s}",
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
@@ -958,13 +950,13 @@ class LLM(BaseLLM):
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
) -> Union[str, Any]:
|
||||
messages: str | list[dict[str, str]],
|
||||
tools: list[dict] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str | Any:
|
||||
"""High-level LLM call method.
|
||||
|
||||
Args:
|
||||
@@ -988,10 +980,11 @@ class LLM(BaseLLM):
|
||||
Raises:
|
||||
TypeError: If messages format is invalid
|
||||
ValueError: If response format is not supported
|
||||
LLMContextLengthExceededException: If input exceeds model's context limit
|
||||
LLMContextLengthExceededError: If input exceeds model's context limit
|
||||
"""
|
||||
# --- 1) Emit call started event
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
if not hasattr(crewai_event_bus, "emit"):
|
||||
raise AttributeError("crewai_event_bus must have an 'emit' method")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallStartedEvent(
|
||||
@@ -1028,13 +1021,12 @@ class LLM(BaseLLM):
|
||||
return self._handle_streaming_response(
|
||||
params, callbacks, available_functions, from_task, from_agent
|
||||
)
|
||||
else:
|
||||
return self._handle_non_streaming_response(
|
||||
params, callbacks, available_functions, from_task, from_agent
|
||||
)
|
||||
return self._handle_non_streaming_response(
|
||||
params, callbacks, available_functions, from_task, from_agent
|
||||
)
|
||||
|
||||
except LLMContextLengthExceededException:
|
||||
# Re-raise LLMContextLengthExceededException as it should be handled
|
||||
except LLMContextLengthExceededError:
|
||||
# Re-raise LLMContextLengthExceededError as it should be handled
|
||||
# by the CrewAgentExecutor._invoke_loop method, which can then decide
|
||||
# whether to summarize the content or abort based on the respect_context_window flag
|
||||
raise
|
||||
@@ -1065,7 +1057,10 @@ class LLM(BaseLLM):
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
if not hasattr(crewai_event_bus, "emit"):
|
||||
raise AttributeError(
|
||||
"crewai_event_bus must have an 'emit' method"
|
||||
) from e
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(
|
||||
@@ -1078,8 +1073,8 @@ class LLM(BaseLLM):
|
||||
self,
|
||||
response: Any,
|
||||
call_type: LLMCallType,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
messages: str | list[dict[str, Any]] | None = None,
|
||||
):
|
||||
"""Handle the events for the LLM call.
|
||||
@@ -1091,7 +1086,8 @@ class LLM(BaseLLM):
|
||||
from_agent: Optional agent object
|
||||
messages: Optional messages object
|
||||
"""
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
if not hasattr(crewai_event_bus, "emit"):
|
||||
raise AttributeError("crewai_event_bus must have an 'emit' method")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallCompletedEvent(
|
||||
@@ -1105,8 +1101,8 @@ class LLM(BaseLLM):
|
||||
)
|
||||
|
||||
def _format_messages_for_provider(
|
||||
self, messages: List[Dict[str, str]]
|
||||
) -> List[Dict[str, str]]:
|
||||
self, messages: list[dict[str, str]]
|
||||
) -> list[dict[str, str]]:
|
||||
"""Format messages according to provider requirements.
|
||||
|
||||
Args:
|
||||
@@ -1147,7 +1143,7 @@ class LLM(BaseLLM):
|
||||
if "mistral" in self.model.lower():
|
||||
# Check if the last message has a role of 'assistant'
|
||||
if messages and messages[-1]["role"] == "assistant":
|
||||
return messages + [{"role": "user", "content": "Please continue."}]
|
||||
return [*messages, {"role": "user", "content": "Please continue."}]
|
||||
return messages
|
||||
|
||||
# TODO: Remove this code after merging PR https://github.com/BerriAI/litellm/pull/10917
|
||||
@@ -1157,7 +1153,7 @@ class LLM(BaseLLM):
|
||||
and messages
|
||||
and messages[-1]["role"] == "assistant"
|
||||
):
|
||||
return messages + [{"role": "user", "content": ""}]
|
||||
return [*messages, {"role": "user", "content": ""}]
|
||||
|
||||
# Handle Anthropic models
|
||||
if not self.is_anthropic:
|
||||
@@ -1170,7 +1166,7 @@ class LLM(BaseLLM):
|
||||
|
||||
return messages
|
||||
|
||||
def _get_custom_llm_provider(self) -> Optional[str]:
|
||||
def _get_custom_llm_provider(self) -> str | None:
|
||||
"""
|
||||
Derives the custom_llm_provider from the model string.
|
||||
- For example, if the model is "openrouter/deepseek/deepseek-chat", returns "openrouter".
|
||||
@@ -1207,7 +1203,7 @@ class LLM(BaseLLM):
|
||||
self.model, custom_llm_provider=provider
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to check function calling support: {str(e)}")
|
||||
logging.error(f"Failed to check function calling support: {e!s}")
|
||||
return False
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
@@ -1215,7 +1211,7 @@ class LLM(BaseLLM):
|
||||
params = get_supported_openai_params(model=self.model)
|
||||
return params is not None and "stop" in params
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to get supported params: {str(e)}")
|
||||
logging.error(f"Failed to get supported params: {e!s}")
|
||||
return False
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
@@ -1229,9 +1225,6 @@ class LLM(BaseLLM):
|
||||
if self.context_window_size != 0:
|
||||
return self.context_window_size
|
||||
|
||||
MIN_CONTEXT = 1024
|
||||
MAX_CONTEXT = 2097152 # Current max from gemini-1.5-pro
|
||||
|
||||
# Validate all context window sizes
|
||||
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
|
||||
if value < MIN_CONTEXT or value > MAX_CONTEXT:
|
||||
@@ -1247,7 +1240,8 @@ class LLM(BaseLLM):
|
||||
self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
return self.context_window_size
|
||||
|
||||
def set_callbacks(self, callbacks: List[Any]):
|
||||
@staticmethod
|
||||
def set_callbacks(callbacks: list[Any]):
|
||||
"""
|
||||
Attempt to keep a single set of callbacks in litellm by removing old
|
||||
duplicates and adding new ones.
|
||||
@@ -1264,9 +1258,9 @@ class LLM(BaseLLM):
|
||||
|
||||
litellm.callbacks = callbacks
|
||||
|
||||
def set_env_callbacks(self):
|
||||
"""
|
||||
Sets the success and failure callbacks for the LiteLLM library from environment variables.
|
||||
@staticmethod
|
||||
def set_env_callbacks() -> None:
|
||||
"""Sets the success and failure callbacks for the LiteLLM library from environment variables.
|
||||
|
||||
This method reads the `LITELLM_SUCCESS_CALLBACKS` and `LITELLM_FAILURE_CALLBACKS`
|
||||
environment variables, which should contain comma-separated lists of callback names.
|
||||
@@ -1276,7 +1270,7 @@ class LLM(BaseLLM):
|
||||
If the environment variables are not set or are empty, the corresponding callback lists
|
||||
will be set to empty lists.
|
||||
|
||||
Example:
|
||||
Examples:
|
||||
LITELLM_SUCCESS_CALLBACKS="langfuse,langsmith"
|
||||
LITELLM_FAILURE_CALLBACKS="langfuse"
|
||||
|
||||
@@ -1285,16 +1279,15 @@ class LLM(BaseLLM):
|
||||
"""
|
||||
with suppress_warnings():
|
||||
success_callbacks_str = os.environ.get("LITELLM_SUCCESS_CALLBACKS", "")
|
||||
success_callbacks = []
|
||||
success_callbacks: list[str | Callable[..., Any] | CustomLogger] = []
|
||||
if success_callbacks_str:
|
||||
success_callbacks = [
|
||||
cb.strip() for cb in success_callbacks_str.split(",") if cb.strip()
|
||||
]
|
||||
|
||||
failure_callbacks_str = os.environ.get("LITELLM_FAILURE_CALLBACKS", "")
|
||||
failure_callbacks = []
|
||||
if failure_callbacks_str:
|
||||
failure_callbacks = [
|
||||
failure_callbacks: list[str | Callable[..., Any] | CustomLogger] = [
|
||||
cb.strip() for cb in failure_callbacks_str.split(",") if cb.strip()
|
||||
]
|
||||
|
||||
|
||||
@@ -7,7 +7,8 @@ from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
|
||||
from crewai.rag.config.utils import get_rag_client
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.embeddings.factory import get_embedding_function
|
||||
from crewai.rag.embeddings.factory import EmbedderConfig, get_embedding_function
|
||||
from crewai.rag.embeddings.types import EmbeddingOptions
|
||||
from crewai.rag.factory import create_client
|
||||
from crewai.rag.storage.base_rag_storage import BaseRAGStorage
|
||||
from crewai.rag.types import BaseRecord
|
||||
@@ -25,7 +26,7 @@ class RAGStorage(BaseRAGStorage):
|
||||
self,
|
||||
type: str,
|
||||
allow_reset: bool = True,
|
||||
embedder_config: dict[str, Any] | None = None,
|
||||
embedder_config: EmbeddingOptions | EmbedderConfig | None = None,
|
||||
crew: Any = None,
|
||||
path: str | None = None,
|
||||
) -> None:
|
||||
@@ -50,11 +51,43 @@ class RAGStorage(BaseRAGStorage):
|
||||
|
||||
if self.embedder_config:
|
||||
embedding_function = get_embedding_function(self.embedder_config)
|
||||
config = ChromaDBConfig(
|
||||
embedding_function=cast(
|
||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||
|
||||
try:
|
||||
_ = embedding_function(["test"])
|
||||
except Exception as e:
|
||||
provider = (
|
||||
self.embedder_config.provider
|
||||
if isinstance(self.embedder_config, EmbeddingOptions)
|
||||
else self.embedder_config.get("provider", "unknown")
|
||||
)
|
||||
raise ValueError(
|
||||
f"Failed to initialize embedder. Please check your configuration or connection.\n"
|
||||
f"Provider: {provider}\n"
|
||||
f"Error: {e}"
|
||||
) from e
|
||||
|
||||
batch_size = None
|
||||
if (
|
||||
isinstance(self.embedder_config, dict)
|
||||
and "config" in self.embedder_config
|
||||
):
|
||||
nested_config = self.embedder_config["config"]
|
||||
if isinstance(nested_config, dict):
|
||||
batch_size = nested_config.get("batch_size")
|
||||
|
||||
if batch_size is not None:
|
||||
config = ChromaDBConfig(
|
||||
embedding_function=cast(
|
||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||
),
|
||||
batch_size=batch_size,
|
||||
)
|
||||
else:
|
||||
config = ChromaDBConfig(
|
||||
embedding_function=cast(
|
||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||
)
|
||||
)
|
||||
)
|
||||
self._client = create_client(config)
|
||||
|
||||
def _get_client(self) -> BaseClient:
|
||||
@@ -95,7 +128,26 @@ class RAGStorage(BaseRAGStorage):
|
||||
if metadata:
|
||||
document["metadata"] = metadata
|
||||
|
||||
client.add_documents(collection_name=collection_name, documents=[document])
|
||||
batch_size = None
|
||||
if (
|
||||
self.embedder_config
|
||||
and isinstance(self.embedder_config, dict)
|
||||
and "config" in self.embedder_config
|
||||
):
|
||||
nested_config = self.embedder_config["config"]
|
||||
if isinstance(nested_config, dict):
|
||||
batch_size = nested_config.get("batch_size")
|
||||
|
||||
if batch_size is not None:
|
||||
client.add_documents(
|
||||
collection_name=collection_name,
|
||||
documents=[document],
|
||||
batch_size=batch_size,
|
||||
)
|
||||
else:
|
||||
client.add_documents(
|
||||
collection_name=collection_name, documents=[document]
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error during {self.type} save: {e!s}\n{traceback.format_exc()}"
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
"""RAG (Retrieval-Augmented Generation) infrastructure for CrewAI."""
|
||||
|
||||
import sys
|
||||
import importlib
|
||||
import sys
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
|
||||
from crewai.rag.config.types import RagConfigType
|
||||
from crewai.rag.config.utils import set_rag_config
|
||||
|
||||
|
||||
_module_path = __path__
|
||||
_module_file = __file__
|
||||
|
||||
|
||||
class _RagModule(ModuleType):
|
||||
"""Module wrapper to intercept attribute setting for config."""
|
||||
|
||||
@@ -51,8 +51,10 @@ class _RagModule(ModuleType):
|
||||
"""
|
||||
try:
|
||||
return importlib.import_module(f"{self.__name__}.{name}")
|
||||
except ImportError:
|
||||
raise AttributeError(f"module '{self.__name__}' has no attribute '{name}'")
|
||||
except ImportError as e:
|
||||
raise AttributeError(
|
||||
f"module '{self.__name__}' has no attribute '{name}'"
|
||||
) from e
|
||||
|
||||
|
||||
sys.modules[__name__] = _RagModule(__name__)
|
||||
|
||||
@@ -17,6 +17,7 @@ from crewai.rag.chromadb.types import (
|
||||
ChromaDBCollectionSearchParams,
|
||||
)
|
||||
from crewai.rag.chromadb.utils import (
|
||||
_create_batch_slice,
|
||||
_extract_search_params,
|
||||
_is_async_client,
|
||||
_is_sync_client,
|
||||
@@ -52,6 +53,7 @@ class ChromaDBClient(BaseClient):
|
||||
embedding_function: ChromaEmbeddingFunction,
|
||||
default_limit: int = 5,
|
||||
default_score_threshold: float = 0.6,
|
||||
default_batch_size: int = 100,
|
||||
) -> None:
|
||||
"""Initialize ChromaDBClient with client and embedding function.
|
||||
|
||||
@@ -60,11 +62,13 @@ class ChromaDBClient(BaseClient):
|
||||
embedding_function: Embedding function for text to vector conversion.
|
||||
default_limit: Default number of results to return in searches.
|
||||
default_score_threshold: Default minimum score for search results.
|
||||
default_batch_size: Default batch size for adding documents.
|
||||
"""
|
||||
self.client = client
|
||||
self.embedding_function = embedding_function
|
||||
self.default_limit = default_limit
|
||||
self.default_score_threshold = default_score_threshold
|
||||
self.default_batch_size = default_batch_size
|
||||
|
||||
def create_collection(
|
||||
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
|
||||
@@ -291,6 +295,7 @@ class ChromaDBClient(BaseClient):
|
||||
- content: The text content (required)
|
||||
- doc_id: Optional unique identifier (auto-generated if missing)
|
||||
- metadata: Optional metadata dictionary
|
||||
batch_size: Optional batch size for processing documents (default: 100)
|
||||
|
||||
Raises:
|
||||
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
|
||||
@@ -305,6 +310,7 @@ class ChromaDBClient(BaseClient):
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
documents = kwargs["documents"]
|
||||
batch_size = kwargs.get("batch_size", self.default_batch_size)
|
||||
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
@@ -315,13 +321,17 @@ class ChromaDBClient(BaseClient):
|
||||
)
|
||||
|
||||
prepared = _prepare_documents_for_chromadb(documents)
|
||||
# ChromaDB doesn't accept empty metadata dicts, so pass None if all are empty
|
||||
metadatas = prepared.metadatas if any(m for m in prepared.metadatas) else None
|
||||
collection.upsert(
|
||||
ids=prepared.ids,
|
||||
documents=prepared.texts,
|
||||
metadatas=metadatas,
|
||||
)
|
||||
|
||||
for i in range(0, len(prepared.ids), batch_size):
|
||||
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||
prepared=prepared, start_index=i, batch_size=batch_size
|
||||
)
|
||||
|
||||
collection.upsert(
|
||||
ids=batch_ids,
|
||||
documents=batch_texts,
|
||||
metadatas=batch_metadatas,
|
||||
)
|
||||
|
||||
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
||||
"""Add documents with their embeddings to a collection asynchronously.
|
||||
@@ -335,6 +345,7 @@ class ChromaDBClient(BaseClient):
|
||||
- content: The text content (required)
|
||||
- doc_id: Optional unique identifier (auto-generated if missing)
|
||||
- metadata: Optional metadata dictionary
|
||||
batch_size: Optional batch size for processing documents (default: 100)
|
||||
|
||||
Raises:
|
||||
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
|
||||
@@ -349,6 +360,7 @@ class ChromaDBClient(BaseClient):
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
documents = kwargs["documents"]
|
||||
batch_size = kwargs.get("batch_size", self.default_batch_size)
|
||||
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
@@ -358,13 +370,17 @@ class ChromaDBClient(BaseClient):
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
prepared = _prepare_documents_for_chromadb(documents)
|
||||
# ChromaDB doesn't accept empty metadata dicts, so pass None if all are empty
|
||||
metadatas = prepared.metadatas if any(m for m in prepared.metadatas) else None
|
||||
await collection.upsert(
|
||||
ids=prepared.ids,
|
||||
documents=prepared.texts,
|
||||
metadatas=metadatas,
|
||||
)
|
||||
|
||||
for i in range(0, len(prepared.ids), batch_size):
|
||||
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||
prepared=prepared, start_index=i, batch_size=batch_size
|
||||
)
|
||||
|
||||
await collection.upsert(
|
||||
ids=batch_ids,
|
||||
documents=batch_texts,
|
||||
metadatas=batch_metadatas,
|
||||
)
|
||||
|
||||
def search(
|
||||
self, **kwargs: Unpack[ChromaDBCollectionSearchParams]
|
||||
|
||||
@@ -41,4 +41,5 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient:
|
||||
embedding_function=config.embedding_function,
|
||||
default_limit=config.limit,
|
||||
default_score_threshold=config.score_threshold,
|
||||
default_batch_size=config.batch_size,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Utility functions for ChromaDB client implementation."""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from typing import Literal, TypeGuard, cast
|
||||
|
||||
@@ -72,7 +73,15 @@ def _prepare_documents_for_chromadb(
|
||||
if "doc_id" in doc:
|
||||
ids.append(doc["doc_id"])
|
||||
else:
|
||||
content_hash = hashlib.sha256(doc["content"].encode()).hexdigest()[:16]
|
||||
content_for_hash = doc["content"]
|
||||
metadata = doc.get("metadata")
|
||||
if metadata:
|
||||
metadata_str = json.dumps(metadata, sort_keys=True)
|
||||
content_for_hash = f"{content_for_hash}|{metadata_str}"
|
||||
|
||||
content_hash = hashlib.blake2b(
|
||||
content_for_hash.encode(), digest_size=32
|
||||
).hexdigest()
|
||||
ids.append(content_hash)
|
||||
|
||||
texts.append(doc["content"])
|
||||
@@ -88,6 +97,32 @@ def _prepare_documents_for_chromadb(
|
||||
return PreparedDocuments(ids, texts, metadatas)
|
||||
|
||||
|
||||
def _create_batch_slice(
|
||||
prepared: PreparedDocuments, start_index: int, batch_size: int
|
||||
) -> tuple[list[str], list[str], list[Mapping[str, str | int | float | bool]] | None]:
|
||||
"""Create a batch slice from prepared documents.
|
||||
|
||||
Args:
|
||||
prepared: PreparedDocuments containing ids, texts, and metadatas.
|
||||
start_index: Starting index for the batch.
|
||||
batch_size: Size of the batch.
|
||||
|
||||
Returns:
|
||||
Tuple of (batch_ids, batch_texts, batch_metadatas).
|
||||
"""
|
||||
batch_end = min(start_index + batch_size, len(prepared.ids))
|
||||
batch_ids = prepared.ids[start_index:batch_end]
|
||||
batch_texts = prepared.texts[start_index:batch_end]
|
||||
batch_metadatas = (
|
||||
prepared.metadatas[start_index:batch_end] if prepared.metadatas else None
|
||||
)
|
||||
|
||||
if batch_metadatas and not any(m for m in batch_metadatas):
|
||||
batch_metadatas = None
|
||||
|
||||
return batch_ids, batch_texts, batch_metadatas
|
||||
|
||||
|
||||
def _extract_search_params(
|
||||
kwargs: ChromaDBCollectionSearchParams,
|
||||
) -> ExtractedSearchParams:
|
||||
|
||||
@@ -16,3 +16,4 @@ class BaseRagConfig:
|
||||
embedding_function: Any | None = field(default=None)
|
||||
limit: int = field(default=5)
|
||||
score_threshold: float = field(default=0.6)
|
||||
batch_size: int = field(default=100)
|
||||
|
||||
@@ -1 +1 @@
|
||||
"""Optional imports for RAG configuration providers."""
|
||||
"""Optional imports for RAG configuration providers."""
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Base classes for missing provider configurations."""
|
||||
|
||||
from typing import Literal
|
||||
from dataclasses import field
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import ConfigDict
|
||||
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Protocol, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.rag.chromadb.client import ChromaDBClient
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""Provider-specific missing configuration classes."""
|
||||
|
||||
from typing import Literal
|
||||
from dataclasses import field
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import ConfigDict
|
||||
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Type definitions for RAG configuration."""
|
||||
|
||||
from typing import Annotated, TypeAlias, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Annotated, TypeAlias
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.config.constants import DISCRIMINATOR
|
||||
|
||||
@@ -4,14 +4,14 @@ from contextvars import ContextVar
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.utilities.import_utils import require
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.config.types import RagConfigType
|
||||
from crewai.rag.config.constants import (
|
||||
DEFAULT_RAG_CONFIG_PATH,
|
||||
DEFAULT_RAG_CONFIG_CLASS,
|
||||
DEFAULT_RAG_CONFIG_PATH,
|
||||
)
|
||||
from crewai.rag.config.types import RagConfigType
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.factory import create_client
|
||||
from crewai.utilities.import_utils import require
|
||||
|
||||
|
||||
class RagContext(BaseModel):
|
||||
|
||||
@@ -29,7 +29,7 @@ class BaseCollectionParams(TypedDict):
|
||||
]
|
||||
|
||||
|
||||
class BaseCollectionAddParams(BaseCollectionParams):
|
||||
class BaseCollectionAddParams(BaseCollectionParams, total=False):
|
||||
"""Parameters for adding documents to a collection.
|
||||
|
||||
Extends BaseCollectionParams with document-specific fields.
|
||||
@@ -37,9 +37,11 @@ class BaseCollectionAddParams(BaseCollectionParams):
|
||||
Attributes:
|
||||
collection_name: The name of the collection to add documents to.
|
||||
documents: List of BaseRecord dictionaries containing document data.
|
||||
batch_size: Optional batch size for processing documents to avoid token limits.
|
||||
"""
|
||||
|
||||
documents: list[BaseRecord]
|
||||
documents: Required[list[BaseRecord]]
|
||||
batch_size: int
|
||||
|
||||
|
||||
class BaseCollectionSearchParams(BaseCollectionParams, total=False):
|
||||
|
||||
@@ -1 +1 @@
|
||||
"""Embedding components for RAG infrastructure."""
|
||||
"""Embedding components for RAG infrastructure."""
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Dict, Optional, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from chromadb import Documents, EmbeddingFunction, Embeddings
|
||||
from chromadb.api.types import validate_embedding_function
|
||||
@@ -23,7 +23,7 @@ class EmbeddingConfigurator:
|
||||
|
||||
def configure_embedder(
|
||||
self,
|
||||
embedder_config: Optional[Dict[str, Any]] = None,
|
||||
embedder_config: dict[str, Any] | None = None,
|
||||
) -> EmbeddingFunction:
|
||||
"""Configures and returns an embedding function based on the provided config."""
|
||||
if embedder_config is None:
|
||||
@@ -42,9 +42,9 @@ class EmbeddingConfigurator:
|
||||
embedding_function = self.embedding_functions[provider]
|
||||
except ImportError as e:
|
||||
missing_package = str(e).split()[-1]
|
||||
raise ImportError(
|
||||
raise ImportError(
|
||||
f"{missing_package} is not installed. Please install it with: pip install {missing_package}"
|
||||
)
|
||||
) from e
|
||||
|
||||
return (
|
||||
embedding_function(config)
|
||||
@@ -147,7 +147,7 @@ class EmbeddingConfigurator:
|
||||
|
||||
@staticmethod
|
||||
def _configure_voyageai(config, model_name):
|
||||
from chromadb.utils.embedding_functions.voyageai_embedding_function import (
|
||||
from chromadb.utils.embedding_functions.voyageai_embedding_function import ( # type: ignore[import-not-found]
|
||||
VoyageAIEmbeddingFunction,
|
||||
)
|
||||
|
||||
@@ -181,9 +181,11 @@ class EmbeddingConfigurator:
|
||||
@staticmethod
|
||||
def _configure_watson(config, model_name):
|
||||
try:
|
||||
import ibm_watsonx_ai.foundation_models as watson_models
|
||||
from ibm_watsonx_ai import Credentials
|
||||
from ibm_watsonx_ai.metanames import EmbedTextParamsMetaNames as EmbedParams
|
||||
import ibm_watsonx_ai.foundation_models as watson_models # type: ignore[import-not-found]
|
||||
from ibm_watsonx_ai import Credentials # type: ignore[import-not-found]
|
||||
from ibm_watsonx_ai.metanames import ( # type: ignore[import-not-found]
|
||||
EmbedTextParamsMetaNames as EmbedParams,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"IBM Watson dependencies are not installed. Please install them to use Watson embedding."
|
||||
@@ -225,7 +227,7 @@ class EmbeddingConfigurator:
|
||||
validate_embedding_function(custom_embedder)
|
||||
return custom_embedder
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid custom embedding function: {str(e)}")
|
||||
raise ValueError(f"Invalid custom embedding function: {e!s}") from e
|
||||
elif callable(custom_embedder):
|
||||
try:
|
||||
instance = custom_embedder()
|
||||
@@ -236,7 +238,7 @@ class EmbeddingConfigurator:
|
||||
"Custom embedder does not create an EmbeddingFunction instance"
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error instantiating custom embedder: {str(e)}")
|
||||
raise ValueError(f"Error instantiating custom embedder: {e!s}") from e
|
||||
else:
|
||||
raise ValueError(
|
||||
"Custom embedder must be an instance of `EmbeddingFunction` or a callable that creates one"
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Minimal embedding function factory for CrewAI."""
|
||||
|
||||
import os
|
||||
from collections.abc import Callable, MutableMapping
|
||||
from typing import Any, Final, Literal, TypedDict
|
||||
|
||||
from chromadb import EmbeddingFunction
|
||||
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
|
||||
@@ -42,19 +44,116 @@ from chromadb.utils.embedding_functions.sentence_transformer_embedding_function
|
||||
from chromadb.utils.embedding_functions.text2vec_embedding_function import (
|
||||
Text2VecEmbeddingFunction,
|
||||
)
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from crewai.rag.embeddings.types import EmbeddingOptions
|
||||
|
||||
AllowedEmbeddingProviders = Literal[
|
||||
"openai",
|
||||
"cohere",
|
||||
"ollama",
|
||||
"huggingface",
|
||||
"sentence-transformer",
|
||||
"instructor",
|
||||
"google-palm",
|
||||
"google-generativeai",
|
||||
"google-vertex",
|
||||
"amazon-bedrock",
|
||||
"jina",
|
||||
"roboflow",
|
||||
"openclip",
|
||||
"text2vec",
|
||||
"onnx",
|
||||
]
|
||||
|
||||
|
||||
class EmbedderConfig(TypedDict):
|
||||
"""Configuration for embedding functions with nested format."""
|
||||
|
||||
provider: AllowedEmbeddingProviders
|
||||
config: NotRequired[dict[str, Any]]
|
||||
|
||||
|
||||
EMBEDDING_PROVIDERS: Final[
|
||||
dict[AllowedEmbeddingProviders, Callable[..., EmbeddingFunction]]
|
||||
] = {
|
||||
"openai": OpenAIEmbeddingFunction,
|
||||
"cohere": CohereEmbeddingFunction,
|
||||
"ollama": OllamaEmbeddingFunction,
|
||||
"huggingface": HuggingFaceEmbeddingFunction,
|
||||
"sentence-transformer": SentenceTransformerEmbeddingFunction,
|
||||
"instructor": InstructorEmbeddingFunction,
|
||||
"google-palm": GooglePalmEmbeddingFunction,
|
||||
"google-generativeai": GoogleGenerativeAiEmbeddingFunction,
|
||||
"google-vertex": GoogleVertexEmbeddingFunction,
|
||||
"amazon-bedrock": AmazonBedrockEmbeddingFunction,
|
||||
"jina": JinaEmbeddingFunction,
|
||||
"roboflow": RoboflowEmbeddingFunction,
|
||||
"openclip": OpenCLIPEmbeddingFunction,
|
||||
"text2vec": Text2VecEmbeddingFunction,
|
||||
"onnx": ONNXMiniLM_L6_V2,
|
||||
}
|
||||
|
||||
PROVIDER_ENV_MAPPING: Final[dict[AllowedEmbeddingProviders, tuple[str, str]]] = {
|
||||
"openai": ("OPENAI_API_KEY", "api_key"),
|
||||
"cohere": ("COHERE_API_KEY", "api_key"),
|
||||
"huggingface": ("HUGGINGFACE_API_KEY", "api_key"),
|
||||
"google-palm": ("GOOGLE_API_KEY", "api_key"),
|
||||
"google-generativeai": ("GOOGLE_API_KEY", "api_key"),
|
||||
"google-vertex": ("GOOGLE_API_KEY", "api_key"),
|
||||
"jina": ("JINA_API_KEY", "api_key"),
|
||||
"roboflow": ("ROBOFLOW_API_KEY", "api_key"),
|
||||
}
|
||||
|
||||
|
||||
def _inject_api_key_from_env(
|
||||
provider: AllowedEmbeddingProviders, config_dict: MutableMapping[str, Any]
|
||||
) -> None:
|
||||
"""Inject API key or other required configuration from environment if not explicitly provided.
|
||||
|
||||
Args:
|
||||
provider: The embedding provider name
|
||||
config_dict: The configuration dictionary to modify in-place
|
||||
|
||||
Raises:
|
||||
ImportError: If required libraries for certain providers are not installed
|
||||
ValueError: If AWS session creation fails for amazon-bedrock
|
||||
"""
|
||||
if provider in PROVIDER_ENV_MAPPING:
|
||||
env_var_name, config_key = PROVIDER_ENV_MAPPING[provider]
|
||||
if config_key not in config_dict:
|
||||
env_value = os.getenv(env_var_name)
|
||||
if env_value:
|
||||
config_dict[config_key] = env_value
|
||||
|
||||
if provider == "amazon-bedrock":
|
||||
if "session" not in config_dict:
|
||||
try:
|
||||
import boto3 # type: ignore[import]
|
||||
|
||||
config_dict["session"] = boto3.Session()
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"boto3 is required for amazon-bedrock embeddings. "
|
||||
"Install it with: uv add boto3"
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Failed to create AWS session for amazon-bedrock. "
|
||||
f"Ensure AWS credentials are configured. Error: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
def get_embedding_function(
|
||||
config: EmbeddingOptions | dict | None = None,
|
||||
config: EmbeddingOptions | EmbedderConfig | None = None,
|
||||
) -> EmbeddingFunction:
|
||||
"""Get embedding function - delegates to ChromaDB.
|
||||
|
||||
Args:
|
||||
config: Optional configuration - either an EmbeddingOptions object or a dict with:
|
||||
- provider: The embedding provider to use (default: "openai")
|
||||
- Any other provider-specific parameters
|
||||
config: Optional configuration - either:
|
||||
- EmbeddingOptions: Pydantic model with flat configuration
|
||||
- EmbedderConfig: TypedDict with nested format {"provider": str, "config": dict}
|
||||
- None: Uses default OpenAI configuration
|
||||
|
||||
Returns:
|
||||
EmbeddingFunction instance ready for use with ChromaDB
|
||||
@@ -81,31 +180,33 @@ def get_embedding_function(
|
||||
>>> embedder = get_embedding_function()
|
||||
|
||||
# Use Cohere with dict
|
||||
>>> embedder = get_embedding_function({
|
||||
>>> embedder = get_embedding_function(EmbedderConfig(**{
|
||||
... "provider": "cohere",
|
||||
... "api_key": "your-key",
|
||||
... "model_name": "embed-english-v3.0"
|
||||
... })
|
||||
... "config": {
|
||||
... "api_key": "your-key",
|
||||
... "model_name": "embed-english-v3.0"
|
||||
... }
|
||||
... }))
|
||||
|
||||
# Use with EmbeddingOptions
|
||||
>>> embedder = get_embedding_function(
|
||||
... EmbeddingOptions(provider="sentence-transformer", model_name="all-MiniLM-L6-v2")
|
||||
... )
|
||||
|
||||
# Use local sentence transformers (no API key needed)
|
||||
>>> embedder = get_embedding_function({
|
||||
... "provider": "sentence-transformer",
|
||||
... "model_name": "all-MiniLM-L6-v2"
|
||||
# Use Azure OpenAI
|
||||
>>> embedder = get_embedding_function(EmbedderConfig(**{
|
||||
... "provider": "openai",
|
||||
... "config": {
|
||||
... "api_key": "your-azure-key",
|
||||
... "api_base": "https://your-resource.openai.azure.com/",
|
||||
... "api_type": "azure",
|
||||
... "api_version": "2023-05-15",
|
||||
... "model": "text-embedding-3-small",
|
||||
... "deployment_id": "your-deployment-name"
|
||||
... }
|
||||
... })
|
||||
|
||||
# Use Ollama for local embeddings
|
||||
>>> embedder = get_embedding_function({
|
||||
... "provider": "ollama",
|
||||
... "model_name": "nomic-embed-text"
|
||||
... })
|
||||
|
||||
# Use ONNX (no API key needed)
|
||||
>>> embedder = get_embedding_function({
|
||||
>>> embedder = get_embedding_function(EmbedderConfig(**{
|
||||
... "provider": "onnx"
|
||||
... })
|
||||
"""
|
||||
@@ -114,35 +215,35 @@ def get_embedding_function(
|
||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
||||
)
|
||||
|
||||
# Handle EmbeddingOptions object
|
||||
provider: AllowedEmbeddingProviders
|
||||
config_dict: dict[str, Any]
|
||||
|
||||
if isinstance(config, EmbeddingOptions):
|
||||
config_dict = config.model_dump(exclude_none=True)
|
||||
provider = config_dict["provider"]
|
||||
else:
|
||||
config_dict = config.copy()
|
||||
provider = config["provider"]
|
||||
nested: dict[str, Any] = config.get("config", {})
|
||||
|
||||
provider = config_dict.pop("provider", "openai")
|
||||
if not nested and len(config) > 1:
|
||||
raise ValueError(
|
||||
"Invalid embedder configuration format. "
|
||||
"Configuration must be nested under a 'config' key. "
|
||||
"Example: {'provider': 'openai', 'config': {'api_key': '...', 'model': '...'}}"
|
||||
)
|
||||
|
||||
embedding_functions = {
|
||||
"openai": OpenAIEmbeddingFunction,
|
||||
"cohere": CohereEmbeddingFunction,
|
||||
"ollama": OllamaEmbeddingFunction,
|
||||
"huggingface": HuggingFaceEmbeddingFunction,
|
||||
"sentence-transformer": SentenceTransformerEmbeddingFunction,
|
||||
"instructor": InstructorEmbeddingFunction,
|
||||
"google-palm": GooglePalmEmbeddingFunction,
|
||||
"google-generativeai": GoogleGenerativeAiEmbeddingFunction,
|
||||
"google-vertex": GoogleVertexEmbeddingFunction,
|
||||
"amazon-bedrock": AmazonBedrockEmbeddingFunction,
|
||||
"jina": JinaEmbeddingFunction,
|
||||
"roboflow": RoboflowEmbeddingFunction,
|
||||
"openclip": OpenCLIPEmbeddingFunction,
|
||||
"text2vec": Text2VecEmbeddingFunction,
|
||||
"onnx": ONNXMiniLM_L6_V2,
|
||||
}
|
||||
config_dict = dict(nested)
|
||||
if "model" in config_dict and "model_name" not in config_dict:
|
||||
config_dict["model_name"] = config_dict.pop("model")
|
||||
|
||||
if provider not in embedding_functions:
|
||||
if provider not in EMBEDDING_PROVIDERS:
|
||||
raise ValueError(
|
||||
f"Unsupported provider: {provider}. "
|
||||
f"Available providers: {list(embedding_functions.keys())}"
|
||||
f"Available providers: {list(EMBEDDING_PROVIDERS.keys())}"
|
||||
)
|
||||
return embedding_functions[provider](**config_dict)
|
||||
|
||||
_inject_api_key_from_env(provider, config_dict)
|
||||
|
||||
config_dict.pop("batch_size", None)
|
||||
|
||||
return EMBEDDING_PROVIDERS[provider](**config_dict)
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
"""Type definitions for the embeddings module."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from crewai.rag.types import EmbeddingFunction
|
||||
|
||||
|
||||
EmbeddingProvider = Literal[
|
||||
"openai",
|
||||
"cohere",
|
||||
|
||||
@@ -6,8 +6,8 @@ from crewai.rag.config.optional_imports.protocols import (
|
||||
ChromaFactoryModule,
|
||||
QdrantFactoryModule,
|
||||
)
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.config.types import RagConfigType
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.utilities.import_utils import require
|
||||
|
||||
|
||||
@@ -43,3 +43,5 @@ def create_client(config: RagConfigType) -> BaseClient:
|
||||
),
|
||||
)
|
||||
return qdrant_mod.create_client(config)
|
||||
|
||||
raise ValueError(f"Unsupported provider: {config.provider}")
|
||||
|
||||
@@ -1 +1 @@
|
||||
"""Qdrant vector database client implementation."""
|
||||
"""Qdrant vector database client implementation."""
|
||||
|
||||
@@ -48,6 +48,7 @@ class QdrantClient(BaseClient):
|
||||
embedding_function: EmbeddingFunction | AsyncEmbeddingFunction,
|
||||
default_limit: int = 5,
|
||||
default_score_threshold: float = 0.6,
|
||||
default_batch_size: int = 100,
|
||||
) -> None:
|
||||
"""Initialize QdrantClient with client and embedding function.
|
||||
|
||||
@@ -56,11 +57,13 @@ class QdrantClient(BaseClient):
|
||||
embedding_function: Embedding function for text to vector conversion.
|
||||
default_limit: Default number of results to return in searches.
|
||||
default_score_threshold: Default minimum score for search results.
|
||||
default_batch_size: Default batch size for adding documents.
|
||||
"""
|
||||
self.client = client
|
||||
self.embedding_function = embedding_function
|
||||
self.default_limit = default_limit
|
||||
self.default_score_threshold = default_score_threshold
|
||||
self.default_batch_size = default_batch_size
|
||||
|
||||
def create_collection(self, **kwargs: Unpack[QdrantCollectionCreateParams]) -> None:
|
||||
"""Create a new collection in Qdrant.
|
||||
@@ -234,6 +237,7 @@ class QdrantClient(BaseClient):
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to add documents to.
|
||||
documents: List of BaseRecord dicts containing document data.
|
||||
batch_size: Optional batch size for processing documents (default: 100)
|
||||
|
||||
Raises:
|
||||
ValueError: If collection doesn't exist or documents list is empty.
|
||||
@@ -249,6 +253,7 @@ class QdrantClient(BaseClient):
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
documents = kwargs["documents"]
|
||||
batch_size = kwargs.get("batch_size", self.default_batch_size)
|
||||
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
@@ -256,19 +261,20 @@ class QdrantClient(BaseClient):
|
||||
if not self.client.collection_exists(collection_name):
|
||||
raise ValueError(f"Collection '{collection_name}' does not exist")
|
||||
|
||||
points = []
|
||||
for doc in documents:
|
||||
if _is_async_embedding_function(self.embedding_function):
|
||||
raise TypeError(
|
||||
"Async embedding function cannot be used with sync add_documents. "
|
||||
"Use aadd_documents instead."
|
||||
)
|
||||
sync_fn = cast(EmbeddingFunction, self.embedding_function)
|
||||
embedding = sync_fn(doc["content"])
|
||||
point = _create_point_from_document(doc, embedding)
|
||||
points.append(point)
|
||||
|
||||
self.client.upsert(collection_name=collection_name, points=points)
|
||||
for i in range(0, len(documents), batch_size):
|
||||
batch_docs = documents[i : min(i + batch_size, len(documents))]
|
||||
points = []
|
||||
for doc in batch_docs:
|
||||
if _is_async_embedding_function(self.embedding_function):
|
||||
raise TypeError(
|
||||
"Async embedding function cannot be used with sync add_documents. "
|
||||
"Use aadd_documents instead."
|
||||
)
|
||||
sync_fn = cast(EmbeddingFunction, self.embedding_function)
|
||||
embedding = sync_fn(doc["content"])
|
||||
point = _create_point_from_document(doc, embedding)
|
||||
points.append(point)
|
||||
self.client.upsert(collection_name=collection_name, points=points)
|
||||
|
||||
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
||||
"""Add documents with their embeddings to a collection asynchronously.
|
||||
@@ -276,6 +282,7 @@ class QdrantClient(BaseClient):
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to add documents to.
|
||||
documents: List of BaseRecord dicts containing document data.
|
||||
batch_size: Optional batch size for processing documents (default: 100)
|
||||
|
||||
Raises:
|
||||
ValueError: If collection doesn't exist or documents list is empty.
|
||||
@@ -291,6 +298,7 @@ class QdrantClient(BaseClient):
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
documents = kwargs["documents"]
|
||||
batch_size = kwargs.get("batch_size", self.default_batch_size)
|
||||
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
@@ -298,18 +306,19 @@ class QdrantClient(BaseClient):
|
||||
if not await self.client.collection_exists(collection_name):
|
||||
raise ValueError(f"Collection '{collection_name}' does not exist")
|
||||
|
||||
points = []
|
||||
for doc in documents:
|
||||
if _is_async_embedding_function(self.embedding_function):
|
||||
async_fn = cast(AsyncEmbeddingFunction, self.embedding_function)
|
||||
embedding = await async_fn(doc["content"])
|
||||
else:
|
||||
sync_fn = cast(EmbeddingFunction, self.embedding_function)
|
||||
embedding = sync_fn(doc["content"])
|
||||
point = _create_point_from_document(doc, embedding)
|
||||
points.append(point)
|
||||
|
||||
await self.client.upsert(collection_name=collection_name, points=points)
|
||||
for i in range(0, len(documents), batch_size):
|
||||
batch_docs = documents[i : min(i + batch_size, len(documents))]
|
||||
points = []
|
||||
for doc in batch_docs:
|
||||
if _is_async_embedding_function(self.embedding_function):
|
||||
async_fn = cast(AsyncEmbeddingFunction, self.embedding_function)
|
||||
embedding = await async_fn(doc["content"])
|
||||
else:
|
||||
sync_fn = cast(EmbeddingFunction, self.embedding_function)
|
||||
embedding = sync_fn(doc["content"])
|
||||
point = _create_point_from_document(doc, embedding)
|
||||
points.append(point)
|
||||
await self.client.upsert(collection_name=collection_name, points=points)
|
||||
|
||||
def search(
|
||||
self, **kwargs: Unpack[BaseCollectionSearchParams]
|
||||
|
||||
@@ -2,11 +2,12 @@
|
||||
|
||||
from dataclasses import field
|
||||
from typing import Literal, cast
|
||||
|
||||
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||
|
||||
from crewai.rag.config.base import BaseRagConfig
|
||||
from crewai.rag.qdrant.types import QdrantClientParams, QdrantEmbeddingFunctionWrapper
|
||||
from crewai.rag.qdrant.constants import DEFAULT_EMBEDDING_MODEL, DEFAULT_STORAGE_PATH
|
||||
from crewai.rag.qdrant.types import QdrantClientParams, QdrantEmbeddingFunctionWrapper
|
||||
|
||||
|
||||
def _default_options() -> QdrantClientParams:
|
||||
@@ -24,7 +25,7 @@ def _default_embedding_function() -> QdrantEmbeddingFunctionWrapper:
|
||||
Returns:
|
||||
Default embedding function using fastembed with all-MiniLM-L6-v2.
|
||||
"""
|
||||
from fastembed import TextEmbedding
|
||||
from fastembed import TextEmbedding # type: ignore[import-not-found]
|
||||
|
||||
model = TextEmbedding(model_name=DEFAULT_EMBEDDING_MODEL)
|
||||
|
||||
|
||||
@@ -22,4 +22,5 @@ def create_client(config: QdrantConfig) -> QdrantClient:
|
||||
embedding_function=config.embedding_function,
|
||||
default_limit=config.limit,
|
||||
default_score_threshold=config.score_threshold,
|
||||
default_batch_size=config.batch_size,
|
||||
)
|
||||
|
||||
@@ -2,13 +2,15 @@
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Annotated, Any, Protocol, TypeAlias
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
import numpy as np
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
from qdrant_client import AsyncQdrantClient, QdrantClient as SyncQdrantClient
|
||||
from qdrant_client.models import (
|
||||
from qdrant_client import AsyncQdrantClient # type: ignore[import-not-found]
|
||||
from qdrant_client import (
|
||||
QdrantClient as SyncQdrantClient, # type: ignore[import-not-found]
|
||||
)
|
||||
from qdrant_client.models import ( # type: ignore[import-not-found]
|
||||
FieldCondition,
|
||||
Filter,
|
||||
HasIdCondition,
|
||||
@@ -25,6 +27,7 @@ from qdrant_client.models import (
|
||||
VectorsConfig,
|
||||
WalConfigDiff,
|
||||
)
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from crewai.rag.core.base_client import BaseCollectionParams
|
||||
|
||||
@@ -134,8 +137,6 @@ class QdrantCollectionCreateParams(
|
||||
):
|
||||
"""High-level parameters for creating a Qdrant collection."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CreateCollectionParams(CommonCreateFields, total=False):
|
||||
"""Parameters for qdrant_client.create_collection."""
|
||||
|
||||
@@ -4,8 +4,11 @@ import asyncio
|
||||
from typing import TypeGuard
|
||||
from uuid import uuid4
|
||||
|
||||
from qdrant_client import AsyncQdrantClient, QdrantClient as SyncQdrantClient
|
||||
from qdrant_client.models import (
|
||||
from qdrant_client import AsyncQdrantClient # type: ignore[import-not-found]
|
||||
from qdrant_client import (
|
||||
QdrantClient as SyncQdrantClient, # type: ignore[import-not-found]
|
||||
)
|
||||
from qdrant_client.models import ( # type: ignore[import-not-found]
|
||||
FieldCondition,
|
||||
Filter,
|
||||
MatchValue,
|
||||
@@ -25,7 +28,7 @@ from crewai.rag.qdrant.types import (
|
||||
QdrantCollectionCreateParams,
|
||||
QueryEmbedding,
|
||||
)
|
||||
from crewai.rag.types import SearchResult, BaseRecord
|
||||
from crewai.rag.types import BaseRecord, SearchResult
|
||||
|
||||
|
||||
def _ensure_list_embedding(embedding: QueryEmbedding) -> list[float]:
|
||||
@@ -38,7 +41,8 @@ def _ensure_list_embedding(embedding: QueryEmbedding) -> list[float]:
|
||||
Embedding as list[float].
|
||||
"""
|
||||
if not isinstance(embedding, list):
|
||||
return embedding.tolist()
|
||||
result = embedding.tolist()
|
||||
return result if isinstance(result, list) else [result]
|
||||
return embedding
|
||||
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
"""Storage components for RAG infrastructure."""
|
||||
"""Storage components for RAG infrastructure."""
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from crewai.rag.embeddings.factory import EmbedderConfig
|
||||
from crewai.rag.embeddings.types import EmbeddingOptions
|
||||
|
||||
|
||||
class BaseRAGStorage(ABC):
|
||||
"""
|
||||
@@ -13,7 +16,7 @@ class BaseRAGStorage(ABC):
|
||||
self,
|
||||
type: str,
|
||||
allow_reset: bool = True,
|
||||
embedder_config: dict[str, Any] | None = None,
|
||||
embedder_config: EmbeddingOptions | EmbedderConfig | None = None,
|
||||
crew: Any = None,
|
||||
):
|
||||
self.type = type
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Type definitions for RAG (Retrieval-Augmented Generation) systems."""
|
||||
|
||||
from collections.abc import Callable, Mapping
|
||||
from typing import TypeAlias, Any
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
@@ -5,20 +5,14 @@ import logging
|
||||
import threading
|
||||
import uuid
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import Future
|
||||
from copy import copy
|
||||
from hashlib import md5
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
get_args,
|
||||
get_origin,
|
||||
@@ -35,20 +29,20 @@ from pydantic import (
|
||||
from pydantic_core import PydanticCustomError
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.event_types import (
|
||||
TaskCompletedEvent,
|
||||
TaskFailedEvent,
|
||||
TaskStartedEvent,
|
||||
)
|
||||
from crewai.security import Fingerprint, SecurityConfig
|
||||
from crewai.tasks.output_format import OutputFormat
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.utilities.config import process_config
|
||||
from crewai.utilities.constants import NOT_SPECIFIED, _NotSpecified
|
||||
from crewai.utilities.guardrail import process_guardrail, GuardrailResult
|
||||
from crewai.utilities.converter import Converter, convert_to_model
|
||||
from crewai.events.event_types import (
|
||||
TaskCompletedEvent,
|
||||
TaskFailedEvent,
|
||||
TaskStartedEvent,
|
||||
)
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.utilities.guardrail import process_guardrail
|
||||
from crewai.utilities.i18n import I18N
|
||||
from crewai.utilities.printer import Printer
|
||||
from crewai.utilities.string_utils import interpolate_only
|
||||
@@ -85,50 +79,50 @@ class Task(BaseModel):
|
||||
tools_errors: int = 0
|
||||
delegations: int = 0
|
||||
i18n: I18N = I18N()
|
||||
name: Optional[str] = Field(default=None)
|
||||
prompt_context: Optional[str] = None
|
||||
name: str | None = Field(default=None)
|
||||
prompt_context: str | None = None
|
||||
description: str = Field(description="Description of the actual task.")
|
||||
expected_output: str = Field(
|
||||
description="Clear definition of expected output for the task."
|
||||
)
|
||||
config: Optional[Dict[str, Any]] = Field(
|
||||
config: dict[str, Any] | None = Field(
|
||||
description="Configuration for the agent",
|
||||
default=None,
|
||||
)
|
||||
callback: Optional[Any] = Field(
|
||||
callback: Any | None = Field(
|
||||
description="Callback to be executed after the task is completed.", default=None
|
||||
)
|
||||
agent: Optional[BaseAgent] = Field(
|
||||
agent: BaseAgent | None = Field(
|
||||
description="Agent responsible for execution the task.", default=None
|
||||
)
|
||||
context: Union[List["Task"], None, _NotSpecified] = Field(
|
||||
context: list["Task"] | None | _NotSpecified = Field(
|
||||
description="Other tasks that will have their output used as context for this task.",
|
||||
default=NOT_SPECIFIED,
|
||||
)
|
||||
async_execution: Optional[bool] = Field(
|
||||
async_execution: bool | None = Field(
|
||||
description="Whether the task should be executed asynchronously or not.",
|
||||
default=False,
|
||||
)
|
||||
output_json: Optional[Type[BaseModel]] = Field(
|
||||
output_json: type[BaseModel] | None = Field(
|
||||
description="A Pydantic model to be used to create a JSON output.",
|
||||
default=None,
|
||||
)
|
||||
output_pydantic: Optional[Type[BaseModel]] = Field(
|
||||
output_pydantic: type[BaseModel] | None = Field(
|
||||
description="A Pydantic model to be used to create a Pydantic output.",
|
||||
default=None,
|
||||
)
|
||||
output_file: Optional[str] = Field(
|
||||
output_file: str | None = Field(
|
||||
description="A file path to be used to create a file output.",
|
||||
default=None,
|
||||
)
|
||||
create_directory: Optional[bool] = Field(
|
||||
create_directory: bool | None = Field(
|
||||
description="Whether to create the directory for output_file if it doesn't exist.",
|
||||
default=True,
|
||||
)
|
||||
output: Optional[TaskOutput] = Field(
|
||||
output: TaskOutput | None = Field(
|
||||
description="Task output, it's final result after being executed", default=None
|
||||
)
|
||||
tools: Optional[List[BaseTool]] = Field(
|
||||
tools: list[BaseTool] | None = Field(
|
||||
default_factory=list,
|
||||
description="Tools the agent is limited to use for this task.",
|
||||
)
|
||||
@@ -141,24 +135,24 @@ class Task(BaseModel):
|
||||
frozen=True,
|
||||
description="Unique identifier for the object, not set by user.",
|
||||
)
|
||||
human_input: Optional[bool] = Field(
|
||||
human_input: bool | None = Field(
|
||||
description="Whether the task should have a human review the final answer of the agent",
|
||||
default=False,
|
||||
)
|
||||
markdown: Optional[bool] = Field(
|
||||
markdown: bool | None = Field(
|
||||
description="Whether the task should instruct the agent to return the final answer formatted in Markdown",
|
||||
default=False,
|
||||
)
|
||||
converter_cls: Optional[Type[Converter]] = Field(
|
||||
converter_cls: type[Converter] | None = Field(
|
||||
description="A converter class used to export structured output",
|
||||
default=None,
|
||||
)
|
||||
processed_by_agents: Set[str] = Field(default_factory=set)
|
||||
guardrail: Optional[Union[Callable[[TaskOutput], Tuple[bool, Any]], str]] = Field(
|
||||
processed_by_agents: set[str] = Field(default_factory=set)
|
||||
guardrail: Callable[[TaskOutput], tuple[bool, Any]] | str | None = Field(
|
||||
default=None,
|
||||
description="Function or string description of a guardrail to validate task output before proceeding to next task",
|
||||
)
|
||||
max_retries: Optional[int] = Field(
|
||||
max_retries: int | None = Field(
|
||||
default=None,
|
||||
description="[DEPRECATED] Maximum number of retries when guardrail fails. Use guardrail_max_retries instead. Will be removed in v1.0.0",
|
||||
)
|
||||
@@ -166,13 +160,13 @@ class Task(BaseModel):
|
||||
default=3, description="Maximum number of retries when guardrail fails"
|
||||
)
|
||||
retry_count: int = Field(default=0, description="Current number of retries")
|
||||
start_time: Optional[datetime.datetime] = Field(
|
||||
start_time: datetime.datetime | None = Field(
|
||||
default=None, description="Start time of the task execution"
|
||||
)
|
||||
end_time: Optional[datetime.datetime] = Field(
|
||||
end_time: datetime.datetime | None = Field(
|
||||
default=None, description="End time of the task execution"
|
||||
)
|
||||
allow_crewai_trigger_context: Optional[bool] = Field(
|
||||
allow_crewai_trigger_context: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether this task should append 'Trigger Payload: {crewai_trigger_payload}' to the task description when crewai_trigger_payload exists in crew inputs.",
|
||||
)
|
||||
@@ -181,8 +175,8 @@ class Task(BaseModel):
|
||||
@field_validator("guardrail")
|
||||
@classmethod
|
||||
def validate_guardrail_function(
|
||||
cls, v: Optional[str | Callable]
|
||||
) -> Optional[str | Callable]:
|
||||
cls, v: str | Callable | None
|
||||
) -> str | Callable | None:
|
||||
"""
|
||||
If v is a callable, validate that the guardrail function has the correct signature and behavior.
|
||||
If v is a string, return it as is.
|
||||
@@ -229,7 +223,7 @@ class Task(BaseModel):
|
||||
return_annotation_args[1] is Any
|
||||
or return_annotation_args[1] is str
|
||||
or return_annotation_args[1] is TaskOutput
|
||||
or return_annotation_args[1] == Union[str, TaskOutput]
|
||||
or return_annotation_args[1] == str | TaskOutput
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
@@ -237,11 +231,11 @@ class Task(BaseModel):
|
||||
)
|
||||
return v
|
||||
|
||||
_guardrail: Optional[Callable] = PrivateAttr(default=None)
|
||||
_original_description: Optional[str] = PrivateAttr(default=None)
|
||||
_original_expected_output: Optional[str] = PrivateAttr(default=None)
|
||||
_original_output_file: Optional[str] = PrivateAttr(default=None)
|
||||
_thread: Optional[threading.Thread] = PrivateAttr(default=None)
|
||||
_guardrail: Callable | None = PrivateAttr(default=None)
|
||||
_original_description: str | None = PrivateAttr(default=None)
|
||||
_original_expected_output: str | None = PrivateAttr(default=None)
|
||||
_original_output_file: str | None = PrivateAttr(default=None)
|
||||
_thread: threading.Thread | None = PrivateAttr(default=None)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@@ -265,7 +259,9 @@ class Task(BaseModel):
|
||||
elif isinstance(self.guardrail, str):
|
||||
from crewai.tasks.llm_guardrail import LLMGuardrail
|
||||
|
||||
assert self.agent is not None
|
||||
if self.agent is None:
|
||||
raise ValueError("Agent is required to use LLMGuardrail")
|
||||
|
||||
self._guardrail = LLMGuardrail(
|
||||
description=self.guardrail, llm=self.agent.llm
|
||||
)
|
||||
@@ -274,7 +270,7 @@ class Task(BaseModel):
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
def _deny_user_set_id(cls, v: Optional[UUID4]) -> None:
|
||||
def _deny_user_set_id(cls, v: UUID4 | None) -> None:
|
||||
if v:
|
||||
raise PydanticCustomError(
|
||||
"may_not_set_field", "This field is not to be set by the user.", {}
|
||||
@@ -282,7 +278,7 @@ class Task(BaseModel):
|
||||
|
||||
@field_validator("output_file")
|
||||
@classmethod
|
||||
def output_file_validation(cls, value: Optional[str]) -> Optional[str]:
|
||||
def output_file_validation(cls, value: str | None) -> str | None:
|
||||
"""Validate the output file path.
|
||||
|
||||
Args:
|
||||
@@ -307,7 +303,7 @@ class Task(BaseModel):
|
||||
)
|
||||
|
||||
# Check for shell expansion first
|
||||
if value.startswith("~") or value.startswith("$"):
|
||||
if value.startswith(("~", "$")):
|
||||
raise ValueError(
|
||||
"Shell expansion characters are not allowed in output_file paths"
|
||||
)
|
||||
@@ -373,9 +369,9 @@ class Task(BaseModel):
|
||||
|
||||
def execute_sync(
|
||||
self,
|
||||
agent: Optional[BaseAgent] = None,
|
||||
context: Optional[str] = None,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
agent: BaseAgent | None = None,
|
||||
context: str | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
) -> TaskOutput:
|
||||
"""Execute the task synchronously."""
|
||||
return self._execute_core(agent, context, tools)
|
||||
@@ -397,8 +393,8 @@ class Task(BaseModel):
|
||||
def execute_async(
|
||||
self,
|
||||
agent: BaseAgent | None = None,
|
||||
context: Optional[str] = None,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
context: str | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
) -> Future[TaskOutput]:
|
||||
"""Execute the task asynchronously."""
|
||||
future: Future[TaskOutput] = Future()
|
||||
@@ -411,9 +407,9 @@ class Task(BaseModel):
|
||||
|
||||
def _execute_task_async(
|
||||
self,
|
||||
agent: Optional[BaseAgent],
|
||||
context: Optional[str],
|
||||
tools: Optional[List[Any]],
|
||||
agent: BaseAgent | None,
|
||||
context: str | None,
|
||||
tools: list[Any] | None,
|
||||
future: Future[TaskOutput],
|
||||
) -> None:
|
||||
"""Execute the task asynchronously with context handling."""
|
||||
@@ -422,9 +418,9 @@ class Task(BaseModel):
|
||||
|
||||
def _execute_core(
|
||||
self,
|
||||
agent: Optional[BaseAgent],
|
||||
context: Optional[str],
|
||||
tools: Optional[List[Any]],
|
||||
agent: BaseAgent | None,
|
||||
context: str | None,
|
||||
tools: list[Any] | None,
|
||||
) -> TaskOutput:
|
||||
"""Run the core execution logic of the task."""
|
||||
try:
|
||||
@@ -465,6 +461,7 @@ class Task(BaseModel):
|
||||
output=task_output,
|
||||
guardrail=self._guardrail,
|
||||
retry_count=self.retry_count,
|
||||
event_source=self,
|
||||
)
|
||||
if not guardrail_result.success:
|
||||
if self.retry_count >= self.guardrail_max_retries:
|
||||
@@ -528,41 +525,6 @@ class Task(BaseModel):
|
||||
crewai_event_bus.emit(self, TaskFailedEvent(error=str(e), task=self))
|
||||
raise e # Re-raise the exception after emitting the event
|
||||
|
||||
def _process_guardrail(self, task_output: TaskOutput) -> GuardrailResult:
|
||||
assert self._guardrail is not None
|
||||
|
||||
from crewai.events.event_types import (
|
||||
LLMGuardrailCompletedEvent,
|
||||
LLMGuardrailStartedEvent,
|
||||
)
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
LLMGuardrailStartedEvent(
|
||||
guardrail=self._guardrail, retry_count=self.retry_count
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
result = self._guardrail(task_output)
|
||||
guardrail_result = GuardrailResult.from_tuple(result)
|
||||
except Exception as e:
|
||||
guardrail_result = GuardrailResult(
|
||||
success=False, result=None, error=f"Guardrail execution error: {str(e)}"
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
LLMGuardrailCompletedEvent(
|
||||
success=guardrail_result.success,
|
||||
result=guardrail_result.result,
|
||||
error=guardrail_result.error,
|
||||
retry_count=self.retry_count,
|
||||
),
|
||||
)
|
||||
return guardrail_result
|
||||
|
||||
def prompt(self) -> str:
|
||||
"""Generates the task prompt with optional markdown formatting.
|
||||
|
||||
@@ -604,7 +566,7 @@ Follow these guidelines:
|
||||
return "\n".join(tasks_slices)
|
||||
|
||||
def interpolate_inputs_and_add_conversation_history(
|
||||
self, inputs: Dict[str, Union[str, int, float, Dict[str, Any], List[Any]]]
|
||||
self, inputs: dict[str, str | int | float | dict[str, Any] | list[Any]]
|
||||
) -> None:
|
||||
"""Interpolate inputs into the task description, expected output, and output file path.
|
||||
Add conversation history if present.
|
||||
@@ -635,14 +597,14 @@ Follow these guidelines:
|
||||
f"Missing required template variable '{e.args[0]}' in description"
|
||||
) from e
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Error interpolating description: {str(e)}") from e
|
||||
raise ValueError(f"Error interpolating description: {e!s}") from e
|
||||
|
||||
try:
|
||||
self.expected_output = interpolate_only(
|
||||
input_string=self._original_expected_output, inputs=inputs
|
||||
)
|
||||
except (KeyError, ValueError) as e:
|
||||
raise ValueError(f"Error interpolating expected_output: {str(e)}") from e
|
||||
raise ValueError(f"Error interpolating expected_output: {e!s}") from e
|
||||
|
||||
if self.output_file is not None:
|
||||
try:
|
||||
@@ -650,11 +612,9 @@ Follow these guidelines:
|
||||
input_string=self._original_output_file, inputs=inputs
|
||||
)
|
||||
except (KeyError, ValueError) as e:
|
||||
raise ValueError(
|
||||
f"Error interpolating output_file path: {str(e)}"
|
||||
) from e
|
||||
raise ValueError(f"Error interpolating output_file path: {e!s}") from e
|
||||
|
||||
if "crew_chat_messages" in inputs and inputs["crew_chat_messages"]:
|
||||
if inputs.get("crew_chat_messages"):
|
||||
conversation_instruction = self.i18n.slice(
|
||||
"conversation_history_instruction"
|
||||
)
|
||||
@@ -681,14 +641,14 @@ Follow these guidelines:
|
||||
"""Increment the tools errors counter."""
|
||||
self.tools_errors += 1
|
||||
|
||||
def increment_delegations(self, agent_name: Optional[str]) -> None:
|
||||
def increment_delegations(self, agent_name: str | None) -> None:
|
||||
"""Increment the delegations counter."""
|
||||
if agent_name:
|
||||
self.processed_by_agents.add(agent_name)
|
||||
self.delegations += 1
|
||||
|
||||
def copy(
|
||||
self, agents: List["BaseAgent"], task_mapping: Dict[str, "Task"]
|
||||
def copy( # type: ignore
|
||||
self, agents: list["BaseAgent"], task_mapping: dict[str, "Task"]
|
||||
) -> "Task":
|
||||
"""Creates a deep copy of the Task while preserving its original class type.
|
||||
|
||||
@@ -721,20 +681,18 @@ Follow these guidelines:
|
||||
cloned_agent = get_agent_by_role(self.agent.role) if self.agent else None
|
||||
cloned_tools = copy(self.tools) if self.tools else []
|
||||
|
||||
copied_task = self.__class__(
|
||||
return self.__class__(
|
||||
**copied_data,
|
||||
context=cloned_context,
|
||||
agent=cloned_agent,
|
||||
tools=cloned_tools,
|
||||
)
|
||||
|
||||
return copied_task
|
||||
|
||||
def _export_output(
|
||||
self, result: str
|
||||
) -> Tuple[Optional[BaseModel], Optional[Dict[str, Any]]]:
|
||||
pydantic_output: Optional[BaseModel] = None
|
||||
json_output: Optional[Dict[str, Any]] = None
|
||||
) -> tuple[BaseModel | None, dict[str, Any] | None]:
|
||||
pydantic_output: BaseModel | None = None
|
||||
json_output: dict[str, Any] | None = None
|
||||
|
||||
if self.output_pydantic or self.output_json:
|
||||
model_output = convert_to_model(
|
||||
@@ -764,7 +722,7 @@ Follow these guidelines:
|
||||
return OutputFormat.PYDANTIC
|
||||
return OutputFormat.RAW
|
||||
|
||||
def _save_file(self, result: Union[Dict, str, Any]) -> None:
|
||||
def _save_file(self, result: dict | str | Any) -> None:
|
||||
"""Save task output to a file.
|
||||
|
||||
Note:
|
||||
@@ -785,7 +743,7 @@ Follow these guidelines:
|
||||
if self.output_file is None:
|
||||
raise ValueError("output_file is not set.")
|
||||
|
||||
FILEWRITER_RECOMMENDATION = (
|
||||
filewriter_recommendation = (
|
||||
"For cross-platform file writing, especially on Windows, "
|
||||
"use FileWriterTool from crewai_tools package."
|
||||
)
|
||||
@@ -811,10 +769,10 @@ Follow these guidelines:
|
||||
except (OSError, IOError) as e:
|
||||
raise RuntimeError(
|
||||
"\n".join(
|
||||
[f"Failed to save output file: {e}", FILEWRITER_RECOMMENDATION]
|
||||
[f"Failed to save output file: {e}", filewriter_recommendation]
|
||||
)
|
||||
)
|
||||
return None
|
||||
) from e
|
||||
return
|
||||
|
||||
def __repr__(self):
|
||||
return f"Task(description={self.description}, expected_output={self.expected_output})"
|
||||
|
||||
@@ -6,7 +6,7 @@ Classes:
|
||||
HallucinationGuardrail: Placeholder guardrail that validates task outputs.
|
||||
"""
|
||||
|
||||
from typing import Any, Optional, Tuple
|
||||
from typing import Any
|
||||
|
||||
from crewai.llm import LLM
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
@@ -48,7 +48,7 @@ class HallucinationGuardrail:
|
||||
self,
|
||||
context: str,
|
||||
llm: LLM,
|
||||
threshold: Optional[float] = None,
|
||||
threshold: float | None = None,
|
||||
tool_response: str = "",
|
||||
):
|
||||
"""Initialize the HallucinationGuardrail placeholder.
|
||||
@@ -75,7 +75,7 @@ class HallucinationGuardrail:
|
||||
"""Generate a description of this guardrail for event logging."""
|
||||
return "HallucinationGuardrail (no-op)"
|
||||
|
||||
def __call__(self, task_output: TaskOutput) -> Tuple[bool, Any]:
|
||||
def __call__(self, task_output: TaskOutput) -> tuple[bool, Any]:
|
||||
"""Validate a task output against hallucination criteria.
|
||||
|
||||
In the open source, this method always returns that the output is valid.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Tuple
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -53,7 +53,7 @@ class LLMGuardrail:
|
||||
|
||||
Guardrail:
|
||||
{self.description}
|
||||
|
||||
|
||||
Your task:
|
||||
- Confirm if the Task result complies with the guardrail.
|
||||
- If not, provide clear feedback explaining what is wrong (e.g., by how much it violates the rule, or what specific part fails).
|
||||
@@ -61,11 +61,9 @@ class LLMGuardrail:
|
||||
- If the Task result complies with the guardrail, saying that is valid
|
||||
"""
|
||||
|
||||
result = agent.kickoff(query, response_format=LLMGuardrailResult)
|
||||
return agent.kickoff(query, response_format=LLMGuardrailResult)
|
||||
|
||||
return result
|
||||
|
||||
def __call__(self, task_output: TaskOutput) -> Tuple[bool, Any]:
|
||||
def __call__(self, task_output: TaskOutput) -> tuple[bool, Any]:
|
||||
"""Validates the output of a task based on specified criteria.
|
||||
|
||||
Args:
|
||||
@@ -79,13 +77,11 @@ class LLMGuardrail:
|
||||
|
||||
try:
|
||||
result = self._validate_output(task_output)
|
||||
assert isinstance(
|
||||
result.pydantic, LLMGuardrailResult
|
||||
), "The guardrail result is not a valid pydantic model"
|
||||
if not isinstance(result.pydantic, LLMGuardrailResult):
|
||||
raise ValueError("The guardrail result is not a valid pydantic model")
|
||||
|
||||
if result.pydantic.valid:
|
||||
return True, task_output.raw
|
||||
else:
|
||||
return False, result.pydantic.feedback
|
||||
return False, result.pydantic.feedback
|
||||
except Exception as e:
|
||||
return False, f"Error while validating the task output: {str(e)}"
|
||||
return False, f"Error while validating the task output: {e!s}"
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from .base_tool import BaseTool, tool, EnvVar
|
||||
from .base_tool import BaseTool, EnvVar, tool
|
||||
|
||||
__all__ = [
|
||||
"BaseTool",
|
||||
"tool",
|
||||
"EnvVar",
|
||||
]
|
||||
"tool",
|
||||
]
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
@@ -10,7 +8,7 @@ i18n = I18N()
|
||||
|
||||
class AddImageToolSchema(BaseModel):
|
||||
image_url: str = Field(..., description="The URL or path of the image to add")
|
||||
action: Optional[str] = Field(
|
||||
action: str | None = Field(
|
||||
default=None, description="Optional context or question about the image"
|
||||
)
|
||||
|
||||
@@ -18,14 +16,16 @@ class AddImageToolSchema(BaseModel):
|
||||
class AddImageTool(BaseTool):
|
||||
"""Tool for adding images to the content"""
|
||||
|
||||
name: str = Field(default_factory=lambda: i18n.tools("add_image")["name"]) # type: ignore
|
||||
description: str = Field(default_factory=lambda: i18n.tools("add_image")["description"]) # type: ignore
|
||||
name: str = Field(default_factory=lambda: i18n.tools("add_image")["name"]) # type: ignore[index]
|
||||
description: str = Field(
|
||||
default_factory=lambda: i18n.tools("add_image")["description"] # type: ignore[index]
|
||||
)
|
||||
args_schema: type[BaseModel] = AddImageToolSchema
|
||||
|
||||
def _run(
|
||||
self,
|
||||
image_url: str,
|
||||
action: Optional[str] = None,
|
||||
action: str | None = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
action = action or i18n.tools("add_image")["default_action"] # type: ignore
|
||||
|
||||
@@ -9,9 +9,9 @@ from .delegate_work_tool import DelegateWorkTool
|
||||
class AgentTools:
|
||||
"""Manager class for agent-related tools"""
|
||||
|
||||
def __init__(self, agents: list[BaseAgent], i18n: I18N = I18N()):
|
||||
def __init__(self, agents: list[BaseAgent], i18n: I18N | None = None):
|
||||
self.agents = agents
|
||||
self.i18n = i18n
|
||||
self.i18n = i18n if i18n is not None else I18N()
|
||||
|
||||
def tools(self) -> list[BaseTool]:
|
||||
"""Get all available agent tools"""
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.tools.agent_tools.base_agent_tools import BaseAgentTool
|
||||
@@ -21,7 +19,7 @@ class AskQuestionTool(BaseAgentTool):
|
||||
self,
|
||||
question: str,
|
||||
context: str,
|
||||
coworker: Optional[str] = None,
|
||||
coworker: str | None = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
coworker = self._get_coworker(coworker, **kwargs)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
@@ -38,7 +37,7 @@ class BaseAgentTool(BaseTool):
|
||||
# Remove quotes and convert to lowercase
|
||||
return normalized.replace('"', "").casefold()
|
||||
|
||||
def _get_coworker(self, coworker: Optional[str], **kwargs) -> Optional[str]:
|
||||
def _get_coworker(self, coworker: str | None, **kwargs) -> str | None:
|
||||
coworker = coworker or kwargs.get("co_worker") or kwargs.get("coworker")
|
||||
if coworker:
|
||||
is_list = coworker.startswith("[") and coworker.endswith("]")
|
||||
@@ -47,10 +46,7 @@ class BaseAgentTool(BaseTool):
|
||||
return coworker
|
||||
|
||||
def _execute(
|
||||
self,
|
||||
agent_name: Optional[str],
|
||||
task: str,
|
||||
context: Optional[str] = None
|
||||
self, agent_name: str | None, task: str, context: str | None = None
|
||||
) -> str:
|
||||
"""
|
||||
Execute delegation to an agent with case-insensitive and whitespace-tolerant matching.
|
||||
@@ -77,7 +73,9 @@ class BaseAgentTool(BaseTool):
|
||||
# when it should look like this:
|
||||
# {"task": "....", "coworker": "...."}
|
||||
sanitized_name = self.sanitize_agent_name(agent_name)
|
||||
logger.debug(f"Sanitized agent name from '{agent_name}' to '{sanitized_name}'")
|
||||
logger.debug(
|
||||
f"Sanitized agent name from '{agent_name}' to '{sanitized_name}'"
|
||||
)
|
||||
|
||||
available_agents = [agent.role for agent in self.agents]
|
||||
logger.debug(f"Available agents: {available_agents}")
|
||||
@@ -87,38 +85,47 @@ class BaseAgentTool(BaseTool):
|
||||
for available_agent in self.agents
|
||||
if self.sanitize_agent_name(available_agent.role) == sanitized_name
|
||||
]
|
||||
logger.debug(f"Found {len(agent)} matching agents for role '{sanitized_name}'")
|
||||
logger.debug(
|
||||
f"Found {len(agent)} matching agents for role '{sanitized_name}'"
|
||||
)
|
||||
except (AttributeError, ValueError) as e:
|
||||
# Handle specific exceptions that might occur during role name processing
|
||||
return self.i18n.errors("agent_tool_unexisting_coworker").format(
|
||||
coworkers="\n".join(
|
||||
[f"- {self.sanitize_agent_name(agent.role)}" for agent in self.agents]
|
||||
[
|
||||
f"- {self.sanitize_agent_name(agent.role)}"
|
||||
for agent in self.agents
|
||||
]
|
||||
),
|
||||
error=str(e)
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
if not agent:
|
||||
# No matching agent found after sanitization
|
||||
return self.i18n.errors("agent_tool_unexisting_coworker").format(
|
||||
coworkers="\n".join(
|
||||
[f"- {self.sanitize_agent_name(agent.role)}" for agent in self.agents]
|
||||
[
|
||||
f"- {self.sanitize_agent_name(agent.role)}"
|
||||
for agent in self.agents
|
||||
]
|
||||
),
|
||||
error=f"No agent found with role '{sanitized_name}'"
|
||||
error=f"No agent found with role '{sanitized_name}'",
|
||||
)
|
||||
|
||||
agent = agent[0]
|
||||
selected_agent = agent[0]
|
||||
try:
|
||||
task_with_assigned_agent = Task(
|
||||
description=task,
|
||||
agent=agent,
|
||||
expected_output=agent.i18n.slice("manager_request"),
|
||||
i18n=agent.i18n,
|
||||
agent=selected_agent,
|
||||
expected_output=selected_agent.i18n.slice("manager_request"),
|
||||
i18n=selected_agent.i18n,
|
||||
)
|
||||
logger.debug(f"Created task for agent '{self.sanitize_agent_name(agent.role)}': {task}")
|
||||
return agent.execute_task(task_with_assigned_agent, context)
|
||||
logger.debug(
|
||||
f"Created task for agent '{self.sanitize_agent_name(selected_agent.role)}': {task}"
|
||||
)
|
||||
return selected_agent.execute_task(task_with_assigned_agent, context)
|
||||
except Exception as e:
|
||||
# Handle task creation or execution errors
|
||||
return self.i18n.errors("agent_tool_execution_error").format(
|
||||
agent_role=self.sanitize_agent_name(agent.role),
|
||||
error=str(e)
|
||||
agent_role=self.sanitize_agent_name(selected_agent.role), error=str(e)
|
||||
)
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.tools.agent_tools.base_agent_tools import BaseAgentTool
|
||||
@@ -23,7 +21,7 @@ class DelegateWorkTool(BaseAgentTool):
|
||||
self,
|
||||
task: str,
|
||||
context: str,
|
||||
coworker: Optional[str] = None,
|
||||
coworker: str | None = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
coworker = self._get_coworker(coworker, **kwargs)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from inspect import signature
|
||||
from typing import Any, Callable, Type, get_args, get_origin, Optional, List
|
||||
from typing import Any, get_args, get_origin
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
@@ -19,7 +20,7 @@ class EnvVar(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
required: bool = True
|
||||
default: Optional[str] = None
|
||||
default: str | None = None
|
||||
|
||||
|
||||
class BaseTool(BaseModel, ABC):
|
||||
@@ -32,10 +33,10 @@ class BaseTool(BaseModel, ABC):
|
||||
"""The unique name of the tool that clearly communicates its purpose."""
|
||||
description: str
|
||||
"""Used to tell the model how/when/why to use the tool."""
|
||||
env_vars: List[EnvVar] = []
|
||||
env_vars: list[EnvVar] = []
|
||||
"""List of environment variables used by the tool."""
|
||||
args_schema: Type[PydanticBaseModel] = Field(
|
||||
default_factory=_ArgsSchemaPlaceholder, validate_default=True
|
||||
args_schema: type[PydanticBaseModel] = Field(
|
||||
default=_ArgsSchemaPlaceholder, validate_default=True
|
||||
)
|
||||
"""The schema for the arguments that the tool accepts."""
|
||||
description_updated: bool = False
|
||||
@@ -52,9 +53,9 @@ class BaseTool(BaseModel, ABC):
|
||||
@field_validator("args_schema", mode="before")
|
||||
@classmethod
|
||||
def _default_args_schema(
|
||||
cls, v: Type[PydanticBaseModel]
|
||||
) -> Type[PydanticBaseModel]:
|
||||
if not isinstance(v, cls._ArgsSchemaPlaceholder):
|
||||
cls, v: type[PydanticBaseModel]
|
||||
) -> type[PydanticBaseModel]:
|
||||
if v != cls._ArgsSchemaPlaceholder:
|
||||
return v
|
||||
|
||||
return type(
|
||||
@@ -139,7 +140,7 @@ class BaseTool(BaseModel, ABC):
|
||||
# Infer args_schema from the function signature if not provided
|
||||
func_signature = signature(tool.func)
|
||||
annotations = func_signature.parameters
|
||||
args_fields = {}
|
||||
args_fields: dict[str, Any] = {}
|
||||
for name, param in annotations.items():
|
||||
if name != "self":
|
||||
param_annotation = (
|
||||
@@ -247,7 +248,7 @@ class Tool(BaseTool):
|
||||
# Infer args_schema from the function signature if not provided
|
||||
func_signature = signature(tool.func)
|
||||
annotations = func_signature.parameters
|
||||
args_fields = {}
|
||||
args_fields: dict[str, Any] = {}
|
||||
for name, param in annotations.items():
|
||||
if name != "self":
|
||||
param_annotation = (
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel as PydanticBaseModel
|
||||
@@ -7,7 +7,7 @@ from pydantic import Field as PydanticField
|
||||
|
||||
class ToolCalling(BaseModel):
|
||||
tool_name: str = Field(..., description="The name of the tool to be called.")
|
||||
arguments: Optional[Dict[str, Any]] = Field(
|
||||
arguments: dict[str, Any] | None = Field(
|
||||
..., description="A dictionary of arguments to be passed to the tool."
|
||||
)
|
||||
|
||||
@@ -16,6 +16,6 @@ class InstructorToolCalling(PydanticBaseModel):
|
||||
tool_name: str = PydanticField(
|
||||
..., description="The name of the tool to be called."
|
||||
)
|
||||
arguments: Optional[Dict[str, Any]] = PydanticField(
|
||||
arguments: dict[str, Any] | None = PydanticField(
|
||||
..., description="A dictionary of arguments to be passed to the tool."
|
||||
)
|
||||
|
||||
@@ -1,26 +1,24 @@
|
||||
from .converter import Converter, ConverterError
|
||||
from .file_handler import FileHandler
|
||||
from .i18n import I18N
|
||||
from .internal_instructor import InternalInstructor
|
||||
from .logger import Logger
|
||||
from .parser import YamlParser
|
||||
from .printer import Printer
|
||||
from .prompts import Prompts
|
||||
from .rpm_controller import RPMController
|
||||
from .exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededException,
|
||||
from crewai.utilities.converter import Converter, ConverterError
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
)
|
||||
from crewai.utilities.file_handler import FileHandler
|
||||
from crewai.utilities.i18n import I18N
|
||||
from crewai.utilities.internal_instructor import InternalInstructor
|
||||
from crewai.utilities.logger import Logger
|
||||
from crewai.utilities.printer import Printer
|
||||
from crewai.utilities.prompts import Prompts
|
||||
from crewai.utilities.rpm_controller import RPMController
|
||||
|
||||
__all__ = [
|
||||
"I18N",
|
||||
"Converter",
|
||||
"ConverterError",
|
||||
"FileHandler",
|
||||
"I18N",
|
||||
"InternalInstructor",
|
||||
"LLMContextLengthExceededError",
|
||||
"Logger",
|
||||
"Printer",
|
||||
"Prompts",
|
||||
"RPMController",
|
||||
"YamlParser",
|
||||
"LLMContextLengthExceededException",
|
||||
]
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any, Final, Literal, TypedDict
|
||||
|
||||
from rich.console import Console
|
||||
|
||||
@@ -19,18 +21,47 @@ from crewai.tools import BaseTool as CrewAITool
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
from crewai.tools.tool_types import ToolResult
|
||||
from crewai.utilities import I18N, Printer
|
||||
from crewai.utilities.errors import AgentRepositoryError
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededException,
|
||||
LLMContextLengthExceededError,
|
||||
)
|
||||
from crewai.utilities.i18n import I18N
|
||||
from crewai.utilities.printer import ColoredText, Printer
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent import Agent
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
class SummaryContent(TypedDict):
|
||||
"""Structure for summary content entries.
|
||||
|
||||
Attributes:
|
||||
content: The summarized content.
|
||||
"""
|
||||
|
||||
content: str
|
||||
|
||||
|
||||
console = Console()
|
||||
|
||||
_MULTIPLE_NEWLINES: Final[re.Pattern[str]] = re.compile(r"\n+")
|
||||
|
||||
|
||||
def parse_tools(tools: list[BaseTool]) -> list[CrewStructuredTool]:
|
||||
"""Parse tools to be used for the task."""
|
||||
tools_list = []
|
||||
"""Parse tools to be used for the task.
|
||||
|
||||
Args:
|
||||
tools: List of tools to parse.
|
||||
|
||||
Returns:
|
||||
List of structured tools.
|
||||
|
||||
Raises:
|
||||
ValueError: If a tool is not a CrewStructuredTool or BaseTool.
|
||||
"""
|
||||
tools_list: list[CrewStructuredTool] = []
|
||||
|
||||
for tool in tools:
|
||||
if isinstance(tool, CrewAITool):
|
||||
@@ -42,7 +73,14 @@ def parse_tools(tools: list[BaseTool]) -> list[CrewStructuredTool]:
|
||||
|
||||
|
||||
def get_tool_names(tools: Sequence[CrewStructuredTool | BaseTool]) -> str:
|
||||
"""Get the names of the tools."""
|
||||
"""Get the names of the tools.
|
||||
|
||||
Args:
|
||||
tools: List of tools to get names from.
|
||||
|
||||
Returns:
|
||||
Comma-separated string of tool names.
|
||||
"""
|
||||
return ", ".join([t.name for t in tools])
|
||||
|
||||
|
||||
@@ -51,16 +89,30 @@ def render_text_description_and_args(
|
||||
) -> str:
|
||||
"""Render the tool name, description, and args in plain text.
|
||||
|
||||
search: This tool is used for search, args: {"query": {"type": "string"}}
|
||||
calculator: This tool is used for math, \
|
||||
args: {"expression": {"type": "string"}}
|
||||
search: This tool is used for search, args: {"query": {"type": "string"}}
|
||||
calculator: This tool is used for math, \
|
||||
args: {"expression": {"type": "string"}}
|
||||
|
||||
Args:
|
||||
tools: List of tools to render.
|
||||
|
||||
Returns:
|
||||
Plain text description of tools.
|
||||
"""
|
||||
tool_strings = [tool.description for tool in tools]
|
||||
return "\n".join(tool_strings)
|
||||
|
||||
|
||||
def has_reached_max_iterations(iterations: int, max_iterations: int) -> bool:
|
||||
"""Check if the maximum number of iterations has been reached."""
|
||||
"""Check if the maximum number of iterations has been reached.
|
||||
|
||||
Args:
|
||||
iterations: Current number of iterations.
|
||||
max_iterations: Maximum allowed iterations.
|
||||
|
||||
Returns:
|
||||
True if maximum iterations reached, False otherwise.
|
||||
"""
|
||||
return iterations >= max_iterations
|
||||
|
||||
|
||||
@@ -68,16 +120,19 @@ def handle_max_iterations_exceeded(
|
||||
formatted_answer: AgentAction | AgentFinish | None,
|
||||
printer: Printer,
|
||||
i18n: I18N,
|
||||
messages: list[dict[str, str]],
|
||||
messages: list[LLMMessage],
|
||||
llm: LLM | BaseLLM,
|
||||
callbacks: list[Any],
|
||||
callbacks: list[Callable[..., Any]],
|
||||
) -> AgentAction | AgentFinish:
|
||||
"""
|
||||
Handles the case when the maximum number of iterations is exceeded.
|
||||
Performs one more LLM call to get the final answer.
|
||||
"""Handles the case when the maximum number of iterations is exceeded. Performs one more LLM call to get the final answer.
|
||||
|
||||
Parameters:
|
||||
Args:
|
||||
formatted_answer: The last formatted answer from the agent.
|
||||
printer: Printer instance for output.
|
||||
i18n: I18N instance for internationalization.
|
||||
messages: List of messages to send to the LLM.
|
||||
llm: The LLM instance to call.
|
||||
callbacks: List of callbacks for the LLM call.
|
||||
|
||||
Returns:
|
||||
The final formatted answer after exceeding max iterations.
|
||||
@@ -98,7 +153,7 @@ def handle_max_iterations_exceeded(
|
||||
|
||||
# Perform one more LLM call to get the final answer
|
||||
answer = llm.call(
|
||||
messages,
|
||||
messages, # type: ignore[arg-type]
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
@@ -110,20 +165,38 @@ def handle_max_iterations_exceeded(
|
||||
raise ValueError("Invalid response from LLM call - None or empty.")
|
||||
|
||||
# Return the formatted answer, regardless of its type
|
||||
return format_answer(answer)
|
||||
return format_answer(answer=answer)
|
||||
|
||||
|
||||
def format_message_for_llm(prompt: str, role: str = "user") -> dict[str, str]:
|
||||
def format_message_for_llm(
|
||||
prompt: str, role: Literal["user", "assistant", "system"] = "user"
|
||||
) -> LLMMessage:
|
||||
"""Format a message for the LLM.
|
||||
|
||||
Args:
|
||||
prompt: The message content.
|
||||
role: The role of the message sender, either 'user' or 'assistant'.
|
||||
|
||||
Returns:
|
||||
A dictionary with 'role' and 'content' keys.
|
||||
|
||||
"""
|
||||
prompt = prompt.rstrip()
|
||||
return {"role": role, "content": prompt}
|
||||
|
||||
|
||||
def format_answer(answer: str) -> AgentAction | AgentFinish:
|
||||
"""Format a response from the LLM into an AgentAction or AgentFinish."""
|
||||
"""Format a response from the LLM into an AgentAction or AgentFinish.
|
||||
|
||||
Args:
|
||||
answer: The raw response from the LLM
|
||||
|
||||
Returns:
|
||||
Either an AgentAction or AgentFinish
|
||||
"""
|
||||
try:
|
||||
return parse(answer)
|
||||
except Exception:
|
||||
# If parsing fails, return a default AgentFinish
|
||||
return AgentFinish(
|
||||
thought="Failed to parse LLM response",
|
||||
output=answer,
|
||||
@@ -134,23 +207,43 @@ def format_answer(answer: str) -> AgentAction | AgentFinish:
|
||||
def enforce_rpm_limit(
|
||||
request_within_rpm_limit: Callable[[], bool] | None = None,
|
||||
) -> None:
|
||||
"""Enforce the requests per minute (RPM) limit if applicable."""
|
||||
"""Enforce the requests per minute (RPM) limit if applicable.
|
||||
|
||||
Args:
|
||||
request_within_rpm_limit: Function to enforce RPM limit.
|
||||
"""
|
||||
if request_within_rpm_limit:
|
||||
request_within_rpm_limit()
|
||||
|
||||
|
||||
def get_llm_response(
|
||||
llm: LLM | BaseLLM,
|
||||
messages: list[dict[str, str]],
|
||||
callbacks: list[Any],
|
||||
messages: list[LLMMessage],
|
||||
callbacks: list[Callable[..., Any]],
|
||||
printer: Printer,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
) -> str:
|
||||
"""Call the LLM and return the response, handling any invalid responses."""
|
||||
"""Call the LLM and return the response, handling any invalid responses.
|
||||
|
||||
Args:
|
||||
llm: The LLM instance to call
|
||||
messages: The messages to send to the LLM
|
||||
callbacks: List of callbacks for the LLM call
|
||||
printer: Printer instance for output
|
||||
from_task: Optional task context for the LLM call
|
||||
from_agent: Optional agent context for the LLM call
|
||||
|
||||
Returns:
|
||||
The response from the LLM as a string
|
||||
|
||||
Raises:
|
||||
Exception: If an error occurs.
|
||||
ValueError: If the response is None or empty.
|
||||
"""
|
||||
try:
|
||||
answer = llm.call(
|
||||
messages,
|
||||
messages, # type: ignore[arg-type]
|
||||
callbacks=callbacks,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
@@ -170,7 +263,15 @@ def get_llm_response(
|
||||
def process_llm_response(
|
||||
answer: str, use_stop_words: bool
|
||||
) -> AgentAction | AgentFinish:
|
||||
"""Process the LLM response and format it into an AgentAction or AgentFinish."""
|
||||
"""Process the LLM response and format it into an AgentAction or AgentFinish.
|
||||
|
||||
Args:
|
||||
answer: The raw response from the LLM
|
||||
use_stop_words: Whether to use stop words in the LLM call
|
||||
|
||||
Returns:
|
||||
Either an AgentAction or AgentFinish
|
||||
"""
|
||||
if not use_stop_words:
|
||||
try:
|
||||
# Preliminary parsing to check for errors.
|
||||
@@ -200,6 +301,9 @@ def handle_agent_action_core(
|
||||
|
||||
Returns:
|
||||
Either an AgentAction or AgentFinish
|
||||
|
||||
Notes:
|
||||
- TODO: Remove messages parameter and its usage.
|
||||
"""
|
||||
if step_callback:
|
||||
step_callback(tool_result)
|
||||
@@ -220,7 +324,7 @@ def handle_agent_action_core(
|
||||
return formatted_answer
|
||||
|
||||
|
||||
def handle_unknown_error(printer: Any, exception: Exception) -> None:
|
||||
def handle_unknown_error(printer: Printer, exception: Exception) -> None:
|
||||
"""Handle unknown errors by informing the user.
|
||||
|
||||
Args:
|
||||
@@ -244,10 +348,10 @@ def handle_unknown_error(printer: Any, exception: Exception) -> None:
|
||||
|
||||
def handle_output_parser_exception(
|
||||
e: OutputParserError,
|
||||
messages: list[dict[str, str]],
|
||||
messages: list[LLMMessage],
|
||||
iterations: int,
|
||||
log_error_after: int = 3,
|
||||
printer: Any | None = None,
|
||||
printer: Printer | None = None,
|
||||
) -> AgentAction:
|
||||
"""Handle OutputParserError by updating messages and formatted_answer.
|
||||
|
||||
@@ -288,18 +392,18 @@ def is_context_length_exceeded(exception: Exception) -> bool:
|
||||
Returns:
|
||||
bool: True if the exception is due to context length exceeding
|
||||
"""
|
||||
return LLMContextLengthExceededException(str(exception))._is_context_limit_error(
|
||||
return LLMContextLengthExceededError(str(exception))._is_context_limit_error(
|
||||
str(exception)
|
||||
)
|
||||
|
||||
|
||||
def handle_context_length(
|
||||
respect_context_window: bool,
|
||||
printer: Any,
|
||||
messages: list[dict[str, str]],
|
||||
llm: Any,
|
||||
callbacks: list[Any],
|
||||
i18n: Any,
|
||||
printer: Printer,
|
||||
messages: list[LLMMessage],
|
||||
llm: LLM | BaseLLM,
|
||||
callbacks: list[Callable[..., Any]],
|
||||
i18n: I18N,
|
||||
) -> None:
|
||||
"""Handle context length exceeded by either summarizing or raising an error.
|
||||
|
||||
@@ -310,13 +414,16 @@ def handle_context_length(
|
||||
llm: LLM instance for summarization
|
||||
callbacks: List of callbacks for LLM
|
||||
i18n: I18N instance for messages
|
||||
|
||||
Raises:
|
||||
SystemExit: If context length is exceeded and user opts not to summarize
|
||||
"""
|
||||
if respect_context_window:
|
||||
printer.print(
|
||||
content="Context length exceeded. Summarizing content to fit the model context window. Might take a while...",
|
||||
color="yellow",
|
||||
)
|
||||
summarize_messages(messages, llm, callbacks, i18n)
|
||||
summarize_messages(messages=messages, llm=llm, callbacks=callbacks, i18n=i18n)
|
||||
else:
|
||||
printer.print(
|
||||
content="Context length exceeded. Consider using smaller text or RAG tools from crewai_tools.",
|
||||
@@ -328,10 +435,10 @@ def handle_context_length(
|
||||
|
||||
|
||||
def summarize_messages(
|
||||
messages: list[dict[str, str]],
|
||||
llm: Any,
|
||||
callbacks: list[Any],
|
||||
i18n: Any,
|
||||
messages: list[LLMMessage],
|
||||
llm: LLM | BaseLLM,
|
||||
callbacks: list[Callable[..., Any]],
|
||||
i18n: I18N,
|
||||
) -> None:
|
||||
"""Summarize messages to fit within context window.
|
||||
|
||||
@@ -349,7 +456,7 @@ def summarize_messages(
|
||||
for i in range(0, len(messages_string), cut_size)
|
||||
]
|
||||
|
||||
summarized_contents = []
|
||||
summarized_contents: list[SummaryContent] = []
|
||||
|
||||
total_groups = len(messages_groups)
|
||||
for idx, group in enumerate(messages_groups, 1):
|
||||
@@ -357,15 +464,17 @@ def summarize_messages(
|
||||
content=f"Summarizing {idx}/{total_groups}...",
|
||||
color="yellow",
|
||||
)
|
||||
|
||||
messages = [
|
||||
format_message_for_llm(
|
||||
i18n.slice("summarizer_system_message"), role="system"
|
||||
),
|
||||
format_message_for_llm(
|
||||
i18n.slice("summarize_instruction").format(group=group["content"]),
|
||||
),
|
||||
]
|
||||
summary = llm.call(
|
||||
[
|
||||
format_message_for_llm(
|
||||
i18n.slice("summarizer_system_message"), role="system"
|
||||
),
|
||||
format_message_for_llm(
|
||||
i18n.slice("summarize_instruction").format(group=group["content"]),
|
||||
),
|
||||
],
|
||||
messages, # type: ignore[arg-type]
|
||||
callbacks=callbacks,
|
||||
)
|
||||
summarized_contents.append({"content": str(summary)})
|
||||
@@ -404,20 +513,29 @@ def show_agent_logs(
|
||||
if formatted_answer is None:
|
||||
# Start logs
|
||||
printer.print(
|
||||
content=f"\033[1m\033[95m# Agent:\033[00m \033[1m\033[92m{agent_role}\033[00m"
|
||||
content=[
|
||||
ColoredText("# Agent: ", "bold_purple"),
|
||||
ColoredText(agent_role, "bold_green"),
|
||||
]
|
||||
)
|
||||
if task_description:
|
||||
printer.print(
|
||||
content=f"\033[95m## Task:\033[00m \033[92m{task_description}\033[00m"
|
||||
content=[
|
||||
ColoredText("## Task: ", "purple"),
|
||||
ColoredText(task_description, "green"),
|
||||
]
|
||||
)
|
||||
else:
|
||||
# Execution logs
|
||||
printer.print(
|
||||
content=f"\n\n\033[1m\033[95m# Agent:\033[00m \033[1m\033[92m{agent_role}\033[00m"
|
||||
content=[
|
||||
ColoredText("\n\n# Agent: ", "bold_purple"),
|
||||
ColoredText(agent_role, "bold_green"),
|
||||
]
|
||||
)
|
||||
|
||||
if isinstance(formatted_answer, AgentAction):
|
||||
thought = re.sub(r"\n+", "\n", formatted_answer.thought)
|
||||
thought = _MULTIPLE_NEWLINES.sub("\n", formatted_answer.thought)
|
||||
formatted_json = json.dumps(
|
||||
formatted_answer.tool_input,
|
||||
indent=2,
|
||||
@@ -425,24 +543,39 @@ def show_agent_logs(
|
||||
)
|
||||
if thought and thought != "":
|
||||
printer.print(
|
||||
content=f"\033[95m## Thought:\033[00m \033[92m{thought}\033[00m"
|
||||
content=[
|
||||
ColoredText("## Thought: ", "purple"),
|
||||
ColoredText(thought, "green"),
|
||||
]
|
||||
)
|
||||
printer.print(
|
||||
content=f"\033[95m## Using tool:\033[00m \033[92m{formatted_answer.tool}\033[00m"
|
||||
content=[
|
||||
ColoredText("## Using tool: ", "purple"),
|
||||
ColoredText(formatted_answer.tool, "green"),
|
||||
]
|
||||
)
|
||||
printer.print(
|
||||
content=f"\033[95m## Tool Input:\033[00m \033[92m\n{formatted_json}\033[00m"
|
||||
content=[
|
||||
ColoredText("## Tool Input: ", "purple"),
|
||||
ColoredText(f"\n{formatted_json}", "green"),
|
||||
]
|
||||
)
|
||||
printer.print(
|
||||
content=f"\033[95m## Tool Output:\033[00m \033[92m\n{formatted_answer.result}\033[00m"
|
||||
content=[
|
||||
ColoredText("## Tool Output: ", "purple"),
|
||||
ColoredText(f"\n{formatted_answer.result}", "green"),
|
||||
]
|
||||
)
|
||||
elif isinstance(formatted_answer, AgentFinish):
|
||||
printer.print(
|
||||
content=f"\033[95m## Final Answer:\033[00m \033[92m\n{formatted_answer.output}\033[00m\n\n"
|
||||
content=[
|
||||
ColoredText("## Final Answer: ", "purple"),
|
||||
ColoredText(f"\n{formatted_answer.output}\n\n", "green"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _print_current_organization():
|
||||
def _print_current_organization() -> None:
|
||||
settings = Settings()
|
||||
if settings.org_uuid:
|
||||
console.print(
|
||||
@@ -457,6 +590,17 @@ def _print_current_organization():
|
||||
|
||||
|
||||
def load_agent_from_repository(from_repository: str) -> dict[str, Any]:
|
||||
"""Load an agent from the repository.
|
||||
|
||||
Args:
|
||||
from_repository: The name of the agent to load.
|
||||
|
||||
Returns:
|
||||
A dictionary of attributes to use for the agent.
|
||||
|
||||
Raises:
|
||||
AgentRepositoryError: If the agent cannot be loaded.
|
||||
"""
|
||||
attributes: dict[str, Any] = {}
|
||||
if from_repository:
|
||||
import importlib
|
||||
|
||||
@@ -1,20 +1,19 @@
|
||||
from typing import Any, Dict, Type
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def process_config(
|
||||
values: Dict[str, Any], model_class: Type[BaseModel]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Process the config dictionary and update the values accordingly.
|
||||
values: dict[str, Any], model_class: type[BaseModel]
|
||||
) -> dict[str, Any]:
|
||||
"""Process the config dictionary and update the values accordingly.
|
||||
|
||||
Args:
|
||||
values (Dict[str, Any]): The dictionary of values to update.
|
||||
model_class (Type[BaseModel]): The Pydantic model class to reference for field validation.
|
||||
values: The dictionary of values to update.
|
||||
model_class: The Pydantic model class to reference for field validation.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The updated values dictionary.
|
||||
The updated values dictionary.
|
||||
"""
|
||||
config = values.get("config", {})
|
||||
if not config:
|
||||
|
||||
@@ -1,19 +1,32 @@
|
||||
TRAINING_DATA_FILE = "training_data.pkl"
|
||||
TRAINED_AGENTS_DATA_FILE = "trained_agents_data.pkl"
|
||||
DEFAULT_SCORE_THRESHOLD = 0.35
|
||||
KNOWLEDGE_DIRECTORY = "knowledge"
|
||||
MAX_LLM_RETRY = 3
|
||||
MAX_FILE_NAME_LENGTH = 255
|
||||
EMITTER_COLOR = "bold_blue"
|
||||
from typing import Annotated, Final
|
||||
|
||||
from crewai.utilities.printer import PrinterColor
|
||||
|
||||
TRAINING_DATA_FILE: Final[str] = "training_data.pkl"
|
||||
TRAINED_AGENTS_DATA_FILE: Final[str] = "trained_agents_data.pkl"
|
||||
KNOWLEDGE_DIRECTORY: Final[str] = "knowledge"
|
||||
MAX_FILE_NAME_LENGTH: Final[int] = 255
|
||||
EMITTER_COLOR: Final[PrinterColor] = "bold_blue"
|
||||
|
||||
|
||||
class _NotSpecified:
|
||||
def __repr__(self):
|
||||
"""Sentinel class to detect when no value has been explicitly provided.
|
||||
|
||||
Notes:
|
||||
- TODO: Consider moving this class and NOT_SPECIFIED to types.py
|
||||
as they are more type-related constructs than business constants.
|
||||
"""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "NOT_SPECIFIED"
|
||||
|
||||
|
||||
# Sentinel value used to detect when no value has been explicitly provided.
|
||||
# Unlike `None`, which might be a valid value from the user, `NOT_SPECIFIED` allows
|
||||
# us to distinguish between "not passed at all" and "explicitly passed None" or "[]".
|
||||
NOT_SPECIFIED = _NotSpecified()
|
||||
CREWAI_BASE_URL = "https://app.crewai.com"
|
||||
NOT_SPECIFIED: Final[
|
||||
Annotated[
|
||||
_NotSpecified,
|
||||
"Sentinel value used to detect when no value has been explicitly provided. "
|
||||
"Unlike `None`, which might be a valid value from the user, `NOT_SPECIFIED` "
|
||||
"allows us to distinguish between 'not passed at all' and 'explicitly passed None' or '[]'.",
|
||||
]
|
||||
] = _NotSpecified()
|
||||
CREWAI_BASE_URL: Final[str] = "https://app.crewai.com"
|
||||
|
||||
@@ -1,18 +1,35 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Optional, Type, Union, get_args, get_origin
|
||||
from typing import TYPE_CHECKING, Any, Final, TypedDict, Union, get_args, get_origin
|
||||
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.agents.agent_builder.utilities.base_output_converter import OutputConverter
|
||||
from crewai.utilities.internal_instructor import InternalInstructor
|
||||
from crewai.utilities.printer import Printer
|
||||
from crewai.utilities.pydantic_schema_parser import PydanticSchemaParser
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent import Agent
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
|
||||
_JSON_PATTERN: Final[re.Pattern[str]] = re.compile(r"({.*})", re.DOTALL)
|
||||
|
||||
|
||||
class ConverterError(Exception):
|
||||
"""Error raised when Converter fails to parse the input."""
|
||||
|
||||
def __init__(self, message: str, *args: object) -> None:
|
||||
"""Initialize the ConverterError with a message.
|
||||
|
||||
Args:
|
||||
message: The error message.
|
||||
*args: Additional arguments for the base Exception class.
|
||||
"""
|
||||
super().__init__(message, *args)
|
||||
self.message = message
|
||||
|
||||
@@ -20,8 +37,18 @@ class ConverterError(Exception):
|
||||
class Converter(OutputConverter):
|
||||
"""Class that converts text into either pydantic or json."""
|
||||
|
||||
def to_pydantic(self, current_attempt=1) -> BaseModel:
|
||||
"""Convert text to pydantic."""
|
||||
def to_pydantic(self, current_attempt: int = 1) -> BaseModel:
|
||||
"""Convert text to pydantic.
|
||||
|
||||
Args:
|
||||
current_attempt: The current attempt number for conversion retries.
|
||||
|
||||
Returns:
|
||||
A Pydantic BaseModel instance.
|
||||
|
||||
Raises:
|
||||
ConverterError: If conversion fails after maximum attempts.
|
||||
"""
|
||||
try:
|
||||
if self.llm.supports_function_calling():
|
||||
result = self._create_instructor().to_pydantic()
|
||||
@@ -37,104 +64,124 @@ class Converter(OutputConverter):
|
||||
result = self.model.model_validate_json(response)
|
||||
except ValidationError:
|
||||
# If direct validation fails, attempt to extract valid JSON
|
||||
result = handle_partial_json(response, self.model, False, None)
|
||||
result = handle_partial_json(
|
||||
result=response,
|
||||
model=self.model,
|
||||
is_json_output=False,
|
||||
agent=None,
|
||||
)
|
||||
# Ensure result is a BaseModel instance
|
||||
if not isinstance(result, BaseModel):
|
||||
if isinstance(result, dict):
|
||||
result = self.model.parse_obj(result)
|
||||
result = self.model.model_validate(result)
|
||||
elif isinstance(result, str):
|
||||
try:
|
||||
parsed = json.loads(result)
|
||||
result = self.model.parse_obj(parsed)
|
||||
result = self.model.model_validate(parsed)
|
||||
except Exception as parse_err:
|
||||
raise ConverterError(
|
||||
f"Failed to convert partial JSON result into Pydantic: {parse_err}"
|
||||
)
|
||||
) from parse_err
|
||||
else:
|
||||
raise ConverterError(
|
||||
"handle_partial_json returned an unexpected type."
|
||||
)
|
||||
) from None
|
||||
return result
|
||||
except ValidationError as e:
|
||||
if current_attempt < self.max_attempts:
|
||||
return self.to_pydantic(current_attempt + 1)
|
||||
raise ConverterError(
|
||||
f"Failed to convert text into a Pydantic model due to validation error: {e}"
|
||||
)
|
||||
) from e
|
||||
except Exception as e:
|
||||
if current_attempt < self.max_attempts:
|
||||
return self.to_pydantic(current_attempt + 1)
|
||||
raise ConverterError(
|
||||
f"Failed to convert text into a Pydantic model due to error: {e}"
|
||||
)
|
||||
) from e
|
||||
|
||||
def to_json(self, current_attempt=1):
|
||||
"""Convert text to json."""
|
||||
def to_json(self, current_attempt: int = 1) -> str | ConverterError | Any: # type: ignore[override]
|
||||
"""Convert text to json.
|
||||
|
||||
Args:
|
||||
current_attempt: The current attempt number for conversion retries.
|
||||
|
||||
Returns:
|
||||
A JSON string or ConverterError if conversion fails.
|
||||
|
||||
Raises:
|
||||
ConverterError: If conversion fails after maximum attempts.
|
||||
|
||||
"""
|
||||
try:
|
||||
if self.llm.supports_function_calling():
|
||||
return self._create_instructor().to_json()
|
||||
else:
|
||||
return json.dumps(
|
||||
self.llm.call(
|
||||
[
|
||||
{"role": "system", "content": self.instructions},
|
||||
{"role": "user", "content": self.text},
|
||||
]
|
||||
)
|
||||
return json.dumps(
|
||||
self.llm.call(
|
||||
[
|
||||
{"role": "system", "content": self.instructions},
|
||||
{"role": "user", "content": self.text},
|
||||
]
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
if current_attempt < self.max_attempts:
|
||||
return self.to_json(current_attempt + 1)
|
||||
return ConverterError(f"Failed to convert text into JSON, error: {e}.")
|
||||
|
||||
def _create_instructor(self):
|
||||
def _create_instructor(self) -> InternalInstructor:
|
||||
"""Create an instructor."""
|
||||
from crewai.utilities import InternalInstructor
|
||||
|
||||
inst = InternalInstructor(
|
||||
return InternalInstructor(
|
||||
llm=self.llm,
|
||||
model=self.model,
|
||||
content=self.text,
|
||||
)
|
||||
return inst
|
||||
|
||||
def _convert_with_instructions(self):
|
||||
"""Create a chain."""
|
||||
from crewai.utilities.crew_pydantic_output_parser import (
|
||||
CrewPydanticOutputParser,
|
||||
)
|
||||
|
||||
parser = CrewPydanticOutputParser(pydantic_object=self.model)
|
||||
result = self.llm.call(
|
||||
[
|
||||
{"role": "system", "content": self.instructions},
|
||||
{"role": "user", "content": self.text},
|
||||
]
|
||||
)
|
||||
return parser.parse_result(result)
|
||||
|
||||
|
||||
def convert_to_model(
|
||||
result: str,
|
||||
output_pydantic: Optional[Type[BaseModel]],
|
||||
output_json: Optional[Type[BaseModel]],
|
||||
agent: Any,
|
||||
converter_cls: Optional[Type[Converter]] = None,
|
||||
) -> Union[dict, BaseModel, str]:
|
||||
output_pydantic: type[BaseModel] | None,
|
||||
output_json: type[BaseModel] | None,
|
||||
agent: Agent | None = None,
|
||||
converter_cls: type[Converter] | None = None,
|
||||
) -> dict[str, Any] | BaseModel | str:
|
||||
"""Convert a result string to a Pydantic model or JSON.
|
||||
|
||||
Args:
|
||||
result: The result string to convert.
|
||||
output_pydantic: The Pydantic model class to convert to.
|
||||
output_json: The Pydantic model class to convert to JSON.
|
||||
agent: The agent instance.
|
||||
converter_cls: The converter class to use.
|
||||
|
||||
Returns:
|
||||
The converted result as a dict, BaseModel, or original string.
|
||||
"""
|
||||
model = output_pydantic or output_json
|
||||
if model is None:
|
||||
return result
|
||||
try:
|
||||
escaped_result = json.dumps(json.loads(result, strict=False))
|
||||
return validate_model(escaped_result, model, bool(output_json))
|
||||
return validate_model(
|
||||
result=escaped_result, model=model, is_json_output=bool(output_json)
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
return handle_partial_json(
|
||||
result, model, bool(output_json), agent, converter_cls
|
||||
result=result,
|
||||
model=model,
|
||||
is_json_output=bool(output_json),
|
||||
agent=agent,
|
||||
converter_cls=converter_cls,
|
||||
)
|
||||
|
||||
except ValidationError:
|
||||
return handle_partial_json(
|
||||
result, model, bool(output_json), agent, converter_cls
|
||||
result=result,
|
||||
model=model,
|
||||
is_json_output=bool(output_json),
|
||||
agent=agent,
|
||||
converter_cls=converter_cls,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -146,8 +193,18 @@ def convert_to_model(
|
||||
|
||||
|
||||
def validate_model(
|
||||
result: str, model: Type[BaseModel], is_json_output: bool
|
||||
) -> Union[dict, BaseModel]:
|
||||
result: str, model: type[BaseModel], is_json_output: bool
|
||||
) -> dict[str, Any] | BaseModel:
|
||||
"""Validate and convert a JSON string to a Pydantic model or dict.
|
||||
|
||||
Args:
|
||||
result: The JSON string to validate and convert.
|
||||
model: The Pydantic model class to convert to.
|
||||
is_json_output: Whether to return a dict (True) or Pydantic model (False).
|
||||
|
||||
Returns:
|
||||
The converted result as a dict or BaseModel.
|
||||
"""
|
||||
exported_result = model.model_validate_json(result)
|
||||
if is_json_output:
|
||||
return exported_result.model_dump()
|
||||
@@ -156,15 +213,27 @@ def validate_model(
|
||||
|
||||
def handle_partial_json(
|
||||
result: str,
|
||||
model: Type[BaseModel],
|
||||
model: type[BaseModel],
|
||||
is_json_output: bool,
|
||||
agent: Any,
|
||||
converter_cls: Optional[Type[Converter]] = None,
|
||||
) -> Union[dict, BaseModel, str]:
|
||||
match = re.search(r"({.*})", result, re.DOTALL)
|
||||
agent: Agent | None,
|
||||
converter_cls: type[Converter] | None = None,
|
||||
) -> dict[str, Any] | BaseModel | str:
|
||||
"""Handle partial JSON in a result string and convert to Pydantic model or dict.
|
||||
|
||||
Args:
|
||||
result: The result string to process.
|
||||
model: The Pydantic model class to convert to.
|
||||
is_json_output: Whether to return a dict (True) or Pydantic model (False).
|
||||
agent: The agent instance.
|
||||
converter_cls: The converter class to use.
|
||||
|
||||
Returns:
|
||||
The converted result as a dict, BaseModel, or original string.
|
||||
"""
|
||||
match = _JSON_PATTERN.search(result)
|
||||
if match:
|
||||
try:
|
||||
exported_result = model.model_validate_json(match.group(0))
|
||||
exported_result = model.model_validate_json(match.group())
|
||||
if is_json_output:
|
||||
return exported_result.model_dump()
|
||||
return exported_result
|
||||
@@ -179,19 +248,43 @@ def handle_partial_json(
|
||||
)
|
||||
|
||||
return convert_with_instructions(
|
||||
result, model, is_json_output, agent, converter_cls
|
||||
result=result,
|
||||
model=model,
|
||||
is_json_output=is_json_output,
|
||||
agent=agent,
|
||||
converter_cls=converter_cls,
|
||||
)
|
||||
|
||||
|
||||
def convert_with_instructions(
|
||||
result: str,
|
||||
model: Type[BaseModel],
|
||||
model: type[BaseModel],
|
||||
is_json_output: bool,
|
||||
agent: Any,
|
||||
converter_cls: Optional[Type[Converter]] = None,
|
||||
) -> Union[dict, BaseModel, str]:
|
||||
agent: Agent | None,
|
||||
converter_cls: type[Converter] | None = None,
|
||||
) -> dict | BaseModel | str:
|
||||
"""Convert a result string to a Pydantic model or JSON using instructions.
|
||||
|
||||
Args:
|
||||
result: The result string to convert.
|
||||
model: The Pydantic model class to convert to.
|
||||
is_json_output: Whether to return a dict (True) or Pydantic model (False).
|
||||
agent: The agent instance.
|
||||
converter_cls: The converter class to use.
|
||||
|
||||
Returns:
|
||||
The converted result as a dict, BaseModel, or original string.
|
||||
|
||||
Raises:
|
||||
TypeError: If neither agent nor converter_cls is provided.
|
||||
|
||||
Notes:
|
||||
- TODO: Fix llm typing issues, return llm should not be able to be str or None.
|
||||
"""
|
||||
if agent is None:
|
||||
raise TypeError("Agent must be provided if converter_cls is not specified.")
|
||||
llm = agent.function_calling_llm or agent.llm
|
||||
instructions = get_conversion_instructions(model, llm)
|
||||
instructions = get_conversion_instructions(model=model, llm=llm)
|
||||
converter = create_converter(
|
||||
agent=agent,
|
||||
converter_cls=converter_cls,
|
||||
@@ -214,9 +307,25 @@ def convert_with_instructions(
|
||||
return exported_result
|
||||
|
||||
|
||||
def get_conversion_instructions(model: Type[BaseModel], llm: Any) -> str:
|
||||
def get_conversion_instructions(
|
||||
model: type[BaseModel], llm: BaseLLM | LLM | str
|
||||
) -> str:
|
||||
"""Generate conversion instructions based on the model and LLM capabilities.
|
||||
|
||||
Args:
|
||||
model: A Pydantic model class.
|
||||
llm: The language model instance.
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
instructions = "Please convert the following text into valid JSON."
|
||||
if llm and not isinstance(llm, str) and llm.supports_function_calling():
|
||||
if (
|
||||
llm
|
||||
and not isinstance(llm, str)
|
||||
and hasattr(llm, "supports_function_calling")
|
||||
and llm.supports_function_calling()
|
||||
):
|
||||
model_schema = PydanticSchemaParser(model=model).get_schema()
|
||||
instructions += (
|
||||
f"\n\nOutput ONLY the valid JSON and nothing else.\n\n"
|
||||
@@ -231,12 +340,45 @@ def get_conversion_instructions(model: Type[BaseModel], llm: Any) -> str:
|
||||
return instructions
|
||||
|
||||
|
||||
class CreateConverterKwargs(TypedDict, total=False):
|
||||
"""Keyword arguments for creating a converter.
|
||||
|
||||
Attributes:
|
||||
llm: The language model instance.
|
||||
text: The text to convert.
|
||||
model: The Pydantic model class.
|
||||
instructions: The conversion instructions.
|
||||
"""
|
||||
|
||||
llm: BaseLLM | LLM | str
|
||||
text: str
|
||||
model: type[BaseModel]
|
||||
instructions: str
|
||||
|
||||
|
||||
def create_converter(
|
||||
agent: Optional[Any] = None,
|
||||
converter_cls: Optional[Type[Converter]] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
agent: Agent | None = None,
|
||||
converter_cls: type[Converter] | None = None,
|
||||
*args: Any,
|
||||
**kwargs: Unpack[CreateConverterKwargs],
|
||||
) -> Converter:
|
||||
"""Create a converter instance based on the agent or provided class.
|
||||
|
||||
Args:
|
||||
agent: The agent instance.
|
||||
converter_cls: The converter class to instantiate.
|
||||
*args: The positional arguments to pass to the converter.
|
||||
**kwargs: The keyword arguments to pass to the converter.
|
||||
|
||||
Returns:
|
||||
An instance of the specified converter class.
|
||||
|
||||
Raises:
|
||||
ValueError: If neither agent nor converter_cls is provided.
|
||||
AttributeError: If the agent does not have a 'get_output_converter' method.
|
||||
Exception: If no converter instance is created.
|
||||
|
||||
"""
|
||||
if agent and not converter_cls:
|
||||
if hasattr(agent, "get_output_converter"):
|
||||
converter = agent.get_output_converter(*args, **kwargs)
|
||||
@@ -253,17 +395,30 @@ def create_converter(
|
||||
return converter
|
||||
|
||||
|
||||
def generate_model_description(model: Type[BaseModel]) -> str:
|
||||
"""
|
||||
Generate a string description of a Pydantic model's fields and their types.
|
||||
def generate_model_description(model: type[BaseModel]) -> str:
|
||||
"""Generate a string description of a Pydantic model's fields and their types.
|
||||
|
||||
This function takes a Pydantic model class and returns a string that describes
|
||||
the model's fields and their respective types. The description includes handling
|
||||
of complex types such as `Optional`, `List`, and `Dict`, as well as nested Pydantic
|
||||
models.
|
||||
|
||||
Args:
|
||||
model: A Pydantic model class.
|
||||
|
||||
Returns:
|
||||
A string representation of the model's fields and types.
|
||||
"""
|
||||
|
||||
def describe_field(field_type):
|
||||
def describe_field(field_type: Any) -> str:
|
||||
"""Recursively describe a field's type.
|
||||
|
||||
Args:
|
||||
field_type: The type of the field to describe.
|
||||
|
||||
Returns:
|
||||
A string representation of the field's type.
|
||||
"""
|
||||
origin = get_origin(field_type)
|
||||
args = get_args(field_type)
|
||||
|
||||
@@ -272,20 +427,18 @@ def generate_model_description(model: Type[BaseModel]) -> str:
|
||||
non_none_args = [arg for arg in args if arg is not type(None)]
|
||||
if len(non_none_args) == 1:
|
||||
return f"Optional[{describe_field(non_none_args[0])}]"
|
||||
else:
|
||||
return f"Optional[Union[{', '.join(describe_field(arg) for arg in non_none_args)}]]"
|
||||
elif origin is list:
|
||||
return f"Optional[Union[{', '.join(describe_field(arg) for arg in non_none_args)}]]"
|
||||
if origin is list:
|
||||
return f"List[{describe_field(args[0])}]"
|
||||
elif origin is dict:
|
||||
if origin is dict:
|
||||
key_type = describe_field(args[0])
|
||||
value_type = describe_field(args[1])
|
||||
return f"Dict[{key_type}, {value_type}]"
|
||||
elif isinstance(field_type, type) and issubclass(field_type, BaseModel):
|
||||
if isinstance(field_type, type) and issubclass(field_type, BaseModel):
|
||||
return generate_model_description(field_type)
|
||||
elif hasattr(field_type, "__name__"):
|
||||
if hasattr(field_type, "__name__"):
|
||||
return field_type.__name__
|
||||
else:
|
||||
return str(field_type)
|
||||
return str(field_type)
|
||||
|
||||
fields = model.model_fields
|
||||
field_descriptions = [
|
||||
|
||||
@@ -1 +1 @@
|
||||
"""Crew-specific utilities."""
|
||||
"""Crew-specific utilities."""
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
"""Context management utilities for tracking crew and task execution context using OpenTelemetry baggage."""
|
||||
|
||||
from typing import Optional
|
||||
from typing import cast
|
||||
|
||||
from opentelemetry import baggage
|
||||
|
||||
from crewai.utilities.crew.models import CrewContext
|
||||
|
||||
|
||||
def get_crew_context() -> Optional[CrewContext]:
|
||||
def get_crew_context() -> CrewContext | None:
|
||||
"""Get the current crew context from OpenTelemetry baggage.
|
||||
|
||||
Returns:
|
||||
CrewContext instance containing crew context information, or None if no context is set
|
||||
"""
|
||||
return baggage.get_baggage("crew_context")
|
||||
return cast(CrewContext | None, baggage.get_baggage("crew_context"))
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
"""Models for crew-related data structures."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class CrewContext(BaseModel):
|
||||
"""Model representing crew context information."""
|
||||
"""Model representing crew context information.
|
||||
|
||||
id: Optional[str] = Field(
|
||||
default=None, description="Unique identifier for the crew"
|
||||
)
|
||||
key: Optional[str] = Field(
|
||||
Attributes:
|
||||
id: Unique identifier for the crew.
|
||||
key: Optional crew key/name for identification.
|
||||
"""
|
||||
|
||||
id: str | None = Field(default=None, description="Unique identifier for the crew")
|
||||
key: str | None = Field(
|
||||
default=None, description="Optional crew key/name for identification"
|
||||
)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user