fix: resolve lint and type issues in Google Batch Mode implementation

- Fix unused imports in test_batch_mode.py
- Update Google GenAI API calls to use correct client.batches methods
- Add proper type annotations and error handling
- Ensure compatibility with google-generativeai SDK

Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
Devin AI
2025-07-07 22:11:32 +00:00
parent 09e5a829f9
commit 49aa75e622
2 changed files with 37 additions and 29 deletions

View File

@@ -26,13 +26,10 @@ from pydantic import BaseModel, Field
try:
import google.generativeai as genai
from google.generativeai.types import BatchCreateJobRequest, BatchJob
GOOGLE_GENAI_AVAILABLE = True
except ImportError:
GOOGLE_GENAI_AVAILABLE = False
# Create dummy types for type annotations when genai is not available
BatchJob = object
BatchCreateJobRequest = object
genai = None # type: ignore
from crewai.utilities.events.llm_events import (
LLMCallCompletedEvent,
@@ -380,8 +377,8 @@ class LLM(BaseLLM):
self.batch_mode = batch_mode
self.batch_size = batch_size or 10
self.batch_timeout = batch_timeout
self._batch_requests = []
self._current_batch_job = None
self._batch_requests: List[Dict[str, Any]] = []
self._current_batch_job: Optional[str] = None
litellm.drop_params = True
@@ -474,7 +471,7 @@ class LLM(BaseLLM):
formatted_messages = self._format_messages_for_provider(messages)
request = {
request: Dict[str, Any] = {
"contents": [],
"generationConfig": {
"temperature": self.temperature,
@@ -505,32 +502,37 @@ class LLM(BaseLLM):
raise ValueError("API key is required for batch mode")
genai.configure(api_key=self.api_key)
client = genai.Client()
batch_request = BatchCreateJobRequest(
requests=requests,
display_name=f"crewai-batch-{int(time.time())}"
batch_job = client.batches.create(
model=f"models/{self.model.replace('gemini/', '')}",
src=requests,
config={
'display_name': f"crewai-batch-{int(time.time())}"
}
)
batch_job = genai.create_batch_job(batch_request)
return batch_job.name
def _poll_batch_job(self, job_name: str) -> BatchJob:
def _poll_batch_job(self, job_name: str) -> Any:
"""Poll batch job status until completion."""
if not GOOGLE_GENAI_AVAILABLE:
raise ImportError("google-generativeai is required for batch mode")
genai.configure(api_key=self.api_key)
client = genai.Client()
start_time = time.time()
while time.time() - start_time < self.batch_timeout:
batch_job = genai.get_batch_job(job_name)
timeout = self.batch_timeout or 300
while time.time() - start_time < timeout:
batch_job = client.batches.get(name=job_name)
if batch_job.state in ["JOB_STATE_SUCCEEDED", "JOB_STATE_FAILED", "JOB_STATE_CANCELLED"]:
if batch_job.state.name in ["JOB_STATE_SUCCEEDED", "JOB_STATE_FAILED", "JOB_STATE_CANCELLED"]:
return batch_job
time.sleep(5)
raise TimeoutError(f"Batch job {job_name} did not complete within {self.batch_timeout} seconds")
raise TimeoutError(f"Batch job {job_name} did not complete within {timeout} seconds")
def _retrieve_batch_results(self, job_name: str) -> List[str]:
"""Retrieve results from a completed batch job."""
@@ -538,22 +540,29 @@ class LLM(BaseLLM):
raise ImportError("google-generativeai is required for batch mode")
genai.configure(api_key=self.api_key)
client = genai.Client()
batch_job = genai.get_batch_job(job_name)
batch_job = client.batches.get(name=job_name)
if batch_job.state != "JOB_STATE_SUCCEEDED":
raise RuntimeError(f"Batch job failed with state: {batch_job.state}")
if batch_job.state.name != "JOB_STATE_SUCCEEDED":
raise RuntimeError(f"Batch job failed with state: {batch_job.state.name}")
results = []
for response in genai.list_batch_job_responses(job_name):
if response.response and response.response.candidates:
content = response.response.candidates[0].content
if content and content.parts:
results.append(content.parts[0].text)
if batch_job.dest and batch_job.dest.inlined_responses:
for inline_response in batch_job.dest.inlined_responses:
if inline_response.response and hasattr(inline_response.response, 'text'):
results.append(inline_response.response.text)
elif inline_response.response and hasattr(inline_response.response, 'candidates'):
if inline_response.response.candidates and inline_response.response.candidates[0].content:
content = inline_response.response.candidates[0].content
if content.parts:
results.append(content.parts[0].text)
else:
results.append("")
else:
results.append("")
else:
results.append("")
else:
results.append("")
return results

View File

@@ -1,7 +1,6 @@
import pytest
import time
from unittest.mock import Mock, patch, MagicMock
from crewai.llm import LLM, BatchJobStartedEvent, BatchJobCompletedEvent, BatchJobFailedEvent
from unittest.mock import Mock, patch
from crewai.llm import LLM
class TestBatchMode: