mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-30 03:08:29 +00:00
Compare commits
20 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b5b10a8cde | ||
|
|
2485ed93d6 | ||
|
|
ce5ea9be6f | ||
|
|
e070c1400c | ||
|
|
6537e3737d | ||
|
|
346faf229f | ||
|
|
a0b757a12c | ||
|
|
1dbe8aab52 | ||
|
|
4ac65eb0a6 | ||
|
|
3e97393f58 | ||
|
|
34bed359a6 | ||
|
|
feeed505bb | ||
|
|
cb0efd05b4 | ||
|
|
db5f565dea | ||
|
|
58413b663a | ||
|
|
37636f0dd7 | ||
|
|
0e370593f1 | ||
|
|
aa8dc9d77f | ||
|
|
9c1096dbdc | ||
|
|
47044450c0 |
@@ -5,6 +5,82 @@ icon: "clock"
|
||||
mode: "wide"
|
||||
---
|
||||
|
||||
<Update label="Sep 20, 2025">
|
||||
## v0.193.2
|
||||
|
||||
[View release on GitHub](https://github.com/crewAIInc/crewAI/releases/tag/0.193.2)
|
||||
|
||||
## What's Changed
|
||||
|
||||
- Updated pyproject templates to use the right version
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="Sep 20, 2025">
|
||||
## v0.193.1
|
||||
|
||||
[View release on GitHub](https://github.com/crewAIInc/crewAI/releases/tag/0.193.1)
|
||||
|
||||
## What's Changed
|
||||
|
||||
- Series of minor fixes and linter improvements
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="Sep 19, 2025">
|
||||
## v0.193.0
|
||||
|
||||
[View release on GitHub](https://github.com/crewAIInc/crewAI/releases/tag/0.193.0)
|
||||
|
||||
## Core Improvements & Fixes
|
||||
|
||||
- Fixed handling of the `model` parameter during OpenAI adapter initialization
|
||||
- Resolved test duration cache issues in CI workflows
|
||||
- Fixed flaky test related to repeated tool usage by agents
|
||||
- Added missing event exports to `__init__.py` for consistent module behavior
|
||||
- Dropped message storage from metadata in Mem0 to reduce bloat
|
||||
- Fixed L2 distance metric support for backward compatibility in vector search
|
||||
|
||||
## New Features & Enhancements
|
||||
|
||||
- Introduced thread-safe platform context management
|
||||
- Added test duration caching for optimized `pytest-split` runs
|
||||
- Added ephemeral trace improvements for better trace control
|
||||
- Made search parameters for RAG, knowledge, and memory fully configurable
|
||||
- Enabled ChromaDB to use OpenAI API for embedding functions
|
||||
- Added deeper observability tools for user-level insights
|
||||
- Unified RAG storage system with instance-specific client support
|
||||
|
||||
## Documentation & Guides
|
||||
|
||||
- Updated `RagTool` references to reflect CrewAI native RAG implementation
|
||||
- Improved internal docs for `langgraph` and `openai` agent adapters with type annotations and docstrings
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="Sep 11, 2025">
|
||||
## v0.186.1
|
||||
|
||||
[View release on GitHub](https://github.com/crewAIInc/crewAI/releases/tag/0.186.1)
|
||||
|
||||
## What's Changed
|
||||
|
||||
- Fixed version not being found and silently failing reversion
|
||||
- Bumped CrewAI version to 0.186.1 and updated dependencies in the CLI
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="Sep 10, 2025">
|
||||
## v0.186.0
|
||||
|
||||
[View release on GitHub](https://github.com/crewAIInc/crewAI/releases/tag/0.186.0)
|
||||
|
||||
## What's Changed
|
||||
|
||||
- Refer to the GitHub release notes for detailed changes
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="Sep 04, 2025">
|
||||
## v0.177.0
|
||||
|
||||
|
||||
@@ -404,6 +404,10 @@ crewai config reset
|
||||
After resetting configuration, re-run `crewai login` to authenticate again.
|
||||
</Tip>
|
||||
|
||||
<Tip>
|
||||
CrewAI CLI handles authentication to the Tool Repository automatically when adding packages to your project. Just append `crewai` before any `uv` command to use it. E.g. `crewai uv add requests`. For more information, see [Tool Repository](https://docs.crewai.com/enterprise/features/tool-repository) docs.
|
||||
</Tip>
|
||||
|
||||
<Note>
|
||||
Configuration settings are stored in `~/.config/crewai/settings.json`. Some settings like organization name and UUID are read-only and managed through authentication and organization commands. Tool repository related settings are hidden and cannot be set directly by users.
|
||||
</Note>
|
||||
|
||||
@@ -52,6 +52,36 @@ researcher = Agent(
|
||||
)
|
||||
```
|
||||
|
||||
## Adding other packages after installing a tool
|
||||
|
||||
After installing a tool from the CrewAI Enterprise Tool Repository, you need to use the `crewai uv` command to add other packages to your project.
|
||||
Using pure `uv` commands will fail due to authentication to tool repository being handled by the CLI. By using the `crewai uv` command, you can add other packages to your project without having to worry about authentication.
|
||||
Any `uv` command can be used with the `crewai uv` command, making it a powerful tool for managing your project's dependencies without the hassle of managing authentication through environment variables or other methods.
|
||||
|
||||
Say that you have installed a custom tool from the CrewAI Enterprise Tool Repository called "my-tool":
|
||||
|
||||
```bash
|
||||
crewai tool install my-tool
|
||||
```
|
||||
|
||||
And now you want to add another package to your project, you can use the following command:
|
||||
|
||||
```bash
|
||||
crewai uv add requests
|
||||
```
|
||||
|
||||
Other commands like `uv sync` or `uv remove` can also be used with the `crewai uv` command:
|
||||
|
||||
```bash
|
||||
crewai uv sync
|
||||
```
|
||||
|
||||
```bash
|
||||
crewai uv remove requests
|
||||
```
|
||||
|
||||
This will add the package to your project and update `pyproject.toml` accordingly.
|
||||
|
||||
## Creating and Publishing Tools
|
||||
|
||||
To create a new tool project:
|
||||
|
||||
@@ -27,7 +27,7 @@ Follow the steps below to get Crewing! 🚣♂️
|
||||
<Step title="Navigate to your new crew project">
|
||||
<CodeGroup>
|
||||
```shell Terminal
|
||||
cd latest-ai-development
|
||||
cd latest_ai_development
|
||||
```
|
||||
</CodeGroup>
|
||||
</Step>
|
||||
|
||||
@@ -5,6 +5,82 @@ icon: "clock"
|
||||
mode: "wide"
|
||||
---
|
||||
|
||||
<Update label="2025년 9월 20일">
|
||||
## v0.193.2
|
||||
|
||||
[GitHub 릴리스 보기](https://github.com/crewAIInc/crewAI/releases/tag/0.193.2)
|
||||
|
||||
## 변경 사항
|
||||
|
||||
- 올바른 버전을 사용하도록 pyproject 템플릿 업데이트
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="2025년 9월 20일">
|
||||
## v0.193.1
|
||||
|
||||
[GitHub 릴리스 보기](https://github.com/crewAIInc/crewAI/releases/tag/0.193.1)
|
||||
|
||||
## 변경 사항
|
||||
|
||||
- 일련의 사소한 수정 및 린터 개선
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="2025년 9월 19일">
|
||||
## v0.193.0
|
||||
|
||||
[GitHub 릴리스 보기](https://github.com/crewAIInc/crewAI/releases/tag/0.193.0)
|
||||
|
||||
## 핵심 개선 사항 및 수정 사항
|
||||
|
||||
- OpenAI 어댑터 초기화 중 `model` 매개변수 처리 수정
|
||||
- CI 워크플로에서 테스트 소요 시간 캐시 문제 해결
|
||||
- 에이전트의 반복 도구 사용과 관련된 불안정한 테스트 수정
|
||||
- 일관된 모듈 동작을 위해 누락된 이벤트 내보내기를 `__init__.py`에 추가
|
||||
- 메타데이터 부하를 줄이기 위해 Mem0에서 메시지 저장 제거
|
||||
- 벡터 검색의 하위 호환성을 위해 L2 거리 메트릭 지원 수정
|
||||
|
||||
## 새로운 기능 및 향상 사항
|
||||
|
||||
- 스레드 안전한 플랫폼 컨텍스트 관리 도입
|
||||
- `pytest-split` 실행 최적화를 위한 테스트 소요 시간 캐싱 추가
|
||||
- 더 나은 추적 제어를 위한 일시적(trace) 개선
|
||||
- RAG, 지식, 메모리 검색 매개변수를 완전 구성 가능하게 변경
|
||||
- ChromaDB가 임베딩 함수에 OpenAI API를 사용할 수 있도록 지원
|
||||
- 사용자 수준 인사이트를 위한 심화된 관찰 가능성 도구 추가
|
||||
- 인스턴스별 클라이언트를 지원하는 통합 RAG 스토리지 시스템
|
||||
|
||||
## 문서 및 가이드
|
||||
|
||||
- CrewAI 네이티브 RAG 구현을 반영하도록 `RagTool` 참조 업데이트
|
||||
- 타입 주석과 도크스트링을 포함해 `langgraph` 및 `openai` 에이전트 어댑터 내부 문서 개선
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="2025년 9월 11일">
|
||||
## v0.186.1
|
||||
|
||||
[GitHub 릴리스 보기](https://github.com/crewAIInc/crewAI/releases/tag/0.186.1)
|
||||
|
||||
## 변경 사항
|
||||
|
||||
- 버전을 찾지 못해 조용히 되돌리는(reversion) 문제 수정
|
||||
- CLI에서 CrewAI 버전을 0.186.1로 올리고 의존성 업데이트
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="2025년 9월 10일">
|
||||
## v0.186.0
|
||||
|
||||
[GitHub 릴리스 보기](https://github.com/crewAIInc/crewAI/releases/tag/0.186.0)
|
||||
|
||||
## 변경 사항
|
||||
|
||||
- 자세한 변경 사항은 GitHub 릴리스 노트를 참조하세요
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="2025년 9월 4일">
|
||||
## v0.177.0
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ mode: "wide"
|
||||
<Step title="새로운 crew 프로젝트로 이동하기">
|
||||
<CodeGroup>
|
||||
```shell Terminal
|
||||
cd latest-ai-development
|
||||
cd latest_ai_development
|
||||
```
|
||||
</CodeGroup>
|
||||
</Step>
|
||||
|
||||
@@ -5,6 +5,82 @@ icon: "clock"
|
||||
mode: "wide"
|
||||
---
|
||||
|
||||
<Update label="20 set 2025">
|
||||
## v0.193.2
|
||||
|
||||
[Ver release no GitHub](https://github.com/crewAIInc/crewAI/releases/tag/0.193.2)
|
||||
|
||||
## O que Mudou
|
||||
|
||||
- Atualizados templates do pyproject para usar a versão correta
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="20 set 2025">
|
||||
## v0.193.1
|
||||
|
||||
[Ver release no GitHub](https://github.com/crewAIInc/crewAI/releases/tag/0.193.1)
|
||||
|
||||
## O que Mudou
|
||||
|
||||
- Série de pequenas correções e melhorias de linter
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="19 set 2025">
|
||||
## v0.193.0
|
||||
|
||||
[Ver release no GitHub](https://github.com/crewAIInc/crewAI/releases/tag/0.193.0)
|
||||
|
||||
## Melhorias e Correções Principais
|
||||
|
||||
- Corrigido manuseio do parâmetro `model` durante a inicialização do adaptador OpenAI
|
||||
- Resolvidos problemas de cache da duração de testes nos fluxos de CI
|
||||
- Corrigido teste instável relacionado ao uso repetido de ferramentas pelos agentes
|
||||
- Adicionadas exportações de eventos ausentes no `__init__.py` para comportamento consistente do módulo
|
||||
- Removido armazenamento de mensagem dos metadados no Mem0 para reduzir inchaço
|
||||
- Corrigido suporte à métrica de distância L2 para compatibilidade retroativa na busca vetorial
|
||||
|
||||
## Novos Recursos e Melhorias
|
||||
|
||||
- Introduzida gestão de contexto de plataforma com segurança de threads
|
||||
- Adicionado cache da duração de testes para execuções otimizadas do `pytest-split`
|
||||
- Melhorias de traces efêmeros para melhor controle de rastreamento
|
||||
- Parâmetros de busca para RAG, conhecimento e memória totalmente configuráveis
|
||||
- Habilitado ChromaDB para usar a OpenAI API para funções de embedding
|
||||
- Adicionadas ferramentas de observabilidade mais profundas para insights ao nível do usuário
|
||||
- Sistema de armazenamento RAG unificado com suporte a cliente específico por instância
|
||||
|
||||
## Documentação e Guias
|
||||
|
||||
- Atualizadas referências do `RagTool` para refletir a implementação nativa de RAG do CrewAI
|
||||
- Melhorada documentação interna para adaptadores de agente `langgraph` e `openai` com anotações de tipo e docstrings
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="11 set 2025">
|
||||
## v0.186.1
|
||||
|
||||
[Ver release no GitHub](https://github.com/crewAIInc/crewAI/releases/tag/0.186.1)
|
||||
|
||||
## O que Mudou
|
||||
|
||||
- Corrigida falha silenciosa de reversão quando a versão não era encontrada
|
||||
- Versão do CrewAI atualizada para 0.186.1 e dependências do CLI atualizadas
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="10 set 2025">
|
||||
## v0.186.0
|
||||
|
||||
[Ver release no GitHub](https://github.com/crewAIInc/crewAI/releases/tag/0.186.0)
|
||||
|
||||
## O que Mudou
|
||||
|
||||
- Consulte as notas de lançamento no GitHub para detalhes completos
|
||||
|
||||
</Update>
|
||||
|
||||
<Update label="04 set 2025">
|
||||
## v0.177.0
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ Siga os passos abaixo para começar a tripular! 🚣♂️
|
||||
<Step title="Navegue até o novo projeto da sua tripulação">
|
||||
<CodeGroup>
|
||||
```shell Terminal
|
||||
cd latest-ai-development
|
||||
cd latest_ai_development
|
||||
```
|
||||
</CodeGroup>
|
||||
</Step>
|
||||
|
||||
@@ -9,7 +9,7 @@ authors = [
|
||||
]
|
||||
dependencies = [
|
||||
# Core Dependencies
|
||||
"pydantic>=2.4.2",
|
||||
"pydantic>=2.11.9",
|
||||
"openai>=1.13.3",
|
||||
"litellm==1.74.9",
|
||||
"instructor>=1.3.3",
|
||||
@@ -21,13 +21,12 @@ dependencies = [
|
||||
"opentelemetry-sdk>=1.30.0",
|
||||
"opentelemetry-exporter-otlp-proto-http>=1.30.0",
|
||||
# Data Handling
|
||||
"chromadb>=0.5.23",
|
||||
"chromadb~=1.1.0",
|
||||
"tokenizers>=0.20.3",
|
||||
"onnxruntime==1.22.0",
|
||||
"openpyxl>=3.1.5",
|
||||
"pyvis>=0.3.2",
|
||||
# Authentication and Security
|
||||
"python-dotenv>=1.0.0",
|
||||
"python-dotenv>=1.1.1",
|
||||
"pyjwt>=2.9.0",
|
||||
# Configuration and Utils
|
||||
"click>=8.1.7",
|
||||
@@ -40,6 +39,7 @@ dependencies = [
|
||||
"blinker>=1.9.0",
|
||||
"json5>=0.10.0",
|
||||
"portalocker==2.7.0",
|
||||
"pydantic-settings>=2.10.1",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
@@ -48,7 +48,9 @@ Documentation = "https://docs.crewai.com"
|
||||
Repository = "https://github.com/crewAIInc/crewAI"
|
||||
|
||||
[project.optional-dependencies]
|
||||
tools = ["crewai-tools~=0.73.0"]
|
||||
tools = [
|
||||
"crewai-tools>=0.74.0",
|
||||
]
|
||||
embeddings = [
|
||||
"tiktoken~=0.8.0"
|
||||
]
|
||||
@@ -71,24 +73,30 @@ aisuite = [
|
||||
qdrant = [
|
||||
"qdrant-client[fastembed]>=1.14.3",
|
||||
]
|
||||
aws = [
|
||||
"boto3>=1.40.38",
|
||||
]
|
||||
watson = [
|
||||
"ibm-watsonx-ai>=1.3.39",
|
||||
]
|
||||
voyageai = [
|
||||
"voyageai>=0.3.5",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
dev-dependencies = [
|
||||
"ruff>=0.12.11",
|
||||
"mypy>=1.17.1",
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"ruff>=0.13.1",
|
||||
"mypy>=1.18.2",
|
||||
"pre-commit>=4.3.0",
|
||||
"bandit>=1.8.6",
|
||||
"pillow>=10.2.0",
|
||||
"cairosvg>=2.7.1",
|
||||
"pytest>=8.0.0",
|
||||
"python-dotenv>=1.0.0",
|
||||
"pytest-asyncio>=0.23.7",
|
||||
"pytest-subprocess>=1.5.2",
|
||||
"pytest-recording>=0.13.2",
|
||||
"pytest-randomly>=3.16.0",
|
||||
"pytest-timeout>=2.3.1",
|
||||
"pytest-xdist>=3.6.1",
|
||||
"pytest-split>=0.9.0",
|
||||
"pytest>=8.4.2",
|
||||
"pytest-asyncio>=1.2.0",
|
||||
"pytest-subprocess>=1.5.3",
|
||||
"pytest-recording>=0.13.4",
|
||||
"pytest-randomly>=4.0.1",
|
||||
"pytest-timeout>=2.4.0",
|
||||
"pytest-xdist>=3.8.0",
|
||||
"pytest-split>=0.10.0",
|
||||
"types-requests==2.32.*",
|
||||
"types-pyyaml==6.0.*",
|
||||
"types-regex==2024.11.6.*",
|
||||
|
||||
@@ -40,7 +40,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
|
||||
|
||||
_suppress_pydantic_deprecation_warnings()
|
||||
|
||||
__version__ = "0.193.2"
|
||||
__version__ = "0.201.0"
|
||||
_telemetry_submitted = False
|
||||
|
||||
|
||||
|
||||
@@ -1,17 +1,10 @@
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
|
||||
@@ -19,12 +12,31 @@ from pydantic import Field, InstanceOf, PrivateAttr, model_validator
|
||||
from crewai.agents import CacheHandler
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.knowledge_events import (
|
||||
KnowledgeQueryCompletedEvent,
|
||||
KnowledgeQueryFailedEvent,
|
||||
KnowledgeQueryStartedEvent,
|
||||
KnowledgeRetrievalCompletedEvent,
|
||||
KnowledgeRetrievalStartedEvent,
|
||||
KnowledgeSearchQueryFailedEvent,
|
||||
)
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryRetrievalCompletedEvent,
|
||||
MemoryRetrievalStartedEvent,
|
||||
)
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
|
||||
from crewai.lite_agent import LiteAgent, LiteAgentOutput
|
||||
from crewai.llm import BaseLLM
|
||||
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
from crewai.security import Fingerprint
|
||||
from crewai.task import Task
|
||||
from crewai.tools import BaseTool
|
||||
@@ -38,24 +50,6 @@ from crewai.utilities.agent_utils import (
|
||||
)
|
||||
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryRetrievalStartedEvent,
|
||||
MemoryRetrievalCompletedEvent,
|
||||
)
|
||||
from crewai.events.types.knowledge_events import (
|
||||
KnowledgeQueryCompletedEvent,
|
||||
KnowledgeQueryFailedEvent,
|
||||
KnowledgeQueryStartedEvent,
|
||||
KnowledgeRetrievalCompletedEvent,
|
||||
KnowledgeRetrievalStartedEvent,
|
||||
KnowledgeSearchQueryFailedEvent,
|
||||
)
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
@@ -87,36 +81,36 @@ class Agent(BaseAgent):
|
||||
"""
|
||||
|
||||
_times_executed: int = PrivateAttr(default=0)
|
||||
max_execution_time: Optional[int] = Field(
|
||||
max_execution_time: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum execution time for an agent to execute a task",
|
||||
)
|
||||
agent_ops_agent_name: str = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str")
|
||||
agent_ops_agent_id: str = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str")
|
||||
step_callback: Optional[Any] = Field(
|
||||
step_callback: Any | None = Field(
|
||||
default=None,
|
||||
description="Callback to be executed after each step of the agent execution.",
|
||||
)
|
||||
use_system_prompt: Optional[bool] = Field(
|
||||
use_system_prompt: bool | None = Field(
|
||||
default=True,
|
||||
description="Use system prompt for the agent.",
|
||||
)
|
||||
llm: Union[str, InstanceOf[BaseLLM], Any] = Field(
|
||||
llm: str | InstanceOf[BaseLLM] | Any = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
function_calling_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
||||
function_calling_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
system_template: Optional[str] = Field(
|
||||
system_template: str | None = Field(
|
||||
default=None, description="System format for the agent."
|
||||
)
|
||||
prompt_template: Optional[str] = Field(
|
||||
prompt_template: str | None = Field(
|
||||
default=None, description="Prompt format for the agent."
|
||||
)
|
||||
response_template: Optional[str] = Field(
|
||||
response_template: str | None = Field(
|
||||
default=None, description="Response format for the agent."
|
||||
)
|
||||
allow_code_execution: Optional[bool] = Field(
|
||||
allow_code_execution: bool | None = Field(
|
||||
default=False, description="Enable code execution for the agent."
|
||||
)
|
||||
respect_context_window: bool = Field(
|
||||
@@ -147,31 +141,31 @@ class Agent(BaseAgent):
|
||||
default=False,
|
||||
description="Whether the agent should reflect and create a plan before executing a task.",
|
||||
)
|
||||
max_reasoning_attempts: Optional[int] = Field(
|
||||
max_reasoning_attempts: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum number of reasoning attempts before executing the task. If None, will try until ready.",
|
||||
)
|
||||
embedder: Optional[Dict[str, Any]] = Field(
|
||||
embedder: EmbedderConfig | None = Field(
|
||||
default=None,
|
||||
description="Embedder configuration for the agent.",
|
||||
)
|
||||
agent_knowledge_context: Optional[str] = Field(
|
||||
agent_knowledge_context: str | None = Field(
|
||||
default=None,
|
||||
description="Knowledge context for the agent.",
|
||||
)
|
||||
crew_knowledge_context: Optional[str] = Field(
|
||||
crew_knowledge_context: str | None = Field(
|
||||
default=None,
|
||||
description="Knowledge context for the crew.",
|
||||
)
|
||||
knowledge_search_query: Optional[str] = Field(
|
||||
knowledge_search_query: str | None = Field(
|
||||
default=None,
|
||||
description="Knowledge search query for the agent dynamically generated by the agent.",
|
||||
)
|
||||
from_repository: Optional[str] = Field(
|
||||
from_repository: str | None = Field(
|
||||
default=None,
|
||||
description="The Agent's role to be used from your repository.",
|
||||
)
|
||||
guardrail: Optional[Union[Callable[[Any], Tuple[bool, Any]], str]] = Field(
|
||||
guardrail: Callable[[Any], tuple[bool, Any]] | str | None = Field(
|
||||
default=None,
|
||||
description="Function or string description of a guardrail to validate agent output",
|
||||
)
|
||||
@@ -180,7 +174,7 @@ class Agent(BaseAgent):
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
def validate_from_repository(cls, v):
|
||||
def validate_from_repository(cls, v): # noqa: N805
|
||||
if v is not None and (from_repository := v.get("from_repository")):
|
||||
return load_agent_from_repository(from_repository) | v
|
||||
return v
|
||||
@@ -208,7 +202,7 @@ class Agent(BaseAgent):
|
||||
self.cache_handler = CacheHandler()
|
||||
self.set_cache_handler(self.cache_handler)
|
||||
|
||||
def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None):
|
||||
def set_knowledge(self, crew_embedder: EmbedderConfig | None = None):
|
||||
try:
|
||||
if self.embedder is None and crew_embedder:
|
||||
self.embedder = crew_embedder
|
||||
@@ -224,7 +218,7 @@ class Agent(BaseAgent):
|
||||
)
|
||||
self.knowledge.add_sources()
|
||||
except (TypeError, ValueError) as e:
|
||||
raise ValueError(f"Invalid Knowledge Configuration: {str(e)}")
|
||||
raise ValueError(f"Invalid Knowledge Configuration: {e!s}") from e
|
||||
|
||||
def _is_any_available_memory(self) -> bool:
|
||||
"""Check if any memory is available."""
|
||||
@@ -244,8 +238,8 @@ class Agent(BaseAgent):
|
||||
def execute_task(
|
||||
self,
|
||||
task: Task,
|
||||
context: Optional[str] = None,
|
||||
tools: Optional[List[BaseTool]] = None,
|
||||
context: str | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
) -> str:
|
||||
"""Execute a task with the agent.
|
||||
|
||||
@@ -278,11 +272,9 @@ class Agent(BaseAgent):
|
||||
task.description += f"\n\nReasoning Plan:\n{reasoning_output.plan.plan}"
|
||||
except Exception as e:
|
||||
if hasattr(self, "_logger"):
|
||||
self._logger.log(
|
||||
"error", f"Error during reasoning process: {str(e)}"
|
||||
)
|
||||
self._logger.log("error", f"Error during reasoning process: {e!s}")
|
||||
else:
|
||||
print(f"Error during reasoning process: {str(e)}")
|
||||
print(f"Error during reasoning process: {e!s}")
|
||||
|
||||
self._inject_date_to_task(task)
|
||||
|
||||
@@ -335,7 +327,7 @@ class Agent(BaseAgent):
|
||||
agent=self,
|
||||
task=task,
|
||||
)
|
||||
memory = contextual_memory.build_context_for_task(task, context)
|
||||
memory = contextual_memory.build_context_for_task(task, context) # type: ignore[arg-type]
|
||||
if memory.strip() != "":
|
||||
task_prompt += self.i18n.slice("memory").format(memory=memory)
|
||||
|
||||
@@ -525,14 +517,14 @@ class Agent(BaseAgent):
|
||||
|
||||
try:
|
||||
return future.result(timeout=timeout)
|
||||
except concurrent.futures.TimeoutError:
|
||||
except concurrent.futures.TimeoutError as e:
|
||||
future.cancel()
|
||||
raise TimeoutError(
|
||||
f"Task '{task.description}' execution timed out after {timeout} seconds. Consider increasing max_execution_time or optimizing the task."
|
||||
)
|
||||
) from e
|
||||
except Exception as e:
|
||||
future.cancel()
|
||||
raise RuntimeError(f"Task execution failed: {str(e)}")
|
||||
raise RuntimeError(f"Task execution failed: {e!s}") from e
|
||||
|
||||
def _execute_without_timeout(self, task_prompt: str, task: Task) -> str:
|
||||
"""Execute a task without a timeout.
|
||||
@@ -554,14 +546,14 @@ class Agent(BaseAgent):
|
||||
)["output"]
|
||||
|
||||
def create_agent_executor(
|
||||
self, tools: Optional[List[BaseTool]] = None, task=None
|
||||
self, tools: list[BaseTool] | None = None, task=None
|
||||
) -> None:
|
||||
"""Create an agent executor for the agent.
|
||||
|
||||
Returns:
|
||||
An instance of the CrewAgentExecutor class.
|
||||
"""
|
||||
raw_tools: List[BaseTool] = tools or self.tools or []
|
||||
raw_tools: list[BaseTool] = tools or self.tools or []
|
||||
parsed_tools = parse_tools(raw_tools)
|
||||
|
||||
prompt = Prompts(
|
||||
@@ -587,7 +579,7 @@ class Agent(BaseAgent):
|
||||
agent=self,
|
||||
crew=self.crew,
|
||||
tools=parsed_tools,
|
||||
prompt=prompt,
|
||||
prompt=prompt, # type: ignore[arg-type]
|
||||
original_tools=raw_tools,
|
||||
stop_words=stop_words,
|
||||
max_iter=self.max_iter,
|
||||
@@ -603,10 +595,9 @@ class Agent(BaseAgent):
|
||||
callbacks=[TokenCalcHandler(self._token_process)],
|
||||
)
|
||||
|
||||
def get_delegation_tools(self, agents: List[BaseAgent]):
|
||||
def get_delegation_tools(self, agents: list[BaseAgent]):
|
||||
agent_tools = AgentTools(agents=agents)
|
||||
tools = agent_tools.tools()
|
||||
return tools
|
||||
return agent_tools.tools()
|
||||
|
||||
def get_multimodal_tools(self) -> Sequence[BaseTool]:
|
||||
from crewai.tools.agent_tools.add_image_tool import AddImageTool
|
||||
@@ -654,7 +645,7 @@ class Agent(BaseAgent):
|
||||
)
|
||||
return task_prompt
|
||||
|
||||
def _render_text_description(self, tools: List[Any]) -> str:
|
||||
def _render_text_description(self, tools: list[Any]) -> str:
|
||||
"""Render the tool name and description in plain text.
|
||||
|
||||
Output will be in the format of:
|
||||
@@ -664,15 +655,13 @@ class Agent(BaseAgent):
|
||||
search: This tool is used for search
|
||||
calculator: This tool is used for math
|
||||
"""
|
||||
description = "\n".join(
|
||||
return "\n".join(
|
||||
[
|
||||
f"Tool name: {tool.name}\nTool description:\n{tool.description}"
|
||||
for tool in tools
|
||||
]
|
||||
)
|
||||
|
||||
return description
|
||||
|
||||
def _inject_date_to_task(self, task):
|
||||
"""Inject the current date into the task description if inject_date is enabled."""
|
||||
if self.inject_date:
|
||||
@@ -696,13 +685,13 @@ class Agent(BaseAgent):
|
||||
if not is_valid:
|
||||
raise ValueError(f"Invalid date format: {self.date_format}")
|
||||
|
||||
current_date: str = datetime.now().strftime(self.date_format)
|
||||
current_date = datetime.now().strftime(self.date_format)
|
||||
task.description += f"\n\nCurrent Date: {current_date}"
|
||||
except Exception as e:
|
||||
if hasattr(self, "_logger"):
|
||||
self._logger.log("warning", f"Failed to inject date: {str(e)}")
|
||||
self._logger.log("warning", f"Failed to inject date: {e!s}")
|
||||
else:
|
||||
print(f"Warning: Failed to inject date: {str(e)}")
|
||||
print(f"Warning: Failed to inject date: {e!s}")
|
||||
|
||||
def _validate_docker_installation(self) -> None:
|
||||
"""Check if Docker is installed and running."""
|
||||
@@ -713,15 +702,15 @@ class Agent(BaseAgent):
|
||||
|
||||
try:
|
||||
subprocess.run(
|
||||
["docker", "info"],
|
||||
["/usr/bin/docker", "info"],
|
||||
check=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
except subprocess.CalledProcessError:
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise RuntimeError(
|
||||
f"Docker is not running. Please start Docker to use code execution with agent: {self.role}"
|
||||
)
|
||||
) from e
|
||||
|
||||
def __repr__(self):
|
||||
return f"Agent(role={self.role}, goal={self.goal}, backstory={self.backstory})"
|
||||
@@ -796,8 +785,8 @@ class Agent(BaseAgent):
|
||||
|
||||
def kickoff(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
response_format: Optional[Type[Any]] = None,
|
||||
messages: str | list[dict[str, str]],
|
||||
response_format: type[Any] | None = None,
|
||||
) -> LiteAgentOutput:
|
||||
"""
|
||||
Execute the agent with the given messages using a LiteAgent instance.
|
||||
@@ -836,8 +825,8 @@ class Agent(BaseAgent):
|
||||
|
||||
async def kickoff_async(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
response_format: Optional[Type[Any]] = None,
|
||||
messages: str | list[dict[str, str]],
|
||||
response_format: type[Any] | None = None,
|
||||
) -> LiteAgentOutput:
|
||||
"""
|
||||
Execute the agent asynchronously with the given messages using a LiteAgent instance.
|
||||
|
||||
@@ -22,6 +22,7 @@ from crewai.agents.tools_handler import ToolsHandler
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.knowledge_config import KnowledgeConfig
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
from crewai.security.security_config import SecurityConfig
|
||||
from crewai.tools.base_tool import BaseTool, Tool
|
||||
from crewai.utilities import I18N, Logger, RPMController
|
||||
@@ -359,5 +360,5 @@ class BaseAgent(ABC, BaseModel):
|
||||
self._rpm_controller = rpm_controller
|
||||
self.create_agent_executor()
|
||||
|
||||
def set_knowledge(self, crew_embedder: dict[str, Any] | None = None):
|
||||
def set_knowledge(self, crew_embedder: EmbedderConfig | None = None):
|
||||
pass
|
||||
|
||||
@@ -18,7 +18,7 @@ from crewai.agents.constants import (
|
||||
MISSING_ACTION_INPUT_AFTER_ACTION_ERROR_MESSAGE,
|
||||
UNABLE_TO_REPAIR_JSON_RESULTS,
|
||||
)
|
||||
from crewai.utilities import I18N
|
||||
from crewai.utilities.i18n import I18N
|
||||
|
||||
_I18N = I18N()
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import os
|
||||
import subprocess
|
||||
from importlib.metadata import version as get_version
|
||||
|
||||
import click
|
||||
@@ -8,6 +10,7 @@ from crewai.cli.create_crew import create_crew
|
||||
from crewai.cli.create_flow import create_flow
|
||||
from crewai.cli.crew_chat import run_chat
|
||||
from crewai.cli.settings.main import SettingsCommand
|
||||
from crewai.cli.utils import build_env_with_tool_repository_credentials, read_toml
|
||||
from crewai.memory.storage.kickoff_task_outputs_storage import (
|
||||
KickoffTaskOutputsSQLiteStorage,
|
||||
)
|
||||
@@ -34,6 +37,46 @@ def crewai():
|
||||
"""Top-level command group for crewai."""
|
||||
|
||||
|
||||
@crewai.command(
|
||||
name="uv",
|
||||
context_settings=dict(
|
||||
ignore_unknown_options=True,
|
||||
),
|
||||
)
|
||||
@click.argument("uv_args", nargs=-1, type=click.UNPROCESSED)
|
||||
def uv(uv_args):
|
||||
"""A wrapper around uv commands that adds custom tool authentication through env vars."""
|
||||
env = os.environ.copy()
|
||||
try:
|
||||
pyproject_data = read_toml()
|
||||
sources = pyproject_data.get("tool", {}).get("uv", {}).get("sources", {})
|
||||
|
||||
for source_config in sources.values():
|
||||
if isinstance(source_config, dict):
|
||||
index = source_config.get("index")
|
||||
if index:
|
||||
index_env = build_env_with_tool_repository_credentials(index)
|
||||
env.update(index_env)
|
||||
except (FileNotFoundError, KeyError) as e:
|
||||
raise SystemExit(
|
||||
"Error. A valid pyproject.toml file is required. Check that a valid pyproject.toml file exists in the current directory."
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise SystemExit(f"Error: {e}") from e
|
||||
|
||||
try:
|
||||
subprocess.run( # noqa: S603
|
||||
["uv", *uv_args], # noqa: S607
|
||||
capture_output=False,
|
||||
env=env,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
except subprocess.CalledProcessError as e:
|
||||
click.secho(f"uv command failed with exit code {e.returncode}", fg="red")
|
||||
raise SystemExit(e.returncode) from e
|
||||
|
||||
|
||||
@crewai.command()
|
||||
@click.argument("type", type=click.Choice(["crew", "flow"]))
|
||||
@click.argument("name")
|
||||
@@ -239,11 +282,6 @@ def deploy():
|
||||
"""Deploy the Crew CLI group."""
|
||||
|
||||
|
||||
@crewai.group()
|
||||
def tool():
|
||||
"""Tool Repository related commands."""
|
||||
|
||||
|
||||
@deploy.command(name="create")
|
||||
@click.option("-y", "--yes", is_flag=True, help="Skip the confirmation prompt")
|
||||
def deploy_create(yes: bool):
|
||||
@@ -291,6 +329,11 @@ def deploy_remove(uuid: str | None):
|
||||
deploy_cmd.remove_crew(uuid=uuid)
|
||||
|
||||
|
||||
@crewai.group()
|
||||
def tool():
|
||||
"""Tool Repository related commands."""
|
||||
|
||||
|
||||
@tool.command(name="create")
|
||||
@click.argument("handle")
|
||||
def tool_create(handle: str):
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import json
|
||||
import tempfile
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -12,8 +14,48 @@ from crewai.cli.constants import (
|
||||
)
|
||||
from crewai.cli.shared.token_manager import TokenManager
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
DEFAULT_CONFIG_PATH = Path.home() / ".config" / "crewai" / "settings.json"
|
||||
|
||||
|
||||
def get_writable_config_path() -> Path | None:
|
||||
"""
|
||||
Find a writable location for the config file with fallback options.
|
||||
|
||||
Tries in order:
|
||||
1. Default: ~/.config/crewai/settings.json
|
||||
2. Temp directory: /tmp/crewai_settings.json (or OS equivalent)
|
||||
3. Current directory: ./crewai_settings.json
|
||||
4. In-memory only (returns None)
|
||||
|
||||
Returns:
|
||||
Path object for writable config location, or None if no writable location found
|
||||
"""
|
||||
fallback_paths = [
|
||||
DEFAULT_CONFIG_PATH, # Default location
|
||||
Path(tempfile.gettempdir()) / "crewai_settings.json", # Temporary directory
|
||||
Path.cwd() / "crewai_settings.json", # Current working directory
|
||||
]
|
||||
|
||||
for config_path in fallback_paths:
|
||||
try:
|
||||
config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
test_file = config_path.parent / ".crewai_write_test"
|
||||
try:
|
||||
test_file.write_text("test")
|
||||
test_file.unlink() # Clean up test file
|
||||
logger.info(f"Using config path: {config_path}")
|
||||
return config_path
|
||||
except Exception: # noqa: S112
|
||||
continue
|
||||
|
||||
except Exception: # noqa: S112
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# Settings that are related to the user's account
|
||||
USER_SETTINGS_KEYS = [
|
||||
"tool_repository_username",
|
||||
@@ -93,16 +135,32 @@ class Settings(BaseModel):
|
||||
default=DEFAULT_CLI_SETTINGS["oauth2_domain"],
|
||||
)
|
||||
|
||||
def __init__(self, config_path: Path = DEFAULT_CONFIG_PATH, **data):
|
||||
"""Load Settings from config path"""
|
||||
config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
def __init__(self, config_path: Path | None = None, **data):
|
||||
"""Load Settings from config path with fallback support"""
|
||||
if config_path is None:
|
||||
config_path = get_writable_config_path()
|
||||
|
||||
# If config_path is None, we're in memory-only mode
|
||||
if config_path is None:
|
||||
merged_data = {**data}
|
||||
# Dummy path for memory-only mode
|
||||
super().__init__(config_path=Path("/dev/null"), **merged_data)
|
||||
return
|
||||
|
||||
try:
|
||||
config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
except Exception:
|
||||
merged_data = {**data}
|
||||
# Dummy path for memory-only mode
|
||||
super().__init__(config_path=Path("/dev/null"), **merged_data)
|
||||
return
|
||||
|
||||
file_data = {}
|
||||
if config_path.is_file():
|
||||
try:
|
||||
with config_path.open("r") as f:
|
||||
file_data = json.load(f)
|
||||
except json.JSONDecodeError:
|
||||
except Exception:
|
||||
file_data = {}
|
||||
|
||||
merged_data = {**file_data, **data}
|
||||
@@ -122,15 +180,22 @@ class Settings(BaseModel):
|
||||
|
||||
def dump(self) -> None:
|
||||
"""Save current settings to settings.json"""
|
||||
if self.config_path.is_file():
|
||||
with self.config_path.open("r") as f:
|
||||
existing_data = json.load(f)
|
||||
else:
|
||||
existing_data = {}
|
||||
if str(self.config_path) == "/dev/null":
|
||||
return
|
||||
|
||||
updated_data = {**existing_data, **self.model_dump(exclude_unset=True)}
|
||||
with self.config_path.open("w") as f:
|
||||
json.dump(updated_data, f, indent=4)
|
||||
try:
|
||||
if self.config_path.is_file():
|
||||
with self.config_path.open("r") as f:
|
||||
existing_data = json.load(f)
|
||||
else:
|
||||
existing_data = {}
|
||||
|
||||
updated_data = {**existing_data, **self.model_dump(exclude_unset=True)}
|
||||
with self.config_path.open("w") as f:
|
||||
json.dump(updated_data, f, indent=4)
|
||||
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
def _reset_user_settings(self) -> None:
|
||||
"""Reset all user settings to default values"""
|
||||
|
||||
@@ -166,3 +166,13 @@ class PlusAPI:
|
||||
json=payload,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
def mark_trace_batch_as_failed(
|
||||
self, trace_batch_id: str, error_message: str
|
||||
) -> requests.Response:
|
||||
return self._make_request(
|
||||
"PATCH",
|
||||
f"{self.TRACING_RESOURCE}/batches/{trace_batch_id}",
|
||||
json={"status": "failed", "failure_reason": error_message},
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]>=0.193.2,<1.0.0"
|
||||
"crewai[tools]>=0.201.0,<1.0.0"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]>=0.193.2,<1.0.0",
|
||||
"crewai[tools]>=0.201.0,<1.0.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "Power up your crews with {{folder_name}}"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]>=0.193.2"
|
||||
"crewai[tools]>=0.201.0"
|
||||
]
|
||||
|
||||
[tool.crewai]
|
||||
|
||||
@@ -12,6 +12,7 @@ from crewai.cli import git
|
||||
from crewai.cli.command import BaseCommand, PlusAPIMixin
|
||||
from crewai.cli.config import Settings
|
||||
from crewai.cli.utils import (
|
||||
build_env_with_tool_repository_credentials,
|
||||
extract_available_exports,
|
||||
get_project_description,
|
||||
get_project_name,
|
||||
@@ -42,8 +43,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
if project_root.exists():
|
||||
click.secho(f"Folder {folder_name} already exists.", fg="red")
|
||||
raise SystemExit
|
||||
else:
|
||||
os.makedirs(project_root)
|
||||
os.makedirs(project_root)
|
||||
|
||||
click.secho(f"Creating custom tool {folder_name}...", fg="green", bold=True)
|
||||
|
||||
@@ -56,7 +56,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
os.chdir(project_root)
|
||||
try:
|
||||
self.login()
|
||||
subprocess.run(["git", "init"], check=True)
|
||||
subprocess.run(["git", "init"], check=True) # noqa: S607
|
||||
console.print(
|
||||
f"[green]Created custom tool [bold]{folder_name}[/bold]. Run [bold]cd {project_root}[/bold] to start working.[/green]"
|
||||
)
|
||||
@@ -76,10 +76,10 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
raise SystemExit()
|
||||
|
||||
project_name = get_project_name(require=True)
|
||||
assert isinstance(project_name, str)
|
||||
assert isinstance(project_name, str) # noqa: S101
|
||||
|
||||
project_version = get_project_version(require=True)
|
||||
assert isinstance(project_version, str)
|
||||
assert isinstance(project_version, str) # noqa: S101
|
||||
|
||||
project_description = get_project_description(require=False)
|
||||
encoded_tarball = None
|
||||
@@ -94,8 +94,8 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
self._print_current_organization()
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_build_dir:
|
||||
subprocess.run(
|
||||
["uv", "build", "--sdist", "--out-dir", temp_build_dir],
|
||||
subprocess.run( # noqa: S603
|
||||
["uv", "build", "--sdist", "--out-dir", temp_build_dir], # noqa: S607
|
||||
check=True,
|
||||
capture_output=False,
|
||||
)
|
||||
@@ -146,7 +146,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
style="bold red",
|
||||
)
|
||||
raise SystemExit
|
||||
elif get_response.status_code != 200:
|
||||
if get_response.status_code != 200:
|
||||
console.print(
|
||||
"Failed to get tool details. Please try again later.", style="bold red"
|
||||
)
|
||||
@@ -196,10 +196,10 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
else:
|
||||
add_package_command.extend(["--index", index, tool_handle])
|
||||
|
||||
add_package_result = subprocess.run(
|
||||
add_package_result = subprocess.run( # noqa: S603
|
||||
add_package_command,
|
||||
capture_output=False,
|
||||
env=self._build_env_with_credentials(repository_handle),
|
||||
env=build_env_with_tool_repository_credentials(repository_handle),
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
@@ -221,20 +221,6 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
)
|
||||
raise SystemExit
|
||||
|
||||
def _build_env_with_credentials(self, repository_handle: str):
|
||||
repository_handle = repository_handle.upper().replace("-", "_")
|
||||
settings = Settings()
|
||||
|
||||
env = os.environ.copy()
|
||||
env[f"UV_INDEX_{repository_handle}_USERNAME"] = str(
|
||||
settings.tool_repository_username or ""
|
||||
)
|
||||
env[f"UV_INDEX_{repository_handle}_PASSWORD"] = str(
|
||||
settings.tool_repository_password or ""
|
||||
)
|
||||
|
||||
return env
|
||||
|
||||
def _print_current_organization(self) -> None:
|
||||
settings = Settings()
|
||||
if settings.org_uuid:
|
||||
|
||||
@@ -11,6 +11,7 @@ import click
|
||||
import tomli
|
||||
from rich.console import Console
|
||||
|
||||
from crewai.cli.config import Settings
|
||||
from crewai.cli.constants import ENV_VARS
|
||||
from crewai.crew import Crew
|
||||
from crewai.flow import Flow
|
||||
@@ -417,6 +418,21 @@ def extract_available_exports(dir_path: str = "src"):
|
||||
raise SystemExit(1) from e
|
||||
|
||||
|
||||
def build_env_with_tool_repository_credentials(repository_handle: str):
|
||||
repository_handle = repository_handle.upper().replace("-", "_")
|
||||
settings = Settings()
|
||||
|
||||
env = os.environ.copy()
|
||||
env[f"UV_INDEX_{repository_handle}_USERNAME"] = str(
|
||||
settings.tool_repository_username or ""
|
||||
)
|
||||
env[f"UV_INDEX_{repository_handle}_PASSWORD"] = str(
|
||||
settings.tool_repository_password or ""
|
||||
)
|
||||
|
||||
return env
|
||||
|
||||
|
||||
def _load_tools_from_init(init_file: Path) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Load and validate tools from a given __init__.py file.
|
||||
|
||||
@@ -59,6 +59,7 @@ from crewai.memory.external.external_memory import ExternalMemory
|
||||
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
||||
from crewai.process import Process
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
from crewai.rag.types import SearchResult
|
||||
from crewai.security import Fingerprint, SecurityConfig
|
||||
from crewai.task import Task
|
||||
@@ -168,7 +169,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
default=None,
|
||||
description="An Instance of the ExternalMemory to be used by the Crew",
|
||||
)
|
||||
embedder: dict | None = Field(
|
||||
embedder: EmbedderConfig | None = Field(
|
||||
default=None,
|
||||
description="Configuration for the embedder to be used for the crew.",
|
||||
)
|
||||
@@ -622,7 +623,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
training_data=training_data, agent_id=str(agent.id)
|
||||
)
|
||||
CrewTrainingHandler(filename).save_trained_data(
|
||||
agent_id=str(agent.role), trained_data=result.model_dump()
|
||||
agent_id=str(agent.role),
|
||||
trained_data=result.model_dump(), # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
@@ -1057,7 +1059,10 @@ class Crew(FlowTrackable, BaseModel):
|
||||
def _log_task_start(self, task: Task, role: str = "None"):
|
||||
if self.output_log_file:
|
||||
self._file_handler.log(
|
||||
task_name=task.name, task=task.description, agent=role, status="started"
|
||||
task_name=task.name, # type: ignore[arg-type]
|
||||
task=task.description,
|
||||
agent=role,
|
||||
status="started",
|
||||
)
|
||||
|
||||
def _update_manager_tools(
|
||||
@@ -1086,7 +1091,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
role = task.agent.role if task.agent is not None else "None"
|
||||
if self.output_log_file:
|
||||
self._file_handler.log(
|
||||
task_name=task.name,
|
||||
task_name=task.name, # type: ignore[arg-type]
|
||||
task=task.description,
|
||||
agent=role,
|
||||
status="completed",
|
||||
|
||||
@@ -200,6 +200,9 @@ class TraceBatchManager:
|
||||
if self.event_buffer:
|
||||
events_sent_to_backend_status = self._send_events_to_backend()
|
||||
if events_sent_to_backend_status == 500:
|
||||
self.plus_api.mark_trace_batch_as_failed(
|
||||
self.trace_batch_id, "Error sending events to backend"
|
||||
)
|
||||
return None
|
||||
self._finalize_backend_batch()
|
||||
|
||||
@@ -273,10 +276,13 @@ class TraceBatchManager:
|
||||
logger.error(
|
||||
f"❌ Failed to finalize trace batch: {response.status_code} - {response.text}"
|
||||
)
|
||||
self.plus_api.mark_trace_batch_as_failed(
|
||||
self.trace_batch_id, response.text
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error finalizing trace batch: {e}")
|
||||
# TODO: send error to app marking as failed
|
||||
self.plus_api.mark_trace_batch_as_failed(self.trace_batch_id, str(e))
|
||||
|
||||
def _cleanup_batch_data(self):
|
||||
"""Clean up batch data after successful finalization to free memory"""
|
||||
|
||||
@@ -1,40 +1,39 @@
|
||||
from crewai.experimental.evaluation import (
|
||||
AgentEvaluationResult,
|
||||
AgentEvaluator,
|
||||
BaseEvaluator,
|
||||
EvaluationScore,
|
||||
MetricCategory,
|
||||
AgentEvaluationResult,
|
||||
SemanticQualityEvaluator,
|
||||
GoalAlignmentEvaluator,
|
||||
ReasoningEfficiencyEvaluator,
|
||||
ToolSelectionEvaluator,
|
||||
ParameterExtractionEvaluator,
|
||||
ToolInvocationEvaluator,
|
||||
EvaluationTraceCallback,
|
||||
create_evaluation_callbacks,
|
||||
AgentEvaluator,
|
||||
create_default_evaluator,
|
||||
ExperimentRunner,
|
||||
ExperimentResults,
|
||||
ExperimentResult,
|
||||
ExperimentResults,
|
||||
ExperimentRunner,
|
||||
GoalAlignmentEvaluator,
|
||||
MetricCategory,
|
||||
ParameterExtractionEvaluator,
|
||||
ReasoningEfficiencyEvaluator,
|
||||
SemanticQualityEvaluator,
|
||||
ToolInvocationEvaluator,
|
||||
ToolSelectionEvaluator,
|
||||
create_default_evaluator,
|
||||
create_evaluation_callbacks,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AgentEvaluationResult",
|
||||
"AgentEvaluator",
|
||||
"BaseEvaluator",
|
||||
"EvaluationScore",
|
||||
"MetricCategory",
|
||||
"AgentEvaluationResult",
|
||||
"SemanticQualityEvaluator",
|
||||
"GoalAlignmentEvaluator",
|
||||
"ReasoningEfficiencyEvaluator",
|
||||
"ToolSelectionEvaluator",
|
||||
"ParameterExtractionEvaluator",
|
||||
"ToolInvocationEvaluator",
|
||||
"EvaluationTraceCallback",
|
||||
"create_evaluation_callbacks",
|
||||
"AgentEvaluator",
|
||||
"create_default_evaluator",
|
||||
"ExperimentRunner",
|
||||
"ExperimentResult",
|
||||
"ExperimentResults",
|
||||
"ExperimentResult"
|
||||
]
|
||||
"ExperimentRunner",
|
||||
"GoalAlignmentEvaluator",
|
||||
"MetricCategory",
|
||||
"ParameterExtractionEvaluator",
|
||||
"ReasoningEfficiencyEvaluator",
|
||||
"SemanticQualityEvaluator",
|
||||
"ToolInvocationEvaluator",
|
||||
"ToolSelectionEvaluator",
|
||||
"create_default_evaluator",
|
||||
"create_evaluation_callbacks",
|
||||
]
|
||||
|
||||
@@ -1,51 +1,47 @@
|
||||
from crewai.experimental.evaluation.agent_evaluator import (
|
||||
AgentEvaluator,
|
||||
create_default_evaluator,
|
||||
)
|
||||
from crewai.experimental.evaluation.base_evaluator import (
|
||||
AgentEvaluationResult,
|
||||
BaseEvaluator,
|
||||
EvaluationScore,
|
||||
MetricCategory,
|
||||
AgentEvaluationResult
|
||||
)
|
||||
|
||||
from crewai.experimental.evaluation.metrics import (
|
||||
SemanticQualityEvaluator,
|
||||
GoalAlignmentEvaluator,
|
||||
ReasoningEfficiencyEvaluator,
|
||||
ToolSelectionEvaluator,
|
||||
ParameterExtractionEvaluator,
|
||||
ToolInvocationEvaluator
|
||||
)
|
||||
|
||||
from crewai.experimental.evaluation.evaluation_listener import (
|
||||
EvaluationTraceCallback,
|
||||
create_evaluation_callbacks
|
||||
create_evaluation_callbacks,
|
||||
)
|
||||
|
||||
from crewai.experimental.evaluation.agent_evaluator import (
|
||||
AgentEvaluator,
|
||||
create_default_evaluator
|
||||
)
|
||||
|
||||
from crewai.experimental.evaluation.experiment import (
|
||||
ExperimentRunner,
|
||||
ExperimentResult,
|
||||
ExperimentResults,
|
||||
ExperimentResult
|
||||
ExperimentRunner,
|
||||
)
|
||||
from crewai.experimental.evaluation.metrics import (
|
||||
GoalAlignmentEvaluator,
|
||||
ParameterExtractionEvaluator,
|
||||
ReasoningEfficiencyEvaluator,
|
||||
SemanticQualityEvaluator,
|
||||
ToolInvocationEvaluator,
|
||||
ToolSelectionEvaluator,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AgentEvaluationResult",
|
||||
"AgentEvaluator",
|
||||
"BaseEvaluator",
|
||||
"EvaluationScore",
|
||||
"MetricCategory",
|
||||
"AgentEvaluationResult",
|
||||
"SemanticQualityEvaluator",
|
||||
"GoalAlignmentEvaluator",
|
||||
"ReasoningEfficiencyEvaluator",
|
||||
"ToolSelectionEvaluator",
|
||||
"ParameterExtractionEvaluator",
|
||||
"ToolInvocationEvaluator",
|
||||
"EvaluationTraceCallback",
|
||||
"create_evaluation_callbacks",
|
||||
"AgentEvaluator",
|
||||
"create_default_evaluator",
|
||||
"ExperimentRunner",
|
||||
"ExperimentResult",
|
||||
"ExperimentResults",
|
||||
"ExperimentResult"
|
||||
"ExperimentRunner",
|
||||
"GoalAlignmentEvaluator",
|
||||
"MetricCategory",
|
||||
"ParameterExtractionEvaluator",
|
||||
"ReasoningEfficiencyEvaluator",
|
||||
"SemanticQualityEvaluator",
|
||||
"ToolInvocationEvaluator",
|
||||
"ToolSelectionEvaluator",
|
||||
"create_default_evaluator",
|
||||
"create_evaluation_callbacks",
|
||||
]
|
||||
|
||||
@@ -1,34 +1,36 @@
|
||||
import threading
|
||||
from typing import Any, Optional
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from crewai.experimental.evaluation.base_evaluator import (
|
||||
AgentEvaluationResult,
|
||||
AggregationStrategy,
|
||||
)
|
||||
from crewai.agent import Agent
|
||||
from crewai.task import Task
|
||||
from crewai.experimental.evaluation.evaluation_display import EvaluationDisplayFormatter
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentEvaluationStartedEvent,
|
||||
AgentEvaluationCompletedEvent,
|
||||
AgentEvaluationFailedEvent,
|
||||
AgentEvaluationStartedEvent,
|
||||
LiteAgentExecutionCompletedEvent,
|
||||
)
|
||||
from crewai.experimental.evaluation import BaseEvaluator, create_evaluation_callbacks
|
||||
from collections.abc import Sequence
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.utils.console_formatter import ConsoleFormatter
|
||||
from crewai.events.types.task_events import TaskCompletedEvent
|
||||
from crewai.events.types.agent_events import LiteAgentExecutionCompletedEvent
|
||||
from crewai.events.utils.console_formatter import ConsoleFormatter
|
||||
from crewai.experimental.evaluation.base_evaluator import (
|
||||
AgentAggregatedEvaluationResult,
|
||||
AgentEvaluationResult,
|
||||
AggregationStrategy,
|
||||
BaseEvaluator,
|
||||
EvaluationScore,
|
||||
MetricCategory,
|
||||
)
|
||||
from crewai.experimental.evaluation.evaluation_display import EvaluationDisplayFormatter
|
||||
from crewai.experimental.evaluation.evaluation_listener import (
|
||||
create_evaluation_callbacks,
|
||||
)
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
class ExecutionState:
|
||||
current_agent_id: Optional[str] = None
|
||||
current_task_id: Optional[str] = None
|
||||
current_agent_id: str | None = None
|
||||
current_task_id: str | None = None
|
||||
|
||||
def __init__(self):
|
||||
self.traces = {}
|
||||
@@ -40,10 +42,10 @@ class ExecutionState:
|
||||
class AgentEvaluator:
|
||||
def __init__(
|
||||
self,
|
||||
agents: list[Agent],
|
||||
agents: list[Agent] | list[BaseAgent],
|
||||
evaluators: Sequence[BaseEvaluator] | None = None,
|
||||
):
|
||||
self.agents: list[Agent] = agents
|
||||
self.agents: list[Agent] | list[BaseAgent] = agents
|
||||
self.evaluators: Sequence[BaseEvaluator] | None = evaluators
|
||||
|
||||
self.callback = create_evaluation_callbacks()
|
||||
@@ -75,7 +77,8 @@ class AgentEvaluator:
|
||||
)
|
||||
|
||||
def _handle_task_completed(self, source: Any, event: TaskCompletedEvent) -> None:
|
||||
assert event.task is not None
|
||||
if event.task is None:
|
||||
raise ValueError("TaskCompletedEvent must have a task")
|
||||
agent = event.task.agent
|
||||
if (
|
||||
agent
|
||||
@@ -92,9 +95,8 @@ class AgentEvaluator:
|
||||
state.current_agent_id = str(agent.id)
|
||||
state.current_task_id = str(event.task.id)
|
||||
|
||||
assert (
|
||||
state.current_agent_id is not None and state.current_task_id is not None
|
||||
)
|
||||
if state.current_agent_id is None or state.current_task_id is None:
|
||||
raise ValueError("Agent ID and Task ID must not be None")
|
||||
trace = self.callback.get_trace(
|
||||
state.current_agent_id, state.current_task_id
|
||||
)
|
||||
@@ -146,9 +148,8 @@ class AgentEvaluator:
|
||||
if not target_agent:
|
||||
return
|
||||
|
||||
assert (
|
||||
state.current_agent_id is not None and state.current_task_id is not None
|
||||
)
|
||||
if state.current_agent_id is None or state.current_task_id is None:
|
||||
raise ValueError("Agent ID and Task ID must not be None")
|
||||
trace = self.callback.get_trace(
|
||||
state.current_agent_id, state.current_task_id
|
||||
)
|
||||
@@ -244,7 +245,7 @@ class AgentEvaluator:
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
agent: Agent,
|
||||
agent: Agent | BaseAgent,
|
||||
execution_trace: dict[str, Any],
|
||||
final_output: Any,
|
||||
state: ExecutionState,
|
||||
@@ -255,7 +256,8 @@ class AgentEvaluator:
|
||||
task_id=state.current_task_id or (str(task.id) if task else "unknown_task"),
|
||||
)
|
||||
|
||||
assert self.evaluators is not None
|
||||
if self.evaluators is None:
|
||||
raise ValueError("Evaluators must be initialized")
|
||||
task_id = str(task.id) if task else None
|
||||
for evaluator in self.evaluators:
|
||||
try:
|
||||
@@ -276,7 +278,7 @@ class AgentEvaluator:
|
||||
metric_category=evaluator.metric_category,
|
||||
score=score,
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception as e: # noqa: PERF203
|
||||
self.emit_evaluation_failed_event(
|
||||
agent_role=agent.role,
|
||||
agent_id=str(agent.id),
|
||||
@@ -284,7 +286,7 @@ class AgentEvaluator:
|
||||
error=str(e),
|
||||
)
|
||||
self.console_formatter.print(
|
||||
f"Error in {evaluator.metric_category.value} evaluator: {str(e)}"
|
||||
f"Error in {evaluator.metric_category.value} evaluator: {e!s}"
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -337,14 +339,14 @@ class AgentEvaluator:
|
||||
)
|
||||
|
||||
|
||||
def create_default_evaluator(agents: list[Agent], llm: None = None):
|
||||
def create_default_evaluator(agents: list[Agent] | list[BaseAgent], llm: None = None):
|
||||
from crewai.experimental.evaluation import (
|
||||
GoalAlignmentEvaluator,
|
||||
SemanticQualityEvaluator,
|
||||
ToolSelectionEvaluator,
|
||||
ParameterExtractionEvaluator,
|
||||
ToolInvocationEvaluator,
|
||||
ReasoningEfficiencyEvaluator,
|
||||
SemanticQualityEvaluator,
|
||||
ToolInvocationEvaluator,
|
||||
ToolSelectionEvaluator,
|
||||
)
|
||||
|
||||
evaluators = [
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
import abc
|
||||
import enum
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.task import Task
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.llm import BaseLLM
|
||||
from crewai.task import Task
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
|
||||
|
||||
class MetricCategory(enum.Enum):
|
||||
GOAL_ALIGNMENT = "goal_alignment"
|
||||
SEMANTIC_QUALITY = "semantic_quality"
|
||||
@@ -19,7 +21,7 @@ class MetricCategory(enum.Enum):
|
||||
TOOL_INVOCATION = "tool_invocation"
|
||||
|
||||
def title(self):
|
||||
return self.value.replace('_', ' ').title()
|
||||
return self.value.replace("_", " ").title()
|
||||
|
||||
|
||||
class EvaluationScore(BaseModel):
|
||||
@@ -27,15 +29,13 @@ class EvaluationScore(BaseModel):
|
||||
default=5.0,
|
||||
description="Numeric score from 0-10 where 0 is worst and 10 is best, None if not applicable",
|
||||
ge=0.0,
|
||||
le=10.0
|
||||
le=10.0,
|
||||
)
|
||||
feedback: str = Field(
|
||||
default="",
|
||||
description="Detailed feedback explaining the evaluation score"
|
||||
default="", description="Detailed feedback explaining the evaluation score"
|
||||
)
|
||||
raw_response: str | None = Field(
|
||||
default=None,
|
||||
description="Raw response from the evaluator (e.g., LLM)"
|
||||
default=None, description="Raw response from the evaluator (e.g., LLM)"
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
@@ -56,8 +56,8 @@ class BaseEvaluator(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def evaluate(
|
||||
self,
|
||||
agent: Agent,
|
||||
execution_trace: Dict[str, Any],
|
||||
agent: Agent | BaseAgent,
|
||||
execution_trace: dict[str, Any],
|
||||
final_output: Any,
|
||||
task: Task | None = None,
|
||||
) -> EvaluationScore:
|
||||
@@ -67,9 +67,8 @@ class BaseEvaluator(abc.ABC):
|
||||
class AgentEvaluationResult(BaseModel):
|
||||
agent_id: str = Field(description="ID of the evaluated agent")
|
||||
task_id: str = Field(description="ID of the task that was executed")
|
||||
metrics: Dict[MetricCategory, EvaluationScore] = Field(
|
||||
default_factory=dict,
|
||||
description="Evaluation scores for each metric category"
|
||||
metrics: dict[MetricCategory, EvaluationScore] = Field(
|
||||
default_factory=dict, description="Evaluation scores for each metric category"
|
||||
)
|
||||
|
||||
|
||||
@@ -81,33 +80,23 @@ class AggregationStrategy(Enum):
|
||||
|
||||
|
||||
class AgentAggregatedEvaluationResult(BaseModel):
|
||||
agent_id: str = Field(
|
||||
default="",
|
||||
description="ID of the agent"
|
||||
)
|
||||
agent_role: str = Field(
|
||||
default="",
|
||||
description="Role of the agent"
|
||||
)
|
||||
agent_id: str = Field(default="", description="ID of the agent")
|
||||
agent_role: str = Field(default="", description="Role of the agent")
|
||||
task_count: int = Field(
|
||||
default=0,
|
||||
description="Number of tasks included in this aggregation"
|
||||
default=0, description="Number of tasks included in this aggregation"
|
||||
)
|
||||
aggregation_strategy: AggregationStrategy = Field(
|
||||
default=AggregationStrategy.SIMPLE_AVERAGE,
|
||||
description="Strategy used for aggregation"
|
||||
description="Strategy used for aggregation",
|
||||
)
|
||||
metrics: Dict[MetricCategory, EvaluationScore] = Field(
|
||||
default_factory=dict,
|
||||
description="Aggregated metrics across all tasks"
|
||||
metrics: dict[MetricCategory, EvaluationScore] = Field(
|
||||
default_factory=dict, description="Aggregated metrics across all tasks"
|
||||
)
|
||||
task_results: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="IDs of tasks included in this aggregation"
|
||||
task_results: list[str] = Field(
|
||||
default_factory=list, description="IDs of tasks included in this aggregation"
|
||||
)
|
||||
overall_score: Optional[float] = Field(
|
||||
default=None,
|
||||
description="Overall score for this agent"
|
||||
overall_score: float | None = Field(
|
||||
default=None, description="Overall score for this agent"
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
@@ -119,7 +108,7 @@ class AgentAggregatedEvaluationResult(BaseModel):
|
||||
result += f"\n\n- {category.value.upper()}: {score.score}/10\n"
|
||||
|
||||
if score.feedback:
|
||||
detailed_feedback = "\n ".join(score.feedback.split('\n'))
|
||||
detailed_feedback = "\n ".join(score.feedback.split("\n"))
|
||||
result += f" {detailed_feedback}\n"
|
||||
|
||||
return result
|
||||
return result
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
from collections import defaultdict
|
||||
from typing import Dict, Any, List
|
||||
from rich.table import Table
|
||||
from rich.box import HEAVY_EDGE, ROUNDED
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from rich.box import HEAVY_EDGE, ROUNDED
|
||||
from rich.table import Table
|
||||
|
||||
from crewai.events.utils.console_formatter import ConsoleFormatter
|
||||
from crewai.experimental.evaluation.base_evaluator import (
|
||||
AgentAggregatedEvaluationResult,
|
||||
AggregationStrategy,
|
||||
AgentEvaluationResult,
|
||||
AggregationStrategy,
|
||||
EvaluationScore,
|
||||
MetricCategory,
|
||||
)
|
||||
from crewai.experimental.evaluation import EvaluationScore
|
||||
from crewai.events.utils.console_formatter import ConsoleFormatter
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
|
||||
|
||||
@@ -19,7 +21,7 @@ class EvaluationDisplayFormatter:
|
||||
self.console_formatter = ConsoleFormatter()
|
||||
|
||||
def display_evaluation_with_feedback(
|
||||
self, iterations_results: Dict[int, Dict[str, List[Any]]]
|
||||
self, iterations_results: dict[int, dict[str, list[Any]]]
|
||||
):
|
||||
if not iterations_results:
|
||||
self.console_formatter.print(
|
||||
@@ -99,7 +101,7 @@ class EvaluationDisplayFormatter:
|
||||
|
||||
def display_summary_results(
|
||||
self,
|
||||
iterations_results: Dict[int, Dict[str, List[AgentAggregatedEvaluationResult]]],
|
||||
iterations_results: dict[int, dict[str, list[AgentEvaluationResult]]],
|
||||
):
|
||||
if not iterations_results:
|
||||
self.console_formatter.print(
|
||||
@@ -280,7 +282,7 @@ class EvaluationDisplayFormatter:
|
||||
feedback_summary = feedbacks[0]
|
||||
|
||||
aggregated_metrics[category] = EvaluationScore(
|
||||
score=avg_score, feedback=feedback_summary
|
||||
score=avg_score, feedback=feedback_summary or ""
|
||||
)
|
||||
|
||||
overall_score = None
|
||||
@@ -304,25 +306,25 @@ class EvaluationDisplayFormatter:
|
||||
self,
|
||||
agent_role: str,
|
||||
metric: str,
|
||||
feedbacks: List[str],
|
||||
scores: List[float | None],
|
||||
feedbacks: list[str],
|
||||
scores: list[float | None],
|
||||
strategy: AggregationStrategy,
|
||||
) -> str:
|
||||
if len(feedbacks) <= 2 and all(len(fb) < 200 for fb in feedbacks):
|
||||
return "\n\n".join(
|
||||
[f"Feedback {i+1}: {fb}" for i, fb in enumerate(feedbacks)]
|
||||
[f"Feedback {i + 1}: {fb}" for i, fb in enumerate(feedbacks)]
|
||||
)
|
||||
|
||||
try:
|
||||
llm = create_llm()
|
||||
|
||||
formatted_feedbacks = []
|
||||
for i, (feedback, score) in enumerate(zip(feedbacks, scores)):
|
||||
for i, (feedback, score) in enumerate(zip(feedbacks, scores, strict=False)):
|
||||
if len(feedback) > 500:
|
||||
feedback = feedback[:500] + "..."
|
||||
score_text = f"{score:.1f}" if score is not None else "N/A"
|
||||
formatted_feedbacks.append(
|
||||
f"Feedback #{i+1} (Score: {score_text}):\n{feedback}"
|
||||
f"Feedback #{i + 1} (Score: {score_text}):\n{feedback}"
|
||||
)
|
||||
|
||||
all_feedbacks = "\n\n" + "\n\n---\n\n".join(formatted_feedbacks)
|
||||
@@ -365,10 +367,9 @@ class EvaluationDisplayFormatter:
|
||||
""",
|
||||
},
|
||||
]
|
||||
assert llm is not None
|
||||
response = llm.call(prompt)
|
||||
|
||||
return response
|
||||
if llm is None:
|
||||
raise ValueError("LLM must be initialized")
|
||||
return llm.call(prompt)
|
||||
|
||||
except Exception:
|
||||
return "Synthesized from multiple tasks: " + "\n\n".join(
|
||||
|
||||
@@ -1,26 +1,25 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.task import Task
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.events.base_event_listener import BaseEventListener
|
||||
from crewai.events.event_bus import CrewAIEventsBus
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionStartedEvent,
|
||||
AgentExecutionCompletedEvent,
|
||||
LiteAgentExecutionStartedEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
LiteAgentExecutionCompletedEvent,
|
||||
LiteAgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.llm_events import LLMCallCompletedEvent, LLMCallStartedEvent
|
||||
from crewai.events.types.tool_usage_events import (
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageErrorEvent,
|
||||
ToolExecutionErrorEvent,
|
||||
ToolSelectionErrorEvent,
|
||||
ToolUsageErrorEvent,
|
||||
ToolUsageFinishedEvent,
|
||||
ToolValidateInputErrorEvent,
|
||||
)
|
||||
from crewai.events.types.llm_events import LLMCallStartedEvent, LLMCallCompletedEvent
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
class EvaluationTraceCallback(BaseEventListener):
|
||||
@@ -136,7 +135,7 @@ class EvaluationTraceCallback(BaseEventListener):
|
||||
def _init_trace(self, trace_key: str, **kwargs: Any):
|
||||
self.traces[trace_key] = kwargs
|
||||
|
||||
def on_agent_start(self, agent: Agent, task: Task):
|
||||
def on_agent_start(self, agent: BaseAgent, task: Task):
|
||||
self.current_agent_id = agent.id
|
||||
self.current_task_id = task.id
|
||||
|
||||
@@ -151,7 +150,7 @@ class EvaluationTraceCallback(BaseEventListener):
|
||||
final_output=None,
|
||||
)
|
||||
|
||||
def on_agent_finish(self, agent: Agent, task: Task, output: Any):
|
||||
def on_agent_finish(self, agent: BaseAgent, task: Task, output: Any):
|
||||
trace_key = f"{agent.id}_{task.id}"
|
||||
if trace_key in self.traces:
|
||||
self.traces[trace_key]["final_output"] = output
|
||||
@@ -253,7 +252,7 @@ class EvaluationTraceCallback(BaseEventListener):
|
||||
if hasattr(self, "current_llm_call"):
|
||||
self.current_llm_call = {}
|
||||
|
||||
def get_trace(self, agent_id: str, task_id: str) -> Optional[Dict[str, Any]]:
|
||||
def get_trace(self, agent_id: str, task_id: str) -> dict[str, Any] | None:
|
||||
trace_key = f"{agent_id}_{task_id}"
|
||||
return self.traces.get(trace_key)
|
||||
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from crewai.experimental.evaluation.experiment.result import (
|
||||
ExperimentResult,
|
||||
ExperimentResults,
|
||||
)
|
||||
from crewai.experimental.evaluation.experiment.runner import ExperimentRunner
|
||||
from crewai.experimental.evaluation.experiment.result import ExperimentResults, ExperimentResult
|
||||
|
||||
__all__ = [
|
||||
"ExperimentRunner",
|
||||
"ExperimentResults",
|
||||
"ExperimentResult"
|
||||
]
|
||||
__all__ = ["ExperimentResult", "ExperimentResults", "ExperimentRunner"]
|
||||
|
||||
@@ -2,45 +2,60 @@ import json
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ExperimentResult(BaseModel):
|
||||
identifier: str
|
||||
inputs: dict[str, Any]
|
||||
score: int | dict[str, int | float]
|
||||
expected_score: int | dict[str, int | float]
|
||||
score: float | dict[str, float]
|
||||
expected_score: float | dict[str, float]
|
||||
passed: bool
|
||||
agent_evaluations: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ExperimentResults:
|
||||
def __init__(self, results: list[ExperimentResult], metadata: dict[str, Any] | None = None):
|
||||
def __init__(
|
||||
self, results: list[ExperimentResult], metadata: dict[str, Any] | None = None
|
||||
):
|
||||
self.results = results
|
||||
self.metadata = metadata or {}
|
||||
self.timestamp = datetime.now(timezone.utc)
|
||||
|
||||
from crewai.experimental.evaluation.experiment.result_display import ExperimentResultsDisplay
|
||||
from crewai.experimental.evaluation.experiment.result_display import (
|
||||
ExperimentResultsDisplay,
|
||||
)
|
||||
|
||||
self.display = ExperimentResultsDisplay()
|
||||
|
||||
def to_json(self, filepath: str | None = None) -> dict[str, Any]:
|
||||
data = {
|
||||
"timestamp": self.timestamp.isoformat(),
|
||||
"metadata": self.metadata,
|
||||
"results": [r.model_dump(exclude={"agent_evaluations"}) for r in self.results]
|
||||
"results": [
|
||||
r.model_dump(exclude={"agent_evaluations"}) for r in self.results
|
||||
],
|
||||
}
|
||||
|
||||
if filepath:
|
||||
with open(filepath, 'w') as f:
|
||||
with open(filepath, "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
self.display.console.print(f"[green]Results saved to {filepath}[/green]")
|
||||
|
||||
return data
|
||||
|
||||
def compare_with_baseline(self, baseline_filepath: str, save_current: bool = True, print_summary: bool = False) -> dict[str, Any]:
|
||||
def compare_with_baseline(
|
||||
self,
|
||||
baseline_filepath: str,
|
||||
save_current: bool = True,
|
||||
print_summary: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
baseline_runs = []
|
||||
|
||||
if os.path.exists(baseline_filepath) and os.path.getsize(baseline_filepath) > 0:
|
||||
try:
|
||||
with open(baseline_filepath, 'r') as f:
|
||||
with open(baseline_filepath, "r") as f:
|
||||
baseline_data = json.load(f)
|
||||
|
||||
if isinstance(baseline_data, dict) and "timestamp" in baseline_data:
|
||||
@@ -48,14 +63,18 @@ class ExperimentResults:
|
||||
elif isinstance(baseline_data, list):
|
||||
baseline_runs = baseline_data
|
||||
except (json.JSONDecodeError, FileNotFoundError) as e:
|
||||
self.display.console.print(f"[yellow]Warning: Could not load baseline file: {str(e)}[/yellow]")
|
||||
self.display.console.print(
|
||||
f"[yellow]Warning: Could not load baseline file: {e!s}[/yellow]"
|
||||
)
|
||||
|
||||
if not baseline_runs:
|
||||
if save_current:
|
||||
current_data = self.to_json()
|
||||
with open(baseline_filepath, 'w') as f:
|
||||
with open(baseline_filepath, "w") as f:
|
||||
json.dump([current_data], f, indent=2)
|
||||
self.display.console.print(f"[green]Saved current results as new baseline to {baseline_filepath}[/green]")
|
||||
self.display.console.print(
|
||||
f"[green]Saved current results as new baseline to {baseline_filepath}[/green]"
|
||||
)
|
||||
return {"is_baseline": True, "changes": {}}
|
||||
|
||||
baseline_runs.sort(key=lambda x: x.get("timestamp", ""), reverse=True)
|
||||
@@ -69,9 +88,11 @@ class ExperimentResults:
|
||||
if save_current:
|
||||
current_data = self.to_json()
|
||||
baseline_runs.append(current_data)
|
||||
with open(baseline_filepath, 'w') as f:
|
||||
with open(baseline_filepath, "w") as f:
|
||||
json.dump(baseline_runs, f, indent=2)
|
||||
self.display.console.print(f"[green]Added current results to baseline file {baseline_filepath}[/green]")
|
||||
self.display.console.print(
|
||||
f"[green]Added current results to baseline file {baseline_filepath}[/green]"
|
||||
)
|
||||
|
||||
return comparison
|
||||
|
||||
@@ -118,5 +139,5 @@ class ExperimentResults:
|
||||
"new_tests": new_tests,
|
||||
"missing_tests": missing_tests,
|
||||
"total_compared": len(improved) + len(regressed) + len(unchanged),
|
||||
"baseline_timestamp": baseline_run.get("timestamp", "unknown")
|
||||
"baseline_timestamp": baseline_run.get("timestamp", "unknown"),
|
||||
}
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
from typing import Dict, Any
|
||||
from typing import Any
|
||||
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
|
||||
from crewai.experimental.evaluation.experiment.result import ExperimentResults
|
||||
|
||||
|
||||
class ExperimentResultsDisplay:
|
||||
def __init__(self):
|
||||
self.console = Console()
|
||||
@@ -19,13 +22,19 @@ class ExperimentResultsDisplay:
|
||||
table.add_row("Total Test Cases", str(total))
|
||||
table.add_row("Passed", str(passed))
|
||||
table.add_row("Failed", str(total - passed))
|
||||
table.add_row("Success Rate", f"{(passed / total * 100):.1f}%" if total > 0 else "N/A")
|
||||
table.add_row(
|
||||
"Success Rate", f"{(passed / total * 100):.1f}%" if total > 0 else "N/A"
|
||||
)
|
||||
|
||||
self.console.print(table)
|
||||
|
||||
def comparison_summary(self, comparison: Dict[str, Any], baseline_timestamp: str):
|
||||
self.console.print(Panel(f"[bold]Comparison with baseline run from {baseline_timestamp}[/bold]",
|
||||
expand=False))
|
||||
def comparison_summary(self, comparison: dict[str, Any], baseline_timestamp: str):
|
||||
self.console.print(
|
||||
Panel(
|
||||
f"[bold]Comparison with baseline run from {baseline_timestamp}[/bold]",
|
||||
expand=False,
|
||||
)
|
||||
)
|
||||
|
||||
table = Table(title="Results Comparison")
|
||||
table.add_column("Metric", style="cyan")
|
||||
@@ -34,7 +43,9 @@ class ExperimentResultsDisplay:
|
||||
|
||||
improved = comparison.get("improved", [])
|
||||
if improved:
|
||||
details = ", ".join([f"{test_identifier}" for test_identifier in improved[:3]])
|
||||
details = ", ".join(
|
||||
[f"{test_identifier}" for test_identifier in improved[:3]]
|
||||
)
|
||||
if len(improved) > 3:
|
||||
details += f" and {len(improved) - 3} more"
|
||||
table.add_row("✅ Improved", str(len(improved)), details)
|
||||
@@ -43,7 +54,9 @@ class ExperimentResultsDisplay:
|
||||
|
||||
regressed = comparison.get("regressed", [])
|
||||
if regressed:
|
||||
details = ", ".join([f"{test_identifier}" for test_identifier in regressed[:3]])
|
||||
details = ", ".join(
|
||||
[f"{test_identifier}" for test_identifier in regressed[:3]]
|
||||
)
|
||||
if len(regressed) > 3:
|
||||
details += f" and {len(regressed) - 3} more"
|
||||
table.add_row("❌ Regressed", str(len(regressed)), details, style="red")
|
||||
@@ -58,13 +71,13 @@ class ExperimentResultsDisplay:
|
||||
details = ", ".join(new_tests[:3])
|
||||
if len(new_tests) > 3:
|
||||
details += f" and {len(new_tests) - 3} more"
|
||||
table.add_row("➕ New Tests", str(len(new_tests)), details)
|
||||
table.add_row("+ New Tests", str(len(new_tests)), details)
|
||||
|
||||
missing_tests = comparison.get("missing_tests", [])
|
||||
if missing_tests:
|
||||
details = ", ".join(missing_tests[:3])
|
||||
if len(missing_tests) > 3:
|
||||
details += f" and {len(missing_tests) - 3} more"
|
||||
table.add_row("➖ Missing Tests", str(len(missing_tests)), details)
|
||||
table.add_row("- Missing Tests", str(len(missing_tests)), details)
|
||||
|
||||
self.console.print(table)
|
||||
|
||||
@@ -2,11 +2,20 @@ from collections import defaultdict
|
||||
from hashlib import md5
|
||||
from typing import Any
|
||||
|
||||
from crewai import Crew, Agent
|
||||
from crewai import Agent, Crew
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.experimental.evaluation import AgentEvaluator, create_default_evaluator
|
||||
from crewai.experimental.evaluation.experiment.result_display import ExperimentResultsDisplay
|
||||
from crewai.experimental.evaluation.experiment.result import ExperimentResults, ExperimentResult
|
||||
from crewai.experimental.evaluation.evaluation_display import AgentAggregatedEvaluationResult
|
||||
from crewai.experimental.evaluation.evaluation_display import (
|
||||
AgentAggregatedEvaluationResult,
|
||||
)
|
||||
from crewai.experimental.evaluation.experiment.result import (
|
||||
ExperimentResult,
|
||||
ExperimentResults,
|
||||
)
|
||||
from crewai.experimental.evaluation.experiment.result_display import (
|
||||
ExperimentResultsDisplay,
|
||||
)
|
||||
|
||||
|
||||
class ExperimentRunner:
|
||||
def __init__(self, dataset: list[dict[str, Any]]):
|
||||
@@ -14,11 +23,17 @@ class ExperimentRunner:
|
||||
self.evaluator: AgentEvaluator | None = None
|
||||
self.display = ExperimentResultsDisplay()
|
||||
|
||||
def run(self, crew: Crew | None = None, agents: list[Agent] | None = None, print_summary: bool = False) -> ExperimentResults:
|
||||
def run(
|
||||
self,
|
||||
crew: Crew | None = None,
|
||||
agents: list[Agent] | list[BaseAgent] | None = None,
|
||||
print_summary: bool = False,
|
||||
) -> ExperimentResults:
|
||||
if crew and not agents:
|
||||
agents = crew.agents
|
||||
|
||||
assert agents is not None
|
||||
if agents is None:
|
||||
raise ValueError("Agents must be provided either directly or via a crew")
|
||||
self.evaluator = create_default_evaluator(agents=agents)
|
||||
|
||||
results = []
|
||||
@@ -35,21 +50,37 @@ class ExperimentRunner:
|
||||
|
||||
return experiment_results
|
||||
|
||||
def _run_test_case(self, test_case: dict[str, Any], agents: list[Agent], crew: Crew | None = None) -> ExperimentResult:
|
||||
def _run_test_case(
|
||||
self,
|
||||
test_case: dict[str, Any],
|
||||
agents: list[Agent] | list[BaseAgent],
|
||||
crew: Crew | None = None,
|
||||
) -> ExperimentResult:
|
||||
inputs = test_case["inputs"]
|
||||
expected_score = test_case["expected_score"]
|
||||
identifier = test_case.get("identifier") or md5(str(test_case).encode(), usedforsecurity=False).hexdigest()
|
||||
identifier = (
|
||||
test_case.get("identifier")
|
||||
or md5(str(test_case).encode(), usedforsecurity=False).hexdigest()
|
||||
)
|
||||
|
||||
try:
|
||||
self.display.console.print(f"[dim]Running crew with input: {str(inputs)[:50]}...[/dim]")
|
||||
self.display.console.print(
|
||||
f"[dim]Running crew with input: {str(inputs)[:50]}...[/dim]"
|
||||
)
|
||||
self.display.console.print("\n")
|
||||
if crew:
|
||||
crew.kickoff(inputs=inputs)
|
||||
else:
|
||||
for agent in agents:
|
||||
agent.kickoff(**inputs)
|
||||
if isinstance(agent, Agent):
|
||||
agent.kickoff(**inputs)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Agent {agent} is not an instance of Agent and cannot be kicked off directly"
|
||||
)
|
||||
|
||||
assert self.evaluator is not None
|
||||
if self.evaluator is None:
|
||||
raise ValueError("Evaluator must be initialized")
|
||||
agent_evaluations = self.evaluator.get_agent_evaluation()
|
||||
|
||||
actual_score = self._extract_scores(agent_evaluations)
|
||||
@@ -61,35 +92,38 @@ class ExperimentRunner:
|
||||
score=actual_score,
|
||||
expected_score=expected_score,
|
||||
passed=passed,
|
||||
agent_evaluations=agent_evaluations
|
||||
agent_evaluations=agent_evaluations,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.display.console.print(f"[red]Error running test case: {str(e)}[/red]")
|
||||
self.display.console.print(f"[red]Error running test case: {e!s}[/red]")
|
||||
return ExperimentResult(
|
||||
identifier=identifier,
|
||||
inputs=inputs,
|
||||
score=0,
|
||||
score=0.0,
|
||||
expected_score=expected_score,
|
||||
passed=False
|
||||
passed=False,
|
||||
)
|
||||
|
||||
def _extract_scores(self, agent_evaluations: dict[str, AgentAggregatedEvaluationResult]) -> float | dict[str, float]:
|
||||
def _extract_scores(
|
||||
self, agent_evaluations: dict[str, AgentAggregatedEvaluationResult]
|
||||
) -> float | dict[str, float]:
|
||||
all_scores: dict[str, list[float]] = defaultdict(list)
|
||||
for evaluation in agent_evaluations.values():
|
||||
for metric_name, score in evaluation.metrics.items():
|
||||
if score.score is not None:
|
||||
all_scores[metric_name.value].append(score.score)
|
||||
|
||||
avg_scores = {m: sum(s)/len(s) for m, s in all_scores.items()}
|
||||
avg_scores = {m: sum(s) / len(s) for m, s in all_scores.items()}
|
||||
|
||||
if len(avg_scores) == 1:
|
||||
return list(avg_scores.values())[0]
|
||||
return next(iter(avg_scores.values()))
|
||||
|
||||
return avg_scores
|
||||
|
||||
def _assert_scores(self, expected: float | dict[str, float],
|
||||
actual: float | dict[str, float]) -> bool:
|
||||
def _assert_scores(
|
||||
self, expected: float | dict[str, float], actual: float | dict[str, float]
|
||||
) -> bool:
|
||||
"""
|
||||
Compare expected and actual scores, and return whether the test case passed.
|
||||
|
||||
@@ -122,4 +156,4 @@ class ExperimentRunner:
|
||||
# All matching keys must have actual >= expected
|
||||
return all(actual[key] >= expected[key] for key in matching_keys)
|
||||
|
||||
return False
|
||||
return False
|
||||
|
||||
@@ -13,11 +13,11 @@ def extract_json_from_llm_response(text: str) -> dict[str, Any]:
|
||||
|
||||
json_patterns = [
|
||||
# Standard markdown code blocks with json
|
||||
r'```json\s*([\s\S]*?)\s*```',
|
||||
r"```json\s*([\s\S]*?)\s*```",
|
||||
# Code blocks without language specifier
|
||||
r'```\s*([\s\S]*?)\s*```',
|
||||
r"```\s*([\s\S]*?)\s*```",
|
||||
# Inline code with JSON
|
||||
r'`([{\\[].*[}\]])`',
|
||||
r"`([{\\[].*[}\]])`",
|
||||
]
|
||||
|
||||
for pattern in json_patterns:
|
||||
@@ -25,6 +25,6 @@ def extract_json_from_llm_response(text: str) -> dict[str, Any]:
|
||||
for match in matches:
|
||||
try:
|
||||
return json.loads(match.strip())
|
||||
except json.JSONDecodeError:
|
||||
except json.JSONDecodeError: # noqa: PERF203
|
||||
continue
|
||||
raise ValueError("No valid JSON found in the response")
|
||||
|
||||
@@ -1,26 +1,21 @@
|
||||
from crewai.experimental.evaluation.metrics.goal_metrics import GoalAlignmentEvaluator
|
||||
from crewai.experimental.evaluation.metrics.reasoning_metrics import (
|
||||
ReasoningEfficiencyEvaluator
|
||||
ReasoningEfficiencyEvaluator,
|
||||
)
|
||||
|
||||
from crewai.experimental.evaluation.metrics.tools_metrics import (
|
||||
ToolSelectionEvaluator,
|
||||
ParameterExtractionEvaluator,
|
||||
ToolInvocationEvaluator
|
||||
)
|
||||
|
||||
from crewai.experimental.evaluation.metrics.goal_metrics import (
|
||||
GoalAlignmentEvaluator
|
||||
)
|
||||
|
||||
from crewai.experimental.evaluation.metrics.semantic_quality_metrics import (
|
||||
SemanticQualityEvaluator
|
||||
SemanticQualityEvaluator,
|
||||
)
|
||||
from crewai.experimental.evaluation.metrics.tools_metrics import (
|
||||
ParameterExtractionEvaluator,
|
||||
ToolInvocationEvaluator,
|
||||
ToolSelectionEvaluator,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ReasoningEfficiencyEvaluator",
|
||||
"ToolSelectionEvaluator",
|
||||
"ParameterExtractionEvaluator",
|
||||
"ToolInvocationEvaluator",
|
||||
"GoalAlignmentEvaluator",
|
||||
"SemanticQualityEvaluator"
|
||||
]
|
||||
"ParameterExtractionEvaluator",
|
||||
"ReasoningEfficiencyEvaluator",
|
||||
"SemanticQualityEvaluator",
|
||||
"ToolInvocationEvaluator",
|
||||
"ToolSelectionEvaluator",
|
||||
]
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.experimental.evaluation.base_evaluator import (
|
||||
BaseEvaluator,
|
||||
EvaluationScore,
|
||||
MetricCategory,
|
||||
)
|
||||
from crewai.experimental.evaluation.json_parser import extract_json_from_llm_response
|
||||
from crewai.task import Task
|
||||
|
||||
from crewai.experimental.evaluation.base_evaluator import BaseEvaluator, EvaluationScore, MetricCategory
|
||||
from crewai.experimental.evaluation.json_parser import extract_json_from_llm_response
|
||||
|
||||
class GoalAlignmentEvaluator(BaseEvaluator):
|
||||
@property
|
||||
@@ -13,8 +18,8 @@ class GoalAlignmentEvaluator(BaseEvaluator):
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
agent: Agent,
|
||||
execution_trace: Dict[str, Any],
|
||||
agent: Agent | BaseAgent,
|
||||
execution_trace: dict[str, Any],
|
||||
final_output: Any,
|
||||
task: Task | None = None,
|
||||
) -> EvaluationScore:
|
||||
@@ -23,7 +28,9 @@ class GoalAlignmentEvaluator(BaseEvaluator):
|
||||
task_context = f"Task description: {task.description}\nExpected output: {task.expected_output}\n"
|
||||
|
||||
prompt = [
|
||||
{"role": "system", "content": """You are an expert evaluator assessing how well an AI agent's output aligns with its assigned task goal.
|
||||
{
|
||||
"role": "system",
|
||||
"content": """You are an expert evaluator assessing how well an AI agent's output aligns with its assigned task goal.
|
||||
|
||||
Score the agent's goal alignment on a scale from 0-10 where:
|
||||
- 0: Complete misalignment, agent did not understand or attempt the task goal
|
||||
@@ -37,8 +44,11 @@ Consider:
|
||||
4. Did the agent provide all requested information or deliverables?
|
||||
|
||||
Return your evaluation as JSON with fields 'score' (number) and 'feedback' (string).
|
||||
"""},
|
||||
{"role": "user", "content": f"""
|
||||
""",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""
|
||||
Agent role: {agent.role}
|
||||
Agent goal: {agent.goal}
|
||||
{task_context}
|
||||
@@ -47,23 +57,26 @@ Agent's final output:
|
||||
{final_output}
|
||||
|
||||
Evaluate how well the agent's output aligns with the assigned task goal.
|
||||
"""}
|
||||
""",
|
||||
},
|
||||
]
|
||||
assert self.llm is not None
|
||||
if self.llm is None:
|
||||
raise ValueError("LLM must be initialized")
|
||||
response = self.llm.call(prompt)
|
||||
|
||||
try:
|
||||
evaluation_data: dict[str, Any] = extract_json_from_llm_response(response)
|
||||
assert evaluation_data is not None
|
||||
if evaluation_data is None:
|
||||
raise ValueError("Failed to extract evaluation data from LLM response")
|
||||
|
||||
return EvaluationScore(
|
||||
score=evaluation_data.get("score", 0),
|
||||
feedback=evaluation_data.get("feedback", response),
|
||||
raw_response=response
|
||||
raw_response=response,
|
||||
)
|
||||
except Exception:
|
||||
return EvaluationScore(
|
||||
score=None,
|
||||
feedback=f"Failed to parse evaluation. Raw response: {response}",
|
||||
raw_response=response
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
@@ -8,18 +8,24 @@ This module provides evaluator implementations for:
|
||||
|
||||
import logging
|
||||
import re
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Tuple
|
||||
import numpy as np
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.task import Task
|
||||
|
||||
from crewai.experimental.evaluation.base_evaluator import BaseEvaluator, EvaluationScore, MetricCategory
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.experimental.evaluation.base_evaluator import (
|
||||
BaseEvaluator,
|
||||
EvaluationScore,
|
||||
MetricCategory,
|
||||
)
|
||||
from crewai.experimental.evaluation.json_parser import extract_json_from_llm_response
|
||||
from crewai.task import Task
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
|
||||
|
||||
class ReasoningPatternType(Enum):
|
||||
EFFICIENT = "efficient" # Good reasoning flow
|
||||
LOOP = "loop" # Agent is stuck in a loop
|
||||
@@ -35,8 +41,8 @@ class ReasoningEfficiencyEvaluator(BaseEvaluator):
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
agent: Agent,
|
||||
execution_trace: Dict[str, Any],
|
||||
agent: Agent | BaseAgent,
|
||||
execution_trace: dict[str, Any],
|
||||
final_output: TaskOutput | str,
|
||||
task: Task | None = None,
|
||||
) -> EvaluationScore:
|
||||
@@ -49,7 +55,7 @@ class ReasoningEfficiencyEvaluator(BaseEvaluator):
|
||||
if not llm_calls or len(llm_calls) < 2:
|
||||
return EvaluationScore(
|
||||
score=None,
|
||||
feedback="Insufficient LLM calls to evaluate reasoning efficiency."
|
||||
feedback="Insufficient LLM calls to evaluate reasoning efficiency.",
|
||||
)
|
||||
|
||||
total_calls = len(llm_calls)
|
||||
@@ -58,12 +64,16 @@ class ReasoningEfficiencyEvaluator(BaseEvaluator):
|
||||
time_intervals = []
|
||||
has_reliable_timing = True
|
||||
for i in range(1, len(llm_calls)):
|
||||
start_time = llm_calls[i-1].get("end_time")
|
||||
start_time = llm_calls[i - 1].get("end_time")
|
||||
end_time = llm_calls[i].get("start_time")
|
||||
if start_time and end_time and start_time != end_time:
|
||||
try:
|
||||
interval = end_time - start_time
|
||||
time_intervals.append(interval.total_seconds() if hasattr(interval, 'total_seconds') else 0)
|
||||
time_intervals.append(
|
||||
interval.total_seconds()
|
||||
if hasattr(interval, "total_seconds")
|
||||
else 0
|
||||
)
|
||||
except Exception:
|
||||
has_reliable_timing = False
|
||||
else:
|
||||
@@ -83,14 +93,22 @@ class ReasoningEfficiencyEvaluator(BaseEvaluator):
|
||||
if has_reliable_timing and time_intervals:
|
||||
efficiency_metrics["avg_time_between_calls"] = np.mean(time_intervals)
|
||||
|
||||
loop_info = f"Detected {len(loop_details)} potential reasoning loops." if loop_detected else "No significant reasoning loops detected."
|
||||
loop_info = (
|
||||
f"Detected {len(loop_details)} potential reasoning loops."
|
||||
if loop_detected
|
||||
else "No significant reasoning loops detected."
|
||||
)
|
||||
|
||||
call_samples = self._get_call_samples(llm_calls)
|
||||
|
||||
final_output = final_output.raw if isinstance(final_output, TaskOutput) else final_output
|
||||
final_output = (
|
||||
final_output.raw if isinstance(final_output, TaskOutput) else final_output
|
||||
)
|
||||
|
||||
prompt = [
|
||||
{"role": "system", "content": """You are an expert evaluator assessing the reasoning efficiency of an AI agent's thought process.
|
||||
{
|
||||
"role": "system",
|
||||
"content": """You are an expert evaluator assessing the reasoning efficiency of an AI agent's thought process.
|
||||
|
||||
Evaluate the agent's reasoning efficiency across these five key subcategories:
|
||||
|
||||
@@ -120,8 +138,11 @@ Return your evaluation as JSON with the following structure:
|
||||
"feedback": string (general feedback about overall reasoning efficiency),
|
||||
"optimization_suggestions": string (concrete suggestions for improving reasoning efficiency),
|
||||
"detected_patterns": string (describe any inefficient reasoning patterns you observe)
|
||||
}"""},
|
||||
{"role": "user", "content": f"""
|
||||
}""",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""
|
||||
Agent role: {agent.role}
|
||||
{task_context}
|
||||
|
||||
@@ -140,10 +161,12 @@ Agent's final output:
|
||||
|
||||
Evaluate the reasoning efficiency of this agent based on these interaction patterns.
|
||||
Identify any inefficient reasoning patterns and provide specific suggestions for optimization.
|
||||
"""}
|
||||
""",
|
||||
},
|
||||
]
|
||||
|
||||
assert self.llm is not None
|
||||
if self.llm is None:
|
||||
raise ValueError("LLM must be initialized")
|
||||
response = self.llm.call(prompt)
|
||||
|
||||
try:
|
||||
@@ -156,34 +179,46 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
||||
conciseness = scores.get("conciseness", 5.0)
|
||||
loop_avoidance = scores.get("loop_avoidance", 5.0)
|
||||
|
||||
overall_score = evaluation_data.get("overall_score", evaluation_data.get("score", 5.0))
|
||||
overall_score = evaluation_data.get(
|
||||
"overall_score", evaluation_data.get("score", 5.0)
|
||||
)
|
||||
feedback = evaluation_data.get("feedback", "No detailed feedback provided.")
|
||||
optimization_suggestions = evaluation_data.get("optimization_suggestions", "No specific suggestions provided.")
|
||||
optimization_suggestions = evaluation_data.get(
|
||||
"optimization_suggestions", "No specific suggestions provided."
|
||||
)
|
||||
|
||||
detailed_feedback = "Reasoning Efficiency Evaluation:\n"
|
||||
detailed_feedback += f"• Focus: {focus}/10 - Staying on topic without tangents\n"
|
||||
detailed_feedback += f"• Progression: {progression}/10 - Building on previous thinking\n"
|
||||
detailed_feedback += (
|
||||
f"• Focus: {focus}/10 - Staying on topic without tangents\n"
|
||||
)
|
||||
detailed_feedback += (
|
||||
f"• Progression: {progression}/10 - Building on previous thinking\n"
|
||||
)
|
||||
detailed_feedback += f"• Decision Quality: {decision_quality}/10 - Making appropriate decisions\n"
|
||||
detailed_feedback += f"• Conciseness: {conciseness}/10 - Communicating efficiently\n"
|
||||
detailed_feedback += (
|
||||
f"• Conciseness: {conciseness}/10 - Communicating efficiently\n"
|
||||
)
|
||||
detailed_feedback += f"• Loop Avoidance: {loop_avoidance}/10 - Avoiding repetitive patterns\n\n"
|
||||
|
||||
detailed_feedback += f"Feedback:\n{feedback}\n\n"
|
||||
detailed_feedback += f"Optimization Suggestions:\n{optimization_suggestions}"
|
||||
detailed_feedback += (
|
||||
f"Optimization Suggestions:\n{optimization_suggestions}"
|
||||
)
|
||||
|
||||
return EvaluationScore(
|
||||
score=float(overall_score),
|
||||
feedback=detailed_feedback,
|
||||
raw_response=response
|
||||
raw_response=response,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to parse reasoning efficiency evaluation: {e}")
|
||||
return EvaluationScore(
|
||||
score=None,
|
||||
feedback=f"Failed to parse reasoning efficiency evaluation. Raw response: {response[:200]}...",
|
||||
raw_response=response
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
def _detect_loops(self, llm_calls: List[Dict]) -> Tuple[bool, List[Dict]]:
|
||||
def _detect_loops(self, llm_calls: list[dict]) -> tuple[bool, list[dict]]:
|
||||
loop_details = []
|
||||
|
||||
messages = []
|
||||
@@ -193,9 +228,11 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
||||
messages.append(content)
|
||||
elif isinstance(content, list) and len(content) > 0:
|
||||
# Handle message list format
|
||||
for msg in content:
|
||||
if isinstance(msg, dict) and "content" in msg:
|
||||
messages.append(msg["content"])
|
||||
messages.extend(
|
||||
msg["content"]
|
||||
for msg in content
|
||||
if isinstance(msg, dict) and "content" in msg
|
||||
)
|
||||
|
||||
# Simple n-gram based similarity detection
|
||||
# For a more robust implementation, consider using embedding-based similarity
|
||||
@@ -205,18 +242,20 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
||||
# A more sophisticated approach would use semantic similarity
|
||||
similarity = self._calculate_text_similarity(messages[i], messages[j])
|
||||
if similarity > 0.7: # Arbitrary threshold
|
||||
loop_details.append({
|
||||
"first_occurrence": i,
|
||||
"second_occurrence": j,
|
||||
"similarity": similarity,
|
||||
"snippet": messages[i][:100] + "..."
|
||||
})
|
||||
loop_details.append(
|
||||
{
|
||||
"first_occurrence": i,
|
||||
"second_occurrence": j,
|
||||
"similarity": similarity,
|
||||
"snippet": messages[i][:100] + "...",
|
||||
}
|
||||
)
|
||||
|
||||
return len(loop_details) > 0, loop_details
|
||||
|
||||
def _calculate_text_similarity(self, text1: str, text2: str) -> float:
|
||||
text1 = re.sub(r'\s+', ' ', text1.lower()).strip()
|
||||
text2 = re.sub(r'\s+', ' ', text2.lower()).strip()
|
||||
text1 = re.sub(r"\s+", " ", text1.lower()).strip()
|
||||
text2 = re.sub(r"\s+", " ", text2.lower()).strip()
|
||||
|
||||
# Simple Jaccard similarity on word sets
|
||||
words1 = set(text1.split())
|
||||
@@ -227,7 +266,7 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
||||
|
||||
return intersection / union if union > 0 else 0.0
|
||||
|
||||
def _analyze_reasoning_patterns(self, llm_calls: List[Dict]) -> Dict[str, Any]:
|
||||
def _analyze_reasoning_patterns(self, llm_calls: list[dict]) -> dict[str, Any]:
|
||||
call_lengths = []
|
||||
response_times = []
|
||||
|
||||
@@ -248,8 +287,8 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
||||
if start_time and end_time:
|
||||
try:
|
||||
response_times.append(end_time - start_time)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logging.debug(f"Failed to calculate response time: {e}")
|
||||
|
||||
avg_length = np.mean(call_lengths) if call_lengths else 0
|
||||
std_length = np.std(call_lengths) if call_lengths else 0
|
||||
@@ -267,7 +306,9 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
||||
details = "Agent is consistently verbose across interactions."
|
||||
elif len(llm_calls) > 10 and length_trend > 0.5:
|
||||
primary_pattern = ReasoningPatternType.INDECISIVE
|
||||
details = "Agent shows signs of indecisiveness with increasing message lengths."
|
||||
details = (
|
||||
"Agent shows signs of indecisiveness with increasing message lengths."
|
||||
)
|
||||
elif std_length / avg_length > 0.8:
|
||||
primary_pattern = ReasoningPatternType.SCATTERED
|
||||
details = "Agent shows inconsistent reasoning flow with highly variable responses."
|
||||
@@ -279,8 +320,8 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
||||
"avg_length": avg_length,
|
||||
"std_length": std_length,
|
||||
"length_trend": length_trend,
|
||||
"loop_score": loop_score
|
||||
}
|
||||
"loop_score": loop_score,
|
||||
},
|
||||
}
|
||||
|
||||
def _calculate_trend(self, values: Sequence[float | int]) -> float:
|
||||
@@ -303,7 +344,9 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
||||
except Exception:
|
||||
return 0.0
|
||||
|
||||
def _calculate_loop_likelihood(self, call_lengths: Sequence[float], response_times: Sequence[float]) -> float:
|
||||
def _calculate_loop_likelihood(
|
||||
self, call_lengths: Sequence[float], response_times: Sequence[float]
|
||||
) -> float:
|
||||
if not call_lengths or len(call_lengths) < 3:
|
||||
return 0.0
|
||||
|
||||
@@ -312,7 +355,11 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
||||
if len(call_lengths) >= 4:
|
||||
repeated_lengths = 0
|
||||
for i in range(len(call_lengths) - 2):
|
||||
ratio = call_lengths[i] / call_lengths[i + 2] if call_lengths[i + 2] > 0 else 0
|
||||
ratio = (
|
||||
call_lengths[i] / call_lengths[i + 2]
|
||||
if call_lengths[i + 2] > 0
|
||||
else 0
|
||||
)
|
||||
if 0.85 <= ratio <= 1.15:
|
||||
repeated_lengths += 1
|
||||
|
||||
@@ -324,21 +371,27 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
||||
std_time = np.std(response_times)
|
||||
mean_time = np.mean(response_times)
|
||||
if mean_time > 0:
|
||||
time_consistency = 1.0 - (std_time / mean_time)
|
||||
indicators.append(max(0, time_consistency - 0.3) * 1.5)
|
||||
except Exception:
|
||||
pass
|
||||
time_consistency = 1.0 - (float(std_time) / float(mean_time))
|
||||
indicators.append(max(0.0, float(time_consistency - 0.3)) * 1.5)
|
||||
except Exception as e:
|
||||
logging.debug(f"Time consistency calculation failed: {e}")
|
||||
|
||||
return np.mean(indicators) if indicators else 0.0
|
||||
return float(np.mean(indicators)) if indicators else 0.0
|
||||
|
||||
def _get_call_samples(self, llm_calls: List[Dict]) -> str:
|
||||
def _get_call_samples(self, llm_calls: list[dict]) -> str:
|
||||
samples = []
|
||||
|
||||
if len(llm_calls) <= 6:
|
||||
sample_indices = list(range(len(llm_calls)))
|
||||
else:
|
||||
sample_indices = [0, 1, len(llm_calls) // 2 - 1, len(llm_calls) // 2,
|
||||
len(llm_calls) - 2, len(llm_calls) - 1]
|
||||
sample_indices = [
|
||||
0,
|
||||
1,
|
||||
len(llm_calls) // 2 - 1,
|
||||
len(llm_calls) // 2,
|
||||
len(llm_calls) - 2,
|
||||
len(llm_calls) - 1,
|
||||
]
|
||||
|
||||
for idx in sample_indices:
|
||||
call = llm_calls[idx]
|
||||
@@ -347,10 +400,11 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
||||
if isinstance(content, str):
|
||||
sample = content
|
||||
elif isinstance(content, list) and len(content) > 0:
|
||||
sample_parts = []
|
||||
for msg in content:
|
||||
if isinstance(msg, dict) and "content" in msg:
|
||||
sample_parts.append(msg["content"])
|
||||
sample_parts = [
|
||||
msg["content"]
|
||||
for msg in content
|
||||
if isinstance(msg, dict) and "content" in msg
|
||||
]
|
||||
sample = "\n".join(sample_parts)
|
||||
else:
|
||||
sample = str(content)
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.experimental.evaluation.base_evaluator import (
|
||||
BaseEvaluator,
|
||||
EvaluationScore,
|
||||
MetricCategory,
|
||||
)
|
||||
from crewai.experimental.evaluation.json_parser import extract_json_from_llm_response
|
||||
from crewai.task import Task
|
||||
|
||||
from crewai.experimental.evaluation.base_evaluator import BaseEvaluator, EvaluationScore, MetricCategory
|
||||
from crewai.experimental.evaluation.json_parser import extract_json_from_llm_response
|
||||
|
||||
class SemanticQualityEvaluator(BaseEvaluator):
|
||||
@property
|
||||
@@ -13,8 +18,8 @@ class SemanticQualityEvaluator(BaseEvaluator):
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
agent: Agent,
|
||||
execution_trace: Dict[str, Any],
|
||||
agent: Agent | BaseAgent,
|
||||
execution_trace: dict[str, Any],
|
||||
final_output: Any,
|
||||
task: Task | None = None,
|
||||
) -> EvaluationScore:
|
||||
@@ -22,7 +27,9 @@ class SemanticQualityEvaluator(BaseEvaluator):
|
||||
if task is not None:
|
||||
task_context = f"Task description: {task.description}"
|
||||
prompt = [
|
||||
{"role": "system", "content": """You are an expert evaluator assessing the semantic quality of an AI agent's output.
|
||||
{
|
||||
"role": "system",
|
||||
"content": """You are an expert evaluator assessing the semantic quality of an AI agent's output.
|
||||
|
||||
Score the semantic quality on a scale from 0-10 where:
|
||||
- 0: Completely incoherent, confusing, or logically flawed output
|
||||
@@ -37,8 +44,11 @@ Consider:
|
||||
5. Is the output free from contradictions and logical fallacies?
|
||||
|
||||
Return your evaluation as JSON with fields 'score' (number) and 'feedback' (string).
|
||||
"""},
|
||||
{"role": "user", "content": f"""
|
||||
""",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""
|
||||
Agent role: {agent.role}
|
||||
{task_context}
|
||||
|
||||
@@ -46,23 +56,28 @@ Agent's final output:
|
||||
{final_output}
|
||||
|
||||
Evaluate the semantic quality and reasoning of this output.
|
||||
"""}
|
||||
""",
|
||||
},
|
||||
]
|
||||
|
||||
assert self.llm is not None
|
||||
if self.llm is None:
|
||||
raise ValueError("LLM must be initialized")
|
||||
response = self.llm.call(prompt)
|
||||
|
||||
try:
|
||||
evaluation_data: dict[str, Any] = extract_json_from_llm_response(response)
|
||||
assert evaluation_data is not None
|
||||
if evaluation_data is None:
|
||||
raise ValueError("Failed to extract evaluation data from LLM response")
|
||||
return EvaluationScore(
|
||||
score=float(evaluation_data["score"]) if evaluation_data.get("score") is not None else None,
|
||||
score=float(evaluation_data["score"])
|
||||
if evaluation_data.get("score") is not None
|
||||
else None,
|
||||
feedback=evaluation_data.get("feedback", response),
|
||||
raw_response=response
|
||||
raw_response=response,
|
||||
)
|
||||
except Exception:
|
||||
return EvaluationScore(
|
||||
score=None,
|
||||
feedback=f"Failed to parse evaluation. Raw response: {response}",
|
||||
raw_response=response
|
||||
)
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
@@ -1,22 +1,26 @@
|
||||
import json
|
||||
from typing import Dict, Any
|
||||
from typing import Any
|
||||
|
||||
from crewai.experimental.evaluation.base_evaluator import BaseEvaluator, EvaluationScore, MetricCategory
|
||||
from crewai.experimental.evaluation.json_parser import extract_json_from_llm_response
|
||||
from crewai.agent import Agent
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.experimental.evaluation.base_evaluator import (
|
||||
BaseEvaluator,
|
||||
EvaluationScore,
|
||||
MetricCategory,
|
||||
)
|
||||
from crewai.experimental.evaluation.json_parser import extract_json_from_llm_response
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
class ToolSelectionEvaluator(BaseEvaluator):
|
||||
|
||||
@property
|
||||
def metric_category(self) -> MetricCategory:
|
||||
return MetricCategory.TOOL_SELECTION
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
agent: Agent,
|
||||
execution_trace: Dict[str, Any],
|
||||
agent: Agent | BaseAgent,
|
||||
execution_trace: dict[str, Any],
|
||||
final_output: str,
|
||||
task: Task | None = None,
|
||||
) -> EvaluationScore:
|
||||
@@ -26,19 +30,18 @@ class ToolSelectionEvaluator(BaseEvaluator):
|
||||
|
||||
tool_uses = execution_trace.get("tool_uses", [])
|
||||
tool_count = len(tool_uses)
|
||||
unique_tool_types = set([tool.get("tool", "Unknown tool") for tool in tool_uses])
|
||||
unique_tool_types = set(
|
||||
[tool.get("tool", "Unknown tool") for tool in tool_uses]
|
||||
)
|
||||
|
||||
if tool_count == 0:
|
||||
if not agent.tools:
|
||||
return EvaluationScore(
|
||||
score=None,
|
||||
feedback="Agent had no tools available to use."
|
||||
)
|
||||
else:
|
||||
return EvaluationScore(
|
||||
score=None,
|
||||
feedback="Agent had tools available but didn't use any."
|
||||
score=None, feedback="Agent had no tools available to use."
|
||||
)
|
||||
return EvaluationScore(
|
||||
score=None, feedback="Agent had tools available but didn't use any."
|
||||
)
|
||||
|
||||
available_tools_info = ""
|
||||
if agent.tools:
|
||||
@@ -52,7 +55,9 @@ class ToolSelectionEvaluator(BaseEvaluator):
|
||||
tool_types_summary += f"- {tool_type}\n"
|
||||
|
||||
prompt = [
|
||||
{"role": "system", "content": """You are an expert evaluator assessing if an AI agent selected the most appropriate tools for a given task.
|
||||
{
|
||||
"role": "system",
|
||||
"content": """You are an expert evaluator assessing if an AI agent selected the most appropriate tools for a given task.
|
||||
|
||||
You must evaluate based on these 2 criteria:
|
||||
1. Relevance (0-10): Were the tools chosen directly aligned with the task's goals?
|
||||
@@ -73,8 +78,11 @@ Return your evaluation as JSON with these fields:
|
||||
- overall_score: number (average of all scores, 0-10)
|
||||
- feedback: string (focused ONLY on tool selection decisions from available tools)
|
||||
- improvement_suggestions: string (ONLY suggest better selection from the AVAILABLE tools list, NOT new tools)
|
||||
"""},
|
||||
{"role": "user", "content": f"""
|
||||
""",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""
|
||||
Agent role: {agent.role}
|
||||
{task_context}
|
||||
|
||||
@@ -89,14 +97,17 @@ IMPORTANT:
|
||||
- ONLY evaluate selection from tools listed as available
|
||||
- DO NOT suggest new tools that aren't in the available tools list
|
||||
- DO NOT evaluate tool usage or results
|
||||
"""}
|
||||
""",
|
||||
},
|
||||
]
|
||||
assert self.llm is not None
|
||||
if self.llm is None:
|
||||
raise ValueError("LLM must be initialized")
|
||||
response = self.llm.call(prompt)
|
||||
|
||||
try:
|
||||
evaluation_data = extract_json_from_llm_response(response)
|
||||
assert evaluation_data is not None
|
||||
if evaluation_data is None:
|
||||
raise ValueError("Failed to extract evaluation data from LLM response")
|
||||
|
||||
scores = evaluation_data.get("scores", {})
|
||||
relevance = scores.get("relevance", 5.0)
|
||||
@@ -105,22 +116,24 @@ IMPORTANT:
|
||||
|
||||
feedback = "Tool Selection Evaluation:\n"
|
||||
feedback += f"• Relevance: {relevance}/10 - Selection of appropriate tool types for the task\n"
|
||||
feedback += f"• Coverage: {coverage}/10 - Selection of all necessary tool types\n"
|
||||
feedback += (
|
||||
f"• Coverage: {coverage}/10 - Selection of all necessary tool types\n"
|
||||
)
|
||||
if "improvement_suggestions" in evaluation_data:
|
||||
feedback += f"Improvement Suggestions:\n{evaluation_data['improvement_suggestions']}"
|
||||
else:
|
||||
feedback += evaluation_data.get("feedback", "No detailed feedback available.")
|
||||
feedback += evaluation_data.get(
|
||||
"feedback", "No detailed feedback available."
|
||||
)
|
||||
|
||||
return EvaluationScore(
|
||||
score=overall_score,
|
||||
feedback=feedback,
|
||||
raw_response=response
|
||||
score=overall_score, feedback=feedback, raw_response=response
|
||||
)
|
||||
except Exception as e:
|
||||
return EvaluationScore(
|
||||
score=None,
|
||||
feedback=f"Error evaluating tool selection: {e}",
|
||||
raw_response=response
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
|
||||
@@ -131,8 +144,8 @@ class ParameterExtractionEvaluator(BaseEvaluator):
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
agent: Agent,
|
||||
execution_trace: Dict[str, Any],
|
||||
agent: Agent | BaseAgent,
|
||||
execution_trace: dict[str, Any],
|
||||
final_output: str,
|
||||
task: Task | None = None,
|
||||
) -> EvaluationScore:
|
||||
@@ -145,19 +158,23 @@ class ParameterExtractionEvaluator(BaseEvaluator):
|
||||
if tool_count == 0:
|
||||
return EvaluationScore(
|
||||
score=None,
|
||||
feedback="No tool usage detected. Cannot evaluate parameter extraction."
|
||||
feedback="No tool usage detected. Cannot evaluate parameter extraction.",
|
||||
)
|
||||
|
||||
validation_errors = []
|
||||
for tool_use in tool_uses:
|
||||
if not tool_use.get("success", True) and tool_use.get("error_type") == "validation_error":
|
||||
validation_errors.append({
|
||||
"tool": tool_use.get("tool", "Unknown tool"),
|
||||
"error": tool_use.get("result"),
|
||||
"args": tool_use.get("args", {})
|
||||
})
|
||||
validation_errors = [
|
||||
{
|
||||
"tool": tool_use.get("tool", "Unknown tool"),
|
||||
"error": tool_use.get("result"),
|
||||
"args": tool_use.get("args", {}),
|
||||
}
|
||||
for tool_use in tool_uses
|
||||
if not tool_use.get("success", True)
|
||||
and tool_use.get("error_type") == "validation_error"
|
||||
]
|
||||
|
||||
validation_error_rate = len(validation_errors) / tool_count if tool_count > 0 else 0
|
||||
validation_error_rate = (
|
||||
len(validation_errors) / tool_count if tool_count > 0 else 0
|
||||
)
|
||||
|
||||
param_samples = []
|
||||
for i, tool_use in enumerate(tool_uses[:5]):
|
||||
@@ -168,7 +185,7 @@ class ParameterExtractionEvaluator(BaseEvaluator):
|
||||
|
||||
is_validation_error = error_type == "validation_error"
|
||||
|
||||
sample = f"Tool use #{i+1} - {tool_name}:\n"
|
||||
sample = f"Tool use #{i + 1} - {tool_name}:\n"
|
||||
sample += f"- Parameters: {json.dumps(tool_args, indent=2)}\n"
|
||||
sample += f"- Success: {'No' if not success else 'Yes'}"
|
||||
|
||||
@@ -187,13 +204,17 @@ class ParameterExtractionEvaluator(BaseEvaluator):
|
||||
tool_name = err.get("tool", "Unknown tool")
|
||||
error_msg = err.get("error", "Unknown error")
|
||||
args = err.get("args", {})
|
||||
validation_errors_info += f"\nValidation Error #{i+1}:\n- Tool: {tool_name}\n- Args: {json.dumps(args, indent=2)}\n- Error: {error_msg}"
|
||||
validation_errors_info += f"\nValidation Error #{i + 1}:\n- Tool: {tool_name}\n- Args: {json.dumps(args, indent=2)}\n- Error: {error_msg}"
|
||||
|
||||
if len(validation_errors) > 3:
|
||||
validation_errors_info += f"\n...and {len(validation_errors) - 3} more validation errors."
|
||||
validation_errors_info += (
|
||||
f"\n...and {len(validation_errors) - 3} more validation errors."
|
||||
)
|
||||
param_samples_text = "\n\n".join(param_samples)
|
||||
prompt = [
|
||||
{"role": "system", "content": """You are an expert evaluator assessing how well an AI agent extracts and formats PARAMETER VALUES for tool calls.
|
||||
{
|
||||
"role": "system",
|
||||
"content": """You are an expert evaluator assessing how well an AI agent extracts and formats PARAMETER VALUES for tool calls.
|
||||
|
||||
Your job is to evaluate ONLY whether the agent used the correct parameter VALUES, not whether the right tools were selected or how the tools were invoked.
|
||||
|
||||
@@ -216,8 +237,11 @@ Return your evaluation as JSON with these fields:
|
||||
- overall_score: number (average of all scores, 0-10)
|
||||
- feedback: string (focused ONLY on parameter value extraction quality)
|
||||
- improvement_suggestions: string (concrete suggestions for better parameter VALUE extraction)
|
||||
"""},
|
||||
{"role": "user", "content": f"""
|
||||
""",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""
|
||||
Agent role: {agent.role}
|
||||
{task_context}
|
||||
|
||||
@@ -226,15 +250,18 @@ Parameter extraction examples:
|
||||
{validation_errors_info}
|
||||
|
||||
Evaluate the quality of the agent's parameter extraction for this task.
|
||||
"""}
|
||||
""",
|
||||
},
|
||||
]
|
||||
|
||||
assert self.llm is not None
|
||||
if self.llm is None:
|
||||
raise ValueError("LLM must be initialized")
|
||||
response = self.llm.call(prompt)
|
||||
|
||||
try:
|
||||
evaluation_data = extract_json_from_llm_response(response)
|
||||
assert evaluation_data is not None
|
||||
if evaluation_data is None:
|
||||
raise ValueError("Failed to extract evaluation data from LLM response")
|
||||
|
||||
scores = evaluation_data.get("scores", {})
|
||||
accuracy = scores.get("accuracy", 5.0)
|
||||
@@ -251,18 +278,18 @@ Evaluate the quality of the agent's parameter extraction for this task.
|
||||
if "improvement_suggestions" in evaluation_data:
|
||||
feedback += f"Improvement Suggestions:\n{evaluation_data['improvement_suggestions']}"
|
||||
else:
|
||||
feedback += evaluation_data.get("feedback", "No detailed feedback available.")
|
||||
feedback += evaluation_data.get(
|
||||
"feedback", "No detailed feedback available."
|
||||
)
|
||||
|
||||
return EvaluationScore(
|
||||
score=overall_score,
|
||||
feedback=feedback,
|
||||
raw_response=response
|
||||
score=overall_score, feedback=feedback, raw_response=response
|
||||
)
|
||||
except Exception as e:
|
||||
return EvaluationScore(
|
||||
score=None,
|
||||
feedback=f"Error evaluating parameter extraction: {e}",
|
||||
raw_response=response
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
|
||||
@@ -273,8 +300,8 @@ class ToolInvocationEvaluator(BaseEvaluator):
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
agent: Agent,
|
||||
execution_trace: Dict[str, Any],
|
||||
agent: Agent | BaseAgent,
|
||||
execution_trace: dict[str, Any],
|
||||
final_output: str,
|
||||
task: Task | None = None,
|
||||
) -> EvaluationScore:
|
||||
@@ -288,7 +315,7 @@ class ToolInvocationEvaluator(BaseEvaluator):
|
||||
if tool_count == 0:
|
||||
return EvaluationScore(
|
||||
score=None,
|
||||
feedback="No tool usage detected. Cannot evaluate tool invocation."
|
||||
feedback="No tool usage detected. Cannot evaluate tool invocation.",
|
||||
)
|
||||
|
||||
for tool_use in tool_uses:
|
||||
@@ -296,7 +323,7 @@ class ToolInvocationEvaluator(BaseEvaluator):
|
||||
error_info = {
|
||||
"tool": tool_use.get("tool", "Unknown tool"),
|
||||
"error": tool_use.get("result"),
|
||||
"error_type": tool_use.get("error_type", "unknown_error")
|
||||
"error_type": tool_use.get("error_type", "unknown_error"),
|
||||
}
|
||||
tool_errors.append(error_info)
|
||||
|
||||
@@ -315,9 +342,11 @@ class ToolInvocationEvaluator(BaseEvaluator):
|
||||
tool_args = tool_use.get("args", {})
|
||||
success = tool_use.get("success", True) and not tool_use.get("error", False)
|
||||
error_type = tool_use.get("error_type", "") if not success else ""
|
||||
error_msg = tool_use.get("result", "No error") if not success else "No error"
|
||||
error_msg = (
|
||||
tool_use.get("result", "No error") if not success else "No error"
|
||||
)
|
||||
|
||||
sample = f"Tool invocation #{i+1}:\n"
|
||||
sample = f"Tool invocation #{i + 1}:\n"
|
||||
sample += f"- Tool: {tool_name}\n"
|
||||
sample += f"- Parameters: {json.dumps(tool_args, indent=2)}\n"
|
||||
sample += f"- Success: {'No' if not success else 'Yes'}\n"
|
||||
@@ -330,11 +359,13 @@ class ToolInvocationEvaluator(BaseEvaluator):
|
||||
if error_types:
|
||||
error_type_summary = "Error type breakdown:\n"
|
||||
for error_type, count in error_types.items():
|
||||
error_type_summary += f"- {error_type}: {count} occurrences ({(count/tool_count):.1%})\n"
|
||||
error_type_summary += f"- {error_type}: {count} occurrences ({(count / tool_count):.1%})\n"
|
||||
|
||||
invocation_samples_text = "\n\n".join(invocation_samples)
|
||||
prompt = [
|
||||
{"role": "system", "content": """You are an expert evaluator assessing how correctly an AI agent's tool invocations are STRUCTURED.
|
||||
{
|
||||
"role": "system",
|
||||
"content": """You are an expert evaluator assessing how correctly an AI agent's tool invocations are STRUCTURED.
|
||||
|
||||
Your job is to evaluate ONLY the structural and syntactical aspects of how the agent called tools, NOT which tools were selected or what parameter values were used.
|
||||
|
||||
@@ -359,8 +390,11 @@ Return your evaluation as JSON with these fields:
|
||||
- overall_score: number (average of all scores, 0-10)
|
||||
- feedback: string (focused ONLY on structural aspects of tool invocation)
|
||||
- improvement_suggestions: string (concrete suggestions for better structuring of tool calls)
|
||||
"""},
|
||||
{"role": "user", "content": f"""
|
||||
""",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""
|
||||
Agent role: {agent.role}
|
||||
{task_context}
|
||||
|
||||
@@ -371,15 +405,18 @@ Tool error rate: {error_rate:.2%} ({len(tool_errors)} errors out of {tool_count}
|
||||
{error_type_summary}
|
||||
|
||||
Evaluate the quality of the agent's tool invocation structure during this task.
|
||||
"""}
|
||||
""",
|
||||
},
|
||||
]
|
||||
|
||||
assert self.llm is not None
|
||||
if self.llm is None:
|
||||
raise ValueError("LLM must be initialized")
|
||||
response = self.llm.call(prompt)
|
||||
|
||||
try:
|
||||
evaluation_data = extract_json_from_llm_response(response)
|
||||
assert evaluation_data is not None
|
||||
if evaluation_data is None:
|
||||
raise ValueError("Failed to extract evaluation data from LLM response")
|
||||
scores = evaluation_data.get("scores", {})
|
||||
structure = scores.get("structure", 5.0)
|
||||
error_handling = scores.get("error_handling", 5.0)
|
||||
@@ -388,23 +425,25 @@ Evaluate the quality of the agent's tool invocation structure during this task.
|
||||
overall_score = float(evaluation_data.get("overall_score", 5.0))
|
||||
|
||||
feedback = "Tool Invocation Evaluation:\n"
|
||||
feedback += f"• Structure: {structure}/10 - Following proper syntax and format\n"
|
||||
feedback += (
|
||||
f"• Structure: {structure}/10 - Following proper syntax and format\n"
|
||||
)
|
||||
feedback += f"• Error Handling: {error_handling}/10 - Appropriately handling tool errors\n"
|
||||
feedback += f"• Invocation Patterns: {invocation_patterns}/10 - Proper sequencing and management of calls\n\n"
|
||||
|
||||
if "improvement_suggestions" in evaluation_data:
|
||||
feedback += f"Improvement Suggestions:\n{evaluation_data['improvement_suggestions']}"
|
||||
else:
|
||||
feedback += evaluation_data.get("feedback", "No detailed feedback available.")
|
||||
feedback += evaluation_data.get(
|
||||
"feedback", "No detailed feedback available."
|
||||
)
|
||||
|
||||
return EvaluationScore(
|
||||
score=overall_score,
|
||||
feedback=feedback,
|
||||
raw_response=response
|
||||
score=overall_score, feedback=feedback, raw_response=response
|
||||
)
|
||||
except Exception as e:
|
||||
return EvaluationScore(
|
||||
score=None,
|
||||
feedback=f"Error evaluating tool invocation: {e}",
|
||||
raw_response=response
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
@@ -1,12 +1,21 @@
|
||||
import inspect
|
||||
import warnings
|
||||
|
||||
from typing_extensions import Any
|
||||
import warnings
|
||||
from crewai.experimental.evaluation.experiment import ExperimentResults, ExperimentRunner
|
||||
from crewai import Crew, Agent
|
||||
|
||||
def assert_experiment_successfully(experiment_results: ExperimentResults, baseline_filepath: str | None = None) -> None:
|
||||
failed_tests = [result for result in experiment_results.results if not result.passed]
|
||||
from crewai import Agent, Crew
|
||||
from crewai.experimental.evaluation.experiment import (
|
||||
ExperimentResults,
|
||||
ExperimentRunner,
|
||||
)
|
||||
|
||||
|
||||
def assert_experiment_successfully(
|
||||
experiment_results: ExperimentResults, baseline_filepath: str | None = None
|
||||
) -> None:
|
||||
failed_tests = [
|
||||
result for result in experiment_results.results if not result.passed
|
||||
]
|
||||
|
||||
if failed_tests:
|
||||
detailed_failures: list[str] = []
|
||||
@@ -14,39 +23,54 @@ def assert_experiment_successfully(experiment_results: ExperimentResults, baseli
|
||||
for result in failed_tests:
|
||||
expected = result.expected_score
|
||||
actual = result.score
|
||||
detailed_failures.append(f"- {result.identifier}: expected {expected}, got {actual}")
|
||||
detailed_failures.append(
|
||||
f"- {result.identifier}: expected {expected}, got {actual}"
|
||||
)
|
||||
|
||||
failure_details = "\n".join(detailed_failures)
|
||||
raise AssertionError(f"The following test cases failed:\n{failure_details}")
|
||||
|
||||
baseline_filepath = baseline_filepath or _get_baseline_filepath_fallback()
|
||||
comparison = experiment_results.compare_with_baseline(baseline_filepath=baseline_filepath)
|
||||
comparison = experiment_results.compare_with_baseline(
|
||||
baseline_filepath=baseline_filepath
|
||||
)
|
||||
assert_experiment_no_regression(comparison)
|
||||
|
||||
|
||||
def assert_experiment_no_regression(comparison_result: dict[str, list[str]]) -> None:
|
||||
regressed = comparison_result.get("regressed", [])
|
||||
if regressed:
|
||||
raise AssertionError(f"Regression detected! The following tests that previously passed now fail: {regressed}")
|
||||
raise AssertionError(
|
||||
f"Regression detected! The following tests that previously passed now fail: {regressed}"
|
||||
)
|
||||
|
||||
missing_tests = comparison_result.get("missing_tests", [])
|
||||
if missing_tests:
|
||||
warnings.warn(
|
||||
f"Warning: {len(missing_tests)} tests from the baseline are missing in the current run: {missing_tests}",
|
||||
UserWarning
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
def run_experiment(dataset: list[dict[str, Any]], crew: Crew | None = None, agents: list[Agent] | None = None, verbose: bool = False) -> ExperimentResults:
|
||||
|
||||
def run_experiment(
|
||||
dataset: list[dict[str, Any]],
|
||||
crew: Crew | None = None,
|
||||
agents: list[Agent] | None = None,
|
||||
verbose: bool = False,
|
||||
) -> ExperimentResults:
|
||||
runner = ExperimentRunner(dataset=dataset)
|
||||
|
||||
return runner.run(agents=agents, crew=crew, print_summary=verbose)
|
||||
|
||||
|
||||
def _get_baseline_filepath_fallback() -> str:
|
||||
test_func_name = "experiment_fallback"
|
||||
|
||||
try:
|
||||
current_frame = inspect.currentframe()
|
||||
if current_frame is not None:
|
||||
test_func_name = current_frame.f_back.f_back.f_code.co_name # type: ignore[union-attr]
|
||||
test_func_name = current_frame.f_back.f_back.f_code.co_name # type: ignore[union-attr]
|
||||
except Exception:
|
||||
...
|
||||
return f"{test_func_name}_results.json"
|
||||
return f"{test_func_name}_results.json"
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from crewai.flow.flow import Flow, start, listen, or_, and_, router
|
||||
from crewai.flow.flow import Flow, and_, listen, or_, router, start
|
||||
from crewai.flow.persistence import persist
|
||||
|
||||
__all__ = ["Flow", "start", "listen", "or_", "and_", "router", "persist"]
|
||||
|
||||
__all__ = ["Flow", "and_", "listen", "or_", "persist", "router", "start"]
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import inspect
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field, InstanceOf, model_validator
|
||||
|
||||
@@ -14,7 +13,7 @@ class FlowTrackable(BaseModel):
|
||||
inspecting the call stack.
|
||||
"""
|
||||
|
||||
parent_flow: Optional[InstanceOf[Flow]] = Field(
|
||||
parent_flow: InstanceOf[Flow] | None = Field(
|
||||
default=None,
|
||||
description="The parent flow of the instance, if it was created inside a flow.",
|
||||
)
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
# flow_visualizer.py
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from pyvis.network import Network
|
||||
from pyvis.network import Network # type: ignore[import-untyped]
|
||||
|
||||
from crewai.flow.config import COLORS, NODE_STYLES
|
||||
from crewai.flow.html_template_handler import HTMLTemplateHandler
|
||||
from crewai.flow.legend_generator import generate_legend_items_html, get_legend_items
|
||||
from crewai.flow.path_utils import safe_path_join, validate_path_exists
|
||||
from crewai.flow.path_utils import safe_path_join
|
||||
from crewai.flow.utils import calculate_node_levels
|
||||
from crewai.flow.visualization_utils import (
|
||||
add_edges,
|
||||
@@ -34,13 +33,13 @@ class FlowPlot:
|
||||
ValueError
|
||||
If flow object is invalid or missing required attributes.
|
||||
"""
|
||||
if not hasattr(flow, '_methods'):
|
||||
if not hasattr(flow, "_methods"):
|
||||
raise ValueError("Invalid flow object: missing '_methods' attribute")
|
||||
if not hasattr(flow, '_listeners'):
|
||||
if not hasattr(flow, "_listeners"):
|
||||
raise ValueError("Invalid flow object: missing '_listeners' attribute")
|
||||
if not hasattr(flow, '_start_methods'):
|
||||
if not hasattr(flow, "_start_methods"):
|
||||
raise ValueError("Invalid flow object: missing '_start_methods' attribute")
|
||||
|
||||
|
||||
self.flow = flow
|
||||
self.colors = COLORS
|
||||
self.node_styles = NODE_STYLES
|
||||
@@ -65,7 +64,7 @@ class FlowPlot:
|
||||
"""
|
||||
if not filename or not isinstance(filename, str):
|
||||
raise ValueError("Filename must be a non-empty string")
|
||||
|
||||
|
||||
try:
|
||||
# Initialize network
|
||||
net = Network(
|
||||
@@ -96,32 +95,34 @@ class FlowPlot:
|
||||
try:
|
||||
node_levels = calculate_node_levels(self.flow)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to calculate node levels: {str(e)}")
|
||||
raise ValueError(f"Failed to calculate node levels: {e!s}") from e
|
||||
|
||||
# Compute positions
|
||||
try:
|
||||
node_positions = compute_positions(self.flow, node_levels)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to compute node positions: {str(e)}")
|
||||
raise ValueError(f"Failed to compute node positions: {e!s}") from e
|
||||
|
||||
# Add nodes to the network
|
||||
try:
|
||||
add_nodes_to_network(net, self.flow, node_positions, self.node_styles)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to add nodes to network: {str(e)}")
|
||||
raise RuntimeError(f"Failed to add nodes to network: {e!s}") from e
|
||||
|
||||
# Add edges to the network
|
||||
try:
|
||||
add_edges(net, self.flow, node_positions, self.colors)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to add edges to network: {str(e)}")
|
||||
raise RuntimeError(f"Failed to add edges to network: {e!s}") from e
|
||||
|
||||
# Generate HTML
|
||||
try:
|
||||
network_html = net.generate_html()
|
||||
final_html_content = self._generate_final_html(network_html)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to generate network visualization: {str(e)}")
|
||||
raise RuntimeError(
|
||||
f"Failed to generate network visualization: {e!s}"
|
||||
) from e
|
||||
|
||||
# Save the final HTML content to the file
|
||||
try:
|
||||
@@ -129,12 +130,16 @@ class FlowPlot:
|
||||
f.write(final_html_content)
|
||||
print(f"Plot saved as {filename}.html")
|
||||
except IOError as e:
|
||||
raise IOError(f"Failed to save flow visualization to {filename}.html: {str(e)}")
|
||||
raise IOError(
|
||||
f"Failed to save flow visualization to {filename}.html: {e!s}"
|
||||
) from e
|
||||
|
||||
except (ValueError, RuntimeError, IOError) as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Unexpected error during flow visualization: {str(e)}")
|
||||
raise RuntimeError(
|
||||
f"Unexpected error during flow visualization: {e!s}"
|
||||
) from e
|
||||
finally:
|
||||
self._cleanup_pyvis_lib()
|
||||
|
||||
@@ -165,7 +170,9 @@ class FlowPlot:
|
||||
try:
|
||||
# Extract just the body content from the generated HTML
|
||||
current_dir = os.path.dirname(__file__)
|
||||
template_path = safe_path_join("assets", "crewai_flow_visual_template.html", root=current_dir)
|
||||
template_path = safe_path_join(
|
||||
"assets", "crewai_flow_visual_template.html", root=current_dir
|
||||
)
|
||||
logo_path = safe_path_join("assets", "crewai_logo.svg", root=current_dir)
|
||||
|
||||
if not os.path.exists(template_path):
|
||||
@@ -179,12 +186,9 @@ class FlowPlot:
|
||||
# Generate the legend items HTML
|
||||
legend_items = get_legend_items(self.colors)
|
||||
legend_items_html = generate_legend_items_html(legend_items)
|
||||
final_html_content = html_handler.generate_final_html(
|
||||
network_body, legend_items_html
|
||||
)
|
||||
return final_html_content
|
||||
return html_handler.generate_final_html(network_body, legend_items_html)
|
||||
except Exception as e:
|
||||
raise IOError(f"Failed to generate visualization HTML: {str(e)}")
|
||||
raise IOError(f"Failed to generate visualization HTML: {e!s}") from e
|
||||
|
||||
def _cleanup_pyvis_lib(self):
|
||||
"""
|
||||
@@ -197,6 +201,7 @@ class FlowPlot:
|
||||
lib_folder = safe_path_join("lib", root=os.getcwd())
|
||||
if os.path.exists(lib_folder) and os.path.isdir(lib_folder):
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(lib_folder)
|
||||
except ValueError as e:
|
||||
print(f"Error validating lib folder path: {e}")
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import base64
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from crewai.flow.path_utils import safe_path_join, validate_path_exists
|
||||
from crewai.flow.path_utils import validate_path_exists
|
||||
|
||||
|
||||
class HTMLTemplateHandler:
|
||||
@@ -28,7 +27,7 @@ class HTMLTemplateHandler:
|
||||
self.template_path = validate_path_exists(template_path, "file")
|
||||
self.logo_path = validate_path_exists(logo_path, "file")
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid template or logo path: {e}")
|
||||
raise ValueError(f"Invalid template or logo path: {e}") from e
|
||||
|
||||
def read_template(self):
|
||||
"""Read and return the HTML template file contents."""
|
||||
@@ -53,23 +52,23 @@ class HTMLTemplateHandler:
|
||||
if "border" in item:
|
||||
legend_items_html += f"""
|
||||
<div class="legend-item">
|
||||
<div class="legend-color-box" style="background-color: {item['color']}; border: 2px dashed {item['border']};"></div>
|
||||
<div>{item['label']}</div>
|
||||
<div class="legend-color-box" style="background-color: {item["color"]}; border: 2px dashed {item["border"]};"></div>
|
||||
<div>{item["label"]}</div>
|
||||
</div>
|
||||
"""
|
||||
elif item.get("dashed") is not None:
|
||||
style = "dashed" if item["dashed"] else "solid"
|
||||
legend_items_html += f"""
|
||||
<div class="legend-item">
|
||||
<div class="legend-{style}" style="border-bottom: 2px {style} {item['color']};"></div>
|
||||
<div>{item['label']}</div>
|
||||
<div class="legend-{style}" style="border-bottom: 2px {style} {item["color"]};"></div>
|
||||
<div>{item["label"]}</div>
|
||||
</div>
|
||||
"""
|
||||
else:
|
||||
legend_items_html += f"""
|
||||
<div class="legend-item">
|
||||
<div class="legend-color-box" style="background-color: {item['color']};"></div>
|
||||
<div>{item['label']}</div>
|
||||
<div class="legend-color-box" style="background-color: {item["color"]};"></div>
|
||||
<div>{item["label"]}</div>
|
||||
</div>
|
||||
"""
|
||||
return legend_items_html
|
||||
@@ -79,15 +78,9 @@ class HTMLTemplateHandler:
|
||||
html_template = self.read_template()
|
||||
logo_svg_base64 = self.encode_logo()
|
||||
|
||||
final_html_content = html_template.replace("{{ title }}", title)
|
||||
final_html_content = final_html_content.replace(
|
||||
"{{ network_content }}", network_body
|
||||
return (
|
||||
html_template.replace("{{ title }}", title)
|
||||
.replace("{{ network_content }}", network_body)
|
||||
.replace("{{ logo_svg_base64 }}", logo_svg_base64)
|
||||
.replace("<!-- LEGEND_ITEMS_PLACEHOLDER -->", legend_items_html)
|
||||
)
|
||||
final_html_content = final_html_content.replace(
|
||||
"{{ logo_svg_base64 }}", logo_svg_base64
|
||||
)
|
||||
final_html_content = final_html_content.replace(
|
||||
"<!-- LEGEND_ITEMS_PLACEHOLDER -->", legend_items_html
|
||||
)
|
||||
|
||||
return final_html_content
|
||||
|
||||
@@ -5,12 +5,10 @@ This module provides utilities for secure path handling to prevent directory
|
||||
traversal attacks and ensure paths remain within allowed boundaries.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
|
||||
def safe_path_join(*parts: str, root: Union[str, Path, None] = None) -> str:
|
||||
def safe_path_join(*parts: str, root: str | Path | None = None) -> str:
|
||||
"""
|
||||
Safely join path components and ensure the result is within allowed boundaries.
|
||||
|
||||
@@ -43,25 +41,25 @@ def safe_path_join(*parts: str, root: Union[str, Path, None] = None) -> str:
|
||||
|
||||
# Establish root directory
|
||||
root_path = Path(root).resolve() if root else Path.cwd()
|
||||
|
||||
|
||||
# Join and resolve the full path
|
||||
full_path = Path(root_path, *clean_parts).resolve()
|
||||
|
||||
|
||||
# Check if the resolved path is within root
|
||||
if not str(full_path).startswith(str(root_path)):
|
||||
raise ValueError(
|
||||
f"Invalid path: Potential directory traversal. Path must be within {root_path}"
|
||||
)
|
||||
|
||||
|
||||
return str(full_path)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, ValueError):
|
||||
raise
|
||||
raise ValueError(f"Invalid path components: {str(e)}")
|
||||
raise ValueError(f"Invalid path components: {e!s}") from e
|
||||
|
||||
|
||||
def validate_path_exists(path: Union[str, Path], file_type: str = "file") -> str:
|
||||
def validate_path_exists(path: str | Path, file_type: str = "file") -> str:
|
||||
"""
|
||||
Validate that a path exists and is of the expected type.
|
||||
|
||||
@@ -84,24 +82,24 @@ def validate_path_exists(path: Union[str, Path], file_type: str = "file") -> str
|
||||
"""
|
||||
try:
|
||||
path_obj = Path(path).resolve()
|
||||
|
||||
|
||||
if not path_obj.exists():
|
||||
raise ValueError(f"Path does not exist: {path}")
|
||||
|
||||
|
||||
if file_type == "file" and not path_obj.is_file():
|
||||
raise ValueError(f"Path is not a file: {path}")
|
||||
elif file_type == "directory" and not path_obj.is_dir():
|
||||
if file_type == "directory" and not path_obj.is_dir():
|
||||
raise ValueError(f"Path is not a directory: {path}")
|
||||
|
||||
|
||||
return str(path_obj)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, ValueError):
|
||||
raise
|
||||
raise ValueError(f"Invalid path: {str(e)}")
|
||||
raise ValueError(f"Invalid path: {e!s}") from e
|
||||
|
||||
|
||||
def list_files(directory: Union[str, Path], pattern: str = "*") -> List[str]:
|
||||
def list_files(directory: str | Path, pattern: str = "*") -> list[str]:
|
||||
"""
|
||||
Safely list files in a directory matching a pattern.
|
||||
|
||||
@@ -126,10 +124,10 @@ def list_files(directory: Union[str, Path], pattern: str = "*") -> List[str]:
|
||||
dir_path = Path(directory).resolve()
|
||||
if not dir_path.is_dir():
|
||||
raise ValueError(f"Not a directory: {directory}")
|
||||
|
||||
|
||||
return [str(p) for p in dir_path.glob(pattern) if p.is_file()]
|
||||
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, ValueError):
|
||||
raise
|
||||
raise ValueError(f"Error listing files: {str(e)}")
|
||||
raise ValueError(f"Error listing files: {e!s}") from e
|
||||
|
||||
@@ -4,7 +4,7 @@ CrewAI Flow Persistence.
|
||||
This module provides interfaces and implementations for persisting flow states.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, TypeVar, Union
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -12,7 +12,7 @@ from crewai.flow.persistence.base import FlowPersistence
|
||||
from crewai.flow.persistence.decorators import persist
|
||||
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
|
||||
|
||||
__all__ = ["FlowPersistence", "persist", "SQLiteFlowPersistence"]
|
||||
__all__ = ["FlowPersistence", "SQLiteFlowPersistence", "persist"]
|
||||
|
||||
StateType = TypeVar('StateType', bound=Union[Dict[str, Any], BaseModel])
|
||||
DictStateType = Dict[str, Any]
|
||||
StateType = TypeVar("StateType", bound=dict[str, Any] | BaseModel)
|
||||
DictStateType = dict[str, Any]
|
||||
|
||||
@@ -1,53 +1,47 @@
|
||||
"""Base class for flow state persistence."""
|
||||
|
||||
import abc
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class FlowPersistence(abc.ABC):
|
||||
"""Abstract base class for flow state persistence.
|
||||
|
||||
|
||||
This class defines the interface that all persistence implementations must follow.
|
||||
It supports both structured (Pydantic BaseModel) and unstructured (dict) states.
|
||||
"""
|
||||
|
||||
|
||||
@abc.abstractmethod
|
||||
def init_db(self) -> None:
|
||||
"""Initialize the persistence backend.
|
||||
|
||||
|
||||
This method should handle any necessary setup, such as:
|
||||
- Creating tables
|
||||
- Establishing connections
|
||||
- Setting up indexes
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abc.abstractmethod
|
||||
def save_state(
|
||||
self,
|
||||
flow_uuid: str,
|
||||
method_name: str,
|
||||
state_data: Union[Dict[str, Any], BaseModel]
|
||||
self, flow_uuid: str, method_name: str, state_data: dict[str, Any] | BaseModel
|
||||
) -> None:
|
||||
"""Persist the flow state after method completion.
|
||||
|
||||
|
||||
Args:
|
||||
flow_uuid: Unique identifier for the flow instance
|
||||
method_name: Name of the method that just completed
|
||||
state_data: Current state data (either dict or Pydantic model)
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]:
|
||||
def load_state(self, flow_uuid: str) -> dict[str, Any] | None:
|
||||
"""Load the most recent state for a given flow UUID.
|
||||
|
||||
|
||||
Args:
|
||||
flow_uuid: Unique identifier for the flow instance
|
||||
|
||||
|
||||
Returns:
|
||||
The most recent state as a dictionary, or None if no state exists
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -24,13 +24,10 @@ Example:
|
||||
import asyncio
|
||||
import functools
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
@@ -48,7 +45,7 @@ LOG_MESSAGES = {
|
||||
"save_state": "Saving flow state to memory for ID: {}",
|
||||
"save_error": "Failed to persist state for method {}: {}",
|
||||
"state_missing": "Flow instance has no state",
|
||||
"id_missing": "Flow state must have an 'id' field for persistence"
|
||||
"id_missing": "Flow state must have an 'id' field for persistence",
|
||||
}
|
||||
|
||||
|
||||
@@ -58,7 +55,13 @@ class PersistenceDecorator:
|
||||
_printer = Printer() # Class-level printer instance
|
||||
|
||||
@classmethod
|
||||
def persist_state(cls, flow_instance: Any, method_name: str, persistence_instance: FlowPersistence, verbose: bool = False) -> None:
|
||||
def persist_state(
|
||||
cls,
|
||||
flow_instance: Any,
|
||||
method_name: str,
|
||||
persistence_instance: FlowPersistence,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
"""Persist flow state with proper error handling and logging.
|
||||
|
||||
This method handles the persistence of flow state data, including proper
|
||||
@@ -76,22 +79,24 @@ class PersistenceDecorator:
|
||||
AttributeError: If flow instance lacks required state attributes
|
||||
"""
|
||||
try:
|
||||
state = getattr(flow_instance, 'state', None)
|
||||
state = getattr(flow_instance, "state", None)
|
||||
if state is None:
|
||||
raise ValueError("Flow instance has no state")
|
||||
|
||||
flow_uuid: Optional[str] = None
|
||||
flow_uuid: str | None = None
|
||||
if isinstance(state, dict):
|
||||
flow_uuid = state.get('id')
|
||||
flow_uuid = state.get("id")
|
||||
elif isinstance(state, BaseModel):
|
||||
flow_uuid = getattr(state, 'id', None)
|
||||
flow_uuid = getattr(state, "id", None)
|
||||
|
||||
if not flow_uuid:
|
||||
raise ValueError("Flow state must have an 'id' field for persistence")
|
||||
|
||||
# Log state saving only if verbose is True
|
||||
if verbose:
|
||||
cls._printer.print(LOG_MESSAGES["save_state"].format(flow_uuid), color="cyan")
|
||||
cls._printer.print(
|
||||
LOG_MESSAGES["save_state"].format(flow_uuid), color="cyan"
|
||||
)
|
||||
logger.info(LOG_MESSAGES["save_state"].format(flow_uuid))
|
||||
|
||||
try:
|
||||
@@ -104,12 +109,12 @@ class PersistenceDecorator:
|
||||
error_msg = LOG_MESSAGES["save_error"].format(method_name, str(e))
|
||||
cls._printer.print(error_msg, color="red")
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(f"State persistence failed: {str(e)}") from e
|
||||
except AttributeError:
|
||||
raise RuntimeError(f"State persistence failed: {e!s}") from e
|
||||
except AttributeError as e:
|
||||
error_msg = LOG_MESSAGES["state_missing"]
|
||||
cls._printer.print(error_msg, color="red")
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
raise ValueError(error_msg) from e
|
||||
except (TypeError, ValueError) as e:
|
||||
error_msg = LOG_MESSAGES["id_missing"]
|
||||
cls._printer.print(error_msg, color="red")
|
||||
@@ -117,7 +122,7 @@ class PersistenceDecorator:
|
||||
raise ValueError(error_msg) from e
|
||||
|
||||
|
||||
def persist(persistence: Optional[FlowPersistence] = None, verbose: bool = False):
|
||||
def persist(persistence: FlowPersistence | None = None, verbose: bool = False):
|
||||
"""Decorator to persist flow state.
|
||||
|
||||
This decorator can be applied at either the class level or method level.
|
||||
@@ -144,111 +149,151 @@ def persist(persistence: Optional[FlowPersistence] = None, verbose: bool = False
|
||||
def begin(self):
|
||||
pass
|
||||
"""
|
||||
def decorator(target: Union[Type, Callable[..., T]]) -> Union[Type, Callable[..., T]]:
|
||||
|
||||
def decorator(target: type | Callable[..., T]) -> type | Callable[..., T]:
|
||||
"""Decorator that handles both class and method decoration."""
|
||||
actual_persistence = persistence or SQLiteFlowPersistence()
|
||||
|
||||
if isinstance(target, type):
|
||||
# Class decoration
|
||||
original_init = getattr(target, "__init__")
|
||||
original_init = target.__init__ # type: ignore[misc]
|
||||
|
||||
@functools.wraps(original_init)
|
||||
def new_init(self: Any, *args: Any, **kwargs: Any) -> None:
|
||||
if 'persistence' not in kwargs:
|
||||
kwargs['persistence'] = actual_persistence
|
||||
if "persistence" not in kwargs:
|
||||
kwargs["persistence"] = actual_persistence
|
||||
original_init(self, *args, **kwargs)
|
||||
|
||||
setattr(target, "__init__", new_init)
|
||||
target.__init__ = new_init # type: ignore[misc]
|
||||
|
||||
# Store original methods to preserve their decorators
|
||||
original_methods = {}
|
||||
|
||||
for name, method in target.__dict__.items():
|
||||
if callable(method) and (
|
||||
hasattr(method, "__is_start_method__") or
|
||||
hasattr(method, "__trigger_methods__") or
|
||||
hasattr(method, "__condition_type__") or
|
||||
hasattr(method, "__is_flow_method__") or
|
||||
hasattr(method, "__is_router__")
|
||||
):
|
||||
original_methods[name] = method
|
||||
original_methods = {
|
||||
name: method
|
||||
for name, method in target.__dict__.items()
|
||||
if callable(method)
|
||||
and (
|
||||
hasattr(method, "__is_start_method__")
|
||||
or hasattr(method, "__trigger_methods__")
|
||||
or hasattr(method, "__condition_type__")
|
||||
or hasattr(method, "__is_flow_method__")
|
||||
or hasattr(method, "__is_router__")
|
||||
)
|
||||
}
|
||||
|
||||
# Create wrapped versions of the methods that include persistence
|
||||
for name, method in original_methods.items():
|
||||
if asyncio.iscoroutinefunction(method):
|
||||
# Create a closure to capture the current name and method
|
||||
def create_async_wrapper(method_name: str, original_method: Callable):
|
||||
def create_async_wrapper(
|
||||
method_name: str, original_method: Callable
|
||||
):
|
||||
@functools.wraps(original_method)
|
||||
async def method_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
async def method_wrapper(
|
||||
self: Any, *args: Any, **kwargs: Any
|
||||
) -> Any:
|
||||
result = await original_method(self, *args, **kwargs)
|
||||
PersistenceDecorator.persist_state(self, method_name, actual_persistence, verbose)
|
||||
PersistenceDecorator.persist_state(
|
||||
self, method_name, actual_persistence, verbose
|
||||
)
|
||||
return result
|
||||
|
||||
return method_wrapper
|
||||
|
||||
wrapped = create_async_wrapper(name, method)
|
||||
|
||||
# Preserve all original decorators and attributes
|
||||
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
|
||||
for attr in [
|
||||
"__is_start_method__",
|
||||
"__trigger_methods__",
|
||||
"__condition_type__",
|
||||
"__is_router__",
|
||||
]:
|
||||
if hasattr(method, attr):
|
||||
setattr(wrapped, attr, getattr(method, attr))
|
||||
setattr(wrapped, "__is_flow_method__", True)
|
||||
wrapped.__is_flow_method__ = True # type: ignore[attr-defined]
|
||||
|
||||
# Update the class with the wrapped method
|
||||
setattr(target, name, wrapped)
|
||||
else:
|
||||
# Create a closure to capture the current name and method
|
||||
def create_sync_wrapper(method_name: str, original_method: Callable):
|
||||
def create_sync_wrapper(
|
||||
method_name: str, original_method: Callable
|
||||
):
|
||||
@functools.wraps(original_method)
|
||||
def method_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
result = original_method(self, *args, **kwargs)
|
||||
PersistenceDecorator.persist_state(self, method_name, actual_persistence, verbose)
|
||||
PersistenceDecorator.persist_state(
|
||||
self, method_name, actual_persistence, verbose
|
||||
)
|
||||
return result
|
||||
|
||||
return method_wrapper
|
||||
|
||||
wrapped = create_sync_wrapper(name, method)
|
||||
|
||||
# Preserve all original decorators and attributes
|
||||
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
|
||||
for attr in [
|
||||
"__is_start_method__",
|
||||
"__trigger_methods__",
|
||||
"__condition_type__",
|
||||
"__is_router__",
|
||||
]:
|
||||
if hasattr(method, attr):
|
||||
setattr(wrapped, attr, getattr(method, attr))
|
||||
setattr(wrapped, "__is_flow_method__", True)
|
||||
wrapped.__is_flow_method__ = True # type: ignore[attr-defined]
|
||||
|
||||
# Update the class with the wrapped method
|
||||
setattr(target, name, wrapped)
|
||||
|
||||
return target
|
||||
else:
|
||||
# Method decoration
|
||||
method = target
|
||||
setattr(method, "__is_flow_method__", True)
|
||||
# Method decoration
|
||||
method = target
|
||||
method.__is_flow_method__ = True # type: ignore[attr-defined]
|
||||
|
||||
if asyncio.iscoroutinefunction(method):
|
||||
@functools.wraps(method)
|
||||
async def method_async_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
|
||||
method_coro = method(flow_instance, *args, **kwargs)
|
||||
if asyncio.iscoroutine(method_coro):
|
||||
result = await method_coro
|
||||
else:
|
||||
result = method_coro
|
||||
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence, verbose)
|
||||
return result
|
||||
if asyncio.iscoroutinefunction(method):
|
||||
|
||||
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
|
||||
if hasattr(method, attr):
|
||||
setattr(method_async_wrapper, attr, getattr(method, attr))
|
||||
setattr(method_async_wrapper, "__is_flow_method__", True)
|
||||
return cast(Callable[..., T], method_async_wrapper)
|
||||
else:
|
||||
@functools.wraps(method)
|
||||
def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
|
||||
result = method(flow_instance, *args, **kwargs)
|
||||
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence, verbose)
|
||||
return result
|
||||
@functools.wraps(method)
|
||||
async def method_async_wrapper(
|
||||
flow_instance: Any, *args: Any, **kwargs: Any
|
||||
) -> T:
|
||||
method_coro = method(flow_instance, *args, **kwargs)
|
||||
if asyncio.iscoroutine(method_coro):
|
||||
result = await method_coro
|
||||
else:
|
||||
result = method_coro
|
||||
PersistenceDecorator.persist_state(
|
||||
flow_instance, method.__name__, actual_persistence, verbose
|
||||
)
|
||||
return result
|
||||
|
||||
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
|
||||
if hasattr(method, attr):
|
||||
setattr(method_sync_wrapper, attr, getattr(method, attr))
|
||||
setattr(method_sync_wrapper, "__is_flow_method__", True)
|
||||
return cast(Callable[..., T], method_sync_wrapper)
|
||||
for attr in [
|
||||
"__is_start_method__",
|
||||
"__trigger_methods__",
|
||||
"__condition_type__",
|
||||
"__is_router__",
|
||||
]:
|
||||
if hasattr(method, attr):
|
||||
setattr(method_async_wrapper, attr, getattr(method, attr))
|
||||
method_async_wrapper.__is_flow_method__ = True # type: ignore[attr-defined]
|
||||
return cast(Callable[..., T], method_async_wrapper)
|
||||
|
||||
@functools.wraps(method)
|
||||
def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
|
||||
result = method(flow_instance, *args, **kwargs)
|
||||
PersistenceDecorator.persist_state(
|
||||
flow_instance, method.__name__, actual_persistence, verbose
|
||||
)
|
||||
return result
|
||||
|
||||
for attr in [
|
||||
"__is_start_method__",
|
||||
"__trigger_methods__",
|
||||
"__condition_type__",
|
||||
"__is_router__",
|
||||
]:
|
||||
if hasattr(method, attr):
|
||||
setattr(method_sync_wrapper, attr, getattr(method, attr))
|
||||
method_sync_wrapper.__is_flow_method__ = True # type: ignore[attr-defined]
|
||||
return cast(Callable[..., T], method_sync_wrapper)
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -6,7 +6,7 @@ import json
|
||||
import sqlite3
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -23,7 +23,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
|
||||
db_path: str
|
||||
|
||||
def __init__(self, db_path: Optional[str] = None):
|
||||
def __init__(self, db_path: str | None = None):
|
||||
"""Initialize SQLite persistence.
|
||||
|
||||
Args:
|
||||
@@ -70,7 +70,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
self,
|
||||
flow_uuid: str,
|
||||
method_name: str,
|
||||
state_data: Union[Dict[str, Any], BaseModel],
|
||||
state_data: dict[str, Any] | BaseModel,
|
||||
) -> None:
|
||||
"""Save the current flow state to SQLite.
|
||||
|
||||
@@ -107,7 +107,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
),
|
||||
)
|
||||
|
||||
def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]:
|
||||
def load_state(self, flow_uuid: str) -> dict[str, Any] | None:
|
||||
"""Load the most recent state for a given flow UUID.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -5,6 +5,7 @@ the Flow system.
|
||||
"""
|
||||
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from typing_extensions import NotRequired, Required
|
||||
|
||||
|
||||
|
||||
@@ -17,10 +17,10 @@ import ast
|
||||
import inspect
|
||||
import textwrap
|
||||
from collections import defaultdict, deque
|
||||
from typing import Any, Deque, Dict, List, Optional, Set, Union
|
||||
from typing import Any
|
||||
|
||||
|
||||
def get_possible_return_constants(function: Any) -> Optional[List[str]]:
|
||||
def get_possible_return_constants(function: Any) -> list[str] | None:
|
||||
try:
|
||||
source = inspect.getsource(function)
|
||||
except OSError:
|
||||
@@ -58,12 +58,12 @@ def get_possible_return_constants(function: Any) -> Optional[List[str]]:
|
||||
target = node.targets[0]
|
||||
if isinstance(target, ast.Name):
|
||||
var_name = target.id
|
||||
dict_values = []
|
||||
# Extract string values from the dictionary
|
||||
for val in node.value.values:
|
||||
if isinstance(val, ast.Constant) and isinstance(val.value, str):
|
||||
dict_values.append(val.value)
|
||||
# If non-string, skip or just ignore
|
||||
dict_values = [
|
||||
val.value
|
||||
for val in node.value.values
|
||||
if isinstance(val, ast.Constant) and isinstance(val.value, str)
|
||||
]
|
||||
if dict_values:
|
||||
dict_definitions[var_name] = dict_values
|
||||
self.generic_visit(node)
|
||||
@@ -94,7 +94,7 @@ def get_possible_return_constants(function: Any) -> Optional[List[str]]:
|
||||
return list(return_values) if return_values else None
|
||||
|
||||
|
||||
def calculate_node_levels(flow: Any) -> Dict[str, int]:
|
||||
def calculate_node_levels(flow: Any) -> dict[str, int]:
|
||||
"""
|
||||
Calculate the hierarchical level of each node in the flow.
|
||||
|
||||
@@ -118,10 +118,10 @@ def calculate_node_levels(flow: Any) -> Dict[str, int]:
|
||||
- Handles both OR and AND conditions for listeners
|
||||
- Processes router paths separately
|
||||
"""
|
||||
levels: Dict[str, int] = {}
|
||||
queue: Deque[str] = deque()
|
||||
visited: Set[str] = set()
|
||||
pending_and_listeners: Dict[str, Set[str]] = {}
|
||||
levels: dict[str, int] = {}
|
||||
queue: deque[str] = deque()
|
||||
visited: set[str] = set()
|
||||
pending_and_listeners: dict[str, set[str]] = {}
|
||||
|
||||
# Make all start methods at level 0
|
||||
for method_name, method in flow._methods.items():
|
||||
@@ -172,7 +172,7 @@ def calculate_node_levels(flow: Any) -> Dict[str, int]:
|
||||
return levels
|
||||
|
||||
|
||||
def count_outgoing_edges(flow: Any) -> Dict[str, int]:
|
||||
def count_outgoing_edges(flow: Any) -> dict[str, int]:
|
||||
"""
|
||||
Count the number of outgoing edges for each method in the flow.
|
||||
|
||||
@@ -197,7 +197,7 @@ def count_outgoing_edges(flow: Any) -> Dict[str, int]:
|
||||
return counts
|
||||
|
||||
|
||||
def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]:
|
||||
def build_ancestor_dict(flow: Any) -> dict[str, set[str]]:
|
||||
"""
|
||||
Build a dictionary mapping each node to its ancestor nodes.
|
||||
|
||||
@@ -211,8 +211,8 @@ def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]:
|
||||
Dict[str, Set[str]]
|
||||
Dictionary mapping each node to a set of its ancestor nodes.
|
||||
"""
|
||||
ancestors: Dict[str, Set[str]] = {node: set() for node in flow._methods}
|
||||
visited: Set[str] = set()
|
||||
ancestors: dict[str, set[str]] = {node: set() for node in flow._methods}
|
||||
visited: set[str] = set()
|
||||
for node in flow._methods:
|
||||
if node not in visited:
|
||||
dfs_ancestors(node, ancestors, visited, flow)
|
||||
@@ -220,7 +220,7 @@ def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]:
|
||||
|
||||
|
||||
def dfs_ancestors(
|
||||
node: str, ancestors: Dict[str, Set[str]], visited: Set[str], flow: Any
|
||||
node: str, ancestors: dict[str, set[str]], visited: set[str], flow: Any
|
||||
) -> None:
|
||||
"""
|
||||
Perform depth-first search to build ancestor relationships.
|
||||
@@ -265,7 +265,7 @@ def dfs_ancestors(
|
||||
|
||||
|
||||
def is_ancestor(
|
||||
node: str, ancestor_candidate: str, ancestors: Dict[str, Set[str]]
|
||||
node: str, ancestor_candidate: str, ancestors: dict[str, set[str]]
|
||||
) -> bool:
|
||||
"""
|
||||
Check if one node is an ancestor of another.
|
||||
@@ -287,7 +287,7 @@ def is_ancestor(
|
||||
return ancestor_candidate in ancestors.get(node, set())
|
||||
|
||||
|
||||
def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]:
|
||||
def build_parent_children_dict(flow: Any) -> dict[str, list[str]]:
|
||||
"""
|
||||
Build a dictionary mapping parent nodes to their children.
|
||||
|
||||
@@ -307,7 +307,7 @@ def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]:
|
||||
- Maps router methods to their paths and listeners
|
||||
- Children lists are sorted for consistent ordering
|
||||
"""
|
||||
parent_children: Dict[str, List[str]] = {}
|
||||
parent_children: dict[str, list[str]] = {}
|
||||
|
||||
# Map listeners to their trigger methods
|
||||
for listener_name, (_, trigger_methods) in flow._listeners.items():
|
||||
@@ -332,7 +332,7 @@ def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]:
|
||||
|
||||
|
||||
def get_child_index(
|
||||
parent: str, child: str, parent_children: Dict[str, List[str]]
|
||||
parent: str, child: str, parent_children: dict[str, list[str]]
|
||||
) -> int:
|
||||
"""
|
||||
Get the index of a child node in its parent's sorted children list.
|
||||
@@ -364,7 +364,7 @@ def process_router_paths(flow, current, current_level, levels, queue):
|
||||
paths = flow._router_paths.get(current, [])
|
||||
for path in paths:
|
||||
for listener_name, (
|
||||
condition_type,
|
||||
_condition_type,
|
||||
trigger_methods,
|
||||
) in flow._listeners.items():
|
||||
if path in trigger_methods:
|
||||
|
||||
@@ -17,7 +17,7 @@ Example
|
||||
|
||||
import ast
|
||||
import inspect
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
from typing import Any
|
||||
|
||||
from .utils import (
|
||||
build_ancestor_dict,
|
||||
@@ -56,6 +56,7 @@ def method_calls_crew(method: Any) -> bool:
|
||||
|
||||
class CrewCallVisitor(ast.NodeVisitor):
|
||||
"""AST visitor to detect .crew() method calls."""
|
||||
|
||||
def __init__(self):
|
||||
self.found = False
|
||||
|
||||
@@ -73,8 +74,8 @@ def method_calls_crew(method: Any) -> bool:
|
||||
def add_nodes_to_network(
|
||||
net: Any,
|
||||
flow: Any,
|
||||
node_positions: Dict[str, Tuple[float, float]],
|
||||
node_styles: Dict[str, Dict[str, Any]]
|
||||
node_positions: dict[str, tuple[float, float]],
|
||||
node_styles: dict[str, dict[str, Any]],
|
||||
) -> None:
|
||||
"""
|
||||
Add nodes to the network visualization with appropriate styling.
|
||||
@@ -98,6 +99,7 @@ def add_nodes_to_network(
|
||||
- Crew methods
|
||||
- Regular methods
|
||||
"""
|
||||
|
||||
def human_friendly_label(method_name):
|
||||
return method_name.replace("_", " ").title()
|
||||
|
||||
@@ -138,10 +140,10 @@ def add_nodes_to_network(
|
||||
|
||||
def compute_positions(
|
||||
flow: Any,
|
||||
node_levels: Dict[str, int],
|
||||
node_levels: dict[str, int],
|
||||
y_spacing: float = 150,
|
||||
x_spacing: float = 300
|
||||
) -> Dict[str, Tuple[float, float]]:
|
||||
x_spacing: float = 300,
|
||||
) -> dict[str, tuple[float, float]]:
|
||||
"""
|
||||
Compute the (x, y) positions for each node in the flow graph.
|
||||
|
||||
@@ -161,8 +163,8 @@ def compute_positions(
|
||||
Dict[str, Tuple[float, float]]
|
||||
Dictionary mapping node names to their (x, y) coordinates.
|
||||
"""
|
||||
level_nodes: Dict[int, List[str]] = {}
|
||||
node_positions: Dict[str, Tuple[float, float]] = {}
|
||||
level_nodes: dict[int, list[str]] = {}
|
||||
node_positions: dict[str, tuple[float, float]] = {}
|
||||
|
||||
for method_name, level in node_levels.items():
|
||||
level_nodes.setdefault(level, []).append(method_name)
|
||||
@@ -180,10 +182,10 @@ def compute_positions(
|
||||
def add_edges(
|
||||
net: Any,
|
||||
flow: Any,
|
||||
node_positions: Dict[str, Tuple[float, float]],
|
||||
colors: Dict[str, str]
|
||||
node_positions: dict[str, tuple[float, float]],
|
||||
colors: dict[str, str],
|
||||
) -> None:
|
||||
edge_smooth: Dict[str, Union[str, float]] = {"type": "continuous"} # Default value
|
||||
edge_smooth: dict[str, str | float] = {"type": "continuous"} # Default value
|
||||
"""
|
||||
Add edges to the network visualization with appropriate styling.
|
||||
|
||||
@@ -269,7 +271,7 @@ def add_edges(
|
||||
for router_method_name, paths in flow._router_paths.items():
|
||||
for path in paths:
|
||||
for listener_name, (
|
||||
condition_type,
|
||||
_condition_type,
|
||||
trigger_methods,
|
||||
) in flow._listeners.items():
|
||||
if path in trigger_methods:
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
from crewai.rag.types import SearchResult
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false" # removes logging from fastembed
|
||||
@@ -16,20 +16,20 @@ class Knowledge(BaseModel):
|
||||
Args:
|
||||
sources: list[BaseKnowledgeSource] = Field(default_factory=list)
|
||||
storage: KnowledgeStorage | None = Field(default=None)
|
||||
embedder: dict[str, Any] | None = None
|
||||
embedder: EmbedderConfig | None = None
|
||||
"""
|
||||
|
||||
sources: list[BaseKnowledgeSource] = Field(default_factory=list)
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
storage: KnowledgeStorage | None = Field(default=None)
|
||||
embedder: dict[str, Any] | None = None
|
||||
embedder: EmbedderConfig | None = None
|
||||
collection_name: str | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
sources: list[BaseKnowledgeSource],
|
||||
embedder: dict[str, Any] | None = None,
|
||||
embedder: EmbedderConfig | None = None,
|
||||
storage: KnowledgeStorage | None = None,
|
||||
**data,
|
||||
):
|
||||
|
||||
@@ -8,7 +8,9 @@ from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
|
||||
from crewai.rag.config.utils import get_rag_client
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.embeddings.factory import get_embedding_function
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.factory import build_embedder
|
||||
from crewai.rag.embeddings.types import ProviderSpec
|
||||
from crewai.rag.factory import create_client
|
||||
from crewai.rag.types import BaseRecord, SearchResult
|
||||
from crewai.utilities.logger import Logger
|
||||
@@ -22,7 +24,10 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedder: dict[str, Any] | None = None,
|
||||
embedder: ProviderSpec
|
||||
| BaseEmbeddingsProvider
|
||||
| type[BaseEmbeddingsProvider]
|
||||
| None = None,
|
||||
collection_name: str | None = None,
|
||||
) -> None:
|
||||
self.collection_name = collection_name
|
||||
@@ -35,7 +40,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
)
|
||||
|
||||
if embedder:
|
||||
embedding_function = get_embedding_function(embedder)
|
||||
embedding_function = build_embedder(embedder) # type: ignore[arg-type]
|
||||
config = ChromaDBConfig(
|
||||
embedding_function=cast(
|
||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||
|
||||
@@ -367,6 +367,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
output=output,
|
||||
guardrail=self._guardrail,
|
||||
retry_count=self._guardrail_retry_count,
|
||||
event_source=self,
|
||||
)
|
||||
|
||||
if not guardrail_result.success:
|
||||
|
||||
@@ -1,28 +1,26 @@
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
Any,
|
||||
DefaultDict,
|
||||
Dict,
|
||||
List,
|
||||
Final,
|
||||
Literal,
|
||||
Optional,
|
||||
Type,
|
||||
TextIO,
|
||||
TypedDict,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from datetime import datetime
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from litellm.types.utils import ChatCompletionDeltaToolCall
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.llm_events import (
|
||||
LLMCallCompletedEvent,
|
||||
LLMCallFailedEvent,
|
||||
@@ -31,15 +29,19 @@ from crewai.events.types.llm_events import (
|
||||
LLMStreamChunkEvent,
|
||||
)
|
||||
from crewai.events.types.tool_usage_events import (
|
||||
ToolUsageStartedEvent,
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageErrorEvent,
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageStartedEvent,
|
||||
)
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
)
|
||||
from crewai.utilities.logger_utils import suppress_warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", UserWarning)
|
||||
with suppress_warnings():
|
||||
import litellm
|
||||
from litellm import Choices
|
||||
from litellm import Choices, CustomLogger
|
||||
from litellm.exceptions import ContextWindowExceededError
|
||||
from litellm.litellm_core_utils.get_supported_openai_params import (
|
||||
get_supported_openai_params,
|
||||
@@ -47,16 +49,6 @@ with warnings.catch_warnings():
|
||||
from litellm.types.utils import ModelResponse
|
||||
from litellm.utils import supports_response_schema
|
||||
|
||||
|
||||
import io
|
||||
from typing import TextIO
|
||||
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededException,
|
||||
)
|
||||
|
||||
load_dotenv()
|
||||
|
||||
litellm.suppress_debug_info = True
|
||||
@@ -126,7 +118,11 @@ if not isinstance(sys.stderr, FilteredStream):
|
||||
sys.stderr = FilteredStream(sys.stderr)
|
||||
|
||||
|
||||
LLM_CONTEXT_WINDOW_SIZES = {
|
||||
MIN_CONTEXT: Final[int] = 1024
|
||||
MAX_CONTEXT: Final[int] = 2097152 # Current max from gemini-1.5-pro
|
||||
ANTHROPIC_PREFIXES: Final[tuple[str, str, str]] = ("anthropic/", "claude-", "claude/")
|
||||
|
||||
LLM_CONTEXT_WINDOW_SIZES: Final[dict[str, int]] = {
|
||||
# openai
|
||||
"gpt-4": 8192,
|
||||
"gpt-4o": 128000,
|
||||
@@ -252,30 +248,19 @@ LLM_CONTEXT_WINDOW_SIZES = {
|
||||
"mistral/mistral-large-2402": 32768,
|
||||
}
|
||||
|
||||
DEFAULT_CONTEXT_WINDOW_SIZE = 8192
|
||||
CONTEXT_WINDOW_USAGE_RATIO = 0.85
|
||||
|
||||
|
||||
@contextmanager
|
||||
def suppress_warnings():
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore")
|
||||
warnings.filterwarnings(
|
||||
"ignore", message="open_text is deprecated*", category=DeprecationWarning
|
||||
)
|
||||
|
||||
yield
|
||||
DEFAULT_CONTEXT_WINDOW_SIZE: Final[int] = 8192
|
||||
CONTEXT_WINDOW_USAGE_RATIO: Final[float] = 0.85
|
||||
|
||||
|
||||
class Delta(TypedDict):
|
||||
content: Optional[str]
|
||||
role: Optional[str]
|
||||
content: str | None
|
||||
role: str | None
|
||||
|
||||
|
||||
class StreamingChoices(TypedDict):
|
||||
delta: Delta
|
||||
index: int
|
||||
finish_reason: Optional[str]
|
||||
finish_reason: str | None
|
||||
|
||||
|
||||
class FunctionArgs(BaseModel):
|
||||
@@ -288,31 +273,31 @@ class AccumulatedToolArgs(BaseModel):
|
||||
|
||||
|
||||
class LLM(BaseLLM):
|
||||
completion_cost: Optional[float] = None
|
||||
completion_cost: float | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
timeout: Optional[Union[float, int]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
n: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
max_completion_tokens: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
logit_bias: Optional[Dict[int, float]] = None,
|
||||
response_format: Optional[Type[BaseModel]] = None,
|
||||
seed: Optional[int] = None,
|
||||
logprobs: Optional[int] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
base_url: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
callbacks: List[Any] | None = None,
|
||||
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
|
||||
timeout: float | int | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
n: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[int, float] | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
seed: int | None = None,
|
||||
logprobs: int | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
base_url: str | None = None,
|
||||
api_base: str | None = None,
|
||||
api_version: str | None = None,
|
||||
api_key: str | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
reasoning_effort: Literal["none", "low", "medium", "high"] | None = None,
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -345,7 +330,7 @@ class LLM(BaseLLM):
|
||||
|
||||
# Normalize self.stop to always be a List[str]
|
||||
if stop is None:
|
||||
self.stop: List[str] = []
|
||||
self.stop: list[str] = []
|
||||
elif isinstance(stop, str):
|
||||
self.stop = [stop]
|
||||
else:
|
||||
@@ -354,7 +339,8 @@ class LLM(BaseLLM):
|
||||
self.set_callbacks(callbacks or [])
|
||||
self.set_env_callbacks()
|
||||
|
||||
def _is_anthropic_model(self, model: str) -> bool:
|
||||
@staticmethod
|
||||
def _is_anthropic_model(model: str) -> bool:
|
||||
"""Determine if the model is from Anthropic provider.
|
||||
|
||||
Args:
|
||||
@@ -363,21 +349,18 @@ class LLM(BaseLLM):
|
||||
Returns:
|
||||
bool: True if the model is from Anthropic, False otherwise.
|
||||
"""
|
||||
ANTHROPIC_PREFIXES = ("anthropic/", "claude-", "claude/")
|
||||
return any(prefix in model.lower() for prefix in ANTHROPIC_PREFIXES)
|
||||
|
||||
def _prepare_completion_params(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
messages: str | list[dict[str, str]],
|
||||
tools: list[dict] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Prepare parameters for the completion call.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the LLM
|
||||
tools: Optional list of tool schemas
|
||||
callbacks: Optional list of callback functions
|
||||
available_functions: Optional dict of available functions
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Parameters for the completion call
|
||||
@@ -419,11 +402,11 @@ class LLM(BaseLLM):
|
||||
|
||||
def _handle_streaming_response(
|
||||
self,
|
||||
params: Dict[str, Any],
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
params: dict[str, Any],
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str:
|
||||
"""Handle a streaming response from the LLM.
|
||||
|
||||
@@ -445,9 +428,8 @@ class LLM(BaseLLM):
|
||||
last_chunk = None
|
||||
chunk_count = 0
|
||||
usage_info = None
|
||||
tool_calls = None
|
||||
|
||||
accumulated_tool_args: DefaultDict[int, AccumulatedToolArgs] = defaultdict(
|
||||
accumulated_tool_args: defaultdict[int, AccumulatedToolArgs] = defaultdict(
|
||||
AccumulatedToolArgs
|
||||
)
|
||||
|
||||
@@ -472,16 +454,16 @@ class LLM(BaseLLM):
|
||||
choices = chunk["choices"]
|
||||
elif hasattr(chunk, "choices"):
|
||||
# Check if choices is not a type but an actual attribute with value
|
||||
if not isinstance(getattr(chunk, "choices"), type):
|
||||
choices = getattr(chunk, "choices")
|
||||
if not isinstance(chunk.choices, type):
|
||||
choices = chunk.choices
|
||||
|
||||
# Try to extract usage information if available
|
||||
if isinstance(chunk, dict) and "usage" in chunk:
|
||||
usage_info = chunk["usage"]
|
||||
elif hasattr(chunk, "usage"):
|
||||
# Check if usage is not a type but an actual attribute with value
|
||||
if not isinstance(getattr(chunk, "usage"), type):
|
||||
usage_info = getattr(chunk, "usage")
|
||||
if not isinstance(chunk.usage, type):
|
||||
usage_info = chunk.usage
|
||||
|
||||
if choices and len(choices) > 0:
|
||||
choice = choices[0]
|
||||
@@ -491,7 +473,7 @@ class LLM(BaseLLM):
|
||||
if isinstance(choice, dict) and "delta" in choice:
|
||||
delta = choice["delta"]
|
||||
elif hasattr(choice, "delta"):
|
||||
delta = getattr(choice, "delta")
|
||||
delta = choice.delta
|
||||
|
||||
# Extract content from delta
|
||||
if delta:
|
||||
@@ -501,7 +483,7 @@ class LLM(BaseLLM):
|
||||
chunk_content = delta["content"]
|
||||
# Handle object format
|
||||
elif hasattr(delta, "content"):
|
||||
chunk_content = getattr(delta, "content")
|
||||
chunk_content = delta.content
|
||||
|
||||
# Handle case where content might be None or empty
|
||||
if chunk_content is None and isinstance(delta, dict):
|
||||
@@ -533,7 +515,9 @@ class LLM(BaseLLM):
|
||||
full_response += chunk_content
|
||||
|
||||
# Emit the chunk event
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
if not hasattr(crewai_event_bus, "emit"):
|
||||
raise Exception("crewai_event_bus must have an `emit` method")
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMStreamChunkEvent(
|
||||
@@ -572,8 +556,8 @@ class LLM(BaseLLM):
|
||||
if isinstance(last_chunk, dict) and "choices" in last_chunk:
|
||||
choices = last_chunk["choices"]
|
||||
elif hasattr(last_chunk, "choices"):
|
||||
if not isinstance(getattr(last_chunk, "choices"), type):
|
||||
choices = getattr(last_chunk, "choices")
|
||||
if not isinstance(last_chunk.choices, type):
|
||||
choices = last_chunk.choices
|
||||
|
||||
if choices and len(choices) > 0:
|
||||
choice = choices[0]
|
||||
@@ -583,14 +567,14 @@ class LLM(BaseLLM):
|
||||
if isinstance(choice, dict) and "message" in choice:
|
||||
message = choice["message"]
|
||||
elif hasattr(choice, "message"):
|
||||
message = getattr(choice, "message")
|
||||
message = choice.message
|
||||
|
||||
if message:
|
||||
content = None
|
||||
if isinstance(message, dict) and "content" in message:
|
||||
content = message["content"]
|
||||
elif hasattr(message, "content"):
|
||||
content = getattr(message, "content")
|
||||
content = message.content
|
||||
|
||||
if content:
|
||||
full_response = content
|
||||
@@ -617,8 +601,8 @@ class LLM(BaseLLM):
|
||||
if isinstance(last_chunk, dict) and "choices" in last_chunk:
|
||||
choices = last_chunk["choices"]
|
||||
elif hasattr(last_chunk, "choices"):
|
||||
if not isinstance(getattr(last_chunk, "choices"), type):
|
||||
choices = getattr(last_chunk, "choices")
|
||||
if not isinstance(last_chunk.choices, type):
|
||||
choices = last_chunk.choices
|
||||
|
||||
if choices and len(choices) > 0:
|
||||
choice = choices[0]
|
||||
@@ -627,13 +611,13 @@ class LLM(BaseLLM):
|
||||
if isinstance(choice, dict) and "message" in choice:
|
||||
message = choice["message"]
|
||||
elif hasattr(choice, "message"):
|
||||
message = getattr(choice, "message")
|
||||
message = choice.message
|
||||
|
||||
if message:
|
||||
if isinstance(message, dict) and "tool_calls" in message:
|
||||
tool_calls = message["tool_calls"]
|
||||
elif hasattr(message, "tool_calls"):
|
||||
tool_calls = getattr(message, "tool_calls")
|
||||
tool_calls = message.tool_calls
|
||||
except Exception as e:
|
||||
logging.debug(f"Error checking for tool calls: {e}")
|
||||
# --- 8) If no tool calls or no available functions, return the text response directly
|
||||
@@ -673,11 +657,11 @@ class LLM(BaseLLM):
|
||||
# Catch context window errors from litellm and convert them to our own exception type.
|
||||
# This exception is handled by CrewAgentExecutor._invoke_loop() which can then
|
||||
# decide whether to summarize the content or abort based on the respect_context_window flag.
|
||||
raise LLMContextLengthExceededException(str(e))
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
except Exception as e:
|
||||
logging.error(f"Error in streaming response: {str(e)}")
|
||||
logging.error(f"Error in streaming response: {e!s}")
|
||||
if full_response.strip():
|
||||
logging.warning(f"Returning partial response despite error: {str(e)}")
|
||||
logging.warning(f"Returning partial response despite error: {e!s}")
|
||||
self._handle_emit_call_events(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
@@ -688,22 +672,25 @@ class LLM(BaseLLM):
|
||||
return full_response
|
||||
|
||||
# Emit failed event and re-raise the exception
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
if not hasattr(crewai_event_bus, "emit"):
|
||||
raise AttributeError(
|
||||
"crewai_event_bus must have an 'emit' method"
|
||||
) from e
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(
|
||||
error=str(e), from_task=from_task, from_agent=from_agent
|
||||
),
|
||||
)
|
||||
raise Exception(f"Failed to get streaming response: {str(e)}")
|
||||
raise Exception(f"Failed to get streaming response: {e!s}") from e
|
||||
|
||||
def _handle_streaming_tool_calls(
|
||||
self,
|
||||
tool_calls: List[ChatCompletionDeltaToolCall],
|
||||
accumulated_tool_args: DefaultDict[int, AccumulatedToolArgs],
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
tool_calls: list[ChatCompletionDeltaToolCall],
|
||||
accumulated_tool_args: defaultdict[int, AccumulatedToolArgs],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> None | str:
|
||||
for tool_call in tool_calls:
|
||||
current_tool_accumulator = accumulated_tool_args[tool_call.index]
|
||||
@@ -715,7 +702,8 @@ class LLM(BaseLLM):
|
||||
current_tool_accumulator.function.arguments += (
|
||||
tool_call.function.arguments
|
||||
)
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
if not hasattr(crewai_event_bus, "emit"):
|
||||
raise AttributeError("crewai_event_bus must have an 'emit' method")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMStreamChunkEvent(
|
||||
@@ -742,11 +730,11 @@ class LLM(BaseLLM):
|
||||
continue
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _handle_streaming_callbacks(
|
||||
self,
|
||||
callbacks: Optional[List[Any]],
|
||||
usage_info: Optional[Dict[str, Any]],
|
||||
last_chunk: Optional[Any],
|
||||
callbacks: list[Any] | None,
|
||||
usage_info: dict[str, Any] | None,
|
||||
last_chunk: Any | None,
|
||||
) -> None:
|
||||
"""Handle callbacks with usage info for streaming responses.
|
||||
|
||||
@@ -769,10 +757,8 @@ class LLM(BaseLLM):
|
||||
):
|
||||
usage_info = last_chunk["usage"]
|
||||
elif hasattr(last_chunk, "usage"):
|
||||
if not isinstance(
|
||||
getattr(last_chunk, "usage"), type
|
||||
):
|
||||
usage_info = getattr(last_chunk, "usage")
|
||||
if not isinstance(last_chunk.usage, type):
|
||||
usage_info = last_chunk.usage
|
||||
except Exception as e:
|
||||
logging.debug(f"Error extracting usage info: {e}")
|
||||
|
||||
@@ -786,11 +772,11 @@ class LLM(BaseLLM):
|
||||
|
||||
def _handle_non_streaming_response(
|
||||
self,
|
||||
params: Dict[str, Any],
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
params: dict[str, Any],
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle a non-streaming response from the LLM.
|
||||
|
||||
@@ -815,7 +801,7 @@ class LLM(BaseLLM):
|
||||
except ContextWindowExceededError as e:
|
||||
# Convert litellm's context window error to our own exception type
|
||||
# for consistent handling in the rest of the codebase
|
||||
raise LLMContextLengthExceededException(str(e))
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
# --- 2) Extract response message and content
|
||||
response_message = cast(Choices, cast(ModelResponse, response).choices)[
|
||||
0
|
||||
@@ -847,7 +833,7 @@ class LLM(BaseLLM):
|
||||
)
|
||||
return text_response
|
||||
# --- 6) If there is no text response, no available functions, but there are tool calls, return the tool calls
|
||||
elif tool_calls and not available_functions and not text_response:
|
||||
if tool_calls and not available_functions and not text_response:
|
||||
return tool_calls
|
||||
|
||||
# --- 7) Handle tool calls if present
|
||||
@@ -868,19 +854,21 @@ class LLM(BaseLLM):
|
||||
|
||||
def _handle_tool_call(
|
||||
self,
|
||||
tool_calls: List[Any],
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
) -> Optional[str]:
|
||||
tool_calls: list[Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str | None:
|
||||
"""Handle a tool call from the LLM.
|
||||
|
||||
Args:
|
||||
tool_calls: List of tool calls from the LLM
|
||||
available_functions: Dict of available functions
|
||||
from_task: Optional Task that invoked the LLM
|
||||
from_agent: Optional Agent that invoked the LLM
|
||||
|
||||
Returns:
|
||||
Optional[str]: The result of the tool call, or None if no tool call was made
|
||||
The result of the tool call, or None if no tool call was made
|
||||
"""
|
||||
# --- 1) Validate tool calls and available functions
|
||||
if not tool_calls or not available_functions:
|
||||
@@ -899,7 +887,8 @@ class LLM(BaseLLM):
|
||||
fn = available_functions[function_name]
|
||||
|
||||
# --- 3.2) Execute function
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
if not hasattr(crewai_event_bus, "emit"):
|
||||
raise AttributeError("crewai_event_bus must have an 'emit' method")
|
||||
started_at = datetime.now()
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -939,17 +928,20 @@ class LLM(BaseLLM):
|
||||
function_name, lambda: None
|
||||
) # Ensure fn is always a callable
|
||||
logging.error(f"Error executing function '{function_name}': {e}")
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
if not hasattr(crewai_event_bus, "emit"):
|
||||
raise AttributeError(
|
||||
"crewai_event_bus must have an 'emit' method"
|
||||
) from e
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(error=f"Tool execution error: {str(e)}"),
|
||||
event=LLMCallFailedEvent(error=f"Tool execution error: {e!s}"),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolUsageErrorEvent(
|
||||
tool_name=function_name,
|
||||
tool_args=function_args,
|
||||
error=f"Tool execution error: {str(e)}",
|
||||
error=f"Tool execution error: {e!s}",
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
@@ -958,13 +950,13 @@ class LLM(BaseLLM):
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
) -> Union[str, Any]:
|
||||
messages: str | list[dict[str, str]],
|
||||
tools: list[dict] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str | Any:
|
||||
"""High-level LLM call method.
|
||||
|
||||
Args:
|
||||
@@ -988,10 +980,11 @@ class LLM(BaseLLM):
|
||||
Raises:
|
||||
TypeError: If messages format is invalid
|
||||
ValueError: If response format is not supported
|
||||
LLMContextLengthExceededException: If input exceeds model's context limit
|
||||
LLMContextLengthExceededError: If input exceeds model's context limit
|
||||
"""
|
||||
# --- 1) Emit call started event
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
if not hasattr(crewai_event_bus, "emit"):
|
||||
raise AttributeError("crewai_event_bus must have an 'emit' method")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallStartedEvent(
|
||||
@@ -1028,13 +1021,12 @@ class LLM(BaseLLM):
|
||||
return self._handle_streaming_response(
|
||||
params, callbacks, available_functions, from_task, from_agent
|
||||
)
|
||||
else:
|
||||
return self._handle_non_streaming_response(
|
||||
params, callbacks, available_functions, from_task, from_agent
|
||||
)
|
||||
return self._handle_non_streaming_response(
|
||||
params, callbacks, available_functions, from_task, from_agent
|
||||
)
|
||||
|
||||
except LLMContextLengthExceededException:
|
||||
# Re-raise LLMContextLengthExceededException as it should be handled
|
||||
except LLMContextLengthExceededError:
|
||||
# Re-raise LLMContextLengthExceededError as it should be handled
|
||||
# by the CrewAgentExecutor._invoke_loop method, which can then decide
|
||||
# whether to summarize the content or abort based on the respect_context_window flag
|
||||
raise
|
||||
@@ -1065,7 +1057,10 @@ class LLM(BaseLLM):
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
if not hasattr(crewai_event_bus, "emit"):
|
||||
raise AttributeError(
|
||||
"crewai_event_bus must have an 'emit' method"
|
||||
) from e
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(
|
||||
@@ -1078,8 +1073,8 @@ class LLM(BaseLLM):
|
||||
self,
|
||||
response: Any,
|
||||
call_type: LLMCallType,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
messages: str | list[dict[str, Any]] | None = None,
|
||||
):
|
||||
"""Handle the events for the LLM call.
|
||||
@@ -1091,7 +1086,8 @@ class LLM(BaseLLM):
|
||||
from_agent: Optional agent object
|
||||
messages: Optional messages object
|
||||
"""
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
if not hasattr(crewai_event_bus, "emit"):
|
||||
raise AttributeError("crewai_event_bus must have an 'emit' method")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallCompletedEvent(
|
||||
@@ -1105,8 +1101,8 @@ class LLM(BaseLLM):
|
||||
)
|
||||
|
||||
def _format_messages_for_provider(
|
||||
self, messages: List[Dict[str, str]]
|
||||
) -> List[Dict[str, str]]:
|
||||
self, messages: list[dict[str, str]]
|
||||
) -> list[dict[str, str]]:
|
||||
"""Format messages according to provider requirements.
|
||||
|
||||
Args:
|
||||
@@ -1147,7 +1143,7 @@ class LLM(BaseLLM):
|
||||
if "mistral" in self.model.lower():
|
||||
# Check if the last message has a role of 'assistant'
|
||||
if messages and messages[-1]["role"] == "assistant":
|
||||
return messages + [{"role": "user", "content": "Please continue."}]
|
||||
return [*messages, {"role": "user", "content": "Please continue."}]
|
||||
return messages
|
||||
|
||||
# TODO: Remove this code after merging PR https://github.com/BerriAI/litellm/pull/10917
|
||||
@@ -1157,7 +1153,7 @@ class LLM(BaseLLM):
|
||||
and messages
|
||||
and messages[-1]["role"] == "assistant"
|
||||
):
|
||||
return messages + [{"role": "user", "content": ""}]
|
||||
return [*messages, {"role": "user", "content": ""}]
|
||||
|
||||
# Handle Anthropic models
|
||||
if not self.is_anthropic:
|
||||
@@ -1170,7 +1166,7 @@ class LLM(BaseLLM):
|
||||
|
||||
return messages
|
||||
|
||||
def _get_custom_llm_provider(self) -> Optional[str]:
|
||||
def _get_custom_llm_provider(self) -> str | None:
|
||||
"""
|
||||
Derives the custom_llm_provider from the model string.
|
||||
- For example, if the model is "openrouter/deepseek/deepseek-chat", returns "openrouter".
|
||||
@@ -1207,7 +1203,7 @@ class LLM(BaseLLM):
|
||||
self.model, custom_llm_provider=provider
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to check function calling support: {str(e)}")
|
||||
logging.error(f"Failed to check function calling support: {e!s}")
|
||||
return False
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
@@ -1215,7 +1211,7 @@ class LLM(BaseLLM):
|
||||
params = get_supported_openai_params(model=self.model)
|
||||
return params is not None and "stop" in params
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to get supported params: {str(e)}")
|
||||
logging.error(f"Failed to get supported params: {e!s}")
|
||||
return False
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
@@ -1229,9 +1225,6 @@ class LLM(BaseLLM):
|
||||
if self.context_window_size != 0:
|
||||
return self.context_window_size
|
||||
|
||||
MIN_CONTEXT = 1024
|
||||
MAX_CONTEXT = 2097152 # Current max from gemini-1.5-pro
|
||||
|
||||
# Validate all context window sizes
|
||||
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
|
||||
if value < MIN_CONTEXT or value > MAX_CONTEXT:
|
||||
@@ -1247,7 +1240,8 @@ class LLM(BaseLLM):
|
||||
self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
return self.context_window_size
|
||||
|
||||
def set_callbacks(self, callbacks: List[Any]):
|
||||
@staticmethod
|
||||
def set_callbacks(callbacks: list[Any]):
|
||||
"""
|
||||
Attempt to keep a single set of callbacks in litellm by removing old
|
||||
duplicates and adding new ones.
|
||||
@@ -1264,9 +1258,9 @@ class LLM(BaseLLM):
|
||||
|
||||
litellm.callbacks = callbacks
|
||||
|
||||
def set_env_callbacks(self):
|
||||
"""
|
||||
Sets the success and failure callbacks for the LiteLLM library from environment variables.
|
||||
@staticmethod
|
||||
def set_env_callbacks() -> None:
|
||||
"""Sets the success and failure callbacks for the LiteLLM library from environment variables.
|
||||
|
||||
This method reads the `LITELLM_SUCCESS_CALLBACKS` and `LITELLM_FAILURE_CALLBACKS`
|
||||
environment variables, which should contain comma-separated lists of callback names.
|
||||
@@ -1276,7 +1270,7 @@ class LLM(BaseLLM):
|
||||
If the environment variables are not set or are empty, the corresponding callback lists
|
||||
will be set to empty lists.
|
||||
|
||||
Example:
|
||||
Examples:
|
||||
LITELLM_SUCCESS_CALLBACKS="langfuse,langsmith"
|
||||
LITELLM_FAILURE_CALLBACKS="langfuse"
|
||||
|
||||
@@ -1285,16 +1279,15 @@ class LLM(BaseLLM):
|
||||
"""
|
||||
with suppress_warnings():
|
||||
success_callbacks_str = os.environ.get("LITELLM_SUCCESS_CALLBACKS", "")
|
||||
success_callbacks = []
|
||||
success_callbacks: list[str | Callable[..., Any] | CustomLogger] = []
|
||||
if success_callbacks_str:
|
||||
success_callbacks = [
|
||||
cb.strip() for cb in success_callbacks_str.split(",") if cb.strip()
|
||||
]
|
||||
|
||||
failure_callbacks_str = os.environ.get("LITELLM_FAILURE_CALLBACKS", "")
|
||||
failure_callbacks = []
|
||||
if failure_callbacks_str:
|
||||
failure_callbacks = [
|
||||
failure_callbacks: list[str | Callable[..., Any] | CustomLogger] = [
|
||||
cb.strip() for cb in failure_callbacks_str.split(",") if cb.strip()
|
||||
]
|
||||
|
||||
|
||||
@@ -27,7 +27,10 @@ class EntityMemory(Memory):
|
||||
_memory_provider: str | None = PrivateAttr()
|
||||
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||
memory_provider = embedder_config.get("provider") if embedder_config else None
|
||||
memory_provider = None
|
||||
if embedder_config and isinstance(embedder_config, dict):
|
||||
memory_provider = embedder_config.get("provider")
|
||||
|
||||
if memory_provider == "mem0":
|
||||
try:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
@@ -35,7 +38,11 @@ class EntityMemory(Memory):
|
||||
raise ImportError(
|
||||
"Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||
) from e
|
||||
config = embedder_config.get("config") if embedder_config else None
|
||||
config = (
|
||||
embedder_config.get("config")
|
||||
if embedder_config and isinstance(embedder_config, dict)
|
||||
else None
|
||||
)
|
||||
storage = Mem0Storage(type="short_term", crew=crew, config=config)
|
||||
else:
|
||||
storage = (
|
||||
|
||||
@@ -13,6 +13,7 @@ from crewai.events.types.memory_events import (
|
||||
from crewai.memory.external.external_memory_item import ExternalMemoryItem
|
||||
from crewai.memory.memory import Memory
|
||||
from crewai.memory.storage.interface import Storage
|
||||
from crewai.rag.embeddings.types import ProviderSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
@@ -35,7 +36,9 @@ class ExternalMemory(Memory):
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def create_storage(crew: Any, embedder_config: dict[str, Any] | None) -> Storage:
|
||||
def create_storage(
|
||||
crew: Any, embedder_config: dict[str, Any] | ProviderSpec | None
|
||||
) -> Storage:
|
||||
if not embedder_config:
|
||||
raise ValueError("embedder_config is required")
|
||||
|
||||
@@ -159,6 +162,6 @@ class ExternalMemory(Memory):
|
||||
super().set_crew(crew)
|
||||
|
||||
if not self.storage:
|
||||
self.storage = self.create_storage(crew, self.embedder_config)
|
||||
self.storage = self.create_storage(crew, self.embedder_config) # type: ignore[arg-type]
|
||||
|
||||
return self
|
||||
|
||||
@@ -2,6 +2,8 @@ from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent import Agent
|
||||
from crewai.task import Task
|
||||
@@ -12,7 +14,7 @@ class Memory(BaseModel):
|
||||
Base class for memory, now supporting agent tags and generic metadata.
|
||||
"""
|
||||
|
||||
embedder_config: dict[str, Any] | None = None
|
||||
embedder_config: EmbedderConfig | dict[str, Any] | None = None
|
||||
crew: Any | None = None
|
||||
|
||||
storage: Any
|
||||
|
||||
@@ -29,7 +29,10 @@ class ShortTermMemory(Memory):
|
||||
_memory_provider: str | None = PrivateAttr()
|
||||
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||
memory_provider = embedder_config.get("provider") if embedder_config else None
|
||||
memory_provider = None
|
||||
if embedder_config and isinstance(embedder_config, dict):
|
||||
memory_provider = embedder_config.get("provider")
|
||||
|
||||
if memory_provider == "mem0":
|
||||
try:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
@@ -37,7 +40,11 @@ class ShortTermMemory(Memory):
|
||||
raise ImportError(
|
||||
"Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||
) from e
|
||||
config = embedder_config.get("config") if embedder_config else None
|
||||
config = (
|
||||
embedder_config.get("config")
|
||||
if embedder_config and isinstance(embedder_config, dict)
|
||||
else None
|
||||
)
|
||||
storage = Mem0Storage(type="short_term", crew=crew, config=config)
|
||||
else:
|
||||
storage = (
|
||||
|
||||
@@ -7,7 +7,9 @@ from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
|
||||
from crewai.rag.config.utils import get_rag_client
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.embeddings.factory import get_embedding_function
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.factory import build_embedder
|
||||
from crewai.rag.embeddings.types import ProviderSpec
|
||||
from crewai.rag.factory import create_client
|
||||
from crewai.rag.storage.base_rag_storage import BaseRAGStorage
|
||||
from crewai.rag.types import BaseRecord
|
||||
@@ -25,7 +27,7 @@ class RAGStorage(BaseRAGStorage):
|
||||
self,
|
||||
type: str,
|
||||
allow_reset: bool = True,
|
||||
embedder_config: dict[str, Any] | None = None,
|
||||
embedder_config: ProviderSpec | BaseEmbeddingsProvider | None = None,
|
||||
crew: Any = None,
|
||||
path: str | None = None,
|
||||
) -> None:
|
||||
@@ -49,12 +51,46 @@ class RAGStorage(BaseRAGStorage):
|
||||
)
|
||||
|
||||
if self.embedder_config:
|
||||
embedding_function = get_embedding_function(self.embedder_config)
|
||||
config = ChromaDBConfig(
|
||||
embedding_function=cast(
|
||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||
embedding_function = build_embedder(self.embedder_config)
|
||||
|
||||
try:
|
||||
_ = embedding_function(["test"])
|
||||
except Exception as e:
|
||||
provider = (
|
||||
self.embedder_config["provider"]
|
||||
if isinstance(self.embedder_config, dict)
|
||||
else self.embedder_config.__class__.__name__.replace(
|
||||
"Provider", ""
|
||||
).lower()
|
||||
)
|
||||
raise ValueError(
|
||||
f"Failed to initialize embedder. Please check your configuration or connection.\n"
|
||||
f"Provider: {provider}\n"
|
||||
f"Error: {e}"
|
||||
) from e
|
||||
|
||||
batch_size = None
|
||||
if (
|
||||
isinstance(self.embedder_config, dict)
|
||||
and "config" in self.embedder_config
|
||||
):
|
||||
nested_config = self.embedder_config["config"]
|
||||
if isinstance(nested_config, dict):
|
||||
batch_size = nested_config.get("batch_size")
|
||||
|
||||
if batch_size is not None:
|
||||
config = ChromaDBConfig(
|
||||
embedding_function=cast(
|
||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||
),
|
||||
batch_size=cast(int, batch_size),
|
||||
)
|
||||
else:
|
||||
config = ChromaDBConfig(
|
||||
embedding_function=cast(
|
||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||
)
|
||||
)
|
||||
)
|
||||
self._client = create_client(config)
|
||||
|
||||
def _get_client(self) -> BaseClient:
|
||||
@@ -95,7 +131,26 @@ class RAGStorage(BaseRAGStorage):
|
||||
if metadata:
|
||||
document["metadata"] = metadata
|
||||
|
||||
client.add_documents(collection_name=collection_name, documents=[document])
|
||||
batch_size = None
|
||||
if (
|
||||
self.embedder_config
|
||||
and isinstance(self.embedder_config, dict)
|
||||
and "config" in self.embedder_config
|
||||
):
|
||||
nested_config = self.embedder_config["config"]
|
||||
if isinstance(nested_config, dict):
|
||||
batch_size = nested_config.get("batch_size")
|
||||
|
||||
if batch_size is not None:
|
||||
client.add_documents(
|
||||
collection_name=collection_name,
|
||||
documents=[document],
|
||||
batch_size=cast(int, batch_size),
|
||||
)
|
||||
else:
|
||||
client.add_documents(
|
||||
collection_name=collection_name, documents=[document]
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error during {self.type} save: {e!s}\n{traceback.format_exc()}"
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
"""RAG (Retrieval-Augmented Generation) infrastructure for CrewAI."""
|
||||
|
||||
import sys
|
||||
import importlib
|
||||
import sys
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
|
||||
from crewai.rag.config.types import RagConfigType
|
||||
from crewai.rag.config.utils import set_rag_config
|
||||
|
||||
|
||||
_module_path = __path__
|
||||
_module_file = __file__
|
||||
|
||||
|
||||
class _RagModule(ModuleType):
|
||||
"""Module wrapper to intercept attribute setting for config."""
|
||||
|
||||
@@ -51,8 +51,10 @@ class _RagModule(ModuleType):
|
||||
"""
|
||||
try:
|
||||
return importlib.import_module(f"{self.__name__}.{name}")
|
||||
except ImportError:
|
||||
raise AttributeError(f"module '{self.__name__}' has no attribute '{name}'")
|
||||
except ImportError as e:
|
||||
raise AttributeError(
|
||||
f"module '{self.__name__}' has no attribute '{name}'"
|
||||
) from e
|
||||
|
||||
|
||||
sys.modules[__name__] = _RagModule(__name__)
|
||||
|
||||
@@ -17,6 +17,7 @@ from crewai.rag.chromadb.types import (
|
||||
ChromaDBCollectionSearchParams,
|
||||
)
|
||||
from crewai.rag.chromadb.utils import (
|
||||
_create_batch_slice,
|
||||
_extract_search_params,
|
||||
_is_async_client,
|
||||
_is_sync_client,
|
||||
@@ -52,6 +53,7 @@ class ChromaDBClient(BaseClient):
|
||||
embedding_function: ChromaEmbeddingFunction,
|
||||
default_limit: int = 5,
|
||||
default_score_threshold: float = 0.6,
|
||||
default_batch_size: int = 100,
|
||||
) -> None:
|
||||
"""Initialize ChromaDBClient with client and embedding function.
|
||||
|
||||
@@ -60,11 +62,13 @@ class ChromaDBClient(BaseClient):
|
||||
embedding_function: Embedding function for text to vector conversion.
|
||||
default_limit: Default number of results to return in searches.
|
||||
default_score_threshold: Default minimum score for search results.
|
||||
default_batch_size: Default batch size for adding documents.
|
||||
"""
|
||||
self.client = client
|
||||
self.embedding_function = embedding_function
|
||||
self.default_limit = default_limit
|
||||
self.default_score_threshold = default_score_threshold
|
||||
self.default_batch_size = default_batch_size
|
||||
|
||||
def create_collection(
|
||||
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
|
||||
@@ -291,6 +295,7 @@ class ChromaDBClient(BaseClient):
|
||||
- content: The text content (required)
|
||||
- doc_id: Optional unique identifier (auto-generated if missing)
|
||||
- metadata: Optional metadata dictionary
|
||||
batch_size: Optional batch size for processing documents (default: 100)
|
||||
|
||||
Raises:
|
||||
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
|
||||
@@ -305,6 +310,7 @@ class ChromaDBClient(BaseClient):
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
documents = kwargs["documents"]
|
||||
batch_size = kwargs.get("batch_size", self.default_batch_size)
|
||||
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
@@ -315,13 +321,17 @@ class ChromaDBClient(BaseClient):
|
||||
)
|
||||
|
||||
prepared = _prepare_documents_for_chromadb(documents)
|
||||
# ChromaDB doesn't accept empty metadata dicts, so pass None if all are empty
|
||||
metadatas = prepared.metadatas if any(m for m in prepared.metadatas) else None
|
||||
collection.upsert(
|
||||
ids=prepared.ids,
|
||||
documents=prepared.texts,
|
||||
metadatas=metadatas,
|
||||
)
|
||||
|
||||
for i in range(0, len(prepared.ids), batch_size):
|
||||
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||
prepared=prepared, start_index=i, batch_size=batch_size
|
||||
)
|
||||
|
||||
collection.upsert(
|
||||
ids=batch_ids,
|
||||
documents=batch_texts,
|
||||
metadatas=batch_metadatas,
|
||||
)
|
||||
|
||||
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
||||
"""Add documents with their embeddings to a collection asynchronously.
|
||||
@@ -335,6 +345,7 @@ class ChromaDBClient(BaseClient):
|
||||
- content: The text content (required)
|
||||
- doc_id: Optional unique identifier (auto-generated if missing)
|
||||
- metadata: Optional metadata dictionary
|
||||
batch_size: Optional batch size for processing documents (default: 100)
|
||||
|
||||
Raises:
|
||||
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
|
||||
@@ -349,6 +360,7 @@ class ChromaDBClient(BaseClient):
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
documents = kwargs["documents"]
|
||||
batch_size = kwargs.get("batch_size", self.default_batch_size)
|
||||
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
@@ -358,13 +370,17 @@ class ChromaDBClient(BaseClient):
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
prepared = _prepare_documents_for_chromadb(documents)
|
||||
# ChromaDB doesn't accept empty metadata dicts, so pass None if all are empty
|
||||
metadatas = prepared.metadatas if any(m for m in prepared.metadatas) else None
|
||||
await collection.upsert(
|
||||
ids=prepared.ids,
|
||||
documents=prepared.texts,
|
||||
metadatas=metadatas,
|
||||
)
|
||||
|
||||
for i in range(0, len(prepared.ids), batch_size):
|
||||
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||
prepared=prepared, start_index=i, batch_size=batch_size
|
||||
)
|
||||
|
||||
await collection.upsert(
|
||||
ids=batch_ids,
|
||||
documents=batch_texts,
|
||||
metadatas=batch_metadatas,
|
||||
)
|
||||
|
||||
def search(
|
||||
self, **kwargs: Unpack[ChromaDBCollectionSearchParams]
|
||||
|
||||
@@ -41,4 +41,5 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient:
|
||||
embedding_function=config.embedding_function,
|
||||
default_limit=config.limit,
|
||||
default_score_threshold=config.score_threshold,
|
||||
default_batch_size=config.batch_size,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Utility functions for ChromaDB client implementation."""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from typing import Literal, TypeGuard, cast
|
||||
|
||||
@@ -9,7 +10,6 @@ from chromadb.api.models.AsyncCollection import AsyncCollection
|
||||
from chromadb.api.models.Collection import Collection
|
||||
from chromadb.api.types import (
|
||||
Include,
|
||||
IncludeEnum,
|
||||
QueryResult,
|
||||
)
|
||||
|
||||
@@ -72,7 +72,15 @@ def _prepare_documents_for_chromadb(
|
||||
if "doc_id" in doc:
|
||||
ids.append(doc["doc_id"])
|
||||
else:
|
||||
content_hash = hashlib.sha256(doc["content"].encode()).hexdigest()[:16]
|
||||
content_for_hash = doc["content"]
|
||||
metadata = doc.get("metadata")
|
||||
if metadata:
|
||||
metadata_str = json.dumps(metadata, sort_keys=True)
|
||||
content_for_hash = f"{content_for_hash}|{metadata_str}"
|
||||
|
||||
content_hash = hashlib.blake2b(
|
||||
content_for_hash.encode(), digest_size=32
|
||||
).hexdigest()
|
||||
ids.append(content_hash)
|
||||
|
||||
texts.append(doc["content"])
|
||||
@@ -88,6 +96,32 @@ def _prepare_documents_for_chromadb(
|
||||
return PreparedDocuments(ids, texts, metadatas)
|
||||
|
||||
|
||||
def _create_batch_slice(
|
||||
prepared: PreparedDocuments, start_index: int, batch_size: int
|
||||
) -> tuple[list[str], list[str], list[Mapping[str, str | int | float | bool]] | None]:
|
||||
"""Create a batch slice from prepared documents.
|
||||
|
||||
Args:
|
||||
prepared: PreparedDocuments containing ids, texts, and metadatas.
|
||||
start_index: Starting index for the batch.
|
||||
batch_size: Size of the batch.
|
||||
|
||||
Returns:
|
||||
Tuple of (batch_ids, batch_texts, batch_metadatas).
|
||||
"""
|
||||
batch_end = min(start_index + batch_size, len(prepared.ids))
|
||||
batch_ids = prepared.ids[start_index:batch_end]
|
||||
batch_texts = prepared.texts[start_index:batch_end]
|
||||
batch_metadatas = (
|
||||
prepared.metadatas[start_index:batch_end] if prepared.metadatas else None
|
||||
)
|
||||
|
||||
if batch_metadatas and not any(m for m in batch_metadatas):
|
||||
batch_metadatas = None
|
||||
|
||||
return batch_ids, batch_texts, batch_metadatas
|
||||
|
||||
|
||||
def _extract_search_params(
|
||||
kwargs: ChromaDBCollectionSearchParams,
|
||||
) -> ExtractedSearchParams:
|
||||
@@ -107,9 +141,12 @@ def _extract_search_params(
|
||||
score_threshold=kwargs.get("score_threshold"),
|
||||
where=kwargs.get("where"),
|
||||
where_document=kwargs.get("where_document"),
|
||||
include=kwargs.get(
|
||||
"include",
|
||||
[IncludeEnum.metadatas, IncludeEnum.documents, IncludeEnum.distances],
|
||||
include=cast(
|
||||
Include,
|
||||
kwargs.get(
|
||||
"include",
|
||||
["metadatas", "documents", "distances"],
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -158,7 +195,7 @@ def _convert_chromadb_results_to_search_results(
|
||||
"""
|
||||
search_results: list[SearchResult] = []
|
||||
|
||||
include_strings = [item.value for item in include] if include else []
|
||||
include_strings = list(include) if include else []
|
||||
|
||||
ids = results["ids"][0] if results.get("ids") else []
|
||||
|
||||
|
||||
@@ -16,3 +16,4 @@ class BaseRagConfig:
|
||||
embedding_function: Any | None = field(default=None)
|
||||
limit: int = field(default=5)
|
||||
score_threshold: float = field(default=0.6)
|
||||
batch_size: int = field(default=100)
|
||||
|
||||
@@ -1 +1 @@
|
||||
"""Optional imports for RAG configuration providers."""
|
||||
"""Optional imports for RAG configuration providers."""
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Base classes for missing provider configurations."""
|
||||
|
||||
from typing import Literal
|
||||
from dataclasses import field
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import ConfigDict
|
||||
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Protocol, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.rag.chromadb.client import ChromaDBClient
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""Provider-specific missing configuration classes."""
|
||||
|
||||
from typing import Literal
|
||||
from dataclasses import field
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import ConfigDict
|
||||
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Type definitions for RAG configuration."""
|
||||
|
||||
from typing import Annotated, TypeAlias, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Annotated, TypeAlias
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.config.constants import DISCRIMINATOR
|
||||
|
||||
@@ -4,14 +4,14 @@ from contextvars import ContextVar
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.utilities.import_utils import require
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.config.types import RagConfigType
|
||||
from crewai.rag.config.constants import (
|
||||
DEFAULT_RAG_CONFIG_PATH,
|
||||
DEFAULT_RAG_CONFIG_CLASS,
|
||||
DEFAULT_RAG_CONFIG_PATH,
|
||||
)
|
||||
from crewai.rag.config.types import RagConfigType
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.factory import create_client
|
||||
from crewai.utilities.import_utils import require
|
||||
|
||||
|
||||
class RagContext(BaseModel):
|
||||
|
||||
@@ -29,7 +29,7 @@ class BaseCollectionParams(TypedDict):
|
||||
]
|
||||
|
||||
|
||||
class BaseCollectionAddParams(BaseCollectionParams):
|
||||
class BaseCollectionAddParams(BaseCollectionParams, total=False):
|
||||
"""Parameters for adding documents to a collection.
|
||||
|
||||
Extends BaseCollectionParams with document-specific fields.
|
||||
@@ -37,9 +37,11 @@ class BaseCollectionAddParams(BaseCollectionParams):
|
||||
Attributes:
|
||||
collection_name: The name of the collection to add documents to.
|
||||
documents: List of BaseRecord dictionaries containing document data.
|
||||
batch_size: Optional batch size for processing documents to avoid token limits.
|
||||
"""
|
||||
|
||||
documents: list[BaseRecord]
|
||||
documents: Required[list[BaseRecord]]
|
||||
batch_size: int
|
||||
|
||||
|
||||
class BaseCollectionSearchParams(BaseCollectionParams, total=False):
|
||||
|
||||
142
src/crewai/rag/core/base_embeddings_callable.py
Normal file
142
src/crewai/rag/core/base_embeddings_callable.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""Base embeddings callable utilities for RAG systems."""
|
||||
|
||||
from typing import Protocol, TypeVar, runtime_checkable
|
||||
|
||||
import numpy as np
|
||||
|
||||
from crewai.rag.core.types import (
|
||||
Embeddable,
|
||||
Embedding,
|
||||
Embeddings,
|
||||
PyEmbedding,
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
D = TypeVar("D", bound=Embeddable, contravariant=True)
|
||||
|
||||
|
||||
def normalize_embeddings(
|
||||
target: Embedding | list[Embedding] | PyEmbedding | list[PyEmbedding],
|
||||
) -> Embeddings | None:
|
||||
"""Normalize various embedding formats to a standard list of numpy arrays.
|
||||
|
||||
Args:
|
||||
target: Input embeddings in various formats (list of floats, list of lists,
|
||||
numpy array, or list of numpy arrays).
|
||||
|
||||
Returns:
|
||||
Normalized embeddings as a list of numpy arrays, or None if input is None.
|
||||
|
||||
Raises:
|
||||
ValueError: If embeddings are empty or in an unsupported format.
|
||||
"""
|
||||
if isinstance(target, np.ndarray):
|
||||
if target.ndim == 1:
|
||||
return [target.astype(np.float32)]
|
||||
if target.ndim == 2:
|
||||
return [row.astype(np.float32) for row in target]
|
||||
raise ValueError(f"Unsupported numpy array shape: {target.shape}")
|
||||
|
||||
first = target[0]
|
||||
if isinstance(first, (int, float)) and not isinstance(first, bool):
|
||||
return [np.array(target, dtype=np.float32)]
|
||||
if isinstance(first, list):
|
||||
return [np.array(emb, dtype=np.float32) for emb in target]
|
||||
if isinstance(first, np.ndarray):
|
||||
return [emb.astype(np.float32) for emb in target] # type: ignore[union-attr]
|
||||
|
||||
raise ValueError(f"Unsupported embeddings format: {type(first)}")
|
||||
|
||||
|
||||
def maybe_cast_one_to_many(target: T | list[T] | None) -> list[T] | None:
|
||||
"""Cast a single item to a list if needed.
|
||||
|
||||
Args:
|
||||
target: A single item or list of items.
|
||||
|
||||
Returns:
|
||||
A list of items or None if input is None.
|
||||
"""
|
||||
if target is None:
|
||||
return None
|
||||
return target if isinstance(target, list) else [target]
|
||||
|
||||
|
||||
def validate_embeddings(embeddings: Embeddings) -> Embeddings:
|
||||
"""Validate embeddings format and content.
|
||||
|
||||
Args:
|
||||
embeddings: List of numpy arrays to validate.
|
||||
|
||||
Returns:
|
||||
Validated embeddings.
|
||||
|
||||
Raises:
|
||||
ValueError: If embeddings format or content is invalid.
|
||||
"""
|
||||
if not isinstance(embeddings, list):
|
||||
raise ValueError(
|
||||
f"Expected embeddings to be a list, got {type(embeddings).__name__}"
|
||||
)
|
||||
if len(embeddings) == 0:
|
||||
raise ValueError(
|
||||
f"Expected embeddings to be a list with at least one item, got {len(embeddings)} embeddings"
|
||||
)
|
||||
if not all(isinstance(e, np.ndarray) for e in embeddings):
|
||||
raise ValueError(
|
||||
"Expected each embedding in the embeddings to be a numpy array"
|
||||
)
|
||||
for i, embedding in enumerate(embeddings):
|
||||
if embedding.ndim == 0:
|
||||
raise ValueError(
|
||||
f"Expected a 1-dimensional array, got a 0-dimensional array {embedding}"
|
||||
)
|
||||
if embedding.size == 0:
|
||||
raise ValueError(
|
||||
f"Expected each embedding to be a 1-dimensional numpy array with at least 1 value. "
|
||||
f"Got an array with no values at position {i}"
|
||||
)
|
||||
if not all(
|
||||
isinstance(value, (np.integer, float, np.floating))
|
||||
and not isinstance(value, bool)
|
||||
for value in embedding
|
||||
):
|
||||
raise ValueError(
|
||||
f"Expected embedding to contain numeric values, got non-numeric values at position {i}"
|
||||
)
|
||||
return embeddings
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class EmbeddingFunction(Protocol[D]):
|
||||
"""Protocol for embedding functions.
|
||||
|
||||
Embedding functions convert input data (documents or images) into vector embeddings.
|
||||
"""
|
||||
|
||||
def __call__(self, input: D) -> Embeddings:
|
||||
"""Convert input data to embeddings.
|
||||
|
||||
Args:
|
||||
input: Input data to embed (documents or images).
|
||||
|
||||
Returns:
|
||||
List of numpy arrays representing the embeddings.
|
||||
"""
|
||||
...
|
||||
|
||||
def __init_subclass__(cls) -> None:
|
||||
"""Wrap __call__ method to normalize and validate embeddings."""
|
||||
super().__init_subclass__()
|
||||
original_call = cls.__call__
|
||||
|
||||
def wrapped_call(self: EmbeddingFunction[D], input: D) -> Embeddings:
|
||||
result = original_call(self, input)
|
||||
if result is None:
|
||||
raise ValueError("Embedding function returned None")
|
||||
normalized = normalize_embeddings(result)
|
||||
if normalized is None:
|
||||
raise ValueError("Normalization returned None for non-None input")
|
||||
return validate_embeddings(normalized)
|
||||
|
||||
cls.__call__ = wrapped_call # type: ignore[method-assign]
|
||||
23
src/crewai/rag/core/base_embeddings_provider.py
Normal file
23
src/crewai/rag/core/base_embeddings_provider.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Base class for embedding providers."""
|
||||
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
||||
|
||||
T = TypeVar("T", bound=EmbeddingFunction)
|
||||
|
||||
|
||||
class BaseEmbeddingsProvider(BaseSettings, Generic[T]):
|
||||
"""Abstract base class for embedding providers.
|
||||
|
||||
This class provides a common interface for dynamically loading and building
|
||||
embedding functions from various providers.
|
||||
"""
|
||||
|
||||
model_config = SettingsConfigDict(extra="allow", populate_by_name=True)
|
||||
embedding_callable: type[T] = Field(
|
||||
..., description="The embedding function class to use"
|
||||
)
|
||||
28
src/crewai/rag/core/types.py
Normal file
28
src/crewai/rag/core/types.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Core type definitions for RAG systems."""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import TypeVar
|
||||
|
||||
import numpy as np
|
||||
from numpy import floating, integer, number
|
||||
from numpy.typing import NDArray
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
PyEmbedding = Sequence[float] | Sequence[int]
|
||||
PyEmbeddings = list[PyEmbedding]
|
||||
Embedding = NDArray[np.int32 | np.float32]
|
||||
Embeddings = list[Embedding]
|
||||
|
||||
Documents = list[str]
|
||||
Images = list[np.ndarray]
|
||||
Embeddable = Documents | Images
|
||||
|
||||
ScalarType = TypeVar("ScalarType", bound=np.generic)
|
||||
IntegerType = TypeVar("IntegerType", bound=integer)
|
||||
FloatingType = TypeVar("FloatingType", bound=floating)
|
||||
NumberType = TypeVar("NumberType", bound=number)
|
||||
|
||||
DType32 = TypeVar("DType32", np.int32, np.float32)
|
||||
DType64 = TypeVar("DType64", np.int64, np.float64)
|
||||
DTypeCommon = TypeVar("DTypeCommon", np.int32, np.int64, np.float32, np.float64)
|
||||
@@ -1 +1 @@
|
||||
"""Embedding components for RAG infrastructure."""
|
||||
"""Embedding components for RAG infrastructure."""
|
||||
|
||||
@@ -1,243 +0,0 @@
|
||||
import os
|
||||
from typing import Any, Dict, Optional, cast
|
||||
|
||||
from chromadb import Documents, EmbeddingFunction, Embeddings
|
||||
from chromadb.api.types import validate_embedding_function
|
||||
|
||||
|
||||
class EmbeddingConfigurator:
|
||||
def __init__(self):
|
||||
self.embedding_functions = {
|
||||
"openai": self._configure_openai,
|
||||
"azure": self._configure_azure,
|
||||
"ollama": self._configure_ollama,
|
||||
"vertexai": self._configure_vertexai,
|
||||
"google": self._configure_google,
|
||||
"cohere": self._configure_cohere,
|
||||
"voyageai": self._configure_voyageai,
|
||||
"bedrock": self._configure_bedrock,
|
||||
"huggingface": self._configure_huggingface,
|
||||
"watson": self._configure_watson,
|
||||
"custom": self._configure_custom,
|
||||
}
|
||||
|
||||
def configure_embedder(
|
||||
self,
|
||||
embedder_config: Optional[Dict[str, Any]] = None,
|
||||
) -> EmbeddingFunction:
|
||||
"""Configures and returns an embedding function based on the provided config."""
|
||||
if embedder_config is None:
|
||||
return self._create_default_embedding_function()
|
||||
|
||||
provider = embedder_config.get("provider")
|
||||
config = embedder_config.get("config", {})
|
||||
model_name = config.get("model") if provider != "custom" else None
|
||||
|
||||
if provider not in self.embedding_functions:
|
||||
raise Exception(
|
||||
f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}"
|
||||
)
|
||||
|
||||
try:
|
||||
embedding_function = self.embedding_functions[provider]
|
||||
except ImportError as e:
|
||||
missing_package = str(e).split()[-1]
|
||||
raise ImportError(
|
||||
f"{missing_package} is not installed. Please install it with: pip install {missing_package}"
|
||||
)
|
||||
|
||||
return (
|
||||
embedding_function(config)
|
||||
if provider == "custom"
|
||||
else embedding_function(config, model_name)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_default_embedding_function():
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingFunction(
|
||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_openai(config, model_name):
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingFunction(
|
||||
api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"),
|
||||
model_name=model_name,
|
||||
api_base=config.get("api_base", None),
|
||||
api_type=config.get("api_type", None),
|
||||
api_version=config.get("api_version", None),
|
||||
default_headers=config.get("default_headers", None),
|
||||
dimensions=config.get("dimensions", None),
|
||||
deployment_id=config.get("deployment_id", None),
|
||||
organization_id=config.get("organization_id", None),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_azure(config, model_name):
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingFunction(
|
||||
api_key=config.get("api_key"),
|
||||
api_base=config.get("api_base"),
|
||||
api_type=config.get("api_type", "azure"),
|
||||
api_version=config.get("api_version"),
|
||||
model_name=model_name,
|
||||
default_headers=config.get("default_headers"),
|
||||
dimensions=config.get("dimensions"),
|
||||
deployment_id=config.get("deployment_id"),
|
||||
organization_id=config.get("organization_id"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_ollama(config, model_name):
|
||||
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
||||
OllamaEmbeddingFunction,
|
||||
)
|
||||
|
||||
return OllamaEmbeddingFunction(
|
||||
url=config.get("url", "http://localhost:11434/api/embeddings"),
|
||||
model_name=model_name,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_vertexai(config, model_name):
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleVertexEmbeddingFunction,
|
||||
)
|
||||
|
||||
return GoogleVertexEmbeddingFunction(
|
||||
model_name=model_name,
|
||||
api_key=config.get("api_key"),
|
||||
project_id=config.get("project_id"),
|
||||
region=config.get("region"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_google(config, model_name):
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleGenerativeAiEmbeddingFunction,
|
||||
)
|
||||
|
||||
return GoogleGenerativeAiEmbeddingFunction(
|
||||
model_name=model_name,
|
||||
api_key=config.get("api_key"),
|
||||
task_type=config.get("task_type"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_cohere(config, model_name):
|
||||
from chromadb.utils.embedding_functions.cohere_embedding_function import (
|
||||
CohereEmbeddingFunction,
|
||||
)
|
||||
|
||||
return CohereEmbeddingFunction(
|
||||
model_name=model_name,
|
||||
api_key=config.get("api_key"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_voyageai(config, model_name):
|
||||
from chromadb.utils.embedding_functions.voyageai_embedding_function import (
|
||||
VoyageAIEmbeddingFunction,
|
||||
)
|
||||
|
||||
return VoyageAIEmbeddingFunction(
|
||||
model_name=model_name,
|
||||
api_key=config.get("api_key"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_bedrock(config, model_name):
|
||||
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
|
||||
AmazonBedrockEmbeddingFunction,
|
||||
)
|
||||
|
||||
# Allow custom model_name override with backwards compatibility
|
||||
kwargs = {"session": config.get("session")}
|
||||
if model_name is not None:
|
||||
kwargs["model_name"] = model_name
|
||||
return AmazonBedrockEmbeddingFunction(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _configure_huggingface(config, model_name):
|
||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||
HuggingFaceEmbeddingServer,
|
||||
)
|
||||
|
||||
return HuggingFaceEmbeddingServer(
|
||||
url=config.get("api_url"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_watson(config, model_name):
|
||||
try:
|
||||
import ibm_watsonx_ai.foundation_models as watson_models
|
||||
from ibm_watsonx_ai import Credentials
|
||||
from ibm_watsonx_ai.metanames import EmbedTextParamsMetaNames as EmbedParams
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"IBM Watson dependencies are not installed. Please install them to use Watson embedding."
|
||||
) from e
|
||||
|
||||
class WatsonEmbeddingFunction(EmbeddingFunction):
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
||||
embed_params = {
|
||||
EmbedParams.TRUNCATE_INPUT_TOKENS: 3,
|
||||
EmbedParams.RETURN_OPTIONS: {"input_text": True},
|
||||
}
|
||||
|
||||
embedding = watson_models.Embeddings(
|
||||
model_id=config.get("model"),
|
||||
params=embed_params,
|
||||
credentials=Credentials(
|
||||
api_key=config.get("api_key"), url=config.get("api_url")
|
||||
),
|
||||
project_id=config.get("project_id"),
|
||||
)
|
||||
|
||||
try:
|
||||
embeddings = embedding.embed_documents(input)
|
||||
return cast(Embeddings, embeddings)
|
||||
except Exception as e:
|
||||
print("Error during Watson embedding:", e)
|
||||
raise e
|
||||
|
||||
return WatsonEmbeddingFunction()
|
||||
|
||||
@staticmethod
|
||||
def _configure_custom(config):
|
||||
custom_embedder = config.get("embedder")
|
||||
if isinstance(custom_embedder, EmbeddingFunction):
|
||||
try:
|
||||
validate_embedding_function(custom_embedder)
|
||||
return custom_embedder
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid custom embedding function: {str(e)}")
|
||||
elif callable(custom_embedder):
|
||||
try:
|
||||
instance = custom_embedder()
|
||||
if isinstance(instance, EmbeddingFunction):
|
||||
validate_embedding_function(instance)
|
||||
return instance
|
||||
raise ValueError(
|
||||
"Custom embedder does not create an EmbeddingFunction instance"
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error instantiating custom embedder: {str(e)}")
|
||||
else:
|
||||
raise ValueError(
|
||||
"Custom embedder must be an instance of `EmbeddingFunction` or a callable that creates one"
|
||||
)
|
||||
@@ -1,148 +1,363 @@
|
||||
"""Minimal embedding function factory for CrewAI."""
|
||||
"""Factory functions for creating embedding providers and functions."""
|
||||
|
||||
import os
|
||||
from __future__ import annotations
|
||||
|
||||
from chromadb import EmbeddingFunction
|
||||
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
|
||||
AmazonBedrockEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.cohere_embedding_function import (
|
||||
CohereEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleGenerativeAiEmbeddingFunction,
|
||||
GooglePalmEmbeddingFunction,
|
||||
GoogleVertexEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||
HuggingFaceEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.instructor_embedding_function import (
|
||||
InstructorEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.jina_embedding_function import (
|
||||
JinaEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
||||
OllamaEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ONNXMiniLM_L6_V2
|
||||
from chromadb.utils.embedding_functions.open_clip_embedding_function import (
|
||||
OpenCLIPEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.roboflow_embedding_function import (
|
||||
RoboflowEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import (
|
||||
SentenceTransformerEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.text2vec_embedding_function import (
|
||||
Text2VecEmbeddingFunction,
|
||||
)
|
||||
from typing import TYPE_CHECKING, TypeVar, overload
|
||||
|
||||
from crewai.rag.embeddings.types import EmbeddingOptions
|
||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.utilities.import_utils import import_and_validate_definition
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
|
||||
AmazonBedrockEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.cohere_embedding_function import (
|
||||
CohereEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleGenerativeAiEmbeddingFunction,
|
||||
GoogleVertexEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||
HuggingFaceEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.instructor_embedding_function import (
|
||||
InstructorEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.jina_embedding_function import (
|
||||
JinaEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
||||
OllamaEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ONNXMiniLM_L6_V2
|
||||
from chromadb.utils.embedding_functions.open_clip_embedding_function import (
|
||||
OpenCLIPEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.roboflow_embedding_function import (
|
||||
RoboflowEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import (
|
||||
SentenceTransformerEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.text2vec_embedding_function import (
|
||||
Text2VecEmbeddingFunction,
|
||||
)
|
||||
|
||||
from crewai.rag.embeddings.providers.aws.types import BedrockProviderSpec
|
||||
from crewai.rag.embeddings.providers.cohere.types import CohereProviderSpec
|
||||
from crewai.rag.embeddings.providers.custom.types import CustomProviderSpec
|
||||
from crewai.rag.embeddings.providers.google.types import (
|
||||
GenerativeAiProviderSpec,
|
||||
VertexAIProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.huggingface.types import (
|
||||
HuggingFaceProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.ibm.embedding_callable import (
|
||||
WatsonEmbeddingFunction,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.ibm.types import WatsonProviderSpec
|
||||
from crewai.rag.embeddings.providers.instructor.types import InstructorProviderSpec
|
||||
from crewai.rag.embeddings.providers.jina.types import JinaProviderSpec
|
||||
from crewai.rag.embeddings.providers.microsoft.types import AzureProviderSpec
|
||||
from crewai.rag.embeddings.providers.ollama.types import OllamaProviderSpec
|
||||
from crewai.rag.embeddings.providers.onnx.types import ONNXProviderSpec
|
||||
from crewai.rag.embeddings.providers.openai.types import OpenAIProviderSpec
|
||||
from crewai.rag.embeddings.providers.openclip.types import OpenCLIPProviderSpec
|
||||
from crewai.rag.embeddings.providers.roboflow.types import RoboflowProviderSpec
|
||||
from crewai.rag.embeddings.providers.sentence_transformer.types import (
|
||||
SentenceTransformerProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.text2vec.types import Text2VecProviderSpec
|
||||
from crewai.rag.embeddings.providers.voyageai.embedding_callable import (
|
||||
VoyageAIEmbeddingFunction,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderSpec
|
||||
|
||||
T = TypeVar("T", bound=EmbeddingFunction)
|
||||
|
||||
|
||||
def get_embedding_function(
|
||||
config: EmbeddingOptions | dict | None = None,
|
||||
) -> EmbeddingFunction:
|
||||
"""Get embedding function - delegates to ChromaDB.
|
||||
PROVIDER_PATHS = {
|
||||
"azure": "crewai.rag.embeddings.providers.microsoft.azure.AzureProvider",
|
||||
"amazon-bedrock": "crewai.rag.embeddings.providers.aws.bedrock.BedrockProvider",
|
||||
"cohere": "crewai.rag.embeddings.providers.cohere.cohere_provider.CohereProvider",
|
||||
"custom": "crewai.rag.embeddings.providers.custom.custom_provider.CustomProvider",
|
||||
"google-generativeai": "crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider",
|
||||
"google-vertex": "crewai.rag.embeddings.providers.google.vertex.VertexAIProvider",
|
||||
"huggingface": "crewai.rag.embeddings.providers.huggingface.huggingface_provider.HuggingFaceProvider",
|
||||
"instructor": "crewai.rag.embeddings.providers.instructor.instructor_provider.InstructorProvider",
|
||||
"jina": "crewai.rag.embeddings.providers.jina.jina_provider.JinaProvider",
|
||||
"ollama": "crewai.rag.embeddings.providers.ollama.ollama_provider.OllamaProvider",
|
||||
"onnx": "crewai.rag.embeddings.providers.onnx.onnx_provider.ONNXProvider",
|
||||
"openai": "crewai.rag.embeddings.providers.openai.openai_provider.OpenAIProvider",
|
||||
"openclip": "crewai.rag.embeddings.providers.openclip.openclip_provider.OpenCLIPProvider",
|
||||
"roboflow": "crewai.rag.embeddings.providers.roboflow.roboflow_provider.RoboflowProvider",
|
||||
"sentence-transformer": "crewai.rag.embeddings.providers.sentence_transformer.sentence_transformer_provider.SentenceTransformerProvider",
|
||||
"text2vec": "crewai.rag.embeddings.providers.text2vec.text2vec_provider.Text2VecProvider",
|
||||
"voyageai": "crewai.rag.embeddings.providers.voyageai.voyageai_provider.VoyageAIProvider",
|
||||
"watson": "crewai.rag.embeddings.providers.ibm.watson.WatsonProvider",
|
||||
}
|
||||
|
||||
|
||||
def build_embedder_from_provider(provider: BaseEmbeddingsProvider[T]) -> T:
|
||||
"""Build an embedding function instance from a provider.
|
||||
|
||||
Args:
|
||||
config: Optional configuration - either an EmbeddingOptions object or a dict with:
|
||||
- provider: The embedding provider to use (default: "openai")
|
||||
- Any other provider-specific parameters
|
||||
provider: The embedding provider configuration.
|
||||
|
||||
Returns:
|
||||
EmbeddingFunction instance ready for use with ChromaDB
|
||||
An instance of the specified embedding function type.
|
||||
"""
|
||||
return provider.embedding_callable(
|
||||
**provider.model_dump(exclude={"embedding_callable"})
|
||||
)
|
||||
|
||||
Supported providers:
|
||||
- openai: OpenAI embeddings
|
||||
- cohere: Cohere embeddings
|
||||
- ollama: Ollama local embeddings
|
||||
- huggingface: HuggingFace embeddings
|
||||
- sentence-transformer: Local sentence transformers
|
||||
- instructor: Instructor embeddings for specialized tasks
|
||||
- google-palm: Google PaLM embeddings
|
||||
- google-generativeai: Google Generative AI embeddings
|
||||
- google-vertex: Google Vertex AI embeddings
|
||||
- amazon-bedrock: AWS Bedrock embeddings
|
||||
- jina: Jina AI embeddings
|
||||
- roboflow: Roboflow embeddings for vision tasks
|
||||
- openclip: OpenCLIP embeddings for multimodal tasks
|
||||
- text2vec: Text2Vec embeddings
|
||||
- onnx: ONNX MiniLM-L6-v2 (no API key needed, included with ChromaDB)
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(spec: AzureProviderSpec) -> OpenAIEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: BedrockProviderSpec,
|
||||
) -> AmazonBedrockEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(spec: CohereProviderSpec) -> CohereEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(spec: CustomProviderSpec) -> EmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: GenerativeAiProviderSpec,
|
||||
) -> GoogleGenerativeAiEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: HuggingFaceProviderSpec,
|
||||
) -> HuggingFaceEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(spec: OllamaProviderSpec) -> OllamaEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(spec: OpenAIProviderSpec) -> OpenAIEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: VertexAIProviderSpec,
|
||||
) -> GoogleVertexEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: VoyageAIProviderSpec,
|
||||
) -> VoyageAIEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(spec: WatsonProviderSpec) -> WatsonEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: SentenceTransformerProviderSpec,
|
||||
) -> SentenceTransformerEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: InstructorProviderSpec,
|
||||
) -> InstructorEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(spec: JinaProviderSpec) -> JinaEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: RoboflowProviderSpec,
|
||||
) -> RoboflowEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: OpenCLIPProviderSpec,
|
||||
) -> OpenCLIPEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: Text2VecProviderSpec,
|
||||
) -> Text2VecEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(spec: ONNXProviderSpec) -> ONNXMiniLM_L6_V2: ...
|
||||
|
||||
|
||||
def build_embedder_from_dict(spec):
|
||||
"""Build an embedding function instance from a dictionary specification.
|
||||
|
||||
Args:
|
||||
spec: A dictionary with 'provider' and 'config' keys.
|
||||
Example: {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"api_key": "sk-...",
|
||||
"model_name": "text-embedding-3-small"
|
||||
}
|
||||
}
|
||||
|
||||
Returns:
|
||||
An instance of the appropriate embedding function.
|
||||
|
||||
Raises:
|
||||
ValueError: If the provider is not recognized.
|
||||
"""
|
||||
provider_name = spec["provider"]
|
||||
if not provider_name:
|
||||
raise ValueError("Missing 'provider' key in specification")
|
||||
|
||||
if provider_name not in PROVIDER_PATHS:
|
||||
raise ValueError(
|
||||
f"Unknown provider: {provider_name}. Available providers: {list(PROVIDER_PATHS.keys())}"
|
||||
)
|
||||
|
||||
provider_path = PROVIDER_PATHS[provider_name]
|
||||
try:
|
||||
provider_class = import_and_validate_definition(provider_path)
|
||||
except (ImportError, AttributeError, ValueError) as e:
|
||||
raise ImportError(f"Failed to import provider {provider_name}: {e}") from e
|
||||
|
||||
provider_config = spec.get("config", {})
|
||||
|
||||
if provider_name == "custom" and "embedding_callable" not in provider_config:
|
||||
raise ValueError("Custom provider requires 'embedding_callable' in config")
|
||||
|
||||
provider = provider_class(**provider_config)
|
||||
return build_embedder_from_provider(provider)
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: BaseEmbeddingsProvider[T]) -> T: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: AzureProviderSpec) -> OpenAIEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: BedrockProviderSpec) -> AmazonBedrockEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: CohereProviderSpec) -> CohereEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: CustomProviderSpec) -> EmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(
|
||||
spec: GenerativeAiProviderSpec,
|
||||
) -> GoogleGenerativeAiEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: HuggingFaceProviderSpec) -> HuggingFaceEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: OllamaProviderSpec) -> OllamaEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: OpenAIProviderSpec) -> OpenAIEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: VertexAIProviderSpec) -> GoogleVertexEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: VoyageAIProviderSpec) -> VoyageAIEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: WatsonProviderSpec) -> WatsonEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(
|
||||
spec: SentenceTransformerProviderSpec,
|
||||
) -> SentenceTransformerEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: InstructorProviderSpec) -> InstructorEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: JinaProviderSpec) -> JinaEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: RoboflowProviderSpec) -> RoboflowEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: OpenCLIPProviderSpec) -> OpenCLIPEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: Text2VecProviderSpec) -> Text2VecEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: ONNXProviderSpec) -> ONNXMiniLM_L6_V2: ...
|
||||
|
||||
|
||||
def build_embedder(spec):
|
||||
"""Build an embedding function from either a provider spec or a provider instance.
|
||||
|
||||
Args:
|
||||
spec: Either a provider specification dictionary or a provider instance.
|
||||
|
||||
Returns:
|
||||
An embedding function instance. If a typed provider is passed, returns
|
||||
the specific embedding function type.
|
||||
|
||||
Examples:
|
||||
# Use default OpenAI embedding
|
||||
>>> embedder = get_embedding_function()
|
||||
# From dictionary specification
|
||||
embedder = build_embedder({
|
||||
"provider": "openai",
|
||||
"config": {"api_key": "sk-..."}
|
||||
})
|
||||
|
||||
# Use Cohere with dict
|
||||
>>> embedder = get_embedding_function({
|
||||
... "provider": "cohere",
|
||||
... "api_key": "your-key",
|
||||
... "model_name": "embed-english-v3.0"
|
||||
... })
|
||||
|
||||
# Use with EmbeddingOptions
|
||||
>>> embedder = get_embedding_function(
|
||||
... EmbeddingOptions(provider="sentence-transformer", model_name="all-MiniLM-L6-v2")
|
||||
... )
|
||||
|
||||
# Use local sentence transformers (no API key needed)
|
||||
>>> embedder = get_embedding_function({
|
||||
... "provider": "sentence-transformer",
|
||||
... "model_name": "all-MiniLM-L6-v2"
|
||||
... })
|
||||
|
||||
# Use Ollama for local embeddings
|
||||
>>> embedder = get_embedding_function({
|
||||
... "provider": "ollama",
|
||||
... "model_name": "nomic-embed-text"
|
||||
... })
|
||||
|
||||
# Use ONNX (no API key needed)
|
||||
>>> embedder = get_embedding_function({
|
||||
... "provider": "onnx"
|
||||
... })
|
||||
# From provider instance
|
||||
provider = OpenAIProvider(api_key="sk-...")
|
||||
embedder = build_embedder(provider)
|
||||
"""
|
||||
if config is None:
|
||||
return OpenAIEmbeddingFunction(
|
||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
||||
)
|
||||
if isinstance(spec, BaseEmbeddingsProvider):
|
||||
return build_embedder_from_provider(spec)
|
||||
return build_embedder_from_dict(spec)
|
||||
|
||||
# Handle EmbeddingOptions object
|
||||
if isinstance(config, EmbeddingOptions):
|
||||
config_dict = config.model_dump(exclude_none=True)
|
||||
else:
|
||||
config_dict = config.copy()
|
||||
|
||||
provider = config_dict.pop("provider", "openai")
|
||||
|
||||
embedding_functions = {
|
||||
"openai": OpenAIEmbeddingFunction,
|
||||
"cohere": CohereEmbeddingFunction,
|
||||
"ollama": OllamaEmbeddingFunction,
|
||||
"huggingface": HuggingFaceEmbeddingFunction,
|
||||
"sentence-transformer": SentenceTransformerEmbeddingFunction,
|
||||
"instructor": InstructorEmbeddingFunction,
|
||||
"google-palm": GooglePalmEmbeddingFunction,
|
||||
"google-generativeai": GoogleGenerativeAiEmbeddingFunction,
|
||||
"google-vertex": GoogleVertexEmbeddingFunction,
|
||||
"amazon-bedrock": AmazonBedrockEmbeddingFunction,
|
||||
"jina": JinaEmbeddingFunction,
|
||||
"roboflow": RoboflowEmbeddingFunction,
|
||||
"openclip": OpenCLIPEmbeddingFunction,
|
||||
"text2vec": Text2VecEmbeddingFunction,
|
||||
"onnx": ONNXMiniLM_L6_V2,
|
||||
}
|
||||
|
||||
if provider not in embedding_functions:
|
||||
raise ValueError(
|
||||
f"Unsupported provider: {provider}. "
|
||||
f"Available providers: {list(embedding_functions.keys())}"
|
||||
)
|
||||
return embedding_functions[provider](**config_dict)
|
||||
# Backward compatibility alias
|
||||
get_embedding_function = build_embedder
|
||||
|
||||
1
src/crewai/rag/embeddings/providers/__init__.py
Normal file
1
src/crewai/rag/embeddings/providers/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Embedding provider implementations."""
|
||||
13
src/crewai/rag/embeddings/providers/aws/__init__.py
Normal file
13
src/crewai/rag/embeddings/providers/aws/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""AWS embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.aws.bedrock import BedrockProvider
|
||||
from crewai.rag.embeddings.providers.aws.types import (
|
||||
BedrockProviderConfig,
|
||||
BedrockProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BedrockProvider",
|
||||
"BedrockProviderConfig",
|
||||
"BedrockProviderSpec",
|
||||
]
|
||||
53
src/crewai/rag/embeddings/providers/aws/bedrock.py
Normal file
53
src/crewai/rag/embeddings/providers/aws/bedrock.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""Amazon Bedrock embeddings provider."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
|
||||
AmazonBedrockEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
def create_aws_session() -> Any:
|
||||
"""Create an AWS session for Bedrock.
|
||||
|
||||
Returns:
|
||||
boto3.Session: AWS session object
|
||||
|
||||
Raises:
|
||||
ImportError: If boto3 is not installed
|
||||
ValueError: If AWS session creation fails
|
||||
"""
|
||||
try:
|
||||
import boto3 # type: ignore[import]
|
||||
|
||||
return boto3.Session()
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"boto3 is required for amazon-bedrock embeddings. "
|
||||
"Install it with: uv add boto3"
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Failed to create AWS session for amazon-bedrock. "
|
||||
f"Ensure AWS credentials are configured. Error: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
class BedrockProvider(BaseEmbeddingsProvider[AmazonBedrockEmbeddingFunction]):
|
||||
"""Amazon Bedrock embeddings provider."""
|
||||
|
||||
embedding_callable: type[AmazonBedrockEmbeddingFunction] = Field(
|
||||
default=AmazonBedrockEmbeddingFunction,
|
||||
description="Amazon Bedrock embedding function class",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="amazon.titan-embed-text-v1",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="BEDROCK_MODEL_NAME",
|
||||
)
|
||||
session: Any = Field(
|
||||
default_factory=create_aws_session, description="AWS session object"
|
||||
)
|
||||
19
src/crewai/rag/embeddings/providers/aws/types.py
Normal file
19
src/crewai/rag/embeddings/providers/aws/types.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Type definitions for AWS embedding providers."""
|
||||
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class BedrockProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Bedrock provider."""
|
||||
|
||||
model_name: Annotated[str, "amazon.titan-embed-text-v1"]
|
||||
session: Any
|
||||
|
||||
|
||||
class BedrockProviderSpec(TypedDict, total=False):
|
||||
"""Bedrock provider specification."""
|
||||
|
||||
provider: Required[Literal["amazon-bedrock"]]
|
||||
config: BedrockProviderConfig
|
||||
13
src/crewai/rag/embeddings/providers/cohere/__init__.py
Normal file
13
src/crewai/rag/embeddings/providers/cohere/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""Cohere embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.cohere.cohere_provider import CohereProvider
|
||||
from crewai.rag.embeddings.providers.cohere.types import (
|
||||
CohereProviderConfig,
|
||||
CohereProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CohereProvider",
|
||||
"CohereProviderConfig",
|
||||
"CohereProviderSpec",
|
||||
]
|
||||
@@ -0,0 +1,24 @@
|
||||
"""Cohere embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.cohere_embedding_function import (
|
||||
CohereEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class CohereProvider(BaseEmbeddingsProvider[CohereEmbeddingFunction]):
|
||||
"""Cohere embeddings provider."""
|
||||
|
||||
embedding_callable: type[CohereEmbeddingFunction] = Field(
|
||||
default=CohereEmbeddingFunction, description="Cohere embedding function class"
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Cohere API key", validation_alias="COHERE_API_KEY"
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="large",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="COHERE_MODEL_NAME",
|
||||
)
|
||||
19
src/crewai/rag/embeddings/providers/cohere/types.py
Normal file
19
src/crewai/rag/embeddings/providers/cohere/types.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Type definitions for Cohere embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class CohereProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Cohere provider."""
|
||||
|
||||
api_key: str
|
||||
model_name: Annotated[str, "large"]
|
||||
|
||||
|
||||
class CohereProviderSpec(TypedDict, total=False):
|
||||
"""Cohere provider specification."""
|
||||
|
||||
provider: Required[Literal["cohere"]]
|
||||
config: CohereProviderConfig
|
||||
13
src/crewai/rag/embeddings/providers/custom/__init__.py
Normal file
13
src/crewai/rag/embeddings/providers/custom/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""Custom embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.custom.custom_provider import CustomProvider
|
||||
from crewai.rag.embeddings.providers.custom.types import (
|
||||
CustomProviderConfig,
|
||||
CustomProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CustomProvider",
|
||||
"CustomProviderConfig",
|
||||
"CustomProviderSpec",
|
||||
]
|
||||
@@ -0,0 +1,19 @@
|
||||
"""Custom embeddings provider for user-defined embedding functions."""
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import SettingsConfigDict
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.providers.custom.embedding_callable import (
|
||||
CustomEmbeddingFunction,
|
||||
)
|
||||
|
||||
|
||||
class CustomProvider(BaseEmbeddingsProvider[CustomEmbeddingFunction]):
|
||||
"""Custom embeddings provider for user-defined embedding functions."""
|
||||
|
||||
embedding_callable: type[CustomEmbeddingFunction] = Field(
|
||||
..., description="Custom embedding function class"
|
||||
)
|
||||
|
||||
model_config = SettingsConfigDict(extra="allow")
|
||||
@@ -0,0 +1,22 @@
|
||||
"""Custom embedding function base implementation."""
|
||||
|
||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
||||
from crewai.rag.core.types import Documents, Embeddings
|
||||
|
||||
|
||||
class CustomEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
"""Base class for custom embedding functions.
|
||||
|
||||
This provides a concrete implementation that can be subclassed for custom embeddings.
|
||||
"""
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
"""Convert input documents to embeddings.
|
||||
|
||||
Args:
|
||||
input: List of documents to embed.
|
||||
|
||||
Returns:
|
||||
List of numpy arrays representing the embeddings.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement __call__ method")
|
||||
19
src/crewai/rag/embeddings/providers/custom/types.py
Normal file
19
src/crewai/rag/embeddings/providers/custom/types.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Type definitions for custom embedding providers."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from chromadb.api.types import EmbeddingFunction
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class CustomProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Custom provider."""
|
||||
|
||||
embedding_callable: type[EmbeddingFunction]
|
||||
|
||||
|
||||
class CustomProviderSpec(TypedDict, total=False):
|
||||
"""Custom provider specification."""
|
||||
|
||||
provider: Required[Literal["custom"]]
|
||||
config: CustomProviderConfig
|
||||
23
src/crewai/rag/embeddings/providers/google/__init__.py
Normal file
23
src/crewai/rag/embeddings/providers/google/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Google embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.google.generative_ai import (
|
||||
GenerativeAiProvider,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.google.types import (
|
||||
GenerativeAiProviderConfig,
|
||||
GenerativeAiProviderSpec,
|
||||
VertexAIProviderConfig,
|
||||
VertexAIProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.google.vertex import (
|
||||
VertexAIProvider,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"GenerativeAiProvider",
|
||||
"GenerativeAiProviderConfig",
|
||||
"GenerativeAiProviderSpec",
|
||||
"VertexAIProvider",
|
||||
"VertexAIProviderConfig",
|
||||
"VertexAIProviderSpec",
|
||||
]
|
||||
30
src/crewai/rag/embeddings/providers/google/generative_ai.py
Normal file
30
src/crewai/rag/embeddings/providers/google/generative_ai.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""Google Generative AI embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleGenerativeAiEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class GenerativeAiProvider(BaseEmbeddingsProvider[GoogleGenerativeAiEmbeddingFunction]):
|
||||
"""Google Generative AI embeddings provider."""
|
||||
|
||||
embedding_callable: type[GoogleGenerativeAiEmbeddingFunction] = Field(
|
||||
default=GoogleGenerativeAiEmbeddingFunction,
|
||||
description="Google Generative AI embedding function class",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="models/embedding-001",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="GOOGLE_GENERATIVE_AI_MODEL_NAME",
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Google API key", validation_alias="GOOGLE_API_KEY"
|
||||
)
|
||||
task_type: str = Field(
|
||||
default="RETRIEVAL_DOCUMENT",
|
||||
description="Task type for embeddings",
|
||||
validation_alias="GOOGLE_GENERATIVE_AI_TASK_TYPE",
|
||||
)
|
||||
36
src/crewai/rag/embeddings/providers/google/types.py
Normal file
36
src/crewai/rag/embeddings/providers/google/types.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""Type definitions for Google embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class GenerativeAiProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Google Generative AI provider."""
|
||||
|
||||
api_key: str
|
||||
model_name: Annotated[str, "models/embedding-001"]
|
||||
task_type: Annotated[str, "RETRIEVAL_DOCUMENT"]
|
||||
|
||||
|
||||
class GenerativeAiProviderSpec(TypedDict):
|
||||
"""Google Generative AI provider specification."""
|
||||
|
||||
provider: Literal["google-generativeai"]
|
||||
config: GenerativeAiProviderConfig
|
||||
|
||||
|
||||
class VertexAIProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Vertex AI provider."""
|
||||
|
||||
api_key: str
|
||||
model_name: Annotated[str, "textembedding-gecko"]
|
||||
project_id: Annotated[str, "cloud-large-language-models"]
|
||||
region: Annotated[str, "us-central1"]
|
||||
|
||||
|
||||
class VertexAIProviderSpec(TypedDict, total=False):
|
||||
"""Vertex AI provider specification."""
|
||||
|
||||
provider: Required[Literal["google-vertex"]]
|
||||
config: VertexAIProviderConfig
|
||||
35
src/crewai/rag/embeddings/providers/google/vertex.py
Normal file
35
src/crewai/rag/embeddings/providers/google/vertex.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""Google Vertex AI embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleVertexEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class VertexAIProvider(BaseEmbeddingsProvider[GoogleVertexEmbeddingFunction]):
|
||||
"""Google Vertex AI embeddings provider."""
|
||||
|
||||
embedding_callable: type[GoogleVertexEmbeddingFunction] = Field(
|
||||
default=GoogleVertexEmbeddingFunction,
|
||||
description="Vertex AI embedding function class",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="textembedding-gecko",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="GOOGLE_VERTEX_MODEL_NAME",
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Google API key", validation_alias="GOOGLE_CLOUD_API_KEY"
|
||||
)
|
||||
project_id: str = Field(
|
||||
default="cloud-large-language-models",
|
||||
description="GCP project ID",
|
||||
validation_alias="GOOGLE_CLOUD_PROJECT",
|
||||
)
|
||||
region: str = Field(
|
||||
default="us-central1",
|
||||
description="GCP region",
|
||||
validation_alias="GOOGLE_CLOUD_REGION",
|
||||
)
|
||||
15
src/crewai/rag/embeddings/providers/huggingface/__init__.py
Normal file
15
src/crewai/rag/embeddings/providers/huggingface/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""HuggingFace embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.huggingface.huggingface_provider import (
|
||||
HuggingFaceProvider,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.huggingface.types import (
|
||||
HuggingFaceProviderConfig,
|
||||
HuggingFaceProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"HuggingFaceProvider",
|
||||
"HuggingFaceProviderConfig",
|
||||
"HuggingFaceProviderSpec",
|
||||
]
|
||||
@@ -0,0 +1,20 @@
|
||||
"""HuggingFace embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||
HuggingFaceEmbeddingServer,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class HuggingFaceProvider(BaseEmbeddingsProvider[HuggingFaceEmbeddingServer]):
|
||||
"""HuggingFace embeddings provider."""
|
||||
|
||||
embedding_callable: type[HuggingFaceEmbeddingServer] = Field(
|
||||
default=HuggingFaceEmbeddingServer,
|
||||
description="HuggingFace embedding function class",
|
||||
)
|
||||
url: str = Field(
|
||||
description="HuggingFace API URL", validation_alias="HUGGINGFACE_URL"
|
||||
)
|
||||
18
src/crewai/rag/embeddings/providers/huggingface/types.py
Normal file
18
src/crewai/rag/embeddings/providers/huggingface/types.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""Type definitions for HuggingFace embedding providers."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class HuggingFaceProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for HuggingFace provider."""
|
||||
|
||||
url: str
|
||||
|
||||
|
||||
class HuggingFaceProviderSpec(TypedDict, total=False):
|
||||
"""HuggingFace provider specification."""
|
||||
|
||||
provider: Required[Literal["huggingface"]]
|
||||
config: HuggingFaceProviderConfig
|
||||
15
src/crewai/rag/embeddings/providers/ibm/__init__.py
Normal file
15
src/crewai/rag/embeddings/providers/ibm/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""IBM embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.ibm.types import (
|
||||
WatsonProviderConfig,
|
||||
WatsonProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.ibm.watson import (
|
||||
WatsonProvider,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"WatsonProvider",
|
||||
"WatsonProviderConfig",
|
||||
"WatsonProviderSpec",
|
||||
]
|
||||
154
src/crewai/rag/embeddings/providers/ibm/embedding_callable.py
Normal file
154
src/crewai/rag/embeddings/providers/ibm/embedding_callable.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""IBM Watson embedding function implementation."""
|
||||
|
||||
from typing import cast
|
||||
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
||||
from crewai.rag.core.types import Documents, Embeddings
|
||||
from crewai.rag.embeddings.providers.ibm.types import WatsonProviderConfig
|
||||
|
||||
|
||||
class WatsonEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
"""Embedding function for IBM Watson models."""
|
||||
|
||||
def __init__(self, **kwargs: Unpack[WatsonProviderConfig]) -> None:
|
||||
"""Initialize Watson embedding function.
|
||||
|
||||
Args:
|
||||
**kwargs: Configuration parameters for Watson Embeddings and Credentials.
|
||||
"""
|
||||
self._config = kwargs
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
"""Generate embeddings for input documents.
|
||||
|
||||
Args:
|
||||
input: List of documents to embed.
|
||||
|
||||
Returns:
|
||||
List of embedding vectors.
|
||||
"""
|
||||
try:
|
||||
import ibm_watsonx_ai.foundation_models as watson_models # type: ignore[import-not-found, import-untyped]
|
||||
from ibm_watsonx_ai import (
|
||||
Credentials, # type: ignore[import-not-found, import-untyped]
|
||||
)
|
||||
from ibm_watsonx_ai.metanames import ( # type: ignore[import-not-found, import-untyped]
|
||||
EmbedTextParamsMetaNames as EmbedParams,
|
||||
)
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"ibm-watsonx-ai is required for watson embeddings. "
|
||||
"Install it with: uv add ibm-watsonx-ai"
|
||||
) from e
|
||||
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
||||
embeddings_config: dict = {
|
||||
"model_id": self._config["model_id"],
|
||||
}
|
||||
if "params" in self._config and self._config["params"] is not None:
|
||||
embeddings_config["params"] = self._config["params"]
|
||||
if "project_id" in self._config and self._config["project_id"] is not None:
|
||||
embeddings_config["project_id"] = self._config["project_id"]
|
||||
if "space_id" in self._config and self._config["space_id"] is not None:
|
||||
embeddings_config["space_id"] = self._config["space_id"]
|
||||
if "api_client" in self._config and self._config["api_client"] is not None:
|
||||
embeddings_config["api_client"] = self._config["api_client"]
|
||||
if "verify" in self._config and self._config["verify"] is not None:
|
||||
embeddings_config["verify"] = self._config["verify"]
|
||||
if "persistent_connection" in self._config:
|
||||
embeddings_config["persistent_connection"] = self._config[
|
||||
"persistent_connection"
|
||||
]
|
||||
if "batch_size" in self._config:
|
||||
embeddings_config["batch_size"] = self._config["batch_size"]
|
||||
if "concurrency_limit" in self._config:
|
||||
embeddings_config["concurrency_limit"] = self._config["concurrency_limit"]
|
||||
if "max_retries" in self._config and self._config["max_retries"] is not None:
|
||||
embeddings_config["max_retries"] = self._config["max_retries"]
|
||||
if "delay_time" in self._config and self._config["delay_time"] is not None:
|
||||
embeddings_config["delay_time"] = self._config["delay_time"]
|
||||
if (
|
||||
"retry_status_codes" in self._config
|
||||
and self._config["retry_status_codes"] is not None
|
||||
):
|
||||
embeddings_config["retry_status_codes"] = self._config["retry_status_codes"]
|
||||
|
||||
if "credentials" in self._config and self._config["credentials"] is not None:
|
||||
embeddings_config["credentials"] = self._config["credentials"]
|
||||
else:
|
||||
cred_config: dict = {}
|
||||
if "url" in self._config and self._config["url"] is not None:
|
||||
cred_config["url"] = self._config["url"]
|
||||
if "api_key" in self._config and self._config["api_key"] is not None:
|
||||
cred_config["api_key"] = self._config["api_key"]
|
||||
if "name" in self._config and self._config["name"] is not None:
|
||||
cred_config["name"] = self._config["name"]
|
||||
if (
|
||||
"iam_serviceid_crn" in self._config
|
||||
and self._config["iam_serviceid_crn"] is not None
|
||||
):
|
||||
cred_config["iam_serviceid_crn"] = self._config["iam_serviceid_crn"]
|
||||
if (
|
||||
"trusted_profile_id" in self._config
|
||||
and self._config["trusted_profile_id"] is not None
|
||||
):
|
||||
cred_config["trusted_profile_id"] = self._config["trusted_profile_id"]
|
||||
if "token" in self._config and self._config["token"] is not None:
|
||||
cred_config["token"] = self._config["token"]
|
||||
if (
|
||||
"projects_token" in self._config
|
||||
and self._config["projects_token"] is not None
|
||||
):
|
||||
cred_config["projects_token"] = self._config["projects_token"]
|
||||
if "username" in self._config and self._config["username"] is not None:
|
||||
cred_config["username"] = self._config["username"]
|
||||
if "password" in self._config and self._config["password"] is not None:
|
||||
cred_config["password"] = self._config["password"]
|
||||
if (
|
||||
"instance_id" in self._config
|
||||
and self._config["instance_id"] is not None
|
||||
):
|
||||
cred_config["instance_id"] = self._config["instance_id"]
|
||||
if "version" in self._config and self._config["version"] is not None:
|
||||
cred_config["version"] = self._config["version"]
|
||||
if (
|
||||
"bedrock_url" in self._config
|
||||
and self._config["bedrock_url"] is not None
|
||||
):
|
||||
cred_config["bedrock_url"] = self._config["bedrock_url"]
|
||||
if (
|
||||
"platform_url" in self._config
|
||||
and self._config["platform_url"] is not None
|
||||
):
|
||||
cred_config["platform_url"] = self._config["platform_url"]
|
||||
if "proxies" in self._config and self._config["proxies"] is not None:
|
||||
cred_config["proxies"] = self._config["proxies"]
|
||||
if (
|
||||
"verify" not in embeddings_config
|
||||
and "verify" in self._config
|
||||
and self._config["verify"] is not None
|
||||
):
|
||||
cred_config["verify"] = self._config["verify"]
|
||||
|
||||
if cred_config:
|
||||
embeddings_config["credentials"] = Credentials(**cred_config)
|
||||
|
||||
if "params" not in embeddings_config:
|
||||
embeddings_config["params"] = {
|
||||
EmbedParams.TRUNCATE_INPUT_TOKENS: 3,
|
||||
EmbedParams.RETURN_OPTIONS: {"input_text": True},
|
||||
}
|
||||
|
||||
embedding = watson_models.Embeddings(**embeddings_config)
|
||||
|
||||
try:
|
||||
embeddings = embedding.embed_documents(input)
|
||||
return cast(Embeddings, embeddings)
|
||||
except Exception as e:
|
||||
print(f"Error during Watson embedding: {e}")
|
||||
raise
|
||||
44
src/crewai/rag/embeddings/providers/ibm/types.py
Normal file
44
src/crewai/rag/embeddings/providers/ibm/types.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Type definitions for IBM Watson embedding providers."""
|
||||
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class WatsonProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Watson provider."""
|
||||
|
||||
model_id: str
|
||||
url: str
|
||||
params: dict[str, str | dict[str, str]]
|
||||
credentials: Any
|
||||
project_id: str
|
||||
space_id: str
|
||||
api_client: Any
|
||||
verify: bool | str
|
||||
persistent_connection: Annotated[bool, True]
|
||||
batch_size: Annotated[int, 100]
|
||||
concurrency_limit: Annotated[int, 10]
|
||||
max_retries: int
|
||||
delay_time: float
|
||||
retry_status_codes: list[int]
|
||||
api_key: str
|
||||
name: str
|
||||
iam_serviceid_crn: str
|
||||
trusted_profile_id: str
|
||||
token: str
|
||||
projects_token: str
|
||||
username: str
|
||||
password: str
|
||||
instance_id: str
|
||||
version: str
|
||||
bedrock_url: str
|
||||
platform_url: str
|
||||
proxies: dict
|
||||
|
||||
|
||||
class WatsonProviderSpec(TypedDict, total=False):
|
||||
"""Watson provider specification."""
|
||||
|
||||
provider: Required[Literal["watson"]]
|
||||
config: WatsonProviderConfig
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user