From 17474a3a0c90a6d79169eb3a3164f622090777f8 Mon Sep 17 00:00:00 2001 From: Vini Brasil Date: Fri, 2 May 2025 14:40:39 -0300 Subject: [PATCH] Identify `parent_flow` of `Crew` and `LiteAgent` (#2723) This commit adds a new crew field called parent_flow, evaluated when the Crew instance is instantiated. The stacktrace is traversed to look up if the caller is an instance of Flow, and if so, it fills in the field. Other alternatives were considered, such as a global context or even a new field to be manually filled, however, this is the most magical solution that was thread-safe and did not require public API changes. --- src/crewai/crew.py | 15 +++- src/crewai/flow/flow_trackable.py | 44 ++++++++++++ src/crewai/lite_agent.py | 5 +- tests/cli/tools/test_main.py | 115 +++++++++++++----------------- tests/crew_test.py | 34 ++++++++- tests/test_lite_agent.py | 60 ++++++++++++++++ 6 files changed, 201 insertions(+), 72 deletions(-) create mode 100644 src/crewai/flow/flow_trackable.py diff --git a/src/crewai/crew.py b/src/crewai/crew.py index 75433554c..102f22881 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -6,7 +6,17 @@ import warnings from concurrent.futures import Future from copy import copy as shallow_copy from hashlib import md5 -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union, cast +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Set, + Tuple, + Union, + cast, +) from pydantic import ( UUID4, @@ -24,6 +34,7 @@ from crewai.agent import Agent from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.agents.cache import CacheHandler from crewai.crews.crew_output import CrewOutput +from crewai.flow.flow_trackable import FlowTrackable from crewai.knowledge.knowledge import Knowledge from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.llm import LLM, BaseLLM @@ -69,7 +80,7 @@ from crewai.utilities.training_handler import CrewTrainingHandler warnings.filterwarnings("ignore", category=SyntaxWarning, module="pysbd") -class Crew(BaseModel): +class Crew(FlowTrackable, BaseModel): """ Represents a group of agents, defining how they should collaborate and the tasks they should perform. diff --git a/src/crewai/flow/flow_trackable.py b/src/crewai/flow/flow_trackable.py new file mode 100644 index 000000000..64e90630c --- /dev/null +++ b/src/crewai/flow/flow_trackable.py @@ -0,0 +1,44 @@ +import inspect +from typing import Optional + +from pydantic import BaseModel, Field, InstanceOf, model_validator + +from crewai.flow import Flow + + +class FlowTrackable(BaseModel): + """Mixin that tracks the Flow instance that instantiated the object, e.g. a + Flow instance that created a Crew or Agent. + + Automatically finds and stores a reference to the parent Flow instance by + inspecting the call stack. + """ + + parent_flow: Optional[InstanceOf[Flow]] = Field( + default=None, + description="The parent flow of the instance, if it was created inside a flow.", + ) + + @model_validator(mode="after") + def _set_parent_flow(self, max_depth: int = 5) -> "FlowTrackable": + frame = inspect.currentframe() + + try: + if frame is None: + return self + + frame = frame.f_back + for _ in range(max_depth): + if frame is None: + break + + candidate = frame.f_locals.get("self") + if isinstance(candidate, Flow): + self.parent_flow = candidate + break + + frame = frame.f_back + finally: + del frame + + return self diff --git a/src/crewai/lite_agent.py b/src/crewai/lite_agent.py index d458e6de0..4cb46c1f0 100644 --- a/src/crewai/lite_agent.py +++ b/src/crewai/lite_agent.py @@ -13,6 +13,7 @@ from crewai.agents.parser import ( AgentFinish, OutputParserException, ) +from crewai.flow.flow_trackable import FlowTrackable from crewai.llm import LLM from crewai.tools.base_tool import BaseTool from crewai.tools.structured_tool import CrewStructuredTool @@ -80,7 +81,7 @@ class LiteAgentOutput(BaseModel): return self.raw -class LiteAgent(BaseModel): +class LiteAgent(FlowTrackable, BaseModel): """ A lightweight agent that can process messages and use tools. @@ -162,7 +163,7 @@ class LiteAgent(BaseModel): _messages: List[Dict[str, str]] = PrivateAttr(default_factory=list) _iterations: int = PrivateAttr(default=0) _printer: Printer = PrivateAttr(default_factory=Printer) - + @model_validator(mode="after") def setup_llm(self): """Set up the LLM and other components after initialization.""" diff --git a/tests/cli/tools/test_main.py b/tests/cli/tools/test_main.py index b06c0b28c..28659a80a 100644 --- a/tests/cli/tools/test_main.py +++ b/tests/cli/tools/test_main.py @@ -3,12 +3,13 @@ import tempfile import unittest import unittest.mock from contextlib import contextmanager -from io import StringIO from unittest import mock from unittest.mock import MagicMock, patch +import pytest from pytest import raises +from crewai.cli.authentication.utils import TokenManager from crewai.cli.tools.main import ToolCommand @@ -23,17 +24,20 @@ def in_temp_dir(): os.chdir(original_dir) -@patch("crewai.cli.tools.main.subprocess.run") -def test_create_success(mock_subprocess): - with in_temp_dir(): - tool_command = ToolCommand() +@pytest.fixture +def tool_command(): + TokenManager().save_tokens("test-token", 36000) + tool_command = ToolCommand() + with patch.object(tool_command, "login"): + yield tool_command - with ( - patch.object(tool_command, "login") as mock_login, - patch("sys.stdout", new=StringIO()) as fake_out, - ): - tool_command.create("test-tool") - output = fake_out.getvalue() + +@patch("crewai.cli.tools.main.subprocess.run") +def test_create_success(mock_subprocess, capsys, tool_command): + with in_temp_dir(): + tool_command.create("test-tool") + output = capsys.readouterr().out + assert "Creating custom tool test_tool..." in output assert os.path.isdir("test_tool") assert os.path.isfile(os.path.join("test_tool", "README.md")) @@ -47,15 +51,12 @@ def test_create_success(mock_subprocess): content = f.read() assert "class TestTool" in content - mock_login.assert_called_once() mock_subprocess.assert_called_once_with(["git", "init"], check=True) - assert "Creating custom tool test_tool..." in output - @patch("crewai.cli.tools.main.subprocess.run") @patch("crewai.cli.plus_api.PlusAPI.get_tool") -def test_install_success(mock_get, mock_subprocess_run): +def test_install_success(mock_get, mock_subprocess_run, capsys, tool_command): mock_get_response = MagicMock() mock_get_response.status_code = 200 mock_get_response.json.return_value = { @@ -65,11 +66,9 @@ def test_install_success(mock_get, mock_subprocess_run): mock_get.return_value = mock_get_response mock_subprocess_run.return_value = MagicMock(stderr=None) - tool_command = ToolCommand() - - with patch("sys.stdout", new=StringIO()) as fake_out: - tool_command.install("sample-tool") - output = fake_out.getvalue() + tool_command.install("sample-tool") + output = capsys.readouterr().out + assert "Successfully installed sample-tool" in output mock_get.assert_has_calls([mock.call("sample-tool"), mock.call().json()]) mock_subprocess_run.assert_any_call( @@ -86,54 +85,42 @@ def test_install_success(mock_get, mock_subprocess_run): env=unittest.mock.ANY, ) - assert "Successfully installed sample-tool" in output - @patch("crewai.cli.plus_api.PlusAPI.get_tool") -def test_install_tool_not_found(mock_get): +def test_install_tool_not_found(mock_get, capsys, tool_command): mock_get_response = MagicMock() mock_get_response.status_code = 404 mock_get.return_value = mock_get_response - tool_command = ToolCommand() - - with patch("sys.stdout", new=StringIO()) as fake_out: - try: - tool_command.install("non-existent-tool") - except SystemExit: - pass - output = fake_out.getvalue() + with raises(SystemExit): + tool_command.install("non-existent-tool") + output = capsys.readouterr().out + assert "No tool found with this name" in output mock_get.assert_called_once_with("non-existent-tool") - assert "No tool found with this name" in output @patch("crewai.cli.plus_api.PlusAPI.get_tool") -def test_install_api_error(mock_get): +def test_install_api_error(mock_get, capsys, tool_command): mock_get_response = MagicMock() mock_get_response.status_code = 500 mock_get.return_value = mock_get_response - tool_command = ToolCommand() - - with patch("sys.stdout", new=StringIO()) as fake_out: - try: - tool_command.install("error-tool") - except SystemExit: - pass - output = fake_out.getvalue() + with raises(SystemExit): + tool_command.install("error-tool") + output = capsys.readouterr().out + assert "Failed to get tool details" in output mock_get.assert_called_once_with("error-tool") - assert "Failed to get tool details" in output @patch("crewai.cli.tools.main.git.Repository.is_synced", return_value=False) -def test_publish_when_not_in_sync(mock_is_synced): - with patch("sys.stdout", new=StringIO()) as fake_out, raises(SystemExit): - tool_command = ToolCommand() +def test_publish_when_not_in_sync(mock_is_synced, capsys, tool_command): + with raises(SystemExit): tool_command.publish(is_public=True) - assert "Local changes need to be resolved before publishing" in fake_out.getvalue() + output = capsys.readouterr().out + assert "Local changes need to be resolved before publishing" in output @patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool") @@ -157,13 +144,13 @@ def test_publish_when_not_in_sync_and_force( mock_get_project_description, mock_get_project_version, mock_get_project_name, + tool_command, ): mock_publish_response = MagicMock() mock_publish_response.status_code = 200 mock_publish_response.json.return_value = {"handle": "sample-tool"} mock_publish.return_value = mock_publish_response - tool_command = ToolCommand() tool_command.publish(is_public=True, force=True) mock_get_project_name.assert_called_with(require=True) @@ -205,13 +192,13 @@ def test_publish_success( mock_get_project_description, mock_get_project_version, mock_get_project_name, + tool_command, ): mock_publish_response = MagicMock() mock_publish_response.status_code = 200 mock_publish_response.json.return_value = {"handle": "sample-tool"} mock_publish.return_value = mock_publish_response - tool_command = ToolCommand() tool_command.publish(is_public=True) mock_get_project_name.assert_called_with(require=True) @@ -251,25 +238,22 @@ def test_publish_failure( mock_get_project_description, mock_get_project_version, mock_get_project_name, + capsys, + tool_command, ): mock_publish_response = MagicMock() mock_publish_response.status_code = 422 mock_publish_response.json.return_value = {"name": ["is already taken"]} mock_publish.return_value = mock_publish_response - tool_command = ToolCommand() - - with patch("sys.stdout", new=StringIO()) as fake_out: - try: - tool_command.publish(is_public=True) - except SystemExit: - pass - output = fake_out.getvalue() - - mock_publish.assert_called_once() + with raises(SystemExit): + tool_command.publish(is_public=True) + output = capsys.readouterr().out assert "Failed to complete operation" in output assert "Name is already taken" in output + mock_publish.assert_called_once() + @patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool") @patch("crewai.cli.tools.main.get_project_version", return_value="1.0.0") @@ -290,6 +274,8 @@ def test_publish_api_error( mock_get_project_description, mock_get_project_version, mock_get_project_name, + capsys, + tool_command, ): mock_response = MagicMock() mock_response.status_code = 500 @@ -297,14 +283,9 @@ def test_publish_api_error( mock_response.ok = False mock_publish.return_value = mock_response - tool_command = ToolCommand() - - with patch("sys.stdout", new=StringIO()) as fake_out: - try: - tool_command.publish(is_public=True) - except SystemExit: - pass - output = fake_out.getvalue() + with raises(SystemExit): + tool_command.publish(is_public=True) + output = capsys.readouterr().out + assert "Request to Enterprise API failed" in output mock_publish.assert_called_once() - assert "Request to Enterprise API failed" in output diff --git a/tests/crew_test.py b/tests/crew_test.py index aa23294fb..a4e4e61df 100644 --- a/tests/crew_test.py +++ b/tests/crew_test.py @@ -17,6 +17,7 @@ from crewai.agents.cache import CacheHandler from crewai.agents.crew_agent_executor import CrewAgentExecutor from crewai.crew import Crew from crewai.crews.crew_output import CrewOutput +from crewai.flow import Flow, listen, start from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource from crewai.llm import LLM from crewai.memory.contextual.contextual_memory import ContextualMemory @@ -2164,7 +2165,6 @@ def test_tools_with_custom_caching(): with patch.object( CacheHandler, "add", wraps=crew._cache_handler.add ) as add_to_cache: - result = crew.kickoff() # Check that add_to_cache was called exactly twice @@ -4351,3 +4351,35 @@ def test_crew_copy_with_memory(): raise e # Re-raise other validation errors except Exception as e: pytest.fail(f"Copying crew raised an unexpected exception: {e}") + + +def test_sets_parent_flow_when_outside_flow(researcher, writer): + crew = Crew( + agents=[researcher, writer], + process=Process.sequential, + tasks=[ + Task(description="Task 1", expected_output="output", agent=researcher), + Task(description="Task 2", expected_output="output", agent=writer), + ], + ) + assert crew.parent_flow is None + + +def test_sets_parent_flow_when_inside_flow(researcher, writer): + class MyFlow(Flow): + @start() + def start(self): + return Crew( + agents=[researcher, writer], + process=Process.sequential, + tasks=[ + Task( + description="Task 1", expected_output="output", agent=researcher + ), + Task(description="Task 2", expected_output="output", agent=writer), + ], + ) + + flow = MyFlow() + result = flow.kickoff() + assert result.parent_flow is flow diff --git a/tests/test_lite_agent.py b/tests/test_lite_agent.py index 06c87319c..7ef23dccb 100644 --- a/tests/test_lite_agent.py +++ b/tests/test_lite_agent.py @@ -1,13 +1,16 @@ import asyncio from typing import cast +from unittest.mock import Mock import pytest from pydantic import BaseModel, Field from crewai import LLM, Agent +from crewai.flow import Flow, start from crewai.lite_agent import LiteAgent, LiteAgentOutput from crewai.tools import BaseTool from crewai.utilities.events import crewai_event_bus +from crewai.utilities.events.agent_events import LiteAgentExecutionStartedEvent from crewai.utilities.events.tool_usage_events import ToolUsageStartedEvent @@ -255,3 +258,60 @@ async def test_lite_agent_returns_usage_metrics_async(): assert "21 million" in result.raw or "37 million" in result.raw assert result.usage_metrics is not None assert result.usage_metrics["total_tokens"] > 0 + + +class TestFlow(Flow): + """A test flow that creates and runs an agent.""" + + def __init__(self, llm, tools): + self.llm = llm + self.tools = tools + super().__init__() + + @start() + def start(self): + agent = Agent( + role="Test Agent", + goal="Test Goal", + backstory="Test Backstory", + llm=self.llm, + tools=self.tools, + ) + return agent.kickoff("Test query") + + +def verify_agent_parent_flow(result, agent, flow): + """Verify that both the result and agent have the correct parent flow.""" + assert result.parent_flow is flow + assert agent is not None + assert agent.parent_flow is flow + + +def test_sets_parent_flow_when_inside_flow(): + captured_agent = None + + mock_llm = Mock(spec=LLM) + mock_llm.call.return_value = "Test response" + + class MyFlow(Flow): + @start() + def start(self): + agent = Agent( + role="Test Agent", + goal="Test Goal", + backstory="Test Backstory", + llm=mock_llm, + tools=[WebSearchTool()], + ) + return agent.kickoff("Test query") + + flow = MyFlow() + with crewai_event_bus.scoped_handlers(): + + @crewai_event_bus.on(LiteAgentExecutionStartedEvent) + def capture_agent(source, event): + nonlocal captured_agent + captured_agent = source + + result = flow.kickoff() + assert captured_agent.parent_flow is flow