mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 08:38:30 +00:00
feat: add capability to track LLM calls by task and agent (#3087)
* feat: add capability to track LLM calls by task and agent This makes it possible to filter or scope LLM events by specific agents or tasks, which can be very useful for debugging or analytics in real-time application * feat: add docs about LLM tracking by Agents and Tasks * fix incompatible BaseLLM.call method signature * feat: support to filter LLM Events from Lite Agent
This commit is contained in:
@@ -775,6 +775,7 @@ class Agent(BaseAgent):
|
||||
LiteAgentOutput: The result of the agent execution.
|
||||
"""
|
||||
lite_agent = LiteAgent(
|
||||
id=self.id,
|
||||
role=self.role,
|
||||
goal=self.goal,
|
||||
backstory=self.backstory,
|
||||
|
||||
@@ -159,6 +159,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
messages=self.messages,
|
||||
callbacks=self.callbacks,
|
||||
printer=self._printer,
|
||||
from_task=self.task
|
||||
)
|
||||
formatted_answer = process_llm_response(answer, self.use_stop_words)
|
||||
|
||||
|
||||
@@ -15,12 +15,14 @@ from typing import (
|
||||
get_origin,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
from typing import Self
|
||||
except ImportError:
|
||||
from typing_extensions import Self
|
||||
|
||||
from pydantic import (
|
||||
UUID4,
|
||||
BaseModel,
|
||||
Field,
|
||||
InstanceOf,
|
||||
@@ -129,6 +131,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
# Core Agent Properties
|
||||
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
|
||||
role: str = Field(description="Role of the agent")
|
||||
goal: str = Field(description="Goal of the agent")
|
||||
backstory: str = Field(description="Backstory of the agent")
|
||||
@@ -517,6 +520,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
messages=self._messages,
|
||||
tools=None,
|
||||
callbacks=self._callbacks,
|
||||
from_agent=self,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -526,6 +530,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
messages=self._messages,
|
||||
callbacks=self._callbacks,
|
||||
printer=self._printer,
|
||||
from_agent=self,
|
||||
)
|
||||
|
||||
# Emit LLM call completed event
|
||||
@@ -534,13 +539,14 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
event=LLMCallCompletedEvent(
|
||||
response=answer,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_agent=self,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
# Emit LLM call failed event
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(error=str(e)),
|
||||
event=LLMCallFailedEvent(error=str(e), from_agent=self),
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
@@ -419,6 +419,8 @@ class LLM(BaseLLM):
|
||||
params: Dict[str, Any],
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
) -> str:
|
||||
"""Handle a streaming response from the LLM.
|
||||
|
||||
@@ -426,6 +428,8 @@ class LLM(BaseLLM):
|
||||
params: Parameters for the completion call
|
||||
callbacks: Optional list of callback functions
|
||||
available_functions: Dict of available functions
|
||||
from_task: Optional task object
|
||||
from_agent: Optional agent object
|
||||
|
||||
Returns:
|
||||
str: The complete response text
|
||||
@@ -510,6 +514,8 @@ class LLM(BaseLLM):
|
||||
tool_calls=tool_calls,
|
||||
accumulated_tool_args=accumulated_tool_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
if result is not None:
|
||||
chunk_content = result
|
||||
@@ -527,7 +533,7 @@ class LLM(BaseLLM):
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMStreamChunkEvent(chunk=chunk_content),
|
||||
event=LLMStreamChunkEvent(chunk=chunk_content, from_task=from_task, from_agent=from_agent),
|
||||
)
|
||||
# --- 4) Fallback to non-streaming if no content received
|
||||
if not full_response.strip() and chunk_count == 0:
|
||||
@@ -540,7 +546,7 @@ class LLM(BaseLLM):
|
||||
"stream_options", None
|
||||
) # Remove stream_options for non-streaming call
|
||||
return self._handle_non_streaming_response(
|
||||
non_streaming_params, callbacks, available_functions
|
||||
non_streaming_params, callbacks, available_functions, from_task, from_agent
|
||||
)
|
||||
|
||||
# --- 5) Handle empty response with chunks
|
||||
@@ -625,7 +631,7 @@ class LLM(BaseLLM):
|
||||
# Log token usage if available in streaming mode
|
||||
self._handle_streaming_callbacks(callbacks, usage_info, last_chunk)
|
||||
# Emit completion event and return response
|
||||
self._handle_emit_call_events(full_response, LLMCallType.LLM_CALL)
|
||||
self._handle_emit_call_events(full_response, LLMCallType.LLM_CALL, from_task, from_agent)
|
||||
return full_response
|
||||
|
||||
# --- 9) Handle tool calls if present
|
||||
@@ -637,7 +643,7 @@ class LLM(BaseLLM):
|
||||
self._handle_streaming_callbacks(callbacks, usage_info, last_chunk)
|
||||
|
||||
# --- 11) Emit completion event and return response
|
||||
self._handle_emit_call_events(full_response, LLMCallType.LLM_CALL)
|
||||
self._handle_emit_call_events(full_response, LLMCallType.LLM_CALL, from_task, from_agent)
|
||||
return full_response
|
||||
|
||||
except ContextWindowExceededError as e:
|
||||
@@ -649,14 +655,14 @@ class LLM(BaseLLM):
|
||||
logging.error(f"Error in streaming response: {str(e)}")
|
||||
if full_response.strip():
|
||||
logging.warning(f"Returning partial response despite error: {str(e)}")
|
||||
self._handle_emit_call_events(full_response, LLMCallType.LLM_CALL)
|
||||
self._handle_emit_call_events(full_response, LLMCallType.LLM_CALL, from_task, from_agent)
|
||||
return full_response
|
||||
|
||||
# Emit failed event and re-raise the exception
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(error=str(e)),
|
||||
event=LLMCallFailedEvent(error=str(e), from_task=from_task, from_agent=from_agent),
|
||||
)
|
||||
raise Exception(f"Failed to get streaming response: {str(e)}")
|
||||
|
||||
@@ -665,6 +671,8 @@ class LLM(BaseLLM):
|
||||
tool_calls: List[ChatCompletionDeltaToolCall],
|
||||
accumulated_tool_args: DefaultDict[int, AccumulatedToolArgs],
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
) -> None | str:
|
||||
for tool_call in tool_calls:
|
||||
current_tool_accumulator = accumulated_tool_args[tool_call.index]
|
||||
@@ -682,6 +690,8 @@ class LLM(BaseLLM):
|
||||
event=LLMStreamChunkEvent(
|
||||
tool_call=tool_call.to_dict(),
|
||||
chunk=tool_call.function.arguments,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -748,6 +758,8 @@ class LLM(BaseLLM):
|
||||
params: Dict[str, Any],
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
) -> str:
|
||||
"""Handle a non-streaming response from the LLM.
|
||||
|
||||
@@ -755,6 +767,8 @@ class LLM(BaseLLM):
|
||||
params: Parameters for the completion call
|
||||
callbacks: Optional list of callback functions
|
||||
available_functions: Dict of available functions
|
||||
from_task: Optional Task that invoked the LLM
|
||||
from_agent: Optional Agent that invoked the LLM
|
||||
|
||||
Returns:
|
||||
str: The response text
|
||||
@@ -795,7 +809,7 @@ class LLM(BaseLLM):
|
||||
|
||||
# --- 5) If no tool calls or no available functions, return the text response directly
|
||||
if not tool_calls or not available_functions:
|
||||
self._handle_emit_call_events(text_response, LLMCallType.LLM_CALL)
|
||||
self._handle_emit_call_events(text_response, LLMCallType.LLM_CALL, from_task, from_agent)
|
||||
return text_response
|
||||
|
||||
# --- 6) Handle tool calls if present
|
||||
@@ -804,7 +818,7 @@ class LLM(BaseLLM):
|
||||
return tool_result
|
||||
|
||||
# --- 7) If tool call handling didn't return a result, emit completion event and return text response
|
||||
self._handle_emit_call_events(text_response, LLMCallType.LLM_CALL)
|
||||
self._handle_emit_call_events(text_response, LLMCallType.LLM_CALL, from_task, from_agent)
|
||||
return text_response
|
||||
|
||||
def _handle_tool_call(
|
||||
@@ -889,6 +903,8 @@ class LLM(BaseLLM):
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
) -> Union[str, Any]:
|
||||
"""High-level LLM call method.
|
||||
|
||||
@@ -903,6 +919,8 @@ class LLM(BaseLLM):
|
||||
during and after the LLM call.
|
||||
available_functions: Optional dict mapping function names to callables
|
||||
that can be invoked by the LLM.
|
||||
from_task: Optional Task that invoked the LLM
|
||||
from_agent: Optional Agent that invoked the LLM
|
||||
|
||||
Returns:
|
||||
Union[str, Any]: Either a text response from the LLM (str) or
|
||||
@@ -922,6 +940,8 @@ class LLM(BaseLLM):
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -950,11 +970,11 @@ class LLM(BaseLLM):
|
||||
# --- 7) Make the completion call and handle response
|
||||
if self.stream:
|
||||
return self._handle_streaming_response(
|
||||
params, callbacks, available_functions
|
||||
params, callbacks, available_functions, from_task, from_agent
|
||||
)
|
||||
else:
|
||||
return self._handle_non_streaming_response(
|
||||
params, callbacks, available_functions
|
||||
params, callbacks, available_functions, from_task, from_agent
|
||||
)
|
||||
|
||||
except LLMContextLengthExceededException:
|
||||
@@ -966,12 +986,12 @@ class LLM(BaseLLM):
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(error=str(e)),
|
||||
event=LLMCallFailedEvent(error=str(e), from_task=from_task, from_agent=from_agent),
|
||||
)
|
||||
logging.error(f"LiteLLM call failed: {str(e)}")
|
||||
raise
|
||||
|
||||
def _handle_emit_call_events(self, response: Any, call_type: LLMCallType):
|
||||
def _handle_emit_call_events(self, response: Any, call_type: LLMCallType, from_task: Optional[Any] = None, from_agent: Optional[Any] = None):
|
||||
"""Handle the events for the LLM call.
|
||||
|
||||
Args:
|
||||
@@ -981,7 +1001,7 @@ class LLM(BaseLLM):
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallCompletedEvent(response=response, call_type=call_type),
|
||||
event=LLMCallCompletedEvent(response=response, call_type=call_type, from_task=from_task, from_agent=from_agent),
|
||||
)
|
||||
|
||||
def _format_messages_for_provider(
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
|
||||
class BaseLLM(ABC):
|
||||
@@ -47,6 +47,8 @@ class BaseLLM(ABC):
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
) -> Union[str, Any]:
|
||||
"""Call the LLM with the given messages.
|
||||
|
||||
@@ -61,6 +63,7 @@ class BaseLLM(ABC):
|
||||
during and after the LLM call.
|
||||
available_functions: Optional dict mapping function names to callables
|
||||
that can be invoked by the LLM.
|
||||
from_task: Optional task caller to be used for the LLM call.
|
||||
|
||||
Returns:
|
||||
Either a text response from the LLM (str) or
|
||||
|
||||
2
src/crewai/llms/third_party/ai_suite.py
vendored
2
src/crewai/llms/third_party/ai_suite.py
vendored
@@ -16,6 +16,8 @@ class AISuiteLLM(BaseLLM):
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
) -> Union[str, Any]:
|
||||
completion_params = self._prepare_completion_params(messages, tools)
|
||||
response = self.client.chat.completions.create(**completion_params)
|
||||
|
||||
@@ -145,12 +145,16 @@ def get_llm_response(
|
||||
messages: List[Dict[str, str]],
|
||||
callbacks: List[Any],
|
||||
printer: Printer,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
) -> str:
|
||||
"""Call the LLM and return the response, handling any invalid responses."""
|
||||
try:
|
||||
answer = llm.call(
|
||||
messages,
|
||||
callbacks=callbacks,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
except Exception as e:
|
||||
printer.print(
|
||||
|
||||
@@ -5,6 +5,32 @@ from pydantic import BaseModel
|
||||
|
||||
from crewai.utilities.events.base_events import BaseEvent
|
||||
|
||||
class LLMEventBase(BaseEvent):
|
||||
task_name: Optional[str] = None
|
||||
task_id: Optional[str] = None
|
||||
|
||||
agent_id: Optional[str] = None
|
||||
agent_role: Optional[str] = None
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
self._set_agent_params(data)
|
||||
self._set_task_params(data)
|
||||
|
||||
def _set_agent_params(self, data: Dict[str, Any]):
|
||||
task = data.get("from_task", None)
|
||||
agent = task.agent if task else data.get("from_agent", None)
|
||||
|
||||
if not agent:
|
||||
return
|
||||
|
||||
self.agent_id = agent.id
|
||||
self.agent_role = agent.role
|
||||
|
||||
def _set_task_params(self, data: Dict[str, Any]):
|
||||
if "from_task" in data and (task := data["from_task"]):
|
||||
self.task_id = task.id
|
||||
self.task_name = task.name
|
||||
|
||||
class LLMCallType(Enum):
|
||||
"""Type of LLM call being made"""
|
||||
@@ -13,7 +39,7 @@ class LLMCallType(Enum):
|
||||
LLM_CALL = "llm_call"
|
||||
|
||||
|
||||
class LLMCallStartedEvent(BaseEvent):
|
||||
class LLMCallStartedEvent(LLMEventBase):
|
||||
"""Event emitted when a LLM call starts
|
||||
|
||||
Attributes:
|
||||
@@ -28,7 +54,7 @@ class LLMCallStartedEvent(BaseEvent):
|
||||
available_functions: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class LLMCallCompletedEvent(BaseEvent):
|
||||
class LLMCallCompletedEvent(LLMEventBase):
|
||||
"""Event emitted when a LLM call completes"""
|
||||
|
||||
type: str = "llm_call_completed"
|
||||
@@ -36,7 +62,7 @@ class LLMCallCompletedEvent(BaseEvent):
|
||||
call_type: LLMCallType
|
||||
|
||||
|
||||
class LLMCallFailedEvent(BaseEvent):
|
||||
class LLMCallFailedEvent(LLMEventBase):
|
||||
"""Event emitted when a LLM call fails"""
|
||||
|
||||
error: str
|
||||
@@ -55,7 +81,7 @@ class ToolCall(BaseModel):
|
||||
index: int
|
||||
|
||||
|
||||
class LLMStreamChunkEvent(BaseEvent):
|
||||
class LLMStreamChunkEvent(LLMEventBase):
|
||||
"""Event emitted when a streaming chunk is received"""
|
||||
|
||||
type: str = "llm_stream_chunk"
|
||||
|
||||
Reference in New Issue
Block a user