Compare commits

...

19 Commits

Author SHA1 Message Date
Lucas Gomide
ae8e52b484 wip 2025-09-26 09:11:31 -03:00
Greyson LaLonde
e070c1400c feat: update pydantic, add pydantic-settings, migrate to dependency-groups
Some checks failed
Build uv cache / build-cache (3.10) (push) Has been cancelled
Build uv cache / build-cache (3.11) (push) Has been cancelled
Build uv cache / build-cache (3.12) (push) Has been cancelled
Build uv cache / build-cache (3.13) (push) Has been cancelled
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Update Test Durations / update-durations (3.10) (push) Has been cancelled
Update Test Durations / update-durations (3.11) (push) Has been cancelled
Update Test Durations / update-durations (3.12) (push) Has been cancelled
Update Test Durations / update-durations (3.13) (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
- Add pydantic-settings>=2.10.1 dependency for configuration management
- Update pydantic to 2.11.9 and python-dotenv to 1.1.1
- Migrate from deprecated tool.uv.dev-dependencies to dependency-groups.dev format
- Remove unnecessary dev dependencies: pillow, cairosvg
- Update all dev tooling to latest versions
- Remove duplicate python-dotenv from dev dependencies
2025-09-24 14:42:18 -04:00
Greyson LaLonde
6537e3737d fix: correct directory name in quickstart documentation 2025-09-24 11:41:33 -04:00
Greyson LaLonde
346faf229f feat: add pydantic-compatible import validation and deprecate old utilities 2025-09-24 11:36:02 -04:00
Lorenze Jay
a0b757a12c Lorenze/traces mark as failed (#3586)
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Update Test Durations / update-durations (3.10) (push) Has been cancelled
Update Test Durations / update-durations (3.11) (push) Has been cancelled
Update Test Durations / update-durations (3.12) (push) Has been cancelled
Update Test Durations / update-durations (3.13) (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
* marking trace batch as failed if its failed

* fix test
2025-09-23 22:02:27 -07:00
Greyson LaLonde
1dbe8aab52 fix: add batch_size support to prevent embedder token limit errors
- add batch_size field to baseragconfig (default=100)  
- update chromadb/qdrant clients and factories to use batch_size  
- extract and filter batch_size from embedder config in knowledgestorage  
- fix large csv files exceeding embedder token limits (#3574)  
- remove unneeded conditional for type  

Co-authored-by: Vini Brasil <vini@hey.com>
2025-09-24 00:05:43 -04:00
Greyson LaLonde
4ac65eb0a6 fix: support nested config format for embedder configuration
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Update Test Durations / update-durations (3.10) (push) Has been cancelled
Update Test Durations / update-durations (3.11) (push) Has been cancelled
Update Test Durations / update-durations (3.12) (push) Has been cancelled
Update Test Durations / update-durations (3.13) (push) Has been cancelled
- support nested config format with embedderconfig typeddict  
- fix parsing for model/model_name compatibility  
- add validation, typing_extensions, and improved type hints  
- enhance embedding factory with env var injection and provider support  
- add tests for openai, azure, and all embedding providers  
- misc fixes: test file rename, updated mocking patterns
2025-09-23 11:57:46 -04:00
Greyson LaLonde
3e97393f58 chore: improve typing and consolidate utilities
- add type annotations across utility modules  
- refactor printer system, agent utils, and imports for consistency  
- remove unused modules, constants, and redundant patterns  
- improve runtime type checks, exception handling, and guardrail validation  
- standardize warning suppression and logging utilities  
- fix llm typing, threading/typing edge cases, and test behavior
2025-09-23 11:33:46 -04:00
Heitor Carvalho
34bed359a6 feat: add crewai uv wrapper for uv commands (#3581) 2025-09-23 10:55:15 -04:00
Tony Kipkemboi
feeed505bb docs(changelog): add releases 0.193.2, 0.193.1, 0.193.0, 0.186.1, 0.186.0 across en/ko/pt-BR (#3577)
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
2025-09-22 16:19:55 -07:00
Greyson LaLonde
cb0efd05b4 chore: fix ruff linting issues in tools module
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Update Test Durations / update-durations (3.13) (push) Has been cancelled
Update Test Durations / update-durations (3.10) (push) Has been cancelled
Update Test Durations / update-durations (3.11) (push) Has been cancelled
Update Test Durations / update-durations (3.12) (push) Has been cancelled
linting, args_schema default, and validator check
2025-09-22 13:13:23 -04:00
Greyson LaLonde
db5f565dea fix: apply ruff linting fixes to tasks module 2025-09-22 13:09:53 -04:00
Greyson LaLonde
58413b663a chore: fix ruff linting issues in rag module
linting, list embedding handling, and test update
2025-09-22 13:06:22 -04:00
Greyson LaLonde
37636f0dd7 chore: fix ruff linting and mypy issues in flow module 2025-09-22 13:03:06 -04:00
Greyson LaLonde
0e370593f1 chore: resolve all ruff and mypy issues in experimental module
resolve linting, typing, and import issues; update Okta test
2025-09-22 12:56:28 -04:00
Vini Brasil
aa8dc9d77f Add source to LLM Guardrail events (#3572)
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Update Test Durations / update-durations (3.10) (push) Has been cancelled
Update Test Durations / update-durations (3.11) (push) Has been cancelled
Update Test Durations / update-durations (3.12) (push) Has been cancelled
Update Test Durations / update-durations (3.13) (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
This commit adds the source attribute to LLM Guardrail event calls to
identify the Lite Agent or Task that executed the guardrail.
2025-09-22 11:58:00 +09:00
Jonathan Hill
9c1096dbdc fix: Make 'ready' parameter optional in _create_reasoning_plan function (#3561)
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Update Test Durations / update-durations (3.10) (push) Has been cancelled
Update Test Durations / update-durations (3.11) (push) Has been cancelled
Update Test Durations / update-durations (3.12) (push) Has been cancelled
Update Test Durations / update-durations (3.13) (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
* fix: Make 'ready' parameter optional in _create_reasoning_plan function

This PR fixes Issue #3466 where the _create_reasoning_plan function was missing
the 'ready' parameter when called by the LLM. The fix makes the 'ready' parameter
optional with a default value of False, which allows the function to be called
with only the 'plan' argument.

Fixes #3466

* Change default value of 'ready' parameter to True

---------

Co-authored-by: João Moura <joaomdmoura@gmail.com>
2025-09-20 22:57:18 -03:00
João Moura
47044450c0 Adding fallback to crew settings (#3562)
* Adding fallback to crew settings

* fix: resolve ruff and mypy issues in cli/config.py

---------

Co-authored-by: Greyson Lalonde <greyson.r.lalonde@gmail.com>
2025-09-20 22:54:36 -03:00
João Moura
0ee438c39d fix version (#3557) 2025-09-20 17:14:28 -03:00
150 changed files with 7655 additions and 4370 deletions

View File

@@ -0,0 +1,135 @@
# CrewAI CLI Trigger Feature Implementation
## Overview
Successfully implemented the trigger functionality for CrewAI CLI as requested, adding two main commands:
- `crewai trigger list` - Lists all triggers grouped by provider
- `crewai trigger <app/trigger_name>` - Runs a crew with the specified trigger payload
## Implementation Details
### 1. Extended PlusAPI Client (`src/crewai/cli/plus_api.py`)
- Added `TRIGGERS_RESOURCE = "/v1/triggers"` endpoint constant
- Implemented `list_triggers()` method for GET `/v1/triggers`
- Implemented `get_trigger_sample_payload(trigger_identification)` method for POST `/v1/triggers/sample_payload`
### 2. Created TriggerCommand Class (`src/crewai/cli/trigger_command.py`)
- Inherits from `BaseCommand` and `PlusAPIMixin` for proper authentication
- Implements `list_triggers()` method with:
- Rich table display grouped by provider
- Comprehensive error handling for network issues, authentication, etc.
- User-friendly messages and styling
- Implements `run_trigger(trigger_identification)` method with:
- Trigger identification format validation (`app/trigger_name`)
- Sample payload retrieval from API
- Dynamic crew/flow execution with trigger payload injection
- Temporary script generation and cleanup
- Robust error handling and validation
### 3. Integrated CLI Commands (`src/crewai/cli/cli.py`)
- Added import for `TriggerCommand`
- Implemented `@crewai.command()` decorator for `trigger` command
- Supports both `crewai trigger list` and `crewai trigger <app/trigger_name>` syntax
- Proper argument parsing and command routing
### 4. Key Features
#### Trigger Listing
- Fetches triggers from `/v1/triggers` endpoint
- Displays triggers in a formatted table grouped by provider
- Shows trigger ID and description for each trigger
- Provides usage instructions
#### Trigger Execution
- Validates trigger identification format
- Fetches sample payload from `/v1/triggers/sample_payload` endpoint
- Detects project type (crew vs flow) from `pyproject.toml`
- Generates appropriate execution script with trigger payload injection
- Executes crew/flow with `uv run python` command
- Adds trigger payload to inputs as `crewai_trigger_payload`
- Handles cleanup of temporary files
#### Error Handling
- Network connectivity issues
- Authentication failures (401)
- Authorization issues (403)
- Trigger not found (404)
- Invalid project structure
- Subprocess execution errors
- Comprehensive user feedback with actionable suggestions
### 5. Usage Examples
```bash
# List all available triggers
crewai trigger list
# Run a specific trigger
crewai trigger github/pull_request_opened
crewai trigger slack/message_received
crewai trigger webhook/user_signup
```
### 6. API Integration Points
#### CrewAI Client → Rails App
- GET `/v1/triggers` - Returns triggers grouped by provider
- POST `/v1/triggers/sample_payload` with `{"trigger_identification": "app/trigger_name"}`
#### Expected Response Format
```json
{
"github": {
"github/pull_request_opened": {
"description": "Triggered when a pull request is opened"
},
"github/issue_created": {
"description": "Triggered when an issue is created"
}
},
"slack": {
"slack/message_received": {
"description": "Triggered when a message is received"
}
}
}
```
### 7. Crew/Flow Integration
The trigger payload is automatically injected into the crew/flow inputs as `crewai_trigger_payload`, allowing crews to access trigger data:
```python
# In crew/flow code
def my_crew():
crew = Crew(...)
result = crew.kickoff(inputs=inputs) # inputs will contain 'crewai_trigger_payload'
return result
```
### 8. Dependencies
- `click` - CLI framework
- `rich` - Enhanced terminal output
- `requests` - HTTP client
- Existing CrewAI CLI infrastructure (authentication, configuration, etc.)
## Testing
- All imports work correctly
- CLI command structure is properly implemented
- Error handling is comprehensive
- Code follows CrewAI patterns and conventions
## Next Steps for Backend Implementation
### Rails App Requirements
1. Add `GET /v1/triggers` endpoint
2. Add `POST /v1/triggers/sample_payload` endpoint
3. Implement integration service method `summarize_triggers`
4. Each provider service must implement:
- `list_triggers()` method
- `get_sample_payload(trigger_identification)` method
### CrewAI OAuth Requirements
1. Implement endpoint that returns sample payload for trigger identification
2. Ensure trigger data format matches expected structure
The CLI implementation is complete and ready for integration with the backend services.

View File

@@ -5,6 +5,82 @@ icon: "clock"
mode: "wide"
---
<Update label="Sep 20, 2025">
## v0.193.2
[View release on GitHub](https://github.com/crewAIInc/crewAI/releases/tag/0.193.2)
## What's Changed
- Updated pyproject templates to use the right version
</Update>
<Update label="Sep 20, 2025">
## v0.193.1
[View release on GitHub](https://github.com/crewAIInc/crewAI/releases/tag/0.193.1)
## What's Changed
- Series of minor fixes and linter improvements
</Update>
<Update label="Sep 19, 2025">
## v0.193.0
[View release on GitHub](https://github.com/crewAIInc/crewAI/releases/tag/0.193.0)
## Core Improvements & Fixes
- Fixed handling of the `model` parameter during OpenAI adapter initialization
- Resolved test duration cache issues in CI workflows
- Fixed flaky test related to repeated tool usage by agents
- Added missing event exports to `__init__.py` for consistent module behavior
- Dropped message storage from metadata in Mem0 to reduce bloat
- Fixed L2 distance metric support for backward compatibility in vector search
## New Features & Enhancements
- Introduced thread-safe platform context management
- Added test duration caching for optimized `pytest-split` runs
- Added ephemeral trace improvements for better trace control
- Made search parameters for RAG, knowledge, and memory fully configurable
- Enabled ChromaDB to use OpenAI API for embedding functions
- Added deeper observability tools for user-level insights
- Unified RAG storage system with instance-specific client support
## Documentation & Guides
- Updated `RagTool` references to reflect CrewAI native RAG implementation
- Improved internal docs for `langgraph` and `openai` agent adapters with type annotations and docstrings
</Update>
<Update label="Sep 11, 2025">
## v0.186.1
[View release on GitHub](https://github.com/crewAIInc/crewAI/releases/tag/0.186.1)
## What's Changed
- Fixed version not being found and silently failing reversion
- Bumped CrewAI version to 0.186.1 and updated dependencies in the CLI
</Update>
<Update label="Sep 10, 2025">
## v0.186.0
[View release on GitHub](https://github.com/crewAIInc/crewAI/releases/tag/0.186.0)
## What's Changed
- Refer to the GitHub release notes for detailed changes
</Update>
<Update label="Sep 04, 2025">
## v0.177.0

View File

@@ -404,6 +404,10 @@ crewai config reset
After resetting configuration, re-run `crewai login` to authenticate again.
</Tip>
<Tip>
CrewAI CLI handles authentication to the Tool Repository automatically when adding packages to your project. Just append `crewai` before any `uv` command to use it. E.g. `crewai uv add requests`. For more information, see [Tool Repository](https://docs.crewai.com/enterprise/features/tool-repository) docs.
</Tip>
<Note>
Configuration settings are stored in `~/.config/crewai/settings.json`. Some settings like organization name and UUID are read-only and managed through authentication and organization commands. Tool repository related settings are hidden and cannot be set directly by users.
</Note>

View File

@@ -52,6 +52,36 @@ researcher = Agent(
)
```
## Adding other packages after installing a tool
After installing a tool from the CrewAI Enterprise Tool Repository, you need to use the `crewai uv` command to add other packages to your project.
Using pure `uv` commands will fail due to authentication to tool repository being handled by the CLI. By using the `crewai uv` command, you can add other packages to your project without having to worry about authentication.
Any `uv` command can be used with the `crewai uv` command, making it a powerful tool for managing your project's dependencies without the hassle of managing authentication through environment variables or other methods.
Say that you have installed a custom tool from the CrewAI Enterprise Tool Repository called "my-tool":
```bash
crewai tool install my-tool
```
And now you want to add another package to your project, you can use the following command:
```bash
crewai uv add requests
```
Other commands like `uv sync` or `uv remove` can also be used with the `crewai uv` command:
```bash
crewai uv sync
```
```bash
crewai uv remove requests
```
This will add the package to your project and update `pyproject.toml` accordingly.
## Creating and Publishing Tools
To create a new tool project:

View File

@@ -27,7 +27,7 @@ Follow the steps below to get Crewing! 🚣‍♂️
<Step title="Navigate to your new crew project">
<CodeGroup>
```shell Terminal
cd latest-ai-development
cd latest_ai_development
```
</CodeGroup>
</Step>

View File

@@ -5,6 +5,82 @@ icon: "clock"
mode: "wide"
---
<Update label="2025년 9월 20일">
## v0.193.2
[GitHub 릴리스 보기](https://github.com/crewAIInc/crewAI/releases/tag/0.193.2)
## 변경 사항
- 올바른 버전을 사용하도록 pyproject 템플릿 업데이트
</Update>
<Update label="2025년 9월 20일">
## v0.193.1
[GitHub 릴리스 보기](https://github.com/crewAIInc/crewAI/releases/tag/0.193.1)
## 변경 사항
- 일련의 사소한 수정 및 린터 개선
</Update>
<Update label="2025년 9월 19일">
## v0.193.0
[GitHub 릴리스 보기](https://github.com/crewAIInc/crewAI/releases/tag/0.193.0)
## 핵심 개선 사항 및 수정 사항
- OpenAI 어댑터 초기화 중 `model` 매개변수 처리 수정
- CI 워크플로에서 테스트 소요 시간 캐시 문제 해결
- 에이전트의 반복 도구 사용과 관련된 불안정한 테스트 수정
- 일관된 모듈 동작을 위해 누락된 이벤트 내보내기를 `__init__.py`에 추가
- 메타데이터 부하를 줄이기 위해 Mem0에서 메시지 저장 제거
- 벡터 검색의 하위 호환성을 위해 L2 거리 메트릭 지원 수정
## 새로운 기능 및 향상 사항
- 스레드 안전한 플랫폼 컨텍스트 관리 도입
- `pytest-split` 실행 최적화를 위한 테스트 소요 시간 캐싱 추가
- 더 나은 추적 제어를 위한 일시적(trace) 개선
- RAG, 지식, 메모리 검색 매개변수를 완전 구성 가능하게 변경
- ChromaDB가 임베딩 함수에 OpenAI API를 사용할 수 있도록 지원
- 사용자 수준 인사이트를 위한 심화된 관찰 가능성 도구 추가
- 인스턴스별 클라이언트를 지원하는 통합 RAG 스토리지 시스템
## 문서 및 가이드
- CrewAI 네이티브 RAG 구현을 반영하도록 `RagTool` 참조 업데이트
- 타입 주석과 도크스트링을 포함해 `langgraph` 및 `openai` 에이전트 어댑터 내부 문서 개선
</Update>
<Update label="2025년 9월 11일">
## v0.186.1
[GitHub 릴리스 보기](https://github.com/crewAIInc/crewAI/releases/tag/0.186.1)
## 변경 사항
- 버전을 찾지 못해 조용히 되돌리는(reversion) 문제 수정
- CLI에서 CrewAI 버전을 0.186.1로 올리고 의존성 업데이트
</Update>
<Update label="2025년 9월 10일">
## v0.186.0
[GitHub 릴리스 보기](https://github.com/crewAIInc/crewAI/releases/tag/0.186.0)
## 변경 사항
- 자세한 변경 사항은 GitHub 릴리스 노트를 참조하세요
</Update>
<Update label="2025년 9월 4일">
## v0.177.0

View File

@@ -27,7 +27,7 @@ mode: "wide"
<Step title="새로운 crew 프로젝트로 이동하기">
<CodeGroup>
```shell Terminal
cd latest-ai-development
cd latest_ai_development
```
</CodeGroup>
</Step>

View File

@@ -5,6 +5,82 @@ icon: "clock"
mode: "wide"
---
<Update label="20 set 2025">
## v0.193.2
[Ver release no GitHub](https://github.com/crewAIInc/crewAI/releases/tag/0.193.2)
## O que Mudou
- Atualizados templates do pyproject para usar a versão correta
</Update>
<Update label="20 set 2025">
## v0.193.1
[Ver release no GitHub](https://github.com/crewAIInc/crewAI/releases/tag/0.193.1)
## O que Mudou
- Série de pequenas correções e melhorias de linter
</Update>
<Update label="19 set 2025">
## v0.193.0
[Ver release no GitHub](https://github.com/crewAIInc/crewAI/releases/tag/0.193.0)
## Melhorias e Correções Principais
- Corrigido manuseio do parâmetro `model` durante a inicialização do adaptador OpenAI
- Resolvidos problemas de cache da duração de testes nos fluxos de CI
- Corrigido teste instável relacionado ao uso repetido de ferramentas pelos agentes
- Adicionadas exportações de eventos ausentes no `__init__.py` para comportamento consistente do módulo
- Removido armazenamento de mensagem dos metadados no Mem0 para reduzir inchaço
- Corrigido suporte à métrica de distância L2 para compatibilidade retroativa na busca vetorial
## Novos Recursos e Melhorias
- Introduzida gestão de contexto de plataforma com segurança de threads
- Adicionado cache da duração de testes para execuções otimizadas do `pytest-split`
- Melhorias de traces efêmeros para melhor controle de rastreamento
- Parâmetros de busca para RAG, conhecimento e memória totalmente configuráveis
- Habilitado ChromaDB para usar a OpenAI API para funções de embedding
- Adicionadas ferramentas de observabilidade mais profundas para insights ao nível do usuário
- Sistema de armazenamento RAG unificado com suporte a cliente específico por instância
## Documentação e Guias
- Atualizadas referências do `RagTool` para refletir a implementação nativa de RAG do CrewAI
- Melhorada documentação interna para adaptadores de agente `langgraph` e `openai` com anotações de tipo e docstrings
</Update>
<Update label="11 set 2025">
## v0.186.1
[Ver release no GitHub](https://github.com/crewAIInc/crewAI/releases/tag/0.186.1)
## O que Mudou
- Corrigida falha silenciosa de reversão quando a versão não era encontrada
- Versão do CrewAI atualizada para 0.186.1 e dependências do CLI atualizadas
</Update>
<Update label="10 set 2025">
## v0.186.0
[Ver release no GitHub](https://github.com/crewAIInc/crewAI/releases/tag/0.186.0)
## O que Mudou
- Consulte as notas de lançamento no GitHub para detalhes completos
</Update>
<Update label="04 set 2025">
## v0.177.0

View File

@@ -27,7 +27,7 @@ Siga os passos abaixo para começar a tripular! 🚣‍♂️
<Step title="Navegue até o novo projeto da sua tripulação">
<CodeGroup>
```shell Terminal
cd latest-ai-development
cd latest_ai_development
```
</CodeGroup>
</Step>

View File

@@ -9,7 +9,7 @@ authors = [
]
dependencies = [
# Core Dependencies
"pydantic>=2.4.2",
"pydantic>=2.11.9",
"openai>=1.13.3",
"litellm==1.74.9",
"instructor>=1.3.3",
@@ -27,7 +27,7 @@ dependencies = [
"openpyxl>=3.1.5",
"pyvis>=0.3.2",
# Authentication and Security
"python-dotenv>=1.0.0",
"python-dotenv>=1.1.1",
"pyjwt>=2.9.0",
# Configuration and Utils
"click>=8.1.7",
@@ -40,6 +40,7 @@ dependencies = [
"blinker>=1.9.0",
"json5>=0.10.0",
"portalocker==2.7.0",
"pydantic-settings>=2.10.1",
]
[project.urls]
@@ -72,23 +73,20 @@ qdrant = [
"qdrant-client[fastembed]>=1.14.3",
]
[tool.uv]
dev-dependencies = [
"ruff>=0.12.11",
"mypy>=1.17.1",
[dependency-groups]
dev = [
"ruff>=0.13.1",
"mypy>=1.18.2",
"pre-commit>=4.3.0",
"bandit>=1.8.6",
"pillow>=10.2.0",
"cairosvg>=2.7.1",
"pytest>=8.0.0",
"python-dotenv>=1.0.0",
"pytest-asyncio>=0.23.7",
"pytest-subprocess>=1.5.2",
"pytest-recording>=0.13.2",
"pytest-randomly>=3.16.0",
"pytest-timeout>=2.3.1",
"pytest-xdist>=3.6.1",
"pytest-split>=0.9.0",
"pytest>=8.4.2",
"pytest-asyncio>=1.2.0",
"pytest-subprocess>=1.5.3",
"pytest-recording>=0.13.4",
"pytest-randomly>=4.0.1",
"pytest-timeout>=2.4.0",
"pytest-xdist>=3.8.0",
"pytest-split>=0.10.0",
"types-requests==2.32.*",
"types-pyyaml==6.0.*",
"types-regex==2024.11.6.*",

View File

@@ -40,7 +40,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
_suppress_pydantic_deprecation_warnings()
__version__ = "0.193.1"
__version__ = "0.193.2"
_telemetry_submitted = False

View File

@@ -18,7 +18,7 @@ from crewai.agents.constants import (
MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE,
UNABLE_TO_REPAIR_JSON_RESULTS,
)
from crewai.utilities import I18N
from crewai.utilities.i18n import I18N
_I18N = I18N()

View File

@@ -1,3 +1,5 @@
import os
import subprocess
from importlib.metadata import version as get_version
import click
@@ -8,6 +10,7 @@ from crewai.cli.create_crew import create_crew
from crewai.cli.create_flow import create_flow
from crewai.cli.crew_chat import run_chat
from crewai.cli.settings.main import SettingsCommand
from crewai.cli.utils import build_env_with_tool_repository_credentials, read_toml
from crewai.memory.storage.kickoff_task_outputs_storage import (
KickoffTaskOutputsSQLiteStorage,
)
@@ -25,6 +28,7 @@ from .reset_memories_command import reset_memories_command
from .run_crew import run_crew
from .tools.main import ToolCommand
from .train_crew import train_crew
from .trigger_command import TriggerCommand
from .update_crew import update_crew
@@ -34,6 +38,46 @@ def crewai():
"""Top-level command group for crewai."""
@crewai.command(
name="uv",
context_settings=dict(
ignore_unknown_options=True,
),
)
@click.argument("uv_args", nargs=-1, type=click.UNPROCESSED)
def uv(uv_args):
"""A wrapper around uv commands that adds custom tool authentication through env vars."""
env = os.environ.copy()
try:
pyproject_data = read_toml()
sources = pyproject_data.get("tool", {}).get("uv", {}).get("sources", {})
for source_config in sources.values():
if isinstance(source_config, dict):
index = source_config.get("index")
if index:
index_env = build_env_with_tool_repository_credentials(index)
env.update(index_env)
except (FileNotFoundError, KeyError) as e:
raise SystemExit(
"Error. A valid pyproject.toml file is required. Check that a valid pyproject.toml file exists in the current directory."
) from e
except Exception as e:
raise SystemExit(f"Error: {e}") from e
try:
subprocess.run( # noqa: S603
["uv", *uv_args], # noqa: S607
capture_output=False,
env=env,
text=True,
check=True,
)
except subprocess.CalledProcessError as e:
click.secho(f"uv command failed with exit code {e.returncode}", fg="red")
raise SystemExit(e.returncode) from e
@crewai.command()
@click.argument("type", type=click.Choice(["crew", "flow"]))
@click.argument("name")
@@ -239,11 +283,6 @@ def deploy():
"""Deploy the Crew CLI group."""
@crewai.group()
def tool():
"""Tool Repository related commands."""
@deploy.command(name="create")
@click.option("-y", "--yes", is_flag=True, help="Skip the confirmation prompt")
def deploy_create(yes: bool):
@@ -291,6 +330,11 @@ def deploy_remove(uuid: str | None):
deploy_cmd.remove_crew(uuid=uuid)
@crewai.group()
def tool():
"""Tool Repository related commands."""
@tool.command(name="create")
@click.argument("handle")
def tool_create(handle: str):
@@ -430,5 +474,18 @@ def config_reset():
config_command.reset_all_settings()
@crewai.command()
@click.argument("action_or_trigger")
def trigger(action_or_trigger: str):
"""Trigger management. Use 'list' to list triggers or provide trigger identification to run."""
trigger_cmd = TriggerCommand()
if action_or_trigger == "list":
trigger_cmd.list_triggers()
else:
# Assume it's a trigger identification
trigger_cmd.run_trigger(action_or_trigger)
if __name__ == "__main__":
crewai()

View File

@@ -1,4 +1,6 @@
import json
import tempfile
from logging import getLogger
from pathlib import Path
from pydantic import BaseModel, Field
@@ -12,8 +14,48 @@ from crewai.cli.constants import (
)
from crewai.cli.shared.token_manager import TokenManager
logger = getLogger(__name__)
DEFAULT_CONFIG_PATH = Path.home() / ".config" / "crewai" / "settings.json"
def get_writable_config_path() -> Path | None:
"""
Find a writable location for the config file with fallback options.
Tries in order:
1. Default: ~/.config/crewai/settings.json
2. Temp directory: /tmp/crewai_settings.json (or OS equivalent)
3. Current directory: ./crewai_settings.json
4. In-memory only (returns None)
Returns:
Path object for writable config location, or None if no writable location found
"""
fallback_paths = [
DEFAULT_CONFIG_PATH, # Default location
Path(tempfile.gettempdir()) / "crewai_settings.json", # Temporary directory
Path.cwd() / "crewai_settings.json", # Current working directory
]
for config_path in fallback_paths:
try:
config_path.parent.mkdir(parents=True, exist_ok=True)
test_file = config_path.parent / ".crewai_write_test"
try:
test_file.write_text("test")
test_file.unlink() # Clean up test file
logger.info(f"Using config path: {config_path}")
return config_path
except Exception: # noqa: S112
continue
except Exception: # noqa: S112
continue
return None
# Settings that are related to the user's account
USER_SETTINGS_KEYS = [
"tool_repository_username",
@@ -93,16 +135,32 @@ class Settings(BaseModel):
default=DEFAULT_CLI_SETTINGS["oauth2_domain"],
)
def __init__(self, config_path: Path = DEFAULT_CONFIG_PATH, **data):
"""Load Settings from config path"""
config_path.parent.mkdir(parents=True, exist_ok=True)
def __init__(self, config_path: Path | None = None, **data):
"""Load Settings from config path with fallback support"""
if config_path is None:
config_path = get_writable_config_path()
# If config_path is None, we're in memory-only mode
if config_path is None:
merged_data = {**data}
# Dummy path for memory-only mode
super().__init__(config_path=Path("/dev/null"), **merged_data)
return
try:
config_path.parent.mkdir(parents=True, exist_ok=True)
except Exception:
merged_data = {**data}
# Dummy path for memory-only mode
super().__init__(config_path=Path("/dev/null"), **merged_data)
return
file_data = {}
if config_path.is_file():
try:
with config_path.open("r") as f:
file_data = json.load(f)
except json.JSONDecodeError:
except Exception:
file_data = {}
merged_data = {**file_data, **data}
@@ -122,15 +180,22 @@ class Settings(BaseModel):
def dump(self) -> None:
"""Save current settings to settings.json"""
if self.config_path.is_file():
with self.config_path.open("r") as f:
existing_data = json.load(f)
else:
existing_data = {}
if str(self.config_path) == "/dev/null":
return
updated_data = {**existing_data, **self.model_dump(exclude_unset=True)}
with self.config_path.open("w") as f:
json.dump(updated_data, f, indent=4)
try:
if self.config_path.is_file():
with self.config_path.open("r") as f:
existing_data = json.load(f)
else:
existing_data = {}
updated_data = {**existing_data, **self.model_dump(exclude_unset=True)}
with self.config_path.open("w") as f:
json.dump(updated_data, f, indent=4)
except Exception: # noqa: S110
pass
def _reset_user_settings(self) -> None:
"""Reset all user settings to default values"""

View File

@@ -18,6 +18,7 @@ class PlusAPI:
AGENTS_RESOURCE = "/crewai_plus/api/v1/agents"
TRACING_RESOURCE = "/crewai_plus/api/v1/tracing"
EPHEMERAL_TRACING_RESOURCE = "/crewai_plus/api/v1/tracing/ephemeral"
TRIGGERS_RESOURCE = "/v1/triggers"
def __init__(self, api_key: str) -> None:
self.api_key = api_key
@@ -166,3 +167,25 @@ class PlusAPI:
json=payload,
timeout=30,
)
def mark_trace_batch_as_failed(
self, trace_batch_id: str, error_message: str
) -> requests.Response:
return self._make_request(
"PATCH",
f"{self.TRACING_RESOURCE}/batches/{trace_batch_id}",
json={"status": "failed", "failure_reason": error_message},
timeout=30,
)
def list_triggers(self) -> requests.Response:
"""List all triggers from the current user."""
return self._make_request("GET", self.TRIGGERS_RESOURCE)
def get_trigger_sample_payload(self, trigger_identification: str) -> requests.Response:
"""Get sample payload for a trigger identification."""
return self._make_request(
"POST",
f"{self.TRIGGERS_RESOURCE}/sample_payload",
json={"trigger_identification": trigger_identification}
)

View File

@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
authors = [{ name = "Your Name", email = "you@example.com" }]
requires-python = ">=3.10,<3.14"
dependencies = [
"crewai[tools]>=0.193.0,<1.0.0"
"crewai[tools]>=0.193.2,<1.0.0"
]
[project.scripts]

View File

@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
authors = [{ name = "Your Name", email = "you@example.com" }]
requires-python = ">=3.10,<3.14"
dependencies = [
"crewai[tools]>=0.193.0,<1.0.0",
"crewai[tools]>=0.193.2,<1.0.0",
]
[project.scripts]

View File

@@ -5,7 +5,7 @@ description = "Power up your crews with {{folder_name}}"
readme = "README.md"
requires-python = ">=3.10,<3.14"
dependencies = [
"crewai[tools]>=0.193.0"
"crewai[tools]>=0.193.2"
]
[tool.crewai]

View File

@@ -12,6 +12,7 @@ from crewai.cli import git
from crewai.cli.command import BaseCommand, PlusAPIMixin
from crewai.cli.config import Settings
from crewai.cli.utils import (
build_env_with_tool_repository_credentials,
extract_available_exports,
get_project_description,
get_project_name,
@@ -42,8 +43,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
if project_root.exists():
click.secho(f"Folder {folder_name} already exists.", fg="red")
raise SystemExit
else:
os.makedirs(project_root)
os.makedirs(project_root)
click.secho(f"Creating custom tool {folder_name}...", fg="green", bold=True)
@@ -56,7 +56,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
os.chdir(project_root)
try:
self.login()
subprocess.run(["git", "init"], check=True)
subprocess.run(["git", "init"], check=True) # noqa: S607
console.print(
f"[green]Created custom tool [bold]{folder_name}[/bold]. Run [bold]cd {project_root}[/bold] to start working.[/green]"
)
@@ -76,10 +76,10 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
raise SystemExit()
project_name = get_project_name(require=True)
assert isinstance(project_name, str)
assert isinstance(project_name, str) # noqa: S101
project_version = get_project_version(require=True)
assert isinstance(project_version, str)
assert isinstance(project_version, str) # noqa: S101
project_description = get_project_description(require=False)
encoded_tarball = None
@@ -94,8 +94,8 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
self._print_current_organization()
with tempfile.TemporaryDirectory() as temp_build_dir:
subprocess.run(
["uv", "build", "--sdist", "--out-dir", temp_build_dir],
subprocess.run( # noqa: S603
["uv", "build", "--sdist", "--out-dir", temp_build_dir], # noqa: S607
check=True,
capture_output=False,
)
@@ -146,7 +146,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
style="bold red",
)
raise SystemExit
elif get_response.status_code != 200:
if get_response.status_code != 200:
console.print(
"Failed to get tool details. Please try again later.", style="bold red"
)
@@ -196,10 +196,10 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
else:
add_package_command.extend(["--index", index, tool_handle])
add_package_result = subprocess.run(
add_package_result = subprocess.run( # noqa: S603
add_package_command,
capture_output=False,
env=self._build_env_with_credentials(repository_handle),
env=build_env_with_tool_repository_credentials(repository_handle),
text=True,
check=True,
)
@@ -221,20 +221,6 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
)
raise SystemExit
def _build_env_with_credentials(self, repository_handle: str):
repository_handle = repository_handle.upper().replace("-", "_")
settings = Settings()
env = os.environ.copy()
env[f"UV_INDEX_{repository_handle}_USERNAME"] = str(
settings.tool_repository_username or ""
)
env[f"UV_INDEX_{repository_handle}_PASSWORD"] = str(
settings.tool_repository_password or ""
)
return env
def _print_current_organization(self) -> None:
settings = Settings()
if settings.org_uuid:

View File

@@ -0,0 +1,315 @@
import sys
import os
import subprocess
from typing import Dict, Any
import click
import requests
from rich.console import Console
from rich.table import Table
from rich.text import Text
from crewai.cli.command import BaseCommand, PlusAPIMixin
from crewai.telemetry.telemetry import Telemetry
console = Console()
class TriggerCommand(BaseCommand, PlusAPIMixin):
"""Command handler for trigger-related operations."""
def __init__(self):
"""Initialize the trigger command with telemetry and API client."""
self._telemetry = Telemetry()
super().__init__()
PlusAPIMixin.__init__(self, self._telemetry)
def list_triggers(self) -> None:
"""List all triggers grouped by provider name."""
try:
console.print("Fetching triggers from CrewAI API...", style="blue")
# Fetch triggers from API
response = self.plus_api_client.list_triggers()
self._validate_response(response)
triggers_data = response.json()
if not triggers_data:
console.print(
"No triggers found for the current user.", style="yellow"
)
return
# Display triggers grouped by provider
self._display_triggers(triggers_data)
except requests.exceptions.ConnectionError:
console.print(
"Failed to connect to CrewAI API. Please check your internet connection.",
style="bold red"
)
raise SystemExit(1)
except requests.exceptions.Timeout:
console.print(
"Request to CrewAI API timed out. Please try again later.",
style="bold red"
)
raise SystemExit(1)
except requests.exceptions.HTTPError as e:
if e.response.status_code == 401:
console.print(
"Authentication failed. Please run 'crewai login' to authenticate.",
style="bold red"
)
elif e.response.status_code == 403:
console.print(
"Access denied. You may not have permission to access triggers.",
style="bold red"
)
else:
console.print(f"HTTP error occurred: {e}", style="bold red")
raise SystemExit(1)
except Exception as e:
console.print(f"Unexpected error listing triggers: {e}", style="bold red")
console.print("Please check your configuration and try again.", style="yellow")
raise SystemExit(1)
def run_trigger(self, trigger_identification: str) -> None:
"""Run a crew with the specified trigger payload."""
try:
# Validate trigger identification format
if not trigger_identification or "/" not in trigger_identification:
console.print(
"Invalid trigger identification format. Expected format: 'app/trigger_name'",
style="bold red"
)
console.print(
"Use 'crewai trigger list' to see available triggers.", style="yellow"
)
raise SystemExit(1)
# Get sample payload for the trigger
console.print(f"Getting sample payload for trigger: {trigger_identification}", style="blue")
response = self.plus_api_client.get_trigger_sample_payload(trigger_identification)
self._validate_response(response)
trigger_payload = response.json()
if not trigger_payload:
console.print(
f"No sample payload found for trigger: {trigger_identification}",
style="yellow"
)
console.print(
"Use 'crewai trigger list' to see available triggers.", style="yellow"
)
return
console.print("Sample payload retrieved successfully", style="green")
# Import and run the crew with the trigger payload
self._run_crew_with_payload(trigger_payload)
except requests.exceptions.ConnectionError:
console.print(
"Failed to connect to CrewAI API. Please check your internet connection.",
style="bold red"
)
raise SystemExit(1)
except requests.exceptions.Timeout:
console.print(
"Request to CrewAI API timed out. Please try again later.",
style="bold red"
)
raise SystemExit(1)
except requests.exceptions.HTTPError as e:
if e.response.status_code == 401:
console.print(
"Authentication failed. Please run 'crewai login' to authenticate.",
style="bold red"
)
elif e.response.status_code == 404:
console.print(
f"Trigger '{trigger_identification}' not found.",
style="bold red"
)
console.print(
"Use 'crewai trigger list' to see available triggers.", style="yellow"
)
elif e.response.status_code == 403:
console.print(
"Access denied. You may not have permission to access this trigger.",
style="bold red"
)
else:
console.print(f"HTTP error occurred: {e}", style="bold red")
raise SystemExit(1)
except FileNotFoundError as e:
console.print(
f"Project file not found: {e}", style="bold red"
)
console.print(
"Make sure you're in a valid CrewAI project directory.", style="yellow"
)
raise SystemExit(1)
except subprocess.CalledProcessError as e:
console.print(f"Error running crew: {e}", style="bold red")
if e.output:
console.print(f"Output: {e.output}", style="red")
raise SystemExit(1)
except Exception as e:
console.print(f"Unexpected error running trigger: {e}", style="bold red")
console.print("Please check your configuration and try again.", style="yellow")
raise SystemExit(1)
def _display_triggers(self, triggers_data: Dict[str, Any]) -> None:
"""Display triggers in a formatted table grouped by provider."""
table = Table(title="Available Triggers")
table.add_column("Provider", style="cyan", no_wrap=True)
table.add_column("Trigger ID", style="magenta")
table.add_column("Description", style="green")
# Group triggers by provider
for provider_name, triggers in triggers_data.items():
if isinstance(triggers, dict):
# Add provider header
first_trigger = True
for trigger_id, trigger_info in triggers.items():
description = trigger_info.get("description", "No description available")
# Display provider name only for the first trigger of each provider
provider_display = provider_name if first_trigger else ""
first_trigger = False
table.add_row(
provider_display,
trigger_id,
description
)
# Add separator between providers (except for the last one)
if provider_name != list(triggers_data.keys())[-1]:
table.add_row("", "", "")
console.print(table)
console.print("\nTo run a trigger, use: [bold green]crewai trigger <trigger_id>[/bold green]")
def _run_crew_with_payload(self, trigger_payload: Dict[str, Any]) -> None:
"""Run the crew with the trigger payload."""
script_path = None
try:
from crewai.cli.utils import read_toml
# Validate project structure
if not os.path.exists("pyproject.toml"):
raise FileNotFoundError("pyproject.toml not found. Make sure you're in a CrewAI project directory.")
if not os.path.exists("src"):
raise FileNotFoundError("src directory not found. Make sure you're in a CrewAI project directory.")
if not os.path.exists("src/main.py"):
raise FileNotFoundError("src/main.py not found. Make sure you have a valid CrewAI project.")
# Read project configuration
pyproject_data = read_toml()
is_flow = pyproject_data.get("tool", {}).get("crewai", {}).get("type") == "flow"
console.print(f"Project type detected: {'Flow' if is_flow else 'Crew'}")
console.print("Preparing execution environment...")
# Create a temporary script to run the crew with trigger payload
script_content = self._generate_crew_script(trigger_payload, is_flow)
# Write script to temporary file
script_path = "temp_trigger_run.py"
with open(script_path, "w") as f:
f.write(script_content)
console.print(f"Running {'flow' if is_flow else 'crew'} with trigger payload...", style="blue")
# Execute the script
command = ["uv", "run", "python", script_path]
result = subprocess.run(command, check=True, capture_output=True, text=True)
# Display success message
console.print("✓ Execution completed successfully!", style="bold green")
if result.stdout:
console.print("Output:", style="blue")
console.print(result.stdout)
except FileNotFoundError as e:
raise # Re-raise to be caught by the outer try-catch
except subprocess.CalledProcessError as e:
error_msg = f"Crew execution failed with exit code {e.returncode}"
if e.stderr:
error_msg += f"\nError output: {e.stderr}"
if e.stdout:
error_msg += f"\nStandard output: {e.stdout}"
raise subprocess.CalledProcessError(e.returncode, e.cmd, error_msg)
except Exception as e:
raise Exception(f"Failed to execute crew: {str(e)}")
finally:
# Clean up temporary script
if script_path and os.path.exists(script_path):
try:
os.remove(script_path)
except OSError:
console.print(f"Warning: Could not remove temporary file {script_path}", style="yellow")
def _generate_crew_script(self, trigger_payload: Dict[str, Any], is_flow: bool) -> str:
"""Generate a Python script to run the crew with trigger payload."""
if is_flow:
return f"""
import sys
sys.path.append('src')
from main import *
def main():
try:
# Initialize and run the flow with trigger payload
flow = main()
# Add trigger payload to inputs
inputs = {{"crewai_trigger_payload": {trigger_payload}}}
result = flow.kickoff(inputs=inputs)
print("Flow execution completed successfully")
print(f"Result: {{result}}")
except Exception as e:
print(f"Error running flow: {{e}}", file=sys.stderr)
sys.exit(1)
if __name__ == "__main__":
main()
"""
else:
return f"""
import sys
sys.path.append('src')
def main():
try:
# Import the crew
from main import main as crew_main
# Get the crew instance
crew = crew_main()
# Add trigger payload to inputs
inputs = {{"crewai_trigger_payload": {trigger_payload}}}
result = crew.kickoff(inputs=inputs)
print("Crew execution completed successfully")
print(f"Result: {{result}}")
except Exception as e:
print(f"Error running crew: {{e}}", file=sys.stderr)
sys.exit(1)
if __name__ == "__main__":
main()
"""

View File

@@ -11,6 +11,7 @@ import click
import tomli
from rich.console import Console
from crewai.cli.config import Settings
from crewai.cli.constants import ENV_VARS
from crewai.crew import Crew
from crewai.flow import Flow
@@ -417,6 +418,21 @@ def extract_available_exports(dir_path: str = "src"):
raise SystemExit(1) from e
def build_env_with_tool_repository_credentials(repository_handle: str):
repository_handle = repository_handle.upper().replace("-", "_")
settings = Settings()
env = os.environ.copy()
env[f"UV_INDEX_{repository_handle}_USERNAME"] = str(
settings.tool_repository_username or ""
)
env[f"UV_INDEX_{repository_handle}_PASSWORD"] = str(
settings.tool_repository_password or ""
)
return env
def _load_tools_from_init(init_file: Path) -> list[dict[str, Any]]:
"""
Load and validate tools from a given __init__.py file.

View File

@@ -200,6 +200,9 @@ class TraceBatchManager:
if self.event_buffer:
events_sent_to_backend_status = self._send_events_to_backend()
if events_sent_to_backend_status == 500:
self.plus_api.mark_trace_batch_as_failed(
self.trace_batch_id, "Error sending events to backend"
)
return None
self._finalize_backend_batch()
@@ -273,10 +276,13 @@ class TraceBatchManager:
logger.error(
f"❌ Failed to finalize trace batch: {response.status_code} - {response.text}"
)
self.plus_api.mark_trace_batch_as_failed(
self.trace_batch_id, response.text
)
except Exception as e:
logger.error(f"❌ Error finalizing trace batch: {e}")
# TODO: send error to app marking as failed
self.plus_api.mark_trace_batch_as_failed(self.trace_batch_id, str(e))
def _cleanup_batch_data(self):
"""Clean up batch data after successful finalization to free memory"""

View File

@@ -1,40 +1,39 @@
from crewai.experimental.evaluation import (
AgentEvaluationResult,
AgentEvaluator,
BaseEvaluator,
EvaluationScore,
MetricCategory,
AgentEvaluationResult,
SemanticQualityEvaluator,
GoalAlignmentEvaluator,
ReasoningEfficiencyEvaluator,
ToolSelectionEvaluator,
ParameterExtractionEvaluator,
ToolInvocationEvaluator,
EvaluationTraceCallback,
create_evaluation_callbacks,
AgentEvaluator,
create_default_evaluator,
ExperimentRunner,
ExperimentResults,
ExperimentResult,
ExperimentResults,
ExperimentRunner,
GoalAlignmentEvaluator,
MetricCategory,
ParameterExtractionEvaluator,
ReasoningEfficiencyEvaluator,
SemanticQualityEvaluator,
ToolInvocationEvaluator,
ToolSelectionEvaluator,
create_default_evaluator,
create_evaluation_callbacks,
)
__all__ = [
"AgentEvaluationResult",
"AgentEvaluator",
"BaseEvaluator",
"EvaluationScore",
"MetricCategory",
"AgentEvaluationResult",
"SemanticQualityEvaluator",
"GoalAlignmentEvaluator",
"ReasoningEfficiencyEvaluator",
"ToolSelectionEvaluator",
"ParameterExtractionEvaluator",
"ToolInvocationEvaluator",
"EvaluationTraceCallback",
"create_evaluation_callbacks",
"AgentEvaluator",
"create_default_evaluator",
"ExperimentRunner",
"ExperimentResult",
"ExperimentResults",
"ExperimentResult"
]
"ExperimentRunner",
"GoalAlignmentEvaluator",
"MetricCategory",
"ParameterExtractionEvaluator",
"ReasoningEfficiencyEvaluator",
"SemanticQualityEvaluator",
"ToolInvocationEvaluator",
"ToolSelectionEvaluator",
"create_default_evaluator",
"create_evaluation_callbacks",
]

View File

@@ -1,51 +1,47 @@
from crewai.experimental.evaluation.agent_evaluator import (
AgentEvaluator,
create_default_evaluator,
)
from crewai.experimental.evaluation.base_evaluator import (
AgentEvaluationResult,
BaseEvaluator,
EvaluationScore,
MetricCategory,
AgentEvaluationResult
)
from crewai.experimental.evaluation.metrics import (
SemanticQualityEvaluator,
GoalAlignmentEvaluator,
ReasoningEfficiencyEvaluator,
ToolSelectionEvaluator,
ParameterExtractionEvaluator,
ToolInvocationEvaluator
)
from crewai.experimental.evaluation.evaluation_listener import (
EvaluationTraceCallback,
create_evaluation_callbacks
create_evaluation_callbacks,
)
from crewai.experimental.evaluation.agent_evaluator import (
AgentEvaluator,
create_default_evaluator
)
from crewai.experimental.evaluation.experiment import (
ExperimentRunner,
ExperimentResult,
ExperimentResults,
ExperimentResult
ExperimentRunner,
)
from crewai.experimental.evaluation.metrics import (
GoalAlignmentEvaluator,
ParameterExtractionEvaluator,
ReasoningEfficiencyEvaluator,
SemanticQualityEvaluator,
ToolInvocationEvaluator,
ToolSelectionEvaluator,
)
__all__ = [
"AgentEvaluationResult",
"AgentEvaluator",
"BaseEvaluator",
"EvaluationScore",
"MetricCategory",
"AgentEvaluationResult",
"SemanticQualityEvaluator",
"GoalAlignmentEvaluator",
"ReasoningEfficiencyEvaluator",
"ToolSelectionEvaluator",
"ParameterExtractionEvaluator",
"ToolInvocationEvaluator",
"EvaluationTraceCallback",
"create_evaluation_callbacks",
"AgentEvaluator",
"create_default_evaluator",
"ExperimentRunner",
"ExperimentResult",
"ExperimentResults",
"ExperimentResult"
"ExperimentRunner",
"GoalAlignmentEvaluator",
"MetricCategory",
"ParameterExtractionEvaluator",
"ReasoningEfficiencyEvaluator",
"SemanticQualityEvaluator",
"ToolInvocationEvaluator",
"ToolSelectionEvaluator",
"create_default_evaluator",
"create_evaluation_callbacks",
]

View File

@@ -1,34 +1,36 @@
import threading
from typing import Any, Optional
from collections.abc import Sequence
from typing import Any
from crewai.experimental.evaluation.base_evaluator import (
AgentEvaluationResult,
AggregationStrategy,
)
from crewai.agent import Agent
from crewai.task import Task
from crewai.experimental.evaluation.evaluation_display import EvaluationDisplayFormatter
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.agent_events import (
AgentEvaluationStartedEvent,
AgentEvaluationCompletedEvent,
AgentEvaluationFailedEvent,
AgentEvaluationStartedEvent,
LiteAgentExecutionCompletedEvent,
)
from crewai.experimental.evaluation import BaseEvaluator, create_evaluation_callbacks
from collections.abc import Sequence
from crewai.events.event_bus import crewai_event_bus
from crewai.events.utils.console_formatter import ConsoleFormatter
from crewai.events.types.task_events import TaskCompletedEvent
from crewai.events.types.agent_events import LiteAgentExecutionCompletedEvent
from crewai.events.utils.console_formatter import ConsoleFormatter
from crewai.experimental.evaluation.base_evaluator import (
AgentAggregatedEvaluationResult,
AgentEvaluationResult,
AggregationStrategy,
BaseEvaluator,
EvaluationScore,
MetricCategory,
)
from crewai.experimental.evaluation.evaluation_display import EvaluationDisplayFormatter
from crewai.experimental.evaluation.evaluation_listener import (
create_evaluation_callbacks,
)
from crewai.task import Task
class ExecutionState:
current_agent_id: Optional[str] = None
current_task_id: Optional[str] = None
current_agent_id: str | None = None
current_task_id: str | None = None
def __init__(self):
self.traces = {}
@@ -40,10 +42,10 @@ class ExecutionState:
class AgentEvaluator:
def __init__(
self,
agents: list[Agent],
agents: list[Agent] | list[BaseAgent],
evaluators: Sequence[BaseEvaluator] | None = None,
):
self.agents: list[Agent] = agents
self.agents: list[Agent] | list[BaseAgent] = agents
self.evaluators: Sequence[BaseEvaluator] | None = evaluators
self.callback = create_evaluation_callbacks()
@@ -75,7 +77,8 @@ class AgentEvaluator:
)
def _handle_task_completed(self, source: Any, event: TaskCompletedEvent) -> None:
assert event.task is not None
if event.task is None:
raise ValueError("TaskCompletedEvent must have a task")
agent = event.task.agent
if (
agent
@@ -92,9 +95,8 @@ class AgentEvaluator:
state.current_agent_id = str(agent.id)
state.current_task_id = str(event.task.id)
assert (
state.current_agent_id is not None and state.current_task_id is not None
)
if state.current_agent_id is None or state.current_task_id is None:
raise ValueError("Agent ID and Task ID must not be None")
trace = self.callback.get_trace(
state.current_agent_id, state.current_task_id
)
@@ -146,9 +148,8 @@ class AgentEvaluator:
if not target_agent:
return
assert (
state.current_agent_id is not None and state.current_task_id is not None
)
if state.current_agent_id is None or state.current_task_id is None:
raise ValueError("Agent ID and Task ID must not be None")
trace = self.callback.get_trace(
state.current_agent_id, state.current_task_id
)
@@ -244,7 +245,7 @@ class AgentEvaluator:
def evaluate(
self,
agent: Agent,
agent: Agent | BaseAgent,
execution_trace: dict[str, Any],
final_output: Any,
state: ExecutionState,
@@ -255,7 +256,8 @@ class AgentEvaluator:
task_id=state.current_task_id or (str(task.id) if task else "unknown_task"),
)
assert self.evaluators is not None
if self.evaluators is None:
raise ValueError("Evaluators must be initialized")
task_id = str(task.id) if task else None
for evaluator in self.evaluators:
try:
@@ -276,7 +278,7 @@ class AgentEvaluator:
metric_category=evaluator.metric_category,
score=score,
)
except Exception as e:
except Exception as e: # noqa: PERF203
self.emit_evaluation_failed_event(
agent_role=agent.role,
agent_id=str(agent.id),
@@ -284,7 +286,7 @@ class AgentEvaluator:
error=str(e),
)
self.console_formatter.print(
f"Error in {evaluator.metric_category.value} evaluator: {str(e)}"
f"Error in {evaluator.metric_category.value} evaluator: {e!s}"
)
return result
@@ -337,14 +339,14 @@ class AgentEvaluator:
)
def create_default_evaluator(agents: list[Agent], llm: None = None):
def create_default_evaluator(agents: list[Agent] | list[BaseAgent], llm: None = None):
from crewai.experimental.evaluation import (
GoalAlignmentEvaluator,
SemanticQualityEvaluator,
ToolSelectionEvaluator,
ParameterExtractionEvaluator,
ToolInvocationEvaluator,
ReasoningEfficiencyEvaluator,
SemanticQualityEvaluator,
ToolInvocationEvaluator,
ToolSelectionEvaluator,
)
evaluators = [

View File

@@ -1,15 +1,17 @@
import abc
import enum
from enum import Enum
from typing import Any, Dict, List, Optional
from typing import Any
from pydantic import BaseModel, Field
from crewai.agent import Agent
from crewai.task import Task
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.llm import BaseLLM
from crewai.task import Task
from crewai.utilities.llm_utils import create_llm
class MetricCategory(enum.Enum):
GOAL_ALIGNMENT = "goal_alignment"
SEMANTIC_QUALITY = "semantic_quality"
@@ -19,7 +21,7 @@ class MetricCategory(enum.Enum):
TOOL_INVOCATION = "tool_invocation"
def title(self):
return self.value.replace('_', ' ').title()
return self.value.replace("_", " ").title()
class EvaluationScore(BaseModel):
@@ -27,15 +29,13 @@ class EvaluationScore(BaseModel):
default=5.0,
description="Numeric score from 0-10 where 0 is worst and 10 is best, None if not applicable",
ge=0.0,
le=10.0
le=10.0,
)
feedback: str = Field(
default="",
description="Detailed feedback explaining the evaluation score"
default="", description="Detailed feedback explaining the evaluation score"
)
raw_response: str | None = Field(
default=None,
description="Raw response from the evaluator (e.g., LLM)"
default=None, description="Raw response from the evaluator (e.g., LLM)"
)
def __str__(self) -> str:
@@ -56,8 +56,8 @@ class BaseEvaluator(abc.ABC):
@abc.abstractmethod
def evaluate(
self,
agent: Agent,
execution_trace: Dict[str, Any],
agent: Agent | BaseAgent,
execution_trace: dict[str, Any],
final_output: Any,
task: Task | None = None,
) -> EvaluationScore:
@@ -67,9 +67,8 @@ class BaseEvaluator(abc.ABC):
class AgentEvaluationResult(BaseModel):
agent_id: str = Field(description="ID of the evaluated agent")
task_id: str = Field(description="ID of the task that was executed")
metrics: Dict[MetricCategory, EvaluationScore] = Field(
default_factory=dict,
description="Evaluation scores for each metric category"
metrics: dict[MetricCategory, EvaluationScore] = Field(
default_factory=dict, description="Evaluation scores for each metric category"
)
@@ -81,33 +80,23 @@ class AggregationStrategy(Enum):
class AgentAggregatedEvaluationResult(BaseModel):
agent_id: str = Field(
default="",
description="ID of the agent"
)
agent_role: str = Field(
default="",
description="Role of the agent"
)
agent_id: str = Field(default="", description="ID of the agent")
agent_role: str = Field(default="", description="Role of the agent")
task_count: int = Field(
default=0,
description="Number of tasks included in this aggregation"
default=0, description="Number of tasks included in this aggregation"
)
aggregation_strategy: AggregationStrategy = Field(
default=AggregationStrategy.SIMPLE_AVERAGE,
description="Strategy used for aggregation"
description="Strategy used for aggregation",
)
metrics: Dict[MetricCategory, EvaluationScore] = Field(
default_factory=dict,
description="Aggregated metrics across all tasks"
metrics: dict[MetricCategory, EvaluationScore] = Field(
default_factory=dict, description="Aggregated metrics across all tasks"
)
task_results: List[str] = Field(
default_factory=list,
description="IDs of tasks included in this aggregation"
task_results: list[str] = Field(
default_factory=list, description="IDs of tasks included in this aggregation"
)
overall_score: Optional[float] = Field(
default=None,
description="Overall score for this agent"
overall_score: float | None = Field(
default=None, description="Overall score for this agent"
)
def __str__(self) -> str:
@@ -119,7 +108,7 @@ class AgentAggregatedEvaluationResult(BaseModel):
result += f"\n\n- {category.value.upper()}: {score.score}/10\n"
if score.feedback:
detailed_feedback = "\n ".join(score.feedback.split('\n'))
detailed_feedback = "\n ".join(score.feedback.split("\n"))
result += f" {detailed_feedback}\n"
return result
return result

View File

@@ -1,16 +1,18 @@
from collections import defaultdict
from typing import Dict, Any, List
from rich.table import Table
from rich.box import HEAVY_EDGE, ROUNDED
from collections.abc import Sequence
from typing import Any
from rich.box import HEAVY_EDGE, ROUNDED
from rich.table import Table
from crewai.events.utils.console_formatter import ConsoleFormatter
from crewai.experimental.evaluation.base_evaluator import (
AgentAggregatedEvaluationResult,
AggregationStrategy,
AgentEvaluationResult,
AggregationStrategy,
EvaluationScore,
MetricCategory,
)
from crewai.experimental.evaluation import EvaluationScore
from crewai.events.utils.console_formatter import ConsoleFormatter
from crewai.utilities.llm_utils import create_llm
@@ -19,7 +21,7 @@ class EvaluationDisplayFormatter:
self.console_formatter = ConsoleFormatter()
def display_evaluation_with_feedback(
self, iterations_results: Dict[int, Dict[str, List[Any]]]
self, iterations_results: dict[int, dict[str, list[Any]]]
):
if not iterations_results:
self.console_formatter.print(
@@ -99,7 +101,7 @@ class EvaluationDisplayFormatter:
def display_summary_results(
self,
iterations_results: Dict[int, Dict[str, List[AgentAggregatedEvaluationResult]]],
iterations_results: dict[int, dict[str, list[AgentEvaluationResult]]],
):
if not iterations_results:
self.console_formatter.print(
@@ -280,7 +282,7 @@ class EvaluationDisplayFormatter:
feedback_summary = feedbacks[0]
aggregated_metrics[category] = EvaluationScore(
score=avg_score, feedback=feedback_summary
score=avg_score, feedback=feedback_summary or ""
)
overall_score = None
@@ -304,25 +306,25 @@ class EvaluationDisplayFormatter:
self,
agent_role: str,
metric: str,
feedbacks: List[str],
scores: List[float | None],
feedbacks: list[str],
scores: list[float | None],
strategy: AggregationStrategy,
) -> str:
if len(feedbacks) <= 2 and all(len(fb) < 200 for fb in feedbacks):
return "\n\n".join(
[f"Feedback {i+1}: {fb}" for i, fb in enumerate(feedbacks)]
[f"Feedback {i + 1}: {fb}" for i, fb in enumerate(feedbacks)]
)
try:
llm = create_llm()
formatted_feedbacks = []
for i, (feedback, score) in enumerate(zip(feedbacks, scores)):
for i, (feedback, score) in enumerate(zip(feedbacks, scores, strict=False)):
if len(feedback) > 500:
feedback = feedback[:500] + "..."
score_text = f"{score:.1f}" if score is not None else "N/A"
formatted_feedbacks.append(
f"Feedback #{i+1} (Score: {score_text}):\n{feedback}"
f"Feedback #{i + 1} (Score: {score_text}):\n{feedback}"
)
all_feedbacks = "\n\n" + "\n\n---\n\n".join(formatted_feedbacks)
@@ -365,10 +367,9 @@ class EvaluationDisplayFormatter:
""",
},
]
assert llm is not None
response = llm.call(prompt)
return response
if llm is None:
raise ValueError("LLM must be initialized")
return llm.call(prompt)
except Exception:
return "Synthesized from multiple tasks: " + "\n\n".join(

View File

@@ -1,26 +1,25 @@
from datetime import datetime
from typing import Any, Dict, Optional
from collections.abc import Sequence
from datetime import datetime
from typing import Any
from crewai.agent import Agent
from crewai.task import Task
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.events.base_event_listener import BaseEventListener
from crewai.events.event_bus import CrewAIEventsBus
from crewai.events.types.agent_events import (
AgentExecutionStartedEvent,
AgentExecutionCompletedEvent,
LiteAgentExecutionStartedEvent,
AgentExecutionStartedEvent,
LiteAgentExecutionCompletedEvent,
LiteAgentExecutionStartedEvent,
)
from crewai.events.types.llm_events import LLMCallCompletedEvent, LLMCallStartedEvent
from crewai.events.types.tool_usage_events import (
ToolUsageFinishedEvent,
ToolUsageErrorEvent,
ToolExecutionErrorEvent,
ToolSelectionErrorEvent,
ToolUsageErrorEvent,
ToolUsageFinishedEvent,
ToolValidateInputErrorEvent,
)
from crewai.events.types.llm_events import LLMCallStartedEvent, LLMCallCompletedEvent
from crewai.task import Task
class EvaluationTraceCallback(BaseEventListener):
@@ -136,7 +135,7 @@ class EvaluationTraceCallback(BaseEventListener):
def _init_trace(self, trace_key: str, **kwargs: Any):
self.traces[trace_key] = kwargs
def on_agent_start(self, agent: Agent, task: Task):
def on_agent_start(self, agent: BaseAgent, task: Task):
self.current_agent_id = agent.id
self.current_task_id = task.id
@@ -151,7 +150,7 @@ class EvaluationTraceCallback(BaseEventListener):
final_output=None,
)
def on_agent_finish(self, agent: Agent, task: Task, output: Any):
def on_agent_finish(self, agent: BaseAgent, task: Task, output: Any):
trace_key = f"{agent.id}_{task.id}"
if trace_key in self.traces:
self.traces[trace_key]["final_output"] = output
@@ -253,7 +252,7 @@ class EvaluationTraceCallback(BaseEventListener):
if hasattr(self, "current_llm_call"):
self.current_llm_call = {}
def get_trace(self, agent_id: str, task_id: str) -> Optional[Dict[str, Any]]:
def get_trace(self, agent_id: str, task_id: str) -> dict[str, Any] | None:
trace_key = f"{agent_id}_{task_id}"
return self.traces.get(trace_key)

View File

@@ -1,8 +1,7 @@
from crewai.experimental.evaluation.experiment.result import (
ExperimentResult,
ExperimentResults,
)
from crewai.experimental.evaluation.experiment.runner import ExperimentRunner
from crewai.experimental.evaluation.experiment.result import ExperimentResults, ExperimentResult
__all__ = [
"ExperimentRunner",
"ExperimentResults",
"ExperimentResult"
]
__all__ = ["ExperimentResult", "ExperimentResults", "ExperimentRunner"]

View File

@@ -2,45 +2,60 @@ import json
import os
from datetime import datetime, timezone
from typing import Any
from pydantic import BaseModel
class ExperimentResult(BaseModel):
identifier: str
inputs: dict[str, Any]
score: int | dict[str, int | float]
expected_score: int | dict[str, int | float]
score: float | dict[str, float]
expected_score: float | dict[str, float]
passed: bool
agent_evaluations: dict[str, Any] | None = None
class ExperimentResults:
def __init__(self, results: list[ExperimentResult], metadata: dict[str, Any] | None = None):
def __init__(
self, results: list[ExperimentResult], metadata: dict[str, Any] | None = None
):
self.results = results
self.metadata = metadata or {}
self.timestamp = datetime.now(timezone.utc)
from crewai.experimental.evaluation.experiment.result_display import ExperimentResultsDisplay
from crewai.experimental.evaluation.experiment.result_display import (
ExperimentResultsDisplay,
)
self.display = ExperimentResultsDisplay()
def to_json(self, filepath: str | None = None) -> dict[str, Any]:
data = {
"timestamp": self.timestamp.isoformat(),
"metadata": self.metadata,
"results": [r.model_dump(exclude={"agent_evaluations"}) for r in self.results]
"results": [
r.model_dump(exclude={"agent_evaluations"}) for r in self.results
],
}
if filepath:
with open(filepath, 'w') as f:
with open(filepath, "w") as f:
json.dump(data, f, indent=2)
self.display.console.print(f"[green]Results saved to {filepath}[/green]")
return data
def compare_with_baseline(self, baseline_filepath: str, save_current: bool = True, print_summary: bool = False) -> dict[str, Any]:
def compare_with_baseline(
self,
baseline_filepath: str,
save_current: bool = True,
print_summary: bool = False,
) -> dict[str, Any]:
baseline_runs = []
if os.path.exists(baseline_filepath) and os.path.getsize(baseline_filepath) > 0:
try:
with open(baseline_filepath, 'r') as f:
with open(baseline_filepath, "r") as f:
baseline_data = json.load(f)
if isinstance(baseline_data, dict) and "timestamp" in baseline_data:
@@ -48,14 +63,18 @@ class ExperimentResults:
elif isinstance(baseline_data, list):
baseline_runs = baseline_data
except (json.JSONDecodeError, FileNotFoundError) as e:
self.display.console.print(f"[yellow]Warning: Could not load baseline file: {str(e)}[/yellow]")
self.display.console.print(
f"[yellow]Warning: Could not load baseline file: {e!s}[/yellow]"
)
if not baseline_runs:
if save_current:
current_data = self.to_json()
with open(baseline_filepath, 'w') as f:
with open(baseline_filepath, "w") as f:
json.dump([current_data], f, indent=2)
self.display.console.print(f"[green]Saved current results as new baseline to {baseline_filepath}[/green]")
self.display.console.print(
f"[green]Saved current results as new baseline to {baseline_filepath}[/green]"
)
return {"is_baseline": True, "changes": {}}
baseline_runs.sort(key=lambda x: x.get("timestamp", ""), reverse=True)
@@ -69,9 +88,11 @@ class ExperimentResults:
if save_current:
current_data = self.to_json()
baseline_runs.append(current_data)
with open(baseline_filepath, 'w') as f:
with open(baseline_filepath, "w") as f:
json.dump(baseline_runs, f, indent=2)
self.display.console.print(f"[green]Added current results to baseline file {baseline_filepath}[/green]")
self.display.console.print(
f"[green]Added current results to baseline file {baseline_filepath}[/green]"
)
return comparison
@@ -118,5 +139,5 @@ class ExperimentResults:
"new_tests": new_tests,
"missing_tests": missing_tests,
"total_compared": len(improved) + len(regressed) + len(unchanged),
"baseline_timestamp": baseline_run.get("timestamp", "unknown")
"baseline_timestamp": baseline_run.get("timestamp", "unknown"),
}

View File

@@ -1,9 +1,12 @@
from typing import Dict, Any
from typing import Any
from rich.console import Console
from rich.table import Table
from rich.panel import Panel
from rich.table import Table
from crewai.experimental.evaluation.experiment.result import ExperimentResults
class ExperimentResultsDisplay:
def __init__(self):
self.console = Console()
@@ -19,13 +22,19 @@ class ExperimentResultsDisplay:
table.add_row("Total Test Cases", str(total))
table.add_row("Passed", str(passed))
table.add_row("Failed", str(total - passed))
table.add_row("Success Rate", f"{(passed / total * 100):.1f}%" if total > 0 else "N/A")
table.add_row(
"Success Rate", f"{(passed / total * 100):.1f}%" if total > 0 else "N/A"
)
self.console.print(table)
def comparison_summary(self, comparison: Dict[str, Any], baseline_timestamp: str):
self.console.print(Panel(f"[bold]Comparison with baseline run from {baseline_timestamp}[/bold]",
expand=False))
def comparison_summary(self, comparison: dict[str, Any], baseline_timestamp: str):
self.console.print(
Panel(
f"[bold]Comparison with baseline run from {baseline_timestamp}[/bold]",
expand=False,
)
)
table = Table(title="Results Comparison")
table.add_column("Metric", style="cyan")
@@ -34,7 +43,9 @@ class ExperimentResultsDisplay:
improved = comparison.get("improved", [])
if improved:
details = ", ".join([f"{test_identifier}" for test_identifier in improved[:3]])
details = ", ".join(
[f"{test_identifier}" for test_identifier in improved[:3]]
)
if len(improved) > 3:
details += f" and {len(improved) - 3} more"
table.add_row("✅ Improved", str(len(improved)), details)
@@ -43,7 +54,9 @@ class ExperimentResultsDisplay:
regressed = comparison.get("regressed", [])
if regressed:
details = ", ".join([f"{test_identifier}" for test_identifier in regressed[:3]])
details = ", ".join(
[f"{test_identifier}" for test_identifier in regressed[:3]]
)
if len(regressed) > 3:
details += f" and {len(regressed) - 3} more"
table.add_row("❌ Regressed", str(len(regressed)), details, style="red")
@@ -58,13 +71,13 @@ class ExperimentResultsDisplay:
details = ", ".join(new_tests[:3])
if len(new_tests) > 3:
details += f" and {len(new_tests) - 3} more"
table.add_row(" New Tests", str(len(new_tests)), details)
table.add_row("+ New Tests", str(len(new_tests)), details)
missing_tests = comparison.get("missing_tests", [])
if missing_tests:
details = ", ".join(missing_tests[:3])
if len(missing_tests) > 3:
details += f" and {len(missing_tests) - 3} more"
table.add_row(" Missing Tests", str(len(missing_tests)), details)
table.add_row("- Missing Tests", str(len(missing_tests)), details)
self.console.print(table)

View File

@@ -2,11 +2,20 @@ from collections import defaultdict
from hashlib import md5
from typing import Any
from crewai import Crew, Agent
from crewai import Agent, Crew
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.experimental.evaluation import AgentEvaluator, create_default_evaluator
from crewai.experimental.evaluation.experiment.result_display import ExperimentResultsDisplay
from crewai.experimental.evaluation.experiment.result import ExperimentResults, ExperimentResult
from crewai.experimental.evaluation.evaluation_display import AgentAggregatedEvaluationResult
from crewai.experimental.evaluation.evaluation_display import (
AgentAggregatedEvaluationResult,
)
from crewai.experimental.evaluation.experiment.result import (
ExperimentResult,
ExperimentResults,
)
from crewai.experimental.evaluation.experiment.result_display import (
ExperimentResultsDisplay,
)
class ExperimentRunner:
def __init__(self, dataset: list[dict[str, Any]]):
@@ -14,11 +23,17 @@ class ExperimentRunner:
self.evaluator: AgentEvaluator | None = None
self.display = ExperimentResultsDisplay()
def run(self, crew: Crew | None = None, agents: list[Agent] | None = None, print_summary: bool = False) -> ExperimentResults:
def run(
self,
crew: Crew | None = None,
agents: list[Agent] | list[BaseAgent] | None = None,
print_summary: bool = False,
) -> ExperimentResults:
if crew and not agents:
agents = crew.agents
assert agents is not None
if agents is None:
raise ValueError("Agents must be provided either directly or via a crew")
self.evaluator = create_default_evaluator(agents=agents)
results = []
@@ -35,21 +50,37 @@ class ExperimentRunner:
return experiment_results
def _run_test_case(self, test_case: dict[str, Any], agents: list[Agent], crew: Crew | None = None) -> ExperimentResult:
def _run_test_case(
self,
test_case: dict[str, Any],
agents: list[Agent] | list[BaseAgent],
crew: Crew | None = None,
) -> ExperimentResult:
inputs = test_case["inputs"]
expected_score = test_case["expected_score"]
identifier = test_case.get("identifier") or md5(str(test_case).encode(), usedforsecurity=False).hexdigest()
identifier = (
test_case.get("identifier")
or md5(str(test_case).encode(), usedforsecurity=False).hexdigest()
)
try:
self.display.console.print(f"[dim]Running crew with input: {str(inputs)[:50]}...[/dim]")
self.display.console.print(
f"[dim]Running crew with input: {str(inputs)[:50]}...[/dim]"
)
self.display.console.print("\n")
if crew:
crew.kickoff(inputs=inputs)
else:
for agent in agents:
agent.kickoff(**inputs)
if isinstance(agent, Agent):
agent.kickoff(**inputs)
else:
raise TypeError(
f"Agent {agent} is not an instance of Agent and cannot be kicked off directly"
)
assert self.evaluator is not None
if self.evaluator is None:
raise ValueError("Evaluator must be initialized")
agent_evaluations = self.evaluator.get_agent_evaluation()
actual_score = self._extract_scores(agent_evaluations)
@@ -61,35 +92,38 @@ class ExperimentRunner:
score=actual_score,
expected_score=expected_score,
passed=passed,
agent_evaluations=agent_evaluations
agent_evaluations=agent_evaluations,
)
except Exception as e:
self.display.console.print(f"[red]Error running test case: {str(e)}[/red]")
self.display.console.print(f"[red]Error running test case: {e!s}[/red]")
return ExperimentResult(
identifier=identifier,
inputs=inputs,
score=0,
score=0.0,
expected_score=expected_score,
passed=False
passed=False,
)
def _extract_scores(self, agent_evaluations: dict[str, AgentAggregatedEvaluationResult]) -> float | dict[str, float]:
def _extract_scores(
self, agent_evaluations: dict[str, AgentAggregatedEvaluationResult]
) -> float | dict[str, float]:
all_scores: dict[str, list[float]] = defaultdict(list)
for evaluation in agent_evaluations.values():
for metric_name, score in evaluation.metrics.items():
if score.score is not None:
all_scores[metric_name.value].append(score.score)
avg_scores = {m: sum(s)/len(s) for m, s in all_scores.items()}
avg_scores = {m: sum(s) / len(s) for m, s in all_scores.items()}
if len(avg_scores) == 1:
return list(avg_scores.values())[0]
return next(iter(avg_scores.values()))
return avg_scores
def _assert_scores(self, expected: float | dict[str, float],
actual: float | dict[str, float]) -> bool:
def _assert_scores(
self, expected: float | dict[str, float], actual: float | dict[str, float]
) -> bool:
"""
Compare expected and actual scores, and return whether the test case passed.
@@ -122,4 +156,4 @@ class ExperimentRunner:
# All matching keys must have actual >= expected
return all(actual[key] >= expected[key] for key in matching_keys)
return False
return False

View File

@@ -13,11 +13,11 @@ def extract_json_from_llm_response(text: str) -> dict[str, Any]:
json_patterns = [
# Standard markdown code blocks with json
r'```json\s*([\s\S]*?)\s*```',
r"```json\s*([\s\S]*?)\s*```",
# Code blocks without language specifier
r'```\s*([\s\S]*?)\s*```',
r"```\s*([\s\S]*?)\s*```",
# Inline code with JSON
r'`([{\\[].*[}\]])`',
r"`([{\\[].*[}\]])`",
]
for pattern in json_patterns:
@@ -25,6 +25,6 @@ def extract_json_from_llm_response(text: str) -> dict[str, Any]:
for match in matches:
try:
return json.loads(match.strip())
except json.JSONDecodeError:
except json.JSONDecodeError: # noqa: PERF203
continue
raise ValueError("No valid JSON found in the response")

View File

@@ -1,26 +1,21 @@
from crewai.experimental.evaluation.metrics.goal_metrics import GoalAlignmentEvaluator
from crewai.experimental.evaluation.metrics.reasoning_metrics import (
ReasoningEfficiencyEvaluator
ReasoningEfficiencyEvaluator,
)
from crewai.experimental.evaluation.metrics.tools_metrics import (
ToolSelectionEvaluator,
ParameterExtractionEvaluator,
ToolInvocationEvaluator
)
from crewai.experimental.evaluation.metrics.goal_metrics import (
GoalAlignmentEvaluator
)
from crewai.experimental.evaluation.metrics.semantic_quality_metrics import (
SemanticQualityEvaluator
SemanticQualityEvaluator,
)
from crewai.experimental.evaluation.metrics.tools_metrics import (
ParameterExtractionEvaluator,
ToolInvocationEvaluator,
ToolSelectionEvaluator,
)
__all__ = [
"ReasoningEfficiencyEvaluator",
"ToolSelectionEvaluator",
"ParameterExtractionEvaluator",
"ToolInvocationEvaluator",
"GoalAlignmentEvaluator",
"SemanticQualityEvaluator"
]
"ParameterExtractionEvaluator",
"ReasoningEfficiencyEvaluator",
"SemanticQualityEvaluator",
"ToolInvocationEvaluator",
"ToolSelectionEvaluator",
]

View File

@@ -1,10 +1,15 @@
from typing import Any, Dict
from typing import Any
from crewai.agent import Agent
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.experimental.evaluation.base_evaluator import (
BaseEvaluator,
EvaluationScore,
MetricCategory,
)
from crewai.experimental.evaluation.json_parser import extract_json_from_llm_response
from crewai.task import Task
from crewai.experimental.evaluation.base_evaluator import BaseEvaluator, EvaluationScore, MetricCategory
from crewai.experimental.evaluation.json_parser import extract_json_from_llm_response
class GoalAlignmentEvaluator(BaseEvaluator):
@property
@@ -13,8 +18,8 @@ class GoalAlignmentEvaluator(BaseEvaluator):
def evaluate(
self,
agent: Agent,
execution_trace: Dict[str, Any],
agent: Agent | BaseAgent,
execution_trace: dict[str, Any],
final_output: Any,
task: Task | None = None,
) -> EvaluationScore:
@@ -23,7 +28,9 @@ class GoalAlignmentEvaluator(BaseEvaluator):
task_context = f"Task description: {task.description}\nExpected output: {task.expected_output}\n"
prompt = [
{"role": "system", "content": """You are an expert evaluator assessing how well an AI agent's output aligns with its assigned task goal.
{
"role": "system",
"content": """You are an expert evaluator assessing how well an AI agent's output aligns with its assigned task goal.
Score the agent's goal alignment on a scale from 0-10 where:
- 0: Complete misalignment, agent did not understand or attempt the task goal
@@ -37,8 +44,11 @@ Consider:
4. Did the agent provide all requested information or deliverables?
Return your evaluation as JSON with fields 'score' (number) and 'feedback' (string).
"""},
{"role": "user", "content": f"""
""",
},
{
"role": "user",
"content": f"""
Agent role: {agent.role}
Agent goal: {agent.goal}
{task_context}
@@ -47,23 +57,26 @@ Agent's final output:
{final_output}
Evaluate how well the agent's output aligns with the assigned task goal.
"""}
""",
},
]
assert self.llm is not None
if self.llm is None:
raise ValueError("LLM must be initialized")
response = self.llm.call(prompt)
try:
evaluation_data: dict[str, Any] = extract_json_from_llm_response(response)
assert evaluation_data is not None
if evaluation_data is None:
raise ValueError("Failed to extract evaluation data from LLM response")
return EvaluationScore(
score=evaluation_data.get("score", 0),
feedback=evaluation_data.get("feedback", response),
raw_response=response
raw_response=response,
)
except Exception:
return EvaluationScore(
score=None,
feedback=f"Failed to parse evaluation. Raw response: {response}",
raw_response=response
raw_response=response,
)

View File

@@ -8,18 +8,24 @@ This module provides evaluator implementations for:
import logging
import re
from enum import Enum
from typing import Any, Dict, List, Tuple
import numpy as np
from collections.abc import Sequence
from enum import Enum
from typing import Any
import numpy as np
from crewai.agent import Agent
from crewai.task import Task
from crewai.experimental.evaluation.base_evaluator import BaseEvaluator, EvaluationScore, MetricCategory
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.experimental.evaluation.base_evaluator import (
BaseEvaluator,
EvaluationScore,
MetricCategory,
)
from crewai.experimental.evaluation.json_parser import extract_json_from_llm_response
from crewai.task import Task
from crewai.tasks.task_output import TaskOutput
class ReasoningPatternType(Enum):
EFFICIENT = "efficient" # Good reasoning flow
LOOP = "loop" # Agent is stuck in a loop
@@ -35,8 +41,8 @@ class ReasoningEfficiencyEvaluator(BaseEvaluator):
def evaluate(
self,
agent: Agent,
execution_trace: Dict[str, Any],
agent: Agent | BaseAgent,
execution_trace: dict[str, Any],
final_output: TaskOutput | str,
task: Task | None = None,
) -> EvaluationScore:
@@ -49,7 +55,7 @@ class ReasoningEfficiencyEvaluator(BaseEvaluator):
if not llm_calls or len(llm_calls) < 2:
return EvaluationScore(
score=None,
feedback="Insufficient LLM calls to evaluate reasoning efficiency."
feedback="Insufficient LLM calls to evaluate reasoning efficiency.",
)
total_calls = len(llm_calls)
@@ -58,12 +64,16 @@ class ReasoningEfficiencyEvaluator(BaseEvaluator):
time_intervals = []
has_reliable_timing = True
for i in range(1, len(llm_calls)):
start_time = llm_calls[i-1].get("end_time")
start_time = llm_calls[i - 1].get("end_time")
end_time = llm_calls[i].get("start_time")
if start_time and end_time and start_time != end_time:
try:
interval = end_time - start_time
time_intervals.append(interval.total_seconds() if hasattr(interval, 'total_seconds') else 0)
time_intervals.append(
interval.total_seconds()
if hasattr(interval, "total_seconds")
else 0
)
except Exception:
has_reliable_timing = False
else:
@@ -83,14 +93,22 @@ class ReasoningEfficiencyEvaluator(BaseEvaluator):
if has_reliable_timing and time_intervals:
efficiency_metrics["avg_time_between_calls"] = np.mean(time_intervals)
loop_info = f"Detected {len(loop_details)} potential reasoning loops." if loop_detected else "No significant reasoning loops detected."
loop_info = (
f"Detected {len(loop_details)} potential reasoning loops."
if loop_detected
else "No significant reasoning loops detected."
)
call_samples = self._get_call_samples(llm_calls)
final_output = final_output.raw if isinstance(final_output, TaskOutput) else final_output
final_output = (
final_output.raw if isinstance(final_output, TaskOutput) else final_output
)
prompt = [
{"role": "system", "content": """You are an expert evaluator assessing the reasoning efficiency of an AI agent's thought process.
{
"role": "system",
"content": """You are an expert evaluator assessing the reasoning efficiency of an AI agent's thought process.
Evaluate the agent's reasoning efficiency across these five key subcategories:
@@ -120,8 +138,11 @@ Return your evaluation as JSON with the following structure:
"feedback": string (general feedback about overall reasoning efficiency),
"optimization_suggestions": string (concrete suggestions for improving reasoning efficiency),
"detected_patterns": string (describe any inefficient reasoning patterns you observe)
}"""},
{"role": "user", "content": f"""
}""",
},
{
"role": "user",
"content": f"""
Agent role: {agent.role}
{task_context}
@@ -140,10 +161,12 @@ Agent's final output:
Evaluate the reasoning efficiency of this agent based on these interaction patterns.
Identify any inefficient reasoning patterns and provide specific suggestions for optimization.
"""}
""",
},
]
assert self.llm is not None
if self.llm is None:
raise ValueError("LLM must be initialized")
response = self.llm.call(prompt)
try:
@@ -156,34 +179,46 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
conciseness = scores.get("conciseness", 5.0)
loop_avoidance = scores.get("loop_avoidance", 5.0)
overall_score = evaluation_data.get("overall_score", evaluation_data.get("score", 5.0))
overall_score = evaluation_data.get(
"overall_score", evaluation_data.get("score", 5.0)
)
feedback = evaluation_data.get("feedback", "No detailed feedback provided.")
optimization_suggestions = evaluation_data.get("optimization_suggestions", "No specific suggestions provided.")
optimization_suggestions = evaluation_data.get(
"optimization_suggestions", "No specific suggestions provided."
)
detailed_feedback = "Reasoning Efficiency Evaluation:\n"
detailed_feedback += f"• Focus: {focus}/10 - Staying on topic without tangents\n"
detailed_feedback += f"• Progression: {progression}/10 - Building on previous thinking\n"
detailed_feedback += (
f"• Focus: {focus}/10 - Staying on topic without tangents\n"
)
detailed_feedback += (
f"• Progression: {progression}/10 - Building on previous thinking\n"
)
detailed_feedback += f"• Decision Quality: {decision_quality}/10 - Making appropriate decisions\n"
detailed_feedback += f"• Conciseness: {conciseness}/10 - Communicating efficiently\n"
detailed_feedback += (
f"• Conciseness: {conciseness}/10 - Communicating efficiently\n"
)
detailed_feedback += f"• Loop Avoidance: {loop_avoidance}/10 - Avoiding repetitive patterns\n\n"
detailed_feedback += f"Feedback:\n{feedback}\n\n"
detailed_feedback += f"Optimization Suggestions:\n{optimization_suggestions}"
detailed_feedback += (
f"Optimization Suggestions:\n{optimization_suggestions}"
)
return EvaluationScore(
score=float(overall_score),
feedback=detailed_feedback,
raw_response=response
raw_response=response,
)
except Exception as e:
logging.warning(f"Failed to parse reasoning efficiency evaluation: {e}")
return EvaluationScore(
score=None,
feedback=f"Failed to parse reasoning efficiency evaluation. Raw response: {response[:200]}...",
raw_response=response
raw_response=response,
)
def _detect_loops(self, llm_calls: List[Dict]) -> Tuple[bool, List[Dict]]:
def _detect_loops(self, llm_calls: list[dict]) -> tuple[bool, list[dict]]:
loop_details = []
messages = []
@@ -193,9 +228,11 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
messages.append(content)
elif isinstance(content, list) and len(content) > 0:
# Handle message list format
for msg in content:
if isinstance(msg, dict) and "content" in msg:
messages.append(msg["content"])
messages.extend(
msg["content"]
for msg in content
if isinstance(msg, dict) and "content" in msg
)
# Simple n-gram based similarity detection
# For a more robust implementation, consider using embedding-based similarity
@@ -205,18 +242,20 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
# A more sophisticated approach would use semantic similarity
similarity = self._calculate_text_similarity(messages[i], messages[j])
if similarity > 0.7: # Arbitrary threshold
loop_details.append({
"first_occurrence": i,
"second_occurrence": j,
"similarity": similarity,
"snippet": messages[i][:100] + "..."
})
loop_details.append(
{
"first_occurrence": i,
"second_occurrence": j,
"similarity": similarity,
"snippet": messages[i][:100] + "...",
}
)
return len(loop_details) > 0, loop_details
def _calculate_text_similarity(self, text1: str, text2: str) -> float:
text1 = re.sub(r'\s+', ' ', text1.lower()).strip()
text2 = re.sub(r'\s+', ' ', text2.lower()).strip()
text1 = re.sub(r"\s+", " ", text1.lower()).strip()
text2 = re.sub(r"\s+", " ", text2.lower()).strip()
# Simple Jaccard similarity on word sets
words1 = set(text1.split())
@@ -227,7 +266,7 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
return intersection / union if union > 0 else 0.0
def _analyze_reasoning_patterns(self, llm_calls: List[Dict]) -> Dict[str, Any]:
def _analyze_reasoning_patterns(self, llm_calls: list[dict]) -> dict[str, Any]:
call_lengths = []
response_times = []
@@ -248,8 +287,8 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
if start_time and end_time:
try:
response_times.append(end_time - start_time)
except Exception:
pass
except Exception as e:
logging.debug(f"Failed to calculate response time: {e}")
avg_length = np.mean(call_lengths) if call_lengths else 0
std_length = np.std(call_lengths) if call_lengths else 0
@@ -267,7 +306,9 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
details = "Agent is consistently verbose across interactions."
elif len(llm_calls) > 10 and length_trend > 0.5:
primary_pattern = ReasoningPatternType.INDECISIVE
details = "Agent shows signs of indecisiveness with increasing message lengths."
details = (
"Agent shows signs of indecisiveness with increasing message lengths."
)
elif std_length / avg_length > 0.8:
primary_pattern = ReasoningPatternType.SCATTERED
details = "Agent shows inconsistent reasoning flow with highly variable responses."
@@ -279,8 +320,8 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
"avg_length": avg_length,
"std_length": std_length,
"length_trend": length_trend,
"loop_score": loop_score
}
"loop_score": loop_score,
},
}
def _calculate_trend(self, values: Sequence[float | int]) -> float:
@@ -303,7 +344,9 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
except Exception:
return 0.0
def _calculate_loop_likelihood(self, call_lengths: Sequence[float], response_times: Sequence[float]) -> float:
def _calculate_loop_likelihood(
self, call_lengths: Sequence[float], response_times: Sequence[float]
) -> float:
if not call_lengths or len(call_lengths) < 3:
return 0.0
@@ -312,7 +355,11 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
if len(call_lengths) >= 4:
repeated_lengths = 0
for i in range(len(call_lengths) - 2):
ratio = call_lengths[i] / call_lengths[i + 2] if call_lengths[i + 2] > 0 else 0
ratio = (
call_lengths[i] / call_lengths[i + 2]
if call_lengths[i + 2] > 0
else 0
)
if 0.85 <= ratio <= 1.15:
repeated_lengths += 1
@@ -324,21 +371,27 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
std_time = np.std(response_times)
mean_time = np.mean(response_times)
if mean_time > 0:
time_consistency = 1.0 - (std_time / mean_time)
indicators.append(max(0, time_consistency - 0.3) * 1.5)
except Exception:
pass
time_consistency = 1.0 - (float(std_time) / float(mean_time))
indicators.append(max(0.0, float(time_consistency - 0.3)) * 1.5)
except Exception as e:
logging.debug(f"Time consistency calculation failed: {e}")
return np.mean(indicators) if indicators else 0.0
return float(np.mean(indicators)) if indicators else 0.0
def _get_call_samples(self, llm_calls: List[Dict]) -> str:
def _get_call_samples(self, llm_calls: list[dict]) -> str:
samples = []
if len(llm_calls) <= 6:
sample_indices = list(range(len(llm_calls)))
else:
sample_indices = [0, 1, len(llm_calls) // 2 - 1, len(llm_calls) // 2,
len(llm_calls) - 2, len(llm_calls) - 1]
sample_indices = [
0,
1,
len(llm_calls) // 2 - 1,
len(llm_calls) // 2,
len(llm_calls) - 2,
len(llm_calls) - 1,
]
for idx in sample_indices:
call = llm_calls[idx]
@@ -347,10 +400,11 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
if isinstance(content, str):
sample = content
elif isinstance(content, list) and len(content) > 0:
sample_parts = []
for msg in content:
if isinstance(msg, dict) and "content" in msg:
sample_parts.append(msg["content"])
sample_parts = [
msg["content"]
for msg in content
if isinstance(msg, dict) and "content" in msg
]
sample = "\n".join(sample_parts)
else:
sample = str(content)

View File

@@ -1,10 +1,15 @@
from typing import Any, Dict
from typing import Any
from crewai.agent import Agent
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.experimental.evaluation.base_evaluator import (
BaseEvaluator,
EvaluationScore,
MetricCategory,
)
from crewai.experimental.evaluation.json_parser import extract_json_from_llm_response
from crewai.task import Task
from crewai.experimental.evaluation.base_evaluator import BaseEvaluator, EvaluationScore, MetricCategory
from crewai.experimental.evaluation.json_parser import extract_json_from_llm_response
class SemanticQualityEvaluator(BaseEvaluator):
@property
@@ -13,8 +18,8 @@ class SemanticQualityEvaluator(BaseEvaluator):
def evaluate(
self,
agent: Agent,
execution_trace: Dict[str, Any],
agent: Agent | BaseAgent,
execution_trace: dict[str, Any],
final_output: Any,
task: Task | None = None,
) -> EvaluationScore:
@@ -22,7 +27,9 @@ class SemanticQualityEvaluator(BaseEvaluator):
if task is not None:
task_context = f"Task description: {task.description}"
prompt = [
{"role": "system", "content": """You are an expert evaluator assessing the semantic quality of an AI agent's output.
{
"role": "system",
"content": """You are an expert evaluator assessing the semantic quality of an AI agent's output.
Score the semantic quality on a scale from 0-10 where:
- 0: Completely incoherent, confusing, or logically flawed output
@@ -37,8 +44,11 @@ Consider:
5. Is the output free from contradictions and logical fallacies?
Return your evaluation as JSON with fields 'score' (number) and 'feedback' (string).
"""},
{"role": "user", "content": f"""
""",
},
{
"role": "user",
"content": f"""
Agent role: {agent.role}
{task_context}
@@ -46,23 +56,28 @@ Agent's final output:
{final_output}
Evaluate the semantic quality and reasoning of this output.
"""}
""",
},
]
assert self.llm is not None
if self.llm is None:
raise ValueError("LLM must be initialized")
response = self.llm.call(prompt)
try:
evaluation_data: dict[str, Any] = extract_json_from_llm_response(response)
assert evaluation_data is not None
if evaluation_data is None:
raise ValueError("Failed to extract evaluation data from LLM response")
return EvaluationScore(
score=float(evaluation_data["score"]) if evaluation_data.get("score") is not None else None,
score=float(evaluation_data["score"])
if evaluation_data.get("score") is not None
else None,
feedback=evaluation_data.get("feedback", response),
raw_response=response
raw_response=response,
)
except Exception:
return EvaluationScore(
score=None,
feedback=f"Failed to parse evaluation. Raw response: {response}",
raw_response=response
)
raw_response=response,
)

View File

@@ -1,22 +1,26 @@
import json
from typing import Dict, Any
from typing import Any
from crewai.experimental.evaluation.base_evaluator import BaseEvaluator, EvaluationScore, MetricCategory
from crewai.experimental.evaluation.json_parser import extract_json_from_llm_response
from crewai.agent import Agent
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.experimental.evaluation.base_evaluator import (
BaseEvaluator,
EvaluationScore,
MetricCategory,
)
from crewai.experimental.evaluation.json_parser import extract_json_from_llm_response
from crewai.task import Task
class ToolSelectionEvaluator(BaseEvaluator):
@property
def metric_category(self) -> MetricCategory:
return MetricCategory.TOOL_SELECTION
def evaluate(
self,
agent: Agent,
execution_trace: Dict[str, Any],
agent: Agent | BaseAgent,
execution_trace: dict[str, Any],
final_output: str,
task: Task | None = None,
) -> EvaluationScore:
@@ -26,19 +30,18 @@ class ToolSelectionEvaluator(BaseEvaluator):
tool_uses = execution_trace.get("tool_uses", [])
tool_count = len(tool_uses)
unique_tool_types = set([tool.get("tool", "Unknown tool") for tool in tool_uses])
unique_tool_types = set(
[tool.get("tool", "Unknown tool") for tool in tool_uses]
)
if tool_count == 0:
if not agent.tools:
return EvaluationScore(
score=None,
feedback="Agent had no tools available to use."
)
else:
return EvaluationScore(
score=None,
feedback="Agent had tools available but didn't use any."
score=None, feedback="Agent had no tools available to use."
)
return EvaluationScore(
score=None, feedback="Agent had tools available but didn't use any."
)
available_tools_info = ""
if agent.tools:
@@ -52,7 +55,9 @@ class ToolSelectionEvaluator(BaseEvaluator):
tool_types_summary += f"- {tool_type}\n"
prompt = [
{"role": "system", "content": """You are an expert evaluator assessing if an AI agent selected the most appropriate tools for a given task.
{
"role": "system",
"content": """You are an expert evaluator assessing if an AI agent selected the most appropriate tools for a given task.
You must evaluate based on these 2 criteria:
1. Relevance (0-10): Were the tools chosen directly aligned with the task's goals?
@@ -73,8 +78,11 @@ Return your evaluation as JSON with these fields:
- overall_score: number (average of all scores, 0-10)
- feedback: string (focused ONLY on tool selection decisions from available tools)
- improvement_suggestions: string (ONLY suggest better selection from the AVAILABLE tools list, NOT new tools)
"""},
{"role": "user", "content": f"""
""",
},
{
"role": "user",
"content": f"""
Agent role: {agent.role}
{task_context}
@@ -89,14 +97,17 @@ IMPORTANT:
- ONLY evaluate selection from tools listed as available
- DO NOT suggest new tools that aren't in the available tools list
- DO NOT evaluate tool usage or results
"""}
""",
},
]
assert self.llm is not None
if self.llm is None:
raise ValueError("LLM must be initialized")
response = self.llm.call(prompt)
try:
evaluation_data = extract_json_from_llm_response(response)
assert evaluation_data is not None
if evaluation_data is None:
raise ValueError("Failed to extract evaluation data from LLM response")
scores = evaluation_data.get("scores", {})
relevance = scores.get("relevance", 5.0)
@@ -105,22 +116,24 @@ IMPORTANT:
feedback = "Tool Selection Evaluation:\n"
feedback += f"• Relevance: {relevance}/10 - Selection of appropriate tool types for the task\n"
feedback += f"• Coverage: {coverage}/10 - Selection of all necessary tool types\n"
feedback += (
f"• Coverage: {coverage}/10 - Selection of all necessary tool types\n"
)
if "improvement_suggestions" in evaluation_data:
feedback += f"Improvement Suggestions:\n{evaluation_data['improvement_suggestions']}"
else:
feedback += evaluation_data.get("feedback", "No detailed feedback available.")
feedback += evaluation_data.get(
"feedback", "No detailed feedback available."
)
return EvaluationScore(
score=overall_score,
feedback=feedback,
raw_response=response
score=overall_score, feedback=feedback, raw_response=response
)
except Exception as e:
return EvaluationScore(
score=None,
feedback=f"Error evaluating tool selection: {e}",
raw_response=response
raw_response=response,
)
@@ -131,8 +144,8 @@ class ParameterExtractionEvaluator(BaseEvaluator):
def evaluate(
self,
agent: Agent,
execution_trace: Dict[str, Any],
agent: Agent | BaseAgent,
execution_trace: dict[str, Any],
final_output: str,
task: Task | None = None,
) -> EvaluationScore:
@@ -145,19 +158,23 @@ class ParameterExtractionEvaluator(BaseEvaluator):
if tool_count == 0:
return EvaluationScore(
score=None,
feedback="No tool usage detected. Cannot evaluate parameter extraction."
feedback="No tool usage detected. Cannot evaluate parameter extraction.",
)
validation_errors = []
for tool_use in tool_uses:
if not tool_use.get("success", True) and tool_use.get("error_type") == "validation_error":
validation_errors.append({
"tool": tool_use.get("tool", "Unknown tool"),
"error": tool_use.get("result"),
"args": tool_use.get("args", {})
})
validation_errors = [
{
"tool": tool_use.get("tool", "Unknown tool"),
"error": tool_use.get("result"),
"args": tool_use.get("args", {}),
}
for tool_use in tool_uses
if not tool_use.get("success", True)
and tool_use.get("error_type") == "validation_error"
]
validation_error_rate = len(validation_errors) / tool_count if tool_count > 0 else 0
validation_error_rate = (
len(validation_errors) / tool_count if tool_count > 0 else 0
)
param_samples = []
for i, tool_use in enumerate(tool_uses[:5]):
@@ -168,7 +185,7 @@ class ParameterExtractionEvaluator(BaseEvaluator):
is_validation_error = error_type == "validation_error"
sample = f"Tool use #{i+1} - {tool_name}:\n"
sample = f"Tool use #{i + 1} - {tool_name}:\n"
sample += f"- Parameters: {json.dumps(tool_args, indent=2)}\n"
sample += f"- Success: {'No' if not success else 'Yes'}"
@@ -187,13 +204,17 @@ class ParameterExtractionEvaluator(BaseEvaluator):
tool_name = err.get("tool", "Unknown tool")
error_msg = err.get("error", "Unknown error")
args = err.get("args", {})
validation_errors_info += f"\nValidation Error #{i+1}:\n- Tool: {tool_name}\n- Args: {json.dumps(args, indent=2)}\n- Error: {error_msg}"
validation_errors_info += f"\nValidation Error #{i + 1}:\n- Tool: {tool_name}\n- Args: {json.dumps(args, indent=2)}\n- Error: {error_msg}"
if len(validation_errors) > 3:
validation_errors_info += f"\n...and {len(validation_errors) - 3} more validation errors."
validation_errors_info += (
f"\n...and {len(validation_errors) - 3} more validation errors."
)
param_samples_text = "\n\n".join(param_samples)
prompt = [
{"role": "system", "content": """You are an expert evaluator assessing how well an AI agent extracts and formats PARAMETER VALUES for tool calls.
{
"role": "system",
"content": """You are an expert evaluator assessing how well an AI agent extracts and formats PARAMETER VALUES for tool calls.
Your job is to evaluate ONLY whether the agent used the correct parameter VALUES, not whether the right tools were selected or how the tools were invoked.
@@ -216,8 +237,11 @@ Return your evaluation as JSON with these fields:
- overall_score: number (average of all scores, 0-10)
- feedback: string (focused ONLY on parameter value extraction quality)
- improvement_suggestions: string (concrete suggestions for better parameter VALUE extraction)
"""},
{"role": "user", "content": f"""
""",
},
{
"role": "user",
"content": f"""
Agent role: {agent.role}
{task_context}
@@ -226,15 +250,18 @@ Parameter extraction examples:
{validation_errors_info}
Evaluate the quality of the agent's parameter extraction for this task.
"""}
""",
},
]
assert self.llm is not None
if self.llm is None:
raise ValueError("LLM must be initialized")
response = self.llm.call(prompt)
try:
evaluation_data = extract_json_from_llm_response(response)
assert evaluation_data is not None
if evaluation_data is None:
raise ValueError("Failed to extract evaluation data from LLM response")
scores = evaluation_data.get("scores", {})
accuracy = scores.get("accuracy", 5.0)
@@ -251,18 +278,18 @@ Evaluate the quality of the agent's parameter extraction for this task.
if "improvement_suggestions" in evaluation_data:
feedback += f"Improvement Suggestions:\n{evaluation_data['improvement_suggestions']}"
else:
feedback += evaluation_data.get("feedback", "No detailed feedback available.")
feedback += evaluation_data.get(
"feedback", "No detailed feedback available."
)
return EvaluationScore(
score=overall_score,
feedback=feedback,
raw_response=response
score=overall_score, feedback=feedback, raw_response=response
)
except Exception as e:
return EvaluationScore(
score=None,
feedback=f"Error evaluating parameter extraction: {e}",
raw_response=response
raw_response=response,
)
@@ -273,8 +300,8 @@ class ToolInvocationEvaluator(BaseEvaluator):
def evaluate(
self,
agent: Agent,
execution_trace: Dict[str, Any],
agent: Agent | BaseAgent,
execution_trace: dict[str, Any],
final_output: str,
task: Task | None = None,
) -> EvaluationScore:
@@ -288,7 +315,7 @@ class ToolInvocationEvaluator(BaseEvaluator):
if tool_count == 0:
return EvaluationScore(
score=None,
feedback="No tool usage detected. Cannot evaluate tool invocation."
feedback="No tool usage detected. Cannot evaluate tool invocation.",
)
for tool_use in tool_uses:
@@ -296,7 +323,7 @@ class ToolInvocationEvaluator(BaseEvaluator):
error_info = {
"tool": tool_use.get("tool", "Unknown tool"),
"error": tool_use.get("result"),
"error_type": tool_use.get("error_type", "unknown_error")
"error_type": tool_use.get("error_type", "unknown_error"),
}
tool_errors.append(error_info)
@@ -315,9 +342,11 @@ class ToolInvocationEvaluator(BaseEvaluator):
tool_args = tool_use.get("args", {})
success = tool_use.get("success", True) and not tool_use.get("error", False)
error_type = tool_use.get("error_type", "") if not success else ""
error_msg = tool_use.get("result", "No error") if not success else "No error"
error_msg = (
tool_use.get("result", "No error") if not success else "No error"
)
sample = f"Tool invocation #{i+1}:\n"
sample = f"Tool invocation #{i + 1}:\n"
sample += f"- Tool: {tool_name}\n"
sample += f"- Parameters: {json.dumps(tool_args, indent=2)}\n"
sample += f"- Success: {'No' if not success else 'Yes'}\n"
@@ -330,11 +359,13 @@ class ToolInvocationEvaluator(BaseEvaluator):
if error_types:
error_type_summary = "Error type breakdown:\n"
for error_type, count in error_types.items():
error_type_summary += f"- {error_type}: {count} occurrences ({(count/tool_count):.1%})\n"
error_type_summary += f"- {error_type}: {count} occurrences ({(count / tool_count):.1%})\n"
invocation_samples_text = "\n\n".join(invocation_samples)
prompt = [
{"role": "system", "content": """You are an expert evaluator assessing how correctly an AI agent's tool invocations are STRUCTURED.
{
"role": "system",
"content": """You are an expert evaluator assessing how correctly an AI agent's tool invocations are STRUCTURED.
Your job is to evaluate ONLY the structural and syntactical aspects of how the agent called tools, NOT which tools were selected or what parameter values were used.
@@ -359,8 +390,11 @@ Return your evaluation as JSON with these fields:
- overall_score: number (average of all scores, 0-10)
- feedback: string (focused ONLY on structural aspects of tool invocation)
- improvement_suggestions: string (concrete suggestions for better structuring of tool calls)
"""},
{"role": "user", "content": f"""
""",
},
{
"role": "user",
"content": f"""
Agent role: {agent.role}
{task_context}
@@ -371,15 +405,18 @@ Tool error rate: {error_rate:.2%} ({len(tool_errors)} errors out of {tool_count}
{error_type_summary}
Evaluate the quality of the agent's tool invocation structure during this task.
"""}
""",
},
]
assert self.llm is not None
if self.llm is None:
raise ValueError("LLM must be initialized")
response = self.llm.call(prompt)
try:
evaluation_data = extract_json_from_llm_response(response)
assert evaluation_data is not None
if evaluation_data is None:
raise ValueError("Failed to extract evaluation data from LLM response")
scores = evaluation_data.get("scores", {})
structure = scores.get("structure", 5.0)
error_handling = scores.get("error_handling", 5.0)
@@ -388,23 +425,25 @@ Evaluate the quality of the agent's tool invocation structure during this task.
overall_score = float(evaluation_data.get("overall_score", 5.0))
feedback = "Tool Invocation Evaluation:\n"
feedback += f"• Structure: {structure}/10 - Following proper syntax and format\n"
feedback += (
f"• Structure: {structure}/10 - Following proper syntax and format\n"
)
feedback += f"• Error Handling: {error_handling}/10 - Appropriately handling tool errors\n"
feedback += f"• Invocation Patterns: {invocation_patterns}/10 - Proper sequencing and management of calls\n\n"
if "improvement_suggestions" in evaluation_data:
feedback += f"Improvement Suggestions:\n{evaluation_data['improvement_suggestions']}"
else:
feedback += evaluation_data.get("feedback", "No detailed feedback available.")
feedback += evaluation_data.get(
"feedback", "No detailed feedback available."
)
return EvaluationScore(
score=overall_score,
feedback=feedback,
raw_response=response
score=overall_score, feedback=feedback, raw_response=response
)
except Exception as e:
return EvaluationScore(
score=None,
feedback=f"Error evaluating tool invocation: {e}",
raw_response=response
raw_response=response,
)

View File

@@ -1,12 +1,21 @@
import inspect
import warnings
from typing_extensions import Any
import warnings
from crewai.experimental.evaluation.experiment import ExperimentResults, ExperimentRunner
from crewai import Crew, Agent
def assert_experiment_successfully(experiment_results: ExperimentResults, baseline_filepath: str | None = None) -> None:
failed_tests = [result for result in experiment_results.results if not result.passed]
from crewai import Agent, Crew
from crewai.experimental.evaluation.experiment import (
ExperimentResults,
ExperimentRunner,
)
def assert_experiment_successfully(
experiment_results: ExperimentResults, baseline_filepath: str | None = None
) -> None:
failed_tests = [
result for result in experiment_results.results if not result.passed
]
if failed_tests:
detailed_failures: list[str] = []
@@ -14,39 +23,54 @@ def assert_experiment_successfully(experiment_results: ExperimentResults, baseli
for result in failed_tests:
expected = result.expected_score
actual = result.score
detailed_failures.append(f"- {result.identifier}: expected {expected}, got {actual}")
detailed_failures.append(
f"- {result.identifier}: expected {expected}, got {actual}"
)
failure_details = "\n".join(detailed_failures)
raise AssertionError(f"The following test cases failed:\n{failure_details}")
baseline_filepath = baseline_filepath or _get_baseline_filepath_fallback()
comparison = experiment_results.compare_with_baseline(baseline_filepath=baseline_filepath)
comparison = experiment_results.compare_with_baseline(
baseline_filepath=baseline_filepath
)
assert_experiment_no_regression(comparison)
def assert_experiment_no_regression(comparison_result: dict[str, list[str]]) -> None:
regressed = comparison_result.get("regressed", [])
if regressed:
raise AssertionError(f"Regression detected! The following tests that previously passed now fail: {regressed}")
raise AssertionError(
f"Regression detected! The following tests that previously passed now fail: {regressed}"
)
missing_tests = comparison_result.get("missing_tests", [])
if missing_tests:
warnings.warn(
f"Warning: {len(missing_tests)} tests from the baseline are missing in the current run: {missing_tests}",
UserWarning
UserWarning,
stacklevel=2,
)
def run_experiment(dataset: list[dict[str, Any]], crew: Crew | None = None, agents: list[Agent] | None = None, verbose: bool = False) -> ExperimentResults:
def run_experiment(
dataset: list[dict[str, Any]],
crew: Crew | None = None,
agents: list[Agent] | None = None,
verbose: bool = False,
) -> ExperimentResults:
runner = ExperimentRunner(dataset=dataset)
return runner.run(agents=agents, crew=crew, print_summary=verbose)
def _get_baseline_filepath_fallback() -> str:
test_func_name = "experiment_fallback"
try:
current_frame = inspect.currentframe()
if current_frame is not None:
test_func_name = current_frame.f_back.f_back.f_code.co_name # type: ignore[union-attr]
test_func_name = current_frame.f_back.f_back.f_code.co_name # type: ignore[union-attr]
except Exception:
...
return f"{test_func_name}_results.json"
return f"{test_func_name}_results.json"

View File

@@ -1,5 +1,4 @@
from crewai.flow.flow import Flow, start, listen, or_, and_, router
from crewai.flow.flow import Flow, and_, listen, or_, router, start
from crewai.flow.persistence import persist
__all__ = ["Flow", "start", "listen", "or_", "and_", "router", "persist"]
__all__ = ["Flow", "and_", "listen", "or_", "persist", "router", "start"]

View File

@@ -1,5 +1,4 @@
import inspect
from typing import Optional
from pydantic import BaseModel, Field, InstanceOf, model_validator
@@ -14,7 +13,7 @@ class FlowTrackable(BaseModel):
inspecting the call stack.
"""
parent_flow: Optional[InstanceOf[Flow]] = Field(
parent_flow: InstanceOf[Flow] | None = Field(
default=None,
description="The parent flow of the instance, if it was created inside a flow.",
)

View File

@@ -1,14 +1,13 @@
# flow_visualizer.py
import os
from pathlib import Path
from pyvis.network import Network
from pyvis.network import Network # type: ignore[import-untyped]
from crewai.flow.config import COLORS, NODE_STYLES
from crewai.flow.html_template_handler import HTMLTemplateHandler
from crewai.flow.legend_generator import generate_legend_items_html, get_legend_items
from crewai.flow.path_utils import safe_path_join, validate_path_exists
from crewai.flow.path_utils import safe_path_join
from crewai.flow.utils import calculate_node_levels
from crewai.flow.visualization_utils import (
add_edges,
@@ -34,13 +33,13 @@ class FlowPlot:
ValueError
If flow object is invalid or missing required attributes.
"""
if not hasattr(flow, '_methods'):
if not hasattr(flow, "_methods"):
raise ValueError("Invalid flow object: missing '_methods' attribute")
if not hasattr(flow, '_listeners'):
if not hasattr(flow, "_listeners"):
raise ValueError("Invalid flow object: missing '_listeners' attribute")
if not hasattr(flow, '_start_methods'):
if not hasattr(flow, "_start_methods"):
raise ValueError("Invalid flow object: missing '_start_methods' attribute")
self.flow = flow
self.colors = COLORS
self.node_styles = NODE_STYLES
@@ -65,7 +64,7 @@ class FlowPlot:
"""
if not filename or not isinstance(filename, str):
raise ValueError("Filename must be a non-empty string")
try:
# Initialize network
net = Network(
@@ -96,32 +95,34 @@ class FlowPlot:
try:
node_levels = calculate_node_levels(self.flow)
except Exception as e:
raise ValueError(f"Failed to calculate node levels: {str(e)}")
raise ValueError(f"Failed to calculate node levels: {e!s}") from e
# Compute positions
try:
node_positions = compute_positions(self.flow, node_levels)
except Exception as e:
raise ValueError(f"Failed to compute node positions: {str(e)}")
raise ValueError(f"Failed to compute node positions: {e!s}") from e
# Add nodes to the network
try:
add_nodes_to_network(net, self.flow, node_positions, self.node_styles)
except Exception as e:
raise RuntimeError(f"Failed to add nodes to network: {str(e)}")
raise RuntimeError(f"Failed to add nodes to network: {e!s}") from e
# Add edges to the network
try:
add_edges(net, self.flow, node_positions, self.colors)
except Exception as e:
raise RuntimeError(f"Failed to add edges to network: {str(e)}")
raise RuntimeError(f"Failed to add edges to network: {e!s}") from e
# Generate HTML
try:
network_html = net.generate_html()
final_html_content = self._generate_final_html(network_html)
except Exception as e:
raise RuntimeError(f"Failed to generate network visualization: {str(e)}")
raise RuntimeError(
f"Failed to generate network visualization: {e!s}"
) from e
# Save the final HTML content to the file
try:
@@ -129,12 +130,16 @@ class FlowPlot:
f.write(final_html_content)
print(f"Plot saved as {filename}.html")
except IOError as e:
raise IOError(f"Failed to save flow visualization to {filename}.html: {str(e)}")
raise IOError(
f"Failed to save flow visualization to {filename}.html: {e!s}"
) from e
except (ValueError, RuntimeError, IOError) as e:
raise e
except Exception as e:
raise RuntimeError(f"Unexpected error during flow visualization: {str(e)}")
raise RuntimeError(
f"Unexpected error during flow visualization: {e!s}"
) from e
finally:
self._cleanup_pyvis_lib()
@@ -165,7 +170,9 @@ class FlowPlot:
try:
# Extract just the body content from the generated HTML
current_dir = os.path.dirname(__file__)
template_path = safe_path_join("assets", "crewai_flow_visual_template.html", root=current_dir)
template_path = safe_path_join(
"assets", "crewai_flow_visual_template.html", root=current_dir
)
logo_path = safe_path_join("assets", "crewai_logo.svg", root=current_dir)
if not os.path.exists(template_path):
@@ -179,12 +186,9 @@ class FlowPlot:
# Generate the legend items HTML
legend_items = get_legend_items(self.colors)
legend_items_html = generate_legend_items_html(legend_items)
final_html_content = html_handler.generate_final_html(
network_body, legend_items_html
)
return final_html_content
return html_handler.generate_final_html(network_body, legend_items_html)
except Exception as e:
raise IOError(f"Failed to generate visualization HTML: {str(e)}")
raise IOError(f"Failed to generate visualization HTML: {e!s}") from e
def _cleanup_pyvis_lib(self):
"""
@@ -197,6 +201,7 @@ class FlowPlot:
lib_folder = safe_path_join("lib", root=os.getcwd())
if os.path.exists(lib_folder) and os.path.isdir(lib_folder):
import shutil
shutil.rmtree(lib_folder)
except ValueError as e:
print(f"Error validating lib folder path: {e}")

View File

@@ -1,8 +1,7 @@
import base64
import re
from pathlib import Path
from crewai.flow.path_utils import safe_path_join, validate_path_exists
from crewai.flow.path_utils import validate_path_exists
class HTMLTemplateHandler:
@@ -28,7 +27,7 @@ class HTMLTemplateHandler:
self.template_path = validate_path_exists(template_path, "file")
self.logo_path = validate_path_exists(logo_path, "file")
except ValueError as e:
raise ValueError(f"Invalid template or logo path: {e}")
raise ValueError(f"Invalid template or logo path: {e}") from e
def read_template(self):
"""Read and return the HTML template file contents."""
@@ -53,23 +52,23 @@ class HTMLTemplateHandler:
if "border" in item:
legend_items_html += f"""
<div class="legend-item">
<div class="legend-color-box" style="background-color: {item['color']}; border: 2px dashed {item['border']};"></div>
<div>{item['label']}</div>
<div class="legend-color-box" style="background-color: {item["color"]}; border: 2px dashed {item["border"]};"></div>
<div>{item["label"]}</div>
</div>
"""
elif item.get("dashed") is not None:
style = "dashed" if item["dashed"] else "solid"
legend_items_html += f"""
<div class="legend-item">
<div class="legend-{style}" style="border-bottom: 2px {style} {item['color']};"></div>
<div>{item['label']}</div>
<div class="legend-{style}" style="border-bottom: 2px {style} {item["color"]};"></div>
<div>{item["label"]}</div>
</div>
"""
else:
legend_items_html += f"""
<div class="legend-item">
<div class="legend-color-box" style="background-color: {item['color']};"></div>
<div>{item['label']}</div>
<div class="legend-color-box" style="background-color: {item["color"]};"></div>
<div>{item["label"]}</div>
</div>
"""
return legend_items_html
@@ -79,15 +78,9 @@ class HTMLTemplateHandler:
html_template = self.read_template()
logo_svg_base64 = self.encode_logo()
final_html_content = html_template.replace("{{ title }}", title)
final_html_content = final_html_content.replace(
"{{ network_content }}", network_body
return (
html_template.replace("{{ title }}", title)
.replace("{{ network_content }}", network_body)
.replace("{{ logo_svg_base64 }}", logo_svg_base64)
.replace("<!-- LEGEND_ITEMS_PLACEHOLDER -->", legend_items_html)
)
final_html_content = final_html_content.replace(
"{{ logo_svg_base64 }}", logo_svg_base64
)
final_html_content = final_html_content.replace(
"<!-- LEGEND_ITEMS_PLACEHOLDER -->", legend_items_html
)
return final_html_content

View File

@@ -5,12 +5,10 @@ This module provides utilities for secure path handling to prevent directory
traversal attacks and ensure paths remain within allowed boundaries.
"""
import os
from pathlib import Path
from typing import List, Union
def safe_path_join(*parts: str, root: Union[str, Path, None] = None) -> str:
def safe_path_join(*parts: str, root: str | Path | None = None) -> str:
"""
Safely join path components and ensure the result is within allowed boundaries.
@@ -43,25 +41,25 @@ def safe_path_join(*parts: str, root: Union[str, Path, None] = None) -> str:
# Establish root directory
root_path = Path(root).resolve() if root else Path.cwd()
# Join and resolve the full path
full_path = Path(root_path, *clean_parts).resolve()
# Check if the resolved path is within root
if not str(full_path).startswith(str(root_path)):
raise ValueError(
f"Invalid path: Potential directory traversal. Path must be within {root_path}"
)
return str(full_path)
except Exception as e:
if isinstance(e, ValueError):
raise
raise ValueError(f"Invalid path components: {str(e)}")
raise ValueError(f"Invalid path components: {e!s}") from e
def validate_path_exists(path: Union[str, Path], file_type: str = "file") -> str:
def validate_path_exists(path: str | Path, file_type: str = "file") -> str:
"""
Validate that a path exists and is of the expected type.
@@ -84,24 +82,24 @@ def validate_path_exists(path: Union[str, Path], file_type: str = "file") -> str
"""
try:
path_obj = Path(path).resolve()
if not path_obj.exists():
raise ValueError(f"Path does not exist: {path}")
if file_type == "file" and not path_obj.is_file():
raise ValueError(f"Path is not a file: {path}")
elif file_type == "directory" and not path_obj.is_dir():
if file_type == "directory" and not path_obj.is_dir():
raise ValueError(f"Path is not a directory: {path}")
return str(path_obj)
except Exception as e:
if isinstance(e, ValueError):
raise
raise ValueError(f"Invalid path: {str(e)}")
raise ValueError(f"Invalid path: {e!s}") from e
def list_files(directory: Union[str, Path], pattern: str = "*") -> List[str]:
def list_files(directory: str | Path, pattern: str = "*") -> list[str]:
"""
Safely list files in a directory matching a pattern.
@@ -126,10 +124,10 @@ def list_files(directory: Union[str, Path], pattern: str = "*") -> List[str]:
dir_path = Path(directory).resolve()
if not dir_path.is_dir():
raise ValueError(f"Not a directory: {directory}")
return [str(p) for p in dir_path.glob(pattern) if p.is_file()]
except Exception as e:
if isinstance(e, ValueError):
raise
raise ValueError(f"Error listing files: {str(e)}")
raise ValueError(f"Error listing files: {e!s}") from e

View File

@@ -4,7 +4,7 @@ CrewAI Flow Persistence.
This module provides interfaces and implementations for persisting flow states.
"""
from typing import Any, Dict, TypeVar, Union
from typing import Any, TypeVar
from pydantic import BaseModel
@@ -12,7 +12,7 @@ from crewai.flow.persistence.base import FlowPersistence
from crewai.flow.persistence.decorators import persist
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
__all__ = ["FlowPersistence", "persist", "SQLiteFlowPersistence"]
__all__ = ["FlowPersistence", "SQLiteFlowPersistence", "persist"]
StateType = TypeVar('StateType', bound=Union[Dict[str, Any], BaseModel])
DictStateType = Dict[str, Any]
StateType = TypeVar("StateType", bound=dict[str, Any] | BaseModel)
DictStateType = dict[str, Any]

View File

@@ -1,53 +1,47 @@
"""Base class for flow state persistence."""
import abc
from typing import Any, Dict, Optional, Union
from typing import Any
from pydantic import BaseModel
class FlowPersistence(abc.ABC):
"""Abstract base class for flow state persistence.
This class defines the interface that all persistence implementations must follow.
It supports both structured (Pydantic BaseModel) and unstructured (dict) states.
"""
@abc.abstractmethod
def init_db(self) -> None:
"""Initialize the persistence backend.
This method should handle any necessary setup, such as:
- Creating tables
- Establishing connections
- Setting up indexes
"""
pass
@abc.abstractmethod
def save_state(
self,
flow_uuid: str,
method_name: str,
state_data: Union[Dict[str, Any], BaseModel]
self, flow_uuid: str, method_name: str, state_data: dict[str, Any] | BaseModel
) -> None:
"""Persist the flow state after method completion.
Args:
flow_uuid: Unique identifier for the flow instance
method_name: Name of the method that just completed
state_data: Current state data (either dict or Pydantic model)
"""
pass
@abc.abstractmethod
def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]:
def load_state(self, flow_uuid: str) -> dict[str, Any] | None:
"""Load the most recent state for a given flow UUID.
Args:
flow_uuid: Unique identifier for the flow instance
Returns:
The most recent state as a dictionary, or None if no state exists
"""
pass

View File

@@ -24,13 +24,10 @@ Example:
import asyncio
import functools
import logging
from collections.abc import Callable
from typing import (
Any,
Callable,
Optional,
Type,
TypeVar,
Union,
cast,
)
@@ -48,7 +45,7 @@ LOG_MESSAGES = {
"save_state": "Saving flow state to memory for ID: {}",
"save_error": "Failed to persist state for method {}: {}",
"state_missing": "Flow instance has no state",
"id_missing": "Flow state must have an 'id' field for persistence"
"id_missing": "Flow state must have an 'id' field for persistence",
}
@@ -58,7 +55,13 @@ class PersistenceDecorator:
_printer = Printer() # Class-level printer instance
@classmethod
def persist_state(cls, flow_instance: Any, method_name: str, persistence_instance: FlowPersistence, verbose: bool = False) -> None:
def persist_state(
cls,
flow_instance: Any,
method_name: str,
persistence_instance: FlowPersistence,
verbose: bool = False,
) -> None:
"""Persist flow state with proper error handling and logging.
This method handles the persistence of flow state data, including proper
@@ -76,22 +79,24 @@ class PersistenceDecorator:
AttributeError: If flow instance lacks required state attributes
"""
try:
state = getattr(flow_instance, 'state', None)
state = getattr(flow_instance, "state", None)
if state is None:
raise ValueError("Flow instance has no state")
flow_uuid: Optional[str] = None
flow_uuid: str | None = None
if isinstance(state, dict):
flow_uuid = state.get('id')
flow_uuid = state.get("id")
elif isinstance(state, BaseModel):
flow_uuid = getattr(state, 'id', None)
flow_uuid = getattr(state, "id", None)
if not flow_uuid:
raise ValueError("Flow state must have an 'id' field for persistence")
# Log state saving only if verbose is True
if verbose:
cls._printer.print(LOG_MESSAGES["save_state"].format(flow_uuid), color="cyan")
cls._printer.print(
LOG_MESSAGES["save_state"].format(flow_uuid), color="cyan"
)
logger.info(LOG_MESSAGES["save_state"].format(flow_uuid))
try:
@@ -104,12 +109,12 @@ class PersistenceDecorator:
error_msg = LOG_MESSAGES["save_error"].format(method_name, str(e))
cls._printer.print(error_msg, color="red")
logger.error(error_msg)
raise RuntimeError(f"State persistence failed: {str(e)}") from e
except AttributeError:
raise RuntimeError(f"State persistence failed: {e!s}") from e
except AttributeError as e:
error_msg = LOG_MESSAGES["state_missing"]
cls._printer.print(error_msg, color="red")
logger.error(error_msg)
raise ValueError(error_msg)
raise ValueError(error_msg) from e
except (TypeError, ValueError) as e:
error_msg = LOG_MESSAGES["id_missing"]
cls._printer.print(error_msg, color="red")
@@ -117,7 +122,7 @@ class PersistenceDecorator:
raise ValueError(error_msg) from e
def persist(persistence: Optional[FlowPersistence] = None, verbose: bool = False):
def persist(persistence: FlowPersistence | None = None, verbose: bool = False):
"""Decorator to persist flow state.
This decorator can be applied at either the class level or method level.
@@ -144,111 +149,151 @@ def persist(persistence: Optional[FlowPersistence] = None, verbose: bool = False
def begin(self):
pass
"""
def decorator(target: Union[Type, Callable[..., T]]) -> Union[Type, Callable[..., T]]:
def decorator(target: type | Callable[..., T]) -> type | Callable[..., T]:
"""Decorator that handles both class and method decoration."""
actual_persistence = persistence or SQLiteFlowPersistence()
if isinstance(target, type):
# Class decoration
original_init = getattr(target, "__init__")
original_init = target.__init__ # type: ignore[misc]
@functools.wraps(original_init)
def new_init(self: Any, *args: Any, **kwargs: Any) -> None:
if 'persistence' not in kwargs:
kwargs['persistence'] = actual_persistence
if "persistence" not in kwargs:
kwargs["persistence"] = actual_persistence
original_init(self, *args, **kwargs)
setattr(target, "__init__", new_init)
target.__init__ = new_init # type: ignore[misc]
# Store original methods to preserve their decorators
original_methods = {}
for name, method in target.__dict__.items():
if callable(method) and (
hasattr(method, "__is_start_method__") or
hasattr(method, "__trigger_methods__") or
hasattr(method, "__condition_type__") or
hasattr(method, "__is_flow_method__") or
hasattr(method, "__is_router__")
):
original_methods[name] = method
original_methods = {
name: method
for name, method in target.__dict__.items()
if callable(method)
and (
hasattr(method, "__is_start_method__")
or hasattr(method, "__trigger_methods__")
or hasattr(method, "__condition_type__")
or hasattr(method, "__is_flow_method__")
or hasattr(method, "__is_router__")
)
}
# Create wrapped versions of the methods that include persistence
for name, method in original_methods.items():
if asyncio.iscoroutinefunction(method):
# Create a closure to capture the current name and method
def create_async_wrapper(method_name: str, original_method: Callable):
def create_async_wrapper(
method_name: str, original_method: Callable
):
@functools.wraps(original_method)
async def method_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
async def method_wrapper(
self: Any, *args: Any, **kwargs: Any
) -> Any:
result = await original_method(self, *args, **kwargs)
PersistenceDecorator.persist_state(self, method_name, actual_persistence, verbose)
PersistenceDecorator.persist_state(
self, method_name, actual_persistence, verbose
)
return result
return method_wrapper
wrapped = create_async_wrapper(name, method)
# Preserve all original decorators and attributes
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
for attr in [
"__is_start_method__",
"__trigger_methods__",
"__condition_type__",
"__is_router__",
]:
if hasattr(method, attr):
setattr(wrapped, attr, getattr(method, attr))
setattr(wrapped, "__is_flow_method__", True)
wrapped.__is_flow_method__ = True # type: ignore[attr-defined]
# Update the class with the wrapped method
setattr(target, name, wrapped)
else:
# Create a closure to capture the current name and method
def create_sync_wrapper(method_name: str, original_method: Callable):
def create_sync_wrapper(
method_name: str, original_method: Callable
):
@functools.wraps(original_method)
def method_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
result = original_method(self, *args, **kwargs)
PersistenceDecorator.persist_state(self, method_name, actual_persistence, verbose)
PersistenceDecorator.persist_state(
self, method_name, actual_persistence, verbose
)
return result
return method_wrapper
wrapped = create_sync_wrapper(name, method)
# Preserve all original decorators and attributes
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
for attr in [
"__is_start_method__",
"__trigger_methods__",
"__condition_type__",
"__is_router__",
]:
if hasattr(method, attr):
setattr(wrapped, attr, getattr(method, attr))
setattr(wrapped, "__is_flow_method__", True)
wrapped.__is_flow_method__ = True # type: ignore[attr-defined]
# Update the class with the wrapped method
setattr(target, name, wrapped)
return target
else:
# Method decoration
method = target
setattr(method, "__is_flow_method__", True)
# Method decoration
method = target
method.__is_flow_method__ = True # type: ignore[attr-defined]
if asyncio.iscoroutinefunction(method):
@functools.wraps(method)
async def method_async_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
method_coro = method(flow_instance, *args, **kwargs)
if asyncio.iscoroutine(method_coro):
result = await method_coro
else:
result = method_coro
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence, verbose)
return result
if asyncio.iscoroutinefunction(method):
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
if hasattr(method, attr):
setattr(method_async_wrapper, attr, getattr(method, attr))
setattr(method_async_wrapper, "__is_flow_method__", True)
return cast(Callable[..., T], method_async_wrapper)
else:
@functools.wraps(method)
def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
result = method(flow_instance, *args, **kwargs)
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence, verbose)
return result
@functools.wraps(method)
async def method_async_wrapper(
flow_instance: Any, *args: Any, **kwargs: Any
) -> T:
method_coro = method(flow_instance, *args, **kwargs)
if asyncio.iscoroutine(method_coro):
result = await method_coro
else:
result = method_coro
PersistenceDecorator.persist_state(
flow_instance, method.__name__, actual_persistence, verbose
)
return result
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
if hasattr(method, attr):
setattr(method_sync_wrapper, attr, getattr(method, attr))
setattr(method_sync_wrapper, "__is_flow_method__", True)
return cast(Callable[..., T], method_sync_wrapper)
for attr in [
"__is_start_method__",
"__trigger_methods__",
"__condition_type__",
"__is_router__",
]:
if hasattr(method, attr):
setattr(method_async_wrapper, attr, getattr(method, attr))
method_async_wrapper.__is_flow_method__ = True # type: ignore[attr-defined]
return cast(Callable[..., T], method_async_wrapper)
@functools.wraps(method)
def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
result = method(flow_instance, *args, **kwargs)
PersistenceDecorator.persist_state(
flow_instance, method.__name__, actual_persistence, verbose
)
return result
for attr in [
"__is_start_method__",
"__trigger_methods__",
"__condition_type__",
"__is_router__",
]:
if hasattr(method, attr):
setattr(method_sync_wrapper, attr, getattr(method, attr))
method_sync_wrapper.__is_flow_method__ = True # type: ignore[attr-defined]
return cast(Callable[..., T], method_sync_wrapper)
return decorator

View File

@@ -6,7 +6,7 @@ import json
import sqlite3
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, Optional, Union
from typing import Any
from pydantic import BaseModel
@@ -23,7 +23,7 @@ class SQLiteFlowPersistence(FlowPersistence):
db_path: str
def __init__(self, db_path: Optional[str] = None):
def __init__(self, db_path: str | None = None):
"""Initialize SQLite persistence.
Args:
@@ -70,7 +70,7 @@ class SQLiteFlowPersistence(FlowPersistence):
self,
flow_uuid: str,
method_name: str,
state_data: Union[Dict[str, Any], BaseModel],
state_data: dict[str, Any] | BaseModel,
) -> None:
"""Save the current flow state to SQLite.
@@ -107,7 +107,7 @@ class SQLiteFlowPersistence(FlowPersistence):
),
)
def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]:
def load_state(self, flow_uuid: str) -> dict[str, Any] | None:
"""Load the most recent state for a given flow UUID.
Args:

View File

@@ -5,6 +5,7 @@ the Flow system.
"""
from typing import Any, TypedDict
from typing_extensions import NotRequired, Required

View File

@@ -17,10 +17,10 @@ import ast
import inspect
import textwrap
from collections import defaultdict, deque
from typing import Any, Deque, Dict, List, Optional, Set, Union
from typing import Any
def get_possible_return_constants(function: Any) -> Optional[List[str]]:
def get_possible_return_constants(function: Any) -> list[str] | None:
try:
source = inspect.getsource(function)
except OSError:
@@ -58,12 +58,12 @@ def get_possible_return_constants(function: Any) -> Optional[List[str]]:
target = node.targets[0]
if isinstance(target, ast.Name):
var_name = target.id
dict_values = []
# Extract string values from the dictionary
for val in node.value.values:
if isinstance(val, ast.Constant) and isinstance(val.value, str):
dict_values.append(val.value)
# If non-string, skip or just ignore
dict_values = [
val.value
for val in node.value.values
if isinstance(val, ast.Constant) and isinstance(val.value, str)
]
if dict_values:
dict_definitions[var_name] = dict_values
self.generic_visit(node)
@@ -94,7 +94,7 @@ def get_possible_return_constants(function: Any) -> Optional[List[str]]:
return list(return_values) if return_values else None
def calculate_node_levels(flow: Any) -> Dict[str, int]:
def calculate_node_levels(flow: Any) -> dict[str, int]:
"""
Calculate the hierarchical level of each node in the flow.
@@ -118,10 +118,10 @@ def calculate_node_levels(flow: Any) -> Dict[str, int]:
- Handles both OR and AND conditions for listeners
- Processes router paths separately
"""
levels: Dict[str, int] = {}
queue: Deque[str] = deque()
visited: Set[str] = set()
pending_and_listeners: Dict[str, Set[str]] = {}
levels: dict[str, int] = {}
queue: deque[str] = deque()
visited: set[str] = set()
pending_and_listeners: dict[str, set[str]] = {}
# Make all start methods at level 0
for method_name, method in flow._methods.items():
@@ -172,7 +172,7 @@ def calculate_node_levels(flow: Any) -> Dict[str, int]:
return levels
def count_outgoing_edges(flow: Any) -> Dict[str, int]:
def count_outgoing_edges(flow: Any) -> dict[str, int]:
"""
Count the number of outgoing edges for each method in the flow.
@@ -197,7 +197,7 @@ def count_outgoing_edges(flow: Any) -> Dict[str, int]:
return counts
def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]:
def build_ancestor_dict(flow: Any) -> dict[str, set[str]]:
"""
Build a dictionary mapping each node to its ancestor nodes.
@@ -211,8 +211,8 @@ def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]:
Dict[str, Set[str]]
Dictionary mapping each node to a set of its ancestor nodes.
"""
ancestors: Dict[str, Set[str]] = {node: set() for node in flow._methods}
visited: Set[str] = set()
ancestors: dict[str, set[str]] = {node: set() for node in flow._methods}
visited: set[str] = set()
for node in flow._methods:
if node not in visited:
dfs_ancestors(node, ancestors, visited, flow)
@@ -220,7 +220,7 @@ def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]:
def dfs_ancestors(
node: str, ancestors: Dict[str, Set[str]], visited: Set[str], flow: Any
node: str, ancestors: dict[str, set[str]], visited: set[str], flow: Any
) -> None:
"""
Perform depth-first search to build ancestor relationships.
@@ -265,7 +265,7 @@ def dfs_ancestors(
def is_ancestor(
node: str, ancestor_candidate: str, ancestors: Dict[str, Set[str]]
node: str, ancestor_candidate: str, ancestors: dict[str, set[str]]
) -> bool:
"""
Check if one node is an ancestor of another.
@@ -287,7 +287,7 @@ def is_ancestor(
return ancestor_candidate in ancestors.get(node, set())
def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]:
def build_parent_children_dict(flow: Any) -> dict[str, list[str]]:
"""
Build a dictionary mapping parent nodes to their children.
@@ -307,7 +307,7 @@ def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]:
- Maps router methods to their paths and listeners
- Children lists are sorted for consistent ordering
"""
parent_children: Dict[str, List[str]] = {}
parent_children: dict[str, list[str]] = {}
# Map listeners to their trigger methods
for listener_name, (_, trigger_methods) in flow._listeners.items():
@@ -332,7 +332,7 @@ def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]:
def get_child_index(
parent: str, child: str, parent_children: Dict[str, List[str]]
parent: str, child: str, parent_children: dict[str, list[str]]
) -> int:
"""
Get the index of a child node in its parent's sorted children list.
@@ -364,7 +364,7 @@ def process_router_paths(flow, current, current_level, levels, queue):
paths = flow._router_paths.get(current, [])
for path in paths:
for listener_name, (
condition_type,
_condition_type,
trigger_methods,
) in flow._listeners.items():
if path in trigger_methods:

View File

@@ -17,7 +17,7 @@ Example
import ast
import inspect
from typing import Any, Dict, List, Tuple, Union
from typing import Any
from .utils import (
build_ancestor_dict,
@@ -56,6 +56,7 @@ def method_calls_crew(method: Any) -> bool:
class CrewCallVisitor(ast.NodeVisitor):
"""AST visitor to detect .crew() method calls."""
def __init__(self):
self.found = False
@@ -73,8 +74,8 @@ def method_calls_crew(method: Any) -> bool:
def add_nodes_to_network(
net: Any,
flow: Any,
node_positions: Dict[str, Tuple[float, float]],
node_styles: Dict[str, Dict[str, Any]]
node_positions: dict[str, tuple[float, float]],
node_styles: dict[str, dict[str, Any]],
) -> None:
"""
Add nodes to the network visualization with appropriate styling.
@@ -98,6 +99,7 @@ def add_nodes_to_network(
- Crew methods
- Regular methods
"""
def human_friendly_label(method_name):
return method_name.replace("_", " ").title()
@@ -138,10 +140,10 @@ def add_nodes_to_network(
def compute_positions(
flow: Any,
node_levels: Dict[str, int],
node_levels: dict[str, int],
y_spacing: float = 150,
x_spacing: float = 300
) -> Dict[str, Tuple[float, float]]:
x_spacing: float = 300,
) -> dict[str, tuple[float, float]]:
"""
Compute the (x, y) positions for each node in the flow graph.
@@ -161,8 +163,8 @@ def compute_positions(
Dict[str, Tuple[float, float]]
Dictionary mapping node names to their (x, y) coordinates.
"""
level_nodes: Dict[int, List[str]] = {}
node_positions: Dict[str, Tuple[float, float]] = {}
level_nodes: dict[int, list[str]] = {}
node_positions: dict[str, tuple[float, float]] = {}
for method_name, level in node_levels.items():
level_nodes.setdefault(level, []).append(method_name)
@@ -180,10 +182,10 @@ def compute_positions(
def add_edges(
net: Any,
flow: Any,
node_positions: Dict[str, Tuple[float, float]],
colors: Dict[str, str]
node_positions: dict[str, tuple[float, float]],
colors: dict[str, str],
) -> None:
edge_smooth: Dict[str, Union[str, float]] = {"type": "continuous"} # Default value
edge_smooth: dict[str, str | float] = {"type": "continuous"} # Default value
"""
Add edges to the network visualization with appropriate styling.
@@ -269,7 +271,7 @@ def add_edges(
for router_method_name, paths in flow._router_paths.items():
for path in paths:
for listener_name, (
condition_type,
_condition_type,
trigger_methods,
) in flow._listeners.items():
if path in trigger_methods:

View File

@@ -8,7 +8,7 @@ from crewai.rag.chromadb.config import ChromaDBConfig
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
from crewai.rag.config.utils import get_rag_client
from crewai.rag.core.base_client import BaseClient
from crewai.rag.embeddings.factory import get_embedding_function
from crewai.rag.embeddings.factory import EmbedderConfig, get_embedding_function
from crewai.rag.factory import create_client
from crewai.rag.types import BaseRecord, SearchResult
from crewai.utilities.logger import Logger
@@ -27,6 +27,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
) -> None:
self.collection_name = collection_name
self._client: BaseClient | None = None
self._embedder_config = embedder # Store embedder config
warnings.filterwarnings(
"ignore",
@@ -35,12 +36,29 @@ class KnowledgeStorage(BaseKnowledgeStorage):
)
if embedder:
embedding_function = get_embedding_function(embedder)
config = ChromaDBConfig(
embedding_function=cast(
ChromaEmbeddingFunctionWrapper, embedding_function
# Cast to EmbedderConfig for type checking
embedder_typed = cast(EmbedderConfig, embedder)
embedding_function = get_embedding_function(embedder_typed)
batch_size = None
if isinstance(embedder, dict) and "config" in embedder:
nested_config = embedder["config"]
if isinstance(nested_config, dict):
batch_size = nested_config.get("batch_size")
# Create config with batch_size if provided
if batch_size is not None:
config = ChromaDBConfig(
embedding_function=cast(
ChromaEmbeddingFunctionWrapper, embedding_function
),
batch_size=batch_size,
)
else:
config = ChromaDBConfig(
embedding_function=cast(
ChromaEmbeddingFunctionWrapper, embedding_function
)
)
)
self._client = create_client(config)
def _get_client(self) -> BaseClient:
@@ -105,9 +123,23 @@ class KnowledgeStorage(BaseKnowledgeStorage):
rag_documents: list[BaseRecord] = [{"content": doc} for doc in documents]
client.add_documents(
collection_name=collection_name, documents=rag_documents
)
batch_size = None
if self._embedder_config and isinstance(self._embedder_config, dict):
if "config" in self._embedder_config:
nested_config = self._embedder_config["config"]
if isinstance(nested_config, dict):
batch_size = nested_config.get("batch_size")
if batch_size is not None:
client.add_documents(
collection_name=collection_name,
documents=rag_documents,
batch_size=batch_size,
)
else:
client.add_documents(
collection_name=collection_name, documents=rag_documents
)
except Exception as e:
if "dimension mismatch" in str(e).lower():
Logger(verbose=True).log(

View File

@@ -367,6 +367,7 @@ class LiteAgent(FlowTrackable, BaseModel):
output=output,
guardrail=self._guardrail,
retry_count=self._guardrail_retry_count,
event_source=self,
)
if not guardrail_result.success:

View File

@@ -1,28 +1,26 @@
import io
import json
import logging
import os
import sys
import threading
import warnings
from collections import defaultdict
from contextlib import contextmanager
from collections.abc import Callable
from datetime import datetime
from typing import (
Any,
DefaultDict,
Dict,
List,
Final,
Literal,
Optional,
Type,
TextIO,
TypedDict,
Union,
cast,
)
from datetime import datetime
from dotenv import load_dotenv
from litellm.types.utils import ChatCompletionDeltaToolCall
from pydantic import BaseModel, Field
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.llm_events import (
LLMCallCompletedEvent,
LLMCallFailedEvent,
@@ -31,15 +29,19 @@ from crewai.events.types.llm_events import (
LLMStreamChunkEvent,
)
from crewai.events.types.tool_usage_events import (
ToolUsageStartedEvent,
ToolUsageFinishedEvent,
ToolUsageErrorEvent,
ToolUsageFinishedEvent,
ToolUsageStartedEvent,
)
from crewai.llms.base_llm import BaseLLM
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededError,
)
from crewai.utilities.logger_utils import suppress_warnings
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
with suppress_warnings():
import litellm
from litellm import Choices
from litellm import Choices, CustomLogger
from litellm.exceptions import ContextWindowExceededError
from litellm.litellm_core_utils.get_supported_openai_params import (
get_supported_openai_params,
@@ -47,16 +49,6 @@ with warnings.catch_warnings():
from litellm.types.utils import ModelResponse
from litellm.utils import supports_response_schema
import io
from typing import TextIO
from crewai.llms.base_llm import BaseLLM
from crewai.events.event_bus import crewai_event_bus
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededException,
)
load_dotenv()
litellm.suppress_debug_info = True
@@ -126,7 +118,11 @@ if not isinstance(sys.stderr, FilteredStream):
sys.stderr = FilteredStream(sys.stderr)
LLM_CONTEXT_WINDOW_SIZES = {
MIN_CONTEXT: Final[int] = 1024
MAX_CONTEXT: Final[int] = 2097152 # Current max from gemini-1.5-pro
ANTHROPIC_PREFIXES: Final[tuple[str, str, str]] = ("anthropic/", "claude-", "claude/")
LLM_CONTEXT_WINDOW_SIZES: Final[dict[str, int]] = {
# openai
"gpt-4": 8192,
"gpt-4o": 128000,
@@ -252,30 +248,19 @@ LLM_CONTEXT_WINDOW_SIZES = {
"mistral/mistral-large-2402": 32768,
}
DEFAULT_CONTEXT_WINDOW_SIZE = 8192
CONTEXT_WINDOW_USAGE_RATIO = 0.85
@contextmanager
def suppress_warnings():
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
warnings.filterwarnings(
"ignore", message="open_text is deprecated*", category=DeprecationWarning
)
yield
DEFAULT_CONTEXT_WINDOW_SIZE: Final[int] = 8192
CONTEXT_WINDOW_USAGE_RATIO: Final[float] = 0.85
class Delta(TypedDict):
content: Optional[str]
role: Optional[str]
content: str | None
role: str | None
class StreamingChoices(TypedDict):
delta: Delta
index: int
finish_reason: Optional[str]
finish_reason: str | None
class FunctionArgs(BaseModel):
@@ -288,31 +273,31 @@ class AccumulatedToolArgs(BaseModel):
class LLM(BaseLLM):
completion_cost: Optional[float] = None
completion_cost: float | None = None
def __init__(
self,
model: str,
timeout: Optional[Union[float, int]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
n: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[int, float]] = None,
response_format: Optional[Type[BaseModel]] = None,
seed: Optional[int] = None,
logprobs: Optional[int] = None,
top_logprobs: Optional[int] = None,
base_url: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
api_key: Optional[str] = None,
callbacks: List[Any] | None = None,
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
timeout: float | int | None = None,
temperature: float | None = None,
top_p: float | None = None,
n: int | None = None,
stop: str | list[str] | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
presence_penalty: float | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[int, float] | None = None,
response_format: type[BaseModel] | None = None,
seed: int | None = None,
logprobs: int | None = None,
top_logprobs: int | None = None,
base_url: str | None = None,
api_base: str | None = None,
api_version: str | None = None,
api_key: str | None = None,
callbacks: list[Any] | None = None,
reasoning_effort: Literal["none", "low", "medium", "high"] | None = None,
stream: bool = False,
**kwargs,
):
@@ -345,7 +330,7 @@ class LLM(BaseLLM):
# Normalize self.stop to always be a List[str]
if stop is None:
self.stop: List[str] = []
self.stop: list[str] = []
elif isinstance(stop, str):
self.stop = [stop]
else:
@@ -354,7 +339,8 @@ class LLM(BaseLLM):
self.set_callbacks(callbacks or [])
self.set_env_callbacks()
def _is_anthropic_model(self, model: str) -> bool:
@staticmethod
def _is_anthropic_model(model: str) -> bool:
"""Determine if the model is from Anthropic provider.
Args:
@@ -363,21 +349,18 @@ class LLM(BaseLLM):
Returns:
bool: True if the model is from Anthropic, False otherwise.
"""
ANTHROPIC_PREFIXES = ("anthropic/", "claude-", "claude/")
return any(prefix in model.lower() for prefix in ANTHROPIC_PREFIXES)
def _prepare_completion_params(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
) -> Dict[str, Any]:
messages: str | list[dict[str, str]],
tools: list[dict] | None = None,
) -> dict[str, Any]:
"""Prepare parameters for the completion call.
Args:
messages: Input messages for the LLM
tools: Optional list of tool schemas
callbacks: Optional list of callback functions
available_functions: Optional dict of available functions
Returns:
Dict[str, Any]: Parameters for the completion call
@@ -419,11 +402,11 @@ class LLM(BaseLLM):
def _handle_streaming_response(
self,
params: Dict[str, Any],
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
from_task: Optional[Any] = None,
from_agent: Optional[Any] = None,
params: dict[str, Any],
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str:
"""Handle a streaming response from the LLM.
@@ -445,9 +428,8 @@ class LLM(BaseLLM):
last_chunk = None
chunk_count = 0
usage_info = None
tool_calls = None
accumulated_tool_args: DefaultDict[int, AccumulatedToolArgs] = defaultdict(
accumulated_tool_args: defaultdict[int, AccumulatedToolArgs] = defaultdict(
AccumulatedToolArgs
)
@@ -472,16 +454,16 @@ class LLM(BaseLLM):
choices = chunk["choices"]
elif hasattr(chunk, "choices"):
# Check if choices is not a type but an actual attribute with value
if not isinstance(getattr(chunk, "choices"), type):
choices = getattr(chunk, "choices")
if not isinstance(chunk.choices, type):
choices = chunk.choices
# Try to extract usage information if available
if isinstance(chunk, dict) and "usage" in chunk:
usage_info = chunk["usage"]
elif hasattr(chunk, "usage"):
# Check if usage is not a type but an actual attribute with value
if not isinstance(getattr(chunk, "usage"), type):
usage_info = getattr(chunk, "usage")
if not isinstance(chunk.usage, type):
usage_info = chunk.usage
if choices and len(choices) > 0:
choice = choices[0]
@@ -491,7 +473,7 @@ class LLM(BaseLLM):
if isinstance(choice, dict) and "delta" in choice:
delta = choice["delta"]
elif hasattr(choice, "delta"):
delta = getattr(choice, "delta")
delta = choice.delta
# Extract content from delta
if delta:
@@ -501,7 +483,7 @@ class LLM(BaseLLM):
chunk_content = delta["content"]
# Handle object format
elif hasattr(delta, "content"):
chunk_content = getattr(delta, "content")
chunk_content = delta.content
# Handle case where content might be None or empty
if chunk_content is None and isinstance(delta, dict):
@@ -533,7 +515,9 @@ class LLM(BaseLLM):
full_response += chunk_content
# Emit the chunk event
assert hasattr(crewai_event_bus, "emit")
if not hasattr(crewai_event_bus, "emit"):
raise Exception("crewai_event_bus must have an `emit` method")
crewai_event_bus.emit(
self,
event=LLMStreamChunkEvent(
@@ -572,8 +556,8 @@ class LLM(BaseLLM):
if isinstance(last_chunk, dict) and "choices" in last_chunk:
choices = last_chunk["choices"]
elif hasattr(last_chunk, "choices"):
if not isinstance(getattr(last_chunk, "choices"), type):
choices = getattr(last_chunk, "choices")
if not isinstance(last_chunk.choices, type):
choices = last_chunk.choices
if choices and len(choices) > 0:
choice = choices[0]
@@ -583,14 +567,14 @@ class LLM(BaseLLM):
if isinstance(choice, dict) and "message" in choice:
message = choice["message"]
elif hasattr(choice, "message"):
message = getattr(choice, "message")
message = choice.message
if message:
content = None
if isinstance(message, dict) and "content" in message:
content = message["content"]
elif hasattr(message, "content"):
content = getattr(message, "content")
content = message.content
if content:
full_response = content
@@ -617,8 +601,8 @@ class LLM(BaseLLM):
if isinstance(last_chunk, dict) and "choices" in last_chunk:
choices = last_chunk["choices"]
elif hasattr(last_chunk, "choices"):
if not isinstance(getattr(last_chunk, "choices"), type):
choices = getattr(last_chunk, "choices")
if not isinstance(last_chunk.choices, type):
choices = last_chunk.choices
if choices and len(choices) > 0:
choice = choices[0]
@@ -627,13 +611,13 @@ class LLM(BaseLLM):
if isinstance(choice, dict) and "message" in choice:
message = choice["message"]
elif hasattr(choice, "message"):
message = getattr(choice, "message")
message = choice.message
if message:
if isinstance(message, dict) and "tool_calls" in message:
tool_calls = message["tool_calls"]
elif hasattr(message, "tool_calls"):
tool_calls = getattr(message, "tool_calls")
tool_calls = message.tool_calls
except Exception as e:
logging.debug(f"Error checking for tool calls: {e}")
# --- 8) If no tool calls or no available functions, return the text response directly
@@ -673,11 +657,11 @@ class LLM(BaseLLM):
# Catch context window errors from litellm and convert them to our own exception type.
# This exception is handled by CrewAgentExecutor._invoke_loop() which can then
# decide whether to summarize the content or abort based on the respect_context_window flag.
raise LLMContextLengthExceededException(str(e))
raise LLMContextLengthExceededError(str(e)) from e
except Exception as e:
logging.error(f"Error in streaming response: {str(e)}")
logging.error(f"Error in streaming response: {e!s}")
if full_response.strip():
logging.warning(f"Returning partial response despite error: {str(e)}")
logging.warning(f"Returning partial response despite error: {e!s}")
self._handle_emit_call_events(
response=full_response,
call_type=LLMCallType.LLM_CALL,
@@ -688,22 +672,25 @@ class LLM(BaseLLM):
return full_response
# Emit failed event and re-raise the exception
assert hasattr(crewai_event_bus, "emit")
if not hasattr(crewai_event_bus, "emit"):
raise AttributeError(
"crewai_event_bus must have an 'emit' method"
) from e
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(
error=str(e), from_task=from_task, from_agent=from_agent
),
)
raise Exception(f"Failed to get streaming response: {str(e)}")
raise Exception(f"Failed to get streaming response: {e!s}") from e
def _handle_streaming_tool_calls(
self,
tool_calls: List[ChatCompletionDeltaToolCall],
accumulated_tool_args: DefaultDict[int, AccumulatedToolArgs],
available_functions: Optional[Dict[str, Any]] = None,
from_task: Optional[Any] = None,
from_agent: Optional[Any] = None,
tool_calls: list[ChatCompletionDeltaToolCall],
accumulated_tool_args: defaultdict[int, AccumulatedToolArgs],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> None | str:
for tool_call in tool_calls:
current_tool_accumulator = accumulated_tool_args[tool_call.index]
@@ -715,7 +702,8 @@ class LLM(BaseLLM):
current_tool_accumulator.function.arguments += (
tool_call.function.arguments
)
assert hasattr(crewai_event_bus, "emit")
if not hasattr(crewai_event_bus, "emit"):
raise AttributeError("crewai_event_bus must have an 'emit' method")
crewai_event_bus.emit(
self,
event=LLMStreamChunkEvent(
@@ -742,11 +730,11 @@ class LLM(BaseLLM):
continue
return None
@staticmethod
def _handle_streaming_callbacks(
self,
callbacks: Optional[List[Any]],
usage_info: Optional[Dict[str, Any]],
last_chunk: Optional[Any],
callbacks: list[Any] | None,
usage_info: dict[str, Any] | None,
last_chunk: Any | None,
) -> None:
"""Handle callbacks with usage info for streaming responses.
@@ -769,10 +757,8 @@ class LLM(BaseLLM):
):
usage_info = last_chunk["usage"]
elif hasattr(last_chunk, "usage"):
if not isinstance(
getattr(last_chunk, "usage"), type
):
usage_info = getattr(last_chunk, "usage")
if not isinstance(last_chunk.usage, type):
usage_info = last_chunk.usage
except Exception as e:
logging.debug(f"Error extracting usage info: {e}")
@@ -786,11 +772,11 @@ class LLM(BaseLLM):
def _handle_non_streaming_response(
self,
params: Dict[str, Any],
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
from_task: Optional[Any] = None,
from_agent: Optional[Any] = None,
params: dict[str, Any],
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str | Any:
"""Handle a non-streaming response from the LLM.
@@ -815,7 +801,7 @@ class LLM(BaseLLM):
except ContextWindowExceededError as e:
# Convert litellm's context window error to our own exception type
# for consistent handling in the rest of the codebase
raise LLMContextLengthExceededException(str(e))
raise LLMContextLengthExceededError(str(e)) from e
# --- 2) Extract response message and content
response_message = cast(Choices, cast(ModelResponse, response).choices)[
0
@@ -847,7 +833,7 @@ class LLM(BaseLLM):
)
return text_response
# --- 6) If there is no text response, no available functions, but there are tool calls, return the tool calls
elif tool_calls and not available_functions and not text_response:
if tool_calls and not available_functions and not text_response:
return tool_calls
# --- 7) Handle tool calls if present
@@ -868,19 +854,21 @@ class LLM(BaseLLM):
def _handle_tool_call(
self,
tool_calls: List[Any],
available_functions: Optional[Dict[str, Any]] = None,
from_task: Optional[Any] = None,
from_agent: Optional[Any] = None,
) -> Optional[str]:
tool_calls: list[Any],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str | None:
"""Handle a tool call from the LLM.
Args:
tool_calls: List of tool calls from the LLM
available_functions: Dict of available functions
from_task: Optional Task that invoked the LLM
from_agent: Optional Agent that invoked the LLM
Returns:
Optional[str]: The result of the tool call, or None if no tool call was made
The result of the tool call, or None if no tool call was made
"""
# --- 1) Validate tool calls and available functions
if not tool_calls or not available_functions:
@@ -899,7 +887,8 @@ class LLM(BaseLLM):
fn = available_functions[function_name]
# --- 3.2) Execute function
assert hasattr(crewai_event_bus, "emit")
if not hasattr(crewai_event_bus, "emit"):
raise AttributeError("crewai_event_bus must have an 'emit' method")
started_at = datetime.now()
crewai_event_bus.emit(
self,
@@ -939,17 +928,20 @@ class LLM(BaseLLM):
function_name, lambda: None
) # Ensure fn is always a callable
logging.error(f"Error executing function '{function_name}': {e}")
assert hasattr(crewai_event_bus, "emit")
if not hasattr(crewai_event_bus, "emit"):
raise AttributeError(
"crewai_event_bus must have an 'emit' method"
) from e
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(error=f"Tool execution error: {str(e)}"),
event=LLMCallFailedEvent(error=f"Tool execution error: {e!s}"),
)
crewai_event_bus.emit(
self,
event=ToolUsageErrorEvent(
tool_name=function_name,
tool_args=function_args,
error=f"Tool execution error: {str(e)}",
error=f"Tool execution error: {e!s}",
from_task=from_task,
from_agent=from_agent,
),
@@ -958,13 +950,13 @@ class LLM(BaseLLM):
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
from_task: Optional[Any] = None,
from_agent: Optional[Any] = None,
) -> Union[str, Any]:
messages: str | list[dict[str, str]],
tools: list[dict] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str | Any:
"""High-level LLM call method.
Args:
@@ -988,10 +980,11 @@ class LLM(BaseLLM):
Raises:
TypeError: If messages format is invalid
ValueError: If response format is not supported
LLMContextLengthExceededException: If input exceeds model's context limit
LLMContextLengthExceededError: If input exceeds model's context limit
"""
# --- 1) Emit call started event
assert hasattr(crewai_event_bus, "emit")
if not hasattr(crewai_event_bus, "emit"):
raise AttributeError("crewai_event_bus must have an 'emit' method")
crewai_event_bus.emit(
self,
event=LLMCallStartedEvent(
@@ -1028,13 +1021,12 @@ class LLM(BaseLLM):
return self._handle_streaming_response(
params, callbacks, available_functions, from_task, from_agent
)
else:
return self._handle_non_streaming_response(
params, callbacks, available_functions, from_task, from_agent
)
return self._handle_non_streaming_response(
params, callbacks, available_functions, from_task, from_agent
)
except LLMContextLengthExceededException:
# Re-raise LLMContextLengthExceededException as it should be handled
except LLMContextLengthExceededError:
# Re-raise LLMContextLengthExceededError as it should be handled
# by the CrewAgentExecutor._invoke_loop method, which can then decide
# whether to summarize the content or abort based on the respect_context_window flag
raise
@@ -1065,7 +1057,10 @@ class LLM(BaseLLM):
from_agent=from_agent,
)
assert hasattr(crewai_event_bus, "emit")
if not hasattr(crewai_event_bus, "emit"):
raise AttributeError(
"crewai_event_bus must have an 'emit' method"
) from e
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(
@@ -1078,8 +1073,8 @@ class LLM(BaseLLM):
self,
response: Any,
call_type: LLMCallType,
from_task: Optional[Any] = None,
from_agent: Optional[Any] = None,
from_task: Any | None = None,
from_agent: Any | None = None,
messages: str | list[dict[str, Any]] | None = None,
):
"""Handle the events for the LLM call.
@@ -1091,7 +1086,8 @@ class LLM(BaseLLM):
from_agent: Optional agent object
messages: Optional messages object
"""
assert hasattr(crewai_event_bus, "emit")
if not hasattr(crewai_event_bus, "emit"):
raise AttributeError("crewai_event_bus must have an 'emit' method")
crewai_event_bus.emit(
self,
event=LLMCallCompletedEvent(
@@ -1105,8 +1101,8 @@ class LLM(BaseLLM):
)
def _format_messages_for_provider(
self, messages: List[Dict[str, str]]
) -> List[Dict[str, str]]:
self, messages: list[dict[str, str]]
) -> list[dict[str, str]]:
"""Format messages according to provider requirements.
Args:
@@ -1147,7 +1143,7 @@ class LLM(BaseLLM):
if "mistral" in self.model.lower():
# Check if the last message has a role of 'assistant'
if messages and messages[-1]["role"] == "assistant":
return messages + [{"role": "user", "content": "Please continue."}]
return [*messages, {"role": "user", "content": "Please continue."}]
return messages
# TODO: Remove this code after merging PR https://github.com/BerriAI/litellm/pull/10917
@@ -1157,7 +1153,7 @@ class LLM(BaseLLM):
and messages
and messages[-1]["role"] == "assistant"
):
return messages + [{"role": "user", "content": ""}]
return [*messages, {"role": "user", "content": ""}]
# Handle Anthropic models
if not self.is_anthropic:
@@ -1170,7 +1166,7 @@ class LLM(BaseLLM):
return messages
def _get_custom_llm_provider(self) -> Optional[str]:
def _get_custom_llm_provider(self) -> str | None:
"""
Derives the custom_llm_provider from the model string.
- For example, if the model is "openrouter/deepseek/deepseek-chat", returns "openrouter".
@@ -1207,7 +1203,7 @@ class LLM(BaseLLM):
self.model, custom_llm_provider=provider
)
except Exception as e:
logging.error(f"Failed to check function calling support: {str(e)}")
logging.error(f"Failed to check function calling support: {e!s}")
return False
def supports_stop_words(self) -> bool:
@@ -1215,7 +1211,7 @@ class LLM(BaseLLM):
params = get_supported_openai_params(model=self.model)
return params is not None and "stop" in params
except Exception as e:
logging.error(f"Failed to get supported params: {str(e)}")
logging.error(f"Failed to get supported params: {e!s}")
return False
def get_context_window_size(self) -> int:
@@ -1229,9 +1225,6 @@ class LLM(BaseLLM):
if self.context_window_size != 0:
return self.context_window_size
MIN_CONTEXT = 1024
MAX_CONTEXT = 2097152 # Current max from gemini-1.5-pro
# Validate all context window sizes
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
if value < MIN_CONTEXT or value > MAX_CONTEXT:
@@ -1247,7 +1240,8 @@ class LLM(BaseLLM):
self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO)
return self.context_window_size
def set_callbacks(self, callbacks: List[Any]):
@staticmethod
def set_callbacks(callbacks: list[Any]):
"""
Attempt to keep a single set of callbacks in litellm by removing old
duplicates and adding new ones.
@@ -1264,9 +1258,9 @@ class LLM(BaseLLM):
litellm.callbacks = callbacks
def set_env_callbacks(self):
"""
Sets the success and failure callbacks for the LiteLLM library from environment variables.
@staticmethod
def set_env_callbacks() -> None:
"""Sets the success and failure callbacks for the LiteLLM library from environment variables.
This method reads the `LITELLM_SUCCESS_CALLBACKS` and `LITELLM_FAILURE_CALLBACKS`
environment variables, which should contain comma-separated lists of callback names.
@@ -1276,7 +1270,7 @@ class LLM(BaseLLM):
If the environment variables are not set or are empty, the corresponding callback lists
will be set to empty lists.
Example:
Examples:
LITELLM_SUCCESS_CALLBACKS="langfuse,langsmith"
LITELLM_FAILURE_CALLBACKS="langfuse"
@@ -1285,16 +1279,15 @@ class LLM(BaseLLM):
"""
with suppress_warnings():
success_callbacks_str = os.environ.get("LITELLM_SUCCESS_CALLBACKS", "")
success_callbacks = []
success_callbacks: list[str | Callable[..., Any] | CustomLogger] = []
if success_callbacks_str:
success_callbacks = [
cb.strip() for cb in success_callbacks_str.split(",") if cb.strip()
]
failure_callbacks_str = os.environ.get("LITELLM_FAILURE_CALLBACKS", "")
failure_callbacks = []
if failure_callbacks_str:
failure_callbacks = [
failure_callbacks: list[str | Callable[..., Any] | CustomLogger] = [
cb.strip() for cb in failure_callbacks_str.split(",") if cb.strip()
]

View File

@@ -7,7 +7,8 @@ from crewai.rag.chromadb.config import ChromaDBConfig
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
from crewai.rag.config.utils import get_rag_client
from crewai.rag.core.base_client import BaseClient
from crewai.rag.embeddings.factory import get_embedding_function
from crewai.rag.embeddings.factory import EmbedderConfig, get_embedding_function
from crewai.rag.embeddings.types import EmbeddingOptions
from crewai.rag.factory import create_client
from crewai.rag.storage.base_rag_storage import BaseRAGStorage
from crewai.rag.types import BaseRecord
@@ -25,7 +26,7 @@ class RAGStorage(BaseRAGStorage):
self,
type: str,
allow_reset: bool = True,
embedder_config: dict[str, Any] | None = None,
embedder_config: EmbeddingOptions | EmbedderConfig | None = None,
crew: Any = None,
path: str | None = None,
) -> None:
@@ -50,11 +51,43 @@ class RAGStorage(BaseRAGStorage):
if self.embedder_config:
embedding_function = get_embedding_function(self.embedder_config)
config = ChromaDBConfig(
embedding_function=cast(
ChromaEmbeddingFunctionWrapper, embedding_function
try:
_ = embedding_function(["test"])
except Exception as e:
provider = (
self.embedder_config.provider
if isinstance(self.embedder_config, EmbeddingOptions)
else self.embedder_config.get("provider", "unknown")
)
raise ValueError(
f"Failed to initialize embedder. Please check your configuration or connection.\n"
f"Provider: {provider}\n"
f"Error: {e}"
) from e
batch_size = None
if (
isinstance(self.embedder_config, dict)
and "config" in self.embedder_config
):
nested_config = self.embedder_config["config"]
if isinstance(nested_config, dict):
batch_size = nested_config.get("batch_size")
if batch_size is not None:
config = ChromaDBConfig(
embedding_function=cast(
ChromaEmbeddingFunctionWrapper, embedding_function
),
batch_size=batch_size,
)
else:
config = ChromaDBConfig(
embedding_function=cast(
ChromaEmbeddingFunctionWrapper, embedding_function
)
)
)
self._client = create_client(config)
def _get_client(self) -> BaseClient:
@@ -95,7 +128,26 @@ class RAGStorage(BaseRAGStorage):
if metadata:
document["metadata"] = metadata
client.add_documents(collection_name=collection_name, documents=[document])
batch_size = None
if (
self.embedder_config
and isinstance(self.embedder_config, dict)
and "config" in self.embedder_config
):
nested_config = self.embedder_config["config"]
if isinstance(nested_config, dict):
batch_size = nested_config.get("batch_size")
if batch_size is not None:
client.add_documents(
collection_name=collection_name,
documents=[document],
batch_size=batch_size,
)
else:
client.add_documents(
collection_name=collection_name, documents=[document]
)
except Exception as e:
logging.error(
f"Error during {self.type} save: {e!s}\n{traceback.format_exc()}"

View File

@@ -1,17 +1,17 @@
"""RAG (Retrieval-Augmented Generation) infrastructure for CrewAI."""
import sys
import importlib
import sys
from types import ModuleType
from typing import Any
from crewai.rag.config.types import RagConfigType
from crewai.rag.config.utils import set_rag_config
_module_path = __path__
_module_file = __file__
class _RagModule(ModuleType):
"""Module wrapper to intercept attribute setting for config."""
@@ -51,8 +51,10 @@ class _RagModule(ModuleType):
"""
try:
return importlib.import_module(f"{self.__name__}.{name}")
except ImportError:
raise AttributeError(f"module '{self.__name__}' has no attribute '{name}'")
except ImportError as e:
raise AttributeError(
f"module '{self.__name__}' has no attribute '{name}'"
) from e
sys.modules[__name__] = _RagModule(__name__)

View File

@@ -17,6 +17,7 @@ from crewai.rag.chromadb.types import (
ChromaDBCollectionSearchParams,
)
from crewai.rag.chromadb.utils import (
_create_batch_slice,
_extract_search_params,
_is_async_client,
_is_sync_client,
@@ -52,6 +53,7 @@ class ChromaDBClient(BaseClient):
embedding_function: ChromaEmbeddingFunction,
default_limit: int = 5,
default_score_threshold: float = 0.6,
default_batch_size: int = 100,
) -> None:
"""Initialize ChromaDBClient with client and embedding function.
@@ -60,11 +62,13 @@ class ChromaDBClient(BaseClient):
embedding_function: Embedding function for text to vector conversion.
default_limit: Default number of results to return in searches.
default_score_threshold: Default minimum score for search results.
default_batch_size: Default batch size for adding documents.
"""
self.client = client
self.embedding_function = embedding_function
self.default_limit = default_limit
self.default_score_threshold = default_score_threshold
self.default_batch_size = default_batch_size
def create_collection(
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
@@ -291,6 +295,7 @@ class ChromaDBClient(BaseClient):
- content: The text content (required)
- doc_id: Optional unique identifier (auto-generated if missing)
- metadata: Optional metadata dictionary
batch_size: Optional batch size for processing documents (default: 100)
Raises:
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
@@ -305,6 +310,7 @@ class ChromaDBClient(BaseClient):
collection_name = kwargs["collection_name"]
documents = kwargs["documents"]
batch_size = kwargs.get("batch_size", self.default_batch_size)
if not documents:
raise ValueError("Documents list cannot be empty")
@@ -315,13 +321,17 @@ class ChromaDBClient(BaseClient):
)
prepared = _prepare_documents_for_chromadb(documents)
# ChromaDB doesn't accept empty metadata dicts, so pass None if all are empty
metadatas = prepared.metadatas if any(m for m in prepared.metadatas) else None
collection.upsert(
ids=prepared.ids,
documents=prepared.texts,
metadatas=metadatas,
)
for i in range(0, len(prepared.ids), batch_size):
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
prepared=prepared, start_index=i, batch_size=batch_size
)
collection.upsert(
ids=batch_ids,
documents=batch_texts,
metadatas=batch_metadatas,
)
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
"""Add documents with their embeddings to a collection asynchronously.
@@ -335,6 +345,7 @@ class ChromaDBClient(BaseClient):
- content: The text content (required)
- doc_id: Optional unique identifier (auto-generated if missing)
- metadata: Optional metadata dictionary
batch_size: Optional batch size for processing documents (default: 100)
Raises:
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
@@ -349,6 +360,7 @@ class ChromaDBClient(BaseClient):
collection_name = kwargs["collection_name"]
documents = kwargs["documents"]
batch_size = kwargs.get("batch_size", self.default_batch_size)
if not documents:
raise ValueError("Documents list cannot be empty")
@@ -358,13 +370,17 @@ class ChromaDBClient(BaseClient):
embedding_function=self.embedding_function,
)
prepared = _prepare_documents_for_chromadb(documents)
# ChromaDB doesn't accept empty metadata dicts, so pass None if all are empty
metadatas = prepared.metadatas if any(m for m in prepared.metadatas) else None
await collection.upsert(
ids=prepared.ids,
documents=prepared.texts,
metadatas=metadatas,
)
for i in range(0, len(prepared.ids), batch_size):
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
prepared=prepared, start_index=i, batch_size=batch_size
)
await collection.upsert(
ids=batch_ids,
documents=batch_texts,
metadatas=batch_metadatas,
)
def search(
self, **kwargs: Unpack[ChromaDBCollectionSearchParams]

View File

@@ -41,4 +41,5 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient:
embedding_function=config.embedding_function,
default_limit=config.limit,
default_score_threshold=config.score_threshold,
default_batch_size=config.batch_size,
)

View File

@@ -1,6 +1,7 @@
"""Utility functions for ChromaDB client implementation."""
import hashlib
import json
from collections.abc import Mapping
from typing import Literal, TypeGuard, cast
@@ -72,7 +73,15 @@ def _prepare_documents_for_chromadb(
if "doc_id" in doc:
ids.append(doc["doc_id"])
else:
content_hash = hashlib.sha256(doc["content"].encode()).hexdigest()[:16]
content_for_hash = doc["content"]
metadata = doc.get("metadata")
if metadata:
metadata_str = json.dumps(metadata, sort_keys=True)
content_for_hash = f"{content_for_hash}|{metadata_str}"
content_hash = hashlib.blake2b(
content_for_hash.encode(), digest_size=32
).hexdigest()
ids.append(content_hash)
texts.append(doc["content"])
@@ -88,6 +97,32 @@ def _prepare_documents_for_chromadb(
return PreparedDocuments(ids, texts, metadatas)
def _create_batch_slice(
prepared: PreparedDocuments, start_index: int, batch_size: int
) -> tuple[list[str], list[str], list[Mapping[str, str | int | float | bool]] | None]:
"""Create a batch slice from prepared documents.
Args:
prepared: PreparedDocuments containing ids, texts, and metadatas.
start_index: Starting index for the batch.
batch_size: Size of the batch.
Returns:
Tuple of (batch_ids, batch_texts, batch_metadatas).
"""
batch_end = min(start_index + batch_size, len(prepared.ids))
batch_ids = prepared.ids[start_index:batch_end]
batch_texts = prepared.texts[start_index:batch_end]
batch_metadatas = (
prepared.metadatas[start_index:batch_end] if prepared.metadatas else None
)
if batch_metadatas and not any(m for m in batch_metadatas):
batch_metadatas = None
return batch_ids, batch_texts, batch_metadatas
def _extract_search_params(
kwargs: ChromaDBCollectionSearchParams,
) -> ExtractedSearchParams:

View File

@@ -16,3 +16,4 @@ class BaseRagConfig:
embedding_function: Any | None = field(default=None)
limit: int = field(default=5)
score_threshold: float = field(default=0.6)
batch_size: int = field(default=100)

View File

@@ -1 +1 @@
"""Optional imports for RAG configuration providers."""
"""Optional imports for RAG configuration providers."""

View File

@@ -1,7 +1,7 @@
"""Base classes for missing provider configurations."""
from typing import Literal
from dataclasses import field
from typing import Literal
from pydantic import ConfigDict
from pydantic.dataclasses import dataclass as pyd_dataclass

View File

@@ -2,7 +2,7 @@
from __future__ import annotations
from typing import Protocol, TYPE_CHECKING
from typing import TYPE_CHECKING, Protocol
if TYPE_CHECKING:
from crewai.rag.chromadb.client import ChromaDBClient

View File

@@ -1,7 +1,8 @@
"""Provider-specific missing configuration classes."""
from typing import Literal
from dataclasses import field
from typing import Literal
from pydantic import ConfigDict
from pydantic.dataclasses import dataclass as pyd_dataclass

View File

@@ -1,6 +1,7 @@
"""Type definitions for RAG configuration."""
from typing import Annotated, TypeAlias, TYPE_CHECKING
from typing import TYPE_CHECKING, Annotated, TypeAlias
from pydantic import Field
from crewai.rag.config.constants import DISCRIMINATOR

View File

@@ -4,14 +4,14 @@ from contextvars import ContextVar
from pydantic import BaseModel, Field
from crewai.utilities.import_utils import require
from crewai.rag.core.base_client import BaseClient
from crewai.rag.config.types import RagConfigType
from crewai.rag.config.constants import (
DEFAULT_RAG_CONFIG_PATH,
DEFAULT_RAG_CONFIG_CLASS,
DEFAULT_RAG_CONFIG_PATH,
)
from crewai.rag.config.types import RagConfigType
from crewai.rag.core.base_client import BaseClient
from crewai.rag.factory import create_client
from crewai.utilities.import_utils import require
class RagContext(BaseModel):

View File

@@ -29,7 +29,7 @@ class BaseCollectionParams(TypedDict):
]
class BaseCollectionAddParams(BaseCollectionParams):
class BaseCollectionAddParams(BaseCollectionParams, total=False):
"""Parameters for adding documents to a collection.
Extends BaseCollectionParams with document-specific fields.
@@ -37,9 +37,11 @@ class BaseCollectionAddParams(BaseCollectionParams):
Attributes:
collection_name: The name of the collection to add documents to.
documents: List of BaseRecord dictionaries containing document data.
batch_size: Optional batch size for processing documents to avoid token limits.
"""
documents: list[BaseRecord]
documents: Required[list[BaseRecord]]
batch_size: int
class BaseCollectionSearchParams(BaseCollectionParams, total=False):

View File

@@ -1 +1 @@
"""Embedding components for RAG infrastructure."""
"""Embedding components for RAG infrastructure."""

View File

@@ -1,5 +1,5 @@
import os
from typing import Any, Dict, Optional, cast
from typing import Any, cast
from chromadb import Documents, EmbeddingFunction, Embeddings
from chromadb.api.types import validate_embedding_function
@@ -23,7 +23,7 @@ class EmbeddingConfigurator:
def configure_embedder(
self,
embedder_config: Optional[Dict[str, Any]] = None,
embedder_config: dict[str, Any] | None = None,
) -> EmbeddingFunction:
"""Configures and returns an embedding function based on the provided config."""
if embedder_config is None:
@@ -42,9 +42,9 @@ class EmbeddingConfigurator:
embedding_function = self.embedding_functions[provider]
except ImportError as e:
missing_package = str(e).split()[-1]
raise ImportError(
raise ImportError(
f"{missing_package} is not installed. Please install it with: pip install {missing_package}"
)
) from e
return (
embedding_function(config)
@@ -147,7 +147,7 @@ class EmbeddingConfigurator:
@staticmethod
def _configure_voyageai(config, model_name):
from chromadb.utils.embedding_functions.voyageai_embedding_function import (
from chromadb.utils.embedding_functions.voyageai_embedding_function import ( # type: ignore[import-not-found]
VoyageAIEmbeddingFunction,
)
@@ -181,9 +181,11 @@ class EmbeddingConfigurator:
@staticmethod
def _configure_watson(config, model_name):
try:
import ibm_watsonx_ai.foundation_models as watson_models
from ibm_watsonx_ai import Credentials
from ibm_watsonx_ai.metanames import EmbedTextParamsMetaNames as EmbedParams
import ibm_watsonx_ai.foundation_models as watson_models # type: ignore[import-not-found]
from ibm_watsonx_ai import Credentials # type: ignore[import-not-found]
from ibm_watsonx_ai.metanames import ( # type: ignore[import-not-found]
EmbedTextParamsMetaNames as EmbedParams,
)
except ImportError as e:
raise ImportError(
"IBM Watson dependencies are not installed. Please install them to use Watson embedding."
@@ -225,7 +227,7 @@ class EmbeddingConfigurator:
validate_embedding_function(custom_embedder)
return custom_embedder
except Exception as e:
raise ValueError(f"Invalid custom embedding function: {str(e)}")
raise ValueError(f"Invalid custom embedding function: {e!s}") from e
elif callable(custom_embedder):
try:
instance = custom_embedder()
@@ -236,7 +238,7 @@ class EmbeddingConfigurator:
"Custom embedder does not create an EmbeddingFunction instance"
)
except Exception as e:
raise ValueError(f"Error instantiating custom embedder: {str(e)}")
raise ValueError(f"Error instantiating custom embedder: {e!s}") from e
else:
raise ValueError(
"Custom embedder must be an instance of `EmbeddingFunction` or a callable that creates one"

View File

@@ -1,6 +1,8 @@
"""Minimal embedding function factory for CrewAI."""
import os
from collections.abc import Callable, MutableMapping
from typing import Any, Final, Literal, TypedDict
from chromadb import EmbeddingFunction
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
@@ -42,19 +44,116 @@ from chromadb.utils.embedding_functions.sentence_transformer_embedding_function
from chromadb.utils.embedding_functions.text2vec_embedding_function import (
Text2VecEmbeddingFunction,
)
from typing_extensions import NotRequired
from crewai.rag.embeddings.types import EmbeddingOptions
AllowedEmbeddingProviders = Literal[
"openai",
"cohere",
"ollama",
"huggingface",
"sentence-transformer",
"instructor",
"google-palm",
"google-generativeai",
"google-vertex",
"amazon-bedrock",
"jina",
"roboflow",
"openclip",
"text2vec",
"onnx",
]
class EmbedderConfig(TypedDict):
"""Configuration for embedding functions with nested format."""
provider: AllowedEmbeddingProviders
config: NotRequired[dict[str, Any]]
EMBEDDING_PROVIDERS: Final[
dict[AllowedEmbeddingProviders, Callable[..., EmbeddingFunction]]
] = {
"openai": OpenAIEmbeddingFunction,
"cohere": CohereEmbeddingFunction,
"ollama": OllamaEmbeddingFunction,
"huggingface": HuggingFaceEmbeddingFunction,
"sentence-transformer": SentenceTransformerEmbeddingFunction,
"instructor": InstructorEmbeddingFunction,
"google-palm": GooglePalmEmbeddingFunction,
"google-generativeai": GoogleGenerativeAiEmbeddingFunction,
"google-vertex": GoogleVertexEmbeddingFunction,
"amazon-bedrock": AmazonBedrockEmbeddingFunction,
"jina": JinaEmbeddingFunction,
"roboflow": RoboflowEmbeddingFunction,
"openclip": OpenCLIPEmbeddingFunction,
"text2vec": Text2VecEmbeddingFunction,
"onnx": ONNXMiniLM_L6_V2,
}
PROVIDER_ENV_MAPPING: Final[dict[AllowedEmbeddingProviders, tuple[str, str]]] = {
"openai": ("OPENAI_API_KEY", "api_key"),
"cohere": ("COHERE_API_KEY", "api_key"),
"huggingface": ("HUGGINGFACE_API_KEY", "api_key"),
"google-palm": ("GOOGLE_API_KEY", "api_key"),
"google-generativeai": ("GOOGLE_API_KEY", "api_key"),
"google-vertex": ("GOOGLE_API_KEY", "api_key"),
"jina": ("JINA_API_KEY", "api_key"),
"roboflow": ("ROBOFLOW_API_KEY", "api_key"),
}
def _inject_api_key_from_env(
provider: AllowedEmbeddingProviders, config_dict: MutableMapping[str, Any]
) -> None:
"""Inject API key or other required configuration from environment if not explicitly provided.
Args:
provider: The embedding provider name
config_dict: The configuration dictionary to modify in-place
Raises:
ImportError: If required libraries for certain providers are not installed
ValueError: If AWS session creation fails for amazon-bedrock
"""
if provider in PROVIDER_ENV_MAPPING:
env_var_name, config_key = PROVIDER_ENV_MAPPING[provider]
if config_key not in config_dict:
env_value = os.getenv(env_var_name)
if env_value:
config_dict[config_key] = env_value
if provider == "amazon-bedrock":
if "session" not in config_dict:
try:
import boto3 # type: ignore[import]
config_dict["session"] = boto3.Session()
except ImportError as e:
raise ImportError(
"boto3 is required for amazon-bedrock embeddings. "
"Install it with: uv add boto3"
) from e
except Exception as e:
raise ValueError(
f"Failed to create AWS session for amazon-bedrock. "
f"Ensure AWS credentials are configured. Error: {e}"
) from e
def get_embedding_function(
config: EmbeddingOptions | dict | None = None,
config: EmbeddingOptions | EmbedderConfig | None = None,
) -> EmbeddingFunction:
"""Get embedding function - delegates to ChromaDB.
Args:
config: Optional configuration - either an EmbeddingOptions object or a dict with:
- provider: The embedding provider to use (default: "openai")
- Any other provider-specific parameters
config: Optional configuration - either:
- EmbeddingOptions: Pydantic model with flat configuration
- EmbedderConfig: TypedDict with nested format {"provider": str, "config": dict}
- None: Uses default OpenAI configuration
Returns:
EmbeddingFunction instance ready for use with ChromaDB
@@ -81,31 +180,33 @@ def get_embedding_function(
>>> embedder = get_embedding_function()
# Use Cohere with dict
>>> embedder = get_embedding_function({
>>> embedder = get_embedding_function(EmbedderConfig(**{
... "provider": "cohere",
... "api_key": "your-key",
... "model_name": "embed-english-v3.0"
... })
... "config": {
... "api_key": "your-key",
... "model_name": "embed-english-v3.0"
... }
... }))
# Use with EmbeddingOptions
>>> embedder = get_embedding_function(
... EmbeddingOptions(provider="sentence-transformer", model_name="all-MiniLM-L6-v2")
... )
# Use local sentence transformers (no API key needed)
>>> embedder = get_embedding_function({
... "provider": "sentence-transformer",
... "model_name": "all-MiniLM-L6-v2"
# Use Azure OpenAI
>>> embedder = get_embedding_function(EmbedderConfig(**{
... "provider": "openai",
... "config": {
... "api_key": "your-azure-key",
... "api_base": "https://your-resource.openai.azure.com/",
... "api_type": "azure",
... "api_version": "2023-05-15",
... "model": "text-embedding-3-small",
... "deployment_id": "your-deployment-name"
... }
... })
# Use Ollama for local embeddings
>>> embedder = get_embedding_function({
... "provider": "ollama",
... "model_name": "nomic-embed-text"
... })
# Use ONNX (no API key needed)
>>> embedder = get_embedding_function({
>>> embedder = get_embedding_function(EmbedderConfig(**{
... "provider": "onnx"
... })
"""
@@ -114,35 +215,35 @@ def get_embedding_function(
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
)
# Handle EmbeddingOptions object
provider: AllowedEmbeddingProviders
config_dict: dict[str, Any]
if isinstance(config, EmbeddingOptions):
config_dict = config.model_dump(exclude_none=True)
provider = config_dict["provider"]
else:
config_dict = config.copy()
provider = config["provider"]
nested: dict[str, Any] = config.get("config", {})
provider = config_dict.pop("provider", "openai")
if not nested and len(config) > 1:
raise ValueError(
"Invalid embedder configuration format. "
"Configuration must be nested under a 'config' key. "
"Example: {'provider': 'openai', 'config': {'api_key': '...', 'model': '...'}}"
)
embedding_functions = {
"openai": OpenAIEmbeddingFunction,
"cohere": CohereEmbeddingFunction,
"ollama": OllamaEmbeddingFunction,
"huggingface": HuggingFaceEmbeddingFunction,
"sentence-transformer": SentenceTransformerEmbeddingFunction,
"instructor": InstructorEmbeddingFunction,
"google-palm": GooglePalmEmbeddingFunction,
"google-generativeai": GoogleGenerativeAiEmbeddingFunction,
"google-vertex": GoogleVertexEmbeddingFunction,
"amazon-bedrock": AmazonBedrockEmbeddingFunction,
"jina": JinaEmbeddingFunction,
"roboflow": RoboflowEmbeddingFunction,
"openclip": OpenCLIPEmbeddingFunction,
"text2vec": Text2VecEmbeddingFunction,
"onnx": ONNXMiniLM_L6_V2,
}
config_dict = dict(nested)
if "model" in config_dict and "model_name" not in config_dict:
config_dict["model_name"] = config_dict.pop("model")
if provider not in embedding_functions:
if provider not in EMBEDDING_PROVIDERS:
raise ValueError(
f"Unsupported provider: {provider}. "
f"Available providers: {list(embedding_functions.keys())}"
f"Available providers: {list(EMBEDDING_PROVIDERS.keys())}"
)
return embedding_functions[provider](**config_dict)
_inject_api_key_from_env(provider, config_dict)
config_dict.pop("batch_size", None)
return EMBEDDING_PROVIDERS[provider](**config_dict)

View File

@@ -1,11 +1,11 @@
"""Type definitions for the embeddings module."""
from typing import Literal
from pydantic import BaseModel, Field, SecretStr
from crewai.rag.types import EmbeddingFunction
EmbeddingProvider = Literal[
"openai",
"cohere",

View File

@@ -6,8 +6,8 @@ from crewai.rag.config.optional_imports.protocols import (
ChromaFactoryModule,
QdrantFactoryModule,
)
from crewai.rag.core.base_client import BaseClient
from crewai.rag.config.types import RagConfigType
from crewai.rag.core.base_client import BaseClient
from crewai.utilities.import_utils import require
@@ -43,3 +43,5 @@ def create_client(config: RagConfigType) -> BaseClient:
),
)
return qdrant_mod.create_client(config)
raise ValueError(f"Unsupported provider: {config.provider}")

View File

@@ -1 +1 @@
"""Qdrant vector database client implementation."""
"""Qdrant vector database client implementation."""

View File

@@ -48,6 +48,7 @@ class QdrantClient(BaseClient):
embedding_function: EmbeddingFunction | AsyncEmbeddingFunction,
default_limit: int = 5,
default_score_threshold: float = 0.6,
default_batch_size: int = 100,
) -> None:
"""Initialize QdrantClient with client and embedding function.
@@ -56,11 +57,13 @@ class QdrantClient(BaseClient):
embedding_function: Embedding function for text to vector conversion.
default_limit: Default number of results to return in searches.
default_score_threshold: Default minimum score for search results.
default_batch_size: Default batch size for adding documents.
"""
self.client = client
self.embedding_function = embedding_function
self.default_limit = default_limit
self.default_score_threshold = default_score_threshold
self.default_batch_size = default_batch_size
def create_collection(self, **kwargs: Unpack[QdrantCollectionCreateParams]) -> None:
"""Create a new collection in Qdrant.
@@ -234,6 +237,7 @@ class QdrantClient(BaseClient):
Keyword Args:
collection_name: The name of the collection to add documents to.
documents: List of BaseRecord dicts containing document data.
batch_size: Optional batch size for processing documents (default: 100)
Raises:
ValueError: If collection doesn't exist or documents list is empty.
@@ -249,6 +253,7 @@ class QdrantClient(BaseClient):
collection_name = kwargs["collection_name"]
documents = kwargs["documents"]
batch_size = kwargs.get("batch_size", self.default_batch_size)
if not documents:
raise ValueError("Documents list cannot be empty")
@@ -256,19 +261,20 @@ class QdrantClient(BaseClient):
if not self.client.collection_exists(collection_name):
raise ValueError(f"Collection '{collection_name}' does not exist")
points = []
for doc in documents:
if _is_async_embedding_function(self.embedding_function):
raise TypeError(
"Async embedding function cannot be used with sync add_documents. "
"Use aadd_documents instead."
)
sync_fn = cast(EmbeddingFunction, self.embedding_function)
embedding = sync_fn(doc["content"])
point = _create_point_from_document(doc, embedding)
points.append(point)
self.client.upsert(collection_name=collection_name, points=points)
for i in range(0, len(documents), batch_size):
batch_docs = documents[i : min(i + batch_size, len(documents))]
points = []
for doc in batch_docs:
if _is_async_embedding_function(self.embedding_function):
raise TypeError(
"Async embedding function cannot be used with sync add_documents. "
"Use aadd_documents instead."
)
sync_fn = cast(EmbeddingFunction, self.embedding_function)
embedding = sync_fn(doc["content"])
point = _create_point_from_document(doc, embedding)
points.append(point)
self.client.upsert(collection_name=collection_name, points=points)
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
"""Add documents with their embeddings to a collection asynchronously.
@@ -276,6 +282,7 @@ class QdrantClient(BaseClient):
Keyword Args:
collection_name: The name of the collection to add documents to.
documents: List of BaseRecord dicts containing document data.
batch_size: Optional batch size for processing documents (default: 100)
Raises:
ValueError: If collection doesn't exist or documents list is empty.
@@ -291,6 +298,7 @@ class QdrantClient(BaseClient):
collection_name = kwargs["collection_name"]
documents = kwargs["documents"]
batch_size = kwargs.get("batch_size", self.default_batch_size)
if not documents:
raise ValueError("Documents list cannot be empty")
@@ -298,18 +306,19 @@ class QdrantClient(BaseClient):
if not await self.client.collection_exists(collection_name):
raise ValueError(f"Collection '{collection_name}' does not exist")
points = []
for doc in documents:
if _is_async_embedding_function(self.embedding_function):
async_fn = cast(AsyncEmbeddingFunction, self.embedding_function)
embedding = await async_fn(doc["content"])
else:
sync_fn = cast(EmbeddingFunction, self.embedding_function)
embedding = sync_fn(doc["content"])
point = _create_point_from_document(doc, embedding)
points.append(point)
await self.client.upsert(collection_name=collection_name, points=points)
for i in range(0, len(documents), batch_size):
batch_docs = documents[i : min(i + batch_size, len(documents))]
points = []
for doc in batch_docs:
if _is_async_embedding_function(self.embedding_function):
async_fn = cast(AsyncEmbeddingFunction, self.embedding_function)
embedding = await async_fn(doc["content"])
else:
sync_fn = cast(EmbeddingFunction, self.embedding_function)
embedding = sync_fn(doc["content"])
point = _create_point_from_document(doc, embedding)
points.append(point)
await self.client.upsert(collection_name=collection_name, points=points)
def search(
self, **kwargs: Unpack[BaseCollectionSearchParams]

View File

@@ -2,11 +2,12 @@
from dataclasses import field
from typing import Literal, cast
from pydantic.dataclasses import dataclass as pyd_dataclass
from crewai.rag.config.base import BaseRagConfig
from crewai.rag.qdrant.types import QdrantClientParams, QdrantEmbeddingFunctionWrapper
from crewai.rag.qdrant.constants import DEFAULT_EMBEDDING_MODEL, DEFAULT_STORAGE_PATH
from crewai.rag.qdrant.types import QdrantClientParams, QdrantEmbeddingFunctionWrapper
def _default_options() -> QdrantClientParams:
@@ -24,7 +25,7 @@ def _default_embedding_function() -> QdrantEmbeddingFunctionWrapper:
Returns:
Default embedding function using fastembed with all-MiniLM-L6-v2.
"""
from fastembed import TextEmbedding
from fastembed import TextEmbedding # type: ignore[import-not-found]
model = TextEmbedding(model_name=DEFAULT_EMBEDDING_MODEL)

View File

@@ -22,4 +22,5 @@ def create_client(config: QdrantConfig) -> QdrantClient:
embedding_function=config.embedding_function,
default_limit=config.limit,
default_score_threshold=config.score_threshold,
default_batch_size=config.batch_size,
)

View File

@@ -2,13 +2,15 @@
from collections.abc import Awaitable, Callable
from typing import Annotated, Any, Protocol, TypeAlias
from typing_extensions import NotRequired, TypedDict
import numpy as np
from pydantic import GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema
from qdrant_client import AsyncQdrantClient, QdrantClient as SyncQdrantClient
from qdrant_client.models import (
from qdrant_client import AsyncQdrantClient # type: ignore[import-not-found]
from qdrant_client import (
QdrantClient as SyncQdrantClient, # type: ignore[import-not-found]
)
from qdrant_client.models import ( # type: ignore[import-not-found]
FieldCondition,
Filter,
HasIdCondition,
@@ -25,6 +27,7 @@ from qdrant_client.models import (
VectorsConfig,
WalConfigDiff,
)
from typing_extensions import NotRequired, TypedDict
from crewai.rag.core.base_client import BaseCollectionParams
@@ -134,8 +137,6 @@ class QdrantCollectionCreateParams(
):
"""High-level parameters for creating a Qdrant collection."""
pass
class CreateCollectionParams(CommonCreateFields, total=False):
"""Parameters for qdrant_client.create_collection."""

View File

@@ -4,8 +4,11 @@ import asyncio
from typing import TypeGuard
from uuid import uuid4
from qdrant_client import AsyncQdrantClient, QdrantClient as SyncQdrantClient
from qdrant_client.models import (
from qdrant_client import AsyncQdrantClient # type: ignore[import-not-found]
from qdrant_client import (
QdrantClient as SyncQdrantClient, # type: ignore[import-not-found]
)
from qdrant_client.models import ( # type: ignore[import-not-found]
FieldCondition,
Filter,
MatchValue,
@@ -25,7 +28,7 @@ from crewai.rag.qdrant.types import (
QdrantCollectionCreateParams,
QueryEmbedding,
)
from crewai.rag.types import SearchResult, BaseRecord
from crewai.rag.types import BaseRecord, SearchResult
def _ensure_list_embedding(embedding: QueryEmbedding) -> list[float]:
@@ -38,7 +41,8 @@ def _ensure_list_embedding(embedding: QueryEmbedding) -> list[float]:
Embedding as list[float].
"""
if not isinstance(embedding, list):
return embedding.tolist()
result = embedding.tolist()
return result if isinstance(result, list) else [result]
return embedding

View File

@@ -1 +1 @@
"""Storage components for RAG infrastructure."""
"""Storage components for RAG infrastructure."""

View File

@@ -1,6 +1,9 @@
from abc import ABC, abstractmethod
from typing import Any
from crewai.rag.embeddings.factory import EmbedderConfig
from crewai.rag.embeddings.types import EmbeddingOptions
class BaseRAGStorage(ABC):
"""
@@ -13,7 +16,7 @@ class BaseRAGStorage(ABC):
self,
type: str,
allow_reset: bool = True,
embedder_config: dict[str, Any] | None = None,
embedder_config: EmbeddingOptions | EmbedderConfig | None = None,
crew: Any = None,
):
self.type = type

View File

@@ -1,7 +1,7 @@
"""Type definitions for RAG (Retrieval-Augmented Generation) systems."""
from collections.abc import Callable, Mapping
from typing import TypeAlias, Any
from typing import Any, TypeAlias
from typing_extensions import Required, TypedDict

View File

@@ -5,20 +5,14 @@ import logging
import threading
import uuid
import warnings
from collections.abc import Callable
from concurrent.futures import Future
from copy import copy
from hashlib import md5
from pathlib import Path
from typing import (
Any,
Callable,
ClassVar,
Dict,
List,
Optional,
Set,
Tuple,
Type,
Union,
get_args,
get_origin,
@@ -35,20 +29,20 @@ from pydantic import (
from pydantic_core import PydanticCustomError
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.events.event_bus import crewai_event_bus
from crewai.events.event_types import (
TaskCompletedEvent,
TaskFailedEvent,
TaskStartedEvent,
)
from crewai.security import Fingerprint, SecurityConfig
from crewai.tasks.output_format import OutputFormat
from crewai.tasks.task_output import TaskOutput
from crewai.tools.base_tool import BaseTool
from crewai.utilities.config import process_config
from crewai.utilities.constants import NOT_SPECIFIED, _NotSpecified
from crewai.utilities.guardrail import process_guardrail, GuardrailResult
from crewai.utilities.converter import Converter, convert_to_model
from crewai.events.event_types import (
TaskCompletedEvent,
TaskFailedEvent,
TaskStartedEvent,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.utilities.guardrail import process_guardrail
from crewai.utilities.i18n import I18N
from crewai.utilities.printer import Printer
from crewai.utilities.string_utils import interpolate_only
@@ -85,50 +79,50 @@ class Task(BaseModel):
tools_errors: int = 0
delegations: int = 0
i18n: I18N = I18N()
name: Optional[str] = Field(default=None)
prompt_context: Optional[str] = None
name: str | None = Field(default=None)
prompt_context: str | None = None
description: str = Field(description="Description of the actual task.")
expected_output: str = Field(
description="Clear definition of expected output for the task."
)
config: Optional[Dict[str, Any]] = Field(
config: dict[str, Any] | None = Field(
description="Configuration for the agent",
default=None,
)
callback: Optional[Any] = Field(
callback: Any | None = Field(
description="Callback to be executed after the task is completed.", default=None
)
agent: Optional[BaseAgent] = Field(
agent: BaseAgent | None = Field(
description="Agent responsible for execution the task.", default=None
)
context: Union[List["Task"], None, _NotSpecified] = Field(
context: list["Task"] | None | _NotSpecified = Field(
description="Other tasks that will have their output used as context for this task.",
default=NOT_SPECIFIED,
)
async_execution: Optional[bool] = Field(
async_execution: bool | None = Field(
description="Whether the task should be executed asynchronously or not.",
default=False,
)
output_json: Optional[Type[BaseModel]] = Field(
output_json: type[BaseModel] | None = Field(
description="A Pydantic model to be used to create a JSON output.",
default=None,
)
output_pydantic: Optional[Type[BaseModel]] = Field(
output_pydantic: type[BaseModel] | None = Field(
description="A Pydantic model to be used to create a Pydantic output.",
default=None,
)
output_file: Optional[str] = Field(
output_file: str | None = Field(
description="A file path to be used to create a file output.",
default=None,
)
create_directory: Optional[bool] = Field(
create_directory: bool | None = Field(
description="Whether to create the directory for output_file if it doesn't exist.",
default=True,
)
output: Optional[TaskOutput] = Field(
output: TaskOutput | None = Field(
description="Task output, it's final result after being executed", default=None
)
tools: Optional[List[BaseTool]] = Field(
tools: list[BaseTool] | None = Field(
default_factory=list,
description="Tools the agent is limited to use for this task.",
)
@@ -141,24 +135,24 @@ class Task(BaseModel):
frozen=True,
description="Unique identifier for the object, not set by user.",
)
human_input: Optional[bool] = Field(
human_input: bool | None = Field(
description="Whether the task should have a human review the final answer of the agent",
default=False,
)
markdown: Optional[bool] = Field(
markdown: bool | None = Field(
description="Whether the task should instruct the agent to return the final answer formatted in Markdown",
default=False,
)
converter_cls: Optional[Type[Converter]] = Field(
converter_cls: type[Converter] | None = Field(
description="A converter class used to export structured output",
default=None,
)
processed_by_agents: Set[str] = Field(default_factory=set)
guardrail: Optional[Union[Callable[[TaskOutput], Tuple[bool, Any]], str]] = Field(
processed_by_agents: set[str] = Field(default_factory=set)
guardrail: Callable[[TaskOutput], tuple[bool, Any]] | str | None = Field(
default=None,
description="Function or string description of a guardrail to validate task output before proceeding to next task",
)
max_retries: Optional[int] = Field(
max_retries: int | None = Field(
default=None,
description="[DEPRECATED] Maximum number of retries when guardrail fails. Use guardrail_max_retries instead. Will be removed in v1.0.0",
)
@@ -166,13 +160,13 @@ class Task(BaseModel):
default=3, description="Maximum number of retries when guardrail fails"
)
retry_count: int = Field(default=0, description="Current number of retries")
start_time: Optional[datetime.datetime] = Field(
start_time: datetime.datetime | None = Field(
default=None, description="Start time of the task execution"
)
end_time: Optional[datetime.datetime] = Field(
end_time: datetime.datetime | None = Field(
default=None, description="End time of the task execution"
)
allow_crewai_trigger_context: Optional[bool] = Field(
allow_crewai_trigger_context: bool | None = Field(
default=None,
description="Whether this task should append 'Trigger Payload: {crewai_trigger_payload}' to the task description when crewai_trigger_payload exists in crew inputs.",
)
@@ -181,8 +175,8 @@ class Task(BaseModel):
@field_validator("guardrail")
@classmethod
def validate_guardrail_function(
cls, v: Optional[str | Callable]
) -> Optional[str | Callable]:
cls, v: str | Callable | None
) -> str | Callable | None:
"""
If v is a callable, validate that the guardrail function has the correct signature and behavior.
If v is a string, return it as is.
@@ -229,7 +223,7 @@ class Task(BaseModel):
return_annotation_args[1] is Any
or return_annotation_args[1] is str
or return_annotation_args[1] is TaskOutput
or return_annotation_args[1] == Union[str, TaskOutput]
or return_annotation_args[1] == str | TaskOutput
)
):
raise ValueError(
@@ -237,11 +231,11 @@ class Task(BaseModel):
)
return v
_guardrail: Optional[Callable] = PrivateAttr(default=None)
_original_description: Optional[str] = PrivateAttr(default=None)
_original_expected_output: Optional[str] = PrivateAttr(default=None)
_original_output_file: Optional[str] = PrivateAttr(default=None)
_thread: Optional[threading.Thread] = PrivateAttr(default=None)
_guardrail: Callable | None = PrivateAttr(default=None)
_original_description: str | None = PrivateAttr(default=None)
_original_expected_output: str | None = PrivateAttr(default=None)
_original_output_file: str | None = PrivateAttr(default=None)
_thread: threading.Thread | None = PrivateAttr(default=None)
@model_validator(mode="before")
@classmethod
@@ -265,7 +259,9 @@ class Task(BaseModel):
elif isinstance(self.guardrail, str):
from crewai.tasks.llm_guardrail import LLMGuardrail
assert self.agent is not None
if self.agent is None:
raise ValueError("Agent is required to use LLMGuardrail")
self._guardrail = LLMGuardrail(
description=self.guardrail, llm=self.agent.llm
)
@@ -274,7 +270,7 @@ class Task(BaseModel):
@field_validator("id", mode="before")
@classmethod
def _deny_user_set_id(cls, v: Optional[UUID4]) -> None:
def _deny_user_set_id(cls, v: UUID4 | None) -> None:
if v:
raise PydanticCustomError(
"may_not_set_field", "This field is not to be set by the user.", {}
@@ -282,7 +278,7 @@ class Task(BaseModel):
@field_validator("output_file")
@classmethod
def output_file_validation(cls, value: Optional[str]) -> Optional[str]:
def output_file_validation(cls, value: str | None) -> str | None:
"""Validate the output file path.
Args:
@@ -307,7 +303,7 @@ class Task(BaseModel):
)
# Check for shell expansion first
if value.startswith("~") or value.startswith("$"):
if value.startswith(("~", "$")):
raise ValueError(
"Shell expansion characters are not allowed in output_file paths"
)
@@ -373,9 +369,9 @@ class Task(BaseModel):
def execute_sync(
self,
agent: Optional[BaseAgent] = None,
context: Optional[str] = None,
tools: Optional[List[BaseTool]] = None,
agent: BaseAgent | None = None,
context: str | None = None,
tools: list[BaseTool] | None = None,
) -> TaskOutput:
"""Execute the task synchronously."""
return self._execute_core(agent, context, tools)
@@ -397,8 +393,8 @@ class Task(BaseModel):
def execute_async(
self,
agent: BaseAgent | None = None,
context: Optional[str] = None,
tools: Optional[List[BaseTool]] = None,
context: str | None = None,
tools: list[BaseTool] | None = None,
) -> Future[TaskOutput]:
"""Execute the task asynchronously."""
future: Future[TaskOutput] = Future()
@@ -411,9 +407,9 @@ class Task(BaseModel):
def _execute_task_async(
self,
agent: Optional[BaseAgent],
context: Optional[str],
tools: Optional[List[Any]],
agent: BaseAgent | None,
context: str | None,
tools: list[Any] | None,
future: Future[TaskOutput],
) -> None:
"""Execute the task asynchronously with context handling."""
@@ -422,9 +418,9 @@ class Task(BaseModel):
def _execute_core(
self,
agent: Optional[BaseAgent],
context: Optional[str],
tools: Optional[List[Any]],
agent: BaseAgent | None,
context: str | None,
tools: list[Any] | None,
) -> TaskOutput:
"""Run the core execution logic of the task."""
try:
@@ -465,6 +461,7 @@ class Task(BaseModel):
output=task_output,
guardrail=self._guardrail,
retry_count=self.retry_count,
event_source=self,
)
if not guardrail_result.success:
if self.retry_count >= self.guardrail_max_retries:
@@ -528,41 +525,6 @@ class Task(BaseModel):
crewai_event_bus.emit(self, TaskFailedEvent(error=str(e), task=self))
raise e # Re-raise the exception after emitting the event
def _process_guardrail(self, task_output: TaskOutput) -> GuardrailResult:
assert self._guardrail is not None
from crewai.events.event_types import (
LLMGuardrailCompletedEvent,
LLMGuardrailStartedEvent,
)
from crewai.events.event_bus import crewai_event_bus
crewai_event_bus.emit(
self,
LLMGuardrailStartedEvent(
guardrail=self._guardrail, retry_count=self.retry_count
),
)
try:
result = self._guardrail(task_output)
guardrail_result = GuardrailResult.from_tuple(result)
except Exception as e:
guardrail_result = GuardrailResult(
success=False, result=None, error=f"Guardrail execution error: {str(e)}"
)
crewai_event_bus.emit(
self,
LLMGuardrailCompletedEvent(
success=guardrail_result.success,
result=guardrail_result.result,
error=guardrail_result.error,
retry_count=self.retry_count,
),
)
return guardrail_result
def prompt(self) -> str:
"""Generates the task prompt with optional markdown formatting.
@@ -604,7 +566,7 @@ Follow these guidelines:
return "\n".join(tasks_slices)
def interpolate_inputs_and_add_conversation_history(
self, inputs: Dict[str, Union[str, int, float, Dict[str, Any], List[Any]]]
self, inputs: dict[str, str | int | float | dict[str, Any] | list[Any]]
) -> None:
"""Interpolate inputs into the task description, expected output, and output file path.
Add conversation history if present.
@@ -635,14 +597,14 @@ Follow these guidelines:
f"Missing required template variable '{e.args[0]}' in description"
) from e
except ValueError as e:
raise ValueError(f"Error interpolating description: {str(e)}") from e
raise ValueError(f"Error interpolating description: {e!s}") from e
try:
self.expected_output = interpolate_only(
input_string=self._original_expected_output, inputs=inputs
)
except (KeyError, ValueError) as e:
raise ValueError(f"Error interpolating expected_output: {str(e)}") from e
raise ValueError(f"Error interpolating expected_output: {e!s}") from e
if self.output_file is not None:
try:
@@ -650,11 +612,9 @@ Follow these guidelines:
input_string=self._original_output_file, inputs=inputs
)
except (KeyError, ValueError) as e:
raise ValueError(
f"Error interpolating output_file path: {str(e)}"
) from e
raise ValueError(f"Error interpolating output_file path: {e!s}") from e
if "crew_chat_messages" in inputs and inputs["crew_chat_messages"]:
if inputs.get("crew_chat_messages"):
conversation_instruction = self.i18n.slice(
"conversation_history_instruction"
)
@@ -681,14 +641,14 @@ Follow these guidelines:
"""Increment the tools errors counter."""
self.tools_errors += 1
def increment_delegations(self, agent_name: Optional[str]) -> None:
def increment_delegations(self, agent_name: str | None) -> None:
"""Increment the delegations counter."""
if agent_name:
self.processed_by_agents.add(agent_name)
self.delegations += 1
def copy(
self, agents: List["BaseAgent"], task_mapping: Dict[str, "Task"]
def copy( # type: ignore
self, agents: list["BaseAgent"], task_mapping: dict[str, "Task"]
) -> "Task":
"""Creates a deep copy of the Task while preserving its original class type.
@@ -721,20 +681,18 @@ Follow these guidelines:
cloned_agent = get_agent_by_role(self.agent.role) if self.agent else None
cloned_tools = copy(self.tools) if self.tools else []
copied_task = self.__class__(
return self.__class__(
**copied_data,
context=cloned_context,
agent=cloned_agent,
tools=cloned_tools,
)
return copied_task
def _export_output(
self, result: str
) -> Tuple[Optional[BaseModel], Optional[Dict[str, Any]]]:
pydantic_output: Optional[BaseModel] = None
json_output: Optional[Dict[str, Any]] = None
) -> tuple[BaseModel | None, dict[str, Any] | None]:
pydantic_output: BaseModel | None = None
json_output: dict[str, Any] | None = None
if self.output_pydantic or self.output_json:
model_output = convert_to_model(
@@ -764,7 +722,7 @@ Follow these guidelines:
return OutputFormat.PYDANTIC
return OutputFormat.RAW
def _save_file(self, result: Union[Dict, str, Any]) -> None:
def _save_file(self, result: dict | str | Any) -> None:
"""Save task output to a file.
Note:
@@ -785,7 +743,7 @@ Follow these guidelines:
if self.output_file is None:
raise ValueError("output_file is not set.")
FILEWRITER_RECOMMENDATION = (
filewriter_recommendation = (
"For cross-platform file writing, especially on Windows, "
"use FileWriterTool from crewai_tools package."
)
@@ -811,10 +769,10 @@ Follow these guidelines:
except (OSError, IOError) as e:
raise RuntimeError(
"\n".join(
[f"Failed to save output file: {e}", FILEWRITER_RECOMMENDATION]
[f"Failed to save output file: {e}", filewriter_recommendation]
)
)
return None
) from e
return
def __repr__(self):
return f"Task(description={self.description}, expected_output={self.expected_output})"

View File

@@ -6,7 +6,7 @@ Classes:
HallucinationGuardrail: Placeholder guardrail that validates task outputs.
"""
from typing import Any, Optional, Tuple
from typing import Any
from crewai.llm import LLM
from crewai.tasks.task_output import TaskOutput
@@ -48,7 +48,7 @@ class HallucinationGuardrail:
self,
context: str,
llm: LLM,
threshold: Optional[float] = None,
threshold: float | None = None,
tool_response: str = "",
):
"""Initialize the HallucinationGuardrail placeholder.
@@ -75,7 +75,7 @@ class HallucinationGuardrail:
"""Generate a description of this guardrail for event logging."""
return "HallucinationGuardrail (no-op)"
def __call__(self, task_output: TaskOutput) -> Tuple[bool, Any]:
def __call__(self, task_output: TaskOutput) -> tuple[bool, Any]:
"""Validate a task output against hallucination criteria.
In the open source, this method always returns that the output is valid.

View File

@@ -1,4 +1,4 @@
from typing import Any, Tuple
from typing import Any
from pydantic import BaseModel, Field
@@ -53,7 +53,7 @@ class LLMGuardrail:
Guardrail:
{self.description}
Your task:
- Confirm if the Task result complies with the guardrail.
- If not, provide clear feedback explaining what is wrong (e.g., by how much it violates the rule, or what specific part fails).
@@ -61,11 +61,9 @@ class LLMGuardrail:
- If the Task result complies with the guardrail, saying that is valid
"""
result = agent.kickoff(query, response_format=LLMGuardrailResult)
return agent.kickoff(query, response_format=LLMGuardrailResult)
return result
def __call__(self, task_output: TaskOutput) -> Tuple[bool, Any]:
def __call__(self, task_output: TaskOutput) -> tuple[bool, Any]:
"""Validates the output of a task based on specified criteria.
Args:
@@ -79,13 +77,11 @@ class LLMGuardrail:
try:
result = self._validate_output(task_output)
assert isinstance(
result.pydantic, LLMGuardrailResult
), "The guardrail result is not a valid pydantic model"
if not isinstance(result.pydantic, LLMGuardrailResult):
raise ValueError("The guardrail result is not a valid pydantic model")
if result.pydantic.valid:
return True, task_output.raw
else:
return False, result.pydantic.feedback
return False, result.pydantic.feedback
except Exception as e:
return False, f"Error while validating the task output: {str(e)}"
return False, f"Error while validating the task output: {e!s}"

View File

@@ -1,7 +1,7 @@
from .base_tool import BaseTool, tool, EnvVar
from .base_tool import BaseTool, EnvVar, tool
__all__ = [
"BaseTool",
"tool",
"EnvVar",
]
"tool",
]

View File

@@ -1,5 +1,3 @@
from typing import Dict, Optional, Union
from pydantic import BaseModel, Field
from crewai.tools.base_tool import BaseTool
@@ -10,7 +8,7 @@ i18n = I18N()
class AddImageToolSchema(BaseModel):
image_url: str = Field(..., description="The URL or path of the image to add")
action: Optional[str] = Field(
action: str | None = Field(
default=None, description="Optional context or question about the image"
)
@@ -18,14 +16,16 @@ class AddImageToolSchema(BaseModel):
class AddImageTool(BaseTool):
"""Tool for adding images to the content"""
name: str = Field(default_factory=lambda: i18n.tools("add_image")["name"]) # type: ignore
description: str = Field(default_factory=lambda: i18n.tools("add_image")["description"]) # type: ignore
name: str = Field(default_factory=lambda: i18n.tools("add_image")["name"]) # type: ignore[index]
description: str = Field(
default_factory=lambda: i18n.tools("add_image")["description"] # type: ignore[index]
)
args_schema: type[BaseModel] = AddImageToolSchema
def _run(
self,
image_url: str,
action: Optional[str] = None,
action: str | None = None,
**kwargs,
) -> dict:
action = action or i18n.tools("add_image")["default_action"] # type: ignore

View File

@@ -9,9 +9,9 @@ from .delegate_work_tool import DelegateWorkTool
class AgentTools:
"""Manager class for agent-related tools"""
def __init__(self, agents: list[BaseAgent], i18n: I18N = I18N()):
def __init__(self, agents: list[BaseAgent], i18n: I18N | None = None):
self.agents = agents
self.i18n = i18n
self.i18n = i18n if i18n is not None else I18N()
def tools(self) -> list[BaseTool]:
"""Get all available agent tools"""

View File

@@ -1,5 +1,3 @@
from typing import Optional
from pydantic import BaseModel, Field
from crewai.tools.agent_tools.base_agent_tools import BaseAgentTool
@@ -21,7 +19,7 @@ class AskQuestionTool(BaseAgentTool):
self,
question: str,
context: str,
coworker: Optional[str] = None,
coworker: str | None = None,
**kwargs,
) -> str:
coworker = self._get_coworker(coworker, **kwargs)

View File

@@ -1,5 +1,4 @@
import logging
from typing import Optional
from pydantic import Field
@@ -38,7 +37,7 @@ class BaseAgentTool(BaseTool):
# Remove quotes and convert to lowercase
return normalized.replace('"', "").casefold()
def _get_coworker(self, coworker: Optional[str], **kwargs) -> Optional[str]:
def _get_coworker(self, coworker: str | None, **kwargs) -> str | None:
coworker = coworker or kwargs.get("co_worker") or kwargs.get("coworker")
if coworker:
is_list = coworker.startswith("[") and coworker.endswith("]")
@@ -47,10 +46,7 @@ class BaseAgentTool(BaseTool):
return coworker
def _execute(
self,
agent_name: Optional[str],
task: str,
context: Optional[str] = None
self, agent_name: str | None, task: str, context: str | None = None
) -> str:
"""
Execute delegation to an agent with case-insensitive and whitespace-tolerant matching.
@@ -77,7 +73,9 @@ class BaseAgentTool(BaseTool):
# when it should look like this:
# {"task": "....", "coworker": "...."}
sanitized_name = self.sanitize_agent_name(agent_name)
logger.debug(f"Sanitized agent name from '{agent_name}' to '{sanitized_name}'")
logger.debug(
f"Sanitized agent name from '{agent_name}' to '{sanitized_name}'"
)
available_agents = [agent.role for agent in self.agents]
logger.debug(f"Available agents: {available_agents}")
@@ -87,38 +85,47 @@ class BaseAgentTool(BaseTool):
for available_agent in self.agents
if self.sanitize_agent_name(available_agent.role) == sanitized_name
]
logger.debug(f"Found {len(agent)} matching agents for role '{sanitized_name}'")
logger.debug(
f"Found {len(agent)} matching agents for role '{sanitized_name}'"
)
except (AttributeError, ValueError) as e:
# Handle specific exceptions that might occur during role name processing
return self.i18n.errors("agent_tool_unexisting_coworker").format(
coworkers="\n".join(
[f"- {self.sanitize_agent_name(agent.role)}" for agent in self.agents]
[
f"- {self.sanitize_agent_name(agent.role)}"
for agent in self.agents
]
),
error=str(e)
error=str(e),
)
if not agent:
# No matching agent found after sanitization
return self.i18n.errors("agent_tool_unexisting_coworker").format(
coworkers="\n".join(
[f"- {self.sanitize_agent_name(agent.role)}" for agent in self.agents]
[
f"- {self.sanitize_agent_name(agent.role)}"
for agent in self.agents
]
),
error=f"No agent found with role '{sanitized_name}'"
error=f"No agent found with role '{sanitized_name}'",
)
agent = agent[0]
selected_agent = agent[0]
try:
task_with_assigned_agent = Task(
description=task,
agent=agent,
expected_output=agent.i18n.slice("manager_request"),
i18n=agent.i18n,
agent=selected_agent,
expected_output=selected_agent.i18n.slice("manager_request"),
i18n=selected_agent.i18n,
)
logger.debug(f"Created task for agent '{self.sanitize_agent_name(agent.role)}': {task}")
return agent.execute_task(task_with_assigned_agent, context)
logger.debug(
f"Created task for agent '{self.sanitize_agent_name(selected_agent.role)}': {task}"
)
return selected_agent.execute_task(task_with_assigned_agent, context)
except Exception as e:
# Handle task creation or execution errors
return self.i18n.errors("agent_tool_execution_error").format(
agent_role=self.sanitize_agent_name(agent.role),
error=str(e)
agent_role=self.sanitize_agent_name(selected_agent.role), error=str(e)
)

View File

@@ -1,5 +1,3 @@
from typing import Optional
from pydantic import BaseModel, Field
from crewai.tools.agent_tools.base_agent_tools import BaseAgentTool
@@ -23,7 +21,7 @@ class DelegateWorkTool(BaseAgentTool):
self,
task: str,
context: str,
coworker: Optional[str] = None,
coworker: str | None = None,
**kwargs,
) -> str:
coworker = self._get_coworker(coworker, **kwargs)

View File

@@ -1,7 +1,8 @@
import asyncio
from abc import ABC, abstractmethod
from collections.abc import Callable
from inspect import signature
from typing import Any, Callable, Type, get_args, get_origin, Optional, List
from typing import Any, get_args, get_origin
from pydantic import (
BaseModel,
@@ -19,7 +20,7 @@ class EnvVar(BaseModel):
name: str
description: str
required: bool = True
default: Optional[str] = None
default: str | None = None
class BaseTool(BaseModel, ABC):
@@ -32,10 +33,10 @@ class BaseTool(BaseModel, ABC):
"""The unique name of the tool that clearly communicates its purpose."""
description: str
"""Used to tell the model how/when/why to use the tool."""
env_vars: List[EnvVar] = []
env_vars: list[EnvVar] = []
"""List of environment variables used by the tool."""
args_schema: Type[PydanticBaseModel] = Field(
default_factory=_ArgsSchemaPlaceholder, validate_default=True
args_schema: type[PydanticBaseModel] = Field(
default=_ArgsSchemaPlaceholder, validate_default=True
)
"""The schema for the arguments that the tool accepts."""
description_updated: bool = False
@@ -52,9 +53,9 @@ class BaseTool(BaseModel, ABC):
@field_validator("args_schema", mode="before")
@classmethod
def _default_args_schema(
cls, v: Type[PydanticBaseModel]
) -> Type[PydanticBaseModel]:
if not isinstance(v, cls._ArgsSchemaPlaceholder):
cls, v: type[PydanticBaseModel]
) -> type[PydanticBaseModel]:
if v != cls._ArgsSchemaPlaceholder:
return v
return type(
@@ -139,7 +140,7 @@ class BaseTool(BaseModel, ABC):
# Infer args_schema from the function signature if not provided
func_signature = signature(tool.func)
annotations = func_signature.parameters
args_fields = {}
args_fields: dict[str, Any] = {}
for name, param in annotations.items():
if name != "self":
param_annotation = (
@@ -247,7 +248,7 @@ class Tool(BaseTool):
# Infer args_schema from the function signature if not provided
func_signature = signature(tool.func)
annotations = func_signature.parameters
args_fields = {}
args_fields: dict[str, Any] = {}
for name, param in annotations.items():
if name != "self":
param_annotation = (

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel as PydanticBaseModel
@@ -7,7 +7,7 @@ from pydantic import Field as PydanticField
class ToolCalling(BaseModel):
tool_name: str = Field(..., description="The name of the tool to be called.")
arguments: Optional[Dict[str, Any]] = Field(
arguments: dict[str, Any] | None = Field(
..., description="A dictionary of arguments to be passed to the tool."
)
@@ -16,6 +16,6 @@ class InstructorToolCalling(PydanticBaseModel):
tool_name: str = PydanticField(
..., description="The name of the tool to be called."
)
arguments: Optional[Dict[str, Any]] = PydanticField(
arguments: dict[str, Any] | None = PydanticField(
..., description="A dictionary of arguments to be passed to the tool."
)

View File

@@ -1,26 +1,24 @@
from .converter import Converter, ConverterError
from .file_handler import FileHandler
from .i18n import I18N
from .internal_instructor import InternalInstructor
from .logger import Logger
from .parser import YamlParser
from .printer import Printer
from .prompts import Prompts
from .rpm_controller import RPMController
from .exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededException,
from crewai.utilities.converter import Converter, ConverterError
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededError,
)
from crewai.utilities.file_handler import FileHandler
from crewai.utilities.i18n import I18N
from crewai.utilities.internal_instructor import InternalInstructor
from crewai.utilities.logger import Logger
from crewai.utilities.printer import Printer
from crewai.utilities.prompts import Prompts
from crewai.utilities.rpm_controller import RPMController
__all__ = [
"I18N",
"Converter",
"ConverterError",
"FileHandler",
"I18N",
"InternalInstructor",
"LLMContextLengthExceededError",
"Logger",
"Printer",
"Prompts",
"RPMController",
"YamlParser",
"LLMContextLengthExceededException",
]

View File

@@ -1,7 +1,9 @@
from __future__ import annotations
import json
import re
from collections.abc import Callable, Sequence
from typing import Any
from typing import TYPE_CHECKING, Any, Final, Literal, TypedDict
from rich.console import Console
@@ -19,18 +21,47 @@ from crewai.tools import BaseTool as CrewAITool
from crewai.tools.base_tool import BaseTool
from crewai.tools.structured_tool import CrewStructuredTool
from crewai.tools.tool_types import ToolResult
from crewai.utilities import I18N, Printer
from crewai.utilities.errors import AgentRepositoryError
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededException,
LLMContextLengthExceededError,
)
from crewai.utilities.i18n import I18N
from crewai.utilities.printer import ColoredText, Printer
from crewai.utilities.types import LLMMessage
if TYPE_CHECKING:
from crewai.agent import Agent
from crewai.task import Task
class SummaryContent(TypedDict):
"""Structure for summary content entries.
Attributes:
content: The summarized content.
"""
content: str
console = Console()
_MULTIPLE_NEWLINES: Final[re.Pattern[str]] = re.compile(r"\n+")
def parse_tools(tools: list[BaseTool]) -> list[CrewStructuredTool]:
"""Parse tools to be used for the task."""
tools_list = []
"""Parse tools to be used for the task.
Args:
tools: List of tools to parse.
Returns:
List of structured tools.
Raises:
ValueError: If a tool is not a CrewStructuredTool or BaseTool.
"""
tools_list: list[CrewStructuredTool] = []
for tool in tools:
if isinstance(tool, CrewAITool):
@@ -42,7 +73,14 @@ def parse_tools(tools: list[BaseTool]) -> list[CrewStructuredTool]:
def get_tool_names(tools: Sequence[CrewStructuredTool | BaseTool]) -> str:
"""Get the names of the tools."""
"""Get the names of the tools.
Args:
tools: List of tools to get names from.
Returns:
Comma-separated string of tool names.
"""
return ", ".join([t.name for t in tools])
@@ -51,16 +89,30 @@ def render_text_description_and_args(
) -> str:
"""Render the tool name, description, and args in plain text.
search: This tool is used for search, args: {"query": {"type": "string"}}
calculator: This tool is used for math, \
args: {"expression": {"type": "string"}}
search: This tool is used for search, args: {"query": {"type": "string"}}
calculator: This tool is used for math, \
args: {"expression": {"type": "string"}}
Args:
tools: List of tools to render.
Returns:
Plain text description of tools.
"""
tool_strings = [tool.description for tool in tools]
return "\n".join(tool_strings)
def has_reached_max_iterations(iterations: int, max_iterations: int) -> bool:
"""Check if the maximum number of iterations has been reached."""
"""Check if the maximum number of iterations has been reached.
Args:
iterations: Current number of iterations.
max_iterations: Maximum allowed iterations.
Returns:
True if maximum iterations reached, False otherwise.
"""
return iterations >= max_iterations
@@ -68,16 +120,19 @@ def handle_max_iterations_exceeded(
formatted_answer: AgentAction | AgentFinish | None,
printer: Printer,
i18n: I18N,
messages: list[dict[str, str]],
messages: list[LLMMessage],
llm: LLM | BaseLLM,
callbacks: list[Any],
callbacks: list[Callable[..., Any]],
) -> AgentAction | AgentFinish:
"""
Handles the case when the maximum number of iterations is exceeded.
Performs one more LLM call to get the final answer.
"""Handles the case when the maximum number of iterations is exceeded. Performs one more LLM call to get the final answer.
Parameters:
Args:
formatted_answer: The last formatted answer from the agent.
printer: Printer instance for output.
i18n: I18N instance for internationalization.
messages: List of messages to send to the LLM.
llm: The LLM instance to call.
callbacks: List of callbacks for the LLM call.
Returns:
The final formatted answer after exceeding max iterations.
@@ -98,7 +153,7 @@ def handle_max_iterations_exceeded(
# Perform one more LLM call to get the final answer
answer = llm.call(
messages,
messages, # type: ignore[arg-type]
callbacks=callbacks,
)
@@ -110,20 +165,38 @@ def handle_max_iterations_exceeded(
raise ValueError("Invalid response from LLM call - None or empty.")
# Return the formatted answer, regardless of its type
return format_answer(answer)
return format_answer(answer=answer)
def format_message_for_llm(prompt: str, role: str = "user") -> dict[str, str]:
def format_message_for_llm(
prompt: str, role: Literal["user", "assistant", "system"] = "user"
) -> LLMMessage:
"""Format a message for the LLM.
Args:
prompt: The message content.
role: The role of the message sender, either 'user' or 'assistant'.
Returns:
A dictionary with 'role' and 'content' keys.
"""
prompt = prompt.rstrip()
return {"role": role, "content": prompt}
def format_answer(answer: str) -> AgentAction | AgentFinish:
"""Format a response from the LLM into an AgentAction or AgentFinish."""
"""Format a response from the LLM into an AgentAction or AgentFinish.
Args:
answer: The raw response from the LLM
Returns:
Either an AgentAction or AgentFinish
"""
try:
return parse(answer)
except Exception:
# If parsing fails, return a default AgentFinish
return AgentFinish(
thought="Failed to parse LLM response",
output=answer,
@@ -134,23 +207,43 @@ def format_answer(answer: str) -> AgentAction | AgentFinish:
def enforce_rpm_limit(
request_within_rpm_limit: Callable[[], bool] | None = None,
) -> None:
"""Enforce the requests per minute (RPM) limit if applicable."""
"""Enforce the requests per minute (RPM) limit if applicable.
Args:
request_within_rpm_limit: Function to enforce RPM limit.
"""
if request_within_rpm_limit:
request_within_rpm_limit()
def get_llm_response(
llm: LLM | BaseLLM,
messages: list[dict[str, str]],
callbacks: list[Any],
messages: list[LLMMessage],
callbacks: list[Callable[..., Any]],
printer: Printer,
from_task: Any | None = None,
from_agent: Any | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
) -> str:
"""Call the LLM and return the response, handling any invalid responses."""
"""Call the LLM and return the response, handling any invalid responses.
Args:
llm: The LLM instance to call
messages: The messages to send to the LLM
callbacks: List of callbacks for the LLM call
printer: Printer instance for output
from_task: Optional task context for the LLM call
from_agent: Optional agent context for the LLM call
Returns:
The response from the LLM as a string
Raises:
Exception: If an error occurs.
ValueError: If the response is None or empty.
"""
try:
answer = llm.call(
messages,
messages, # type: ignore[arg-type]
callbacks=callbacks,
from_task=from_task,
from_agent=from_agent,
@@ -170,7 +263,15 @@ def get_llm_response(
def process_llm_response(
answer: str, use_stop_words: bool
) -> AgentAction | AgentFinish:
"""Process the LLM response and format it into an AgentAction or AgentFinish."""
"""Process the LLM response and format it into an AgentAction or AgentFinish.
Args:
answer: The raw response from the LLM
use_stop_words: Whether to use stop words in the LLM call
Returns:
Either an AgentAction or AgentFinish
"""
if not use_stop_words:
try:
# Preliminary parsing to check for errors.
@@ -200,6 +301,9 @@ def handle_agent_action_core(
Returns:
Either an AgentAction or AgentFinish
Notes:
- TODO: Remove messages parameter and its usage.
"""
if step_callback:
step_callback(tool_result)
@@ -220,7 +324,7 @@ def handle_agent_action_core(
return formatted_answer
def handle_unknown_error(printer: Any, exception: Exception) -> None:
def handle_unknown_error(printer: Printer, exception: Exception) -> None:
"""Handle unknown errors by informing the user.
Args:
@@ -244,10 +348,10 @@ def handle_unknown_error(printer: Any, exception: Exception) -> None:
def handle_output_parser_exception(
e: OutputParserError,
messages: list[dict[str, str]],
messages: list[LLMMessage],
iterations: int,
log_error_after: int = 3,
printer: Any | None = None,
printer: Printer | None = None,
) -> AgentAction:
"""Handle OutputParserError by updating messages and formatted_answer.
@@ -288,18 +392,18 @@ def is_context_length_exceeded(exception: Exception) -> bool:
Returns:
bool: True if the exception is due to context length exceeding
"""
return LLMContextLengthExceededException(str(exception))._is_context_limit_error(
return LLMContextLengthExceededError(str(exception))._is_context_limit_error(
str(exception)
)
def handle_context_length(
respect_context_window: bool,
printer: Any,
messages: list[dict[str, str]],
llm: Any,
callbacks: list[Any],
i18n: Any,
printer: Printer,
messages: list[LLMMessage],
llm: LLM | BaseLLM,
callbacks: list[Callable[..., Any]],
i18n: I18N,
) -> None:
"""Handle context length exceeded by either summarizing or raising an error.
@@ -310,13 +414,16 @@ def handle_context_length(
llm: LLM instance for summarization
callbacks: List of callbacks for LLM
i18n: I18N instance for messages
Raises:
SystemExit: If context length is exceeded and user opts not to summarize
"""
if respect_context_window:
printer.print(
content="Context length exceeded. Summarizing content to fit the model context window. Might take a while...",
color="yellow",
)
summarize_messages(messages, llm, callbacks, i18n)
summarize_messages(messages=messages, llm=llm, callbacks=callbacks, i18n=i18n)
else:
printer.print(
content="Context length exceeded. Consider using smaller text or RAG tools from crewai_tools.",
@@ -328,10 +435,10 @@ def handle_context_length(
def summarize_messages(
messages: list[dict[str, str]],
llm: Any,
callbacks: list[Any],
i18n: Any,
messages: list[LLMMessage],
llm: LLM | BaseLLM,
callbacks: list[Callable[..., Any]],
i18n: I18N,
) -> None:
"""Summarize messages to fit within context window.
@@ -349,7 +456,7 @@ def summarize_messages(
for i in range(0, len(messages_string), cut_size)
]
summarized_contents = []
summarized_contents: list[SummaryContent] = []
total_groups = len(messages_groups)
for idx, group in enumerate(messages_groups, 1):
@@ -357,15 +464,17 @@ def summarize_messages(
content=f"Summarizing {idx}/{total_groups}...",
color="yellow",
)
messages = [
format_message_for_llm(
i18n.slice("summarizer_system_message"), role="system"
),
format_message_for_llm(
i18n.slice("summarize_instruction").format(group=group["content"]),
),
]
summary = llm.call(
[
format_message_for_llm(
i18n.slice("summarizer_system_message"), role="system"
),
format_message_for_llm(
i18n.slice("summarize_instruction").format(group=group["content"]),
),
],
messages, # type: ignore[arg-type]
callbacks=callbacks,
)
summarized_contents.append({"content": str(summary)})
@@ -404,20 +513,29 @@ def show_agent_logs(
if formatted_answer is None:
# Start logs
printer.print(
content=f"\033[1m\033[95m# Agent:\033[00m \033[1m\033[92m{agent_role}\033[00m"
content=[
ColoredText("# Agent: ", "bold_purple"),
ColoredText(agent_role, "bold_green"),
]
)
if task_description:
printer.print(
content=f"\033[95m## Task:\033[00m \033[92m{task_description}\033[00m"
content=[
ColoredText("## Task: ", "purple"),
ColoredText(task_description, "green"),
]
)
else:
# Execution logs
printer.print(
content=f"\n\n\033[1m\033[95m# Agent:\033[00m \033[1m\033[92m{agent_role}\033[00m"
content=[
ColoredText("\n\n# Agent: ", "bold_purple"),
ColoredText(agent_role, "bold_green"),
]
)
if isinstance(formatted_answer, AgentAction):
thought = re.sub(r"\n+", "\n", formatted_answer.thought)
thought = _MULTIPLE_NEWLINES.sub("\n", formatted_answer.thought)
formatted_json = json.dumps(
formatted_answer.tool_input,
indent=2,
@@ -425,24 +543,39 @@ def show_agent_logs(
)
if thought and thought != "":
printer.print(
content=f"\033[95m## Thought:\033[00m \033[92m{thought}\033[00m"
content=[
ColoredText("## Thought: ", "purple"),
ColoredText(thought, "green"),
]
)
printer.print(
content=f"\033[95m## Using tool:\033[00m \033[92m{formatted_answer.tool}\033[00m"
content=[
ColoredText("## Using tool: ", "purple"),
ColoredText(formatted_answer.tool, "green"),
]
)
printer.print(
content=f"\033[95m## Tool Input:\033[00m \033[92m\n{formatted_json}\033[00m"
content=[
ColoredText("## Tool Input: ", "purple"),
ColoredText(f"\n{formatted_json}", "green"),
]
)
printer.print(
content=f"\033[95m## Tool Output:\033[00m \033[92m\n{formatted_answer.result}\033[00m"
content=[
ColoredText("## Tool Output: ", "purple"),
ColoredText(f"\n{formatted_answer.result}", "green"),
]
)
elif isinstance(formatted_answer, AgentFinish):
printer.print(
content=f"\033[95m## Final Answer:\033[00m \033[92m\n{formatted_answer.output}\033[00m\n\n"
content=[
ColoredText("## Final Answer: ", "purple"),
ColoredText(f"\n{formatted_answer.output}\n\n", "green"),
]
)
def _print_current_organization():
def _print_current_organization() -> None:
settings = Settings()
if settings.org_uuid:
console.print(
@@ -457,6 +590,17 @@ def _print_current_organization():
def load_agent_from_repository(from_repository: str) -> dict[str, Any]:
"""Load an agent from the repository.
Args:
from_repository: The name of the agent to load.
Returns:
A dictionary of attributes to use for the agent.
Raises:
AgentRepositoryError: If the agent cannot be loaded.
"""
attributes: dict[str, Any] = {}
if from_repository:
import importlib

View File

@@ -1,20 +1,19 @@
from typing import Any, Dict, Type
from typing import Any
from pydantic import BaseModel
def process_config(
values: Dict[str, Any], model_class: Type[BaseModel]
) -> Dict[str, Any]:
"""
Process the config dictionary and update the values accordingly.
values: dict[str, Any], model_class: type[BaseModel]
) -> dict[str, Any]:
"""Process the config dictionary and update the values accordingly.
Args:
values (Dict[str, Any]): The dictionary of values to update.
model_class (Type[BaseModel]): The Pydantic model class to reference for field validation.
values: The dictionary of values to update.
model_class: The Pydantic model class to reference for field validation.
Returns:
Dict[str, Any]: The updated values dictionary.
The updated values dictionary.
"""
config = values.get("config", {})
if not config:

View File

@@ -1,19 +1,32 @@
TRAINING_DATA_FILE = "training_data.pkl"
TRAINED_AGENTS_DATA_FILE = "trained_agents_data.pkl"
DEFAULT_SCORE_THRESHOLD = 0.35
KNOWLEDGE_DIRECTORY = "knowledge"
MAX_LLM_RETRY = 3
MAX_FILE_NAME_LENGTH = 255
EMITTER_COLOR = "bold_blue"
from typing import Annotated, Final
from crewai.utilities.printer import PrinterColor
TRAINING_DATA_FILE: Final[str] = "training_data.pkl"
TRAINED_AGENTS_DATA_FILE: Final[str] = "trained_agents_data.pkl"
KNOWLEDGE_DIRECTORY: Final[str] = "knowledge"
MAX_FILE_NAME_LENGTH: Final[int] = 255
EMITTER_COLOR: Final[PrinterColor] = "bold_blue"
class _NotSpecified:
def __repr__(self):
"""Sentinel class to detect when no value has been explicitly provided.
Notes:
- TODO: Consider moving this class and NOT_SPECIFIED to types.py
as they are more type-related constructs than business constants.
"""
def __repr__(self) -> str:
return "NOT_SPECIFIED"
# Sentinel value used to detect when no value has been explicitly provided.
# Unlike `None`, which might be a valid value from the user, `NOT_SPECIFIED` allows
# us to distinguish between "not passed at all" and "explicitly passed None" or "[]".
NOT_SPECIFIED = _NotSpecified()
CREWAI_BASE_URL = "https://app.crewai.com"
NOT_SPECIFIED: Final[
Annotated[
_NotSpecified,
"Sentinel value used to detect when no value has been explicitly provided. "
"Unlike `None`, which might be a valid value from the user, `NOT_SPECIFIED` "
"allows us to distinguish between 'not passed at all' and 'explicitly passed None' or '[]'.",
]
] = _NotSpecified()
CREWAI_BASE_URL: Final[str] = "https://app.crewai.com"

View File

@@ -1,18 +1,35 @@
from __future__ import annotations
import json
import re
from typing import Any, Optional, Type, Union, get_args, get_origin
from typing import TYPE_CHECKING, Any, Final, TypedDict, Union, get_args, get_origin
from pydantic import BaseModel, ValidationError
from typing_extensions import Unpack
from crewai.agents.agent_builder.utilities.base_output_converter import OutputConverter
from crewai.utilities.internal_instructor import InternalInstructor
from crewai.utilities.printer import Printer
from crewai.utilities.pydantic_schema_parser import PydanticSchemaParser
if TYPE_CHECKING:
from crewai.agent import Agent
from crewai.llm import LLM
from crewai.llms.base_llm import BaseLLM
_JSON_PATTERN: Final[re.Pattern[str]] = re.compile(r"({.*})", re.DOTALL)
class ConverterError(Exception):
"""Error raised when Converter fails to parse the input."""
def __init__(self, message: str, *args: object) -> None:
"""Initialize the ConverterError with a message.
Args:
message: The error message.
*args: Additional arguments for the base Exception class.
"""
super().__init__(message, *args)
self.message = message
@@ -20,8 +37,18 @@ class ConverterError(Exception):
class Converter(OutputConverter):
"""Class that converts text into either pydantic or json."""
def to_pydantic(self, current_attempt=1) -> BaseModel:
"""Convert text to pydantic."""
def to_pydantic(self, current_attempt: int = 1) -> BaseModel:
"""Convert text to pydantic.
Args:
current_attempt: The current attempt number for conversion retries.
Returns:
A Pydantic BaseModel instance.
Raises:
ConverterError: If conversion fails after maximum attempts.
"""
try:
if self.llm.supports_function_calling():
result = self._create_instructor().to_pydantic()
@@ -37,104 +64,124 @@ class Converter(OutputConverter):
result = self.model.model_validate_json(response)
except ValidationError:
# If direct validation fails, attempt to extract valid JSON
result = handle_partial_json(response, self.model, False, None)
result = handle_partial_json(
result=response,
model=self.model,
is_json_output=False,
agent=None,
)
# Ensure result is a BaseModel instance
if not isinstance(result, BaseModel):
if isinstance(result, dict):
result = self.model.parse_obj(result)
result = self.model.model_validate(result)
elif isinstance(result, str):
try:
parsed = json.loads(result)
result = self.model.parse_obj(parsed)
result = self.model.model_validate(parsed)
except Exception as parse_err:
raise ConverterError(
f"Failed to convert partial JSON result into Pydantic: {parse_err}"
)
) from parse_err
else:
raise ConverterError(
"handle_partial_json returned an unexpected type."
)
) from None
return result
except ValidationError as e:
if current_attempt < self.max_attempts:
return self.to_pydantic(current_attempt + 1)
raise ConverterError(
f"Failed to convert text into a Pydantic model due to validation error: {e}"
)
) from e
except Exception as e:
if current_attempt < self.max_attempts:
return self.to_pydantic(current_attempt + 1)
raise ConverterError(
f"Failed to convert text into a Pydantic model due to error: {e}"
)
) from e
def to_json(self, current_attempt=1):
"""Convert text to json."""
def to_json(self, current_attempt: int = 1) -> str | ConverterError | Any: # type: ignore[override]
"""Convert text to json.
Args:
current_attempt: The current attempt number for conversion retries.
Returns:
A JSON string or ConverterError if conversion fails.
Raises:
ConverterError: If conversion fails after maximum attempts.
"""
try:
if self.llm.supports_function_calling():
return self._create_instructor().to_json()
else:
return json.dumps(
self.llm.call(
[
{"role": "system", "content": self.instructions},
{"role": "user", "content": self.text},
]
)
return json.dumps(
self.llm.call(
[
{"role": "system", "content": self.instructions},
{"role": "user", "content": self.text},
]
)
)
except Exception as e:
if current_attempt < self.max_attempts:
return self.to_json(current_attempt + 1)
return ConverterError(f"Failed to convert text into JSON, error: {e}.")
def _create_instructor(self):
def _create_instructor(self) -> InternalInstructor:
"""Create an instructor."""
from crewai.utilities import InternalInstructor
inst = InternalInstructor(
return InternalInstructor(
llm=self.llm,
model=self.model,
content=self.text,
)
return inst
def _convert_with_instructions(self):
"""Create a chain."""
from crewai.utilities.crew_pydantic_output_parser import (
CrewPydanticOutputParser,
)
parser = CrewPydanticOutputParser(pydantic_object=self.model)
result = self.llm.call(
[
{"role": "system", "content": self.instructions},
{"role": "user", "content": self.text},
]
)
return parser.parse_result(result)
def convert_to_model(
result: str,
output_pydantic: Optional[Type[BaseModel]],
output_json: Optional[Type[BaseModel]],
agent: Any,
converter_cls: Optional[Type[Converter]] = None,
) -> Union[dict, BaseModel, str]:
output_pydantic: type[BaseModel] | None,
output_json: type[BaseModel] | None,
agent: Agent | None = None,
converter_cls: type[Converter] | None = None,
) -> dict[str, Any] | BaseModel | str:
"""Convert a result string to a Pydantic model or JSON.
Args:
result: The result string to convert.
output_pydantic: The Pydantic model class to convert to.
output_json: The Pydantic model class to convert to JSON.
agent: The agent instance.
converter_cls: The converter class to use.
Returns:
The converted result as a dict, BaseModel, or original string.
"""
model = output_pydantic or output_json
if model is None:
return result
try:
escaped_result = json.dumps(json.loads(result, strict=False))
return validate_model(escaped_result, model, bool(output_json))
return validate_model(
result=escaped_result, model=model, is_json_output=bool(output_json)
)
except json.JSONDecodeError:
return handle_partial_json(
result, model, bool(output_json), agent, converter_cls
result=result,
model=model,
is_json_output=bool(output_json),
agent=agent,
converter_cls=converter_cls,
)
except ValidationError:
return handle_partial_json(
result, model, bool(output_json), agent, converter_cls
result=result,
model=model,
is_json_output=bool(output_json),
agent=agent,
converter_cls=converter_cls,
)
except Exception as e:
@@ -146,8 +193,18 @@ def convert_to_model(
def validate_model(
result: str, model: Type[BaseModel], is_json_output: bool
) -> Union[dict, BaseModel]:
result: str, model: type[BaseModel], is_json_output: bool
) -> dict[str, Any] | BaseModel:
"""Validate and convert a JSON string to a Pydantic model or dict.
Args:
result: The JSON string to validate and convert.
model: The Pydantic model class to convert to.
is_json_output: Whether to return a dict (True) or Pydantic model (False).
Returns:
The converted result as a dict or BaseModel.
"""
exported_result = model.model_validate_json(result)
if is_json_output:
return exported_result.model_dump()
@@ -156,15 +213,27 @@ def validate_model(
def handle_partial_json(
result: str,
model: Type[BaseModel],
model: type[BaseModel],
is_json_output: bool,
agent: Any,
converter_cls: Optional[Type[Converter]] = None,
) -> Union[dict, BaseModel, str]:
match = re.search(r"({.*})", result, re.DOTALL)
agent: Agent | None,
converter_cls: type[Converter] | None = None,
) -> dict[str, Any] | BaseModel | str:
"""Handle partial JSON in a result string and convert to Pydantic model or dict.
Args:
result: The result string to process.
model: The Pydantic model class to convert to.
is_json_output: Whether to return a dict (True) or Pydantic model (False).
agent: The agent instance.
converter_cls: The converter class to use.
Returns:
The converted result as a dict, BaseModel, or original string.
"""
match = _JSON_PATTERN.search(result)
if match:
try:
exported_result = model.model_validate_json(match.group(0))
exported_result = model.model_validate_json(match.group())
if is_json_output:
return exported_result.model_dump()
return exported_result
@@ -179,19 +248,43 @@ def handle_partial_json(
)
return convert_with_instructions(
result, model, is_json_output, agent, converter_cls
result=result,
model=model,
is_json_output=is_json_output,
agent=agent,
converter_cls=converter_cls,
)
def convert_with_instructions(
result: str,
model: Type[BaseModel],
model: type[BaseModel],
is_json_output: bool,
agent: Any,
converter_cls: Optional[Type[Converter]] = None,
) -> Union[dict, BaseModel, str]:
agent: Agent | None,
converter_cls: type[Converter] | None = None,
) -> dict | BaseModel | str:
"""Convert a result string to a Pydantic model or JSON using instructions.
Args:
result: The result string to convert.
model: The Pydantic model class to convert to.
is_json_output: Whether to return a dict (True) or Pydantic model (False).
agent: The agent instance.
converter_cls: The converter class to use.
Returns:
The converted result as a dict, BaseModel, or original string.
Raises:
TypeError: If neither agent nor converter_cls is provided.
Notes:
- TODO: Fix llm typing issues, return llm should not be able to be str or None.
"""
if agent is None:
raise TypeError("Agent must be provided if converter_cls is not specified.")
llm = agent.function_calling_llm or agent.llm
instructions = get_conversion_instructions(model, llm)
instructions = get_conversion_instructions(model=model, llm=llm)
converter = create_converter(
agent=agent,
converter_cls=converter_cls,
@@ -214,9 +307,25 @@ def convert_with_instructions(
return exported_result
def get_conversion_instructions(model: Type[BaseModel], llm: Any) -> str:
def get_conversion_instructions(
model: type[BaseModel], llm: BaseLLM | LLM | str
) -> str:
"""Generate conversion instructions based on the model and LLM capabilities.
Args:
model: A Pydantic model class.
llm: The language model instance.
Returns:
"""
instructions = "Please convert the following text into valid JSON."
if llm and not isinstance(llm, str) and llm.supports_function_calling():
if (
llm
and not isinstance(llm, str)
and hasattr(llm, "supports_function_calling")
and llm.supports_function_calling()
):
model_schema = PydanticSchemaParser(model=model).get_schema()
instructions += (
f"\n\nOutput ONLY the valid JSON and nothing else.\n\n"
@@ -231,12 +340,45 @@ def get_conversion_instructions(model: Type[BaseModel], llm: Any) -> str:
return instructions
class CreateConverterKwargs(TypedDict, total=False):
"""Keyword arguments for creating a converter.
Attributes:
llm: The language model instance.
text: The text to convert.
model: The Pydantic model class.
instructions: The conversion instructions.
"""
llm: BaseLLM | LLM | str
text: str
model: type[BaseModel]
instructions: str
def create_converter(
agent: Optional[Any] = None,
converter_cls: Optional[Type[Converter]] = None,
*args,
**kwargs,
agent: Agent | None = None,
converter_cls: type[Converter] | None = None,
*args: Any,
**kwargs: Unpack[CreateConverterKwargs],
) -> Converter:
"""Create a converter instance based on the agent or provided class.
Args:
agent: The agent instance.
converter_cls: The converter class to instantiate.
*args: The positional arguments to pass to the converter.
**kwargs: The keyword arguments to pass to the converter.
Returns:
An instance of the specified converter class.
Raises:
ValueError: If neither agent nor converter_cls is provided.
AttributeError: If the agent does not have a 'get_output_converter' method.
Exception: If no converter instance is created.
"""
if agent and not converter_cls:
if hasattr(agent, "get_output_converter"):
converter = agent.get_output_converter(*args, **kwargs)
@@ -253,17 +395,30 @@ def create_converter(
return converter
def generate_model_description(model: Type[BaseModel]) -> str:
"""
Generate a string description of a Pydantic model's fields and their types.
def generate_model_description(model: type[BaseModel]) -> str:
"""Generate a string description of a Pydantic model's fields and their types.
This function takes a Pydantic model class and returns a string that describes
the model's fields and their respective types. The description includes handling
of complex types such as `Optional`, `List`, and `Dict`, as well as nested Pydantic
models.
Args:
model: A Pydantic model class.
Returns:
A string representation of the model's fields and types.
"""
def describe_field(field_type):
def describe_field(field_type: Any) -> str:
"""Recursively describe a field's type.
Args:
field_type: The type of the field to describe.
Returns:
A string representation of the field's type.
"""
origin = get_origin(field_type)
args = get_args(field_type)
@@ -272,20 +427,18 @@ def generate_model_description(model: Type[BaseModel]) -> str:
non_none_args = [arg for arg in args if arg is not type(None)]
if len(non_none_args) == 1:
return f"Optional[{describe_field(non_none_args[0])}]"
else:
return f"Optional[Union[{', '.join(describe_field(arg) for arg in non_none_args)}]]"
elif origin is list:
return f"Optional[Union[{', '.join(describe_field(arg) for arg in non_none_args)}]]"
if origin is list:
return f"List[{describe_field(args[0])}]"
elif origin is dict:
if origin is dict:
key_type = describe_field(args[0])
value_type = describe_field(args[1])
return f"Dict[{key_type}, {value_type}]"
elif isinstance(field_type, type) and issubclass(field_type, BaseModel):
if isinstance(field_type, type) and issubclass(field_type, BaseModel):
return generate_model_description(field_type)
elif hasattr(field_type, "__name__"):
if hasattr(field_type, "__name__"):
return field_type.__name__
else:
return str(field_type)
return str(field_type)
fields = model.model_fields
field_descriptions = [

View File

@@ -1 +1 @@
"""Crew-specific utilities."""
"""Crew-specific utilities."""

View File

@@ -1,16 +1,16 @@
"""Context management utilities for tracking crew and task execution context using OpenTelemetry baggage."""
from typing import Optional
from typing import cast
from opentelemetry import baggage
from crewai.utilities.crew.models import CrewContext
def get_crew_context() -> Optional[CrewContext]:
def get_crew_context() -> CrewContext | None:
"""Get the current crew context from OpenTelemetry baggage.
Returns:
CrewContext instance containing crew context information, or None if no context is set
"""
return baggage.get_baggage("crew_context")
return cast(CrewContext | None, baggage.get_baggage("crew_context"))

View File

@@ -1,16 +1,17 @@
"""Models for crew-related data structures."""
from typing import Optional
from pydantic import BaseModel, Field
class CrewContext(BaseModel):
"""Model representing crew context information."""
"""Model representing crew context information.
id: Optional[str] = Field(
default=None, description="Unique identifier for the crew"
)
key: Optional[str] = Field(
Attributes:
id: Unique identifier for the crew.
key: Optional crew key/name for identification.
"""
id: str | None = Field(default=None, description="Unique identifier for the crew")
key: str | None = Field(
default=None, description="Optional crew key/name for identification"
)

Some files were not shown because too many files have changed in this diff Show More