updated reasoning

This commit is contained in:
João Moura
2025-05-29 20:50:38 -07:00
parent 2242545a2e
commit 3161a871b3
7 changed files with 168 additions and 72 deletions

View File

@@ -395,6 +395,41 @@ class Agent(BaseAgent):
else: else:
task_prompt = self._use_trained_data(task_prompt=task_prompt) task_prompt = self._use_trained_data(task_prompt=task_prompt)
if self.reasoning:
try:
from crewai.utilities.reasoning_handler import (
AgentReasoning,
AgentReasoningOutput,
)
reasoning_handler = AgentReasoning(
task=task,
agent=self,
extra_context=context or "",
)
reasoning_output: AgentReasoningOutput = reasoning_handler.handle_agent_reasoning()
plan_text = reasoning_output.plan.plan
internal_plan_msg = (
"### INTERNAL PLAN (do NOT reveal or repeat)\n" + plan_text
)
task_prompt = (
task_prompt
+ "\n\n"
+ internal_plan_msg
)
except Exception as e:
if hasattr(self, "_logger"):
self._logger.log(
"error", f"Error during reasoning process: {str(e)}"
)
else:
print(f"Error during reasoning process: {str(e)}")
try: try:
crewai_event_bus.emit( crewai_event_bus.emit(
self, self,

View File

@@ -220,6 +220,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
llm=self.llm, llm=self.llm,
callbacks=self.callbacks, callbacks=self.callbacks,
i18n=self._i18n, i18n=self._i18n,
task_description=getattr(self.task, "description", None),
expected_output=getattr(self.task, "expected_output", None),
) )
continue continue
else: else:
@@ -297,39 +299,6 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
or (hasattr(self, "crew") and getattr(self.crew, "verbose", False)), or (hasattr(self, "crew") and getattr(self.crew, "verbose", False)),
) )
def _summarize_messages(self) -> None:
messages_groups = []
for message in self.messages:
content = message["content"]
cut_size = self.llm.get_context_window_size()
for i in range(0, len(content), cut_size):
messages_groups.append({"content": content[i : i + cut_size]})
summarized_contents = []
for group in messages_groups:
summary = self.llm.call(
[
format_message_for_llm(
self._i18n.slice("summarizer_system_message"), role="system"
),
format_message_for_llm(
self._i18n.slice("summarize_instruction").format(
group=group["content"]
),
),
],
callbacks=self.callbacks,
)
summarized_contents.append({"content": str(summary)})
merged_summary = " ".join(content["content"] for content in summarized_contents)
self.messages = [
format_message_for_llm(
self._i18n.slice("summary").format(merged_summary=merged_summary)
)
]
def _handle_crew_training_output( def _handle_crew_training_output(
self, result: AgentFinish, human_feedback: Optional[str] = None self, result: AgentFinish, human_feedback: Optional[str] = None
) -> None: ) -> None:
@@ -470,6 +439,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
Returns: Returns:
bool: True if reasoning should be triggered, False otherwise. bool: True if reasoning should be triggered, False otherwise.
""" """
if self.iterations == 0:
return False
if not hasattr(self.agent, "reasoning") or not self.agent.reasoning: if not hasattr(self.agent, "reasoning") or not self.agent.reasoning:
return False return False
@@ -561,13 +533,15 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
iteration_messages=self.messages iteration_messages=self.messages
) )
self._append_message( updated_plan_msg = (
self._i18n.retrieve("reasoning", "mid_execution_reasoning_update").format( self._i18n.retrieve("reasoning", "mid_execution_reasoning_update").format(
plan=reasoning_output.plan.plan plan=reasoning_output.plan.plan
), ) +
role="assistant", "\n\nRemember: strictly follow the updated plan above and ensure the final answer fully meets the EXPECTED OUTPUT criteria."
) )
self._append_message(updated_plan_msg, role="assistant")
self.steps_since_reasoning = 0 self.steps_since_reasoning = 0
except Exception as e: except Exception as e:

View File

