mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-25 08:08:14 +00:00
Compare commits
3 Commits
devin/1768
...
devin/1740
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
331f7a9fe0 | ||
|
|
a651d7ddd3 | ||
|
|
1225071e00 |
@@ -249,6 +249,7 @@ class Agent(BaseAgent):
|
|||||||
"tool_names": self.agent_executor.tools_names,
|
"tool_names": self.agent_executor.tools_names,
|
||||||
"tools": self.agent_executor.tools_description,
|
"tools": self.agent_executor.tools_description,
|
||||||
"ask_for_human_input": task.human_input,
|
"ask_for_human_input": task.human_input,
|
||||||
|
"max_dialogue_rounds": task.max_dialogue_rounds,
|
||||||
}
|
}
|
||||||
)["output"]
|
)["output"]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -94,10 +94,20 @@ class CrewAgentExecutorMixin:
|
|||||||
print(f"Failed to add to long term memory: {e}")
|
print(f"Failed to add to long term memory: {e}")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _ask_human_input(self, final_answer: str) -> str:
|
def _ask_human_input(self, final_answer: str, current_round: int = 1, max_rounds: int = 10) -> str:
|
||||||
"""Prompt human input with mode-appropriate messaging."""
|
"""Prompt human input with mode-appropriate messaging.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
final_answer: The final answer from the agent
|
||||||
|
current_round: The current dialogue round (default: 1)
|
||||||
|
max_rounds: Maximum number of dialogue rounds (default: 10)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The user's feedback
|
||||||
|
"""
|
||||||
|
round_info = f"\033[1m\033[93mRound {current_round}/{max_rounds}\033[00m"
|
||||||
self._printer.print(
|
self._printer.print(
|
||||||
content=f"\033[1m\033[95m ## Final Result:\033[00m \033[92m{final_answer}\033[00m"
|
content=f"\033[1m\033[95m ## Result {round_info}:\033[00m \033[92m{final_answer}\033[00m"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training mode prompt (single iteration)
|
# Training mode prompt (single iteration)
|
||||||
@@ -113,7 +123,7 @@ class CrewAgentExecutorMixin:
|
|||||||
else:
|
else:
|
||||||
prompt = (
|
prompt = (
|
||||||
"\n\n=====\n"
|
"\n\n=====\n"
|
||||||
"## HUMAN FEEDBACK: Provide feedback on the Final Result and Agent's actions.\n"
|
f"## HUMAN FEEDBACK (Round {current_round}/{max_rounds}): Provide feedback on the Result and Agent's actions.\n"
|
||||||
"Please follow these guidelines:\n"
|
"Please follow these guidelines:\n"
|
||||||
" - If you are happy with the result, simply hit Enter without typing anything.\n"
|
" - If you are happy with the result, simply hit Enter without typing anything.\n"
|
||||||
" - Otherwise, provide specific improvement requests.\n"
|
" - Otherwise, provide specific improvement requests.\n"
|
||||||
|
|||||||
@@ -103,6 +103,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
self._show_start_logs()
|
self._show_start_logs()
|
||||||
|
|
||||||
self.ask_for_human_input = bool(inputs.get("ask_for_human_input", False))
|
self.ask_for_human_input = bool(inputs.get("ask_for_human_input", False))
|
||||||
|
max_rounds = int(inputs.get("max_dialogue_rounds", 10))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
formatted_answer = self._invoke_loop()
|
formatted_answer = self._invoke_loop()
|
||||||
@@ -121,7 +122,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
if self.ask_for_human_input:
|
if self.ask_for_human_input:
|
||||||
formatted_answer = self._handle_human_feedback(formatted_answer)
|
formatted_answer = self._handle_human_feedback(formatted_answer, max_rounds)
|
||||||
|
|
||||||
self._create_short_term_memory(formatted_answer)
|
self._create_short_term_memory(formatted_answer)
|
||||||
self._create_long_term_memory(formatted_answer)
|
self._create_long_term_memory(formatted_answer)
|
||||||
@@ -524,21 +525,22 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
prompt = prompt.rstrip()
|
prompt = prompt.rstrip()
|
||||||
return {"role": role, "content": prompt}
|
return {"role": role, "content": prompt}
|
||||||
|
|
||||||
def _handle_human_feedback(self, formatted_answer: AgentFinish) -> AgentFinish:
|
def _handle_human_feedback(self, formatted_answer: AgentFinish, max_rounds: int = 10) -> AgentFinish:
|
||||||
"""Handle human feedback with different flows for training vs regular use.
|
"""Handle human feedback with different flows for training vs regular use.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
formatted_answer: The initial AgentFinish result to get feedback on
|
formatted_answer: The initial AgentFinish result to get feedback on
|
||||||
|
max_rounds: Maximum number of dialogue rounds (default: 10)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
AgentFinish: The final answer after processing feedback
|
AgentFinish: The final answer after processing feedback
|
||||||
"""
|
"""
|
||||||
human_feedback = self._ask_human_input(formatted_answer.output)
|
human_feedback = self._ask_human_input(formatted_answer.output, 1, max_rounds)
|
||||||
|
|
||||||
if self._is_training_mode():
|
if self._is_training_mode():
|
||||||
return self._handle_training_feedback(formatted_answer, human_feedback)
|
return self._handle_training_feedback(formatted_answer, human_feedback)
|
||||||
|
|
||||||
return self._handle_regular_feedback(formatted_answer, human_feedback)
|
return self._handle_regular_feedback(formatted_answer, human_feedback, max_rounds)
|
||||||
|
|
||||||
def _is_training_mode(self) -> bool:
|
def _is_training_mode(self) -> bool:
|
||||||
"""Check if crew is in training mode."""
|
"""Check if crew is in training mode."""
|
||||||
@@ -560,19 +562,33 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
return improved_answer
|
return improved_answer
|
||||||
|
|
||||||
def _handle_regular_feedback(
|
def _handle_regular_feedback(
|
||||||
self, current_answer: AgentFinish, initial_feedback: str
|
self, current_answer: AgentFinish, initial_feedback: str, max_rounds: int = 10
|
||||||
) -> AgentFinish:
|
) -> AgentFinish:
|
||||||
"""Process feedback for regular use with potential multiple iterations."""
|
"""Process feedback for regular use with potential multiple iterations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
current_answer: The initial AgentFinish result to get feedback on
|
||||||
|
initial_feedback: The initial feedback from the user
|
||||||
|
max_rounds: Maximum number of dialogue rounds (default: 10)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AgentFinish: The final answer after processing feedback
|
||||||
|
"""
|
||||||
|
if max_rounds < 1:
|
||||||
|
raise ValueError("max_rounds must be positive")
|
||||||
|
|
||||||
feedback = initial_feedback
|
feedback = initial_feedback
|
||||||
answer = current_answer
|
answer = current_answer
|
||||||
|
current_round = 1
|
||||||
|
|
||||||
while self.ask_for_human_input:
|
while self.ask_for_human_input and current_round <= max_rounds:
|
||||||
# If the user provides a blank response, assume they are happy with the result
|
# If the user provides a blank response, assume they are happy with the result
|
||||||
if feedback.strip() == "":
|
if feedback.strip() == "":
|
||||||
self.ask_for_human_input = False
|
self.ask_for_human_input = False
|
||||||
else:
|
else:
|
||||||
answer = self._process_feedback_iteration(feedback)
|
answer = self._process_feedback_iteration(feedback)
|
||||||
feedback = self._ask_human_input(answer.output)
|
feedback = self._ask_human_input(answer.output, current_round, max_rounds)
|
||||||
|
current_round += 1
|
||||||
|
|
||||||
return answer
|
return answer
|
||||||
|
|
||||||
|
|||||||
@@ -125,6 +125,12 @@ class Task(BaseModel):
|
|||||||
description="Whether the task should have a human review the final answer of the agent",
|
description="Whether the task should have a human review the final answer of the agent",
|
||||||
default=False,
|
default=False,
|
||||||
)
|
)
|
||||||
|
max_dialogue_rounds: int = Field(
|
||||||
|
default=10,
|
||||||
|
description="Maximum number of dialogue rounds for human input",
|
||||||
|
ge=1, # Ensures positive integer
|
||||||
|
examples=[5, 10, 15],
|
||||||
|
)
|
||||||
converter_cls: Optional[Type[Converter]] = Field(
|
converter_cls: Optional[Type[Converter]] = Field(
|
||||||
description="A converter class used to export structured output",
|
description="A converter class used to export structured output",
|
||||||
default=None,
|
default=None,
|
||||||
|
|||||||
@@ -1206,6 +1206,7 @@ def test_agent_max_retry_limit():
|
|||||||
"tool_names": "",
|
"tool_names": "",
|
||||||
"tools": "",
|
"tools": "",
|
||||||
"ask_for_human_input": True,
|
"ask_for_human_input": True,
|
||||||
|
"max_dialogue_rounds": 10,
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
mock.call(
|
mock.call(
|
||||||
@@ -1214,6 +1215,7 @@ def test_agent_max_retry_limit():
|
|||||||
"tool_names": "",
|
"tool_names": "",
|
||||||
"tools": "",
|
"tools": "",
|
||||||
"ask_for_human_input": True,
|
"ask_for_human_input": True,
|
||||||
|
"max_dialogue_rounds": 10,
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|||||||
77
tests/agents/test_multi_round_dialogue.py
Normal file
77
tests/agents/test_multi_round_dialogue.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
import unittest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from langchain_core.agents import AgentFinish
|
||||||
|
|
||||||
|
from crewai.task import Task
|
||||||
|
|
||||||
|
|
||||||
|
class TestMultiRoundDialogue(unittest.TestCase):
|
||||||
|
"""Test the multi-round dialogue functionality."""
|
||||||
|
|
||||||
|
def test_task_max_dialogue_rounds_default(self):
|
||||||
|
"""Test that Task has a default max_dialogue_rounds of 10."""
|
||||||
|
# Create a task with default max_dialogue_rounds
|
||||||
|
task = Task(
|
||||||
|
description="Test task",
|
||||||
|
expected_output="Test output",
|
||||||
|
human_input=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the default value
|
||||||
|
self.assertEqual(task.max_dialogue_rounds, 10)
|
||||||
|
|
||||||
|
def test_task_max_dialogue_rounds_custom(self):
|
||||||
|
"""Test that Task accepts a custom max_dialogue_rounds."""
|
||||||
|
# Create a task with custom max_dialogue_rounds
|
||||||
|
task = Task(
|
||||||
|
description="Test task",
|
||||||
|
expected_output="Test output",
|
||||||
|
human_input=True,
|
||||||
|
max_dialogue_rounds=5
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the custom value
|
||||||
|
self.assertEqual(task.max_dialogue_rounds, 5)
|
||||||
|
|
||||||
|
def test_task_max_dialogue_rounds_validation(self):
|
||||||
|
"""Test that Task validates max_dialogue_rounds as a positive integer."""
|
||||||
|
# Create a task with invalid max_dialogue_rounds
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
task = Task(
|
||||||
|
description="Test task",
|
||||||
|
expected_output="Test output",
|
||||||
|
human_input=True,
|
||||||
|
max_dialogue_rounds=0
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_handle_regular_feedback_rounds(self):
|
||||||
|
"""Test that _handle_regular_feedback correctly handles multiple rounds."""
|
||||||
|
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||||
|
|
||||||
|
# Create a simple mock executor
|
||||||
|
executor = MagicMock()
|
||||||
|
executor.ask_for_human_input = True
|
||||||
|
executor._ask_human_input = MagicMock(side_effect=["Feedback", ""])
|
||||||
|
executor._process_feedback_iteration = MagicMock(return_value=MagicMock())
|
||||||
|
|
||||||
|
# Create a sample initial answer
|
||||||
|
initial_answer = MagicMock()
|
||||||
|
|
||||||
|
# Call the method directly
|
||||||
|
CrewAgentExecutor._handle_regular_feedback(
|
||||||
|
executor,
|
||||||
|
initial_answer,
|
||||||
|
"Initial feedback",
|
||||||
|
max_rounds=3
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the correct number of iterations occurred
|
||||||
|
# First call for initial feedback, second call for empty feedback to end loop
|
||||||
|
self.assertEqual(executor._ask_human_input.call_count, 2)
|
||||||
|
# The _process_feedback_iteration is called for the initial feedback and the first round
|
||||||
|
self.assertEqual(executor._process_feedback_iteration.call_count, 2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user