From ae59abb052e5649f6fbda9e038b84123ca4e88bb Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Mon, 7 Jul 2025 22:01:56 +0000 Subject: [PATCH] feat: implement Google Batch Mode support for LLM calls MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add google-generativeai dependency to pyproject.toml - Extend LLM class with batch mode parameters (batch_mode, batch_size, batch_timeout) - Implement batch request management methods for Gemini models - Add batch-specific event types (BatchJobStartedEvent, BatchJobCompletedEvent, BatchJobFailedEvent) - Create comprehensive test suite for batch mode functionality - Add example demonstrating batch mode usage with cost savings - Support inline batch requests for up to 50% cost reduction on Gemini models Resolves issue #3116 Co-Authored-By: João --- examples/batch_mode_example.py | 63 +++++++++ pyproject.toml | 1 + src/crewai/llm.py | 219 ++++++++++++++++++++++++++++++ tests/test_batch_mode.py | 236 +++++++++++++++++++++++++++++++++ 4 files changed, 519 insertions(+) create mode 100644 examples/batch_mode_example.py create mode 100644 tests/test_batch_mode.py diff --git a/examples/batch_mode_example.py b/examples/batch_mode_example.py new file mode 100644 index 000000000..d451e4510 --- /dev/null +++ b/examples/batch_mode_example.py @@ -0,0 +1,63 @@ +""" +Example demonstrating Google Batch Mode support in CrewAI. + +This example shows how to use batch mode with Gemini models to reduce costs +by up to 50% for non-urgent LLM calls. +""" + +import os +from crewai import Agent, Task, Crew +from crewai.llm import LLM + +os.environ["GOOGLE_API_KEY"] = "your-google-api-key-here" + +def main(): + batch_llm = LLM( + model="gemini/gemini-1.5-pro", + batch_mode=True, + batch_size=5, # Process 5 requests at once + batch_timeout=300, # Wait up to 5 minutes for batch completion + temperature=0.7 + ) + + research_agent = Agent( + role="Research Analyst", + goal="Analyze market trends and provide insights", + backstory="You are an expert market analyst with years of experience.", + llm=batch_llm, + verbose=True + ) + + tasks = [] + topics = [ + "artificial intelligence market trends", + "renewable energy investment opportunities", + "cryptocurrency regulatory landscape", + "e-commerce growth projections", + "healthcare technology innovations" + ] + + for topic in topics: + task = Task( + description=f"Research and analyze {topic}. Provide a brief summary of key trends and insights.", + agent=research_agent, + expected_output="A concise analysis with key findings and trends" + ) + tasks.append(task) + + crew = Crew( + agents=[research_agent], + tasks=tasks, + verbose=True + ) + + print("Starting batch processing...") + print("Note: Batch requests will be queued until batch_size is reached") + + result = crew.kickoff() + + print("Batch processing completed!") + print("Results:", result) + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 111d738a4..5b2bba532 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ dependencies = [ "tomli>=2.0.2", "blinker>=1.9.0", "json5>=0.10.0", + "google-generativeai>=0.8.0", ] [project.urls] diff --git a/src/crewai/llm.py b/src/crewai/llm.py index 88edb5ec5..1bf016d79 100644 --- a/src/crewai/llm.py +++ b/src/crewai/llm.py @@ -3,6 +3,7 @@ import logging import os import sys import threading +import time import warnings from collections import defaultdict from contextlib import contextmanager @@ -23,6 +24,16 @@ from dotenv import load_dotenv from litellm.types.utils import ChatCompletionDeltaToolCall from pydantic import BaseModel, Field +try: + import google.generativeai as genai + from google.generativeai.types import BatchCreateJobRequest, BatchJob + GOOGLE_GENAI_AVAILABLE = True +except ImportError: + GOOGLE_GENAI_AVAILABLE = False + # Create dummy types for type annotations when genai is not available + BatchJob = object + BatchCreateJobRequest = object + from crewai.utilities.events.llm_events import ( LLMCallCompletedEvent, LLMCallFailedEvent, @@ -57,6 +68,32 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import ( LLMContextLengthExceededException, ) + +class BatchJobStartedEvent: + """Event emitted when a batch job is started.""" + def __init__(self, messages, tools=None, from_task=None, from_agent=None): + self.messages = messages + self.tools = tools + self.from_task = from_task + self.from_agent = from_agent + + +class BatchJobCompletedEvent: + """Event emitted when a batch job is completed.""" + def __init__(self, response, job_name, from_task=None, from_agent=None): + self.response = response + self.job_name = job_name + self.from_task = from_task + self.from_agent = from_agent + + +class BatchJobFailedEvent: + """Event emitted when a batch job fails.""" + def __init__(self, error, from_task=None, from_agent=None): + self.error = error + self.from_task = from_task + self.from_agent = from_agent + load_dotenv() @@ -311,6 +348,9 @@ class LLM(BaseLLM): callbacks: List[Any] = [], reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None, stream: bool = False, + batch_mode: bool = False, + batch_size: Optional[int] = None, + batch_timeout: Optional[int] = 300, **kwargs, ): self.model = model @@ -337,6 +377,11 @@ class LLM(BaseLLM): self.additional_params = kwargs self.is_anthropic = self._is_anthropic_model(model) self.stream = stream + self.batch_mode = batch_mode + self.batch_size = batch_size or 10 + self.batch_timeout = batch_timeout + self._batch_requests = [] + self._current_batch_job = None litellm.drop_params = True @@ -363,6 +408,10 @@ class LLM(BaseLLM): ANTHROPIC_PREFIXES = ("anthropic/", "claude-", "claude/") return any(prefix in model.lower() for prefix in ANTHROPIC_PREFIXES) + def _is_gemini_model(self) -> bool: + """Check if the model is a Gemini model that supports batch mode.""" + return "gemini" in self.model.lower() and GOOGLE_GENAI_AVAILABLE + def _prepare_completion_params( self, messages: Union[str, List[Dict[str, str]]], @@ -414,6 +463,100 @@ class LLM(BaseLLM): # Remove None values from params return {k: v for k, v in params.items() if v is not None} + def _prepare_batch_request( + self, + messages: List[Dict[str, str]], + tools: Optional[List[dict]] = None + ) -> Dict[str, Any]: + """Prepare a single request for batch processing.""" + if not self._is_gemini_model(): + raise ValueError("Batch mode is only supported for Gemini models") + + formatted_messages = self._format_messages_for_provider(messages) + + request = { + "contents": [], + "generationConfig": { + "temperature": self.temperature, + "topP": self.top_p, + "maxOutputTokens": self.max_tokens or self.max_completion_tokens, + "stopSequences": self.stop if isinstance(self.stop, list) else [self.stop] if self.stop else None, + } + } + + for message in formatted_messages: + role = "user" if message["role"] == "user" else "model" + request["contents"].append({ + "role": role, + "parts": [{"text": message["content"]}] + }) + + if tools: + request["tools"] = tools + + return {"model": self.model.replace("gemini/", ""), "contents": request["contents"], "generationConfig": request["generationConfig"]} + + def _submit_batch_job(self, requests: List[Dict[str, Any]]) -> str: + """Submit a batch job to Google GenAI API.""" + if not GOOGLE_GENAI_AVAILABLE: + raise ImportError("google-generativeai is required for batch mode") + + if not self.api_key: + raise ValueError("API key is required for batch mode") + + genai.configure(api_key=self.api_key) + + batch_request = BatchCreateJobRequest( + requests=requests, + display_name=f"crewai-batch-{int(time.time())}" + ) + + batch_job = genai.create_batch_job(batch_request) + return batch_job.name + + def _poll_batch_job(self, job_name: str) -> BatchJob: + """Poll batch job status until completion.""" + if not GOOGLE_GENAI_AVAILABLE: + raise ImportError("google-generativeai is required for batch mode") + + genai.configure(api_key=self.api_key) + + start_time = time.time() + while time.time() - start_time < self.batch_timeout: + batch_job = genai.get_batch_job(job_name) + + if batch_job.state in ["JOB_STATE_SUCCEEDED", "JOB_STATE_FAILED", "JOB_STATE_CANCELLED"]: + return batch_job + + time.sleep(5) + + raise TimeoutError(f"Batch job {job_name} did not complete within {self.batch_timeout} seconds") + + def _retrieve_batch_results(self, job_name: str) -> List[str]: + """Retrieve results from a completed batch job.""" + if not GOOGLE_GENAI_AVAILABLE: + raise ImportError("google-generativeai is required for batch mode") + + genai.configure(api_key=self.api_key) + + batch_job = genai.get_batch_job(job_name) + + if batch_job.state != "JOB_STATE_SUCCEEDED": + raise RuntimeError(f"Batch job failed with state: {batch_job.state}") + + results = [] + for response in genai.list_batch_job_responses(job_name): + if response.response and response.response.candidates: + content = response.response.candidates[0].content + if content and content.parts: + results.append(content.parts[0].text) + else: + results.append("") + else: + results.append("") + + return results + def _handle_streaming_response( self, params: Dict[str, Any], @@ -952,6 +1095,11 @@ class LLM(BaseLLM): if isinstance(messages, str): messages = [{"role": "user", "content": messages}] + if self.batch_mode and self._is_gemini_model(): + return self._handle_batch_request( + messages, tools, callbacks, available_functions, from_task, from_agent + ) + # --- 4) Handle O1 model special case (system messages not supported) if "o1" in self.model.lower(): for message in messages: @@ -991,6 +1139,77 @@ class LLM(BaseLLM): logging.error(f"LiteLLM call failed: {str(e)}") raise + def _handle_batch_request( + self, + messages: List[Dict[str, str]], + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, + from_task: Optional[Any] = None, + from_agent: Optional[Any] = None, + ) -> str: + """Handle batch mode request for Gemini models.""" + if not self._is_gemini_model(): + raise ValueError("Batch mode is only supported for Gemini models") + + assert hasattr(crewai_event_bus, "emit") + crewai_event_bus.emit( + self, + event=BatchJobStartedEvent( + messages=messages, + tools=tools, + from_task=from_task, + from_agent=from_agent, + ), + ) + + try: + batch_request = self._prepare_batch_request(messages, tools) + self._batch_requests.append(batch_request) + + if len(self._batch_requests) >= self.batch_size: + job_name = self._submit_batch_job(self._batch_requests) + self._current_batch_job = job_name + + self._poll_batch_job(job_name) + results = self._retrieve_batch_results(job_name) + + self._batch_requests.clear() + self._current_batch_job = None + + if results: + response = results[0] + + assert hasattr(crewai_event_bus, "emit") + crewai_event_bus.emit( + self, + event=BatchJobCompletedEvent( + response=response, + job_name=job_name, + from_task=from_task, + from_agent=from_agent, + ), + ) + + return response + else: + raise RuntimeError("No results returned from batch job") + else: + return "Batch request queued. Call with more requests to trigger batch processing." + + except Exception as e: + assert hasattr(crewai_event_bus, "emit") + crewai_event_bus.emit( + self, + event=BatchJobFailedEvent( + error=str(e), + from_task=from_task, + from_agent=from_agent, + ), + ) + logging.error(f"Batch request failed: {str(e)}") + raise + def _handle_emit_call_events(self, response: Any, call_type: LLMCallType, from_task: Optional[Any] = None, from_agent: Optional[Any] = None): """Handle the events for the LLM call. diff --git a/tests/test_batch_mode.py b/tests/test_batch_mode.py new file mode 100644 index 000000000..bb721c83b --- /dev/null +++ b/tests/test_batch_mode.py @@ -0,0 +1,236 @@ +import pytest +import time +from unittest.mock import Mock, patch, MagicMock +from crewai.llm import LLM, BatchJobStartedEvent, BatchJobCompletedEvent, BatchJobFailedEvent + + +class TestBatchMode: + """Test suite for Google Batch Mode functionality.""" + + def test_batch_mode_initialization(self): + """Test that batch mode parameters are properly initialized.""" + llm = LLM( + model="gemini/gemini-1.5-pro", + batch_mode=True, + batch_size=5, + batch_timeout=600 + ) + + assert llm.batch_mode is True + assert llm.batch_size == 5 + assert llm.batch_timeout == 600 + assert llm._batch_requests == [] + assert llm._current_batch_job is None + + def test_batch_mode_defaults(self): + """Test default values for batch mode parameters.""" + llm = LLM(model="gemini/gemini-1.5-pro", batch_mode=True) + + assert llm.batch_mode is True + assert llm.batch_size == 10 + assert llm.batch_timeout == 300 + + def test_is_gemini_model_detection(self): + """Test Gemini model detection for batch mode support.""" + with patch('crewai.llm.GOOGLE_GENAI_AVAILABLE', True): + llm_gemini = LLM(model="gemini/gemini-1.5-pro") + assert llm_gemini._is_gemini_model() is True + + llm_openai = LLM(model="gpt-4") + assert llm_openai._is_gemini_model() is False + + def test_is_gemini_model_without_genai_available(self): + """Test Gemini model detection when google-generativeai is not available.""" + with patch('crewai.llm.GOOGLE_GENAI_AVAILABLE', False): + llm = LLM(model="gemini/gemini-1.5-pro") + assert llm._is_gemini_model() is False + + def test_prepare_batch_request(self): + """Test batch request preparation.""" + with patch('crewai.llm.GOOGLE_GENAI_AVAILABLE', True): + llm = LLM( + model="gemini/gemini-1.5-pro", + temperature=0.7, + top_p=0.9, + max_tokens=1000 + ) + + messages = [{"role": "user", "content": "Hello, world!"}] + batch_request = llm._prepare_batch_request(messages) + + assert "model" in batch_request + assert batch_request["model"] == "gemini-1.5-pro" + assert "contents" in batch_request + assert "generationConfig" in batch_request + assert batch_request["generationConfig"]["temperature"] == 0.7 + assert batch_request["generationConfig"]["topP"] == 0.9 + assert batch_request["generationConfig"]["maxOutputTokens"] == 1000 + + def test_prepare_batch_request_non_gemini_model(self): + """Test that batch request preparation fails for non-Gemini models.""" + llm = LLM(model="gpt-4") + messages = [{"role": "user", "content": "Hello, world!"}] + + with pytest.raises(ValueError, match="Batch mode is only supported for Gemini models"): + llm._prepare_batch_request(messages) + + @patch('crewai.llm.genai') + def test_submit_batch_job(self, mock_genai): + """Test batch job submission.""" + with patch('crewai.llm.GOOGLE_GENAI_AVAILABLE', True): + mock_batch_job = Mock() + mock_batch_job.name = "test-job-123" + mock_genai.create_batch_job.return_value = mock_batch_job + + llm = LLM( + model="gemini/gemini-1.5-pro", + api_key="test-key" + ) + + requests = [{"model": "gemini-1.5-pro", "contents": []}] + job_name = llm._submit_batch_job(requests) + + assert job_name == "test-job-123" + mock_genai.configure.assert_called_with(api_key="test-key") + mock_genai.create_batch_job.assert_called_once() + + def test_submit_batch_job_without_genai(self): + """Test batch job submission without google-generativeai available.""" + with patch('crewai.llm.GOOGLE_GENAI_AVAILABLE', False): + llm = LLM(model="gemini/gemini-1.5-pro") + + with pytest.raises(ImportError, match="google-generativeai is required for batch mode"): + llm._submit_batch_job([]) + + def test_submit_batch_job_without_api_key(self): + """Test batch job submission without API key.""" + with patch('crewai.llm.GOOGLE_GENAI_AVAILABLE', True): + llm = LLM(model="gemini/gemini-1.5-pro") + + with pytest.raises(ValueError, match="API key is required for batch mode"): + llm._submit_batch_job([]) + + @patch('crewai.llm.genai') + @patch('crewai.llm.time') + def test_poll_batch_job_success(self, mock_time, mock_genai): + """Test successful batch job polling.""" + with patch('crewai.llm.GOOGLE_GENAI_AVAILABLE', True): + mock_batch_job = Mock() + mock_batch_job.state = "JOB_STATE_SUCCEEDED" + mock_genai.get_batch_job.return_value = mock_batch_job + mock_time.time.side_effect = [0, 1, 2] + mock_time.sleep = Mock() + + llm = LLM( + model="gemini/gemini-1.5-pro", + api_key="test-key" + ) + + result = llm._poll_batch_job("test-job-123") + + assert result == mock_batch_job + mock_genai.get_batch_job.assert_called_with("test-job-123") + + @patch('crewai.llm.genai') + @patch('crewai.llm.time') + def test_poll_batch_job_timeout(self, mock_time, mock_genai): + """Test batch job polling timeout.""" + with patch('crewai.llm.GOOGLE_GENAI_AVAILABLE', True): + mock_batch_job = Mock() + mock_batch_job.state = "JOB_STATE_PENDING" + mock_genai.get_batch_job.return_value = mock_batch_job + mock_time.time.side_effect = [0, 400] + mock_time.sleep = Mock() + + llm = LLM( + model="gemini/gemini-1.5-pro", + api_key="test-key", + batch_timeout=300 + ) + + with pytest.raises(TimeoutError, match="did not complete within 300 seconds"): + llm._poll_batch_job("test-job-123") + + @patch('crewai.llm.genai') + def test_retrieve_batch_results(self, mock_genai): + """Test batch result retrieval.""" + with patch('crewai.llm.GOOGLE_GENAI_AVAILABLE', True): + mock_batch_job = Mock() + mock_batch_job.state = "JOB_STATE_SUCCEEDED" + mock_genai.get_batch_job.return_value = mock_batch_job + + mock_response = Mock() + mock_response.response.candidates = [Mock()] + mock_response.response.candidates[0].content.parts = [Mock()] + mock_response.response.candidates[0].content.parts[0].text = "Test response" + + mock_genai.list_batch_job_responses.return_value = [mock_response] + + llm = LLM( + model="gemini/gemini-1.5-pro", + api_key="test-key" + ) + + results = llm._retrieve_batch_results("test-job-123") + + assert results == ["Test response"] + mock_genai.get_batch_job.assert_called_with("test-job-123") + mock_genai.list_batch_job_responses.assert_called_with("test-job-123") + + @patch('crewai.llm.genai') + def test_retrieve_batch_results_failed_job(self, mock_genai): + """Test batch result retrieval for failed job.""" + with patch('crewai.llm.GOOGLE_GENAI_AVAILABLE', True): + mock_batch_job = Mock() + mock_batch_job.state = "JOB_STATE_FAILED" + mock_genai.get_batch_job.return_value = mock_batch_job + + llm = LLM( + model="gemini/gemini-1.5-pro", + api_key="test-key" + ) + + with pytest.raises(RuntimeError, match="Batch job failed with state: JOB_STATE_FAILED"): + llm._retrieve_batch_results("test-job-123") + + @patch('crewai.llm.crewai_event_bus') + def test_handle_batch_request_non_gemini(self, mock_event_bus): + """Test batch request handling for non-Gemini models.""" + llm = LLM(model="gpt-4", batch_mode=True) + messages = [{"role": "user", "content": "Hello"}] + + with pytest.raises(ValueError, match="Batch mode is only supported for Gemini models"): + llm._handle_batch_request(messages) + + @patch('crewai.llm.crewai_event_bus') + def test_batch_mode_call_routing(self, mock_event_bus): + """Test that batch mode calls are routed correctly.""" + with patch('crewai.llm.GOOGLE_GENAI_AVAILABLE', True): + llm = LLM( + model="gemini/gemini-1.5-pro", + batch_mode=True, + api_key="test-key" + ) + + with patch.object(llm, '_handle_batch_request') as mock_batch_handler: + mock_batch_handler.return_value = "Batch response" + + result = llm.call("Hello, world!") + + assert result == "Batch response" + mock_batch_handler.assert_called_once() + + def test_non_batch_mode_unchanged(self): + """Test that non-batch mode behavior is unchanged.""" + with patch('crewai.llm.litellm') as mock_litellm: + mock_response = Mock() + mock_response.choices = [Mock()] + mock_response.choices[0].message.content = "Regular response" + mock_response.choices[0].message.tool_calls = [] + mock_litellm.completion.return_value = mock_response + + llm = LLM(model="gemini/gemini-1.5-pro", batch_mode=False) + result = llm.call("Hello, world!") + + assert result == "Regular response" + mock_litellm.completion.assert_called_once()