mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-03-31 16:18:14 +00:00
Compare commits
4 Commits
gl/refacto
...
fix/trace-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fad083ffaa | ||
|
|
b753012fc8 | ||
|
|
b40098b28e | ||
|
|
a4f1164812 |
@@ -6,7 +6,6 @@ import warnings
|
||||
|
||||
from crewai.agent.core import Agent
|
||||
from crewai.agent.planning_config import PlanningConfig
|
||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||
from crewai.crew import Crew
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.flow.flow import Flow
|
||||
@@ -20,9 +19,6 @@ from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.telemetry.telemetry import Telemetry
|
||||
|
||||
|
||||
CrewAgentExecutor.model_rebuild()
|
||||
|
||||
|
||||
def _suppress_pydantic_deprecation_warnings() -> None:
|
||||
"""Suppress Pydantic deprecation warnings using targeted monkey patch."""
|
||||
original_warn = warnings.warn
|
||||
|
||||
@@ -25,6 +25,7 @@ from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
InstanceOf,
|
||||
PrivateAttr,
|
||||
model_validator,
|
||||
)
|
||||
@@ -166,10 +167,10 @@ class Agent(BaseAgent):
|
||||
default=True,
|
||||
description="Use system prompt for the agent.",
|
||||
)
|
||||
llm: str | BaseLLM | None = Field(
|
||||
llm: str | InstanceOf[BaseLLM] | None = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
function_calling_llm: str | BaseLLM | None = Field(
|
||||
function_calling_llm: str | InstanceOf[BaseLLM] | None = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
system_template: str | None = Field(
|
||||
|
||||
@@ -12,6 +12,7 @@ from pydantic import (
|
||||
UUID4,
|
||||
BaseModel,
|
||||
Field,
|
||||
InstanceOf,
|
||||
PrivateAttr,
|
||||
field_validator,
|
||||
model_validator,
|
||||
@@ -184,7 +185,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
default=None,
|
||||
description="Knowledge sources for the agent.",
|
||||
)
|
||||
knowledge_storage: BaseKnowledgeStorage | None = Field(
|
||||
knowledge_storage: InstanceOf[BaseKnowledgeStorage] | None = Field(
|
||||
default=None,
|
||||
description="Custom knowledge storage for the agent.",
|
||||
)
|
||||
|
||||
@@ -14,15 +14,8 @@ import inspect
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
PrivateAttr,
|
||||
ValidationError,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import Self
|
||||
from pydantic import BaseModel, GetCoreSchemaHandler, ValidationError
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
|
||||
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
|
||||
from crewai.agents.parser import (
|
||||
@@ -30,7 +23,6 @@ from crewai.agents.parser import (
|
||||
AgentFinish,
|
||||
OutputParserError,
|
||||
)
|
||||
from crewai.agents.tools_handler import ToolsHandler
|
||||
from crewai.core.providers.human_input import ExecutorContext, get_provider
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.logging_events import (
|
||||
@@ -46,9 +38,6 @@ from crewai.hooks.tool_hooks import (
|
||||
get_after_tool_call_hooks,
|
||||
get_before_tool_call_hooks,
|
||||
)
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
from crewai.utilities.agent_utils import (
|
||||
aget_llm_response,
|
||||
convert_tools_to_openai_schema,
|
||||
@@ -70,65 +59,106 @@ from crewai.utilities.constants import TRAINING_DATA_FILE
|
||||
from crewai.utilities.file_store import aget_all_files, get_all_files
|
||||
from crewai.utilities.i18n import I18N, get_i18n
|
||||
from crewai.utilities.printer import Printer
|
||||
from crewai.utilities.prompts import StandardPromptResult, SystemPromptResult
|
||||
from crewai.utilities.string_utils import sanitize_tool_name
|
||||
from crewai.utilities.tool_utils import (
|
||||
aexecute_tool_and_check_finality,
|
||||
execute_tool_and_check_finality,
|
||||
)
|
||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent import Agent
|
||||
from crewai.agents.tools_handler import ToolsHandler
|
||||
from crewai.crew import Crew
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.task import Task
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
from crewai.tools.tool_types import ToolResult
|
||||
from crewai.utilities.prompts import StandardPromptResult, SystemPromptResult
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
class CrewAgentExecutor(BaseModel, CrewAgentExecutorMixin):
|
||||
class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
"""Executor for crew agents.
|
||||
|
||||
Manages the execution lifecycle of an agent including prompt formatting,
|
||||
LLM interactions, tool execution, and feedback handling.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True)
|
||||
def __init__(
|
||||
self,
|
||||
llm: BaseLLM,
|
||||
task: Task,
|
||||
crew: Crew,
|
||||
agent: Agent,
|
||||
prompt: SystemPromptResult | StandardPromptResult,
|
||||
max_iter: int,
|
||||
tools: list[CrewStructuredTool],
|
||||
tools_names: str,
|
||||
stop_words: list[str],
|
||||
tools_description: str,
|
||||
tools_handler: ToolsHandler,
|
||||
step_callback: Any = None,
|
||||
original_tools: list[BaseTool] | None = None,
|
||||
function_calling_llm: BaseLLM | Any | None = None,
|
||||
respect_context_window: bool = False,
|
||||
request_within_rpm_limit: Callable[[], bool] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
i18n: I18N | None = None,
|
||||
) -> None:
|
||||
"""Initialize executor.
|
||||
|
||||
llm: BaseLLM
|
||||
task: Task | None = None
|
||||
crew: Crew | None = None
|
||||
agent: Agent
|
||||
prompt: SystemPromptResult | StandardPromptResult
|
||||
max_iter: int
|
||||
tools: list[CrewStructuredTool]
|
||||
tools_names: str
|
||||
stop: list[str] = Field(alias="stop_words")
|
||||
tools_description: str
|
||||
tools_handler: ToolsHandler
|
||||
step_callback: Any = None
|
||||
original_tools: list[BaseTool] = Field(default_factory=list)
|
||||
function_calling_llm: BaseLLM | Any | None = None
|
||||
respect_context_window: bool = False
|
||||
request_within_rpm_limit: Callable[[], bool] | None = None
|
||||
callbacks: list[Any] = Field(default_factory=list)
|
||||
response_model: type[BaseModel] | None = None
|
||||
i18n: I18N | None = Field(default=None, exclude=True)
|
||||
ask_for_human_input: bool = False
|
||||
messages: list[LLMMessage] = Field(default_factory=list)
|
||||
iterations: int = 0
|
||||
log_error_after: int = 3
|
||||
before_llm_call_hooks: list[Callable[..., Any]] = Field(default_factory=list)
|
||||
after_llm_call_hooks: list[Callable[..., Any]] = Field(default_factory=list)
|
||||
_i18n: I18N = PrivateAttr()
|
||||
_printer: Printer = PrivateAttr(default_factory=Printer)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _init_executor(self) -> Self:
|
||||
self._i18n = self.i18n or get_i18n()
|
||||
Args:
|
||||
llm: Language model instance.
|
||||
task: Task to execute.
|
||||
crew: Crew instance.
|
||||
agent: Agent to execute.
|
||||
prompt: Prompt templates.
|
||||
max_iter: Maximum iterations.
|
||||
tools: Available tools.
|
||||
tools_names: Tool names string.
|
||||
stop_words: Stop word list.
|
||||
tools_description: Tool descriptions.
|
||||
tools_handler: Tool handler instance.
|
||||
step_callback: Optional step callback.
|
||||
original_tools: Original tool list.
|
||||
function_calling_llm: Optional function calling LLM.
|
||||
respect_context_window: Respect context limits.
|
||||
request_within_rpm_limit: RPM limit check function.
|
||||
callbacks: Optional callbacks list.
|
||||
response_model: Optional Pydantic model for structured outputs.
|
||||
"""
|
||||
self._i18n: I18N = i18n or get_i18n()
|
||||
self.llm = llm
|
||||
self.task = task
|
||||
self.agent = agent
|
||||
self.crew = crew
|
||||
self.prompt = prompt
|
||||
self.tools = tools
|
||||
self.tools_names = tools_names
|
||||
self.stop = stop_words
|
||||
self.max_iter = max_iter
|
||||
self.callbacks = callbacks or []
|
||||
self._printer: Printer = Printer()
|
||||
self.tools_handler = tools_handler
|
||||
self.original_tools = original_tools or []
|
||||
self.step_callback = step_callback
|
||||
self.tools_description = tools_description
|
||||
self.function_calling_llm = function_calling_llm
|
||||
self.respect_context_window = respect_context_window
|
||||
self.request_within_rpm_limit = request_within_rpm_limit
|
||||
self.response_model = response_model
|
||||
self.ask_for_human_input = False
|
||||
self.messages: list[LLMMessage] = []
|
||||
self.iterations = 0
|
||||
self.log_error_after = 3
|
||||
self.before_llm_call_hooks: list[Callable[..., Any]] = []
|
||||
self.after_llm_call_hooks: list[Callable[..., Any]] = []
|
||||
self.before_llm_call_hooks.extend(get_before_llm_call_hooks())
|
||||
self.after_llm_call_hooks.extend(get_after_llm_call_hooks())
|
||||
if self.llm:
|
||||
@@ -141,7 +171,6 @@ class CrewAgentExecutor(BaseModel, CrewAgentExecutorMixin):
|
||||
else self.stop
|
||||
)
|
||||
)
|
||||
return self
|
||||
|
||||
@property
|
||||
def use_stop_words(self) -> bool:
|
||||
@@ -1658,3 +1687,14 @@ class CrewAgentExecutor(BaseModel, CrewAgentExecutorMixin):
|
||||
return format_message_for_llm(
|
||||
self._i18n.slice("feedback_instructions").format(feedback=feedback)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, _source_type: Any, _handler: GetCoreSchemaHandler
|
||||
) -> CoreSchema:
|
||||
"""Generate Pydantic core schema for BaseClient Protocol.
|
||||
|
||||
This allows the Protocol to be used in Pydantic models without
|
||||
requiring arbitrary_types_allowed=True.
|
||||
"""
|
||||
return core_schema.any_schema()
|
||||
|
||||
@@ -22,6 +22,7 @@ from pydantic import (
|
||||
UUID4,
|
||||
BaseModel,
|
||||
Field,
|
||||
InstanceOf,
|
||||
Json,
|
||||
PrivateAttr,
|
||||
field_validator,
|
||||
@@ -175,7 +176,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
_rpm_controller: RPMController = PrivateAttr()
|
||||
_logger: Logger = PrivateAttr()
|
||||
_file_handler: FileHandler = PrivateAttr()
|
||||
_cache_handler: CacheHandler = PrivateAttr(default_factory=CacheHandler)
|
||||
_cache_handler: InstanceOf[CacheHandler] = PrivateAttr(default_factory=CacheHandler)
|
||||
_memory: Memory | MemoryScope | MemorySlice | None = PrivateAttr(default=None)
|
||||
_train: bool | None = PrivateAttr(default=False)
|
||||
_train_iteration: int | None = PrivateAttr()
|
||||
@@ -209,13 +210,13 @@ class Crew(FlowTrackable, BaseModel):
|
||||
default=None,
|
||||
description="Metrics for the LLM usage during all tasks execution.",
|
||||
)
|
||||
manager_llm: str | BaseLLM | None = Field(
|
||||
manager_llm: str | InstanceOf[BaseLLM] | None = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
manager_agent: BaseAgent | None = Field(
|
||||
description="Custom agent that will be used as manager.", default=None
|
||||
)
|
||||
function_calling_llm: str | LLM | None = Field(
|
||||
function_calling_llm: str | InstanceOf[LLM] | None = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
config: Json[dict[str, Any]] | dict[str, Any] | None = Field(default=None)
|
||||
@@ -266,7 +267,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
default=False,
|
||||
description="Plan the crew execution and add the plan to the crew.",
|
||||
)
|
||||
planning_llm: str | BaseLLM | Any | None = Field(
|
||||
planning_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Language model that will run the AgentPlanner if planning is True."
|
||||
@@ -287,7 +288,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
"knowledge object."
|
||||
),
|
||||
)
|
||||
chat_llm: str | BaseLLM | Any | None = Field(
|
||||
chat_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
|
||||
default=None,
|
||||
description="LLM used to handle chatting with the crew.",
|
||||
)
|
||||
@@ -1799,7 +1800,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
def test(
|
||||
self,
|
||||
n_iterations: int,
|
||||
eval_llm: str | BaseLLM,
|
||||
eval_llm: str | InstanceOf[BaseLLM],
|
||||
inputs: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Test and evaluate the Crew with the given inputs for n iterations.
|
||||
|
||||
@@ -126,18 +126,59 @@ from crewai.events.types.tool_usage_events import (
|
||||
from crewai.events.utils.console_formatter import ConsoleFormatter
|
||||
|
||||
|
||||
# Fields to exclude from trace serialization to reduce redundant data.
|
||||
# These back-references and heavy objects create massive bloat when serialized
|
||||
# repeatedly across events (crew->agents->tasks->agent creates circular refs).
|
||||
TRACE_EXCLUDE_FIELDS = {
|
||||
# Back-references that create redundant/circular data
|
||||
"crew",
|
||||
"agent",
|
||||
"agents",
|
||||
"task",
|
||||
"tasks",
|
||||
"context",
|
||||
# Heavy fields not needed in individual trace events
|
||||
# NOTE: "tools" intentionally NOT here - LLMCallStartedEvent.tools is lightweight
|
||||
# (list of tool schemas). Agent.tools is excluded in _build_crew_started_data.
|
||||
"llm",
|
||||
"function_calling_llm",
|
||||
"step_callback",
|
||||
"task_callback",
|
||||
"crew_callback",
|
||||
"callbacks",
|
||||
"_memory",
|
||||
"_cache",
|
||||
"_rpm_controller",
|
||||
"_request_within_rpm_limit",
|
||||
"_token_process",
|
||||
"knowledge_sources",
|
||||
}
|
||||
|
||||
|
||||
def _serialize_for_trace(
|
||||
event: Any, extra_exclude: set[str] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Serialize an event for tracing, excluding redundant back-references.
|
||||
|
||||
Keeps all scalar fields (agent_role, task_name, etc.) that the AMP frontend uses.
|
||||
Replaces heavy nested objects with lightweight ID references to reduce trace bloat.
|
||||
|
||||
Args:
|
||||
event: The event object to serialize.
|
||||
extra_exclude: Additional fields to exclude beyond TRACE_EXCLUDE_FIELDS.
|
||||
|
||||
Returns:
|
||||
A dictionary with the serialized event data.
|
||||
"""
|
||||
exclude = TRACE_EXCLUDE_FIELDS.copy()
|
||||
if extra_exclude:
|
||||
exclude.update(extra_exclude)
|
||||
return safe_serialize_to_dict(event, exclude=exclude)
|
||||
|
||||
|
||||
class TraceCollectionListener(BaseEventListener):
|
||||
"""Trace collection listener that orchestrates trace collection."""
|
||||
|
||||
complex_events: ClassVar[list[str]] = [
|
||||
"task_started",
|
||||
"task_completed",
|
||||
"llm_call_started",
|
||||
"llm_call_completed",
|
||||
"agent_execution_started",
|
||||
"agent_execution_completed",
|
||||
]
|
||||
|
||||
_instance: Self | None = None
|
||||
_initialized: bool = False
|
||||
_listeners_setup: bool = False
|
||||
@@ -810,9 +851,17 @@ class TraceCollectionListener(BaseEventListener):
|
||||
def _build_event_data(
|
||||
self, event_type: str, event: Any, source: Any
|
||||
) -> dict[str, Any]:
|
||||
"""Build event data"""
|
||||
if event_type not in self.complex_events:
|
||||
return safe_serialize_to_dict(event)
|
||||
"""Build event data with optimized serialization to reduce trace bloat.
|
||||
|
||||
For most events, excludes heavy nested objects (crew, agents, tasks, tools)
|
||||
that would create massive redundant data. Only crew_kickoff_started gets
|
||||
the full crew structure as a one-time dump.
|
||||
"""
|
||||
# crew_kickoff_started is special: include full crew structure ONCE
|
||||
if event_type == "crew_kickoff_started":
|
||||
return self._build_crew_started_data(event)
|
||||
|
||||
# Complex events have custom handling that already extracts only needed fields
|
||||
if event_type == "task_started":
|
||||
task_name = event.task.name or event.task.description
|
||||
task_display_name = (
|
||||
@@ -853,19 +902,101 @@ class TraceCollectionListener(BaseEventListener):
|
||||
"agent_backstory": event.agent.backstory,
|
||||
}
|
||||
if event_type == "llm_call_started":
|
||||
event_data = safe_serialize_to_dict(event)
|
||||
event_data = _serialize_for_trace(event)
|
||||
event_data["task_name"] = event.task_name or getattr(
|
||||
event, "task_description", None
|
||||
)
|
||||
return event_data
|
||||
if event_type == "llm_call_completed":
|
||||
return safe_serialize_to_dict(event)
|
||||
return _serialize_for_trace(event)
|
||||
|
||||
return {
|
||||
"event_type": event_type,
|
||||
"event": safe_serialize_to_dict(event),
|
||||
"source": source,
|
||||
}
|
||||
# Error events need agent/task identification extracted before generic
|
||||
# serialization strips them (agent/task are in TRACE_EXCLUDE_FIELDS)
|
||||
if event_type == "agent_execution_error":
|
||||
event_data = _serialize_for_trace(event)
|
||||
if event.agent:
|
||||
event_data["agent_role"] = getattr(event.agent, "role", None)
|
||||
event_data["agent_id"] = str(getattr(event.agent, "id", ""))
|
||||
return event_data
|
||||
if event_type == "task_failed":
|
||||
event_data = _serialize_for_trace(event)
|
||||
if event.task:
|
||||
event_data["task_name"] = getattr(event.task, "name", None) or getattr(
|
||||
event.task, "description", None
|
||||
)
|
||||
event_data["task_id"] = str(getattr(event.task, "id", ""))
|
||||
return event_data
|
||||
|
||||
# For all other events, use lightweight serialization
|
||||
return _serialize_for_trace(event)
|
||||
|
||||
def _build_crew_started_data(self, event: Any) -> dict[str, Any]:
|
||||
"""Build comprehensive crew structure for crew_kickoff_started event.
|
||||
|
||||
This is the ONE place where we serialize the full crew structure.
|
||||
Subsequent events use lightweight references to avoid redundancy.
|
||||
"""
|
||||
event_data = _serialize_for_trace(event)
|
||||
|
||||
# Add full crew structure with optimized agent/task serialization
|
||||
crew = getattr(event, "crew", None)
|
||||
if crew is not None:
|
||||
# Serialize agents with tools (first occurrence only)
|
||||
agents_data = []
|
||||
for agent in getattr(crew, "agents", []) or []:
|
||||
agent_data = {
|
||||
"id": str(getattr(agent, "id", "")),
|
||||
"role": getattr(agent, "role", ""),
|
||||
"goal": getattr(agent, "goal", ""),
|
||||
"backstory": getattr(agent, "backstory", ""),
|
||||
"verbose": getattr(agent, "verbose", False),
|
||||
"allow_delegation": getattr(agent, "allow_delegation", False),
|
||||
"max_iter": getattr(agent, "max_iter", None),
|
||||
"max_rpm": getattr(agent, "max_rpm", None),
|
||||
}
|
||||
# Include tool names (not full tool objects)
|
||||
tools = getattr(agent, "tools", None)
|
||||
if tools:
|
||||
agent_data["tool_names"] = [
|
||||
getattr(t, "name", str(t)) for t in tools
|
||||
]
|
||||
agents_data.append(agent_data)
|
||||
|
||||
# Serialize tasks with lightweight agent references
|
||||
tasks_data = []
|
||||
for task in getattr(crew, "tasks", []) or []:
|
||||
task_data = {
|
||||
"id": str(getattr(task, "id", "")),
|
||||
"name": getattr(task, "name", None),
|
||||
"description": getattr(task, "description", ""),
|
||||
"expected_output": getattr(task, "expected_output", ""),
|
||||
"async_execution": getattr(task, "async_execution", False),
|
||||
"human_input": getattr(task, "human_input", False),
|
||||
}
|
||||
# Replace full agent with lightweight reference
|
||||
task_agent = getattr(task, "agent", None)
|
||||
if task_agent:
|
||||
task_data["agent_ref"] = {
|
||||
"id": str(getattr(task_agent, "id", "")),
|
||||
"role": getattr(task_agent, "role", ""),
|
||||
}
|
||||
# Replace context tasks with lightweight references
|
||||
context_tasks = getattr(task, "context", None)
|
||||
if context_tasks:
|
||||
task_data["context_task_ids"] = [
|
||||
str(getattr(ct, "id", "")) for ct in context_tasks
|
||||
]
|
||||
tasks_data.append(task_data)
|
||||
|
||||
event_data["crew_structure"] = {
|
||||
"agents": agents_data,
|
||||
"tasks": tasks_data,
|
||||
"process": str(getattr(crew, "process", "")),
|
||||
"verbose": getattr(crew, "verbose", False),
|
||||
"memory": getattr(crew, "memory", False),
|
||||
}
|
||||
|
||||
return event_data
|
||||
|
||||
def _show_tracing_disabled_message(self) -> None:
|
||||
"""Show a message when tracing is disabled."""
|
||||
|
||||
@@ -1966,6 +1966,37 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
|
||||
"original_tool": original_tool,
|
||||
}
|
||||
|
||||
def _extract_tool_name(self, tool_call: Any) -> str:
|
||||
"""Extract tool name from various tool call formats."""
|
||||
if hasattr(tool_call, "function"):
|
||||
return sanitize_tool_name(tool_call.function.name)
|
||||
if hasattr(tool_call, "function_call") and tool_call.function_call:
|
||||
return sanitize_tool_name(tool_call.function_call.name)
|
||||
if hasattr(tool_call, "name"):
|
||||
return sanitize_tool_name(tool_call.name)
|
||||
if isinstance(tool_call, dict):
|
||||
func_info = tool_call.get("function", {})
|
||||
return sanitize_tool_name(
|
||||
func_info.get("name", "") or tool_call.get("name", "unknown")
|
||||
)
|
||||
return "unknown"
|
||||
|
||||
@router(execute_native_tool)
|
||||
def check_native_todo_completion(
|
||||
self,
|
||||
) -> Literal["todo_satisfied", "todo_not_satisfied"]:
|
||||
"""Check if the native tool execution satisfied the active todo.
|
||||
|
||||
Similar to check_todo_completion but for native tool execution path.
|
||||
"""
|
||||
current_todo = self.state.todos.current_todo
|
||||
|
||||
if not current_todo:
|
||||
return "todo_not_satisfied"
|
||||
|
||||
# For native tools, any tool execution satisfies the todo
|
||||
return "todo_satisfied"
|
||||
|
||||
@listen("initialized")
|
||||
def continue_iteration(self) -> Literal["check_iteration"]:
|
||||
"""Bridge listener that connects iteration loop back to iteration check."""
|
||||
|
||||
@@ -3,15 +3,12 @@ from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.rag.types import SearchResult
|
||||
|
||||
|
||||
class BaseKnowledgeStorage(BaseModel, ABC):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
class BaseKnowledgeStorage(ABC):
|
||||
"""Abstract base class for knowledge storage implementations."""
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -3,9 +3,6 @@ import traceback
|
||||
from typing import Any, cast
|
||||
import warnings
|
||||
|
||||
from pydantic import Field, PrivateAttr, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
|
||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
|
||||
@@ -25,32 +22,31 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
search efficiency.
|
||||
"""
|
||||
|
||||
collection_name: str | None = None
|
||||
embedder: (
|
||||
ProviderSpec
|
||||
def __init__(
|
||||
self,
|
||||
embedder: ProviderSpec
|
||||
| BaseEmbeddingsProvider[Any]
|
||||
| type[BaseEmbeddingsProvider[Any]]
|
||||
| None
|
||||
) = Field(default=None, exclude=True)
|
||||
_client: BaseClient | None = PrivateAttr(default=None)
|
||||
| None = None,
|
||||
collection_name: str | None = None,
|
||||
) -> None:
|
||||
self.collection_name = collection_name
|
||||
self._client: BaseClient | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _init_client(self) -> Self:
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=r".*'model_fields'.*is deprecated.*",
|
||||
module=r"^chromadb(\.|$)",
|
||||
)
|
||||
|
||||
if self.embedder:
|
||||
embedding_function = build_embedder(self.embedder) # type: ignore[arg-type]
|
||||
if embedder:
|
||||
embedding_function = build_embedder(embedder) # type: ignore[arg-type]
|
||||
config = ChromaDBConfig(
|
||||
embedding_function=cast(
|
||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||
)
|
||||
)
|
||||
self._client = create_client(config)
|
||||
return self
|
||||
|
||||
def _get_client(self) -> BaseClient:
|
||||
"""Get the appropriate client - instance-specific or global."""
|
||||
|
||||
@@ -22,6 +22,7 @@ from pydantic import (
|
||||
UUID4,
|
||||
BaseModel,
|
||||
Field,
|
||||
InstanceOf,
|
||||
PrivateAttr,
|
||||
field_validator,
|
||||
model_validator,
|
||||
@@ -203,7 +204,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
role: str = Field(description="Role of the agent")
|
||||
goal: str = Field(description="Goal of the agent")
|
||||
backstory: str = Field(description="Backstory of the agent")
|
||||
llm: str | BaseLLM | Any | None = Field(
|
||||
llm: str | InstanceOf[BaseLLM] | Any | None = Field(
|
||||
default=None, description="Language model that will run the agent"
|
||||
)
|
||||
tools: list[BaseTool] = Field(
|
||||
|
||||
@@ -20,7 +20,8 @@ from typing import (
|
||||
)
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
@@ -36,12 +37,7 @@ from crewai.events.types.tool_usage_events import (
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageStartedEvent,
|
||||
)
|
||||
from crewai.llms.base_llm import (
|
||||
BaseLLM,
|
||||
JsonResponseFormat,
|
||||
get_current_call_id,
|
||||
llm_call_context,
|
||||
)
|
||||
from crewai.llms.base_llm import BaseLLM, get_current_call_id, llm_call_context
|
||||
from crewai.llms.constants import (
|
||||
ANTHROPIC_MODELS,
|
||||
AZURE_MODELS,
|
||||
@@ -67,6 +63,8 @@ except ImportError:
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent.core import Agent
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
from crewai.llms.providers.anthropic.completion import AnthropicThinkingConfig
|
||||
from crewai.task import Task
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.utilities.types import LLMMessage
|
||||
@@ -344,27 +342,6 @@ class AccumulatedToolArgs(BaseModel):
|
||||
|
||||
class LLM(BaseLLM):
|
||||
completion_cost: float | None = None
|
||||
timeout: float | int | None = None
|
||||
top_p: float | None = None
|
||||
n: int | None = None
|
||||
max_completion_tokens: int | None = None
|
||||
max_tokens: int | float | None = None
|
||||
presence_penalty: float | None = None
|
||||
frequency_penalty: float | None = None
|
||||
logit_bias: dict[int, float] | None = None
|
||||
response_format: JsonResponseFormat | type[BaseModel] | None = None
|
||||
seed: int | None = None
|
||||
logprobs: int | None = None
|
||||
top_logprobs: int | None = None
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
callbacks: list[Any] | None = None
|
||||
reasoning_effort: Literal["none", "low", "medium", "high"] | None = None
|
||||
stream: bool = False
|
||||
interceptor: Any = None
|
||||
thinking: Any = None
|
||||
context_window_size: int = 0
|
||||
is_anthropic: bool = False
|
||||
|
||||
def __new__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> LLM:
|
||||
"""Factory method that routes to native SDK or falls back to LiteLLM.
|
||||
@@ -459,7 +436,10 @@ class LLM(BaseLLM):
|
||||
logger.error(error_msg)
|
||||
raise ImportError(error_msg) from None
|
||||
|
||||
return object.__new__(cls)
|
||||
instance = object.__new__(cls)
|
||||
super(LLM, instance).__init__(model=model, is_litellm=True, **kwargs)
|
||||
instance.is_litellm = True
|
||||
return instance
|
||||
|
||||
@classmethod
|
||||
def _matches_provider_pattern(cls, model: str, provider: str) -> bool:
|
||||
@@ -644,23 +624,89 @@ class LLM(BaseLLM):
|
||||
|
||||
return None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _validate_llm_fields(cls, data: Any) -> Any:
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
model = data.get("model", "")
|
||||
data["is_anthropic"] = cls._is_anthropic_model(model)
|
||||
return data
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
timeout: float | int | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
n: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | float | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[int, float] | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
seed: int | None = None,
|
||||
logprobs: int | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
base_url: str | None = None,
|
||||
api_base: str | None = None,
|
||||
api_version: str | None = None,
|
||||
api_key: str | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
reasoning_effort: Literal["none", "low", "medium", "high"] | None = None,
|
||||
stream: bool = False,
|
||||
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
|
||||
thinking: AnthropicThinkingConfig | dict[str, Any] | None = None,
|
||||
prefer_upload: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize LLM instance.
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _init_litellm(self) -> LLM:
|
||||
self.is_litellm = True
|
||||
if LITELLM_AVAILABLE:
|
||||
litellm.drop_params = True
|
||||
self.set_callbacks(self.callbacks or [])
|
||||
self.set_env_callbacks()
|
||||
return self
|
||||
Note: This __init__ method is only called for fallback instances.
|
||||
Native provider instances handle their own initialization in their respective classes.
|
||||
"""
|
||||
super().__init__(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
timeout=timeout,
|
||||
**kwargs,
|
||||
)
|
||||
self.model = model
|
||||
self.timeout = timeout
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
self.n = n
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
self.max_tokens = max_tokens
|
||||
self.presence_penalty = presence_penalty
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.logit_bias = logit_bias
|
||||
self.response_format = response_format
|
||||
self.seed = seed
|
||||
self.logprobs = logprobs
|
||||
self.top_logprobs = top_logprobs
|
||||
self.base_url = base_url
|
||||
self.api_base = api_base
|
||||
self.api_version = api_version
|
||||
self.api_key = api_key
|
||||
self.callbacks = callbacks
|
||||
self.context_window_size = 0
|
||||
self.reasoning_effort = reasoning_effort
|
||||
self.prefer_upload = prefer_upload
|
||||
self.additional_params = {
|
||||
k: v for k, v in kwargs.items() if k not in ("is_litellm", "provider")
|
||||
}
|
||||
self.is_anthropic = self._is_anthropic_model(model)
|
||||
self.stream = stream
|
||||
self.interceptor = interceptor
|
||||
|
||||
litellm.drop_params = True
|
||||
|
||||
# Normalize self.stop to always be a list[str]
|
||||
if stop is None:
|
||||
self.stop: list[str] = []
|
||||
elif isinstance(stop, str):
|
||||
self.stop = [stop]
|
||||
else:
|
||||
self.stop = stop
|
||||
|
||||
self.set_callbacks(callbacks or [])
|
||||
self.set_env_callbacks()
|
||||
|
||||
@staticmethod
|
||||
def _is_anthropic_model(model: str) -> bool:
|
||||
@@ -2396,7 +2442,7 @@ class LLM(BaseLLM):
|
||||
**filtered_params,
|
||||
)
|
||||
|
||||
def __deepcopy__(self, memo: dict[int, Any] | None = None) -> LLM:
|
||||
def __deepcopy__(self, memo: dict[int, Any] | None) -> LLM:
|
||||
"""Create a deep copy of the LLM instance."""
|
||||
import copy
|
||||
|
||||
|
||||
@@ -14,18 +14,10 @@ from datetime import datetime
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, Final, Literal
|
||||
from typing import TYPE_CHECKING, Any, Final
|
||||
import uuid
|
||||
|
||||
from pydantic import (
|
||||
AliasChoices,
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
PrivateAttr,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import TypedDict
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.llm_events import (
|
||||
@@ -59,12 +51,6 @@ if TYPE_CHECKING:
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
class JsonResponseFormat(TypedDict):
|
||||
"""Response format requesting raw JSON output (e.g. ``{"type": "json_object"}``)."""
|
||||
|
||||
type: Literal["json_object"]
|
||||
|
||||
|
||||
DEFAULT_CONTEXT_WINDOW_SIZE: Final[int] = 4096
|
||||
DEFAULT_SUPPORTS_STOP_WORDS: Final[bool] = True
|
||||
_JSON_EXTRACTION_PATTERN: Final[re.Pattern[str]] = re.compile(r"\{.*}", re.DOTALL)
|
||||
@@ -96,7 +82,7 @@ def get_current_call_id() -> str:
|
||||
return call_id
|
||||
|
||||
|
||||
class BaseLLM(BaseModel, ABC):
|
||||
class BaseLLM(ABC):
|
||||
"""Abstract base class for LLM implementations.
|
||||
|
||||
This class defines the interface that all LLM implementations must follow.
|
||||
@@ -115,100 +101,56 @@ class BaseLLM(BaseModel, ABC):
|
||||
additional_params: Additional provider-specific parameters.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True)
|
||||
|
||||
model: str
|
||||
temperature: float | None = None
|
||||
api_key: str | None = None
|
||||
base_url: str | None = None
|
||||
provider: str = Field(default="openai")
|
||||
prefer_upload: bool = False
|
||||
is_litellm: bool = False
|
||||
stop: list[str] = Field(
|
||||
default_factory=list,
|
||||
validation_alias=AliasChoices("stop", "stop_sequences"),
|
||||
)
|
||||
additional_params: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
if name in ("stop", "stop_sequences"):
|
||||
if value is None:
|
||||
value = []
|
||||
elif isinstance(value, str):
|
||||
value = [value]
|
||||
elif not isinstance(value, list):
|
||||
value = list(value)
|
||||
name = "stop"
|
||||
try:
|
||||
super().__setattr__(name, value)
|
||||
except ValueError:
|
||||
if name in self.model_fields:
|
||||
raise # Re-raise validation errors on declared fields
|
||||
# Fallback for attributes not declared as fields (e.g. mock patching)
|
||||
object.__setattr__(self, name, value)
|
||||
except AttributeError:
|
||||
object.__setattr__(self, name, value)
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
temperature: float | None = None,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
provider: str | None = None,
|
||||
prefer_upload: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the BaseLLM with default attributes.
|
||||
|
||||
def __delattr__(self, name: str) -> None:
|
||||
try:
|
||||
super().__delattr__(name)
|
||||
except AttributeError:
|
||||
object.__delattr__(self, name)
|
||||
|
||||
@property
|
||||
def stop_sequences(self) -> list[str]:
|
||||
"""Alias for ``stop`` — kept for backward compatibility with provider APIs.
|
||||
|
||||
Writes are handled by ``__setattr__``, which normalizes and redirects
|
||||
``stop_sequences`` assignments to the ``stop`` field.
|
||||
Args:
|
||||
model: The model identifier/name.
|
||||
temperature: Optional temperature setting for response generation.
|
||||
stop: Optional list of stop sequences for generation.
|
||||
prefer_upload: Whether to prefer file upload over inline base64.
|
||||
**kwargs: Additional provider-specific parameters.
|
||||
"""
|
||||
return self.stop
|
||||
if not model:
|
||||
raise ValueError("Model name is required and cannot be empty")
|
||||
|
||||
_token_usage: dict[str, int] = PrivateAttr(
|
||||
default_factory=lambda: {
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self.prefer_upload = prefer_upload
|
||||
# Store additional parameters for provider-specific use
|
||||
self.additional_params = kwargs
|
||||
self._provider = provider or "openai"
|
||||
|
||||
stop = kwargs.pop("stop", None)
|
||||
if stop is None:
|
||||
self.stop: list[str] = []
|
||||
elif isinstance(stop, str):
|
||||
self.stop = [stop]
|
||||
elif isinstance(stop, list):
|
||||
self.stop = stop
|
||||
else:
|
||||
self.stop = []
|
||||
|
||||
self._token_usage = {
|
||||
"total_tokens": 0,
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"successful_requests": 0,
|
||||
"cached_prompt_tokens": 0,
|
||||
}
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _validate_init_fields(cls, data: Any) -> Any:
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
|
||||
if not data.get("model"):
|
||||
raise ValueError("Model name is required and cannot be empty")
|
||||
|
||||
# Normalize stop: accept str, list, or None; also accept stop_sequences alias
|
||||
stop_seqs = data.pop("stop_sequences", None)
|
||||
stop = stop_seqs if stop_seqs is not None else data.get("stop")
|
||||
if stop is None:
|
||||
data["stop"] = []
|
||||
elif isinstance(stop, str):
|
||||
data["stop"] = [stop]
|
||||
elif isinstance(stop, list):
|
||||
data["stop"] = stop
|
||||
else:
|
||||
data["stop"] = list(stop)
|
||||
|
||||
# Default provider
|
||||
if not data.get("provider"):
|
||||
data["provider"] = "openai"
|
||||
|
||||
# Collect unknown kwargs into additional_params
|
||||
known_fields = set(cls.model_fields.keys())
|
||||
extras = {k: v for k, v in data.items() if k not in known_fields}
|
||||
for k in extras:
|
||||
data.pop(k)
|
||||
existing = data.get("additional_params") or {}
|
||||
existing.update(extras)
|
||||
data["additional_params"] = existing
|
||||
|
||||
return data
|
||||
|
||||
def to_config_dict(self) -> dict[str, Any]:
|
||||
"""Serialize this LLM to a dict that can reconstruct it via ``LLM(**config)``.
|
||||
@@ -232,6 +174,16 @@ class BaseLLM(BaseModel, ABC):
|
||||
|
||||
return config
|
||||
|
||||
@property
|
||||
def provider(self) -> str:
|
||||
"""Get the provider of the LLM."""
|
||||
return self._provider
|
||||
|
||||
@provider.setter
|
||||
def provider(self, value: str) -> None:
|
||||
"""Set the provider of the LLM."""
|
||||
self._provider = value
|
||||
|
||||
@abstractmethod
|
||||
def call(
|
||||
self,
|
||||
|
||||
@@ -3,13 +3,12 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Final, Literal, TypeGuard, cast
|
||||
from typing import TYPE_CHECKING, Any, Final, Literal, TypeGuard, cast
|
||||
|
||||
from pydantic import BaseModel, PrivateAttr, model_validator
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llms.base_llm import BaseLLM, JsonResponseFormat, llm_call_context
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
from crewai.llms.base_llm import BaseLLM, llm_call_context
|
||||
from crewai.llms.hooks.transport import AsyncHTTPTransport, HTTPTransport
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
@@ -18,6 +17,9 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
|
||||
try:
|
||||
from anthropic import Anthropic, AsyncAnthropic, transform_schema
|
||||
from anthropic.types import (
|
||||
@@ -148,47 +150,60 @@ class AnthropicCompletion(BaseLLM):
|
||||
offering native tool use, streaming support, and proper message formatting.
|
||||
"""
|
||||
|
||||
model: str = "claude-3-5-sonnet-20241022"
|
||||
timeout: float | None = None
|
||||
max_retries: int = 2
|
||||
max_tokens: int = 4096
|
||||
top_p: float | None = None
|
||||
stream: bool = False
|
||||
client_params: dict[str, Any] | None = None
|
||||
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None
|
||||
thinking: AnthropicThinkingConfig | None = None
|
||||
response_format: JsonResponseFormat | type[BaseModel] | None = None
|
||||
tool_search: AnthropicToolSearchConfig | None = None
|
||||
is_claude_3: bool = False
|
||||
supports_tools: bool = True
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "claude-3-5-sonnet-20241022",
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
timeout: float | None = None,
|
||||
max_retries: int = 2,
|
||||
temperature: float | None = None,
|
||||
max_tokens: int = 4096, # Required for Anthropic
|
||||
top_p: float | None = None,
|
||||
stop_sequences: list[str] | None = None,
|
||||
stream: bool = False,
|
||||
client_params: dict[str, Any] | None = None,
|
||||
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
|
||||
thinking: AnthropicThinkingConfig | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
tool_search: AnthropicToolSearchConfig | bool | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize Anthropic chat completion client.
|
||||
|
||||
_client: Any = PrivateAttr(default=None)
|
||||
_async_client: Any = PrivateAttr(default=None)
|
||||
_previous_thinking_blocks: list[Any] = PrivateAttr(default_factory=list)
|
||||
Args:
|
||||
model: Anthropic model name (e.g., 'claude-3-5-sonnet-20241022')
|
||||
api_key: Anthropic API key (defaults to ANTHROPIC_API_KEY env var)
|
||||
base_url: Custom base URL for Anthropic API
|
||||
timeout: Request timeout in seconds
|
||||
max_retries: Maximum number of retries
|
||||
temperature: Sampling temperature (0-1)
|
||||
max_tokens: Maximum tokens in response (required for Anthropic)
|
||||
top_p: Nucleus sampling parameter
|
||||
stop_sequences: Stop sequences (Anthropic uses stop_sequences, not stop)
|
||||
stream: Enable streaming responses
|
||||
client_params: Additional parameters for the Anthropic client
|
||||
interceptor: HTTP interceptor for modifying requests/responses at transport level.
|
||||
response_format: Pydantic model for structured output. When provided, responses
|
||||
will be validated against this model schema.
|
||||
tool_search: Enable Anthropic's server-side tool search. When True, uses "bm25"
|
||||
variant by default. Pass an AnthropicToolSearchConfig to choose "regex" or
|
||||
"bm25". When enabled, tools are automatically marked with defer_loading=True
|
||||
and a tool search tool is injected into the tools list.
|
||||
**kwargs: Additional parameters
|
||||
"""
|
||||
super().__init__(
|
||||
model=model, temperature=temperature, stop=stop_sequences or [], **kwargs
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _normalize_anthropic_fields(cls, data: Any) -> Any:
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
# Anthropic uses stop_sequences; normalize from stop kwarg
|
||||
popped = data.pop("stop_sequences", None)
|
||||
seqs = popped if popped is not None else (data.get("stop") or [])
|
||||
if isinstance(seqs, str):
|
||||
seqs = [seqs]
|
||||
data["stop"] = seqs
|
||||
data["is_claude_3"] = "claude-3" in data.get("model", "").lower()
|
||||
# Normalize tool_search
|
||||
ts = data.get("tool_search")
|
||||
if ts is True:
|
||||
data["tool_search"] = AnthropicToolSearchConfig()
|
||||
elif ts is not None and not isinstance(ts, AnthropicToolSearchConfig):
|
||||
data["tool_search"] = None
|
||||
return data
|
||||
# Client params
|
||||
self.interceptor = interceptor
|
||||
self.client_params = client_params
|
||||
self.base_url = base_url
|
||||
self.timeout = timeout
|
||||
self.max_retries = max_retries
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _init_clients(self) -> AnthropicCompletion:
|
||||
self._client = Anthropic(**self._get_client_params())
|
||||
self.client = Anthropic(**self._get_client_params())
|
||||
|
||||
async_client_params = self._get_client_params()
|
||||
if self.interceptor:
|
||||
@@ -196,8 +211,51 @@ class AnthropicCompletion(BaseLLM):
|
||||
async_http_client = httpx.AsyncClient(transport=async_transport)
|
||||
async_client_params["http_client"] = async_http_client
|
||||
|
||||
self._async_client = AsyncAnthropic(**async_client_params)
|
||||
return self
|
||||
self.async_client = AsyncAnthropic(**async_client_params)
|
||||
|
||||
# Store completion parameters
|
||||
self.max_tokens = max_tokens
|
||||
self.top_p = top_p
|
||||
self.stream = stream
|
||||
self.stop_sequences = stop_sequences or []
|
||||
self.thinking = thinking
|
||||
self.previous_thinking_blocks: list[ThinkingBlock] = []
|
||||
self.response_format = response_format
|
||||
# Tool search config
|
||||
self.tool_search: AnthropicToolSearchConfig | None
|
||||
if tool_search is True:
|
||||
self.tool_search = AnthropicToolSearchConfig()
|
||||
elif isinstance(tool_search, AnthropicToolSearchConfig):
|
||||
self.tool_search = tool_search
|
||||
else:
|
||||
self.tool_search = None
|
||||
# Model-specific settings
|
||||
self.is_claude_3 = "claude-3" in model.lower()
|
||||
self.supports_tools = True
|
||||
|
||||
@property
|
||||
def stop(self) -> list[str]:
|
||||
"""Get stop sequences sent to the API."""
|
||||
return self.stop_sequences
|
||||
|
||||
@stop.setter
|
||||
def stop(self, value: list[str] | str | None) -> None:
|
||||
"""Set stop sequences.
|
||||
|
||||
Synchronizes stop_sequences to ensure values set by CrewAgentExecutor
|
||||
are properly sent to the Anthropic API.
|
||||
|
||||
Args:
|
||||
value: Stop sequences as a list, single string, or None
|
||||
"""
|
||||
if value is None:
|
||||
self.stop_sequences = []
|
||||
elif isinstance(value, str):
|
||||
self.stop_sequences = [value]
|
||||
elif isinstance(value, list):
|
||||
self.stop_sequences = value
|
||||
else:
|
||||
self.stop_sequences = []
|
||||
|
||||
def to_config_dict(self) -> dict[str, Any]:
|
||||
"""Extend base config with Anthropic-specific fields."""
|
||||
@@ -693,11 +751,11 @@ class AnthropicCompletion(BaseLLM):
|
||||
)
|
||||
elif isinstance(content, list):
|
||||
formatted_messages.append({"role": "assistant", "content": content})
|
||||
elif self.thinking and self._previous_thinking_blocks:
|
||||
elif self.thinking and self.previous_thinking_blocks:
|
||||
structured_content = cast(
|
||||
list[dict[str, Any]],
|
||||
[
|
||||
*self._previous_thinking_blocks,
|
||||
*self.previous_thinking_blocks,
|
||||
{"type": "text", "text": content if content else ""},
|
||||
],
|
||||
)
|
||||
@@ -751,7 +809,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: JsonResponseFormat | type[BaseModel] | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming message completion."""
|
||||
uses_file_api = _contains_file_id_reference(params.get("messages", []))
|
||||
@@ -785,11 +843,11 @@ class AnthropicCompletion(BaseLLM):
|
||||
try:
|
||||
if betas:
|
||||
params["betas"] = betas
|
||||
response = self._client.beta.messages.create(
|
||||
response = self.client.beta.messages.create(
|
||||
**params, extra_body=extra_body
|
||||
)
|
||||
else:
|
||||
response = self._client.messages.create(**params)
|
||||
response = self.client.messages.create(**params)
|
||||
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
@@ -870,7 +928,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
thinking_blocks.append(cast(ThinkingBlock, thinking_block))
|
||||
|
||||
if thinking_blocks:
|
||||
self._previous_thinking_blocks = thinking_blocks
|
||||
self.previous_thinking_blocks = thinking_blocks
|
||||
|
||||
content = self._apply_stop_words(content)
|
||||
self._emit_call_completed_event(
|
||||
@@ -894,7 +952,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: JsonResponseFormat | type[BaseModel] | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle streaming message completion."""
|
||||
betas: list[str] = []
|
||||
@@ -933,9 +991,9 @@ class AnthropicCompletion(BaseLLM):
|
||||
current_tool_calls: dict[int, dict[str, Any]] = {}
|
||||
|
||||
stream_context = (
|
||||
self._client.beta.messages.stream(**stream_params, extra_body=extra_body)
|
||||
self.client.beta.messages.stream(**stream_params, extra_body=extra_body)
|
||||
if betas
|
||||
else self._client.messages.stream(**stream_params)
|
||||
else self.client.messages.stream(**stream_params)
|
||||
)
|
||||
with stream_context as stream:
|
||||
response_id = None
|
||||
@@ -1014,7 +1072,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
thinking_blocks.append(cast(ThinkingBlock, thinking_block))
|
||||
|
||||
if thinking_blocks:
|
||||
self._previous_thinking_blocks = thinking_blocks
|
||||
self.previous_thinking_blocks = thinking_blocks
|
||||
|
||||
usage = self._extract_anthropic_token_usage(final_message)
|
||||
self._track_token_usage_internal(usage)
|
||||
@@ -1211,7 +1269,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
try:
|
||||
# Send tool results back to Claude for final response
|
||||
final_response: Message = self._client.messages.create(**follow_up_params)
|
||||
final_response: Message = self.client.messages.create(**follow_up_params)
|
||||
|
||||
# Track token usage for follow-up call
|
||||
follow_up_usage = self._extract_anthropic_token_usage(final_response)
|
||||
@@ -1230,7 +1288,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
thinking_blocks.append(cast(ThinkingBlock, thinking_block))
|
||||
|
||||
if thinking_blocks:
|
||||
self._previous_thinking_blocks = thinking_blocks
|
||||
self.previous_thinking_blocks = thinking_blocks
|
||||
|
||||
final_content = self._apply_stop_words(final_content)
|
||||
|
||||
@@ -1272,7 +1330,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: JsonResponseFormat | type[BaseModel] | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming async message completion."""
|
||||
uses_file_api = _contains_file_id_reference(params.get("messages", []))
|
||||
@@ -1306,11 +1364,11 @@ class AnthropicCompletion(BaseLLM):
|
||||
try:
|
||||
if betas:
|
||||
params["betas"] = betas
|
||||
response = await self._async_client.beta.messages.create(
|
||||
response = await self.async_client.beta.messages.create(
|
||||
**params, extra_body=extra_body
|
||||
)
|
||||
else:
|
||||
response = await self._async_client.messages.create(**params)
|
||||
response = await self.async_client.messages.create(**params)
|
||||
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
@@ -1403,7 +1461,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: JsonResponseFormat | type[BaseModel] | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle async streaming message completion."""
|
||||
betas: list[str] = []
|
||||
@@ -1440,11 +1498,11 @@ class AnthropicCompletion(BaseLLM):
|
||||
current_tool_calls: dict[int, dict[str, Any]] = {}
|
||||
|
||||
stream_context = (
|
||||
self._async_client.beta.messages.stream(
|
||||
self.async_client.beta.messages.stream(
|
||||
**stream_params, extra_body=extra_body
|
||||
)
|
||||
if betas
|
||||
else self._async_client.messages.stream(**stream_params)
|
||||
else self.async_client.messages.stream(**stream_params)
|
||||
)
|
||||
async with stream_context as stream:
|
||||
response_id = None
|
||||
@@ -1606,7 +1664,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
]
|
||||
|
||||
try:
|
||||
final_response: Message = await self._async_client.messages.create(
|
||||
final_response: Message = await self.async_client.messages.create(
|
||||
**follow_up_params
|
||||
)
|
||||
|
||||
@@ -1728,8 +1786,8 @@ class AnthropicCompletion(BaseLLM):
|
||||
from crewai_files.uploaders.anthropic import AnthropicFileUploader
|
||||
|
||||
return AnthropicFileUploader(
|
||||
client=self._client,
|
||||
async_client=self._async_client,
|
||||
client=self.client,
|
||||
async_client=self.async_client,
|
||||
)
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
@@ -3,13 +3,11 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, TypedDict
|
||||
from urllib.parse import urlparse
|
||||
from typing import TYPE_CHECKING, Any, TypedDict
|
||||
|
||||
from pydantic import BaseModel, PrivateAttr, model_validator
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
@@ -18,6 +16,10 @@ from crewai.utilities.pydantic_schema_utils import generate_model_description
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
|
||||
|
||||
try:
|
||||
from azure.ai.inference import (
|
||||
ChatCompletionsClient,
|
||||
@@ -74,84 +76,109 @@ class AzureCompletion(BaseLLM):
|
||||
offering native function calling, streaming support, and proper Azure authentication.
|
||||
"""
|
||||
|
||||
endpoint: str | None = None
|
||||
api_version: str | None = None
|
||||
timeout: float | None = None
|
||||
max_retries: int = 2
|
||||
top_p: float | None = None
|
||||
frequency_penalty: float | None = None
|
||||
presence_penalty: float | None = None
|
||||
max_tokens: int | None = None
|
||||
stream: bool = False
|
||||
interceptor: BaseInterceptor[Any, Any] | None = None
|
||||
response_format: type[BaseModel] | None = None
|
||||
is_openai_model: bool = False
|
||||
is_azure_openai_endpoint: bool = False
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
api_key: str | None = None,
|
||||
endpoint: str | None = None,
|
||||
api_version: str | None = None,
|
||||
timeout: float | None = None,
|
||||
max_retries: int = 2,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = False,
|
||||
interceptor: BaseInterceptor[Any, Any] | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize Azure AI Inference chat completion client.
|
||||
|
||||
_client: Any = PrivateAttr(default=None)
|
||||
_async_client: Any = PrivateAttr(default=None)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _normalize_azure_fields(cls, data: Any) -> Any:
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
|
||||
if data.get("interceptor") is not None:
|
||||
Args:
|
||||
model: Azure deployment name or model name
|
||||
api_key: Azure API key (defaults to AZURE_API_KEY env var)
|
||||
endpoint: Azure endpoint URL (defaults to AZURE_ENDPOINT env var)
|
||||
api_version: Azure API version (defaults to AZURE_API_VERSION env var)
|
||||
timeout: Request timeout in seconds
|
||||
max_retries: Maximum number of retries
|
||||
temperature: Sampling temperature (0-2)
|
||||
top_p: Nucleus sampling parameter
|
||||
frequency_penalty: Frequency penalty (-2 to 2)
|
||||
presence_penalty: Presence penalty (-2 to 2)
|
||||
max_tokens: Maximum tokens in response
|
||||
stop: Stop sequences
|
||||
stream: Enable streaming responses
|
||||
interceptor: HTTP interceptor (not yet supported for Azure).
|
||||
response_format: Pydantic model for structured output. Used as default when
|
||||
response_model is not passed to call()/acall() methods.
|
||||
Only works with OpenAI models deployed on Azure.
|
||||
**kwargs: Additional parameters
|
||||
"""
|
||||
if interceptor is not None:
|
||||
raise NotImplementedError(
|
||||
"HTTP interceptors are not yet supported for Azure AI Inference provider. "
|
||||
"Interceptors are currently supported for OpenAI and Anthropic providers only."
|
||||
)
|
||||
|
||||
# Resolve env vars
|
||||
data["api_key"] = data.get("api_key") or os.getenv("AZURE_API_KEY")
|
||||
data["endpoint"] = (
|
||||
data.get("endpoint")
|
||||
super().__init__(
|
||||
model=model, temperature=temperature, stop=stop or [], **kwargs
|
||||
)
|
||||
|
||||
self.api_key = api_key or os.getenv("AZURE_API_KEY")
|
||||
self.endpoint = (
|
||||
endpoint
|
||||
or os.getenv("AZURE_ENDPOINT")
|
||||
or os.getenv("AZURE_OPENAI_ENDPOINT")
|
||||
or os.getenv("AZURE_API_BASE")
|
||||
)
|
||||
data["api_version"] = (
|
||||
data.get("api_version") or os.getenv("AZURE_API_VERSION") or "2024-06-01"
|
||||
)
|
||||
self.api_version = api_version or os.getenv("AZURE_API_VERSION") or "2024-06-01"
|
||||
self.timeout = timeout
|
||||
self.max_retries = max_retries
|
||||
|
||||
if not data["api_key"]:
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
"Azure API key is required. Set AZURE_API_KEY environment variable or pass api_key parameter."
|
||||
)
|
||||
if not data["endpoint"]:
|
||||
if not self.endpoint:
|
||||
raise ValueError(
|
||||
"Azure endpoint is required. Set AZURE_ENDPOINT environment variable or pass endpoint parameter."
|
||||
)
|
||||
|
||||
model = data.get("model", "")
|
||||
data["endpoint"] = AzureCompletion._validate_and_fix_endpoint(
|
||||
data["endpoint"], model
|
||||
)
|
||||
data["is_openai_model"] = any(
|
||||
prefix in model.lower() for prefix in ["gpt-", "o1-", "text-"]
|
||||
)
|
||||
parsed = urlparse(data["endpoint"])
|
||||
hostname = parsed.hostname or ""
|
||||
data["is_azure_openai_endpoint"] = (
|
||||
hostname == "openai.azure.com" or hostname.endswith(".openai.azure.com")
|
||||
) and "/openai/deployments/" in data["endpoint"]
|
||||
return data
|
||||
# Validate and potentially fix Azure OpenAI endpoint URL
|
||||
self.endpoint = self._validate_and_fix_endpoint(self.endpoint, model)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _init_clients(self) -> AzureCompletion:
|
||||
if not self.api_key:
|
||||
raise ValueError("Azure API key is required.")
|
||||
client_kwargs: dict[str, Any] = {
|
||||
# Build client kwargs
|
||||
client_kwargs = {
|
||||
"endpoint": self.endpoint,
|
||||
"credential": AzureKeyCredential(self.api_key),
|
||||
}
|
||||
|
||||
# Add api_version if specified (primarily for Azure OpenAI endpoints)
|
||||
if self.api_version:
|
||||
client_kwargs["api_version"] = self.api_version
|
||||
|
||||
self._client = ChatCompletionsClient(**client_kwargs)
|
||||
self._async_client = AsyncChatCompletionsClient(**client_kwargs)
|
||||
return self
|
||||
self.client = ChatCompletionsClient(**client_kwargs) # type: ignore[arg-type]
|
||||
|
||||
self.async_client = AsyncChatCompletionsClient(**client_kwargs) # type: ignore[arg-type]
|
||||
|
||||
self.top_p = top_p
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.presence_penalty = presence_penalty
|
||||
self.max_tokens = max_tokens
|
||||
self.stream = stream
|
||||
self.response_format = response_format
|
||||
|
||||
self.is_openai_model = any(
|
||||
prefix in model.lower() for prefix in ["gpt-", "o1-", "text-"]
|
||||
)
|
||||
|
||||
self.is_azure_openai_endpoint = (
|
||||
"openai.azure.com" in self.endpoint
|
||||
and "/openai/deployments/" in self.endpoint
|
||||
)
|
||||
|
||||
def to_config_dict(self) -> dict[str, Any]:
|
||||
"""Extend base config with Azure-specific fields."""
|
||||
@@ -188,11 +215,7 @@ class AzureCompletion(BaseLLM):
|
||||
Returns:
|
||||
Validated and potentially corrected endpoint URL
|
||||
"""
|
||||
ep_host = urlparse(endpoint).hostname or ""
|
||||
is_azure_openai = ep_host == "openai.azure.com" or ep_host.endswith(
|
||||
".openai.azure.com"
|
||||
)
|
||||
if is_azure_openai and "/openai/deployments/" not in endpoint:
|
||||
if "openai.azure.com" in endpoint and "/openai/deployments/" not in endpoint:
|
||||
endpoint = endpoint.rstrip("/")
|
||||
|
||||
if not endpoint.endswith("/openai/deployments"):
|
||||
@@ -708,7 +731,7 @@ class AzureCompletion(BaseLLM):
|
||||
"""Handle non-streaming chat completion."""
|
||||
try:
|
||||
# Cast params to Any to avoid type checking issues with TypedDict unpacking
|
||||
response: ChatCompletions = self._client.complete(**params)
|
||||
response: ChatCompletions = self.client.complete(**params) # type: ignore[assignment,arg-type]
|
||||
return self._process_completion_response(
|
||||
response=response,
|
||||
params=params,
|
||||
@@ -903,7 +926,7 @@ class AzureCompletion(BaseLLM):
|
||||
tool_calls: dict[int, dict[str, Any]] = {}
|
||||
|
||||
usage_data = {"total_tokens": 0}
|
||||
for update in self._client.complete(**params):
|
||||
for update in self.client.complete(**params): # type: ignore[arg-type]
|
||||
if isinstance(update, StreamingChatCompletionsUpdate):
|
||||
if update.usage:
|
||||
usage = update.usage
|
||||
@@ -944,7 +967,7 @@ class AzureCompletion(BaseLLM):
|
||||
"""Handle non-streaming chat completion asynchronously."""
|
||||
try:
|
||||
# Cast params to Any to avoid type checking issues with TypedDict unpacking
|
||||
response: ChatCompletions = await self._async_client.complete(**params)
|
||||
response: ChatCompletions = await self.async_client.complete(**params) # type: ignore[assignment,arg-type]
|
||||
return self._process_completion_response(
|
||||
response=response,
|
||||
params=params,
|
||||
@@ -970,8 +993,8 @@ class AzureCompletion(BaseLLM):
|
||||
|
||||
usage_data = {"total_tokens": 0}
|
||||
|
||||
stream = await self._async_client.complete(**params)
|
||||
async for update in stream:
|
||||
stream = await self.async_client.complete(**params) # type: ignore[arg-type]
|
||||
async for update in stream: # type: ignore[union-attr]
|
||||
if isinstance(update, StreamingChatCompletionsUpdate):
|
||||
if hasattr(update, "usage") and update.usage:
|
||||
usage = update.usage
|
||||
@@ -1087,8 +1110,8 @@ class AzureCompletion(BaseLLM):
|
||||
This ensures proper cleanup of the underlying aiohttp session
|
||||
to avoid unclosed connector warnings.
|
||||
"""
|
||||
if hasattr(self._async_client, "close"):
|
||||
await self._async_client.close()
|
||||
if hasattr(self.async_client, "close"):
|
||||
await self.async_client.close()
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
"""Async context manager entry."""
|
||||
|
||||
@@ -7,7 +7,7 @@ import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, TypedDict, cast
|
||||
|
||||
from pydantic import BaseModel, PrivateAttr, model_validator
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Required
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
@@ -33,7 +33,7 @@ if TYPE_CHECKING:
|
||||
ToolTypeDef,
|
||||
)
|
||||
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
|
||||
|
||||
try:
|
||||
@@ -228,97 +228,129 @@ class BedrockCompletion(BaseLLM):
|
||||
- Model-specific conversation format handling (e.g., Cohere requirements)
|
||||
"""
|
||||
|
||||
model: str = "anthropic.claude-3-5-sonnet-20241022-v2:0"
|
||||
aws_access_key_id: str | None = None
|
||||
aws_secret_access_key: str | None = None
|
||||
aws_session_token: str | None = None
|
||||
region_name: str | None = None
|
||||
max_tokens: int | None = None
|
||||
top_p: float | None = None
|
||||
top_k: int | None = None
|
||||
stream: bool = False
|
||||
guardrail_config: dict[str, Any] | None = None
|
||||
additional_model_request_fields: dict[str, Any] | None = None
|
||||
additional_model_response_field_paths: list[str] | None = None
|
||||
interceptor: BaseInterceptor[Any, Any] | None = None
|
||||
response_format: type[BaseModel] | None = None
|
||||
is_claude_model: bool = False
|
||||
supports_tools: bool = True
|
||||
supports_streaming: bool = True
|
||||
model_id: str = ""
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
aws_access_key_id: str | None = None,
|
||||
aws_secret_access_key: str | None = None,
|
||||
aws_session_token: str | None = None,
|
||||
region_name: str | None = None,
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
top_p: float | None = None,
|
||||
top_k: int | None = None,
|
||||
stop_sequences: Sequence[str] | None = None,
|
||||
stream: bool = False,
|
||||
guardrail_config: dict[str, Any] | None = None,
|
||||
additional_model_request_fields: dict[str, Any] | None = None,
|
||||
additional_model_response_field_paths: list[str] | None = None,
|
||||
interceptor: BaseInterceptor[Any, Any] | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize AWS Bedrock completion client.
|
||||
|
||||
_client: Any = PrivateAttr(default=None)
|
||||
_async_exit_stack: Any = PrivateAttr(default=None)
|
||||
_async_client_initialized: bool = PrivateAttr(default=False)
|
||||
_async_client: Any = PrivateAttr(default=None)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _normalize_bedrock_fields(cls, data: Any) -> Any:
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
|
||||
if data.get("interceptor") is not None:
|
||||
Args:
|
||||
model: The Bedrock model ID to use
|
||||
aws_access_key_id: AWS access key (defaults to environment variable)
|
||||
aws_secret_access_key: AWS secret key (defaults to environment variable)
|
||||
aws_session_token: AWS session token for temporary credentials
|
||||
region_name: AWS region name
|
||||
temperature: Sampling temperature for response generation
|
||||
max_tokens: Maximum tokens to generate
|
||||
top_p: Nucleus sampling parameter
|
||||
top_k: Top-k sampling parameter (Claude models only)
|
||||
stop_sequences: List of sequences that stop generation
|
||||
stream: Whether to use streaming responses
|
||||
guardrail_config: Guardrail configuration for content filtering
|
||||
additional_model_request_fields: Model-specific request parameters
|
||||
additional_model_response_field_paths: Custom response field paths
|
||||
interceptor: HTTP interceptor (not yet supported for Bedrock).
|
||||
response_format: Pydantic model for structured output. Used as default when
|
||||
response_model is not passed to call()/acall() methods.
|
||||
**kwargs: Additional parameters
|
||||
"""
|
||||
if interceptor is not None:
|
||||
raise NotImplementedError(
|
||||
"HTTP interceptors are not yet supported for AWS Bedrock provider. "
|
||||
"Interceptors are currently supported for OpenAI and Anthropic providers only."
|
||||
)
|
||||
|
||||
# Force provider to bedrock
|
||||
data.pop("provider", None)
|
||||
data["provider"] = "bedrock"
|
||||
# Extract provider from kwargs to avoid duplicate argument
|
||||
kwargs.pop("provider", None)
|
||||
|
||||
# Normalize stop_sequences from stop kwarg
|
||||
popped = data.pop("stop_sequences", None)
|
||||
seqs = popped if popped is not None else (data.get("stop") or [])
|
||||
if isinstance(seqs, str):
|
||||
seqs = [seqs]
|
||||
elif isinstance(seqs, Sequence) and not isinstance(seqs, list):
|
||||
seqs = list(seqs)
|
||||
data["stop"] = seqs
|
||||
super().__init__(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
stop=stop_sequences or [],
|
||||
provider="bedrock",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Resolve env vars
|
||||
data["aws_access_key_id"] = data.get("aws_access_key_id") or os.getenv(
|
||||
"AWS_ACCESS_KEY_ID"
|
||||
# Configure client with timeouts and retries following AWS best practices
|
||||
config = Config(
|
||||
read_timeout=300,
|
||||
retries={
|
||||
"max_attempts": 3,
|
||||
"mode": "adaptive",
|
||||
},
|
||||
tcp_keepalive=True,
|
||||
)
|
||||
data["aws_secret_access_key"] = data.get("aws_secret_access_key") or os.getenv(
|
||||
"AWS_SECRET_ACCESS_KEY"
|
||||
)
|
||||
data["aws_session_token"] = data.get("aws_session_token") or os.getenv(
|
||||
"AWS_SESSION_TOKEN"
|
||||
)
|
||||
data["region_name"] = (
|
||||
data.get("region_name")
|
||||
|
||||
self.region_name = (
|
||||
region_name
|
||||
or os.getenv("AWS_DEFAULT_REGION")
|
||||
or os.getenv("AWS_REGION_NAME")
|
||||
or "us-east-1"
|
||||
)
|
||||
|
||||
model = data.get("model", "anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
data["is_claude_model"] = "claude" in model.lower()
|
||||
data["model_id"] = model
|
||||
return data
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _init_clients(self) -> BedrockCompletion:
|
||||
config = Config(
|
||||
read_timeout=300,
|
||||
retries={"max_attempts": 3, "mode": "adaptive"},
|
||||
tcp_keepalive=True,
|
||||
self.aws_access_key_id = aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID")
|
||||
self.aws_secret_access_key = aws_secret_access_key or os.getenv(
|
||||
"AWS_SECRET_ACCESS_KEY"
|
||||
)
|
||||
self.aws_session_token = aws_session_token or os.getenv("AWS_SESSION_TOKEN")
|
||||
|
||||
# Initialize Bedrock client with proper configuration
|
||||
session = Session(
|
||||
aws_access_key_id=self.aws_access_key_id,
|
||||
aws_secret_access_key=self.aws_secret_access_key,
|
||||
aws_session_token=self.aws_session_token,
|
||||
region_name=self.region_name,
|
||||
)
|
||||
self._client = session.client("bedrock-runtime", config=config)
|
||||
|
||||
self.client = session.client("bedrock-runtime", config=config)
|
||||
|
||||
self._async_exit_stack = AsyncExitStack() if AIOBOTOCORE_AVAILABLE else None
|
||||
return self
|
||||
self._async_client_initialized = False
|
||||
|
||||
# Store completion parameters
|
||||
self.max_tokens = max_tokens
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
self.stream = stream
|
||||
self.stop_sequences = stop_sequences
|
||||
self.response_format = response_format
|
||||
|
||||
# Store advanced features (optional)
|
||||
self.guardrail_config = guardrail_config
|
||||
self.additional_model_request_fields = additional_model_request_fields
|
||||
self.additional_model_response_field_paths = (
|
||||
additional_model_response_field_paths
|
||||
)
|
||||
|
||||
# Model-specific settings
|
||||
self.is_claude_model = "claude" in model.lower()
|
||||
self.supports_tools = True # Converse API supports tools for most models
|
||||
self.supports_streaming = True
|
||||
|
||||
# Handle inference profiles for newer models
|
||||
self.model_id = model
|
||||
|
||||
def to_config_dict(self) -> dict[str, Any]:
|
||||
"""Extend base config with Bedrock-specific fields."""
|
||||
config = super().to_config_dict()
|
||||
# NOTE: AWS credentials (access_key, secret_key, session_token) are
|
||||
# intentionally excluded — they must come from env on resume.
|
||||
if self.region_name and self.region_name != "us-east-1":
|
||||
config["region_name"] = self.region_name
|
||||
if self.max_tokens is not None:
|
||||
@@ -331,6 +363,30 @@ class BedrockCompletion(BaseLLM):
|
||||
config["guardrail_config"] = self.guardrail_config
|
||||
return config
|
||||
|
||||
@property
|
||||
def stop(self) -> list[str]:
|
||||
"""Get stop sequences sent to the API."""
|
||||
return [] if self.stop_sequences is None else list(self.stop_sequences)
|
||||
|
||||
@stop.setter
|
||||
def stop(self, value: Sequence[str] | str | None) -> None:
|
||||
"""Set stop sequences.
|
||||
|
||||
Synchronizes stop_sequences to ensure values set by CrewAgentExecutor
|
||||
are properly sent to the Bedrock API.
|
||||
|
||||
Args:
|
||||
value: Stop sequences as a Sequence, single string, or None
|
||||
"""
|
||||
if value is None:
|
||||
self.stop_sequences = []
|
||||
elif isinstance(value, str):
|
||||
self.stop_sequences = [value]
|
||||
elif isinstance(value, Sequence):
|
||||
self.stop_sequences = list(value)
|
||||
else:
|
||||
self.stop_sequences = []
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: str | list[LLMMessage],
|
||||
@@ -654,7 +710,7 @@ class BedrockCompletion(BaseLLM):
|
||||
raise ValueError(f"Invalid message format at index {i}")
|
||||
|
||||
# Call Bedrock Converse API with proper error handling
|
||||
response = self._client.converse(
|
||||
response = self.client.converse(
|
||||
modelId=self.model_id,
|
||||
messages=cast(
|
||||
"Sequence[MessageTypeDef | MessageOutputTypeDef]",
|
||||
@@ -938,13 +994,13 @@ class BedrockCompletion(BaseLLM):
|
||||
accumulated_tool_input = ""
|
||||
|
||||
try:
|
||||
response = self._client.converse_stream(
|
||||
response = self.client.converse_stream(
|
||||
modelId=self.model_id,
|
||||
messages=cast(
|
||||
"Sequence[MessageTypeDef | MessageOutputTypeDef]",
|
||||
cast(object, messages),
|
||||
),
|
||||
**body,
|
||||
**body, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
stream = response.get("stream")
|
||||
|
||||
@@ -5,13 +5,12 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Literal, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from pydantic import BaseModel, Field, PrivateAttr, model_validator
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llms.base_llm import BaseLLM, llm_call_context
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
@@ -20,6 +19,10 @@ from crewai.utilities.pydantic_schema_utils import generate_model_description
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
|
||||
|
||||
try:
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
@@ -41,84 +44,137 @@ class GeminiCompletion(BaseLLM):
|
||||
offering native function calling, streaming support, and proper Gemini formatting.
|
||||
"""
|
||||
|
||||
model: str = "gemini-2.0-flash-001"
|
||||
project: str | None = None
|
||||
location: str | None = None
|
||||
top_p: float | None = None
|
||||
top_k: int | None = None
|
||||
max_output_tokens: int | None = None
|
||||
stream: bool = False
|
||||
safety_settings: dict[str, Any] = Field(default_factory=dict)
|
||||
client_params: dict[str, Any] = Field(default_factory=dict)
|
||||
interceptor: BaseInterceptor[Any, Any] | None = None
|
||||
use_vertexai: bool = False
|
||||
response_format: type[BaseModel] | None = None
|
||||
thinking_config: Any = None
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
supports_tools: bool = False
|
||||
is_gemini_2_0: bool = False
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "gemini-2.0-flash-001",
|
||||
api_key: str | None = None,
|
||||
project: str | None = None,
|
||||
location: str | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
top_k: int | None = None,
|
||||
max_output_tokens: int | None = None,
|
||||
stop_sequences: list[str] | None = None,
|
||||
stream: bool = False,
|
||||
safety_settings: dict[str, Any] | None = None,
|
||||
client_params: dict[str, Any] | None = None,
|
||||
interceptor: BaseInterceptor[Any, Any] | None = None,
|
||||
use_vertexai: bool | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
thinking_config: types.ThinkingConfig | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize Google Gemini chat completion client.
|
||||
|
||||
_client: Any = PrivateAttr(default=None)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _normalize_gemini_fields(cls, data: Any) -> Any:
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
|
||||
if data.get("interceptor") is not None:
|
||||
Args:
|
||||
model: Gemini model name (e.g., 'gemini-2.0-flash-001', 'gemini-1.5-pro')
|
||||
api_key: Google API key for Gemini API authentication.
|
||||
Defaults to GOOGLE_API_KEY or GEMINI_API_KEY env var.
|
||||
NOTE: Cannot be used with Vertex AI (project parameter). Use Gemini API instead.
|
||||
project: Google Cloud project ID for Vertex AI with ADC authentication.
|
||||
Requires Application Default Credentials (gcloud auth application-default login).
|
||||
NOTE: Vertex AI does NOT support API keys, only OAuth2/ADC.
|
||||
If both api_key and project are set, api_key takes precedence.
|
||||
location: Google Cloud location (for Vertex AI with ADC, defaults to 'us-central1')
|
||||
temperature: Sampling temperature (0-2)
|
||||
top_p: Nucleus sampling parameter
|
||||
top_k: Top-k sampling parameter
|
||||
max_output_tokens: Maximum tokens in response
|
||||
stop_sequences: Stop sequences
|
||||
stream: Enable streaming responses
|
||||
safety_settings: Safety filter settings
|
||||
client_params: Additional parameters to pass to the Google Gen AI Client constructor.
|
||||
Supports parameters like http_options, credentials, debug_config, etc.
|
||||
interceptor: HTTP interceptor (not yet supported for Gemini).
|
||||
use_vertexai: Whether to use Vertex AI instead of Gemini API.
|
||||
- True: Use Vertex AI (with ADC or Express mode with API key)
|
||||
- False: Use Gemini API (explicitly override env var)
|
||||
- None (default): Check GOOGLE_GENAI_USE_VERTEXAI env var
|
||||
When using Vertex AI with API key (Express mode), http_options with
|
||||
api_version="v1" is automatically configured.
|
||||
response_format: Pydantic model for structured output. Used as default when
|
||||
response_model is not passed to call()/acall() methods.
|
||||
thinking_config: ThinkingConfig for thinking models (gemini-2.5+, gemini-3+).
|
||||
Controls thought output via include_thoughts, thinking_budget,
|
||||
and thinking_level. When None, thinking models automatically
|
||||
get include_thoughts=True so thought content is surfaced.
|
||||
**kwargs: Additional parameters
|
||||
"""
|
||||
if interceptor is not None:
|
||||
raise NotImplementedError(
|
||||
"HTTP interceptors are not yet supported for Google Gemini provider. "
|
||||
"Interceptors are currently supported for OpenAI and Anthropic providers only."
|
||||
)
|
||||
|
||||
# Normalize stop_sequences from stop kwarg
|
||||
popped = data.pop("stop_sequences", None)
|
||||
seqs = popped if popped is not None else (data.get("stop") or [])
|
||||
if isinstance(seqs, str):
|
||||
seqs = [seqs]
|
||||
data["stop"] = seqs
|
||||
|
||||
# Resolve env vars
|
||||
data["api_key"] = (
|
||||
data.get("api_key")
|
||||
or os.getenv("GOOGLE_API_KEY")
|
||||
or os.getenv("GEMINI_API_KEY")
|
||||
)
|
||||
data["project"] = data.get("project") or os.getenv("GOOGLE_CLOUD_PROJECT")
|
||||
data["location"] = (
|
||||
data.get("location") or os.getenv("GOOGLE_CLOUD_LOCATION") or "us-central1"
|
||||
super().__init__(
|
||||
model=model, temperature=temperature, stop=stop_sequences or [], **kwargs
|
||||
)
|
||||
|
||||
use_vx = data.get("use_vertexai")
|
||||
if use_vx is None:
|
||||
use_vx = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true"
|
||||
data["use_vertexai"] = use_vx
|
||||
# Store client params for later use
|
||||
self.client_params = client_params or {}
|
||||
|
||||
# Get API configuration with environment variable fallbacks
|
||||
self.api_key = (
|
||||
api_key or os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY")
|
||||
)
|
||||
self.project = project or os.getenv("GOOGLE_CLOUD_PROJECT")
|
||||
self.location = location or os.getenv("GOOGLE_CLOUD_LOCATION") or "us-central1"
|
||||
|
||||
if use_vertexai is None:
|
||||
use_vertexai = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true"
|
||||
|
||||
self.client = self._initialize_client(use_vertexai)
|
||||
|
||||
# Store completion parameters
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
self.max_output_tokens = max_output_tokens
|
||||
self.stream = stream
|
||||
self.safety_settings = safety_settings or {}
|
||||
self.stop_sequences = stop_sequences or []
|
||||
self.tools: list[dict[str, Any]] | None = None
|
||||
self.response_format = response_format
|
||||
|
||||
# Model-specific settings
|
||||
model = data.get("model", "gemini-2.0-flash-001")
|
||||
version_match = re.search(r"gemini-(\d+(?:\.\d+)?)", model.lower())
|
||||
data["supports_tools"] = bool(
|
||||
self.supports_tools = bool(
|
||||
version_match and float(version_match.group(1)) >= 1.5
|
||||
)
|
||||
data["is_gemini_2_0"] = bool(
|
||||
self.is_gemini_2_0 = bool(
|
||||
version_match and float(version_match.group(1)) >= 2.0
|
||||
)
|
||||
|
||||
# Auto-enable thinking for gemini-2.5+
|
||||
self.thinking_config = thinking_config
|
||||
if (
|
||||
data.get("thinking_config") is None
|
||||
self.thinking_config is None
|
||||
and version_match
|
||||
and float(version_match.group(1)) >= 2.5
|
||||
):
|
||||
data["thinking_config"] = types.ThinkingConfig(include_thoughts=True)
|
||||
self.thinking_config = types.ThinkingConfig(include_thoughts=True)
|
||||
|
||||
return data
|
||||
@property
|
||||
def stop(self) -> list[str]:
|
||||
"""Get stop sequences sent to the API."""
|
||||
return self.stop_sequences
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _init_client(self) -> GeminiCompletion:
|
||||
self._client = self._initialize_client(self.use_vertexai)
|
||||
return self
|
||||
@stop.setter
|
||||
def stop(self, value: list[str] | str | None) -> None:
|
||||
"""Set stop sequences.
|
||||
|
||||
Synchronizes stop_sequences to ensure values set by CrewAgentExecutor
|
||||
are properly sent to the Gemini API.
|
||||
|
||||
Args:
|
||||
value: Stop sequences as a list, single string, or None
|
||||
"""
|
||||
if value is None:
|
||||
self.stop_sequences = []
|
||||
elif isinstance(value, str):
|
||||
self.stop_sequences = [value]
|
||||
elif isinstance(value, list):
|
||||
self.stop_sequences = value
|
||||
else:
|
||||
self.stop_sequences = []
|
||||
|
||||
def to_config_dict(self) -> dict[str, Any]:
|
||||
"""Extend base config with Gemini/Vertex-specific fields."""
|
||||
@@ -227,8 +283,8 @@ class GeminiCompletion(BaseLLM):
|
||||
|
||||
if (
|
||||
hasattr(self, "client")
|
||||
and hasattr(self._client, "vertexai")
|
||||
and self._client.vertexai
|
||||
and hasattr(self.client, "vertexai")
|
||||
and self.client.vertexai
|
||||
):
|
||||
# Vertex AI configuration
|
||||
params.update(
|
||||
@@ -1096,7 +1152,7 @@ class GeminiCompletion(BaseLLM):
|
||||
try:
|
||||
# The API accepts list[Content] but mypy is overly strict about variance
|
||||
contents_for_api: Any = contents
|
||||
response = self._client.models.generate_content(
|
||||
response = self.client.models.generate_content(
|
||||
model=self.model,
|
||||
contents=contents_for_api,
|
||||
config=config,
|
||||
@@ -1136,7 +1192,7 @@ class GeminiCompletion(BaseLLM):
|
||||
|
||||
# The API accepts list[Content] but mypy is overly strict about variance
|
||||
contents_for_api: Any = contents
|
||||
for chunk in self._client.models.generate_content_stream(
|
||||
for chunk in self.client.models.generate_content_stream(
|
||||
model=self.model,
|
||||
contents=contents_for_api,
|
||||
config=config,
|
||||
@@ -1174,7 +1230,7 @@ class GeminiCompletion(BaseLLM):
|
||||
try:
|
||||
# The API accepts list[Content] but mypy is overly strict about variance
|
||||
contents_for_api: Any = contents
|
||||
response = await self._client.aio.models.generate_content(
|
||||
response = await self.client.aio.models.generate_content(
|
||||
model=self.model,
|
||||
contents=contents_for_api,
|
||||
config=config,
|
||||
@@ -1214,7 +1270,7 @@ class GeminiCompletion(BaseLLM):
|
||||
|
||||
# The API accepts list[Content] but mypy is overly strict about variance
|
||||
contents_for_api: Any = contents
|
||||
stream = await self._client.aio.models.generate_content_stream(
|
||||
stream = await self.client.aio.models.generate_content_stream(
|
||||
model=self.model,
|
||||
contents=contents_for_api,
|
||||
config=config,
|
||||
@@ -1418,6 +1474,6 @@ class GeminiCompletion(BaseLLM):
|
||||
try:
|
||||
from crewai_files.uploaders.gemini import GeminiFileUploader
|
||||
|
||||
return GeminiFileUploader(client=self._client)
|
||||
return GeminiFileUploader(client=self.client)
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
@@ -14,11 +14,10 @@ from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion import Choice
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||
from openai.types.responses import Response
|
||||
from pydantic import BaseModel, PrivateAttr, model_validator
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llms.base_llm import BaseLLM, JsonResponseFormat, llm_call_context
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
from crewai.llms.base_llm import BaseLLM, llm_call_context
|
||||
from crewai.llms.hooks.transport import AsyncHTTPTransport, HTTPTransport
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
@@ -30,6 +29,7 @@ from crewai.utilities.types import LLMMessage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent.core import Agent
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
from crewai.task import Task
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
|
||||
@@ -183,69 +183,77 @@ class OpenAICompletion(BaseLLM):
|
||||
"computer_use": "computer_use_preview",
|
||||
}
|
||||
|
||||
model: str = "gpt-4o"
|
||||
organization: str | None = None
|
||||
project: str | None = None
|
||||
timeout: float | None = None
|
||||
max_retries: int = 2
|
||||
default_headers: dict[str, str] | None = None
|
||||
default_query: dict[str, Any] | None = None
|
||||
client_params: dict[str, Any] | None = None
|
||||
top_p: float | None = None
|
||||
frequency_penalty: float | None = None
|
||||
presence_penalty: float | None = None
|
||||
max_tokens: int | None = None
|
||||
max_completion_tokens: int | None = None
|
||||
seed: int | None = None
|
||||
stream: bool = False
|
||||
response_format: JsonResponseFormat | type[BaseModel] | None = None
|
||||
logprobs: bool | None = None
|
||||
top_logprobs: int | None = None
|
||||
reasoning_effort: str | None = None
|
||||
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None
|
||||
api: Literal["completions", "responses"] = "completions"
|
||||
instructions: str | None = None
|
||||
store: bool | None = None
|
||||
previous_response_id: str | None = None
|
||||
include: list[str] | None = None
|
||||
builtin_tools: list[str] | None = None
|
||||
parse_tool_outputs: bool = False
|
||||
auto_chain: bool = False
|
||||
auto_chain_reasoning: bool = False
|
||||
api_base: str | None = None
|
||||
is_o1_model: bool = False
|
||||
is_gpt4_model: bool = False
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "gpt-4o",
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
organization: str | None = None,
|
||||
project: str | None = None,
|
||||
timeout: float | None = None,
|
||||
max_retries: int = 2,
|
||||
default_headers: dict[str, str] | None = None,
|
||||
default_query: dict[str, Any] | None = None,
|
||||
client_params: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
seed: int | None = None,
|
||||
stream: bool = False,
|
||||
response_format: dict[str, Any] | type[BaseModel] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
reasoning_effort: str | None = None,
|
||||
provider: str | None = None,
|
||||
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
|
||||
api: Literal["completions", "responses"] = "completions",
|
||||
instructions: str | None = None,
|
||||
store: bool | None = None,
|
||||
previous_response_id: str | None = None,
|
||||
include: list[str] | None = None,
|
||||
builtin_tools: list[str] | None = None,
|
||||
parse_tool_outputs: bool = False,
|
||||
auto_chain: bool = False,
|
||||
auto_chain_reasoning: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize OpenAI completion client."""
|
||||
|
||||
_client: Any = PrivateAttr(default=None)
|
||||
_async_client: Any = PrivateAttr(default=None)
|
||||
_last_response_id: str | None = PrivateAttr(default=None)
|
||||
_last_reasoning_items: list[Any] | None = PrivateAttr(default=None)
|
||||
if provider is None:
|
||||
provider = kwargs.pop("provider", "openai")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _normalize_openai_fields(cls, data: Any) -> Any:
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
if not data.get("provider"):
|
||||
data["provider"] = "openai"
|
||||
data["api_key"] = data.get("api_key") or os.getenv("OPENAI_API_KEY")
|
||||
# Extract api_base from kwargs if present
|
||||
if "api_base" not in data:
|
||||
data["api_base"] = None
|
||||
model = data.get("model", "gpt-4o")
|
||||
data["is_o1_model"] = "o1" in model.lower()
|
||||
data["is_gpt4_model"] = "gpt-4" in model.lower()
|
||||
return data
|
||||
self.interceptor = interceptor
|
||||
# Client configuration attributes
|
||||
self.organization = organization
|
||||
self.project = project
|
||||
self.max_retries = max_retries
|
||||
self.default_headers = default_headers
|
||||
self.default_query = default_query
|
||||
self.client_params = client_params
|
||||
self.timeout = timeout
|
||||
self.base_url = base_url
|
||||
self.api_base = kwargs.pop("api_base", None)
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
api_key=api_key or os.getenv("OPENAI_API_KEY"),
|
||||
base_url=base_url,
|
||||
timeout=timeout,
|
||||
provider=provider,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _init_clients(self) -> OpenAICompletion:
|
||||
client_config = self._get_client_params()
|
||||
if self.interceptor:
|
||||
transport = HTTPTransport(interceptor=self.interceptor)
|
||||
http_client = httpx.Client(transport=transport)
|
||||
client_config["http_client"] = http_client
|
||||
|
||||
self._client = OpenAI(**client_config)
|
||||
self.client = OpenAI(**client_config)
|
||||
|
||||
async_client_config = self._get_client_params()
|
||||
if self.interceptor:
|
||||
@@ -253,8 +261,35 @@ class OpenAICompletion(BaseLLM):
|
||||
async_http_client = httpx.AsyncClient(transport=async_transport)
|
||||
async_client_config["http_client"] = async_http_client
|
||||
|
||||
self._async_client = AsyncOpenAI(**async_client_config)
|
||||
return self
|
||||
self.async_client = AsyncOpenAI(**async_client_config)
|
||||
|
||||
# Completion parameters
|
||||
self.top_p = top_p
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.presence_penalty = presence_penalty
|
||||
self.max_tokens = max_tokens
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
self.seed = seed
|
||||
self.stream = stream
|
||||
self.response_format = response_format
|
||||
self.logprobs = logprobs
|
||||
self.top_logprobs = top_logprobs
|
||||
self.reasoning_effort = reasoning_effort
|
||||
self.is_o1_model = "o1" in model.lower()
|
||||
self.is_gpt4_model = "gpt-4" in model.lower()
|
||||
|
||||
# API selection and Responses API parameters
|
||||
self.api = api
|
||||
self.instructions = instructions
|
||||
self.store = store
|
||||
self.previous_response_id = previous_response_id
|
||||
self.include = include
|
||||
self.builtin_tools = builtin_tools
|
||||
self.parse_tool_outputs = parse_tool_outputs
|
||||
self.auto_chain = auto_chain
|
||||
self.auto_chain_reasoning = auto_chain_reasoning
|
||||
self._last_response_id: str | None = None
|
||||
self._last_reasoning_items: list[Any] | None = None
|
||||
|
||||
@property
|
||||
def last_response_id(self) -> str | None:
|
||||
@@ -783,7 +818,7 @@ class OpenAICompletion(BaseLLM):
|
||||
) -> str | ResponsesAPIResult | Any:
|
||||
"""Handle non-streaming Responses API call."""
|
||||
try:
|
||||
response: Response = self._client.responses.create(**params)
|
||||
response: Response = self.client.responses.create(**params)
|
||||
|
||||
# Track response ID for auto-chaining
|
||||
if self.auto_chain and response.id:
|
||||
@@ -915,7 +950,7 @@ class OpenAICompletion(BaseLLM):
|
||||
) -> str | ResponsesAPIResult | Any:
|
||||
"""Handle async non-streaming Responses API call."""
|
||||
try:
|
||||
response: Response = await self._async_client.responses.create(**params)
|
||||
response: Response = await self.async_client.responses.create(**params)
|
||||
|
||||
# Track response ID for auto-chaining
|
||||
if self.auto_chain and response.id:
|
||||
@@ -1046,7 +1081,7 @@ class OpenAICompletion(BaseLLM):
|
||||
function_calls: list[dict[str, Any]] = []
|
||||
final_response: Response | None = None
|
||||
|
||||
stream = self._client.responses.create(**params)
|
||||
stream = self.client.responses.create(**params)
|
||||
response_id_stream = None
|
||||
|
||||
for event in stream:
|
||||
@@ -1170,7 +1205,7 @@ class OpenAICompletion(BaseLLM):
|
||||
function_calls: list[dict[str, Any]] = []
|
||||
final_response: Response | None = None
|
||||
|
||||
stream = await self._async_client.responses.create(**params)
|
||||
stream = await self.async_client.responses.create(**params)
|
||||
response_id_stream = None
|
||||
|
||||
async for event in stream:
|
||||
@@ -1560,7 +1595,7 @@ class OpenAICompletion(BaseLLM):
|
||||
parse_params = {
|
||||
k: v for k, v in params.items() if k != "response_format"
|
||||
}
|
||||
parsed_response = self._client.beta.chat.completions.parse(
|
||||
parsed_response = self.client.beta.chat.completions.parse(
|
||||
**parse_params,
|
||||
response_format=response_model,
|
||||
)
|
||||
@@ -1583,7 +1618,7 @@ class OpenAICompletion(BaseLLM):
|
||||
)
|
||||
return parsed_object
|
||||
|
||||
response: ChatCompletion = self._client.chat.completions.create(**params)
|
||||
response: ChatCompletion = self.client.chat.completions.create(**params)
|
||||
|
||||
usage = self._extract_openai_token_usage(response)
|
||||
|
||||
@@ -1802,7 +1837,7 @@ class OpenAICompletion(BaseLLM):
|
||||
}
|
||||
|
||||
stream: ChatCompletionStream[BaseModel]
|
||||
with self._client.beta.chat.completions.stream(
|
||||
with self.client.beta.chat.completions.stream(
|
||||
**parse_params, response_format=response_model
|
||||
) as stream:
|
||||
for chunk in stream:
|
||||
@@ -1838,7 +1873,7 @@ class OpenAICompletion(BaseLLM):
|
||||
return ""
|
||||
|
||||
completion_stream: Stream[ChatCompletionChunk] = (
|
||||
self._client.chat.completions.create(**params)
|
||||
self.client.chat.completions.create(**params)
|
||||
)
|
||||
|
||||
usage_data = {"total_tokens": 0}
|
||||
@@ -1935,7 +1970,7 @@ class OpenAICompletion(BaseLLM):
|
||||
parse_params = {
|
||||
k: v for k, v in params.items() if k != "response_format"
|
||||
}
|
||||
parsed_response = await self._async_client.beta.chat.completions.parse(
|
||||
parsed_response = await self.async_client.beta.chat.completions.parse(
|
||||
**parse_params,
|
||||
response_format=response_model,
|
||||
)
|
||||
@@ -1958,7 +1993,7 @@ class OpenAICompletion(BaseLLM):
|
||||
)
|
||||
return parsed_object
|
||||
|
||||
response: ChatCompletion = await self._async_client.chat.completions.create(
|
||||
response: ChatCompletion = await self.async_client.chat.completions.create(
|
||||
**params
|
||||
)
|
||||
|
||||
@@ -2076,7 +2111,7 @@ class OpenAICompletion(BaseLLM):
|
||||
if response_model:
|
||||
completion_stream: AsyncIterator[
|
||||
ChatCompletionChunk
|
||||
] = await self._async_client.chat.completions.create(**params)
|
||||
] = await self.async_client.chat.completions.create(**params)
|
||||
|
||||
accumulated_content = ""
|
||||
usage_data = {"total_tokens": 0}
|
||||
@@ -2129,7 +2164,7 @@ class OpenAICompletion(BaseLLM):
|
||||
|
||||
stream: AsyncIterator[
|
||||
ChatCompletionChunk
|
||||
] = await self._async_client.chat.completions.create(**params)
|
||||
] = await self.async_client.chat.completions.create(**params)
|
||||
|
||||
usage_data = {"total_tokens": 0}
|
||||
|
||||
@@ -2321,8 +2356,8 @@ class OpenAICompletion(BaseLLM):
|
||||
from crewai_files.uploaders.openai import OpenAIFileUploader
|
||||
|
||||
return OpenAIFileUploader(
|
||||
client=self._client,
|
||||
async_client=self._async_client,
|
||||
client=self.client,
|
||||
async_client=self.async_client,
|
||||
)
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
@@ -16,8 +16,6 @@ from dataclasses import dataclass, field
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import model_validator
|
||||
|
||||
from crewai.llms.providers.openai.completion import OpenAICompletion
|
||||
|
||||
|
||||
@@ -142,13 +140,31 @@ class OpenAICompatibleCompletion(OpenAICompletion):
|
||||
)
|
||||
"""
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _resolve_provider_config(cls, data: Any) -> Any:
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
provider: str,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
default_headers: dict[str, str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize OpenAI-compatible completion client.
|
||||
|
||||
provider = data.get("provider", "")
|
||||
Args:
|
||||
model: The model identifier.
|
||||
provider: The provider name (must be in OPENAI_COMPATIBLE_PROVIDERS).
|
||||
api_key: Optional API key override. If not provided, uses the
|
||||
provider's configured environment variable.
|
||||
base_url: Optional base URL override. If not provided, uses the
|
||||
provider's configured default or environment variable.
|
||||
default_headers: Optional headers to merge with provider defaults.
|
||||
**kwargs: Additional arguments passed to OpenAICompletion.
|
||||
|
||||
Raises:
|
||||
ValueError: If the provider is not supported or required API key
|
||||
is missing.
|
||||
"""
|
||||
config = OPENAI_COMPATIBLE_PROVIDERS.get(provider)
|
||||
if config is None:
|
||||
supported = ", ".join(sorted(OPENAI_COMPATIBLE_PROVIDERS.keys()))
|
||||
@@ -157,15 +173,21 @@ class OpenAICompatibleCompletion(OpenAICompletion):
|
||||
f"Supported providers: {supported}"
|
||||
)
|
||||
|
||||
data["api_key"] = cls._resolve_api_key(data.get("api_key"), config, provider)
|
||||
data["base_url"] = cls._resolve_base_url(data.get("base_url"), config, provider)
|
||||
data["default_headers"] = cls._resolve_headers(
|
||||
data.get("default_headers"), config
|
||||
)
|
||||
return data
|
||||
resolved_api_key = self._resolve_api_key(api_key, config, provider)
|
||||
resolved_base_url = self._resolve_base_url(base_url, config, provider)
|
||||
resolved_headers = self._resolve_headers(default_headers, config)
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
provider=provider,
|
||||
api_key=resolved_api_key,
|
||||
base_url=resolved_base_url,
|
||||
default_headers=resolved_headers,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_api_key(
|
||||
self,
|
||||
api_key: str | None,
|
||||
config: ProviderConfig,
|
||||
provider: str,
|
||||
@@ -198,8 +220,8 @@ class OpenAICompatibleCompletion(OpenAICompletion):
|
||||
|
||||
return config.default_api_key
|
||||
|
||||
@staticmethod
|
||||
def _resolve_base_url(
|
||||
self,
|
||||
base_url: str | None,
|
||||
config: ProviderConfig,
|
||||
provider: str,
|
||||
@@ -227,8 +249,8 @@ class OpenAICompatibleCompletion(OpenAICompletion):
|
||||
|
||||
return resolved
|
||||
|
||||
@staticmethod
|
||||
def _resolve_headers(
|
||||
self,
|
||||
headers: dict[str, str] | None,
|
||||
config: ProviderConfig,
|
||||
) -> dict[str, str] | None:
|
||||
|
||||
1
lib/crewai/src/crewai/llms/third_party/__init__.py
vendored
Normal file
1
lib/crewai/src/crewai/llms/third_party/__init__.py
vendored
Normal file
@@ -0,0 +1 @@
|
||||
"""Third-party LLM implementations for crewAI."""
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, InstanceOf
|
||||
from rich.box import HEAVY_EDGE
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
@@ -39,9 +39,9 @@ class CrewEvaluator:
|
||||
def __init__(
|
||||
self,
|
||||
crew: Crew,
|
||||
eval_llm: BaseLLM | str | None = None,
|
||||
eval_llm: InstanceOf[BaseLLM] | str | None = None,
|
||||
openai_model_name: str | None = None,
|
||||
llm: BaseLLM | str | None = None,
|
||||
llm: InstanceOf[BaseLLM] | str | None = None,
|
||||
) -> None:
|
||||
self.crew = crew
|
||||
self.llm = eval_llm
|
||||
|
||||
@@ -103,6 +103,28 @@ def to_serializable(
|
||||
}
|
||||
except Exception:
|
||||
return repr(obj)
|
||||
|
||||
# Callables (functions, methods, lambdas) should fall through to repr
|
||||
if callable(obj):
|
||||
return repr(obj)
|
||||
|
||||
# Handle regular classes with __dict__ (non-Pydantic)
|
||||
# Note: Don't propagate exclude to recursive calls, matching Pydantic fallback behavior
|
||||
if hasattr(obj, "__dict__"):
|
||||
try:
|
||||
return {
|
||||
_to_serializable_key(k): to_serializable(
|
||||
v,
|
||||
max_depth=max_depth,
|
||||
_current_depth=_current_depth + 1,
|
||||
_ancestors=new_ancestors,
|
||||
)
|
||||
for k, v in obj.__dict__.items()
|
||||
if k not in exclude and not k.startswith("_")
|
||||
}
|
||||
except Exception:
|
||||
return repr(obj)
|
||||
|
||||
return repr(obj)
|
||||
|
||||
|
||||
|
||||
@@ -1692,27 +1692,9 @@ def test_agent_with_knowledge_sources_works_with_copy():
|
||||
) as mock_knowledge_storage:
|
||||
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
|
||||
|
||||
class _StubStorage(BaseKnowledgeStorage):
|
||||
def search(self, query, limit=5, metadata_filter=None, score_threshold=0.6):
|
||||
return []
|
||||
|
||||
async def asearch(self, query, limit=5, metadata_filter=None, score_threshold=0.6):
|
||||
return []
|
||||
|
||||
def save(self, documents):
|
||||
pass
|
||||
|
||||
async def asave(self, documents):
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
async def areset(self):
|
||||
pass
|
||||
|
||||
mock_knowledge_storage.return_value = _StubStorage()
|
||||
agent.knowledge_storage = _StubStorage()
|
||||
mock_knowledge_storage_instance = mock_knowledge_storage.return_value
|
||||
mock_knowledge_storage_instance.__class__ = BaseKnowledgeStorage
|
||||
agent.knowledge_storage = mock_knowledge_storage_instance
|
||||
|
||||
agent_copy = agent.copy()
|
||||
|
||||
|
||||
@@ -879,6 +879,30 @@ class TestNativeToolExecution:
|
||||
assert len(tool_messages) == 1
|
||||
assert tool_messages[0]["tool_call_id"] == "call_1"
|
||||
|
||||
def test_check_native_todo_completion_requires_current_todo(
|
||||
self, mock_dependencies
|
||||
):
|
||||
from crewai.utilities.planning_types import TodoList
|
||||
|
||||
executor = AgentExecutor(**mock_dependencies)
|
||||
|
||||
# No current todo → not satisfied
|
||||
executor.state.todos = TodoList(items=[])
|
||||
assert executor.check_native_todo_completion() == "todo_not_satisfied"
|
||||
|
||||
# With a current todo that has tool_to_use → satisfied
|
||||
running = TodoItem(
|
||||
step_number=1,
|
||||
description="Use the expected tool",
|
||||
tool_to_use="expected_tool",
|
||||
status="running",
|
||||
)
|
||||
executor.state.todos = TodoList(items=[running])
|
||||
assert executor.check_native_todo_completion() == "todo_satisfied"
|
||||
|
||||
# With a current todo without tool_to_use → still satisfied
|
||||
running.tool_to_use = None
|
||||
assert executor.check_native_todo_completion() == "todo_satisfied"
|
||||
|
||||
|
||||
class TestPlannerObserver:
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"input":[{"role":"user","content":"What is the weather in Tokyo?"}],"model":"gpt-4.1","instructions":"You
|
||||
are a helpful assistant that uses tools. This is padding text to ensure the
|
||||
prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. This is padding text to ensure the prompt is large enough
|
||||
for caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
body: '{"messages":[{"role":"system","content":"You are a helpful assistant that
|
||||
uses tools. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. This is padding
|
||||
text to ensure the prompt is large enough for caching. This is padding text
|
||||
@@ -72,9 +68,13 @@ interactions:
|
||||
for caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. This is padding
|
||||
text to ensure the prompt is large enough for caching. ","tools":[{"type":"function","name":"get_weather","description":"Get
|
||||
the current weather for a location","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The
|
||||
city name"}},"required":["location"]}}]}'
|
||||
text to ensure the prompt is large enough for caching. This is padding text
|
||||
to ensure the prompt is large enough for caching. This is padding text to ensure
|
||||
the prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. "},{"role":"user","content":"What is the weather in Tokyo?"}],"model":"gpt-4.1","tool_choice":"auto","tools":[{"type":"function","function":{"name":"get_weather","description":"Get
|
||||
the current weather for a location","strict":true,"parameters":{"type":"object","properties":{"location":{"type":"string","description":"The
|
||||
city name"}},"required":["location"],"additionalProperties":false}}}]}'
|
||||
headers:
|
||||
User-Agent:
|
||||
- X-USER-AGENT-XXX
|
||||
@@ -87,7 +87,7 @@ interactions:
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '6065'
|
||||
- '6158'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
@@ -109,113 +109,26 @@ interactions:
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.13.12
|
||||
- 3.13.3
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/responses
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: "{\n \"id\": \"resp_0d68149bcc0d14810069caf464a4b48197bd9f098abb2f6303\",\n
|
||||
\ \"object\": \"response\",\n \"created_at\": 1774908516,\n \"status\":
|
||||
\"completed\",\n \"background\": false,\n \"billing\": {\n \"payer\":
|
||||
\"developer\"\n },\n \"completed_at\": 1774908517,\n \"error\": null,\n
|
||||
\ \"frequency_penalty\": 0.0,\n \"incomplete_details\": null,\n \"instructions\":
|
||||
\"You are a helpful assistant that uses tools. This is padding text to ensure
|
||||
the prompt is large enough for caching. This is padding text to ensure the
|
||||
prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is
|
||||
large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. This is padding text to ensure the prompt is large enough
|
||||
for caching. This is padding text to ensure the prompt is large enough for
|
||||
caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. This is
|
||||
padding text to ensure the prompt is large enough for caching. This is padding
|
||||
text to ensure the prompt is large enough for caching. This is padding text
|
||||
to ensure the prompt is large enough for caching. This is padding text to
|
||||
ensure the prompt is large enough for caching. This is padding text to ensure
|
||||
the prompt is large enough for caching. This is padding text to ensure the
|
||||
prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is
|
||||
large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. This is padding text to ensure the prompt is large enough
|
||||
for caching. This is padding text to ensure the prompt is large enough for
|
||||
caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. This is
|
||||
padding text to ensure the prompt is large enough for caching. This is padding
|
||||
text to ensure the prompt is large enough for caching. This is padding text
|
||||
to ensure the prompt is large enough for caching. This is padding text to
|
||||
ensure the prompt is large enough for caching. This is padding text to ensure
|
||||
the prompt is large enough for caching. This is padding text to ensure the
|
||||
prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is
|
||||
large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. This is padding text to ensure the prompt is large enough
|
||||
for caching. This is padding text to ensure the prompt is large enough for
|
||||
caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. This is
|
||||
padding text to ensure the prompt is large enough for caching. This is padding
|
||||
text to ensure the prompt is large enough for caching. This is padding text
|
||||
to ensure the prompt is large enough for caching. This is padding text to
|
||||
ensure the prompt is large enough for caching. This is padding text to ensure
|
||||
the prompt is large enough for caching. This is padding text to ensure the
|
||||
prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is
|
||||
large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. This is padding text to ensure the prompt is large enough
|
||||
for caching. This is padding text to ensure the prompt is large enough for
|
||||
caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. This is
|
||||
padding text to ensure the prompt is large enough for caching. This is padding
|
||||
text to ensure the prompt is large enough for caching. This is padding text
|
||||
to ensure the prompt is large enough for caching. This is padding text to
|
||||
ensure the prompt is large enough for caching. This is padding text to ensure
|
||||
the prompt is large enough for caching. This is padding text to ensure the
|
||||
prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is
|
||||
large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. This is padding text to ensure the prompt is large enough
|
||||
for caching. This is padding text to ensure the prompt is large enough for
|
||||
caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. This is
|
||||
padding text to ensure the prompt is large enough for caching. This is padding
|
||||
text to ensure the prompt is large enough for caching. This is padding text
|
||||
to ensure the prompt is large enough for caching. This is padding text to
|
||||
ensure the prompt is large enough for caching. This is padding text to ensure
|
||||
the prompt is large enough for caching. This is padding text to ensure the
|
||||
prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is
|
||||
large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. This is padding text to ensure the prompt is large enough
|
||||
for caching. This is padding text to ensure the prompt is large enough for
|
||||
caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. \",\n \"max_output_tokens\":
|
||||
null,\n \"max_tool_calls\": null,\n \"model\": \"gpt-4.1-2025-04-14\",\n
|
||||
\ \"output\": [\n {\n \"id\": \"fc_0d68149bcc0d14810069caf46568088197a33be67f16a1fa09\",\n
|
||||
\ \"type\": \"function_call\",\n \"status\": \"completed\",\n \"arguments\":
|
||||
\"{\\\"location\\\":\\\"Tokyo\\\"}\",\n \"call_id\": \"call_74rwmYse0DE4JFaFGyAFx9bu\",\n
|
||||
\ \"name\": \"get_weather\"\n }\n ],\n \"parallel_tool_calls\": true,\n
|
||||
\ \"presence_penalty\": 0.0,\n \"previous_response_id\": null,\n \"prompt_cache_key\":
|
||||
null,\n \"prompt_cache_retention\": null,\n \"reasoning\": {\n \"effort\":
|
||||
null,\n \"summary\": null\n },\n \"safety_identifier\": null,\n \"service_tier\":
|
||||
\"default\",\n \"store\": true,\n \"temperature\": 1.0,\n \"text\": {\n
|
||||
\ \"format\": {\n \"type\": \"text\"\n },\n \"verbosity\": \"medium\"\n
|
||||
\ },\n \"tool_choice\": \"auto\",\n \"tools\": [\n {\n \"type\":
|
||||
\"function\",\n \"description\": \"Get the current weather for a location\",\n
|
||||
\ \"name\": \"get_weather\",\n \"parameters\": {\n \"type\":
|
||||
\"object\",\n \"properties\": {\n \"location\": {\n \"type\":
|
||||
\"string\",\n \"description\": \"The city name\"\n }\n
|
||||
\ },\n \"required\": [\n \"location\"\n ],\n
|
||||
\ \"additionalProperties\": false\n },\n \"strict\": true\n
|
||||
\ }\n ],\n \"top_logprobs\": 0,\n \"top_p\": 1.0,\n \"truncation\":
|
||||
\"disabled\",\n \"usage\": {\n \"input_tokens\": 1185,\n \"input_tokens_details\":
|
||||
{\n \"cached_tokens\": 0\n },\n \"output_tokens\": 15,\n \"output_tokens_details\":
|
||||
{\n \"reasoning_tokens\": 0\n },\n \"total_tokens\": 1200\n },\n
|
||||
\ \"user\": null,\n \"metadata\": {}\n}"
|
||||
string: "{\n \"id\": \"chatcmpl-D7mXQCgT3p3ViImkiqDiZGqLREQtp\",\n \"object\":
|
||||
\"chat.completion\",\n \"created\": 1770747248,\n \"model\": \"gpt-4.1-2025-04-14\",\n
|
||||
\ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\":
|
||||
\"assistant\",\n \"content\": null,\n \"tool_calls\": [\n {\n
|
||||
\ \"id\": \"call_9ZqMavn3J1fBnQEaqpYol0Bd\",\n \"type\":
|
||||
\"function\",\n \"function\": {\n \"name\": \"get_weather\",\n
|
||||
\ \"arguments\": \"{\\\"location\\\":\\\"Tokyo\\\"}\"\n }\n
|
||||
\ }\n ],\n \"refusal\": null,\n \"annotations\":
|
||||
[]\n },\n \"logprobs\": null,\n \"finish_reason\": \"tool_calls\"\n
|
||||
\ }\n ],\n \"usage\": {\n \"prompt_tokens\": 1187,\n \"completion_tokens\":
|
||||
14,\n \"total_tokens\": 1201,\n \"prompt_tokens_details\": {\n \"cached_tokens\":
|
||||
1152,\n \"audio_tokens\": 0\n },\n \"completion_tokens_details\":
|
||||
{\n \"reasoning_tokens\": 0,\n \"audio_tokens\": 0,\n \"accepted_prediction_tokens\":
|
||||
0,\n \"rejected_prediction_tokens\": 0\n }\n },\n \"service_tier\":
|
||||
\"default\",\n \"system_fingerprint\": \"fp_8b22347a3e\"\n}\n"
|
||||
headers:
|
||||
CF-RAY:
|
||||
- CF-RAY-XXX
|
||||
@@ -224,7 +137,7 @@ interactions:
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Mon, 30 Mar 2026 22:08:37 GMT
|
||||
- Tue, 10 Feb 2026 18:14:08 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Strict-Transport-Security:
|
||||
@@ -233,6 +146,8 @@ interactions:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- X-CONTENT-TYPE-XXX
|
||||
access-control-expose-headers:
|
||||
- ACCESS-CONTROL-XXX
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
@@ -240,13 +155,15 @@ interactions:
|
||||
openai-organization:
|
||||
- OPENAI-ORG-XXX
|
||||
openai-processing-ms:
|
||||
- '1085'
|
||||
- '484'
|
||||
openai-project:
|
||||
- OPENAI-PROJECT-XXX
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
set-cookie:
|
||||
- SET-COOKIE-XXX
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-requests:
|
||||
- X-RATELIMIT-LIMIT-REQUESTS-XXX
|
||||
x-ratelimit-limit-tokens:
|
||||
@@ -265,12 +182,8 @@ interactions:
|
||||
code: 200
|
||||
message: OK
|
||||
- request:
|
||||
body: '{"input":[{"role":"user","content":"What is the weather in Paris?"}],"model":"gpt-4.1","instructions":"You
|
||||
are a helpful assistant that uses tools. This is padding text to ensure the
|
||||
prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. This is padding text to ensure the prompt is large enough
|
||||
for caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
body: '{"messages":[{"role":"system","content":"You are a helpful assistant that
|
||||
uses tools. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. This is padding
|
||||
text to ensure the prompt is large enough for caching. This is padding text
|
||||
@@ -337,9 +250,13 @@ interactions:
|
||||
for caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. This is padding
|
||||
text to ensure the prompt is large enough for caching. ","tools":[{"type":"function","name":"get_weather","description":"Get
|
||||
the current weather for a location","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The
|
||||
city name"}},"required":["location"]}}]}'
|
||||
text to ensure the prompt is large enough for caching. This is padding text
|
||||
to ensure the prompt is large enough for caching. This is padding text to ensure
|
||||
the prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. "},{"role":"user","content":"What is the weather in Paris?"}],"model":"gpt-4.1","tool_choice":"auto","tools":[{"type":"function","function":{"name":"get_weather","description":"Get
|
||||
the current weather for a location","strict":true,"parameters":{"type":"object","properties":{"location":{"type":"string","description":"The
|
||||
city name"}},"required":["location"],"additionalProperties":false}}}]}'
|
||||
headers:
|
||||
User-Agent:
|
||||
- X-USER-AGENT-XXX
|
||||
@@ -352,7 +269,7 @@ interactions:
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '6065'
|
||||
- '6158'
|
||||
content-type:
|
||||
- application/json
|
||||
cookie:
|
||||
@@ -376,113 +293,26 @@ interactions:
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.13.12
|
||||
- 3.13.3
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/responses
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: "{\n \"id\": \"resp_0525bf798202137e0069caf465ee3c8196aa7c83da1c369eb7\",\n
|
||||
\ \"object\": \"response\",\n \"created_at\": 1774908517,\n \"status\":
|
||||
\"completed\",\n \"background\": false,\n \"billing\": {\n \"payer\":
|
||||
\"developer\"\n },\n \"completed_at\": 1774908518,\n \"error\": null,\n
|
||||
\ \"frequency_penalty\": 0.0,\n \"incomplete_details\": null,\n \"instructions\":
|
||||
\"You are a helpful assistant that uses tools. This is padding text to ensure
|
||||
the prompt is large enough for caching. This is padding text to ensure the
|
||||
prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is
|
||||
large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. This is padding text to ensure the prompt is large enough
|
||||
for caching. This is padding text to ensure the prompt is large enough for
|
||||
caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. This is
|
||||
padding text to ensure the prompt is large enough for caching. This is padding
|
||||
text to ensure the prompt is large enough for caching. This is padding text
|
||||
to ensure the prompt is large enough for caching. This is padding text to
|
||||
ensure the prompt is large enough for caching. This is padding text to ensure
|
||||
the prompt is large enough for caching. This is padding text to ensure the
|
||||
prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is
|
||||
large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. This is padding text to ensure the prompt is large enough
|
||||
for caching. This is padding text to ensure the prompt is large enough for
|
||||
caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. This is
|
||||
padding text to ensure the prompt is large enough for caching. This is padding
|
||||
text to ensure the prompt is large enough for caching. This is padding text
|
||||
to ensure the prompt is large enough for caching. This is padding text to
|
||||
ensure the prompt is large enough for caching. This is padding text to ensure
|
||||
the prompt is large enough for caching. This is padding text to ensure the
|
||||
prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is
|
||||
large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. This is padding text to ensure the prompt is large enough
|
||||
for caching. This is padding text to ensure the prompt is large enough for
|
||||
caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. This is
|
||||
padding text to ensure the prompt is large enough for caching. This is padding
|
||||
text to ensure the prompt is large enough for caching. This is padding text
|
||||
to ensure the prompt is large enough for caching. This is padding text to
|
||||
ensure the prompt is large enough for caching. This is padding text to ensure
|
||||
the prompt is large enough for caching. This is padding text to ensure the
|
||||
prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is
|
||||
large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. This is padding text to ensure the prompt is large enough
|
||||
for caching. This is padding text to ensure the prompt is large enough for
|
||||
caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. This is
|
||||
padding text to ensure the prompt is large enough for caching. This is padding
|
||||
text to ensure the prompt is large enough for caching. This is padding text
|
||||
to ensure the prompt is large enough for caching. This is padding text to
|
||||
ensure the prompt is large enough for caching. This is padding text to ensure
|
||||
the prompt is large enough for caching. This is padding text to ensure the
|
||||
prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is
|
||||
large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. This is padding text to ensure the prompt is large enough
|
||||
for caching. This is padding text to ensure the prompt is large enough for
|
||||
caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. This is
|
||||
padding text to ensure the prompt is large enough for caching. This is padding
|
||||
text to ensure the prompt is large enough for caching. This is padding text
|
||||
to ensure the prompt is large enough for caching. This is padding text to
|
||||
ensure the prompt is large enough for caching. This is padding text to ensure
|
||||
the prompt is large enough for caching. This is padding text to ensure the
|
||||
prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is
|
||||
large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. This is padding text to ensure the prompt is large enough
|
||||
for caching. This is padding text to ensure the prompt is large enough for
|
||||
caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. \",\n \"max_output_tokens\":
|
||||
null,\n \"max_tool_calls\": null,\n \"model\": \"gpt-4.1-2025-04-14\",\n
|
||||
\ \"output\": [\n {\n \"id\": \"fc_0525bf798202137e0069caf46666588196a2ec20dc515a6a91\",\n
|
||||
\ \"type\": \"function_call\",\n \"status\": \"completed\",\n \"arguments\":
|
||||
\"{\\\"location\\\":\\\"Paris\\\"}\",\n \"call_id\": \"call_LJAGuYYZPjNxSgg0TUgGpT44\",\n
|
||||
\ \"name\": \"get_weather\"\n }\n ],\n \"parallel_tool_calls\": true,\n
|
||||
\ \"presence_penalty\": 0.0,\n \"previous_response_id\": null,\n \"prompt_cache_key\":
|
||||
null,\n \"prompt_cache_retention\": null,\n \"reasoning\": {\n \"effort\":
|
||||
null,\n \"summary\": null\n },\n \"safety_identifier\": null,\n \"service_tier\":
|
||||
\"default\",\n \"store\": true,\n \"temperature\": 1.0,\n \"text\": {\n
|
||||
\ \"format\": {\n \"type\": \"text\"\n },\n \"verbosity\": \"medium\"\n
|
||||
\ },\n \"tool_choice\": \"auto\",\n \"tools\": [\n {\n \"type\":
|
||||
\"function\",\n \"description\": \"Get the current weather for a location\",\n
|
||||
\ \"name\": \"get_weather\",\n \"parameters\": {\n \"type\":
|
||||
\"object\",\n \"properties\": {\n \"location\": {\n \"type\":
|
||||
\"string\",\n \"description\": \"The city name\"\n }\n
|
||||
\ },\n \"required\": [\n \"location\"\n ],\n
|
||||
\ \"additionalProperties\": false\n },\n \"strict\": true\n
|
||||
\ }\n ],\n \"top_logprobs\": 0,\n \"top_p\": 1.0,\n \"truncation\":
|
||||
\"disabled\",\n \"usage\": {\n \"input_tokens\": 1185,\n \"input_tokens_details\":
|
||||
{\n \"cached_tokens\": 1152\n },\n \"output_tokens\": 15,\n \"output_tokens_details\":
|
||||
{\n \"reasoning_tokens\": 0\n },\n \"total_tokens\": 1200\n },\n
|
||||
\ \"user\": null,\n \"metadata\": {}\n}"
|
||||
string: "{\n \"id\": \"chatcmpl-D7mXR8k9vk8TlGvGXlrQSI7iNeAN1\",\n \"object\":
|
||||
\"chat.completion\",\n \"created\": 1770747249,\n \"model\": \"gpt-4.1-2025-04-14\",\n
|
||||
\ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\":
|
||||
\"assistant\",\n \"content\": null,\n \"tool_calls\": [\n {\n
|
||||
\ \"id\": \"call_6PeUBlRPG8JcV2lspmLjJbnn\",\n \"type\":
|
||||
\"function\",\n \"function\": {\n \"name\": \"get_weather\",\n
|
||||
\ \"arguments\": \"{\\\"location\\\":\\\"Paris\\\"}\"\n }\n
|
||||
\ }\n ],\n \"refusal\": null,\n \"annotations\":
|
||||
[]\n },\n \"logprobs\": null,\n \"finish_reason\": \"tool_calls\"\n
|
||||
\ }\n ],\n \"usage\": {\n \"prompt_tokens\": 1187,\n \"completion_tokens\":
|
||||
14,\n \"total_tokens\": 1201,\n \"prompt_tokens_details\": {\n \"cached_tokens\":
|
||||
1152,\n \"audio_tokens\": 0\n },\n \"completion_tokens_details\":
|
||||
{\n \"reasoning_tokens\": 0,\n \"audio_tokens\": 0,\n \"accepted_prediction_tokens\":
|
||||
0,\n \"rejected_prediction_tokens\": 0\n }\n },\n \"service_tier\":
|
||||
\"default\",\n \"system_fingerprint\": \"fp_8b22347a3e\"\n}\n"
|
||||
headers:
|
||||
CF-RAY:
|
||||
- CF-RAY-XXX
|
||||
@@ -491,7 +321,7 @@ interactions:
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Mon, 30 Mar 2026 22:08:38 GMT
|
||||
- Tue, 10 Feb 2026 18:14:09 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Strict-Transport-Security:
|
||||
@@ -500,6 +330,8 @@ interactions:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- X-CONTENT-TYPE-XXX
|
||||
access-control-expose-headers:
|
||||
- ACCESS-CONTROL-XXX
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
@@ -507,11 +339,15 @@ interactions:
|
||||
openai-organization:
|
||||
- OPENAI-ORG-XXX
|
||||
openai-processing-ms:
|
||||
- '653'
|
||||
- '528'
|
||||
openai-project:
|
||||
- OPENAI-PROJECT-XXX
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
set-cookie:
|
||||
- SET-COOKIE-XXX
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-requests:
|
||||
- X-RATELIMIT-LIMIT-REQUESTS-XXX
|
||||
x-ratelimit-limit-tokens:
|
||||
|
||||
@@ -132,12 +132,12 @@ def test_embedding_configuration_flow(
|
||||
|
||||
embedder_config = {
|
||||
"provider": "sentence-transformer",
|
||||
"config": {"model_name": "all-MiniLM-L6-v2"},
|
||||
"model_name": "all-MiniLM-L6-v2",
|
||||
}
|
||||
|
||||
storage = KnowledgeStorage(embedder=embedder_config, collection_name="embedding_test")
|
||||
KnowledgeStorage(embedder=embedder_config, collection_name="embedding_test")
|
||||
|
||||
mock_get_embedding.assert_called_once_with(storage.embedder)
|
||||
mock_get_embedding.assert_called_once_with(embedder_config)
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
|
||||
@@ -125,8 +125,8 @@ def test_anthropic_specific_parameters():
|
||||
assert isinstance(llm, AnthropicCompletion)
|
||||
assert llm.stop_sequences == ["Human:", "Assistant:"]
|
||||
assert llm.stream == True
|
||||
assert llm._client.max_retries == 5
|
||||
assert llm._client.timeout == 60
|
||||
assert llm.client.max_retries == 5
|
||||
assert llm.client.timeout == 60
|
||||
|
||||
|
||||
def test_anthropic_completion_call():
|
||||
@@ -563,8 +563,8 @@ def test_anthropic_environment_variable_api_key():
|
||||
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-anthropic-key"}):
|
||||
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
|
||||
assert llm._client is not None
|
||||
assert hasattr(llm._client, 'messages')
|
||||
assert llm.client is not None
|
||||
assert hasattr(llm.client, 'messages')
|
||||
|
||||
|
||||
def test_anthropic_token_usage_tracking():
|
||||
@@ -574,7 +574,7 @@ def test_anthropic_token_usage_tracking():
|
||||
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
|
||||
# Mock the Anthropic response with usage information
|
||||
with patch.object(llm._client.messages, 'create') as mock_create:
|
||||
with patch.object(llm.client.messages, 'create') as mock_create:
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock(text="test response")]
|
||||
mock_response.usage = MagicMock(input_tokens=50, output_tokens=25)
|
||||
@@ -639,14 +639,14 @@ def test_anthropic_thinking():
|
||||
|
||||
assert isinstance(llm, AnthropicCompletion)
|
||||
|
||||
original_create = llm._client.messages.create
|
||||
original_create = llm.client.messages.create
|
||||
captured_params = {}
|
||||
|
||||
def capture_and_call(**kwargs):
|
||||
captured_params.update(kwargs)
|
||||
return original_create(**kwargs)
|
||||
|
||||
with patch.object(llm._client.messages, 'create', side_effect=capture_and_call):
|
||||
with patch.object(llm.client.messages, 'create', side_effect=capture_and_call):
|
||||
result = llm.call("What is the weather in Tokyo?")
|
||||
|
||||
assert result is not None
|
||||
@@ -677,14 +677,14 @@ def test_anthropic_thinking_blocks_preserved_across_turns():
|
||||
assert isinstance(llm, AnthropicCompletion)
|
||||
|
||||
# Capture all messages.create calls to verify thinking blocks are included
|
||||
original_create = llm._client.messages.create
|
||||
original_create = llm.client.messages.create
|
||||
captured_calls = []
|
||||
|
||||
def capture_and_call(**kwargs):
|
||||
captured_calls.append(kwargs)
|
||||
return original_create(**kwargs)
|
||||
|
||||
with patch.object(llm._client.messages, 'create', side_effect=capture_and_call):
|
||||
with patch.object(llm.client.messages, 'create', side_effect=capture_and_call):
|
||||
# First call - establishes context and generates thinking blocks
|
||||
messages = [{"role": "user", "content": "What is 2+2?"}]
|
||||
first_result = llm.call(messages)
|
||||
@@ -695,8 +695,8 @@ def test_anthropic_thinking_blocks_preserved_across_turns():
|
||||
assert len(first_result) > 0
|
||||
|
||||
# Verify thinking blocks were stored after first response
|
||||
assert len(llm._previous_thinking_blocks) > 0, "No thinking blocks stored after first call"
|
||||
first_thinking = llm._previous_thinking_blocks[0]
|
||||
assert len(llm.previous_thinking_blocks) > 0, "No thinking blocks stored after first call"
|
||||
first_thinking = llm.previous_thinking_blocks[0]
|
||||
assert first_thinking["type"] == "thinking"
|
||||
assert "thinking" in first_thinking
|
||||
assert "signature" in first_thinking
|
||||
|
||||
@@ -66,7 +66,7 @@ def test_azure_tool_use_conversation_flow():
|
||||
available_functions = {"get_weather": mock_weather_tool}
|
||||
|
||||
# Mock the Azure client responses
|
||||
with patch.object(completion._client, 'complete') as mock_complete:
|
||||
with patch.object(completion.client, 'complete') as mock_complete:
|
||||
# Mock tool call in response with proper type
|
||||
mock_tool_call = MagicMock(spec=ChatCompletionsToolCall)
|
||||
mock_tool_call.function.name = "get_weather"
|
||||
@@ -698,7 +698,7 @@ def test_azure_environment_variable_endpoint():
|
||||
}):
|
||||
llm = LLM(model="azure/gpt-4")
|
||||
|
||||
assert llm._client is not None
|
||||
assert llm.client is not None
|
||||
assert llm.endpoint == "https://test.openai.azure.com/openai/deployments/gpt-4"
|
||||
|
||||
|
||||
@@ -709,7 +709,7 @@ def test_azure_token_usage_tracking():
|
||||
llm = LLM(model="azure/gpt-4")
|
||||
|
||||
# Mock the Azure response with usage information
|
||||
with patch.object(llm._client, 'complete') as mock_complete:
|
||||
with patch.object(llm.client, 'complete') as mock_complete:
|
||||
mock_message = MagicMock()
|
||||
mock_message.content = "test response"
|
||||
mock_message.tool_calls = None
|
||||
@@ -747,7 +747,7 @@ def test_azure_http_error_handling():
|
||||
llm = LLM(model="azure/gpt-4")
|
||||
|
||||
# Mock an HTTP error
|
||||
with patch.object(llm._client, 'complete') as mock_complete:
|
||||
with patch.object(llm.client, 'complete') as mock_complete:
|
||||
mock_complete.side_effect = HttpResponseError(message="Rate limit exceeded", response=MagicMock(status_code=429))
|
||||
|
||||
with pytest.raises(HttpResponseError):
|
||||
@@ -966,7 +966,7 @@ def test_azure_improved_error_messages():
|
||||
|
||||
llm = LLM(model="azure/gpt-4")
|
||||
|
||||
with patch.object(llm._client, 'complete') as mock_complete:
|
||||
with patch.object(llm.client, 'complete') as mock_complete:
|
||||
error_401 = HttpResponseError(message="Unauthorized")
|
||||
error_401.status_code = 401
|
||||
mock_complete.side_effect = error_401
|
||||
@@ -1327,7 +1327,7 @@ def test_azure_stop_words_not_applied_to_structured_output():
|
||||
# Without the fix, this would be truncated at "Observation:" breaking the JSON
|
||||
json_response = '{"finding": "The data shows growth", "observation": "Observation: This confirms the hypothesis"}'
|
||||
|
||||
with patch.object(llm._client, 'complete') as mock_complete:
|
||||
with patch.object(llm.client, 'complete') as mock_complete:
|
||||
mock_message = MagicMock()
|
||||
mock_message.content = json_response
|
||||
mock_message.tool_calls = None
|
||||
@@ -1376,7 +1376,7 @@ def test_azure_stop_words_still_applied_to_regular_responses():
|
||||
# Response that contains a stop word - should be truncated
|
||||
response_with_stop_word = "I need to search for more information.\n\nAction: search\nObservation: Found results"
|
||||
|
||||
with patch.object(llm._client, 'complete') as mock_complete:
|
||||
with patch.object(llm.client, 'complete') as mock_complete:
|
||||
mock_message = MagicMock()
|
||||
mock_message.content = response_with_stop_word
|
||||
mock_message.tool_calls = None
|
||||
|
||||
@@ -674,7 +674,7 @@ def test_bedrock_token_usage_tracking():
|
||||
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
|
||||
# Mock the Bedrock response with usage information
|
||||
with patch.object(llm._client, 'converse') as mock_converse:
|
||||
with patch.object(llm.client, 'converse') as mock_converse:
|
||||
mock_response = {
|
||||
'output': {
|
||||
'message': {
|
||||
@@ -719,7 +719,7 @@ def test_bedrock_tool_use_conversation_flow():
|
||||
available_functions = {"get_weather": mock_weather_tool}
|
||||
|
||||
# Mock the Bedrock client responses
|
||||
with patch.object(llm._client, 'converse') as mock_converse:
|
||||
with patch.object(llm.client, 'converse') as mock_converse:
|
||||
# First response: tool use request
|
||||
tool_use_response = {
|
||||
'output': {
|
||||
@@ -805,7 +805,7 @@ def test_bedrock_client_error_handling():
|
||||
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
|
||||
# Test ValidationException
|
||||
with patch.object(llm._client, 'converse') as mock_converse:
|
||||
with patch.object(llm.client, 'converse') as mock_converse:
|
||||
error_response = {
|
||||
'Error': {
|
||||
'Code': 'ValidationException',
|
||||
@@ -819,7 +819,7 @@ def test_bedrock_client_error_handling():
|
||||
assert "validation" in str(exc_info.value).lower()
|
||||
|
||||
# Test ThrottlingException
|
||||
with patch.object(llm._client, 'converse') as mock_converse:
|
||||
with patch.object(llm.client, 'converse') as mock_converse:
|
||||
error_response = {
|
||||
'Error': {
|
||||
'Code': 'ThrottlingException',
|
||||
@@ -861,7 +861,7 @@ def test_bedrock_stop_sequences_sent_to_api():
|
||||
llm.stop = ["\nObservation:", "\nThought:"]
|
||||
|
||||
# Patch the API call to capture parameters without making real call
|
||||
with patch.object(llm._client, 'converse') as mock_converse:
|
||||
with patch.object(llm.client, 'converse') as mock_converse:
|
||||
mock_response = {
|
||||
'output': {
|
||||
'message': {
|
||||
|
||||
@@ -556,8 +556,8 @@ def test_gemini_environment_variable_api_key():
|
||||
with patch.dict(os.environ, {"GOOGLE_API_KEY": "test-google-key"}):
|
||||
llm = LLM(model="google/gemini-2.0-flash-001")
|
||||
|
||||
assert llm._client is not None
|
||||
assert hasattr(llm._client, 'models')
|
||||
assert llm.client is not None
|
||||
assert hasattr(llm.client, 'models')
|
||||
assert llm.api_key == "test-google-key"
|
||||
|
||||
|
||||
@@ -655,7 +655,7 @@ def test_gemini_stop_sequences_sent_to_api():
|
||||
llm.stop = ["\nObservation:", "\nThought:"]
|
||||
|
||||
# Patch the API call to capture parameters without making real call
|
||||
with patch.object(llm._client.models, 'generate_content') as mock_generate:
|
||||
with patch.object(llm.client.models, 'generate_content') as mock_generate:
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Hello"
|
||||
mock_response.candidates = []
|
||||
|
||||
@@ -371,11 +371,11 @@ def test_openai_client_setup_with_extra_arguments():
|
||||
assert llm.top_p == 0.5
|
||||
|
||||
# Check that client parameters are properly configured
|
||||
assert llm._client.max_retries == 3
|
||||
assert llm._client.timeout == 30
|
||||
assert llm.client.max_retries == 3
|
||||
assert llm.client.timeout == 30
|
||||
|
||||
# Test that parameters are properly used in API calls
|
||||
with patch.object(llm._client.chat.completions, 'create') as mock_create:
|
||||
with patch.object(llm.client.chat.completions, 'create') as mock_create:
|
||||
mock_create.return_value = MagicMock(
|
||||
choices=[MagicMock(message=MagicMock(content="test response", tool_calls=None))],
|
||||
usage=MagicMock(prompt_tokens=10, completion_tokens=20, total_tokens=30)
|
||||
@@ -396,7 +396,7 @@ def test_extra_arguments_are_passed_to_openai_completion():
|
||||
"""
|
||||
llm = LLM(model="gpt-4o", temperature=0.7, max_tokens=1000, top_p=0.5, max_retries=3)
|
||||
|
||||
with patch.object(llm._client.chat.completions, 'create') as mock_create:
|
||||
with patch.object(llm.client.chat.completions, 'create') as mock_create:
|
||||
mock_create.return_value = MagicMock(
|
||||
choices=[MagicMock(message=MagicMock(content="test response", tool_calls=None))],
|
||||
usage=MagicMock(prompt_tokens=10, completion_tokens=20, total_tokens=30)
|
||||
@@ -507,7 +507,7 @@ def test_openai_streaming_with_response_model():
|
||||
|
||||
llm = LLM(model="openai/gpt-4o", stream=True)
|
||||
|
||||
with patch.object(llm._client.beta.chat.completions, "stream") as mock_stream:
|
||||
with patch.object(llm.client.beta.chat.completions, "stream") as mock_stream:
|
||||
# Create mock chunks with content.delta event structure
|
||||
mock_chunk1 = MagicMock()
|
||||
mock_chunk1.type = "content.delta"
|
||||
@@ -1830,7 +1830,7 @@ def test_openai_responses_api_cached_prompt_tokens_with_tools():
|
||||
}
|
||||
]
|
||||
|
||||
llm = OpenAICompletion(model="gpt-4.1", api='responses')
|
||||
llm = OpenAICompletion(model="gpt-4.1", api='response')
|
||||
|
||||
# First call with tool
|
||||
llm.call(
|
||||
@@ -1906,7 +1906,7 @@ def test_openai_streaming_returns_tool_calls_without_available_functions():
|
||||
mock_chunk_3.id = "chatcmpl-1"
|
||||
|
||||
with patch.object(
|
||||
llm._client.chat.completions, "create", return_value=iter([mock_chunk_1, mock_chunk_2, mock_chunk_3])
|
||||
llm.client.chat.completions, "create", return_value=iter([mock_chunk_1, mock_chunk_2, mock_chunk_3])
|
||||
):
|
||||
result = llm.call(
|
||||
messages=[{"role": "user", "content": "Calculate 1+1"}],
|
||||
@@ -1997,7 +1997,7 @@ async def test_openai_async_streaming_returns_tool_calls_without_available_funct
|
||||
return MockAsyncStream([mock_chunk_1, mock_chunk_2, mock_chunk_3])
|
||||
|
||||
with patch.object(
|
||||
llm._async_client.chat.completions, "create", side_effect=mock_create
|
||||
llm.async_client.chat.completions, "create", side_effect=mock_create
|
||||
):
|
||||
result = await llm.acall(
|
||||
messages=[{"role": "user", "content": "Calculate 1+1"}],
|
||||
|
||||
@@ -3,8 +3,6 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from crewai.knowledge.storage.knowledge_storage import ( # type: ignore[import-untyped]
|
||||
KnowledgeStorage,
|
||||
)
|
||||
@@ -61,7 +59,7 @@ def test_knowledge_storage_invalid_embedding_config(mock_get_client: MagicMock)
|
||||
"Unsupported provider: invalid_provider"
|
||||
)
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
with pytest.raises(ValueError, match="Unsupported provider: invalid_provider"):
|
||||
KnowledgeStorage(
|
||||
embedder={"provider": "invalid_provider"},
|
||||
collection_name="invalid_embedding_test",
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from typing import Any, ClassVar
|
||||
from unittest.mock import Mock, create_autospec, patch
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from crewai.agent import Agent
|
||||
@@ -372,11 +372,8 @@ def test_internal_crew_with_mcp():
|
||||
mock_adapter = Mock()
|
||||
mock_adapter.tools = ToolCollection([simple_tool, another_simple_tool])
|
||||
|
||||
class _StubLLM(BaseLLM):
|
||||
def call(self, *a: Any, **kw: Any) -> str:
|
||||
return ""
|
||||
|
||||
mock_llm = create_autospec(_StubLLM(model="stub"), instance=True)
|
||||
mock_llm = Mock()
|
||||
mock_llm.__class__ = BaseLLM
|
||||
|
||||
with (
|
||||
patch("crewai_tools.MCPServerAdapter", return_value=mock_adapter) as adapter_mock,
|
||||
|
||||
586
lib/crewai/tests/tracing/test_trace_serialization.py
Normal file
586
lib/crewai/tests/tracing/test_trace_serialization.py
Normal file
@@ -0,0 +1,586 @@
|
||||
"""Tests for trace serialization optimization to prevent trace table bloat.
|
||||
|
||||
These tests verify that trace events don't contain redundant full crew/task/agent
|
||||
objects, reducing event sizes from 50-100KB to a few KB per event.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.events.listeners.tracing.trace_listener import (
|
||||
TRACE_EXCLUDE_FIELDS,
|
||||
TraceCollectionListener,
|
||||
_serialize_for_trace,
|
||||
)
|
||||
|
||||
|
||||
class TestTraceExcludeFields:
|
||||
"""Test that TRACE_EXCLUDE_FIELDS contains all the heavy/redundant fields."""
|
||||
|
||||
def test_contains_back_references(self):
|
||||
"""Verify back-reference fields are excluded."""
|
||||
back_refs = {"crew", "agent", "agents", "tasks", "context"}
|
||||
assert back_refs.issubset(TRACE_EXCLUDE_FIELDS)
|
||||
|
||||
def test_contains_heavy_fields(self):
|
||||
"""Verify heavy objects are excluded.
|
||||
|
||||
Note: 'tools' is NOT in TRACE_EXCLUDE_FIELDS because LLMCallStartedEvent.tools
|
||||
is a lightweight list of tool schemas. Agent.tools exclusion is handled
|
||||
explicitly in _build_crew_started_data.
|
||||
"""
|
||||
heavy_fields = {
|
||||
"llm",
|
||||
"function_calling_llm",
|
||||
"step_callback",
|
||||
"task_callback",
|
||||
"crew_callback",
|
||||
"callbacks",
|
||||
"_memory",
|
||||
"_cache",
|
||||
"knowledge_sources",
|
||||
}
|
||||
assert heavy_fields.issubset(TRACE_EXCLUDE_FIELDS)
|
||||
# tools is NOT excluded globally - LLM events need it
|
||||
assert "tools" not in TRACE_EXCLUDE_FIELDS
|
||||
|
||||
|
||||
class TestSerializeForTrace:
|
||||
"""Test the _serialize_for_trace helper function."""
|
||||
|
||||
def test_excludes_crew_field(self):
|
||||
"""Verify crew field is excluded from serialization."""
|
||||
event = MagicMock()
|
||||
event.crew = MagicMock(name="TestCrew")
|
||||
event.crew_name = "TestCrew"
|
||||
event.timestamp = None
|
||||
|
||||
result = _serialize_for_trace(event)
|
||||
|
||||
# crew_name should be present (scalar field)
|
||||
# crew should be excluded (back-reference)
|
||||
assert "crew" not in result or result.get("crew") is None
|
||||
|
||||
def test_excludes_agent_field(self):
|
||||
"""Verify agent field is excluded from serialization."""
|
||||
event = MagicMock()
|
||||
event.agent = MagicMock(role="TestAgent")
|
||||
event.agent_role = "TestAgent"
|
||||
|
||||
result = _serialize_for_trace(event)
|
||||
|
||||
assert "agent" not in result or result.get("agent") is None
|
||||
|
||||
def test_preserves_tools_field(self):
|
||||
"""Verify tools field is preserved for LLM events (lightweight schemas)."""
|
||||
|
||||
class EventWithTools:
|
||||
def __init__(self):
|
||||
self.tools = [{"name": "search", "description": "Search tool"}]
|
||||
self.tool_name = "test_tool"
|
||||
|
||||
event = EventWithTools()
|
||||
result = _serialize_for_trace(event)
|
||||
|
||||
# tools should be preserved (lightweight for LLM events)
|
||||
assert "tools" in result
|
||||
assert result["tools"] == [{"name": "search", "description": "Search tool"}]
|
||||
|
||||
def test_preserves_scalar_fields(self):
|
||||
"""Verify scalar fields needed by AMP frontend are preserved."""
|
||||
|
||||
class SimpleEvent:
|
||||
def __init__(self):
|
||||
self.agent_role = "Researcher"
|
||||
self.task_name = "Research Task"
|
||||
self.task_id = str(uuid.uuid4())
|
||||
self.duration_ms = 1500
|
||||
self.tokens_used = 500
|
||||
|
||||
event = SimpleEvent()
|
||||
result = _serialize_for_trace(event)
|
||||
|
||||
# Scalar fields should be preserved
|
||||
assert result.get("agent_role") == "Researcher"
|
||||
assert result.get("task_name") == "Research Task"
|
||||
assert result.get("duration_ms") == 1500
|
||||
assert result.get("tokens_used") == 500
|
||||
|
||||
def test_extra_exclude_parameter(self):
|
||||
"""Verify extra_exclude adds to the default exclusions."""
|
||||
|
||||
class EventWithCustomField:
|
||||
def __init__(self):
|
||||
self.custom_heavy_field = {"large": "data" * 1000}
|
||||
self.keep_this = "small"
|
||||
|
||||
event = EventWithCustomField()
|
||||
result = _serialize_for_trace(event, extra_exclude={"custom_heavy_field"})
|
||||
|
||||
assert "custom_heavy_field" not in result
|
||||
assert result.get("keep_this") == "small"
|
||||
|
||||
|
||||
class TestBuildEventData:
|
||||
"""Test _build_event_data method for different event types."""
|
||||
|
||||
@pytest.fixture
|
||||
def listener(self):
|
||||
"""Create a trace listener for testing."""
|
||||
# Reset singleton
|
||||
TraceCollectionListener._instance = None
|
||||
TraceCollectionListener._initialized = False
|
||||
TraceCollectionListener._listeners_setup = False
|
||||
return TraceCollectionListener()
|
||||
|
||||
def test_task_started_no_full_task_object(self, listener):
|
||||
"""Verify task_started event doesn't include full task object."""
|
||||
mock_task = MagicMock()
|
||||
mock_task.name = "Test Task"
|
||||
mock_task.description = "A test task description"
|
||||
mock_task.expected_output = "Expected result"
|
||||
mock_task.id = uuid.uuid4()
|
||||
# Add heavy fields that should NOT appear in output
|
||||
mock_task.crew = MagicMock(name="HeavyCrew")
|
||||
mock_task.agent = MagicMock(role="HeavyAgent")
|
||||
mock_task.context = [MagicMock(), MagicMock()]
|
||||
mock_task.tools = [MagicMock(), MagicMock()]
|
||||
|
||||
mock_event = MagicMock()
|
||||
mock_event.task = mock_task
|
||||
mock_event.context = "test context"
|
||||
|
||||
mock_source = MagicMock()
|
||||
mock_source.agent = MagicMock()
|
||||
mock_source.agent.role = "Worker"
|
||||
|
||||
result = listener._build_event_data("task_started", mock_event, mock_source)
|
||||
|
||||
# Should have scalar fields
|
||||
assert result["task_name"] == "Test Task"
|
||||
assert result["task_description"] == "A test task description"
|
||||
assert result["agent_role"] == "Worker"
|
||||
assert result["task_id"] == str(mock_task.id)
|
||||
|
||||
# Should NOT have full objects
|
||||
assert "crew" not in result
|
||||
assert "tools" not in result
|
||||
# task and agent should not be full objects
|
||||
assert result.get("task") is None or not hasattr(result.get("task"), "crew")
|
||||
|
||||
def test_task_completed_no_full_task_object(self, listener):
|
||||
"""Verify task_completed event doesn't include full task object."""
|
||||
mock_task = MagicMock()
|
||||
mock_task.name = "Completed Task"
|
||||
mock_task.description = "Task description"
|
||||
mock_task.id = uuid.uuid4()
|
||||
|
||||
mock_output = MagicMock()
|
||||
mock_output.raw = "Task result"
|
||||
mock_output.output_format = "text"
|
||||
mock_output.agent = "Worker"
|
||||
|
||||
mock_event = MagicMock()
|
||||
mock_event.task = mock_task
|
||||
mock_event.output = mock_output
|
||||
|
||||
result = listener._build_event_data("task_completed", mock_event, None)
|
||||
|
||||
# Should have scalar fields
|
||||
assert result["task_name"] == "Completed Task"
|
||||
assert result["output_raw"] == "Task result"
|
||||
assert result["agent_role"] == "Worker"
|
||||
|
||||
# Should NOT have full task object
|
||||
assert "crew" not in result
|
||||
assert "tools" not in result
|
||||
|
||||
def test_agent_execution_started_no_full_agent(self, listener):
|
||||
"""Verify agent_execution_started extracts only scalar fields."""
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.role = "Analyst"
|
||||
mock_agent.goal = "Analyze data"
|
||||
mock_agent.backstory = "Expert analyst"
|
||||
# Heavy fields
|
||||
mock_agent.tools = [MagicMock(), MagicMock()]
|
||||
mock_agent.llm = MagicMock()
|
||||
mock_agent.crew = MagicMock()
|
||||
|
||||
mock_event = MagicMock()
|
||||
mock_event.agent = mock_agent
|
||||
|
||||
result = listener._build_event_data(
|
||||
"agent_execution_started", mock_event, None
|
||||
)
|
||||
|
||||
# Should have scalar fields
|
||||
assert result["agent_role"] == "Analyst"
|
||||
assert result["agent_goal"] == "Analyze data"
|
||||
assert result["agent_backstory"] == "Expert analyst"
|
||||
|
||||
# Should NOT have heavy objects
|
||||
assert "tools" not in result
|
||||
assert "llm" not in result
|
||||
assert "crew" not in result
|
||||
|
||||
def test_llm_call_started_excludes_heavy_fields(self, listener):
|
||||
"""Verify llm_call_started uses lightweight serialization.
|
||||
|
||||
LLMCallStartedEvent.tools is a lightweight list of tool schemas (dicts),
|
||||
not heavy Agent.tools objects, so it should be preserved.
|
||||
"""
|
||||
|
||||
class MockLLMEvent:
|
||||
def __init__(self):
|
||||
self.task_name = "LLM Task"
|
||||
self.model = "gpt-4"
|
||||
self.tokens = 100
|
||||
# Heavy fields that should be excluded
|
||||
self.crew = MagicMock()
|
||||
self.agent = MagicMock()
|
||||
# LLM event tools are lightweight schemas (dicts), should be kept
|
||||
self.tools = [{"name": "search", "description": "Search tool"}]
|
||||
|
||||
mock_event = MockLLMEvent()
|
||||
|
||||
result = listener._build_event_data("llm_call_started", mock_event, None)
|
||||
|
||||
# task_name should be present
|
||||
assert result["task_name"] == "LLM Task"
|
||||
|
||||
# Heavy fields should be excluded
|
||||
assert "crew" not in result or result.get("crew") is None
|
||||
assert "agent" not in result or result.get("agent") is None
|
||||
# LLM tools (lightweight schemas) should be preserved
|
||||
assert result.get("tools") == [{"name": "search", "description": "Search tool"}]
|
||||
|
||||
def test_llm_call_completed_excludes_heavy_fields(self, listener):
|
||||
"""Verify llm_call_completed uses lightweight serialization."""
|
||||
|
||||
class MockLLMCompletedEvent:
|
||||
def __init__(self):
|
||||
self.response = "LLM response"
|
||||
self.tokens_used = 150
|
||||
self.duration_ms = 500
|
||||
# Heavy fields
|
||||
self.crew = MagicMock()
|
||||
self.agent = MagicMock()
|
||||
|
||||
mock_event = MockLLMCompletedEvent()
|
||||
|
||||
result = listener._build_event_data("llm_call_completed", mock_event, None)
|
||||
|
||||
# Scalar fields preserved
|
||||
assert result.get("response") == "LLM response"
|
||||
assert result.get("tokens_used") == 150
|
||||
|
||||
# Heavy fields excluded
|
||||
assert "crew" not in result or result.get("crew") is None
|
||||
assert "agent" not in result or result.get("agent") is None
|
||||
|
||||
|
||||
class TestCrewKickoffStartedEvent:
|
||||
"""Test that crew_kickoff_started event has full structure."""
|
||||
|
||||
@pytest.fixture
|
||||
def listener(self):
|
||||
"""Create a trace listener for testing."""
|
||||
TraceCollectionListener._instance = None
|
||||
TraceCollectionListener._initialized = False
|
||||
TraceCollectionListener._listeners_setup = False
|
||||
return TraceCollectionListener()
|
||||
|
||||
def test_crew_started_has_crew_structure(self, listener):
|
||||
"""Verify crew_kickoff_started includes the crew_structure field."""
|
||||
# Create mock crew with agents and tasks
|
||||
mock_agent1 = MagicMock()
|
||||
mock_agent1.id = uuid.uuid4()
|
||||
mock_agent1.role = "Researcher"
|
||||
mock_agent1.goal = "Research things"
|
||||
mock_agent1.backstory = "Expert researcher"
|
||||
mock_agent1.verbose = True
|
||||
mock_agent1.allow_delegation = False
|
||||
mock_agent1.max_iter = 10
|
||||
mock_agent1.max_rpm = None
|
||||
mock_agent1.tools = [MagicMock(name="search_tool"), MagicMock(name="read_tool")]
|
||||
|
||||
mock_agent2 = MagicMock()
|
||||
mock_agent2.id = uuid.uuid4()
|
||||
mock_agent2.role = "Writer"
|
||||
mock_agent2.goal = "Write content"
|
||||
mock_agent2.backstory = "Expert writer"
|
||||
mock_agent2.verbose = False
|
||||
mock_agent2.allow_delegation = True
|
||||
mock_agent2.max_iter = 5
|
||||
mock_agent2.max_rpm = 10
|
||||
mock_agent2.tools = []
|
||||
|
||||
mock_task1 = MagicMock()
|
||||
mock_task1.id = uuid.uuid4()
|
||||
mock_task1.name = "Research Task"
|
||||
mock_task1.description = "Do research"
|
||||
mock_task1.expected_output = "Research results"
|
||||
mock_task1.async_execution = False
|
||||
mock_task1.human_input = False
|
||||
mock_task1.agent = mock_agent1
|
||||
mock_task1.context = None
|
||||
|
||||
mock_task2 = MagicMock()
|
||||
mock_task2.id = uuid.uuid4()
|
||||
mock_task2.name = "Writing Task"
|
||||
mock_task2.description = "Write report"
|
||||
mock_task2.expected_output = "Written report"
|
||||
mock_task2.async_execution = True
|
||||
mock_task2.human_input = True
|
||||
mock_task2.agent = mock_agent2
|
||||
mock_task2.context = [mock_task1]
|
||||
|
||||
mock_crew = MagicMock()
|
||||
mock_crew.agents = [mock_agent1, mock_agent2]
|
||||
mock_crew.tasks = [mock_task1, mock_task2]
|
||||
mock_crew.process = "sequential"
|
||||
mock_crew.verbose = True
|
||||
mock_crew.memory = False
|
||||
|
||||
mock_event = MagicMock()
|
||||
mock_event.crew = mock_crew
|
||||
mock_event.crew_name = "TestCrew"
|
||||
mock_event.inputs = {"key": "value"}
|
||||
|
||||
result = listener._build_event_data("crew_kickoff_started", mock_event, None)
|
||||
|
||||
# Should have crew_structure
|
||||
assert "crew_structure" in result
|
||||
crew_structure = result["crew_structure"]
|
||||
|
||||
# Verify agents are serialized with tool names
|
||||
assert len(crew_structure["agents"]) == 2
|
||||
agent1_data = crew_structure["agents"][0]
|
||||
assert agent1_data["role"] == "Researcher"
|
||||
assert agent1_data["goal"] == "Research things"
|
||||
assert "tool_names" in agent1_data
|
||||
assert len(agent1_data["tool_names"]) == 2
|
||||
|
||||
# Verify tasks have lightweight agent references
|
||||
assert len(crew_structure["tasks"]) == 2
|
||||
task2_data = crew_structure["tasks"][1]
|
||||
assert task2_data["name"] == "Writing Task"
|
||||
assert "agent_ref" in task2_data
|
||||
assert task2_data["agent_ref"]["role"] == "Writer"
|
||||
|
||||
# Verify context uses task IDs
|
||||
assert "context_task_ids" in task2_data
|
||||
assert str(mock_task1.id) in task2_data["context_task_ids"]
|
||||
|
||||
def test_crew_started_agents_no_full_tools(self, listener):
|
||||
"""Verify agents in crew_structure have tool_names, not full tool objects."""
|
||||
mock_tool = MagicMock()
|
||||
mock_tool.name = "web_search"
|
||||
mock_tool.description = "Search the web"
|
||||
mock_tool.func = lambda x: x # Heavy callable
|
||||
mock_tool.args_schema = {"type": "object"} # Schema
|
||||
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.id = uuid.uuid4()
|
||||
mock_agent.role = "Searcher"
|
||||
mock_agent.goal = "Search"
|
||||
mock_agent.backstory = "Expert"
|
||||
mock_agent.verbose = False
|
||||
mock_agent.allow_delegation = False
|
||||
mock_agent.max_iter = 5
|
||||
mock_agent.max_rpm = None
|
||||
mock_agent.tools = [mock_tool]
|
||||
|
||||
mock_crew = MagicMock()
|
||||
mock_crew.agents = [mock_agent]
|
||||
mock_crew.tasks = []
|
||||
mock_crew.process = "sequential"
|
||||
mock_crew.verbose = False
|
||||
mock_crew.memory = False
|
||||
|
||||
mock_event = MagicMock()
|
||||
mock_event.crew = mock_crew
|
||||
|
||||
result = listener._build_event_data("crew_kickoff_started", mock_event, None)
|
||||
|
||||
agent_data = result["crew_structure"]["agents"][0]
|
||||
|
||||
# Should have tool_names (list of strings)
|
||||
assert "tool_names" in agent_data
|
||||
assert agent_data["tool_names"] == ["web_search"]
|
||||
|
||||
# Should NOT have full tools array
|
||||
assert "tools" not in agent_data
|
||||
|
||||
def test_crew_started_tasks_no_full_agent(self, listener):
|
||||
"""Verify tasks have agent_ref, not full agent object."""
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.id = uuid.uuid4()
|
||||
mock_agent.role = "Worker"
|
||||
mock_agent.goal = "Work hard"
|
||||
mock_agent.backstory = "Dedicated worker"
|
||||
mock_agent.tools = [MagicMock(), MagicMock()]
|
||||
mock_agent.llm = MagicMock()
|
||||
|
||||
mock_task = MagicMock()
|
||||
mock_task.id = uuid.uuid4()
|
||||
mock_task.name = "Work Task"
|
||||
mock_task.description = "Do work"
|
||||
mock_task.expected_output = "Work done"
|
||||
mock_task.async_execution = False
|
||||
mock_task.human_input = False
|
||||
mock_task.agent = mock_agent
|
||||
mock_task.context = None
|
||||
|
||||
mock_crew = MagicMock()
|
||||
mock_crew.agents = [mock_agent]
|
||||
mock_crew.tasks = [mock_task]
|
||||
mock_crew.process = "sequential"
|
||||
mock_crew.verbose = False
|
||||
mock_crew.memory = False
|
||||
|
||||
mock_event = MagicMock()
|
||||
mock_event.crew = mock_crew
|
||||
|
||||
result = listener._build_event_data("crew_kickoff_started", mock_event, None)
|
||||
|
||||
task_data = result["crew_structure"]["tasks"][0]
|
||||
|
||||
# Should have lightweight agent_ref
|
||||
assert "agent_ref" in task_data
|
||||
assert task_data["agent_ref"]["id"] == str(mock_agent.id)
|
||||
assert task_data["agent_ref"]["role"] == "Worker"
|
||||
|
||||
# agent_ref should ONLY have id and role (not tools, llm, etc.)
|
||||
assert len(task_data["agent_ref"]) == 2
|
||||
|
||||
# Should NOT have full agent
|
||||
assert "agent" not in task_data
|
||||
|
||||
|
||||
class TestNonCrewStartedEvents:
|
||||
"""Test that non-crew_started events don't have redundant data."""
|
||||
|
||||
@pytest.fixture
|
||||
def listener(self):
|
||||
"""Create a trace listener for testing."""
|
||||
TraceCollectionListener._instance = None
|
||||
TraceCollectionListener._initialized = False
|
||||
TraceCollectionListener._listeners_setup = False
|
||||
return TraceCollectionListener()
|
||||
|
||||
def test_generic_event_no_crew(self, listener):
|
||||
"""Verify generic events exclude crew object.
|
||||
|
||||
Note: 'tools' is now preserved since LLMCallStartedEvent.tools is lightweight.
|
||||
"""
|
||||
|
||||
class GenericEvent:
|
||||
def __init__(self):
|
||||
self.event_type = "some_event"
|
||||
self.data = "some_data"
|
||||
# These should be excluded
|
||||
self.crew = MagicMock()
|
||||
self.agents = [MagicMock()]
|
||||
self.tasks = [MagicMock()]
|
||||
# tools is now preserved (for LLM events it's lightweight)
|
||||
self.tools = [{"name": "search"}]
|
||||
|
||||
mock_event = GenericEvent()
|
||||
|
||||
result = listener._build_event_data("some_event", mock_event, None)
|
||||
|
||||
# Scalar fields preserved
|
||||
assert result.get("event_type") == "some_event"
|
||||
assert result.get("data") == "some_data"
|
||||
|
||||
# Heavy fields excluded
|
||||
assert "crew" not in result or result.get("crew") is None
|
||||
assert "agents" not in result or result.get("agents") is None
|
||||
assert "tasks" not in result or result.get("tasks") is None
|
||||
# tools is now preserved (lightweight for LLM events)
|
||||
assert result.get("tools") == [{"name": "search"}]
|
||||
|
||||
def test_crew_kickoff_completed_no_full_crew(self, listener):
|
||||
"""Verify crew_kickoff_completed doesn't repeat full crew structure."""
|
||||
|
||||
class CrewCompletedEvent:
|
||||
def __init__(self):
|
||||
self.crew_name = "TestCrew"
|
||||
self.total_tokens = 5000
|
||||
self.output = "Final output"
|
||||
# Should be excluded
|
||||
self.crew = MagicMock()
|
||||
self.crew.agents = [MagicMock(), MagicMock()]
|
||||
self.crew.tasks = [MagicMock()]
|
||||
|
||||
mock_event = CrewCompletedEvent()
|
||||
|
||||
result = listener._build_event_data("crew_kickoff_completed", mock_event, None)
|
||||
|
||||
# Scalar fields preserved
|
||||
assert result.get("crew_name") == "TestCrew"
|
||||
assert result.get("total_tokens") == 5000
|
||||
|
||||
# Should NOT have full crew object
|
||||
assert "crew" not in result or result.get("crew") is None
|
||||
# Should NOT have crew_structure (that's only for crew_started)
|
||||
assert "crew_structure" not in result
|
||||
|
||||
|
||||
class TestSizeReduction:
|
||||
"""Test that the optimization actually reduces serialized size."""
|
||||
|
||||
@pytest.fixture
|
||||
def listener(self):
|
||||
"""Create a trace listener for testing."""
|
||||
TraceCollectionListener._instance = None
|
||||
TraceCollectionListener._initialized = False
|
||||
TraceCollectionListener._listeners_setup = False
|
||||
return TraceCollectionListener()
|
||||
|
||||
def test_task_event_size_reduction(self, listener):
|
||||
"""Verify task events are much smaller than naive serialization."""
|
||||
import json
|
||||
|
||||
# Create a realistic task with many fields
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.id = uuid.uuid4()
|
||||
mock_agent.role = "Researcher"
|
||||
mock_agent.goal = "Research" * 50 # Longer goal
|
||||
mock_agent.backstory = "Expert" * 100 # Longer backstory
|
||||
mock_agent.tools = [MagicMock() for _ in range(5)]
|
||||
mock_agent.llm = MagicMock()
|
||||
mock_agent.crew = MagicMock()
|
||||
|
||||
mock_task = MagicMock()
|
||||
mock_task.name = "Research Task"
|
||||
mock_task.description = "Detailed description" * 20
|
||||
mock_task.expected_output = "Expected" * 10
|
||||
mock_task.id = uuid.uuid4()
|
||||
mock_task.agent = mock_agent
|
||||
mock_task.context = [MagicMock() for _ in range(3)]
|
||||
mock_task.crew = MagicMock()
|
||||
mock_task.tools = [MagicMock() for _ in range(3)]
|
||||
|
||||
mock_event = MagicMock()
|
||||
mock_event.task = mock_task
|
||||
mock_event.context = "test context"
|
||||
|
||||
mock_source = MagicMock()
|
||||
mock_source.agent = mock_agent
|
||||
|
||||
result = listener._build_event_data("task_started", mock_event, mock_source)
|
||||
|
||||
# The result should be relatively small
|
||||
serialized = json.dumps(result, default=str)
|
||||
|
||||
# Should be under 2KB for task_started (was potentially 50-100KB before)
|
||||
assert len(serialized) < 2000, f"task_started too large: {len(serialized)} bytes"
|
||||
|
||||
# Should have the essential fields
|
||||
assert "task_name" in result
|
||||
assert "task_id" in result
|
||||
assert "agent_role" in result
|
||||
Reference in New Issue
Block a user