mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-27 17:18:13 +00:00
Compare commits
1 Commits
llm-event-
...
devin/1768
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2ada7aba97 |
@@ -341,6 +341,7 @@ class AccumulatedToolArgs(BaseModel):
|
|||||||
|
|
||||||
class LLM(BaseLLM):
|
class LLM(BaseLLM):
|
||||||
completion_cost: float | None = None
|
completion_cost: float | None = None
|
||||||
|
_callback_lock: threading.RLock = threading.RLock()
|
||||||
|
|
||||||
def __new__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> LLM:
|
def __new__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> LLM:
|
||||||
"""Factory method that routes to native SDK or falls back to LiteLLM.
|
"""Factory method that routes to native SDK or falls back to LiteLLM.
|
||||||
@@ -1144,7 +1145,7 @@ class LLM(BaseLLM):
|
|||||||
if response_model:
|
if response_model:
|
||||||
params["response_model"] = response_model
|
params["response_model"] = response_model
|
||||||
response = litellm.completion(**params)
|
response = litellm.completion(**params)
|
||||||
|
|
||||||
if hasattr(response,"usage") and not isinstance(response.usage, type) and response.usage:
|
if hasattr(response,"usage") and not isinstance(response.usage, type) and response.usage:
|
||||||
usage_info = response.usage
|
usage_info = response.usage
|
||||||
self._track_token_usage_internal(usage_info)
|
self._track_token_usage_internal(usage_info)
|
||||||
@@ -1363,7 +1364,7 @@ class LLM(BaseLLM):
|
|||||||
"""
|
"""
|
||||||
full_response = ""
|
full_response = ""
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
|
|
||||||
usage_info = None
|
usage_info = None
|
||||||
|
|
||||||
accumulated_tool_args: defaultdict[int, AccumulatedToolArgs] = defaultdict(
|
accumulated_tool_args: defaultdict[int, AccumulatedToolArgs] = defaultdict(
|
||||||
@@ -1657,78 +1658,92 @@ class LLM(BaseLLM):
|
|||||||
raise ValueError("LLM call blocked by before_llm_call hook")
|
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||||
|
|
||||||
# --- 5) Set up callbacks if provided
|
# --- 5) Set up callbacks if provided
|
||||||
|
# Use a class-level lock to synchronize access to global litellm callbacks.
|
||||||
|
# This prevents race conditions when multiple LLM instances call set_callbacks
|
||||||
|
# concurrently, which could cause callbacks to be removed before they fire.
|
||||||
with suppress_warnings():
|
with suppress_warnings():
|
||||||
if callbacks and len(callbacks) > 0:
|
with LLM._callback_lock:
|
||||||
self.set_callbacks(callbacks)
|
if callbacks and len(callbacks) > 0:
|
||||||
try:
|
self.set_callbacks(callbacks)
|
||||||
# --- 6) Prepare parameters for the completion call
|
try:
|
||||||
params = self._prepare_completion_params(messages, tools)
|
# --- 6) Prepare parameters for the completion call
|
||||||
# --- 7) Make the completion call and handle response
|
params = self._prepare_completion_params(messages, tools)
|
||||||
if self.stream:
|
# --- 7) Make the completion call and handle response
|
||||||
result = self._handle_streaming_response(
|
if self.stream:
|
||||||
params=params,
|
result = self._handle_streaming_response(
|
||||||
callbacks=callbacks,
|
params=params,
|
||||||
available_functions=available_functions,
|
callbacks=callbacks,
|
||||||
from_task=from_task,
|
available_functions=available_functions,
|
||||||
from_agent=from_agent,
|
from_task=from_task,
|
||||||
response_model=response_model,
|
from_agent=from_agent,
|
||||||
)
|
response_model=response_model,
|
||||||
else:
|
|
||||||
result = self._handle_non_streaming_response(
|
|
||||||
params=params,
|
|
||||||
callbacks=callbacks,
|
|
||||||
available_functions=available_functions,
|
|
||||||
from_task=from_task,
|
|
||||||
from_agent=from_agent,
|
|
||||||
response_model=response_model,
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(result, str):
|
|
||||||
result = self._invoke_after_llm_call_hooks(
|
|
||||||
messages, result, from_agent
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
except LLMContextLengthExceededError:
|
|
||||||
# Re-raise LLMContextLengthExceededError as it should be handled
|
|
||||||
# by the CrewAgentExecutor._invoke_loop method, which can then decide
|
|
||||||
# whether to summarize the content or abort based on the respect_context_window flag
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
unsupported_stop = "Unsupported parameter" in str(
|
|
||||||
e
|
|
||||||
) and "'stop'" in str(e)
|
|
||||||
|
|
||||||
if unsupported_stop:
|
|
||||||
if (
|
|
||||||
"additional_drop_params" in self.additional_params
|
|
||||||
and isinstance(
|
|
||||||
self.additional_params["additional_drop_params"], list
|
|
||||||
)
|
)
|
||||||
):
|
|
||||||
self.additional_params["additional_drop_params"].append("stop")
|
|
||||||
else:
|
else:
|
||||||
self.additional_params = {"additional_drop_params": ["stop"]}
|
result = self._handle_non_streaming_response(
|
||||||
|
params=params,
|
||||||
|
callbacks=callbacks,
|
||||||
|
available_functions=available_functions,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
response_model=response_model,
|
||||||
|
)
|
||||||
|
|
||||||
logging.info("Retrying LLM call without the unsupported 'stop'")
|
if isinstance(result, str):
|
||||||
|
result = self._invoke_after_llm_call_hooks(
|
||||||
|
messages, result, from_agent
|
||||||
|
)
|
||||||
|
|
||||||
return self.call(
|
return result
|
||||||
messages,
|
except LLMContextLengthExceededError:
|
||||||
tools=tools,
|
# Re-raise LLMContextLengthExceededError as it should be handled
|
||||||
callbacks=callbacks,
|
# by the CrewAgentExecutor._invoke_loop method, which can then decide
|
||||||
available_functions=available_functions,
|
# whether to summarize the content or abort based on the respect_context_window flag
|
||||||
from_task=from_task,
|
raise
|
||||||
from_agent=from_agent,
|
except Exception as e:
|
||||||
response_model=response_model,
|
unsupported_stop = "Unsupported parameter" in str(
|
||||||
|
e
|
||||||
|
) and "'stop'" in str(e)
|
||||||
|
|
||||||
|
if unsupported_stop:
|
||||||
|
if (
|
||||||
|
"additional_drop_params" in self.additional_params
|
||||||
|
and isinstance(
|
||||||
|
self.additional_params["additional_drop_params"], list
|
||||||
|
)
|
||||||
|
):
|
||||||
|
self.additional_params["additional_drop_params"].append(
|
||||||
|
"stop"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.additional_params = {
|
||||||
|
"additional_drop_params": ["stop"]
|
||||||
|
}
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
"Retrying LLM call without the unsupported 'stop'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Recursive call happens inside the lock since we're using
|
||||||
|
# a reentrant-safe pattern (the lock is released when we
|
||||||
|
# exit the with block, and the recursive call will acquire
|
||||||
|
# it again)
|
||||||
|
return self.call(
|
||||||
|
messages,
|
||||||
|
tools=tools,
|
||||||
|
callbacks=callbacks,
|
||||||
|
available_functions=available_functions,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
response_model=response_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=LLMCallFailedEvent(
|
||||||
|
error=str(e), from_task=from_task, from_agent=from_agent
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
raise
|
||||||
crewai_event_bus.emit(
|
|
||||||
self,
|
|
||||||
event=LLMCallFailedEvent(
|
|
||||||
error=str(e), from_task=from_task, from_agent=from_agent
|
|
||||||
),
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def acall(
|
async def acall(
|
||||||
self,
|
self,
|
||||||
@@ -1790,14 +1805,27 @@ class LLM(BaseLLM):
|
|||||||
msg_role: Literal["assistant"] = "assistant"
|
msg_role: Literal["assistant"] = "assistant"
|
||||||
message["role"] = msg_role
|
message["role"] = msg_role
|
||||||
|
|
||||||
|
# Use a class-level lock to synchronize access to global litellm callbacks.
|
||||||
|
# This prevents race conditions when multiple LLM instances call set_callbacks
|
||||||
|
# concurrently, which could cause callbacks to be removed before they fire.
|
||||||
with suppress_warnings():
|
with suppress_warnings():
|
||||||
if callbacks and len(callbacks) > 0:
|
with LLM._callback_lock:
|
||||||
self.set_callbacks(callbacks)
|
if callbacks and len(callbacks) > 0:
|
||||||
try:
|
self.set_callbacks(callbacks)
|
||||||
params = self._prepare_completion_params(messages, tools)
|
try:
|
||||||
|
params = self._prepare_completion_params(messages, tools)
|
||||||
|
|
||||||
if self.stream:
|
if self.stream:
|
||||||
return await self._ahandle_streaming_response(
|
return await self._ahandle_streaming_response(
|
||||||
|
params=params,
|
||||||
|
callbacks=callbacks,
|
||||||
|
available_functions=available_functions,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
response_model=response_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
return await self._ahandle_non_streaming_response(
|
||||||
params=params,
|
params=params,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
available_functions=available_functions,
|
available_functions=available_functions,
|
||||||
@@ -1805,52 +1833,49 @@ class LLM(BaseLLM):
|
|||||||
from_agent=from_agent,
|
from_agent=from_agent,
|
||||||
response_model=response_model,
|
response_model=response_model,
|
||||||
)
|
)
|
||||||
|
except LLMContextLengthExceededError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
unsupported_stop = "Unsupported parameter" in str(
|
||||||
|
e
|
||||||
|
) and "'stop'" in str(e)
|
||||||
|
|
||||||
return await self._ahandle_non_streaming_response(
|
if unsupported_stop:
|
||||||
params=params,
|
if (
|
||||||
callbacks=callbacks,
|
"additional_drop_params" in self.additional_params
|
||||||
available_functions=available_functions,
|
and isinstance(
|
||||||
from_task=from_task,
|
self.additional_params["additional_drop_params"], list
|
||||||
from_agent=from_agent,
|
)
|
||||||
response_model=response_model,
|
):
|
||||||
)
|
self.additional_params["additional_drop_params"].append(
|
||||||
except LLMContextLengthExceededError:
|
"stop"
|
||||||
raise
|
)
|
||||||
except Exception as e:
|
else:
|
||||||
unsupported_stop = "Unsupported parameter" in str(
|
self.additional_params = {
|
||||||
e
|
"additional_drop_params": ["stop"]
|
||||||
) and "'stop'" in str(e)
|
}
|
||||||
|
|
||||||
if unsupported_stop:
|
logging.info(
|
||||||
if (
|
"Retrying LLM call without the unsupported 'stop'"
|
||||||
"additional_drop_params" in self.additional_params
|
|
||||||
and isinstance(
|
|
||||||
self.additional_params["additional_drop_params"], list
|
|
||||||
)
|
)
|
||||||
):
|
|
||||||
self.additional_params["additional_drop_params"].append("stop")
|
|
||||||
else:
|
|
||||||
self.additional_params = {"additional_drop_params": ["stop"]}
|
|
||||||
|
|
||||||
logging.info("Retrying LLM call without the unsupported 'stop'")
|
return await self.acall(
|
||||||
|
messages,
|
||||||
|
tools=tools,
|
||||||
|
callbacks=callbacks,
|
||||||
|
available_functions=available_functions,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
response_model=response_model,
|
||||||
|
)
|
||||||
|
|
||||||
return await self.acall(
|
crewai_event_bus.emit(
|
||||||
messages,
|
self,
|
||||||
tools=tools,
|
event=LLMCallFailedEvent(
|
||||||
callbacks=callbacks,
|
error=str(e), from_task=from_task, from_agent=from_agent
|
||||||
available_functions=available_functions,
|
),
|
||||||
from_task=from_task,
|
|
||||||
from_agent=from_agent,
|
|
||||||
response_model=response_model,
|
|
||||||
)
|
)
|
||||||
|
raise
|
||||||
crewai_event_bus.emit(
|
|
||||||
self,
|
|
||||||
event=LLMCallFailedEvent(
|
|
||||||
error=str(e), from_task=from_task, from_agent=from_agent
|
|
||||||
),
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _handle_emit_call_events(
|
def _handle_emit_call_events(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from time import sleep
|
import threading
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||||
@@ -18,9 +18,15 @@ from pydantic import BaseModel
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
# TODO: This test fails without print statement, which makes me think that something is happening asynchronously that we need to eventually fix and dive deeper into at a later date
|
|
||||||
@pytest.mark.vcr()
|
@pytest.mark.vcr()
|
||||||
def test_llm_callback_replacement():
|
def test_llm_callback_replacement():
|
||||||
|
"""Test that callbacks are properly isolated between LLM instances.
|
||||||
|
|
||||||
|
This test verifies that the race condition fix (using _callback_lock) works
|
||||||
|
correctly. Previously, this test required a sleep(5) workaround because
|
||||||
|
callbacks were being modified globally without synchronization, causing
|
||||||
|
one LLM instance's callbacks to interfere with another's.
|
||||||
|
"""
|
||||||
llm1 = LLM(model="gpt-4o-mini", is_litellm=True)
|
llm1 = LLM(model="gpt-4o-mini", is_litellm=True)
|
||||||
llm2 = LLM(model="gpt-4o-mini", is_litellm=True)
|
llm2 = LLM(model="gpt-4o-mini", is_litellm=True)
|
||||||
|
|
||||||
@@ -37,7 +43,6 @@ def test_llm_callback_replacement():
|
|||||||
messages=[{"role": "user", "content": "Hello, world from another agent!"}],
|
messages=[{"role": "user", "content": "Hello, world from another agent!"}],
|
||||||
callbacks=[calc_handler_2],
|
callbacks=[calc_handler_2],
|
||||||
)
|
)
|
||||||
sleep(5)
|
|
||||||
usage_metrics_2 = calc_handler_2.token_cost_process.get_summary()
|
usage_metrics_2 = calc_handler_2.token_cost_process.get_summary()
|
||||||
|
|
||||||
# The first handler should not have been updated
|
# The first handler should not have been updated
|
||||||
@@ -46,6 +51,66 @@ def test_llm_callback_replacement():
|
|||||||
assert usage_metrics_1 == calc_handler_1.token_cost_process.get_summary()
|
assert usage_metrics_1 == calc_handler_1.token_cost_process.get_summary()
|
||||||
|
|
||||||
|
|
||||||
|
def test_llm_callback_lock_prevents_race_condition():
|
||||||
|
"""Test that the _callback_lock prevents race conditions in concurrent LLM calls.
|
||||||
|
|
||||||
|
This test verifies that multiple threads can safely call LLM.call() with
|
||||||
|
different callbacks without interfering with each other. The lock ensures
|
||||||
|
that callbacks are properly isolated between concurrent calls.
|
||||||
|
"""
|
||||||
|
num_threads = 5
|
||||||
|
results: list[int] = []
|
||||||
|
errors: list[Exception] = []
|
||||||
|
lock = threading.Lock()
|
||||||
|
|
||||||
|
def make_llm_call(thread_id: int, mock_completion: MagicMock) -> None:
|
||||||
|
try:
|
||||||
|
llm = LLM(model="gpt-4o-mini", is_litellm=True)
|
||||||
|
calc_handler = TokenCalcHandler(token_cost_process=TokenProcess())
|
||||||
|
|
||||||
|
mock_message = MagicMock()
|
||||||
|
mock_message.content = f"Response from thread {thread_id}"
|
||||||
|
mock_choice = MagicMock()
|
||||||
|
mock_choice.message = mock_message
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.choices = [mock_choice]
|
||||||
|
mock_response.usage = {
|
||||||
|
"prompt_tokens": 10,
|
||||||
|
"completion_tokens": 10,
|
||||||
|
"total_tokens": 20,
|
||||||
|
}
|
||||||
|
mock_completion.return_value = mock_response
|
||||||
|
|
||||||
|
llm.call(
|
||||||
|
messages=[{"role": "user", "content": f"Hello from thread {thread_id}"}],
|
||||||
|
callbacks=[calc_handler],
|
||||||
|
)
|
||||||
|
|
||||||
|
usage = calc_handler.token_cost_process.get_summary()
|
||||||
|
with lock:
|
||||||
|
results.append(usage.successful_requests)
|
||||||
|
except Exception as e:
|
||||||
|
with lock:
|
||||||
|
errors.append(e)
|
||||||
|
|
||||||
|
with patch("litellm.completion") as mock_completion:
|
||||||
|
threads = [
|
||||||
|
threading.Thread(target=make_llm_call, args=(i, mock_completion))
|
||||||
|
for i in range(num_threads)
|
||||||
|
]
|
||||||
|
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
assert len(errors) == 0, f"Errors occurred: {errors}"
|
||||||
|
assert len(results) == num_threads
|
||||||
|
assert all(
|
||||||
|
r == 1 for r in results
|
||||||
|
), f"Expected all callbacks to have 1 successful request, got {results}"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.vcr()
|
@pytest.mark.vcr()
|
||||||
def test_llm_call_with_string_input():
|
def test_llm_call_with_string_input():
|
||||||
llm = LLM(model="gpt-4o-mini")
|
llm = LLM(model="gpt-4o-mini")
|
||||||
@@ -989,4 +1054,4 @@ async def test_usage_info_streaming_with_acall():
|
|||||||
assert llm._token_usage["completion_tokens"] > 0
|
assert llm._token_usage["completion_tokens"] > 0
|
||||||
assert llm._token_usage["total_tokens"] > 0
|
assert llm._token_usage["total_tokens"] > 0
|
||||||
|
|
||||||
assert len(result) > 0
|
assert len(result) > 0
|
||||||
|
|||||||
Reference in New Issue
Block a user