mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-01 15:22:37 +00:00
Merge branch 'main' into lg-allow-remove-stop
This commit is contained in:
@@ -6,7 +6,17 @@ import warnings
|
|||||||
from concurrent.futures import Future
|
from concurrent.futures import Future
|
||||||
from copy import copy as shallow_copy
|
from copy import copy as shallow_copy
|
||||||
from hashlib import md5
|
from hashlib import md5
|
||||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union, cast
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Set,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
UUID4,
|
UUID4,
|
||||||
@@ -24,6 +34,7 @@ from crewai.agent import Agent
|
|||||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||||
from crewai.agents.cache import CacheHandler
|
from crewai.agents.cache import CacheHandler
|
||||||
from crewai.crews.crew_output import CrewOutput
|
from crewai.crews.crew_output import CrewOutput
|
||||||
|
from crewai.flow.flow_trackable import FlowTrackable
|
||||||
from crewai.knowledge.knowledge import Knowledge
|
from crewai.knowledge.knowledge import Knowledge
|
||||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||||
from crewai.llm import LLM, BaseLLM
|
from crewai.llm import LLM, BaseLLM
|
||||||
@@ -69,7 +80,7 @@ from crewai.utilities.training_handler import CrewTrainingHandler
|
|||||||
warnings.filterwarnings("ignore", category=SyntaxWarning, module="pysbd")
|
warnings.filterwarnings("ignore", category=SyntaxWarning, module="pysbd")
|
||||||
|
|
||||||
|
|
||||||
class Crew(BaseModel):
|
class Crew(FlowTrackable, BaseModel):
|
||||||
"""
|
"""
|
||||||
Represents a group of agents, defining how they should collaborate and the tasks they should perform.
|
Represents a group of agents, defining how they should collaborate and the tasks they should perform.
|
||||||
|
|
||||||
|
|||||||
44
src/crewai/flow/flow_trackable.py
Normal file
44
src/crewai/flow/flow_trackable.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
import inspect
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, InstanceOf, model_validator
|
||||||
|
|
||||||
|
from crewai.flow import Flow
|
||||||
|
|
||||||
|
|
||||||
|
class FlowTrackable(BaseModel):
|
||||||
|
"""Mixin that tracks the Flow instance that instantiated the object, e.g. a
|
||||||
|
Flow instance that created a Crew or Agent.
|
||||||
|
|
||||||
|
Automatically finds and stores a reference to the parent Flow instance by
|
||||||
|
inspecting the call stack.
|
||||||
|
"""
|
||||||
|
|
||||||
|
parent_flow: Optional[InstanceOf[Flow]] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The parent flow of the instance, if it was created inside a flow.",
|
||||||
|
)
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def _set_parent_flow(self, max_depth: int = 5) -> "FlowTrackable":
|
||||||
|
frame = inspect.currentframe()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if frame is None:
|
||||||
|
return self
|
||||||
|
|
||||||
|
frame = frame.f_back
|
||||||
|
for _ in range(max_depth):
|
||||||
|
if frame is None:
|
||||||
|
break
|
||||||
|
|
||||||
|
candidate = frame.f_locals.get("self")
|
||||||
|
if isinstance(candidate, Flow):
|
||||||
|
self.parent_flow = candidate
|
||||||
|
break
|
||||||
|
|
||||||
|
frame = frame.f_back
|
||||||
|
finally:
|
||||||
|
del frame
|
||||||
|
|
||||||
|
return self
|
||||||
@@ -13,6 +13,7 @@ from crewai.agents.parser import (
|
|||||||
AgentFinish,
|
AgentFinish,
|
||||||
OutputParserException,
|
OutputParserException,
|
||||||
)
|
)
|
||||||
|
from crewai.flow.flow_trackable import FlowTrackable
|
||||||
from crewai.llm import LLM
|
from crewai.llm import LLM
|
||||||
from crewai.tools.base_tool import BaseTool
|
from crewai.tools.base_tool import BaseTool
|
||||||
from crewai.tools.structured_tool import CrewStructuredTool
|
from crewai.tools.structured_tool import CrewStructuredTool
|
||||||
@@ -80,7 +81,7 @@ class LiteAgentOutput(BaseModel):
|
|||||||
return self.raw
|
return self.raw
|
||||||
|
|
||||||
|
|
||||||
class LiteAgent(BaseModel):
|
class LiteAgent(FlowTrackable, BaseModel):
|
||||||
"""
|
"""
|
||||||
A lightweight agent that can process messages and use tools.
|
A lightweight agent that can process messages and use tools.
|
||||||
|
|
||||||
@@ -162,7 +163,7 @@ class LiteAgent(BaseModel):
|
|||||||
_messages: List[Dict[str, str]] = PrivateAttr(default_factory=list)
|
_messages: List[Dict[str, str]] = PrivateAttr(default_factory=list)
|
||||||
_iterations: int = PrivateAttr(default=0)
|
_iterations: int = PrivateAttr(default=0)
|
||||||
_printer: Printer = PrivateAttr(default_factory=Printer)
|
_printer: Printer = PrivateAttr(default_factory=Printer)
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def setup_llm(self):
|
def setup_llm(self):
|
||||||
"""Set up the LLM and other components after initialization."""
|
"""Set up the LLM and other components after initialization."""
|
||||||
|
|||||||
@@ -3,12 +3,13 @@ import tempfile
|
|||||||
import unittest
|
import unittest
|
||||||
import unittest.mock
|
import unittest.mock
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from io import StringIO
|
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
from pytest import raises
|
from pytest import raises
|
||||||
|
|
||||||
|
from crewai.cli.authentication.utils import TokenManager
|
||||||
from crewai.cli.tools.main import ToolCommand
|
from crewai.cli.tools.main import ToolCommand
|
||||||
|
|
||||||
|
|
||||||
@@ -23,17 +24,20 @@ def in_temp_dir():
|
|||||||
os.chdir(original_dir)
|
os.chdir(original_dir)
|
||||||
|
|
||||||
|
|
||||||
@patch("crewai.cli.tools.main.subprocess.run")
|
@pytest.fixture
|
||||||
def test_create_success(mock_subprocess):
|
def tool_command():
|
||||||
with in_temp_dir():
|
TokenManager().save_tokens("test-token", 36000)
|
||||||
tool_command = ToolCommand()
|
tool_command = ToolCommand()
|
||||||
|
with patch.object(tool_command, "login"):
|
||||||
|
yield tool_command
|
||||||
|
|
||||||
with (
|
|
||||||
patch.object(tool_command, "login") as mock_login,
|
@patch("crewai.cli.tools.main.subprocess.run")
|
||||||
patch("sys.stdout", new=StringIO()) as fake_out,
|
def test_create_success(mock_subprocess, capsys, tool_command):
|
||||||
):
|
with in_temp_dir():
|
||||||
tool_command.create("test-tool")
|
tool_command.create("test-tool")
|
||||||
output = fake_out.getvalue()
|
output = capsys.readouterr().out
|
||||||
|
assert "Creating custom tool test_tool..." in output
|
||||||
|
|
||||||
assert os.path.isdir("test_tool")
|
assert os.path.isdir("test_tool")
|
||||||
assert os.path.isfile(os.path.join("test_tool", "README.md"))
|
assert os.path.isfile(os.path.join("test_tool", "README.md"))
|
||||||
@@ -47,15 +51,12 @@ def test_create_success(mock_subprocess):
|
|||||||
content = f.read()
|
content = f.read()
|
||||||
assert "class TestTool" in content
|
assert "class TestTool" in content
|
||||||
|
|
||||||
mock_login.assert_called_once()
|
|
||||||
mock_subprocess.assert_called_once_with(["git", "init"], check=True)
|
mock_subprocess.assert_called_once_with(["git", "init"], check=True)
|
||||||
|
|
||||||
assert "Creating custom tool test_tool..." in output
|
|
||||||
|
|
||||||
|
|
||||||
@patch("crewai.cli.tools.main.subprocess.run")
|
@patch("crewai.cli.tools.main.subprocess.run")
|
||||||
@patch("crewai.cli.plus_api.PlusAPI.get_tool")
|
@patch("crewai.cli.plus_api.PlusAPI.get_tool")
|
||||||
def test_install_success(mock_get, mock_subprocess_run):
|
def test_install_success(mock_get, mock_subprocess_run, capsys, tool_command):
|
||||||
mock_get_response = MagicMock()
|
mock_get_response = MagicMock()
|
||||||
mock_get_response.status_code = 200
|
mock_get_response.status_code = 200
|
||||||
mock_get_response.json.return_value = {
|
mock_get_response.json.return_value = {
|
||||||
@@ -65,11 +66,9 @@ def test_install_success(mock_get, mock_subprocess_run):
|
|||||||
mock_get.return_value = mock_get_response
|
mock_get.return_value = mock_get_response
|
||||||
mock_subprocess_run.return_value = MagicMock(stderr=None)
|
mock_subprocess_run.return_value = MagicMock(stderr=None)
|
||||||
|
|
||||||
tool_command = ToolCommand()
|
tool_command.install("sample-tool")
|
||||||
|
output = capsys.readouterr().out
|
||||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
assert "Successfully installed sample-tool" in output
|
||||||
tool_command.install("sample-tool")
|
|
||||||
output = fake_out.getvalue()
|
|
||||||
|
|
||||||
mock_get.assert_has_calls([mock.call("sample-tool"), mock.call().json()])
|
mock_get.assert_has_calls([mock.call("sample-tool"), mock.call().json()])
|
||||||
mock_subprocess_run.assert_any_call(
|
mock_subprocess_run.assert_any_call(
|
||||||
@@ -86,54 +85,42 @@ def test_install_success(mock_get, mock_subprocess_run):
|
|||||||
env=unittest.mock.ANY,
|
env=unittest.mock.ANY,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert "Successfully installed sample-tool" in output
|
|
||||||
|
|
||||||
|
|
||||||
@patch("crewai.cli.plus_api.PlusAPI.get_tool")
|
@patch("crewai.cli.plus_api.PlusAPI.get_tool")
|
||||||
def test_install_tool_not_found(mock_get):
|
def test_install_tool_not_found(mock_get, capsys, tool_command):
|
||||||
mock_get_response = MagicMock()
|
mock_get_response = MagicMock()
|
||||||
mock_get_response.status_code = 404
|
mock_get_response.status_code = 404
|
||||||
mock_get.return_value = mock_get_response
|
mock_get.return_value = mock_get_response
|
||||||
|
|
||||||
tool_command = ToolCommand()
|
with raises(SystemExit):
|
||||||
|
tool_command.install("non-existent-tool")
|
||||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
output = capsys.readouterr().out
|
||||||
try:
|
assert "No tool found with this name" in output
|
||||||
tool_command.install("non-existent-tool")
|
|
||||||
except SystemExit:
|
|
||||||
pass
|
|
||||||
output = fake_out.getvalue()
|
|
||||||
|
|
||||||
mock_get.assert_called_once_with("non-existent-tool")
|
mock_get.assert_called_once_with("non-existent-tool")
|
||||||
assert "No tool found with this name" in output
|
|
||||||
|
|
||||||
|
|
||||||
@patch("crewai.cli.plus_api.PlusAPI.get_tool")
|
@patch("crewai.cli.plus_api.PlusAPI.get_tool")
|
||||||
def test_install_api_error(mock_get):
|
def test_install_api_error(mock_get, capsys, tool_command):
|
||||||
mock_get_response = MagicMock()
|
mock_get_response = MagicMock()
|
||||||
mock_get_response.status_code = 500
|
mock_get_response.status_code = 500
|
||||||
mock_get.return_value = mock_get_response
|
mock_get.return_value = mock_get_response
|
||||||
|
|
||||||
tool_command = ToolCommand()
|
with raises(SystemExit):
|
||||||
|
tool_command.install("error-tool")
|
||||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
output = capsys.readouterr().out
|
||||||
try:
|
assert "Failed to get tool details" in output
|
||||||
tool_command.install("error-tool")
|
|
||||||
except SystemExit:
|
|
||||||
pass
|
|
||||||
output = fake_out.getvalue()
|
|
||||||
|
|
||||||
mock_get.assert_called_once_with("error-tool")
|
mock_get.assert_called_once_with("error-tool")
|
||||||
assert "Failed to get tool details" in output
|
|
||||||
|
|
||||||
|
|
||||||
@patch("crewai.cli.tools.main.git.Repository.is_synced", return_value=False)
|
@patch("crewai.cli.tools.main.git.Repository.is_synced", return_value=False)
|
||||||
def test_publish_when_not_in_sync(mock_is_synced):
|
def test_publish_when_not_in_sync(mock_is_synced, capsys, tool_command):
|
||||||
with patch("sys.stdout", new=StringIO()) as fake_out, raises(SystemExit):
|
with raises(SystemExit):
|
||||||
tool_command = ToolCommand()
|
|
||||||
tool_command.publish(is_public=True)
|
tool_command.publish(is_public=True)
|
||||||
|
|
||||||
assert "Local changes need to be resolved before publishing" in fake_out.getvalue()
|
output = capsys.readouterr().out
|
||||||
|
assert "Local changes need to be resolved before publishing" in output
|
||||||
|
|
||||||
|
|
||||||
@patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool")
|
@patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool")
|
||||||
@@ -157,13 +144,13 @@ def test_publish_when_not_in_sync_and_force(
|
|||||||
mock_get_project_description,
|
mock_get_project_description,
|
||||||
mock_get_project_version,
|
mock_get_project_version,
|
||||||
mock_get_project_name,
|
mock_get_project_name,
|
||||||
|
tool_command,
|
||||||
):
|
):
|
||||||
mock_publish_response = MagicMock()
|
mock_publish_response = MagicMock()
|
||||||
mock_publish_response.status_code = 200
|
mock_publish_response.status_code = 200
|
||||||
mock_publish_response.json.return_value = {"handle": "sample-tool"}
|
mock_publish_response.json.return_value = {"handle": "sample-tool"}
|
||||||
mock_publish.return_value = mock_publish_response
|
mock_publish.return_value = mock_publish_response
|
||||||
|
|
||||||
tool_command = ToolCommand()
|
|
||||||
tool_command.publish(is_public=True, force=True)
|
tool_command.publish(is_public=True, force=True)
|
||||||
|
|
||||||
mock_get_project_name.assert_called_with(require=True)
|
mock_get_project_name.assert_called_with(require=True)
|
||||||
@@ -205,13 +192,13 @@ def test_publish_success(
|
|||||||
mock_get_project_description,
|
mock_get_project_description,
|
||||||
mock_get_project_version,
|
mock_get_project_version,
|
||||||
mock_get_project_name,
|
mock_get_project_name,
|
||||||
|
tool_command,
|
||||||
):
|
):
|
||||||
mock_publish_response = MagicMock()
|
mock_publish_response = MagicMock()
|
||||||
mock_publish_response.status_code = 200
|
mock_publish_response.status_code = 200
|
||||||
mock_publish_response.json.return_value = {"handle": "sample-tool"}
|
mock_publish_response.json.return_value = {"handle": "sample-tool"}
|
||||||
mock_publish.return_value = mock_publish_response
|
mock_publish.return_value = mock_publish_response
|
||||||
|
|
||||||
tool_command = ToolCommand()
|
|
||||||
tool_command.publish(is_public=True)
|
tool_command.publish(is_public=True)
|
||||||
|
|
||||||
mock_get_project_name.assert_called_with(require=True)
|
mock_get_project_name.assert_called_with(require=True)
|
||||||
@@ -251,25 +238,22 @@ def test_publish_failure(
|
|||||||
mock_get_project_description,
|
mock_get_project_description,
|
||||||
mock_get_project_version,
|
mock_get_project_version,
|
||||||
mock_get_project_name,
|
mock_get_project_name,
|
||||||
|
capsys,
|
||||||
|
tool_command,
|
||||||
):
|
):
|
||||||
mock_publish_response = MagicMock()
|
mock_publish_response = MagicMock()
|
||||||
mock_publish_response.status_code = 422
|
mock_publish_response.status_code = 422
|
||||||
mock_publish_response.json.return_value = {"name": ["is already taken"]}
|
mock_publish_response.json.return_value = {"name": ["is already taken"]}
|
||||||
mock_publish.return_value = mock_publish_response
|
mock_publish.return_value = mock_publish_response
|
||||||
|
|
||||||
tool_command = ToolCommand()
|
with raises(SystemExit):
|
||||||
|
tool_command.publish(is_public=True)
|
||||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
output = capsys.readouterr().out
|
||||||
try:
|
|
||||||
tool_command.publish(is_public=True)
|
|
||||||
except SystemExit:
|
|
||||||
pass
|
|
||||||
output = fake_out.getvalue()
|
|
||||||
|
|
||||||
mock_publish.assert_called_once()
|
|
||||||
assert "Failed to complete operation" in output
|
assert "Failed to complete operation" in output
|
||||||
assert "Name is already taken" in output
|
assert "Name is already taken" in output
|
||||||
|
|
||||||
|
mock_publish.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
@patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool")
|
@patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool")
|
||||||
@patch("crewai.cli.tools.main.get_project_version", return_value="1.0.0")
|
@patch("crewai.cli.tools.main.get_project_version", return_value="1.0.0")
|
||||||
@@ -290,6 +274,8 @@ def test_publish_api_error(
|
|||||||
mock_get_project_description,
|
mock_get_project_description,
|
||||||
mock_get_project_version,
|
mock_get_project_version,
|
||||||
mock_get_project_name,
|
mock_get_project_name,
|
||||||
|
capsys,
|
||||||
|
tool_command,
|
||||||
):
|
):
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.status_code = 500
|
mock_response.status_code = 500
|
||||||
@@ -297,14 +283,9 @@ def test_publish_api_error(
|
|||||||
mock_response.ok = False
|
mock_response.ok = False
|
||||||
mock_publish.return_value = mock_response
|
mock_publish.return_value = mock_response
|
||||||
|
|
||||||
tool_command = ToolCommand()
|
with raises(SystemExit):
|
||||||
|
tool_command.publish(is_public=True)
|
||||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
output = capsys.readouterr().out
|
||||||
try:
|
assert "Request to Enterprise API failed" in output
|
||||||
tool_command.publish(is_public=True)
|
|
||||||
except SystemExit:
|
|
||||||
pass
|
|
||||||
output = fake_out.getvalue()
|
|
||||||
|
|
||||||
mock_publish.assert_called_once()
|
mock_publish.assert_called_once()
|
||||||
assert "Request to Enterprise API failed" in output
|
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from crewai.agents.cache import CacheHandler
|
|||||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||||
from crewai.crew import Crew
|
from crewai.crew import Crew
|
||||||
from crewai.crews.crew_output import CrewOutput
|
from crewai.crews.crew_output import CrewOutput
|
||||||
|
from crewai.flow import Flow, listen, start
|
||||||
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
|
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
|
||||||
from crewai.llm import LLM
|
from crewai.llm import LLM
|
||||||
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||||
@@ -2164,7 +2165,6 @@ def test_tools_with_custom_caching():
|
|||||||
with patch.object(
|
with patch.object(
|
||||||
CacheHandler, "add", wraps=crew._cache_handler.add
|
CacheHandler, "add", wraps=crew._cache_handler.add
|
||||||
) as add_to_cache:
|
) as add_to_cache:
|
||||||
|
|
||||||
result = crew.kickoff()
|
result = crew.kickoff()
|
||||||
|
|
||||||
# Check that add_to_cache was called exactly twice
|
# Check that add_to_cache was called exactly twice
|
||||||
@@ -4351,3 +4351,35 @@ def test_crew_copy_with_memory():
|
|||||||
raise e # Re-raise other validation errors
|
raise e # Re-raise other validation errors
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Copying crew raised an unexpected exception: {e}")
|
pytest.fail(f"Copying crew raised an unexpected exception: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_sets_parent_flow_when_outside_flow(researcher, writer):
|
||||||
|
crew = Crew(
|
||||||
|
agents=[researcher, writer],
|
||||||
|
process=Process.sequential,
|
||||||
|
tasks=[
|
||||||
|
Task(description="Task 1", expected_output="output", agent=researcher),
|
||||||
|
Task(description="Task 2", expected_output="output", agent=writer),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert crew.parent_flow is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_sets_parent_flow_when_inside_flow(researcher, writer):
|
||||||
|
class MyFlow(Flow):
|
||||||
|
@start()
|
||||||
|
def start(self):
|
||||||
|
return Crew(
|
||||||
|
agents=[researcher, writer],
|
||||||
|
process=Process.sequential,
|
||||||
|
tasks=[
|
||||||
|
Task(
|
||||||
|
description="Task 1", expected_output="output", agent=researcher
|
||||||
|
),
|
||||||
|
Task(description="Task 2", expected_output="output", agent=writer),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
flow = MyFlow()
|
||||||
|
result = flow.kickoff()
|
||||||
|
assert result.parent_flow is flow
|
||||||
|
|||||||
@@ -1,13 +1,16 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from crewai import LLM, Agent
|
from crewai import LLM, Agent
|
||||||
|
from crewai.flow import Flow, start
|
||||||
from crewai.lite_agent import LiteAgent, LiteAgentOutput
|
from crewai.lite_agent import LiteAgent, LiteAgentOutput
|
||||||
from crewai.tools import BaseTool
|
from crewai.tools import BaseTool
|
||||||
from crewai.utilities.events import crewai_event_bus
|
from crewai.utilities.events import crewai_event_bus
|
||||||
|
from crewai.utilities.events.agent_events import LiteAgentExecutionStartedEvent
|
||||||
from crewai.utilities.events.tool_usage_events import ToolUsageStartedEvent
|
from crewai.utilities.events.tool_usage_events import ToolUsageStartedEvent
|
||||||
|
|
||||||
|
|
||||||
@@ -255,3 +258,60 @@ async def test_lite_agent_returns_usage_metrics_async():
|
|||||||
assert "21 million" in result.raw or "37 million" in result.raw
|
assert "21 million" in result.raw or "37 million" in result.raw
|
||||||
assert result.usage_metrics is not None
|
assert result.usage_metrics is not None
|
||||||
assert result.usage_metrics["total_tokens"] > 0
|
assert result.usage_metrics["total_tokens"] > 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlow(Flow):
|
||||||
|
"""A test flow that creates and runs an agent."""
|
||||||
|
|
||||||
|
def __init__(self, llm, tools):
|
||||||
|
self.llm = llm
|
||||||
|
self.tools = tools
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@start()
|
||||||
|
def start(self):
|
||||||
|
agent = Agent(
|
||||||
|
role="Test Agent",
|
||||||
|
goal="Test Goal",
|
||||||
|
backstory="Test Backstory",
|
||||||
|
llm=self.llm,
|
||||||
|
tools=self.tools,
|
||||||
|
)
|
||||||
|
return agent.kickoff("Test query")
|
||||||
|
|
||||||
|
|
||||||
|
def verify_agent_parent_flow(result, agent, flow):
|
||||||
|
"""Verify that both the result and agent have the correct parent flow."""
|
||||||
|
assert result.parent_flow is flow
|
||||||
|
assert agent is not None
|
||||||
|
assert agent.parent_flow is flow
|
||||||
|
|
||||||
|
|
||||||
|
def test_sets_parent_flow_when_inside_flow():
|
||||||
|
captured_agent = None
|
||||||
|
|
||||||
|
mock_llm = Mock(spec=LLM)
|
||||||
|
mock_llm.call.return_value = "Test response"
|
||||||
|
|
||||||
|
class MyFlow(Flow):
|
||||||
|
@start()
|
||||||
|
def start(self):
|
||||||
|
agent = Agent(
|
||||||
|
role="Test Agent",
|
||||||
|
goal="Test Goal",
|
||||||
|
backstory="Test Backstory",
|
||||||
|
llm=mock_llm,
|
||||||
|
tools=[WebSearchTool()],
|
||||||
|
)
|
||||||
|
return agent.kickoff("Test query")
|
||||||
|
|
||||||
|
flow = MyFlow()
|
||||||
|
with crewai_event_bus.scoped_handlers():
|
||||||
|
|
||||||
|
@crewai_event_bus.on(LiteAgentExecutionStartedEvent)
|
||||||
|
def capture_agent(source, event):
|
||||||
|
nonlocal captured_agent
|
||||||
|
captured_agent = source
|
||||||
|
|
||||||
|
result = flow.kickoff()
|
||||||
|
assert captured_agent.parent_flow is flow
|
||||||
|
|||||||
Reference in New Issue
Block a user