mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
fix: resolve lint and type issues in Google Batch Mode implementation
- Fix unused imports in test_batch_mode.py - Update Google GenAI API calls to use correct client.batches methods - Add proper type annotations and error handling - Ensure compatibility with google-generativeai SDK Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -26,13 +26,10 @@ from pydantic import BaseModel, Field
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
import google.generativeai as genai
|
import google.generativeai as genai
|
||||||
from google.generativeai.types import BatchCreateJobRequest, BatchJob
|
|
||||||
GOOGLE_GENAI_AVAILABLE = True
|
GOOGLE_GENAI_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
GOOGLE_GENAI_AVAILABLE = False
|
GOOGLE_GENAI_AVAILABLE = False
|
||||||
# Create dummy types for type annotations when genai is not available
|
genai = None # type: ignore
|
||||||
BatchJob = object
|
|
||||||
BatchCreateJobRequest = object
|
|
||||||
|
|
||||||
from crewai.utilities.events.llm_events import (
|
from crewai.utilities.events.llm_events import (
|
||||||
LLMCallCompletedEvent,
|
LLMCallCompletedEvent,
|
||||||
@@ -380,8 +377,8 @@ class LLM(BaseLLM):
|
|||||||
self.batch_mode = batch_mode
|
self.batch_mode = batch_mode
|
||||||
self.batch_size = batch_size or 10
|
self.batch_size = batch_size or 10
|
||||||
self.batch_timeout = batch_timeout
|
self.batch_timeout = batch_timeout
|
||||||
self._batch_requests = []
|
self._batch_requests: List[Dict[str, Any]] = []
|
||||||
self._current_batch_job = None
|
self._current_batch_job: Optional[str] = None
|
||||||
|
|
||||||
litellm.drop_params = True
|
litellm.drop_params = True
|
||||||
|
|
||||||
@@ -474,7 +471,7 @@ class LLM(BaseLLM):
|
|||||||
|
|
||||||
formatted_messages = self._format_messages_for_provider(messages)
|
formatted_messages = self._format_messages_for_provider(messages)
|
||||||
|
|
||||||
request = {
|
request: Dict[str, Any] = {
|
||||||
"contents": [],
|
"contents": [],
|
||||||
"generationConfig": {
|
"generationConfig": {
|
||||||
"temperature": self.temperature,
|
"temperature": self.temperature,
|
||||||
@@ -505,32 +502,37 @@ class LLM(BaseLLM):
|
|||||||
raise ValueError("API key is required for batch mode")
|
raise ValueError("API key is required for batch mode")
|
||||||
|
|
||||||
genai.configure(api_key=self.api_key)
|
genai.configure(api_key=self.api_key)
|
||||||
|
client = genai.Client()
|
||||||
|
|
||||||
batch_request = BatchCreateJobRequest(
|
batch_job = client.batches.create(
|
||||||
requests=requests,
|
model=f"models/{self.model.replace('gemini/', '')}",
|
||||||
display_name=f"crewai-batch-{int(time.time())}"
|
src=requests,
|
||||||
|
config={
|
||||||
|
'display_name': f"crewai-batch-{int(time.time())}"
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
batch_job = genai.create_batch_job(batch_request)
|
|
||||||
return batch_job.name
|
return batch_job.name
|
||||||
|
|
||||||
def _poll_batch_job(self, job_name: str) -> BatchJob:
|
def _poll_batch_job(self, job_name: str) -> Any:
|
||||||
"""Poll batch job status until completion."""
|
"""Poll batch job status until completion."""
|
||||||
if not GOOGLE_GENAI_AVAILABLE:
|
if not GOOGLE_GENAI_AVAILABLE:
|
||||||
raise ImportError("google-generativeai is required for batch mode")
|
raise ImportError("google-generativeai is required for batch mode")
|
||||||
|
|
||||||
genai.configure(api_key=self.api_key)
|
genai.configure(api_key=self.api_key)
|
||||||
|
client = genai.Client()
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
while time.time() - start_time < self.batch_timeout:
|
timeout = self.batch_timeout or 300
|
||||||
batch_job = genai.get_batch_job(job_name)
|
while time.time() - start_time < timeout:
|
||||||
|
batch_job = client.batches.get(name=job_name)
|
||||||
|
|
||||||
if batch_job.state in ["JOB_STATE_SUCCEEDED", "JOB_STATE_FAILED", "JOB_STATE_CANCELLED"]:
|
if batch_job.state.name in ["JOB_STATE_SUCCEEDED", "JOB_STATE_FAILED", "JOB_STATE_CANCELLED"]:
|
||||||
return batch_job
|
return batch_job
|
||||||
|
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
|
|
||||||
raise TimeoutError(f"Batch job {job_name} did not complete within {self.batch_timeout} seconds")
|
raise TimeoutError(f"Batch job {job_name} did not complete within {timeout} seconds")
|
||||||
|
|
||||||
def _retrieve_batch_results(self, job_name: str) -> List[str]:
|
def _retrieve_batch_results(self, job_name: str) -> List[str]:
|
||||||
"""Retrieve results from a completed batch job."""
|
"""Retrieve results from a completed batch job."""
|
||||||
@@ -538,22 +540,29 @@ class LLM(BaseLLM):
|
|||||||
raise ImportError("google-generativeai is required for batch mode")
|
raise ImportError("google-generativeai is required for batch mode")
|
||||||
|
|
||||||
genai.configure(api_key=self.api_key)
|
genai.configure(api_key=self.api_key)
|
||||||
|
client = genai.Client()
|
||||||
|
|
||||||
batch_job = genai.get_batch_job(job_name)
|
batch_job = client.batches.get(name=job_name)
|
||||||
|
|
||||||
if batch_job.state != "JOB_STATE_SUCCEEDED":
|
if batch_job.state.name != "JOB_STATE_SUCCEEDED":
|
||||||
raise RuntimeError(f"Batch job failed with state: {batch_job.state}")
|
raise RuntimeError(f"Batch job failed with state: {batch_job.state.name}")
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for response in genai.list_batch_job_responses(job_name):
|
if batch_job.dest and batch_job.dest.inlined_responses:
|
||||||
if response.response and response.response.candidates:
|
for inline_response in batch_job.dest.inlined_responses:
|
||||||
content = response.response.candidates[0].content
|
if inline_response.response and hasattr(inline_response.response, 'text'):
|
||||||
if content and content.parts:
|
results.append(inline_response.response.text)
|
||||||
results.append(content.parts[0].text)
|
elif inline_response.response and hasattr(inline_response.response, 'candidates'):
|
||||||
|
if inline_response.response.candidates and inline_response.response.candidates[0].content:
|
||||||
|
content = inline_response.response.candidates[0].content
|
||||||
|
if content.parts:
|
||||||
|
results.append(content.parts[0].text)
|
||||||
|
else:
|
||||||
|
results.append("")
|
||||||
|
else:
|
||||||
|
results.append("")
|
||||||
else:
|
else:
|
||||||
results.append("")
|
results.append("")
|
||||||
else:
|
|
||||||
results.append("")
|
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import time
|
from unittest.mock import Mock, patch
|
||||||
from unittest.mock import Mock, patch, MagicMock
|
from crewai.llm import LLM
|
||||||
from crewai.llm import LLM, BatchJobStartedEvent, BatchJobCompletedEvent, BatchJobFailedEvent
|
|
||||||
|
|
||||||
|
|
||||||
class TestBatchMode:
|
class TestBatchMode:
|
||||||
|
|||||||
Reference in New Issue
Block a user