Compare commits

...

1 Commits

Author SHA1 Message Date
Greyson LaLonde
71b4f8402a fix: ensure callbacks are ran/awaited if promise
Some checks are pending
CodeQL Advanced / Analyze (actions) (push) Waiting to run
CodeQL Advanced / Analyze (python) (push) Waiting to run
2026-02-20 13:15:50 -05:00
5 changed files with 85 additions and 17 deletions

View File

@@ -6,8 +6,10 @@ and memory management.
from __future__ import annotations
import asyncio
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor, as_completed
import inspect
import logging
from typing import TYPE_CHECKING, Any, Literal, cast
@@ -736,7 +738,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
] = []
for call_id, func_name, func_args in parsed_calls:
original_tool = original_tools_by_name.get(func_name)
execution_plan.append((call_id, func_name, func_args, original_tool))
execution_plan.append(
(call_id, func_name, func_args, original_tool)
)
self._append_assistant_tool_calls_message(
[
@@ -746,7 +750,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
)
max_workers = min(8, len(execution_plan))
ordered_results: list[dict[str, Any] | None] = [None] * len(execution_plan)
ordered_results: list[dict[str, Any] | None] = [None] * len(
execution_plan
)
with ThreadPoolExecutor(max_workers=max_workers) as pool:
futures = {
pool.submit(
@@ -803,7 +809,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
return tool_finish
reasoning_prompt = self._i18n.slice("post_tool_reasoning")
reasoning_message: LLMMessage = {
reasoning_message = {
"role": "user",
"content": reasoning_prompt,
}
@@ -908,9 +914,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
elif (
should_execute
and original_tool
and getattr(original_tool, "max_usage_count", None) is not None
and getattr(original_tool, "current_usage_count", 0)
>= original_tool.max_usage_count
and (max_count := getattr(original_tool, "max_usage_count", None))
is not None
and getattr(original_tool, "current_usage_count", 0) >= max_count
):
max_usage_reached = True
@@ -989,13 +995,17 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
and hasattr(original_tool, "cache_function")
and callable(original_tool.cache_function)
):
should_cache = original_tool.cache_function(args_dict, raw_result)
should_cache = original_tool.cache_function(
args_dict, raw_result
)
if should_cache:
self.tools_handler.cache.add(
tool=func_name, input=input_str, output=raw_result
)
result = str(raw_result) if not isinstance(raw_result, str) else raw_result
result = (
str(raw_result) if not isinstance(raw_result, str) else raw_result
)
except Exception as e:
result = f"Error executing tool: {e}"
if self.task:
@@ -1490,7 +1500,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
formatted_answer: Current agent response.
"""
if self.step_callback:
self.step_callback(formatted_answer)
cb_result = self.step_callback(formatted_answer)
if inspect.iscoroutine(cb_result):
asyncio.run(cb_result)
def _append_message(
self, text: str, role: Literal["user", "assistant", "system"] = "assistant"

View File

@@ -1,8 +1,10 @@
from __future__ import annotations
import asyncio
from collections.abc import Callable, Coroutine
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
import inspect
import json
import threading
from typing import TYPE_CHECKING, Any, Literal, cast
@@ -778,7 +780,7 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
from_cache = cast(bool, execution_result["from_cache"])
original_tool = execution_result["original_tool"]
tool_message: LLMMessage = {
tool_message = {
"role": "tool",
"tool_call_id": call_id,
"name": func_name,
@@ -1358,7 +1360,9 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
formatted_answer: Current agent response.
"""
if self.step_callback:
self.step_callback(formatted_answer)
cb_result = self.step_callback(formatted_answer)
if inspect.iscoroutine(cb_result):
asyncio.run(cb_result)
def _append_message_to_state(
self, text: str, role: Literal["user", "assistant", "system"] = "assistant"

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import asyncio
from concurrent.futures import Future
from copy import copy as shallow_copy
import datetime
@@ -624,11 +625,15 @@ class Task(BaseModel):
self.end_time = datetime.datetime.now()
if self.callback:
self.callback(self.output)
cb_result = self.callback(self.output)
if inspect.isawaitable(cb_result):
await cb_result
crew = self.agent.crew # type: ignore[union-attr]
if crew and crew.task_callback and crew.task_callback != self.callback:
crew.task_callback(self.output)
cb_result = crew.task_callback(self.output)
if inspect.isawaitable(cb_result):
await cb_result
if self.output_file:
content = (
@@ -722,11 +727,15 @@ class Task(BaseModel):
self.end_time = datetime.datetime.now()
if self.callback:
self.callback(self.output)
cb_result = self.callback(self.output)
if inspect.iscoroutine(cb_result):
asyncio.run(cb_result)
crew = self.agent.crew # type: ignore[union-attr]
if crew and crew.task_callback and crew.task_callback != self.callback:
crew.task_callback(self.output)
cb_result = crew.task_callback(self.output)
if inspect.iscoroutine(cb_result):
asyncio.run(cb_result)
if self.output_file:
content = (

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio
from collections.abc import Callable, Sequence
import concurrent.futures
import inspect
import json
import re
from typing import TYPE_CHECKING, Any, Final, Literal, TypedDict
@@ -501,7 +502,9 @@ def handle_agent_action_core(
- TODO: Remove messages parameter and its usage.
"""
if step_callback:
step_callback(tool_result)
cb_result = step_callback(tool_result)
if inspect.iscoroutine(cb_result):
asyncio.run(cb_result)
formatted_answer.text += f"\nObservation: {tool_result.result}"
formatted_answer.result = tool_result.result

View File

@@ -2,7 +2,7 @@
import asyncio
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
@@ -291,6 +291,46 @@ class TestAsyncAgentExecutor:
assert max_concurrent > 1, f"Expected concurrent execution, max concurrent was {max_concurrent}"
class TestInvokeStepCallback:
"""Tests for _invoke_step_callback with sync and async callbacks."""
def test_invoke_step_callback_with_sync_callback(
self, executor: CrewAgentExecutor
) -> None:
"""Test that a sync step callback is called normally."""
callback = Mock()
executor.step_callback = callback
answer = AgentFinish(thought="thinking", output="test", text="final")
executor._invoke_step_callback(answer)
callback.assert_called_once_with(answer)
def test_invoke_step_callback_with_async_callback(
self, executor: CrewAgentExecutor
) -> None:
"""Test that an async step callback is awaited via asyncio.run."""
async_callback = AsyncMock()
executor.step_callback = async_callback
answer = AgentFinish(thought="thinking", output="test", text="final")
with patch("crewai.agents.crew_agent_executor.asyncio.run") as mock_run:
executor._invoke_step_callback(answer)
async_callback.assert_called_once_with(answer)
mock_run.assert_called_once()
def test_invoke_step_callback_with_none(
self, executor: CrewAgentExecutor
) -> None:
"""Test that no error is raised when step_callback is None."""
executor.step_callback = None
answer = AgentFinish(thought="thinking", output="test", text="final")
# Should not raise
executor._invoke_step_callback(answer)
class TestAsyncLLMResponseHelper:
"""Tests for aget_llm_response helper function."""