max tools per turn wip and ensure we drop print times

This commit is contained in:
lorenzejay
2026-01-20 16:46:38 -08:00
parent 3472cb4f8a
commit b49e42af05
2 changed files with 144 additions and 15 deletions

View File

@@ -250,6 +250,10 @@ class Agent(BaseAgent):
default=CrewAgentExecutor, default=CrewAgentExecutor,
description="Class to use for the agent executor. Defaults to CrewAgentExecutor, can optionally use AgentExecutor.", description="Class to use for the agent executor. Defaults to CrewAgentExecutor, can optionally use AgentExecutor.",
) )
max_tools_per_turn: int = Field(
default=10,
description="Maximum number of tool calls to execute per LLM turn before asking for reflection.",
)
@model_validator(mode="before") @model_validator(mode="before")
def validate_from_repository(cls, v: Any) -> dict[str, Any] | None | Any: # noqa: N805 def validate_from_repository(cls, v: Any) -> dict[str, Any] | None | Any: # noqa: N805
@@ -803,6 +807,7 @@ class Agent(BaseAgent):
request_within_rpm_limit=rpm_limit_fn, request_within_rpm_limit=rpm_limit_fn,
callbacks=[TokenCalcHandler(self._token_process)], callbacks=[TokenCalcHandler(self._token_process)],
response_model=task.response_model if task else None, response_model=task.response_model if task else None,
# max_tools_per_turn=self.max_tools_per_turn, #TODO: drop this
) )
def _update_executor_parameters( def _update_executor_parameters(
@@ -1698,6 +1703,7 @@ class Agent(BaseAgent):
callbacks=[TokenCalcHandler(self._token_process)], callbacks=[TokenCalcHandler(self._token_process)],
response_model=response_format, response_model=response_format,
i18n=self.i18n, i18n=self.i18n,
max_tools_per_turn=self.max_tools_per_turn,
) )
# Format messages # Format messages

View File

