mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
fix: resolve lint and type issues in Google Batch Mode implementation
- Fix unused imports in test_batch_mode.py - Update Google GenAI API calls to use correct client.batches methods - Add proper type annotations and error handling - Ensure compatibility with google-generativeai SDK Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -26,13 +26,10 @@ from pydantic import BaseModel, Field
|
||||
|
||||
try:
|
||||
import google.generativeai as genai
|
||||
from google.generativeai.types import BatchCreateJobRequest, BatchJob
|
||||
GOOGLE_GENAI_AVAILABLE = True
|
||||
except ImportError:
|
||||
GOOGLE_GENAI_AVAILABLE = False
|
||||
# Create dummy types for type annotations when genai is not available
|
||||
BatchJob = object
|
||||
BatchCreateJobRequest = object
|
||||
genai = None # type: ignore
|
||||
|
||||
from crewai.utilities.events.llm_events import (
|
||||
LLMCallCompletedEvent,
|
||||
@@ -380,8 +377,8 @@ class LLM(BaseLLM):
|
||||
self.batch_mode = batch_mode
|
||||
self.batch_size = batch_size or 10
|
||||
self.batch_timeout = batch_timeout
|
||||
self._batch_requests = []
|
||||
self._current_batch_job = None
|
||||
self._batch_requests: List[Dict[str, Any]] = []
|
||||
self._current_batch_job: Optional[str] = None
|
||||
|
||||
litellm.drop_params = True
|
||||
|
||||
@@ -474,7 +471,7 @@ class LLM(BaseLLM):
|
||||
|
||||
formatted_messages = self._format_messages_for_provider(messages)
|
||||
|
||||
request = {
|
||||
request: Dict[str, Any] = {
|
||||
"contents": [],
|
||||
"generationConfig": {
|
||||
"temperature": self.temperature,
|
||||
@@ -505,32 +502,37 @@ class LLM(BaseLLM):
|
||||
raise ValueError("API key is required for batch mode")
|
||||
|
||||
genai.configure(api_key=self.api_key)
|
||||
client = genai.Client()
|
||||
|
||||
batch_request = BatchCreateJobRequest(
|
||||
requests=requests,
|
||||
display_name=f"crewai-batch-{int(time.time())}"
|
||||
batch_job = client.batches.create(
|
||||
model=f"models/{self.model.replace('gemini/', '')}",
|
||||
src=requests,
|
||||
config={
|
||||
'display_name': f"crewai-batch-{int(time.time())}"
|
||||
}
|
||||
)
|
||||
|
||||
batch_job = genai.create_batch_job(batch_request)
|
||||
return batch_job.name
|
||||
|
||||
def _poll_batch_job(self, job_name: str) -> BatchJob:
|
||||
def _poll_batch_job(self, job_name: str) -> Any:
|
||||
"""Poll batch job status until completion."""
|
||||
if not GOOGLE_GENAI_AVAILABLE:
|
||||
raise ImportError("google-generativeai is required for batch mode")
|
||||
|
||||
genai.configure(api_key=self.api_key)
|
||||
client = genai.Client()
|
||||
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < self.batch_timeout:
|
||||
batch_job = genai.get_batch_job(job_name)
|
||||
timeout = self.batch_timeout or 300
|
||||
while time.time() - start_time < timeout:
|
||||
batch_job = client.batches.get(name=job_name)
|
||||
|
||||
if batch_job.state in ["JOB_STATE_SUCCEEDED", "JOB_STATE_FAILED", "JOB_STATE_CANCELLED"]:
|
||||
if batch_job.state.name in ["JOB_STATE_SUCCEEDED", "JOB_STATE_FAILED", "JOB_STATE_CANCELLED"]:
|
||||
return batch_job
|
||||
|
||||
time.sleep(5)
|
||||
|
||||
raise TimeoutError(f"Batch job {job_name} did not complete within {self.batch_timeout} seconds")
|
||||
raise TimeoutError(f"Batch job {job_name} did not complete within {timeout} seconds")
|
||||
|
||||
def _retrieve_batch_results(self, job_name: str) -> List[str]:
|
||||
"""Retrieve results from a completed batch job."""
|
||||
@@ -538,22 +540,29 @@ class LLM(BaseLLM):
|
||||
raise ImportError("google-generativeai is required for batch mode")
|
||||
|
||||
genai.configure(api_key=self.api_key)
|
||||
client = genai.Client()
|
||||
|
||||
batch_job = genai.get_batch_job(job_name)
|
||||
batch_job = client.batches.get(name=job_name)
|
||||
|
||||
if batch_job.state != "JOB_STATE_SUCCEEDED":
|
||||
raise RuntimeError(f"Batch job failed with state: {batch_job.state}")
|
||||
if batch_job.state.name != "JOB_STATE_SUCCEEDED":
|
||||
raise RuntimeError(f"Batch job failed with state: {batch_job.state.name}")
|
||||
|
||||
results = []
|
||||
for response in genai.list_batch_job_responses(job_name):
|
||||
if response.response and response.response.candidates:
|
||||
content = response.response.candidates[0].content
|
||||
if content and content.parts:
|
||||
results.append(content.parts[0].text)
|
||||
if batch_job.dest and batch_job.dest.inlined_responses:
|
||||
for inline_response in batch_job.dest.inlined_responses:
|
||||
if inline_response.response and hasattr(inline_response.response, 'text'):
|
||||
results.append(inline_response.response.text)
|
||||
elif inline_response.response and hasattr(inline_response.response, 'candidates'):
|
||||
if inline_response.response.candidates and inline_response.response.candidates[0].content:
|
||||
content = inline_response.response.candidates[0].content
|
||||
if content.parts:
|
||||
results.append(content.parts[0].text)
|
||||
else:
|
||||
results.append("")
|
||||
else:
|
||||
results.append("")
|
||||
else:
|
||||
results.append("")
|
||||
else:
|
||||
results.append("")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import pytest
|
||||
import time
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from crewai.llm import LLM, BatchJobStartedEvent, BatchJobCompletedEvent, BatchJobFailedEvent
|
||||
from unittest.mock import Mock, patch
|
||||
from crewai.llm import LLM
|
||||
|
||||
|
||||
class TestBatchMode:
|
||||
|
||||
Reference in New Issue
Block a user