Compare commits

...

6 Commits

Author SHA1 Message Date
Devin AI
9af03058fe fix: skip signal handler registration in non-main thread
When CrewAI is initialized from a non-main thread (e.g., in Streamlit,
Flask, Django, Jupyter), the telemetry module was printing multiple
ValueError tracebacks for each signal handler registration attempt.

This fix adds a proactive main thread check in _register_shutdown_handlers()
before attempting signal registration. If not in the main thread, a debug
message is logged and signal handler registration is skipped.

Fixes #4289

Co-Authored-By: João <joao@crewai.com>
2026-01-27 19:45:52 +00:00
Greyson LaLonde
d52dbc1f4b chore: add missing change logs (#4285)
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Check Documentation Broken Links / Check broken links (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
* chore: add missing change logs

* chore: add translations
2026-01-26 18:26:01 -08:00
Lorenze Jay
6b926b90d0 chore: update version to 1.9.0 across all relevant files (#4284)
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Check Documentation Broken Links / Check broken links (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Build uv cache / build-cache (3.10) (push) Has been cancelled
Build uv cache / build-cache (3.11) (push) Has been cancelled
Build uv cache / build-cache (3.12) (push) Has been cancelled
Build uv cache / build-cache (3.13) (push) Has been cancelled
- Bumped the version number to 1.9.0 in pyproject.toml files and __init__.py files across the CrewAI library and its tools.
- Updated dependencies to use the new version of crewai-tools (1.9.0) for improved functionality and compatibility.
- Ensured consistency in versioning across the codebase to reflect the latest updates.
2026-01-26 16:36:35 -08:00
Lorenze Jay
fc84daadbb fix: enhance file store with fallback memory cache when aiocache is n… (#4283)
* fix: enhance file store with fallback memory cache when aiocache is not installed

- Added a simple in-memory cache implementation to serve as a fallback when the aiocache library is unavailable.
- Improved error handling for the aiocache import, ensuring that the application can still function without it.
- This change enhances the robustness of the file store utility by providing a reliable caching mechanism in various environments.

* drop fallback

* Potential fix for pull request finding 'Unused global variable'

Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com>

---------

Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com>
2026-01-26 15:12:34 -08:00
Lorenze Jay
58b866a83d Lorenze/supporting vertex embeddings (#4282)
* feat: introduce GoogleGenAIVertexEmbeddingFunction for dual SDK support

- Added a new embedding function to support both the legacy vertexai.language_models SDK and the new google-genai SDK for Google Vertex AI.
- Updated factory methods to route to the new embedding function.
- Enhanced VertexAIProvider and related configurations to accommodate the new model options.
- Added integration tests for Google Vertex embeddings with Crew memory, ensuring compatibility and functionality with both authentication methods.

This update improves the flexibility and compatibility of Google Vertex AI embeddings within the CrewAI framework.

* fix test count

* rm comment

* regen cassettes

* regen

* drop variable from .envtest

* dreict to relevant trest only
2026-01-26 14:55:03 -08:00
Greyson LaLonde
9797567342 feat: add structured outputs and response_format support across providers (#4280)
* feat: add response_format parameter to Azure and Gemini providers

* feat: add structured outputs support to Bedrock and Anthropic providers

* chore: bump anthropic dep

* fix: use beta structured output for new models
2026-01-26 11:03:33 -08:00
30 changed files with 15041 additions and 224 deletions

View File

@@ -4,6 +4,74 @@ description: "Product updates, improvements, and bug fixes for CrewAI"
icon: "clock" icon: "clock"
mode: "wide" mode: "wide"
--- ---
<Update label="Jan 26, 2026">
## v1.9.0
[View release on GitHub](https://github.com/crewAIInc/crewAI/releases/tag/1.9.0)
## What's Changed
### Features
- Add structured outputs and response_format support across providers
- Add response ID to streaming responses
- Add event ordering with parent-child hierarchies
- Add Keycloak SSO authentication support
- Add multimodal file handling capabilities
- Add native OpenAI responses API support
- Add A2A task execution utilities
- Add A2A server configuration and agent card generation
- Enhance event system and expand transport options
- Improve tool calling mechanisms
### Bug Fixes
- Enhance file store with fallback memory cache when aiocache is not available
- Ensure document list is not empty
- Handle Bedrock stop sequences properly
- Add Google Vertex API key support
- Enhance Azure model stop word detection
- Improve error handling for HumanFeedbackPending in flow execution
- Fix execution span task unlinking
### Documentation
- Add native file handling documentation
- Add OpenAI responses API documentation
- Add agent card implementation guidance
- Refine A2A documentation
- Update changelog for v1.8.0
### Contributors
@Anaisdg, @GininDenis, @Vidit-Ostwal, @greysonlalonde, @heitorado, @joaomdmoura, @koushiv777, @lorenzejay, @nicoferdi96, @vinibrsl
</Update>
<Update label="Jan 15, 2026">
## v1.8.1
[View release on GitHub](https://github.com/crewAIInc/crewAI/releases/tag/1.8.1)
## What's Changed
### Features
- Add A2A task execution utilities
- Add A2A server configuration and agent card generation
- Add additional transport mechanisms
- Add Galileo integration support
### Bug Fixes
- Improve Azure model compatibility
- Expand frame inspection depth to detect parent_flow
- Resolve task execution span management issues
- Enhance error handling for human feedback scenarios during flow execution
### Documentation
- Add A2A agent card documentation
- Add PII redaction feature documentation
### Contributors
@Anaisdg, @GininDenis, @greysonlalonde, @joaomdmoura, @koushiv777, @lorenzejay, @vinibrsl
</Update>
<Update label="Jan 08, 2026"> <Update label="Jan 08, 2026">
## v1.8.0 ## v1.8.0

View File

@@ -401,23 +401,58 @@ crew = Crew(
### Vertex AI Embeddings ### Vertex AI Embeddings
For Google Cloud users with Vertex AI access. For Google Cloud users with Vertex AI access. Supports both legacy and new embedding models with automatic SDK selection.
<Note>
**Deprecation Notice:** Legacy models (`textembedding-gecko*`) use the deprecated `vertexai.language_models` SDK which will be removed after June 24, 2026. Consider migrating to newer models like `gemini-embedding-001`. See the [Google migration guide](https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk) for details.
</Note>
```python ```python
# Recommended: Using new models with google-genai SDK
crew = Crew( crew = Crew(
memory=True, memory=True,
embedder={ embedder={
"provider": "vertexai", "provider": "google-vertex",
"config": { "config": {
"project_id": "your-gcp-project-id", "project_id": "your-gcp-project-id",
"region": "us-central1", # or your preferred region "location": "us-central1",
"api_key": "your-service-account-key", "model_name": "gemini-embedding-001", # or "text-embedding-005", "text-multilingual-embedding-002"
"model_name": "textembedding-gecko" "task_type": "RETRIEVAL_DOCUMENT", # Optional
"output_dimensionality": 768 # Optional
}
}
)
# Using API key authentication (Exp)
crew = Crew(
memory=True,
embedder={
"provider": "google-vertex",
"config": {
"api_key": "your-google-api-key",
"model_name": "gemini-embedding-001"
}
}
)
# Legacy models (backwards compatible, emits deprecation warning)
crew = Crew(
memory=True,
embedder={
"provider": "google-vertex",
"config": {
"project_id": "your-gcp-project-id",
"region": "us-central1", # or "location" (region is deprecated)
"model_name": "textembedding-gecko" # Legacy model
} }
} }
) )
``` ```
**Available models:**
- **New SDK models** (recommended): `gemini-embedding-001`, `text-embedding-005`, `text-multilingual-embedding-002`
- **Legacy models** (deprecated): `textembedding-gecko`, `textembedding-gecko@001`, `textembedding-gecko-multilingual`
### Ollama Embeddings (Local) ### Ollama Embeddings (Local)
Run embeddings locally for privacy and cost savings. Run embeddings locally for privacy and cost savings.
@@ -569,7 +604,7 @@ mem0_client_embedder_config = {
"project_id": "my_project_id", # Optional "project_id": "my_project_id", # Optional
"api_key": "custom-api-key" # Optional - overrides env var "api_key": "custom-api-key" # Optional - overrides env var
"run_id": "my_run_id", # Optional - for short-term memory "run_id": "my_run_id", # Optional - for short-term memory
"includes": "include1", # Optional "includes": "include1", # Optional
"excludes": "exclude1", # Optional "excludes": "exclude1", # Optional
"infer": True # Optional defaults to True "infer": True # Optional defaults to True
"custom_categories": new_categories # Optional - custom categories for user memory "custom_categories": new_categories # Optional - custom categories for user memory
@@ -591,7 +626,7 @@ crew = Crew(
### Choosing the Right Embedding Provider ### Choosing the Right Embedding Provider
When selecting an embedding provider, consider factors like performance, privacy, cost, and integration needs. When selecting an embedding provider, consider factors like performance, privacy, cost, and integration needs.
Below is a comparison to help you decide: Below is a comparison to help you decide:
| Provider | Best For | Pros | Cons | | Provider | Best For | Pros | Cons |
@@ -749,7 +784,7 @@ Entity Memory supports batching when saving multiple entities at once. When you
This improves performance and observability when writing many entities in one operation. This improves performance and observability when writing many entities in one operation.
## 2. External Memory ## 2. External Memory
External Memory provides a standalone memory system that operates independently from the crew's built-in memory. This is ideal for specialized memory providers or cross-application memory sharing. External Memory provides a standalone memory system that operates independently from the crew's built-in memory. This is ideal for specialized memory providers or cross-application memory sharing.
### Basic External Memory with Mem0 ### Basic External Memory with Mem0
@@ -819,7 +854,7 @@ external_memory = ExternalMemory(
"project_id": "my_project_id", # Optional "project_id": "my_project_id", # Optional
"api_key": "custom-api-key" # Optional - overrides env var "api_key": "custom-api-key" # Optional - overrides env var
"run_id": "my_run_id", # Optional - for short-term memory "run_id": "my_run_id", # Optional - for short-term memory
"includes": "include1", # Optional "includes": "include1", # Optional
"excludes": "exclude1", # Optional "excludes": "exclude1", # Optional
"infer": True # Optional defaults to True "infer": True # Optional defaults to True
"custom_categories": new_categories # Optional - custom categories for user memory "custom_categories": new_categories # Optional - custom categories for user memory

View File

@@ -4,6 +4,74 @@ description: "CrewAI의 제품 업데이트, 개선 사항 및 버그 수정"
icon: "clock" icon: "clock"
mode: "wide" mode: "wide"
--- ---
<Update label="2026년 1월 26일">
## v1.9.0
[GitHub 릴리스 보기](https://github.com/crewAIInc/crewAI/releases/tag/1.9.0)
## 변경 사항
### 기능
- 프로바이더 전반에 걸친 구조화된 출력 및 response_format 지원 추가
- 스트리밍 응답에 응답 ID 추가
- 부모-자식 계층 구조를 가진 이벤트 순서 추가
- Keycloak SSO 인증 지원 추가
- 멀티모달 파일 처리 기능 추가
- 네이티브 OpenAI responses API 지원 추가
- A2A 작업 실행 유틸리티 추가
- A2A 서버 구성 및 에이전트 카드 생성 추가
- 이벤트 시스템 향상 및 전송 옵션 확장
- 도구 호출 메커니즘 개선
### 버그 수정
- aiocache를 사용할 수 없을 때 폴백 메모리 캐시로 파일 저장소 향상
- 문서 목록이 비어 있지 않도록 보장
- Bedrock 중지 시퀀스 적절히 처리
- Google Vertex API 키 지원 추가
- Azure 모델 중지 단어 감지 향상
- 흐름 실행 시 HumanFeedbackPending 오류 처리 개선
- 실행 스팬 작업 연결 해제 수정
### 문서
- 네이티브 파일 처리 문서 추가
- OpenAI responses API 문서 추가
- 에이전트 카드 구현 가이드 추가
- A2A 문서 개선
- v1.8.0 변경 로그 업데이트
### 기여자
@Anaisdg, @GininDenis, @Vidit-Ostwal, @greysonlalonde, @heitorado, @joaomdmoura, @koushiv777, @lorenzejay, @nicoferdi96, @vinibrsl
</Update>
<Update label="2026년 1월 15일">
## v1.8.1
[GitHub 릴리스 보기](https://github.com/crewAIInc/crewAI/releases/tag/1.8.1)
## 변경 사항
### 기능
- A2A 작업 실행 유틸리티 추가
- A2A 서버 구성 및 에이전트 카드 생성 추가
- 추가 전송 메커니즘 추가
- Galileo 통합 지원 추가
### 버그 수정
- Azure 모델 호환성 개선
- parent_flow 감지를 위한 프레임 검사 깊이 확장
- 작업 실행 스팬 관리 문제 해결
- 흐름 실행 중 휴먼 피드백 시나리오에 대한 오류 처리 향상
### 문서
- A2A 에이전트 카드 문서 추가
- PII 삭제 기능 문서 추가
### 기여자
@Anaisdg, @GininDenis, @greysonlalonde, @joaomdmoura, @koushiv777, @lorenzejay, @vinibrsl
</Update>
<Update label="2026년 1월 8일"> <Update label="2026년 1월 8일">
## v1.8.0 ## v1.8.0

View File

@@ -4,6 +4,74 @@ description: "Atualizações de produto, melhorias e correções do CrewAI"
icon: "clock" icon: "clock"
mode: "wide" mode: "wide"
--- ---
<Update label="26 jan 2026">
## v1.9.0
[Ver release no GitHub](https://github.com/crewAIInc/crewAI/releases/tag/1.9.0)
## O que Mudou
### Funcionalidades
- Adicionar suporte a saídas estruturadas e response_format em vários provedores
- Adicionar ID de resposta às respostas de streaming
- Adicionar ordenação de eventos com hierarquias pai-filho
- Adicionar suporte à autenticação SSO Keycloak
- Adicionar capacidades de manipulação de arquivos multimodais
- Adicionar suporte nativo à API de respostas OpenAI
- Adicionar utilitários de execução de tarefas A2A
- Adicionar configuração de servidor A2A e geração de cartão de agente
- Aprimorar sistema de eventos e expandir opções de transporte
- Melhorar mecanismos de chamada de ferramentas
### Correções de Bugs
- Aprimorar armazenamento de arquivos com cache de memória de fallback quando aiocache não está disponível
- Garantir que lista de documentos não esteja vazia
- Tratar sequências de parada do Bedrock adequadamente
- Adicionar suporte à chave de API do Google Vertex
- Aprimorar detecção de palavras de parada do modelo Azure
- Melhorar tratamento de erros para HumanFeedbackPending na execução de fluxo
- Corrigir desvinculação de tarefa do span de execução
### Documentação
- Adicionar documentação de manipulação nativa de arquivos
- Adicionar documentação da API de respostas OpenAI
- Adicionar orientação de implementação de cartão de agente
- Refinar documentação A2A
- Atualizar changelog para v1.8.0
### Contribuidores
@Anaisdg, @GininDenis, @Vidit-Ostwal, @greysonlalonde, @heitorado, @joaomdmoura, @koushiv777, @lorenzejay, @nicoferdi96, @vinibrsl
</Update>
<Update label="15 jan 2026">
## v1.8.1
[Ver release no GitHub](https://github.com/crewAIInc/crewAI/releases/tag/1.8.1)
## O que Mudou
### Funcionalidades
- Adicionar utilitários de execução de tarefas A2A
- Adicionar configuração de servidor A2A e geração de cartão de agente
- Adicionar mecanismos de transporte adicionais
- Adicionar suporte à integração Galileo
### Correções de Bugs
- Melhorar compatibilidade do modelo Azure
- Expandir profundidade de inspeção de frame para detectar parent_flow
- Resolver problemas de gerenciamento de span de execução de tarefas
- Aprimorar tratamento de erros para cenários de feedback humano durante execução de fluxo
### Documentação
- Adicionar documentação de cartão de agente A2A
- Adicionar documentação de recurso de redação de PII
### Contribuidores
@Anaisdg, @GininDenis, @greysonlalonde, @joaomdmoura, @koushiv777, @lorenzejay, @vinibrsl
</Update>
<Update label="08 jan 2026"> <Update label="08 jan 2026">
## v1.8.0 ## v1.8.0

View File

@@ -152,4 +152,4 @@ __all__ = [
"wrap_file_source", "wrap_file_source",
] ]
__version__ = "1.8.1" __version__ = "1.9.0"

View File

@@ -12,7 +12,7 @@ dependencies = [
"pytube~=15.0.0", "pytube~=15.0.0",
"requests~=2.32.5", "requests~=2.32.5",
"docker~=7.1.0", "docker~=7.1.0",
"crewai==1.8.1", "crewai==1.9.0",
"lancedb~=0.5.4", "lancedb~=0.5.4",
"tiktoken~=0.8.0", "tiktoken~=0.8.0",
"beautifulsoup4~=4.13.4", "beautifulsoup4~=4.13.4",

View File

@@ -291,4 +291,4 @@ __all__ = [
"ZapierActionTools", "ZapierActionTools",
] ]
__version__ = "1.8.1" __version__ = "1.9.0"

View File

@@ -49,7 +49,7 @@ Repository = "https://github.com/crewAIInc/crewAI"
[project.optional-dependencies] [project.optional-dependencies]
tools = [ tools = [
"crewai-tools==1.8.1", "crewai-tools==1.9.0",
] ]
embeddings = [ embeddings = [
"tiktoken~=0.8.0" "tiktoken~=0.8.0"
@@ -90,7 +90,7 @@ azure-ai-inference = [
"azure-ai-inference~=1.0.0b9", "azure-ai-inference~=1.0.0b9",
] ]
anthropic = [ anthropic = [
"anthropic~=0.71.0", "anthropic~=0.73.0",
] ]
a2a = [ a2a = [
"a2a-sdk~=0.3.10", "a2a-sdk~=0.3.10",

View File

@@ -40,7 +40,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
_suppress_pydantic_deprecation_warnings() _suppress_pydantic_deprecation_warnings()
__version__ = "1.8.1" __version__ = "1.9.0"
_telemetry_submitted = False _telemetry_submitted = False

View File

@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
authors = [{ name = "Your Name", email = "you@example.com" }] authors = [{ name = "Your Name", email = "you@example.com" }]
requires-python = ">=3.10,<3.14" requires-python = ">=3.10,<3.14"
dependencies = [ dependencies = [
"crewai[tools]==1.8.1" "crewai[tools]==1.9.0"
] ]
[project.scripts] [project.scripts]

View File

@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
authors = [{ name = "Your Name", email = "you@example.com" }] authors = [{ name = "Your Name", email = "you@example.com" }]
requires-python = ">=3.10,<3.14" requires-python = ">=3.10,<3.14"
dependencies = [ dependencies = [
"crewai[tools]==1.8.1" "crewai[tools]==1.9.0"
] ]
[project.scripts] [project.scripts]

View File

@@ -3,9 +3,8 @@ from __future__ import annotations
import json import json
import logging import logging
import os import os
from typing import TYPE_CHECKING, Any, Literal, cast from typing import TYPE_CHECKING, Any, Final, Literal, TypeGuard, cast
from anthropic.types import ThinkingBlock
from pydantic import BaseModel from pydantic import BaseModel
from crewai.events.types.llm_events import LLMCallType from crewai.events.types.llm_events import LLMCallType
@@ -22,8 +21,9 @@ if TYPE_CHECKING:
from crewai.llms.hooks.base import BaseInterceptor from crewai.llms.hooks.base import BaseInterceptor
try: try:
from anthropic import Anthropic, AsyncAnthropic from anthropic import Anthropic, AsyncAnthropic, transform_schema
from anthropic.types import Message, TextBlock, ThinkingBlock, ToolUseBlock from anthropic.types import Message, TextBlock, ThinkingBlock, ToolUseBlock
from anthropic.types.beta import BetaMessage
import httpx import httpx
except ImportError: except ImportError:
raise ImportError( raise ImportError(
@@ -31,7 +31,62 @@ except ImportError:
) from None ) from None
ANTHROPIC_FILES_API_BETA = "files-api-2025-04-14" ANTHROPIC_FILES_API_BETA: Final = "files-api-2025-04-14"
ANTHROPIC_STRUCTURED_OUTPUTS_BETA: Final = "structured-outputs-2025-11-13"
NATIVE_STRUCTURED_OUTPUT_MODELS: Final[
tuple[
Literal["claude-sonnet-4-5"],
Literal["claude-sonnet-4.5"],
Literal["claude-opus-4-5"],
Literal["claude-opus-4.5"],
Literal["claude-opus-4-1"],
Literal["claude-opus-4.1"],
Literal["claude-haiku-4-5"],
Literal["claude-haiku-4.5"],
]
] = (
"claude-sonnet-4-5",
"claude-sonnet-4.5",
"claude-opus-4-5",
"claude-opus-4.5",
"claude-opus-4-1",
"claude-opus-4.1",
"claude-haiku-4-5",
"claude-haiku-4.5",
)
def _supports_native_structured_outputs(model: str) -> bool:
"""Check if the model supports native structured outputs.
Native structured outputs are only available for Claude 4.5 models
(Sonnet 4.5, Opus 4.5, Opus 4.1, Haiku 4.5).
Other models require the tool-based fallback approach.
Args:
model: The model name/identifier.
Returns:
True if the model supports native structured outputs.
"""
model_lower = model.lower()
return any(prefix in model_lower for prefix in NATIVE_STRUCTURED_OUTPUT_MODELS)
def _is_pydantic_model_class(obj: Any) -> TypeGuard[type[BaseModel]]:
"""Check if an object is a Pydantic model class.
This distinguishes between Pydantic model classes that support structured
outputs (have model_json_schema) and plain dicts like {"type": "json_object"}.
Args:
obj: The object to check.
Returns:
True if obj is a Pydantic model class.
"""
return isinstance(obj, type) and issubclass(obj, BaseModel)
def _contains_file_id_reference(messages: list[dict[str, Any]]) -> bool: def _contains_file_id_reference(messages: list[dict[str, Any]]) -> bool:
@@ -84,6 +139,7 @@ class AnthropicCompletion(BaseLLM):
client_params: dict[str, Any] | None = None, client_params: dict[str, Any] | None = None,
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None, interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
thinking: AnthropicThinkingConfig | None = None, thinking: AnthropicThinkingConfig | None = None,
response_format: type[BaseModel] | None = None,
**kwargs: Any, **kwargs: Any,
): ):
"""Initialize Anthropic chat completion client. """Initialize Anthropic chat completion client.
@@ -101,6 +157,8 @@ class AnthropicCompletion(BaseLLM):
stream: Enable streaming responses stream: Enable streaming responses
client_params: Additional parameters for the Anthropic client client_params: Additional parameters for the Anthropic client
interceptor: HTTP interceptor for modifying requests/responses at transport level. interceptor: HTTP interceptor for modifying requests/responses at transport level.
response_format: Pydantic model for structured output. When provided, responses
will be validated against this model schema.
**kwargs: Additional parameters **kwargs: Additional parameters
""" """
super().__init__( super().__init__(
@@ -131,6 +189,7 @@ class AnthropicCompletion(BaseLLM):
self.stop_sequences = stop_sequences or [] self.stop_sequences = stop_sequences or []
self.thinking = thinking self.thinking = thinking
self.previous_thinking_blocks: list[ThinkingBlock] = [] self.previous_thinking_blocks: list[ThinkingBlock] = []
self.response_format = response_format
# Model-specific settings # Model-specific settings
self.is_claude_3 = "claude-3" in model.lower() self.is_claude_3 = "claude-3" in model.lower()
self.supports_tools = True self.supports_tools = True
@@ -231,6 +290,8 @@ class AnthropicCompletion(BaseLLM):
formatted_messages, system_message, tools formatted_messages, system_message, tools
) )
effective_response_model = response_model or self.response_format
# Handle streaming vs non-streaming # Handle streaming vs non-streaming
if self.stream: if self.stream:
return self._handle_streaming_completion( return self._handle_streaming_completion(
@@ -238,7 +299,7 @@ class AnthropicCompletion(BaseLLM):
available_functions, available_functions,
from_task, from_task,
from_agent, from_agent,
response_model, effective_response_model,
) )
return self._handle_completion( return self._handle_completion(
@@ -246,7 +307,7 @@ class AnthropicCompletion(BaseLLM):
available_functions, available_functions,
from_task, from_task,
from_agent, from_agent,
response_model, effective_response_model,
) )
except Exception as e: except Exception as e:
@@ -298,13 +359,15 @@ class AnthropicCompletion(BaseLLM):
formatted_messages, system_message, tools formatted_messages, system_message, tools
) )
effective_response_model = response_model or self.response_format
if self.stream: if self.stream:
return await self._ahandle_streaming_completion( return await self._ahandle_streaming_completion(
completion_params, completion_params,
available_functions, available_functions,
from_task, from_task,
from_agent, from_agent,
response_model, effective_response_model,
) )
return await self._ahandle_completion( return await self._ahandle_completion(
@@ -312,7 +375,7 @@ class AnthropicCompletion(BaseLLM):
available_functions, available_functions,
from_task, from_task,
from_agent, from_agent,
response_model, effective_response_model,
) )
except Exception as e: except Exception as e:
@@ -565,22 +628,40 @@ class AnthropicCompletion(BaseLLM):
response_model: type[BaseModel] | None = None, response_model: type[BaseModel] | None = None,
) -> str | Any: ) -> str | Any:
"""Handle non-streaming message completion.""" """Handle non-streaming message completion."""
if response_model:
structured_tool = {
"name": "structured_output",
"description": "Returns structured data according to the schema",
"input_schema": response_model.model_json_schema(),
}
params["tools"] = [structured_tool]
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
uses_file_api = _contains_file_id_reference(params.get("messages", [])) uses_file_api = _contains_file_id_reference(params.get("messages", []))
betas: list[str] = []
use_native_structured_output = False
if uses_file_api:
betas.append(ANTHROPIC_FILES_API_BETA)
extra_body: dict[str, Any] | None = None
if _is_pydantic_model_class(response_model):
schema = transform_schema(response_model.model_json_schema())
if _supports_native_structured_outputs(self.model):
use_native_structured_output = True
betas.append(ANTHROPIC_STRUCTURED_OUTPUTS_BETA)
extra_body = {
"output_format": {
"type": "json_schema",
"schema": schema,
}
}
else:
structured_tool = {
"name": "structured_output",
"description": "Output the structured response",
"input_schema": schema,
}
params["tools"] = [structured_tool]
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
try: try:
if uses_file_api: if betas:
params["betas"] = [ANTHROPIC_FILES_API_BETA] params["betas"] = betas
response = self.client.beta.messages.create(**params) response = self.client.beta.messages.create(
**params, extra_body=extra_body
)
else: else:
response = self.client.messages.create(**params) response = self.client.messages.create(**params)
@@ -593,22 +674,34 @@ class AnthropicCompletion(BaseLLM):
usage = self._extract_anthropic_token_usage(response) usage = self._extract_anthropic_token_usage(response)
self._track_token_usage_internal(usage) self._track_token_usage_internal(usage)
if response_model and response.content: if _is_pydantic_model_class(response_model) and response.content:
tool_uses = [ if use_native_structured_output:
block for block in response.content if isinstance(block, ToolUseBlock) for block in response.content:
] if isinstance(block, TextBlock):
if tool_uses and tool_uses[0].name == "structured_output": structured_json = block.text
structured_data = tool_uses[0].input self._emit_call_completed_event(
structured_json = json.dumps(structured_data) response=structured_json,
self._emit_call_completed_event( call_type=LLMCallType.LLM_CALL,
response=structured_json, from_task=from_task,
call_type=LLMCallType.LLM_CALL, from_agent=from_agent,
from_task=from_task, messages=params["messages"],
from_agent=from_agent, )
messages=params["messages"], return structured_json
) else:
for block in response.content:
return structured_json if (
isinstance(block, ToolUseBlock)
and block.name == "structured_output"
):
structured_json = json.dumps(block.input)
self._emit_call_completed_event(
response=structured_json,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
# Check if Claude wants to use tools # Check if Claude wants to use tools
if response.content: if response.content:
@@ -678,17 +771,31 @@ class AnthropicCompletion(BaseLLM):
from_task: Any | None = None, from_task: Any | None = None,
from_agent: Any | None = None, from_agent: Any | None = None,
response_model: type[BaseModel] | None = None, response_model: type[BaseModel] | None = None,
) -> str: ) -> str | Any:
"""Handle streaming message completion.""" """Handle streaming message completion."""
if response_model: betas: list[str] = []
structured_tool = { use_native_structured_output = False
"name": "structured_output",
"description": "Returns structured data according to the schema",
"input_schema": response_model.model_json_schema(),
}
params["tools"] = [structured_tool] extra_body: dict[str, Any] | None = None
params["tool_choice"] = {"type": "tool", "name": "structured_output"} if _is_pydantic_model_class(response_model):
schema = transform_schema(response_model.model_json_schema())
if _supports_native_structured_outputs(self.model):
use_native_structured_output = True
betas.append(ANTHROPIC_STRUCTURED_OUTPUTS_BETA)
extra_body = {
"output_format": {
"type": "json_schema",
"schema": schema,
}
}
else:
structured_tool = {
"name": "structured_output",
"description": "Output the structured response",
"input_schema": schema,
}
params["tools"] = [structured_tool]
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
full_response = "" full_response = ""
@@ -696,15 +803,22 @@ class AnthropicCompletion(BaseLLM):
# (the SDK sets it internally) # (the SDK sets it internally)
stream_params = {k: v for k, v in params.items() if k != "stream"} stream_params = {k: v for k, v in params.items() if k != "stream"}
if betas:
stream_params["betas"] = betas
current_tool_calls: dict[int, dict[str, Any]] = {} current_tool_calls: dict[int, dict[str, Any]] = {}
# Make streaming API call stream_context = (
with self.client.messages.stream(**stream_params) as stream: self.client.beta.messages.stream(**stream_params, extra_body=extra_body)
if betas
else self.client.messages.stream(**stream_params)
)
with stream_context as stream:
response_id = None response_id = None
for event in stream: for event in stream:
if hasattr(event, "message") and hasattr(event.message, "id"): if hasattr(event, "message") and hasattr(event.message, "id"):
response_id = event.message.id response_id = event.message.id
if hasattr(event, "delta") and hasattr(event.delta, "text"): if hasattr(event, "delta") and hasattr(event.delta, "text"):
text_delta = event.delta.text text_delta = event.delta.text
full_response += text_delta full_response += text_delta
@@ -712,7 +826,7 @@ class AnthropicCompletion(BaseLLM):
chunk=text_delta, chunk=text_delta,
from_task=from_task, from_task=from_task,
from_agent=from_agent, from_agent=from_agent,
response_id=response_id response_id=response_id,
) )
if event.type == "content_block_start": if event.type == "content_block_start":
@@ -739,7 +853,7 @@ class AnthropicCompletion(BaseLLM):
"index": block_index, "index": block_index,
}, },
call_type=LLMCallType.TOOL_CALL, call_type=LLMCallType.TOOL_CALL,
response_id=response_id response_id=response_id,
) )
elif event.type == "content_block_delta": elif event.type == "content_block_delta":
if event.delta.type == "input_json_delta": if event.delta.type == "input_json_delta":
@@ -763,10 +877,10 @@ class AnthropicCompletion(BaseLLM):
"index": block_index, "index": block_index,
}, },
call_type=LLMCallType.TOOL_CALL, call_type=LLMCallType.TOOL_CALL,
response_id=response_id response_id=response_id,
) )
final_message: Message = stream.get_final_message() final_message = stream.get_final_message()
thinking_blocks: list[ThinkingBlock] = [] thinking_blocks: list[ThinkingBlock] = []
if final_message.content: if final_message.content:
@@ -781,25 +895,30 @@ class AnthropicCompletion(BaseLLM):
usage = self._extract_anthropic_token_usage(final_message) usage = self._extract_anthropic_token_usage(final_message)
self._track_token_usage_internal(usage) self._track_token_usage_internal(usage)
if response_model and final_message.content: if _is_pydantic_model_class(response_model):
tool_uses = [ if use_native_structured_output:
block
for block in final_message.content
if isinstance(block, ToolUseBlock)
]
if tool_uses and tool_uses[0].name == "structured_output":
structured_data = tool_uses[0].input
structured_json = json.dumps(structured_data)
self._emit_call_completed_event( self._emit_call_completed_event(
response=structured_json, response=full_response,
call_type=LLMCallType.LLM_CALL, call_type=LLMCallType.LLM_CALL,
from_task=from_task, from_task=from_task,
from_agent=from_agent, from_agent=from_agent,
messages=params["messages"], messages=params["messages"],
) )
return full_response
return structured_json for block in final_message.content:
if (
isinstance(block, ToolUseBlock)
and block.name == "structured_output"
):
structured_json = json.dumps(block.input)
self._emit_call_completed_event(
response=structured_json,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
if final_message.content: if final_message.content:
tool_uses = [ tool_uses = [
@@ -809,11 +928,9 @@ class AnthropicCompletion(BaseLLM):
] ]
if tool_uses: if tool_uses:
# If no available_functions, return tool calls for executor to handle
if not available_functions: if not available_functions:
return list(tool_uses) return list(tool_uses)
# Handle tool use conversation flow internally
return self._handle_tool_use_conversation( return self._handle_tool_use_conversation(
final_message, final_message,
tool_uses, tool_uses,
@@ -823,10 +940,8 @@ class AnthropicCompletion(BaseLLM):
from_agent, from_agent,
) )
# Apply stop words to full response
full_response = self._apply_stop_words(full_response) full_response = self._apply_stop_words(full_response)
# Emit completion event and return full response
self._emit_call_completed_event( self._emit_call_completed_event(
response=full_response, response=full_response,
call_type=LLMCallType.LLM_CALL, call_type=LLMCallType.LLM_CALL,
@@ -884,7 +999,7 @@ class AnthropicCompletion(BaseLLM):
def _handle_tool_use_conversation( def _handle_tool_use_conversation(
self, self,
initial_response: Message, initial_response: Message | BetaMessage,
tool_uses: list[ToolUseBlock], tool_uses: list[ToolUseBlock],
params: dict[str, Any], params: dict[str, Any],
available_functions: dict[str, Any], available_functions: dict[str, Any],
@@ -1002,22 +1117,40 @@ class AnthropicCompletion(BaseLLM):
response_model: type[BaseModel] | None = None, response_model: type[BaseModel] | None = None,
) -> str | Any: ) -> str | Any:
"""Handle non-streaming async message completion.""" """Handle non-streaming async message completion."""
if response_model:
structured_tool = {
"name": "structured_output",
"description": "Returns structured data according to the schema",
"input_schema": response_model.model_json_schema(),
}
params["tools"] = [structured_tool]
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
uses_file_api = _contains_file_id_reference(params.get("messages", [])) uses_file_api = _contains_file_id_reference(params.get("messages", []))
betas: list[str] = []
use_native_structured_output = False
if uses_file_api:
betas.append(ANTHROPIC_FILES_API_BETA)
extra_body: dict[str, Any] | None = None
if _is_pydantic_model_class(response_model):
schema = transform_schema(response_model.model_json_schema())
if _supports_native_structured_outputs(self.model):
use_native_structured_output = True
betas.append(ANTHROPIC_STRUCTURED_OUTPUTS_BETA)
extra_body = {
"output_format": {
"type": "json_schema",
"schema": schema,
}
}
else:
structured_tool = {
"name": "structured_output",
"description": "Output the structured response",
"input_schema": schema,
}
params["tools"] = [structured_tool]
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
try: try:
if uses_file_api: if betas:
params["betas"] = [ANTHROPIC_FILES_API_BETA] params["betas"] = betas
response = await self.async_client.beta.messages.create(**params) response = await self.async_client.beta.messages.create(
**params, extra_body=extra_body
)
else: else:
response = await self.async_client.messages.create(**params) response = await self.async_client.messages.create(**params)
@@ -1030,23 +1163,34 @@ class AnthropicCompletion(BaseLLM):
usage = self._extract_anthropic_token_usage(response) usage = self._extract_anthropic_token_usage(response)
self._track_token_usage_internal(usage) self._track_token_usage_internal(usage)
if response_model and response.content: if _is_pydantic_model_class(response_model) and response.content:
tool_uses = [ if use_native_structured_output:
block for block in response.content if isinstance(block, ToolUseBlock) for block in response.content:
] if isinstance(block, TextBlock):
if tool_uses and tool_uses[0].name == "structured_output": structured_json = block.text
structured_data = tool_uses[0].input self._emit_call_completed_event(
structured_json = json.dumps(structured_data) response=structured_json,
call_type=LLMCallType.LLM_CALL,
self._emit_call_completed_event( from_task=from_task,
response=structured_json, from_agent=from_agent,
call_type=LLMCallType.LLM_CALL, messages=params["messages"],
from_task=from_task, )
from_agent=from_agent, return structured_json
messages=params["messages"], else:
) for block in response.content:
if (
return structured_json isinstance(block, ToolUseBlock)
and block.name == "structured_output"
):
structured_json = json.dumps(block.input)
self._emit_call_completed_event(
response=structured_json,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
if response.content: if response.content:
tool_uses = [ tool_uses = [
@@ -1102,25 +1246,49 @@ class AnthropicCompletion(BaseLLM):
from_task: Any | None = None, from_task: Any | None = None,
from_agent: Any | None = None, from_agent: Any | None = None,
response_model: type[BaseModel] | None = None, response_model: type[BaseModel] | None = None,
) -> str: ) -> str | Any:
"""Handle async streaming message completion.""" """Handle async streaming message completion."""
if response_model: betas: list[str] = []
structured_tool = { use_native_structured_output = False
"name": "structured_output",
"description": "Returns structured data according to the schema",
"input_schema": response_model.model_json_schema(),
}
params["tools"] = [structured_tool] extra_body: dict[str, Any] | None = None
params["tool_choice"] = {"type": "tool", "name": "structured_output"} if _is_pydantic_model_class(response_model):
schema = transform_schema(response_model.model_json_schema())
if _supports_native_structured_outputs(self.model):
use_native_structured_output = True
betas.append(ANTHROPIC_STRUCTURED_OUTPUTS_BETA)
extra_body = {
"output_format": {
"type": "json_schema",
"schema": schema,
}
}
else:
structured_tool = {
"name": "structured_output",
"description": "Output the structured response",
"input_schema": schema,
}
params["tools"] = [structured_tool]
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
full_response = "" full_response = ""
stream_params = {k: v for k, v in params.items() if k != "stream"} stream_params = {k: v for k, v in params.items() if k != "stream"}
if betas:
stream_params["betas"] = betas
current_tool_calls: dict[int, dict[str, Any]] = {} current_tool_calls: dict[int, dict[str, Any]] = {}
async with self.async_client.messages.stream(**stream_params) as stream: stream_context = (
self.async_client.beta.messages.stream(
**stream_params, extra_body=extra_body
)
if betas
else self.async_client.messages.stream(**stream_params)
)
async with stream_context as stream:
response_id = None response_id = None
async for event in stream: async for event in stream:
if hasattr(event, "message") and hasattr(event.message, "id"): if hasattr(event, "message") and hasattr(event.message, "id"):
@@ -1133,7 +1301,7 @@ class AnthropicCompletion(BaseLLM):
chunk=text_delta, chunk=text_delta,
from_task=from_task, from_task=from_task,
from_agent=from_agent, from_agent=from_agent,
response_id=response_id response_id=response_id,
) )
if event.type == "content_block_start": if event.type == "content_block_start":
@@ -1160,7 +1328,7 @@ class AnthropicCompletion(BaseLLM):
"index": block_index, "index": block_index,
}, },
call_type=LLMCallType.TOOL_CALL, call_type=LLMCallType.TOOL_CALL,
response_id=response_id response_id=response_id,
) )
elif event.type == "content_block_delta": elif event.type == "content_block_delta":
if event.delta.type == "input_json_delta": if event.delta.type == "input_json_delta":
@@ -1184,33 +1352,38 @@ class AnthropicCompletion(BaseLLM):
"index": block_index, "index": block_index,
}, },
call_type=LLMCallType.TOOL_CALL, call_type=LLMCallType.TOOL_CALL,
response_id=response_id response_id=response_id,
) )
final_message: Message = await stream.get_final_message() final_message = await stream.get_final_message()
usage = self._extract_anthropic_token_usage(final_message) usage = self._extract_anthropic_token_usage(final_message)
self._track_token_usage_internal(usage) self._track_token_usage_internal(usage)
if response_model and final_message.content: if _is_pydantic_model_class(response_model):
tool_uses = [ if use_native_structured_output:
block
for block in final_message.content
if isinstance(block, ToolUseBlock)
]
if tool_uses and tool_uses[0].name == "structured_output":
structured_data = tool_uses[0].input
structured_json = json.dumps(structured_data)
self._emit_call_completed_event( self._emit_call_completed_event(
response=structured_json, response=full_response,
call_type=LLMCallType.LLM_CALL, call_type=LLMCallType.LLM_CALL,
from_task=from_task, from_task=from_task,
from_agent=from_agent, from_agent=from_agent,
messages=params["messages"], messages=params["messages"],
) )
return full_response
return structured_json for block in final_message.content:
if (
isinstance(block, ToolUseBlock)
and block.name == "structured_output"
):
structured_json = json.dumps(block.input)
self._emit_call_completed_event(
response=structured_json,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
if final_message.content: if final_message.content:
tool_uses = [ tool_uses = [
@@ -1220,7 +1393,6 @@ class AnthropicCompletion(BaseLLM):
] ]
if tool_uses: if tool_uses:
# If no available_functions, return tool calls for executor to handle
if not available_functions: if not available_functions:
return list(tool_uses) return list(tool_uses)
@@ -1247,7 +1419,7 @@ class AnthropicCompletion(BaseLLM):
async def _ahandle_tool_use_conversation( async def _ahandle_tool_use_conversation(
self, self,
initial_response: Message, initial_response: Message | BetaMessage,
tool_uses: list[ToolUseBlock], tool_uses: list[ToolUseBlock],
params: dict[str, Any], params: dict[str, Any],
available_functions: dict[str, Any], available_functions: dict[str, Any],
@@ -1356,7 +1528,9 @@ class AnthropicCompletion(BaseLLM):
return int(200000 * CONTEXT_WINDOW_USAGE_RATIO) return int(200000 * CONTEXT_WINDOW_USAGE_RATIO)
@staticmethod @staticmethod
def _extract_anthropic_token_usage(response: Message) -> dict[str, Any]: def _extract_anthropic_token_usage(
response: Message | BetaMessage,
) -> dict[str, Any]:
"""Extract token usage from Anthropic response.""" """Extract token usage from Anthropic response."""
if hasattr(response, "usage") and response.usage: if hasattr(response, "usage") and response.usage:
usage = response.usage usage = response.usage

View File

@@ -92,6 +92,7 @@ class AzureCompletion(BaseLLM):
stop: list[str] | None = None, stop: list[str] | None = None,
stream: bool = False, stream: bool = False,
interceptor: BaseInterceptor[Any, Any] | None = None, interceptor: BaseInterceptor[Any, Any] | None = None,
response_format: type[BaseModel] | None = None,
**kwargs: Any, **kwargs: Any,
): ):
"""Initialize Azure AI Inference chat completion client. """Initialize Azure AI Inference chat completion client.
@@ -111,6 +112,9 @@ class AzureCompletion(BaseLLM):
stop: Stop sequences stop: Stop sequences
stream: Enable streaming responses stream: Enable streaming responses
interceptor: HTTP interceptor (not yet supported for Azure). interceptor: HTTP interceptor (not yet supported for Azure).
response_format: Pydantic model for structured output. Used as default when
response_model is not passed to call()/acall() methods.
Only works with OpenAI models deployed on Azure.
**kwargs: Additional parameters **kwargs: Additional parameters
""" """
if interceptor is not None: if interceptor is not None:
@@ -165,6 +169,7 @@ class AzureCompletion(BaseLLM):
self.presence_penalty = presence_penalty self.presence_penalty = presence_penalty
self.max_tokens = max_tokens self.max_tokens = max_tokens
self.stream = stream self.stream = stream
self.response_format = response_format
self.is_openai_model = any( self.is_openai_model = any(
prefix in model.lower() for prefix in ["gpt-", "o1-", "text-"] prefix in model.lower() for prefix in ["gpt-", "o1-", "text-"]
@@ -298,6 +303,7 @@ class AzureCompletion(BaseLLM):
from_task=from_task, from_task=from_task,
from_agent=from_agent, from_agent=from_agent,
) )
effective_response_model = response_model or self.response_format
# Format messages for Azure # Format messages for Azure
formatted_messages = self._format_messages_for_azure(messages) formatted_messages = self._format_messages_for_azure(messages)
@@ -307,7 +313,7 @@ class AzureCompletion(BaseLLM):
# Prepare completion parameters # Prepare completion parameters
completion_params = self._prepare_completion_params( completion_params = self._prepare_completion_params(
formatted_messages, tools, response_model formatted_messages, tools, effective_response_model
) )
# Handle streaming vs non-streaming # Handle streaming vs non-streaming
@@ -317,7 +323,7 @@ class AzureCompletion(BaseLLM):
available_functions, available_functions,
from_task, from_task,
from_agent, from_agent,
response_model, effective_response_model,
) )
return self._handle_completion( return self._handle_completion(
@@ -325,7 +331,7 @@ class AzureCompletion(BaseLLM):
available_functions, available_functions,
from_task, from_task,
from_agent, from_agent,
response_model, effective_response_model,
) )
except Exception as e: except Exception as e:
@@ -364,11 +370,12 @@ class AzureCompletion(BaseLLM):
from_task=from_task, from_task=from_task,
from_agent=from_agent, from_agent=from_agent,
) )
effective_response_model = response_model or self.response_format
formatted_messages = self._format_messages_for_azure(messages) formatted_messages = self._format_messages_for_azure(messages)
completion_params = self._prepare_completion_params( completion_params = self._prepare_completion_params(
formatted_messages, tools, response_model formatted_messages, tools, effective_response_model
) )
if self.stream: if self.stream:
@@ -377,7 +384,7 @@ class AzureCompletion(BaseLLM):
available_functions, available_functions,
from_task, from_task,
from_agent, from_agent,
response_model, effective_response_model,
) )
return await self._ahandle_completion( return await self._ahandle_completion(
@@ -385,7 +392,7 @@ class AzureCompletion(BaseLLM):
available_functions, available_functions,
from_task, from_task,
from_agent, from_agent,
response_model, effective_response_model,
) )
except Exception as e: except Exception as e:
@@ -726,7 +733,7 @@ class AzureCompletion(BaseLLM):
""" """
if update.choices: if update.choices:
choice = update.choices[0] choice = update.choices[0]
response_id = update.id if hasattr(update,"id") else None response_id = update.id if hasattr(update, "id") else None
if choice.delta and choice.delta.content: if choice.delta and choice.delta.content:
content_delta = choice.delta.content content_delta = choice.delta.content
full_response += content_delta full_response += content_delta
@@ -734,7 +741,7 @@ class AzureCompletion(BaseLLM):
chunk=content_delta, chunk=content_delta,
from_task=from_task, from_task=from_task,
from_agent=from_agent, from_agent=from_agent,
response_id=response_id response_id=response_id,
) )
if choice.delta and choice.delta.tool_calls: if choice.delta and choice.delta.tool_calls:
@@ -769,7 +776,7 @@ class AzureCompletion(BaseLLM):
"index": idx, "index": idx,
}, },
call_type=LLMCallType.TOOL_CALL, call_type=LLMCallType.TOOL_CALL,
response_id=response_id response_id=response_id,
) )
return full_response return full_response

