Add message_id to LLM events to differentiate stream chunks

- Add message_id field to LLMEventBase for all LLM events
- Generate unique message_id (UUID) for each LLM call in LLM.call()
- Thread message_id through all streaming and non-streaming paths
- Update all event emissions to include message_id
- Add comprehensive tests verifying message_id uniqueness

Fixes #3845

Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
Devin AI
2025-11-06 14:02:39 +00:00
parent e4cc9a664c
commit 04b88bcf88
3 changed files with 212 additions and 5 deletions

View File

@@ -9,6 +9,7 @@ from crewai.events.base_events import BaseEvent
class LLMEventBase(BaseEvent):
from_task: Any | None = None
from_agent: Any | None = None
message_id: str | None = None
def __init__(self, **data):
if data.get("from_task"):

View File

@@ -18,6 +18,7 @@ from typing import (
TypedDict,
cast,
)
import uuid
from dotenv import load_dotenv
import httpx
@@ -532,6 +533,7 @@ class LLM(BaseLLM):
from_task: Task | None = None,
from_agent: Agent | None = None,
response_model: type[BaseModel] | None = None,
message_id: str | None = None,
) -> Any:
"""Handle a streaming response from the LLM.
@@ -626,6 +628,7 @@ class LLM(BaseLLM):
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
message_id=message_id,
)
if result is not None:
@@ -646,6 +649,7 @@ class LLM(BaseLLM):
chunk=chunk_content,
from_task=from_task,
from_agent=from_agent,
message_id=message_id,
),
)
# --- 4) Fallback to non-streaming if no content received
@@ -763,6 +767,7 @@ class LLM(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
message_id=message_id,
)
return structured_response
@@ -772,11 +777,12 @@ class LLM(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
message_id=message_id,
)
return full_response
# --- 9) Handle tool calls if present
tool_result = self._handle_tool_call(tool_calls, available_functions)
tool_result = self._handle_tool_call(tool_calls, available_functions, from_task, from_agent, message_id)
if tool_result is not None:
return tool_result
@@ -792,6 +798,7 @@ class LLM(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
message_id=message_id,
)
return full_response
@@ -810,13 +817,14 @@ class LLM(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
message_id=message_id,
)
return full_response
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(
error=str(e), from_task=from_task, from_agent=from_agent
error=str(e), from_task=from_task, from_agent=from_agent, message_id=message_id
),
)
raise Exception(f"Failed to get streaming response: {e!s}") from e
@@ -828,6 +836,7 @@ class LLM(BaseLLM):
available_functions: dict[str, Any] | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
message_id: str | None = None,
) -> Any:
for tool_call in tool_calls:
current_tool_accumulator = accumulated_tool_args[tool_call.index]
@@ -847,6 +856,7 @@ class LLM(BaseLLM):
chunk=tool_call.function.arguments,
from_task=from_task,
from_agent=from_agent,
message_id=message_id,
),
)
@@ -861,6 +871,9 @@ class LLM(BaseLLM):
return self._handle_tool_call(
[current_tool_accumulator],
available_functions,
from_task,
from_agent,
message_id,
)
except json.JSONDecodeError:
continue
@@ -914,6 +927,7 @@ class LLM(BaseLLM):
from_task: Task | None = None,
from_agent: Agent | None = None,
response_model: type[BaseModel] | None = None,
message_id: str | None = None,
) -> str | Any:
"""Handle a non-streaming response from the LLM.
@@ -954,6 +968,7 @@ class LLM(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
message_id=message_id,
)
return structured_response
@@ -982,6 +997,7 @@ class LLM(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
message_id=message_id,
)
return structured_response
@@ -1013,6 +1029,7 @@ class LLM(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
message_id=message_id,
)
return text_response
@@ -1022,7 +1039,7 @@ class LLM(BaseLLM):
# --- 7) Handle tool calls if present
tool_result = self._handle_tool_call(
tool_calls, available_functions, from_task, from_agent
tool_calls, available_functions, from_task, from_agent, message_id
)
if tool_result is not None:
return tool_result
@@ -1033,6 +1050,7 @@ class LLM(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
message_id=message_id,
)
return text_response
@@ -1042,6 +1060,7 @@ class LLM(BaseLLM):
available_functions: dict[str, Any] | None = None,
from_task: Task | None = None,
from_agent: Agent | None = None,
message_id: str | None = None,
) -> Any:
"""Handle a tool call from the LLM.
@@ -1101,6 +1120,7 @@ class LLM(BaseLLM):
call_type=LLMCallType.TOOL_CALL,
from_task=from_task,
from_agent=from_agent,
message_id=message_id,
)
return result
except Exception as e:
@@ -1111,7 +1131,7 @@ class LLM(BaseLLM):
logging.error(f"Error executing function '{function_name}': {e}")
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(error=f"Tool execution error: {e!s}"),
event=LLMCallFailedEvent(error=f"Tool execution error: {e!s}", message_id=message_id),
)
crewai_event_bus.emit(
self,
@@ -1161,6 +1181,8 @@ class LLM(BaseLLM):
ValueError: If response format is not supported
LLMContextLengthExceededError: If input exceeds model's context limit
"""
message_id = uuid.uuid4().hex
crewai_event_bus.emit(
self,
event=LLMCallStartedEvent(
@@ -1171,6 +1193,7 @@ class LLM(BaseLLM):
from_task=from_task,
from_agent=from_agent,
model=self.model,
message_id=message_id,
),
)
@@ -1202,6 +1225,7 @@ class LLM(BaseLLM):
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
message_id=message_id,
)
return self._handle_non_streaming_response(
@@ -1211,6 +1235,7 @@ class LLM(BaseLLM):
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
message_id=message_id,
)
except LLMContextLengthExceededError:
# Re-raise LLMContextLengthExceededError as it should be handled
@@ -1248,7 +1273,7 @@ class LLM(BaseLLM):
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(
error=str(e), from_task=from_task, from_agent=from_agent
error=str(e), from_task=from_task, from_agent=from_agent, message_id=message_id
),
)
raise
@@ -1260,6 +1285,7 @@ class LLM(BaseLLM):
from_task: Task | None = None,
from_agent: Agent | None = None,
messages: str | list[LLMMessage] | None = None,
message_id: str | None = None,
) -> None:
"""Handle the events for the LLM call.
@@ -1269,6 +1295,7 @@ class LLM(BaseLLM):
from_task: Optional task object
from_agent: Optional agent object
messages: Optional messages object
message_id: Optional message identifier
"""
crewai_event_bus.emit(
self,
@@ -1279,6 +1306,7 @@ class LLM(BaseLLM):
from_task=from_task,
from_agent=from_agent,
model=self.model,
message_id=message_id,
),
)

