mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-30 10:38:14 +00:00
Compare commits
4 Commits
llm-event-
...
1.9.0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6b926b90d0 | ||
|
|
fc84daadbb | ||
|
|
58b866a83d | ||
|
|
9797567342 |
@@ -401,23 +401,58 @@ crew = Crew(
|
||||
|
||||
### Vertex AI Embeddings
|
||||
|
||||
For Google Cloud users with Vertex AI access.
|
||||
For Google Cloud users with Vertex AI access. Supports both legacy and new embedding models with automatic SDK selection.
|
||||
|
||||
<Note>
|
||||
**Deprecation Notice:** Legacy models (`textembedding-gecko*`) use the deprecated `vertexai.language_models` SDK which will be removed after June 24, 2026. Consider migrating to newer models like `gemini-embedding-001`. See the [Google migration guide](https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk) for details.
|
||||
</Note>
|
||||
|
||||
```python
|
||||
# Recommended: Using new models with google-genai SDK
|
||||
crew = Crew(
|
||||
memory=True,
|
||||
embedder={
|
||||
"provider": "vertexai",
|
||||
"provider": "google-vertex",
|
||||
"config": {
|
||||
"project_id": "your-gcp-project-id",
|
||||
"region": "us-central1", # or your preferred region
|
||||
"api_key": "your-service-account-key",
|
||||
"model_name": "textembedding-gecko"
|
||||
"location": "us-central1",
|
||||
"model_name": "gemini-embedding-001", # or "text-embedding-005", "text-multilingual-embedding-002"
|
||||
"task_type": "RETRIEVAL_DOCUMENT", # Optional
|
||||
"output_dimensionality": 768 # Optional
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Using API key authentication (Exp)
|
||||
crew = Crew(
|
||||
memory=True,
|
||||
embedder={
|
||||
"provider": "google-vertex",
|
||||
"config": {
|
||||
"api_key": "your-google-api-key",
|
||||
"model_name": "gemini-embedding-001"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Legacy models (backwards compatible, emits deprecation warning)
|
||||
crew = Crew(
|
||||
memory=True,
|
||||
embedder={
|
||||
"provider": "google-vertex",
|
||||
"config": {
|
||||
"project_id": "your-gcp-project-id",
|
||||
"region": "us-central1", # or "location" (region is deprecated)
|
||||
"model_name": "textembedding-gecko" # Legacy model
|
||||
}
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
**Available models:**
|
||||
- **New SDK models** (recommended): `gemini-embedding-001`, `text-embedding-005`, `text-multilingual-embedding-002`
|
||||
- **Legacy models** (deprecated): `textembedding-gecko`, `textembedding-gecko@001`, `textembedding-gecko-multilingual`
|
||||
|
||||
### Ollama Embeddings (Local)
|
||||
|
||||
Run embeddings locally for privacy and cost savings.
|
||||
@@ -569,7 +604,7 @@ mem0_client_embedder_config = {
|
||||
"project_id": "my_project_id", # Optional
|
||||
"api_key": "custom-api-key" # Optional - overrides env var
|
||||
"run_id": "my_run_id", # Optional - for short-term memory
|
||||
"includes": "include1", # Optional
|
||||
"includes": "include1", # Optional
|
||||
"excludes": "exclude1", # Optional
|
||||
"infer": True # Optional defaults to True
|
||||
"custom_categories": new_categories # Optional - custom categories for user memory
|
||||
@@ -591,7 +626,7 @@ crew = Crew(
|
||||
|
||||
### Choosing the Right Embedding Provider
|
||||
|
||||
When selecting an embedding provider, consider factors like performance, privacy, cost, and integration needs.
|
||||
When selecting an embedding provider, consider factors like performance, privacy, cost, and integration needs.
|
||||
Below is a comparison to help you decide:
|
||||
|
||||
| Provider | Best For | Pros | Cons |
|
||||
@@ -749,7 +784,7 @@ Entity Memory supports batching when saving multiple entities at once. When you
|
||||
|
||||
This improves performance and observability when writing many entities in one operation.
|
||||
|
||||
## 2. External Memory
|
||||
## 2. External Memory
|
||||
External Memory provides a standalone memory system that operates independently from the crew's built-in memory. This is ideal for specialized memory providers or cross-application memory sharing.
|
||||
|
||||
### Basic External Memory with Mem0
|
||||
@@ -819,7 +854,7 @@ external_memory = ExternalMemory(
|
||||
"project_id": "my_project_id", # Optional
|
||||
"api_key": "custom-api-key" # Optional - overrides env var
|
||||
"run_id": "my_run_id", # Optional - for short-term memory
|
||||
"includes": "include1", # Optional
|
||||
"includes": "include1", # Optional
|
||||
"excludes": "exclude1", # Optional
|
||||
"infer": True # Optional defaults to True
|
||||
"custom_categories": new_categories # Optional - custom categories for user memory
|
||||
|
||||
@@ -152,4 +152,4 @@ __all__ = [
|
||||
"wrap_file_source",
|
||||
]
|
||||
|
||||
__version__ = "1.8.1"
|
||||
__version__ = "1.9.0"
|
||||
|
||||
@@ -12,7 +12,7 @@ dependencies = [
|
||||
"pytube~=15.0.0",
|
||||
"requests~=2.32.5",
|
||||
"docker~=7.1.0",
|
||||
"crewai==1.8.1",
|
||||
"crewai==1.9.0",
|
||||
"lancedb~=0.5.4",
|
||||
"tiktoken~=0.8.0",
|
||||
"beautifulsoup4~=4.13.4",
|
||||
|
||||
@@ -291,4 +291,4 @@ __all__ = [
|
||||
"ZapierActionTools",
|
||||
]
|
||||
|
||||
__version__ = "1.8.1"
|
||||
__version__ = "1.9.0"
|
||||
|
||||
@@ -49,7 +49,7 @@ Repository = "https://github.com/crewAIInc/crewAI"
|
||||
|
||||
[project.optional-dependencies]
|
||||
tools = [
|
||||
"crewai-tools==1.8.1",
|
||||
"crewai-tools==1.9.0",
|
||||
]
|
||||
embeddings = [
|
||||
"tiktoken~=0.8.0"
|
||||
@@ -90,7 +90,7 @@ azure-ai-inference = [
|
||||
"azure-ai-inference~=1.0.0b9",
|
||||
]
|
||||
anthropic = [
|
||||
"anthropic~=0.71.0",
|
||||
"anthropic~=0.73.0",
|
||||
]
|
||||
a2a = [
|
||||
"a2a-sdk~=0.3.10",
|
||||
|
||||
@@ -40,7 +40,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
|
||||
|
||||
_suppress_pydantic_deprecation_warnings()
|
||||
|
||||
__version__ = "1.8.1"
|
||||
__version__ = "1.9.0"
|
||||
_telemetry_submitted = False
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.8.1"
|
||||
"crewai[tools]==1.9.0"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.8.1"
|
||||
"crewai[tools]==1.9.0"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -10,7 +10,6 @@ class LLMEventBase(BaseEvent):
|
||||
from_task: Any | None = None
|
||||
from_agent: Any | None = None
|
||||
model: str | None = None
|
||||
call_id: str
|
||||
|
||||
def __init__(self, **data: Any) -> None:
|
||||
if data.get("from_task"):
|
||||
|
||||
@@ -37,7 +37,7 @@ from crewai.events.types.tool_usage_events import (
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageStartedEvent,
|
||||
)
|
||||
from crewai.llms.base_llm import BaseLLM, get_current_call_id, llm_call_context
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.llms.constants import (
|
||||
ANTHROPIC_MODELS,
|
||||
AZURE_MODELS,
|
||||
@@ -770,7 +770,7 @@ class LLM(BaseLLM):
|
||||
chunk_content = None
|
||||
response_id = None
|
||||
|
||||
if hasattr(chunk, "id"):
|
||||
if hasattr(chunk,'id'):
|
||||
response_id = chunk.id
|
||||
|
||||
# Safely extract content from various chunk formats
|
||||
@@ -827,7 +827,7 @@ class LLM(BaseLLM):
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_id=response_id,
|
||||
response_id=response_id
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
@@ -849,8 +849,7 @@ class LLM(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
response_id=response_id,
|
||||
call_id=get_current_call_id(),
|
||||
response_id=response_id
|
||||
),
|
||||
)
|
||||
# --- 4) Fallback to non-streaming if no content received
|
||||
@@ -1016,10 +1015,7 @@ class LLM(BaseLLM):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(
|
||||
error=str(e),
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
call_id=get_current_call_id(),
|
||||
error=str(e), from_task=from_task, from_agent=from_agent
|
||||
),
|
||||
)
|
||||
raise Exception(f"Failed to get streaming response: {e!s}") from e
|
||||
@@ -1052,8 +1048,7 @@ class LLM(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
response_id=response_id,
|
||||
call_id=get_current_call_id(),
|
||||
response_id=response_id
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1481,8 +1476,7 @@ class LLM(BaseLLM):
|
||||
chunk=chunk_content,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_id=response_id,
|
||||
call_id=get_current_call_id(),
|
||||
response_id=response_id
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1625,12 +1619,7 @@ class LLM(BaseLLM):
|
||||
logging.error(f"Error executing function '{function_name}': {e}")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(
|
||||
error=f"Tool execution error: {e!s}",
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
call_id=get_current_call_id(),
|
||||
),
|
||||
event=LLMCallFailedEvent(error=f"Tool execution error: {e!s}"),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -1680,117 +1669,108 @@ class LLM(BaseLLM):
|
||||
ValueError: If response format is not supported
|
||||
LLMContextLengthExceededError: If input exceeds model's context limit
|
||||
"""
|
||||
with llm_call_context() as call_id:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallStartedEvent(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
model=self.model,
|
||||
call_id=call_id,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallStartedEvent(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
model=self.model,
|
||||
),
|
||||
)
|
||||
|
||||
# --- 2) Validate parameters before proceeding with the call
|
||||
self._validate_call_params()
|
||||
# --- 2) Validate parameters before proceeding with the call
|
||||
self._validate_call_params()
|
||||
|
||||
# --- 3) Convert string messages to proper format if needed
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
# --- 4) Handle O1 model special case (system messages not supported)
|
||||
if "o1" in self.model.lower():
|
||||
for message in messages:
|
||||
if message.get("role") == "system":
|
||||
msg_role: Literal["assistant"] = "assistant"
|
||||
message["role"] = msg_role
|
||||
# --- 3) Convert string messages to proper format if needed
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
# --- 4) Handle O1 model special case (system messages not supported)
|
||||
if "o1" in self.model.lower():
|
||||
for message in messages:
|
||||
if message.get("role") == "system":
|
||||
msg_role: Literal["assistant"] = "assistant"
|
||||
message["role"] = msg_role
|
||||
|
||||
if not self._invoke_before_llm_call_hooks(messages, from_agent):
|
||||
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||
if not self._invoke_before_llm_call_hooks(messages, from_agent):
|
||||
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||
|
||||
# --- 5) Set up callbacks if provided
|
||||
with suppress_warnings():
|
||||
if callbacks and len(callbacks) > 0:
|
||||
self.set_callbacks(callbacks)
|
||||
try:
|
||||
# --- 6) Prepare parameters for the completion call
|
||||
params = self._prepare_completion_params(messages, tools)
|
||||
# --- 7) Make the completion call and handle response
|
||||
if self.stream:
|
||||
result = self._handle_streaming_response(
|
||||
params=params,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
else:
|
||||
result = self._handle_non_streaming_response(
|
||||
params=params,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
if isinstance(result, str):
|
||||
result = self._invoke_after_llm_call_hooks(
|
||||
messages, result, from_agent
|
||||
)
|
||||
|
||||
return result
|
||||
except LLMContextLengthExceededError:
|
||||
# Re-raise LLMContextLengthExceededError as it should be handled
|
||||
# by the CrewAgentExecutor._invoke_loop method, which can then decide
|
||||
# whether to summarize the content or abort based on the respect_context_window flag
|
||||
raise
|
||||
except Exception as e:
|
||||
unsupported_stop = "Unsupported parameter" in str(
|
||||
e
|
||||
) and "'stop'" in str(e)
|
||||
|
||||
if unsupported_stop:
|
||||
if (
|
||||
"additional_drop_params" in self.additional_params
|
||||
and isinstance(
|
||||
self.additional_params["additional_drop_params"], list
|
||||
)
|
||||
):
|
||||
self.additional_params["additional_drop_params"].append(
|
||||
"stop"
|
||||
)
|
||||
else:
|
||||
self.additional_params = {
|
||||
"additional_drop_params": ["stop"]
|
||||
}
|
||||
|
||||
logging.info("Retrying LLM call without the unsupported 'stop'")
|
||||
|
||||
return self.call(
|
||||
messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(
|
||||
error=str(e),
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
call_id=get_current_call_id(),
|
||||
),
|
||||
# --- 5) Set up callbacks if provided
|
||||
with suppress_warnings():
|
||||
if callbacks and len(callbacks) > 0:
|
||||
self.set_callbacks(callbacks)
|
||||
try:
|
||||
# --- 6) Prepare parameters for the completion call
|
||||
params = self._prepare_completion_params(messages, tools)
|
||||
# --- 7) Make the completion call and handle response
|
||||
if self.stream:
|
||||
result = self._handle_streaming_response(
|
||||
params=params,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
raise
|
||||
else:
|
||||
result = self._handle_non_streaming_response(
|
||||
params=params,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
if isinstance(result, str):
|
||||
result = self._invoke_after_llm_call_hooks(
|
||||
messages, result, from_agent
|
||||
)
|
||||
|
||||
return result
|
||||
except LLMContextLengthExceededError:
|
||||
# Re-raise LLMContextLengthExceededError as it should be handled
|
||||
# by the CrewAgentExecutor._invoke_loop method, which can then decide
|
||||
# whether to summarize the content or abort based on the respect_context_window flag
|
||||
raise
|
||||
except Exception as e:
|
||||
unsupported_stop = "Unsupported parameter" in str(
|
||||
e
|
||||
) and "'stop'" in str(e)
|
||||
|
||||
if unsupported_stop:
|
||||
if (
|
||||
"additional_drop_params" in self.additional_params
|
||||
and isinstance(
|
||||
self.additional_params["additional_drop_params"], list
|
||||
)
|
||||
):
|
||||
self.additional_params["additional_drop_params"].append("stop")
|
||||
else:
|
||||
self.additional_params = {"additional_drop_params": ["stop"]}
|
||||
|
||||
logging.info("Retrying LLM call without the unsupported 'stop'")
|
||||
|
||||
return self.call(
|
||||
messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(
|
||||
error=str(e), from_task=from_task, from_agent=from_agent
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
async def acall(
|
||||
self,
|
||||
@@ -1828,54 +1808,43 @@ class LLM(BaseLLM):
|
||||
ValueError: If response format is not supported
|
||||
LLMContextLengthExceededError: If input exceeds model's context limit
|
||||
"""
|
||||
with llm_call_context() as call_id:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallStartedEvent(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
model=self.model,
|
||||
call_id=call_id,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallStartedEvent(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
model=self.model,
|
||||
),
|
||||
)
|
||||
|
||||
self._validate_call_params()
|
||||
self._validate_call_params()
|
||||
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
|
||||
# Process file attachments asynchronously before preparing params
|
||||
messages = await self._aprocess_message_files(messages)
|
||||
# Process file attachments asynchronously before preparing params
|
||||
messages = await self._aprocess_message_files(messages)
|
||||
|
||||
if "o1" in self.model.lower():
|
||||
for message in messages:
|
||||
if message.get("role") == "system":
|
||||
msg_role: Literal["assistant"] = "assistant"
|
||||
message["role"] = msg_role
|
||||
if "o1" in self.model.lower():
|
||||
for message in messages:
|
||||
if message.get("role") == "system":
|
||||
msg_role: Literal["assistant"] = "assistant"
|
||||
message["role"] = msg_role
|
||||
|
||||
with suppress_warnings():
|
||||
if callbacks and len(callbacks) > 0:
|
||||
self.set_callbacks(callbacks)
|
||||
try:
|
||||
params = self._prepare_completion_params(
|
||||
messages, tools, skip_file_processing=True
|
||||
)
|
||||
with suppress_warnings():
|
||||
if callbacks and len(callbacks) > 0:
|
||||
self.set_callbacks(callbacks)
|
||||
try:
|
||||
params = self._prepare_completion_params(
|
||||
messages, tools, skip_file_processing=True
|
||||
)
|
||||
|
||||
if self.stream:
|
||||
return await self._ahandle_streaming_response(
|
||||
params=params,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
return await self._ahandle_non_streaming_response(
|
||||
if self.stream:
|
||||
return await self._ahandle_streaming_response(
|
||||
params=params,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
@@ -1883,50 +1852,52 @@ class LLM(BaseLLM):
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
except LLMContextLengthExceededError:
|
||||
raise
|
||||
except Exception as e:
|
||||
unsupported_stop = "Unsupported parameter" in str(
|
||||
e
|
||||
) and "'stop'" in str(e)
|
||||
|
||||
if unsupported_stop:
|
||||
if (
|
||||
"additional_drop_params" in self.additional_params
|
||||
and isinstance(
|
||||
self.additional_params["additional_drop_params"], list
|
||||
)
|
||||
):
|
||||
self.additional_params["additional_drop_params"].append(
|
||||
"stop"
|
||||
)
|
||||
else:
|
||||
self.additional_params = {
|
||||
"additional_drop_params": ["stop"]
|
||||
}
|
||||
return await self._ahandle_non_streaming_response(
|
||||
params=params,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
except LLMContextLengthExceededError:
|
||||
raise
|
||||
except Exception as e:
|
||||
unsupported_stop = "Unsupported parameter" in str(
|
||||
e
|
||||
) and "'stop'" in str(e)
|
||||
|
||||
logging.info("Retrying LLM call without the unsupported 'stop'")
|
||||
|
||||
return await self.acall(
|
||||
messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
if unsupported_stop:
|
||||
if (
|
||||
"additional_drop_params" in self.additional_params
|
||||
and isinstance(
|
||||
self.additional_params["additional_drop_params"], list
|
||||
)
|
||||
):
|
||||
self.additional_params["additional_drop_params"].append("stop")
|
||||
else:
|
||||
self.additional_params = {"additional_drop_params": ["stop"]}
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(
|
||||
error=str(e),
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
call_id=get_current_call_id(),
|
||||
),
|
||||
logging.info("Retrying LLM call without the unsupported 'stop'")
|
||||
|
||||
return await self.acall(
|
||||
messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
raise
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(
|
||||
error=str(e), from_task=from_task, from_agent=from_agent
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
def _handle_emit_call_events(
|
||||
self,
|
||||
@@ -1954,7 +1925,6 @@ class LLM(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
model=self.model,
|
||||
call_id=get_current_call_id(),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -7,15 +7,11 @@ in CrewAI, including common functionality for native SDK implementations.
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
import contextvars
|
||||
from datetime import datetime
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, Final
|
||||
import uuid
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -54,32 +50,6 @@ DEFAULT_CONTEXT_WINDOW_SIZE: Final[int] = 4096
|
||||
DEFAULT_SUPPORTS_STOP_WORDS: Final[bool] = True
|
||||
_JSON_EXTRACTION_PATTERN: Final[re.Pattern[str]] = re.compile(r"\{.*}", re.DOTALL)
|
||||
|
||||
_current_call_id: contextvars.ContextVar[str | None] = contextvars.ContextVar(
|
||||
"_current_call_id", default=None
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def llm_call_context() -> Generator[str, None, None]:
|
||||
"""Context manager that establishes an LLM call scope with a unique call_id."""
|
||||
call_id = str(uuid.uuid4())
|
||||
token = _current_call_id.set(call_id)
|
||||
try:
|
||||
yield call_id
|
||||
finally:
|
||||
_current_call_id.reset(token)
|
||||
|
||||
|
||||
def get_current_call_id() -> str:
|
||||
"""Get current call_id from context"""
|
||||
call_id = _current_call_id.get()
|
||||
if call_id is None:
|
||||
logging.warning(
|
||||
"LLM event emitted outside call context - generating fallback call_id"
|
||||
)
|
||||
return str(uuid.uuid4())
|
||||
return call_id
|
||||
|
||||
|
||||
class BaseLLM(ABC):
|
||||
"""Abstract base class for LLM implementations.
|
||||
@@ -381,7 +351,6 @@ class BaseLLM(ABC):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
model=self.model,
|
||||
call_id=get_current_call_id(),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -405,7 +374,6 @@ class BaseLLM(ABC):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
model=self.model,
|
||||
call_id=get_current_call_id(),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -426,7 +394,6 @@ class BaseLLM(ABC):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
model=self.model,
|
||||
call_id=get_current_call_id(),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -437,7 +404,7 @@ class BaseLLM(ABC):
|
||||
from_agent: Agent | None = None,
|
||||
tool_call: dict[str, Any] | None = None,
|
||||
call_type: LLMCallType | None = None,
|
||||
response_id: str | None = None,
|
||||
response_id: str | None = None
|
||||
) -> None:
|
||||
"""Emit stream chunk event.
|
||||
|
||||
@@ -460,8 +427,7 @@ class BaseLLM(ABC):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
call_type=call_type,
|
||||
response_id=response_id,
|
||||
call_id=get_current_call_id(),
|
||||
response_id=response_id
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -3,13 +3,12 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
from typing import TYPE_CHECKING, Any, Final, Literal, TypeGuard, cast
|
||||
|
||||
from anthropic.types import ThinkingBlock
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llms.base_llm import BaseLLM, llm_call_context
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.llms.hooks.transport import AsyncHTTPTransport, HTTPTransport
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
@@ -22,8 +21,9 @@ if TYPE_CHECKING:
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
|
||||
try:
|
||||
from anthropic import Anthropic, AsyncAnthropic
|
||||
from anthropic import Anthropic, AsyncAnthropic, transform_schema
|
||||
from anthropic.types import Message, TextBlock, ThinkingBlock, ToolUseBlock
|
||||
from anthropic.types.beta import BetaMessage
|
||||
import httpx
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
@@ -31,7 +31,62 @@ except ImportError:
|
||||
) from None
|
||||
|
||||
|
||||
ANTHROPIC_FILES_API_BETA = "files-api-2025-04-14"
|
||||
ANTHROPIC_FILES_API_BETA: Final = "files-api-2025-04-14"
|
||||
ANTHROPIC_STRUCTURED_OUTPUTS_BETA: Final = "structured-outputs-2025-11-13"
|
||||
|
||||
NATIVE_STRUCTURED_OUTPUT_MODELS: Final[
|
||||
tuple[
|
||||
Literal["claude-sonnet-4-5"],
|
||||
Literal["claude-sonnet-4.5"],
|
||||
Literal["claude-opus-4-5"],
|
||||
Literal["claude-opus-4.5"],
|
||||
Literal["claude-opus-4-1"],
|
||||
Literal["claude-opus-4.1"],
|
||||
Literal["claude-haiku-4-5"],
|
||||
Literal["claude-haiku-4.5"],
|
||||
]
|
||||
] = (
|
||||
"claude-sonnet-4-5",
|
||||
"claude-sonnet-4.5",
|
||||
"claude-opus-4-5",
|
||||
"claude-opus-4.5",
|
||||
"claude-opus-4-1",
|
||||
"claude-opus-4.1",
|
||||
"claude-haiku-4-5",
|
||||
"claude-haiku-4.5",
|
||||
)
|
||||
|
||||
|
||||
def _supports_native_structured_outputs(model: str) -> bool:
|
||||
"""Check if the model supports native structured outputs.
|
||||
|
||||
Native structured outputs are only available for Claude 4.5 models
|
||||
(Sonnet 4.5, Opus 4.5, Opus 4.1, Haiku 4.5).
|
||||
Other models require the tool-based fallback approach.
|
||||
|
||||
Args:
|
||||
model: The model name/identifier.
|
||||
|
||||
Returns:
|
||||
True if the model supports native structured outputs.
|
||||
"""
|
||||
model_lower = model.lower()
|
||||
return any(prefix in model_lower for prefix in NATIVE_STRUCTURED_OUTPUT_MODELS)
|
||||
|
||||
|
||||
def _is_pydantic_model_class(obj: Any) -> TypeGuard[type[BaseModel]]:
|
||||
"""Check if an object is a Pydantic model class.
|
||||
|
||||
This distinguishes between Pydantic model classes that support structured
|
||||
outputs (have model_json_schema) and plain dicts like {"type": "json_object"}.
|
||||
|
||||
Args:
|
||||
obj: The object to check.
|
||||
|
||||
Returns:
|
||||
True if obj is a Pydantic model class.
|
||||
"""
|
||||
return isinstance(obj, type) and issubclass(obj, BaseModel)
|
||||
|
||||
|
||||
def _contains_file_id_reference(messages: list[dict[str, Any]]) -> bool:
|
||||
@@ -84,6 +139,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
client_params: dict[str, Any] | None = None,
|
||||
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
|
||||
thinking: AnthropicThinkingConfig | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize Anthropic chat completion client.
|
||||
@@ -101,6 +157,8 @@ class AnthropicCompletion(BaseLLM):
|
||||
stream: Enable streaming responses
|
||||
client_params: Additional parameters for the Anthropic client
|
||||
interceptor: HTTP interceptor for modifying requests/responses at transport level.
|
||||
response_format: Pydantic model for structured output. When provided, responses
|
||||
will be validated against this model schema.
|
||||
**kwargs: Additional parameters
|
||||
"""
|
||||
super().__init__(
|
||||
@@ -131,6 +189,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
self.stop_sequences = stop_sequences or []
|
||||
self.thinking = thinking
|
||||
self.previous_thinking_blocks: list[ThinkingBlock] = []
|
||||
self.response_format = response_format
|
||||
# Model-specific settings
|
||||
self.is_claude_3 = "claude-3" in model.lower()
|
||||
self.supports_tools = True
|
||||
@@ -207,58 +266,57 @@ class AnthropicCompletion(BaseLLM):
|
||||
Returns:
|
||||
Chat completion response or tool call result
|
||||
"""
|
||||
with llm_call_context():
|
||||
try:
|
||||
# Emit call started event
|
||||
self._emit_call_started_event(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
try:
|
||||
# Emit call started event
|
||||
self._emit_call_started_event(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
# Format messages for Anthropic
|
||||
formatted_messages, system_message = (
|
||||
self._format_messages_for_anthropic(messages)
|
||||
)
|
||||
# Format messages for Anthropic
|
||||
formatted_messages, system_message = self._format_messages_for_anthropic(
|
||||
messages
|
||||
)
|
||||
|
||||
if not self._invoke_before_llm_call_hooks(
|
||||
formatted_messages, from_agent
|
||||
):
|
||||
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||
if not self._invoke_before_llm_call_hooks(formatted_messages, from_agent):
|
||||
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||
|
||||
# Prepare completion parameters
|
||||
completion_params = self._prepare_completion_params(
|
||||
formatted_messages, system_message, tools
|
||||
)
|
||||
# Prepare completion parameters
|
||||
completion_params = self._prepare_completion_params(
|
||||
formatted_messages, system_message, tools
|
||||
)
|
||||
|
||||
# Handle streaming vs non-streaming
|
||||
if self.stream:
|
||||
return self._handle_streaming_completion(
|
||||
completion_params,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
)
|
||||
effective_response_model = response_model or self.response_format
|
||||
|
||||
return self._handle_completion(
|
||||
# Handle streaming vs non-streaming
|
||||
if self.stream:
|
||||
return self._handle_streaming_completion(
|
||||
completion_params,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
effective_response_model,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Anthropic API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
return self._handle_completion(
|
||||
completion_params,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
effective_response_model,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Anthropic API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
|
||||
async def acall(
|
||||
self,
|
||||
@@ -283,49 +341,50 @@ class AnthropicCompletion(BaseLLM):
|
||||
Returns:
|
||||
Chat completion response or tool call result
|
||||
"""
|
||||
with llm_call_context():
|
||||
try:
|
||||
self._emit_call_started_event(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
try:
|
||||
self._emit_call_started_event(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
formatted_messages, system_message = (
|
||||
self._format_messages_for_anthropic(messages)
|
||||
)
|
||||
formatted_messages, system_message = self._format_messages_for_anthropic(
|
||||
messages
|
||||
)
|
||||
|
||||
completion_params = self._prepare_completion_params(
|
||||
formatted_messages, system_message, tools
|
||||
)
|
||||
completion_params = self._prepare_completion_params(
|
||||
formatted_messages, system_message, tools
|
||||
)
|
||||
|
||||
if self.stream:
|
||||
return await self._ahandle_streaming_completion(
|
||||
completion_params,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
)
|
||||
effective_response_model = response_model or self.response_format
|
||||
|
||||
return await self._ahandle_completion(
|
||||
if self.stream:
|
||||
return await self._ahandle_streaming_completion(
|
||||
completion_params,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
effective_response_model,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Anthropic API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
return await self._ahandle_completion(
|
||||
completion_params,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
effective_response_model,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Anthropic API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
|
||||
def _prepare_completion_params(
|
||||
self,
|
||||
@@ -569,22 +628,40 @@ class AnthropicCompletion(BaseLLM):
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming message completion."""
|
||||
if response_model:
|
||||
structured_tool = {
|
||||
"name": "structured_output",
|
||||
"description": "Returns structured data according to the schema",
|
||||
"input_schema": response_model.model_json_schema(),
|
||||
}
|
||||
|
||||
params["tools"] = [structured_tool]
|
||||
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
|
||||
|
||||
uses_file_api = _contains_file_id_reference(params.get("messages", []))
|
||||
betas: list[str] = []
|
||||
use_native_structured_output = False
|
||||
|
||||
if uses_file_api:
|
||||
betas.append(ANTHROPIC_FILES_API_BETA)
|
||||
|
||||
extra_body: dict[str, Any] | None = None
|
||||
if _is_pydantic_model_class(response_model):
|
||||
schema = transform_schema(response_model.model_json_schema())
|
||||
if _supports_native_structured_outputs(self.model):
|
||||
use_native_structured_output = True
|
||||
betas.append(ANTHROPIC_STRUCTURED_OUTPUTS_BETA)
|
||||
extra_body = {
|
||||
"output_format": {
|
||||
"type": "json_schema",
|
||||
"schema": schema,
|
||||
}
|
||||
}
|
||||
else:
|
||||
structured_tool = {
|
||||
"name": "structured_output",
|
||||
"description": "Output the structured response",
|
||||
"input_schema": schema,
|
||||
}
|
||||
params["tools"] = [structured_tool]
|
||||
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
|
||||
|
||||
try:
|
||||
if uses_file_api:
|
||||
params["betas"] = [ANTHROPIC_FILES_API_BETA]
|
||||
response = self.client.beta.messages.create(**params)
|
||||
if betas:
|
||||
params["betas"] = betas
|
||||
response = self.client.beta.messages.create(
|
||||
**params, extra_body=extra_body
|
||||
)
|
||||
else:
|
||||
response = self.client.messages.create(**params)
|
||||
|
||||
@@ -597,22 +674,34 @@ class AnthropicCompletion(BaseLLM):
|
||||
usage = self._extract_anthropic_token_usage(response)
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
if response_model and response.content:
|
||||
tool_uses = [
|
||||
block for block in response.content if isinstance(block, ToolUseBlock)
|
||||
]
|
||||
if tool_uses and tool_uses[0].name == "structured_output":
|
||||
structured_data = tool_uses[0].input
|
||||
structured_json = json.dumps(structured_data)
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return structured_json
|
||||
if _is_pydantic_model_class(response_model) and response.content:
|
||||
if use_native_structured_output:
|
||||
for block in response.content:
|
||||
if isinstance(block, TextBlock):
|
||||
structured_json = block.text
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return structured_json
|
||||
else:
|
||||
for block in response.content:
|
||||
if (
|
||||
isinstance(block, ToolUseBlock)
|
||||
and block.name == "structured_output"
|
||||
):
|
||||
structured_json = json.dumps(block.input)
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return structured_json
|
||||
|
||||
# Check if Claude wants to use tools
|
||||
if response.content:
|
||||
@@ -682,17 +771,31 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
) -> str | Any:
|
||||
"""Handle streaming message completion."""
|
||||
if response_model:
|
||||
structured_tool = {
|
||||
"name": "structured_output",
|
||||
"description": "Returns structured data according to the schema",
|
||||
"input_schema": response_model.model_json_schema(),
|
||||
}
|
||||
betas: list[str] = []
|
||||
use_native_structured_output = False
|
||||
|
||||
params["tools"] = [structured_tool]
|
||||
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
|
||||
extra_body: dict[str, Any] | None = None
|
||||
if _is_pydantic_model_class(response_model):
|
||||
schema = transform_schema(response_model.model_json_schema())
|
||||
if _supports_native_structured_outputs(self.model):
|
||||
use_native_structured_output = True
|
||||
betas.append(ANTHROPIC_STRUCTURED_OUTPUTS_BETA)
|
||||
extra_body = {
|
||||
"output_format": {
|
||||
"type": "json_schema",
|
||||
"schema": schema,
|
||||
}
|
||||
}
|
||||
else:
|
||||
structured_tool = {
|
||||
"name": "structured_output",
|
||||
"description": "Output the structured response",
|
||||
"input_schema": schema,
|
||||
}
|
||||
params["tools"] = [structured_tool]
|
||||
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
|
||||
|
||||
full_response = ""
|
||||
|
||||
@@ -700,10 +803,17 @@ class AnthropicCompletion(BaseLLM):
|
||||
# (the SDK sets it internally)
|
||||
stream_params = {k: v for k, v in params.items() if k != "stream"}
|
||||
|
||||
if betas:
|
||||
stream_params["betas"] = betas
|
||||
|
||||
current_tool_calls: dict[int, dict[str, Any]] = {}
|
||||
|
||||
# Make streaming API call
|
||||
with self.client.messages.stream(**stream_params) as stream:
|
||||
stream_context = (
|
||||
self.client.beta.messages.stream(**stream_params, extra_body=extra_body)
|
||||
if betas
|
||||
else self.client.messages.stream(**stream_params)
|
||||
)
|
||||
with stream_context as stream:
|
||||
response_id = None
|
||||
for event in stream:
|
||||
if hasattr(event, "message") and hasattr(event.message, "id"):
|
||||
@@ -770,7 +880,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
response_id=response_id,
|
||||
)
|
||||
|
||||
final_message: Message = stream.get_final_message()
|
||||
final_message = stream.get_final_message()
|
||||
|
||||
thinking_blocks: list[ThinkingBlock] = []
|
||||
if final_message.content:
|
||||
@@ -785,25 +895,30 @@ class AnthropicCompletion(BaseLLM):
|
||||
usage = self._extract_anthropic_token_usage(final_message)
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
if response_model and final_message.content:
|
||||
tool_uses = [
|
||||
block
|
||||
for block in final_message.content
|
||||
if isinstance(block, ToolUseBlock)
|
||||
]
|
||||
if tool_uses and tool_uses[0].name == "structured_output":
|
||||
structured_data = tool_uses[0].input
|
||||
structured_json = json.dumps(structured_data)
|
||||
|
||||
if _is_pydantic_model_class(response_model):
|
||||
if use_native_structured_output:
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return structured_json
|
||||
return full_response
|
||||
for block in final_message.content:
|
||||
if (
|
||||
isinstance(block, ToolUseBlock)
|
||||
and block.name == "structured_output"
|
||||
):
|
||||
structured_json = json.dumps(block.input)
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return structured_json
|
||||
|
||||
if final_message.content:
|
||||
tool_uses = [
|
||||
@@ -813,11 +928,9 @@ class AnthropicCompletion(BaseLLM):
|
||||
]
|
||||
|
||||
if tool_uses:
|
||||
# If no available_functions, return tool calls for executor to handle
|
||||
if not available_functions:
|
||||
return list(tool_uses)
|
||||
|
||||
# Handle tool use conversation flow internally
|
||||
return self._handle_tool_use_conversation(
|
||||
final_message,
|
||||
tool_uses,
|
||||
@@ -827,10 +940,8 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_agent,
|
||||
)
|
||||
|
||||
# Apply stop words to full response
|
||||
full_response = self._apply_stop_words(full_response)
|
||||
|
||||
# Emit completion event and return full response
|
||||
self._emit_call_completed_event(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
@@ -888,7 +999,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
def _handle_tool_use_conversation(
|
||||
self,
|
||||
initial_response: Message,
|
||||
initial_response: Message | BetaMessage,
|
||||
tool_uses: list[ToolUseBlock],
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any],
|
||||
@@ -1006,22 +1117,40 @@ class AnthropicCompletion(BaseLLM):
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming async message completion."""
|
||||
if response_model:
|
||||
structured_tool = {
|
||||
"name": "structured_output",
|
||||
"description": "Returns structured data according to the schema",
|
||||
"input_schema": response_model.model_json_schema(),
|
||||
}
|
||||
|
||||
params["tools"] = [structured_tool]
|
||||
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
|
||||
|
||||
uses_file_api = _contains_file_id_reference(params.get("messages", []))
|
||||
betas: list[str] = []
|
||||
use_native_structured_output = False
|
||||
|
||||
if uses_file_api:
|
||||
betas.append(ANTHROPIC_FILES_API_BETA)
|
||||
|
||||
extra_body: dict[str, Any] | None = None
|
||||
if _is_pydantic_model_class(response_model):
|
||||
schema = transform_schema(response_model.model_json_schema())
|
||||
if _supports_native_structured_outputs(self.model):
|
||||
use_native_structured_output = True
|
||||
betas.append(ANTHROPIC_STRUCTURED_OUTPUTS_BETA)
|
||||
extra_body = {
|
||||
"output_format": {
|
||||
"type": "json_schema",
|
||||
"schema": schema,
|
||||
}
|
||||
}
|
||||
else:
|
||||
structured_tool = {
|
||||
"name": "structured_output",
|
||||
"description": "Output the structured response",
|
||||
"input_schema": schema,
|
||||
}
|
||||
params["tools"] = [structured_tool]
|
||||
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
|
||||
|
||||
try:
|
||||
if uses_file_api:
|
||||
params["betas"] = [ANTHROPIC_FILES_API_BETA]
|
||||
response = await self.async_client.beta.messages.create(**params)
|
||||
if betas:
|
||||
params["betas"] = betas
|
||||
response = await self.async_client.beta.messages.create(
|
||||
**params, extra_body=extra_body
|
||||
)
|
||||
else:
|
||||
response = await self.async_client.messages.create(**params)
|
||||
|
||||
@@ -1034,23 +1163,34 @@ class AnthropicCompletion(BaseLLM):
|
||||
usage = self._extract_anthropic_token_usage(response)
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
if response_model and response.content:
|
||||
tool_uses = [
|
||||
block for block in response.content if isinstance(block, ToolUseBlock)
|
||||
]
|
||||
if tool_uses and tool_uses[0].name == "structured_output":
|
||||
structured_data = tool_uses[0].input
|
||||
structured_json = json.dumps(structured_data)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return structured_json
|
||||
if _is_pydantic_model_class(response_model) and response.content:
|
||||
if use_native_structured_output:
|
||||
for block in response.content:
|
||||
if isinstance(block, TextBlock):
|
||||
structured_json = block.text
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return structured_json
|
||||
else:
|
||||
for block in response.content:
|
||||
if (
|
||||
isinstance(block, ToolUseBlock)
|
||||
and block.name == "structured_output"
|
||||
):
|
||||
structured_json = json.dumps(block.input)
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return structured_json
|
||||
|
||||
if response.content:
|
||||
tool_uses = [
|
||||
@@ -1106,25 +1246,49 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
) -> str | Any:
|
||||
"""Handle async streaming message completion."""
|
||||
if response_model:
|
||||
structured_tool = {
|
||||
"name": "structured_output",
|
||||
"description": "Returns structured data according to the schema",
|
||||
"input_schema": response_model.model_json_schema(),
|
||||
}
|
||||
betas: list[str] = []
|
||||
use_native_structured_output = False
|
||||
|
||||
params["tools"] = [structured_tool]
|
||||
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
|
||||
extra_body: dict[str, Any] | None = None
|
||||
if _is_pydantic_model_class(response_model):
|
||||
schema = transform_schema(response_model.model_json_schema())
|
||||
if _supports_native_structured_outputs(self.model):
|
||||
use_native_structured_output = True
|
||||
betas.append(ANTHROPIC_STRUCTURED_OUTPUTS_BETA)
|
||||
extra_body = {
|
||||
"output_format": {
|
||||
"type": "json_schema",
|
||||
"schema": schema,
|
||||
}
|
||||
}
|
||||
else:
|
||||
structured_tool = {
|
||||
"name": "structured_output",
|
||||
"description": "Output the structured response",
|
||||
"input_schema": schema,
|
||||
}
|
||||
params["tools"] = [structured_tool]
|
||||
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
|
||||
|
||||
full_response = ""
|
||||
|
||||
stream_params = {k: v for k, v in params.items() if k != "stream"}
|
||||
|
||||
if betas:
|
||||
stream_params["betas"] = betas
|
||||
|
||||
current_tool_calls: dict[int, dict[str, Any]] = {}
|
||||
|
||||
async with self.async_client.messages.stream(**stream_params) as stream:
|
||||
stream_context = (
|
||||
self.async_client.beta.messages.stream(
|
||||
**stream_params, extra_body=extra_body
|
||||
)
|
||||
if betas
|
||||
else self.async_client.messages.stream(**stream_params)
|
||||
)
|
||||
async with stream_context as stream:
|
||||
response_id = None
|
||||
async for event in stream:
|
||||
if hasattr(event, "message") and hasattr(event.message, "id"):
|
||||
@@ -1191,30 +1355,35 @@ class AnthropicCompletion(BaseLLM):
|
||||
response_id=response_id,
|
||||
)
|
||||
|
||||
final_message: Message = await stream.get_final_message()
|
||||
final_message = await stream.get_final_message()
|
||||
|
||||
usage = self._extract_anthropic_token_usage(final_message)
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
if response_model and final_message.content:
|
||||
tool_uses = [
|
||||
block
|
||||
for block in final_message.content
|
||||
if isinstance(block, ToolUseBlock)
|
||||
]
|
||||
if tool_uses and tool_uses[0].name == "structured_output":
|
||||
structured_data = tool_uses[0].input
|
||||
structured_json = json.dumps(structured_data)
|
||||
|
||||
if _is_pydantic_model_class(response_model):
|
||||
if use_native_structured_output:
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return structured_json
|
||||
return full_response
|
||||
for block in final_message.content:
|
||||
if (
|
||||
isinstance(block, ToolUseBlock)
|
||||
and block.name == "structured_output"
|
||||
):
|
||||
structured_json = json.dumps(block.input)
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return structured_json
|
||||
|
||||
if final_message.content:
|
||||
tool_uses = [
|
||||
@@ -1224,7 +1393,6 @@ class AnthropicCompletion(BaseLLM):
|
||||
]
|
||||
|
||||
if tool_uses:
|
||||
# If no available_functions, return tool calls for executor to handle
|
||||
if not available_functions:
|
||||
return list(tool_uses)
|
||||
|
||||
@@ -1251,7 +1419,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
async def _ahandle_tool_use_conversation(
|
||||
self,
|
||||
initial_response: Message,
|
||||
initial_response: Message | BetaMessage,
|
||||
tool_uses: list[ToolUseBlock],
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any],
|
||||
@@ -1360,7 +1528,9 @@ class AnthropicCompletion(BaseLLM):
|
||||
return int(200000 * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
|
||||
@staticmethod
|
||||
def _extract_anthropic_token_usage(response: Message) -> dict[str, Any]:
|
||||
def _extract_anthropic_token_usage(
|
||||
response: Message | BetaMessage,
|
||||
) -> dict[str, Any]:
|
||||
"""Extract token usage from Anthropic response."""
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
usage = response.usage
|
||||
|
||||
@@ -43,7 +43,7 @@ try:
|
||||
)
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llms.base_llm import BaseLLM, llm_call_context
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
@@ -92,6 +92,7 @@ class AzureCompletion(BaseLLM):
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = False,
|
||||
interceptor: BaseInterceptor[Any, Any] | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize Azure AI Inference chat completion client.
|
||||
@@ -111,6 +112,9 @@ class AzureCompletion(BaseLLM):
|
||||
stop: Stop sequences
|
||||
stream: Enable streaming responses
|
||||
interceptor: HTTP interceptor (not yet supported for Azure).
|
||||
response_format: Pydantic model for structured output. Used as default when
|
||||
response_model is not passed to call()/acall() methods.
|
||||
Only works with OpenAI models deployed on Azure.
|
||||
**kwargs: Additional parameters
|
||||
"""
|
||||
if interceptor is not None:
|
||||
@@ -165,6 +169,7 @@ class AzureCompletion(BaseLLM):
|
||||
self.presence_penalty = presence_penalty
|
||||
self.max_tokens = max_tokens
|
||||
self.stream = stream
|
||||
self.response_format = response_format
|
||||
|
||||
self.is_openai_model = any(
|
||||
prefix in model.lower() for prefix in ["gpt-", "o1-", "text-"]
|
||||
@@ -288,51 +293,49 @@ class AzureCompletion(BaseLLM):
|
||||
Returns:
|
||||
Chat completion response or tool call result
|
||||
"""
|
||||
with llm_call_context():
|
||||
try:
|
||||
# Emit call started event
|
||||
self._emit_call_started_event(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
try:
|
||||
# Emit call started event
|
||||
self._emit_call_started_event(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
effective_response_model = response_model or self.response_format
|
||||
|
||||
# Format messages for Azure
|
||||
formatted_messages = self._format_messages_for_azure(messages)
|
||||
# Format messages for Azure
|
||||
formatted_messages = self._format_messages_for_azure(messages)
|
||||
|
||||
if not self._invoke_before_llm_call_hooks(
|
||||
formatted_messages, from_agent
|
||||
):
|
||||
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||
if not self._invoke_before_llm_call_hooks(formatted_messages, from_agent):
|
||||
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||
|
||||
# Prepare completion parameters
|
||||
completion_params = self._prepare_completion_params(
|
||||
formatted_messages, tools, response_model
|
||||
)
|
||||
# Prepare completion parameters
|
||||
completion_params = self._prepare_completion_params(
|
||||
formatted_messages, tools, effective_response_model
|
||||
)
|
||||
|
||||
# Handle streaming vs non-streaming
|
||||
if self.stream:
|
||||
return self._handle_streaming_completion(
|
||||
completion_params,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
)
|
||||
|
||||
return self._handle_completion(
|
||||
# Handle streaming vs non-streaming
|
||||
if self.stream:
|
||||
return self._handle_streaming_completion(
|
||||
completion_params,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
effective_response_model,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return self._handle_api_error(e, from_task, from_agent) # type: ignore[func-returns-value]
|
||||
return self._handle_completion(
|
||||
completion_params,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
effective_response_model,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return self._handle_api_error(e, from_task, from_agent) # type: ignore[func-returns-value]
|
||||
|
||||
async def acall( # type: ignore[return]
|
||||
self,
|
||||
@@ -358,42 +361,42 @@ class AzureCompletion(BaseLLM):
|
||||
Returns:
|
||||
Chat completion response or tool call result
|
||||
"""
|
||||
with llm_call_context():
|
||||
try:
|
||||
self._emit_call_started_event(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
try:
|
||||
self._emit_call_started_event(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
effective_response_model = response_model or self.response_format
|
||||
|
||||
formatted_messages = self._format_messages_for_azure(messages)
|
||||
formatted_messages = self._format_messages_for_azure(messages)
|
||||
|
||||
completion_params = self._prepare_completion_params(
|
||||
formatted_messages, tools, response_model
|
||||
)
|
||||
completion_params = self._prepare_completion_params(
|
||||
formatted_messages, tools, effective_response_model
|
||||
)
|
||||
|
||||
if self.stream:
|
||||
return await self._ahandle_streaming_completion(
|
||||
completion_params,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
)
|
||||
|
||||
return await self._ahandle_completion(
|
||||
if self.stream:
|
||||
return await self._ahandle_streaming_completion(
|
||||
completion_params,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
effective_response_model,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self._handle_api_error(e, from_task, from_agent)
|
||||
return await self._ahandle_completion(
|
||||
completion_params,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
effective_response_model,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self._handle_api_error(e, from_task, from_agent)
|
||||
|
||||
def _prepare_completion_params(
|
||||
self,
|
||||
|
||||
@@ -11,7 +11,7 @@ from pydantic import BaseModel
|
||||
from typing_extensions import Required
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llms.base_llm import BaseLLM, llm_call_context
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
@@ -172,6 +172,7 @@ class BedrockCompletion(BaseLLM):
|
||||
additional_model_request_fields: dict[str, Any] | None = None,
|
||||
additional_model_response_field_paths: list[str] | None = None,
|
||||
interceptor: BaseInterceptor[Any, Any] | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize AWS Bedrock completion client.
|
||||
@@ -192,6 +193,8 @@ class BedrockCompletion(BaseLLM):
|
||||
additional_model_request_fields: Model-specific request parameters
|
||||
additional_model_response_field_paths: Custom response field paths
|
||||
interceptor: HTTP interceptor (not yet supported for Bedrock).
|
||||
response_format: Pydantic model for structured output. Used as default when
|
||||
response_model is not passed to call()/acall() methods.
|
||||
**kwargs: Additional parameters
|
||||
"""
|
||||
if interceptor is not None:
|
||||
@@ -248,6 +251,7 @@ class BedrockCompletion(BaseLLM):
|
||||
self.top_k = top_k
|
||||
self.stream = stream
|
||||
self.stop_sequences = stop_sequences
|
||||
self.response_format = response_format
|
||||
|
||||
# Store advanced features (optional)
|
||||
self.guardrail_config = guardrail_config
|
||||
@@ -299,107 +303,107 @@ class BedrockCompletion(BaseLLM):
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Call AWS Bedrock Converse API."""
|
||||
with llm_call_context():
|
||||
try:
|
||||
# Emit call started event
|
||||
self._emit_call_started_event(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
effective_response_model = response_model or self.response_format
|
||||
|
||||
try:
|
||||
# Emit call started event
|
||||
self._emit_call_started_event(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
# Format messages for Converse API
|
||||
formatted_messages, system_message = self._format_messages_for_converse(
|
||||
messages
|
||||
)
|
||||
|
||||
if not self._invoke_before_llm_call_hooks(formatted_messages, from_agent):
|
||||
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||
|
||||
# Prepare request body
|
||||
body: BedrockConverseRequestBody = {
|
||||
"inferenceConfig": self._get_inference_config(),
|
||||
}
|
||||
|
||||
# Add system message if present
|
||||
if system_message:
|
||||
body["system"] = cast(
|
||||
"list[SystemContentBlockTypeDef]",
|
||||
cast(object, [{"text": system_message}]),
|
||||
)
|
||||
|
||||
# Format messages for Converse API
|
||||
formatted_messages, system_message = self._format_messages_for_converse(
|
||||
messages
|
||||
)
|
||||
|
||||
if not self._invoke_before_llm_call_hooks(
|
||||
formatted_messages, from_agent
|
||||
):
|
||||
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||
|
||||
# Prepare request body
|
||||
body: BedrockConverseRequestBody = {
|
||||
"inferenceConfig": self._get_inference_config(),
|
||||
# Add tool config if present or if messages contain tool content
|
||||
# Bedrock requires toolConfig when messages have toolUse/toolResult
|
||||
if tools:
|
||||
tool_config: ToolConfigurationTypeDef = {
|
||||
"tools": cast(
|
||||
"Sequence[ToolTypeDef]",
|
||||
cast(object, self._format_tools_for_converse(tools)),
|
||||
)
|
||||
}
|
||||
|
||||
# Add system message if present
|
||||
if system_message:
|
||||
body["system"] = cast(
|
||||
"list[SystemContentBlockTypeDef]",
|
||||
cast(object, [{"text": system_message}]),
|
||||
body["toolConfig"] = tool_config
|
||||
elif self._messages_contain_tool_content(formatted_messages):
|
||||
# Create minimal toolConfig from tool history in messages
|
||||
tools_from_history = self._extract_tools_from_message_history(
|
||||
formatted_messages
|
||||
)
|
||||
if tools_from_history:
|
||||
body["toolConfig"] = cast(
|
||||
"ToolConfigurationTypeDef",
|
||||
cast(object, {"tools": tools_from_history}),
|
||||
)
|
||||
|
||||
# Add tool config if present or if messages contain tool content
|
||||
# Bedrock requires toolConfig when messages have toolUse/toolResult
|
||||
if tools:
|
||||
tool_config: ToolConfigurationTypeDef = {
|
||||
"tools": cast(
|
||||
"Sequence[ToolTypeDef]",
|
||||
cast(object, self._format_tools_for_converse(tools)),
|
||||
)
|
||||
}
|
||||
body["toolConfig"] = tool_config
|
||||
elif self._messages_contain_tool_content(formatted_messages):
|
||||
# Create minimal toolConfig from tool history in messages
|
||||
tools_from_history = self._extract_tools_from_message_history(
|
||||
formatted_messages
|
||||
)
|
||||
if tools_from_history:
|
||||
body["toolConfig"] = cast(
|
||||
"ToolConfigurationTypeDef",
|
||||
cast(object, {"tools": tools_from_history}),
|
||||
)
|
||||
# Add optional advanced features if configured
|
||||
if self.guardrail_config:
|
||||
guardrail_config: GuardrailConfigurationTypeDef = cast(
|
||||
"GuardrailConfigurationTypeDef", cast(object, self.guardrail_config)
|
||||
)
|
||||
body["guardrailConfig"] = guardrail_config
|
||||
|
||||
# Add optional advanced features if configured
|
||||
if self.guardrail_config:
|
||||
guardrail_config: GuardrailConfigurationTypeDef = cast(
|
||||
"GuardrailConfigurationTypeDef",
|
||||
cast(object, self.guardrail_config),
|
||||
)
|
||||
body["guardrailConfig"] = guardrail_config
|
||||
if self.additional_model_request_fields:
|
||||
body["additionalModelRequestFields"] = (
|
||||
self.additional_model_request_fields
|
||||
)
|
||||
|
||||
if self.additional_model_request_fields:
|
||||
body["additionalModelRequestFields"] = (
|
||||
self.additional_model_request_fields
|
||||
)
|
||||
if self.additional_model_response_field_paths:
|
||||
body["additionalModelResponseFieldPaths"] = (
|
||||
self.additional_model_response_field_paths
|
||||
)
|
||||
|
||||
if self.additional_model_response_field_paths:
|
||||
body["additionalModelResponseFieldPaths"] = (
|
||||
self.additional_model_response_field_paths
|
||||
)
|
||||
|
||||
if self.stream:
|
||||
return self._handle_streaming_converse(
|
||||
formatted_messages,
|
||||
body,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
)
|
||||
|
||||
return self._handle_converse(
|
||||
if self.stream:
|
||||
return self._handle_streaming_converse(
|
||||
formatted_messages,
|
||||
body,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
effective_response_model,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
logging.error(f"Context window exceeded: {e}")
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
return self._handle_converse(
|
||||
formatted_messages,
|
||||
body,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
effective_response_model,
|
||||
)
|
||||
|
||||
error_msg = f"AWS Bedrock API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
logging.error(f"Context window exceeded: {e}")
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
|
||||
error_msg = f"AWS Bedrock API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
|
||||
async def acall(
|
||||
self,
|
||||
@@ -429,99 +433,105 @@ class BedrockCompletion(BaseLLM):
|
||||
NotImplementedError: If aiobotocore is not installed.
|
||||
LLMContextLengthExceededError: If context window is exceeded.
|
||||
"""
|
||||
effective_response_model = response_model or self.response_format
|
||||
|
||||
if not AIOBOTOCORE_AVAILABLE:
|
||||
raise NotImplementedError(
|
||||
"Async support for AWS Bedrock requires aiobotocore. "
|
||||
'Install with: uv add "crewai[bedrock-async]"'
|
||||
)
|
||||
|
||||
with llm_call_context():
|
||||
try:
|
||||
self._emit_call_started_event(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
try:
|
||||
self._emit_call_started_event(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
formatted_messages, system_message = self._format_messages_for_converse(
|
||||
messages
|
||||
)
|
||||
|
||||
body: BedrockConverseRequestBody = {
|
||||
"inferenceConfig": self._get_inference_config(),
|
||||
}
|
||||
|
||||
if system_message:
|
||||
body["system"] = cast(
|
||||
"list[SystemContentBlockTypeDef]",
|
||||
cast(object, [{"text": system_message}]),
|
||||
)
|
||||
|
||||
formatted_messages, system_message = self._format_messages_for_converse(
|
||||
messages
|
||||
)
|
||||
|
||||
body: BedrockConverseRequestBody = {
|
||||
"inferenceConfig": self._get_inference_config(),
|
||||
# Add tool config if present or if messages contain tool content
|
||||
# Bedrock requires toolConfig when messages have toolUse/toolResult
|
||||
if tools:
|
||||
tool_config: ToolConfigurationTypeDef = {
|
||||
"tools": cast(
|
||||
"Sequence[ToolTypeDef]",
|
||||
cast(object, self._format_tools_for_converse(tools)),
|
||||
)
|
||||
}
|
||||
|
||||
if system_message:
|
||||
body["system"] = cast(
|
||||
"list[SystemContentBlockTypeDef]",
|
||||
cast(object, [{"text": system_message}]),
|
||||
body["toolConfig"] = tool_config
|
||||
elif self._messages_contain_tool_content(formatted_messages):
|
||||
# Create minimal toolConfig from tool history in messages
|
||||
tools_from_history = self._extract_tools_from_message_history(
|
||||
formatted_messages
|
||||
)
|
||||
if tools_from_history:
|
||||
body["toolConfig"] = cast(
|
||||
"ToolConfigurationTypeDef",
|
||||
cast(object, {"tools": tools_from_history}),
|
||||
)
|
||||
|
||||
# Add tool config if present or if messages contain tool content
|
||||
# Bedrock requires toolConfig when messages have toolUse/toolResult
|
||||
if tools:
|
||||
tool_config: ToolConfigurationTypeDef = {
|
||||
"tools": cast(
|
||||
"Sequence[ToolTypeDef]",
|
||||
cast(object, self._format_tools_for_converse(tools)),
|
||||
)
|
||||
}
|
||||
body["toolConfig"] = tool_config
|
||||
elif self._messages_contain_tool_content(formatted_messages):
|
||||
# Create minimal toolConfig from tool history in messages
|
||||
tools_from_history = self._extract_tools_from_message_history(
|
||||
formatted_messages
|
||||
)
|
||||
if tools_from_history:
|
||||
body["toolConfig"] = cast(
|
||||
"ToolConfigurationTypeDef",
|
||||
cast(object, {"tools": tools_from_history}),
|
||||
)
|
||||
if self.guardrail_config:
|
||||
guardrail_config: GuardrailConfigurationTypeDef = cast(
|
||||
"GuardrailConfigurationTypeDef", cast(object, self.guardrail_config)
|
||||
)
|
||||
body["guardrailConfig"] = guardrail_config
|
||||
|
||||
if self.guardrail_config:
|
||||
guardrail_config: GuardrailConfigurationTypeDef = cast(
|
||||
"GuardrailConfigurationTypeDef",
|
||||
cast(object, self.guardrail_config),
|
||||
)
|
||||
body["guardrailConfig"] = guardrail_config
|
||||
|
||||
if self.additional_model_request_fields:
|
||||
body["additionalModelRequestFields"] = (
|
||||
self.additional_model_request_fields
|
||||
)
|
||||
|
||||
if self.additional_model_response_field_paths:
|
||||
body["additionalModelResponseFieldPaths"] = (
|
||||
self.additional_model_response_field_paths
|
||||
)
|
||||
|
||||
if self.stream:
|
||||
return await self._ahandle_streaming_converse(
|
||||
formatted_messages,
|
||||
body,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
)
|
||||
|
||||
return await self._ahandle_converse(
|
||||
formatted_messages, body, available_functions, from_task, from_agent
|
||||
if self.additional_model_request_fields:
|
||||
body["additionalModelRequestFields"] = (
|
||||
self.additional_model_request_fields
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
logging.error(f"Context window exceeded: {e}")
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
|
||||
error_msg = f"AWS Bedrock API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
if self.additional_model_response_field_paths:
|
||||
body["additionalModelResponseFieldPaths"] = (
|
||||
self.additional_model_response_field_paths
|
||||
)
|
||||
raise
|
||||
|
||||
if self.stream:
|
||||
return await self._ahandle_streaming_converse(
|
||||
formatted_messages,
|
||||
body,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
effective_response_model,
|
||||
)
|
||||
|
||||
return await self._ahandle_converse(
|
||||
formatted_messages,
|
||||
body,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
effective_response_model,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
logging.error(f"Context window exceeded: {e}")
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
|
||||
error_msg = f"AWS Bedrock API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
|
||||
def _handle_converse(
|
||||
self,
|
||||
@@ -530,10 +540,29 @@ class BedrockCompletion(BaseLLM):
|
||||
available_functions: Mapping[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str:
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming converse API call following AWS best practices."""
|
||||
if response_model:
|
||||
structured_tool: ConverseToolTypeDef = {
|
||||
"toolSpec": {
|
||||
"name": "structured_output",
|
||||
"description": "Returns structured data according to the schema",
|
||||
"inputSchema": {"json": response_model.model_json_schema()},
|
||||
}
|
||||
}
|
||||
body["toolConfig"] = cast(
|
||||
"ToolConfigurationTypeDef",
|
||||
cast(
|
||||
object,
|
||||
{
|
||||
"tools": [structured_tool],
|
||||
"toolChoice": {"tool": {"name": "structured_output"}},
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
# Validate messages format before API call
|
||||
if not messages:
|
||||
raise ValueError("Messages cannot be empty")
|
||||
|
||||
@@ -581,6 +610,21 @@ class BedrockCompletion(BaseLLM):
|
||||
|
||||
# If there are tool uses but no available_functions, return them for the executor to handle
|
||||
tool_uses = [block["toolUse"] for block in content if "toolUse" in block]
|
||||
|
||||
if response_model and tool_uses:
|
||||
for tool_use in tool_uses:
|
||||
if tool_use.get("name") == "structured_output":
|
||||
structured_data = tool_use.get("input", {})
|
||||
result = response_model.model_validate(structured_data)
|
||||
self._emit_call_completed_event(
|
||||
response=result.model_dump_json(),
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=messages,
|
||||
)
|
||||
return result
|
||||
|
||||
if tool_uses and not available_functions:
|
||||
self._emit_call_completed_event(
|
||||
response=tool_uses,
|
||||
@@ -727,8 +771,28 @@ class BedrockCompletion(BaseLLM):
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
"""Handle streaming converse API call with comprehensive event handling."""
|
||||
if response_model:
|
||||
structured_tool: ConverseToolTypeDef = {
|
||||
"toolSpec": {
|
||||
"name": "structured_output",
|
||||
"description": "Returns structured data according to the schema",
|
||||
"inputSchema": {"json": response_model.model_json_schema()},
|
||||
}
|
||||
}
|
||||
body["toolConfig"] = cast(
|
||||
"ToolConfigurationTypeDef",
|
||||
cast(
|
||||
object,
|
||||
{
|
||||
"tools": [structured_tool],
|
||||
"toolChoice": {"tool": {"name": "structured_output"}},
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
full_response = ""
|
||||
current_tool_use: dict[str, Any] | None = None
|
||||
tool_use_id: str | None = None
|
||||
@@ -939,8 +1003,28 @@ class BedrockCompletion(BaseLLM):
|
||||
available_functions: Mapping[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str:
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle async non-streaming converse API call."""
|
||||
if response_model:
|
||||
structured_tool: ConverseToolTypeDef = {
|
||||
"toolSpec": {
|
||||
"name": "structured_output",
|
||||
"description": "Returns structured data according to the schema",
|
||||
"inputSchema": {"json": response_model.model_json_schema()},
|
||||
}
|
||||
}
|
||||
body["toolConfig"] = cast(
|
||||
"ToolConfigurationTypeDef",
|
||||
cast(
|
||||
object,
|
||||
{
|
||||
"tools": [structured_tool],
|
||||
"toolChoice": {"tool": {"name": "structured_output"}},
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
if not messages:
|
||||
raise ValueError("Messages cannot be empty")
|
||||
@@ -986,6 +1070,21 @@ class BedrockCompletion(BaseLLM):
|
||||
|
||||
# If there are tool uses but no available_functions, return them for the executor to handle
|
||||
tool_uses = [block["toolUse"] for block in content if "toolUse" in block]
|
||||
|
||||
if response_model and tool_uses:
|
||||
for tool_use in tool_uses:
|
||||
if tool_use.get("name") == "structured_output":
|
||||
structured_data = tool_use.get("input", {})
|
||||
result = response_model.model_validate(structured_data)
|
||||
self._emit_call_completed_event(
|
||||
response=result.model_dump_json(),
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=messages,
|
||||
)
|
||||
return result
|
||||
|
||||
if tool_uses and not available_functions:
|
||||
self._emit_call_completed_event(
|
||||
response=tool_uses,
|
||||
@@ -1116,8 +1215,28 @@ class BedrockCompletion(BaseLLM):
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
"""Handle async streaming converse API call."""
|
||||
if response_model:
|
||||
structured_tool: ConverseToolTypeDef = {
|
||||
"toolSpec": {
|
||||
"name": "structured_output",
|
||||
"description": "Returns structured data according to the schema",
|
||||
"inputSchema": {"json": response_model.model_json_schema()},
|
||||
}
|
||||
}
|
||||
body["toolConfig"] = cast(
|
||||
"ToolConfigurationTypeDef",
|
||||
cast(
|
||||
object,
|
||||
{
|
||||
"tools": [structured_tool],
|
||||
"toolChoice": {"tool": {"name": "structured_output"}},
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
full_response = ""
|
||||
current_tool_use: dict[str, Any] | None = None
|
||||
tool_use_id: str | None = None
|
||||
|
||||
@@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llms.base_llm import BaseLLM, llm_call_context
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
@@ -56,6 +56,7 @@ class GeminiCompletion(BaseLLM):
|
||||
client_params: dict[str, Any] | None = None,
|
||||
interceptor: BaseInterceptor[Any, Any] | None = None,
|
||||
use_vertexai: bool | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize Google Gemini chat completion client.
|
||||
@@ -86,6 +87,8 @@ class GeminiCompletion(BaseLLM):
|
||||
- None (default): Check GOOGLE_GENAI_USE_VERTEXAI env var
|
||||
When using Vertex AI with API key (Express mode), http_options with
|
||||
api_version="v1" is automatically configured.
|
||||
response_format: Pydantic model for structured output. Used as default when
|
||||
response_model is not passed to call()/acall() methods.
|
||||
**kwargs: Additional parameters
|
||||
"""
|
||||
if interceptor is not None:
|
||||
@@ -121,6 +124,7 @@ class GeminiCompletion(BaseLLM):
|
||||
self.safety_settings = safety_settings or {}
|
||||
self.stop_sequences = stop_sequences or []
|
||||
self.tools: list[dict[str, Any]] | None = None
|
||||
self.response_format = response_format
|
||||
|
||||
# Model-specific settings
|
||||
version_match = re.search(r"gemini-(\d+(?:\.\d+)?)", model.lower())
|
||||
@@ -282,66 +286,64 @@ class GeminiCompletion(BaseLLM):
|
||||
Returns:
|
||||
Chat completion response or tool call result
|
||||
"""
|
||||
with llm_call_context():
|
||||
try:
|
||||
self._emit_call_started_event(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
self.tools = tools
|
||||
try:
|
||||
self._emit_call_started_event(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
self.tools = tools
|
||||
effective_response_model = response_model or self.response_format
|
||||
|
||||
formatted_content, system_instruction = (
|
||||
self._format_messages_for_gemini(messages)
|
||||
)
|
||||
formatted_content, system_instruction = self._format_messages_for_gemini(
|
||||
messages
|
||||
)
|
||||
|
||||
messages_for_hooks = self._convert_contents_to_dict(formatted_content)
|
||||
messages_for_hooks = self._convert_contents_to_dict(formatted_content)
|
||||
|
||||
if not self._invoke_before_llm_call_hooks(
|
||||
messages_for_hooks, from_agent
|
||||
):
|
||||
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||
if not self._invoke_before_llm_call_hooks(messages_for_hooks, from_agent):
|
||||
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||
|
||||
config = self._prepare_generation_config(
|
||||
system_instruction, tools, response_model
|
||||
)
|
||||
config = self._prepare_generation_config(
|
||||
system_instruction, tools, effective_response_model
|
||||
)
|
||||
|
||||
if self.stream:
|
||||
return self._handle_streaming_completion(
|
||||
formatted_content,
|
||||
config,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
)
|
||||
|
||||
return self._handle_completion(
|
||||
if self.stream:
|
||||
return self._handle_streaming_completion(
|
||||
formatted_content,
|
||||
config,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
effective_response_model,
|
||||
)
|
||||
|
||||
except APIError as e:
|
||||
error_msg = f"Google Gemini API error: {e.code} - {e.message}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"Google Gemini API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
return self._handle_completion(
|
||||
formatted_content,
|
||||
config,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
effective_response_model,
|
||||
)
|
||||
|
||||
except APIError as e:
|
||||
error_msg = f"Google Gemini API error: {e.code} - {e.message}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"Google Gemini API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
|
||||
async def acall(
|
||||
self,
|
||||
@@ -367,59 +369,59 @@ class GeminiCompletion(BaseLLM):
|
||||
Returns:
|
||||
Chat completion response or tool call result
|
||||
"""
|
||||
with llm_call_context():
|
||||
try:
|
||||
self._emit_call_started_event(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
self.tools = tools
|
||||
try:
|
||||
self._emit_call_started_event(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
self.tools = tools
|
||||
effective_response_model = response_model or self.response_format
|
||||
|
||||
formatted_content, system_instruction = (
|
||||
self._format_messages_for_gemini(messages)
|
||||
)
|
||||
formatted_content, system_instruction = self._format_messages_for_gemini(
|
||||
messages
|
||||
)
|
||||
|
||||
config = self._prepare_generation_config(
|
||||
system_instruction, tools, response_model
|
||||
)
|
||||
config = self._prepare_generation_config(
|
||||
system_instruction, tools, effective_response_model
|
||||
)
|
||||
|
||||
if self.stream:
|
||||
return await self._ahandle_streaming_completion(
|
||||
formatted_content,
|
||||
config,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
)
|
||||
|
||||
return await self._ahandle_completion(
|
||||
if self.stream:
|
||||
return await self._ahandle_streaming_completion(
|
||||
formatted_content,
|
||||
config,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
response_model,
|
||||
effective_response_model,
|
||||
)
|
||||
|
||||
except APIError as e:
|
||||
error_msg = f"Google Gemini API error: {e.code} - {e.message}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"Google Gemini API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
return await self._ahandle_completion(
|
||||
formatted_content,
|
||||
config,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
effective_response_model,
|
||||
)
|
||||
|
||||
except APIError as e:
|
||||
error_msg = f"Google Gemini API error: {e.code} - {e.message}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"Google Gemini API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
|
||||
def _prepare_generation_config(
|
||||
self,
|
||||
@@ -574,10 +576,10 @@ class GeminiCompletion(BaseLLM):
|
||||
types.Content(role="user", parts=[function_response_part])
|
||||
)
|
||||
elif role == "assistant" and message.get("tool_calls"):
|
||||
parts: list[types.Part] = []
|
||||
tool_parts: list[types.Part] = []
|
||||
|
||||
if text_content:
|
||||
parts.append(types.Part.from_text(text=text_content))
|
||||
tool_parts.append(types.Part.from_text(text=text_content))
|
||||
|
||||
tool_calls: list[dict[str, Any]] = message.get("tool_calls") or []
|
||||
for tool_call in tool_calls:
|
||||
@@ -596,11 +598,11 @@ class GeminiCompletion(BaseLLM):
|
||||
else:
|
||||
func_args = func_args_raw
|
||||
|
||||
parts.append(
|
||||
tool_parts.append(
|
||||
types.Part.from_function_call(name=func_name, args=func_args)
|
||||
)
|
||||
|
||||
contents.append(types.Content(role="model", parts=parts))
|
||||
contents.append(types.Content(role="model", parts=tool_parts))
|
||||
else:
|
||||
# Convert role for Gemini (assistant -> model)
|
||||
gemini_role = "model" if role == "assistant" else "user"
|
||||
@@ -976,7 +978,7 @@ class GeminiCompletion(BaseLLM):
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
) -> str | Any:
|
||||
"""Handle streaming content generation."""
|
||||
full_response = ""
|
||||
function_calls: dict[int, dict[str, Any]] = {}
|
||||
@@ -1054,7 +1056,7 @@ class GeminiCompletion(BaseLLM):
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
) -> str | Any:
|
||||
"""Handle async streaming content generation."""
|
||||
full_response = ""
|
||||
function_calls: dict[int, dict[str, Any]] = {}
|
||||
|
||||
@@ -17,7 +17,7 @@ from openai.types.responses import Response
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llms.base_llm import BaseLLM, llm_call_context
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.llms.hooks.transport import AsyncHTTPTransport, HTTPTransport
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
@@ -382,35 +382,23 @@ class OpenAICompletion(BaseLLM):
|
||||
Returns:
|
||||
Completion response or tool call result.
|
||||
"""
|
||||
with llm_call_context():
|
||||
try:
|
||||
self._emit_call_started_event(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
try:
|
||||
self._emit_call_started_event(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
formatted_messages = self._format_messages(messages)
|
||||
formatted_messages = self._format_messages(messages)
|
||||
|
||||
if not self._invoke_before_llm_call_hooks(
|
||||
formatted_messages, from_agent
|
||||
):
|
||||
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||
if not self._invoke_before_llm_call_hooks(formatted_messages, from_agent):
|
||||
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||
|
||||
if self.api == "responses":
|
||||
return self._call_responses(
|
||||
messages=formatted_messages,
|
||||
tools=tools,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
return self._call_completions(
|
||||
if self.api == "responses":
|
||||
return self._call_responses(
|
||||
messages=formatted_messages,
|
||||
tools=tools,
|
||||
available_functions=available_functions,
|
||||
@@ -419,13 +407,22 @@ class OpenAICompletion(BaseLLM):
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"OpenAI API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
return self._call_completions(
|
||||
messages=formatted_messages,
|
||||
tools=tools,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"OpenAI API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
|
||||
def _call_completions(
|
||||
self,
|
||||
@@ -482,30 +479,20 @@ class OpenAICompletion(BaseLLM):
|
||||
Returns:
|
||||
Completion response or tool call result.
|
||||
"""
|
||||
with llm_call_context():
|
||||
try:
|
||||
self._emit_call_started_event(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
try:
|
||||
self._emit_call_started_event(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
formatted_messages = self._format_messages(messages)
|
||||
formatted_messages = self._format_messages(messages)
|
||||
|
||||
if self.api == "responses":
|
||||
return await self._acall_responses(
|
||||
messages=formatted_messages,
|
||||
tools=tools,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
return await self._acall_completions(
|
||||
if self.api == "responses":
|
||||
return await self._acall_responses(
|
||||
messages=formatted_messages,
|
||||
tools=tools,
|
||||
available_functions=available_functions,
|
||||
@@ -514,13 +501,22 @@ class OpenAICompletion(BaseLLM):
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"OpenAI API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
return await self._acall_completions(
|
||||
messages=formatted_messages,
|
||||
tools=tools,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"OpenAI API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
|
||||
async def _acall_completions(
|
||||
self,
|
||||
@@ -1064,7 +1060,7 @@ class OpenAICompletion(BaseLLM):
|
||||
chunk=delta_text,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_id=response_id_stream,
|
||||
response_id=response_id_stream
|
||||
)
|
||||
|
||||
elif event.type == "response.function_call_arguments.delta":
|
||||
@@ -1713,7 +1709,7 @@ class OpenAICompletion(BaseLLM):
|
||||
**parse_params, response_format=response_model
|
||||
) as stream:
|
||||
for chunk in stream:
|
||||
response_id_stream = chunk.id if hasattr(chunk, "id") else None
|
||||
response_id_stream=chunk.id if hasattr(chunk,"id") else None
|
||||
|
||||
if chunk.type == "content.delta":
|
||||
delta_content = chunk.delta
|
||||
@@ -1722,7 +1718,7 @@ class OpenAICompletion(BaseLLM):
|
||||
chunk=delta_content,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_id=response_id_stream,
|
||||
response_id=response_id_stream
|
||||
)
|
||||
|
||||
final_completion = stream.get_final_completion()
|
||||
@@ -1752,9 +1748,7 @@ class OpenAICompletion(BaseLLM):
|
||||
usage_data = {"total_tokens": 0}
|
||||
|
||||
for completion_chunk in completion_stream:
|
||||
response_id_stream = (
|
||||
completion_chunk.id if hasattr(completion_chunk, "id") else None
|
||||
)
|
||||
response_id_stream=completion_chunk.id if hasattr(completion_chunk,"id") else None
|
||||
|
||||
if hasattr(completion_chunk, "usage") and completion_chunk.usage:
|
||||
usage_data = self._extract_openai_token_usage(completion_chunk)
|
||||
@@ -1772,7 +1766,7 @@ class OpenAICompletion(BaseLLM):
|
||||
chunk=chunk_delta.content,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_id=response_id_stream,
|
||||
response_id=response_id_stream
|
||||
)
|
||||
|
||||
if chunk_delta.tool_calls:
|
||||
@@ -1811,7 +1805,7 @@ class OpenAICompletion(BaseLLM):
|
||||
"index": tool_calls[tool_index]["index"],
|
||||
},
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
response_id=response_id_stream,
|
||||
response_id=response_id_stream
|
||||
)
|
||||
|
||||
self._track_token_usage_internal(usage_data)
|
||||
@@ -2023,7 +2017,7 @@ class OpenAICompletion(BaseLLM):
|
||||
accumulated_content = ""
|
||||
usage_data = {"total_tokens": 0}
|
||||
async for chunk in completion_stream:
|
||||
response_id_stream = chunk.id if hasattr(chunk, "id") else None
|
||||
response_id_stream=chunk.id if hasattr(chunk,"id") else None
|
||||
|
||||
if hasattr(chunk, "usage") and chunk.usage:
|
||||
usage_data = self._extract_openai_token_usage(chunk)
|
||||
@@ -2041,7 +2035,7 @@ class OpenAICompletion(BaseLLM):
|
||||
chunk=delta.content,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_id=response_id_stream,
|
||||
response_id=response_id_stream
|
||||
)
|
||||
|
||||
self._track_token_usage_internal(usage_data)
|
||||
@@ -2077,7 +2071,7 @@ class OpenAICompletion(BaseLLM):
|
||||
usage_data = {"total_tokens": 0}
|
||||
|
||||
async for chunk in stream:
|
||||
response_id_stream = chunk.id if hasattr(chunk, "id") else None
|
||||
response_id_stream=chunk.id if hasattr(chunk,"id") else None
|
||||
|
||||
if hasattr(chunk, "usage") and chunk.usage:
|
||||
usage_data = self._extract_openai_token_usage(chunk)
|
||||
@@ -2095,7 +2089,7 @@ class OpenAICompletion(BaseLLM):
|
||||
chunk=chunk_delta.content,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_id=response_id_stream,
|
||||
response_id=response_id_stream
|
||||
)
|
||||
|
||||
if chunk_delta.tool_calls:
|
||||
@@ -2134,7 +2128,7 @@ class OpenAICompletion(BaseLLM):
|
||||
"index": tool_calls[tool_index]["index"],
|
||||
},
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
response_id=response_id_stream,
|
||||
response_id=response_id_stream
|
||||
)
|
||||
|
||||
self._track_token_usage_internal(usage_data)
|
||||
|
||||
@@ -18,7 +18,6 @@ if TYPE_CHECKING:
|
||||
)
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleGenerativeAiEmbeddingFunction,
|
||||
GoogleVertexEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||
HuggingFaceEmbeddingFunction,
|
||||
@@ -52,6 +51,9 @@ if TYPE_CHECKING:
|
||||
from crewai.rag.embeddings.providers.aws.types import BedrockProviderSpec
|
||||
from crewai.rag.embeddings.providers.cohere.types import CohereProviderSpec
|
||||
from crewai.rag.embeddings.providers.custom.types import CustomProviderSpec
|
||||
from crewai.rag.embeddings.providers.google.genai_vertex_embedding import (
|
||||
GoogleGenAIVertexEmbeddingFunction,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.google.types import (
|
||||
GenerativeAiProviderSpec,
|
||||
VertexAIProviderSpec,
|
||||
@@ -163,7 +165,7 @@ def build_embedder_from_dict(spec: OpenAIProviderSpec) -> OpenAIEmbeddingFunctio
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: VertexAIProviderSpec,
|
||||
) -> GoogleVertexEmbeddingFunction: ...
|
||||
) -> GoogleGenAIVertexEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
@@ -296,7 +298,9 @@ def build_embedder(spec: OpenAIProviderSpec) -> OpenAIEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: VertexAIProviderSpec) -> GoogleVertexEmbeddingFunction: ...
|
||||
def build_embedder(
|
||||
spec: VertexAIProviderSpec,
|
||||
) -> GoogleGenAIVertexEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
"""Google embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.google.genai_vertex_embedding import (
|
||||
GoogleGenAIVertexEmbeddingFunction,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.google.generative_ai import (
|
||||
GenerativeAiProvider,
|
||||
)
|
||||
@@ -18,6 +21,7 @@ __all__ = [
|
||||
"GenerativeAiProvider",
|
||||
"GenerativeAiProviderConfig",
|
||||
"GenerativeAiProviderSpec",
|
||||
"GoogleGenAIVertexEmbeddingFunction",
|
||||
"VertexAIProvider",
|
||||
"VertexAIProviderConfig",
|
||||
"VertexAIProviderSpec",
|
||||
|
||||
@@ -0,0 +1,237 @@
|
||||
"""Google Vertex AI embedding function implementation.
|
||||
|
||||
This module supports both the new google-genai SDK and the deprecated
|
||||
vertexai.language_models module for backwards compatibility.
|
||||
|
||||
The deprecated vertexai.language_models module will be removed after June 24, 2026.
|
||||
Migration guide: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk
|
||||
"""
|
||||
|
||||
from typing import Any, ClassVar, cast
|
||||
import warnings
|
||||
|
||||
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.rag.embeddings.providers.google.types import VertexAIProviderConfig
|
||||
|
||||
|
||||
class GoogleGenAIVertexEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
"""Embedding function for Google Vertex AI with dual SDK support.
|
||||
|
||||
This class supports both:
|
||||
- Legacy models (textembedding-gecko*) using the deprecated vertexai.language_models SDK
|
||||
- New models (gemini-embedding-*, text-embedding-*) using the google-genai SDK
|
||||
|
||||
The SDK is automatically selected based on the model name. Legacy models will
|
||||
emit a deprecation warning.
|
||||
|
||||
Supports two authentication modes:
|
||||
1. Vertex AI backend: Set project_id and location/region (uses Application Default Credentials)
|
||||
2. API key: Set api_key for direct API access
|
||||
|
||||
Example:
|
||||
# Using legacy model (will emit deprecation warning)
|
||||
embedder = GoogleGenAIVertexEmbeddingFunction(
|
||||
project_id="my-project",
|
||||
region="us-central1",
|
||||
model_name="textembedding-gecko"
|
||||
)
|
||||
|
||||
# Using new model with google-genai SDK
|
||||
embedder = GoogleGenAIVertexEmbeddingFunction(
|
||||
project_id="my-project",
|
||||
location="us-central1",
|
||||
model_name="gemini-embedding-001"
|
||||
)
|
||||
|
||||
# Using API key (new SDK only)
|
||||
embedder = GoogleGenAIVertexEmbeddingFunction(
|
||||
api_key="your-api-key",
|
||||
model_name="gemini-embedding-001"
|
||||
)
|
||||
"""
|
||||
|
||||
# Models that use the legacy vertexai.language_models SDK
|
||||
LEGACY_MODELS: ClassVar[set[str]] = {
|
||||
"textembedding-gecko",
|
||||
"textembedding-gecko@001",
|
||||
"textembedding-gecko@002",
|
||||
"textembedding-gecko@003",
|
||||
"textembedding-gecko@latest",
|
||||
"textembedding-gecko-multilingual",
|
||||
"textembedding-gecko-multilingual@001",
|
||||
"textembedding-gecko-multilingual@latest",
|
||||
}
|
||||
|
||||
# Models that use the new google-genai SDK
|
||||
GENAI_MODELS: ClassVar[set[str]] = {
|
||||
"gemini-embedding-001",
|
||||
"text-embedding-005",
|
||||
"text-multilingual-embedding-002",
|
||||
}
|
||||
|
||||
def __init__(self, **kwargs: Unpack[VertexAIProviderConfig]) -> None:
|
||||
"""Initialize Google Vertex AI embedding function.
|
||||
|
||||
Args:
|
||||
**kwargs: Configuration parameters including:
|
||||
- model_name: Model to use for embeddings (default: "textembedding-gecko")
|
||||
- api_key: Optional API key for authentication (new SDK only)
|
||||
- project_id: GCP project ID (for Vertex AI backend)
|
||||
- location: GCP region (default: "us-central1")
|
||||
- region: Deprecated alias for location
|
||||
- task_type: Task type for embeddings (default: "RETRIEVAL_DOCUMENT", new SDK only)
|
||||
- output_dimensionality: Optional output embedding dimension (new SDK only)
|
||||
"""
|
||||
# Handle deprecated 'region' parameter (only if it has a value)
|
||||
region_value = kwargs.pop("region", None) # type: ignore[typeddict-item]
|
||||
if region_value is not None:
|
||||
warnings.warn(
|
||||
"The 'region' parameter is deprecated, use 'location' instead. "
|
||||
"See: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if "location" not in kwargs or kwargs.get("location") is None:
|
||||
kwargs["location"] = region_value # type: ignore[typeddict-unknown-key]
|
||||
|
||||
self._config = kwargs
|
||||
self._model_name = str(kwargs.get("model_name", "textembedding-gecko"))
|
||||
self._use_legacy = self._is_legacy_model(self._model_name)
|
||||
|
||||
if self._use_legacy:
|
||||
self._init_legacy_client(**kwargs)
|
||||
else:
|
||||
self._init_genai_client(**kwargs)
|
||||
|
||||
def _is_legacy_model(self, model_name: str) -> bool:
|
||||
"""Check if the model uses the legacy SDK."""
|
||||
return model_name in self.LEGACY_MODELS or model_name.startswith(
|
||||
"textembedding-gecko"
|
||||
)
|
||||
|
||||
def _init_legacy_client(self, **kwargs: Any) -> None:
|
||||
"""Initialize using the deprecated vertexai.language_models SDK."""
|
||||
warnings.warn(
|
||||
f"Model '{self._model_name}' uses the deprecated vertexai.language_models SDK "
|
||||
"which will be removed after June 24, 2026. Consider migrating to newer models "
|
||||
"like 'gemini-embedding-001'. "
|
||||
"See: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk",
|
||||
DeprecationWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
|
||||
try:
|
||||
import vertexai
|
||||
from vertexai.language_models import TextEmbeddingModel
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"vertexai is required for legacy embedding models (textembedding-gecko*). "
|
||||
"Install it with: pip install google-cloud-aiplatform"
|
||||
) from e
|
||||
|
||||
project_id = kwargs.get("project_id")
|
||||
location = str(kwargs.get("location", "us-central1"))
|
||||
|
||||
if not project_id:
|
||||
raise ValueError(
|
||||
"project_id is required for legacy models. "
|
||||
"For API key authentication, use newer models like 'gemini-embedding-001'."
|
||||
)
|
||||
|
||||
vertexai.init(project=str(project_id), location=location)
|
||||
self._legacy_model = TextEmbeddingModel.from_pretrained(self._model_name)
|
||||
|
||||
def _init_genai_client(self, **kwargs: Any) -> None:
|
||||
"""Initialize using the new google-genai SDK."""
|
||||
try:
|
||||
from google import genai
|
||||
from google.genai.types import EmbedContentConfig
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"google-genai is required for Google Gen AI embeddings. "
|
||||
"Install it with: uv add 'crewai[google-genai]'"
|
||||
) from e
|
||||
|
||||
self._genai = genai
|
||||
self._EmbedContentConfig = EmbedContentConfig
|
||||
self._task_type = kwargs.get("task_type", "RETRIEVAL_DOCUMENT")
|
||||
self._output_dimensionality = kwargs.get("output_dimensionality")
|
||||
|
||||
# Initialize client based on authentication mode
|
||||
api_key = kwargs.get("api_key")
|
||||
project_id = kwargs.get("project_id")
|
||||
location: str = str(kwargs.get("location", "us-central1"))
|
||||
|
||||
if api_key:
|
||||
self._client = genai.Client(api_key=api_key)
|
||||
elif project_id:
|
||||
self._client = genai.Client(
|
||||
vertexai=True,
|
||||
project=str(project_id),
|
||||
location=location,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either 'api_key' (for API key authentication) or 'project_id' "
|
||||
"(for Vertex AI backend with ADC) must be provided."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
"""Return the name of the embedding function for ChromaDB compatibility."""
|
||||
return "google-vertex"
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
"""Generate embeddings for input documents.
|
||||
|
||||
Args:
|
||||
input: List of documents to embed.
|
||||
|
||||
Returns:
|
||||
List of embedding vectors.
|
||||
"""
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
||||
if self._use_legacy:
|
||||
return self._call_legacy(input)
|
||||
return self._call_genai(input)
|
||||
|
||||
def _call_legacy(self, input: list[str]) -> Embeddings:
|
||||
"""Generate embeddings using the legacy SDK."""
|
||||
import numpy as np
|
||||
|
||||
embeddings_list = []
|
||||
for text in input:
|
||||
embedding_result = self._legacy_model.get_embeddings([text])
|
||||
embeddings_list.append(
|
||||
np.array(embedding_result[0].values, dtype=np.float32)
|
||||
)
|
||||
|
||||
return cast(Embeddings, embeddings_list)
|
||||
|
||||
def _call_genai(self, input: list[str]) -> Embeddings:
|
||||
"""Generate embeddings using the new google-genai SDK."""
|
||||
# Build config for embed_content
|
||||
config_kwargs: dict[str, Any] = {
|
||||
"task_type": self._task_type,
|
||||
}
|
||||
if self._output_dimensionality is not None:
|
||||
config_kwargs["output_dimensionality"] = self._output_dimensionality
|
||||
|
||||
config = self._EmbedContentConfig(**config_kwargs)
|
||||
|
||||
# Call the embedding API
|
||||
response = self._client.models.embed_content(
|
||||
model=self._model_name,
|
||||
contents=input, # type: ignore[arg-type]
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Extract embeddings from response
|
||||
if response.embeddings is None:
|
||||
raise ValueError("No embeddings returned from the API")
|
||||
embeddings = [emb.values for emb in response.embeddings]
|
||||
return cast(Embeddings, embeddings)
|
||||
@@ -34,12 +34,47 @@ class GenerativeAiProviderSpec(TypedDict):
|
||||
|
||||
|
||||
class VertexAIProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Vertex AI provider."""
|
||||
"""Configuration for Vertex AI provider with dual SDK support.
|
||||
|
||||
Supports both legacy models (textembedding-gecko*) using the deprecated
|
||||
vertexai.language_models SDK and new models using google-genai SDK.
|
||||
|
||||
Attributes:
|
||||
api_key: Google API key (optional if using project_id with ADC). Only for new SDK models.
|
||||
model_name: Embedding model name (default: "textembedding-gecko").
|
||||
Legacy models: textembedding-gecko, textembedding-gecko@001, etc.
|
||||
New models: gemini-embedding-001, text-embedding-005, text-multilingual-embedding-002
|
||||
project_id: GCP project ID (required for Vertex AI backend and legacy models).
|
||||
location: GCP region/location (default: "us-central1").
|
||||
region: Deprecated alias for location (kept for backwards compatibility).
|
||||
task_type: Task type for embeddings (default: "RETRIEVAL_DOCUMENT"). Only for new SDK models.
|
||||
output_dimensionality: Output embedding dimension (optional). Only for new SDK models.
|
||||
"""
|
||||
|
||||
api_key: str
|
||||
model_name: Annotated[str, "textembedding-gecko"]
|
||||
project_id: Annotated[str, "cloud-large-language-models"]
|
||||
region: Annotated[str, "us-central1"]
|
||||
model_name: Annotated[
|
||||
Literal[
|
||||
# Legacy models (deprecated vertexai.language_models SDK)
|
||||
"textembedding-gecko",
|
||||
"textembedding-gecko@001",
|
||||
"textembedding-gecko@002",
|
||||
"textembedding-gecko@003",
|
||||
"textembedding-gecko@latest",
|
||||
"textembedding-gecko-multilingual",
|
||||
"textembedding-gecko-multilingual@001",
|
||||
"textembedding-gecko-multilingual@latest",
|
||||
# New models (google-genai SDK)
|
||||
"gemini-embedding-001",
|
||||
"text-embedding-005",
|
||||
"text-multilingual-embedding-002",
|
||||
],
|
||||
"textembedding-gecko",
|
||||
]
|
||||
project_id: str
|
||||
location: Annotated[str, "us-central1"]
|
||||
region: Annotated[str, "us-central1"] # Deprecated alias for location
|
||||
task_type: Annotated[str, "RETRIEVAL_DOCUMENT"]
|
||||
output_dimensionality: int
|
||||
|
||||
|
||||
class VertexAIProviderSpec(TypedDict, total=False):
|
||||
|
||||
@@ -1,46 +1,126 @@
|
||||
"""Google Vertex AI embeddings provider."""
|
||||
"""Google Vertex AI embeddings provider.
|
||||
|
||||
This module supports both the new google-genai SDK and the deprecated
|
||||
vertexai.language_models module for backwards compatibility.
|
||||
|
||||
The SDK is automatically selected based on the model name:
|
||||
- Legacy models (textembedding-gecko*) use vertexai.language_models (deprecated)
|
||||
- New models (gemini-embedding-*, text-embedding-*) use google-genai
|
||||
|
||||
Migration guide: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/deprecations/genai-vertexai-sdk
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleVertexEmbeddingFunction,
|
||||
)
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.providers.google.genai_vertex_embedding import (
|
||||
GoogleGenAIVertexEmbeddingFunction,
|
||||
)
|
||||
|
||||
|
||||
class VertexAIProvider(BaseEmbeddingsProvider[GoogleVertexEmbeddingFunction]):
|
||||
"""Google Vertex AI embeddings provider."""
|
||||
class VertexAIProvider(BaseEmbeddingsProvider[GoogleGenAIVertexEmbeddingFunction]):
|
||||
"""Google Vertex AI embeddings provider with dual SDK support.
|
||||
|
||||
embedding_callable: type[GoogleVertexEmbeddingFunction] = Field(
|
||||
default=GoogleVertexEmbeddingFunction,
|
||||
description="Vertex AI embedding function class",
|
||||
Supports both legacy models (textembedding-gecko*) using the deprecated
|
||||
vertexai.language_models SDK and new models (gemini-embedding-*, text-embedding-*)
|
||||
using the google-genai SDK.
|
||||
|
||||
The SDK is automatically selected based on the model name. Legacy models will
|
||||
emit a deprecation warning.
|
||||
|
||||
Authentication modes:
|
||||
1. Vertex AI backend: Set project_id and location/region (uses Application Default Credentials)
|
||||
2. API key: Set api_key for direct API access (new SDK models only)
|
||||
|
||||
Example:
|
||||
# Legacy model (backwards compatible, will emit deprecation warning)
|
||||
provider = VertexAIProvider(
|
||||
project_id="my-project",
|
||||
region="us-central1", # or location="us-central1"
|
||||
model_name="textembedding-gecko"
|
||||
)
|
||||
|
||||
# New model with Vertex AI backend
|
||||
provider = VertexAIProvider(
|
||||
project_id="my-project",
|
||||
location="us-central1",
|
||||
model_name="gemini-embedding-001"
|
||||
)
|
||||
|
||||
# New model with API key
|
||||
provider = VertexAIProvider(
|
||||
api_key="your-api-key",
|
||||
model_name="gemini-embedding-001"
|
||||
)
|
||||
"""
|
||||
|
||||
embedding_callable: type[GoogleGenAIVertexEmbeddingFunction] = Field(
|
||||
default=GoogleGenAIVertexEmbeddingFunction,
|
||||
description="Google Vertex AI embedding function class",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="textembedding-gecko",
|
||||
description="Model name to use for embeddings",
|
||||
description=(
|
||||
"Model name to use for embeddings. Legacy models (textembedding-gecko*) "
|
||||
"use the deprecated SDK. New models (gemini-embedding-001, text-embedding-005) "
|
||||
"use the google-genai SDK."
|
||||
),
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_VERTEX_MODEL_NAME",
|
||||
"GOOGLE_VERTEX_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Google API key",
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="Google API key (optional if using project_id with Application Default Credentials)",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_CLOUD_API_KEY", "GOOGLE_CLOUD_API_KEY"
|
||||
"EMBEDDINGS_GOOGLE_CLOUD_API_KEY",
|
||||
"GOOGLE_CLOUD_API_KEY",
|
||||
"GOOGLE_API_KEY",
|
||||
),
|
||||
)
|
||||
project_id: str = Field(
|
||||
default="cloud-large-language-models",
|
||||
description="GCP project ID",
|
||||
project_id: str | None = Field(
|
||||
default=None,
|
||||
description="GCP project ID (required for Vertex AI backend and legacy models)",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_CLOUD_PROJECT", "GOOGLE_CLOUD_PROJECT"
|
||||
"EMBEDDINGS_GOOGLE_CLOUD_PROJECT",
|
||||
"GOOGLE_CLOUD_PROJECT",
|
||||
),
|
||||
)
|
||||
region: str = Field(
|
||||
location: str = Field(
|
||||
default="us-central1",
|
||||
description="GCP region",
|
||||
description="GCP region/location",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_CLOUD_REGION", "GOOGLE_CLOUD_REGION"
|
||||
"EMBEDDINGS_GOOGLE_CLOUD_LOCATION",
|
||||
"EMBEDDINGS_GOOGLE_CLOUD_REGION",
|
||||
"GOOGLE_CLOUD_LOCATION",
|
||||
"GOOGLE_CLOUD_REGION",
|
||||
),
|
||||
)
|
||||
region: str | None = Field(
|
||||
default=None,
|
||||
description="Deprecated: Use 'location' instead. GCP region (kept for backwards compatibility)",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_VERTEX_REGION",
|
||||
"GOOGLE_VERTEX_REGION",
|
||||
),
|
||||
)
|
||||
task_type: str = Field(
|
||||
default="RETRIEVAL_DOCUMENT",
|
||||
description="Task type for embeddings (e.g., RETRIEVAL_DOCUMENT, RETRIEVAL_QUERY). Only used with new SDK models.",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_VERTEX_TASK_TYPE",
|
||||
"GOOGLE_VERTEX_TASK_TYPE",
|
||||
),
|
||||
)
|
||||
output_dimensionality: int | None = Field(
|
||||
default=None,
|
||||
description="Output embedding dimensionality (optional). Only used with new SDK models.",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_GOOGLE_VERTEX_OUTPUT_DIMENSIONALITY",
|
||||
"GOOGLE_VERTEX_OUTPUT_DIMENSIONALITY",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -5,17 +5,29 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
from collections.abc import Coroutine
|
||||
import concurrent.futures
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
from uuid import UUID
|
||||
|
||||
from aiocache import Cache # type: ignore[import-untyped]
|
||||
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from aiocache import Cache
|
||||
from crewai_files import FileInput
|
||||
|
||||
_file_store = Cache(Cache.MEMORY, serializer=PickleSerializer())
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_file_store: Cache | None = None
|
||||
|
||||
try:
|
||||
from aiocache import Cache
|
||||
from aiocache.serializers import PickleSerializer
|
||||
|
||||
_file_store = Cache(Cache.MEMORY, serializer=PickleSerializer())
|
||||
except ImportError:
|
||||
logger.debug(
|
||||
"aiocache is not installed. File store features will be disabled. "
|
||||
"Install with: uv add aiocache"
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -59,6 +71,8 @@ async def astore_files(
|
||||
files: Dictionary mapping names to file inputs.
|
||||
ttl: Time-to-live in seconds.
|
||||
"""
|
||||
if _file_store is None:
|
||||
return
|
||||
await _file_store.set(f"{_CREW_PREFIX}{execution_id}", files, ttl=ttl)
|
||||
|
||||
|
||||
@@ -71,6 +85,8 @@ async def aget_files(execution_id: UUID) -> dict[str, FileInput] | None:
|
||||
Returns:
|
||||
Dictionary of files or None if not found.
|
||||
"""
|
||||
if _file_store is None:
|
||||
return None
|
||||
result: dict[str, FileInput] | None = await _file_store.get(
|
||||
f"{_CREW_PREFIX}{execution_id}"
|
||||
)
|
||||
@@ -83,6 +99,8 @@ async def aclear_files(execution_id: UUID) -> None:
|
||||
Args:
|
||||
execution_id: Unique identifier for the crew execution.
|
||||
"""
|
||||
if _file_store is None:
|
||||
return
|
||||
await _file_store.delete(f"{_CREW_PREFIX}{execution_id}")
|
||||
|
||||
|
||||
@@ -98,6 +116,8 @@ async def astore_task_files(
|
||||
files: Dictionary mapping names to file inputs.
|
||||
ttl: Time-to-live in seconds.
|
||||
"""
|
||||
if _file_store is None:
|
||||
return
|
||||
await _file_store.set(f"{_TASK_PREFIX}{task_id}", files, ttl=ttl)
|
||||
|
||||
|
||||
@@ -110,6 +130,8 @@ async def aget_task_files(task_id: UUID) -> dict[str, FileInput] | None:
|
||||
Returns:
|
||||
Dictionary of files or None if not found.
|
||||
"""
|
||||
if _file_store is None:
|
||||
return None
|
||||
result: dict[str, FileInput] | None = await _file_store.get(
|
||||
f"{_TASK_PREFIX}{task_id}"
|
||||
)
|
||||
@@ -122,6 +144,8 @@ async def aclear_task_files(task_id: UUID) -> None:
|
||||
Args:
|
||||
task_id: Unique identifier for the task.
|
||||
"""
|
||||
if _file_store is None:
|
||||
return
|
||||
await _file_store.delete(f"{_TASK_PREFIX}{task_id}")
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"max_tokens":4096,"messages":[{"role":"user","content":"Say hello in French"}],"model":"claude-sonnet-4-0","stream":false,"tool_choice":{"type":"tool","name":"structured_output"},"tools":[{"name":"structured_output","description":"Returns structured data according to the schema","input_schema":{"description":"Response model for greeting test.","properties":{"greeting":{"title":"Greeting","type":"string"},"language":{"title":"Language","type":"string"}},"required":["greeting","language"],"title":"GreetingResponse","type":"object"}}]}'
|
||||
body: '{"max_tokens":4096,"messages":[{"role":"user","content":"Say hello in French"}],"model":"claude-sonnet-4-0","stream":false,"tool_choice":{"type":"tool","name":"structured_output"},"tools":[{"name":"structured_output","description":"Output
|
||||
the structured response","input_schema":{"type":"object","description":"Response
|
||||
model for greeting test.","title":"GreetingResponse","properties":{"greeting":{"type":"string","title":"Greeting"},"language":{"type":"string","title":"Language"}},"additionalProperties":false,"required":["greeting","language"]}}]}'
|
||||
headers:
|
||||
User-Agent:
|
||||
- X-USER-AGENT-XXX
|
||||
@@ -13,7 +15,7 @@ interactions:
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '539'
|
||||
- '551'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
@@ -29,7 +31,7 @@ interactions:
|
||||
x-stainless-os:
|
||||
- X-STAINLESS-OS-XXX
|
||||
x-stainless-package-version:
|
||||
- 0.75.0
|
||||
- 0.76.0
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
@@ -42,7 +44,7 @@ interactions:
|
||||
uri: https://api.anthropic.com/v1/messages
|
||||
response:
|
||||
body:
|
||||
string: '{"model":"claude-sonnet-4-20250514","id":"msg_01XjvX2nCho1knuucbwwgCpw","type":"message","role":"assistant","content":[{"type":"tool_use","id":"toolu_019rfPRSDmBb7CyCTdGMv5rK","name":"structured_output","input":{"greeting":"Bonjour","language":"French"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":432,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"cache_creation":{"ephemeral_5m_input_tokens":0,"ephemeral_1h_input_tokens":0},"output_tokens":53,"service_tier":"standard"}}'
|
||||
string: '{"model":"claude-sonnet-4-20250514","id":"msg_01CKTyVmak15L5oQ36mv4sL9","type":"message","role":"assistant","content":[{"type":"tool_use","id":"toolu_0174BYmn6xiSnUwVhFD8S7EW","name":"structured_output","input":{"greeting":"Bonjour","language":"French"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":436,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"cache_creation":{"ephemeral_5m_input_tokens":0,"ephemeral_1h_input_tokens":0},"output_tokens":53,"service_tier":"standard"}}'
|
||||
headers:
|
||||
CF-RAY:
|
||||
- CF-RAY-XXX
|
||||
@@ -51,7 +53,7 @@ interactions:
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Mon, 01 Dec 2025 11:19:38 GMT
|
||||
- Mon, 26 Jan 2026 14:59:34 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Transfer-Encoding:
|
||||
@@ -82,12 +84,10 @@ interactions:
|
||||
- DYNAMIC
|
||||
request-id:
|
||||
- REQUEST-ID-XXX
|
||||
retry-after:
|
||||
- '24'
|
||||
strict-transport-security:
|
||||
- STS-XXX
|
||||
x-envoy-upstream-service-time:
|
||||
- '2101'
|
||||
- '968'
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,108 +0,0 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"messages":[{"role":"user","content":"Say hi"}],"model":"gpt-4o-mini"}'
|
||||
headers:
|
||||
User-Agent:
|
||||
- X-USER-AGENT-XXX
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- ACCEPT-ENCODING-XXX
|
||||
authorization:
|
||||
- AUTHORIZATION-XXX
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '71'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- api.openai.com
|
||||
x-stainless-arch:
|
||||
- X-STAINLESS-ARCH-XXX
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- X-STAINLESS-OS-XXX
|
||||
x-stainless-package-version:
|
||||
- 1.83.0
|
||||
x-stainless-read-timeout:
|
||||
- X-STAINLESS-READ-TIMEOUT-XXX
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.13.0
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: "{\n \"id\": \"chatcmpl-D2HpUSxS5LeHwDTELElWlC5CDMzmr\",\n \"object\":
|
||||
\"chat.completion\",\n \"created\": 1769437564,\n \"model\": \"gpt-4o-mini-2024-07-18\",\n
|
||||
\ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\":
|
||||
\"assistant\",\n \"content\": \"Hi there! How can I assist you today?\",\n
|
||||
\ \"refusal\": null,\n \"annotations\": []\n },\n \"logprobs\":
|
||||
null,\n \"finish_reason\": \"stop\"\n }\n ],\n \"usage\": {\n \"prompt_tokens\":
|
||||
9,\n \"completion_tokens\": 10,\n \"total_tokens\": 19,\n \"prompt_tokens_details\":
|
||||
{\n \"cached_tokens\": 0,\n \"audio_tokens\": 0\n },\n \"completion_tokens_details\":
|
||||
{\n \"reasoning_tokens\": 0,\n \"audio_tokens\": 0,\n \"accepted_prediction_tokens\":
|
||||
0,\n \"rejected_prediction_tokens\": 0\n }\n },\n \"service_tier\":
|
||||
\"default\",\n \"system_fingerprint\": \"fp_29330a9688\"\n}\n"
|
||||
headers:
|
||||
CF-RAY:
|
||||
- CF-RAY-XXX
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Mon, 26 Jan 2026 14:26:05 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Set-Cookie:
|
||||
- SET-COOKIE-XXX
|
||||
Strict-Transport-Security:
|
||||
- STS-XXX
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- X-CONTENT-TYPE-XXX
|
||||
access-control-expose-headers:
|
||||
- ACCESS-CONTROL-XXX
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
openai-organization:
|
||||
- OPENAI-ORG-XXX
|
||||
openai-processing-ms:
|
||||
- '460'
|
||||
openai-project:
|
||||
- OPENAI-PROJECT-XXX
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
x-envoy-upstream-service-time:
|
||||
- '477'
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-requests:
|
||||
- X-RATELIMIT-LIMIT-REQUESTS-XXX
|
||||
x-ratelimit-limit-tokens:
|
||||
- X-RATELIMIT-LIMIT-TOKENS-XXX
|
||||
x-ratelimit-remaining-requests:
|
||||
- X-RATELIMIT-REMAINING-REQUESTS-XXX
|
||||
x-ratelimit-remaining-tokens:
|
||||
- X-RATELIMIT-REMAINING-TOKENS-XXX
|
||||
x-ratelimit-reset-requests:
|
||||
- X-RATELIMIT-RESET-REQUESTS-XXX
|
||||
x-ratelimit-reset-tokens:
|
||||
- X-RATELIMIT-RESET-TOKENS-XXX
|
||||
x-request-id:
|
||||
- X-REQUEST-ID-XXX
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
@@ -1,215 +0,0 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"messages":[{"role":"user","content":"Say hi"}],"model":"gpt-4o-mini"}'
|
||||
headers:
|
||||
User-Agent:
|
||||
- X-USER-AGENT-XXX
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- ACCEPT-ENCODING-XXX
|
||||
authorization:
|
||||
- AUTHORIZATION-XXX
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '71'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- api.openai.com
|
||||
x-stainless-arch:
|
||||
- X-STAINLESS-ARCH-XXX
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- X-STAINLESS-OS-XXX
|
||||
x-stainless-package-version:
|
||||
- 1.83.0
|
||||
x-stainless-read-timeout:
|
||||
- X-STAINLESS-READ-TIMEOUT-XXX
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.13.0
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: "{\n \"id\": \"chatcmpl-D2HpStmyOpe9DrthWBlDdMZfVMJ1u\",\n \"object\":
|
||||
\"chat.completion\",\n \"created\": 1769437562,\n \"model\": \"gpt-4o-mini-2024-07-18\",\n
|
||||
\ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\":
|
||||
\"assistant\",\n \"content\": \"Hi! How can I assist you today?\",\n
|
||||
\ \"refusal\": null,\n \"annotations\": []\n },\n \"logprobs\":
|
||||
null,\n \"finish_reason\": \"stop\"\n }\n ],\n \"usage\": {\n \"prompt_tokens\":
|
||||
9,\n \"completion_tokens\": 9,\n \"total_tokens\": 18,\n \"prompt_tokens_details\":
|
||||
{\n \"cached_tokens\": 0,\n \"audio_tokens\": 0\n },\n \"completion_tokens_details\":
|
||||
{\n \"reasoning_tokens\": 0,\n \"audio_tokens\": 0,\n \"accepted_prediction_tokens\":
|
||||
0,\n \"rejected_prediction_tokens\": 0\n }\n },\n \"service_tier\":
|
||||
\"default\",\n \"system_fingerprint\": \"fp_29330a9688\"\n}\n"
|
||||
headers:
|
||||
CF-RAY:
|
||||
- CF-RAY-XXX
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Mon, 26 Jan 2026 14:26:02 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Set-Cookie:
|
||||
- SET-COOKIE-XXX
|
||||
Strict-Transport-Security:
|
||||
- STS-XXX
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- X-CONTENT-TYPE-XXX
|
||||
access-control-expose-headers:
|
||||
- ACCESS-CONTROL-XXX
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
openai-organization:
|
||||
- OPENAI-ORG-XXX
|
||||
openai-processing-ms:
|
||||
- '415'
|
||||
openai-project:
|
||||
- OPENAI-PROJECT-XXX
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
x-envoy-upstream-service-time:
|
||||
- '434'
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-requests:
|
||||
- X-RATELIMIT-LIMIT-REQUESTS-XXX
|
||||
x-ratelimit-limit-tokens:
|
||||
- X-RATELIMIT-LIMIT-TOKENS-XXX
|
||||
x-ratelimit-remaining-requests:
|
||||
- X-RATELIMIT-REMAINING-REQUESTS-XXX
|
||||
x-ratelimit-remaining-tokens:
|
||||
- X-RATELIMIT-REMAINING-TOKENS-XXX
|
||||
x-ratelimit-reset-requests:
|
||||
- X-RATELIMIT-RESET-REQUESTS-XXX
|
||||
x-ratelimit-reset-tokens:
|
||||
- X-RATELIMIT-RESET-TOKENS-XXX
|
||||
x-request-id:
|
||||
- X-REQUEST-ID-XXX
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
- request:
|
||||
body: '{"messages":[{"role":"user","content":"Say bye"}],"model":"gpt-4o-mini"}'
|
||||
headers:
|
||||
User-Agent:
|
||||
- X-USER-AGENT-XXX
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- ACCEPT-ENCODING-XXX
|
||||
authorization:
|
||||
- AUTHORIZATION-XXX
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '72'
|
||||
content-type:
|
||||
- application/json
|
||||
cookie:
|
||||
- COOKIE-XXX
|
||||
host:
|
||||
- api.openai.com
|
||||
x-stainless-arch:
|
||||
- X-STAINLESS-ARCH-XXX
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- X-STAINLESS-OS-XXX
|
||||
x-stainless-package-version:
|
||||
- 1.83.0
|
||||
x-stainless-read-timeout:
|
||||
- X-STAINLESS-READ-TIMEOUT-XXX
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.13.0
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: "{\n \"id\": \"chatcmpl-D2HpS1DP0Xd3tmWt5PBincVrdU7yw\",\n \"object\":
|
||||
\"chat.completion\",\n \"created\": 1769437562,\n \"model\": \"gpt-4o-mini-2024-07-18\",\n
|
||||
\ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\":
|
||||
\"assistant\",\n \"content\": \"Goodbye! If you have more questions
|
||||
in the future, feel free to reach out. Have a great day!\",\n \"refusal\":
|
||||
null,\n \"annotations\": []\n },\n \"logprobs\": null,\n
|
||||
\ \"finish_reason\": \"stop\"\n }\n ],\n \"usage\": {\n \"prompt_tokens\":
|
||||
9,\n \"completion_tokens\": 23,\n \"total_tokens\": 32,\n \"prompt_tokens_details\":
|
||||
{\n \"cached_tokens\": 0,\n \"audio_tokens\": 0\n },\n \"completion_tokens_details\":
|
||||
{\n \"reasoning_tokens\": 0,\n \"audio_tokens\": 0,\n \"accepted_prediction_tokens\":
|
||||
0,\n \"rejected_prediction_tokens\": 0\n }\n },\n \"service_tier\":
|
||||
\"default\",\n \"system_fingerprint\": \"fp_29330a9688\"\n}\n"
|
||||
headers:
|
||||
CF-RAY:
|
||||
- CF-RAY-XXX
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Mon, 26 Jan 2026 14:26:03 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Strict-Transport-Security:
|
||||
- STS-XXX
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- X-CONTENT-TYPE-XXX
|
||||
access-control-expose-headers:
|
||||
- ACCESS-CONTROL-XXX
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
openai-organization:
|
||||
- OPENAI-ORG-XXX
|
||||
openai-processing-ms:
|
||||
- '964'
|
||||
openai-project:
|
||||
- OPENAI-PROJECT-XXX
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
x-envoy-upstream-service-time:
|
||||
- '979'
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-requests:
|
||||
- X-RATELIMIT-LIMIT-REQUESTS-XXX
|
||||
x-ratelimit-limit-tokens:
|
||||
- X-RATELIMIT-LIMIT-TOKENS-XXX
|
||||
x-ratelimit-remaining-requests:
|
||||
- X-RATELIMIT-REMAINING-REQUESTS-XXX
|
||||
x-ratelimit-remaining-tokens:
|
||||
- X-RATELIMIT-REMAINING-TOKENS-XXX
|
||||
x-ratelimit-reset-requests:
|
||||
- X-RATELIMIT-RESET-REQUESTS-XXX
|
||||
x-ratelimit-reset-tokens:
|
||||
- X-RATELIMIT-RESET-TOKENS-XXX
|
||||
x-request-id:
|
||||
- X-REQUEST-ID-XXX
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
@@ -1,143 +0,0 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"messages":[{"role":"user","content":"Say hi"}],"model":"gpt-4o-mini","stream":true,"stream_options":{"include_usage":true}}'
|
||||
headers:
|
||||
User-Agent:
|
||||
- X-USER-AGENT-XXX
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- ACCEPT-ENCODING-XXX
|
||||
authorization:
|
||||
- AUTHORIZATION-XXX
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '125'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- api.openai.com
|
||||
x-stainless-arch:
|
||||
- X-STAINLESS-ARCH-XXX
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- X-STAINLESS-OS-XXX
|
||||
x-stainless-package-version:
|
||||
- 1.83.0
|
||||
x-stainless-read-timeout:
|
||||
- X-STAINLESS-READ-TIMEOUT-XXX
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.13.0
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: 'data: {"id":"chatcmpl-D2HpUGTvIFKBsR9Xd6XRT4AuFXzbz","object":"chat.completion.chunk","created":1769437564,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_29330a9688","choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"rVIyGQF2E"}
|
||||
|
||||
|
||||
data: {"id":"chatcmpl-D2HpUGTvIFKBsR9Xd6XRT4AuFXzbz","object":"chat.completion.chunk","created":1769437564,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_29330a9688","choices":[{"index":0,"delta":{"content":"Hi"},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"ZGVqV7ZDm"}
|
||||
|
||||
|
||||
data: {"id":"chatcmpl-D2HpUGTvIFKBsR9Xd6XRT4AuFXzbz","object":"chat.completion.chunk","created":1769437564,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_29330a9688","choices":[{"index":0,"delta":{"content":"!"},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"vnfm7IxlIB"}
|
||||
|
||||
|
||||
data: {"id":"chatcmpl-D2HpUGTvIFKBsR9Xd6XRT4AuFXzbz","object":"chat.completion.chunk","created":1769437564,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_29330a9688","choices":[{"index":0,"delta":{"content":"
|
||||
How"},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"o8F35ZZ"}
|
||||
|
||||
|
||||
data: {"id":"chatcmpl-D2HpUGTvIFKBsR9Xd6XRT4AuFXzbz","object":"chat.completion.chunk","created":1769437564,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_29330a9688","choices":[{"index":0,"delta":{"content":"
|
||||
can"},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"kiBzGe3"}
|
||||
|
||||
|
||||
data: {"id":"chatcmpl-D2HpUGTvIFKBsR9Xd6XRT4AuFXzbz","object":"chat.completion.chunk","created":1769437564,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_29330a9688","choices":[{"index":0,"delta":{"content":"
|
||||
I"},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"cbGT2RWgx"}
|
||||
|
||||
|
||||
data: {"id":"chatcmpl-D2HpUGTvIFKBsR9Xd6XRT4AuFXzbz","object":"chat.completion.chunk","created":1769437564,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_29330a9688","choices":[{"index":0,"delta":{"content":"
|
||||
assist"},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"DtxR"}
|
||||
|
||||
|
||||
data: {"id":"chatcmpl-D2HpUGTvIFKBsR9Xd6XRT4AuFXzbz","object":"chat.completion.chunk","created":1769437564,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_29330a9688","choices":[{"index":0,"delta":{"content":"
|
||||
you"},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"6y6Co8J"}
|
||||
|
||||
|
||||
data: {"id":"chatcmpl-D2HpUGTvIFKBsR9Xd6XRT4AuFXzbz","object":"chat.completion.chunk","created":1769437564,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_29330a9688","choices":[{"index":0,"delta":{"content":"
|
||||
today"},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"SZOmm"}
|
||||
|
||||
|
||||
data: {"id":"chatcmpl-D2HpUGTvIFKBsR9Xd6XRT4AuFXzbz","object":"chat.completion.chunk","created":1769437564,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_29330a9688","choices":[{"index":0,"delta":{"content":"?"},"logprobs":null,"finish_reason":null}],"usage":null,"obfuscation":"s9Bc0HqlPg"}
|
||||
|
||||
|
||||
data: {"id":"chatcmpl-D2HpUGTvIFKBsR9Xd6XRT4AuFXzbz","object":"chat.completion.chunk","created":1769437564,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_29330a9688","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":null,"obfuscation":"u9aar"}
|
||||
|
||||
|
||||
data: {"id":"chatcmpl-D2HpUGTvIFKBsR9Xd6XRT4AuFXzbz","object":"chat.completion.chunk","created":1769437564,"model":"gpt-4o-mini-2024-07-18","service_tier":"default","system_fingerprint":"fp_29330a9688","choices":[],"usage":{"prompt_tokens":9,"completion_tokens":9,"total_tokens":18,"prompt_tokens_details":{"cached_tokens":0,"audio_tokens":0},"completion_tokens_details":{"reasoning_tokens":0,"audio_tokens":0,"accepted_prediction_tokens":0,"rejected_prediction_tokens":0}},"obfuscation":"5hudm8ySqh39"}
|
||||
|
||||
|
||||
data: [DONE]
|
||||
|
||||
|
||||
'
|
||||
headers:
|
||||
CF-RAY:
|
||||
- CF-RAY-XXX
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Type:
|
||||
- text/event-stream; charset=utf-8
|
||||
Date:
|
||||
- Mon, 26 Jan 2026 14:26:04 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Set-Cookie:
|
||||
- SET-COOKIE-XXX
|
||||
Strict-Transport-Security:
|
||||
- STS-XXX
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- X-CONTENT-TYPE-XXX
|
||||
access-control-expose-headers:
|
||||
- ACCESS-CONTROL-XXX
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
openai-organization:
|
||||
- OPENAI-ORG-XXX
|
||||
openai-processing-ms:
|
||||
- '260'
|
||||
openai-project:
|
||||
- OPENAI-PROJECT-XXX
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
x-envoy-upstream-service-time:
|
||||
- '275'
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-requests:
|
||||
- X-RATELIMIT-LIMIT-REQUESTS-XXX
|
||||
x-ratelimit-limit-tokens:
|
||||
- X-RATELIMIT-LIMIT-TOKENS-XXX
|
||||
x-ratelimit-remaining-requests:
|
||||
- X-RATELIMIT-REMAINING-REQUESTS-XXX
|
||||
x-ratelimit-remaining-tokens:
|
||||
- X-RATELIMIT-REMAINING-TOKENS-XXX
|
||||
x-ratelimit-reset-requests:
|
||||
- X-RATELIMIT-RESET-REQUESTS-XXX
|
||||
x-ratelimit-reset-tokens:
|
||||
- X-RATELIMIT-RESET-TOKENS-XXX
|
||||
x-request-id:
|
||||
- X-REQUEST-ID-XXX
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
@@ -272,3 +272,100 @@ class TestEmbeddingFactory:
|
||||
mock_build_from_provider.assert_called_once_with(mock_provider)
|
||||
assert result == mock_embedding_function
|
||||
mock_import.assert_not_called()
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_build_embedder_google_vertex_with_genai_model(self, mock_import):
|
||||
"""Test routing to Google Vertex provider with new genai model."""
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_embedding_function = MagicMock()
|
||||
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||
|
||||
config = {
|
||||
"provider": "google-vertex",
|
||||
"config": {
|
||||
"api_key": "test-google-api-key",
|
||||
"model_name": "gemini-embedding-001",
|
||||
},
|
||||
}
|
||||
|
||||
build_embedder(config)
|
||||
|
||||
mock_import.assert_called_once_with(
|
||||
"crewai.rag.embeddings.providers.google.vertex.VertexAIProvider"
|
||||
)
|
||||
mock_provider_class.assert_called_once()
|
||||
|
||||
call_kwargs = mock_provider_class.call_args.kwargs
|
||||
assert call_kwargs["api_key"] == "test-google-api-key"
|
||||
assert call_kwargs["model_name"] == "gemini-embedding-001"
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_build_embedder_google_vertex_with_legacy_model(self, mock_import):
|
||||
"""Test routing to Google Vertex provider with legacy textembedding-gecko model."""
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_embedding_function = MagicMock()
|
||||
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||
|
||||
config = {
|
||||
"provider": "google-vertex",
|
||||
"config": {
|
||||
"project_id": "my-gcp-project",
|
||||
"region": "us-central1",
|
||||
"model_name": "textembedding-gecko",
|
||||
},
|
||||
}
|
||||
|
||||
build_embedder(config)
|
||||
|
||||
mock_import.assert_called_once_with(
|
||||
"crewai.rag.embeddings.providers.google.vertex.VertexAIProvider"
|
||||
)
|
||||
mock_provider_class.assert_called_once()
|
||||
|
||||
call_kwargs = mock_provider_class.call_args.kwargs
|
||||
assert call_kwargs["project_id"] == "my-gcp-project"
|
||||
assert call_kwargs["region"] == "us-central1"
|
||||
assert call_kwargs["model_name"] == "textembedding-gecko"
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_build_embedder_google_vertex_with_location(self, mock_import):
|
||||
"""Test routing to Google Vertex provider with location parameter."""
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_embedding_function = MagicMock()
|
||||
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||
|
||||
config = {
|
||||
"provider": "google-vertex",
|
||||
"config": {
|
||||
"project_id": "my-gcp-project",
|
||||
"location": "europe-west1",
|
||||
"model_name": "gemini-embedding-001",
|
||||
"task_type": "RETRIEVAL_DOCUMENT",
|
||||
"output_dimensionality": 768,
|
||||
},
|
||||
}
|
||||
|
||||
build_embedder(config)
|
||||
|
||||
mock_import.assert_called_once_with(
|
||||
"crewai.rag.embeddings.providers.google.vertex.VertexAIProvider"
|
||||
)
|
||||
|
||||
call_kwargs = mock_provider_class.call_args.kwargs
|
||||
assert call_kwargs["project_id"] == "my-gcp-project"
|
||||
assert call_kwargs["location"] == "europe-west1"
|
||||
assert call_kwargs["model_name"] == "gemini-embedding-001"
|
||||
assert call_kwargs["task_type"] == "RETRIEVAL_DOCUMENT"
|
||||
assert call_kwargs["output_dimensionality"] == 768
|
||||
|
||||
@@ -0,0 +1,176 @@
|
||||
"""Integration tests for Google Vertex embeddings with Crew memory.
|
||||
|
||||
These tests make real API calls and use VCR to record/replay responses.
|
||||
"""
|
||||
|
||||
import os
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai import Agent, Crew, Task
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_vertex_ai_env():
|
||||
"""Set up environment for Vertex AI tests.
|
||||
|
||||
Sets GOOGLE_GENAI_USE_VERTEXAI=true to ensure the SDK uses the Vertex AI
|
||||
backend (aiplatform.googleapis.com) which matches the VCR cassettes.
|
||||
Also mocks GOOGLE_API_KEY if not already set.
|
||||
"""
|
||||
env_updates = {"GOOGLE_GENAI_USE_VERTEXAI": "true"}
|
||||
|
||||
# Add a mock API key if none exists
|
||||
if "GOOGLE_API_KEY" not in os.environ and "GEMINI_API_KEY" not in os.environ:
|
||||
env_updates["GOOGLE_API_KEY"] = "test-key"
|
||||
|
||||
with patch.dict(os.environ, env_updates):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def google_vertex_embedder_config():
|
||||
"""Fixture providing Google Vertex embedder configuration."""
|
||||
return {
|
||||
"provider": "google-vertex",
|
||||
"config": {
|
||||
"api_key": os.getenv("GOOGLE_API_KEY", "test-key"),
|
||||
"model_name": "gemini-embedding-001",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_agent():
|
||||
"""Fixture providing a simple test agent."""
|
||||
return Agent(
|
||||
role="Research Assistant",
|
||||
goal="Help with research tasks",
|
||||
backstory="You are a helpful research assistant.",
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_task(simple_agent):
|
||||
"""Fixture providing a simple test task."""
|
||||
return Task(
|
||||
description="Summarize the key points about artificial intelligence in one sentence.",
|
||||
expected_output="A one sentence summary about AI.",
|
||||
agent=simple_agent,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
@pytest.mark.timeout(120) # Longer timeout for VCR recording
|
||||
def test_crew_memory_with_google_vertex_embedder(
|
||||
google_vertex_embedder_config, simple_agent, simple_task
|
||||
) -> None:
|
||||
"""Test that Crew with memory=True works with google-vertex embedder and memory is used."""
|
||||
# Track memory events
|
||||
events: dict[str, list] = defaultdict(list)
|
||||
condition = threading.Condition()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||
def on_save_started(source, event):
|
||||
with condition:
|
||||
events["MemorySaveStartedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||
def on_save_completed(source, event):
|
||||
with condition:
|
||||
events["MemorySaveCompletedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
crew = Crew(
|
||||
agents=[simple_agent],
|
||||
tasks=[simple_task],
|
||||
memory=True,
|
||||
embedder=google_vertex_embedder_config,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
result = crew.kickoff()
|
||||
|
||||
assert result is not None
|
||||
assert result.raw is not None
|
||||
assert len(result.raw) > 0
|
||||
|
||||
with condition:
|
||||
success = condition.wait_for(
|
||||
lambda: len(events["MemorySaveCompletedEvent"]) >= 1,
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
assert success, "Timeout waiting for memory save events - memory may not be working"
|
||||
assert len(events["MemorySaveStartedEvent"]) >= 1, "No memory save started events"
|
||||
assert len(events["MemorySaveCompletedEvent"]) >= 1, "Memory save completed events"
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
@pytest.mark.timeout(120)
|
||||
def test_crew_memory_with_google_vertex_project_id(simple_agent, simple_task) -> None:
|
||||
"""Test Crew memory with Google Vertex using project_id authentication."""
|
||||
project_id = os.getenv("GOOGLE_CLOUD_PROJECT")
|
||||
if not project_id:
|
||||
pytest.skip("GOOGLE_CLOUD_PROJECT environment variable not set")
|
||||
|
||||
# Track memory events
|
||||
events: dict[str, list] = defaultdict(list)
|
||||
condition = threading.Condition()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||
def on_save_started(source, event):
|
||||
with condition:
|
||||
events["MemorySaveStartedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||
def on_save_completed(source, event):
|
||||
with condition:
|
||||
events["MemorySaveCompletedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
embedder_config = {
|
||||
"provider": "google-vertex",
|
||||
"config": {
|
||||
"project_id": project_id,
|
||||
"location": "us-central1",
|
||||
"model_name": "gemini-embedding-001",
|
||||
},
|
||||
}
|
||||
|
||||
crew = Crew(
|
||||
agents=[simple_agent],
|
||||
tasks=[simple_task],
|
||||
memory=True,
|
||||
embedder=embedder_config,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
result = crew.kickoff()
|
||||
|
||||
# Verify basic result
|
||||
assert result is not None
|
||||
assert result.raw is not None
|
||||
|
||||
# Wait for memory save events
|
||||
with condition:
|
||||
success = condition.wait_for(
|
||||
lambda: len(events["MemorySaveCompletedEvent"]) >= 1,
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
# Verify memory was actually used
|
||||
assert success, "Timeout waiting for memory save events - memory may not be working"
|
||||
assert len(events["MemorySaveStartedEvent"]) >= 1, "No memory save started events"
|
||||
assert len(events["MemorySaveCompletedEvent"]) >= 1, "No memory save completed events"
|
||||
@@ -217,7 +217,6 @@ class TestCrewKickoffStreaming:
|
||||
LLMStreamChunkEvent(
|
||||
type="llm_stream_chunk",
|
||||
chunk="Hello ",
|
||||
call_id="test-call-id",
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
@@ -225,7 +224,6 @@ class TestCrewKickoffStreaming:
|
||||
LLMStreamChunkEvent(
|
||||
type="llm_stream_chunk",
|
||||
chunk="World!",
|
||||
call_id="test-call-id",
|
||||
),
|
||||
)
|
||||
return mock_output
|
||||
@@ -286,7 +284,6 @@ class TestCrewKickoffStreaming:
|
||||
LLMStreamChunkEvent(
|
||||
type="llm_stream_chunk",
|
||||
chunk="",
|
||||
call_id="test-call-id",
|
||||
tool_call=ToolCall(
|
||||
id="call-123",
|
||||
function=FunctionCall(
|
||||
@@ -367,7 +364,6 @@ class TestCrewKickoffStreamingAsync:
|
||||
LLMStreamChunkEvent(
|
||||
type="llm_stream_chunk",
|
||||
chunk="Async ",
|
||||
call_id="test-call-id",
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
@@ -375,7 +371,6 @@ class TestCrewKickoffStreamingAsync:
|
||||
LLMStreamChunkEvent(
|
||||
type="llm_stream_chunk",
|
||||
chunk="Stream!",
|
||||
call_id="test-call-id",
|
||||
),
|
||||
)
|
||||
return mock_output
|
||||
@@ -456,7 +451,6 @@ class TestFlowKickoffStreaming:
|
||||
LLMStreamChunkEvent(
|
||||
type="llm_stream_chunk",
|
||||
chunk="Flow ",
|
||||
call_id="test-call-id",
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
@@ -464,7 +458,6 @@ class TestFlowKickoffStreaming:
|
||||
LLMStreamChunkEvent(
|
||||
type="llm_stream_chunk",
|
||||
chunk="output!",
|
||||
call_id="test-call-id",
|
||||
),
|
||||
)
|
||||
return "done"
|
||||
@@ -552,7 +545,6 @@ class TestFlowKickoffStreamingAsync:
|
||||
LLMStreamChunkEvent(
|
||||
type="llm_stream_chunk",
|
||||
chunk="Async flow ",
|
||||
call_id="test-call-id",
|
||||
),
|
||||
)
|
||||
await asyncio.sleep(0.01)
|
||||
@@ -561,7 +553,6 @@ class TestFlowKickoffStreamingAsync:
|
||||
LLMStreamChunkEvent(
|
||||
type="llm_stream_chunk",
|
||||
chunk="stream!",
|
||||
call_id="test-call-id",
|
||||
),
|
||||
)
|
||||
await asyncio.sleep(0.01)
|
||||
@@ -695,7 +686,6 @@ class TestStreamingEdgeCases:
|
||||
type="llm_stream_chunk",
|
||||
chunk="Task 1",
|
||||
task_name="First task",
|
||||
call_id="test-call-id",
|
||||
),
|
||||
)
|
||||
return mock_output
|
||||
|
||||
@@ -984,8 +984,8 @@ def test_streaming_fallback_to_non_streaming():
|
||||
def mock_call(messages, tools=None, callbacks=None, available_functions=None):
|
||||
nonlocal fallback_called
|
||||
# Emit a couple of chunks to simulate partial streaming
|
||||
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk="Test chunk 1", response_id="Id", call_id="test-call-id"))
|
||||
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk="Test chunk 2", response_id="Id", call_id="test-call-id"))
|
||||
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk="Test chunk 1", response_id = "Id"))
|
||||
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk="Test chunk 2", response_id = "Id"))
|
||||
|
||||
# Mark that fallback would be called
|
||||
fallback_called = True
|
||||
@@ -1041,7 +1041,7 @@ def test_streaming_empty_response_handling():
|
||||
def mock_call(messages, tools=None, callbacks=None, available_functions=None):
|
||||
# Emit a few empty chunks
|
||||
for _ in range(3):
|
||||
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk="", response_id="id", call_id="test-call-id"))
|
||||
crewai_event_bus.emit(llm, event=LLMStreamChunkEvent(chunk="",response_id="id"))
|
||||
|
||||
# Return the default message for empty responses
|
||||
return "I apologize, but I couldn't generate a proper response. Please try again or rephrase your request."
|
||||
@@ -1280,105 +1280,6 @@ def test_llm_emits_event_with_lite_agent():
|
||||
assert set(all_agent_id) == {str(agent.id)}
|
||||
|
||||
|
||||
# ----------- CALL_ID CORRELATION TESTS -----------
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_llm_call_events_share_call_id():
|
||||
"""All events from a single LLM call should share the same call_id."""
|
||||
import uuid
|
||||
|
||||
events = []
|
||||
condition = threading.Condition()
|
||||
|
||||
@crewai_event_bus.on(LLMCallStartedEvent)
|
||||
def on_start(source, event):
|
||||
with condition:
|
||||
events.append(event)
|
||||
condition.notify()
|
||||
|
||||
@crewai_event_bus.on(LLMCallCompletedEvent)
|
||||
def on_complete(source, event):
|
||||
with condition:
|
||||
events.append(event)
|
||||
condition.notify()
|
||||
|
||||
llm = LLM(model="gpt-4o-mini")
|
||||
llm.call("Say hi")
|
||||
|
||||
with condition:
|
||||
success = condition.wait_for(lambda: len(events) >= 2, timeout=10)
|
||||
assert success, "Timeout waiting for LLM events"
|
||||
|
||||
# Behavior: all events from the call share the same call_id
|
||||
assert len(events) == 2
|
||||
assert events[0].call_id == events[1].call_id
|
||||
# call_id should be a valid UUID
|
||||
uuid.UUID(events[0].call_id)
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_streaming_chunks_share_call_id_with_call():
|
||||
"""Streaming chunks should share call_id with started/completed events."""
|
||||
events = []
|
||||
condition = threading.Condition()
|
||||
|
||||
@crewai_event_bus.on(LLMCallStartedEvent)
|
||||
def on_start(source, event):
|
||||
with condition:
|
||||
events.append(event)
|
||||
condition.notify()
|
||||
|
||||
@crewai_event_bus.on(LLMStreamChunkEvent)
|
||||
def on_chunk(source, event):
|
||||
with condition:
|
||||
events.append(event)
|
||||
condition.notify()
|
||||
|
||||
@crewai_event_bus.on(LLMCallCompletedEvent)
|
||||
def on_complete(source, event):
|
||||
with condition:
|
||||
events.append(event)
|
||||
condition.notify()
|
||||
|
||||
llm = LLM(model="gpt-4o-mini", stream=True)
|
||||
llm.call("Say hi")
|
||||
|
||||
with condition:
|
||||
# Wait for at least started, some chunks, and completed
|
||||
success = condition.wait_for(lambda: len(events) >= 3, timeout=10)
|
||||
assert success, "Timeout waiting for streaming events"
|
||||
|
||||
# Behavior: all events (started, chunks, completed) share the same call_id
|
||||
call_ids = {e.call_id for e in events}
|
||||
assert len(call_ids) == 1
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_separate_llm_calls_have_different_call_ids():
|
||||
"""Different LLM calls should have different call_ids."""
|
||||
call_ids = []
|
||||
condition = threading.Condition()
|
||||
|
||||
@crewai_event_bus.on(LLMCallStartedEvent)
|
||||
def on_start(source, event):
|
||||
with condition:
|
||||
call_ids.append(event.call_id)
|
||||
condition.notify()
|
||||
|
||||
llm = LLM(model="gpt-4o-mini")
|
||||
llm.call("Say hi")
|
||||
llm.call("Say bye")
|
||||
|
||||
with condition:
|
||||
success = condition.wait_for(lambda: len(call_ids) >= 2, timeout=10)
|
||||
assert success, "Timeout waiting for LLM call events"
|
||||
|
||||
# Behavior: each call has its own call_id
|
||||
assert len(call_ids) == 2
|
||||
assert call_ids[0] != call_ids[1]
|
||||
|
||||
|
||||
# ----------- HUMAN FEEDBACK EVENTS -----------
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
"""CrewAI development tools."""
|
||||
|
||||
__version__ = "1.8.1"
|
||||
__version__ = "1.9.0"
|
||||
|
||||
8
uv.lock
generated
8
uv.lock
generated
@@ -310,7 +310,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "anthropic"
|
||||
version = "0.71.1"
|
||||
version = "0.73.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "anyio" },
|
||||
@@ -322,9 +322,9 @@ dependencies = [
|
||||
{ name = "sniffio" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/05/4b/19620875841f692fdc35eb58bf0201c8ad8c47b8443fecbf1b225312175b/anthropic-0.71.1.tar.gz", hash = "sha256:a77d156d3e7d318b84681b59823b2dee48a8ac508a3e54e49f0ab0d074e4b0da", size = 493294, upload-time = "2025-10-28T17:28:42.213Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f0/07/f550112c3f5299d02f06580577f602e8a112b1988ad7c98ac1a8f7292d7e/anthropic-0.73.0.tar.gz", hash = "sha256:30f0d7d86390165f86af6ca7c3041f8720bb2e1b0e12a44525c8edfdbd2c5239", size = 425168, upload-time = "2025-11-14T18:47:52.635Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/4b/68/b2f988b13325f9ac9921b1e87f0b7994468014e1b5bd3bdbd2472f5baf45/anthropic-0.71.1-py3-none-any.whl", hash = "sha256:6ca6c579f0899a445faeeed9c0eb97aa4bdb751196262f9ccc96edfc0bb12679", size = 355020, upload-time = "2025-10-28T17:28:40.653Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/15/b1/5d4d3f649e151e58dc938cf19c4d0cd19fca9a986879f30fea08a7b17138/anthropic-0.73.0-py3-none-any.whl", hash = "sha256:0d56cd8b3ca3fea9c9b5162868bdfd053fbc189b8b56d4290bd2d427b56db769", size = 367839, upload-time = "2025-11-14T18:47:51.195Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1276,7 +1276,7 @@ requires-dist = [
|
||||
{ name = "aiobotocore", marker = "extra == 'aws'", specifier = "~=2.25.2" },
|
||||
{ name = "aiocache", extras = ["memcached", "redis"], marker = "extra == 'a2a'", specifier = "~=0.12.3" },
|
||||
{ name = "aiosqlite", specifier = "~=0.21.0" },
|
||||
{ name = "anthropic", marker = "extra == 'anthropic'", specifier = "~=0.71.0" },
|
||||
{ name = "anthropic", marker = "extra == 'anthropic'", specifier = "~=0.73.0" },
|
||||
{ name = "appdirs", specifier = "~=1.4.4" },
|
||||
{ name = "azure-ai-inference", marker = "extra == 'azure-ai-inference'", specifier = "~=1.0.0b9" },
|
||||
{ name = "boto3", marker = "extra == 'aws'", specifier = "~=1.40.38" },
|
||||
|
||||
Reference in New Issue
Block a user