View File

@@ -172,6 +172,7 @@ class BedrockCompletion(BaseLLM):
additional_model_request_fields: dict[str, Any] | None = None, additional_model_request_fields: dict[str, Any] | None = None,
additional_model_response_field_paths: list[str] | None = None, additional_model_response_field_paths: list[str] | None = None,
interceptor: BaseInterceptor[Any, Any] | None = None, interceptor: BaseInterceptor[Any, Any] | None = None,
response_format: type[BaseModel] | None = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Initialize AWS Bedrock completion client. """Initialize AWS Bedrock completion client.
@@ -192,6 +193,8 @@ class BedrockCompletion(BaseLLM):
additional_model_request_fields: Model-specific request parameters additional_model_request_fields: Model-specific request parameters
additional_model_response_field_paths: Custom response field paths additional_model_response_field_paths: Custom response field paths
interceptor: HTTP interceptor (not yet supported for Bedrock). interceptor: HTTP interceptor (not yet supported for Bedrock).
response_format: Pydantic model for structured output. Used as default when
response_model is not passed to call()/acall() methods.
**kwargs: Additional parameters **kwargs: Additional parameters
""" """
if interceptor is not None: if interceptor is not None:
@@ -248,6 +251,7 @@ class BedrockCompletion(BaseLLM):
self.top_k = top_k self.top_k = top_k
self.stream = stream self.stream = stream
self.stop_sequences = stop_sequences self.stop_sequences = stop_sequences
self.response_format = response_format
# Store advanced features (optional) # Store advanced features (optional)
self.guardrail_config = guardrail_config self.guardrail_config = guardrail_config
@@ -299,6 +303,8 @@ class BedrockCompletion(BaseLLM):
response_model: type[BaseModel] | None = None, response_model: type[BaseModel] | None = None,
) -> str | Any: ) -> str | Any:
"""Call AWS Bedrock Converse API.""" """Call AWS Bedrock Converse API."""
effective_response_model = response_model or self.response_format
try: try:
# Emit call started event # Emit call started event
self._emit_call_started_event( self._emit_call_started_event(
@@ -375,6 +381,7 @@ class BedrockCompletion(BaseLLM):
available_functions, available_functions,
from_task, from_task,
from_agent, from_agent,
effective_response_model,
) )
return self._handle_converse( return self._handle_converse(
@@ -383,6 +390,7 @@ class BedrockCompletion(BaseLLM):
available_functions, available_functions,
from_task, from_task,
from_agent, from_agent,
effective_response_model,
) )
except Exception as e: except Exception as e:
@@ -425,6 +433,8 @@ class BedrockCompletion(BaseLLM):
NotImplementedError: If aiobotocore is not installed. NotImplementedError: If aiobotocore is not installed.
LLMContextLengthExceededError: If context window is exceeded. LLMContextLengthExceededError: If context window is exceeded.
""" """
effective_response_model = response_model or self.response_format
if not AIOBOTOCORE_AVAILABLE: if not AIOBOTOCORE_AVAILABLE:
raise NotImplementedError( raise NotImplementedError(
"Async support for AWS Bedrock requires aiobotocore. " "Async support for AWS Bedrock requires aiobotocore. "
@@ -494,11 +504,21 @@ class BedrockCompletion(BaseLLM):
if self.stream: if self.stream:
return await self._ahandle_streaming_converse( return await self._ahandle_streaming_converse(
formatted_messages, body, available_functions, from_task, from_agent formatted_messages,
body,
available_functions,
from_task,
from_agent,
effective_response_model,
) )
return await self._ahandle_converse( return await self._ahandle_converse(
formatted_messages, body, available_functions, from_task, from_agent formatted_messages,
body,
available_functions,
from_task,
from_agent,
effective_response_model,
) )
except Exception as e: except Exception as e:
@@ -520,10 +540,29 @@ class BedrockCompletion(BaseLLM):
available_functions: Mapping[str, Any] | None = None, available_functions: Mapping[str, Any] | None = None,
from_task: Any | None = None, from_task: Any | None = None,
from_agent: Any | None = None, from_agent: Any | None = None,
) -> str: response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle non-streaming converse API call following AWS best practices.""" """Handle non-streaming converse API call following AWS best practices."""
if response_model:
structured_tool: ConverseToolTypeDef = {
"toolSpec": {
"name": "structured_output",
"description": "Returns structured data according to the schema",
"inputSchema": {"json": response_model.model_json_schema()},
}
}
body["toolConfig"] = cast(
"ToolConfigurationTypeDef",
cast(
object,
{
"tools": [structured_tool],
"toolChoice": {"tool": {"name": "structured_output"}},
},
),
)
try: try:
# Validate messages format before API call
if not messages: if not messages:
raise ValueError("Messages cannot be empty") raise ValueError("Messages cannot be empty")
@@ -571,6 +610,21 @@ class BedrockCompletion(BaseLLM):
# If there are tool uses but no available_functions, return them for the executor to handle # If there are tool uses but no available_functions, return them for the executor to handle
tool_uses = [block["toolUse"] for block in content if "toolUse" in block] tool_uses = [block["toolUse"] for block in content if "toolUse" in block]
if response_model and tool_uses:
for tool_use in tool_uses:
if tool_use.get("name") == "structured_output":
structured_data = tool_use.get("input", {})
result = response_model.model_validate(structured_data)
self._emit_call_completed_event(
response=result.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=messages,
)
return result
if tool_uses and not available_functions: if tool_uses and not available_functions:
self._emit_call_completed_event( self._emit_call_completed_event(
response=tool_uses, response=tool_uses,
@@ -717,8 +771,28 @@ class BedrockCompletion(BaseLLM):
available_functions: dict[str, Any] | None = None, available_functions: dict[str, Any] | None = None,
from_task: Any | None = None, from_task: Any | None = None,
from_agent: Any | None = None, from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str: ) -> str:
"""Handle streaming converse API call with comprehensive event handling.""" """Handle streaming converse API call with comprehensive event handling."""
if response_model:
structured_tool: ConverseToolTypeDef = {
"toolSpec": {
"name": "structured_output",
"description": "Returns structured data according to the schema",
"inputSchema": {"json": response_model.model_json_schema()},
}
}
body["toolConfig"] = cast(
"ToolConfigurationTypeDef",
cast(
object,
{
"tools": [structured_tool],
"toolChoice": {"tool": {"name": "structured_output"}},
},
),
)
full_response = "" full_response = ""
current_tool_use: dict[str, Any] | None = None current_tool_use: dict[str, Any] | None = None
tool_use_id: str | None = None tool_use_id: str | None = None
@@ -805,7 +879,7 @@ class BedrockCompletion(BaseLLM):
"index": tool_use_index, "index": tool_use_index,
}, },
call_type=LLMCallType.TOOL_CALL, call_type=LLMCallType.TOOL_CALL,
response_id=response_id response_id=response_id,
) )
elif "contentBlockStop" in event: elif "contentBlockStop" in event:
logging.debug("Content block stopped in stream") logging.debug("Content block stopped in stream")
@@ -929,8 +1003,28 @@ class BedrockCompletion(BaseLLM):
available_functions: Mapping[str, Any] | None = None, available_functions: Mapping[str, Any] | None = None,
from_task: Any | None = None, from_task: Any | None = None,
from_agent: Any | None = None, from_agent: Any | None = None,
) -> str: response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle async non-streaming converse API call.""" """Handle async non-streaming converse API call."""
if response_model:
structured_tool: ConverseToolTypeDef = {
"toolSpec": {
"name": "structured_output",
"description": "Returns structured data according to the schema",
"inputSchema": {"json": response_model.model_json_schema()},
}
}
body["toolConfig"] = cast(
"ToolConfigurationTypeDef",
cast(
object,
{
"tools": [structured_tool],
"toolChoice": {"tool": {"name": "structured_output"}},
},
),
)
try: try:
if not messages: if not messages:
raise ValueError("Messages cannot be empty") raise ValueError("Messages cannot be empty")
@@ -976,6 +1070,21 @@ class BedrockCompletion(BaseLLM):
# If there are tool uses but no available_functions, return them for the executor to handle # If there are tool uses but no available_functions, return them for the executor to handle
tool_uses = [block["toolUse"] for block in content if "toolUse" in block] tool_uses = [block["toolUse"] for block in content if "toolUse" in block]
if response_model and tool_uses:
for tool_use in tool_uses:
if tool_use.get("name") == "structured_output":
structured_data = tool_use.get("input", {})
result = response_model.model_validate(structured_data)
self._emit_call_completed_event(
response=result.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=messages,
)
return result
if tool_uses and not available_functions: if tool_uses and not available_functions:
self._emit_call_completed_event( self._emit_call_completed_event(
response=tool_uses, response=tool_uses,
@@ -1106,8 +1215,28 @@ class BedrockCompletion(BaseLLM):
available_functions: dict[str, Any] | None = None, available_functions: dict[str, Any] | None = None,
from_task: Any | None = None, from_task: Any | None = None,
from_agent: Any | None = None, from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str: ) -> str:
"""Handle async streaming converse API call.""" """Handle async streaming converse API call."""
if response_model:
structured_tool: ConverseToolTypeDef = {
"toolSpec": {
"name": "structured_output",
"description": "Returns structured data according to the schema",
"inputSchema": {"json": response_model.model_json_schema()},
}
}
body["toolConfig"] = cast(
"ToolConfigurationTypeDef",
cast(
object,
{
"tools": [structured_tool],
"toolChoice": {"tool": {"name": "structured_output"}},
},
),
)
full_response = "" full_response = ""
current_tool_use: dict[str, Any] | None = None current_tool_use: dict[str, Any] | None = None
tool_use_id: str | None = None tool_use_id: str | None = None
@@ -1174,7 +1303,7 @@ class BedrockCompletion(BaseLLM):
chunk=text_chunk, chunk=text_chunk,
from_task=from_task, from_task=from_task,
from_agent=from_agent, from_agent=from_agent,
response_id=response_id response_id=response_id,
) )
elif "toolUse" in delta and current_tool_use: elif "toolUse" in delta and current_tool_use:
tool_input = delta["toolUse"].get("input", "") tool_input = delta["toolUse"].get("input", "")

View File

@@ -56,6 +56,7 @@ class GeminiCompletion(BaseLLM):
client_params: dict[str, Any] | None = None, client_params: dict[str, Any] | None = None,
interceptor: BaseInterceptor[Any, Any] | None = None, interceptor: BaseInterceptor[Any, Any] | None = None,
use_vertexai: bool | None = None, use_vertexai: bool | None = None,
response_format: type[BaseModel] | None = None,
**kwargs: Any, **kwargs: Any,
): ):
"""Initialize Google Gemini chat completion client. """Initialize Google Gemini chat completion client.
@@ -86,6 +87,8 @@ class GeminiCompletion(BaseLLM):
- None (default): Check GOOGLE_GENAI_USE_VERTEXAI env var - None (default): Check GOOGLE_GENAI_USE_VERTEXAI env var
When using Vertex AI with API key (Express mode), http_options with When using Vertex AI with API key (Express mode), http_options with
api_version="v1" is automatically configured. api_version="v1" is automatically configured.
response_format: Pydantic model for structured output. Used as default when
response_model is not passed to call()/acall() methods.
**kwargs: Additional parameters **kwargs: Additional parameters
""" """
if interceptor is not None: if interceptor is not None:
@@ -121,6 +124,7 @@ class GeminiCompletion(BaseLLM):
self.safety_settings = safety_settings or {} self.safety_settings = safety_settings or {}
self.stop_sequences = stop_sequences or [] self.stop_sequences = stop_sequences or []
self.tools: list[dict[str, Any]] | None = None self.tools: list[dict[str, Any]] | None = None
self.response_format = response_format
# Model-specific settings # Model-specific settings
version_match = re.search(r"gemini-(\d+(?:\.\d+)?)", model.lower()) version_match = re.search(r"gemini-(\d+(?:\.\d+)?)", model.lower())
@@ -292,6 +296,7 @@ class GeminiCompletion(BaseLLM):
from_agent=from_agent, from_agent=from_agent,
) )
self.tools = tools self.tools = tools
effective_response_model = response_model or self.response_format
formatted_content, system_instruction = self._format_messages_for_gemini( formatted_content, system_instruction = self._format_messages_for_gemini(
messages messages
@@ -303,7 +308,7 @@ class GeminiCompletion(BaseLLM):
raise ValueError("LLM call blocked by before_llm_call hook") raise ValueError("LLM call blocked by before_llm_call hook")
config = self._prepare_generation_config( config = self._prepare_generation_config(
system_instruction, tools, response_model system_instruction, tools, effective_response_model
) )
if self.stream: if self.stream:
@@ -313,7 +318,7 @@ class GeminiCompletion(BaseLLM):
available_functions, available_functions,
from_task, from_task,
from_agent, from_agent,
response_model, effective_response_model,
) )
return self._handle_completion( return self._handle_completion(
@@ -322,7 +327,7 @@ class GeminiCompletion(BaseLLM):
available_functions, available_functions,
from_task, from_task,
from_agent, from_agent,
response_model, effective_response_model,
) )
except APIError as e: except APIError as e:
@@ -374,13 +379,14 @@ class GeminiCompletion(BaseLLM):
from_agent=from_agent, from_agent=from_agent,
) )
self.tools = tools self.tools = tools
effective_response_model = response_model or self.response_format
formatted_content, system_instruction = self._format_messages_for_gemini( formatted_content, system_instruction = self._format_messages_for_gemini(
messages messages
) )
config = self._prepare_generation_config( config = self._prepare_generation_config(
system_instruction, tools, response_model system_instruction, tools, effective_response_model
) )
if self.stream: if self.stream:
@@ -390,7 +396,7 @@ class GeminiCompletion(BaseLLM):
available_functions, available_functions,
from_task, from_task,
from_agent, from_agent,
response_model, effective_response_model,
) )
return await self._ahandle_completion( return await self._ahandle_completion(
@@ -399,7 +405,7 @@ class GeminiCompletion(BaseLLM):
available_functions, available_functions,
from_task, from_task,
from_agent, from_agent,
response_model, effective_response_model,
) )
except APIError as e: except APIError as e:
@@ -570,10 +576,10 @@ class GeminiCompletion(BaseLLM):
types.Content(role="user", parts=[function_response_part]) types.Content(role="user", parts=[function_response_part])
) )
elif role == "assistant" and message.get("tool_calls"): elif role == "assistant" and message.get("tool_calls"):
parts: list[types.Part] = [] tool_parts: list[types.Part] = []
if text_content: if text_content:
parts.append(types.Part.from_text(text=text_content)) tool_parts.append(types.Part.from_text(text=text_content))
tool_calls: list[dict[str, Any]] = message.get("tool_calls") or [] tool_calls: list[dict[str, Any]] = message.get("tool_calls") or []
for tool_call in tool_calls: for tool_call in tool_calls:
@@ -592,11 +598,11 @@ class GeminiCompletion(BaseLLM):
else: else:
func_args = func_args_raw func_args = func_args_raw
parts.append( tool_parts.append(
types.Part.from_function_call(name=func_name, args=func_args) types.Part.from_function_call(name=func_name, args=func_args)
) )
contents.append(types.Content(role="model", parts=parts)) contents.append(types.Content(role="model", parts=tool_parts))
else: else:
# Convert role for Gemini (assistant -> model) # Convert role for Gemini (assistant -> model)
gemini_role = "model" if role == "assistant" else "user" gemini_role = "model" if role == "assistant" else "user"
@@ -790,7 +796,7 @@ class GeminiCompletion(BaseLLM):
Returns: Returns:
Tuple of (updated full_response, updated function_calls, updated usage_data) Tuple of (updated full_response, updated function_calls, updated usage_data)
""" """
response_id=chunk.response_id if hasattr(chunk,"response_id") else None response_id = chunk.response_id if hasattr(chunk, "response_id") else None
if chunk.usage_metadata: if chunk.usage_metadata:
usage_data = self._extract_token_usage(chunk) usage_data = self._extract_token_usage(chunk)
@@ -800,7 +806,7 @@ class GeminiCompletion(BaseLLM):
chunk=chunk.text, chunk=chunk.text,
from_task=from_task, from_task=from_task,
from_agent=from_agent, from_agent=from_agent,
response_id=response_id response_id=response_id,
) )
if chunk.candidates: if chunk.candidates:
@@ -837,7 +843,7 @@ class GeminiCompletion(BaseLLM):
"index": call_index, "index": call_index,
}, },
call_type=LLMCallType.TOOL_CALL, call_type=LLMCallType.TOOL_CALL,
response_id=response_id response_id=response_id,
) )
return full_response, function_calls, usage_data return full_response, function_calls, usage_data
@@ -972,7 +978,7 @@ class GeminiCompletion(BaseLLM):
from_task: Any | None = None, from_task: Any | None = None,
from_agent: Any | None = None, from_agent: Any | None = None,
response_model: type[BaseModel] | None = None, response_model: type[BaseModel] | None = None,
) -> str: ) -> str | Any:
"""Handle streaming content generation.""" """Handle streaming content generation."""
full_response = "" full_response = ""
function_calls: dict[int, dict[str, Any]] = {} function_calls: dict[int, dict[str, Any]] = {}
@@ -1050,7 +1056,7 @@ class GeminiCompletion(BaseLLM):
from_task: Any | None = None, from_task: Any | None = None,
from_agent: Any | None = None, from_agent: Any | None = None,
response_model: type[BaseModel] | None = None, response_model: type[BaseModel] | None = None,
) -> str: ) -> str | Any:
"""Handle async streaming content generation.""" """Handle async streaming content generation."""
full_response = "" full_response = ""
function_calls: dict[int, dict[str, Any]] = {} function_calls: dict[int, dict[str, Any]] = {}

View File

@@ -18,7 +18,6 @@ if TYPE_CHECKING:
) )
from chromadb.utils.embedding_functions.google_embedding_function import ( from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleGenerativeAiEmbeddingFunction, GoogleGenerativeAiEmbeddingFunction,
GoogleVertexEmbeddingFunction,
) )
from chromadb.utils.embedding_functions.huggingface_embedding_function import ( from chromadb.utils.embedding_functions.huggingface_embedding_function import (
HuggingFaceEmbeddingFunction, HuggingFaceEmbeddingFunction,
@@ -52,6 +51,9 @@ if TYPE_CHECKING:
from crewai.rag.embeddings.providers.aws.types import BedrockProviderSpec from crewai.rag.embeddings.providers.aws.types import BedrockProviderSpec
from crewai.rag.embeddings.providers.cohere.types import CohereProviderSpec from crewai.rag.embeddings.providers.cohere.types import CohereProviderSpec
from crewai.rag.embeddings.providers.custom.types import CustomProviderSpec from crewai.rag.embeddings.providers.custom.types import CustomProviderSpec
from crewai.rag.embeddings.providers.google.genai_vertex_embedding import (
GoogleGenAIVertexEmbeddingFunction,
)
from crewai.rag.embeddings.providers.google.types import ( from crewai.rag.embeddings.providers.google.types import (
GenerativeAiProviderSpec, GenerativeAiProviderSpec,
VertexAIProviderSpec, VertexAIProviderSpec,
@@ -163,7 +165,7 @@ def build_embedder_from_dict(spec: OpenAIProviderSpec) -> OpenAIEmbeddingFunctio
@overload @overload
def build_embedder_from_dict( def build_embedder_from_dict(
spec: VertexAIProviderSpec, spec: VertexAIProviderSpec,
) -> GoogleVertexEmbeddingFunction: ... ) -> GoogleGenAIVertexEmbeddingFunction: ...
@overload @overload
@@ -296,7 +298,9 @@ def build_embedder(spec: OpenAIProviderSpec) -> OpenAIEmbeddingFunction: ...
@overload @overload
def build_embedder(spec: VertexAIProviderSpec) -> GoogleVertexEmbeddingFunction: ... def build_embedder(
spec: VertexAIProviderSpec,
) -> GoogleGenAIVertexEmbeddingFunction: ...
@overload @overload

View File

@@ -1,5 +1,8 @@
"""Google embedding providers.""" """Google embedding providers."""
from crewai.rag.embeddings.providers.google.genai_vertex_embedding import (
GoogleGenAIVertexEmbeddingFunction,
)
from crewai.rag.embeddings.providers.google.generative_ai import ( from crewai.rag.embeddings.providers.google.generative_ai import (
GenerativeAiProvider, GenerativeAiProvider,
) )
@@ -18,6 +21,7 @@ __all__ = [
"GenerativeAiProvider", "GenerativeAiProvider",
"GenerativeAiProviderConfig", "GenerativeAiProviderConfig",
"GenerativeAiProviderSpec", "GenerativeAiProviderSpec",
"GoogleGenAIVertexEmbeddingFunction",
"VertexAIProvider", "VertexAIProvider",
"VertexAIProviderConfig", "VertexAIProviderConfig",
"VertexAIProviderSpec", "VertexAIProviderSpec",

View File

@@ -0,0 +1,237 @@
"""Google Vertex AI embedding function implementation.
This module supports both the new google-genai SDK and the deprecated
vertexai.language_models module for backwards compatibility.
The deprecated vertexai.language_models module will be removed after June 24, 2026.
Migration guide: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk
"""
from typing import Any, ClassVar, cast
import warnings
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings
from typing_extensions import Unpack
from crewai.rag.embeddings.providers.google.types import VertexAIProviderConfig
class GoogleGenAIVertexEmbeddingFunction(EmbeddingFunction[Documents]):
"""Embedding function for Google Vertex AI with dual SDK support.
This class supports both:
- Legacy models (textembedding-gecko*) using the deprecated vertexai.language_models SDK
- New models (gemini-embedding-*, text-embedding-*) using the google-genai SDK
The SDK is automatically selected based on the model name. Legacy models will
emit a deprecation warning.
Supports two authentication modes:
1. Vertex AI backend: Set project_id and location/region (uses Application Default Credentials)
2. API key: Set api_key for direct API access
Example:
# Using legacy model (will emit deprecation warning)
embedder = GoogleGenAIVertexEmbeddingFunction(
project_id="my-project",
region="us-central1",
model_name="textembedding-gecko"
)
# Using new model with google-genai SDK
embedder = GoogleGenAIVertexEmbeddingFunction(
project_id="my-project",
location="us-central1",
model_name="gemini-embedding-001"
)
# Using API key (new SDK only)
embedder = GoogleGenAIVertexEmbeddingFunction(
api_key="your-api-key",
model_name="gemini-embedding-001"
)
"""
# Models that use the legacy vertexai.language_models SDK
LEGACY_MODELS: ClassVar[set[str]] = {
"textembedding-gecko",
"textembedding-gecko@001",
"textembedding-gecko@002",
"textembedding-gecko@003",
"textembedding-gecko@latest",
"textembedding-gecko-multilingual",
"textembedding-gecko-multilingual@001",
"textembedding-gecko-multilingual@latest",
}
# Models that use the new google-genai SDK
GENAI_MODELS: ClassVar[set[str]] = {
"gemini-embedding-001",
"text-embedding-005",
"text-multilingual-embedding-002",
}
def __init__(self, **kwargs: Unpack[VertexAIProviderConfig]) -> None:
"""Initialize Google Vertex AI embedding function.
Args:
**kwargs: Configuration parameters including:
- model_name: Model to use for embeddings (default: "textembedding-gecko")
- api_key: Optional API key for authentication (new SDK only)
- project_id: GCP project ID (for Vertex AI backend)
- location: GCP region (default: "us-central1")
- region: Deprecated alias for location
- task_type: Task type for embeddings (default: "RETRIEVAL_DOCUMENT", new SDK only)
- output_dimensionality: Optional output embedding dimension (new SDK only)
"""
# Handle deprecated 'region' parameter (only if it has a value)
region_value = kwargs.pop("region", None) # type: ignore[typeddict-item]
if region_value is not None:
warnings.warn(
"The 'region' parameter is deprecated, use 'location' instead. "
"See: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk",
DeprecationWarning,
stacklevel=2,
)
if "location" not in kwargs or kwargs.get("location") is None:
kwargs["location"] = region_value # type: ignore[typeddict-unknown-key]
self._config = kwargs
self._model_name = str(kwargs.get("model_name", "textembedding-gecko"))
self._use_legacy = self._is_legacy_model(self._model_name)
if self._use_legacy:
self._init_legacy_client(**kwargs)
else:
self._init_genai_client(**kwargs)
def _is_legacy_model(self, model_name: str) -> bool:
"""Check if the model uses the legacy SDK."""
return model_name in self.LEGACY_MODELS or model_name.startswith(
"textembedding-gecko"
)
def _init_legacy_client(self, **kwargs: Any) -> None:
"""Initialize using the deprecated vertexai.language_models SDK."""
warnings.warn(
f"Model '{self._model_name}' uses the deprecated vertexai.language_models SDK "
"which will be removed after June 24, 2026. Consider migrating to newer models "
"like 'gemini-embedding-001'. "
"See: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk",
DeprecationWarning,
stacklevel=3,
)
try:
import vertexai
from vertexai.language_models import TextEmbeddingModel
except ImportError as e:
raise ImportError(
"vertexai is required for legacy embedding models (textembedding-gecko*). "
"Install it with: pip install google-cloud-aiplatform"
) from e
project_id = kwargs.get("project_id")
location = str(kwargs.get("location", "us-central1"))
if not project_id:
raise ValueError(
"project_id is required for legacy models. "
"For API key authentication, use newer models like 'gemini-embedding-001'."
)
vertexai.init(project=str(project_id), location=location)
self._legacy_model = TextEmbeddingModel.from_pretrained(self._model_name)
def _init_genai_client(self, **kwargs: Any) -> None:
"""Initialize using the new google-genai SDK."""
try:
from google import genai
from google.genai.types import EmbedContentConfig
except ImportError as e:
raise ImportError(
"google-genai is required for Google Gen AI embeddings. "
"Install it with: uv add 'crewai[google-genai]'"
) from e
self._genai = genai
self._EmbedContentConfig = EmbedContentConfig
self._task_type = kwargs.get("task_type", "RETRIEVAL_DOCUMENT")
self._output_dimensionality = kwargs.get("output_dimensionality")
# Initialize client based on authentication mode
api_key = kwargs.get("api_key")
project_id = kwargs.get("project_id")
location: str = str(kwargs.get("location", "us-central1"))
if api_key:
self._client = genai.Client(api_key=api_key)
elif project_id:
self._client = genai.Client(
vertexai=True,
project=str(project_id),
location=location,
)
else:
raise ValueError(
"Either 'api_key' (for API key authentication) or 'project_id' "
"(for Vertex AI backend with ADC) must be provided."
)
@staticmethod
def name() -> str:
"""Return the name of the embedding function for ChromaDB compatibility."""
return "google-vertex"
def __call__(self, input: Documents) -> Embeddings:
"""Generate embeddings for input documents.
Args:
input: List of documents to embed.
Returns:
List of embedding vectors.
"""
if isinstance(input, str):
input = [input]
if self._use_legacy:
return self._call_legacy(input)
return self._call_genai(input)
def _call_legacy(self, input: list[str]) -> Embeddings:
"""Generate embeddings using the legacy SDK."""
import numpy as np
embeddings_list = []
for text in input:
embedding_result = self._legacy_model.get_embeddings([text])
embeddings_list.append(
np.array(embedding_result[0].values, dtype=np.float32)
)
return cast(Embeddings, embeddings_list)
def _call_genai(self, input: list[str]) -> Embeddings:
"""Generate embeddings using the new google-genai SDK."""
# Build config for embed_content
config_kwargs: dict[str, Any] = {
"task_type": self._task_type,
}
if self._output_dimensionality is not None:
config_kwargs["output_dimensionality"] = self._output_dimensionality
config = self._EmbedContentConfig(**config_kwargs)
# Call the embedding API
response = self._client.models.embed_content(
model=self._model_name,
contents=input, # type: ignore[arg-type]
config=config,
)
# Extract embeddings from response
if response.embeddings is None:
raise ValueError("No embeddings returned from the API")
embeddings = [emb.values for emb in response.embeddings]
return cast(Embeddings, embeddings)

View File

@@ -34,12 +34,47 @@ class GenerativeAiProviderSpec(TypedDict):
class VertexAIProviderConfig(TypedDict, total=False): class VertexAIProviderConfig(TypedDict, total=False):
"""Configuration for Vertex AI provider.""" """Configuration for Vertex AI provider with dual SDK support.
Supports both legacy models (textembedding-gecko*) using the deprecated
vertexai.language_models SDK and new models using google-genai SDK.
Attributes:
api_key: Google API key (optional if using project_id with ADC). Only for new SDK models.
model_name: Embedding model name (default: "textembedding-gecko").
Legacy models: textembedding-gecko, textembedding-gecko@001, etc.
New models: gemini-embedding-001, text-embedding-005, text-multilingual-embedding-002
project_id: GCP project ID (required for Vertex AI backend and legacy models).
location: GCP region/location (default: "us-central1").
region: Deprecated alias for location (kept for backwards compatibility).
task_type: Task type for embeddings (default: "RETRIEVAL_DOCUMENT"). Only for new SDK models.
output_dimensionality: Output embedding dimension (optional). Only for new SDK models.
"""
api_key: str api_key: str
model_name: Annotated[str, "textembedding-gecko"] model_name: Annotated[
project_id: Annotated[str, "cloud-large-language-models"] Literal[
region: Annotated[str, "us-central1"] # Legacy models (deprecated vertexai.language_models SDK)
"textembedding-gecko",
"textembedding-gecko@001",
"textembedding-gecko@002",
"textembedding-gecko@003",
"textembedding-gecko@latest",
"textembedding-gecko-multilingual",
"textembedding-gecko-multilingual@001",
"textembedding-gecko-multilingual@latest",
# New models (google-genai SDK)
"gemini-embedding-001",
"text-embedding-005",
"text-multilingual-embedding-002",
],
"textembedding-gecko",
]
project_id: str
location: Annotated[str, "us-central1"]
region: Annotated[str, "us-central1"] # Deprecated alias for location
task_type: Annotated[str, "RETRIEVAL_DOCUMENT"]
output_dimensionality: int
class VertexAIProviderSpec(TypedDict, total=False): class VertexAIProviderSpec(TypedDict, total=False):

View File

@@ -1,46 +1,126 @@
"""Google Vertex AI embeddings provider.""" """Google Vertex AI embeddings provider.
This module supports both the new google-genai SDK and the deprecated
vertexai.language_models module for backwards compatibility.
The SDK is automatically selected based on the model name:
- Legacy models (textembedding-gecko*) use vertexai.language_models (deprecated)
- New models (gemini-embedding-*, text-embedding-*) use google-genai
Migration guide: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk
"""
from __future__ import annotations
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleVertexEmbeddingFunction,
)
from pydantic import AliasChoices, Field from pydantic import AliasChoices, Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
from crewai.rag.embeddings.providers.google.genai_vertex_embedding import (
GoogleGenAIVertexEmbeddingFunction,
)
class VertexAIProvider(BaseEmbeddingsProvider[GoogleVertexEmbeddingFunction]): class VertexAIProvider(BaseEmbeddingsProvider[GoogleGenAIVertexEmbeddingFunction]):
"""Google Vertex AI embeddings provider.""" """Google Vertex AI embeddings provider with dual SDK support.
embedding_callable: type[GoogleVertexEmbeddingFunction] = Field( Supports both legacy models (textembedding-gecko*) using the deprecated
default=GoogleVertexEmbeddingFunction, vertexai.language_models SDK and new models (gemini-embedding-*, text-embedding-*)
description="Vertex AI embedding function class", using the google-genai SDK.
The SDK is automatically selected based on the model name. Legacy models will
emit a deprecation warning.
Authentication modes:
1. Vertex AI backend: Set project_id and location/region (uses Application Default Credentials)
2. API key: Set api_key for direct API access (new SDK models only)
Example:
# Legacy model (backwards compatible, will emit deprecation warning)
provider = VertexAIProvider(
project_id="my-project",
region="us-central1", # or location="us-central1"
model_name="textembedding-gecko"
)
# New model with Vertex AI backend
provider = VertexAIProvider(
project_id="my-project",
location="us-central1",
model_name="gemini-embedding-001"
)
# New model with API key
provider = VertexAIProvider(
api_key="your-api-key",
model_name="gemini-embedding-001"
)
"""
embedding_callable: type[GoogleGenAIVertexEmbeddingFunction] = Field(
default=GoogleGenAIVertexEmbeddingFunction,
description="Google Vertex AI embedding function class",
) )
model_name: str = Field( model_name: str = Field(
default="textembedding-gecko", default="textembedding-gecko",
description="Model name to use for embeddings", description=(
"Model name to use for embeddings. Legacy models (textembedding-gecko*) "
"use the deprecated SDK. New models (gemini-embedding-001, text-embedding-005) "
"use the google-genai SDK."
),
validation_alias=AliasChoices( validation_alias=AliasChoices(
"EMBEDDINGS_GOOGLE_VERTEX_MODEL_NAME", "EMBEDDINGS_GOOGLE_VERTEX_MODEL_NAME",
"GOOGLE_VERTEX_MODEL_NAME", "GOOGLE_VERTEX_MODEL_NAME",
"model", "model",
), ),
) )
api_key: str = Field( api_key: str | None = Field(
description="Google API key", default=None,
description="Google API key (optional if using project_id with Application Default Credentials)",
validation_alias=AliasChoices( validation_alias=AliasChoices(
"EMBEDDINGS_GOOGLE_CLOUD_API_KEY", "GOOGLE_CLOUD_API_KEY" "EMBEDDINGS_GOOGLE_CLOUD_API_KEY",
"GOOGLE_CLOUD_API_KEY",
"GOOGLE_API_KEY",
), ),
) )
project_id: str = Field( project_id: str | None = Field(
default="cloud-large-language-models", default=None,
description="GCP project ID", description="GCP project ID (required for Vertex AI backend and legacy models)",
validation_alias=AliasChoices( validation_alias=AliasChoices(
"EMBEDDINGS_GOOGLE_CLOUD_PROJECT", "GOOGLE_CLOUD_PROJECT" "EMBEDDINGS_GOOGLE_CLOUD_PROJECT",
"GOOGLE_CLOUD_PROJECT",
), ),
) )
region: str = Field( location: str = Field(
default="us-central1", default="us-central1",
description="GCP region", description="GCP region/location",
validation_alias=AliasChoices( validation_alias=AliasChoices(
"EMBEDDINGS_GOOGLE_CLOUD_REGION", "GOOGLE_CLOUD_REGION" "EMBEDDINGS_GOOGLE_CLOUD_LOCATION",
"EMBEDDINGS_GOOGLE_CLOUD_REGION",
"GOOGLE_CLOUD_LOCATION",
"GOOGLE_CLOUD_REGION",
),
)
region: str | None = Field(
default=None,
description="Deprecated: Use 'location' instead. GCP region (kept for backwards compatibility)",
validation_alias=AliasChoices(
"EMBEDDINGS_GOOGLE_VERTEX_REGION",
"GOOGLE_VERTEX_REGION",
),
)
task_type: str = Field(
default="RETRIEVAL_DOCUMENT",
description="Task type for embeddings (e.g., RETRIEVAL_DOCUMENT, RETRIEVAL_QUERY). Only used with new SDK models.",
validation_alias=AliasChoices(
"EMBEDDINGS_GOOGLE_VERTEX_TASK_TYPE",
"GOOGLE_VERTEX_TASK_TYPE",
),
)
output_dimensionality: int | None = Field(
default=None,
description="Output embedding dimensionality (optional). Only used with new SDK models.",
validation_alias=AliasChoices(
"EMBEDDINGS_GOOGLE_VERTEX_OUTPUT_DIMENSIONALITY",
"GOOGLE_VERTEX_OUTPUT_DIMENSIONALITY",
), ),
) )

View File

@@ -173,6 +173,13 @@ class Telemetry:
self._original_handlers: dict[int, Any] = {} self._original_handlers: dict[int, Any] = {}
if threading.current_thread() is not threading.main_thread():
logger.debug(
"CrewAI telemetry: Skipping signal handler registration "
"(not running in main thread)."
)
return
self._register_signal_handler(signal.SIGTERM, SigTermEvent, shutdown=True) self._register_signal_handler(signal.SIGTERM, SigTermEvent, shutdown=True)
self._register_signal_handler(signal.SIGINT, SigIntEvent, shutdown=True) self._register_signal_handler(signal.SIGINT, SigIntEvent, shutdown=True)
if hasattr(signal, "SIGHUP"): if hasattr(signal, "SIGHUP"):

View File

@@ -5,17 +5,29 @@ from __future__ import annotations
import asyncio import asyncio
from collections.abc import Coroutine from collections.abc import Coroutine
import concurrent.futures import concurrent.futures
import logging
from typing import TYPE_CHECKING, TypeVar from typing import TYPE_CHECKING, TypeVar
from uuid import UUID from uuid import UUID
from aiocache import Cache # type: ignore[import-untyped]
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
if TYPE_CHECKING: if TYPE_CHECKING:
from aiocache import Cache
from crewai_files import FileInput from crewai_files import FileInput
_file_store = Cache(Cache.MEMORY, serializer=PickleSerializer()) logger = logging.getLogger(__name__)
_file_store: Cache | None = None
try:
from aiocache import Cache
from aiocache.serializers import PickleSerializer
_file_store = Cache(Cache.MEMORY, serializer=PickleSerializer())
except ImportError:
logger.debug(
"aiocache is not installed. File store features will be disabled. "
"Install with: uv add aiocache"
)
T = TypeVar("T") T = TypeVar("T")
@@ -59,6 +71,8 @@ async def astore_files(
files: Dictionary mapping names to file inputs. files: Dictionary mapping names to file inputs.
ttl: Time-to-live in seconds. ttl: Time-to-live in seconds.
""" """
if _file_store is None:
return
await _file_store.set(f"{_CREW_PREFIX}{execution_id}", files, ttl=ttl) await _file_store.set(f"{_CREW_PREFIX}{execution_id}", files, ttl=ttl)
@@ -71,6 +85,8 @@ async def aget_files(execution_id: UUID) -> dict[str, FileInput] | None:
Returns: Returns:
Dictionary of files or None if not found. Dictionary of files or None if not found.
""" """
if _file_store is None:
return None
result: dict[str, FileInput] | None = await _file_store.get( result: dict[str, FileInput] | None = await _file_store.get(
f"{_CREW_PREFIX}{execution_id}" f"{_CREW_PREFIX}{execution_id}"
) )
@@ -83,6 +99,8 @@ async def aclear_files(execution_id: UUID) -> None:
Args: Args:
execution_id: Unique identifier for the crew execution. execution_id: Unique identifier for the crew execution.
""" """
if _file_store is None:
return
await _file_store.delete(f"{_CREW_PREFIX}{execution_id}") await _file_store.delete(f"{_CREW_PREFIX}{execution_id}")
@@ -98,6 +116,8 @@ async def astore_task_files(
files: Dictionary mapping names to file inputs. files: Dictionary mapping names to file inputs.
ttl: Time-to-live in seconds. ttl: Time-to-live in seconds.
""" """
if _file_store is None:
return
await _file_store.set(f"{_TASK_PREFIX}{task_id}", files, ttl=ttl) await _file_store.set(f"{_TASK_PREFIX}{task_id}", files, ttl=ttl)
@@ -110,6 +130,8 @@ async def aget_task_files(task_id: UUID) -> dict[str, FileInput] | None:
Returns: Returns:
Dictionary of files or None if not found. Dictionary of files or None if not found.
""" """
if _file_store is None:
return None
result: dict[str, FileInput] | None = await _file_store.get( result: dict[str, FileInput] | None = await _file_store.get(
f"{_TASK_PREFIX}{task_id}" f"{_TASK_PREFIX}{task_id}"
) )
@@ -122,6 +144,8 @@ async def aclear_task_files(task_id: UUID) -> None:
Args: Args:
task_id: Unique identifier for the task. task_id: Unique identifier for the task.
""" """
if _file_store is None:
return
await _file_store.delete(f"{_TASK_PREFIX}{task_id}") await _file_store.delete(f"{_TASK_PREFIX}{task_id}")

View File

@@ -1,6 +1,8 @@
interactions: interactions:
- request: - request:
body: '{"max_tokens":4096,"messages":[{"role":"user","content":"Say hello in French"}],"model":"claude-sonnet-4-0","stream":false,"tool_choice":{"type":"tool","name":"structured_output"},"tools":[{"name":"structured_output","description":"Returns structured data according to the schema","input_schema":{"description":"Response model for greeting test.","properties":{"greeting":{"title":"Greeting","type":"string"},"language":{"title":"Language","type":"string"}},"required":["greeting","language"],"title":"GreetingResponse","type":"object"}}]}' body: '{"max_tokens":4096,"messages":[{"role":"user","content":"Say hello in French"}],"model":"claude-sonnet-4-0","stream":false,"tool_choice":{"type":"tool","name":"structured_output"},"tools":[{"name":"structured_output","description":"Output
the structured response","input_schema":{"type":"object","description":"Response
model for greeting test.","title":"GreetingResponse","properties":{"greeting":{"type":"string","title":"Greeting"},"language":{"type":"string","title":"Language"}},"additionalProperties":false,"required":["greeting","language"]}}]}'
headers: headers:
User-Agent: User-Agent:
- X-USER-AGENT-XXX - X-USER-AGENT-XXX
@@ -13,7 +15,7 @@ interactions:
connection: connection:
- keep-alive - keep-alive
content-length: content-length:
- '539' - '551'
content-type: content-type:
- application/json - application/json
host: host:
@@ -29,7 +31,7 @@ interactions:
x-stainless-os: x-stainless-os:
- X-STAINLESS-OS-XXX - X-STAINLESS-OS-XXX
x-stainless-package-version: x-stainless-package-version:
- 0.75.0 - 0.76.0
x-stainless-retry-count: x-stainless-retry-count:
- '0' - '0'
x-stainless-runtime: x-stainless-runtime:
@@ -42,7 +44,7 @@ interactions:
uri: https://api.anthropic.com/v1/messages uri: https://api.anthropic.com/v1/messages
response: response:
body: body:
string: '{"model":"claude-sonnet-4-20250514","id":"msg_01XjvX2nCho1knuucbwwgCpw","type":"message","role":"assistant","content":[{"type":"tool_use","id":"toolu_019rfPRSDmBb7CyCTdGMv5rK","name":"structured_output","input":{"greeting":"Bonjour","language":"French"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":432,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"cache_creation":{"ephemeral_5m_input_tokens":0,"ephemeral_1h_input_tokens":0},"output_tokens":53,"service_tier":"standard"}}' string: '{"model":"claude-sonnet-4-20250514","id":"msg_01CKTyVmak15L5oQ36mv4sL9","type":"message","role":"assistant","content":[{"type":"tool_use","id":"toolu_0174BYmn6xiSnUwVhFD8S7EW","name":"structured_output","input":{"greeting":"Bonjour","language":"French"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":436,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"cache_creation":{"ephemeral_5m_input_tokens":0,"ephemeral_1h_input_tokens":0},"output_tokens":53,"service_tier":"standard"}}'
headers: headers:
CF-RAY: CF-RAY:
- CF-RAY-XXX - CF-RAY-XXX
@@ -51,7 +53,7 @@ interactions:
Content-Type: Content-Type:
- application/json - application/json
Date: Date:
- Mon, 01 Dec 2025 11:19:38 GMT - Mon, 26 Jan 2026 14:59:34 GMT
Server: Server:
- cloudflare - cloudflare
Transfer-Encoding: Transfer-Encoding:
@@ -82,12 +84,10 @@ interactions:
- DYNAMIC - DYNAMIC
request-id: request-id:
- REQUEST-ID-XXX - REQUEST-ID-XXX
retry-after:
- '24'
strict-transport-security: strict-transport-security:
- STS-XXX - STS-XXX
x-envoy-upstream-service-time: x-envoy-upstream-service-time:
- '2101' - '968'
status: status:
code: 200 code: 200
message: OK message: OK

View File

@@ -272,3 +272,100 @@ class TestEmbeddingFactory:
mock_build_from_provider.assert_called_once_with(mock_provider) mock_build_from_provider.assert_called_once_with(mock_provider)
assert result == mock_embedding_function assert result == mock_embedding_function
mock_import.assert_not_called() mock_import.assert_not_called()
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_google_vertex_with_genai_model(self, mock_import):
"""Test routing to Google Vertex provider with new genai model."""
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_embedding_function = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
config = {
"provider": "google-vertex",
"config": {
"api_key": "test-google-api-key",
"model_name": "gemini-embedding-001",
},
}
build_embedder(config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.google.vertex.VertexAIProvider"
)
mock_provider_class.assert_called_once()
call_kwargs = mock_provider_class.call_args.kwargs
assert call_kwargs["api_key"] == "test-google-api-key"
assert call_kwargs["model_name"] == "gemini-embedding-001"
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_google_vertex_with_legacy_model(self, mock_import):
"""Test routing to Google Vertex provider with legacy textembedding-gecko model."""
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_embedding_function = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
config = {
"provider": "google-vertex",
"config": {
"project_id": "my-gcp-project",
"region": "us-central1",
"model_name": "textembedding-gecko",
},
}
build_embedder(config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.google.vertex.VertexAIProvider"
)
mock_provider_class.assert_called_once()
call_kwargs = mock_provider_class.call_args.kwargs
assert call_kwargs["project_id"] == "my-gcp-project"
assert call_kwargs["region"] == "us-central1"
assert call_kwargs["model_name"] == "textembedding-gecko"
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_google_vertex_with_location(self, mock_import):
"""Test routing to Google Vertex provider with location parameter."""
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_embedding_function = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
config = {
"provider": "google-vertex",
"config": {
"project_id": "my-gcp-project",
"location": "europe-west1",
"model_name": "gemini-embedding-001",
"task_type": "RETRIEVAL_DOCUMENT",
"output_dimensionality": 768,
},
}
build_embedder(config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.google.vertex.VertexAIProvider"
)
call_kwargs = mock_provider_class.call_args.kwargs
assert call_kwargs["project_id"] == "my-gcp-project"
assert call_kwargs["location"] == "europe-west1"
assert call_kwargs["model_name"] == "gemini-embedding-001"
assert call_kwargs["task_type"] == "RETRIEVAL_DOCUMENT"
assert call_kwargs["output_dimensionality"] == 768

View File

@@ -0,0 +1,176 @@
"""Integration tests for Google Vertex embeddings with Crew memory.
These tests make real API calls and use VCR to record/replay responses.
"""
import os
import threading
from collections import defaultdict
from unittest.mock import patch
import pytest
from crewai import Agent, Crew, Task
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemorySaveCompletedEvent,
MemorySaveStartedEvent,
)
@pytest.fixture(autouse=True)
def setup_vertex_ai_env():
"""Set up environment for Vertex AI tests.
Sets GOOGLE_GENAI_USE_VERTEXAI=true to ensure the SDK uses the Vertex AI
backend (aiplatform.googleapis.com) which matches the VCR cassettes.
Also mocks GOOGLE_API_KEY if not already set.
"""
env_updates = {"GOOGLE_GENAI_USE_VERTEXAI": "true"}
# Add a mock API key if none exists
if "GOOGLE_API_KEY" not in os.environ and "GEMINI_API_KEY" not in os.environ:
env_updates["GOOGLE_API_KEY"] = "test-key"
with patch.dict(os.environ, env_updates):
yield
@pytest.fixture
def google_vertex_embedder_config():
"""Fixture providing Google Vertex embedder configuration."""
return {
"provider": "google-vertex",
"config": {
"api_key": os.getenv("GOOGLE_API_KEY", "test-key"),
"model_name": "gemini-embedding-001",
},
}
@pytest.fixture
def simple_agent():
"""Fixture providing a simple test agent."""
return Agent(
role="Research Assistant",
goal="Help with research tasks",
backstory="You are a helpful research assistant.",
verbose=False,
)
@pytest.fixture
def simple_task(simple_agent):
"""Fixture providing a simple test task."""
return Task(
description="Summarize the key points about artificial intelligence in one sentence.",
expected_output="A one sentence summary about AI.",
agent=simple_agent,
)
@pytest.mark.vcr()
@pytest.mark.timeout(120) # Longer timeout for VCR recording
def test_crew_memory_with_google_vertex_embedder(
google_vertex_embedder_config, simple_agent, simple_task
) -> None:
"""Test that Crew with memory=True works with google-vertex embedder and memory is used."""
# Track memory events
events: dict[str, list] = defaultdict(list)
condition = threading.Condition()
@crewai_event_bus.on(MemorySaveStartedEvent)
def on_save_started(source, event):
with condition:
events["MemorySaveStartedEvent"].append(event)
condition.notify()
@crewai_event_bus.on(MemorySaveCompletedEvent)
def on_save_completed(source, event):
with condition:
events["MemorySaveCompletedEvent"].append(event)
condition.notify()
crew = Crew(
agents=[simple_agent],
tasks=[simple_task],
memory=True,
embedder=google_vertex_embedder_config,
verbose=False,
)
result = crew.kickoff()
assert result is not None
assert result.raw is not None
assert len(result.raw) > 0
with condition:
success = condition.wait_for(
lambda: len(events["MemorySaveCompletedEvent"]) >= 1,
timeout=10,
)
assert success, "Timeout waiting for memory save events - memory may not be working"
assert len(events["MemorySaveStartedEvent"]) >= 1, "No memory save started events"
assert len(events["MemorySaveCompletedEvent"]) >= 1, "Memory save completed events"
@pytest.mark.vcr()
@pytest.mark.timeout(120)
def test_crew_memory_with_google_vertex_project_id(simple_agent, simple_task) -> None:
"""Test Crew memory with Google Vertex using project_id authentication."""
project_id = os.getenv("GOOGLE_CLOUD_PROJECT")
if not project_id:
pytest.skip("GOOGLE_CLOUD_PROJECT environment variable not set")
# Track memory events
events: dict[str, list] = defaultdict(list)
condition = threading.Condition()
@crewai_event_bus.on(MemorySaveStartedEvent)
def on_save_started(source, event):
with condition:
events["MemorySaveStartedEvent"].append(event)
condition.notify()
@crewai_event_bus.on(MemorySaveCompletedEvent)
def on_save_completed(source, event):
with condition:
events["MemorySaveCompletedEvent"].append(event)
condition.notify()
embedder_config = {
"provider": "google-vertex",
"config": {
"project_id": project_id,
"location": "us-central1",
"model_name": "gemini-embedding-001",
},
}
crew = Crew(
agents=[simple_agent],
tasks=[simple_task],
memory=True,
embedder=embedder_config,
verbose=False,
)
result = crew.kickoff()
# Verify basic result
assert result is not None
assert result.raw is not None
# Wait for memory save events
with condition:
success = condition.wait_for(
lambda: len(events["MemorySaveCompletedEvent"]) >= 1,
timeout=10,
)
# Verify memory was actually used
assert success, "Timeout waiting for memory save events - memory may not be working"
assert len(events["MemorySaveStartedEvent"]) >= 1, "No memory save started events"
assert len(events["MemorySaveCompletedEvent"]) >= 1, "No memory save completed events"

View File

@@ -1,6 +1,6 @@
import os import os
import threading import threading
from unittest.mock import patch from unittest.mock import MagicMock, patch
import pytest import pytest
from crewai import Agent, Crew, Task from crewai import Agent, Crew, Task
@@ -121,3 +121,90 @@ def test_telemetry_singleton_pattern():
thread.join() thread.join()
assert all(instance is telemetry1 for instance in instances) assert all(instance is telemetry1 for instance in instances)
def test_signal_handler_registration_skipped_in_non_main_thread():
"""Test that signal handler registration is skipped when running from a non-main thread.
This test verifies that when Telemetry is initialized from a non-main thread,
the signal handler registration is skipped without raising noisy ValueError tracebacks.
See: https://github.com/crewAIInc/crewAI/issues/4289
"""
Telemetry._instance = None
result = {"register_signal_handler_called": False, "error": None}
def init_telemetry_in_thread():
try:
with patch("crewai.telemetry.telemetry.TracerProvider"):
with patch.object(
Telemetry,
"_register_signal_handler",
wraps=lambda *args, **kwargs: None,
) as mock_register:
telemetry = Telemetry()
result["register_signal_handler_called"] = mock_register.called
result["telemetry"] = telemetry
except Exception as e:
result["error"] = e
thread = threading.Thread(target=init_telemetry_in_thread)
thread.start()
thread.join()
assert result["error"] is None, f"Unexpected error: {result['error']}"
assert (
result["register_signal_handler_called"] is False
), "Signal handler should not be registered in non-main thread"
def test_signal_handler_registration_skipped_logs_debug_message():
"""Test that a debug message is logged when signal handler registration is skipped.
This test verifies that when Telemetry is initialized from a non-main thread,
a debug message is logged indicating that signal handler registration was skipped.
"""
Telemetry._instance = None
result = {"telemetry": None, "error": None, "debug_calls": []}
mock_logger_debug = MagicMock()
def init_telemetry_in_thread():
try:
with patch("crewai.telemetry.telemetry.TracerProvider"):
with patch(
"crewai.telemetry.telemetry.logger.debug", mock_logger_debug
):
result["telemetry"] = Telemetry()
result["debug_calls"] = [
str(call) for call in mock_logger_debug.call_args_list
]
except Exception as e:
result["error"] = e
thread = threading.Thread(target=init_telemetry_in_thread)
thread.start()
thread.join()
assert result["error"] is None, f"Unexpected error: {result['error']}"
assert result["telemetry"] is not None
debug_calls = result["debug_calls"]
assert any(
"Skipping signal handler registration" in call for call in debug_calls
), f"Expected debug message about skipping signal handler registration, got: {debug_calls}"
def test_signal_handlers_registered_in_main_thread():
"""Test that signal handlers are registered when running from the main thread."""
Telemetry._instance = None
with patch("crewai.telemetry.telemetry.TracerProvider"):
with patch(
"crewai.telemetry.telemetry.Telemetry._register_signal_handler"
) as mock_register:
telemetry = Telemetry()
assert telemetry.ready is True
assert mock_register.call_count >= 2

View File

@@ -1,3 +1,3 @@
"""CrewAI development tools.""" """CrewAI development tools."""
__version__ = "1.8.1" __version__ = "1.9.0"

8
uv.lock generated
View File

@@ -310,7 +310,7 @@ wheels = [
[[package]] [[package]]
name = "anthropic" name = "anthropic"
version = "0.71.1" version = "0.73.0"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "anyio" }, { name = "anyio" },
@@ -322,9 +322,9 @@ dependencies = [
{ name = "sniffio" }, { name = "sniffio" },
{ name = "typing-extensions" }, { name = "typing-extensions" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/05/4b/19620875841f692fdc35eb58bf0201c8ad8c47b8443fecbf1b225312175b/anthropic-0.71.1.tar.gz", hash = "sha256:a77d156d3e7d318b84681b59823b2dee48a8ac508a3e54e49f0ab0d074e4b0da", size = 493294, upload-time = "2025-10-28T17:28:42.213Z" } sdist = { url = "https://files.pythonhosted.org/packages/f0/07/f550112c3f5299d02f06580577f602e8a112b1988ad7c98ac1a8f7292d7e/anthropic-0.73.0.tar.gz", hash = "sha256:30f0d7d86390165f86af6ca7c3041f8720bb2e1b0e12a44525c8edfdbd2c5239", size = 425168, upload-time = "2025-11-14T18:47:52.635Z" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/4b/68/b2f988b13325f9ac9921b1e87f0b7994468014e1b5bd3bdbd2472f5baf45/anthropic-0.71.1-py3-none-any.whl", hash = "sha256:6ca6c579f0899a445faeeed9c0eb97aa4bdb751196262f9ccc96edfc0bb12679", size = 355020, upload-time = "2025-10-28T17:28:40.653Z" }, { url = "https://files.pythonhosted.org/packages/15/b1/5d4d3f649e151e58dc938cf19c4d0cd19fca9a986879f30fea08a7b17138/anthropic-0.73.0-py3-none-any.whl", hash = "sha256:0d56cd8b3ca3fea9c9b5162868bdfd053fbc189b8b56d4290bd2d427b56db769", size = 367839, upload-time = "2025-11-14T18:47:51.195Z" },
] ]
[[package]] [[package]]
@@ -1276,7 +1276,7 @@ requires-dist = [
{ name = "aiobotocore", marker = "extra == 'aws'", specifier = "~=2.25.2" }, { name = "aiobotocore", marker = "extra == 'aws'", specifier = "~=2.25.2" },
{ name = "aiocache", extras = ["memcached", "redis"], marker = "extra == 'a2a'", specifier = "~=0.12.3" }, { name = "aiocache", extras = ["memcached", "redis"], marker = "extra == 'a2a'", specifier = "~=0.12.3" },
{ name = "aiosqlite", specifier = "~=0.21.0" }, { name = "aiosqlite", specifier = "~=0.21.0" },
{ name = "anthropic", marker = "extra == 'anthropic'", specifier = "~=0.71.0" }, { name = "anthropic", marker = "extra == 'anthropic'", specifier = "~=0.73.0" },
{ name = "appdirs", specifier = "~=1.4.4" }, { name = "appdirs", specifier = "~=1.4.4" },
{ name = "azure-ai-inference", marker = "extra == 'azure-ai-inference'", specifier = "~=1.0.0b9" }, { name = "azure-ai-inference", marker = "extra == 'azure-ai-inference'", specifier = "~=1.0.0b9" },
{ name = "boto3", marker = "extra == 'aws'", specifier = "~=1.40.38" }, { name = "boto3", marker = "extra == 'aws'", specifier = "~=1.40.38" },