mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-03-11 14:28:14 +00:00
Compare commits
6 Commits
0.140.0
...
devin/1751
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
319b12f950 | ||
|
|
49aa75e622 | ||
|
|
09e5a829f9 | ||
|
|
ae59abb052 | ||
|
|
34a03f882c | ||
|
|
a0fcc0c8d1 |
20
.github/workflows/tests.yml
vendored
20
.github/workflows/tests.yml
vendored
@@ -7,14 +7,18 @@ permissions:
|
||||
|
||||
env:
|
||||
OPENAI_API_KEY: fake-api-key
|
||||
PYTHONUNBUFFERED: 1
|
||||
|
||||
jobs:
|
||||
tests:
|
||||
name: tests (${{ matrix.python-version }})
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 15
|
||||
strategy:
|
||||
fail-fast: true
|
||||
matrix:
|
||||
python-version: ['3.10', '3.11', '3.12', '3.13']
|
||||
group: [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
@@ -23,6 +27,9 @@ jobs:
|
||||
uses: astral-sh/setup-uv@v3
|
||||
with:
|
||||
enable-cache: true
|
||||
cache-dependency-glob: |
|
||||
**/pyproject.toml
|
||||
**/uv.lock
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
run: uv python install ${{ matrix.python-version }}
|
||||
@@ -30,5 +37,14 @@ jobs:
|
||||
- name: Install the project
|
||||
run: uv sync --dev --all-extras
|
||||
|
||||
- name: Run tests
|
||||
run: uv run pytest --block-network --timeout=60 -vv
|
||||
- name: Run tests (group ${{ matrix.group }} of 8)
|
||||
run: |
|
||||
uv run pytest \
|
||||
--block-network \
|
||||
--timeout=30 \
|
||||
-vv \
|
||||
--splits 8 \
|
||||
--group ${{ matrix.group }} \
|
||||
--durations=10 \
|
||||
-n auto \
|
||||
--maxfail=3
|
||||
|
||||
63
examples/batch_mode_example.py
Normal file
63
examples/batch_mode_example.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""
|
||||
Example demonstrating Google Batch Mode support in CrewAI.
|
||||
|
||||
This example shows how to use batch mode with Gemini models to reduce costs
|
||||
by up to 50% for non-urgent LLM calls.
|
||||
"""
|
||||
|
||||
import os
|
||||
from crewai import Agent, Task, Crew
|
||||
from crewai.llm import LLM
|
||||
|
||||
os.environ["GOOGLE_API_KEY"] = "your-google-api-key-here"
|
||||
|
||||
def main():
|
||||
batch_llm = LLM(
|
||||
model="gemini/gemini-1.5-pro",
|
||||
batch_mode=True,
|
||||
batch_size=5, # Process 5 requests at once
|
||||
batch_timeout=300, # Wait up to 5 minutes for batch completion
|
||||
temperature=0.7
|
||||
)
|
||||
|
||||
research_agent = Agent(
|
||||
role="Research Analyst",
|
||||
goal="Analyze market trends and provide insights",
|
||||
backstory="You are an expert market analyst with years of experience.",
|
||||
llm=batch_llm,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
tasks = []
|
||||
topics = [
|
||||
"artificial intelligence market trends",
|
||||
"renewable energy investment opportunities",
|
||||
"cryptocurrency regulatory landscape",
|
||||
"e-commerce growth projections",
|
||||
"healthcare technology innovations"
|
||||
]
|
||||
|
||||
for topic in topics:
|
||||
task = Task(
|
||||
description=f"Research and analyze {topic}. Provide a brief summary of key trends and insights.",
|
||||
agent=research_agent,
|
||||
expected_output="A concise analysis with key findings and trends"
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
crew = Crew(
|
||||
agents=[research_agent],
|
||||
tasks=tasks,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
print("Starting batch processing...")
|
||||
print("Note: Batch requests will be queued until batch_size is reached")
|
||||
|
||||
result = crew.kickoff()
|
||||
|
||||
print("Batch processing completed!")
|
||||
print("Results:", result)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -39,6 +39,7 @@ dependencies = [
|
||||
"tomli>=2.0.2",
|
||||
"blinker>=1.9.0",
|
||||
"json5>=0.10.0",
|
||||
"google-generativeai>=0.8.0",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
@@ -83,6 +84,8 @@ dev-dependencies = [
|
||||
"pytest-recording>=0.13.2",
|
||||
"pytest-randomly>=3.16.0",
|
||||
"pytest-timeout>=2.3.1",
|
||||
"pytest-xdist>=3.6.1",
|
||||
"pytest-split>=0.9.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -18,6 +18,11 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from opentelemetry import baggage
|
||||
from opentelemetry.context import attach, detach
|
||||
|
||||
from crewai.utilities.crew.models import CrewContext
|
||||
|
||||
from pydantic import (
|
||||
UUID4,
|
||||
BaseModel,
|
||||
@@ -616,6 +621,11 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
) -> CrewOutput:
|
||||
ctx = baggage.set_baggage(
|
||||
"crew_context", CrewContext(id=str(self.id), key=self.key)
|
||||
)
|
||||
token = attach(ctx)
|
||||
|
||||
try:
|
||||
for before_callback in self.before_kickoff_callbacks:
|
||||
if inputs is None:
|
||||
@@ -676,6 +686,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
CrewKickoffFailedEvent(error=str(e), crew_name=self.name or "crew"),
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
detach(token)
|
||||
|
||||
def kickoff_for_each(self, inputs: List[Dict[str, Any]]) -> List[CrewOutput]:
|
||||
"""Executes the Crew's workflow for each input in the list and aggregates results."""
|
||||
|
||||
@@ -3,6 +3,7 @@ import logging
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
@@ -23,6 +24,13 @@ from dotenv import load_dotenv
|
||||
from litellm.types.utils import ChatCompletionDeltaToolCall
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
try:
|
||||
import google.generativeai as genai
|
||||
GOOGLE_GENAI_AVAILABLE = True
|
||||
except ImportError:
|
||||
GOOGLE_GENAI_AVAILABLE = False
|
||||
genai = None # type: ignore
|
||||
|
||||
from crewai.utilities.events.llm_events import (
|
||||
LLMCallCompletedEvent,
|
||||
LLMCallFailedEvent,
|
||||
@@ -57,6 +65,32 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededException,
|
||||
)
|
||||
|
||||
|
||||
class BatchJobStartedEvent:
|
||||
"""Event emitted when a batch job is started."""
|
||||
def __init__(self, messages, tools=None, from_task=None, from_agent=None):
|
||||
self.messages = messages
|
||||
self.tools = tools
|
||||
self.from_task = from_task
|
||||
self.from_agent = from_agent
|
||||
|
||||
|
||||
class BatchJobCompletedEvent:
|
||||
"""Event emitted when a batch job is completed."""
|
||||
def __init__(self, response, job_name, from_task=None, from_agent=None):
|
||||
self.response = response
|
||||
self.job_name = job_name
|
||||
self.from_task = from_task
|
||||
self.from_agent = from_agent
|
||||
|
||||
|
||||
class BatchJobFailedEvent:
|
||||
"""Event emitted when a batch job fails."""
|
||||
def __init__(self, error, from_task=None, from_agent=None):
|
||||
self.error = error
|
||||
self.from_task = from_task
|
||||
self.from_agent = from_agent
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
@@ -311,6 +345,9 @@ class LLM(BaseLLM):
|
||||
callbacks: List[Any] = [],
|
||||
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
|
||||
stream: bool = False,
|
||||
batch_mode: bool = False,
|
||||
batch_size: Optional[int] = None,
|
||||
batch_timeout: Optional[int] = 300,
|
||||
**kwargs,
|
||||
):
|
||||
self.model = model
|
||||
@@ -337,6 +374,12 @@ class LLM(BaseLLM):
|
||||
self.additional_params = kwargs
|
||||
self.is_anthropic = self._is_anthropic_model(model)
|
||||
self.stream = stream
|
||||
self.batch_mode = batch_mode
|
||||
self.batch_size = batch_size or 10
|
||||
self.batch_timeout = batch_timeout
|
||||
self._batch_requests: List[Dict[str, Any]] = []
|
||||
self._current_batch_job: Optional[str] = None
|
||||
self._batch_results: Dict[str, List[str]] = {}
|
||||
|
||||
litellm.drop_params = True
|
||||
|
||||
@@ -363,6 +406,10 @@ class LLM(BaseLLM):
|
||||
ANTHROPIC_PREFIXES = ("anthropic/", "claude-", "claude/")
|
||||
return any(prefix in model.lower() for prefix in ANTHROPIC_PREFIXES)
|
||||
|
||||
def _is_gemini_model(self) -> bool:
|
||||
"""Check if the model is a Gemini model that supports batch mode."""
|
||||
return "gemini" in self.model.lower() and GOOGLE_GENAI_AVAILABLE
|
||||
|
||||
def _prepare_completion_params(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
@@ -414,6 +461,87 @@ class LLM(BaseLLM):
|
||||
# Remove None values from params
|
||||
return {k: v for k, v in params.items() if v is not None}
|
||||
|
||||
def _prepare_batch_request(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
tools: Optional[List[dict]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Prepare a single request for batch processing."""
|
||||
if not self._is_gemini_model():
|
||||
raise ValueError("Batch mode is only supported for Gemini models")
|
||||
|
||||
formatted_messages = self._format_messages_for_provider(messages)
|
||||
|
||||
request: Dict[str, Any] = {
|
||||
"contents": [],
|
||||
"generationConfig": {
|
||||
"temperature": self.temperature,
|
||||
"topP": self.top_p,
|
||||
"maxOutputTokens": self.max_tokens or self.max_completion_tokens,
|
||||
"stopSequences": self.stop if isinstance(self.stop, list) else [self.stop] if self.stop else None,
|
||||
}
|
||||
}
|
||||
|
||||
for message in formatted_messages:
|
||||
role = "user" if message["role"] == "user" else "model"
|
||||
request["contents"].append({
|
||||
"role": role,
|
||||
"parts": [{"text": message["content"]}]
|
||||
})
|
||||
|
||||
if tools:
|
||||
request["tools"] = tools
|
||||
|
||||
return {"model": self.model.replace("gemini/", ""), "contents": request["contents"], "generationConfig": request["generationConfig"]}
|
||||
|
||||
def _submit_batch_job(self, requests: List[Dict[str, Any]]) -> str:
|
||||
"""Submit requests for sequential processing (fallback for batch mode)."""
|
||||
if not GOOGLE_GENAI_AVAILABLE:
|
||||
raise ImportError("google-generativeai is required for batch mode")
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("API key is required for batch mode")
|
||||
|
||||
genai.configure(api_key=self.api_key)
|
||||
model = genai.GenerativeModel(self.model.replace("gemini/", ""))
|
||||
|
||||
job_id = f"crewai-batch-{int(time.time())}"
|
||||
results = []
|
||||
|
||||
for request in requests:
|
||||
try:
|
||||
response = model.generate_content(
|
||||
request["contents"],
|
||||
generation_config=request.get("generationConfig")
|
||||
)
|
||||
results.append(response.text if response.text else "")
|
||||
except Exception as e:
|
||||
results.append(f"Error: {str(e)}")
|
||||
|
||||
self._batch_results[job_id] = results
|
||||
return job_id
|
||||
|
||||
def _poll_batch_job(self, job_name: str) -> Any:
|
||||
"""Return immediately since processing is synchronous."""
|
||||
if not GOOGLE_GENAI_AVAILABLE:
|
||||
raise ImportError("google-generativeai is required for batch mode")
|
||||
|
||||
class MockBatchJob:
|
||||
def __init__(self):
|
||||
self.state = type('State', (), {'name': 'JOB_STATE_SUCCEEDED'})()
|
||||
|
||||
return MockBatchJob()
|
||||
|
||||
def _retrieve_batch_results(self, job_name: str) -> List[str]:
|
||||
"""Retrieve stored results."""
|
||||
if not GOOGLE_GENAI_AVAILABLE:
|
||||
raise ImportError("google-generativeai is required for batch mode")
|
||||
|
||||
if job_name in self._batch_results:
|
||||
return self._batch_results[job_name]
|
||||
|
||||
return []
|
||||
|
||||
def _handle_streaming_response(
|
||||
self,
|
||||
params: Dict[str, Any],
|
||||
@@ -952,6 +1080,11 @@ class LLM(BaseLLM):
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
|
||||
if self.batch_mode and self._is_gemini_model():
|
||||
return self._handle_batch_request(
|
||||
messages, tools, callbacks, available_functions, from_task, from_agent
|
||||
)
|
||||
|
||||
# --- 4) Handle O1 model special case (system messages not supported)
|
||||
if "o1" in self.model.lower():
|
||||
for message in messages:
|
||||
@@ -991,6 +1124,77 @@ class LLM(BaseLLM):
|
||||
logging.error(f"LiteLLM call failed: {str(e)}")
|
||||
raise
|
||||
|
||||
def _handle_batch_request(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
) -> str:
|
||||
"""Handle batch mode request for Gemini models."""
|
||||
if not self._is_gemini_model():
|
||||
raise ValueError("Batch mode is only supported for Gemini models")
|
||||
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=BatchJobStartedEvent(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
batch_request = self._prepare_batch_request(messages, tools)
|
||||
self._batch_requests.append(batch_request)
|
||||
|
||||
if len(self._batch_requests) >= self.batch_size:
|
||||
job_name = self._submit_batch_job(self._batch_requests)
|
||||
self._current_batch_job = job_name
|
||||
|
||||
self._poll_batch_job(job_name)
|
||||
results = self._retrieve_batch_results(job_name)
|
||||
|
||||
self._batch_requests.clear()
|
||||
self._current_batch_job = None
|
||||
|
||||
if results:
|
||||
response = results[0]
|
||||
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=BatchJobCompletedEvent(
|
||||
response=response,
|
||||
job_name=job_name,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
else:
|
||||
raise RuntimeError("No results returned from batch job")
|
||||
else:
|
||||
return "Batch request queued. Call with more requests to trigger batch processing."
|
||||
|
||||
except Exception as e:
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=BatchJobFailedEvent(
|
||||
error=str(e),
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
logging.error(f"Batch request failed: {str(e)}")
|
||||
raise
|
||||
|
||||
def _handle_emit_call_events(self, response: Any, call_type: LLMCallType, from_task: Optional[Any] = None, from_agent: Optional[Any] = None):
|
||||
"""Handle the events for the LLM call.
|
||||
|
||||
|
||||
1
src/crewai/utilities/crew/__init__.py
Normal file
1
src/crewai/utilities/crew/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Crew-specific utilities."""
|
||||
16
src/crewai/utilities/crew/crew_context.py
Normal file
16
src/crewai/utilities/crew/crew_context.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Context management utilities for tracking crew and task execution context using OpenTelemetry baggage."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from opentelemetry import baggage
|
||||
|
||||
from crewai.utilities.crew.models import CrewContext
|
||||
|
||||
|
||||
def get_crew_context() -> Optional[CrewContext]:
|
||||
"""Get the current crew context from OpenTelemetry baggage.
|
||||
|
||||
Returns:
|
||||
CrewContext instance containing crew context information, or None if no context is set
|
||||
"""
|
||||
return baggage.get_baggage("crew_context")
|
||||
16
src/crewai/utilities/crew/models.py
Normal file
16
src/crewai/utilities/crew/models.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Models for crew-related data structures."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class CrewContext(BaseModel):
|
||||
"""Model representing crew context information."""
|
||||
|
||||
id: Optional[str] = Field(
|
||||
default=None, description="Unique identifier for the crew"
|
||||
)
|
||||
key: Optional[str] = Field(
|
||||
default=None, description="Optional crew key/name for identification"
|
||||
)
|
||||
@@ -1,3 +1,4 @@
|
||||
from inspect import getsource
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
from crewai.utilities.events.base_events import BaseEvent
|
||||
@@ -16,23 +17,26 @@ class LLMGuardrailStartedEvent(BaseEvent):
|
||||
retry_count: int
|
||||
|
||||
def __init__(self, **data):
|
||||
from inspect import getsource
|
||||
|
||||
from crewai.tasks.llm_guardrail import LLMGuardrail
|
||||
from crewai.tasks.hallucination_guardrail import HallucinationGuardrail
|
||||
|
||||
super().__init__(**data)
|
||||
|
||||
if isinstance(self.guardrail, LLMGuardrail) or isinstance(
|
||||
self.guardrail, HallucinationGuardrail
|
||||
):
|
||||
if isinstance(self.guardrail, (LLMGuardrail, HallucinationGuardrail)):
|
||||
self.guardrail = self.guardrail.description.strip()
|
||||
elif isinstance(self.guardrail, Callable):
|
||||
self.guardrail = getsource(self.guardrail).strip()
|
||||
|
||||
|
||||
class LLMGuardrailCompletedEvent(BaseEvent):
|
||||
"""Event emitted when a guardrail task completes"""
|
||||
"""Event emitted when a guardrail task completes
|
||||
|
||||
Attributes:
|
||||
success: Whether the guardrail validation passed
|
||||
result: The validation result
|
||||
error: Error message if validation failed
|
||||
retry_count: The number of times the guardrail has been retried
|
||||
"""
|
||||
|
||||
type: str = "llm_guardrail_completed"
|
||||
success: bool
|
||||
|
||||
235
tests/test_batch_mode.py
Normal file
235
tests/test_batch_mode.py
Normal file
@@ -0,0 +1,235 @@
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from crewai.llm import LLM
|
||||
|
||||
|
||||
class TestBatchMode:
|
||||
"""Test suite for Google Batch Mode functionality."""
|
||||
|
||||
def test_batch_mode_initialization(self):
|
||||
"""Test that batch mode parameters are properly initialized."""
|
||||
llm = LLM(
|
||||
model="gemini/gemini-1.5-pro",
|
||||
batch_mode=True,
|
||||
batch_size=5,
|
||||
batch_timeout=600
|
||||
)
|
||||
|
||||
assert llm.batch_mode is True
|
||||
assert llm.batch_size == 5
|
||||
assert llm.batch_timeout == 600
|
||||
assert llm._batch_requests == []
|
||||
assert llm._current_batch_job is None
|
||||
|
||||
def test_batch_mode_defaults(self):
|
||||
"""Test default values for batch mode parameters."""
|
||||
llm = LLM(model="gemini/gemini-1.5-pro", batch_mode=True)
|
||||
|
||||
assert llm.batch_mode is True
|
||||
assert llm.batch_size == 10
|
||||
assert llm.batch_timeout == 300
|
||||
|
||||
def test_is_gemini_model_detection(self):
|
||||
"""Test Gemini model detection for batch mode support."""
|
||||
with patch('crewai.llm.GOOGLE_GENAI_AVAILABLE', True):
|
||||
llm_gemini = LLM(model="gemini/gemini-1.5-pro")
|
||||
assert llm_gemini._is_gemini_model() is True
|
||||
|
||||
llm_openai = LLM(model="gpt-4")
|
||||
assert llm_openai._is_gemini_model() is False
|
||||
|
||||
def test_is_gemini_model_without_genai_available(self):
|
||||
"""Test Gemini model detection when google-generativeai is not available."""
|
||||
with patch('crewai.llm.GOOGLE_GENAI_AVAILABLE', False):
|
||||
llm = LLM(model="gemini/gemini-1.5-pro")
|
||||
assert llm._is_gemini_model() is False
|
||||
|
||||
def test_prepare_batch_request(self):
|
||||
"""Test batch request preparation."""
|
||||
with patch('crewai.llm.GOOGLE_GENAI_AVAILABLE', True):
|
||||
llm = LLM(
|
||||
model="gemini/gemini-1.5-pro",
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
max_tokens=1000
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": "Hello, world!"}]
|
||||
batch_request = llm._prepare_batch_request(messages)
|
||||
|
||||
assert "model" in batch_request
|
||||
assert batch_request["model"] == "gemini-1.5-pro"
|
||||
assert "contents" in batch_request
|
||||
assert "generationConfig" in batch_request
|
||||
assert batch_request["generationConfig"]["temperature"] == 0.7
|
||||
assert batch_request["generationConfig"]["topP"] == 0.9
|
||||
assert batch_request["generationConfig"]["maxOutputTokens"] == 1000
|
||||
|
||||
def test_prepare_batch_request_non_gemini_model(self):
|
||||
"""Test that batch request preparation fails for non-Gemini models."""
|
||||
llm = LLM(model="gpt-4")
|
||||
messages = [{"role": "user", "content": "Hello, world!"}]
|
||||
|
||||
with pytest.raises(ValueError, match="Batch mode is only supported for Gemini models"):
|
||||
llm._prepare_batch_request(messages)
|
||||
|
||||
@patch('crewai.llm.genai')
|
||||
def test_submit_batch_job(self, mock_genai):
|
||||
"""Test batch job submission."""
|
||||
with patch('crewai.llm.GOOGLE_GENAI_AVAILABLE', True):
|
||||
mock_batch_job = Mock()
|
||||
mock_batch_job.name = "test-job-123"
|
||||
mock_genai.create_batch_job.return_value = mock_batch_job
|
||||
|
||||
llm = LLM(
|
||||
model="gemini/gemini-1.5-pro",
|
||||
api_key="test-key"
|
||||
)
|
||||
|
||||
requests = [{"model": "gemini-1.5-pro", "contents": []}]
|
||||
job_name = llm._submit_batch_job(requests)
|
||||
|
||||
assert job_name == "test-job-123"
|
||||
mock_genai.configure.assert_called_with(api_key="test-key")
|
||||
mock_genai.create_batch_job.assert_called_once()
|
||||
|
||||
def test_submit_batch_job_without_genai(self):
|
||||
"""Test batch job submission without google-generativeai available."""
|
||||
with patch('crewai.llm.GOOGLE_GENAI_AVAILABLE', False):
|
||||
llm = LLM(model="gemini/gemini-1.5-pro")
|
||||
|
||||
with pytest.raises(ImportError, match="google-generativeai is required for batch mode"):
|
||||
llm._submit_batch_job([])
|
||||
|
||||
def test_submit_batch_job_without_api_key(self):
|
||||
"""Test batch job submission without API key."""
|
||||
with patch('crewai.llm.GOOGLE_GENAI_AVAILABLE', True):
|
||||
llm = LLM(model="gemini/gemini-1.5-pro")
|
||||
|
||||
with pytest.raises(ValueError, match="API key is required for batch mode"):
|
||||
llm._submit_batch_job([])
|
||||
|
||||
@patch('crewai.llm.genai')
|
||||
@patch('crewai.llm.time')
|
||||
def test_poll_batch_job_success(self, mock_time, mock_genai):
|
||||
"""Test successful batch job polling."""
|
||||
with patch('crewai.llm.GOOGLE_GENAI_AVAILABLE', True):
|
||||
mock_batch_job = Mock()
|
||||
mock_batch_job.state = "JOB_STATE_SUCCEEDED"
|
||||
mock_genai.get_batch_job.return_value = mock_batch_job
|
||||
mock_time.time.side_effect = [0, 1, 2]
|
||||
mock_time.sleep = Mock()
|
||||
|
||||
llm = LLM(
|
||||
model="gemini/gemini-1.5-pro",
|
||||
api_key="test-key"
|
||||
)
|
||||
|
||||
result = llm._poll_batch_job("test-job-123")
|
||||
|
||||
assert result == mock_batch_job
|
||||
mock_genai.get_batch_job.assert_called_with("test-job-123")
|
||||
|
||||
@patch('crewai.llm.genai')
|
||||
@patch('crewai.llm.time')
|
||||
def test_poll_batch_job_timeout(self, mock_time, mock_genai):
|
||||
"""Test batch job polling timeout."""
|
||||
with patch('crewai.llm.GOOGLE_GENAI_AVAILABLE', True):
|
||||
mock_batch_job = Mock()
|
||||
mock_batch_job.state = "JOB_STATE_PENDING"
|
||||
mock_genai.get_batch_job.return_value = mock_batch_job
|
||||
mock_time.time.side_effect = [0, 400]
|
||||
mock_time.sleep = Mock()
|
||||
|
||||
llm = LLM(
|
||||
model="gemini/gemini-1.5-pro",
|
||||
api_key="test-key",
|
||||
batch_timeout=300
|
||||
)
|
||||
|
||||
with pytest.raises(TimeoutError, match="did not complete within 300 seconds"):
|
||||
llm._poll_batch_job("test-job-123")
|
||||
|
||||
@patch('crewai.llm.genai')
|
||||
def test_retrieve_batch_results(self, mock_genai):
|
||||
"""Test batch result retrieval."""
|
||||
with patch('crewai.llm.GOOGLE_GENAI_AVAILABLE', True):
|
||||
mock_batch_job = Mock()
|
||||
mock_batch_job.state = "JOB_STATE_SUCCEEDED"
|
||||
mock_genai.get_batch_job.return_value = mock_batch_job
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.response.candidates = [Mock()]
|
||||
mock_response.response.candidates[0].content.parts = [Mock()]
|
||||
mock_response.response.candidates[0].content.parts[0].text = "Test response"
|
||||
|
||||
mock_genai.list_batch_job_responses.return_value = [mock_response]
|
||||
|
||||
llm = LLM(
|
||||
model="gemini/gemini-1.5-pro",
|
||||
api_key="test-key"
|
||||
)
|
||||
|
||||
results = llm._retrieve_batch_results("test-job-123")
|
||||
|
||||
assert results == ["Test response"]
|
||||
mock_genai.get_batch_job.assert_called_with("test-job-123")
|
||||
mock_genai.list_batch_job_responses.assert_called_with("test-job-123")
|
||||
|
||||
@patch('crewai.llm.genai')
|
||||
def test_retrieve_batch_results_failed_job(self, mock_genai):
|
||||
"""Test batch result retrieval for failed job."""
|
||||
with patch('crewai.llm.GOOGLE_GENAI_AVAILABLE', True):
|
||||
mock_batch_job = Mock()
|
||||
mock_batch_job.state = "JOB_STATE_FAILED"
|
||||
mock_genai.get_batch_job.return_value = mock_batch_job
|
||||
|
||||
llm = LLM(
|
||||
model="gemini/gemini-1.5-pro",
|
||||
api_key="test-key"
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Batch job failed with state: JOB_STATE_FAILED"):
|
||||
llm._retrieve_batch_results("test-job-123")
|
||||
|
||||
@patch('crewai.llm.crewai_event_bus')
|
||||
def test_handle_batch_request_non_gemini(self, mock_event_bus):
|
||||
"""Test batch request handling for non-Gemini models."""
|
||||
llm = LLM(model="gpt-4", batch_mode=True)
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
|
||||
with pytest.raises(ValueError, match="Batch mode is only supported for Gemini models"):
|
||||
llm._handle_batch_request(messages)
|
||||
|
||||
@patch('crewai.llm.crewai_event_bus')
|
||||
def test_batch_mode_call_routing(self, mock_event_bus):
|
||||
"""Test that batch mode calls are routed correctly."""
|
||||
with patch('crewai.llm.GOOGLE_GENAI_AVAILABLE', True):
|
||||
llm = LLM(
|
||||
model="gemini/gemini-1.5-pro",
|
||||
batch_mode=True,
|
||||
api_key="test-key"
|
||||
)
|
||||
|
||||
with patch.object(llm, '_handle_batch_request') as mock_batch_handler:
|
||||
mock_batch_handler.return_value = "Batch response"
|
||||
|
||||
result = llm.call("Hello, world!")
|
||||
|
||||
assert result == "Batch response"
|
||||
mock_batch_handler.assert_called_once()
|
||||
|
||||
def test_non_batch_mode_unchanged(self):
|
||||
"""Test that non-batch mode behavior is unchanged."""
|
||||
with patch('crewai.llm.litellm') as mock_litellm:
|
||||
mock_response = Mock()
|
||||
mock_response.choices = [Mock()]
|
||||
mock_response.choices[0].message.content = "Regular response"
|
||||
mock_response.choices[0].message.tool_calls = []
|
||||
mock_litellm.completion.return_value = mock_response
|
||||
|
||||
llm = LLM(model="gemini/gemini-1.5-pro", batch_mode=False)
|
||||
result = llm.call("Hello, world!")
|
||||
|
||||
assert result == "Regular response"
|
||||
mock_litellm.completion.assert_called_once()
|
||||
226
tests/test_crew_thread_safety.py
Normal file
226
tests/test_crew_thread_safety.py
Normal file
@@ -0,0 +1,226 @@
|
||||
import asyncio
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict, Any, Callable
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai import Agent, Crew, Task
|
||||
from crewai.utilities.crew.crew_context import get_crew_context
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_agent_factory():
|
||||
def create_agent(name: str) -> Agent:
|
||||
return Agent(
|
||||
role=f"{name} Agent",
|
||||
goal=f"Complete {name} task",
|
||||
backstory=f"I am agent for {name}",
|
||||
)
|
||||
|
||||
return create_agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_task_factory():
|
||||
def create_task(name: str, callback: Callable = None) -> Task:
|
||||
return Task(
|
||||
description=f"Task for {name}", expected_output="Done", callback=callback
|
||||
)
|
||||
|
||||
return create_task
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def crew_factory(simple_agent_factory, simple_task_factory):
|
||||
def create_crew(name: str, task_callback: Callable = None) -> Crew:
|
||||
agent = simple_agent_factory(name)
|
||||
task = simple_task_factory(name, callback=task_callback)
|
||||
task.agent = agent
|
||||
|
||||
return Crew(agents=[agent], tasks=[task], verbose=False)
|
||||
|
||||
return create_crew
|
||||
|
||||
|
||||
class TestCrewThreadSafety:
|
||||
@patch("crewai.Agent.execute_task")
|
||||
def test_parallel_crews_thread_safety(self, mock_execute_task, crew_factory):
|
||||
mock_execute_task.return_value = "Task completed"
|
||||
num_crews = 5
|
||||
|
||||
def run_crew_with_context_check(crew_id: str) -> Dict[str, Any]:
|
||||
results = {"crew_id": crew_id, "contexts": []}
|
||||
|
||||
def check_context_task(output):
|
||||
context = get_crew_context()
|
||||
results["contexts"].append(
|
||||
{
|
||||
"stage": "task_callback",
|
||||
"crew_id": context.id if context else None,
|
||||
"crew_key": context.key if context else None,
|
||||
"thread": threading.current_thread().name,
|
||||
}
|
||||
)
|
||||
return output
|
||||
|
||||
context_before = get_crew_context()
|
||||
results["contexts"].append(
|
||||
{
|
||||
"stage": "before_kickoff",
|
||||
"crew_id": context_before.id if context_before else None,
|
||||
"thread": threading.current_thread().name,
|
||||
}
|
||||
)
|
||||
|
||||
crew = crew_factory(crew_id, task_callback=check_context_task)
|
||||
output = crew.kickoff()
|
||||
|
||||
context_after = get_crew_context()
|
||||
results["contexts"].append(
|
||||
{
|
||||
"stage": "after_kickoff",
|
||||
"crew_id": context_after.id if context_after else None,
|
||||
"thread": threading.current_thread().name,
|
||||
}
|
||||
)
|
||||
|
||||
results["crew_uuid"] = str(crew.id)
|
||||
results["output"] = output.raw
|
||||
|
||||
return results
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_crews) as executor:
|
||||
futures = []
|
||||
for i in range(num_crews):
|
||||
future = executor.submit(run_crew_with_context_check, f"crew_{i}")
|
||||
futures.append(future)
|
||||
|
||||
results = [f.result() for f in futures]
|
||||
|
||||
for result in results:
|
||||
crew_uuid = result["crew_uuid"]
|
||||
|
||||
before_ctx = next(
|
||||
ctx for ctx in result["contexts"] if ctx["stage"] == "before_kickoff"
|
||||
)
|
||||
assert (
|
||||
before_ctx["crew_id"] is None
|
||||
), f"Context should be None before kickoff for {result['crew_id']}"
|
||||
|
||||
task_ctx = next(
|
||||
ctx for ctx in result["contexts"] if ctx["stage"] == "task_callback"
|
||||
)
|
||||
assert (
|
||||
task_ctx["crew_id"] == crew_uuid
|
||||
), f"Context mismatch during task for {result['crew_id']}"
|
||||
|
||||
after_ctx = next(
|
||||
ctx for ctx in result["contexts"] if ctx["stage"] == "after_kickoff"
|
||||
)
|
||||
assert (
|
||||
after_ctx["crew_id"] is None
|
||||
), f"Context should be None after kickoff for {result['crew_id']}"
|
||||
|
||||
thread_name = before_ctx["thread"]
|
||||
assert (
|
||||
"ThreadPoolExecutor" in thread_name
|
||||
), f"Should run in thread pool for {result['crew_id']}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.Agent.execute_task")
|
||||
async def test_async_crews_thread_safety(self, mock_execute_task, crew_factory):
|
||||
mock_execute_task.return_value = "Task completed"
|
||||
num_crews = 5
|
||||
|
||||
async def run_crew_async(crew_id: str) -> Dict[str, Any]:
|
||||
task_context = {"crew_id": crew_id, "context": None}
|
||||
|
||||
def capture_context(output):
|
||||
ctx = get_crew_context()
|
||||
task_context["context"] = {
|
||||
"crew_id": ctx.id if ctx else None,
|
||||
"crew_key": ctx.key if ctx else None,
|
||||
}
|
||||
return output
|
||||
|
||||
crew = crew_factory(crew_id, task_callback=capture_context)
|
||||
output = await crew.kickoff_async()
|
||||
|
||||
return {
|
||||
"crew_id": crew_id,
|
||||
"crew_uuid": str(crew.id),
|
||||
"output": output.raw,
|
||||
"task_context": task_context,
|
||||
}
|
||||
|
||||
tasks = [run_crew_async(f"async_crew_{i}") for i in range(num_crews)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
for result in results:
|
||||
crew_uuid = result["crew_uuid"]
|
||||
task_ctx = result["task_context"]["context"]
|
||||
|
||||
assert (
|
||||
task_ctx is not None
|
||||
), f"Context should exist during task for {result['crew_id']}"
|
||||
assert (
|
||||
task_ctx["crew_id"] == crew_uuid
|
||||
), f"Context mismatch for {result['crew_id']}"
|
||||
|
||||
@patch("crewai.Agent.execute_task")
|
||||
def test_concurrent_kickoff_for_each(self, mock_execute_task, crew_factory):
|
||||
mock_execute_task.return_value = "Task completed"
|
||||
contexts_captured = []
|
||||
|
||||
def capture_context(output):
|
||||
ctx = get_crew_context()
|
||||
contexts_captured.append(
|
||||
{
|
||||
"context_id": ctx.id if ctx else None,
|
||||
"thread": threading.current_thread().name,
|
||||
}
|
||||
)
|
||||
return output
|
||||
|
||||
crew = crew_factory("for_each_test", task_callback=capture_context)
|
||||
inputs = [{"item": f"input_{i}"} for i in range(3)]
|
||||
|
||||
results = crew.kickoff_for_each(inputs=inputs)
|
||||
|
||||
assert len(results) == len(inputs)
|
||||
assert len(contexts_captured) == len(inputs)
|
||||
|
||||
context_ids = [ctx["context_id"] for ctx in contexts_captured]
|
||||
assert len(set(context_ids)) == len(
|
||||
inputs
|
||||
), "Each execution should have unique context"
|
||||
|
||||
@patch("crewai.Agent.execute_task")
|
||||
def test_no_context_leakage_between_crews(self, mock_execute_task, crew_factory):
|
||||
mock_execute_task.return_value = "Task completed"
|
||||
contexts = []
|
||||
|
||||
def check_context(output):
|
||||
ctx = get_crew_context()
|
||||
contexts.append(
|
||||
{
|
||||
"context_id": ctx.id if ctx else None,
|
||||
"context_key": ctx.key if ctx else None,
|
||||
}
|
||||
)
|
||||
return output
|
||||
|
||||
def run_crew(name: str):
|
||||
crew = crew_factory(name, task_callback=check_context)
|
||||
crew.kickoff()
|
||||
return str(crew.id)
|
||||
|
||||
crew1_id = run_crew("First")
|
||||
crew2_id = run_crew("Second")
|
||||
|
||||
assert len(contexts) == 2
|
||||
assert contexts[0]["context_id"] == crew1_id
|
||||
assert contexts[1]["context_id"] == crew2_id
|
||||
assert contexts[0]["context_id"] != contexts[1]["context_id"]
|
||||
0
tests/utilities/crew/__init__.py
Normal file
0
tests/utilities/crew/__init__.py
Normal file
88
tests/utilities/crew/test_crew_context.py
Normal file
88
tests/utilities/crew/test_crew_context.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from opentelemetry import baggage
|
||||
from opentelemetry.context import attach, detach
|
||||
|
||||
from crewai.utilities.crew.crew_context import get_crew_context
|
||||
from crewai.utilities.crew.models import CrewContext
|
||||
|
||||
|
||||
def test_crew_context_creation():
|
||||
crew_id = str(uuid.uuid4())
|
||||
context = CrewContext(id=crew_id, key="test-crew")
|
||||
assert context.id == crew_id
|
||||
assert context.key == "test-crew"
|
||||
|
||||
|
||||
def test_get_crew_context_with_baggage():
|
||||
crew_id = str(uuid.uuid4())
|
||||
assert get_crew_context() is None
|
||||
|
||||
crew_ctx = CrewContext(id=crew_id, key="test-key")
|
||||
ctx = baggage.set_baggage("crew_context", crew_ctx)
|
||||
token = attach(ctx)
|
||||
|
||||
try:
|
||||
context = get_crew_context()
|
||||
assert context is not None
|
||||
assert context.id == crew_id
|
||||
assert context.key == "test-key"
|
||||
finally:
|
||||
detach(token)
|
||||
|
||||
assert get_crew_context() is None
|
||||
|
||||
|
||||
def test_get_crew_context_empty():
|
||||
assert get_crew_context() is None
|
||||
|
||||
|
||||
def test_baggage_nested_contexts():
|
||||
crew_id1 = str(uuid.uuid4())
|
||||
crew_id2 = str(uuid.uuid4())
|
||||
|
||||
crew_ctx1 = CrewContext(id=crew_id1, key="outer")
|
||||
ctx1 = baggage.set_baggage("crew_context", crew_ctx1)
|
||||
token1 = attach(ctx1)
|
||||
|
||||
try:
|
||||
outer_context = get_crew_context()
|
||||
assert outer_context.id == crew_id1
|
||||
assert outer_context.key == "outer"
|
||||
|
||||
crew_ctx2 = CrewContext(id=crew_id2, key="inner")
|
||||
ctx2 = baggage.set_baggage("crew_context", crew_ctx2)
|
||||
token2 = attach(ctx2)
|
||||
|
||||
try:
|
||||
inner_context = get_crew_context()
|
||||
assert inner_context.id == crew_id2
|
||||
assert inner_context.key == "inner"
|
||||
finally:
|
||||
detach(token2)
|
||||
|
||||
restored_context = get_crew_context()
|
||||
assert restored_context.id == crew_id1
|
||||
assert restored_context.key == "outer"
|
||||
finally:
|
||||
detach(token1)
|
||||
|
||||
assert get_crew_context() is None
|
||||
|
||||
|
||||
def test_baggage_exception_handling():
|
||||
crew_id = str(uuid.uuid4())
|
||||
|
||||
crew_ctx = CrewContext(id=crew_id, key="test")
|
||||
ctx = baggage.set_baggage("crew_context", crew_ctx)
|
||||
token = attach(ctx)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
try:
|
||||
assert get_crew_context() is not None
|
||||
raise ValueError("Test exception")
|
||||
finally:
|
||||
detach(token)
|
||||
|
||||
assert get_crew_context() is None
|
||||
Reference in New Issue
Block a user