diff --git a/lib/crewai/src/crewai/agents/crew_agent_executor.py b/lib/crewai/src/crewai/agents/crew_agent_executor.py index 99991f73b..df51807f7 100644 --- a/lib/crewai/src/crewai/agents/crew_agent_executor.py +++ b/lib/crewai/src/crewai/agents/crew_agent_executor.py @@ -6,8 +6,10 @@ and memory management. from __future__ import annotations +import asyncio from collections.abc import Callable from concurrent.futures import ThreadPoolExecutor, as_completed +import inspect import logging from typing import TYPE_CHECKING, Any, Literal, cast @@ -736,7 +738,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): ] = [] for call_id, func_name, func_args in parsed_calls: original_tool = original_tools_by_name.get(func_name) - execution_plan.append((call_id, func_name, func_args, original_tool)) + execution_plan.append( + (call_id, func_name, func_args, original_tool) + ) self._append_assistant_tool_calls_message( [ @@ -746,7 +750,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): ) max_workers = min(8, len(execution_plan)) - ordered_results: list[dict[str, Any] | None] = [None] * len(execution_plan) + ordered_results: list[dict[str, Any] | None] = [None] * len( + execution_plan + ) with ThreadPoolExecutor(max_workers=max_workers) as pool: futures = { pool.submit( @@ -803,7 +809,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): return tool_finish reasoning_prompt = self._i18n.slice("post_tool_reasoning") - reasoning_message: LLMMessage = { + reasoning_message = { "role": "user", "content": reasoning_prompt, } @@ -908,9 +914,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): elif ( should_execute and original_tool - and getattr(original_tool, "max_usage_count", None) is not None - and getattr(original_tool, "current_usage_count", 0) - >= original_tool.max_usage_count + and (max_count := getattr(original_tool, "max_usage_count", None)) + is not None + and getattr(original_tool, "current_usage_count", 0) >= max_count ): max_usage_reached = True @@ -989,13 +995,17 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): and hasattr(original_tool, "cache_function") and callable(original_tool.cache_function) ): - should_cache = original_tool.cache_function(args_dict, raw_result) + should_cache = original_tool.cache_function( + args_dict, raw_result + ) if should_cache: self.tools_handler.cache.add( tool=func_name, input=input_str, output=raw_result ) - result = str(raw_result) if not isinstance(raw_result, str) else raw_result + result = ( + str(raw_result) if not isinstance(raw_result, str) else raw_result + ) except Exception as e: result = f"Error executing tool: {e}" if self.task: @@ -1490,7 +1500,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): formatted_answer: Current agent response. """ if self.step_callback: - self.step_callback(formatted_answer) + cb_result = self.step_callback(formatted_answer) + if inspect.iscoroutine(cb_result): + asyncio.run(cb_result) def _append_message( self, text: str, role: Literal["user", "assistant", "system"] = "assistant" diff --git a/lib/crewai/src/crewai/experimental/agent_executor.py b/lib/crewai/src/crewai/experimental/agent_executor.py index 5724def3b..56c4da030 100644 --- a/lib/crewai/src/crewai/experimental/agent_executor.py +++ b/lib/crewai/src/crewai/experimental/agent_executor.py @@ -1,8 +1,10 @@ from __future__ import annotations +import asyncio from collections.abc import Callable, Coroutine from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime +import inspect import json import threading from typing import TYPE_CHECKING, Any, Literal, cast @@ -778,7 +780,7 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin): from_cache = cast(bool, execution_result["from_cache"]) original_tool = execution_result["original_tool"] - tool_message: LLMMessage = { + tool_message = { "role": "tool", "tool_call_id": call_id, "name": func_name, @@ -1358,7 +1360,9 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin): formatted_answer: Current agent response. """ if self.step_callback: - self.step_callback(formatted_answer) + cb_result = self.step_callback(formatted_answer) + if inspect.iscoroutine(cb_result): + asyncio.run(cb_result) def _append_message_to_state( self, text: str, role: Literal["user", "assistant", "system"] = "assistant" diff --git a/lib/crewai/src/crewai/task.py b/lib/crewai/src/crewai/task.py index eac42f956..775e3c94d 100644 --- a/lib/crewai/src/crewai/task.py +++ b/lib/crewai/src/crewai/task.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from concurrent.futures import Future from copy import copy as shallow_copy import datetime @@ -624,11 +625,15 @@ class Task(BaseModel): self.end_time = datetime.datetime.now() if self.callback: - self.callback(self.output) + cb_result = self.callback(self.output) + if inspect.isawaitable(cb_result): + await cb_result crew = self.agent.crew # type: ignore[union-attr] if crew and crew.task_callback and crew.task_callback != self.callback: - crew.task_callback(self.output) + cb_result = crew.task_callback(self.output) + if inspect.isawaitable(cb_result): + await cb_result if self.output_file: content = ( @@ -722,11 +727,15 @@ class Task(BaseModel): self.end_time = datetime.datetime.now() if self.callback: - self.callback(self.output) + cb_result = self.callback(self.output) + if inspect.iscoroutine(cb_result): + asyncio.run(cb_result) crew = self.agent.crew # type: ignore[union-attr] if crew and crew.task_callback and crew.task_callback != self.callback: - crew.task_callback(self.output) + cb_result = crew.task_callback(self.output) + if inspect.iscoroutine(cb_result): + asyncio.run(cb_result) if self.output_file: content = ( diff --git a/lib/crewai/src/crewai/utilities/agent_utils.py b/lib/crewai/src/crewai/utilities/agent_utils.py index 22b498541..80c80dbb6 100644 --- a/lib/crewai/src/crewai/utilities/agent_utils.py +++ b/lib/crewai/src/crewai/utilities/agent_utils.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio from collections.abc import Callable, Sequence import concurrent.futures +import inspect import json import re from typing import TYPE_CHECKING, Any, Final, Literal, TypedDict @@ -501,7 +502,9 @@ def handle_agent_action_core( - TODO: Remove messages parameter and its usage. """ if step_callback: - step_callback(tool_result) + cb_result = step_callback(tool_result) + if inspect.iscoroutine(cb_result): + asyncio.run(cb_result) formatted_answer.text += f"\nObservation: {tool_result.result}" formatted_answer.result = tool_result.result