feat: implement Google Batch Mode support for LLM calls

- Add google-generativeai dependency to pyproject.toml
- Extend LLM class with batch mode parameters (batch_mode, batch_size, batch_timeout)
- Implement batch request management methods for Gemini models
- Add batch-specific event types (BatchJobStartedEvent, BatchJobCompletedEvent, BatchJobFailedEvent)
- Create comprehensive test suite for batch mode functionality
- Add example demonstrating batch mode usage with cost savings
- Support inline batch requests for up to 50% cost reduction on Gemini models

Resolves issue #3116

Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
Devin AI
2025-07-07 22:01:56 +00:00
parent 34a03f882c
commit ae59abb052
4 changed files with 519 additions and 0 deletions

View File

@@ -0,0 +1,63 @@
"""
Example demonstrating Google Batch Mode support in CrewAI.
This example shows how to use batch mode with Gemini models to reduce costs
by up to 50% for non-urgent LLM calls.
"""
import os
from crewai import Agent, Task, Crew
from crewai.llm import LLM
os.environ["GOOGLE_API_KEY"] = "your-google-api-key-here"
def main():
batch_llm = LLM(
model="gemini/gemini-1.5-pro",
batch_mode=True,
batch_size=5, # Process 5 requests at once
batch_timeout=300, # Wait up to 5 minutes for batch completion
temperature=0.7
)
research_agent = Agent(
role="Research Analyst",
goal="Analyze market trends and provide insights",
backstory="You are an expert market analyst with years of experience.",
llm=batch_llm,
verbose=True
)
tasks = []
topics = [
"artificial intelligence market trends",
"renewable energy investment opportunities",
"cryptocurrency regulatory landscape",
"e-commerce growth projections",
"healthcare technology innovations"
]
for topic in topics:
task = Task(
description=f"Research and analyze {topic}. Provide a brief summary of key trends and insights.",
agent=research_agent,
expected_output="A concise analysis with key findings and trends"
)
tasks.append(task)
crew = Crew(
agents=[research_agent],
tasks=tasks,
verbose=True
)
print("Starting batch processing...")
print("Note: Batch requests will be queued until batch_size is reached")
result = crew.kickoff()
print("Batch processing completed!")
print("Results:", result)
if __name__ == "__main__":
main()

View File

@@ -39,6 +39,7 @@ dependencies = [
"tomli>=2.0.2", "tomli>=2.0.2",
"blinker>=1.9.0", "blinker>=1.9.0",
"json5>=0.10.0", "json5>=0.10.0",
"google-generativeai>=0.8.0",
] ]
[project.urls] [project.urls]

View File

