max tools per turn wip and ensure we drop print times

This commit is contained in:
lorenzejay
2026-01-20 16:46:38 -08:00
parent 3472cb4f8a
commit b49e42af05
2 changed files with 144 additions and 15 deletions

View File

@@ -250,6 +250,10 @@ class Agent(BaseAgent):
default=CrewAgentExecutor,
description="Class to use for the agent executor. Defaults to CrewAgentExecutor, can optionally use AgentExecutor.",
)
max_tools_per_turn: int = Field(
default=10,
description="Maximum number of tool calls to execute per LLM turn before asking for reflection.",
)
@model_validator(mode="before")
def validate_from_repository(cls, v: Any) -> dict[str, Any] | None | Any: # noqa: N805
@@ -803,6 +807,7 @@ class Agent(BaseAgent):
request_within_rpm_limit=rpm_limit_fn,
callbacks=[TokenCalcHandler(self._token_process)],
response_model=task.response_model if task else None,
# max_tools_per_turn=self.max_tools_per_turn, #TODO: drop this
)
def _update_executor_parameters(
@@ -1698,6 +1703,7 @@ class Agent(BaseAgent):
callbacks=[TokenCalcHandler(self._token_process)],
response_model=response_format,
i18n=self.i18n,
max_tools_per_turn=self.max_tools_per_turn,
)
# Format messages

View File