@@ -527,10 +527,10 @@ class Task(BaseModel):
def prompt(self) -> str: def prompt(self) -> str:
"""Generates the task prompt with optional markdown formatting. """Generates the task prompt with optional markdown formatting.
When the markdown attribute is True, instructions for formatting the When the markdown attribute is True, instructions for formatting the
response in Markdown syntax will be added to the prompt. response in Markdown syntax will be added to the prompt.
Returns: Returns:
str: The formatted prompt string containing the task description, str: The formatted prompt string containing the task description,
expected output, and optional markdown formatting instructions. expected output, and optional markdown formatting instructions.
@@ -541,7 +541,7 @@ class Task(BaseModel):
expected_output=self.expected_output expected_output=self.expected_output
) )
tasks_slices = [self.description, output] tasks_slices = [self.description, output]
if self.markdown: if self.markdown:
markdown_instruction = """Your final answer MUST be formatted in Markdown syntax. markdown_instruction = """Your final answer MUST be formatted in Markdown syntax.
Follow these guidelines: Follow these guidelines:
@@ -550,7 +550,8 @@ Follow these guidelines:
- Use * for italic text - Use * for italic text
- Use - or * for bullet points - Use - or * for bullet points
- Use `code` for inline code - Use `code` for inline code
- Use ```language for code blocks""" - Use ```language for code blocks
- Don't start your answer with a code block"""
tasks_slices.append(markdown_instruction) tasks_slices.append(markdown_instruction)
return "\n".join(tasks_slices) return "\n".join(tasks_slices)

View File

