mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-28 17:48:13 +00:00
Compare commits
6 Commits
llm-event-
...
devin/1769
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9af03058fe | ||
|
|
d52dbc1f4b | ||
|
|
6b926b90d0 | ||
|
|
fc84daadbb | ||
|
|
58b866a83d | ||
|
|
9797567342 |
@@ -4,6 +4,74 @@ description: "Product updates, improvements, and bug fixes for CrewAI"
|
|||||||
icon: "clock"
|
icon: "clock"
|
||||||
mode: "wide"
|
mode: "wide"
|
||||||
---
|
---
|
||||||
|
<Update label="Jan 26, 2026">
|
||||||
|
## v1.9.0
|
||||||
|
|
||||||
|
[View release on GitHub](https://github.com/crewAIInc/crewAI/releases/tag/1.9.0)
|
||||||
|
|
||||||
|
## What's Changed
|
||||||
|
|
||||||
|
### Features
|
||||||
|
- Add structured outputs and response_format support across providers
|
||||||
|
- Add response ID to streaming responses
|
||||||
|
- Add event ordering with parent-child hierarchies
|
||||||
|
- Add Keycloak SSO authentication support
|
||||||
|
- Add multimodal file handling capabilities
|
||||||
|
- Add native OpenAI responses API support
|
||||||
|
- Add A2A task execution utilities
|
||||||
|
- Add A2A server configuration and agent card generation
|
||||||
|
- Enhance event system and expand transport options
|
||||||
|
- Improve tool calling mechanisms
|
||||||
|
|
||||||
|
### Bug Fixes
|
||||||
|
- Enhance file store with fallback memory cache when aiocache is not available
|
||||||
|
- Ensure document list is not empty
|
||||||
|
- Handle Bedrock stop sequences properly
|
||||||
|
- Add Google Vertex API key support
|
||||||
|
- Enhance Azure model stop word detection
|
||||||
|
- Improve error handling for HumanFeedbackPending in flow execution
|
||||||
|
- Fix execution span task unlinking
|
||||||
|
|
||||||
|
### Documentation
|
||||||
|
- Add native file handling documentation
|
||||||
|
- Add OpenAI responses API documentation
|
||||||
|
- Add agent card implementation guidance
|
||||||
|
- Refine A2A documentation
|
||||||
|
- Update changelog for v1.8.0
|
||||||
|
|
||||||
|
### Contributors
|
||||||
|
@Anaisdg, @GininDenis, @Vidit-Ostwal, @greysonlalonde, @heitorado, @joaomdmoura, @koushiv777, @lorenzejay, @nicoferdi96, @vinibrsl
|
||||||
|
|
||||||
|
</Update>
|
||||||
|
|
||||||
|
<Update label="Jan 15, 2026">
|
||||||
|
## v1.8.1
|
||||||
|
|
||||||
|
[View release on GitHub](https://github.com/crewAIInc/crewAI/releases/tag/1.8.1)
|
||||||
|
|
||||||
|
## What's Changed
|
||||||
|
|
||||||
|
### Features
|
||||||
|
- Add A2A task execution utilities
|
||||||
|
- Add A2A server configuration and agent card generation
|
||||||
|
- Add additional transport mechanisms
|
||||||
|
- Add Galileo integration support
|
||||||
|
|
||||||
|
### Bug Fixes
|
||||||
|
- Improve Azure model compatibility
|
||||||
|
- Expand frame inspection depth to detect parent_flow
|
||||||
|
- Resolve task execution span management issues
|
||||||
|
- Enhance error handling for human feedback scenarios during flow execution
|
||||||
|
|
||||||
|
### Documentation
|
||||||
|
- Add A2A agent card documentation
|
||||||
|
- Add PII redaction feature documentation
|
||||||
|
|
||||||
|
### Contributors
|
||||||
|
@Anaisdg, @GininDenis, @greysonlalonde, @joaomdmoura, @koushiv777, @lorenzejay, @vinibrsl
|
||||||
|
|
||||||
|
</Update>
|
||||||
|
|
||||||
<Update label="Jan 08, 2026">
|
<Update label="Jan 08, 2026">
|
||||||
## v1.8.0
|
## v1.8.0
|
||||||
|
|
||||||
|
|||||||
@@ -401,23 +401,58 @@ crew = Crew(
|
|||||||
|
|
||||||
### Vertex AI Embeddings
|
### Vertex AI Embeddings
|
||||||
|
|
||||||
For Google Cloud users with Vertex AI access.
|
For Google Cloud users with Vertex AI access. Supports both legacy and new embedding models with automatic SDK selection.
|
||||||
|
|
||||||
|
<Note>
|
||||||
|
**Deprecation Notice:** Legacy models (`textembedding-gecko*`) use the deprecated `vertexai.language_models` SDK which will be removed after June 24, 2026. Consider migrating to newer models like `gemini-embedding-001`. See the [Google migration guide](https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk) for details.
|
||||||
|
</Note>
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
# Recommended: Using new models with google-genai SDK
|
||||||
crew = Crew(
|
crew = Crew(
|
||||||
memory=True,
|
memory=True,
|
||||||
embedder={
|
embedder={
|
||||||
"provider": "vertexai",
|
"provider": "google-vertex",
|
||||||
"config": {
|
"config": {
|
||||||
"project_id": "your-gcp-project-id",
|
"project_id": "your-gcp-project-id",
|
||||||
"region": "us-central1", # or your preferred region
|
"location": "us-central1",
|
||||||
"api_key": "your-service-account-key",
|
"model_name": "gemini-embedding-001", # or "text-embedding-005", "text-multilingual-embedding-002"
|
||||||
"model_name": "textembedding-gecko"
|
"task_type": "RETRIEVAL_DOCUMENT", # Optional
|
||||||
|
"output_dimensionality": 768 # Optional
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Using API key authentication (Exp)
|
||||||
|
crew = Crew(
|
||||||
|
memory=True,
|
||||||
|
embedder={
|
||||||
|
"provider": "google-vertex",
|
||||||
|
"config": {
|
||||||
|
"api_key": "your-google-api-key",
|
||||||
|
"model_name": "gemini-embedding-001"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Legacy models (backwards compatible, emits deprecation warning)
|
||||||
|
crew = Crew(
|
||||||
|
memory=True,
|
||||||
|
embedder={
|
||||||
|
"provider": "google-vertex",
|
||||||
|
"config": {
|
||||||
|
"project_id": "your-gcp-project-id",
|
||||||
|
"region": "us-central1", # or "location" (region is deprecated)
|
||||||
|
"model_name": "textembedding-gecko" # Legacy model
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**Available models:**
|
||||||
|
- **New SDK models** (recommended): `gemini-embedding-001`, `text-embedding-005`, `text-multilingual-embedding-002`
|
||||||
|
- **Legacy models** (deprecated): `textembedding-gecko`, `textembedding-gecko@001`, `textembedding-gecko-multilingual`
|
||||||
|
|
||||||
### Ollama Embeddings (Local)
|
### Ollama Embeddings (Local)
|
||||||
|
|
||||||
Run embeddings locally for privacy and cost savings.
|
Run embeddings locally for privacy and cost savings.
|
||||||
@@ -569,7 +604,7 @@ mem0_client_embedder_config = {
|
|||||||
"project_id": "my_project_id", # Optional
|
"project_id": "my_project_id", # Optional
|
||||||
"api_key": "custom-api-key" # Optional - overrides env var
|
"api_key": "custom-api-key" # Optional - overrides env var
|
||||||
"run_id": "my_run_id", # Optional - for short-term memory
|
"run_id": "my_run_id", # Optional - for short-term memory
|
||||||
"includes": "include1", # Optional
|
"includes": "include1", # Optional
|
||||||
"excludes": "exclude1", # Optional
|
"excludes": "exclude1", # Optional
|
||||||
"infer": True # Optional defaults to True
|
"infer": True # Optional defaults to True
|
||||||
"custom_categories": new_categories # Optional - custom categories for user memory
|
"custom_categories": new_categories # Optional - custom categories for user memory
|
||||||
@@ -591,7 +626,7 @@ crew = Crew(
|
|||||||
|
|
||||||
### Choosing the Right Embedding Provider
|
### Choosing the Right Embedding Provider
|
||||||
|
|
||||||
When selecting an embedding provider, consider factors like performance, privacy, cost, and integration needs.
|
When selecting an embedding provider, consider factors like performance, privacy, cost, and integration needs.
|
||||||
Below is a comparison to help you decide:
|
Below is a comparison to help you decide:
|
||||||
|
|
||||||
| Provider | Best For | Pros | Cons |
|
| Provider | Best For | Pros | Cons |
|
||||||
@@ -749,7 +784,7 @@ Entity Memory supports batching when saving multiple entities at once. When you
|
|||||||
|
|
||||||
This improves performance and observability when writing many entities in one operation.
|
This improves performance and observability when writing many entities in one operation.
|
||||||
|
|
||||||
## 2. External Memory
|
## 2. External Memory
|
||||||
External Memory provides a standalone memory system that operates independently from the crew's built-in memory. This is ideal for specialized memory providers or cross-application memory sharing.
|
External Memory provides a standalone memory system that operates independently from the crew's built-in memory. This is ideal for specialized memory providers or cross-application memory sharing.
|
||||||
|
|
||||||
### Basic External Memory with Mem0
|
### Basic External Memory with Mem0
|
||||||
@@ -819,7 +854,7 @@ external_memory = ExternalMemory(
|
|||||||
"project_id": "my_project_id", # Optional
|
"project_id": "my_project_id", # Optional
|
||||||
"api_key": "custom-api-key" # Optional - overrides env var
|
"api_key": "custom-api-key" # Optional - overrides env var
|
||||||
"run_id": "my_run_id", # Optional - for short-term memory
|
"run_id": "my_run_id", # Optional - for short-term memory
|
||||||
"includes": "include1", # Optional
|
"includes": "include1", # Optional
|
||||||
"excludes": "exclude1", # Optional
|
"excludes": "exclude1", # Optional
|
||||||
"infer": True # Optional defaults to True
|
"infer": True # Optional defaults to True
|
||||||
"custom_categories": new_categories # Optional - custom categories for user memory
|
"custom_categories": new_categories # Optional - custom categories for user memory
|
||||||
|
|||||||
@@ -4,6 +4,74 @@ description: "CrewAI의 제품 업데이트, 개선 사항 및 버그 수정"
|
|||||||
icon: "clock"
|
icon: "clock"
|
||||||
mode: "wide"
|
mode: "wide"
|
||||||
---
|
---
|
||||||
|
<Update label="2026년 1월 26일">
|
||||||
|
## v1.9.0
|
||||||
|
|
||||||
|
[GitHub 릴리스 보기](https://github.com/crewAIInc/crewAI/releases/tag/1.9.0)
|
||||||
|
|
||||||
|
## 변경 사항
|
||||||
|
|
||||||
|
### 기능
|
||||||
|
- 프로바이더 전반에 걸친 구조화된 출력 및 response_format 지원 추가
|
||||||
|
- 스트리밍 응답에 응답 ID 추가
|
||||||
|
- 부모-자식 계층 구조를 가진 이벤트 순서 추가
|
||||||
|
- Keycloak SSO 인증 지원 추가
|
||||||
|
- 멀티모달 파일 처리 기능 추가
|
||||||
|
- 네이티브 OpenAI responses API 지원 추가
|
||||||
|
- A2A 작업 실행 유틸리티 추가
|
||||||
|
- A2A 서버 구성 및 에이전트 카드 생성 추가
|
||||||
|
- 이벤트 시스템 향상 및 전송 옵션 확장
|
||||||
|
- 도구 호출 메커니즘 개선
|
||||||
|
|
||||||
|
### 버그 수정
|
||||||
|
- aiocache를 사용할 수 없을 때 폴백 메모리 캐시로 파일 저장소 향상
|
||||||
|
- 문서 목록이 비어 있지 않도록 보장
|
||||||
|
- Bedrock 중지 시퀀스 적절히 처리
|
||||||
|
- Google Vertex API 키 지원 추가
|
||||||
|
- Azure 모델 중지 단어 감지 향상
|
||||||
|
- 흐름 실행 시 HumanFeedbackPending 오류 처리 개선
|
||||||
|
- 실행 스팬 작업 연결 해제 수정
|
||||||
|
|
||||||
|
### 문서
|
||||||
|
- 네이티브 파일 처리 문서 추가
|
||||||
|
- OpenAI responses API 문서 추가
|
||||||
|
- 에이전트 카드 구현 가이드 추가
|
||||||
|
- A2A 문서 개선
|
||||||
|
- v1.8.0 변경 로그 업데이트
|
||||||
|
|
||||||
|
### 기여자
|
||||||
|
@Anaisdg, @GininDenis, @Vidit-Ostwal, @greysonlalonde, @heitorado, @joaomdmoura, @koushiv777, @lorenzejay, @nicoferdi96, @vinibrsl
|
||||||
|
|
||||||
|
</Update>
|
||||||
|
|
||||||
|
<Update label="2026년 1월 15일">
|
||||||
|
## v1.8.1
|
||||||
|
|
||||||
|
[GitHub 릴리스 보기](https://github.com/crewAIInc/crewAI/releases/tag/1.8.1)
|
||||||
|
|
||||||
|
## 변경 사항
|
||||||
|
|
||||||
|
### 기능
|
||||||
|
- A2A 작업 실행 유틸리티 추가
|
||||||
|
- A2A 서버 구성 및 에이전트 카드 생성 추가
|
||||||
|
- 추가 전송 메커니즘 추가
|
||||||
|
- Galileo 통합 지원 추가
|
||||||
|
|
||||||
|
### 버그 수정
|
||||||
|
- Azure 모델 호환성 개선
|
||||||
|
- parent_flow 감지를 위한 프레임 검사 깊이 확장
|
||||||
|
- 작업 실행 스팬 관리 문제 해결
|
||||||
|
- 흐름 실행 중 휴먼 피드백 시나리오에 대한 오류 처리 향상
|
||||||
|
|
||||||
|
### 문서
|
||||||
|
- A2A 에이전트 카드 문서 추가
|
||||||
|
- PII 삭제 기능 문서 추가
|
||||||
|
|
||||||
|
### 기여자
|
||||||
|
@Anaisdg, @GininDenis, @greysonlalonde, @joaomdmoura, @koushiv777, @lorenzejay, @vinibrsl
|
||||||
|
|
||||||
|
</Update>
|
||||||
|
|
||||||
<Update label="2026년 1월 8일">
|
<Update label="2026년 1월 8일">
|
||||||
## v1.8.0
|
## v1.8.0
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,74 @@ description: "Atualizações de produto, melhorias e correções do CrewAI"
|
|||||||
icon: "clock"
|
icon: "clock"
|
||||||
mode: "wide"
|
mode: "wide"
|
||||||
---
|
---
|
||||||
|
<Update label="26 jan 2026">
|
||||||
|
## v1.9.0
|
||||||
|
|
||||||
|
[Ver release no GitHub](https://github.com/crewAIInc/crewAI/releases/tag/1.9.0)
|
||||||
|
|
||||||
|
## O que Mudou
|
||||||
|
|
||||||
|
### Funcionalidades
|
||||||
|
- Adicionar suporte a saídas estruturadas e response_format em vários provedores
|
||||||
|
- Adicionar ID de resposta às respostas de streaming
|
||||||
|
- Adicionar ordenação de eventos com hierarquias pai-filho
|
||||||
|
- Adicionar suporte à autenticação SSO Keycloak
|
||||||
|
- Adicionar capacidades de manipulação de arquivos multimodais
|
||||||
|
- Adicionar suporte nativo à API de respostas OpenAI
|
||||||
|
- Adicionar utilitários de execução de tarefas A2A
|
||||||
|
- Adicionar configuração de servidor A2A e geração de cartão de agente
|
||||||
|
- Aprimorar sistema de eventos e expandir opções de transporte
|
||||||
|
- Melhorar mecanismos de chamada de ferramentas
|
||||||
|
|
||||||
|
### Correções de Bugs
|
||||||
|
- Aprimorar armazenamento de arquivos com cache de memória de fallback quando aiocache não está disponível
|
||||||
|
- Garantir que lista de documentos não esteja vazia
|
||||||
|
- Tratar sequências de parada do Bedrock adequadamente
|
||||||
|
- Adicionar suporte à chave de API do Google Vertex
|
||||||
|
- Aprimorar detecção de palavras de parada do modelo Azure
|
||||||
|
- Melhorar tratamento de erros para HumanFeedbackPending na execução de fluxo
|
||||||
|
- Corrigir desvinculação de tarefa do span de execução
|
||||||
|
|
||||||
|
### Documentação
|
||||||
|
- Adicionar documentação de manipulação nativa de arquivos
|
||||||
|
- Adicionar documentação da API de respostas OpenAI
|
||||||
|
- Adicionar orientação de implementação de cartão de agente
|
||||||
|
- Refinar documentação A2A
|
||||||
|
- Atualizar changelog para v1.8.0
|
||||||
|
|
||||||
|
### Contribuidores
|
||||||
|
@Anaisdg, @GininDenis, @Vidit-Ostwal, @greysonlalonde, @heitorado, @joaomdmoura, @koushiv777, @lorenzejay, @nicoferdi96, @vinibrsl
|
||||||
|
|
||||||
|
</Update>
|
||||||
|
|
||||||
|
<Update label="15 jan 2026">
|
||||||
|
## v1.8.1
|
||||||
|
|
||||||
|
[Ver release no GitHub](https://github.com/crewAIInc/crewAI/releases/tag/1.8.1)
|
||||||
|
|
||||||
|
## O que Mudou
|
||||||
|
|
||||||
|
### Funcionalidades
|
||||||
|
- Adicionar utilitários de execução de tarefas A2A
|
||||||
|
- Adicionar configuração de servidor A2A e geração de cartão de agente
|
||||||
|
- Adicionar mecanismos de transporte adicionais
|
||||||
|
- Adicionar suporte à integração Galileo
|
||||||
|
|
||||||
|
### Correções de Bugs
|
||||||
|
- Melhorar compatibilidade do modelo Azure
|
||||||
|
- Expandir profundidade de inspeção de frame para detectar parent_flow
|
||||||
|
- Resolver problemas de gerenciamento de span de execução de tarefas
|
||||||
|
- Aprimorar tratamento de erros para cenários de feedback humano durante execução de fluxo
|
||||||
|
|
||||||
|
### Documentação
|
||||||
|
- Adicionar documentação de cartão de agente A2A
|
||||||
|
- Adicionar documentação de recurso de redação de PII
|
||||||
|
|
||||||
|
### Contribuidores
|
||||||
|
@Anaisdg, @GininDenis, @greysonlalonde, @joaomdmoura, @koushiv777, @lorenzejay, @vinibrsl
|
||||||
|
|
||||||
|
</Update>
|
||||||
|
|
||||||
<Update label="08 jan 2026">
|
<Update label="08 jan 2026">
|
||||||
## v1.8.0
|
## v1.8.0
|
||||||
|
|
||||||
|
|||||||
@@ -152,4 +152,4 @@ __all__ = [
|
|||||||
"wrap_file_source",
|
"wrap_file_source",
|
||||||
]
|
]
|
||||||
|
|
||||||
__version__ = "1.8.1"
|
__version__ = "1.9.0"
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ dependencies = [
|
|||||||
"pytube~=15.0.0",
|
"pytube~=15.0.0",
|
||||||
"requests~=2.32.5",
|
"requests~=2.32.5",
|
||||||
"docker~=7.1.0",
|
"docker~=7.1.0",
|
||||||
"crewai==1.8.1",
|
"crewai==1.9.0",
|
||||||
"lancedb~=0.5.4",
|
"lancedb~=0.5.4",
|
||||||
"tiktoken~=0.8.0",
|
"tiktoken~=0.8.0",
|
||||||
"beautifulsoup4~=4.13.4",
|
"beautifulsoup4~=4.13.4",
|
||||||
|
|||||||
@@ -291,4 +291,4 @@ __all__ = [
|
|||||||
"ZapierActionTools",
|
"ZapierActionTools",
|
||||||
]
|
]
|
||||||
|
|
||||||
__version__ = "1.8.1"
|
__version__ = "1.9.0"
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ Repository = "https://github.com/crewAIInc/crewAI"
|
|||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
tools = [
|
tools = [
|
||||||
"crewai-tools==1.8.1",
|
"crewai-tools==1.9.0",
|
||||||
]
|
]
|
||||||
embeddings = [
|
embeddings = [
|
||||||
"tiktoken~=0.8.0"
|
"tiktoken~=0.8.0"
|
||||||
@@ -90,7 +90,7 @@ azure-ai-inference = [
|
|||||||
"azure-ai-inference~=1.0.0b9",
|
"azure-ai-inference~=1.0.0b9",
|
||||||
]
|
]
|
||||||
anthropic = [
|
anthropic = [
|
||||||
"anthropic~=0.71.0",
|
"anthropic~=0.73.0",
|
||||||
]
|
]
|
||||||
a2a = [
|
a2a = [
|
||||||
"a2a-sdk~=0.3.10",
|
"a2a-sdk~=0.3.10",
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
|
|||||||
|
|
||||||
_suppress_pydantic_deprecation_warnings()
|
_suppress_pydantic_deprecation_warnings()
|
||||||
|
|
||||||
__version__ = "1.8.1"
|
__version__ = "1.9.0"
|
||||||
_telemetry_submitted = False
|
_telemetry_submitted = False
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
|||||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||||
requires-python = ">=3.10,<3.14"
|
requires-python = ">=3.10,<3.14"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"crewai[tools]==1.8.1"
|
"crewai[tools]==1.9.0"
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
|||||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||||
requires-python = ">=3.10,<3.14"
|
requires-python = ">=3.10,<3.14"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"crewai[tools]==1.8.1"
|
"crewai[tools]==1.9.0"
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
|||||||
@@ -3,9 +3,8 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
from typing import TYPE_CHECKING, Any, Final, Literal, TypeGuard, cast
|
||||||
|
|
||||||
from anthropic.types import ThinkingBlock
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from crewai.events.types.llm_events import LLMCallType
|
from crewai.events.types.llm_events import LLMCallType
|
||||||
@@ -22,8 +21,9 @@ if TYPE_CHECKING:
|
|||||||
from crewai.llms.hooks.base import BaseInterceptor
|
from crewai.llms.hooks.base import BaseInterceptor
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from anthropic import Anthropic, AsyncAnthropic
|
from anthropic import Anthropic, AsyncAnthropic, transform_schema
|
||||||
from anthropic.types import Message, TextBlock, ThinkingBlock, ToolUseBlock
|
from anthropic.types import Message, TextBlock, ThinkingBlock, ToolUseBlock
|
||||||
|
from anthropic.types.beta import BetaMessage
|
||||||
import httpx
|
import httpx
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
@@ -31,7 +31,62 @@ except ImportError:
|
|||||||
) from None
|
) from None
|
||||||
|
|
||||||
|
|
||||||
ANTHROPIC_FILES_API_BETA = "files-api-2025-04-14"
|
ANTHROPIC_FILES_API_BETA: Final = "files-api-2025-04-14"
|
||||||
|
ANTHROPIC_STRUCTURED_OUTPUTS_BETA: Final = "structured-outputs-2025-11-13"
|
||||||
|
|
||||||
|
NATIVE_STRUCTURED_OUTPUT_MODELS: Final[
|
||||||
|
tuple[
|
||||||
|
Literal["claude-sonnet-4-5"],
|
||||||
|
Literal["claude-sonnet-4.5"],
|
||||||
|
Literal["claude-opus-4-5"],
|
||||||
|
Literal["claude-opus-4.5"],
|
||||||
|
Literal["claude-opus-4-1"],
|
||||||
|
Literal["claude-opus-4.1"],
|
||||||
|
Literal["claude-haiku-4-5"],
|
||||||
|
Literal["claude-haiku-4.5"],
|
||||||
|
]
|
||||||
|
] = (
|
||||||
|
"claude-sonnet-4-5",
|
||||||
|
"claude-sonnet-4.5",
|
||||||
|
"claude-opus-4-5",
|
||||||
|
"claude-opus-4.5",
|
||||||
|
"claude-opus-4-1",
|
||||||
|
"claude-opus-4.1",
|
||||||
|
"claude-haiku-4-5",
|
||||||
|
"claude-haiku-4.5",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _supports_native_structured_outputs(model: str) -> bool:
|
||||||
|
"""Check if the model supports native structured outputs.
|
||||||
|
|
||||||
|
Native structured outputs are only available for Claude 4.5 models
|
||||||
|
(Sonnet 4.5, Opus 4.5, Opus 4.1, Haiku 4.5).
|
||||||
|
Other models require the tool-based fallback approach.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model name/identifier.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the model supports native structured outputs.
|
||||||
|
"""
|
||||||
|
model_lower = model.lower()
|
||||||
|
return any(prefix in model_lower for prefix in NATIVE_STRUCTURED_OUTPUT_MODELS)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_pydantic_model_class(obj: Any) -> TypeGuard[type[BaseModel]]:
|
||||||
|
"""Check if an object is a Pydantic model class.
|
||||||
|
|
||||||
|
This distinguishes between Pydantic model classes that support structured
|
||||||
|
outputs (have model_json_schema) and plain dicts like {"type": "json_object"}.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj: The object to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if obj is a Pydantic model class.
|
||||||
|
"""
|
||||||
|
return isinstance(obj, type) and issubclass(obj, BaseModel)
|
||||||
|
|
||||||
|
|
||||||
def _contains_file_id_reference(messages: list[dict[str, Any]]) -> bool:
|
def _contains_file_id_reference(messages: list[dict[str, Any]]) -> bool:
|
||||||
@@ -84,6 +139,7 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
client_params: dict[str, Any] | None = None,
|
client_params: dict[str, Any] | None = None,
|
||||||
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
|
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
|
||||||
thinking: AnthropicThinkingConfig | None = None,
|
thinking: AnthropicThinkingConfig | None = None,
|
||||||
|
response_format: type[BaseModel] | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
"""Initialize Anthropic chat completion client.
|
"""Initialize Anthropic chat completion client.
|
||||||
@@ -101,6 +157,8 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
stream: Enable streaming responses
|
stream: Enable streaming responses
|
||||||
client_params: Additional parameters for the Anthropic client
|
client_params: Additional parameters for the Anthropic client
|
||||||
interceptor: HTTP interceptor for modifying requests/responses at transport level.
|
interceptor: HTTP interceptor for modifying requests/responses at transport level.
|
||||||
|
response_format: Pydantic model for structured output. When provided, responses
|
||||||
|
will be validated against this model schema.
|
||||||
**kwargs: Additional parameters
|
**kwargs: Additional parameters
|
||||||
"""
|
"""
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -131,6 +189,7 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
self.stop_sequences = stop_sequences or []
|
self.stop_sequences = stop_sequences or []
|
||||||
self.thinking = thinking
|
self.thinking = thinking
|
||||||
self.previous_thinking_blocks: list[ThinkingBlock] = []
|
self.previous_thinking_blocks: list[ThinkingBlock] = []
|
||||||
|
self.response_format = response_format
|
||||||
# Model-specific settings
|
# Model-specific settings
|
||||||
self.is_claude_3 = "claude-3" in model.lower()
|
self.is_claude_3 = "claude-3" in model.lower()
|
||||||
self.supports_tools = True
|
self.supports_tools = True
|
||||||
@@ -231,6 +290,8 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
formatted_messages, system_message, tools
|
formatted_messages, system_message, tools
|
||||||
)
|
)
|
||||||
|
|
||||||
|
effective_response_model = response_model or self.response_format
|
||||||
|
|
||||||
# Handle streaming vs non-streaming
|
# Handle streaming vs non-streaming
|
||||||
if self.stream:
|
if self.stream:
|
||||||
return self._handle_streaming_completion(
|
return self._handle_streaming_completion(
|
||||||
@@ -238,7 +299,7 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
available_functions,
|
available_functions,
|
||||||
from_task,
|
from_task,
|
||||||
from_agent,
|
from_agent,
|
||||||
response_model,
|
effective_response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._handle_completion(
|
return self._handle_completion(
|
||||||
@@ -246,7 +307,7 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
available_functions,
|
available_functions,
|
||||||
from_task,
|
from_task,
|
||||||
from_agent,
|
from_agent,
|
||||||
response_model,
|
effective_response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -298,13 +359,15 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
formatted_messages, system_message, tools
|
formatted_messages, system_message, tools
|
||||||
)
|
)
|
||||||
|
|
||||||
|
effective_response_model = response_model or self.response_format
|
||||||
|
|
||||||
if self.stream:
|
if self.stream:
|
||||||
return await self._ahandle_streaming_completion(
|
return await self._ahandle_streaming_completion(
|
||||||
completion_params,
|
completion_params,
|
||||||
available_functions,
|
available_functions,
|
||||||
from_task,
|
from_task,
|
||||||
from_agent,
|
from_agent,
|
||||||
response_model,
|
effective_response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
return await self._ahandle_completion(
|
return await self._ahandle_completion(
|
||||||
@@ -312,7 +375,7 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
available_functions,
|
available_functions,
|
||||||
from_task,
|
from_task,
|
||||||
from_agent,
|
from_agent,
|
||||||
response_model,
|
effective_response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -565,22 +628,40 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str | Any:
|
) -> str | Any:
|
||||||
"""Handle non-streaming message completion."""
|
"""Handle non-streaming message completion."""
|
||||||
if response_model:
|
|
||||||
structured_tool = {
|
|
||||||
"name": "structured_output",
|
|
||||||
"description": "Returns structured data according to the schema",
|
|
||||||
"input_schema": response_model.model_json_schema(),
|
|
||||||
}
|
|
||||||
|
|
||||||
params["tools"] = [structured_tool]
|
|
||||||
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
|
|
||||||
|
|
||||||
uses_file_api = _contains_file_id_reference(params.get("messages", []))
|
uses_file_api = _contains_file_id_reference(params.get("messages", []))
|
||||||
|
betas: list[str] = []
|
||||||
|
use_native_structured_output = False
|
||||||
|
|
||||||
|
if uses_file_api:
|
||||||
|
betas.append(ANTHROPIC_FILES_API_BETA)
|
||||||
|
|
||||||
|
extra_body: dict[str, Any] | None = None
|
||||||
|
if _is_pydantic_model_class(response_model):
|
||||||
|
schema = transform_schema(response_model.model_json_schema())
|
||||||
|
if _supports_native_structured_outputs(self.model):
|
||||||
|
use_native_structured_output = True
|
||||||
|
betas.append(ANTHROPIC_STRUCTURED_OUTPUTS_BETA)
|
||||||
|
extra_body = {
|
||||||
|
"output_format": {
|
||||||
|
"type": "json_schema",
|
||||||
|
"schema": schema,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
structured_tool = {
|
||||||
|
"name": "structured_output",
|
||||||
|
"description": "Output the structured response",
|
||||||
|
"input_schema": schema,
|
||||||
|
}
|
||||||
|
params["tools"] = [structured_tool]
|
||||||
|
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if uses_file_api:
|
if betas:
|
||||||
params["betas"] = [ANTHROPIC_FILES_API_BETA]
|
params["betas"] = betas
|
||||||
response = self.client.beta.messages.create(**params)
|
response = self.client.beta.messages.create(
|
||||||
|
**params, extra_body=extra_body
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
response = self.client.messages.create(**params)
|
response = self.client.messages.create(**params)
|
||||||
|
|
||||||
@@ -593,22 +674,34 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
usage = self._extract_anthropic_token_usage(response)
|
usage = self._extract_anthropic_token_usage(response)
|
||||||
self._track_token_usage_internal(usage)
|
self._track_token_usage_internal(usage)
|
||||||
|
|
||||||
if response_model and response.content:
|
if _is_pydantic_model_class(response_model) and response.content:
|
||||||
tool_uses = [
|
if use_native_structured_output:
|
||||||
block for block in response.content if isinstance(block, ToolUseBlock)
|
for block in response.content:
|
||||||
]
|
if isinstance(block, TextBlock):
|
||||||
if tool_uses and tool_uses[0].name == "structured_output":
|
structured_json = block.text
|
||||||
structured_data = tool_uses[0].input
|
self._emit_call_completed_event(
|
||||||
structured_json = json.dumps(structured_data)
|
response=structured_json,
|
||||||
self._emit_call_completed_event(
|
call_type=LLMCallType.LLM_CALL,
|
||||||
response=structured_json,
|
from_task=from_task,
|
||||||
call_type=LLMCallType.LLM_CALL,
|
from_agent=from_agent,
|
||||||
from_task=from_task,
|
messages=params["messages"],
|
||||||
from_agent=from_agent,
|
)
|
||||||
messages=params["messages"],
|
return structured_json
|
||||||
)
|
else:
|
||||||
|
for block in response.content:
|
||||||
return structured_json
|
if (
|
||||||
|
isinstance(block, ToolUseBlock)
|
||||||
|
and block.name == "structured_output"
|
||||||
|
):
|
||||||
|
structured_json = json.dumps(block.input)
|
||||||
|
self._emit_call_completed_event(
|
||||||
|
response=structured_json,
|
||||||
|
call_type=LLMCallType.LLM_CALL,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
messages=params["messages"],
|
||||||
|
)
|
||||||
|
return structured_json
|
||||||
|
|
||||||
# Check if Claude wants to use tools
|
# Check if Claude wants to use tools
|
||||||
if response.content:
|
if response.content:
|
||||||
@@ -678,17 +771,31 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
from_task: Any | None = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Any | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str:
|
) -> str | Any:
|
||||||
"""Handle streaming message completion."""
|
"""Handle streaming message completion."""
|
||||||
if response_model:
|
betas: list[str] = []
|
||||||
structured_tool = {
|
use_native_structured_output = False
|
||||||
"name": "structured_output",
|
|
||||||
"description": "Returns structured data according to the schema",
|
|
||||||
"input_schema": response_model.model_json_schema(),
|
|
||||||
}
|
|
||||||
|
|
||||||
params["tools"] = [structured_tool]
|
extra_body: dict[str, Any] | None = None
|
||||||
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
|
if _is_pydantic_model_class(response_model):
|
||||||
|
schema = transform_schema(response_model.model_json_schema())
|
||||||
|
if _supports_native_structured_outputs(self.model):
|
||||||
|
use_native_structured_output = True
|
||||||
|
betas.append(ANTHROPIC_STRUCTURED_OUTPUTS_BETA)
|
||||||
|
extra_body = {
|
||||||
|
"output_format": {
|
||||||
|
"type": "json_schema",
|
||||||
|
"schema": schema,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
structured_tool = {
|
||||||
|
"name": "structured_output",
|
||||||
|
"description": "Output the structured response",
|
||||||
|
"input_schema": schema,
|
||||||
|
}
|
||||||
|
params["tools"] = [structured_tool]
|
||||||
|
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
|
||||||
|
|
||||||
full_response = ""
|
full_response = ""
|
||||||
|
|
||||||
@@ -696,15 +803,22 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
# (the SDK sets it internally)
|
# (the SDK sets it internally)
|
||||||
stream_params = {k: v for k, v in params.items() if k != "stream"}
|
stream_params = {k: v for k, v in params.items() if k != "stream"}
|
||||||
|
|
||||||
|
if betas:
|
||||||
|
stream_params["betas"] = betas
|
||||||
|
|
||||||
current_tool_calls: dict[int, dict[str, Any]] = {}
|
current_tool_calls: dict[int, dict[str, Any]] = {}
|
||||||
|
|
||||||
# Make streaming API call
|
stream_context = (
|
||||||
with self.client.messages.stream(**stream_params) as stream:
|
self.client.beta.messages.stream(**stream_params, extra_body=extra_body)
|
||||||
|
if betas
|
||||||
|
else self.client.messages.stream(**stream_params)
|
||||||
|
)
|
||||||
|
with stream_context as stream:
|
||||||
response_id = None
|
response_id = None
|
||||||
for event in stream:
|
for event in stream:
|
||||||
if hasattr(event, "message") and hasattr(event.message, "id"):
|
if hasattr(event, "message") and hasattr(event.message, "id"):
|
||||||
response_id = event.message.id
|
response_id = event.message.id
|
||||||
|
|
||||||
if hasattr(event, "delta") and hasattr(event.delta, "text"):
|
if hasattr(event, "delta") and hasattr(event.delta, "text"):
|
||||||
text_delta = event.delta.text
|
text_delta = event.delta.text
|
||||||
full_response += text_delta
|
full_response += text_delta
|
||||||
@@ -712,7 +826,7 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
chunk=text_delta,
|
chunk=text_delta,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
response_id=response_id
|
response_id=response_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if event.type == "content_block_start":
|
if event.type == "content_block_start":
|
||||||
@@ -739,7 +853,7 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
"index": block_index,
|
"index": block_index,
|
||||||
},
|
},
|
||||||
call_type=LLMCallType.TOOL_CALL,
|
call_type=LLMCallType.TOOL_CALL,
|
||||||
response_id=response_id
|
response_id=response_id,
|
||||||
)
|
)
|
||||||
elif event.type == "content_block_delta":
|
elif event.type == "content_block_delta":
|
||||||
if event.delta.type == "input_json_delta":
|
if event.delta.type == "input_json_delta":
|
||||||
@@ -763,10 +877,10 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
"index": block_index,
|
"index": block_index,
|
||||||
},
|
},
|
||||||
call_type=LLMCallType.TOOL_CALL,
|
call_type=LLMCallType.TOOL_CALL,
|
||||||
response_id=response_id
|
response_id=response_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
final_message: Message = stream.get_final_message()
|
final_message = stream.get_final_message()
|
||||||
|
|
||||||
thinking_blocks: list[ThinkingBlock] = []
|
thinking_blocks: list[ThinkingBlock] = []
|
||||||
if final_message.content:
|
if final_message.content:
|
||||||
@@ -781,25 +895,30 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
usage = self._extract_anthropic_token_usage(final_message)
|
usage = self._extract_anthropic_token_usage(final_message)
|
||||||
self._track_token_usage_internal(usage)
|
self._track_token_usage_internal(usage)
|
||||||
|
|
||||||
if response_model and final_message.content:
|
if _is_pydantic_model_class(response_model):
|
||||||
tool_uses = [
|
if use_native_structured_output:
|
||||||
block
|
|
||||||
for block in final_message.content
|
|
||||||
if isinstance(block, ToolUseBlock)
|
|
||||||
]
|
|
||||||
if tool_uses and tool_uses[0].name == "structured_output":
|
|
||||||
structured_data = tool_uses[0].input
|
|
||||||
structured_json = json.dumps(structured_data)
|
|
||||||
|
|
||||||
self._emit_call_completed_event(
|
self._emit_call_completed_event(
|
||||||
response=structured_json,
|
response=full_response,
|
||||||
call_type=LLMCallType.LLM_CALL,
|
call_type=LLMCallType.LLM_CALL,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
messages=params["messages"],
|
messages=params["messages"],
|
||||||
)
|
)
|
||||||
|
return full_response
|
||||||
return structured_json
|
for block in final_message.content:
|
||||||
|
if (
|
||||||
|
isinstance(block, ToolUseBlock)
|
||||||
|
and block.name == "structured_output"
|
||||||
|
):
|
||||||
|
structured_json = json.dumps(block.input)
|
||||||
|
self._emit_call_completed_event(
|
||||||
|
response=structured_json,
|
||||||
|
call_type=LLMCallType.LLM_CALL,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
messages=params["messages"],
|
||||||
|
)
|
||||||
|
return structured_json
|
||||||
|
|
||||||
if final_message.content:
|
if final_message.content:
|
||||||
tool_uses = [
|
tool_uses = [
|
||||||
@@ -809,11 +928,9 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
]
|
]
|
||||||
|
|
||||||
if tool_uses:
|
if tool_uses:
|
||||||
# If no available_functions, return tool calls for executor to handle
|
|
||||||
if not available_functions:
|
if not available_functions:
|
||||||
return list(tool_uses)
|
return list(tool_uses)
|
||||||
|
|
||||||
# Handle tool use conversation flow internally
|
|
||||||
return self._handle_tool_use_conversation(
|
return self._handle_tool_use_conversation(
|
||||||
final_message,
|
final_message,
|
||||||
tool_uses,
|
tool_uses,
|
||||||
@@ -823,10 +940,8 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
from_agent,
|
from_agent,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply stop words to full response
|
|
||||||
full_response = self._apply_stop_words(full_response)
|
full_response = self._apply_stop_words(full_response)
|
||||||
|
|
||||||
# Emit completion event and return full response
|
|
||||||
self._emit_call_completed_event(
|
self._emit_call_completed_event(
|
||||||
response=full_response,
|
response=full_response,
|
||||||
call_type=LLMCallType.LLM_CALL,
|
call_type=LLMCallType.LLM_CALL,
|
||||||
@@ -884,7 +999,7 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
|
|
||||||
def _handle_tool_use_conversation(
|
def _handle_tool_use_conversation(
|
||||||
self,
|
self,
|
||||||
initial_response: Message,
|
initial_response: Message | BetaMessage,
|
||||||
tool_uses: list[ToolUseBlock],
|
tool_uses: list[ToolUseBlock],
|
||||||
params: dict[str, Any],
|
params: dict[str, Any],
|
||||||
available_functions: dict[str, Any],
|
available_functions: dict[str, Any],
|
||||||
@@ -1002,22 +1117,40 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str | Any:
|
) -> str | Any:
|
||||||
"""Handle non-streaming async message completion."""
|
"""Handle non-streaming async message completion."""
|
||||||
if response_model:
|
|
||||||
structured_tool = {
|
|
||||||
"name": "structured_output",
|
|
||||||
"description": "Returns structured data according to the schema",
|
|
||||||
"input_schema": response_model.model_json_schema(),
|
|
||||||
}
|
|
||||||
|
|
||||||
params["tools"] = [structured_tool]
|
|
||||||
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
|
|
||||||
|
|
||||||
uses_file_api = _contains_file_id_reference(params.get("messages", []))
|
uses_file_api = _contains_file_id_reference(params.get("messages", []))
|
||||||
|
betas: list[str] = []
|
||||||
|
use_native_structured_output = False
|
||||||
|
|
||||||
|
if uses_file_api:
|
||||||
|
betas.append(ANTHROPIC_FILES_API_BETA)
|
||||||
|
|
||||||
|
extra_body: dict[str, Any] | None = None
|
||||||
|
if _is_pydantic_model_class(response_model):
|
||||||
|
schema = transform_schema(response_model.model_json_schema())
|
||||||
|
if _supports_native_structured_outputs(self.model):
|
||||||
|
use_native_structured_output = True
|
||||||
|
betas.append(ANTHROPIC_STRUCTURED_OUTPUTS_BETA)
|
||||||
|
extra_body = {
|
||||||
|
"output_format": {
|
||||||
|
"type": "json_schema",
|
||||||
|
"schema": schema,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
structured_tool = {
|
||||||
|
"name": "structured_output",
|
||||||
|
"description": "Output the structured response",
|
||||||
|
"input_schema": schema,
|
||||||
|
}
|
||||||
|
params["tools"] = [structured_tool]
|
||||||
|
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if uses_file_api:
|
if betas:
|
||||||
params["betas"] = [ANTHROPIC_FILES_API_BETA]
|
params["betas"] = betas
|
||||||
response = await self.async_client.beta.messages.create(**params)
|
response = await self.async_client.beta.messages.create(
|
||||||
|
**params, extra_body=extra_body
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
response = await self.async_client.messages.create(**params)
|
response = await self.async_client.messages.create(**params)
|
||||||
|
|
||||||
@@ -1030,23 +1163,34 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
usage = self._extract_anthropic_token_usage(response)
|
usage = self._extract_anthropic_token_usage(response)
|
||||||
self._track_token_usage_internal(usage)
|
self._track_token_usage_internal(usage)
|
||||||
|
|
||||||
if response_model and response.content:
|
if _is_pydantic_model_class(response_model) and response.content:
|
||||||
tool_uses = [
|
if use_native_structured_output:
|
||||||
block for block in response.content if isinstance(block, ToolUseBlock)
|
for block in response.content:
|
||||||
]
|
if isinstance(block, TextBlock):
|
||||||
if tool_uses and tool_uses[0].name == "structured_output":
|
structured_json = block.text
|
||||||
structured_data = tool_uses[0].input
|
self._emit_call_completed_event(
|
||||||
structured_json = json.dumps(structured_data)
|
response=structured_json,
|
||||||
|
call_type=LLMCallType.LLM_CALL,
|
||||||
self._emit_call_completed_event(
|
from_task=from_task,
|
||||||
response=structured_json,
|
from_agent=from_agent,
|
||||||
call_type=LLMCallType.LLM_CALL,
|
messages=params["messages"],
|
||||||
from_task=from_task,
|
)
|
||||||
from_agent=from_agent,
|
return structured_json
|
||||||
messages=params["messages"],
|
else:
|
||||||
)
|
for block in response.content:
|
||||||
|
if (
|
||||||
return structured_json
|
isinstance(block, ToolUseBlock)
|
||||||
|
and block.name == "structured_output"
|
||||||
|
):
|
||||||
|
structured_json = json.dumps(block.input)
|
||||||
|
self._emit_call_completed_event(
|
||||||
|
response=structured_json,
|
||||||
|
call_type=LLMCallType.LLM_CALL,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
messages=params["messages"],
|
||||||
|
)
|
||||||
|
return structured_json
|
||||||
|
|
||||||
if response.content:
|
if response.content:
|
||||||
tool_uses = [
|
tool_uses = [
|
||||||
@@ -1102,25 +1246,49 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
from_task: Any | None = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Any | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str:
|
) -> str | Any:
|
||||||
"""Handle async streaming message completion."""
|
"""Handle async streaming message completion."""
|
||||||
if response_model:
|
betas: list[str] = []
|
||||||
structured_tool = {
|
use_native_structured_output = False
|
||||||
"name": "structured_output",
|
|
||||||
"description": "Returns structured data according to the schema",
|
|
||||||
"input_schema": response_model.model_json_schema(),
|
|
||||||
}
|
|
||||||
|
|
||||||
params["tools"] = [structured_tool]
|
extra_body: dict[str, Any] | None = None
|
||||||
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
|
if _is_pydantic_model_class(response_model):
|
||||||
|
schema = transform_schema(response_model.model_json_schema())
|
||||||
|
if _supports_native_structured_outputs(self.model):
|
||||||
|
use_native_structured_output = True
|
||||||
|
betas.append(ANTHROPIC_STRUCTURED_OUTPUTS_BETA)
|
||||||
|
extra_body = {
|
||||||
|
"output_format": {
|
||||||
|
"type": "json_schema",
|
||||||
|
"schema": schema,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
structured_tool = {
|
||||||
|
"name": "structured_output",
|
||||||
|
"description": "Output the structured response",
|
||||||
|
"input_schema": schema,
|
||||||
|
}
|
||||||
|
params["tools"] = [structured_tool]
|
||||||
|
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
|
||||||
|
|
||||||
full_response = ""
|
full_response = ""
|
||||||
|
|
||||||
stream_params = {k: v for k, v in params.items() if k != "stream"}
|
stream_params = {k: v for k, v in params.items() if k != "stream"}
|
||||||
|
|
||||||
|
if betas:
|
||||||
|
stream_params["betas"] = betas
|
||||||
|
|
||||||
current_tool_calls: dict[int, dict[str, Any]] = {}
|
current_tool_calls: dict[int, dict[str, Any]] = {}
|
||||||
|
|
||||||
async with self.async_client.messages.stream(**stream_params) as stream:
|
stream_context = (
|
||||||
|
self.async_client.beta.messages.stream(
|
||||||
|
**stream_params, extra_body=extra_body
|
||||||
|
)
|
||||||
|
if betas
|
||||||
|
else self.async_client.messages.stream(**stream_params)
|
||||||
|
)
|
||||||
|
async with stream_context as stream:
|
||||||
response_id = None
|
response_id = None
|
||||||
async for event in stream:
|
async for event in stream:
|
||||||
if hasattr(event, "message") and hasattr(event.message, "id"):
|
if hasattr(event, "message") and hasattr(event.message, "id"):
|
||||||
@@ -1133,7 +1301,7 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
chunk=text_delta,
|
chunk=text_delta,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
response_id=response_id
|
response_id=response_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if event.type == "content_block_start":
|
if event.type == "content_block_start":
|
||||||
@@ -1160,7 +1328,7 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
"index": block_index,
|
"index": block_index,
|
||||||
},
|
},
|
||||||
call_type=LLMCallType.TOOL_CALL,
|
call_type=LLMCallType.TOOL_CALL,
|
||||||
response_id=response_id
|
response_id=response_id,
|
||||||
)
|
)
|
||||||
elif event.type == "content_block_delta":
|
elif event.type == "content_block_delta":
|
||||||
if event.delta.type == "input_json_delta":
|
if event.delta.type == "input_json_delta":
|
||||||
@@ -1184,33 +1352,38 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
"index": block_index,
|
"index": block_index,
|
||||||
},
|
},
|
||||||
call_type=LLMCallType.TOOL_CALL,
|
call_type=LLMCallType.TOOL_CALL,
|
||||||
response_id=response_id
|
response_id=response_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
final_message: Message = await stream.get_final_message()
|
final_message = await stream.get_final_message()
|
||||||
|
|
||||||
usage = self._extract_anthropic_token_usage(final_message)
|
usage = self._extract_anthropic_token_usage(final_message)
|
||||||
self._track_token_usage_internal(usage)
|
self._track_token_usage_internal(usage)
|
||||||
|
|
||||||
if response_model and final_message.content:
|
if _is_pydantic_model_class(response_model):
|
||||||
tool_uses = [
|
if use_native_structured_output:
|
||||||
block
|
|
||||||
for block in final_message.content
|
|
||||||
if isinstance(block, ToolUseBlock)
|
|
||||||
]
|
|
||||||
if tool_uses and tool_uses[0].name == "structured_output":
|
|
||||||
structured_data = tool_uses[0].input
|
|
||||||
structured_json = json.dumps(structured_data)
|
|
||||||
|
|
||||||
self._emit_call_completed_event(
|
self._emit_call_completed_event(
|
||||||
response=structured_json,
|
response=full_response,
|
||||||
call_type=LLMCallType.LLM_CALL,
|
call_type=LLMCallType.LLM_CALL,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
messages=params["messages"],
|
messages=params["messages"],
|
||||||
)
|
)
|
||||||
|
return full_response
|
||||||
return structured_json
|
for block in final_message.content:
|
||||||
|
if (
|
||||||
|
isinstance(block, ToolUseBlock)
|
||||||
|
and block.name == "structured_output"
|
||||||
|
):
|
||||||
|
structured_json = json.dumps(block.input)
|
||||||
|
self._emit_call_completed_event(
|
||||||
|
response=structured_json,
|
||||||
|
call_type=LLMCallType.LLM_CALL,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
messages=params["messages"],
|
||||||
|
)
|
||||||
|
return structured_json
|
||||||
|
|
||||||
if final_message.content:
|
if final_message.content:
|
||||||
tool_uses = [
|
tool_uses = [
|
||||||
@@ -1220,7 +1393,6 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
]
|
]
|
||||||
|
|
||||||
if tool_uses:
|
if tool_uses:
|
||||||
# If no available_functions, return tool calls for executor to handle
|
|
||||||
if not available_functions:
|
if not available_functions:
|
||||||
return list(tool_uses)
|
return list(tool_uses)
|
||||||
|
|
||||||
@@ -1247,7 +1419,7 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
|
|
||||||
async def _ahandle_tool_use_conversation(
|
async def _ahandle_tool_use_conversation(
|
||||||
self,
|
self,
|
||||||
initial_response: Message,
|
initial_response: Message | BetaMessage,
|
||||||
tool_uses: list[ToolUseBlock],
|
tool_uses: list[ToolUseBlock],
|
||||||
params: dict[str, Any],
|
params: dict[str, Any],
|
||||||
available_functions: dict[str, Any],
|
available_functions: dict[str, Any],
|
||||||
@@ -1356,7 +1528,9 @@ class AnthropicCompletion(BaseLLM):
|
|||||||
return int(200000 * CONTEXT_WINDOW_USAGE_RATIO)
|
return int(200000 * CONTEXT_WINDOW_USAGE_RATIO)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract_anthropic_token_usage(response: Message) -> dict[str, Any]:
|
def _extract_anthropic_token_usage(
|
||||||
|
response: Message | BetaMessage,
|
||||||
|
) -> dict[str, Any]:
|
||||||
"""Extract token usage from Anthropic response."""
|
"""Extract token usage from Anthropic response."""
|
||||||
if hasattr(response, "usage") and response.usage:
|
if hasattr(response, "usage") and response.usage:
|
||||||
usage = response.usage
|
usage = response.usage
|
||||||
|
|||||||
@@ -92,6 +92,7 @@ class AzureCompletion(BaseLLM):
|
|||||||
stop: list[str] | None = None,
|
stop: list[str] | None = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
interceptor: BaseInterceptor[Any, Any] | None = None,
|
interceptor: BaseInterceptor[Any, Any] | None = None,
|
||||||
|
response_format: type[BaseModel] | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
"""Initialize Azure AI Inference chat completion client.
|
"""Initialize Azure AI Inference chat completion client.
|
||||||
@@ -111,6 +112,9 @@ class AzureCompletion(BaseLLM):
|
|||||||
stop: Stop sequences
|
stop: Stop sequences
|
||||||
stream: Enable streaming responses
|
stream: Enable streaming responses
|
||||||
interceptor: HTTP interceptor (not yet supported for Azure).
|
interceptor: HTTP interceptor (not yet supported for Azure).
|
||||||
|
response_format: Pydantic model for structured output. Used as default when
|
||||||
|
response_model is not passed to call()/acall() methods.
|
||||||
|
Only works with OpenAI models deployed on Azure.
|
||||||
**kwargs: Additional parameters
|
**kwargs: Additional parameters
|
||||||
"""
|
"""
|
||||||
if interceptor is not None:
|
if interceptor is not None:
|
||||||
@@ -165,6 +169,7 @@ class AzureCompletion(BaseLLM):
|
|||||||
self.presence_penalty = presence_penalty
|
self.presence_penalty = presence_penalty
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
self.stream = stream
|
self.stream = stream
|
||||||
|
self.response_format = response_format
|
||||||
|
|
||||||
self.is_openai_model = any(
|
self.is_openai_model = any(
|
||||||
prefix in model.lower() for prefix in ["gpt-", "o1-", "text-"]
|
prefix in model.lower() for prefix in ["gpt-", "o1-", "text-"]
|
||||||
@@ -298,6 +303,7 @@ class AzureCompletion(BaseLLM):
|
|||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
)
|
)
|
||||||
|
effective_response_model = response_model or self.response_format
|
||||||
|
|
||||||
# Format messages for Azure
|
# Format messages for Azure
|
||||||
formatted_messages = self._format_messages_for_azure(messages)
|
formatted_messages = self._format_messages_for_azure(messages)
|
||||||
@@ -307,7 +313,7 @@ class AzureCompletion(BaseLLM):
|
|||||||
|
|
||||||
# Prepare completion parameters
|
# Prepare completion parameters
|
||||||
completion_params = self._prepare_completion_params(
|
completion_params = self._prepare_completion_params(
|
||||||
formatted_messages, tools, response_model
|
formatted_messages, tools, effective_response_model
|
||||||
)
|
)
|
||||||
|
|
||||||
# Handle streaming vs non-streaming
|
# Handle streaming vs non-streaming
|
||||||
@@ -317,7 +323,7 @@ class AzureCompletion(BaseLLM):
|
|||||||
available_functions,
|
available_functions,
|
||||||
from_task,
|
from_task,
|
||||||
from_agent,
|
from_agent,
|
||||||
response_model,
|
effective_response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._handle_completion(
|
return self._handle_completion(
|
||||||
@@ -325,7 +331,7 @@ class AzureCompletion(BaseLLM):
|
|||||||
available_functions,
|
available_functions,
|
||||||
from_task,
|
from_task,
|
||||||
from_agent,
|
from_agent,
|
||||||
response_model,
|
effective_response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -364,11 +370,12 @@ class AzureCompletion(BaseLLM):
|
|||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
)
|
)
|
||||||
|
effective_response_model = response_model or self.response_format
|
||||||
|
|
||||||
formatted_messages = self._format_messages_for_azure(messages)
|
formatted_messages = self._format_messages_for_azure(messages)
|
||||||
|
|
||||||
completion_params = self._prepare_completion_params(
|
completion_params = self._prepare_completion_params(
|
||||||
formatted_messages, tools, response_model
|
formatted_messages, tools, effective_response_model
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.stream:
|
if self.stream:
|
||||||
@@ -377,7 +384,7 @@ class AzureCompletion(BaseLLM):
|
|||||||
available_functions,
|
available_functions,
|
||||||
from_task,
|
from_task,
|
||||||
from_agent,
|
from_agent,
|
||||||
response_model,
|
effective_response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
return await self._ahandle_completion(
|
return await self._ahandle_completion(
|
||||||
@@ -385,7 +392,7 @@ class AzureCompletion(BaseLLM):
|
|||||||
available_functions,
|
available_functions,
|
||||||
from_task,
|
from_task,
|
||||||
from_agent,
|
from_agent,
|
||||||
response_model,
|
effective_response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -726,7 +733,7 @@ class AzureCompletion(BaseLLM):
|
|||||||
"""
|
"""
|
||||||
if update.choices:
|
if update.choices:
|
||||||
choice = update.choices[0]
|
choice = update.choices[0]
|
||||||
response_id = update.id if hasattr(update,"id") else None
|
response_id = update.id if hasattr(update, "id") else None
|
||||||
if choice.delta and choice.delta.content:
|
if choice.delta and choice.delta.content:
|
||||||
content_delta = choice.delta.content
|
content_delta = choice.delta.content
|
||||||
full_response += content_delta
|
full_response += content_delta
|
||||||
@@ -734,7 +741,7 @@ class AzureCompletion(BaseLLM):
|
|||||||
chunk=content_delta,
|
chunk=content_delta,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
response_id=response_id
|
response_id=response_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if choice.delta and choice.delta.tool_calls:
|
if choice.delta and choice.delta.tool_calls:
|
||||||
@@ -769,7 +776,7 @@ class AzureCompletion(BaseLLM):
|
|||||||
"index": idx,
|
"index": idx,
|
||||||
},
|
},
|
||||||
call_type=LLMCallType.TOOL_CALL,
|
call_type=LLMCallType.TOOL_CALL,
|
||||||
response_id=response_id
|
response_id=response_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return full_response
|
return full_response
|
||||||
|
|||||||
@@ -172,6 +172,7 @@ class BedrockCompletion(BaseLLM):
|
|||||||
additional_model_request_fields: dict[str, Any] | None = None,
|
additional_model_request_fields: dict[str, Any] | None = None,
|
||||||
additional_model_response_field_paths: list[str] | None = None,
|
additional_model_response_field_paths: list[str] | None = None,
|
||||||
interceptor: BaseInterceptor[Any, Any] | None = None,
|
interceptor: BaseInterceptor[Any, Any] | None = None,
|
||||||
|
response_format: type[BaseModel] | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize AWS Bedrock completion client.
|
"""Initialize AWS Bedrock completion client.
|
||||||
@@ -192,6 +193,8 @@ class BedrockCompletion(BaseLLM):
|
|||||||
additional_model_request_fields: Model-specific request parameters
|
additional_model_request_fields: Model-specific request parameters
|
||||||
additional_model_response_field_paths: Custom response field paths
|
additional_model_response_field_paths: Custom response field paths
|
||||||
interceptor: HTTP interceptor (not yet supported for Bedrock).
|
interceptor: HTTP interceptor (not yet supported for Bedrock).
|
||||||
|
response_format: Pydantic model for structured output. Used as default when
|
||||||
|
response_model is not passed to call()/acall() methods.
|
||||||
**kwargs: Additional parameters
|
**kwargs: Additional parameters
|
||||||
"""
|
"""
|
||||||
if interceptor is not None:
|
if interceptor is not None:
|
||||||
@@ -248,6 +251,7 @@ class BedrockCompletion(BaseLLM):
|
|||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
self.stream = stream
|
self.stream = stream
|
||||||
self.stop_sequences = stop_sequences
|
self.stop_sequences = stop_sequences
|
||||||
|
self.response_format = response_format
|
||||||
|
|
||||||
# Store advanced features (optional)
|
# Store advanced features (optional)
|
||||||
self.guardrail_config = guardrail_config
|
self.guardrail_config = guardrail_config
|
||||||
@@ -299,6 +303,8 @@ class BedrockCompletion(BaseLLM):
|
|||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str | Any:
|
) -> str | Any:
|
||||||
"""Call AWS Bedrock Converse API."""
|
"""Call AWS Bedrock Converse API."""
|
||||||
|
effective_response_model = response_model or self.response_format
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Emit call started event
|
# Emit call started event
|
||||||
self._emit_call_started_event(
|
self._emit_call_started_event(
|
||||||
@@ -375,6 +381,7 @@ class BedrockCompletion(BaseLLM):
|
|||||||
available_functions,
|
available_functions,
|
||||||
from_task,
|
from_task,
|
||||||
from_agent,
|
from_agent,
|
||||||
|
effective_response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._handle_converse(
|
return self._handle_converse(
|
||||||
@@ -383,6 +390,7 @@ class BedrockCompletion(BaseLLM):
|
|||||||
available_functions,
|
available_functions,
|
||||||
from_task,
|
from_task,
|
||||||
from_agent,
|
from_agent,
|
||||||
|
effective_response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -425,6 +433,8 @@ class BedrockCompletion(BaseLLM):
|
|||||||
NotImplementedError: If aiobotocore is not installed.
|
NotImplementedError: If aiobotocore is not installed.
|
||||||
LLMContextLengthExceededError: If context window is exceeded.
|
LLMContextLengthExceededError: If context window is exceeded.
|
||||||
"""
|
"""
|
||||||
|
effective_response_model = response_model or self.response_format
|
||||||
|
|
||||||
if not AIOBOTOCORE_AVAILABLE:
|
if not AIOBOTOCORE_AVAILABLE:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Async support for AWS Bedrock requires aiobotocore. "
|
"Async support for AWS Bedrock requires aiobotocore. "
|
||||||
@@ -494,11 +504,21 @@ class BedrockCompletion(BaseLLM):
|
|||||||
|
|
||||||
if self.stream:
|
if self.stream:
|
||||||
return await self._ahandle_streaming_converse(
|
return await self._ahandle_streaming_converse(
|
||||||
formatted_messages, body, available_functions, from_task, from_agent
|
formatted_messages,
|
||||||
|
body,
|
||||||
|
available_functions,
|
||||||
|
from_task,
|
||||||
|
from_agent,
|
||||||
|
effective_response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
return await self._ahandle_converse(
|
return await self._ahandle_converse(
|
||||||
formatted_messages, body, available_functions, from_task, from_agent
|
formatted_messages,
|
||||||
|
body,
|
||||||
|
available_functions,
|
||||||
|
from_task,
|
||||||
|
from_agent,
|
||||||
|
effective_response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -520,10 +540,29 @@ class BedrockCompletion(BaseLLM):
|
|||||||
available_functions: Mapping[str, Any] | None = None,
|
available_functions: Mapping[str, Any] | None = None,
|
||||||
from_task: Any | None = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Any | None = None,
|
||||||
) -> str:
|
response_model: type[BaseModel] | None = None,
|
||||||
|
) -> str | Any:
|
||||||
"""Handle non-streaming converse API call following AWS best practices."""
|
"""Handle non-streaming converse API call following AWS best practices."""
|
||||||
|
if response_model:
|
||||||
|
structured_tool: ConverseToolTypeDef = {
|
||||||
|
"toolSpec": {
|
||||||
|
"name": "structured_output",
|
||||||
|
"description": "Returns structured data according to the schema",
|
||||||
|
"inputSchema": {"json": response_model.model_json_schema()},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
body["toolConfig"] = cast(
|
||||||
|
"ToolConfigurationTypeDef",
|
||||||
|
cast(
|
||||||
|
object,
|
||||||
|
{
|
||||||
|
"tools": [structured_tool],
|
||||||
|
"toolChoice": {"tool": {"name": "structured_output"}},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Validate messages format before API call
|
|
||||||
if not messages:
|
if not messages:
|
||||||
raise ValueError("Messages cannot be empty")
|
raise ValueError("Messages cannot be empty")
|
||||||
|
|
||||||
@@ -571,6 +610,21 @@ class BedrockCompletion(BaseLLM):
|
|||||||
|
|
||||||
# If there are tool uses but no available_functions, return them for the executor to handle
|
# If there are tool uses but no available_functions, return them for the executor to handle
|
||||||
tool_uses = [block["toolUse"] for block in content if "toolUse" in block]
|
tool_uses = [block["toolUse"] for block in content if "toolUse" in block]
|
||||||
|
|
||||||
|
if response_model and tool_uses:
|
||||||
|
for tool_use in tool_uses:
|
||||||
|
if tool_use.get("name") == "structured_output":
|
||||||
|
structured_data = tool_use.get("input", {})
|
||||||
|
result = response_model.model_validate(structured_data)
|
||||||
|
self._emit_call_completed_event(
|
||||||
|
response=result.model_dump_json(),
|
||||||
|
call_type=LLMCallType.LLM_CALL,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
if tool_uses and not available_functions:
|
if tool_uses and not available_functions:
|
||||||
self._emit_call_completed_event(
|
self._emit_call_completed_event(
|
||||||
response=tool_uses,
|
response=tool_uses,
|
||||||
@@ -717,8 +771,28 @@ class BedrockCompletion(BaseLLM):
|
|||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Any | None = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Any | None = None,
|
||||||
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Handle streaming converse API call with comprehensive event handling."""
|
"""Handle streaming converse API call with comprehensive event handling."""
|
||||||
|
if response_model:
|
||||||
|
structured_tool: ConverseToolTypeDef = {
|
||||||
|
"toolSpec": {
|
||||||
|
"name": "structured_output",
|
||||||
|
"description": "Returns structured data according to the schema",
|
||||||
|
"inputSchema": {"json": response_model.model_json_schema()},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
body["toolConfig"] = cast(
|
||||||
|
"ToolConfigurationTypeDef",
|
||||||
|
cast(
|
||||||
|
object,
|
||||||
|
{
|
||||||
|
"tools": [structured_tool],
|
||||||
|
"toolChoice": {"tool": {"name": "structured_output"}},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
full_response = ""
|
full_response = ""
|
||||||
current_tool_use: dict[str, Any] | None = None
|
current_tool_use: dict[str, Any] | None = None
|
||||||
tool_use_id: str | None = None
|
tool_use_id: str | None = None
|
||||||
@@ -805,7 +879,7 @@ class BedrockCompletion(BaseLLM):
|
|||||||
"index": tool_use_index,
|
"index": tool_use_index,
|
||||||
},
|
},
|
||||||
call_type=LLMCallType.TOOL_CALL,
|
call_type=LLMCallType.TOOL_CALL,
|
||||||
response_id=response_id
|
response_id=response_id,
|
||||||
)
|
)
|
||||||
elif "contentBlockStop" in event:
|
elif "contentBlockStop" in event:
|
||||||
logging.debug("Content block stopped in stream")
|
logging.debug("Content block stopped in stream")
|
||||||
@@ -929,8 +1003,28 @@ class BedrockCompletion(BaseLLM):
|
|||||||
available_functions: Mapping[str, Any] | None = None,
|
available_functions: Mapping[str, Any] | None = None,
|
||||||
from_task: Any | None = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Any | None = None,
|
||||||
) -> str:
|
response_model: type[BaseModel] | None = None,
|
||||||
|
) -> str | Any:
|
||||||
"""Handle async non-streaming converse API call."""
|
"""Handle async non-streaming converse API call."""
|
||||||
|
if response_model:
|
||||||
|
structured_tool: ConverseToolTypeDef = {
|
||||||
|
"toolSpec": {
|
||||||
|
"name": "structured_output",
|
||||||
|
"description": "Returns structured data according to the schema",
|
||||||
|
"inputSchema": {"json": response_model.model_json_schema()},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
body["toolConfig"] = cast(
|
||||||
|
"ToolConfigurationTypeDef",
|
||||||
|
cast(
|
||||||
|
object,
|
||||||
|
{
|
||||||
|
"tools": [structured_tool],
|
||||||
|
"toolChoice": {"tool": {"name": "structured_output"}},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not messages:
|
if not messages:
|
||||||
raise ValueError("Messages cannot be empty")
|
raise ValueError("Messages cannot be empty")
|
||||||
@@ -976,6 +1070,21 @@ class BedrockCompletion(BaseLLM):
|
|||||||
|
|
||||||
# If there are tool uses but no available_functions, return them for the executor to handle
|
# If there are tool uses but no available_functions, return them for the executor to handle
|
||||||
tool_uses = [block["toolUse"] for block in content if "toolUse" in block]
|
tool_uses = [block["toolUse"] for block in content if "toolUse" in block]
|
||||||
|
|
||||||
|
if response_model and tool_uses:
|
||||||
|
for tool_use in tool_uses:
|
||||||
|
if tool_use.get("name") == "structured_output":
|
||||||
|
structured_data = tool_use.get("input", {})
|
||||||
|
result = response_model.model_validate(structured_data)
|
||||||
|
self._emit_call_completed_event(
|
||||||
|
response=result.model_dump_json(),
|
||||||
|
call_type=LLMCallType.LLM_CALL,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
if tool_uses and not available_functions:
|
if tool_uses and not available_functions:
|
||||||
self._emit_call_completed_event(
|
self._emit_call_completed_event(
|
||||||
response=tool_uses,
|
response=tool_uses,
|
||||||
@@ -1106,8 +1215,28 @@ class BedrockCompletion(BaseLLM):
|
|||||||
available_functions: dict[str, Any] | None = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Any | None = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Any | None = None,
|
||||||
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Handle async streaming converse API call."""
|
"""Handle async streaming converse API call."""
|
||||||
|
if response_model:
|
||||||
|
structured_tool: ConverseToolTypeDef = {
|
||||||
|
"toolSpec": {
|
||||||
|
"name": "structured_output",
|
||||||
|
"description": "Returns structured data according to the schema",
|
||||||
|
"inputSchema": {"json": response_model.model_json_schema()},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
body["toolConfig"] = cast(
|
||||||
|
"ToolConfigurationTypeDef",
|
||||||
|
cast(
|
||||||
|
object,
|
||||||
|
{
|
||||||
|
"tools": [structured_tool],
|
||||||
|
"toolChoice": {"tool": {"name": "structured_output"}},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
full_response = ""
|
full_response = ""
|
||||||
current_tool_use: dict[str, Any] | None = None
|
current_tool_use: dict[str, Any] | None = None
|
||||||
tool_use_id: str | None = None
|
tool_use_id: str | None = None
|
||||||
@@ -1174,7 +1303,7 @@ class BedrockCompletion(BaseLLM):
|
|||||||
chunk=text_chunk,
|
chunk=text_chunk,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
response_id=response_id
|
response_id=response_id,
|
||||||
)
|
)
|
||||||
elif "toolUse" in delta and current_tool_use:
|
elif "toolUse" in delta and current_tool_use:
|
||||||
tool_input = delta["toolUse"].get("input", "")
|
tool_input = delta["toolUse"].get("input", "")
|
||||||
|
|||||||
@@ -56,6 +56,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
client_params: dict[str, Any] | None = None,
|
client_params: dict[str, Any] | None = None,
|
||||||
interceptor: BaseInterceptor[Any, Any] | None = None,
|
interceptor: BaseInterceptor[Any, Any] | None = None,
|
||||||
use_vertexai: bool | None = None,
|
use_vertexai: bool | None = None,
|
||||||
|
response_format: type[BaseModel] | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
"""Initialize Google Gemini chat completion client.
|
"""Initialize Google Gemini chat completion client.
|
||||||
@@ -86,6 +87,8 @@ class GeminiCompletion(BaseLLM):
|
|||||||
- None (default): Check GOOGLE_GENAI_USE_VERTEXAI env var
|
- None (default): Check GOOGLE_GENAI_USE_VERTEXAI env var
|
||||||
When using Vertex AI with API key (Express mode), http_options with
|
When using Vertex AI with API key (Express mode), http_options with
|
||||||
api_version="v1" is automatically configured.
|
api_version="v1" is automatically configured.
|
||||||
|
response_format: Pydantic model for structured output. Used as default when
|
||||||
|
response_model is not passed to call()/acall() methods.
|
||||||
**kwargs: Additional parameters
|
**kwargs: Additional parameters
|
||||||
"""
|
"""
|
||||||
if interceptor is not None:
|
if interceptor is not None:
|
||||||
@@ -121,6 +124,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
self.safety_settings = safety_settings or {}
|
self.safety_settings = safety_settings or {}
|
||||||
self.stop_sequences = stop_sequences or []
|
self.stop_sequences = stop_sequences or []
|
||||||
self.tools: list[dict[str, Any]] | None = None
|
self.tools: list[dict[str, Any]] | None = None
|
||||||
|
self.response_format = response_format
|
||||||
|
|
||||||
# Model-specific settings
|
# Model-specific settings
|
||||||
version_match = re.search(r"gemini-(\d+(?:\.\d+)?)", model.lower())
|
version_match = re.search(r"gemini-(\d+(?:\.\d+)?)", model.lower())
|
||||||
@@ -292,6 +296,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
)
|
)
|
||||||
self.tools = tools
|
self.tools = tools
|
||||||
|
effective_response_model = response_model or self.response_format
|
||||||
|
|
||||||
formatted_content, system_instruction = self._format_messages_for_gemini(
|
formatted_content, system_instruction = self._format_messages_for_gemini(
|
||||||
messages
|
messages
|
||||||
@@ -303,7 +308,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
raise ValueError("LLM call blocked by before_llm_call hook")
|
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||||
|
|
||||||
config = self._prepare_generation_config(
|
config = self._prepare_generation_config(
|
||||||
system_instruction, tools, response_model
|
system_instruction, tools, effective_response_model
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.stream:
|
if self.stream:
|
||||||
@@ -313,7 +318,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
available_functions,
|
available_functions,
|
||||||
from_task,
|
from_task,
|
||||||
from_agent,
|
from_agent,
|
||||||
response_model,
|
effective_response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._handle_completion(
|
return self._handle_completion(
|
||||||
@@ -322,7 +327,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
available_functions,
|
available_functions,
|
||||||
from_task,
|
from_task,
|
||||||
from_agent,
|
from_agent,
|
||||||
response_model,
|
effective_response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
except APIError as e:
|
except APIError as e:
|
||||||
@@ -374,13 +379,14 @@ class GeminiCompletion(BaseLLM):
|
|||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
)
|
)
|
||||||
self.tools = tools
|
self.tools = tools
|
||||||
|
effective_response_model = response_model or self.response_format
|
||||||
|
|
||||||
formatted_content, system_instruction = self._format_messages_for_gemini(
|
formatted_content, system_instruction = self._format_messages_for_gemini(
|
||||||
messages
|
messages
|
||||||
)
|
)
|
||||||
|
|
||||||
config = self._prepare_generation_config(
|
config = self._prepare_generation_config(
|
||||||
system_instruction, tools, response_model
|
system_instruction, tools, effective_response_model
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.stream:
|
if self.stream:
|
||||||
@@ -390,7 +396,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
available_functions,
|
available_functions,
|
||||||
from_task,
|
from_task,
|
||||||
from_agent,
|
from_agent,
|
||||||
response_model,
|
effective_response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
return await self._ahandle_completion(
|
return await self._ahandle_completion(
|
||||||
@@ -399,7 +405,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
available_functions,
|
available_functions,
|
||||||
from_task,
|
from_task,
|
||||||
from_agent,
|
from_agent,
|
||||||
response_model,
|
effective_response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
except APIError as e:
|
except APIError as e:
|
||||||
@@ -570,10 +576,10 @@ class GeminiCompletion(BaseLLM):
|
|||||||
types.Content(role="user", parts=[function_response_part])
|
types.Content(role="user", parts=[function_response_part])
|
||||||
)
|
)
|
||||||
elif role == "assistant" and message.get("tool_calls"):
|
elif role == "assistant" and message.get("tool_calls"):
|
||||||
parts: list[types.Part] = []
|
tool_parts: list[types.Part] = []
|
||||||
|
|
||||||
if text_content:
|
if text_content:
|
||||||
parts.append(types.Part.from_text(text=text_content))
|
tool_parts.append(types.Part.from_text(text=text_content))
|
||||||
|
|
||||||
tool_calls: list[dict[str, Any]] = message.get("tool_calls") or []
|
tool_calls: list[dict[str, Any]] = message.get("tool_calls") or []
|
||||||
for tool_call in tool_calls:
|
for tool_call in tool_calls:
|
||||||
@@ -592,11 +598,11 @@ class GeminiCompletion(BaseLLM):
|
|||||||
else:
|
else:
|
||||||
func_args = func_args_raw
|
func_args = func_args_raw
|
||||||
|
|
||||||
parts.append(
|
tool_parts.append(
|
||||||
types.Part.from_function_call(name=func_name, args=func_args)
|
types.Part.from_function_call(name=func_name, args=func_args)
|
||||||
)
|
)
|
||||||
|
|
||||||
contents.append(types.Content(role="model", parts=parts))
|
contents.append(types.Content(role="model", parts=tool_parts))
|
||||||
else:
|
else:
|
||||||
# Convert role for Gemini (assistant -> model)
|
# Convert role for Gemini (assistant -> model)
|
||||||
gemini_role = "model" if role == "assistant" else "user"
|
gemini_role = "model" if role == "assistant" else "user"
|
||||||
@@ -790,7 +796,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple of (updated full_response, updated function_calls, updated usage_data)
|
Tuple of (updated full_response, updated function_calls, updated usage_data)
|
||||||
"""
|
"""
|
||||||
response_id=chunk.response_id if hasattr(chunk,"response_id") else None
|
response_id = chunk.response_id if hasattr(chunk, "response_id") else None
|
||||||
if chunk.usage_metadata:
|
if chunk.usage_metadata:
|
||||||
usage_data = self._extract_token_usage(chunk)
|
usage_data = self._extract_token_usage(chunk)
|
||||||
|
|
||||||
@@ -800,7 +806,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
chunk=chunk.text,
|
chunk=chunk.text,
|
||||||
from_task=from_task,
|
from_task=from_task,
|
||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
response_id=response_id
|
response_id=response_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if chunk.candidates:
|
if chunk.candidates:
|
||||||
@@ -837,7 +843,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
"index": call_index,
|
"index": call_index,
|
||||||
},
|
},
|
||||||
call_type=LLMCallType.TOOL_CALL,
|
call_type=LLMCallType.TOOL_CALL,
|
||||||
response_id=response_id
|
response_id=response_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return full_response, function_calls, usage_data
|
return full_response, function_calls, usage_data
|
||||||
@@ -972,7 +978,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
from_task: Any | None = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Any | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str:
|
) -> str | Any:
|
||||||
"""Handle streaming content generation."""
|
"""Handle streaming content generation."""
|
||||||
full_response = ""
|
full_response = ""
|
||||||
function_calls: dict[int, dict[str, Any]] = {}
|
function_calls: dict[int, dict[str, Any]] = {}
|
||||||
@@ -1050,7 +1056,7 @@ class GeminiCompletion(BaseLLM):
|
|||||||
from_task: Any | None = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Any | None = None,
|
from_agent: Any | None = None,
|
||||||
response_model: type[BaseModel] | None = None,
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> str:
|
) -> str | Any:
|
||||||
"""Handle async streaming content generation."""
|
"""Handle async streaming content generation."""
|
||||||
full_response = ""
|
full_response = ""
|
||||||
function_calls: dict[int, dict[str, Any]] = {}
|
function_calls: dict[int, dict[str, Any]] = {}
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||||
GoogleGenerativeAiEmbeddingFunction,
|
GoogleGenerativeAiEmbeddingFunction,
|
||||||
GoogleVertexEmbeddingFunction,
|
|
||||||
)
|
)
|
||||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||||
HuggingFaceEmbeddingFunction,
|
HuggingFaceEmbeddingFunction,
|
||||||
@@ -52,6 +51,9 @@ if TYPE_CHECKING:
|
|||||||
from crewai.rag.embeddings.providers.aws.types import BedrockProviderSpec
|
from crewai.rag.embeddings.providers.aws.types import BedrockProviderSpec
|
||||||
from crewai.rag.embeddings.providers.cohere.types import CohereProviderSpec
|
from crewai.rag.embeddings.providers.cohere.types import CohereProviderSpec
|
||||||
from crewai.rag.embeddings.providers.custom.types import CustomProviderSpec
|
from crewai.rag.embeddings.providers.custom.types import CustomProviderSpec
|
||||||
|
from crewai.rag.embeddings.providers.google.genai_vertex_embedding import (
|
||||||
|
GoogleGenAIVertexEmbeddingFunction,
|
||||||
|
)
|
||||||
from crewai.rag.embeddings.providers.google.types import (
|
from crewai.rag.embeddings.providers.google.types import (
|
||||||
GenerativeAiProviderSpec,
|
GenerativeAiProviderSpec,
|
||||||
VertexAIProviderSpec,
|
VertexAIProviderSpec,
|
||||||
@@ -163,7 +165,7 @@ def build_embedder_from_dict(spec: OpenAIProviderSpec) -> OpenAIEmbeddingFunctio
|
|||||||
@overload
|
@overload
|
||||||
def build_embedder_from_dict(
|
def build_embedder_from_dict(
|
||||||
spec: VertexAIProviderSpec,
|
spec: VertexAIProviderSpec,
|
||||||
) -> GoogleVertexEmbeddingFunction: ...
|
) -> GoogleGenAIVertexEmbeddingFunction: ...
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
@@ -296,7 +298,9 @@ def build_embedder(spec: OpenAIProviderSpec) -> OpenAIEmbeddingFunction: ...
|
|||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def build_embedder(spec: VertexAIProviderSpec) -> GoogleVertexEmbeddingFunction: ...
|
def build_embedder(
|
||||||
|
spec: VertexAIProviderSpec,
|
||||||
|
) -> GoogleGenAIVertexEmbeddingFunction: ...
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|||||||
@@ -1,5 +1,8 @@
|
|||||||
"""Google embedding providers."""
|
"""Google embedding providers."""
|
||||||
|
|
||||||
|
from crewai.rag.embeddings.providers.google.genai_vertex_embedding import (
|
||||||
|
GoogleGenAIVertexEmbeddingFunction,
|
||||||
|
)
|
||||||
from crewai.rag.embeddings.providers.google.generative_ai import (
|
from crewai.rag.embeddings.providers.google.generative_ai import (
|
||||||
GenerativeAiProvider,
|
GenerativeAiProvider,
|
||||||
)
|
)
|
||||||
@@ -18,6 +21,7 @@ __all__ = [
|
|||||||
"GenerativeAiProvider",
|
"GenerativeAiProvider",
|
||||||
"GenerativeAiProviderConfig",
|
"GenerativeAiProviderConfig",
|
||||||
"GenerativeAiProviderSpec",
|
"GenerativeAiProviderSpec",
|
||||||
|
"GoogleGenAIVertexEmbeddingFunction",
|
||||||
"VertexAIProvider",
|
"VertexAIProvider",
|
||||||
"VertexAIProviderConfig",
|
"VertexAIProviderConfig",
|
||||||
"VertexAIProviderSpec",
|
"VertexAIProviderSpec",
|
||||||
|
|||||||
@@ -0,0 +1,237 @@
|
|||||||
|
"""Google Vertex AI embedding function implementation.
|
||||||
|
|
||||||
|
This module supports both the new google-genai SDK and the deprecated
|
||||||
|
vertexai.language_models module for backwards compatibility.
|
||||||
|
|
||||||
|
The deprecated vertexai.language_models module will be removed after June 24, 2026.
|
||||||
|
Migration guide: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, ClassVar, cast
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings
|
||||||
|
from typing_extensions import Unpack
|
||||||
|
|
||||||
|
from crewai.rag.embeddings.providers.google.types import VertexAIProviderConfig
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleGenAIVertexEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||||
|
"""Embedding function for Google Vertex AI with dual SDK support.
|
||||||
|
|
||||||
|
This class supports both:
|
||||||
|
- Legacy models (textembedding-gecko*) using the deprecated vertexai.language_models SDK
|
||||||
|
- New models (gemini-embedding-*, text-embedding-*) using the google-genai SDK
|
||||||
|
|
||||||
|
The SDK is automatically selected based on the model name. Legacy models will
|
||||||
|
emit a deprecation warning.
|
||||||
|
|
||||||
|
Supports two authentication modes:
|
||||||
|
1. Vertex AI backend: Set project_id and location/region (uses Application Default Credentials)
|
||||||
|
2. API key: Set api_key for direct API access
|
||||||
|
|
||||||
|
Example:
|
||||||
|
# Using legacy model (will emit deprecation warning)
|
||||||
|
embedder = GoogleGenAIVertexEmbeddingFunction(
|
||||||
|
project_id="my-project",
|
||||||
|
region="us-central1",
|
||||||
|
model_name="textembedding-gecko"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Using new model with google-genai SDK
|
||||||
|
embedder = GoogleGenAIVertexEmbeddingFunction(
|
||||||
|
project_id="my-project",
|
||||||
|
location="us-central1",
|
||||||
|
model_name="gemini-embedding-001"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Using API key (new SDK only)
|
||||||
|
embedder = GoogleGenAIVertexEmbeddingFunction(
|
||||||
|
api_key="your-api-key",
|
||||||
|
model_name="gemini-embedding-001"
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Models that use the legacy vertexai.language_models SDK
|
||||||
|
LEGACY_MODELS: ClassVar[set[str]] = {
|
||||||
|
"textembedding-gecko",
|
||||||
|
"textembedding-gecko@001",
|
||||||
|
"textembedding-gecko@002",
|
||||||
|
"textembedding-gecko@003",
|
||||||
|
"textembedding-gecko@latest",
|
||||||
|
"textembedding-gecko-multilingual",
|
||||||
|
"textembedding-gecko-multilingual@001",
|
||||||
|
"textembedding-gecko-multilingual@latest",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Models that use the new google-genai SDK
|
||||||
|
GENAI_MODELS: ClassVar[set[str]] = {
|
||||||
|
"gemini-embedding-001",
|
||||||
|
"text-embedding-005",
|
||||||
|
"text-multilingual-embedding-002",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: Unpack[VertexAIProviderConfig]) -> None:
|
||||||
|
"""Initialize Google Vertex AI embedding function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Configuration parameters including:
|
||||||
|
- model_name: Model to use for embeddings (default: "textembedding-gecko")
|
||||||
|
- api_key: Optional API key for authentication (new SDK only)
|
||||||
|
- project_id: GCP project ID (for Vertex AI backend)
|
||||||
|
- location: GCP region (default: "us-central1")
|
||||||
|
- region: Deprecated alias for location
|
||||||
|
- task_type: Task type for embeddings (default: "RETRIEVAL_DOCUMENT", new SDK only)
|
||||||
|
- output_dimensionality: Optional output embedding dimension (new SDK only)
|
||||||
|
"""
|
||||||
|
# Handle deprecated 'region' parameter (only if it has a value)
|
||||||
|
region_value = kwargs.pop("region", None) # type: ignore[typeddict-item]
|
||||||
|
if region_value is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"The 'region' parameter is deprecated, use 'location' instead. "
|
||||||
|
"See: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
if "location" not in kwargs or kwargs.get("location") is None:
|
||||||
|
kwargs["location"] = region_value # type: ignore[typeddict-unknown-key]
|
||||||
|
|
||||||
|
self._config = kwargs
|
||||||
|
self._model_name = str(kwargs.get("model_name", "textembedding-gecko"))
|
||||||
|
self._use_legacy = self._is_legacy_model(self._model_name)
|
||||||
|
|
||||||
|
if self._use_legacy:
|
||||||
|
self._init_legacy_client(**kwargs)
|
||||||
|
else:
|
||||||
|
self._init_genai_client(**kwargs)
|
||||||
|
|
||||||
|
def _is_legacy_model(self, model_name: str) -> bool:
|
||||||
|
"""Check if the model uses the legacy SDK."""
|
||||||
|
return model_name in self.LEGACY_MODELS or model_name.startswith(
|
||||||
|
"textembedding-gecko"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _init_legacy_client(self, **kwargs: Any) -> None:
|
||||||
|
"""Initialize using the deprecated vertexai.language_models SDK."""
|
||||||
|
warnings.warn(
|
||||||
|
f"Model '{self._model_name}' uses the deprecated vertexai.language_models SDK "
|
||||||
|
"which will be removed after June 24, 2026. Consider migrating to newer models "
|
||||||
|
"like 'gemini-embedding-001'. "
|
||||||
|
"See: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import vertexai
|
||||||
|
from vertexai.language_models import TextEmbeddingModel
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"vertexai is required for legacy embedding models (textembedding-gecko*). "
|
||||||
|
"Install it with: pip install google-cloud-aiplatform"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
project_id = kwargs.get("project_id")
|
||||||
|
location = str(kwargs.get("location", "us-central1"))
|
||||||
|
|
||||||
|
if not project_id:
|
||||||
|
raise ValueError(
|
||||||
|
"project_id is required for legacy models. "
|
||||||
|
"For API key authentication, use newer models like 'gemini-embedding-001'."
|
||||||
|
)
|
||||||
|
|
||||||
|
vertexai.init(project=str(project_id), location=location)
|
||||||
|
self._legacy_model = TextEmbeddingModel.from_pretrained(self._model_name)
|
||||||
|
|
||||||
|
def _init_genai_client(self, **kwargs: Any) -> None:
|
||||||
|
"""Initialize using the new google-genai SDK."""
|
||||||
|
try:
|
||||||
|
from google import genai
|
||||||
|
from google.genai.types import EmbedContentConfig
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"google-genai is required for Google Gen AI embeddings. "
|
||||||
|
"Install it with: uv add 'crewai[google-genai]'"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
self._genai = genai
|
||||||
|
self._EmbedContentConfig = EmbedContentConfig
|
||||||
|
self._task_type = kwargs.get("task_type", "RETRIEVAL_DOCUMENT")
|
||||||
|
self._output_dimensionality = kwargs.get("output_dimensionality")
|
||||||
|
|
||||||
|
# Initialize client based on authentication mode
|
||||||
|
api_key = kwargs.get("api_key")
|
||||||
|
project_id = kwargs.get("project_id")
|
||||||
|
location: str = str(kwargs.get("location", "us-central1"))
|
||||||
|
|
||||||
|
if api_key:
|
||||||
|
self._client = genai.Client(api_key=api_key)
|
||||||
|
elif project_id:
|
||||||
|
self._client = genai.Client(
|
||||||
|
vertexai=True,
|
||||||
|
project=str(project_id),
|
||||||
|
location=location,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Either 'api_key' (for API key authentication) or 'project_id' "
|
||||||
|
"(for Vertex AI backend with ADC) must be provided."
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def name() -> str:
|
||||||
|
"""Return the name of the embedding function for ChromaDB compatibility."""
|
||||||
|
return "google-vertex"
|
||||||
|
|
||||||
|
def __call__(self, input: Documents) -> Embeddings:
|
||||||
|
"""Generate embeddings for input documents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input: List of documents to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of embedding vectors.
|
||||||
|
"""
|
||||||
|
if isinstance(input, str):
|
||||||
|
input = [input]
|
||||||
|
|
||||||
|
if self._use_legacy:
|
||||||
|
return self._call_legacy(input)
|
||||||
|
return self._call_genai(input)
|
||||||
|
|
||||||
|
def _call_legacy(self, input: list[str]) -> Embeddings:
|
||||||
|
"""Generate embeddings using the legacy SDK."""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
embeddings_list = []
|
||||||
|
for text in input:
|
||||||
|
embedding_result = self._legacy_model.get_embeddings([text])
|
||||||
|
embeddings_list.append(
|
||||||
|
np.array(embedding_result[0].values, dtype=np.float32)
|
||||||
|
)
|
||||||
|
|
||||||
|
return cast(Embeddings, embeddings_list)
|
||||||
|
|
||||||
|
def _call_genai(self, input: list[str]) -> Embeddings:
|
||||||
|
"""Generate embeddings using the new google-genai SDK."""
|
||||||
|
# Build config for embed_content
|
||||||
|
config_kwargs: dict[str, Any] = {
|
||||||
|
"task_type": self._task_type,
|
||||||
|
}
|
||||||
|
if self._output_dimensionality is not None:
|
||||||
|
config_kwargs["output_dimensionality"] = self._output_dimensionality
|
||||||
|
|
||||||
|
config = self._EmbedContentConfig(**config_kwargs)
|
||||||
|
|
||||||
|
# Call the embedding API
|
||||||
|
response = self._client.models.embed_content(
|
||||||
|
model=self._model_name,
|
||||||
|
contents=input, # type: ignore[arg-type]
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract embeddings from response
|
||||||
|
if response.embeddings is None:
|
||||||
|
raise ValueError("No embeddings returned from the API")
|
||||||
|
embeddings = [emb.values for emb in response.embeddings]
|
||||||
|
return cast(Embeddings, embeddings)
|
||||||
@@ -34,12 +34,47 @@ class GenerativeAiProviderSpec(TypedDict):
|
|||||||
|
|
||||||
|
|
||||||
class VertexAIProviderConfig(TypedDict, total=False):
|
class VertexAIProviderConfig(TypedDict, total=False):
|
||||||
"""Configuration for Vertex AI provider."""
|
"""Configuration for Vertex AI provider with dual SDK support.
|
||||||
|
|
||||||
|
Supports both legacy models (textembedding-gecko*) using the deprecated
|
||||||
|
vertexai.language_models SDK and new models using google-genai SDK.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
api_key: Google API key (optional if using project_id with ADC). Only for new SDK models.
|
||||||
|
model_name: Embedding model name (default: "textembedding-gecko").
|
||||||
|
Legacy models: textembedding-gecko, textembedding-gecko@001, etc.
|
||||||
|
New models: gemini-embedding-001, text-embedding-005, text-multilingual-embedding-002
|
||||||
|
project_id: GCP project ID (required for Vertex AI backend and legacy models).
|
||||||
|
location: GCP region/location (default: "us-central1").
|
||||||
|
region: Deprecated alias for location (kept for backwards compatibility).
|
||||||
|
task_type: Task type for embeddings (default: "RETRIEVAL_DOCUMENT"). Only for new SDK models.
|
||||||
|
output_dimensionality: Output embedding dimension (optional). Only for new SDK models.
|
||||||
|
"""
|
||||||
|
|
||||||
api_key: str
|
api_key: str
|
||||||
model_name: Annotated[str, "textembedding-gecko"]
|
model_name: Annotated[
|
||||||
project_id: Annotated[str, "cloud-large-language-models"]
|
Literal[
|
||||||
region: Annotated[str, "us-central1"]
|
# Legacy models (deprecated vertexai.language_models SDK)
|
||||||
|
"textembedding-gecko",
|
||||||
|
"textembedding-gecko@001",
|
||||||
|
"textembedding-gecko@002",
|
||||||
|
"textembedding-gecko@003",
|
||||||
|
"textembedding-gecko@latest",
|
||||||
|
"textembedding-gecko-multilingual",
|
||||||
|
"textembedding-gecko-multilingual@001",
|
||||||
|
"textembedding-gecko-multilingual@latest",
|
||||||
|
# New models (google-genai SDK)
|
||||||
|
"gemini-embedding-001",
|
||||||
|
"text-embedding-005",
|
||||||
|
"text-multilingual-embedding-002",
|
||||||
|
],
|
||||||
|
"textembedding-gecko",
|
||||||
|
]
|
||||||
|
project_id: str
|
||||||
|
location: Annotated[str, "us-central1"]
|
||||||
|
region: Annotated[str, "us-central1"] # Deprecated alias for location
|
||||||
|
task_type: Annotated[str, "RETRIEVAL_DOCUMENT"]
|
||||||
|
output_dimensionality: int
|
||||||
|
|
||||||
|
|
||||||
class VertexAIProviderSpec(TypedDict, total=False):
|
class VertexAIProviderSpec(TypedDict, total=False):
|
||||||
|
|||||||
@@ -1,46 +1,126 @@
|
|||||||
"""Google Vertex AI embeddings provider."""
|
"""Google Vertex AI embeddings provider.
|
||||||
|
|
||||||
|
This module supports both the new google-genai SDK and the deprecated
|
||||||
|
vertexai.language_models module for backwards compatibility.
|
||||||
|
|
||||||
|
The SDK is automatically selected based on the model name:
|
||||||
|
- Legacy models (textembedding-gecko*) use vertexai.language_models (deprecated)
|
||||||
|
- New models (gemini-embedding-*, text-embedding-*) use google-genai
|
||||||
|
|
||||||
|
Migration guide: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
|
||||||
GoogleVertexEmbeddingFunction,
|
|
||||||
)
|
|
||||||
from pydantic import AliasChoices, Field
|
from pydantic import AliasChoices, Field
|
||||||
|
|
||||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||||
|
from crewai.rag.embeddings.providers.google.genai_vertex_embedding import (
|
||||||
|
GoogleGenAIVertexEmbeddingFunction,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class VertexAIProvider(BaseEmbeddingsProvider[GoogleVertexEmbeddingFunction]):
|
class VertexAIProvider(BaseEmbeddingsProvider[GoogleGenAIVertexEmbeddingFunction]):
|
||||||
"""Google Vertex AI embeddings provider."""
|
"""Google Vertex AI embeddings provider with dual SDK support.
|
||||||
|
|
||||||
embedding_callable: type[GoogleVertexEmbeddingFunction] = Field(
|
Supports both legacy models (textembedding-gecko*) using the deprecated
|
||||||
default=GoogleVertexEmbeddingFunction,
|
vertexai.language_models SDK and new models (gemini-embedding-*, text-embedding-*)
|
||||||
description="Vertex AI embedding function class",
|
using the google-genai SDK.
|
||||||
|
|
||||||
|
The SDK is automatically selected based on the model name. Legacy models will
|
||||||
|
emit a deprecation warning.
|
||||||
|
|
||||||
|
Authentication modes:
|
||||||
|
1. Vertex AI backend: Set project_id and location/region (uses Application Default Credentials)
|
||||||
|
2. API key: Set api_key for direct API access (new SDK models only)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
# Legacy model (backwards compatible, will emit deprecation warning)
|
||||||
|
provider = VertexAIProvider(
|
||||||
|
project_id="my-project",
|
||||||
|
region="us-central1", # or location="us-central1"
|
||||||
|
model_name="textembedding-gecko"
|
||||||
|
)
|
||||||
|
|
||||||
|
# New model with Vertex AI backend
|
||||||
|
provider = VertexAIProvider(
|
||||||
|
project_id="my-project",
|
||||||
|
location="us-central1",
|
||||||
|
model_name="gemini-embedding-001"
|
||||||
|
)
|
||||||
|
|
||||||
|
# New model with API key
|
||||||
|
provider = VertexAIProvider(
|
||||||
|
api_key="your-api-key",
|
||||||
|
model_name="gemini-embedding-001"
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
embedding_callable: type[GoogleGenAIVertexEmbeddingFunction] = Field(
|
||||||
|
default=GoogleGenAIVertexEmbeddingFunction,
|
||||||
|
description="Google Vertex AI embedding function class",
|
||||||
)
|
)
|
||||||
model_name: str = Field(
|
model_name: str = Field(
|
||||||
default="textembedding-gecko",
|
default="textembedding-gecko",
|
||||||
description="Model name to use for embeddings",
|
description=(
|
||||||
|
"Model name to use for embeddings. Legacy models (textembedding-gecko*) "
|
||||||
|
"use the deprecated SDK. New models (gemini-embedding-001, text-embedding-005) "
|
||||||
|
"use the google-genai SDK."
|
||||||
|
),
|
||||||
validation_alias=AliasChoices(
|
validation_alias=AliasChoices(
|
||||||
"EMBEDDINGS_GOOGLE_VERTEX_MODEL_NAME",
|
"EMBEDDINGS_GOOGLE_VERTEX_MODEL_NAME",
|
||||||
"GOOGLE_VERTEX_MODEL_NAME",
|
"GOOGLE_VERTEX_MODEL_NAME",
|
||||||
"model",
|
"model",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
api_key: str = Field(
|
api_key: str | None = Field(
|
||||||
description="Google API key",
|
default=None,
|
||||||
|
description="Google API key (optional if using project_id with Application Default Credentials)",
|
||||||
validation_alias=AliasChoices(
|
validation_alias=AliasChoices(
|
||||||
"EMBEDDINGS_GOOGLE_CLOUD_API_KEY", "GOOGLE_CLOUD_API_KEY"
|
"EMBEDDINGS_GOOGLE_CLOUD_API_KEY",
|
||||||
|
"GOOGLE_CLOUD_API_KEY",
|
||||||
|
"GOOGLE_API_KEY",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
project_id: str = Field(
|
project_id: str | None = Field(
|
||||||
default="cloud-large-language-models",
|
default=None,
|
||||||
description="GCP project ID",
|
description="GCP project ID (required for Vertex AI backend and legacy models)",
|
||||||
validation_alias=AliasChoices(
|
validation_alias=AliasChoices(
|
||||||
"EMBEDDINGS_GOOGLE_CLOUD_PROJECT", "GOOGLE_CLOUD_PROJECT"
|
"EMBEDDINGS_GOOGLE_CLOUD_PROJECT",
|
||||||
|
"GOOGLE_CLOUD_PROJECT",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
region: str = Field(
|
location: str = Field(
|
||||||
default="us-central1",
|
default="us-central1",
|
||||||
description="GCP region",
|
description="GCP region/location",
|
||||||
validation_alias=AliasChoices(
|
validation_alias=AliasChoices(
|
||||||
"EMBEDDINGS_GOOGLE_CLOUD_REGION", "GOOGLE_CLOUD_REGION"
|
"EMBEDDINGS_GOOGLE_CLOUD_LOCATION",
|
||||||
|
"EMBEDDINGS_GOOGLE_CLOUD_REGION",
|
||||||
|
"GOOGLE_CLOUD_LOCATION",
|
||||||
|
"GOOGLE_CLOUD_REGION",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
region: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Deprecated: Use 'location' instead. GCP region (kept for backwards compatibility)",
|
||||||
|
validation_alias=AliasChoices(
|
||||||
|
"EMBEDDINGS_GOOGLE_VERTEX_REGION",
|
||||||
|
"GOOGLE_VERTEX_REGION",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
task_type: str = Field(
|
||||||
|
default="RETRIEVAL_DOCUMENT",
|
||||||
|
description="Task type for embeddings (e.g., RETRIEVAL_DOCUMENT, RETRIEVAL_QUERY). Only used with new SDK models.",
|
||||||
|
validation_alias=AliasChoices(
|
||||||
|
"EMBEDDINGS_GOOGLE_VERTEX_TASK_TYPE",
|
||||||
|
"GOOGLE_VERTEX_TASK_TYPE",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
output_dimensionality: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Output embedding dimensionality (optional). Only used with new SDK models.",
|
||||||
|
validation_alias=AliasChoices(
|
||||||
|
"EMBEDDINGS_GOOGLE_VERTEX_OUTPUT_DIMENSIONALITY",
|
||||||
|
"GOOGLE_VERTEX_OUTPUT_DIMENSIONALITY",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -173,6 +173,13 @@ class Telemetry:
|
|||||||
|
|
||||||
self._original_handlers: dict[int, Any] = {}
|
self._original_handlers: dict[int, Any] = {}
|
||||||
|
|
||||||
|
if threading.current_thread() is not threading.main_thread():
|
||||||
|
logger.debug(
|
||||||
|
"CrewAI telemetry: Skipping signal handler registration "
|
||||||
|
"(not running in main thread)."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
self._register_signal_handler(signal.SIGTERM, SigTermEvent, shutdown=True)
|
self._register_signal_handler(signal.SIGTERM, SigTermEvent, shutdown=True)
|
||||||
self._register_signal_handler(signal.SIGINT, SigIntEvent, shutdown=True)
|
self._register_signal_handler(signal.SIGINT, SigIntEvent, shutdown=True)
|
||||||
if hasattr(signal, "SIGHUP"):
|
if hasattr(signal, "SIGHUP"):
|
||||||
|
|||||||
@@ -5,17 +5,29 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Coroutine
|
from collections.abc import Coroutine
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
|
import logging
|
||||||
from typing import TYPE_CHECKING, TypeVar
|
from typing import TYPE_CHECKING, TypeVar
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from aiocache import Cache # type: ignore[import-untyped]
|
|
||||||
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from aiocache import Cache
|
||||||
from crewai_files import FileInput
|
from crewai_files import FileInput
|
||||||
|
|
||||||
_file_store = Cache(Cache.MEMORY, serializer=PickleSerializer())
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_file_store: Cache | None = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
from aiocache import Cache
|
||||||
|
from aiocache.serializers import PickleSerializer
|
||||||
|
|
||||||
|
_file_store = Cache(Cache.MEMORY, serializer=PickleSerializer())
|
||||||
|
except ImportError:
|
||||||
|
logger.debug(
|
||||||
|
"aiocache is not installed. File store features will be disabled. "
|
||||||
|
"Install with: uv add aiocache"
|
||||||
|
)
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
@@ -59,6 +71,8 @@ async def astore_files(
|
|||||||
files: Dictionary mapping names to file inputs.
|
files: Dictionary mapping names to file inputs.
|
||||||
ttl: Time-to-live in seconds.
|
ttl: Time-to-live in seconds.
|
||||||
"""
|
"""
|
||||||
|
if _file_store is None:
|
||||||
|
return
|
||||||
await _file_store.set(f"{_CREW_PREFIX}{execution_id}", files, ttl=ttl)
|
await _file_store.set(f"{_CREW_PREFIX}{execution_id}", files, ttl=ttl)
|
||||||
|
|
||||||
|
|
||||||
@@ -71,6 +85,8 @@ async def aget_files(execution_id: UUID) -> dict[str, FileInput] | None:
|
|||||||
Returns:
|
Returns:
|
||||||
Dictionary of files or None if not found.
|
Dictionary of files or None if not found.
|
||||||
"""
|
"""
|
||||||
|
if _file_store is None:
|
||||||
|
return None
|
||||||
result: dict[str, FileInput] | None = await _file_store.get(
|
result: dict[str, FileInput] | None = await _file_store.get(
|
||||||
f"{_CREW_PREFIX}{execution_id}"
|
f"{_CREW_PREFIX}{execution_id}"
|
||||||
)
|
)
|
||||||
@@ -83,6 +99,8 @@ async def aclear_files(execution_id: UUID) -> None:
|
|||||||
Args:
|
Args:
|
||||||
execution_id: Unique identifier for the crew execution.
|
execution_id: Unique identifier for the crew execution.
|
||||||
"""
|
"""
|
||||||
|
if _file_store is None:
|
||||||
|
return
|
||||||
await _file_store.delete(f"{_CREW_PREFIX}{execution_id}")
|
await _file_store.delete(f"{_CREW_PREFIX}{execution_id}")
|
||||||
|
|
||||||
|
|
||||||
@@ -98,6 +116,8 @@ async def astore_task_files(
|
|||||||
files: Dictionary mapping names to file inputs.
|
files: Dictionary mapping names to file inputs.
|
||||||
ttl: Time-to-live in seconds.
|
ttl: Time-to-live in seconds.
|
||||||
"""
|
"""
|
||||||
|
if _file_store is None:
|
||||||
|
return
|
||||||
await _file_store.set(f"{_TASK_PREFIX}{task_id}", files, ttl=ttl)
|
await _file_store.set(f"{_TASK_PREFIX}{task_id}", files, ttl=ttl)
|
||||||
|
|
||||||
|
|
||||||
@@ -110,6 +130,8 @@ async def aget_task_files(task_id: UUID) -> dict[str, FileInput] | None:
|
|||||||
Returns:
|
Returns:
|
||||||
Dictionary of files or None if not found.
|
Dictionary of files or None if not found.
|
||||||
"""
|
"""
|
||||||
|
if _file_store is None:
|
||||||
|
return None
|
||||||
result: dict[str, FileInput] | None = await _file_store.get(
|
result: dict[str, FileInput] | None = await _file_store.get(
|
||||||
f"{_TASK_PREFIX}{task_id}"
|
f"{_TASK_PREFIX}{task_id}"
|
||||||
)
|
)
|
||||||
@@ -122,6 +144,8 @@ async def aclear_task_files(task_id: UUID) -> None:
|
|||||||
Args:
|
Args:
|
||||||
task_id: Unique identifier for the task.
|
task_id: Unique identifier for the task.
|
||||||
"""
|
"""
|
||||||
|
if _file_store is None:
|
||||||
|
return
|
||||||
await _file_store.delete(f"{_TASK_PREFIX}{task_id}")
|
await _file_store.delete(f"{_TASK_PREFIX}{task_id}")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
interactions:
|
interactions:
|
||||||
- request:
|
- request:
|
||||||
body: '{"max_tokens":4096,"messages":[{"role":"user","content":"Say hello in French"}],"model":"claude-sonnet-4-0","stream":false,"tool_choice":{"type":"tool","name":"structured_output"},"tools":[{"name":"structured_output","description":"Returns structured data according to the schema","input_schema":{"description":"Response model for greeting test.","properties":{"greeting":{"title":"Greeting","type":"string"},"language":{"title":"Language","type":"string"}},"required":["greeting","language"],"title":"GreetingResponse","type":"object"}}]}'
|
body: '{"max_tokens":4096,"messages":[{"role":"user","content":"Say hello in French"}],"model":"claude-sonnet-4-0","stream":false,"tool_choice":{"type":"tool","name":"structured_output"},"tools":[{"name":"structured_output","description":"Output
|
||||||
|
the structured response","input_schema":{"type":"object","description":"Response
|
||||||
|
model for greeting test.","title":"GreetingResponse","properties":{"greeting":{"type":"string","title":"Greeting"},"language":{"type":"string","title":"Language"}},"additionalProperties":false,"required":["greeting","language"]}}]}'
|
||||||
headers:
|
headers:
|
||||||
User-Agent:
|
User-Agent:
|
||||||
- X-USER-AGENT-XXX
|
- X-USER-AGENT-XXX
|
||||||
@@ -13,7 +15,7 @@ interactions:
|
|||||||
connection:
|
connection:
|
||||||
- keep-alive
|
- keep-alive
|
||||||
content-length:
|
content-length:
|
||||||
- '539'
|
- '551'
|
||||||
content-type:
|
content-type:
|
||||||
- application/json
|
- application/json
|
||||||
host:
|
host:
|
||||||
@@ -29,7 +31,7 @@ interactions:
|
|||||||
x-stainless-os:
|
x-stainless-os:
|
||||||
- X-STAINLESS-OS-XXX
|
- X-STAINLESS-OS-XXX
|
||||||
x-stainless-package-version:
|
x-stainless-package-version:
|
||||||
- 0.75.0
|
- 0.76.0
|
||||||
x-stainless-retry-count:
|
x-stainless-retry-count:
|
||||||
- '0'
|
- '0'
|
||||||
x-stainless-runtime:
|
x-stainless-runtime:
|
||||||
@@ -42,7 +44,7 @@ interactions:
|
|||||||
uri: https://api.anthropic.com/v1/messages
|
uri: https://api.anthropic.com/v1/messages
|
||||||
response:
|
response:
|
||||||
body:
|
body:
|
||||||
string: '{"model":"claude-sonnet-4-20250514","id":"msg_01XjvX2nCho1knuucbwwgCpw","type":"message","role":"assistant","content":[{"type":"tool_use","id":"toolu_019rfPRSDmBb7CyCTdGMv5rK","name":"structured_output","input":{"greeting":"Bonjour","language":"French"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":432,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"cache_creation":{"ephemeral_5m_input_tokens":0,"ephemeral_1h_input_tokens":0},"output_tokens":53,"service_tier":"standard"}}'
|
string: '{"model":"claude-sonnet-4-20250514","id":"msg_01CKTyVmak15L5oQ36mv4sL9","type":"message","role":"assistant","content":[{"type":"tool_use","id":"toolu_0174BYmn6xiSnUwVhFD8S7EW","name":"structured_output","input":{"greeting":"Bonjour","language":"French"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":436,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"cache_creation":{"ephemeral_5m_input_tokens":0,"ephemeral_1h_input_tokens":0},"output_tokens":53,"service_tier":"standard"}}'
|
||||||
headers:
|
headers:
|
||||||
CF-RAY:
|
CF-RAY:
|
||||||
- CF-RAY-XXX
|
- CF-RAY-XXX
|
||||||
@@ -51,7 +53,7 @@ interactions:
|
|||||||
Content-Type:
|
Content-Type:
|
||||||
- application/json
|
- application/json
|
||||||
Date:
|
Date:
|
||||||
- Mon, 01 Dec 2025 11:19:38 GMT
|
- Mon, 26 Jan 2026 14:59:34 GMT
|
||||||
Server:
|
Server:
|
||||||
- cloudflare
|
- cloudflare
|
||||||
Transfer-Encoding:
|
Transfer-Encoding:
|
||||||
@@ -82,12 +84,10 @@ interactions:
|
|||||||
- DYNAMIC
|
- DYNAMIC
|
||||||
request-id:
|
request-id:
|
||||||
- REQUEST-ID-XXX
|
- REQUEST-ID-XXX
|
||||||
retry-after:
|
|
||||||
- '24'
|
|
||||||
strict-transport-security:
|
strict-transport-security:
|
||||||
- STS-XXX
|
- STS-XXX
|
||||||
x-envoy-upstream-service-time:
|
x-envoy-upstream-service-time:
|
||||||
- '2101'
|
- '968'
|
||||||
status:
|
status:
|
||||||
code: 200
|
code: 200
|
||||||
message: OK
|
message: OK
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -272,3 +272,100 @@ class TestEmbeddingFactory:
|
|||||||
mock_build_from_provider.assert_called_once_with(mock_provider)
|
mock_build_from_provider.assert_called_once_with(mock_provider)
|
||||||
assert result == mock_embedding_function
|
assert result == mock_embedding_function
|
||||||
mock_import.assert_not_called()
|
mock_import.assert_not_called()
|
||||||
|
|
||||||
|
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||||
|
def test_build_embedder_google_vertex_with_genai_model(self, mock_import):
|
||||||
|
"""Test routing to Google Vertex provider with new genai model."""
|
||||||
|
mock_provider_class = MagicMock()
|
||||||
|
mock_provider_instance = MagicMock()
|
||||||
|
mock_embedding_function = MagicMock()
|
||||||
|
|
||||||
|
mock_import.return_value = mock_provider_class
|
||||||
|
mock_provider_class.return_value = mock_provider_instance
|
||||||
|
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"provider": "google-vertex",
|
||||||
|
"config": {
|
||||||
|
"api_key": "test-google-api-key",
|
||||||
|
"model_name": "gemini-embedding-001",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
build_embedder(config)
|
||||||
|
|
||||||
|
mock_import.assert_called_once_with(
|
||||||
|
"crewai.rag.embeddings.providers.google.vertex.VertexAIProvider"
|
||||||
|
)
|
||||||
|
mock_provider_class.assert_called_once()
|
||||||
|
|
||||||
|
call_kwargs = mock_provider_class.call_args.kwargs
|
||||||
|
assert call_kwargs["api_key"] == "test-google-api-key"
|
||||||
|
assert call_kwargs["model_name"] == "gemini-embedding-001"
|
||||||
|
|
||||||
|
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||||
|
def test_build_embedder_google_vertex_with_legacy_model(self, mock_import):
|
||||||
|
"""Test routing to Google Vertex provider with legacy textembedding-gecko model."""
|
||||||
|
mock_provider_class = MagicMock()
|
||||||
|
mock_provider_instance = MagicMock()
|
||||||
|
mock_embedding_function = MagicMock()
|
||||||
|
|
||||||
|
mock_import.return_value = mock_provider_class
|
||||||
|
mock_provider_class.return_value = mock_provider_instance
|
||||||
|
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"provider": "google-vertex",
|
||||||
|
"config": {
|
||||||
|
"project_id": "my-gcp-project",
|
||||||
|
"region": "us-central1",
|
||||||
|
"model_name": "textembedding-gecko",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
build_embedder(config)
|
||||||
|
|
||||||
|
mock_import.assert_called_once_with(
|
||||||
|
"crewai.rag.embeddings.providers.google.vertex.VertexAIProvider"
|
||||||
|
)
|
||||||
|
mock_provider_class.assert_called_once()
|
||||||
|
|
||||||
|
call_kwargs = mock_provider_class.call_args.kwargs
|
||||||
|
assert call_kwargs["project_id"] == "my-gcp-project"
|
||||||
|
assert call_kwargs["region"] == "us-central1"
|
||||||
|
assert call_kwargs["model_name"] == "textembedding-gecko"
|
||||||
|
|
||||||
|
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||||
|
def test_build_embedder_google_vertex_with_location(self, mock_import):
|
||||||
|
"""Test routing to Google Vertex provider with location parameter."""
|
||||||
|
mock_provider_class = MagicMock()
|
||||||
|
mock_provider_instance = MagicMock()
|
||||||
|
mock_embedding_function = MagicMock()
|
||||||
|
|
||||||
|
mock_import.return_value = mock_provider_class
|
||||||
|
mock_provider_class.return_value = mock_provider_instance
|
||||||
|
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"provider": "google-vertex",
|
||||||
|
"config": {
|
||||||
|
"project_id": "my-gcp-project",
|
||||||
|
"location": "europe-west1",
|
||||||
|
"model_name": "gemini-embedding-001",
|
||||||
|
"task_type": "RETRIEVAL_DOCUMENT",
|
||||||
|
"output_dimensionality": 768,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
build_embedder(config)
|
||||||
|
|
||||||
|
mock_import.assert_called_once_with(
|
||||||
|
"crewai.rag.embeddings.providers.google.vertex.VertexAIProvider"
|
||||||
|
)
|
||||||
|
|
||||||
|
call_kwargs = mock_provider_class.call_args.kwargs
|
||||||
|
assert call_kwargs["project_id"] == "my-gcp-project"
|
||||||
|
assert call_kwargs["location"] == "europe-west1"
|
||||||
|
assert call_kwargs["model_name"] == "gemini-embedding-001"
|
||||||
|
assert call_kwargs["task_type"] == "RETRIEVAL_DOCUMENT"
|
||||||
|
assert call_kwargs["output_dimensionality"] == 768
|
||||||
|
|||||||
@@ -0,0 +1,176 @@
|
|||||||
|
"""Integration tests for Google Vertex embeddings with Crew memory.
|
||||||
|
|
||||||
|
These tests make real API calls and use VCR to record/replay responses.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
from collections import defaultdict
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from crewai import Agent, Crew, Task
|
||||||
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
|
from crewai.events.types.memory_events import (
|
||||||
|
MemorySaveCompletedEvent,
|
||||||
|
MemorySaveStartedEvent,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def setup_vertex_ai_env():
|
||||||
|
"""Set up environment for Vertex AI tests.
|
||||||
|
|
||||||
|
Sets GOOGLE_GENAI_USE_VERTEXAI=true to ensure the SDK uses the Vertex AI
|
||||||
|
backend (aiplatform.googleapis.com) which matches the VCR cassettes.
|
||||||
|
Also mocks GOOGLE_API_KEY if not already set.
|
||||||
|
"""
|
||||||
|
env_updates = {"GOOGLE_GENAI_USE_VERTEXAI": "true"}
|
||||||
|
|
||||||
|
# Add a mock API key if none exists
|
||||||
|
if "GOOGLE_API_KEY" not in os.environ and "GEMINI_API_KEY" not in os.environ:
|
||||||
|
env_updates["GOOGLE_API_KEY"] = "test-key"
|
||||||
|
|
||||||
|
with patch.dict(os.environ, env_updates):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def google_vertex_embedder_config():
|
||||||
|
"""Fixture providing Google Vertex embedder configuration."""
|
||||||
|
return {
|
||||||
|
"provider": "google-vertex",
|
||||||
|
"config": {
|
||||||
|
"api_key": os.getenv("GOOGLE_API_KEY", "test-key"),
|
||||||
|
"model_name": "gemini-embedding-001",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def simple_agent():
|
||||||
|
"""Fixture providing a simple test agent."""
|
||||||
|
return Agent(
|
||||||
|
role="Research Assistant",
|
||||||
|
goal="Help with research tasks",
|
||||||
|
backstory="You are a helpful research assistant.",
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def simple_task(simple_agent):
|
||||||
|
"""Fixture providing a simple test task."""
|
||||||
|
return Task(
|
||||||
|
description="Summarize the key points about artificial intelligence in one sentence.",
|
||||||
|
expected_output="A one sentence summary about AI.",
|
||||||
|
agent=simple_agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.vcr()
|
||||||
|
@pytest.mark.timeout(120) # Longer timeout for VCR recording
|
||||||
|
def test_crew_memory_with_google_vertex_embedder(
|
||||||
|
google_vertex_embedder_config, simple_agent, simple_task
|
||||||
|
) -> None:
|
||||||
|
"""Test that Crew with memory=True works with google-vertex embedder and memory is used."""
|
||||||
|
# Track memory events
|
||||||
|
events: dict[str, list] = defaultdict(list)
|
||||||
|
condition = threading.Condition()
|
||||||
|
|
||||||
|
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||||
|
def on_save_started(source, event):
|
||||||
|
with condition:
|
||||||
|
events["MemorySaveStartedEvent"].append(event)
|
||||||
|
condition.notify()
|
||||||
|
|
||||||
|
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||||
|
def on_save_completed(source, event):
|
||||||
|
with condition:
|
||||||
|
events["MemorySaveCompletedEvent"].append(event)
|
||||||
|
condition.notify()
|
||||||
|
|
||||||
|
crew = Crew(
|
||||||
|
agents=[simple_agent],
|
||||||
|
tasks=[simple_task],
|
||||||
|
memory=True,
|
||||||
|
embedder=google_vertex_embedder_config,
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = crew.kickoff()
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.raw is not None
|
||||||
|
assert len(result.raw) > 0
|
||||||
|
|
||||||
|
with condition:
|
||||||
|
success = condition.wait_for(
|
||||||
|
lambda: len(events["MemorySaveCompletedEvent"]) >= 1,
|
||||||
|
timeout=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert success, "Timeout waiting for memory save events - memory may not be working"
|
||||||
|
assert len(events["MemorySaveStartedEvent"]) >= 1, "No memory save started events"
|
||||||
|
assert len(events["MemorySaveCompletedEvent"]) >= 1, "Memory save completed events"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.vcr()
|
||||||
|
@pytest.mark.timeout(120)
|
||||||
|
def test_crew_memory_with_google_vertex_project_id(simple_agent, simple_task) -> None:
|
||||||
|
"""Test Crew memory with Google Vertex using project_id authentication."""
|
||||||
|
project_id = os.getenv("GOOGLE_CLOUD_PROJECT")
|
||||||
|
if not project_id:
|
||||||
|
pytest.skip("GOOGLE_CLOUD_PROJECT environment variable not set")
|
||||||
|
|
||||||
|
# Track memory events
|
||||||
|
events: dict[str, list] = defaultdict(list)
|
||||||
|
condition = threading.Condition()
|
||||||
|
|
||||||
|
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||||
|
def on_save_started(source, event):
|
||||||
|
with condition:
|
||||||
|
events["MemorySaveStartedEvent"].append(event)
|
||||||
|
condition.notify()
|
||||||
|
|
||||||
|
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||||
|
def on_save_completed(source, event):
|
||||||
|
with condition:
|
||||||
|
events["MemorySaveCompletedEvent"].append(event)
|
||||||
|
condition.notify()
|
||||||
|
|
||||||
|
embedder_config = {
|
||||||
|
"provider": "google-vertex",
|
||||||
|
"config": {
|
||||||
|
"project_id": project_id,
|
||||||
|
"location": "us-central1",
|
||||||
|
"model_name": "gemini-embedding-001",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
crew = Crew(
|
||||||
|
agents=[simple_agent],
|
||||||
|
tasks=[simple_task],
|
||||||
|
memory=True,
|
||||||
|
embedder=embedder_config,
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = crew.kickoff()
|
||||||
|
|
||||||
|
# Verify basic result
|
||||||
|
assert result is not None
|
||||||
|
assert result.raw is not None
|
||||||
|
|
||||||
|
# Wait for memory save events
|
||||||
|
with condition:
|
||||||
|
success = condition.wait_for(
|
||||||
|
lambda: len(events["MemorySaveCompletedEvent"]) >= 1,
|
||||||
|
timeout=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify memory was actually used
|
||||||
|
assert success, "Timeout waiting for memory save events - memory may not be working"
|
||||||
|
assert len(events["MemorySaveStartedEvent"]) >= 1, "No memory save started events"
|
||||||
|
assert len(events["MemorySaveCompletedEvent"]) >= 1, "No memory save completed events"
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
from unittest.mock import patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from crewai import Agent, Crew, Task
|
from crewai import Agent, Crew, Task
|
||||||
@@ -121,3 +121,90 @@ def test_telemetry_singleton_pattern():
|
|||||||
thread.join()
|
thread.join()
|
||||||
|
|
||||||
assert all(instance is telemetry1 for instance in instances)
|
assert all(instance is telemetry1 for instance in instances)
|
||||||
|
|
||||||
|
|
||||||
|
def test_signal_handler_registration_skipped_in_non_main_thread():
|
||||||
|
"""Test that signal handler registration is skipped when running from a non-main thread.
|
||||||
|
|
||||||
|
This test verifies that when Telemetry is initialized from a non-main thread,
|
||||||
|
the signal handler registration is skipped without raising noisy ValueError tracebacks.
|
||||||
|
See: https://github.com/crewAIInc/crewAI/issues/4289
|
||||||
|
"""
|
||||||
|
Telemetry._instance = None
|
||||||
|
|
||||||
|
result = {"register_signal_handler_called": False, "error": None}
|
||||||
|
|
||||||
|
def init_telemetry_in_thread():
|
||||||
|
try:
|
||||||
|
with patch("crewai.telemetry.telemetry.TracerProvider"):
|
||||||
|
with patch.object(
|
||||||
|
Telemetry,
|
||||||
|
"_register_signal_handler",
|
||||||
|
wraps=lambda *args, **kwargs: None,
|
||||||
|
) as mock_register:
|
||||||
|
telemetry = Telemetry()
|
||||||
|
result["register_signal_handler_called"] = mock_register.called
|
||||||
|
result["telemetry"] = telemetry
|
||||||
|
except Exception as e:
|
||||||
|
result["error"] = e
|
||||||
|
|
||||||
|
thread = threading.Thread(target=init_telemetry_in_thread)
|
||||||
|
thread.start()
|
||||||
|
thread.join()
|
||||||
|
|
||||||
|
assert result["error"] is None, f"Unexpected error: {result['error']}"
|
||||||
|
assert (
|
||||||
|
result["register_signal_handler_called"] is False
|
||||||
|
), "Signal handler should not be registered in non-main thread"
|
||||||
|
|
||||||
|
|
||||||
|
def test_signal_handler_registration_skipped_logs_debug_message():
|
||||||
|
"""Test that a debug message is logged when signal handler registration is skipped.
|
||||||
|
|
||||||
|
This test verifies that when Telemetry is initialized from a non-main thread,
|
||||||
|
a debug message is logged indicating that signal handler registration was skipped.
|
||||||
|
"""
|
||||||
|
Telemetry._instance = None
|
||||||
|
|
||||||
|
result = {"telemetry": None, "error": None, "debug_calls": []}
|
||||||
|
|
||||||
|
mock_logger_debug = MagicMock()
|
||||||
|
|
||||||
|
def init_telemetry_in_thread():
|
||||||
|
try:
|
||||||
|
with patch("crewai.telemetry.telemetry.TracerProvider"):
|
||||||
|
with patch(
|
||||||
|
"crewai.telemetry.telemetry.logger.debug", mock_logger_debug
|
||||||
|
):
|
||||||
|
result["telemetry"] = Telemetry()
|
||||||
|
result["debug_calls"] = [
|
||||||
|
str(call) for call in mock_logger_debug.call_args_list
|
||||||
|
]
|
||||||
|
except Exception as e:
|
||||||
|
result["error"] = e
|
||||||
|
|
||||||
|
thread = threading.Thread(target=init_telemetry_in_thread)
|
||||||
|
thread.start()
|
||||||
|
thread.join()
|
||||||
|
|
||||||
|
assert result["error"] is None, f"Unexpected error: {result['error']}"
|
||||||
|
assert result["telemetry"] is not None
|
||||||
|
|
||||||
|
debug_calls = result["debug_calls"]
|
||||||
|
assert any(
|
||||||
|
"Skipping signal handler registration" in call for call in debug_calls
|
||||||
|
), f"Expected debug message about skipping signal handler registration, got: {debug_calls}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_signal_handlers_registered_in_main_thread():
|
||||||
|
"""Test that signal handlers are registered when running from the main thread."""
|
||||||
|
Telemetry._instance = None
|
||||||
|
|
||||||
|
with patch("crewai.telemetry.telemetry.TracerProvider"):
|
||||||
|
with patch(
|
||||||
|
"crewai.telemetry.telemetry.Telemetry._register_signal_handler"
|
||||||
|
) as mock_register:
|
||||||
|
telemetry = Telemetry()
|
||||||
|
|
||||||
|
assert telemetry.ready is True
|
||||||
|
assert mock_register.call_count >= 2
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
"""CrewAI development tools."""
|
"""CrewAI development tools."""
|
||||||
|
|
||||||
__version__ = "1.8.1"
|
__version__ = "1.9.0"
|
||||||
|
|||||||
8
uv.lock
generated
8
uv.lock
generated
@@ -310,7 +310,7 @@ wheels = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "anthropic"
|
name = "anthropic"
|
||||||
version = "0.71.1"
|
version = "0.73.0"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "anyio" },
|
{ name = "anyio" },
|
||||||
@@ -322,9 +322,9 @@ dependencies = [
|
|||||||
{ name = "sniffio" },
|
{ name = "sniffio" },
|
||||||
{ name = "typing-extensions" },
|
{ name = "typing-extensions" },
|
||||||
]
|
]
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/05/4b/19620875841f692fdc35eb58bf0201c8ad8c47b8443fecbf1b225312175b/anthropic-0.71.1.tar.gz", hash = "sha256:a77d156d3e7d318b84681b59823b2dee48a8ac508a3e54e49f0ab0d074e4b0da", size = 493294, upload-time = "2025-10-28T17:28:42.213Z" }
|
sdist = { url = "https://files.pythonhosted.org/packages/f0/07/f550112c3f5299d02f06580577f602e8a112b1988ad7c98ac1a8f7292d7e/anthropic-0.73.0.tar.gz", hash = "sha256:30f0d7d86390165f86af6ca7c3041f8720bb2e1b0e12a44525c8edfdbd2c5239", size = 425168, upload-time = "2025-11-14T18:47:52.635Z" }
|
||||||
wheels = [
|
wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/4b/68/b2f988b13325f9ac9921b1e87f0b7994468014e1b5bd3bdbd2472f5baf45/anthropic-0.71.1-py3-none-any.whl", hash = "sha256:6ca6c579f0899a445faeeed9c0eb97aa4bdb751196262f9ccc96edfc0bb12679", size = 355020, upload-time = "2025-10-28T17:28:40.653Z" },
|
{ url = "https://files.pythonhosted.org/packages/15/b1/5d4d3f649e151e58dc938cf19c4d0cd19fca9a986879f30fea08a7b17138/anthropic-0.73.0-py3-none-any.whl", hash = "sha256:0d56cd8b3ca3fea9c9b5162868bdfd053fbc189b8b56d4290bd2d427b56db769", size = 367839, upload-time = "2025-11-14T18:47:51.195Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -1276,7 +1276,7 @@ requires-dist = [
|
|||||||
{ name = "aiobotocore", marker = "extra == 'aws'", specifier = "~=2.25.2" },
|
{ name = "aiobotocore", marker = "extra == 'aws'", specifier = "~=2.25.2" },
|
||||||
{ name = "aiocache", extras = ["memcached", "redis"], marker = "extra == 'a2a'", specifier = "~=0.12.3" },
|
{ name = "aiocache", extras = ["memcached", "redis"], marker = "extra == 'a2a'", specifier = "~=0.12.3" },
|
||||||
{ name = "aiosqlite", specifier = "~=0.21.0" },
|
{ name = "aiosqlite", specifier = "~=0.21.0" },
|
||||||
{ name = "anthropic", marker = "extra == 'anthropic'", specifier = "~=0.71.0" },
|
{ name = "anthropic", marker = "extra == 'anthropic'", specifier = "~=0.73.0" },
|
||||||
{ name = "appdirs", specifier = "~=1.4.4" },
|
{ name = "appdirs", specifier = "~=1.4.4" },
|
||||||
{ name = "azure-ai-inference", marker = "extra == 'azure-ai-inference'", specifier = "~=1.0.0b9" },
|
{ name = "azure-ai-inference", marker = "extra == 'azure-ai-inference'", specifier = "~=1.0.0b9" },
|
||||||
{ name = "boto3", marker = "extra == 'aws'", specifier = "~=1.40.38" },
|
{ name = "boto3", marker = "extra == 'aws'", specifier = "~=1.40.38" },
|
||||||
|
|||||||
Reference in New Issue
Block a user