fix: ensure callbacks are ran/awaited if promise

This commit is contained in:
Greyson LaLonde
2026-02-20 03:11:34 -05:00
parent 4a4c99d8a2
commit 53b5fdc2ff
4 changed files with 44 additions and 16 deletions

View File

@@ -6,8 +6,10 @@ and memory management.
from __future__ import annotations
import asyncio
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor, as_completed
import inspect
import logging
from typing import TYPE_CHECKING, Any, Literal, cast
@@ -736,7 +738,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
] = []
for call_id, func_name, func_args in parsed_calls:
original_tool = original_tools_by_name.get(func_name)
execution_plan.append((call_id, func_name, func_args, original_tool))
execution_plan.append(
(call_id, func_name, func_args, original_tool)
)
self._append_assistant_tool_calls_message(
[
@@ -746,7 +750,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
)
max_workers = min(8, len(execution_plan))
ordered_results: list[dict[str, Any] | None] = [None] * len(execution_plan)
ordered_results: list[dict[str, Any] | None] = [None] * len(
execution_plan
)
with ThreadPoolExecutor(max_workers=max_workers) as pool:
futures = {
pool.submit(
@@ -803,7 +809,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
return tool_finish
reasoning_prompt = self._i18n.slice("post_tool_reasoning")
reasoning_message: LLMMessage = {
reasoning_message = {
"role": "user",
"content": reasoning_prompt,
}
@@ -908,9 +914,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
elif (
should_execute
and original_tool
and getattr(original_tool, "max_usage_count", None) is not None
and getattr(original_tool, "current_usage_count", 0)
>= original_tool.max_usage_count
and (max_count := getattr(original_tool, "max_usage_count", None))
is not None
and getattr(original_tool, "current_usage_count", 0) >= max_count
):
max_usage_reached = True
@@ -989,13 +995,17 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
and hasattr(original_tool, "cache_function")
and callable(original_tool.cache_function)
):
should_cache = original_tool.cache_function(args_dict, raw_result)
should_cache = original_tool.cache_function(
args_dict, raw_result
)
if should_cache:
self.tools_handler.cache.add(
tool=func_name, input=input_str, output=raw_result
)
result = str(raw_result) if not isinstance(raw_result, str) else raw_result
result = (
str(raw_result) if not isinstance(raw_result, str) else raw_result
)
except Exception as e:
result = f"Error executing tool: {e}"
if self.task:
@@ -1490,7 +1500,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
formatted_answer: Current agent response.
"""
if self.step_callback:
self.step_callback(formatted_answer)
cb_result = self.step_callback(formatted_answer)
if inspect.iscoroutine(cb_result):
asyncio.run(cb_result)
def _append_message(
self, text: str, role: Literal["user", "assistant", "system"] = "assistant"

View File

@@ -1,8 +1,10 @@
from __future__ import annotations
import asyncio
from collections.abc import Callable, Coroutine
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
import inspect
import json
import threading
from typing import TYPE_CHECKING, Any, Literal, cast
@@ -778,7 +780,7 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
from_cache = cast(bool, execution_result["from_cache"])
original_tool = execution_result["original_tool"]
tool_message: LLMMessage = {
tool_message = {
"role": "tool",
"tool_call_id": call_id,
"name": func_name,
@@ -1358,7 +1360,9 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
formatted_answer: Current agent response.
"""
if self.step_callback:
self.step_callback(formatted_answer)
cb_result = self.step_callback(formatted_answer)
if inspect.iscoroutine(cb_result):
asyncio.run(cb_result)
def _append_message_to_state(
self, text: str, role: Literal["user", "assistant", "system"] = "assistant"

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import asyncio
from concurrent.futures import Future
from copy import copy as shallow_copy
import datetime
@@ -624,11 +625,15 @@ class Task(BaseModel):
self.end_time = datetime.datetime.now()
if self.callback:
self.callback(self.output)
cb_result = self.callback(self.output)
if inspect.isawaitable(cb_result):
await cb_result
crew = self.agent.crew # type: ignore[union-attr]
if crew and crew.task_callback and crew.task_callback != self.callback:
crew.task_callback(self.output)
cb_result = crew.task_callback(self.output)
if inspect.isawaitable(cb_result):
await cb_result
if self.output_file:
content = (
@@ -722,11 +727,15 @@ class Task(BaseModel):
self.end_time = datetime.datetime.now()
if self.callback:
self.callback(self.output)
cb_result = self.callback(self.output)
if inspect.iscoroutine(cb_result):
asyncio.run(cb_result)
crew = self.agent.crew # type: ignore[union-attr]
if crew and crew.task_callback and crew.task_callback != self.callback:
crew.task_callback(self.output)
cb_result = crew.task_callback(self.output)
if inspect.iscoroutine(cb_result):
asyncio.run(cb_result)
if self.output_file:
content = (

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio
from collections.abc import Callable, Sequence
import concurrent.futures
import inspect
import json
import re
from typing import TYPE_CHECKING, Any, Final, Literal, TypedDict
@@ -501,7 +502,9 @@ def handle_agent_action_core(
- TODO: Remove messages parameter and its usage.
"""
if step_callback:
step_callback(tool_result)
cb_result = step_callback(tool_result)
if inspect.iscoroutine(cb_result):
asyncio.run(cb_result)
formatted_answer.text += f"\nObservation: {tool_result.result}"
formatted_answer.result = tool_result.result