@@ -293,6 +293,8 @@ def handle_context_length(
llm: Any, llm: Any,
callbacks: List[Any], callbacks: List[Any],
i18n: Any, i18n: Any,
task_description: Optional[str] = None,
expected_output: Optional[str] = None,
) -> None: ) -> None:
"""Handle context length exceeded by either summarizing or raising an error. """Handle context length exceeded by either summarizing or raising an error.
@@ -303,13 +305,22 @@ def handle_context_length(
llm: LLM instance for summarization llm: LLM instance for summarization
callbacks: List of callbacks for LLM callbacks: List of callbacks for LLM
i18n: I18N instance for messages i18n: I18N instance for messages
task_description: Optional original task description
expected_output: Optional expected output
""" """
if respect_context_window: if respect_context_window:
printer.print( printer.print(
content="Context length exceeded. Summarizing content to fit the model context window. Might take a while...", content="Context length exceeded. Summarizing content to fit the model context window. Might take a while...",
color="yellow", color="yellow",
) )
summarize_messages(messages, llm, callbacks, i18n) summarize_messages(
messages,
llm,
callbacks,
i18n,
task_description=task_description,
expected_output=expected_output,
)
else: else:
printer.print( printer.print(
content="Context length exceeded. Consider using smaller text or RAG tools from crewai_tools.", content="Context length exceeded. Consider using smaller text or RAG tools from crewai_tools.",
@@ -325,6 +336,8 @@ def summarize_messages(
llm: Any, llm: Any,
callbacks: List[Any], callbacks: List[Any],
i18n: Any, i18n: Any,
task_description: Optional[str] = None,
expected_output: Optional[str] = None,
) -> None: ) -> None:
"""Summarize messages to fit within context window. """Summarize messages to fit within context window.
@@ -333,6 +346,8 @@ def summarize_messages(
llm: LLM instance for summarization llm: LLM instance for summarization
callbacks: List of callbacks for LLM callbacks: List of callbacks for LLM
i18n: I18N instance for messages i18n: I18N instance for messages
task_description: Optional original task description
expected_output: Optional expected output
""" """
messages_string = " ".join([message["content"] for message in messages]) messages_string = " ".join([message["content"] for message in messages])
messages_groups = [] messages_groups = []
@@ -365,12 +380,19 @@ def summarize_messages(
merged_summary = " ".join(content["content"] for content in summarized_contents) merged_summary = " ".join(content["content"] for content in summarized_contents)
# Build the summary message and optionally inject the task reminder.
summary_message = i18n.slice("summary").format(merged_summary=merged_summary)
if task_description or expected_output:
summary_message += "\n\n" # blank line before the reminder
if task_description:
summary_message += f"Original task: {task_description}\n"
if expected_output:
summary_message += f"Expected output: {expected_output}"
# Replace the conversation with the new summary message.
messages.clear() messages.clear()
messages.append( messages.append(format_message_for_llm(summary_message))
format_message_for_llm(
i18n.slice("summary").format(merged_summary=merged_summary)
)
)
def show_agent_logs( def show_agent_logs(

View File

@@ -110,6 +110,7 @@ class EventListener(BaseEventListener):
event.crew_name or "Crew", event.crew_name or "Crew",
source.id, source.id,
"completed", "completed",
final_result=final_string_output,
) )
@crewai_event_bus.on(CrewKickoffFailedEvent) @crewai_event_bus.on(CrewKickoffFailedEvent)

View File

@@ -1,4 +1,5 @@
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import threading
from rich.console import Console from rich.console import Console
from rich.panel import Panel from rich.panel import Panel
@@ -19,9 +20,13 @@ class ConsoleFormatter:
current_reasoning_branch: Optional[Tree] = None # Track reasoning status current_reasoning_branch: Optional[Tree] = None # Track reasoning status
current_llm_tool_tree: Optional[Tree] = None current_llm_tool_tree: Optional[Tree] = None
current_adaptive_decision_branch: Optional[Tree] = None # Track last adaptive decision branch current_adaptive_decision_branch: Optional[Tree] = None # Track last adaptive decision branch
# Spinner support # Spinner support ---------------------------------------------------
_spinner_frames = ["", "", "", "", "", "", "", "", "", ""] _spinner_frames = ["", "", "", "", "", "", "", "", "", ""]
_spinner_index: int = 0 _spinner_index: int = 0
_spinner_branches: Dict[Tree, tuple[str, str, str]] = {} # branch -> (icon, name, style)
_spinner_thread: Optional[threading.Thread] = None
_stop_spinner_event: Optional[threading.Event] = None
_spinner_running: bool = False
def __init__(self, verbose: bool = False): def __init__(self, verbose: bool = False):
self.console = Console(width=None) self.console = Console(width=None)
@@ -53,6 +58,8 @@ class ConsoleFormatter:
for label, value in fields.items(): for label, value in fields.items():
content.append(f"{label}: ", style="white") content.append(f"{label}: ", style="white")
if label == "Result":
content.append("\n")
content.append( content.append(
f"{value}\n", style=fields.get(f"{label}_style", status_style) f"{value}\n", style=fields.get(f"{label}_style", status_style)
) )
@@ -142,6 +149,7 @@ class ConsoleFormatter:
crew_name: str, crew_name: str,
source_id: str, source_id: str,
status: str = "completed", status: str = "completed",
final_result: Optional[str] = None,
) -> None: ) -> None:
"""Handle crew tree updates with consistent formatting.""" """Handle crew tree updates with consistent formatting."""
if not self.verbose or tree is None: if not self.verbose or tree is None:
@@ -167,11 +175,18 @@ class ConsoleFormatter:
style, style,
) )
# Prepare additional fields for the completion panel
additional_fields: Dict[str, Any] = {"ID": source_id}
# Include the final result if provided and the status is completed
if status == "completed" and final_result is not None:
additional_fields["Result"] = final_result
content = self.create_status_content( content = self.create_status_content(
content_title, content_title,
crew_name or "Crew", crew_name or "Crew",
style, style,
ID=source_id, **additional_fields,
) )
self.print_panel(content, title, style) self.print_panel(content, title, style)
@@ -227,7 +242,7 @@ class ConsoleFormatter:
# and tool branches so that any upcoming Reasoning / Tool logs attach # and tool branches so that any upcoming Reasoning / Tool logs attach
# to the correct task. # to the correct task.
self.current_agent_branch = None self.current_agent_branch = None
self.current_reasoning_branch = None # Keep current_reasoning_branch; reasoning may still be in progress
self.current_tool_branch = None self.current_tool_branch = None
return task_branch return task_branch
@@ -283,7 +298,10 @@ class ConsoleFormatter:
self.current_task_branch = None self.current_task_branch = None
self.current_agent_branch = None self.current_agent_branch = None
self.current_tool_branch = None self.current_tool_branch = None
self.current_reasoning_branch = None # Ensure spinner is stopped if reasoning branch exists
if self.current_reasoning_branch is not None:
self._unregister_spinner_branch(self.current_reasoning_branch)
self.current_reasoning_branch = None
def create_agent_branch( def create_agent_branch(
self, task_branch: Optional[Tree], agent_role: str, crew_tree: Optional[Tree] self, task_branch: Optional[Tree], agent_role: str, crew_tree: Optional[Tree]
@@ -521,20 +539,20 @@ class ConsoleFormatter:
# Update tool usage count # Update tool usage count
self.tool_usage_counts[tool_name] = self.tool_usage_counts.get(tool_name, 0) + 1 self.tool_usage_counts[tool_name] = self.tool_usage_counts.get(tool_name, 0) + 1
# Find or create tool node # Always create a new branch for each tool invocation so that previous
tool_branch = self.current_tool_branch # tool usages remain visible in the tree.
if tool_branch is None: tool_branch = branch_to_use.add("")
tool_branch = branch_to_use.add("") self.current_tool_branch = tool_branch
self.current_tool_branch = tool_branch
# Update label with current count # Update label with current count
spinner = self._next_spinner() spinner_char = self._next_spinner()
self.update_tree_label( self.update_tree_label(
tool_branch, tool_branch,
f"🔧 {spinner}", f"🔧 {spinner_char}",
f"Using {tool_name} ({self.tool_usage_counts[tool_name]})", f"Using {tool_name} ({self.tool_usage_counts[tool_name]})",
"yellow", "yellow",
) )
self._register_spinner_branch(tool_branch, "🔧", f"Using {tool_name} ({self.tool_usage_counts[tool_name]})", "yellow")
# Print updated tree immediately # Print updated tree immediately
self.print(tree_to_use) self.print(tree_to_use)
@@ -560,13 +578,11 @@ class ConsoleFormatter:
# Update the existing tool node's label # Update the existing tool node's label
self.update_tree_label( self.update_tree_label(
tool_branch, tool_branch,
"🔧", "🔧",
f"Used {tool_name} ({self.tool_usage_counts[tool_name]})", f"Used {tool_name} ({self.tool_usage_counts[tool_name]})",
"green", "green",
) )
self._unregister_spinner_branch(tool_branch)
# Clear the current tool branch as we're done with it
self.current_tool_branch = None
# Only print if we have a valid tree and the tool node is still in it # Only print if we have a valid tree and the tool node is still in it
if isinstance(tree_to_use, Tree) and tool_branch in tree_to_use.children: if isinstance(tree_to_use, Tree) and tool_branch in tree_to_use.children:
@@ -633,8 +649,9 @@ class ConsoleFormatter:
# Only add thinking status if we don't have a current tool branch # Only add thinking status if we don't have a current tool branch
if self.current_tool_branch is None: if self.current_tool_branch is None:
tool_branch = branch_to_use.add("") tool_branch = branch_to_use.add("")
spinner = self._next_spinner() spinner_char = self._next_spinner()
self.update_tree_label(tool_branch, f"🧠 {spinner}", "Thinking...", "blue") self.update_tree_label(tool_branch, f"🧠 {spinner_char}", "Thinking...", "blue")
self._register_spinner_branch(tool_branch, "🧠", "Thinking...", "blue")
self.current_tool_branch = tool_branch self.current_tool_branch = tool_branch
self.print(tree_to_use) self.print(tree_to_use)
self.print() self.print()
@@ -668,6 +685,8 @@ class ConsoleFormatter:
for parent in parents: for parent in parents:
if isinstance(parent, Tree) and tool_branch in parent.children: if isinstance(parent, Tree) and tool_branch in parent.children:
parent.children.remove(tool_branch) parent.children.remove(tool_branch)
# Stop spinner for the thinking branch before removing
self._unregister_spinner_branch(tool_branch)
removed = True removed = True
break break
@@ -1161,8 +1180,7 @@ class ConsoleFormatter:
# Build label text depending on attempt and whether it's mid-execution # Build label text depending on attempt and whether it's mid-execution
if current_step is not None: if current_step is not None:
trigger_text = f" ({reasoning_trigger})" if reasoning_trigger else "" status_text = "Mid-Execution Reasoning"
status_text = f"Mid-Execution Reasoning{trigger_text}"
else: else:
status_text = ( status_text = (
f"Reasoning (Attempt {attempt})" if attempt > 1 else "Reasoning..." f"Reasoning (Attempt {attempt})" if attempt > 1 else "Reasoning..."
@@ -1170,8 +1188,11 @@ class ConsoleFormatter:
# ⠋ is the first frame of a braille spinner visually hints progress even # ⠋ is the first frame of a braille spinner visually hints progress even
# without true animation. # without true animation.
spinner = self._next_spinner() spinner_char = self._next_spinner()
self.update_tree_label(reasoning_branch, f"🧠 {spinner}", status_text, "yellow") self.update_tree_label(reasoning_branch, f"🧠 {spinner_char}", status_text, "yellow")
# Register branch for continuous spinner
self._register_spinner_branch(reasoning_branch, "🧠", status_text, "yellow")
self.print(tree_to_use) self.print(tree_to_use)
self.print() self.print()
@@ -1199,7 +1220,8 @@ class ConsoleFormatter:
or crew_tree or crew_tree
) )
style = "green" if ready else "yellow" # Completed reasoning should always display in green.
style = "green"
# Build duration part separately for cleaner formatting # Build duration part separately for cleaner formatting
duration_part = f"{duration_seconds:.2f}s" if duration_seconds > 0 else "" duration_part = f"{duration_seconds:.2f}s" if duration_seconds > 0 else ""
@@ -1356,3 +1378,43 @@ class ConsoleFormatter:
frame = self._spinner_frames[self._spinner_index] frame = self._spinner_frames[self._spinner_index]
self._spinner_index = (self._spinner_index + 1) % len(self._spinner_frames) self._spinner_index = (self._spinner_index + 1) % len(self._spinner_frames)
return frame return frame
def _register_spinner_branch(self, branch: Tree, icon: str, name: str, style: str):
"""Start animating spinner for given branch."""
self._spinner_branches[branch] = (icon, name, style)
if not self._spinner_running:
self._start_spinner_thread()
def _unregister_spinner_branch(self, branch: Optional[Tree]):
if branch is None:
return
self._spinner_branches.pop(branch, None)
if not self._spinner_branches:
self._stop_spinner_thread()
def _start_spinner_thread(self):
if self._spinner_running:
return
self._stop_spinner_event = threading.Event()
self._spinner_thread = threading.Thread(target=self._spinner_loop, daemon=True)
self._spinner_thread.start()
self._spinner_running = True
def _stop_spinner_thread(self):
if self._stop_spinner_event:
self._stop_spinner_event.set()
self._spinner_running = False
def _spinner_loop(self):
import time
while self._stop_spinner_event and not self._stop_spinner_event.is_set():
if self._live and self._spinner_branches:
for branch, (icon, name, style) in list(self._spinner_branches.items()):
spinner_char = self._next_spinner()
self.update_tree_label(branch, f"{icon} {spinner_char}", name, style)
# Refresh live view
try:
self._live.update(self._live.renderable, refresh=True)
except Exception:
pass
time.sleep(0.15)

View File

@@ -38,7 +38,7 @@ class AgentReasoning:
Handles the agent reasoning process, enabling an agent to reflect and create a plan Handles the agent reasoning process, enabling an agent to reflect and create a plan
before executing a task. before executing a task.
""" """
def __init__(self, task: Task, agent: Agent): def __init__(self, task: Task, agent: Agent, extra_context: str | None = None):
if not task or not agent: if not task or not agent:
raise ValueError("Both task and agent must be provided.") raise ValueError("Both task and agent must be provided.")
self.task = task self.task = task
@@ -46,6 +46,7 @@ class AgentReasoning:
self.llm = cast(LLM, agent.llm) self.llm = cast(LLM, agent.llm)
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
self.i18n = I18N() self.i18n = I18N()
self.extra_context = extra_context or ""
def handle_agent_reasoning(self) -> AgentReasoningOutput: def handle_agent_reasoning(self) -> AgentReasoningOutput:
""" """
@@ -323,7 +324,7 @@ class AgentReasoning:
role=self.agent.role, role=self.agent.role,
goal=self.agent.goal, goal=self.agent.goal,
backstory=self.__get_agent_backstory(), backstory=self.__get_agent_backstory(),
description=self.task.description, description=self.task.description + (f"\n\nContext:\n{self.extra_context}" if self.extra_context else ""),
expected_output=self.task.expected_output, expected_output=self.task.expected_output,
tools=available_tools tools=available_tools
) )
@@ -547,7 +548,7 @@ class AgentReasoning:
recent_messages += f"{role.upper()}: {content[:200]}...\n\n" recent_messages += f"{role.upper()}: {content[:200]}...\n\n"
return self.i18n.retrieve("reasoning", "mid_execution_reasoning").format( return self.i18n.retrieve("reasoning", "mid_execution_reasoning").format(
description=self.task.description, description=self.task.description + (f"\n\nContext:\n{self.extra_context}" if self.extra_context else ""),
expected_output=self.task.expected_output, expected_output=self.task.expected_output,
current_steps=current_steps, current_steps=current_steps,
tools_used=tools_used_str, tools_used=tools_used_str,
@@ -681,7 +682,7 @@ class AgentReasoning:
) )
context_prompt = self.i18n.retrieve("reasoning", "adaptive_reasoning_context").format( context_prompt = self.i18n.retrieve("reasoning", "adaptive_reasoning_context").format(
description=self.task.description, description=self.task.description + (f"\n\nContext:\n{self.extra_context}" if self.extra_context else ""),
expected_output=self.task.expected_output, expected_output=self.task.expected_output,
current_steps=current_steps, current_steps=current_steps,
tools_used=tools_used_str, tools_used=tools_used_str,