mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-25 16:18:13 +00:00
- add input_files parameter to Crew.kickoff(), Flow.kickoff(), Task, and Agent.kickoff() - add provider-specific file uploaders for OpenAI, Anthropic, Gemini, and Bedrock - add file type detection, constraint validation, and automatic format conversion - add URL file source support for multimodal content - add streaming uploads for large files - add prompt caching support for Anthropic - add OpenAI Responses API support
1229 lines
46 KiB
Python
1229 lines
46 KiB
Python
from __future__ import annotations
|
|
|
|
import base64
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
from typing import TYPE_CHECKING, Any, Literal, cast
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from crewai.events.types.llm_events import LLMCallType
|
|
from crewai.llms.base_llm import BaseLLM
|
|
from crewai.utilities.agent_utils import is_context_length_exceeded
|
|
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
|
LLMContextLengthExceededError,
|
|
)
|
|
from crewai.utilities.types import LLMMessage
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from crewai.llms.hooks.base import BaseInterceptor
|
|
|
|
|
|
try:
|
|
from google import genai
|
|
from google.genai import types
|
|
from google.genai.errors import APIError
|
|
from google.genai.types import GenerateContentResponse
|
|
except ImportError:
|
|
raise ImportError(
|
|
'Google Gen AI native provider not available, to install: uv add "crewai[google-genai]"'
|
|
) from None
|
|
|
|
|
|
class GeminiCompletion(BaseLLM):
|
|
"""Google Gemini native completion implementation.
|
|
|
|
This class provides direct integration with the Google Gen AI Python SDK,
|
|
offering native function calling, streaming support, and proper Gemini formatting.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model: str = "gemini-2.0-flash-001",
|
|
api_key: str | None = None,
|
|
project: str | None = None,
|
|
location: str | None = None,
|
|
temperature: float | None = None,
|
|
top_p: float | None = None,
|
|
top_k: int | None = None,
|
|
max_output_tokens: int | None = None,
|
|
stop_sequences: list[str] | None = None,
|
|
stream: bool = False,
|
|
safety_settings: dict[str, Any] | None = None,
|
|
client_params: dict[str, Any] | None = None,
|
|
interceptor: BaseInterceptor[Any, Any] | None = None,
|
|
use_vertexai: bool | None = None,
|
|
**kwargs: Any,
|
|
):
|
|
"""Initialize Google Gemini chat completion client.
|
|
|
|
Args:
|
|
model: Gemini model name (e.g., 'gemini-2.0-flash-001', 'gemini-1.5-pro')
|
|
api_key: Google API key for Gemini API authentication.
|
|
Defaults to GOOGLE_API_KEY or GEMINI_API_KEY env var.
|
|
NOTE: Cannot be used with Vertex AI (project parameter). Use Gemini API instead.
|
|
project: Google Cloud project ID for Vertex AI with ADC authentication.
|
|
Requires Application Default Credentials (gcloud auth application-default login).
|
|
NOTE: Vertex AI does NOT support API keys, only OAuth2/ADC.
|
|
If both api_key and project are set, api_key takes precedence.
|
|
location: Google Cloud location (for Vertex AI with ADC, defaults to 'us-central1')
|
|
temperature: Sampling temperature (0-2)
|
|
top_p: Nucleus sampling parameter
|
|
top_k: Top-k sampling parameter
|
|
max_output_tokens: Maximum tokens in response
|
|
stop_sequences: Stop sequences
|
|
stream: Enable streaming responses
|
|
safety_settings: Safety filter settings
|
|
client_params: Additional parameters to pass to the Google Gen AI Client constructor.
|
|
Supports parameters like http_options, credentials, debug_config, etc.
|
|
interceptor: HTTP interceptor (not yet supported for Gemini).
|
|
use_vertexai: Whether to use Vertex AI instead of Gemini API.
|
|
- True: Use Vertex AI (with ADC or Express mode with API key)
|
|
- False: Use Gemini API (explicitly override env var)
|
|
- None (default): Check GOOGLE_GENAI_USE_VERTEXAI env var
|
|
When using Vertex AI with API key (Express mode), http_options with
|
|
api_version="v1" is automatically configured.
|
|
**kwargs: Additional parameters
|
|
"""
|
|
if interceptor is not None:
|
|
raise NotImplementedError(
|
|
"HTTP interceptors are not yet supported for Google Gemini provider. "
|
|
"Interceptors are currently supported for OpenAI and Anthropic providers only."
|
|
)
|
|
|
|
super().__init__(
|
|
model=model, temperature=temperature, stop=stop_sequences or [], **kwargs
|
|
)
|
|
|
|
# Store client params for later use
|
|
self.client_params = client_params or {}
|
|
|
|
# Get API configuration with environment variable fallbacks
|
|
self.api_key = (
|
|
api_key or os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY")
|
|
)
|
|
self.project = project or os.getenv("GOOGLE_CLOUD_PROJECT")
|
|
self.location = location or os.getenv("GOOGLE_CLOUD_LOCATION") or "us-central1"
|
|
|
|
if use_vertexai is None:
|
|
use_vertexai = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true"
|
|
|
|
self.client = self._initialize_client(use_vertexai)
|
|
|
|
# Store completion parameters
|
|
self.top_p = top_p
|
|
self.top_k = top_k
|
|
self.max_output_tokens = max_output_tokens
|
|
self.stream = stream
|
|
self.safety_settings = safety_settings or {}
|
|
self.stop_sequences = stop_sequences or []
|
|
self.tools: list[dict[str, Any]] | None = None
|
|
|
|
# Model-specific settings
|
|
version_match = re.search(r"gemini-(\d+(?:\.\d+)?)", model.lower())
|
|
self.supports_tools = bool(
|
|
version_match and float(version_match.group(1)) >= 1.5
|
|
)
|
|
|
|
@property
|
|
def stop(self) -> list[str]:
|
|
"""Get stop sequences sent to the API."""
|
|
return self.stop_sequences
|
|
|
|
@stop.setter
|
|
def stop(self, value: list[str] | str | None) -> None:
|
|
"""Set stop sequences.
|
|
|
|
Synchronizes stop_sequences to ensure values set by CrewAgentExecutor
|
|
are properly sent to the Gemini API.
|
|
|
|
Args:
|
|
value: Stop sequences as a list, single string, or None
|
|
"""
|
|
if value is None:
|
|
self.stop_sequences = []
|
|
elif isinstance(value, str):
|
|
self.stop_sequences = [value]
|
|
elif isinstance(value, list):
|
|
self.stop_sequences = value
|
|
else:
|
|
self.stop_sequences = []
|
|
|
|
def _initialize_client(self, use_vertexai: bool = False) -> genai.Client:
|
|
"""Initialize the Google Gen AI client with proper parameter handling.
|
|
|
|
Args:
|
|
use_vertexai: Whether to use Vertex AI (from environment variable)
|
|
|
|
Returns:
|
|
Initialized Google Gen AI Client
|
|
|
|
Note:
|
|
Google Gen AI SDK has two distinct endpoints with different auth requirements:
|
|
- Gemini API (generativelanguage.googleapis.com): Supports API key authentication
|
|
- Vertex AI (aiplatform.googleapis.com): Only supports OAuth2/ADC, NO API keys
|
|
|
|
When vertexai=True is set, it routes to aiplatform.googleapis.com which rejects
|
|
API keys. Use Gemini API endpoint for API key authentication instead.
|
|
"""
|
|
client_params = {}
|
|
|
|
if self.client_params:
|
|
client_params.update(self.client_params)
|
|
|
|
# Determine authentication mode based on available credentials
|
|
has_api_key = bool(self.api_key)
|
|
has_project = bool(self.project)
|
|
|
|
if has_api_key and has_project:
|
|
logging.warning(
|
|
"Both API key and project provided. Using API key authentication. "
|
|
"Project/location parameters are ignored when using API keys. "
|
|
"To use Vertex AI with ADC, remove the api_key parameter."
|
|
)
|
|
has_project = False
|
|
|
|
# Vertex AI with ADC (project without API key)
|
|
if (use_vertexai or has_project) and not has_api_key:
|
|
client_params.update(
|
|
{
|
|
"vertexai": True,
|
|
"project": self.project,
|
|
"location": self.location,
|
|
}
|
|
)
|
|
|
|
# API key authentication (works with both Gemini API and Vertex AI Express)
|
|
elif has_api_key:
|
|
client_params["api_key"] = self.api_key
|
|
|
|
# Vertex AI Express mode: API key + vertexai=True + http_options with api_version="v1"
|
|
# See: https://cloud.google.com/vertex-ai/generative-ai/docs/start/quickstart?usertype=apikey
|
|
if use_vertexai:
|
|
client_params["vertexai"] = True
|
|
client_params["http_options"] = types.HttpOptions(api_version="v1")
|
|
else:
|
|
# This ensures we use the Gemini API (generativelanguage.googleapis.com)
|
|
client_params["vertexai"] = False
|
|
|
|
# Clean up project/location (not allowed with API key)
|
|
client_params.pop("project", None)
|
|
client_params.pop("location", None)
|
|
|
|
else:
|
|
try:
|
|
return genai.Client(**client_params)
|
|
except Exception as e:
|
|
raise ValueError(
|
|
"Authentication required. Provide one of:\n"
|
|
" 1. API key via GOOGLE_API_KEY or GEMINI_API_KEY environment variable\n"
|
|
" (use_vertexai=True is optional for Vertex AI with API key)\n"
|
|
" 2. For Vertex AI with ADC: Set GOOGLE_CLOUD_PROJECT and run:\n"
|
|
" gcloud auth application-default login\n"
|
|
" 3. Pass api_key parameter directly to LLM constructor\n"
|
|
) from e
|
|
return genai.Client(**client_params)
|
|
|
|
def _get_client_params(self) -> dict[str, Any]:
|
|
"""Get client parameters for compatibility with base class.
|
|
|
|
Note: This method is kept for compatibility but the Google Gen AI SDK
|
|
uses a different initialization pattern via the Client constructor.
|
|
"""
|
|
params = {}
|
|
|
|
if (
|
|
hasattr(self, "client")
|
|
and hasattr(self.client, "vertexai")
|
|
and self.client.vertexai
|
|
):
|
|
# Vertex AI configuration
|
|
params.update(
|
|
{
|
|
"vertexai": True,
|
|
"project": self.project,
|
|
"location": self.location,
|
|
}
|
|
)
|
|
if self.api_key:
|
|
params["api_key"] = self.api_key
|
|
elif self.api_key:
|
|
params["api_key"] = self.api_key
|
|
|
|
if self.client_params:
|
|
params.update(self.client_params)
|
|
|
|
return params
|
|
|
|
def call(
|
|
self,
|
|
messages: str | list[LLMMessage],
|
|
tools: list[dict[str, Any]] | None = None,
|
|
callbacks: list[Any] | None = None,
|
|
available_functions: dict[str, Any] | None = None,
|
|
from_task: Any | None = None,
|
|
from_agent: Any | None = None,
|
|
response_model: type[BaseModel] | None = None,
|
|
) -> str | Any:
|
|
"""Call Google Gemini generate content API.
|
|
|
|
Args:
|
|
messages: Input messages for the chat completion
|
|
tools: List of tool/function definitions
|
|
callbacks: Callback functions (not used as token counts are handled by the response)
|
|
available_functions: Available functions for tool calling
|
|
from_task: Task that initiated the call
|
|
from_agent: Agent that initiated the call
|
|
response_model: Response model to use.
|
|
|
|
Returns:
|
|
Chat completion response or tool call result
|
|
"""
|
|
try:
|
|
self._emit_call_started_event(
|
|
messages=messages,
|
|
tools=tools,
|
|
callbacks=callbacks,
|
|
available_functions=available_functions,
|
|
from_task=from_task,
|
|
from_agent=from_agent,
|
|
)
|
|
self.tools = tools
|
|
|
|
formatted_content, system_instruction = self._format_messages_for_gemini(
|
|
messages
|
|
)
|
|
|
|
messages_for_hooks = self._convert_contents_to_dict(formatted_content)
|
|
|
|
if not self._invoke_before_llm_call_hooks(messages_for_hooks, from_agent):
|
|
raise ValueError("LLM call blocked by before_llm_call hook")
|
|
|
|
config = self._prepare_generation_config(
|
|
system_instruction, tools, response_model
|
|
)
|
|
|
|
if self.stream:
|
|
return self._handle_streaming_completion(
|
|
formatted_content,
|
|
config,
|
|
available_functions,
|
|
from_task,
|
|
from_agent,
|
|
response_model,
|
|
)
|
|
|
|
return self._handle_completion(
|
|
formatted_content,
|
|
config,
|
|
available_functions,
|
|
from_task,
|
|
from_agent,
|
|
response_model,
|
|
)
|
|
|
|
except APIError as e:
|
|
error_msg = f"Google Gemini API error: {e.code} - {e.message}"
|
|
logging.error(error_msg)
|
|
self._emit_call_failed_event(
|
|
error=error_msg, from_task=from_task, from_agent=from_agent
|
|
)
|
|
raise
|
|
except Exception as e:
|
|
error_msg = f"Google Gemini API call failed: {e!s}"
|
|
logging.error(error_msg)
|
|
self._emit_call_failed_event(
|
|
error=error_msg, from_task=from_task, from_agent=from_agent
|
|
)
|
|
raise
|
|
|
|
async def acall(
|
|
self,
|
|
messages: str | list[LLMMessage],
|
|
tools: list[dict[str, Any]] | None = None,
|
|
callbacks: list[Any] | None = None,
|
|
available_functions: dict[str, Any] | None = None,
|
|
from_task: Any | None = None,
|
|
from_agent: Any | None = None,
|
|
response_model: type[BaseModel] | None = None,
|
|
) -> str | Any:
|
|
"""Async call to Google Gemini generate content API.
|
|
|
|
Args:
|
|
messages: Input messages for the chat completion
|
|
tools: List of tool/function definitions
|
|
callbacks: Callback functions (not used as token counts are handled by the response)
|
|
available_functions: Available functions for tool calling
|
|
from_task: Task that initiated the call
|
|
from_agent: Agent that initiated the call
|
|
response_model: Response model to use.
|
|
|
|
Returns:
|
|
Chat completion response or tool call result
|
|
"""
|
|
try:
|
|
self._emit_call_started_event(
|
|
messages=messages,
|
|
tools=tools,
|
|
callbacks=callbacks,
|
|
available_functions=available_functions,
|
|
from_task=from_task,
|
|
from_agent=from_agent,
|
|
)
|
|
self.tools = tools
|
|
|
|
formatted_content, system_instruction = self._format_messages_for_gemini(
|
|
messages
|
|
)
|
|
|
|
config = self._prepare_generation_config(
|
|
system_instruction, tools, response_model
|
|
)
|
|
|
|
if self.stream:
|
|
return await self._ahandle_streaming_completion(
|
|
formatted_content,
|
|
config,
|
|
available_functions,
|
|
from_task,
|
|
from_agent,
|
|
response_model,
|
|
)
|
|
|
|
return await self._ahandle_completion(
|
|
formatted_content,
|
|
config,
|
|
available_functions,
|
|
from_task,
|
|
from_agent,
|
|
response_model,
|
|
)
|
|
|
|
except APIError as e:
|
|
error_msg = f"Google Gemini API error: {e.code} - {e.message}"
|
|
logging.error(error_msg)
|
|
self._emit_call_failed_event(
|
|
error=error_msg, from_task=from_task, from_agent=from_agent
|
|
)
|
|
raise
|
|
except Exception as e:
|
|
error_msg = f"Google Gemini API call failed: {e!s}"
|
|
logging.error(error_msg)
|
|
self._emit_call_failed_event(
|
|
error=error_msg, from_task=from_task, from_agent=from_agent
|
|
)
|
|
raise
|
|
|
|
def _prepare_generation_config(
|
|
self,
|
|
system_instruction: str | None = None,
|
|
tools: list[dict[str, Any]] | None = None,
|
|
response_model: type[BaseModel] | None = None,
|
|
) -> types.GenerateContentConfig:
|
|
"""Prepare generation config for Google Gemini API.
|
|
|
|
Args:
|
|
system_instruction: System instruction for the model
|
|
tools: Tool definitions
|
|
response_model: Pydantic model for structured output
|
|
|
|
Returns:
|
|
GenerateContentConfig object for Gemini API
|
|
"""
|
|
self.tools = tools
|
|
config_params: dict[str, Any] = {}
|
|
|
|
# Add system instruction if present
|
|
if system_instruction:
|
|
# Convert system instruction to Content format
|
|
system_content = types.Content(
|
|
role="user", parts=[types.Part.from_text(text=system_instruction)]
|
|
)
|
|
config_params["system_instruction"] = system_content
|
|
|
|
# Add generation config parameters
|
|
if self.temperature is not None:
|
|
config_params["temperature"] = self.temperature
|
|
if self.top_p is not None:
|
|
config_params["top_p"] = self.top_p
|
|
if self.top_k is not None:
|
|
config_params["top_k"] = self.top_k
|
|
if self.max_output_tokens is not None:
|
|
config_params["max_output_tokens"] = self.max_output_tokens
|
|
if self.stop_sequences:
|
|
config_params["stop_sequences"] = self.stop_sequences
|
|
|
|
if response_model:
|
|
config_params["response_mime_type"] = "application/json"
|
|
config_params["response_schema"] = response_model.model_json_schema()
|
|
|
|
# Handle tools for supported models
|
|
if tools and self.supports_tools:
|
|
config_params["tools"] = self._convert_tools_for_interference(tools)
|
|
|
|
if self.safety_settings:
|
|
config_params["safety_settings"] = self.safety_settings
|
|
|
|
return types.GenerateContentConfig(**config_params)
|
|
|
|
def _convert_tools_for_interference( # type: ignore[override]
|
|
self, tools: list[dict[str, Any]]
|
|
) -> list[types.Tool]:
|
|
"""Convert CrewAI tool format to Gemini function declaration format."""
|
|
gemini_tools = []
|
|
|
|
from crewai.llms.providers.utils.common import safe_tool_conversion
|
|
|
|
for tool in tools:
|
|
name, description, parameters = safe_tool_conversion(tool, "Gemini")
|
|
|
|
function_declaration = types.FunctionDeclaration(
|
|
name=name,
|
|
description=description,
|
|
parameters=parameters if parameters else None,
|
|
)
|
|
|
|
gemini_tool = types.Tool(function_declarations=[function_declaration])
|
|
gemini_tools.append(gemini_tool)
|
|
|
|
return gemini_tools
|
|
|
|
def _format_messages_for_gemini(
|
|
self, messages: str | list[LLMMessage]
|
|
) -> tuple[list[types.Content], str | None]:
|
|
"""Format messages for Gemini API.
|
|
|
|
Gemini has specific requirements:
|
|
- System messages are separate system_instruction
|
|
- Content is organized as Content objects with Parts
|
|
- Roles are 'user' and 'model' (not 'assistant')
|
|
|
|
Args:
|
|
messages: Input messages
|
|
|
|
Returns:
|
|
Tuple of (formatted_contents, system_instruction)
|
|
"""
|
|
# Use base class formatting first
|
|
base_formatted = super()._format_messages(messages)
|
|
|
|
contents: list[types.Content] = []
|
|
system_instruction: str | None = None
|
|
|
|
for message in base_formatted:
|
|
role = message["role"]
|
|
content = message["content"]
|
|
|
|
# Build parts list from content
|
|
parts: list[types.Part] = []
|
|
if isinstance(content, list):
|
|
for item in content:
|
|
if isinstance(item, dict):
|
|
if "text" in item:
|
|
parts.append(types.Part.from_text(text=str(item["text"])))
|
|
elif "inlineData" in item:
|
|
inline = item["inlineData"]
|
|
parts.append(
|
|
types.Part.from_bytes(
|
|
data=base64.b64decode(inline["data"]),
|
|
mime_type=inline["mimeType"],
|
|
)
|
|
)
|
|
else:
|
|
parts.append(types.Part.from_text(text=str(item)))
|
|
else:
|
|
parts.append(types.Part.from_text(text=str(content) if content else ""))
|
|
|
|
if role == "system":
|
|
# Extract system instruction - Gemini handles it separately
|
|
text_content = " ".join(
|
|
p.text for p in parts if hasattr(p, "text") and p.text
|
|
)
|
|
if system_instruction:
|
|
system_instruction += f"\n\n{text_content}"
|
|
else:
|
|
system_instruction = text_content
|
|
elif role == "tool":
|
|
tool_call_id = message.get("tool_call_id")
|
|
if not tool_call_id:
|
|
raise ValueError("Tool message missing required tool_call_id")
|
|
|
|
tool_name = message.get("name", "")
|
|
|
|
response_data: dict[str, Any]
|
|
try:
|
|
response_data = json.loads(text_content) if text_content else {}
|
|
except (json.JSONDecodeError, TypeError):
|
|
response_data = {"result": text_content}
|
|
|
|
function_response_part = types.Part.from_function_response(
|
|
name=tool_name, response=response_data
|
|
)
|
|
contents.append(
|
|
types.Content(role="user", parts=[function_response_part])
|
|
)
|
|
elif role == "assistant" and message.get("tool_calls"):
|
|
parts: list[types.Part] = []
|
|
|
|
if text_content:
|
|
parts.append(types.Part.from_text(text=text_content))
|
|
|
|
tool_calls: list[dict[str, Any]] = message.get("tool_calls") or []
|
|
for tool_call in tool_calls:
|
|
func: dict[str, Any] = tool_call.get("function") or {}
|
|
func_name: str = str(func.get("name") or "")
|
|
func_args_raw: str | dict[str, Any] = func.get("arguments") or {}
|
|
|
|
func_args: dict[str, Any]
|
|
if isinstance(func_args_raw, str):
|
|
try:
|
|
func_args = (
|
|
json.loads(func_args_raw) if func_args_raw else {}
|
|
)
|
|
except (json.JSONDecodeError, TypeError):
|
|
func_args = {}
|
|
else:
|
|
func_args = func_args_raw
|
|
|
|
parts.append(
|
|
types.Part.from_function_call(name=func_name, args=func_args)
|
|
)
|
|
|
|
contents.append(types.Content(role="model", parts=parts))
|
|
else:
|
|
# Convert role for Gemini (assistant -> model)
|
|
gemini_role = "model" if role == "assistant" else "user"
|
|
|
|
# Create Content object
|
|
gemini_content = types.Content(role=gemini_role, parts=parts)
|
|
contents.append(gemini_content)
|
|
|
|
return contents, system_instruction
|
|
|
|
def _validate_and_emit_structured_output(
|
|
self,
|
|
content: str,
|
|
response_model: type[BaseModel],
|
|
messages_for_event: list[LLMMessage],
|
|
from_task: Any | None = None,
|
|
from_agent: Any | None = None,
|
|
) -> str:
|
|
"""Validate content against response model and emit completion event.
|
|
|
|
Args:
|
|
content: Response content to validate
|
|
response_model: Pydantic model for validation
|
|
messages_for_event: Messages to include in event
|
|
from_task: Task that initiated the call
|
|
from_agent: Agent that initiated the call
|
|
|
|
Returns:
|
|
Validated and serialized JSON string
|
|
|
|
Raises:
|
|
ValueError: If validation fails
|
|
"""
|
|
try:
|
|
structured_data = response_model.model_validate_json(content)
|
|
structured_json = structured_data.model_dump_json()
|
|
|
|
self._emit_call_completed_event(
|
|
response=structured_json,
|
|
call_type=LLMCallType.LLM_CALL,
|
|
from_task=from_task,
|
|
from_agent=from_agent,
|
|
messages=messages_for_event,
|
|
)
|
|
|
|
return structured_json
|
|
except Exception as e:
|
|
error_msg = f"Failed to validate structured output with model {response_model.__name__}: {e}"
|
|
logging.error(error_msg)
|
|
raise ValueError(error_msg) from e
|
|
|
|
def _finalize_completion_response(
|
|
self,
|
|
content: str,
|
|
contents: list[types.Content],
|
|
response_model: type[BaseModel] | None = None,
|
|
from_task: Any | None = None,
|
|
from_agent: Any | None = None,
|
|
) -> str:
|
|
"""Finalize completion response with validation and event emission.
|
|
|
|
Args:
|
|
content: The response content
|
|
contents: Original contents for event conversion
|
|
response_model: Pydantic model for structured output validation
|
|
from_task: Task that initiated the call
|
|
from_agent: Agent that initiated the call
|
|
|
|
Returns:
|
|
Final response content after processing
|
|
"""
|
|
messages_for_event = self._convert_contents_to_dict(contents)
|
|
|
|
# Handle structured output validation
|
|
if response_model:
|
|
return self._validate_and_emit_structured_output(
|
|
content=content,
|
|
response_model=response_model,
|
|
messages_for_event=messages_for_event,
|
|
from_task=from_task,
|
|
from_agent=from_agent,
|
|
)
|
|
|
|
self._emit_call_completed_event(
|
|
response=content,
|
|
call_type=LLMCallType.LLM_CALL,
|
|
from_task=from_task,
|
|
from_agent=from_agent,
|
|
messages=messages_for_event,
|
|
)
|
|
|
|
return self._invoke_after_llm_call_hooks(
|
|
messages_for_event, content, from_agent
|
|
)
|
|
|
|
def _process_response_with_tools(
|
|
self,
|
|
response: GenerateContentResponse,
|
|
contents: list[types.Content],
|
|
available_functions: dict[str, Any] | None = None,
|
|
from_task: Any | None = None,
|
|
from_agent: Any | None = None,
|
|
response_model: type[BaseModel] | None = None,
|
|
) -> str | Any:
|
|
"""Process response, execute function calls, and finalize completion.
|
|
|
|
Args:
|
|
response: The completion response
|
|
contents: Original contents for event conversion
|
|
available_functions: Available functions for function calling
|
|
from_task: Task that initiated the call
|
|
from_agent: Agent that initiated the call
|
|
response_model: Pydantic model for structured output validation
|
|
|
|
Returns:
|
|
Final response content or function call result
|
|
"""
|
|
if response.candidates and (self.tools or available_functions):
|
|
candidate = response.candidates[0]
|
|
if candidate.content and candidate.content.parts:
|
|
# Collect function call parts
|
|
function_call_parts = [
|
|
part for part in candidate.content.parts if part.function_call
|
|
]
|
|
|
|
# If there are function calls but no available_functions,
|
|
# return them for the executor to handle (like OpenAI/Anthropic)
|
|
if function_call_parts and not available_functions:
|
|
self._emit_call_completed_event(
|
|
response=function_call_parts,
|
|
call_type=LLMCallType.TOOL_CALL,
|
|
from_task=from_task,
|
|
from_agent=from_agent,
|
|
messages=self._convert_contents_to_dict(contents),
|
|
)
|
|
return function_call_parts
|
|
|
|
# Otherwise execute the tools internally
|
|
for part in candidate.content.parts:
|
|
if part.function_call:
|
|
function_name = part.function_call.name
|
|
if function_name is None:
|
|
continue
|
|
function_args = (
|
|
dict(part.function_call.args)
|
|
if part.function_call.args
|
|
else {}
|
|
)
|
|
|
|
result = self._handle_tool_execution(
|
|
function_name=function_name,
|
|
function_args=function_args,
|
|
available_functions=available_functions or {},
|
|
from_task=from_task,
|
|
from_agent=from_agent,
|
|
)
|
|
|
|
if result is not None:
|
|
return result
|
|
|
|
content = self._extract_text_from_response(response)
|
|
content = self._apply_stop_words(content)
|
|
|
|
return self._finalize_completion_response(
|
|
content=content,
|
|
contents=contents,
|
|
response_model=response_model,
|
|
from_task=from_task,
|
|
from_agent=from_agent,
|
|
)
|
|
|
|
def _process_stream_chunk(
|
|
self,
|
|
chunk: GenerateContentResponse,
|
|
full_response: str,
|
|
function_calls: dict[int, dict[str, Any]],
|
|
usage_data: dict[str, int],
|
|
from_task: Any | None = None,
|
|
from_agent: Any | None = None,
|
|
) -> tuple[str, dict[int, dict[str, Any]], dict[str, int]]:
|
|
"""Process a single streaming chunk.
|
|
|
|
Args:
|
|
chunk: The streaming chunk response
|
|
full_response: Accumulated response text
|
|
function_calls: Accumulated function calls keyed by sequential index
|
|
usage_data: Accumulated usage data
|
|
from_task: Task that initiated the call
|
|
from_agent: Agent that initiated the call
|
|
|
|
Returns:
|
|
Tuple of (updated full_response, updated function_calls, updated usage_data)
|
|
"""
|
|
if chunk.usage_metadata:
|
|
usage_data = self._extract_token_usage(chunk)
|
|
|
|
if chunk.text:
|
|
full_response += chunk.text
|
|
self._emit_stream_chunk_event(
|
|
chunk=chunk.text,
|
|
from_task=from_task,
|
|
from_agent=from_agent,
|
|
)
|
|
|
|
if chunk.candidates:
|
|
candidate = chunk.candidates[0]
|
|
if candidate.content and candidate.content.parts:
|
|
for part in candidate.content.parts:
|
|
if part.function_call:
|
|
call_index = len(function_calls)
|
|
call_id = f"call_{call_index}"
|
|
args_dict = (
|
|
dict(part.function_call.args)
|
|
if part.function_call.args
|
|
else {}
|
|
)
|
|
args_json = json.dumps(args_dict)
|
|
|
|
function_calls[call_index] = {
|
|
"id": call_id,
|
|
"name": part.function_call.name,
|
|
"args": args_dict,
|
|
}
|
|
|
|
self._emit_stream_chunk_event(
|
|
chunk=args_json,
|
|
from_task=from_task,
|
|
from_agent=from_agent,
|
|
tool_call={
|
|
"id": call_id,
|
|
"function": {
|
|
"name": part.function_call.name or "",
|
|
"arguments": args_json,
|
|
},
|
|
"type": "function",
|
|
"index": call_index,
|
|
},
|
|
call_type=LLMCallType.TOOL_CALL,
|
|
)
|
|
|
|
return full_response, function_calls, usage_data
|
|
|
|
def _finalize_streaming_response(
|
|
self,
|
|
full_response: str,
|
|
function_calls: dict[int, dict[str, Any]],
|
|
usage_data: dict[str, int],
|
|
contents: list[types.Content],
|
|
available_functions: dict[str, Any] | None = None,
|
|
from_task: Any | None = None,
|
|
from_agent: Any | None = None,
|
|
response_model: type[BaseModel] | None = None,
|
|
) -> str | list[dict[str, Any]]:
|
|
"""Finalize streaming response with usage tracking, function execution, and events.
|
|
|
|
Args:
|
|
full_response: The complete streamed response content
|
|
function_calls: Dictionary of function calls accumulated during streaming
|
|
usage_data: Token usage data from the stream
|
|
contents: Original contents for event conversion
|
|
available_functions: Available functions for function calling
|
|
from_task: Task that initiated the call
|
|
from_agent: Agent that initiated the call
|
|
response_model: Pydantic model for structured output validation
|
|
|
|
Returns:
|
|
Final response content after processing
|
|
"""
|
|
self._track_token_usage_internal(usage_data)
|
|
|
|
# If there are function calls but no available_functions,
|
|
# return them for the executor to handle
|
|
if function_calls and not available_functions:
|
|
formatted_function_calls = [
|
|
{
|
|
"id": call_data["id"],
|
|
"function": {
|
|
"name": call_data["name"],
|
|
"arguments": json.dumps(call_data["args"]),
|
|
},
|
|
"type": "function",
|
|
}
|
|
for call_data in function_calls.values()
|
|
]
|
|
self._emit_call_completed_event(
|
|
response=formatted_function_calls,
|
|
call_type=LLMCallType.TOOL_CALL,
|
|
from_task=from_task,
|
|
from_agent=from_agent,
|
|
messages=self._convert_contents_to_dict(contents),
|
|
)
|
|
return formatted_function_calls
|
|
|
|
# Handle completed function calls
|
|
if function_calls and available_functions:
|
|
for call_data in function_calls.values():
|
|
function_name = call_data["name"]
|
|
function_args = call_data["args"]
|
|
|
|
# Skip if function_name is None
|
|
if not isinstance(function_name, str):
|
|
continue
|
|
|
|
# Ensure function_args is a dict
|
|
if not isinstance(function_args, dict):
|
|
function_args = {}
|
|
|
|
# Execute tool
|
|
result = self._handle_tool_execution(
|
|
function_name=function_name,
|
|
function_args=function_args,
|
|
available_functions=available_functions,
|
|
from_task=from_task,
|
|
from_agent=from_agent,
|
|
)
|
|
|
|
if result is not None:
|
|
return result
|
|
|
|
return self._finalize_completion_response(
|
|
content=full_response,
|
|
contents=contents,
|
|
response_model=response_model,
|
|
from_task=from_task,
|
|
from_agent=from_agent,
|
|
)
|
|
|
|
def _handle_completion(
|
|
self,
|
|
contents: list[types.Content],
|
|
config: types.GenerateContentConfig,
|
|
available_functions: dict[str, Any] | None = None,
|
|
from_task: Any | None = None,
|
|
from_agent: Any | None = None,
|
|
response_model: type[BaseModel] | None = None,
|
|
) -> str | Any:
|
|
"""Handle non-streaming content generation."""
|
|
try:
|
|
# The API accepts list[Content] but mypy is overly strict about variance
|
|
contents_for_api: Any = contents
|
|
response = self.client.models.generate_content(
|
|
model=self.model,
|
|
contents=contents_for_api,
|
|
config=config,
|
|
)
|
|
|
|
usage = self._extract_token_usage(response)
|
|
except Exception as e:
|
|
if is_context_length_exceeded(e):
|
|
logging.error(f"Context window exceeded: {e}")
|
|
raise LLMContextLengthExceededError(str(e)) from e
|
|
raise e from e
|
|
|
|
self._track_token_usage_internal(usage)
|
|
|
|
return self._process_response_with_tools(
|
|
response=response,
|
|
contents=contents,
|
|
available_functions=available_functions,
|
|
from_task=from_task,
|
|
from_agent=from_agent,
|
|
response_model=response_model,
|
|
)
|
|
|
|
def _handle_streaming_completion(
|
|
self,
|
|
contents: list[types.Content],
|
|
config: types.GenerateContentConfig,
|
|
available_functions: dict[str, Any] | None = None,
|
|
from_task: Any | None = None,
|
|
from_agent: Any | None = None,
|
|
response_model: type[BaseModel] | None = None,
|
|
) -> str:
|
|
"""Handle streaming content generation."""
|
|
full_response = ""
|
|
function_calls: dict[int, dict[str, Any]] = {}
|
|
usage_data = {"total_tokens": 0}
|
|
|
|
# The API accepts list[Content] but mypy is overly strict about variance
|
|
contents_for_api: Any = contents
|
|
for chunk in self.client.models.generate_content_stream(
|
|
model=self.model,
|
|
contents=contents_for_api,
|
|
config=config,
|
|
):
|
|
full_response, function_calls, usage_data = self._process_stream_chunk(
|
|
chunk=chunk,
|
|
full_response=full_response,
|
|
function_calls=function_calls,
|
|
usage_data=usage_data,
|
|
from_task=from_task,
|
|
from_agent=from_agent,
|
|
)
|
|
|
|
return self._finalize_streaming_response(
|
|
full_response=full_response,
|
|
function_calls=function_calls,
|
|
usage_data=usage_data,
|
|
contents=contents,
|
|
available_functions=available_functions,
|
|
from_task=from_task,
|
|
from_agent=from_agent,
|
|
response_model=response_model,
|
|
)
|
|
|
|
async def _ahandle_completion(
|
|
self,
|
|
contents: list[types.Content],
|
|
config: types.GenerateContentConfig,
|
|
available_functions: dict[str, Any] | None = None,
|
|
from_task: Any | None = None,
|
|
from_agent: Any | None = None,
|
|
response_model: type[BaseModel] | None = None,
|
|
) -> str | Any:
|
|
"""Handle async non-streaming content generation."""
|
|
try:
|
|
# The API accepts list[Content] but mypy is overly strict about variance
|
|
contents_for_api: Any = contents
|
|
response = await self.client.aio.models.generate_content(
|
|
model=self.model,
|
|
contents=contents_for_api,
|
|
config=config,
|
|
)
|
|
|
|
usage = self._extract_token_usage(response)
|
|
except Exception as e:
|
|
if is_context_length_exceeded(e):
|
|
logging.error(f"Context window exceeded: {e}")
|
|
raise LLMContextLengthExceededError(str(e)) from e
|
|
raise e from e
|
|
|
|
self._track_token_usage_internal(usage)
|
|
|
|
return self._process_response_with_tools(
|
|
response=response,
|
|
contents=contents,
|
|
available_functions=available_functions,
|
|
from_task=from_task,
|
|
from_agent=from_agent,
|
|
response_model=response_model,
|
|
)
|
|
|
|
async def _ahandle_streaming_completion(
|
|
self,
|
|
contents: list[types.Content],
|
|
config: types.GenerateContentConfig,
|
|
available_functions: dict[str, Any] | None = None,
|
|
from_task: Any | None = None,
|
|
from_agent: Any | None = None,
|
|
response_model: type[BaseModel] | None = None,
|
|
) -> str:
|
|
"""Handle async streaming content generation."""
|
|
full_response = ""
|
|
function_calls: dict[int, dict[str, Any]] = {}
|
|
usage_data = {"total_tokens": 0}
|
|
|
|
# The API accepts list[Content] but mypy is overly strict about variance
|
|
contents_for_api: Any = contents
|
|
stream = await self.client.aio.models.generate_content_stream(
|
|
model=self.model,
|
|
contents=contents_for_api,
|
|
config=config,
|
|
)
|
|
async for chunk in stream:
|
|
full_response, function_calls, usage_data = self._process_stream_chunk(
|
|
chunk=chunk,
|
|
full_response=full_response,
|
|
function_calls=function_calls,
|
|
usage_data=usage_data,
|
|
from_task=from_task,
|
|
from_agent=from_agent,
|
|
)
|
|
|
|
return self._finalize_streaming_response(
|
|
full_response=full_response,
|
|
function_calls=function_calls,
|
|
usage_data=usage_data,
|
|
contents=contents,
|
|
available_functions=available_functions,
|
|
from_task=from_task,
|
|
from_agent=from_agent,
|
|
response_model=response_model,
|
|
)
|
|
|
|
def supports_function_calling(self) -> bool:
|
|
"""Check if the model supports function calling."""
|
|
return self.supports_tools
|
|
|
|
def supports_stop_words(self) -> bool:
|
|
"""Check if the model supports stop words."""
|
|
return True
|
|
|
|
def get_context_window_size(self) -> int:
|
|
"""Get the context window size for the model."""
|
|
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES
|
|
|
|
min_context = 1024
|
|
max_context = 2097152
|
|
|
|
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
|
|
if value < min_context or value > max_context:
|
|
raise ValueError(
|
|
f"Context window for {key} must be between {min_context} and {max_context}"
|
|
)
|
|
|
|
context_windows = {
|
|
"gemini-3-pro-preview": 1048576, # 1M tokens
|
|
"gemini-2.0-flash": 1048576, # 1M tokens
|
|
"gemini-2.0-flash-thinking": 32768,
|
|
"gemini-2.0-flash-lite": 1048576,
|
|
"gemini-2.5-flash": 1048576,
|
|
"gemini-2.5-pro": 1048576,
|
|
"gemini-1.5-pro": 2097152, # 2M tokens
|
|
"gemini-1.5-flash": 1048576,
|
|
"gemini-1.5-flash-8b": 1048576,
|
|
"gemini-1.0-pro": 32768,
|
|
"gemma-3-1b": 32000,
|
|
"gemma-3-4b": 128000,
|
|
"gemma-3-12b": 128000,
|
|
"gemma-3-27b": 128000,
|
|
}
|
|
|
|
# Find the best match for the model name
|
|
for model_prefix, size in context_windows.items():
|
|
if self.model.startswith(model_prefix):
|
|
return int(size * CONTEXT_WINDOW_USAGE_RATIO)
|
|
|
|
# Default context window size for Gemini models
|
|
return int(1048576 * CONTEXT_WINDOW_USAGE_RATIO) # 1M tokens
|
|
|
|
@staticmethod
|
|
def _extract_token_usage(response: GenerateContentResponse) -> dict[str, Any]:
|
|
"""Extract token usage from Gemini response."""
|
|
if response.usage_metadata:
|
|
usage = response.usage_metadata
|
|
return {
|
|
"prompt_token_count": getattr(usage, "prompt_token_count", 0),
|
|
"candidates_token_count": getattr(usage, "candidates_token_count", 0),
|
|
"total_token_count": getattr(usage, "total_token_count", 0),
|
|
"total_tokens": getattr(usage, "total_token_count", 0),
|
|
}
|
|
return {"total_tokens": 0}
|
|
|
|
@staticmethod
|
|
def _extract_text_from_response(response: GenerateContentResponse) -> str:
|
|
"""Extract text content from Gemini response without triggering warnings.
|
|
|
|
This method directly accesses the response parts to extract text content,
|
|
avoiding the warning that occurs when using response.text on responses
|
|
containing non-text parts (e.g., 'thought_signature' from thinking models).
|
|
|
|
Args:
|
|
response: The Gemini API response
|
|
|
|
Returns:
|
|
Concatenated text content from all text parts
|
|
"""
|
|
if not response.candidates:
|
|
return ""
|
|
|
|
candidate = response.candidates[0]
|
|
if not candidate.content or not candidate.content.parts:
|
|
return ""
|
|
|
|
text_parts = [
|
|
part.text
|
|
for part in candidate.content.parts
|
|
if hasattr(part, "text") and part.text
|
|
]
|
|
|
|
return "".join(text_parts)
|
|
|
|
@staticmethod
|
|
def _convert_contents_to_dict(
|
|
contents: list[types.Content],
|
|
) -> list[LLMMessage]:
|
|
"""Convert contents to dict format."""
|
|
result: list[LLMMessage] = []
|
|
for content_obj in contents:
|
|
role = content_obj.role
|
|
if role == "model":
|
|
role = "assistant"
|
|
elif role is None:
|
|
role = "user"
|
|
|
|
parts = content_obj.parts or []
|
|
content = " ".join(
|
|
part.text for part in parts if hasattr(part, "text") and part.text
|
|
)
|
|
|
|
result.append(
|
|
LLMMessage(
|
|
role=cast(Literal["user", "assistant", "system"], role),
|
|
content=content,
|
|
)
|
|
)
|
|
return result
|
|
|
|
def supports_multimodal(self) -> bool:
|
|
"""Check if the model supports multimodal inputs.
|
|
|
|
Gemini models support images, audio, video, and PDFs.
|
|
|
|
Returns:
|
|
True if the model supports multimodal inputs.
|
|
"""
|
|
return True
|
|
|
|
def format_text_content(self, text: str) -> dict[str, Any]:
|
|
"""Format text as a Gemini content block.
|
|
|
|
Gemini uses {"text": "..."} format instead of {"type": "text", "text": "..."}.
|
|
|
|
Args:
|
|
text: The text content to format.
|
|
|
|
Returns:
|
|
A content block in Gemini's expected format.
|
|
"""
|
|
return {"text": text}
|
|
|
|
def get_file_uploader(self) -> Any:
|
|
"""Get a Gemini file uploader using this LLM's client.
|
|
|
|
Returns:
|
|
GeminiFileUploader instance with pre-configured client.
|
|
"""
|
|
try:
|
|
from crewai_files.uploaders.gemini import GeminiFileUploader
|
|
|
|
return GeminiFileUploader(client=self.client)
|
|
except ImportError:
|
|
return None
|