View File

@@ -0,0 +1,178 @@
import threading
from unittest.mock import Mock, patch
import pytest
from crewai.agent import Agent
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.llm_events import (
LLMCallCompletedEvent,
LLMCallStartedEvent,
LLMStreamChunkEvent,
)
from crewai.llm import LLM
from crewai.task import Task
@pytest.fixture
def base_agent():
return Agent(
role="test_agent",
llm="gpt-4o-mini",
goal="Test message_id",
backstory="You are a test assistant",
)
@pytest.fixture
def base_task(base_agent):
return Task(
description="Test message_id",
expected_output="test",
agent=base_agent,
)
def test_llm_events_have_unique_message_ids_for_different_calls(base_agent, base_task):
"""Test that different LLM calls have different message_ids"""
received_events = []
event_received = threading.Event()
@crewai_event_bus.on(LLMCallStartedEvent)
def handle_llm_started(source, event):
received_events.append(event)
if len(received_events) >= 2:
event_received.set()
llm = LLM(model="gpt-4o-mini")
with patch("litellm.completion") as mock_completion:
mock_completion.return_value = Mock(
choices=[Mock(message=Mock(content="Response 1", tool_calls=None))],
usage=Mock(prompt_tokens=10, completion_tokens=20, total_tokens=30),
)
llm.call("Test message 1", from_task=base_task, from_agent=base_agent)
llm.call("Test message 2", from_task=base_task, from_agent=base_agent)
assert event_received.wait(timeout=5), "Timeout waiting for LLM started events"
assert len(received_events) >= 2
assert received_events[0].message_id is not None
assert received_events[1].message_id is not None
assert received_events[0].message_id != received_events[1].message_id
def test_streaming_chunks_have_same_message_id(base_agent, base_task):
"""Test that all chunks from the same streaming call have the same message_id"""
received_events = []
lock = threading.Lock()
all_events_received = threading.Event()
@crewai_event_bus.on(LLMStreamChunkEvent)
def handle_stream_chunk(source, event):
with lock:
received_events.append(event)
if len(received_events) >= 3:
all_events_received.set()
llm = LLM(model="gpt-4o-mini", stream=True)
def mock_stream_generator():
yield Mock(
choices=[Mock(delta=Mock(content="Hello", tool_calls=None))],
usage=None,
)
yield Mock(
choices=[Mock(delta=Mock(content=" ", tool_calls=None))],
usage=None,
)
yield Mock(
choices=[Mock(delta=Mock(content="World", tool_calls=None))],
usage=Mock(prompt_tokens=10, completion_tokens=20, total_tokens=30),
)
with patch("litellm.completion", return_value=mock_stream_generator()):
llm.call("Test streaming", from_task=base_task, from_agent=base_agent)
assert all_events_received.wait(timeout=5), "Timeout waiting for stream chunk events"
assert len(received_events) >= 3
message_ids = [event.message_id for event in received_events]
assert all(mid is not None for mid in message_ids)
assert len(set(message_ids)) == 1, "All chunks should have the same message_id"
def test_completed_event_has_same_message_id_as_started(base_agent, base_task):
"""Test that Started and Completed events have the same message_id"""
received_events = {"started": None, "completed": None}
lock = threading.Lock()
all_events_received = threading.Event()
@crewai_event_bus.on(LLMCallStartedEvent)
def handle_started(source, event):
with lock:
received_events["started"] = event
if received_events["completed"] is not None:
all_events_received.set()
@crewai_event_bus.on(LLMCallCompletedEvent)
def handle_completed(source, event):
with lock:
received_events["completed"] = event
if received_events["started"] is not None:
all_events_received.set()
llm = LLM(model="gpt-4o-mini")
with patch("litellm.completion") as mock_completion:
mock_completion.return_value = Mock(
choices=[Mock(message=Mock(content="Response", tool_calls=None))],
usage=Mock(prompt_tokens=10, completion_tokens=20, total_tokens=30),
)
llm.call("Test message", from_task=base_task, from_agent=base_agent)
assert all_events_received.wait(timeout=5), "Timeout waiting for events"
assert received_events["started"] is not None
assert received_events["completed"] is not None
assert received_events["started"].message_id is not None
assert received_events["completed"].message_id is not None
assert received_events["started"].message_id == received_events["completed"].message_id
def test_multiple_calls_same_agent_task_have_different_message_ids(base_agent, base_task):
"""Test that multiple calls from the same agent/task have different message_ids"""
received_started_events = []
lock = threading.Lock()
all_events_received = threading.Event()
@crewai_event_bus.on(LLMCallStartedEvent)
def handle_started(source, event):
with lock:
received_started_events.append(event)
if len(received_started_events) >= 3:
all_events_received.set()
llm = LLM(model="gpt-4o-mini")
with patch("litellm.completion") as mock_completion:
mock_completion.return_value = Mock(
choices=[Mock(message=Mock(content="Response", tool_calls=None))],
usage=Mock(prompt_tokens=10, completion_tokens=20, total_tokens=30),
)
llm.call("Message 1", from_task=base_task, from_agent=base_agent)
llm.call("Message 2", from_task=base_task, from_agent=base_agent)
llm.call("Message 3", from_task=base_task, from_agent=base_agent)
assert all_events_received.wait(timeout=5), "Timeout waiting for events"
assert len(received_started_events) >= 3
message_ids = [event.message_id for event in received_started_events]
assert all(mid is not None for mid in message_ids)
assert len(set(message_ids)) == 3, "Each call should have a unique message_id"
task_ids = [event.task_id for event in received_started_events]
agent_ids = [event.agent_id for event in received_started_events]
assert len(set(task_ids)) == 1, "All calls should have the same task_id"
assert len(set(agent_ids)) == 1, "All calls should have the same agent_id"