mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-01 16:48:15 +00:00
Compare commits
15 Commits
gl/refacto
...
fix/trace-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0cd27790fd | ||
|
|
68e943be68 | ||
|
|
8388169a56 | ||
|
|
3283a00e31 | ||
|
|
5de23b867c | ||
|
|
8edd8b3355 | ||
|
|
2af6a531f5 | ||
|
|
c0d6d2b63f | ||
|
|
3e0c750f51 | ||
|
|
416f01fe23 | ||
|
|
da65ca2502 | ||
|
|
47f192e112 | ||
|
|
19d1088bab | ||
|
|
1faee0c684 | ||
|
|
6da1c5f964 |
@@ -43,7 +43,7 @@ dependencies = [
|
||||
"uv~=0.9.13",
|
||||
"aiosqlite~=0.21.0",
|
||||
"pyyaml~=6.0",
|
||||
"lancedb>=0.29.2",
|
||||
"lancedb>=0.29.2,<0.30.1",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
|
||||
@@ -6,7 +6,6 @@ import warnings
|
||||
|
||||
from crewai.agent.core import Agent
|
||||
from crewai.agent.planning_config import PlanningConfig
|
||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||
from crewai.crew import Crew
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.flow.flow import Flow
|
||||
@@ -20,9 +19,6 @@ from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.telemetry.telemetry import Telemetry
|
||||
|
||||
|
||||
CrewAgentExecutor.model_rebuild()
|
||||
|
||||
|
||||
def _suppress_pydantic_deprecation_warnings() -> None:
|
||||
"""Suppress Pydantic deprecation warnings using targeted monkey patch."""
|
||||
original_warn = warnings.warn
|
||||
|
||||
@@ -14,15 +14,8 @@ import inspect
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
PrivateAttr,
|
||||
ValidationError,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import Self
|
||||
from pydantic import BaseModel, GetCoreSchemaHandler, ValidationError
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
|
||||
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
|
||||
from crewai.agents.parser import (
|
||||
@@ -30,7 +23,6 @@ from crewai.agents.parser import (
|
||||
AgentFinish,
|
||||
OutputParserError,
|
||||
)
|
||||
from crewai.agents.tools_handler import ToolsHandler
|
||||
from crewai.core.providers.human_input import ExecutorContext, get_provider
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.logging_events import (
|
||||
@@ -46,9 +38,6 @@ from crewai.hooks.tool_hooks import (
|
||||
get_after_tool_call_hooks,
|
||||
get_before_tool_call_hooks,
|
||||
)
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
from crewai.utilities.agent_utils import (
|
||||
aget_llm_response,
|
||||
convert_tools_to_openai_schema,
|
||||
@@ -70,65 +59,106 @@ from crewai.utilities.constants import TRAINING_DATA_FILE
|
||||
from crewai.utilities.file_store import aget_all_files, get_all_files
|
||||
from crewai.utilities.i18n import I18N, get_i18n
|
||||
from crewai.utilities.printer import Printer
|
||||
from crewai.utilities.prompts import StandardPromptResult, SystemPromptResult
|
||||
from crewai.utilities.string_utils import sanitize_tool_name
|
||||
from crewai.utilities.tool_utils import (
|
||||
aexecute_tool_and_check_finality,
|
||||
execute_tool_and_check_finality,
|
||||
)
|
||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent import Agent
|
||||
from crewai.agents.tools_handler import ToolsHandler
|
||||
from crewai.crew import Crew
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.task import Task
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
from crewai.tools.tool_types import ToolResult
|
||||
from crewai.utilities.prompts import StandardPromptResult, SystemPromptResult
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
class CrewAgentExecutor(BaseModel, CrewAgentExecutorMixin):
|
||||
class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
"""Executor for crew agents.
|
||||
|
||||
Manages the execution lifecycle of an agent including prompt formatting,
|
||||
LLM interactions, tool execution, and feedback handling.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True)
|
||||
def __init__(
|
||||
self,
|
||||
llm: BaseLLM,
|
||||
task: Task,
|
||||
crew: Crew,
|
||||
agent: Agent,
|
||||
prompt: SystemPromptResult | StandardPromptResult,
|
||||
max_iter: int,
|
||||
tools: list[CrewStructuredTool],
|
||||
tools_names: str,
|
||||
stop_words: list[str],
|
||||
tools_description: str,
|
||||
tools_handler: ToolsHandler,
|
||||
step_callback: Any = None,
|
||||
original_tools: list[BaseTool] | None = None,
|
||||
function_calling_llm: BaseLLM | Any | None = None,
|
||||
respect_context_window: bool = False,
|
||||
request_within_rpm_limit: Callable[[], bool] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
i18n: I18N | None = None,
|
||||
) -> None:
|
||||
"""Initialize executor.
|
||||
|
||||
llm: BaseLLM
|
||||
task: Task | None = None
|
||||
crew: Crew | None = None
|
||||
agent: Agent
|
||||
prompt: SystemPromptResult | StandardPromptResult
|
||||
max_iter: int
|
||||
tools: list[CrewStructuredTool]
|
||||
tools_names: str
|
||||
stop: list[str] = Field(alias="stop_words")
|
||||
tools_description: str
|
||||
tools_handler: ToolsHandler
|
||||
step_callback: Any = None
|
||||
original_tools: list[BaseTool] = Field(default_factory=list)
|
||||
function_calling_llm: BaseLLM | Any | None = None
|
||||
respect_context_window: bool = False
|
||||
request_within_rpm_limit: Callable[[], bool] | None = None
|
||||
callbacks: list[Any] = Field(default_factory=list)
|
||||
response_model: type[BaseModel] | None = None
|
||||
i18n: I18N | None = Field(default=None, exclude=True)
|
||||
ask_for_human_input: bool = False
|
||||
messages: list[LLMMessage] = Field(default_factory=list)
|
||||
iterations: int = 0
|
||||
log_error_after: int = 3
|
||||
before_llm_call_hooks: list[Callable[..., Any]] = Field(default_factory=list)
|
||||
after_llm_call_hooks: list[Callable[..., Any]] = Field(default_factory=list)
|
||||
_i18n: I18N = PrivateAttr()
|
||||
_printer: Printer = PrivateAttr(default_factory=Printer)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _init_executor(self) -> Self:
|
||||
self._i18n = self.i18n or get_i18n()
|
||||
Args:
|
||||
llm: Language model instance.
|
||||
task: Task to execute.
|
||||
crew: Crew instance.
|
||||
agent: Agent to execute.
|
||||
prompt: Prompt templates.
|
||||
max_iter: Maximum iterations.
|
||||
tools: Available tools.
|
||||
tools_names: Tool names string.
|
||||
stop_words: Stop word list.
|
||||
tools_description: Tool descriptions.
|
||||
tools_handler: Tool handler instance.
|
||||
step_callback: Optional step callback.
|
||||
original_tools: Original tool list.
|
||||
function_calling_llm: Optional function calling LLM.
|
||||
respect_context_window: Respect context limits.
|
||||
request_within_rpm_limit: RPM limit check function.
|
||||
callbacks: Optional callbacks list.
|
||||
response_model: Optional Pydantic model for structured outputs.
|
||||
"""
|
||||
self._i18n: I18N = i18n or get_i18n()
|
||||
self.llm = llm
|
||||
self.task = task
|
||||
self.agent = agent
|
||||
self.crew = crew
|
||||
self.prompt = prompt
|
||||
self.tools = tools
|
||||
self.tools_names = tools_names
|
||||
self.stop = stop_words
|
||||
self.max_iter = max_iter
|
||||
self.callbacks = callbacks or []
|
||||
self._printer: Printer = Printer()
|
||||
self.tools_handler = tools_handler
|
||||
self.original_tools = original_tools or []
|
||||
self.step_callback = step_callback
|
||||
self.tools_description = tools_description
|
||||
self.function_calling_llm = function_calling_llm
|
||||
self.respect_context_window = respect_context_window
|
||||
self.request_within_rpm_limit = request_within_rpm_limit
|
||||
self.response_model = response_model
|
||||
self.ask_for_human_input = False
|
||||
self.messages: list[LLMMessage] = []
|
||||
self.iterations = 0
|
||||
self.log_error_after = 3
|
||||
self.before_llm_call_hooks: list[Callable[..., Any]] = []
|
||||
self.after_llm_call_hooks: list[Callable[..., Any]] = []
|
||||
self.before_llm_call_hooks.extend(get_before_llm_call_hooks())
|
||||
self.after_llm_call_hooks.extend(get_after_llm_call_hooks())
|
||||
if self.llm:
|
||||
@@ -141,7 +171,6 @@ class CrewAgentExecutor(BaseModel, CrewAgentExecutorMixin):
|
||||
else self.stop
|
||||
)
|
||||
)
|
||||
return self
|
||||
|
||||
@property
|
||||
def use_stop_words(self) -> bool:
|
||||
@@ -1658,3 +1687,14 @@ class CrewAgentExecutor(BaseModel, CrewAgentExecutorMixin):
|
||||
return format_message_for_llm(
|
||||
self._i18n.slice("feedback_instructions").format(feedback=feedback)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, _source_type: Any, _handler: GetCoreSchemaHandler
|
||||
) -> CoreSchema:
|
||||
"""Generate Pydantic core schema for BaseClient Protocol.
|
||||
|
||||
This allows the Protocol to be used in Pydantic models without
|
||||
requiring arbitrary_types_allowed=True.
|
||||
"""
|
||||
return core_schema.any_schema()
|
||||
|
||||
@@ -2,14 +2,56 @@ from collections.abc import Iterator
|
||||
import contextvars
|
||||
from datetime import datetime, timezone
|
||||
import itertools
|
||||
from typing import Any
|
||||
from typing import Any, TypedDict
|
||||
import uuid
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, SerializationInfo
|
||||
|
||||
from crewai.utilities.serialization import Serializable, to_serializable
|
||||
|
||||
|
||||
def _is_trace_context(info: SerializationInfo) -> bool:
|
||||
"""Check if serialization is happening in trace context."""
|
||||
return bool(info.context and info.context.get("trace"))
|
||||
|
||||
|
||||
class AgentRef(TypedDict):
|
||||
id: str
|
||||
role: str
|
||||
|
||||
|
||||
class TaskRef(TypedDict):
|
||||
id: str
|
||||
name: str
|
||||
|
||||
|
||||
def _trace_agent_ref(agent: Any) -> AgentRef | None:
|
||||
"""Return a lightweight agent reference for trace serialization."""
|
||||
if agent is None:
|
||||
return None
|
||||
return AgentRef(
|
||||
id=str(getattr(agent, "id", "")),
|
||||
role=getattr(agent, "role", ""),
|
||||
)
|
||||
|
||||
|
||||
def _trace_task_ref(task: Any) -> TaskRef | None:
|
||||
"""Return a lightweight task reference for trace serialization."""
|
||||
if task is None:
|
||||
return None
|
||||
return TaskRef(
|
||||
id=str(getattr(task, "id", "")),
|
||||
name=str(getattr(task, "name", None) or getattr(task, "description", "")),
|
||||
)
|
||||
|
||||
|
||||
def _trace_tool_names(tools: Any) -> list[str] | None:
|
||||
"""Return a list of tool names for trace serialization."""
|
||||
if not tools:
|
||||
return None
|
||||
return [getattr(t, "name", str(t)) for t in tools]
|
||||
|
||||
|
||||
_emission_counter: contextvars.ContextVar[Iterator[int]] = contextvars.ContextVar(
|
||||
"_emission_counter"
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Trace collection listener for orchestrating trace collection."""
|
||||
|
||||
import os
|
||||
from typing import Any, ClassVar
|
||||
from typing import Any
|
||||
import uuid
|
||||
|
||||
from typing_extensions import Self
|
||||
@@ -126,18 +126,13 @@ from crewai.events.types.tool_usage_events import (
|
||||
from crewai.events.utils.console_formatter import ConsoleFormatter
|
||||
|
||||
|
||||
_TRACE_CONTEXT: dict[str, bool] = {"trace": True}
|
||||
"""Serialization context that triggers lightweight field serializers on event models."""
|
||||
|
||||
|
||||
class TraceCollectionListener(BaseEventListener):
|
||||
"""Trace collection listener that orchestrates trace collection."""
|
||||
|
||||
complex_events: ClassVar[list[str]] = [
|
||||
"task_started",
|
||||
"task_completed",
|
||||
"llm_call_started",
|
||||
"llm_call_completed",
|
||||
"agent_execution_started",
|
||||
"agent_execution_completed",
|
||||
]
|
||||
|
||||
_instance: Self | None = None
|
||||
_initialized: bool = False
|
||||
_listeners_setup: bool = False
|
||||
@@ -810,9 +805,19 @@ class TraceCollectionListener(BaseEventListener):
|
||||
def _build_event_data(
|
||||
self, event_type: str, event: Any, source: Any
|
||||
) -> dict[str, Any]:
|
||||
"""Build event data"""
|
||||
if event_type not in self.complex_events:
|
||||
return safe_serialize_to_dict(event)
|
||||
"""Build event data with context-based serialization to reduce trace bloat.
|
||||
|
||||
Field serializers on event models check for context={"trace": True} and
|
||||
return lightweight references instead of full nested objects. This replaces
|
||||
the old denylist approach with Pydantic v2's native context mechanism.
|
||||
|
||||
Only crew_kickoff_started gets a full crew structure (built separately).
|
||||
Complex events (task_started, etc.) use custom projections for specific shapes.
|
||||
All other events get context-aware serialization automatically.
|
||||
"""
|
||||
if event_type == "crew_kickoff_started":
|
||||
return self._build_crew_started_data(event)
|
||||
|
||||
if event_type == "task_started":
|
||||
task_name = event.task.name or event.task.description
|
||||
task_display_name = (
|
||||
@@ -853,19 +858,77 @@ class TraceCollectionListener(BaseEventListener):
|
||||
"agent_backstory": event.agent.backstory,
|
||||
}
|
||||
if event_type == "llm_call_started":
|
||||
event_data = safe_serialize_to_dict(event)
|
||||
event_data = safe_serialize_to_dict(event, context=_TRACE_CONTEXT)
|
||||
event_data["task_name"] = event.task_name or getattr(
|
||||
event, "task_description", None
|
||||
)
|
||||
return event_data
|
||||
if event_type == "llm_call_completed":
|
||||
return safe_serialize_to_dict(event)
|
||||
return safe_serialize_to_dict(event, context=_TRACE_CONTEXT)
|
||||
|
||||
return {
|
||||
"event_type": event_type,
|
||||
"event": safe_serialize_to_dict(event),
|
||||
"source": source,
|
||||
}
|
||||
return safe_serialize_to_dict(event, context=_TRACE_CONTEXT)
|
||||
|
||||
def _build_crew_started_data(self, event: Any) -> dict[str, Any]:
|
||||
"""Build comprehensive crew structure for crew_kickoff_started event.
|
||||
|
||||
This is the ONE place where we serialize the full crew structure.
|
||||
Subsequent events use lightweight references via field serializers.
|
||||
"""
|
||||
event_data = safe_serialize_to_dict(event, context=_TRACE_CONTEXT)
|
||||
|
||||
crew = getattr(event, "crew", None)
|
||||
if crew is not None:
|
||||
agents_data = []
|
||||
for agent in getattr(crew, "agents", []) or []:
|
||||
agent_data = {
|
||||
"id": str(getattr(agent, "id", "")),
|
||||
"role": getattr(agent, "role", ""),
|
||||
"goal": getattr(agent, "goal", ""),
|
||||
"backstory": getattr(agent, "backstory", ""),
|
||||
"verbose": getattr(agent, "verbose", False),
|
||||
"allow_delegation": getattr(agent, "allow_delegation", False),
|
||||
"max_iter": getattr(agent, "max_iter", None),
|
||||
"max_rpm": getattr(agent, "max_rpm", None),
|
||||
}
|
||||
tools = getattr(agent, "tools", None)
|
||||
if tools:
|
||||
agent_data["tool_names"] = [
|
||||
getattr(t, "name", str(t)) for t in tools
|
||||
]
|
||||
agents_data.append(agent_data)
|
||||
|
||||
tasks_data = []
|
||||
for task in getattr(crew, "tasks", []) or []:
|
||||
task_data = {
|
||||
"id": str(getattr(task, "id", "")),
|
||||
"name": getattr(task, "name", None),
|
||||
"description": getattr(task, "description", ""),
|
||||
"expected_output": getattr(task, "expected_output", ""),
|
||||
"async_execution": getattr(task, "async_execution", False),
|
||||
"human_input": getattr(task, "human_input", False),
|
||||
}
|
||||
task_agent = getattr(task, "agent", None)
|
||||
if task_agent:
|
||||
task_data["agent_ref"] = {
|
||||
"id": str(getattr(task_agent, "id", "")),
|
||||
"role": getattr(task_agent, "role", ""),
|
||||
}
|
||||
context_tasks = getattr(task, "context", None)
|
||||
if context_tasks:
|
||||
task_data["context_task_ids"] = [
|
||||
str(getattr(ct, "id", "")) for ct in context_tasks
|
||||
]
|
||||
tasks_data.append(task_data)
|
||||
|
||||
event_data["crew_structure"] = {
|
||||
"agents": agents_data,
|
||||
"tasks": tasks_data,
|
||||
"process": str(getattr(crew, "process", "")),
|
||||
"verbose": getattr(crew, "verbose", False),
|
||||
"memory": getattr(crew, "memory", False),
|
||||
}
|
||||
|
||||
return event_data
|
||||
|
||||
def _show_tracing_disabled_message(self) -> None:
|
||||
"""Show a message when tracing is disabled."""
|
||||
|
||||
@@ -429,10 +429,22 @@ def mark_first_execution_done(user_consented: bool = False) -> None:
|
||||
p.write_text(json.dumps(data, indent=2))
|
||||
|
||||
|
||||
def safe_serialize_to_dict(obj: Any, exclude: set[str] | None = None) -> dict[str, Any]:
|
||||
"""Safely serialize an object to a dictionary for event data."""
|
||||
def safe_serialize_to_dict(
|
||||
obj: Any,
|
||||
exclude: set[str] | None = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Safely serialize an object to a dictionary for event data.
|
||||
|
||||
Args:
|
||||
obj: Object to serialize.
|
||||
exclude: Set of keys to exclude from the result.
|
||||
context: Optional context dict passed through to Pydantic's model_dump().
|
||||
Field serializers can inspect this to customize output
|
||||
(e.g. context={"trace": True} for lightweight trace serialization).
|
||||
"""
|
||||
try:
|
||||
serialized = to_serializable(obj, exclude)
|
||||
serialized = to_serializable(obj, exclude, context=context)
|
||||
if isinstance(serialized, dict):
|
||||
return serialized
|
||||
return {"serialized_data": serialized}
|
||||
|
||||
@@ -5,11 +5,17 @@ from __future__ import annotations
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from pydantic import ConfigDict, model_validator
|
||||
from pydantic import ConfigDict, SerializationInfo, field_serializer, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.events.base_events import BaseEvent
|
||||
from crewai.events.base_events import (
|
||||
BaseEvent,
|
||||
_is_trace_context,
|
||||
_trace_agent_ref,
|
||||
_trace_task_ref,
|
||||
_trace_tool_names,
|
||||
)
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
|
||||
@@ -31,6 +37,21 @@ class AgentExecutionStartedEvent(BaseEvent):
|
||||
_set_agent_fingerprint(self, self.agent)
|
||||
return self
|
||||
|
||||
@field_serializer("agent")
|
||||
@classmethod
|
||||
def _serialize_agent(cls, v: Any, info: SerializationInfo) -> Any:
|
||||
return _trace_agent_ref(v) if _is_trace_context(info) else v
|
||||
|
||||
@field_serializer("task")
|
||||
@classmethod
|
||||
def _serialize_task(cls, v: Any, info: SerializationInfo) -> Any:
|
||||
return _trace_task_ref(v) if _is_trace_context(info) else v
|
||||
|
||||
@field_serializer("tools")
|
||||
@classmethod
|
||||
def _serialize_tools(cls, v: Any, info: SerializationInfo) -> Any:
|
||||
return _trace_tool_names(v) if _is_trace_context(info) else v
|
||||
|
||||
|
||||
class AgentExecutionCompletedEvent(BaseEvent):
|
||||
"""Event emitted when an agent completes executing a task"""
|
||||
@@ -48,6 +69,16 @@ class AgentExecutionCompletedEvent(BaseEvent):
|
||||
_set_agent_fingerprint(self, self.agent)
|
||||
return self
|
||||
|
||||
@field_serializer("agent")
|
||||
@classmethod
|
||||
def _serialize_agent(cls, v: Any, info: SerializationInfo) -> Any:
|
||||
return _trace_agent_ref(v) if _is_trace_context(info) else v
|
||||
|
||||
@field_serializer("task")
|
||||
@classmethod
|
||||
def _serialize_task(cls, v: Any, info: SerializationInfo) -> Any:
|
||||
return _trace_task_ref(v) if _is_trace_context(info) else v
|
||||
|
||||
|
||||
class AgentExecutionErrorEvent(BaseEvent):
|
||||
"""Event emitted when an agent encounters an error during execution"""
|
||||
@@ -65,6 +96,16 @@ class AgentExecutionErrorEvent(BaseEvent):
|
||||
_set_agent_fingerprint(self, self.agent)
|
||||
return self
|
||||
|
||||
@field_serializer("agent")
|
||||
@classmethod
|
||||
def _serialize_agent(cls, v: Any, info: SerializationInfo) -> Any:
|
||||
return _trace_agent_ref(v) if _is_trace_context(info) else v
|
||||
|
||||
@field_serializer("task")
|
||||
@classmethod
|
||||
def _serialize_task(cls, v: Any, info: SerializationInfo) -> Any:
|
||||
return _trace_task_ref(v) if _is_trace_context(info) else v
|
||||
|
||||
|
||||
# New event classes for LiteAgent
|
||||
class LiteAgentExecutionStartedEvent(BaseEvent):
|
||||
@@ -77,6 +118,11 @@ class LiteAgentExecutionStartedEvent(BaseEvent):
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@field_serializer("tools")
|
||||
@classmethod
|
||||
def _serialize_tools(cls, v: Any, info: SerializationInfo) -> Any:
|
||||
return _trace_tool_names(v) if _is_trace_context(info) else v
|
||||
|
||||
|
||||
class LiteAgentExecutionCompletedEvent(BaseEvent):
|
||||
"""Event emitted when a LiteAgent completes execution"""
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
from pydantic import SerializationInfo, field_serializer
|
||||
|
||||
from crewai.events.base_events import BaseEvent, _is_trace_context
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -26,6 +28,14 @@ class CrewBaseEvent(BaseEvent):
|
||||
if self.crew.fingerprint.metadata:
|
||||
self.fingerprint_metadata = self.crew.fingerprint.metadata
|
||||
|
||||
@field_serializer("crew")
|
||||
@classmethod
|
||||
def _serialize_crew(cls, v: Any, info: SerializationInfo) -> Any:
|
||||
"""Exclude crew in trace context — crew_kickoff_started builds structure separately."""
|
||||
if _is_trace_context(info):
|
||||
return None
|
||||
return v
|
||||
|
||||
def to_json(self, exclude: set[str] | None = None) -> Any:
|
||||
if exclude is None:
|
||||
exclude = set()
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, SerializationInfo, field_serializer
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
from crewai.events.base_events import BaseEvent, _is_trace_context
|
||||
|
||||
|
||||
class LLMEventBase(BaseEvent):
|
||||
@@ -49,6 +49,16 @@ class LLMCallStartedEvent(LLMEventBase):
|
||||
callbacks: list[Any] | None = None
|
||||
available_functions: dict[str, Any] | None = None
|
||||
|
||||
@field_serializer("callbacks")
|
||||
@classmethod
|
||||
def _serialize_callbacks(cls, v: Any, info: SerializationInfo) -> Any:
|
||||
return None if _is_trace_context(info) else v
|
||||
|
||||
@field_serializer("available_functions")
|
||||
@classmethod
|
||||
def _serialize_available_functions(cls, v: Any, info: SerializationInfo) -> Any:
|
||||
return None if _is_trace_context(info) else v
|
||||
|
||||
|
||||
class LLMCallCompletedEvent(LLMEventBase):
|
||||
"""Event emitted when a LLM call completes"""
|
||||
@@ -57,6 +67,7 @@ class LLMCallCompletedEvent(LLMEventBase):
|
||||
messages: str | list[dict[str, Any]] | None = None
|
||||
response: Any
|
||||
call_type: LLMCallType
|
||||
usage: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class LLMCallFailedEvent(LLMEventBase):
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from typing import Any
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
from pydantic import SerializationInfo, field_serializer
|
||||
|
||||
from crewai.events.base_events import BaseEvent, _is_trace_context, _trace_task_ref
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
|
||||
|
||||
@@ -24,6 +26,11 @@ class TaskStartedEvent(BaseEvent):
|
||||
super().__init__(**data)
|
||||
_set_task_fingerprint(self, self.task)
|
||||
|
||||
@field_serializer("task")
|
||||
@classmethod
|
||||
def _serialize_task(cls, v: Any, info: SerializationInfo) -> Any:
|
||||
return _trace_task_ref(v) if _is_trace_context(info) else v
|
||||
|
||||
|
||||
class TaskCompletedEvent(BaseEvent):
|
||||
"""Event emitted when a task completes"""
|
||||
@@ -36,6 +43,11 @@ class TaskCompletedEvent(BaseEvent):
|
||||
super().__init__(**data)
|
||||
_set_task_fingerprint(self, self.task)
|
||||
|
||||
@field_serializer("task")
|
||||
@classmethod
|
||||
def _serialize_task(cls, v: Any, info: SerializationInfo) -> Any:
|
||||
return _trace_task_ref(v) if _is_trace_context(info) else v
|
||||
|
||||
|
||||
class TaskFailedEvent(BaseEvent):
|
||||
"""Event emitted when a task fails"""
|
||||
@@ -48,6 +60,11 @@ class TaskFailedEvent(BaseEvent):
|
||||
super().__init__(**data)
|
||||
_set_task_fingerprint(self, self.task)
|
||||
|
||||
@field_serializer("task")
|
||||
@classmethod
|
||||
def _serialize_task(cls, v: Any, info: SerializationInfo) -> Any:
|
||||
return _trace_task_ref(v) if _is_trace_context(info) else v
|
||||
|
||||
|
||||
class TaskEvaluationEvent(BaseEvent):
|
||||
"""Event emitted when a task evaluation is completed"""
|
||||
@@ -59,3 +76,8 @@ class TaskEvaluationEvent(BaseEvent):
|
||||
def __init__(self, **data: Any) -> None:
|
||||
super().__init__(**data)
|
||||
_set_task_fingerprint(self, self.task)
|
||||
|
||||
@field_serializer("task")
|
||||
@classmethod
|
||||
def _serialize_task(cls, v: Any, info: SerializationInfo) -> Any:
|
||||
return _trace_task_ref(v) if _is_trace_context(info) else v
|
||||
|
||||
@@ -2,9 +2,9 @@ from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import ConfigDict, SerializationInfo, field_serializer
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
from crewai.events.base_events import BaseEvent, _is_trace_context, _trace_agent_ref
|
||||
|
||||
|
||||
class ToolUsageEvent(BaseEvent):
|
||||
@@ -26,6 +26,11 @@ class ToolUsageEvent(BaseEvent):
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@field_serializer("agent")
|
||||
@classmethod
|
||||
def _serialize_agent(cls, v: Any, info: SerializationInfo) -> Any:
|
||||
return _trace_agent_ref(v) if _is_trace_context(info) else v
|
||||
|
||||
def __init__(self, **data: Any) -> None:
|
||||
if data.get("from_task"):
|
||||
task = data["from_task"]
|
||||
@@ -99,6 +104,11 @@ class ToolExecutionErrorEvent(BaseEvent):
|
||||
tool_class: Callable[..., Any]
|
||||
agent: Any | None = None
|
||||
|
||||
@field_serializer("agent")
|
||||
@classmethod
|
||||
def _serialize_agent(cls, v: Any, info: SerializationInfo) -> Any:
|
||||
return _trace_agent_ref(v) if _is_trace_context(info) else v
|
||||
|
||||
def __init__(self, **data: Any) -> None:
|
||||
super().__init__(**data)
|
||||
# Set fingerprint data from the agent
|
||||
|
||||
@@ -970,21 +970,25 @@ class LLM(BaseLLM):
|
||||
)
|
||||
result = instructor_instance.to_pydantic()
|
||||
structured_response = result.model_dump_json()
|
||||
usage_dict = self._usage_to_dict(usage_info)
|
||||
self._handle_emit_call_events(
|
||||
response=structured_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
usage=usage_dict,
|
||||
)
|
||||
return structured_response
|
||||
|
||||
usage_dict = self._usage_to_dict(usage_info)
|
||||
self._handle_emit_call_events(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
usage=usage_dict,
|
||||
)
|
||||
return full_response
|
||||
|
||||
@@ -994,12 +998,14 @@ class LLM(BaseLLM):
|
||||
return tool_result
|
||||
|
||||
# --- 10) Emit completion event and return response
|
||||
usage_dict = self._usage_to_dict(usage_info)
|
||||
self._handle_emit_call_events(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
usage=usage_dict,
|
||||
)
|
||||
return full_response
|
||||
|
||||
@@ -1021,6 +1027,7 @@ class LLM(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
usage=self._usage_to_dict(usage_info),
|
||||
)
|
||||
return full_response
|
||||
|
||||
@@ -1172,6 +1179,7 @@ class LLM(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
usage=None,
|
||||
)
|
||||
return structured_response
|
||||
|
||||
@@ -1202,6 +1210,8 @@ class LLM(BaseLLM):
|
||||
raise LLMContextLengthExceededError(error_msg) from e
|
||||
raise
|
||||
|
||||
response_usage = self._usage_to_dict(getattr(response, "usage", None))
|
||||
|
||||
# --- 2) Handle structured output response (when response_model is provided)
|
||||
if response_model is not None:
|
||||
# When using instructor/response_model, litellm returns a Pydantic model instance
|
||||
@@ -1213,6 +1223,7 @@ class LLM(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
usage=response_usage,
|
||||
)
|
||||
return structured_response
|
||||
|
||||
@@ -1244,6 +1255,7 @@ class LLM(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
usage=response_usage,
|
||||
)
|
||||
return text_response
|
||||
|
||||
@@ -1267,6 +1279,7 @@ class LLM(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
usage=response_usage,
|
||||
)
|
||||
return text_response
|
||||
|
||||
@@ -1316,6 +1329,7 @@ class LLM(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
usage=None,
|
||||
)
|
||||
return structured_response
|
||||
|
||||
@@ -1342,6 +1356,8 @@ class LLM(BaseLLM):
|
||||
raise LLMContextLengthExceededError(error_msg) from e
|
||||
raise
|
||||
|
||||
response_usage = self._usage_to_dict(getattr(response, "usage", None))
|
||||
|
||||
if response_model is not None:
|
||||
if isinstance(response, BaseModel):
|
||||
structured_response = response.model_dump_json()
|
||||
@@ -1351,6 +1367,7 @@ class LLM(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
usage=response_usage,
|
||||
)
|
||||
return structured_response
|
||||
|
||||
@@ -1380,6 +1397,7 @@ class LLM(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
usage=response_usage,
|
||||
)
|
||||
return text_response
|
||||
|
||||
@@ -1402,6 +1420,7 @@ class LLM(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
usage=response_usage,
|
||||
)
|
||||
return text_response
|
||||
|
||||
@@ -1548,12 +1567,14 @@ class LLM(BaseLLM):
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
usage_dict = self._usage_to_dict(usage_info)
|
||||
self._handle_emit_call_events(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params.get("messages"),
|
||||
usage=usage_dict,
|
||||
)
|
||||
return full_response
|
||||
|
||||
@@ -1575,6 +1596,7 @@ class LLM(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params.get("messages"),
|
||||
usage=self._usage_to_dict(usage_info),
|
||||
)
|
||||
return full_response
|
||||
raise
|
||||
@@ -1961,6 +1983,19 @@ class LLM(BaseLLM):
|
||||
)
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def _usage_to_dict(usage: Any) -> dict[str, Any] | None:
|
||||
if usage is None:
|
||||
return None
|
||||
if isinstance(usage, dict):
|
||||
return usage
|
||||
if hasattr(usage, "model_dump"):
|
||||
result: dict[str, Any] = usage.model_dump()
|
||||
return result
|
||||
if hasattr(usage, "__dict__"):
|
||||
return {k: v for k, v in vars(usage).items() if not k.startswith("_")}
|
||||
return None
|
||||
|
||||
def _handle_emit_call_events(
|
||||
self,
|
||||
response: Any,
|
||||
@@ -1968,6 +2003,7 @@ class LLM(BaseLLM):
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
messages: str | list[LLMMessage] | None = None,
|
||||
usage: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Handle the events for the LLM call.
|
||||
|
||||
@@ -1977,6 +2013,7 @@ class LLM(BaseLLM):
|
||||
from_task: Optional task object
|
||||
from_agent: Optional agent object
|
||||
messages: Optional messages object
|
||||
usage: Optional token usage data
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -1988,6 +2025,7 @@ class LLM(BaseLLM):
|
||||
from_agent=from_agent,
|
||||
model=self.model,
|
||||
call_id=get_current_call_id(),
|
||||
usage=usage,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -460,6 +460,7 @@ class BaseLLM(BaseModel, ABC):
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
messages: str | list[LLMMessage] | None = None,
|
||||
usage: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Emit LLM call completed event."""
|
||||
from crewai.utilities.serialization import to_serializable
|
||||
@@ -474,6 +475,7 @@ class BaseLLM(BaseModel, ABC):
|
||||
from_agent=from_agent,
|
||||
model=self.model,
|
||||
call_id=get_current_call_id(),
|
||||
usage=usage,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -811,6 +811,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
usage=usage,
|
||||
)
|
||||
return structured_data
|
||||
else:
|
||||
@@ -826,6 +827,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
usage=usage,
|
||||
)
|
||||
return structured_data
|
||||
|
||||
@@ -848,6 +850,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
usage=usage,
|
||||
)
|
||||
return list(tool_uses)
|
||||
|
||||
@@ -879,6 +882,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
if usage.get("total_tokens", 0) > 0:
|
||||
@@ -1028,6 +1032,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
usage=usage,
|
||||
)
|
||||
return structured_data
|
||||
for block in final_message.content:
|
||||
@@ -1042,6 +1047,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
usage=usage,
|
||||
)
|
||||
return structured_data
|
||||
|
||||
@@ -1071,6 +1077,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
return self._invoke_after_llm_call_hooks(
|
||||
@@ -1241,6 +1248,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=follow_up_params["messages"],
|
||||
usage=follow_up_usage,
|
||||
)
|
||||
|
||||
# Log combined token usage
|
||||
@@ -1332,6 +1340,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
usage=usage,
|
||||
)
|
||||
return structured_data
|
||||
else:
|
||||
@@ -1347,6 +1356,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
usage=usage,
|
||||
)
|
||||
return structured_data
|
||||
|
||||
@@ -1367,6 +1377,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
usage=usage,
|
||||
)
|
||||
return list(tool_uses)
|
||||
|
||||
@@ -1390,6 +1401,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
if usage.get("total_tokens", 0) > 0:
|
||||
@@ -1527,6 +1539,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
usage=usage,
|
||||
)
|
||||
return structured_data
|
||||
for block in final_message.content:
|
||||
@@ -1541,6 +1554,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
usage=usage,
|
||||
)
|
||||
return structured_data
|
||||
|
||||
@@ -1569,6 +1583,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
return full_response
|
||||
@@ -1627,6 +1642,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=follow_up_params["messages"],
|
||||
usage=follow_up_usage,
|
||||
)
|
||||
|
||||
total_usage = {
|
||||
|
||||
@@ -569,6 +569,7 @@ class AzureCompletion(BaseLLM):
|
||||
params: AzureCompletionParams,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
usage: dict[str, Any] | None = None,
|
||||
) -> BaseModel:
|
||||
"""Validate content against response model and emit completion event.
|
||||
|
||||
@@ -594,6 +595,7 @@ class AzureCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
return structured_data
|
||||
@@ -643,6 +645,7 @@ class AzureCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
usage=usage,
|
||||
)
|
||||
return list(message.tool_calls)
|
||||
|
||||
@@ -680,6 +683,7 @@ class AzureCompletion(BaseLLM):
|
||||
params=params,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
content = self._apply_stop_words(content)
|
||||
@@ -691,6 +695,7 @@ class AzureCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
return self._invoke_after_llm_call_hooks(
|
||||
@@ -794,7 +799,7 @@ class AzureCompletion(BaseLLM):
|
||||
self,
|
||||
full_response: str,
|
||||
tool_calls: dict[int, dict[str, Any]],
|
||||
usage_data: dict[str, int],
|
||||
usage_data: dict[str, Any] | None,
|
||||
params: AzureCompletionParams,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
@@ -806,7 +811,7 @@ class AzureCompletion(BaseLLM):
|
||||
Args:
|
||||
full_response: The complete streamed response content
|
||||
tool_calls: Dictionary of tool calls accumulated during streaming
|
||||
usage_data: Token usage data from the stream
|
||||
usage_data: Token usage data from the stream, or None if unavailable
|
||||
params: Completion parameters containing messages
|
||||
available_functions: Available functions for tool calling
|
||||
from_task: Task that initiated the call
|
||||
@@ -816,7 +821,8 @@ class AzureCompletion(BaseLLM):
|
||||
Returns:
|
||||
Final response content after processing, or structured output
|
||||
"""
|
||||
self._track_token_usage_internal(usage_data)
|
||||
if usage_data:
|
||||
self._track_token_usage_internal(usage_data)
|
||||
|
||||
# Handle structured output validation
|
||||
if response_model and self.is_openai_model:
|
||||
@@ -826,6 +832,7 @@ class AzureCompletion(BaseLLM):
|
||||
params=params,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
usage=usage_data,
|
||||
)
|
||||
|
||||
# If there are tool_calls but no available_functions, return them
|
||||
@@ -848,6 +855,7 @@ class AzureCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
usage=usage_data,
|
||||
)
|
||||
return formatted_tool_calls
|
||||
|
||||
@@ -884,6 +892,7 @@ class AzureCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
usage=usage_data,
|
||||
)
|
||||
|
||||
return self._invoke_after_llm_call_hooks(
|
||||
@@ -902,7 +911,7 @@ class AzureCompletion(BaseLLM):
|
||||
full_response = ""
|
||||
tool_calls: dict[int, dict[str, Any]] = {}
|
||||
|
||||
usage_data = {"total_tokens": 0}
|
||||
usage_data: dict[str, Any] | None = None
|
||||
for update in self._client.complete(**params):
|
||||
if isinstance(update, StreamingChatCompletionsUpdate):
|
||||
if update.usage:
|
||||
@@ -968,7 +977,7 @@ class AzureCompletion(BaseLLM):
|
||||
full_response = ""
|
||||
tool_calls: dict[int, dict[str, Any]] = {}
|
||||
|
||||
usage_data = {"total_tokens": 0}
|
||||
usage_data: dict[str, Any] | None = None
|
||||
|
||||
stream = await self._async_client.complete(**params)
|
||||
async for update in stream:
|
||||
|
||||
@@ -664,8 +664,9 @@ class BedrockCompletion(BaseLLM):
|
||||
)
|
||||
|
||||
# Track token usage according to AWS response format
|
||||
if "usage" in response:
|
||||
self._track_token_usage_internal(response["usage"])
|
||||
usage = response.get("usage")
|
||||
if usage:
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
stop_reason = response.get("stopReason")
|
||||
if stop_reason:
|
||||
@@ -705,6 +706,7 @@ class BedrockCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=messages,
|
||||
usage=usage,
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
@@ -727,6 +729,7 @@ class BedrockCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=messages,
|
||||
usage=usage,
|
||||
)
|
||||
return non_structured_output_tool_uses
|
||||
|
||||
@@ -806,6 +809,7 @@ class BedrockCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=messages,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
return self._invoke_after_llm_call_hooks(
|
||||
@@ -936,6 +940,7 @@ class BedrockCompletion(BaseLLM):
|
||||
tool_use_id: str | None = None
|
||||
tool_use_index = 0
|
||||
accumulated_tool_input = ""
|
||||
usage_data: dict[str, Any] | None = None
|
||||
|
||||
try:
|
||||
response = self._client.converse_stream(
|
||||
@@ -1045,6 +1050,7 @@ class BedrockCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=messages,
|
||||
usage=usage_data,
|
||||
)
|
||||
return result # type: ignore[return-value]
|
||||
except Exception as e:
|
||||
@@ -1112,6 +1118,7 @@ class BedrockCompletion(BaseLLM):
|
||||
metadata = event["metadata"]
|
||||
if "usage" in metadata:
|
||||
usage_metrics = metadata["usage"]
|
||||
usage_data = usage_metrics
|
||||
self._track_token_usage_internal(usage_metrics)
|
||||
logging.debug(f"Token usage: {usage_metrics}")
|
||||
if "trace" in metadata:
|
||||
@@ -1141,6 +1148,7 @@ class BedrockCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=messages,
|
||||
usage=usage_data,
|
||||
)
|
||||
|
||||
return full_response
|
||||
@@ -1252,8 +1260,9 @@ class BedrockCompletion(BaseLLM):
|
||||
**body,
|
||||
)
|
||||
|
||||
if "usage" in response:
|
||||
self._track_token_usage_internal(response["usage"])
|
||||
usage = response.get("usage")
|
||||
if usage:
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
stop_reason = response.get("stopReason")
|
||||
if stop_reason:
|
||||
@@ -1292,6 +1301,7 @@ class BedrockCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=messages,
|
||||
usage=usage,
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
@@ -1314,6 +1324,7 @@ class BedrockCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=messages,
|
||||
usage=usage,
|
||||
)
|
||||
return non_structured_output_tool_uses
|
||||
|
||||
@@ -1388,6 +1399,7 @@ class BedrockCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=messages,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
return text_content
|
||||
@@ -1508,6 +1520,7 @@ class BedrockCompletion(BaseLLM):
|
||||
tool_use_id: str | None = None
|
||||
tool_use_index = 0
|
||||
accumulated_tool_input = ""
|
||||
usage_data: dict[str, Any] | None = None
|
||||
|
||||
try:
|
||||
async_client = await self._ensure_async_client()
|
||||
@@ -1619,6 +1632,7 @@ class BedrockCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=messages,
|
||||
usage=usage_data,
|
||||
)
|
||||
return result # type: ignore[return-value]
|
||||
except Exception as e:
|
||||
@@ -1691,6 +1705,7 @@ class BedrockCompletion(BaseLLM):
|
||||
metadata = event["metadata"]
|
||||
if "usage" in metadata:
|
||||
usage_metrics = metadata["usage"]
|
||||
usage_data = usage_metrics
|
||||
self._track_token_usage_internal(usage_metrics)
|
||||
logging.debug(f"Token usage: {usage_metrics}")
|
||||
if "trace" in metadata:
|
||||
@@ -1720,6 +1735,7 @@ class BedrockCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=messages,
|
||||
usage=usage_data,
|
||||
)
|
||||
|
||||
return self._invoke_after_llm_call_hooks(
|
||||
|
||||
@@ -665,6 +665,7 @@ class GeminiCompletion(BaseLLM):
|
||||
messages_for_event: list[LLMMessage],
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
usage: dict[str, Any] | None = None,
|
||||
) -> BaseModel:
|
||||
"""Validate content against response model and emit completion event.
|
||||
|
||||
@@ -690,6 +691,7 @@ class GeminiCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=messages_for_event,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
return structured_data
|
||||
@@ -705,6 +707,7 @@ class GeminiCompletion(BaseLLM):
|
||||
response_model: type[BaseModel] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
usage: dict[str, Any] | None = None,
|
||||
) -> str | BaseModel:
|
||||
"""Finalize completion response with validation and event emission.
|
||||
|
||||
@@ -728,6 +731,7 @@ class GeminiCompletion(BaseLLM):
|
||||
messages_for_event=messages_for_event,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
@@ -736,6 +740,7 @@ class GeminiCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=messages_for_event,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
return self._invoke_after_llm_call_hooks(
|
||||
@@ -749,6 +754,7 @@ class GeminiCompletion(BaseLLM):
|
||||
contents: list[types.Content],
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
usage: dict[str, Any] | None = None,
|
||||
) -> BaseModel:
|
||||
"""Validate and emit event for structured_output tool call.
|
||||
|
||||
@@ -773,6 +779,7 @@ class GeminiCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=self._convert_contents_to_dict(contents),
|
||||
usage=usage,
|
||||
)
|
||||
return validated_data
|
||||
except Exception as e:
|
||||
@@ -791,6 +798,7 @@ class GeminiCompletion(BaseLLM):
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
usage: dict[str, Any] | None = None,
|
||||
) -> str | Any:
|
||||
"""Process response, execute function calls, and finalize completion.
|
||||
|
||||
@@ -831,6 +839,7 @@ class GeminiCompletion(BaseLLM):
|
||||
contents=contents,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
# Filter out structured_output from function calls returned to executor
|
||||
@@ -852,6 +861,7 @@ class GeminiCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=self._convert_contents_to_dict(contents),
|
||||
usage=usage,
|
||||
)
|
||||
return non_structured_output_parts
|
||||
|
||||
@@ -893,6 +903,7 @@ class GeminiCompletion(BaseLLM):
|
||||
response_model=effective_response_model,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def _process_stream_chunk(
|
||||
@@ -900,10 +911,10 @@ class GeminiCompletion(BaseLLM):
|
||||
chunk: GenerateContentResponse,
|
||||
full_response: str,
|
||||
function_calls: dict[int, dict[str, Any]],
|
||||
usage_data: dict[str, int],
|
||||
usage_data: dict[str, int] | None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> tuple[str, dict[int, dict[str, Any]], dict[str, int]]:
|
||||
) -> tuple[str, dict[int, dict[str, Any]], dict[str, int] | None]:
|
||||
"""Process a single streaming chunk.
|
||||
|
||||
Args:
|
||||
@@ -979,7 +990,7 @@ class GeminiCompletion(BaseLLM):
|
||||
self,
|
||||
full_response: str,
|
||||
function_calls: dict[int, dict[str, Any]],
|
||||
usage_data: dict[str, int],
|
||||
usage_data: dict[str, int] | None,
|
||||
contents: list[types.Content],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
@@ -991,7 +1002,7 @@ class GeminiCompletion(BaseLLM):
|
||||
Args:
|
||||
full_response: The complete streamed response content
|
||||
function_calls: Dictionary of function calls accumulated during streaming
|
||||
usage_data: Token usage data from the stream
|
||||
usage_data: Token usage data from the stream, or None if unavailable
|
||||
contents: Original contents for event conversion
|
||||
available_functions: Available functions for function calling
|
||||
from_task: Task that initiated the call
|
||||
@@ -1001,7 +1012,8 @@ class GeminiCompletion(BaseLLM):
|
||||
Returns:
|
||||
Final response content after processing
|
||||
"""
|
||||
self._track_token_usage_internal(usage_data)
|
||||
if usage_data:
|
||||
self._track_token_usage_internal(usage_data)
|
||||
|
||||
if response_model and function_calls:
|
||||
for call_data in function_calls.values():
|
||||
@@ -1013,6 +1025,7 @@ class GeminiCompletion(BaseLLM):
|
||||
contents=contents,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
usage=usage_data,
|
||||
)
|
||||
|
||||
non_structured_output_calls = {
|
||||
@@ -1041,6 +1054,7 @@ class GeminiCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=self._convert_contents_to_dict(contents),
|
||||
usage=usage_data,
|
||||
)
|
||||
return formatted_function_calls
|
||||
|
||||
@@ -1081,6 +1095,7 @@ class GeminiCompletion(BaseLLM):
|
||||
response_model=effective_response_model,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
usage=usage_data,
|
||||
)
|
||||
|
||||
def _handle_completion(
|
||||
@@ -1118,6 +1133,7 @@ class GeminiCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def _handle_streaming_completion(
|
||||
@@ -1132,7 +1148,7 @@ class GeminiCompletion(BaseLLM):
|
||||
"""Handle streaming content generation."""
|
||||
full_response = ""
|
||||
function_calls: dict[int, dict[str, Any]] = {}
|
||||
usage_data = {"total_tokens": 0}
|
||||
usage_data: dict[str, int] | None = None
|
||||
|
||||
# The API accepts list[Content] but mypy is overly strict about variance
|
||||
contents_for_api: Any = contents
|
||||
@@ -1196,6 +1212,7 @@ class GeminiCompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
async def _ahandle_streaming_completion(
|
||||
@@ -1210,7 +1227,7 @@ class GeminiCompletion(BaseLLM):
|
||||
"""Handle async streaming content generation."""
|
||||
full_response = ""
|
||||
function_calls: dict[int, dict[str, Any]] = {}
|
||||
usage_data = {"total_tokens": 0}
|
||||
usage_data: dict[str, int] | None = None
|
||||
|
||||
# The API accepts list[Content] but mypy is overly strict about variance
|
||||
contents_for_api: Any = contents
|
||||
|
||||
@@ -809,6 +809,7 @@ class OpenAICompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params.get("input", []),
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
return parsed_result
|
||||
@@ -821,6 +822,7 @@ class OpenAICompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params.get("input", []),
|
||||
usage=usage,
|
||||
)
|
||||
return function_calls
|
||||
|
||||
@@ -858,6 +860,7 @@ class OpenAICompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params.get("input", []),
|
||||
usage=usage,
|
||||
)
|
||||
return structured_result
|
||||
except ValueError as e:
|
||||
@@ -871,6 +874,7 @@ class OpenAICompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params.get("input", []),
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
content = self._invoke_after_llm_call_hooks(
|
||||
@@ -941,6 +945,7 @@ class OpenAICompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params.get("input", []),
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
return parsed_result
|
||||
@@ -953,6 +958,7 @@ class OpenAICompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params.get("input", []),
|
||||
usage=usage,
|
||||
)
|
||||
return function_calls
|
||||
|
||||
@@ -990,6 +996,7 @@ class OpenAICompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params.get("input", []),
|
||||
usage=usage,
|
||||
)
|
||||
return structured_result
|
||||
except ValueError as e:
|
||||
@@ -1003,6 +1010,7 @@ class OpenAICompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params.get("input", []),
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
except NotFoundError as e:
|
||||
@@ -1045,6 +1053,7 @@ class OpenAICompletion(BaseLLM):
|
||||
full_response = ""
|
||||
function_calls: list[dict[str, Any]] = []
|
||||
final_response: Response | None = None
|
||||
usage: dict[str, Any] | None = None
|
||||
|
||||
stream = self._client.responses.create(**params)
|
||||
response_id_stream = None
|
||||
@@ -1102,6 +1111,7 @@ class OpenAICompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params.get("input", []),
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
return parsed_result
|
||||
@@ -1138,6 +1148,7 @@ class OpenAICompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params.get("input", []),
|
||||
usage=usage,
|
||||
)
|
||||
return structured_result
|
||||
except ValueError as e:
|
||||
@@ -1151,6 +1162,7 @@ class OpenAICompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params.get("input", []),
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
return self._invoke_after_llm_call_hooks(
|
||||
@@ -1169,6 +1181,7 @@ class OpenAICompletion(BaseLLM):
|
||||
full_response = ""
|
||||
function_calls: list[dict[str, Any]] = []
|
||||
final_response: Response | None = None
|
||||
usage: dict[str, Any] | None = None
|
||||
|
||||
stream = await self._async_client.responses.create(**params)
|
||||
response_id_stream = None
|
||||
@@ -1226,6 +1239,7 @@ class OpenAICompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params.get("input", []),
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
return parsed_result
|
||||
@@ -1262,6 +1276,7 @@ class OpenAICompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params.get("input", []),
|
||||
usage=usage,
|
||||
)
|
||||
return structured_result
|
||||
except ValueError as e:
|
||||
@@ -1275,6 +1290,7 @@ class OpenAICompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params.get("input", []),
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
return full_response
|
||||
@@ -1580,6 +1596,7 @@ class OpenAICompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
usage=usage,
|
||||
)
|
||||
return parsed_object
|
||||
|
||||
@@ -1601,6 +1618,7 @@ class OpenAICompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
usage=usage,
|
||||
)
|
||||
return list(message.tool_calls)
|
||||
|
||||
@@ -1639,6 +1657,7 @@ class OpenAICompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
usage=usage,
|
||||
)
|
||||
return structured_result
|
||||
except ValueError as e:
|
||||
@@ -1652,6 +1671,7 @@ class OpenAICompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
if usage.get("total_tokens", 0) > 0:
|
||||
@@ -1693,7 +1713,7 @@ class OpenAICompletion(BaseLLM):
|
||||
self,
|
||||
full_response: str,
|
||||
tool_calls: dict[int, dict[str, Any]],
|
||||
usage_data: dict[str, int],
|
||||
usage_data: dict[str, Any] | None,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
@@ -1704,7 +1724,7 @@ class OpenAICompletion(BaseLLM):
|
||||
Args:
|
||||
full_response: The accumulated text response from the stream.
|
||||
tool_calls: Accumulated tool calls from the stream, keyed by index.
|
||||
usage_data: Token usage data from the stream.
|
||||
usage_data: Token usage data from the stream, or None if unavailable.
|
||||
params: The completion parameters containing messages.
|
||||
available_functions: Available functions for tool calling.
|
||||
from_task: Task that initiated the call.
|
||||
@@ -1715,7 +1735,8 @@ class OpenAICompletion(BaseLLM):
|
||||
tool execution result when available_functions is provided,
|
||||
or the text response string.
|
||||
"""
|
||||
self._track_token_usage_internal(usage_data)
|
||||
if usage_data:
|
||||
self._track_token_usage_internal(usage_data)
|
||||
|
||||
if tool_calls and not available_functions:
|
||||
tool_calls_list = [
|
||||
@@ -1736,6 +1757,7 @@ class OpenAICompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
usage=usage_data,
|
||||
)
|
||||
return tool_calls_list
|
||||
|
||||
@@ -1778,6 +1800,7 @@ class OpenAICompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
usage=usage_data,
|
||||
)
|
||||
|
||||
return full_response
|
||||
@@ -1831,6 +1854,7 @@ class OpenAICompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
usage=usage,
|
||||
)
|
||||
return parsed_result
|
||||
|
||||
@@ -1841,7 +1865,7 @@ class OpenAICompletion(BaseLLM):
|
||||
self._client.chat.completions.create(**params)
|
||||
)
|
||||
|
||||
usage_data = {"total_tokens": 0}
|
||||
usage_data: dict[str, Any] | None = None
|
||||
|
||||
for completion_chunk in completion_stream:
|
||||
response_id_stream = (
|
||||
@@ -1955,6 +1979,7 @@ class OpenAICompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
usage=usage,
|
||||
)
|
||||
return parsed_object
|
||||
|
||||
@@ -1978,6 +2003,7 @@ class OpenAICompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
usage=usage,
|
||||
)
|
||||
return list(message.tool_calls)
|
||||
|
||||
@@ -2016,6 +2042,7 @@ class OpenAICompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
usage=usage,
|
||||
)
|
||||
return structured_result
|
||||
except ValueError as e:
|
||||
@@ -2029,6 +2056,7 @@ class OpenAICompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
if usage.get("total_tokens", 0) > 0:
|
||||
@@ -2079,7 +2107,7 @@ class OpenAICompletion(BaseLLM):
|
||||
] = await self._async_client.chat.completions.create(**params)
|
||||
|
||||
accumulated_content = ""
|
||||
usage_data = {"total_tokens": 0}
|
||||
usage_data: dict[str, Any] | None = None
|
||||
async for chunk in completion_stream:
|
||||
response_id_stream = chunk.id if hasattr(chunk, "id") else None
|
||||
|
||||
@@ -2102,7 +2130,8 @@ class OpenAICompletion(BaseLLM):
|
||||
response_id=response_id_stream,
|
||||
)
|
||||
|
||||
self._track_token_usage_internal(usage_data)
|
||||
if usage_data:
|
||||
self._track_token_usage_internal(usage_data)
|
||||
|
||||
try:
|
||||
parsed_object = response_model.model_validate_json(accumulated_content)
|
||||
@@ -2113,6 +2142,7 @@ class OpenAICompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
usage=usage_data,
|
||||
)
|
||||
|
||||
return parsed_object
|
||||
@@ -2124,6 +2154,7 @@ class OpenAICompletion(BaseLLM):
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
usage=usage_data,
|
||||
)
|
||||
return accumulated_content
|
||||
|
||||
@@ -2131,7 +2162,7 @@ class OpenAICompletion(BaseLLM):
|
||||
ChatCompletionChunk
|
||||
] = await self._async_client.chat.completions.create(**params)
|
||||
|
||||
usage_data = {"total_tokens": 0}
|
||||
usage_data = None
|
||||
|
||||
async for chunk in stream:
|
||||
response_id_stream = chunk.id if hasattr(chunk, "id") else None
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date, datetime
|
||||
from enum import Enum
|
||||
import json
|
||||
from typing import Any, TypeAlias
|
||||
import uuid
|
||||
@@ -20,6 +21,7 @@ def to_serializable(
|
||||
max_depth: int = 5,
|
||||
_current_depth: int = 0,
|
||||
_ancestors: set[int] | None = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
) -> Serializable:
|
||||
"""Converts a Python object into a JSON-compatible representation.
|
||||
|
||||
@@ -33,6 +35,9 @@ def to_serializable(
|
||||
max_depth: Maximum recursion depth. Defaults to 5.
|
||||
_current_depth: Current recursion depth (for internal use).
|
||||
_ancestors: Set of ancestor object ids for cycle detection (for internal use).
|
||||
context: Optional context dict passed to Pydantic's model_dump(context=...).
|
||||
Field serializers on the model can inspect this to customize output
|
||||
(e.g. context={"trace": True} for lightweight trace serialization).
|
||||
|
||||
Returns:
|
||||
Serializable: A JSON-compatible structure.
|
||||
@@ -48,6 +53,15 @@ def to_serializable(
|
||||
|
||||
if isinstance(obj, (str, int, float, bool, type(None))):
|
||||
return obj
|
||||
if isinstance(obj, Enum):
|
||||
return to_serializable(
|
||||
obj.value,
|
||||
exclude=exclude,
|
||||
max_depth=max_depth,
|
||||
_current_depth=_current_depth,
|
||||
_ancestors=_ancestors,
|
||||
context=context,
|
||||
)
|
||||
if isinstance(obj, uuid.UUID):
|
||||
return str(obj)
|
||||
if isinstance(obj, (date, datetime)):
|
||||
@@ -66,6 +80,7 @@ def to_serializable(
|
||||
max_depth=max_depth,
|
||||
_current_depth=_current_depth + 1,
|
||||
_ancestors=new_ancestors,
|
||||
context=context,
|
||||
)
|
||||
for item in obj
|
||||
]
|
||||
@@ -77,17 +92,24 @@ def to_serializable(
|
||||
max_depth=max_depth,
|
||||
_current_depth=_current_depth + 1,
|
||||
_ancestors=new_ancestors,
|
||||
context=context,
|
||||
)
|
||||
for key, value in obj.items()
|
||||
if key not in exclude
|
||||
}
|
||||
if isinstance(obj, BaseModel):
|
||||
try:
|
||||
dump_kwargs: dict[str, Any] = {}
|
||||
if exclude:
|
||||
dump_kwargs["exclude"] = exclude
|
||||
if context is not None:
|
||||
dump_kwargs["context"] = context
|
||||
return to_serializable(
|
||||
obj=obj.model_dump(exclude=exclude),
|
||||
obj=obj.model_dump(**dump_kwargs),
|
||||
max_depth=max_depth,
|
||||
_current_depth=_current_depth + 1,
|
||||
_ancestors=new_ancestors,
|
||||
context=context,
|
||||
)
|
||||
except Exception:
|
||||
try:
|
||||
@@ -97,12 +119,30 @@ def to_serializable(
|
||||
max_depth=max_depth,
|
||||
_current_depth=_current_depth + 1,
|
||||
_ancestors=new_ancestors,
|
||||
context=context,
|
||||
)
|
||||
for k, v in obj.__dict__.items()
|
||||
if k not in (exclude or set())
|
||||
}
|
||||
except Exception:
|
||||
return repr(obj)
|
||||
if callable(obj):
|
||||
return repr(obj)
|
||||
if hasattr(obj, "__dict__"):
|
||||
try:
|
||||
return {
|
||||
_to_serializable_key(k): to_serializable(
|
||||
v,
|
||||
max_depth=max_depth,
|
||||
_current_depth=_current_depth + 1,
|
||||
_ancestors=new_ancestors,
|
||||
context=context,
|
||||
)
|
||||
for k, v in obj.__dict__.items()
|
||||
if not k.startswith("_")
|
||||
}
|
||||
except Exception:
|
||||
return repr(obj)
|
||||
return repr(obj)
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,108 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"messages":[{"role":"user","content":"Say hello"}],"model":"gpt-4o-mini"}'
|
||||
headers:
|
||||
User-Agent:
|
||||
- X-USER-AGENT-XXX
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- ACCEPT-ENCODING-XXX
|
||||
authorization:
|
||||
- AUTHORIZATION-XXX
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '74'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- api.openai.com
|
||||
x-stainless-arch:
|
||||
- X-STAINLESS-ARCH-XXX
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- X-STAINLESS-OS-XXX
|
||||
x-stainless-package-version:
|
||||
- 1.83.0
|
||||
x-stainless-read-timeout:
|
||||
- X-STAINLESS-READ-TIMEOUT-XXX
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.13.2
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: "{\n \"id\": \"chatcmpl-DPS8YQSwQ3pZKZztIoIe1eYodMqh2\",\n \"object\":
|
||||
\"chat.completion\",\n \"created\": 1774958730,\n \"model\": \"gpt-4o-mini-2024-07-18\",\n
|
||||
\ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\":
|
||||
\"assistant\",\n \"content\": \"Hello! How can I assist you today?\",\n
|
||||
\ \"refusal\": null,\n \"annotations\": []\n },\n \"logprobs\":
|
||||
null,\n \"finish_reason\": \"stop\"\n }\n ],\n \"usage\": {\n \"prompt_tokens\":
|
||||
9,\n \"completion_tokens\": 9,\n \"total_tokens\": 18,\n \"prompt_tokens_details\":
|
||||
{\n \"cached_tokens\": 0,\n \"audio_tokens\": 0\n },\n \"completion_tokens_details\":
|
||||
{\n \"reasoning_tokens\": 0,\n \"audio_tokens\": 0,\n \"accepted_prediction_tokens\":
|
||||
0,\n \"rejected_prediction_tokens\": 0\n }\n },\n \"service_tier\":
|
||||
\"default\",\n \"system_fingerprint\": \"fp_709f182cb4\"\n}\n"
|
||||
headers:
|
||||
CF-Cache-Status:
|
||||
- DYNAMIC
|
||||
CF-Ray:
|
||||
- 9e4f38fc5d9d82e8-GIG
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Tue, 31 Mar 2026 12:05:30 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Strict-Transport-Security:
|
||||
- STS-XXX
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- X-CONTENT-TYPE-XXX
|
||||
access-control-expose-headers:
|
||||
- ACCESS-CONTROL-XXX
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
content-length:
|
||||
- '839'
|
||||
openai-organization:
|
||||
- OPENAI-ORG-XXX
|
||||
openai-processing-ms:
|
||||
- '680'
|
||||
openai-project:
|
||||
- OPENAI-PROJECT-XXX
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
set-cookie:
|
||||
- SET-COOKIE-XXX
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-requests:
|
||||
- X-RATELIMIT-LIMIT-REQUESTS-XXX
|
||||
x-ratelimit-limit-tokens:
|
||||
- X-RATELIMIT-LIMIT-TOKENS-XXX
|
||||
x-ratelimit-remaining-requests:
|
||||
- X-RATELIMIT-REMAINING-REQUESTS-XXX
|
||||
x-ratelimit-remaining-tokens:
|
||||
- X-RATELIMIT-REMAINING-TOKENS-XXX
|
||||
x-ratelimit-reset-requests:
|
||||
- X-RATELIMIT-RESET-REQUESTS-XXX
|
||||
x-ratelimit-reset-tokens:
|
||||
- X-RATELIMIT-RESET-TOKENS-XXX
|
||||
x-request-id:
|
||||
- X-REQUEST-ID-XXX
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
176
lib/crewai/tests/events/test_llm_usage_event.py
Normal file
176
lib/crewai/tests/events/test_llm_usage_event.py
Normal file
@@ -0,0 +1,176 @@
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.events.event_bus import CrewAIEventsBus
|
||||
from crewai.events.types.llm_events import LLMCallCompletedEvent, LLMCallType
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
|
||||
|
||||
class TestLLMCallCompletedEventUsageField:
|
||||
def test_accepts_usage_dict(self):
|
||||
event = LLMCallCompletedEvent(
|
||||
response="hello",
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
call_id="test-id",
|
||||
usage={"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
|
||||
)
|
||||
assert event.usage == {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 20,
|
||||
"total_tokens": 30,
|
||||
}
|
||||
|
||||
def test_usage_defaults_to_none(self):
|
||||
event = LLMCallCompletedEvent(
|
||||
response="hello",
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
call_id="test-id",
|
||||
)
|
||||
assert event.usage is None
|
||||
|
||||
def test_accepts_none_usage(self):
|
||||
event = LLMCallCompletedEvent(
|
||||
response="hello",
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
call_id="test-id",
|
||||
usage=None,
|
||||
)
|
||||
assert event.usage is None
|
||||
|
||||
def test_accepts_nested_usage_dict(self):
|
||||
usage = {
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 200,
|
||||
"total_tokens": 300,
|
||||
"prompt_tokens_details": {"cached_tokens": 50},
|
||||
}
|
||||
event = LLMCallCompletedEvent(
|
||||
response="hello",
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
call_id="test-id",
|
||||
usage=usage,
|
||||
)
|
||||
assert event.usage["prompt_tokens_details"]["cached_tokens"] == 50
|
||||
|
||||
|
||||
class TestUsageToDict:
|
||||
def test_none_returns_none(self):
|
||||
assert LLM._usage_to_dict(None) is None
|
||||
|
||||
def test_dict_passes_through(self):
|
||||
usage = {"prompt_tokens": 10, "total_tokens": 30}
|
||||
assert LLM._usage_to_dict(usage) is usage
|
||||
|
||||
def test_pydantic_model_uses_model_dump(self):
|
||||
class Usage(BaseModel):
|
||||
prompt_tokens: int = 10
|
||||
completion_tokens: int = 20
|
||||
total_tokens: int = 30
|
||||
|
||||
result = LLM._usage_to_dict(Usage())
|
||||
assert result == {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 20,
|
||||
"total_tokens": 30,
|
||||
}
|
||||
|
||||
def test_object_with_dict_attr(self):
|
||||
class UsageObj:
|
||||
def __init__(self):
|
||||
self.prompt_tokens = 5
|
||||
self.completion_tokens = 15
|
||||
self.total_tokens = 20
|
||||
|
||||
result = LLM._usage_to_dict(UsageObj())
|
||||
assert result == {
|
||||
"prompt_tokens": 5,
|
||||
"completion_tokens": 15,
|
||||
"total_tokens": 20,
|
||||
}
|
||||
|
||||
def test_object_with_dict_excludes_private_attrs(self):
|
||||
class UsageObj:
|
||||
def __init__(self):
|
||||
self.total_tokens = 42
|
||||
self._internal = "hidden"
|
||||
|
||||
result = LLM._usage_to_dict(UsageObj())
|
||||
assert result == {"total_tokens": 42}
|
||||
assert "_internal" not in result
|
||||
|
||||
def test_unsupported_type_returns_none(self):
|
||||
assert LLM._usage_to_dict(42) is None
|
||||
assert LLM._usage_to_dict("string") is None
|
||||
|
||||
|
||||
class _StubLLM(BaseLLM):
|
||||
"""Minimal concrete BaseLLM for testing event emission."""
|
||||
|
||||
model: str = "test-model"
|
||||
|
||||
def call(self, *args: Any, **kwargs: Any) -> str:
|
||||
return ""
|
||||
|
||||
async def acall(self, *args: Any, **kwargs: Any) -> str:
|
||||
return ""
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
return False
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class TestEmitCallCompletedEventPassesUsage:
|
||||
@pytest.fixture
|
||||
def mock_emit(self):
|
||||
with patch.object(CrewAIEventsBus, "emit") as mock:
|
||||
yield mock
|
||||
|
||||
@pytest.fixture
|
||||
def llm(self):
|
||||
return _StubLLM(model="test-model")
|
||||
|
||||
def test_usage_is_passed_to_event(self, mock_emit, llm):
|
||||
usage_data = {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}
|
||||
|
||||
llm._emit_call_completed_event(
|
||||
response="hello",
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
messages="test prompt",
|
||||
usage=usage_data,
|
||||
)
|
||||
|
||||
mock_emit.assert_called_once()
|
||||
event = mock_emit.call_args[1]["event"]
|
||||
assert isinstance(event, LLMCallCompletedEvent)
|
||||
assert event.usage == usage_data
|
||||
|
||||
def test_none_usage_is_passed_to_event(self, mock_emit, llm):
|
||||
llm._emit_call_completed_event(
|
||||
response="hello",
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
messages="test prompt",
|
||||
usage=None,
|
||||
)
|
||||
|
||||
mock_emit.assert_called_once()
|
||||
event = mock_emit.call_args[1]["event"]
|
||||
assert isinstance(event, LLMCallCompletedEvent)
|
||||
assert event.usage is None
|
||||
|
||||
def test_usage_omitted_defaults_to_none(self, mock_emit, llm):
|
||||
llm._emit_call_completed_event(
|
||||
response="hello",
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
messages="test prompt",
|
||||
)
|
||||
|
||||
mock_emit.assert_called_once()
|
||||
event = mock_emit.call_args[1]["event"]
|
||||
assert isinstance(event, LLMCallCompletedEvent)
|
||||
assert event.usage is None
|
||||
@@ -752,11 +752,7 @@ def test_litellm_retry_catches_litellm_unsupported_params_error(caplog):
|
||||
raise litellm_error
|
||||
return MagicMock(
|
||||
choices=[MagicMock(message=MagicMock(content="Paris", tool_calls=None))],
|
||||
usage=MagicMock(
|
||||
prompt_tokens=10,
|
||||
completion_tokens=5,
|
||||
total_tokens=15,
|
||||
),
|
||||
usage={"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
|
||||
)
|
||||
|
||||
with patch("litellm.completion", side_effect=mock_completion):
|
||||
@@ -787,11 +783,7 @@ def test_litellm_retry_catches_openai_api_stop_error(caplog):
|
||||
raise api_error
|
||||
return MagicMock(
|
||||
choices=[MagicMock(message=MagicMock(content="Paris", tool_calls=None))],
|
||||
usage=MagicMock(
|
||||
prompt_tokens=10,
|
||||
completion_tokens=5,
|
||||
total_tokens=15,
|
||||
),
|
||||
usage={"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
|
||||
)
|
||||
|
||||
with patch("litellm.completion", side_effect=mock_completion):
|
||||
|
||||
612
lib/crewai/tests/tracing/test_trace_serialization.py
Normal file
612
lib/crewai/tests/tracing/test_trace_serialization.py
Normal file
@@ -0,0 +1,612 @@
|
||||
"""Tests for trace serialization optimization using Pydantic v2 context-based serialization.
|
||||
|
||||
These tests verify that trace events use @field_serializer with SerializationInfo.context
|
||||
to produce lightweight representations, reducing event sizes from 50-100KB to a few KB.
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.events.base_events import _trace_agent_ref, _trace_task_ref, _trace_tool_names
|
||||
from crewai.events.listeners.tracing.utils import safe_serialize_to_dict
|
||||
from crewai.utilities.serialization import to_serializable
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Lightweight BaseAgent subclass for tests (avoids heavy dependencies)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _StubAgent(BaseAgent):
|
||||
"""Minimal BaseAgent subclass that satisfies validation without heavy deps."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def execute_task(self, *a: Any, **kw: Any) -> str:
|
||||
return ""
|
||||
|
||||
def create_agent_executor(self, *a: Any, **kw: Any) -> None:
|
||||
pass
|
||||
|
||||
def _parse_tools(self, *a: Any, **kw: Any) -> list:
|
||||
return []
|
||||
|
||||
def get_delegation_tools(self, *a: Any, **kw: Any) -> list:
|
||||
return []
|
||||
|
||||
def get_output_converter(self, *a: Any, **kw: Any) -> Any:
|
||||
return None
|
||||
|
||||
def get_multimodal_tools(self, *a: Any, **kw: Any) -> list:
|
||||
return []
|
||||
|
||||
async def aexecute_task(self, *a: Any, **kw: Any) -> str:
|
||||
return ""
|
||||
|
||||
def get_mcp_tools(self, *a: Any, **kw: Any) -> list:
|
||||
return []
|
||||
|
||||
def get_platform_tools(self, *a: Any, **kw: Any) -> list:
|
||||
return []
|
||||
|
||||
|
||||
def _make_stub_agent(**overrides) -> _StubAgent:
|
||||
"""Create a minimal BaseAgent instance for testing."""
|
||||
defaults = {
|
||||
"role": "Researcher",
|
||||
"goal": "Research things",
|
||||
"backstory": "Expert researcher",
|
||||
"tools": [],
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return _StubAgent(**defaults)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers to build realistic mock objects for event fields
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_mock_task(**overrides):
|
||||
task = MagicMock()
|
||||
task.id = overrides.get("id", uuid.uuid4())
|
||||
task.name = overrides.get("name", "Research Task")
|
||||
task.description = overrides.get("description", "Do research")
|
||||
task.expected_output = overrides.get("expected_output", "Research results")
|
||||
task.async_execution = overrides.get("async_execution", False)
|
||||
task.human_input = overrides.get("human_input", False)
|
||||
task.agent = overrides.get("agent", _make_stub_agent())
|
||||
task.context = overrides.get("context", None)
|
||||
task.crew = MagicMock()
|
||||
task.tools = overrides.get("tools", [MagicMock(), MagicMock()])
|
||||
|
||||
fp = MagicMock()
|
||||
fp.uuid_str = str(uuid.uuid4())
|
||||
fp.metadata = {"name": task.name}
|
||||
task.fingerprint = fp
|
||||
|
||||
return task
|
||||
|
||||
|
||||
def _make_stub_tool(tool_name="web_search") -> Any:
|
||||
"""Create a minimal BaseTool instance for testing."""
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
|
||||
class _StubTool(BaseTool):
|
||||
name: str = "stub"
|
||||
description: str = "stub tool"
|
||||
|
||||
def _run(self, *a: Any, **kw: Any) -> str:
|
||||
return ""
|
||||
|
||||
return _StubTool(name=tool_name, description=f"{tool_name} tool")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit tests: trace ref helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTraceRefHelpers:
|
||||
def test_trace_agent_ref(self):
|
||||
agent = _make_stub_agent(role="Analyst")
|
||||
ref = _trace_agent_ref(agent)
|
||||
assert ref["role"] == "Analyst"
|
||||
assert "id" in ref
|
||||
assert len(ref) == 2 # only id and role
|
||||
|
||||
def test_trace_agent_ref_none(self):
|
||||
assert _trace_agent_ref(None) is None
|
||||
|
||||
def test_trace_task_ref(self):
|
||||
task = _make_mock_task(name="Write Report")
|
||||
ref = _trace_task_ref(task)
|
||||
assert ref["name"] == "Write Report"
|
||||
assert "id" in ref
|
||||
assert len(ref) == 2
|
||||
|
||||
def test_trace_task_ref_falls_back_to_description(self):
|
||||
task = _make_mock_task(name=None, description="Describe the report")
|
||||
ref = _trace_task_ref(task)
|
||||
assert ref["name"] == "Describe the report"
|
||||
|
||||
def test_trace_task_ref_none(self):
|
||||
assert _trace_task_ref(None) is None
|
||||
|
||||
def test_trace_tool_names(self):
|
||||
tools = [_make_stub_tool("search"), _make_stub_tool("read")]
|
||||
names = _trace_tool_names(tools)
|
||||
assert names == ["search", "read"]
|
||||
|
||||
def test_trace_tool_names_empty(self):
|
||||
assert _trace_tool_names([]) is None
|
||||
assert _trace_tool_names(None) is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration tests: field serializers on real event classes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestAgentEventFieldSerializers:
|
||||
"""Test that agent event field serializers respond to trace context."""
|
||||
|
||||
def test_agent_execution_started_trace_context(self):
|
||||
from crewai.events.types.agent_events import AgentExecutionStartedEvent
|
||||
|
||||
agent = _make_stub_agent(role="Researcher")
|
||||
task = _make_mock_task(name="Research Task")
|
||||
tools = [_make_stub_tool("search"), _make_stub_tool("read")]
|
||||
|
||||
event = AgentExecutionStartedEvent(
|
||||
agent=agent, task=task, tools=tools, task_prompt="Do research"
|
||||
)
|
||||
|
||||
# With trace context: lightweight refs
|
||||
trace_dump = event.model_dump(context={"trace": True})
|
||||
assert trace_dump["agent"] == {"id": str(agent.id), "role": "Researcher"}
|
||||
assert trace_dump["task"] == {"id": str(task.id), "name": "Research Task"}
|
||||
assert trace_dump["tools"] == ["search", "read"]
|
||||
|
||||
def test_agent_execution_started_no_context(self):
|
||||
from crewai.events.types.agent_events import AgentExecutionStartedEvent
|
||||
|
||||
agent = _make_stub_agent(role="SpecificRole")
|
||||
task = _make_mock_task()
|
||||
|
||||
event = AgentExecutionStartedEvent(
|
||||
agent=agent, task=task, tools=None, task_prompt="Do research"
|
||||
)
|
||||
|
||||
# Without context: full agent dict (Pydantic model_dump expands it)
|
||||
normal_dump = event.model_dump()
|
||||
assert isinstance(normal_dump["agent"], dict)
|
||||
assert normal_dump["agent"]["role"] == "SpecificRole"
|
||||
# Should have ALL agent fields, not just the lightweight ref
|
||||
assert "goal" in normal_dump["agent"]
|
||||
assert "backstory" in normal_dump["agent"]
|
||||
assert "max_iter" in normal_dump["agent"]
|
||||
|
||||
def test_agent_execution_error_preserves_identification(self):
|
||||
from crewai.events.types.agent_events import AgentExecutionErrorEvent
|
||||
|
||||
agent = _make_stub_agent(role="Analyst")
|
||||
task = _make_mock_task(name="Analysis Task")
|
||||
|
||||
event = AgentExecutionErrorEvent(
|
||||
agent=agent, task=task, error="Something went wrong"
|
||||
)
|
||||
|
||||
trace_dump = event.model_dump(context={"trace": True})
|
||||
# Error events should still have agent/task identification as refs
|
||||
assert trace_dump["agent"]["role"] == "Analyst"
|
||||
assert trace_dump["task"]["name"] == "Analysis Task"
|
||||
assert trace_dump["error"] == "Something went wrong"
|
||||
|
||||
def test_agent_execution_completed_trace_context(self):
|
||||
from crewai.events.types.agent_events import AgentExecutionCompletedEvent
|
||||
|
||||
agent = _make_stub_agent(role="Writer")
|
||||
task = _make_mock_task(name="Writing Task")
|
||||
|
||||
event = AgentExecutionCompletedEvent(
|
||||
agent=agent, task=task, output="Final output"
|
||||
)
|
||||
|
||||
trace_dump = event.model_dump(context={"trace": True})
|
||||
assert trace_dump["agent"]["role"] == "Writer"
|
||||
assert trace_dump["task"]["name"] == "Writing Task"
|
||||
assert trace_dump["output"] == "Final output"
|
||||
|
||||
|
||||
class TestTaskEventFieldSerializers:
|
||||
"""Test that task event field serializers respond to trace context."""
|
||||
|
||||
def test_task_started_trace_context(self):
|
||||
from crewai.events.types.task_events import TaskStartedEvent
|
||||
|
||||
task = _make_mock_task(name="Test Task")
|
||||
event = TaskStartedEvent(task=task, context="some context")
|
||||
|
||||
trace_dump = event.model_dump(context={"trace": True})
|
||||
assert trace_dump["task"] == {"id": str(task.id), "name": "Test Task"}
|
||||
assert trace_dump["context"] == "some context"
|
||||
|
||||
def test_task_failed_trace_context(self):
|
||||
from crewai.events.types.task_events import TaskFailedEvent
|
||||
|
||||
task = _make_mock_task(name="Failing Task")
|
||||
event = TaskFailedEvent(task=task, error="Task failed")
|
||||
|
||||
trace_dump = event.model_dump(context={"trace": True})
|
||||
assert trace_dump["task"]["name"] == "Failing Task"
|
||||
assert trace_dump["error"] == "Task failed"
|
||||
|
||||
|
||||
class TestCrewEventFieldSerializers:
|
||||
"""Test that crew event field serializers respond to trace context."""
|
||||
|
||||
def test_crew_kickoff_started_excludes_crew_in_trace(self):
|
||||
from crewai.events.types.crew_events import CrewKickoffStartedEvent
|
||||
|
||||
crew = MagicMock()
|
||||
crew.fingerprint = MagicMock()
|
||||
crew.fingerprint.uuid_str = str(uuid.uuid4())
|
||||
crew.fingerprint.metadata = {}
|
||||
|
||||
event = CrewKickoffStartedEvent(
|
||||
crew=crew, crew_name="TestCrew", inputs={"key": "value"}
|
||||
)
|
||||
|
||||
trace_dump = event.model_dump(context={"trace": True})
|
||||
# crew field should be None in trace context
|
||||
assert trace_dump["crew"] is None
|
||||
# scalar fields preserved
|
||||
assert trace_dump["crew_name"] == "TestCrew"
|
||||
assert trace_dump["inputs"] == {"key": "value"}
|
||||
|
||||
def test_crew_event_no_context_preserves_crew(self):
|
||||
from crewai.events.types.crew_events import CrewKickoffStartedEvent
|
||||
|
||||
crew = MagicMock()
|
||||
crew.fingerprint = MagicMock()
|
||||
crew.fingerprint.uuid_str = str(uuid.uuid4())
|
||||
crew.fingerprint.metadata = {}
|
||||
|
||||
event = CrewKickoffStartedEvent(
|
||||
crew=crew, crew_name="TestCrew", inputs=None
|
||||
)
|
||||
|
||||
normal_dump = event.model_dump()
|
||||
# Without trace context, crew should NOT be None (field serializer didn't fire)
|
||||
assert normal_dump["crew"] is not None
|
||||
|
||||
|
||||
class TestLLMEventFieldSerializers:
|
||||
"""Test that LLM event field serializers respond to trace context."""
|
||||
|
||||
def test_llm_call_started_excludes_callbacks_in_trace(self):
|
||||
from crewai.events.types.llm_events import LLMCallStartedEvent
|
||||
|
||||
event = LLMCallStartedEvent(
|
||||
call_id="test-call",
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
tools=[{"name": "search", "description": "Search tool"}],
|
||||
callbacks=[MagicMock(), MagicMock()],
|
||||
available_functions={"search": MagicMock()},
|
||||
)
|
||||
|
||||
trace_dump = event.model_dump(context={"trace": True})
|
||||
# callbacks and available_functions excluded
|
||||
assert trace_dump["callbacks"] is None
|
||||
assert trace_dump["available_functions"] is None
|
||||
# tools preserved (lightweight list of dicts)
|
||||
assert trace_dump["tools"] == [{"name": "search", "description": "Search tool"}]
|
||||
# messages preserved
|
||||
assert trace_dump["messages"] == [{"role": "user", "content": "Hello"}]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration tests: safe_serialize_to_dict with context
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSafeSerializeWithContext:
|
||||
"""Test that safe_serialize_to_dict properly passes context through."""
|
||||
|
||||
def test_context_flows_through_to_field_serializers(self):
|
||||
from crewai.events.types.agent_events import AgentExecutionErrorEvent
|
||||
|
||||
agent = _make_stub_agent(role="Worker")
|
||||
task = _make_mock_task(name="Work Task")
|
||||
|
||||
event = AgentExecutionErrorEvent(
|
||||
agent=agent, task=task, error="error msg"
|
||||
)
|
||||
|
||||
result = safe_serialize_to_dict(event, context={"trace": True})
|
||||
# Field serializers should have fired
|
||||
assert result["agent"] == {"id": str(agent.id), "role": "Worker"}
|
||||
assert result["task"] == {"id": str(task.id), "name": "Work Task"}
|
||||
assert result["error"] == "error msg"
|
||||
|
||||
def test_no_context_preserves_full_serialization(self):
|
||||
from crewai.events.types.task_events import TaskFailedEvent
|
||||
|
||||
task = _make_mock_task(name="Test")
|
||||
event = TaskFailedEvent(task=task, error="fail")
|
||||
|
||||
result = safe_serialize_to_dict(event)
|
||||
# Without context, task should not be a lightweight ref
|
||||
assert result.get("task") is not None
|
||||
# It should be the raw object (model_dump returns it as-is for Any fields)
|
||||
# to_serializable will then repr() or process it further
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration tests: TraceCollectionListener._build_event_data
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBuildEventData:
|
||||
@pytest.fixture
|
||||
def listener(self):
|
||||
from crewai.events.listeners.tracing.trace_listener import (
|
||||
TraceCollectionListener,
|
||||
)
|
||||
TraceCollectionListener._instance = None
|
||||
TraceCollectionListener._initialized = False
|
||||
TraceCollectionListener._listeners_setup = False
|
||||
return TraceCollectionListener()
|
||||
|
||||
def test_crew_kickoff_started_has_crew_structure(self, listener):
|
||||
agent = _make_stub_agent(role="Researcher")
|
||||
agent.tools = [_make_stub_tool("search"), _make_stub_tool("read")]
|
||||
|
||||
task = _make_mock_task(name="Research Task", agent=agent)
|
||||
task.context = None
|
||||
|
||||
crew = MagicMock()
|
||||
crew.agents = [agent]
|
||||
crew.tasks = [task]
|
||||
crew.process = "sequential"
|
||||
crew.verbose = True
|
||||
crew.memory = False
|
||||
crew.fingerprint = MagicMock()
|
||||
crew.fingerprint.uuid_str = str(uuid.uuid4())
|
||||
crew.fingerprint.metadata = {}
|
||||
|
||||
from crewai.events.types.crew_events import CrewKickoffStartedEvent
|
||||
event = CrewKickoffStartedEvent(
|
||||
crew=crew, crew_name="TestCrew", inputs={"key": "value"}
|
||||
)
|
||||
|
||||
result = listener._build_event_data("crew_kickoff_started", event, None)
|
||||
|
||||
assert "crew_structure" in result
|
||||
cs = result["crew_structure"]
|
||||
assert len(cs["agents"]) == 1
|
||||
assert cs["agents"][0]["role"] == "Researcher"
|
||||
assert cs["agents"][0]["tool_names"] == ["search", "read"]
|
||||
assert len(cs["tasks"]) == 1
|
||||
assert cs["tasks"][0]["name"] == "Research Task"
|
||||
assert "agent_ref" in cs["tasks"][0]
|
||||
assert cs["tasks"][0]["agent_ref"]["role"] == "Researcher"
|
||||
|
||||
def test_crew_kickoff_started_context_task_ids(self, listener):
|
||||
agent = _make_stub_agent()
|
||||
task1 = _make_mock_task(name="Task 1", agent=agent)
|
||||
task1.context = None
|
||||
task2 = _make_mock_task(name="Task 2", agent=agent)
|
||||
task2.context = [task1]
|
||||
|
||||
crew = MagicMock()
|
||||
crew.agents = [agent]
|
||||
crew.tasks = [task1, task2]
|
||||
crew.process = "sequential"
|
||||
crew.verbose = False
|
||||
crew.memory = False
|
||||
crew.fingerprint = MagicMock()
|
||||
crew.fingerprint.uuid_str = str(uuid.uuid4())
|
||||
crew.fingerprint.metadata = {}
|
||||
|
||||
from crewai.events.types.crew_events import CrewKickoffStartedEvent
|
||||
event = CrewKickoffStartedEvent(
|
||||
crew=crew, crew_name="TestCrew", inputs=None
|
||||
)
|
||||
|
||||
result = listener._build_event_data("crew_kickoff_started", event, None)
|
||||
task2_data = result["crew_structure"]["tasks"][1]
|
||||
assert "context_task_ids" in task2_data
|
||||
assert str(task1.id) in task2_data["context_task_ids"]
|
||||
|
||||
def test_generic_event_uses_trace_context(self, listener):
|
||||
"""Non-complex events should use context-based serialization."""
|
||||
from crewai.events.types.crew_events import CrewKickoffCompletedEvent
|
||||
|
||||
crew = MagicMock()
|
||||
crew.fingerprint = MagicMock()
|
||||
crew.fingerprint.uuid_str = str(uuid.uuid4())
|
||||
crew.fingerprint.metadata = {}
|
||||
|
||||
event = CrewKickoffCompletedEvent(
|
||||
crew=crew, crew_name="TestCrew", output="Final result", total_tokens=5000
|
||||
)
|
||||
|
||||
result = listener._build_event_data("crew_kickoff_completed", event, None)
|
||||
|
||||
# Scalar fields preserved
|
||||
assert result.get("crew_name") == "TestCrew"
|
||||
assert result.get("total_tokens") == 5000
|
||||
# crew excluded by field serializer
|
||||
assert result.get("crew") is None
|
||||
# No crew_structure (that's only for kickoff_started)
|
||||
assert "crew_structure" not in result
|
||||
|
||||
def test_task_started_custom_projection(self, listener):
|
||||
task = _make_mock_task(name="Test Task")
|
||||
from crewai.events.types.task_events import TaskStartedEvent
|
||||
event = TaskStartedEvent(task=task, context="test context")
|
||||
source = MagicMock()
|
||||
source.agent = _make_stub_agent(role="Worker")
|
||||
|
||||
result = listener._build_event_data("task_started", event, source)
|
||||
|
||||
assert result["task_name"] == "Test Task"
|
||||
assert result["agent_role"] == "Worker"
|
||||
assert result["task_id"] == str(task.id)
|
||||
assert result["context"] == "test context"
|
||||
|
||||
def test_llm_call_started_uses_trace_context(self, listener):
|
||||
from crewai.events.types.llm_events import LLMCallStartedEvent
|
||||
|
||||
event = LLMCallStartedEvent(
|
||||
call_id="test",
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
tools=[{"name": "search"}],
|
||||
callbacks=[MagicMock()],
|
||||
available_functions={"fn": MagicMock()},
|
||||
)
|
||||
|
||||
result = listener._build_event_data("llm_call_started", event, None)
|
||||
|
||||
# callbacks and available_functions excluded via field serializer
|
||||
assert result.get("callbacks") is None
|
||||
assert result.get("available_functions") is None
|
||||
# tools preserved (lightweight schemas)
|
||||
assert result.get("tools") == [{"name": "search"}]
|
||||
|
||||
def test_agent_execution_error_preserves_identification(self, listener):
|
||||
"""Error events should preserve agent/task identification via field serializers."""
|
||||
from crewai.events.types.agent_events import AgentExecutionErrorEvent
|
||||
|
||||
agent = _make_stub_agent(role="Analyst")
|
||||
task = _make_mock_task(name="Analysis")
|
||||
|
||||
event = AgentExecutionErrorEvent(
|
||||
agent=agent, task=task, error="Something broke"
|
||||
)
|
||||
|
||||
result = listener._build_event_data("agent_execution_error", event, None)
|
||||
|
||||
# Field serializers return lightweight refs, not None
|
||||
assert result["agent"] == {"id": str(agent.id), "role": "Analyst"}
|
||||
assert result["task"] == {"id": str(task.id), "name": "Analysis"}
|
||||
assert result["error"] == "Something broke"
|
||||
|
||||
def test_task_failed_preserves_identification(self, listener):
|
||||
from crewai.events.types.task_events import TaskFailedEvent
|
||||
|
||||
task = _make_mock_task(name="Failed Task")
|
||||
event = TaskFailedEvent(task=task, error="Task failed")
|
||||
|
||||
result = listener._build_event_data("task_failed", event, None)
|
||||
|
||||
assert result["task"] == {"id": str(task.id), "name": "Failed Task"}
|
||||
assert result["error"] == "Task failed"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Size reduction verification
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSizeReduction:
|
||||
@pytest.fixture
|
||||
def listener(self):
|
||||
from crewai.events.listeners.tracing.trace_listener import (
|
||||
TraceCollectionListener,
|
||||
)
|
||||
TraceCollectionListener._instance = None
|
||||
TraceCollectionListener._initialized = False
|
||||
TraceCollectionListener._listeners_setup = False
|
||||
return TraceCollectionListener()
|
||||
|
||||
def test_task_started_event_size(self, listener):
|
||||
"""task_started event data should be well under 2KB."""
|
||||
agent = _make_stub_agent(
|
||||
role="Researcher",
|
||||
goal="Research" * 50,
|
||||
backstory="Expert" * 100,
|
||||
)
|
||||
agent.tools = [_make_stub_tool(f"tool_{i}") for i in range(5)]
|
||||
|
||||
task = _make_mock_task(
|
||||
name="Research Task",
|
||||
description="Detailed description" * 20,
|
||||
expected_output="Expected" * 10,
|
||||
agent=agent,
|
||||
)
|
||||
task.context = [_make_mock_task() for _ in range(3)]
|
||||
task.tools = [_make_stub_tool(f"t_{i}") for i in range(3)]
|
||||
|
||||
from crewai.events.types.task_events import TaskStartedEvent
|
||||
event = TaskStartedEvent(task=task, context="test context")
|
||||
source = MagicMock()
|
||||
source.agent = agent
|
||||
|
||||
result = listener._build_event_data("task_started", event, source)
|
||||
serialized = json.dumps(result, default=str)
|
||||
|
||||
assert len(serialized) < 2000, f"task_started too large: {len(serialized)} bytes"
|
||||
assert "task_name" in result
|
||||
assert "agent_role" in result
|
||||
|
||||
def test_error_event_size(self, listener):
|
||||
"""Error events should be small despite having agent/task refs."""
|
||||
from crewai.events.types.agent_events import AgentExecutionErrorEvent
|
||||
|
||||
agent = _make_stub_agent(
|
||||
goal="Very long goal " * 100,
|
||||
backstory="Very long backstory " * 100,
|
||||
)
|
||||
task = _make_mock_task(description="Very long description " * 100)
|
||||
|
||||
event = AgentExecutionErrorEvent(
|
||||
agent=agent, task=task, error="error"
|
||||
)
|
||||
|
||||
result = listener._build_event_data("agent_execution_error", event, None)
|
||||
serialized = json.dumps(result, default=str)
|
||||
|
||||
# Should be small - agent/task are just {id, role/name} refs
|
||||
assert len(serialized) < 5000, f"error event too large: {len(serialized)} bytes"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# to_serializable context threading
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestToSerializableContext:
|
||||
"""Test that context parameter flows through to_serializable correctly."""
|
||||
|
||||
def test_context_passed_to_model_dump(self):
|
||||
from crewai.events.types.agent_events import AgentExecutionErrorEvent
|
||||
|
||||
agent = _make_stub_agent(role="Tester")
|
||||
task = _make_mock_task(name="Test Task")
|
||||
|
||||
event = AgentExecutionErrorEvent(
|
||||
agent=agent, task=task, error="test error"
|
||||
)
|
||||
|
||||
# Directly use to_serializable with context
|
||||
result = to_serializable(event, context={"trace": True})
|
||||
assert isinstance(result, dict)
|
||||
assert result["agent"] == {"id": str(agent.id), "role": "Tester"}
|
||||
assert result["task"] == {"id": str(task.id), "name": "Test Task"}
|
||||
|
||||
def test_no_context_does_not_trigger_serializers(self):
|
||||
from crewai.events.types.crew_events import CrewKickoffStartedEvent
|
||||
|
||||
crew = MagicMock()
|
||||
crew.fingerprint = MagicMock()
|
||||
crew.fingerprint.uuid_str = str(uuid.uuid4())
|
||||
crew.fingerprint.metadata = {}
|
||||
|
||||
event = CrewKickoffStartedEvent(
|
||||
crew=crew, crew_name="Test", inputs=None
|
||||
)
|
||||
|
||||
# Without context, crew should NOT be None
|
||||
result = event.model_dump()
|
||||
assert result["crew"] is not None
|
||||
@@ -879,6 +879,35 @@ def test_llm_emits_call_started_event():
|
||||
assert started_events[0].task_id is None
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_llm_completed_event_includes_usage():
|
||||
completed_events: list[LLMCallCompletedEvent] = []
|
||||
condition = threading.Condition()
|
||||
|
||||
@crewai_event_bus.on(LLMCallCompletedEvent)
|
||||
def handle_llm_call_completed(source, event):
|
||||
with condition:
|
||||
completed_events.append(event)
|
||||
condition.notify()
|
||||
|
||||
llm = LLM(model="gpt-4o-mini")
|
||||
llm.call("Say hello")
|
||||
|
||||
with condition:
|
||||
success = condition.wait_for(
|
||||
lambda: len(completed_events) >= 1,
|
||||
timeout=10,
|
||||
)
|
||||
assert success, "Timeout waiting for LLMCallCompletedEvent"
|
||||
|
||||
event = completed_events[0]
|
||||
assert event.usage is not None
|
||||
assert isinstance(event.usage, dict)
|
||||
assert event.usage.get("prompt_tokens", 0) > 0
|
||||
assert event.usage.get("completion_tokens", 0) > 0
|
||||
assert event.usage.get("total_tokens", 0) > 0
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_llm_emits_call_failed_event():
|
||||
received_events = []
|
||||
|
||||
2
uv.lock
generated
2
uv.lock
generated
@@ -1243,7 +1243,7 @@ requires-dist = [
|
||||
{ name = "json-repair", specifier = "~=0.25.2" },
|
||||
{ name = "json5", specifier = "~=0.10.0" },
|
||||
{ name = "jsonref", specifier = "~=1.1.0" },
|
||||
{ name = "lancedb", specifier = ">=0.29.2" },
|
||||
{ name = "lancedb", specifier = ">=0.29.2,<0.30.1" },
|
||||
{ name = "litellm", marker = "extra == 'litellm'", specifier = ">=1.74.9,<=1.82.6" },
|
||||
{ name = "mcp", specifier = "~=1.26.0" },
|
||||
{ name = "mem0ai", marker = "extra == 'mem0'", specifier = "~=0.1.94" },
|
||||
|
||||
Reference in New Issue
Block a user