mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 04:18:35 +00:00
Add message_id to LLM events to differentiate stream chunks
- Add message_id field to LLMEventBase for all LLM events - Generate unique message_id (UUID) for each LLM call in LLM.call() - Thread message_id through all streaming and non-streaming paths - Update all event emissions to include message_id - Add comprehensive tests verifying message_id uniqueness Fixes #3845 Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -9,6 +9,7 @@ from crewai.events.base_events import BaseEvent
|
||||
class LLMEventBase(BaseEvent):
|
||||
from_task: Any | None = None
|
||||
from_agent: Any | None = None
|
||||
message_id: str | None = None
|
||||
|
||||
def __init__(self, **data):
|
||||
if data.get("from_task"):
|
||||
|
||||
@@ -18,6 +18,7 @@ from typing import (
|
||||
TypedDict,
|
||||
cast,
|
||||
)
|
||||
import uuid
|
||||
|
||||
from dotenv import load_dotenv
|
||||
import httpx
|
||||
@@ -532,6 +533,7 @@ class LLM(BaseLLM):
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Any:
|
||||
"""Handle a streaming response from the LLM.
|
||||
|
||||
@@ -626,6 +628,7 @@ class LLM(BaseLLM):
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
@@ -646,6 +649,7 @@ class LLM(BaseLLM):
|
||||
chunk=chunk_content,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
message_id=message_id,
|
||||
),
|
||||
)
|
||||
# --- 4) Fallback to non-streaming if no content received
|
||||
@@ -763,6 +767,7 @@ class LLM(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
message_id=message_id,
|
||||
)
|
||||
return structured_response
|
||||
|
||||
@@ -772,11 +777,12 @@ class LLM(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
message_id=message_id,
|
||||
)
|
||||
return full_response
|
||||
|
||||
# --- 9) Handle tool calls if present
|
||||
tool_result = self._handle_tool_call(tool_calls, available_functions)
|
||||
tool_result = self._handle_tool_call(tool_calls, available_functions, from_task, from_agent, message_id)
|
||||
if tool_result is not None:
|
||||
return tool_result
|
||||
|
||||
@@ -792,6 +798,7 @@ class LLM(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
message_id=message_id,
|
||||
)
|
||||
return full_response
|
||||
|
||||
@@ -810,13 +817,14 @@ class LLM(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
message_id=message_id,
|
||||
)
|
||||
return full_response
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(
|
||||
error=str(e), from_task=from_task, from_agent=from_agent
|
||||
error=str(e), from_task=from_task, from_agent=from_agent, message_id=message_id
|
||||
),
|
||||
)
|
||||
raise Exception(f"Failed to get streaming response: {e!s}") from e
|
||||
@@ -828,6 +836,7 @@ class LLM(BaseLLM):
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Any:
|
||||
for tool_call in tool_calls:
|
||||
current_tool_accumulator = accumulated_tool_args[tool_call.index]
|
||||
@@ -847,6 +856,7 @@ class LLM(BaseLLM):
|
||||
chunk=tool_call.function.arguments,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
message_id=message_id,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -861,6 +871,9 @@ class LLM(BaseLLM):
|
||||
return self._handle_tool_call(
|
||||
[current_tool_accumulator],
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
message_id,
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
@@ -914,6 +927,7 @@ class LLM(BaseLLM):
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle a non-streaming response from the LLM.
|
||||
|
||||
@@ -954,6 +968,7 @@ class LLM(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
message_id=message_id,
|
||||
)
|
||||
return structured_response
|
||||
|
||||
@@ -982,6 +997,7 @@ class LLM(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
message_id=message_id,
|
||||
)
|
||||
return structured_response
|
||||
|
||||
@@ -1013,6 +1029,7 @@ class LLM(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
message_id=message_id,
|
||||
)
|
||||
return text_response
|
||||
|
||||
@@ -1022,7 +1039,7 @@ class LLM(BaseLLM):
|
||||
|
||||
# --- 7) Handle tool calls if present
|
||||
tool_result = self._handle_tool_call(
|
||||
tool_calls, available_functions, from_task, from_agent
|
||||
tool_calls, available_functions, from_task, from_agent, message_id
|
||||
)
|
||||
if tool_result is not None:
|
||||
return tool_result
|
||||
@@ -1033,6 +1050,7 @@ class LLM(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
message_id=message_id,
|
||||
)
|
||||
return text_response
|
||||
|
||||
@@ -1042,6 +1060,7 @@ class LLM(BaseLLM):
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Any:
|
||||
"""Handle a tool call from the LLM.
|
||||
|
||||
@@ -1101,6 +1120,7 @@ class LLM(BaseLLM):
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
message_id=message_id,
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
@@ -1111,7 +1131,7 @@ class LLM(BaseLLM):
|
||||
logging.error(f"Error executing function '{function_name}': {e}")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(error=f"Tool execution error: {e!s}"),
|
||||
event=LLMCallFailedEvent(error=f"Tool execution error: {e!s}", message_id=message_id),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -1161,6 +1181,8 @@ class LLM(BaseLLM):
|
||||
ValueError: If response format is not supported
|
||||
LLMContextLengthExceededError: If input exceeds model's context limit
|
||||
"""
|
||||
message_id = uuid.uuid4().hex
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallStartedEvent(
|
||||
@@ -1171,6 +1193,7 @@ class LLM(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
model=self.model,
|
||||
message_id=message_id,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1202,6 +1225,7 @@ class LLM(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
return self._handle_non_streaming_response(
|
||||
@@ -1211,6 +1235,7 @@ class LLM(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
message_id=message_id,
|
||||
)
|
||||
except LLMContextLengthExceededError:
|
||||
# Re-raise LLMContextLengthExceededError as it should be handled
|
||||
@@ -1248,7 +1273,7 @@ class LLM(BaseLLM):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(
|
||||
error=str(e), from_task=from_task, from_agent=from_agent
|
||||
error=str(e), from_task=from_task, from_agent=from_agent, message_id=message_id
|
||||
),
|
||||
)
|
||||
raise
|
||||
@@ -1260,6 +1285,7 @@ class LLM(BaseLLM):
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
messages: str | list[LLMMessage] | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> None:
|
||||
"""Handle the events for the LLM call.
|
||||
|
||||
@@ -1269,6 +1295,7 @@ class LLM(BaseLLM):
|
||||
from_task: Optional task object
|
||||
from_agent: Optional agent object
|
||||
messages: Optional messages object
|
||||
message_id: Optional message identifier
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -1279,6 +1306,7 @@ class LLM(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
model=self.model,
|
||||
message_id=message_id,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
178
lib/crewai/tests/test_llm_message_id.py
Normal file
178
lib/crewai/tests/test_llm_message_id.py
Normal file
@@ -0,0 +1,178 @@
|
||||
import threading
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.llm_events import (
|
||||
LLMCallCompletedEvent,
|
||||
LLMCallStartedEvent,
|
||||
LLMStreamChunkEvent,
|
||||
)
|
||||
from crewai.llm import LLM
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_agent():
|
||||
return Agent(
|
||||
role="test_agent",
|
||||
llm="gpt-4o-mini",
|
||||
goal="Test message_id",
|
||||
backstory="You are a test assistant",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_task(base_agent):
|
||||
return Task(
|
||||
description="Test message_id",
|
||||
expected_output="test",
|
||||
agent=base_agent,
|
||||
)
|
||||
|
||||
|
||||
def test_llm_events_have_unique_message_ids_for_different_calls(base_agent, base_task):
|
||||
"""Test that different LLM calls have different message_ids"""
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(LLMCallStartedEvent)
|
||||
def handle_llm_started(source, event):
|
||||
received_events.append(event)
|
||||
if len(received_events) >= 2:
|
||||
event_received.set()
|
||||
|
||||
llm = LLM(model="gpt-4o-mini")
|
||||
|
||||
with patch("litellm.completion") as mock_completion:
|
||||
mock_completion.return_value = Mock(
|
||||
choices=[Mock(message=Mock(content="Response 1", tool_calls=None))],
|
||||
usage=Mock(prompt_tokens=10, completion_tokens=20, total_tokens=30),
|
||||
)
|
||||
|
||||
llm.call("Test message 1", from_task=base_task, from_agent=base_agent)
|
||||
llm.call("Test message 2", from_task=base_task, from_agent=base_agent)
|
||||
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for LLM started events"
|
||||
assert len(received_events) >= 2
|
||||
assert received_events[0].message_id is not None
|
||||
assert received_events[1].message_id is not None
|
||||
assert received_events[0].message_id != received_events[1].message_id
|
||||
|
||||
|
||||
def test_streaming_chunks_have_same_message_id(base_agent, base_task):
|
||||
"""Test that all chunks from the same streaming call have the same message_id"""
|
||||
received_events = []
|
||||
lock = threading.Lock()
|
||||
all_events_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(LLMStreamChunkEvent)
|
||||
def handle_stream_chunk(source, event):
|
||||
with lock:
|
||||
received_events.append(event)
|
||||
if len(received_events) >= 3:
|
||||
all_events_received.set()
|
||||
|
||||
llm = LLM(model="gpt-4o-mini", stream=True)
|
||||
|
||||
def mock_stream_generator():
|
||||
yield Mock(
|
||||
choices=[Mock(delta=Mock(content="Hello", tool_calls=None))],
|
||||
usage=None,
|
||||
)
|
||||
yield Mock(
|
||||
choices=[Mock(delta=Mock(content=" ", tool_calls=None))],
|
||||
usage=None,
|
||||
)
|
||||
yield Mock(
|
||||
choices=[Mock(delta=Mock(content="World", tool_calls=None))],
|
||||
usage=Mock(prompt_tokens=10, completion_tokens=20, total_tokens=30),
|
||||
)
|
||||
|
||||
with patch("litellm.completion", return_value=mock_stream_generator()):
|
||||
llm.call("Test streaming", from_task=base_task, from_agent=base_agent)
|
||||
|
||||
assert all_events_received.wait(timeout=5), "Timeout waiting for stream chunk events"
|
||||
assert len(received_events) >= 3
|
||||
|
||||
message_ids = [event.message_id for event in received_events]
|
||||
assert all(mid is not None for mid in message_ids)
|
||||
assert len(set(message_ids)) == 1, "All chunks should have the same message_id"
|
||||
|
||||
|
||||
def test_completed_event_has_same_message_id_as_started(base_agent, base_task):
|
||||
"""Test that Started and Completed events have the same message_id"""
|
||||
received_events = {"started": None, "completed": None}
|
||||
lock = threading.Lock()
|
||||
all_events_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(LLMCallStartedEvent)
|
||||
def handle_started(source, event):
|
||||
with lock:
|
||||
received_events["started"] = event
|
||||
if received_events["completed"] is not None:
|
||||
all_events_received.set()
|
||||
|
||||
@crewai_event_bus.on(LLMCallCompletedEvent)
|
||||
def handle_completed(source, event):
|
||||
with lock:
|
||||
received_events["completed"] = event
|
||||
if received_events["started"] is not None:
|
||||
all_events_received.set()
|
||||
|
||||
llm = LLM(model="gpt-4o-mini")
|
||||
|
||||
with patch("litellm.completion") as mock_completion:
|
||||
mock_completion.return_value = Mock(
|
||||
choices=[Mock(message=Mock(content="Response", tool_calls=None))],
|
||||
usage=Mock(prompt_tokens=10, completion_tokens=20, total_tokens=30),
|
||||
)
|
||||
|
||||
llm.call("Test message", from_task=base_task, from_agent=base_agent)
|
||||
|
||||
assert all_events_received.wait(timeout=5), "Timeout waiting for events"
|
||||
assert received_events["started"] is not None
|
||||
assert received_events["completed"] is not None
|
||||
assert received_events["started"].message_id is not None
|
||||
assert received_events["completed"].message_id is not None
|
||||
assert received_events["started"].message_id == received_events["completed"].message_id
|
||||
|
||||
|
||||
def test_multiple_calls_same_agent_task_have_different_message_ids(base_agent, base_task):
|
||||
"""Test that multiple calls from the same agent/task have different message_ids"""
|
||||
received_started_events = []
|
||||
lock = threading.Lock()
|
||||
all_events_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(LLMCallStartedEvent)
|
||||
def handle_started(source, event):
|
||||
with lock:
|
||||
received_started_events.append(event)
|
||||
if len(received_started_events) >= 3:
|
||||
all_events_received.set()
|
||||
|
||||
llm = LLM(model="gpt-4o-mini")
|
||||
|
||||
with patch("litellm.completion") as mock_completion:
|
||||
mock_completion.return_value = Mock(
|
||||
choices=[Mock(message=Mock(content="Response", tool_calls=None))],
|
||||
usage=Mock(prompt_tokens=10, completion_tokens=20, total_tokens=30),
|
||||
)
|
||||
|
||||
llm.call("Message 1", from_task=base_task, from_agent=base_agent)
|
||||
llm.call("Message 2", from_task=base_task, from_agent=base_agent)
|
||||
llm.call("Message 3", from_task=base_task, from_agent=base_agent)
|
||||
|
||||
assert all_events_received.wait(timeout=5), "Timeout waiting for events"
|
||||
assert len(received_started_events) >= 3
|
||||
|
||||
message_ids = [event.message_id for event in received_started_events]
|
||||
assert all(mid is not None for mid in message_ids)
|
||||
assert len(set(message_ids)) == 3, "Each call should have a unique message_id"
|
||||
|
||||
task_ids = [event.task_id for event in received_started_events]
|
||||
agent_ids = [event.agent_id for event in received_started_events]
|
||||
assert len(set(task_ids)) == 1, "All calls should have the same task_id"
|
||||
assert len(set(agent_ids)) == 1, "All calls should have the same agent_id"
|
||||
Reference in New Issue
Block a user