@@ -3,6 +3,7 @@ import logging
import os import os
import sys import sys
import threading import threading
import time
import warnings import warnings
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
@@ -23,6 +24,16 @@ from dotenv import load_dotenv
from litellm.types.utils import ChatCompletionDeltaToolCall from litellm.types.utils import ChatCompletionDeltaToolCall
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
try:
import google.generativeai as genai
from google.generativeai.types import BatchCreateJobRequest, BatchJob
GOOGLE_GENAI_AVAILABLE = True
except ImportError:
GOOGLE_GENAI_AVAILABLE = False
# Create dummy types for type annotations when genai is not available
BatchJob = object
BatchCreateJobRequest = object
from crewai.utilities.events.llm_events import ( from crewai.utilities.events.llm_events import (
LLMCallCompletedEvent, LLMCallCompletedEvent,
LLMCallFailedEvent, LLMCallFailedEvent,
@@ -57,6 +68,32 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededException, LLMContextLengthExceededException,
) )
class BatchJobStartedEvent:
"""Event emitted when a batch job is started."""
def __init__(self, messages, tools=None, from_task=None, from_agent=None):
self.messages = messages
self.tools = tools
self.from_task = from_task
self.from_agent = from_agent
class BatchJobCompletedEvent:
"""Event emitted when a batch job is completed."""
def __init__(self, response, job_name, from_task=None, from_agent=None):
self.response = response
self.job_name = job_name
self.from_task = from_task
self.from_agent = from_agent
class BatchJobFailedEvent:
"""Event emitted when a batch job fails."""
def __init__(self, error, from_task=None, from_agent=None):
self.error = error
self.from_task = from_task
self.from_agent = from_agent
load_dotenv() load_dotenv()
@@ -311,6 +348,9 @@ class LLM(BaseLLM):
callbacks: List[Any] = [], callbacks: List[Any] = [],
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None, reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
stream: bool = False, stream: bool = False,
batch_mode: bool = False,
batch_size: Optional[int] = None,
batch_timeout: Optional[int] = 300,
**kwargs, **kwargs,
): ):
self.model = model self.model = model
@@ -337,6 +377,11 @@ class LLM(BaseLLM):
self.additional_params = kwargs self.additional_params = kwargs
self.is_anthropic = self._is_anthropic_model(model) self.is_anthropic = self._is_anthropic_model(model)
self.stream = stream self.stream = stream
self.batch_mode = batch_mode
self.batch_size = batch_size or 10
self.batch_timeout = batch_timeout
self._batch_requests = []
self._current_batch_job = None
litellm.drop_params = True litellm.drop_params = True
@@ -363,6 +408,10 @@ class LLM(BaseLLM):
ANTHROPIC_PREFIXES = ("anthropic/", "claude-", "claude/") ANTHROPIC_PREFIXES = ("anthropic/", "claude-", "claude/")
return any(prefix in model.lower() for prefix in ANTHROPIC_PREFIXES) return any(prefix in model.lower() for prefix in ANTHROPIC_PREFIXES)
def _is_gemini_model(self) -> bool:
"""Check if the model is a Gemini model that supports batch mode."""
return "gemini" in self.model.lower() and GOOGLE_GENAI_AVAILABLE
def _prepare_completion_params( def _prepare_completion_params(
self, self,
messages: Union[str, List[Dict[str, str]]], messages: Union[str, List[Dict[str, str]]],
@@ -414,6 +463,100 @@ class LLM(BaseLLM):
# Remove None values from params # Remove None values from params
return {k: v for k, v in params.items() if v is not None} return {k: v for k, v in params.items() if v is not None}
def _prepare_batch_request(
self,
messages: List[Dict[str, str]],
tools: Optional[List[dict]] = None
) -> Dict[str, Any]:
"""Prepare a single request for batch processing."""
if not self._is_gemini_model():
raise ValueError("Batch mode is only supported for Gemini models")
formatted_messages = self._format_messages_for_provider(messages)
request = {
"contents": [],
"generationConfig": {
"temperature": self.temperature,
"topP": self.top_p,
"maxOutputTokens": self.max_tokens or self.max_completion_tokens,
"stopSequences": self.stop if isinstance(self.stop, list) else [self.stop] if self.stop else None,
}
}
for message in formatted_messages:
role = "user" if message["role"] == "user" else "model"
request["contents"].append({
"role": role,
"parts": [{"text": message["content"]}]
})
if tools:
request["tools"] = tools
return {"model": self.model.replace("gemini/", ""), "contents": request["contents"], "generationConfig": request["generationConfig"]}
def _submit_batch_job(self, requests: List[Dict[str, Any]]) -> str:
"""Submit a batch job to Google GenAI API."""
if not GOOGLE_GENAI_AVAILABLE:
raise ImportError("google-generativeai is required for batch mode")
if not self.api_key:
raise ValueError("API key is required for batch mode")
genai.configure(api_key=self.api_key)
batch_request = BatchCreateJobRequest(
requests=requests,
display_name=f"crewai-batch-{int(time.time())}"
)
batch_job = genai.create_batch_job(batch_request)
return batch_job.name
def _poll_batch_job(self, job_name: str) -> BatchJob:
"""Poll batch job status until completion."""
if not GOOGLE_GENAI_AVAILABLE:
raise ImportError("google-generativeai is required for batch mode")
genai.configure(api_key=self.api_key)
start_time = time.time()
while time.time() - start_time < self.batch_timeout:
batch_job = genai.get_batch_job(job_name)
if batch_job.state in ["JOB_STATE_SUCCEEDED", "JOB_STATE_FAILED", "JOB_STATE_CANCELLED"]:
return batch_job
time.sleep(5)
raise TimeoutError(f"Batch job {job_name} did not complete within {self.batch_timeout} seconds")
def _retrieve_batch_results(self, job_name: str) -> List[str]:
"""Retrieve results from a completed batch job."""
if not GOOGLE_GENAI_AVAILABLE:
raise ImportError("google-generativeai is required for batch mode")
genai.configure(api_key=self.api_key)
batch_job = genai.get_batch_job(job_name)
if batch_job.state != "JOB_STATE_SUCCEEDED":
raise RuntimeError(f"Batch job failed with state: {batch_job.state}")
results = []
for response in genai.list_batch_job_responses(job_name):
if response.response and response.response.candidates:
content = response.response.candidates[0].content
if content and content.parts:
results.append(content.parts[0].text)
else:
results.append("")
else:
results.append("")
return results
def _handle_streaming_response( def _handle_streaming_response(
self, self,
params: Dict[str, Any], params: Dict[str, Any],
@@ -952,6 +1095,11 @@ class LLM(BaseLLM):
if isinstance(messages, str): if isinstance(messages, str):
messages = [{"role": "user", "content": messages}] messages = [{"role": "user", "content": messages}]
if self.batch_mode and self._is_gemini_model():
return self._handle_batch_request(
messages, tools, callbacks, available_functions, from_task, from_agent
)
# --- 4) Handle O1 model special case (system messages not supported) # --- 4) Handle O1 model special case (system messages not supported)
if "o1" in self.model.lower(): if "o1" in self.model.lower():
for message in messages: for message in messages:
@@ -991,6 +1139,77 @@ class LLM(BaseLLM):
logging.error(f"LiteLLM call failed: {str(e)}") logging.error(f"LiteLLM call failed: {str(e)}")
raise raise
def _handle_batch_request(
self,
messages: List[Dict[str, str]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
from_task: Optional[Any] = None,
from_agent: Optional[Any] = None,
) -> str:
"""Handle batch mode request for Gemini models."""
if not self._is_gemini_model():
raise ValueError("Batch mode is only supported for Gemini models")
assert hasattr(crewai_event_bus, "emit")
crewai_event_bus.emit(
self,
event=BatchJobStartedEvent(
messages=messages,
tools=tools,
from_task=from_task,
from_agent=from_agent,
),
)
try:
batch_request = self._prepare_batch_request(messages, tools)
self._batch_requests.append(batch_request)
if len(self._batch_requests) >= self.batch_size:
job_name = self._submit_batch_job(self._batch_requests)
self._current_batch_job = job_name
self._poll_batch_job(job_name)
results = self._retrieve_batch_results(job_name)
self._batch_requests.clear()
self._current_batch_job = None
if results:
response = results[0]
assert hasattr(crewai_event_bus, "emit")
crewai_event_bus.emit(
self,
event=BatchJobCompletedEvent(
response=response,
job_name=job_name,
from_task=from_task,
from_agent=from_agent,
),
)
return response
else:
raise RuntimeError("No results returned from batch job")
else:
return "Batch request queued. Call with more requests to trigger batch processing."
except Exception as e:
assert hasattr(crewai_event_bus, "emit")
crewai_event_bus.emit(
self,
event=BatchJobFailedEvent(
error=str(e),
from_task=from_task,
from_agent=from_agent,
),
)
logging.error(f"Batch request failed: {str(e)}")
raise
def _handle_emit_call_events(self, response: Any, call_type: LLMCallType, from_task: Optional[Any] = None, from_agent: Optional[Any] = None): def _handle_emit_call_events(self, response: Any, call_type: LLMCallType, from_task: Optional[Any] = None, from_agent: Optional[Any] = None):
"""Handle the events for the LLM call. """Handle the events for the LLM call.

236
tests/test_batch_mode.py Normal file
View File

@@ -0,0 +1,236 @@
import pytest
import time
from unittest.mock import Mock, patch, MagicMock
from crewai.llm import LLM, BatchJobStartedEvent, BatchJobCompletedEvent, BatchJobFailedEvent
class TestBatchMode:
"""Test suite for Google Batch Mode functionality."""
def test_batch_mode_initialization(self):
"""Test that batch mode parameters are properly initialized."""
llm = LLM(
model="gemini/gemini-1.5-pro",
batch_mode=True,
batch_size=5,
batch_timeout=600
)
assert llm.batch_mode is True
assert llm.batch_size == 5
assert llm.batch_timeout == 600
assert llm._batch_requests == []
assert llm._current_batch_job is None
def test_batch_mode_defaults(self):
"""Test default values for batch mode parameters."""
llm = LLM(model="gemini/gemini-1.5-pro", batch_mode=True)
assert llm.batch_mode is True
assert llm.batch_size == 10
assert llm.batch_timeout == 300
def test_is_gemini_model_detection(self):
"""Test Gemini model detection for batch mode support."""
with patch('crewai.llm.GOOGLE_GENAI_AVAILABLE', True):
llm_gemini = LLM(model="gemini/gemini-1.5-pro")
assert llm_gemini._is_gemini_model() is True
llm_openai = LLM(model="gpt-4")
assert llm_openai._is_gemini_model() is False
def test_is_gemini_model_without_genai_available(self):
"""Test Gemini model detection when google-generativeai is not available."""
with patch('crewai.llm.GOOGLE_GENAI_AVAILABLE', False):
llm = LLM(model="gemini/gemini-1.5-pro")
assert llm._is_gemini_model() is False
def test_prepare_batch_request(self):
"""Test batch request preparation."""
with patch('crewai.llm.GOOGLE_GENAI_AVAILABLE', True):
llm = LLM(
model="gemini/gemini-1.5-pro",
temperature=0.7,
top_p=0.9,
max_tokens=1000
)
messages = [{"role": "user", "content": "Hello, world!"}]
batch_request = llm._prepare_batch_request(messages)
assert "model" in batch_request
assert batch_request["model"] == "gemini-1.5-pro"
assert "contents" in batch_request
assert "generationConfig" in batch_request
assert batch_request["generationConfig"]["temperature"] == 0.7
assert batch_request["generationConfig"]["topP"] == 0.9
assert batch_request["generationConfig"]["maxOutputTokens"] == 1000
def test_prepare_batch_request_non_gemini_model(self):
"""Test that batch request preparation fails for non-Gemini models."""
llm = LLM(model="gpt-4")
messages = [{"role": "user", "content": "Hello, world!"}]
with pytest.raises(ValueError, match="Batch mode is only supported for Gemini models"):
llm._prepare_batch_request(messages)
@patch('crewai.llm.genai')
def test_submit_batch_job(self, mock_genai):
"""Test batch job submission."""
with patch('crewai.llm.GOOGLE_GENAI_AVAILABLE', True):
mock_batch_job = Mock()
mock_batch_job.name = "test-job-123"
mock_genai.create_batch_job.return_value = mock_batch_job
llm = LLM(
model="gemini/gemini-1.5-pro",
api_key="test-key"
)
requests = [{"model": "gemini-1.5-pro", "contents": []}]
job_name = llm._submit_batch_job(requests)
assert job_name == "test-job-123"
mock_genai.configure.assert_called_with(api_key="test-key")
mock_genai.create_batch_job.assert_called_once()
def test_submit_batch_job_without_genai(self):
"""Test batch job submission without google-generativeai available."""
with patch('crewai.llm.GOOGLE_GENAI_AVAILABLE', False):
llm = LLM(model="gemini/gemini-1.5-pro")
with pytest.raises(ImportError, match="google-generativeai is required for batch mode"):
llm._submit_batch_job([])
def test_submit_batch_job_without_api_key(self):
"""Test batch job submission without API key."""
with patch('crewai.llm.GOOGLE_GENAI_AVAILABLE', True):
llm = LLM(model="gemini/gemini-1.5-pro")
with pytest.raises(ValueError, match="API key is required for batch mode"):
llm._submit_batch_job([])
@patch('crewai.llm.genai')
@patch('crewai.llm.time')
def test_poll_batch_job_success(self, mock_time, mock_genai):
"""Test successful batch job polling."""
with patch('crewai.llm.GOOGLE_GENAI_AVAILABLE', True):
mock_batch_job = Mock()
mock_batch_job.state = "JOB_STATE_SUCCEEDED"
mock_genai.get_batch_job.return_value = mock_batch_job
mock_time.time.side_effect = [0, 1, 2]
mock_time.sleep = Mock()
llm = LLM(
model="gemini/gemini-1.5-pro",
api_key="test-key"
)
result = llm._poll_batch_job("test-job-123")
assert result == mock_batch_job
mock_genai.get_batch_job.assert_called_with("test-job-123")
@patch('crewai.llm.genai')
@patch('crewai.llm.time')
def test_poll_batch_job_timeout(self, mock_time, mock_genai):
"""Test batch job polling timeout."""
with patch('crewai.llm.GOOGLE_GENAI_AVAILABLE', True):
mock_batch_job = Mock()
mock_batch_job.state = "JOB_STATE_PENDING"
mock_genai.get_batch_job.return_value = mock_batch_job
mock_time.time.side_effect = [0, 400]
mock_time.sleep = Mock()
llm = LLM(
model="gemini/gemini-1.5-pro",
api_key="test-key",
batch_timeout=300
)
with pytest.raises(TimeoutError, match="did not complete within 300 seconds"):
llm._poll_batch_job("test-job-123")
@patch('crewai.llm.genai')
def test_retrieve_batch_results(self, mock_genai):
"""Test batch result retrieval."""
with patch('crewai.llm.GOOGLE_GENAI_AVAILABLE', True):
mock_batch_job = Mock()
mock_batch_job.state = "JOB_STATE_SUCCEEDED"
mock_genai.get_batch_job.return_value = mock_batch_job
mock_response = Mock()
mock_response.response.candidates = [Mock()]
mock_response.response.candidates[0].content.parts = [Mock()]
mock_response.response.candidates[0].content.parts[0].text = "Test response"
mock_genai.list_batch_job_responses.return_value = [mock_response]
llm = LLM(
model="gemini/gemini-1.5-pro",
api_key="test-key"
)
results = llm._retrieve_batch_results("test-job-123")
assert results == ["Test response"]
mock_genai.get_batch_job.assert_called_with("test-job-123")
mock_genai.list_batch_job_responses.assert_called_with("test-job-123")
@patch('crewai.llm.genai')
def test_retrieve_batch_results_failed_job(self, mock_genai):
"""Test batch result retrieval for failed job."""
with patch('crewai.llm.GOOGLE_GENAI_AVAILABLE', True):
mock_batch_job = Mock()
mock_batch_job.state = "JOB_STATE_FAILED"
mock_genai.get_batch_job.return_value = mock_batch_job
llm = LLM(
model="gemini/gemini-1.5-pro",
api_key="test-key"
)
with pytest.raises(RuntimeError, match="Batch job failed with state: JOB_STATE_FAILED"):
llm._retrieve_batch_results("test-job-123")
@patch('crewai.llm.crewai_event_bus')
def test_handle_batch_request_non_gemini(self, mock_event_bus):
"""Test batch request handling for non-Gemini models."""
llm = LLM(model="gpt-4", batch_mode=True)
messages = [{"role": "user", "content": "Hello"}]
with pytest.raises(ValueError, match="Batch mode is only supported for Gemini models"):
llm._handle_batch_request(messages)
@patch('crewai.llm.crewai_event_bus')
def test_batch_mode_call_routing(self, mock_event_bus):
"""Test that batch mode calls are routed correctly."""
with patch('crewai.llm.GOOGLE_GENAI_AVAILABLE', True):
llm = LLM(
model="gemini/gemini-1.5-pro",
batch_mode=True,
api_key="test-key"
)
with patch.object(llm, '_handle_batch_request') as mock_batch_handler:
mock_batch_handler.return_value = "Batch response"
result = llm.call("Hello, world!")
assert result == "Batch response"
mock_batch_handler.assert_called_once()
def test_non_batch_mode_unchanged(self):
"""Test that non-batch mode behavior is unchanged."""
with patch('crewai.llm.litellm') as mock_litellm:
mock_response = Mock()
mock_response.choices = [Mock()]
mock_response.choices[0].message.content = "Regular response"
mock_response.choices[0].message.tool_calls = []
mock_litellm.completion.return_value = mock_response
llm = LLM(model="gemini/gemini-1.5-pro", batch_mode=False)
result = llm.call("Hello, world!")
assert result == "Regular response"
mock_litellm.completion.assert_called_once()