@@ -4,6 +4,7 @@ from collections.abc import Callable, Coroutine
from datetime import datetime from datetime import datetime
import json import json
import threading import threading
import time
from typing import TYPE_CHECKING, Any, Literal, cast from typing import TYPE_CHECKING, Any, Literal, cast
from uuid import uuid4 from uuid import uuid4
@@ -123,6 +124,7 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
callbacks: list[Any] | None = None, callbacks: list[Any] | None = None,
response_model: type[BaseModel] | None = None, response_model: type[BaseModel] | None = None,
i18n: I18N | None = None, i18n: I18N | None = None,
max_tools_per_turn: int = 10,
) -> None: ) -> None:
"""Initialize the flow-based agent executor. """Initialize the flow-based agent executor.
@@ -146,6 +148,7 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
callbacks: Optional callbacks list. callbacks: Optional callbacks list.
response_model: Optional Pydantic model for structured outputs. response_model: Optional Pydantic model for structured outputs.
""" """
print("lorenze using agent executor")
self._i18n: I18N = i18n or get_i18n() self._i18n: I18N = i18n or get_i18n()
self.llm = llm self.llm = llm
self.task: Task | None = task self.task: Task | None = task
@@ -166,6 +169,7 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
self.respect_context_window = respect_context_window self.respect_context_window = respect_context_window
self.request_within_rpm_limit = request_within_rpm_limit self.request_within_rpm_limit = request_within_rpm_limit
self.response_model = response_model self.response_model = response_model
self.max_tools_per_turn = max_tools_per_turn
self.log_error_after = 3 self.log_error_after = 3
self._console: Console = Console() self._console: Console = Console()
@@ -333,8 +337,21 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
Returns routing decision based on parsing result. Returns routing decision based on parsing result.
""" """
try: try:
iteration_start = time.time()
print(f"\n{'=' * 60}")
print(
f"[{time.strftime('%H:%M:%S')}] ITERATION {self.state.iterations} - call_llm_and_parse START (ReAct)"
)
print(
f"[{time.strftime('%H:%M:%S')}] Messages count: {len(self.state.messages)}"
)
print(f"{'=' * 60}")
enforce_rpm_limit(self.request_within_rpm_limit) enforce_rpm_limit(self.request_within_rpm_limit)
llm_start = time.time()
print(f"[{time.strftime('%H:%M:%S')}] LLM CALL START")
answer = get_llm_response( answer = get_llm_response(
llm=self.llm, llm=self.llm,
messages=list(self.state.messages), messages=list(self.state.messages),
@@ -346,8 +363,21 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
executor_context=self, executor_context=self,
) )
llm_elapsed = time.time() - llm_start
print(
f"[{time.strftime('%H:%M:%S')}] LLM CALL END - took {llm_elapsed:.2f}s"
)
print(
f"[{time.strftime('%H:%M:%S')}] Answer length: {len(answer) if answer else 0} chars"
)
# Parse the LLM response # Parse the LLM response
parse_start = time.time()
formatted_answer = process_llm_response(answer, self.use_stop_words) formatted_answer = process_llm_response(answer, self.use_stop_words)
parse_elapsed = time.time() - parse_start
print(
f"[{time.strftime('%H:%M:%S')}] Parsing took {parse_elapsed:.3f}s -> {type(formatted_answer).__name__}"
)
self.state.current_answer = formatted_answer self.state.current_answer = formatted_answer
if "Final Answer:" in answer and isinstance(formatted_answer, AgentAction): if "Final Answer:" in answer and isinstance(formatted_answer, AgentAction):
@@ -363,6 +393,10 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
preview_text.append(f"{answer[:200]}...", style="yellow dim") preview_text.append(f"{answer[:200]}...", style="yellow dim")
self._console.print(preview_text) self._console.print(preview_text)
iteration_elapsed = time.time() - iteration_start
print(
f"[{time.strftime('%H:%M:%S')}] ITERATION total: {iteration_elapsed:.2f}s"
)
return "parsed" return "parsed"
except OutputParserError as e: except OutputParserError as e:
@@ -387,14 +421,49 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
) -> Literal["native_tool_calls", "native_finished", "context_error"]: ) -> Literal["native_tool_calls", "native_finished", "context_error"]:
"""Execute LLM call with native function calling. """Execute LLM call with native function calling.
Always calls the LLM so it can read reflection prompts and decide
whether to provide a final answer or request more tools.
Returns routing decision based on whether tool calls or final answer. Returns routing decision based on whether tool calls or final answer.
""" """
try: try:
iteration_start = time.time()
print(f"\n{'=' * 60}")
print(
f"[{time.strftime('%H:%M:%S')}] ITERATION {self.state.iterations} - call_llm_native_tools START"
)
print(
f"[{time.strftime('%H:%M:%S')}] pending_tool_calls before LLM: {len(self.state.pending_tool_calls)}"
)
print(
f"[{time.strftime('%H:%M:%S')}] Messages count: {len(self.state.messages)}"
)
print(f"{'=' * 60}")
# Clear pending tools - LLM will decide what to do next after reading
# the reflection prompt. It can either:
# 1. Return a final answer (string) if it has enough info
# 2. Return tool calls (possibly same ones, or different ones)
self.state.pending_tool_calls.clear()
enforce_rpm_limit(self.request_within_rpm_limit) enforce_rpm_limit(self.request_within_rpm_limit)
last_msg_content = (
self.state.messages[-1].get("content", "")
if self.state.messages
else ""
)
last_msg_preview = (
last_msg_content[:200] if last_msg_content else "(no content)"
)
print(
f"[{time.strftime('%H:%M:%S')}] Last message to LLM: {last_msg_preview}..."
)
# Call LLM with native tools # Call LLM with native tools
# Pass available_functions=None so the LLM returns tool_calls llm_start = time.time()
# without executing them. The executor handles tool execution. print(f"[{time.strftime('%H:%M:%S')}] LLM CALL START")
answer = get_llm_response( answer = get_llm_response(
llm=self.llm, llm=self.llm,
messages=list(self.state.messages), messages=list(self.state.messages),
@@ -404,15 +473,27 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
available_functions=None, available_functions=None,
from_task=self.task, from_task=self.task,
from_agent=self.agent, from_agent=self.agent,
# response_model=self.response_model,
response_model=None, response_model=None,
executor_context=self, executor_context=self,
) )
llm_elapsed = time.time() - llm_start
print(
f"[{time.strftime('%H:%M:%S')}] LLM CALL END - took {llm_elapsed:.2f}s"
)
print(f"[{time.strftime('%H:%M:%S')}] Answer type: {type(answer).__name__}")
# Check if the response is a list of tool calls # Check if the response is a list of tool calls
if isinstance(answer, list) and answer and self._is_tool_call_list(answer): if isinstance(answer, list) and answer and self._is_tool_call_list(answer):
# Store tool calls for sequential processing # Store tool calls for sequential processing
self.state.pending_tool_calls = list(answer) self.state.pending_tool_calls = list(answer)
iteration_elapsed = time.time() - iteration_start
print(
f"[{time.strftime('%H:%M:%S')}] -> Routing to native_tool_calls ({len(answer)} tools)"
)
print(
f"[{time.strftime('%H:%M:%S')}] ITERATION total: {iteration_elapsed:.2f}s"
)
return "native_tool_calls" return "native_tool_calls"
# Text response - this is the final answer # Text response - this is the final answer
@@ -424,6 +505,13 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
) )
self._invoke_step_callback(self.state.current_answer) self._invoke_step_callback(self.state.current_answer)
self._append_message_to_state(answer) self._append_message_to_state(answer)
iteration_elapsed = time.time() - iteration_start
print(
f"[{time.strftime('%H:%M:%S')}] -> FINAL ANSWER (string, len={len(answer)})"
)
print(
f"[{time.strftime('%H:%M:%S')}] ITERATION total: {iteration_elapsed:.2f}s"
)
return "native_finished" return "native_finished"
# Unexpected response type, treat as final answer # Unexpected response type, treat as final answer
@@ -434,6 +522,11 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
) )
self._invoke_step_callback(self.state.current_answer) self._invoke_step_callback(self.state.current_answer)
self._append_message_to_state(str(answer)) self._append_message_to_state(str(answer))
iteration_elapsed = time.time() - iteration_start
print(f"[{time.strftime('%H:%M:%S')}] -> FINAL ANSWER (unexpected type)")
print(
f"[{time.strftime('%H:%M:%S')}] ITERATION total: {iteration_elapsed:.2f}s"
)
return "native_finished" return "native_finished"
except Exception as e: except Exception as e:
@@ -519,16 +612,20 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
@listen("native_tool_calls") @listen("native_tool_calls")
def execute_native_tool(self) -> Literal["native_tool_completed"]: def execute_native_tool(self) -> Literal["native_tool_completed"]:
"""Execute a single native tool call and inject reasoning prompt. """Execute a SINGLE native tool call with reflection after.
Processes only the FIRST tool call from pending_tool_calls for Processes only the first tool from pending_tool_calls, then asks
sequential execution with reflection after each tool. the LLM if it can answer the task. Remaining tools stay in the queue
for potential execution on next iteration.
""" """
if not self.state.pending_tool_calls: if not self.state.pending_tool_calls:
return "native_tool_completed" return "native_tool_completed"
tool_call = self.state.pending_tool_calls[0] # Pop just the first tool (leave the rest in queue for potential continuation)
self.state.pending_tool_calls = [] # Clear pending calls tool_call = self.state.pending_tool_calls.pop(0)
print(
f"Executing 1 tool, {len(self.state.pending_tool_calls)} remaining in queue"
)
# Extract tool call info - handle OpenAI, Anthropic, and Gemini formats # Extract tool call info - handle OpenAI, Anthropic, and Gemini formats
if hasattr(tool_call, "function"): if hasattr(tool_call, "function"):
@@ -556,6 +653,7 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
func_name = func_info.get("name", "") or tool_call.get("name", "") func_name = func_info.get("name", "") or tool_call.get("name", "")
func_args = func_info.get("arguments", "{}") or tool_call.get("input", {}) func_args = func_info.get("arguments", "{}") or tool_call.get("input", {})
else: else:
# Unrecognized format - skip and try next
return "native_tool_completed" return "native_tool_completed"
# Append assistant message with single tool call # Append assistant message with single tool call
@@ -638,16 +736,41 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
color="green", color="green",
) )
# Inject post-tool reasoning prompt to enforce analysis # Only add reflection prompt if there are still pending tools
reasoning_prompt = self._i18n.slice("post_tool_reasoning") # If no pending tools, skip reflection - LLM will naturally continue
reasoning_message: LLMMessage = { if self.state.pending_tool_calls:
"role": "user", print("--------------------------------")
"content": reasoning_prompt, print(
} f"REFLECTION: {len(self.state.pending_tool_calls)} tools pending - asking LLM to decide"
self.state.messages.append(reasoning_message) )
print("--------------------------------")
reasoning_prompt = self._i18n.slice("post_tool_reasoning")
reasoning_message: LLMMessage = {
"role": "user",
"content": reasoning_prompt,
}
self.state.messages.append(reasoning_message)
else:
print("--------------------------------")
print("SKIPPING REFLECTION: No pending tools - LLM will continue naturally")
print("--------------------------------")
return "native_tool_completed" return "native_tool_completed"
def _extract_tool_name(self, tool_call: Any) -> str:
"""Extract tool name from various tool call formats."""
if hasattr(tool_call, "function"):
return tool_call.function.name
if hasattr(tool_call, "function_call") and tool_call.function_call:
return tool_call.function_call.name
if hasattr(tool_call, "name"):
return tool_call.name
if isinstance(tool_call, dict):
func_info = tool_call.get("function", {})
return func_info.get("name", "") or tool_call.get("name", "unknown")
return "unknown"
@router(execute_native_tool) @router(execute_native_tool)
def increment_native_and_continue(self) -> Literal["initialized"]: def increment_native_and_continue(self) -> Literal["initialized"]:
"""Increment iteration counter after native tool execution.""" """Increment iteration counter after native tool execution."""