mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-02-20 21:08:15 +00:00
Compare commits
1 Commits
lg-isolate
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
71b4f8402a |
@@ -6,8 +6,10 @@ and memory management.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import inspect
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
@@ -736,7 +738,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
] = []
|
||||
for call_id, func_name, func_args in parsed_calls:
|
||||
original_tool = original_tools_by_name.get(func_name)
|
||||
execution_plan.append((call_id, func_name, func_args, original_tool))
|
||||
execution_plan.append(
|
||||
(call_id, func_name, func_args, original_tool)
|
||||
)
|
||||
|
||||
self._append_assistant_tool_calls_message(
|
||||
[
|
||||
@@ -746,7 +750,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
)
|
||||
|
||||
max_workers = min(8, len(execution_plan))
|
||||
ordered_results: list[dict[str, Any] | None] = [None] * len(execution_plan)
|
||||
ordered_results: list[dict[str, Any] | None] = [None] * len(
|
||||
execution_plan
|
||||
)
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||
futures = {
|
||||
pool.submit(
|
||||
@@ -803,7 +809,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
return tool_finish
|
||||
|
||||
reasoning_prompt = self._i18n.slice("post_tool_reasoning")
|
||||
reasoning_message: LLMMessage = {
|
||||
reasoning_message = {
|
||||
"role": "user",
|
||||
"content": reasoning_prompt,
|
||||
}
|
||||
@@ -908,9 +914,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
elif (
|
||||
should_execute
|
||||
and original_tool
|
||||
and getattr(original_tool, "max_usage_count", None) is not None
|
||||
and getattr(original_tool, "current_usage_count", 0)
|
||||
>= original_tool.max_usage_count
|
||||
and (max_count := getattr(original_tool, "max_usage_count", None))
|
||||
is not None
|
||||
and getattr(original_tool, "current_usage_count", 0) >= max_count
|
||||
):
|
||||
max_usage_reached = True
|
||||
|
||||
@@ -989,13 +995,17 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
and hasattr(original_tool, "cache_function")
|
||||
and callable(original_tool.cache_function)
|
||||
):
|
||||
should_cache = original_tool.cache_function(args_dict, raw_result)
|
||||
should_cache = original_tool.cache_function(
|
||||
args_dict, raw_result
|
||||
)
|
||||
if should_cache:
|
||||
self.tools_handler.cache.add(
|
||||
tool=func_name, input=input_str, output=raw_result
|
||||
)
|
||||
|
||||
result = str(raw_result) if not isinstance(raw_result, str) else raw_result
|
||||
result = (
|
||||
str(raw_result) if not isinstance(raw_result, str) else raw_result
|
||||
)
|
||||
except Exception as e:
|
||||
result = f"Error executing tool: {e}"
|
||||
if self.task:
|
||||
@@ -1490,7 +1500,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
formatted_answer: Current agent response.
|
||||
"""
|
||||
if self.step_callback:
|
||||
self.step_callback(formatted_answer)
|
||||
cb_result = self.step_callback(formatted_answer)
|
||||
if inspect.iscoroutine(cb_result):
|
||||
asyncio.run(cb_result)
|
||||
|
||||
def _append_message(
|
||||
self, text: str, role: Literal["user", "assistant", "system"] = "assistant"
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable, Coroutine
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from datetime import datetime
|
||||
import inspect
|
||||
import json
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
@@ -778,7 +780,7 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
|
||||
from_cache = cast(bool, execution_result["from_cache"])
|
||||
original_tool = execution_result["original_tool"]
|
||||
|
||||
tool_message: LLMMessage = {
|
||||
tool_message = {
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"name": func_name,
|
||||
@@ -1358,7 +1360,9 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
|
||||
formatted_answer: Current agent response.
|
||||
"""
|
||||
if self.step_callback:
|
||||
self.step_callback(formatted_answer)
|
||||
cb_result = self.step_callback(formatted_answer)
|
||||
if inspect.iscoroutine(cb_result):
|
||||
asyncio.run(cb_result)
|
||||
|
||||
def _append_message_to_state(
|
||||
self, text: str, role: Literal["user", "assistant", "system"] = "assistant"
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import Future
|
||||
from copy import copy as shallow_copy
|
||||
import datetime
|
||||
@@ -624,11 +625,15 @@ class Task(BaseModel):
|
||||
self.end_time = datetime.datetime.now()
|
||||
|
||||
if self.callback:
|
||||
self.callback(self.output)
|
||||
cb_result = self.callback(self.output)
|
||||
if inspect.isawaitable(cb_result):
|
||||
await cb_result
|
||||
|
||||
crew = self.agent.crew # type: ignore[union-attr]
|
||||
if crew and crew.task_callback and crew.task_callback != self.callback:
|
||||
crew.task_callback(self.output)
|
||||
cb_result = crew.task_callback(self.output)
|
||||
if inspect.isawaitable(cb_result):
|
||||
await cb_result
|
||||
|
||||
if self.output_file:
|
||||
content = (
|
||||
@@ -722,11 +727,15 @@ class Task(BaseModel):
|
||||
self.end_time = datetime.datetime.now()
|
||||
|
||||
if self.callback:
|
||||
self.callback(self.output)
|
||||
cb_result = self.callback(self.output)
|
||||
if inspect.iscoroutine(cb_result):
|
||||
asyncio.run(cb_result)
|
||||
|
||||
crew = self.agent.crew # type: ignore[union-attr]
|
||||
if crew and crew.task_callback and crew.task_callback != self.callback:
|
||||
crew.task_callback(self.output)
|
||||
cb_result = crew.task_callback(self.output)
|
||||
if inspect.iscoroutine(cb_result):
|
||||
asyncio.run(cb_result)
|
||||
|
||||
if self.output_file:
|
||||
content = (
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
from collections.abc import Callable, Sequence
|
||||
import concurrent.futures
|
||||
import inspect
|
||||
import json
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, Final, Literal, TypedDict
|
||||
@@ -501,7 +502,9 @@ def handle_agent_action_core(
|
||||
- TODO: Remove messages parameter and its usage.
|
||||
"""
|
||||
if step_callback:
|
||||
step_callback(tool_result)
|
||||
cb_result = step_callback(tool_result)
|
||||
if inspect.iscoroutine(cb_result):
|
||||
asyncio.run(cb_result)
|
||||
|
||||
formatted_answer.text += f"\nObservation: {tool_result.result}"
|
||||
formatted_answer.result = tool_result.result
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -291,6 +291,46 @@ class TestAsyncAgentExecutor:
|
||||
assert max_concurrent > 1, f"Expected concurrent execution, max concurrent was {max_concurrent}"
|
||||
|
||||
|
||||
class TestInvokeStepCallback:
|
||||
"""Tests for _invoke_step_callback with sync and async callbacks."""
|
||||
|
||||
def test_invoke_step_callback_with_sync_callback(
|
||||
self, executor: CrewAgentExecutor
|
||||
) -> None:
|
||||
"""Test that a sync step callback is called normally."""
|
||||
callback = Mock()
|
||||
executor.step_callback = callback
|
||||
answer = AgentFinish(thought="thinking", output="test", text="final")
|
||||
|
||||
executor._invoke_step_callback(answer)
|
||||
|
||||
callback.assert_called_once_with(answer)
|
||||
|
||||
def test_invoke_step_callback_with_async_callback(
|
||||
self, executor: CrewAgentExecutor
|
||||
) -> None:
|
||||
"""Test that an async step callback is awaited via asyncio.run."""
|
||||
async_callback = AsyncMock()
|
||||
executor.step_callback = async_callback
|
||||
answer = AgentFinish(thought="thinking", output="test", text="final")
|
||||
|
||||
with patch("crewai.agents.crew_agent_executor.asyncio.run") as mock_run:
|
||||
executor._invoke_step_callback(answer)
|
||||
|
||||
async_callback.assert_called_once_with(answer)
|
||||
mock_run.assert_called_once()
|
||||
|
||||
def test_invoke_step_callback_with_none(
|
||||
self, executor: CrewAgentExecutor
|
||||
) -> None:
|
||||
"""Test that no error is raised when step_callback is None."""
|
||||
executor.step_callback = None
|
||||
answer = AgentFinish(thought="thinking", output="test", text="final")
|
||||
|
||||
# Should not raise
|
||||
executor._invoke_step_callback(answer)
|
||||
|
||||
|
||||
class TestAsyncLLMResponseHelper:
|
||||
"""Tests for aget_llm_response helper function."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user