mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
feat: add async execution support to agent executor
This commit is contained in:
@@ -28,6 +28,7 @@ from crewai.hooks.llm_hooks import (
|
|||||||
get_before_llm_call_hooks,
|
get_before_llm_call_hooks,
|
||||||
)
|
)
|
||||||
from crewai.utilities.agent_utils import (
|
from crewai.utilities.agent_utils import (
|
||||||
|
aget_llm_response,
|
||||||
enforce_rpm_limit,
|
enforce_rpm_limit,
|
||||||
format_message_for_llm,
|
format_message_for_llm,
|
||||||
get_llm_response,
|
get_llm_response,
|
||||||
@@ -43,7 +44,10 @@ from crewai.utilities.agent_utils import (
|
|||||||
from crewai.utilities.constants import TRAINING_DATA_FILE
|
from crewai.utilities.constants import TRAINING_DATA_FILE
|
||||||
from crewai.utilities.i18n import I18N, get_i18n
|
from crewai.utilities.i18n import I18N, get_i18n
|
||||||
from crewai.utilities.printer import Printer
|
from crewai.utilities.printer import Printer
|
||||||
from crewai.utilities.tool_utils import execute_tool_and_check_finality
|
from crewai.utilities.tool_utils import (
|
||||||
|
aexecute_tool_and_check_finality,
|
||||||
|
execute_tool_and_check_finality,
|
||||||
|
)
|
||||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||||
|
|
||||||
|
|
||||||
@@ -134,8 +138,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
self.messages: list[LLMMessage] = []
|
self.messages: list[LLMMessage] = []
|
||||||
self.iterations = 0
|
self.iterations = 0
|
||||||
self.log_error_after = 3
|
self.log_error_after = 3
|
||||||
self.before_llm_call_hooks: list[Callable] = []
|
self.before_llm_call_hooks: list[Callable[..., Any]] = []
|
||||||
self.after_llm_call_hooks: list[Callable] = []
|
self.after_llm_call_hooks: list[Callable[..., Any]] = []
|
||||||
self.before_llm_call_hooks.extend(get_before_llm_call_hooks())
|
self.before_llm_call_hooks.extend(get_before_llm_call_hooks())
|
||||||
self.after_llm_call_hooks.extend(get_after_llm_call_hooks())
|
self.after_llm_call_hooks.extend(get_after_llm_call_hooks())
|
||||||
if self.llm:
|
if self.llm:
|
||||||
@@ -312,6 +316,154 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
self._show_logs(formatted_answer)
|
self._show_logs(formatted_answer)
|
||||||
return formatted_answer
|
return formatted_answer
|
||||||
|
|
||||||
|
async def ainvoke(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Execute the agent asynchronously with given inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: Input dictionary containing prompt variables.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with agent output.
|
||||||
|
"""
|
||||||
|
if "system" in self.prompt:
|
||||||
|
system_prompt = self._format_prompt(
|
||||||
|
cast(str, self.prompt.get("system", "")), inputs
|
||||||
|
)
|
||||||
|
user_prompt = self._format_prompt(
|
||||||
|
cast(str, self.prompt.get("user", "")), inputs
|
||||||
|
)
|
||||||
|
self.messages.append(format_message_for_llm(system_prompt, role="system"))
|
||||||
|
self.messages.append(format_message_for_llm(user_prompt))
|
||||||
|
else:
|
||||||
|
user_prompt = self._format_prompt(self.prompt.get("prompt", ""), inputs)
|
||||||
|
self.messages.append(format_message_for_llm(user_prompt))
|
||||||
|
|
||||||
|
self._show_start_logs()
|
||||||
|
|
||||||
|
self.ask_for_human_input = bool(inputs.get("ask_for_human_input", False))
|
||||||
|
|
||||||
|
try:
|
||||||
|
formatted_answer = await self._ainvoke_loop()
|
||||||
|
except AssertionError:
|
||||||
|
self._printer.print(
|
||||||
|
content="Agent failed to reach a final answer. This is likely a bug - please report it.",
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
handle_unknown_error(self._printer, e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
if self.ask_for_human_input:
|
||||||
|
formatted_answer = self._handle_human_feedback(formatted_answer)
|
||||||
|
|
||||||
|
self._create_short_term_memory(formatted_answer)
|
||||||
|
self._create_long_term_memory(formatted_answer)
|
||||||
|
self._create_external_memory(formatted_answer)
|
||||||
|
return {"output": formatted_answer.output}
|
||||||
|
|
||||||
|
async def _ainvoke_loop(self) -> AgentFinish:
|
||||||
|
"""Execute agent loop asynchronously until completion.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Final answer from the agent.
|
||||||
|
"""
|
||||||
|
formatted_answer = None
|
||||||
|
while not isinstance(formatted_answer, AgentFinish):
|
||||||
|
try:
|
||||||
|
if has_reached_max_iterations(self.iterations, self.max_iter):
|
||||||
|
formatted_answer = handle_max_iterations_exceeded(
|
||||||
|
formatted_answer,
|
||||||
|
printer=self._printer,
|
||||||
|
i18n=self._i18n,
|
||||||
|
messages=self.messages,
|
||||||
|
llm=self.llm,
|
||||||
|
callbacks=self.callbacks,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
enforce_rpm_limit(self.request_within_rpm_limit)
|
||||||
|
|
||||||
|
answer = await aget_llm_response(
|
||||||
|
llm=self.llm,
|
||||||
|
messages=self.messages,
|
||||||
|
callbacks=self.callbacks,
|
||||||
|
printer=self._printer,
|
||||||
|
from_task=self.task,
|
||||||
|
from_agent=self.agent,
|
||||||
|
response_model=self.response_model,
|
||||||
|
executor_context=self,
|
||||||
|
)
|
||||||
|
formatted_answer = process_llm_response(answer, self.use_stop_words) # type: ignore[assignment]
|
||||||
|
|
||||||
|
if isinstance(formatted_answer, AgentAction):
|
||||||
|
fingerprint_context = {}
|
||||||
|
if (
|
||||||
|
self.agent
|
||||||
|
and hasattr(self.agent, "security_config")
|
||||||
|
and hasattr(self.agent.security_config, "fingerprint")
|
||||||
|
):
|
||||||
|
fingerprint_context = {
|
||||||
|
"agent_fingerprint": str(
|
||||||
|
self.agent.security_config.fingerprint
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
tool_result = await aexecute_tool_and_check_finality(
|
||||||
|
agent_action=formatted_answer,
|
||||||
|
fingerprint_context=fingerprint_context,
|
||||||
|
tools=self.tools,
|
||||||
|
i18n=self._i18n,
|
||||||
|
agent_key=self.agent.key if self.agent else None,
|
||||||
|
agent_role=self.agent.role if self.agent else None,
|
||||||
|
tools_handler=self.tools_handler,
|
||||||
|
task=self.task,
|
||||||
|
agent=self.agent,
|
||||||
|
function_calling_llm=self.function_calling_llm,
|
||||||
|
crew=self.crew,
|
||||||
|
)
|
||||||
|
formatted_answer = self._handle_agent_action(
|
||||||
|
formatted_answer, tool_result
|
||||||
|
)
|
||||||
|
|
||||||
|
self._invoke_step_callback(formatted_answer) # type: ignore[arg-type]
|
||||||
|
self._append_message(formatted_answer.text) # type: ignore[union-attr,attr-defined]
|
||||||
|
|
||||||
|
except OutputParserError as e:
|
||||||
|
formatted_answer = handle_output_parser_exception( # type: ignore[assignment]
|
||||||
|
e=e,
|
||||||
|
messages=self.messages,
|
||||||
|
iterations=self.iterations,
|
||||||
|
log_error_after=self.log_error_after,
|
||||||
|
printer=self._printer,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if e.__class__.__module__.startswith("litellm"):
|
||||||
|
raise e
|
||||||
|
if is_context_length_exceeded(e):
|
||||||
|
handle_context_length(
|
||||||
|
respect_context_window=self.respect_context_window,
|
||||||
|
printer=self._printer,
|
||||||
|
messages=self.messages,
|
||||||
|
llm=self.llm,
|
||||||
|
callbacks=self.callbacks,
|
||||||
|
i18n=self._i18n,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
handle_unknown_error(self._printer, e)
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
self.iterations += 1
|
||||||
|
|
||||||
|
if not isinstance(formatted_answer, AgentFinish):
|
||||||
|
raise RuntimeError(
|
||||||
|
"Agent execution ended without reaching a final answer. "
|
||||||
|
f"Got {type(formatted_answer).__name__} instead of AgentFinish."
|
||||||
|
)
|
||||||
|
self._show_logs(formatted_answer)
|
||||||
|
return formatted_answer
|
||||||
|
|
||||||
def _handle_agent_action(
|
def _handle_agent_action(
|
||||||
self, formatted_answer: AgentAction, tool_result: ToolResult
|
self, formatted_answer: AgentAction, tool_result: ToolResult
|
||||||
) -> AgentAction | AgentFinish:
|
) -> AgentAction | AgentFinish:
|
||||||
|
|||||||
@@ -242,17 +242,17 @@ def get_llm_response(
|
|||||||
"""Call the LLM and return the response, handling any invalid responses.
|
"""Call the LLM and return the response, handling any invalid responses.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
llm: The LLM instance to call
|
llm: The LLM instance to call.
|
||||||
messages: The messages to send to the LLM
|
messages: The messages to send to the LLM.
|
||||||
callbacks: List of callbacks for the LLM call
|
callbacks: List of callbacks for the LLM call.
|
||||||
printer: Printer instance for output
|
printer: Printer instance for output.
|
||||||
from_task: Optional task context for the LLM call
|
from_task: Optional task context for the LLM call.
|
||||||
from_agent: Optional agent context for the LLM call
|
from_agent: Optional agent context for the LLM call.
|
||||||
response_model: Optional Pydantic model for structured outputs
|
response_model: Optional Pydantic model for structured outputs.
|
||||||
executor_context: Optional executor context for hook invocation
|
executor_context: Optional executor context for hook invocation.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The response from the LLM as a string
|
The response from the LLM as a string.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
Exception: If an error occurs.
|
Exception: If an error occurs.
|
||||||
@@ -284,6 +284,60 @@ def get_llm_response(
|
|||||||
return _setup_after_llm_call_hooks(executor_context, answer, printer)
|
return _setup_after_llm_call_hooks(executor_context, answer, printer)
|
||||||
|
|
||||||
|
|
||||||
|
async def aget_llm_response(
|
||||||
|
llm: LLM | BaseLLM,
|
||||||
|
messages: list[LLMMessage],
|
||||||
|
callbacks: list[TokenCalcHandler],
|
||||||
|
printer: Printer,
|
||||||
|
from_task: Task | None = None,
|
||||||
|
from_agent: Agent | LiteAgent | None = None,
|
||||||
|
response_model: type[BaseModel] | None = None,
|
||||||
|
executor_context: CrewAgentExecutor | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Call the LLM asynchronously and return the response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm: The LLM instance to call.
|
||||||
|
messages: The messages to send to the LLM.
|
||||||
|
callbacks: List of callbacks for the LLM call.
|
||||||
|
printer: Printer instance for output.
|
||||||
|
from_task: Optional task context for the LLM call.
|
||||||
|
from_agent: Optional agent context for the LLM call.
|
||||||
|
response_model: Optional Pydantic model for structured outputs.
|
||||||
|
executor_context: Optional executor context for hook invocation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The response from the LLM as a string.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If an error occurs.
|
||||||
|
ValueError: If the response is None or empty.
|
||||||
|
"""
|
||||||
|
if executor_context is not None:
|
||||||
|
if not _setup_before_llm_call_hooks(executor_context, printer):
|
||||||
|
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||||
|
messages = executor_context.messages
|
||||||
|
|
||||||
|
try:
|
||||||
|
answer = await llm.acall(
|
||||||
|
messages,
|
||||||
|
callbacks=callbacks,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent, # type: ignore[arg-type]
|
||||||
|
response_model=response_model,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
if not answer:
|
||||||
|
printer.print(
|
||||||
|
content="Received None or empty response from LLM call.",
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
raise ValueError("Invalid response from LLM call - None or empty.")
|
||||||
|
|
||||||
|
return _setup_after_llm_call_hooks(executor_context, answer, printer)
|
||||||
|
|
||||||
|
|
||||||
def process_llm_response(
|
def process_llm_response(
|
||||||
answer: str, use_stop_words: bool
|
answer: str, use_stop_words: bool
|
||||||
) -> AgentAction | AgentFinish:
|
) -> AgentAction | AgentFinish:
|
||||||
|
|||||||
Reference in New Issue
Block a user