fix: resolve lint and type issues in Google Batch Mode implementation

- Fix unused imports in test_batch_mode.py
- Update Google GenAI API calls to use correct client.batches methods
- Add proper type annotations and error handling
- Ensure compatibility with google-generativeai SDK

Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
Devin AI
2025-07-07 22:11:32 +00:00
parent 09e5a829f9
commit 49aa75e622
2 changed files with 37 additions and 29 deletions

View File

@@ -26,13 +26,10 @@ from pydantic import BaseModel, Field
try: try:
import google.generativeai as genai import google.generativeai as genai
from google.generativeai.types import BatchCreateJobRequest, BatchJob
GOOGLE_GENAI_AVAILABLE = True GOOGLE_GENAI_AVAILABLE = True
except ImportError: except ImportError:
GOOGLE_GENAI_AVAILABLE = False GOOGLE_GENAI_AVAILABLE = False
# Create dummy types for type annotations when genai is not available genai = None # type: ignore
BatchJob = object
BatchCreateJobRequest = object
from crewai.utilities.events.llm_events import ( from crewai.utilities.events.llm_events import (
LLMCallCompletedEvent, LLMCallCompletedEvent,
@@ -380,8 +377,8 @@ class LLM(BaseLLM):
self.batch_mode = batch_mode self.batch_mode = batch_mode
self.batch_size = batch_size or 10 self.batch_size = batch_size or 10
self.batch_timeout = batch_timeout self.batch_timeout = batch_timeout
self._batch_requests = [] self._batch_requests: List[Dict[str, Any]] = []
self._current_batch_job = None self._current_batch_job: Optional[str] = None
litellm.drop_params = True litellm.drop_params = True
@@ -474,7 +471,7 @@ class LLM(BaseLLM):
formatted_messages = self._format_messages_for_provider(messages) formatted_messages = self._format_messages_for_provider(messages)
request = { request: Dict[str, Any] = {
"contents": [], "contents": [],
"generationConfig": { "generationConfig": {
"temperature": self.temperature, "temperature": self.temperature,
@@ -505,32 +502,37 @@ class LLM(BaseLLM):
raise ValueError("API key is required for batch mode") raise ValueError("API key is required for batch mode")
genai.configure(api_key=self.api_key) genai.configure(api_key=self.api_key)
client = genai.Client()
batch_request = BatchCreateJobRequest( batch_job = client.batches.create(
requests=requests, model=f"models/{self.model.replace('gemini/', '')}",
display_name=f"crewai-batch-{int(time.time())}" src=requests,
config={
'display_name': f"crewai-batch-{int(time.time())}"
}
) )
batch_job = genai.create_batch_job(batch_request)
return batch_job.name return batch_job.name
def _poll_batch_job(self, job_name: str) -> BatchJob: def _poll_batch_job(self, job_name: str) -> Any:
"""Poll batch job status until completion.""" """Poll batch job status until completion."""
if not GOOGLE_GENAI_AVAILABLE: if not GOOGLE_GENAI_AVAILABLE:
raise ImportError("google-generativeai is required for batch mode") raise ImportError("google-generativeai is required for batch mode")
genai.configure(api_key=self.api_key) genai.configure(api_key=self.api_key)
client = genai.Client()
start_time = time.time() start_time = time.time()
while time.time() - start_time < self.batch_timeout: timeout = self.batch_timeout or 300
batch_job = genai.get_batch_job(job_name) while time.time() - start_time < timeout:
batch_job = client.batches.get(name=job_name)
if batch_job.state in ["JOB_STATE_SUCCEEDED", "JOB_STATE_FAILED", "JOB_STATE_CANCELLED"]: if batch_job.state.name in ["JOB_STATE_SUCCEEDED", "JOB_STATE_FAILED", "JOB_STATE_CANCELLED"]:
return batch_job return batch_job
time.sleep(5) time.sleep(5)
raise TimeoutError(f"Batch job {job_name} did not complete within {self.batch_timeout} seconds") raise TimeoutError(f"Batch job {job_name} did not complete within {timeout} seconds")
def _retrieve_batch_results(self, job_name: str) -> List[str]: def _retrieve_batch_results(self, job_name: str) -> List[str]:
"""Retrieve results from a completed batch job.""" """Retrieve results from a completed batch job."""
@@ -538,22 +540,29 @@ class LLM(BaseLLM):
raise ImportError("google-generativeai is required for batch mode") raise ImportError("google-generativeai is required for batch mode")
genai.configure(api_key=self.api_key) genai.configure(api_key=self.api_key)
client = genai.Client()
batch_job = genai.get_batch_job(job_name) batch_job = client.batches.get(name=job_name)
if batch_job.state != "JOB_STATE_SUCCEEDED": if batch_job.state.name != "JOB_STATE_SUCCEEDED":
raise RuntimeError(f"Batch job failed with state: {batch_job.state}") raise RuntimeError(f"Batch job failed with state: {batch_job.state.name}")
results = [] results = []
for response in genai.list_batch_job_responses(job_name): if batch_job.dest and batch_job.dest.inlined_responses:
if response.response and response.response.candidates: for inline_response in batch_job.dest.inlined_responses:
content = response.response.candidates[0].content if inline_response.response and hasattr(inline_response.response, 'text'):
if content and content.parts: results.append(inline_response.response.text)
results.append(content.parts[0].text) elif inline_response.response and hasattr(inline_response.response, 'candidates'):
if inline_response.response.candidates and inline_response.response.candidates[0].content:
content = inline_response.response.candidates[0].content
if content.parts:
results.append(content.parts[0].text)
else:
results.append("")
else:
results.append("")
else: else:
results.append("") results.append("")
else:
results.append("")
return results return results

View File

@@ -1,7 +1,6 @@
import pytest import pytest
import time from unittest.mock import Mock, patch
from unittest.mock import Mock, patch, MagicMock from crewai.llm import LLM
from crewai.llm import LLM, BatchJobStartedEvent, BatchJobCompletedEvent, BatchJobFailedEvent
class TestBatchMode: class TestBatchMode: