mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
feat: implement Google Batch Mode support for LLM calls
- Add google-generativeai dependency to pyproject.toml - Extend LLM class with batch mode parameters (batch_mode, batch_size, batch_timeout) - Implement batch request management methods for Gemini models - Add batch-specific event types (BatchJobStartedEvent, BatchJobCompletedEvent, BatchJobFailedEvent) - Create comprehensive test suite for batch mode functionality - Add example demonstrating batch mode usage with cost savings - Support inline batch requests for up to 50% cost reduction on Gemini models Resolves issue #3116 Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
63
examples/batch_mode_example.py
Normal file
63
examples/batch_mode_example.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
"""
|
||||||
|
Example demonstrating Google Batch Mode support in CrewAI.
|
||||||
|
|
||||||
|
This example shows how to use batch mode with Gemini models to reduce costs
|
||||||
|
by up to 50% for non-urgent LLM calls.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from crewai import Agent, Task, Crew
|
||||||
|
from crewai.llm import LLM
|
||||||
|
|
||||||
|
os.environ["GOOGLE_API_KEY"] = "your-google-api-key-here"
|
||||||
|
|
||||||
|
def main():
|
||||||
|
batch_llm = LLM(
|
||||||
|
model="gemini/gemini-1.5-pro",
|
||||||
|
batch_mode=True,
|
||||||
|
batch_size=5, # Process 5 requests at once
|
||||||
|
batch_timeout=300, # Wait up to 5 minutes for batch completion
|
||||||
|
temperature=0.7
|
||||||
|
)
|
||||||
|
|
||||||
|
research_agent = Agent(
|
||||||
|
role="Research Analyst",
|
||||||
|
goal="Analyze market trends and provide insights",
|
||||||
|
backstory="You are an expert market analyst with years of experience.",
|
||||||
|
llm=batch_llm,
|
||||||
|
verbose=True
|
||||||
|
)
|
||||||
|
|
||||||
|
tasks = []
|
||||||
|
topics = [
|
||||||
|
"artificial intelligence market trends",
|
||||||
|
"renewable energy investment opportunities",
|
||||||
|
"cryptocurrency regulatory landscape",
|
||||||
|
"e-commerce growth projections",
|
||||||
|
"healthcare technology innovations"
|
||||||
|
]
|
||||||
|
|
||||||
|
for topic in topics:
|
||||||
|
task = Task(
|
||||||
|
description=f"Research and analyze {topic}. Provide a brief summary of key trends and insights.",
|
||||||
|
agent=research_agent,
|
||||||
|
expected_output="A concise analysis with key findings and trends"
|
||||||
|
)
|
||||||
|
tasks.append(task)
|
||||||
|
|
||||||
|
crew = Crew(
|
||||||
|
agents=[research_agent],
|
||||||
|
tasks=tasks,
|
||||||
|
verbose=True
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Starting batch processing...")
|
||||||
|
print("Note: Batch requests will be queued until batch_size is reached")
|
||||||
|
|
||||||
|
result = crew.kickoff()
|
||||||
|
|
||||||
|
print("Batch processing completed!")
|
||||||
|
print("Results:", result)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -39,6 +39,7 @@ dependencies = [
|
|||||||
"tomli>=2.0.2",
|
"tomli>=2.0.2",
|
||||||
"blinker>=1.9.0",
|
"blinker>=1.9.0",
|
||||||
"json5>=0.10.0",
|
"json5>=0.10.0",
|
||||||
|
"google-generativeai>=0.8.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
@@ -23,6 +24,16 @@ from dotenv import load_dotenv
|
|||||||
from litellm.types.utils import ChatCompletionDeltaToolCall
|
from litellm.types.utils import ChatCompletionDeltaToolCall
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
try:
|
||||||
|
import google.generativeai as genai
|
||||||
|
from google.generativeai.types import BatchCreateJobRequest, BatchJob
|
||||||
|
GOOGLE_GENAI_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
GOOGLE_GENAI_AVAILABLE = False
|
||||||
|
# Create dummy types for type annotations when genai is not available
|
||||||
|
BatchJob = object
|
||||||
|
BatchCreateJobRequest = object
|
||||||
|
|
||||||
from crewai.utilities.events.llm_events import (
|
from crewai.utilities.events.llm_events import (
|
||||||
LLMCallCompletedEvent,
|
LLMCallCompletedEvent,
|
||||||
LLMCallFailedEvent,
|
LLMCallFailedEvent,
|
||||||
@@ -57,6 +68,32 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
|||||||
LLMContextLengthExceededException,
|
LLMContextLengthExceededException,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BatchJobStartedEvent:
|
||||||
|
"""Event emitted when a batch job is started."""
|
||||||
|
def __init__(self, messages, tools=None, from_task=None, from_agent=None):
|
||||||
|
self.messages = messages
|
||||||
|
self.tools = tools
|
||||||
|
self.from_task = from_task
|
||||||
|
self.from_agent = from_agent
|
||||||
|
|
||||||
|
|
||||||
|
class BatchJobCompletedEvent:
|
||||||
|
"""Event emitted when a batch job is completed."""
|
||||||
|
def __init__(self, response, job_name, from_task=None, from_agent=None):
|
||||||
|
self.response = response
|
||||||
|
self.job_name = job_name
|
||||||
|
self.from_task = from_task
|
||||||
|
self.from_agent = from_agent
|
||||||
|
|
||||||
|
|
||||||
|
class BatchJobFailedEvent:
|
||||||
|
"""Event emitted when a batch job fails."""
|
||||||
|
def __init__(self, error, from_task=None, from_agent=None):
|
||||||
|
self.error = error
|
||||||
|
self.from_task = from_task
|
||||||
|
self.from_agent = from_agent
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
@@ -311,6 +348,9 @@ class LLM(BaseLLM):
|
|||||||
callbacks: List[Any] = [],
|
callbacks: List[Any] = [],
|
||||||
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
|
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
|
batch_mode: bool = False,
|
||||||
|
batch_size: Optional[int] = None,
|
||||||
|
batch_timeout: Optional[int] = 300,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
@@ -337,6 +377,11 @@ class LLM(BaseLLM):
|
|||||||
self.additional_params = kwargs
|
self.additional_params = kwargs
|
||||||
self.is_anthropic = self._is_anthropic_model(model)
|
self.is_anthropic = self._is_anthropic_model(model)
|
||||||
self.stream = stream
|
self.stream = stream
|
||||||
|
self.batch_mode = batch_mode
|
||||||
|
self.batch_size = batch_size or 10
|
||||||
|
self.batch_timeout = batch_timeout
|
||||||
|
self._batch_requests = []
|
||||||
|
self._current_batch_job = None
|
||||||
|
|
||||||
litellm.drop_params = True
|
litellm.drop_params = True
|
||||||
|
|
||||||
@@ -363,6 +408,10 @@ class LLM(BaseLLM):
|
|||||||
ANTHROPIC_PREFIXES = ("anthropic/", "claude-", "claude/")
|
ANTHROPIC_PREFIXES = ("anthropic/", "claude-", "claude/")
|
||||||
return any(prefix in model.lower() for prefix in ANTHROPIC_PREFIXES)
|
return any(prefix in model.lower() for prefix in ANTHROPIC_PREFIXES)
|
||||||
|
|
||||||
|
def _is_gemini_model(self) -> bool:
|
||||||
|
"""Check if the model is a Gemini model that supports batch mode."""
|
||||||
|
return "gemini" in self.model.lower() and GOOGLE_GENAI_AVAILABLE
|
||||||
|
|
||||||
def _prepare_completion_params(
|
def _prepare_completion_params(
|
||||||
self,
|
self,
|
||||||
messages: Union[str, List[Dict[str, str]]],
|
messages: Union[str, List[Dict[str, str]]],
|
||||||
@@ -414,6 +463,100 @@ class LLM(BaseLLM):
|
|||||||
# Remove None values from params
|
# Remove None values from params
|
||||||
return {k: v for k, v in params.items() if v is not None}
|
return {k: v for k, v in params.items() if v is not None}
|
||||||
|
|
||||||
|
def _prepare_batch_request(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
tools: Optional[List[dict]] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Prepare a single request for batch processing."""
|
||||||
|
if not self._is_gemini_model():
|
||||||
|
raise ValueError("Batch mode is only supported for Gemini models")
|
||||||
|
|
||||||
|
formatted_messages = self._format_messages_for_provider(messages)
|
||||||
|
|
||||||
|
request = {
|
||||||
|
"contents": [],
|
||||||
|
"generationConfig": {
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"topP": self.top_p,
|
||||||
|
"maxOutputTokens": self.max_tokens or self.max_completion_tokens,
|
||||||
|
"stopSequences": self.stop if isinstance(self.stop, list) else [self.stop] if self.stop else None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for message in formatted_messages:
|
||||||
|
role = "user" if message["role"] == "user" else "model"
|
||||||
|
request["contents"].append({
|
||||||
|
"role": role,
|
||||||
|
"parts": [{"text": message["content"]}]
|
||||||
|
})
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
request["tools"] = tools
|
||||||
|
|
||||||
|
return {"model": self.model.replace("gemini/", ""), "contents": request["contents"], "generationConfig": request["generationConfig"]}
|
||||||
|
|
||||||
|
def _submit_batch_job(self, requests: List[Dict[str, Any]]) -> str:
|
||||||
|
"""Submit a batch job to Google GenAI API."""
|
||||||
|
if not GOOGLE_GENAI_AVAILABLE:
|
||||||
|
raise ImportError("google-generativeai is required for batch mode")
|
||||||
|
|
||||||
|
if not self.api_key:
|
||||||
|
raise ValueError("API key is required for batch mode")
|
||||||
|
|
||||||
|
genai.configure(api_key=self.api_key)
|
||||||
|
|
||||||
|
batch_request = BatchCreateJobRequest(
|
||||||
|
requests=requests,
|
||||||
|
display_name=f"crewai-batch-{int(time.time())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_job = genai.create_batch_job(batch_request)
|
||||||
|
return batch_job.name
|
||||||
|
|
||||||
|
def _poll_batch_job(self, job_name: str) -> BatchJob:
|
||||||
|
"""Poll batch job status until completion."""
|
||||||
|
if not GOOGLE_GENAI_AVAILABLE:
|
||||||
|
raise ImportError("google-generativeai is required for batch mode")
|
||||||
|
|
||||||
|
genai.configure(api_key=self.api_key)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
while time.time() - start_time < self.batch_timeout:
|
||||||
|
batch_job = genai.get_batch_job(job_name)
|
||||||
|
|
||||||
|
if batch_job.state in ["JOB_STATE_SUCCEEDED", "JOB_STATE_FAILED", "JOB_STATE_CANCELLED"]:
|
||||||
|
return batch_job
|
||||||
|
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
raise TimeoutError(f"Batch job {job_name} did not complete within {self.batch_timeout} seconds")
|
||||||
|
|
||||||
|
def _retrieve_batch_results(self, job_name: str) -> List[str]:
|
||||||
|
"""Retrieve results from a completed batch job."""
|
||||||
|
if not GOOGLE_GENAI_AVAILABLE:
|
||||||
|
raise ImportError("google-generativeai is required for batch mode")
|
||||||
|
|
||||||
|
genai.configure(api_key=self.api_key)
|
||||||
|
|
||||||
|
batch_job = genai.get_batch_job(job_name)
|
||||||
|
|
||||||
|
if batch_job.state != "JOB_STATE_SUCCEEDED":
|
||||||
|
raise RuntimeError(f"Batch job failed with state: {batch_job.state}")
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for response in genai.list_batch_job_responses(job_name):
|
||||||
|
if response.response and response.response.candidates:
|
||||||
|
content = response.response.candidates[0].content
|
||||||
|
if content and content.parts:
|
||||||
|
results.append(content.parts[0].text)
|
||||||
|
else:
|
||||||
|
results.append("")
|
||||||
|
else:
|
||||||
|
results.append("")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
def _handle_streaming_response(
|
def _handle_streaming_response(
|
||||||
self,
|
self,
|
||||||
params: Dict[str, Any],
|
params: Dict[str, Any],
|
||||||
@@ -952,6 +1095,11 @@ class LLM(BaseLLM):
|
|||||||
if isinstance(messages, str):
|
if isinstance(messages, str):
|
||||||
messages = [{"role": "user", "content": messages}]
|
messages = [{"role": "user", "content": messages}]
|
||||||
|
|
||||||
|
if self.batch_mode and self._is_gemini_model():
|
||||||
|
return self._handle_batch_request(
|
||||||
|
messages, tools, callbacks, available_functions, from_task, from_agent
|
||||||
|
)
|
||||||
|
|
||||||
# --- 4) Handle O1 model special case (system messages not supported)
|
# --- 4) Handle O1 model special case (system messages not supported)
|
||||||
if "o1" in self.model.lower():
|
if "o1" in self.model.lower():
|
||||||
for message in messages:
|
for message in messages:
|
||||||
@@ -991,6 +1139,77 @@ class LLM(BaseLLM):
|
|||||||
logging.error(f"LiteLLM call failed: {str(e)}")
|
logging.error(f"LiteLLM call failed: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
def _handle_batch_request(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
tools: Optional[List[dict]] = None,
|
||||||
|
callbacks: Optional[List[Any]] = None,
|
||||||
|
available_functions: Optional[Dict[str, Any]] = None,
|
||||||
|
from_task: Optional[Any] = None,
|
||||||
|
from_agent: Optional[Any] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Handle batch mode request for Gemini models."""
|
||||||
|
if not self._is_gemini_model():
|
||||||
|
raise ValueError("Batch mode is only supported for Gemini models")
|
||||||
|
|
||||||
|
assert hasattr(crewai_event_bus, "emit")
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=BatchJobStartedEvent(
|
||||||
|
messages=messages,
|
||||||
|
tools=tools,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
batch_request = self._prepare_batch_request(messages, tools)
|
||||||
|
self._batch_requests.append(batch_request)
|
||||||
|
|
||||||
|
if len(self._batch_requests) >= self.batch_size:
|
||||||
|
job_name = self._submit_batch_job(self._batch_requests)
|
||||||
|
self._current_batch_job = job_name
|
||||||
|
|
||||||
|
self._poll_batch_job(job_name)
|
||||||
|
results = self._retrieve_batch_results(job_name)
|
||||||
|
|
||||||
|
self._batch_requests.clear()
|
||||||
|
self._current_batch_job = None
|
||||||
|
|
||||||
|
if results:
|
||||||
|
response = results[0]
|
||||||
|
|
||||||
|
assert hasattr(crewai_event_bus, "emit")
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=BatchJobCompletedEvent(
|
||||||
|
response=response,
|
||||||
|
job_name=job_name,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
else:
|
||||||
|
raise RuntimeError("No results returned from batch job")
|
||||||
|
else:
|
||||||
|
return "Batch request queued. Call with more requests to trigger batch processing."
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
assert hasattr(crewai_event_bus, "emit")
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=BatchJobFailedEvent(
|
||||||
|
error=str(e),
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
logging.error(f"Batch request failed: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
def _handle_emit_call_events(self, response: Any, call_type: LLMCallType, from_task: Optional[Any] = None, from_agent: Optional[Any] = None):
|
def _handle_emit_call_events(self, response: Any, call_type: LLMCallType, from_task: Optional[Any] = None, from_agent: Optional[Any] = None):
|
||||||
"""Handle the events for the LLM call.
|
"""Handle the events for the LLM call.
|
||||||
|
|
||||||
|
|||||||
236
tests/test_batch_mode.py
Normal file
236
tests/test_batch_mode.py
Normal file
@@ -0,0 +1,236 @@
|
|||||||
|
import pytest
|
||||||
|
import time
|
||||||
|
from unittest.mock import Mock, patch, MagicMock
|
||||||
|
from crewai.llm import LLM, BatchJobStartedEvent, BatchJobCompletedEvent, BatchJobFailedEvent
|
||||||
|
|
||||||
|
|
||||||
|
class TestBatchMode:
|
||||||
|
"""Test suite for Google Batch Mode functionality."""
|
||||||
|
|
||||||
|
def test_batch_mode_initialization(self):
|
||||||
|
"""Test that batch mode parameters are properly initialized."""
|
||||||
|
llm = LLM(
|
||||||
|
model="gemini/gemini-1.5-pro",
|
||||||
|
batch_mode=True,
|
||||||
|
batch_size=5,
|
||||||
|
batch_timeout=600
|
||||||
|
)
|
||||||
|
|
||||||
|
assert llm.batch_mode is True
|
||||||
|
assert llm.batch_size == 5
|
||||||
|
assert llm.batch_timeout == 600
|
||||||
|
assert llm._batch_requests == []
|
||||||
|
assert llm._current_batch_job is None
|
||||||
|
|
||||||
|
def test_batch_mode_defaults(self):
|
||||||
|
"""Test default values for batch mode parameters."""
|
||||||
|
llm = LLM(model="gemini/gemini-1.5-pro", batch_mode=True)
|
||||||
|
|
||||||
|
assert llm.batch_mode is True
|
||||||
|
assert llm.batch_size == 10
|
||||||
|
assert llm.batch_timeout == 300
|
||||||
|
|
||||||
|
def test_is_gemini_model_detection(self):
|
||||||
|
"""Test Gemini model detection for batch mode support."""
|
||||||
|
with patch('crewai.llm.GOOGLE_GENAI_AVAILABLE', True):
|
||||||
|
llm_gemini = LLM(model="gemini/gemini-1.5-pro")
|
||||||
|
assert llm_gemini._is_gemini_model() is True
|
||||||
|
|
||||||
|
llm_openai = LLM(model="gpt-4")
|
||||||
|
assert llm_openai._is_gemini_model() is False
|
||||||
|
|
||||||
|
def test_is_gemini_model_without_genai_available(self):
|
||||||
|
"""Test Gemini model detection when google-generativeai is not available."""
|
||||||
|
with patch('crewai.llm.GOOGLE_GENAI_AVAILABLE', False):
|
||||||
|
llm = LLM(model="gemini/gemini-1.5-pro")
|
||||||
|
assert llm._is_gemini_model() is False
|
||||||
|
|
||||||
|
def test_prepare_batch_request(self):
|
||||||
|
"""Test batch request preparation."""
|
||||||
|
with patch('crewai.llm.GOOGLE_GENAI_AVAILABLE', True):
|
||||||
|
llm = LLM(
|
||||||
|
model="gemini/gemini-1.5-pro",
|
||||||
|
temperature=0.7,
|
||||||
|
top_p=0.9,
|
||||||
|
max_tokens=1000
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "Hello, world!"}]
|
||||||
|
batch_request = llm._prepare_batch_request(messages)
|
||||||
|
|
||||||
|
assert "model" in batch_request
|
||||||
|
assert batch_request["model"] == "gemini-1.5-pro"
|
||||||
|
assert "contents" in batch_request
|
||||||
|
assert "generationConfig" in batch_request
|
||||||
|
assert batch_request["generationConfig"]["temperature"] == 0.7
|
||||||
|
assert batch_request["generationConfig"]["topP"] == 0.9
|
||||||
|
assert batch_request["generationConfig"]["maxOutputTokens"] == 1000
|
||||||
|
|
||||||
|
def test_prepare_batch_request_non_gemini_model(self):
|
||||||
|
"""Test that batch request preparation fails for non-Gemini models."""
|
||||||
|
llm = LLM(model="gpt-4")
|
||||||
|
messages = [{"role": "user", "content": "Hello, world!"}]
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Batch mode is only supported for Gemini models"):
|
||||||
|
llm._prepare_batch_request(messages)
|
||||||
|
|
||||||
|
@patch('crewai.llm.genai')
|
||||||
|
def test_submit_batch_job(self, mock_genai):
|
||||||
|
"""Test batch job submission."""
|
||||||
|
with patch('crewai.llm.GOOGLE_GENAI_AVAILABLE', True):
|
||||||
|
mock_batch_job = Mock()
|
||||||
|
mock_batch_job.name = "test-job-123"
|
||||||
|
mock_genai.create_batch_job.return_value = mock_batch_job
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model="gemini/gemini-1.5-pro",
|
||||||
|
api_key="test-key"
|
||||||
|
)
|
||||||
|
|
||||||
|
requests = [{"model": "gemini-1.5-pro", "contents": []}]
|
||||||
|
job_name = llm._submit_batch_job(requests)
|
||||||
|
|
||||||
|
assert job_name == "test-job-123"
|
||||||
|
mock_genai.configure.assert_called_with(api_key="test-key")
|
||||||
|
mock_genai.create_batch_job.assert_called_once()
|
||||||
|
|
||||||
|
def test_submit_batch_job_without_genai(self):
|
||||||
|
"""Test batch job submission without google-generativeai available."""
|
||||||
|
with patch('crewai.llm.GOOGLE_GENAI_AVAILABLE', False):
|
||||||
|
llm = LLM(model="gemini/gemini-1.5-pro")
|
||||||
|
|
||||||
|
with pytest.raises(ImportError, match="google-generativeai is required for batch mode"):
|
||||||
|
llm._submit_batch_job([])
|
||||||
|
|
||||||
|
def test_submit_batch_job_without_api_key(self):
|
||||||
|
"""Test batch job submission without API key."""
|
||||||
|
with patch('crewai.llm.GOOGLE_GENAI_AVAILABLE', True):
|
||||||
|
llm = LLM(model="gemini/gemini-1.5-pro")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="API key is required for batch mode"):
|
||||||
|
llm._submit_batch_job([])
|
||||||
|
|
||||||
|
@patch('crewai.llm.genai')
|
||||||
|
@patch('crewai.llm.time')
|
||||||
|
def test_poll_batch_job_success(self, mock_time, mock_genai):
|
||||||
|
"""Test successful batch job polling."""
|
||||||
|
with patch('crewai.llm.GOOGLE_GENAI_AVAILABLE', True):
|
||||||
|
mock_batch_job = Mock()
|
||||||
|
mock_batch_job.state = "JOB_STATE_SUCCEEDED"
|
||||||
|
mock_genai.get_batch_job.return_value = mock_batch_job
|
||||||
|
mock_time.time.side_effect = [0, 1, 2]
|
||||||
|
mock_time.sleep = Mock()
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model="gemini/gemini-1.5-pro",
|
||||||
|
api_key="test-key"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = llm._poll_batch_job("test-job-123")
|
||||||
|
|
||||||
|
assert result == mock_batch_job
|
||||||
|
mock_genai.get_batch_job.assert_called_with("test-job-123")
|
||||||
|
|
||||||
|
@patch('crewai.llm.genai')
|
||||||
|
@patch('crewai.llm.time')
|
||||||
|
def test_poll_batch_job_timeout(self, mock_time, mock_genai):
|
||||||
|
"""Test batch job polling timeout."""
|
||||||
|
with patch('crewai.llm.GOOGLE_GENAI_AVAILABLE', True):
|
||||||
|
mock_batch_job = Mock()
|
||||||
|
mock_batch_job.state = "JOB_STATE_PENDING"
|
||||||
|
mock_genai.get_batch_job.return_value = mock_batch_job
|
||||||
|
mock_time.time.side_effect = [0, 400]
|
||||||
|
mock_time.sleep = Mock()
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model="gemini/gemini-1.5-pro",
|
||||||
|
api_key="test-key",
|
||||||
|
batch_timeout=300
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(TimeoutError, match="did not complete within 300 seconds"):
|
||||||
|
llm._poll_batch_job("test-job-123")
|
||||||
|
|
||||||
|
@patch('crewai.llm.genai')
|
||||||
|
def test_retrieve_batch_results(self, mock_genai):
|
||||||
|
"""Test batch result retrieval."""
|
||||||
|
with patch('crewai.llm.GOOGLE_GENAI_AVAILABLE', True):
|
||||||
|
mock_batch_job = Mock()
|
||||||
|
mock_batch_job.state = "JOB_STATE_SUCCEEDED"
|
||||||
|
mock_genai.get_batch_job.return_value = mock_batch_job
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.response.candidates = [Mock()]
|
||||||
|
mock_response.response.candidates[0].content.parts = [Mock()]
|
||||||
|
mock_response.response.candidates[0].content.parts[0].text = "Test response"
|
||||||
|
|
||||||
|
mock_genai.list_batch_job_responses.return_value = [mock_response]
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model="gemini/gemini-1.5-pro",
|
||||||
|
api_key="test-key"
|
||||||
|
)
|
||||||
|
|
||||||
|
results = llm._retrieve_batch_results("test-job-123")
|
||||||
|
|
||||||
|
assert results == ["Test response"]
|
||||||
|
mock_genai.get_batch_job.assert_called_with("test-job-123")
|
||||||
|
mock_genai.list_batch_job_responses.assert_called_with("test-job-123")
|
||||||
|
|
||||||
|
@patch('crewai.llm.genai')
|
||||||
|
def test_retrieve_batch_results_failed_job(self, mock_genai):
|
||||||
|
"""Test batch result retrieval for failed job."""
|
||||||
|
with patch('crewai.llm.GOOGLE_GENAI_AVAILABLE', True):
|
||||||
|
mock_batch_job = Mock()
|
||||||
|
mock_batch_job.state = "JOB_STATE_FAILED"
|
||||||
|
mock_genai.get_batch_job.return_value = mock_batch_job
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model="gemini/gemini-1.5-pro",
|
||||||
|
api_key="test-key"
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="Batch job failed with state: JOB_STATE_FAILED"):
|
||||||
|
llm._retrieve_batch_results("test-job-123")
|
||||||
|
|
||||||
|
@patch('crewai.llm.crewai_event_bus')
|
||||||
|
def test_handle_batch_request_non_gemini(self, mock_event_bus):
|
||||||
|
"""Test batch request handling for non-Gemini models."""
|
||||||
|
llm = LLM(model="gpt-4", batch_mode=True)
|
||||||
|
messages = [{"role": "user", "content": "Hello"}]
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Batch mode is only supported for Gemini models"):
|
||||||
|
llm._handle_batch_request(messages)
|
||||||
|
|
||||||
|
@patch('crewai.llm.crewai_event_bus')
|
||||||
|
def test_batch_mode_call_routing(self, mock_event_bus):
|
||||||
|
"""Test that batch mode calls are routed correctly."""
|
||||||
|
with patch('crewai.llm.GOOGLE_GENAI_AVAILABLE', True):
|
||||||
|
llm = LLM(
|
||||||
|
model="gemini/gemini-1.5-pro",
|
||||||
|
batch_mode=True,
|
||||||
|
api_key="test-key"
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(llm, '_handle_batch_request') as mock_batch_handler:
|
||||||
|
mock_batch_handler.return_value = "Batch response"
|
||||||
|
|
||||||
|
result = llm.call("Hello, world!")
|
||||||
|
|
||||||
|
assert result == "Batch response"
|
||||||
|
mock_batch_handler.assert_called_once()
|
||||||
|
|
||||||
|
def test_non_batch_mode_unchanged(self):
|
||||||
|
"""Test that non-batch mode behavior is unchanged."""
|
||||||
|
with patch('crewai.llm.litellm') as mock_litellm:
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.choices = [Mock()]
|
||||||
|
mock_response.choices[0].message.content = "Regular response"
|
||||||
|
mock_response.choices[0].message.tool_calls = []
|
||||||
|
mock_litellm.completion.return_value = mock_response
|
||||||
|
|
||||||
|
llm = LLM(model="gemini/gemini-1.5-pro", batch_mode=False)
|
||||||
|
result = llm.call("Hello, world!")
|
||||||
|
|
||||||
|
assert result == "Regular response"
|
||||||
|
mock_litellm.completion.assert_called_once()
|
||||||
Reference in New Issue
Block a user