diff --git a/src/crewai/llm.py b/src/crewai/llm.py index 1bf016d79..b506dd04f 100644 --- a/src/crewai/llm.py +++ b/src/crewai/llm.py @@ -26,13 +26,10 @@ from pydantic import BaseModel, Field try: import google.generativeai as genai - from google.generativeai.types import BatchCreateJobRequest, BatchJob GOOGLE_GENAI_AVAILABLE = True except ImportError: GOOGLE_GENAI_AVAILABLE = False - # Create dummy types for type annotations when genai is not available - BatchJob = object - BatchCreateJobRequest = object + genai = None # type: ignore from crewai.utilities.events.llm_events import ( LLMCallCompletedEvent, @@ -380,8 +377,8 @@ class LLM(BaseLLM): self.batch_mode = batch_mode self.batch_size = batch_size or 10 self.batch_timeout = batch_timeout - self._batch_requests = [] - self._current_batch_job = None + self._batch_requests: List[Dict[str, Any]] = [] + self._current_batch_job: Optional[str] = None litellm.drop_params = True @@ -474,7 +471,7 @@ class LLM(BaseLLM): formatted_messages = self._format_messages_for_provider(messages) - request = { + request: Dict[str, Any] = { "contents": [], "generationConfig": { "temperature": self.temperature, @@ -505,32 +502,37 @@ class LLM(BaseLLM): raise ValueError("API key is required for batch mode") genai.configure(api_key=self.api_key) + client = genai.Client() - batch_request = BatchCreateJobRequest( - requests=requests, - display_name=f"crewai-batch-{int(time.time())}" + batch_job = client.batches.create( + model=f"models/{self.model.replace('gemini/', '')}", + src=requests, + config={ + 'display_name': f"crewai-batch-{int(time.time())}" + } ) - batch_job = genai.create_batch_job(batch_request) return batch_job.name - def _poll_batch_job(self, job_name: str) -> BatchJob: + def _poll_batch_job(self, job_name: str) -> Any: """Poll batch job status until completion.""" if not GOOGLE_GENAI_AVAILABLE: raise ImportError("google-generativeai is required for batch mode") genai.configure(api_key=self.api_key) + client = genai.Client() start_time = time.time() - while time.time() - start_time < self.batch_timeout: - batch_job = genai.get_batch_job(job_name) + timeout = self.batch_timeout or 300 + while time.time() - start_time < timeout: + batch_job = client.batches.get(name=job_name) - if batch_job.state in ["JOB_STATE_SUCCEEDED", "JOB_STATE_FAILED", "JOB_STATE_CANCELLED"]: + if batch_job.state.name in ["JOB_STATE_SUCCEEDED", "JOB_STATE_FAILED", "JOB_STATE_CANCELLED"]: return batch_job time.sleep(5) - raise TimeoutError(f"Batch job {job_name} did not complete within {self.batch_timeout} seconds") + raise TimeoutError(f"Batch job {job_name} did not complete within {timeout} seconds") def _retrieve_batch_results(self, job_name: str) -> List[str]: """Retrieve results from a completed batch job.""" @@ -538,22 +540,29 @@ class LLM(BaseLLM): raise ImportError("google-generativeai is required for batch mode") genai.configure(api_key=self.api_key) + client = genai.Client() - batch_job = genai.get_batch_job(job_name) + batch_job = client.batches.get(name=job_name) - if batch_job.state != "JOB_STATE_SUCCEEDED": - raise RuntimeError(f"Batch job failed with state: {batch_job.state}") + if batch_job.state.name != "JOB_STATE_SUCCEEDED": + raise RuntimeError(f"Batch job failed with state: {batch_job.state.name}") results = [] - for response in genai.list_batch_job_responses(job_name): - if response.response and response.response.candidates: - content = response.response.candidates[0].content - if content and content.parts: - results.append(content.parts[0].text) + if batch_job.dest and batch_job.dest.inlined_responses: + for inline_response in batch_job.dest.inlined_responses: + if inline_response.response and hasattr(inline_response.response, 'text'): + results.append(inline_response.response.text) + elif inline_response.response and hasattr(inline_response.response, 'candidates'): + if inline_response.response.candidates and inline_response.response.candidates[0].content: + content = inline_response.response.candidates[0].content + if content.parts: + results.append(content.parts[0].text) + else: + results.append("") + else: + results.append("") else: results.append("") - else: - results.append("") return results diff --git a/tests/test_batch_mode.py b/tests/test_batch_mode.py index bb721c83b..3d8cd4f8d 100644 --- a/tests/test_batch_mode.py +++ b/tests/test_batch_mode.py @@ -1,7 +1,6 @@ import pytest -import time -from unittest.mock import Mock, patch, MagicMock -from crewai.llm import LLM, BatchJobStartedEvent, BatchJobCompletedEvent, BatchJobFailedEvent +from unittest.mock import Mock, patch +from crewai.llm import LLM class TestBatchMode: