diff --git a/src/crewai/crew.py b/src/crewai/crew.py index 75433554c..102f22881 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -6,7 +6,17 @@ import warnings from concurrent.futures import Future from copy import copy as shallow_copy from hashlib import md5 -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union, cast +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Set, + Tuple, + Union, + cast, +) from pydantic import ( UUID4, @@ -24,6 +34,7 @@ from crewai.agent import Agent from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.agents.cache import CacheHandler from crewai.crews.crew_output import CrewOutput +from crewai.flow.flow_trackable import FlowTrackable from crewai.knowledge.knowledge import Knowledge from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.llm import LLM, BaseLLM @@ -69,7 +80,7 @@ from crewai.utilities.training_handler import CrewTrainingHandler warnings.filterwarnings("ignore", category=SyntaxWarning, module="pysbd") -class Crew(BaseModel): +class Crew(FlowTrackable, BaseModel): """ Represents a group of agents, defining how they should collaborate and the tasks they should perform. diff --git a/src/crewai/flow/flow_trackable.py b/src/crewai/flow/flow_trackable.py new file mode 100644 index 000000000..64e90630c --- /dev/null +++ b/src/crewai/flow/flow_trackable.py @@ -0,0 +1,44 @@ +import inspect +from typing import Optional + +from pydantic import BaseModel, Field, InstanceOf, model_validator + +from crewai.flow import Flow + + +class FlowTrackable(BaseModel): + """Mixin that tracks the Flow instance that instantiated the object, e.g. a + Flow instance that created a Crew or Agent. + + Automatically finds and stores a reference to the parent Flow instance by + inspecting the call stack. + """ + + parent_flow: Optional[InstanceOf[Flow]] = Field( + default=None, + description="The parent flow of the instance, if it was created inside a flow.", + ) + + @model_validator(mode="after") + def _set_parent_flow(self, max_depth: int = 5) -> "FlowTrackable": + frame = inspect.currentframe() + + try: + if frame is None: + return self + + frame = frame.f_back + for _ in range(max_depth): + if frame is None: + break + + candidate = frame.f_locals.get("self") + if isinstance(candidate, Flow): + self.parent_flow = candidate + break + + frame = frame.f_back + finally: + del frame + + return self diff --git a/src/crewai/lite_agent.py b/src/crewai/lite_agent.py index d458e6de0..4cb46c1f0 100644 --- a/src/crewai/lite_agent.py +++ b/src/crewai/lite_agent.py @@ -13,6 +13,7 @@ from crewai.agents.parser import ( AgentFinish, OutputParserException, ) +from crewai.flow.flow_trackable import FlowTrackable from crewai.llm import LLM from crewai.tools.base_tool import BaseTool from crewai.tools.structured_tool import CrewStructuredTool @@ -80,7 +81,7 @@ class LiteAgentOutput(BaseModel): return self.raw -class LiteAgent(BaseModel): +class LiteAgent(FlowTrackable, BaseModel): """ A lightweight agent that can process messages and use tools. @@ -162,7 +163,7 @@ class LiteAgent(BaseModel): _messages: List[Dict[str, str]] = PrivateAttr(default_factory=list) _iterations: int = PrivateAttr(default=0) _printer: Printer = PrivateAttr(default_factory=Printer) - + @model_validator(mode="after") def setup_llm(self): """Set up the LLM and other components after initialization.""" diff --git a/tests/cli/tools/test_main.py b/tests/cli/tools/test_main.py index b06c0b28c..28659a80a 100644 --- a/tests/cli/tools/test_main.py +++ b/tests/cli/tools/test_main.py @@ -3,12 +3,13 @@ import tempfile import unittest import unittest.mock from contextlib import contextmanager -from io import StringIO from unittest import mock from unittest.mock import MagicMock, patch +import pytest from pytest import raises +from crewai.cli.authentication.utils import TokenManager from crewai.cli.tools.main import ToolCommand @@ -23,17 +24,20 @@ def in_temp_dir(): os.chdir(original_dir) -@patch("crewai.cli.tools.main.subprocess.run") -def test_create_success(mock_subprocess): - with in_temp_dir(): - tool_command = ToolCommand() +@pytest.fixture +def tool_command(): + TokenManager().save_tokens("test-token", 36000) + tool_command = ToolCommand() + with patch.object(tool_command, "login"): + yield tool_command - with ( - patch.object(tool_command, "login") as mock_login, - patch("sys.stdout", new=StringIO()) as fake_out, - ): - tool_command.create("test-tool") - output = fake_out.getvalue() + +@patch("crewai.cli.tools.main.subprocess.run") +def test_create_success(mock_subprocess, capsys, tool_command): + with in_temp_dir(): + tool_command.create("test-tool") + output = capsys.readouterr().out + assert "Creating custom tool test_tool..." in output assert os.path.isdir("test_tool") assert os.path.isfile(os.path.join("test_tool", "README.md")) @@ -47,15 +51,12 @@ def test_create_success(mock_subprocess): content = f.read() assert "class TestTool" in content - mock_login.assert_called_once() mock_subprocess.assert_called_once_with(["git", "init"], check=True) - assert "Creating custom tool test_tool..." in output - @patch("crewai.cli.tools.main.subprocess.run") @patch("crewai.cli.plus_api.PlusAPI.get_tool") -def test_install_success(mock_get, mock_subprocess_run): +def test_install_success(mock_get, mock_subprocess_run, capsys, tool_command): mock_get_response = MagicMock() mock_get_response.status_code = 200 mock_get_response.json.return_value = { @@ -65,11 +66,9 @@ def test_install_success(mock_get, mock_subprocess_run): mock_get.return_value = mock_get_response mock_subprocess_run.return_value = MagicMock(stderr=None) - tool_command = ToolCommand() - - with patch("sys.stdout", new=StringIO()) as fake_out: - tool_command.install("sample-tool") - output = fake_out.getvalue() + tool_command.install("sample-tool") + output = capsys.readouterr().out + assert "Successfully installed sample-tool" in output mock_get.assert_has_calls([mock.call("sample-tool"), mock.call().json()]) mock_subprocess_run.assert_any_call( @@ -86,54 +85,42 @@ def test_install_success(mock_get, mock_subprocess_run): env=unittest.mock.ANY, ) - assert "Successfully installed sample-tool" in output - @patch("crewai.cli.plus_api.PlusAPI.get_tool") -def test_install_tool_not_found(mock_get): +def test_install_tool_not_found(mock_get, capsys, tool_command): mock_get_response = MagicMock() mock_get_response.status_code = 404 mock_get.return_value = mock_get_response - tool_command = ToolCommand() - - with patch("sys.stdout", new=StringIO()) as fake_out: - try: - tool_command.install("non-existent-tool") - except SystemExit: - pass - output = fake_out.getvalue() + with raises(SystemExit): + tool_command.install("non-existent-tool") + output = capsys.readouterr().out + assert "No tool found with this name" in output mock_get.assert_called_once_with("non-existent-tool") - assert "No tool found with this name" in output @patch("crewai.cli.plus_api.PlusAPI.get_tool") -def test_install_api_error(mock_get): +def test_install_api_error(mock_get, capsys, tool_command): mock_get_response = MagicMock() mock_get_response.status_code = 500 mock_get.return_value = mock_get_response - tool_command = ToolCommand() - - with patch("sys.stdout", new=StringIO()) as fake_out: - try: - tool_command.install("error-tool") - except SystemExit: - pass - output = fake_out.getvalue() + with raises(SystemExit): + tool_command.install("error-tool") + output = capsys.readouterr().out + assert "Failed to get tool details" in output mock_get.assert_called_once_with("error-tool") - assert "Failed to get tool details" in output @patch("crewai.cli.tools.main.git.Repository.is_synced", return_value=False) -def test_publish_when_not_in_sync(mock_is_synced): - with patch("sys.stdout", new=StringIO()) as fake_out, raises(SystemExit): - tool_command = ToolCommand() +def test_publish_when_not_in_sync(mock_is_synced, capsys, tool_command): + with raises(SystemExit): tool_command.publish(is_public=True) - assert "Local changes need to be resolved before publishing" in fake_out.getvalue() + output = capsys.readouterr().out + assert "Local changes need to be resolved before publishing" in output @patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool") @@ -157,13 +144,13 @@ def test_publish_when_not_in_sync_and_force( mock_get_project_description, mock_get_project_version, mock_get_project_name, + tool_command, ): mock_publish_response = MagicMock() mock_publish_response.status_code = 200 mock_publish_response.json.return_value = {"handle": "sample-tool"} mock_publish.return_value = mock_publish_response - tool_command = ToolCommand() tool_command.publish(is_public=True, force=True) mock_get_project_name.assert_called_with(require=True) @@ -205,13 +192,13 @@ def test_publish_success( mock_get_project_description, mock_get_project_version, mock_get_project_name, + tool_command, ): mock_publish_response = MagicMock() mock_publish_response.status_code = 200 mock_publish_response.json.return_value = {"handle": "sample-tool"} mock_publish.return_value = mock_publish_response - tool_command = ToolCommand() tool_command.publish(is_public=True) mock_get_project_name.assert_called_with(require=True) @@ -251,25 +238,22 @@ def test_publish_failure( mock_get_project_description, mock_get_project_version, mock_get_project_name, + capsys, + tool_command, ): mock_publish_response = MagicMock() mock_publish_response.status_code = 422 mock_publish_response.json.return_value = {"name": ["is already taken"]} mock_publish.return_value = mock_publish_response - tool_command = ToolCommand() - - with patch("sys.stdout", new=StringIO()) as fake_out: - try: - tool_command.publish(is_public=True) - except SystemExit: - pass - output = fake_out.getvalue() - - mock_publish.assert_called_once() + with raises(SystemExit): + tool_command.publish(is_public=True) + output = capsys.readouterr().out assert "Failed to complete operation" in output assert "Name is already taken" in output + mock_publish.assert_called_once() + @patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool") @patch("crewai.cli.tools.main.get_project_version", return_value="1.0.0") @@ -290,6 +274,8 @@ def test_publish_api_error( mock_get_project_description, mock_get_project_version, mock_get_project_name, + capsys, + tool_command, ): mock_response = MagicMock() mock_response.status_code = 500 @@ -297,14 +283,9 @@ def test_publish_api_error( mock_response.ok = False mock_publish.return_value = mock_response - tool_command = ToolCommand() - - with patch("sys.stdout", new=StringIO()) as fake_out: - try: - tool_command.publish(is_public=True) - except SystemExit: - pass - output = fake_out.getvalue() + with raises(SystemExit): + tool_command.publish(is_public=True) + output = capsys.readouterr().out + assert "Request to Enterprise API failed" in output mock_publish.assert_called_once() - assert "Request to Enterprise API failed" in output diff --git a/tests/crew_test.py b/tests/crew_test.py index aa23294fb..a4e4e61df 100644 --- a/tests/crew_test.py +++ b/tests/crew_test.py @@ -17,6 +17,7 @@ from crewai.agents.cache import CacheHandler from crewai.agents.crew_agent_executor import CrewAgentExecutor from crewai.crew import Crew from crewai.crews.crew_output import CrewOutput +from crewai.flow import Flow, listen, start from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource from crewai.llm import LLM from crewai.memory.contextual.contextual_memory import ContextualMemory @@ -2164,7 +2165,6 @@ def test_tools_with_custom_caching(): with patch.object( CacheHandler, "add", wraps=crew._cache_handler.add ) as add_to_cache: - result = crew.kickoff() # Check that add_to_cache was called exactly twice @@ -4351,3 +4351,35 @@ def test_crew_copy_with_memory(): raise e # Re-raise other validation errors except Exception as e: pytest.fail(f"Copying crew raised an unexpected exception: {e}") + + +def test_sets_parent_flow_when_outside_flow(researcher, writer): + crew = Crew( + agents=[researcher, writer], + process=Process.sequential, + tasks=[ + Task(description="Task 1", expected_output="output", agent=researcher), + Task(description="Task 2", expected_output="output", agent=writer), + ], + ) + assert crew.parent_flow is None + + +def test_sets_parent_flow_when_inside_flow(researcher, writer): + class MyFlow(Flow): + @start() + def start(self): + return Crew( + agents=[researcher, writer], + process=Process.sequential, + tasks=[ + Task( + description="Task 1", expected_output="output", agent=researcher + ), + Task(description="Task 2", expected_output="output", agent=writer), + ], + ) + + flow = MyFlow() + result = flow.kickoff() + assert result.parent_flow is flow diff --git a/tests/test_lite_agent.py b/tests/test_lite_agent.py index 06c87319c..7ef23dccb 100644 --- a/tests/test_lite_agent.py +++ b/tests/test_lite_agent.py @@ -1,13 +1,16 @@ import asyncio from typing import cast +from unittest.mock import Mock import pytest from pydantic import BaseModel, Field from crewai import LLM, Agent +from crewai.flow import Flow, start from crewai.lite_agent import LiteAgent, LiteAgentOutput from crewai.tools import BaseTool from crewai.utilities.events import crewai_event_bus +from crewai.utilities.events.agent_events import LiteAgentExecutionStartedEvent from crewai.utilities.events.tool_usage_events import ToolUsageStartedEvent @@ -255,3 +258,60 @@ async def test_lite_agent_returns_usage_metrics_async(): assert "21 million" in result.raw or "37 million" in result.raw assert result.usage_metrics is not None assert result.usage_metrics["total_tokens"] > 0 + + +class TestFlow(Flow): + """A test flow that creates and runs an agent.""" + + def __init__(self, llm, tools): + self.llm = llm + self.tools = tools + super().__init__() + + @start() + def start(self): + agent = Agent( + role="Test Agent", + goal="Test Goal", + backstory="Test Backstory", + llm=self.llm, + tools=self.tools, + ) + return agent.kickoff("Test query") + + +def verify_agent_parent_flow(result, agent, flow): + """Verify that both the result and agent have the correct parent flow.""" + assert result.parent_flow is flow + assert agent is not None + assert agent.parent_flow is flow + + +def test_sets_parent_flow_when_inside_flow(): + captured_agent = None + + mock_llm = Mock(spec=LLM) + mock_llm.call.return_value = "Test response" + + class MyFlow(Flow): + @start() + def start(self): + agent = Agent( + role="Test Agent", + goal="Test Goal", + backstory="Test Backstory", + llm=mock_llm, + tools=[WebSearchTool()], + ) + return agent.kickoff("Test query") + + flow = MyFlow() + with crewai_event_bus.scoped_handlers(): + + @crewai_event_bus.on(LiteAgentExecutionStartedEvent) + def capture_agent(source, event): + nonlocal captured_agent + captured_agent = source + + result = flow.kickoff() + assert captured_agent.parent_flow is flow