@@ -4,6 +4,7 @@ from collections.abc import Callable, Coroutine
from datetime import datetime
import json
import threading
import time
from typing import TYPE_CHECKING, Any, Literal, cast
from uuid import uuid4
@@ -123,6 +124,7 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
callbacks: list[Any] | None = None,
response_model: type[BaseModel] | None = None,
i18n: I18N | None = None,
max_tools_per_turn: int = 10,
) -> None:
"""Initialize the flow-based agent executor.
@@ -146,6 +148,7 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
callbacks: Optional callbacks list.
response_model: Optional Pydantic model for structured outputs.
"""
print("lorenze using agent executor")
self._i18n: I18N = i18n or get_i18n()
self.llm = llm
self.task: Task | None = task
@@ -166,6 +169,7 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
self.respect_context_window = respect_context_window
self.request_within_rpm_limit = request_within_rpm_limit
self.response_model = response_model
self.max_tools_per_turn = max_tools_per_turn
self.log_error_after = 3
self._console: Console = Console()
@@ -333,8 +337,21 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
Returns routing decision based on parsing result.
"""
try:
iteration_start = time.time()
print(f"\n{'=' * 60}")
print(
f"[{time.strftime('%H:%M:%S')}] ITERATION {self.state.iterations} - call_llm_and_parse START (ReAct)"
)
print(
f"[{time.strftime('%H:%M:%S')}] Messages count: {len(self.state.messages)}"
)
print(f"{'=' * 60}")
enforce_rpm_limit(self.request_within_rpm_limit)
llm_start = time.time()
print(f"[{time.strftime('%H:%M:%S')}] LLM CALL START")
answer = get_llm_response(
llm=self.llm,
messages=list(self.state.messages),
@@ -346,8 +363,21 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
executor_context=self,
)
llm_elapsed = time.time() - llm_start
print(
f"[{time.strftime('%H:%M:%S')}] LLM CALL END - took {llm_elapsed:.2f}s"
)
print(
f"[{time.strftime('%H:%M:%S')}] Answer length: {len(answer) if answer else 0} chars"
)
# Parse the LLM response
parse_start = time.time()
formatted_answer = process_llm_response(answer, self.use_stop_words)
parse_elapsed = time.time() - parse_start
print(
f"[{time.strftime('%H:%M:%S')}] Parsing took {parse_elapsed:.3f}s -> {type(formatted_answer).__name__}"
)
self.state.current_answer = formatted_answer
if "Final Answer:" in answer and isinstance(formatted_answer, AgentAction):
@@ -363,6 +393,10 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
preview_text.append(f"{answer[:200]}...", style="yellow dim")
self._console.print(preview_text)
iteration_elapsed = time.time() - iteration_start
print(
f"[{time.strftime('%H:%M:%S')}] ITERATION total: {iteration_elapsed:.2f}s"
)
return "parsed"
except OutputParserError as e:
@@ -387,14 +421,49 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
) -> Literal["native_tool_calls", "native_finished", "context_error"]:
"""Execute LLM call with native function calling.
Always calls the LLM so it can read reflection prompts and decide
whether to provide a final answer or request more tools.
Returns routing decision based on whether tool calls or final answer.
"""
try:
iteration_start = time.time()
print(f"\n{'=' * 60}")
print(
f"[{time.strftime('%H:%M:%S')}] ITERATION {self.state.iterations} - call_llm_native_tools START"
)
print(
f"[{time.strftime('%H:%M:%S')}] pending_tool_calls before LLM: {len(self.state.pending_tool_calls)}"
)
print(
f"[{time.strftime('%H:%M:%S')}] Messages count: {len(self.state.messages)}"
)
print(f"{'=' * 60}")
# Clear pending tools - LLM will decide what to do next after reading
# the reflection prompt. It can either:
# 1. Return a final answer (string) if it has enough info
# 2. Return tool calls (possibly same ones, or different ones)
self.state.pending_tool_calls.clear()
enforce_rpm_limit(self.request_within_rpm_limit)
last_msg_content = (
self.state.messages[-1].get("content", "")
if self.state.messages
else ""
)
last_msg_preview = (
last_msg_content[:200] if last_msg_content else "(no content)"
)
print(
f"[{time.strftime('%H:%M:%S')}] Last message to LLM: {last_msg_preview}..."
)
# Call LLM with native tools
# Pass available_functions=None so the LLM returns tool_calls
# without executing them. The executor handles tool execution.
llm_start = time.time()
print(f"[{time.strftime('%H:%M:%S')}] LLM CALL START")
answer = get_llm_response(
llm=self.llm,
messages=list(self.state.messages),
@@ -404,15 +473,27 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
available_functions=None,
from_task=self.task,
from_agent=self.agent,
# response_model=self.response_model,
response_model=None,
executor_context=self,
)
llm_elapsed = time.time() - llm_start
print(
f"[{time.strftime('%H:%M:%S')}] LLM CALL END - took {llm_elapsed:.2f}s"
)
print(f"[{time.strftime('%H:%M:%S')}] Answer type: {type(answer).__name__}")
# Check if the response is a list of tool calls
if isinstance(answer, list) and answer and self._is_tool_call_list(answer):
# Store tool calls for sequential processing
self.state.pending_tool_calls = list(answer)
iteration_elapsed = time.time() - iteration_start
print(
f"[{time.strftime('%H:%M:%S')}] -> Routing to native_tool_calls ({len(answer)} tools)"
)
print(
f"[{time.strftime('%H:%M:%S')}] ITERATION total: {iteration_elapsed:.2f}s"
)
return "native_tool_calls"
# Text response - this is the final answer
@@ -424,6 +505,13 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
)
self._invoke_step_callback(self.state.current_answer)
self._append_message_to_state(answer)
iteration_elapsed = time.time() - iteration_start
print(
f"[{time.strftime('%H:%M:%S')}] -> FINAL ANSWER (string, len={len(answer)})"
)
print(
f"[{time.strftime('%H:%M:%S')}] ITERATION total: {iteration_elapsed:.2f}s"
)
return "native_finished"
# Unexpected response type, treat as final answer
@@ -434,6 +522,11 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
)
self._invoke_step_callback(self.state.current_answer)
self._append_message_to_state(str(answer))
iteration_elapsed = time.time() - iteration_start
print(f"[{time.strftime('%H:%M:%S')}] -> FINAL ANSWER (unexpected type)")
print(
f"[{time.strftime('%H:%M:%S')}] ITERATION total: {iteration_elapsed:.2f}s"
)
return "native_finished"
except Exception as e:
@@ -519,16 +612,20 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
@listen("native_tool_calls")
def execute_native_tool(self) -> Literal["native_tool_completed"]:
"""Execute a single native tool call and inject reasoning prompt.
"""Execute a SINGLE native tool call with reflection after.
Processes only the FIRST tool call from pending_tool_calls for
sequential execution with reflection after each tool.
Processes only the first tool from pending_tool_calls, then asks
the LLM if it can answer the task. Remaining tools stay in the queue
for potential execution on next iteration.
"""
if not self.state.pending_tool_calls:
return "native_tool_completed"
tool_call = self.state.pending_tool_calls[0]
self.state.pending_tool_calls = [] # Clear pending calls
# Pop just the first tool (leave the rest in queue for potential continuation)
tool_call = self.state.pending_tool_calls.pop(0)
print(
f"Executing 1 tool, {len(self.state.pending_tool_calls)} remaining in queue"
)
# Extract tool call info - handle OpenAI, Anthropic, and Gemini formats
if hasattr(tool_call, "function"):
@@ -556,6 +653,7 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
func_name = func_info.get("name", "") or tool_call.get("name", "")
func_args = func_info.get("arguments", "{}") or tool_call.get("input", {})
else:
# Unrecognized format - skip and try next
return "native_tool_completed"
# Append assistant message with single tool call
@@ -638,16 +736,41 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
color="green",
)
# Inject post-tool reasoning prompt to enforce analysis
reasoning_prompt = self._i18n.slice("post_tool_reasoning")
reasoning_message: LLMMessage = {
"role": "user",
"content": reasoning_prompt,
}
self.state.messages.append(reasoning_message)
# Only add reflection prompt if there are still pending tools
# If no pending tools, skip reflection - LLM will naturally continue
if self.state.pending_tool_calls:
print("--------------------------------")
print(
f"REFLECTION: {len(self.state.pending_tool_calls)} tools pending - asking LLM to decide"
)
print("--------------------------------")
reasoning_prompt = self._i18n.slice("post_tool_reasoning")
reasoning_message: LLMMessage = {
"role": "user",
"content": reasoning_prompt,
}
self.state.messages.append(reasoning_message)
else:
print("--------------------------------")
print("SKIPPING REFLECTION: No pending tools - LLM will continue naturally")
print("--------------------------------")
return "native_tool_completed"
def _extract_tool_name(self, tool_call: Any) -> str:
"""Extract tool name from various tool call formats."""
if hasattr(tool_call, "function"):
return tool_call.function.name
if hasattr(tool_call, "function_call") and tool_call.function_call:
return tool_call.function_call.name
if hasattr(tool_call, "name"):
return tool_call.name
if isinstance(tool_call, dict):
func_info = tool_call.get("function", {})
return func_info.get("name", "") or tool_call.get("name", "unknown")
return "unknown"
@router(execute_native_tool)
def increment_native_and_continue(self) -> Literal["initialized"]:
"""Increment iteration counter after native tool execution."""